Skip to content

TE EP integration to MoEBlock#3116

Draft
tdophung wants to merge 61 commits into
NVIDIA:mainfrom
tdophung:teddy/te_ep_integration
Draft

TE EP integration to MoEBlock#3116
tdophung wants to merge 61 commits into
NVIDIA:mainfrom
tdophung:teddy/te_ep_integration

Conversation

@tdophung

@tdophung tdophung commented Jun 10, 2026

Copy link
Copy Markdown
Collaborator

Description

Will rebase and squash the commits on this branch once about to merge
Will also change the JAX APIs if needed when TE EP JAX merge

Fixes # (issue)

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Changes

Please list the changes introduced in this PR:

  • Change A
  • Change B

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

phu0ngng added 2 commits June 9, 2026 18:27
Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
…em_reloc gating

Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
@tdophung tdophung force-pushed the teddy/te_ep_integration branch from 0ff3bff to bd14fe6 Compare June 10, 2026 21:58
Comment thread transformer_engine/jax/flax/moe.py Outdated
Comment thread transformer_engine/jax/flax/moe.py Outdated
# Minimum per-expert slot alignment fed to ``tex.ep_prepare``. Default 0
# uses the natural slot count; set to e.g. 128 to satisfy FP8 grouped-GEMM
# tile alignment.
align_size: int = 0

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Placeholder comment for me to fix this so align_size is inferred automatically based on the recipe and doesn't need to be specified by the user

Comment thread transformer_engine/jax/flax/moe.py Outdated
nn.with_logical_partitioning(self.bias_init, ("exp",)),
(self.num_experts,),
self.dtype,
jnp.float32,

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is the router always in fp32 so this expert bias must also be? If so, can we add a small comment indicating this

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes I will add a comment



__all__ = ["moe", "PermutationBackend"]
def _with_sharding_constraint_cast_bwd(x: jnp.ndarray, sharding) -> jnp.ndarray:

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do we need this utility function? I haven't seen something like this required for our other VJPs

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

for MoE particularly, In our _moe_bwd_rule, d_x is built from two cotangent paths:

d_x_from_dispatch = tex.ep_dispatch_bwd(...) in bf16 if x is in bf16.
and
d_x_from_gate = d_logits_2d @ gate_kernel^T, where d_logits_2d comes from tex.fused_topk_with_score_function_bwd, which only ooutput the logits in fp32 (per alp in one of the weekly meetings + I checked the code for the router kernel). So d_x_from_gate is fp32.

Therefore, d_x = d_x_from_dispatch + d_x_from_gate = bf16 + fp32 = fp32, while our x could be bf16. So the cotangent that flows to bwd needs to be constrainted. Other VJPs don't have this issue because the dgrad will just be the same dtype as the activation dtype.

# is a frozen dataclass of ints); the rest are jnp.ndarray,
# GroupedNoScaleTensor (already a pytree), or None when aux_loss_coeff == 0.
@register_pytree_node_class
@dataclass

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this tree_flatten was from my patch, but looking at the diff I think it'd be better to use the @flax_struct.dataclass you were using on the permutation dataclasses since that seems to auto-populate a default pytree flatten/unflatten for us

Comment thread transformer_engine/jax/moe.py Outdated
else:
d_recv_w_from_intermediate = jnp.zeros_like(recv_w_flat)

# Activation bwd. Mirror the fwd's fp32 promotion of silu+multiply

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is this dtype casting required? I don't recall us needing it for the non-MoE LNMLP block

# local expert. We must size to that worst case or NCCL EP's HT kernel
# rejects the dispatch buffer with ``invalid argument``.
natural_spe = num_ep * max_tokens_per_rank # = (B // dp_size) * S
# NCCL EP requires each expert-major output block to be at least

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we have a use-case for user-specified alignments beyond 128 currently? If NCCL EP requires an alignment of at least 128, and since an alignment of 128 is sufficient for all TE grouped GEMM types, would it make sense to instead hardcode _ALIGN_SIZE = 128 as a constant at the top of the file for now to simplify this MoEBlock API.

We can always expand the API to support a user-specified align size in the future

batch_pspec_axis = (*data_parallelism_axes, ep_axis)
ep3_spec = P(batch_pspec_axis, None, None)
ep2_spec = P(batch_pspec_axis, None)
x = jax.lax.with_sharding_constraint(x, NamedSharding(mesh, ep3_spec))

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Which axis name inputs are physical mesh axes and why can be logical axes? I see above x = with_sharding_constraint_by_logical_axes(x, input_axes) but here we directly use jax.lax.with_sharding_constraint which only supports mesh axes.

No need to make any changes for now, I just want to assess which are which and then we can discuss if it makes sense to support logical on some/all or if some are required to be physical axes. Thanks!

Comment thread transformer_engine/jax/moe.py Outdated
# `grad_pre_combine * w` sees them. Padded positions in sparse_probs
# are already zero (routing_map is False there); only the rare
# underflow path emits NaN.
sparse_probs = jnp.where(jnp.isnan(sparse_probs), 0, sparse_probs).astype(dtype)

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this NaN filtering a debugging artifact or something we need in the final version?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

debugging artifact. Remopving

tdophung added a commit to tdophung/TransformerEngine that referenced this pull request Jun 12, 2026
…IA#3116)

Address jberchtold-nvidia's PR NVIDIA#3116 nit "rename use_bias ->
use_ffn_bias and use_expert_bias -> use_expert_routing_bias". The
two flags are siblings (they enable two different bias buffers) but
the old names suggested ``use_bias`` was the general fallback, which
wasn't the intent. The new names make the FFN-vs-routing distinction
obvious from the call site.

* transformer_engine/jax/flax/moe.py
    use_bias -> use_ffn_bias  (dataclass field + branch in __call__
    + docstring entry)
    use_expert_bias -> use_expert_routing_bias  (same)
* tests/jax/test_te_ep_moe.py
    _make_block(use_expert_bias=...) -> use_expert_routing_bias
    sigmoid-bias-strong config key updated
    _reference_kwargs_from_config now reads use_expert_routing_bias

``_MoEBlock`` is still the experimental underscore-prefixed alias
(no public ``MoEBlock`` export yet), so the rename is API-safe.

The pre-resync legacy tests (``test_moe_vjp.py``,
``test_multiprocess_moe_vjp.py``) are intentionally not updated --
they already reference removed APIs like ``PermutationBackend`` and
need a separate post-resync cleanup pass.

Signed-off-by: tdophung <tdophung@nvidia.com>
tdophung added a commit to tdophung/TransformerEngine that referenced this pull request Jun 12, 2026
…and inline justifications)

Responds to jberchtold-nvidia's PR NVIDIA#3116 review threads on
``transformer_engine/jax/moe.py``. All changes are confined to a
single file because each review thread targets a localized region
and splitting mid-file would risk reordering bugs.

Per review thread:

1. "Why do we need _with_sharding_constraint_cast_bwd? I haven't
    seen something like this required for our other VJPs."
   -- Expand the helper's docstring to spell out exactly why MoE
   needs it: unlike LN+MLP, the MoE bwd composes a bf16 cotangent
   from ep_dispatch_bwd with an fp32 cotangent from
   fused_topk_with_score_function_bwd (which the fwd's
   logits_2d -> fp32 promotion forces). Without the cast, ``d_x``
   surfaces at fp32 even when ``x`` is bf16, doubling activation
   grad bandwidth and breaking any downstream LN bwd that pins a
   bf16 layout. (Review thread "Why do we need this utility
   function?".)

2. "Why is this dtype casting required? I don't recall us needing
    it for the non-MoE LNMLP block."
   -- Expand the comment above the bwd activation fp32 promotion
   to explain the MoE-specific math: LN+MLP's silu sits behind a
   downstream LN that absorbs the bf16 rounding error, while
   MoE's silu sits on the *expert* side of routing -- the bf16
   rounding rides directly into expert_outputs and is summed
   across topk experts by ep_combine. Bf16 silu alone drifts ~1%
   vs fp32 silu and compounds through wo->combine into the ~1.4%
   per-element parity gap we measured against the pure-JAX
   softmax reference. Mirroring the fwd's fp32 promotion in the
   bwd keeps silu' in lock-step with silu. (Review thread on
   "# Activation bwd. Mirror the fwd's fp32 promotion of
   silu+multiply".)

3. "Do we have a use-case for user-specified alignments beyond
    128 currently? ... it'd make sense to instead hardcode
    _ALIGN_SIZE = 128 as a constant at the top of the file for
    now to simplify this MoEBlock API. We can always expand the
    API to support a user-specified align size in the future."
   -- Implement the suggestion. Drop ``align_size`` from
   ``_moe_fwd_rule`` / ``_moe_bwd_rule`` / ``_moe`` / public
   ``moe()``; shift the ``custom_vjp`` ``nondiff_argnums`` from
   ``range(9, 27)`` -> ``range(9, 26)``; replace ``effective_align
   = max(int(align_size), 128)`` with the new module-level
   ``_ALIGN_SIZE = 128`` constant. Trim the ``moe()`` docstring
   accordingly. (Review thread on
   "natural_spe = num_ep * max_tokens_per_rank".)

4. "Which axis name inputs are physical mesh axes and why can be
    logical axes? ... No need to make any changes for now, I just
    want to assess which are which and then we can discuss if it
    makes sense to support logical on some/all or if some are
    required to be physical axes."
   -- Add an "Axis-name parameters" section to ``moe()``'s
   docstring listing which kwargs are physical mesh axes
   (``ep_axis``, ``data_parallelism_axes`` -- they index
   ``Mesh.shape`` directly to compute ``num_ep`` / ``dp_size``
   and to construct the ``P((dp..., ep), None, None)`` for
   ``jax.lax.with_sharding_constraint``) vs logical axes
   (``input_axes``, ``gate_kernel_axes``, ``wi_kernel_axes``,
   ``wo_kernel_axes`` -- resolved via the Flax logical-axis
   rules). Also document why ``ep_axis`` / ``data_parallelism_axes``
   are intentionally non-logical: the EP comm-group construction
   (``dp_color = rank // ep_size``) and the bootstrap signature
   check both require concrete integer sizes. (Review thread on
   "batch_pspec_axis = (*data_parallelism_axes, ep_axis)".)

5. "Is this NaN filtering a debugging artifact or something we
    need in the final version?"
   -- Strengthen the inline comment above
   ``sparse_probs = jnp.where(jnp.isnan(sparse_probs), 0, ...)``
   to explicitly call this out as a CORRECTNESS REQUIREMENT, not
   a debugging artifact: it covers the sigmoid+K>1 underflow
   path where top-K sigmoid scores all round to zero and the
   ``weights / (weights.sum + 1e-20)`` normalisation emits NaN.
   Observationally the filter is a no-op on the dense unit-test
   distributions, but it must stay in for sparse / production
   routing. (Review thread on
   "sparse_probs = jnp.where(jnp.isnan(sparse_probs), ...).")

Not addressed in this commit (intentional):

* Review thread on the ``align_size: int = 0`` placeholder in
  ``flax/moe.py`` ("Placeholder comment for me to fix this so
  align_size is inferred automatically based on the recipe and
  doesn't need to be specified by the user"). That's
  jberchtold's own follow-up.
* Review thread on the explicit ``tree_flatten`` /
  ``tree_unflatten`` on ``_Ctx`` ("better to use the
  ``@flax_struct.dataclass``"). Deferred to a separate, testable
  commit because changing a ``custom_vjp`` residual's pytree
  registration touches subtle ordering / None-handling semantics
  that warrant their own bisect surface.
* Review thread on ``use_bias`` / ``use_expert_bias`` renames --
  handled in the immediately preceding commit
  ``jax/flax,tests: rename use_bias/use_expert_bias for symmetry``.
* Review thread on the ``expert_bias`` fp32 init -- already
  resolved during the Phuong PR NVIDIA#3036 resync (the redundant
  ``jnp.float32`` second-dtype argument on ``self.param`` was
  dropped; ``expert_bias`` now lives at ``self.dtype``).

Signed-off-by: tdophung <tdophung@nvidia.com>
tdophung added a commit to tdophung/TransformerEngine that referenced this pull request Jun 12, 2026
…aN sanitizer

* Rewrite the inline justifications added in 078a7d80 so each one
  reads as standalone code documentation, not as a reply to a
  reviewer: drop "per PR NVIDIA#3116 review", "review feedback",
  "Renamed from ... per PR ..." and similar PR/thread references
  from moe.py, flax/moe.py, and tests/jax/test_te_ep_moe.py.
  Technical content (why the fp32 promotion is needed for the MoE
  silu+multiply, why _with_sharding_constraint_cast_bwd exists,
  physical-vs-logical axis split in moe() docstring, the 128
  alignment rationale) is preserved and reframed to be useful to
  a reader who has no PR context.

* Drop the jnp.where(jnp.isnan(sparse_probs), 0, sparse_probs)
  guard. Tracing fused_topk_with_score_function.cu shows the
  kernel divides by sum_scores + 1e-20, so finite non-negative
  sigmoid scores cannot produce NaN here; the filter was only
  defense against upstream NaNs, which would mask a real
  regression if anything ever did start producing them.

Signed-off-by: tdophung <tdophung@nvidia.com>
pre-commit-ci Bot and others added 23 commits June 11, 2026 17:15
for more information, see https://pre-commit.ci

Signed-off-by: tdophung <tdophung@nvidia.com>
Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
…rce at dispatch

Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
… static layer registration

Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
…er + NVTEEpHandle struct (NVTE_EP_HANDLE_CACHE_SIZE=-1 disables eviction)

Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
…CCL_EP

Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
…hout it

Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
…ogging.h

Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
…_COPY_{ON,OFF}

Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
…tyAllSymm

Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
…CUDA Toolkit)

Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
… for wheel install

Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
…bmodules

Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
…rop submodule header mirror

Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
…al CommWindow

Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
phu0ngng and others added 29 commits June 11, 2026 17:15
Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
…lint runtime/int)

Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
…pe lifetime)

Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
for more information, see https://pre-commit.ci

Signed-off-by: tdophung <tdophung@nvidia.com>
…16 max_token_dtype

Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
… with_sharding_constraint

Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
…trap

Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
…EpLayerConfig type)

Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
…ives (lint 10.00)

Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
…; define NVTE_WITH_NCCL_EP

Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
…ract, drop dead helpers

Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
Reset 33 local commits onto phuong/ep-3-jax @ c34771d (her latest with
EpConfig + EpLayerConfig API, NCCL bumped to 808d2433) and re-applied
the three deltas uniquely ours:

  * transformer_engine/jax/moe.py: replaces upstream's multi-backend
    MoE block with our TE-EP-only single-custom-vjp rewrite. Adapted
    to her new API surface: tex.EpLayerConfig replaces tex.ep_make_handle
    (no more EpHandle pool/cache); 5 EP callsites rewired (cfg passed
    in place of handle, ep_prepare arg order swapped, top_k= dropped
    from ep_dispatch_bwd since it's now in cfg.
  * tests/jax/test_te_ep_moe.py: TE-EP MoE test (kept), with
    ep_bootstrap kwargs ep_size= and allow_handle_mem_reloc= dropped
    (no longer supported; ep_size is derived from mesh axes and the
    handle_mem reloc gating is gone).
  * tests/jax/run_te_ep_moe.sh: multi-process launcher (kept).

Pre-sync state preserved at branch
teddy/te_ep_integration.backup-pre-phuong-sync.
EOF
)

Signed-off-by: tdophung <tdophung@nvidia.com>
for more information, see https://pre-commit.ci

Signed-off-by: tdophung <tdophung@nvidia.com>
…ero)

* drop ``TestZZZTeEpMoeBootstrap``: the re-bootstrap mismatch is a
  one-line guard in ``ep_bootstrap`` and not the MoE block's concern;
  exercising it from this suite also taints the per-process NCCL
  bootstrap cache for the rest of the file with no real upside.
* drop ``TestTeEpMoEBlockFlax::test_init_apply_parity``: every config
  in ``_CONFIGS`` already runs ``MoEBlock`` (the Flax wrapper)
  end-to-end via ``test_forward`` / ``test_backward``, so this was a
  duplicate of ``softmax`` parity in another wrapper -- leave wrapper
  refactors to devs without paying for an extra CI run each time.
* drop ``sigmoid-bias-zero``: with a zero-init bias buffer the routing
  math collapses to the no-bias case, so ``sigmoid`` already covers
  that numerical path. The bias-aware codepath is still exercised by
  ``sigmoid-bias-strong`` (non-zero bias).
* refresh the module-level docstring to list intentional
  non-coverage so future readers don't re-add these tests.

Signed-off-by: tdophung <tdophung@nvidia.com>
… closure)

Two unrelated one-line bugs in the bwd custom_partitioning machinery
that only surface once the MoE block's aux-loss path is lifted out of
shard_map (the custom_partitioning_sharding_rule check is skipped under
shard_map, which is why these never tripped before).

1. FusedMoEAuxLossBwdPrimitive.shardy_sharding_rule:
   ``grad_aux_loss`` is the cotangent of a scalar loss and is rank-0;
   declaring it with a spurious ``grad_one`` factor gave it rank-1 and
   tripped JAX's custom_partitioning_sharding_rule rank check at global
   view. Change the rule's third operand entry to empty:

     "const_buf_one, num_experts, grad_one -> i num_experts"
   ->
     "const_buf_one, num_experts, -> i num_experts"

2. FusedTopkWithScoreFunctionBwdPrimitive.partition:
   ``del result_infos, routing_map_format`` removed
   ``routing_map_format`` from the enclosing scope before the nested
   ``sharded_impl`` closure was invoked. Python closures resolve names
   at call time, not definition time, so when XLA finally invoked
   ``sharded_impl`` for the bwd partitioned impl it raised
   ``NameError: cannot access free variable 'routing_map_format'``.
   Drop ``routing_map_format`` from the ``del`` and leave a NOTE so
   future cleanups don't reintroduce the bug. Sibling partition
   methods (fwd topk, both aux-loss directions) already only
   ``del result_infos`` and need no change.

Signed-off-by: tdophung <tdophung@nvidia.com>
A dp_resource or fsdp_resource that exists in the active mesh resource
config but is sized 1 in the actual mesh would still be returned by
``_ep_outer_axis()``, pinning EP-output PartitionSpecs to a degenerate
axis. JAX collapses size-1 mesh axes during lowering, which made the
EP-output specs reference an axis that no longer exists at runtime --
breaking shard_map output stitching on configs where DP or FSDP is
optional.

Treat a size-1 axis as absent: prefer dp -> fsdp, but only when the
candidate axis is actually sized > 1 in the current mesh. Falls back
to the previous behaviour when no axis is configured at all.

Signed-off-by: tdophung <tdophung@nvidia.com>
After the upstream PR NVIDIA#3036 resync the moe() API surface lost
PermutationBackend (TE-EP is the only backend now), gate_inside_vjp
(always True), and the per-call quantizer_sets knob (quantization
flows through the standard TE autocast / with_quantizer_set context).
It also gained apply_topk_weights_early and renamed the wrapper's
private _align_size to the public align_size the test suite already
uses. The Flax _MoEBlock wrapper was still passing the old kwargs,
which broke every test that touched the wrapper.

Wrapper changes:
  * drop "from ..moe import PermutationBackend" plus the dataclass
    field, the isinstance(..., PermutationBackend) validation in
    __post_init__, and the pass-through to moe().
  * drop "from ..quantize import noop_quantizer_set" and the
    quantizer_sets=(noop, noop, noop) pass-through.
  * drop gate_inside_vjp=True.
  * rename _align_size: int = 0 -> align_size: int = 0 (matches
    what tests/jax/test_te_ep_moe.py already passes).
  * add apply_topk_weights_early: bool = False and pass it through
    to moe().
  * refresh class docstring: drop permutation_backend / _align_size
    / quantizer_sets descriptions, add apply_topk_weights_early /
    align_size, note that quantization currently flows only through
    fp8_autocast.

Signed-off-by: tdophung <tdophung@nvidia.com>
…ices

Two correctness fixes for the TE-EP MoE custom_vjp that together let
the bwd parity tests pass on 0-token-globally experts, and drop a
workaround that is no longer needed.

(1) Plumb per-expert padded token_counts into grouped_gemm group_sizes.

NCCL EP HT dispatch lays out recv_tokens expert-major as
  [expert_0_padded | expert_1_padded | ... | overalloc_tail]
where each per-expert block already includes the
dispatch_output_per_expert_alignment zero-padding and only the trailing
overalloc tail (slack between sum(token_counts) and the worst-case
recv_pr) is unused. Previously _ffn_fwd_per_shard built a static
local_group_sizes = jnp.full((num_local_experts,), slots_per_expert),
which over-counted by the overalloc tail and forced cuBLAS to run the
GEMM for every group including 0-token-routed experts.

Pipe the real per-shard token_counts (1, num_local_experts) from
ep_prepare through _moe_fwd_rule (added to ffn_in_specs/ffn_in_args
with ep2_spec), into _ffn_fwd_per_shard as token_counts_local, and
reshape into local_group_sizes for both grouped_quantize and
grouped_gemm. cuBLAS now skips both 0-token experts and the trailing
overalloc tail. Mirror the residual spec change on the bwd
(local_group_sizes residual moves from P() to ep2_spec).

(2) Per-group jnp.where zero-fill on wgrad outputs.

cuBLAS grouped_gemm skips groups with size_g == 0 without zero-filling
the corresponding out[g, :, :] slice (cublaslt_grouped_gemm.cu lines
2086/2096). For a shard hosting an expert that received zero tokens
globally, d_wo / d_wi_combined for that expert is left uninit, which
propagates NaN straight into the user's optimizer state.

Add wgrad_group_active = (local_group_sizes > 0)[:, None, None] in
_ffn_bwd_per_shard and apply via jnp.where on d_wo (right after the wo
wgrad) and d_wi_combined (right after the fused wi_0+wi_1 wgrad).
Mask shape is (num_local_experts, 1, 1) so cost is negligible.

(3) Drop the lax.cond zero-init guard on r_tok in _moe_fwd_rule._body.

Previously a jax.lax.cond(jnp.any(r_w != 0), identity, zeros_like)
wrapper around recv_tokens worked around tex.ep_dispatch_fwd leaving
the recv buffer uninit on fully-empty-receiver ranks. With (1) in
place, cuBLAS skips experts whose group_sizes == 0 and the per-row
trailing tail of dispatched recv_tokens is unread by every downstream
consumer (subsequent grouped_gemms read only sum(group_sizes) rows;
ep_combine and ep_dispatch_bwd are handle_mem-aware). The only
per-row consumer that would propagate the tail is grouped_dbias
(per-row segment_sum), which only runs when has_bias=True, and that
FFN bias path is currently gated upstream (cuBLAS grouped_gemm has
no fused bias on Hopper yet; PR 3083 adds the pure-JAX bias add).
With (2) handling the user-visible wgrad-NaN risk on 0-token experts,
the lax.cond is now redundant. Replace with a NOTE pointing at the
two follow-ups that would force its reintroduction:
  - a future caller that reads the full recv tile non-group-aware
    (e.g. an inspect probe), or
  - the FFN bias path landing, which would resurrect grouped_dbias.

Also rewrite the _ffn_fwd_per_shard and _ffn_bwd_per_shard docstrings
to spell out the per-row vs per-group uninit semantics so the next
person debugging a NaN here has the invariants written down.

Signed-off-by: tdophung <tdophung@nvidia.com>
for more information, see https://pre-commit.ci

Signed-off-by: tdophung <tdophung@nvidia.com>
…IA#3116)

Address jberchtold-nvidia's PR NVIDIA#3116 nit "rename use_bias ->
use_ffn_bias and use_expert_bias -> use_expert_routing_bias". The
two flags are siblings (they enable two different bias buffers) but
the old names suggested ``use_bias`` was the general fallback, which
wasn't the intent. The new names make the FFN-vs-routing distinction
obvious from the call site.

* transformer_engine/jax/flax/moe.py
    use_bias -> use_ffn_bias  (dataclass field + branch in __call__
    + docstring entry)
    use_expert_bias -> use_expert_routing_bias  (same)
* tests/jax/test_te_ep_moe.py
    _make_block(use_expert_bias=...) -> use_expert_routing_bias
    sigmoid-bias-strong config key updated
    _reference_kwargs_from_config now reads use_expert_routing_bias

``_MoEBlock`` is still the experimental underscore-prefixed alias
(no public ``MoEBlock`` export yet), so the rename is API-safe.

The pre-resync legacy tests (``test_moe_vjp.py``,
``test_multiprocess_moe_vjp.py``) are intentionally not updated --
they already reference removed APIs like ``PermutationBackend`` and
need a separate post-resync cleanup pass.

Signed-off-by: tdophung <tdophung@nvidia.com>
…and inline justifications)

Responds to jberchtold-nvidia's PR NVIDIA#3116 review threads on
``transformer_engine/jax/moe.py``. All changes are confined to a
single file because each review thread targets a localized region
and splitting mid-file would risk reordering bugs.

Per review thread:

1. "Why do we need _with_sharding_constraint_cast_bwd? I haven't
    seen something like this required for our other VJPs."
   -- Expand the helper's docstring to spell out exactly why MoE
   needs it: unlike LN+MLP, the MoE bwd composes a bf16 cotangent
   from ep_dispatch_bwd with an fp32 cotangent from
   fused_topk_with_score_function_bwd (which the fwd's
   logits_2d -> fp32 promotion forces). Without the cast, ``d_x``
   surfaces at fp32 even when ``x`` is bf16, doubling activation
   grad bandwidth and breaking any downstream LN bwd that pins a
   bf16 layout. (Review thread "Why do we need this utility
   function?".)

2. "Why is this dtype casting required? I don't recall us needing
    it for the non-MoE LNMLP block."
   -- Expand the comment above the bwd activation fp32 promotion
   to explain the MoE-specific math: LN+MLP's silu sits behind a
   downstream LN that absorbs the bf16 rounding error, while
   MoE's silu sits on the *expert* side of routing -- the bf16
   rounding rides directly into expert_outputs and is summed
   across topk experts by ep_combine. Bf16 silu alone drifts ~1%
   vs fp32 silu and compounds through wo->combine into the ~1.4%
   per-element parity gap we measured against the pure-JAX
   softmax reference. Mirroring the fwd's fp32 promotion in the
   bwd keeps silu' in lock-step with silu. (Review thread on
   "# Activation bwd. Mirror the fwd's fp32 promotion of
   silu+multiply".)

3. "Do we have a use-case for user-specified alignments beyond
    128 currently? ... it'd make sense to instead hardcode
    _ALIGN_SIZE = 128 as a constant at the top of the file for
    now to simplify this MoEBlock API. We can always expand the
    API to support a user-specified align size in the future."
   -- Implement the suggestion. Drop ``align_size`` from
   ``_moe_fwd_rule`` / ``_moe_bwd_rule`` / ``_moe`` / public
   ``moe()``; shift the ``custom_vjp`` ``nondiff_argnums`` from
   ``range(9, 27)`` -> ``range(9, 26)``; replace ``effective_align
   = max(int(align_size), 128)`` with the new module-level
   ``_ALIGN_SIZE = 128`` constant. Trim the ``moe()`` docstring
   accordingly. (Review thread on
   "natural_spe = num_ep * max_tokens_per_rank".)

4. "Which axis name inputs are physical mesh axes and why can be
    logical axes? ... No need to make any changes for now, I just
    want to assess which are which and then we can discuss if it
    makes sense to support logical on some/all or if some are
    required to be physical axes."
   -- Add an "Axis-name parameters" section to ``moe()``'s
   docstring listing which kwargs are physical mesh axes
   (``ep_axis``, ``data_parallelism_axes`` -- they index
   ``Mesh.shape`` directly to compute ``num_ep`` / ``dp_size``
   and to construct the ``P((dp..., ep), None, None)`` for
   ``jax.lax.with_sharding_constraint``) vs logical axes
   (``input_axes``, ``gate_kernel_axes``, ``wi_kernel_axes``,
   ``wo_kernel_axes`` -- resolved via the Flax logical-axis
   rules). Also document why ``ep_axis`` / ``data_parallelism_axes``
   are intentionally non-logical: the EP comm-group construction
   (``dp_color = rank // ep_size``) and the bootstrap signature
   check both require concrete integer sizes. (Review thread on
   "batch_pspec_axis = (*data_parallelism_axes, ep_axis)".)

5. "Is this NaN filtering a debugging artifact or something we
    need in the final version?"
   -- Strengthen the inline comment above
   ``sparse_probs = jnp.where(jnp.isnan(sparse_probs), 0, ...)``
   to explicitly call this out as a CORRECTNESS REQUIREMENT, not
   a debugging artifact: it covers the sigmoid+K>1 underflow
   path where top-K sigmoid scores all round to zero and the
   ``weights / (weights.sum + 1e-20)`` normalisation emits NaN.
   Observationally the filter is a no-op on the dense unit-test
   distributions, but it must stay in for sparse / production
   routing. (Review thread on
   "sparse_probs = jnp.where(jnp.isnan(sparse_probs), ...).")

Not addressed in this commit (intentional):

* Review thread on the ``align_size: int = 0`` placeholder in
  ``flax/moe.py`` ("Placeholder comment for me to fix this so
  align_size is inferred automatically based on the recipe and
  doesn't need to be specified by the user"). That's
  jberchtold's own follow-up.
* Review thread on the explicit ``tree_flatten`` /
  ``tree_unflatten`` on ``_Ctx`` ("better to use the
  ``@flax_struct.dataclass``"). Deferred to a separate, testable
  commit because changing a ``custom_vjp`` residual's pytree
  registration touches subtle ordering / None-handling semantics
  that warrant their own bisect surface.
* Review thread on ``use_bias`` / ``use_expert_bias`` renames --
  handled in the immediately preceding commit
  ``jax/flax,tests: rename use_bias/use_expert_bias for symmetry``.
* Review thread on the ``expert_bias`` fp32 init -- already
  resolved during the Phuong PR NVIDIA#3036 resync (the redundant
  ``jnp.float32`` second-dtype argument on ``self.param`` was
  dropped; ``expert_bias`` now lives at ``self.dtype``).

Signed-off-by: tdophung <tdophung@nvidia.com>
…aN sanitizer

* Rewrite the inline justifications added in 078a7d80 so each one
  reads as standalone code documentation, not as a reply to a
  reviewer: drop "per PR NVIDIA#3116 review", "review feedback",
  "Renamed from ... per PR ..." and similar PR/thread references
  from moe.py, flax/moe.py, and tests/jax/test_te_ep_moe.py.
  Technical content (why the fp32 promotion is needed for the MoE
  silu+multiply, why _with_sharding_constraint_cast_bwd exists,
  physical-vs-logical axis split in moe() docstring, the 128
  alignment rationale) is preserved and reframed to be useful to
  a reader who has no PR context.

* Drop the jnp.where(jnp.isnan(sparse_probs), 0, sparse_probs)
  guard. Tracing fused_topk_with_score_function.cu shows the
  kernel divides by sum_scores + 1e-20, so finite non-negative
  sigmoid scores cannot produce NaN here; the filter was only
  defense against upstream NaNs, which would mask a real
  regression if anything ever did start producing them.

Signed-off-by: tdophung <tdophung@nvidia.com>
@tdophung tdophung force-pushed the teddy/te_ep_integration branch from 68617ea to fe44697 Compare June 12, 2026 00:15
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants