"""Run a davidclara/building-block-vectorization model from Hugging Face on a map image. Averages predictions across cross-validation folds with a Gaussian-weighted sliding window, thresholds with the ensemble threshold from config.json, and writes a single binary PNG mask. Example: python inference.py \\ --hf-repo davidclara/building-block-vectorization \\ --model-name unet_scse \\ --image map.jpg \\ --out mask.png """ import argparse import inspect import json from pathlib import Path import numpy as np import segmentation_models_pytorch as smp import torch from huggingface_hub import hf_hub_download from PIL import Image from safetensors.torch import load_file Image.MAX_IMAGE_PIXELS = None IMAGENET_MEAN = np.array([0.485, 0.456, 0.406], dtype=np.float32) IMAGENET_STD = np.array([0.229, 0.224, 0.225], dtype=np.float32) SMP = { "unet": smp.Unet, "unetpp": smp.UnetPlusPlus, "deeplabv3p": smp.DeepLabV3Plus, "fpn": smp.FPN, "pan": smp.PAN, } def build_model(model_cfg: dict) -> torch.nn.Module: cfg = dict(model_cfg) name = cfg.pop("name") cfg["classes"] = cfg.pop("num_classes") cfg["encoder_weights"] = None # weights come from safetensors cls = SMP[name] accepted = set(inspect.signature(cls).parameters) return cls(**{k: v for k, v in cfg.items() if k in accepted}) def gaussian_kernel(size: int, sigma_ratio: float = 0.125) -> np.ndarray: sigma = size * sigma_ratio ax = np.arange(size) - (size - 1) / 2.0 g1d = np.exp(-(ax**2) / (2 * sigma**2)) g2d = np.outer(g1d, g1d) return (g2d / g2d.max()).astype(np.float32) def sliding_window_ensemble(models, img, patch, stride, n_classes, device): _, H, W = img.shape probs_sum = np.zeros((n_classes, H, W), dtype=np.float32) weight_sum = np.zeros((H, W), dtype=np.float32) kernel = gaussian_kernel(patch) rows = sorted({*range(0, max(H - patch, 0) + 1, stride), max(H - patch, 0)}) cols = sorted({*range(0, max(W - patch, 0) + 1, stride), max(W - patch, 0)}) with torch.no_grad(): for r in rows: for c in cols: tile = img[:, r : r + patch, c : c + patch] ph, pw = patch - tile.shape[1], patch - tile.shape[2] if ph or pw: tile = np.pad(tile, ((0, 0), (0, ph), (0, pw))) x = torch.from_numpy(tile).unsqueeze(0).to(device) # average sigmoid(logits) across folds (matches predict.py) fold_probs = None for m in models: logits = m(x).cpu().numpy()[0] p = 1.0 / (1.0 + np.exp(-np.clip(logits, -88, 88))) fold_probs = p if fold_probs is None else fold_probs + p fold_probs = fold_probs / len(models) h, w = patch - ph, patch - pw g = kernel[:h, :w] probs_sum[:, r : r + h, c : c + w] += fold_probs[:, :h, :w] * g[None] weight_sum[r : r + h, c : c + w] += g return probs_sum / np.maximum(weight_sum, 1e-8) def main() -> None: ap = argparse.ArgumentParser() ap.add_argument("--hf-repo", required=True) ap.add_argument("--model-name", required=True) ap.add_argument("--image", required=True) ap.add_argument("--out", default="mask.png") args = ap.parse_args() cfg_path = hf_hub_download(args.hf_repo, f"{args.model_name}/config.json") cfg = json.loads(Path(cfg_path).read_text()) n_folds = cfg["n_folds"] fold_paths = [ hf_hub_download(args.hf_repo, f"{args.model_name}/model_f{i}.safetensors") for i in range(n_folds) ] device = "cuda" if torch.cuda.is_available() else ("mps" if torch.backends.mps.is_available() else "cpu") classes = cfg["class_names"] patch = cfg["patch_size"] normalize_mode = cfg.get("normalize_mode", "imagenet") thrs = cfg.get("ensemble_thresholds") or {} thr = np.array([thrs.get(n, 0.5) for n in classes], dtype=np.float32).reshape(-1, 1, 1) img = np.asarray(Image.open(args.image).convert("RGB"), dtype=np.float32) / 255.0 if normalize_mode == "imagenet": img = (img - IMAGENET_MEAN) / IMAGENET_STD img = img.transpose(2, 0, 1) models = [] for wts_path in fold_paths: m = build_model(cfg["model"]).to(device).eval() m.load_state_dict(load_file(wts_path)) models.append(m) print(f"Loaded {n_folds} fold(s), device={device}, normalize_mode={normalize_mode}") probs = sliding_window_ensemble(models, img, patch, patch // 2, len(classes), device) binary = (probs > thr).astype(np.uint8)[0] Image.fromarray(binary * 255, "L").save(args.out) print(f"Wrote mask to {args.out}") if __name__ == "__main__": main()