Codeseys's picture
Wave 16: install ergonomics + gradient evidence + SDPO end-to-end example
c0a5ab7

API Reference — composer-replication-framework

Complete reference for every public symbol in composer_replication. Source-of-truth is the .py files in composer_replication/; docstrings have been pulled verbatim where they exist and supplemented where missing.

Legend

  • ⚠️ UNTESTED-CONTRACT — symbol exists and is callable, but its behaviour is not pinned by an automated test in composer_replication/**/tests/ or spikes/**/tests/.
  • 🟡 SKELETON — class/method body raises NotImplementedError; ships as design-of-record per ADR-005 / ADR-006.

Module groups (in this document)

  1. composer_replication (top-level re-exports)
  2. composer_replication.loss
  3. composer_replication.batch
  4. composer_replication.opsd
  5. composer_replication.distillation
  6. composer_replication.teacher_replay
  7. composer_replication.replaysim
  8. composer_replication.ingestion (+ .claude_code)
  9. composer_replication.hint_generator
  10. composer_replication.trainer (+ .composer_trainer, .data_collator)
  11. composer_replication.diloco
  12. composer_replication.diloco.serverless (+ .executor, .allreduce, .modal, .hf_jobs, .replica_entrypoint)
  13. composer_replication.recipes.prime_rl.composer_loss
  14. composer_replication.recipes.monarch.actors

1. composer_replication — top-level package

The package re-exports the most common entry points from sub-modules. __all__ is the canonical list of public top-level names.

composer_replication.__version__: str

Package version string. Currently "0.1.0".

import composer_replication
print(composer_replication.__version__)  # "0.1.0"

composer_replication._DILOCO_AVAILABLE: bool

True iff torchft is importable in the running Python environment (gates make_diloco_outer_loop). Set to False and make_diloco_outer_loop is set to None when torchft is missing.

from composer_replication import _DILOCO_AVAILABLE
if _DILOCO_AVAILABLE:
    from composer_replication import make_diloco_outer_loop

Re-exports

Name Source module
compose_loss composer_replication.loss
LossComponents composer_replication.loss
build_batch composer_replication.batch
generalized_jsd_loss composer_replication.opsd
ClaudeCodeIngester composer_replication.ingestion.claude_code
IngestionStats composer_replication.ingestion.claude_code
SYSTEM_PROMPT composer_replication.ingestion.claude_code
DEFAULT_TEACHERS composer_replication.teacher_replay
DPOPair composer_replication.teacher_replay
TeacherCallResult composer_replication.teacher_replay
TeacherSpec composer_replication.teacher_replay
TraceState composer_replication.teacher_replay
extract_dpo_pairs composer_replication.teacher_replay
replay_trace composer_replication.teacher_replay
ComposerReplicationTrainer composer_replication.trainer
make_diloco_outer_loop composer_replication.diloco (or None if torchft missing)

See each source module below for full signatures.


2. composer_replication.loss

Verification-harness 3-channel loss. Free function, does not depend on trl.

class LossComponents

@dataclass
class LossComponents:
    lm_ce: torch.Tensor
    sdpo_jsd: torch.Tensor
    trace_replay_dpo: torch.Tensor
    total: torch.Tensor

    def detached(self) -> dict[str, float]: ...

Per-channel breakdown of the total loss for logging and ablation. All four fields are scalar torch.Tensors (shape=()); total = lm_ce + alpha_sdpo * sdpo_jsd + beta_replay * trace_replay_dpo.

detached() -> dict[str, float] — returns Python-float copies of all four fields with no grad. Useful for W&B logging.

from composer_replication import compose_loss, build_batch
components = compose_loss(model, build_batch(tokenizer))
print(components.detached())  # {'lm_ce': 2.34, 'sdpo_jsd': 0.12, ...}
components.total.backward()

compose_loss(model, inputs, *, ...) -> LossComponents

def compose_loss(
    model: torch.nn.Module,
    inputs: dict[str, torch.Tensor],
    *,
    alpha_sdpo: float = 0.1,
    beta_replay: float = 0.05,
    sdpo_jsd_beta: float = 0.5,
    sdpo_temperature: float = 1.0,
    sdpo_token_clip: float | None = None,
    replay_dpo_beta: float = 0.1,
    lm_ce_label_smoothing: float = 0.0,
    dpo_variant: Literal["dpo", "simpo"] = "dpo",
    sdpo_wrapper: Literal["none", "taid", "entropy_opd"] = "none",
    taid_t: float | None = None,
    simpo_beta: float = 2.0,
    simpo_gamma: float = 1.0,
    entropy_opd_h_max: float | None = None,
) -> LossComponents

Compute total = lm_ce + alpha_sdpo * sdpo_jsd + beta_replay * trace_replay_dpo.

Required keys in inputs

  • input_ids: (B, T_s) student rollout token ids.
  • response_mask: (B, T_s) 1 on assistant-response tokens, 0 elsewhere.

Optional keys (channel auto-disables if missing OR if its weight = 0):

  • SDPO: ctx_teacher_input_ids (B, T_t), sdpo_loss_mask (B, T_t).
  • DPO (dpo_variant="dpo"): dpo_chosen_input_ids, dpo_chosen_response_mask, dpo_rejected_input_ids, dpo_rejected_response_mask, dpo_chosen_ref_logprobs, dpo_rejected_ref_logprobs (precomputed).
  • SimPO (dpo_variant="simpo"): same DPO ids/masks; reference logprobs are silently ignored.
  • TAID (sdpo_wrapper="taid"): no extra inputs keys needed; the optional sdpo_loss_mask is reused as the per-token TAID mask. Pass taid_t directly (or drive it from TAIDScheduler).

Parameters

Name Type Default Meaning
model torch.nn.Module HF causal-LM. Must accept input_ids= and return an object with .logits.
inputs dict[str, torch.Tensor] Batch dict (see required/optional keys above).
alpha_sdpo float 0.1 Weight on SDPO/JSD channel. 0.0 disables.
beta_replay float 0.05 Weight on trace-replay DPO channel. 0.0 disables.
sdpo_jsd_beta float 0.5 β param for generalized_jsd_loss (0=fwd KL, 0.5=JSD, 1=rev KL). Unused when sdpo_wrapper="taid".
sdpo_temperature float 1.0 Softmax temperature in SDPO. Unused when sdpo_wrapper="taid".
sdpo_token_clip float | None None Per-token JSD clamp.
replay_dpo_beta float 0.1 β in standard DPO logit.
lm_ce_label_smoothing float 0.0 F.cross_entropy(label_smoothing=).
dpo_variant Literal["dpo","simpo"] "dpo" Channel-3 algorithm.
sdpo_wrapper Literal["none","taid","entropy_opd"] "none" Channel-2 wrapper.
taid_t float | None None Current TAID interpolation coefficient in [0, 1]. Required when sdpo_wrapper="taid". Drive from TAIDScheduler or pass a fixed value.
simpo_beta float 2.0 SimPO β (paper default).
simpo_gamma float 1.0 SimPO target margin γ (paper default).
entropy_opd_h_max float | None None Max-entropy normalizer; Nonelog(V).

Returns LossComponents (see above).

Raises ValueError if dpo_variant or sdpo_wrapper is unknown, if sdpo_wrapper="taid" is requested without taid_t, or if taid_t is outside [0, 1].

from composer_replication import compose_loss, build_batch
batch = build_batch(tokenizer)
out = compose_loss(model, batch, alpha_sdpo=0.1, beta_replay=0.05)
out.total.backward()
print(out.detached())

3. composer_replication.batch

Verification-harness batch builder.

build_batch(tokenizer, *, ...) -> dict[str, torch.Tensor]

def build_batch(
    tokenizer: Any,
    *,
    device: torch.device | str = "cpu",
    seed: int = 42,
    variant: str = "factorial",
    align_sdpo_shapes: bool = False,
) -> dict[str, torch.Tensor]

Construct a full 3-channel batch from a real HF tokenizer. The DPO ref-logprobs are dummy tensors (the smoke verifies loss composition wires together, not the reference-policy precompute).

Returned keys: input_ids, response_mask, ctx_teacher_input_ids, sdpo_loss_mask, dpo_chosen_input_ids, dpo_chosen_response_mask, dpo_rejected_input_ids, dpo_rejected_response_mask, dpo_chosen_ref_logprobs, dpo_rejected_ref_logprobs.

Parameters

Name Type Default Meaning
tokenizer HF AutoTokenizer (duck-typed) Must support apply_chat_template and __call__.
device torch.device | str "cpu" Target device for all returned tensors.
seed int 42 Fixes torch.manual_seed.
variant str "factorial" One of "factorial", "binary_search".
align_sdpo_shapes bool False If True, truncate/pad ctx_teacher_input_ids to input_ids length so the SDPO channel actually fires.

Raises ValueError if variant is unknown.

from transformers import AutoTokenizer
from composer_replication import build_batch
tok = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-0.5B-Instruct")
batch = build_batch(tok, variant="factorial", align_sdpo_shapes=True)
print({k: v.shape for k, v in batch.items()})

4. composer_replication.opsd

Self-distillation generalized-JSD loss, lifted verbatim from siyan-zhao/OPSD (MIT) per ADR-006.

generalized_jsd_loss(student_logits, teacher_logits, labels=None, beta=0.5, ...) -> torch.Tensor

def generalized_jsd_loss(
    student_logits: torch.Tensor,
    teacher_logits: torch.Tensor,
    labels: torch.Tensor | None = None,
    beta: float = 0.5,
    temperature: float = 1.0,
    reduction: str = "batchmean",
    logits_are_probs: bool = False,
    top_k: int | None = None,
    token_clip: float | None = None,
) -> torch.Tensor

Generalized JSD between student and teacher distributions. Same model on different contexts in the SDPO recipe; student and teacher params come from the SAME model.

Parameters

Name Type Default Meaning
student_logits Tensor (B, T, V) Student logits with grad.
teacher_logits Tensor (B, T, V) Teacher logits (no grad in SDPO).
labels Tensor (B, T) | None None Per-token mask. -100 positions are ignored (HF convention).
beta float in [0, 1] 0.5 0=fwd KL, 1=rev KL, 0.5=symmetric JSD.
temperature float 1.0 Softmax temperature.
reduction str "batchmean" "batchmean", "sum", "mean", "none".
logits_are_probs bool False Skip softmax if inputs are already probabilities.
top_k int | None None Restrict KL to teacher's top-k tokens.
token_clip float | None None Clip per-token JSD for stability.

Returns scalar tensor (or (B, T) if reduction="none").

Raises ValueError for unknown reduction.

import torch
from composer_replication.opsd import generalized_jsd_loss
s = torch.randn(2, 8, 32, requires_grad=True)
t = torch.randn(2, 8, 32)
loss = generalized_jsd_loss(s, t, beta=0.5, reduction="batchmean")
loss.backward()

5. composer_replication.distillation

Pluggable self-distillation losses (ADR-007). All pure PyTorch.

simpo_loss(chosen_avg_logprobs, rejected_avg_logprobs, *, beta=2.0, gamma=1.0) -> torch.Tensor

def simpo_loss(
    chosen_avg_logprobs: torch.Tensor,
    rejected_avg_logprobs: torch.Tensor,
    *,
    beta: float = 2.0,
    gamma: float = 1.0,
) -> torch.Tensor

Reference-free DPO with target margin γ (Meng et al., NeurIPS 2024). L = -log σ(β · (avg_logπ(c) − avg_logπ(r)) − γ).

Parameters

Name Type Default Meaning
chosen_avg_logprobs Tensor (B,) Per-sequence avg logprob over chosen response tokens.
rejected_avg_logprobs Tensor (B,) Same for rejected.
beta float 2.0 Scaling factor (paper default).
gamma float 1.0 Target margin (paper default).

Returns scalar; Raises ValueError if shapes mismatch.

import torch
from composer_replication.distillation import simpo_loss
loss = simpo_loss(torch.tensor([-2.1, -1.8]), torch.tensor([-3.0, -2.5]),
                  beta=2.0, gamma=1.0)

avg_sequence_logprob(model_logprobs, response_mask) -> torch.Tensor

⚠️ UNTESTED-CONTRACT (helper exported from simpo.py but not asserted by a test).

def avg_sequence_logprob(
    model_logprobs: torch.Tensor,
    response_mask: torch.Tensor,
) -> torch.Tensor

Convert (B, T) per-token logprobs + (B, T) response mask into (B,) per-sequence average over response tokens.

from composer_replication.distillation.simpo import avg_sequence_logprob
import torch
lp = torch.randn(2, 8); m = torch.tensor([[0,0,1,1,1,0,0,0],[0,1,1,1,1,1,0,0]])
out = avg_sequence_logprob(lp, m)  # shape (2,)

taid_loss(student_logits, teacher_logits, mask=None, *, t) -> torch.Tensor

def taid_loss(
    student_logits: torch.Tensor,
    teacher_logits: torch.Tensor,
    mask: torch.Tensor | None = None,
    *,
    t: float | torch.Tensor,
) -> torch.Tensor

Faithful port of SakanaAI/TAID (arXiv:2501.16937). Forward-KL distillation against a logit-space-interpolated target whose anchor is the current student detached:

p_t = softmax( (1 - t) · stop_grad(student_logits) + t · teacher_logits )
L   = - mean_token  Σ_v  p_t(v) · log_softmax(student_logits)(v)

At t=0 the target collapses to the detached student (no teacher signal in the gradient). At t=1 it reduces to standard forward-KL distillation against the teacher.

Wave 15 breaking change. The previous signature taid_loss(student, teacher, student_init, *, schedule_step, total_steps, schedule, alpha_min, alpha_max, jsd_beta, temperature, reduction) was algorithmically wrong (probability-space mix, frozen step-0 anchor, JSD criterion). All those kwargs are removed; the schedule is now the caller's responsibility (see TAIDScheduler below for the upstream adaptive scheme).

Parameters

Name Type Default Meaning
student_logits Tensor (B, T, V) Current student (with grad).
teacher_logits Tensor (B, T, V) Teacher logits.
mask Tensor (B, T) | None None Token mask. None ⇒ all-ones.
t float | Tensor Interpolation coefficient in [0, 1].

Raises ValueError for shape mismatch.

from composer_replication.distillation import taid_loss
loss = taid_loss(s_logits, t_logits, mask, t=0.4)

TAIDScheduler(num_train_steps, *, t_start=0.4, t_end=1.0, alpha=5e-4, beta=0.99, disable_adaptive=False)

Stateful schedule that mirrors upstream TAID.update_t. Monotone non-decreasing, bumped above the linear floor by an EMA on the relative loss change. Use as:

from composer_replication.distillation import TAIDScheduler

sched = TAIDScheduler(num_train_steps=10_000)   # paper defaults
for step in range(num_train_steps):
    loss = taid_loss(s, t, mask, t=sched.t)
    loss.backward(); optimizer.step()
    sched.update_t(loss.detach(), global_step=step)

Parameters

Name Type Default Meaning
num_train_steps int Total planned training steps; sets the linear floor.
t_start float 0.4 Initial t (paper default).
t_end float 1.0 Terminal t; hard ceiling at every step.
alpha float 5e-4 Adaptive bump magnitude.
beta float 0.99 EMA decay on relative-loss-change momentum.
disable_adaptive bool False If True, fall back to deterministic linear schedule.
device torch.device | str "cpu" Where to allocate state buffers.

Properties / methods

  • sched.t -> float — current t as a Python float (zero-arg property).
  • sched.update_t(loss, global_step) -> Tensor | None — update internal state. First finite-loss call only seeds prev_loss and returns None; subsequent calls return the (positive) delta_t added on top of the linear floor.

entropy_aware_opd_loss(student_logits, teacher_logits, *, labels=None, h_max=None, temperature=1.0, reduction="batchmean") -> torch.Tensor

def entropy_aware_opd_loss(
    student_logits: torch.Tensor,
    teacher_logits: torch.Tensor,
    *,
    labels: torch.Tensor | None = None,
    h_max: float | None = None,
    temperature: float = 1.0,
    reduction: str = "batchmean",
) -> torch.Tensor

Per-token mixture of forward and reverse KL gated by teacher entropy: w(t) = clamp(H_teacher(t)/h_max, 0, 1). High-entropy tokens use forward KL (mode-covering), low-entropy tokens use reverse KL (mode-seeking).

Parameters

Name Type Default Meaning
student_logits Tensor (B,T,V) Student logits (grad).
teacher_logits Tensor (B,T,V) Teacher logits (no grad).
labels Tensor (B,T) | None None 0/1 mask, applied multiplicatively after the per-token mix.
h_max float | None Nonelog(V) Max-entropy normalizer.
temperature float 1.0 Softmax temperature on both.
reduction str "batchmean" "batchmean", "sum", "mean", "none".

Raises ValueError on shape mismatch (student vs teacher; labels vs per-token loss) or unknown reduction.

from composer_replication.distillation import entropy_aware_opd_loss
loss = entropy_aware_opd_loss(s_logits, t_logits, temperature=1.0)
loss.backward()

teacher_entropy(teacher_logits) -> torch.Tensor

⚠️ UNTESTED-CONTRACT (helper exposed from entropy_aware_opd.py's __all__ but not directly asserted).

Per-token entropy in nats. Input (B,T,V), output (B,T).

from composer_replication.distillation.entropy_aware_opd import teacher_entropy
H = teacher_entropy(teacher_logits)  # (B, T)

6. composer_replication.teacher_replay

N-teacher OpenRouter parallel client + DPO-pair extractor. httpx is lazy-imported inside replay_trace; the deterministic local logic is testable without it.

DEFAULT_TEACHERS: list[TeacherSpec]

Three-teacher default set: anthropic/claude-opus-4.7, openai/gpt-5, deepseek/deepseek-v4-pro with paper-baseline OpenRouter pricing.

from composer_replication.teacher_replay import DEFAULT_TEACHERS
print([t["slug"] for t in DEFAULT_TEACHERS])

class TeacherSpec(TypedDict)

class TeacherSpec(TypedDict):
    slug: str
    input_per_mtok: float
    output_per_mtok: float

OpenRouter model slug + per-million-token pricing.

spec: TeacherSpec = {"slug": "openai/gpt-5",
                     "input_per_mtok": 1.25, "output_per_mtok": 10.0}

class TraceState(TypedDict)

class TraceState(TypedDict):
    state_id: str          # unique within the trace
    messages: list[dict]   # OpenAI-style chat history up to (and incl.) this user prompt
    student_action: str    # what the student actually did at this step

One step of a frozen agentic trace. student_action is the raw text emitted by the student; teachers are queried with messages and asked to predict the assistant's next action.

state: TraceState = {"state_id": "ex001::0042",
                     "messages": [{"role": "user", "content": "..."}],
                     "student_action": "[TOOL_USE] name=Read input={...}"}

class TeacherCallResult(TypedDict)

class TeacherCallResult(TypedDict):
    state_id: str
    teacher_slug: str
    response_text: str | None    # None on error
    latency_s: float
    prompt_tokens: int
    completion_tokens: int
    cost_usd: float
    error: str | None            # None on success

One row of N×T results from replay_trace.

r: TeacherCallResult = {"state_id": "x", "teacher_slug": "openai/gpt-5",
    "response_text": "ok", "latency_s": 1.2, "prompt_tokens": 100,
    "completion_tokens": 5, "cost_usd": 0.001, "error": None}

class DPOPair(TypedDict)

class DPOPair(TypedDict):
    state_id: str
    state_messages: list[dict]
    chosen: str          # teacher-consensus action
    rejected: str        # student action
    n_teachers_agreeing: int

One preference pair extracted from teacher-vs-student disagreement.

p: DPOPair = {"state_id": "x", "state_messages": [...], "chosen": "...",
              "rejected": "...", "n_teachers_agreeing": 2}

async replay_trace(states, teachers=DEFAULT_TEACHERS, max_total_usd=5.0, api_key=None) -> list[TeacherCallResult]

async def replay_trace(
    states: Sequence[TraceState],
    teachers: Sequence[TeacherSpec] = tuple(DEFAULT_TEACHERS),
    max_total_usd: float = 5.0,
    api_key: str | None = None,
) -> list[TeacherCallResult]

For each state, fan-out one parallel call per teacher via OpenRouter. Hard-caps cumulative spend at max_total_usd (stops after the offending state completes).

Parameters

Name Type Default Meaning
states Sequence[TraceState] Frozen trace, one entry per assistant turn.
teachers Sequence[TeacherSpec] DEFAULT_TEACHERS Models to query in parallel.
max_total_usd float 5.0 Cumulative spend cap.
api_key str | None None OpenRouter key; defaults to OPENROUTER_API_KEY env or ~/.hermes/.env.

Returns flat list of TeacherCallResults (length len(states) * len(teachers) modulo budget cutoff).

Raises RuntimeError if OPENROUTER_API_KEY is not findable; ImportError if httpx is missing at call time.

import asyncio
from composer_replication import replay_trace
results = asyncio.run(replay_trace(states=my_trace, max_total_usd=1.0))

extract_dpo_pairs(states, teacher_actions, agreement_threshold=2) -> list[DPOPair]

def extract_dpo_pairs(
    states: Sequence[TraceState],
    teacher_actions: Sequence[TeacherCallResult],
    agreement_threshold: int = 2,
) -> list[DPOPair]

Group teacher_actions by state_id, normalize whitespace, and emit one DPOPair per state where ≥agreement_threshold teachers agreed on an action that differs from the student's. chosen is the original (un-normalized) teacher response text.

Parameters

Name Type Default Meaning
states Sequence[TraceState] Same as passed to replay_trace.
teacher_actions Sequence[TeacherCallResult] Output of replay_trace.
agreement_threshold int 2 Min teachers that must agree for a pair to fire.

Returns list of DPOPair. At most one pair per state (the most-agreed-upon action wins).

from composer_replication import extract_dpo_pairs
pairs = extract_dpo_pairs(my_states, results, agreement_threshold=2)

save_pairs(pairs, path) -> None

⚠️ UNTESTED-CONTRACT.

def save_pairs(pairs: Sequence[DPOPair], path: str | Path) -> None

Write pairs to JSONL (one dict per line). Creates parent dirs.

from composer_replication.teacher_replay import save_pairs
save_pairs(pairs, "/tmp/dpo_pairs.jsonl")

7. composer_replication.replaysim

ADR-004 normalization layer over teacher_replay. Re-exports DPOPair, TeacherCallResult, extract_dpo_pairs, replay_trace from teacher_replay.

class NormalizedDPOPair

@dataclass
class NormalizedDPOPair:
    state_id: str
    state_messages: list[dict[str, Any]]
    chosen_messages: list[dict[str, Any]]
    rejected_messages: list[dict[str, Any]]
    n_teachers_agreeing: int
    metadata: dict[str, Any]

Post-normalization shape. chosen_messages/rejected_messages are chat-format ([{"role": "assistant", "content": ...}]). metadata carries op-graph provenance, including {"skipped": True} when the normalizer was bypassed (skip_dj=True).

from composer_replication.replaysim import NormalizedDPOPair
n = NormalizedDPOPair(state_id="x", state_messages=[],
    chosen_messages=[{"role": "assistant", "content": "ok"}],
    rejected_messages=[{"role": "assistant", "content": "no"}],
    n_teachers_agreeing=2, metadata={})

class DJNormalizer

class DJNormalizer:
    DEFAULT_RECIPE: ClassVar[Path]  # composer_replication/recipes/replaysim/default.yaml

    def __init__(
        self,
        recipe_path: str | os.PathLike[str] | None = None,
        *,
        skip_dj: bool = False,
    ) -> None: ...

    def normalize(
        self,
        pairs: Iterable[DPOPair | dict[str, Any]],
    ) -> list[NormalizedDPOPair]: ...

data-juicer-backed normalizer. Pipeline: each DPOPair → JSONL record → data_juicer.core.DefaultExecutor.run() against the recipe → JSONL → NormalizedDPOPair.

Constructor parameters

Name Type Default Meaning
recipe_path str | PathLike | None None ⇒ default recipe data-juicer YAML recipe path.
skip_dj bool (kw-only) False If True: passthrough; records get metadata={"skipped": True} and no ops run.

normalize(pairs) -> list[NormalizedDPOPair] runs the op-graph. Output may be shorter than input if filter ops drop records.

Raises RuntimeError at construction time if skip_dj=False and data_juicer is not importable. FileNotFoundError if recipe_path (default or explicit) is missing and skip_dj=False.

from composer_replication.replaysim import DJNormalizer
norm = DJNormalizer(skip_dj=True)
out = norm.normalize(my_pairs)

async replay_and_normalize_trace(*, states, teachers=None, agreement_threshold=2, max_total_usd=5.0, normalizer=None, **replay_kwargs) -> tuple[list[TeacherCallResult], list[NormalizedDPOPair]]

async def replay_and_normalize_trace(
    *,
    states: Any,
    teachers: Any = None,
    agreement_threshold: int = 2,
    max_total_usd: float = 5.0,
    normalizer: DJNormalizer | None = None,
    **replay_kwargs: Any,
) -> tuple[list[TeacherCallResult], list[NormalizedDPOPair]]

End-to-end async: replay → extract pairs → normalize.

Parameters

Name Type Default Meaning
states Sequence[TraceState] Frozen trace.
teachers Sequence[TeacherSpec] | None None ⇒ defaults Forwarded to replay_trace.
agreement_threshold int 2 Forwarded to extract_dpo_pairs.
max_total_usd float 5.0 Spend cap.
normalizer DJNormalizer | None NoneDJNormalizer() Pass DJNormalizer(skip_dj=True) to bypass.
**replay_kwargs Any Forwarded to replay_trace (e.g. api_key).

Returns (raw_teacher_actions, normalized_pairs).

import asyncio
from composer_replication.replaysim import replay_and_normalize_trace, DJNormalizer
raw, norm = asyncio.run(replay_and_normalize_trace(
    states=my_states, normalizer=DJNormalizer(skip_dj=True)))

replay_and_normalize_trace_sync(*args, **kwargs) -> tuple[list[TeacherCallResult], list[NormalizedDPOPair]]

⚠️ UNTESTED-CONTRACT (sync wrapper around the async function; tests call the async form via asyncio.run).

def replay_and_normalize_trace_sync(*args, **kwargs) -> ...

Sync convenience wrapping asyncio.run(replay_and_normalize_trace(...)).

from composer_replication.replaysim.normalize import replay_and_normalize_trace_sync
raw, norm = replay_and_normalize_trace_sync(states=my_states)

8. composer_replication.ingestion & composer_replication.ingestion.claude_code

Trace-source adapters (ADR-002). v0.1 supports Claude Code session JSONL.

SYSTEM_PROMPT: str

Default synthetic system prompt injected at messages[0] for ingested traces (most Claude Code sessions don't write one). Truncated head: "You are a senior software engineer working as a coding agent in a terminal environment...".

from composer_replication import SYSTEM_PROMPT
print(SYSTEM_PROMPT[:60])

class IngestionStats

@dataclass
class IngestionStats:
    n_records_total: int = 0
    n_records_skipped: int = 0
    n_states_emitted: int = 0
    n_assistant_turns: int = 0
    n_tool_use_blocks: int = 0
    n_text_blocks: int = 0
    skipped_subagent: int = 0
    skipped_summary: int = 0
    skipped_truncated_lines: int = 0
    version_warnings: list[str] | None = None  # initialized to [] in __post_init__

Counters populated by ClaudeCodeIngester.ingest() and exposed as ingester.last_stats.

from composer_replication import IngestionStats
s = IngestionStats(n_records_total=5)
print(s.version_warnings)  # []

class ClaudeCodeIngester

class ClaudeCodeIngester:
    def __init__(
        self,
        *,
        system_prompt: str = SYSTEM_PROMPT,
        skip_sidechain: bool = True,
        strip_thinking: bool = True,
        max_history_tokens: int | None = None,
    ) -> None: ...

    def ingest(self, path: Path) -> Iterator[TraceState]: ...

Convert a Claude Code session JSONL to a stream of TraceStates — one per assistant TURN (not per tool_use block).

Constructor parameters

Name Type Default Meaning
system_prompt str SYSTEM_PROMPT Synthetic system message injected at history[0].
skip_sidechain bool True Skip subagent files (agent-*.jsonl) and records with isSidechain=True.
strip_thinking bool True Remove [THINKING] blocks from history handed to teachers (kept inside student_action).
max_history_tokens int | None None ⚠️ UNTESTED-CONTRACT — accepted but currently not used to truncate.

ingest(path) -> Iterator[TraceState]: generator over TraceState objects. Each turn's state_id is f"{path.stem}::{idx:04d}". Side effect: replaces self.last_stats with a fresh IngestionStats and updates it as records stream.

from pathlib import Path
from composer_replication import ClaudeCodeIngester
ing = ClaudeCodeIngester()
for state in ing.ingest(Path("session.jsonl")):
    print(state["state_id"])
print(ing.last_stats.n_states_emitted)

9. composer_replication.hint_generator

⚠️ UNTESTED-CONTRACT (entire module — used by the data collator config but not pinned by a test).

Template-based hint registry for SDPO error-site injection.

class HintContext(TypedDict, total=False)

class HintContext(TypedDict, total=False):
    error_kind: str
    error_message: str
    available_tools: list[str]
    tool_name: str
    tool_schema: dict
    intent: str

Per-error context dict consumed by hint templates.

HINT_TEMPLATES: dict[str, Callable[[HintContext], str]]

Default registry keys: "tool_not_found", "json_decode", "type_error", "runtime_error", "repeated_failure".

dispatch(error_kind, ctx=None) -> str | None

def dispatch(error_kind: str, ctx: HintContext | None = None) -> str | None

Look up error_kind in HINT_TEMPLATES. Returns the template's hint text, or None if the kind is unknown.

from composer_replication.hint_generator import dispatch
hint = dispatch("json_decode")  # "Reminder: tool arguments must be valid JSON. ..."

register(error_kind, fn) -> None

def register(error_kind: str, fn: Callable[[HintContext], str]) -> None

Add or override a custom hint template.

from composer_replication.hint_generator import register
register("my_error", lambda ctx: "Reminder: try X.")

Individual template functions

⚠️ UNTESTED-CONTRACT — exported only via HINT_TEMPLATES, useful as building blocks:

  • hint_tool_not_found(ctx) -> str
  • hint_json_decode(ctx) -> str
  • hint_type_error(ctx) -> str
  • hint_runtime_error(ctx) -> str
  • hint_repeated_failure(ctx) -> str

Each accepts a HintContext and returns hint text. Signatures are uniform: Callable[[HintContext], str].

from composer_replication.hint_generator import hint_tool_not_found
text = hint_tool_not_found({"available_tools": ["Read", "Write"]})

10. composer_replication.trainer & sub-modules

Production trainer (TRL GRPOTrainer subclass) plus data collator.

class ComposerReplicationTrainer

class ComposerReplicationTrainer(GRPOTrainer):
    def __init__(
        self,
        *args: Any,
        alpha_sdpo: float = 0.1,
        beta_replay: float = 0.05,
        sdpo_jsd_beta: float = 0.5,
        sdpo_temperature: float = 1.0,
        sdpo_token_clip: float | None = None,
        replay_dpo_beta: float = 0.1,
        **kwargs: Any,
    ) -> None: ...

    def _compute_loss(
        self,
        model: torch.nn.Module,
        inputs: dict[str, torch.Tensor],
    ) -> torch.Tensor: ...

trl.GRPOTrainer subclass that overrides _compute_loss(model, inputs) to compose total = grpo + α·sdpo + β·trace_replay_dpo. When trl is not installed, the parent class falls back to object so the module imports — but instantiation will fail because the parent's GRPO machinery is missing.

Constructor (kw-only beyond GRPOTrainer's own *args, **kwargs)

Name Type Default Meaning
alpha_sdpo float 0.1 Channel-2 weight.
beta_replay float 0.05 Channel-3 weight.
sdpo_jsd_beta float 0.5 β for generalized_jsd_loss.
sdpo_temperature float 1.0 SDPO softmax temperature.
sdpo_token_clip float | None None Per-token JSD clip.
replay_dpo_beta float 0.1 DPO β.

_compute_loss(model, inputs) -> torch.Tensor — overrides GRPOTrainer._compute_loss. Calls super()._compute_loss for channel 1, then _compute_sdpo_loss and _compute_trace_replay_loss, then composes. Logs per-channel components every args.logging_steps (default 50). Raises whatever super() raises (TRL-shaped errors).

Internal methods (publicly accessible, exercised by spike tests)

  • ⚠️ UNTESTED-CONTRACT _compute_sdpo_loss(model, inputs) -> torch.Tensor — generalized-JSD between student forward and ctx_teacher_input_ids forward. Returns 0.0 (with grad) when alpha_sdpo == 0, the key is missing, or shapes mismatch. Logs a warning on shape mismatch.
  • ⚠️ UNTESTED-CONTRACT _compute_trace_replay_loss(model, inputs) -> torch.Tensor — standard DPO over dpo_chosen_* and dpo_rejected_*, using precomputed dpo_chosen_ref_logprobs / dpo_rejected_ref_logprobs.
  • ⚠️ UNTESTED-CONTRACT @staticmethod _sequence_logprobs(model, input_ids, response_mask) -> torch.Tensor — sum logprobs over response tokens; standard DPO accounting.
from composer_replication import ComposerReplicationTrainer
trainer = ComposerReplicationTrainer(
    model=my_model, args=my_grpo_args, train_dataset=ds,
    data_collator=my_collator, alpha_sdpo=0.1, beta_replay=0.05,
)
# trainer.train()  # uses overridden _compute_loss

class TraceTurn(TypedDict, total=False)trainer.data_collator

class TraceTurn(TypedDict, total=False):
    role: str                # "user" | "assistant" | "tool"
    content: str
    tool_call: dict | None
    tool_error: str | None
    error_meta: dict

One turn of an agentic trace as consumed by ComposerDataCollator.

class TraceExample(TypedDict, total=False)trainer.data_collator

class TraceExample(TypedDict, total=False):
    trace_id: str
    turns: list[TraceTurn]
    final_reward: float
    dpo_pairs: list[dict] | None

One training example: (turns, optional dpo_pairs). dpo_pairs shape matches DPOPair.

class TokenizerLiketrainer.data_collator

⚠️ UNTESTED-CONTRACT (duck-typed protocol; used as a type hint).

class TokenizerLike:
    pad_token_id: int
    def __call__(self, text: str | list[str], **kwargs: Any) -> dict[str, list]: ...
    def apply_chat_template(self, messages: list[dict], **kwargs: Any) -> str | list[int]: ...

Minimal protocol the collator needs. Compatible with HF AutoTokenizer.

class CollatorConfigtrainer.data_collator

@dataclass
class CollatorConfig:
    max_seq_len: int = 4096
    max_dpo_seq_len: int = 2048
    pad_token_id: int = 0
    ignore_index: int = -100
    enable_sdpo: bool = True
    hint_generator: Callable[[str, dict], str | None] | None = None
    enable_replay_dpo: bool = True
    rlvr_reward_key: str = "final_reward"

Tunables for ComposerDataCollator.

Field Default Meaning
max_seq_len 4096 Truncation cap for student/teacher sequences.
max_dpo_seq_len 2048 Truncation cap for DPO chosen/rejected sequences.
pad_token_id 0 Padding token id.
ignore_index -100 HF "ignore in loss" sentinel for SDPO mask.
enable_sdpo True Toggle channel-2 fields.
hint_generator Callable[[str, dict], str | None] | None (None) (error_kind, error_meta) -> hint_text. SDPO is no-op without this.
enable_replay_dpo True Toggle channel-3 fields.
rlvr_reward_key "final_reward" Key in TraceExample to read scalar reward.
from composer_replication.trainer.data_collator import CollatorConfig
cfg = CollatorConfig(max_seq_len=2048, hint_generator=my_dispatch)

class ComposerDataCollatortrainer.data_collator

@dataclass
class ComposerDataCollator:
    tokenizer: TokenizerLike
    config: CollatorConfig = field(default_factory=CollatorConfig)

    def __call__(
        self, batch: Sequence[TraceExample]
    ) -> dict[str, torch.Tensor]: ...

Build trainer-ready batches from raw traces + optional DPO pairs.

Output dict keys (tested in spikes/005-integrated-trainer-skeleton/tests/test_data_collator.py):

  • Channel 1 (always): input_ids, attention_mask, response_mask, rewards.
  • Channel 2 (when enable_sdpo=True AND batch has at least one error site AND hint_generator is set): ctx_teacher_input_ids, sdpo_loss_mask.
  • Channel 3 (when enable_replay_dpo=True AND batch has at least one dpo_pair): dpo_chosen_input_ids, dpo_chosen_response_mask, dpo_rejected_input_ids, dpo_rejected_response_mask. (Reference logprobs are NOT computed here — the trainer does that pass.)
from composer_replication.trainer.data_collator import (
    ComposerDataCollator, CollatorConfig)
collator = ComposerDataCollator(tokenizer=tok, config=CollatorConfig())
batch = collator([{"trace_id": "x", "turns": [...], "final_reward": 1.0}])

11. composer_replication.diloco

DiLoCo outer-loop wrapper around torchft.local_sgd.DiLoCo. Optional dep — when torchft is missing the package re-export composer_replication.make_diloco_outer_loop is None.

Module-level attributes

  • DiLoCo: Anytorchft.local_sgd.DiLoCo if importable else None.
  • Manager: Anytorchft.manager.Manager if importable else None.
  • _DummyWork: Anytorchft.work._DummyWork if importable else None.
  • _TORCHFT_AVAILABLE: bool — whether the imports succeeded.
from composer_replication.diloco import _TORCHFT_AVAILABLE, DiLoCo

make_diloco_outer_loop(manager, model_fragments, inner_optimizer, *, ...) -> torchft.local_sgd.DiLoCo

def make_diloco_outer_loop(
    manager: Any,
    model_fragments: list[torch.nn.Module],
    inner_optimizer: torch.optim.Optimizer,
    *,
    outer_lr: float = 0.7,
    outer_momentum: float = 0.9,
    nesterov: bool = True,
    sync_every: int = 100,
    fragment_sync_delay: int = 0,
    fragment_update_alpha: float = 0.0,
) -> Any

Construct a torchft.DiLoCo configured with framework-default hyperparams (DiLoCo paper §3.2: lr=0.7, momentum=0.9, Nesterov).

Parameters

Name Type Default Meaning
manager torchft.Manager (or duck-typed MockManager) Provides allreduce, should_commit, current_step, start_quorum, etc.
model_fragments list[torch.nn.Module] One module for vanilla DiLoCo; N modules for Streaming DiLoCo.
inner_optimizer torch.optim.Optimizer Inner-step optimizer (steps every batch).
outer_lr float 0.7 Outer SGD lr.
outer_momentum float 0.9 Outer SGD momentum.
nesterov bool True Nesterov momentum on outer SGD.
sync_every int 100 Inner steps per outer round.
fragment_sync_delay int 0 0 = vanilla; >0 = Streaming DiLoCo (requires CUDA streams).
fragment_update_alpha float 0.0 0 = full replacement on sync; >0 = exponential mix.

Returns a torchft.local_sgd.DiLoCo instance — usable as a context manager.

Raises RuntimeError if torchft is not installed.

import torch
from composer_replication.diloco import make_diloco_outer_loop
opt = torch.optim.AdamW(model.parameters(), lr=1e-5)
outer = make_diloco_outer_loop(manager=mgr, model_fragments=[model],
                               inner_optimizer=opt, sync_every=100)
with outer:
    for _ in range(N):
        opt.zero_grad(); loss.backward(); opt.step()

12. composer_replication.diloco.serverless

ADR-005 serverless DiLoCo executors + object-store all-reduce.

class ReplicaHandleserverless.executor

@dataclass
class ReplicaHandle:
    rank: int
    backend_name: str
    metadata: dict[str, Any] = field(default_factory=dict)

Opaque handle returned by ServerlessExecutor.launch_replicas. metadata is backend-specific.

from composer_replication.diloco.serverless import ReplicaHandle
h = ReplicaHandle(rank=0, backend_name="local_process",
                  metadata={"pid": 12345})

class ServerlessExecutor (Protocol) — serverless.executor

@runtime_checkable
class ServerlessExecutor(Protocol):
    backend_name: str
    supports_inter_replica_network: bool

    def launch_replicas(
        self,
        n_replicas: int,
        entrypoint: str | Callable[..., Any],
        entrypoint_args: Mapping[str, Any],
        *,
        gpu: str | None = None,
        timeout: int = 3600,
    ) -> list[ReplicaHandle]: ...

    def poll(self, handle: ReplicaHandle) -> str: ...
    def stream_logs(self, handle: ReplicaHandle, *, n_lines: int = 200) -> str: ...
    def cancel(self, handle: ReplicaHandle) -> None: ...
    def collect(
        self, handles: list[ReplicaHandle], *, timeout: int | None = None,
    ) -> list[dict[str, Any]]: ...

Structural protocol for serverless backends.

  • launch_replicas(...) returns list[ReplicaHandle] of length n_replicas in rank order. entrypoint is either an importable module path (uses main()) or a module.function path or a Callable (Local executor only). entrypoint_args may include rank_env (default "REPLICA_RANK").
  • poll(handle) -> str: one of "pending", "running", "succeeded", "failed", "cancelled".
  • stream_logs(handle, n_lines=200) -> str: best-effort recent stdout/stderr.
  • cancel(handle) -> None: best-effort.
  • collect(handles, timeout=None) -> list[dict]: blocks; each result dict has rank, status, exit_code, error (and result from LocalProcessExecutor).
from composer_replication.diloco.serverless import ServerlessExecutor
def supports(x: ServerlessExecutor) -> bool:
    return isinstance(x, ServerlessExecutor)  # runtime_checkable

class LocalProcessExecutorserverless.executor

class LocalProcessExecutor:
    backend_name = "local_process"
    supports_inter_replica_network = True

    def __init__(self) -> None: ...
    # implements ServerlessExecutor protocol

Reference implementation using Python multiprocessing (spawn context). Used for tests, CI smokes, and local development with file:// rendezvous.

launch_replicas(...): emits a soft warning on gpu != None (local processes share whatever GPUs are visible). metadata = {"pid": ..., "start_ts": ...}.

from composer_replication.diloco.serverless import LocalProcessExecutor
ex = LocalProcessExecutor()
handles = ex.launch_replicas(
    n_replicas=2,
    entrypoint="composer_replication.diloco.serverless.replica_entrypoint",
    entrypoint_args={"rendezvous_uri": "/tmp/run/", "world_size": 2,
                     "trainer_module": "my.trainer"},
)
results = ex.collect(handles, timeout=60)

class ObjectStoreAllReduceserverless.allreduce

class ObjectStoreAllReduce:
    def __init__(
        self,
        uri: str,
        rank: int,
        world_size: int,
        *,
        round_id: int | None = None,
        timeout_s: float = 1800.0,
        poll_interval_s: float = 1.0,
    ) -> None: ...

    @property
    def round_id(self) -> int: ...

    def allreduce(
        self, tensor: torch.Tensor, *, name: str | None = None,
    ) -> torch.Tensor: ...

fsspec-backed pseudo-gradient rendezvous. uri accepts s3://, gs://, az://, hf://, file://, or a plain local path.

Constructor parameters

Name Type Default Meaning
uri str fsspec URI or local path. Trailing / enforced.
rank int This replica's rank.
world_size int Total replicas.
round_id int | None (kw-only) None ⇒ start at 0 Initial round counter.
timeout_s float (kw-only) 1800.0 Per-allreduce timeout.
poll_interval_s float (kw-only) 1.0 Sleep between peer-file existence checks.

allreduce(tensor, name=None) -> torch.Tensor: serializes tensor.detach().cpu() to round_NNNNNN/rank_RRRR.pt, blocks until all peers post, then averages. Modifies tensor in place AND returns it. Increments the internal _round_counter.

Raises ValueError on invalid rank, RuntimeError if non-local URI is requested without fsspec installed, TimeoutError if peers don't show up before timeout_s.

from composer_replication.diloco.serverless import ObjectStoreAllReduce
import torch
store = ObjectStoreAllReduce("/tmp/run/", rank=0, world_size=2)
g = torch.zeros(10)
store.allreduce(g)  # blocks for rank 1

class MockManagerserverless.allreduce

class MockManager:
    def __init__(self, store: ObjectStoreAllReduce) -> None: ...

    # torchft.Manager-shaped surface:
    num_participants: int
    rank: int
    _use_async_quorum: bool        # always False
    _step: int
    _state_dict_fns: dict[str, tuple[Any, Any]]

    def allreduce(self, tensor: torch.Tensor, **_kwargs: Any) -> "_ImmediateWork": ...
    def should_commit(self) -> bool: ...
    def start_quorum(self) -> None: ...
    def wait_quorum(self) -> int: ...
    def current_step(self) -> int: ...
    def allow_state_dict_read(self) -> None: ...
    def disallow_state_dict_read(self) -> None: ...
    def register_state_dict_fn(self, key: str, load_fn: Any, save_fn: Any) -> None: ...
    def is_leader(self) -> bool: ...

Drop-in replacement for torchft.Manager that routes allreduce through ObjectStoreAllReduce. All other methods are no-ops or simple counters appropriate for single-shot serverless DiLoCo.

  • allreduce(tensor) returns an _ImmediateWork whose .wait() is a no-op (the tensor is already averaged).
  • should_commit() always True (no fault-tolerance failover).
  • start_quorum() bumps _step.
  • is_leader() returns rank == 0.
from composer_replication.diloco.serverless import MockManager, ObjectStoreAllReduce
store = ObjectStoreAllReduce("/tmp/run/", rank=0, world_size=2)
mgr = MockManager(store)
# pass mgr into make_diloco_outer_loop(manager=mgr, ...)

class _ImmediateWorkserverless.allreduce

⚠️ UNTESTED-CONTRACT internal helper exported from __all__. Work-shaped wrapper with .wait() -> True and .get_future() -> torch.futures.Future. Consumed by torchft DiLoCo's perform_sync.

from composer_replication.diloco.serverless.allreduce import _ImmediateWork

class ModalExecutorserverless.modal

🟡 SKELETON — raises NotImplementedError; see ADR-005. Class body documents the v0 implementation pattern (Modal app.function + function.spawn(rank=...)).

from composer_replication.diloco.serverless.modal import ModalExecutor
# ModalExecutor()  # would NotImplementedError when instantiated

class HFJobsExecutorserverless.hf_jobs

🟡 SKELETON — raises NotImplementedError; see ADR-005. Class body documents the v0 pattern using huggingface_hub.run_job against hf://datasets/.../ rendezvous.

from composer_replication.diloco.serverless.hf_jobs import HFJobsExecutor
# instantiation will fail until v0 implementation lands

replica_entrypoint.main(...)serverless.replica_entrypoint

def main(
    rendezvous_uri: str,
    world_size: int,
    trainer_module: str,
    trainer_fn: str = "train",
    trainer_kwargs: dict[str, Any] | None = None,
) -> Any

Script run by every replica. Reads REPLICA_RANK env var, builds ObjectStoreAllReduce + MockManager, imports trainer_module, and calls getattr(mod, trainer_fn)(**trainer_kwargs, manager=..., rank=..., world_size=...). Returns whatever the train fn returns.

Raises RuntimeError if REPLICA_RANK env var is missing; ValueError if rank ∉ [0, world_size).

The if __name__ == "__main__" block accepts CLI flags --rendezvous, --world-size, --trainer-module, --trainer-fn, --trainer-kwargs-json.

# In-process invocation
import os
os.environ["REPLICA_RANK"] = "0"
from composer_replication.diloco.serverless.replica_entrypoint import main
result = main(rendezvous_uri="/tmp/run/", world_size=1,
              trainer_module="my.trainer", trainer_fn="train")

13. composer_replication.recipes.prime_rl.composer_loss

PRIME-RL adapter (ADR-006). Maps PRIME-RL's LossInputs struct onto channel 1 (DPPO + KL on the importance ratio, mirroring PRIME-RL's upstream default_loss_fn at prime_rl/trainer/rl/loss.py lines 116-165). Channel 2 raises NotImplementedError; channel 3 is out of scope.

loss_fn(inputs, *, alpha_sdpo=0.0, beta_dpo=0.0, dppo_mask_high=0.2, dppo_mask_low=0.2, adv_tau=1.0, kl_tau=1e-3) -> torch.Tensor

def loss_fn(
    inputs: Any,  # PRIME-RL's LossInputs (duck-typed)
    *,
    alpha_sdpo: float = 0.0,
    beta_dpo: float = 0.0,
    dppo_mask_high: float = 0.2,
    dppo_mask_low: float = 0.2,
    adv_tau: float = 1.0,
    kl_tau: float = 1e-3,
) -> Any  # torch.Tensor scalar

PRIME-RL passes per-sample 1-D (seq,) tensors (not batched). The function mirrors PRIME-RL's upstream DPPO+KL formula:

  • Mask gate is on probability-space probs_diff = exp(trainer_lp) - exp(inference_lp) (NOT on the log-ratio).
  • A token is dropped iff its advantage sign matches the offending bound: positive-advantage tokens are dropped when probs_diff > dppo_mask_high, negative-advantage tokens when probs_diff < -dppo_mask_low. (PRIME-RL stores both bounds with Field(..., ge=0) and applies the sign internally.)
  • The PG term is keep * (adv_tau * advantages) * exp(trainer_lp - inference_lp) (importance-ratio corrected, not REINFORCE).
  • A KL penalty kl_tau * log_importance_ratio**2 is added on the full loss_mask (DPPO masking does not gate it).
  • Reduction is a plain sum(); PRIME-RL's outer compute_loss divides by loss_scale.

Parameters

Name Type Default Meaning
inputs PRIME-RL LossInputs (duck-typed) Must expose trainer_logprobs, inference_logprobs, advantages, loss_mask (all 1-D), and optionally teacher_logprobs.
alpha_sdpo float (kw-only) 0.0 Channel-2 weight. Must be 0 in v0; >0 → NotImplementedError.
beta_dpo float (kw-only) 0.0 Channel-3 weight. Non-zero emits a UserWarning.
dppo_mask_high float (kw-only), >= 0 0.2 Upper probability-diff threshold. PRIME-RL DefaultLossConfig default.
dppo_mask_low float (kw-only), >= 0 0.2 Magnitude of lower probability-diff threshold (sign flipped internally). PRIME-RL default.
adv_tau float (kw-only), >= 0 1.0 Advantage temperature. PRIME-RL default.
kl_tau float (kw-only), >= 0 1e-3 KL term temperature. PRIME-RL default.

Returns scalar torch.Tensor (PRIME-RL's trainer calls .backward()).

Raises ValueError if any of trainer_logprobs, inference_logprobs, advantages, loss_mask is not 1-D, or any of the four >=0-constrained knobs is negative. NotImplementedError if alpha_sdpo > 0 (channel 2 deferred).

from composer_replication.recipes.prime_rl.composer_loss import loss_fn
# In PRIME-RL config:
#   loss:
#     custom:
#       import_path: composer_replication.recipes.prime_rl.composer_loss:loss_fn
#       kwargs:
#         dppo_mask_high: 0.2
#         dppo_mask_low:  0.2
#         adv_tau:        1.0
#         kl_tau:         1.0e-3

14. composer_replication.recipes.monarch.actors

🟡 SKELETON module per ADR-006. Importable; classes raise NotImplementedError on instantiation. Documents the actor signatures so the recipe matrix is complete.

class TrainerActor 🟡

class TrainerActor:
    backend = "monarch"
    role = "trainer"

    def __init__(self) -> None: raise NotImplementedError(...)
    async def train_outer_step(self, batch_id: int) -> dict[str, Any]: raise NotImplementedError

Hosts the framework's 3-channel composer trainer. Real impl deferred to v0.2+.

class GeneratorActor 🟡

class GeneratorActor:
    backend = "monarch"
    role = "generator"
    def __init__(self) -> None: raise NotImplementedError(...)
    async def rollout(self, prompts: list[str]) -> list[str]: raise NotImplementedError

vLLM-backed rollout actor.

class RewarderActor 🟡

class RewarderActor:
    backend = "monarch"
    role = "rewarder"
    def __init__(self) -> None: raise NotImplementedError(...)
    async def score(self, completions: list[str]) -> list[float]: raise NotImplementedError

verifiers-protocol rewarder.

class TeacherPoolActor 🟡

class TeacherPoolActor:
    backend = "monarch"
    role = "teacher_pool"
    def __init__(self) -> None: raise NotImplementedError(...)

Channel-3 teacher pool wrapping composer_replication.teacher_replay.

# All Monarch actors raise on instantiation in v0:
from composer_replication.recipes.monarch.actors import TrainerActor
# TrainerActor()  # NotImplementedError

Notes on test coverage

Tested contracts (referenced spike/test paths):

  • compose_loss + LossComponents + build_batch: composer_replication/tests/test_compose_loss_integration.py, spikes/006-real-hf-model-smoke/tests/.
  • generalized_jsd_loss: spikes/005-integrated-trainer-skeleton/tests/test_opsd_loss.py.
  • simpo_loss, taid_loss, taid_alpha_schedule, taid_blended_logits, entropy_aware_opd_loss: composer_replication/distillation/tests/test_distillation_losses.py.
  • replay_trace, extract_dpo_pairs, DPOPair, TraceState, TeacherCallResult, TeacherSpec, DEFAULT_TEACHERS: spikes/005-integrated-trainer-skeleton/tests/test_teacher_replay.py.
  • DJNormalizer, NormalizedDPOPair, replay_and_normalize_trace: composer_replication/replaysim/tests/test_replaysim.py.
  • ClaudeCodeIngester, IngestionStats, SYSTEM_PROMPT: spikes/007-real-trace-ingestion/tests/.
  • ComposerDataCollator, CollatorConfig, TraceTurn, TraceExample: spikes/005-integrated-trainer-skeleton/tests/test_data_collator.py.
  • ComposerReplicationTrainer._compute_loss (composition arithmetic): spikes/005-integrated-trainer-skeleton/tests/test_loss_composition_smoke.py.
  • make_diloco_outer_loop + sign convention: spikes/008-streaming-diloco/tests/test_diloco_smoke.py.
  • ObjectStoreAllReduce, MockManager, LocalProcessExecutor, ReplicaHandle, ServerlessExecutor, replica_entrypoint.main: composer_replication/diloco/serverless/tests/test_serverless_local.py, test_serverless_diloco_integration.py.
  • recipes.prime_rl.composer_loss.loss_fn: composer_replication/recipes/prime_rl/tests/test_composer_loss.py.

Untested-contract symbols (⚠️) and skeletons (🟡) are flagged inline above.


Document path: /mnt/e/CS/HF/composer-replication-framework/docs/API_REFERENCE.md