echo-memory / app.py
multimodalart's picture
multimodalart HF Staff
Upload app.py with huggingface_hub
9df870d verified
Raw
History Blame Contribute Delete
11.5 kB
import os
os.environ.setdefault("PYTORCH_CUDA_ALLOC_CONF", "expandable_segments:True")
import spaces # MUST come before torch / any CUDA-touching import
import sys
import json
import math
import tempfile
import numpy as np
import torch
from PIL import Image
import gradio as gr
# ── Ensure repo root is in sys.path so diffsynth / env / src are importable ──
_repo_root = os.path.dirname(os.path.abspath(__file__))
if _repo_root not in sys.path:
sys.path.insert(0, _repo_root)
from huggingface_hub import snapshot_download
# ── Constants ─────────────────────────────────────────────────────────────────
WAN_BASE_MODEL_ID = "Wan-AI/Wan2.1-T2V-1.3B"
ECHO_CKPT_REPO = "Echo-Team/Echo-Memory"
ECHO_CKPT_PATH = "context_k1/epoch-0.safetensors"
DEFAULT_NEGATIVE_PROMPT = "oversaturated colors, overexposed, static, blurry details"
HEIGHT, WIDTH = 352, 640
NUM_FRAMES = 81
FPS = 15
# ── Model loading at module scope (ZeroGPU: .to("cuda") is intercepted) ──────
print("[app] Downloading Wan2.1-T2V-1.3B base model...")
_base_dir = snapshot_download(WAN_BASE_MODEL_ID)
print("[app] Downloading Echo-Memory checkpoint...")
_ckpt_dir = snapshot_download(ECHO_CKPT_REPO)
_ckpt_path = os.path.join(_ckpt_dir, ECHO_CKPT_PATH)
_dit_path = os.path.join(_base_dir, "diffusion_pytorch_model.safetensors")
_text_encoder_path = os.path.join(_base_dir, "models_t5_umt5-xxl-enc-bf16.pth")
_vae_path = os.path.join(_base_dir, "Wan2.1_VAE.pth")
_tokenizer_path = os.path.join(_base_dir, "google", "umt5-xxl")
from env.loop_utils import load_pipeline_and_ckpt
from env.run_replay_loop_two_chunk import run_one_chunk, encode_context_frames_per_frame
from env.memory_baseline_runtime import MemoryProfile, infer_memory_profile_spec
from diffsynth import save_video
from src.model_training.fov_retrieval import compute_rotation_list
# ── Inline helpers from inference/unified_inference.py ──────────────────────
def resolve_memory_profile(memory_type: str, ckpt_path: str) -> MemoryProfile:
"""Resolve memory_type to a MemoryProfile. context_k* use default pipe flags."""
_CONTEXT_K_PROFILES = {
"context_k1": MemoryProfile(context_override=1),
"context_k5": MemoryProfile(context_override=5),
"context_k20": MemoryProfile(context_override=20),
}
if memory_type in _CONTEXT_K_PROFILES:
print(f"[app] Using context learning profile: {memory_type}")
return _CONTEXT_K_PROFILES[memory_type]
spec = infer_memory_profile_spec(ckpt_path)
if spec is not None:
print(f"[app] Auto-detected memory profile: {spec.profile_id}")
return spec.profile
return MemoryProfile()
def apply_profile_to_pipe(pipe, profile: MemoryProfile) -> None:
"""Apply a MemoryProfile directly to the pipeline object."""
pipe.use_framepack_memory = bool(profile.use_framepack_memory)
pipe.context_temporal_decay = float(profile.context_temporal_decay or 1.0)
pipe.context_attention_weight = float(profile.context_attention_weight or 1.0)
pipe.use_framepack_length_compress = bool(profile.use_framepack_length_compress)
pipe.framepack_ratio = int(profile.framepack_ratio or 2)
pipe.use_spatial_memory = bool(profile.use_spatial_memory)
pipe.spatial_memory_tokens = int(profile.spatial_memory_tokens or 64)
if profile.spatial_memory_inject_mode:
pipe.spatial_memory_inject_mode = str(profile.spatial_memory_inject_mode)
pipe.use_spatial_memory_legacy = bool(profile.use_spatial_memory_legacy)
pipe.use_block_wise_ssm = bool(getattr(profile, "use_block_wise_ssm", False))
pipe.use_videossm_hybrid = bool(getattr(profile, "use_videossm_hybrid", False))
print("[app] Loading pipeline (DiT -> cuda)...")
pipe = load_pipeline_and_ckpt(
ckpt_path=_ckpt_path,
dit_path=_dit_path,
text_encoder_path=_text_encoder_path,
vae_path=_vae_path,
device="cuda",
add_action_attn=False,
action_use_temporal_attention=True,
tokenizer_path=_tokenizer_path,
)
# Apply the memory profile for context_k1
_profile = resolve_memory_profile("context_k1", _ckpt_path)
apply_profile_to_pipe(pipe, _profile)
print("[app] Model loaded and memory profile applied.")
def _build_rotation_action(deg: float, clockwise: bool, num_frames: int = 81) -> dict:
"""Build a uniform yaw-rotation action dictionary for `num_frames` frames.
Args:
deg: rotation magnitude in degrees.
clockwise: if True, rotate clockwise (negative yaw); else counter-clockwise.
num_frames: number of frames in the chunk.
Returns:
dict mapping frame index (str) -> 12-D RT list.
"""
denom = max(1, num_frames - 1)
actions = {}
for i in range(num_frames):
yaw = (i / denom) * (-deg if clockwise else deg)
actions[str(i)] = compute_rotation_list([0.0, 0.0, 0.0, yaw])
return actions
@spaces.GPU(duration=180)
def generate(
context_image: Image.Image | None,
prompt: str,
rotation_direction: str,
rotation_degrees: float,
seed: int,
num_inference_steps: int,
cfg_scale: float,
progress=gr.Progress(track_tqdm=True),
):
"""Generate an action-conditioned video from an initial frame and a text prompt.
Args:
context_image: Initial frame (first image of the video).
prompt: Text description of the scene.
rotation_direction: Camera rotation direction ("Left (CCW)" or "Right (CW)").
rotation_degrees: Total rotation in degrees (e.g. 45).
seed: RNG seed for reproducibility.
num_inference_steps: Number of diffusion denoising steps.
cfg_scale: Classifier-free guidance scale.
"""
if context_image is None:
return None, "Please provide an initial image."
if not prompt or not prompt.strip():
return None, "Please provide a text prompt."
clockwise = "Right" in rotation_direction or "CW" in rotation_direction
deg = float(rotation_degrees)
# Resize context image to model resolution
ctx_pil = context_image.convert("RGB").resize((WIDTH, HEIGHT), Image.LANCZOS)
# Encode context frame through VAE
print("[generate] Encoding context image...")
pipe.load_models_to_device(["vae"])
with torch.no_grad():
context_latents = encode_context_frames_per_frame(pipe, [ctx_pil], pipe.device)
num_context_frames = 1
identity_rt = [0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0]
context_actions_t = torch.tensor([identity_rt], dtype=torch.float32)
# Build camera rotation action JSON and write to temp file
cam_pose_actions = _build_rotation_action(deg, clockwise, NUM_FRAMES)
action_tmp = tempfile.NamedTemporaryFile(suffix=".json", delete=False, mode="w")
json.dump(cam_pose_actions, action_tmp)
action_tmp.close()
action_path = action_tmp.name
# Generate video
print(f"[generate] Generating {NUM_FRAMES} frames @ {WIDTH}x{HEIGHT}, rotation={deg}Β° {'CW' if clockwise else 'CCW'}")
frames = run_one_chunk(
pipe=pipe,
prompt=prompt,
use_negative_prompt=DEFAULT_NEGATIVE_PROMPT,
action_path=action_path,
context_latents=context_latents,
num_context_frames=num_context_frames,
context_actions_t=context_actions_t,
chunk_frames=NUM_FRAMES,
h=HEIGHT,
w=WIDTH,
seed=int(seed),
sigma_shift=15.0,
num_inference_steps=int(num_inference_steps),
cfg_scale=float(cfg_scale),
log_prefix="[generate]",
)
# Save to temporary file
tmp = tempfile.NamedTemporaryFile(suffix=".mp4", delete=False)
tmp.close()
save_video(frames, tmp.name, fps=FPS, quality=5)
print(f"[generate] Video saved to {tmp.name}")
return tmp.name, f"Generated {len(frames)} frames with {deg}Β° {'clockwise' if clockwise else 'counter-clockwise'} rotation."
# ── Gradio UI ────────────────────────────────────────────────────────────────
CSS = """
#col-container { max-width: 1100px; margin: 0 auto; }
.dark .gradio-container { color: var(--body-text-color); }
"""
with gr.Blocks() as demo:
gr.Markdown("# 🧠 Echo-Memory: Action-Conditioned World Model")
gr.Markdown(
"Generate a video from an initial frame, a text prompt, and a camera rotation action. "
"Based on the [Echo-Memory](https://huggingface.co/papers/2606.09803) paper β€” "
"a controlled study of memory in action world models using the Wan 2.1 1.3B backbone."
)
with gr.Row():
with gr.Column(scale=1):
context_image = gr.Image(label="Initial Frame", type="pil", height=300)
prompt = gr.Textbox(
label="Text Prompt",
placeholder="A toy bear on a table, the camera rotates around it",
lines=2,
)
with gr.Row():
rotation_direction = gr.Radio(
label="Camera Rotation",
choices=["Left (CCW)", "Right (CW)"],
value="Left (CCW)",
)
rotation_degrees = gr.Slider(
label="Rotation Degrees",
minimum=5,
maximum=90,
value=45,
step=5,
)
run_btn = gr.Button("Generate Video", variant="primary")
with gr.Column(scale=1):
video_output = gr.Video(label="Generated Video", height=300)
status_text = gr.Textbox(label="Status", interactive=False)
with gr.Accordion("Advanced Settings", open=False):
with gr.Row():
seed = gr.Number(label="Seed", value=42, precision=0)
num_inference_steps = gr.Slider(
label="Inference Steps",
minimum=10,
maximum=100,
value=50,
step=5,
)
cfg_scale = gr.Slider(
label="CFG Scale",
minimum=1.0,
maximum=10.0,
value=5.0,
step=0.5,
)
gr.Examples(
examples=[
["examples/1774363417.png", "A toy bear on a table, the camera rotates around it", "Left (CCW)", 45],
["examples/1774363487.png", "A decorative object on a surface, rotating view", "Right (CW)", 45],
["examples/1774363572.png", "A scene with objects on a table, camera pans", "Left (CCW)", 30],
],
inputs=[context_image, prompt, rotation_direction, rotation_degrees],
)
gr.Markdown(
"---\n"
"**Model:** [Echo-Team/Echo-Memory](https://huggingface.co/Echo-Team/Echo-Memory) Β· "
"**Backbone:** [Wan-AI/Wan2.1-T2V-1.3B](https://huggingface.co/Wan-AI/Wan2.1-T2V-1.3B) Β· "
"**Paper:** [arXiv:2606.09803](https://arxiv.org/abs/2606.09803) Β· "
"**Code:** [GitHub](https://github.com/Echo-Team-Joy-Future-Academy-JD/Echo-Memory)"
)
run_btn.click(
fn=generate,
inputs=[context_image, prompt, rotation_direction, rotation_degrees, seed, num_inference_steps, cfg_scale],
outputs=[video_output, status_text],
api_name="generate",
)
demo.launch(theme=gr.themes.Citrus(), css=CSS, mcp_server=True)