Skip to content

[Common/PyTorch] bugfix: Token-linear fused RoPE impl. for THD tensors.#3057

Open
plugyawn wants to merge 10 commits into
NVIDIA:mainfrom
plugyawn:rope-thd-token-linear
Open

[Common/PyTorch] bugfix: Token-linear fused RoPE impl. for THD tensors.#3057
plugyawn wants to merge 10 commits into
NVIDIA:mainfrom
plugyawn:rope-thd-token-linear

Conversation

@plugyawn

@plugyawn plugyawn commented May 28, 2026

Copy link
Copy Markdown

Description

Adds a token-linear implementation of the existing THD fused RoPE path to remove a launch-scaling bug.

Addresses #2866, which finds an interesting case with RoPE scales by freqs_len × n_spans, which is pathological; it should scale by total tokens. I reproduced the issue and found that it's causing a noticeable drops on even plausibly routine shapes. For eg: the [128/512] and [512/128] cases here.

The new kernel reuses the existing fused_rope_block_forward and fused_rope_block_backward device helpers, so the math doesn't change. All we need to do is add a THD-only path that launches one bloc/packed token.

n_seqs max span old layer fwd+bwd (ms) new layer fwd+bwd (ms) layer speedup old paired-RoPE share new paired-RoPE share
128 512 41.8151 23.0284 1.816x 49.12% 6.14%
512 128 102.1047 23.0167 4.436x 79.38% 6.59%
1024 64 182.9933 23.3783 7.827x 88.36% 6.77%
2401 28 401.0516 24.5668 16.325x 94.40% 6.41%

This is mostly pathological, however, so I've added a condition on the dispatch to avoid the unnecessary binary search overhead, although the overhead appears to be not-that-relevant. The condition is: token-linear only when b >= 64 and the old launch would issue ≥ 8× as many blocks as there are tokens. I'm not sure if this the usual shape of TE updates, so I could remove it!

Some more relevant tests:
Microbenchmark on H100 (bf16, h=32, d=d2=128, freqs_len=T_local=65536, single GPU):

n_seqs old fwd+bwd (ms) new fwd+bwd (ms) speedup
1 1.2746 1.2734 1.001x
8 1.8860 1.3827 1.364x
32 3.9359 1.4462 2.722x
128 12.1849 1.5024 8.110x
512 44.9411 1.5600 28.808x
1024 89.1110 1.5919 55.977x
2401 208.4182 1.6373 127.296x

Fixes: #2866.

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Changes

Please list the changes introduced in this PR:

  • Add token-linear THD fused RoPE forward/backward kernels that launch one CUDA block per packed local token row.
  • Add NVTE_FUSED_ROPE_THD_TOKEN_LINEAR=0|1.
  • Reuses existing fused_rope_block_forward and fused_rope_block_backward device helpers.

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation <<(none?)>>
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

@github-actions github-actions Bot added the community-contribution PRs from external contributor outside the core maintainers, representing community-driven work. label May 28, 2026
@greptile-apps

greptile-apps Bot commented May 28, 2026

Copy link
Copy Markdown
Contributor

Greptile Summary

This PR fixes a launch-scaling bug in the THD fused RoPE path where the legacy kernel launched freqs_len × n_seqs CUDA blocks instead of the actual number of packed local tokens, causing severe over-launch for many-sequence workloads. The fix adds a token-linear forward/backward kernel pair that launches one block per packed token row and reuses the existing fused_rope_block_forward/fused_rope_block_backward device helpers unchanged.

  • New CUDA kernels (fused_rope_thd_token_forward/backward_kernel): each block does a binary search over divided cu_seqlens boundaries to find its owning sequence, applies the same CP offset formula as the legacy kernel, and calls the shared device helper.
  • Heuristic dispatcher (fused_rope_thd_use_token_linear): activates the new path when legacy_blocks > 2 × cp_size × token_linear_blocks; overridable via NVTE_FUSED_ROPE_THD_TOKEN_LINEAR=0|1.
  • Tests: a new parity test forces both paths back-to-back and asserts bitwise equality across many shapes including zero-length spans and CP ranks.

Confidence Score: 4/5

The new kernel logic is correct and well-tested, but the dim3 cast from size_t to unsigned int in both launchers has no explicit guard.

The core math is provably correct — the new kernel reuses the unchanged device helpers with an identical s_id_for_freqs computation, and the bitwise-equality parity test validates this across many shapes. The binary search handles zero-length spans and CP ranks correctly. The only unguarded edge is the static_cast to unsigned int in both launchers: if local_tokens ever exceeds UINT_MAX, the grid silently launches far fewer blocks than needed without any error.

transformer_engine/common/fused_rope/fused_rope.cu — the two static_cast calls in the forward and backward launchers should have an explicit upper-bound check before the cast.

Important Files Changed

Filename Overview
transformer_engine/common/fused_rope/fused_rope.cu Adds token-linear THD forward/backward kernels with correct binary-search sequence lookup, shared-memory valid_token guard, and heuristic dispatcher; logic correctly mirrors the existing per-block kernels
tests/pytorch/test_fused_rope.py Adds parity test that forces both kernel paths and asserts bitwise equality; existing test_fused_rope_thd now only exercises the new path via monkeypatch.setenv; zero-length span coverage included
benchmarks/attention/benchmark_rope_thd_token_linear.py New standalone microbenchmark sweeping n_seqs; uses context manager for env overrides, correct cu_seqlens construction, and structured CSV output

Flowchart

%%{init: {'theme': 'neutral'}}%%
flowchart TD
    A["fused_rope_forward / fused_rope_backward"] --> B["Compute local_tokens from input.shape[0] for THD"]
    B --> C["fused_rope_forward/backward_launcher"]
    C --> D{"fused_rope_thd_use_token_linear?"}
    D -- "env=1 or heuristic passes" --> E["Launch token-linear kernel dim3(local_tokens)"]
    D -- "env=0 or heuristic fails" --> F["Launch legacy kernel dim3(s, b)"]
    E --> G["fused_rope_thd_token_fwd/bwd_kernel blockIdx.x = t_id"]
    G --> H["Thread 0: valid_token + seq_id via binary search"]
    H --> I["__syncthreads() + early return if invalid"]
    I --> J["Compute s_id_for_freqs with CP offsets"]
    J --> K["fused_rope_block_forward/backward"]
Loading

Reviews (8): Last reviewed commit: "Merge branch 'main' into rope-thd-token-..." | Re-trigger Greptile

Comment on lines +250 to +251
int t_id = blockIdx.x;
int b_id = fused_rope_thd_find_seq_id(cu_seqlens, nseq, t_id, cp_size);

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Redundant binary search across all threads in the block

Every thread in the block calls fused_rope_thd_find_seq_id with the same arguments (t_id = blockIdx.x, nseq, cp_size) and produces an identical result. With warps_per_block = 8, that's 256 threads each doing O(log nseq) global-memory reads of cu_seqlens that could be performed once. For nseq=2401 (~12 iterations x 256 threads), each block reads ~3,072 redundant entries from cu_seqlens. Performing the search once in thread 0 and broadcasting the result via shared memory would eliminate that overhead.

Note: If this suggestion doesn't match your team's coding style, reply to this and let me know. I'll remember it for next time!

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Smart bot!

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Resolved!

Comment thread transformer_engine/common/fused_rope/fused_rope.cu
@ptrendx

ptrendx commented May 28, 2026

Copy link
Copy Markdown
Member

@plugyawn Hi, could you sign your commits? See https://github.com/NVIDIA/TransformerEngine/blob/main/CONTRIBUTING.rst#sign-your-work
Nice improvement :-).

@sudhakarsingh27 Could you take a look?

plugyawn and others added 3 commits May 29, 2026 03:23
Signed-off-by: plugyawn <progyan.das@iitgn.ac.in>
Signed-off-by: plugyawn <progyan.das@iitgn.ac.in>
for more information, see https://pre-commit.ci

Signed-off-by: plugyawn <progyan.das@iitgn.ac.in>
@plugyawn plugyawn force-pushed the rope-thd-token-linear branch from 331a3a0 to 6c46696 Compare May 28, 2026 21:55
@plugyawn

plugyawn commented May 28, 2026

Copy link
Copy Markdown
Author

Thanks! Signed!

fwiw I think the binary search overhead on normal cases can be reduced also, I'll probably add some improvements.

@sudhakarsingh27 sudhakarsingh27 self-requested a review June 3, 2026 22:08

@sudhakarsingh27 sudhakarsingh27 left a comment

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Posted the RoPE THD token-linear review comments from the local benchmark/coverage analysis. The main concerns are the dispatch heuristic, CP-local token accounting, CP-rank coverage, and benchmark scope.

Comment thread transformer_engine/common/fused_rope/fused_rope.cu Outdated
const int o_stride_h = d;
const int o_stride_d = 1;

if (fused_rope_thd_use_token_linear(qkv_format, b, s, total_tokens)) {

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please make the compact launch decision use the actual local THD rows and the legacy launch blocks. The local patch uses this shape:

const size_t compact_thd_blocks = input.data.shape[0];
const size_t legacy_thd_blocks = static_cast<size_t>(s) * b;

if (fused_rope_thd_use_compact_launch(legacy_thd_blocks, compact_thd_blocks, cp_size)) {
  const int t = input.data.shape[0];
  dim3 blocks(t);
  ...
}

This also avoids routing the heuristic through a total_tokens value whose CP/global semantics are easy to confuse.

@plugyawn plugyawn Jun 9, 2026

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed. Also renamed the variable from total_tokens, so no CP/global ambiguity. Could you check if it's fine now?

Comment thread tests/pytorch/test_fused_rope.py Outdated
Comment thread benchmarks/attention/benchmark_rope_thd_token_linear.py Outdated
Comment thread benchmarks/attention/benchmark_rope_thd_full_layer.py Outdated
Signed-off-by: plugyawn <progyan.das@iitgn.ac.in>
@plugyawn

plugyawn commented Jun 9, 2026

Copy link
Copy Markdown
Author

Fixed some of the review comments, resolving the rest now.


Additional CP-rank validation for the THD token-linear RoPE path on the rebased PR tip:

  • Commit: eaee5a1731141654a006f9872fcfe10132cdcf76 (Cover CP ranks in THD RoPE token-linear tests)
  • Hardware/runtime: Prime Datacrunch A100 80GB, driver 580.126.09, CUDA 12.8, PyTorch 2.8.0+cu128
  • Build: editable TE PyTorch build passed; fused_rope.cu and apply_rope.cpp compiled on this exact tip
  • test_fused_rope_thd_token_linear_parity: 288 passed / 96 skipped / 0 failed. The skips are invalid cp_rank >= cp_size; the JUnit/log includes 96 passing cp_rank=1, cp_size=2 cases.
  • test_fused_rope_thd with the token-linear path forced: 384 passed / 0 failed

This closes the earlier proof gap where old-vs-new parity only covered cp_rank=0.

@plugyawn plugyawn requested a review from sudhakarsingh27 June 9, 2026 08:10
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

community-contribution PRs from external contributor outside the core maintainers, representing community-driven work.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

[Performance] Fused RoPE THD kernel becomes dominant bottleneck in long-context training with many packed sequences

3 participants