"""CodeBERT inference Gradio app (optional HF Space for predictions).""" from __future__ import annotations import gradio as gr from src.hf_predict_codebert import CodeBERTSQLErrorClassifier MODEL_DIR = "models/codebert-cross-encoder" try: clf = CodeBERTSQLErrorClassifier(MODEL_DIR) model_status = f"Loaded model from `{MODEL_DIR}`" except Exception as exc: clf = None model_status = f"Model not loaded: {exc}. Train first or set SPACE_MODEL_DIR." EXAMPLE = { "question": "What is the average score of students in each department?", "schema": "students(id, name, score, department_id) | departments(id, name)", "student_sql": "SELECT department_id, SUM(score) FROM students GROUP BY department_id", "correct_sql": "SELECT department_id, AVG(score) FROM students GROUP BY department_id", } def classify(question, schema, student_sql, correct_sql, threshold): if clf is None: return "Train a model first.", "" result = clf.predict( question=question.strip(), schema=schema.strip(), student_sql=student_sql.strip(), correct_sql=correct_sql.strip(), threshold=threshold, ) summary = ( f"**{result['primary_label']}** ({result['primary_confidence']:.1%})\n\n" f"All labels above threshold: {', '.join(result['error_labels']) or 'none'}" ) probs = "\n".join( f"- {k}: {v:.1%}" for k, v in result["probabilities"].items() ) return summary, probs with gr.Blocks(title="SQL Error Classifier") as demo: gr.Markdown(f"# SQL Error Classifier (CodeBERT)\n{model_status}") with gr.Row(): with gr.Column(): question = gr.Textbox(label="Question", lines=2, value=EXAMPLE["question"]) schema = gr.Textbox(label="Schema", lines=2, value=EXAMPLE["schema"]) student_sql = gr.Textbox(label="Student SQL", lines=3, value=EXAMPLE["student_sql"]) correct_sql = gr.Textbox(label="Correct SQL", lines=3, value=EXAMPLE["correct_sql"]) threshold = gr.Slider(0.1, 0.9, value=0.5, step=0.05, label="Threshold") btn = gr.Button("Classify", variant="primary") with gr.Column(): prediction = gr.Markdown() probabilities = gr.Markdown() btn.click( classify, [question, schema, student_sql, correct_sql, threshold], [prediction, probabilities], ) if __name__ == "__main__": demo.launch()