""" Hugging Face Space — CodeBERT SQL Error Classifier Training UI. Deploy as a Gradio Space with app_file: train_space_app.py Set hardware to GPU (t4-small recommended). Add HF_TOKEN secret to push trained models to your Hub account. """ from __future__ import annotations import json import os import shutil import tempfile from pathlib import Path import gradio as gr import pandas as pd from src.hf_train_codebert import train PROJECT_ROOT = Path(__file__).parent DEFAULT_DATA = PROJECT_ROOT / "data" / "sql_errors_dev.parquet" OUTPUT_DIR = PROJECT_ROOT / "models" / "codebert-cross-encoder" BUNDLED_DATASETS = { "Dev (15K samples)": str(PROJECT_ROOT / "data" / "sql_errors_dev.parquet"), "Full (1M samples)": str(PROJECT_ROOT / "data" / "sql_errors_1m.parquet"), } def _format_metrics(metrics: dict) -> str: val = metrics.get("validation", {}) test = metrics.get("test", {}) lines = [ "## Training complete", "", f"- Train samples: **{metrics.get('train_samples', 0):,}**", f"- Val samples: **{metrics.get('val_samples', 0):,}**", f"- Test samples: **{metrics.get('test_samples', 0):,}**", "", "### Validation", f"- F1 macro: **{val.get('eval_f1_macro', 0):.4f}**", f"- F1 micro: **{val.get('eval_f1_micro', 0):.4f}**", "", "### Test", f"- F1 macro: **{test.get('f1_macro', 0):.4f}**", f"- F1 micro: **{test.get('f1_micro', 0):.4f}**", f"- Subset accuracy: **{test.get('subset_accuracy', 0):.4f}**", "", f"Model saved to `{OUTPUT_DIR}`", ] if metrics.get("hub_url"): lines.append(f"\n**Hub model:** {metrics['hub_url']}") return "\n".join(lines) def run_training( dataset_choice: str, uploaded_file, max_samples: int, epochs: float, batch_size: int, learning_rate: float, max_length: int, fp16: bool, push_to_hub: bool, hub_model_id: str, progress=gr.Progress(), ): progress(0, desc="Preparing dataset...") if uploaded_file is not None: data_path = Path(uploaded_file.name) else: data_path = Path(BUNDLED_DATASETS.get(dataset_choice, DEFAULT_DATA)) if not data_path.exists(): return ( f"Dataset not found: `{data_path}`. " "Upload a parquet file or include data/ in the Space repo.", None, None, ) hub_token = os.getenv("HF_TOKEN") or os.getenv("HUGGING_FACE_HUB_TOKEN") if push_to_hub and not hub_token: return ( "Add `HF_TOKEN` to Space secrets to push models to the Hub.", None, None, ) if push_to_hub and not hub_model_id.strip(): return "Enter a Hub model id (e.g. `your-username/sql-codebert-classifier`).", None, None if OUTPUT_DIR.exists(): shutil.rmtree(OUTPUT_DIR, ignore_errors=True) OUTPUT_DIR.mkdir(parents=True, exist_ok=True) samples = int(max_samples) if max_samples and max_samples > 0 else None progress(0.1, desc="Starting CodeBERT training...") try: metrics = train( data_path=data_path, output_dir=OUTPUT_DIR, epochs=epochs, batch_size=batch_size, learning_rate=learning_rate, max_length=max_length, max_samples=samples, fp16=fp16, save_strategy="no", push_to_hub=push_to_hub, hub_model_id=hub_model_id.strip() or None, hub_token=hub_token, ) except Exception as exc: return f"Training failed:\n\n```\n{exc}\n```", None, None progress(1.0, desc="Done") if push_to_hub and hub_model_id.strip(): metrics["hub_url"] = f"https://huggingface.co/{hub_model_id.strip()}" metrics_path = OUTPUT_DIR / "metrics.json" summary = _format_metrics(metrics) return summary, str(metrics_path) if metrics_path.exists() else None, str(OUTPUT_DIR) def load_preview(dataset_choice: str, uploaded_file) -> str: try: if uploaded_file is not None: df = pd.read_parquet(uploaded_file.name) else: path = BUNDLED_DATASETS.get(dataset_choice, DEFAULT_DATA) if not Path(path).exists(): return f"Dataset not found: {path}" df = pd.read_parquet(path) cols = list(df.columns) sample = df.head(2).to_dict(orient="records") return f"**Rows:** {len(df):,}\n\n**Columns:** `{cols}`\n\n**Sample:**\n```json\n{json.dumps(sample, indent=2)[:2000]}\n```" except Exception as exc: return f"Could not load preview: {exc}" with gr.Blocks(title="SQL Error Classifier — Train") as demo: gr.Markdown( """ # SQL Error Classifier — CodeBERT Training Train **microsoft/codebert-base** as a cross-encoder on this Space. **Input format:** `QUESTION` + `SCHEMA` + `STUDENT_SQL` + `CORRECT_SQL` (single sequence) **GPU recommended** — upgrade Space hardware to `t4-small` or better. """ ) with gr.Row(): with gr.Column(scale=1): dataset_choice = gr.Dropdown( choices=list(BUNDLED_DATASETS.keys()), value="Dev (15K samples)", label="Bundled dataset", ) uploaded = gr.File( label="Or upload parquet", file_types=[".parquet"], ) preview_btn = gr.Button("Preview dataset") preview_out = gr.Markdown() max_samples = gr.Number( label="Max samples (0 = all)", value=5000, precision=0, ) epochs = gr.Slider(1, 10, value=2, step=1, label="Epochs") batch_size = gr.Slider(4, 64, value=8, step=4, label="Batch size") learning_rate = gr.Number(label="Learning rate", value=2e-5) max_length = gr.Slider(128, 512, value=512, step=64, label="Max length") fp16 = gr.Checkbox(label="FP16 (GPU only)", value=True) push_to_hub = gr.Checkbox(label="Push to Hugging Face Hub", value=False) hub_model_id = gr.Textbox( label="Hub model id", placeholder="your-username/sql-codebert-classifier", ) train_btn = gr.Button("Start Training", variant="primary") with gr.Column(scale=1): result = gr.Markdown(label="Results") metrics_file = gr.File(label="metrics.json") model_dir = gr.Textbox(label="Model output path", interactive=False) preview_btn.click(load_preview, [dataset_choice, uploaded], preview_out) train_btn.click( run_training, [ dataset_choice, uploaded, max_samples, epochs, batch_size, learning_rate, max_length, fp16, push_to_hub, hub_model_id, ], [result, metrics_file, model_dir], ) gr.Markdown( """ ### Space setup 1. Create a Gradio Space and push this repo 2. Set **Hardware → GPU (t4-small)** 3. Add secret `HF_TOKEN` (write token) to push models 4. Include `data/sql_errors_dev.parquet` in the repo (or upload at runtime) ### After training Use the saved model with: ```python from src.hf_predict_codebert import CodeBERTSQLErrorClassifier clf = CodeBERTSQLErrorClassifier("models/codebert-cross-encoder") ``` """ ) if __name__ == "__main__": demo.launch(server_name="0.0.0.0", server_port=7860)