| from typing import Dict, List, Any |
| from PIL import Image |
| import base64 |
| import io |
| import os |
| import torch |
|
|
| class EndpointHandler: |
| def __init__(self, path=""): |
| from doclayout_yolo import YOLOv10 |
| |
| |
| model_path = os.path.join(path, "doclayout_yolo_docstructbench_imgsz1024.pt") |
| self.model = YOLOv10(model_path) |
| |
| |
| self.id_to_names = { |
| 0: 'title', |
| 1: 'plain_text', |
| 2: 'abandon', |
| 3: 'figure', |
| 4: 'figure_caption', |
| 5: 'table', |
| 6: 'table_caption', |
| 7: 'table_footnote', |
| 8: 'isolate_formula', |
| 9: 'formula_caption' |
| } |
| |
| |
| self.device = 'cuda' if torch.cuda.is_available() else 'cpu' |
| |
| def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]: |
| """ |
| Process image and return layout detections. |
| |
| Args: |
| data: Dictionary with: |
| - "inputs": base64 encoded image string or PIL Image |
| - "parameters" (optional): { |
| "confidence": float (default 0.2), |
| "iou_threshold": float (default 0.45) |
| } |
| |
| Returns: |
| List of detections with label, score, and bounding box |
| """ |
| |
| image = data.get("inputs") |
| |
| |
| params = data.get("parameters", {}) |
| conf_threshold = params.get("confidence", 0.2) |
| iou_threshold = params.get("iou_threshold", 0.45) |
| |
| |
| if isinstance(image, str): |
| |
| if "base64," in image: |
| image = image.split("base64,")[1] |
| image = Image.open(io.BytesIO(base64.b64decode(image))) |
| |
| |
| results = self.model.predict( |
| image, |
| imgsz=1024, |
| conf=conf_threshold, |
| iou=iou_threshold, |
| device=self.device |
| )[0] |
| |
| |
| detections = [] |
| boxes = results.boxes |
| |
| for i in range(len(boxes)): |
| box = boxes[i] |
| cls_id = int(box.cls.item()) |
| |
| detections.append({ |
| "label": self.id_to_names.get(cls_id, f"class_{cls_id}"), |
| "score": round(float(box.conf.item()), 4), |
| "box": { |
| "x1": round(float(box.xyxy[0][0].item()), 2), |
| "y1": round(float(box.xyxy[0][1].item()), 2), |
| "x2": round(float(box.xyxy[0][2].item()), 2), |
| "y2": round(float(box.xyxy[0][3].item()), 2) |
| } |
| }) |
| |
| |
| detections.sort(key=lambda x: x["score"], reverse=True) |
| |
| return detections |