|
8 | 8 |
|
9 | 9 | import numpy as np |
10 | 10 | import torch |
11 | | -from diffusers.models.attention_dispatch import dispatch_attention_fn |
12 | 11 | from diffusers.models.modeling_outputs import Transformer2DModelOutput |
13 | 12 | from diffusers.models.transformers.transformer_flux import ( |
14 | 13 | FluxAttention, |
|
19 | 18 | _get_qkv_projections, |
20 | 19 | ) |
21 | 20 |
|
| 21 | +from QEfficient.diffusers.models.modeling_utils import compute_blocked_attention, get_attention_blocking_config |
22 | 22 | from QEfficient.utils.logging_utils import logger |
23 | 23 |
|
24 | 24 |
|
@@ -89,9 +89,21 @@ def __call__( |
89 | 89 | query = qeff_apply_rotary_emb(query, image_rotary_emb) |
90 | 90 | key = qeff_apply_rotary_emb(key, image_rotary_emb) |
91 | 91 |
|
92 | | - hidden_states = dispatch_attention_fn( |
93 | | - query, key, value, attn_mask=attention_mask, backend=self._attention_backend |
| 92 | + # Get blocking configuration |
| 93 | + blocking_mode, head_block_size, num_kv_blocks, num_q_blocks = get_attention_blocking_config() |
| 94 | + # Apply blocking using pipeline_utils |
| 95 | + hidden_states = compute_blocked_attention( |
| 96 | + query.transpose(1, 2), |
| 97 | + key.transpose(1, 2), |
| 98 | + value.transpose(1, 2), |
| 99 | + blocking_mode=blocking_mode, |
| 100 | + head_block_size=head_block_size, |
| 101 | + num_kv_blocks=num_kv_blocks, |
| 102 | + num_q_blocks=num_q_blocks, |
| 103 | + attention_mask=attention_mask, |
94 | 104 | ) |
| 105 | + |
| 106 | + hidden_states = hidden_states.transpose(1, 2) |
95 | 107 | hidden_states = hidden_states.flatten(2, 3) |
96 | 108 | hidden_states = hidden_states.to(query.dtype) |
97 | 109 |
|
|
0 commit comments