Codeseys's picture
Wave 16: install ergonomics + gradient evidence + SDPO end-to-end example
c0a5ab7
# API Reference — composer-replication-framework
Complete reference for every public symbol in `composer_replication`. Source-of-truth is the `.py` files in `composer_replication/`; docstrings have been pulled verbatim where they exist and supplemented where missing.
**Legend**
- ⚠️ **UNTESTED-CONTRACT** — symbol exists and is callable, but its behaviour is not pinned by an automated test in `composer_replication/**/tests/` or `spikes/**/tests/`.
- 🟡 **SKELETON** — class/method body raises `NotImplementedError`; ships as design-of-record per ADR-005 / ADR-006.
**Module groups (in this document)**
1. `composer_replication` (top-level re-exports)
2. `composer_replication.loss`
3. `composer_replication.batch`
4. `composer_replication.opsd`
5. `composer_replication.distillation`
6. `composer_replication.teacher_replay`
7. `composer_replication.replaysim`
8. `composer_replication.ingestion` (+ `.claude_code`)
9. `composer_replication.hint_generator`
10. `composer_replication.trainer` (+ `.composer_trainer`, `.data_collator`)
11. `composer_replication.diloco`
12. `composer_replication.diloco.serverless` (+ `.executor`, `.allreduce`, `.modal`, `.hf_jobs`, `.replica_entrypoint`)
13. `composer_replication.recipes.prime_rl.composer_loss`
14. `composer_replication.recipes.monarch.actors`
---
## 1. `composer_replication` — top-level package
The package re-exports the most common entry points from sub-modules. `__all__` is the canonical list of public top-level names.
### `composer_replication.__version__: str`
Package version string. Currently `"0.1.0"`.
```python
import composer_replication
print(composer_replication.__version__) # "0.1.0"
```
### `composer_replication._DILOCO_AVAILABLE: bool`
`True` iff `torchft` is importable in the running Python environment (gates `make_diloco_outer_loop`). Set to `False` and `make_diloco_outer_loop` is set to `None` when `torchft` is missing.
```python
from composer_replication import _DILOCO_AVAILABLE
if _DILOCO_AVAILABLE:
from composer_replication import make_diloco_outer_loop
```
### Re-exports
| Name | Source module |
|---|---|
| `compose_loss` | `composer_replication.loss` |
| `LossComponents` | `composer_replication.loss` |
| `build_batch` | `composer_replication.batch` |
| `generalized_jsd_loss` | `composer_replication.opsd` |
| `ClaudeCodeIngester` | `composer_replication.ingestion.claude_code` |
| `IngestionStats` | `composer_replication.ingestion.claude_code` |
| `SYSTEM_PROMPT` | `composer_replication.ingestion.claude_code` |
| `DEFAULT_TEACHERS` | `composer_replication.teacher_replay` |
| `DPOPair` | `composer_replication.teacher_replay` |
| `TeacherCallResult` | `composer_replication.teacher_replay` |
| `TeacherSpec` | `composer_replication.teacher_replay` |
| `TraceState` | `composer_replication.teacher_replay` |
| `extract_dpo_pairs` | `composer_replication.teacher_replay` |
| `replay_trace` | `composer_replication.teacher_replay` |
| `ComposerReplicationTrainer` | `composer_replication.trainer` |
| `make_diloco_outer_loop` | `composer_replication.diloco` (or `None` if `torchft` missing) |
See each source module below for full signatures.
---
## 2. `composer_replication.loss`
Verification-harness 3-channel loss. Free function, does not depend on `trl`.
### `class LossComponents`
```python
@dataclass
class LossComponents:
lm_ce: torch.Tensor
sdpo_jsd: torch.Tensor
trace_replay_dpo: torch.Tensor
total: torch.Tensor
def detached(self) -> dict[str, float]: ...
```
Per-channel breakdown of the total loss for logging and ablation. All four fields are scalar `torch.Tensor`s (`shape=()`); `total = lm_ce + alpha_sdpo * sdpo_jsd + beta_replay * trace_replay_dpo`.
**`detached() -> dict[str, float]`** — returns Python-float copies of all four fields with no grad. Useful for W&B logging.
```python
from composer_replication import compose_loss, build_batch
components = compose_loss(model, build_batch(tokenizer))
print(components.detached()) # {'lm_ce': 2.34, 'sdpo_jsd': 0.12, ...}
components.total.backward()
```
### `compose_loss(model, inputs, *, ...) -> LossComponents`
<a id="compose_loss"></a>
```python
def compose_loss(
model: torch.nn.Module,
inputs: dict[str, torch.Tensor],
*,
alpha_sdpo: float = 0.1,
beta_replay: float = 0.05,
sdpo_jsd_beta: float = 0.5,
sdpo_temperature: float = 1.0,
sdpo_token_clip: float | None = None,
replay_dpo_beta: float = 0.1,
lm_ce_label_smoothing: float = 0.0,
dpo_variant: Literal["dpo", "simpo"] = "dpo",
sdpo_wrapper: Literal["none", "taid", "entropy_opd"] = "none",
taid_t: float | None = None,
simpo_beta: float = 2.0,
simpo_gamma: float = 1.0,
entropy_opd_h_max: float | None = None,
) -> LossComponents
```
Compute `total = lm_ce + alpha_sdpo * sdpo_jsd + beta_replay * trace_replay_dpo`.
**Required keys in `inputs`**
- `input_ids`: `(B, T_s)` student rollout token ids.
- `response_mask`: `(B, T_s)` 1 on assistant-response tokens, 0 elsewhere.
**Optional keys** (channel auto-disables if missing OR if its weight = 0):
- SDPO: `ctx_teacher_input_ids` `(B, T_t)`, `sdpo_loss_mask` `(B, T_t)`.
- DPO (`dpo_variant="dpo"`): `dpo_chosen_input_ids`, `dpo_chosen_response_mask`, `dpo_rejected_input_ids`, `dpo_rejected_response_mask`, `dpo_chosen_ref_logprobs`, `dpo_rejected_ref_logprobs` (precomputed).
- SimPO (`dpo_variant="simpo"`): same DPO ids/masks; reference logprobs are silently ignored.
- TAID (`sdpo_wrapper="taid"`): no extra `inputs` keys needed; the optional `sdpo_loss_mask` is reused as the per-token TAID mask. Pass `taid_t` directly (or drive it from `TAIDScheduler`).
**Parameters**
| Name | Type | Default | Meaning |
|---|---|---|---|
| `model` | `torch.nn.Module` | — | HF causal-LM. Must accept `input_ids=` and return an object with `.logits`. |
| `inputs` | `dict[str, torch.Tensor]` | — | Batch dict (see required/optional keys above). |
| `alpha_sdpo` | `float` | `0.1` | Weight on SDPO/JSD channel. `0.0` disables. |
| `beta_replay` | `float` | `0.05` | Weight on trace-replay DPO channel. `0.0` disables. |
| `sdpo_jsd_beta` | `float` | `0.5` | β param for `generalized_jsd_loss` (0=fwd KL, 0.5=JSD, 1=rev KL). Unused when `sdpo_wrapper="taid"`. |
| `sdpo_temperature` | `float` | `1.0` | Softmax temperature in SDPO. Unused when `sdpo_wrapper="taid"`. |
| `sdpo_token_clip` | `float \| None` | `None` | Per-token JSD clamp. |
| `replay_dpo_beta` | `float` | `0.1` | β in standard DPO logit. |
| `lm_ce_label_smoothing` | `float` | `0.0` | `F.cross_entropy(label_smoothing=)`. |
| `dpo_variant` | `Literal["dpo","simpo"]` | `"dpo"` | Channel-3 algorithm. |
| `sdpo_wrapper` | `Literal["none","taid","entropy_opd"]` | `"none"` | Channel-2 wrapper. |
| `taid_t` | `float \| None` | `None` | Current TAID interpolation coefficient in `[0, 1]`. Required when `sdpo_wrapper="taid"`. Drive from `TAIDScheduler` or pass a fixed value. |
| `simpo_beta` | `float` | `2.0` | SimPO β (paper default). |
| `simpo_gamma` | `float` | `1.0` | SimPO target margin γ (paper default). |
| `entropy_opd_h_max` | `float \| None` | `None` | Max-entropy normalizer; `None``log(V)`. |
**Returns** `LossComponents` (see above).
**Raises** `ValueError` if `dpo_variant` or `sdpo_wrapper` is unknown, if `sdpo_wrapper="taid"` is requested without `taid_t`, or if `taid_t` is outside `[0, 1]`.
```python
from composer_replication import compose_loss, build_batch
batch = build_batch(tokenizer)
out = compose_loss(model, batch, alpha_sdpo=0.1, beta_replay=0.05)
out.total.backward()
print(out.detached())
```
---
## 3. `composer_replication.batch`
Verification-harness batch builder.
### `build_batch(tokenizer, *, ...) -> dict[str, torch.Tensor]`
```python
def build_batch(
tokenizer: Any,
*,
device: torch.device | str = "cpu",
seed: int = 42,
variant: str = "factorial",
align_sdpo_shapes: bool = False,
) -> dict[str, torch.Tensor]
```
Construct a full 3-channel batch from a real HF tokenizer. The DPO ref-logprobs are dummy tensors (the smoke verifies loss composition wires together, not the reference-policy precompute).
**Returned keys**: `input_ids`, `response_mask`, `ctx_teacher_input_ids`, `sdpo_loss_mask`, `dpo_chosen_input_ids`, `dpo_chosen_response_mask`, `dpo_rejected_input_ids`, `dpo_rejected_response_mask`, `dpo_chosen_ref_logprobs`, `dpo_rejected_ref_logprobs`.
**Parameters**
| Name | Type | Default | Meaning |
|---|---|---|---|
| `tokenizer` | HF `AutoTokenizer` (duck-typed) | — | Must support `apply_chat_template` and `__call__`. |
| `device` | `torch.device \| str` | `"cpu"` | Target device for all returned tensors. |
| `seed` | `int` | `42` | Fixes `torch.manual_seed`. |
| `variant` | `str` | `"factorial"` | One of `"factorial"`, `"binary_search"`. |
| `align_sdpo_shapes` | `bool` | `False` | If True, truncate/pad `ctx_teacher_input_ids` to `input_ids` length so the SDPO channel actually fires. |
**Raises** `ValueError` if `variant` is unknown.
```python
from transformers import AutoTokenizer
from composer_replication import build_batch
tok = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-0.5B-Instruct")
batch = build_batch(tok, variant="factorial", align_sdpo_shapes=True)
print({k: v.shape for k, v in batch.items()})
```
---
## 4. `composer_replication.opsd`
Self-distillation generalized-JSD loss, lifted verbatim from `siyan-zhao/OPSD` (MIT) per ADR-006.
### `generalized_jsd_loss(student_logits, teacher_logits, labels=None, beta=0.5, ...) -> torch.Tensor`
```python
def generalized_jsd_loss(
student_logits: torch.Tensor,
teacher_logits: torch.Tensor,
labels: torch.Tensor | None = None,
beta: float = 0.5,
temperature: float = 1.0,
reduction: str = "batchmean",
logits_are_probs: bool = False,
top_k: int | None = None,
token_clip: float | None = None,
) -> torch.Tensor
```
Generalized JSD between student and teacher distributions. Same model on different contexts in the SDPO recipe; student and teacher params come from the SAME model.
**Parameters**
| Name | Type | Default | Meaning |
|---|---|---|---|
| `student_logits` | `Tensor (B, T, V)` | — | Student logits with grad. |
| `teacher_logits` | `Tensor (B, T, V)` | — | Teacher logits (no grad in SDPO). |
| `labels` | `Tensor (B, T) \| None` | `None` | Per-token mask. `-100` positions are ignored (HF convention). |
| `beta` | `float` in [0, 1] | `0.5` | 0=fwd KL, 1=rev KL, 0.5=symmetric JSD. |
| `temperature` | `float` | `1.0` | Softmax temperature. |
| `reduction` | `str` | `"batchmean"` | `"batchmean"`, `"sum"`, `"mean"`, `"none"`. |
| `logits_are_probs` | `bool` | `False` | Skip softmax if inputs are already probabilities. |
| `top_k` | `int \| None` | `None` | Restrict KL to teacher's top-k tokens. |
| `token_clip` | `float \| None` | `None` | Clip per-token JSD for stability. |
**Returns** scalar tensor (or `(B, T)` if `reduction="none"`).
**Raises** `ValueError` for unknown `reduction`.
```python
import torch
from composer_replication.opsd import generalized_jsd_loss
s = torch.randn(2, 8, 32, requires_grad=True)
t = torch.randn(2, 8, 32)
loss = generalized_jsd_loss(s, t, beta=0.5, reduction="batchmean")
loss.backward()
```
---
## 5. `composer_replication.distillation`
Pluggable self-distillation losses (ADR-007). All pure PyTorch.
### `simpo_loss(chosen_avg_logprobs, rejected_avg_logprobs, *, beta=2.0, gamma=1.0) -> torch.Tensor`
```python
def simpo_loss(
chosen_avg_logprobs: torch.Tensor,
rejected_avg_logprobs: torch.Tensor,
*,
beta: float = 2.0,
gamma: float = 1.0,
) -> torch.Tensor
```
Reference-free DPO with target margin γ (Meng et al., NeurIPS 2024). `L = -log σ(β · (avg_logπ(c) − avg_logπ(r)) − γ)`.
**Parameters**
| Name | Type | Default | Meaning |
|---|---|---|---|
| `chosen_avg_logprobs` | `Tensor (B,)` | — | Per-sequence avg logprob over chosen response tokens. |
| `rejected_avg_logprobs` | `Tensor (B,)` | — | Same for rejected. |
| `beta` | `float` | `2.0` | Scaling factor (paper default). |
| `gamma` | `float` | `1.0` | Target margin (paper default). |
**Returns** scalar; **Raises** `ValueError` if shapes mismatch.
```python
import torch
from composer_replication.distillation import simpo_loss
loss = simpo_loss(torch.tensor([-2.1, -1.8]), torch.tensor([-3.0, -2.5]),
beta=2.0, gamma=1.0)
```
### `avg_sequence_logprob(model_logprobs, response_mask) -> torch.Tensor`
⚠️ UNTESTED-CONTRACT (helper exported from `simpo.py` but not asserted by a test).
```python
def avg_sequence_logprob(
model_logprobs: torch.Tensor,
response_mask: torch.Tensor,
) -> torch.Tensor
```
Convert `(B, T)` per-token logprobs + `(B, T)` response mask into `(B,)` per-sequence average over response tokens.
```python
from composer_replication.distillation.simpo import avg_sequence_logprob
import torch
lp = torch.randn(2, 8); m = torch.tensor([[0,0,1,1,1,0,0,0],[0,1,1,1,1,1,0,0]])
out = avg_sequence_logprob(lp, m) # shape (2,)
```
### `taid_loss(student_logits, teacher_logits, mask=None, *, t) -> torch.Tensor`
```python
def taid_loss(
student_logits: torch.Tensor,
teacher_logits: torch.Tensor,
mask: torch.Tensor | None = None,
*,
t: float | torch.Tensor,
) -> torch.Tensor
```
Faithful port of `SakanaAI/TAID` (arXiv:2501.16937). Forward-KL distillation against a logit-space-interpolated target whose anchor is the **current student detached**:
```
p_t = softmax( (1 - t) · stop_grad(student_logits) + t · teacher_logits )
L = - mean_token Σ_v p_t(v) · log_softmax(student_logits)(v)
```
At `t=0` the target collapses to the detached student (no teacher signal in the gradient). At `t=1` it reduces to standard forward-KL distillation against the teacher.
**Wave 15 breaking change.** The previous signature `taid_loss(student, teacher, student_init, *, schedule_step, total_steps, schedule, alpha_min, alpha_max, jsd_beta, temperature, reduction)` was algorithmically wrong (probability-space mix, frozen step-0 anchor, JSD criterion). All those kwargs are removed; the schedule is now the caller's responsibility (see `TAIDScheduler` below for the upstream adaptive scheme).
**Parameters**
| Name | Type | Default | Meaning |
|---|---|---|---|
| `student_logits` | `Tensor (B, T, V)` | — | Current student (with grad). |
| `teacher_logits` | `Tensor (B, T, V)` | — | Teacher logits. |
| `mask` | `Tensor (B, T) \| None` | `None` | Token mask. `None` ⇒ all-ones. |
| `t` | `float \| Tensor` | — | Interpolation coefficient in `[0, 1]`. |
**Raises** `ValueError` for shape mismatch.
```python
from composer_replication.distillation import taid_loss
loss = taid_loss(s_logits, t_logits, mask, t=0.4)
```
### `TAIDScheduler(num_train_steps, *, t_start=0.4, t_end=1.0, alpha=5e-4, beta=0.99, disable_adaptive=False)`
Stateful schedule that mirrors upstream `TAID.update_t`. Monotone non-decreasing, bumped above the linear floor by an EMA on the relative loss change. Use as:
```python
from composer_replication.distillation import TAIDScheduler
sched = TAIDScheduler(num_train_steps=10_000) # paper defaults
for step in range(num_train_steps):
loss = taid_loss(s, t, mask, t=sched.t)
loss.backward(); optimizer.step()
sched.update_t(loss.detach(), global_step=step)
```
**Parameters**
| Name | Type | Default | Meaning |
|---|---|---|---|
| `num_train_steps` | `int` | — | Total planned training steps; sets the linear floor. |
| `t_start` | `float` | `0.4` | Initial `t` (paper default). |
| `t_end` | `float` | `1.0` | Terminal `t`; hard ceiling at every step. |
| `alpha` | `float` | `5e-4` | Adaptive bump magnitude. |
| `beta` | `float` | `0.99` | EMA decay on relative-loss-change momentum. |
| `disable_adaptive` | `bool` | `False` | If True, fall back to deterministic linear schedule. |
| `device` | `torch.device \| str` | `"cpu"` | Where to allocate state buffers. |
**Properties / methods**
- `sched.t -> float` — current `t` as a Python float (zero-arg property).
- `sched.update_t(loss, global_step) -> Tensor | None` — update internal state. First finite-loss call only seeds `prev_loss` and returns `None`; subsequent calls return the (positive) `delta_t` added on top of the linear floor.
### `entropy_aware_opd_loss(student_logits, teacher_logits, *, labels=None, h_max=None, temperature=1.0, reduction="batchmean") -> torch.Tensor`
```python
def entropy_aware_opd_loss(
student_logits: torch.Tensor,
teacher_logits: torch.Tensor,
*,
labels: torch.Tensor | None = None,
h_max: float | None = None,
temperature: float = 1.0,
reduction: str = "batchmean",
) -> torch.Tensor
```
Per-token mixture of forward and reverse KL gated by teacher entropy: `w(t) = clamp(H_teacher(t)/h_max, 0, 1)`. High-entropy tokens use forward KL (mode-covering), low-entropy tokens use reverse KL (mode-seeking).
**Parameters**
| Name | Type | Default | Meaning |
|---|---|---|---|
| `student_logits` | `Tensor (B,T,V)` | — | Student logits (grad). |
| `teacher_logits` | `Tensor (B,T,V)` | — | Teacher logits (no grad). |
| `labels` | `Tensor (B,T) \| None` | `None` | 0/1 mask, applied multiplicatively after the per-token mix. |
| `h_max` | `float \| None` | `None``log(V)` | Max-entropy normalizer. |
| `temperature` | `float` | `1.0` | Softmax temperature on both. |
| `reduction` | `str` | `"batchmean"` | `"batchmean"`, `"sum"`, `"mean"`, `"none"`. |
**Raises** `ValueError` on shape mismatch (student vs teacher; labels vs per-token loss) or unknown `reduction`.
```python
from composer_replication.distillation import entropy_aware_opd_loss
loss = entropy_aware_opd_loss(s_logits, t_logits, temperature=1.0)
loss.backward()
```
### `teacher_entropy(teacher_logits) -> torch.Tensor`
⚠️ UNTESTED-CONTRACT (helper exposed from `entropy_aware_opd.py`'s `__all__` but not directly asserted).
Per-token entropy in nats. Input `(B,T,V)`, output `(B,T)`.
```python
from composer_replication.distillation.entropy_aware_opd import teacher_entropy
H = teacher_entropy(teacher_logits) # (B, T)
```
---
## 6. `composer_replication.teacher_replay`
N-teacher OpenRouter parallel client + DPO-pair extractor. `httpx` is lazy-imported inside `replay_trace`; the deterministic local logic is testable without it.
### `DEFAULT_TEACHERS: list[TeacherSpec]`
Three-teacher default set: `anthropic/claude-opus-4.7`, `openai/gpt-5`, `deepseek/deepseek-v4-pro` with paper-baseline OpenRouter pricing.
```python
from composer_replication.teacher_replay import DEFAULT_TEACHERS
print([t["slug"] for t in DEFAULT_TEACHERS])
```
### `class TeacherSpec(TypedDict)`
```python
class TeacherSpec(TypedDict):
slug: str
input_per_mtok: float
output_per_mtok: float
```
OpenRouter model slug + per-million-token pricing.
```python
spec: TeacherSpec = {"slug": "openai/gpt-5",
"input_per_mtok": 1.25, "output_per_mtok": 10.0}
```
### `class TraceState(TypedDict)`
```python
class TraceState(TypedDict):
state_id: str # unique within the trace
messages: list[dict] # OpenAI-style chat history up to (and incl.) this user prompt
student_action: str # what the student actually did at this step
```
One step of a frozen agentic trace. `student_action` is the raw text emitted by the student; teachers are queried with `messages` and asked to predict the assistant's next action.
```python
state: TraceState = {"state_id": "ex001::0042",
"messages": [{"role": "user", "content": "..."}],
"student_action": "[TOOL_USE] name=Read input={...}"}
```
### `class TeacherCallResult(TypedDict)`
```python
class TeacherCallResult(TypedDict):
state_id: str
teacher_slug: str
response_text: str | None # None on error
latency_s: float
prompt_tokens: int
completion_tokens: int
cost_usd: float
error: str | None # None on success
```
One row of N×T results from `replay_trace`.
```python
r: TeacherCallResult = {"state_id": "x", "teacher_slug": "openai/gpt-5",
"response_text": "ok", "latency_s": 1.2, "prompt_tokens": 100,
"completion_tokens": 5, "cost_usd": 0.001, "error": None}
```
### `class DPOPair(TypedDict)`
```python
class DPOPair(TypedDict):
state_id: str
state_messages: list[dict]
chosen: str # teacher-consensus action
rejected: str # student action
n_teachers_agreeing: int
```
One preference pair extracted from teacher-vs-student disagreement.
```python
p: DPOPair = {"state_id": "x", "state_messages": [...], "chosen": "...",
"rejected": "...", "n_teachers_agreeing": 2}
```
### `async replay_trace(states, teachers=DEFAULT_TEACHERS, max_total_usd=5.0, api_key=None) -> list[TeacherCallResult]`
```python
async def replay_trace(
states: Sequence[TraceState],
teachers: Sequence[TeacherSpec] = tuple(DEFAULT_TEACHERS),
max_total_usd: float = 5.0,
api_key: str | None = None,
) -> list[TeacherCallResult]
```
For each state, fan-out one parallel call per teacher via OpenRouter. Hard-caps cumulative spend at `max_total_usd` (stops after the offending state completes).
**Parameters**
| Name | Type | Default | Meaning |
|---|---|---|---|
| `states` | `Sequence[TraceState]` | — | Frozen trace, one entry per assistant turn. |
| `teachers` | `Sequence[TeacherSpec]` | `DEFAULT_TEACHERS` | Models to query in parallel. |
| `max_total_usd` | `float` | `5.0` | Cumulative spend cap. |
| `api_key` | `str \| None` | `None` | OpenRouter key; defaults to `OPENROUTER_API_KEY` env or `~/.hermes/.env`. |
**Returns** flat list of `TeacherCallResult`s (length `len(states) * len(teachers)` modulo budget cutoff).
**Raises** `RuntimeError` if `OPENROUTER_API_KEY` is not findable; `ImportError` if `httpx` is missing at call time.
```python
import asyncio
from composer_replication import replay_trace
results = asyncio.run(replay_trace(states=my_trace, max_total_usd=1.0))
```
### `extract_dpo_pairs(states, teacher_actions, agreement_threshold=2) -> list[DPOPair]`
```python
def extract_dpo_pairs(
states: Sequence[TraceState],
teacher_actions: Sequence[TeacherCallResult],
agreement_threshold: int = 2,
) -> list[DPOPair]
```
Group teacher_actions by `state_id`, normalize whitespace, and emit one `DPOPair` per state where ≥`agreement_threshold` teachers agreed on an action that differs from the student's. `chosen` is the original (un-normalized) teacher response text.
**Parameters**
| Name | Type | Default | Meaning |
|---|---|---|---|
| `states` | `Sequence[TraceState]` | — | Same as passed to `replay_trace`. |
| `teacher_actions` | `Sequence[TeacherCallResult]` | — | Output of `replay_trace`. |
| `agreement_threshold` | `int` | `2` | Min teachers that must agree for a pair to fire. |
**Returns** list of `DPOPair`. At most one pair per state (the most-agreed-upon action wins).
```python
from composer_replication import extract_dpo_pairs
pairs = extract_dpo_pairs(my_states, results, agreement_threshold=2)
```
### `save_pairs(pairs, path) -> None`
⚠️ UNTESTED-CONTRACT.
```python
def save_pairs(pairs: Sequence[DPOPair], path: str | Path) -> None
```
Write pairs to JSONL (one dict per line). Creates parent dirs.
```python
from composer_replication.teacher_replay import save_pairs
save_pairs(pairs, "/tmp/dpo_pairs.jsonl")
```
---
## 7. `composer_replication.replaysim`
ADR-004 normalization layer over `teacher_replay`. Re-exports `DPOPair`, `TeacherCallResult`, `extract_dpo_pairs`, `replay_trace` from `teacher_replay`.
### `class NormalizedDPOPair`
```python
@dataclass
class NormalizedDPOPair:
state_id: str
state_messages: list[dict[str, Any]]
chosen_messages: list[dict[str, Any]]
rejected_messages: list[dict[str, Any]]
n_teachers_agreeing: int
metadata: dict[str, Any]
```
Post-normalization shape. `chosen_messages`/`rejected_messages` are chat-format (`[{"role": "assistant", "content": ...}]`). `metadata` carries op-graph provenance, including `{"skipped": True}` when the normalizer was bypassed (`skip_dj=True`).
```python
from composer_replication.replaysim import NormalizedDPOPair
n = NormalizedDPOPair(state_id="x", state_messages=[],
chosen_messages=[{"role": "assistant", "content": "ok"}],
rejected_messages=[{"role": "assistant", "content": "no"}],
n_teachers_agreeing=2, metadata={})
```
### `class DJNormalizer`
```python
class DJNormalizer:
DEFAULT_RECIPE: ClassVar[Path] # composer_replication/recipes/replaysim/default.yaml
def __init__(
self,
recipe_path: str | os.PathLike[str] | None = None,
*,
skip_dj: bool = False,
) -> None: ...
def normalize(
self,
pairs: Iterable[DPOPair | dict[str, Any]],
) -> list[NormalizedDPOPair]: ...
```
`data-juicer`-backed normalizer. Pipeline: each `DPOPair` → JSONL record → `data_juicer.core.DefaultExecutor.run()` against the recipe → JSONL → `NormalizedDPOPair`.
**Constructor parameters**
| Name | Type | Default | Meaning |
|---|---|---|---|
| `recipe_path` | `str \| PathLike \| None` | `None` ⇒ default recipe | data-juicer YAML recipe path. |
| `skip_dj` | `bool` (kw-only) | `False` | If True: passthrough; records get `metadata={"skipped": True}` and no ops run. |
**`normalize(pairs) -> list[NormalizedDPOPair]`** runs the op-graph. Output may be shorter than input if filter ops drop records.
**Raises** `RuntimeError` at construction time if `skip_dj=False` and `data_juicer` is not importable. `FileNotFoundError` if `recipe_path` (default or explicit) is missing and `skip_dj=False`.
```python
from composer_replication.replaysim import DJNormalizer
norm = DJNormalizer(skip_dj=True)
out = norm.normalize(my_pairs)
```
### `async replay_and_normalize_trace(*, states, teachers=None, agreement_threshold=2, max_total_usd=5.0, normalizer=None, **replay_kwargs) -> tuple[list[TeacherCallResult], list[NormalizedDPOPair]]`
```python
async def replay_and_normalize_trace(
*,
states: Any,
teachers: Any = None,
agreement_threshold: int = 2,
max_total_usd: float = 5.0,
normalizer: DJNormalizer | None = None,
**replay_kwargs: Any,
) -> tuple[list[TeacherCallResult], list[NormalizedDPOPair]]
```
End-to-end async: replay → extract pairs → normalize.
**Parameters**
| Name | Type | Default | Meaning |
|---|---|---|---|
| `states` | `Sequence[TraceState]` | — | Frozen trace. |
| `teachers` | `Sequence[TeacherSpec] \| None` | `None` ⇒ defaults | Forwarded to `replay_trace`. |
| `agreement_threshold` | `int` | `2` | Forwarded to `extract_dpo_pairs`. |
| `max_total_usd` | `float` | `5.0` | Spend cap. |
| `normalizer` | `DJNormalizer \| None` | `None` ⇒ `DJNormalizer()` | Pass `DJNormalizer(skip_dj=True)` to bypass. |
| `**replay_kwargs` | `Any` | — | Forwarded to `replay_trace` (e.g. `api_key`). |
**Returns** `(raw_teacher_actions, normalized_pairs)`.
```python
import asyncio
from composer_replication.replaysim import replay_and_normalize_trace, DJNormalizer
raw, norm = asyncio.run(replay_and_normalize_trace(
states=my_states, normalizer=DJNormalizer(skip_dj=True)))
```
### `replay_and_normalize_trace_sync(*args, **kwargs) -> tuple[list[TeacherCallResult], list[NormalizedDPOPair]]`
⚠️ UNTESTED-CONTRACT (sync wrapper around the async function; tests call the async form via `asyncio.run`).
```python
def replay_and_normalize_trace_sync(*args, **kwargs) -> ...
```
Sync convenience wrapping `asyncio.run(replay_and_normalize_trace(...))`.
```python
from composer_replication.replaysim.normalize import replay_and_normalize_trace_sync
raw, norm = replay_and_normalize_trace_sync(states=my_states)
```
---
## 8. `composer_replication.ingestion` & `composer_replication.ingestion.claude_code`
Trace-source adapters (ADR-002). v0.1 supports Claude Code session JSONL.
### `SYSTEM_PROMPT: str`
Default synthetic system prompt injected at `messages[0]` for ingested traces (most Claude Code sessions don't write one). Truncated head: `"You are a senior software engineer working as a coding agent in a terminal environment..."`.
```python
from composer_replication import SYSTEM_PROMPT
print(SYSTEM_PROMPT[:60])
```
### `class IngestionStats`
```python
@dataclass
class IngestionStats:
n_records_total: int = 0
n_records_skipped: int = 0
n_states_emitted: int = 0
n_assistant_turns: int = 0
n_tool_use_blocks: int = 0
n_text_blocks: int = 0
skipped_subagent: int = 0
skipped_summary: int = 0
skipped_truncated_lines: int = 0
version_warnings: list[str] | None = None # initialized to [] in __post_init__
```
Counters populated by `ClaudeCodeIngester.ingest()` and exposed as `ingester.last_stats`.
```python
from composer_replication import IngestionStats
s = IngestionStats(n_records_total=5)
print(s.version_warnings) # []
```
### `class ClaudeCodeIngester`
```python
class ClaudeCodeIngester:
def __init__(
self,
*,
system_prompt: str = SYSTEM_PROMPT,
skip_sidechain: bool = True,
strip_thinking: bool = True,
max_history_tokens: int | None = None,
) -> None: ...
def ingest(self, path: Path) -> Iterator[TraceState]: ...
```
Convert a Claude Code session JSONL to a stream of `TraceState`s — one per assistant TURN (not per `tool_use` block).
**Constructor parameters**
| Name | Type | Default | Meaning |
|---|---|---|---|
| `system_prompt` | `str` | `SYSTEM_PROMPT` | Synthetic system message injected at history[0]. |
| `skip_sidechain` | `bool` | `True` | Skip subagent files (`agent-*.jsonl`) and records with `isSidechain=True`. |
| `strip_thinking` | `bool` | `True` | Remove `[THINKING]` blocks from history handed to teachers (kept inside `student_action`). |
| `max_history_tokens` | `int \| None` | `None` | ⚠️ UNTESTED-CONTRACT — accepted but currently not used to truncate. |
**`ingest(path) -> Iterator[TraceState]`**: generator over `TraceState` objects. Each turn's `state_id` is `f"{path.stem}::{idx:04d}"`. Side effect: replaces `self.last_stats` with a fresh `IngestionStats` and updates it as records stream.
```python
from pathlib import Path
from composer_replication import ClaudeCodeIngester
ing = ClaudeCodeIngester()
for state in ing.ingest(Path("session.jsonl")):
print(state["state_id"])
print(ing.last_stats.n_states_emitted)
```
---
## 9. `composer_replication.hint_generator`
⚠️ UNTESTED-CONTRACT (entire module — used by the data collator config but not pinned by a test).
Template-based hint registry for SDPO error-site injection.
### `class HintContext(TypedDict, total=False)`
```python
class HintContext(TypedDict, total=False):
error_kind: str
error_message: str
available_tools: list[str]
tool_name: str
tool_schema: dict
intent: str
```
Per-error context dict consumed by hint templates.
### `HINT_TEMPLATES: dict[str, Callable[[HintContext], str]]`
Default registry keys: `"tool_not_found"`, `"json_decode"`, `"type_error"`, `"runtime_error"`, `"repeated_failure"`.
### `dispatch(error_kind, ctx=None) -> str | None`
```python
def dispatch(error_kind: str, ctx: HintContext | None = None) -> str | None
```
Look up `error_kind` in `HINT_TEMPLATES`. Returns the template's hint text, or `None` if the kind is unknown.
```python
from composer_replication.hint_generator import dispatch
hint = dispatch("json_decode") # "Reminder: tool arguments must be valid JSON. ..."
```
### `register(error_kind, fn) -> None`
```python
def register(error_kind: str, fn: Callable[[HintContext], str]) -> None
```
Add or override a custom hint template.
```python
from composer_replication.hint_generator import register
register("my_error", lambda ctx: "Reminder: try X.")
```
### Individual template functions
⚠️ UNTESTED-CONTRACT — exported only via `HINT_TEMPLATES`, useful as building blocks:
- `hint_tool_not_found(ctx) -> str`
- `hint_json_decode(ctx) -> str`
- `hint_type_error(ctx) -> str`
- `hint_runtime_error(ctx) -> str`
- `hint_repeated_failure(ctx) -> str`
Each accepts a `HintContext` and returns hint text. Signatures are uniform: `Callable[[HintContext], str]`.
```python
from composer_replication.hint_generator import hint_tool_not_found
text = hint_tool_not_found({"available_tools": ["Read", "Write"]})
```
---
## 10. `composer_replication.trainer` & sub-modules
Production trainer (TRL `GRPOTrainer` subclass) plus data collator.
### `class ComposerReplicationTrainer`
```python
class ComposerReplicationTrainer(GRPOTrainer):
def __init__(
self,
*args: Any,
alpha_sdpo: float = 0.1,
beta_replay: float = 0.05,
sdpo_jsd_beta: float = 0.5,
sdpo_temperature: float = 1.0,
sdpo_token_clip: float | None = None,
replay_dpo_beta: float = 0.1,
**kwargs: Any,
) -> None: ...
def _compute_loss(
self,
model: torch.nn.Module,
inputs: dict[str, torch.Tensor],
) -> torch.Tensor: ...
```
`trl.GRPOTrainer` subclass that overrides `_compute_loss(model, inputs)` to compose `total = grpo + α·sdpo + β·trace_replay_dpo`. When `trl` is not installed, the parent class falls back to `object` so the module imports — but instantiation will fail because the parent's GRPO machinery is missing.
**Constructor (kw-only beyond GRPOTrainer's own `*args, **kwargs`)**
| Name | Type | Default | Meaning |
|---|---|---|---|
| `alpha_sdpo` | `float` | `0.1` | Channel-2 weight. |
| `beta_replay` | `float` | `0.05` | Channel-3 weight. |
| `sdpo_jsd_beta` | `float` | `0.5` | β for `generalized_jsd_loss`. |
| `sdpo_temperature` | `float` | `1.0` | SDPO softmax temperature. |
| `sdpo_token_clip` | `float \| None` | `None` | Per-token JSD clip. |
| `replay_dpo_beta` | `float` | `0.1` | DPO β. |
**`_compute_loss(model, inputs) -> torch.Tensor`** — overrides `GRPOTrainer._compute_loss`. Calls `super()._compute_loss` for channel 1, then `_compute_sdpo_loss` and `_compute_trace_replay_loss`, then composes. Logs per-channel components every `args.logging_steps` (default 50). **Raises** whatever `super()` raises (TRL-shaped errors).
**Internal methods (publicly accessible, exercised by spike tests)**
- ⚠️ UNTESTED-CONTRACT `_compute_sdpo_loss(model, inputs) -> torch.Tensor` — generalized-JSD between student forward and `ctx_teacher_input_ids` forward. Returns `0.0` (with grad) when `alpha_sdpo == 0`, the key is missing, or shapes mismatch. Logs a warning on shape mismatch.
- ⚠️ UNTESTED-CONTRACT `_compute_trace_replay_loss(model, inputs) -> torch.Tensor` — standard DPO over `dpo_chosen_*` and `dpo_rejected_*`, using precomputed `dpo_chosen_ref_logprobs` / `dpo_rejected_ref_logprobs`.
- ⚠️ UNTESTED-CONTRACT `@staticmethod _sequence_logprobs(model, input_ids, response_mask) -> torch.Tensor` — sum logprobs over response tokens; standard DPO accounting.
```python
from composer_replication import ComposerReplicationTrainer
trainer = ComposerReplicationTrainer(
model=my_model, args=my_grpo_args, train_dataset=ds,
data_collator=my_collator, alpha_sdpo=0.1, beta_replay=0.05,
)
# trainer.train() # uses overridden _compute_loss
```
### `class TraceTurn(TypedDict, total=False)` — `trainer.data_collator`
```python
class TraceTurn(TypedDict, total=False):
role: str # "user" | "assistant" | "tool"
content: str
tool_call: dict | None
tool_error: str | None
error_meta: dict
```
One turn of an agentic trace as consumed by `ComposerDataCollator`.
### `class TraceExample(TypedDict, total=False)` — `trainer.data_collator`
```python
class TraceExample(TypedDict, total=False):
trace_id: str
turns: list[TraceTurn]
final_reward: float
dpo_pairs: list[dict] | None
```
One training example: `(turns, optional dpo_pairs)`. `dpo_pairs` shape matches `DPOPair`.
### `class TokenizerLike` — `trainer.data_collator`
⚠️ UNTESTED-CONTRACT (duck-typed protocol; used as a type hint).
```python
class TokenizerLike:
pad_token_id: int
def __call__(self, text: str | list[str], **kwargs: Any) -> dict[str, list]: ...
def apply_chat_template(self, messages: list[dict], **kwargs: Any) -> str | list[int]: ...
```
Minimal protocol the collator needs. Compatible with HF `AutoTokenizer`.
### `class CollatorConfig` — `trainer.data_collator`
```python
@dataclass
class CollatorConfig:
max_seq_len: int = 4096
max_dpo_seq_len: int = 2048
pad_token_id: int = 0
ignore_index: int = -100
enable_sdpo: bool = True
hint_generator: Callable[[str, dict], str | None] | None = None
enable_replay_dpo: bool = True
rlvr_reward_key: str = "final_reward"
```
Tunables for `ComposerDataCollator`.
| Field | Default | Meaning |
|---|---|---|
| `max_seq_len` | `4096` | Truncation cap for student/teacher sequences. |
| `max_dpo_seq_len` | `2048` | Truncation cap for DPO chosen/rejected sequences. |
| `pad_token_id` | `0` | Padding token id. |
| `ignore_index` | `-100` | HF "ignore in loss" sentinel for SDPO mask. |
| `enable_sdpo` | `True` | Toggle channel-2 fields. |
| `hint_generator` | `Callable[[str, dict], str \| None] \| None` (`None`) | `(error_kind, error_meta) -> hint_text`. SDPO is no-op without this. |
| `enable_replay_dpo` | `True` | Toggle channel-3 fields. |
| `rlvr_reward_key` | `"final_reward"` | Key in `TraceExample` to read scalar reward. |
```python
from composer_replication.trainer.data_collator import CollatorConfig
cfg = CollatorConfig(max_seq_len=2048, hint_generator=my_dispatch)
```
### `class ComposerDataCollator` — `trainer.data_collator`
```python
@dataclass
class ComposerDataCollator:
tokenizer: TokenizerLike
config: CollatorConfig = field(default_factory=CollatorConfig)
def __call__(
self, batch: Sequence[TraceExample]
) -> dict[str, torch.Tensor]: ...
```
Build trainer-ready batches from raw traces + optional DPO pairs.
**Output dict keys** (tested in `spikes/005-integrated-trainer-skeleton/tests/test_data_collator.py`):
- Channel 1 (always): `input_ids`, `attention_mask`, `response_mask`, `rewards`.
- Channel 2 (when `enable_sdpo=True` AND batch has at least one error site AND `hint_generator` is set): `ctx_teacher_input_ids`, `sdpo_loss_mask`.
- Channel 3 (when `enable_replay_dpo=True` AND batch has at least one `dpo_pair`): `dpo_chosen_input_ids`, `dpo_chosen_response_mask`, `dpo_rejected_input_ids`, `dpo_rejected_response_mask`. (Reference logprobs are NOT computed here — the trainer does that pass.)
```python
from composer_replication.trainer.data_collator import (
ComposerDataCollator, CollatorConfig)
collator = ComposerDataCollator(tokenizer=tok, config=CollatorConfig())
batch = collator([{"trace_id": "x", "turns": [...], "final_reward": 1.0}])
```
---
## 11. `composer_replication.diloco`
DiLoCo outer-loop wrapper around `torchft.local_sgd.DiLoCo`. Optional dep — when `torchft` is missing the package re-export `composer_replication.make_diloco_outer_loop` is `None`.
### Module-level attributes
- `DiLoCo: Any` — `torchft.local_sgd.DiLoCo` if importable else `None`.
- `Manager: Any` — `torchft.manager.Manager` if importable else `None`.
- `_DummyWork: Any` — `torchft.work._DummyWork` if importable else `None`.
- `_TORCHFT_AVAILABLE: bool` — whether the imports succeeded.
```python
from composer_replication.diloco import _TORCHFT_AVAILABLE, DiLoCo
```
### `make_diloco_outer_loop(manager, model_fragments, inner_optimizer, *, ...) -> torchft.local_sgd.DiLoCo`
```python
def make_diloco_outer_loop(
manager: Any,
model_fragments: list[torch.nn.Module],
inner_optimizer: torch.optim.Optimizer,
*,
outer_lr: float = 0.7,
outer_momentum: float = 0.9,
nesterov: bool = True,
sync_every: int = 100,
fragment_sync_delay: int = 0,
fragment_update_alpha: float = 0.0,
) -> Any
```
Construct a `torchft.DiLoCo` configured with framework-default hyperparams (DiLoCo paper §3.2: `lr=0.7, momentum=0.9, Nesterov`).
**Parameters**
| Name | Type | Default | Meaning |
|---|---|---|---|
| `manager` | `torchft.Manager` (or duck-typed `MockManager`) | — | Provides `allreduce`, `should_commit`, `current_step`, `start_quorum`, etc. |
| `model_fragments` | `list[torch.nn.Module]` | — | One module for vanilla DiLoCo; N modules for Streaming DiLoCo. |
| `inner_optimizer` | `torch.optim.Optimizer` | — | Inner-step optimizer (steps every batch). |
| `outer_lr` | `float` | `0.7` | Outer SGD lr. |
| `outer_momentum` | `float` | `0.9` | Outer SGD momentum. |
| `nesterov` | `bool` | `True` | Nesterov momentum on outer SGD. |
| `sync_every` | `int` | `100` | Inner steps per outer round. |
| `fragment_sync_delay` | `int` | `0` | 0 = vanilla; >0 = Streaming DiLoCo (requires CUDA streams). |
| `fragment_update_alpha` | `float` | `0.0` | 0 = full replacement on sync; >0 = exponential mix. |
**Returns** a `torchft.local_sgd.DiLoCo` instance — usable as a context manager.
**Raises** `RuntimeError` if `torchft` is not installed.
```python
import torch
from composer_replication.diloco import make_diloco_outer_loop
opt = torch.optim.AdamW(model.parameters(), lr=1e-5)
outer = make_diloco_outer_loop(manager=mgr, model_fragments=[model],
inner_optimizer=opt, sync_every=100)
with outer:
for _ in range(N):
opt.zero_grad(); loss.backward(); opt.step()
```
---
## 12. `composer_replication.diloco.serverless`
ADR-005 serverless DiLoCo executors + object-store all-reduce.
### `class ReplicaHandle` — `serverless.executor`
```python
@dataclass
class ReplicaHandle:
rank: int
backend_name: str
metadata: dict[str, Any] = field(default_factory=dict)
```
Opaque handle returned by `ServerlessExecutor.launch_replicas`. `metadata` is backend-specific.
```python
from composer_replication.diloco.serverless import ReplicaHandle
h = ReplicaHandle(rank=0, backend_name="local_process",
metadata={"pid": 12345})
```
### `class ServerlessExecutor` (Protocol) — `serverless.executor`
```python
@runtime_checkable
class ServerlessExecutor(Protocol):
backend_name: str
supports_inter_replica_network: bool
def launch_replicas(
self,
n_replicas: int,
entrypoint: str | Callable[..., Any],
entrypoint_args: Mapping[str, Any],
*,
gpu: str | None = None,
timeout: int = 3600,
) -> list[ReplicaHandle]: ...
def poll(self, handle: ReplicaHandle) -> str: ...
def stream_logs(self, handle: ReplicaHandle, *, n_lines: int = 200) -> str: ...
def cancel(self, handle: ReplicaHandle) -> None: ...
def collect(
self, handles: list[ReplicaHandle], *, timeout: int | None = None,
) -> list[dict[str, Any]]: ...
```
Structural protocol for serverless backends.
- `launch_replicas(...)` returns `list[ReplicaHandle]` of length `n_replicas` in rank order. `entrypoint` is either an importable module path (uses `main()`) or a `module.function` path or a `Callable` (Local executor only). `entrypoint_args` may include `rank_env` (default `"REPLICA_RANK"`).
- `poll(handle) -> str`: one of `"pending"`, `"running"`, `"succeeded"`, `"failed"`, `"cancelled"`.
- `stream_logs(handle, n_lines=200) -> str`: best-effort recent stdout/stderr.
- `cancel(handle) -> None`: best-effort.
- `collect(handles, timeout=None) -> list[dict]`: blocks; each result dict has `rank`, `status`, `exit_code`, `error` (and `result` from `LocalProcessExecutor`).
```python
from composer_replication.diloco.serverless import ServerlessExecutor
def supports(x: ServerlessExecutor) -> bool:
return isinstance(x, ServerlessExecutor) # runtime_checkable
```
### `class LocalProcessExecutor` — `serverless.executor`
```python
class LocalProcessExecutor:
backend_name = "local_process"
supports_inter_replica_network = True
def __init__(self) -> None: ...
# implements ServerlessExecutor protocol
```
Reference implementation using Python `multiprocessing` (`spawn` context). Used for tests, CI smokes, and local development with `file://` rendezvous.
`launch_replicas(...)`: emits a soft warning on `gpu != None` (local processes share whatever GPUs are visible). `metadata = {"pid": ..., "start_ts": ...}`.
```python
from composer_replication.diloco.serverless import LocalProcessExecutor
ex = LocalProcessExecutor()
handles = ex.launch_replicas(
n_replicas=2,
entrypoint="composer_replication.diloco.serverless.replica_entrypoint",
entrypoint_args={"rendezvous_uri": "/tmp/run/", "world_size": 2,
"trainer_module": "my.trainer"},
)
results = ex.collect(handles, timeout=60)
```
### `class ObjectStoreAllReduce` — `serverless.allreduce`
```python
class ObjectStoreAllReduce:
def __init__(
self,
uri: str,
rank: int,
world_size: int,
*,
round_id: int | None = None,
timeout_s: float = 1800.0,
poll_interval_s: float = 1.0,
) -> None: ...
@property
def round_id(self) -> int: ...
def allreduce(
self, tensor: torch.Tensor, *, name: str | None = None,
) -> torch.Tensor: ...
```
fsspec-backed pseudo-gradient rendezvous. `uri` accepts `s3://`, `gs://`, `az://`, `hf://`, `file://`, or a plain local path.
**Constructor parameters**
| Name | Type | Default | Meaning |
|---|---|---|---|
| `uri` | `str` | — | fsspec URI or local path. Trailing `/` enforced. |
| `rank` | `int` | — | This replica's rank. |
| `world_size` | `int` | — | Total replicas. |
| `round_id` | `int \| None` (kw-only) | `None` ⇒ start at 0 | Initial round counter. |
| `timeout_s` | `float` (kw-only) | `1800.0` | Per-`allreduce` timeout. |
| `poll_interval_s` | `float` (kw-only) | `1.0` | Sleep between peer-file existence checks. |
**`allreduce(tensor, name=None) -> torch.Tensor`**: serializes `tensor.detach().cpu()` to `round_NNNNNN/rank_RRRR.pt`, blocks until all peers post, then averages. **Modifies `tensor` in place** AND returns it. Increments the internal `_round_counter`.
**Raises** `ValueError` on invalid `rank`, `RuntimeError` if non-local URI is requested without `fsspec` installed, `TimeoutError` if peers don't show up before `timeout_s`.
```python
from composer_replication.diloco.serverless import ObjectStoreAllReduce
import torch
store = ObjectStoreAllReduce("/tmp/run/", rank=0, world_size=2)
g = torch.zeros(10)
store.allreduce(g) # blocks for rank 1
```
### `class MockManager` — `serverless.allreduce`
```python
class MockManager:
def __init__(self, store: ObjectStoreAllReduce) -> None: ...
# torchft.Manager-shaped surface:
num_participants: int
rank: int
_use_async_quorum: bool # always False
_step: int
_state_dict_fns: dict[str, tuple[Any, Any]]
def allreduce(self, tensor: torch.Tensor, **_kwargs: Any) -> "_ImmediateWork": ...
def should_commit(self) -> bool: ...
def start_quorum(self) -> None: ...
def wait_quorum(self) -> int: ...
def current_step(self) -> int: ...
def allow_state_dict_read(self) -> None: ...
def disallow_state_dict_read(self) -> None: ...
def register_state_dict_fn(self, key: str, load_fn: Any, save_fn: Any) -> None: ...
def is_leader(self) -> bool: ...
```
Drop-in replacement for `torchft.Manager` that routes `allreduce` through `ObjectStoreAllReduce`. All other methods are no-ops or simple counters appropriate for single-shot serverless DiLoCo.
- `allreduce(tensor)` returns an `_ImmediateWork` whose `.wait()` is a no-op (the tensor is already averaged).
- `should_commit()` always `True` (no fault-tolerance failover).
- `start_quorum()` bumps `_step`.
- `is_leader()` returns `rank == 0`.
```python
from composer_replication.diloco.serverless import MockManager, ObjectStoreAllReduce
store = ObjectStoreAllReduce("/tmp/run/", rank=0, world_size=2)
mgr = MockManager(store)
# pass mgr into make_diloco_outer_loop(manager=mgr, ...)
```
### `class _ImmediateWork` — `serverless.allreduce`
⚠️ UNTESTED-CONTRACT internal helper exported from `__all__`. `Work`-shaped wrapper with `.wait() -> True` and `.get_future() -> torch.futures.Future`. Consumed by torchft DiLoCo's `perform_sync`.
```python
from composer_replication.diloco.serverless.allreduce import _ImmediateWork
```
### `class ModalExecutor` — `serverless.modal`
🟡 SKELETON — raises `NotImplementedError`; see ADR-005. Class body documents the v0 implementation pattern (Modal `app.function` + `function.spawn(rank=...)`).
```python
from composer_replication.diloco.serverless.modal import ModalExecutor
# ModalExecutor() # would NotImplementedError when instantiated
```
### `class HFJobsExecutor` — `serverless.hf_jobs`
🟡 SKELETON — raises `NotImplementedError`; see ADR-005. Class body documents the v0 pattern using `huggingface_hub.run_job` against `hf://datasets/.../` rendezvous.
```python
from composer_replication.diloco.serverless.hf_jobs import HFJobsExecutor
# instantiation will fail until v0 implementation lands
```
### `replica_entrypoint.main(...)` — `serverless.replica_entrypoint`
```python
def main(
rendezvous_uri: str,
world_size: int,
trainer_module: str,
trainer_fn: str = "train",
trainer_kwargs: dict[str, Any] | None = None,
) -> Any
```
Script run by every replica. Reads `REPLICA_RANK` env var, builds `ObjectStoreAllReduce` + `MockManager`, imports `trainer_module`, and calls `getattr(mod, trainer_fn)(**trainer_kwargs, manager=..., rank=..., world_size=...)`. Returns whatever the train fn returns.
**Raises** `RuntimeError` if `REPLICA_RANK` env var is missing; `ValueError` if rank ∉ `[0, world_size)`.
The `if __name__ == "__main__"` block accepts CLI flags `--rendezvous`, `--world-size`, `--trainer-module`, `--trainer-fn`, `--trainer-kwargs-json`.
```python
# In-process invocation
import os
os.environ["REPLICA_RANK"] = "0"
from composer_replication.diloco.serverless.replica_entrypoint import main
result = main(rendezvous_uri="/tmp/run/", world_size=1,
trainer_module="my.trainer", trainer_fn="train")
```
---
## 13. `composer_replication.recipes.prime_rl.composer_loss`
PRIME-RL adapter (ADR-006). Maps PRIME-RL's `LossInputs` struct onto channel 1 (DPPO + KL on the importance ratio, mirroring PRIME-RL's upstream `default_loss_fn` at `prime_rl/trainer/rl/loss.py` lines 116-165). Channel 2 raises `NotImplementedError`; channel 3 is out of scope.
### `loss_fn(inputs, *, alpha_sdpo=0.0, beta_dpo=0.0, dppo_mask_high=0.2, dppo_mask_low=0.2, adv_tau=1.0, kl_tau=1e-3) -> torch.Tensor`
```python
def loss_fn(
inputs: Any, # PRIME-RL's LossInputs (duck-typed)
*,
alpha_sdpo: float = 0.0,
beta_dpo: float = 0.0,
dppo_mask_high: float = 0.2,
dppo_mask_low: float = 0.2,
adv_tau: float = 1.0,
kl_tau: float = 1e-3,
) -> Any # torch.Tensor scalar
```
PRIME-RL passes per-sample **1-D `(seq,)` tensors** (not batched). The function mirrors PRIME-RL's upstream DPPO+KL formula:
- Mask gate is on **probability-space** `probs_diff = exp(trainer_lp) - exp(inference_lp)` (NOT on the log-ratio).
- A token is dropped iff its advantage sign matches the offending bound: positive-advantage tokens are dropped when `probs_diff > dppo_mask_high`, negative-advantage tokens when `probs_diff < -dppo_mask_low`. (PRIME-RL stores both bounds with `Field(..., ge=0)` and applies the sign internally.)
- The PG term is `keep * (adv_tau * advantages) * exp(trainer_lp - inference_lp)` (importance-ratio corrected, not REINFORCE).
- A KL penalty `kl_tau * log_importance_ratio**2` is added on the full `loss_mask` (DPPO masking does not gate it).
- Reduction is a plain `sum()`; PRIME-RL's outer `compute_loss` divides by `loss_scale`.
**Parameters**
| Name | Type | Default | Meaning |
|---|---|---|---|
| `inputs` | PRIME-RL `LossInputs` (duck-typed) | — | Must expose `trainer_logprobs`, `inference_logprobs`, `advantages`, `loss_mask` (all 1-D), and optionally `teacher_logprobs`. |
| `alpha_sdpo` | `float` (kw-only) | `0.0` | Channel-2 weight. Must be `0` in v0; >0 → `NotImplementedError`. |
| `beta_dpo` | `float` (kw-only) | `0.0` | Channel-3 weight. Non-zero emits a `UserWarning`. |
| `dppo_mask_high` | `float` (kw-only), `>= 0` | `0.2` | Upper probability-diff threshold. PRIME-RL `DefaultLossConfig` default. |
| `dppo_mask_low` | `float` (kw-only), `>= 0` | `0.2` | Magnitude of lower probability-diff threshold (sign flipped internally). PRIME-RL default. |
| `adv_tau` | `float` (kw-only), `>= 0` | `1.0` | Advantage temperature. PRIME-RL default. |
| `kl_tau` | `float` (kw-only), `>= 0` | `1e-3` | KL term temperature. PRIME-RL default. |
**Returns** scalar `torch.Tensor` (PRIME-RL's trainer calls `.backward()`).
**Raises** `ValueError` if any of `trainer_logprobs`, `inference_logprobs`, `advantages`, `loss_mask` is not 1-D, or any of the four `>=0`-constrained knobs is negative. `NotImplementedError` if `alpha_sdpo > 0` (channel 2 deferred).
```python
from composer_replication.recipes.prime_rl.composer_loss import loss_fn
# In PRIME-RL config:
# loss:
# custom:
# import_path: composer_replication.recipes.prime_rl.composer_loss:loss_fn
# kwargs:
# dppo_mask_high: 0.2
# dppo_mask_low: 0.2
# adv_tau: 1.0
# kl_tau: 1.0e-3
```
---
## 14. `composer_replication.recipes.monarch.actors`
🟡 SKELETON module per ADR-006. Importable; classes raise `NotImplementedError` on instantiation. Documents the actor signatures so the recipe matrix is complete.
### `class TrainerActor` 🟡
```python
class TrainerActor:
backend = "monarch"
role = "trainer"
def __init__(self) -> None: raise NotImplementedError(...)
async def train_outer_step(self, batch_id: int) -> dict[str, Any]: raise NotImplementedError
```
Hosts the framework's 3-channel composer trainer. Real impl deferred to v0.2+.
### `class GeneratorActor` 🟡
```python
class GeneratorActor:
backend = "monarch"
role = "generator"
def __init__(self) -> None: raise NotImplementedError(...)
async def rollout(self, prompts: list[str]) -> list[str]: raise NotImplementedError
```
vLLM-backed rollout actor.
### `class RewarderActor` 🟡
```python
class RewarderActor:
backend = "monarch"
role = "rewarder"
def __init__(self) -> None: raise NotImplementedError(...)
async def score(self, completions: list[str]) -> list[float]: raise NotImplementedError
```
verifiers-protocol rewarder.
### `class TeacherPoolActor` 🟡
```python
class TeacherPoolActor:
backend = "monarch"
role = "teacher_pool"
def __init__(self) -> None: raise NotImplementedError(...)
```
Channel-3 teacher pool wrapping `composer_replication.teacher_replay`.
```python
# All Monarch actors raise on instantiation in v0:
from composer_replication.recipes.monarch.actors import TrainerActor
# TrainerActor() # NotImplementedError
```
---
## Notes on test coverage
Tested contracts (referenced spike/test paths):
- `compose_loss` + `LossComponents` + `build_batch`: `composer_replication/tests/test_compose_loss_integration.py`, `spikes/006-real-hf-model-smoke/tests/`.
- `generalized_jsd_loss`: `spikes/005-integrated-trainer-skeleton/tests/test_opsd_loss.py`.
- `simpo_loss`, `taid_loss`, `taid_alpha_schedule`, `taid_blended_logits`, `entropy_aware_opd_loss`: `composer_replication/distillation/tests/test_distillation_losses.py`.
- `replay_trace`, `extract_dpo_pairs`, `DPOPair`, `TraceState`, `TeacherCallResult`, `TeacherSpec`, `DEFAULT_TEACHERS`: `spikes/005-integrated-trainer-skeleton/tests/test_teacher_replay.py`.
- `DJNormalizer`, `NormalizedDPOPair`, `replay_and_normalize_trace`: `composer_replication/replaysim/tests/test_replaysim.py`.
- `ClaudeCodeIngester`, `IngestionStats`, `SYSTEM_PROMPT`: `spikes/007-real-trace-ingestion/tests/`.
- `ComposerDataCollator`, `CollatorConfig`, `TraceTurn`, `TraceExample`: `spikes/005-integrated-trainer-skeleton/tests/test_data_collator.py`.
- `ComposerReplicationTrainer._compute_loss` (composition arithmetic): `spikes/005-integrated-trainer-skeleton/tests/test_loss_composition_smoke.py`.
- `make_diloco_outer_loop` + sign convention: `spikes/008-streaming-diloco/tests/test_diloco_smoke.py`.
- `ObjectStoreAllReduce`, `MockManager`, `LocalProcessExecutor`, `ReplicaHandle`, `ServerlessExecutor`, `replica_entrypoint.main`: `composer_replication/diloco/serverless/tests/test_serverless_local.py`, `test_serverless_diloco_integration.py`.
- `recipes.prime_rl.composer_loss.loss_fn`: `composer_replication/recipes/prime_rl/tests/test_composer_loss.py`.
Untested-contract symbols (⚠️) and skeletons (🟡) are flagged inline above.
---
**Document path**: `/mnt/e/CS/HF/composer-replication-framework/docs/API_REFERENCE.md`