Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 9 additions & 21 deletions tests/pytorch/test_fusible_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,13 @@
_quantization_list.append("nvfp4")


@pytest.fixture(autouse=True, scope="class")
def _reset_rng_states_per_test():
"""Restore torch, CUDA, and Python ``random`` before each test in this module."""
reset_rng_states()
yield


def maybe_skip_quantization(
quantization: Optional[str],
*,
Expand Down Expand Up @@ -363,10 +370,6 @@ def test_extra_tensors(self, size: int = 16) -> None:
class TestFuser:
"""Tests for operation fusion infrastructure"""

@staticmethod
def setup_class(cls) -> None:
reset_rng_states()

@pytest.mark.skipif(not fp8_available, reason=reason_for_no_fp8)
def test_fp8_scale_update(
self,
Expand Down Expand Up @@ -579,10 +582,6 @@ def test_pyt_autocast(
class TestBasicOps:
"""Tests for individual operations"""

@staticmethod
def setup_class(cls) -> None:
reset_rng_states()

@pytest.mark.parametrize("dtype", _dtypes)
@pytest.mark.parametrize("device", ("cuda", "cpu"))
@pytest.mark.parametrize("quantization", _quantization_list)
Expand Down Expand Up @@ -2326,10 +2325,6 @@ def test_interleaved_scaled_clamped_qgeglu(self):
class TestFusedOps:
"""Tests for fused operations"""

@staticmethod
def setup_class(cls) -> None:
reset_rng_states()

@pytest.mark.parametrize("weight_shape", ((32, 64), (3, 5)))
@pytest.mark.parametrize("in_shape", ((-1,), (1, 7, -1), (8, 2, 10, -1)))
@pytest.mark.parametrize("dtype", _dtypes)
Expand Down Expand Up @@ -3034,10 +3029,6 @@ def test_backward_linear_scale(
class TestCheckpointing:
"""Tests for checkpointing"""

@staticmethod
def setup_class(cls) -> None:
reset_rng_states()

@pytest.mark.parametrize("quantization", _quantization_list)
@pytest.mark.parametrize("quantized_weight", (False, True))
def test_linear(
Expand Down Expand Up @@ -3150,10 +3141,6 @@ def test_linear(
class TestSequentialModules:
"""Test for larger Sequentials with modules commonly used together"""

@staticmethod
def setup_class(cls) -> None:
reset_rng_states()

@pytest.mark.parametrize("requires_grad", (False, True))
@pytest.mark.parametrize("bias", (False, True))
@pytest.mark.parametrize("quantized_compute", (False, True))
Expand Down Expand Up @@ -3337,13 +3324,14 @@ def to_cpu(tensor: Optional[torch.Tensor]) -> Optional[torch.Tensor]:
@pytest.mark.parametrize("accumulate_into_main_grad", (False, True))
@pytest.mark.parametrize("glu_interleave_size", (None, 32))
@pytest.mark.parametrize("delay_wgrad_compute", (False, True))
@pytest.mark.parametrize("hidden_size", (128, 256))
@pytest.mark.parametrize("activation", ("scaled_swiglu", "scaled_clamped_qgeglu"))
def test_grouped_mlp(
self,
*,
group_size: int = 4,
bias: bool,
hidden_size: int = 256,
hidden_size: int,
dtype: torch.dtype,
quantization: Optional[str],
single_grouped_weight: bool,
Expand Down
15 changes: 11 additions & 4 deletions tests/pytorch/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

import logging
import os
import random
import subprocess
from contextlib import contextmanager
from typing import Optional, Sequence, Tuple, Dict, Any, List
Expand Down Expand Up @@ -173,8 +174,8 @@ def skip_unsupported_backward_override(
pytest.skip(f"{layer_type} does not support NVTE_BACKWARD_OVERRIDE={backward_override}.")


# Cached RNG state
_rng_states: Optional[Tuple[torch.Tensor, torch.Tensor]] = None
# Cached RNG state (torch CPU, torch CUDA, Python ``random``)
_rng_states: Optional[Tuple[torch.Tensor, torch.Tensor, Any]] = None


def reset_rng_states() -> None:
Expand All @@ -183,11 +184,17 @@ def reset_rng_states() -> None:
if _rng_states is None:
torch.manual_seed(1234)
torch.cuda.manual_seed(1234)
_rng_states = (torch.get_rng_state(), torch.cuda.get_rng_state())
random.seed(1234)
_rng_states = (
torch.get_rng_state(),
torch.cuda.get_rng_state(),
random.getstate(),
)
else:
cpu_rng_state, cuda_rng_state = _rng_states
cpu_rng_state, cuda_rng_state, random_state = _rng_states
torch.set_rng_state(cpu_rng_state)
torch.cuda.set_rng_state(cuda_rng_state)
random.setstate(random_state)


def compare_and_assert(a, b, name_a, name_b, atol, rtol, rmse_tol, is_fp8):
Expand Down
12 changes: 6 additions & 6 deletions transformer_engine/pytorch/ops/_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,12 +88,12 @@ def get_fp8_meta_from_fp8_tensor(tensor: Float8Tensor) -> tuple[FP8TensorMeta, i
def validate_grouped_mlp_dims(fc1, glu_op, fc2) -> None:
"""Validate FC1 / scaled GLU / FC2 dimensions for fused grouped MLP."""

if fc1.in_features % 256 != 0 or fc1.out_features % 256 != 0:
if fc1.in_features % 64 != 0 or fc1.out_features % 64 != 0:
raise ValueError(
f"Unsupported dims for FC1 (num_groups={fc1.num_groups}, "
f"in_features={fc1.in_features}, out_features={fc1.out_features})."
)
if fc2.in_features % 256 != 0 or fc2.out_features % 256 != 0:
if fc2.in_features % 64 != 0 or fc2.out_features % 64 != 0:
raise ValueError(
f"Unsupported dims for FC2 (num_groups={fc2.num_groups}, "
f"in_features={fc2.in_features}, out_features={fc2.out_features})."
Expand Down Expand Up @@ -176,10 +176,10 @@ def fuse_grouped_mlp_ops(
elif window[0].num_groups != window[2].num_groups:
matches_pattern = False
elif (
window[0].in_features % 256 != 0
or window[0].out_features % 256 != 0
or window[2].in_features % 256 != 0
or window[2].out_features % 256 != 0
window[0].in_features % 64 != 0
or window[0].out_features % 64 != 0
or window[2].in_features % 64 != 0
or window[2].out_features % 64 != 0
):
matches_pattern = False
elif window[1].glu_interleave_size != 32:
Expand Down
12 changes: 6 additions & 6 deletions transformer_engine/pytorch/ops/fused/backward_grouped_mlp.py
Original file line number Diff line number Diff line change
Expand Up @@ -406,8 +406,8 @@ def fuser_backward(
fc2_dy_scales = fc2_dy_scales.view(dtype=torch.float8_e8m0fnu)
fc2_dy_scales = fc2_dy_scales.view(
1,
out_shape[0] // 128,
out_shape[1] // 128,
(out_shape[0] + 127) // 128,
(out_shape[1] + 127) // 128,
MXFP8_BLOCK_SCALING_SIZE,
4,
4,
Expand Down Expand Up @@ -456,8 +456,8 @@ def fuser_backward(
fc2_w_scales = fc2_weight_for_gemm.columnwise_scale_inv.view(dtype=torch.float8_e8m0fnu)
fc2_w_scales = fc2_w_scales.view(
num_groups,
fc2_weight_shape[1] // 128,
fc2_weight_shape[0] // 128,
(fc2_weight_shape[1] + 127) // 128,
(fc2_weight_shape[0] + 127) // 128,
MXFP8_BLOCK_SCALING_SIZE,
4,
4,
Expand Down Expand Up @@ -587,8 +587,8 @@ def fuser_backward(
)
fc1_w_scales = fc1_w_scales.view(
num_groups,
fc1_weight_shape[1] // 128,
fc1_weight_shape[0] // 128,
(fc1_weight_shape[1] + 127) // 128,
(fc1_weight_shape[0] + 127) // 128,
MXFP8_BLOCK_SCALING_SIZE,
4,
4,
Expand Down
13 changes: 7 additions & 6 deletions transformer_engine/pytorch/ops/fused/forward_grouped_mlp.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,6 +154,7 @@ def fuser_forward(
fc2_weight_shape = (fc2_op.out_features, fc2_op.in_features)
input_ = input_.reshape(-1, fc1_weight_shape[1])
in_shape = list(input_.size())
assert in_shape[0] % 128 == 0, "Unsupported input shape for fused grouped MLP."

num_groups = fc1_op.num_groups
fc1_weight_param = fc1_op.weight if fc1_op.single_grouped_weight else fc1_op.weight0
Expand Down Expand Up @@ -312,8 +313,8 @@ def fuser_forward(
fc1_x_scales = fc1_x_scales.view(dtype=torch.float8_e8m0fnu)
fc1_x_scales = fc1_x_scales.view(
1,
in_shape[0] // 128,
in_shape[1] // 128,
(in_shape[0] + 127) // 128,
(in_shape[1] + 127) // 128,
MXFP8_BLOCK_SCALING_SIZE,
4,
4,
Expand Down Expand Up @@ -361,8 +362,8 @@ def fuser_forward(
fc1_w_scales = fc1_weight_for_gemm.scale_inv.view(dtype=torch.float8_e8m0fnu)
fc1_w_scales = fc1_w_scales.view(
num_groups,
fc1_weight_shape[0] // 128,
fc1_weight_shape[1] // 128,
(fc1_weight_shape[0] + 127) // 128,
(fc1_weight_shape[1] + 127) // 128,
MXFP8_BLOCK_SCALING_SIZE,
4,
4,
Expand Down Expand Up @@ -458,8 +459,8 @@ def fuser_forward(
fc2_w_scales = fc2_weight_for_gemm.scale_inv.view(dtype=torch.float8_e8m0fnu)
fc2_w_scales = fc2_w_scales.view(
num_groups,
fc2_weight_shape[0] // 128,
fc2_weight_shape[1] // 128,
(fc2_weight_shape[0] + 127) // 128,
(fc2_weight_shape[1] + 127) // 128,
MXFP8_BLOCK_SCALING_SIZE,
4,
4,
Expand Down
Loading