Text Generation
Transformers
Safetensors
English
Arabic
quasar_long
silx-ai
quasar-preview
quasar
foundation-model
Mixture of Experts
18b
2b-active
long-context
bittensor
sn24
decentralized-training
distillation
hybrid-transformer
loop-transformer
safe-nope
drope
conversational
custom_code
Instructions to use zsjTiger/Quasar-Preview with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- Transformers
How to use zsjTiger/Quasar-Preview with Transformers:
# Use a pipeline as a high-level helper from transformers import pipeline pipe = pipeline("text-generation", model="zsjTiger/Quasar-Preview", trust_remote_code=True) messages = [ {"role": "user", "content": "Who are you?"}, ] pipe(messages)# Load model directly from transformers import AutoModelForCausalLM model = AutoModelForCausalLM.from_pretrained("zsjTiger/Quasar-Preview", trust_remote_code=True, dtype="auto") - Notebooks
- Google Colab
- Kaggle
- Local Apps Settings
- vLLM
How to use zsjTiger/Quasar-Preview with vLLM:
Install from pip and serve model
# Install vLLM from pip: pip install vllm # Start the vLLM server: vllm serve "zsjTiger/Quasar-Preview" # Call the server using curl (OpenAI-compatible API): curl -X POST "http://localhost:8000/v1/chat/completions" \ -H "Content-Type: application/json" \ --data '{ "model": "zsjTiger/Quasar-Preview", "messages": [ { "role": "user", "content": "What is the capital of France?" } ] }'Use Docker
docker model run hf.co/zsjTiger/Quasar-Preview
- SGLang
How to use zsjTiger/Quasar-Preview with SGLang:
Install from pip and serve model
# Install SGLang from pip: pip install sglang # Start the SGLang server: python3 -m sglang.launch_server \ --model-path "zsjTiger/Quasar-Preview" \ --host 0.0.0.0 \ --port 30000 # Call the server using curl (OpenAI-compatible API): curl -X POST "http://localhost:30000/v1/chat/completions" \ -H "Content-Type: application/json" \ --data '{ "model": "zsjTiger/Quasar-Preview", "messages": [ { "role": "user", "content": "What is the capital of France?" } ] }'Use Docker images
docker run --gpus all \ --shm-size 32g \ -p 30000:30000 \ -v ~/.cache/huggingface:/root/.cache/huggingface \ --env "HF_TOKEN=<secret>" \ --ipc=host \ lmsysorg/sglang:latest \ python3 -m sglang.launch_server \ --model-path "zsjTiger/Quasar-Preview" \ --host 0.0.0.0 \ --port 30000 # Call the server using curl (OpenAI-compatible API): curl -X POST "http://localhost:30000/v1/chat/completions" \ -H "Content-Type: application/json" \ --data '{ "model": "zsjTiger/Quasar-Preview", "messages": [ { "role": "user", "content": "What is the capital of France?" } ] }' - Docker Model Runner
How to use zsjTiger/Quasar-Preview with Docker Model Runner:
docker model run hf.co/zsjTiger/Quasar-Preview
| """ | |
| EngramModule: Conditional N-gram Memory for Quasar-RoPE | |
| Implements Engram from DeepSeek-AI (arXiv:2601.07372). | |
| Design constraints: | |
| - No Python loops over T (sequence length) or B (batch). | |
| - N-gram extraction via torch.unfold (single vectorized op). | |
| - Hash computed via vectorized XOR reduction (loop over n=2..3 only, compile-time constant). | |
| - Embedding lookup via batched advanced indexing β no loop over T. | |
| - Optional Triton kernel fuses hash + lookup + accumulation into a single SRAM pass. | |
| - Zero output at init: conv.weight=0, out_proj uses deep Trinity init. | |
| """ | |
| import math | |
| import os | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| try: | |
| import triton | |
| import triton.language as tl | |
| HAS_TRITON = True | |
| except ImportError: | |
| HAS_TRITON = False | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # Helpers | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def _next_prime(n: int) -> int: | |
| """Smallest prime >= n.""" | |
| def _is_prime(x: int) -> bool: | |
| if x < 2: | |
| return False | |
| if x == 2: | |
| return True | |
| if x % 2 == 0: | |
| return False | |
| for i in range(3, int(x ** 0.5) + 1, 2): | |
| if x % i == 0: | |
| return False | |
| return True | |
| n = max(n, 2) | |
| while not _is_prime(n): | |
| n += 1 | |
| return n | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # Triton Kernel: Fused N-gram Hash + Embedding Lookup | |
| # | |
| # Grid: (B, T). Each program handles one (batch, position) pair. | |
| # For each of the `num_tables` embedding tables: | |
| # 1. Load the suffix n-gram ending at position t (causal, no future tokens). | |
| # 2. Compute XOR-multiplicative hash (loop over n β€ 3, constexpr-unrolled). | |
| # 3. Index into the embedding table and write directly to output. | |
| # One SRAM pass β no intermediate [B,T,n] tensor, no round-trip to HBM. | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| if HAS_TRITON: | |
| def _engram_hash_lookup_kernel( | |
| # [B, T] canonical token IDs (int32 on device) | |
| canonical_ptr, stride_cb, stride_ct, | |
| # [B, T, num_tables * d_slot] output (bfloat16) | |
| output_ptr, stride_ob, stride_ot, | |
| # [num_tables, M, d_slot] embedding tables (float32) | |
| tables_ptr, stride_tn, stride_tm, stride_td, | |
| # [num_tables] per-table seeds (int64) | |
| seeds_ptr, | |
| # [num_ngram_orders] ngram order values, e.g. [2, 3] | |
| ngrams_ptr, | |
| # Scalars | |
| B, T: tl.constexpr, M, d_slot: tl.constexpr, | |
| num_tables: tl.constexpr, | |
| num_ngram_orders: tl.constexpr, # β€ 4 | |
| num_heads: tl.constexpr, # β€ 16 | |
| MAX_N: tl.constexpr, # max(ngram_orders), e.g. 3 | |
| BLOCK_D: tl.constexpr, # power-of-2 β₯ d_slot | |
| ): | |
| b_idx = tl.program_id(0) | |
| t_idx = tl.program_id(1) | |
| d_offs = tl.arange(0, BLOCK_D) | |
| d_mask = d_offs < d_slot | |
| # Pre-load the last MAX_N canonical tokens ending at t_idx (causal). | |
| # Positions before 0 are treated as padding (0). We unroll to pure scalars for Triton compatibility. | |
| # Crucial Safety: clamp pos using tl.where to ensure pointer arithmetic is never negative. | |
| c0 = tl.full((), 0, dtype=tl.int64) | |
| c1 = tl.full((), 0, dtype=tl.int64) | |
| c2 = tl.full((), 0, dtype=tl.int64) | |
| c3 = tl.full((), 0, dtype=tl.int64) | |
| if MAX_N >= 1: | |
| pos_raw = t_idx - (MAX_N - 1 - 0) | |
| valid = pos_raw >= 0 | |
| pos = tl.where(valid, pos_raw, 0) | |
| tok = tl.load( | |
| canonical_ptr + b_idx * stride_cb + pos * stride_ct, | |
| mask=valid, other=0, | |
| ) | |
| c0 = tl.where(valid, tok.to(tl.int64), tl.full((), 0, dtype=tl.int64)) | |
| if MAX_N >= 2: | |
| pos_raw = t_idx - (MAX_N - 1 - 1) | |
| valid = pos_raw >= 0 | |
| pos = tl.where(valid, pos_raw, 0) | |
| tok = tl.load( | |
| canonical_ptr + b_idx * stride_cb + pos * stride_ct, | |
| mask=valid, other=0, | |
| ) | |
| c1 = tl.where(valid, tok.to(tl.int64), tl.full((), 0, dtype=tl.int64)) | |
| if MAX_N >= 3: | |
| pos_raw = t_idx - (MAX_N - 1 - 2) | |
| valid = pos_raw >= 0 | |
| pos = tl.where(valid, pos_raw, 0) | |
| tok = tl.load( | |
| canonical_ptr + b_idx * stride_cb + pos * stride_ct, | |
| mask=valid, other=0, | |
| ) | |
| c2 = tl.where(valid, tok.to(tl.int64), tl.full((), 0, dtype=tl.int64)) | |
| if MAX_N >= 4: | |
| pos_raw = t_idx - (MAX_N - 1 - 3) | |
| valid = pos_raw >= 0 | |
| pos = tl.where(valid, pos_raw, 0) | |
| tok = tl.load( | |
| canonical_ptr + b_idx * stride_cb + pos * stride_ct, | |
| mask=valid, other=0, | |
| ) | |
| c3 = tl.where(valid, tok.to(tl.int64), tl.full((), 0, dtype=tl.int64)) | |
| # Iterate over all tables; loop bounds are constexpr β fully unrolled by compiler. | |
| for n_ord in tl.static_range(4): # β€ num_ngram_orders | |
| if n_ord < num_ngram_orders: | |
| n = tl.load(ngrams_ptr + n_ord).to(tl.int32) | |
| for k in tl.static_range(16): # β€ num_heads | |
| if k < num_heads: | |
| # Safety: Compute unique table_idx directly from static loop indices to avoid mutable variable register compilation bugs. | |
| table_idx = n_ord * num_heads + k | |
| seed = tl.load(seeds_ptr + table_idx).to(tl.int64) | |
| # XOR-multiplicative hash over the suffix n-gram. | |
| # loop over MAX_N positions; positions outside the suffix are skipped. | |
| h = seed | |
| for i in tl.static_range(MAX_N): | |
| include = i >= (MAX_N - n) | |
| tok = tl.full((), 0, dtype=tl.int64) | |
| if i == 0: | |
| tok = c0 | |
| elif i == 1: | |
| tok = c1 | |
| elif i == 2: | |
| tok = c2 | |
| elif i == 3: | |
| tok = c3 | |
| new_h = h * 2654435761 ^ tok | |
| h = tl.where(include, new_h, h) | |
| # Clamp absolute value of hash using tl.where for maximum Triton version safety | |
| idx = tl.where(h >= 0, h, -h) % M | |
| # Load d_slot floats from embed_tables[table_idx, idx] | |
| emb_base = table_idx * stride_tn + idx * stride_tm | |
| emb = tl.load( | |
| tables_ptr + emb_base + d_offs * stride_td, | |
| mask=d_mask, other=0.0, | |
| ) | |
| # Write to output[b, t, table_idx*d_slot : (table_idx+1)*d_slot] | |
| out_base = b_idx * stride_ob + t_idx * stride_ot + table_idx * d_slot | |
| tl.store( | |
| output_ptr + out_base + d_offs, | |
| emb.to(output_ptr.dtype.element_ty), | |
| mask=d_mask, | |
| ) | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # Custom Autograd Function for Fused Triton Training Lookup | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| class FusedEngramLookupFunction(torch.autograd.Function): | |
| def forward( | |
| ctx, | |
| canonical, | |
| embed_tables, | |
| seeds, | |
| ngram_orders_buf, | |
| M, | |
| d_slot, | |
| num_tables, | |
| num_ngram_orders, | |
| num_heads, | |
| ngram_orders, | |
| ): | |
| ctx.save_for_backward(canonical, seeds, ngram_orders_buf) | |
| ctx.M = M | |
| ctx.d_slot = d_slot | |
| ctx.num_tables = num_tables | |
| ctx.num_ngram_orders = num_ngram_orders | |
| ctx.num_heads = num_heads | |
| ctx.ngram_orders = ngram_orders | |
| ctx.embed_tables_shape = embed_tables.shape | |
| B, T = canonical.shape | |
| BLOCK_D = triton.next_power_of_2(d_slot) | |
| out = torch.empty( | |
| B, T, num_tables * d_slot, | |
| device=canonical.device, dtype=embed_tables.dtype, | |
| ) | |
| tables = embed_tables.contiguous() | |
| _engram_hash_lookup_kernel[(B, T)]( | |
| canonical.int().contiguous(), canonical.stride(0), canonical.stride(1), | |
| out, out.stride(0), out.stride(1), | |
| tables, tables.stride(0), tables.stride(1), tables.stride(2), | |
| seeds.contiguous(), | |
| ngram_orders_buf.contiguous(), | |
| B, T, M, d_slot, | |
| num_tables, num_ngram_orders, num_heads, | |
| MAX_N=max(ngram_orders), | |
| BLOCK_D=BLOCK_D, | |
| ) | |
| return out | |
| def backward(ctx, grad_output): | |
| canonical, seeds, ngram_orders_buf = ctx.saved_tensors | |
| B, T = canonical.shape | |
| device = canonical.device | |
| # 1. Re-compute hashes for each table in vectorized form | |
| all_hashes = torch.empty(ctx.num_tables, B * T, dtype=torch.long, device=device) | |
| table_idx = 0 | |
| for n_idx, n in enumerate(ctx.ngram_orders): | |
| padded = F.pad(canonical, (n - 1, 0), value=0) | |
| ngrams = padded.unfold(dimension=1, size=n, step=1) | |
| for k in range(ctx.num_heads): | |
| seed = int(seeds[table_idx].item()) | |
| h = torch.full(ngrams.shape[:2], seed, dtype=torch.long, device=device) | |
| for i in range(ngrams.shape[-1]): | |
| h = h * 2654435761 ^ ngrams[..., i] | |
| h = h.abs() % ctx.M | |
| all_hashes[table_idx] = h.view(B * T) | |
| table_idx += 1 | |
| # 2. Reshape and permute grad_output from [B, T, num_tables * d_slot] back to [num_tables, B * T, d_slot] | |
| grad_out_reshaped = grad_output.reshape(B, T, ctx.num_tables, ctx.d_slot).permute(2, 0, 1, 3).reshape(ctx.num_tables, B * T, ctx.d_slot) | |
| # 3. Accumulate gradients into grad_embed_tables using PyTorch's native CUDA-optimized index_put_ scatter-add | |
| grad_embed_tables = torch.zeros(ctx.embed_tables_shape, dtype=grad_output.dtype, device=device) | |
| tbl_idx = torch.arange(ctx.num_tables, device=device).unsqueeze(1).expand(ctx.num_tables, B * T) | |
| grad_embed_tables.index_put_((tbl_idx, all_hashes), grad_out_reshaped, accumulate=True) | |
| # Return gradients matching forward arguments (None for non-tensor / constant arguments) | |
| return None, grad_embed_tables, None, None, None, None, None, None, None, None | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # Lightweight RMSNorm (standalone; avoids circular import from quasar_rope) | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| class _RMSNorm(nn.Module): | |
| def __init__(self, dim: int, eps: float = 1e-6): | |
| super().__init__() | |
| self.weight = nn.Parameter(torch.ones(dim)) | |
| self.eps = eps | |
| def forward(self, x: torch.Tensor) -> torch.Tensor: | |
| dtype = x.dtype | |
| x = x.float() | |
| x = x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) | |
| return (self.weight * x).to(dtype) | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # EngramModule | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| class EngramModule(nn.Module): | |
| """ | |
| Engram Conditional Memory Module (DeepSeek-AI, arXiv:2601.07372). | |
| Replaces expensive attention layers for static N-gram patterns with | |
| O(1) hash-table lookups gated into the hidden state. | |
| All operations are fully vectorized β no Python loops over T or B: | |
| β’ N-gram extraction: torch.unfold (single op) | |
| β’ Hash computation: vectorized XOR accumulation (loop over n=2..3 only) | |
| β’ Embedding lookup: batched advanced indexing (single gather) | |
| β’ Conv: nn.Conv1d with causal pad + slice | |
| """ | |
| def __init__( | |
| self, | |
| vocab_size: int, | |
| d_model: int, | |
| d_mem: int, | |
| num_heads: int = 8, | |
| ngram_orders: list = None, | |
| target_slots: int = 5_700_000, | |
| n_layers: int = 24, | |
| ): | |
| super().__init__() | |
| if ngram_orders is None: | |
| ngram_orders = [2, 3] | |
| self.vocab_size = vocab_size | |
| self.d_model = d_model | |
| self.d_mem = d_mem | |
| self.num_heads = num_heads | |
| self.ngram_orders = list(ngram_orders) | |
| self.num_ngram_orders = len(ngram_orders) | |
| self.num_tables = self.num_ngram_orders * num_heads | |
| self.n_layers = n_layers | |
| # ββ A. Tokenizer Compression Buffer ββββββββββββββββββββββββββββββββββ | |
| # Surjective P: V β V', ~23% compression. | |
| # Deterministic multiplicative hash β no tokenizer object needed at | |
| # construction time (avoids FSDP serialization problems). | |
| compressed_size = max(1, int(vocab_size * 0.77)) | |
| self.compressed_vocab_size = compressed_size | |
| token_map = ( | |
| torch.arange(vocab_size, dtype=torch.long) * 2654435761 | |
| ) % compressed_size | |
| self.register_buffer('token_map', token_map) | |
| # ββ B. Embedding Tables βββββββββββββββββββββββββββββββββββββββββββββββ | |
| # All num_tables share the same prime size M for vectorized indexing. | |
| slots_per_table = max(1, target_slots // self.num_tables) | |
| self.M = _next_prime(slots_per_table) | |
| self.d_slot = max(16, d_mem // max(1, self.num_tables)) | |
| self.total_embed_dim = self.num_tables * self.d_slot | |
| # Single parameter tensor β enables batched advanced-index gather. | |
| self.embed_tables = nn.Parameter( | |
| torch.empty(self.num_tables, self.M, self.d_slot) | |
| ) | |
| # Per-table hashing seeds (non-trainable). | |
| seeds = torch.randint(1, 2 ** 31 - 1, (self.num_tables,), dtype=torch.long) | |
| self.register_buffer('seeds', seeds) | |
| # N-gram order list as a buffer for the Triton kernel. | |
| self.register_buffer( | |
| 'ngram_orders_buf', | |
| torch.tensor(self.ngram_orders, dtype=torch.long), | |
| ) | |
| # ββ C. Projection total_embed_dim β d_mem ββββββββββββββββββββββββββββ | |
| self.embed_proj = nn.Linear(self.total_embed_dim, d_mem, bias=False) | |
| # ββ D. Context-aware gating βββββββββββββββββββββββββββββββββββββββββββ | |
| self.q_proj = nn.Linear(d_model, d_mem, bias=False) | |
| self.W_K = nn.Linear(d_mem, d_mem, bias=False) | |
| self.W_V = nn.Linear(d_mem, d_mem, bias=False) | |
| # ββ E. Causal depthwise Conv1d ββββββββββββββββββββββββββββββββββββββββ | |
| # kernel=4, dilation=3 β causal receptive field = 1 + (4-1)*3 = 10 | |
| self.kernel_size = 4 | |
| self.dilation = 3 | |
| self.conv_norm = _RMSNorm(d_mem) | |
| self.conv = nn.Conv1d( | |
| d_mem, d_mem, | |
| kernel_size=self.kernel_size, | |
| dilation=self.dilation, | |
| groups=d_mem, # depthwise | |
| bias=False, | |
| ) | |
| # ββ F. Output projection d_mem β d_model βββββββββββββββββββββββββββββ | |
| self.out_proj = nn.Linear(d_mem, d_model, bias=False) | |
| # Triton eligible when compile-time bounds fit the kernel | |
| self._triton_ok = ( | |
| HAS_TRITON | |
| and self.num_ngram_orders <= 4 | |
| and num_heads <= 16 | |
| and max(ngram_orders) <= 3 | |
| ) | |
| self.triton_training = True | |
| self._init_weights() | |
| # ββ Initialization ββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def _init_weights(self): | |
| # 1. Deterministic buffer re-population (bypasses meta-device empty uninitialized memory) | |
| if hasattr(self, "token_map") and self.token_map is not None: | |
| dev = "cpu" if self.token_map.device.type == "meta" else self.token_map.device | |
| t_map = (torch.arange(self.vocab_size, dtype=torch.long, device=dev) * 2654435761) % self.compressed_vocab_size | |
| self.token_map.data.copy_(t_map) | |
| if hasattr(self, "seeds") and self.seeds is not None: | |
| # Deterministic hash seeds across all ranks | |
| g = torch.Generator().manual_seed(42) | |
| dev = "cpu" if self.seeds.device.type == "meta" else self.seeds.device | |
| s_t = torch.randint(1, 2 ** 31 - 1, (self.num_tables,), dtype=torch.long, device=dev, generator=g) | |
| self.seeds.data.copy_(s_t) | |
| if hasattr(self, "ngram_orders_buf") and self.ngram_orders_buf is not None: | |
| dev = "cpu" if self.ngram_orders_buf.device.type == "meta" else self.ngram_orders_buf.device | |
| ord_buf = torch.tensor(self.ngram_orders, dtype=torch.long, device=dev) | |
| self.ngram_orders_buf.data.copy_(ord_buf) | |
| trinity_std = 0.5 / math.sqrt(self.d_model) | |
| scale_factor = 1.0 / math.sqrt(2 * self.n_layers) | |
| # 2. Deep init on output β zero-init to guarantee exactly zero output at step 0 | |
| nn.init.zeros_(self.out_proj.weight) | |
| # Gating projections: standard Trinity | |
| nn.init.normal_(self.q_proj.weight, std=trinity_std) | |
| nn.init.normal_(self.W_K.weight, std=trinity_std) | |
| nn.init.normal_(self.W_V.weight, std=trinity_std) | |
| # embed_proj: standard Trinity | |
| nn.init.normal_(self.embed_proj.weight, std=trinity_std) | |
| # Conv: zero init β identity pass-through at step 0 | |
| nn.init.zeros_(self.conv.weight) | |
| # Embedding tables: small normal (paper standard) | |
| nn.init.normal_(self.embed_tables, std=0.01) | |
| # Conv norm: fill with ones | |
| if hasattr(self.conv_norm, "weight") and self.conv_norm.weight is not None: | |
| nn.init.ones_(self.conv_norm.weight) | |
| # Check for any non-finite initialization values | |
| for name, p in [("out_proj", self.out_proj.weight), ("q_proj", self.q_proj.weight), | |
| ("W_K", self.W_K.weight), ("W_V", self.W_V.weight), | |
| ("embed_proj", self.embed_proj.weight), ("conv", self.conv.weight), | |
| ("embed_tables", self.embed_tables), ("conv_norm", self.conv_norm.weight)]: | |
| if p.device.type != "meta": | |
| if not torch.isfinite(p).all(): | |
| print(f"[engram-init-warn] Parameter {name} contains non-finite values! Re-initializing with zeros.", flush=True) | |
| nn.init.zeros_(p) | |
| # ββ Core helpers ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def _hash_ngrams(ngrams: torch.Tensor, table_size: int, seed: int) -> torch.Tensor: | |
| """ | |
| Vectorized XOR-multiplicative hash. | |
| ngrams: [B, T, n] β n β {2, 3}, compile-time constant. | |
| Returns: [B, T] β indices into embedding table. | |
| No loop over T or B; only loops over n (β€ 3). | |
| """ | |
| h = torch.full(ngrams.shape[:2], seed, dtype=torch.long, device=ngrams.device) | |
| for i in range(ngrams.shape[-1]): # n iterations, NOT T | |
| h = h * 2654435761 ^ ngrams[..., i] | |
| return h.abs() % table_size | |
| def _lookup_pytorch(self, canonical: torch.Tensor) -> torch.Tensor: | |
| """ | |
| Pure-PyTorch path: fully vectorized, no T/B loops. | |
| Steps: | |
| 1. For each n-gram order, extract suffix n-grams via unfold β [B, T, n] | |
| 2. Hash all (n, k) pairs β [num_tables, B, T] | |
| 3. Batched advanced-index gather from embed_tables β [num_tables, B*T, d_slot] | |
| 4. Reshape to [B, T, total_embed_dim] | |
| """ | |
| B, T = canonical.shape | |
| device = canonical.device | |
| # Step 1+2: collect hashes for all tables β loop over num_tables (β€ 32, not over T) | |
| all_hashes = torch.empty(self.num_tables, B * T, dtype=torch.long, device=device) | |
| table_idx = 0 | |
| seeds_cpu = self.seeds.cpu().tolist() if hasattr(self, "seeds") and self.seeds is not None else [] | |
| for n_idx, n in enumerate(self.ngram_orders): # 2 or 3 iterations | |
| # Vectorized n-gram extraction: unfold over T β [B, T, n] | |
| padded = F.pad(canonical, (n - 1, 0), value=0) # [B, T+n-1] | |
| ngrams = padded.unfold(dimension=1, size=n, step=1) # [B, T, n] | |
| for k in range(self.num_heads): # num_heads iterations (β€ 16) | |
| seed = seeds_cpu[table_idx] if table_idx < len(seeds_cpu) else 42 | |
| h = self._hash_ngrams(ngrams, self.M, seed) # [B, T] | |
| all_hashes[table_idx] = h.view(B * T) | |
| table_idx += 1 | |
| # Step 3: Single batched gather β no loop over T | |
| # embed_tables: [num_tables, M, d_slot] | |
| # all_hashes: [num_tables, B*T] | |
| # Expand table index for advanced indexing | |
| tbl_idx = torch.arange(self.num_tables, device=device).unsqueeze(1).expand( | |
| self.num_tables, B * T | |
| ) # [num_tables, B*T] | |
| embeddings = self.embed_tables[tbl_idx, all_hashes] # [num_tables, B*T, d_slot] | |
| # Step 4: Reshape to [B, T, total_embed_dim] | |
| embeddings = embeddings.permute(1, 0, 2) # [B*T, num_tables, d_slot] | |
| return embeddings.reshape(B, T, self.total_embed_dim) | |
| def _lookup_triton(self, canonical: torch.Tensor) -> torch.Tensor: | |
| """ | |
| Triton path: fused hash + lookup in a single SRAM pass. | |
| Uses FusedEngramLookupFunction to support exact backward auto-differentiation in training. | |
| """ | |
| return FusedEngramLookupFunction.apply( | |
| canonical, | |
| self.embed_tables, | |
| self.seeds, | |
| self.ngram_orders_buf, | |
| self.M, | |
| self.d_slot, | |
| self.num_tables, | |
| self.num_ngram_orders, | |
| self.num_heads, | |
| self.ngram_orders, | |
| ) | |
| # ββ Forward βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def forward( | |
| self, | |
| input_ids: torch.Tensor, # [B, T] raw token IDs | |
| hidden_states: torch.Tensor, # [B, T, d_model] | |
| ) -> tuple[torch.Tensor, torch.Tensor]: | |
| """ | |
| Returns: | |
| engram_out : [B, T, d_model] β to add to residual stream | |
| alpha_mean : scalar tensor β mean gate value for LatentMemory suppression | |
| """ | |
| B, T = input_ids.shape | |
| orig_dtype = hidden_states.dtype | |
| # ββ A. Token compression βββββββββββββββββββββββββββββββββββββββββββββ | |
| # Single gather op β no loop | |
| canonical = self.token_map[input_ids.clamp(0, self.vocab_size - 1)] # [B, T] | |
| # ββ B+C. Hash β lookup β project to d_mem βββββββββββββββββββββββββββ | |
| use_triton_lookup = self._triton_ok and canonical.is_cuda and ( | |
| self.training is False or bool(getattr(self, "triton_training", False)) | |
| ) | |
| if use_triton_lookup: | |
| raw_embed = self._lookup_triton(canonical) # [B, T, total_embed_dim] | |
| else: | |
| raw_embed = self._lookup_pytorch(canonical) # [B, T, total_embed_dim] | |
| raw_embed = raw_embed.to(orig_dtype) | |
| debug_engram = bool(int(os.environ.get("ENGRAM_DEBUG", "0"))) | |
| if debug_engram: | |
| print(f"[engram-debug] embed_tables: finite={torch.isfinite(self.embed_tables).all().item()} min={self.embed_tables.float().min().item():.6g} max={self.embed_tables.float().max().item():.6g}", flush=True) | |
| print(f"[engram-debug] hidden_states: finite={torch.isfinite(hidden_states).all().item()} min={hidden_states.float().min().item():.6g} max={hidden_states.float().max().item():.6g}", flush=True) | |
| print(f"[engram-debug] raw_embed: finite={torch.isfinite(raw_embed).all().item()} min={raw_embed.float().min().item():.6g} max={raw_embed.float().max().item():.6g}", flush=True) | |
| raw_embed = torch.nan_to_num(raw_embed, nan=0.0, posinf=0.0, neginf=0.0).clamp_(-10.0, 10.0) | |
| e_t = self.embed_proj(raw_embed) # [B, T, d_mem] | |
| if debug_engram: | |
| print(f"[engram-debug] e_t: finite={torch.isfinite(e_t).all().item()} min={e_t.float().min().item():.6g} max={e_t.float().max().item():.6g}", flush=True) | |
| e_t = torch.nan_to_num(e_t, nan=0.0, posinf=0.0, neginf=0.0).clamp_(-100.0, 100.0) | |
| # ββ D. Context-aware gating ββββββββββββββββββββββββββββββββββββββββββ | |
| h_proj = self.q_proj(hidden_states) # [B, T, d_mem] | |
| k_t = self.W_K(e_t) # [B, T, d_mem] | |
| v_t = self.W_V(e_t) # [B, T, d_mem] | |
| if debug_engram: | |
| print(f"[engram-debug] h_proj: finite={torch.isfinite(h_proj).all().item()} min={h_proj.float().min().item():.6g} max={h_proj.float().max().item():.6g}", flush=True) | |
| print(f"[engram-debug] k_t: finite={torch.isfinite(k_t).all().item()} min={k_t.float().min().item():.6g} max={k_t.float().max().item():.6g}", flush=True) | |
| print(f"[engram-debug] v_t: finite={torch.isfinite(v_t).all().item()} min={v_t.float().min().item():.6g} max={v_t.float().max().item():.6g}", flush=True) | |
| h_proj = torch.nan_to_num(h_proj, nan=0.0, posinf=0.0, neginf=0.0).clamp_(-100.0, 100.0) | |
| k_t = torch.nan_to_num(k_t, nan=0.0, posinf=0.0, neginf=0.0).clamp_(-100.0, 100.0) | |
| v_t = torch.nan_to_num(v_t, nan=0.0, posinf=0.0, neginf=0.0).clamp_(-100.0, 100.0) | |
| # L2-normalize for stability (matches Quasar key normalization) | |
| q_norm = F.normalize(h_proj.float(), dim=-1, eps=1e-6).to(orig_dtype) | |
| k_norm = F.normalize(k_t.float(), dim=-1, eps=1e-6).to(orig_dtype) | |
| # Scalar gate per token per position | |
| alpha_logits = (q_norm * k_norm).sum(-1, keepdim=True).float() / math.sqrt(self.d_mem) | |
| alpha_t = torch.sigmoid(alpha_logits.clamp_(-30.0, 30.0)).to(orig_dtype) # [B, T, 1] | |
| if debug_engram: | |
| print(f"[engram-debug] alpha_t: finite={torch.isfinite(alpha_t).all().item()} min={alpha_t.float().min().item():.6g} max={alpha_t.float().max().item():.6g}", flush=True) | |
| gated = alpha_t * v_t # [B, T, d_mem] | |
| gated = torch.nan_to_num(gated, nan=0.0, posinf=0.0, neginf=0.0).clamp_(-100.0, 100.0) | |
| # ββ E. Causal depthwise conv βββββββββββββββββββββββββββββββββββββββββ | |
| # Fully vectorized: F.pad + Conv1d + slice β no loop over T | |
| causal_pad = (self.kernel_size - 1) * self.dilation | |
| g_norm = self.conv_norm(gated) # [B, T, d_mem] | |
| if debug_engram: | |
| print(f"[engram-debug] gated: finite={torch.isfinite(gated).all().item()} min={gated.float().min().item():.6g} max={gated.float().max().item():.6g}", flush=True) | |
| print(f"[engram-debug] conv_norm.weight: finite={torch.isfinite(self.conv_norm.weight).all().item()} min={self.conv_norm.weight.float().min().item():.6g} max={self.conv_norm.weight.float().max().item():.6g}", flush=True) | |
| print(f"[engram-debug] g_norm: finite={torch.isfinite(g_norm).all().item()} min={g_norm.float().min().item():.6g} max={g_norm.float().max().item():.6g}", flush=True) | |
| g_norm = torch.nan_to_num(g_norm, nan=0.0, posinf=0.0, neginf=0.0).clamp_(-100.0, 100.0) | |
| g_t = g_norm.transpose(1, 2) # [B, d_mem, T] | |
| g_t = F.pad(g_t, (causal_pad, 0)) # [B, d_mem, T+pad] | |
| g_t = self.conv(g_t)[..., :T] # [B, d_mem, T] | |
| g_t = F.silu(g_t).transpose(1, 2) # [B, T, d_mem] | |
| Y = g_t + gated # residual | |
| if debug_engram: | |
| print(f"[engram-debug] Y: finite={torch.isfinite(Y).all().item()} min={Y.float().min().item():.6g} max={Y.float().max().item():.6g}", flush=True) | |
| Y = torch.nan_to_num(Y, nan=0.0, posinf=0.0, neginf=0.0).clamp_(-100.0, 100.0) | |
| # ββ F. Output projection βββββββββββββββββββββββββββββββββββββββββββββ | |
| engram_out = self.out_proj(Y) # [B, T, d_model] | |
| if debug_engram: | |
| print(f"[engram-debug] engram_out: finite={torch.isfinite(engram_out).all().item()} min={engram_out.float().min().item():.6g} max={engram_out.float().max().item():.6g}", flush=True) | |
| engram_out = torch.nan_to_num(engram_out, nan=0.0, posinf=0.0, neginf=0.0).clamp_(-100.0, 100.0) | |
| # alpha_mean: mean gate activity β used by LatentMemory for suppression | |
| alpha_mean = alpha_t.squeeze(-1) # [B, T] | |
| return engram_out, alpha_mean | |