TE EP integration to MoEBlock#3116
Conversation
Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
…em_reloc gating Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
0ff3bff to
bd14fe6
Compare
| # Minimum per-expert slot alignment fed to ``tex.ep_prepare``. Default 0 | ||
| # uses the natural slot count; set to e.g. 128 to satisfy FP8 grouped-GEMM | ||
| # tile alignment. | ||
| align_size: int = 0 |
There was a problem hiding this comment.
Placeholder comment for me to fix this so align_size is inferred automatically based on the recipe and doesn't need to be specified by the user
| nn.with_logical_partitioning(self.bias_init, ("exp",)), | ||
| (self.num_experts,), | ||
| self.dtype, | ||
| jnp.float32, |
There was a problem hiding this comment.
Is the router always in fp32 so this expert bias must also be? If so, can we add a small comment indicating this
There was a problem hiding this comment.
yes I will add a comment
|
|
||
|
|
||
| __all__ = ["moe", "PermutationBackend"] | ||
| def _with_sharding_constraint_cast_bwd(x: jnp.ndarray, sharding) -> jnp.ndarray: |
There was a problem hiding this comment.
Why do we need this utility function? I haven't seen something like this required for our other VJPs
There was a problem hiding this comment.
for MoE particularly, In our _moe_bwd_rule, d_x is built from two cotangent paths:
d_x_from_dispatch = tex.ep_dispatch_bwd(...) in bf16 if x is in bf16.
and
d_x_from_gate = d_logits_2d @ gate_kernel^T, where d_logits_2d comes from tex.fused_topk_with_score_function_bwd, which only ooutput the logits in fp32 (per alp in one of the weekly meetings + I checked the code for the router kernel). So d_x_from_gate is fp32.
Therefore, d_x = d_x_from_dispatch + d_x_from_gate = bf16 + fp32 = fp32, while our x could be bf16. So the cotangent that flows to bwd needs to be constrainted. Other VJPs don't have this issue because the dgrad will just be the same dtype as the activation dtype.
| # is a frozen dataclass of ints); the rest are jnp.ndarray, | ||
| # GroupedNoScaleTensor (already a pytree), or None when aux_loss_coeff == 0. | ||
| @register_pytree_node_class | ||
| @dataclass |
There was a problem hiding this comment.
I think this tree_flatten was from my patch, but looking at the diff I think it'd be better to use the @flax_struct.dataclass you were using on the permutation dataclasses since that seems to auto-populate a default pytree flatten/unflatten for us
| else: | ||
| d_recv_w_from_intermediate = jnp.zeros_like(recv_w_flat) | ||
|
|
||
| # Activation bwd. Mirror the fwd's fp32 promotion of silu+multiply |
There was a problem hiding this comment.
Why is this dtype casting required? I don't recall us needing it for the non-MoE LNMLP block
| # local expert. We must size to that worst case or NCCL EP's HT kernel | ||
| # rejects the dispatch buffer with ``invalid argument``. | ||
| natural_spe = num_ep * max_tokens_per_rank # = (B // dp_size) * S | ||
| # NCCL EP requires each expert-major output block to be at least |
There was a problem hiding this comment.
Do we have a use-case for user-specified alignments beyond 128 currently? If NCCL EP requires an alignment of at least 128, and since an alignment of 128 is sufficient for all TE grouped GEMM types, would it make sense to instead hardcode _ALIGN_SIZE = 128 as a constant at the top of the file for now to simplify this MoEBlock API.
We can always expand the API to support a user-specified align size in the future
| batch_pspec_axis = (*data_parallelism_axes, ep_axis) | ||
| ep3_spec = P(batch_pspec_axis, None, None) | ||
| ep2_spec = P(batch_pspec_axis, None) | ||
| x = jax.lax.with_sharding_constraint(x, NamedSharding(mesh, ep3_spec)) |
There was a problem hiding this comment.
Which axis name inputs are physical mesh axes and why can be logical axes? I see above x = with_sharding_constraint_by_logical_axes(x, input_axes) but here we directly use jax.lax.with_sharding_constraint which only supports mesh axes.
No need to make any changes for now, I just want to assess which are which and then we can discuss if it makes sense to support logical on some/all or if some are required to be physical axes. Thanks!
| # `grad_pre_combine * w` sees them. Padded positions in sparse_probs | ||
| # are already zero (routing_map is False there); only the rare | ||
| # underflow path emits NaN. | ||
| sparse_probs = jnp.where(jnp.isnan(sparse_probs), 0, sparse_probs).astype(dtype) |
There was a problem hiding this comment.
Is this NaN filtering a debugging artifact or something we need in the final version?
There was a problem hiding this comment.
debugging artifact. Remopving
…IA#3116) Address jberchtold-nvidia's PR NVIDIA#3116 nit "rename use_bias -> use_ffn_bias and use_expert_bias -> use_expert_routing_bias". The two flags are siblings (they enable two different bias buffers) but the old names suggested ``use_bias`` was the general fallback, which wasn't the intent. The new names make the FFN-vs-routing distinction obvious from the call site. * transformer_engine/jax/flax/moe.py use_bias -> use_ffn_bias (dataclass field + branch in __call__ + docstring entry) use_expert_bias -> use_expert_routing_bias (same) * tests/jax/test_te_ep_moe.py _make_block(use_expert_bias=...) -> use_expert_routing_bias sigmoid-bias-strong config key updated _reference_kwargs_from_config now reads use_expert_routing_bias ``_MoEBlock`` is still the experimental underscore-prefixed alias (no public ``MoEBlock`` export yet), so the rename is API-safe. The pre-resync legacy tests (``test_moe_vjp.py``, ``test_multiprocess_moe_vjp.py``) are intentionally not updated -- they already reference removed APIs like ``PermutationBackend`` and need a separate post-resync cleanup pass. Signed-off-by: tdophung <tdophung@nvidia.com>
…and inline justifications) Responds to jberchtold-nvidia's PR NVIDIA#3116 review threads on ``transformer_engine/jax/moe.py``. All changes are confined to a single file because each review thread targets a localized region and splitting mid-file would risk reordering bugs. Per review thread: 1. "Why do we need _with_sharding_constraint_cast_bwd? I haven't seen something like this required for our other VJPs." -- Expand the helper's docstring to spell out exactly why MoE needs it: unlike LN+MLP, the MoE bwd composes a bf16 cotangent from ep_dispatch_bwd with an fp32 cotangent from fused_topk_with_score_function_bwd (which the fwd's logits_2d -> fp32 promotion forces). Without the cast, ``d_x`` surfaces at fp32 even when ``x`` is bf16, doubling activation grad bandwidth and breaking any downstream LN bwd that pins a bf16 layout. (Review thread "Why do we need this utility function?".) 2. "Why is this dtype casting required? I don't recall us needing it for the non-MoE LNMLP block." -- Expand the comment above the bwd activation fp32 promotion to explain the MoE-specific math: LN+MLP's silu sits behind a downstream LN that absorbs the bf16 rounding error, while MoE's silu sits on the *expert* side of routing -- the bf16 rounding rides directly into expert_outputs and is summed across topk experts by ep_combine. Bf16 silu alone drifts ~1% vs fp32 silu and compounds through wo->combine into the ~1.4% per-element parity gap we measured against the pure-JAX softmax reference. Mirroring the fwd's fp32 promotion in the bwd keeps silu' in lock-step with silu. (Review thread on "# Activation bwd. Mirror the fwd's fp32 promotion of silu+multiply".) 3. "Do we have a use-case for user-specified alignments beyond 128 currently? ... it'd make sense to instead hardcode _ALIGN_SIZE = 128 as a constant at the top of the file for now to simplify this MoEBlock API. We can always expand the API to support a user-specified align size in the future." -- Implement the suggestion. Drop ``align_size`` from ``_moe_fwd_rule`` / ``_moe_bwd_rule`` / ``_moe`` / public ``moe()``; shift the ``custom_vjp`` ``nondiff_argnums`` from ``range(9, 27)`` -> ``range(9, 26)``; replace ``effective_align = max(int(align_size), 128)`` with the new module-level ``_ALIGN_SIZE = 128`` constant. Trim the ``moe()`` docstring accordingly. (Review thread on "natural_spe = num_ep * max_tokens_per_rank".) 4. "Which axis name inputs are physical mesh axes and why can be logical axes? ... No need to make any changes for now, I just want to assess which are which and then we can discuss if it makes sense to support logical on some/all or if some are required to be physical axes." -- Add an "Axis-name parameters" section to ``moe()``'s docstring listing which kwargs are physical mesh axes (``ep_axis``, ``data_parallelism_axes`` -- they index ``Mesh.shape`` directly to compute ``num_ep`` / ``dp_size`` and to construct the ``P((dp..., ep), None, None)`` for ``jax.lax.with_sharding_constraint``) vs logical axes (``input_axes``, ``gate_kernel_axes``, ``wi_kernel_axes``, ``wo_kernel_axes`` -- resolved via the Flax logical-axis rules). Also document why ``ep_axis`` / ``data_parallelism_axes`` are intentionally non-logical: the EP comm-group construction (``dp_color = rank // ep_size``) and the bootstrap signature check both require concrete integer sizes. (Review thread on "batch_pspec_axis = (*data_parallelism_axes, ep_axis)".) 5. "Is this NaN filtering a debugging artifact or something we need in the final version?" -- Strengthen the inline comment above ``sparse_probs = jnp.where(jnp.isnan(sparse_probs), 0, ...)`` to explicitly call this out as a CORRECTNESS REQUIREMENT, not a debugging artifact: it covers the sigmoid+K>1 underflow path where top-K sigmoid scores all round to zero and the ``weights / (weights.sum + 1e-20)`` normalisation emits NaN. Observationally the filter is a no-op on the dense unit-test distributions, but it must stay in for sparse / production routing. (Review thread on "sparse_probs = jnp.where(jnp.isnan(sparse_probs), ...).") Not addressed in this commit (intentional): * Review thread on the ``align_size: int = 0`` placeholder in ``flax/moe.py`` ("Placeholder comment for me to fix this so align_size is inferred automatically based on the recipe and doesn't need to be specified by the user"). That's jberchtold's own follow-up. * Review thread on the explicit ``tree_flatten`` / ``tree_unflatten`` on ``_Ctx`` ("better to use the ``@flax_struct.dataclass``"). Deferred to a separate, testable commit because changing a ``custom_vjp`` residual's pytree registration touches subtle ordering / None-handling semantics that warrant their own bisect surface. * Review thread on ``use_bias`` / ``use_expert_bias`` renames -- handled in the immediately preceding commit ``jax/flax,tests: rename use_bias/use_expert_bias for symmetry``. * Review thread on the ``expert_bias`` fp32 init -- already resolved during the Phuong PR NVIDIA#3036 resync (the redundant ``jnp.float32`` second-dtype argument on ``self.param`` was dropped; ``expert_bias`` now lives at ``self.dtype``). Signed-off-by: tdophung <tdophung@nvidia.com>
…aN sanitizer * Rewrite the inline justifications added in 078a7d80 so each one reads as standalone code documentation, not as a reply to a reviewer: drop "per PR NVIDIA#3116 review", "review feedback", "Renamed from ... per PR ..." and similar PR/thread references from moe.py, flax/moe.py, and tests/jax/test_te_ep_moe.py. Technical content (why the fp32 promotion is needed for the MoE silu+multiply, why _with_sharding_constraint_cast_bwd exists, physical-vs-logical axis split in moe() docstring, the 128 alignment rationale) is preserved and reframed to be useful to a reader who has no PR context. * Drop the jnp.where(jnp.isnan(sparse_probs), 0, sparse_probs) guard. Tracing fused_topk_with_score_function.cu shows the kernel divides by sum_scores + 1e-20, so finite non-negative sigmoid scores cannot produce NaN here; the filter was only defense against upstream NaNs, which would mask a real regression if anything ever did start producing them. Signed-off-by: tdophung <tdophung@nvidia.com>
for more information, see https://pre-commit.ci Signed-off-by: tdophung <tdophung@nvidia.com>
Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
…rce at dispatch Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
… static layer registration Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
…er + NVTEEpHandle struct (NVTE_EP_HANDLE_CACHE_SIZE=-1 disables eviction) Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
…CCL_EP Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
…hout it Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
…ogging.h Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
…_COPY_{ON,OFF}
Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
…tyAllSymm Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
…CUDA Toolkit) Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
… for wheel install Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
…bmodules Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
…rop submodule header mirror Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
…al CommWindow Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
…lint runtime/int) Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
…pe lifetime) Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
for more information, see https://pre-commit.ci Signed-off-by: tdophung <tdophung@nvidia.com>
…16 max_token_dtype Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
… with_sharding_constraint Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
…trap Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
…EpLayerConfig type) Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
…ives (lint 10.00) Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
…; define NVTE_WITH_NCCL_EP Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
…ract, drop dead helpers Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
Reset 33 local commits onto phuong/ep-3-jax @ c34771d (her latest with EpConfig + EpLayerConfig API, NCCL bumped to 808d2433) and re-applied the three deltas uniquely ours: * transformer_engine/jax/moe.py: replaces upstream's multi-backend MoE block with our TE-EP-only single-custom-vjp rewrite. Adapted to her new API surface: tex.EpLayerConfig replaces tex.ep_make_handle (no more EpHandle pool/cache); 5 EP callsites rewired (cfg passed in place of handle, ep_prepare arg order swapped, top_k= dropped from ep_dispatch_bwd since it's now in cfg. * tests/jax/test_te_ep_moe.py: TE-EP MoE test (kept), with ep_bootstrap kwargs ep_size= and allow_handle_mem_reloc= dropped (no longer supported; ep_size is derived from mesh axes and the handle_mem reloc gating is gone). * tests/jax/run_te_ep_moe.sh: multi-process launcher (kept). Pre-sync state preserved at branch teddy/te_ep_integration.backup-pre-phuong-sync. EOF ) Signed-off-by: tdophung <tdophung@nvidia.com>
for more information, see https://pre-commit.ci Signed-off-by: tdophung <tdophung@nvidia.com>
…ero) * drop ``TestZZZTeEpMoeBootstrap``: the re-bootstrap mismatch is a one-line guard in ``ep_bootstrap`` and not the MoE block's concern; exercising it from this suite also taints the per-process NCCL bootstrap cache for the rest of the file with no real upside. * drop ``TestTeEpMoEBlockFlax::test_init_apply_parity``: every config in ``_CONFIGS`` already runs ``MoEBlock`` (the Flax wrapper) end-to-end via ``test_forward`` / ``test_backward``, so this was a duplicate of ``softmax`` parity in another wrapper -- leave wrapper refactors to devs without paying for an extra CI run each time. * drop ``sigmoid-bias-zero``: with a zero-init bias buffer the routing math collapses to the no-bias case, so ``sigmoid`` already covers that numerical path. The bias-aware codepath is still exercised by ``sigmoid-bias-strong`` (non-zero bias). * refresh the module-level docstring to list intentional non-coverage so future readers don't re-add these tests. Signed-off-by: tdophung <tdophung@nvidia.com>
… closure)
Two unrelated one-line bugs in the bwd custom_partitioning machinery
that only surface once the MoE block's aux-loss path is lifted out of
shard_map (the custom_partitioning_sharding_rule check is skipped under
shard_map, which is why these never tripped before).
1. FusedMoEAuxLossBwdPrimitive.shardy_sharding_rule:
``grad_aux_loss`` is the cotangent of a scalar loss and is rank-0;
declaring it with a spurious ``grad_one`` factor gave it rank-1 and
tripped JAX's custom_partitioning_sharding_rule rank check at global
view. Change the rule's third operand entry to empty:
"const_buf_one, num_experts, grad_one -> i num_experts"
->
"const_buf_one, num_experts, -> i num_experts"
2. FusedTopkWithScoreFunctionBwdPrimitive.partition:
``del result_infos, routing_map_format`` removed
``routing_map_format`` from the enclosing scope before the nested
``sharded_impl`` closure was invoked. Python closures resolve names
at call time, not definition time, so when XLA finally invoked
``sharded_impl`` for the bwd partitioned impl it raised
``NameError: cannot access free variable 'routing_map_format'``.
Drop ``routing_map_format`` from the ``del`` and leave a NOTE so
future cleanups don't reintroduce the bug. Sibling partition
methods (fwd topk, both aux-loss directions) already only
``del result_infos`` and need no change.
Signed-off-by: tdophung <tdophung@nvidia.com>
A dp_resource or fsdp_resource that exists in the active mesh resource config but is sized 1 in the actual mesh would still be returned by ``_ep_outer_axis()``, pinning EP-output PartitionSpecs to a degenerate axis. JAX collapses size-1 mesh axes during lowering, which made the EP-output specs reference an axis that no longer exists at runtime -- breaking shard_map output stitching on configs where DP or FSDP is optional. Treat a size-1 axis as absent: prefer dp -> fsdp, but only when the candidate axis is actually sized > 1 in the current mesh. Falls back to the previous behaviour when no axis is configured at all. Signed-off-by: tdophung <tdophung@nvidia.com>
After the upstream PR NVIDIA#3036 resync the moe() API surface lost PermutationBackend (TE-EP is the only backend now), gate_inside_vjp (always True), and the per-call quantizer_sets knob (quantization flows through the standard TE autocast / with_quantizer_set context). It also gained apply_topk_weights_early and renamed the wrapper's private _align_size to the public align_size the test suite already uses. The Flax _MoEBlock wrapper was still passing the old kwargs, which broke every test that touched the wrapper. Wrapper changes: * drop "from ..moe import PermutationBackend" plus the dataclass field, the isinstance(..., PermutationBackend) validation in __post_init__, and the pass-through to moe(). * drop "from ..quantize import noop_quantizer_set" and the quantizer_sets=(noop, noop, noop) pass-through. * drop gate_inside_vjp=True. * rename _align_size: int = 0 -> align_size: int = 0 (matches what tests/jax/test_te_ep_moe.py already passes). * add apply_topk_weights_early: bool = False and pass it through to moe(). * refresh class docstring: drop permutation_backend / _align_size / quantizer_sets descriptions, add apply_topk_weights_early / align_size, note that quantization currently flows only through fp8_autocast. Signed-off-by: tdophung <tdophung@nvidia.com>
…ices
Two correctness fixes for the TE-EP MoE custom_vjp that together let
the bwd parity tests pass on 0-token-globally experts, and drop a
workaround that is no longer needed.
(1) Plumb per-expert padded token_counts into grouped_gemm group_sizes.
NCCL EP HT dispatch lays out recv_tokens expert-major as
[expert_0_padded | expert_1_padded | ... | overalloc_tail]
where each per-expert block already includes the
dispatch_output_per_expert_alignment zero-padding and only the trailing
overalloc tail (slack between sum(token_counts) and the worst-case
recv_pr) is unused. Previously _ffn_fwd_per_shard built a static
local_group_sizes = jnp.full((num_local_experts,), slots_per_expert),
which over-counted by the overalloc tail and forced cuBLAS to run the
GEMM for every group including 0-token-routed experts.
Pipe the real per-shard token_counts (1, num_local_experts) from
ep_prepare through _moe_fwd_rule (added to ffn_in_specs/ffn_in_args
with ep2_spec), into _ffn_fwd_per_shard as token_counts_local, and
reshape into local_group_sizes for both grouped_quantize and
grouped_gemm. cuBLAS now skips both 0-token experts and the trailing
overalloc tail. Mirror the residual spec change on the bwd
(local_group_sizes residual moves from P() to ep2_spec).
(2) Per-group jnp.where zero-fill on wgrad outputs.
cuBLAS grouped_gemm skips groups with size_g == 0 without zero-filling
the corresponding out[g, :, :] slice (cublaslt_grouped_gemm.cu lines
2086/2096). For a shard hosting an expert that received zero tokens
globally, d_wo / d_wi_combined for that expert is left uninit, which
propagates NaN straight into the user's optimizer state.
Add wgrad_group_active = (local_group_sizes > 0)[:, None, None] in
_ffn_bwd_per_shard and apply via jnp.where on d_wo (right after the wo
wgrad) and d_wi_combined (right after the fused wi_0+wi_1 wgrad).
Mask shape is (num_local_experts, 1, 1) so cost is negligible.
(3) Drop the lax.cond zero-init guard on r_tok in _moe_fwd_rule._body.
Previously a jax.lax.cond(jnp.any(r_w != 0), identity, zeros_like)
wrapper around recv_tokens worked around tex.ep_dispatch_fwd leaving
the recv buffer uninit on fully-empty-receiver ranks. With (1) in
place, cuBLAS skips experts whose group_sizes == 0 and the per-row
trailing tail of dispatched recv_tokens is unread by every downstream
consumer (subsequent grouped_gemms read only sum(group_sizes) rows;
ep_combine and ep_dispatch_bwd are handle_mem-aware). The only
per-row consumer that would propagate the tail is grouped_dbias
(per-row segment_sum), which only runs when has_bias=True, and that
FFN bias path is currently gated upstream (cuBLAS grouped_gemm has
no fused bias on Hopper yet; PR 3083 adds the pure-JAX bias add).
With (2) handling the user-visible wgrad-NaN risk on 0-token experts,
the lax.cond is now redundant. Replace with a NOTE pointing at the
two follow-ups that would force its reintroduction:
- a future caller that reads the full recv tile non-group-aware
(e.g. an inspect probe), or
- the FFN bias path landing, which would resurrect grouped_dbias.
Also rewrite the _ffn_fwd_per_shard and _ffn_bwd_per_shard docstrings
to spell out the per-row vs per-group uninit semantics so the next
person debugging a NaN here has the invariants written down.
Signed-off-by: tdophung <tdophung@nvidia.com>
for more information, see https://pre-commit.ci Signed-off-by: tdophung <tdophung@nvidia.com>
…IA#3116) Address jberchtold-nvidia's PR NVIDIA#3116 nit "rename use_bias -> use_ffn_bias and use_expert_bias -> use_expert_routing_bias". The two flags are siblings (they enable two different bias buffers) but the old names suggested ``use_bias`` was the general fallback, which wasn't the intent. The new names make the FFN-vs-routing distinction obvious from the call site. * transformer_engine/jax/flax/moe.py use_bias -> use_ffn_bias (dataclass field + branch in __call__ + docstring entry) use_expert_bias -> use_expert_routing_bias (same) * tests/jax/test_te_ep_moe.py _make_block(use_expert_bias=...) -> use_expert_routing_bias sigmoid-bias-strong config key updated _reference_kwargs_from_config now reads use_expert_routing_bias ``_MoEBlock`` is still the experimental underscore-prefixed alias (no public ``MoEBlock`` export yet), so the rename is API-safe. The pre-resync legacy tests (``test_moe_vjp.py``, ``test_multiprocess_moe_vjp.py``) are intentionally not updated -- they already reference removed APIs like ``PermutationBackend`` and need a separate post-resync cleanup pass. Signed-off-by: tdophung <tdophung@nvidia.com>
…and inline justifications) Responds to jberchtold-nvidia's PR NVIDIA#3116 review threads on ``transformer_engine/jax/moe.py``. All changes are confined to a single file because each review thread targets a localized region and splitting mid-file would risk reordering bugs. Per review thread: 1. "Why do we need _with_sharding_constraint_cast_bwd? I haven't seen something like this required for our other VJPs." -- Expand the helper's docstring to spell out exactly why MoE needs it: unlike LN+MLP, the MoE bwd composes a bf16 cotangent from ep_dispatch_bwd with an fp32 cotangent from fused_topk_with_score_function_bwd (which the fwd's logits_2d -> fp32 promotion forces). Without the cast, ``d_x`` surfaces at fp32 even when ``x`` is bf16, doubling activation grad bandwidth and breaking any downstream LN bwd that pins a bf16 layout. (Review thread "Why do we need this utility function?".) 2. "Why is this dtype casting required? I don't recall us needing it for the non-MoE LNMLP block." -- Expand the comment above the bwd activation fp32 promotion to explain the MoE-specific math: LN+MLP's silu sits behind a downstream LN that absorbs the bf16 rounding error, while MoE's silu sits on the *expert* side of routing -- the bf16 rounding rides directly into expert_outputs and is summed across topk experts by ep_combine. Bf16 silu alone drifts ~1% vs fp32 silu and compounds through wo->combine into the ~1.4% per-element parity gap we measured against the pure-JAX softmax reference. Mirroring the fwd's fp32 promotion in the bwd keeps silu' in lock-step with silu. (Review thread on "# Activation bwd. Mirror the fwd's fp32 promotion of silu+multiply".) 3. "Do we have a use-case for user-specified alignments beyond 128 currently? ... it'd make sense to instead hardcode _ALIGN_SIZE = 128 as a constant at the top of the file for now to simplify this MoEBlock API. We can always expand the API to support a user-specified align size in the future." -- Implement the suggestion. Drop ``align_size`` from ``_moe_fwd_rule`` / ``_moe_bwd_rule`` / ``_moe`` / public ``moe()``; shift the ``custom_vjp`` ``nondiff_argnums`` from ``range(9, 27)`` -> ``range(9, 26)``; replace ``effective_align = max(int(align_size), 128)`` with the new module-level ``_ALIGN_SIZE = 128`` constant. Trim the ``moe()`` docstring accordingly. (Review thread on "natural_spe = num_ep * max_tokens_per_rank".) 4. "Which axis name inputs are physical mesh axes and why can be logical axes? ... No need to make any changes for now, I just want to assess which are which and then we can discuss if it makes sense to support logical on some/all or if some are required to be physical axes." -- Add an "Axis-name parameters" section to ``moe()``'s docstring listing which kwargs are physical mesh axes (``ep_axis``, ``data_parallelism_axes`` -- they index ``Mesh.shape`` directly to compute ``num_ep`` / ``dp_size`` and to construct the ``P((dp..., ep), None, None)`` for ``jax.lax.with_sharding_constraint``) vs logical axes (``input_axes``, ``gate_kernel_axes``, ``wi_kernel_axes``, ``wo_kernel_axes`` -- resolved via the Flax logical-axis rules). Also document why ``ep_axis`` / ``data_parallelism_axes`` are intentionally non-logical: the EP comm-group construction (``dp_color = rank // ep_size``) and the bootstrap signature check both require concrete integer sizes. (Review thread on "batch_pspec_axis = (*data_parallelism_axes, ep_axis)".) 5. "Is this NaN filtering a debugging artifact or something we need in the final version?" -- Strengthen the inline comment above ``sparse_probs = jnp.where(jnp.isnan(sparse_probs), 0, ...)`` to explicitly call this out as a CORRECTNESS REQUIREMENT, not a debugging artifact: it covers the sigmoid+K>1 underflow path where top-K sigmoid scores all round to zero and the ``weights / (weights.sum + 1e-20)`` normalisation emits NaN. Observationally the filter is a no-op on the dense unit-test distributions, but it must stay in for sparse / production routing. (Review thread on "sparse_probs = jnp.where(jnp.isnan(sparse_probs), ...).") Not addressed in this commit (intentional): * Review thread on the ``align_size: int = 0`` placeholder in ``flax/moe.py`` ("Placeholder comment for me to fix this so align_size is inferred automatically based on the recipe and doesn't need to be specified by the user"). That's jberchtold's own follow-up. * Review thread on the explicit ``tree_flatten`` / ``tree_unflatten`` on ``_Ctx`` ("better to use the ``@flax_struct.dataclass``"). Deferred to a separate, testable commit because changing a ``custom_vjp`` residual's pytree registration touches subtle ordering / None-handling semantics that warrant their own bisect surface. * Review thread on ``use_bias`` / ``use_expert_bias`` renames -- handled in the immediately preceding commit ``jax/flax,tests: rename use_bias/use_expert_bias for symmetry``. * Review thread on the ``expert_bias`` fp32 init -- already resolved during the Phuong PR NVIDIA#3036 resync (the redundant ``jnp.float32`` second-dtype argument on ``self.param`` was dropped; ``expert_bias`` now lives at ``self.dtype``). Signed-off-by: tdophung <tdophung@nvidia.com>
…aN sanitizer * Rewrite the inline justifications added in 078a7d80 so each one reads as standalone code documentation, not as a reply to a reviewer: drop "per PR NVIDIA#3116 review", "review feedback", "Renamed from ... per PR ..." and similar PR/thread references from moe.py, flax/moe.py, and tests/jax/test_te_ep_moe.py. Technical content (why the fp32 promotion is needed for the MoE silu+multiply, why _with_sharding_constraint_cast_bwd exists, physical-vs-logical axis split in moe() docstring, the 128 alignment rationale) is preserved and reframed to be useful to a reader who has no PR context. * Drop the jnp.where(jnp.isnan(sparse_probs), 0, sparse_probs) guard. Tracing fused_topk_with_score_function.cu shows the kernel divides by sum_scores + 1e-20, so finite non-negative sigmoid scores cannot produce NaN here; the filter was only defense against upstream NaNs, which would mask a real regression if anything ever did start producing them. Signed-off-by: tdophung <tdophung@nvidia.com>
68617ea to
fe44697
Compare
Description
Will rebase and squash the commits on this branch once about to merge
Will also change the JAX APIs if needed when TE EP JAX merge
Fixes # (issue)
Type of change
Changes
Please list the changes introduced in this PR:
Checklist: