Spaces:
Sleeping
Sleeping
| """CodeBERT inference Gradio app (optional HF Space for predictions).""" | |
| from __future__ import annotations | |
| import gradio as gr | |
| from src.hf_predict_codebert import CodeBERTSQLErrorClassifier | |
| MODEL_DIR = "models/codebert-cross-encoder" | |
| try: | |
| clf = CodeBERTSQLErrorClassifier(MODEL_DIR) | |
| model_status = f"Loaded model from `{MODEL_DIR}`" | |
| except Exception as exc: | |
| clf = None | |
| model_status = f"Model not loaded: {exc}. Train first or set SPACE_MODEL_DIR." | |
| EXAMPLE = { | |
| "question": "What is the average score of students in each department?", | |
| "schema": "students(id, name, score, department_id) | departments(id, name)", | |
| "student_sql": "SELECT department_id, SUM(score) FROM students GROUP BY department_id", | |
| "correct_sql": "SELECT department_id, AVG(score) FROM students GROUP BY department_id", | |
| } | |
| def classify(question, schema, student_sql, correct_sql, threshold): | |
| if clf is None: | |
| return "Train a model first.", "" | |
| result = clf.predict( | |
| question=question.strip(), | |
| schema=schema.strip(), | |
| student_sql=student_sql.strip(), | |
| correct_sql=correct_sql.strip(), | |
| threshold=threshold, | |
| ) | |
| summary = ( | |
| f"**{result['primary_label']}** ({result['primary_confidence']:.1%})\n\n" | |
| f"All labels above threshold: {', '.join(result['error_labels']) or 'none'}" | |
| ) | |
| probs = "\n".join( | |
| f"- {k}: {v:.1%}" for k, v in result["probabilities"].items() | |
| ) | |
| return summary, probs | |
| with gr.Blocks(title="SQL Error Classifier") as demo: | |
| gr.Markdown(f"# SQL Error Classifier (CodeBERT)\n{model_status}") | |
| with gr.Row(): | |
| with gr.Column(): | |
| question = gr.Textbox(label="Question", lines=2, value=EXAMPLE["question"]) | |
| schema = gr.Textbox(label="Schema", lines=2, value=EXAMPLE["schema"]) | |
| student_sql = gr.Textbox(label="Student SQL", lines=3, value=EXAMPLE["student_sql"]) | |
| correct_sql = gr.Textbox(label="Correct SQL", lines=3, value=EXAMPLE["correct_sql"]) | |
| threshold = gr.Slider(0.1, 0.9, value=0.5, step=0.05, label="Threshold") | |
| btn = gr.Button("Classify", variant="primary") | |
| with gr.Column(): | |
| prediction = gr.Markdown() | |
| probabilities = gr.Markdown() | |
| btn.click( | |
| classify, | |
| [question, schema, student_sql, correct_sql, threshold], | |
| [prediction, probabilities], | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch() | |