Spaces:
Sleeping
Sleeping
| """Multi-tower semantic comparison architecture for SQL error classification.""" | |
| from __future__ import annotations | |
| from dataclasses import dataclass | |
| from typing import List, Optional | |
| import numpy as np | |
| from sklearn.linear_model import LogisticRegression | |
| from sklearn.preprocessing import StandardScaler | |
| from src.model import DEFAULT_ENCODER | |
| from src.sql_features import extract_sql_features | |
| class QueryContext: | |
| """Inputs available in the SQL playground at inference time.""" | |
| question: str | |
| schema: str | |
| correct_query: str | |
| student_query: str | |
| error_message: Optional[str] = None | |
| def _cosine(a: np.ndarray, b: np.ndarray) -> np.ndarray: | |
| denom = np.linalg.norm(a, axis=1) * np.linalg.norm(b, axis=1) | |
| denom = np.maximum(denom, 1e-8) | |
| return np.sum(a * b, axis=1) / denom | |
| class MultiTowerClassifier: | |
| """ | |
| Recommended architecture for SQL error classification. | |
| Three semantic towers (shared MiniLM encoder): | |
| 1. Intent tower β question + schema β what should be answered | |
| 2. Reference tower β correct_query β ground-truth solution | |
| 3. Student tower β student_query β what the student wrote | |
| Comparison layer fuses: | |
| - tower embeddings | |
| - |student β reference| (what changed) | |
| - student β reference (interaction) | |
| - cosine similarities (semantic alignment) | |
| - SQL structural features (join/null/agg rules) | |
| A light linear head maps the fused vector β 15 error categories. | |
| """ | |
| def __init__( | |
| self, | |
| encoder_name: str = DEFAULT_ENCODER, | |
| batch_size: int = 256, | |
| ): | |
| self.encoder_name = encoder_name | |
| self.batch_size = batch_size | |
| self.encoder = None | |
| self.scaler = StandardScaler() | |
| self.clf = LogisticRegression( | |
| max_iter=1000, | |
| solver="lbfgs", | |
| class_weight="balanced", | |
| random_state=42, | |
| ) | |
| self.classes_: Optional[np.ndarray] = None | |
| def _load_encoder(self): | |
| if self.encoder is None: | |
| from sentence_transformers import SentenceTransformer | |
| self.encoder = SentenceTransformer(self.encoder_name) | |
| def _encode(self, texts: List[str], show_progress: bool = False) -> np.ndarray: | |
| self._load_encoder() | |
| return self.encoder.encode( | |
| texts, | |
| batch_size=self.batch_size, | |
| show_progress_bar=show_progress, | |
| convert_to_numpy=True, | |
| ) | |
| def _intent_text(ctx: QueryContext) -> str: | |
| return f"QUESTION: {ctx.question} SCHEMA: {ctx.schema}" | |
| def _reference_text(ctx: QueryContext) -> str: | |
| return f"REFERENCE: {ctx.correct_query}" | |
| def _student_text(ctx: QueryContext) -> str: | |
| parts = [f"STUDENT: {ctx.student_query}"] | |
| if ctx.error_message: | |
| parts.append(f"ERROR: {ctx.error_message}") | |
| return " ".join(parts) | |
| def _build_feature_matrix( | |
| self, | |
| contexts: List[QueryContext], | |
| show_progress: bool = False, | |
| ) -> np.ndarray: | |
| intent_texts = [self._intent_text(c) for c in contexts] | |
| ref_texts = [self._reference_text(c) for c in contexts] | |
| student_texts = [self._student_text(c) for c in contexts] | |
| intent_emb = self._encode(intent_texts, show_progress) | |
| ref_emb = self._encode(ref_texts, show_progress=False) | |
| student_emb = self._encode(student_texts, show_progress=False) | |
| diff = np.abs(student_emb - ref_emb) | |
| prod = student_emb * ref_emb | |
| cos_sr = _cosine(student_emb, ref_emb).reshape(-1, 1) | |
| cos_si = _cosine(student_emb, intent_emb).reshape(-1, 1) | |
| cos_ri = _cosine(ref_emb, intent_emb).reshape(-1, 1) | |
| sql_feats = np.array( | |
| [ | |
| extract_sql_features(c.student_query, c.correct_query) | |
| for c in contexts | |
| ], | |
| dtype=np.float64, | |
| ) | |
| return np.hstack( | |
| [intent_emb, ref_emb, student_emb, diff, prod, cos_sr, cos_si, cos_ri, sql_feats] | |
| ) | |
| def fit(self, contexts: List[QueryContext], y: np.ndarray) -> "MultiTowerClassifier": | |
| X = self._build_feature_matrix(contexts, show_progress=True) | |
| X = self.scaler.fit_transform(X) | |
| self.clf.fit(X, y) | |
| self.classes_ = self.clf.classes_ | |
| return self | |
| def _prepare_features(self, contexts: List[QueryContext]) -> np.ndarray: | |
| X = self.scaler.transform(self._build_feature_matrix(contexts)) | |
| return np.nan_to_num(X, nan=0.0, posinf=1e3, neginf=-1e3) | |
| def predict(self, contexts: List[QueryContext]) -> np.ndarray: | |
| return self.clf.predict(self._prepare_features(contexts)) | |
| def predict_proba(self, contexts: List[QueryContext]) -> np.ndarray: | |
| return self.clf.predict_proba(self._prepare_features(contexts)) | |
| def explain_similarities(self, ctx: QueryContext) -> dict: | |
| """Diagnostic scores for the playground UI.""" | |
| emb = self._build_feature_matrix([ctx]) | |
| intent_texts = [self._intent_text(ctx)] | |
| ref_texts = [self._reference_text(ctx)] | |
| student_texts = [self._student_text(ctx)] | |
| intent_emb = self._encode(intent_texts) | |
| ref_emb = self._encode(ref_texts) | |
| student_emb = self._encode(student_texts) | |
| return { | |
| "student_vs_reference": float(_cosine(student_emb, ref_emb)[0]), | |
| "student_vs_intent": float(_cosine(student_emb, intent_emb)[0]), | |
| "reference_vs_intent": float(_cosine(ref_emb, intent_emb)[0]), | |
| } | |
| def contexts_from_dataframe(df) -> List[QueryContext]: | |
| """Build QueryContext list from a training dataframe.""" | |
| has_error = "error_message" in df.columns | |
| return [ | |
| QueryContext( | |
| question=row["question"], | |
| schema=row["schema"], | |
| correct_query=row["correct_query"], | |
| student_query=row["query"], | |
| error_message=row["error_message"] if has_error else None, | |
| ) | |
| for row in df.to_dict("records") | |
| ] | |