Codeseys's picture
Wave 12: close V1-V8 brief — GPU smoke, SDPO firing, real-trace e2e
d88715c
"""composer_replication — Composer 2.5 Replication Framework.
A research-grade, open replication of Cursor Composer 2.5's training recipe:
take any HuggingFace model, further-RL-train it using a 3-channel loss combining
1. RLVR / GRPO (channel 1, via TRL)
2. SDPO hint-distillation (channel 2, OPSD-based)
3. Multi-teacher trace-replay DPO (channel 3, this framework's contribution)
with optional DiLoCo / Streaming DiLoCo outer-loop sync for distributed runs.
See https://huggingface.co/Codeseys/composer-replication-framework for the
full project README, design docs, ADRs, and verification spikes.
## Two API surfaces, on purpose
This package exposes BOTH a verification-harness API and a production-trainer
API. Use the right one for your purpose:
### Verification harness (small, easy to call, NOT for real training)
`compose_loss(model, batch, alpha_sdpo, beta_replay)` is a free function
that returns `LossComponents(lm_ce, sdpo_jsd, trace_replay_dpo, total)`.
It stubs the GRPO channel with LM cross-entropy on response tokens (the
limit GRPO converges to under deterministic rewards) so you can verify
the 3-channel composition wires together WITHOUT spinning up TRL's full
reward + advantage machinery.
`build_batch(tokenizer)` produces a real chat-template-formatted batch
with all keys `compose_loss` may consume.
Use these for:
- CPU smokes on real HF models (Spike 006 / Spike 002a-mini-gpu)
- Unit testing custom loss-composition variants
- Debugging gradient flow through one of the three channels
- Anything where you want to call backward() on a real model without
spinning up TRL
### Production trainer (use for actual training runs)
`ComposerReplicationTrainer` is a `trl.GRPOTrainer` subclass that
overrides `_compute_loss(model, inputs)` to compose the same 3 channels
on top of TRL's real GRPO machinery. This is what you train models with.
Use this for:
- Real training runs on HF models with real rollouts + rewards
- Anything where the GRPO channel's policy-gradient signal matters
(i.e., not a memorization smoke)
The verification harness's `compose_loss` is intentionally NOT a
drop-in replacement for `_compute_loss` — they target different
phases of the framework's lifecycle.
## Quickstart (verification-harness API)
>>> from composer_replication import compose_loss, build_batch
>>> from transformers import AutoModelForCausalLM, AutoTokenizer
>>> model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen2.5-0.5B-Instruct")
>>> tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-0.5B-Instruct")
>>> batch = build_batch(tokenizer)
>>> components = compose_loss(model, batch, alpha_sdpo=0.1, beta_replay=0.05)
>>> components.total.backward()
See `examples/qwen_05b_quickstart/run.py` in the repo for a complete CPU
smoke (verification harness) and `spikes/002a-mini-gpu-smoke/run_gpu_smoke.py`
for a GPU smoke (verification harness, bf16, 50 steps).
For production-trainer usage, see `docs/INTEGRATION_ARCHITECTURE.md` Recipe A.
"""
from __future__ import annotations
# Loss composition (Spike 006)
from composer_replication.loss import LossComponents, compose_loss
from composer_replication.batch import build_batch
# Trace ingestion (Spike 007)
from composer_replication.ingestion.claude_code import (
SYSTEM_PROMPT,
ClaudeCodeIngester,
IngestionStats,
)
# OPSD / SDPO loss (verified extension from siyan-zhao/OPSD, MIT)
from composer_replication.opsd import generalized_jsd_loss
# Teacher replay (Spike 001 → trainer)
from composer_replication.teacher_replay import (
DEFAULT_TEACHERS,
DPOPair,
TeacherCallResult,
TeacherSpec,
TraceState,
extract_dpo_pairs,
replay_trace,
)
# Trainer (Spike 005)
from composer_replication.trainer import ComposerReplicationTrainer
# DiLoCo (Spike 008) — optional, requires torchft
try:
from composer_replication.diloco import make_diloco_outer_loop
_DILOCO_AVAILABLE = True
except ImportError:
_DILOCO_AVAILABLE = False
make_diloco_outer_loop = None # type: ignore[assignment]
__version__ = "0.1.0"
__all__ = [
# Core loss
"compose_loss",
"LossComponents",
"build_batch",
"generalized_jsd_loss",
# Trace ingestion
"ClaudeCodeIngester",
"IngestionStats",
"SYSTEM_PROMPT",
"TraceState",
# Teacher replay
"DEFAULT_TEACHERS",
"DPOPair",
"TeacherCallResult",
"TeacherSpec",
"extract_dpo_pairs",
"replay_trace",
# Trainer
"ComposerReplicationTrainer",
# DiLoCo (optional)
"make_diloco_outer_loop",
# Meta
"_DILOCO_AVAILABLE",
"__version__",
]