"""Generate playground exercises: schema + question + correct SQL.""" from __future__ import annotations import random from dataclasses import dataclass from typing import Callable, Dict, List @dataclass(frozen=True) class Exercise: schema: str question: str correct_query: str tables: tuple[str, ...] columns: tuple[str, ...] def _fmt_schema(tables: Dict[str, List[str]]) -> str: parts = [f"{name}({', '.join(cols)})" for name, cols in tables.items()] return " | ".join(parts) ExerciseBuilder = Callable[[random.Random], Exercise] EXERCISE_BUILDERS: List[ExerciseBuilder] = [] def _register(builder: ExerciseBuilder) -> ExerciseBuilder: EXERCISE_BUILDERS.append(builder) return builder @_register def exercise_avg_by_department(rng: random.Random) -> Exercise: tables = { "students": ["id", "name", "email", "score", "department_id"], "departments": ["id", "name", "city"], } return Exercise( schema=_fmt_schema(tables), question="What is the average score of students in each department?", correct_query=( "SELECT department_id, AVG(score) " "FROM students GROUP BY department_id" ), tables=tuple(tables), columns=("department_id", "score"), ) @_register def exercise_student_department_names(rng: random.Random) -> Exercise: tables = { "students": ["id", "name", "department_id"], "departments": ["id", "name"], } return Exercise( schema=_fmt_schema(tables), question="List each student's name along with their department name.", correct_query=( "SELECT students.name, departments.name " "FROM students " "INNER JOIN departments ON students.department_id = departments.id" ), tables=tuple(tables), columns=("name", "department_id"), ) @_register def exercise_high_scoring_students(rng: random.Random) -> Exercise: threshold = rng.randint(70, 90) tables = {"students": ["id", "name", "age", "score", "status"]} return Exercise( schema=_fmt_schema(tables), question=( f"Find names of students older than 18 with a score above {threshold}." ), correct_query=( f"SELECT name FROM students " f"WHERE age > 18 AND score > {threshold}" ), tables=tuple(tables), columns=("name", "age", "score", "status"), ) @_register def exercise_unique_cities(rng: random.Random) -> Exercise: tables = {"students": ["id", "name", "city", "country"]} return Exercise( schema=_fmt_schema(tables), question="List the unique cities where students live.", correct_query="SELECT DISTINCT city FROM students", tables=tuple(tables), columns=("city",), ) @_register def exercise_top_scorer(rng: random.Random) -> Exercise: tables = {"students": ["id", "name", "score"], "grades": ["id", "score"]} return Exercise( schema=_fmt_schema(tables), question="Find students whose score equals the highest score in the class.", correct_query=( "SELECT name FROM students " "WHERE score = (SELECT MAX(score) FROM grades)" ), tables=tuple(tables), columns=("name", "score"), ) @_register def exercise_departments_over_budget(rng: random.Random) -> Exercise: budget = rng.randint(3, 8) tables = {"employees": ["id", "name", "department_id", "salary"]} return Exercise( schema=_fmt_schema(tables), question=f"Which departments have more than {budget} employees?", correct_query=( f"SELECT department_id, COUNT(*) AS cnt " f"FROM employees GROUP BY department_id " f"HAVING COUNT(*) > {budget}" ), tables=tuple(tables), columns=("department_id", "salary"), ) @_register def exercise_recent_orders(rng: random.Random) -> Exercise: year = rng.randint(2020, 2024) tables = {"orders": ["id", "customer_id", "amount", "order_date", "status"]} return Exercise( schema=_fmt_schema(tables), question=f"Show orders placed on or after January 1, {year}.", correct_query=( f"SELECT id, amount FROM orders " f"WHERE order_date >= DATE '{year}-01-01'" ), tables=tuple(tables), columns=("order_date", "amount", "status"), ) @_register def exercise_missing_email(rng: random.Random) -> Exercise: tables = {"students": ["id", "name", "email", "phone"]} return Exercise( schema=_fmt_schema(tables), question="Find students who have not provided an email address.", correct_query="SELECT name FROM students WHERE email IS NULL", tables=tuple(tables), columns=("email", "name"), ) @_register def exercise_rank_by_score(rng: random.Random) -> Exercise: tables = {"students": ["id", "name", "score", "department_id"]} return Exercise( schema=_fmt_schema(tables), question="Rank students by score within each department.", correct_query=( "SELECT name, score, " "RANK() OVER (PARTITION BY department_id ORDER BY score DESC) AS rnk " "FROM students" ), tables=tuple(tables), columns=("name", "score", "department_id"), ) @_register def exercise_course_enrollment_count(rng: random.Random) -> Exercise: tables = { "courses": ["id", "title"], "enrollments": ["id", "course_id", "student_id"], } return Exercise( schema=_fmt_schema(tables), question="How many students are enrolled in each course?", correct_query=( "SELECT courses.title, COUNT(enrollments.student_id) AS enrolled " "FROM courses " "INNER JOIN enrollments ON courses.id = enrollments.course_id " "GROUP BY courses.title" ), tables=tuple(tables), columns=("title", "student_id", "course_id"), ) @_register def exercise_active_employees(rng: random.Random) -> Exercise: tables = {"employees": ["id", "name", "salary", "status", "hire_date"]} return Exercise( schema=_fmt_schema(tables), question="What is the total salary paid to active employees?", correct_query=( "SELECT SUM(salary) FROM employees WHERE status = 'active'" ), tables=tuple(tables), columns=("salary", "status"), ) @_register def exercise_product_price_filter(rng: random.Random) -> Exercise: lo, hi = rng.randint(10, 50), rng.randint(100, 500) tables = {"products": ["id", "name", "price", "category"]} return Exercise( schema=_fmt_schema(tables), question=f"List products priced between {lo} and {hi}.", correct_query=( f"SELECT name, price FROM products " f"WHERE price BETWEEN {lo} AND {hi}" ), tables=tuple(tables), columns=("name", "price", "category"), ) def generate_exercise(rng: random.Random) -> Exercise: return rng.choice(EXERCISE_BUILDERS)(rng)