Skip to content

Commit 2e8308e

Browse files
committed
Falcon BetterTransformer requires transformers>=4.34 (#1431)
* falcon BT requires transformers>=4.34 * more fix
1 parent 42924f8 commit 2e8308e

File tree

7 files changed

+41
-11
lines changed

7 files changed

+41
-11
lines changed

optimum/bettertransformer/models/attention.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -913,6 +913,7 @@ def falcon_forward(
913913
alibi: Optional[torch.Tensor],
914914
attention_mask: torch.Tensor,
915915
layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
916+
position_ids: Optional[torch.LongTensor] = None,
916917
head_mask: Optional[torch.Tensor] = None,
917918
use_cache: bool = False,
918919
output_attentions: bool = False,
@@ -937,7 +938,7 @@ def falcon_forward(
937938
value_layer = value_layer.transpose(1, 2).reshape(batch_size * num_kv_heads, query_length, self.head_dim)
938939

939940
past_kv_length = 0 if layer_past is None else layer_past[0].shape[1]
940-
query_layer, key_layer = self.maybe_rotary(query_layer, key_layer, past_kv_length)
941+
query_layer, key_layer = self.maybe_rotary(query_layer, key_layer, past_kv_length, position_ids)
941942

942943
if layer_past is not None:
943944
past_key, past_value = layer_past

optimum/bettertransformer/models/decoder_models.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@
4444
else:
4545
from ...utils.dummy_bettertransformer_objects import BarkSelfAttention
4646

47-
if check_if_transformers_greater("4.32"):
47+
if check_if_transformers_greater("4.34"):
4848
from transformers.models.falcon.modeling_falcon import FalconAttention
4949
else:
5050
from ...utils.dummy_bettertransformer_objects import FalconAttention

optimum/utils/dummy_bettertransformer_objects.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,10 +16,10 @@ def __init__(self, *args, **kwargs):
1616

1717

1818
class FalconAttention(metaclass=DummyObject):
19-
_backends = ["transformers_432"]
19+
_backends = ["transformers_434"]
2020

2121
def __init__(self, *args, **kwargs):
22-
requires_backends(self, ["transformers_432"])
22+
requires_backends(self, ["transformers_434"])
2323

2424

2525
def _llama_prepare_decoder_attention_mask(*args, **kwargs):

optimum/utils/import_utils.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -201,6 +201,10 @@ def require_numpy_strictly_lower(version: str, message: str):
201201
"transformers_432",
202202
(lambda: check_if_transformers_greater("4.32"), "{0} " + TRANSFORMERS_IMPORT_ERROR.format("4.32")),
203203
),
204+
(
205+
"transformers_434",
206+
(lambda: check_if_transformers_greater("4.34"), "{0} " + TRANSFORMERS_IMPORT_ERROR.format("4.34")),
207+
),
204208
]
205209
)
206210

optimum/utils/input_generators.py

Lines changed: 28 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -861,21 +861,44 @@ def generate(self, input_name: str, framework: str = "pt", int_dtype: str = "int
861861

862862

863863
class FalconDummyPastKeyValuesGenerator(DummyPastKeyValuesGenerator):
864-
def generate(self, input_name: str, framework: str = "pt", int_dtype: str = "int64", float_dtype: str = "fp32"):
865-
self.num_kv_heads = 1
866-
head_dim = self.hidden_size // self.num_attention_heads
864+
def __init__(
865+
self,
866+
task: str,
867+
normalized_config: NormalizedTextConfig,
868+
batch_size: int = DEFAULT_DUMMY_SHAPES["batch_size"],
869+
sequence_length: int = DEFAULT_DUMMY_SHAPES["sequence_length"],
870+
random_batch_size_range: Optional[Tuple[int, int]] = None,
871+
random_sequence_length_range: Optional[Tuple[int, int]] = None,
872+
**kwargs,
873+
):
874+
super().__init__(
875+
task=task,
876+
normalized_config=normalized_config,
877+
batch_size=batch_size,
878+
sequence_length=sequence_length,
879+
random_batch_size_range=random_batch_size_range,
880+
random_sequence_length_range=random_sequence_length_range,
881+
**kwargs,
882+
)
883+
self.num_kv_heads = self.num_kv_heads = (
884+
normalized_config.num_kv_heads
885+
if (normalized_config.new_decoder_architecture or not normalized_config.multi_query)
886+
else 1
887+
)
888+
self.head_dim = self.hidden_size // self.num_attention_heads
867889

890+
def generate(self, input_name: str, framework: str = "pt", int_dtype: str = "int64", float_dtype: str = "fp32"):
868891
past_key_shape = (
869892
self.batch_size,
870893
self.num_kv_heads,
871894
self.sequence_length,
872-
head_dim,
895+
self.head_dim,
873896
)
874897
past_value_shape = (
875898
self.batch_size,
876899
self.num_kv_heads,
877900
self.sequence_length,
878-
head_dim,
901+
self.head_dim,
879902
)
880903
return [
881904
(

optimum/utils/normalized_config.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -211,7 +211,9 @@ class NormalizedConfigManager:
211211
"blenderbot": BartLikeNormalizedTextConfig,
212212
"blenderbot_small": BartLikeNormalizedTextConfig,
213213
"bloom": NormalizedTextConfig.with_args(num_layers="n_layer"),
214-
"falcon": NormalizedTextConfig.with_args(num_layers="num_hidden_layers", num_attention_heads="num_kv_heads"),
214+
"falcon": NormalizedTextConfig.with_args(
215+
num_layers="num_hidden_layers", num_attention_heads="num_attention_heads"
216+
),
215217
"camembert": NormalizedTextConfig,
216218
"codegen": GPT2LikeNormalizedTextConfig,
217219
"cvt": NormalizedVisionConfig,

tests/bettertransformer/testing_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@
4343
"distilbert": "hf-internal-testing/tiny-random-DistilBertModel",
4444
"electra": "hf-internal-testing/tiny-random-ElectraModel",
4545
"ernie": "hf-internal-testing/tiny-random-ErnieModel",
46-
"falcon": "Rocketknight1/tiny-random-falcon-7b",
46+
"falcon": "fxmarty/really-tiny-falcon-testing",
4747
"fsmt": "hf-internal-testing/tiny-random-FSMTModel",
4848
"gpt2": "hf-internal-testing/tiny-random-GPT2Model",
4949
# NOTE: this tiny model does not use attention_softmax_in_fp32=True (contrary to e.g. starcoder)

0 commit comments

Comments
 (0)