DotCache-Arena / dotcache /encode.py
DeanoCalver's picture
Initial DotCache Arena Space upload
751ad26 verified
from __future__ import annotations
import numpy as np
from .config import DotCacheConfig
from .planner import PageModeSpec
from .modes.m0_affine import quantize_tensor
from .modes.m1_lut import quantize_tensor_lut
from .modes.m2_key_sketch import quantize_tensor_m2, reconstruct_group_m2
from .modes.m4_key_project import quantize_tensor_m4, reconstruct_group_m4
from .modes.m3_escape import encode_escape_storage
from .modes.turbo3 import quantize_tensor_turbo3
from .page_format import build_payload
from .packing import words_per_group
from .types import EncodedPage, Kind, PageHeader
DEFAULT_RUNTIME_SKETCH_ROWS = 4
def _reconstruct_lut_page(codes: np.ndarray, codebooks: np.ndarray) -> np.ndarray:
token_count, num_groups, group_size = codes.shape
dense = np.zeros((token_count, num_groups * group_size), dtype=np.float32)
for group_index in range(num_groups):
start = group_index * group_size
end = start + group_size
group_codebook = codebooks[group_index].astype(np.float32)
if group_codebook.ndim == 1:
dense[:, start:end] = group_codebook[codes[:, group_index].astype(np.int64)]
else:
segment_count = group_codebook.shape[0]
segment_ids = (np.arange(token_count, dtype=np.int64) * segment_count) // max(token_count, 1)
dense[:, start:end] = group_codebook[segment_ids[:, None], codes[:, group_index].astype(np.int64)]
return dense
def _reconstruct_m2_page(coeffs: np.ndarray, basis: np.ndarray, mean: np.ndarray | None, *, group_size: int) -> np.ndarray:
token_count, num_groups, _ = coeffs.shape
dense = np.zeros((token_count, num_groups * group_size), dtype=np.float32)
for group_index in range(num_groups):
start = group_index * group_size
end = start + group_size
dense[:, start:end] = reconstruct_group_m2(
coeffs[:, group_index, :],
basis=basis[group_index],
mean=None if mean is None else mean[group_index],
)
return dense
def _reconstruct_m4_page(coeffs: np.ndarray, mean: np.ndarray, *, group_size: int) -> np.ndarray:
token_count, num_groups, _ = coeffs.shape
dense = np.zeros((token_count, num_groups * group_size), dtype=np.float32)
for group_index in range(num_groups):
start = group_index * group_size
end = start + group_size
dense[:, start:end] = reconstruct_group_m4(
coeffs[:, group_index, :],
mean=mean[group_index],
group_size=group_size,
)
return dense
def _build_runtime_page_sketch(values: np.ndarray, *, sketch_rows: int = DEFAULT_RUNTIME_SKETCH_ROWS) -> tuple[np.ndarray, np.ndarray]:
rows = min(max(1, sketch_rows), values.shape[0])
chunks = np.array_split(values, rows, axis=0)
sketch = np.stack([chunk.mean(axis=0) for chunk in chunks], axis=0).astype(np.float16)
page_mean = values.mean(axis=0).astype(np.float16)
return page_mean, sketch
def _build_runtime_page_envelope(values: np.ndarray) -> tuple[np.ndarray, np.ndarray]:
page_min = values.min(axis=0).astype(np.float16)
page_max = values.max(axis=0).astype(np.float16)
return page_min, page_max
def _candidate_m2_segment_counts(max_segment_count: int) -> list[int]:
max_count = max(1, int(max_segment_count))
counts = [1]
candidate = 2
while candidate < max_count:
counts.append(candidate)
candidate *= 2
if max_count not in counts:
counts.append(max_count)
return counts
def _encode_m2_tensor(values: np.ndarray, config: DotCacheConfig) -> tuple[np.ndarray, np.ndarray, np.ndarray, int]:
best_coeffs, best_basis, best_mean, padded_head_dim = quantize_tensor_m2(
values,
group_size=config.group_size,
sketch_dim=config.m2_sketch_dim_k,
center=config.m2_center_k,
segment_count=1 if config.m2_adaptive_segments_k else config.m2_segment_count_k,
)
if not config.m2_adaptive_segments_k or config.m2_segment_count_k <= 1:
return best_coeffs, best_basis, best_mean, padded_head_dim
baseline = _reconstruct_m2_page(best_coeffs, best_basis, best_mean, group_size=config.group_size)[:, : config.head_dim]
rms = float(np.sqrt(np.mean(np.square(values), dtype=np.float64)))
best_error = float(np.mean(np.abs(values - baseline), dtype=np.float64) / max(rms, 1e-6))
for segment_count in _candidate_m2_segment_counts(config.m2_segment_count_k)[1:]:
coeffs, basis, mean, padded_head_dim = quantize_tensor_m2(
values,
group_size=config.group_size,
sketch_dim=config.m2_sketch_dim_k,
center=config.m2_center_k,
segment_count=segment_count,
)
reconstructed = _reconstruct_m2_page(coeffs, basis, mean, group_size=config.group_size)[:, : config.head_dim]
trial_error = float(np.mean(np.abs(values - reconstructed), dtype=np.float64) / max(rms, 1e-6))
if (best_error - trial_error) / max(best_error, 1e-6) >= config.m2_adaptive_min_improvement_k:
best_coeffs, best_basis, best_mean = coeffs, basis, mean
best_error = trial_error
return best_coeffs, best_basis, best_mean, padded_head_dim
def encode_page(
tensor_slice: np.ndarray,
config: DotCacheConfig,
*,
kind: Kind,
layer_id: int = 0,
kv_head_id: int = 0,
token_start: int = 0,
mode: str | None = None,
page_mode: PageModeSpec | None = None,
layout: str | None = None,
quant_scheme: str | None = None,
build_runtime_metadata: bool = True,
build_m2_sidecar: bool | None = None,
m4_basis_override: np.ndarray | None = None,
) -> EncodedPage:
values = np.asarray(tensor_slice, dtype=np.float32)
if values.ndim != 2:
raise ValueError("tensor_slice must have shape [token_count, head_dim]")
if values.shape[1] != config.head_dim:
raise ValueError("tensor_slice head_dim must match config.head_dim")
bits = config.bits_k if kind == "K" else config.bits_v
default_mode = config.default_mode_k if kind == "K" else config.default_mode_v
selected_page_mode = page_mode
page_mode_name = selected_page_mode.mode if selected_page_mode is not None else (mode or default_mode)
if selected_page_mode is not None:
bits = int(selected_page_mode.bits)
page_layout = layout or (config.payload_layout_k if kind == "K" else config.payload_layout_v)
scheme = (
selected_page_mode.quant_scheme
if selected_page_mode is not None
else quant_scheme or (config.quant_scheme_k if kind == "K" else config.quant_scheme_v)
)
token_count = values.shape[0]
requested_mode = page_mode_name
trial_quant_error = None
runtime_page_mean = None
runtime_page_sketch = None
runtime_page_min = None
runtime_page_max = None
if build_runtime_metadata:
runtime_page_mean, runtime_page_sketch = _build_runtime_page_sketch(values)
runtime_page_min, runtime_page_max = _build_runtime_page_envelope(values)
def _build_m2_sidecar() -> tuple[np.ndarray | None, np.ndarray | None, np.ndarray | None]:
sidecar_enabled = config.m2_prefilter_top_k > 0 if build_m2_sidecar is None else bool(build_m2_sidecar)
if kind != "K" or not sidecar_enabled:
return None, None, None
coeffs, basis, mean, _ = _encode_m2_tensor(values, config)
return (
coeffs.astype(np.float16, copy=False),
basis.astype(np.float16, copy=False),
mean.astype(np.float16, copy=False),
)
header_kwargs = {
"policy_id": selected_page_mode.policy_id if selected_page_mode is not None else "exact_baseline",
"sensitivity_tier": selected_page_mode.sensitivity_tier if selected_page_mode is not None else "exact",
"fallback_reason": selected_page_mode.fallback_reason if selected_page_mode is not None else "",
"age_bucket": selected_page_mode.age_bucket if selected_page_mode is not None else "aged",
}
if page_mode_name == "M3":
escape_dtype = (
selected_page_mode.escape_dtype
if selected_page_mode is not None and selected_page_mode.escape_dtype is not None
else config.escape_dtype
)
header = PageHeader(
layer_id=layer_id,
kv_head_id=kv_head_id,
kind=kind,
token_start=token_start,
token_count=token_count,
head_dim=config.head_dim,
padded_head_dim=config.padded_head_dim,
group_size=config.group_size,
num_groups=config.num_groups,
bits=bits,
words_per_group=words_per_group(config.group_size, bits),
mode_default="M3",
layout=page_layout,
quant_scheme=scheme,
**header_kwargs,
escape_dtype=escape_dtype,
)
escape_payload, escape_scales = encode_escape_storage(values, dtype=escape_dtype)
return EncodedPage(
header=header,
escape_payload=escape_payload,
escape_scales=escape_scales,
requested_mode=page_mode,
runtime_page_mean=runtime_page_mean,
runtime_page_sketch=runtime_page_sketch,
runtime_page_min=runtime_page_min,
runtime_page_max=runtime_page_max,
)
trial_token_p95_error = None
if page_mode_name == "M2":
if kind != "K":
raise ValueError("M2 is only supported for K pages in this phase")
coeffs, basis, mean, padded_head_dim = _encode_m2_tensor(values, config)
header = PageHeader(
layer_id=layer_id,
kv_head_id=kv_head_id,
kind=kind,
token_start=token_start,
token_count=token_count,
head_dim=config.head_dim,
padded_head_dim=padded_head_dim,
group_size=config.group_size,
num_groups=config.num_groups,
bits=bits,
words_per_group=0,
mode_default="M2",
layout=page_layout,
quant_scheme="sketch",
**header_kwargs,
escape_dtype=config.escape_dtype,
)
return EncodedPage(
header=header,
m2_sketch=coeffs.astype(np.float16, copy=False),
m2_basis=basis.astype(np.float16, copy=False),
m2_mean=mean.astype(np.float16, copy=False),
requested_mode=page_mode,
runtime_page_mean=runtime_page_mean,
runtime_page_sketch=runtime_page_sketch,
runtime_page_min=runtime_page_min,
runtime_page_max=runtime_page_max,
)
if page_mode_name == "M4":
if kind != "K":
raise ValueError("M4 is only supported for K pages in this phase")
coeffs, basis, mean, padded_head_dim = quantize_tensor_m4(
values,
group_size=config.group_size,
project_dim=config.resolve_m4_project_dim_k(layer_id=layer_id),
basis_family=config.resolve_m4_project_basis_k(layer_id=layer_id),
basis_override=m4_basis_override,
)
header = PageHeader(
layer_id=layer_id,
kv_head_id=kv_head_id,
kind=kind,
token_start=token_start,
token_count=token_count,
head_dim=config.head_dim,
padded_head_dim=padded_head_dim,
group_size=config.group_size,
num_groups=config.num_groups,
bits=bits,
words_per_group=0,
mode_default="M4",
layout=page_layout,
quant_scheme="project",
project_basis=config.resolve_m4_project_basis_k(layer_id=layer_id),
**header_kwargs,
escape_dtype=config.escape_dtype,
)
return EncodedPage(
header=header,
m2_sketch=coeffs.astype(np.float16, copy=False),
m2_basis=None if basis is None else basis.astype(np.float16, copy=False),
m2_mean=mean.astype(np.float16, copy=False),
requested_mode=page_mode,
runtime_page_mean=runtime_page_mean,
runtime_page_sketch=runtime_page_sketch,
runtime_page_min=runtime_page_min,
runtime_page_max=runtime_page_max,
)
if page_mode_name == "M1":
codes, codebooks, padded_head_dim = quantize_tensor_lut(
values,
group_size=config.group_size,
bits=bits,
segment_count=config.m1_segment_count_k if kind == "K" else config.m1_segment_count_v,
refine_steps=config.lut_refine_steps,
preconditioner=config.preconditioner,
precondition_strength=config.precondition_strength,
)
if config.m1_fallback_to_m0:
reconstructed = _reconstruct_lut_page(codes, codebooks)[:, : config.head_dim]
rms = float(np.sqrt(np.mean(np.square(values), dtype=np.float64)))
trial_quant_error = float(np.mean(np.abs(values - reconstructed), dtype=np.float64) / max(rms, 1e-6))
token_norms = np.linalg.norm(values, axis=1)
token_rel_error = np.linalg.norm(values - reconstructed, axis=1) / np.maximum(token_norms, 1e-6)
trial_token_p95_error = float(np.percentile(token_rel_error, 95))
if (
trial_quant_error > config.m1_error_threshold
or trial_token_p95_error > config.m1_token_p95_error_threshold
):
page_mode_name = "M0"
scheme = "affine"
if page_mode_name == "M1":
sidecar_sketch, sidecar_basis, sidecar_mean = _build_m2_sidecar()
payload = build_payload(codes, bits, page_layout)
header = PageHeader(
layer_id=layer_id,
kv_head_id=kv_head_id,
kind=kind,
token_start=token_start,
token_count=token_count,
head_dim=config.head_dim,
padded_head_dim=padded_head_dim,
group_size=config.group_size,
num_groups=config.num_groups,
bits=bits,
words_per_group=words_per_group(config.group_size, bits),
mode_default="M1",
layout=page_layout,
quant_scheme="lut",
**header_kwargs,
escape_dtype=config.escape_dtype,
)
return EncodedPage(
header=header,
payload=payload,
codebooks=codebooks.astype(np.float16),
m2_sketch=sidecar_sketch,
m2_basis=sidecar_basis,
m2_mean=sidecar_mean,
lut_segment_count=int(codebooks.shape[1]) if codebooks.ndim == 3 else 1,
requested_mode=requested_mode,
trial_quant_error=trial_quant_error,
trial_token_p95_error=trial_token_p95_error,
runtime_page_mean=runtime_page_mean,
runtime_page_sketch=runtime_page_sketch,
runtime_page_min=runtime_page_min,
runtime_page_max=runtime_page_max,
)
if page_mode_name == "T3":
codes, correction, centroids, padded_head_dim = quantize_tensor_turbo3(
values,
group_size=config.group_size,
)
sidecar_sketch, sidecar_basis, sidecar_mean = _build_m2_sidecar()
payload = build_payload(codes, 3, page_layout)
header = PageHeader(
layer_id=layer_id,
kv_head_id=kv_head_id,
kind=kind,
token_start=token_start,
token_count=token_count,
head_dim=config.head_dim,
padded_head_dim=padded_head_dim,
group_size=config.group_size,
num_groups=config.num_groups,
bits=3,
words_per_group=words_per_group(config.group_size, 3),
mode_default="T3",
layout=page_layout,
quant_scheme="turbo3",
**header_kwargs,
escape_dtype=config.escape_dtype,
)
return EncodedPage(
header=header,
payload=payload,
scales=correction,
codebooks=centroids,
m2_sketch=sidecar_sketch,
m2_basis=sidecar_basis,
m2_mean=sidecar_mean,
requested_mode=requested_mode,
runtime_page_mean=runtime_page_mean,
runtime_page_sketch=runtime_page_sketch,
runtime_page_min=runtime_page_min,
runtime_page_max=runtime_page_max,
)
if page_mode_name != "M0":
raise ValueError("only M0, M1, M2, M3, M4, and T3 are supported in this bootstrap")
codes, scales, bias, padded_head_dim = quantize_tensor(
values,
group_size=config.group_size,
bits=bits,
scheme=scheme,
)
payload = build_payload(codes, bits, page_layout)
header = PageHeader(
layer_id=layer_id,
kv_head_id=kv_head_id,
kind=kind,
token_start=token_start,
token_count=token_count,
head_dim=config.head_dim,
padded_head_dim=padded_head_dim,
group_size=config.group_size,
num_groups=config.num_groups,
bits=bits,
words_per_group=words_per_group(config.group_size, bits),
mode_default="M0",
layout=page_layout,
quant_scheme=scheme,
**header_kwargs,
escape_dtype=config.escape_dtype,
)
stored_scales = scales.astype(np.float16)
stored_bias = None if bias is None else bias.astype(np.float16)
sidecar_sketch, sidecar_basis, sidecar_mean = _build_m2_sidecar()
return EncodedPage(
header=header,
payload=payload,
scales=stored_scales,
bias=stored_bias,
m2_sketch=sidecar_sketch,
m2_basis=sidecar_basis,
m2_mean=sidecar_mean,
requested_mode=requested_mode,
trial_quant_error=trial_quant_error,
trial_token_p95_error=trial_token_p95_error if "trial_token_p95_error" in locals() else None,
runtime_page_mean=runtime_page_mean,
runtime_page_sketch=runtime_page_sketch,
runtime_page_min=runtime_page_min,
runtime_page_max=runtime_page_max,
)