nishu08's picture
Deploy CodeBERT inference Space
7aae828 verified
Raw
History Blame Contribute Delete
4.27 kB
"""CodeBERT inference Gradio app for Hugging Face Spaces."""
from __future__ import annotations
import os
import gradio as gr
from src.hf_predict_codebert import CodeBERTSQLErrorClassifier
MODEL_ID = os.getenv("SPACE_MODEL_ID", "nishu08/sql-codebert-classifier")
try:
clf = CodeBERTSQLErrorClassifier(MODEL_ID)
model_status = f"Model loaded: **{MODEL_ID}**"
except Exception as exc:
clf = None
model_status = (
f"Could not load `{MODEL_ID}`: {exc}\n\n"
"Push your model first: `python scripts/push_codebert_to_hub.py --repo-id nishu08/sql-codebert-classifier`"
)
EXAMPLES = [
[
"What is the average score of students in each department?",
"students(id, name, score, department_id) | departments(id, name)",
"SELECT department_id, SUM(score) FROM students GROUP BY department_id",
"SELECT department_id, AVG(score) FROM students GROUP BY department_id",
0.5,
],
[
"Find students who have not provided an email address.",
"students(id, name, email, phone)",
"SELECT name FROM students WHERE email = NULL",
"SELECT name FROM students WHERE email IS NULL",
0.5,
],
[
"List each student's name along with their department name.",
"students(id, name, department_id) | departments(id, name)",
"SELECT students.name, departments.name FROM students JOIN departments",
"SELECT students.name, departments.name FROM students INNER JOIN departments ON students.department_id = departments.id",
0.5,
],
[
"What is the average score of students in each department?",
"students(id, name, score, department_id) | departments(id, name)",
"SELECT department_id, AVG(score) FROM students GROUP BY department_id",
"SELECT department_id, AVG(score) FROM students GROUP BY department_id",
0.5,
],
]
def classify(question, schema, student_sql, correct_sql, threshold):
if clf is None:
return "Model not loaded.", ""
result = clf.predict(
question=question.strip(),
schema=schema.strip(),
student_sql=student_sql.strip(),
correct_sql=correct_sql.strip(),
threshold=threshold,
)
if result["primary_label"] == "NO_ERROR":
if result.get("match_detected"):
summary = (
"### No error\n"
"Student SQL matches the correct answer — no mistake to classify."
)
else:
summary = (
"### No error\n"
f"No label exceeded the {threshold:.0%} threshold."
)
probs = "_All probabilities below threshold._"
else:
summary = (
f"### {result['primary_label']}\n"
f"Confidence: **{result['primary_confidence']:.1%}**\n\n"
f"**Active labels:** {', '.join(result['error_labels']) or 'none'}"
)
probs = "\n".join(
f"- **{k}**: {v:.1%}" for k, v in result["probabilities"].items()
)
return summary, probs
with gr.Blocks(title="SQL Error Classifier", theme=gr.themes.Soft()) as demo:
gr.Markdown(
f"""
# SQL Error Classifier
CodeBERT cross-encoder — classifies student SQL mistakes into error categories.
{model_status}
"""
)
with gr.Row():
with gr.Column():
question = gr.Textbox(label="Question", lines=2)
schema = gr.Textbox(label="Schema", lines=2)
student_sql = gr.Textbox(label="Student SQL", lines=4)
correct_sql = gr.Textbox(label="Correct SQL", lines=4)
threshold = gr.Slider(0.1, 0.9, value=0.5, step=0.05, label="Threshold")
btn = gr.Button("Classify", variant="primary")
with gr.Column():
prediction = gr.Markdown(label="Prediction")
probabilities = gr.Markdown(label="Probabilities")
gr.Examples(examples=EXAMPLES, inputs=[question, schema, student_sql, correct_sql, threshold])
btn.click(
classify,
[question, schema, student_sql, correct_sql, threshold],
[prediction, probabilities],
)
if __name__ == "__main__":
demo.launch(server_name="0.0.0.0", server_port=7860)