Skip to content

Fuse expert grouping and padding into the MXFP8 EP dispatch path#4086

Open
MagellaX wants to merge 3 commits intopytorch:mainfrom
MagellaX:feat/4050-fused-mxfp8-a2a-padding
Open

Fuse expert grouping and padding into the MXFP8 EP dispatch path#4086
MagellaX wants to merge 3 commits intopytorch:mainfrom
MagellaX:feat/4050-fused-mxfp8-a2a-padding

Conversation

@MagellaX
Copy link
Copy Markdown
Contributor

Summary

This PR wires the fused MXFP8 dispatch-and-group path into the expert-parallel flow.

Instead of doing:

  • MXFP8 all-to-all dispatch
  • separate expert regroup / permutation
  • separate padding for grouped GEMM

the EP path can now dispatch directly into grouped, padded MXFP8 layout that is ready for grouped GEMM.

What changed

  • Added a2a_dispatch_and_group_mxfp8_fwd_hp_bwd(...) in torchao/prototype/moe_training/ep/a2a_dispatch.py
  • Integrated the fused path into the MXFP8 expert-parallel example
  • Finished the grouped MXFP8 all-to-all kernel plumbing in torchao/prototype/moe_training/kernels/mxfp8/comms.py
  • Added EP-level coverage for:
    • fused dispatch/group metadata
    • forward integration through grouped GEMM / unpermute
    • compiled EP pipeline

Why

#4050 is about removing the extra post-dispatch regroup/padding step from the MXFP8 EP flow.

This change moves that layout work into the dispatch path itself, so the output of dispatch is already:

  • grouped by local expert
  • padded to grouped-GEMM alignment
  • still in MXFP8

Validation

Local:

  • ruff check --fix
  • ruff format
  • python -m compileall ...

Hardware:

  • Modal 4x B200

Notes

  • This PR focuses on correctness and EP-path integration.
  • It does not include a separate end-to-end performance benchmark in this diff.

addresses #4050

@pytorch-bot
Copy link
Copy Markdown

pytorch-bot bot commented Mar 14, 2026

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/ao/4086

Note: Links to docs will display an error until the docs builds have been completed.

❌ 1 New Failure

As of commit 4bda581 with merge base ab4a336 (image):

NEW FAILURE - The following job has failed:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@meta-cla meta-cla bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Mar 14, 2026
@danielvegamyhre
Copy link
Copy Markdown
Contributor

Hi @MagellaX looks like this is using my WIP PR #4066 is that right?

@danielvegamyhre danielvegamyhre added mx moe module: training quantize_ api training flow labels Mar 14, 2026
@danielvegamyhre danielvegamyhre added this to the MXFP8 Training milestone Mar 14, 2026
@danielvegamyhre danielvegamyhre self-requested a review March 14, 2026 17:07
Copy link
Copy Markdown
Contributor

@danielvegamyhre danielvegamyhre left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for helping with this work! Based on some parts of the code, I think there may be some confusion about the intended design/goals here, let me finish my WIP PR to get the initial prototype working and we can go from there

[input.shape[0]], device=input.device, dtype=torch.int64
)
dist.all_reduce(max_input_rows_per_rank, op=dist.ReduceOp.MAX, group=group)
max_output_rows_per_rank = int(max_input_rows_per_rank.item()) * ep_degree + (
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

d2h sync here (can't do this, kills perf)

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah, makes sense, this should not be doing a host sync here. I was using it just to size the buffer conservatively in the EP wrapper, but that’s not the right shape for the final path.

f"kernel output_splits {output_splits_tensor.tolist()} do not match expected {output_splits}"
)

padded_rows = int(padded_group_end_offsets[-1].item())
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

d2h sync here from .item()

)
permuted_indices = permuted_indices.to(device=input.device, dtype=torch.int32)
assert torch.equal(group_offsets.to(torch.int64), padded_group_end_offsets), (
f"group_offsets {group_offsets.tolist()} do not match padded_group_end_offsets {padded_group_end_offsets.tolist()}"
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

f-strings are not evaluated lazily, the .tolist() will cause a d2h sync

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yep nice catch....


padded_rows = int(padded_group_end_offsets[-1].item())
permuted_indices, num_tokens_per_expert_padded, group_offsets = (
generate_permute_indices(
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this function is used for permuting the tokens from rank-major to expert-major and adding per group padding, it should not be needed in this impl since the a2a should write directly to expert-major

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

agreed, after looking more closely at #4066, this is the clearest signal that I was still partially thinking in terms of the old rank-major -> expert-major regroup path. if the fused a2a writes directly to expert-major padded layout, this extra step should not exist.

max_input_rows_per_rank = torch.tensor(
[input.shape[0]], device=input.device, dtype=torch.int64
)
dist.all_reduce(max_input_rows_per_rank, op=dist.ReduceOp.MAX, group=group)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i'm confused why there is a new all-reduce being added here, we shouldn't need to do this, we just overallocate the sym mem buffer for the worst case of all tokens being routed to the same rank

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah, fair point. I added that to derive a max row count for the wrapper allocation, but that’s not the right approach if we’re just overallocating symm-mem for the true worst case....so yeah, this should not be added.

@MagellaX
Copy link
Copy Markdown
Contributor Author

Hi @MagellaX looks like this is using my WIP PR #4066 is that right?

ahh i didn't even knew that this PR was there :) at the first place, but now looking at the code, i can see the overlap more clearly.

i mean it's definitely in the low-level fused expert-major MXFP8 a2a work in your #4066. that part is very much the same direction.

what I added on top here is mostly the EP-path wiring around it
so yeah, the kernel-level idea overlaps with #4066, while this PR is more about carrying that path through the actual EP API + tests.

it wasn’t intentionally based on your PR, but after comparing the code, I agree they’re close enough... :)

@MagellaX
Copy link
Copy Markdown
Contributor Author

hey yo!! @danielvegamyhre let me know

@MagellaX
Copy link
Copy Markdown
Contributor Author

hey yo!! @danielvegamyhre let me know

lmk any updates!! @danielvegamyhre, btw, loved your blog on mxfp8 gemm!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. module: training quantize_ api training flow moe mx

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants