Reinforcement Learning
Transformers
English
post-training
distillation
agentic-coding
composer-2.5
cursor
kimi-k2
grpo
dapo
diloco
openenv
trl
verl
research
methodology
Instructions to use Codeseys/composer-replication-framework with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- Transformers
How to use Codeseys/composer-replication-framework with Transformers:
# Load model directly from transformers import AutoModel model = AutoModel.from_pretrained("Codeseys/composer-replication-framework", dtype="auto") - Notebooks
- Google Colab
- Kaggle
| # API Reference — composer-replication-framework | |
| Complete reference for every public symbol in `composer_replication`. Source-of-truth is the `.py` files in `composer_replication/`; docstrings have been pulled verbatim where they exist and supplemented where missing. | |
| **Legend** | |
| - ⚠️ **UNTESTED-CONTRACT** — symbol exists and is callable, but its behaviour is not pinned by an automated test in `composer_replication/**/tests/` or `spikes/**/tests/`. | |
| - 🟡 **SKELETON** — class/method body raises `NotImplementedError`; ships as design-of-record per ADR-005 / ADR-006. | |
| **Module groups (in this document)** | |
| 1. `composer_replication` (top-level re-exports) | |
| 2. `composer_replication.loss` | |
| 3. `composer_replication.batch` | |
| 4. `composer_replication.opsd` | |
| 5. `composer_replication.distillation` | |
| 6. `composer_replication.teacher_replay` | |
| 7. `composer_replication.replaysim` | |
| 8. `composer_replication.ingestion` (+ `.claude_code`) | |
| 9. `composer_replication.hint_generator` | |
| 10. `composer_replication.trainer` (+ `.composer_trainer`, `.data_collator`) | |
| 11. `composer_replication.diloco` | |
| 12. `composer_replication.diloco.serverless` (+ `.executor`, `.allreduce`, `.modal`, `.hf_jobs`, `.replica_entrypoint`) | |
| 13. `composer_replication.recipes.prime_rl.composer_loss` | |
| 14. `composer_replication.recipes.monarch.actors` | |
| --- | |
| ## 1. `composer_replication` — top-level package | |
| The package re-exports the most common entry points from sub-modules. `__all__` is the canonical list of public top-level names. | |
| ### `composer_replication.__version__: str` | |
| Package version string. Currently `"0.1.0"`. | |
| ```python | |
| import composer_replication | |
| print(composer_replication.__version__) # "0.1.0" | |
| ``` | |
| ### `composer_replication._DILOCO_AVAILABLE: bool` | |
| `True` iff `torchft` is importable in the running Python environment (gates `make_diloco_outer_loop`). Set to `False` and `make_diloco_outer_loop` is set to `None` when `torchft` is missing. | |
| ```python | |
| from composer_replication import _DILOCO_AVAILABLE | |
| if _DILOCO_AVAILABLE: | |
| from composer_replication import make_diloco_outer_loop | |
| ``` | |
| ### Re-exports | |
| | Name | Source module | | |
| |---|---| | |
| | `compose_loss` | `composer_replication.loss` | | |
| | `LossComponents` | `composer_replication.loss` | | |
| | `build_batch` | `composer_replication.batch` | | |
| | `generalized_jsd_loss` | `composer_replication.opsd` | | |
| | `ClaudeCodeIngester` | `composer_replication.ingestion.claude_code` | | |
| | `IngestionStats` | `composer_replication.ingestion.claude_code` | | |
| | `SYSTEM_PROMPT` | `composer_replication.ingestion.claude_code` | | |
| | `DEFAULT_TEACHERS` | `composer_replication.teacher_replay` | | |
| | `DPOPair` | `composer_replication.teacher_replay` | | |
| | `TeacherCallResult` | `composer_replication.teacher_replay` | | |
| | `TeacherSpec` | `composer_replication.teacher_replay` | | |
| | `TraceState` | `composer_replication.teacher_replay` | | |
| | `extract_dpo_pairs` | `composer_replication.teacher_replay` | | |
| | `replay_trace` | `composer_replication.teacher_replay` | | |
| | `ComposerReplicationTrainer` | `composer_replication.trainer` | | |
| | `make_diloco_outer_loop` | `composer_replication.diloco` (or `None` if `torchft` missing) | | |
| See each source module below for full signatures. | |
| --- | |
| ## 2. `composer_replication.loss` | |
| Verification-harness 3-channel loss. Free function, does not depend on `trl`. | |
| ### `class LossComponents` | |
| ```python | |
| @dataclass | |
| class LossComponents: | |
| lm_ce: torch.Tensor | |
| sdpo_jsd: torch.Tensor | |
| trace_replay_dpo: torch.Tensor | |
| total: torch.Tensor | |
| def detached(self) -> dict[str, float]: ... | |
| ``` | |
| Per-channel breakdown of the total loss for logging and ablation. All four fields are scalar `torch.Tensor`s (`shape=()`); `total = lm_ce + alpha_sdpo * sdpo_jsd + beta_replay * trace_replay_dpo`. | |
| **`detached() -> dict[str, float]`** — returns Python-float copies of all four fields with no grad. Useful for W&B logging. | |
| ```python | |
| from composer_replication import compose_loss, build_batch | |
| components = compose_loss(model, build_batch(tokenizer)) | |
| print(components.detached()) # {'lm_ce': 2.34, 'sdpo_jsd': 0.12, ...} | |
| components.total.backward() | |
| ``` | |
| ### `compose_loss(model, inputs, *, ...) -> LossComponents` | |
| <a id="compose_loss"></a> | |
| ```python | |
| def compose_loss( | |
| model: torch.nn.Module, | |
| inputs: dict[str, torch.Tensor], | |
| *, | |
| alpha_sdpo: float = 0.1, | |
| beta_replay: float = 0.05, | |
| sdpo_jsd_beta: float = 0.5, | |
| sdpo_temperature: float = 1.0, | |
| sdpo_token_clip: float | None = None, | |
| replay_dpo_beta: float = 0.1, | |
| lm_ce_label_smoothing: float = 0.0, | |
| dpo_variant: Literal["dpo", "simpo"] = "dpo", | |
| sdpo_wrapper: Literal["none", "taid", "entropy_opd"] = "none", | |
| taid_t: float | None = None, | |
| simpo_beta: float = 2.0, | |
| simpo_gamma: float = 1.0, | |
| entropy_opd_h_max: float | None = None, | |
| ) -> LossComponents | |
| ``` | |
| Compute `total = lm_ce + alpha_sdpo * sdpo_jsd + beta_replay * trace_replay_dpo`. | |
| **Required keys in `inputs`** | |
| - `input_ids`: `(B, T_s)` student rollout token ids. | |
| - `response_mask`: `(B, T_s)` 1 on assistant-response tokens, 0 elsewhere. | |
| **Optional keys** (channel auto-disables if missing OR if its weight = 0): | |
| - SDPO: `ctx_teacher_input_ids` `(B, T_t)`, `sdpo_loss_mask` `(B, T_t)`. | |
| - DPO (`dpo_variant="dpo"`): `dpo_chosen_input_ids`, `dpo_chosen_response_mask`, `dpo_rejected_input_ids`, `dpo_rejected_response_mask`, `dpo_chosen_ref_logprobs`, `dpo_rejected_ref_logprobs` (precomputed). | |
| - SimPO (`dpo_variant="simpo"`): same DPO ids/masks; reference logprobs are silently ignored. | |
| - TAID (`sdpo_wrapper="taid"`): no extra `inputs` keys needed; the optional `sdpo_loss_mask` is reused as the per-token TAID mask. Pass `taid_t` directly (or drive it from `TAIDScheduler`). | |
| **Parameters** | |
| | Name | Type | Default | Meaning | | |
| |---|---|---|---| | |
| | `model` | `torch.nn.Module` | — | HF causal-LM. Must accept `input_ids=` and return an object with `.logits`. | | |
| | `inputs` | `dict[str, torch.Tensor]` | — | Batch dict (see required/optional keys above). | | |
| | `alpha_sdpo` | `float` | `0.1` | Weight on SDPO/JSD channel. `0.0` disables. | | |
| | `beta_replay` | `float` | `0.05` | Weight on trace-replay DPO channel. `0.0` disables. | | |
| | `sdpo_jsd_beta` | `float` | `0.5` | β param for `generalized_jsd_loss` (0=fwd KL, 0.5=JSD, 1=rev KL). Unused when `sdpo_wrapper="taid"`. | | |
| | `sdpo_temperature` | `float` | `1.0` | Softmax temperature in SDPO. Unused when `sdpo_wrapper="taid"`. | | |
| | `sdpo_token_clip` | `float \| None` | `None` | Per-token JSD clamp. | | |
| | `replay_dpo_beta` | `float` | `0.1` | β in standard DPO logit. | | |
| | `lm_ce_label_smoothing` | `float` | `0.0` | `F.cross_entropy(label_smoothing=)`. | | |
| | `dpo_variant` | `Literal["dpo","simpo"]` | `"dpo"` | Channel-3 algorithm. | | |
| | `sdpo_wrapper` | `Literal["none","taid","entropy_opd"]` | `"none"` | Channel-2 wrapper. | | |
| | `taid_t` | `float \| None` | `None` | Current TAID interpolation coefficient in `[0, 1]`. Required when `sdpo_wrapper="taid"`. Drive from `TAIDScheduler` or pass a fixed value. | | |
| | `simpo_beta` | `float` | `2.0` | SimPO β (paper default). | | |
| | `simpo_gamma` | `float` | `1.0` | SimPO target margin γ (paper default). | | |
| | `entropy_opd_h_max` | `float \| None` | `None` | Max-entropy normalizer; `None` ⇒ `log(V)`. | | |
| **Returns** `LossComponents` (see above). | |
| **Raises** `ValueError` if `dpo_variant` or `sdpo_wrapper` is unknown, if `sdpo_wrapper="taid"` is requested without `taid_t`, or if `taid_t` is outside `[0, 1]`. | |
| ```python | |
| from composer_replication import compose_loss, build_batch | |
| batch = build_batch(tokenizer) | |
| out = compose_loss(model, batch, alpha_sdpo=0.1, beta_replay=0.05) | |
| out.total.backward() | |
| print(out.detached()) | |
| ``` | |
| --- | |
| ## 3. `composer_replication.batch` | |
| Verification-harness batch builder. | |
| ### `build_batch(tokenizer, *, ...) -> dict[str, torch.Tensor]` | |
| ```python | |
| def build_batch( | |
| tokenizer: Any, | |
| *, | |
| device: torch.device | str = "cpu", | |
| seed: int = 42, | |
| variant: str = "factorial", | |
| align_sdpo_shapes: bool = False, | |
| ) -> dict[str, torch.Tensor] | |
| ``` | |
| Construct a full 3-channel batch from a real HF tokenizer. The DPO ref-logprobs are dummy tensors (the smoke verifies loss composition wires together, not the reference-policy precompute). | |
| **Returned keys**: `input_ids`, `response_mask`, `ctx_teacher_input_ids`, `sdpo_loss_mask`, `dpo_chosen_input_ids`, `dpo_chosen_response_mask`, `dpo_rejected_input_ids`, `dpo_rejected_response_mask`, `dpo_chosen_ref_logprobs`, `dpo_rejected_ref_logprobs`. | |
| **Parameters** | |
| | Name | Type | Default | Meaning | | |
| |---|---|---|---| | |
| | `tokenizer` | HF `AutoTokenizer` (duck-typed) | — | Must support `apply_chat_template` and `__call__`. | | |
| | `device` | `torch.device \| str` | `"cpu"` | Target device for all returned tensors. | | |
| | `seed` | `int` | `42` | Fixes `torch.manual_seed`. | | |
| | `variant` | `str` | `"factorial"` | One of `"factorial"`, `"binary_search"`. | | |
| | `align_sdpo_shapes` | `bool` | `False` | If True, truncate/pad `ctx_teacher_input_ids` to `input_ids` length so the SDPO channel actually fires. | | |
| **Raises** `ValueError` if `variant` is unknown. | |
| ```python | |
| from transformers import AutoTokenizer | |
| from composer_replication import build_batch | |
| tok = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-0.5B-Instruct") | |
| batch = build_batch(tok, variant="factorial", align_sdpo_shapes=True) | |
| print({k: v.shape for k, v in batch.items()}) | |
| ``` | |
| --- | |
| ## 4. `composer_replication.opsd` | |
| Self-distillation generalized-JSD loss, lifted verbatim from `siyan-zhao/OPSD` (MIT) per ADR-006. | |
| ### `generalized_jsd_loss(student_logits, teacher_logits, labels=None, beta=0.5, ...) -> torch.Tensor` | |
| ```python | |
| def generalized_jsd_loss( | |
| student_logits: torch.Tensor, | |
| teacher_logits: torch.Tensor, | |
| labels: torch.Tensor | None = None, | |
| beta: float = 0.5, | |
| temperature: float = 1.0, | |
| reduction: str = "batchmean", | |
| logits_are_probs: bool = False, | |
| top_k: int | None = None, | |
| token_clip: float | None = None, | |
| ) -> torch.Tensor | |
| ``` | |
| Generalized JSD between student and teacher distributions. Same model on different contexts in the SDPO recipe; student and teacher params come from the SAME model. | |
| **Parameters** | |
| | Name | Type | Default | Meaning | | |
| |---|---|---|---| | |
| | `student_logits` | `Tensor (B, T, V)` | — | Student logits with grad. | | |
| | `teacher_logits` | `Tensor (B, T, V)` | — | Teacher logits (no grad in SDPO). | | |
| | `labels` | `Tensor (B, T) \| None` | `None` | Per-token mask. `-100` positions are ignored (HF convention). | | |
| | `beta` | `float` in [0, 1] | `0.5` | 0=fwd KL, 1=rev KL, 0.5=symmetric JSD. | | |
| | `temperature` | `float` | `1.0` | Softmax temperature. | | |
| | `reduction` | `str` | `"batchmean"` | `"batchmean"`, `"sum"`, `"mean"`, `"none"`. | | |
| | `logits_are_probs` | `bool` | `False` | Skip softmax if inputs are already probabilities. | | |
| | `top_k` | `int \| None` | `None` | Restrict KL to teacher's top-k tokens. | | |
| | `token_clip` | `float \| None` | `None` | Clip per-token JSD for stability. | | |
| **Returns** scalar tensor (or `(B, T)` if `reduction="none"`). | |
| **Raises** `ValueError` for unknown `reduction`. | |
| ```python | |
| import torch | |
| from composer_replication.opsd import generalized_jsd_loss | |
| s = torch.randn(2, 8, 32, requires_grad=True) | |
| t = torch.randn(2, 8, 32) | |
| loss = generalized_jsd_loss(s, t, beta=0.5, reduction="batchmean") | |
| loss.backward() | |
| ``` | |
| --- | |
| ## 5. `composer_replication.distillation` | |
| Pluggable self-distillation losses (ADR-007). All pure PyTorch. | |
| ### `simpo_loss(chosen_avg_logprobs, rejected_avg_logprobs, *, beta=2.0, gamma=1.0) -> torch.Tensor` | |
| ```python | |
| def simpo_loss( | |
| chosen_avg_logprobs: torch.Tensor, | |
| rejected_avg_logprobs: torch.Tensor, | |
| *, | |
| beta: float = 2.0, | |
| gamma: float = 1.0, | |
| ) -> torch.Tensor | |
| ``` | |
| Reference-free DPO with target margin γ (Meng et al., NeurIPS 2024). `L = -log σ(β · (avg_logπ(c) − avg_logπ(r)) − γ)`. | |
| **Parameters** | |
| | Name | Type | Default | Meaning | | |
| |---|---|---|---| | |
| | `chosen_avg_logprobs` | `Tensor (B,)` | — | Per-sequence avg logprob over chosen response tokens. | | |
| | `rejected_avg_logprobs` | `Tensor (B,)` | — | Same for rejected. | | |
| | `beta` | `float` | `2.0` | Scaling factor (paper default). | | |
| | `gamma` | `float` | `1.0` | Target margin (paper default). | | |
| **Returns** scalar; **Raises** `ValueError` if shapes mismatch. | |
| ```python | |
| import torch | |
| from composer_replication.distillation import simpo_loss | |
| loss = simpo_loss(torch.tensor([-2.1, -1.8]), torch.tensor([-3.0, -2.5]), | |
| beta=2.0, gamma=1.0) | |
| ``` | |
| ### `avg_sequence_logprob(model_logprobs, response_mask) -> torch.Tensor` | |
| ⚠️ UNTESTED-CONTRACT (helper exported from `simpo.py` but not asserted by a test). | |
| ```python | |
| def avg_sequence_logprob( | |
| model_logprobs: torch.Tensor, | |
| response_mask: torch.Tensor, | |
| ) -> torch.Tensor | |
| ``` | |
| Convert `(B, T)` per-token logprobs + `(B, T)` response mask into `(B,)` per-sequence average over response tokens. | |
| ```python | |
| from composer_replication.distillation.simpo import avg_sequence_logprob | |
| import torch | |
| lp = torch.randn(2, 8); m = torch.tensor([[0,0,1,1,1,0,0,0],[0,1,1,1,1,1,0,0]]) | |
| out = avg_sequence_logprob(lp, m) # shape (2,) | |
| ``` | |
| ### `taid_loss(student_logits, teacher_logits, mask=None, *, t) -> torch.Tensor` | |
| ```python | |
| def taid_loss( | |
| student_logits: torch.Tensor, | |
| teacher_logits: torch.Tensor, | |
| mask: torch.Tensor | None = None, | |
| *, | |
| t: float | torch.Tensor, | |
| ) -> torch.Tensor | |
| ``` | |
| Faithful port of `SakanaAI/TAID` (arXiv:2501.16937). Forward-KL distillation against a logit-space-interpolated target whose anchor is the **current student detached**: | |
| ``` | |
| p_t = softmax( (1 - t) · stop_grad(student_logits) + t · teacher_logits ) | |
| L = - mean_token Σ_v p_t(v) · log_softmax(student_logits)(v) | |
| ``` | |
| At `t=0` the target collapses to the detached student (no teacher signal in the gradient). At `t=1` it reduces to standard forward-KL distillation against the teacher. | |
| **Wave 15 breaking change.** The previous signature `taid_loss(student, teacher, student_init, *, schedule_step, total_steps, schedule, alpha_min, alpha_max, jsd_beta, temperature, reduction)` was algorithmically wrong (probability-space mix, frozen step-0 anchor, JSD criterion). All those kwargs are removed; the schedule is now the caller's responsibility (see `TAIDScheduler` below for the upstream adaptive scheme). | |
| **Parameters** | |
| | Name | Type | Default | Meaning | | |
| |---|---|---|---| | |
| | `student_logits` | `Tensor (B, T, V)` | — | Current student (with grad). | | |
| | `teacher_logits` | `Tensor (B, T, V)` | — | Teacher logits. | | |
| | `mask` | `Tensor (B, T) \| None` | `None` | Token mask. `None` ⇒ all-ones. | | |
| | `t` | `float \| Tensor` | — | Interpolation coefficient in `[0, 1]`. | | |
| **Raises** `ValueError` for shape mismatch. | |
| ```python | |
| from composer_replication.distillation import taid_loss | |
| loss = taid_loss(s_logits, t_logits, mask, t=0.4) | |
| ``` | |
| ### `TAIDScheduler(num_train_steps, *, t_start=0.4, t_end=1.0, alpha=5e-4, beta=0.99, disable_adaptive=False)` | |
| Stateful schedule that mirrors upstream `TAID.update_t`. Monotone non-decreasing, bumped above the linear floor by an EMA on the relative loss change. Use as: | |
| ```python | |
| from composer_replication.distillation import TAIDScheduler | |
| sched = TAIDScheduler(num_train_steps=10_000) # paper defaults | |
| for step in range(num_train_steps): | |
| loss = taid_loss(s, t, mask, t=sched.t) | |
| loss.backward(); optimizer.step() | |
| sched.update_t(loss.detach(), global_step=step) | |
| ``` | |
| **Parameters** | |
| | Name | Type | Default | Meaning | | |
| |---|---|---|---| | |
| | `num_train_steps` | `int` | — | Total planned training steps; sets the linear floor. | | |
| | `t_start` | `float` | `0.4` | Initial `t` (paper default). | | |
| | `t_end` | `float` | `1.0` | Terminal `t`; hard ceiling at every step. | | |
| | `alpha` | `float` | `5e-4` | Adaptive bump magnitude. | | |
| | `beta` | `float` | `0.99` | EMA decay on relative-loss-change momentum. | | |
| | `disable_adaptive` | `bool` | `False` | If True, fall back to deterministic linear schedule. | | |
| | `device` | `torch.device \| str` | `"cpu"` | Where to allocate state buffers. | | |
| **Properties / methods** | |
| - `sched.t -> float` — current `t` as a Python float (zero-arg property). | |
| - `sched.update_t(loss, global_step) -> Tensor | None` — update internal state. First finite-loss call only seeds `prev_loss` and returns `None`; subsequent calls return the (positive) `delta_t` added on top of the linear floor. | |
| ### `entropy_aware_opd_loss(student_logits, teacher_logits, *, labels=None, h_max=None, temperature=1.0, reduction="batchmean") -> torch.Tensor` | |
| ```python | |
| def entropy_aware_opd_loss( | |
| student_logits: torch.Tensor, | |
| teacher_logits: torch.Tensor, | |
| *, | |
| labels: torch.Tensor | None = None, | |
| h_max: float | None = None, | |
| temperature: float = 1.0, | |
| reduction: str = "batchmean", | |
| ) -> torch.Tensor | |
| ``` | |
| Per-token mixture of forward and reverse KL gated by teacher entropy: `w(t) = clamp(H_teacher(t)/h_max, 0, 1)`. High-entropy tokens use forward KL (mode-covering), low-entropy tokens use reverse KL (mode-seeking). | |
| **Parameters** | |
| | Name | Type | Default | Meaning | | |
| |---|---|---|---| | |
| | `student_logits` | `Tensor (B,T,V)` | — | Student logits (grad). | | |
| | `teacher_logits` | `Tensor (B,T,V)` | — | Teacher logits (no grad). | | |
| | `labels` | `Tensor (B,T) \| None` | `None` | 0/1 mask, applied multiplicatively after the per-token mix. | | |
| | `h_max` | `float \| None` | `None` ⇒ `log(V)` | Max-entropy normalizer. | | |
| | `temperature` | `float` | `1.0` | Softmax temperature on both. | | |
| | `reduction` | `str` | `"batchmean"` | `"batchmean"`, `"sum"`, `"mean"`, `"none"`. | | |
| **Raises** `ValueError` on shape mismatch (student vs teacher; labels vs per-token loss) or unknown `reduction`. | |
| ```python | |
| from composer_replication.distillation import entropy_aware_opd_loss | |
| loss = entropy_aware_opd_loss(s_logits, t_logits, temperature=1.0) | |
| loss.backward() | |
| ``` | |
| ### `teacher_entropy(teacher_logits) -> torch.Tensor` | |
| ⚠️ UNTESTED-CONTRACT (helper exposed from `entropy_aware_opd.py`'s `__all__` but not directly asserted). | |
| Per-token entropy in nats. Input `(B,T,V)`, output `(B,T)`. | |
| ```python | |
| from composer_replication.distillation.entropy_aware_opd import teacher_entropy | |
| H = teacher_entropy(teacher_logits) # (B, T) | |
| ``` | |
| --- | |
| ## 6. `composer_replication.teacher_replay` | |
| N-teacher OpenRouter parallel client + DPO-pair extractor. `httpx` is lazy-imported inside `replay_trace`; the deterministic local logic is testable without it. | |
| ### `DEFAULT_TEACHERS: list[TeacherSpec]` | |
| Three-teacher default set: `anthropic/claude-opus-4.7`, `openai/gpt-5`, `deepseek/deepseek-v4-pro` with paper-baseline OpenRouter pricing. | |
| ```python | |
| from composer_replication.teacher_replay import DEFAULT_TEACHERS | |
| print([t["slug"] for t in DEFAULT_TEACHERS]) | |
| ``` | |
| ### `class TeacherSpec(TypedDict)` | |
| ```python | |
| class TeacherSpec(TypedDict): | |
| slug: str | |
| input_per_mtok: float | |
| output_per_mtok: float | |
| ``` | |
| OpenRouter model slug + per-million-token pricing. | |
| ```python | |
| spec: TeacherSpec = {"slug": "openai/gpt-5", | |
| "input_per_mtok": 1.25, "output_per_mtok": 10.0} | |
| ``` | |
| ### `class TraceState(TypedDict)` | |
| ```python | |
| class TraceState(TypedDict): | |
| state_id: str # unique within the trace | |
| messages: list[dict] # OpenAI-style chat history up to (and incl.) this user prompt | |
| student_action: str # what the student actually did at this step | |
| ``` | |
| One step of a frozen agentic trace. `student_action` is the raw text emitted by the student; teachers are queried with `messages` and asked to predict the assistant's next action. | |
| ```python | |
| state: TraceState = {"state_id": "ex001::0042", | |
| "messages": [{"role": "user", "content": "..."}], | |
| "student_action": "[TOOL_USE] name=Read input={...}"} | |
| ``` | |
| ### `class TeacherCallResult(TypedDict)` | |
| ```python | |
| class TeacherCallResult(TypedDict): | |
| state_id: str | |
| teacher_slug: str | |
| response_text: str | None # None on error | |
| latency_s: float | |
| prompt_tokens: int | |
| completion_tokens: int | |
| cost_usd: float | |
| error: str | None # None on success | |
| ``` | |
| One row of N×T results from `replay_trace`. | |
| ```python | |
| r: TeacherCallResult = {"state_id": "x", "teacher_slug": "openai/gpt-5", | |
| "response_text": "ok", "latency_s": 1.2, "prompt_tokens": 100, | |
| "completion_tokens": 5, "cost_usd": 0.001, "error": None} | |
| ``` | |
| ### `class DPOPair(TypedDict)` | |
| ```python | |
| class DPOPair(TypedDict): | |
| state_id: str | |
| state_messages: list[dict] | |
| chosen: str # teacher-consensus action | |
| rejected: str # student action | |
| n_teachers_agreeing: int | |
| ``` | |
| One preference pair extracted from teacher-vs-student disagreement. | |
| ```python | |
| p: DPOPair = {"state_id": "x", "state_messages": [...], "chosen": "...", | |
| "rejected": "...", "n_teachers_agreeing": 2} | |
| ``` | |
| ### `async replay_trace(states, teachers=DEFAULT_TEACHERS, max_total_usd=5.0, api_key=None) -> list[TeacherCallResult]` | |
| ```python | |
| async def replay_trace( | |
| states: Sequence[TraceState], | |
| teachers: Sequence[TeacherSpec] = tuple(DEFAULT_TEACHERS), | |
| max_total_usd: float = 5.0, | |
| api_key: str | None = None, | |
| ) -> list[TeacherCallResult] | |
| ``` | |
| For each state, fan-out one parallel call per teacher via OpenRouter. Hard-caps cumulative spend at `max_total_usd` (stops after the offending state completes). | |
| **Parameters** | |
| | Name | Type | Default | Meaning | | |
| |---|---|---|---| | |
| | `states` | `Sequence[TraceState]` | — | Frozen trace, one entry per assistant turn. | | |
| | `teachers` | `Sequence[TeacherSpec]` | `DEFAULT_TEACHERS` | Models to query in parallel. | | |
| | `max_total_usd` | `float` | `5.0` | Cumulative spend cap. | | |
| | `api_key` | `str \| None` | `None` | OpenRouter key; defaults to `OPENROUTER_API_KEY` env or `~/.hermes/.env`. | | |
| **Returns** flat list of `TeacherCallResult`s (length `len(states) * len(teachers)` modulo budget cutoff). | |
| **Raises** `RuntimeError` if `OPENROUTER_API_KEY` is not findable; `ImportError` if `httpx` is missing at call time. | |
| ```python | |
| import asyncio | |
| from composer_replication import replay_trace | |
| results = asyncio.run(replay_trace(states=my_trace, max_total_usd=1.0)) | |
| ``` | |
| ### `extract_dpo_pairs(states, teacher_actions, agreement_threshold=2) -> list[DPOPair]` | |
| ```python | |
| def extract_dpo_pairs( | |
| states: Sequence[TraceState], | |
| teacher_actions: Sequence[TeacherCallResult], | |
| agreement_threshold: int = 2, | |
| ) -> list[DPOPair] | |
| ``` | |
| Group teacher_actions by `state_id`, normalize whitespace, and emit one `DPOPair` per state where ≥`agreement_threshold` teachers agreed on an action that differs from the student's. `chosen` is the original (un-normalized) teacher response text. | |
| **Parameters** | |
| | Name | Type | Default | Meaning | | |
| |---|---|---|---| | |
| | `states` | `Sequence[TraceState]` | — | Same as passed to `replay_trace`. | | |
| | `teacher_actions` | `Sequence[TeacherCallResult]` | — | Output of `replay_trace`. | | |
| | `agreement_threshold` | `int` | `2` | Min teachers that must agree for a pair to fire. | | |
| **Returns** list of `DPOPair`. At most one pair per state (the most-agreed-upon action wins). | |
| ```python | |
| from composer_replication import extract_dpo_pairs | |
| pairs = extract_dpo_pairs(my_states, results, agreement_threshold=2) | |
| ``` | |
| ### `save_pairs(pairs, path) -> None` | |
| ⚠️ UNTESTED-CONTRACT. | |
| ```python | |
| def save_pairs(pairs: Sequence[DPOPair], path: str | Path) -> None | |
| ``` | |
| Write pairs to JSONL (one dict per line). Creates parent dirs. | |
| ```python | |
| from composer_replication.teacher_replay import save_pairs | |
| save_pairs(pairs, "/tmp/dpo_pairs.jsonl") | |
| ``` | |
| --- | |
| ## 7. `composer_replication.replaysim` | |
| ADR-004 normalization layer over `teacher_replay`. Re-exports `DPOPair`, `TeacherCallResult`, `extract_dpo_pairs`, `replay_trace` from `teacher_replay`. | |
| ### `class NormalizedDPOPair` | |
| ```python | |
| @dataclass | |
| class NormalizedDPOPair: | |
| state_id: str | |
| state_messages: list[dict[str, Any]] | |
| chosen_messages: list[dict[str, Any]] | |
| rejected_messages: list[dict[str, Any]] | |
| n_teachers_agreeing: int | |
| metadata: dict[str, Any] | |
| ``` | |
| Post-normalization shape. `chosen_messages`/`rejected_messages` are chat-format (`[{"role": "assistant", "content": ...}]`). `metadata` carries op-graph provenance, including `{"skipped": True}` when the normalizer was bypassed (`skip_dj=True`). | |
| ```python | |
| from composer_replication.replaysim import NormalizedDPOPair | |
| n = NormalizedDPOPair(state_id="x", state_messages=[], | |
| chosen_messages=[{"role": "assistant", "content": "ok"}], | |
| rejected_messages=[{"role": "assistant", "content": "no"}], | |
| n_teachers_agreeing=2, metadata={}) | |
| ``` | |
| ### `class DJNormalizer` | |
| ```python | |
| class DJNormalizer: | |
| DEFAULT_RECIPE: ClassVar[Path] # composer_replication/recipes/replaysim/default.yaml | |
| def __init__( | |
| self, | |
| recipe_path: str | os.PathLike[str] | None = None, | |
| *, | |
| skip_dj: bool = False, | |
| ) -> None: ... | |
| def normalize( | |
| self, | |
| pairs: Iterable[DPOPair | dict[str, Any]], | |
| ) -> list[NormalizedDPOPair]: ... | |
| ``` | |
| `data-juicer`-backed normalizer. Pipeline: each `DPOPair` → JSONL record → `data_juicer.core.DefaultExecutor.run()` against the recipe → JSONL → `NormalizedDPOPair`. | |
| **Constructor parameters** | |
| | Name | Type | Default | Meaning | | |
| |---|---|---|---| | |
| | `recipe_path` | `str \| PathLike \| None` | `None` ⇒ default recipe | data-juicer YAML recipe path. | | |
| | `skip_dj` | `bool` (kw-only) | `False` | If True: passthrough; records get `metadata={"skipped": True}` and no ops run. | | |
| **`normalize(pairs) -> list[NormalizedDPOPair]`** runs the op-graph. Output may be shorter than input if filter ops drop records. | |
| **Raises** `RuntimeError` at construction time if `skip_dj=False` and `data_juicer` is not importable. `FileNotFoundError` if `recipe_path` (default or explicit) is missing and `skip_dj=False`. | |
| ```python | |
| from composer_replication.replaysim import DJNormalizer | |
| norm = DJNormalizer(skip_dj=True) | |
| out = norm.normalize(my_pairs) | |
| ``` | |
| ### `async replay_and_normalize_trace(*, states, teachers=None, agreement_threshold=2, max_total_usd=5.0, normalizer=None, **replay_kwargs) -> tuple[list[TeacherCallResult], list[NormalizedDPOPair]]` | |
| ```python | |
| async def replay_and_normalize_trace( | |
| *, | |
| states: Any, | |
| teachers: Any = None, | |
| agreement_threshold: int = 2, | |
| max_total_usd: float = 5.0, | |
| normalizer: DJNormalizer | None = None, | |
| **replay_kwargs: Any, | |
| ) -> tuple[list[TeacherCallResult], list[NormalizedDPOPair]] | |
| ``` | |
| End-to-end async: replay → extract pairs → normalize. | |
| **Parameters** | |
| | Name | Type | Default | Meaning | | |
| |---|---|---|---| | |
| | `states` | `Sequence[TraceState]` | — | Frozen trace. | | |
| | `teachers` | `Sequence[TeacherSpec] \| None` | `None` ⇒ defaults | Forwarded to `replay_trace`. | | |
| | `agreement_threshold` | `int` | `2` | Forwarded to `extract_dpo_pairs`. | | |
| | `max_total_usd` | `float` | `5.0` | Spend cap. | | |
| | `normalizer` | `DJNormalizer \| None` | `None` ⇒ `DJNormalizer()` | Pass `DJNormalizer(skip_dj=True)` to bypass. | | |
| | `**replay_kwargs` | `Any` | — | Forwarded to `replay_trace` (e.g. `api_key`). | | |
| **Returns** `(raw_teacher_actions, normalized_pairs)`. | |
| ```python | |
| import asyncio | |
| from composer_replication.replaysim import replay_and_normalize_trace, DJNormalizer | |
| raw, norm = asyncio.run(replay_and_normalize_trace( | |
| states=my_states, normalizer=DJNormalizer(skip_dj=True))) | |
| ``` | |
| ### `replay_and_normalize_trace_sync(*args, **kwargs) -> tuple[list[TeacherCallResult], list[NormalizedDPOPair]]` | |
| ⚠️ UNTESTED-CONTRACT (sync wrapper around the async function; tests call the async form via `asyncio.run`). | |
| ```python | |
| def replay_and_normalize_trace_sync(*args, **kwargs) -> ... | |
| ``` | |
| Sync convenience wrapping `asyncio.run(replay_and_normalize_trace(...))`. | |
| ```python | |
| from composer_replication.replaysim.normalize import replay_and_normalize_trace_sync | |
| raw, norm = replay_and_normalize_trace_sync(states=my_states) | |
| ``` | |
| --- | |
| ## 8. `composer_replication.ingestion` & `composer_replication.ingestion.claude_code` | |
| Trace-source adapters (ADR-002). v0.1 supports Claude Code session JSONL. | |
| ### `SYSTEM_PROMPT: str` | |
| Default synthetic system prompt injected at `messages[0]` for ingested traces (most Claude Code sessions don't write one). Truncated head: `"You are a senior software engineer working as a coding agent in a terminal environment..."`. | |
| ```python | |
| from composer_replication import SYSTEM_PROMPT | |
| print(SYSTEM_PROMPT[:60]) | |
| ``` | |
| ### `class IngestionStats` | |
| ```python | |
| @dataclass | |
| class IngestionStats: | |
| n_records_total: int = 0 | |
| n_records_skipped: int = 0 | |
| n_states_emitted: int = 0 | |
| n_assistant_turns: int = 0 | |
| n_tool_use_blocks: int = 0 | |
| n_text_blocks: int = 0 | |
| skipped_subagent: int = 0 | |
| skipped_summary: int = 0 | |
| skipped_truncated_lines: int = 0 | |
| version_warnings: list[str] | None = None # initialized to [] in __post_init__ | |
| ``` | |
| Counters populated by `ClaudeCodeIngester.ingest()` and exposed as `ingester.last_stats`. | |
| ```python | |
| from composer_replication import IngestionStats | |
| s = IngestionStats(n_records_total=5) | |
| print(s.version_warnings) # [] | |
| ``` | |
| ### `class ClaudeCodeIngester` | |
| ```python | |
| class ClaudeCodeIngester: | |
| def __init__( | |
| self, | |
| *, | |
| system_prompt: str = SYSTEM_PROMPT, | |
| skip_sidechain: bool = True, | |
| strip_thinking: bool = True, | |
| max_history_tokens: int | None = None, | |
| ) -> None: ... | |
| def ingest(self, path: Path) -> Iterator[TraceState]: ... | |
| ``` | |
| Convert a Claude Code session JSONL to a stream of `TraceState`s — one per assistant TURN (not per `tool_use` block). | |
| **Constructor parameters** | |
| | Name | Type | Default | Meaning | | |
| |---|---|---|---| | |
| | `system_prompt` | `str` | `SYSTEM_PROMPT` | Synthetic system message injected at history[0]. | | |
| | `skip_sidechain` | `bool` | `True` | Skip subagent files (`agent-*.jsonl`) and records with `isSidechain=True`. | | |
| | `strip_thinking` | `bool` | `True` | Remove `[THINKING]` blocks from history handed to teachers (kept inside `student_action`). | | |
| | `max_history_tokens` | `int \| None` | `None` | ⚠️ UNTESTED-CONTRACT — accepted but currently not used to truncate. | | |
| **`ingest(path) -> Iterator[TraceState]`**: generator over `TraceState` objects. Each turn's `state_id` is `f"{path.stem}::{idx:04d}"`. Side effect: replaces `self.last_stats` with a fresh `IngestionStats` and updates it as records stream. | |
| ```python | |
| from pathlib import Path | |
| from composer_replication import ClaudeCodeIngester | |
| ing = ClaudeCodeIngester() | |
| for state in ing.ingest(Path("session.jsonl")): | |
| print(state["state_id"]) | |
| print(ing.last_stats.n_states_emitted) | |
| ``` | |
| --- | |
| ## 9. `composer_replication.hint_generator` | |
| ⚠️ UNTESTED-CONTRACT (entire module — used by the data collator config but not pinned by a test). | |
| Template-based hint registry for SDPO error-site injection. | |
| ### `class HintContext(TypedDict, total=False)` | |
| ```python | |
| class HintContext(TypedDict, total=False): | |
| error_kind: str | |
| error_message: str | |
| available_tools: list[str] | |
| tool_name: str | |
| tool_schema: dict | |
| intent: str | |
| ``` | |
| Per-error context dict consumed by hint templates. | |
| ### `HINT_TEMPLATES: dict[str, Callable[[HintContext], str]]` | |
| Default registry keys: `"tool_not_found"`, `"json_decode"`, `"type_error"`, `"runtime_error"`, `"repeated_failure"`. | |
| ### `dispatch(error_kind, ctx=None) -> str | None` | |
| ```python | |
| def dispatch(error_kind: str, ctx: HintContext | None = None) -> str | None | |
| ``` | |
| Look up `error_kind` in `HINT_TEMPLATES`. Returns the template's hint text, or `None` if the kind is unknown. | |
| ```python | |
| from composer_replication.hint_generator import dispatch | |
| hint = dispatch("json_decode") # "Reminder: tool arguments must be valid JSON. ..." | |
| ``` | |
| ### `register(error_kind, fn) -> None` | |
| ```python | |
| def register(error_kind: str, fn: Callable[[HintContext], str]) -> None | |
| ``` | |
| Add or override a custom hint template. | |
| ```python | |
| from composer_replication.hint_generator import register | |
| register("my_error", lambda ctx: "Reminder: try X.") | |
| ``` | |
| ### Individual template functions | |
| ⚠️ UNTESTED-CONTRACT — exported only via `HINT_TEMPLATES`, useful as building blocks: | |
| - `hint_tool_not_found(ctx) -> str` | |
| - `hint_json_decode(ctx) -> str` | |
| - `hint_type_error(ctx) -> str` | |
| - `hint_runtime_error(ctx) -> str` | |
| - `hint_repeated_failure(ctx) -> str` | |
| Each accepts a `HintContext` and returns hint text. Signatures are uniform: `Callable[[HintContext], str]`. | |
| ```python | |
| from composer_replication.hint_generator import hint_tool_not_found | |
| text = hint_tool_not_found({"available_tools": ["Read", "Write"]}) | |
| ``` | |
| --- | |
| ## 10. `composer_replication.trainer` & sub-modules | |
| Production trainer (TRL `GRPOTrainer` subclass) plus data collator. | |
| ### `class ComposerReplicationTrainer` | |
| ```python | |
| class ComposerReplicationTrainer(GRPOTrainer): | |
| def __init__( | |
| self, | |
| *args: Any, | |
| alpha_sdpo: float = 0.1, | |
| beta_replay: float = 0.05, | |
| sdpo_jsd_beta: float = 0.5, | |
| sdpo_temperature: float = 1.0, | |
| sdpo_token_clip: float | None = None, | |
| replay_dpo_beta: float = 0.1, | |
| **kwargs: Any, | |
| ) -> None: ... | |
| def _compute_loss( | |
| self, | |
| model: torch.nn.Module, | |
| inputs: dict[str, torch.Tensor], | |
| ) -> torch.Tensor: ... | |
| ``` | |
| `trl.GRPOTrainer` subclass that overrides `_compute_loss(model, inputs)` to compose `total = grpo + α·sdpo + β·trace_replay_dpo`. When `trl` is not installed, the parent class falls back to `object` so the module imports — but instantiation will fail because the parent's GRPO machinery is missing. | |
| **Constructor (kw-only beyond GRPOTrainer's own `*args, **kwargs`)** | |
| | Name | Type | Default | Meaning | | |
| |---|---|---|---| | |
| | `alpha_sdpo` | `float` | `0.1` | Channel-2 weight. | | |
| | `beta_replay` | `float` | `0.05` | Channel-3 weight. | | |
| | `sdpo_jsd_beta` | `float` | `0.5` | β for `generalized_jsd_loss`. | | |
| | `sdpo_temperature` | `float` | `1.0` | SDPO softmax temperature. | | |
| | `sdpo_token_clip` | `float \| None` | `None` | Per-token JSD clip. | | |
| | `replay_dpo_beta` | `float` | `0.1` | DPO β. | | |
| **`_compute_loss(model, inputs) -> torch.Tensor`** — overrides `GRPOTrainer._compute_loss`. Calls `super()._compute_loss` for channel 1, then `_compute_sdpo_loss` and `_compute_trace_replay_loss`, then composes. Logs per-channel components every `args.logging_steps` (default 50). **Raises** whatever `super()` raises (TRL-shaped errors). | |
| **Internal methods (publicly accessible, exercised by spike tests)** | |
| - ⚠️ UNTESTED-CONTRACT `_compute_sdpo_loss(model, inputs) -> torch.Tensor` — generalized-JSD between student forward and `ctx_teacher_input_ids` forward. Returns `0.0` (with grad) when `alpha_sdpo == 0`, the key is missing, or shapes mismatch. Logs a warning on shape mismatch. | |
| - ⚠️ UNTESTED-CONTRACT `_compute_trace_replay_loss(model, inputs) -> torch.Tensor` — standard DPO over `dpo_chosen_*` and `dpo_rejected_*`, using precomputed `dpo_chosen_ref_logprobs` / `dpo_rejected_ref_logprobs`. | |
| - ⚠️ UNTESTED-CONTRACT `@staticmethod _sequence_logprobs(model, input_ids, response_mask) -> torch.Tensor` — sum logprobs over response tokens; standard DPO accounting. | |
| ```python | |
| from composer_replication import ComposerReplicationTrainer | |
| trainer = ComposerReplicationTrainer( | |
| model=my_model, args=my_grpo_args, train_dataset=ds, | |
| data_collator=my_collator, alpha_sdpo=0.1, beta_replay=0.05, | |
| ) | |
| # trainer.train() # uses overridden _compute_loss | |
| ``` | |
| ### `class TraceTurn(TypedDict, total=False)` — `trainer.data_collator` | |
| ```python | |
| class TraceTurn(TypedDict, total=False): | |
| role: str # "user" | "assistant" | "tool" | |
| content: str | |
| tool_call: dict | None | |
| tool_error: str | None | |
| error_meta: dict | |
| ``` | |
| One turn of an agentic trace as consumed by `ComposerDataCollator`. | |
| ### `class TraceExample(TypedDict, total=False)` — `trainer.data_collator` | |
| ```python | |
| class TraceExample(TypedDict, total=False): | |
| trace_id: str | |
| turns: list[TraceTurn] | |
| final_reward: float | |
| dpo_pairs: list[dict] | None | |
| ``` | |
| One training example: `(turns, optional dpo_pairs)`. `dpo_pairs` shape matches `DPOPair`. | |
| ### `class TokenizerLike` — `trainer.data_collator` | |
| ⚠️ UNTESTED-CONTRACT (duck-typed protocol; used as a type hint). | |
| ```python | |
| class TokenizerLike: | |
| pad_token_id: int | |
| def __call__(self, text: str | list[str], **kwargs: Any) -> dict[str, list]: ... | |
| def apply_chat_template(self, messages: list[dict], **kwargs: Any) -> str | list[int]: ... | |
| ``` | |
| Minimal protocol the collator needs. Compatible with HF `AutoTokenizer`. | |
| ### `class CollatorConfig` — `trainer.data_collator` | |
| ```python | |
| @dataclass | |
| class CollatorConfig: | |
| max_seq_len: int = 4096 | |
| max_dpo_seq_len: int = 2048 | |
| pad_token_id: int = 0 | |
| ignore_index: int = -100 | |
| enable_sdpo: bool = True | |
| hint_generator: Callable[[str, dict], str | None] | None = None | |
| enable_replay_dpo: bool = True | |
| rlvr_reward_key: str = "final_reward" | |
| ``` | |
| Tunables for `ComposerDataCollator`. | |
| | Field | Default | Meaning | | |
| |---|---|---| | |
| | `max_seq_len` | `4096` | Truncation cap for student/teacher sequences. | | |
| | `max_dpo_seq_len` | `2048` | Truncation cap for DPO chosen/rejected sequences. | | |
| | `pad_token_id` | `0` | Padding token id. | | |
| | `ignore_index` | `-100` | HF "ignore in loss" sentinel for SDPO mask. | | |
| | `enable_sdpo` | `True` | Toggle channel-2 fields. | | |
| | `hint_generator` | `Callable[[str, dict], str \| None] \| None` (`None`) | `(error_kind, error_meta) -> hint_text`. SDPO is no-op without this. | | |
| | `enable_replay_dpo` | `True` | Toggle channel-3 fields. | | |
| | `rlvr_reward_key` | `"final_reward"` | Key in `TraceExample` to read scalar reward. | | |
| ```python | |
| from composer_replication.trainer.data_collator import CollatorConfig | |
| cfg = CollatorConfig(max_seq_len=2048, hint_generator=my_dispatch) | |
| ``` | |
| ### `class ComposerDataCollator` — `trainer.data_collator` | |
| ```python | |
| @dataclass | |
| class ComposerDataCollator: | |
| tokenizer: TokenizerLike | |
| config: CollatorConfig = field(default_factory=CollatorConfig) | |
| def __call__( | |
| self, batch: Sequence[TraceExample] | |
| ) -> dict[str, torch.Tensor]: ... | |
| ``` | |
| Build trainer-ready batches from raw traces + optional DPO pairs. | |
| **Output dict keys** (tested in `spikes/005-integrated-trainer-skeleton/tests/test_data_collator.py`): | |
| - Channel 1 (always): `input_ids`, `attention_mask`, `response_mask`, `rewards`. | |
| - Channel 2 (when `enable_sdpo=True` AND batch has at least one error site AND `hint_generator` is set): `ctx_teacher_input_ids`, `sdpo_loss_mask`. | |
| - Channel 3 (when `enable_replay_dpo=True` AND batch has at least one `dpo_pair`): `dpo_chosen_input_ids`, `dpo_chosen_response_mask`, `dpo_rejected_input_ids`, `dpo_rejected_response_mask`. (Reference logprobs are NOT computed here — the trainer does that pass.) | |
| ```python | |
| from composer_replication.trainer.data_collator import ( | |
| ComposerDataCollator, CollatorConfig) | |
| collator = ComposerDataCollator(tokenizer=tok, config=CollatorConfig()) | |
| batch = collator([{"trace_id": "x", "turns": [...], "final_reward": 1.0}]) | |
| ``` | |
| --- | |
| ## 11. `composer_replication.diloco` | |
| DiLoCo outer-loop wrapper around `torchft.local_sgd.DiLoCo`. Optional dep — when `torchft` is missing the package re-export `composer_replication.make_diloco_outer_loop` is `None`. | |
| ### Module-level attributes | |
| - `DiLoCo: Any` — `torchft.local_sgd.DiLoCo` if importable else `None`. | |
| - `Manager: Any` — `torchft.manager.Manager` if importable else `None`. | |
| - `_DummyWork: Any` — `torchft.work._DummyWork` if importable else `None`. | |
| - `_TORCHFT_AVAILABLE: bool` — whether the imports succeeded. | |
| ```python | |
| from composer_replication.diloco import _TORCHFT_AVAILABLE, DiLoCo | |
| ``` | |
| ### `make_diloco_outer_loop(manager, model_fragments, inner_optimizer, *, ...) -> torchft.local_sgd.DiLoCo` | |
| ```python | |
| def make_diloco_outer_loop( | |
| manager: Any, | |
| model_fragments: list[torch.nn.Module], | |
| inner_optimizer: torch.optim.Optimizer, | |
| *, | |
| outer_lr: float = 0.7, | |
| outer_momentum: float = 0.9, | |
| nesterov: bool = True, | |
| sync_every: int = 100, | |
| fragment_sync_delay: int = 0, | |
| fragment_update_alpha: float = 0.0, | |
| ) -> Any | |
| ``` | |
| Construct a `torchft.DiLoCo` configured with framework-default hyperparams (DiLoCo paper §3.2: `lr=0.7, momentum=0.9, Nesterov`). | |
| **Parameters** | |
| | Name | Type | Default | Meaning | | |
| |---|---|---|---| | |
| | `manager` | `torchft.Manager` (or duck-typed `MockManager`) | — | Provides `allreduce`, `should_commit`, `current_step`, `start_quorum`, etc. | | |
| | `model_fragments` | `list[torch.nn.Module]` | — | One module for vanilla DiLoCo; N modules for Streaming DiLoCo. | | |
| | `inner_optimizer` | `torch.optim.Optimizer` | — | Inner-step optimizer (steps every batch). | | |
| | `outer_lr` | `float` | `0.7` | Outer SGD lr. | | |
| | `outer_momentum` | `float` | `0.9` | Outer SGD momentum. | | |
| | `nesterov` | `bool` | `True` | Nesterov momentum on outer SGD. | | |
| | `sync_every` | `int` | `100` | Inner steps per outer round. | | |
| | `fragment_sync_delay` | `int` | `0` | 0 = vanilla; >0 = Streaming DiLoCo (requires CUDA streams). | | |
| | `fragment_update_alpha` | `float` | `0.0` | 0 = full replacement on sync; >0 = exponential mix. | | |
| **Returns** a `torchft.local_sgd.DiLoCo` instance — usable as a context manager. | |
| **Raises** `RuntimeError` if `torchft` is not installed. | |
| ```python | |
| import torch | |
| from composer_replication.diloco import make_diloco_outer_loop | |
| opt = torch.optim.AdamW(model.parameters(), lr=1e-5) | |
| outer = make_diloco_outer_loop(manager=mgr, model_fragments=[model], | |
| inner_optimizer=opt, sync_every=100) | |
| with outer: | |
| for _ in range(N): | |
| opt.zero_grad(); loss.backward(); opt.step() | |
| ``` | |
| --- | |
| ## 12. `composer_replication.diloco.serverless` | |
| ADR-005 serverless DiLoCo executors + object-store all-reduce. | |
| ### `class ReplicaHandle` — `serverless.executor` | |
| ```python | |
| @dataclass | |
| class ReplicaHandle: | |
| rank: int | |
| backend_name: str | |
| metadata: dict[str, Any] = field(default_factory=dict) | |
| ``` | |
| Opaque handle returned by `ServerlessExecutor.launch_replicas`. `metadata` is backend-specific. | |
| ```python | |
| from composer_replication.diloco.serverless import ReplicaHandle | |
| h = ReplicaHandle(rank=0, backend_name="local_process", | |
| metadata={"pid": 12345}) | |
| ``` | |
| ### `class ServerlessExecutor` (Protocol) — `serverless.executor` | |
| ```python | |
| @runtime_checkable | |
| class ServerlessExecutor(Protocol): | |
| backend_name: str | |
| supports_inter_replica_network: bool | |
| def launch_replicas( | |
| self, | |
| n_replicas: int, | |
| entrypoint: str | Callable[..., Any], | |
| entrypoint_args: Mapping[str, Any], | |
| *, | |
| gpu: str | None = None, | |
| timeout: int = 3600, | |
| ) -> list[ReplicaHandle]: ... | |
| def poll(self, handle: ReplicaHandle) -> str: ... | |
| def stream_logs(self, handle: ReplicaHandle, *, n_lines: int = 200) -> str: ... | |
| def cancel(self, handle: ReplicaHandle) -> None: ... | |
| def collect( | |
| self, handles: list[ReplicaHandle], *, timeout: int | None = None, | |
| ) -> list[dict[str, Any]]: ... | |
| ``` | |
| Structural protocol for serverless backends. | |
| - `launch_replicas(...)` returns `list[ReplicaHandle]` of length `n_replicas` in rank order. `entrypoint` is either an importable module path (uses `main()`) or a `module.function` path or a `Callable` (Local executor only). `entrypoint_args` may include `rank_env` (default `"REPLICA_RANK"`). | |
| - `poll(handle) -> str`: one of `"pending"`, `"running"`, `"succeeded"`, `"failed"`, `"cancelled"`. | |
| - `stream_logs(handle, n_lines=200) -> str`: best-effort recent stdout/stderr. | |
| - `cancel(handle) -> None`: best-effort. | |
| - `collect(handles, timeout=None) -> list[dict]`: blocks; each result dict has `rank`, `status`, `exit_code`, `error` (and `result` from `LocalProcessExecutor`). | |
| ```python | |
| from composer_replication.diloco.serverless import ServerlessExecutor | |
| def supports(x: ServerlessExecutor) -> bool: | |
| return isinstance(x, ServerlessExecutor) # runtime_checkable | |
| ``` | |
| ### `class LocalProcessExecutor` — `serverless.executor` | |
| ```python | |
| class LocalProcessExecutor: | |
| backend_name = "local_process" | |
| supports_inter_replica_network = True | |
| def __init__(self) -> None: ... | |
| # implements ServerlessExecutor protocol | |
| ``` | |
| Reference implementation using Python `multiprocessing` (`spawn` context). Used for tests, CI smokes, and local development with `file://` rendezvous. | |
| `launch_replicas(...)`: emits a soft warning on `gpu != None` (local processes share whatever GPUs are visible). `metadata = {"pid": ..., "start_ts": ...}`. | |
| ```python | |
| from composer_replication.diloco.serverless import LocalProcessExecutor | |
| ex = LocalProcessExecutor() | |
| handles = ex.launch_replicas( | |
| n_replicas=2, | |
| entrypoint="composer_replication.diloco.serverless.replica_entrypoint", | |
| entrypoint_args={"rendezvous_uri": "/tmp/run/", "world_size": 2, | |
| "trainer_module": "my.trainer"}, | |
| ) | |
| results = ex.collect(handles, timeout=60) | |
| ``` | |
| ### `class ObjectStoreAllReduce` — `serverless.allreduce` | |
| ```python | |
| class ObjectStoreAllReduce: | |
| def __init__( | |
| self, | |
| uri: str, | |
| rank: int, | |
| world_size: int, | |
| *, | |
| round_id: int | None = None, | |
| timeout_s: float = 1800.0, | |
| poll_interval_s: float = 1.0, | |
| ) -> None: ... | |
| @property | |
| def round_id(self) -> int: ... | |
| def allreduce( | |
| self, tensor: torch.Tensor, *, name: str | None = None, | |
| ) -> torch.Tensor: ... | |
| ``` | |
| fsspec-backed pseudo-gradient rendezvous. `uri` accepts `s3://`, `gs://`, `az://`, `hf://`, `file://`, or a plain local path. | |
| **Constructor parameters** | |
| | Name | Type | Default | Meaning | | |
| |---|---|---|---| | |
| | `uri` | `str` | — | fsspec URI or local path. Trailing `/` enforced. | | |
| | `rank` | `int` | — | This replica's rank. | | |
| | `world_size` | `int` | — | Total replicas. | | |
| | `round_id` | `int \| None` (kw-only) | `None` ⇒ start at 0 | Initial round counter. | | |
| | `timeout_s` | `float` (kw-only) | `1800.0` | Per-`allreduce` timeout. | | |
| | `poll_interval_s` | `float` (kw-only) | `1.0` | Sleep between peer-file existence checks. | | |
| **`allreduce(tensor, name=None) -> torch.Tensor`**: serializes `tensor.detach().cpu()` to `round_NNNNNN/rank_RRRR.pt`, blocks until all peers post, then averages. **Modifies `tensor` in place** AND returns it. Increments the internal `_round_counter`. | |
| **Raises** `ValueError` on invalid `rank`, `RuntimeError` if non-local URI is requested without `fsspec` installed, `TimeoutError` if peers don't show up before `timeout_s`. | |
| ```python | |
| from composer_replication.diloco.serverless import ObjectStoreAllReduce | |
| import torch | |
| store = ObjectStoreAllReduce("/tmp/run/", rank=0, world_size=2) | |
| g = torch.zeros(10) | |
| store.allreduce(g) # blocks for rank 1 | |
| ``` | |
| ### `class MockManager` — `serverless.allreduce` | |
| ```python | |
| class MockManager: | |
| def __init__(self, store: ObjectStoreAllReduce) -> None: ... | |
| # torchft.Manager-shaped surface: | |
| num_participants: int | |
| rank: int | |
| _use_async_quorum: bool # always False | |
| _step: int | |
| _state_dict_fns: dict[str, tuple[Any, Any]] | |
| def allreduce(self, tensor: torch.Tensor, **_kwargs: Any) -> "_ImmediateWork": ... | |
| def should_commit(self) -> bool: ... | |
| def start_quorum(self) -> None: ... | |
| def wait_quorum(self) -> int: ... | |
| def current_step(self) -> int: ... | |
| def allow_state_dict_read(self) -> None: ... | |
| def disallow_state_dict_read(self) -> None: ... | |
| def register_state_dict_fn(self, key: str, load_fn: Any, save_fn: Any) -> None: ... | |
| def is_leader(self) -> bool: ... | |
| ``` | |
| Drop-in replacement for `torchft.Manager` that routes `allreduce` through `ObjectStoreAllReduce`. All other methods are no-ops or simple counters appropriate for single-shot serverless DiLoCo. | |
| - `allreduce(tensor)` returns an `_ImmediateWork` whose `.wait()` is a no-op (the tensor is already averaged). | |
| - `should_commit()` always `True` (no fault-tolerance failover). | |
| - `start_quorum()` bumps `_step`. | |
| - `is_leader()` returns `rank == 0`. | |
| ```python | |
| from composer_replication.diloco.serverless import MockManager, ObjectStoreAllReduce | |
| store = ObjectStoreAllReduce("/tmp/run/", rank=0, world_size=2) | |
| mgr = MockManager(store) | |
| # pass mgr into make_diloco_outer_loop(manager=mgr, ...) | |
| ``` | |
| ### `class _ImmediateWork` — `serverless.allreduce` | |
| ⚠️ UNTESTED-CONTRACT internal helper exported from `__all__`. `Work`-shaped wrapper with `.wait() -> True` and `.get_future() -> torch.futures.Future`. Consumed by torchft DiLoCo's `perform_sync`. | |
| ```python | |
| from composer_replication.diloco.serverless.allreduce import _ImmediateWork | |
| ``` | |
| ### `class ModalExecutor` — `serverless.modal` | |
| 🟡 SKELETON — raises `NotImplementedError`; see ADR-005. Class body documents the v0 implementation pattern (Modal `app.function` + `function.spawn(rank=...)`). | |
| ```python | |
| from composer_replication.diloco.serverless.modal import ModalExecutor | |
| # ModalExecutor() # would NotImplementedError when instantiated | |
| ``` | |
| ### `class HFJobsExecutor` — `serverless.hf_jobs` | |
| 🟡 SKELETON — raises `NotImplementedError`; see ADR-005. Class body documents the v0 pattern using `huggingface_hub.run_job` against `hf://datasets/.../` rendezvous. | |
| ```python | |
| from composer_replication.diloco.serverless.hf_jobs import HFJobsExecutor | |
| # instantiation will fail until v0 implementation lands | |
| ``` | |
| ### `replica_entrypoint.main(...)` — `serverless.replica_entrypoint` | |
| ```python | |
| def main( | |
| rendezvous_uri: str, | |
| world_size: int, | |
| trainer_module: str, | |
| trainer_fn: str = "train", | |
| trainer_kwargs: dict[str, Any] | None = None, | |
| ) -> Any | |
| ``` | |
| Script run by every replica. Reads `REPLICA_RANK` env var, builds `ObjectStoreAllReduce` + `MockManager`, imports `trainer_module`, and calls `getattr(mod, trainer_fn)(**trainer_kwargs, manager=..., rank=..., world_size=...)`. Returns whatever the train fn returns. | |
| **Raises** `RuntimeError` if `REPLICA_RANK` env var is missing; `ValueError` if rank ∉ `[0, world_size)`. | |
| The `if __name__ == "__main__"` block accepts CLI flags `--rendezvous`, `--world-size`, `--trainer-module`, `--trainer-fn`, `--trainer-kwargs-json`. | |
| ```python | |
| # In-process invocation | |
| import os | |
| os.environ["REPLICA_RANK"] = "0" | |
| from composer_replication.diloco.serverless.replica_entrypoint import main | |
| result = main(rendezvous_uri="/tmp/run/", world_size=1, | |
| trainer_module="my.trainer", trainer_fn="train") | |
| ``` | |
| --- | |
| ## 13. `composer_replication.recipes.prime_rl.composer_loss` | |
| PRIME-RL adapter (ADR-006). Maps PRIME-RL's `LossInputs` struct onto channel 1 (DPPO + KL on the importance ratio, mirroring PRIME-RL's upstream `default_loss_fn` at `prime_rl/trainer/rl/loss.py` lines 116-165). Channel 2 raises `NotImplementedError`; channel 3 is out of scope. | |
| ### `loss_fn(inputs, *, alpha_sdpo=0.0, beta_dpo=0.0, dppo_mask_high=0.2, dppo_mask_low=0.2, adv_tau=1.0, kl_tau=1e-3) -> torch.Tensor` | |
| ```python | |
| def loss_fn( | |
| inputs: Any, # PRIME-RL's LossInputs (duck-typed) | |
| *, | |
| alpha_sdpo: float = 0.0, | |
| beta_dpo: float = 0.0, | |
| dppo_mask_high: float = 0.2, | |
| dppo_mask_low: float = 0.2, | |
| adv_tau: float = 1.0, | |
| kl_tau: float = 1e-3, | |
| ) -> Any # torch.Tensor scalar | |
| ``` | |
| PRIME-RL passes per-sample **1-D `(seq,)` tensors** (not batched). The function mirrors PRIME-RL's upstream DPPO+KL formula: | |
| - Mask gate is on **probability-space** `probs_diff = exp(trainer_lp) - exp(inference_lp)` (NOT on the log-ratio). | |
| - A token is dropped iff its advantage sign matches the offending bound: positive-advantage tokens are dropped when `probs_diff > dppo_mask_high`, negative-advantage tokens when `probs_diff < -dppo_mask_low`. (PRIME-RL stores both bounds with `Field(..., ge=0)` and applies the sign internally.) | |
| - The PG term is `keep * (adv_tau * advantages) * exp(trainer_lp - inference_lp)` (importance-ratio corrected, not REINFORCE). | |
| - A KL penalty `kl_tau * log_importance_ratio**2` is added on the full `loss_mask` (DPPO masking does not gate it). | |
| - Reduction is a plain `sum()`; PRIME-RL's outer `compute_loss` divides by `loss_scale`. | |
| **Parameters** | |
| | Name | Type | Default | Meaning | | |
| |---|---|---|---| | |
| | `inputs` | PRIME-RL `LossInputs` (duck-typed) | — | Must expose `trainer_logprobs`, `inference_logprobs`, `advantages`, `loss_mask` (all 1-D), and optionally `teacher_logprobs`. | | |
| | `alpha_sdpo` | `float` (kw-only) | `0.0` | Channel-2 weight. Must be `0` in v0; >0 → `NotImplementedError`. | | |
| | `beta_dpo` | `float` (kw-only) | `0.0` | Channel-3 weight. Non-zero emits a `UserWarning`. | | |
| | `dppo_mask_high` | `float` (kw-only), `>= 0` | `0.2` | Upper probability-diff threshold. PRIME-RL `DefaultLossConfig` default. | | |
| | `dppo_mask_low` | `float` (kw-only), `>= 0` | `0.2` | Magnitude of lower probability-diff threshold (sign flipped internally). PRIME-RL default. | | |
| | `adv_tau` | `float` (kw-only), `>= 0` | `1.0` | Advantage temperature. PRIME-RL default. | | |
| | `kl_tau` | `float` (kw-only), `>= 0` | `1e-3` | KL term temperature. PRIME-RL default. | | |
| **Returns** scalar `torch.Tensor` (PRIME-RL's trainer calls `.backward()`). | |
| **Raises** `ValueError` if any of `trainer_logprobs`, `inference_logprobs`, `advantages`, `loss_mask` is not 1-D, or any of the four `>=0`-constrained knobs is negative. `NotImplementedError` if `alpha_sdpo > 0` (channel 2 deferred). | |
| ```python | |
| from composer_replication.recipes.prime_rl.composer_loss import loss_fn | |
| # In PRIME-RL config: | |
| # loss: | |
| # custom: | |
| # import_path: composer_replication.recipes.prime_rl.composer_loss:loss_fn | |
| # kwargs: | |
| # dppo_mask_high: 0.2 | |
| # dppo_mask_low: 0.2 | |
| # adv_tau: 1.0 | |
| # kl_tau: 1.0e-3 | |
| ``` | |
| --- | |
| ## 14. `composer_replication.recipes.monarch.actors` | |
| 🟡 SKELETON module per ADR-006. Importable; classes raise `NotImplementedError` on instantiation. Documents the actor signatures so the recipe matrix is complete. | |
| ### `class TrainerActor` 🟡 | |
| ```python | |
| class TrainerActor: | |
| backend = "monarch" | |
| role = "trainer" | |
| def __init__(self) -> None: raise NotImplementedError(...) | |
| async def train_outer_step(self, batch_id: int) -> dict[str, Any]: raise NotImplementedError | |
| ``` | |
| Hosts the framework's 3-channel composer trainer. Real impl deferred to v0.2+. | |
| ### `class GeneratorActor` 🟡 | |
| ```python | |
| class GeneratorActor: | |
| backend = "monarch" | |
| role = "generator" | |
| def __init__(self) -> None: raise NotImplementedError(...) | |
| async def rollout(self, prompts: list[str]) -> list[str]: raise NotImplementedError | |
| ``` | |
| vLLM-backed rollout actor. | |
| ### `class RewarderActor` 🟡 | |
| ```python | |
| class RewarderActor: | |
| backend = "monarch" | |
| role = "rewarder" | |
| def __init__(self) -> None: raise NotImplementedError(...) | |
| async def score(self, completions: list[str]) -> list[float]: raise NotImplementedError | |
| ``` | |
| verifiers-protocol rewarder. | |
| ### `class TeacherPoolActor` 🟡 | |
| ```python | |
| class TeacherPoolActor: | |
| backend = "monarch" | |
| role = "teacher_pool" | |
| def __init__(self) -> None: raise NotImplementedError(...) | |
| ``` | |
| Channel-3 teacher pool wrapping `composer_replication.teacher_replay`. | |
| ```python | |
| # All Monarch actors raise on instantiation in v0: | |
| from composer_replication.recipes.monarch.actors import TrainerActor | |
| # TrainerActor() # NotImplementedError | |
| ``` | |
| --- | |
| ## Notes on test coverage | |
| Tested contracts (referenced spike/test paths): | |
| - `compose_loss` + `LossComponents` + `build_batch`: `composer_replication/tests/test_compose_loss_integration.py`, `spikes/006-real-hf-model-smoke/tests/`. | |
| - `generalized_jsd_loss`: `spikes/005-integrated-trainer-skeleton/tests/test_opsd_loss.py`. | |
| - `simpo_loss`, `taid_loss`, `taid_alpha_schedule`, `taid_blended_logits`, `entropy_aware_opd_loss`: `composer_replication/distillation/tests/test_distillation_losses.py`. | |
| - `replay_trace`, `extract_dpo_pairs`, `DPOPair`, `TraceState`, `TeacherCallResult`, `TeacherSpec`, `DEFAULT_TEACHERS`: `spikes/005-integrated-trainer-skeleton/tests/test_teacher_replay.py`. | |
| - `DJNormalizer`, `NormalizedDPOPair`, `replay_and_normalize_trace`: `composer_replication/replaysim/tests/test_replaysim.py`. | |
| - `ClaudeCodeIngester`, `IngestionStats`, `SYSTEM_PROMPT`: `spikes/007-real-trace-ingestion/tests/`. | |
| - `ComposerDataCollator`, `CollatorConfig`, `TraceTurn`, `TraceExample`: `spikes/005-integrated-trainer-skeleton/tests/test_data_collator.py`. | |
| - `ComposerReplicationTrainer._compute_loss` (composition arithmetic): `spikes/005-integrated-trainer-skeleton/tests/test_loss_composition_smoke.py`. | |
| - `make_diloco_outer_loop` + sign convention: `spikes/008-streaming-diloco/tests/test_diloco_smoke.py`. | |
| - `ObjectStoreAllReduce`, `MockManager`, `LocalProcessExecutor`, `ReplicaHandle`, `ServerlessExecutor`, `replica_entrypoint.main`: `composer_replication/diloco/serverless/tests/test_serverless_local.py`, `test_serverless_diloco_integration.py`. | |
| - `recipes.prime_rl.composer_loss.loss_fn`: `composer_replication/recipes/prime_rl/tests/test_composer_loss.py`. | |
| Untested-contract symbols (⚠️) and skeletons (🟡) are flagged inline above. | |
| --- | |
| **Document path**: `/mnt/e/CS/HF/composer-replication-framework/docs/API_REFERENCE.md` | |