@@ -329,7 +329,11 @@ class _HubKernelConfig:
329329_HUB_KERNELS_REGISTRY : dict ["AttentionBackendName" , _HubKernelConfig ] = {
330330 # TODO: temporary revision for now. Remove when merged upstream into `main`.
331331 AttentionBackendName ._FLASH_3_HUB : _HubKernelConfig (
332- repo_id = "kernels-community/flash-attn3" , function_attr = "flash_attn_func" , revision = "fake-ops-return-probs"
332+ repo_id = "kernels-community/flash-attn3" ,
333+ function_attr = "flash_attn_func" ,
334+ revision = "fake-ops-return-probs" ,
335+ wrapped_forward_attr = "flash_attn_interface._flash_attn_forward" ,
336+ wrapped_backward_attr = "flash_attn_interface._flash_attn_backward" ,
333337 ),
334338 AttentionBackendName ._FLASH_3_VARLEN_HUB : _HubKernelConfig (
335339 repo_id = "kernels-community/flash-attn3" ,
@@ -729,7 +733,7 @@ def _wrapped_flash_attn_3(
729733) -> tuple [torch .Tensor , torch .Tensor ]:
730734 # Hardcoded for now because pytorch does not support tuple/int type hints
731735 window_size = (- 1 , - 1 )
732- out , lse , * _ = flash_attn_3_func (
736+ result = flash_attn_3_func (
733737 q = q ,
734738 k = k ,
735739 v = v ,
@@ -746,7 +750,9 @@ def _wrapped_flash_attn_3(
746750 pack_gqa = pack_gqa ,
747751 deterministic = deterministic ,
748752 sm_margin = sm_margin ,
753+ return_attn_probs = True ,
749754 )
755+ out , lse , * _ = result
750756 lse = lse .permute (0 , 2 , 1 )
751757 return out , lse
752758
@@ -1290,36 +1296,62 @@ def _flash_attention_3_hub_forward_op(
12901296 if enable_gqa :
12911297 raise ValueError ("`enable_gqa` is not yet supported for flash-attn 3 hub kernels." )
12921298
1293- func = _HUB_KERNELS_REGISTRY [AttentionBackendName ._FLASH_3_HUB ].kernel_fn
1294- out = func (
1295- q = query ,
1296- k = key ,
1297- v = value ,
1298- softmax_scale = scale ,
1299+ config = _HUB_KERNELS_REGISTRY [AttentionBackendName ._FLASH_3_HUB ]
1300+ wrapped_forward_fn = config .wrapped_forward_fn
1301+ if wrapped_forward_fn is None :
1302+ raise RuntimeError (
1303+ "Flash attention 3 hub kernels must expose `flash_attn_interface._flash_attn_forward` "
1304+ "for context parallel execution."
1305+ )
1306+
1307+ if scale is None :
1308+ scale = query .shape [- 1 ] ** (- 0.5 )
1309+
1310+ out , softmax_lse , * _ = wrapped_forward_fn (
1311+ query ,
1312+ key ,
1313+ value ,
1314+ None ,
1315+ None , # k_new, v_new
1316+ None , # qv
1317+ None , # out
1318+ None ,
1319+ None ,
1320+ None , # cu_seqlens_q/k/k_new
1321+ None ,
1322+ None , # seqused_q/k
1323+ None ,
1324+ None , # max_seqlen_q/k
1325+ None ,
1326+ None ,
1327+ None , # page_table, kv_batch_idx, leftpad_k
1328+ None ,
1329+ None ,
1330+ None , # rotary_cos/sin, seqlens_rotary
1331+ None ,
1332+ None ,
1333+ None , # q_descale, k_descale, v_descale
1334+ scale ,
12991335 causal = is_causal ,
1300- qv = None ,
1301- q_descale = None ,
1302- k_descale = None ,
1303- v_descale = None ,
1304- window_size = window_size ,
1336+ window_size_left = window_size [0 ],
1337+ window_size_right = window_size [1 ],
1338+ attention_chunk = 0 ,
13051339 softcap = softcap ,
13061340 num_splits = num_splits ,
13071341 pack_gqa = pack_gqa ,
1308- deterministic = deterministic ,
13091342 sm_margin = sm_margin ,
1310- return_attn_probs = return_lse ,
13111343 )
13121344
1313- lse = None
1314- if return_lse :
1315- out , lse = out
1316- lse = lse .permute (0 , 2 , 1 ).contiguous ()
1345+ lse = softmax_lse .permute (0 , 2 , 1 ).contiguous () if return_lse else None
13171346
13181347 if _save_ctx :
1319- ctx .save_for_backward (query , key , value )
1348+ ctx .save_for_backward (query , key , value , out , softmax_lse )
13201349 ctx .scale = scale
13211350 ctx .is_causal = is_causal
1322- ctx ._hub_kernel = func
1351+ ctx .window_size = window_size
1352+ ctx .softcap = softcap
1353+ ctx .deterministic = deterministic
1354+ ctx .sm_margin = sm_margin
13231355
13241356 return (out , lse ) if return_lse else out
13251357
@@ -1328,55 +1360,50 @@ def _flash_attention_3_hub_backward_op(
13281360 ctx : torch .autograd .function .FunctionCtx ,
13291361 grad_out : torch .Tensor ,
13301362 * args ,
1331- window_size : tuple [int , int ] = (- 1 , - 1 ),
1332- softcap : float = 0.0 ,
1333- num_splits : int = 1 ,
1334- pack_gqa : bool | None = None ,
1335- deterministic : bool = False ,
1336- sm_margin : int = 0 ,
1363+ ** kwargs ,
13371364):
1338- query , key , value = ctx .saved_tensors
1339- kernel_fn = ctx ._hub_kernel
1340- # NOTE: Unlike the FA2 hub kernel, the FA3 hub kernel does not expose separate wrapped forward/backward
1341- # primitives (no `wrapped_forward_attr`/`wrapped_backward_attr` in its `_HubKernelConfig`). We
1342- # therefore rerun the forward pass under `torch.enable_grad()` and differentiate through it with
1343- # `torch.autograd.grad()`. This is a second forward pass during backward; it can be avoided once
1344- # the FA3 hub exposes a dedicated fused backward kernel (analogous to `_wrapped_flash_attn_backward`
1345- # in the FA2 hub), at which point this can be refactored to match `_flash_attention_hub_backward_op`.
1346- with torch .enable_grad ():
1347- query_r = query .detach ().requires_grad_ (True )
1348- key_r = key .detach ().requires_grad_ (True )
1349- value_r = value .detach ().requires_grad_ (True )
1350-
1351- out = kernel_fn (
1352- q = query_r ,
1353- k = key_r ,
1354- v = value_r ,
1355- softmax_scale = ctx .scale ,
1356- causal = ctx .is_causal ,
1357- qv = None ,
1358- q_descale = None ,
1359- k_descale = None ,
1360- v_descale = None ,
1361- window_size = window_size ,
1362- softcap = softcap ,
1363- num_splits = num_splits ,
1364- pack_gqa = pack_gqa ,
1365- deterministic = deterministic ,
1366- sm_margin = sm_margin ,
1367- return_attn_probs = False ,
1368- )
1369- if isinstance (out , tuple ):
1370- out = out [0 ]
1371-
1372- grad_query , grad_key , grad_value = torch .autograd .grad (
1373- out ,
1374- (query_r , key_r , value_r ),
1375- grad_out ,
1376- retain_graph = False ,
1377- allow_unused = False ,
1365+ config = _HUB_KERNELS_REGISTRY [AttentionBackendName ._FLASH_3_HUB ]
1366+ wrapped_backward_fn = config .wrapped_backward_fn
1367+ if wrapped_backward_fn is None :
1368+ raise RuntimeError (
1369+ "Flash attention 3 hub kernels must expose `flash_attn_interface._flash_attn_backward` "
1370+ "for context parallel execution."
13781371 )
13791372
1373+ query , key , value , out , softmax_lse = ctx .saved_tensors
1374+ grad_query = torch .empty_like (query )
1375+ grad_key = torch .empty_like (key )
1376+ grad_value = torch .empty_like (value )
1377+
1378+ wrapped_backward_fn (
1379+ grad_out ,
1380+ query ,
1381+ key ,
1382+ value ,
1383+ out ,
1384+ softmax_lse ,
1385+ None ,
1386+ None , # cu_seqlens_q, cu_seqlens_k
1387+ None ,
1388+ None , # seqused_q, seqused_k
1389+ None ,
1390+ None , # max_seqlen_q, max_seqlen_k
1391+ grad_query ,
1392+ grad_key ,
1393+ grad_value ,
1394+ ctx .scale ,
1395+ ctx .is_causal ,
1396+ ctx .window_size [0 ],
1397+ ctx .window_size [1 ],
1398+ ctx .softcap ,
1399+ ctx .deterministic ,
1400+ ctx .sm_margin ,
1401+ )
1402+
1403+ grad_query = grad_query [..., : grad_out .shape [- 1 ]]
1404+ grad_key = grad_key [..., : grad_out .shape [- 1 ]]
1405+ grad_value = grad_value [..., : grad_out .shape [- 1 ]]
1406+
13801407 return grad_query , grad_key , grad_value
13811408
13821409
@@ -2676,7 +2703,7 @@ def _flash_varlen_attention_3(
26762703 key_packed = torch .cat (key_valid , dim = 0 )
26772704 value_packed = torch .cat (value_valid , dim = 0 )
26782705
2679- out , lse , * _ = flash_attn_3_varlen_func (
2706+ result = flash_attn_3_varlen_func (
26802707 q = query_packed ,
26812708 k = key_packed ,
26822709 v = value_packed ,
@@ -2686,7 +2713,13 @@ def _flash_varlen_attention_3(
26862713 max_seqlen_k = max_seqlen_k ,
26872714 softmax_scale = scale ,
26882715 causal = is_causal ,
2716+ return_attn_probs = return_lse ,
26892717 )
2718+ if isinstance (result , tuple ):
2719+ out , lse , * _ = result
2720+ else :
2721+ out = result
2722+ lse = None
26902723 out = out .unflatten (0 , (batch_size , - 1 ))
26912724
26922725 return (out , lse ) if return_lse else out
0 commit comments