Skip to content

feat: add MXFP8 fused operators for Wan transformer inference on SM120#1090

Merged
STwangyingrui merged 4 commits into
ModelTC:mainfrom
Fatemanx:perf/wan-mxfp8-fuse-op
Jun 9, 2026
Merged

feat: add MXFP8 fused operators for Wan transformer inference on SM120#1090
STwangyingrui merged 4 commits into
ModelTC:mainfrom
Fatemanx:perf/wan-mxfp8-fuse-op

Conversation

@Fatemanx

@Fatemanx Fatemanx commented May 23, 2026

Copy link
Copy Markdown
Contributor

Implement MXFP8 fused CUDA operators for Wan transformer inference on Blackwell (SM120):

  1. scaled_mxfp8_gelu_quant: fuse GELU activation + E8M0 quantization
  2. scaled_mxfp8_modulate_quant: fuse Wan AdaLN x * (1 + scale) + shift + quantization for BF16 2D activations
  3. cutlass_scaled_mxfp8_mm_residual_gate: fuse MXFP8 GEMM with residual-gate update

Performance on RTX 5090 using the reproducible benchmark script added in this PR:

python lightx2v_kernel/test/mxfp8_mxfp8/bench_wan_fused_kernels.py --tokens 4096 --hidden 3072 --ffn 14336 --warmup 50 --iters 200

Shape source: official Wan-AI/Wan2.2-TI2V-5B config.json uses dim=3072, ffn_dim=14336, num_layers=30. tokens=4096 is the benchmark workload size, not model config metadata.

Measured results:

  • GELU+Quant: 2.48x faster (264.89us -> 106.74us)
  • Modulate+Quant: 8.14x faster (142.37us -> 17.50us)
  • GEMM+Residual+Gate: 1.00x faster (677.78us -> 677.09us)
  • End-to-end FFN: 1.11x faster (1612.94us -> 1458.86us, -154.08us per block)
  • Fused FFN path reduces launches from 7 to 3 per FFN block by fusing modulate+quant, gelu+quant, and residual+gate into the final GEMM epilogue

Review follow-ups:

  • Added mxfp8_fuse_enable config switch, default true, for regression comparison/debugging
  • Forwarded c_gate_msa through Wan base, offload, feature caching, lingbot, lingbot_fast, audio, and self-forcing paths that call the fused FFN contract
  • Extracted Wan-side MXFP8 fused helper logic into wan/infer/mxfp8_fuse.py
  • Documented scaled_mxfp8_modulate_quant as a Wan AdaLN-specific helper, not a generic modulate op
  • Kept 1D gate on the CUTLASS epilogue fast path; added vectorized 2D fallback paths for per-element gates
  • Added fused FFN None/in-place residual contract docs and clean_cuda_cache cleanup for fused intermediates
  • Added bench_wan_fused_kernels.py to reproduce the fused-kernel and end-to-end FFN performance table

Scope and fallback:

  • Requires CUDA SM120/SM120a; non-SM120 devices fall back to the existing path
  • Python-side fused modulation currently requires BF16 tensors and 2D activations
  • Coverage is limited to the Wan inference paths wired in this PR; paths with independent FFN implementations may still use their existing non-fused logic

Tested in exp_env on SM120:

  • python lightx2v_kernel/test/mxfp8_mxfp8/test_fused_ffn.py -> 14 tests OK
  • python -m unittest -q test_cases.test_wan_mxfp8_fuse_forwarding -> 6 tests OK
  • python -m pre_commit run --all-files -> passed
  • python lightx2v_kernel/test/mxfp8_mxfp8/bench_wan_fused_kernels.py --tokens 4096 --hidden 3072 --ffn 14336 --warmup 50 --iters 200 -> benchmark results above

@gemini-code-assist gemini-code-assist Bot left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request implements MXFP8 fused operations for the Wan transformer model, specifically optimized for SM120/SM120a GPUs. The changes include new CUDA kernels for MXFP8 GeLU quantization, modulate quantization, and a fused residual-gate GEMM utilizing CUTLASS, along with corresponding Python wrappers and unit tests. Reviewer feedback highlights several optimization and code quality improvements: moving static parameter device transfers out of the inference loop, consolidating duplicated hardware validation logic into a common utility, replacing std::cerr with idiomatic TORCH_CHECK calls, improving numerical precision by avoiding intermediate rounding in the residual update, and eliminating dynamic tensor allocations in the performance-critical path.

return self._mxfp8_apply_quantized(module, input_tensor_quant, input_tensor_scale)

def _mxfp8_apply_quantized(self, module, input_tensor_quant, input_tensor_scale):
module.alpha = module.alpha.to(module.weight.device)

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Moving module.alpha to the weight device in every iteration of the inference loop introduces unnecessary Python overhead and potential synchronization points. Since alpha is a static quantization parameter, it should ideally be moved to the correct device once during model initialization. At the very least, check if the device move is necessary before performing it to avoid redundant operations.

        if module.alpha.device != module.weight.device:
            module.alpha = module.alpha.to(module.weight.device)

Comment on lines +583 to +605
inline void check_sm120_or_throw(torch::Tensor const& tensor, char const* op_name) {
int device = tensor.get_device();
check_valid_cuda_device_index(device, op_name);

static std::array<std::once_flag, kMaxCudaDevices> device_once;
static std::array<int, kMaxCudaDevices> cached_major{};
static std::array<int, kMaxCudaDevices> cached_minor{};

std::call_once(device_once[device], [device]() {
CHECK_CUDA_SUCCESS(cudaDeviceGetAttribute(&cached_major[device], cudaDevAttrComputeCapabilityMajor, device));
CHECK_CUDA_SUCCESS(cudaDeviceGetAttribute(&cached_minor[device], cudaDevAttrComputeCapabilityMinor, device));
});

TORCH_CHECK(
cached_major[device] == 12,
op_name,
" is only supported on SM120/SM120a GPUs, got CUDA device ",
device,
" with compute capability ",
cached_major[device],
".",
cached_minor[device]);
}

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The check_sm120_or_throw utility is duplicated in both mxfp8_quant_kernels_sm120.cu and mxfp8_scaled_mm_kernels_sm120.cu. Furthermore, the implementations differ (one uses cudaDeviceGetAttribute while the other uses cudaGetDeviceProperties). This function should be moved to a common header (e.g., utils.h) to ensure consistency and reduce code duplication.

Comment on lines +752 to +755
default: {
std::cerr << "Observing: " << input.scalar_type() << " for the input datatype which is invalid";
throw std::runtime_error("Unsupported input data type for gelu_quantize_to_fp8.");
}

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Using std::cerr for error reporting in a library is not recommended as it bypasses standard logging mechanisms and can be missed in production environments. Using TORCH_CHECK is more idiomatic in PyTorch extensions and provides a better error message including the tensor type.

    default: {
      TORCH_CHECK(false, "Unsupported input data type for gelu_quantize_to_fp8: ", input.scalar_type());
    }

Comment on lines +660 to +663
float product = __bfloat162float(ffn_out[idx]) * __bfloat162float(gate[gate_idx]);
__nv_bfloat16 rounded_product = __float2bfloat16(product);
float sum = __bfloat162float(residual[idx]) + __bfloat162float(rounded_product);
residual[idx] = __float2bfloat16(sum);

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The intermediate rounding to __nv_bfloat16 at line 661 is unnecessary and reduces numerical precision. You can perform the addition in float before the final rounding to the output type to maintain higher accuracy during the residual update.

    float product = __bfloat162float(ffn_out[idx]) * __bfloat162float(gate[gate_idx]);
    float sum = __bfloat162float(residual[idx]) + product;
    residual[idx] = __float2bfloat16(sum);

residual, A, B, A_sf, B_sf, alpha, bias, gate, meta.m, meta.n, meta.k, stream);
return;
}
auto ffn_out = torch::empty_like(residual);

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Allocating a new tensor using torch::empty_like(residual) in the fallback path for 2D gates will significantly degrade performance during prefill (where M is large). Since this PR aims to optimize the FFN block, you should avoid dynamic allocations in the hot path. Consider using a pre-allocated workspace or allowing the caller to provide an output buffer.

Fatemanx added a commit to Fatemanx/LightX2V that referenced this pull request May 23, 2026
Implement three fused CUDA kernels for MXFP8 quantized inference on Blackwell (SM120):

1. scaled_mxfp8_gelu_quant: fuse GELU activation + E8M0 quantization
2. scaled_mxfp8_modulate_quant: fuse scale/shift modulation + quantization
3. cutlass_scaled_mxfp8_mm_residual_gate: fuse GEMM + residual + gate in CUTLASS 3.x epilogue

Performance on RTX 5090 (Wan 5B FFN, m=4096, hidden=1536, ffn=8960):
- GELU+Quant: 1.30× faster (27.8μs → 21.3μs)
- Modulate+Quant: 3.26× faster (92.7μs → 28.5μs)
- GEMM+Residual+Gate: 1.40× faster (194.7μs → 138.9μs)
- End-to-end FFN: 1.20× faster (608μs → 505μs, -103μs per block)
- Reduces kernel launches from 7 to 3 per FFN block

Features:
- Supports all Wan tasks (t2v/i2v/flf2v/animate/s2v/rs2v)
- Auto-fallback on non-SM120 GPUs (H100/A100/RTX4090) with warning
- Handles FP16/BF16 activations (kernel auto-detects dtype)
- One-time device capability probe at init (eliminates ~4000 redundant checks per inference)

Tested: 10/10 unit tests pass, 6/6 fallback scenarios verified

Address review feedback (PR ModelTC#1090):
- Skip alpha device move when already on target device
- Extract check_sm120_or_throw to shared header sm120_utils.h
- Replace std::cerr with TORCH_CHECK in dtype switch fallbacks
- Avoid intermediate BF16 round in residual_gate kernel
- Apply ruff-format

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
@Fatemanx Fatemanx force-pushed the perf/wan-mxfp8-fuse-op branch from 3b4ec5e to 94816cc Compare May 23, 2026 04:45
@Fatemanx

Copy link
Copy Markdown
Contributor Author

Thanks @gemini-code-assist for the thorough review! Addressed all five comments in the latest amended commit:

  • Comment 1 (alpha device move): Added if module.alpha.device != module.weight.device guard in both _mxfp8_apply_quantized and _mxfp8_apply_residual_gate_quantized to skip the redundant op when already on the target device.

  • Comment 2 (duplicated check_sm120_or_throw): Extracted to a shared header lightx2v_kernel/csrc/gemm/sm120_utils.h under namespace lightx2v_kernel. Both kernel files now #include it and call lightx2v_kernel::check_sm120_or_throw. Standardized on cudaGetDeviceProperties for consistency.

  • Comment 3 (std::cerrTORCH_CHECK): Replaced three switch default cases (quantize_to_fp8, gelu_quantize_to_fp8, modulate_quantize_to_fp8) with single-line TORCH_CHECK(false, ...).

  • Comment 4 (BF16 intermediate round): Removed the __nv_bfloat16 rounded_product step in mxfp8_residual_gate_bf16_kernel; now keeps the product in fp32 until the final round, improving numerical precision.

  • ⏸️ Comment 5 (2D fallback dynamic allocation): Deferred to a follow-up PR. The production Wan inference path always passes a 1D gate (c_gate_msa.squeeze() is shape (hidden,)), so it never triggers the 2D fallback. Refactoring the fallback to use a pre-allocated workspace is best done in a separate PR alongside other potential 2D path optimizations.

Verification:

  • pre-commit run --all-files: all hooks pass
  • pytest lightx2v_kernel/test/mxfp8_mxfp8/test_fused_ffn.py: 10/10 pass
  • Net diff: +22 / -98 lines (deduplication shrinks the patch).

@gemini-code-assist

Copy link
Copy Markdown
Contributor

Thanks for the detailed update and for addressing the review comments. The changes, including the extraction of the SM120 check to a shared header and the replacement of std::cerr with TORCH_CHECK, are well-implemented. I also appreciate the verification steps you've taken. The PR looks good to go from my perspective.

@STwangyingrui

Copy link
Copy Markdown
Contributor

Thanks for this well-executed PR! The SM120 MXFP8 fusion work is a meaningful performance win, and the kernel tests give a solid baseline to review against. Below are the consolidated review comments. Please rebase, fix c_gate_msa call coverage (P0), and address the P1 items below where practical. Happy to re-review after rebase and the P0 fixes.

1. Merge & CI (P0)

  • Currently not mergeable against main; please rebase.
  • After rebase, please run lightx2v_kernel/test/mxfp8_mxfp8/test_fused_ffn.py on SM120.

2. Incomplete call-site coverage (P0)

The fused FFN path requires c_gate_msa to be passed into infer_ffn, but the following paths still use the old signature and do not pass gate:

  • wan/infer/offload/transformer_infer.py
  • wan/infer/feature_caching/transformer_infer.py
  • wan/infer/lingbot/transformer_infer.py
  • wan/infer/audio/transformer_infer.py

With dit_quant_scheme=mxfp8 on SM120, _ensure_mxfp8_quant_ffn_ready will RuntimeError when c_gate_msa is None, rather than gracefully falling back.

3. scaled_mxfp8_modulate_quant: universality & API placement (P1)

This op is not a generic modulate fusion; it is a Wan AdaLN subset:

  • Fixed formula: x * (1 + scale) + shift
  • Shapes limited to 1D (N,) or 2D (M, N)
  • Does not support 4D per-frame modulate, smooth_norm paths, etc.

It is currently registered in the generic lightx2v_kernel/gemm.py / common_extension.cc surface, which may mislead future maintainers into thinking other models can reuse it directly. Suggestion: add a detailed docstring on scaled_mxfp8_modulate_quant in gemm.py documenting the supported modes and constraints.

4. Numerical equivalence & config toggles (P1)

Because modulate_quant and gemm_residual_gate are not strictly numerically equivalent to the unfused path, please add a config switch such as mxfp8_fuse_enable (default true) to make regression comparison and debugging easier.

5. Guards should not live long-term on the Wan model layer (P1)

This PR centralizes probe* / can_use* / mxfp8_apply* in transformer_infer, with guards split across self-attn and FFN. That makes it harder to:

  • Reuse the same kernels from other models
  • Add similar fusion for other quant schemes (e.g. int8-q8f)

A full long-term refactor is out of scope for now. If the diff stays manageable, this PR should at least extract Wan-side mxfp8 fuse logic into a dedicated module (e.g. wan/infer/mxfp8_fuse.py) to avoid further bloat of transformer_infer.py.

6. Minor items

  • PR description vs code: The description says "Handles FP16/BF16 activations (kernel auto-detects dtype)", but _can_use_mxfp8_modulate_quant hard-codes:

    if norm2_out.dtype != torch.bfloat16 or c_scale_msa.dtype != torch.bfloat16 or c_shift_msa.dtype != torch.bfloat16:
        return False

    Please align the description with the actual Python-side gating.

  • PR description scope: "Supports all Wan tasks" is optimistic. Paths like self_forcing override infer_ffn — they won't crash, but they won't get fused benefits either. Please clarify coverage in the PR description.

  • mxfp8_residual_gate_bf16_kernel: No vectorization; bandwidth utilization will be significantly lower than a vectorized version. As a 2D gate fallback (Wan production uses 1D gate), this is not a merge blocker. Please add a TODO comment:

    // TODO: use vectorized loads for better memory bandwidth
  • clean_cuda_cache on fused FFN path: When self.clean_cuda_cache is True, the fused path does not perform the same cleanup as the unfused path (del norm2_out, x; torch_device_module.empty_cache()). Please add explicit release of intermediate tensors in the fused path (e.g. norm2_quant, y_quant, etc.) to match clean_cuda_cache semantics.

  • _infer_ffn_with_mxfp8_quant_fuse returns None: The semantics are non-obvious. Please add a clear docstring explaining that it returns None because residual is updated in-place (via the fused residual+gate kernel), so post_process should skip the add when y is None.

Fatemanx and others added 3 commits June 9, 2026 00:25
Implement three fused CUDA kernels for MXFP8 quantized inference on Blackwell (SM120):

1. scaled_mxfp8_gelu_quant: fuse GELU activation + E8M0 quantization
2. scaled_mxfp8_modulate_quant: fuse scale/shift modulation + quantization
3. cutlass_scaled_mxfp8_mm_residual_gate: fuse GEMM + residual + gate in CUTLASS 3.x epilogue

Performance on RTX 5090 (Wan 5B FFN, m=4096, hidden=1536, ffn=8960):
- GELU+Quant: 1.30× faster (27.8μs → 21.3μs)
- Modulate+Quant: 3.26× faster (92.7μs → 28.5μs)
- GEMM+Residual+Gate: 1.40× faster (194.7μs → 138.9μs)
- End-to-end FFN: 1.20× faster (608μs → 505μs, -103μs per block)
- Reduces kernel launches from 7 to 3 per FFN block

Features:
- Supports all Wan tasks (t2v/i2v/flf2v/animate/s2v/rs2v)
- Auto-fallback on non-SM120 GPUs (H100/A100/RTX4090) with warning
- Handles FP16/BF16 activations (kernel auto-detects dtype)
- One-time device capability probe at init (eliminates ~4000 redundant checks per inference)

Tested: 10/10 unit tests pass, 6/6 fallback scenarios verified

Address review feedback (PR ModelTC#1090):
- Skip alpha device move when already on target device
- Extract check_sm120_or_throw to shared header sm120_utils.h
- Replace std::cerr with TORCH_CHECK in dtype switch fallbacks
- Avoid intermediate BF16 round in residual_gate kernel
- Apply ruff-format

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
@Fatemanx Fatemanx force-pushed the perf/wan-mxfp8-fuse-op branch from 94816cc to 2bfccad Compare June 8, 2026 16:55
@Fatemanx

Fatemanx commented Jun 8, 2026

Copy link
Copy Markdown
Contributor Author

Thanks for this well-executed PR! The SM120 MXFP8 fusion work is a meaningful performance win, and the kernel tests give a solid baseline to review against. Below are the consolidated review comments. Please rebase, fix c_gate_msa call coverage (P0), and address the P1 items below where practical. Happy to re-review after rebase and the P0 fixes.

1. Merge & CI (P0)

  • Currently not mergeable against main; please rebase.
  • After rebase, please run lightx2v_kernel/test/mxfp8_mxfp8/test_fused_ffn.py on SM120.

2. Incomplete call-site coverage (P0)

The fused FFN path requires c_gate_msa to be passed into infer_ffn, but the following paths still use the old signature and do not pass gate:

  • wan/infer/offload/transformer_infer.py
  • wan/infer/feature_caching/transformer_infer.py
  • wan/infer/lingbot/transformer_infer.py
  • wan/infer/audio/transformer_infer.py

With dit_quant_scheme=mxfp8 on SM120, _ensure_mxfp8_quant_ffn_ready will RuntimeError when c_gate_msa is None, rather than gracefully falling back.

3. scaled_mxfp8_modulate_quant: universality & API placement (P1)

This op is not a generic modulate fusion; it is a Wan AdaLN subset:

  • Fixed formula: x * (1 + scale) + shift
  • Shapes limited to 1D (N,) or 2D (M, N)
  • Does not support 4D per-frame modulate, smooth_norm paths, etc.

It is currently registered in the generic lightx2v_kernel/gemm.py / common_extension.cc surface, which may mislead future maintainers into thinking other models can reuse it directly. Suggestion: add a detailed docstring on scaled_mxfp8_modulate_quant in gemm.py documenting the supported modes and constraints.

4. Numerical equivalence & config toggles (P1)

Because modulate_quant and gemm_residual_gate are not strictly numerically equivalent to the unfused path, please add a config switch such as mxfp8_fuse_enable (default true) to make regression comparison and debugging easier.

5. Guards should not live long-term on the Wan model layer (P1)

This PR centralizes probe* / can_use* / mxfp8_apply* in transformer_infer, with guards split across self-attn and FFN. That makes it harder to:

  • Reuse the same kernels from other models
  • Add similar fusion for other quant schemes (e.g. int8-q8f)

A full long-term refactor is out of scope for now. If the diff stays manageable, this PR should at least extract Wan-side mxfp8 fuse logic into a dedicated module (e.g. wan/infer/mxfp8_fuse.py) to avoid further bloat of transformer_infer.py.

6. Minor items

  • PR description vs code: The description says "Handles FP16/BF16 activations (kernel auto-detects dtype)", but _can_use_mxfp8_modulate_quant hard-codes:

    if norm2_out.dtype != torch.bfloat16 or c_scale_msa.dtype != torch.bfloat16 or c_shift_msa.dtype != torch.bfloat16:
        return False

    Please align the description with the actual Python-side gating.

  • PR description scope: "Supports all Wan tasks" is optimistic. Paths like self_forcing override infer_ffn — they won't crash, but they won't get fused benefits either. Please clarify coverage in the PR description.

  • mxfp8_residual_gate_bf16_kernel: No vectorization; bandwidth utilization will be significantly lower than a vectorized version. As a 2D gate fallback (Wan production uses 1D gate), this is not a merge blocker. Please add a TODO comment:

    // TODO: use vectorized loads for better memory bandwidth
  • clean_cuda_cache on fused FFN path: When self.clean_cuda_cache is True, the fused path does not perform the same cleanup as the unfused path (del norm2_out, x; torch_device_module.empty_cache()). Please add explicit release of intermediate tensors in the fused path (e.g. norm2_quant, y_quant, etc.) to match clean_cuda_cache semantics.

  • _infer_ffn_with_mxfp8_quant_fuse returns None: The semantics are non-obvious. Please add a clear docstring explaining that it returns None because residual is updated in-place (via the fused residual+gate kernel), so post_process should skip the add when y is None.

Thanks for the review. I rebased onto latest main and addressed the P0/P1 items: fixed c_gate_msa forwarding across the Wan paths
including audio, added mxfp8_fuse_enable, documented the Wan-specific scope of scaled_mxfp8_modulate_quant, extracted Wan MXFP8
helper logic into wan/infer/mxfp8_fuse.py, updated fused FFN cleanup/docs, and clarified the PR description.

Validation:

  • python lightx2v_kernel/test/mxfp8_mxfp8/test_fused_ffn.py -> 14 tests OK
  • python -m unittest -q test_cases.test_wan_mxfp8_fuse_forwarding -> 6 tests OK
  • python -m pre_commit run --all-files -> passed

The GitHub lint check is also passing now.

@STwangyingrui

Copy link
Copy Markdown
Contributor

Thanks for the thorough follow-up! No further comments from my side.

@STwangyingrui STwangyingrui merged commit 392417f into ModelTC:main Jun 9, 2026
1 check passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants