Spaces:
Sleeping
Sleeping
| import os | |
| import re | |
| import json | |
| import io | |
| import time | |
| import base64 | |
| import traceback | |
| import contextlib | |
| import tempfile | |
| import mimetypes | |
| from urllib.parse import urlparse, parse_qs | |
| import gradio as gr | |
| import requests | |
| import pandas as pd | |
| # --- Constants --- | |
| DEFAULT_API_URL = "https://agents-course-unit4-scoring.hf.space" | |
| OPENROUTER_BASE_URL = "https://openrouter.ai/api/v1" | |
| # Fleet of free OpenRouter models. Tried in order. When one rate-limits or errors, | |
| # we fall through to the next one. Mix of strong reasoning + tool use. | |
| TEXT_MODELS = [ | |
| m.strip() for m in os.getenv( | |
| "OPENROUTER_MODELS", | |
| # Currently-valid free OpenRouter models (verify at https://openrouter.ai/models?q=free). | |
| "meta-llama/llama-3.3-70b-instruct:free," | |
| "mistralai/mistral-small-3.2-24b-instruct:free," | |
| "google/gemini-2.0-flash-exp:free," | |
| "qwen/qwen-2.5-72b-instruct:free," | |
| "deepseek/deepseek-r1:free," | |
| "deepseek/deepseek-chat:free" | |
| ).split(",") | |
| if m.strip() | |
| ] | |
| # Vision-capable free model. Gemini Flash is multimodal and free on OpenRouter. | |
| VISION_MODEL = os.getenv("OPENROUTER_VISION_MODEL", "google/gemini-2.0-flash-exp:free") | |
| MAX_TOOL_ITERATIONS = 7 | |
| TOOL_RESULT_MAX_CHARS = 3500 | |
| ANSWER_CACHE_PATH = os.getenv("ANSWER_CACHE_PATH", "/tmp/answers_cache.json") | |
| RESULTS_CSV_PATH = "/tmp/gaia_results.csv" | |
| INTER_QUESTION_SLEEP = float(os.getenv("INTER_QUESTION_SLEEP", "2")) | |
| INTER_TOOL_SLEEP = float(os.getenv("INTER_TOOL_SLEEP", "0.5")) | |
| # Track downloaded task files so vision/audio tools can re-use them by task_id. | |
| _TASK_FILE_CACHE: dict[str, dict] = {} | |
| # --------------------------------------------------------------------------- | |
| # Tool implementations | |
| # --------------------------------------------------------------------------- | |
| def tool_web_search(query: str, max_results: int = 5) -> str: | |
| """Web search. Tries Tavily first, falls back to DuckDuckGo.""" | |
| tavily_key = os.getenv("TAVILY_API_KEY") | |
| if tavily_key: | |
| try: | |
| from tavily import TavilyClient | |
| client = TavilyClient(api_key=tavily_key) | |
| res = client.search( | |
| query=query, | |
| max_results=max_results, | |
| search_depth="basic", | |
| include_answer=True, | |
| ) | |
| lines = ["[provider: tavily]"] | |
| if res.get("answer"): | |
| lines.append(f"Answer: {res['answer']}") | |
| for r in res.get("results", [])[:max_results]: | |
| lines.append( | |
| f"- {r.get('title', '')}\n {r.get('url', '')}\n {r.get('content', '')[:400]}" | |
| ) | |
| if len(lines) > 1: | |
| return "\n".join(lines) | |
| except Exception as e: | |
| print(f"tavily search failed, falling back to DDG: {e}") | |
| else: | |
| print("[search] TAVILY_API_KEY not set; using DDG.") | |
| try: | |
| from duckduckgo_search import DDGS | |
| results = ["[provider: duckduckgo]"] | |
| with DDGS() as ddgs: | |
| for r in ddgs.text(query, max_results=max_results): | |
| results.append( | |
| f"- {r.get('title', '')}\n {r.get('href', '')}\n {r.get('body', '')[:400]}" | |
| ) | |
| if len(results) == 1: | |
| return "[provider: duckduckgo] No results." | |
| return "\n".join(results) | |
| except Exception as e: | |
| return f"web_search error: {e}" | |
| def tool_fetch_url(url: str, max_chars: int = 3500) -> str: | |
| """Fetch a URL and return readable text (HTML stripped).""" | |
| try: | |
| from bs4 import BeautifulSoup | |
| headers = { | |
| "User-Agent": ( | |
| "Mozilla/5.0 (compatible; GAIA-Agent/1.0; " | |
| "+https://huggingface.co/learn/agents-course)" | |
| ) | |
| } | |
| resp = requests.get(url, headers=headers, timeout=20) | |
| resp.raise_for_status() | |
| ctype = resp.headers.get("Content-Type", "") | |
| if "html" in ctype or url.endswith((".html", ".htm")) or "<html" in resp.text[:500].lower(): | |
| soup = BeautifulSoup(resp.text, "lxml") | |
| for tag in soup(["script", "style", "noscript"]): | |
| tag.decompose() | |
| text = soup.get_text(separator="\n") | |
| else: | |
| text = resp.text | |
| text = re.sub(r"\n\s*\n+", "\n\n", text).strip() | |
| if len(text) > max_chars: | |
| text = text[:max_chars] + "\n...[truncated]" | |
| return text | |
| except Exception as e: | |
| return f"fetch_url error: {e}" | |
| def tool_wikipedia(query: str, sentences: int = 6) -> str: | |
| """Look up a topic on Wikipedia and return a summary.""" | |
| try: | |
| import wikipedia | |
| wikipedia.set_lang("en") | |
| try: | |
| return wikipedia.summary(query, sentences=sentences, auto_suggest=True, redirect=True) | |
| except wikipedia.DisambiguationError as de: | |
| options = ", ".join(de.options[:8]) | |
| return f"Disambiguation. Options: {options}" | |
| except wikipedia.PageError: | |
| hits = wikipedia.search(query, results=5) | |
| if not hits: | |
| return "No Wikipedia page found." | |
| return wikipedia.summary(hits[0], sentences=sentences, auto_suggest=False, redirect=True) | |
| except Exception as e: | |
| return f"wikipedia error: {e}" | |
| def tool_python(code: str) -> str: | |
| """Run a small Python snippet and return stdout (or the value of `result`).""" | |
| buf = io.StringIO() | |
| local_ns: dict = {} | |
| try: | |
| with contextlib.redirect_stdout(buf): | |
| exec(code, {"__builtins__": __builtins__}, local_ns) | |
| out = buf.getvalue().strip() | |
| if not out and "result" in local_ns: | |
| out = str(local_ns["result"]) | |
| return (out or "(no output)")[:2500] | |
| except Exception as e: | |
| return f"python error: {e}\n{traceback.format_exc(limit=2)}" | |
| def _extract_youtube_id(url: str) -> str | None: | |
| try: | |
| u = urlparse(url) | |
| if "youtu.be" in u.netloc: | |
| return u.path.lstrip("/").split("/")[0] or None | |
| if "youtube.com" in u.netloc: | |
| if u.path == "/watch": | |
| return parse_qs(u.query).get("v", [None])[0] | |
| if u.path.startswith("/embed/") or u.path.startswith("/shorts/"): | |
| return u.path.split("/")[2] | |
| except Exception: | |
| pass | |
| return None | |
| def tool_youtube_transcript(url: str, max_chars: int = 3500) -> str: | |
| """Fetch the spoken transcript of a YouTube video.""" | |
| try: | |
| from youtube_transcript_api import YouTubeTranscriptApi | |
| vid = _extract_youtube_id(url) or url.strip() | |
| try: | |
| data = YouTubeTranscriptApi.get_transcript(vid, languages=["en", "en-US", "en-GB"]) | |
| except Exception: | |
| tlist = YouTubeTranscriptApi.list_transcripts(vid) | |
| t = next(iter(tlist), None) | |
| data = t.fetch() if t else [] | |
| text = " ".join(seg.get("text", "") for seg in data).strip() | |
| text = re.sub(r"\s+", " ", text) | |
| if len(text) > max_chars: | |
| text = text[:max_chars] + " ...[truncated]" | |
| return text or "(empty transcript)" | |
| except Exception as e: | |
| return f"youtube_transcript error: {e}" | |
| def _hf_inference(model: str, data: bytes, content_type: str) -> str: | |
| """Call HF Inference API with raw bytes (used for Whisper audio transcription).""" | |
| hf_token = os.getenv("HF_TOKEN") or os.getenv("HUGGINGFACEHUB_API_TOKEN") | |
| headers = {"Content-Type": content_type} | |
| if hf_token: | |
| headers["Authorization"] = f"Bearer {hf_token}" | |
| url = f"https://api-inference.huggingface.co/models/{model}" | |
| # HF inference can be cold-started; retry a few times. | |
| for attempt in range(3): | |
| resp = requests.post(url, headers=headers, data=data, timeout=120) | |
| if resp.status_code == 503: | |
| # Model loading — wait per estimated_time. | |
| try: | |
| wait = float(resp.json().get("estimated_time", 10)) | |
| except Exception: | |
| wait = 10 | |
| wait = min(max(wait, 3), 30) | |
| print(f"HF model {model} loading; waiting {wait}s...") | |
| time.sleep(wait) | |
| continue | |
| resp.raise_for_status() | |
| return resp.text | |
| raise RuntimeError(f"HF model {model} not ready after retries") | |
| def tool_transcribe_audio(task_id: str) -> str: | |
| """Transcribe an attached audio file using HF Whisper Inference API.""" | |
| try: | |
| info = _TASK_FILE_CACHE.get(task_id) | |
| if not info: | |
| tool_get_task_file(task_id) | |
| info = _TASK_FILE_CACHE.get(task_id) | |
| if not info or not os.path.exists(info.get("path", "")): | |
| return "transcribe_audio error: no local file for task" | |
| path = info["path"] | |
| ext = os.path.splitext(path)[1].lower().lstrip(".") | |
| ctype_map = { | |
| "mp3": "audio/mpeg", "wav": "audio/wav", "m4a": "audio/mp4", | |
| "ogg": "audio/ogg", "flac": "audio/flac", "webm": "audio/webm", | |
| } | |
| ctype = ctype_map.get(ext, "audio/mpeg") | |
| with open(path, "rb") as f: | |
| data = f.read() | |
| raw = _hf_inference("openai/whisper-large-v3", data, ctype) | |
| try: | |
| obj = json.loads(raw) | |
| if isinstance(obj, dict) and "text" in obj: | |
| text = obj["text"] | |
| elif isinstance(obj, list) and obj and "text" in obj[0]: | |
| text = obj[0]["text"] | |
| else: | |
| text = raw | |
| except Exception: | |
| text = raw | |
| text = (text or "").strip() | |
| if len(text) > 4000: | |
| text = text[:4000] + " ...[truncated]" | |
| return text or "(empty transcript)" | |
| except Exception as e: | |
| return f"transcribe_audio error: {e}" | |
| def tool_view_image(task_id: str, question: str = "") -> str: | |
| """Inspect an image attached to a GAIA task using a vision-capable LLM via OpenRouter.""" | |
| try: | |
| from openai import OpenAI | |
| info = _TASK_FILE_CACHE.get(task_id) | |
| if not info: | |
| tool_get_task_file(task_id) | |
| info = _TASK_FILE_CACHE.get(task_id) | |
| if not info or not os.path.exists(info.get("path", "")): | |
| return "view_image error: no local file for task" | |
| suffix = os.path.splitext(info["path"])[1].lower().lstrip(".") | |
| if suffix == "jpg": | |
| suffix = "jpeg" | |
| if suffix not in {"png", "jpeg", "gif", "webp"}: | |
| return f"view_image error: unsupported image type .{suffix}" | |
| with open(info["path"], "rb") as f: | |
| b64 = base64.b64encode(f.read()).decode("ascii") | |
| data_url = f"data:image/{suffix};base64,{b64}" | |
| prompt = ( | |
| question.strip() | |
| or "Describe this image in detail, including any text, numbers, or symbols visible." | |
| ) | |
| client = OpenAI( | |
| base_url=OPENROUTER_BASE_URL, | |
| api_key=os.getenv("OPENROUTER_API_KEY"), | |
| ) | |
| resp = client.chat.completions.create( | |
| model=VISION_MODEL, | |
| messages=[ | |
| { | |
| "role": "user", | |
| "content": [ | |
| {"type": "text", "text": prompt}, | |
| {"type": "image_url", "image_url": {"url": data_url}}, | |
| ], | |
| } | |
| ], | |
| temperature=0.0, | |
| max_tokens=600, | |
| extra_headers={ | |
| "HTTP-Referer": "https://huggingface.co/learn/agents-course", | |
| "X-Title": "GAIA Agent", | |
| }, | |
| ) | |
| return (resp.choices[0].message.content or "").strip() | |
| except Exception as e: | |
| return f"view_image error: {e}" | |
| def tool_get_task_file(task_id: str, api_url: str = DEFAULT_API_URL) -> str: | |
| """Download the file attached to a task and return a text preview.""" | |
| try: | |
| resp = requests.get(f"{api_url}/files/{task_id}", timeout=30) | |
| if resp.status_code == 404: | |
| return ( | |
| "NO_FILE: This task has no attached file. Do not call get_task_file again. " | |
| "Answer using web_search / wikipedia / python / your own knowledge." | |
| ) | |
| resp.raise_for_status() | |
| ctype = resp.headers.get("Content-Type", "") | |
| cdisp = resp.headers.get("Content-Disposition", "") | |
| fname_match = re.search(r'filename="?([^"]+)"?', cdisp) | |
| fname = fname_match.group(1) if fname_match else f"{task_id}" | |
| suffix = os.path.splitext(fname)[1].lower() | |
| tmp = tempfile.NamedTemporaryFile(prefix=f"{task_id}_", suffix=suffix, delete=False) | |
| tmp.write(resp.content) | |
| tmp.close() | |
| _TASK_FILE_CACHE[task_id] = { | |
| "path": tmp.name, | |
| "name": fname, | |
| "ctype": ctype, | |
| "size": len(resp.content), | |
| } | |
| info = ( | |
| f"File: {fname}\nContent-Type: {ctype}\nSize: {len(resp.content)} bytes\n" | |
| ) | |
| if suffix in {".txt", ".md", ".csv", ".json", ".py", ".tsv", ".log", ".xml", ".html"}: | |
| try: | |
| text = resp.content.decode("utf-8", errors="replace") | |
| except Exception: | |
| text = resp.text | |
| return info + "\n--- preview ---\n" + text[:3000] | |
| if suffix in {".xlsx", ".xls"}: | |
| try: | |
| df = pd.read_excel(tmp.name) | |
| csv = df.to_csv(index=False) | |
| if len(csv) > 3000: | |
| csv = csv[:3000] + "\n...[truncated]" | |
| return info + "\n--- excel as csv ---\n" + csv | |
| except Exception as e: | |
| return info + f"\n(excel parse error: {e})" | |
| if suffix == ".pdf": | |
| try: | |
| from pypdf import PdfReader | |
| reader = PdfReader(tmp.name) | |
| pages = [p.extract_text() or "" for p in reader.pages[:6]] | |
| return info + "\n--- pdf text ---\n" + "\n".join(pages)[:3000] | |
| except Exception as e: | |
| return info + f"\n(pdf parse error: {e})" | |
| if suffix in {".mp3", ".wav", ".m4a", ".ogg", ".flac", ".webm"}: | |
| return info + "\nAudio file. Call transcribe_audio(task_id) to read it." | |
| if suffix in {".png", ".jpg", ".jpeg", ".gif", ".webp"}: | |
| return info + "\nImage file. Call view_image(task_id, question='...')." | |
| return info + "\n(binary file; no preview)" | |
| except Exception as e: | |
| return f"get_task_file error: {e}" | |
| # --------------------------------------------------------------------------- | |
| # Tool schema (OpenAI-compatible) | |
| # --------------------------------------------------------------------------- | |
| TOOLS_SPEC = [ | |
| { | |
| "type": "function", | |
| "function": { | |
| "name": "web_search", | |
| "description": "Search the web (Tavily preferred, DuckDuckGo fallback). Returns titles, URLs, snippets, and Tavily's synthesized answer.", | |
| "parameters": { | |
| "type": "object", | |
| "properties": { | |
| "query": {"type": "string"}, | |
| "max_results": {"type": "integer"}, | |
| }, | |
| "required": ["query"], | |
| }, | |
| }, | |
| }, | |
| { | |
| "type": "function", | |
| "function": { | |
| "name": "fetch_url", | |
| "description": "Fetch a URL and return cleaned page text. Use after web_search to read a result page.", | |
| "parameters": { | |
| "type": "object", | |
| "properties": { | |
| "url": {"type": "string"}, | |
| "max_chars": {"type": "integer"}, | |
| }, | |
| "required": ["url"], | |
| }, | |
| }, | |
| }, | |
| { | |
| "type": "function", | |
| "function": { | |
| "name": "wikipedia", | |
| "description": "Get a Wikipedia summary for a person, place, work, or topic. Use FIRST for biographical or list questions.", | |
| "parameters": { | |
| "type": "object", | |
| "properties": { | |
| "query": {"type": "string"}, | |
| "sentences": {"type": "integer"}, | |
| }, | |
| "required": ["query"], | |
| }, | |
| }, | |
| }, | |
| { | |
| "type": "function", | |
| "function": { | |
| "name": "python", | |
| "description": "Execute a Python snippet for math, sums, dates, sorting, alphabetizing, parsing, string reversal, set logic. Use print() or assign to `result`.", | |
| "parameters": { | |
| "type": "object", | |
| "properties": {"code": {"type": "string"}}, | |
| "required": ["code"], | |
| }, | |
| }, | |
| }, | |
| { | |
| "type": "function", | |
| "function": { | |
| "name": "get_task_file", | |
| "description": "Download the file attached to a GAIA task by task_id. Returns NO_FILE if no file exists.", | |
| "parameters": { | |
| "type": "object", | |
| "properties": {"task_id": {"type": "string"}}, | |
| "required": ["task_id"], | |
| }, | |
| }, | |
| }, | |
| { | |
| "type": "function", | |
| "function": { | |
| "name": "transcribe_audio", | |
| "description": "Transcribe an attached audio file (.mp3/.wav/.m4a/.ogg/.flac) using Whisper.", | |
| "parameters": { | |
| "type": "object", | |
| "properties": {"task_id": {"type": "string"}}, | |
| "required": ["task_id"], | |
| }, | |
| }, | |
| }, | |
| { | |
| "type": "function", | |
| "function": { | |
| "name": "view_image", | |
| "description": "Inspect an attached image (.png/.jpg/.gif/.webp) using a vision model. Pass a focused question.", | |
| "parameters": { | |
| "type": "object", | |
| "properties": { | |
| "task_id": {"type": "string"}, | |
| "question": {"type": "string"}, | |
| }, | |
| "required": ["task_id"], | |
| }, | |
| }, | |
| }, | |
| { | |
| "type": "function", | |
| "function": { | |
| "name": "youtube_transcript", | |
| "description": "Fetch the spoken transcript of a YouTube video given its URL. Only captures speech, not visual content.", | |
| "parameters": { | |
| "type": "object", | |
| "properties": { | |
| "url": {"type": "string"}, | |
| "max_chars": {"type": "integer"}, | |
| }, | |
| "required": ["url"], | |
| }, | |
| }, | |
| }, | |
| ] | |
| TOOL_FUNCTIONS = { | |
| "web_search": lambda args: tool_web_search(args["query"], int(args.get("max_results", 5))), | |
| "fetch_url": lambda args: tool_fetch_url(args["url"], int(args.get("max_chars", 3500))), | |
| "wikipedia": lambda args: tool_wikipedia(args["query"], int(args.get("sentences", 6))), | |
| "python": lambda args: tool_python(args["code"]), | |
| "get_task_file": lambda args: tool_get_task_file(args["task_id"]), | |
| "transcribe_audio": lambda args: tool_transcribe_audio(args["task_id"]), | |
| "view_image": lambda args: tool_view_image(args["task_id"], args.get("question", "")), | |
| "youtube_transcript": lambda args: tool_youtube_transcript( | |
| args["url"], int(args.get("max_chars", 3500)) | |
| ), | |
| } | |
| SYSTEM_PROMPT = """You are a careful research agent answering GAIA benchmark questions. | |
| Tools: web_search, fetch_url, wikipedia, python, get_task_file, transcribe_audio, view_image, youtube_transcript. | |
| Decision rules: | |
| - If the question references "attached file/image/audio/Excel/PDF/.mp3/.xlsx/.py/recording/photo/image", call get_task_file FIRST. | |
| - Audio (.mp3, .wav, etc.) -> transcribe_audio(task_id) after get_task_file. | |
| - Image (.png, .jpg, etc.) -> view_image(task_id, question="<focused question>") after get_task_file. | |
| - Excel/CSV/text/PDF — the get_task_file preview is enough; use python to compute on it. | |
| - If get_task_file returns NO_FILE, do NOT call it again. | |
| - For YouTube URLs, use youtube_transcript(url) directly. (No get_task_file needed.) The transcript is speech only — for visual questions, give your best estimate. | |
| - For factual lookups about people, places, artists, albums, animals, Wikipedia featured articles: START with wikipedia. | |
| - For everything else research-y: web_search then fetch_url the most relevant URL. | |
| - Use python for ALL arithmetic, sums, date math, sorting, alphabetizing, set/group operations, string reversal. Never compute by hand. | |
| - For Excel/CSV totals, after get_task_file shows the data, ALWAYS use python to compute the sum precisely. | |
| Be decisive — don't repeat the same tool with the same args. You have 7 tool turns. | |
| ANSWER FORMATTING (the grader does an exact-match comparison; sentence answers ALWAYS lose): | |
| Worked examples of correct GAIA format: | |
| - Q: "How many albums..." -> "3" (NOT "3 albums" or "There were 3 albums") | |
| - Q: "Express your answer in USD with two decimal places" -> "89706.00" | |
| - Q: "Give the IOC country code" -> "MLT" | |
| - Q: "Just the city name without abbreviations" -> "Saint Petersburg" | |
| - Q: "Give only the first name" -> "Bartek" | |
| - Q: "Comma separated list ... in alphabetical order" -> "broccoli, celery, fresh basil, lettuce, sweet potatoes, zucchini" | |
| - Q: "Under what NASA award number..." -> "80NSSC21K1130" | |
| - Q: opposite of "left" -> "right" | |
| Strict rules: | |
| - Reply with ONLY the answer. No preamble. No explanation. No quotes. No trailing period. | |
| - Do NOT include "FINAL ANSWER", "Answer:", or any label. | |
| - Numbers: digits only, no commas, no units, no $ — UNLESS the question asks for the unit. | |
| - Currency "two decimal places": e.g. "89706.00". | |
| - Strings: no leading articles ("the", "a") unless required; no abbreviations ("Saint" not "St."); digits as digits. | |
| - Names: read the question carefully ("first name only" / "last name only" / "surname" / "full name"). | |
| - Lists: comma-separated, ONE space after each comma. Sort if asked. | |
| """ | |
| def _maybe_reverse_text(question: str) -> str: | |
| """If the question text looks reversed, flip it.""" | |
| q = question.strip() | |
| if not q: | |
| return question | |
| starts_with_punct = q[0] in ".,;:!?" | |
| reversed_text = q[::-1] | |
| common = (" the ", " of ", " and ", " to ", " is ", " a ", " in ", " for ") | |
| hits = sum(1 for w in common if w in (" " + reversed_text.lower() + " ")) | |
| if starts_with_punct and hits >= 2: | |
| return reversed_text | |
| return question | |
| # --------------------------------------------------------------------------- | |
| # Agent | |
| # --------------------------------------------------------------------------- | |
| class OpenRouterAgent: | |
| def __init__(self): | |
| try: | |
| from openai import OpenAI | |
| except ImportError as e: | |
| raise RuntimeError("openai package not installed") from e | |
| api_key = os.getenv("OPENROUTER_API_KEY") | |
| if not api_key: | |
| raise RuntimeError( | |
| "OPENROUTER_API_KEY is not set. Get one free at https://openrouter.ai/keys " | |
| "and add it as a Secret in your HF Space settings." | |
| ) | |
| self.client = OpenAI(base_url=OPENROUTER_BASE_URL, api_key=api_key) | |
| self.models = list(TEXT_MODELS) | |
| self.exhausted: set[str] = set() | |
| self.extra_headers = { | |
| "HTTP-Referer": "https://huggingface.co/learn/agents-course", | |
| "X-Title": "GAIA Agent", | |
| } | |
| print(f"OpenRouterAgent initialized with model fleet: {self.models}") | |
| def _chat(self, messages, use_tools: bool = True, max_tokens: int = 800): | |
| """Try each model in the fleet. Falls through on rate limit / error.""" | |
| last_error: Exception | None = None | |
| for m in self.models: | |
| if m in self.exhausted: | |
| continue | |
| for attempt in range(2): | |
| try: | |
| kwargs = dict( | |
| model=m, | |
| messages=messages, | |
| temperature=0.0, | |
| max_tokens=max_tokens, | |
| extra_headers=self.extra_headers, | |
| ) | |
| if use_tools: | |
| kwargs["tools"] = TOOLS_SPEC | |
| kwargs["tool_choice"] = "auto" | |
| return self.client.chat.completions.create(**kwargs) | |
| except Exception as e: | |
| msg = str(e) | |
| last_error = e | |
| is_rate = "429" in msg or "rate" in msg.lower() or "limit" in msg.lower() | |
| is_quota = ("daily" in msg.lower() or "quota" in msg.lower() | |
| or "exhausted" in msg.lower()) | |
| # 404 / "No endpoints found" / "model not found" -> dead model, never retry. | |
| is_dead = ( | |
| "404" in msg | |
| or "no endpoints" in msg.lower() | |
| or "not found" in msg.lower() | |
| or "model_not_found" in msg.lower() | |
| ) | |
| if is_dead: | |
| print(f"[{m}] model unavailable (404 / no endpoints); marking exhausted.") | |
| self.exhausted.add(m) | |
| break | |
| if is_rate and is_quota: | |
| print(f"[{m}] daily quota exhausted; switching model.") | |
| self.exhausted.add(m) | |
| break | |
| if is_rate: | |
| wait = 4 * (attempt + 1) | |
| print(f"[{m}] rate-limited; sleeping {wait}s (attempt {attempt + 1}/2)") | |
| time.sleep(wait) | |
| continue | |
| print(f"[{m}] API error: {repr(e)[:240]} — trying next model.") | |
| break | |
| err_str = repr(last_error) if last_error else "no error captured" | |
| raise RuntimeError(f"All OpenRouter models failed. {err_str}") | |
| def __call__(self, question: str, task_id: str | None = None) -> str: | |
| flipped = _maybe_reverse_text(question) | |
| if flipped != question: | |
| print("[reversed-text detected, flipping question]") | |
| question = flipped | |
| user_content = question | |
| if task_id: | |
| user_content = f"task_id: {task_id}\n\nQuestion: {question}" | |
| messages = [ | |
| {"role": "system", "content": SYSTEM_PROMPT}, | |
| {"role": "user", "content": user_content}, | |
| ] | |
| collected_facts: list[str] = [] | |
| seen_calls: set[str] = set() | |
| for step in range(MAX_TOOL_ITERATIONS): | |
| try: | |
| resp = self._chat(messages, use_tools=True, max_tokens=800) | |
| except Exception as e: | |
| print(f"chat at step {step} failed: {e}") | |
| break | |
| msg = resp.choices[0].message | |
| tool_calls = getattr(msg, "tool_calls", None) | |
| if not tool_calls: | |
| answer = (msg.content or "").strip() | |
| if answer: | |
| return self._finalize(answer, question, collected_facts) | |
| break | |
| messages.append( | |
| { | |
| "role": "assistant", | |
| "content": msg.content or "", | |
| "tool_calls": [ | |
| { | |
| "id": tc.id, | |
| "type": "function", | |
| "function": { | |
| "name": tc.function.name, | |
| "arguments": tc.function.arguments, | |
| }, | |
| } | |
| for tc in tool_calls | |
| ], | |
| } | |
| ) | |
| for tc in tool_calls: | |
| name = tc.function.name | |
| try: | |
| args = json.loads(tc.function.arguments or "{}") | |
| except json.JSONDecodeError: | |
| args = {} | |
| call_key = f"{name}|{json.dumps(args, sort_keys=True, default=str)[:300]}" | |
| if call_key in seen_calls: | |
| print(f"[tool] {name}({str(args)[:120]}) [DUPLICATE — skipping]") | |
| result = "DUPLICATE_CALL: you already called this with the same args. Try a different query, a different tool, or give your final answer." | |
| else: | |
| seen_calls.add(call_key) | |
| fn = TOOL_FUNCTIONS.get(name) | |
| print(f"[tool] {name}({str(args)[:200]})") | |
| if fn is None: | |
| result = f"unknown tool: {name}" | |
| else: | |
| try: | |
| result = fn(args) | |
| except Exception as e: | |
| result = f"{name} error: {e}" | |
| if not isinstance(result, str): | |
| result = str(result) | |
| if len(result) > TOOL_RESULT_MAX_CHARS: | |
| result = result[:TOOL_RESULT_MAX_CHARS] + "\n...[truncated]" | |
| collected_facts.append(f"[{name}] {result[:1200]}") | |
| messages.append( | |
| { | |
| "role": "tool", | |
| "tool_call_id": tc.id, | |
| "name": name, | |
| "content": result, | |
| } | |
| ) | |
| if INTER_TOOL_SLEEP > 0: | |
| time.sleep(INTER_TOOL_SLEEP) | |
| return self._synthesize(question, collected_facts) | |
| def _synthesize(self, question: str, facts: list[str]) -> str: | |
| """Final answer pass on a short context. No tools.""" | |
| joined = "\n\n".join(facts[-8:]) | |
| if len(joined) > 5000: | |
| joined = joined[-5000:] | |
| synth_messages = [ | |
| { | |
| "role": "system", | |
| "content": ( | |
| "You are a strict GAIA answer formatter. Read the question and the research " | |
| "notes, then output ONLY the final answer string. No preamble, no labels, no " | |
| "explanation, no quotes, no trailing period. Match the question's required " | |
| "format exactly. If notes are insufficient, give your single best guess based " | |
| "on general knowledge. Never refuse, never apologize, never reply with empty." | |
| ), | |
| }, | |
| { | |
| "role": "user", | |
| "content": ( | |
| f"Question:\n{question}\n\n" | |
| f"Research notes:\n{joined or '(no notes)'}\n\nFinal answer:" | |
| ), | |
| }, | |
| ] | |
| try: | |
| resp = self._chat(synth_messages, use_tools=False, max_tokens=120) | |
| return self._postprocess_answer( | |
| (resp.choices[0].message.content or "").strip(), question | |
| ) or "unknown" | |
| except Exception as e: | |
| print(f"synthesis failed: {e}") | |
| # Last-resort: tiny zero-shot guess | |
| try: | |
| resp = self._chat( | |
| [ | |
| {"role": "system", "content": "Answer in 1-5 words. No explanation."}, | |
| {"role": "user", "content": question[:500]}, | |
| ], | |
| use_tools=False, | |
| max_tokens=40, | |
| ) | |
| return self._postprocess_answer( | |
| (resp.choices[0].message.content or "").strip(), question | |
| ) or "unknown" | |
| except Exception as e2: | |
| print(f"last-resort guess failed: {e2}") | |
| return "unknown" | |
| def _finalize(self, raw: str, question: str, facts: list[str]) -> str: | |
| cleaned = self._postprocess_answer(raw, question) | |
| if not cleaned: | |
| return self._synthesize(question, facts) | |
| looks_sentence = ( | |
| len(cleaned.split()) > 12 | |
| or re.search( | |
| r"\b(because|received|grant|seems|unable|sorry|cannot|provides|indicating|" | |
| r"web_search|youtube_transcript|fetch_url)\b", | |
| cleaned, | |
| re.IGNORECASE, | |
| ) | |
| ) | |
| if looks_sentence: | |
| try: | |
| resp = self._chat( | |
| [ | |
| { | |
| "role": "system", | |
| "content": ( | |
| "Extract ONLY the final answer from the assistant text below, " | |
| "matching the question's required format exactly. No preamble, " | |
| "no explanation, no quotes, no trailing period, no labels." | |
| ), | |
| }, | |
| { | |
| "role": "user", | |
| "content": f"Question: {question}\n\nAssistant text: {cleaned}\n\nFinal answer:", | |
| }, | |
| ], | |
| use_tools=False, | |
| max_tokens=80, | |
| ) | |
| reformat = (resp.choices[0].message.content or "").strip() | |
| reformat = self._postprocess_answer(reformat, question) | |
| if reformat: | |
| return reformat | |
| except Exception as e: | |
| print(f"reformat pass failed: {e}") | |
| return cleaned | |
| def _postprocess_answer(text: str, question: str = "") -> str: | |
| if not text: | |
| return "" | |
| text = text.strip() | |
| text = re.sub( | |
| r"^(final\s*answer|answer|the\s*answer\s*is)\s*[:\-]?\s*", | |
| "", | |
| text, | |
| flags=re.IGNORECASE, | |
| ) | |
| text = text.strip("`") | |
| if len(text) >= 2 and text[0] == text[-1] and text[0] in {'"', "'"}: | |
| text = text[1:-1].strip() | |
| q_lower = question.lower() | |
| wants_number = bool( | |
| re.search(r"\bhow many\b|\bhow much\b|\bwhat number\b|\bcount\b", q_lower) | |
| ) | |
| if wants_number and not re.fullmatch(r"-?\d+(\.\d+)?", text): | |
| m = re.search(r"-?\d+(?:\.\d+)?", text.replace(",", "")) | |
| if m: | |
| text = m.group(0) | |
| if text.endswith(".") and " " not in text: | |
| text = text[:-1] | |
| return text.strip() | |
| # --------------------------------------------------------------------------- | |
| # Cache | |
| # --------------------------------------------------------------------------- | |
| def _load_cache() -> dict: | |
| try: | |
| with open(ANSWER_CACHE_PATH, "r", encoding="utf-8") as f: | |
| return json.load(f) | |
| except (FileNotFoundError, json.JSONDecodeError): | |
| return {} | |
| def _save_cache(cache: dict) -> None: | |
| try: | |
| with open(ANSWER_CACHE_PATH, "w", encoding="utf-8") as f: | |
| json.dump(cache, f, ensure_ascii=False, indent=2) | |
| except Exception as e: | |
| print(f"cache save error: {e}") | |
| # --------------------------------------------------------------------------- | |
| # Gradio submission flow | |
| # --------------------------------------------------------------------------- | |
| def run_and_submit_all(profile: gr.OAuthProfile | None): | |
| space_id = os.getenv("SPACE_ID") | |
| if profile: | |
| username = f"{profile.username}" | |
| print(f"User logged in: {username}") | |
| else: | |
| return "Please Login to Hugging Face with the button.", None, None | |
| api_url = DEFAULT_API_URL | |
| questions_url = f"{api_url}/questions" | |
| submit_url = f"{api_url}/submit" | |
| try: | |
| agent = OpenRouterAgent() | |
| except Exception as e: | |
| print(f"Error instantiating agent: {e}") | |
| return f"Error initializing agent: {e}", None, None | |
| agent_code = f"https://huggingface.co/spaces/{space_id}/tree/main" | |
| print(agent_code) | |
| print(f"Fetching questions from: {questions_url}") | |
| try: | |
| response = requests.get(questions_url, timeout=15) | |
| response.raise_for_status() | |
| questions_data = response.json() | |
| if not questions_data: | |
| return "Fetched questions list is empty or invalid format.", None, None | |
| print(f"Fetched {len(questions_data)} questions.") | |
| except requests.exceptions.RequestException as e: | |
| return f"Error fetching questions: {e}", None, None | |
| except Exception as e: | |
| return f"An unexpected error occurred fetching questions: {e}", None, None | |
| results_log = [] | |
| answers_payload = [] | |
| cache = _load_cache() | |
| if cache: | |
| print(f"Loaded {len(cache)} cached answers from {ANSWER_CACHE_PATH}") | |
| print(f"Running agent on {len(questions_data)} questions...") | |
| for idx, item in enumerate(questions_data, 1): | |
| task_id = item.get("task_id") | |
| question_text = item.get("question") | |
| if not task_id or question_text is None: | |
| print(f"Skipping item with missing task_id or question: {item}") | |
| continue | |
| print(f"\n=== [{idx}/{len(questions_data)}] task_id={task_id} ===") | |
| cached = cache.get(task_id) | |
| if cached and not str(cached).startswith("AGENT ERROR") and cached not in {"", "unknown"}: | |
| submitted_answer = cached | |
| print(f"(cache hit) {submitted_answer[:80]}") | |
| else: | |
| try: | |
| submitted_answer = agent(question_text, task_id=task_id) | |
| except Exception as e: | |
| print(f"Error running agent on task {task_id}: {e}") | |
| submitted_answer = f"AGENT ERROR: {e}" | |
| cache[task_id] = submitted_answer | |
| _save_cache(cache) | |
| answers_payload.append({"task_id": task_id, "submitted_answer": submitted_answer}) | |
| results_log.append( | |
| {"Task ID": task_id, "Question": question_text, "Submitted Answer": submitted_answer} | |
| ) | |
| if INTER_QUESTION_SLEEP > 0 and idx < len(questions_data): | |
| time.sleep(INTER_QUESTION_SLEEP) | |
| if not answers_payload: | |
| df = pd.DataFrame(results_log) | |
| df.to_csv(RESULTS_CSV_PATH, index=False) | |
| return "Agent did not produce any answers to submit.", df, RESULTS_CSV_PATH | |
| df = pd.DataFrame(results_log) | |
| df.to_csv(RESULTS_CSV_PATH, index=False) | |
| print(f"Results CSV written to {RESULTS_CSV_PATH}") | |
| submission_data = { | |
| "username": username.strip(), | |
| "agent_code": agent_code, | |
| "answers": answers_payload, | |
| } | |
| print(f"Submitting {len(answers_payload)} answers for user '{username}'...") | |
| last_error = None | |
| for attempt in range(3): | |
| try: | |
| response = requests.post(submit_url, json=submission_data, timeout=120) | |
| response.raise_for_status() | |
| result_data = response.json() | |
| final_status = ( | |
| f"Submission Successful!\n" | |
| f"User: {result_data.get('username')}\n" | |
| f"Overall Score: {result_data.get('score', 'N/A')}% " | |
| f"({result_data.get('correct_count', '?')}/{result_data.get('total_attempted', '?')} correct)\n" | |
| f"Message: {result_data.get('message', 'No message received.')}" | |
| ) | |
| return final_status, df, RESULTS_CSV_PATH | |
| except requests.exceptions.HTTPError as e: | |
| status = e.response.status_code if e.response is not None else "?" | |
| last_error = e | |
| print(f"Submission attempt {attempt + 1} failed: HTTP {status}") | |
| if status and 500 <= int(status) < 600: | |
| time.sleep(5 * (attempt + 1)) | |
| continue | |
| error_detail = f"Server responded with status {status}." | |
| try: | |
| error_detail += f" Detail: {e.response.json().get('detail', e.response.text)}" | |
| except Exception: | |
| error_detail += f" Response: {e.response.text[:500] if e.response is not None else ''}" | |
| return f"Submission Failed: {error_detail}", df, RESULTS_CSV_PATH | |
| except requests.exceptions.Timeout as e: | |
| last_error = e | |
| print(f"Submission attempt {attempt + 1} timed out.") | |
| time.sleep(5 * (attempt + 1)) | |
| continue | |
| except requests.exceptions.RequestException as e: | |
| last_error = e | |
| print(f"Submission attempt {attempt + 1} network error: {e}") | |
| time.sleep(5 * (attempt + 1)) | |
| continue | |
| return ( | |
| f"Submission Failed after retries: {last_error}.", | |
| df, | |
| RESULTS_CSV_PATH, | |
| ) | |
| # --- Gradio UI --- | |
| with gr.Blocks() as demo: | |
| gr.Markdown("# GAIA Agent (OpenRouter) — Evaluation Runner") | |
| gr.Markdown( | |
| """ | |
| **Setup** | |
| 1. Add a Space secret named `OPENROUTER_API_KEY` (free at [openrouter.ai/keys](https://openrouter.ai/keys)). | |
| 2. *Optional but recommended:* `TAVILY_API_KEY` for better search. | |
| 3. Optional: `HF_TOKEN` for Whisper audio transcription via HF Inference API. | |
| 4. Optional env vars: `OPENROUTER_MODELS` (comma-separated fleet), `OPENROUTER_VISION_MODEL`. | |
| 5. Log in to Hugging Face below and click **Run Evaluation & Submit All Answers**. | |
| Tools: `web_search`, `fetch_url`, `wikipedia`, `python`, `get_task_file`, | |
| `transcribe_audio` (HF Whisper), `view_image` (Gemini Flash via OpenRouter), `youtube_transcript`. | |
| Model fleet falls through automatically when one rate-limits. | |
| """ | |
| ) | |
| gr.LoginButton() | |
| run_button = gr.Button("Run Evaluation & Submit All Answers") | |
| status_output = gr.Textbox(label="Run Status / Submission Result", lines=5, interactive=False) | |
| results_table = gr.DataFrame(label="Questions and Agent Answers", wrap=True) | |
| results_csv = gr.File(label="Download Results CSV (paste back to me for tuning)") | |
| run_button.click(fn=run_and_submit_all, outputs=[status_output, results_table, results_csv]) | |
| if __name__ == "__main__": | |
| print("\n" + "-" * 30 + " App Starting " + "-" * 30) | |
| space_host_startup = os.getenv("SPACE_HOST") | |
| space_id_startup = os.getenv("SPACE_ID") | |
| if space_host_startup: | |
| print(f"✅ SPACE_HOST found: {space_host_startup}") | |
| else: | |
| print("ℹ️ SPACE_HOST not found (running locally?).") | |
| if space_id_startup: | |
| print(f"✅ SPACE_ID found: {space_id_startup}") | |
| print(f" Repo Tree URL: https://huggingface.co/spaces/{space_id_startup}/tree/main") | |
| else: | |
| print("ℹ️ SPACE_ID not found (running locally?).") | |
| if not os.getenv("OPENROUTER_API_KEY"): | |
| print("⚠️ OPENROUTER_API_KEY is not set. Set it before running evaluation.") | |
| if not os.getenv("TAVILY_API_KEY"): | |
| print("ℹ️ TAVILY_API_KEY not set — search will use DuckDuckGo (less reliable).") | |
| if not os.getenv("HF_TOKEN"): | |
| print("ℹ️ HF_TOKEN not set — audio transcription may rate-limit on cold starts.") | |
| print("-" * (60 + len(" App Starting ")) + "\n") | |
| demo.launch(debug=True, share=False) | |