| """ |
| SOAR-style Program Synthesis for ARC-AGI |
| Evolutionary search: Sample programs → Execute → Refine → Vote |
| |
| Based on: |
| - SOAR (Pourcel et al., 2507.14172) |
| - Product of Experts (Franzen et al., 2505.07859) |
| - CodeIt (Butt et al., 2402.04858) |
| """ |
| import copy |
| import json |
| import time |
| import random |
| import traceback |
| import signal |
| from typing import List, Dict, Tuple, Optional, Any |
| from collections import Counter, defaultdict |
| from concurrent.futures import ProcessPoolExecutor, TimeoutError as FuturesTimeout |
| import numpy as np |
|
|
| from arc_data import ( |
| grids_equal, grid_to_string, grid_to_numpy, numpy_to_grid, |
| D8_TRANSFORMS, augment_task, reverse_d8, reverse_color_permutation, |
| create_color_permutation, apply_color_permutation |
| ) |
|
|
|
|
| |
| |
| |
|
|
| def _execute_transform(code: str, input_grid: List[List[int]], timeout_sec: float = 5.0) -> Optional[List[List[int]]]: |
| """Execute a transform function safely with timeout.""" |
| try: |
| namespace = {} |
| exec(code, namespace) |
| if "transform" not in namespace: |
| return None |
| |
| transform_fn = namespace["transform"] |
| result = transform_fn(copy.deepcopy(input_grid)) |
| |
| |
| if not isinstance(result, list): |
| return None |
| if len(result) == 0: |
| return None |
| for row in result: |
| if not isinstance(row, list): |
| return None |
| for cell in row: |
| if not isinstance(cell, (int, float)): |
| return None |
| if int(cell) < 0 or int(cell) > 9: |
| return None |
| |
| |
| result = [[int(c) for c in row] for row in result] |
| return result |
| |
| except Exception: |
| return None |
|
|
|
|
| def evaluate_program_on_task(code: str, task: Dict) -> Tuple[float, Optional[List[List[int]]]]: |
| """ |
| Evaluate a program on all training pairs. |
| Returns (accuracy, test_output). |
| accuracy = fraction of training pairs correctly predicted. |
| """ |
| train_pairs = task["train"] |
| correct = 0 |
| total = len(train_pairs) |
| |
| for pair in train_pairs: |
| pred = _execute_transform(code, pair["input"]) |
| if pred is not None and grids_equal(pred, pair["output"]): |
| correct += 1 |
| |
| accuracy = correct / total if total > 0 else 0.0 |
| |
| |
| test_output = None |
| if task.get("test") and len(task["test"]) > 0: |
| test_output = _execute_transform(code, task["test"][0]["input"]) |
| |
| return accuracy, test_output |
|
|
|
|
| |
| |
| |
|
|
| SAMPLING_PROMPT = """You are an expert at solving abstract reasoning puzzles. Given input-output grid examples, write a Python function that transforms the input grid to the output grid. |
| |
| Each grid is a 2D list of integers 0-9, where each integer represents a color: |
| 0=black, 1=blue, 2=red, 3=green, 4=yellow, 5=grey, 6=magenta, 7=orange, 8=cyan, 9=maroon |
| |
| {examples} |
| |
| Write a Python function `transform(input_grid: list[list[int]]) -> list[list[int]]` that correctly transforms any input to its output. You may use numpy. Think step by step about the pattern, then write the code. |
| |
| ```python |
| {code_prefix} |
| ```""" |
|
|
| REFINEMENT_PROMPT = """The following Python function was supposed to transform input grids to output grids, but it has errors. |
| |
| {examples} |
| |
| Previous attempt: |
| ```python |
| {previous_code} |
| ``` |
| |
| {error_feedback} |
| |
| Please fix the function to correctly handle all examples. Write the corrected `transform` function: |
| |
| ```python |
| {code_prefix} |
| ```""" |
|
|
|
|
| def format_examples_for_prompt(task: Dict) -> str: |
| """Format task examples for the prompt.""" |
| parts = [] |
| for i, pair in enumerate(task["train"]): |
| parts.append(f"Example {i+1}:") |
| parts.append(f" Input ({len(pair['input'])}x{len(pair['input'][0])}):") |
| for row in pair["input"]: |
| parts.append(f" {row}") |
| parts.append(f" Output ({len(pair['output'])}x{len(pair['output'][0])}):") |
| for row in pair["output"]: |
| parts.append(f" {row}") |
| |
| if task.get("test") and len(task["test"]) > 0: |
| test_inp = task["test"][0]["input"] |
| parts.append(f"\nTest Input ({len(test_inp)}x{len(test_inp[0])}):") |
| for row in test_inp: |
| parts.append(f" {row}") |
| |
| return "\n".join(parts) |
|
|
|
|
| def get_error_feedback(code: str, task: Dict) -> str: |
| """Generate error feedback for a failed program.""" |
| feedback_parts = [] |
| |
| for i, pair in enumerate(task["train"]): |
| pred = _execute_transform(code, pair["input"]) |
| if pred is None: |
| feedback_parts.append(f"Example {i+1}: EXECUTION ERROR") |
| elif not grids_equal(pred, pair["output"]): |
| feedback_parts.append(f"Example {i+1}: WRONG OUTPUT") |
| feedback_parts.append(f" Expected: {pair['output']}") |
| feedback_parts.append(f" Got: {pred}") |
| else: |
| feedback_parts.append(f"Example {i+1}: CORRECT") |
| |
| return "\n".join(feedback_parts) |
|
|
|
|
| |
| |
| |
|
|
| def extract_python_code(text: str) -> Optional[str]: |
| """Extract Python code from LLM response.""" |
| |
| if "```python" in text: |
| start = text.index("```python") + len("```python") |
| end = text.find("```", start) |
| if end != -1: |
| code = text[start:end].strip() |
| if "def transform" in code: |
| return code |
| |
| if "```" in text: |
| parts = text.split("```") |
| for i, part in enumerate(parts): |
| if i % 2 == 1: |
| code = part.strip() |
| if code.startswith("python\n"): |
| code = code[7:] |
| if "def transform" in code: |
| return code |
| |
| |
| if "def transform" in text: |
| start = text.index("def transform") |
| |
| lines = text[start:].split("\n") |
| func_lines = [lines[0]] |
| for line in lines[1:]: |
| if line.strip() and not line.startswith(" ") and not line.startswith("\t"): |
| if line.startswith("def ") or line.startswith("class "): |
| break |
| func_lines.append(line) |
| code = "\n".join(func_lines).strip() |
| return code |
| |
| return None |
|
|
|
|
| |
| |
| |
|
|
| def grid_to_hashable(grid: List[List[int]]) -> tuple: |
| """Convert grid to hashable tuple for voting.""" |
| return tuple(tuple(row) for row in grid) |
|
|
|
|
| def hashable_to_grid(h: tuple) -> List[List[int]]: |
| """Convert hashable tuple back to grid.""" |
| return [list(row) for row in h] |
|
|
|
|
| def weighted_majority_vote(programs: List[Tuple[str, float, Optional[List[List[int]]]]], |
| top_k: int = 2) -> List[List[List[int]]]: |
| """ |
| Weighted majority voting over program outputs. |
| Each program has (code, accuracy, test_output). |
| Score = sum of accuracies for programs producing the same output. |
| Returns top_k outputs. |
| """ |
| vote_scores = defaultdict(float) |
| vote_programs = defaultdict(list) |
| |
| for code, accuracy, test_output in programs: |
| if test_output is None: |
| continue |
| key = grid_to_hashable(test_output) |
| vote_scores[key] += accuracy |
| vote_programs[key].append(code) |
| |
| if not vote_scores: |
| return [] |
| |
| |
| sorted_votes = sorted(vote_scores.items(), key=lambda x: -x[1]) |
| |
| results = [] |
| for key, score in sorted_votes[:top_k]: |
| results.append(hashable_to_grid(key)) |
| |
| return results |
|
|
|
|
| |
| |
| |
|
|
| class ProgramSynthesisEngine: |
| """ |
| Evolutionary program synthesis engine for ARC tasks. |
| Uses an LLM to sample and refine Python programs. |
| """ |
| |
| def __init__(self, model=None, tokenizer=None, max_samples: int = 50, |
| max_refinements: int = 50, temperature: float = 0.8): |
| self.model = model |
| self.tokenizer = tokenizer |
| self.max_samples = max_samples |
| self.max_refinements = max_refinements |
| self.temperature = temperature |
| self.programs_cache = {} |
| |
| def generate_program(self, prompt: str, max_new_tokens: int = 1024) -> Optional[str]: |
| """Generate a program from the LLM.""" |
| if self.model is None or self.tokenizer is None: |
| return None |
| |
| inputs = self.tokenizer(prompt, return_tensors="pt", truncation=True, max_length=4096) |
| inputs = {k: v.to(self.model.device) for k, v in inputs.items()} |
| |
| with __import__('torch').no_grad(): |
| outputs = self.model.generate( |
| **inputs, |
| max_new_tokens=max_new_tokens, |
| temperature=self.temperature, |
| top_p=0.95, |
| do_sample=True, |
| pad_token_id=self.tokenizer.eos_token_id, |
| ) |
| |
| text = self.tokenizer.decode(outputs[0][inputs["input_ids"].shape[1]:], skip_special_tokens=True) |
| return extract_python_code(text) |
| |
| def sample_programs(self, task: Dict, n_samples: int = None) -> List[Tuple[str, float, Optional[List[List[int]]]]]: |
| """Sample programs for a task.""" |
| if n_samples is None: |
| n_samples = self.max_samples |
| |
| examples_text = format_examples_for_prompt(task) |
| results = [] |
| |
| for i in range(n_samples): |
| prompt = SAMPLING_PROMPT.format( |
| examples=examples_text, |
| code_prefix="import numpy as np\n\ndef transform(input_grid: list[list[int]]) -> list[list[int]]:" |
| ) |
| |
| code = self.generate_program(prompt) |
| if code is None: |
| continue |
| |
| accuracy, test_output = evaluate_program_on_task(code, task) |
| results.append((code, accuracy, test_output)) |
| |
| |
| if accuracy == 1.0: |
| break |
| |
| return results |
| |
| def refine_programs(self, task: Dict, programs: List[Tuple[str, float, Optional[List[List[int]]]]], |
| n_refinements: int = None) -> List[Tuple[str, float, Optional[List[List[int]]]]]: |
| """Refine programs using execution feedback (REX bandit-style).""" |
| if n_refinements is None: |
| n_refinements = self.max_refinements |
| |
| examples_text = format_examples_for_prompt(task) |
| all_programs = list(programs) |
| |
| |
| for i in range(n_refinements): |
| if not all_programs: |
| break |
| |
| |
| |
| if i % 2 == 0 and all_programs: |
| |
| parent = max(all_programs, key=lambda x: x[1]) |
| else: |
| |
| parent = random.choice(all_programs) |
| |
| parent_code, parent_acc, _ = parent |
| |
| if parent_acc == 1.0: |
| break |
| |
| error_feedback = get_error_feedback(parent_code, task) |
| |
| prompt = REFINEMENT_PROMPT.format( |
| examples=examples_text, |
| previous_code=parent_code, |
| error_feedback=error_feedback, |
| code_prefix="import numpy as np\n\ndef transform(input_grid: list[list[int]]) -> list[list[int]]:" |
| ) |
| |
| new_code = self.generate_program(prompt) |
| if new_code is None: |
| continue |
| |
| accuracy, test_output = evaluate_program_on_task(new_code, task) |
| all_programs.append((new_code, accuracy, test_output)) |
| |
| if accuracy == 1.0: |
| break |
| |
| return all_programs |
| |
| def solve_task(self, task: Dict) -> List[List[List[int]]]: |
| """ |
| Full solve pipeline for a single ARC task: |
| 1. Sample programs |
| 2. Refine promising ones |
| 3. Weighted majority vote → top-2 answers |
| """ |
| |
| programs = self.sample_programs(task) |
| |
| |
| programs = self.refine_programs(task, programs) |
| |
| |
| good_programs = [(c, a, o) for c, a, o in programs if a > 0 and o is not None] |
| |
| if not good_programs: |
| |
| good_programs = [(c, a, o) for c, a, o in programs if o is not None] |
| |
| |
| predictions = weighted_majority_vote(good_programs, top_k=2) |
| |
| return predictions |
|
|
|
|
| |
| |
| |
|
|
| class ARCPrimitives: |
| """Common ARC transformation primitives that can be composed.""" |
| |
| @staticmethod |
| def get_unique_colors(grid: List[List[int]]) -> set: |
| return set(c for row in grid for c in row) |
| |
| @staticmethod |
| def get_background_color(grid: List[List[int]]) -> int: |
| """Most common color is usually background.""" |
| counter = Counter(c for row in grid for c in row) |
| return counter.most_common(1)[0][0] |
| |
| @staticmethod |
| def find_objects(grid: List[List[int]], bg_color: int = 0) -> List[Dict]: |
| """Find connected components (objects) in the grid.""" |
| arr = grid_to_numpy(grid) |
| h, w = arr.shape |
| visited = np.zeros_like(arr, dtype=bool) |
| objects = [] |
| |
| def bfs(r, c): |
| color = arr[r, c] |
| cells = [] |
| stack = [(r, c)] |
| while stack: |
| cr, cc = stack.pop() |
| if cr < 0 or cr >= h or cc < 0 or cc >= w: |
| continue |
| if visited[cr, cc] or arr[cr, cc] != color: |
| continue |
| visited[cr, cc] = True |
| cells.append((cr, cc)) |
| for dr, dc in [(-1,0),(1,0),(0,-1),(0,1)]: |
| stack.append((cr+dr, cc+dc)) |
| return cells |
| |
| for r in range(h): |
| for c in range(w): |
| if not visited[r, c] and arr[r, c] != bg_color: |
| cells = bfs(r, c) |
| if cells: |
| min_r = min(cr for cr, _ in cells) |
| max_r = max(cr for cr, _ in cells) |
| min_c = min(cc for _, cc in cells) |
| max_c = max(cc for _, cc in cells) |
| objects.append({ |
| "color": arr[r, c], |
| "cells": cells, |
| "bbox": (min_r, min_c, max_r, max_c), |
| "size": len(cells), |
| }) |
| |
| return objects |
| |
| @staticmethod |
| def crop_grid(grid: List[List[int]], r1: int, c1: int, r2: int, c2: int) -> List[List[int]]: |
| """Crop grid to bounding box.""" |
| return [row[c1:c2+1] for row in grid[r1:r2+1]] |
| |
| @staticmethod |
| def scale_grid(grid: List[List[int]], factor: int) -> List[List[int]]: |
| """Scale grid by integer factor.""" |
| result = [] |
| for row in grid: |
| new_row = [] |
| for c in row: |
| new_row.extend([c] * factor) |
| for _ in range(factor): |
| result.append(list(new_row)) |
| return result |
| |
| @staticmethod |
| def tile_grid(grid: List[List[int]], rows: int, cols: int) -> List[List[int]]: |
| """Tile grid in a rows×cols pattern.""" |
| h = len(grid) |
| w = len(grid[0]) if grid else 0 |
| result = [] |
| for tr in range(rows): |
| for r in range(h): |
| new_row = [] |
| for tc in range(cols): |
| new_row.extend(grid[r]) |
| result.append(new_row) |
| return result |
|
|
|
|
| |
| |
| |
|
|
| if __name__ == "__main__": |
| from arc_data import load_arc_dataset_from_hf |
| |
| print("Loading ARC tasks...") |
| tasks = load_arc_dataset_from_hf("arc-agi-community/arc-agi-2", "train") |
| |
| |
| task = tasks[0] |
| print(f"\nTask 0: {len(task['train'])} demos") |
| for i, pair in enumerate(task["train"]): |
| print(f" Demo {i}: {len(pair['input'])}x{len(pair['input'][0])} -> {len(pair['output'])}x{len(pair['output'][0])}") |
| |
| |
| test_code = """ |
| def transform(input_grid): |
| # Simple test: return the input as-is |
| return input_grid |
| """ |
| acc, out = evaluate_program_on_task(test_code, task) |
| print(f"\nIdentity program accuracy: {acc:.2f}") |
| |
| |
| programs = [ |
| (test_code, 0.5, [[1,2],[3,4]]), |
| (test_code, 0.8, [[1,2],[3,4]]), |
| (test_code, 0.3, [[5,6],[7,8]]), |
| ] |
| votes = weighted_majority_vote(programs, top_k=2) |
| print(f"Voting result: {len(votes)} candidates") |
| print(f" Top: {votes[0]}") |
| |
| |
| prims = ARCPrimitives() |
| bg = prims.get_background_color(task["train"][0]["input"]) |
| objects = prims.find_objects(task["train"][0]["input"], bg) |
| print(f"\nBackground color: {bg}") |
| print(f"Objects found: {len(objects)}") |
| for obj in objects[:3]: |
| print(f" Color {obj['color']}, size {obj['size']}, bbox {obj['bbox']}") |
| |
| print("\n✅ Program synthesis module tests passed!") |
|
|