[PyTorch] Relax dimension constraints for using fused grouped MLP#2856
[PyTorch] Relax dimension constraints for using fused grouped MLP#2856ksivaman wants to merge 5 commits intoNVIDIA:mainfrom
Conversation
Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>
Greptile SummaryThis PR relaxes the fused grouped MLP dimension gate from Confidence Score: 5/5Safe to merge; all remaining findings are P2 style/coverage suggestions that do not affect correctness. The core logic changes — gate relaxation from % 256 to % 64 and ceiling-division SF views — are correct and consistent across forward and backward paths. Only P2 issues remain: an assert that should be a RuntimeError, stale comments, and a test-coverage gap for non-128-multiple hidden sizes. forward_grouped_mlp.py (assert → RuntimeError, stale comments); tests/pytorch/test_fusible_ops.py (consider adding hidden_size=64) Important Files Changed
Flowchart%%{init: {'theme': 'neutral'}}%%
flowchart TD
A["Input tensor\nin_shape = M x K"] --> B{"M % 128 == 0?"}
B -- No --> ERR["RuntimeError\n(assert currently)"]
B -- Yes --> C{"Weight dims\n% 64 == 0?\n(was % 256)"}
C -- No --> SKIP["Skip fused path\n(fallback to unfused)"]
C -- Yes --> D["Quantize inputs/weights\n(MXFP8 group quantize)"]
D --> E["Build SF view\nceil(dim/128) blocks\n(was dim//128)"]
E --> F["Forward GEMM + SwiGLU kernel"]
F --> G["FC2 GEMM kernel"]
G --> H["Output tensor"]
H --> BWD["Backward pass\nSame ceiling-div SF views\nfor fc2_dy, fc1_w, fc2_w"]
Reviews (4): Last reviewed commit: "reset rng as before, assert input dim" | Re-trigger Greptile |
|
/te-ci pytorch L0 |
Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>
Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>
|
/te-ci pytorch L0 |
1 similar comment
|
/te-ci pytorch L0 |
Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>
Description
Previously we required all weight dims to be modulo256, this PR changes it to modulo64, which is the actual constraint from the kernel.
Type of change
Changes
Checklist: