File size: 7,659 Bytes
ac05fbf
 
 
 
e5add15
 
 
ac05fbf
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e5add15
 
 
 
ac05fbf
 
 
 
 
 
e5add15
 
 
 
 
 
 
 
 
ac05fbf
 
e5add15
 
 
 
 
ac05fbf
 
 
 
 
 
 
e5add15
ac05fbf
e5add15
 
 
 
 
 
ac05fbf
 
 
e5add15
 
 
 
 
 
ac05fbf
 
 
e5add15
 
 
 
 
 
 
ac05fbf
e5add15
 
 
 
 
 
 
ac05fbf
e5add15
 
 
 
 
 
 
 
 
 
 
ac05fbf
e5add15
ac05fbf
e5add15
 
 
ac05fbf
e5add15
 
ac05fbf
e5add15
ac05fbf
e5add15
 
 
 
ac05fbf
e5add15
ac05fbf
e5add15
ac05fbf
e5add15
ac05fbf
e5add15
 
ac05fbf
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
"""opsd_loss.py — Self-distillation loss, lifted from siyan-zhao/OPSD.

Original source: github.com/siyan-zhao/OPSD::OPSDTrainer.generalized_jsd_loss (MIT).
Verified self-contained via DeepWiki audit on 2026-05-25.
Re-aligned byte-for-byte against upstream `opsd_trainer.py` lines 381-479 on
2026-05-26 after Wave 15 math review found three numerical divergences (mixture
weighting, β coefficient placement, reduction divisor) and one docstring mislabel.

Mathematical reference:
- OPSD paper: Zhao et al., "Self-Distilled Reasoner: On-Policy Self-Distillation
  for LLMs", arXiv:2601.18734.
- SDPO paper: Hübotter et al., "Reinforcement Learning via Self-Distillation",
  arXiv:2601.20802 (formalizes the same loss as Composer 2.5's "Targeted RL with
  Textual Feedback").

The loss computes JSD/KL divergence between a teacher distribution (model
conditioned on privileged information / a hint) and a student distribution
(model on the original context). Both come from the SAME model — the teacher
is just "the model with hint inserted into context."

Composer 2.5 uses this with the privileged information being a "hint" inserted
at the error-turn site. We use the same loss; the data collator constructs
ctx_teacher = ctx_student + hint_at_error_turn for us.
"""

from __future__ import annotations

import torch
import torch.nn.functional as F


def generalized_jsd_loss(
    student_logits: torch.Tensor,
    teacher_logits: torch.Tensor,
    labels: torch.Tensor | None = None,
    beta: float = 0.5,
    temperature: float = 1.0,
    reduction: str = "batchmean",
    logits_are_probs: bool = False,
    top_k: int | None = None,
    token_clip: float | None = None,
) -> torch.Tensor:
    """Generalized Jensen-Shannon Divergence loss between student and teacher.

    Byte-for-byte replication of `OPSDTrainer.generalized_jsd_loss`
    (siyan-zhao/OPSD, opsd_trainer.py lines 381-479). See
    https://huggingface.co/papers/2306.13649 Eq. (1) for the definition.

    Args:
        student_logits: (B, T, V) — student model logits at each token position.
        teacher_logits: (B, T, V) — teacher (= same model with hint context) logits.
        labels: (B, T) — token-level mask. Positions with label == -100 are ignored
            (standard HF padding/ignored convention). For Composer-style hint-distill,
            mask should be 1 at error-turn tokens AFTER the hint, 0 elsewhere.
        beta: in [0, 1]. NOTE on direction (per `F.kl_div` semantics, where
            `F.kl_div(log_q, log_p, log_target=True)` computes KL(p || q)):
              β = 0  → kl_div(student_log_probs, teacher_log_probs)
                    = KL(teacher || student)  (reverse KL — mode-covering for student)
              β = 1  → kl_div(teacher_log_probs, student_log_probs)
                    = KL(student || teacher)  (forward KL — mode-seeking for student)
              β = 0.5 → symmetric JSD with M = 0.5*(P+Q)
            General β ∈ (0,1): mixture M = (1-β)·P_student + β·P_teacher and
            jsd = β·KL(teacher||M) + (1-β)·KL(student||M).
        temperature: softens distributions; T > 1 encourages distribution-matching
            on broader tail probabilities. SDPO paper uses 1.0.
        reduction: "batchmean" | "sum" | "mean" | "none". "batchmean" matches
            upstream OPSD: divides by `mask.sum()` when labels are given, else
            by the leading dim of jsd (= batch size). This differs from PyTorch's
            `KLDivLoss(reduction='batchmean')` (which divides by batch). We match
            upstream because gradient scale stability matters more than the name.
        logits_are_probs: if True, inputs are already probabilities (skip softmax).
        top_k: restrict KL to top-k tokens of the teacher distribution.
            Saves compute on large vocabularies (Qwen3 vocab = 152K).
        token_clip: clip per-token JSD to this max. Stabilizes training.
            SDPO paper does NOT clip; OPSD code defaults to None (no clip).

    Returns:
        Scalar loss tensor (or unreduced (B, T, V) tensor for reduction="none").
    """
    # Path A: probabilities-in. Take log directly with a clamp for stability.
    if logits_are_probs:
        student_log_probs = torch.log(student_logits.clamp_min(1e-8))
        teacher_log_probs = torch.log(teacher_logits.clamp_min(1e-8))
    else:
        # Apply temperature scaling to logits before computing probabilities.
        student_logits = student_logits / temperature
        teacher_logits = teacher_logits / temperature

        if top_k is not None and top_k > 0:
            # Restrict to top-k tokens of the teacher distribution and renormalize.
            _, top_k_indices = torch.topk(teacher_logits, k=top_k, dim=-1)
            student_logits = torch.gather(student_logits, dim=-1, index=top_k_indices)
            teacher_logits = torch.gather(teacher_logits, dim=-1, index=top_k_indices)

        student_log_probs = F.log_softmax(student_logits, dim=-1)
        teacher_log_probs = F.log_softmax(teacher_logits, dim=-1)

    if beta == 0:
        # F.kl_div(input=log_q, target=log_p, log_target=True) computes KL(p || q):
        #   sum_x p(x) * (log p(x) - log q(x))
        # With input=student_log_probs, target=teacher_log_probs → KL(teacher || student).
        jsd = F.kl_div(student_log_probs, teacher_log_probs, reduction="none", log_target=True)
    elif beta == 1:
        jsd = F.kl_div(teacher_log_probs, student_log_probs, reduction="none", log_target=True)
    else:
        # Compute the log of the β-weighted mixture distribution:
        #   M = (1-β)·P_student + β·P_teacher
        #   log M = logsumexp([log P_student + log(1-β), log P_teacher + log(β)])
        beta = torch.tensor(beta, dtype=student_log_probs.dtype, device=student_log_probs.device)
        mixture_log_probs = torch.logsumexp(
            torch.stack([student_log_probs + torch.log1p(-beta), teacher_log_probs + torch.log(beta)]),
            dim=0,
        )

        # Compute KL divergences using F.kl_div.
        # PyTorch differs from the standard mathematical definition, so the order of
        # the probability distributions is swapped compared to that defined in the paper.
        kl_teacher = F.kl_div(mixture_log_probs, teacher_log_probs, reduction="none", log_target=True)
        kl_student = F.kl_div(mixture_log_probs, student_log_probs, reduction="none", log_target=True)

        # Generalized JSD: β weights the teacher-leg KL (matches upstream).
        jsd = beta * kl_teacher + (1 - beta) * kl_student

    # Per-token clipping: cap each token's divergence value.
    if token_clip is not None:
        jsd = jsd.clamp(max=token_clip)

    # Masking. labels has shape (B, T); jsd has shape (B, T, V) (or top_k for V).
    # `jsd[mask]` indexes the first two dims, yielding shape (n_valid, V).
    mask = None
    if labels is not None:
        mask = labels != -100
        jsd = jsd[mask]

    # Apply reduction (matches upstream byte-for-byte for batchmean/sum/mean).
    if reduction == "batchmean":
        if labels is not None:
            assert mask is not None
            return jsd.sum() / mask.sum()
        return jsd.sum() / jsd.size(0)
    elif reduction == "sum":
        return jsd.sum()
    elif reduction == "mean":
        return jsd.mean()
    elif reduction == "none":
        return jsd
    else:
        # Upstream falls through to `return jsd` for unknown reductions; we raise
        # to surface caller bugs instead of silently returning an unreduced tensor.
        raise ValueError(f"Unknown reduction: {reduction}")


__all__ = ["generalized_jsd_loss"]