Skip to content

Minor improvements to token_type_ids extension for PA#34661

Open
p-wysocki wants to merge 36 commits intoopenvinotoolkit:masterfrom
p-wysocki:attn_fixes
Open

Minor improvements to token_type_ids extension for PA#34661
p-wysocki wants to merge 36 commits intoopenvinotoolkit:masterfrom
p-wysocki:attn_fixes

Conversation

@p-wysocki
Copy link
Copy Markdown
Contributor

Details:

Tickets:

  • N/A

Signed-off-by: p-wysocki <przemyslaw.wysocki@intel.com>
Signed-off-by: p-wysocki <przemyslaw.wysocki@intel.com>
Signed-off-by: p-wysocki <przemyslaw.wysocki@intel.com>
Signed-off-by: p-wysocki <przemyslaw.wysocki@intel.com>
Signed-off-by: p-wysocki <przemyslaw.wysocki@intel.com>
Signed-off-by: p-wysocki <przemyslaw.wysocki@intel.com>
Signed-off-by: p-wysocki <przemyslaw.wysocki@intel.com>
Signed-off-by: p-wysocki <przemyslaw.wysocki@intel.com>
Signed-off-by: p-wysocki <przemyslaw.wysocki@intel.com>
Signed-off-by: p-wysocki <przemyslaw.wysocki@intel.com>
Signed-off-by: p-wysocki <przemyslaw.wysocki@intel.com>
Signed-off-by: p-wysocki <przemyslaw.wysocki@intel.com>
Signed-off-by: p-wysocki <przemyslaw.wysocki@intel.com>
Signed-off-by: p-wysocki <przemyslaw.wysocki@intel.com>
Signed-off-by: p-wysocki <przemyslaw.wysocki@intel.com>
…nto attn_idea_2

Signed-off-by: p-wysocki <przemyslaw.wysocki@intel.com>
Signed-off-by: p-wysocki <przemyslaw.wysocki@intel.com>
Signed-off-by: p-wysocki <przemyslaw.wysocki@intel.com>
Signed-off-by: p-wysocki <przemyslaw.wysocki@intel.com>
Signed-off-by: p-wysocki <przemyslaw.wysocki@intel.com>
Signed-off-by: p-wysocki <przemyslaw.wysocki@intel.com>
Signed-off-by: p-wysocki <przemyslaw.wysocki@intel.com>
Signed-off-by: p-wysocki <przemyslaw.wysocki@intel.com>
…into attn_idea_2

Signed-off-by: p-wysocki <przemyslaw.wysocki@intel.com>
Signed-off-by: p-wysocki <przemyslaw.wysocki@intel.com>
Signed-off-by: p-wysocki <przemyslaw.wysocki@intel.com>
…into attn_fixes

Signed-off-by: p-wysocki <przemyslaw.wysocki@intel.com>
@p-wysocki p-wysocki requested review from a team as code owners March 12, 2026 11:34
@github-actions github-actions bot added category: Core OpenVINO Core (aka ngraph) category: transformations OpenVINO Runtime library - Transformations labels Mar 12, 2026
// Shared flag to track whether the model is Gemma3, set when any layer matches
// the gptoss_gemma3 sliding window pattern. Combined with the token_type_ids check,
// this uniquely identifies Gemma3 (gpt-oss shares the pattern but lacks token_type_ids).
auto is_gptoss_gemma3 = std::make_shared<bool>(false);
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we define this variable inside the callback?

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agree it looks strange is required to define it outside

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Gemma3 has a repeating sequence of attention layers: 5x sliding window attention, 1x full attention. The pattern we currently have detects sliding window, but token_type_ids has to be passed to full attention layers as well.

has_token_type_ids is defined outside of the callback as shared_ptr, because it has to stay consistent between all lambda callbacks - since lambda's capture is =, it gets a new shared pointer to the object. Without it, the token_type_ids would be routed to PA only for sliding window PAa, and not for full attention PAs.

Technically we could detect full attention pattern to avoid gpt-oss/gemma3 mixup and do the same trick, but then the first 5x sliding window attentions would not receive token_type_ids input, because only the first full attention layer (6th in line) would set the variable to true.

Summing up, it may not be as clean as I'd like it to be, but it works. If you insist that this piece of code will cause issues I can keep looking for a universal pattern which would:

  1. separate gpt-oss and gemma3
  2. work for both sliding window and full attention layers

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, it's a little dirty solution, but if there's no other option, I believe we can live with it.

I can keep looking for a universal pattern

Any ideas of what this could be?

sliding_window = std::make_shared<v1::Subtract>(v0::Constant::create(element::i32, Shape{}, {2}), offset);
} else if (pattern_map.count(gptoss_gemma3_offset)) {
*is_gptoss_gemma3 = true;
is_gemma3 = optional_model_wide_params.count("token_type_ids");
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In fact any model with token_type_ids and matching sliding window pattern will set this is_gemma3 flag true, why not simply name this variable has_token_type_ids?
Or set has_sliding_window here instead, and use below.
Also currently is_gemma3 will be false for causal mask case (no sliding window) within the same model.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I renamed the variable, and regarding the no sliding window case, the explanation is provided in #34661 (comment).

Comment on lines 760 to 763
if (is_gemma3) {
pa_arguments.insert(pa_arguments.begin() + 25, handle_gemma3_token_type_ids(optional_model_wide_params));
} else {
pa_arguments.insert(pa_arguments.begin() + 25, v0::Constant::create(element::i32, Shape{0}, {}));
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The variable naming is tight to gemma3 but it can be generic for any model having has_token_type_ids and has_sliding_window true.
It is currently applied for sliding_window case only, but as a next step it could be extended to causal case as well then this if else will be reduced to single case:

pa_arguments.insert(pa_arguments.begin() + 25, handle_token_type_ids(optional_model_wide_params));

Suggested change
if (is_gemma3) {
pa_arguments.insert(pa_arguments.begin() + 25, handle_gemma3_token_type_ids(optional_model_wide_params));
} else {
pa_arguments.insert(pa_arguments.begin() + 25, v0::Constant::create(element::i32, Shape{0}, {}));
if (has_sliding_window) {
pa_arguments.insert(pa_arguments.begin() + 25, handle_token_type_ids(optional_model_wide_params));
} else {
pa_arguments.insert(pa_arguments.begin() + 25, v0::Constant::create(element::i32, Shape{0}, {}));

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I changed the variable name. The token_type_ids is currently working also for causal case, see #34661 (comment).

…into attn_fixes

Signed-off-by: p-wysocki <przemyslaw.wysocki@intel.com>
Signed-off-by: p-wysocki <przemyslaw.wysocki@intel.com>
Signed-off-by: p-wysocki <przemyslaw.wysocki@intel.com>
Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This PR refines Gemma3 token_type_ids handling for the SDPA→PagedAttention transformation and strengthens PagedAttentionExtension type-propagation coverage around the newly-supported token_type_ids ranks.

Changes:

  • Add type-prop tests validating token_type_ids acceptance for rank-1/rank-2, dynamic shape, and invalid type/rank cases.
  • Simplify token_type_ids retrieval/conversion in the Gemma3 path by assuming presence when the Gemma3 condition is met and avoiding an internal fallback.

Reviewed changes

Copilot reviewed 2 out of 2 changed files in this pull request and generated 1 comment.

File Description
src/core/tests/type_prop/paged_attention.cpp Adds dedicated type-prop tests for token_type_ids rank/type validation.
src/common/transformations/src/transformations/sdpa_to_paged_attention/state_management_pattern.cpp Adjusts Gemma3 detection flagging and streamlines token_type_ids handling (convert-to-i32 when needed).

Comment on lines +438 to +442
// Set to true once a sliding_attention layer matching the gptoss_gemma3 pattern is found
// alongside a token_type_ids model input - the combination that uniquely identifies Gemma3
// since pattern for full attention mask in Gemma3 is different than sliding window
// it has to be persistent in the callback, so shared_ptr is used
auto has_token_type_ids = std::make_shared<bool>(false);
Copy link

Copilot AI Mar 31, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[LOW] has_token_type_ids is used as a persisted “Gemma3 detected / enable token_type_ids wiring” flag (it’s only updated when the gptoss_gemma3 sliding-window pattern matches), so the name is misleading—there are cases where the model may have a token_type_ids input but this flag stays false until that pattern is seen. Consider renaming it to something like is_gemma3 / enable_gemma3_token_type_ids to reflect the actual semantics and reduce the chance of future misuse.

Copilot generated this review using guidance from repository custom instructions.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

category: Core OpenVINO Core (aka ngraph) category: transformations OpenVINO Runtime library - Transformations

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants