"""Train the SQL error classifier.""" from __future__ import annotations import argparse import json from pathlib import Path import pandas as pd from sklearn.metrics import classification_report from sklearn.model_selection import train_test_split from src.categories import id_to_name, load_categories from src.cross_encoder_model import ( CrossEncoderClassifier, FineTunedCrossEncoderClassifier, ) from src.model import ( DEFAULT_MODEL_PATH, ModelType, build_classifier, combine_features, save_model, ) from src.multi_tower_model import MultiTowerClassifier, contexts_from_dataframe PROJECT_ROOT = Path(__file__).resolve().parent.parent DEFAULT_DATA = PROJECT_ROOT / "data" / "sql_errors_1m.parquet" DEFAULT_METRICS = PROJECT_ROOT / "models" / "metrics.json" CONTEXT_MODELS = ( CrossEncoderClassifier, FineTunedCrossEncoderClassifier, MultiTowerClassifier, ) def _split_dataframe( df: pd.DataFrame, test_size: float, val_size: float, seed: int ) -> tuple[pd.DataFrame, pd.DataFrame, pd.DataFrame]: trainval, test = train_test_split( df, test_size=test_size, random_state=seed, stratify=df["label_id"] ) relative_val = val_size / (1 - test_size) train, val = train_test_split( trainval, test_size=relative_val, random_state=seed, stratify=trainval["label_id"], ) return train, val, test def train( data_path: Path = DEFAULT_DATA, model_path: Path = DEFAULT_MODEL_PATH, metrics_path: Path = DEFAULT_METRICS, test_size: float = 0.1, val_size: float = 0.1, use_error_message: bool = True, max_train_samples: int | None = None, model_type: ModelType = "cross_encoder", epochs: int = 1, seed: int = 42, ) -> dict: print(f"Loading data from {data_path}...") df = pd.read_parquet(data_path) if max_train_samples and len(df) > max_train_samples: df = df.sample(n=max_train_samples, random_state=seed) if not use_error_message and "error_message" in df.columns: df = df.drop(columns=["error_message"]) train_df, val_df, test_df = _split_dataframe(df, test_size, val_size, seed) print( f"Train: {len(train_df):,} | Val: {len(val_df):,} | Test: {len(test_df):,}" ) model = build_classifier(model_type=model_type) print(f"Training {model_type} classifier...") if isinstance(model, CONTEXT_MODELS): train_ctx = contexts_from_dataframe(train_df) val_ctx = contexts_from_dataframe(val_df) test_ctx = contexts_from_dataframe(test_df) if isinstance(model, FineTunedCrossEncoderClassifier): model.fit( train_ctx, train_df["label_id"].values, epochs=epochs, output_path=model_path.with_suffix(".ce") if model_path.suffix == ".joblib" else model_path, ) else: model.fit(train_ctx, train_df["label_id"].values) val_preds = model.predict(val_ctx) test_preds = model.predict(test_ctx) y_val = val_df["label_id"].values y_test = test_df["label_id"].values else: def to_texts(frame: pd.DataFrame) -> list[str]: return combine_features( queries=frame["query"].tolist(), error_messages=frame["error_message"].tolist() if "error_message" in frame.columns else None, schemas=frame["schema"].tolist() if "schema" in frame.columns else None, questions=frame["question"].tolist() if "question" in frame.columns else None, ) model.fit(to_texts(train_df), train_df["label_id"].values) val_preds = model.predict(to_texts(val_df)) test_preds = model.predict(to_texts(test_df)) y_val = val_df["label_id"].values y_test = test_df["label_id"].values val_report = classification_report( y_val, val_preds, output_dict=True, zero_division=0 ) print(f"Validation accuracy: {val_report['accuracy']:.4f}") test_report = classification_report( y_test, test_preds, output_dict=True, zero_division=0 ) print(f"Test accuracy: {test_report['accuracy']:.4f}") save_model(model, model_path, model_type=model_type) print(f"Model saved to {model_path}") categories = load_categories() label_map = id_to_name(categories) metrics = { "train_size": len(train_df), "val_size": len(val_df), "test_size": len(test_df), "model_type": model_type, "epochs": epochs if model_type == "cross_encoder_ft" else None, "use_error_message": use_error_message, "validation": val_report, "test": test_report, "label_map": {str(k): v for k, v in label_map.items()}, } metrics_path.parent.mkdir(parents=True, exist_ok=True) with open(metrics_path, "w") as f: json.dump(metrics, f, indent=2) print(f"Metrics saved to {metrics_path}") return metrics def main() -> None: parser = argparse.ArgumentParser(description="Train SQL error classifier") parser.add_argument("--data", type=Path, default=DEFAULT_DATA) parser.add_argument("--model", type=Path, default=DEFAULT_MODEL_PATH) parser.add_argument("--metrics", type=Path, default=DEFAULT_METRICS) parser.add_argument("--test-size", type=float, default=0.1) parser.add_argument("--val-size", type=float, default=0.1) parser.add_argument("--no-error-message", action="store_true") parser.add_argument("--max-samples", type=int, default=None) parser.add_argument( "--model-type", choices=["cross_encoder", "cross_encoder_ft", "multi_tower", "minilm", "tfidf"], default="cross_encoder", help="cross_encoder (recommended): joint attention pairs; " "cross_encoder_ft: fine-tuned end-to-end (best accuracy)", ) parser.add_argument( "--epochs", type=int, default=1, help="Epochs for cross_encoder_ft fine-tuning", ) parser.add_argument("--seed", type=int, default=42) args = parser.parse_args() train( data_path=args.data, model_path=args.model, metrics_path=args.metrics, test_size=args.test_size, val_size=args.val_size, use_error_message=not args.no_error_message, max_train_samples=args.max_samples, model_type=args.model_type, epochs=args.epochs, seed=args.seed, ) if __name__ == "__main__": main()