DotCache-Arena / dotcache /state_cache_sim.py
DeanoCalver's picture
Initial DotCache Arena Space upload
751ad26 verified
from __future__ import annotations
import math
from dataclasses import asdict, dataclass, field
from pathlib import Path
from typing import Any, Literal
import numpy as np
StateMode = Literal["M0", "M3"]
@dataclass(slots=True)
class StateTileSpec:
state_rows: int
state_cols: int
group_size: int = 32
bits: int = 8
mode: StateMode = "M0"
escape_dtype: str = "float32"
def to_dict(self) -> dict[str, int | str]:
return asdict(self)
@dataclass(slots=True)
class StateLayerRecord:
layer_id: int
layer_type: str
state_family: str
conv_state_bytes: int
recurrent_state_bytes: int
layer_state_bytes: int
state_shapes: dict[str, list[int]] = field(default_factory=dict)
state_delta_norms: list[dict[str, float | int]] = field(default_factory=list)
def to_dict(self) -> dict[str, Any]:
return asdict(self)
@dataclass(slots=True)
class StateAblationResult:
stage_name: str
bits: int | None
max_abs_error: float
max_rel_error: float
output_max_abs_error: float
error_grows_step_to_step: bool
per_layer_max_abs_error: dict[str, float] = field(default_factory=dict)
per_layer_max_rel_error: dict[str, float] = field(default_factory=dict)
per_layer_output_max_abs_error: dict[str, float] = field(default_factory=dict)
per_step_output_max_abs_error: list[float] = field(default_factory=list)
def to_dict(self) -> dict[str, Any]:
return asdict(self)
@dataclass(slots=True)
class StateSimResult:
mode: StateMode
bits: int
group_size: int
renorm_interval: int
bytes_per_token: int
bytes_per_layer: int
effective_compression_ratio: float
update_error_curve: list[float] = field(default_factory=list)
readout_error_curve: list[float] = field(default_factory=list)
def to_dict(self) -> dict[str, Any]:
return asdict(self)
@dataclass(slots=True)
class CapturedStateSample:
source: str
state_kind: str
layer_id: int
prompt_length: int
token_indices: list[int]
initial_state: np.ndarray
update_deltas: np.ndarray
@property
def state_rows(self) -> int:
return int(np.prod(self.initial_state.shape[:-1], dtype=np.int64)) if self.initial_state.ndim > 1 else 1
@property
def state_cols(self) -> int:
return int(self.initial_state.shape[-1])
@property
def steps(self) -> int:
return int(self.update_deltas.shape[0])
def to_dict(self) -> dict[str, Any]:
return {
"source": self.source,
"state_kind": self.state_kind,
"layer_id": self.layer_id,
"prompt_length": self.prompt_length,
"token_indices": list(self.token_indices),
"state_rows": self.state_rows,
"state_cols": self.state_cols,
"steps": self.steps,
}
def _codec_dtype_bytes(escape_dtype: str) -> int:
if escape_dtype == "float16":
return np.dtype(np.float16).itemsize
if escape_dtype == "float32":
return np.dtype(np.float32).itemsize
raise ValueError(f"unsupported StateCache escape_dtype {escape_dtype!r}")
def _renorm_rows(array: np.ndarray) -> np.ndarray:
flat = array.reshape(-1, array.shape[-1]).astype(np.float32, copy=True)
row_norms = np.linalg.norm(flat, axis=-1, keepdims=True)
row_norms = np.maximum(row_norms, 1e-8)
flat = flat / row_norms
return flat.reshape(array.shape)
def simulate_state_codec(tile: np.ndarray, spec: StateTileSpec) -> tuple[np.ndarray, int, int]:
array = np.asarray(tile, dtype=np.float32)
if array.ndim < 2:
raise ValueError("state tiles must be at least 2D")
if spec.mode == "M3":
payload_nbytes = int(array.size * _codec_dtype_bytes(spec.escape_dtype))
return array.astype(np.float32, copy=True), payload_nbytes, 0
if spec.mode != "M0":
raise ValueError(f"unsupported StateCache mode {spec.mode!r}")
if spec.bits <= 0:
raise ValueError("spec.bits must be positive")
flat = array.reshape(-1, array.shape[-1]).astype(np.float32, copy=False)
group_size = max(int(spec.group_size), 1)
levels = max((1 << int(spec.bits)) - 1, 1)
decoded = np.empty_like(flat)
num_groups = 0
for row_id in range(flat.shape[0]):
row = flat[row_id]
for start in range(0, row.shape[0], group_size):
end = min(start + group_size, row.shape[0])
group = row[start:end]
num_groups += 1
lo = float(group.min(initial=0.0))
hi = float(group.max(initial=0.0))
if math.isclose(lo, hi, rel_tol=0.0, abs_tol=1e-8):
decoded[row_id, start:end] = lo
continue
scale = (hi - lo) / float(levels)
quantized = np.rint((group - lo) / scale).clip(0, levels)
decoded[row_id, start:end] = quantized.astype(np.float32) * np.float32(scale) + np.float32(lo)
payload_nbytes = int(math.ceil(array.size * int(spec.bits) / 8.0))
metadata_nbytes = int(num_groups * 8)
return decoded.reshape(array.shape), payload_nbytes, metadata_nbytes
def simulate_state_sequence(
initial_state: np.ndarray,
update_deltas: np.ndarray,
readout_projections: np.ndarray,
*,
spec: StateTileSpec,
renorm_interval: int = 0,
) -> StateSimResult:
dense_state = np.asarray(initial_state, dtype=np.float32).copy()
approx_state = dense_state.copy()
deltas = np.asarray(update_deltas, dtype=np.float32)
projections = np.asarray(readout_projections, dtype=np.float32)
if deltas.ndim < 3:
raise ValueError("update_deltas must have shape [steps, rows, cols]")
if projections.ndim != 3:
raise ValueError("readout_projections must have shape [steps, cols, output_dim]")
if deltas.shape[0] != projections.shape[0]:
raise ValueError("update_deltas and readout_projections must have the same step count")
if tuple(deltas.shape[1:]) != tuple(dense_state.shape):
raise ValueError("update_deltas shape must match initial_state")
if projections.shape[1] != dense_state.shape[-1]:
raise ValueError("readout projection input dimension must match state_cols")
update_error_curve: list[float] = []
readout_error_curve: list[float] = []
stored_bytes = dense_state.nbytes
compressed_bytes = stored_bytes
for step_index in range(deltas.shape[0]):
decoded_state, payload_nbytes, metadata_nbytes = simulate_state_codec(approx_state, spec)
compressed_bytes = int(payload_nbytes + metadata_nbytes)
dense_state = dense_state + deltas[step_index]
approx_after = decoded_state + deltas[step_index]
if renorm_interval > 0 and (step_index + 1) % int(renorm_interval) == 0:
dense_state = _renorm_rows(dense_state)
approx_after = _renorm_rows(approx_after)
approx_state, _, _ = simulate_state_codec(approx_after, spec)
dense_readout = dense_state @ projections[step_index]
approx_readout = approx_state @ projections[step_index]
update_error_curve.append(float(np.max(np.abs(approx_state - dense_state))))
readout_error_curve.append(float(np.max(np.abs(approx_readout - dense_readout))))
bytes_per_layer = int(compressed_bytes)
bytes_per_token = int(bytes_per_layer * 2)
effective_ratio = float(stored_bytes / max(bytes_per_layer, 1))
return StateSimResult(
mode=spec.mode,
bits=int(spec.bits),
group_size=int(spec.group_size),
renorm_interval=int(renorm_interval),
bytes_per_token=bytes_per_token,
bytes_per_layer=bytes_per_layer,
effective_compression_ratio=effective_ratio,
update_error_curve=update_error_curve,
readout_error_curve=readout_error_curve,
)
def load_captured_state_sample(path: str | Path) -> CapturedStateSample:
with np.load(Path(path), allow_pickle=False) as data:
source = str(data["source"].item())
state_kind = str(data["state_kind"].item())
layer_id = int(data["layer_id"].item())
prompt_length = int(data["prompt_length"].item())
token_indices = [int(value) for value in np.asarray(data["token_indices"]).tolist()]
initial_state = np.asarray(data["initial_state"], dtype=np.float32)
update_deltas = np.asarray(data["update_deltas"], dtype=np.float32)
if initial_state.ndim < 2:
raise ValueError("captured initial_state must be at least 2D")
if update_deltas.ndim < 3:
raise ValueError("captured update_deltas must be at least 3D")
if tuple(update_deltas.shape[1:]) != tuple(initial_state.shape):
raise ValueError("captured update_deltas shape must match initial_state shape")
return CapturedStateSample(
source=source,
state_kind=state_kind,
layer_id=layer_id,
prompt_length=prompt_length,
token_indices=token_indices,
initial_state=initial_state,
update_deltas=update_deltas,
)
__all__ = [
"CapturedStateSample",
"StateAblationResult",
"StateLayerRecord",
"StateSimResult",
"StateTileSpec",
"load_captured_state_sample",
"simulate_state_codec",
"simulate_state_sequence",
]