Spaces:
Running on Zero
Running on Zero
| import os | |
| os.environ.setdefault("PYTORCH_CUDA_ALLOC_CONF", "expandable_segments:True") | |
| import spaces # MUST come before torch / any CUDA-touching import | |
| import sys | |
| import json | |
| import math | |
| import tempfile | |
| import numpy as np | |
| import torch | |
| from PIL import Image | |
| import gradio as gr | |
| # ββ Ensure repo root is in sys.path so diffsynth / env / src are importable ββ | |
| _repo_root = os.path.dirname(os.path.abspath(__file__)) | |
| if _repo_root not in sys.path: | |
| sys.path.insert(0, _repo_root) | |
| from huggingface_hub import snapshot_download | |
| # ββ Constants βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| WAN_BASE_MODEL_ID = "Wan-AI/Wan2.1-T2V-1.3B" | |
| ECHO_CKPT_REPO = "Echo-Team/Echo-Memory" | |
| ECHO_CKPT_PATH = "context_k1/epoch-0.safetensors" | |
| DEFAULT_NEGATIVE_PROMPT = "oversaturated colors, overexposed, static, blurry details" | |
| HEIGHT, WIDTH = 352, 640 | |
| NUM_FRAMES = 81 | |
| FPS = 15 | |
| # ββ Model loading at module scope (ZeroGPU: .to("cuda") is intercepted) ββββββ | |
| print("[app] Downloading Wan2.1-T2V-1.3B base model...") | |
| _base_dir = snapshot_download(WAN_BASE_MODEL_ID) | |
| print("[app] Downloading Echo-Memory checkpoint...") | |
| _ckpt_dir = snapshot_download(ECHO_CKPT_REPO) | |
| _ckpt_path = os.path.join(_ckpt_dir, ECHO_CKPT_PATH) | |
| _dit_path = os.path.join(_base_dir, "diffusion_pytorch_model.safetensors") | |
| _text_encoder_path = os.path.join(_base_dir, "models_t5_umt5-xxl-enc-bf16.pth") | |
| _vae_path = os.path.join(_base_dir, "Wan2.1_VAE.pth") | |
| _tokenizer_path = os.path.join(_base_dir, "google", "umt5-xxl") | |
| from env.loop_utils import load_pipeline_and_ckpt | |
| from env.run_replay_loop_two_chunk import run_one_chunk, encode_context_frames_per_frame | |
| from env.memory_baseline_runtime import MemoryProfile, infer_memory_profile_spec | |
| from diffsynth import save_video | |
| from src.model_training.fov_retrieval import compute_rotation_list | |
| # ββ Inline helpers from inference/unified_inference.py ββββββββββββββββββββββ | |
| def resolve_memory_profile(memory_type: str, ckpt_path: str) -> MemoryProfile: | |
| """Resolve memory_type to a MemoryProfile. context_k* use default pipe flags.""" | |
| _CONTEXT_K_PROFILES = { | |
| "context_k1": MemoryProfile(context_override=1), | |
| "context_k5": MemoryProfile(context_override=5), | |
| "context_k20": MemoryProfile(context_override=20), | |
| } | |
| if memory_type in _CONTEXT_K_PROFILES: | |
| print(f"[app] Using context learning profile: {memory_type}") | |
| return _CONTEXT_K_PROFILES[memory_type] | |
| spec = infer_memory_profile_spec(ckpt_path) | |
| if spec is not None: | |
| print(f"[app] Auto-detected memory profile: {spec.profile_id}") | |
| return spec.profile | |
| return MemoryProfile() | |
| def apply_profile_to_pipe(pipe, profile: MemoryProfile) -> None: | |
| """Apply a MemoryProfile directly to the pipeline object.""" | |
| pipe.use_framepack_memory = bool(profile.use_framepack_memory) | |
| pipe.context_temporal_decay = float(profile.context_temporal_decay or 1.0) | |
| pipe.context_attention_weight = float(profile.context_attention_weight or 1.0) | |
| pipe.use_framepack_length_compress = bool(profile.use_framepack_length_compress) | |
| pipe.framepack_ratio = int(profile.framepack_ratio or 2) | |
| pipe.use_spatial_memory = bool(profile.use_spatial_memory) | |
| pipe.spatial_memory_tokens = int(profile.spatial_memory_tokens or 64) | |
| if profile.spatial_memory_inject_mode: | |
| pipe.spatial_memory_inject_mode = str(profile.spatial_memory_inject_mode) | |
| pipe.use_spatial_memory_legacy = bool(profile.use_spatial_memory_legacy) | |
| pipe.use_block_wise_ssm = bool(getattr(profile, "use_block_wise_ssm", False)) | |
| pipe.use_videossm_hybrid = bool(getattr(profile, "use_videossm_hybrid", False)) | |
| print("[app] Loading pipeline (DiT -> cuda)...") | |
| pipe = load_pipeline_and_ckpt( | |
| ckpt_path=_ckpt_path, | |
| dit_path=_dit_path, | |
| text_encoder_path=_text_encoder_path, | |
| vae_path=_vae_path, | |
| device="cuda", | |
| add_action_attn=False, | |
| action_use_temporal_attention=True, | |
| tokenizer_path=_tokenizer_path, | |
| ) | |
| # Apply the memory profile for context_k1 | |
| _profile = resolve_memory_profile("context_k1", _ckpt_path) | |
| apply_profile_to_pipe(pipe, _profile) | |
| print("[app] Model loaded and memory profile applied.") | |
| def _build_rotation_action(deg: float, clockwise: bool, num_frames: int = 81) -> dict: | |
| """Build a uniform yaw-rotation action dictionary for `num_frames` frames. | |
| Args: | |
| deg: rotation magnitude in degrees. | |
| clockwise: if True, rotate clockwise (negative yaw); else counter-clockwise. | |
| num_frames: number of frames in the chunk. | |
| Returns: | |
| dict mapping frame index (str) -> 12-D RT list. | |
| """ | |
| denom = max(1, num_frames - 1) | |
| actions = {} | |
| for i in range(num_frames): | |
| yaw = (i / denom) * (-deg if clockwise else deg) | |
| actions[str(i)] = compute_rotation_list([0.0, 0.0, 0.0, yaw]) | |
| return actions | |
| def generate( | |
| context_image: Image.Image | None, | |
| prompt: str, | |
| rotation_direction: str, | |
| rotation_degrees: float, | |
| seed: int, | |
| num_inference_steps: int, | |
| cfg_scale: float, | |
| progress=gr.Progress(track_tqdm=True), | |
| ): | |
| """Generate an action-conditioned video from an initial frame and a text prompt. | |
| Args: | |
| context_image: Initial frame (first image of the video). | |
| prompt: Text description of the scene. | |
| rotation_direction: Camera rotation direction ("Left (CCW)" or "Right (CW)"). | |
| rotation_degrees: Total rotation in degrees (e.g. 45). | |
| seed: RNG seed for reproducibility. | |
| num_inference_steps: Number of diffusion denoising steps. | |
| cfg_scale: Classifier-free guidance scale. | |
| """ | |
| if context_image is None: | |
| return None, "Please provide an initial image." | |
| if not prompt or not prompt.strip(): | |
| return None, "Please provide a text prompt." | |
| clockwise = "Right" in rotation_direction or "CW" in rotation_direction | |
| deg = float(rotation_degrees) | |
| # Resize context image to model resolution | |
| ctx_pil = context_image.convert("RGB").resize((WIDTH, HEIGHT), Image.LANCZOS) | |
| # Encode context frame through VAE | |
| print("[generate] Encoding context image...") | |
| pipe.load_models_to_device(["vae"]) | |
| with torch.no_grad(): | |
| context_latents = encode_context_frames_per_frame(pipe, [ctx_pil], pipe.device) | |
| num_context_frames = 1 | |
| identity_rt = [0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0] | |
| context_actions_t = torch.tensor([identity_rt], dtype=torch.float32) | |
| # Build camera rotation action JSON and write to temp file | |
| cam_pose_actions = _build_rotation_action(deg, clockwise, NUM_FRAMES) | |
| action_tmp = tempfile.NamedTemporaryFile(suffix=".json", delete=False, mode="w") | |
| json.dump(cam_pose_actions, action_tmp) | |
| action_tmp.close() | |
| action_path = action_tmp.name | |
| # Generate video | |
| print(f"[generate] Generating {NUM_FRAMES} frames @ {WIDTH}x{HEIGHT}, rotation={deg}Β° {'CW' if clockwise else 'CCW'}") | |
| frames = run_one_chunk( | |
| pipe=pipe, | |
| prompt=prompt, | |
| use_negative_prompt=DEFAULT_NEGATIVE_PROMPT, | |
| action_path=action_path, | |
| context_latents=context_latents, | |
| num_context_frames=num_context_frames, | |
| context_actions_t=context_actions_t, | |
| chunk_frames=NUM_FRAMES, | |
| h=HEIGHT, | |
| w=WIDTH, | |
| seed=int(seed), | |
| sigma_shift=15.0, | |
| num_inference_steps=int(num_inference_steps), | |
| cfg_scale=float(cfg_scale), | |
| log_prefix="[generate]", | |
| ) | |
| # Save to temporary file | |
| tmp = tempfile.NamedTemporaryFile(suffix=".mp4", delete=False) | |
| tmp.close() | |
| save_video(frames, tmp.name, fps=FPS, quality=5) | |
| print(f"[generate] Video saved to {tmp.name}") | |
| return tmp.name, f"Generated {len(frames)} frames with {deg}Β° {'clockwise' if clockwise else 'counter-clockwise'} rotation." | |
| # ββ Gradio UI ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| CSS = """ | |
| #col-container { max-width: 1100px; margin: 0 auto; } | |
| .dark .gradio-container { color: var(--body-text-color); } | |
| """ | |
| with gr.Blocks() as demo: | |
| gr.Markdown("# π§ Echo-Memory: Action-Conditioned World Model") | |
| gr.Markdown( | |
| "Generate a video from an initial frame, a text prompt, and a camera rotation action. " | |
| "Based on the [Echo-Memory](https://huggingface.co/papers/2606.09803) paper β " | |
| "a controlled study of memory in action world models using the Wan 2.1 1.3B backbone." | |
| ) | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| context_image = gr.Image(label="Initial Frame", type="pil", height=300) | |
| prompt = gr.Textbox( | |
| label="Text Prompt", | |
| placeholder="A toy bear on a table, the camera rotates around it", | |
| lines=2, | |
| ) | |
| with gr.Row(): | |
| rotation_direction = gr.Radio( | |
| label="Camera Rotation", | |
| choices=["Left (CCW)", "Right (CW)"], | |
| value="Left (CCW)", | |
| ) | |
| rotation_degrees = gr.Slider( | |
| label="Rotation Degrees", | |
| minimum=5, | |
| maximum=90, | |
| value=45, | |
| step=5, | |
| ) | |
| run_btn = gr.Button("Generate Video", variant="primary") | |
| with gr.Column(scale=1): | |
| video_output = gr.Video(label="Generated Video", height=300) | |
| status_text = gr.Textbox(label="Status", interactive=False) | |
| with gr.Accordion("Advanced Settings", open=False): | |
| with gr.Row(): | |
| seed = gr.Number(label="Seed", value=42, precision=0) | |
| num_inference_steps = gr.Slider( | |
| label="Inference Steps", | |
| minimum=10, | |
| maximum=100, | |
| value=50, | |
| step=5, | |
| ) | |
| cfg_scale = gr.Slider( | |
| label="CFG Scale", | |
| minimum=1.0, | |
| maximum=10.0, | |
| value=5.0, | |
| step=0.5, | |
| ) | |
| gr.Examples( | |
| examples=[ | |
| ["examples/1774363417.png", "A toy bear on a table, the camera rotates around it", "Left (CCW)", 45], | |
| ["examples/1774363487.png", "A decorative object on a surface, rotating view", "Right (CW)", 45], | |
| ["examples/1774363572.png", "A scene with objects on a table, camera pans", "Left (CCW)", 30], | |
| ], | |
| inputs=[context_image, prompt, rotation_direction, rotation_degrees], | |
| ) | |
| gr.Markdown( | |
| "---\n" | |
| "**Model:** [Echo-Team/Echo-Memory](https://huggingface.co/Echo-Team/Echo-Memory) Β· " | |
| "**Backbone:** [Wan-AI/Wan2.1-T2V-1.3B](https://huggingface.co/Wan-AI/Wan2.1-T2V-1.3B) Β· " | |
| "**Paper:** [arXiv:2606.09803](https://arxiv.org/abs/2606.09803) Β· " | |
| "**Code:** [GitHub](https://github.com/Echo-Team-Joy-Future-Academy-JD/Echo-Memory)" | |
| ) | |
| run_btn.click( | |
| fn=generate, | |
| inputs=[context_image, prompt, rotation_direction, rotation_degrees, seed, num_inference_steps, cfg_scale], | |
| outputs=[video_output, status_text], | |
| api_name="generate", | |
| ) | |
| demo.launch(theme=gr.themes.Citrus(), css=CSS, mcp_server=True) |