-
Notifications
You must be signed in to change notification settings - Fork 3.2k
Minor improvements to token_type_ids extension for PA
#34661
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Changes from all commits
1abe8ec
34763dd
e7f8238
daef794
fdb3a73
829c430
a04d165
96188fb
4d9f607
1f6a6d1
2bf68b5
bcbb855
ed3374d
0fd5001
5fbbf5e
81ab320
4e0d5ac
5f5af24
362cb80
7841c70
9810dc3
890804b
6a62dda
dfc6e1f
a412f6c
810130d
acbd73e
2da2303
03e9935
e2347e1
7176099
af07897
c99b9c5
7e6ba0e
84042e8
6bc0315
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -200,14 +200,11 @@ static std::shared_ptr<ov::Node> handle_baichuan2_13b_alibi( | |
|
|
||
| static std::shared_ptr<ov::Node> handle_gemma3_token_type_ids( | ||
| const std::map<std::string, std::shared_ptr<v0::Parameter>>& optional_model_wide_params) { | ||
| if (optional_model_wide_params.find("token_type_ids") != optional_model_wide_params.end()) { | ||
| auto param = optional_model_wide_params.at("token_type_ids"); | ||
| if (param->get_element_type() != ov::element::i32) { | ||
| return std::make_shared<v0::Convert>(param, ov::element::i32); | ||
| } | ||
| return param; | ||
| auto param = optional_model_wide_params.at("token_type_ids"); | ||
| if (param->get_element_type() != ov::element::i32) { | ||
| return std::make_shared<v0::Convert>(param, ov::element::i32); | ||
| } | ||
| return v0::Constant::create(ov::element::i32, ov::Shape{0}, {}); | ||
| return param; | ||
| } | ||
|
|
||
| static std::tuple<std::shared_ptr<ov::Node>, std::shared_ptr<ov::Node>> phi3_sliding_window_pattern() { | ||
|
|
@@ -438,10 +435,11 @@ ov::pass::StateManagementPattern::StateManagementPattern( | |
|
|
||
| auto sdpa_variants = std::make_shared<Or>(OutputVector{sdpa_with_4_inputs, sdpa_with_5_inputs, sdpa_with_6_inputs}); | ||
|
|
||
| // Shared flag to track whether the model is Gemma3, set when any layer matches | ||
| // the gptoss_gemma3 sliding window pattern. Combined with the token_type_ids check, | ||
| // this uniquely identifies Gemma3 (gpt-oss shares the pattern but lacks token_type_ids). | ||
| auto is_gptoss_gemma3 = std::make_shared<bool>(false); | ||
| // Set to true once a sliding_attention layer matching the gptoss_gemma3 pattern is found | ||
| // alongside a token_type_ids model input - the combination that uniquely identifies Gemma3 | ||
| // since pattern for full attention mask in Gemma3 is different than sliding window | ||
| // it has to be persistent in the callback, so shared_ptr is used | ||
| auto has_token_type_ids = std::make_shared<bool>(false); | ||
|
Comment on lines
+438
to
+442
|
||
|
|
||
| ov::matcher_pass_callback callback = [=, | ||
| &kv_parameters, | ||
|
|
@@ -621,7 +619,9 @@ ov::pass::StateManagementPattern::StateManagementPattern( | |
| } | ||
| sliding_window = std::make_shared<v1::Subtract>(v0::Constant::create(element::i32, Shape{}, {2}), offset); | ||
| } else if (pattern_map.count(gptoss_gemma3_offset)) { | ||
| *is_gptoss_gemma3 = true; | ||
| // gptoss_gemma3 pattern + token_type_ids input uniquely identifies Gemma3; | ||
| // gpt-oss shares this sliding window pattern but has no token_type_ids. | ||
| *has_token_type_ids = optional_model_wide_params.count("token_type_ids"); | ||
| auto offset = pattern_map.at(gptoss_gemma3_offset).get_node_shared_ptr(); | ||
| if (pattern_map.at(gptoss_gemma3_offset).get_partial_shape().rank() != 0) { | ||
| offset = std::make_shared<v15::Squeeze>(offset); | ||
|
|
@@ -756,7 +756,7 @@ ov::pass::StateManagementPattern::StateManagementPattern( | |
| } | ||
| OPENVINO_ASSERT(pa_arguments.size() == 25); | ||
|
|
||
| if (*is_gptoss_gemma3) { | ||
| if (*has_token_type_ids) { | ||
| pa_arguments.insert(pa_arguments.begin() + 25, handle_gemma3_token_type_ids(optional_model_wide_params)); | ||
| } else { | ||
| pa_arguments.insert(pa_arguments.begin() + 25, v0::Constant::create(element::i32, Shape{0}, {})); | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we define this variable inside the callback?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Agree it looks strange is required to define it outside
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Gemma3 has a repeating sequence of attention layers: 5x sliding window attention, 1x full attention. The pattern we currently have detects sliding window, but
token_type_idshas to be passed to full attention layers as well.has_token_type_idsis defined outside of the callback asshared_ptr, because it has to stay consistent between all lambda callbacks - since lambda's capture is=, it gets a new shared pointer to the object. Without it, thetoken_type_idswould be routed to PA only for sliding window PAa, and not for full attention PAs.Technically we could detect full attention pattern to avoid gpt-oss/gemma3 mixup and do the same trick, but then the first 5x sliding window attentions would not receive
token_type_idsinput, because only the first full attention layer (6th in line) would set the variable totrue.Summing up, it may not be as clean as I'd like it to be, but it works. If you insist that this piece of code will cause issues I can keep looking for a universal pattern which would:
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah, it's a little dirty solution, but if there's no other option, I believe we can live with it.
Any ideas of what this could be?