Spaces:
Sleeping
Sleeping
File size: 6,766 Bytes
b786614 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 | """Unit tests for proactive_cache."""
import pytest
import numpy as np
import torch
from proactive_cache.eviction import score_tokens, select_indices, prune_kv_cache, evict
from proactive_cache.prototypes import build_prototypes, save_prototypes, load_prototypes
from proactive_cache.utils import to_tuple_kv, to_dynamic_cache
# ── Fixtures ──────────────────────────────────────────────────────────────────
def make_dummy_patterns(num_docs=5, num_layers=2, num_heads=4, seq_len=64):
"""Create synthetic attention patterns for testing."""
patterns = []
for _ in range(num_docs):
doc = {}
for layer in range(num_layers):
for head in range(num_heads):
arr = np.random.rand(seq_len).astype(np.float32)
arr /= arr.sum()
doc[(layer, head)] = arr
patterns.append(doc)
return patterns
def make_dummy_kv_cache(num_layers=2, num_heads=4, seq_len=64, head_dim=32, device="cpu"):
"""Create a synthetic KV cache tuple."""
return tuple(
(torch.randn(1, num_heads, seq_len, head_dim, device=device),
torch.randn(1, num_heads, seq_len, head_dim, device=device))
for _ in range(num_layers)
)
# ── Eviction tests ────────────────────────────────────────────────────────────
class TestScoreTokens:
def test_returns_correct_shape(self):
scores = score_tokens(None, seq_len=128, budget=64)
assert scores.shape == (128,)
def test_token_zero_has_highest_score(self):
scores = score_tokens(None, seq_len=128, budget=64)
# Sink boost means position 0 is always kept
top_k = np.argsort(scores)[-64:]
assert 0 in top_k, "Token 0 (attention sink) must always be selected"
def test_recency_tokens_kept(self):
seq_len, budget = 128, 64
scores = score_tokens(None, seq_len=seq_len, budget=budget)
top_k = np.argsort(scores)[-budget:]
# Last few tokens should be in top-k
assert (seq_len - 1) in top_k, "Most recent token must always be kept"
def test_with_prototypes(self):
patterns = make_dummy_patterns(seq_len=64)
protos = build_prototypes(patterns, n_clusters=2, max_seq_len=64)
scores = score_tokens(protos, seq_len=64, budget=32)
assert scores.shape == (64,)
assert np.all(np.isfinite(scores)), "Scores must be finite"
def test_budget_proportional_recency(self):
# Larger budget → larger recency window (proportional)
s128 = score_tokens(None, seq_len=512, budget=128)
s256 = score_tokens(None, seq_len=512, budget=256)
# More positions should be elevated in s256
# (just check both run without error)
assert s128.shape == s256.shape == (512,)
class TestSelectIndices:
def test_returns_sorted(self):
scores = np.random.rand(100)
idx = select_indices(scores, budget=20)
assert idx == sorted(idx), "Indices must be in ascending order"
def test_correct_count(self):
scores = np.random.rand(100)
idx = select_indices(scores, budget=30)
assert len(idx) == 30
def test_budget_larger_than_seq(self):
scores = np.random.rand(10)
idx = select_indices(scores, budget=50)
assert len(idx) == 10 # clipped to seq_len
class TestPruneKVCache:
def test_prunes_to_budget(self):
kv = make_dummy_kv_cache(num_layers=3, num_heads=4, seq_len=128)
indices = list(range(0, 64, 2)) # 32 indices
pruned = prune_kv_cache(kv, indices, device=torch.device("cpu"))
pruned_tuple = to_tuple_kv(pruned)
assert pruned_tuple[0][0].shape[2] == 32, "Pruned KV must have budget tokens"
def test_all_layers_pruned(self):
num_layers = 4
kv = make_dummy_kv_cache(num_layers=num_layers, seq_len=100)
indices = list(range(50))
pruned_tuple = to_tuple_kv(prune_kv_cache(kv, indices, torch.device("cpu")))
assert len(pruned_tuple) == num_layers
def test_no_prune_when_under_budget(self):
kv = make_dummy_kv_cache(seq_len=32)
result = evict(kv, budget=64, prototypes=None, seq_len=32, device=torch.device("cpu"))
# Should return unchanged (seq_len <= budget)
assert to_tuple_kv(result)[0][0].shape[2] == 32
# ── Prototype tests ───────────────────────────────────────────────────────────
class TestPrototypes:
def test_build_returns_dict(self):
patterns = make_dummy_patterns()
protos = build_prototypes(patterns, n_clusters=2, max_seq_len=64)
assert isinstance(protos, dict)
assert len(protos) > 0
def test_centroid_shapes(self):
patterns = make_dummy_patterns(num_layers=2, num_heads=4, seq_len=64)
protos = build_prototypes(patterns, n_clusters=3, max_seq_len=64)
for key, val in protos.items():
centroids = val["centroids"]
assert centroids.shape == (3, 64), f"Wrong centroid shape: {centroids.shape}"
def test_save_load_roundtrip(self, tmp_path):
patterns = make_dummy_patterns()
protos = build_prototypes(patterns, n_clusters=2, max_seq_len=64)
path = str(tmp_path / "test_protos.pkl")
save_prototypes(protos, path)
loaded = load_prototypes(path)
assert set(loaded.keys()) == set(protos.keys())
def test_load_missing_raises(self, tmp_path):
with pytest.raises(FileNotFoundError):
load_prototypes(str(tmp_path / "does_not_exist.pkl"))
def test_empty_patterns_raises(self):
with pytest.raises(ValueError):
build_prototypes([], n_clusters=2)
# ── Utils tests ───────────────────────────────────────────────────────────────
class TestUtils:
def test_to_tuple_kv_from_tuple(self):
kv = make_dummy_kv_cache(num_layers=2)
result = to_tuple_kv(kv)
assert len(result) == 2
assert isinstance(result[0], tuple)
def test_to_dynamic_cache_roundtrip(self):
kv = make_dummy_kv_cache(num_layers=2, seq_len=32)
kv_tuple = to_tuple_kv(kv)
dynamic = to_dynamic_cache(kv_tuple)
back = to_tuple_kv(dynamic)
# Shapes should be preserved
assert back[0][0].shape == kv_tuple[0][0].shape
|