Skip to content

Commit c0c079b

Browse files
Metal backend: Add v2 entry point with enable_gqa (#19145)
1 parent d7f8718 commit c0c079b

2 files changed

Lines changed: 51 additions & 0 deletions

File tree

backends/apple/metal/metal_backend.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@ def get_supported_fallback_kernels(cls) -> Dict[str, Any]:
3535
"aoti_torch_mps_convolution": None,
3636
"aoti_torch_mps_mm_out": None,
3737
"at::_ops::_scaled_dot_product_attention_math_for_mps::call": None,
38+
"at::_ops::_scaled_dot_product_attention_math_for_mps_v2::call": None,
3839
"torchao::_linear_fp_act_4bit_weight": None,
3940
"at::_ops::topk::call": None,
4041
"metal::gather_qmv": None,

backends/apple/metal/runtime/ops/op_sdpa.mm

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -251,6 +251,21 @@
251251

252252
extern "C" {
253253

254+
// Forward declaration of the implementation shared by both v1 and v2.
255+
static AOTITorchError sdpa_mps_impl(
256+
AOTITensorHandle query,
257+
AOTITensorHandle key,
258+
AOTITensorHandle value,
259+
AOTITensorHandle* attn_mask,
260+
double dropout_p,
261+
int32_t is_causal,
262+
AOTITensorHandle* dropout_mask,
263+
double* scale,
264+
int32_t enable_gqa,
265+
AOTITensorHandle* ret0,
266+
AOTITensorHandle* ret1);
267+
268+
// v1: Original signature without enable_gqa (for old .pte files).
254269
AOTITorchError aoti_torch_mps__scaled_dot_product_attention_math_for_mps(
255270
AOTITensorHandle query,
256271
AOTITensorHandle key,
@@ -262,6 +277,41 @@ AOTITorchError aoti_torch_mps__scaled_dot_product_attention_math_for_mps(
262277
double* scale,
263278
AOTITensorHandle* ret0,
264279
AOTITensorHandle* ret1) {
280+
return sdpa_mps_impl(
281+
query, key, value, attn_mask, dropout_p, is_causal,
282+
dropout_mask, scale, /*enable_gqa=*/0, ret0, ret1);
283+
}
284+
285+
// v2: New signature with enable_gqa (for new .pte files).
286+
AOTITorchError aoti_torch_mps__scaled_dot_product_attention_math_for_mps_v2(
287+
AOTITensorHandle query,
288+
AOTITensorHandle key,
289+
AOTITensorHandle value,
290+
AOTITensorHandle* attn_mask,
291+
double dropout_p,
292+
int32_t is_causal,
293+
AOTITensorHandle* dropout_mask,
294+
double* scale,
295+
int32_t enable_gqa,
296+
AOTITensorHandle* ret0,
297+
AOTITensorHandle* ret1) {
298+
return sdpa_mps_impl(
299+
query, key, value, attn_mask, dropout_p, is_causal,
300+
dropout_mask, scale, enable_gqa, ret0, ret1);
301+
}
302+
303+
static AOTITorchError sdpa_mps_impl(
304+
AOTITensorHandle query,
305+
AOTITensorHandle key,
306+
AOTITensorHandle value,
307+
AOTITensorHandle* attn_mask,
308+
double dropout_p,
309+
int32_t is_causal,
310+
AOTITensorHandle* dropout_mask,
311+
double* scale,
312+
int32_t enable_gqa,
313+
AOTITensorHandle* ret0,
314+
AOTITensorHandle* ret1) {
265315

266316
ET_LOG(Debug, "aoti_torch_mps__scaled_dot_product_attention_math_for_mps: Starting with Metal kernel implementation");
267317

0 commit comments

Comments
 (0)