File size: 2,958 Bytes
ca8fa7a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
92b5cbe
ca8fa7a
 
3bd068b
ca8fa7a
 
 
 
 
 
92b5cbe
ca8fa7a
 
 
 
 
 
 
 
92b5cbe
 
ca8fa7a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
09330e4
 
 
 
 
 
 
 
92b5cbe
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
import cv2
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import gradio as gr
from seg import U2NETP

# Image processing utilities
def load_image(path: str):
    """ Loads an image from the specified path and converts it to RGB format. """
    img = cv2.imread(path)
    img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
    return img / 255.0

def save_image(image: np.ndarray, path: str):
    """ Saves an image to the specified path. """
    img = (image * 255).astype(np.uint8)
    img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR)
    cv2.imwrite(path, img)

# Document Segmentation Model
class U2NETP_DocSeg(nn.Module):
    def __init__(self, num_classes=1):
        super(U2NETP_DocSeg, self).__init__()
        self.u2netp = U2NETP(out_ch=num_classes)

    def forward(self, x):
        mask, *_ = self.u2netp(x)
        return mask

# Initialize the document segmentation model
docseg = U2NETP_DocSeg(num_classes=1)
# Load pretrained weights
docseg_weight_path = './weights/u2netp_docseg_epoch_225_date_2026-01-02.pth'
checkpoint = torch.load(docseg_weight_path, map_location=torch.device('cpu'))
docseg.load_state_dict(checkpoint[f"model_state_dict"])
docseg.eval()

# Get segmentation mask
def get_mask(image, confidence=0.5):
    org_shape = image.shape[:2]
    image_tensor = torch.from_numpy(image).float().permute(2, 0, 1).unsqueeze(0)
    image_tensor = F.interpolate(image_tensor, size=(288, 288), mode='bilinear')
    with torch.inference_mode():  # faster than no_grad
        mask = docseg(image_tensor)
        mask = (mask > confidence).float()
        mask = F.interpolate(mask, size=org_shape, mode='bilinear')
    return mask[0, 0]  # keep tensor

def overlay_mask(image, mask):
    image = torch.from_numpy(image).float()
    red = torch.tensor([1.0, 0, 0]).view(1, 3, 1, 1)
    mask = mask.unsqueeze(0)  # (1, H, W)
    mask = mask.unsqueeze(0)  # (1, 1, H, W)
    overlay = image.permute(2, 0, 1).unsqueeze(0)
    overlay = torch.where(mask > 0, red, overlay)
    blended = 0.7 * image.permute(2, 0, 1).unsqueeze(0) + 0.3 * overlay
    return blended[0].permute(1, 2, 0).cpu().numpy()

def segment_image(image):
    """ Gradio function to segment input image and return overlay. """
    image = image.astype(np.float32) / 255.0  # Normalize to [0, 1]
    mask = get_mask(image, confidence=0.5)
    overlayed_image = overlay_mask(image, mask)
    yield overlayed_image

with gr.Blocks() as demo:
    gr.Markdown("## Real-time Document Segmentation")
    with gr.Row():
        input_image = gr.Image(label="Input Image", type="numpy")
        output_image = gr.Image(label="Segmentation Overlay", type="numpy")
    examples = gr.Examples(
        examples=[
            "./examples/sample.jpg",
            "./examples/manga.png",
            "./examples/invoice.png"
        ],
        inputs=input_image
    )
    input_image.change(segment_image, inputs=input_image, outputs=output_image)

demo.launch()