QED / back.py
SPP
Q.E.D — initial submission
ed428ff
Raw
History Blame Contribute Delete
10.8 kB
import modal
import hashlib
from fastapi import FastAPI
from pydantic import BaseModel
app = modal.App("lean-proof-agent")
image = (
modal.Image.debian_slim()
.apt_install("curl", "git", "build-essential")
.pip_install("lean-interact", "requests", "fastapi")
.run_commands(
"curl https://elan.lean-lang.org/elan-init.sh -sSf | sh -s -- -y --default-toolchain leanprover/lean4:v4.14.0",
)
.env({"PATH": "/root/.elan/bin:/usr/local/sbin:/usr/local/bin:/usr/sbin:/usr/bin:/sbin:/bin"})
.run_commands(
'python -c "from lean_interact import LeanREPLConfig; LeanREPLConfig()"'
)
)
LLAMA_ENDPOINT = "https://no-name13--llama-server-serve.modal.run/v1/chat/completions"
web_app = FastAPI()
class ProveRequest(BaseModel):
theorem: str
max_steps: int = 20
use_fallbacks: bool = True
show_reasoning: bool = True # when True, skip cache read so the full agent loop always runs
class StepLog(BaseModel):
step: int
goal: str
candidates: list[str]
chosen: str
status: str
error: str | None = None
class ProveResponse(BaseModel):
success: bool
stuck: bool = False
tactics: list[str]
steps: list[StepLog]
message: str
@app.function(
image=image,
timeout=300,
min_containers=1,
)
@modal.asgi_app()
def fastapi_app():
from lean_interact import LeanServer, LeanREPLConfig, Command, ProofStep
from lean_interact.interface import LeanError
import requests
config = LeanREPLConfig()
server = LeanServer(config)
FALLBACK_TACTICS = ["rfl", "norm_num", "simp", "omega", "contradiction", "assumption"]
NUM_CANDIDATES = 3
lemma_cache: dict[str, list[str]] = {}
def ask_model(goal_state, last_error=None, num_candidates=3):
error_context = ""
if last_error:
error_context = (
f"\nThe previous tactic failed with this error:\n{last_error}\n"
f"IMPORTANT: if the error says 'major premise type is not an inductive type', "
f"it means you must use `intro` to bring variables into context BEFORE using `induction`.\n"
)
prompt = (
f"You are a Lean 4 theorem prover. Given this proof state:\n\n{goal_state}\n"
f"{error_context}\n"
f"Suggest the next single tactic. Output ONLY the tactic, no backticks, no explanation.\n"
f"RULES:\n"
f"- do NOT use `omega`, `decide`, `tauto`\n"
f"- do NOT use `apply Nat.add_comm` (using a named library lemma as a shortcut)\n"
f"- ALLOWED closing tactics — use these freely when they fit:\n"
f" `exact h` or `exact ⟨h1, h2⟩` (provide a proof term directly)\n"
f" `contradiction` (when context contains P and ¬P)\n"
f" `assumption` (when goal matches a hypothesis exactly)\n"
f" `absurd h1 h2` (derive False from h1 : P and h2 : ¬P)\n"
f"- always use fresh, distinct variable names when introducing (e.g. `intro n`, `intro P`, `intro Q`) — never reuse a name already present in the context\n"
f"- if the goal starts with `∀`, always use `intro` first before anything else\n"
f"- when using induction, always provide full case syntax:\n"
f" induction n with\n | zero => simp\n | succ n ih => simp [ih]"
)
tactics = []
for _ in range(num_candidates):
try:
payload = {
"model": "any",
"messages": [{"role": "user", "content": prompt}],
"max_tokens": 200,
"stream": False,
"temperature": 0.8,
"chat_template_kwargs": {"enable_thinking": False}
}
resp = requests.post(LLAMA_ENDPOINT, json=payload, timeout=30)
resp.raise_for_status()
tactic = resp.json()["choices"][0]["message"]["content"].strip().strip("`").strip()
if tactic and tactic not in tactics:
tactics.append(tactic)
except Exception:
break
return tactics
def try_tactic(tactic, proof_state_id):
result = server.run(ProofStep(tactic=tactic, proof_state=proof_state_id))
if isinstance(result, LeanError):
return None, result.message
return result, None
def try_fallbacks(proof_state_id, enabled):
if not enabled:
return None, None
for tactic in FALLBACK_TACTICS:
result, _ = try_tactic(tactic, proof_state_id)
if result is not None and "sorry" not in result.proof_status:
return result, tactic
return None, None
@web_app.post("/prove", response_model=ProveResponse)
def prove(req: ProveRequest):
steps = []
response = server.run(Command(cmd=f"{req.theorem} := by sorry"))
if not response.sorries:
if any(m.data == "Goals accomplished!" for m in response.messages):
return ProveResponse(success=True, tactics=[], steps=[], message="Proved trivially!")
return ProveResponse(success=False, tactics=[], steps=[], message="Could not get initial proof state")
proof_state_id = response.sorries[0].proof_state
current_goals = [response.sorries[0].goal]
goal_hash = hashlib.md5(current_goals[0].encode()).hexdigest()
if not req.show_reasoning and goal_hash in lemma_cache:
cached = lemma_cache[goal_hash]
return ProveResponse(
success=True, tactics=cached, steps=[],
message=f"Proved from cache ({len(cached)} tactic(s))!"
)
tactics = []
last_error = None
visited = set()
llm_ever_responded = False
consecutive_failures = 0 # all_failed steps in a row
goal_seen: dict[str, int] = {} # goal text → times seen
STUCK_THRESHOLD = 3
for step in range(req.max_steps):
goal_text = "\n".join(current_goals)
if not goal_text.strip():
break
# Stuck-state detection: same goal returning, or consecutive dead ends
goal_seen[goal_text] = goal_seen.get(goal_text, 0) + 1
if goal_seen[goal_text] >= STUCK_THRESHOLD or consecutive_failures >= STUCK_THRESHOLD:
return ProveResponse(
success=False, stuck=True, tactics=tactics, steps=steps,
message=(
"Search stuck — the same goal state recurred with no progress. "
"This theorem is likely not provable in the current theory: "
"it may require classical logic (Law of Excluded Middle), "
"or a tactic the agent is constrained from using."
),
)
result, fallback_tactic = try_fallbacks(proof_state_id, req.use_fallbacks)
if result is not None:
tactics.append(fallback_tactic)
steps.append(StepLog(
step=step, goal=goal_text,
candidates=[fallback_tactic], chosen=fallback_tactic,
status=result.proof_status
))
if result.proof_status == "Completed":
lemma_cache[goal_hash] = list(tactics)
return ProveResponse(success=True, tactics=tactics, steps=steps,
message=f"Proved in {step+1} steps!")
proof_state_id = result.proof_state
current_goals = result.goals
last_error = None
consecutive_failures = 0
continue
candidates = ask_model(goal_text, last_error=last_error, num_candidates=NUM_CANDIDATES)
if not candidates:
steps.append(StepLog(
step=step, goal=goal_text,
candidates=[], chosen="",
status="model_unavailable", error="LLM endpoint cold/unavailable — retrying"
))
continue # don't count toward stuck-state; just burn a step and retry
llm_ever_responded = True
best_result = None
best_tactic = None
step_error = None
for tactic in candidates:
key = (proof_state_id, tactic)
if key in visited:
continue
visited.add(key)
result, error = try_tactic(tactic, proof_state_id)
if result is None:
last_error = error
step_error = error
continue
if "sorry" in result.proof_status:
last_error = "That tactic left sorry holes. Provide the full proof of each case inline."
continue
if result.proof_status == "Completed":
tactics.append(tactic)
steps.append(StepLog(
step=step, goal=goal_text,
candidates=candidates, chosen=tactic,
status="Completed"
))
lemma_cache[goal_hash] = list(tactics)
return ProveResponse(success=True, tactics=tactics, steps=steps,
message=f"Proved in {step+1} steps!")
if best_result is None or len(result.goals) < len(best_result.goals):
best_result = result
best_tactic = tactic
if best_result is not None:
tactics.append(best_tactic)
steps.append(StepLog(
step=step, goal=goal_text,
candidates=candidates, chosen=best_tactic,
status=best_result.proof_status
))
proof_state_id = best_result.proof_state
current_goals = best_result.goals
last_error = None
consecutive_failures = 0
else:
consecutive_failures += 1
steps.append(StepLog(
step=step, goal=goal_text,
candidates=candidates, chosen="",
status="all_failed", error=step_error
))
fail_msg = (
"LLM endpoint warming up — fallback tactics only. Failed within max steps."
if not llm_ever_responded
else "Failed within max steps"
)
return ProveResponse(success=False, tactics=tactics, steps=steps, message=fail_msg)
return web_app