File size: 4,265 Bytes
8a3099e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7aae828
 
 
 
 
 
 
8a3099e
 
 
 
 
 
 
 
 
 
 
 
 
7aae828
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8a3099e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
"""CodeBERT inference Gradio app for Hugging Face Spaces."""

from __future__ import annotations

import os

import gradio as gr

from src.hf_predict_codebert import CodeBERTSQLErrorClassifier

MODEL_ID = os.getenv("SPACE_MODEL_ID", "nishu08/sql-codebert-classifier")

try:
    clf = CodeBERTSQLErrorClassifier(MODEL_ID)
    model_status = f"Model loaded: **{MODEL_ID}**"
except Exception as exc:
    clf = None
    model_status = (
        f"Could not load `{MODEL_ID}`: {exc}\n\n"
        "Push your model first: `python scripts/push_codebert_to_hub.py --repo-id nishu08/sql-codebert-classifier`"
    )

EXAMPLES = [
    [
        "What is the average score of students in each department?",
        "students(id, name, score, department_id) | departments(id, name)",
        "SELECT department_id, SUM(score) FROM students GROUP BY department_id",
        "SELECT department_id, AVG(score) FROM students GROUP BY department_id",
        0.5,
    ],
    [
        "Find students who have not provided an email address.",
        "students(id, name, email, phone)",
        "SELECT name FROM students WHERE email = NULL",
        "SELECT name FROM students WHERE email IS NULL",
        0.5,
    ],
    [
        "List each student's name along with their department name.",
        "students(id, name, department_id) | departments(id, name)",
        "SELECT students.name, departments.name FROM students JOIN departments",
        "SELECT students.name, departments.name FROM students INNER JOIN departments ON students.department_id = departments.id",
        0.5,
    ],
    [
        "What is the average score of students in each department?",
        "students(id, name, score, department_id) | departments(id, name)",
        "SELECT department_id, AVG(score) FROM students GROUP BY department_id",
        "SELECT department_id, AVG(score) FROM students GROUP BY department_id",
        0.5,
    ],
]


def classify(question, schema, student_sql, correct_sql, threshold):
    if clf is None:
        return "Model not loaded.", ""
    result = clf.predict(
        question=question.strip(),
        schema=schema.strip(),
        student_sql=student_sql.strip(),
        correct_sql=correct_sql.strip(),
        threshold=threshold,
    )
    if result["primary_label"] == "NO_ERROR":
        if result.get("match_detected"):
            summary = (
                "### No error\n"
                "Student SQL matches the correct answer — no mistake to classify."
            )
        else:
            summary = (
                "### No error\n"
                f"No label exceeded the {threshold:.0%} threshold."
            )
        probs = "_All probabilities below threshold._"
    else:
        summary = (
            f"### {result['primary_label']}\n"
            f"Confidence: **{result['primary_confidence']:.1%}**\n\n"
            f"**Active labels:** {', '.join(result['error_labels']) or 'none'}"
        )
        probs = "\n".join(
            f"- **{k}**: {v:.1%}" for k, v in result["probabilities"].items()
        )
    return summary, probs


with gr.Blocks(title="SQL Error Classifier", theme=gr.themes.Soft()) as demo:
    gr.Markdown(
        f"""
        # SQL Error Classifier
        CodeBERT cross-encoder — classifies student SQL mistakes into error categories.

        {model_status}
        """
    )
    with gr.Row():
        with gr.Column():
            question = gr.Textbox(label="Question", lines=2)
            schema = gr.Textbox(label="Schema", lines=2)
            student_sql = gr.Textbox(label="Student SQL", lines=4)
            correct_sql = gr.Textbox(label="Correct SQL", lines=4)
            threshold = gr.Slider(0.1, 0.9, value=0.5, step=0.05, label="Threshold")
            btn = gr.Button("Classify", variant="primary")
        with gr.Column():
            prediction = gr.Markdown(label="Prediction")
            probabilities = gr.Markdown(label="Probabilities")

    gr.Examples(examples=EXAMPLES, inputs=[question, schema, student_sql, correct_sql, threshold])
    btn.click(
        classify,
        [question, schema, student_sql, correct_sql, threshold],
        [prediction, probabilities],
    )

if __name__ == "__main__":
    demo.launch(server_name="0.0.0.0", server_port=7860)