Reinforcement Learning
Transformers
English
post-training
distillation
agentic-coding
composer-2.5
cursor
kimi-k2
grpo
dapo
diloco
openenv
trl
verl
research
methodology
Instructions to use Codeseys/composer-replication-framework with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- Transformers
How to use Codeseys/composer-replication-framework with Transformers:
# Load model directly from transformers import AutoModel model = AutoModel.from_pretrained("Codeseys/composer-replication-framework", dtype="auto") - Notebooks
- Google Colab
- Kaggle
| """real_batch.py — build a real, tokenized 3-channel batch from a HF tokenizer. | |
| Used by Spike 006's smoke to generate inputs for `compose_loss` from a real | |
| chat-template-formatted conversation, NOT random ints. | |
| """ | |
| from __future__ import annotations | |
| from typing import Any | |
| import torch | |
| def build_batch( | |
| tokenizer: Any, | |
| *, | |
| device: torch.device | str = "cpu", | |
| seed: int = 42, | |
| variant: str = "factorial", | |
| align_sdpo_shapes: bool = False, | |
| ) -> dict[str, torch.Tensor]: | |
| """Construct a full 3-channel input batch from a real tokenizer. | |
| Returns a dict with all keys `compose_loss` may consume: | |
| input_ids, response_mask | |
| ctx_teacher_input_ids, sdpo_loss_mask | |
| dpo_chosen_input_ids, dpo_chosen_response_mask | |
| dpo_rejected_input_ids, dpo_rejected_response_mask | |
| dpo_chosen_ref_logprobs, dpo_rejected_ref_logprobs | |
| The DPO ref logprobs are dummy tensors (not from a real reference policy | |
| forward); the smoke is verifying the loss composition wires together, | |
| not the reference-policy precompute pipeline. | |
| Args: | |
| tokenizer: real HF tokenizer | |
| device: torch device for the returned tensors | |
| seed: reproducibility — fixes torch.manual_seed before any random | |
| tensor (only the dummy logprobs use random; the chat-template | |
| text is deterministic) | |
| variant: "factorial" or "binary_search" — pick which canned | |
| conversation. Used by Spike 006-strict to alternate batches | |
| so the loss-decrease isn't memorization of a single sample. | |
| align_sdpo_shapes: if True, truncate ctx_teacher_input_ids to | |
| match input_ids length so the SDPO channel actually fires | |
| (no shape-mismatch fallback). Used by Spike 006-strict to | |
| exercise the SDPO loss on a real model. | |
| """ | |
| torch.manual_seed(seed) | |
| # ------------------------------------------------------------------ | |
| # Conversation 1: student rollout (variants for non-tautological tests) | |
| # ------------------------------------------------------------------ | |
| if variant == "factorial": | |
| student_msgs = [ | |
| {"role": "system", "content": "You are a careful coding assistant."}, | |
| {"role": "user", "content": "Write a Python function to compute the factorial of n."}, | |
| {"role": "assistant", "content": "def factorial(n):\n if n <= 1: return 1\n return n * factorial(n - 1)"}, | |
| ] | |
| teacher_msgs = [ | |
| {"role": "system", "content": "You are a careful coding assistant."}, | |
| {"role": "user", "content": "Write a Python function to compute the factorial of n."}, | |
| {"role": "user", "content": "[HINT] Recursion overflows for n>1000. Use an iterative loop."}, | |
| {"role": "assistant", "content": "def factorial(n):\n result = 1\n for i in range(2, n + 1):\n result *= i\n return result"}, | |
| ] | |
| elif variant == "binary_search": | |
| student_msgs = [ | |
| {"role": "system", "content": "You are a careful coding assistant."}, | |
| {"role": "user", "content": "Implement binary search in Python."}, | |
| {"role": "assistant", "content": "def bsearch(a, t):\n l, r = 0, len(a)\n while l < r:\n m = (l + r) // 2\n if a[m] < t: l = m + 1\n else: r = m\n return l"}, | |
| ] | |
| teacher_msgs = [ | |
| {"role": "system", "content": "You are a careful coding assistant."}, | |
| {"role": "user", "content": "Implement binary search in Python."}, | |
| {"role": "user", "content": "[HINT] Use right = len(a) - 1 with inclusive upper bound is more standard."}, | |
| {"role": "assistant", "content": "def bsearch(a, t):\n l, r = 0, len(a) - 1\n while l <= r:\n m = (l + r) // 2\n if a[m] == t: return m\n if a[m] < t: l = m + 1\n else: r = m - 1\n return -1"}, | |
| ] | |
| else: | |
| raise ValueError(f"unknown variant: {variant!r}") | |
| student_text = tokenizer.apply_chat_template(student_msgs, tokenize=False, add_generation_prompt=False) | |
| student_enc = tokenizer(student_text, return_tensors="pt", add_special_tokens=False) | |
| input_ids = student_enc["input_ids"].to(device) | |
| T = input_ids.shape[1] | |
| response_mask = torch.zeros_like(input_ids) | |
| response_mask[:, int(T * 0.7):] = 1 | |
| # ------------------------------------------------------------------ | |
| # Conversation 2: hint-conditioned teacher context (SDPO) | |
| # ------------------------------------------------------------------ | |
| teacher_text = tokenizer.apply_chat_template(teacher_msgs, tokenize=False, add_generation_prompt=False) | |
| teacher_enc = tokenizer(teacher_text, return_tensors="pt", add_special_tokens=False) | |
| ctx_teacher_input_ids = teacher_enc["input_ids"].to(device) | |
| if align_sdpo_shapes: | |
| # Truncate the teacher context to the student length so SDPO actually fires | |
| # (compose_loss falls back to zero when shapes mismatch). This is a | |
| # correctness-relaxing test mode — production will pad/align via the | |
| # real data collator, but for the smoke we just need the SDPO loss | |
| # to exercise the generalized_jsd_loss code path on a real HF model. | |
| T_t = ctx_teacher_input_ids.shape[1] | |
| if T_t > T: | |
| ctx_teacher_input_ids = ctx_teacher_input_ids[:, :T] | |
| elif T_t < T: | |
| pad_id = tokenizer.pad_token_id if tokenizer.pad_token_id is not None else tokenizer.eos_token_id | |
| pad = torch.full((1, T - T_t), pad_id, dtype=ctx_teacher_input_ids.dtype, device=device) | |
| ctx_teacher_input_ids = torch.cat([ctx_teacher_input_ids, pad], dim=1) | |
| T_t = ctx_teacher_input_ids.shape[1] | |
| sdpo_loss_mask = torch.zeros_like(ctx_teacher_input_ids) | |
| sdpo_loss_mask[:, int(T_t * 0.7):] = 1 | |
| # ------------------------------------------------------------------ | |
| # Conversation 3 + 4: DPO chosen / rejected pairs | |
| # ------------------------------------------------------------------ | |
| dpo_chosen_msgs = [ | |
| {"role": "system", "content": "You are a careful coding assistant."}, | |
| {"role": "user", "content": "What's the time complexity of binary search?"}, | |
| {"role": "assistant", "content": "Binary search is O(log n) because each comparison halves the search space."}, | |
| ] | |
| dpo_rejected_msgs = [ | |
| {"role": "system", "content": "You are a careful coding assistant."}, | |
| {"role": "user", "content": "What's the time complexity of binary search?"}, | |
| {"role": "assistant", "content": "It's O(n) I think, you have to look at every element."}, | |
| ] | |
| chosen_text = tokenizer.apply_chat_template(dpo_chosen_msgs, tokenize=False, add_generation_prompt=False) | |
| rejected_text = tokenizer.apply_chat_template(dpo_rejected_msgs, tokenize=False, add_generation_prompt=False) | |
| # Pad both sequences to the same length so we can stack them | |
| chosen_enc = tokenizer(chosen_text, return_tensors="pt", add_special_tokens=False, padding=False) | |
| rejected_enc = tokenizer(rejected_text, return_tensors="pt", add_special_tokens=False, padding=False) | |
| pad_id = tokenizer.pad_token_id if tokenizer.pad_token_id is not None else tokenizer.eos_token_id | |
| chosen_ids = chosen_enc["input_ids"] | |
| rejected_ids = rejected_enc["input_ids"] | |
| L = max(chosen_ids.shape[1], rejected_ids.shape[1]) | |
| def _pad(ids: torch.Tensor, length: int) -> torch.Tensor: | |
| cur = ids.shape[1] | |
| if cur >= length: | |
| return ids[:, :length] | |
| return torch.cat([ids, torch.full((1, length - cur), pad_id, dtype=ids.dtype)], dim=1) | |
| dpo_chosen_input_ids = _pad(chosen_ids, L).to(device) | |
| dpo_rejected_input_ids = _pad(rejected_ids, L).to(device) | |
| chosen_resp_mask = torch.zeros_like(dpo_chosen_input_ids) | |
| chosen_resp_mask[:, int(L * 0.6):chosen_ids.shape[1]] = 1 | |
| rejected_resp_mask = torch.zeros_like(dpo_rejected_input_ids) | |
| rejected_resp_mask[:, int(L * 0.6):rejected_ids.shape[1]] = 1 | |
| # Dummy reference-policy logprobs (in production: precomputed by data collator) | |
| dpo_chosen_ref_logprobs = torch.tensor([-30.0], device=device) | |
| dpo_rejected_ref_logprobs = torch.tensor([-35.0], device=device) | |
| return { | |
| "input_ids": input_ids, | |
| "response_mask": response_mask, | |
| "ctx_teacher_input_ids": ctx_teacher_input_ids, | |
| "sdpo_loss_mask": sdpo_loss_mask, | |
| "dpo_chosen_input_ids": dpo_chosen_input_ids, | |
| "dpo_chosen_response_mask": chosen_resp_mask, | |
| "dpo_rejected_input_ids": dpo_rejected_input_ids, | |
| "dpo_rejected_response_mask": rejected_resp_mask, | |
| "dpo_chosen_ref_logprobs": dpo_chosen_ref_logprobs, | |
| "dpo_rejected_ref_logprobs": dpo_rejected_ref_logprobs, | |
| } | |
| __all__ = ["build_batch"] | |