Spaces:
Sleeping
Sleeping
| """CodeBERT inference Gradio app for Hugging Face Spaces.""" | |
| from __future__ import annotations | |
| import os | |
| import gradio as gr | |
| from src.hf_predict_codebert import CodeBERTSQLErrorClassifier | |
| MODEL_ID = os.getenv("SPACE_MODEL_ID", "nishu08/sql-codebert-classifier") | |
| try: | |
| clf = CodeBERTSQLErrorClassifier(MODEL_ID) | |
| model_status = f"Model loaded: **{MODEL_ID}**" | |
| except Exception as exc: | |
| clf = None | |
| model_status = ( | |
| f"Could not load `{MODEL_ID}`: {exc}\n\n" | |
| "Push your model first: `python scripts/push_codebert_to_hub.py --repo-id nishu08/sql-codebert-classifier`" | |
| ) | |
| EXAMPLES = [ | |
| [ | |
| "What is the average score of students in each department?", | |
| "students(id, name, score, department_id) | departments(id, name)", | |
| "SELECT department_id, SUM(score) FROM students GROUP BY department_id", | |
| "SELECT department_id, AVG(score) FROM students GROUP BY department_id", | |
| 0.5, | |
| ], | |
| [ | |
| "Find students who have not provided an email address.", | |
| "students(id, name, email, phone)", | |
| "SELECT name FROM students WHERE email = NULL", | |
| "SELECT name FROM students WHERE email IS NULL", | |
| 0.5, | |
| ], | |
| [ | |
| "List each student's name along with their department name.", | |
| "students(id, name, department_id) | departments(id, name)", | |
| "SELECT students.name, departments.name FROM students JOIN departments", | |
| "SELECT students.name, departments.name FROM students INNER JOIN departments ON students.department_id = departments.id", | |
| 0.5, | |
| ], | |
| [ | |
| "What is the average score of students in each department?", | |
| "students(id, name, score, department_id) | departments(id, name)", | |
| "SELECT department_id, AVG(score) FROM students GROUP BY department_id", | |
| "SELECT department_id, AVG(score) FROM students GROUP BY department_id", | |
| 0.5, | |
| ], | |
| ] | |
| def classify(question, schema, student_sql, correct_sql, threshold): | |
| if clf is None: | |
| return "Model not loaded.", "" | |
| result = clf.predict( | |
| question=question.strip(), | |
| schema=schema.strip(), | |
| student_sql=student_sql.strip(), | |
| correct_sql=correct_sql.strip(), | |
| threshold=threshold, | |
| ) | |
| if result["primary_label"] == "NO_ERROR": | |
| if result.get("match_detected"): | |
| summary = ( | |
| "### No error\n" | |
| "Student SQL matches the correct answer — no mistake to classify." | |
| ) | |
| else: | |
| summary = ( | |
| "### No error\n" | |
| f"No label exceeded the {threshold:.0%} threshold." | |
| ) | |
| probs = "_All probabilities below threshold._" | |
| else: | |
| summary = ( | |
| f"### {result['primary_label']}\n" | |
| f"Confidence: **{result['primary_confidence']:.1%}**\n\n" | |
| f"**Active labels:** {', '.join(result['error_labels']) or 'none'}" | |
| ) | |
| probs = "\n".join( | |
| f"- **{k}**: {v:.1%}" for k, v in result["probabilities"].items() | |
| ) | |
| return summary, probs | |
| with gr.Blocks(title="SQL Error Classifier", theme=gr.themes.Soft()) as demo: | |
| gr.Markdown( | |
| f""" | |
| # SQL Error Classifier | |
| CodeBERT cross-encoder — classifies student SQL mistakes into error categories. | |
| {model_status} | |
| """ | |
| ) | |
| with gr.Row(): | |
| with gr.Column(): | |
| question = gr.Textbox(label="Question", lines=2) | |
| schema = gr.Textbox(label="Schema", lines=2) | |
| student_sql = gr.Textbox(label="Student SQL", lines=4) | |
| correct_sql = gr.Textbox(label="Correct SQL", lines=4) | |
| threshold = gr.Slider(0.1, 0.9, value=0.5, step=0.05, label="Threshold") | |
| btn = gr.Button("Classify", variant="primary") | |
| with gr.Column(): | |
| prediction = gr.Markdown(label="Prediction") | |
| probabilities = gr.Markdown(label="Probabilities") | |
| gr.Examples(examples=EXAMPLES, inputs=[question, schema, student_sql, correct_sql, threshold]) | |
| btn.click( | |
| classify, | |
| [question, schema, student_sql, correct_sql, threshold], | |
| [prediction, probabilities], | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch(server_name="0.0.0.0", server_port=7860) | |