"""CodeBERT inference Gradio app for Hugging Face Spaces.""" from __future__ import annotations import os import gradio as gr from src.hf_predict_codebert import CodeBERTSQLErrorClassifier MODEL_ID = os.getenv("SPACE_MODEL_ID", "nishu08/sql-codebert-classifier") try: clf = CodeBERTSQLErrorClassifier(MODEL_ID) model_status = f"Model loaded: **{MODEL_ID}**" except Exception as exc: clf = None model_status = ( f"Could not load `{MODEL_ID}`: {exc}\n\n" "Push your model first: `python scripts/push_codebert_to_hub.py --repo-id nishu08/sql-codebert-classifier`" ) EXAMPLES = [ [ "What is the average score of students in each department?", "students(id, name, score, department_id) | departments(id, name)", "SELECT department_id, SUM(score) FROM students GROUP BY department_id", "SELECT department_id, AVG(score) FROM students GROUP BY department_id", 0.5, ], [ "Find students who have not provided an email address.", "students(id, name, email, phone)", "SELECT name FROM students WHERE email = NULL", "SELECT name FROM students WHERE email IS NULL", 0.5, ], [ "List each student's name along with their department name.", "students(id, name, department_id) | departments(id, name)", "SELECT students.name, departments.name FROM students JOIN departments", "SELECT students.name, departments.name FROM students INNER JOIN departments ON students.department_id = departments.id", 0.5, ], [ "What is the average score of students in each department?", "students(id, name, score, department_id) | departments(id, name)", "SELECT department_id, AVG(score) FROM students GROUP BY department_id", "SELECT department_id, AVG(score) FROM students GROUP BY department_id", 0.5, ], ] def classify(question, schema, student_sql, correct_sql, threshold): if clf is None: return "Model not loaded.", "" result = clf.predict( question=question.strip(), schema=schema.strip(), student_sql=student_sql.strip(), correct_sql=correct_sql.strip(), threshold=threshold, ) if result["primary_label"] == "NO_ERROR": if result.get("match_detected"): summary = ( "### No error\n" "Student SQL matches the correct answer — no mistake to classify." ) else: summary = ( "### No error\n" f"No label exceeded the {threshold:.0%} threshold." ) probs = "_All probabilities below threshold._" else: summary = ( f"### {result['primary_label']}\n" f"Confidence: **{result['primary_confidence']:.1%}**\n\n" f"**Active labels:** {', '.join(result['error_labels']) or 'none'}" ) probs = "\n".join( f"- **{k}**: {v:.1%}" for k, v in result["probabilities"].items() ) return summary, probs with gr.Blocks(title="SQL Error Classifier", theme=gr.themes.Soft()) as demo: gr.Markdown( f""" # SQL Error Classifier CodeBERT cross-encoder — classifies student SQL mistakes into error categories. {model_status} """ ) with gr.Row(): with gr.Column(): question = gr.Textbox(label="Question", lines=2) schema = gr.Textbox(label="Schema", lines=2) student_sql = gr.Textbox(label="Student SQL", lines=4) correct_sql = gr.Textbox(label="Correct SQL", lines=4) threshold = gr.Slider(0.1, 0.9, value=0.5, step=0.05, label="Threshold") btn = gr.Button("Classify", variant="primary") with gr.Column(): prediction = gr.Markdown(label="Prediction") probabilities = gr.Markdown(label="Probabilities") gr.Examples(examples=EXAMPLES, inputs=[question, schema, student_sql, correct_sql, threshold]) btn.click( classify, [question, schema, student_sql, correct_sql, threshold], [prediction, probabilities], ) if __name__ == "__main__": demo.launch(server_name="0.0.0.0", server_port=7860)