Skip to content

[PyTorch] Update cuBLASLt grouped gemm filter#3119

Open
yaox12 wants to merge 2 commits into
NVIDIA:mainfrom
yaox12:xiny/update_cublaslt_grouped_gemm_filter
Open

[PyTorch] Update cuBLASLt grouped gemm filter#3119
yaox12 wants to merge 2 commits into
NVIDIA:mainfrom
yaox12:xiny/update_cublaslt_grouped_gemm_filter

Conversation

@yaox12

@yaox12 yaox12 commented Jun 11, 2026

Copy link
Copy Markdown
Member

Description

  • Updated grouped-tensor path filters in module and ops grouped linear to match cuBLASLt grouped GEMM support:
    • Hopper SM90: BF16
    • Blackwell SM10x/SM110: BF16/MXFP8/NVFP4.
    • Quantized fused path remains Blackwell-only because there's no grouped quantization kernels for fp8 blockwise scaling and fp8 delayed/current scaling.
  • Added NVFP4 support to the fused grouped tensor path filter.
  • Limited bgrad_group_quantize usage to MXFP8, since that helper only supports MXFP8; NVFP4+bias now falls back to separate grouped bias-grad computation.
  • Updated C++ grouped GEMM architecture checks and comments to cap support at SM110, so SM120 is excluded, which matches the cuBLAS doc.
  • Extended test_grouped_linear_fused_path_cuda_graph_safe:
    • Adds NVFP4 with disable_stochastic_rounding=True.
    • Allows BF16 to run on Hopper.
    • Adds cuBLASLt version skips for Hopper/Blackwell requirements.

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Changes

Please list the changes introduced in this PR:

  • Change A
  • Change B

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

Signed-off-by: Xin Yao <xiny@nvidia.com>
@yaox12

yaox12 commented Jun 11, 2026

Copy link
Copy Markdown
Member Author

/te-ci pytorch

@greptile-apps

greptile-apps Bot commented Jun 11, 2026

Copy link
Copy Markdown
Contributor

Greptile Summary

This PR extends cuBLASLt grouped GEMM support in both the module and ops GroupedLinear paths: the compute-capability filter is widened to include Hopper (SM90) for BF16/FP16 and Blackwell (SM10x/SM110) for MXFP8/NVFP4, while SM120+ is explicitly excluded to match the cuBLAS docs. bgrad_group_quantize usage is correctly restricted to MXFP8 only, and NVFP4+bias falls back to a separate compute_grouped_dbias call.

  • Filter refactor: Both _is_grouped_tensor_path_supported / _is_graph_safe_path_supported now use (9,0) <= CC <= (11,0) guards; weight/grad-output quantizer type checks were removed from the filter and are instead documented as a precondition assumed by the caller.
  • NVFP4 support added: NVFP4 quantizers are now accepted alongside MXFP8 in the quantized grouped-tensor dispatch; the backward bias-grad path correctly falls back to compute_grouped_dbias for NVFP4.
  • Tests updated: test_grouped_linear_fused_path_cuda_graph_safe gains an NVFP4 parametrize, Hopper-capable BF16 run, and version-based skip guards for cuBLAS 13.3+/13.4+.

Confidence Score: 4/5

The functional changes are well-scoped and self-consistent; the only findings are stale doc comments that do not affect runtime behavior.

The SM-range guards align across the C++, Python module, and ops layers; the NVFP4 backward bias fallback to compute_grouped_dbias is correct because dy_2d remains unquantized at that point. Two stale comments in ops/basic/grouped_linear.py do not affect execution.

transformer_engine/pytorch/ops/basic/grouped_linear.py has two stale comments; all other files look correct.

Important Files Changed

Filename Overview
transformer_engine/common/gemm/cublaslt_grouped_gemm.cu Updated SM range checks to cap at SM110 in grouped GEMM requirements and per-group alpha/beta support; updated error messages and comments accordingly.
transformer_engine/pytorch/module/grouped_linear.py Extended grouped-tensor path filter to support Hopper (SM90) BF16/FP16 and NVFP4 on Blackwell; removed weight/grad_output quantizer checks from the filter; bgrad_group_quantize usage correctly restricted to MXFP8.
transformer_engine/pytorch/ops/basic/grouped_linear.py Mirrors filter changes from module: Hopper BF16/FP16 allowed, NVFP4 added for Blackwell; bgrad_group_quantize gated to MXFP8; two comments are stale now that NVFP4 uses the same code paths.
tests/pytorch/test_grouped_linear.py Added NVFP4 parametrize to cuda-graph-safe test; extended device-capability guard to include Hopper; added cuBLASLt version skip conditions for Hopper (13.4+) and Blackwell (13.3+).

Flowchart

%%{init: {'theme': 'neutral'}}%%
flowchart TD
    A[GroupedLinear forward/backward] --> B{Env var FUSED_GROUPED_GEMM?}
    B -- No --> LEGACY[Legacy path]
    B -- Yes --> C{Advanced features?}
    C -- Yes --> LEGACY
    C -- No --> D{CC in 9.0 to 11.0?}
    D -- No --> LEGACY
    D -- Yes --> E{output_quantizers set?}
    E -- Yes --> LEGACY
    E -- No --> F{fp8?}
    F -- No --> G{BF16 or FP16?}
    G -- Yes --> GROUPED[GroupedTensor cuBLASLt GEMM]
    G -- No --> LEGACY
    F -- Yes --> H{CC in 10.0 to 11.0 Blackwell?}
    H -- No --> LEGACY
    H -- Yes --> I{all MXFP8 or all NVFP4?}
    I -- Yes --> GROUPED
    I -- No --> LEGACY
    GROUPED --> J{Backward with bias?}
    J -- MXFP8 --> K[bgrad_group_quantize fused]
    J -- NVFP4/BF16/FP16 --> L[group_quantize + compute_grouped_dbias]
    J -- No bias --> M[group_quantize only]
Loading

Comments Outside Diff (2)

  1. transformer_engine/pytorch/ops/basic/grouped_linear.py, line 789-799 (link)

    P2 The docstring for _get_grouped_weight_for_gemm still lists only MXFP8/BF16/FP16 compute paths, but NVFP4 is now a supported quantized-compute path that routes through the same tex.group_quantize branch.

    Note: If this suggestion doesn't match your team's coding style, reply to this and let me know. I'll remember it for next time!

  2. transformer_engine/pytorch/ops/basic/grouped_linear.py, line 1632-1634 (link)

    P2 The comment "BF16/FP16 path" is now misleading: this branch is also reached when NVFP4 + bias is active, because bgrad_group_quantize is restricted to MXFP8Quantizer and dbias_packed is left as None for NVFP4.

Reviews (1): Last reviewed commit: "[pre-commit.ci] auto fixes from pre-comm..." | Re-trigger Greptile

@vthumbe1503 vthumbe1503 left a comment

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM. I think we should also add nvfp4 recipe to the test that tests correctness of the grouped_linear with grouped tensor path

Comment on lines +1706 to +1709
device_capability = torch.cuda.get_device_capability()
if not ((9, 0) <= device_capability <= (11, 0)):
pytest.skip(
"GroupedTensor grouped GEMM path requires Hopper (SM90) or Blackwell (SM10x and SM110)."

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Given that this path just verifies cuda graph safeness. It might also make sense to also add nvfp4 recipe to the test test_grouped_linear_grouped_tensor_path_matches_legacy.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants