@@ -861,21 +861,44 @@ def generate(self, input_name: str, framework: str = "pt", int_dtype: str = "int
861861
862862
863863class FalconDummyPastKeyValuesGenerator (DummyPastKeyValuesGenerator ):
864- def generate (self , input_name : str , framework : str = "pt" , int_dtype : str = "int64" , float_dtype : str = "fp32" ):
865- self .num_kv_heads = 1
866- head_dim = self .hidden_size // self .num_attention_heads
864+ def __init__ (
865+ self ,
866+ task : str ,
867+ normalized_config : NormalizedTextConfig ,
868+ batch_size : int = DEFAULT_DUMMY_SHAPES ["batch_size" ],
869+ sequence_length : int = DEFAULT_DUMMY_SHAPES ["sequence_length" ],
870+ random_batch_size_range : Optional [Tuple [int , int ]] = None ,
871+ random_sequence_length_range : Optional [Tuple [int , int ]] = None ,
872+ ** kwargs ,
873+ ):
874+ super ().__init__ (
875+ task = task ,
876+ normalized_config = normalized_config ,
877+ batch_size = batch_size ,
878+ sequence_length = sequence_length ,
879+ random_batch_size_range = random_batch_size_range ,
880+ random_sequence_length_range = random_sequence_length_range ,
881+ ** kwargs ,
882+ )
883+ self .num_kv_heads = self .num_kv_heads = (
884+ normalized_config .num_kv_heads
885+ if (normalized_config .new_decoder_architecture or not normalized_config .multi_query )
886+ else 1
887+ )
888+ self .head_dim = self .hidden_size // self .num_attention_heads
867889
890+ def generate (self , input_name : str , framework : str = "pt" , int_dtype : str = "int64" , float_dtype : str = "fp32" ):
868891 past_key_shape = (
869892 self .batch_size ,
870893 self .num_kv_heads ,
871894 self .sequence_length ,
872- head_dim ,
895+ self . head_dim ,
873896 )
874897 past_value_shape = (
875898 self .batch_size ,
876899 self .num_kv_heads ,
877900 self .sequence_length ,
878- head_dim ,
901+ self . head_dim ,
879902 )
880903 return [
881904 (
0 commit comments