| """Run a davidclara/building-block-vectorization model from Hugging Face on a map image. |
| |
| Averages predictions across cross-validation folds with a Gaussian-weighted |
| sliding window, thresholds with the ensemble threshold from config.json, and |
| writes a single binary PNG mask. |
| |
| Example: |
| python inference.py \\ |
| --hf-repo davidclara/building-block-vectorization \\ |
| --model-name unet_scse \\ |
| --image map.jpg \\ |
| --out mask.png |
| """ |
|
|
| import argparse |
| import inspect |
| import json |
| from pathlib import Path |
|
|
| import numpy as np |
| import segmentation_models_pytorch as smp |
| import torch |
| from huggingface_hub import hf_hub_download |
| from PIL import Image |
| from safetensors.torch import load_file |
|
|
| Image.MAX_IMAGE_PIXELS = None |
| IMAGENET_MEAN = np.array([0.485, 0.456, 0.406], dtype=np.float32) |
| IMAGENET_STD = np.array([0.229, 0.224, 0.225], dtype=np.float32) |
| SMP = { |
| "unet": smp.Unet, |
| "unetpp": smp.UnetPlusPlus, |
| "deeplabv3p": smp.DeepLabV3Plus, |
| "fpn": smp.FPN, |
| "pan": smp.PAN, |
| } |
|
|
|
|
| def build_model(model_cfg: dict) -> torch.nn.Module: |
| cfg = dict(model_cfg) |
| name = cfg.pop("name") |
| cfg["classes"] = cfg.pop("num_classes") |
| cfg["encoder_weights"] = None |
| cls = SMP[name] |
| accepted = set(inspect.signature(cls).parameters) |
| return cls(**{k: v for k, v in cfg.items() if k in accepted}) |
|
|
|
|
| def gaussian_kernel(size: int, sigma_ratio: float = 0.125) -> np.ndarray: |
| sigma = size * sigma_ratio |
| ax = np.arange(size) - (size - 1) / 2.0 |
| g1d = np.exp(-(ax**2) / (2 * sigma**2)) |
| g2d = np.outer(g1d, g1d) |
| return (g2d / g2d.max()).astype(np.float32) |
|
|
|
|
| def sliding_window_ensemble(models, img, patch, stride, n_classes, device): |
| _, H, W = img.shape |
| probs_sum = np.zeros((n_classes, H, W), dtype=np.float32) |
| weight_sum = np.zeros((H, W), dtype=np.float32) |
| kernel = gaussian_kernel(patch) |
| rows = sorted({*range(0, max(H - patch, 0) + 1, stride), max(H - patch, 0)}) |
| cols = sorted({*range(0, max(W - patch, 0) + 1, stride), max(W - patch, 0)}) |
| with torch.no_grad(): |
| for r in rows: |
| for c in cols: |
| tile = img[:, r : r + patch, c : c + patch] |
| ph, pw = patch - tile.shape[1], patch - tile.shape[2] |
| if ph or pw: |
| tile = np.pad(tile, ((0, 0), (0, ph), (0, pw))) |
| x = torch.from_numpy(tile).unsqueeze(0).to(device) |
| |
| fold_probs = None |
| for m in models: |
| logits = m(x).cpu().numpy()[0] |
| p = 1.0 / (1.0 + np.exp(-np.clip(logits, -88, 88))) |
| fold_probs = p if fold_probs is None else fold_probs + p |
| fold_probs = fold_probs / len(models) |
| h, w = patch - ph, patch - pw |
| g = kernel[:h, :w] |
| probs_sum[:, r : r + h, c : c + w] += fold_probs[:, :h, :w] * g[None] |
| weight_sum[r : r + h, c : c + w] += g |
| return probs_sum / np.maximum(weight_sum, 1e-8) |
|
|
|
|
| def main() -> None: |
| ap = argparse.ArgumentParser() |
| ap.add_argument("--hf-repo", required=True) |
| ap.add_argument("--model-name", required=True) |
| ap.add_argument("--image", required=True) |
| ap.add_argument("--out", default="mask.png") |
| args = ap.parse_args() |
|
|
| cfg_path = hf_hub_download(args.hf_repo, f"{args.model_name}/config.json") |
| cfg = json.loads(Path(cfg_path).read_text()) |
| n_folds = cfg["n_folds"] |
| fold_paths = [ |
| hf_hub_download(args.hf_repo, f"{args.model_name}/model_f{i}.safetensors") |
| for i in range(n_folds) |
| ] |
|
|
| device = "cuda" if torch.cuda.is_available() else ("mps" if torch.backends.mps.is_available() else "cpu") |
| classes = cfg["class_names"] |
| patch = cfg["patch_size"] |
| normalize_mode = cfg.get("normalize_mode", "imagenet") |
| thrs = cfg.get("ensemble_thresholds") or {} |
| thr = np.array([thrs.get(n, 0.5) for n in classes], dtype=np.float32).reshape(-1, 1, 1) |
|
|
| img = np.asarray(Image.open(args.image).convert("RGB"), dtype=np.float32) / 255.0 |
| if normalize_mode == "imagenet": |
| img = (img - IMAGENET_MEAN) / IMAGENET_STD |
| img = img.transpose(2, 0, 1) |
|
|
| models = [] |
| for wts_path in fold_paths: |
| m = build_model(cfg["model"]).to(device).eval() |
| m.load_state_dict(load_file(wts_path)) |
| models.append(m) |
| print(f"Loaded {n_folds} fold(s), device={device}, normalize_mode={normalize_mode}") |
|
|
| probs = sliding_window_ensemble(models, img, patch, patch // 2, len(classes), device) |
| binary = (probs > thr).astype(np.uint8)[0] |
| Image.fromarray(binary * 255, "L").save(args.out) |
| print(f"Wrote mask to {args.out}") |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|