Spaces:
Sleeping
Sleeping
| #!/usr/bin/env python | |
| # -*- coding: utf-8 -*- | |
| """ | |
| Cosmos-T3 — single-file Gradio app for inference/chat. | |
| """ | |
| from __future__ import annotations | |
| import os | |
| import sys | |
| import queue | |
| import threading | |
| from pathlib import Path | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| import gradio as gr | |
| from transformers import AutoTokenizer | |
| # ───────────────────────────────────────────────────────────── | |
| # CONFIG | |
| # ───────────────────────────────────────────────────────────── | |
| TOKENIZER_NAME = "Qwen/Qwen2.5-0.5B" | |
| BLOCK_SIZE = 1024 | |
| MAX_LEN = 1024 | |
| D_MODEL = 768 | |
| N_LAYERS = 12 | |
| N_HEADS = 12 | |
| N_KV_HEADS = 4 | |
| D_FF = 2048 | |
| ROPE_BASE = 10000 | |
| DROP_OUT = 0.0 | |
| USE_ENGRAM = True | |
| ENGRAM_EVERY = 4 | |
| ENGRAM_BUCKETS = 8192 | |
| ENGRAM_DIM = 64 | |
| ENGRAM_ORDER = 3 | |
| DEFAULT_SYSTEM_PROMPT = "Enable thinking features: INTUITION" | |
| STAGE_CKPT = { | |
| "pretrain": "Cosmos-T3-Pretrain.resume.pt", | |
| "finetune": "Cosmos-T3-Instruct.resume.pt", | |
| } | |
| STAGE_BUCKET = { | |
| "pretrain": "pretrain/checkpoints/Cosmos-T3-Pretrain.resume.pt", | |
| "finetune": "finetune/checkpoints/Cosmos-T3-Instruct.resume.pt", | |
| } | |
| HF_BUCKET_ID = "wop/Cosmos-SFT" | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| PAD_ID = 0 | |
| STOP_TOKEN_IDS: set[int] = set() | |
| MODEL_LOCK = threading.Lock() | |
| def resolve_checkpoint(stage="finetune", work_dir="cosmos_t3_run", no_bucket=False): | |
| local = Path(work_dir) / STAGE_CKPT[stage] | |
| if local.exists(): | |
| return local | |
| if no_bucket: | |
| raise FileNotFoundError(f"Missing checkpoint: {local}") | |
| token = os.environ.get("HF_TOKEN", "empty") | |
| os.environ["HF_TOKEN"] = token | |
| from huggingface_hub import download_bucket_files | |
| remote = STAGE_BUCKET[stage] | |
| local.parent.mkdir(parents=True, exist_ok=True) | |
| print(f"Downloading from bucket: {HF_BUCKET_ID}/{remote}") | |
| download_bucket_files(HF_BUCKET_ID, files=[(remote, str(local))]) | |
| if not local.exists(): | |
| raise RuntimeError("Bucket download failed") | |
| return local | |
| # ───────────────────────────────────────────────────────────── | |
| # MODEL CORE | |
| # ───────────────────────────────────────────────────────────── | |
| class RMSNorm(nn.Module): | |
| def __init__(self, dim, eps=1e-6): | |
| super().__init__() | |
| self.weight = nn.Parameter(torch.ones(dim)) | |
| self.eps = eps | |
| def forward(self, x): | |
| rms = x.pow(2).mean(dim=-1, keepdim=True) | |
| x = x * torch.rsqrt(rms + self.eps) | |
| return x * self.weight | |
| def rotate_half(x): | |
| x1 = x[..., ::2] | |
| x2 = x[..., 1::2] | |
| return torch.stack((-x2, x1), dim=-1).flatten(-2) | |
| def apply_rope(q, k, cos, sin): | |
| q = (q * cos) + (rotate_half(q) * sin) | |
| k = (k * cos) + (rotate_half(k) * sin) | |
| return q, k | |
| class GQAAttention(nn.Module): | |
| def __init__(self, d_model, n_heads, n_kv_heads, rope_base=10000, dropout=0.0): | |
| super().__init__() | |
| assert d_model % n_heads == 0 | |
| assert n_heads % n_kv_heads == 0 | |
| self.n_heads = n_heads | |
| self.n_kv_heads = n_kv_heads | |
| self.head_dim = d_model // n_heads | |
| self.rope_base = rope_base | |
| self.dropout = dropout | |
| self.q_proj = nn.Linear(d_model, n_heads * self.head_dim, bias=False) | |
| self.k_proj = nn.Linear(d_model, n_kv_heads * self.head_dim, bias=False) | |
| self.v_proj = nn.Linear(d_model, n_kv_heads * self.head_dim, bias=False) | |
| self.o_proj = nn.Linear(d_model, d_model, bias=False) | |
| def forward(self, x, rope_cos, rope_sin, past_kv=None, use_cache=False): | |
| bsz, seq_len, _ = x.shape | |
| q = self.q_proj(x).view(bsz, seq_len, self.n_heads, self.head_dim).transpose(1, 2) | |
| k = self.k_proj(x).view(bsz, seq_len, self.n_kv_heads, self.head_dim).transpose(1, 2) | |
| v = self.v_proj(x).view(bsz, seq_len, self.n_kv_heads, self.head_dim).transpose(1, 2) | |
| q, k = apply_rope(q, k, rope_cos, rope_sin) | |
| if past_kv is not None: | |
| past_k, past_v = past_kv | |
| k = torch.cat([past_k, k], dim=2) | |
| v = torch.cat([past_v, v], dim=2) | |
| present_kv = (k, v) if use_cache else None | |
| if self.n_kv_heads != self.n_heads: | |
| repeat = self.n_heads // self.n_kv_heads | |
| k = k.repeat_interleave(repeat, dim=1) | |
| v = v.repeat_interleave(repeat, dim=1) | |
| attn_out = F.scaled_dot_product_attention( | |
| q, k, v, | |
| is_causal=(past_kv is None), | |
| dropout_p=self.dropout if self.training else 0.0, | |
| ) | |
| attn_out = attn_out.transpose(1, 2).contiguous().view(bsz, seq_len, -1) | |
| attn_out = self.o_proj(attn_out) | |
| return (attn_out, present_kv) if use_cache else attn_out | |
| class SwiGLUMLP(nn.Module): | |
| def __init__(self, d_model, hidden_dim, dropout=0.0): | |
| super().__init__() | |
| self.gate = nn.Linear(d_model, hidden_dim, bias=False) | |
| self.up = nn.Linear(d_model, hidden_dim, bias=False) | |
| self.down = nn.Linear(hidden_dim, d_model, bias=False) | |
| self.dropout = nn.Dropout(dropout) | |
| def forward(self, x): | |
| x = F.silu(self.gate(x)) * self.up(x) | |
| return self.down(self.dropout(x)) | |
| class EngramMemory(nn.Module): | |
| def __init__(self, d_model, bucket_count, memory_dim, order, pad_id=0, dropout=0.0): | |
| super().__init__() | |
| self.bucket_count = bucket_count | |
| self.memory_dim = memory_dim | |
| self.order = order | |
| self.pad_id = pad_id | |
| self.bucket = nn.Embedding(bucket_count, memory_dim) | |
| self.query = nn.Linear(d_model, memory_dim, bias=False) | |
| self.project = nn.Linear(memory_dim, d_model, bias=False) | |
| self.gate = nn.Linear(d_model, d_model, bias=True) | |
| self.dropout = nn.Dropout(dropout) | |
| primes = [1, 1315423911, 2654435761, 97531, 433494437] | |
| self.register_buffer("primes", torch.tensor(primes[:order], dtype=torch.long), persistent=False) | |
| def hash_tokens(self, idx): | |
| batch, seq_len = idx.shape | |
| pad = torch.full((batch, self.order - 1), self.pad_id, device=idx.device, dtype=idx.dtype) | |
| history = torch.cat([pad, idx], dim=1) | |
| hashed = torch.zeros((batch, seq_len), device=idx.device, dtype=torch.long) | |
| for offset in range(self.order): | |
| slice_ = history[:, offset: offset + seq_len].long() | |
| hashed = (hashed * 1315423911 + slice_ * self.primes[offset]) % self.bucket_count | |
| return hashed | |
| def forward(self, x, idx): | |
| hashed = self.hash_tokens(idx) | |
| if hashed.size(1) != x.size(1): | |
| hashed = hashed[:, -x.size(1):] | |
| query = torch.tanh(self.query(x)) | |
| memory = self.bucket(hashed) * query | |
| memory = self.project(memory) | |
| gate = torch.sigmoid(self.gate(x)) | |
| return self.dropout(gate * memory) | |
| class Block(nn.Module): | |
| def __init__( | |
| self, | |
| d_model, | |
| n_heads, | |
| n_kv_heads, | |
| d_ff, | |
| rope_base, | |
| dropout=0.0, | |
| use_engram=False, | |
| engram_bucket_count=4096, | |
| engram_dim=96, | |
| engram_order=3, | |
| pad_id=0, | |
| ): | |
| super().__init__() | |
| self.norm1 = RMSNorm(d_model) | |
| self.attn = GQAAttention(d_model, n_heads, n_kv_heads, rope_base=rope_base, dropout=dropout) | |
| self.norm2 = RMSNorm(d_model) | |
| self.engram = ( | |
| EngramMemory(d_model, engram_bucket_count, engram_dim, engram_order, pad_id=pad_id, dropout=dropout) | |
| if use_engram | |
| else None | |
| ) | |
| self.norm3 = RMSNorm(d_model) | |
| self.mlp = SwiGLUMLP(d_model, d_ff, dropout=dropout) | |
| def forward(self, x, idx, rope_cos, rope_sin): | |
| x = x + self.attn(self.norm1(x), rope_cos, rope_sin) | |
| if self.engram is not None: | |
| x = x + self.engram(self.norm2(x), idx) | |
| x = x + self.mlp(self.norm3(x)) | |
| return x | |
| def forward_cached(self, x, idx_context, rope_cos, rope_sin, past_kv=None): | |
| attn_out, present_kv = self.attn( | |
| self.norm1(x), | |
| rope_cos, | |
| rope_sin, | |
| past_kv=past_kv, | |
| use_cache=True, | |
| ) | |
| x = x + attn_out | |
| if self.engram is not None: | |
| x = x + self.engram(self.norm2(x), idx_context) | |
| x = x + self.mlp(self.norm3(x)) | |
| return x, present_kv | |
| class CosmosT2_Accelerate_LLM(nn.Module): | |
| def __init__( | |
| self, | |
| vocab_size, | |
| d_model=D_MODEL, | |
| n_layers=N_LAYERS, | |
| n_heads=N_HEADS, | |
| n_kv_heads=N_KV_HEADS, | |
| d_ff=D_FF, | |
| max_len=MAX_LEN, | |
| rope_base=ROPE_BASE, | |
| dropout=DROP_OUT, | |
| use_engram=USE_ENGRAM, | |
| engram_every=ENGRAM_EVERY, | |
| engram_bucket_count=ENGRAM_BUCKETS, | |
| engram_dim=ENGRAM_DIM, | |
| engram_order=ENGRAM_ORDER, | |
| pad_id=0, | |
| ): | |
| super().__init__() | |
| self.vocab_size = vocab_size | |
| self.d_model = d_model | |
| self.n_layers = n_layers | |
| self.n_heads = n_heads | |
| self.n_kv_heads = n_kv_heads | |
| self.rope_theta = float(rope_base) | |
| self.head_dim = d_model // n_heads | |
| self.max_len = max_len | |
| self.rope_base = rope_base | |
| self.pad_id = pad_id | |
| self.tok_emb = nn.Embedding(vocab_size, d_model) | |
| self.blocks = nn.ModuleList() | |
| for layer_index in range(n_layers): | |
| block_uses_engram = use_engram and ((layer_index + 1) % engram_every == 0) | |
| self.blocks.append( | |
| Block( | |
| d_model=d_model, | |
| n_heads=n_heads, | |
| n_kv_heads=n_kv_heads, | |
| d_ff=d_ff, | |
| rope_base=rope_base, | |
| dropout=dropout, | |
| use_engram=block_uses_engram, | |
| engram_bucket_count=engram_bucket_count, | |
| engram_dim=engram_dim, | |
| engram_order=engram_order, | |
| pad_id=pad_id, | |
| ) | |
| ) | |
| self.norm_f = RMSNorm(d_model) | |
| def build_rope(self, seq_len, device, dtype, start_pos=0): | |
| inv_freq = 1.0 / ( | |
| self.rope_theta ** (torch.arange(0, self.head_dim, 2, device=device).float() / self.head_dim) | |
| ) | |
| positions = torch.arange(start_pos, start_pos + seq_len, device=device).float() | |
| freqs = torch.outer(positions, inv_freq) | |
| cos = freqs.cos().repeat_interleave(2, dim=-1).to(dtype)[None, None, :, :] | |
| sin = freqs.sin().repeat_interleave(2, dim=-1).to(dtype)[None, None, :, :] | |
| return cos, sin | |
| def forward(self, idx, targets=None): | |
| if idx.size(1) > self.max_len: | |
| idx = idx[:, -self.max_len:] | |
| seq_len = idx.size(1) | |
| rope_cos, rope_sin = self.build_rope(seq_len, idx.device, self.tok_emb.weight.dtype) | |
| x = self.tok_emb(idx) | |
| for block in self.blocks: | |
| x = block(x, idx, rope_cos, rope_sin) | |
| x = self.norm_f(x) | |
| logits = F.linear(x, self.tok_emb.weight) | |
| loss = None | |
| if targets is not None: | |
| loss = F.cross_entropy(logits.reshape(-1, logits.size(-1)), targets.reshape(-1)) | |
| return logits, loss | |
| def trim_kv_cache(self, past_kv, max_tokens): | |
| if past_kv is None: | |
| return None | |
| max_tokens = max(0, int(max_tokens)) | |
| trimmed = [] | |
| for k, v in past_kv: | |
| if max_tokens == 0: | |
| k = k[:, :, :0, :].contiguous() | |
| v = v[:, :, :0, :].contiguous() | |
| elif k.size(2) > max_tokens: | |
| k = k[:, :, -max_tokens:, :].contiguous() | |
| v = v[:, :, -max_tokens:, :].contiguous() | |
| trimmed.append((k, v)) | |
| return trimmed | |
| def forward_cached(self, idx, past_kv=None, cache_pos=0, max_ctx=None, idx_context=None): | |
| self.eval() | |
| max_ctx = self.max_len if max_ctx is None else int(max_ctx) | |
| if past_kv is None: | |
| idx = idx[:, -max_ctx:] | |
| idx_context = idx | |
| cache_pos = 0 | |
| else: | |
| keep_past = max(0, max_ctx - idx.size(1)) | |
| past_kv = self.trim_kv_cache(past_kv, keep_past) | |
| idx_context = idx if idx_context is None else idx_context[:, -max_ctx:] | |
| seq_len = idx.size(1) | |
| rope_cos, rope_sin = self.build_rope( | |
| seq_len, | |
| idx.device, | |
| self.tok_emb.weight.dtype, | |
| start_pos=cache_pos, | |
| ) | |
| x = self.tok_emb(idx) | |
| present_kv = [] | |
| for layer_index, block in enumerate(self.blocks): | |
| layer_past = None if past_kv is None else past_kv[layer_index] | |
| x, layer_present = block.forward_cached( | |
| x, | |
| idx_context, | |
| rope_cos, | |
| rope_sin, | |
| past_kv=layer_past, | |
| ) | |
| present_kv.append(layer_present) | |
| x = self.norm_f(x) | |
| logits = F.linear(x, self.tok_emb.weight) | |
| return logits, present_kv, cache_pos + seq_len | |
| def sample_next(self, logits, temperature=0.8, top_k=50): | |
| if logits.dim() == 3: | |
| logits = logits[:, -1, :] | |
| if temperature <= 1e-6: | |
| return torch.argmax(logits, dim=-1, keepdim=True) | |
| logits = logits / temperature | |
| if top_k and top_k > 0: | |
| values, _ = torch.topk(logits, min(top_k, logits.size(-1))) | |
| cutoff = values[:, [-1]] | |
| logits = logits.masked_fill(logits < cutoff, float("-inf")) | |
| probs = F.softmax(logits, dim=-1) | |
| return torch.multinomial(probs, num_samples=1) | |
| def prefill_cache(self, idx, max_ctx=None): | |
| logits, past_kv, cache_pos = self.forward_cached(idx, past_kv=None, cache_pos=0, max_ctx=max_ctx) | |
| return logits[:, -1, :], past_kv, cache_pos | |
| def decode_cached(self, idx, past_kv, cache_pos, idx_context, max_ctx=None): | |
| logits, past_kv, cache_pos = self.forward_cached( | |
| idx, | |
| past_kv=past_kv, | |
| cache_pos=cache_pos, | |
| max_ctx=max_ctx, | |
| idx_context=idx_context, | |
| ) | |
| return logits[:, -1, :], past_kv, cache_pos | |
| def generate( | |
| self, | |
| idx, | |
| max_new_tokens=256, | |
| temperature=0.9, | |
| top_k=45, | |
| max_ctx=None, | |
| stop_ids=None, | |
| on_token=None, | |
| ): | |
| self.eval() | |
| max_ctx = self.max_len if max_ctx is None else int(max_ctx) | |
| idx = idx[:, -max_ctx:] | |
| logits, past_kv, cache_pos = self.prefill_cache(idx, max_ctx=max_ctx) | |
| stop_ids = STOP_TOKEN_IDS if stop_ids is None else stop_ids | |
| for step in range(max_new_tokens): | |
| nxt = self.sample_next(logits, temperature=temperature, top_k=top_k) | |
| if stop_ids and nxt.numel() == 1 and int(nxt.item()) in stop_ids: | |
| break | |
| if on_token is not None: | |
| on_token(int(nxt.item())) | |
| idx = torch.cat([idx, nxt], dim=1) | |
| if step + 1 < max_new_tokens: | |
| logits, past_kv, cache_pos = self.decode_cached( | |
| nxt, | |
| past_kv, | |
| cache_pos, | |
| idx[:, -max_ctx:], | |
| max_ctx=max_ctx, | |
| ) | |
| return idx | |
| # ───────────────────────────────────────────────────────────── | |
| # HELPERS | |
| # ───────────────────────────────────────────────────────────── | |
| def _resolve_stop_ids(tok): | |
| ids = set() | |
| for t in ("<|im_end|>", "<|endoftext|>"): | |
| i = tok.convert_tokens_to_ids(t) | |
| if isinstance(i, int) and i >= 0 and i != tok.unk_token_id: | |
| ids.add(i) | |
| if tok.eos_token_id is not None: | |
| ids.add(tok.eos_token_id) | |
| return ids | |
| def _looks_like_state_dict(d): | |
| if not isinstance(d, dict) or not d: | |
| return False | |
| tensor_vals = [v for v in d.values() if torch.is_tensor(v)] | |
| if len(tensor_vals) < max(4, 0.5 * len(d)): | |
| return False | |
| return any("." in str(k) for k in d.keys()) | |
| def _extract_state_dict(blob): | |
| if _looks_like_state_dict(blob): | |
| return blob | |
| if isinstance(blob, dict): | |
| for key in ("model_state_dict", "model", "model_state", "state_dict", "weights", "net", "module", "ema", "ema_model"): | |
| inner = blob.get(key) | |
| if _looks_like_state_dict(inner): | |
| return inner | |
| if isinstance(inner, dict): | |
| for k2, v2 in inner.items(): | |
| if _looks_like_state_dict(v2): | |
| return v2 | |
| for v in blob.values(): | |
| if _looks_like_state_dict(v): | |
| return v | |
| if isinstance(v, dict): | |
| for v2 in v.values(): | |
| if _looks_like_state_dict(v2): | |
| return v2 | |
| raise ValueError( | |
| "Could not find a model state_dict in the checkpoint. " | |
| f"Top-level keys were: {list(blob.keys())}" | |
| ) | |
| raise ValueError(f"Unexpected checkpoint type: {type(blob)}") | |
| def load_model(ckpt_path, tokenizer): | |
| blob = torch.load(ckpt_path, map_location="cpu", weights_only=False) | |
| cfg = {} | |
| if isinstance(blob, dict): | |
| for key in ("model_config", "config"): | |
| if isinstance(blob.get(key), dict): | |
| cfg = blob[key] | |
| break | |
| model = CosmosT2_Accelerate_LLM( | |
| vocab_size=cfg.get("vocab_size", len(tokenizer)), | |
| d_model=cfg.get("d_model", D_MODEL), | |
| n_layers=cfg.get("n_layers", N_LAYERS), | |
| n_heads=cfg.get("n_heads", N_HEADS), | |
| n_kv_heads=cfg.get("n_kv_heads", N_KV_HEADS), | |
| d_ff=cfg.get("d_ff", D_FF), | |
| max_len=cfg.get("max_len", MAX_LEN), | |
| rope_base=cfg.get("rope_base", ROPE_BASE), | |
| dropout=0.0, | |
| use_engram=cfg.get("use_engram", USE_ENGRAM), | |
| engram_every=cfg.get("engram_every", ENGRAM_EVERY), | |
| engram_bucket_count=cfg.get("engram_buckets", ENGRAM_BUCKETS), | |
| engram_dim=cfg.get("engram_dim", ENGRAM_DIM), | |
| engram_order=cfg.get("engram_order", ENGRAM_ORDER), | |
| pad_id=tokenizer.pad_token_id or 0, | |
| ) | |
| state = _extract_state_dict(blob) | |
| missing, unexpected = model.load_state_dict(state, strict=False) | |
| if missing: | |
| print(f"[warn] missing keys: {len(missing)} (e.g. {missing[:3]})") | |
| if unexpected: | |
| print(f"[warn] unexpected keys: {len(unexpected)} (e.g. {unexpected[:3]})") | |
| model.eval() | |
| return model | |
| def build_prompt_ids(tokenizer, user_text, stage, system_prompt, history=None): | |
| if stage == "pretrain": | |
| ids = tokenizer(user_text, add_special_tokens=False, return_attention_mask=False)["input_ids"] | |
| return ids | |
| messages = [] | |
| if system_prompt: | |
| messages.append({"role": "system", "content": system_prompt}) | |
| for role, content in (history or []): | |
| messages.append({"role": role, "content": content}) | |
| messages.append({"role": "user", "content": user_text}) | |
| text = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) | |
| return tokenizer(text, add_special_tokens=False, return_attention_mask=False)["input_ids"] | |
| # ───────────────────────────────────────────────────────────── | |
| # LOAD ON STARTUP | |
| # ───────────────────────────────────────────────────────────── | |
| print("Loading tokenizer...") | |
| tokenizer = AutoTokenizer.from_pretrained(TOKENIZER_NAME, trust_remote_code=True) | |
| if tokenizer.pad_token is None: | |
| tokenizer.pad_token = tokenizer.eos_token | |
| PAD_ID = tokenizer.pad_token_id | |
| STOP_TOKEN_IDS = _resolve_stop_ids(tokenizer) | |
| # ───────────────────────────── | |
| # FIX: resolve FIRST | |
| # ───────────────────────────── | |
| CKPT_PATH = resolve_checkpoint(stage="finetune") | |
| print(f"Loading model checkpoint: {CKPT_PATH}") | |
| model = load_model(CKPT_PATH, tokenizer) | |
| model.to(device) | |
| model.eval() | |
| n_params = sum(p.numel() for p in model.parameters()) | |
| print(f"Model ready: {n_params/1e6:.1f}M params | device={device}") | |
| # ───────────────────────────────────────────────────────────── | |
| # GRADIO STREAMING | |
| # ───────────────────────────────────────────────────────────── | |
| def history_to_role_messages(history): | |
| messages = [] | |
| for user_msg, assistant_msg in history or []: | |
| messages.append(("user", user_msg)) | |
| messages.append(("assistant", assistant_msg)) | |
| return messages | |
| def chat_stream(message, history, system_prompt=DEFAULT_SYSTEM_PROMPT): | |
| role_history = history_to_role_messages(history) | |
| prompt_ids = build_prompt_ids( | |
| tokenizer=tokenizer, | |
| user_text=message, | |
| stage="finetune", | |
| system_prompt=system_prompt, | |
| history=role_history, | |
| ) | |
| idx = torch.tensor([prompt_ids], dtype=torch.long, device=device) | |
| q: queue.Queue[str | object] = queue.Queue() | |
| END = object() | |
| def worker(): | |
| try: | |
| def on_token(tid: int): | |
| txt = tokenizer.decode([tid], skip_special_tokens=True) | |
| q.put(txt) | |
| with MODEL_LOCK: | |
| with torch.inference_mode(): | |
| model.generate( | |
| idx, | |
| max_new_tokens=256, | |
| temperature=0.9, | |
| top_k=45, | |
| max_ctx=MAX_LEN, | |
| on_token=on_token, | |
| ) | |
| finally: | |
| q.put(END) | |
| threading.Thread(target=worker, daemon=True).start() | |
| output = "" | |
| while True: | |
| item = q.get() | |
| if item is END: | |
| break | |
| output += item | |
| yield output | |
| def chat(message, history): | |
| yield from chat_stream(message, history, system_prompt=DEFAULT_SYSTEM_PROMPT) | |
| # ───────────────────────────────────────────────────────────── | |
| # UI | |
| # ───────────────────────────────────────────────────────────── | |
| demo = gr.ChatInterface( | |
| fn=chat, | |
| title="Cosmos-T3 API", | |
| description="Streaming inference API (backend for your frontend)", | |
| ) | |
| if __name__ == "__main__": | |
| demo.queue().launch( | |
| server_name="0.0.0.0", | |
| server_port=7860, | |
| ) |