# API Reference β€” composer-replication-framework Complete reference for every public symbol in `composer_replication`. Source-of-truth is the `.py` files in `composer_replication/`; docstrings have been pulled verbatim where they exist and supplemented where missing. **Legend** - ⚠️ **UNTESTED-CONTRACT** β€” symbol exists and is callable, but its behaviour is not pinned by an automated test in `composer_replication/**/tests/` or `spikes/**/tests/`. - 🟑 **SKELETON** β€” class/method body raises `NotImplementedError`; ships as design-of-record per ADR-005 / ADR-006. **Module groups (in this document)** 1. `composer_replication` (top-level re-exports) 2. `composer_replication.loss` 3. `composer_replication.batch` 4. `composer_replication.opsd` 5. `composer_replication.distillation` 6. `composer_replication.teacher_replay` 7. `composer_replication.replaysim` 8. `composer_replication.ingestion` (+ `.claude_code`) 9. `composer_replication.hint_generator` 10. `composer_replication.trainer` (+ `.composer_trainer`, `.data_collator`) 11. `composer_replication.diloco` 12. `composer_replication.diloco.serverless` (+ `.executor`, `.allreduce`, `.modal`, `.hf_jobs`, `.replica_entrypoint`) 13. `composer_replication.recipes.prime_rl.composer_loss` 14. `composer_replication.recipes.monarch.actors` --- ## 1. `composer_replication` β€” top-level package The package re-exports the most common entry points from sub-modules. `__all__` is the canonical list of public top-level names. ### `composer_replication.__version__: str` Package version string. Currently `"0.1.0"`. ```python import composer_replication print(composer_replication.__version__) # "0.1.0" ``` ### `composer_replication._DILOCO_AVAILABLE: bool` `True` iff `torchft` is importable in the running Python environment (gates `make_diloco_outer_loop`). Set to `False` and `make_diloco_outer_loop` is set to `None` when `torchft` is missing. ```python from composer_replication import _DILOCO_AVAILABLE if _DILOCO_AVAILABLE: from composer_replication import make_diloco_outer_loop ``` ### Re-exports | Name | Source module | |---|---| | `compose_loss` | `composer_replication.loss` | | `LossComponents` | `composer_replication.loss` | | `build_batch` | `composer_replication.batch` | | `generalized_jsd_loss` | `composer_replication.opsd` | | `ClaudeCodeIngester` | `composer_replication.ingestion.claude_code` | | `IngestionStats` | `composer_replication.ingestion.claude_code` | | `SYSTEM_PROMPT` | `composer_replication.ingestion.claude_code` | | `DEFAULT_TEACHERS` | `composer_replication.teacher_replay` | | `DPOPair` | `composer_replication.teacher_replay` | | `TeacherCallResult` | `composer_replication.teacher_replay` | | `TeacherSpec` | `composer_replication.teacher_replay` | | `TraceState` | `composer_replication.teacher_replay` | | `extract_dpo_pairs` | `composer_replication.teacher_replay` | | `replay_trace` | `composer_replication.teacher_replay` | | `ComposerReplicationTrainer` | `composer_replication.trainer` | | `make_diloco_outer_loop` | `composer_replication.diloco` (or `None` if `torchft` missing) | See each source module below for full signatures. --- ## 2. `composer_replication.loss` Verification-harness 3-channel loss. Free function, does not depend on `trl`. ### `class LossComponents` ```python @dataclass class LossComponents: lm_ce: torch.Tensor sdpo_jsd: torch.Tensor trace_replay_dpo: torch.Tensor total: torch.Tensor def detached(self) -> dict[str, float]: ... ``` Per-channel breakdown of the total loss for logging and ablation. All four fields are scalar `torch.Tensor`s (`shape=()`); `total = lm_ce + alpha_sdpo * sdpo_jsd + beta_replay * trace_replay_dpo`. **`detached() -> dict[str, float]`** β€” returns Python-float copies of all four fields with no grad. Useful for W&B logging. ```python from composer_replication import compose_loss, build_batch components = compose_loss(model, build_batch(tokenizer)) print(components.detached()) # {'lm_ce': 2.34, 'sdpo_jsd': 0.12, ...} components.total.backward() ``` ### `compose_loss(model, inputs, *, ...) -> LossComponents` ```python def compose_loss( model: torch.nn.Module, inputs: dict[str, torch.Tensor], *, alpha_sdpo: float = 0.1, beta_replay: float = 0.05, sdpo_jsd_beta: float = 0.5, sdpo_temperature: float = 1.0, sdpo_token_clip: float | None = None, replay_dpo_beta: float = 0.1, lm_ce_label_smoothing: float = 0.0, dpo_variant: Literal["dpo", "simpo"] = "dpo", sdpo_wrapper: Literal["none", "taid", "entropy_opd"] = "none", taid_t: float | None = None, simpo_beta: float = 2.0, simpo_gamma: float = 1.0, entropy_opd_h_max: float | None = None, ) -> LossComponents ``` Compute `total = lm_ce + alpha_sdpo * sdpo_jsd + beta_replay * trace_replay_dpo`. **Required keys in `inputs`** - `input_ids`: `(B, T_s)` student rollout token ids. - `response_mask`: `(B, T_s)` 1 on assistant-response tokens, 0 elsewhere. **Optional keys** (channel auto-disables if missing OR if its weight = 0): - SDPO: `ctx_teacher_input_ids` `(B, T_t)`, `sdpo_loss_mask` `(B, T_t)`. - DPO (`dpo_variant="dpo"`): `dpo_chosen_input_ids`, `dpo_chosen_response_mask`, `dpo_rejected_input_ids`, `dpo_rejected_response_mask`, `dpo_chosen_ref_logprobs`, `dpo_rejected_ref_logprobs` (precomputed). - SimPO (`dpo_variant="simpo"`): same DPO ids/masks; reference logprobs are silently ignored. - TAID (`sdpo_wrapper="taid"`): no extra `inputs` keys needed; the optional `sdpo_loss_mask` is reused as the per-token TAID mask. Pass `taid_t` directly (or drive it from `TAIDScheduler`). **Parameters** | Name | Type | Default | Meaning | |---|---|---|---| | `model` | `torch.nn.Module` | β€” | HF causal-LM. Must accept `input_ids=` and return an object with `.logits`. | | `inputs` | `dict[str, torch.Tensor]` | β€” | Batch dict (see required/optional keys above). | | `alpha_sdpo` | `float` | `0.1` | Weight on SDPO/JSD channel. `0.0` disables. | | `beta_replay` | `float` | `0.05` | Weight on trace-replay DPO channel. `0.0` disables. | | `sdpo_jsd_beta` | `float` | `0.5` | Ξ² param for `generalized_jsd_loss` (0=fwd KL, 0.5=JSD, 1=rev KL). Unused when `sdpo_wrapper="taid"`. | | `sdpo_temperature` | `float` | `1.0` | Softmax temperature in SDPO. Unused when `sdpo_wrapper="taid"`. | | `sdpo_token_clip` | `float \| None` | `None` | Per-token JSD clamp. | | `replay_dpo_beta` | `float` | `0.1` | Ξ² in standard DPO logit. | | `lm_ce_label_smoothing` | `float` | `0.0` | `F.cross_entropy(label_smoothing=)`. | | `dpo_variant` | `Literal["dpo","simpo"]` | `"dpo"` | Channel-3 algorithm. | | `sdpo_wrapper` | `Literal["none","taid","entropy_opd"]` | `"none"` | Channel-2 wrapper. | | `taid_t` | `float \| None` | `None` | Current TAID interpolation coefficient in `[0, 1]`. Required when `sdpo_wrapper="taid"`. Drive from `TAIDScheduler` or pass a fixed value. | | `simpo_beta` | `float` | `2.0` | SimPO Ξ² (paper default). | | `simpo_gamma` | `float` | `1.0` | SimPO target margin Ξ³ (paper default). | | `entropy_opd_h_max` | `float \| None` | `None` | Max-entropy normalizer; `None` β‡’ `log(V)`. | **Returns** `LossComponents` (see above). **Raises** `ValueError` if `dpo_variant` or `sdpo_wrapper` is unknown, if `sdpo_wrapper="taid"` is requested without `taid_t`, or if `taid_t` is outside `[0, 1]`. ```python from composer_replication import compose_loss, build_batch batch = build_batch(tokenizer) out = compose_loss(model, batch, alpha_sdpo=0.1, beta_replay=0.05) out.total.backward() print(out.detached()) ``` --- ## 3. `composer_replication.batch` Verification-harness batch builder. ### `build_batch(tokenizer, *, ...) -> dict[str, torch.Tensor]` ```python def build_batch( tokenizer: Any, *, device: torch.device | str = "cpu", seed: int = 42, variant: str = "factorial", align_sdpo_shapes: bool = False, ) -> dict[str, torch.Tensor] ``` Construct a full 3-channel batch from a real HF tokenizer. The DPO ref-logprobs are dummy tensors (the smoke verifies loss composition wires together, not the reference-policy precompute). **Returned keys**: `input_ids`, `response_mask`, `ctx_teacher_input_ids`, `sdpo_loss_mask`, `dpo_chosen_input_ids`, `dpo_chosen_response_mask`, `dpo_rejected_input_ids`, `dpo_rejected_response_mask`, `dpo_chosen_ref_logprobs`, `dpo_rejected_ref_logprobs`. **Parameters** | Name | Type | Default | Meaning | |---|---|---|---| | `tokenizer` | HF `AutoTokenizer` (duck-typed) | β€” | Must support `apply_chat_template` and `__call__`. | | `device` | `torch.device \| str` | `"cpu"` | Target device for all returned tensors. | | `seed` | `int` | `42` | Fixes `torch.manual_seed`. | | `variant` | `str` | `"factorial"` | One of `"factorial"`, `"binary_search"`. | | `align_sdpo_shapes` | `bool` | `False` | If True, truncate/pad `ctx_teacher_input_ids` to `input_ids` length so the SDPO channel actually fires. | **Raises** `ValueError` if `variant` is unknown. ```python from transformers import AutoTokenizer from composer_replication import build_batch tok = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-0.5B-Instruct") batch = build_batch(tok, variant="factorial", align_sdpo_shapes=True) print({k: v.shape for k, v in batch.items()}) ``` --- ## 4. `composer_replication.opsd` Self-distillation generalized-JSD loss, lifted verbatim from `siyan-zhao/OPSD` (MIT) per ADR-006. ### `generalized_jsd_loss(student_logits, teacher_logits, labels=None, beta=0.5, ...) -> torch.Tensor` ```python def generalized_jsd_loss( student_logits: torch.Tensor, teacher_logits: torch.Tensor, labels: torch.Tensor | None = None, beta: float = 0.5, temperature: float = 1.0, reduction: str = "batchmean", logits_are_probs: bool = False, top_k: int | None = None, token_clip: float | None = None, ) -> torch.Tensor ``` Generalized JSD between student and teacher distributions. Same model on different contexts in the SDPO recipe; student and teacher params come from the SAME model. **Parameters** | Name | Type | Default | Meaning | |---|---|---|---| | `student_logits` | `Tensor (B, T, V)` | β€” | Student logits with grad. | | `teacher_logits` | `Tensor (B, T, V)` | β€” | Teacher logits (no grad in SDPO). | | `labels` | `Tensor (B, T) \| None` | `None` | Per-token mask. `-100` positions are ignored (HF convention). | | `beta` | `float` in [0, 1] | `0.5` | 0=fwd KL, 1=rev KL, 0.5=symmetric JSD. | | `temperature` | `float` | `1.0` | Softmax temperature. | | `reduction` | `str` | `"batchmean"` | `"batchmean"`, `"sum"`, `"mean"`, `"none"`. | | `logits_are_probs` | `bool` | `False` | Skip softmax if inputs are already probabilities. | | `top_k` | `int \| None` | `None` | Restrict KL to teacher's top-k tokens. | | `token_clip` | `float \| None` | `None` | Clip per-token JSD for stability. | **Returns** scalar tensor (or `(B, T)` if `reduction="none"`). **Raises** `ValueError` for unknown `reduction`. ```python import torch from composer_replication.opsd import generalized_jsd_loss s = torch.randn(2, 8, 32, requires_grad=True) t = torch.randn(2, 8, 32) loss = generalized_jsd_loss(s, t, beta=0.5, reduction="batchmean") loss.backward() ``` --- ## 5. `composer_replication.distillation` Pluggable self-distillation losses (ADR-007). All pure PyTorch. ### `simpo_loss(chosen_avg_logprobs, rejected_avg_logprobs, *, beta=2.0, gamma=1.0) -> torch.Tensor` ```python def simpo_loss( chosen_avg_logprobs: torch.Tensor, rejected_avg_logprobs: torch.Tensor, *, beta: float = 2.0, gamma: float = 1.0, ) -> torch.Tensor ``` Reference-free DPO with target margin Ξ³ (Meng et al., NeurIPS 2024). `L = -log Οƒ(Ξ² Β· (avg_logΟ€(c) βˆ’ avg_logΟ€(r)) βˆ’ Ξ³)`. **Parameters** | Name | Type | Default | Meaning | |---|---|---|---| | `chosen_avg_logprobs` | `Tensor (B,)` | β€” | Per-sequence avg logprob over chosen response tokens. | | `rejected_avg_logprobs` | `Tensor (B,)` | β€” | Same for rejected. | | `beta` | `float` | `2.0` | Scaling factor (paper default). | | `gamma` | `float` | `1.0` | Target margin (paper default). | **Returns** scalar; **Raises** `ValueError` if shapes mismatch. ```python import torch from composer_replication.distillation import simpo_loss loss = simpo_loss(torch.tensor([-2.1, -1.8]), torch.tensor([-3.0, -2.5]), beta=2.0, gamma=1.0) ``` ### `avg_sequence_logprob(model_logprobs, response_mask) -> torch.Tensor` ⚠️ UNTESTED-CONTRACT (helper exported from `simpo.py` but not asserted by a test). ```python def avg_sequence_logprob( model_logprobs: torch.Tensor, response_mask: torch.Tensor, ) -> torch.Tensor ``` Convert `(B, T)` per-token logprobs + `(B, T)` response mask into `(B,)` per-sequence average over response tokens. ```python from composer_replication.distillation.simpo import avg_sequence_logprob import torch lp = torch.randn(2, 8); m = torch.tensor([[0,0,1,1,1,0,0,0],[0,1,1,1,1,1,0,0]]) out = avg_sequence_logprob(lp, m) # shape (2,) ``` ### `taid_loss(student_logits, teacher_logits, mask=None, *, t) -> torch.Tensor` ```python def taid_loss( student_logits: torch.Tensor, teacher_logits: torch.Tensor, mask: torch.Tensor | None = None, *, t: float | torch.Tensor, ) -> torch.Tensor ``` Faithful port of `SakanaAI/TAID` (arXiv:2501.16937). Forward-KL distillation against a logit-space-interpolated target whose anchor is the **current student detached**: ``` p_t = softmax( (1 - t) Β· stop_grad(student_logits) + t Β· teacher_logits ) L = - mean_token Ξ£_v p_t(v) Β· log_softmax(student_logits)(v) ``` At `t=0` the target collapses to the detached student (no teacher signal in the gradient). At `t=1` it reduces to standard forward-KL distillation against the teacher. **Wave 15 breaking change.** The previous signature `taid_loss(student, teacher, student_init, *, schedule_step, total_steps, schedule, alpha_min, alpha_max, jsd_beta, temperature, reduction)` was algorithmically wrong (probability-space mix, frozen step-0 anchor, JSD criterion). All those kwargs are removed; the schedule is now the caller's responsibility (see `TAIDScheduler` below for the upstream adaptive scheme). **Parameters** | Name | Type | Default | Meaning | |---|---|---|---| | `student_logits` | `Tensor (B, T, V)` | β€” | Current student (with grad). | | `teacher_logits` | `Tensor (B, T, V)` | β€” | Teacher logits. | | `mask` | `Tensor (B, T) \| None` | `None` | Token mask. `None` β‡’ all-ones. | | `t` | `float \| Tensor` | β€” | Interpolation coefficient in `[0, 1]`. | **Raises** `ValueError` for shape mismatch. ```python from composer_replication.distillation import taid_loss loss = taid_loss(s_logits, t_logits, mask, t=0.4) ``` ### `TAIDScheduler(num_train_steps, *, t_start=0.4, t_end=1.0, alpha=5e-4, beta=0.99, disable_adaptive=False)` Stateful schedule that mirrors upstream `TAID.update_t`. Monotone non-decreasing, bumped above the linear floor by an EMA on the relative loss change. Use as: ```python from composer_replication.distillation import TAIDScheduler sched = TAIDScheduler(num_train_steps=10_000) # paper defaults for step in range(num_train_steps): loss = taid_loss(s, t, mask, t=sched.t) loss.backward(); optimizer.step() sched.update_t(loss.detach(), global_step=step) ``` **Parameters** | Name | Type | Default | Meaning | |---|---|---|---| | `num_train_steps` | `int` | β€” | Total planned training steps; sets the linear floor. | | `t_start` | `float` | `0.4` | Initial `t` (paper default). | | `t_end` | `float` | `1.0` | Terminal `t`; hard ceiling at every step. | | `alpha` | `float` | `5e-4` | Adaptive bump magnitude. | | `beta` | `float` | `0.99` | EMA decay on relative-loss-change momentum. | | `disable_adaptive` | `bool` | `False` | If True, fall back to deterministic linear schedule. | | `device` | `torch.device \| str` | `"cpu"` | Where to allocate state buffers. | **Properties / methods** - `sched.t -> float` β€” current `t` as a Python float (zero-arg property). - `sched.update_t(loss, global_step) -> Tensor | None` β€” update internal state. First finite-loss call only seeds `prev_loss` and returns `None`; subsequent calls return the (positive) `delta_t` added on top of the linear floor. ### `entropy_aware_opd_loss(student_logits, teacher_logits, *, labels=None, h_max=None, temperature=1.0, reduction="batchmean") -> torch.Tensor` ```python def entropy_aware_opd_loss( student_logits: torch.Tensor, teacher_logits: torch.Tensor, *, labels: torch.Tensor | None = None, h_max: float | None = None, temperature: float = 1.0, reduction: str = "batchmean", ) -> torch.Tensor ``` Per-token mixture of forward and reverse KL gated by teacher entropy: `w(t) = clamp(H_teacher(t)/h_max, 0, 1)`. High-entropy tokens use forward KL (mode-covering), low-entropy tokens use reverse KL (mode-seeking). **Parameters** | Name | Type | Default | Meaning | |---|---|---|---| | `student_logits` | `Tensor (B,T,V)` | β€” | Student logits (grad). | | `teacher_logits` | `Tensor (B,T,V)` | β€” | Teacher logits (no grad). | | `labels` | `Tensor (B,T) \| None` | `None` | 0/1 mask, applied multiplicatively after the per-token mix. | | `h_max` | `float \| None` | `None` β‡’ `log(V)` | Max-entropy normalizer. | | `temperature` | `float` | `1.0` | Softmax temperature on both. | | `reduction` | `str` | `"batchmean"` | `"batchmean"`, `"sum"`, `"mean"`, `"none"`. | **Raises** `ValueError` on shape mismatch (student vs teacher; labels vs per-token loss) or unknown `reduction`. ```python from composer_replication.distillation import entropy_aware_opd_loss loss = entropy_aware_opd_loss(s_logits, t_logits, temperature=1.0) loss.backward() ``` ### `teacher_entropy(teacher_logits) -> torch.Tensor` ⚠️ UNTESTED-CONTRACT (helper exposed from `entropy_aware_opd.py`'s `__all__` but not directly asserted). Per-token entropy in nats. Input `(B,T,V)`, output `(B,T)`. ```python from composer_replication.distillation.entropy_aware_opd import teacher_entropy H = teacher_entropy(teacher_logits) # (B, T) ``` --- ## 6. `composer_replication.teacher_replay` N-teacher OpenRouter parallel client + DPO-pair extractor. `httpx` is lazy-imported inside `replay_trace`; the deterministic local logic is testable without it. ### `DEFAULT_TEACHERS: list[TeacherSpec]` Three-teacher default set: `anthropic/claude-opus-4.7`, `openai/gpt-5`, `deepseek/deepseek-v4-pro` with paper-baseline OpenRouter pricing. ```python from composer_replication.teacher_replay import DEFAULT_TEACHERS print([t["slug"] for t in DEFAULT_TEACHERS]) ``` ### `class TeacherSpec(TypedDict)` ```python class TeacherSpec(TypedDict): slug: str input_per_mtok: float output_per_mtok: float ``` OpenRouter model slug + per-million-token pricing. ```python spec: TeacherSpec = {"slug": "openai/gpt-5", "input_per_mtok": 1.25, "output_per_mtok": 10.0} ``` ### `class TraceState(TypedDict)` ```python class TraceState(TypedDict): state_id: str # unique within the trace messages: list[dict] # OpenAI-style chat history up to (and incl.) this user prompt student_action: str # what the student actually did at this step ``` One step of a frozen agentic trace. `student_action` is the raw text emitted by the student; teachers are queried with `messages` and asked to predict the assistant's next action. ```python state: TraceState = {"state_id": "ex001::0042", "messages": [{"role": "user", "content": "..."}], "student_action": "[TOOL_USE] name=Read input={...}"} ``` ### `class TeacherCallResult(TypedDict)` ```python class TeacherCallResult(TypedDict): state_id: str teacher_slug: str response_text: str | None # None on error latency_s: float prompt_tokens: int completion_tokens: int cost_usd: float error: str | None # None on success ``` One row of NΓ—T results from `replay_trace`. ```python r: TeacherCallResult = {"state_id": "x", "teacher_slug": "openai/gpt-5", "response_text": "ok", "latency_s": 1.2, "prompt_tokens": 100, "completion_tokens": 5, "cost_usd": 0.001, "error": None} ``` ### `class DPOPair(TypedDict)` ```python class DPOPair(TypedDict): state_id: str state_messages: list[dict] chosen: str # teacher-consensus action rejected: str # student action n_teachers_agreeing: int ``` One preference pair extracted from teacher-vs-student disagreement. ```python p: DPOPair = {"state_id": "x", "state_messages": [...], "chosen": "...", "rejected": "...", "n_teachers_agreeing": 2} ``` ### `async replay_trace(states, teachers=DEFAULT_TEACHERS, max_total_usd=5.0, api_key=None) -> list[TeacherCallResult]` ```python async def replay_trace( states: Sequence[TraceState], teachers: Sequence[TeacherSpec] = tuple(DEFAULT_TEACHERS), max_total_usd: float = 5.0, api_key: str | None = None, ) -> list[TeacherCallResult] ``` For each state, fan-out one parallel call per teacher via OpenRouter. Hard-caps cumulative spend at `max_total_usd` (stops after the offending state completes). **Parameters** | Name | Type | Default | Meaning | |---|---|---|---| | `states` | `Sequence[TraceState]` | β€” | Frozen trace, one entry per assistant turn. | | `teachers` | `Sequence[TeacherSpec]` | `DEFAULT_TEACHERS` | Models to query in parallel. | | `max_total_usd` | `float` | `5.0` | Cumulative spend cap. | | `api_key` | `str \| None` | `None` | OpenRouter key; defaults to `OPENROUTER_API_KEY` env or `~/.hermes/.env`. | **Returns** flat list of `TeacherCallResult`s (length `len(states) * len(teachers)` modulo budget cutoff). **Raises** `RuntimeError` if `OPENROUTER_API_KEY` is not findable; `ImportError` if `httpx` is missing at call time. ```python import asyncio from composer_replication import replay_trace results = asyncio.run(replay_trace(states=my_trace, max_total_usd=1.0)) ``` ### `extract_dpo_pairs(states, teacher_actions, agreement_threshold=2) -> list[DPOPair]` ```python def extract_dpo_pairs( states: Sequence[TraceState], teacher_actions: Sequence[TeacherCallResult], agreement_threshold: int = 2, ) -> list[DPOPair] ``` Group teacher_actions by `state_id`, normalize whitespace, and emit one `DPOPair` per state where β‰₯`agreement_threshold` teachers agreed on an action that differs from the student's. `chosen` is the original (un-normalized) teacher response text. **Parameters** | Name | Type | Default | Meaning | |---|---|---|---| | `states` | `Sequence[TraceState]` | β€” | Same as passed to `replay_trace`. | | `teacher_actions` | `Sequence[TeacherCallResult]` | β€” | Output of `replay_trace`. | | `agreement_threshold` | `int` | `2` | Min teachers that must agree for a pair to fire. | **Returns** list of `DPOPair`. At most one pair per state (the most-agreed-upon action wins). ```python from composer_replication import extract_dpo_pairs pairs = extract_dpo_pairs(my_states, results, agreement_threshold=2) ``` ### `save_pairs(pairs, path) -> None` ⚠️ UNTESTED-CONTRACT. ```python def save_pairs(pairs: Sequence[DPOPair], path: str | Path) -> None ``` Write pairs to JSONL (one dict per line). Creates parent dirs. ```python from composer_replication.teacher_replay import save_pairs save_pairs(pairs, "/tmp/dpo_pairs.jsonl") ``` --- ## 7. `composer_replication.replaysim` ADR-004 normalization layer over `teacher_replay`. Re-exports `DPOPair`, `TeacherCallResult`, `extract_dpo_pairs`, `replay_trace` from `teacher_replay`. ### `class NormalizedDPOPair` ```python @dataclass class NormalizedDPOPair: state_id: str state_messages: list[dict[str, Any]] chosen_messages: list[dict[str, Any]] rejected_messages: list[dict[str, Any]] n_teachers_agreeing: int metadata: dict[str, Any] ``` Post-normalization shape. `chosen_messages`/`rejected_messages` are chat-format (`[{"role": "assistant", "content": ...}]`). `metadata` carries op-graph provenance, including `{"skipped": True}` when the normalizer was bypassed (`skip_dj=True`). ```python from composer_replication.replaysim import NormalizedDPOPair n = NormalizedDPOPair(state_id="x", state_messages=[], chosen_messages=[{"role": "assistant", "content": "ok"}], rejected_messages=[{"role": "assistant", "content": "no"}], n_teachers_agreeing=2, metadata={}) ``` ### `class DJNormalizer` ```python class DJNormalizer: DEFAULT_RECIPE: ClassVar[Path] # composer_replication/recipes/replaysim/default.yaml def __init__( self, recipe_path: str | os.PathLike[str] | None = None, *, skip_dj: bool = False, ) -> None: ... def normalize( self, pairs: Iterable[DPOPair | dict[str, Any]], ) -> list[NormalizedDPOPair]: ... ``` `data-juicer`-backed normalizer. Pipeline: each `DPOPair` β†’ JSONL record β†’ `data_juicer.core.DefaultExecutor.run()` against the recipe β†’ JSONL β†’ `NormalizedDPOPair`. **Constructor parameters** | Name | Type | Default | Meaning | |---|---|---|---| | `recipe_path` | `str \| PathLike \| None` | `None` β‡’ default recipe | data-juicer YAML recipe path. | | `skip_dj` | `bool` (kw-only) | `False` | If True: passthrough; records get `metadata={"skipped": True}` and no ops run. | **`normalize(pairs) -> list[NormalizedDPOPair]`** runs the op-graph. Output may be shorter than input if filter ops drop records. **Raises** `RuntimeError` at construction time if `skip_dj=False` and `data_juicer` is not importable. `FileNotFoundError` if `recipe_path` (default or explicit) is missing and `skip_dj=False`. ```python from composer_replication.replaysim import DJNormalizer norm = DJNormalizer(skip_dj=True) out = norm.normalize(my_pairs) ``` ### `async replay_and_normalize_trace(*, states, teachers=None, agreement_threshold=2, max_total_usd=5.0, normalizer=None, **replay_kwargs) -> tuple[list[TeacherCallResult], list[NormalizedDPOPair]]` ```python async def replay_and_normalize_trace( *, states: Any, teachers: Any = None, agreement_threshold: int = 2, max_total_usd: float = 5.0, normalizer: DJNormalizer | None = None, **replay_kwargs: Any, ) -> tuple[list[TeacherCallResult], list[NormalizedDPOPair]] ``` End-to-end async: replay β†’ extract pairs β†’ normalize. **Parameters** | Name | Type | Default | Meaning | |---|---|---|---| | `states` | `Sequence[TraceState]` | β€” | Frozen trace. | | `teachers` | `Sequence[TeacherSpec] \| None` | `None` β‡’ defaults | Forwarded to `replay_trace`. | | `agreement_threshold` | `int` | `2` | Forwarded to `extract_dpo_pairs`. | | `max_total_usd` | `float` | `5.0` | Spend cap. | | `normalizer` | `DJNormalizer \| None` | `None` β‡’ `DJNormalizer()` | Pass `DJNormalizer(skip_dj=True)` to bypass. | | `**replay_kwargs` | `Any` | β€” | Forwarded to `replay_trace` (e.g. `api_key`). | **Returns** `(raw_teacher_actions, normalized_pairs)`. ```python import asyncio from composer_replication.replaysim import replay_and_normalize_trace, DJNormalizer raw, norm = asyncio.run(replay_and_normalize_trace( states=my_states, normalizer=DJNormalizer(skip_dj=True))) ``` ### `replay_and_normalize_trace_sync(*args, **kwargs) -> tuple[list[TeacherCallResult], list[NormalizedDPOPair]]` ⚠️ UNTESTED-CONTRACT (sync wrapper around the async function; tests call the async form via `asyncio.run`). ```python def replay_and_normalize_trace_sync(*args, **kwargs) -> ... ``` Sync convenience wrapping `asyncio.run(replay_and_normalize_trace(...))`. ```python from composer_replication.replaysim.normalize import replay_and_normalize_trace_sync raw, norm = replay_and_normalize_trace_sync(states=my_states) ``` --- ## 8. `composer_replication.ingestion` & `composer_replication.ingestion.claude_code` Trace-source adapters (ADR-002). v0.1 supports Claude Code session JSONL. ### `SYSTEM_PROMPT: str` Default synthetic system prompt injected at `messages[0]` for ingested traces (most Claude Code sessions don't write one). Truncated head: `"You are a senior software engineer working as a coding agent in a terminal environment..."`. ```python from composer_replication import SYSTEM_PROMPT print(SYSTEM_PROMPT[:60]) ``` ### `class IngestionStats` ```python @dataclass class IngestionStats: n_records_total: int = 0 n_records_skipped: int = 0 n_states_emitted: int = 0 n_assistant_turns: int = 0 n_tool_use_blocks: int = 0 n_text_blocks: int = 0 skipped_subagent: int = 0 skipped_summary: int = 0 skipped_truncated_lines: int = 0 version_warnings: list[str] | None = None # initialized to [] in __post_init__ ``` Counters populated by `ClaudeCodeIngester.ingest()` and exposed as `ingester.last_stats`. ```python from composer_replication import IngestionStats s = IngestionStats(n_records_total=5) print(s.version_warnings) # [] ``` ### `class ClaudeCodeIngester` ```python class ClaudeCodeIngester: def __init__( self, *, system_prompt: str = SYSTEM_PROMPT, skip_sidechain: bool = True, strip_thinking: bool = True, max_history_tokens: int | None = None, ) -> None: ... def ingest(self, path: Path) -> Iterator[TraceState]: ... ``` Convert a Claude Code session JSONL to a stream of `TraceState`s β€” one per assistant TURN (not per `tool_use` block). **Constructor parameters** | Name | Type | Default | Meaning | |---|---|---|---| | `system_prompt` | `str` | `SYSTEM_PROMPT` | Synthetic system message injected at history[0]. | | `skip_sidechain` | `bool` | `True` | Skip subagent files (`agent-*.jsonl`) and records with `isSidechain=True`. | | `strip_thinking` | `bool` | `True` | Remove `[THINKING]` blocks from history handed to teachers (kept inside `student_action`). | | `max_history_tokens` | `int \| None` | `None` | ⚠️ UNTESTED-CONTRACT β€” accepted but currently not used to truncate. | **`ingest(path) -> Iterator[TraceState]`**: generator over `TraceState` objects. Each turn's `state_id` is `f"{path.stem}::{idx:04d}"`. Side effect: replaces `self.last_stats` with a fresh `IngestionStats` and updates it as records stream. ```python from pathlib import Path from composer_replication import ClaudeCodeIngester ing = ClaudeCodeIngester() for state in ing.ingest(Path("session.jsonl")): print(state["state_id"]) print(ing.last_stats.n_states_emitted) ``` --- ## 9. `composer_replication.hint_generator` ⚠️ UNTESTED-CONTRACT (entire module β€” used by the data collator config but not pinned by a test). Template-based hint registry for SDPO error-site injection. ### `class HintContext(TypedDict, total=False)` ```python class HintContext(TypedDict, total=False): error_kind: str error_message: str available_tools: list[str] tool_name: str tool_schema: dict intent: str ``` Per-error context dict consumed by hint templates. ### `HINT_TEMPLATES: dict[str, Callable[[HintContext], str]]` Default registry keys: `"tool_not_found"`, `"json_decode"`, `"type_error"`, `"runtime_error"`, `"repeated_failure"`. ### `dispatch(error_kind, ctx=None) -> str | None` ```python def dispatch(error_kind: str, ctx: HintContext | None = None) -> str | None ``` Look up `error_kind` in `HINT_TEMPLATES`. Returns the template's hint text, or `None` if the kind is unknown. ```python from composer_replication.hint_generator import dispatch hint = dispatch("json_decode") # "Reminder: tool arguments must be valid JSON. ..." ``` ### `register(error_kind, fn) -> None` ```python def register(error_kind: str, fn: Callable[[HintContext], str]) -> None ``` Add or override a custom hint template. ```python from composer_replication.hint_generator import register register("my_error", lambda ctx: "Reminder: try X.") ``` ### Individual template functions ⚠️ UNTESTED-CONTRACT β€” exported only via `HINT_TEMPLATES`, useful as building blocks: - `hint_tool_not_found(ctx) -> str` - `hint_json_decode(ctx) -> str` - `hint_type_error(ctx) -> str` - `hint_runtime_error(ctx) -> str` - `hint_repeated_failure(ctx) -> str` Each accepts a `HintContext` and returns hint text. Signatures are uniform: `Callable[[HintContext], str]`. ```python from composer_replication.hint_generator import hint_tool_not_found text = hint_tool_not_found({"available_tools": ["Read", "Write"]}) ``` --- ## 10. `composer_replication.trainer` & sub-modules Production trainer (TRL `GRPOTrainer` subclass) plus data collator. ### `class ComposerReplicationTrainer` ```python class ComposerReplicationTrainer(GRPOTrainer): def __init__( self, *args: Any, alpha_sdpo: float = 0.1, beta_replay: float = 0.05, sdpo_jsd_beta: float = 0.5, sdpo_temperature: float = 1.0, sdpo_token_clip: float | None = None, replay_dpo_beta: float = 0.1, **kwargs: Any, ) -> None: ... def _compute_loss( self, model: torch.nn.Module, inputs: dict[str, torch.Tensor], ) -> torch.Tensor: ... ``` `trl.GRPOTrainer` subclass that overrides `_compute_loss(model, inputs)` to compose `total = grpo + Ξ±Β·sdpo + Ξ²Β·trace_replay_dpo`. When `trl` is not installed, the parent class falls back to `object` so the module imports β€” but instantiation will fail because the parent's GRPO machinery is missing. **Constructor (kw-only beyond GRPOTrainer's own `*args, **kwargs`)** | Name | Type | Default | Meaning | |---|---|---|---| | `alpha_sdpo` | `float` | `0.1` | Channel-2 weight. | | `beta_replay` | `float` | `0.05` | Channel-3 weight. | | `sdpo_jsd_beta` | `float` | `0.5` | Ξ² for `generalized_jsd_loss`. | | `sdpo_temperature` | `float` | `1.0` | SDPO softmax temperature. | | `sdpo_token_clip` | `float \| None` | `None` | Per-token JSD clip. | | `replay_dpo_beta` | `float` | `0.1` | DPO Ξ². | **`_compute_loss(model, inputs) -> torch.Tensor`** β€” overrides `GRPOTrainer._compute_loss`. Calls `super()._compute_loss` for channel 1, then `_compute_sdpo_loss` and `_compute_trace_replay_loss`, then composes. Logs per-channel components every `args.logging_steps` (default 50). **Raises** whatever `super()` raises (TRL-shaped errors). **Internal methods (publicly accessible, exercised by spike tests)** - ⚠️ UNTESTED-CONTRACT `_compute_sdpo_loss(model, inputs) -> torch.Tensor` β€” generalized-JSD between student forward and `ctx_teacher_input_ids` forward. Returns `0.0` (with grad) when `alpha_sdpo == 0`, the key is missing, or shapes mismatch. Logs a warning on shape mismatch. - ⚠️ UNTESTED-CONTRACT `_compute_trace_replay_loss(model, inputs) -> torch.Tensor` β€” standard DPO over `dpo_chosen_*` and `dpo_rejected_*`, using precomputed `dpo_chosen_ref_logprobs` / `dpo_rejected_ref_logprobs`. - ⚠️ UNTESTED-CONTRACT `@staticmethod _sequence_logprobs(model, input_ids, response_mask) -> torch.Tensor` β€” sum logprobs over response tokens; standard DPO accounting. ```python from composer_replication import ComposerReplicationTrainer trainer = ComposerReplicationTrainer( model=my_model, args=my_grpo_args, train_dataset=ds, data_collator=my_collator, alpha_sdpo=0.1, beta_replay=0.05, ) # trainer.train() # uses overridden _compute_loss ``` ### `class TraceTurn(TypedDict, total=False)` β€” `trainer.data_collator` ```python class TraceTurn(TypedDict, total=False): role: str # "user" | "assistant" | "tool" content: str tool_call: dict | None tool_error: str | None error_meta: dict ``` One turn of an agentic trace as consumed by `ComposerDataCollator`. ### `class TraceExample(TypedDict, total=False)` β€” `trainer.data_collator` ```python class TraceExample(TypedDict, total=False): trace_id: str turns: list[TraceTurn] final_reward: float dpo_pairs: list[dict] | None ``` One training example: `(turns, optional dpo_pairs)`. `dpo_pairs` shape matches `DPOPair`. ### `class TokenizerLike` β€” `trainer.data_collator` ⚠️ UNTESTED-CONTRACT (duck-typed protocol; used as a type hint). ```python class TokenizerLike: pad_token_id: int def __call__(self, text: str | list[str], **kwargs: Any) -> dict[str, list]: ... def apply_chat_template(self, messages: list[dict], **kwargs: Any) -> str | list[int]: ... ``` Minimal protocol the collator needs. Compatible with HF `AutoTokenizer`. ### `class CollatorConfig` β€” `trainer.data_collator` ```python @dataclass class CollatorConfig: max_seq_len: int = 4096 max_dpo_seq_len: int = 2048 pad_token_id: int = 0 ignore_index: int = -100 enable_sdpo: bool = True hint_generator: Callable[[str, dict], str | None] | None = None enable_replay_dpo: bool = True rlvr_reward_key: str = "final_reward" ``` Tunables for `ComposerDataCollator`. | Field | Default | Meaning | |---|---|---| | `max_seq_len` | `4096` | Truncation cap for student/teacher sequences. | | `max_dpo_seq_len` | `2048` | Truncation cap for DPO chosen/rejected sequences. | | `pad_token_id` | `0` | Padding token id. | | `ignore_index` | `-100` | HF "ignore in loss" sentinel for SDPO mask. | | `enable_sdpo` | `True` | Toggle channel-2 fields. | | `hint_generator` | `Callable[[str, dict], str \| None] \| None` (`None`) | `(error_kind, error_meta) -> hint_text`. SDPO is no-op without this. | | `enable_replay_dpo` | `True` | Toggle channel-3 fields. | | `rlvr_reward_key` | `"final_reward"` | Key in `TraceExample` to read scalar reward. | ```python from composer_replication.trainer.data_collator import CollatorConfig cfg = CollatorConfig(max_seq_len=2048, hint_generator=my_dispatch) ``` ### `class ComposerDataCollator` β€” `trainer.data_collator` ```python @dataclass class ComposerDataCollator: tokenizer: TokenizerLike config: CollatorConfig = field(default_factory=CollatorConfig) def __call__( self, batch: Sequence[TraceExample] ) -> dict[str, torch.Tensor]: ... ``` Build trainer-ready batches from raw traces + optional DPO pairs. **Output dict keys** (tested in `spikes/005-integrated-trainer-skeleton/tests/test_data_collator.py`): - Channel 1 (always): `input_ids`, `attention_mask`, `response_mask`, `rewards`. - Channel 2 (when `enable_sdpo=True` AND batch has at least one error site AND `hint_generator` is set): `ctx_teacher_input_ids`, `sdpo_loss_mask`. - Channel 3 (when `enable_replay_dpo=True` AND batch has at least one `dpo_pair`): `dpo_chosen_input_ids`, `dpo_chosen_response_mask`, `dpo_rejected_input_ids`, `dpo_rejected_response_mask`. (Reference logprobs are NOT computed here β€” the trainer does that pass.) ```python from composer_replication.trainer.data_collator import ( ComposerDataCollator, CollatorConfig) collator = ComposerDataCollator(tokenizer=tok, config=CollatorConfig()) batch = collator([{"trace_id": "x", "turns": [...], "final_reward": 1.0}]) ``` --- ## 11. `composer_replication.diloco` DiLoCo outer-loop wrapper around `torchft.local_sgd.DiLoCo`. Optional dep β€” when `torchft` is missing the package re-export `composer_replication.make_diloco_outer_loop` is `None`. ### Module-level attributes - `DiLoCo: Any` β€” `torchft.local_sgd.DiLoCo` if importable else `None`. - `Manager: Any` β€” `torchft.manager.Manager` if importable else `None`. - `_DummyWork: Any` β€” `torchft.work._DummyWork` if importable else `None`. - `_TORCHFT_AVAILABLE: bool` β€” whether the imports succeeded. ```python from composer_replication.diloco import _TORCHFT_AVAILABLE, DiLoCo ``` ### `make_diloco_outer_loop(manager, model_fragments, inner_optimizer, *, ...) -> torchft.local_sgd.DiLoCo` ```python def make_diloco_outer_loop( manager: Any, model_fragments: list[torch.nn.Module], inner_optimizer: torch.optim.Optimizer, *, outer_lr: float = 0.7, outer_momentum: float = 0.9, nesterov: bool = True, sync_every: int = 100, fragment_sync_delay: int = 0, fragment_update_alpha: float = 0.0, ) -> Any ``` Construct a `torchft.DiLoCo` configured with framework-default hyperparams (DiLoCo paper Β§3.2: `lr=0.7, momentum=0.9, Nesterov`). **Parameters** | Name | Type | Default | Meaning | |---|---|---|---| | `manager` | `torchft.Manager` (or duck-typed `MockManager`) | β€” | Provides `allreduce`, `should_commit`, `current_step`, `start_quorum`, etc. | | `model_fragments` | `list[torch.nn.Module]` | β€” | One module for vanilla DiLoCo; N modules for Streaming DiLoCo. | | `inner_optimizer` | `torch.optim.Optimizer` | β€” | Inner-step optimizer (steps every batch). | | `outer_lr` | `float` | `0.7` | Outer SGD lr. | | `outer_momentum` | `float` | `0.9` | Outer SGD momentum. | | `nesterov` | `bool` | `True` | Nesterov momentum on outer SGD. | | `sync_every` | `int` | `100` | Inner steps per outer round. | | `fragment_sync_delay` | `int` | `0` | 0 = vanilla; >0 = Streaming DiLoCo (requires CUDA streams). | | `fragment_update_alpha` | `float` | `0.0` | 0 = full replacement on sync; >0 = exponential mix. | **Returns** a `torchft.local_sgd.DiLoCo` instance β€” usable as a context manager. **Raises** `RuntimeError` if `torchft` is not installed. ```python import torch from composer_replication.diloco import make_diloco_outer_loop opt = torch.optim.AdamW(model.parameters(), lr=1e-5) outer = make_diloco_outer_loop(manager=mgr, model_fragments=[model], inner_optimizer=opt, sync_every=100) with outer: for _ in range(N): opt.zero_grad(); loss.backward(); opt.step() ``` --- ## 12. `composer_replication.diloco.serverless` ADR-005 serverless DiLoCo executors + object-store all-reduce. ### `class ReplicaHandle` β€” `serverless.executor` ```python @dataclass class ReplicaHandle: rank: int backend_name: str metadata: dict[str, Any] = field(default_factory=dict) ``` Opaque handle returned by `ServerlessExecutor.launch_replicas`. `metadata` is backend-specific. ```python from composer_replication.diloco.serverless import ReplicaHandle h = ReplicaHandle(rank=0, backend_name="local_process", metadata={"pid": 12345}) ``` ### `class ServerlessExecutor` (Protocol) β€” `serverless.executor` ```python @runtime_checkable class ServerlessExecutor(Protocol): backend_name: str supports_inter_replica_network: bool def launch_replicas( self, n_replicas: int, entrypoint: str | Callable[..., Any], entrypoint_args: Mapping[str, Any], *, gpu: str | None = None, timeout: int = 3600, ) -> list[ReplicaHandle]: ... def poll(self, handle: ReplicaHandle) -> str: ... def stream_logs(self, handle: ReplicaHandle, *, n_lines: int = 200) -> str: ... def cancel(self, handle: ReplicaHandle) -> None: ... def collect( self, handles: list[ReplicaHandle], *, timeout: int | None = None, ) -> list[dict[str, Any]]: ... ``` Structural protocol for serverless backends. - `launch_replicas(...)` returns `list[ReplicaHandle]` of length `n_replicas` in rank order. `entrypoint` is either an importable module path (uses `main()`) or a `module.function` path or a `Callable` (Local executor only). `entrypoint_args` may include `rank_env` (default `"REPLICA_RANK"`). - `poll(handle) -> str`: one of `"pending"`, `"running"`, `"succeeded"`, `"failed"`, `"cancelled"`. - `stream_logs(handle, n_lines=200) -> str`: best-effort recent stdout/stderr. - `cancel(handle) -> None`: best-effort. - `collect(handles, timeout=None) -> list[dict]`: blocks; each result dict has `rank`, `status`, `exit_code`, `error` (and `result` from `LocalProcessExecutor`). ```python from composer_replication.diloco.serverless import ServerlessExecutor def supports(x: ServerlessExecutor) -> bool: return isinstance(x, ServerlessExecutor) # runtime_checkable ``` ### `class LocalProcessExecutor` β€” `serverless.executor` ```python class LocalProcessExecutor: backend_name = "local_process" supports_inter_replica_network = True def __init__(self) -> None: ... # implements ServerlessExecutor protocol ``` Reference implementation using Python `multiprocessing` (`spawn` context). Used for tests, CI smokes, and local development with `file://` rendezvous. `launch_replicas(...)`: emits a soft warning on `gpu != None` (local processes share whatever GPUs are visible). `metadata = {"pid": ..., "start_ts": ...}`. ```python from composer_replication.diloco.serverless import LocalProcessExecutor ex = LocalProcessExecutor() handles = ex.launch_replicas( n_replicas=2, entrypoint="composer_replication.diloco.serverless.replica_entrypoint", entrypoint_args={"rendezvous_uri": "/tmp/run/", "world_size": 2, "trainer_module": "my.trainer"}, ) results = ex.collect(handles, timeout=60) ``` ### `class ObjectStoreAllReduce` β€” `serverless.allreduce` ```python class ObjectStoreAllReduce: def __init__( self, uri: str, rank: int, world_size: int, *, round_id: int | None = None, timeout_s: float = 1800.0, poll_interval_s: float = 1.0, ) -> None: ... @property def round_id(self) -> int: ... def allreduce( self, tensor: torch.Tensor, *, name: str | None = None, ) -> torch.Tensor: ... ``` fsspec-backed pseudo-gradient rendezvous. `uri` accepts `s3://`, `gs://`, `az://`, `hf://`, `file://`, or a plain local path. **Constructor parameters** | Name | Type | Default | Meaning | |---|---|---|---| | `uri` | `str` | β€” | fsspec URI or local path. Trailing `/` enforced. | | `rank` | `int` | β€” | This replica's rank. | | `world_size` | `int` | β€” | Total replicas. | | `round_id` | `int \| None` (kw-only) | `None` β‡’ start at 0 | Initial round counter. | | `timeout_s` | `float` (kw-only) | `1800.0` | Per-`allreduce` timeout. | | `poll_interval_s` | `float` (kw-only) | `1.0` | Sleep between peer-file existence checks. | **`allreduce(tensor, name=None) -> torch.Tensor`**: serializes `tensor.detach().cpu()` to `round_NNNNNN/rank_RRRR.pt`, blocks until all peers post, then averages. **Modifies `tensor` in place** AND returns it. Increments the internal `_round_counter`. **Raises** `ValueError` on invalid `rank`, `RuntimeError` if non-local URI is requested without `fsspec` installed, `TimeoutError` if peers don't show up before `timeout_s`. ```python from composer_replication.diloco.serverless import ObjectStoreAllReduce import torch store = ObjectStoreAllReduce("/tmp/run/", rank=0, world_size=2) g = torch.zeros(10) store.allreduce(g) # blocks for rank 1 ``` ### `class MockManager` β€” `serverless.allreduce` ```python class MockManager: def __init__(self, store: ObjectStoreAllReduce) -> None: ... # torchft.Manager-shaped surface: num_participants: int rank: int _use_async_quorum: bool # always False _step: int _state_dict_fns: dict[str, tuple[Any, Any]] def allreduce(self, tensor: torch.Tensor, **_kwargs: Any) -> "_ImmediateWork": ... def should_commit(self) -> bool: ... def start_quorum(self) -> None: ... def wait_quorum(self) -> int: ... def current_step(self) -> int: ... def allow_state_dict_read(self) -> None: ... def disallow_state_dict_read(self) -> None: ... def register_state_dict_fn(self, key: str, load_fn: Any, save_fn: Any) -> None: ... def is_leader(self) -> bool: ... ``` Drop-in replacement for `torchft.Manager` that routes `allreduce` through `ObjectStoreAllReduce`. All other methods are no-ops or simple counters appropriate for single-shot serverless DiLoCo. - `allreduce(tensor)` returns an `_ImmediateWork` whose `.wait()` is a no-op (the tensor is already averaged). - `should_commit()` always `True` (no fault-tolerance failover). - `start_quorum()` bumps `_step`. - `is_leader()` returns `rank == 0`. ```python from composer_replication.diloco.serverless import MockManager, ObjectStoreAllReduce store = ObjectStoreAllReduce("/tmp/run/", rank=0, world_size=2) mgr = MockManager(store) # pass mgr into make_diloco_outer_loop(manager=mgr, ...) ``` ### `class _ImmediateWork` β€” `serverless.allreduce` ⚠️ UNTESTED-CONTRACT internal helper exported from `__all__`. `Work`-shaped wrapper with `.wait() -> True` and `.get_future() -> torch.futures.Future`. Consumed by torchft DiLoCo's `perform_sync`. ```python from composer_replication.diloco.serverless.allreduce import _ImmediateWork ``` ### `class ModalExecutor` β€” `serverless.modal` 🟑 SKELETON β€” raises `NotImplementedError`; see ADR-005. Class body documents the v0 implementation pattern (Modal `app.function` + `function.spawn(rank=...)`). ```python from composer_replication.diloco.serverless.modal import ModalExecutor # ModalExecutor() # would NotImplementedError when instantiated ``` ### `class HFJobsExecutor` β€” `serverless.hf_jobs` 🟑 SKELETON β€” raises `NotImplementedError`; see ADR-005. Class body documents the v0 pattern using `huggingface_hub.run_job` against `hf://datasets/.../` rendezvous. ```python from composer_replication.diloco.serverless.hf_jobs import HFJobsExecutor # instantiation will fail until v0 implementation lands ``` ### `replica_entrypoint.main(...)` β€” `serverless.replica_entrypoint` ```python def main( rendezvous_uri: str, world_size: int, trainer_module: str, trainer_fn: str = "train", trainer_kwargs: dict[str, Any] | None = None, ) -> Any ``` Script run by every replica. Reads `REPLICA_RANK` env var, builds `ObjectStoreAllReduce` + `MockManager`, imports `trainer_module`, and calls `getattr(mod, trainer_fn)(**trainer_kwargs, manager=..., rank=..., world_size=...)`. Returns whatever the train fn returns. **Raises** `RuntimeError` if `REPLICA_RANK` env var is missing; `ValueError` if rank βˆ‰ `[0, world_size)`. The `if __name__ == "__main__"` block accepts CLI flags `--rendezvous`, `--world-size`, `--trainer-module`, `--trainer-fn`, `--trainer-kwargs-json`. ```python # In-process invocation import os os.environ["REPLICA_RANK"] = "0" from composer_replication.diloco.serverless.replica_entrypoint import main result = main(rendezvous_uri="/tmp/run/", world_size=1, trainer_module="my.trainer", trainer_fn="train") ``` --- ## 13. `composer_replication.recipes.prime_rl.composer_loss` PRIME-RL adapter (ADR-006). Maps PRIME-RL's `LossInputs` struct onto channel 1 (DPPO + KL on the importance ratio, mirroring PRIME-RL's upstream `default_loss_fn` at `prime_rl/trainer/rl/loss.py` lines 116-165). Channel 2 raises `NotImplementedError`; channel 3 is out of scope. ### `loss_fn(inputs, *, alpha_sdpo=0.0, beta_dpo=0.0, dppo_mask_high=0.2, dppo_mask_low=0.2, adv_tau=1.0, kl_tau=1e-3) -> torch.Tensor` ```python def loss_fn( inputs: Any, # PRIME-RL's LossInputs (duck-typed) *, alpha_sdpo: float = 0.0, beta_dpo: float = 0.0, dppo_mask_high: float = 0.2, dppo_mask_low: float = 0.2, adv_tau: float = 1.0, kl_tau: float = 1e-3, ) -> Any # torch.Tensor scalar ``` PRIME-RL passes per-sample **1-D `(seq,)` tensors** (not batched). The function mirrors PRIME-RL's upstream DPPO+KL formula: - Mask gate is on **probability-space** `probs_diff = exp(trainer_lp) - exp(inference_lp)` (NOT on the log-ratio). - A token is dropped iff its advantage sign matches the offending bound: positive-advantage tokens are dropped when `probs_diff > dppo_mask_high`, negative-advantage tokens when `probs_diff < -dppo_mask_low`. (PRIME-RL stores both bounds with `Field(..., ge=0)` and applies the sign internally.) - The PG term is `keep * (adv_tau * advantages) * exp(trainer_lp - inference_lp)` (importance-ratio corrected, not REINFORCE). - A KL penalty `kl_tau * log_importance_ratio**2` is added on the full `loss_mask` (DPPO masking does not gate it). - Reduction is a plain `sum()`; PRIME-RL's outer `compute_loss` divides by `loss_scale`. **Parameters** | Name | Type | Default | Meaning | |---|---|---|---| | `inputs` | PRIME-RL `LossInputs` (duck-typed) | β€” | Must expose `trainer_logprobs`, `inference_logprobs`, `advantages`, `loss_mask` (all 1-D), and optionally `teacher_logprobs`. | | `alpha_sdpo` | `float` (kw-only) | `0.0` | Channel-2 weight. Must be `0` in v0; >0 β†’ `NotImplementedError`. | | `beta_dpo` | `float` (kw-only) | `0.0` | Channel-3 weight. Non-zero emits a `UserWarning`. | | `dppo_mask_high` | `float` (kw-only), `>= 0` | `0.2` | Upper probability-diff threshold. PRIME-RL `DefaultLossConfig` default. | | `dppo_mask_low` | `float` (kw-only), `>= 0` | `0.2` | Magnitude of lower probability-diff threshold (sign flipped internally). PRIME-RL default. | | `adv_tau` | `float` (kw-only), `>= 0` | `1.0` | Advantage temperature. PRIME-RL default. | | `kl_tau` | `float` (kw-only), `>= 0` | `1e-3` | KL term temperature. PRIME-RL default. | **Returns** scalar `torch.Tensor` (PRIME-RL's trainer calls `.backward()`). **Raises** `ValueError` if any of `trainer_logprobs`, `inference_logprobs`, `advantages`, `loss_mask` is not 1-D, or any of the four `>=0`-constrained knobs is negative. `NotImplementedError` if `alpha_sdpo > 0` (channel 2 deferred). ```python from composer_replication.recipes.prime_rl.composer_loss import loss_fn # In PRIME-RL config: # loss: # custom: # import_path: composer_replication.recipes.prime_rl.composer_loss:loss_fn # kwargs: # dppo_mask_high: 0.2 # dppo_mask_low: 0.2 # adv_tau: 1.0 # kl_tau: 1.0e-3 ``` --- ## 14. `composer_replication.recipes.monarch.actors` 🟑 SKELETON module per ADR-006. Importable; classes raise `NotImplementedError` on instantiation. Documents the actor signatures so the recipe matrix is complete. ### `class TrainerActor` 🟑 ```python class TrainerActor: backend = "monarch" role = "trainer" def __init__(self) -> None: raise NotImplementedError(...) async def train_outer_step(self, batch_id: int) -> dict[str, Any]: raise NotImplementedError ``` Hosts the framework's 3-channel composer trainer. Real impl deferred to v0.2+. ### `class GeneratorActor` 🟑 ```python class GeneratorActor: backend = "monarch" role = "generator" def __init__(self) -> None: raise NotImplementedError(...) async def rollout(self, prompts: list[str]) -> list[str]: raise NotImplementedError ``` vLLM-backed rollout actor. ### `class RewarderActor` 🟑 ```python class RewarderActor: backend = "monarch" role = "rewarder" def __init__(self) -> None: raise NotImplementedError(...) async def score(self, completions: list[str]) -> list[float]: raise NotImplementedError ``` verifiers-protocol rewarder. ### `class TeacherPoolActor` 🟑 ```python class TeacherPoolActor: backend = "monarch" role = "teacher_pool" def __init__(self) -> None: raise NotImplementedError(...) ``` Channel-3 teacher pool wrapping `composer_replication.teacher_replay`. ```python # All Monarch actors raise on instantiation in v0: from composer_replication.recipes.monarch.actors import TrainerActor # TrainerActor() # NotImplementedError ``` --- ## Notes on test coverage Tested contracts (referenced spike/test paths): - `compose_loss` + `LossComponents` + `build_batch`: `composer_replication/tests/test_compose_loss_integration.py`, `spikes/006-real-hf-model-smoke/tests/`. - `generalized_jsd_loss`: `spikes/005-integrated-trainer-skeleton/tests/test_opsd_loss.py`. - `simpo_loss`, `taid_loss`, `taid_alpha_schedule`, `taid_blended_logits`, `entropy_aware_opd_loss`: `composer_replication/distillation/tests/test_distillation_losses.py`. - `replay_trace`, `extract_dpo_pairs`, `DPOPair`, `TraceState`, `TeacherCallResult`, `TeacherSpec`, `DEFAULT_TEACHERS`: `spikes/005-integrated-trainer-skeleton/tests/test_teacher_replay.py`. - `DJNormalizer`, `NormalizedDPOPair`, `replay_and_normalize_trace`: `composer_replication/replaysim/tests/test_replaysim.py`. - `ClaudeCodeIngester`, `IngestionStats`, `SYSTEM_PROMPT`: `spikes/007-real-trace-ingestion/tests/`. - `ComposerDataCollator`, `CollatorConfig`, `TraceTurn`, `TraceExample`: `spikes/005-integrated-trainer-skeleton/tests/test_data_collator.py`. - `ComposerReplicationTrainer._compute_loss` (composition arithmetic): `spikes/005-integrated-trainer-skeleton/tests/test_loss_composition_smoke.py`. - `make_diloco_outer_loop` + sign convention: `spikes/008-streaming-diloco/tests/test_diloco_smoke.py`. - `ObjectStoreAllReduce`, `MockManager`, `LocalProcessExecutor`, `ReplicaHandle`, `ServerlessExecutor`, `replica_entrypoint.main`: `composer_replication/diloco/serverless/tests/test_serverless_local.py`, `test_serverless_diloco_integration.py`. - `recipes.prime_rl.composer_loss.loss_fn`: `composer_replication/recipes/prime_rl/tests/test_composer_loss.py`. Untested-contract symbols (⚠️) and skeletons (🟑) are flagged inline above. --- **Document path**: `/mnt/e/CS/HF/composer-replication-framework/docs/API_REFERENCE.md`