Fuse expert grouping and padding into the MXFP8 EP dispatch path#4086
Fuse expert grouping and padding into the MXFP8 EP dispatch path#4086MagellaX wants to merge 3 commits intopytorch:mainfrom
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/ao/4086
Note: Links to docs will display an error until the docs builds have been completed. ❌ 1 New FailureAs of commit 4bda581 with merge base ab4a336 ( NEW FAILURE - The following job has failed:
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
danielvegamyhre
left a comment
There was a problem hiding this comment.
Thanks for helping with this work! Based on some parts of the code, I think there may be some confusion about the intended design/goals here, let me finish my WIP PR to get the initial prototype working and we can go from there
| [input.shape[0]], device=input.device, dtype=torch.int64 | ||
| ) | ||
| dist.all_reduce(max_input_rows_per_rank, op=dist.ReduceOp.MAX, group=group) | ||
| max_output_rows_per_rank = int(max_input_rows_per_rank.item()) * ep_degree + ( |
There was a problem hiding this comment.
d2h sync here (can't do this, kills perf)
There was a problem hiding this comment.
yeah, makes sense, this should not be doing a host sync here. I was using it just to size the buffer conservatively in the EP wrapper, but that’s not the right shape for the final path.
| f"kernel output_splits {output_splits_tensor.tolist()} do not match expected {output_splits}" | ||
| ) | ||
|
|
||
| padded_rows = int(padded_group_end_offsets[-1].item()) |
There was a problem hiding this comment.
d2h sync here from .item()
| ) | ||
| permuted_indices = permuted_indices.to(device=input.device, dtype=torch.int32) | ||
| assert torch.equal(group_offsets.to(torch.int64), padded_group_end_offsets), ( | ||
| f"group_offsets {group_offsets.tolist()} do not match padded_group_end_offsets {padded_group_end_offsets.tolist()}" |
There was a problem hiding this comment.
f-strings are not evaluated lazily, the .tolist() will cause a d2h sync
There was a problem hiding this comment.
yep nice catch....
|
|
||
| padded_rows = int(padded_group_end_offsets[-1].item()) | ||
| permuted_indices, num_tokens_per_expert_padded, group_offsets = ( | ||
| generate_permute_indices( |
There was a problem hiding this comment.
this function is used for permuting the tokens from rank-major to expert-major and adding per group padding, it should not be needed in this impl since the a2a should write directly to expert-major
There was a problem hiding this comment.
agreed, after looking more closely at #4066, this is the clearest signal that I was still partially thinking in terms of the old rank-major -> expert-major regroup path. if the fused a2a writes directly to expert-major padded layout, this extra step should not exist.
| max_input_rows_per_rank = torch.tensor( | ||
| [input.shape[0]], device=input.device, dtype=torch.int64 | ||
| ) | ||
| dist.all_reduce(max_input_rows_per_rank, op=dist.ReduceOp.MAX, group=group) |
There was a problem hiding this comment.
i'm confused why there is a new all-reduce being added here, we shouldn't need to do this, we just overallocate the sym mem buffer for the worst case of all tokens being routed to the same rank
There was a problem hiding this comment.
yeah, fair point. I added that to derive a max row count for the wrapper allocation, but that’s not the right approach if we’re just overallocating symm-mem for the true worst case....so yeah, this should not be added.
ahh i didn't even knew that this PR was there :) at the first place, but now looking at the code, i can see the overlap more clearly. i mean it's definitely in the low-level fused expert-major MXFP8 a2a work in your what I added on top here is mostly the EP-path wiring around it it wasn’t intentionally based on your PR, but after comparing the code, I agree they’re close enough... :) |
|
hey yo!! @danielvegamyhre let me know |
|
Summary
This PR wires the fused MXFP8 dispatch-and-group path into the expert-parallel flow.
Instead of doing:
the EP path can now dispatch directly into grouped, padded MXFP8 layout that is ready for grouped GEMM.
What changed
a2a_dispatch_and_group_mxfp8_fwd_hp_bwd(...)intorchao/prototype/moe_training/ep/a2a_dispatch.pytorchao/prototype/moe_training/kernels/mxfp8/comms.pyWhy
#4050is about removing the extra post-dispatch regroup/padding step from the MXFP8 EP flow.This change moves that layout work into the dispatch path itself, so the output of dispatch is already:
Validation
Local:
ruff check --fixruff formatpython -m compileall ...Hardware:
4x B200Notes
addresses #4050