sql-error-classifier-train / src /multi_tower_model.py
nishu08's picture
Deploy CodeBERT training Space
9b2cded verified
Raw
History Blame Contribute Delete
6.11 kB
"""Multi-tower semantic comparison architecture for SQL error classification."""
from __future__ import annotations
from dataclasses import dataclass
from typing import List, Optional
import numpy as np
from sklearn.linear_model import LogisticRegression
from sklearn.preprocessing import StandardScaler
from src.model import DEFAULT_ENCODER
from src.sql_features import extract_sql_features
@dataclass
class QueryContext:
"""Inputs available in the SQL playground at inference time."""
question: str
schema: str
correct_query: str
student_query: str
error_message: Optional[str] = None
def _cosine(a: np.ndarray, b: np.ndarray) -> np.ndarray:
denom = np.linalg.norm(a, axis=1) * np.linalg.norm(b, axis=1)
denom = np.maximum(denom, 1e-8)
return np.sum(a * b, axis=1) / denom
class MultiTowerClassifier:
"""
Recommended architecture for SQL error classification.
Three semantic towers (shared MiniLM encoder):
1. Intent tower β€” question + schema β†’ what should be answered
2. Reference tower β€” correct_query β†’ ground-truth solution
3. Student tower β€” student_query β†’ what the student wrote
Comparison layer fuses:
- tower embeddings
- |student βˆ’ reference| (what changed)
- student βŠ™ reference (interaction)
- cosine similarities (semantic alignment)
- SQL structural features (join/null/agg rules)
A light linear head maps the fused vector β†’ 15 error categories.
"""
def __init__(
self,
encoder_name: str = DEFAULT_ENCODER,
batch_size: int = 256,
):
self.encoder_name = encoder_name
self.batch_size = batch_size
self.encoder = None
self.scaler = StandardScaler()
self.clf = LogisticRegression(
max_iter=1000,
solver="lbfgs",
class_weight="balanced",
random_state=42,
)
self.classes_: Optional[np.ndarray] = None
def _load_encoder(self):
if self.encoder is None:
from sentence_transformers import SentenceTransformer
self.encoder = SentenceTransformer(self.encoder_name)
def _encode(self, texts: List[str], show_progress: bool = False) -> np.ndarray:
self._load_encoder()
return self.encoder.encode(
texts,
batch_size=self.batch_size,
show_progress_bar=show_progress,
convert_to_numpy=True,
)
@staticmethod
def _intent_text(ctx: QueryContext) -> str:
return f"QUESTION: {ctx.question} SCHEMA: {ctx.schema}"
@staticmethod
def _reference_text(ctx: QueryContext) -> str:
return f"REFERENCE: {ctx.correct_query}"
@staticmethod
def _student_text(ctx: QueryContext) -> str:
parts = [f"STUDENT: {ctx.student_query}"]
if ctx.error_message:
parts.append(f"ERROR: {ctx.error_message}")
return " ".join(parts)
def _build_feature_matrix(
self,
contexts: List[QueryContext],
show_progress: bool = False,
) -> np.ndarray:
intent_texts = [self._intent_text(c) for c in contexts]
ref_texts = [self._reference_text(c) for c in contexts]
student_texts = [self._student_text(c) for c in contexts]
intent_emb = self._encode(intent_texts, show_progress)
ref_emb = self._encode(ref_texts, show_progress=False)
student_emb = self._encode(student_texts, show_progress=False)
diff = np.abs(student_emb - ref_emb)
prod = student_emb * ref_emb
cos_sr = _cosine(student_emb, ref_emb).reshape(-1, 1)
cos_si = _cosine(student_emb, intent_emb).reshape(-1, 1)
cos_ri = _cosine(ref_emb, intent_emb).reshape(-1, 1)
sql_feats = np.array(
[
extract_sql_features(c.student_query, c.correct_query)
for c in contexts
],
dtype=np.float64,
)
return np.hstack(
[intent_emb, ref_emb, student_emb, diff, prod, cos_sr, cos_si, cos_ri, sql_feats]
)
def fit(self, contexts: List[QueryContext], y: np.ndarray) -> "MultiTowerClassifier":
X = self._build_feature_matrix(contexts, show_progress=True)
X = self.scaler.fit_transform(X)
self.clf.fit(X, y)
self.classes_ = self.clf.classes_
return self
def _prepare_features(self, contexts: List[QueryContext]) -> np.ndarray:
X = self.scaler.transform(self._build_feature_matrix(contexts))
return np.nan_to_num(X, nan=0.0, posinf=1e3, neginf=-1e3)
def predict(self, contexts: List[QueryContext]) -> np.ndarray:
return self.clf.predict(self._prepare_features(contexts))
def predict_proba(self, contexts: List[QueryContext]) -> np.ndarray:
return self.clf.predict_proba(self._prepare_features(contexts))
def explain_similarities(self, ctx: QueryContext) -> dict:
"""Diagnostic scores for the playground UI."""
emb = self._build_feature_matrix([ctx])
intent_texts = [self._intent_text(ctx)]
ref_texts = [self._reference_text(ctx)]
student_texts = [self._student_text(ctx)]
intent_emb = self._encode(intent_texts)
ref_emb = self._encode(ref_texts)
student_emb = self._encode(student_texts)
return {
"student_vs_reference": float(_cosine(student_emb, ref_emb)[0]),
"student_vs_intent": float(_cosine(student_emb, intent_emb)[0]),
"reference_vs_intent": float(_cosine(ref_emb, intent_emb)[0]),
}
def contexts_from_dataframe(df) -> List[QueryContext]:
"""Build QueryContext list from a training dataframe."""
has_error = "error_message" in df.columns
return [
QueryContext(
question=row["question"],
schema=row["schema"],
correct_query=row["correct_query"],
student_query=row["query"],
error_message=row["error_message"] if has_error else None,
)
for row in df.to_dict("records")
]