Spaces:
Paused
Paused
| from __future__ import annotations | |
| import argparse | |
| import hashlib | |
| import re | |
| from dataclasses import dataclass | |
| from pathlib import Path | |
| from threading import Thread | |
| from types import SimpleNamespace | |
| from typing import Tuple | |
| import torch | |
| from transformers import TextIteratorStreamer | |
| from unsloth import FastLanguageModel | |
| from permanence.agent_interface.formatter import format_observation | |
| from permanence.agent_interface.parser import parse_agent_output | |
| from permanence.tasks.task_bank import TaskSpec | |
| from permanence.world.state import EmployeeState, ExternalRelationshipState, ProjectState, WorldState | |
| from training.config import TrainingConfig, load_simple_yaml | |
| DEFAULT_SCENARIO_PROMPT = "[JUDGE MODE] Enter a custom corporate crisis scenario: > " | |
| DEFAULT_MODEL_SUFFIX = "final_model" | |
| MAX_NEW_TOKENS = 220 | |
| MAX_SEQ_LENGTH = 2048 | |
| class JudgeTask: | |
| task_id: str = "judge_sandbox" | |
| name: str = "Judge Sandbox" | |
| narrative: str = ( | |
| "A custom corporate crisis scenario supplied by a human judge. " | |
| "Respond with a concise internal reasoning trace and one concrete corporate action." | |
| ) | |
| max_steps: int = 1 | |
| available_actions: Tuple[str, ...] = ( | |
| "draft_internal_memo", | |
| "brief_internal_stakeholders", | |
| "prepare_response_draft", | |
| "send_internal_communication", | |
| "send_external_communication", | |
| "issue_public_statement", | |
| "delay_release", | |
| "begin_internal_investigation", | |
| ) | |
| def _hash_suffix(text: str) -> str: | |
| digest = hashlib.sha1(text.encode("utf-8")).hexdigest() | |
| return digest[:8] | |
| def _clean_label(text: str, fallback: str) -> str: | |
| tokens = re.findall(r"[A-Za-z0-9]+", text.lower()) | |
| if not tokens: | |
| return fallback | |
| return "_".join(tokens[:3]) | |
| def parse_judge_scenario(raw_text: str) -> WorldState: | |
| scenario = raw_text.strip() | |
| lowered = scenario.lower() | |
| scenario_id = f"judge:{_hash_suffix(scenario or 'empty')}" | |
| deadline_pressure = 0.9 | |
| board_trust = 0.5 | |
| board_expectation = 0.55 | |
| client_standing = 0.55 | |
| project_momentum = 0.5 | |
| resource_level = 0.68 | |
| if any(keyword in lowered for keyword in ("outage", "incident", "breach", "failure", "downtime")): | |
| board_expectation = 0.82 | |
| board_trust = 0.38 | |
| client_standing = 0.32 | |
| project_momentum = 0.34 | |
| if any(keyword in lowered for keyword in ("launch", "release", "shipment", "go-live")): | |
| project_momentum = 0.74 | |
| if any(keyword in lowered for keyword in ("lawsuit", "regulator", "compliance", "audit")): | |
| board_expectation = 0.88 | |
| board_trust = 0.42 | |
| if any(keyword in lowered for keyword in ("client", "customer", "partner", "escalation")): | |
| client_standing = 0.41 | |
| project_name = scenario if scenario else "Generic Corporate Crisis" | |
| project_id = f"proj_{_clean_label(scenario, 'crisis') or 'crisis'}" | |
| employees = { | |
| "emp_judge_001": EmployeeState( | |
| employee_id="emp_judge_001", | |
| name="Avery", | |
| role="incident_lead", | |
| trust_score=0.71, | |
| availability="active", | |
| current_project=project_id, | |
| relationship_flags=set(), | |
| institutional_knowledge=0.8, | |
| ), | |
| "emp_judge_002": EmployeeState( | |
| employee_id="emp_judge_002", | |
| name="Blair", | |
| role="communications", | |
| trust_score=0.67, | |
| availability="active", | |
| current_project=project_id, | |
| relationship_flags=set(), | |
| institutional_knowledge=0.76, | |
| ), | |
| "emp_judge_003": EmployeeState( | |
| employee_id="emp_judge_003", | |
| name="Casey", | |
| role="executive", | |
| trust_score=0.63, | |
| availability="active", | |
| current_project=project_id, | |
| relationship_flags=set(), | |
| institutional_knowledge=0.72, | |
| ), | |
| } | |
| projects = { | |
| project_id: ProjectState( | |
| project_id=project_id, | |
| name=project_name[:80] if project_name else "Generic Corporate Crisis", | |
| momentum=project_momentum, | |
| resource_level=resource_level, | |
| deadline_pressure=deadline_pressure, | |
| lead_employee_id="emp_judge_001", | |
| dependencies=["internal_review"], | |
| external_commitment_made=False, | |
| status="active", | |
| ) | |
| } | |
| external = ExternalRelationshipState( | |
| board_expectation_level=board_expectation, | |
| board_trust_score=board_trust, | |
| client_standing={"client_a": client_standing}, | |
| public_record=[], | |
| partner_obligations=[], | |
| ) | |
| critical_options = { | |
| "notify_board": True, | |
| "stabilize_operations": True, | |
| "issue_public_statement": False, | |
| "preserve_escalation_path": True, | |
| } | |
| return WorldState( | |
| employees=employees, | |
| projects=projects, | |
| external=external, | |
| action_history=[], | |
| locked_actions={}, | |
| critical_options=critical_options, | |
| episode_step=0, | |
| scenario_id=scenario_id, | |
| task_id="judge_sandbox", | |
| ) | |
| def _build_task() -> SimpleNamespace: | |
| spec = TaskSpec( | |
| task_id="judge_sandbox", | |
| name="Judge Sandbox", | |
| narrative=( | |
| "A judge-supplied corporate crisis scenario. Analyze the current world state, " | |
| "explain the reasoning in <thinking>, then emit a single reversible action decision." | |
| ), | |
| max_steps=1, | |
| available_actions=list(JudgeTask.available_actions), | |
| preservation_targets=["notify_board", "stabilize_operations"], | |
| success_fn=lambda world_state, task_spec: True, | |
| difficulty=1, | |
| ) | |
| return SimpleNamespace(**spec.__dict__) | |
| def _load_model_path(config_path: str, model_path: str | None) -> Path: | |
| if model_path: | |
| return Path(model_path) | |
| config_data = load_simple_yaml(config_path) | |
| config = TrainingConfig.from_mapping(config_data) | |
| return Path(config.output_dir) / DEFAULT_MODEL_SUFFIX | |
| def load_final_model(model_dir: Path): | |
| if not model_dir.exists(): | |
| raise FileNotFoundError( | |
| f"Final trained weights not found at {model_dir}. Run training/train.py first to produce final_model." | |
| ) | |
| model, tokenizer = FastLanguageModel.from_pretrained( | |
| model_name=str(model_dir), | |
| max_seq_length=MAX_SEQ_LENGTH, | |
| dtype=None, | |
| load_in_4bit=True, | |
| ) | |
| if tokenizer.pad_token is None: | |
| tokenizer.pad_token = tokenizer.eos_token | |
| if hasattr(FastLanguageModel, "for_inference"): | |
| try: | |
| model = FastLanguageModel.for_inference(model) | |
| except Exception: | |
| pass | |
| return model, tokenizer | |
| def build_prompt(observation: dict, scenario_text: str) -> str: | |
| return ( | |
| "You are operating in judge sandbox mode.\n" | |
| "Use the supplied world state to reason about the corporate crisis.\n" | |
| "Respond only with a <thinking> block, then one <action id=\"...\" .../> tag, then one <reversibility level=\"R1-R5\" confidence=\"0.0-1.0\"/> tag.\n\n" | |
| f"JUDGE SCENARIO:\n{scenario_text.strip() or '(empty scenario)'}\n\n" | |
| f"WORLD STATE:\n{observation['text']}\n" | |
| ) | |
| def _stream_generate(model, tokenizer, prompt: str, max_new_tokens: int) -> str: | |
| inputs = tokenizer(prompt, return_tensors="pt") | |
| device = getattr(model, "device", None) | |
| if device is not None: | |
| inputs = {key: value.to(device) for key, value in inputs.items()} | |
| streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True) | |
| generation_kwargs = dict( | |
| **inputs, | |
| streamer=streamer, | |
| max_new_tokens=max_new_tokens, | |
| do_sample=True, | |
| temperature=0.7, | |
| top_p=0.9, | |
| eos_token_id=tokenizer.eos_token_id, | |
| pad_token_id=tokenizer.pad_token_id, | |
| ) | |
| thread = Thread(target=model.generate, kwargs=generation_kwargs, daemon=True) | |
| thread.start() | |
| pieces: list[str] = [] | |
| print("\n--- MODEL OUTPUT ---") | |
| for piece in streamer: | |
| print(piece, end="", flush=True) | |
| pieces.append(piece) | |
| print() | |
| thread.join() | |
| return "".join(pieces) | |
| def run_judge_session(model, tokenizer, max_new_tokens: int) -> None: | |
| task = _build_task() | |
| while True: | |
| try: | |
| scenario_text = input(DEFAULT_SCENARIO_PROMPT).strip() | |
| except (EOFError, KeyboardInterrupt): | |
| print() | |
| break | |
| if not scenario_text: | |
| print("Exiting judge sandbox.") | |
| break | |
| world_state = parse_judge_scenario(scenario_text) | |
| observation = format_observation(world_state=world_state, task=task, step=0) | |
| prompt = build_prompt(observation, scenario_text) | |
| raw_output = _stream_generate(model, tokenizer, prompt, max_new_tokens=max_new_tokens) | |
| parsed = parse_agent_output(raw_output) | |
| if parsed.raw_thinking: | |
| print(f"[PARSED THINKING] {parsed.raw_thinking}") | |
| if parsed.action_id: | |
| print(f"[PARSED ACTION] {parsed.action_id}") | |
| if parsed.parse_errors: | |
| print(f"[PARSE WARNINGS] {'; '.join(parsed.parse_errors)}") | |
| def main() -> None: | |
| parser = argparse.ArgumentParser(description="PERMANENCE Judge Sandbox interactive evaluator") | |
| parser.add_argument("--config", default="training/config.yaml", help="Training config used to locate final_model.") | |
| parser.add_argument("--model-path", default=None, help="Override path to the final trained model directory.") | |
| parser.add_argument("--max-new-tokens", type=int, default=MAX_NEW_TOKENS, help="Maximum tokens to generate per judge run.") | |
| args = parser.parse_args() | |
| model_dir = _load_model_path(args.config, args.model_path) | |
| model, tokenizer = load_final_model(model_dir) | |
| if torch.cuda.is_available(): | |
| torch.cuda.empty_cache() | |
| run_judge_session(model, tokenizer, max_new_tokens=args.max_new_tokens) | |
| if __name__ == "__main__": | |
| main() |