Codeseys's picture
Wave 15: 4-angle multi-model self-critique caught 2 math BLOCKERs in primary loss kernels; fixed against upstream byte-for-byte + GSM8K example + ergonomics
e5add15
"""opsd_loss.py — Self-distillation loss, lifted from siyan-zhao/OPSD.
Original source: github.com/siyan-zhao/OPSD::OPSDTrainer.generalized_jsd_loss (MIT).
Verified self-contained via DeepWiki audit on 2026-05-25.
Re-aligned byte-for-byte against upstream `opsd_trainer.py` lines 381-479 on
2026-05-26 after Wave 15 math review found three numerical divergences (mixture
weighting, β coefficient placement, reduction divisor) and one docstring mislabel.
Mathematical reference:
- OPSD paper: Zhao et al., "Self-Distilled Reasoner: On-Policy Self-Distillation
for LLMs", arXiv:2601.18734.
- SDPO paper: Hübotter et al., "Reinforcement Learning via Self-Distillation",
arXiv:2601.20802 (formalizes the same loss as Composer 2.5's "Targeted RL with
Textual Feedback").
The loss computes JSD/KL divergence between a teacher distribution (model
conditioned on privileged information / a hint) and a student distribution
(model on the original context). Both come from the SAME model — the teacher
is just "the model with hint inserted into context."
Composer 2.5 uses this with the privileged information being a "hint" inserted
at the error-turn site. We use the same loss; the data collator constructs
ctx_teacher = ctx_student + hint_at_error_turn for us.
"""
from __future__ import annotations
import torch
import torch.nn.functional as F
def generalized_jsd_loss(
student_logits: torch.Tensor,
teacher_logits: torch.Tensor,
labels: torch.Tensor | None = None,
beta: float = 0.5,
temperature: float = 1.0,
reduction: str = "batchmean",
logits_are_probs: bool = False,
top_k: int | None = None,
token_clip: float | None = None,
) -> torch.Tensor:
"""Generalized Jensen-Shannon Divergence loss between student and teacher.
Byte-for-byte replication of `OPSDTrainer.generalized_jsd_loss`
(siyan-zhao/OPSD, opsd_trainer.py lines 381-479). See
https://huggingface.co/papers/2306.13649 Eq. (1) for the definition.
Args:
student_logits: (B, T, V) — student model logits at each token position.
teacher_logits: (B, T, V) — teacher (= same model with hint context) logits.
labels: (B, T) — token-level mask. Positions with label == -100 are ignored
(standard HF padding/ignored convention). For Composer-style hint-distill,
mask should be 1 at error-turn tokens AFTER the hint, 0 elsewhere.
beta: in [0, 1]. NOTE on direction (per `F.kl_div` semantics, where
`F.kl_div(log_q, log_p, log_target=True)` computes KL(p || q)):
β = 0 → kl_div(student_log_probs, teacher_log_probs)
= KL(teacher || student) (reverse KL — mode-covering for student)
β = 1 → kl_div(teacher_log_probs, student_log_probs)
= KL(student || teacher) (forward KL — mode-seeking for student)
β = 0.5 → symmetric JSD with M = 0.5*(P+Q)
General β ∈ (0,1): mixture M = (1-β)·P_student + β·P_teacher and
jsd = β·KL(teacher||M) + (1-β)·KL(student||M).
temperature: softens distributions; T > 1 encourages distribution-matching
on broader tail probabilities. SDPO paper uses 1.0.
reduction: "batchmean" | "sum" | "mean" | "none". "batchmean" matches
upstream OPSD: divides by `mask.sum()` when labels are given, else
by the leading dim of jsd (= batch size). This differs from PyTorch's
`KLDivLoss(reduction='batchmean')` (which divides by batch). We match
upstream because gradient scale stability matters more than the name.
logits_are_probs: if True, inputs are already probabilities (skip softmax).
top_k: restrict KL to top-k tokens of the teacher distribution.
Saves compute on large vocabularies (Qwen3 vocab = 152K).
token_clip: clip per-token JSD to this max. Stabilizes training.
SDPO paper does NOT clip; OPSD code defaults to None (no clip).
Returns:
Scalar loss tensor (or unreduced (B, T, V) tensor for reduction="none").
"""
# Path A: probabilities-in. Take log directly with a clamp for stability.
if logits_are_probs:
student_log_probs = torch.log(student_logits.clamp_min(1e-8))
teacher_log_probs = torch.log(teacher_logits.clamp_min(1e-8))
else:
# Apply temperature scaling to logits before computing probabilities.
student_logits = student_logits / temperature
teacher_logits = teacher_logits / temperature
if top_k is not None and top_k > 0:
# Restrict to top-k tokens of the teacher distribution and renormalize.
_, top_k_indices = torch.topk(teacher_logits, k=top_k, dim=-1)
student_logits = torch.gather(student_logits, dim=-1, index=top_k_indices)
teacher_logits = torch.gather(teacher_logits, dim=-1, index=top_k_indices)
student_log_probs = F.log_softmax(student_logits, dim=-1)
teacher_log_probs = F.log_softmax(teacher_logits, dim=-1)
if beta == 0:
# F.kl_div(input=log_q, target=log_p, log_target=True) computes KL(p || q):
# sum_x p(x) * (log p(x) - log q(x))
# With input=student_log_probs, target=teacher_log_probs → KL(teacher || student).
jsd = F.kl_div(student_log_probs, teacher_log_probs, reduction="none", log_target=True)
elif beta == 1:
jsd = F.kl_div(teacher_log_probs, student_log_probs, reduction="none", log_target=True)
else:
# Compute the log of the β-weighted mixture distribution:
# M = (1-β)·P_student + β·P_teacher
# log M = logsumexp([log P_student + log(1-β), log P_teacher + log(β)])
beta = torch.tensor(beta, dtype=student_log_probs.dtype, device=student_log_probs.device)
mixture_log_probs = torch.logsumexp(
torch.stack([student_log_probs + torch.log1p(-beta), teacher_log_probs + torch.log(beta)]),
dim=0,
)
# Compute KL divergences using F.kl_div.
# PyTorch differs from the standard mathematical definition, so the order of
# the probability distributions is swapped compared to that defined in the paper.
kl_teacher = F.kl_div(mixture_log_probs, teacher_log_probs, reduction="none", log_target=True)
kl_student = F.kl_div(mixture_log_probs, student_log_probs, reduction="none", log_target=True)
# Generalized JSD: β weights the teacher-leg KL (matches upstream).
jsd = beta * kl_teacher + (1 - beta) * kl_student
# Per-token clipping: cap each token's divergence value.
if token_clip is not None:
jsd = jsd.clamp(max=token_clip)
# Masking. labels has shape (B, T); jsd has shape (B, T, V) (or top_k for V).
# `jsd[mask]` indexes the first two dims, yielding shape (n_valid, V).
mask = None
if labels is not None:
mask = labels != -100
jsd = jsd[mask]
# Apply reduction (matches upstream byte-for-byte for batchmean/sum/mean).
if reduction == "batchmean":
if labels is not None:
assert mask is not None
return jsd.sum() / mask.sum()
return jsd.sum() / jsd.size(0)
elif reduction == "sum":
return jsd.sum()
elif reduction == "mean":
return jsd.mean()
elif reduction == "none":
return jsd
else:
# Upstream falls through to `return jsd` for unknown reductions; we raise
# to surface caller bugs instead of silently returning an unreduced tensor.
raise ValueError(f"Unknown reduction: {reduction}")
__all__ = ["generalized_jsd_loss"]