Spaces:
Sleeping
Sleeping
File size: 6,111 Bytes
9b2cded | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 | """Multi-tower semantic comparison architecture for SQL error classification."""
from __future__ import annotations
from dataclasses import dataclass
from typing import List, Optional
import numpy as np
from sklearn.linear_model import LogisticRegression
from sklearn.preprocessing import StandardScaler
from src.model import DEFAULT_ENCODER
from src.sql_features import extract_sql_features
@dataclass
class QueryContext:
"""Inputs available in the SQL playground at inference time."""
question: str
schema: str
correct_query: str
student_query: str
error_message: Optional[str] = None
def _cosine(a: np.ndarray, b: np.ndarray) -> np.ndarray:
denom = np.linalg.norm(a, axis=1) * np.linalg.norm(b, axis=1)
denom = np.maximum(denom, 1e-8)
return np.sum(a * b, axis=1) / denom
class MultiTowerClassifier:
"""
Recommended architecture for SQL error classification.
Three semantic towers (shared MiniLM encoder):
1. Intent tower β question + schema β what should be answered
2. Reference tower β correct_query β ground-truth solution
3. Student tower β student_query β what the student wrote
Comparison layer fuses:
- tower embeddings
- |student β reference| (what changed)
- student β reference (interaction)
- cosine similarities (semantic alignment)
- SQL structural features (join/null/agg rules)
A light linear head maps the fused vector β 15 error categories.
"""
def __init__(
self,
encoder_name: str = DEFAULT_ENCODER,
batch_size: int = 256,
):
self.encoder_name = encoder_name
self.batch_size = batch_size
self.encoder = None
self.scaler = StandardScaler()
self.clf = LogisticRegression(
max_iter=1000,
solver="lbfgs",
class_weight="balanced",
random_state=42,
)
self.classes_: Optional[np.ndarray] = None
def _load_encoder(self):
if self.encoder is None:
from sentence_transformers import SentenceTransformer
self.encoder = SentenceTransformer(self.encoder_name)
def _encode(self, texts: List[str], show_progress: bool = False) -> np.ndarray:
self._load_encoder()
return self.encoder.encode(
texts,
batch_size=self.batch_size,
show_progress_bar=show_progress,
convert_to_numpy=True,
)
@staticmethod
def _intent_text(ctx: QueryContext) -> str:
return f"QUESTION: {ctx.question} SCHEMA: {ctx.schema}"
@staticmethod
def _reference_text(ctx: QueryContext) -> str:
return f"REFERENCE: {ctx.correct_query}"
@staticmethod
def _student_text(ctx: QueryContext) -> str:
parts = [f"STUDENT: {ctx.student_query}"]
if ctx.error_message:
parts.append(f"ERROR: {ctx.error_message}")
return " ".join(parts)
def _build_feature_matrix(
self,
contexts: List[QueryContext],
show_progress: bool = False,
) -> np.ndarray:
intent_texts = [self._intent_text(c) for c in contexts]
ref_texts = [self._reference_text(c) for c in contexts]
student_texts = [self._student_text(c) for c in contexts]
intent_emb = self._encode(intent_texts, show_progress)
ref_emb = self._encode(ref_texts, show_progress=False)
student_emb = self._encode(student_texts, show_progress=False)
diff = np.abs(student_emb - ref_emb)
prod = student_emb * ref_emb
cos_sr = _cosine(student_emb, ref_emb).reshape(-1, 1)
cos_si = _cosine(student_emb, intent_emb).reshape(-1, 1)
cos_ri = _cosine(ref_emb, intent_emb).reshape(-1, 1)
sql_feats = np.array(
[
extract_sql_features(c.student_query, c.correct_query)
for c in contexts
],
dtype=np.float64,
)
return np.hstack(
[intent_emb, ref_emb, student_emb, diff, prod, cos_sr, cos_si, cos_ri, sql_feats]
)
def fit(self, contexts: List[QueryContext], y: np.ndarray) -> "MultiTowerClassifier":
X = self._build_feature_matrix(contexts, show_progress=True)
X = self.scaler.fit_transform(X)
self.clf.fit(X, y)
self.classes_ = self.clf.classes_
return self
def _prepare_features(self, contexts: List[QueryContext]) -> np.ndarray:
X = self.scaler.transform(self._build_feature_matrix(contexts))
return np.nan_to_num(X, nan=0.0, posinf=1e3, neginf=-1e3)
def predict(self, contexts: List[QueryContext]) -> np.ndarray:
return self.clf.predict(self._prepare_features(contexts))
def predict_proba(self, contexts: List[QueryContext]) -> np.ndarray:
return self.clf.predict_proba(self._prepare_features(contexts))
def explain_similarities(self, ctx: QueryContext) -> dict:
"""Diagnostic scores for the playground UI."""
emb = self._build_feature_matrix([ctx])
intent_texts = [self._intent_text(ctx)]
ref_texts = [self._reference_text(ctx)]
student_texts = [self._student_text(ctx)]
intent_emb = self._encode(intent_texts)
ref_emb = self._encode(ref_texts)
student_emb = self._encode(student_texts)
return {
"student_vs_reference": float(_cosine(student_emb, ref_emb)[0]),
"student_vs_intent": float(_cosine(student_emb, intent_emb)[0]),
"reference_vs_intent": float(_cosine(ref_emb, intent_emb)[0]),
}
def contexts_from_dataframe(df) -> List[QueryContext]:
"""Build QueryContext list from a training dataframe."""
has_error = "error_message" in df.columns
return [
QueryContext(
question=row["question"],
schema=row["schema"],
correct_query=row["correct_query"],
student_query=row["query"],
error_message=row["error_message"] if has_error else None,
)
for row in df.to_dict("records")
]
|