LN_segmentation_sweep_v2

A unet model for multilabel image segmentation trained with sliding window approach.

Model Description

Wandb Parameters

Parameter Value
data_path GleghornLab/Semi-Automated_LN_Segmentation_10_11_2025
img_size 128
downsample_factor 1
num_channels 3
batch_size 16
lr 1.7122348637490954e-05
epochs 100
patience 10
weight_decay 8.29726636990404e-05
model_type unet
n_filts 32
t 3
k 3
augment False
norm True
keep 0.06990272917761037
pruning_factor 0.019243240405735405
output_dir pooled_metrics_hev_settings
device None
num_workers 4
prefetch_factor 2
wandb_project segmentation-sweep
wandb_run_name hev-only-repro-pooled
wandb_mode online
push_to_hub True
hub_model_id aholk/LN_segmentation_sweep_v2
skip_report False
sweep_mode False
num_params 34527236
num_classes 4

Model Parameters

Parameter Value
num_channels 3
num_classes 4
n_filts 32
t 3
k 3
img_size 128
norm True
model_arch unet
transformers_version 5.9.0
architectures ["UNetForSegmentation"]
output_hidden_states False
return_dict True
dtype float32
chunk_size_feed_forward 0
is_encoder_decoder False
id2label {"0": "LABEL_0", "1": "LABEL_1"}
label2id {"LABEL_0": 0, "LABEL_1": 1}
problem_type None
_name_or_path
batch_size 16
downsample_factor 1.0
model_type segmentation
output_attentions False

Performance Metrics

Metric Mean Class 0 Class 1 Class 2 Class 3
Dice 0.8169 0.7188 0.8196 0.8181 0.9112
IoU 0.6961 0.5610 0.6943 0.6923 0.8369
F1 0.8169 0.7188 0.8196 0.8181 0.9112
MCC 0.8124 0.7261 0.8171 0.8134 0.8928
ROC AUC 0.9768 0.9726 0.9923 0.9535 0.9888
PR AUC 0.8821 0.8046 0.8960 0.8652 0.9627

Usage

import numpy as np
from model import MODEL_REGISTRY, SegmentationConfig

# Load model
config = SegmentationConfig.from_pretrained("aholk/LN_segmentation_sweep_v2")
model = MODEL_REGISTRY["unet"].from_pretrained("aholk/LN_segmentation_sweep_v2")
model.eval()

# Run inference on a full image with sliding window
image = np.random.rand(2048, 2048, 3).astype(np.float32)  # Your image here
probs = model.predict_full_image(
    image,
    dim=128,
    batch_size=16,
    device="cuda"  # or "cpu"
)
# probs shape: (num_classes, H, W) with values in [0, 1]

# Threshold to get binary masks
masks = (probs > 0.5).astype(np.uint8)

Training Plots

Training Loss Dice Curves IoU Curves MCC Curves Best Validation

Citation

If you use this model, please cite:

@software{windowz_segmentation,
  title={Multilabel Image Segmentation with Sliding Window U-Net},
  author={Gleghorn Lab},
  year={2025},
  url={https://github.com/GleghornLab/ComputerVision2}
}
Downloads last month
18
Safetensors
Model size
34.5M params
Tensor type
F32
·
Inference Providers NEW
This model isn't deployed by any Inference Provider. 🙋 Ask for provider support