Spaces:
Sleeping
Sleeping
| """Cross-encoder architecture for SQL error classification.""" | |
| from __future__ import annotations | |
| from dataclasses import dataclass | |
| from pathlib import Path | |
| from typing import List, Optional, Tuple | |
| import numpy as np | |
| from sklearn.linear_model import LogisticRegression | |
| from sklearn.preprocessing import StandardScaler | |
| from src.multi_tower_model import QueryContext, contexts_from_dataframe | |
| from src.sql_features import extract_sql_features | |
| DEFAULT_CROSS_ENCODER = "cross-encoder/ms-marco-MiniLM-L6-v2" | |
| DEFAULT_FINETUNED_CE = "cross-encoder/ms-marco-MiniLM-L6-v2" | |
| PAIR_NAMES = ( | |
| "intent_vs_student", | |
| "reference_vs_student", | |
| "intent_vs_reference", | |
| ) | |
| class CrossEncoderPair: | |
| name: str | |
| text_a: str | |
| text_b: str | |
| def _intent_text(ctx: QueryContext) -> str: | |
| return f"QUESTION: {ctx.question} SCHEMA: {ctx.schema}" | |
| def _reference_text(ctx: QueryContext) -> str: | |
| return f"REFERENCE: {ctx.correct_query}" | |
| def _student_text(ctx: QueryContext) -> str: | |
| parts = [f"STUDENT: {ctx.student_query}"] | |
| if ctx.error_message: | |
| parts.append(f"ERROR: {ctx.error_message}") | |
| return " ".join(parts) | |
| def _context_text(ctx: QueryContext) -> str: | |
| """Full task context for fine-tuned cross-encoder.""" | |
| return ( | |
| f"QUESTION: {ctx.question} " | |
| f"SCHEMA: {ctx.schema} " | |
| f"REFERENCE: {ctx.correct_query}" | |
| ) | |
| def build_pairs(ctx: QueryContext) -> List[CrossEncoderPair]: | |
| intent, reference, student = ( | |
| _intent_text(ctx), | |
| _reference_text(ctx), | |
| _student_text(ctx), | |
| ) | |
| return [ | |
| CrossEncoderPair("intent_vs_student", intent, student), | |
| CrossEncoderPair("reference_vs_student", reference, student), | |
| CrossEncoderPair("intent_vs_reference", intent, reference), | |
| ] | |
| class CrossEncoderClassifier: | |
| """ | |
| Hybrid cross-encoder: frozen pairwise relevance + linear head. | |
| Unlike bi-encoders (multi-tower), the cross-encoder attends jointly over | |
| each (context, student) pair — better for logical and filtering errors. | |
| Three pairs are scored: | |
| 1. intent vs student — does the query address the question? | |
| 2. reference vs student — how far is the student from the answer? | |
| 3. intent vs reference — task-answer alignment baseline | |
| Pair scores + SQL rule features → LogisticRegression → 15 classes. | |
| """ | |
| def __init__( | |
| self, | |
| cross_encoder_name: str = DEFAULT_CROSS_ENCODER, | |
| batch_size: int = 32, | |
| max_length: int = 512, | |
| ): | |
| self.cross_encoder_name = cross_encoder_name | |
| self.batch_size = batch_size | |
| self.max_length = max_length | |
| self.cross_encoder = None | |
| self.scaler = StandardScaler() | |
| self.clf = LogisticRegression( | |
| max_iter=1000, | |
| solver="lbfgs", | |
| class_weight="balanced", | |
| random_state=42, | |
| ) | |
| self.classes_: Optional[np.ndarray] = None | |
| def _load_cross_encoder(self): | |
| if self.cross_encoder is None: | |
| from sentence_transformers import CrossEncoder | |
| self.cross_encoder = CrossEncoder( | |
| self.cross_encoder_name, | |
| max_length=self.max_length, | |
| ) | |
| def _pair_batches(self, contexts: List[QueryContext]) -> List[List[Tuple[str, str]]]: | |
| """One batch list per pair type across all contexts.""" | |
| pair_lists: List[List[Tuple[str, str]]] = [[], [], []] | |
| for ctx in contexts: | |
| pairs = build_pairs(ctx) | |
| for i, pair in enumerate(pairs): | |
| pair_lists[i].append((pair.text_a, pair.text_b)) | |
| return pair_lists | |
| def _score_pairs( | |
| self, | |
| contexts: List[QueryContext], | |
| show_progress: bool = False, | |
| ) -> np.ndarray: | |
| self._load_cross_encoder() | |
| pair_batches = self._pair_batches(contexts) | |
| scores = [] | |
| for batch in pair_batches: | |
| raw = self.cross_encoder.predict( | |
| batch, | |
| batch_size=self.batch_size, | |
| show_progress_bar=show_progress, | |
| ) | |
| scores.append(np.asarray(raw, dtype=np.float64).reshape(-1, 1)) | |
| return np.hstack(scores) # (n, 3) | |
| def _build_features( | |
| self, | |
| contexts: List[QueryContext], | |
| show_progress: bool = False, | |
| ) -> np.ndarray: | |
| pair_scores = self._score_pairs(contexts, show_progress=show_progress) | |
| s_is, s_rs, s_ir = pair_scores[:, 0], pair_scores[:, 1], pair_scores[:, 2] | |
| derived = np.column_stack( | |
| [ | |
| s_rs - s_is, # reference closer than intent? | |
| s_is - s_ir, # student-intent gap vs baseline | |
| s_rs - s_ir, # student-reference gap vs baseline | |
| s_is * s_rs, # interaction | |
| np.abs(s_rs - s_is), # intent-reference disagreement | |
| ] | |
| ) | |
| sql_feats = np.array( | |
| [extract_sql_features(c.student_query, c.correct_query) for c in contexts], | |
| dtype=np.float64, | |
| ) | |
| return np.hstack([pair_scores, derived, sql_feats]) | |
| def _prepare_features(self, contexts: List[QueryContext]) -> np.ndarray: | |
| X = self.scaler.transform(self._build_features(contexts)) | |
| return np.nan_to_num(X, nan=0.0, posinf=1e3, neginf=-1e3) | |
| def fit(self, contexts: List[QueryContext], y: np.ndarray) -> "CrossEncoderClassifier": | |
| X = self._build_features(contexts, show_progress=True) | |
| X = self.scaler.fit_transform(X) | |
| X = np.nan_to_num(X, nan=0.0, posinf=1e3, neginf=-1e3) | |
| self.clf.fit(X, y) | |
| self.classes_ = self.clf.classes_ | |
| return self | |
| def predict(self, contexts: List[QueryContext]) -> np.ndarray: | |
| return self.clf.predict(self._prepare_features(contexts)) | |
| def predict_proba(self, contexts: List[QueryContext]) -> np.ndarray: | |
| return self.clf.predict_proba(self._prepare_features(contexts)) | |
| def explain_pair_scores(self, ctx: QueryContext) -> dict: | |
| scores = self._score_pairs([ctx])[0] | |
| return { | |
| PAIR_NAMES[0]: float(scores[0]), | |
| PAIR_NAMES[1]: float(scores[1]), | |
| PAIR_NAMES[2]: float(scores[2]), | |
| } | |
| class FineTunedCrossEncoderClassifier: | |
| """ | |
| End-to-end fine-tuned cross-encoder (highest accuracy). | |
| Single cross-attention pass over [task_context | student_query] with | |
| num_labels=15. Slower to train; best on smaller high-quality datasets. | |
| """ | |
| def __init__( | |
| self, | |
| cross_encoder_name: str = DEFAULT_FINETUNED_CE, | |
| batch_size: int = 16, | |
| max_length: int = 512, | |
| num_labels: int = 15, | |
| ): | |
| self.cross_encoder_name = cross_encoder_name | |
| self.batch_size = batch_size | |
| self.max_length = max_length | |
| self.num_labels = num_labels | |
| self.model = None | |
| self.classes_: Optional[np.ndarray] = None | |
| def _load_model(self, num_labels: Optional[int] = None): | |
| if self.model is None: | |
| from sentence_transformers import CrossEncoder | |
| self.model = CrossEncoder( | |
| self.cross_encoder_name, | |
| num_labels=num_labels or self.num_labels, | |
| max_length=self.max_length, | |
| ) | |
| def _to_examples(self, contexts: List[QueryContext], labels: Optional[np.ndarray] = None): | |
| from sentence_transformers import InputExample | |
| examples = [] | |
| for i, ctx in enumerate(contexts): | |
| label = float(labels[i]) if labels is not None else 0.0 | |
| examples.append( | |
| InputExample( | |
| texts=[_context_text(ctx), _student_text(ctx)], | |
| label=label, | |
| ) | |
| ) | |
| return examples | |
| def fit( | |
| self, | |
| contexts: List[QueryContext], | |
| y: np.ndarray, | |
| epochs: int = 1, | |
| warmup_steps: int = 100, | |
| output_path: Optional[Path] = None, | |
| ) -> "FineTunedCrossEncoderClassifier": | |
| from torch.utils.data import DataLoader | |
| self._load_model(num_labels=len(np.unique(y))) | |
| train_examples = self._to_examples(contexts, y) | |
| loader = DataLoader( | |
| train_examples, | |
| shuffle=True, | |
| batch_size=self.batch_size, | |
| ) | |
| self.model.fit( | |
| train_dataloader=loader, | |
| epochs=epochs, | |
| warmup_steps=min(warmup_steps, max(10, len(train_examples) // 10)), | |
| show_progress_bar=True, | |
| output_path=str(output_path) if output_path else None, | |
| ) | |
| self.classes_ = np.sort(np.unique(y)) | |
| return self | |
| def predict(self, contexts: List[QueryContext]) -> np.ndarray: | |
| self._load_model() | |
| pairs = [[_context_text(c), _student_text(c)] for c in contexts] | |
| logits = self.model.predict( | |
| pairs, | |
| batch_size=self.batch_size, | |
| show_progress_bar=False, | |
| convert_to_numpy=True, | |
| ) | |
| logits = np.asarray(logits) | |
| if logits.ndim == 1: | |
| return logits.astype(int) | |
| return logits.argmax(axis=1) | |
| def predict_proba(self, contexts: List[QueryContext]) -> np.ndarray: | |
| self._load_model() | |
| pairs = [[_context_text(c), _student_text(c)] for c in contexts] | |
| logits = self.model.predict( | |
| pairs, | |
| batch_size=self.batch_size, | |
| show_progress_bar=False, | |
| convert_to_numpy=True, | |
| ) | |
| logits = np.asarray(logits, dtype=np.float64) | |
| if logits.ndim == 1: | |
| # binary fallback | |
| probs = np.zeros((len(contexts), len(self.classes_))) | |
| for i, pred in enumerate(logits.astype(int)): | |
| idx = np.where(self.classes_ == pred)[0][0] | |
| probs[i, idx] = 1.0 | |
| return probs | |
| # softmax | |
| exp = np.exp(logits - logits.max(axis=1, keepdims=True)) | |
| return exp / exp.sum(axis=1, keepdims=True) | |
| def save(self, path: Path) -> Path: | |
| path.mkdir(parents=True, exist_ok=True) | |
| self._load_model() | |
| self.model.save(str(path)) | |
| return path | |
| def load(cls, path: Path) -> "FineTunedCrossEncoderClassifier": | |
| from sentence_transformers import CrossEncoder | |
| instance = cls() | |
| instance.model = CrossEncoder(str(path)) | |
| instance.classes_ = np.arange(instance.model.num_labels) | |
| return instance | |