nishu08's picture
Deploy CodeBERT training Space
8464aea verified
Raw
History Blame Contribute Delete
2.46 kB
"""CodeBERT inference Gradio app (optional HF Space for predictions)."""
from __future__ import annotations
import gradio as gr
from src.hf_predict_codebert import CodeBERTSQLErrorClassifier
MODEL_DIR = "models/codebert-cross-encoder"
try:
clf = CodeBERTSQLErrorClassifier(MODEL_DIR)
model_status = f"Loaded model from `{MODEL_DIR}`"
except Exception as exc:
clf = None
model_status = f"Model not loaded: {exc}. Train first or set SPACE_MODEL_DIR."
EXAMPLE = {
"question": "What is the average score of students in each department?",
"schema": "students(id, name, score, department_id) | departments(id, name)",
"student_sql": "SELECT department_id, SUM(score) FROM students GROUP BY department_id",
"correct_sql": "SELECT department_id, AVG(score) FROM students GROUP BY department_id",
}
def classify(question, schema, student_sql, correct_sql, threshold):
if clf is None:
return "Train a model first.", ""
result = clf.predict(
question=question.strip(),
schema=schema.strip(),
student_sql=student_sql.strip(),
correct_sql=correct_sql.strip(),
threshold=threshold,
)
summary = (
f"**{result['primary_label']}** ({result['primary_confidence']:.1%})\n\n"
f"All labels above threshold: {', '.join(result['error_labels']) or 'none'}"
)
probs = "\n".join(
f"- {k}: {v:.1%}" for k, v in result["probabilities"].items()
)
return summary, probs
with gr.Blocks(title="SQL Error Classifier") as demo:
gr.Markdown(f"# SQL Error Classifier (CodeBERT)\n{model_status}")
with gr.Row():
with gr.Column():
question = gr.Textbox(label="Question", lines=2, value=EXAMPLE["question"])
schema = gr.Textbox(label="Schema", lines=2, value=EXAMPLE["schema"])
student_sql = gr.Textbox(label="Student SQL", lines=3, value=EXAMPLE["student_sql"])
correct_sql = gr.Textbox(label="Correct SQL", lines=3, value=EXAMPLE["correct_sql"])
threshold = gr.Slider(0.1, 0.9, value=0.5, step=0.05, label="Threshold")
btn = gr.Button("Classify", variant="primary")
with gr.Column():
prediction = gr.Markdown()
probabilities = gr.Markdown()
btn.click(
classify,
[question, schema, student_sql, correct_sql, threshold],
[prediction, probabilities],
)
if __name__ == "__main__":
demo.launch()