"""opsd_loss.py — Self-distillation loss, lifted from siyan-zhao/OPSD. Original source: github.com/siyan-zhao/OPSD::OPSDTrainer.generalized_jsd_loss (MIT). Verified self-contained via DeepWiki audit on 2026-05-25. Re-aligned byte-for-byte against upstream `opsd_trainer.py` lines 381-479 on 2026-05-26 after Wave 15 math review found three numerical divergences (mixture weighting, β coefficient placement, reduction divisor) and one docstring mislabel. Mathematical reference: - OPSD paper: Zhao et al., "Self-Distilled Reasoner: On-Policy Self-Distillation for LLMs", arXiv:2601.18734. - SDPO paper: Hübotter et al., "Reinforcement Learning via Self-Distillation", arXiv:2601.20802 (formalizes the same loss as Composer 2.5's "Targeted RL with Textual Feedback"). The loss computes JSD/KL divergence between a teacher distribution (model conditioned on privileged information / a hint) and a student distribution (model on the original context). Both come from the SAME model — the teacher is just "the model with hint inserted into context." Composer 2.5 uses this with the privileged information being a "hint" inserted at the error-turn site. We use the same loss; the data collator constructs ctx_teacher = ctx_student + hint_at_error_turn for us. """ from __future__ import annotations import torch import torch.nn.functional as F def generalized_jsd_loss( student_logits: torch.Tensor, teacher_logits: torch.Tensor, labels: torch.Tensor | None = None, beta: float = 0.5, temperature: float = 1.0, reduction: str = "batchmean", logits_are_probs: bool = False, top_k: int | None = None, token_clip: float | None = None, ) -> torch.Tensor: """Generalized Jensen-Shannon Divergence loss between student and teacher. Byte-for-byte replication of `OPSDTrainer.generalized_jsd_loss` (siyan-zhao/OPSD, opsd_trainer.py lines 381-479). See https://huggingface.co/papers/2306.13649 Eq. (1) for the definition. Args: student_logits: (B, T, V) — student model logits at each token position. teacher_logits: (B, T, V) — teacher (= same model with hint context) logits. labels: (B, T) — token-level mask. Positions with label == -100 are ignored (standard HF padding/ignored convention). For Composer-style hint-distill, mask should be 1 at error-turn tokens AFTER the hint, 0 elsewhere. beta: in [0, 1]. NOTE on direction (per `F.kl_div` semantics, where `F.kl_div(log_q, log_p, log_target=True)` computes KL(p || q)): β = 0 → kl_div(student_log_probs, teacher_log_probs) = KL(teacher || student) (reverse KL — mode-covering for student) β = 1 → kl_div(teacher_log_probs, student_log_probs) = KL(student || teacher) (forward KL — mode-seeking for student) β = 0.5 → symmetric JSD with M = 0.5*(P+Q) General β ∈ (0,1): mixture M = (1-β)·P_student + β·P_teacher and jsd = β·KL(teacher||M) + (1-β)·KL(student||M). temperature: softens distributions; T > 1 encourages distribution-matching on broader tail probabilities. SDPO paper uses 1.0. reduction: "batchmean" | "sum" | "mean" | "none". "batchmean" matches upstream OPSD: divides by `mask.sum()` when labels are given, else by the leading dim of jsd (= batch size). This differs from PyTorch's `KLDivLoss(reduction='batchmean')` (which divides by batch). We match upstream because gradient scale stability matters more than the name. logits_are_probs: if True, inputs are already probabilities (skip softmax). top_k: restrict KL to top-k tokens of the teacher distribution. Saves compute on large vocabularies (Qwen3 vocab = 152K). token_clip: clip per-token JSD to this max. Stabilizes training. SDPO paper does NOT clip; OPSD code defaults to None (no clip). Returns: Scalar loss tensor (or unreduced (B, T, V) tensor for reduction="none"). """ # Path A: probabilities-in. Take log directly with a clamp for stability. if logits_are_probs: student_log_probs = torch.log(student_logits.clamp_min(1e-8)) teacher_log_probs = torch.log(teacher_logits.clamp_min(1e-8)) else: # Apply temperature scaling to logits before computing probabilities. student_logits = student_logits / temperature teacher_logits = teacher_logits / temperature if top_k is not None and top_k > 0: # Restrict to top-k tokens of the teacher distribution and renormalize. _, top_k_indices = torch.topk(teacher_logits, k=top_k, dim=-1) student_logits = torch.gather(student_logits, dim=-1, index=top_k_indices) teacher_logits = torch.gather(teacher_logits, dim=-1, index=top_k_indices) student_log_probs = F.log_softmax(student_logits, dim=-1) teacher_log_probs = F.log_softmax(teacher_logits, dim=-1) if beta == 0: # F.kl_div(input=log_q, target=log_p, log_target=True) computes KL(p || q): # sum_x p(x) * (log p(x) - log q(x)) # With input=student_log_probs, target=teacher_log_probs → KL(teacher || student). jsd = F.kl_div(student_log_probs, teacher_log_probs, reduction="none", log_target=True) elif beta == 1: jsd = F.kl_div(teacher_log_probs, student_log_probs, reduction="none", log_target=True) else: # Compute the log of the β-weighted mixture distribution: # M = (1-β)·P_student + β·P_teacher # log M = logsumexp([log P_student + log(1-β), log P_teacher + log(β)]) beta = torch.tensor(beta, dtype=student_log_probs.dtype, device=student_log_probs.device) mixture_log_probs = torch.logsumexp( torch.stack([student_log_probs + torch.log1p(-beta), teacher_log_probs + torch.log(beta)]), dim=0, ) # Compute KL divergences using F.kl_div. # PyTorch differs from the standard mathematical definition, so the order of # the probability distributions is swapped compared to that defined in the paper. kl_teacher = F.kl_div(mixture_log_probs, teacher_log_probs, reduction="none", log_target=True) kl_student = F.kl_div(mixture_log_probs, student_log_probs, reduction="none", log_target=True) # Generalized JSD: β weights the teacher-leg KL (matches upstream). jsd = beta * kl_teacher + (1 - beta) * kl_student # Per-token clipping: cap each token's divergence value. if token_clip is not None: jsd = jsd.clamp(max=token_clip) # Masking. labels has shape (B, T); jsd has shape (B, T, V) (or top_k for V). # `jsd[mask]` indexes the first two dims, yielding shape (n_valid, V). mask = None if labels is not None: mask = labels != -100 jsd = jsd[mask] # Apply reduction (matches upstream byte-for-byte for batchmean/sum/mean). if reduction == "batchmean": if labels is not None: assert mask is not None return jsd.sum() / mask.sum() return jsd.sum() / jsd.size(0) elif reduction == "sum": return jsd.sum() elif reduction == "mean": return jsd.mean() elif reduction == "none": return jsd else: # Upstream falls through to `return jsd` for unknown reductions; we raise # to surface caller bugs instead of silently returning an unreduced tensor. raise ValueError(f"Unknown reduction: {reduction}") __all__ = ["generalized_jsd_loss"]