""" SOAR-style Program Synthesis for ARC-AGI Evolutionary search: Sample programs → Execute → Refine → Vote Based on: - SOAR (Pourcel et al., 2507.14172) - Product of Experts (Franzen et al., 2505.07859) - CodeIt (Butt et al., 2402.04858) """ import copy import json import time import random import traceback import signal from typing import List, Dict, Tuple, Optional, Any from collections import Counter, defaultdict from concurrent.futures import ProcessPoolExecutor, TimeoutError as FuturesTimeout import numpy as np from arc_data import ( grids_equal, grid_to_string, grid_to_numpy, numpy_to_grid, D8_TRANSFORMS, augment_task, reverse_d8, reverse_color_permutation, create_color_permutation, apply_color_permutation ) # ============================================================ # Safe code execution with timeout # ============================================================ def _execute_transform(code: str, input_grid: List[List[int]], timeout_sec: float = 5.0) -> Optional[List[List[int]]]: """Execute a transform function safely with timeout.""" try: namespace = {} exec(code, namespace) if "transform" not in namespace: return None transform_fn = namespace["transform"] result = transform_fn(copy.deepcopy(input_grid)) # Validate output if not isinstance(result, list): return None if len(result) == 0: return None for row in result: if not isinstance(row, list): return None for cell in row: if not isinstance(cell, (int, float)): return None if int(cell) < 0 or int(cell) > 9: return None # Normalize to int result = [[int(c) for c in row] for row in result] return result except Exception: return None def evaluate_program_on_task(code: str, task: Dict) -> Tuple[float, Optional[List[List[int]]]]: """ Evaluate a program on all training pairs. Returns (accuracy, test_output). accuracy = fraction of training pairs correctly predicted. """ train_pairs = task["train"] correct = 0 total = len(train_pairs) for pair in train_pairs: pred = _execute_transform(code, pair["input"]) if pred is not None and grids_equal(pred, pair["output"]): correct += 1 accuracy = correct / total if total > 0 else 0.0 # Get test output test_output = None if task.get("test") and len(task["test"]) > 0: test_output = _execute_transform(code, task["test"][0]["input"]) return accuracy, test_output # ============================================================ # Program generation prompts # ============================================================ SAMPLING_PROMPT = """You are an expert at solving abstract reasoning puzzles. Given input-output grid examples, write a Python function that transforms the input grid to the output grid. Each grid is a 2D list of integers 0-9, where each integer represents a color: 0=black, 1=blue, 2=red, 3=green, 4=yellow, 5=grey, 6=magenta, 7=orange, 8=cyan, 9=maroon {examples} Write a Python function `transform(input_grid: list[list[int]]) -> list[list[int]]` that correctly transforms any input to its output. You may use numpy. Think step by step about the pattern, then write the code. ```python {code_prefix} ```""" REFINEMENT_PROMPT = """The following Python function was supposed to transform input grids to output grids, but it has errors. {examples} Previous attempt: ```python {previous_code} ``` {error_feedback} Please fix the function to correctly handle all examples. Write the corrected `transform` function: ```python {code_prefix} ```""" def format_examples_for_prompt(task: Dict) -> str: """Format task examples for the prompt.""" parts = [] for i, pair in enumerate(task["train"]): parts.append(f"Example {i+1}:") parts.append(f" Input ({len(pair['input'])}x{len(pair['input'][0])}):") for row in pair["input"]: parts.append(f" {row}") parts.append(f" Output ({len(pair['output'])}x{len(pair['output'][0])}):") for row in pair["output"]: parts.append(f" {row}") if task.get("test") and len(task["test"]) > 0: test_inp = task["test"][0]["input"] parts.append(f"\nTest Input ({len(test_inp)}x{len(test_inp[0])}):") for row in test_inp: parts.append(f" {row}") return "\n".join(parts) def get_error_feedback(code: str, task: Dict) -> str: """Generate error feedback for a failed program.""" feedback_parts = [] for i, pair in enumerate(task["train"]): pred = _execute_transform(code, pair["input"]) if pred is None: feedback_parts.append(f"Example {i+1}: EXECUTION ERROR") elif not grids_equal(pred, pair["output"]): feedback_parts.append(f"Example {i+1}: WRONG OUTPUT") feedback_parts.append(f" Expected: {pair['output']}") feedback_parts.append(f" Got: {pred}") else: feedback_parts.append(f"Example {i+1}: CORRECT") return "\n".join(feedback_parts) # ============================================================ # Program extraction from LLM output # ============================================================ def extract_python_code(text: str) -> Optional[str]: """Extract Python code from LLM response.""" # Try to find code block if "```python" in text: start = text.index("```python") + len("```python") end = text.find("```", start) if end != -1: code = text[start:end].strip() if "def transform" in code: return code if "```" in text: parts = text.split("```") for i, part in enumerate(parts): if i % 2 == 1: # Odd indices are code blocks code = part.strip() if code.startswith("python\n"): code = code[7:] if "def transform" in code: return code # Try to find the function definition directly if "def transform" in text: start = text.index("def transform") # Find the end of the function (next def, or end of text) lines = text[start:].split("\n") func_lines = [lines[0]] for line in lines[1:]: if line.strip() and not line.startswith(" ") and not line.startswith("\t"): if line.startswith("def ") or line.startswith("class "): break func_lines.append(line) code = "\n".join(func_lines).strip() return code return None # ============================================================ # Weighted Majority Voting (SOAR-style ensembling) # ============================================================ def grid_to_hashable(grid: List[List[int]]) -> tuple: """Convert grid to hashable tuple for voting.""" return tuple(tuple(row) for row in grid) def hashable_to_grid(h: tuple) -> List[List[int]]: """Convert hashable tuple back to grid.""" return [list(row) for row in h] def weighted_majority_vote(programs: List[Tuple[str, float, Optional[List[List[int]]]]], top_k: int = 2) -> List[List[List[int]]]: """ Weighted majority voting over program outputs. Each program has (code, accuracy, test_output). Score = sum of accuracies for programs producing the same output. Returns top_k outputs. """ vote_scores = defaultdict(float) vote_programs = defaultdict(list) for code, accuracy, test_output in programs: if test_output is None: continue key = grid_to_hashable(test_output) vote_scores[key] += accuracy vote_programs[key].append(code) if not vote_scores: return [] # Sort by score sorted_votes = sorted(vote_scores.items(), key=lambda x: -x[1]) results = [] for key, score in sorted_votes[:top_k]: results.append(hashable_to_grid(key)) return results # ============================================================ # Program Synthesis Search (simplified SOAR) # ============================================================ class ProgramSynthesisEngine: """ Evolutionary program synthesis engine for ARC tasks. Uses an LLM to sample and refine Python programs. """ def __init__(self, model=None, tokenizer=None, max_samples: int = 50, max_refinements: int = 50, temperature: float = 0.8): self.model = model self.tokenizer = tokenizer self.max_samples = max_samples self.max_refinements = max_refinements self.temperature = temperature self.programs_cache = {} def generate_program(self, prompt: str, max_new_tokens: int = 1024) -> Optional[str]: """Generate a program from the LLM.""" if self.model is None or self.tokenizer is None: return None inputs = self.tokenizer(prompt, return_tensors="pt", truncation=True, max_length=4096) inputs = {k: v.to(self.model.device) for k, v in inputs.items()} with __import__('torch').no_grad(): outputs = self.model.generate( **inputs, max_new_tokens=max_new_tokens, temperature=self.temperature, top_p=0.95, do_sample=True, pad_token_id=self.tokenizer.eos_token_id, ) text = self.tokenizer.decode(outputs[0][inputs["input_ids"].shape[1]:], skip_special_tokens=True) return extract_python_code(text) def sample_programs(self, task: Dict, n_samples: int = None) -> List[Tuple[str, float, Optional[List[List[int]]]]]: """Sample programs for a task.""" if n_samples is None: n_samples = self.max_samples examples_text = format_examples_for_prompt(task) results = [] for i in range(n_samples): prompt = SAMPLING_PROMPT.format( examples=examples_text, code_prefix="import numpy as np\n\ndef transform(input_grid: list[list[int]]) -> list[list[int]]:" ) code = self.generate_program(prompt) if code is None: continue accuracy, test_output = evaluate_program_on_task(code, task) results.append((code, accuracy, test_output)) # Early termination if we find a perfect program if accuracy == 1.0: break return results def refine_programs(self, task: Dict, programs: List[Tuple[str, float, Optional[List[List[int]]]]], n_refinements: int = None) -> List[Tuple[str, float, Optional[List[List[int]]]]]: """Refine programs using execution feedback (REX bandit-style).""" if n_refinements is None: n_refinements = self.max_refinements examples_text = format_examples_for_prompt(task) all_programs = list(programs) # Sort by accuracy (exploitation), but also explore low-accuracy ones for i in range(n_refinements): if not all_programs: break # REX-style selection: Thompson sampling with exploration # Simplified: alternate between best (exploit) and random (explore) if i % 2 == 0 and all_programs: # Exploit: refine best program parent = max(all_programs, key=lambda x: x[1]) else: # Explore: refine random program parent = random.choice(all_programs) parent_code, parent_acc, _ = parent if parent_acc == 1.0: break # Already perfect error_feedback = get_error_feedback(parent_code, task) prompt = REFINEMENT_PROMPT.format( examples=examples_text, previous_code=parent_code, error_feedback=error_feedback, code_prefix="import numpy as np\n\ndef transform(input_grid: list[list[int]]) -> list[list[int]]:" ) new_code = self.generate_program(prompt) if new_code is None: continue accuracy, test_output = evaluate_program_on_task(new_code, task) all_programs.append((new_code, accuracy, test_output)) if accuracy == 1.0: break return all_programs def solve_task(self, task: Dict) -> List[List[List[int]]]: """ Full solve pipeline for a single ARC task: 1. Sample programs 2. Refine promising ones 3. Weighted majority vote → top-2 answers """ # Phase 1: Sample programs = self.sample_programs(task) # Phase 2: Refine programs = self.refine_programs(task, programs) # Filter to programs that get at least some training examples right good_programs = [(c, a, o) for c, a, o in programs if a > 0 and o is not None] if not good_programs: # Fallback: use any program that produces output good_programs = [(c, a, o) for c, a, o in programs if o is not None] # Phase 3: Vote predictions = weighted_majority_vote(good_programs, top_k=2) return predictions # ============================================================ # Handcrafted DSL primitives for common ARC patterns # ============================================================ class ARCPrimitives: """Common ARC transformation primitives that can be composed.""" @staticmethod def get_unique_colors(grid: List[List[int]]) -> set: return set(c for row in grid for c in row) @staticmethod def get_background_color(grid: List[List[int]]) -> int: """Most common color is usually background.""" counter = Counter(c for row in grid for c in row) return counter.most_common(1)[0][0] @staticmethod def find_objects(grid: List[List[int]], bg_color: int = 0) -> List[Dict]: """Find connected components (objects) in the grid.""" arr = grid_to_numpy(grid) h, w = arr.shape visited = np.zeros_like(arr, dtype=bool) objects = [] def bfs(r, c): color = arr[r, c] cells = [] stack = [(r, c)] while stack: cr, cc = stack.pop() if cr < 0 or cr >= h or cc < 0 or cc >= w: continue if visited[cr, cc] or arr[cr, cc] != color: continue visited[cr, cc] = True cells.append((cr, cc)) for dr, dc in [(-1,0),(1,0),(0,-1),(0,1)]: stack.append((cr+dr, cc+dc)) return cells for r in range(h): for c in range(w): if not visited[r, c] and arr[r, c] != bg_color: cells = bfs(r, c) if cells: min_r = min(cr for cr, _ in cells) max_r = max(cr for cr, _ in cells) min_c = min(cc for _, cc in cells) max_c = max(cc for _, cc in cells) objects.append({ "color": arr[r, c], "cells": cells, "bbox": (min_r, min_c, max_r, max_c), "size": len(cells), }) return objects @staticmethod def crop_grid(grid: List[List[int]], r1: int, c1: int, r2: int, c2: int) -> List[List[int]]: """Crop grid to bounding box.""" return [row[c1:c2+1] for row in grid[r1:r2+1]] @staticmethod def scale_grid(grid: List[List[int]], factor: int) -> List[List[int]]: """Scale grid by integer factor.""" result = [] for row in grid: new_row = [] for c in row: new_row.extend([c] * factor) for _ in range(factor): result.append(list(new_row)) return result @staticmethod def tile_grid(grid: List[List[int]], rows: int, cols: int) -> List[List[int]]: """Tile grid in a rows×cols pattern.""" h = len(grid) w = len(grid[0]) if grid else 0 result = [] for tr in range(rows): for r in range(h): new_row = [] for tc in range(cols): new_row.extend(grid[r]) result.append(new_row) return result # ============================================================ # Quick test # ============================================================ if __name__ == "__main__": from arc_data import load_arc_dataset_from_hf print("Loading ARC tasks...") tasks = load_arc_dataset_from_hf("arc-agi-community/arc-agi-2", "train") # Test program evaluation task = tasks[0] print(f"\nTask 0: {len(task['train'])} demos") for i, pair in enumerate(task["train"]): print(f" Demo {i}: {len(pair['input'])}x{len(pair['input'][0])} -> {len(pair['output'])}x{len(pair['output'][0])}") # Test a simple program test_code = """ def transform(input_grid): # Simple test: return the input as-is return input_grid """ acc, out = evaluate_program_on_task(test_code, task) print(f"\nIdentity program accuracy: {acc:.2f}") # Test voting programs = [ (test_code, 0.5, [[1,2],[3,4]]), (test_code, 0.8, [[1,2],[3,4]]), (test_code, 0.3, [[5,6],[7,8]]), ] votes = weighted_majority_vote(programs, top_k=2) print(f"Voting result: {len(votes)} candidates") print(f" Top: {votes[0]}") # Test primitives prims = ARCPrimitives() bg = prims.get_background_color(task["train"][0]["input"]) objects = prims.find_objects(task["train"][0]["input"], bg) print(f"\nBackground color: {bg}") print(f"Objects found: {len(objects)}") for obj in objects[:3]: print(f" Color {obj['color']}, size {obj['size']}, bbox {obj['bbox']}") print("\n✅ Program synthesis module tests passed!")