nishu08's picture
Deploy CodeBERT training Space
9b2cded verified
Raw
History Blame Contribute Delete
7.19 kB
"""Generate playground exercises: schema + question + correct SQL."""
from __future__ import annotations
import random
from dataclasses import dataclass
from typing import Callable, Dict, List
@dataclass(frozen=True)
class Exercise:
schema: str
question: str
correct_query: str
tables: tuple[str, ...]
columns: tuple[str, ...]
def _fmt_schema(tables: Dict[str, List[str]]) -> str:
parts = [f"{name}({', '.join(cols)})" for name, cols in tables.items()]
return " | ".join(parts)
ExerciseBuilder = Callable[[random.Random], Exercise]
EXERCISE_BUILDERS: List[ExerciseBuilder] = []
def _register(builder: ExerciseBuilder) -> ExerciseBuilder:
EXERCISE_BUILDERS.append(builder)
return builder
@_register
def exercise_avg_by_department(rng: random.Random) -> Exercise:
tables = {
"students": ["id", "name", "email", "score", "department_id"],
"departments": ["id", "name", "city"],
}
return Exercise(
schema=_fmt_schema(tables),
question="What is the average score of students in each department?",
correct_query=(
"SELECT department_id, AVG(score) "
"FROM students GROUP BY department_id"
),
tables=tuple(tables),
columns=("department_id", "score"),
)
@_register
def exercise_student_department_names(rng: random.Random) -> Exercise:
tables = {
"students": ["id", "name", "department_id"],
"departments": ["id", "name"],
}
return Exercise(
schema=_fmt_schema(tables),
question="List each student's name along with their department name.",
correct_query=(
"SELECT students.name, departments.name "
"FROM students "
"INNER JOIN departments ON students.department_id = departments.id"
),
tables=tuple(tables),
columns=("name", "department_id"),
)
@_register
def exercise_high_scoring_students(rng: random.Random) -> Exercise:
threshold = rng.randint(70, 90)
tables = {"students": ["id", "name", "age", "score", "status"]}
return Exercise(
schema=_fmt_schema(tables),
question=(
f"Find names of students older than 18 with a score above {threshold}."
),
correct_query=(
f"SELECT name FROM students "
f"WHERE age > 18 AND score > {threshold}"
),
tables=tuple(tables),
columns=("name", "age", "score", "status"),
)
@_register
def exercise_unique_cities(rng: random.Random) -> Exercise:
tables = {"students": ["id", "name", "city", "country"]}
return Exercise(
schema=_fmt_schema(tables),
question="List the unique cities where students live.",
correct_query="SELECT DISTINCT city FROM students",
tables=tuple(tables),
columns=("city",),
)
@_register
def exercise_top_scorer(rng: random.Random) -> Exercise:
tables = {"students": ["id", "name", "score"], "grades": ["id", "score"]}
return Exercise(
schema=_fmt_schema(tables),
question="Find students whose score equals the highest score in the class.",
correct_query=(
"SELECT name FROM students "
"WHERE score = (SELECT MAX(score) FROM grades)"
),
tables=tuple(tables),
columns=("name", "score"),
)
@_register
def exercise_departments_over_budget(rng: random.Random) -> Exercise:
budget = rng.randint(3, 8)
tables = {"employees": ["id", "name", "department_id", "salary"]}
return Exercise(
schema=_fmt_schema(tables),
question=f"Which departments have more than {budget} employees?",
correct_query=(
f"SELECT department_id, COUNT(*) AS cnt "
f"FROM employees GROUP BY department_id "
f"HAVING COUNT(*) > {budget}"
),
tables=tuple(tables),
columns=("department_id", "salary"),
)
@_register
def exercise_recent_orders(rng: random.Random) -> Exercise:
year = rng.randint(2020, 2024)
tables = {"orders": ["id", "customer_id", "amount", "order_date", "status"]}
return Exercise(
schema=_fmt_schema(tables),
question=f"Show orders placed on or after January 1, {year}.",
correct_query=(
f"SELECT id, amount FROM orders "
f"WHERE order_date >= DATE '{year}-01-01'"
),
tables=tuple(tables),
columns=("order_date", "amount", "status"),
)
@_register
def exercise_missing_email(rng: random.Random) -> Exercise:
tables = {"students": ["id", "name", "email", "phone"]}
return Exercise(
schema=_fmt_schema(tables),
question="Find students who have not provided an email address.",
correct_query="SELECT name FROM students WHERE email IS NULL",
tables=tuple(tables),
columns=("email", "name"),
)
@_register
def exercise_rank_by_score(rng: random.Random) -> Exercise:
tables = {"students": ["id", "name", "score", "department_id"]}
return Exercise(
schema=_fmt_schema(tables),
question="Rank students by score within each department.",
correct_query=(
"SELECT name, score, "
"RANK() OVER (PARTITION BY department_id ORDER BY score DESC) AS rnk "
"FROM students"
),
tables=tuple(tables),
columns=("name", "score", "department_id"),
)
@_register
def exercise_course_enrollment_count(rng: random.Random) -> Exercise:
tables = {
"courses": ["id", "title"],
"enrollments": ["id", "course_id", "student_id"],
}
return Exercise(
schema=_fmt_schema(tables),
question="How many students are enrolled in each course?",
correct_query=(
"SELECT courses.title, COUNT(enrollments.student_id) AS enrolled "
"FROM courses "
"INNER JOIN enrollments ON courses.id = enrollments.course_id "
"GROUP BY courses.title"
),
tables=tuple(tables),
columns=("title", "student_id", "course_id"),
)
@_register
def exercise_active_employees(rng: random.Random) -> Exercise:
tables = {"employees": ["id", "name", "salary", "status", "hire_date"]}
return Exercise(
schema=_fmt_schema(tables),
question="What is the total salary paid to active employees?",
correct_query=(
"SELECT SUM(salary) FROM employees WHERE status = 'active'"
),
tables=tuple(tables),
columns=("salary", "status"),
)
@_register
def exercise_product_price_filter(rng: random.Random) -> Exercise:
lo, hi = rng.randint(10, 50), rng.randint(100, 500)
tables = {"products": ["id", "name", "price", "category"]}
return Exercise(
schema=_fmt_schema(tables),
question=f"List products priced between {lo} and {hi}.",
correct_query=(
f"SELECT name, price FROM products "
f"WHERE price BETWEEN {lo} AND {hi}"
),
tables=tuple(tables),
columns=("name", "price", "category"),
)
def generate_exercise(rng: random.Random) -> Exercise:
return rng.choice(EXERCISE_BUILDERS)(rng)