Spaces:
Sleeping
Sleeping
| """Train the SQL error classifier.""" | |
| from __future__ import annotations | |
| import argparse | |
| import json | |
| from pathlib import Path | |
| import pandas as pd | |
| from sklearn.metrics import classification_report | |
| from sklearn.model_selection import train_test_split | |
| from src.categories import id_to_name, load_categories | |
| from src.cross_encoder_model import ( | |
| CrossEncoderClassifier, | |
| FineTunedCrossEncoderClassifier, | |
| ) | |
| from src.model import ( | |
| DEFAULT_MODEL_PATH, | |
| ModelType, | |
| build_classifier, | |
| combine_features, | |
| save_model, | |
| ) | |
| from src.multi_tower_model import MultiTowerClassifier, contexts_from_dataframe | |
| PROJECT_ROOT = Path(__file__).resolve().parent.parent | |
| DEFAULT_DATA = PROJECT_ROOT / "data" / "sql_errors_1m.parquet" | |
| DEFAULT_METRICS = PROJECT_ROOT / "models" / "metrics.json" | |
| CONTEXT_MODELS = ( | |
| CrossEncoderClassifier, | |
| FineTunedCrossEncoderClassifier, | |
| MultiTowerClassifier, | |
| ) | |
| def _split_dataframe( | |
| df: pd.DataFrame, test_size: float, val_size: float, seed: int | |
| ) -> tuple[pd.DataFrame, pd.DataFrame, pd.DataFrame]: | |
| trainval, test = train_test_split( | |
| df, test_size=test_size, random_state=seed, stratify=df["label_id"] | |
| ) | |
| relative_val = val_size / (1 - test_size) | |
| train, val = train_test_split( | |
| trainval, | |
| test_size=relative_val, | |
| random_state=seed, | |
| stratify=trainval["label_id"], | |
| ) | |
| return train, val, test | |
| def train( | |
| data_path: Path = DEFAULT_DATA, | |
| model_path: Path = DEFAULT_MODEL_PATH, | |
| metrics_path: Path = DEFAULT_METRICS, | |
| test_size: float = 0.1, | |
| val_size: float = 0.1, | |
| use_error_message: bool = True, | |
| max_train_samples: int | None = None, | |
| model_type: ModelType = "cross_encoder", | |
| epochs: int = 1, | |
| seed: int = 42, | |
| ) -> dict: | |
| print(f"Loading data from {data_path}...") | |
| df = pd.read_parquet(data_path) | |
| if max_train_samples and len(df) > max_train_samples: | |
| df = df.sample(n=max_train_samples, random_state=seed) | |
| if not use_error_message and "error_message" in df.columns: | |
| df = df.drop(columns=["error_message"]) | |
| train_df, val_df, test_df = _split_dataframe(df, test_size, val_size, seed) | |
| print( | |
| f"Train: {len(train_df):,} | Val: {len(val_df):,} | Test: {len(test_df):,}" | |
| ) | |
| model = build_classifier(model_type=model_type) | |
| print(f"Training {model_type} classifier...") | |
| if isinstance(model, CONTEXT_MODELS): | |
| train_ctx = contexts_from_dataframe(train_df) | |
| val_ctx = contexts_from_dataframe(val_df) | |
| test_ctx = contexts_from_dataframe(test_df) | |
| if isinstance(model, FineTunedCrossEncoderClassifier): | |
| model.fit( | |
| train_ctx, | |
| train_df["label_id"].values, | |
| epochs=epochs, | |
| output_path=model_path.with_suffix(".ce") | |
| if model_path.suffix == ".joblib" | |
| else model_path, | |
| ) | |
| else: | |
| model.fit(train_ctx, train_df["label_id"].values) | |
| val_preds = model.predict(val_ctx) | |
| test_preds = model.predict(test_ctx) | |
| y_val = val_df["label_id"].values | |
| y_test = test_df["label_id"].values | |
| else: | |
| def to_texts(frame: pd.DataFrame) -> list[str]: | |
| return combine_features( | |
| queries=frame["query"].tolist(), | |
| error_messages=frame["error_message"].tolist() | |
| if "error_message" in frame.columns | |
| else None, | |
| schemas=frame["schema"].tolist() if "schema" in frame.columns else None, | |
| questions=frame["question"].tolist() | |
| if "question" in frame.columns | |
| else None, | |
| ) | |
| model.fit(to_texts(train_df), train_df["label_id"].values) | |
| val_preds = model.predict(to_texts(val_df)) | |
| test_preds = model.predict(to_texts(test_df)) | |
| y_val = val_df["label_id"].values | |
| y_test = test_df["label_id"].values | |
| val_report = classification_report( | |
| y_val, val_preds, output_dict=True, zero_division=0 | |
| ) | |
| print(f"Validation accuracy: {val_report['accuracy']:.4f}") | |
| test_report = classification_report( | |
| y_test, test_preds, output_dict=True, zero_division=0 | |
| ) | |
| print(f"Test accuracy: {test_report['accuracy']:.4f}") | |
| save_model(model, model_path, model_type=model_type) | |
| print(f"Model saved to {model_path}") | |
| categories = load_categories() | |
| label_map = id_to_name(categories) | |
| metrics = { | |
| "train_size": len(train_df), | |
| "val_size": len(val_df), | |
| "test_size": len(test_df), | |
| "model_type": model_type, | |
| "epochs": epochs if model_type == "cross_encoder_ft" else None, | |
| "use_error_message": use_error_message, | |
| "validation": val_report, | |
| "test": test_report, | |
| "label_map": {str(k): v for k, v in label_map.items()}, | |
| } | |
| metrics_path.parent.mkdir(parents=True, exist_ok=True) | |
| with open(metrics_path, "w") as f: | |
| json.dump(metrics, f, indent=2) | |
| print(f"Metrics saved to {metrics_path}") | |
| return metrics | |
| def main() -> None: | |
| parser = argparse.ArgumentParser(description="Train SQL error classifier") | |
| parser.add_argument("--data", type=Path, default=DEFAULT_DATA) | |
| parser.add_argument("--model", type=Path, default=DEFAULT_MODEL_PATH) | |
| parser.add_argument("--metrics", type=Path, default=DEFAULT_METRICS) | |
| parser.add_argument("--test-size", type=float, default=0.1) | |
| parser.add_argument("--val-size", type=float, default=0.1) | |
| parser.add_argument("--no-error-message", action="store_true") | |
| parser.add_argument("--max-samples", type=int, default=None) | |
| parser.add_argument( | |
| "--model-type", | |
| choices=["cross_encoder", "cross_encoder_ft", "multi_tower", "minilm", "tfidf"], | |
| default="cross_encoder", | |
| help="cross_encoder (recommended): joint attention pairs; " | |
| "cross_encoder_ft: fine-tuned end-to-end (best accuracy)", | |
| ) | |
| parser.add_argument( | |
| "--epochs", | |
| type=int, | |
| default=1, | |
| help="Epochs for cross_encoder_ft fine-tuning", | |
| ) | |
| parser.add_argument("--seed", type=int, default=42) | |
| args = parser.parse_args() | |
| train( | |
| data_path=args.data, | |
| model_path=args.model, | |
| metrics_path=args.metrics, | |
| test_size=args.test_size, | |
| val_size=args.val_size, | |
| use_error_message=not args.no_error_message, | |
| max_train_samples=args.max_samples, | |
| model_type=args.model_type, | |
| epochs=args.epochs, | |
| seed=args.seed, | |
| ) | |
| if __name__ == "__main__": | |
| main() | |