Skip to content

Make the output of MoE forward method have expected output in non cuda backends#19170

Merged
Gasoonjia merged 8 commits intomainfrom
moe-no-float
Apr 29, 2026
Merged

Make the output of MoE forward method have expected output in non cuda backends#19170
Gasoonjia merged 8 commits intomainfrom
moe-no-float

Conversation

@Gasoonjia
Copy link
Copy Markdown
Contributor

@Gasoonjia Gasoonjia commented Apr 27, 2026

Dropped the unconditional .float() from the temperature is None branch of Qwen35MoE.forward to keep its output having the model author's expected dtype.

Qwen 3.5 MoE perf comparsion between this PR and e2eb417

i did detailed performance comparsion between this PR and the state before applying cuda sampler (commit e2eb417) to see if we can bring perf back.

TLDR: With this PR our perf is same or even better than the previous state when running on tiny model across mlx and metal, and on full model + mlx, but crashed on full model on metal; on full model mlx

Tiny Model

Setup: M3 Max 128 GB · macOS 26.4 · Xcode 26.4.1 · --tiny-test model · MLX --qlinear 4w --qlinear-group-size 32 · Metal --qlinear fpa4w · all measurements use in-process warmup (MLX: warmup at prefill + decode shapes + force-eval; Metal: --warmup_iters 2 --warmup_decode_steps 4 --ignore_eos) · median of 3-6 trials.

MLX (Python pybindings)

Config Metric Before this PR After this PR Δ
prompt-len=4, max-new=5 Prefill tok/s 1077 1195 +11%
prompt-len=4, max-new=5 Decode tok/s 294 350 +19%
prompt-len=32, max-new=31 Prefill tok/s 7060 10842 +54%
prompt-len=32, max-new=31 Decode tok/s 314 267* −15% (within trial noise, see note)

* prompt=32 decode trial-by-trial: 281 / 267 / 247 (3 trials). Trial-to-trial spread is ~14%, so the apparent regression is within noise.

Metal (C++ runner)

Config Metric Before this PR (median of 6) After this PR (median of 6) Δ
prompt-len=32, max-new=31 Prefill tok/s (mean ex-cold) 5351 5988 +12%
prompt-len=32, max-new=31 Decode tok/s (mean ex-cold) 217 286 +32%
prompt-len=32, max-new=31 Decode tok/s (median ex-cold) 237 290 +22%

Full Model

Setup: Qwen/Qwen3.5-35B-A3B (40 layers, 2048d, 256 experts top-8, 67 GB safetensors) · M3 Max 128 GB · macOS 26.4 · Xcode 26.4.1 · MLX --qlinear 4w --qlinear-group-size 64 · in-process warmup at prefill+decode shapes + force-eval after prefill · median of 3 trials per config.

MLX (full Qwen 3.5 MoE 35B-A3B)

Config Metric Before this PR After this PR Δ
prompt=4, max-new=5 Prefill tok/s 133.7 163.6 +22%
prompt=4, max-new=5 Decode tok/s 36.4 44.7 +23%
prompt=32, max-new=32 Prefill tok/s 404.3 443.4 +10%
prompt=32, max-new=32 Decode tok/s 37.2 43.4 +17%
prompt=128, max-new=64 Prefill tok/s 650.3 711.5 +9%
prompt=128, max-new=64 Decode tok/s 38.5 43.1 +12%

Trial-to-trial variance is small (≤1 tok/s on decode, ≤5% on prefill) so all deltas are signal.

Metal (full Qwen 3.5 MoE 35B-A3B)

Not measured. Metal export of the 35B model OOM-kills on the 128 GB Mac during AOTI inductor compilation (Killed: 9 exit 137). Confirmed across 3 attempts: default settings, TORCHINDUCTOR_COMPILE_THREADS=1, and --max-seq-len 1024. The transient peak during AOTI lowering exceeds available RAM. Tiny-model Metal A/B (already collected, see prior summary) shows the same pattern: prefill +12%, decode +22~+32%.

Conclusion

No regression on either backend; meaningful uplift on both. MLX shows the cleanest improvement on prefill (+11~+54%) and decode (+19% at small prompt). Metal shows +12% prefill and +22~+32% decode at prompt=32. The single MLX prompt-32 decode delta is within trial-to-trial variance.

@Gasoonjia Gasoonjia requested a review from lucylq as a code owner April 27, 2026 20:43
@pytorch-bot
Copy link
Copy Markdown

pytorch-bot Bot commented Apr 27, 2026

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/executorch/19170

Note: Links to docs will display an error until the docs builds have been completed.

❗ 1 Active SEVs

There are 1 currently active SEVs. If your PR is affected, please view them below:

❌ 4 Cancelled Jobs, 1 Pending, 2 Unrelated Failures

As of commit 4e5c02d with merge base 4ac044b (image):

CANCELLED JOBS - The following jobs were cancelled. Please retry:

BROKEN TRUNK - The following jobs failed but were present on the merge base:

👉 Rebase onto the `viable/strict` branch to avoid these failures

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@meta-cla meta-cla Bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Apr 27, 2026
@Gasoonjia Gasoonjia requested a review from metascroy April 27, 2026 20:43
@github-actions
Copy link
Copy Markdown

This PR needs a release notes: label

If your change should be included in the release notes (i.e. would users of this library care about this change?), please use a label starting with release notes:. This helps us keep track and include your important work in the next release notes.

To add a label, you can comment to pytorchbot, for example
@pytorchbot label "release notes: none"

For more information, see
https://github.com/pytorch/pytorch/wiki/PyTorch-AutoLabel-Bot#why-categorize-for-release-notes-and-how-does-it-work.


// Use a very small temperature for greedy to avoid division by zero
// while keeping the Gumbel noise negligible relative to logit differences.
#ifdef EXECUTORCH_BUILD_CUDA
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we make this more generic? Detect if pte was exported with sampler built in and route appropriately?

Then make fuse-sampler an export arg that is on for cuda and off for mlx/metal for now?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i would like to keep current status; i don't think fusing sampler into model's forward method is a good practice. This is a temporary solution before device support and once we get it in the near future all modules should return a logit and use a sampler tool to do sampling.

Qwen35MoE.forward currently routes through an Optional[Tensor] temperature
parameter and an if/else that picks between the on-device fused Gumbel-max
sampler (CUDA) and raw logits (non-CUDA). The sampling branch is dead code
for MLX and Metal exports, since those backends sample on the host.

Even though torch.export statically eliminates the branch when temperature
defaults to None, the parameter, default value, and unused else-branch leak
into the exported program: extra placeholder nodes, different graph hashes,
and shifted kernel selection in the lowered MLX/Metal graph. On the tiny
test model this slows MLX prefill ~9-37% and decode ~5-19%, and shows up as
~10-25% noise on Metal.

Bind model.forward to a minimal (tokens, input_pos) -> logits variant inside
_export_mlx and _export_metal before torch.export, so the captured program
matches what the backend kernels are tuned for. Eager-mode callers and the
CUDA export path are unaffected.
@metascroy
Copy link
Copy Markdown
Contributor

@Gasoonjia I wouldn't rely on --tiny-test for perf comparisons.

Can you try with the actual model?

@Gasoonjia Gasoonjia merged commit 80997fd into main Apr 29, 2026
246 of 253 checks passed
@Gasoonjia Gasoonjia deleted the moe-no-float branch April 29, 2026 17:53
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ciflow/cuda ciflow/metal CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants