nishu08's picture
Deploy CodeBERT training Space
9b2cded verified
Raw
History Blame Contribute Delete
4.22 kB
"""Inference API for SQL error classification."""
from __future__ import annotations
import argparse
import json
from dataclasses import asdict, dataclass
from pathlib import Path
from typing import List, Optional
from src.categories import id_to_name, load_categories
from src.model import DEFAULT_MODEL_PATH, combine_features, load_model
from src.cross_encoder_model import (
CrossEncoderClassifier,
FineTunedCrossEncoderClassifier,
)
from src.multi_tower_model import MultiTowerClassifier, QueryContext
CONTEXT_MODELS = (
CrossEncoderClassifier,
FineTunedCrossEncoderClassifier,
MultiTowerClassifier,
)
@dataclass
class Prediction:
label_id: int
label_name: str
confidence: float
top_k: List[dict]
similarities: Optional[dict] = None
pair_scores: Optional[dict] = None
class SQLErrorClassifier:
"""Classifier wrapper for playground integration."""
def __init__(self, model_path: Path = DEFAULT_MODEL_PATH):
self.model = load_model(model_path)
self.label_map = id_to_name(load_categories())
def predict(
self,
query: str,
error_message: Optional[str] = None,
schema: Optional[str] = None,
question: Optional[str] = None,
correct_query: Optional[str] = None,
top_k: int = 3,
) -> Prediction:
if isinstance(self.model, CONTEXT_MODELS):
if not all([schema, question, correct_query]):
raise ValueError(
"context models require schema, question, and correct_query"
)
ctx = QueryContext(
question=question,
schema=schema,
correct_query=correct_query,
student_query=query,
error_message=error_message,
)
proba = self.model.predict_proba([ctx])[0]
similarities = (
self.model.explain_similarities(ctx)
if isinstance(self.model, MultiTowerClassifier)
else None
)
pair_scores = (
self.model.explain_pair_scores(ctx)
if isinstance(self.model, CrossEncoderClassifier)
else None
)
else:
pair_scores = None
similarities = None
text = combine_features(
queries=[query],
error_messages=[error_message] if error_message else None,
schemas=[schema] if schema else None,
questions=[question] if question else None,
)[0]
proba = self.model.predict_proba([text])[0]
similarities = None
classes = self.model.classes_
ranked = sorted(zip(classes, proba), key=lambda x: x[1], reverse=True)
best_id = int(ranked[0][0])
return Prediction(
label_id=best_id,
label_name=self.label_map[best_id],
confidence=float(ranked[0][1]),
top_k=[
{
"label_id": int(cls),
"label_name": self.label_map[int(cls)],
"confidence": float(p),
}
for cls, p in ranked[:top_k]
],
similarities=similarities,
pair_scores=pair_scores,
)
def main() -> None:
parser = argparse.ArgumentParser(description="Classify SQL error type")
parser.add_argument("--query", type=str, required=True)
parser.add_argument("--correct-query", type=str, default=None)
parser.add_argument("--error-message", type=str, default=None)
parser.add_argument("--schema", type=str, default=None)
parser.add_argument("--question", type=str, default=None)
parser.add_argument("--model", type=Path, default=DEFAULT_MODEL_PATH)
parser.add_argument("--top-k", type=int, default=3)
args = parser.parse_args()
clf = SQLErrorClassifier(args.model)
result = clf.predict(
args.query,
args.error_message,
args.schema,
args.question,
args.correct_query,
top_k=args.top_k,
)
print(json.dumps(asdict(result), indent=2))
if __name__ == "__main__":
main()