251251
252252extern " C" {
253253
254+ // Forward declaration of the implementation shared by both v1 and v2.
255+ static AOTITorchError sdpa_mps_impl (
256+ AOTITensorHandle query,
257+ AOTITensorHandle key,
258+ AOTITensorHandle value,
259+ AOTITensorHandle* attn_mask,
260+ double dropout_p,
261+ int32_t is_causal,
262+ AOTITensorHandle* dropout_mask,
263+ double * scale,
264+ int32_t enable_gqa,
265+ AOTITensorHandle* ret0,
266+ AOTITensorHandle* ret1);
267+
268+ // v1: Original signature without enable_gqa (for old .pte files).
254269AOTITorchError aoti_torch_mps__scaled_dot_product_attention_math_for_mps (
255270 AOTITensorHandle query,
256271 AOTITensorHandle key,
@@ -262,6 +277,41 @@ AOTITorchError aoti_torch_mps__scaled_dot_product_attention_math_for_mps(
262277 double * scale,
263278 AOTITensorHandle* ret0,
264279 AOTITensorHandle* ret1) {
280+ return sdpa_mps_impl (
281+ query, key, value, attn_mask, dropout_p, is_causal,
282+ dropout_mask, scale, /* enable_gqa=*/ 0 , ret0, ret1);
283+ }
284+
285+ // v2: New signature with enable_gqa (for new .pte files).
286+ AOTITorchError aoti_torch_mps__scaled_dot_product_attention_math_for_mps_v2 (
287+ AOTITensorHandle query,
288+ AOTITensorHandle key,
289+ AOTITensorHandle value,
290+ AOTITensorHandle* attn_mask,
291+ double dropout_p,
292+ int32_t is_causal,
293+ AOTITensorHandle* dropout_mask,
294+ double * scale,
295+ int32_t enable_gqa,
296+ AOTITensorHandle* ret0,
297+ AOTITensorHandle* ret1) {
298+ return sdpa_mps_impl (
299+ query, key, value, attn_mask, dropout_p, is_causal,
300+ dropout_mask, scale, enable_gqa, ret0, ret1);
301+ }
302+
303+ static AOTITorchError sdpa_mps_impl (
304+ AOTITensorHandle query,
305+ AOTITensorHandle key,
306+ AOTITensorHandle value,
307+ AOTITensorHandle* attn_mask,
308+ double dropout_p,
309+ int32_t is_causal,
310+ AOTITensorHandle* dropout_mask,
311+ double * scale,
312+ int32_t enable_gqa,
313+ AOTITensorHandle* ret0,
314+ AOTITensorHandle* ret1) {
265315
266316 ET_LOG (Debug, " aoti_torch_mps__scaled_dot_product_attention_math_for_mps: Starting with Metal kernel implementation" );
267317
0 commit comments