DotCache-Arena / dotcache /decode_reference.py
DeanoCalver's picture
Initial DotCache Arena Space upload
751ad26 verified
from __future__ import annotations
import numpy as np
from .modes.m0_affine import dequantize_group
from .modes.m1_lut import dequantize_group_lut
from .modes.m2_key_sketch import reconstruct_group_m2
from .modes.m4_key_project import reconstruct_group_m4
from .modes.m3_escape import decode_escape_payload
from .modes.turbo3 import dequantize_group_turbo3
from .page_format import load_group_words
from .packing import unpack_bits
from .types import EncodedPage
def decode_group_ref(page: EncodedPage, group_index: int) -> np.ndarray:
page.record_group_decode()
header = page.header
if header.mode_default == "M3":
if page.escape_payload is None:
raise ValueError("escape payload is missing")
start = group_index * header.group_size
end = start + header.group_size
return decode_escape_payload(page.escape_payload, scales=page.escape_scales)[:, start:end]
if header.mode_default == "M2":
if page.m2_sketch is None or page.m2_basis is None:
raise ValueError("M2 page is missing sketch payload")
return reconstruct_group_m2(
page.m2_sketch[:, group_index, :],
basis=page.m2_basis[group_index],
mean=None if page.m2_mean is None else page.m2_mean[group_index],
)
if header.mode_default == "M4":
if page.m2_sketch is None or page.m2_mean is None:
raise ValueError("M4 page is missing projected payload")
return reconstruct_group_m4(
page.m2_sketch[:, group_index, :],
mean=page.m2_mean[group_index],
group_size=header.group_size,
basis_family=header.project_basis,
basis=None if page.m2_basis is None else page.m2_basis[group_index],
)
words = load_group_words(page, group_index)
codes = unpack_bits(words, header.bits, header.group_size)
if header.mode_default == "M1":
if page.codebooks is None:
raise ValueError("M1 page is missing codebooks")
codebook = np.asarray(page.codebooks[group_index], dtype=np.float32)
return dequantize_group_lut(codes, codebook=codebook)
if header.mode_default == "T3":
if page.scales is None or page.codebooks is None:
raise ValueError("T3 page is missing correction metadata")
correction = page.scales[:, group_index].astype(np.float32)
return dequantize_group_turbo3(
codes,
correction=correction,
centroids=np.asarray(page.codebooks, dtype=np.float32),
)
if page.payload is None or page.scales is None:
raise ValueError("M0 page is missing payload or scales")
scales = page.scales[:, group_index].astype(np.float32)[:, None]
bias = None
if page.bias is not None:
bias = page.bias[:, group_index].astype(np.float32)[:, None]
return dequantize_group(
codes,
scales=scales,
bias=bias,
bits=header.bits,
scheme=header.quant_scheme,
)
def decode_page(page: EncodedPage) -> np.ndarray:
page.record_full_decode()
header = page.header
if header.mode_default == "M3":
if page.escape_payload is None:
raise ValueError("escape payload is missing")
return decode_escape_payload(page.escape_payload, head_dim=header.head_dim, scales=page.escape_scales)
groups = [decode_group_ref(page, group_index) for group_index in range(header.num_groups)]
full = np.concatenate(groups, axis=-1)
return full[:, : header.head_dim]