WorldStereo-memory-dmd-fp8_e4m3fn_scaled
A scaled FP8 (e4m3fn) quantization of hanshanxue/WorldStereo's
worldstereo-memory-dmd variant. Drops VRAM from 35 GB → 20 GB so the
14B WorldStereo DiT fits resident on 24 GB consumer GPUs (3090, 4090) instead
of needing partial-load streaming.
| bf16 source | fp8_e4m3fn_scaled (this repo) | |
|---|---|---|
| Size on disk | 34.86 GB | 20.35 GB (-41.6%) |
| Tensors | 1799 (all BF16) | 551 FP8 + 1800 BF16 |
| Fits in 24 GB VRAM | No (needs partial_load) | Yes (resident) |
| Quality | reference | typical fp8_scaled drop: <1% perceptible on short clips, faint color drift / micro-detail loss possible on long sequences (refs) |
| Speed on Ampere (3090) | reference | same (no native fp8 matmul; weight upcast to bf16 per matmul) |
| Speed on Ada/Hopper (4090, H100) | reference | >2x via torch._scaled_mm |
Quantization recipe (verbatim from kijai/ComfyUI-WanVideoWrapper)
For each tensor T in the source state-dict:
If
T's name contains any of these keywords -> stored unchanged in BF16:norm, bias, time_in, time_, patch_embedding, img_emb, modulation, text_embedding, adapter, add, ref_conv, audio_projThis keeps numerically-sensitive small tensors (norms, biases, modulation, I2V cross-attn projections) at full precision.
If
Tis.weightAND has rank >= 2 AND survived the exclusion check -> cast to fp8_e4m3fn with per-tensor scale:FP8_MAX = 448.0 scale = T.float().abs().amax() / FP8_MAX T_fp8 = (T.float() / scale).clamp(-FP8_MAX, FP8_MAX).to(torch.float8_e4m3fn)Stored as two tensors:
<key>(the fp8 weight)<key>.scale_weight(bfloat16 scalar)
A marker tensor
scaled_fp8(bf16 zero) is added so loaders that check"scaled_fp8" in state_dictauto-detect the scaled-fp8 format.
Inference
At runtime the scale must be re-applied. Two paths:
Ampere (3090) - upcast on matmul
def fp8_scaled_linear(x, w_fp8, scale, bias):
return torch.nn.functional.linear(x, w_fp8.to(x.dtype) * scale, bias)
Ada/Hopper (4090, H100) - native fp8 matmul
def fp8_scaled_linear_fast(x, w_fp8, scale_w, bias, base_dtype=torch.bfloat16):
x = x.clamp(-448, 448).to(torch.float8_e4m3fn).contiguous()
scale_in = torch.ones((), device=x.device, dtype=torch.float32)
return torch._scaled_mm(x.view(-1, x.shape[-1]), w_fp8.t(),
out_dtype=base_dtype, bias=bias,
scale_a=scale_in, scale_b=scale_w)
Reference implementations:
kijai/ComfyUI-WanVideoWrapper/fp8_optimization.py(fast path)kijai/ComfyUI-WanVideoWrapper/nodes_model_loading.py_replace_linear(slow path)
Naming convention
This file uses diffusers-style keys (blocks.0.attn1.to_q.weight,
blocks.0.ffn.net.0.proj.weight), matching the original hanshanxue/WorldStereo
source layout. This is different from Kijai's published Wan2.1 fp8 files which
use native Wan keys (blocks.0.self_attn.q.weight). If your loader expects
native names, run a key remap (see diffusers_to_native_wan in
ComfyUI-WorldStereo).
What's in the file
- 549 FP8 tensors (14.5 GB): rank-2 Linear weights only (self-attn Q/K/V/O, FFN projections, controlnet linears). Conv weights stay BF16: no scaled-fp8 Conv kernel exists in PyTorch / ComfyUI
- 1800 BF16 tensors (5.84 GB):
- 362
normtensors (RMSNorm gains): 1.26 GB - 80
add_*tensors (I2V cross-attention to image embeddings): 4.19 GB - 3
time_*tensors (timestep MLP): 0.37 GB - 759
bias+ others: ~0.02 GB - 551
*.scale_weightscalars: 1.1 MB
- 362
- 1 marker tensor
scaled_fp8
Quality
This recipe (Kijai's fp8_e4m3fn_scaled) is the community standard for Wan2.1 14B inference on 24 GB GPUs. Documented behavior:
- Subjective: indistinguishable from bf16 on most short clips (Kijai's own statement: "never any difference remotely [from fp16]" in his per-tensor-scaled tests).
- Known artifacts on long clips: faint color drift, occasional minor detail loss; usually not visible side-by-side without close inspection. References:
- Round-trip on a sample 5120x5120 attention weight: max_abs_err = 0.005 (3.7% of the tensor's max value, ~standard for e4m3fn-scaled).
If you observe quality regressions specific to WorldStereo's camera/scene conditioning, fall back to bf16 with partial-load streaming.
Provenance
- Source:
hanshanxue/WorldStereo(worldstereo-memory-dmdvariant, commit2adb716) - Source base model:
Wan-AI/Wan2.1-I2V-14B-480P-Diffusers(Apache-2.0)- Verified: WorldStereo's Wan base is frozen (sampled tensors differ
from upstream only at fp16->bf16 ULPs). The novel content is the 1.76 GB
controlnet branch under
controlnet.controlnet_blocks.*-- WorldStereo's camera/scene-render conditioning.
- Verified: WorldStereo's Wan base is frozen (sampled tensors differ
from upstream only at fp16->bf16 ULPs). The novel content is the 1.76 GB
controlnet branch under
- Quantization tool:
cook_worldstereo_fp8.py(in this repo)
License
- WorldStereo overlay: MIT (inherited from
hanshanxue/WorldStereo) - Wan2.1 base weights: Apache-2.0 (inherited from
Wan-AI/Wan2.1-I2V-14B-480P-Diffusers)
A NOTICE file in this repo provides Apache-2.0 attribution and a statement of changes per section 4 of the license.
Files
WorldStereo-memory-dmd-fp8_e4m3fn_scaled.safetensors- the quantized modelcook_worldstereo_fp8.py- the quantization script (reproducible)config.json- the sourceworldstereo-memory-dmdconfig (carries the scale_map, base_model path, sampling settings)README.md- this fileNOTICE- Apache-2.0 attributionLICENSE- MIT (for the WorldStereo overlay portion)
- Downloads last month
- 50
Model tree for apozz/WorldStereo-fp8
Base model
Wan-AI/Wan2.1-I2V-14B-480P-Diffusers