File size: 8,820 Bytes
ac05fbf
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d88715c
 
ac05fbf
 
 
 
 
 
 
 
 
 
 
 
 
d88715c
 
 
 
 
 
 
 
 
 
 
 
 
 
ac05fbf
 
 
 
d88715c
ac05fbf
d88715c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ac05fbf
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d88715c
 
 
 
 
 
 
 
 
 
 
 
 
 
ac05fbf
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
"""real_batch.py — build a real, tokenized 3-channel batch from a HF tokenizer.

Used by Spike 006's smoke to generate inputs for `compose_loss` from a real
chat-template-formatted conversation, NOT random ints.
"""
from __future__ import annotations

from typing import Any

import torch


def build_batch(
    tokenizer: Any,
    *,
    device: torch.device | str = "cpu",
    seed: int = 42,
    variant: str = "factorial",
    align_sdpo_shapes: bool = False,
) -> dict[str, torch.Tensor]:
    """Construct a full 3-channel input batch from a real tokenizer.

    Returns a dict with all keys `compose_loss` may consume:
        input_ids, response_mask
        ctx_teacher_input_ids, sdpo_loss_mask
        dpo_chosen_input_ids, dpo_chosen_response_mask
        dpo_rejected_input_ids, dpo_rejected_response_mask
        dpo_chosen_ref_logprobs, dpo_rejected_ref_logprobs

    The DPO ref logprobs are dummy tensors (not from a real reference policy
    forward); the smoke is verifying the loss composition wires together,
    not the reference-policy precompute pipeline.

    Args:
        tokenizer: real HF tokenizer
        device: torch device for the returned tensors
        seed: reproducibility — fixes torch.manual_seed before any random
            tensor (only the dummy logprobs use random; the chat-template
            text is deterministic)
        variant: "factorial" or "binary_search" — pick which canned
            conversation. Used by Spike 006-strict to alternate batches
            so the loss-decrease isn't memorization of a single sample.
        align_sdpo_shapes: if True, truncate ctx_teacher_input_ids to
            match input_ids length so the SDPO channel actually fires
            (no shape-mismatch fallback). Used by Spike 006-strict to
            exercise the SDPO loss on a real model.
    """
    torch.manual_seed(seed)

    # ------------------------------------------------------------------
    # Conversation 1: student rollout (variants for non-tautological tests)
    # ------------------------------------------------------------------
    if variant == "factorial":
        student_msgs = [
            {"role": "system", "content": "You are a careful coding assistant."},
            {"role": "user", "content": "Write a Python function to compute the factorial of n."},
            {"role": "assistant", "content": "def factorial(n):\n    if n <= 1: return 1\n    return n * factorial(n - 1)"},
        ]
        teacher_msgs = [
            {"role": "system", "content": "You are a careful coding assistant."},
            {"role": "user", "content": "Write a Python function to compute the factorial of n."},
            {"role": "user", "content": "[HINT] Recursion overflows for n>1000. Use an iterative loop."},
            {"role": "assistant", "content": "def factorial(n):\n    result = 1\n    for i in range(2, n + 1):\n        result *= i\n    return result"},
        ]
    elif variant == "binary_search":
        student_msgs = [
            {"role": "system", "content": "You are a careful coding assistant."},
            {"role": "user", "content": "Implement binary search in Python."},
            {"role": "assistant", "content": "def bsearch(a, t):\n    l, r = 0, len(a)\n    while l < r:\n        m = (l + r) // 2\n        if a[m] < t: l = m + 1\n        else: r = m\n    return l"},
        ]
        teacher_msgs = [
            {"role": "system", "content": "You are a careful coding assistant."},
            {"role": "user", "content": "Implement binary search in Python."},
            {"role": "user", "content": "[HINT] Use right = len(a) - 1 with inclusive upper bound is more standard."},
            {"role": "assistant", "content": "def bsearch(a, t):\n    l, r = 0, len(a) - 1\n    while l <= r:\n        m = (l + r) // 2\n        if a[m] == t: return m\n        if a[m] < t: l = m + 1\n        else: r = m - 1\n    return -1"},
        ]
    else:
        raise ValueError(f"unknown variant: {variant!r}")

    student_text = tokenizer.apply_chat_template(student_msgs, tokenize=False, add_generation_prompt=False)
    student_enc = tokenizer(student_text, return_tensors="pt", add_special_tokens=False)
    input_ids = student_enc["input_ids"].to(device)

    T = input_ids.shape[1]
    response_mask = torch.zeros_like(input_ids)
    response_mask[:, int(T * 0.7):] = 1

    # ------------------------------------------------------------------
    # Conversation 2: hint-conditioned teacher context (SDPO)
    # ------------------------------------------------------------------
    teacher_text = tokenizer.apply_chat_template(teacher_msgs, tokenize=False, add_generation_prompt=False)
    teacher_enc = tokenizer(teacher_text, return_tensors="pt", add_special_tokens=False)
    ctx_teacher_input_ids = teacher_enc["input_ids"].to(device)

    if align_sdpo_shapes:
        # Truncate the teacher context to the student length so SDPO actually fires
        # (compose_loss falls back to zero when shapes mismatch). This is a
        # correctness-relaxing test mode — production will pad/align via the
        # real data collator, but for the smoke we just need the SDPO loss
        # to exercise the generalized_jsd_loss code path on a real HF model.
        T_t = ctx_teacher_input_ids.shape[1]
        if T_t > T:
            ctx_teacher_input_ids = ctx_teacher_input_ids[:, :T]
        elif T_t < T:
            pad_id = tokenizer.pad_token_id if tokenizer.pad_token_id is not None else tokenizer.eos_token_id
            pad = torch.full((1, T - T_t), pad_id, dtype=ctx_teacher_input_ids.dtype, device=device)
            ctx_teacher_input_ids = torch.cat([ctx_teacher_input_ids, pad], dim=1)

    T_t = ctx_teacher_input_ids.shape[1]
    sdpo_loss_mask = torch.zeros_like(ctx_teacher_input_ids)
    sdpo_loss_mask[:, int(T_t * 0.7):] = 1

    # ------------------------------------------------------------------
    # Conversation 3 + 4: DPO chosen / rejected pairs
    # ------------------------------------------------------------------
    dpo_chosen_msgs = [
        {"role": "system", "content": "You are a careful coding assistant."},
        {"role": "user", "content": "What's the time complexity of binary search?"},
        {"role": "assistant", "content": "Binary search is O(log n) because each comparison halves the search space."},
    ]
    dpo_rejected_msgs = [
        {"role": "system", "content": "You are a careful coding assistant."},
        {"role": "user", "content": "What's the time complexity of binary search?"},
        {"role": "assistant", "content": "It's O(n) I think, you have to look at every element."},
    ]
    chosen_text = tokenizer.apply_chat_template(dpo_chosen_msgs, tokenize=False, add_generation_prompt=False)
    rejected_text = tokenizer.apply_chat_template(dpo_rejected_msgs, tokenize=False, add_generation_prompt=False)

    # Pad both sequences to the same length so we can stack them
    chosen_enc = tokenizer(chosen_text, return_tensors="pt", add_special_tokens=False, padding=False)
    rejected_enc = tokenizer(rejected_text, return_tensors="pt", add_special_tokens=False, padding=False)

    pad_id = tokenizer.pad_token_id if tokenizer.pad_token_id is not None else tokenizer.eos_token_id

    chosen_ids = chosen_enc["input_ids"]
    rejected_ids = rejected_enc["input_ids"]
    L = max(chosen_ids.shape[1], rejected_ids.shape[1])

    def _pad(ids: torch.Tensor, length: int) -> torch.Tensor:
        cur = ids.shape[1]
        if cur >= length:
            return ids[:, :length]
        return torch.cat([ids, torch.full((1, length - cur), pad_id, dtype=ids.dtype)], dim=1)

    dpo_chosen_input_ids = _pad(chosen_ids, L).to(device)
    dpo_rejected_input_ids = _pad(rejected_ids, L).to(device)

    chosen_resp_mask = torch.zeros_like(dpo_chosen_input_ids)
    chosen_resp_mask[:, int(L * 0.6):chosen_ids.shape[1]] = 1
    rejected_resp_mask = torch.zeros_like(dpo_rejected_input_ids)
    rejected_resp_mask[:, int(L * 0.6):rejected_ids.shape[1]] = 1

    # Dummy reference-policy logprobs (in production: precomputed by data collator)
    dpo_chosen_ref_logprobs = torch.tensor([-30.0], device=device)
    dpo_rejected_ref_logprobs = torch.tensor([-35.0], device=device)

    return {
        "input_ids": input_ids,
        "response_mask": response_mask,
        "ctx_teacher_input_ids": ctx_teacher_input_ids,
        "sdpo_loss_mask": sdpo_loss_mask,
        "dpo_chosen_input_ids": dpo_chosen_input_ids,
        "dpo_chosen_response_mask": chosen_resp_mask,
        "dpo_rejected_input_ids": dpo_rejected_input_ids,
        "dpo_rejected_response_mask": rejected_resp_mask,
        "dpo_chosen_ref_logprobs": dpo_chosen_ref_logprobs,
        "dpo_rejected_ref_logprobs": dpo_rejected_ref_logprobs,
    }


__all__ = ["build_batch"]