Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
36 commits
Select commit Hold shift + click to select a range
1abe8ec
WIP
p-wysocki Feb 9, 2026
34763dd
WIP
p-wysocki Feb 9, 2026
e7f8238
WIP
p-wysocki Feb 9, 2026
daef794
WIP
p-wysocki Feb 12, 2026
fdb3a73
Add tests
p-wysocki Feb 12, 2026
829c430
initial clenaup
p-wysocki Feb 12, 2026
a04d165
Set input as optional
p-wysocki Feb 12, 2026
96188fb
Correct tests
p-wysocki Feb 12, 2026
4d9f607
Remove reshape from graph
p-wysocki Feb 12, 2026
1f6a6d1
Remove debug prints
p-wysocki Feb 12, 2026
2bf68b5
Clenaup
p-wysocki Feb 13, 2026
bcbb855
Merge branch 'master' into attn_idea_2
p-wysocki Feb 13, 2026
ed3374d
Sliding window working
p-wysocki Feb 18, 2026
0fd5001
Move sw to gptoss logic
p-wysocki Feb 18, 2026
5fbbf5e
Working, with debug prints
p-wysocki Feb 18, 2026
81ab320
Cleanup
p-wysocki Feb 18, 2026
4e0d5ac
Merge branch 'attn_idea_2' of https://github.com/p-wysocki/openvino i…
p-wysocki Feb 18, 2026
5f5af24
Cleanup
p-wysocki Feb 18, 2026
362cb80
update copyright
p-wysocki Feb 18, 2026
7841c70
Fix transformation tests, add new one
p-wysocki Feb 18, 2026
9810dc3
Fix convert input tests
p-wysocki Feb 18, 2026
890804b
Fix clang
p-wysocki Feb 18, 2026
6a62dda
Fix smoke tests
p-wysocki Feb 18, 2026
dfc6e1f
Fix smoke test
p-wysocki Feb 18, 2026
a412f6c
Update GPU input count
p-wysocki Feb 25, 2026
810130d
CR
p-wysocki Mar 2, 2026
acbd73e
Add token_type_ids to gemma only
p-wysocki Mar 2, 2026
2da2303
Fix gpu test
p-wysocki Mar 2, 2026
03e9935
Merge branch 'master' into attn_idea_2
p-wysocki Mar 10, 2026
e2347e1
Merge branch 'master' of https://github.com/openvinotoolkit/openvino …
p-wysocki Mar 11, 2026
7176099
Fix conflict issues
p-wysocki Mar 11, 2026
af07897
Apply CR
p-wysocki Mar 12, 2026
c99b9c5
Merge branch 'master' of https://github.com/openvinotoolkit/openvino …
p-wysocki Mar 12, 2026
7e6ba0e
Merge branch 'master' of https://github.com/openvinotoolkit/openvino …
p-wysocki Mar 19, 2026
84042e8
working
p-wysocki Mar 19, 2026
6bc0315
Apply CR
p-wysocki Mar 19, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -200,14 +200,11 @@ static std::shared_ptr<ov::Node> handle_baichuan2_13b_alibi(

static std::shared_ptr<ov::Node> handle_gemma3_token_type_ids(
const std::map<std::string, std::shared_ptr<v0::Parameter>>& optional_model_wide_params) {
if (optional_model_wide_params.find("token_type_ids") != optional_model_wide_params.end()) {
auto param = optional_model_wide_params.at("token_type_ids");
if (param->get_element_type() != ov::element::i32) {
return std::make_shared<v0::Convert>(param, ov::element::i32);
}
return param;
auto param = optional_model_wide_params.at("token_type_ids");
if (param->get_element_type() != ov::element::i32) {
return std::make_shared<v0::Convert>(param, ov::element::i32);
}
return v0::Constant::create(ov::element::i32, ov::Shape{0}, {});
return param;
}

static std::tuple<std::shared_ptr<ov::Node>, std::shared_ptr<ov::Node>> phi3_sliding_window_pattern() {
Expand Down Expand Up @@ -438,10 +435,11 @@ ov::pass::StateManagementPattern::StateManagementPattern(

auto sdpa_variants = std::make_shared<Or>(OutputVector{sdpa_with_4_inputs, sdpa_with_5_inputs, sdpa_with_6_inputs});

// Shared flag to track whether the model is Gemma3, set when any layer matches
// the gptoss_gemma3 sliding window pattern. Combined with the token_type_ids check,
// this uniquely identifies Gemma3 (gpt-oss shares the pattern but lacks token_type_ids).
auto is_gptoss_gemma3 = std::make_shared<bool>(false);
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we define this variable inside the callback?

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agree it looks strange is required to define it outside

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Gemma3 has a repeating sequence of attention layers: 5x sliding window attention, 1x full attention. The pattern we currently have detects sliding window, but token_type_ids has to be passed to full attention layers as well.

has_token_type_ids is defined outside of the callback as shared_ptr, because it has to stay consistent between all lambda callbacks - since lambda's capture is =, it gets a new shared pointer to the object. Without it, the token_type_ids would be routed to PA only for sliding window PAa, and not for full attention PAs.

Technically we could detect full attention pattern to avoid gpt-oss/gemma3 mixup and do the same trick, but then the first 5x sliding window attentions would not receive token_type_ids input, because only the first full attention layer (6th in line) would set the variable to true.

Summing up, it may not be as clean as I'd like it to be, but it works. If you insist that this piece of code will cause issues I can keep looking for a universal pattern which would:

  1. separate gpt-oss and gemma3
  2. work for both sliding window and full attention layers

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, it's a little dirty solution, but if there's no other option, I believe we can live with it.

I can keep looking for a universal pattern

Any ideas of what this could be?

// Set to true once a sliding_attention layer matching the gptoss_gemma3 pattern is found
// alongside a token_type_ids model input - the combination that uniquely identifies Gemma3
// since pattern for full attention mask in Gemma3 is different than sliding window
// it has to be persistent in the callback, so shared_ptr is used
auto has_token_type_ids = std::make_shared<bool>(false);
Comment on lines +438 to +442
Copy link

Copilot AI Mar 31, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[LOW] has_token_type_ids is used as a persisted “Gemma3 detected / enable token_type_ids wiring” flag (it’s only updated when the gptoss_gemma3 sliding-window pattern matches), so the name is misleading—there are cases where the model may have a token_type_ids input but this flag stays false until that pattern is seen. Consider renaming it to something like is_gemma3 / enable_gemma3_token_type_ids to reflect the actual semantics and reduce the chance of future misuse.

Copilot generated this review using guidance from repository custom instructions.

ov::matcher_pass_callback callback = [=,
&kv_parameters,
Expand Down Expand Up @@ -621,7 +619,9 @@ ov::pass::StateManagementPattern::StateManagementPattern(
}
sliding_window = std::make_shared<v1::Subtract>(v0::Constant::create(element::i32, Shape{}, {2}), offset);
} else if (pattern_map.count(gptoss_gemma3_offset)) {
*is_gptoss_gemma3 = true;
// gptoss_gemma3 pattern + token_type_ids input uniquely identifies Gemma3;
// gpt-oss shares this sliding window pattern but has no token_type_ids.
*has_token_type_ids = optional_model_wide_params.count("token_type_ids");
auto offset = pattern_map.at(gptoss_gemma3_offset).get_node_shared_ptr();
if (pattern_map.at(gptoss_gemma3_offset).get_partial_shape().rank() != 0) {
offset = std::make_shared<v15::Squeeze>(offset);
Expand Down Expand Up @@ -756,7 +756,7 @@ ov::pass::StateManagementPattern::StateManagementPattern(
}
OPENVINO_ASSERT(pa_arguments.size() == 25);

if (*is_gptoss_gemma3) {
if (*has_token_type_ids) {
pa_arguments.insert(pa_arguments.begin() + 25, handle_gemma3_token_type_ids(optional_model_wide_params));
} else {
pa_arguments.insert(pa_arguments.begin() + 25, v0::Constant::create(element::i32, Shape{0}, {}));
Expand Down
93 changes: 93 additions & 0 deletions src/core/tests/type_prop/paged_attention.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -270,5 +270,98 @@ TEST(type_prop, paged_attention_invalid_rank_key_cache) {
EXPECT_THROW(std::ignore = std::make_shared<op::PagedAttentionExtension>(args), ov::NodeValidationFailure);
}

static ov::OutputVector make_args_with_token_type(const std::shared_ptr<ov::op::v0::Parameter>& token_type_ids) {
using namespace ov::op;
const auto query = std::make_shared<v0::Parameter>(ov::element::f32, ov::PartialShape{3, 4});
const auto key = std::make_shared<v0::Parameter>(ov::element::f32, ov::PartialShape{3, 4});
const auto value = std::make_shared<v0::Parameter>(ov::element::f32, ov::PartialShape{3, 4});
const auto key_cache = std::make_shared<v0::Parameter>(ov::element::f32, ov::PartialShape{6, 2, 5, 4});
const auto value_cache = std::make_shared<v0::Parameter>(ov::element::f32, ov::PartialShape{6, 2, 5, 4});
const auto past_lens = std::make_shared<v0::Parameter>(ov::element::i32, ov::PartialShape{5});
const auto subsequence_begins = std::make_shared<v0::Parameter>(ov::element::i32, ov::PartialShape{5});
const auto block_indices = std::make_shared<v0::Parameter>(ov::element::i32, ov::PartialShape{15});
const auto block_indices_begins = std::make_shared<v0::Parameter>(ov::element::i32, ov::PartialShape{8});
const auto scale = std::make_shared<v0::Parameter>(ov::element::f32, ov::PartialShape{});
const auto sliding_window = std::make_shared<v0::Parameter>(ov::element::i32, ov::PartialShape{});
const auto alibi_slopes = std::make_shared<v0::Parameter>(ov::element::f32, ov::PartialShape{9});
const auto max_context_len = std::make_shared<v0::Parameter>(ov::element::i32, ov::PartialShape{});
const auto score_aggregation_window = std::make_shared<v0::Parameter>(ov::element::i32, ov::PartialShape{5});
const auto rotated_block_indices = std::make_shared<v0::Parameter>(ov::element::i32, ov::PartialShape{3});
const auto rotation_deltas = std::make_shared<v0::Parameter>(ov::element::i32, ov::PartialShape{12, 1});
const auto rotation_trig_lut = std::make_shared<v0::Parameter>(ov::element::f32, ov::PartialShape{256, 4});
const auto xattention_threshold = std::make_shared<v0::Parameter>(ov::element::f32, ov::PartialShape{5});
const auto xattention_block_size = std::make_shared<v0::Parameter>(ov::element::i32, ov::PartialShape{});
const auto xattention_stride = std::make_shared<v0::Parameter>(ov::element::i32, ov::PartialShape{});
const auto sinks = std::make_shared<v0::Parameter>(ov::element::f32, ov::PartialShape{1, 2, 1, 1});
const auto adaptive_rkv_start_size = std::make_shared<v0::Parameter>(ov::element::i32, ov::PartialShape{});
const auto adaptive_rkv_evictable_sizes = std::make_shared<v0::Parameter>(ov::element::i32, ov::PartialShape{5});
const auto adaptive_rkv_diversity_block_set_indices =
std::make_shared<v0::Parameter>(ov::element::i32, ov::PartialShape{10});
const auto adaptive_rkv_diversity_block_set_indices_begins =
std::make_shared<v0::Parameter>(ov::element::i32, ov::PartialShape{5});

return {query,
key,
value,
key_cache,
value_cache,
past_lens,
subsequence_begins,
block_indices,
block_indices_begins,
scale,
sliding_window,
alibi_slopes,
max_context_len,
score_aggregation_window,
rotated_block_indices,
rotation_deltas,
rotation_trig_lut,
xattention_threshold,
xattention_block_size,
xattention_stride,
sinks,
adaptive_rkv_start_size,
adaptive_rkv_evictable_sizes,
adaptive_rkv_diversity_block_set_indices,
adaptive_rkv_diversity_block_set_indices_begins,
token_type_ids};
}

TEST(type_prop, paged_attention_token_type_ids_1d) {
const auto token_type_ids = std::make_shared<op::v0::Parameter>(ov::element::i32, ov::PartialShape{3});
const auto args = make_args_with_token_type(token_type_ids);
const auto op = std::make_shared<op::PagedAttentionExtension>(args);
EXPECT_EQ(op->get_output_element_type(0), ov::element::f32);
EXPECT_EQ(op->get_output_partial_shape(0), (ov::PartialShape{3, 4}));
}

TEST(type_prop, paged_attention_token_type_ids_2d) {
const auto token_type_ids = std::make_shared<op::v0::Parameter>(ov::element::i32, ov::PartialShape{1, 3});
const auto args = make_args_with_token_type(token_type_ids);
const auto op = std::make_shared<op::PagedAttentionExtension>(args);
EXPECT_EQ(op->get_output_element_type(0), ov::element::f32);
EXPECT_EQ(op->get_output_partial_shape(0), (ov::PartialShape{3, 4}));
}

TEST(type_prop, paged_attention_token_type_ids_dynamic_shape) {
const auto token_type_ids =
std::make_shared<op::v0::Parameter>(ov::element::i32, ov::PartialShape{ov::Dimension::dynamic()});
const auto args = make_args_with_token_type(token_type_ids);
EXPECT_NO_THROW(std::ignore = std::make_shared<op::PagedAttentionExtension>(args));
}

TEST(type_prop, paged_attention_invalid_type_token_type_ids) {
const auto token_type_ids = std::make_shared<op::v0::Parameter>(ov::element::f32, ov::PartialShape{3});
const auto args = make_args_with_token_type(token_type_ids);
EXPECT_THROW(std::ignore = std::make_shared<op::PagedAttentionExtension>(args), ov::NodeValidationFailure);
}

TEST(type_prop, paged_attention_invalid_rank_token_type_ids) {
const auto token_type_ids = std::make_shared<op::v0::Parameter>(ov::element::i32, ov::PartialShape{1, 1, 3});
const auto args = make_args_with_token_type(token_type_ids);
EXPECT_THROW(std::ignore = std::make_shared<op::PagedAttentionExtension>(args), ov::NodeValidationFailure);
}

} // namespace testing
} // namespace ov
Loading