Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 16 additions & 0 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,22 @@ repos:
args: [--line-length=100, --preview, --enable-unstable-feature=string_processing]
types: [python]

- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.11.4
hooks:
- id: ruff
name: Lint python code
types: [python]
files: ^transformer_engine/

- repo: https://github.com/cpplint/cpplint
rev: '1.6.0'
hooks:
- id: cpplint
types_or: [c, c++, cuda]
files: ^transformer_engine/(common|jax|pytorch)/
exclude: ^transformer_engine/build_tools/build/

- repo: https://github.com/pre-commit/mirrors-clang-format
rev: v18.1.6
hooks:
Expand Down
38 changes: 0 additions & 38 deletions pylintrc

This file was deleted.

37 changes: 37 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -7,3 +7,40 @@ requires = ["setuptools>=61.0", "cmake>=3.21", "wheel", "pybind11[global]", "nin

# Use legacy backend to import local packages in setup.py
build-backend = "setuptools.build_meta:__legacy__"

[tool.ruff]
line-length = 100

[tool.ruff.format]
preview = true
docstring-code-format = true

[tool.ruff.lint]
select = ["E", "F", "W", "PL"]
ignore = [
"E402", # module-level-import-not-at-top (pylint import-outside-toplevel was disabled)
"E501", # line-too-long
"E731", # lambda-assignment
"E741", # ambiguous-variable-name (pylint invalid-name was disabled)
"PLR0904", # too-many-public-methods
"PLR0911", # too-many-return-statements
"PLR0912", # too-many-branches
"PLR0913", # too-many-arguments
"PLR0914", # too-many-locals
"PLR0915", # too-many-statements
"PLR0917", # too-many-positional-arguments
"PLR1702", # too-many-nested-blocks
"PLR1704", # redefined-argument-from-local
"PLR2004", # magic-value-comparison
"PLR5501", # collapsible-else-if
"PLW0602", # global-variable-not-assigned
"PLW0603", # global-statement
"PLW2901", # redefined-loop-name
"PLC0415", # import-outside-toplevel
]

[tool.ruff.lint.per-file-ignores]
"__init__.py" = ["F401", "F403"]
"transformer_engine/pytorch/fp8.py" = ["F401"]
"transformer_engine/pytorch/export.py" = ["F401"]
"transformer_engine/pytorch/attention/dot_product_attention/backends.py" = ["F401"]
5 changes: 3 additions & 2 deletions qa/L0_jax_lint/test.sh
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
# NOTE: This test is duplicated from pre-commit, and could be deleted.

set -e

: "${TE_PATH:=/opt/transformerengine}"

pip3 install cpplint==1.6.0 pylint==3.3.1
pip3 install cpplint==1.6.0 ruff==0.11.4
if [ -z "${PYTHON_ONLY}" ]
then
cd $TE_PATH
Expand All @@ -20,5 +21,5 @@ if [ -z "${CPP_ONLY}" ]
then
cd $TE_PATH
echo "Checking Python files"
python3 -m pylint --recursive=y transformer_engine/common transformer_engine/jax
python3 -m ruff check transformer_engine/common transformer_engine/jax
fi
5 changes: 3 additions & 2 deletions qa/L0_pytorch_lint/test.sh
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
# NOTE: This script is duplicated from pre-commit checks, and could be deleted.

set -e

: "${TE_PATH:=/opt/transformerengine}"

pip3 install cpplint==1.6.0 pylint==3.3.1
pip3 install cpplint==1.6.0 ruff==0.11.4
if [ -z "${PYTHON_ONLY}" ]
then
cd $TE_PATH
Expand All @@ -20,5 +21,5 @@ if [ -z "${CPP_ONLY}" ]
then
cd $TE_PATH
echo "Checking Python files"
python3 -m pylint --recursive=y transformer_engine/common transformer_engine/pytorch transformer_engine/debug
python3 -m ruff check transformer_engine/common transformer_engine/pytorch transformer_engine/debug
fi
2 changes: 1 addition & 1 deletion transformer_engine/debug/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,5 +7,5 @@
try:
from . import pytorch
from .pytorch.debug_state import set_weight_tensor_tp_group_reduce
except ImportError as e:
except ImportError:
pass
4 changes: 2 additions & 2 deletions transformer_engine/debug/features/_test_dummy_feature.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ def inspect_tensor_enabled(self, config, *_args, **_kwargs):
"""
# Access counter via full module path to ensure we're modifying the same module-level
# variable regardless of import context (debug framework vs test import)
import transformer_engine.debug.features._test_dummy_feature as dummy_feature # pylint: disable=import-self
import transformer_engine.debug.features._test_dummy_feature as dummy_feature # noqa: PLW0406 # pylint: disable=import-self

dummy_feature._inspect_tensor_enabled_call_count += 1

Expand All @@ -56,6 +56,6 @@ def inspect_tensor_enabled(self, config, *_args, **_kwargs):
def inspect_tensor(self, _config, *_args, **_kwargs):
"""This method does nothing but always tracks invocations for testing."""
# Access counter via full module path to ensure shared state across import contexts
import transformer_engine.debug.features._test_dummy_feature as dummy_feature # pylint: disable=import-self
import transformer_engine.debug.features._test_dummy_feature as dummy_feature # noqa: PLW0406 # pylint: disable=import-self

dummy_feature._inspect_tensor_call_count += 1
5 changes: 1 addition & 4 deletions transformer_engine/debug/features/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -466,10 +466,7 @@ def output_assertions_hook(self, api_name, ret, **kwargs):
assert ret is None
if api_name == "modify_tensor":
assert type(ret) in get_all_tensor_types()
if (
type(ret) == torch.Tensor # pylint: disable=unidiomatic-typecheck
and "dtype" in kwargs
):
if type(ret) is torch.Tensor and "dtype" in kwargs:
if kwargs["dtype"] is not None:
assert ret.dtype == kwargs["dtype"]

Expand Down
3 changes: 1 addition & 2 deletions transformer_engine/debug/pytorch/debug_quantization.py
Original file line number Diff line number Diff line change
Expand Up @@ -418,8 +418,7 @@ def any_feature_enabled(self) -> bool:
self.inspect_tensor_enabled
or self.inspect_tensor_postquantize_enabled_rowwise
or self.inspect_tensor_postquantize_enabled_columnwise
or self.rowwise_tensor_plan == API_CALL_MODIFY
or self.columnwise_tensor_plan == API_CALL_MODIFY
or API_CALL_MODIFY in (self.rowwise_tensor_plan, self.columnwise_tensor_plan)
):
return True
if self.parent_quantizer is not None:
Expand Down
1 change: 0 additions & 1 deletion transformer_engine/jax/cpp_extensions/gemm.py
Original file line number Diff line number Diff line change
Expand Up @@ -816,7 +816,6 @@ def batcher(
sequence_dim,
is_outer,
):
del transpose_batch_sequence, sequence_dim, is_outer
if GemmPrimitive.outer_primitive is None:
raise RuntimeError("GemmPrimitive.outer_primitive has not been registered")
lhs_bdims, _, rhs_bdims, *_ = batch_dims
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3355,10 +3355,7 @@ def forward(
assert attn_bias_type == "no_bias", f"{attn_bias_type} bias type is not supported!"
assert q.shape[-1] % 8 == 0, "Hidden size per attention head should be multiple of 8!"
assert (
window_size == (-1, 0)
or window_size == (-1, -1)
or use_fused_attention
or fa_utils.v2_3_plus
window_size in ((-1, 0), (-1, -1)) or use_fused_attention or fa_utils.v2_3_plus
), "Sliding window attention only can work with FusedAttention or FlashAttention >= 2.3!"

flash_attn_fwd = None
Expand Down Expand Up @@ -4061,9 +4058,7 @@ def attn_forward_func_with_cp(
cu_seqlens_q_padded is not None and cu_seqlens_kv_padded is not None
), "cu_seqlens_padded can not be None for context parallelism and qkv_format = 'thd'!"

sliding_window_attn = (
window_size is not None and window_size != (-1, 0) and window_size != (-1, -1)
)
sliding_window_attn = window_size is not None and window_size not in ((-1, 0), (-1, -1))
assert not sliding_window_attn or cp_comm_type in [
"a2a",
"all_gather",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -399,9 +399,7 @@ def qgemm(
return y

# cublas fp8 gemm does not support fp32 bias
use_bias_in_gemm = (
bias is not None and out_dtype != torch.float32 and bias.dtype != torch.float32
)
use_bias_in_gemm = bias is not None and torch.float32 not in (out_dtype, bias.dtype)

# Run quantized gemm: y = qw * qx
scaled_mm_res = torch._scaled_mm(
Expand Down
1 change: 0 additions & 1 deletion transformer_engine/pytorch/distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
from torch.cuda import _lazy_call, _lazy_init
from torch.utils.checkpoint import detach_variable, noop_context_fn
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.distributed.fsdp._common_utils import _get_module_fsdp_state
from torch.distributed.fsdp._traversal_utils import _get_fsdp_states_with_modules

try:
Expand Down
4 changes: 3 additions & 1 deletion transformer_engine/pytorch/ops/basic/grouped_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -499,7 +499,9 @@ def pre_first_fuser_forward(self) -> None:
f"Weight {group_idx} has requires_grad={weight.requires_grad}, "
f"but expected requires_grad={weight_requires_grad}."
)
if type(weight.data) != weight_tensor_type: # pylint: disable=unidiomatic-typecheck
if (
type(weight.data) is not weight_tensor_type
): # pylint: disable=unidiomatic-typecheck
raise RuntimeError(
f"Weight {group_idx} has invalid tensor type "
f"(expected {weight_tensor_type.__name__}, "
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ def fuser_forward(
if basic_op_kwargs[idx]:
raise ValueError("Bias operation forward does not expect keyword arguments")
if self._op_idxs["activation"] is None:
activation_op = None # pylint: disable=unused-variable
pass # No activation op needed
else:
raise NotImplementedError("Activations are not yet supported")

Expand Down
Loading