@@ -680,6 +680,7 @@ def load_old_state_dict(self, old_state_dict: dict, verbose=False) -> None:
680680
681681 Args:
682682 old_state_dict: state dict from the old AutoencoderKL model.
683+ verbose: if True, print diagnostic information about key mismatches.
683684 """
684685
685686 new_state_dict = self .state_dict ()
@@ -715,13 +716,39 @@ def load_old_state_dict(self, old_state_dict: dict, verbose=False) -> None:
715716 new_state_dict [f"{ block } .attn.to_k.bias" ] = old_state_dict .pop (f"{ block } .to_k.bias" )
716717 new_state_dict [f"{ block } .attn.to_v.bias" ] = old_state_dict .pop (f"{ block } .to_v.bias" )
717718
718- # old version did not have a projection so set these to the identity
719- new_state_dict [f"{ block } .attn.out_proj.weight" ] = torch .eye (
720- new_state_dict [f"{ block } .attn.out_proj.weight" ].shape [0 ]
721- )
722- new_state_dict [f"{ block } .attn.out_proj.bias" ] = torch .zeros (
723- new_state_dict [f"{ block } .attn.out_proj.bias" ].shape
724- )
719+ out_w = f"{ block } .attn.out_proj.weight"
720+ out_b = f"{ block } .attn.out_proj.bias"
721+ proj_w = f"{ block } .proj_attn.weight"
722+ proj_b = f"{ block } .proj_attn.bias"
723+
724+ if out_w in new_state_dict :
725+ if proj_w in old_state_dict :
726+ new_state_dict [out_w ] = old_state_dict .pop (proj_w )
727+ if proj_b in old_state_dict :
728+ new_state_dict [out_b ] = old_state_dict .pop (proj_b )
729+ else :
730+ new_state_dict [out_b ] = torch .zeros (
731+ new_state_dict [out_b ].shape ,
732+ dtype = new_state_dict [out_b ].dtype ,
733+ device = new_state_dict [out_b ].device ,
734+ )
735+ else :
736+ # No legacy proj_attn - initialize out_proj to identity/zero
737+ new_state_dict [out_w ] = torch .eye (
738+ new_state_dict [out_w ].shape [0 ],
739+ dtype = new_state_dict [out_w ].dtype ,
740+ device = new_state_dict [out_w ].device ,
741+ )
742+ new_state_dict [out_b ] = torch .zeros (
743+ new_state_dict [out_b ].shape ,
744+ dtype = new_state_dict [out_b ].dtype ,
745+ device = new_state_dict [out_b ].device ,
746+ )
747+ elif proj_w in old_state_dict :
748+ # new model has no out_proj at all - discard the legacy keys so they
749+ # don't surface as "unexpected keys" during load_state_dict
750+ old_state_dict .pop (proj_w )
751+ old_state_dict .pop (proj_b , None )
725752
726753 # fix the upsample conv blocks which were renamed postconv
727754 for k in new_state_dict :
0 commit comments