Skip to content

[TRTLLM-11540][feat] Add EAGLE3 dynamic tree speculative decoding support#12062

Merged
sunnyqgg merged 11 commits intoNVIDIA:mainfrom
sunnyqgg:add_dyanmic_tree_support_one_model
Apr 12, 2026
Merged

[TRTLLM-11540][feat] Add EAGLE3 dynamic tree speculative decoding support#12062
sunnyqgg merged 11 commits intoNVIDIA:mainfrom
sunnyqgg:add_dyanmic_tree_support_one_model

Conversation

@sunnyqgg
Copy link
Copy Markdown
Collaborator

@sunnyqgg sunnyqgg commented Mar 10, 2026

Summary

  • Add dynamic tree speculative decoding support for EAGLE3 (both one-model and two-model flows)
  • Implement Eagle3OneModelDynamicTreeWorker and Eagle3OneModelDynamicTreeSampler for one-model dynamic tree inference
  • Add CUDA kernels for dynamic tree operations (expand, gather, update)
  • Support growing context in dynamic tree mode for improved accept rates

Changes

  • New: tensorrt_llm/_torch/speculative/eagle3_dynamic_tree.py — One-model dynamic tree worker and sampler
  • New: tensorrt_llm/_torch/speculative/dynamic_tree_ops.py — Python wrappers for dynamic tree CUDA ops
  • New: cpp/tensorrt_llm/kernels/speculativeDecoding/dynamicTreeKernels.cu/.h — CUDA kernels
  • New: cpp/tensorrt_llm/thop/dynamicTreeOp.cpp — Torch custom op bindings
  • Modified: tensorrt_llm/_torch/speculative/eagle3.py — Refactored Eagle3OneModelWorker with dispatch pattern for linear vs dynamic tree
  • Modified: tensorrt_llm/_torch/speculative/utils.py — Route to dynamic tree components when use_dynamic_tree=True
  • Modified: tensorrt_llm/_torch/speculative/drafting_loops.py — Two-model dynamic tree drafting loop
  • Modified: tensorrt_llm/_torch/speculative/model_drafter.py — Dynamic tree spec tree manager integration
  • Modified: tensorrt_llm/_torch/speculative/spec_tree_manager.py — Support dynamic tree token organization
  • Modified: tensorrt_llm/_torch/pyexecutor/model_engine.py — Dynamic tree detection for target model
  • Modified: tensorrt_llm/_torch/pyexecutor/sampler.py — Dynamic tree batch verification
  • Modified: tensorrt_llm/_torch/models/modeling_speculative.py — Hidden states handling for dynamic tree
  • Modified: tensorrt_llm/llmapi/llm_args.py — Configuration validation for dynamic tree
  • Modified: Attention backend files for dynamic tree metadata support

Test plan

  • Unit tests pass (tests/unittest/_torch/speculative/)
  • One-model dynamic tree EAGLE3 inference matches two-model accept rates
  • Two-model dynamic tree EAGLE3 inference works correctly
  • No regression in standard (non-dynamic-tree) EAGLE3 flows

Summary by CodeRabbit

  • New Features

    • Added dynamic tree speculative decoding support for Eagle3, enabling efficient parallel token prediction with flexible tree structures
    • Added streaming generation support with configurable max draft tokens and speculative decoding parameters
    • Enhanced token verification with CUDA-accelerated kernels for improved inference performance
  • Configuration

    • Added --max_total_draft_tokens parameter for controlling total draft token budget
    • Added --streaming flag for real-time token streaming output

@sunnyqgg
Copy link
Copy Markdown
Collaborator Author

/bot run

@sunnyqgg sunnyqgg marked this pull request as draft March 10, 2026 04:33
@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #38374 [ run ] triggered by Bot. Commit: b15f7d6 Link to invocation

@coderabbitai
Copy link
Copy Markdown
Contributor

coderabbitai bot commented Mar 10, 2026

📝 Walkthrough

Walkthrough

This pull request introduces dynamic tree-based speculative decoding for EAGLE3 inference with CUDA-accelerated kernels. It adds tree construction and greedy verification kernels, integrates them with Python/PyTorch layers, implements dynamic tree sampling and acceptance logic, and extends the Eagle3 resource manager and worker infrastructure to support dynamic tree mode alongside static tree mode.

Changes

Cohort / File(s) Summary
CUDA Kernels & Interface
cpp/tensorrt_llm/kernels/speculativeDecoding/dynamicTreeKernels.cu, cpp/tensorrt_llm/kernels/speculativeDecoding/dynamicTreeKernels.h
New kernels for dynamic tree construction (packed and non-packed treeMask variants) and greedy tree verification. Includes kernel launchers (invokeBuildDynamicTree, invokeVerifyDynamicTreeGreedy) and TreeMaskMode enum (QLEN_ONLY, QLEN_ONLY_BITPACKING).
Torch C++ Extension
cpp/tensorrt_llm/thop/dynamicTreeOp.cpp, cpp/tensorrt_llm/thop/CMakeLists.txt
PyTorch custom operators binding CUDA kernels: build_dynamic_tree_op and verify_dynamic_tree_greedy_op. Includes input validation, buffer initialization, and kernel invocation via tensor interfaces.
Dynamic Tree Operations
tensorrt_llm/_torch/speculative/dynamic_tree_ops.py
Core Python abstraction layer introducing DynamicTreeBuffers, VerifyTreeResults, DynamicTreeOpsConverter, and factory function create_dynamic_tree_ops_converter for managing tree construction and verification with preallocated buffers and error handling.
Drafting Loop Infrastructure
tensorrt_llm/_torch/speculative/drafting_loops.py
Renames existing TreeDraftingLoopWrapper to StaticTreeDraftingLoopWrapper and introduces new DynamicTreeDraftingLoopWrapper for dynamic-tree drafting with per-layer sampling, extensive buffering, and CUDA kernel integration.
Dynamic Tree Sampling & Worker
tensorrt_llm/_torch/speculative/eagle3_dynamic_tree.py
New module with Eagle3OneModelDynamicTreeSampler and Eagle3OneModelDynamicTreeWorker implementing dynamic-tree verification, KV-cache management, draft token generation, and tree mask/position preparation with extensive internal helpers (sample_dynamic, dt_update_draft_tokens_and_scores, dt_resampling_final_draft_tokens, etc.).
Sampler Integration
tensorrt_llm/_torch/pyexecutor/sampler.py
Adds batch-driven dynamic tree verification (_batch_verify_dynamic_tree) and per-request processing (_process_draft_tokens_dynamic_tree) integrated into main sampling flow. Includes dynamic tree result handling and fallback to greedy path when unavailable.
Eagle3 Configuration & Management
tensorrt_llm/_torch/speculative/eagle3.py, tensorrt_llm/_torch/speculative/spec_tree_manager.py
Introduces Eagle3OneModelDynamicTreeResourceManager and expands Eagle3OneModelSpecMetadata with use_dynamic_tree and eagle_choices fields. Adds dynamic-tree buffer allocation, eagle_paths selection logic, spec_dec_packed_mask computation, and drafter-model offsets in SpecTreeManager.
Attention Backend Interface Updates
tensorrt_llm/_torch/attention_backend/interface.py, tensorrt_llm/_torch/attention_backend/sparse/dsa.py, tensorrt_llm/_torch/attention_backend/trtllm.py
Updated update_spec_dec_param signatures: removed spec_metadata and spec_decoding_tensor parameters, added is_target_model flag. Backend implementations now reshape position_offsets for dynamic trees and handle spec decoding buffers with spec_tree_manager integration, including support for drafter-layer paths.
Model & Executor Updates
tensorrt_llm/_torch/models/modeling_speculative.py, tensorrt_llm/_torch/pyexecutor/model_engine.py, tensorrt_llm/_torch/pyexecutor/py_executor.py, tensorrt_llm/_torch/pyexecutor/py_executor_creator.py
Minor refactoring in Eagle3ForCausalLM; removed spec_decoding_tensor from PyTorchModelEngine; added warmup completion flag in PyExecutor; and expanded py_executor_creator to conditionally select DynamicTreeDraftingLoopWrapper or StaticTreeDraftingLoopWrapper based on EagleDecodingConfig flags.
Model Drafter
tensorrt_llm/_torch/speculative/model_drafter.py
Added dynamic-tree buffer handling in static-draft-output path. Extracts and propagates retrieve_index, retrieve_next_token, retrieve_next_sibling from tree structures to spec_tree_manager per-request state. Updated prepare_draft_tokens signature to require ResourceManager explicitly.
Configuration & Utilities
tensorrt_llm/llmapi/llm_args.py, tensorrt_llm/_torch/speculative/utils.py
Enhanced EagleDecodingConfig.validate_eagle_config with unified dynamic-tree validation: enforces dynamic_tree_max_topK, requires eagle_choices=None for dynamic mode, defaults/validates max_total_draft_tokens. Expanded utils.py to route to dynamic-tree variants (ResourceManager, Sampler, Worker) when use_dynamic_tree is enabled.
Example & Tests
examples/llm-api/quickstart_advanced.py, tests/unittest/_torch/speculative/test_draft_token_prepare_for_generation.py, tests/unittest/_torch/speculative/test_draft_token_tree_sampling.py
Added --max_total_draft_tokens and --streaming CLI arguments to quickstart_advanced. Test imports updated to use StaticTreeDraftingLoopWrapper and DynamicTreeDraftingLoopWrapper; added new test functions for dynamic tree updates and restructuring with ModelDrafter integration.

Sequence Diagram(s)

sequenceDiagram
    participant Python as Python Layer
    participant Sampler as TorchSampler
    participant DTree as DynamicTreeOpsConverter
    participant Kernel as CUDA Kernel
    participant Buf as GPU Buffers

    Python->>Sampler: update() with requests
    Sampler->>Sampler: _batch_verify_dynamic_tree(requests, tokens)
    Sampler->>DTree: verify_dynamic_tree_greedy(draft_tokens, logits, tree_buffers)
    DTree->>DTree: compute target predictions from logits
    DTree->>Kernel: invoke verify_dynamic_tree_greedy_op()
    Kernel->>Buf: read: candidates, retrieve_index, retrieve_next_sibling, targetPredict
    Kernel->>Buf: greedy tree traversal per batch
    Kernel->>Buf: write: predicts, acceptIndex, acceptTokenNum
    Buf-->>Kernel: results
    Kernel-->>DTree: returns VerifyTreeResults
    DTree-->>Sampler: per-slot (num_accepted_tokens, accept_index)
    Sampler->>Sampler: _process_draft_tokens_dynamic_tree() per request
    Sampler-->>Python: updated tokens and finish reasons
Loading
sequenceDiagram
    participant Worker as Eagle3DynamicTreeWorker
    participant DraftModel as Draft Model
    participant DTree as DynamicTreeOpsConverter
    participant Kernel as CUDA Kernel
    participant Cache as KV Cache

    Worker->>Worker: _forward_draft_loop(initial context)
    Worker->>DraftModel: forward(input_ids, position_ids)
    DraftModel-->>Worker: logits
    Worker->>Worker: sample_dynamic(logits, topk)
    Worker->>Worker: dt_update_draft_tokens_and_scores()
    Worker->>DTree: build_dynamic_tree(parent_list, topk_indices, tree_buffers)
    DTree->>Kernel: invoke build_dynamic_tree_op()
    Kernel->>Kernel: construct left-child/right-sibling tree
    Kernel->>Kernel: compute per-node attention masks (treeMask)
    Kernel->>Kernel: compute absolute positions
    Kernel-->>DTree: DynamicTreeBuffers (tree_mask, positions, retrieve_index, etc.)
    DTree-->>Worker: tree structure ready
    Worker->>Worker: dt_prepare_tree_mask_and_position_offset()
    Worker->>DraftModel: forward(growing context with tree topology)
    DraftModel->>Cache: update KV cache with tree positions
    DraftModel-->>Worker: logits per tree node
    Worker->>Worker: _sample_and_accept_dynamic_tree(logits)
Loading
sequenceDiagram
    participant App as Application
    participant Executor as PyExecutor
    participant ResourceMgr as Eagle3OneModelDynamicTreeResourceManager
    participant Worker as Eagle3OneModelDynamicTreeWorker
    participant Sampler as Eagle3OneModelDynamicTreeSampler

    App->>Executor: initialize with EagleDecodingConfig(use_dynamic_tree=True)
    Executor->>ResourceMgr: create with SpecTreeManager(use_dynamic_tree=True)
    Executor->>Worker: initialize with spec_config
    Executor->>Sampler: initialize with spec_config
    App->>Executor: generate(requests)
    Executor->>Worker: forward(context & draft loop)
    Worker->>Worker: _forward_dynamic_tree_draft_loop()
    Worker-->>Executor: draft_tokens, dynamic_tree_buffers, accepted_draft_indices
    Executor->>Sampler: sample_and_accept_draft_tokens(logits, buffers)
    Sampler->>Sampler: verify with dynamic tree buffers
    Sampler-->>Executor: accepted tokens, accepted indices
    Executor->>ResourceMgr: get_needed_resource_to_completion(request)
    ResourceMgr-->>Executor: resource estimates
    Executor->>Worker: prepare_1st_drafter_inputs(with tree topology targets)
    Executor-->>App: generated tokens
Loading

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~75 minutes

🚥 Pre-merge checks | ✅ 2 | ❌ 1

❌ Failed checks (1 warning)

Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 38.89% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
✅ Passed checks (2 passed)
Check name Status Explanation
Description check ✅ Passed The PR description provides a comprehensive summary of changes, clear objectives, and test coverage details.
Title check ✅ Passed The title clearly and specifically summarizes the main change: adding EAGLE3 dynamic tree speculative decoding support. It is concise, directly related to the substantial feature additions across kernels, ops, and model layers.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing Touches
🧪 Generate unit tests (beta)
  • Create PR with unit tests

Comment @coderabbitai help to get the list of available commands and usage tips.

Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 17

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (1)
tensorrt_llm/_torch/pyexecutor/sampler.py (1)

3346-3386: ⚠️ Potential issue | 🔴 Critical

Zero-draft dynamic-tree requests currently fall into the static-tree verifier.

_batch_verify_dynamic_tree() explicitly skips requests whose draft length is 0, but the fallback here still calls process_draft_tokens(). With a non-null spec_tree_manager, that dispatches to _process_draft_tokens_tree(), which assumes a populated draft tree. If a dynamic-tree request reaches this branch with no drafts, it will fail instead of just emitting the verified token.

🛠️ One possible guard
-                if req.py_seq_slot in dynamic_tree_results:
+                if req.py_seq_slot in dynamic_tree_results:
                     num_accepted = self._process_draft_tokens_dynamic_tree(
                         req, new_tokens_list, finish_reasons, dynamic_tree_results[req.py_seq_slot]
                     )
-
+                elif spec_tree_manager is not None and spec_tree_manager.use_dynamic_tree:
+                    num_accepted = self._process_draft_tokens_greedy(
+                        req, new_tokens=new_tokens_list, finish_reasons=finish_reasons
+                    )
                 else:
                     num_accepted = self.process_draft_tokens(
                         req,
                         new_tokens_tensor=new_tokens,
                         new_tokens_list=new_tokens_list,
🧹 Nitpick comments (5)
tensorrt_llm/_torch/pyexecutor/py_executor_creator.py (1)

435-437: Keep the wrapper imports namespaced here too.

Since this is new dispatch code, please import the drafting_loops module and reference these wrappers from that module rather than importing the classes directly. As per coding guidelines, When importing in Python, always maintain the namespace. Import the module, not individual classes or functions.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tensorrt_llm/_torch/pyexecutor/py_executor_creator.py` around lines 435 -
437, Replace the direct class imports DynamicTreeDraftingLoopWrapper,
LinearDraftingLoopWrapper, StaticTreeDraftingLoopWrapper with a namespaced
module import for drafting_loops and update all references to use
drafting_loops.DynamicTreeDraftingLoopWrapper,
drafting_loops.LinearDraftingLoopWrapper, and
drafting_loops.StaticTreeDraftingLoopWrapper (e.g., where these classes are used
in the dispatch/registration code inside py_executor_creator.py) so the module
is imported, not individual classes.
tests/unittest/_torch/speculative/test_draft_token_tree_sampling.py (1)

8-9: Keep the drafting-loop import namespaced.

Please import the drafting_loops module and resolve StaticTreeDraftingLoopWrapper through that namespace here instead of importing the class directly. As per coding guidelines, When importing in Python, always maintain the namespace. Import the module, not individual classes or functions.

Also applies to: 56-56

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tests/unittest/_torch/speculative/test_draft_token_tree_sampling.py` around
lines 8 - 9, Replace the direct class import with a module-level import for the
drafting_loops module and qualify the class through that namespace: change the
current "from tensorrt_llm._torch.speculative.drafting_loops import
StaticTreeDraftingLoopWrapper" to "import
tensorrt_llm._torch.speculative.drafting_loops as drafting_loops" (or "from
tensorrt_llm._torch.speculative import drafting_loops") and update all usages to
drafting_loops.StaticTreeDraftingLoopWrapper (also fix the similar import/usage
at the other occurrence noted).
tensorrt_llm/_torch/attention_backend/interface.py (1)

371-381: Making trailing parameters keyword-only would improve API safety, but it is not required.

The code is already safe: all call sites either use keyword arguments (model_engine.py:3552, test cases) or correctly pass all 9 positional arguments in order (sparse/dsa.py:541). No stale callers with 7–8 positional arguments (which would silently misbind after inserting is_target_model) exist in the codebase.

If keyword-only enforcement is desired for defensiveness, add * before is_target_model:

def update_spec_dec_param(
        self,
        batch_size,
        is_spec_decoding_enabled,
        is_spec_dec_tree,
        is_spec_dec_dynamic_tree,
        max_draft_len,
        max_total_draft_tokens,
+       *,
        is_target_model: bool = True,
        model_is_wrapped: bool = False,
        spec_tree_manager: Optional['SpecTreeManager'] = None):

This prevents future positional misuse and makes the intent explicit, but the current codebase is already compliant.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tensorrt_llm/_torch/attention_backend/interface.py` around lines 371 - 381,
The function update_spec_dec_param currently accepts many trailing
boolean/optional parameters positionally; make is_target_model,
model_is_wrapped, and spec_tree_manager keyword-only by inserting a bare *
before is_target_model in the signature so callers cannot accidentally bind
those flags positionally—update the signature in the update_spec_dec_param
definition and adjust any internal references accordingly (no other logic
changes).
tensorrt_llm/_torch/attention_backend/trtllm.py (1)

502-509: Consider clarifying the reshape assumption.

The reshape from 1D [max_num_requests * N] to 2D [max_num_requests, N] assumes the 1D tensor was allocated with exactly max_num_requests * (max_total_draft_tokens + 1) elements. This is correct based on line 1463-1465, but the implicit coupling between allocation and reshape could be fragile.

Consider adding an assertion or comment to make this contract explicit:

📝 Suggestion for defensive check
         # For dynamic tree, reshape 1D position_offsets to 2D for C++ kernel compatibility
         position_offsets_for_cpp = self.spec_decoding_position_offsets
         if (self.spec_decoding_position_offsets is not None
                 and self.spec_decoding_position_offsets.dim() == 1):
             # Reshape 1D [max_num_requests * N] to 2D [max_num_requests, N]
             # C++ kernel requires 2D to extract max_generation_length from sizes()[1]
+            assert self.spec_decoding_position_offsets.numel() % self.max_num_requests == 0, \
+                "1D position_offsets size must be divisible by max_num_requests"
             position_offsets_for_cpp = self.spec_decoding_position_offsets.view(
                 self.max_num_requests, -1)
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tensorrt_llm/_torch/attention_backend/trtllm.py` around lines 502 - 509, The
reshape from 1D to 2D (position_offsets_for_cpp based on
self.spec_decoding_position_offsets) assumes the 1D tensor length equals
max_num_requests * N; add a defensive check before the view to assert that
self.spec_decoding_position_offsets.numel() is divisible by
self.max_num_requests and (optionally) equals self.max_num_requests *
(self.max_total_draft_tokens + 1) (or raise a clear error mentioning
spec_decoding_position_offsets and max_num_requests) so the implicit
allocation/reshape contract in trtllm.py is explicit and fails fast when
violated.
tensorrt_llm/_torch/pyexecutor/sampler.py (1)

2626-2668: Materialize accept_index before the per-request Python loop.

Line 2649's accept_index[j].item() performs a device read for every accepted token in a hot request loop. Convert the accepted indices once, then iterate over a Python list here, matching the existing new_tokens.tolist() pattern. Based on learnings: In files under tensorrt_llm/_torch/pyexecutor, avoid accessing torch.Tensor objects inside for-loops when iterating over requests. Convert batched tensors to Python lists beforehand using tensor.tolist(), and then iterate over those lists.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tensorrt_llm/_torch/pyexecutor/sampler.py` around lines 2626 - 2668, The loop
in _process_draft_tokens_dynamic_tree repeatedly calls accept_index[j].item(),
causing device reads per-iteration; materialize accept_index to a Python list
once before the request loop (e.g. accept_indices = accept_index.tolist() or
accept_index.cpu().tolist() and cast elements to int), then iterate over
accept_indices for add_token and finish_if_reason calls, and compute
request.py_num_accepted_draft_tokens_indices from that list by subtracting 1 for
positions after the root; keep using the same symbols (accept_index ->
accept_indices, _process_draft_tokens_dynamic_tree, add_token, finish_if_reason,
request.py_num_accepted_draft_tokens_indices) so you only replace tensor
indexing with list indexing.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@cpp/tensorrt_llm/kernels/speculativeDecoding/dynamicTreeKernels.cu`:
- Around line 112-182: The kernel reuses preallocated topology buffers but
doesn’t clear stale data; before building the tree (inside the tid==0 branch of
the dynamic tree builder) explicitly reinitialize retrieveNextToken and
retrieveNextSibling entries for this batch (bid) to -1 for all draftTokenNum
slots, and clear all words of treeMask for this batch (not just word 0); also
ensure positions/retrieveIndex for all slots are set to sane defaults if needed.
Locate the tid==0 block that sets positions[bid * draftTokenNum] and the loop
that writes retrieveIndex/retrieveNextToken/retrieveNextSibling and add the
resets there (and mirror the same full-reset logic in the other build region
referenced around lines 245-317).
- Around line 191-214: The ancestor-walk loop can run past bounds when a parent
lookup misses: after the for-loop that searches selectedIndex for tokenIdx
(using curPosition and draftTokenNum) add a guard to detect "not found"
(curPosition == draftTokenNum) and break the while loop to avoid reading/writing
past selectedIndex/treeMask; apply the same defensive check to the equivalent
ancestor-walk logic around the other block referenced (uses the same symbols:
treeMask, tokenTreeIdx, curPosition, selectedIndex, draftTokenNum, parentList,
parentTbIdx, bid, topK) so both paths stop if the parent was not resampled.

In `@cpp/tensorrt_llm/kernels/speculativeDecoding/dynamicTreeKernels.h`:
- Line 17: The header currently uses `#pragma` once but must follow the repo guard
convention; replace the pragma with a preprocessor include guard named
TRTLLM_DYNAMICTREEKERNELS_H (matching the filename dynamicTreeKernels.h in ALL
CAPS) by adding `#ifndef` TRTLLM_DYNAMICTREEKERNELS_H / `#define`
TRTLLM_DYNAMICTREEKERNELS_H at the top and a matching `#endif` at the bottom,
ensuring no directory names or trailing underscores are used and keeping the
rest of the file (dynamicTreeKernels.h) unchanged.

In `@cpp/tensorrt_llm/thop/dynamicTreeOp.cpp`:
- Around line 34-67: build_dynamic_tree_op (and its sibling
verify_dynamic_tree_greedy_op) currently access raw data pointers and call
at::cuda::getCurrentCUDAStream() without validating devices, dtypes, shapes or
the treeMaskMode enum; add TORCH_CHECKs to ensure all input/output tensors
(parentList, selectedIndex, treeMask, positions, retrieveIndex,
retrieveNextToken, retrieveNextSibling, verifiedSeqLen) are CUDA tensors
(is_cuda()), are on the same device (device.index() equality), and have the
expected scalar types (parentList/selectedIndex int64,
positions/retrieveIndex/retrieveNextToken/retrieveNextSibling/verifiedSeqLen
int32 as used by data_ptr<int32_t/int64_t>()), verify output shapes (batchSize,
numDraftTokens-1, etc.) before zero_/fill_, and check treeMaskMode is within the
valid tk::TreeMaskMode range before static_cast; perform these checks at the
start of build_dynamic_tree_op and verify_dynamic_tree_greedy_op so that
tk::invokeBuildDynamicTree and related kernel calls only receive validated
tensors and enum values.

In `@tensorrt_llm/_torch/attention_backend/sparse/dsa.py`:
- Line 540: The forward reference 'SpecTreeManager' used in the signature
(spec_tree_manager: Optional['SpecTreeManager']) is not imported under
TYPE_CHECKING; add "from tensorrt_llm._torch.speculative.spec_tree_manager
import SpecTreeManager" to the existing TYPE_CHECKING import block (alongside
any existing imports such as DecodingBaseConfig) so the name is resolved for
type checking and Ruff F821 is fixed.

In `@tensorrt_llm/_torch/pyexecutor/py_executor_creator.py`:
- Around line 444-461: The leftover variable use_tree_drafter is assigned
earlier but never used after splitting into static_tree_drafter and
dynamic_tree_drafter, causing a linter F841; remove the unused use_tree_drafter
assignment (or fold its logic into the existing predicates) so only
static_tree_drafter and dynamic_tree_drafter derived from draft_spec_config
(EagleDecodingConfig) remain, leaving the branching that returns
StaticTreeDraftingLoopWrapper and DynamicTreeDraftingLoopWrapper unchanged
(references: use_tree_drafter, static_tree_drafter, dynamic_tree_drafter,
draft_spec_config, spec_config, StaticTreeDraftingLoopWrapper,
DynamicTreeDraftingLoopWrapper).

In `@tensorrt_llm/_torch/speculative/drafting_loops.py`:
- Around line 709-718: The code unconditionally overwrites return_draft_logits
with zeros losing real collected logits; change the logic in drafting_loops.py
so that you only allocate the zero tensor as a fallback when return_draft_logits
is missing or has an incompatible shape (e.g., check if return_draft_logits is
None or return_draft_logits.shape != (self.max_total_draft_tokens, batch_size,
vocab_size)); otherwise preserve the existing return_draft_logits from the last
draft layer; when allocating the fallback ensure dtype/device match
(torch.float32 and 'cuda') and add a brief comment referencing
tokens_accumulated to indicate this is a temporary fallback until per-layer
gathering is implemented.
- Around line 1164-1180: The attn_metadata.use_spec_decoding flag is left False
so subsequent drafter forwards ignore the dynamic-tree metadata; set
attn_metadata.use_spec_decoding = True at the end of this preparation block
(after updating kv_lens_cuda, _seq_lens, host_request_types and before leaving
the dynamic-tree growth steps) so the next draft pass uses speculative decoding;
locate the block updating attn_metadata.kv_lens_cuda, attn_metadata._seq_lens,
attn_metadata.host_request_types and set attn_metadata.use_spec_decoding = True
there (ensure this happens before
spec_metadata.eagle3_resource_manager.is_first_draft is toggled).
- Around line 961-975: spec_decoding_position_offsets is being treated as a flat
vector but it’s stored as a 2-D buffer ([max_num_requests,
max_total_draft_tokens+1]); update the code to slice and assign it as 2-D so
rows correspond to requests: read previous_position_offsets =
attn_metadata.spec_decoding_position_offsets[:batch_size,
:num_tokens_previous_layer], build new_position_offsets by concatenating along
dim=1 (using previous_position_offsets and previous_position_offsets[:,
-self.dynamic_tree_max_topK:]+1), then write it back to
attn_metadata.spec_decoding_position_offsets[:batch_size,
:num_tokens_current_layer] (no flattening/view needed) so the correct request
rows are updated.

In `@tensorrt_llm/_torch/speculative/dynamic_tree_ops.py`:
- Around line 1-12: Add the standard NVIDIA Apache-2.0 license header (with the
latest modification year) at the top of the file before the existing module
docstring in dynamic_tree_ops.py; replace the current file-starting
docstring-only content by prepending the required NVIDIA copyright/license block
so the file begins with the Apache 2.0 header followed by the existing "Dynamic
Tree Operations for EAGLE3 Speculative Decoding" docstring.

In `@tensorrt_llm/_torch/speculative/eagle3_dynamic_tree.py`:
- Around line 98-143: The buffers currently fix max_batch_size = 256 which can
overflow for larger deployments; replace the hard-coded max_batch_size with a
runtime value derived from spec_config or the worker's actual max concurrent
sequences (e.g., spec_config.max_batch_size or a passed-in parameter) and use
that variable when allocating dt_draft_tokens_buffer, dt_position_ids_buffer,
history_draft_tokens_buffer, history_score_buffer,
history_draft_tokens_parent_buffer, tree_mask_buffer, tree_mask_init_buffer, and
tree_mask_padding_zeros, and when calling create_dynamic_tree_ops_converter
(preserve device and dtypes). Ensure the new max_batch_size value is validated
(positive int) and that any flattened-size calculations (like tree_mask_buffer
shape) are updated to use the dynamic max_batch_size to avoid out-of-bounds
writes.
- Around line 475-487: The code incorrectly reshapes
spec_decoding_position_offsets into a flat (max_reqs, tokens_per_req) layout
which conflicts with the request-major layout used elsewhere; replace the manual
flattening with a request-major view and index into the existing first
dimension. Concretely, stop computing max_reqs = total_po_size // tokens_per_req
and using pos_2d = attn_metadata.spec_decoding_position_offsets.view(max_reqs,
tokens_per_req); instead treat pos_2d as request-major (e.g., pos_2d =
attn_metadata.spec_decoding_position_offsets.view(-1, tokens_per_req) or simply
use the existing first-dimension shape) and then write pos_2d[req_idx, :n] =
causal_offs[:n] ensuring req_idx is computed as num_contexts + g_idx and within
bounds. Apply the same fix to the other occurrence around lines 1090-1101 to
preserve the [max_num_requests, max_total_draft_tokens + 1] layout everywhere.

In `@tensorrt_llm/_torch/speculative/model_drafter.py`:
- Around line 686-690: Remove the two unused CPU buffer assignments to avoid
dead code: delete the assignments to topk_score_indices and
history_draft_tokens_parent_buffer that read from
dynamic_tree_buffers["topk_score_indices"].cpu() and
dynamic_tree_buffers["history_draft_tokens_parent_buffer"].cpu(). If those
buffers are intended for future use, replace each assignment with a short TODO
comment referencing the buffer name (topk_score_indices,
history_draft_tokens_parent_buffer) and why it will be needed; otherwise simply
remove the two lines. Ensure no other code in the same method depends on these
variables after removal.
- Around line 1004-1005: The prepare_draft_tokens method in ModelDrafter
currently requires resource_manager but the base class Drafter defines it as
optional; change the signature of ModelDrafter.prepare_draft_tokens to accept
resource_manager: Optional[ResourceManager] = None so it matches the base
contract, add or ensure Optional is imported from typing if missing, and mirror
the pattern used in ngram.py; update any internal usage of resource_manager to
handle None safely.

In `@tensorrt_llm/llmapi/llm_args.py`:
- Around line 1865-1876: The file contains a duplicate TypeAlias named
SpeculativeConfig that shadows the earlier discriminated union (the one defined
with Field(discriminator="decoding_type")), which removes SADecodingConfig and
PARDDecodingConfig; remove the second SpeculativeConfig definition (or rename it
if you truly need a separate non-discriminated alias) and keep the original
annotated union (including SADecodingConfig and PARDDecodingConfig alongside
DraftTargetDecodingConfig, Eagle3DecodingConfig, EagleDecodingConfig,
LookaheadDecodingConfig, MedusaDecodingConfig, MTPDecodingConfig,
NGramDecodingConfig, UserProvidedDecodingConfig, SaveHiddenStatesDecodingConfig,
AutoDecodingConfig) so the discriminator-based Pydantic union remains intact.

---

Nitpick comments:
In `@tensorrt_llm/_torch/attention_backend/interface.py`:
- Around line 371-381: The function update_spec_dec_param currently accepts many
trailing boolean/optional parameters positionally; make is_target_model,
model_is_wrapped, and spec_tree_manager keyword-only by inserting a bare *
before is_target_model in the signature so callers cannot accidentally bind
those flags positionally—update the signature in the update_spec_dec_param
definition and adjust any internal references accordingly (no other logic
changes).

In `@tensorrt_llm/_torch/attention_backend/trtllm.py`:
- Around line 502-509: The reshape from 1D to 2D (position_offsets_for_cpp based
on self.spec_decoding_position_offsets) assumes the 1D tensor length equals
max_num_requests * N; add a defensive check before the view to assert that
self.spec_decoding_position_offsets.numel() is divisible by
self.max_num_requests and (optionally) equals self.max_num_requests *
(self.max_total_draft_tokens + 1) (or raise a clear error mentioning
spec_decoding_position_offsets and max_num_requests) so the implicit
allocation/reshape contract in trtllm.py is explicit and fails fast when
violated.

In `@tensorrt_llm/_torch/pyexecutor/py_executor_creator.py`:
- Around line 435-437: Replace the direct class imports
DynamicTreeDraftingLoopWrapper, LinearDraftingLoopWrapper,
StaticTreeDraftingLoopWrapper with a namespaced module import for drafting_loops
and update all references to use drafting_loops.DynamicTreeDraftingLoopWrapper,
drafting_loops.LinearDraftingLoopWrapper, and
drafting_loops.StaticTreeDraftingLoopWrapper (e.g., where these classes are used
in the dispatch/registration code inside py_executor_creator.py) so the module
is imported, not individual classes.

In `@tensorrt_llm/_torch/pyexecutor/sampler.py`:
- Around line 2626-2668: The loop in _process_draft_tokens_dynamic_tree
repeatedly calls accept_index[j].item(), causing device reads per-iteration;
materialize accept_index to a Python list once before the request loop (e.g.
accept_indices = accept_index.tolist() or accept_index.cpu().tolist() and cast
elements to int), then iterate over accept_indices for add_token and
finish_if_reason calls, and compute request.py_num_accepted_draft_tokens_indices
from that list by subtracting 1 for positions after the root; keep using the
same symbols (accept_index -> accept_indices,
_process_draft_tokens_dynamic_tree, add_token, finish_if_reason,
request.py_num_accepted_draft_tokens_indices) so you only replace tensor
indexing with list indexing.

In `@tests/unittest/_torch/speculative/test_draft_token_tree_sampling.py`:
- Around line 8-9: Replace the direct class import with a module-level import
for the drafting_loops module and qualify the class through that namespace:
change the current "from tensorrt_llm._torch.speculative.drafting_loops import
StaticTreeDraftingLoopWrapper" to "import
tensorrt_llm._torch.speculative.drafting_loops as drafting_loops" (or "from
tensorrt_llm._torch.speculative import drafting_loops") and update all usages to
drafting_loops.StaticTreeDraftingLoopWrapper (also fix the similar import/usage
at the other occurrence noted).

ℹ️ Review info
⚙️ Run configuration

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro

Run ID: 388f70bd-848c-4020-99de-5838cd97e5b3

📥 Commits

Reviewing files that changed from the base of the PR and between 3139ffa and b15f7d6.

📒 Files selected for processing (23)
  • cpp/tensorrt_llm/kernels/speculativeDecoding/dynamicTreeKernels.cu
  • cpp/tensorrt_llm/kernels/speculativeDecoding/dynamicTreeKernels.h
  • cpp/tensorrt_llm/thop/CMakeLists.txt
  • cpp/tensorrt_llm/thop/dynamicTreeOp.cpp
  • examples/llm-api/quickstart_advanced.py
  • tensorrt_llm/_torch/attention_backend/interface.py
  • tensorrt_llm/_torch/attention_backend/sparse/dsa.py
  • tensorrt_llm/_torch/attention_backend/trtllm.py
  • tensorrt_llm/_torch/models/modeling_speculative.py
  • tensorrt_llm/_torch/pyexecutor/model_engine.py
  • tensorrt_llm/_torch/pyexecutor/py_executor.py
  • tensorrt_llm/_torch/pyexecutor/py_executor_creator.py
  • tensorrt_llm/_torch/pyexecutor/sampler.py
  • tensorrt_llm/_torch/speculative/drafting_loops.py
  • tensorrt_llm/_torch/speculative/dynamic_tree_ops.py
  • tensorrt_llm/_torch/speculative/eagle3.py
  • tensorrt_llm/_torch/speculative/eagle3_dynamic_tree.py
  • tensorrt_llm/_torch/speculative/model_drafter.py
  • tensorrt_llm/_torch/speculative/spec_tree_manager.py
  • tensorrt_llm/_torch/speculative/utils.py
  • tensorrt_llm/llmapi/llm_args.py
  • tests/unittest/_torch/speculative/test_draft_token_prepare_for_generation.py
  • tests/unittest/_torch/speculative/test_draft_token_tree_sampling.py

@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #38374 [ run ] completed with state FAILURE. Commit: b15f7d6
/LLM/main/L0_MergeRequest_PR pipeline #29741 completed with status: 'FAILURE'

⚠️ Action Required:

  • Please check the failed tests and fix your PR
  • If you cannot view the failures, ask the CI triggerer to share details
  • Once fixed, request an NVIDIA team member to trigger CI again

Link to invocation

@sunnyqgg sunnyqgg force-pushed the add_dyanmic_tree_support_one_model branch from 99b25a9 to c2c8ef6 Compare March 12, 2026 01:52
@sunnyqgg
Copy link
Copy Markdown
Collaborator Author

/bot run

@sunnyqgg sunnyqgg force-pushed the add_dyanmic_tree_support_one_model branch from c2c8ef6 to abd7543 Compare March 15, 2026 07:36
@sunnyqgg
Copy link
Copy Markdown
Collaborator Author

/bot run

@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #38968 [ run ] triggered by Bot. Commit: abd7543 Link to invocation

@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #38968 [ run ] completed with state FAILURE. Commit: abd7543
/LLM/main/L0_MergeRequest_PR pipeline #30250 completed with status: 'FAILURE'

CI Report

⚠️ Action Required:

  • Please check the failed tests and fix your PR
  • If you cannot view the failures, ask the CI triggerer to share details
  • Once fixed, request an NVIDIA team member to trigger CI again

Link to invocation

@sunnyqgg
Copy link
Copy Markdown
Collaborator Author

/bot run

1 similar comment
@sunnyqgg
Copy link
Copy Markdown
Collaborator Author

/bot run

@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #39082 [ run ] triggered by Bot. Commit: bb00556 Link to invocation

@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #39082 [ run ] completed with state FAILURE. Commit: bb00556
/LLM/main/L0_MergeRequest_PR pipeline #30345 completed with status: 'FAILURE'

CI Report

⚠️ Action Required:

  • Please check the failed tests and fix your PR
  • If you cannot view the failures, ask the CI triggerer to share details
  • Once fixed, request an NVIDIA team member to trigger CI again

Link to invocation

@sunnyqgg
Copy link
Copy Markdown
Collaborator Author

/bot run

@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #39197 [ run ] triggered by Bot. Commit: a79c91d Link to invocation

@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #42211 [ run ] triggered by Bot. Commit: deb3179 Link to invocation

@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #42211 [ run ] completed with state SUCCESS. Commit: deb3179
/LLM/main/L0_MergeRequest_PR pipeline #33029 completed with status: 'FAILURE'

CI Report

⚠️ Action Required:

  • Please check the failed tests and fix your PR
  • If you cannot view the failures, ask the CI triggerer to share details
  • Once fixed, request an NVIDIA team member to trigger CI again

Link to invocation

@sunnyqgg
Copy link
Copy Markdown
Collaborator Author

sunnyqgg commented Apr 8, 2026

/bot run --disable-fail-fast

@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #42279 [ run ] triggered by Bot. Commit: deb3179 Link to invocation

@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #42279 [ run ] completed with state SUCCESS. Commit: deb3179
/LLM/main/L0_MergeRequest_PR pipeline #33076 completed with status: 'FAILURE'

CI Report

⚠️ Action Required:

  • Please check the failed tests and fix your PR
  • If you cannot view the failures, ask the CI triggerer to share details
  • Once fixed, request an NVIDIA team member to trigger CI again

Link to invocation

@sunnyqgg
Copy link
Copy Markdown
Collaborator Author

sunnyqgg commented Apr 8, 2026

/bot run --disable-fail-fast

@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #42358 [ run ] triggered by Bot. Commit: deb3179 Link to invocation

@sunnyqgg
Copy link
Copy Markdown
Collaborator Author

sunnyqgg commented Apr 9, 2026

/bot run --disable-fail-fast

@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #42507 [ run ] triggered by Bot. Commit: deb3179 Link to invocation

@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #42507 [ run ] completed with state SUCCESS. Commit: deb3179
/LLM/main/L0_MergeRequest_PR pipeline #33251 completed with status: 'FAILURE'

CI Report

⚠️ Action Required:

  • Please check the failed tests and fix your PR
  • If you cannot view the failures, ask the CI triggerer to share details
  • Once fixed, request an NVIDIA team member to trigger CI again

Link to invocation

@sunnyqgg
Copy link
Copy Markdown
Collaborator Author

/bot run --disable-fail-fast

@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #42597 [ run ] triggered by Bot. Commit: deb3179 Link to invocation

@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #42597 [ run ] completed with state SUCCESS. Commit: deb3179
/LLM/main/L0_MergeRequest_PR pipeline #33323 completed with status: 'FAILURE'

CI Report

⚠️ Action Required:

  • Please check the failed tests and fix your PR
  • If you cannot view the failures, ask the CI triggerer to share details
  • Once fixed, request an NVIDIA team member to trigger CI again

Link to invocation

@sunnyqgg
Copy link
Copy Markdown
Collaborator Author

/bot run --disable-fail-fast

@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #42679 [ run ] triggered by Bot. Commit: deb3179 Link to invocation

…update_spec_dec_param

The parent TrtllmAttentionMetadata.update_spec_dec_param() no longer
accepts spec_decoding_tensor, but the DSA override still passed it
through to super(), causing a TypeError on all B200 8-GPU speculative
decoding tests (TestGLM5FP8, TestDeepSeekV32 DSA/NVFP4).

Signed-off-by: qgai <qgai@nvidia.com>
@sunnyqgg
Copy link
Copy Markdown
Collaborator Author

/bot run --disable-fail-fast

@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #42698 [ run ] triggered by Bot. Commit: cd13f99 Link to invocation

@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #42679 [ run ] completed with state ABORTED. Commit: deb3179

Link to invocation

…n feature matrices

Signed-off-by: qgai <qgai@nvidia.com>
@sunnyqgg
Copy link
Copy Markdown
Collaborator Author

/bot run --disable-fail-fast

@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #42762 [ run ] triggered by Bot. Commit: 9a9cc5e Link to invocation

@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #42762 [ run ] completed with state SUCCESS. Commit: 9a9cc5e
/LLM/main/L0_MergeRequest_PR pipeline #33439 completed with status: 'FAILURE'

CI Report

⚠️ Action Required:

  • Please check the failed tests and fix your PR
  • If you cannot view the failures, ask the CI triggerer to share details
  • Once fixed, request an NVIDIA team member to trigger CI again

Link to invocation

@sunnyqgg
Copy link
Copy Markdown
Collaborator Author

/bot run --disable-fail-fast

@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #42803 [ run ] triggered by Bot. Commit: 9a9cc5e Link to invocation

@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #42803 [ run ] completed with state SUCCESS. Commit: 9a9cc5e
/LLM/main/L0_MergeRequest_PR pipeline #33476 completed with status: 'SUCCESS'

CI Report

Link to invocation

@sunnyqgg sunnyqgg dismissed laikhtewari’s stale review April 12, 2026 01:42

Already changed based on his review

@sunnyqgg sunnyqgg merged commit 4ece13c into NVIDIA:main Apr 12, 2026
5 checks passed
@sunnyqgg
Copy link
Copy Markdown
Collaborator Author

/bot run

@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #42869 [ ] completed with state FAILURE. Commit: 9a9cc5e
Not allowed on merged PR

Link to invocation

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

8 participants