Conversation
There was a problem hiding this comment.
Pull request overview
This PR adds TBO (Two-Batch Overlap) support to ATOM by introducing a micro-batch (ubatch) dual-thread execution path intended to overlap MoE communication with compute, and wires it into model execution and CUDAGraph capture.
Changes:
- Extend
ForwardContextto carry ubatch slicing info and add thread-local forward-context support for TBO worker threads. - Introduce a new
atom.utils.dbopackage implementing ubatch slicing, TBO thread/stream/event coordination, and aUBatchWrapperto run micro-batches in threads (including a TBO CUDAGraph capture path). - Integrate TBO into attention metadata builders, MORI prepare/finalize async flow, and model runner scheduling/DP synchronization + CLI/config flags.
Reviewed changes
Copilot reviewed 17 out of 17 changed files in this pull request and generated 1 comment.
Show a summary per file
| File | Description |
|---|---|
| atom/utils/forward_context.py | Add ubatch_slices to ForwardContext and thread-local context lookup for TBO threads. |
| atom/utils/dbo/ubatching.py | Implement TBOContext and module helpers for yield/stream switching and recv hooks. |
| atom/utils/dbo/ubatch_wrapper.py | Add UBatchWrapper to run ubatches in threads and optionally capture TBO execution in CUDAGraphs. |
| atom/utils/dbo/ubatch_splitting.py | Provide utilities to split batches into ubatches (decode and token-balanced prefill). |
| atom/utils/dbo/init.py | Export the new DBO/TBO public surface. |
| atom/models/deepseek_v2.py | Disable dual-stream MoE path when TBO is enabled. |
| atom/model_ops/moe.py | Enable MORI async/TBO wiring and instantiate per-ubatch MORI ops. |
| atom/model_ops/fused_moe/mori_prepare_finalize.py | Add async prepare/finalize paths (comm-stream + AsyncLL) and per-ubatch MORI op support. |
| atom/model_ops/fused_moe/modular_kernel.py | Route async prepare/finalize through TBO yield + hook mechanism; adjust finalize API usage. |
| atom/model_ops/attentions/backends.py | Add build_ubatch_prefill_metadata using split_attn_metadata. |
| atom/model_ops/attentions/aiter_mla.py | Allocate per-ubatch buffers and build per-ubatch attention metadata for TBO/CUDAGraph decode. |
| atom/model_ops/attentions/aiter_attention.py | Allocate per-ubatch buffers and build per-ubatch attention metadata for TBO/CUDAGraph decode. |
| atom/model_engine/scheduler.py | Extend next-batch info with request counts to support DP sync decisions. |
| atom/model_engine/model_runner.py | Wrap model with UBatchWrapper, create ubatch slices per batch, and add a TBO CUDAGraph capture path. |
| atom/model_engine/engine_core.py | Sync DP state including request counts; run dummy prefill with minimal reqs for TBO agreement. |
| atom/model_engine/arg_utils.py | Add CLI flags --enable-tbo and --low-latency. |
| atom/config.py | Add config toggles enable_tbo and enable_low_latency. |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| if self.config.enable_tbo and (is_prefill or not batch.is_dummy_run): | ||
| tbo_num_reqs = batch.total_seqs_num_prefill if is_prefill else scheduled_bs | ||
| dp_tbo_tensor = ( | ||
| prefill_reqs_across_dp if is_prefill else num_tokens_across_dp | ||
| ) | ||
| if dp_tbo_tensor is not None: | ||
| # use min across all ranks so all agree on TBO on/off. | ||
| min_reqs = int(torch.min(dp_tbo_tensor).item()) | ||
| can_tbo = min_reqs >= 2 | ||
| else: |
There was a problem hiding this comment.
In DP mode, the TBO enable/disable agreement for decode is derived from num_tokens_across_dp (token count), but TBO splitting is based on scheduled_bs / tbo_num_reqs (request count). With speculative decode (MTP, max_q_len > 1), a rank can have tbo_num_reqs == 1 while still having num_tokens >= 2, causing some ranks to enter TBO (async MORI path) and others to run non-TBO (sync path), which risks cross-rank mismatch/hangs. Use a DP-reduced/gathered request-count tensor (e.g., DPMetadata.num_tokens_across_dp(scheduled_bs, ...)) for the min check instead of token counts, so all ranks make the same TBO decision based on requests.
0f6226d to
1072248
Compare
There was a problem hiding this comment.
Pull request overview
Copilot reviewed 17 out of 17 changed files in this pull request and generated 7 comments.
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| config = get_current_atom_config() | ||
| if config is None: |
There was a problem hiding this comment.
tbo_enabled() calls get_current_atom_config(), which asserts if the config hasn’t been set; the subsequent if config is None: branch is therefore unreachable and may give a false sense that this helper is safe to call early. Consider either removing the None check, or switching to a non-asserting accessor (or try/except AssertionError) so tbo_enabled() reliably returns False when config isn’t initialized yet.
| config = get_current_atom_config() | |
| if config is None: | |
| try: | |
| config = get_current_atom_config() | |
| except AssertionError: |
| """Register a recv completion hook on the NEXT ubatch's context.""" | ||
| ctx_idx = _THREAD_ID_TO_CONTEXT[threading.get_ident()] | ||
| next_ctx = _CURRENT_CONTEXTS[(ctx_idx + 1) % _NUM_UBATCHES] |
There was a problem hiding this comment.
tbo_register_recv_hook() assumes the “next” ubatch context is always present, but _CURRENT_CONTEXTS[(ctx_idx + 1) % _NUM_UBATCHES] can be None (e.g., if the other ubatch thread exits early and clears its slot in __exit__). This will raise an AttributeError and can crash the forward. Add a guard (and decide on a policy such as running the hook immediately when the next context is gone, or storing it in a per-ubatch pending queue that doesn’t require next_ctx to still be alive).
| """Register a recv completion hook on the NEXT ubatch's context.""" | |
| ctx_idx = _THREAD_ID_TO_CONTEXT[threading.get_ident()] | |
| next_ctx = _CURRENT_CONTEXTS[(ctx_idx + 1) % _NUM_UBATCHES] | |
| """Register a recv completion hook on the NEXT ubatch's context. | |
| If the peer ubatch context has already exited and cleared its slot, | |
| run the hook immediately rather than dereferencing ``None`` and | |
| crashing the forward. | |
| """ | |
| ctx_idx = _THREAD_ID_TO_CONTEXT[threading.get_ident()] | |
| next_ctx = _CURRENT_CONTEXTS[(ctx_idx + 1) % _NUM_UBATCHES] | |
| if next_ctx is None: | |
| hook() | |
| return |
| # All threads reach the barrier, then the main thread wakes thread 0 | ||
| self.ready_barrier.wait() | ||
|
|
||
| # Wait for our turn (thread 0 is woken by the main thread) | ||
| self.cpu_wait_event.wait() | ||
| self.cpu_wait_event.clear() | ||
|
|
There was a problem hiding this comment.
TBOContext.__enter__() blocks on ready_barrier.wait() and later on cpu_wait_event.wait() with no timeout or failure path. If any ubatch thread raises before reaching the barrier (e.g., device init / stream init failures) the main thread can block forever on its own ready_barrier.wait() in UBatchWrapper, resulting in a hard hang. Consider using a timeout + BrokenBarrierError handling, and/or ensuring threads always reach/abort the barrier in a finally block (e.g., ready_barrier.abort() plus waking any waiters) so errors propagate instead of deadlocking.
| try: | ||
| threads = [] | ||
| for i in range(N): | ||
| t = threading.Thread(target=_ubatch_thread, args=(i,)) | ||
| threads.append(t) | ||
| t.start() | ||
|
|
||
| self.ready_barrier.wait() | ||
| tbo_ctxs[0].cpu_wait_event.set() | ||
|
|
||
| for t in threads: | ||
| t.join() |
There was a problem hiding this comment.
_run_ubatches() waits on self.ready_barrier.wait() after starting worker threads. Because the barrier wait in TBOContext.__enter__() is unconditional and has no timeout, any exception in a worker before it reaches the barrier will cause the main thread to block forever here (and errors[idx] will never be checked). Consider adding timeouts + BrokenBarrierError handling, and aborting the barrier / setting the initial cpu_wait_event in the worker exception path so the system fails fast instead of deadlocking.
| if num_reqs >= self.max_num_seqs: | ||
| break | ||
| total_tokens += tokens | ||
| num_reqs += 1 |
There was a problem hiding this comment.
get_next_batch_info() can return (True, 0, 0) when self.waiting is non-empty but the first request would exceed max_num_batched_tokens (the loop breaks before adding any seq). This makes the DP sync state report “prefill with 0 tokens/reqs”, which can cause other ranks to compute global_max_tokens from decode ranks and run an oversized dummy prefill. Consider ensuring the function always reports at least one waiting request (e.g., include the first seq even if it exceeds the limit, or clamp total_tokens to the limit but keep num_reqs=1).
| num_reqs += 1 | |
| num_reqs += 1 | |
| if num_reqs == 0: | |
| first_seq = self.waiting[0] | |
| first_tokens = first_seq.num_tokens - first_seq.num_cached_tokens | |
| total_tokens = min(first_tokens, self.max_num_batched_tokens) | |
| num_reqs = 1 |
| output = hidden_states | ||
| else: | ||
| output = torch.zeros_like(hidden_states) | ||
| output = None |
There was a problem hiding this comment.
output is set to None when inplace is False, but _finalize() passes it into prepare_finalize.finalize() / finalize_async(), whose interface is typed to require a torch.Tensor. This works today because MoriPrepareAndFinalize ignores the output argument, but it’s a fragile contract and will break as soon as another FusedMoEPrepareAndFinalize implementation relies on the output buffer. Consider either always allocating an output tensor when inplace is False (as before), or updating the prepare/finalize interface to accept Optional[torch.Tensor] and handling None explicitly in all implementations.
| output = None | |
| output = torch.empty_like(hidden_states) |
| def get_next_batch_info(self) -> tuple[bool, int, int]: | ||
| if self.waiting: | ||
| # new request is waiting, will do prefill | ||
| seq = self.waiting[0] | ||
| num_tokens = seq.num_tokens - seq.num_cached_tokens | ||
| return (True, num_tokens) | ||
| num_reqs = 0 |
There was a problem hiding this comment.
get_next_batch_info() now returns a 3-tuple, but the existing unit tests still expect the old (bool, int) shape (see tests/test_scheduler.py:323-336). Update/add tests to validate the new (is_prefill, total_tokens, num_reqs) contract for empty/waiting/running cases (and any edge cases introduced by the new batching loop).
Motivation
We enable TBO with dp attn + mori
--enable-dp-attention --enable-expert-parallel --enable-tbo
Deepseek:
With MTP=3 perf:
============ Serving Benchmark Result ============
Successful requests: 256
Benchmark duration (s): 45.98
Total input tokens: 236166
Total generated tokens: 234328
Request throughput (req/s): 5.57
Output token throughput (tok/s): 5095.77
Total Token throughput (tok/s): 10231.51
GPT-OSS:
Successful requests: 256
Benchmark duration (s): 23.71
Total input tokens: 236166
Total generated tokens: 234891
Request throughput (req/s): 10.80
Output token throughput (tok/s): 9907.19
Total Token throughput (tok/s): 19868.16
Overlap:
perfill:

decode:

Technical Details
Test Plan
Test Result
Submission Checklist