Spaces:
Paused
Paused
| from __future__ import annotations | |
| import math | |
| from dataclasses import asdict, dataclass, field | |
| from pathlib import Path | |
| from typing import Any, Literal | |
| import numpy as np | |
| StateMode = Literal["M0", "M3"] | |
| class StateTileSpec: | |
| state_rows: int | |
| state_cols: int | |
| group_size: int = 32 | |
| bits: int = 8 | |
| mode: StateMode = "M0" | |
| escape_dtype: str = "float32" | |
| def to_dict(self) -> dict[str, int | str]: | |
| return asdict(self) | |
| class StateLayerRecord: | |
| layer_id: int | |
| layer_type: str | |
| state_family: str | |
| conv_state_bytes: int | |
| recurrent_state_bytes: int | |
| layer_state_bytes: int | |
| state_shapes: dict[str, list[int]] = field(default_factory=dict) | |
| state_delta_norms: list[dict[str, float | int]] = field(default_factory=list) | |
| def to_dict(self) -> dict[str, Any]: | |
| return asdict(self) | |
| class StateAblationResult: | |
| stage_name: str | |
| bits: int | None | |
| max_abs_error: float | |
| max_rel_error: float | |
| output_max_abs_error: float | |
| error_grows_step_to_step: bool | |
| per_layer_max_abs_error: dict[str, float] = field(default_factory=dict) | |
| per_layer_max_rel_error: dict[str, float] = field(default_factory=dict) | |
| per_layer_output_max_abs_error: dict[str, float] = field(default_factory=dict) | |
| per_step_output_max_abs_error: list[float] = field(default_factory=list) | |
| def to_dict(self) -> dict[str, Any]: | |
| return asdict(self) | |
| class StateSimResult: | |
| mode: StateMode | |
| bits: int | |
| group_size: int | |
| renorm_interval: int | |
| bytes_per_token: int | |
| bytes_per_layer: int | |
| effective_compression_ratio: float | |
| update_error_curve: list[float] = field(default_factory=list) | |
| readout_error_curve: list[float] = field(default_factory=list) | |
| def to_dict(self) -> dict[str, Any]: | |
| return asdict(self) | |
| class CapturedStateSample: | |
| source: str | |
| state_kind: str | |
| layer_id: int | |
| prompt_length: int | |
| token_indices: list[int] | |
| initial_state: np.ndarray | |
| update_deltas: np.ndarray | |
| def state_rows(self) -> int: | |
| return int(np.prod(self.initial_state.shape[:-1], dtype=np.int64)) if self.initial_state.ndim > 1 else 1 | |
| def state_cols(self) -> int: | |
| return int(self.initial_state.shape[-1]) | |
| def steps(self) -> int: | |
| return int(self.update_deltas.shape[0]) | |
| def to_dict(self) -> dict[str, Any]: | |
| return { | |
| "source": self.source, | |
| "state_kind": self.state_kind, | |
| "layer_id": self.layer_id, | |
| "prompt_length": self.prompt_length, | |
| "token_indices": list(self.token_indices), | |
| "state_rows": self.state_rows, | |
| "state_cols": self.state_cols, | |
| "steps": self.steps, | |
| } | |
| def _codec_dtype_bytes(escape_dtype: str) -> int: | |
| if escape_dtype == "float16": | |
| return np.dtype(np.float16).itemsize | |
| if escape_dtype == "float32": | |
| return np.dtype(np.float32).itemsize | |
| raise ValueError(f"unsupported StateCache escape_dtype {escape_dtype!r}") | |
| def _renorm_rows(array: np.ndarray) -> np.ndarray: | |
| flat = array.reshape(-1, array.shape[-1]).astype(np.float32, copy=True) | |
| row_norms = np.linalg.norm(flat, axis=-1, keepdims=True) | |
| row_norms = np.maximum(row_norms, 1e-8) | |
| flat = flat / row_norms | |
| return flat.reshape(array.shape) | |
| def simulate_state_codec(tile: np.ndarray, spec: StateTileSpec) -> tuple[np.ndarray, int, int]: | |
| array = np.asarray(tile, dtype=np.float32) | |
| if array.ndim < 2: | |
| raise ValueError("state tiles must be at least 2D") | |
| if spec.mode == "M3": | |
| payload_nbytes = int(array.size * _codec_dtype_bytes(spec.escape_dtype)) | |
| return array.astype(np.float32, copy=True), payload_nbytes, 0 | |
| if spec.mode != "M0": | |
| raise ValueError(f"unsupported StateCache mode {spec.mode!r}") | |
| if spec.bits <= 0: | |
| raise ValueError("spec.bits must be positive") | |
| flat = array.reshape(-1, array.shape[-1]).astype(np.float32, copy=False) | |
| group_size = max(int(spec.group_size), 1) | |
| levels = max((1 << int(spec.bits)) - 1, 1) | |
| decoded = np.empty_like(flat) | |
| num_groups = 0 | |
| for row_id in range(flat.shape[0]): | |
| row = flat[row_id] | |
| for start in range(0, row.shape[0], group_size): | |
| end = min(start + group_size, row.shape[0]) | |
| group = row[start:end] | |
| num_groups += 1 | |
| lo = float(group.min(initial=0.0)) | |
| hi = float(group.max(initial=0.0)) | |
| if math.isclose(lo, hi, rel_tol=0.0, abs_tol=1e-8): | |
| decoded[row_id, start:end] = lo | |
| continue | |
| scale = (hi - lo) / float(levels) | |
| quantized = np.rint((group - lo) / scale).clip(0, levels) | |
| decoded[row_id, start:end] = quantized.astype(np.float32) * np.float32(scale) + np.float32(lo) | |
| payload_nbytes = int(math.ceil(array.size * int(spec.bits) / 8.0)) | |
| metadata_nbytes = int(num_groups * 8) | |
| return decoded.reshape(array.shape), payload_nbytes, metadata_nbytes | |
| def simulate_state_sequence( | |
| initial_state: np.ndarray, | |
| update_deltas: np.ndarray, | |
| readout_projections: np.ndarray, | |
| *, | |
| spec: StateTileSpec, | |
| renorm_interval: int = 0, | |
| ) -> StateSimResult: | |
| dense_state = np.asarray(initial_state, dtype=np.float32).copy() | |
| approx_state = dense_state.copy() | |
| deltas = np.asarray(update_deltas, dtype=np.float32) | |
| projections = np.asarray(readout_projections, dtype=np.float32) | |
| if deltas.ndim < 3: | |
| raise ValueError("update_deltas must have shape [steps, rows, cols]") | |
| if projections.ndim != 3: | |
| raise ValueError("readout_projections must have shape [steps, cols, output_dim]") | |
| if deltas.shape[0] != projections.shape[0]: | |
| raise ValueError("update_deltas and readout_projections must have the same step count") | |
| if tuple(deltas.shape[1:]) != tuple(dense_state.shape): | |
| raise ValueError("update_deltas shape must match initial_state") | |
| if projections.shape[1] != dense_state.shape[-1]: | |
| raise ValueError("readout projection input dimension must match state_cols") | |
| update_error_curve: list[float] = [] | |
| readout_error_curve: list[float] = [] | |
| stored_bytes = dense_state.nbytes | |
| compressed_bytes = stored_bytes | |
| for step_index in range(deltas.shape[0]): | |
| decoded_state, payload_nbytes, metadata_nbytes = simulate_state_codec(approx_state, spec) | |
| compressed_bytes = int(payload_nbytes + metadata_nbytes) | |
| dense_state = dense_state + deltas[step_index] | |
| approx_after = decoded_state + deltas[step_index] | |
| if renorm_interval > 0 and (step_index + 1) % int(renorm_interval) == 0: | |
| dense_state = _renorm_rows(dense_state) | |
| approx_after = _renorm_rows(approx_after) | |
| approx_state, _, _ = simulate_state_codec(approx_after, spec) | |
| dense_readout = dense_state @ projections[step_index] | |
| approx_readout = approx_state @ projections[step_index] | |
| update_error_curve.append(float(np.max(np.abs(approx_state - dense_state)))) | |
| readout_error_curve.append(float(np.max(np.abs(approx_readout - dense_readout)))) | |
| bytes_per_layer = int(compressed_bytes) | |
| bytes_per_token = int(bytes_per_layer * 2) | |
| effective_ratio = float(stored_bytes / max(bytes_per_layer, 1)) | |
| return StateSimResult( | |
| mode=spec.mode, | |
| bits=int(spec.bits), | |
| group_size=int(spec.group_size), | |
| renorm_interval=int(renorm_interval), | |
| bytes_per_token=bytes_per_token, | |
| bytes_per_layer=bytes_per_layer, | |
| effective_compression_ratio=effective_ratio, | |
| update_error_curve=update_error_curve, | |
| readout_error_curve=readout_error_curve, | |
| ) | |
| def load_captured_state_sample(path: str | Path) -> CapturedStateSample: | |
| with np.load(Path(path), allow_pickle=False) as data: | |
| source = str(data["source"].item()) | |
| state_kind = str(data["state_kind"].item()) | |
| layer_id = int(data["layer_id"].item()) | |
| prompt_length = int(data["prompt_length"].item()) | |
| token_indices = [int(value) for value in np.asarray(data["token_indices"]).tolist()] | |
| initial_state = np.asarray(data["initial_state"], dtype=np.float32) | |
| update_deltas = np.asarray(data["update_deltas"], dtype=np.float32) | |
| if initial_state.ndim < 2: | |
| raise ValueError("captured initial_state must be at least 2D") | |
| if update_deltas.ndim < 3: | |
| raise ValueError("captured update_deltas must be at least 3D") | |
| if tuple(update_deltas.shape[1:]) != tuple(initial_state.shape): | |
| raise ValueError("captured update_deltas shape must match initial_state shape") | |
| return CapturedStateSample( | |
| source=source, | |
| state_kind=state_kind, | |
| layer_id=layer_id, | |
| prompt_length=prompt_length, | |
| token_indices=token_indices, | |
| initial_state=initial_state, | |
| update_deltas=update_deltas, | |
| ) | |
| __all__ = [ | |
| "CapturedStateSample", | |
| "StateAblationResult", | |
| "StateLayerRecord", | |
| "StateSimResult", | |
| "StateTileSpec", | |
| "load_captured_state_sample", | |
| "simulate_state_codec", | |
| "simulate_state_sequence", | |
| ] | |