server-backend / app.py
wop's picture
Update app.py
69e260d verified
Raw
History Blame Contribute Delete
23.6 kB
#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
Cosmos-T3 — single-file Gradio app for inference/chat.
"""
from __future__ import annotations
import os
import sys
import queue
import threading
from pathlib import Path
import torch
import torch.nn as nn
import torch.nn.functional as F
import gradio as gr
from transformers import AutoTokenizer
# ─────────────────────────────────────────────────────────────
# CONFIG
# ─────────────────────────────────────────────────────────────
TOKENIZER_NAME = "Qwen/Qwen2.5-0.5B"
BLOCK_SIZE = 1024
MAX_LEN = 1024
D_MODEL = 768
N_LAYERS = 12
N_HEADS = 12
N_KV_HEADS = 4
D_FF = 2048
ROPE_BASE = 10000
DROP_OUT = 0.0
USE_ENGRAM = True
ENGRAM_EVERY = 4
ENGRAM_BUCKETS = 8192
ENGRAM_DIM = 64
ENGRAM_ORDER = 3
DEFAULT_SYSTEM_PROMPT = "Enable thinking features: INTUITION"
STAGE_CKPT = {
"pretrain": "Cosmos-T3-Pretrain.resume.pt",
"finetune": "Cosmos-T3-Instruct.resume.pt",
}
STAGE_BUCKET = {
"pretrain": "pretrain/checkpoints/Cosmos-T3-Pretrain.resume.pt",
"finetune": "finetune/checkpoints/Cosmos-T3-Instruct.resume.pt",
}
HF_BUCKET_ID = "wop/Cosmos-SFT"
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
PAD_ID = 0
STOP_TOKEN_IDS: set[int] = set()
MODEL_LOCK = threading.Lock()
def resolve_checkpoint(stage="finetune", work_dir="cosmos_t3_run", no_bucket=False):
local = Path(work_dir) / STAGE_CKPT[stage]
if local.exists():
return local
if no_bucket:
raise FileNotFoundError(f"Missing checkpoint: {local}")
token = os.environ.get("HF_TOKEN", "empty")
os.environ["HF_TOKEN"] = token
from huggingface_hub import download_bucket_files
remote = STAGE_BUCKET[stage]
local.parent.mkdir(parents=True, exist_ok=True)
print(f"Downloading from bucket: {HF_BUCKET_ID}/{remote}")
download_bucket_files(HF_BUCKET_ID, files=[(remote, str(local))])
if not local.exists():
raise RuntimeError("Bucket download failed")
return local
# ─────────────────────────────────────────────────────────────
# MODEL CORE
# ─────────────────────────────────────────────────────────────
class RMSNorm(nn.Module):
def __init__(self, dim, eps=1e-6):
super().__init__()
self.weight = nn.Parameter(torch.ones(dim))
self.eps = eps
def forward(self, x):
rms = x.pow(2).mean(dim=-1, keepdim=True)
x = x * torch.rsqrt(rms + self.eps)
return x * self.weight
def rotate_half(x):
x1 = x[..., ::2]
x2 = x[..., 1::2]
return torch.stack((-x2, x1), dim=-1).flatten(-2)
def apply_rope(q, k, cos, sin):
q = (q * cos) + (rotate_half(q) * sin)
k = (k * cos) + (rotate_half(k) * sin)
return q, k
class GQAAttention(nn.Module):
def __init__(self, d_model, n_heads, n_kv_heads, rope_base=10000, dropout=0.0):
super().__init__()
assert d_model % n_heads == 0
assert n_heads % n_kv_heads == 0
self.n_heads = n_heads
self.n_kv_heads = n_kv_heads
self.head_dim = d_model // n_heads
self.rope_base = rope_base
self.dropout = dropout
self.q_proj = nn.Linear(d_model, n_heads * self.head_dim, bias=False)
self.k_proj = nn.Linear(d_model, n_kv_heads * self.head_dim, bias=False)
self.v_proj = nn.Linear(d_model, n_kv_heads * self.head_dim, bias=False)
self.o_proj = nn.Linear(d_model, d_model, bias=False)
def forward(self, x, rope_cos, rope_sin, past_kv=None, use_cache=False):
bsz, seq_len, _ = x.shape
q = self.q_proj(x).view(bsz, seq_len, self.n_heads, self.head_dim).transpose(1, 2)
k = self.k_proj(x).view(bsz, seq_len, self.n_kv_heads, self.head_dim).transpose(1, 2)
v = self.v_proj(x).view(bsz, seq_len, self.n_kv_heads, self.head_dim).transpose(1, 2)
q, k = apply_rope(q, k, rope_cos, rope_sin)
if past_kv is not None:
past_k, past_v = past_kv
k = torch.cat([past_k, k], dim=2)
v = torch.cat([past_v, v], dim=2)
present_kv = (k, v) if use_cache else None
if self.n_kv_heads != self.n_heads:
repeat = self.n_heads // self.n_kv_heads
k = k.repeat_interleave(repeat, dim=1)
v = v.repeat_interleave(repeat, dim=1)
attn_out = F.scaled_dot_product_attention(
q, k, v,
is_causal=(past_kv is None),
dropout_p=self.dropout if self.training else 0.0,
)
attn_out = attn_out.transpose(1, 2).contiguous().view(bsz, seq_len, -1)
attn_out = self.o_proj(attn_out)
return (attn_out, present_kv) if use_cache else attn_out
class SwiGLUMLP(nn.Module):
def __init__(self, d_model, hidden_dim, dropout=0.0):
super().__init__()
self.gate = nn.Linear(d_model, hidden_dim, bias=False)
self.up = nn.Linear(d_model, hidden_dim, bias=False)
self.down = nn.Linear(hidden_dim, d_model, bias=False)
self.dropout = nn.Dropout(dropout)
def forward(self, x):
x = F.silu(self.gate(x)) * self.up(x)
return self.down(self.dropout(x))
class EngramMemory(nn.Module):
def __init__(self, d_model, bucket_count, memory_dim, order, pad_id=0, dropout=0.0):
super().__init__()
self.bucket_count = bucket_count
self.memory_dim = memory_dim
self.order = order
self.pad_id = pad_id
self.bucket = nn.Embedding(bucket_count, memory_dim)
self.query = nn.Linear(d_model, memory_dim, bias=False)
self.project = nn.Linear(memory_dim, d_model, bias=False)
self.gate = nn.Linear(d_model, d_model, bias=True)
self.dropout = nn.Dropout(dropout)
primes = [1, 1315423911, 2654435761, 97531, 433494437]
self.register_buffer("primes", torch.tensor(primes[:order], dtype=torch.long), persistent=False)
def hash_tokens(self, idx):
batch, seq_len = idx.shape
pad = torch.full((batch, self.order - 1), self.pad_id, device=idx.device, dtype=idx.dtype)
history = torch.cat([pad, idx], dim=1)
hashed = torch.zeros((batch, seq_len), device=idx.device, dtype=torch.long)
for offset in range(self.order):
slice_ = history[:, offset: offset + seq_len].long()
hashed = (hashed * 1315423911 + slice_ * self.primes[offset]) % self.bucket_count
return hashed
def forward(self, x, idx):
hashed = self.hash_tokens(idx)
if hashed.size(1) != x.size(1):
hashed = hashed[:, -x.size(1):]
query = torch.tanh(self.query(x))
memory = self.bucket(hashed) * query
memory = self.project(memory)
gate = torch.sigmoid(self.gate(x))
return self.dropout(gate * memory)
class Block(nn.Module):
def __init__(
self,
d_model,
n_heads,
n_kv_heads,
d_ff,
rope_base,
dropout=0.0,
use_engram=False,
engram_bucket_count=4096,
engram_dim=96,
engram_order=3,
pad_id=0,
):
super().__init__()
self.norm1 = RMSNorm(d_model)
self.attn = GQAAttention(d_model, n_heads, n_kv_heads, rope_base=rope_base, dropout=dropout)
self.norm2 = RMSNorm(d_model)
self.engram = (
EngramMemory(d_model, engram_bucket_count, engram_dim, engram_order, pad_id=pad_id, dropout=dropout)
if use_engram
else None
)
self.norm3 = RMSNorm(d_model)
self.mlp = SwiGLUMLP(d_model, d_ff, dropout=dropout)
def forward(self, x, idx, rope_cos, rope_sin):
x = x + self.attn(self.norm1(x), rope_cos, rope_sin)
if self.engram is not None:
x = x + self.engram(self.norm2(x), idx)
x = x + self.mlp(self.norm3(x))
return x
def forward_cached(self, x, idx_context, rope_cos, rope_sin, past_kv=None):
attn_out, present_kv = self.attn(
self.norm1(x),
rope_cos,
rope_sin,
past_kv=past_kv,
use_cache=True,
)
x = x + attn_out
if self.engram is not None:
x = x + self.engram(self.norm2(x), idx_context)
x = x + self.mlp(self.norm3(x))
return x, present_kv
class CosmosT2_Accelerate_LLM(nn.Module):
def __init__(
self,
vocab_size,
d_model=D_MODEL,
n_layers=N_LAYERS,
n_heads=N_HEADS,
n_kv_heads=N_KV_HEADS,
d_ff=D_FF,
max_len=MAX_LEN,
rope_base=ROPE_BASE,
dropout=DROP_OUT,
use_engram=USE_ENGRAM,
engram_every=ENGRAM_EVERY,
engram_bucket_count=ENGRAM_BUCKETS,
engram_dim=ENGRAM_DIM,
engram_order=ENGRAM_ORDER,
pad_id=0,
):
super().__init__()
self.vocab_size = vocab_size
self.d_model = d_model
self.n_layers = n_layers
self.n_heads = n_heads
self.n_kv_heads = n_kv_heads
self.rope_theta = float(rope_base)
self.head_dim = d_model // n_heads
self.max_len = max_len
self.rope_base = rope_base
self.pad_id = pad_id
self.tok_emb = nn.Embedding(vocab_size, d_model)
self.blocks = nn.ModuleList()
for layer_index in range(n_layers):
block_uses_engram = use_engram and ((layer_index + 1) % engram_every == 0)
self.blocks.append(
Block(
d_model=d_model,
n_heads=n_heads,
n_kv_heads=n_kv_heads,
d_ff=d_ff,
rope_base=rope_base,
dropout=dropout,
use_engram=block_uses_engram,
engram_bucket_count=engram_bucket_count,
engram_dim=engram_dim,
engram_order=engram_order,
pad_id=pad_id,
)
)
self.norm_f = RMSNorm(d_model)
def build_rope(self, seq_len, device, dtype, start_pos=0):
inv_freq = 1.0 / (
self.rope_theta ** (torch.arange(0, self.head_dim, 2, device=device).float() / self.head_dim)
)
positions = torch.arange(start_pos, start_pos + seq_len, device=device).float()
freqs = torch.outer(positions, inv_freq)
cos = freqs.cos().repeat_interleave(2, dim=-1).to(dtype)[None, None, :, :]
sin = freqs.sin().repeat_interleave(2, dim=-1).to(dtype)[None, None, :, :]
return cos, sin
def forward(self, idx, targets=None):
if idx.size(1) > self.max_len:
idx = idx[:, -self.max_len:]
seq_len = idx.size(1)
rope_cos, rope_sin = self.build_rope(seq_len, idx.device, self.tok_emb.weight.dtype)
x = self.tok_emb(idx)
for block in self.blocks:
x = block(x, idx, rope_cos, rope_sin)
x = self.norm_f(x)
logits = F.linear(x, self.tok_emb.weight)
loss = None
if targets is not None:
loss = F.cross_entropy(logits.reshape(-1, logits.size(-1)), targets.reshape(-1))
return logits, loss
def trim_kv_cache(self, past_kv, max_tokens):
if past_kv is None:
return None
max_tokens = max(0, int(max_tokens))
trimmed = []
for k, v in past_kv:
if max_tokens == 0:
k = k[:, :, :0, :].contiguous()
v = v[:, :, :0, :].contiguous()
elif k.size(2) > max_tokens:
k = k[:, :, -max_tokens:, :].contiguous()
v = v[:, :, -max_tokens:, :].contiguous()
trimmed.append((k, v))
return trimmed
@torch.no_grad()
def forward_cached(self, idx, past_kv=None, cache_pos=0, max_ctx=None, idx_context=None):
self.eval()
max_ctx = self.max_len if max_ctx is None else int(max_ctx)
if past_kv is None:
idx = idx[:, -max_ctx:]
idx_context = idx
cache_pos = 0
else:
keep_past = max(0, max_ctx - idx.size(1))
past_kv = self.trim_kv_cache(past_kv, keep_past)
idx_context = idx if idx_context is None else idx_context[:, -max_ctx:]
seq_len = idx.size(1)
rope_cos, rope_sin = self.build_rope(
seq_len,
idx.device,
self.tok_emb.weight.dtype,
start_pos=cache_pos,
)
x = self.tok_emb(idx)
present_kv = []
for layer_index, block in enumerate(self.blocks):
layer_past = None if past_kv is None else past_kv[layer_index]
x, layer_present = block.forward_cached(
x,
idx_context,
rope_cos,
rope_sin,
past_kv=layer_past,
)
present_kv.append(layer_present)
x = self.norm_f(x)
logits = F.linear(x, self.tok_emb.weight)
return logits, present_kv, cache_pos + seq_len
def sample_next(self, logits, temperature=0.8, top_k=50):
if logits.dim() == 3:
logits = logits[:, -1, :]
if temperature <= 1e-6:
return torch.argmax(logits, dim=-1, keepdim=True)
logits = logits / temperature
if top_k and top_k > 0:
values, _ = torch.topk(logits, min(top_k, logits.size(-1)))
cutoff = values[:, [-1]]
logits = logits.masked_fill(logits < cutoff, float("-inf"))
probs = F.softmax(logits, dim=-1)
return torch.multinomial(probs, num_samples=1)
@torch.no_grad()
def prefill_cache(self, idx, max_ctx=None):
logits, past_kv, cache_pos = self.forward_cached(idx, past_kv=None, cache_pos=0, max_ctx=max_ctx)
return logits[:, -1, :], past_kv, cache_pos
@torch.no_grad()
def decode_cached(self, idx, past_kv, cache_pos, idx_context, max_ctx=None):
logits, past_kv, cache_pos = self.forward_cached(
idx,
past_kv=past_kv,
cache_pos=cache_pos,
max_ctx=max_ctx,
idx_context=idx_context,
)
return logits[:, -1, :], past_kv, cache_pos
@torch.no_grad()
def generate(
self,
idx,
max_new_tokens=256,
temperature=0.9,
top_k=45,
max_ctx=None,
stop_ids=None,
on_token=None,
):
self.eval()
max_ctx = self.max_len if max_ctx is None else int(max_ctx)
idx = idx[:, -max_ctx:]
logits, past_kv, cache_pos = self.prefill_cache(idx, max_ctx=max_ctx)
stop_ids = STOP_TOKEN_IDS if stop_ids is None else stop_ids
for step in range(max_new_tokens):
nxt = self.sample_next(logits, temperature=temperature, top_k=top_k)
if stop_ids and nxt.numel() == 1 and int(nxt.item()) in stop_ids:
break
if on_token is not None:
on_token(int(nxt.item()))
idx = torch.cat([idx, nxt], dim=1)
if step + 1 < max_new_tokens:
logits, past_kv, cache_pos = self.decode_cached(
nxt,
past_kv,
cache_pos,
idx[:, -max_ctx:],
max_ctx=max_ctx,
)
return idx
# ─────────────────────────────────────────────────────────────
# HELPERS
# ─────────────────────────────────────────────────────────────
def _resolve_stop_ids(tok):
ids = set()
for t in ("<|im_end|>", "<|endoftext|>"):
i = tok.convert_tokens_to_ids(t)
if isinstance(i, int) and i >= 0 and i != tok.unk_token_id:
ids.add(i)
if tok.eos_token_id is not None:
ids.add(tok.eos_token_id)
return ids
def _looks_like_state_dict(d):
if not isinstance(d, dict) or not d:
return False
tensor_vals = [v for v in d.values() if torch.is_tensor(v)]
if len(tensor_vals) < max(4, 0.5 * len(d)):
return False
return any("." in str(k) for k in d.keys())
def _extract_state_dict(blob):
if _looks_like_state_dict(blob):
return blob
if isinstance(blob, dict):
for key in ("model_state_dict", "model", "model_state", "state_dict", "weights", "net", "module", "ema", "ema_model"):
inner = blob.get(key)
if _looks_like_state_dict(inner):
return inner
if isinstance(inner, dict):
for k2, v2 in inner.items():
if _looks_like_state_dict(v2):
return v2
for v in blob.values():
if _looks_like_state_dict(v):
return v
if isinstance(v, dict):
for v2 in v.values():
if _looks_like_state_dict(v2):
return v2
raise ValueError(
"Could not find a model state_dict in the checkpoint. "
f"Top-level keys were: {list(blob.keys())}"
)
raise ValueError(f"Unexpected checkpoint type: {type(blob)}")
def load_model(ckpt_path, tokenizer):
blob = torch.load(ckpt_path, map_location="cpu", weights_only=False)
cfg = {}
if isinstance(blob, dict):
for key in ("model_config", "config"):
if isinstance(blob.get(key), dict):
cfg = blob[key]
break
model = CosmosT2_Accelerate_LLM(
vocab_size=cfg.get("vocab_size", len(tokenizer)),
d_model=cfg.get("d_model", D_MODEL),
n_layers=cfg.get("n_layers", N_LAYERS),
n_heads=cfg.get("n_heads", N_HEADS),
n_kv_heads=cfg.get("n_kv_heads", N_KV_HEADS),
d_ff=cfg.get("d_ff", D_FF),
max_len=cfg.get("max_len", MAX_LEN),
rope_base=cfg.get("rope_base", ROPE_BASE),
dropout=0.0,
use_engram=cfg.get("use_engram", USE_ENGRAM),
engram_every=cfg.get("engram_every", ENGRAM_EVERY),
engram_bucket_count=cfg.get("engram_buckets", ENGRAM_BUCKETS),
engram_dim=cfg.get("engram_dim", ENGRAM_DIM),
engram_order=cfg.get("engram_order", ENGRAM_ORDER),
pad_id=tokenizer.pad_token_id or 0,
)
state = _extract_state_dict(blob)
missing, unexpected = model.load_state_dict(state, strict=False)
if missing:
print(f"[warn] missing keys: {len(missing)} (e.g. {missing[:3]})")
if unexpected:
print(f"[warn] unexpected keys: {len(unexpected)} (e.g. {unexpected[:3]})")
model.eval()
return model
def build_prompt_ids(tokenizer, user_text, stage, system_prompt, history=None):
if stage == "pretrain":
ids = tokenizer(user_text, add_special_tokens=False, return_attention_mask=False)["input_ids"]
return ids
messages = []
if system_prompt:
messages.append({"role": "system", "content": system_prompt})
for role, content in (history or []):
messages.append({"role": role, "content": content})
messages.append({"role": "user", "content": user_text})
text = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
return tokenizer(text, add_special_tokens=False, return_attention_mask=False)["input_ids"]
# ─────────────────────────────────────────────────────────────
# LOAD ON STARTUP
# ─────────────────────────────────────────────────────────────
print("Loading tokenizer...")
tokenizer = AutoTokenizer.from_pretrained(TOKENIZER_NAME, trust_remote_code=True)
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
PAD_ID = tokenizer.pad_token_id
STOP_TOKEN_IDS = _resolve_stop_ids(tokenizer)
# ─────────────────────────────
# FIX: resolve FIRST
# ─────────────────────────────
CKPT_PATH = resolve_checkpoint(stage="finetune")
print(f"Loading model checkpoint: {CKPT_PATH}")
model = load_model(CKPT_PATH, tokenizer)
model.to(device)
model.eval()
n_params = sum(p.numel() for p in model.parameters())
print(f"Model ready: {n_params/1e6:.1f}M params | device={device}")
# ─────────────────────────────────────────────────────────────
# GRADIO STREAMING
# ─────────────────────────────────────────────────────────────
def history_to_role_messages(history):
messages = []
for user_msg, assistant_msg in history or []:
messages.append(("user", user_msg))
messages.append(("assistant", assistant_msg))
return messages
def chat_stream(message, history, system_prompt=DEFAULT_SYSTEM_PROMPT):
role_history = history_to_role_messages(history)
prompt_ids = build_prompt_ids(
tokenizer=tokenizer,
user_text=message,
stage="finetune",
system_prompt=system_prompt,
history=role_history,
)
idx = torch.tensor([prompt_ids], dtype=torch.long, device=device)
q: queue.Queue[str | object] = queue.Queue()
END = object()
def worker():
try:
def on_token(tid: int):
txt = tokenizer.decode([tid], skip_special_tokens=True)
q.put(txt)
with MODEL_LOCK:
with torch.inference_mode():
model.generate(
idx,
max_new_tokens=256,
temperature=0.9,
top_k=45,
max_ctx=MAX_LEN,
on_token=on_token,
)
finally:
q.put(END)
threading.Thread(target=worker, daemon=True).start()
output = ""
while True:
item = q.get()
if item is END:
break
output += item
yield output
def chat(message, history):
yield from chat_stream(message, history, system_prompt=DEFAULT_SYSTEM_PROMPT)
# ─────────────────────────────────────────────────────────────
# UI
# ─────────────────────────────────────────────────────────────
demo = gr.ChatInterface(
fn=chat,
title="Cosmos-T3 API",
description="Streaming inference API (backend for your frontend)",
)
if __name__ == "__main__":
demo.queue().launch(
server_name="0.0.0.0",
server_port=7860,
)