avi080704's picture
Update app.py
bdc34ca verified
Raw
History Blame Contribute Delete
42.7 kB
import os
import re
import json
import io
import time
import base64
import traceback
import contextlib
import tempfile
import mimetypes
from urllib.parse import urlparse, parse_qs
import gradio as gr
import requests
import pandas as pd
# --- Constants ---
DEFAULT_API_URL = "https://agents-course-unit4-scoring.hf.space"
OPENROUTER_BASE_URL = "https://openrouter.ai/api/v1"
# Fleet of free OpenRouter models. Tried in order. When one rate-limits or errors,
# we fall through to the next one. Mix of strong reasoning + tool use.
TEXT_MODELS = [
m.strip() for m in os.getenv(
"OPENROUTER_MODELS",
# Currently-valid free OpenRouter models (verify at https://openrouter.ai/models?q=free).
"meta-llama/llama-3.3-70b-instruct:free,"
"mistralai/mistral-small-3.2-24b-instruct:free,"
"google/gemini-2.0-flash-exp:free,"
"qwen/qwen-2.5-72b-instruct:free,"
"deepseek/deepseek-r1:free,"
"deepseek/deepseek-chat:free"
).split(",")
if m.strip()
]
# Vision-capable free model. Gemini Flash is multimodal and free on OpenRouter.
VISION_MODEL = os.getenv("OPENROUTER_VISION_MODEL", "google/gemini-2.0-flash-exp:free")
MAX_TOOL_ITERATIONS = 7
TOOL_RESULT_MAX_CHARS = 3500
ANSWER_CACHE_PATH = os.getenv("ANSWER_CACHE_PATH", "/tmp/answers_cache.json")
RESULTS_CSV_PATH = "/tmp/gaia_results.csv"
INTER_QUESTION_SLEEP = float(os.getenv("INTER_QUESTION_SLEEP", "2"))
INTER_TOOL_SLEEP = float(os.getenv("INTER_TOOL_SLEEP", "0.5"))
# Track downloaded task files so vision/audio tools can re-use them by task_id.
_TASK_FILE_CACHE: dict[str, dict] = {}
# ---------------------------------------------------------------------------
# Tool implementations
# ---------------------------------------------------------------------------
def tool_web_search(query: str, max_results: int = 5) -> str:
"""Web search. Tries Tavily first, falls back to DuckDuckGo."""
tavily_key = os.getenv("TAVILY_API_KEY")
if tavily_key:
try:
from tavily import TavilyClient
client = TavilyClient(api_key=tavily_key)
res = client.search(
query=query,
max_results=max_results,
search_depth="basic",
include_answer=True,
)
lines = ["[provider: tavily]"]
if res.get("answer"):
lines.append(f"Answer: {res['answer']}")
for r in res.get("results", [])[:max_results]:
lines.append(
f"- {r.get('title', '')}\n {r.get('url', '')}\n {r.get('content', '')[:400]}"
)
if len(lines) > 1:
return "\n".join(lines)
except Exception as e:
print(f"tavily search failed, falling back to DDG: {e}")
else:
print("[search] TAVILY_API_KEY not set; using DDG.")
try:
from duckduckgo_search import DDGS
results = ["[provider: duckduckgo]"]
with DDGS() as ddgs:
for r in ddgs.text(query, max_results=max_results):
results.append(
f"- {r.get('title', '')}\n {r.get('href', '')}\n {r.get('body', '')[:400]}"
)
if len(results) == 1:
return "[provider: duckduckgo] No results."
return "\n".join(results)
except Exception as e:
return f"web_search error: {e}"
def tool_fetch_url(url: str, max_chars: int = 3500) -> str:
"""Fetch a URL and return readable text (HTML stripped)."""
try:
from bs4 import BeautifulSoup
headers = {
"User-Agent": (
"Mozilla/5.0 (compatible; GAIA-Agent/1.0; "
"+https://huggingface.co/learn/agents-course)"
)
}
resp = requests.get(url, headers=headers, timeout=20)
resp.raise_for_status()
ctype = resp.headers.get("Content-Type", "")
if "html" in ctype or url.endswith((".html", ".htm")) or "<html" in resp.text[:500].lower():
soup = BeautifulSoup(resp.text, "lxml")
for tag in soup(["script", "style", "noscript"]):
tag.decompose()
text = soup.get_text(separator="\n")
else:
text = resp.text
text = re.sub(r"\n\s*\n+", "\n\n", text).strip()
if len(text) > max_chars:
text = text[:max_chars] + "\n...[truncated]"
return text
except Exception as e:
return f"fetch_url error: {e}"
def tool_wikipedia(query: str, sentences: int = 6) -> str:
"""Look up a topic on Wikipedia and return a summary."""
try:
import wikipedia
wikipedia.set_lang("en")
try:
return wikipedia.summary(query, sentences=sentences, auto_suggest=True, redirect=True)
except wikipedia.DisambiguationError as de:
options = ", ".join(de.options[:8])
return f"Disambiguation. Options: {options}"
except wikipedia.PageError:
hits = wikipedia.search(query, results=5)
if not hits:
return "No Wikipedia page found."
return wikipedia.summary(hits[0], sentences=sentences, auto_suggest=False, redirect=True)
except Exception as e:
return f"wikipedia error: {e}"
def tool_python(code: str) -> str:
"""Run a small Python snippet and return stdout (or the value of `result`)."""
buf = io.StringIO()
local_ns: dict = {}
try:
with contextlib.redirect_stdout(buf):
exec(code, {"__builtins__": __builtins__}, local_ns)
out = buf.getvalue().strip()
if not out and "result" in local_ns:
out = str(local_ns["result"])
return (out or "(no output)")[:2500]
except Exception as e:
return f"python error: {e}\n{traceback.format_exc(limit=2)}"
def _extract_youtube_id(url: str) -> str | None:
try:
u = urlparse(url)
if "youtu.be" in u.netloc:
return u.path.lstrip("/").split("/")[0] or None
if "youtube.com" in u.netloc:
if u.path == "/watch":
return parse_qs(u.query).get("v", [None])[0]
if u.path.startswith("/embed/") or u.path.startswith("/shorts/"):
return u.path.split("/")[2]
except Exception:
pass
return None
def tool_youtube_transcript(url: str, max_chars: int = 3500) -> str:
"""Fetch the spoken transcript of a YouTube video."""
try:
from youtube_transcript_api import YouTubeTranscriptApi
vid = _extract_youtube_id(url) or url.strip()
try:
data = YouTubeTranscriptApi.get_transcript(vid, languages=["en", "en-US", "en-GB"])
except Exception:
tlist = YouTubeTranscriptApi.list_transcripts(vid)
t = next(iter(tlist), None)
data = t.fetch() if t else []
text = " ".join(seg.get("text", "") for seg in data).strip()
text = re.sub(r"\s+", " ", text)
if len(text) > max_chars:
text = text[:max_chars] + " ...[truncated]"
return text or "(empty transcript)"
except Exception as e:
return f"youtube_transcript error: {e}"
def _hf_inference(model: str, data: bytes, content_type: str) -> str:
"""Call HF Inference API with raw bytes (used for Whisper audio transcription)."""
hf_token = os.getenv("HF_TOKEN") or os.getenv("HUGGINGFACEHUB_API_TOKEN")
headers = {"Content-Type": content_type}
if hf_token:
headers["Authorization"] = f"Bearer {hf_token}"
url = f"https://api-inference.huggingface.co/models/{model}"
# HF inference can be cold-started; retry a few times.
for attempt in range(3):
resp = requests.post(url, headers=headers, data=data, timeout=120)
if resp.status_code == 503:
# Model loading — wait per estimated_time.
try:
wait = float(resp.json().get("estimated_time", 10))
except Exception:
wait = 10
wait = min(max(wait, 3), 30)
print(f"HF model {model} loading; waiting {wait}s...")
time.sleep(wait)
continue
resp.raise_for_status()
return resp.text
raise RuntimeError(f"HF model {model} not ready after retries")
def tool_transcribe_audio(task_id: str) -> str:
"""Transcribe an attached audio file using HF Whisper Inference API."""
try:
info = _TASK_FILE_CACHE.get(task_id)
if not info:
tool_get_task_file(task_id)
info = _TASK_FILE_CACHE.get(task_id)
if not info or not os.path.exists(info.get("path", "")):
return "transcribe_audio error: no local file for task"
path = info["path"]
ext = os.path.splitext(path)[1].lower().lstrip(".")
ctype_map = {
"mp3": "audio/mpeg", "wav": "audio/wav", "m4a": "audio/mp4",
"ogg": "audio/ogg", "flac": "audio/flac", "webm": "audio/webm",
}
ctype = ctype_map.get(ext, "audio/mpeg")
with open(path, "rb") as f:
data = f.read()
raw = _hf_inference("openai/whisper-large-v3", data, ctype)
try:
obj = json.loads(raw)
if isinstance(obj, dict) and "text" in obj:
text = obj["text"]
elif isinstance(obj, list) and obj and "text" in obj[0]:
text = obj[0]["text"]
else:
text = raw
except Exception:
text = raw
text = (text or "").strip()
if len(text) > 4000:
text = text[:4000] + " ...[truncated]"
return text or "(empty transcript)"
except Exception as e:
return f"transcribe_audio error: {e}"
def tool_view_image(task_id: str, question: str = "") -> str:
"""Inspect an image attached to a GAIA task using a vision-capable LLM via OpenRouter."""
try:
from openai import OpenAI
info = _TASK_FILE_CACHE.get(task_id)
if not info:
tool_get_task_file(task_id)
info = _TASK_FILE_CACHE.get(task_id)
if not info or not os.path.exists(info.get("path", "")):
return "view_image error: no local file for task"
suffix = os.path.splitext(info["path"])[1].lower().lstrip(".")
if suffix == "jpg":
suffix = "jpeg"
if suffix not in {"png", "jpeg", "gif", "webp"}:
return f"view_image error: unsupported image type .{suffix}"
with open(info["path"], "rb") as f:
b64 = base64.b64encode(f.read()).decode("ascii")
data_url = f"data:image/{suffix};base64,{b64}"
prompt = (
question.strip()
or "Describe this image in detail, including any text, numbers, or symbols visible."
)
client = OpenAI(
base_url=OPENROUTER_BASE_URL,
api_key=os.getenv("OPENROUTER_API_KEY"),
)
resp = client.chat.completions.create(
model=VISION_MODEL,
messages=[
{
"role": "user",
"content": [
{"type": "text", "text": prompt},
{"type": "image_url", "image_url": {"url": data_url}},
],
}
],
temperature=0.0,
max_tokens=600,
extra_headers={
"HTTP-Referer": "https://huggingface.co/learn/agents-course",
"X-Title": "GAIA Agent",
},
)
return (resp.choices[0].message.content or "").strip()
except Exception as e:
return f"view_image error: {e}"
def tool_get_task_file(task_id: str, api_url: str = DEFAULT_API_URL) -> str:
"""Download the file attached to a task and return a text preview."""
try:
resp = requests.get(f"{api_url}/files/{task_id}", timeout=30)
if resp.status_code == 404:
return (
"NO_FILE: This task has no attached file. Do not call get_task_file again. "
"Answer using web_search / wikipedia / python / your own knowledge."
)
resp.raise_for_status()
ctype = resp.headers.get("Content-Type", "")
cdisp = resp.headers.get("Content-Disposition", "")
fname_match = re.search(r'filename="?([^"]+)"?', cdisp)
fname = fname_match.group(1) if fname_match else f"{task_id}"
suffix = os.path.splitext(fname)[1].lower()
tmp = tempfile.NamedTemporaryFile(prefix=f"{task_id}_", suffix=suffix, delete=False)
tmp.write(resp.content)
tmp.close()
_TASK_FILE_CACHE[task_id] = {
"path": tmp.name,
"name": fname,
"ctype": ctype,
"size": len(resp.content),
}
info = (
f"File: {fname}\nContent-Type: {ctype}\nSize: {len(resp.content)} bytes\n"
)
if suffix in {".txt", ".md", ".csv", ".json", ".py", ".tsv", ".log", ".xml", ".html"}:
try:
text = resp.content.decode("utf-8", errors="replace")
except Exception:
text = resp.text
return info + "\n--- preview ---\n" + text[:3000]
if suffix in {".xlsx", ".xls"}:
try:
df = pd.read_excel(tmp.name)
csv = df.to_csv(index=False)
if len(csv) > 3000:
csv = csv[:3000] + "\n...[truncated]"
return info + "\n--- excel as csv ---\n" + csv
except Exception as e:
return info + f"\n(excel parse error: {e})"
if suffix == ".pdf":
try:
from pypdf import PdfReader
reader = PdfReader(tmp.name)
pages = [p.extract_text() or "" for p in reader.pages[:6]]
return info + "\n--- pdf text ---\n" + "\n".join(pages)[:3000]
except Exception as e:
return info + f"\n(pdf parse error: {e})"
if suffix in {".mp3", ".wav", ".m4a", ".ogg", ".flac", ".webm"}:
return info + "\nAudio file. Call transcribe_audio(task_id) to read it."
if suffix in {".png", ".jpg", ".jpeg", ".gif", ".webp"}:
return info + "\nImage file. Call view_image(task_id, question='...')."
return info + "\n(binary file; no preview)"
except Exception as e:
return f"get_task_file error: {e}"
# ---------------------------------------------------------------------------
# Tool schema (OpenAI-compatible)
# ---------------------------------------------------------------------------
TOOLS_SPEC = [
{
"type": "function",
"function": {
"name": "web_search",
"description": "Search the web (Tavily preferred, DuckDuckGo fallback). Returns titles, URLs, snippets, and Tavily's synthesized answer.",
"parameters": {
"type": "object",
"properties": {
"query": {"type": "string"},
"max_results": {"type": "integer"},
},
"required": ["query"],
},
},
},
{
"type": "function",
"function": {
"name": "fetch_url",
"description": "Fetch a URL and return cleaned page text. Use after web_search to read a result page.",
"parameters": {
"type": "object",
"properties": {
"url": {"type": "string"},
"max_chars": {"type": "integer"},
},
"required": ["url"],
},
},
},
{
"type": "function",
"function": {
"name": "wikipedia",
"description": "Get a Wikipedia summary for a person, place, work, or topic. Use FIRST for biographical or list questions.",
"parameters": {
"type": "object",
"properties": {
"query": {"type": "string"},
"sentences": {"type": "integer"},
},
"required": ["query"],
},
},
},
{
"type": "function",
"function": {
"name": "python",
"description": "Execute a Python snippet for math, sums, dates, sorting, alphabetizing, parsing, string reversal, set logic. Use print() or assign to `result`.",
"parameters": {
"type": "object",
"properties": {"code": {"type": "string"}},
"required": ["code"],
},
},
},
{
"type": "function",
"function": {
"name": "get_task_file",
"description": "Download the file attached to a GAIA task by task_id. Returns NO_FILE if no file exists.",
"parameters": {
"type": "object",
"properties": {"task_id": {"type": "string"}},
"required": ["task_id"],
},
},
},
{
"type": "function",
"function": {
"name": "transcribe_audio",
"description": "Transcribe an attached audio file (.mp3/.wav/.m4a/.ogg/.flac) using Whisper.",
"parameters": {
"type": "object",
"properties": {"task_id": {"type": "string"}},
"required": ["task_id"],
},
},
},
{
"type": "function",
"function": {
"name": "view_image",
"description": "Inspect an attached image (.png/.jpg/.gif/.webp) using a vision model. Pass a focused question.",
"parameters": {
"type": "object",
"properties": {
"task_id": {"type": "string"},
"question": {"type": "string"},
},
"required": ["task_id"],
},
},
},
{
"type": "function",
"function": {
"name": "youtube_transcript",
"description": "Fetch the spoken transcript of a YouTube video given its URL. Only captures speech, not visual content.",
"parameters": {
"type": "object",
"properties": {
"url": {"type": "string"},
"max_chars": {"type": "integer"},
},
"required": ["url"],
},
},
},
]
TOOL_FUNCTIONS = {
"web_search": lambda args: tool_web_search(args["query"], int(args.get("max_results", 5))),
"fetch_url": lambda args: tool_fetch_url(args["url"], int(args.get("max_chars", 3500))),
"wikipedia": lambda args: tool_wikipedia(args["query"], int(args.get("sentences", 6))),
"python": lambda args: tool_python(args["code"]),
"get_task_file": lambda args: tool_get_task_file(args["task_id"]),
"transcribe_audio": lambda args: tool_transcribe_audio(args["task_id"]),
"view_image": lambda args: tool_view_image(args["task_id"], args.get("question", "")),
"youtube_transcript": lambda args: tool_youtube_transcript(
args["url"], int(args.get("max_chars", 3500))
),
}
SYSTEM_PROMPT = """You are a careful research agent answering GAIA benchmark questions.
Tools: web_search, fetch_url, wikipedia, python, get_task_file, transcribe_audio, view_image, youtube_transcript.
Decision rules:
- If the question references "attached file/image/audio/Excel/PDF/.mp3/.xlsx/.py/recording/photo/image", call get_task_file FIRST.
- Audio (.mp3, .wav, etc.) -> transcribe_audio(task_id) after get_task_file.
- Image (.png, .jpg, etc.) -> view_image(task_id, question="<focused question>") after get_task_file.
- Excel/CSV/text/PDF — the get_task_file preview is enough; use python to compute on it.
- If get_task_file returns NO_FILE, do NOT call it again.
- For YouTube URLs, use youtube_transcript(url) directly. (No get_task_file needed.) The transcript is speech only — for visual questions, give your best estimate.
- For factual lookups about people, places, artists, albums, animals, Wikipedia featured articles: START with wikipedia.
- For everything else research-y: web_search then fetch_url the most relevant URL.
- Use python for ALL arithmetic, sums, date math, sorting, alphabetizing, set/group operations, string reversal. Never compute by hand.
- For Excel/CSV totals, after get_task_file shows the data, ALWAYS use python to compute the sum precisely.
Be decisive — don't repeat the same tool with the same args. You have 7 tool turns.
ANSWER FORMATTING (the grader does an exact-match comparison; sentence answers ALWAYS lose):
Worked examples of correct GAIA format:
- Q: "How many albums..." -> "3" (NOT "3 albums" or "There were 3 albums")
- Q: "Express your answer in USD with two decimal places" -> "89706.00"
- Q: "Give the IOC country code" -> "MLT"
- Q: "Just the city name without abbreviations" -> "Saint Petersburg"
- Q: "Give only the first name" -> "Bartek"
- Q: "Comma separated list ... in alphabetical order" -> "broccoli, celery, fresh basil, lettuce, sweet potatoes, zucchini"
- Q: "Under what NASA award number..." -> "80NSSC21K1130"
- Q: opposite of "left" -> "right"
Strict rules:
- Reply with ONLY the answer. No preamble. No explanation. No quotes. No trailing period.
- Do NOT include "FINAL ANSWER", "Answer:", or any label.
- Numbers: digits only, no commas, no units, no $ — UNLESS the question asks for the unit.
- Currency "two decimal places": e.g. "89706.00".
- Strings: no leading articles ("the", "a") unless required; no abbreviations ("Saint" not "St."); digits as digits.
- Names: read the question carefully ("first name only" / "last name only" / "surname" / "full name").
- Lists: comma-separated, ONE space after each comma. Sort if asked.
"""
def _maybe_reverse_text(question: str) -> str:
"""If the question text looks reversed, flip it."""
q = question.strip()
if not q:
return question
starts_with_punct = q[0] in ".,;:!?"
reversed_text = q[::-1]
common = (" the ", " of ", " and ", " to ", " is ", " a ", " in ", " for ")
hits = sum(1 for w in common if w in (" " + reversed_text.lower() + " "))
if starts_with_punct and hits >= 2:
return reversed_text
return question
# ---------------------------------------------------------------------------
# Agent
# ---------------------------------------------------------------------------
class OpenRouterAgent:
def __init__(self):
try:
from openai import OpenAI
except ImportError as e:
raise RuntimeError("openai package not installed") from e
api_key = os.getenv("OPENROUTER_API_KEY")
if not api_key:
raise RuntimeError(
"OPENROUTER_API_KEY is not set. Get one free at https://openrouter.ai/keys "
"and add it as a Secret in your HF Space settings."
)
self.client = OpenAI(base_url=OPENROUTER_BASE_URL, api_key=api_key)
self.models = list(TEXT_MODELS)
self.exhausted: set[str] = set()
self.extra_headers = {
"HTTP-Referer": "https://huggingface.co/learn/agents-course",
"X-Title": "GAIA Agent",
}
print(f"OpenRouterAgent initialized with model fleet: {self.models}")
def _chat(self, messages, use_tools: bool = True, max_tokens: int = 800):
"""Try each model in the fleet. Falls through on rate limit / error."""
last_error: Exception | None = None
for m in self.models:
if m in self.exhausted:
continue
for attempt in range(2):
try:
kwargs = dict(
model=m,
messages=messages,
temperature=0.0,
max_tokens=max_tokens,
extra_headers=self.extra_headers,
)
if use_tools:
kwargs["tools"] = TOOLS_SPEC
kwargs["tool_choice"] = "auto"
return self.client.chat.completions.create(**kwargs)
except Exception as e:
msg = str(e)
last_error = e
is_rate = "429" in msg or "rate" in msg.lower() or "limit" in msg.lower()
is_quota = ("daily" in msg.lower() or "quota" in msg.lower()
or "exhausted" in msg.lower())
# 404 / "No endpoints found" / "model not found" -> dead model, never retry.
is_dead = (
"404" in msg
or "no endpoints" in msg.lower()
or "not found" in msg.lower()
or "model_not_found" in msg.lower()
)
if is_dead:
print(f"[{m}] model unavailable (404 / no endpoints); marking exhausted.")
self.exhausted.add(m)
break
if is_rate and is_quota:
print(f"[{m}] daily quota exhausted; switching model.")
self.exhausted.add(m)
break
if is_rate:
wait = 4 * (attempt + 1)
print(f"[{m}] rate-limited; sleeping {wait}s (attempt {attempt + 1}/2)")
time.sleep(wait)
continue
print(f"[{m}] API error: {repr(e)[:240]} — trying next model.")
break
err_str = repr(last_error) if last_error else "no error captured"
raise RuntimeError(f"All OpenRouter models failed. {err_str}")
def __call__(self, question: str, task_id: str | None = None) -> str:
flipped = _maybe_reverse_text(question)
if flipped != question:
print("[reversed-text detected, flipping question]")
question = flipped
user_content = question
if task_id:
user_content = f"task_id: {task_id}\n\nQuestion: {question}"
messages = [
{"role": "system", "content": SYSTEM_PROMPT},
{"role": "user", "content": user_content},
]
collected_facts: list[str] = []
seen_calls: set[str] = set()
for step in range(MAX_TOOL_ITERATIONS):
try:
resp = self._chat(messages, use_tools=True, max_tokens=800)
except Exception as e:
print(f"chat at step {step} failed: {e}")
break
msg = resp.choices[0].message
tool_calls = getattr(msg, "tool_calls", None)
if not tool_calls:
answer = (msg.content or "").strip()
if answer:
return self._finalize(answer, question, collected_facts)
break
messages.append(
{
"role": "assistant",
"content": msg.content or "",
"tool_calls": [
{
"id": tc.id,
"type": "function",
"function": {
"name": tc.function.name,
"arguments": tc.function.arguments,
},
}
for tc in tool_calls
],
}
)
for tc in tool_calls:
name = tc.function.name
try:
args = json.loads(tc.function.arguments or "{}")
except json.JSONDecodeError:
args = {}
call_key = f"{name}|{json.dumps(args, sort_keys=True, default=str)[:300]}"
if call_key in seen_calls:
print(f"[tool] {name}({str(args)[:120]}) [DUPLICATE — skipping]")
result = "DUPLICATE_CALL: you already called this with the same args. Try a different query, a different tool, or give your final answer."
else:
seen_calls.add(call_key)
fn = TOOL_FUNCTIONS.get(name)
print(f"[tool] {name}({str(args)[:200]})")
if fn is None:
result = f"unknown tool: {name}"
else:
try:
result = fn(args)
except Exception as e:
result = f"{name} error: {e}"
if not isinstance(result, str):
result = str(result)
if len(result) > TOOL_RESULT_MAX_CHARS:
result = result[:TOOL_RESULT_MAX_CHARS] + "\n...[truncated]"
collected_facts.append(f"[{name}] {result[:1200]}")
messages.append(
{
"role": "tool",
"tool_call_id": tc.id,
"name": name,
"content": result,
}
)
if INTER_TOOL_SLEEP > 0:
time.sleep(INTER_TOOL_SLEEP)
return self._synthesize(question, collected_facts)
def _synthesize(self, question: str, facts: list[str]) -> str:
"""Final answer pass on a short context. No tools."""
joined = "\n\n".join(facts[-8:])
if len(joined) > 5000:
joined = joined[-5000:]
synth_messages = [
{
"role": "system",
"content": (
"You are a strict GAIA answer formatter. Read the question and the research "
"notes, then output ONLY the final answer string. No preamble, no labels, no "
"explanation, no quotes, no trailing period. Match the question's required "
"format exactly. If notes are insufficient, give your single best guess based "
"on general knowledge. Never refuse, never apologize, never reply with empty."
),
},
{
"role": "user",
"content": (
f"Question:\n{question}\n\n"
f"Research notes:\n{joined or '(no notes)'}\n\nFinal answer:"
),
},
]
try:
resp = self._chat(synth_messages, use_tools=False, max_tokens=120)
return self._postprocess_answer(
(resp.choices[0].message.content or "").strip(), question
) or "unknown"
except Exception as e:
print(f"synthesis failed: {e}")
# Last-resort: tiny zero-shot guess
try:
resp = self._chat(
[
{"role": "system", "content": "Answer in 1-5 words. No explanation."},
{"role": "user", "content": question[:500]},
],
use_tools=False,
max_tokens=40,
)
return self._postprocess_answer(
(resp.choices[0].message.content or "").strip(), question
) or "unknown"
except Exception as e2:
print(f"last-resort guess failed: {e2}")
return "unknown"
def _finalize(self, raw: str, question: str, facts: list[str]) -> str:
cleaned = self._postprocess_answer(raw, question)
if not cleaned:
return self._synthesize(question, facts)
looks_sentence = (
len(cleaned.split()) > 12
or re.search(
r"\b(because|received|grant|seems|unable|sorry|cannot|provides|indicating|"
r"web_search|youtube_transcript|fetch_url)\b",
cleaned,
re.IGNORECASE,
)
)
if looks_sentence:
try:
resp = self._chat(
[
{
"role": "system",
"content": (
"Extract ONLY the final answer from the assistant text below, "
"matching the question's required format exactly. No preamble, "
"no explanation, no quotes, no trailing period, no labels."
),
},
{
"role": "user",
"content": f"Question: {question}\n\nAssistant text: {cleaned}\n\nFinal answer:",
},
],
use_tools=False,
max_tokens=80,
)
reformat = (resp.choices[0].message.content or "").strip()
reformat = self._postprocess_answer(reformat, question)
if reformat:
return reformat
except Exception as e:
print(f"reformat pass failed: {e}")
return cleaned
@staticmethod
def _postprocess_answer(text: str, question: str = "") -> str:
if not text:
return ""
text = text.strip()
text = re.sub(
r"^(final\s*answer|answer|the\s*answer\s*is)\s*[:\-]?\s*",
"",
text,
flags=re.IGNORECASE,
)
text = text.strip("`")
if len(text) >= 2 and text[0] == text[-1] and text[0] in {'"', "'"}:
text = text[1:-1].strip()
q_lower = question.lower()
wants_number = bool(
re.search(r"\bhow many\b|\bhow much\b|\bwhat number\b|\bcount\b", q_lower)
)
if wants_number and not re.fullmatch(r"-?\d+(\.\d+)?", text):
m = re.search(r"-?\d+(?:\.\d+)?", text.replace(",", ""))
if m:
text = m.group(0)
if text.endswith(".") and " " not in text:
text = text[:-1]
return text.strip()
# ---------------------------------------------------------------------------
# Cache
# ---------------------------------------------------------------------------
def _load_cache() -> dict:
try:
with open(ANSWER_CACHE_PATH, "r", encoding="utf-8") as f:
return json.load(f)
except (FileNotFoundError, json.JSONDecodeError):
return {}
def _save_cache(cache: dict) -> None:
try:
with open(ANSWER_CACHE_PATH, "w", encoding="utf-8") as f:
json.dump(cache, f, ensure_ascii=False, indent=2)
except Exception as e:
print(f"cache save error: {e}")
# ---------------------------------------------------------------------------
# Gradio submission flow
# ---------------------------------------------------------------------------
def run_and_submit_all(profile: gr.OAuthProfile | None):
space_id = os.getenv("SPACE_ID")
if profile:
username = f"{profile.username}"
print(f"User logged in: {username}")
else:
return "Please Login to Hugging Face with the button.", None, None
api_url = DEFAULT_API_URL
questions_url = f"{api_url}/questions"
submit_url = f"{api_url}/submit"
try:
agent = OpenRouterAgent()
except Exception as e:
print(f"Error instantiating agent: {e}")
return f"Error initializing agent: {e}", None, None
agent_code = f"https://huggingface.co/spaces/{space_id}/tree/main"
print(agent_code)
print(f"Fetching questions from: {questions_url}")
try:
response = requests.get(questions_url, timeout=15)
response.raise_for_status()
questions_data = response.json()
if not questions_data:
return "Fetched questions list is empty or invalid format.", None, None
print(f"Fetched {len(questions_data)} questions.")
except requests.exceptions.RequestException as e:
return f"Error fetching questions: {e}", None, None
except Exception as e:
return f"An unexpected error occurred fetching questions: {e}", None, None
results_log = []
answers_payload = []
cache = _load_cache()
if cache:
print(f"Loaded {len(cache)} cached answers from {ANSWER_CACHE_PATH}")
print(f"Running agent on {len(questions_data)} questions...")
for idx, item in enumerate(questions_data, 1):
task_id = item.get("task_id")
question_text = item.get("question")
if not task_id or question_text is None:
print(f"Skipping item with missing task_id or question: {item}")
continue
print(f"\n=== [{idx}/{len(questions_data)}] task_id={task_id} ===")
cached = cache.get(task_id)
if cached and not str(cached).startswith("AGENT ERROR") and cached not in {"", "unknown"}:
submitted_answer = cached
print(f"(cache hit) {submitted_answer[:80]}")
else:
try:
submitted_answer = agent(question_text, task_id=task_id)
except Exception as e:
print(f"Error running agent on task {task_id}: {e}")
submitted_answer = f"AGENT ERROR: {e}"
cache[task_id] = submitted_answer
_save_cache(cache)
answers_payload.append({"task_id": task_id, "submitted_answer": submitted_answer})
results_log.append(
{"Task ID": task_id, "Question": question_text, "Submitted Answer": submitted_answer}
)
if INTER_QUESTION_SLEEP > 0 and idx < len(questions_data):
time.sleep(INTER_QUESTION_SLEEP)
if not answers_payload:
df = pd.DataFrame(results_log)
df.to_csv(RESULTS_CSV_PATH, index=False)
return "Agent did not produce any answers to submit.", df, RESULTS_CSV_PATH
df = pd.DataFrame(results_log)
df.to_csv(RESULTS_CSV_PATH, index=False)
print(f"Results CSV written to {RESULTS_CSV_PATH}")
submission_data = {
"username": username.strip(),
"agent_code": agent_code,
"answers": answers_payload,
}
print(f"Submitting {len(answers_payload)} answers for user '{username}'...")
last_error = None
for attempt in range(3):
try:
response = requests.post(submit_url, json=submission_data, timeout=120)
response.raise_for_status()
result_data = response.json()
final_status = (
f"Submission Successful!\n"
f"User: {result_data.get('username')}\n"
f"Overall Score: {result_data.get('score', 'N/A')}% "
f"({result_data.get('correct_count', '?')}/{result_data.get('total_attempted', '?')} correct)\n"
f"Message: {result_data.get('message', 'No message received.')}"
)
return final_status, df, RESULTS_CSV_PATH
except requests.exceptions.HTTPError as e:
status = e.response.status_code if e.response is not None else "?"
last_error = e
print(f"Submission attempt {attempt + 1} failed: HTTP {status}")
if status and 500 <= int(status) < 600:
time.sleep(5 * (attempt + 1))
continue
error_detail = f"Server responded with status {status}."
try:
error_detail += f" Detail: {e.response.json().get('detail', e.response.text)}"
except Exception:
error_detail += f" Response: {e.response.text[:500] if e.response is not None else ''}"
return f"Submission Failed: {error_detail}", df, RESULTS_CSV_PATH
except requests.exceptions.Timeout as e:
last_error = e
print(f"Submission attempt {attempt + 1} timed out.")
time.sleep(5 * (attempt + 1))
continue
except requests.exceptions.RequestException as e:
last_error = e
print(f"Submission attempt {attempt + 1} network error: {e}")
time.sleep(5 * (attempt + 1))
continue
return (
f"Submission Failed after retries: {last_error}.",
df,
RESULTS_CSV_PATH,
)
# --- Gradio UI ---
with gr.Blocks() as demo:
gr.Markdown("# GAIA Agent (OpenRouter) — Evaluation Runner")
gr.Markdown(
"""
**Setup**
1. Add a Space secret named `OPENROUTER_API_KEY` (free at [openrouter.ai/keys](https://openrouter.ai/keys)).
2. *Optional but recommended:* `TAVILY_API_KEY` for better search.
3. Optional: `HF_TOKEN` for Whisper audio transcription via HF Inference API.
4. Optional env vars: `OPENROUTER_MODELS` (comma-separated fleet), `OPENROUTER_VISION_MODEL`.
5. Log in to Hugging Face below and click **Run Evaluation & Submit All Answers**.
Tools: `web_search`, `fetch_url`, `wikipedia`, `python`, `get_task_file`,
`transcribe_audio` (HF Whisper), `view_image` (Gemini Flash via OpenRouter), `youtube_transcript`.
Model fleet falls through automatically when one rate-limits.
"""
)
gr.LoginButton()
run_button = gr.Button("Run Evaluation & Submit All Answers")
status_output = gr.Textbox(label="Run Status / Submission Result", lines=5, interactive=False)
results_table = gr.DataFrame(label="Questions and Agent Answers", wrap=True)
results_csv = gr.File(label="Download Results CSV (paste back to me for tuning)")
run_button.click(fn=run_and_submit_all, outputs=[status_output, results_table, results_csv])
if __name__ == "__main__":
print("\n" + "-" * 30 + " App Starting " + "-" * 30)
space_host_startup = os.getenv("SPACE_HOST")
space_id_startup = os.getenv("SPACE_ID")
if space_host_startup:
print(f"✅ SPACE_HOST found: {space_host_startup}")
else:
print("ℹ️ SPACE_HOST not found (running locally?).")
if space_id_startup:
print(f"✅ SPACE_ID found: {space_id_startup}")
print(f" Repo Tree URL: https://huggingface.co/spaces/{space_id_startup}/tree/main")
else:
print("ℹ️ SPACE_ID not found (running locally?).")
if not os.getenv("OPENROUTER_API_KEY"):
print("⚠️ OPENROUTER_API_KEY is not set. Set it before running evaluation.")
if not os.getenv("TAVILY_API_KEY"):
print("ℹ️ TAVILY_API_KEY not set — search will use DuckDuckGo (less reliable).")
if not os.getenv("HF_TOKEN"):
print("ℹ️ HF_TOKEN not set — audio transcription may rate-limit on cold starts.")
print("-" * (60 + len(" App Starting ")) + "\n")
demo.launch(debug=True, share=False)