File size: 6,111 Bytes
9b2cded
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
"""Multi-tower semantic comparison architecture for SQL error classification."""

from __future__ import annotations

from dataclasses import dataclass
from typing import List, Optional

import numpy as np
from sklearn.linear_model import LogisticRegression
from sklearn.preprocessing import StandardScaler

from src.model import DEFAULT_ENCODER
from src.sql_features import extract_sql_features


@dataclass
class QueryContext:
    """Inputs available in the SQL playground at inference time."""

    question: str
    schema: str
    correct_query: str
    student_query: str
    error_message: Optional[str] = None


def _cosine(a: np.ndarray, b: np.ndarray) -> np.ndarray:
    denom = np.linalg.norm(a, axis=1) * np.linalg.norm(b, axis=1)
    denom = np.maximum(denom, 1e-8)
    return np.sum(a * b, axis=1) / denom


class MultiTowerClassifier:
    """
    Recommended architecture for SQL error classification.

    Three semantic towers (shared MiniLM encoder):
      1. Intent tower   β€” question + schema  β†’ what should be answered
      2. Reference tower β€” correct_query     β†’ ground-truth solution
      3. Student tower  β€” student_query      β†’ what the student wrote

    Comparison layer fuses:
      - tower embeddings
      - |student βˆ’ reference|  (what changed)
      - student βŠ™ reference    (interaction)
      - cosine similarities    (semantic alignment)
      - SQL structural features (join/null/agg rules)

    A light linear head maps the fused vector β†’ 15 error categories.
    """

    def __init__(
        self,
        encoder_name: str = DEFAULT_ENCODER,
        batch_size: int = 256,
    ):
        self.encoder_name = encoder_name
        self.batch_size = batch_size
        self.encoder = None
        self.scaler = StandardScaler()
        self.clf = LogisticRegression(
            max_iter=1000,
            solver="lbfgs",
            class_weight="balanced",
            random_state=42,
        )
        self.classes_: Optional[np.ndarray] = None

    def _load_encoder(self):
        if self.encoder is None:
            from sentence_transformers import SentenceTransformer

            self.encoder = SentenceTransformer(self.encoder_name)

    def _encode(self, texts: List[str], show_progress: bool = False) -> np.ndarray:
        self._load_encoder()
        return self.encoder.encode(
            texts,
            batch_size=self.batch_size,
            show_progress_bar=show_progress,
            convert_to_numpy=True,
        )

    @staticmethod
    def _intent_text(ctx: QueryContext) -> str:
        return f"QUESTION: {ctx.question} SCHEMA: {ctx.schema}"

    @staticmethod
    def _reference_text(ctx: QueryContext) -> str:
        return f"REFERENCE: {ctx.correct_query}"

    @staticmethod
    def _student_text(ctx: QueryContext) -> str:
        parts = [f"STUDENT: {ctx.student_query}"]
        if ctx.error_message:
            parts.append(f"ERROR: {ctx.error_message}")
        return " ".join(parts)

    def _build_feature_matrix(
        self,
        contexts: List[QueryContext],
        show_progress: bool = False,
    ) -> np.ndarray:
        intent_texts = [self._intent_text(c) for c in contexts]
        ref_texts = [self._reference_text(c) for c in contexts]
        student_texts = [self._student_text(c) for c in contexts]

        intent_emb = self._encode(intent_texts, show_progress)
        ref_emb = self._encode(ref_texts, show_progress=False)
        student_emb = self._encode(student_texts, show_progress=False)

        diff = np.abs(student_emb - ref_emb)
        prod = student_emb * ref_emb
        cos_sr = _cosine(student_emb, ref_emb).reshape(-1, 1)
        cos_si = _cosine(student_emb, intent_emb).reshape(-1, 1)
        cos_ri = _cosine(ref_emb, intent_emb).reshape(-1, 1)

        sql_feats = np.array(
            [
                extract_sql_features(c.student_query, c.correct_query)
                for c in contexts
            ],
            dtype=np.float64,
        )

        return np.hstack(
            [intent_emb, ref_emb, student_emb, diff, prod, cos_sr, cos_si, cos_ri, sql_feats]
        )

    def fit(self, contexts: List[QueryContext], y: np.ndarray) -> "MultiTowerClassifier":
        X = self._build_feature_matrix(contexts, show_progress=True)
        X = self.scaler.fit_transform(X)
        self.clf.fit(X, y)
        self.classes_ = self.clf.classes_
        return self

    def _prepare_features(self, contexts: List[QueryContext]) -> np.ndarray:
        X = self.scaler.transform(self._build_feature_matrix(contexts))
        return np.nan_to_num(X, nan=0.0, posinf=1e3, neginf=-1e3)

    def predict(self, contexts: List[QueryContext]) -> np.ndarray:
        return self.clf.predict(self._prepare_features(contexts))

    def predict_proba(self, contexts: List[QueryContext]) -> np.ndarray:
        return self.clf.predict_proba(self._prepare_features(contexts))

    def explain_similarities(self, ctx: QueryContext) -> dict:
        """Diagnostic scores for the playground UI."""
        emb = self._build_feature_matrix([ctx])
        intent_texts = [self._intent_text(ctx)]
        ref_texts = [self._reference_text(ctx)]
        student_texts = [self._student_text(ctx)]
        intent_emb = self._encode(intent_texts)
        ref_emb = self._encode(ref_texts)
        student_emb = self._encode(student_texts)
        return {
            "student_vs_reference": float(_cosine(student_emb, ref_emb)[0]),
            "student_vs_intent": float(_cosine(student_emb, intent_emb)[0]),
            "reference_vs_intent": float(_cosine(ref_emb, intent_emb)[0]),
        }


def contexts_from_dataframe(df) -> List[QueryContext]:
    """Build QueryContext list from a training dataframe."""
    has_error = "error_message" in df.columns
    return [
        QueryContext(
            question=row["question"],
            schema=row["schema"],
            correct_query=row["correct_query"],
            student_query=row["query"],
            error_message=row["error_message"] if has_error else None,
        )
        for row in df.to_dict("records")
    ]