arc-agi-2-solver / program_synthesis.py
Interstellar007's picture
Upload program_synthesis.py
a7e2535 verified
"""
SOAR-style Program Synthesis for ARC-AGI
Evolutionary search: Sample programs → Execute → Refine → Vote
Based on:
- SOAR (Pourcel et al., 2507.14172)
- Product of Experts (Franzen et al., 2505.07859)
- CodeIt (Butt et al., 2402.04858)
"""
import copy
import json
import time
import random
import traceback
import signal
from typing import List, Dict, Tuple, Optional, Any
from collections import Counter, defaultdict
from concurrent.futures import ProcessPoolExecutor, TimeoutError as FuturesTimeout
import numpy as np
from arc_data import (
grids_equal, grid_to_string, grid_to_numpy, numpy_to_grid,
D8_TRANSFORMS, augment_task, reverse_d8, reverse_color_permutation,
create_color_permutation, apply_color_permutation
)
# ============================================================
# Safe code execution with timeout
# ============================================================
def _execute_transform(code: str, input_grid: List[List[int]], timeout_sec: float = 5.0) -> Optional[List[List[int]]]:
"""Execute a transform function safely with timeout."""
try:
namespace = {}
exec(code, namespace)
if "transform" not in namespace:
return None
transform_fn = namespace["transform"]
result = transform_fn(copy.deepcopy(input_grid))
# Validate output
if not isinstance(result, list):
return None
if len(result) == 0:
return None
for row in result:
if not isinstance(row, list):
return None
for cell in row:
if not isinstance(cell, (int, float)):
return None
if int(cell) < 0 or int(cell) > 9:
return None
# Normalize to int
result = [[int(c) for c in row] for row in result]
return result
except Exception:
return None
def evaluate_program_on_task(code: str, task: Dict) -> Tuple[float, Optional[List[List[int]]]]:
"""
Evaluate a program on all training pairs.
Returns (accuracy, test_output).
accuracy = fraction of training pairs correctly predicted.
"""
train_pairs = task["train"]
correct = 0
total = len(train_pairs)
for pair in train_pairs:
pred = _execute_transform(code, pair["input"])
if pred is not None and grids_equal(pred, pair["output"]):
correct += 1
accuracy = correct / total if total > 0 else 0.0
# Get test output
test_output = None
if task.get("test") and len(task["test"]) > 0:
test_output = _execute_transform(code, task["test"][0]["input"])
return accuracy, test_output
# ============================================================
# Program generation prompts
# ============================================================
SAMPLING_PROMPT = """You are an expert at solving abstract reasoning puzzles. Given input-output grid examples, write a Python function that transforms the input grid to the output grid.
Each grid is a 2D list of integers 0-9, where each integer represents a color:
0=black, 1=blue, 2=red, 3=green, 4=yellow, 5=grey, 6=magenta, 7=orange, 8=cyan, 9=maroon
{examples}
Write a Python function `transform(input_grid: list[list[int]]) -> list[list[int]]` that correctly transforms any input to its output. You may use numpy. Think step by step about the pattern, then write the code.
```python
{code_prefix}
```"""
REFINEMENT_PROMPT = """The following Python function was supposed to transform input grids to output grids, but it has errors.
{examples}
Previous attempt:
```python
{previous_code}
```
{error_feedback}
Please fix the function to correctly handle all examples. Write the corrected `transform` function:
```python
{code_prefix}
```"""
def format_examples_for_prompt(task: Dict) -> str:
"""Format task examples for the prompt."""
parts = []
for i, pair in enumerate(task["train"]):
parts.append(f"Example {i+1}:")
parts.append(f" Input ({len(pair['input'])}x{len(pair['input'][0])}):")
for row in pair["input"]:
parts.append(f" {row}")
parts.append(f" Output ({len(pair['output'])}x{len(pair['output'][0])}):")
for row in pair["output"]:
parts.append(f" {row}")
if task.get("test") and len(task["test"]) > 0:
test_inp = task["test"][0]["input"]
parts.append(f"\nTest Input ({len(test_inp)}x{len(test_inp[0])}):")
for row in test_inp:
parts.append(f" {row}")
return "\n".join(parts)
def get_error_feedback(code: str, task: Dict) -> str:
"""Generate error feedback for a failed program."""
feedback_parts = []
for i, pair in enumerate(task["train"]):
pred = _execute_transform(code, pair["input"])
if pred is None:
feedback_parts.append(f"Example {i+1}: EXECUTION ERROR")
elif not grids_equal(pred, pair["output"]):
feedback_parts.append(f"Example {i+1}: WRONG OUTPUT")
feedback_parts.append(f" Expected: {pair['output']}")
feedback_parts.append(f" Got: {pred}")
else:
feedback_parts.append(f"Example {i+1}: CORRECT")
return "\n".join(feedback_parts)
# ============================================================
# Program extraction from LLM output
# ============================================================
def extract_python_code(text: str) -> Optional[str]:
"""Extract Python code from LLM response."""
# Try to find code block
if "```python" in text:
start = text.index("```python") + len("```python")
end = text.find("```", start)
if end != -1:
code = text[start:end].strip()
if "def transform" in code:
return code
if "```" in text:
parts = text.split("```")
for i, part in enumerate(parts):
if i % 2 == 1: # Odd indices are code blocks
code = part.strip()
if code.startswith("python\n"):
code = code[7:]
if "def transform" in code:
return code
# Try to find the function definition directly
if "def transform" in text:
start = text.index("def transform")
# Find the end of the function (next def, or end of text)
lines = text[start:].split("\n")
func_lines = [lines[0]]
for line in lines[1:]:
if line.strip() and not line.startswith(" ") and not line.startswith("\t"):
if line.startswith("def ") or line.startswith("class "):
break
func_lines.append(line)
code = "\n".join(func_lines).strip()
return code
return None
# ============================================================
# Weighted Majority Voting (SOAR-style ensembling)
# ============================================================
def grid_to_hashable(grid: List[List[int]]) -> tuple:
"""Convert grid to hashable tuple for voting."""
return tuple(tuple(row) for row in grid)
def hashable_to_grid(h: tuple) -> List[List[int]]:
"""Convert hashable tuple back to grid."""
return [list(row) for row in h]
def weighted_majority_vote(programs: List[Tuple[str, float, Optional[List[List[int]]]]],
top_k: int = 2) -> List[List[List[int]]]:
"""
Weighted majority voting over program outputs.
Each program has (code, accuracy, test_output).
Score = sum of accuracies for programs producing the same output.
Returns top_k outputs.
"""
vote_scores = defaultdict(float)
vote_programs = defaultdict(list)
for code, accuracy, test_output in programs:
if test_output is None:
continue
key = grid_to_hashable(test_output)
vote_scores[key] += accuracy
vote_programs[key].append(code)
if not vote_scores:
return []
# Sort by score
sorted_votes = sorted(vote_scores.items(), key=lambda x: -x[1])
results = []
for key, score in sorted_votes[:top_k]:
results.append(hashable_to_grid(key))
return results
# ============================================================
# Program Synthesis Search (simplified SOAR)
# ============================================================
class ProgramSynthesisEngine:
"""
Evolutionary program synthesis engine for ARC tasks.
Uses an LLM to sample and refine Python programs.
"""
def __init__(self, model=None, tokenizer=None, max_samples: int = 50,
max_refinements: int = 50, temperature: float = 0.8):
self.model = model
self.tokenizer = tokenizer
self.max_samples = max_samples
self.max_refinements = max_refinements
self.temperature = temperature
self.programs_cache = {}
def generate_program(self, prompt: str, max_new_tokens: int = 1024) -> Optional[str]:
"""Generate a program from the LLM."""
if self.model is None or self.tokenizer is None:
return None
inputs = self.tokenizer(prompt, return_tensors="pt", truncation=True, max_length=4096)
inputs = {k: v.to(self.model.device) for k, v in inputs.items()}
with __import__('torch').no_grad():
outputs = self.model.generate(
**inputs,
max_new_tokens=max_new_tokens,
temperature=self.temperature,
top_p=0.95,
do_sample=True,
pad_token_id=self.tokenizer.eos_token_id,
)
text = self.tokenizer.decode(outputs[0][inputs["input_ids"].shape[1]:], skip_special_tokens=True)
return extract_python_code(text)
def sample_programs(self, task: Dict, n_samples: int = None) -> List[Tuple[str, float, Optional[List[List[int]]]]]:
"""Sample programs for a task."""
if n_samples is None:
n_samples = self.max_samples
examples_text = format_examples_for_prompt(task)
results = []
for i in range(n_samples):
prompt = SAMPLING_PROMPT.format(
examples=examples_text,
code_prefix="import numpy as np\n\ndef transform(input_grid: list[list[int]]) -> list[list[int]]:"
)
code = self.generate_program(prompt)
if code is None:
continue
accuracy, test_output = evaluate_program_on_task(code, task)
results.append((code, accuracy, test_output))
# Early termination if we find a perfect program
if accuracy == 1.0:
break
return results
def refine_programs(self, task: Dict, programs: List[Tuple[str, float, Optional[List[List[int]]]]],
n_refinements: int = None) -> List[Tuple[str, float, Optional[List[List[int]]]]]:
"""Refine programs using execution feedback (REX bandit-style)."""
if n_refinements is None:
n_refinements = self.max_refinements
examples_text = format_examples_for_prompt(task)
all_programs = list(programs)
# Sort by accuracy (exploitation), but also explore low-accuracy ones
for i in range(n_refinements):
if not all_programs:
break
# REX-style selection: Thompson sampling with exploration
# Simplified: alternate between best (exploit) and random (explore)
if i % 2 == 0 and all_programs:
# Exploit: refine best program
parent = max(all_programs, key=lambda x: x[1])
else:
# Explore: refine random program
parent = random.choice(all_programs)
parent_code, parent_acc, _ = parent
if parent_acc == 1.0:
break # Already perfect
error_feedback = get_error_feedback(parent_code, task)
prompt = REFINEMENT_PROMPT.format(
examples=examples_text,
previous_code=parent_code,
error_feedback=error_feedback,
code_prefix="import numpy as np\n\ndef transform(input_grid: list[list[int]]) -> list[list[int]]:"
)
new_code = self.generate_program(prompt)
if new_code is None:
continue
accuracy, test_output = evaluate_program_on_task(new_code, task)
all_programs.append((new_code, accuracy, test_output))
if accuracy == 1.0:
break
return all_programs
def solve_task(self, task: Dict) -> List[List[List[int]]]:
"""
Full solve pipeline for a single ARC task:
1. Sample programs
2. Refine promising ones
3. Weighted majority vote → top-2 answers
"""
# Phase 1: Sample
programs = self.sample_programs(task)
# Phase 2: Refine
programs = self.refine_programs(task, programs)
# Filter to programs that get at least some training examples right
good_programs = [(c, a, o) for c, a, o in programs if a > 0 and o is not None]
if not good_programs:
# Fallback: use any program that produces output
good_programs = [(c, a, o) for c, a, o in programs if o is not None]
# Phase 3: Vote
predictions = weighted_majority_vote(good_programs, top_k=2)
return predictions
# ============================================================
# Handcrafted DSL primitives for common ARC patterns
# ============================================================
class ARCPrimitives:
"""Common ARC transformation primitives that can be composed."""
@staticmethod
def get_unique_colors(grid: List[List[int]]) -> set:
return set(c for row in grid for c in row)
@staticmethod
def get_background_color(grid: List[List[int]]) -> int:
"""Most common color is usually background."""
counter = Counter(c for row in grid for c in row)
return counter.most_common(1)[0][0]
@staticmethod
def find_objects(grid: List[List[int]], bg_color: int = 0) -> List[Dict]:
"""Find connected components (objects) in the grid."""
arr = grid_to_numpy(grid)
h, w = arr.shape
visited = np.zeros_like(arr, dtype=bool)
objects = []
def bfs(r, c):
color = arr[r, c]
cells = []
stack = [(r, c)]
while stack:
cr, cc = stack.pop()
if cr < 0 or cr >= h or cc < 0 or cc >= w:
continue
if visited[cr, cc] or arr[cr, cc] != color:
continue
visited[cr, cc] = True
cells.append((cr, cc))
for dr, dc in [(-1,0),(1,0),(0,-1),(0,1)]:
stack.append((cr+dr, cc+dc))
return cells
for r in range(h):
for c in range(w):
if not visited[r, c] and arr[r, c] != bg_color:
cells = bfs(r, c)
if cells:
min_r = min(cr for cr, _ in cells)
max_r = max(cr for cr, _ in cells)
min_c = min(cc for _, cc in cells)
max_c = max(cc for _, cc in cells)
objects.append({
"color": arr[r, c],
"cells": cells,
"bbox": (min_r, min_c, max_r, max_c),
"size": len(cells),
})
return objects
@staticmethod
def crop_grid(grid: List[List[int]], r1: int, c1: int, r2: int, c2: int) -> List[List[int]]:
"""Crop grid to bounding box."""
return [row[c1:c2+1] for row in grid[r1:r2+1]]
@staticmethod
def scale_grid(grid: List[List[int]], factor: int) -> List[List[int]]:
"""Scale grid by integer factor."""
result = []
for row in grid:
new_row = []
for c in row:
new_row.extend([c] * factor)
for _ in range(factor):
result.append(list(new_row))
return result
@staticmethod
def tile_grid(grid: List[List[int]], rows: int, cols: int) -> List[List[int]]:
"""Tile grid in a rows×cols pattern."""
h = len(grid)
w = len(grid[0]) if grid else 0
result = []
for tr in range(rows):
for r in range(h):
new_row = []
for tc in range(cols):
new_row.extend(grid[r])
result.append(new_row)
return result
# ============================================================
# Quick test
# ============================================================
if __name__ == "__main__":
from arc_data import load_arc_dataset_from_hf
print("Loading ARC tasks...")
tasks = load_arc_dataset_from_hf("arc-agi-community/arc-agi-2", "train")
# Test program evaluation
task = tasks[0]
print(f"\nTask 0: {len(task['train'])} demos")
for i, pair in enumerate(task["train"]):
print(f" Demo {i}: {len(pair['input'])}x{len(pair['input'][0])} -> {len(pair['output'])}x{len(pair['output'][0])}")
# Test a simple program
test_code = """
def transform(input_grid):
# Simple test: return the input as-is
return input_grid
"""
acc, out = evaluate_program_on_task(test_code, task)
print(f"\nIdentity program accuracy: {acc:.2f}")
# Test voting
programs = [
(test_code, 0.5, [[1,2],[3,4]]),
(test_code, 0.8, [[1,2],[3,4]]),
(test_code, 0.3, [[5,6],[7,8]]),
]
votes = weighted_majority_vote(programs, top_k=2)
print(f"Voting result: {len(votes)} candidates")
print(f" Top: {votes[0]}")
# Test primitives
prims = ARCPrimitives()
bg = prims.get_background_color(task["train"][0]["input"])
objects = prims.find_objects(task["train"][0]["input"], bg)
print(f"\nBackground color: {bg}")
print(f"Objects found: {len(objects)}")
for obj in objects[:3]:
print(f" Color {obj['color']}, size {obj['size']}, bbox {obj['bbox']}")
print("\n✅ Program synthesis module tests passed!")