[Common/PyTorch] bugfix: Token-linear fused RoPE impl. for THD tensors.#3057
[Common/PyTorch] bugfix: Token-linear fused RoPE impl. for THD tensors.#3057plugyawn wants to merge 10 commits into
Conversation
Greptile SummaryThis PR fixes a launch-scaling bug in the THD fused RoPE path where the legacy kernel launched
Confidence Score: 4/5The new kernel logic is correct and well-tested, but the dim3 cast from size_t to unsigned int in both launchers has no explicit guard. The core math is provably correct — the new kernel reuses the unchanged device helpers with an identical s_id_for_freqs computation, and the bitwise-equality parity test validates this across many shapes. The binary search handles zero-length spans and CP ranks correctly. The only unguarded edge is the static_cast to unsigned int in both launchers: if local_tokens ever exceeds UINT_MAX, the grid silently launches far fewer blocks than needed without any error. transformer_engine/common/fused_rope/fused_rope.cu — the two static_cast calls in the forward and backward launchers should have an explicit upper-bound check before the cast. Important Files Changed
Flowchart%%{init: {'theme': 'neutral'}}%%
flowchart TD
A["fused_rope_forward / fused_rope_backward"] --> B["Compute local_tokens from input.shape[0] for THD"]
B --> C["fused_rope_forward/backward_launcher"]
C --> D{"fused_rope_thd_use_token_linear?"}
D -- "env=1 or heuristic passes" --> E["Launch token-linear kernel dim3(local_tokens)"]
D -- "env=0 or heuristic fails" --> F["Launch legacy kernel dim3(s, b)"]
E --> G["fused_rope_thd_token_fwd/bwd_kernel blockIdx.x = t_id"]
G --> H["Thread 0: valid_token + seq_id via binary search"]
H --> I["__syncthreads() + early return if invalid"]
I --> J["Compute s_id_for_freqs with CP offsets"]
J --> K["fused_rope_block_forward/backward"]
Reviews (8): Last reviewed commit: "Merge branch 'main' into rope-thd-token-..." | Re-trigger Greptile |
| int t_id = blockIdx.x; | ||
| int b_id = fused_rope_thd_find_seq_id(cu_seqlens, nseq, t_id, cp_size); |
There was a problem hiding this comment.
Redundant binary search across all threads in the block
Every thread in the block calls fused_rope_thd_find_seq_id with the same arguments (t_id = blockIdx.x, nseq, cp_size) and produces an identical result. With warps_per_block = 8, that's 256 threads each doing O(log nseq) global-memory reads of cu_seqlens that could be performed once. For nseq=2401 (~12 iterations x 256 threads), each block reads ~3,072 redundant entries from cu_seqlens. Performing the search once in thread 0 and broadcasting the result via shared memory would eliminate that overhead.
Note: If this suggestion doesn't match your team's coding style, reply to this and let me know. I'll remember it for next time!
|
@plugyawn Hi, could you sign your commits? See https://github.com/NVIDIA/TransformerEngine/blob/main/CONTRIBUTING.rst#sign-your-work @sudhakarsingh27 Could you take a look? |
Signed-off-by: plugyawn <progyan.das@iitgn.ac.in>
Signed-off-by: plugyawn <progyan.das@iitgn.ac.in>
for more information, see https://pre-commit.ci Signed-off-by: plugyawn <progyan.das@iitgn.ac.in>
331a3a0 to
6c46696
Compare
|
Thanks! Signed! fwiw I think the binary search overhead on normal cases can be reduced also, I'll probably add some improvements. |
sudhakarsingh27
left a comment
There was a problem hiding this comment.
Posted the RoPE THD token-linear review comments from the local benchmark/coverage analysis. The main concerns are the dispatch heuristic, CP-local token accounting, CP-rank coverage, and benchmark scope.
| const int o_stride_h = d; | ||
| const int o_stride_d = 1; | ||
|
|
||
| if (fused_rope_thd_use_token_linear(qkv_format, b, s, total_tokens)) { |
There was a problem hiding this comment.
Please make the compact launch decision use the actual local THD rows and the legacy launch blocks. The local patch uses this shape:
const size_t compact_thd_blocks = input.data.shape[0];
const size_t legacy_thd_blocks = static_cast<size_t>(s) * b;
if (fused_rope_thd_use_compact_launch(legacy_thd_blocks, compact_thd_blocks, cp_size)) {
const int t = input.data.shape[0];
dim3 blocks(t);
...
}This also avoids routing the heuristic through a total_tokens value whose CP/global semantics are easy to confuse.
There was a problem hiding this comment.
Fixed. Also renamed the variable from total_tokens, so no CP/global ambiguity. Could you check if it's fine now?
Signed-off-by: plugyawn <progyan.das@iitgn.ac.in>
|
Fixed some of the review comments, resolving the rest now. Additional CP-rank validation for the THD token-linear RoPE path on the rebased PR tip:
This closes the earlier proof gap where old-vs-new parity only covered |
Signed-off-by: plugyawn <progyan.das@iitgn.ac.in>
for more information, see https://pre-commit.ci
Description
Adds a token-linear implementation of the existing THD fused RoPE path to remove a launch-scaling bug.
Addresses #2866, which finds an interesting case with RoPE scales by freqs_len × n_spans, which is pathological; it should scale by total tokens. I reproduced the issue and found that it's causing a noticeable drops on even plausibly routine shapes. For eg: the [128/512] and [512/128] cases here.
The new kernel reuses the existing
fused_rope_block_forwardandfused_rope_block_backwarddevice helpers, so the math doesn't change. All we need to do is add a THD-only path that launches one bloc/packed token.This is mostly pathological, however, so I've added a condition on the dispatch to avoid the unnecessary binary search overhead, although the overhead appears to be not-that-relevant. The condition is: token-linear only when
b >= 64and the old launch would issue ≥ 8× as many blocks as there are tokens. I'm not sure if this the usual shape of TE updates, so I could remove it!Some more relevant tests:
Microbenchmark on H100 (bf16,
h=32,d=d2=128,freqs_len=T_local=65536, single GPU):Fixes: #2866.
Type of change
Changes
Please list the changes introduced in this PR:
NVTE_FUSED_ROPE_THD_TOKEN_LINEAR=0|1.fused_rope_block_forwardandfused_rope_block_backwarddevice helpers.Checklist: