Skip to content

Suppport TBO in ATOM#515

Open
ZhangLirong-amd wants to merge 4 commits intomainfrom
zlr/tbo_dev
Open

Suppport TBO in ATOM#515
ZhangLirong-amd wants to merge 4 commits intomainfrom
zlr/tbo_dev

Conversation

@ZhangLirong-amd
Copy link
Copy Markdown
Contributor

@ZhangLirong-amd ZhangLirong-amd commented Apr 8, 2026

Motivation

We enable TBO with dp attn + mori
--enable-dp-attention --enable-expert-parallel --enable-tbo

Deepseek:

MORI_SHMEM_MODE=ISOLATION python3 -m atom.entrypoints.openai_server --model /data/deepseek-ai/DeepSeek-R1-0528/ -tp 8 --port 5678  --gpu-memory-utilization 0.4  --enable-dp-attention --enable-expert-parallel --enable-tbo --server-port 7777

With MTP=3 perf:
============ Serving Benchmark Result ============
Successful requests: 256
Benchmark duration (s): 45.98
Total input tokens: 236166
Total generated tokens: 234328
Request throughput (req/s): 5.57
Output token throughput (tok/s): 5095.77
Total Token throughput (tok/s): 10231.51

GPT-OSS:

MORI_SHMEM_MODE=ISOLATION python3 -m atom.entrypoints.openai_server --model /data/models/openai/gpt-oss-120b/ -tp 2 --port 5678  --gpu-memory-utilization 0.4  --enable-dp-attention   --server-port 7777 --torch-profiler-dir ./log --enable-expert-parallel --enable-tbo

Successful requests: 256
Benchmark duration (s): 23.71
Total input tokens: 236166
Total generated tokens: 234891
Request throughput (req/s): 10.80
Output token throughput (tok/s): 9907.19
Total Token throughput (tok/s): 19868.16

Overlap:

perfill:
image

decode:
image

Technical Details

Test Plan

Test Result

Submission Checklist

Copilot AI review requested due to automatic review settings April 8, 2026 05:38
Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This PR adds TBO (Two-Batch Overlap) support to ATOM by introducing a micro-batch (ubatch) dual-thread execution path intended to overlap MoE communication with compute, and wires it into model execution and CUDAGraph capture.

Changes:

  • Extend ForwardContext to carry ubatch slicing info and add thread-local forward-context support for TBO worker threads.
  • Introduce a new atom.utils.dbo package implementing ubatch slicing, TBO thread/stream/event coordination, and a UBatchWrapper to run micro-batches in threads (including a TBO CUDAGraph capture path).
  • Integrate TBO into attention metadata builders, MORI prepare/finalize async flow, and model runner scheduling/DP synchronization + CLI/config flags.

Reviewed changes

Copilot reviewed 17 out of 17 changed files in this pull request and generated 1 comment.

Show a summary per file
File Description
atom/utils/forward_context.py Add ubatch_slices to ForwardContext and thread-local context lookup for TBO threads.
atom/utils/dbo/ubatching.py Implement TBOContext and module helpers for yield/stream switching and recv hooks.
atom/utils/dbo/ubatch_wrapper.py Add UBatchWrapper to run ubatches in threads and optionally capture TBO execution in CUDAGraphs.
atom/utils/dbo/ubatch_splitting.py Provide utilities to split batches into ubatches (decode and token-balanced prefill).
atom/utils/dbo/init.py Export the new DBO/TBO public surface.
atom/models/deepseek_v2.py Disable dual-stream MoE path when TBO is enabled.
atom/model_ops/moe.py Enable MORI async/TBO wiring and instantiate per-ubatch MORI ops.
atom/model_ops/fused_moe/mori_prepare_finalize.py Add async prepare/finalize paths (comm-stream + AsyncLL) and per-ubatch MORI op support.
atom/model_ops/fused_moe/modular_kernel.py Route async prepare/finalize through TBO yield + hook mechanism; adjust finalize API usage.
atom/model_ops/attentions/backends.py Add build_ubatch_prefill_metadata using split_attn_metadata.
atom/model_ops/attentions/aiter_mla.py Allocate per-ubatch buffers and build per-ubatch attention metadata for TBO/CUDAGraph decode.
atom/model_ops/attentions/aiter_attention.py Allocate per-ubatch buffers and build per-ubatch attention metadata for TBO/CUDAGraph decode.
atom/model_engine/scheduler.py Extend next-batch info with request counts to support DP sync decisions.
atom/model_engine/model_runner.py Wrap model with UBatchWrapper, create ubatch slices per batch, and add a TBO CUDAGraph capture path.
atom/model_engine/engine_core.py Sync DP state including request counts; run dummy prefill with minimal reqs for TBO agreement.
atom/model_engine/arg_utils.py Add CLI flags --enable-tbo and --low-latency.
atom/config.py Add config toggles enable_tbo and enable_low_latency.

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment on lines +1505 to +1514
if self.config.enable_tbo and (is_prefill or not batch.is_dummy_run):
tbo_num_reqs = batch.total_seqs_num_prefill if is_prefill else scheduled_bs
dp_tbo_tensor = (
prefill_reqs_across_dp if is_prefill else num_tokens_across_dp
)
if dp_tbo_tensor is not None:
# use min across all ranks so all agree on TBO on/off.
min_reqs = int(torch.min(dp_tbo_tensor).item())
can_tbo = min_reqs >= 2
else:
Copy link

Copilot AI Apr 8, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In DP mode, the TBO enable/disable agreement for decode is derived from num_tokens_across_dp (token count), but TBO splitting is based on scheduled_bs / tbo_num_reqs (request count). With speculative decode (MTP, max_q_len > 1), a rank can have tbo_num_reqs == 1 while still having num_tokens >= 2, causing some ranks to enter TBO (async MORI path) and others to run non-TBO (sync path), which risks cross-rank mismatch/hangs. Use a DP-reduced/gathered request-count tensor (e.g., DPMetadata.num_tokens_across_dp(scheduled_bs, ...)) for the min check instead of token counts, so all ranks make the same TBO decision based on requests.

Copilot uses AI. Check for mistakes.
Copilot AI review requested due to automatic review settings April 8, 2026 09:26
Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Copilot reviewed 17 out of 17 changed files in this pull request and generated 7 comments.


💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment on lines +17 to +18
config = get_current_atom_config()
if config is None:
Copy link

Copilot AI Apr 8, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

tbo_enabled() calls get_current_atom_config(), which asserts if the config hasn’t been set; the subsequent if config is None: branch is therefore unreachable and may give a false sense that this helper is safe to call early. Consider either removing the None check, or switching to a non-asserting accessor (or try/except AssertionError) so tbo_enabled() reliably returns False when config isn’t initialized yet.

Suggested change
config = get_current_atom_config()
if config is None:
try:
config = get_current_atom_config()
except AssertionError:

Copilot uses AI. Check for mistakes.
Comment on lines +209 to +211
"""Register a recv completion hook on the NEXT ubatch's context."""
ctx_idx = _THREAD_ID_TO_CONTEXT[threading.get_ident()]
next_ctx = _CURRENT_CONTEXTS[(ctx_idx + 1) % _NUM_UBATCHES]
Copy link

Copilot AI Apr 8, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

tbo_register_recv_hook() assumes the “next” ubatch context is always present, but _CURRENT_CONTEXTS[(ctx_idx + 1) % _NUM_UBATCHES] can be None (e.g., if the other ubatch thread exits early and clears its slot in __exit__). This will raise an AttributeError and can crash the forward. Add a guard (and decide on a policy such as running the hook immediately when the next context is gone, or storing it in a per-ubatch pending queue that doesn’t require next_ctx to still be alive).

Suggested change
"""Register a recv completion hook on the NEXT ubatch's context."""
ctx_idx = _THREAD_ID_TO_CONTEXT[threading.get_ident()]
next_ctx = _CURRENT_CONTEXTS[(ctx_idx + 1) % _NUM_UBATCHES]
"""Register a recv completion hook on the NEXT ubatch's context.
If the peer ubatch context has already exited and cleared its slot,
run the hook immediately rather than dereferencing ``None`` and
crashing the forward.
"""
ctx_idx = _THREAD_ID_TO_CONTEXT[threading.get_ident()]
next_ctx = _CURRENT_CONTEXTS[(ctx_idx + 1) % _NUM_UBATCHES]
if next_ctx is None:
hook()
return

Copilot uses AI. Check for mistakes.
Comment on lines +84 to +90
# All threads reach the barrier, then the main thread wakes thread 0
self.ready_barrier.wait()

# Wait for our turn (thread 0 is woken by the main thread)
self.cpu_wait_event.wait()
self.cpu_wait_event.clear()

Copy link

Copilot AI Apr 8, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

TBOContext.__enter__() blocks on ready_barrier.wait() and later on cpu_wait_event.wait() with no timeout or failure path. If any ubatch thread raises before reaching the barrier (e.g., device init / stream init failures) the main thread can block forever on its own ready_barrier.wait() in UBatchWrapper, resulting in a hard hang. Consider using a timeout + BrokenBarrierError handling, and/or ensuring threads always reach/abort the barrier in a finally block (e.g., ready_barrier.abort() plus waking any waiters) so errors propagate instead of deadlocking.

Copilot uses AI. Check for mistakes.
Comment on lines +145 to +156
try:
threads = []
for i in range(N):
t = threading.Thread(target=_ubatch_thread, args=(i,))
threads.append(t)
t.start()

self.ready_barrier.wait()
tbo_ctxs[0].cpu_wait_event.set()

for t in threads:
t.join()
Copy link

Copilot AI Apr 8, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

_run_ubatches() waits on self.ready_barrier.wait() after starting worker threads. Because the barrier wait in TBOContext.__enter__() is unconditional and has no timeout, any exception in a worker before it reaches the barrier will cause the main thread to block forever here (and errors[idx] will never be checked). Consider adding timeouts + BrokenBarrierError handling, and aborting the barrier / setting the initial cpu_wait_event in the worker exception path so the system fails fast instead of deadlocking.

Copilot uses AI. Check for mistakes.
if num_reqs >= self.max_num_seqs:
break
total_tokens += tokens
num_reqs += 1
Copy link

Copilot AI Apr 8, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

get_next_batch_info() can return (True, 0, 0) when self.waiting is non-empty but the first request would exceed max_num_batched_tokens (the loop breaks before adding any seq). This makes the DP sync state report “prefill with 0 tokens/reqs”, which can cause other ranks to compute global_max_tokens from decode ranks and run an oversized dummy prefill. Consider ensuring the function always reports at least one waiting request (e.g., include the first seq even if it exceeds the limit, or clamp total_tokens to the limit but keep num_reqs=1).

Suggested change
num_reqs += 1
num_reqs += 1
if num_reqs == 0:
first_seq = self.waiting[0]
first_tokens = first_seq.num_tokens - first_seq.num_cached_tokens
total_tokens = min(first_tokens, self.max_num_batched_tokens)
num_reqs = 1

Copilot uses AI. Check for mistakes.
output = hidden_states
else:
output = torch.zeros_like(hidden_states)
output = None
Copy link

Copilot AI Apr 8, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

output is set to None when inplace is False, but _finalize() passes it into prepare_finalize.finalize() / finalize_async(), whose interface is typed to require a torch.Tensor. This works today because MoriPrepareAndFinalize ignores the output argument, but it’s a fragile contract and will break as soon as another FusedMoEPrepareAndFinalize implementation relies on the output buffer. Consider either always allocating an output tensor when inplace is False (as before), or updating the prepare/finalize interface to accept Optional[torch.Tensor] and handling None explicitly in all implementations.

Suggested change
output = None
output = torch.empty_like(hidden_states)

Copilot uses AI. Check for mistakes.
Comment on lines +587 to +590
def get_next_batch_info(self) -> tuple[bool, int, int]:
if self.waiting:
# new request is waiting, will do prefill
seq = self.waiting[0]
num_tokens = seq.num_tokens - seq.num_cached_tokens
return (True, num_tokens)
num_reqs = 0
Copy link

Copilot AI Apr 8, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

get_next_batch_info() now returns a 3-tuple, but the existing unit tests still expect the old (bool, int) shape (see tests/test_scheduler.py:323-336). Update/add tests to validate the new (is_prefill, total_tokens, num_reqs) contract for empty/waiting/running cases (and any edge cases introduced by the new batching loop).

Copilot uses AI. Check for mistakes.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants