Spaces:
Sleeping
Sleeping
| """Inference API for SQL error classification.""" | |
| from __future__ import annotations | |
| import argparse | |
| import json | |
| from dataclasses import asdict, dataclass | |
| from pathlib import Path | |
| from typing import List, Optional | |
| from src.categories import id_to_name, load_categories | |
| from src.model import DEFAULT_MODEL_PATH, combine_features, load_model | |
| from src.cross_encoder_model import ( | |
| CrossEncoderClassifier, | |
| FineTunedCrossEncoderClassifier, | |
| ) | |
| from src.multi_tower_model import MultiTowerClassifier, QueryContext | |
| CONTEXT_MODELS = ( | |
| CrossEncoderClassifier, | |
| FineTunedCrossEncoderClassifier, | |
| MultiTowerClassifier, | |
| ) | |
| class Prediction: | |
| label_id: int | |
| label_name: str | |
| confidence: float | |
| top_k: List[dict] | |
| similarities: Optional[dict] = None | |
| pair_scores: Optional[dict] = None | |
| class SQLErrorClassifier: | |
| """Classifier wrapper for playground integration.""" | |
| def __init__(self, model_path: Path = DEFAULT_MODEL_PATH): | |
| self.model = load_model(model_path) | |
| self.label_map = id_to_name(load_categories()) | |
| def predict( | |
| self, | |
| query: str, | |
| error_message: Optional[str] = None, | |
| schema: Optional[str] = None, | |
| question: Optional[str] = None, | |
| correct_query: Optional[str] = None, | |
| top_k: int = 3, | |
| ) -> Prediction: | |
| if isinstance(self.model, CONTEXT_MODELS): | |
| if not all([schema, question, correct_query]): | |
| raise ValueError( | |
| "context models require schema, question, and correct_query" | |
| ) | |
| ctx = QueryContext( | |
| question=question, | |
| schema=schema, | |
| correct_query=correct_query, | |
| student_query=query, | |
| error_message=error_message, | |
| ) | |
| proba = self.model.predict_proba([ctx])[0] | |
| similarities = ( | |
| self.model.explain_similarities(ctx) | |
| if isinstance(self.model, MultiTowerClassifier) | |
| else None | |
| ) | |
| pair_scores = ( | |
| self.model.explain_pair_scores(ctx) | |
| if isinstance(self.model, CrossEncoderClassifier) | |
| else None | |
| ) | |
| else: | |
| pair_scores = None | |
| similarities = None | |
| text = combine_features( | |
| queries=[query], | |
| error_messages=[error_message] if error_message else None, | |
| schemas=[schema] if schema else None, | |
| questions=[question] if question else None, | |
| )[0] | |
| proba = self.model.predict_proba([text])[0] | |
| similarities = None | |
| classes = self.model.classes_ | |
| ranked = sorted(zip(classes, proba), key=lambda x: x[1], reverse=True) | |
| best_id = int(ranked[0][0]) | |
| return Prediction( | |
| label_id=best_id, | |
| label_name=self.label_map[best_id], | |
| confidence=float(ranked[0][1]), | |
| top_k=[ | |
| { | |
| "label_id": int(cls), | |
| "label_name": self.label_map[int(cls)], | |
| "confidence": float(p), | |
| } | |
| for cls, p in ranked[:top_k] | |
| ], | |
| similarities=similarities, | |
| pair_scores=pair_scores, | |
| ) | |
| def main() -> None: | |
| parser = argparse.ArgumentParser(description="Classify SQL error type") | |
| parser.add_argument("--query", type=str, required=True) | |
| parser.add_argument("--correct-query", type=str, default=None) | |
| parser.add_argument("--error-message", type=str, default=None) | |
| parser.add_argument("--schema", type=str, default=None) | |
| parser.add_argument("--question", type=str, default=None) | |
| parser.add_argument("--model", type=Path, default=DEFAULT_MODEL_PATH) | |
| parser.add_argument("--top-k", type=int, default=3) | |
| args = parser.parse_args() | |
| clf = SQLErrorClassifier(args.model) | |
| result = clf.predict( | |
| args.query, | |
| args.error_message, | |
| args.schema, | |
| args.question, | |
| args.correct_query, | |
| top_k=args.top_k, | |
| ) | |
| print(json.dumps(asdict(result), indent=2)) | |
| if __name__ == "__main__": | |
| main() | |