[PyTorch] Update cuBLASLt grouped gemm filter#3119
Conversation
Signed-off-by: Xin Yao <xiny@nvidia.com>
for more information, see https://pre-commit.ci
|
/te-ci pytorch |
Greptile SummaryThis PR extends cuBLASLt grouped GEMM support in both the module and ops
Confidence Score: 4/5The functional changes are well-scoped and self-consistent; the only findings are stale doc comments that do not affect runtime behavior. The SM-range guards align across the C++, Python module, and ops layers; the NVFP4 backward bias fallback to compute_grouped_dbias is correct because dy_2d remains unquantized at that point. Two stale comments in ops/basic/grouped_linear.py do not affect execution. transformer_engine/pytorch/ops/basic/grouped_linear.py has two stale comments; all other files look correct. Important Files Changed
Flowchart%%{init: {'theme': 'neutral'}}%%
flowchart TD
A[GroupedLinear forward/backward] --> B{Env var FUSED_GROUPED_GEMM?}
B -- No --> LEGACY[Legacy path]
B -- Yes --> C{Advanced features?}
C -- Yes --> LEGACY
C -- No --> D{CC in 9.0 to 11.0?}
D -- No --> LEGACY
D -- Yes --> E{output_quantizers set?}
E -- Yes --> LEGACY
E -- No --> F{fp8?}
F -- No --> G{BF16 or FP16?}
G -- Yes --> GROUPED[GroupedTensor cuBLASLt GEMM]
G -- No --> LEGACY
F -- Yes --> H{CC in 10.0 to 11.0 Blackwell?}
H -- No --> LEGACY
H -- Yes --> I{all MXFP8 or all NVFP4?}
I -- Yes --> GROUPED
I -- No --> LEGACY
GROUPED --> J{Backward with bias?}
J -- MXFP8 --> K[bgrad_group_quantize fused]
J -- NVFP4/BF16/FP16 --> L[group_quantize + compute_grouped_dbias]
J -- No bias --> M[group_quantize only]
|
vthumbe1503
left a comment
There was a problem hiding this comment.
LGTM. I think we should also add nvfp4 recipe to the test that tests correctness of the grouped_linear with grouped tensor path
| device_capability = torch.cuda.get_device_capability() | ||
| if not ((9, 0) <= device_capability <= (11, 0)): | ||
| pytest.skip( | ||
| "GroupedTensor grouped GEMM path requires Hopper (SM90) or Blackwell (SM10x and SM110)." |
There was a problem hiding this comment.
Given that this path just verifies cuda graph safeness. It might also make sense to also add nvfp4 recipe to the test test_grouped_linear_grouped_tensor_path_matches_legacy.
Description
bgrad_group_quantizeusage to MXFP8, since that helper only supports MXFP8; NVFP4+bias now falls back to separate grouped bias-grad computation.test_grouped_linear_fused_path_cuda_graph_safe:disable_stochastic_rounding=True.Type of change
Changes
Please list the changes introduced in this PR:
Checklist: