[TRTLLM-11540][feat] Add EAGLE3 dynamic tree speculative decoding support#12062
[TRTLLM-11540][feat] Add EAGLE3 dynamic tree speculative decoding support#12062sunnyqgg merged 11 commits intoNVIDIA:mainfrom
Conversation
|
/bot run |
|
PR_Github #38374 [ run ] triggered by Bot. Commit: |
📝 WalkthroughWalkthroughThis pull request introduces dynamic tree-based speculative decoding for EAGLE3 inference with CUDA-accelerated kernels. It adds tree construction and greedy verification kernels, integrates them with Python/PyTorch layers, implements dynamic tree sampling and acceptance logic, and extends the Eagle3 resource manager and worker infrastructure to support dynamic tree mode alongside static tree mode. Changes
Sequence Diagram(s)sequenceDiagram
participant Python as Python Layer
participant Sampler as TorchSampler
participant DTree as DynamicTreeOpsConverter
participant Kernel as CUDA Kernel
participant Buf as GPU Buffers
Python->>Sampler: update() with requests
Sampler->>Sampler: _batch_verify_dynamic_tree(requests, tokens)
Sampler->>DTree: verify_dynamic_tree_greedy(draft_tokens, logits, tree_buffers)
DTree->>DTree: compute target predictions from logits
DTree->>Kernel: invoke verify_dynamic_tree_greedy_op()
Kernel->>Buf: read: candidates, retrieve_index, retrieve_next_sibling, targetPredict
Kernel->>Buf: greedy tree traversal per batch
Kernel->>Buf: write: predicts, acceptIndex, acceptTokenNum
Buf-->>Kernel: results
Kernel-->>DTree: returns VerifyTreeResults
DTree-->>Sampler: per-slot (num_accepted_tokens, accept_index)
Sampler->>Sampler: _process_draft_tokens_dynamic_tree() per request
Sampler-->>Python: updated tokens and finish reasons
sequenceDiagram
participant Worker as Eagle3DynamicTreeWorker
participant DraftModel as Draft Model
participant DTree as DynamicTreeOpsConverter
participant Kernel as CUDA Kernel
participant Cache as KV Cache
Worker->>Worker: _forward_draft_loop(initial context)
Worker->>DraftModel: forward(input_ids, position_ids)
DraftModel-->>Worker: logits
Worker->>Worker: sample_dynamic(logits, topk)
Worker->>Worker: dt_update_draft_tokens_and_scores()
Worker->>DTree: build_dynamic_tree(parent_list, topk_indices, tree_buffers)
DTree->>Kernel: invoke build_dynamic_tree_op()
Kernel->>Kernel: construct left-child/right-sibling tree
Kernel->>Kernel: compute per-node attention masks (treeMask)
Kernel->>Kernel: compute absolute positions
Kernel-->>DTree: DynamicTreeBuffers (tree_mask, positions, retrieve_index, etc.)
DTree-->>Worker: tree structure ready
Worker->>Worker: dt_prepare_tree_mask_and_position_offset()
Worker->>DraftModel: forward(growing context with tree topology)
DraftModel->>Cache: update KV cache with tree positions
DraftModel-->>Worker: logits per tree node
Worker->>Worker: _sample_and_accept_dynamic_tree(logits)
sequenceDiagram
participant App as Application
participant Executor as PyExecutor
participant ResourceMgr as Eagle3OneModelDynamicTreeResourceManager
participant Worker as Eagle3OneModelDynamicTreeWorker
participant Sampler as Eagle3OneModelDynamicTreeSampler
App->>Executor: initialize with EagleDecodingConfig(use_dynamic_tree=True)
Executor->>ResourceMgr: create with SpecTreeManager(use_dynamic_tree=True)
Executor->>Worker: initialize with spec_config
Executor->>Sampler: initialize with spec_config
App->>Executor: generate(requests)
Executor->>Worker: forward(context & draft loop)
Worker->>Worker: _forward_dynamic_tree_draft_loop()
Worker-->>Executor: draft_tokens, dynamic_tree_buffers, accepted_draft_indices
Executor->>Sampler: sample_and_accept_draft_tokens(logits, buffers)
Sampler->>Sampler: verify with dynamic tree buffers
Sampler-->>Executor: accepted tokens, accepted indices
Executor->>ResourceMgr: get_needed_resource_to_completion(request)
ResourceMgr-->>Executor: resource estimates
Executor->>Worker: prepare_1st_drafter_inputs(with tree topology targets)
Executor-->>App: generated tokens
Estimated code review effort🎯 4 (Complex) | ⏱️ ~75 minutes 🚥 Pre-merge checks | ✅ 2 | ❌ 1❌ Failed checks (1 warning)
✅ Passed checks (2 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches🧪 Generate unit tests (beta)
Comment |
There was a problem hiding this comment.
Actionable comments posted: 17
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (1)
tensorrt_llm/_torch/pyexecutor/sampler.py (1)
3346-3386:⚠️ Potential issue | 🔴 CriticalZero-draft dynamic-tree requests currently fall into the static-tree verifier.
_batch_verify_dynamic_tree()explicitly skips requests whose draft length is 0, but the fallback here still callsprocess_draft_tokens(). With a non-nullspec_tree_manager, that dispatches to_process_draft_tokens_tree(), which assumes a populated draft tree. If a dynamic-tree request reaches this branch with no drafts, it will fail instead of just emitting the verified token.🛠️ One possible guard
- if req.py_seq_slot in dynamic_tree_results: + if req.py_seq_slot in dynamic_tree_results: num_accepted = self._process_draft_tokens_dynamic_tree( req, new_tokens_list, finish_reasons, dynamic_tree_results[req.py_seq_slot] ) - + elif spec_tree_manager is not None and spec_tree_manager.use_dynamic_tree: + num_accepted = self._process_draft_tokens_greedy( + req, new_tokens=new_tokens_list, finish_reasons=finish_reasons + ) else: num_accepted = self.process_draft_tokens( req, new_tokens_tensor=new_tokens, new_tokens_list=new_tokens_list,
🧹 Nitpick comments (5)
tensorrt_llm/_torch/pyexecutor/py_executor_creator.py (1)
435-437: Keep the wrapper imports namespaced here too.Since this is new dispatch code, please import the
drafting_loopsmodule and reference these wrappers from that module rather than importing the classes directly. As per coding guidelines,When importing in Python, always maintain the namespace. Import the module, not individual classes or functions.🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tensorrt_llm/_torch/pyexecutor/py_executor_creator.py` around lines 435 - 437, Replace the direct class imports DynamicTreeDraftingLoopWrapper, LinearDraftingLoopWrapper, StaticTreeDraftingLoopWrapper with a namespaced module import for drafting_loops and update all references to use drafting_loops.DynamicTreeDraftingLoopWrapper, drafting_loops.LinearDraftingLoopWrapper, and drafting_loops.StaticTreeDraftingLoopWrapper (e.g., where these classes are used in the dispatch/registration code inside py_executor_creator.py) so the module is imported, not individual classes.tests/unittest/_torch/speculative/test_draft_token_tree_sampling.py (1)
8-9: Keep the drafting-loop import namespaced.Please import the
drafting_loopsmodule and resolveStaticTreeDraftingLoopWrapperthrough that namespace here instead of importing the class directly. As per coding guidelines,When importing in Python, always maintain the namespace. Import the module, not individual classes or functions.Also applies to: 56-56
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tests/unittest/_torch/speculative/test_draft_token_tree_sampling.py` around lines 8 - 9, Replace the direct class import with a module-level import for the drafting_loops module and qualify the class through that namespace: change the current "from tensorrt_llm._torch.speculative.drafting_loops import StaticTreeDraftingLoopWrapper" to "import tensorrt_llm._torch.speculative.drafting_loops as drafting_loops" (or "from tensorrt_llm._torch.speculative import drafting_loops") and update all usages to drafting_loops.StaticTreeDraftingLoopWrapper (also fix the similar import/usage at the other occurrence noted).tensorrt_llm/_torch/attention_backend/interface.py (1)
371-381: Making trailing parameters keyword-only would improve API safety, but it is not required.The code is already safe: all call sites either use keyword arguments (
model_engine.py:3552, test cases) or correctly pass all 9 positional arguments in order (sparse/dsa.py:541). No stale callers with 7–8 positional arguments (which would silently misbind after insertingis_target_model) exist in the codebase.If keyword-only enforcement is desired for defensiveness, add
*beforeis_target_model:def update_spec_dec_param( self, batch_size, is_spec_decoding_enabled, is_spec_dec_tree, is_spec_dec_dynamic_tree, max_draft_len, max_total_draft_tokens, + *, is_target_model: bool = True, model_is_wrapped: bool = False, spec_tree_manager: Optional['SpecTreeManager'] = None):This prevents future positional misuse and makes the intent explicit, but the current codebase is already compliant.
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tensorrt_llm/_torch/attention_backend/interface.py` around lines 371 - 381, The function update_spec_dec_param currently accepts many trailing boolean/optional parameters positionally; make is_target_model, model_is_wrapped, and spec_tree_manager keyword-only by inserting a bare * before is_target_model in the signature so callers cannot accidentally bind those flags positionally—update the signature in the update_spec_dec_param definition and adjust any internal references accordingly (no other logic changes).tensorrt_llm/_torch/attention_backend/trtllm.py (1)
502-509: Consider clarifying the reshape assumption.The reshape from 1D
[max_num_requests * N]to 2D[max_num_requests, N]assumes the 1D tensor was allocated with exactlymax_num_requests * (max_total_draft_tokens + 1)elements. This is correct based on line 1463-1465, but the implicit coupling between allocation and reshape could be fragile.Consider adding an assertion or comment to make this contract explicit:
📝 Suggestion for defensive check
# For dynamic tree, reshape 1D position_offsets to 2D for C++ kernel compatibility position_offsets_for_cpp = self.spec_decoding_position_offsets if (self.spec_decoding_position_offsets is not None and self.spec_decoding_position_offsets.dim() == 1): # Reshape 1D [max_num_requests * N] to 2D [max_num_requests, N] # C++ kernel requires 2D to extract max_generation_length from sizes()[1] + assert self.spec_decoding_position_offsets.numel() % self.max_num_requests == 0, \ + "1D position_offsets size must be divisible by max_num_requests" position_offsets_for_cpp = self.spec_decoding_position_offsets.view( self.max_num_requests, -1)🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tensorrt_llm/_torch/attention_backend/trtllm.py` around lines 502 - 509, The reshape from 1D to 2D (position_offsets_for_cpp based on self.spec_decoding_position_offsets) assumes the 1D tensor length equals max_num_requests * N; add a defensive check before the view to assert that self.spec_decoding_position_offsets.numel() is divisible by self.max_num_requests and (optionally) equals self.max_num_requests * (self.max_total_draft_tokens + 1) (or raise a clear error mentioning spec_decoding_position_offsets and max_num_requests) so the implicit allocation/reshape contract in trtllm.py is explicit and fails fast when violated.tensorrt_llm/_torch/pyexecutor/sampler.py (1)
2626-2668: Materializeaccept_indexbefore the per-request Python loop.Line 2649's
accept_index[j].item()performs a device read for every accepted token in a hot request loop. Convert the accepted indices once, then iterate over a Python list here, matching the existingnew_tokens.tolist()pattern. Based on learnings: In files undertensorrt_llm/_torch/pyexecutor, avoid accessingtorch.Tensorobjects inside for-loops when iterating over requests. Convert batched tensors to Python lists beforehand usingtensor.tolist(), and then iterate over those lists.🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tensorrt_llm/_torch/pyexecutor/sampler.py` around lines 2626 - 2668, The loop in _process_draft_tokens_dynamic_tree repeatedly calls accept_index[j].item(), causing device reads per-iteration; materialize accept_index to a Python list once before the request loop (e.g. accept_indices = accept_index.tolist() or accept_index.cpu().tolist() and cast elements to int), then iterate over accept_indices for add_token and finish_if_reason calls, and compute request.py_num_accepted_draft_tokens_indices from that list by subtracting 1 for positions after the root; keep using the same symbols (accept_index -> accept_indices, _process_draft_tokens_dynamic_tree, add_token, finish_if_reason, request.py_num_accepted_draft_tokens_indices) so you only replace tensor indexing with list indexing.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@cpp/tensorrt_llm/kernels/speculativeDecoding/dynamicTreeKernels.cu`:
- Around line 112-182: The kernel reuses preallocated topology buffers but
doesn’t clear stale data; before building the tree (inside the tid==0 branch of
the dynamic tree builder) explicitly reinitialize retrieveNextToken and
retrieveNextSibling entries for this batch (bid) to -1 for all draftTokenNum
slots, and clear all words of treeMask for this batch (not just word 0); also
ensure positions/retrieveIndex for all slots are set to sane defaults if needed.
Locate the tid==0 block that sets positions[bid * draftTokenNum] and the loop
that writes retrieveIndex/retrieveNextToken/retrieveNextSibling and add the
resets there (and mirror the same full-reset logic in the other build region
referenced around lines 245-317).
- Around line 191-214: The ancestor-walk loop can run past bounds when a parent
lookup misses: after the for-loop that searches selectedIndex for tokenIdx
(using curPosition and draftTokenNum) add a guard to detect "not found"
(curPosition == draftTokenNum) and break the while loop to avoid reading/writing
past selectedIndex/treeMask; apply the same defensive check to the equivalent
ancestor-walk logic around the other block referenced (uses the same symbols:
treeMask, tokenTreeIdx, curPosition, selectedIndex, draftTokenNum, parentList,
parentTbIdx, bid, topK) so both paths stop if the parent was not resampled.
In `@cpp/tensorrt_llm/kernels/speculativeDecoding/dynamicTreeKernels.h`:
- Line 17: The header currently uses `#pragma` once but must follow the repo guard
convention; replace the pragma with a preprocessor include guard named
TRTLLM_DYNAMICTREEKERNELS_H (matching the filename dynamicTreeKernels.h in ALL
CAPS) by adding `#ifndef` TRTLLM_DYNAMICTREEKERNELS_H / `#define`
TRTLLM_DYNAMICTREEKERNELS_H at the top and a matching `#endif` at the bottom,
ensuring no directory names or trailing underscores are used and keeping the
rest of the file (dynamicTreeKernels.h) unchanged.
In `@cpp/tensorrt_llm/thop/dynamicTreeOp.cpp`:
- Around line 34-67: build_dynamic_tree_op (and its sibling
verify_dynamic_tree_greedy_op) currently access raw data pointers and call
at::cuda::getCurrentCUDAStream() without validating devices, dtypes, shapes or
the treeMaskMode enum; add TORCH_CHECKs to ensure all input/output tensors
(parentList, selectedIndex, treeMask, positions, retrieveIndex,
retrieveNextToken, retrieveNextSibling, verifiedSeqLen) are CUDA tensors
(is_cuda()), are on the same device (device.index() equality), and have the
expected scalar types (parentList/selectedIndex int64,
positions/retrieveIndex/retrieveNextToken/retrieveNextSibling/verifiedSeqLen
int32 as used by data_ptr<int32_t/int64_t>()), verify output shapes (batchSize,
numDraftTokens-1, etc.) before zero_/fill_, and check treeMaskMode is within the
valid tk::TreeMaskMode range before static_cast; perform these checks at the
start of build_dynamic_tree_op and verify_dynamic_tree_greedy_op so that
tk::invokeBuildDynamicTree and related kernel calls only receive validated
tensors and enum values.
In `@tensorrt_llm/_torch/attention_backend/sparse/dsa.py`:
- Line 540: The forward reference 'SpecTreeManager' used in the signature
(spec_tree_manager: Optional['SpecTreeManager']) is not imported under
TYPE_CHECKING; add "from tensorrt_llm._torch.speculative.spec_tree_manager
import SpecTreeManager" to the existing TYPE_CHECKING import block (alongside
any existing imports such as DecodingBaseConfig) so the name is resolved for
type checking and Ruff F821 is fixed.
In `@tensorrt_llm/_torch/pyexecutor/py_executor_creator.py`:
- Around line 444-461: The leftover variable use_tree_drafter is assigned
earlier but never used after splitting into static_tree_drafter and
dynamic_tree_drafter, causing a linter F841; remove the unused use_tree_drafter
assignment (or fold its logic into the existing predicates) so only
static_tree_drafter and dynamic_tree_drafter derived from draft_spec_config
(EagleDecodingConfig) remain, leaving the branching that returns
StaticTreeDraftingLoopWrapper and DynamicTreeDraftingLoopWrapper unchanged
(references: use_tree_drafter, static_tree_drafter, dynamic_tree_drafter,
draft_spec_config, spec_config, StaticTreeDraftingLoopWrapper,
DynamicTreeDraftingLoopWrapper).
In `@tensorrt_llm/_torch/speculative/drafting_loops.py`:
- Around line 709-718: The code unconditionally overwrites return_draft_logits
with zeros losing real collected logits; change the logic in drafting_loops.py
so that you only allocate the zero tensor as a fallback when return_draft_logits
is missing or has an incompatible shape (e.g., check if return_draft_logits is
None or return_draft_logits.shape != (self.max_total_draft_tokens, batch_size,
vocab_size)); otherwise preserve the existing return_draft_logits from the last
draft layer; when allocating the fallback ensure dtype/device match
(torch.float32 and 'cuda') and add a brief comment referencing
tokens_accumulated to indicate this is a temporary fallback until per-layer
gathering is implemented.
- Around line 1164-1180: The attn_metadata.use_spec_decoding flag is left False
so subsequent drafter forwards ignore the dynamic-tree metadata; set
attn_metadata.use_spec_decoding = True at the end of this preparation block
(after updating kv_lens_cuda, _seq_lens, host_request_types and before leaving
the dynamic-tree growth steps) so the next draft pass uses speculative decoding;
locate the block updating attn_metadata.kv_lens_cuda, attn_metadata._seq_lens,
attn_metadata.host_request_types and set attn_metadata.use_spec_decoding = True
there (ensure this happens before
spec_metadata.eagle3_resource_manager.is_first_draft is toggled).
- Around line 961-975: spec_decoding_position_offsets is being treated as a flat
vector but it’s stored as a 2-D buffer ([max_num_requests,
max_total_draft_tokens+1]); update the code to slice and assign it as 2-D so
rows correspond to requests: read previous_position_offsets =
attn_metadata.spec_decoding_position_offsets[:batch_size,
:num_tokens_previous_layer], build new_position_offsets by concatenating along
dim=1 (using previous_position_offsets and previous_position_offsets[:,
-self.dynamic_tree_max_topK:]+1), then write it back to
attn_metadata.spec_decoding_position_offsets[:batch_size,
:num_tokens_current_layer] (no flattening/view needed) so the correct request
rows are updated.
In `@tensorrt_llm/_torch/speculative/dynamic_tree_ops.py`:
- Around line 1-12: Add the standard NVIDIA Apache-2.0 license header (with the
latest modification year) at the top of the file before the existing module
docstring in dynamic_tree_ops.py; replace the current file-starting
docstring-only content by prepending the required NVIDIA copyright/license block
so the file begins with the Apache 2.0 header followed by the existing "Dynamic
Tree Operations for EAGLE3 Speculative Decoding" docstring.
In `@tensorrt_llm/_torch/speculative/eagle3_dynamic_tree.py`:
- Around line 98-143: The buffers currently fix max_batch_size = 256 which can
overflow for larger deployments; replace the hard-coded max_batch_size with a
runtime value derived from spec_config or the worker's actual max concurrent
sequences (e.g., spec_config.max_batch_size or a passed-in parameter) and use
that variable when allocating dt_draft_tokens_buffer, dt_position_ids_buffer,
history_draft_tokens_buffer, history_score_buffer,
history_draft_tokens_parent_buffer, tree_mask_buffer, tree_mask_init_buffer, and
tree_mask_padding_zeros, and when calling create_dynamic_tree_ops_converter
(preserve device and dtypes). Ensure the new max_batch_size value is validated
(positive int) and that any flattened-size calculations (like tree_mask_buffer
shape) are updated to use the dynamic max_batch_size to avoid out-of-bounds
writes.
- Around line 475-487: The code incorrectly reshapes
spec_decoding_position_offsets into a flat (max_reqs, tokens_per_req) layout
which conflicts with the request-major layout used elsewhere; replace the manual
flattening with a request-major view and index into the existing first
dimension. Concretely, stop computing max_reqs = total_po_size // tokens_per_req
and using pos_2d = attn_metadata.spec_decoding_position_offsets.view(max_reqs,
tokens_per_req); instead treat pos_2d as request-major (e.g., pos_2d =
attn_metadata.spec_decoding_position_offsets.view(-1, tokens_per_req) or simply
use the existing first-dimension shape) and then write pos_2d[req_idx, :n] =
causal_offs[:n] ensuring req_idx is computed as num_contexts + g_idx and within
bounds. Apply the same fix to the other occurrence around lines 1090-1101 to
preserve the [max_num_requests, max_total_draft_tokens + 1] layout everywhere.
In `@tensorrt_llm/_torch/speculative/model_drafter.py`:
- Around line 686-690: Remove the two unused CPU buffer assignments to avoid
dead code: delete the assignments to topk_score_indices and
history_draft_tokens_parent_buffer that read from
dynamic_tree_buffers["topk_score_indices"].cpu() and
dynamic_tree_buffers["history_draft_tokens_parent_buffer"].cpu(). If those
buffers are intended for future use, replace each assignment with a short TODO
comment referencing the buffer name (topk_score_indices,
history_draft_tokens_parent_buffer) and why it will be needed; otherwise simply
remove the two lines. Ensure no other code in the same method depends on these
variables after removal.
- Around line 1004-1005: The prepare_draft_tokens method in ModelDrafter
currently requires resource_manager but the base class Drafter defines it as
optional; change the signature of ModelDrafter.prepare_draft_tokens to accept
resource_manager: Optional[ResourceManager] = None so it matches the base
contract, add or ensure Optional is imported from typing if missing, and mirror
the pattern used in ngram.py; update any internal usage of resource_manager to
handle None safely.
In `@tensorrt_llm/llmapi/llm_args.py`:
- Around line 1865-1876: The file contains a duplicate TypeAlias named
SpeculativeConfig that shadows the earlier discriminated union (the one defined
with Field(discriminator="decoding_type")), which removes SADecodingConfig and
PARDDecodingConfig; remove the second SpeculativeConfig definition (or rename it
if you truly need a separate non-discriminated alias) and keep the original
annotated union (including SADecodingConfig and PARDDecodingConfig alongside
DraftTargetDecodingConfig, Eagle3DecodingConfig, EagleDecodingConfig,
LookaheadDecodingConfig, MedusaDecodingConfig, MTPDecodingConfig,
NGramDecodingConfig, UserProvidedDecodingConfig, SaveHiddenStatesDecodingConfig,
AutoDecodingConfig) so the discriminator-based Pydantic union remains intact.
---
Nitpick comments:
In `@tensorrt_llm/_torch/attention_backend/interface.py`:
- Around line 371-381: The function update_spec_dec_param currently accepts many
trailing boolean/optional parameters positionally; make is_target_model,
model_is_wrapped, and spec_tree_manager keyword-only by inserting a bare *
before is_target_model in the signature so callers cannot accidentally bind
those flags positionally—update the signature in the update_spec_dec_param
definition and adjust any internal references accordingly (no other logic
changes).
In `@tensorrt_llm/_torch/attention_backend/trtllm.py`:
- Around line 502-509: The reshape from 1D to 2D (position_offsets_for_cpp based
on self.spec_decoding_position_offsets) assumes the 1D tensor length equals
max_num_requests * N; add a defensive check before the view to assert that
self.spec_decoding_position_offsets.numel() is divisible by
self.max_num_requests and (optionally) equals self.max_num_requests *
(self.max_total_draft_tokens + 1) (or raise a clear error mentioning
spec_decoding_position_offsets and max_num_requests) so the implicit
allocation/reshape contract in trtllm.py is explicit and fails fast when
violated.
In `@tensorrt_llm/_torch/pyexecutor/py_executor_creator.py`:
- Around line 435-437: Replace the direct class imports
DynamicTreeDraftingLoopWrapper, LinearDraftingLoopWrapper,
StaticTreeDraftingLoopWrapper with a namespaced module import for drafting_loops
and update all references to use drafting_loops.DynamicTreeDraftingLoopWrapper,
drafting_loops.LinearDraftingLoopWrapper, and
drafting_loops.StaticTreeDraftingLoopWrapper (e.g., where these classes are used
in the dispatch/registration code inside py_executor_creator.py) so the module
is imported, not individual classes.
In `@tensorrt_llm/_torch/pyexecutor/sampler.py`:
- Around line 2626-2668: The loop in _process_draft_tokens_dynamic_tree
repeatedly calls accept_index[j].item(), causing device reads per-iteration;
materialize accept_index to a Python list once before the request loop (e.g.
accept_indices = accept_index.tolist() or accept_index.cpu().tolist() and cast
elements to int), then iterate over accept_indices for add_token and
finish_if_reason calls, and compute request.py_num_accepted_draft_tokens_indices
from that list by subtracting 1 for positions after the root; keep using the
same symbols (accept_index -> accept_indices,
_process_draft_tokens_dynamic_tree, add_token, finish_if_reason,
request.py_num_accepted_draft_tokens_indices) so you only replace tensor
indexing with list indexing.
In `@tests/unittest/_torch/speculative/test_draft_token_tree_sampling.py`:
- Around line 8-9: Replace the direct class import with a module-level import
for the drafting_loops module and qualify the class through that namespace:
change the current "from tensorrt_llm._torch.speculative.drafting_loops import
StaticTreeDraftingLoopWrapper" to "import
tensorrt_llm._torch.speculative.drafting_loops as drafting_loops" (or "from
tensorrt_llm._torch.speculative import drafting_loops") and update all usages to
drafting_loops.StaticTreeDraftingLoopWrapper (also fix the similar import/usage
at the other occurrence noted).
ℹ️ Review info
⚙️ Run configuration
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
Run ID: 388f70bd-848c-4020-99de-5838cd97e5b3
📒 Files selected for processing (23)
cpp/tensorrt_llm/kernels/speculativeDecoding/dynamicTreeKernels.cucpp/tensorrt_llm/kernels/speculativeDecoding/dynamicTreeKernels.hcpp/tensorrt_llm/thop/CMakeLists.txtcpp/tensorrt_llm/thop/dynamicTreeOp.cppexamples/llm-api/quickstart_advanced.pytensorrt_llm/_torch/attention_backend/interface.pytensorrt_llm/_torch/attention_backend/sparse/dsa.pytensorrt_llm/_torch/attention_backend/trtllm.pytensorrt_llm/_torch/models/modeling_speculative.pytensorrt_llm/_torch/pyexecutor/model_engine.pytensorrt_llm/_torch/pyexecutor/py_executor.pytensorrt_llm/_torch/pyexecutor/py_executor_creator.pytensorrt_llm/_torch/pyexecutor/sampler.pytensorrt_llm/_torch/speculative/drafting_loops.pytensorrt_llm/_torch/speculative/dynamic_tree_ops.pytensorrt_llm/_torch/speculative/eagle3.pytensorrt_llm/_torch/speculative/eagle3_dynamic_tree.pytensorrt_llm/_torch/speculative/model_drafter.pytensorrt_llm/_torch/speculative/spec_tree_manager.pytensorrt_llm/_torch/speculative/utils.pytensorrt_llm/llmapi/llm_args.pytests/unittest/_torch/speculative/test_draft_token_prepare_for_generation.pytests/unittest/_torch/speculative/test_draft_token_tree_sampling.py
cpp/tensorrt_llm/kernels/speculativeDecoding/dynamicTreeKernels.cu
Outdated
Show resolved
Hide resolved
cpp/tensorrt_llm/kernels/speculativeDecoding/dynamicTreeKernels.cu
Outdated
Show resolved
Hide resolved
|
PR_Github #38374 [ run ] completed with state
|
99b25a9 to
c2c8ef6
Compare
|
/bot run |
c2c8ef6 to
abd7543
Compare
|
/bot run |
|
PR_Github #38968 [ run ] triggered by Bot. Commit: |
|
PR_Github #38968 [ run ] completed with state
|
|
/bot run |
1 similar comment
|
/bot run |
|
PR_Github #39082 [ run ] triggered by Bot. Commit: |
|
PR_Github #39082 [ run ] completed with state
|
|
/bot run |
|
PR_Github #39197 [ run ] triggered by Bot. Commit: |
|
PR_Github #42211 [ run ] triggered by Bot. Commit: |
|
PR_Github #42211 [ run ] completed with state
|
|
/bot run --disable-fail-fast |
|
PR_Github #42279 [ run ] triggered by Bot. Commit: |
|
PR_Github #42279 [ run ] completed with state
|
|
/bot run --disable-fail-fast |
|
PR_Github #42358 [ run ] triggered by Bot. Commit: |
|
/bot run --disable-fail-fast |
|
PR_Github #42507 [ run ] triggered by Bot. Commit: |
|
PR_Github #42507 [ run ] completed with state
|
|
/bot run --disable-fail-fast |
|
PR_Github #42597 [ run ] triggered by Bot. Commit: |
|
PR_Github #42597 [ run ] completed with state
|
|
/bot run --disable-fail-fast |
|
PR_Github #42679 [ run ] triggered by Bot. Commit: |
…update_spec_dec_param The parent TrtllmAttentionMetadata.update_spec_dec_param() no longer accepts spec_decoding_tensor, but the DSA override still passed it through to super(), causing a TypeError on all B200 8-GPU speculative decoding tests (TestGLM5FP8, TestDeepSeekV32 DSA/NVFP4). Signed-off-by: qgai <qgai@nvidia.com>
|
/bot run --disable-fail-fast |
|
PR_Github #42698 [ run ] triggered by Bot. Commit: |
|
PR_Github #42679 [ run ] completed with state |
…n feature matrices Signed-off-by: qgai <qgai@nvidia.com>
|
/bot run --disable-fail-fast |
|
PR_Github #42762 [ run ] triggered by Bot. Commit: |
|
PR_Github #42762 [ run ] completed with state
|
|
/bot run --disable-fail-fast |
|
PR_Github #42803 [ run ] triggered by Bot. Commit: |
|
PR_Github #42803 [ run ] completed with state |
Already changed based on his review
|
/bot run |
|
PR_Github #42869 [ ] completed with state |
Summary
Eagle3OneModelDynamicTreeWorkerandEagle3OneModelDynamicTreeSamplerfor one-model dynamic tree inferenceChanges
tensorrt_llm/_torch/speculative/eagle3_dynamic_tree.py— One-model dynamic tree worker and samplertensorrt_llm/_torch/speculative/dynamic_tree_ops.py— Python wrappers for dynamic tree CUDA opscpp/tensorrt_llm/kernels/speculativeDecoding/dynamicTreeKernels.cu/.h— CUDA kernelscpp/tensorrt_llm/thop/dynamicTreeOp.cpp— Torch custom op bindingstensorrt_llm/_torch/speculative/eagle3.py— RefactoredEagle3OneModelWorkerwith dispatch pattern for linear vs dynamic treetensorrt_llm/_torch/speculative/utils.py— Route to dynamic tree components whenuse_dynamic_tree=Truetensorrt_llm/_torch/speculative/drafting_loops.py— Two-model dynamic tree drafting looptensorrt_llm/_torch/speculative/model_drafter.py— Dynamic tree spec tree manager integrationtensorrt_llm/_torch/speculative/spec_tree_manager.py— Support dynamic tree token organizationtensorrt_llm/_torch/pyexecutor/model_engine.py— Dynamic tree detection for target modeltensorrt_llm/_torch/pyexecutor/sampler.py— Dynamic tree batch verificationtensorrt_llm/_torch/models/modeling_speculative.py— Hidden states handling for dynamic treetensorrt_llm/llmapi/llm_args.py— Configuration validation for dynamic treeTest plan
tests/unittest/_torch/speculative/)Summary by CodeRabbit
New Features
Configuration
--max_total_draft_tokensparameter for controlling total draft token budget--streamingflag for real-time token streaming output