import modal import hashlib from fastapi import FastAPI from pydantic import BaseModel app = modal.App("lean-proof-agent") image = ( modal.Image.debian_slim() .apt_install("curl", "git", "build-essential") .pip_install("lean-interact", "requests", "fastapi") .run_commands( "curl https://elan.lean-lang.org/elan-init.sh -sSf | sh -s -- -y --default-toolchain leanprover/lean4:v4.14.0", ) .env({"PATH": "/root/.elan/bin:/usr/local/sbin:/usr/local/bin:/usr/sbin:/usr/bin:/sbin:/bin"}) .run_commands( 'python -c "from lean_interact import LeanREPLConfig; LeanREPLConfig()"' ) ) LLAMA_ENDPOINT = "https://no-name13--llama-server-serve.modal.run/v1/chat/completions" web_app = FastAPI() class ProveRequest(BaseModel): theorem: str max_steps: int = 20 use_fallbacks: bool = True show_reasoning: bool = True # when True, skip cache read so the full agent loop always runs class StepLog(BaseModel): step: int goal: str candidates: list[str] chosen: str status: str error: str | None = None class ProveResponse(BaseModel): success: bool stuck: bool = False tactics: list[str] steps: list[StepLog] message: str @app.function( image=image, timeout=300, min_containers=1, ) @modal.asgi_app() def fastapi_app(): from lean_interact import LeanServer, LeanREPLConfig, Command, ProofStep from lean_interact.interface import LeanError import requests config = LeanREPLConfig() server = LeanServer(config) FALLBACK_TACTICS = ["rfl", "norm_num", "simp", "omega", "contradiction", "assumption"] NUM_CANDIDATES = 3 lemma_cache: dict[str, list[str]] = {} def ask_model(goal_state, last_error=None, num_candidates=3): error_context = "" if last_error: error_context = ( f"\nThe previous tactic failed with this error:\n{last_error}\n" f"IMPORTANT: if the error says 'major premise type is not an inductive type', " f"it means you must use `intro` to bring variables into context BEFORE using `induction`.\n" ) prompt = ( f"You are a Lean 4 theorem prover. Given this proof state:\n\n{goal_state}\n" f"{error_context}\n" f"Suggest the next single tactic. Output ONLY the tactic, no backticks, no explanation.\n" f"RULES:\n" f"- do NOT use `omega`, `decide`, `tauto`\n" f"- do NOT use `apply Nat.add_comm` (using a named library lemma as a shortcut)\n" f"- ALLOWED closing tactics — use these freely when they fit:\n" f" `exact h` or `exact ⟨h1, h2⟩` (provide a proof term directly)\n" f" `contradiction` (when context contains P and ¬P)\n" f" `assumption` (when goal matches a hypothesis exactly)\n" f" `absurd h1 h2` (derive False from h1 : P and h2 : ¬P)\n" f"- always use fresh, distinct variable names when introducing (e.g. `intro n`, `intro P`, `intro Q`) — never reuse a name already present in the context\n" f"- if the goal starts with `∀`, always use `intro` first before anything else\n" f"- when using induction, always provide full case syntax:\n" f" induction n with\n | zero => simp\n | succ n ih => simp [ih]" ) tactics = [] for _ in range(num_candidates): try: payload = { "model": "any", "messages": [{"role": "user", "content": prompt}], "max_tokens": 200, "stream": False, "temperature": 0.8, "chat_template_kwargs": {"enable_thinking": False} } resp = requests.post(LLAMA_ENDPOINT, json=payload, timeout=30) resp.raise_for_status() tactic = resp.json()["choices"][0]["message"]["content"].strip().strip("`").strip() if tactic and tactic not in tactics: tactics.append(tactic) except Exception: break return tactics def try_tactic(tactic, proof_state_id): result = server.run(ProofStep(tactic=tactic, proof_state=proof_state_id)) if isinstance(result, LeanError): return None, result.message return result, None def try_fallbacks(proof_state_id, enabled): if not enabled: return None, None for tactic in FALLBACK_TACTICS: result, _ = try_tactic(tactic, proof_state_id) if result is not None and "sorry" not in result.proof_status: return result, tactic return None, None @web_app.post("/prove", response_model=ProveResponse) def prove(req: ProveRequest): steps = [] response = server.run(Command(cmd=f"{req.theorem} := by sorry")) if not response.sorries: if any(m.data == "Goals accomplished!" for m in response.messages): return ProveResponse(success=True, tactics=[], steps=[], message="Proved trivially!") return ProveResponse(success=False, tactics=[], steps=[], message="Could not get initial proof state") proof_state_id = response.sorries[0].proof_state current_goals = [response.sorries[0].goal] goal_hash = hashlib.md5(current_goals[0].encode()).hexdigest() if not req.show_reasoning and goal_hash in lemma_cache: cached = lemma_cache[goal_hash] return ProveResponse( success=True, tactics=cached, steps=[], message=f"Proved from cache ({len(cached)} tactic(s))!" ) tactics = [] last_error = None visited = set() llm_ever_responded = False consecutive_failures = 0 # all_failed steps in a row goal_seen: dict[str, int] = {} # goal text → times seen STUCK_THRESHOLD = 3 for step in range(req.max_steps): goal_text = "\n".join(current_goals) if not goal_text.strip(): break # Stuck-state detection: same goal returning, or consecutive dead ends goal_seen[goal_text] = goal_seen.get(goal_text, 0) + 1 if goal_seen[goal_text] >= STUCK_THRESHOLD or consecutive_failures >= STUCK_THRESHOLD: return ProveResponse( success=False, stuck=True, tactics=tactics, steps=steps, message=( "Search stuck — the same goal state recurred with no progress. " "This theorem is likely not provable in the current theory: " "it may require classical logic (Law of Excluded Middle), " "or a tactic the agent is constrained from using." ), ) result, fallback_tactic = try_fallbacks(proof_state_id, req.use_fallbacks) if result is not None: tactics.append(fallback_tactic) steps.append(StepLog( step=step, goal=goal_text, candidates=[fallback_tactic], chosen=fallback_tactic, status=result.proof_status )) if result.proof_status == "Completed": lemma_cache[goal_hash] = list(tactics) return ProveResponse(success=True, tactics=tactics, steps=steps, message=f"Proved in {step+1} steps!") proof_state_id = result.proof_state current_goals = result.goals last_error = None consecutive_failures = 0 continue candidates = ask_model(goal_text, last_error=last_error, num_candidates=NUM_CANDIDATES) if not candidates: steps.append(StepLog( step=step, goal=goal_text, candidates=[], chosen="", status="model_unavailable", error="LLM endpoint cold/unavailable — retrying" )) continue # don't count toward stuck-state; just burn a step and retry llm_ever_responded = True best_result = None best_tactic = None step_error = None for tactic in candidates: key = (proof_state_id, tactic) if key in visited: continue visited.add(key) result, error = try_tactic(tactic, proof_state_id) if result is None: last_error = error step_error = error continue if "sorry" in result.proof_status: last_error = "That tactic left sorry holes. Provide the full proof of each case inline." continue if result.proof_status == "Completed": tactics.append(tactic) steps.append(StepLog( step=step, goal=goal_text, candidates=candidates, chosen=tactic, status="Completed" )) lemma_cache[goal_hash] = list(tactics) return ProveResponse(success=True, tactics=tactics, steps=steps, message=f"Proved in {step+1} steps!") if best_result is None or len(result.goals) < len(best_result.goals): best_result = result best_tactic = tactic if best_result is not None: tactics.append(best_tactic) steps.append(StepLog( step=step, goal=goal_text, candidates=candidates, chosen=best_tactic, status=best_result.proof_status )) proof_state_id = best_result.proof_state current_goals = best_result.goals last_error = None consecutive_failures = 0 else: consecutive_failures += 1 steps.append(StepLog( step=step, goal=goal_text, candidates=candidates, chosen="", status="all_failed", error=step_error )) fail_msg = ( "LLM endpoint warming up — fallback tactics only. Failed within max steps." if not llm_ever_responded else "Failed within max steps" ) return ProveResponse(success=False, tactics=tactics, steps=steps, message=fail_msg) return web_app