Spaces:
Running
Running
| import modal | |
| import hashlib | |
| from fastapi import FastAPI | |
| from pydantic import BaseModel | |
| app = modal.App("lean-proof-agent") | |
| image = ( | |
| modal.Image.debian_slim() | |
| .apt_install("curl", "git", "build-essential") | |
| .pip_install("lean-interact", "requests", "fastapi") | |
| .run_commands( | |
| "curl https://elan.lean-lang.org/elan-init.sh -sSf | sh -s -- -y --default-toolchain leanprover/lean4:v4.14.0", | |
| ) | |
| .env({"PATH": "/root/.elan/bin:/usr/local/sbin:/usr/local/bin:/usr/sbin:/usr/bin:/sbin:/bin"}) | |
| .run_commands( | |
| 'python -c "from lean_interact import LeanREPLConfig; LeanREPLConfig()"' | |
| ) | |
| ) | |
| LLAMA_ENDPOINT = "https://no-name13--llama-server-serve.modal.run/v1/chat/completions" | |
| web_app = FastAPI() | |
| class ProveRequest(BaseModel): | |
| theorem: str | |
| max_steps: int = 20 | |
| use_fallbacks: bool = True | |
| show_reasoning: bool = True # when True, skip cache read so the full agent loop always runs | |
| class StepLog(BaseModel): | |
| step: int | |
| goal: str | |
| candidates: list[str] | |
| chosen: str | |
| status: str | |
| error: str | None = None | |
| class ProveResponse(BaseModel): | |
| success: bool | |
| stuck: bool = False | |
| tactics: list[str] | |
| steps: list[StepLog] | |
| message: str | |
| def fastapi_app(): | |
| from lean_interact import LeanServer, LeanREPLConfig, Command, ProofStep | |
| from lean_interact.interface import LeanError | |
| import requests | |
| config = LeanREPLConfig() | |
| server = LeanServer(config) | |
| FALLBACK_TACTICS = ["rfl", "norm_num", "simp", "omega", "contradiction", "assumption"] | |
| NUM_CANDIDATES = 3 | |
| lemma_cache: dict[str, list[str]] = {} | |
| def ask_model(goal_state, last_error=None, num_candidates=3): | |
| error_context = "" | |
| if last_error: | |
| error_context = ( | |
| f"\nThe previous tactic failed with this error:\n{last_error}\n" | |
| f"IMPORTANT: if the error says 'major premise type is not an inductive type', " | |
| f"it means you must use `intro` to bring variables into context BEFORE using `induction`.\n" | |
| ) | |
| prompt = ( | |
| f"You are a Lean 4 theorem prover. Given this proof state:\n\n{goal_state}\n" | |
| f"{error_context}\n" | |
| f"Suggest the next single tactic. Output ONLY the tactic, no backticks, no explanation.\n" | |
| f"RULES:\n" | |
| f"- do NOT use `omega`, `decide`, `tauto`\n" | |
| f"- do NOT use `apply Nat.add_comm` (using a named library lemma as a shortcut)\n" | |
| f"- ALLOWED closing tactics — use these freely when they fit:\n" | |
| f" `exact h` or `exact ⟨h1, h2⟩` (provide a proof term directly)\n" | |
| f" `contradiction` (when context contains P and ¬P)\n" | |
| f" `assumption` (when goal matches a hypothesis exactly)\n" | |
| f" `absurd h1 h2` (derive False from h1 : P and h2 : ¬P)\n" | |
| f"- always use fresh, distinct variable names when introducing (e.g. `intro n`, `intro P`, `intro Q`) — never reuse a name already present in the context\n" | |
| f"- if the goal starts with `∀`, always use `intro` first before anything else\n" | |
| f"- when using induction, always provide full case syntax:\n" | |
| f" induction n with\n | zero => simp\n | succ n ih => simp [ih]" | |
| ) | |
| tactics = [] | |
| for _ in range(num_candidates): | |
| try: | |
| payload = { | |
| "model": "any", | |
| "messages": [{"role": "user", "content": prompt}], | |
| "max_tokens": 200, | |
| "stream": False, | |
| "temperature": 0.8, | |
| "chat_template_kwargs": {"enable_thinking": False} | |
| } | |
| resp = requests.post(LLAMA_ENDPOINT, json=payload, timeout=30) | |
| resp.raise_for_status() | |
| tactic = resp.json()["choices"][0]["message"]["content"].strip().strip("`").strip() | |
| if tactic and tactic not in tactics: | |
| tactics.append(tactic) | |
| except Exception: | |
| break | |
| return tactics | |
| def try_tactic(tactic, proof_state_id): | |
| result = server.run(ProofStep(tactic=tactic, proof_state=proof_state_id)) | |
| if isinstance(result, LeanError): | |
| return None, result.message | |
| return result, None | |
| def try_fallbacks(proof_state_id, enabled): | |
| if not enabled: | |
| return None, None | |
| for tactic in FALLBACK_TACTICS: | |
| result, _ = try_tactic(tactic, proof_state_id) | |
| if result is not None and "sorry" not in result.proof_status: | |
| return result, tactic | |
| return None, None | |
| def prove(req: ProveRequest): | |
| steps = [] | |
| response = server.run(Command(cmd=f"{req.theorem} := by sorry")) | |
| if not response.sorries: | |
| if any(m.data == "Goals accomplished!" for m in response.messages): | |
| return ProveResponse(success=True, tactics=[], steps=[], message="Proved trivially!") | |
| return ProveResponse(success=False, tactics=[], steps=[], message="Could not get initial proof state") | |
| proof_state_id = response.sorries[0].proof_state | |
| current_goals = [response.sorries[0].goal] | |
| goal_hash = hashlib.md5(current_goals[0].encode()).hexdigest() | |
| if not req.show_reasoning and goal_hash in lemma_cache: | |
| cached = lemma_cache[goal_hash] | |
| return ProveResponse( | |
| success=True, tactics=cached, steps=[], | |
| message=f"Proved from cache ({len(cached)} tactic(s))!" | |
| ) | |
| tactics = [] | |
| last_error = None | |
| visited = set() | |
| llm_ever_responded = False | |
| consecutive_failures = 0 # all_failed steps in a row | |
| goal_seen: dict[str, int] = {} # goal text → times seen | |
| STUCK_THRESHOLD = 3 | |
| for step in range(req.max_steps): | |
| goal_text = "\n".join(current_goals) | |
| if not goal_text.strip(): | |
| break | |
| # Stuck-state detection: same goal returning, or consecutive dead ends | |
| goal_seen[goal_text] = goal_seen.get(goal_text, 0) + 1 | |
| if goal_seen[goal_text] >= STUCK_THRESHOLD or consecutive_failures >= STUCK_THRESHOLD: | |
| return ProveResponse( | |
| success=False, stuck=True, tactics=tactics, steps=steps, | |
| message=( | |
| "Search stuck — the same goal state recurred with no progress. " | |
| "This theorem is likely not provable in the current theory: " | |
| "it may require classical logic (Law of Excluded Middle), " | |
| "or a tactic the agent is constrained from using." | |
| ), | |
| ) | |
| result, fallback_tactic = try_fallbacks(proof_state_id, req.use_fallbacks) | |
| if result is not None: | |
| tactics.append(fallback_tactic) | |
| steps.append(StepLog( | |
| step=step, goal=goal_text, | |
| candidates=[fallback_tactic], chosen=fallback_tactic, | |
| status=result.proof_status | |
| )) | |
| if result.proof_status == "Completed": | |
| lemma_cache[goal_hash] = list(tactics) | |
| return ProveResponse(success=True, tactics=tactics, steps=steps, | |
| message=f"Proved in {step+1} steps!") | |
| proof_state_id = result.proof_state | |
| current_goals = result.goals | |
| last_error = None | |
| consecutive_failures = 0 | |
| continue | |
| candidates = ask_model(goal_text, last_error=last_error, num_candidates=NUM_CANDIDATES) | |
| if not candidates: | |
| steps.append(StepLog( | |
| step=step, goal=goal_text, | |
| candidates=[], chosen="", | |
| status="model_unavailable", error="LLM endpoint cold/unavailable — retrying" | |
| )) | |
| continue # don't count toward stuck-state; just burn a step and retry | |
| llm_ever_responded = True | |
| best_result = None | |
| best_tactic = None | |
| step_error = None | |
| for tactic in candidates: | |
| key = (proof_state_id, tactic) | |
| if key in visited: | |
| continue | |
| visited.add(key) | |
| result, error = try_tactic(tactic, proof_state_id) | |
| if result is None: | |
| last_error = error | |
| step_error = error | |
| continue | |
| if "sorry" in result.proof_status: | |
| last_error = "That tactic left sorry holes. Provide the full proof of each case inline." | |
| continue | |
| if result.proof_status == "Completed": | |
| tactics.append(tactic) | |
| steps.append(StepLog( | |
| step=step, goal=goal_text, | |
| candidates=candidates, chosen=tactic, | |
| status="Completed" | |
| )) | |
| lemma_cache[goal_hash] = list(tactics) | |
| return ProveResponse(success=True, tactics=tactics, steps=steps, | |
| message=f"Proved in {step+1} steps!") | |
| if best_result is None or len(result.goals) < len(best_result.goals): | |
| best_result = result | |
| best_tactic = tactic | |
| if best_result is not None: | |
| tactics.append(best_tactic) | |
| steps.append(StepLog( | |
| step=step, goal=goal_text, | |
| candidates=candidates, chosen=best_tactic, | |
| status=best_result.proof_status | |
| )) | |
| proof_state_id = best_result.proof_state | |
| current_goals = best_result.goals | |
| last_error = None | |
| consecutive_failures = 0 | |
| else: | |
| consecutive_failures += 1 | |
| steps.append(StepLog( | |
| step=step, goal=goal_text, | |
| candidates=candidates, chosen="", | |
| status="all_failed", error=step_error | |
| )) | |
| fail_msg = ( | |
| "LLM endpoint warming up — fallback tactics only. Failed within max steps." | |
| if not llm_ever_responded | |
| else "Failed within max steps" | |
| ) | |
| return ProveResponse(success=False, tactics=tactics, steps=steps, message=fail_msg) | |
| return web_app | |