davidclara's picture
Upload inference.py
6fbf967 verified
Raw
History Blame Contribute Delete
4.79 kB
"""Run a davidclara/building-block-vectorization model from Hugging Face on a map image.
Averages predictions across cross-validation folds with a Gaussian-weighted
sliding window, thresholds with the ensemble threshold from config.json, and
writes a single binary PNG mask.
Example:
python inference.py \\
--hf-repo davidclara/building-block-vectorization \\
--model-name unet_scse \\
--image map.jpg \\
--out mask.png
"""
import argparse
import inspect
import json
from pathlib import Path
import numpy as np
import segmentation_models_pytorch as smp
import torch
from huggingface_hub import hf_hub_download
from PIL import Image
from safetensors.torch import load_file
Image.MAX_IMAGE_PIXELS = None
IMAGENET_MEAN = np.array([0.485, 0.456, 0.406], dtype=np.float32)
IMAGENET_STD = np.array([0.229, 0.224, 0.225], dtype=np.float32)
SMP = {
"unet": smp.Unet,
"unetpp": smp.UnetPlusPlus,
"deeplabv3p": smp.DeepLabV3Plus,
"fpn": smp.FPN,
"pan": smp.PAN,
}
def build_model(model_cfg: dict) -> torch.nn.Module:
cfg = dict(model_cfg)
name = cfg.pop("name")
cfg["classes"] = cfg.pop("num_classes")
cfg["encoder_weights"] = None # weights come from safetensors
cls = SMP[name]
accepted = set(inspect.signature(cls).parameters)
return cls(**{k: v for k, v in cfg.items() if k in accepted})
def gaussian_kernel(size: int, sigma_ratio: float = 0.125) -> np.ndarray:
sigma = size * sigma_ratio
ax = np.arange(size) - (size - 1) / 2.0
g1d = np.exp(-(ax**2) / (2 * sigma**2))
g2d = np.outer(g1d, g1d)
return (g2d / g2d.max()).astype(np.float32)
def sliding_window_ensemble(models, img, patch, stride, n_classes, device):
_, H, W = img.shape
probs_sum = np.zeros((n_classes, H, W), dtype=np.float32)
weight_sum = np.zeros((H, W), dtype=np.float32)
kernel = gaussian_kernel(patch)
rows = sorted({*range(0, max(H - patch, 0) + 1, stride), max(H - patch, 0)})
cols = sorted({*range(0, max(W - patch, 0) + 1, stride), max(W - patch, 0)})
with torch.no_grad():
for r in rows:
for c in cols:
tile = img[:, r : r + patch, c : c + patch]
ph, pw = patch - tile.shape[1], patch - tile.shape[2]
if ph or pw:
tile = np.pad(tile, ((0, 0), (0, ph), (0, pw)))
x = torch.from_numpy(tile).unsqueeze(0).to(device)
# average sigmoid(logits) across folds (matches predict.py)
fold_probs = None
for m in models:
logits = m(x).cpu().numpy()[0]
p = 1.0 / (1.0 + np.exp(-np.clip(logits, -88, 88)))
fold_probs = p if fold_probs is None else fold_probs + p
fold_probs = fold_probs / len(models)
h, w = patch - ph, patch - pw
g = kernel[:h, :w]
probs_sum[:, r : r + h, c : c + w] += fold_probs[:, :h, :w] * g[None]
weight_sum[r : r + h, c : c + w] += g
return probs_sum / np.maximum(weight_sum, 1e-8)
def main() -> None:
ap = argparse.ArgumentParser()
ap.add_argument("--hf-repo", required=True)
ap.add_argument("--model-name", required=True)
ap.add_argument("--image", required=True)
ap.add_argument("--out", default="mask.png")
args = ap.parse_args()
cfg_path = hf_hub_download(args.hf_repo, f"{args.model_name}/config.json")
cfg = json.loads(Path(cfg_path).read_text())
n_folds = cfg["n_folds"]
fold_paths = [
hf_hub_download(args.hf_repo, f"{args.model_name}/model_f{i}.safetensors")
for i in range(n_folds)
]
device = "cuda" if torch.cuda.is_available() else ("mps" if torch.backends.mps.is_available() else "cpu")
classes = cfg["class_names"]
patch = cfg["patch_size"]
normalize_mode = cfg.get("normalize_mode", "imagenet")
thrs = cfg.get("ensemble_thresholds") or {}
thr = np.array([thrs.get(n, 0.5) for n in classes], dtype=np.float32).reshape(-1, 1, 1)
img = np.asarray(Image.open(args.image).convert("RGB"), dtype=np.float32) / 255.0
if normalize_mode == "imagenet":
img = (img - IMAGENET_MEAN) / IMAGENET_STD
img = img.transpose(2, 0, 1)
models = []
for wts_path in fold_paths:
m = build_model(cfg["model"]).to(device).eval()
m.load_state_dict(load_file(wts_path))
models.append(m)
print(f"Loaded {n_folds} fold(s), device={device}, normalize_mode={normalize_mode}")
probs = sliding_window_ensemble(models, img, patch, patch // 2, len(classes), device)
binary = (probs > thr).astype(np.uint8)[0]
Image.fromarray(binary * 255, "L").save(args.out)
print(f"Wrote mask to {args.out}")
if __name__ == "__main__":
main()