Reinforcement Learning
Transformers
English
post-training
distillation
agentic-coding
composer-2.5
cursor
kimi-k2
grpo
dapo
diloco
openenv
trl
verl
research
methodology
Instructions to use Codeseys/composer-replication-framework with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- Transformers
How to use Codeseys/composer-replication-framework with Transformers:
# Load model directly from transformers import AutoModel model = AutoModel.from_pretrained("Codeseys/composer-replication-framework", dtype="auto") - Notebooks
- Google Colab
- Kaggle
File size: 7,659 Bytes
ac05fbf e5add15 ac05fbf e5add15 ac05fbf e5add15 ac05fbf e5add15 ac05fbf e5add15 ac05fbf e5add15 ac05fbf e5add15 ac05fbf e5add15 ac05fbf e5add15 ac05fbf e5add15 ac05fbf e5add15 ac05fbf e5add15 ac05fbf e5add15 ac05fbf e5add15 ac05fbf e5add15 ac05fbf e5add15 ac05fbf e5add15 ac05fbf e5add15 ac05fbf e5add15 ac05fbf | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 | """opsd_loss.py — Self-distillation loss, lifted from siyan-zhao/OPSD.
Original source: github.com/siyan-zhao/OPSD::OPSDTrainer.generalized_jsd_loss (MIT).
Verified self-contained via DeepWiki audit on 2026-05-25.
Re-aligned byte-for-byte against upstream `opsd_trainer.py` lines 381-479 on
2026-05-26 after Wave 15 math review found three numerical divergences (mixture
weighting, β coefficient placement, reduction divisor) and one docstring mislabel.
Mathematical reference:
- OPSD paper: Zhao et al., "Self-Distilled Reasoner: On-Policy Self-Distillation
for LLMs", arXiv:2601.18734.
- SDPO paper: Hübotter et al., "Reinforcement Learning via Self-Distillation",
arXiv:2601.20802 (formalizes the same loss as Composer 2.5's "Targeted RL with
Textual Feedback").
The loss computes JSD/KL divergence between a teacher distribution (model
conditioned on privileged information / a hint) and a student distribution
(model on the original context). Both come from the SAME model — the teacher
is just "the model with hint inserted into context."
Composer 2.5 uses this with the privileged information being a "hint" inserted
at the error-turn site. We use the same loss; the data collator constructs
ctx_teacher = ctx_student + hint_at_error_turn for us.
"""
from __future__ import annotations
import torch
import torch.nn.functional as F
def generalized_jsd_loss(
student_logits: torch.Tensor,
teacher_logits: torch.Tensor,
labels: torch.Tensor | None = None,
beta: float = 0.5,
temperature: float = 1.0,
reduction: str = "batchmean",
logits_are_probs: bool = False,
top_k: int | None = None,
token_clip: float | None = None,
) -> torch.Tensor:
"""Generalized Jensen-Shannon Divergence loss between student and teacher.
Byte-for-byte replication of `OPSDTrainer.generalized_jsd_loss`
(siyan-zhao/OPSD, opsd_trainer.py lines 381-479). See
https://huggingface.co/papers/2306.13649 Eq. (1) for the definition.
Args:
student_logits: (B, T, V) — student model logits at each token position.
teacher_logits: (B, T, V) — teacher (= same model with hint context) logits.
labels: (B, T) — token-level mask. Positions with label == -100 are ignored
(standard HF padding/ignored convention). For Composer-style hint-distill,
mask should be 1 at error-turn tokens AFTER the hint, 0 elsewhere.
beta: in [0, 1]. NOTE on direction (per `F.kl_div` semantics, where
`F.kl_div(log_q, log_p, log_target=True)` computes KL(p || q)):
β = 0 → kl_div(student_log_probs, teacher_log_probs)
= KL(teacher || student) (reverse KL — mode-covering for student)
β = 1 → kl_div(teacher_log_probs, student_log_probs)
= KL(student || teacher) (forward KL — mode-seeking for student)
β = 0.5 → symmetric JSD with M = 0.5*(P+Q)
General β ∈ (0,1): mixture M = (1-β)·P_student + β·P_teacher and
jsd = β·KL(teacher||M) + (1-β)·KL(student||M).
temperature: softens distributions; T > 1 encourages distribution-matching
on broader tail probabilities. SDPO paper uses 1.0.
reduction: "batchmean" | "sum" | "mean" | "none". "batchmean" matches
upstream OPSD: divides by `mask.sum()` when labels are given, else
by the leading dim of jsd (= batch size). This differs from PyTorch's
`KLDivLoss(reduction='batchmean')` (which divides by batch). We match
upstream because gradient scale stability matters more than the name.
logits_are_probs: if True, inputs are already probabilities (skip softmax).
top_k: restrict KL to top-k tokens of the teacher distribution.
Saves compute on large vocabularies (Qwen3 vocab = 152K).
token_clip: clip per-token JSD to this max. Stabilizes training.
SDPO paper does NOT clip; OPSD code defaults to None (no clip).
Returns:
Scalar loss tensor (or unreduced (B, T, V) tensor for reduction="none").
"""
# Path A: probabilities-in. Take log directly with a clamp for stability.
if logits_are_probs:
student_log_probs = torch.log(student_logits.clamp_min(1e-8))
teacher_log_probs = torch.log(teacher_logits.clamp_min(1e-8))
else:
# Apply temperature scaling to logits before computing probabilities.
student_logits = student_logits / temperature
teacher_logits = teacher_logits / temperature
if top_k is not None and top_k > 0:
# Restrict to top-k tokens of the teacher distribution and renormalize.
_, top_k_indices = torch.topk(teacher_logits, k=top_k, dim=-1)
student_logits = torch.gather(student_logits, dim=-1, index=top_k_indices)
teacher_logits = torch.gather(teacher_logits, dim=-1, index=top_k_indices)
student_log_probs = F.log_softmax(student_logits, dim=-1)
teacher_log_probs = F.log_softmax(teacher_logits, dim=-1)
if beta == 0:
# F.kl_div(input=log_q, target=log_p, log_target=True) computes KL(p || q):
# sum_x p(x) * (log p(x) - log q(x))
# With input=student_log_probs, target=teacher_log_probs → KL(teacher || student).
jsd = F.kl_div(student_log_probs, teacher_log_probs, reduction="none", log_target=True)
elif beta == 1:
jsd = F.kl_div(teacher_log_probs, student_log_probs, reduction="none", log_target=True)
else:
# Compute the log of the β-weighted mixture distribution:
# M = (1-β)·P_student + β·P_teacher
# log M = logsumexp([log P_student + log(1-β), log P_teacher + log(β)])
beta = torch.tensor(beta, dtype=student_log_probs.dtype, device=student_log_probs.device)
mixture_log_probs = torch.logsumexp(
torch.stack([student_log_probs + torch.log1p(-beta), teacher_log_probs + torch.log(beta)]),
dim=0,
)
# Compute KL divergences using F.kl_div.
# PyTorch differs from the standard mathematical definition, so the order of
# the probability distributions is swapped compared to that defined in the paper.
kl_teacher = F.kl_div(mixture_log_probs, teacher_log_probs, reduction="none", log_target=True)
kl_student = F.kl_div(mixture_log_probs, student_log_probs, reduction="none", log_target=True)
# Generalized JSD: β weights the teacher-leg KL (matches upstream).
jsd = beta * kl_teacher + (1 - beta) * kl_student
# Per-token clipping: cap each token's divergence value.
if token_clip is not None:
jsd = jsd.clamp(max=token_clip)
# Masking. labels has shape (B, T); jsd has shape (B, T, V) (or top_k for V).
# `jsd[mask]` indexes the first two dims, yielding shape (n_valid, V).
mask = None
if labels is not None:
mask = labels != -100
jsd = jsd[mask]
# Apply reduction (matches upstream byte-for-byte for batchmean/sum/mean).
if reduction == "batchmean":
if labels is not None:
assert mask is not None
return jsd.sum() / mask.sum()
return jsd.sum() / jsd.size(0)
elif reduction == "sum":
return jsd.sum()
elif reduction == "mean":
return jsd.mean()
elif reduction == "none":
return jsd
else:
# Upstream falls through to `return jsd` for unknown reductions; we raise
# to surface caller bugs instead of silently returning an unreduced tensor.
raise ValueError(f"Unknown reduction: {reduction}")
__all__ = ["generalized_jsd_loss"]
|