Minor improvements to token_type_ids extension for PA#34661
Minor improvements to token_type_ids extension for PA#34661p-wysocki wants to merge 36 commits intoopenvinotoolkit:masterfrom
token_type_ids extension for PA#34661Conversation
Signed-off-by: p-wysocki <przemyslaw.wysocki@intel.com>
Signed-off-by: p-wysocki <przemyslaw.wysocki@intel.com>
Signed-off-by: p-wysocki <przemyslaw.wysocki@intel.com>
Signed-off-by: p-wysocki <przemyslaw.wysocki@intel.com>
Signed-off-by: p-wysocki <przemyslaw.wysocki@intel.com>
Signed-off-by: p-wysocki <przemyslaw.wysocki@intel.com>
Signed-off-by: p-wysocki <przemyslaw.wysocki@intel.com>
Signed-off-by: p-wysocki <przemyslaw.wysocki@intel.com>
…nto attn_idea_2 Signed-off-by: p-wysocki <przemyslaw.wysocki@intel.com>
Signed-off-by: p-wysocki <przemyslaw.wysocki@intel.com>
Signed-off-by: p-wysocki <przemyslaw.wysocki@intel.com>
Signed-off-by: p-wysocki <przemyslaw.wysocki@intel.com>
Signed-off-by: p-wysocki <przemyslaw.wysocki@intel.com>
Signed-off-by: p-wysocki <przemyslaw.wysocki@intel.com>
…into attn_idea_2 Signed-off-by: p-wysocki <przemyslaw.wysocki@intel.com>
Signed-off-by: p-wysocki <przemyslaw.wysocki@intel.com>
…into attn_fixes Signed-off-by: p-wysocki <przemyslaw.wysocki@intel.com>
| // Shared flag to track whether the model is Gemma3, set when any layer matches | ||
| // the gptoss_gemma3 sliding window pattern. Combined with the token_type_ids check, | ||
| // this uniquely identifies Gemma3 (gpt-oss shares the pattern but lacks token_type_ids). | ||
| auto is_gptoss_gemma3 = std::make_shared<bool>(false); |
There was a problem hiding this comment.
Can we define this variable inside the callback?
There was a problem hiding this comment.
Agree it looks strange is required to define it outside
There was a problem hiding this comment.
Gemma3 has a repeating sequence of attention layers: 5x sliding window attention, 1x full attention. The pattern we currently have detects sliding window, but token_type_ids has to be passed to full attention layers as well.
has_token_type_ids is defined outside of the callback as shared_ptr, because it has to stay consistent between all lambda callbacks - since lambda's capture is =, it gets a new shared pointer to the object. Without it, the token_type_ids would be routed to PA only for sliding window PAa, and not for full attention PAs.
Technically we could detect full attention pattern to avoid gpt-oss/gemma3 mixup and do the same trick, but then the first 5x sliding window attentions would not receive token_type_ids input, because only the first full attention layer (6th in line) would set the variable to true.
Summing up, it may not be as clean as I'd like it to be, but it works. If you insist that this piece of code will cause issues I can keep looking for a universal pattern which would:
- separate gpt-oss and gemma3
- work for both sliding window and full attention layers
There was a problem hiding this comment.
Yeah, it's a little dirty solution, but if there's no other option, I believe we can live with it.
I can keep looking for a universal pattern
Any ideas of what this could be?
...mon/transformations/src/transformations/sdpa_to_paged_attention/state_management_pattern.cpp
Outdated
Show resolved
Hide resolved
| sliding_window = std::make_shared<v1::Subtract>(v0::Constant::create(element::i32, Shape{}, {2}), offset); | ||
| } else if (pattern_map.count(gptoss_gemma3_offset)) { | ||
| *is_gptoss_gemma3 = true; | ||
| is_gemma3 = optional_model_wide_params.count("token_type_ids"); |
There was a problem hiding this comment.
In fact any model with token_type_ids and matching sliding window pattern will set this is_gemma3 flag true, why not simply name this variable has_token_type_ids?
Or set has_sliding_window here instead, and use below.
Also currently is_gemma3 will be false for causal mask case (no sliding window) within the same model.
There was a problem hiding this comment.
I renamed the variable, and regarding the no sliding window case, the explanation is provided in #34661 (comment).
| if (is_gemma3) { | ||
| pa_arguments.insert(pa_arguments.begin() + 25, handle_gemma3_token_type_ids(optional_model_wide_params)); | ||
| } else { | ||
| pa_arguments.insert(pa_arguments.begin() + 25, v0::Constant::create(element::i32, Shape{0}, {})); |
There was a problem hiding this comment.
The variable naming is tight to gemma3 but it can be generic for any model having has_token_type_ids and has_sliding_window true.
It is currently applied for sliding_window case only, but as a next step it could be extended to causal case as well then this if else will be reduced to single case:
pa_arguments.insert(pa_arguments.begin() + 25, handle_token_type_ids(optional_model_wide_params));
| if (is_gemma3) { | |
| pa_arguments.insert(pa_arguments.begin() + 25, handle_gemma3_token_type_ids(optional_model_wide_params)); | |
| } else { | |
| pa_arguments.insert(pa_arguments.begin() + 25, v0::Constant::create(element::i32, Shape{0}, {})); | |
| if (has_sliding_window) { | |
| pa_arguments.insert(pa_arguments.begin() + 25, handle_token_type_ids(optional_model_wide_params)); | |
| } else { | |
| pa_arguments.insert(pa_arguments.begin() + 25, v0::Constant::create(element::i32, Shape{0}, {})); |
There was a problem hiding this comment.
I changed the variable name. The token_type_ids is currently working also for causal case, see #34661 (comment).
…into attn_fixes Signed-off-by: p-wysocki <przemyslaw.wysocki@intel.com>
There was a problem hiding this comment.
Pull request overview
This PR refines Gemma3 token_type_ids handling for the SDPA→PagedAttention transformation and strengthens PagedAttentionExtension type-propagation coverage around the newly-supported token_type_ids ranks.
Changes:
- Add type-prop tests validating
token_type_idsacceptance for rank-1/rank-2, dynamic shape, and invalid type/rank cases. - Simplify
token_type_idsretrieval/conversion in the Gemma3 path by assuming presence when the Gemma3 condition is met and avoiding an internal fallback.
Reviewed changes
Copilot reviewed 2 out of 2 changed files in this pull request and generated 1 comment.
| File | Description |
|---|---|
src/core/tests/type_prop/paged_attention.cpp |
Adds dedicated type-prop tests for token_type_ids rank/type validation. |
src/common/transformations/src/transformations/sdpa_to_paged_attention/state_management_pattern.cpp |
Adjusts Gemma3 detection flagging and streamlines token_type_ids handling (convert-to-i32 when needed). |
| // Set to true once a sliding_attention layer matching the gptoss_gemma3 pattern is found | ||
| // alongside a token_type_ids model input - the combination that uniquely identifies Gemma3 | ||
| // since pattern for full attention mask in Gemma3 is different than sliding window | ||
| // it has to be persistent in the callback, so shared_ptr is used | ||
| auto has_token_type_ids = std::make_shared<bool>(false); |
There was a problem hiding this comment.
[LOW] has_token_type_ids is used as a persisted “Gemma3 detected / enable token_type_ids wiring” flag (it’s only updated when the gptoss_gemma3 sliding-window pattern matches), so the name is misleading—there are cases where the model may have a token_type_ids input but this flag stays false until that pattern is seen. Consider renaming it to something like is_gemma3 / enable_gemma3_token_type_ids to reflect the actual semantics and reduce the chance of future misuse.
Details:
Tickets: