Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
86 changes: 76 additions & 10 deletions tests/pytorch/test_grouped_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -1497,6 +1497,7 @@ def test_fp8_grouped_gemm(shape, accumulate):
_FUSED_GROUPED_GEMM_ENV = "NVTE_GROUPED_LINEAR_USE_FUSED_GROUPED_GEMM"
_ALL_BOOLEAN = all_boolean
_mxfp8_available, _reason_for_no_mxfp8 = mxfp8_available, reason_for_no_mxfp8
_nvfp4_available, _reason_for_no_nvfp4 = nvfp4_available, reason_for_no_nvfp4


@pytest.fixture(autouse=True)
Expand Down Expand Up @@ -1580,26 +1581,40 @@ def _run_grouped_linear_path(
recipe.MXFP8BlockScaling(),
marks=pytest.mark.skipif(not _mxfp8_available, reason=_reason_for_no_mxfp8),
),
pytest.param(
recipe.NVFP4BlockScaling(disable_stochastic_rounding=True),
marks=pytest.mark.skipif(not _nvfp4_available, reason=_reason_for_no_nvfp4),
),
],
ids=["bf16", "mxfp8"],
ids=["bf16", "mxfp8", "nvfp4"],
)
@pytest.mark.parametrize("bias", _ALL_BOOLEAN)
@pytest.mark.parametrize("fp8_model_params", _ALL_BOOLEAN)
@pytest.mark.parametrize("delay_wgrad_compute", _ALL_BOOLEAN)
def test_grouped_linear_grouped_tensor_path_matches_legacy(
fp8_recipe, bias, fp8_model_params, delay_wgrad_compute, monkeypatch
):
if torch.cuda.get_device_capability() < (10, 0):
pytest.skip("GroupedTensor grouped GEMM path requires SM100+")

use_fp8 = fp8_recipe is not None
device_capability = torch.cuda.get_device_capability()
if not (9, 0) <= device_capability <= (11, 0):
pytest.skip(
"GroupedTensor grouped GEMM path requires Hopper (SM90) or Blackwell (SM10x and SM110)."
)
if use_fp8 and device_capability < (10, 0):
pytest.skip("Quantized GroupedTensor grouped GEMM path requires Blackwell (SM100+).")
cublaslt_version = tex.get_cublasLt_version()
if device_capability < (10, 0) and cublaslt_version < 130400:
pytest.skip("Grouped GEMM on Hopper requires cuBLAS 13.4+.")
if cublaslt_version < 130300:
pytest.skip("Grouped GEMM requires cuBLAS 13.3+.")

if fp8_model_params and not use_fp8:
pytest.skip("fp8_model_params requires FP8")

dtype = torch.bfloat16
num_gemms = 3
in_features = 64
out_features = 64
in_features = 128
out_features = 128
m_splits = [128, 256, 384]
total_tokens = sum(m_splits)

Expand Down Expand Up @@ -1683,6 +1698,43 @@ def test_grouped_linear_grouped_tensor_path_single_grouped_bias_delay_wgrad(monk
grouped_linear.backward_dw()


@pytest.mark.skipif(not _nvfp4_available, reason=_reason_for_no_nvfp4)
def test_grouped_linear_grouped_tensor_path_skips_non_rht_nvfp4(monkeypatch):
"""Non-RHT NVFP4 should fall back to legacy path instead of grouped quantization."""
if torch.cuda.get_device_capability() < (10, 0):
pytest.skip("NVFP4 GroupedTensor grouped GEMM path requires SM100+")

monkeypatch.setenv(_FUSED_GROUPED_GEMM_ENV, "1")
FP8GlobalStateManager.reset()

dtype = torch.bfloat16
num_gemms = 3
in_features = 128
out_features = 128
split_sizes = [128, 256, 384]
total_tokens = sum(split_sizes)

x = torch.randn(total_tokens, in_features, dtype=dtype, device="cuda")
x.requires_grad_(True)
dy = torch.randn(total_tokens, out_features, dtype=dtype, device="cuda")
grouped_linear = GroupedLinear(
num_gemms,
in_features,
out_features,
bias=False,
params_dtype=dtype,
device="cuda",
)

fp8_recipe = recipe.NVFP4BlockScaling(
disable_rht=True,
disable_stochastic_rounding=True,
)
with autocast(enabled=True, recipe=fp8_recipe):
y = grouped_linear(x, split_sizes)
y.backward(dy)


@pytest.mark.parametrize(
"fp8_recipe",
[
Expand All @@ -1691,19 +1743,33 @@ def test_grouped_linear_grouped_tensor_path_single_grouped_bias_delay_wgrad(monk
recipe.MXFP8BlockScaling(),
marks=pytest.mark.skipif(not _mxfp8_available, reason=_reason_for_no_mxfp8),
),
pytest.param(
recipe.NVFP4BlockScaling(disable_stochastic_rounding=True),
marks=pytest.mark.skipif(not _nvfp4_available, reason=_reason_for_no_nvfp4),
),
],
ids=["bf16", "mxfp8"],
ids=["bf16", "mxfp8", "nvfp4"],
)
@pytest.mark.parametrize("bias", _ALL_BOOLEAN)
def test_grouped_linear_fused_path_cuda_graph_safe(fp8_recipe, bias, monkeypatch):
"""Fused GroupedTensor GEMM path should be CUDA graph capturable."""
if torch.cuda.get_device_capability() < (10, 0):
pytest.skip("GroupedTensor grouped GEMM path requires SM100+")
use_fp8 = fp8_recipe is not None
device_capability = torch.cuda.get_device_capability()
if not (9, 0) <= device_capability <= (11, 0):
pytest.skip(
"GroupedTensor grouped GEMM path requires Hopper (SM90) or Blackwell (SM10x and SM110)."
Comment on lines +1757 to +1760

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Given that this path just verifies cuda graph safeness. It might also make sense to also add nvfp4 recipe to the test test_grouped_linear_grouped_tensor_path_matches_legacy.

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Updated.

)
if use_fp8 and device_capability < (10, 0):
pytest.skip("Quantized GroupedTensor grouped GEMM path requires Blackwell (SM100+).")
cublaslt_version = tex.get_cublasLt_version()
if device_capability < (10, 0) and cublaslt_version < 130400:
pytest.skip("Grouped GEMM on Hopper requires cuBLAS 13.4+.")
if cublaslt_version < 130300:
pytest.skip("Grouped GEMM requires cuBLAS 13.3+.")

monkeypatch.setenv(_FUSED_GROUPED_GEMM_ENV, "1")
FP8GlobalStateManager.reset()

use_fp8 = fp8_recipe is not None
dtype = torch.bfloat16
device = "cuda"
num_gemms = 3
Expand Down
15 changes: 8 additions & 7 deletions transformer_engine/common/gemm/cublaslt_grouped_gemm.cu
Original file line number Diff line number Diff line change
Expand Up @@ -228,7 +228,7 @@ struct GroupedGemmSetupWorkspace {
}
};

inline bool grouped_gemm_supports_per_group_alpha_beta(int sm) { return sm >= 100; }
inline bool grouped_gemm_supports_per_group_alpha_beta(int sm) { return sm >= 100 && sm <= 110; }

inline size_t validate_grouped_gemm_inputs(
size_t num_tensors, std::initializer_list<const transformer_engine::GroupedTensor *> inputs,
Expand Down Expand Up @@ -335,7 +335,8 @@ inline void check_grouped_gemm_requirements(const char *api_name) {
const int sm = transformer_engine::cuda::sm_arch(current_device);
const int cublas_ver = transformer_engine::cuda::cublas_version();
#if CUBLAS_VERSION >= CUBLAS_GROUPED_GEMM_HOPPER_VERSION
NVTE_CHECK(sm >= 90, api_name, " requires Hopper (SM90) or newer architecture.");
NVTE_CHECK(sm >= 90 && sm <= 110, api_name,
" requires Hopper (SM90) or Blackwell (SM10x and SM110).");
NVTE_CHECK(cublas_ver >= CUBLAS_GROUPED_GEMM_VERSION, api_name,
" requires cuBLAS 13.3+, but run-time cuBLAS version is ", cublas_ver);
if (sm < 100) {
Expand All @@ -344,7 +345,7 @@ inline void check_grouped_gemm_requirements(const char *api_name) {
cublas_ver);
}
#else
NVTE_CHECK(sm >= 100, api_name, " requires Blackwell (SM100) or newer architecture.");
NVTE_CHECK(sm >= 100 && sm <= 110, api_name, " requires Blackwell (SM10x and SM110).");
NVTE_CHECK(cublas_ver >= CUBLAS_GROUPED_GEMM_VERSION, api_name,
" requires cuBLAS 13.3+, but run-time cuBLAS version is ", cublas_ver);
#endif
Expand Down Expand Up @@ -400,7 +401,7 @@ inline void validate_fp8_block_grouped_gemm_support(const GroupedOperandSelectio
"Grouped GEMM: A and B must both use FP8 block scaling or both not.");
NVTE_CHECK(sm == 90,
"Grouped GEMM: FP8 block scaling is only supported on Hopper (SM90); "
"use MXFP8 on Blackwell (SM100) or newer.");
"use MXFP8 on Blackwell (SM10x and SM110).");
}

inline bool is_compatible_grouped_scaling_mode(NVTEScalingMode a_mode, NVTEScalingMode b_mode) {
Expand Down Expand Up @@ -1567,7 +1568,7 @@ void nvte_grouped_gemm(const NVTEGroupedTensor A, int transa, const NVTEGroupedT
NVTE_API_CALL(nvte_grouped_gemm);
using namespace transformer_engine;

// Grouped GEMM requires Blackwell (SM100) or newer with cuBLAS 13.3+,
// Grouped GEMM requires Blackwell (SM10x and SM110) with cuBLAS 13.3+,
// or Hopper (SM90) with cuBLAS 13.4+.
check_grouped_gemm_requirements("nvte_grouped_gemm");

Expand Down Expand Up @@ -1650,7 +1651,7 @@ void nvte_grouped_gemm_with_discrete_inputA(const NVTETensor *A_list, size_t num
NVTE_API_CALL(nvte_grouped_gemm_with_discrete_inputA);
using namespace transformer_engine;

// Grouped GEMM requires Blackwell (SM100) or newer with cuBLAS 13.3+,
// Grouped GEMM requires Blackwell (SM10x and SM110) with cuBLAS 13.3+,
// or Hopper (SM90) with cuBLAS 13.4+.
check_grouped_gemm_requirements("nvte_grouped_gemm_with_discrete_inputA");

Expand Down Expand Up @@ -1801,7 +1802,7 @@ void nvte_grouped_gemm_with_discrete_out(const NVTEGroupedTensor A, int transa,
NVTE_API_CALL(nvte_grouped_gemm_with_discrete_out);
using namespace transformer_engine;

// Grouped GEMM requires Blackwell (SM100) or newer with cuBLAS 13.3+,
// Grouped GEMM requires Blackwell (SM10x and SM110) with cuBLAS 13.3+,
// or Hopper (SM90) with cuBLAS 13.4+.
check_grouped_gemm_requirements("nvte_grouped_gemm_with_discrete_out");

Expand Down
39 changes: 24 additions & 15 deletions transformer_engine/pytorch/module/grouped_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,8 +55,7 @@
from ..cpu_offload import is_cpu_offload_enabled, mark_not_offload, start_offload
from ..triton.grouped_dbias_dscales import compute_grouped_dbias

from ..tensor.float8_tensor import Float8CurrentScalingQuantizer, Float8Quantizer
from ..tensor.mxfp8_tensor import MXFP8Quantizer
from ..tensor import Float8CurrentScalingQuantizer, Float8Quantizer, MXFP8Quantizer, NVFP4Quantizer
from ..quantized_tensor import (
QuantizedTensorStorage,
Quantizer,
Expand Down Expand Up @@ -95,19 +94,29 @@ def _is_grouped_tensor_path_supported(
save_original_input: bool,
activation_dtype: torch.dtype,
input_quantizers: List[Optional[Quantizer]],
weight_quantizers: List[Optional[Quantizer]],
output_quantizers: List[Optional[Quantizer]],
grad_output_quantizers: List[Optional[Quantizer]],
) -> bool:
"""Whether to use cublasLt grouped GEMM through GroupedTensor metadata.
"""Whether to use cuBLASLt grouped GEMM through GroupedTensor metadata.

There are no checks whether split sizes are supported. Splits
may be in a CUDA tensor, so checking would hurt performance
and be incompatible with CUDA Graphs.

Supported Compute Capability (CC) and precisions:
* Hopper (CC 9.0): BF16/FP16.
* Blackwell (CC 10.x and 11.0): BF16/FP16/MXFP8/NVFP4 with RHT.
FP8 delayed / current scaling, and FP8 block scaling are not supported because the
corresponding grouped quantization kernels are missing.
Non-RHT NVFP4 falls back to the legacy path because graph-safe grouped quantization
currently requires RHT.

Input/weight/grad_output quantizers are assumed to be of the same type, otherwise it would
trigger a fatal error in the cuBLASLt grouped GEMM check.
"""
# 1. Filter by environment variable
if not bool(int(os.getenv("NVTE_GROUPED_LINEAR_USE_FUSED_GROUPED_GEMM", "0"))):
return False
# 2. Filter out advanced features
if (
debug
or cpu_offloading
Expand All @@ -116,16 +125,18 @@ def _is_grouped_tensor_path_supported(
or save_original_input
):
return False
if get_device_compute_capability() < (10, 0):
# 3. Filter by compute capability
if not (9, 0) <= get_device_compute_capability() <= (11, 0):
return False
# 4. Output quantization is not supported.
if any(q is not None for q in output_quantizers):
return False
# 5. Filter by quantization recipes.
if fp8:
return (
activation_dtype in (torch.bfloat16, torch.float16)
and all(isinstance(q, MXFP8Quantizer) for q in input_quantizers)
and all(isinstance(q, MXFP8Quantizer) for q in weight_quantizers)
and all(q is None or isinstance(q, MXFP8Quantizer) for q in grad_output_quantizers)
if not (10, 0) <= get_device_compute_capability() <= (11, 0):
return False
return all(isinstance(q, MXFP8Quantizer) for q in input_quantizers) or all(
isinstance(q, NVFP4Quantizer) and q.with_rht for q in input_quantizers
)
return activation_dtype in (torch.bfloat16, torch.float16)

Expand Down Expand Up @@ -234,7 +245,7 @@ def _forward_grouped_tensor(
weights: Tuple[torch.Tensor, ...],
biases: Tuple[torch.Tensor, ...],
) -> Tuple[torch.Tensor, list]:
"""Forward path backed by GroupedTensor + cublasLt grouped GEMM."""
"""Forward path backed by GroupedTensor + cuBLASLt grouped GEMM."""
num_gemms = len(m_splits)
device = inp.device
in_features = weights[0].size(-1)
Expand Down Expand Up @@ -491,9 +502,7 @@ def forward(
save_original_input=save_original_input,
activation_dtype=activation_dtype,
input_quantizers=input_quantizers,
weight_quantizers=weight_quantizers,
output_quantizers=output_quantizers,
grad_output_quantizers=grad_output_quantizers,
):
return _GroupedLinear._forward_grouped_tensor(
ctx,
Expand Down Expand Up @@ -745,7 +754,7 @@ def _backward_grouped_tensor(
columnwise=ctx.weights_requires_grad,
)
grad_output_quantizer.optimize_for_gemm = True
if ctx.use_bias:
if ctx.use_bias and isinstance(grad_output_quantizer, MXFP8Quantizer):
grouped_dy, dbias_packed = tex.bgrad_group_quantize(
dy_2d,
grad_output_quantizer,
Expand Down
36 changes: 25 additions & 11 deletions transformer_engine/pytorch/ops/basic/grouped_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
from ...cpu_offload import is_cpu_offload_enabled, mark_activation_offload, start_offload
from ...quantization import FP8GlobalStateManager, QuantizerRole, Recipe
from ...quantized_tensor import QuantizedTensorStorage
from ...tensor import MXFP8Quantizer, MXFP8Tensor, Quantizer
from ...tensor import MXFP8Quantizer, MXFP8Tensor, NVFP4Quantizer, Quantizer
from ...utils import (
canonicalize_device,
canonicalize_dtype,
Expand Down Expand Up @@ -764,17 +764,27 @@ def _is_graph_safe_path_supported(

* The graph-safe path dispatches to ``general_grouped_gemm_for_grouped_tensor``,
which is backed by ``nvte_grouped_gemm_with_discrete_inputA`` in the common
library. That kernel requires Blackwell (SM100) or newer with cuBLAS 13.3+.
* Quantized compute is currently MXFP8-only; every other quantization
recipe (fp8 delayed / current scaling, fp8 block scaling, NVFP4, ...)
falls back to the legacy flow.
* Unquantized compute supports BF16/FP16 only -- FP32 is excluded
because the cublasLt grouped GEMM doesn't support it.
library. This filter mirrors cuBLASLt grouped GEMM's architecture
requirement without duplicating its cuBLAS version checks.
* Quantized compute supports MXFP8 and NVFP4 on Blackwell GPUs with Compute Capability (CC)
10.x and 11.0. NVFP4 requires RHT because graph-safe grouped quantization currently
requires RHT;
Every other quantization recipe (fp8 delayed / current scaling, fp8 block scaling, ...)
falls back to the legacy flow because the corresponding grouped quantization kernels are
missing.
* Unquantized compute supports BF16/FP16 on Hopper (CC 9.0) and Blackwell (CC 10.x and 11.0)
-- FP32 is excluded because the cuBLASLt grouped GEMM doesn't support it.
* Input/weight/grad_output quantizers are assumed to be of the same type, otherwise it
would trigger a fatal error in the cuBLASLt grouped GEMM check.
"""
if get_device_compute_capability() < (10, 0):
if not (9, 0) <= get_device_compute_capability() <= (11, 0):
return False
if with_quantized_compute:
return all(isinstance(q, MXFP8Quantizer) for q in input_quantizers)
if not (10, 0) <= get_device_compute_capability() <= (11, 0):
return False
return all(isinstance(q, MXFP8Quantizer) for q in input_quantizers) or all(
isinstance(q, NVFP4Quantizer) and q.with_rht for q in input_quantizers
)
return dtype in (torch.bfloat16, torch.float16)

def _get_grouped_weight_for_gemm(
Expand Down Expand Up @@ -954,7 +964,7 @@ def fuser_forward(

# Dispatch: graph-safe GroupedTensor flow whenever it can be used.
# See ``_is_graph_safe_path_supported`` for the gating rationale --
# in short it requires Blackwell (SM100+) plus a supported dtype /
# in short it requires Hopper (SM90+) plus a supported dtype /
# quantization recipe. Otherwise we fall back to the legacy
# ``tex.split_quantize`` + ``general_grouped_gemm`` flow.
use_grouped_tensor_path = self._is_graph_safe_path_supported(
Expand Down Expand Up @@ -1582,7 +1592,11 @@ def _fuser_backward_grouped_tensor(
)
grad_output_quantizer.optimize_for_gemm = True

if has_bias and not self._scale_bias:
if (
has_bias
and not self._scale_bias
and isinstance(grad_output_quantizer, MXFP8Quantizer)
):
grouped_dy, dbias_packed = tex.bgrad_group_quantize(
dy_2d, grad_output_quantizer, num_groups, split_sizes
)
Expand Down
Loading