Skip to content

Commit 47455bd

Browse files
authored
Fix Flash Attention 3 interface for new FA3 return format (#13173)
* Fix Flash Attention 3 interface compatibility for new FA3 versions Newer versions of flash-attn (after Dao-AILab/flash-attention@ed20940) no longer return lse by default from flash_attn_3_func. The function now returns just the output tensor unless return_attn_probs=True is passed. Updated _wrapped_flash_attn_3 and _flash_varlen_attention_3 to pass return_attn_probs and handle both old (always tuple) and new (tensor or tuple) return formats gracefully. Fixes #12022 * Simplify _wrapped_flash_attn_3 return unpacking Since return_attn_probs=True is always passed, the result is guaranteed to be a tuple. Remove the unnecessary isinstance guard.
1 parent 97c2c6e commit 47455bd

File tree

1 file changed

+10
-2
lines changed

1 file changed

+10
-2
lines changed

src/diffusers/models/attention_dispatch.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -733,7 +733,7 @@ def _wrapped_flash_attn_3(
733733
) -> tuple[torch.Tensor, torch.Tensor]:
734734
# Hardcoded for now because pytorch does not support tuple/int type hints
735735
window_size = (-1, -1)
736-
out, lse, *_ = flash_attn_3_func(
736+
result = flash_attn_3_func(
737737
q=q,
738738
k=k,
739739
v=v,
@@ -750,7 +750,9 @@ def _wrapped_flash_attn_3(
750750
pack_gqa=pack_gqa,
751751
deterministic=deterministic,
752752
sm_margin=sm_margin,
753+
return_attn_probs=True,
753754
)
755+
out, lse, *_ = result
754756
lse = lse.permute(0, 2, 1)
755757
return out, lse
756758

@@ -2701,7 +2703,7 @@ def _flash_varlen_attention_3(
27012703
key_packed = torch.cat(key_valid, dim=0)
27022704
value_packed = torch.cat(value_valid, dim=0)
27032705

2704-
out, lse, *_ = flash_attn_3_varlen_func(
2706+
result = flash_attn_3_varlen_func(
27052707
q=query_packed,
27062708
k=key_packed,
27072709
v=value_packed,
@@ -2711,7 +2713,13 @@ def _flash_varlen_attention_3(
27112713
max_seqlen_k=max_seqlen_k,
27122714
softmax_scale=scale,
27132715
causal=is_causal,
2716+
return_attn_probs=return_lse,
27142717
)
2718+
if isinstance(result, tuple):
2719+
out, lse, *_ = result
2720+
else:
2721+
out = result
2722+
lse = None
27152723
out = out.unflatten(0, (batch_size, -1))
27162724

27172725
return (out, lse) if return_lse else out

0 commit comments

Comments
 (0)