CharlesCNorton
eval_all: hash-keyed result cache (--cache-dir, --no-cache); README: bit-ordering scope rules; docs/ISA.md: opcode reference and end-to-end tutorial; docs/float-pipeline.md: composition gap notes
597e7c2 | """ | |
| Unified evaluation harness for any threshold-computer variant. | |
| Drops the `--cpu-test` smoke test (which was hardcoded to 16-bit/64KB) and | |
| adds variant-aware sweep modes. The same harness handles every (data_bits, | |
| addr_bits) configuration: it reads the manifest from each safetensors file, | |
| runs the BatchedFitnessEvaluator at the right device, and reports per-file | |
| plus per-category results. | |
| Usage: | |
| python eval_all.py path/to/file.safetensors # one file | |
| python eval_all.py variants/ # every .safetensors in dir | |
| python eval_all.py --device cpu variants/ # CPU only (default) | |
| python eval_all.py --pop_size 32 variants/ # batched pop eval | |
| python eval_all.py --debug path/to/file.safetensors # per-circuit detail | |
| python eval_all.py --cpu-program PATH # also run an assembled program | |
| # through the threshold CPU | |
| # sized to the file's manifest | |
| Exit code: | |
| 0 if all files PASS (fitness >= 0.9999) | |
| N where N is the number of FAILing files | |
| """ | |
| from __future__ import annotations | |
| import argparse | |
| import json | |
| import os | |
| import sys | |
| import time | |
| from pathlib import Path | |
| from typing import Any, Dict, List, Optional, Tuple | |
| import torch | |
| from safetensors import safe_open | |
| # Reuse eval.py's evaluator (variant-aware) | |
| sys.path.insert(0, os.path.dirname(os.path.abspath(__file__))) | |
| from eval import ( | |
| BatchedFitnessEvaluator, | |
| create_population, | |
| load_model, | |
| get_manifest, | |
| heaviside, | |
| int_to_bits, | |
| bits_to_int, | |
| bits_msb_to_lsb, | |
| ) | |
| # --------------------------------------------------------------------------- | |
| # Variant-aware threshold ALU + CPU | |
| # --------------------------------------------------------------------------- | |
| class GenericThresholdALU: | |
| """Variant-aware threshold ALU. Reads manifest, runs ADD/SUB/CMP/MUL etc. | |
| Currently supports the 8-bit ALU primitives (ripplecarry8bit, sub8bit, | |
| cmp8bit, mul/div). For wider data paths, use the BatchedFitnessEvaluator | |
| which already handles 16/32-bit comparators, subtractors, etc. | |
| """ | |
| def __init__(self, tensors: Dict[str, torch.Tensor], data_bits: int): | |
| self.T = tensors | |
| self.data_bits = data_bits | |
| def _g(self, name, inputs): | |
| w = self.T[name + ".weight"].view(-1) | |
| b = self.T[name + ".bias"].view(-1) | |
| return int(heaviside((torch.tensor(inputs, dtype=torch.float32) * w).sum() + b).item()) | |
| def _xor_or_nand(self, prefix, inputs): | |
| a, b_ = inputs | |
| h_or = self._g(f"{prefix}.layer1.or", [a, b_]) | |
| h_nand = self._g(f"{prefix}.layer1.nand", [a, b_]) | |
| return self._g(f"{prefix}.layer2", [h_or, h_nand]) | |
| def _fa(self, prefix, a, b, cin): | |
| s1 = self._xor_or_nand(f"{prefix}.ha1.sum", [a, b]) | |
| c1 = self._g(f"{prefix}.ha1.carry", [a, b]) | |
| s2 = self._xor_or_nand(f"{prefix}.ha2.sum", [s1, cin]) | |
| c2 = self._g(f"{prefix}.ha2.carry", [s1, cin]) | |
| cout = self._g(f"{prefix}.carry_or", [c1, c2]) | |
| return s2, cout | |
| def add8(self, a, b): | |
| a_lsb = list(reversed(int_to_bits(a, 8))) | |
| b_lsb = list(reversed(int_to_bits(b, 8))) | |
| carry = 0 | |
| s_lsb = [] | |
| for i in range(8): | |
| s, carry = self._fa(f"arithmetic.ripplecarry8bit.fa{i}", a_lsb[i], b_lsb[i], carry) | |
| s_lsb.append(s) | |
| return bits_to_int(list(reversed(s_lsb))), carry | |
| def sub8(self, a, b): | |
| a_lsb = list(reversed(int_to_bits(a, 8))) | |
| b_lsb = list(reversed(int_to_bits(b, 8))) | |
| carry = 1 | |
| d_lsb = [] | |
| for i in range(8): | |
| notb = self._g(f"arithmetic.sub8bit.notb{i}", [b_lsb[i]]) | |
| x1 = self._xor_or_nand(f"arithmetic.sub8bit.fa{i}.xor1", [a_lsb[i], notb]) | |
| x2 = self._xor_or_nand(f"arithmetic.sub8bit.fa{i}.xor2", [x1, carry]) | |
| and1 = self._g(f"arithmetic.sub8bit.fa{i}.and1", [a_lsb[i], notb]) | |
| and2 = self._g(f"arithmetic.sub8bit.fa{i}.and2", [x1, carry]) | |
| carry = self._g(f"arithmetic.sub8bit.fa{i}.or_carry", [and1, and2]) | |
| d_lsb.append(x2) | |
| return bits_to_int(list(reversed(d_lsb))), carry | |
| def cmp8(self, a, b, kind): | |
| inp = int_to_bits(a, 8) + int_to_bits(b, 8) | |
| if kind == "eq": | |
| h_geq = self._g("arithmetic.equality8bit.layer1.geq", inp) | |
| h_leq = self._g("arithmetic.equality8bit.layer1.leq", inp) | |
| return self._g("arithmetic.equality8bit.layer2", [h_geq, h_leq]) | |
| return self._g(f"arithmetic.{kind}8bit", inp) | |
| def mul8(self, a, b): | |
| ab = int_to_bits(a, 8) | |
| bb = int_to_bits(b, 8) | |
| result = 0 | |
| for j in range(8): | |
| if bb[j] == 0: | |
| continue | |
| row = 0 | |
| for i in range(8): | |
| pp = self._g(f"alu.alu8bit.mul.pp.a{i}b{j}", [ab[i], bb[j]]) | |
| row |= (pp << (7 - i)) | |
| shift = 7 - j | |
| result, _ = self.add8(result & 0xFF, (row << shift) & 0xFF) | |
| return result & 0xFF | |
| # ----- N-bit primitives (for 16-bit and 32-bit variants) ---------------- | |
| def add_n(self, a: int, b: int, bits: int): | |
| """Width-generic ripple-carry add via arithmetic.ripplecarry{N}bit.""" | |
| prefix = f"arithmetic.ripplecarry{bits}bit" | |
| a_lsb = list(reversed(int_to_bits(a, bits))) | |
| b_lsb = list(reversed(int_to_bits(b, bits))) | |
| carry = 0 | |
| s_lsb = [] | |
| for i in range(bits): | |
| s, carry = self._fa(f"{prefix}.fa{i}", a_lsb[i], b_lsb[i], carry) | |
| s_lsb.append(s) | |
| return bits_to_int(list(reversed(s_lsb))), carry | |
| def sub_n(self, a: int, b: int, bits: int): | |
| """N-bit two's-complement subtract via arithmetic.sub{N}bit (N >= 16). | |
| Structure (per build.add_sub_nbits): N NOT gates + N standard full adders. | |
| """ | |
| prefix = f"arithmetic.sub{bits}bit" | |
| a_lsb = list(reversed(int_to_bits(a, bits))) | |
| b_lsb = list(reversed(int_to_bits(b, bits))) | |
| # NOT each B bit | |
| notb = [self._g(f"{prefix}.not_b.bit{i}", [b_lsb[i]]) for i in range(bits)] | |
| carry = 1 # carry-in = 1 for two's-complement | |
| d_lsb = [] | |
| for i in range(bits): | |
| s, carry = self._fa(f"{prefix}.fa{i}", a_lsb[i], notb[i], carry) | |
| d_lsb.append(s) | |
| return bits_to_int(list(reversed(d_lsb))), carry | |
| def cmp_n(self, a: int, b: int, kind: str, bits: int): | |
| """N-bit comparator. For bits <= 16 single-layer; bits == 32 cascaded.""" | |
| a_bits = int_to_bits(a, bits) | |
| b_bits = int_to_bits(b, bits) | |
| if bits <= 16: | |
| inp = a_bits + b_bits | |
| if kind == "eq": | |
| h_geq = self._g(f"arithmetic.equality{bits}bit.layer1.geq", inp) | |
| h_leq = self._g(f"arithmetic.equality{bits}bit.layer1.leq", inp) | |
| return self._g(f"arithmetic.equality{bits}bit.layer2", [h_geq, h_leq]) | |
| return self._g(f"arithmetic.{kind}{bits}bit", inp) | |
| # 32-bit: cascaded byte-wise | |
| prefix = f"arithmetic.cmp{bits}bit" | |
| num_bytes = bits // 8 | |
| # per-byte gt/lt/eq | |
| byte_gt, byte_lt, byte_eq = [], [], [] | |
| for bn in range(num_bytes): | |
| ab = a_bits[bn*8:(bn+1)*8] | |
| bb = b_bits[bn*8:(bn+1)*8] | |
| byte_gt.append(self._g(f"{prefix}.byte{bn}.gt", ab + bb)) | |
| byte_lt.append(self._g(f"{prefix}.byte{bn}.lt", ab + bb)) | |
| geq = self._g(f"{prefix}.byte{bn}.eq.geq", ab + bb) | |
| leq = self._g(f"{prefix}.byte{bn}.eq.leq", ab + bb) | |
| byte_eq.append(self._g(f"{prefix}.byte{bn}.eq.and", [geq, leq])) | |
| if kind == "equality": | |
| # OR of all eq's, but the gate is `arithmetic.equality{bits}bit` with weight=[1,1,..,1]/bias=-num_bytes | |
| return self._g(f"arithmetic.equality{bits}bit", byte_eq) | |
| # cascade | |
| cascade_gt = [byte_gt[0]] | |
| cascade_lt = [byte_lt[0]] | |
| for bn in range(1, num_bytes): | |
| all_eq = self._g(f"{prefix}.cascade.gt.stage{bn}.all_eq", byte_eq[:bn]) | |
| cascade_gt.append(self._g(f"{prefix}.cascade.gt.stage{bn}.and", [all_eq, byte_gt[bn]])) | |
| all_eq2 = self._g(f"{prefix}.cascade.lt.stage{bn}.all_eq", byte_eq[:bn]) | |
| cascade_lt.append(self._g(f"{prefix}.cascade.lt.stage{bn}.and", [all_eq2, byte_lt[bn]])) | |
| if kind == "greaterthan": | |
| return self._g(f"arithmetic.greaterthan{bits}bit", cascade_gt) | |
| if kind == "lessthan": | |
| return self._g(f"arithmetic.lessthan{bits}bit", cascade_lt) | |
| raise ValueError(f"unsupported cmp kind {kind} for bits={bits}") | |
| def mul_n(self, a: int, b: int, bits: int): | |
| """N-bit shift-add multiply (low N bits only).""" | |
| ab = int_to_bits(a, bits) | |
| bb = int_to_bits(b, bits) | |
| mask = (1 << bits) - 1 | |
| result = 0 | |
| for j in range(bits): | |
| if bb[j] == 0: | |
| continue | |
| row = 0 | |
| for i in range(bits): | |
| pp = self._g(f"alu.alu{bits}bit.mul.pp.a{i}b{j}", [ab[i], bb[j]]) | |
| row |= (pp << (bits - 1 - i)) | |
| shift = (bits - 1) - j | |
| result, _ = self.add_n(result & mask, (row << shift) & mask, bits) | |
| return result & mask | |
| class GenericThresholdCPU: | |
| """Variant-aware CPU runtime. Sized from the variant's manifest.""" | |
| def __init__(self, tensors: Dict[str, torch.Tensor]): | |
| self.T = tensors | |
| m = get_manifest(tensors) | |
| self.data_bits = m["data_bits"] | |
| self.addr_bits = m["addr_bits"] | |
| self.mem_bytes = m["memory_bytes"] | |
| # 8-bit CPU primitives (ripplecarry8bit, sub8bit, alu.alu8bit.*, memory.*, | |
| # control.*) are present in every variant regardless of manifest data_bits. | |
| # Wider data widths simply add additional standalone ALU primitives. | |
| if self.mem_bytes == 0: | |
| raise NotImplementedError( | |
| "Pure-ALU variants have no memory; cannot run CPU programs" | |
| ) | |
| self.alu = GenericThresholdALU(tensors, 8) | |
| def _addr_decode(self, addr): | |
| bits = torch.tensor(int_to_bits(addr, self.addr_bits), dtype=torch.float32) | |
| w = self.T["memory.addr_decode.weight"] | |
| b = self.T["memory.addr_decode.bias"] | |
| return heaviside((w * bits).sum(dim=1) + b) | |
| def mem_read(self, mem, addr): | |
| sel = self._addr_decode(addr) | |
| mem_bits = torch.tensor( | |
| [int_to_bits(byte, 8) for byte in mem], dtype=torch.float32 | |
| ) | |
| and_w = self.T["memory.read.and.weight"] | |
| and_b = self.T["memory.read.and.bias"] | |
| or_w = self.T["memory.read.or.weight"] | |
| or_b = self.T["memory.read.or.bias"] | |
| out = [] | |
| for bit in range(8): | |
| inp = torch.stack([mem_bits[:, bit], sel], dim=1) | |
| and_out = heaviside((inp * and_w[bit]).sum(dim=1) + and_b[bit]) | |
| out.append(int(heaviside((and_out * or_w[bit]).sum() + or_b[bit]).item())) | |
| return bits_to_int(out) | |
| def mem_write(self, mem, addr, value): | |
| sel = self._addr_decode(addr) | |
| data_bits = torch.tensor(int_to_bits(value, 8), dtype=torch.float32) | |
| mem_bits = torch.tensor( | |
| [int_to_bits(byte, 8) for byte in mem], dtype=torch.float32 | |
| ) | |
| sel_w = self.T["memory.write.sel.weight"] | |
| sel_b = self.T["memory.write.sel.bias"] | |
| nsel_w = self.T["memory.write.nsel.weight"].squeeze(1) | |
| nsel_b = self.T["memory.write.nsel.bias"] | |
| and_old_w = self.T["memory.write.and_old.weight"] | |
| and_old_b = self.T["memory.write.and_old.bias"] | |
| and_new_w = self.T["memory.write.and_new.weight"] | |
| and_new_b = self.T["memory.write.and_new.bias"] | |
| or_w = self.T["memory.write.or.weight"] | |
| or_b = self.T["memory.write.or.bias"] | |
| we = torch.ones_like(sel) | |
| sel_inp = torch.stack([sel, we], dim=1) | |
| write_sel = heaviside((sel_inp * sel_w).sum(dim=1) + sel_b) | |
| nsel = heaviside(write_sel * nsel_w + nsel_b) | |
| for bit in range(8): | |
| old = mem_bits[:, bit] | |
| data_bit = data_bits[bit].expand(self.mem_bytes) | |
| inp_old = torch.stack([old, nsel], dim=1) | |
| inp_new = torch.stack([data_bit, write_sel], dim=1) | |
| and_old = heaviside((inp_old * and_old_w[:, bit]).sum(dim=1) + and_old_b[:, bit]) | |
| and_new = heaviside((inp_new * and_new_w[:, bit]).sum(dim=1) + and_new_b[:, bit]) | |
| or_inp = torch.stack([and_old, and_new], dim=1) | |
| new_bit = heaviside((or_inp * or_w[:, bit]).sum(dim=1) + or_b[:, bit]) | |
| mem_bits[:, bit] = new_bit | |
| return [bits_to_int([int(b) for b in mem_bits[i].tolist()]) for i in range(self.mem_bytes)] | |
| def step(self, state): | |
| if state["halted"]: | |
| return state | |
| s = dict(state) | |
| s["mem"] = state["mem"][:] | |
| s["regs"] = state["regs"][:] | |
| s["flags"] = state["flags"][:] | |
| addr_mask = (1 << self.addr_bits) - 1 | |
| pc = s["pc"] | |
| hi = self.mem_read(s["mem"], pc & addr_mask) | |
| lo = self.mem_read(s["mem"], (pc + 1) & addr_mask) | |
| ir = ((hi & 0xFF) << 8) | (lo & 0xFF) | |
| opcode = (ir >> 12) & 0xF | |
| rd = (ir >> 10) & 0x3 | |
| rs = (ir >> 8) & 0x3 | |
| imm = ir & 0xFF | |
| next_pc = (pc + 2) & addr_mask | |
| addr_full = None | |
| if opcode in (0xA, 0xB, 0xC, 0xD, 0xE): | |
| ah = self.mem_read(s["mem"], next_pc) | |
| al = self.mem_read(s["mem"], (next_pc + 1) & addr_mask) | |
| addr_full = ((ah & 0xFF) << 8) | (al & 0xFF) | |
| next_pc = (next_pc + 2) & addr_mask | |
| addr = (addr_full & addr_mask) if addr_full is not None else None | |
| a = s["regs"][rd] | |
| b = s["regs"][rs] | |
| result = a | |
| carry = 0 | |
| overflow = 0 | |
| write_result = True | |
| if opcode == 0x0: | |
| result, carry = self.alu.add8(a, b) | |
| overflow = 1 if (((a ^ result) & (b ^ result)) & 0x80) else 0 | |
| elif opcode == 0x1: | |
| result, carry = self.alu.sub8(a, b) | |
| overflow = 1 if (((a ^ b) & (a ^ result)) & 0x80) else 0 | |
| elif opcode == 0x2: # AND | |
| result = a & b | |
| elif opcode == 0x3: # OR | |
| result = a | b | |
| elif opcode == 0x4: # XOR | |
| result = a ^ b | |
| elif opcode == 0x5: # SHL by 1 (8-bit) | |
| result = (a << 1) & 0xFF | |
| carry = 1 if (a & 0x80) else 0 | |
| elif opcode == 0x6: # SHR by 1 | |
| result = a >> 1 | |
| carry = a & 0x1 | |
| elif opcode == 0x7: | |
| result = self.alu.mul8(a, b) | |
| elif opcode == 0x8: # DIV (sets R[d] = R[d] / R[s]; 0xFF on divide by zero) | |
| result = (a // b) if b != 0 else 0xFF | |
| elif opcode == 0x9: | |
| r2, carry = self.alu.sub8(a, b) | |
| z = 1 if r2 == 0 else 0 | |
| n = 1 if (r2 & 0x80) else 0 | |
| v = 1 if (((a ^ b) & (a ^ r2)) & 0x80) else 0 | |
| s["flags"] = [z, n, carry, v] | |
| write_result = False | |
| elif opcode == 0xA: | |
| result = self.mem_read(s["mem"], addr) | |
| elif opcode == 0xB: | |
| s["mem"] = self.mem_write(s["mem"], addr, b & 0xFF) | |
| write_result = False | |
| elif opcode == 0xC: | |
| s["pc"] = addr | |
| return s | |
| elif opcode == 0xD: | |
| cond = imm & 0x7 | |
| z, n, c, v = s["flags"] | |
| take = [z == 1, z == 0, c == 1, c == 0, | |
| n == 1, n == 0, v == 1, v == 0][cond] | |
| s["pc"] = addr if take else next_pc | |
| return s | |
| elif opcode == 0xE: # CALL: push return address (next_pc), set PC = addr | |
| ret_addr = next_pc & 0xFFFF | |
| sp = s.get("sp", addr_mask) | |
| sp = (sp - 1) & addr_mask | |
| s["mem"] = self.mem_write(s["mem"], sp, (ret_addr >> 8) & 0xFF) | |
| sp = (sp - 1) & addr_mask | |
| s["mem"] = self.mem_write(s["mem"], sp, ret_addr & 0xFF) | |
| s["sp"] = sp | |
| s["pc"] = addr | |
| return s | |
| elif opcode == 0xF: | |
| s["halted"] = True | |
| return s | |
| if write_result and opcode != 0x9: | |
| s["regs"][rd] = result & 0xFF | |
| if opcode in (0x0, 0x1, 0x7): | |
| z = 1 if (result & 0xFF) == 0 else 0 | |
| n = 1 if (result & 0x80) else 0 | |
| s["flags"] = [z, n, carry, overflow] | |
| s["pc"] = next_pc | |
| return s | |
| def run(self, state, max_cycles=200): | |
| s = state | |
| cycles = 0 | |
| while not s["halted"] and cycles < max_cycles: | |
| s = self.step(s) | |
| cycles += 1 | |
| return s, cycles | |
| def _encode_instr(opcode, rd, rs, imm): | |
| return ((opcode & 0xF) << 12) | ((rd & 0x3) << 10) | ((rs & 0x3) << 8) | (imm & 0xFF) | |
| def _w16(mem, addr, value): | |
| mem[addr] = (value >> 8) & 0xFF | |
| mem[addr + 1] = value & 0xFF | |
| PROGRAM_MIN_BYTES = 0x84 # code 0x00..0x1F + data 0x80..0x83 | |
| def builtin_program(addr_bits: int) -> Tuple[List[int], int]: | |
| """Sum 5+4+3+2+1 via a loop. Returns (mem, expected_result_at_0x83). | |
| Compact layout: code at 0x00..0x1F (32 bytes), data at 0x80..0x83 (4 bytes). | |
| Total footprint 132 bytes -- fits within scratchpad (256 B) and larger. | |
| Requires addr_bits >= 8. | |
| """ | |
| if (1 << addr_bits) < PROGRAM_MIN_BYTES: | |
| raise ValueError(f"addr_bits={addr_bits} too small for builtin program") | |
| mem = [0] * (1 << addr_bits) | |
| mem[0x80] = 5 # initial counter | |
| mem[0x81] = 1 # decrement | |
| mem[0x82] = 0 # zero (for compare and accumulator init) | |
| # mem[0x83] is the output | |
| _w16(mem, 0x0000, _encode_instr(0xA, 1, 0, 0)); _w16(mem, 0x0002, 0x0080) | |
| _w16(mem, 0x0004, _encode_instr(0xA, 2, 0, 0)); _w16(mem, 0x0006, 0x0081) | |
| _w16(mem, 0x0008, _encode_instr(0xA, 3, 0, 0)); _w16(mem, 0x000A, 0x0082) | |
| _w16(mem, 0x000C, _encode_instr(0xA, 0, 0, 0)); _w16(mem, 0x000E, 0x0082) | |
| _w16(mem, 0x0010, _encode_instr(0x0, 0, 1, 0)) | |
| _w16(mem, 0x0012, _encode_instr(0x1, 1, 2, 0)) | |
| _w16(mem, 0x0014, _encode_instr(0x9, 1, 3, 0)) | |
| _w16(mem, 0x0016, _encode_instr(0xD, 0, 0, 0x01)); _w16(mem, 0x0018, 0x0010) | |
| _w16(mem, 0x001A, _encode_instr(0xB, 0, 0, 0)); _w16(mem, 0x001C, 0x0083) | |
| _w16(mem, 0x001E, _encode_instr(0xF, 0, 0, 0)) | |
| return mem, 15 | |
| # --------------------------------------------------------------------------- | |
| # Eval driver | |
| # --------------------------------------------------------------------------- | |
| def _file_fingerprint(path: Path) -> str: | |
| """Stable cache key for a safetensors file: sha256 of its content. | |
| Hashes are content-addressed so renaming a file doesn't blow the cache, | |
| but mtime-only would re-key on every clone of the repo. The sha256 of a | |
| 30 MB safetensors finishes in tens of milliseconds — small compared to | |
| a 5,900-test fitness run. | |
| """ | |
| import hashlib | |
| h = hashlib.sha256() | |
| with open(path, "rb") as f: | |
| for chunk in iter(lambda: f.read(1 << 20), b""): | |
| h.update(chunk) | |
| return h.hexdigest() | |
| def _cache_key(path: Path, opts: Dict[str, Any]) -> str: | |
| """Cache key combining file content with the relevant evaluation options.""" | |
| fp = _file_fingerprint(path) | |
| opt_str = json.dumps(opts, sort_keys=True) | |
| import hashlib | |
| suffix = hashlib.sha256(opt_str.encode("utf-8")).hexdigest()[:8] | |
| return f"{fp}_{suffix}" | |
| def _load_cache(cache_dir: Path, key: str) -> Dict[str, Any] | None: | |
| p = cache_dir / f"{key}.json" | |
| if not p.exists(): | |
| return None | |
| try: | |
| return json.loads(p.read_text(encoding="utf-8")) | |
| except (json.JSONDecodeError, OSError): | |
| return None | |
| def _save_cache(cache_dir: Path, key: str, payload: Dict[str, Any]) -> None: | |
| cache_dir.mkdir(parents=True, exist_ok=True) | |
| p = cache_dir / f"{key}.json" | |
| try: | |
| p.write_text(json.dumps(payload, indent=2, default=str), encoding="utf-8") | |
| except OSError: | |
| pass | |
| def list_safetensors(path: Path) -> List[Path]: | |
| if path.is_file(): | |
| return [path] | |
| if path.is_dir(): | |
| return sorted(p for p in path.glob("*.safetensors") if p.is_file()) | |
| return [] | |
| def evaluate_one(path: Path, device: str, pop_size: int, debug: bool, run_cpu_program: bool) -> Dict: | |
| out: Dict = {"path": str(path), "filename": path.name} | |
| try: | |
| tensors = load_model(str(path)) | |
| except Exception as e: | |
| out.update(error=f"load failed: {e}", status="ERROR") | |
| return out | |
| manifest = get_manifest(tensors) | |
| out.update( | |
| size_mb=path.stat().st_size / (1024 * 1024), | |
| tensors=len(tensors), | |
| params=sum(t.numel() for t in tensors.values()), | |
| manifest=manifest, | |
| ) | |
| # Move to device | |
| tensors = {k: v.to(device) for k, v in tensors.items()} | |
| try: | |
| evaluator = BatchedFitnessEvaluator(device=device, model_path=str(path), tensors=tensors) | |
| population = create_population(tensors, pop_size=pop_size, device=device) | |
| t0 = time.perf_counter() | |
| fitness = evaluator.evaluate(population, debug=debug) | |
| elapsed = time.perf_counter() - t0 | |
| f0 = float(fitness[0].item()) if pop_size == 1 else float(fitness.mean().item()) | |
| out.update( | |
| fitness=f0, | |
| total_tests=evaluator.total_tests, | |
| elapsed_s=elapsed, | |
| categories={k: (float(v[0]), int(v[1])) for k, v in evaluator.category_scores.items()}, | |
| status="PASS" if f0 >= 0.9999 else "FAIL", | |
| ) | |
| except Exception as e: | |
| out.update(error=f"eval failed: {type(e).__name__}: {e}", status="ERROR") | |
| return out | |
| # Optional: CPU program test (8-bit CPU primitives are in every variant) | |
| if run_cpu_program: | |
| if manifest["memory_bytes"] >= PROGRAM_MIN_BYTES: | |
| try: | |
| cpu_tensors = {k: v.cpu() for k, v in tensors.items()} | |
| cpu = GenericThresholdCPU(cpu_tensors) | |
| mem, expected = builtin_program(manifest["addr_bits"]) | |
| state = {"pc": 0, "regs": [0] * 4, "flags": [0] * 4, "mem": mem, "halted": False} | |
| t0 = time.perf_counter() | |
| final, cycles = cpu.run(state, max_cycles=200) | |
| cpu_elapsed = time.perf_counter() - t0 | |
| got = final["mem"][0x83] | |
| out["cpu_program"] = { | |
| "ok": got == expected, | |
| "got": got, | |
| "expected": expected, | |
| "cycles": cycles, | |
| "elapsed_s": cpu_elapsed, | |
| } | |
| if got != expected: | |
| out["status"] = "FAIL" | |
| except Exception as e: | |
| out["cpu_program"] = {"error": str(e)} | |
| else: | |
| out["cpu_program"] = {"skipped": f"mem={manifest['memory_bytes']}B < {PROGRAM_MIN_BYTES}"} | |
| # Wider-ALU chain test for 16/32-bit variants | |
| bits = manifest["data_bits"] | |
| if bits in (16, 32): | |
| try: | |
| alu_tensors = {k: v.cpu() for k, v in tensors.items()} | |
| alu = GenericThresholdALU(alu_tensors, bits) | |
| t0 = time.perf_counter() | |
| if bits == 16: | |
| x, y = 1234, 5678 | |
| z, _ = alu.add_n(x, y, 16); assert z == (x + y) & 0xFFFF | |
| w, _ = alu.sub_n(z, x, 16); assert w == (z - x) & 0xFFFF, (w, z - x) | |
| gt = alu.cmp_n(z, x, "greaterthan", 16); assert gt == 1 | |
| lt = alu.cmp_n(x, z, "lessthan", 16); assert lt == 1 | |
| eq = alu.cmp_n(w, y, "eq", 16); assert eq == 1 | |
| p = alu.mul_n(123, 5, 16); assert p == (123 * 5) & 0xFFFF | |
| else: # 32 | |
| x, y = 1_000_000, 999_000 | |
| z, _ = alu.sub_n(x, y, 32); assert z == 1_000 | |
| s, _ = alu.add_n(z, x, 32); assert s == 1_001_000 | |
| p = alu.mul_n(z, 100, 32); assert p == 100_000 | |
| gt = alu.cmp_n(x, y, "greaterthan", 32); assert gt == 1 | |
| lt = alu.cmp_n(y, x, "lessthan", 32); assert lt == 1 | |
| eq = alu.cmp_n(p, 100_000, "equality", 32); assert eq == 1 | |
| chain_dt = time.perf_counter() - t0 | |
| out[f"alu_chain_{bits}"] = {"ok": True, "elapsed_s": chain_dt} | |
| except AssertionError as e: | |
| out[f"alu_chain_{bits}"] = {"ok": False, "error": f"chain mismatch: {e}"} | |
| out["status"] = "FAIL" | |
| except Exception as e: | |
| out[f"alu_chain_{bits}"] = {"ok": False, "error": f"{type(e).__name__}: {e}"} | |
| out["status"] = "FAIL" | |
| return out | |
| def print_row(r: Dict, show_cpu: bool) -> None: | |
| if "error" in r: | |
| print(f" {r['filename']:<48} ERROR: {r['error'][:80]}") | |
| return | |
| m = r["manifest"] | |
| fit = f"{r['fitness']:.4f}" if r.get("fitness") is not None else "n/a" | |
| cpu_col = "" | |
| if show_cpu and "cpu_program" in r: | |
| cp = r["cpu_program"] | |
| if cp.get("ok"): | |
| cpu_col = f" CPU OK ({cp['cycles']}cyc/{cp['elapsed_s']:.1f}s)" | |
| elif "skipped" in cp: | |
| cpu_col = f" CPU SKIP" | |
| elif "error" in cp: | |
| cpu_col = f" CPU ERR" | |
| else: | |
| cpu_col = f" CPU FAIL ({cp.get('got')}!={cp.get('expected')})" | |
| chain_col = "" | |
| if show_cpu: | |
| for bits in (16, 32): | |
| key = f"alu_chain_{bits}" | |
| if key in r: | |
| ch = r[key] | |
| if ch.get("ok"): | |
| chain_col = f" ALU{bits} OK ({ch['elapsed_s']:.2f}s)" | |
| else: | |
| chain_col = f" ALU{bits} FAIL" | |
| print( | |
| f" {r['filename']:<48} d={m['data_bits']:>2}b a={m['addr_bits']:>2}b " | |
| f"mem={m['memory_bytes']:>6}B size={r['size_mb']:>6.1f}MB " | |
| f"params={r['params']:>10,} fit={fit:>6} tests={r['total_tests']:>5} " | |
| f"{r['status']:>5}{cpu_col}{chain_col}" | |
| ) | |
| def main() -> int: | |
| parser = argparse.ArgumentParser(description="Variant-agnostic eval harness") | |
| parser.add_argument("path", help="Path to .safetensors file or directory of files") | |
| parser.add_argument("--device", default="cpu", help="cpu (default) or cuda") | |
| parser.add_argument("--pop_size", type=int, default=1) | |
| parser.add_argument("--debug", action="store_true", help="Per-circuit detail per file") | |
| parser.add_argument("--cpu-program", action="store_true", | |
| help="Also run a small assembled program through the threshold CPU " | |
| "(only applies to 8-bit variants with >= 512 B memory)") | |
| parser.add_argument("--json", action="store_true", help="Emit JSON results to stdout instead of a table") | |
| parser.add_argument("--cache-dir", default=".eval_cache", | |
| help="Directory for hash-keyed result cache " | |
| "(default: ./.eval_cache). Set to '' to disable.") | |
| parser.add_argument("--no-cache", action="store_true", | |
| help="Disable the result cache for this run.") | |
| args = parser.parse_args() | |
| files = list_safetensors(Path(args.path)) | |
| if not files: | |
| print(f"No .safetensors files found under {args.path}", file=sys.stderr) | |
| return 2 | |
| print(f"Evaluating {len(files)} file(s) on {args.device}\n") | |
| cache_enabled = bool(args.cache_dir) and not args.no_cache | |
| cache_dir = Path(args.cache_dir) if cache_enabled else None | |
| cache_opts = { | |
| "device": args.device, | |
| "pop_size": args.pop_size, | |
| "cpu_program": bool(args.cpu_program), | |
| } | |
| cache_hits = 0 | |
| results = [] | |
| fail_count = 0 | |
| for f in files: | |
| print(f"=== {f.name}") | |
| cached = None | |
| key = None | |
| if cache_enabled: | |
| try: | |
| key = _cache_key(f, cache_opts) | |
| cached = _load_cache(cache_dir, key) | |
| except OSError: | |
| cached = None | |
| if cached is not None: | |
| r = cached | |
| cache_hits += 1 | |
| print(f" (cache hit)") | |
| else: | |
| r = evaluate_one(f, device=args.device, pop_size=args.pop_size, | |
| debug=args.debug, run_cpu_program=args.cpu_program) | |
| if cache_enabled and key is not None: | |
| _save_cache(cache_dir, key, r) | |
| results.append(r) | |
| print_row(r, show_cpu=args.cpu_program) | |
| if r.get("status") != "PASS": | |
| fail_count += 1 | |
| if args.json: | |
| # Make it serialisable | |
| for r in results: | |
| r["manifest"] = {k: (int(v) if isinstance(v, float) and v.is_integer() else v) | |
| for k, v in r.get("manifest", {}).items()} | |
| print(json.dumps(results, indent=2, default=str)) | |
| return fail_count | |
| # Summary | |
| print() | |
| print("=" * 100) | |
| print(" SUMMARY") | |
| print("=" * 100) | |
| for r in results: | |
| print_row(r, show_cpu=args.cpu_program) | |
| print() | |
| if fail_count == 0: | |
| print(f"ALL {len(files)} variants PASS") | |
| else: | |
| print(f"{fail_count}/{len(files)} variants FAIL") | |
| if cache_enabled: | |
| print(f"(cache: {cache_hits}/{len(files)} hits, dir={cache_dir})") | |
| return fail_count | |
| if __name__ == "__main__": | |
| sys.exit(main()) | |