Skip to content

Commit 9e44025

Browse files
authored
Merge branch 'main' into ci-pin-setuptools-pkg-resources
2 parents 00eaa47 + a577ec3 commit 9e44025

File tree

1 file changed

+19
-13
lines changed

1 file changed

+19
-13
lines changed

src/diffusers/models/transformers/transformer_flux2.py

Lines changed: 19 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -424,7 +424,7 @@ def forward(
424424
self,
425425
hidden_states: torch.Tensor,
426426
encoder_hidden_states: torch.Tensor | None,
427-
temb_mod_params: tuple[torch.Tensor, torch.Tensor, torch.Tensor],
427+
temb_mod: torch.Tensor,
428428
image_rotary_emb: tuple[torch.Tensor, torch.Tensor] | None = None,
429429
joint_attention_kwargs: dict[str, Any] | None = None,
430430
split_hidden_states: bool = False,
@@ -436,7 +436,7 @@ def forward(
436436
text_seq_len = encoder_hidden_states.shape[1]
437437
hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
438438

439-
mod_shift, mod_scale, mod_gate = temb_mod_params
439+
mod_shift, mod_scale, mod_gate = Flux2Modulation.split(temb_mod, 1)[0]
440440

441441
norm_hidden_states = self.norm(hidden_states)
442442
norm_hidden_states = (1 + mod_scale) * norm_hidden_states + mod_shift
@@ -498,16 +498,18 @@ def forward(
498498
self,
499499
hidden_states: torch.Tensor,
500500
encoder_hidden_states: torch.Tensor,
501-
temb_mod_params_img: tuple[tuple[torch.Tensor, torch.Tensor, torch.Tensor], ...],
502-
temb_mod_params_txt: tuple[tuple[torch.Tensor, torch.Tensor, torch.Tensor], ...],
501+
temb_mod_img: torch.Tensor,
502+
temb_mod_txt: torch.Tensor,
503503
image_rotary_emb: tuple[torch.Tensor, torch.Tensor] | None = None,
504504
joint_attention_kwargs: dict[str, Any] | None = None,
505505
) -> tuple[torch.Tensor, torch.Tensor]:
506506
joint_attention_kwargs = joint_attention_kwargs or {}
507507

508508
# Modulation parameters shape: [1, 1, self.dim]
509-
(shift_msa, scale_msa, gate_msa), (shift_mlp, scale_mlp, gate_mlp) = temb_mod_params_img
510-
(c_shift_msa, c_scale_msa, c_gate_msa), (c_shift_mlp, c_scale_mlp, c_gate_mlp) = temb_mod_params_txt
509+
(shift_msa, scale_msa, gate_msa), (shift_mlp, scale_mlp, gate_mlp) = Flux2Modulation.split(temb_mod_img, 2)
510+
(c_shift_msa, c_scale_msa, c_gate_msa), (c_shift_mlp, c_scale_mlp, c_gate_mlp) = Flux2Modulation.split(
511+
temb_mod_txt, 2
512+
)
511513

512514
# Img stream
513515
norm_hidden_states = self.norm1(hidden_states)
@@ -627,15 +629,19 @@ def __init__(self, dim: int, mod_param_sets: int = 2, bias: bool = False):
627629
self.linear = nn.Linear(dim, dim * 3 * self.mod_param_sets, bias=bias)
628630
self.act_fn = nn.SiLU()
629631

630-
def forward(self, temb: torch.Tensor) -> tuple[tuple[torch.Tensor, torch.Tensor, torch.Tensor], ...]:
632+
def forward(self, temb: torch.Tensor) -> torch.Tensor:
631633
mod = self.act_fn(temb)
632634
mod = self.linear(mod)
635+
return mod
633636

637+
@staticmethod
638+
# split inside the transformer blocks, to avoid passing tuples into checkpoints https://github.com/huggingface/diffusers/issues/12776
639+
def split(mod: torch.Tensor, mod_param_sets: int) -> tuple[tuple[torch.Tensor, torch.Tensor, torch.Tensor], ...]:
634640
if mod.ndim == 2:
635641
mod = mod.unsqueeze(1)
636-
mod_params = torch.chunk(mod, 3 * self.mod_param_sets, dim=-1)
642+
mod_params = torch.chunk(mod, 3 * mod_param_sets, dim=-1)
637643
# Return tuple of 3-tuples of modulation params shift/scale/gate
638-
return tuple(mod_params[3 * i : 3 * (i + 1)] for i in range(self.mod_param_sets))
644+
return tuple(mod_params[3 * i : 3 * (i + 1)] for i in range(mod_param_sets))
639645

640646

641647
class Flux2Transformer2DModel(
@@ -824,7 +830,7 @@ def forward(
824830

825831
double_stream_mod_img = self.double_stream_modulation_img(temb)
826832
double_stream_mod_txt = self.double_stream_modulation_txt(temb)
827-
single_stream_mod = self.single_stream_modulation(temb)[0]
833+
single_stream_mod = self.single_stream_modulation(temb)
828834

829835
# 2. Input projection for image (hidden_states) and conditioning text (encoder_hidden_states)
830836
hidden_states = self.x_embedder(hidden_states)
@@ -861,8 +867,8 @@ def forward(
861867
encoder_hidden_states, hidden_states = block(
862868
hidden_states=hidden_states,
863869
encoder_hidden_states=encoder_hidden_states,
864-
temb_mod_params_img=double_stream_mod_img,
865-
temb_mod_params_txt=double_stream_mod_txt,
870+
temb_mod_img=double_stream_mod_img,
871+
temb_mod_txt=double_stream_mod_txt,
866872
image_rotary_emb=concat_rotary_emb,
867873
joint_attention_kwargs=joint_attention_kwargs,
868874
)
@@ -884,7 +890,7 @@ def forward(
884890
hidden_states = block(
885891
hidden_states=hidden_states,
886892
encoder_hidden_states=None,
887-
temb_mod_params=single_stream_mod,
893+
temb_mod=single_stream_mod,
888894
image_rotary_emb=concat_rotary_emb,
889895
joint_attention_kwargs=joint_attention_kwargs,
890896
)

0 commit comments

Comments
 (0)