Skip to content

Metal backend: Add topk fallback kernel via MPSGraph#18876

Merged
manuelcandales merged 10 commits intomainfrom
gh/manuelcandales/171/head
Apr 20, 2026
Merged

Metal backend: Add topk fallback kernel via MPSGraph#18876
manuelcandales merged 10 commits intomainfrom
gh/manuelcandales/171/head

Conversation

@manuelcandales
Copy link
Copy Markdown
Contributor

Adds aoti_torch_mps_topk using MPSGraph's topKWithSourceTensor. Required
for MoE expert routing (torch.topk in SparseMoE.forward). Supports
arbitrary dim via transpose-topk-transpose, largest/smallest modes,
float32 and bfloat16. Includes MPSGraph caching and int32-to-int64
indices conversion (AOTInductor expects int64, MPSGraph outputs int32).

Authored with Claude.

[ghstack-poisoned]
[ghstack-poisoned]
[ghstack-poisoned]
@pytorch-bot
Copy link
Copy Markdown

pytorch-bot Bot commented Apr 14, 2026

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/executorch/18876

Note: Links to docs will display an error until the docs builds have been completed.

⏳ No Failures, 156 Pending

As of commit 933122c with merge base 4898af2 (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

[ghstack-poisoned]
[ghstack-poisoned]
[ghstack-poisoned]
@metascroy
Copy link
Copy Markdown
Contributor

@claude review this code

@claude
Copy link
Copy Markdown

claude Bot commented Apr 16, 2026

Claude finished @metascroy's task in 2m 25s —— View job


Code Review: Metal backend topk fallback kernel

  • Read PR diff and changed files
  • Review op_topk.mm implementation
  • Review test coverage
  • Review build integration
  • Post review feedback

Overall the implementation is solid and follows the established patterns from the other Metal ops. The transpose-topk-transpose approach for arbitrary dims and the negate trick for smallest mode are correct. A few issues worth addressing:


Issues

1. Memory leak on error path (op_topk.mm:232-237)

If values_handle creation fails, the code frees both buffers but doesn't remove values_ptr from ptr_to_mtl_buffer. However, the bigger issue is earlier: if the MPSGraph execution itself throws (caught by the outer catch), both values_ptr and indices_ptr are leaked — they were allocated at lines 101-102 but the catch blocks at lines 290-296 don't free them.

Fix this →

2. sorted parameter is silently ignored (op_topk.mm:25)

The sorted parameter is accepted but never used. MPSGraph's topKWithSourceTensor returns sorted results by default, so this is likely fine in practice, but sorted=false won't actually give unsorted results. Worth either documenting this limitation or validating that sorted=true (and returning an error for sorted=false if unsorted output isn't supported).

3. Missing endKernelCoalescing() call (op_topk.mm:146,216)

The other ops (op_mm.mm:108, op_bmm.mm, op_convolution.mm) all call stream->endKernelCoalescing() before executing the MPSGraph. This op doesn't. This could cause issues if kernels were being coalesced when topk is called.

Fix this →

4. Missing cache_stats.logStats() calls (op_topk.mm:125,148)

Other ops call cache_stats.logStats() after incrementing hits/misses. This op only increments the counters without calling logStats(), so cache statistics won't be periodically logged for topk operations.

5. Missing [... release] for MPSGraphTensorData objects (op_topk.mm:132-144,201-213)

The other ops (op_mm.mm:274-276, op_bmm.mm) explicitly release MPSGraphTensorData objects after execution. This op relies on @autoreleasepool to clean them up, which works but is inconsistent with the existing pattern. The objects created with [[MPSGraphTensorData alloc] init...] are not autoreleased — they are owned and need explicit release (or the autoreleasepool handles them at scope exit, but only if there's no long-lived reference).

6. @try/@catch only wraps cache-miss path (op_topk.mm:151-222)

The @try/@catch for ObjC exceptions only wraps the cache-miss graph construction and execution. The cache-hit execution at line 146 is not protected, so if executeMPSGraph throws an NSException on a cache hit, it won't be caught by the @try block (it will propagate to the C++ catch(...) which may not handle ObjC exceptions correctly on all toolchains).

Fix this →

7. Redundant stream variable in int32→int64 conversion (op_topk.mm:251)

auto* stream_sync = getCurrentMetalStream();

This creates a new stream_sync variable, but stream (from line 37) is the same object. Minor, but unnecessary.


Nit

8. extern declaration inside function body (op_topk.mm:240)

extern std::unordered_map<void*, int32_t> memory_to_n_tensor;

This works but is unusual. The other ops don't seem to need this pattern — the memory.h header (already included transitively through common.h) should provide this declaration. If it doesn't, the extern should be at namespace scope, not inside the function.


Test Coverage

The TopK test module (test_modules.py:672-689) only tests dim=-1 (last dimension) with largest=True (default). Since the transpose logic for non-last dimensions and the negate trick for largest=False are the most complex parts of this implementation, it would be valuable to add test cases for:

  • A non-last dimension (e.g., dim=0)
  • largest=False

Build Integration

CMakeLists.txt and metal_backend.py changes look correct — op_topk.mm is added to sources and at::_ops::topk::call is registered as a supported fallback kernel.


[ghstack-poisoned]
size_t element_size;
MPSDataType mps_dtype;

if (dtype == static_cast<int32_t>(SupportedDTypes::FLOAT32)) {
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: could these be factored out into general utitlies?

[ghstack-poisoned]
[ghstack-poisoned]
Base automatically changed from gh/manuelcandales/170/head to main April 20, 2026 19:30
[ghstack-poisoned]
@github-actions
Copy link
Copy Markdown

This PR needs a release notes: label

If your change should be included in the release notes (i.e. would users of this library care about this change?), please use a label starting with release notes:. This helps us keep track and include your important work in the next release notes.

To add a label, you can comment to pytorchbot, for example
@pytorchbot label "release notes: none"

For more information, see
https://github.com/pytorch/pytorch/wiki/PyTorch-AutoLabel-Bot#why-categorize-for-release-notes-and-how-does-it-work.

@manuelcandales manuelcandales merged commit 66e4656 into main Apr 20, 2026
174 of 178 checks passed
@manuelcandales manuelcandales deleted the gh/manuelcandales/171/head branch April 20, 2026 19:55
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants