| from fastapi import FastAPI, Query |
| from pydantic import BaseModel |
| from typing import List |
| from transformers import BertTokenizer, BertForSequenceClassification |
| import torch |
| import pickle |
| import random |
| from collections import defaultdict |
|
|
| app = FastAPI() |
|
|
| |
| tokenizer = BertTokenizer.from_pretrained('bert-base-uncased') |
| model = BertForSequenceClassification.from_pretrained("best_model") |
| model.eval() |
|
|
| with open("best_model/label_encoder.pkl", "rb") as f: |
| label_encoder = pickle.load(f) |
|
|
| class PredictionResponse(BaseModel): |
| disease: str |
| probability: float |
|
|
| @app.get("/predict", response_model=List[PredictionResponse]) |
| def predict(symptoms: str = Query(..., description="Comma-separated symptoms")): |
| symptoms_list = [s.strip() for s in symptoms.split(",") if s.strip()] |
| agg_probs = defaultdict(float) |
| n_shuffles = 10 |
|
|
| for _ in range(n_shuffles): |
| random.shuffle(symptoms_list) |
| shuffled_text = ", ".join(symptoms_list) |
| inputs = tokenizer(shuffled_text, return_tensors="pt", truncation=True, padding=True, max_length=128) |
| with torch.no_grad(): |
| outputs = model(**inputs) |
| probs = torch.nn.functional.softmax(outputs.logits, dim=-1).squeeze() |
| for i, p in enumerate(probs): |
| agg_probs[i] += p.item() |
| for k in agg_probs: |
| agg_probs[k] /= n_shuffles |
| top_3 = sorted(agg_probs.items(), key=lambda x: x[1], reverse=True)[:3] |
|
|
| results = [] |
| for idx, prob in top_3: |
| label = label_encoder.classes_[idx] |
| results.append({"disease": label, "probability": float(prob)}) |
|
|
| return results |
|
|