File size: 4,220 Bytes
9b2cded
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
"""Inference API for SQL error classification."""

from __future__ import annotations

import argparse
import json
from dataclasses import asdict, dataclass
from pathlib import Path
from typing import List, Optional

from src.categories import id_to_name, load_categories
from src.model import DEFAULT_MODEL_PATH, combine_features, load_model
from src.cross_encoder_model import (
    CrossEncoderClassifier,
    FineTunedCrossEncoderClassifier,
)
from src.multi_tower_model import MultiTowerClassifier, QueryContext

CONTEXT_MODELS = (
    CrossEncoderClassifier,
    FineTunedCrossEncoderClassifier,
    MultiTowerClassifier,
)


@dataclass
class Prediction:
    label_id: int
    label_name: str
    confidence: float
    top_k: List[dict]
    similarities: Optional[dict] = None
    pair_scores: Optional[dict] = None


class SQLErrorClassifier:
    """Classifier wrapper for playground integration."""

    def __init__(self, model_path: Path = DEFAULT_MODEL_PATH):
        self.model = load_model(model_path)
        self.label_map = id_to_name(load_categories())

    def predict(
        self,
        query: str,
        error_message: Optional[str] = None,
        schema: Optional[str] = None,
        question: Optional[str] = None,
        correct_query: Optional[str] = None,
        top_k: int = 3,
    ) -> Prediction:
        if isinstance(self.model, CONTEXT_MODELS):
            if not all([schema, question, correct_query]):
                raise ValueError(
                    "context models require schema, question, and correct_query"
                )
            ctx = QueryContext(
                question=question,
                schema=schema,
                correct_query=correct_query,
                student_query=query,
                error_message=error_message,
            )
            proba = self.model.predict_proba([ctx])[0]
            similarities = (
                self.model.explain_similarities(ctx)
                if isinstance(self.model, MultiTowerClassifier)
                else None
            )
            pair_scores = (
                self.model.explain_pair_scores(ctx)
                if isinstance(self.model, CrossEncoderClassifier)
                else None
            )
        else:
            pair_scores = None
            similarities = None
            text = combine_features(
                queries=[query],
                error_messages=[error_message] if error_message else None,
                schemas=[schema] if schema else None,
                questions=[question] if question else None,
            )[0]
            proba = self.model.predict_proba([text])[0]
            similarities = None

        classes = self.model.classes_
        ranked = sorted(zip(classes, proba), key=lambda x: x[1], reverse=True)
        best_id = int(ranked[0][0])

        return Prediction(
            label_id=best_id,
            label_name=self.label_map[best_id],
            confidence=float(ranked[0][1]),
            top_k=[
                {
                    "label_id": int(cls),
                    "label_name": self.label_map[int(cls)],
                    "confidence": float(p),
                }
                for cls, p in ranked[:top_k]
            ],
            similarities=similarities,
            pair_scores=pair_scores,
        )


def main() -> None:
    parser = argparse.ArgumentParser(description="Classify SQL error type")
    parser.add_argument("--query", type=str, required=True)
    parser.add_argument("--correct-query", type=str, default=None)
    parser.add_argument("--error-message", type=str, default=None)
    parser.add_argument("--schema", type=str, default=None)
    parser.add_argument("--question", type=str, default=None)
    parser.add_argument("--model", type=Path, default=DEFAULT_MODEL_PATH)
    parser.add_argument("--top-k", type=int, default=3)
    args = parser.parse_args()

    clf = SQLErrorClassifier(args.model)
    result = clf.predict(
        args.query,
        args.error_message,
        args.schema,
        args.question,
        args.correct_query,
        top_k=args.top_k,
    )
    print(json.dumps(asdict(result), indent=2))


if __name__ == "__main__":
    main()