Make the output of MoE forward method have expected output in non cuda backends#19170
Make the output of MoE forward method have expected output in non cuda backends#19170
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/executorch/19170
Note: Links to docs will display an error until the docs builds have been completed. ❗ 1 Active SEVsThere are 1 currently active SEVs. If your PR is affected, please view them below: ❌ 4 Cancelled Jobs, 1 Pending, 2 Unrelated FailuresAs of commit 4e5c02d with merge base 4ac044b ( CANCELLED JOBS - The following jobs were cancelled. Please retry:
BROKEN TRUNK - The following jobs failed but were present on the merge base:👉 Rebase onto the `viable/strict` branch to avoid these failures
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
This PR needs a
|
|
|
||
| // Use a very small temperature for greedy to avoid division by zero | ||
| // while keeping the Gumbel noise negligible relative to logit differences. | ||
| #ifdef EXECUTORCH_BUILD_CUDA |
There was a problem hiding this comment.
Can we make this more generic? Detect if pte was exported with sampler built in and route appropriately?
Then make fuse-sampler an export arg that is on for cuda and off for mlx/metal for now?
There was a problem hiding this comment.
i would like to keep current status; i don't think fusing sampler into model's forward method is a good practice. This is a temporary solution before device support and once we get it in the near future all modules should return a logit and use a sampler tool to do sampling.
Qwen35MoE.forward currently routes through an Optional[Tensor] temperature parameter and an if/else that picks between the on-device fused Gumbel-max sampler (CUDA) and raw logits (non-CUDA). The sampling branch is dead code for MLX and Metal exports, since those backends sample on the host. Even though torch.export statically eliminates the branch when temperature defaults to None, the parameter, default value, and unused else-branch leak into the exported program: extra placeholder nodes, different graph hashes, and shifted kernel selection in the lowered MLX/Metal graph. On the tiny test model this slows MLX prefill ~9-37% and decode ~5-19%, and shows up as ~10-25% noise on Metal. Bind model.forward to a minimal (tokens, input_pos) -> logits variant inside _export_mlx and _export_metal before torch.export, so the captured program matches what the backend kernels are tuned for. Eager-mode callers and the CUDA export path are unaffected.
|
@Gasoonjia I wouldn't rely on --tiny-test for perf comparisons. Can you try with the actual model? |
Dropped the unconditional
.float()from thetemperature is Nonebranch ofQwen35MoE.forwardto keep its output having the model author's expected dtype.Qwen 3.5 MoE perf comparsion between this PR and e2eb417
i did detailed performance comparsion between this PR and the state before applying cuda sampler (commit e2eb417) to see if we can bring perf back.
TLDR: With this PR our perf is same or even better than the previous state when running on tiny model across mlx and metal, and on full model + mlx, but crashed on full model on metal; on full model mlx
Tiny Model
Setup: M3 Max 128 GB · macOS 26.4 · Xcode 26.4.1 ·
--tiny-testmodel · MLX--qlinear 4w --qlinear-group-size 32· Metal--qlinear fpa4w· all measurements use in-process warmup (MLX: warmup at prefill + decode shapes + force-eval; Metal:--warmup_iters 2 --warmup_decode_steps 4 --ignore_eos) · median of 3-6 trials.MLX (Python pybindings)
* prompt=32 decode trial-by-trial: 281 / 267 / 247 (3 trials). Trial-to-trial spread is ~14%, so the apparent regression is within noise.
Metal (C++ runner)
Full Model
Setup: Qwen/Qwen3.5-35B-A3B (40 layers, 2048d, 256 experts top-8, 67 GB safetensors) · M3 Max 128 GB · macOS 26.4 · Xcode 26.4.1 · MLX
--qlinear 4w --qlinear-group-size 64· in-process warmup at prefill+decode shapes + force-eval after prefill · median of 3 trials per config.MLX (full Qwen 3.5 MoE 35B-A3B)
Trial-to-trial variance is small (≤1 tok/s on decode, ≤5% on prefill) so all deltas are signal.
Metal (full Qwen 3.5 MoE 35B-A3B)
Not measured. Metal export of the 35B model OOM-kills on the 128 GB Mac during AOTI inductor compilation (
Killed: 9exit 137). Confirmed across 3 attempts: default settings,TORCHINDUCTOR_COMPILE_THREADS=1, and--max-seq-len 1024. The transient peak during AOTI lowering exceeds available RAM. Tiny-model Metal A/B (already collected, see prior summary) shows the same pattern: prefill +12%, decode +22~+32%.Conclusion
No regression on either backend; meaningful uplift on both. MLX shows the cleanest improvement on prefill (+11~+54%) and decode (+19% at small prompt). Metal shows +12% prefill and +22~+32% decode at prompt=32. The single MLX prompt-32 decode delta is within trial-to-trial variance.