Spaces:
Running on Zero
Running on Zero
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| class SpatialGridMemory(nn.Module): | |
| def __init__(self, dim: int, grid_size: int = 8, num_tokens: int = 64): | |
| super().__init__() | |
| self.dim = int(dim) | |
| self.grid_size = int(grid_size) | |
| self.num_tokens = int(num_tokens) | |
| g2 = self.grid_size * self.grid_size | |
| # Keep key name aligned with ckpt loading in loop_utils.py (spatial_to_tokens). | |
| self.spatial_to_tokens = nn.Parameter(torch.zeros(g2, self.num_tokens)) | |
| nn.init.normal_(self.spatial_to_tokens, std=0.02) | |
| def mix(self): | |
| # Backward compatibility for code that referenced the old attribute name. | |
| return self.spatial_to_tokens | |
| def forward(self, x_context: torch.Tensor, num_context_frames: int, h: int, w: int): | |
| # x_context: (B, K*H*W, D) | |
| if x_context is None or x_context.ndim != 3: | |
| return x_context | |
| b, n, d = x_context.shape | |
| if d != self.dim: | |
| raise ValueError(f"SpatialGridMemory dim mismatch: x={d} module={self.dim}") | |
| k = max(int(num_context_frames), 1) | |
| spatial = int(h) * int(w) | |
| if n != k * spatial: | |
| # Best effort fallback: treat x as a flat token map and pool directly. | |
| x_mean = x_context | |
| else: | |
| x_mean = x_context.reshape(b, k, spatial, d).mean(dim=1) # (B, S, D) | |
| g2 = self.grid_size * self.grid_size | |
| pooled = F.adaptive_avg_pool1d(x_mean.transpose(1, 2), g2).transpose(1, 2) # (B, G2, D) | |
| mix = torch.softmax(self.spatial_to_tokens, dim=0) # (G2, M) | |
| mem = torch.einsum("bgd,gm->bmd", pooled, mix) # (B, M, D) | |
| return mem | |
| def load_state_dict(self, state_dict, strict: bool = True): | |
| # Compatibility: | |
| # - old local key: mix | |
| # - current/baseline key: spatial_to_tokens | |
| sd = dict(state_dict) | |
| if "mix" in sd and "spatial_to_tokens" not in sd: | |
| sd["spatial_to_tokens"] = sd.pop("mix") | |
| # Ignore deprecated projection keys from prior experiments. | |
| sd.pop("out.weight", None) | |
| sd.pop("out.bias", None) | |
| return super().load_state_dict(sd, strict=False if not strict else strict) | |
| class SpatialCrossAttnReadout(nn.Module): | |
| def __init__(self, dim: int, num_heads: int = 8): | |
| super().__init__() | |
| self.attn = nn.MultiheadAttention(embed_dim=int(dim), num_heads=int(num_heads), batch_first=True) | |
| self.gate = nn.Parameter(torch.zeros(1)) | |
| def forward(self, x_target: torch.Tensor, mem_tokens: torch.Tensor): | |
| if x_target is None or mem_tokens is None: | |
| return x_target | |
| if x_target.numel() == 0 or mem_tokens.numel() == 0: | |
| return x_target | |
| delta, _ = self.attn(x_target, mem_tokens, mem_tokens, need_weights=False) | |
| return x_target + torch.tanh(self.gate) * delta | |
| def apply_spatial_cross_attn_readout(x_target: torch.Tensor, mem_tokens: torch.Tensor, module: nn.Module = None): | |
| if module is None: | |
| module = SpatialCrossAttnReadout(dim=int(x_target.shape[-1]), num_heads=8).to(device=x_target.device, dtype=x_target.dtype) | |
| return module(x_target, mem_tokens) | |
| def inject_spatial_memory(context: torch.Tensor, mem_tokens: torch.Tensor, mode: str = "concat_text"): | |
| mode = str(mode or "concat_text").lower() | |
| if mem_tokens is None or mode == "none": | |
| return context | |
| if context is None: | |
| return mem_tokens | |
| if mode in ("concat_text", "cross_attn_readout"): | |
| return torch.cat([context, mem_tokens], dim=1) | |
| return context | |