Reinforcement Learning
Transformers
English
post-training
distillation
agentic-coding
composer-2.5
cursor
kimi-k2
grpo
dapo
diloco
openenv
trl
verl
research
methodology
Instructions to use Codeseys/composer-replication-framework with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- Transformers
How to use Codeseys/composer-replication-framework with Transformers:
# Load model directly from transformers import AutoModel model = AutoModel.from_pretrained("Codeseys/composer-replication-framework", dtype="auto") - Notebooks
- Google Colab
- Kaggle
File size: 8,820 Bytes
ac05fbf d88715c ac05fbf d88715c ac05fbf d88715c ac05fbf d88715c ac05fbf d88715c ac05fbf | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 | """real_batch.py — build a real, tokenized 3-channel batch from a HF tokenizer.
Used by Spike 006's smoke to generate inputs for `compose_loss` from a real
chat-template-formatted conversation, NOT random ints.
"""
from __future__ import annotations
from typing import Any
import torch
def build_batch(
tokenizer: Any,
*,
device: torch.device | str = "cpu",
seed: int = 42,
variant: str = "factorial",
align_sdpo_shapes: bool = False,
) -> dict[str, torch.Tensor]:
"""Construct a full 3-channel input batch from a real tokenizer.
Returns a dict with all keys `compose_loss` may consume:
input_ids, response_mask
ctx_teacher_input_ids, sdpo_loss_mask
dpo_chosen_input_ids, dpo_chosen_response_mask
dpo_rejected_input_ids, dpo_rejected_response_mask
dpo_chosen_ref_logprobs, dpo_rejected_ref_logprobs
The DPO ref logprobs are dummy tensors (not from a real reference policy
forward); the smoke is verifying the loss composition wires together,
not the reference-policy precompute pipeline.
Args:
tokenizer: real HF tokenizer
device: torch device for the returned tensors
seed: reproducibility — fixes torch.manual_seed before any random
tensor (only the dummy logprobs use random; the chat-template
text is deterministic)
variant: "factorial" or "binary_search" — pick which canned
conversation. Used by Spike 006-strict to alternate batches
so the loss-decrease isn't memorization of a single sample.
align_sdpo_shapes: if True, truncate ctx_teacher_input_ids to
match input_ids length so the SDPO channel actually fires
(no shape-mismatch fallback). Used by Spike 006-strict to
exercise the SDPO loss on a real model.
"""
torch.manual_seed(seed)
# ------------------------------------------------------------------
# Conversation 1: student rollout (variants for non-tautological tests)
# ------------------------------------------------------------------
if variant == "factorial":
student_msgs = [
{"role": "system", "content": "You are a careful coding assistant."},
{"role": "user", "content": "Write a Python function to compute the factorial of n."},
{"role": "assistant", "content": "def factorial(n):\n if n <= 1: return 1\n return n * factorial(n - 1)"},
]
teacher_msgs = [
{"role": "system", "content": "You are a careful coding assistant."},
{"role": "user", "content": "Write a Python function to compute the factorial of n."},
{"role": "user", "content": "[HINT] Recursion overflows for n>1000. Use an iterative loop."},
{"role": "assistant", "content": "def factorial(n):\n result = 1\n for i in range(2, n + 1):\n result *= i\n return result"},
]
elif variant == "binary_search":
student_msgs = [
{"role": "system", "content": "You are a careful coding assistant."},
{"role": "user", "content": "Implement binary search in Python."},
{"role": "assistant", "content": "def bsearch(a, t):\n l, r = 0, len(a)\n while l < r:\n m = (l + r) // 2\n if a[m] < t: l = m + 1\n else: r = m\n return l"},
]
teacher_msgs = [
{"role": "system", "content": "You are a careful coding assistant."},
{"role": "user", "content": "Implement binary search in Python."},
{"role": "user", "content": "[HINT] Use right = len(a) - 1 with inclusive upper bound is more standard."},
{"role": "assistant", "content": "def bsearch(a, t):\n l, r = 0, len(a) - 1\n while l <= r:\n m = (l + r) // 2\n if a[m] == t: return m\n if a[m] < t: l = m + 1\n else: r = m - 1\n return -1"},
]
else:
raise ValueError(f"unknown variant: {variant!r}")
student_text = tokenizer.apply_chat_template(student_msgs, tokenize=False, add_generation_prompt=False)
student_enc = tokenizer(student_text, return_tensors="pt", add_special_tokens=False)
input_ids = student_enc["input_ids"].to(device)
T = input_ids.shape[1]
response_mask = torch.zeros_like(input_ids)
response_mask[:, int(T * 0.7):] = 1
# ------------------------------------------------------------------
# Conversation 2: hint-conditioned teacher context (SDPO)
# ------------------------------------------------------------------
teacher_text = tokenizer.apply_chat_template(teacher_msgs, tokenize=False, add_generation_prompt=False)
teacher_enc = tokenizer(teacher_text, return_tensors="pt", add_special_tokens=False)
ctx_teacher_input_ids = teacher_enc["input_ids"].to(device)
if align_sdpo_shapes:
# Truncate the teacher context to the student length so SDPO actually fires
# (compose_loss falls back to zero when shapes mismatch). This is a
# correctness-relaxing test mode — production will pad/align via the
# real data collator, but for the smoke we just need the SDPO loss
# to exercise the generalized_jsd_loss code path on a real HF model.
T_t = ctx_teacher_input_ids.shape[1]
if T_t > T:
ctx_teacher_input_ids = ctx_teacher_input_ids[:, :T]
elif T_t < T:
pad_id = tokenizer.pad_token_id if tokenizer.pad_token_id is not None else tokenizer.eos_token_id
pad = torch.full((1, T - T_t), pad_id, dtype=ctx_teacher_input_ids.dtype, device=device)
ctx_teacher_input_ids = torch.cat([ctx_teacher_input_ids, pad], dim=1)
T_t = ctx_teacher_input_ids.shape[1]
sdpo_loss_mask = torch.zeros_like(ctx_teacher_input_ids)
sdpo_loss_mask[:, int(T_t * 0.7):] = 1
# ------------------------------------------------------------------
# Conversation 3 + 4: DPO chosen / rejected pairs
# ------------------------------------------------------------------
dpo_chosen_msgs = [
{"role": "system", "content": "You are a careful coding assistant."},
{"role": "user", "content": "What's the time complexity of binary search?"},
{"role": "assistant", "content": "Binary search is O(log n) because each comparison halves the search space."},
]
dpo_rejected_msgs = [
{"role": "system", "content": "You are a careful coding assistant."},
{"role": "user", "content": "What's the time complexity of binary search?"},
{"role": "assistant", "content": "It's O(n) I think, you have to look at every element."},
]
chosen_text = tokenizer.apply_chat_template(dpo_chosen_msgs, tokenize=False, add_generation_prompt=False)
rejected_text = tokenizer.apply_chat_template(dpo_rejected_msgs, tokenize=False, add_generation_prompt=False)
# Pad both sequences to the same length so we can stack them
chosen_enc = tokenizer(chosen_text, return_tensors="pt", add_special_tokens=False, padding=False)
rejected_enc = tokenizer(rejected_text, return_tensors="pt", add_special_tokens=False, padding=False)
pad_id = tokenizer.pad_token_id if tokenizer.pad_token_id is not None else tokenizer.eos_token_id
chosen_ids = chosen_enc["input_ids"]
rejected_ids = rejected_enc["input_ids"]
L = max(chosen_ids.shape[1], rejected_ids.shape[1])
def _pad(ids: torch.Tensor, length: int) -> torch.Tensor:
cur = ids.shape[1]
if cur >= length:
return ids[:, :length]
return torch.cat([ids, torch.full((1, length - cur), pad_id, dtype=ids.dtype)], dim=1)
dpo_chosen_input_ids = _pad(chosen_ids, L).to(device)
dpo_rejected_input_ids = _pad(rejected_ids, L).to(device)
chosen_resp_mask = torch.zeros_like(dpo_chosen_input_ids)
chosen_resp_mask[:, int(L * 0.6):chosen_ids.shape[1]] = 1
rejected_resp_mask = torch.zeros_like(dpo_rejected_input_ids)
rejected_resp_mask[:, int(L * 0.6):rejected_ids.shape[1]] = 1
# Dummy reference-policy logprobs (in production: precomputed by data collator)
dpo_chosen_ref_logprobs = torch.tensor([-30.0], device=device)
dpo_rejected_ref_logprobs = torch.tensor([-35.0], device=device)
return {
"input_ids": input_ids,
"response_mask": response_mask,
"ctx_teacher_input_ids": ctx_teacher_input_ids,
"sdpo_loss_mask": sdpo_loss_mask,
"dpo_chosen_input_ids": dpo_chosen_input_ids,
"dpo_chosen_response_mask": chosen_resp_mask,
"dpo_rejected_input_ids": dpo_rejected_input_ids,
"dpo_rejected_response_mask": rejected_resp_mask,
"dpo_chosen_ref_logprobs": dpo_chosen_ref_logprobs,
"dpo_rejected_ref_logprobs": dpo_rejected_ref_logprobs,
}
__all__ = ["build_batch"]
|