WorldStereo-memory-dmd-fp8_e4m3fn_scaled

A scaled FP8 (e4m3fn) quantization of hanshanxue/WorldStereo's worldstereo-memory-dmd variant. Drops VRAM from 35 GB → 20 GB so the 14B WorldStereo DiT fits resident on 24 GB consumer GPUs (3090, 4090) instead of needing partial-load streaming.

bf16 source fp8_e4m3fn_scaled (this repo)
Size on disk 34.86 GB 20.35 GB (-41.6%)
Tensors 1799 (all BF16) 551 FP8 + 1800 BF16
Fits in 24 GB VRAM No (needs partial_load) Yes (resident)
Quality reference typical fp8_scaled drop: <1% perceptible on short clips, faint color drift / micro-detail loss possible on long sequences (refs)
Speed on Ampere (3090) reference same (no native fp8 matmul; weight upcast to bf16 per matmul)
Speed on Ada/Hopper (4090, H100) reference >2x via torch._scaled_mm

Quantization recipe (verbatim from kijai/ComfyUI-WanVideoWrapper)

For each tensor T in the source state-dict:

  1. If T's name contains any of these keywords -> stored unchanged in BF16:

    norm, bias, time_in, time_, patch_embedding, img_emb, modulation,
    text_embedding, adapter, add, ref_conv, audio_proj
    

    This keeps numerically-sensitive small tensors (norms, biases, modulation, I2V cross-attn projections) at full precision.

  2. If T is .weight AND has rank >= 2 AND survived the exclusion check -> cast to fp8_e4m3fn with per-tensor scale:

    FP8_MAX = 448.0
    scale = T.float().abs().amax() / FP8_MAX
    T_fp8 = (T.float() / scale).clamp(-FP8_MAX, FP8_MAX).to(torch.float8_e4m3fn)
    

    Stored as two tensors:

    • <key> (the fp8 weight)
    • <key>.scale_weight (bfloat16 scalar)
  3. A marker tensor scaled_fp8 (bf16 zero) is added so loaders that check "scaled_fp8" in state_dict auto-detect the scaled-fp8 format.

Inference

At runtime the scale must be re-applied. Two paths:

Ampere (3090) - upcast on matmul

def fp8_scaled_linear(x, w_fp8, scale, bias):
    return torch.nn.functional.linear(x, w_fp8.to(x.dtype) * scale, bias)

Ada/Hopper (4090, H100) - native fp8 matmul

def fp8_scaled_linear_fast(x, w_fp8, scale_w, bias, base_dtype=torch.bfloat16):
    x = x.clamp(-448, 448).to(torch.float8_e4m3fn).contiguous()
    scale_in = torch.ones((), device=x.device, dtype=torch.float32)
    return torch._scaled_mm(x.view(-1, x.shape[-1]), w_fp8.t(),
                            out_dtype=base_dtype, bias=bias,
                            scale_a=scale_in, scale_b=scale_w)

Reference implementations:

  • kijai/ComfyUI-WanVideoWrapper/fp8_optimization.py (fast path)
  • kijai/ComfyUI-WanVideoWrapper/nodes_model_loading.py _replace_linear (slow path)

Naming convention

This file uses diffusers-style keys (blocks.0.attn1.to_q.weight, blocks.0.ffn.net.0.proj.weight), matching the original hanshanxue/WorldStereo source layout. This is different from Kijai's published Wan2.1 fp8 files which use native Wan keys (blocks.0.self_attn.q.weight). If your loader expects native names, run a key remap (see diffusers_to_native_wan in ComfyUI-WorldStereo).

What's in the file

  • 549 FP8 tensors (14.5 GB): rank-2 Linear weights only (self-attn Q/K/V/O, FFN projections, controlnet linears). Conv weights stay BF16: no scaled-fp8 Conv kernel exists in PyTorch / ComfyUI
  • 1800 BF16 tensors (5.84 GB):
    • 362 norm tensors (RMSNorm gains): 1.26 GB
    • 80 add_* tensors (I2V cross-attention to image embeddings): 4.19 GB
    • 3 time_* tensors (timestep MLP): 0.37 GB
    • 759 bias + others: ~0.02 GB
    • 551 *.scale_weight scalars: 1.1 MB
  • 1 marker tensor scaled_fp8

Quality

This recipe (Kijai's fp8_e4m3fn_scaled) is the community standard for Wan2.1 14B inference on 24 GB GPUs. Documented behavior:

  • Subjective: indistinguishable from bf16 on most short clips (Kijai's own statement: "never any difference remotely [from fp16]" in his per-tensor-scaled tests).
  • Known artifacts on long clips: faint color drift, occasional minor detail loss; usually not visible side-by-side without close inspection. References:
  • Round-trip on a sample 5120x5120 attention weight: max_abs_err = 0.005 (3.7% of the tensor's max value, ~standard for e4m3fn-scaled).

If you observe quality regressions specific to WorldStereo's camera/scene conditioning, fall back to bf16 with partial-load streaming.

Provenance

  • Source: hanshanxue/WorldStereo (worldstereo-memory-dmd variant, commit 2adb716)
  • Source base model: Wan-AI/Wan2.1-I2V-14B-480P-Diffusers (Apache-2.0)
    • Verified: WorldStereo's Wan base is frozen (sampled tensors differ from upstream only at fp16->bf16 ULPs). The novel content is the 1.76 GB controlnet branch under controlnet.controlnet_blocks.* -- WorldStereo's camera/scene-render conditioning.
  • Quantization tool: cook_worldstereo_fp8.py (in this repo)

License

  • WorldStereo overlay: MIT (inherited from hanshanxue/WorldStereo)
  • Wan2.1 base weights: Apache-2.0 (inherited from Wan-AI/Wan2.1-I2V-14B-480P-Diffusers)

A NOTICE file in this repo provides Apache-2.0 attribution and a statement of changes per section 4 of the license.

Files

  • WorldStereo-memory-dmd-fp8_e4m3fn_scaled.safetensors - the quantized model
  • cook_worldstereo_fp8.py - the quantization script (reproducible)
  • config.json - the source worldstereo-memory-dmd config (carries the scale_map, base_model path, sampling settings)
  • README.md - this file
  • NOTICE - Apache-2.0 attribution
  • LICENSE - MIT (for the WorldStereo overlay portion)
Downloads last month
50
Inference Providers NEW
This model isn't deployed by any Inference Provider. 🙋 Ask for provider support

Model tree for apozz/WorldStereo-fp8

Finetuned
(5)
this model