nishu08's picture
Deploy CodeBERT training Space
9b2cded verified
Raw
History Blame Contribute Delete
6.6 kB
"""Train the SQL error classifier."""
from __future__ import annotations
import argparse
import json
from pathlib import Path
import pandas as pd
from sklearn.metrics import classification_report
from sklearn.model_selection import train_test_split
from src.categories import id_to_name, load_categories
from src.cross_encoder_model import (
CrossEncoderClassifier,
FineTunedCrossEncoderClassifier,
)
from src.model import (
DEFAULT_MODEL_PATH,
ModelType,
build_classifier,
combine_features,
save_model,
)
from src.multi_tower_model import MultiTowerClassifier, contexts_from_dataframe
PROJECT_ROOT = Path(__file__).resolve().parent.parent
DEFAULT_DATA = PROJECT_ROOT / "data" / "sql_errors_1m.parquet"
DEFAULT_METRICS = PROJECT_ROOT / "models" / "metrics.json"
CONTEXT_MODELS = (
CrossEncoderClassifier,
FineTunedCrossEncoderClassifier,
MultiTowerClassifier,
)
def _split_dataframe(
df: pd.DataFrame, test_size: float, val_size: float, seed: int
) -> tuple[pd.DataFrame, pd.DataFrame, pd.DataFrame]:
trainval, test = train_test_split(
df, test_size=test_size, random_state=seed, stratify=df["label_id"]
)
relative_val = val_size / (1 - test_size)
train, val = train_test_split(
trainval,
test_size=relative_val,
random_state=seed,
stratify=trainval["label_id"],
)
return train, val, test
def train(
data_path: Path = DEFAULT_DATA,
model_path: Path = DEFAULT_MODEL_PATH,
metrics_path: Path = DEFAULT_METRICS,
test_size: float = 0.1,
val_size: float = 0.1,
use_error_message: bool = True,
max_train_samples: int | None = None,
model_type: ModelType = "cross_encoder",
epochs: int = 1,
seed: int = 42,
) -> dict:
print(f"Loading data from {data_path}...")
df = pd.read_parquet(data_path)
if max_train_samples and len(df) > max_train_samples:
df = df.sample(n=max_train_samples, random_state=seed)
if not use_error_message and "error_message" in df.columns:
df = df.drop(columns=["error_message"])
train_df, val_df, test_df = _split_dataframe(df, test_size, val_size, seed)
print(
f"Train: {len(train_df):,} | Val: {len(val_df):,} | Test: {len(test_df):,}"
)
model = build_classifier(model_type=model_type)
print(f"Training {model_type} classifier...")
if isinstance(model, CONTEXT_MODELS):
train_ctx = contexts_from_dataframe(train_df)
val_ctx = contexts_from_dataframe(val_df)
test_ctx = contexts_from_dataframe(test_df)
if isinstance(model, FineTunedCrossEncoderClassifier):
model.fit(
train_ctx,
train_df["label_id"].values,
epochs=epochs,
output_path=model_path.with_suffix(".ce")
if model_path.suffix == ".joblib"
else model_path,
)
else:
model.fit(train_ctx, train_df["label_id"].values)
val_preds = model.predict(val_ctx)
test_preds = model.predict(test_ctx)
y_val = val_df["label_id"].values
y_test = test_df["label_id"].values
else:
def to_texts(frame: pd.DataFrame) -> list[str]:
return combine_features(
queries=frame["query"].tolist(),
error_messages=frame["error_message"].tolist()
if "error_message" in frame.columns
else None,
schemas=frame["schema"].tolist() if "schema" in frame.columns else None,
questions=frame["question"].tolist()
if "question" in frame.columns
else None,
)
model.fit(to_texts(train_df), train_df["label_id"].values)
val_preds = model.predict(to_texts(val_df))
test_preds = model.predict(to_texts(test_df))
y_val = val_df["label_id"].values
y_test = test_df["label_id"].values
val_report = classification_report(
y_val, val_preds, output_dict=True, zero_division=0
)
print(f"Validation accuracy: {val_report['accuracy']:.4f}")
test_report = classification_report(
y_test, test_preds, output_dict=True, zero_division=0
)
print(f"Test accuracy: {test_report['accuracy']:.4f}")
save_model(model, model_path, model_type=model_type)
print(f"Model saved to {model_path}")
categories = load_categories()
label_map = id_to_name(categories)
metrics = {
"train_size": len(train_df),
"val_size": len(val_df),
"test_size": len(test_df),
"model_type": model_type,
"epochs": epochs if model_type == "cross_encoder_ft" else None,
"use_error_message": use_error_message,
"validation": val_report,
"test": test_report,
"label_map": {str(k): v for k, v in label_map.items()},
}
metrics_path.parent.mkdir(parents=True, exist_ok=True)
with open(metrics_path, "w") as f:
json.dump(metrics, f, indent=2)
print(f"Metrics saved to {metrics_path}")
return metrics
def main() -> None:
parser = argparse.ArgumentParser(description="Train SQL error classifier")
parser.add_argument("--data", type=Path, default=DEFAULT_DATA)
parser.add_argument("--model", type=Path, default=DEFAULT_MODEL_PATH)
parser.add_argument("--metrics", type=Path, default=DEFAULT_METRICS)
parser.add_argument("--test-size", type=float, default=0.1)
parser.add_argument("--val-size", type=float, default=0.1)
parser.add_argument("--no-error-message", action="store_true")
parser.add_argument("--max-samples", type=int, default=None)
parser.add_argument(
"--model-type",
choices=["cross_encoder", "cross_encoder_ft", "multi_tower", "minilm", "tfidf"],
default="cross_encoder",
help="cross_encoder (recommended): joint attention pairs; "
"cross_encoder_ft: fine-tuned end-to-end (best accuracy)",
)
parser.add_argument(
"--epochs",
type=int,
default=1,
help="Epochs for cross_encoder_ft fine-tuning",
)
parser.add_argument("--seed", type=int, default=42)
args = parser.parse_args()
train(
data_path=args.data,
model_path=args.model,
metrics_path=args.metrics,
test_size=args.test_size,
val_size=args.val_size,
use_error_message=not args.no_error_message,
max_train_samples=args.max_samples,
model_type=args.model_type,
epochs=args.epochs,
seed=args.seed,
)
if __name__ == "__main__":
main()