Skip to content

[PyTorch] Relax dimension constraints for using fused grouped MLP#2856

Open
ksivaman wants to merge 5 commits intoNVIDIA:mainfrom
ksivaman:relax_dim_constraint_for_fused_grouped_mlp
Open

[PyTorch] Relax dimension constraints for using fused grouped MLP#2856
ksivaman wants to merge 5 commits intoNVIDIA:mainfrom
ksivaman:relax_dim_constraint_for_fused_grouped_mlp

Conversation

@ksivaman
Copy link
Copy Markdown
Member

@ksivaman ksivaman commented Apr 8, 2026

Description

Previously we required all weight dims to be modulo256, this PR changes it to modulo64, which is the actual constraint from the kernel.

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Changes

  • Enable fused grouped MLP path for 64 modulo dims as well.
  • Fix shapes when creating SF views to pass to cutedsl.

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>
@ksivaman ksivaman requested a review from timmoon10 April 8, 2026 20:07
@greptile-apps
Copy link
Copy Markdown
Contributor

greptile-apps bot commented Apr 8, 2026

Greptile Summary

This PR relaxes the fused grouped MLP dimension gate from % 256 to % 64 (the actual kernel constraint) and fixes the MXFP8 scale-factor view reshapes to use ceiling division ((dim + 127) // 128) instead of floor division — the latter would produce a zero-sized dimension for any weight size that is a multiple of 64 but not 128. A new assertion guards that the token-count dimension (in_shape[0]) remains a multiple of 128, which is a separate kernel requirement.

Confidence Score: 5/5

Safe to merge; all remaining findings are P2 style/coverage suggestions that do not affect correctness.

The core logic changes — gate relaxation from % 256 to % 64 and ceiling-division SF views — are correct and consistent across forward and backward paths. Only P2 issues remain: an assert that should be a RuntimeError, stale comments, and a test-coverage gap for non-128-multiple hidden sizes.

forward_grouped_mlp.py (assert → RuntimeError, stale comments); tests/pytorch/test_fusible_ops.py (consider adding hidden_size=64)

Important Files Changed

Filename Overview
transformer_engine/pytorch/ops/_common.py Relaxes dimension gating from % 256 to % 64 in both validate_grouped_mlp_dims and fuse_grouped_mlp_ops; logic is consistent across both functions.
transformer_engine/pytorch/ops/fused/forward_grouped_mlp.py Adds token-count assertion (% 128) and fixes all SF-view reshapes from floor to ceiling division; assert should be a RuntimeError, and scale-shape comments are stale.
transformer_engine/pytorch/ops/fused/backward_grouped_mlp.py Applies the same ceiling-division fix to fc2_dy, fc2_w, and fc1_w scale views; changes are mechanically symmetric with the forward pass.
tests/pytorch/test_fusible_ops.py Replaces per-class setup_class hooks with a module-level autouse fixture, adds hidden_size parametrization (128, 256); both values are multiples of 128 so the ceiling-division fix is not directly exercised.
tests/pytorch/utils.py Extends reset_rng_states to also save/restore Python random module state, making RNG reproducibility more complete.

Flowchart

%%{init: {'theme': 'neutral'}}%%
flowchart TD
    A["Input tensor\nin_shape = M x K"] --> B{"M % 128 == 0?"}
    B -- No --> ERR["RuntimeError\n(assert currently)"]
    B -- Yes --> C{"Weight dims\n% 64 == 0?\n(was % 256)"}
    C -- No --> SKIP["Skip fused path\n(fallback to unfused)"]
    C -- Yes --> D["Quantize inputs/weights\n(MXFP8 group quantize)"]
    D --> E["Build SF view\nceil(dim/128) blocks\n(was dim//128)"]
    E --> F["Forward GEMM + SwiGLU kernel"]
    F --> G["FC2 GEMM kernel"]
    G --> H["Output tensor"]
    H --> BWD["Backward pass\nSame ceiling-div SF views\nfor fc2_dy, fc1_w, fc2_w"]
Loading

Reviews (4): Last reviewed commit: "reset rng as before, assert input dim" | Re-trigger Greptile

timmoon10
timmoon10 previously approved these changes Apr 8, 2026
Copy link
Copy Markdown
Collaborator

@timmoon10 timmoon10 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@ksivaman
Copy link
Copy Markdown
Member Author

ksivaman commented Apr 8, 2026

/te-ci pytorch L0

Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>
@ksivaman
Copy link
Copy Markdown
Member Author

ksivaman commented Apr 9, 2026

/te-ci pytorch L0

1 similar comment
@ksivaman
Copy link
Copy Markdown
Member Author

/te-ci pytorch L0

Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants