| from fastapi import FastAPI, UploadFile, File |
| from transformers import AutoModelForImageClassification, AutoFeatureExtractor |
| from PIL import Image |
| import torch |
| import io |
|
|
| app = FastAPI() |
|
|
| |
| model_name = "mmuratarat/kvasir-v2-classifier" |
| model = AutoModelForImageClassification.from_pretrained(model_name) |
| feature_extractor = AutoFeatureExtractor.from_pretrained(model_name) |
|
|
| |
| id2label = { |
| '0': 'dyed-lifted-polyps', |
| '1': 'dyed-resection-margins', |
| '2': 'esophagitis', |
| '3': 'normal-cecum', |
| '4': 'normal-pylorus', |
| '5': 'normal-z-line', |
| '6': 'polyps', |
| '7': 'ulcerative-colitis' |
| } |
|
|
| |
| polyp_mapping = { |
| 'dyed-lifted-polyps': "Polyp Present", |
| 'dyed-resection-margins': "Polyp Present", |
| 'polyps': "Polyp Present", |
| 'ulcerative-colitis': "Polyp Absent", |
| 'esophagitis': "Polyp Absent", |
| 'normal-cecum': "Polyp Absent", |
| 'normal-pylorus': "Polyp Absent", |
| 'normal-z-line': "Polyp Absent" |
| } |
|
|
| @app.post("/predict/") |
| async def predict(file: UploadFile = File(...)): |
| |
| image_bytes = await file.read() |
| image = Image.open(io.BytesIO(image_bytes)).convert("RGB") |
|
|
| |
| inputs = feature_extractor(image, return_tensors="pt") |
| logits = model(**inputs).logits |
|
|
| |
| predicted_label = logits.argmax(-1).item() |
| predicted_class = id2label[str(predicted_label)] |
| |
| |
| polyp_status = polyp_mapping[predicted_class] |
|
|
| return { |
| "predicted_class": predicted_class, |
| "polyp_status": polyp_status |
| } |
|
|
|
|