"""Cross-encoder architecture for SQL error classification.""" from __future__ import annotations from dataclasses import dataclass from pathlib import Path from typing import List, Optional, Tuple import numpy as np from sklearn.linear_model import LogisticRegression from sklearn.preprocessing import StandardScaler from src.multi_tower_model import QueryContext, contexts_from_dataframe from src.sql_features import extract_sql_features DEFAULT_CROSS_ENCODER = "cross-encoder/ms-marco-MiniLM-L6-v2" DEFAULT_FINETUNED_CE = "cross-encoder/ms-marco-MiniLM-L6-v2" PAIR_NAMES = ( "intent_vs_student", "reference_vs_student", "intent_vs_reference", ) @dataclass(frozen=True) class CrossEncoderPair: name: str text_a: str text_b: str def _intent_text(ctx: QueryContext) -> str: return f"QUESTION: {ctx.question} SCHEMA: {ctx.schema}" def _reference_text(ctx: QueryContext) -> str: return f"REFERENCE: {ctx.correct_query}" def _student_text(ctx: QueryContext) -> str: parts = [f"STUDENT: {ctx.student_query}"] if ctx.error_message: parts.append(f"ERROR: {ctx.error_message}") return " ".join(parts) def _context_text(ctx: QueryContext) -> str: """Full task context for fine-tuned cross-encoder.""" return ( f"QUESTION: {ctx.question} " f"SCHEMA: {ctx.schema} " f"REFERENCE: {ctx.correct_query}" ) def build_pairs(ctx: QueryContext) -> List[CrossEncoderPair]: intent, reference, student = ( _intent_text(ctx), _reference_text(ctx), _student_text(ctx), ) return [ CrossEncoderPair("intent_vs_student", intent, student), CrossEncoderPair("reference_vs_student", reference, student), CrossEncoderPair("intent_vs_reference", intent, reference), ] class CrossEncoderClassifier: """ Hybrid cross-encoder: frozen pairwise relevance + linear head. Unlike bi-encoders (multi-tower), the cross-encoder attends jointly over each (context, student) pair — better for logical and filtering errors. Three pairs are scored: 1. intent vs student — does the query address the question? 2. reference vs student — how far is the student from the answer? 3. intent vs reference — task-answer alignment baseline Pair scores + SQL rule features → LogisticRegression → 15 classes. """ def __init__( self, cross_encoder_name: str = DEFAULT_CROSS_ENCODER, batch_size: int = 32, max_length: int = 512, ): self.cross_encoder_name = cross_encoder_name self.batch_size = batch_size self.max_length = max_length self.cross_encoder = None self.scaler = StandardScaler() self.clf = LogisticRegression( max_iter=1000, solver="lbfgs", class_weight="balanced", random_state=42, ) self.classes_: Optional[np.ndarray] = None def _load_cross_encoder(self): if self.cross_encoder is None: from sentence_transformers import CrossEncoder self.cross_encoder = CrossEncoder( self.cross_encoder_name, max_length=self.max_length, ) def _pair_batches(self, contexts: List[QueryContext]) -> List[List[Tuple[str, str]]]: """One batch list per pair type across all contexts.""" pair_lists: List[List[Tuple[str, str]]] = [[], [], []] for ctx in contexts: pairs = build_pairs(ctx) for i, pair in enumerate(pairs): pair_lists[i].append((pair.text_a, pair.text_b)) return pair_lists def _score_pairs( self, contexts: List[QueryContext], show_progress: bool = False, ) -> np.ndarray: self._load_cross_encoder() pair_batches = self._pair_batches(contexts) scores = [] for batch in pair_batches: raw = self.cross_encoder.predict( batch, batch_size=self.batch_size, show_progress_bar=show_progress, ) scores.append(np.asarray(raw, dtype=np.float64).reshape(-1, 1)) return np.hstack(scores) # (n, 3) def _build_features( self, contexts: List[QueryContext], show_progress: bool = False, ) -> np.ndarray: pair_scores = self._score_pairs(contexts, show_progress=show_progress) s_is, s_rs, s_ir = pair_scores[:, 0], pair_scores[:, 1], pair_scores[:, 2] derived = np.column_stack( [ s_rs - s_is, # reference closer than intent? s_is - s_ir, # student-intent gap vs baseline s_rs - s_ir, # student-reference gap vs baseline s_is * s_rs, # interaction np.abs(s_rs - s_is), # intent-reference disagreement ] ) sql_feats = np.array( [extract_sql_features(c.student_query, c.correct_query) for c in contexts], dtype=np.float64, ) return np.hstack([pair_scores, derived, sql_feats]) def _prepare_features(self, contexts: List[QueryContext]) -> np.ndarray: X = self.scaler.transform(self._build_features(contexts)) return np.nan_to_num(X, nan=0.0, posinf=1e3, neginf=-1e3) def fit(self, contexts: List[QueryContext], y: np.ndarray) -> "CrossEncoderClassifier": X = self._build_features(contexts, show_progress=True) X = self.scaler.fit_transform(X) X = np.nan_to_num(X, nan=0.0, posinf=1e3, neginf=-1e3) self.clf.fit(X, y) self.classes_ = self.clf.classes_ return self def predict(self, contexts: List[QueryContext]) -> np.ndarray: return self.clf.predict(self._prepare_features(contexts)) def predict_proba(self, contexts: List[QueryContext]) -> np.ndarray: return self.clf.predict_proba(self._prepare_features(contexts)) def explain_pair_scores(self, ctx: QueryContext) -> dict: scores = self._score_pairs([ctx])[0] return { PAIR_NAMES[0]: float(scores[0]), PAIR_NAMES[1]: float(scores[1]), PAIR_NAMES[2]: float(scores[2]), } class FineTunedCrossEncoderClassifier: """ End-to-end fine-tuned cross-encoder (highest accuracy). Single cross-attention pass over [task_context | student_query] with num_labels=15. Slower to train; best on smaller high-quality datasets. """ def __init__( self, cross_encoder_name: str = DEFAULT_FINETUNED_CE, batch_size: int = 16, max_length: int = 512, num_labels: int = 15, ): self.cross_encoder_name = cross_encoder_name self.batch_size = batch_size self.max_length = max_length self.num_labels = num_labels self.model = None self.classes_: Optional[np.ndarray] = None def _load_model(self, num_labels: Optional[int] = None): if self.model is None: from sentence_transformers import CrossEncoder self.model = CrossEncoder( self.cross_encoder_name, num_labels=num_labels or self.num_labels, max_length=self.max_length, ) def _to_examples(self, contexts: List[QueryContext], labels: Optional[np.ndarray] = None): from sentence_transformers import InputExample examples = [] for i, ctx in enumerate(contexts): label = float(labels[i]) if labels is not None else 0.0 examples.append( InputExample( texts=[_context_text(ctx), _student_text(ctx)], label=label, ) ) return examples def fit( self, contexts: List[QueryContext], y: np.ndarray, epochs: int = 1, warmup_steps: int = 100, output_path: Optional[Path] = None, ) -> "FineTunedCrossEncoderClassifier": from torch.utils.data import DataLoader self._load_model(num_labels=len(np.unique(y))) train_examples = self._to_examples(contexts, y) loader = DataLoader( train_examples, shuffle=True, batch_size=self.batch_size, ) self.model.fit( train_dataloader=loader, epochs=epochs, warmup_steps=min(warmup_steps, max(10, len(train_examples) // 10)), show_progress_bar=True, output_path=str(output_path) if output_path else None, ) self.classes_ = np.sort(np.unique(y)) return self def predict(self, contexts: List[QueryContext]) -> np.ndarray: self._load_model() pairs = [[_context_text(c), _student_text(c)] for c in contexts] logits = self.model.predict( pairs, batch_size=self.batch_size, show_progress_bar=False, convert_to_numpy=True, ) logits = np.asarray(logits) if logits.ndim == 1: return logits.astype(int) return logits.argmax(axis=1) def predict_proba(self, contexts: List[QueryContext]) -> np.ndarray: self._load_model() pairs = [[_context_text(c), _student_text(c)] for c in contexts] logits = self.model.predict( pairs, batch_size=self.batch_size, show_progress_bar=False, convert_to_numpy=True, ) logits = np.asarray(logits, dtype=np.float64) if logits.ndim == 1: # binary fallback probs = np.zeros((len(contexts), len(self.classes_))) for i, pred in enumerate(logits.astype(int)): idx = np.where(self.classes_ == pred)[0][0] probs[i, idx] = 1.0 return probs # softmax exp = np.exp(logits - logits.max(axis=1, keepdims=True)) return exp / exp.sum(axis=1, keepdims=True) def save(self, path: Path) -> Path: path.mkdir(parents=True, exist_ok=True) self._load_model() self.model.save(str(path)) return path @classmethod def load(cls, path: Path) -> "FineTunedCrossEncoderClassifier": from sentence_transformers import CrossEncoder instance = cls() instance.model = CrossEncoder(str(path)) instance.classes_ = np.arange(instance.model.num_labels) return instance