CFPVesselSeg / app.py
farrell236's picture
add src
e99a83c
Raw
History Blame Contribute Delete
8.01 kB
import argparse
from pathlib import Path
import cv2
import gradio as gr
import numpy as np
import torch
from PIL import Image
from augmentations import IMAGENET_MEAN, IMAGENET_STD
from models import build_model
APP_STATE = {}
def load_model(args, device):
model = build_model(
model_name=args.model,
num_classes=1,
in_channels=3,
image_size=args.image_size,
backbone=args.backbone,
pretrained=False,
base_channels=args.base_channels,
dropout=args.dropout,
)
checkpoint = torch.load(args.checkpoint, map_location="cpu")
if "model_state_dict" in checkpoint:
state_dict = checkpoint["model_state_dict"]
else:
state_dict = checkpoint
model.load_state_dict(state_dict, strict=True)
model.to(device)
model.eval()
return model
def preprocess_image(image, image_size):
if isinstance(image, Image.Image):
image = np.array(image.convert("RGB"))
else:
image = np.array(image)
if image.ndim == 2:
image = cv2.cvtColor(image, cv2.COLOR_GRAY2RGB)
if image.shape[-1] == 4:
image = image[..., :3]
original_rgb = image.copy()
resized = cv2.resize(
image,
(image_size, image_size),
interpolation=cv2.INTER_LINEAR,
)
resized = resized.astype(np.float32) / 255.0
mean = np.array(IMAGENET_MEAN, dtype=np.float32).reshape(1, 1, 3)
std = np.array(IMAGENET_STD, dtype=np.float32).reshape(1, 1, 3)
resized = (resized - mean) / std
tensor = torch.from_numpy(resized).permute(2, 0, 1).unsqueeze(0).float()
return tensor, original_rgb
def overlay_mask(image_rgb, mask, alpha=0.45):
image_rgb = image_rgb.astype(np.uint8)
red = np.zeros_like(image_rgb)
red[..., 0] = 255
mask_3ch = mask[..., None]
overlay = image_rgb * (1 - alpha * mask_3ch) + red * (alpha * mask_3ch)
overlay = np.clip(overlay, 0, 255).astype(np.uint8)
return overlay
def run_inference(image, threshold):
tensor, original_rgb = preprocess_image(
image=image,
image_size=APP_STATE["image_size"],
)
tensor = tensor.to(APP_STATE["device"])
with torch.no_grad():
logits = APP_STATE["model"](tensor)
probs = torch.sigmoid(logits)
prob_map = probs[0, 0].detach().cpu().numpy()
original_h, original_w = original_rgb.shape[:2]
prob_map = cv2.resize(
prob_map,
(original_w, original_h),
interpolation=cv2.INTER_LINEAR,
)
pred_mask = (prob_map >= threshold).astype(np.float32)
return original_rgb, prob_map, pred_mask
def predict(image, threshold, alpha):
if image is None:
return None, None, None
original_rgb, prob_map, pred_mask = run_inference(image, threshold)
overlay = overlay_mask(original_rgb, pred_mask, alpha=alpha)
prob_vis = (prob_map * 255).clip(0, 255).astype(np.uint8)
mask_vis = (pred_mask * 255).astype(np.uint8)
return overlay, prob_vis, mask_vis
def build_app():
css = """
#input_image {
height: 430px !important;
}
#input_image img {
object-fit: contain !important;
max-height: 430px !important;
}
#overlay_output {
height: 200px !important;
}
#overlay_output img {
object-fit: contain !important;
max-height: 200px !important;
}
#prob_output {
height: 200px !important;
}
#prob_output img {
object-fit: contain !important;
max-height: 200px !important;
}
#mask_output {
height: 430px !important;
}
#mask_output img {
object-fit: contain !important;
max-height: 430px !important;
}
"""
with gr.Blocks(title="Retina Vessel Segmentation", css=css) as demo:
gr.Markdown("# Retina Vessel Segmentation")
gr.Markdown(
f"Model: `{APP_STATE['model_name']}` | "
f"Backbone: `{APP_STATE['backbone']}` | "
f"Image size: `{APP_STATE['image_size']}`"
)
with gr.Row(equal_height=False):
with gr.Column(scale=1):
input_image = gr.Image(
type="pil",
label="Input CFP Image",
elem_id="input_image",
height=430,
)
threshold = gr.Slider(
minimum=0.05,
maximum=0.95,
value=0.5,
step=0.05,
label="Prediction Threshold",
)
alpha = gr.Slider(
minimum=0.1,
maximum=0.9,
value=0.45,
step=0.05,
label="Overlay Alpha",
)
run_button = gr.Button("Segment")
with gr.Column(scale=1.2):
with gr.Row():
overlay_output = gr.Image(
type="numpy",
label="Overlay",
elem_id="overlay_output",
height=200,
)
prob_output = gr.Image(
type="numpy",
label="Probability Map",
elem_id="prob_output",
height=200,
)
mask_output = gr.Image(
type="numpy",
label="Binary Mask",
elem_id="mask_output",
height=430,
)
run_button.click(
fn=predict,
inputs=[input_image, threshold, alpha],
outputs=[overlay_output, prob_output, mask_output],
)
threshold.change(
fn=predict,
inputs=[input_image, threshold, alpha],
outputs=[overlay_output, prob_output, mask_output],
)
alpha.change(
fn=predict,
inputs=[input_image, threshold, alpha],
outputs=[overlay_output, prob_output, mask_output],
)
return demo
def parse_args():
parser = argparse.ArgumentParser(description="Gradio app for retina vessel segmentation.")
parser.add_argument("--checkpoint", type=str, default="checkpoints/fives_resunet/best.pt")
parser.add_argument("--image-size", type=int, default=1024)
parser.add_argument("--model", type=str, default="resunet", choices=["resunet", "deeplabv3", "vit"])
parser.add_argument("--backbone", type=str, default="resnet50")
parser.add_argument("--base-channels", type=int, default=32)
parser.add_argument("--dropout", type=float, default=0.0)
parser.add_argument("--device", type=str, default="cuda")
parser.add_argument("--server-name", type=str, default="127.0.0.1")
parser.add_argument("--server-port", type=int, default=7860)
parser.add_argument("--share", action="store_true")
return parser.parse_args()
if __name__ == "__main__":
args = parse_args()
device = args.device
if device == "cuda" and not torch.cuda.is_available():
device = "cpu"
checkpoint_path = Path(args.checkpoint)
if not checkpoint_path.exists():
raise FileNotFoundError(f"Checkpoint not found: {checkpoint_path}")
APP_STATE["device"] = torch.device(device)
APP_STATE["image_size"] = args.image_size
APP_STATE["model_name"] = args.model
APP_STATE["backbone"] = args.backbone
APP_STATE["model"] = load_model(
args=args,
device=APP_STATE["device"],
)
print(f"Loaded checkpoint: {checkpoint_path}")
print(f"Device: {APP_STATE['device']}")
print(f"Model: {APP_STATE['model_name']}")
print(f"Backbone: {APP_STATE['backbone']}")
print(f"Image size: {APP_STATE['image_size']}")
demo = build_app()
demo.launch(
# server_name=args.server_name,
# server_port=args.server_port,
# share=args.share,
)