Instructions to use Codeseys/composer-replication-framework with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- Transformers
How to use Codeseys/composer-replication-framework with Transformers:
# Load model directly from transformers import AutoModel model = AutoModel.from_pretrained("Codeseys/composer-replication-framework", dtype="auto") - Notebooks
- Google Colab
- Kaggle
API Reference — composer-replication-framework
Complete reference for every public symbol in composer_replication. Source-of-truth is the .py files in composer_replication/; docstrings have been pulled verbatim where they exist and supplemented where missing.
Legend
- ⚠️ UNTESTED-CONTRACT — symbol exists and is callable, but its behaviour is not pinned by an automated test in
composer_replication/**/tests/orspikes/**/tests/. - 🟡 SKELETON — class/method body raises
NotImplementedError; ships as design-of-record per ADR-005 / ADR-006.
Module groups (in this document)
composer_replication(top-level re-exports)composer_replication.losscomposer_replication.batchcomposer_replication.opsdcomposer_replication.distillationcomposer_replication.teacher_replaycomposer_replication.replaysimcomposer_replication.ingestion(+.claude_code)composer_replication.hint_generatorcomposer_replication.trainer(+.composer_trainer,.data_collator)composer_replication.dilococomposer_replication.diloco.serverless(+.executor,.allreduce,.modal,.hf_jobs,.replica_entrypoint)composer_replication.recipes.prime_rl.composer_losscomposer_replication.recipes.monarch.actors
1. composer_replication — top-level package
The package re-exports the most common entry points from sub-modules. __all__ is the canonical list of public top-level names.
composer_replication.__version__: str
Package version string. Currently "0.1.0".
import composer_replication
print(composer_replication.__version__) # "0.1.0"
composer_replication._DILOCO_AVAILABLE: bool
True iff torchft is importable in the running Python environment (gates make_diloco_outer_loop). Set to False and make_diloco_outer_loop is set to None when torchft is missing.
from composer_replication import _DILOCO_AVAILABLE
if _DILOCO_AVAILABLE:
from composer_replication import make_diloco_outer_loop
Re-exports
| Name | Source module |
|---|---|
compose_loss |
composer_replication.loss |
LossComponents |
composer_replication.loss |
build_batch |
composer_replication.batch |
generalized_jsd_loss |
composer_replication.opsd |
ClaudeCodeIngester |
composer_replication.ingestion.claude_code |
IngestionStats |
composer_replication.ingestion.claude_code |
SYSTEM_PROMPT |
composer_replication.ingestion.claude_code |
DEFAULT_TEACHERS |
composer_replication.teacher_replay |
DPOPair |
composer_replication.teacher_replay |
TeacherCallResult |
composer_replication.teacher_replay |
TeacherSpec |
composer_replication.teacher_replay |
TraceState |
composer_replication.teacher_replay |
extract_dpo_pairs |
composer_replication.teacher_replay |
replay_trace |
composer_replication.teacher_replay |
ComposerReplicationTrainer |
composer_replication.trainer |
make_diloco_outer_loop |
composer_replication.diloco (or None if torchft missing) |
See each source module below for full signatures.
2. composer_replication.loss
Verification-harness 3-channel loss. Free function, does not depend on trl.
class LossComponents
@dataclass
class LossComponents:
lm_ce: torch.Tensor
sdpo_jsd: torch.Tensor
trace_replay_dpo: torch.Tensor
total: torch.Tensor
def detached(self) -> dict[str, float]: ...
Per-channel breakdown of the total loss for logging and ablation. All four fields are scalar torch.Tensors (shape=()); total = lm_ce + alpha_sdpo * sdpo_jsd + beta_replay * trace_replay_dpo.
detached() -> dict[str, float] — returns Python-float copies of all four fields with no grad. Useful for W&B logging.
from composer_replication import compose_loss, build_batch
components = compose_loss(model, build_batch(tokenizer))
print(components.detached()) # {'lm_ce': 2.34, 'sdpo_jsd': 0.12, ...}
components.total.backward()
compose_loss(model, inputs, *, ...) -> LossComponents
def compose_loss(
model: torch.nn.Module,
inputs: dict[str, torch.Tensor],
*,
alpha_sdpo: float = 0.1,
beta_replay: float = 0.05,
sdpo_jsd_beta: float = 0.5,
sdpo_temperature: float = 1.0,
sdpo_token_clip: float | None = None,
replay_dpo_beta: float = 0.1,
lm_ce_label_smoothing: float = 0.0,
dpo_variant: Literal["dpo", "simpo"] = "dpo",
sdpo_wrapper: Literal["none", "taid", "entropy_opd"] = "none",
taid_t: float | None = None,
simpo_beta: float = 2.0,
simpo_gamma: float = 1.0,
entropy_opd_h_max: float | None = None,
) -> LossComponents
Compute total = lm_ce + alpha_sdpo * sdpo_jsd + beta_replay * trace_replay_dpo.
Required keys in inputs
input_ids:(B, T_s)student rollout token ids.response_mask:(B, T_s)1 on assistant-response tokens, 0 elsewhere.
Optional keys (channel auto-disables if missing OR if its weight = 0):
- SDPO:
ctx_teacher_input_ids(B, T_t),sdpo_loss_mask(B, T_t). - DPO (
dpo_variant="dpo"):dpo_chosen_input_ids,dpo_chosen_response_mask,dpo_rejected_input_ids,dpo_rejected_response_mask,dpo_chosen_ref_logprobs,dpo_rejected_ref_logprobs(precomputed). - SimPO (
dpo_variant="simpo"): same DPO ids/masks; reference logprobs are silently ignored. - TAID (
sdpo_wrapper="taid"): no extrainputskeys needed; the optionalsdpo_loss_maskis reused as the per-token TAID mask. Passtaid_tdirectly (or drive it fromTAIDScheduler).
Parameters
| Name | Type | Default | Meaning |
|---|---|---|---|
model |
torch.nn.Module |
— | HF causal-LM. Must accept input_ids= and return an object with .logits. |
inputs |
dict[str, torch.Tensor] |
— | Batch dict (see required/optional keys above). |
alpha_sdpo |
float |
0.1 |
Weight on SDPO/JSD channel. 0.0 disables. |
beta_replay |
float |
0.05 |
Weight on trace-replay DPO channel. 0.0 disables. |
sdpo_jsd_beta |
float |
0.5 |
β param for generalized_jsd_loss (0=fwd KL, 0.5=JSD, 1=rev KL). Unused when sdpo_wrapper="taid". |
sdpo_temperature |
float |
1.0 |
Softmax temperature in SDPO. Unused when sdpo_wrapper="taid". |
sdpo_token_clip |
float | None |
None |
Per-token JSD clamp. |
replay_dpo_beta |
float |
0.1 |
β in standard DPO logit. |
lm_ce_label_smoothing |
float |
0.0 |
F.cross_entropy(label_smoothing=). |
dpo_variant |
Literal["dpo","simpo"] |
"dpo" |
Channel-3 algorithm. |
sdpo_wrapper |
Literal["none","taid","entropy_opd"] |
"none" |
Channel-2 wrapper. |
taid_t |
float | None |
None |
Current TAID interpolation coefficient in [0, 1]. Required when sdpo_wrapper="taid". Drive from TAIDScheduler or pass a fixed value. |
simpo_beta |
float |
2.0 |
SimPO β (paper default). |
simpo_gamma |
float |
1.0 |
SimPO target margin γ (paper default). |
entropy_opd_h_max |
float | None |
None |
Max-entropy normalizer; None ⇒ log(V). |
Returns LossComponents (see above).
Raises ValueError if dpo_variant or sdpo_wrapper is unknown, if sdpo_wrapper="taid" is requested without taid_t, or if taid_t is outside [0, 1].
from composer_replication import compose_loss, build_batch
batch = build_batch(tokenizer)
out = compose_loss(model, batch, alpha_sdpo=0.1, beta_replay=0.05)
out.total.backward()
print(out.detached())
3. composer_replication.batch
Verification-harness batch builder.
build_batch(tokenizer, *, ...) -> dict[str, torch.Tensor]
def build_batch(
tokenizer: Any,
*,
device: torch.device | str = "cpu",
seed: int = 42,
variant: str = "factorial",
align_sdpo_shapes: bool = False,
) -> dict[str, torch.Tensor]
Construct a full 3-channel batch from a real HF tokenizer. The DPO ref-logprobs are dummy tensors (the smoke verifies loss composition wires together, not the reference-policy precompute).
Returned keys: input_ids, response_mask, ctx_teacher_input_ids, sdpo_loss_mask, dpo_chosen_input_ids, dpo_chosen_response_mask, dpo_rejected_input_ids, dpo_rejected_response_mask, dpo_chosen_ref_logprobs, dpo_rejected_ref_logprobs.
Parameters
| Name | Type | Default | Meaning |
|---|---|---|---|
tokenizer |
HF AutoTokenizer (duck-typed) |
— | Must support apply_chat_template and __call__. |
device |
torch.device | str |
"cpu" |
Target device for all returned tensors. |
seed |
int |
42 |
Fixes torch.manual_seed. |
variant |
str |
"factorial" |
One of "factorial", "binary_search". |
align_sdpo_shapes |
bool |
False |
If True, truncate/pad ctx_teacher_input_ids to input_ids length so the SDPO channel actually fires. |
Raises ValueError if variant is unknown.
from transformers import AutoTokenizer
from composer_replication import build_batch
tok = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-0.5B-Instruct")
batch = build_batch(tok, variant="factorial", align_sdpo_shapes=True)
print({k: v.shape for k, v in batch.items()})
4. composer_replication.opsd
Self-distillation generalized-JSD loss, lifted verbatim from siyan-zhao/OPSD (MIT) per ADR-006.
generalized_jsd_loss(student_logits, teacher_logits, labels=None, beta=0.5, ...) -> torch.Tensor
def generalized_jsd_loss(
student_logits: torch.Tensor,
teacher_logits: torch.Tensor,
labels: torch.Tensor | None = None,
beta: float = 0.5,
temperature: float = 1.0,
reduction: str = "batchmean",
logits_are_probs: bool = False,
top_k: int | None = None,
token_clip: float | None = None,
) -> torch.Tensor
Generalized JSD between student and teacher distributions. Same model on different contexts in the SDPO recipe; student and teacher params come from the SAME model.
Parameters
| Name | Type | Default | Meaning |
|---|---|---|---|
student_logits |
Tensor (B, T, V) |
— | Student logits with grad. |
teacher_logits |
Tensor (B, T, V) |
— | Teacher logits (no grad in SDPO). |
labels |
Tensor (B, T) | None |
None |
Per-token mask. -100 positions are ignored (HF convention). |
beta |
float in [0, 1] |
0.5 |
0=fwd KL, 1=rev KL, 0.5=symmetric JSD. |
temperature |
float |
1.0 |
Softmax temperature. |
reduction |
str |
"batchmean" |
"batchmean", "sum", "mean", "none". |
logits_are_probs |
bool |
False |
Skip softmax if inputs are already probabilities. |
top_k |
int | None |
None |
Restrict KL to teacher's top-k tokens. |
token_clip |
float | None |
None |
Clip per-token JSD for stability. |
Returns scalar tensor (or (B, T) if reduction="none").
Raises ValueError for unknown reduction.
import torch
from composer_replication.opsd import generalized_jsd_loss
s = torch.randn(2, 8, 32, requires_grad=True)
t = torch.randn(2, 8, 32)
loss = generalized_jsd_loss(s, t, beta=0.5, reduction="batchmean")
loss.backward()
5. composer_replication.distillation
Pluggable self-distillation losses (ADR-007). All pure PyTorch.
simpo_loss(chosen_avg_logprobs, rejected_avg_logprobs, *, beta=2.0, gamma=1.0) -> torch.Tensor
def simpo_loss(
chosen_avg_logprobs: torch.Tensor,
rejected_avg_logprobs: torch.Tensor,
*,
beta: float = 2.0,
gamma: float = 1.0,
) -> torch.Tensor
Reference-free DPO with target margin γ (Meng et al., NeurIPS 2024). L = -log σ(β · (avg_logπ(c) − avg_logπ(r)) − γ).
Parameters
| Name | Type | Default | Meaning |
|---|---|---|---|
chosen_avg_logprobs |
Tensor (B,) |
— | Per-sequence avg logprob over chosen response tokens. |
rejected_avg_logprobs |
Tensor (B,) |
— | Same for rejected. |
beta |
float |
2.0 |
Scaling factor (paper default). |
gamma |
float |
1.0 |
Target margin (paper default). |
Returns scalar; Raises ValueError if shapes mismatch.
import torch
from composer_replication.distillation import simpo_loss
loss = simpo_loss(torch.tensor([-2.1, -1.8]), torch.tensor([-3.0, -2.5]),
beta=2.0, gamma=1.0)
avg_sequence_logprob(model_logprobs, response_mask) -> torch.Tensor
⚠️ UNTESTED-CONTRACT (helper exported from simpo.py but not asserted by a test).
def avg_sequence_logprob(
model_logprobs: torch.Tensor,
response_mask: torch.Tensor,
) -> torch.Tensor
Convert (B, T) per-token logprobs + (B, T) response mask into (B,) per-sequence average over response tokens.
from composer_replication.distillation.simpo import avg_sequence_logprob
import torch
lp = torch.randn(2, 8); m = torch.tensor([[0,0,1,1,1,0,0,0],[0,1,1,1,1,1,0,0]])
out = avg_sequence_logprob(lp, m) # shape (2,)
taid_loss(student_logits, teacher_logits, mask=None, *, t) -> torch.Tensor
def taid_loss(
student_logits: torch.Tensor,
teacher_logits: torch.Tensor,
mask: torch.Tensor | None = None,
*,
t: float | torch.Tensor,
) -> torch.Tensor
Faithful port of SakanaAI/TAID (arXiv:2501.16937). Forward-KL distillation against a logit-space-interpolated target whose anchor is the current student detached:
p_t = softmax( (1 - t) · stop_grad(student_logits) + t · teacher_logits )
L = - mean_token Σ_v p_t(v) · log_softmax(student_logits)(v)
At t=0 the target collapses to the detached student (no teacher signal in the gradient). At t=1 it reduces to standard forward-KL distillation against the teacher.
Wave 15 breaking change. The previous signature taid_loss(student, teacher, student_init, *, schedule_step, total_steps, schedule, alpha_min, alpha_max, jsd_beta, temperature, reduction) was algorithmically wrong (probability-space mix, frozen step-0 anchor, JSD criterion). All those kwargs are removed; the schedule is now the caller's responsibility (see TAIDScheduler below for the upstream adaptive scheme).
Parameters
| Name | Type | Default | Meaning |
|---|---|---|---|
student_logits |
Tensor (B, T, V) |
— | Current student (with grad). |
teacher_logits |
Tensor (B, T, V) |
— | Teacher logits. |
mask |
Tensor (B, T) | None |
None |
Token mask. None ⇒ all-ones. |
t |
float | Tensor |
— | Interpolation coefficient in [0, 1]. |
Raises ValueError for shape mismatch.
from composer_replication.distillation import taid_loss
loss = taid_loss(s_logits, t_logits, mask, t=0.4)
TAIDScheduler(num_train_steps, *, t_start=0.4, t_end=1.0, alpha=5e-4, beta=0.99, disable_adaptive=False)
Stateful schedule that mirrors upstream TAID.update_t. Monotone non-decreasing, bumped above the linear floor by an EMA on the relative loss change. Use as:
from composer_replication.distillation import TAIDScheduler
sched = TAIDScheduler(num_train_steps=10_000) # paper defaults
for step in range(num_train_steps):
loss = taid_loss(s, t, mask, t=sched.t)
loss.backward(); optimizer.step()
sched.update_t(loss.detach(), global_step=step)
Parameters
| Name | Type | Default | Meaning |
|---|---|---|---|
num_train_steps |
int |
— | Total planned training steps; sets the linear floor. |
t_start |
float |
0.4 |
Initial t (paper default). |
t_end |
float |
1.0 |
Terminal t; hard ceiling at every step. |
alpha |
float |
5e-4 |
Adaptive bump magnitude. |
beta |
float |
0.99 |
EMA decay on relative-loss-change momentum. |
disable_adaptive |
bool |
False |
If True, fall back to deterministic linear schedule. |
device |
torch.device | str |
"cpu" |
Where to allocate state buffers. |
Properties / methods
sched.t -> float— currenttas a Python float (zero-arg property).sched.update_t(loss, global_step) -> Tensor | None— update internal state. First finite-loss call only seedsprev_lossand returnsNone; subsequent calls return the (positive)delta_tadded on top of the linear floor.
entropy_aware_opd_loss(student_logits, teacher_logits, *, labels=None, h_max=None, temperature=1.0, reduction="batchmean") -> torch.Tensor
def entropy_aware_opd_loss(
student_logits: torch.Tensor,
teacher_logits: torch.Tensor,
*,
labels: torch.Tensor | None = None,
h_max: float | None = None,
temperature: float = 1.0,
reduction: str = "batchmean",
) -> torch.Tensor
Per-token mixture of forward and reverse KL gated by teacher entropy: w(t) = clamp(H_teacher(t)/h_max, 0, 1). High-entropy tokens use forward KL (mode-covering), low-entropy tokens use reverse KL (mode-seeking).
Parameters
| Name | Type | Default | Meaning |
|---|---|---|---|
student_logits |
Tensor (B,T,V) |
— | Student logits (grad). |
teacher_logits |
Tensor (B,T,V) |
— | Teacher logits (no grad). |
labels |
Tensor (B,T) | None |
None |
0/1 mask, applied multiplicatively after the per-token mix. |
h_max |
float | None |
None ⇒ log(V) |
Max-entropy normalizer. |
temperature |
float |
1.0 |
Softmax temperature on both. |
reduction |
str |
"batchmean" |
"batchmean", "sum", "mean", "none". |
Raises ValueError on shape mismatch (student vs teacher; labels vs per-token loss) or unknown reduction.
from composer_replication.distillation import entropy_aware_opd_loss
loss = entropy_aware_opd_loss(s_logits, t_logits, temperature=1.0)
loss.backward()
teacher_entropy(teacher_logits) -> torch.Tensor
⚠️ UNTESTED-CONTRACT (helper exposed from entropy_aware_opd.py's __all__ but not directly asserted).
Per-token entropy in nats. Input (B,T,V), output (B,T).
from composer_replication.distillation.entropy_aware_opd import teacher_entropy
H = teacher_entropy(teacher_logits) # (B, T)
6. composer_replication.teacher_replay
N-teacher OpenRouter parallel client + DPO-pair extractor. httpx is lazy-imported inside replay_trace; the deterministic local logic is testable without it.
DEFAULT_TEACHERS: list[TeacherSpec]
Three-teacher default set: anthropic/claude-opus-4.7, openai/gpt-5, deepseek/deepseek-v4-pro with paper-baseline OpenRouter pricing.
from composer_replication.teacher_replay import DEFAULT_TEACHERS
print([t["slug"] for t in DEFAULT_TEACHERS])
class TeacherSpec(TypedDict)
class TeacherSpec(TypedDict):
slug: str
input_per_mtok: float
output_per_mtok: float
OpenRouter model slug + per-million-token pricing.
spec: TeacherSpec = {"slug": "openai/gpt-5",
"input_per_mtok": 1.25, "output_per_mtok": 10.0}
class TraceState(TypedDict)
class TraceState(TypedDict):
state_id: str # unique within the trace
messages: list[dict] # OpenAI-style chat history up to (and incl.) this user prompt
student_action: str # what the student actually did at this step
One step of a frozen agentic trace. student_action is the raw text emitted by the student; teachers are queried with messages and asked to predict the assistant's next action.
state: TraceState = {"state_id": "ex001::0042",
"messages": [{"role": "user", "content": "..."}],
"student_action": "[TOOL_USE] name=Read input={...}"}
class TeacherCallResult(TypedDict)
class TeacherCallResult(TypedDict):
state_id: str
teacher_slug: str
response_text: str | None # None on error
latency_s: float
prompt_tokens: int
completion_tokens: int
cost_usd: float
error: str | None # None on success
One row of N×T results from replay_trace.
r: TeacherCallResult = {"state_id": "x", "teacher_slug": "openai/gpt-5",
"response_text": "ok", "latency_s": 1.2, "prompt_tokens": 100,
"completion_tokens": 5, "cost_usd": 0.001, "error": None}
class DPOPair(TypedDict)
class DPOPair(TypedDict):
state_id: str
state_messages: list[dict]
chosen: str # teacher-consensus action
rejected: str # student action
n_teachers_agreeing: int
One preference pair extracted from teacher-vs-student disagreement.
p: DPOPair = {"state_id": "x", "state_messages": [...], "chosen": "...",
"rejected": "...", "n_teachers_agreeing": 2}
async replay_trace(states, teachers=DEFAULT_TEACHERS, max_total_usd=5.0, api_key=None) -> list[TeacherCallResult]
async def replay_trace(
states: Sequence[TraceState],
teachers: Sequence[TeacherSpec] = tuple(DEFAULT_TEACHERS),
max_total_usd: float = 5.0,
api_key: str | None = None,
) -> list[TeacherCallResult]
For each state, fan-out one parallel call per teacher via OpenRouter. Hard-caps cumulative spend at max_total_usd (stops after the offending state completes).
Parameters
| Name | Type | Default | Meaning |
|---|---|---|---|
states |
Sequence[TraceState] |
— | Frozen trace, one entry per assistant turn. |
teachers |
Sequence[TeacherSpec] |
DEFAULT_TEACHERS |
Models to query in parallel. |
max_total_usd |
float |
5.0 |
Cumulative spend cap. |
api_key |
str | None |
None |
OpenRouter key; defaults to OPENROUTER_API_KEY env or ~/.hermes/.env. |
Returns flat list of TeacherCallResults (length len(states) * len(teachers) modulo budget cutoff).
Raises RuntimeError if OPENROUTER_API_KEY is not findable; ImportError if httpx is missing at call time.
import asyncio
from composer_replication import replay_trace
results = asyncio.run(replay_trace(states=my_trace, max_total_usd=1.0))
extract_dpo_pairs(states, teacher_actions, agreement_threshold=2) -> list[DPOPair]
def extract_dpo_pairs(
states: Sequence[TraceState],
teacher_actions: Sequence[TeacherCallResult],
agreement_threshold: int = 2,
) -> list[DPOPair]
Group teacher_actions by state_id, normalize whitespace, and emit one DPOPair per state where ≥agreement_threshold teachers agreed on an action that differs from the student's. chosen is the original (un-normalized) teacher response text.
Parameters
| Name | Type | Default | Meaning |
|---|---|---|---|
states |
Sequence[TraceState] |
— | Same as passed to replay_trace. |
teacher_actions |
Sequence[TeacherCallResult] |
— | Output of replay_trace. |
agreement_threshold |
int |
2 |
Min teachers that must agree for a pair to fire. |
Returns list of DPOPair. At most one pair per state (the most-agreed-upon action wins).
from composer_replication import extract_dpo_pairs
pairs = extract_dpo_pairs(my_states, results, agreement_threshold=2)
save_pairs(pairs, path) -> None
⚠️ UNTESTED-CONTRACT.
def save_pairs(pairs: Sequence[DPOPair], path: str | Path) -> None
Write pairs to JSONL (one dict per line). Creates parent dirs.
from composer_replication.teacher_replay import save_pairs
save_pairs(pairs, "/tmp/dpo_pairs.jsonl")
7. composer_replication.replaysim
ADR-004 normalization layer over teacher_replay. Re-exports DPOPair, TeacherCallResult, extract_dpo_pairs, replay_trace from teacher_replay.
class NormalizedDPOPair
@dataclass
class NormalizedDPOPair:
state_id: str
state_messages: list[dict[str, Any]]
chosen_messages: list[dict[str, Any]]
rejected_messages: list[dict[str, Any]]
n_teachers_agreeing: int
metadata: dict[str, Any]
Post-normalization shape. chosen_messages/rejected_messages are chat-format ([{"role": "assistant", "content": ...}]). metadata carries op-graph provenance, including {"skipped": True} when the normalizer was bypassed (skip_dj=True).
from composer_replication.replaysim import NormalizedDPOPair
n = NormalizedDPOPair(state_id="x", state_messages=[],
chosen_messages=[{"role": "assistant", "content": "ok"}],
rejected_messages=[{"role": "assistant", "content": "no"}],
n_teachers_agreeing=2, metadata={})
class DJNormalizer
class DJNormalizer:
DEFAULT_RECIPE: ClassVar[Path] # composer_replication/recipes/replaysim/default.yaml
def __init__(
self,
recipe_path: str | os.PathLike[str] | None = None,
*,
skip_dj: bool = False,
) -> None: ...
def normalize(
self,
pairs: Iterable[DPOPair | dict[str, Any]],
) -> list[NormalizedDPOPair]: ...
data-juicer-backed normalizer. Pipeline: each DPOPair → JSONL record → data_juicer.core.DefaultExecutor.run() against the recipe → JSONL → NormalizedDPOPair.
Constructor parameters
| Name | Type | Default | Meaning |
|---|---|---|---|
recipe_path |
str | PathLike | None |
None ⇒ default recipe |
data-juicer YAML recipe path. |
skip_dj |
bool (kw-only) |
False |
If True: passthrough; records get metadata={"skipped": True} and no ops run. |
normalize(pairs) -> list[NormalizedDPOPair] runs the op-graph. Output may be shorter than input if filter ops drop records.
Raises RuntimeError at construction time if skip_dj=False and data_juicer is not importable. FileNotFoundError if recipe_path (default or explicit) is missing and skip_dj=False.
from composer_replication.replaysim import DJNormalizer
norm = DJNormalizer(skip_dj=True)
out = norm.normalize(my_pairs)
async replay_and_normalize_trace(*, states, teachers=None, agreement_threshold=2, max_total_usd=5.0, normalizer=None, **replay_kwargs) -> tuple[list[TeacherCallResult], list[NormalizedDPOPair]]
async def replay_and_normalize_trace(
*,
states: Any,
teachers: Any = None,
agreement_threshold: int = 2,
max_total_usd: float = 5.0,
normalizer: DJNormalizer | None = None,
**replay_kwargs: Any,
) -> tuple[list[TeacherCallResult], list[NormalizedDPOPair]]
End-to-end async: replay → extract pairs → normalize.
Parameters
| Name | Type | Default | Meaning |
|---|---|---|---|
states |
Sequence[TraceState] |
— | Frozen trace. |
teachers |
Sequence[TeacherSpec] | None |
None ⇒ defaults |
Forwarded to replay_trace. |
agreement_threshold |
int |
2 |
Forwarded to extract_dpo_pairs. |
max_total_usd |
float |
5.0 |
Spend cap. |
normalizer |
DJNormalizer | None |
None ⇒ DJNormalizer() |
Pass DJNormalizer(skip_dj=True) to bypass. |
**replay_kwargs |
Any |
— | Forwarded to replay_trace (e.g. api_key). |
Returns (raw_teacher_actions, normalized_pairs).
import asyncio
from composer_replication.replaysim import replay_and_normalize_trace, DJNormalizer
raw, norm = asyncio.run(replay_and_normalize_trace(
states=my_states, normalizer=DJNormalizer(skip_dj=True)))
replay_and_normalize_trace_sync(*args, **kwargs) -> tuple[list[TeacherCallResult], list[NormalizedDPOPair]]
⚠️ UNTESTED-CONTRACT (sync wrapper around the async function; tests call the async form via asyncio.run).
def replay_and_normalize_trace_sync(*args, **kwargs) -> ...
Sync convenience wrapping asyncio.run(replay_and_normalize_trace(...)).
from composer_replication.replaysim.normalize import replay_and_normalize_trace_sync
raw, norm = replay_and_normalize_trace_sync(states=my_states)
8. composer_replication.ingestion & composer_replication.ingestion.claude_code
Trace-source adapters (ADR-002). v0.1 supports Claude Code session JSONL.
SYSTEM_PROMPT: str
Default synthetic system prompt injected at messages[0] for ingested traces (most Claude Code sessions don't write one). Truncated head: "You are a senior software engineer working as a coding agent in a terminal environment...".
from composer_replication import SYSTEM_PROMPT
print(SYSTEM_PROMPT[:60])
class IngestionStats
@dataclass
class IngestionStats:
n_records_total: int = 0
n_records_skipped: int = 0
n_states_emitted: int = 0
n_assistant_turns: int = 0
n_tool_use_blocks: int = 0
n_text_blocks: int = 0
skipped_subagent: int = 0
skipped_summary: int = 0
skipped_truncated_lines: int = 0
version_warnings: list[str] | None = None # initialized to [] in __post_init__
Counters populated by ClaudeCodeIngester.ingest() and exposed as ingester.last_stats.
from composer_replication import IngestionStats
s = IngestionStats(n_records_total=5)
print(s.version_warnings) # []
class ClaudeCodeIngester
class ClaudeCodeIngester:
def __init__(
self,
*,
system_prompt: str = SYSTEM_PROMPT,
skip_sidechain: bool = True,
strip_thinking: bool = True,
max_history_tokens: int | None = None,
) -> None: ...
def ingest(self, path: Path) -> Iterator[TraceState]: ...
Convert a Claude Code session JSONL to a stream of TraceStates — one per assistant TURN (not per tool_use block).
Constructor parameters
| Name | Type | Default | Meaning |
|---|---|---|---|
system_prompt |
str |
SYSTEM_PROMPT |
Synthetic system message injected at history[0]. |
skip_sidechain |
bool |
True |
Skip subagent files (agent-*.jsonl) and records with isSidechain=True. |
strip_thinking |
bool |
True |
Remove [THINKING] blocks from history handed to teachers (kept inside student_action). |
max_history_tokens |
int | None |
None |
⚠️ UNTESTED-CONTRACT — accepted but currently not used to truncate. |
ingest(path) -> Iterator[TraceState]: generator over TraceState objects. Each turn's state_id is f"{path.stem}::{idx:04d}". Side effect: replaces self.last_stats with a fresh IngestionStats and updates it as records stream.
from pathlib import Path
from composer_replication import ClaudeCodeIngester
ing = ClaudeCodeIngester()
for state in ing.ingest(Path("session.jsonl")):
print(state["state_id"])
print(ing.last_stats.n_states_emitted)
9. composer_replication.hint_generator
⚠️ UNTESTED-CONTRACT (entire module — used by the data collator config but not pinned by a test).
Template-based hint registry for SDPO error-site injection.
class HintContext(TypedDict, total=False)
class HintContext(TypedDict, total=False):
error_kind: str
error_message: str
available_tools: list[str]
tool_name: str
tool_schema: dict
intent: str
Per-error context dict consumed by hint templates.
HINT_TEMPLATES: dict[str, Callable[[HintContext], str]]
Default registry keys: "tool_not_found", "json_decode", "type_error", "runtime_error", "repeated_failure".
dispatch(error_kind, ctx=None) -> str | None
def dispatch(error_kind: str, ctx: HintContext | None = None) -> str | None
Look up error_kind in HINT_TEMPLATES. Returns the template's hint text, or None if the kind is unknown.
from composer_replication.hint_generator import dispatch
hint = dispatch("json_decode") # "Reminder: tool arguments must be valid JSON. ..."
register(error_kind, fn) -> None
def register(error_kind: str, fn: Callable[[HintContext], str]) -> None
Add or override a custom hint template.
from composer_replication.hint_generator import register
register("my_error", lambda ctx: "Reminder: try X.")
Individual template functions
⚠️ UNTESTED-CONTRACT — exported only via HINT_TEMPLATES, useful as building blocks:
hint_tool_not_found(ctx) -> strhint_json_decode(ctx) -> strhint_type_error(ctx) -> strhint_runtime_error(ctx) -> strhint_repeated_failure(ctx) -> str
Each accepts a HintContext and returns hint text. Signatures are uniform: Callable[[HintContext], str].
from composer_replication.hint_generator import hint_tool_not_found
text = hint_tool_not_found({"available_tools": ["Read", "Write"]})
10. composer_replication.trainer & sub-modules
Production trainer (TRL GRPOTrainer subclass) plus data collator.
class ComposerReplicationTrainer
class ComposerReplicationTrainer(GRPOTrainer):
def __init__(
self,
*args: Any,
alpha_sdpo: float = 0.1,
beta_replay: float = 0.05,
sdpo_jsd_beta: float = 0.5,
sdpo_temperature: float = 1.0,
sdpo_token_clip: float | None = None,
replay_dpo_beta: float = 0.1,
**kwargs: Any,
) -> None: ...
def _compute_loss(
self,
model: torch.nn.Module,
inputs: dict[str, torch.Tensor],
) -> torch.Tensor: ...
trl.GRPOTrainer subclass that overrides _compute_loss(model, inputs) to compose total = grpo + α·sdpo + β·trace_replay_dpo. When trl is not installed, the parent class falls back to object so the module imports — but instantiation will fail because the parent's GRPO machinery is missing.
Constructor (kw-only beyond GRPOTrainer's own *args, **kwargs)
| Name | Type | Default | Meaning |
|---|---|---|---|
alpha_sdpo |
float |
0.1 |
Channel-2 weight. |
beta_replay |
float |
0.05 |
Channel-3 weight. |
sdpo_jsd_beta |
float |
0.5 |
β for generalized_jsd_loss. |
sdpo_temperature |
float |
1.0 |
SDPO softmax temperature. |
sdpo_token_clip |
float | None |
None |
Per-token JSD clip. |
replay_dpo_beta |
float |
0.1 |
DPO β. |
_compute_loss(model, inputs) -> torch.Tensor — overrides GRPOTrainer._compute_loss. Calls super()._compute_loss for channel 1, then _compute_sdpo_loss and _compute_trace_replay_loss, then composes. Logs per-channel components every args.logging_steps (default 50). Raises whatever super() raises (TRL-shaped errors).
Internal methods (publicly accessible, exercised by spike tests)
- ⚠️ UNTESTED-CONTRACT
_compute_sdpo_loss(model, inputs) -> torch.Tensor— generalized-JSD between student forward andctx_teacher_input_idsforward. Returns0.0(with grad) whenalpha_sdpo == 0, the key is missing, or shapes mismatch. Logs a warning on shape mismatch. - ⚠️ UNTESTED-CONTRACT
_compute_trace_replay_loss(model, inputs) -> torch.Tensor— standard DPO overdpo_chosen_*anddpo_rejected_*, using precomputeddpo_chosen_ref_logprobs/dpo_rejected_ref_logprobs. - ⚠️ UNTESTED-CONTRACT
@staticmethod _sequence_logprobs(model, input_ids, response_mask) -> torch.Tensor— sum logprobs over response tokens; standard DPO accounting.
from composer_replication import ComposerReplicationTrainer
trainer = ComposerReplicationTrainer(
model=my_model, args=my_grpo_args, train_dataset=ds,
data_collator=my_collator, alpha_sdpo=0.1, beta_replay=0.05,
)
# trainer.train() # uses overridden _compute_loss
class TraceTurn(TypedDict, total=False) — trainer.data_collator
class TraceTurn(TypedDict, total=False):
role: str # "user" | "assistant" | "tool"
content: str
tool_call: dict | None
tool_error: str | None
error_meta: dict
One turn of an agentic trace as consumed by ComposerDataCollator.
class TraceExample(TypedDict, total=False) — trainer.data_collator
class TraceExample(TypedDict, total=False):
trace_id: str
turns: list[TraceTurn]
final_reward: float
dpo_pairs: list[dict] | None
One training example: (turns, optional dpo_pairs). dpo_pairs shape matches DPOPair.
class TokenizerLike — trainer.data_collator
⚠️ UNTESTED-CONTRACT (duck-typed protocol; used as a type hint).
class TokenizerLike:
pad_token_id: int
def __call__(self, text: str | list[str], **kwargs: Any) -> dict[str, list]: ...
def apply_chat_template(self, messages: list[dict], **kwargs: Any) -> str | list[int]: ...
Minimal protocol the collator needs. Compatible with HF AutoTokenizer.
class CollatorConfig — trainer.data_collator
@dataclass
class CollatorConfig:
max_seq_len: int = 4096
max_dpo_seq_len: int = 2048
pad_token_id: int = 0
ignore_index: int = -100
enable_sdpo: bool = True
hint_generator: Callable[[str, dict], str | None] | None = None
enable_replay_dpo: bool = True
rlvr_reward_key: str = "final_reward"
Tunables for ComposerDataCollator.
| Field | Default | Meaning |
|---|---|---|
max_seq_len |
4096 |
Truncation cap for student/teacher sequences. |
max_dpo_seq_len |
2048 |
Truncation cap for DPO chosen/rejected sequences. |
pad_token_id |
0 |
Padding token id. |
ignore_index |
-100 |
HF "ignore in loss" sentinel for SDPO mask. |
enable_sdpo |
True |
Toggle channel-2 fields. |
hint_generator |
Callable[[str, dict], str | None] | None (None) |
(error_kind, error_meta) -> hint_text. SDPO is no-op without this. |
enable_replay_dpo |
True |
Toggle channel-3 fields. |
rlvr_reward_key |
"final_reward" |
Key in TraceExample to read scalar reward. |
from composer_replication.trainer.data_collator import CollatorConfig
cfg = CollatorConfig(max_seq_len=2048, hint_generator=my_dispatch)
class ComposerDataCollator — trainer.data_collator
@dataclass
class ComposerDataCollator:
tokenizer: TokenizerLike
config: CollatorConfig = field(default_factory=CollatorConfig)
def __call__(
self, batch: Sequence[TraceExample]
) -> dict[str, torch.Tensor]: ...
Build trainer-ready batches from raw traces + optional DPO pairs.
Output dict keys (tested in spikes/005-integrated-trainer-skeleton/tests/test_data_collator.py):
- Channel 1 (always):
input_ids,attention_mask,response_mask,rewards. - Channel 2 (when
enable_sdpo=TrueAND batch has at least one error site ANDhint_generatoris set):ctx_teacher_input_ids,sdpo_loss_mask. - Channel 3 (when
enable_replay_dpo=TrueAND batch has at least onedpo_pair):dpo_chosen_input_ids,dpo_chosen_response_mask,dpo_rejected_input_ids,dpo_rejected_response_mask. (Reference logprobs are NOT computed here — the trainer does that pass.)
from composer_replication.trainer.data_collator import (
ComposerDataCollator, CollatorConfig)
collator = ComposerDataCollator(tokenizer=tok, config=CollatorConfig())
batch = collator([{"trace_id": "x", "turns": [...], "final_reward": 1.0}])
11. composer_replication.diloco
DiLoCo outer-loop wrapper around torchft.local_sgd.DiLoCo. Optional dep — when torchft is missing the package re-export composer_replication.make_diloco_outer_loop is None.
Module-level attributes
DiLoCo: Any—torchft.local_sgd.DiLoCoif importable elseNone.Manager: Any—torchft.manager.Managerif importable elseNone._DummyWork: Any—torchft.work._DummyWorkif importable elseNone._TORCHFT_AVAILABLE: bool— whether the imports succeeded.
from composer_replication.diloco import _TORCHFT_AVAILABLE, DiLoCo
make_diloco_outer_loop(manager, model_fragments, inner_optimizer, *, ...) -> torchft.local_sgd.DiLoCo
def make_diloco_outer_loop(
manager: Any,
model_fragments: list[torch.nn.Module],
inner_optimizer: torch.optim.Optimizer,
*,
outer_lr: float = 0.7,
outer_momentum: float = 0.9,
nesterov: bool = True,
sync_every: int = 100,
fragment_sync_delay: int = 0,
fragment_update_alpha: float = 0.0,
) -> Any
Construct a torchft.DiLoCo configured with framework-default hyperparams (DiLoCo paper §3.2: lr=0.7, momentum=0.9, Nesterov).
Parameters
| Name | Type | Default | Meaning |
|---|---|---|---|
manager |
torchft.Manager (or duck-typed MockManager) |
— | Provides allreduce, should_commit, current_step, start_quorum, etc. |
model_fragments |
list[torch.nn.Module] |
— | One module for vanilla DiLoCo; N modules for Streaming DiLoCo. |
inner_optimizer |
torch.optim.Optimizer |
— | Inner-step optimizer (steps every batch). |
outer_lr |
float |
0.7 |
Outer SGD lr. |
outer_momentum |
float |
0.9 |
Outer SGD momentum. |
nesterov |
bool |
True |
Nesterov momentum on outer SGD. |
sync_every |
int |
100 |
Inner steps per outer round. |
fragment_sync_delay |
int |
0 |
0 = vanilla; >0 = Streaming DiLoCo (requires CUDA streams). |
fragment_update_alpha |
float |
0.0 |
0 = full replacement on sync; >0 = exponential mix. |
Returns a torchft.local_sgd.DiLoCo instance — usable as a context manager.
Raises RuntimeError if torchft is not installed.
import torch
from composer_replication.diloco import make_diloco_outer_loop
opt = torch.optim.AdamW(model.parameters(), lr=1e-5)
outer = make_diloco_outer_loop(manager=mgr, model_fragments=[model],
inner_optimizer=opt, sync_every=100)
with outer:
for _ in range(N):
opt.zero_grad(); loss.backward(); opt.step()
12. composer_replication.diloco.serverless
ADR-005 serverless DiLoCo executors + object-store all-reduce.
class ReplicaHandle — serverless.executor
@dataclass
class ReplicaHandle:
rank: int
backend_name: str
metadata: dict[str, Any] = field(default_factory=dict)
Opaque handle returned by ServerlessExecutor.launch_replicas. metadata is backend-specific.
from composer_replication.diloco.serverless import ReplicaHandle
h = ReplicaHandle(rank=0, backend_name="local_process",
metadata={"pid": 12345})
class ServerlessExecutor (Protocol) — serverless.executor
@runtime_checkable
class ServerlessExecutor(Protocol):
backend_name: str
supports_inter_replica_network: bool
def launch_replicas(
self,
n_replicas: int,
entrypoint: str | Callable[..., Any],
entrypoint_args: Mapping[str, Any],
*,
gpu: str | None = None,
timeout: int = 3600,
) -> list[ReplicaHandle]: ...
def poll(self, handle: ReplicaHandle) -> str: ...
def stream_logs(self, handle: ReplicaHandle, *, n_lines: int = 200) -> str: ...
def cancel(self, handle: ReplicaHandle) -> None: ...
def collect(
self, handles: list[ReplicaHandle], *, timeout: int | None = None,
) -> list[dict[str, Any]]: ...
Structural protocol for serverless backends.
launch_replicas(...)returnslist[ReplicaHandle]of lengthn_replicasin rank order.entrypointis either an importable module path (usesmain()) or amodule.functionpath or aCallable(Local executor only).entrypoint_argsmay includerank_env(default"REPLICA_RANK").poll(handle) -> str: one of"pending","running","succeeded","failed","cancelled".stream_logs(handle, n_lines=200) -> str: best-effort recent stdout/stderr.cancel(handle) -> None: best-effort.collect(handles, timeout=None) -> list[dict]: blocks; each result dict hasrank,status,exit_code,error(andresultfromLocalProcessExecutor).
from composer_replication.diloco.serverless import ServerlessExecutor
def supports(x: ServerlessExecutor) -> bool:
return isinstance(x, ServerlessExecutor) # runtime_checkable
class LocalProcessExecutor — serverless.executor
class LocalProcessExecutor:
backend_name = "local_process"
supports_inter_replica_network = True
def __init__(self) -> None: ...
# implements ServerlessExecutor protocol
Reference implementation using Python multiprocessing (spawn context). Used for tests, CI smokes, and local development with file:// rendezvous.
launch_replicas(...): emits a soft warning on gpu != None (local processes share whatever GPUs are visible). metadata = {"pid": ..., "start_ts": ...}.
from composer_replication.diloco.serverless import LocalProcessExecutor
ex = LocalProcessExecutor()
handles = ex.launch_replicas(
n_replicas=2,
entrypoint="composer_replication.diloco.serverless.replica_entrypoint",
entrypoint_args={"rendezvous_uri": "/tmp/run/", "world_size": 2,
"trainer_module": "my.trainer"},
)
results = ex.collect(handles, timeout=60)
class ObjectStoreAllReduce — serverless.allreduce
class ObjectStoreAllReduce:
def __init__(
self,
uri: str,
rank: int,
world_size: int,
*,
round_id: int | None = None,
timeout_s: float = 1800.0,
poll_interval_s: float = 1.0,
) -> None: ...
@property
def round_id(self) -> int: ...
def allreduce(
self, tensor: torch.Tensor, *, name: str | None = None,
) -> torch.Tensor: ...
fsspec-backed pseudo-gradient rendezvous. uri accepts s3://, gs://, az://, hf://, file://, or a plain local path.
Constructor parameters
| Name | Type | Default | Meaning |
|---|---|---|---|
uri |
str |
— | fsspec URI or local path. Trailing / enforced. |
rank |
int |
— | This replica's rank. |
world_size |
int |
— | Total replicas. |
round_id |
int | None (kw-only) |
None ⇒ start at 0 |
Initial round counter. |
timeout_s |
float (kw-only) |
1800.0 |
Per-allreduce timeout. |
poll_interval_s |
float (kw-only) |
1.0 |
Sleep between peer-file existence checks. |
allreduce(tensor, name=None) -> torch.Tensor: serializes tensor.detach().cpu() to round_NNNNNN/rank_RRRR.pt, blocks until all peers post, then averages. Modifies tensor in place AND returns it. Increments the internal _round_counter.
Raises ValueError on invalid rank, RuntimeError if non-local URI is requested without fsspec installed, TimeoutError if peers don't show up before timeout_s.
from composer_replication.diloco.serverless import ObjectStoreAllReduce
import torch
store = ObjectStoreAllReduce("/tmp/run/", rank=0, world_size=2)
g = torch.zeros(10)
store.allreduce(g) # blocks for rank 1
class MockManager — serverless.allreduce
class MockManager:
def __init__(self, store: ObjectStoreAllReduce) -> None: ...
# torchft.Manager-shaped surface:
num_participants: int
rank: int
_use_async_quorum: bool # always False
_step: int
_state_dict_fns: dict[str, tuple[Any, Any]]
def allreduce(self, tensor: torch.Tensor, **_kwargs: Any) -> "_ImmediateWork": ...
def should_commit(self) -> bool: ...
def start_quorum(self) -> None: ...
def wait_quorum(self) -> int: ...
def current_step(self) -> int: ...
def allow_state_dict_read(self) -> None: ...
def disallow_state_dict_read(self) -> None: ...
def register_state_dict_fn(self, key: str, load_fn: Any, save_fn: Any) -> None: ...
def is_leader(self) -> bool: ...
Drop-in replacement for torchft.Manager that routes allreduce through ObjectStoreAllReduce. All other methods are no-ops or simple counters appropriate for single-shot serverless DiLoCo.
allreduce(tensor)returns an_ImmediateWorkwhose.wait()is a no-op (the tensor is already averaged).should_commit()alwaysTrue(no fault-tolerance failover).start_quorum()bumps_step.is_leader()returnsrank == 0.
from composer_replication.diloco.serverless import MockManager, ObjectStoreAllReduce
store = ObjectStoreAllReduce("/tmp/run/", rank=0, world_size=2)
mgr = MockManager(store)
# pass mgr into make_diloco_outer_loop(manager=mgr, ...)
class _ImmediateWork — serverless.allreduce
⚠️ UNTESTED-CONTRACT internal helper exported from __all__. Work-shaped wrapper with .wait() -> True and .get_future() -> torch.futures.Future. Consumed by torchft DiLoCo's perform_sync.
from composer_replication.diloco.serverless.allreduce import _ImmediateWork
class ModalExecutor — serverless.modal
🟡 SKELETON — raises NotImplementedError; see ADR-005. Class body documents the v0 implementation pattern (Modal app.function + function.spawn(rank=...)).
from composer_replication.diloco.serverless.modal import ModalExecutor
# ModalExecutor() # would NotImplementedError when instantiated
class HFJobsExecutor — serverless.hf_jobs
🟡 SKELETON — raises NotImplementedError; see ADR-005. Class body documents the v0 pattern using huggingface_hub.run_job against hf://datasets/.../ rendezvous.
from composer_replication.diloco.serverless.hf_jobs import HFJobsExecutor
# instantiation will fail until v0 implementation lands
replica_entrypoint.main(...) — serverless.replica_entrypoint
def main(
rendezvous_uri: str,
world_size: int,
trainer_module: str,
trainer_fn: str = "train",
trainer_kwargs: dict[str, Any] | None = None,
) -> Any
Script run by every replica. Reads REPLICA_RANK env var, builds ObjectStoreAllReduce + MockManager, imports trainer_module, and calls getattr(mod, trainer_fn)(**trainer_kwargs, manager=..., rank=..., world_size=...). Returns whatever the train fn returns.
Raises RuntimeError if REPLICA_RANK env var is missing; ValueError if rank ∉ [0, world_size).
The if __name__ == "__main__" block accepts CLI flags --rendezvous, --world-size, --trainer-module, --trainer-fn, --trainer-kwargs-json.
# In-process invocation
import os
os.environ["REPLICA_RANK"] = "0"
from composer_replication.diloco.serverless.replica_entrypoint import main
result = main(rendezvous_uri="/tmp/run/", world_size=1,
trainer_module="my.trainer", trainer_fn="train")
13. composer_replication.recipes.prime_rl.composer_loss
PRIME-RL adapter (ADR-006). Maps PRIME-RL's LossInputs struct onto channel 1 (DPPO + KL on the importance ratio, mirroring PRIME-RL's upstream default_loss_fn at prime_rl/trainer/rl/loss.py lines 116-165). Channel 2 raises NotImplementedError; channel 3 is out of scope.
loss_fn(inputs, *, alpha_sdpo=0.0, beta_dpo=0.0, dppo_mask_high=0.2, dppo_mask_low=0.2, adv_tau=1.0, kl_tau=1e-3) -> torch.Tensor
def loss_fn(
inputs: Any, # PRIME-RL's LossInputs (duck-typed)
*,
alpha_sdpo: float = 0.0,
beta_dpo: float = 0.0,
dppo_mask_high: float = 0.2,
dppo_mask_low: float = 0.2,
adv_tau: float = 1.0,
kl_tau: float = 1e-3,
) -> Any # torch.Tensor scalar
PRIME-RL passes per-sample 1-D (seq,) tensors (not batched). The function mirrors PRIME-RL's upstream DPPO+KL formula:
- Mask gate is on probability-space
probs_diff = exp(trainer_lp) - exp(inference_lp)(NOT on the log-ratio). - A token is dropped iff its advantage sign matches the offending bound: positive-advantage tokens are dropped when
probs_diff > dppo_mask_high, negative-advantage tokens whenprobs_diff < -dppo_mask_low. (PRIME-RL stores both bounds withField(..., ge=0)and applies the sign internally.) - The PG term is
keep * (adv_tau * advantages) * exp(trainer_lp - inference_lp)(importance-ratio corrected, not REINFORCE). - A KL penalty
kl_tau * log_importance_ratio**2is added on the fullloss_mask(DPPO masking does not gate it). - Reduction is a plain
sum(); PRIME-RL's outercompute_lossdivides byloss_scale.
Parameters
| Name | Type | Default | Meaning |
|---|---|---|---|
inputs |
PRIME-RL LossInputs (duck-typed) |
— | Must expose trainer_logprobs, inference_logprobs, advantages, loss_mask (all 1-D), and optionally teacher_logprobs. |
alpha_sdpo |
float (kw-only) |
0.0 |
Channel-2 weight. Must be 0 in v0; >0 → NotImplementedError. |
beta_dpo |
float (kw-only) |
0.0 |
Channel-3 weight. Non-zero emits a UserWarning. |
dppo_mask_high |
float (kw-only), >= 0 |
0.2 |
Upper probability-diff threshold. PRIME-RL DefaultLossConfig default. |
dppo_mask_low |
float (kw-only), >= 0 |
0.2 |
Magnitude of lower probability-diff threshold (sign flipped internally). PRIME-RL default. |
adv_tau |
float (kw-only), >= 0 |
1.0 |
Advantage temperature. PRIME-RL default. |
kl_tau |
float (kw-only), >= 0 |
1e-3 |
KL term temperature. PRIME-RL default. |
Returns scalar torch.Tensor (PRIME-RL's trainer calls .backward()).
Raises ValueError if any of trainer_logprobs, inference_logprobs, advantages, loss_mask is not 1-D, or any of the four >=0-constrained knobs is negative. NotImplementedError if alpha_sdpo > 0 (channel 2 deferred).
from composer_replication.recipes.prime_rl.composer_loss import loss_fn
# In PRIME-RL config:
# loss:
# custom:
# import_path: composer_replication.recipes.prime_rl.composer_loss:loss_fn
# kwargs:
# dppo_mask_high: 0.2
# dppo_mask_low: 0.2
# adv_tau: 1.0
# kl_tau: 1.0e-3
14. composer_replication.recipes.monarch.actors
🟡 SKELETON module per ADR-006. Importable; classes raise NotImplementedError on instantiation. Documents the actor signatures so the recipe matrix is complete.
class TrainerActor 🟡
class TrainerActor:
backend = "monarch"
role = "trainer"
def __init__(self) -> None: raise NotImplementedError(...)
async def train_outer_step(self, batch_id: int) -> dict[str, Any]: raise NotImplementedError
Hosts the framework's 3-channel composer trainer. Real impl deferred to v0.2+.
class GeneratorActor 🟡
class GeneratorActor:
backend = "monarch"
role = "generator"
def __init__(self) -> None: raise NotImplementedError(...)
async def rollout(self, prompts: list[str]) -> list[str]: raise NotImplementedError
vLLM-backed rollout actor.
class RewarderActor 🟡
class RewarderActor:
backend = "monarch"
role = "rewarder"
def __init__(self) -> None: raise NotImplementedError(...)
async def score(self, completions: list[str]) -> list[float]: raise NotImplementedError
verifiers-protocol rewarder.
class TeacherPoolActor 🟡
class TeacherPoolActor:
backend = "monarch"
role = "teacher_pool"
def __init__(self) -> None: raise NotImplementedError(...)
Channel-3 teacher pool wrapping composer_replication.teacher_replay.
# All Monarch actors raise on instantiation in v0:
from composer_replication.recipes.monarch.actors import TrainerActor
# TrainerActor() # NotImplementedError
Notes on test coverage
Tested contracts (referenced spike/test paths):
compose_loss+LossComponents+build_batch:composer_replication/tests/test_compose_loss_integration.py,spikes/006-real-hf-model-smoke/tests/.generalized_jsd_loss:spikes/005-integrated-trainer-skeleton/tests/test_opsd_loss.py.simpo_loss,taid_loss,taid_alpha_schedule,taid_blended_logits,entropy_aware_opd_loss:composer_replication/distillation/tests/test_distillation_losses.py.replay_trace,extract_dpo_pairs,DPOPair,TraceState,TeacherCallResult,TeacherSpec,DEFAULT_TEACHERS:spikes/005-integrated-trainer-skeleton/tests/test_teacher_replay.py.DJNormalizer,NormalizedDPOPair,replay_and_normalize_trace:composer_replication/replaysim/tests/test_replaysim.py.ClaudeCodeIngester,IngestionStats,SYSTEM_PROMPT:spikes/007-real-trace-ingestion/tests/.ComposerDataCollator,CollatorConfig,TraceTurn,TraceExample:spikes/005-integrated-trainer-skeleton/tests/test_data_collator.py.ComposerReplicationTrainer._compute_loss(composition arithmetic):spikes/005-integrated-trainer-skeleton/tests/test_loss_composition_smoke.py.make_diloco_outer_loop+ sign convention:spikes/008-streaming-diloco/tests/test_diloco_smoke.py.ObjectStoreAllReduce,MockManager,LocalProcessExecutor,ReplicaHandle,ServerlessExecutor,replica_entrypoint.main:composer_replication/diloco/serverless/tests/test_serverless_local.py,test_serverless_diloco_integration.py.recipes.prime_rl.composer_loss.loss_fn:composer_replication/recipes/prime_rl/tests/test_composer_loss.py.
Untested-contract symbols (⚠️) and skeletons (🟡) are flagged inline above.
Document path: /mnt/e/CS/HF/composer-replication-framework/docs/API_REFERENCE.md