sql-error-classifier-train / src /cross_encoder_model.py
nishu08's picture
Deploy CodeBERT training Space
9b2cded verified
Raw
History Blame Contribute Delete
10.5 kB
"""Cross-encoder architecture for SQL error classification."""
from __future__ import annotations
from dataclasses import dataclass
from pathlib import Path
from typing import List, Optional, Tuple
import numpy as np
from sklearn.linear_model import LogisticRegression
from sklearn.preprocessing import StandardScaler
from src.multi_tower_model import QueryContext, contexts_from_dataframe
from src.sql_features import extract_sql_features
DEFAULT_CROSS_ENCODER = "cross-encoder/ms-marco-MiniLM-L6-v2"
DEFAULT_FINETUNED_CE = "cross-encoder/ms-marco-MiniLM-L6-v2"
PAIR_NAMES = (
"intent_vs_student",
"reference_vs_student",
"intent_vs_reference",
)
@dataclass(frozen=True)
class CrossEncoderPair:
name: str
text_a: str
text_b: str
def _intent_text(ctx: QueryContext) -> str:
return f"QUESTION: {ctx.question} SCHEMA: {ctx.schema}"
def _reference_text(ctx: QueryContext) -> str:
return f"REFERENCE: {ctx.correct_query}"
def _student_text(ctx: QueryContext) -> str:
parts = [f"STUDENT: {ctx.student_query}"]
if ctx.error_message:
parts.append(f"ERROR: {ctx.error_message}")
return " ".join(parts)
def _context_text(ctx: QueryContext) -> str:
"""Full task context for fine-tuned cross-encoder."""
return (
f"QUESTION: {ctx.question} "
f"SCHEMA: {ctx.schema} "
f"REFERENCE: {ctx.correct_query}"
)
def build_pairs(ctx: QueryContext) -> List[CrossEncoderPair]:
intent, reference, student = (
_intent_text(ctx),
_reference_text(ctx),
_student_text(ctx),
)
return [
CrossEncoderPair("intent_vs_student", intent, student),
CrossEncoderPair("reference_vs_student", reference, student),
CrossEncoderPair("intent_vs_reference", intent, reference),
]
class CrossEncoderClassifier:
"""
Hybrid cross-encoder: frozen pairwise relevance + linear head.
Unlike bi-encoders (multi-tower), the cross-encoder attends jointly over
each (context, student) pair — better for logical and filtering errors.
Three pairs are scored:
1. intent vs student — does the query address the question?
2. reference vs student — how far is the student from the answer?
3. intent vs reference — task-answer alignment baseline
Pair scores + SQL rule features → LogisticRegression → 15 classes.
"""
def __init__(
self,
cross_encoder_name: str = DEFAULT_CROSS_ENCODER,
batch_size: int = 32,
max_length: int = 512,
):
self.cross_encoder_name = cross_encoder_name
self.batch_size = batch_size
self.max_length = max_length
self.cross_encoder = None
self.scaler = StandardScaler()
self.clf = LogisticRegression(
max_iter=1000,
solver="lbfgs",
class_weight="balanced",
random_state=42,
)
self.classes_: Optional[np.ndarray] = None
def _load_cross_encoder(self):
if self.cross_encoder is None:
from sentence_transformers import CrossEncoder
self.cross_encoder = CrossEncoder(
self.cross_encoder_name,
max_length=self.max_length,
)
def _pair_batches(self, contexts: List[QueryContext]) -> List[List[Tuple[str, str]]]:
"""One batch list per pair type across all contexts."""
pair_lists: List[List[Tuple[str, str]]] = [[], [], []]
for ctx in contexts:
pairs = build_pairs(ctx)
for i, pair in enumerate(pairs):
pair_lists[i].append((pair.text_a, pair.text_b))
return pair_lists
def _score_pairs(
self,
contexts: List[QueryContext],
show_progress: bool = False,
) -> np.ndarray:
self._load_cross_encoder()
pair_batches = self._pair_batches(contexts)
scores = []
for batch in pair_batches:
raw = self.cross_encoder.predict(
batch,
batch_size=self.batch_size,
show_progress_bar=show_progress,
)
scores.append(np.asarray(raw, dtype=np.float64).reshape(-1, 1))
return np.hstack(scores) # (n, 3)
def _build_features(
self,
contexts: List[QueryContext],
show_progress: bool = False,
) -> np.ndarray:
pair_scores = self._score_pairs(contexts, show_progress=show_progress)
s_is, s_rs, s_ir = pair_scores[:, 0], pair_scores[:, 1], pair_scores[:, 2]
derived = np.column_stack(
[
s_rs - s_is, # reference closer than intent?
s_is - s_ir, # student-intent gap vs baseline
s_rs - s_ir, # student-reference gap vs baseline
s_is * s_rs, # interaction
np.abs(s_rs - s_is), # intent-reference disagreement
]
)
sql_feats = np.array(
[extract_sql_features(c.student_query, c.correct_query) for c in contexts],
dtype=np.float64,
)
return np.hstack([pair_scores, derived, sql_feats])
def _prepare_features(self, contexts: List[QueryContext]) -> np.ndarray:
X = self.scaler.transform(self._build_features(contexts))
return np.nan_to_num(X, nan=0.0, posinf=1e3, neginf=-1e3)
def fit(self, contexts: List[QueryContext], y: np.ndarray) -> "CrossEncoderClassifier":
X = self._build_features(contexts, show_progress=True)
X = self.scaler.fit_transform(X)
X = np.nan_to_num(X, nan=0.0, posinf=1e3, neginf=-1e3)
self.clf.fit(X, y)
self.classes_ = self.clf.classes_
return self
def predict(self, contexts: List[QueryContext]) -> np.ndarray:
return self.clf.predict(self._prepare_features(contexts))
def predict_proba(self, contexts: List[QueryContext]) -> np.ndarray:
return self.clf.predict_proba(self._prepare_features(contexts))
def explain_pair_scores(self, ctx: QueryContext) -> dict:
scores = self._score_pairs([ctx])[0]
return {
PAIR_NAMES[0]: float(scores[0]),
PAIR_NAMES[1]: float(scores[1]),
PAIR_NAMES[2]: float(scores[2]),
}
class FineTunedCrossEncoderClassifier:
"""
End-to-end fine-tuned cross-encoder (highest accuracy).
Single cross-attention pass over [task_context | student_query] with
num_labels=15. Slower to train; best on smaller high-quality datasets.
"""
def __init__(
self,
cross_encoder_name: str = DEFAULT_FINETUNED_CE,
batch_size: int = 16,
max_length: int = 512,
num_labels: int = 15,
):
self.cross_encoder_name = cross_encoder_name
self.batch_size = batch_size
self.max_length = max_length
self.num_labels = num_labels
self.model = None
self.classes_: Optional[np.ndarray] = None
def _load_model(self, num_labels: Optional[int] = None):
if self.model is None:
from sentence_transformers import CrossEncoder
self.model = CrossEncoder(
self.cross_encoder_name,
num_labels=num_labels or self.num_labels,
max_length=self.max_length,
)
def _to_examples(self, contexts: List[QueryContext], labels: Optional[np.ndarray] = None):
from sentence_transformers import InputExample
examples = []
for i, ctx in enumerate(contexts):
label = float(labels[i]) if labels is not None else 0.0
examples.append(
InputExample(
texts=[_context_text(ctx), _student_text(ctx)],
label=label,
)
)
return examples
def fit(
self,
contexts: List[QueryContext],
y: np.ndarray,
epochs: int = 1,
warmup_steps: int = 100,
output_path: Optional[Path] = None,
) -> "FineTunedCrossEncoderClassifier":
from torch.utils.data import DataLoader
self._load_model(num_labels=len(np.unique(y)))
train_examples = self._to_examples(contexts, y)
loader = DataLoader(
train_examples,
shuffle=True,
batch_size=self.batch_size,
)
self.model.fit(
train_dataloader=loader,
epochs=epochs,
warmup_steps=min(warmup_steps, max(10, len(train_examples) // 10)),
show_progress_bar=True,
output_path=str(output_path) if output_path else None,
)
self.classes_ = np.sort(np.unique(y))
return self
def predict(self, contexts: List[QueryContext]) -> np.ndarray:
self._load_model()
pairs = [[_context_text(c), _student_text(c)] for c in contexts]
logits = self.model.predict(
pairs,
batch_size=self.batch_size,
show_progress_bar=False,
convert_to_numpy=True,
)
logits = np.asarray(logits)
if logits.ndim == 1:
return logits.astype(int)
return logits.argmax(axis=1)
def predict_proba(self, contexts: List[QueryContext]) -> np.ndarray:
self._load_model()
pairs = [[_context_text(c), _student_text(c)] for c in contexts]
logits = self.model.predict(
pairs,
batch_size=self.batch_size,
show_progress_bar=False,
convert_to_numpy=True,
)
logits = np.asarray(logits, dtype=np.float64)
if logits.ndim == 1:
# binary fallback
probs = np.zeros((len(contexts), len(self.classes_)))
for i, pred in enumerate(logits.astype(int)):
idx = np.where(self.classes_ == pred)[0][0]
probs[i, idx] = 1.0
return probs
# softmax
exp = np.exp(logits - logits.max(axis=1, keepdims=True))
return exp / exp.sum(axis=1, keepdims=True)
def save(self, path: Path) -> Path:
path.mkdir(parents=True, exist_ok=True)
self._load_model()
self.model.save(str(path))
return path
@classmethod
def load(cls, path: Path) -> "FineTunedCrossEncoderClassifier":
from sentence_transformers import CrossEncoder
instance = cls()
instance.model = CrossEncoder(str(path))
instance.classes_ = np.arange(instance.model.num_labels)
return instance