@@ -424,7 +424,7 @@ def forward(
424424 self ,
425425 hidden_states : torch .Tensor ,
426426 encoder_hidden_states : torch .Tensor | None ,
427- temb_mod_params : tuple [ torch .Tensor , torch . Tensor , torch . Tensor ] ,
427+ temb_mod : torch .Tensor ,
428428 image_rotary_emb : tuple [torch .Tensor , torch .Tensor ] | None = None ,
429429 joint_attention_kwargs : dict [str , Any ] | None = None ,
430430 split_hidden_states : bool = False ,
@@ -436,7 +436,7 @@ def forward(
436436 text_seq_len = encoder_hidden_states .shape [1 ]
437437 hidden_states = torch .cat ([encoder_hidden_states , hidden_states ], dim = 1 )
438438
439- mod_shift , mod_scale , mod_gate = temb_mod_params
439+ mod_shift , mod_scale , mod_gate = Flux2Modulation . split ( temb_mod , 1 )[ 0 ]
440440
441441 norm_hidden_states = self .norm (hidden_states )
442442 norm_hidden_states = (1 + mod_scale ) * norm_hidden_states + mod_shift
@@ -498,16 +498,18 @@ def forward(
498498 self ,
499499 hidden_states : torch .Tensor ,
500500 encoder_hidden_states : torch .Tensor ,
501- temb_mod_params_img : tuple [ tuple [ torch .Tensor , torch . Tensor , torch . Tensor ], ...] ,
502- temb_mod_params_txt : tuple [ tuple [ torch .Tensor , torch . Tensor , torch . Tensor ], ...] ,
501+ temb_mod_img : torch .Tensor ,
502+ temb_mod_txt : torch .Tensor ,
503503 image_rotary_emb : tuple [torch .Tensor , torch .Tensor ] | None = None ,
504504 joint_attention_kwargs : dict [str , Any ] | None = None ,
505505 ) -> tuple [torch .Tensor , torch .Tensor ]:
506506 joint_attention_kwargs = joint_attention_kwargs or {}
507507
508508 # Modulation parameters shape: [1, 1, self.dim]
509- (shift_msa , scale_msa , gate_msa ), (shift_mlp , scale_mlp , gate_mlp ) = temb_mod_params_img
510- (c_shift_msa , c_scale_msa , c_gate_msa ), (c_shift_mlp , c_scale_mlp , c_gate_mlp ) = temb_mod_params_txt
509+ (shift_msa , scale_msa , gate_msa ), (shift_mlp , scale_mlp , gate_mlp ) = Flux2Modulation .split (temb_mod_img , 2 )
510+ (c_shift_msa , c_scale_msa , c_gate_msa ), (c_shift_mlp , c_scale_mlp , c_gate_mlp ) = Flux2Modulation .split (
511+ temb_mod_txt , 2
512+ )
511513
512514 # Img stream
513515 norm_hidden_states = self .norm1 (hidden_states )
@@ -627,15 +629,19 @@ def __init__(self, dim: int, mod_param_sets: int = 2, bias: bool = False):
627629 self .linear = nn .Linear (dim , dim * 3 * self .mod_param_sets , bias = bias )
628630 self .act_fn = nn .SiLU ()
629631
630- def forward (self , temb : torch .Tensor ) -> tuple [ tuple [ torch .Tensor , torch . Tensor , torch . Tensor ], ...] :
632+ def forward (self , temb : torch .Tensor ) -> torch .Tensor :
631633 mod = self .act_fn (temb )
632634 mod = self .linear (mod )
635+ return mod
633636
637+ @staticmethod
638+ # split inside the transformer blocks, to avoid passing tuples into checkpoints https://github.com/huggingface/diffusers/issues/12776
639+ def split (mod : torch .Tensor , mod_param_sets : int ) -> tuple [tuple [torch .Tensor , torch .Tensor , torch .Tensor ], ...]:
634640 if mod .ndim == 2 :
635641 mod = mod .unsqueeze (1 )
636- mod_params = torch .chunk (mod , 3 * self . mod_param_sets , dim = - 1 )
642+ mod_params = torch .chunk (mod , 3 * mod_param_sets , dim = - 1 )
637643 # Return tuple of 3-tuples of modulation params shift/scale/gate
638- return tuple (mod_params [3 * i : 3 * (i + 1 )] for i in range (self . mod_param_sets ))
644+ return tuple (mod_params [3 * i : 3 * (i + 1 )] for i in range (mod_param_sets ))
639645
640646
641647class Flux2Transformer2DModel (
@@ -824,7 +830,7 @@ def forward(
824830
825831 double_stream_mod_img = self .double_stream_modulation_img (temb )
826832 double_stream_mod_txt = self .double_stream_modulation_txt (temb )
827- single_stream_mod = self .single_stream_modulation (temb )[ 0 ]
833+ single_stream_mod = self .single_stream_modulation (temb )
828834
829835 # 2. Input projection for image (hidden_states) and conditioning text (encoder_hidden_states)
830836 hidden_states = self .x_embedder (hidden_states )
@@ -861,8 +867,8 @@ def forward(
861867 encoder_hidden_states , hidden_states = block (
862868 hidden_states = hidden_states ,
863869 encoder_hidden_states = encoder_hidden_states ,
864- temb_mod_params_img = double_stream_mod_img ,
865- temb_mod_params_txt = double_stream_mod_txt ,
870+ temb_mod_img = double_stream_mod_img ,
871+ temb_mod_txt = double_stream_mod_txt ,
866872 image_rotary_emb = concat_rotary_emb ,
867873 joint_attention_kwargs = joint_attention_kwargs ,
868874 )
@@ -884,7 +890,7 @@ def forward(
884890 hidden_states = block (
885891 hidden_states = hidden_states ,
886892 encoder_hidden_states = None ,
887- temb_mod_params = single_stream_mod ,
893+ temb_mod = single_stream_mod ,
888894 image_rotary_emb = concat_rotary_emb ,
889895 joint_attention_kwargs = joint_attention_kwargs ,
890896 )
0 commit comments