Metal backend: Add topk fallback kernel via MPSGraph#18876
Metal backend: Add topk fallback kernel via MPSGraph#18876manuelcandales merged 10 commits intomainfrom
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/executorch/18876
Note: Links to docs will display an error until the docs builds have been completed. ⏳ No Failures, 156 PendingAs of commit 933122c with merge base 4898af2 ( This comment was automatically generated by Dr. CI and updates every 15 minutes. |
|
@claude review this code |
|
Claude finished @metascroy's task in 2m 25s —— View job Code Review: Metal backend topk fallback kernel
Overall the implementation is solid and follows the established patterns from the other Metal ops. The transpose-topk-transpose approach for arbitrary dims and the negate trick for smallest mode are correct. A few issues worth addressing: Issues1. Memory leak on error path ( If 2. The 3. Missing The other ops ( 4. Missing Other ops call 5. Missing The other ops ( 6. The 7. Redundant auto* stream_sync = getCurrentMetalStream();This creates a new Nit8. extern std::unordered_map<void*, int32_t> memory_to_n_tensor;This works but is unusual. The other ops don't seem to need this pattern — the Test CoverageThe
Build IntegrationCMakeLists.txt and |
| size_t element_size; | ||
| MPSDataType mps_dtype; | ||
|
|
||
| if (dtype == static_cast<int32_t>(SupportedDTypes::FLOAT32)) { |
There was a problem hiding this comment.
Nit: could these be factored out into general utitlies?
This PR needs a
|
Adds aoti_torch_mps_topk using MPSGraph's topKWithSourceTensor. Required
for MoE expert routing (torch.topk in SparseMoE.forward). Supports
arbitrary dim via transpose-topk-transpose, largest/smallest modes,
float32 and bfloat16. Includes MPSGraph caching and int32-to-int64
indices conversion (AOTInductor expects int64, MPSGraph outputs int32).
Authored with Claude.