Skip to content

Commit 18de278

Browse files
quic-amitrajAmit Raj
andauthored
Added blocking support to flux (#679)
Added blocking support to flux --------- Signed-off-by: Amit Raj <amitraj@qti.qualcommm.com> Co-authored-by: Amit Raj <amitraj@qti.qualcommm.com>
1 parent 8721324 commit 18de278

File tree

1 file changed

+15
-3
lines changed

1 file changed

+15
-3
lines changed

QEfficient/diffusers/models/transformers/transformer_flux.py

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,6 @@
88

99
import numpy as np
1010
import torch
11-
from diffusers.models.attention_dispatch import dispatch_attention_fn
1211
from diffusers.models.modeling_outputs import Transformer2DModelOutput
1312
from diffusers.models.transformers.transformer_flux import (
1413
FluxAttention,
@@ -19,6 +18,7 @@
1918
_get_qkv_projections,
2019
)
2120

21+
from QEfficient.diffusers.models.modeling_utils import compute_blocked_attention, get_attention_blocking_config
2222
from QEfficient.utils.logging_utils import logger
2323

2424

@@ -89,9 +89,21 @@ def __call__(
8989
query = qeff_apply_rotary_emb(query, image_rotary_emb)
9090
key = qeff_apply_rotary_emb(key, image_rotary_emb)
9191

92-
hidden_states = dispatch_attention_fn(
93-
query, key, value, attn_mask=attention_mask, backend=self._attention_backend
92+
# Get blocking configuration
93+
blocking_mode, head_block_size, num_kv_blocks, num_q_blocks = get_attention_blocking_config()
94+
# Apply blocking using pipeline_utils
95+
hidden_states = compute_blocked_attention(
96+
query.transpose(1, 2),
97+
key.transpose(1, 2),
98+
value.transpose(1, 2),
99+
blocking_mode=blocking_mode,
100+
head_block_size=head_block_size,
101+
num_kv_blocks=num_kv_blocks,
102+
num_q_blocks=num_q_blocks,
103+
attention_mask=attention_mask,
94104
)
105+
106+
hidden_states = hidden_states.transpose(1, 2)
95107
hidden_states = hidden_states.flatten(2, 3)
96108
hidden_states = hidden_states.to(query.dtype)
97109

0 commit comments

Comments
 (0)