"""Multi-tower semantic comparison architecture for SQL error classification.""" from __future__ import annotations from dataclasses import dataclass from typing import List, Optional import numpy as np from sklearn.linear_model import LogisticRegression from sklearn.preprocessing import StandardScaler from src.model import DEFAULT_ENCODER from src.sql_features import extract_sql_features @dataclass class QueryContext: """Inputs available in the SQL playground at inference time.""" question: str schema: str correct_query: str student_query: str error_message: Optional[str] = None def _cosine(a: np.ndarray, b: np.ndarray) -> np.ndarray: denom = np.linalg.norm(a, axis=1) * np.linalg.norm(b, axis=1) denom = np.maximum(denom, 1e-8) return np.sum(a * b, axis=1) / denom class MultiTowerClassifier: """ Recommended architecture for SQL error classification. Three semantic towers (shared MiniLM encoder): 1. Intent tower — question + schema → what should be answered 2. Reference tower — correct_query → ground-truth solution 3. Student tower — student_query → what the student wrote Comparison layer fuses: - tower embeddings - |student − reference| (what changed) - student ⊙ reference (interaction) - cosine similarities (semantic alignment) - SQL structural features (join/null/agg rules) A light linear head maps the fused vector → 15 error categories. """ def __init__( self, encoder_name: str = DEFAULT_ENCODER, batch_size: int = 256, ): self.encoder_name = encoder_name self.batch_size = batch_size self.encoder = None self.scaler = StandardScaler() self.clf = LogisticRegression( max_iter=1000, solver="lbfgs", class_weight="balanced", random_state=42, ) self.classes_: Optional[np.ndarray] = None def _load_encoder(self): if self.encoder is None: from sentence_transformers import SentenceTransformer self.encoder = SentenceTransformer(self.encoder_name) def _encode(self, texts: List[str], show_progress: bool = False) -> np.ndarray: self._load_encoder() return self.encoder.encode( texts, batch_size=self.batch_size, show_progress_bar=show_progress, convert_to_numpy=True, ) @staticmethod def _intent_text(ctx: QueryContext) -> str: return f"QUESTION: {ctx.question} SCHEMA: {ctx.schema}" @staticmethod def _reference_text(ctx: QueryContext) -> str: return f"REFERENCE: {ctx.correct_query}" @staticmethod def _student_text(ctx: QueryContext) -> str: parts = [f"STUDENT: {ctx.student_query}"] if ctx.error_message: parts.append(f"ERROR: {ctx.error_message}") return " ".join(parts) def _build_feature_matrix( self, contexts: List[QueryContext], show_progress: bool = False, ) -> np.ndarray: intent_texts = [self._intent_text(c) for c in contexts] ref_texts = [self._reference_text(c) for c in contexts] student_texts = [self._student_text(c) for c in contexts] intent_emb = self._encode(intent_texts, show_progress) ref_emb = self._encode(ref_texts, show_progress=False) student_emb = self._encode(student_texts, show_progress=False) diff = np.abs(student_emb - ref_emb) prod = student_emb * ref_emb cos_sr = _cosine(student_emb, ref_emb).reshape(-1, 1) cos_si = _cosine(student_emb, intent_emb).reshape(-1, 1) cos_ri = _cosine(ref_emb, intent_emb).reshape(-1, 1) sql_feats = np.array( [ extract_sql_features(c.student_query, c.correct_query) for c in contexts ], dtype=np.float64, ) return np.hstack( [intent_emb, ref_emb, student_emb, diff, prod, cos_sr, cos_si, cos_ri, sql_feats] ) def fit(self, contexts: List[QueryContext], y: np.ndarray) -> "MultiTowerClassifier": X = self._build_feature_matrix(contexts, show_progress=True) X = self.scaler.fit_transform(X) self.clf.fit(X, y) self.classes_ = self.clf.classes_ return self def _prepare_features(self, contexts: List[QueryContext]) -> np.ndarray: X = self.scaler.transform(self._build_feature_matrix(contexts)) return np.nan_to_num(X, nan=0.0, posinf=1e3, neginf=-1e3) def predict(self, contexts: List[QueryContext]) -> np.ndarray: return self.clf.predict(self._prepare_features(contexts)) def predict_proba(self, contexts: List[QueryContext]) -> np.ndarray: return self.clf.predict_proba(self._prepare_features(contexts)) def explain_similarities(self, ctx: QueryContext) -> dict: """Diagnostic scores for the playground UI.""" emb = self._build_feature_matrix([ctx]) intent_texts = [self._intent_text(ctx)] ref_texts = [self._reference_text(ctx)] student_texts = [self._student_text(ctx)] intent_emb = self._encode(intent_texts) ref_emb = self._encode(ref_texts) student_emb = self._encode(student_texts) return { "student_vs_reference": float(_cosine(student_emb, ref_emb)[0]), "student_vs_intent": float(_cosine(student_emb, intent_emb)[0]), "reference_vs_intent": float(_cosine(ref_emb, intent_emb)[0]), } def contexts_from_dataframe(df) -> List[QueryContext]: """Build QueryContext list from a training dataframe.""" has_error = "error_message" in df.columns return [ QueryContext( question=row["question"], schema=row["schema"], correct_query=row["correct_query"], student_query=row["query"], error_message=row["error_message"] if has_error else None, ) for row in df.to_dict("records") ]