Skip to content

Commit 4ece13c

Browse files
authored
[TRTLLM-11540][feat] Add EAGLE3 dynamic tree speculative decoding support (#12062)
Signed-off-by: qgai <qgai@nvidia.com>
1 parent 9f9feb9 commit 4ece13c

36 files changed

+3566
-414
lines changed
Lines changed: 356 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,356 @@
1+
/*
2+
* Copyright (c) 2024-2026, NVIDIA CORPORATION. All rights reserved.
3+
* Portions Copyright (c) 2025 by SGLang team (original implementation).
4+
*
5+
* Licensed under the Apache License, Version 2.0 (the "License");
6+
* you may not use this file except in compliance with the License.
7+
* You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
#include "dynamicTreeKernels.h"
19+
#include "tensorrt_llm/common/assert.h"
20+
#include "tensorrt_llm/common/cudaUtils.h"
21+
22+
TRTLLM_NAMESPACE_BEGIN
23+
24+
using namespace tensorrt_llm::runtime;
25+
26+
namespace kernels::speculative_decoding
27+
{
28+
29+
//! \param parentList [in] layer-wise parent indices [bs, topK*(depth-1)+1]
30+
//! \param selectedIndex [in] resampled history buffer indices [bs, draftTokenNum-1]
31+
//! \param treeMask [out] attention mask (which nodes each node can see)
32+
//! \param positions [out] position id per node [bs, draftTokenNum]
33+
//! \param retrieveIndex [out] tree node -> local index mapping [bs, draftTokenNum]
34+
//! \param retrieveNextToken [out] first-child pointer [bs, draftTokenNum], -1=none
35+
//! \param retrieveNextSibling [out] next-sibling pointer [bs, draftTokenNum], -1=none
36+
//! \param topK top-K value per layer
37+
//! \param depth max tree depth (number of draft layers)
38+
//! \param draftTokenNum total tree nodes per batch (including root)
39+
__global__ void buildDynamicTreeKernel(int64_t const* parentList, int64_t const* selectedIndex, int32_t* treeMask,
40+
int32_t* positions, int32_t* retrieveIndex, int32_t* retrieveNextToken, int32_t* retrieveNextSibling,
41+
SizeType32 topK, SizeType32 depth, SizeType32 draftTokenNum)
42+
{
43+
int32_t bid = blockIdx.x;
44+
int32_t tid = threadIdx.x;
45+
46+
if (tid >= draftTokenNum)
47+
{
48+
return;
49+
}
50+
51+
// treeMask layout: [batchSize, draftTokenNum, draftTokenNum] (QLEN_ONLY mode)
52+
int32_t tokenTreeIdx = draftTokenNum * draftTokenNum * bid + draftTokenNum * tid + 1;
53+
54+
treeMask[tokenTreeIdx - 1] = 1; // self-attention diagonal
55+
for (int32_t i = 0; i < draftTokenNum - 1; i++)
56+
{
57+
treeMask[tokenTreeIdx + i] = 0;
58+
}
59+
60+
int32_t position = 0;
61+
62+
if (tid == 0)
63+
{
64+
positions[bid * draftTokenNum] = 0;
65+
66+
// Reverse iteration: inserting at list head produces forward sibling order
67+
for (int32_t i = draftTokenNum - 1; i > 0; --i)
68+
{
69+
retrieveIndex[bid * draftTokenNum + i] = i;
70+
71+
int64_t parentTbIdx = selectedIndex[bid * (draftTokenNum - 1) + i - 1] / topK;
72+
int32_t parentPosition = 0;
73+
74+
if (parentTbIdx > 0)
75+
{
76+
int64_t parentTokenIdx = parentList[bid * (topK * (depth - 1) + 1) + parentTbIdx];
77+
for (; parentPosition < draftTokenNum; ++parentPosition)
78+
{
79+
if (selectedIndex[bid * (draftTokenNum - 1) + parentPosition] == parentTokenIdx)
80+
{
81+
++parentPosition; // +1 because position 0 is root
82+
break;
83+
}
84+
}
85+
}
86+
87+
if (parentPosition == draftTokenNum)
88+
{
89+
printf(
90+
"WARNING: Invalid dynamic tree! Detected a token with no parent token selected. "
91+
"Please check if the logprob has nan. The token will be ignored.\n");
92+
continue;
93+
}
94+
95+
if (retrieveNextToken[bid * draftTokenNum + parentPosition] == -1)
96+
{
97+
retrieveNextToken[bid * draftTokenNum + parentPosition] = i;
98+
}
99+
else
100+
{
101+
int32_t originNextToken = retrieveNextToken[bid * draftTokenNum + parentPosition];
102+
retrieveNextToken[bid * draftTokenNum + parentPosition] = i;
103+
retrieveNextSibling[bid * draftTokenNum + i] = originNextToken;
104+
}
105+
}
106+
retrieveIndex[bid * draftTokenNum] = 0;
107+
}
108+
else
109+
{
110+
// Walk up to root, setting treeMask ancestor bits and counting depth
111+
int32_t curPosition = tid - 1;
112+
while (position < depth + 1)
113+
{
114+
position += 1;
115+
treeMask[tokenTreeIdx + curPosition] = 1;
116+
117+
int64_t parentTbIdx = selectedIndex[bid * (draftTokenNum - 1) + curPosition] / topK;
118+
if (parentTbIdx == 0)
119+
{
120+
break;
121+
}
122+
123+
int64_t tokenIdx = parentList[bid * (topK * (depth - 1) + 1) + parentTbIdx];
124+
for (curPosition = 0; curPosition < draftTokenNum; ++curPosition)
125+
{
126+
if (selectedIndex[bid * (draftTokenNum - 1) + curPosition] == tokenIdx)
127+
{
128+
break;
129+
}
130+
}
131+
if (curPosition == draftTokenNum)
132+
{
133+
break;
134+
}
135+
}
136+
positions[bid * draftTokenNum + tid] = position;
137+
}
138+
}
139+
140+
//! Bit-packed variant of buildDynamicTreeKernel.
141+
//! \param numInt32PerRow int32 count per treeMask row (buffer stride; >= ceil(draftTokenNum/32) if padded)
142+
__global__ void buildDynamicTreeKernelPacked(int64_t const* parentList, int64_t const* selectedIndex, int32_t* treeMask,
143+
int32_t* positions, int32_t* retrieveIndex, int32_t* retrieveNextToken, int32_t* retrieveNextSibling,
144+
SizeType32 topK, SizeType32 depth, SizeType32 draftTokenNum, SizeType32 numInt32PerRow)
145+
{
146+
int32_t bid = blockIdx.x;
147+
int32_t tid = threadIdx.x;
148+
149+
if (tid >= draftTokenNum)
150+
{
151+
return;
152+
}
153+
154+
int32_t rowBaseIdx = (bid * draftTokenNum + tid) * numInt32PerRow;
155+
156+
treeMask[rowBaseIdx] = 1; // bit 0 = root, always visible
157+
158+
int32_t position = 0;
159+
160+
if (tid == 0)
161+
{
162+
positions[bid * draftTokenNum] = 0;
163+
164+
for (int32_t i = draftTokenNum - 1; i > 0; --i)
165+
{
166+
retrieveIndex[bid * draftTokenNum + i] = i;
167+
168+
int64_t parentTbIdx = selectedIndex[bid * (draftTokenNum - 1) + i - 1] / topK;
169+
int32_t parentPosition = 0;
170+
171+
if (parentTbIdx > 0)
172+
{
173+
int64_t parentTokenIdx = parentList[bid * (topK * (depth - 1) + 1) + parentTbIdx];
174+
for (; parentPosition < draftTokenNum; ++parentPosition)
175+
{
176+
if (selectedIndex[bid * (draftTokenNum - 1) + parentPosition] == parentTokenIdx)
177+
{
178+
++parentPosition;
179+
break;
180+
}
181+
}
182+
}
183+
184+
if (parentPosition == draftTokenNum)
185+
{
186+
printf("WARNING: Invalid dynamic tree! Detected a token with no parent token selected.\n");
187+
continue;
188+
}
189+
190+
if (retrieveNextToken[bid * draftTokenNum + parentPosition] == -1)
191+
{
192+
retrieveNextToken[bid * draftTokenNum + parentPosition] = i;
193+
}
194+
else
195+
{
196+
int32_t originNextToken = retrieveNextToken[bid * draftTokenNum + parentPosition];
197+
retrieveNextToken[bid * draftTokenNum + parentPosition] = i;
198+
retrieveNextSibling[bid * draftTokenNum + i] = originNextToken;
199+
}
200+
}
201+
retrieveIndex[bid * draftTokenNum] = 0;
202+
}
203+
else
204+
{
205+
int32_t curPosition = tid - 1;
206+
while (position < depth + 1)
207+
{
208+
position += 1;
209+
210+
int32_t bitPosition = curPosition + 1; // +1 because bit 0 is root
211+
int32_t int32Idx = bitPosition / 32;
212+
int32_t bitIdx = bitPosition % 32;
213+
214+
if (int32Idx < numInt32PerRow)
215+
{
216+
atomicOr(&treeMask[rowBaseIdx + int32Idx], 1 << bitIdx);
217+
}
218+
219+
int64_t parentTbIdx = selectedIndex[bid * (draftTokenNum - 1) + curPosition] / topK;
220+
if (parentTbIdx == 0)
221+
{
222+
break;
223+
}
224+
225+
int64_t tokenIdx = parentList[bid * (topK * (depth - 1) + 1) + parentTbIdx];
226+
for (curPosition = 0; curPosition < draftTokenNum; ++curPosition)
227+
{
228+
if (selectedIndex[bid * (draftTokenNum - 1) + curPosition] == tokenIdx)
229+
{
230+
break;
231+
}
232+
}
233+
if (curPosition == draftTokenNum)
234+
{
235+
break;
236+
}
237+
}
238+
positions[bid * draftTokenNum + tid] = position;
239+
}
240+
}
241+
242+
void invokeBuildDynamicTree(int64_t const* parentList, int64_t const* selectedIndex, void* treeMask, int32_t* positions,
243+
int32_t* retrieveIndex, int32_t* retrieveNextToken, int32_t* retrieveNextSibling, SizeType32 batchSize,
244+
SizeType32 topK, SizeType32 depth, SizeType32 numDraftTokens, TreeMaskMode treeMaskMode, cudaStream_t stream,
245+
SizeType32 numInt32PerRow)
246+
{
247+
dim3 grid(batchSize);
248+
dim3 block(numDraftTokens);
249+
250+
if (treeMaskMode == TreeMaskMode::QLEN_ONLY_BITPACKING)
251+
{
252+
TLLM_CHECK_WITH_INFO(
253+
numInt32PerRow > 0, "numInt32PerRow must be the packed treeMask row stride in int32s (from buffer shape).");
254+
buildDynamicTreeKernelPacked<<<grid, block, 0, stream>>>(parentList, selectedIndex,
255+
static_cast<int32_t*>(treeMask), positions, retrieveIndex, retrieveNextToken, retrieveNextSibling, topK,
256+
depth, numDraftTokens, numInt32PerRow);
257+
}
258+
else
259+
{
260+
buildDynamicTreeKernel<<<grid, block, 0, stream>>>(parentList, selectedIndex, static_cast<int32_t*>(treeMask),
261+
positions, retrieveIndex, retrieveNextToken, retrieveNextSibling, topK, depth, numDraftTokens);
262+
}
263+
264+
sync_check_cuda_error(stream);
265+
}
266+
267+
//! \param predicts [out] accepted token ids + bonus token [bs * numDraftTokens]
268+
//! \param acceptIndex [out] accepted path as local tree positions [bs, numSpeculativeTokens]
269+
//! \param acceptTokenNum [out] number of accepted draft tokens per batch [bs]
270+
//! \param candidates [in] candidate token id per tree node [bs, numDraftTokens]
271+
//! \param retrieveIndex [in] tree node -> local index [bs, numDraftTokens]
272+
//! \param retrieveNextToken [in] first-child pointer [bs, numDraftTokens], -1=none
273+
//! \param retrieveNextSibling [in] next-sibling pointer [bs, numDraftTokens], -1=none
274+
//! \param targetPredict [in] target model prediction per position [bs * numDraftTokens]
275+
//! \param batchSize batch size
276+
//! \param numSpeculativeTokens second dim of acceptIndex (>= max possible accepts + 1)
277+
//! \param numDraftTokens total tree nodes per batch (including root)
278+
__global__ void verifyDynamicTreeGreedyKernel(int64_t* predicts, int64_t* acceptIndex, int64_t* acceptTokenNum,
279+
int64_t* acceptToken, int64_t const* candidates, int32_t const* retrieveIndex, int32_t const* retrieveNextToken,
280+
int32_t const* retrieveNextSibling, int64_t const* targetPredict, bool const* treeValid, uint32_t batchSize,
281+
uint32_t numSpeculativeTokens, uint32_t numDraftTokens)
282+
{
283+
uint32_t bx = blockIdx.x;
284+
uint32_t batchOffset = bx * numDraftTokens;
285+
286+
// First-gen or dummy request: no valid tree, accept only the bonus token
287+
if (treeValid != nullptr && !treeValid[bx])
288+
{
289+
acceptTokenNum[bx] = 0;
290+
acceptIndex[bx * numSpeculativeTokens] = 0;
291+
acceptToken[bx * numSpeculativeTokens] = targetPredict[batchOffset];
292+
predicts[batchOffset] = targetPredict[batchOffset];
293+
return;
294+
}
295+
296+
int32_t lastAcceptedLocalIdx = retrieveIndex[batchOffset];
297+
acceptIndex[bx * numSpeculativeTokens] = lastAcceptedLocalIdx;
298+
uint32_t numAcceptedTokens = 0;
299+
int32_t curIndex = 0;
300+
301+
// Root token: target prediction at root position
302+
acceptToken[bx * numSpeculativeTokens] = targetPredict[batchOffset + lastAcceptedLocalIdx];
303+
304+
for (uint32_t j = 1; j < numSpeculativeTokens; ++j)
305+
{
306+
curIndex = retrieveNextToken[batchOffset + curIndex];
307+
308+
while (curIndex != -1)
309+
{
310+
int32_t draftLocalIdx = retrieveIndex[batchOffset + curIndex];
311+
int64_t draftTokenId = candidates[batchOffset + curIndex];
312+
int64_t targetTokenId = targetPredict[batchOffset + lastAcceptedLocalIdx];
313+
314+
if (draftTokenId == targetTokenId)
315+
{
316+
predicts[batchOffset + lastAcceptedLocalIdx] = targetTokenId;
317+
++numAcceptedTokens;
318+
acceptIndex[bx * numSpeculativeTokens + numAcceptedTokens] = draftLocalIdx;
319+
// Accepted token: target prediction at accepted draft position
320+
acceptToken[bx * numSpeculativeTokens + numAcceptedTokens] = targetPredict[batchOffset + draftLocalIdx];
321+
lastAcceptedLocalIdx = draftLocalIdx;
322+
break;
323+
}
324+
else
325+
{
326+
curIndex = retrieveNextSibling[batchOffset + curIndex];
327+
}
328+
}
329+
330+
if (curIndex == -1)
331+
break;
332+
}
333+
334+
acceptTokenNum[bx] = numAcceptedTokens;
335+
// Bonus token from target model at the last accepted position
336+
predicts[batchOffset + lastAcceptedLocalIdx] = targetPredict[batchOffset + lastAcceptedLocalIdx];
337+
}
338+
339+
void invokeVerifyDynamicTreeGreedy(int64_t* predicts, int64_t* acceptIndex, int64_t* acceptTokenNum,
340+
int64_t* acceptToken, int64_t const* candidates, int32_t const* retrieveIndex, int32_t const* retrieveNextToken,
341+
int32_t const* retrieveNextSibling, int64_t const* targetPredict, bool const* treeValid, SizeType32 batchSize,
342+
SizeType32 numDraftTokens, SizeType32 numSpecStep, cudaStream_t stream)
343+
{
344+
dim3 grid(batchSize);
345+
dim3 block(1);
346+
347+
verifyDynamicTreeGreedyKernel<<<grid, block, 0, stream>>>(predicts, acceptIndex, acceptTokenNum, acceptToken,
348+
candidates, retrieveIndex, retrieveNextToken, retrieveNextSibling, targetPredict, treeValid, batchSize,
349+
numSpecStep, numDraftTokens);
350+
351+
sync_check_cuda_error(stream);
352+
}
353+
354+
} // namespace kernels::speculative_decoding
355+
356+
TRTLLM_NAMESPACE_END

0 commit comments

Comments
 (0)