Spaces:
Paused
Paused
File size: 3,544 Bytes
751ad26 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 | from __future__ import annotations
import numpy as np
from .modes.m0_affine import dequantize_group
from .modes.m1_lut import dequantize_group_lut
from .modes.m2_key_sketch import reconstruct_group_m2
from .modes.m4_key_project import reconstruct_group_m4
from .modes.m3_escape import decode_escape_payload
from .modes.turbo3 import dequantize_group_turbo3
from .page_format import load_group_words
from .packing import unpack_bits
from .types import EncodedPage
def decode_group_ref(page: EncodedPage, group_index: int) -> np.ndarray:
page.record_group_decode()
header = page.header
if header.mode_default == "M3":
if page.escape_payload is None:
raise ValueError("escape payload is missing")
start = group_index * header.group_size
end = start + header.group_size
return decode_escape_payload(page.escape_payload, scales=page.escape_scales)[:, start:end]
if header.mode_default == "M2":
if page.m2_sketch is None or page.m2_basis is None:
raise ValueError("M2 page is missing sketch payload")
return reconstruct_group_m2(
page.m2_sketch[:, group_index, :],
basis=page.m2_basis[group_index],
mean=None if page.m2_mean is None else page.m2_mean[group_index],
)
if header.mode_default == "M4":
if page.m2_sketch is None or page.m2_mean is None:
raise ValueError("M4 page is missing projected payload")
return reconstruct_group_m4(
page.m2_sketch[:, group_index, :],
mean=page.m2_mean[group_index],
group_size=header.group_size,
basis_family=header.project_basis,
basis=None if page.m2_basis is None else page.m2_basis[group_index],
)
words = load_group_words(page, group_index)
codes = unpack_bits(words, header.bits, header.group_size)
if header.mode_default == "M1":
if page.codebooks is None:
raise ValueError("M1 page is missing codebooks")
codebook = np.asarray(page.codebooks[group_index], dtype=np.float32)
return dequantize_group_lut(codes, codebook=codebook)
if header.mode_default == "T3":
if page.scales is None or page.codebooks is None:
raise ValueError("T3 page is missing correction metadata")
correction = page.scales[:, group_index].astype(np.float32)
return dequantize_group_turbo3(
codes,
correction=correction,
centroids=np.asarray(page.codebooks, dtype=np.float32),
)
if page.payload is None or page.scales is None:
raise ValueError("M0 page is missing payload or scales")
scales = page.scales[:, group_index].astype(np.float32)[:, None]
bias = None
if page.bias is not None:
bias = page.bias[:, group_index].astype(np.float32)[:, None]
return dequantize_group(
codes,
scales=scales,
bias=bias,
bits=header.bits,
scheme=header.quant_scheme,
)
def decode_page(page: EncodedPage) -> np.ndarray:
page.record_full_decode()
header = page.header
if header.mode_default == "M3":
if page.escape_payload is None:
raise ValueError("escape payload is missing")
return decode_escape_payload(page.escape_payload, head_dim=header.head_dim, scales=page.escape_scales)
groups = [decode_group_ref(page, group_index) for group_index in range(header.num_groups)]
full = np.concatenate(groups, axis=-1)
return full[:, : header.head_dim]
|