Reinforcement Learning
Transformers
English
post-training
distillation
agentic-coding
composer-2.5
cursor
kimi-k2
grpo
dapo
diloco
openenv
trl
verl
research
methodology
Instructions to use Codeseys/composer-replication-framework with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- Transformers
How to use Codeseys/composer-replication-framework with Transformers:
# Load model directly from transformers import AutoModel model = AutoModel.from_pretrained("Codeseys/composer-replication-framework", dtype="auto") - Notebooks
- Google Colab
- Kaggle
| """opsd_loss.py — Self-distillation loss, lifted from siyan-zhao/OPSD. | |
| Original source: github.com/siyan-zhao/OPSD::OPSDTrainer.generalized_jsd_loss (MIT). | |
| Verified self-contained via DeepWiki audit on 2026-05-25. | |
| Re-aligned byte-for-byte against upstream `opsd_trainer.py` lines 381-479 on | |
| 2026-05-26 after Wave 15 math review found three numerical divergences (mixture | |
| weighting, β coefficient placement, reduction divisor) and one docstring mislabel. | |
| Mathematical reference: | |
| - OPSD paper: Zhao et al., "Self-Distilled Reasoner: On-Policy Self-Distillation | |
| for LLMs", arXiv:2601.18734. | |
| - SDPO paper: Hübotter et al., "Reinforcement Learning via Self-Distillation", | |
| arXiv:2601.20802 (formalizes the same loss as Composer 2.5's "Targeted RL with | |
| Textual Feedback"). | |
| The loss computes JSD/KL divergence between a teacher distribution (model | |
| conditioned on privileged information / a hint) and a student distribution | |
| (model on the original context). Both come from the SAME model — the teacher | |
| is just "the model with hint inserted into context." | |
| Composer 2.5 uses this with the privileged information being a "hint" inserted | |
| at the error-turn site. We use the same loss; the data collator constructs | |
| ctx_teacher = ctx_student + hint_at_error_turn for us. | |
| """ | |
| from __future__ import annotations | |
| import torch | |
| import torch.nn.functional as F | |
| def generalized_jsd_loss( | |
| student_logits: torch.Tensor, | |
| teacher_logits: torch.Tensor, | |
| labels: torch.Tensor | None = None, | |
| beta: float = 0.5, | |
| temperature: float = 1.0, | |
| reduction: str = "batchmean", | |
| logits_are_probs: bool = False, | |
| top_k: int | None = None, | |
| token_clip: float | None = None, | |
| ) -> torch.Tensor: | |
| """Generalized Jensen-Shannon Divergence loss between student and teacher. | |
| Byte-for-byte replication of `OPSDTrainer.generalized_jsd_loss` | |
| (siyan-zhao/OPSD, opsd_trainer.py lines 381-479). See | |
| https://huggingface.co/papers/2306.13649 Eq. (1) for the definition. | |
| Args: | |
| student_logits: (B, T, V) — student model logits at each token position. | |
| teacher_logits: (B, T, V) — teacher (= same model with hint context) logits. | |
| labels: (B, T) — token-level mask. Positions with label == -100 are ignored | |
| (standard HF padding/ignored convention). For Composer-style hint-distill, | |
| mask should be 1 at error-turn tokens AFTER the hint, 0 elsewhere. | |
| beta: in [0, 1]. NOTE on direction (per `F.kl_div` semantics, where | |
| `F.kl_div(log_q, log_p, log_target=True)` computes KL(p || q)): | |
| β = 0 → kl_div(student_log_probs, teacher_log_probs) | |
| = KL(teacher || student) (reverse KL — mode-covering for student) | |
| β = 1 → kl_div(teacher_log_probs, student_log_probs) | |
| = KL(student || teacher) (forward KL — mode-seeking for student) | |
| β = 0.5 → symmetric JSD with M = 0.5*(P+Q) | |
| General β ∈ (0,1): mixture M = (1-β)·P_student + β·P_teacher and | |
| jsd = β·KL(teacher||M) + (1-β)·KL(student||M). | |
| temperature: softens distributions; T > 1 encourages distribution-matching | |
| on broader tail probabilities. SDPO paper uses 1.0. | |
| reduction: "batchmean" | "sum" | "mean" | "none". "batchmean" matches | |
| upstream OPSD: divides by `mask.sum()` when labels are given, else | |
| by the leading dim of jsd (= batch size). This differs from PyTorch's | |
| `KLDivLoss(reduction='batchmean')` (which divides by batch). We match | |
| upstream because gradient scale stability matters more than the name. | |
| logits_are_probs: if True, inputs are already probabilities (skip softmax). | |
| top_k: restrict KL to top-k tokens of the teacher distribution. | |
| Saves compute on large vocabularies (Qwen3 vocab = 152K). | |
| token_clip: clip per-token JSD to this max. Stabilizes training. | |
| SDPO paper does NOT clip; OPSD code defaults to None (no clip). | |
| Returns: | |
| Scalar loss tensor (or unreduced (B, T, V) tensor for reduction="none"). | |
| """ | |
| # Path A: probabilities-in. Take log directly with a clamp for stability. | |
| if logits_are_probs: | |
| student_log_probs = torch.log(student_logits.clamp_min(1e-8)) | |
| teacher_log_probs = torch.log(teacher_logits.clamp_min(1e-8)) | |
| else: | |
| # Apply temperature scaling to logits before computing probabilities. | |
| student_logits = student_logits / temperature | |
| teacher_logits = teacher_logits / temperature | |
| if top_k is not None and top_k > 0: | |
| # Restrict to top-k tokens of the teacher distribution and renormalize. | |
| _, top_k_indices = torch.topk(teacher_logits, k=top_k, dim=-1) | |
| student_logits = torch.gather(student_logits, dim=-1, index=top_k_indices) | |
| teacher_logits = torch.gather(teacher_logits, dim=-1, index=top_k_indices) | |
| student_log_probs = F.log_softmax(student_logits, dim=-1) | |
| teacher_log_probs = F.log_softmax(teacher_logits, dim=-1) | |
| if beta == 0: | |
| # F.kl_div(input=log_q, target=log_p, log_target=True) computes KL(p || q): | |
| # sum_x p(x) * (log p(x) - log q(x)) | |
| # With input=student_log_probs, target=teacher_log_probs → KL(teacher || student). | |
| jsd = F.kl_div(student_log_probs, teacher_log_probs, reduction="none", log_target=True) | |
| elif beta == 1: | |
| jsd = F.kl_div(teacher_log_probs, student_log_probs, reduction="none", log_target=True) | |
| else: | |
| # Compute the log of the β-weighted mixture distribution: | |
| # M = (1-β)·P_student + β·P_teacher | |
| # log M = logsumexp([log P_student + log(1-β), log P_teacher + log(β)]) | |
| beta = torch.tensor(beta, dtype=student_log_probs.dtype, device=student_log_probs.device) | |
| mixture_log_probs = torch.logsumexp( | |
| torch.stack([student_log_probs + torch.log1p(-beta), teacher_log_probs + torch.log(beta)]), | |
| dim=0, | |
| ) | |
| # Compute KL divergences using F.kl_div. | |
| # PyTorch differs from the standard mathematical definition, so the order of | |
| # the probability distributions is swapped compared to that defined in the paper. | |
| kl_teacher = F.kl_div(mixture_log_probs, teacher_log_probs, reduction="none", log_target=True) | |
| kl_student = F.kl_div(mixture_log_probs, student_log_probs, reduction="none", log_target=True) | |
| # Generalized JSD: β weights the teacher-leg KL (matches upstream). | |
| jsd = beta * kl_teacher + (1 - beta) * kl_student | |
| # Per-token clipping: cap each token's divergence value. | |
| if token_clip is not None: | |
| jsd = jsd.clamp(max=token_clip) | |
| # Masking. labels has shape (B, T); jsd has shape (B, T, V) (or top_k for V). | |
| # `jsd[mask]` indexes the first two dims, yielding shape (n_valid, V). | |
| mask = None | |
| if labels is not None: | |
| mask = labels != -100 | |
| jsd = jsd[mask] | |
| # Apply reduction (matches upstream byte-for-byte for batchmean/sum/mean). | |
| if reduction == "batchmean": | |
| if labels is not None: | |
| assert mask is not None | |
| return jsd.sum() / mask.sum() | |
| return jsd.sum() / jsd.size(0) | |
| elif reduction == "sum": | |
| return jsd.sum() | |
| elif reduction == "mean": | |
| return jsd.mean() | |
| elif reduction == "none": | |
| return jsd | |
| else: | |
| # Upstream falls through to `return jsd` for unknown reductions; we raise | |
| # to surface caller bugs instead of silently returning an unreduced tensor. | |
| raise ValueError(f"Unknown reduction: {reduction}") | |
| __all__ = ["generalized_jsd_loss"] | |