Skip to content

Commit fcb8b1c

Browse files
authored
[NPUW] Fixed Gemma2 4K sliding window work with short prompts on NPU (#32891)
### Details: - *Relaxed Phi3SlidingMask2 pattern in order to allow `attention_mask` and `position_ids` to be located more flexibly in the model without breaking the sliding window mask calculation* ### Tickets: - *EISW-190610*
1 parent 46209d0 commit fcb8b1c

File tree

1 file changed

+53
-41
lines changed

1 file changed

+53
-41
lines changed

src/plugins/intel_npu/src/plugin/npuw/llm_compiled_model.cpp

Lines changed: 53 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
#include "openvino/openvino.hpp"
1717
#include "openvino/opsets/opset13.hpp"
1818
#include "openvino/pass/graph_rewrite.hpp"
19+
#include "openvino/pass/manager.hpp"
1920
#include "openvino/pass/matcher_pass.hpp"
2021
#include "openvino/pass/pattern/op/optional.hpp"
2122
#include "openvino/pass/pattern/op/or.hpp"
@@ -398,12 +399,12 @@ class GemmaSlidingMask : public ov::pass::MatcherPass {
398399
}
399400
};
400401

401-
class Phi3SlidingMask : public ov::pass::MatcherPass {
402+
class OldPhi3SlidingMaskMatcher : public ov::pass::MatcherPass {
402403
public:
403-
OPENVINO_MATCHER_PASS_RTTI("npuw::LLMCompiledModel::Phi3SlidingMask");
404+
OPENVINO_MATCHER_PASS_RTTI("npuw::LLMCompiledModel::OldPhi3SlidingMaskMatcher");
404405

405-
Phi3SlidingMask() {
406-
// Search for the Phi3 sliding mask pattern to extend it to work with right-padded
406+
OldPhi3SlidingMaskMatcher() {
407+
// Search for the Phi3 old sliding mask pattern to extend it to work with right-padded
407408
// past tokens and left-padded present tokens.
408409
//
409410
// Mask creation is simply done via "less_equal" and "greater" operations between
@@ -587,16 +588,17 @@ class Phi3SlidingMask : public ov::pass::MatcherPass {
587588

588589
return true;
589590
};
590-
register_matcher(std::make_shared<opp::Matcher>(inv_sliding_attention_mask, "Phi3SlidingMask"),
591+
register_matcher(std::make_shared<opp::Matcher>(inv_sliding_attention_mask, "OldPhi3SlidingMaskMatcher"),
591592
std::move(callback));
592593
}
593594
};
594595

595-
class Phi3SlidingMask2 : public ov::pass::MatcherPass {
596+
class Phi3SlidingMaskMatcher : public ov::pass::MatcherPass {
596597
public:
597-
OPENVINO_MATCHER_PASS_RTTI("npuw::LLMCompiledModel::Phi3SlidingMask2");
598+
OPENVINO_MATCHER_PASS_RTTI("npuw::LLMCompiledModel::Phi3SlidingMaskMatcher");
598599

599-
Phi3SlidingMask2() {
600+
Phi3SlidingMaskMatcher(const std::shared_ptr<ov::Node>& attention_mask_node_ptr,
601+
const std::shared_ptr<ov::Node>& position_ids_node_ptr) {
600602
// Search for the Phi3 sliding mask pattern to extend it to work with right-padded
601603
// past tokens and left-padded present tokens. Logic to replace pattern is the same
602604
// as in Phi3SlidingMask rewriter, but adjusted to another set of operations for
@@ -622,17 +624,12 @@ class Phi3SlidingMask2 : public ov::pass::MatcherPass {
622624
};
623625

624626
auto past_kv_len = opp::wrap_type<ov::op::v8::Gather>({opp::any_input(), opp::any_input(), opp::any_input()});
625-
auto pos_ids_param = opp::wrap_type<ov::op::v0::Parameter>();
626-
auto pos_ids_shape_of = opp::wrap_type<ov::op::v3::ShapeOf>({pos_ids_param});
627-
auto pos_ids_len = opp::wrap_type<ov::op::v8::Gather>({pos_ids_shape_of, opp::any_input(), opp::any_input()});
628-
auto full_ctx_len = opp::wrap_type<ov::op::v1::Add>({past_kv_len, pos_ids_len});
627+
auto full_ctx_len = opp::wrap_type<ov::op::v1::Add>({past_kv_len, opp::any_input()});
629628
auto query_range = opp::wrap_type<ov::op::v4::Range>({past_kv_len, full_ctx_len, opp::any_input()});
630629
auto query_range_column = unsqueeze_sequence(query_range);
631630

632631
auto zero_const = opp::wrap_type<ov::op::v0::Constant>();
633-
auto pos_ids_len_reshaped = opp::wrap_type<ov::op::v1::Reshape>({pos_ids_len, opp::any_input()});
634-
auto pos_ids_len_squeezed = opp::wrap_type<ov::op::v0::Squeeze>({pos_ids_len_reshaped, opp::any_input()});
635-
auto full_ctx_len_2 = opp::wrap_type<ov::op::v1::Add>({pos_ids_len_squeezed, past_kv_len});
632+
auto full_ctx_len_2 = opp::wrap_type<ov::op::v1::Add>({opp::any_input(), past_kv_len});
636633
auto key_range = opp::wrap_type<ov::op::v4::Range>({zero_const, full_ctx_len_2, opp::any_input()});
637634
auto key_range_row = unsqueeze_sequence(key_range);
638635
auto opt_key_range_row_f32 = opp::optional<ov::op::v0::Convert>({key_range_row->output(0)});
@@ -646,36 +643,19 @@ class Phi3SlidingMask2 : public ov::pass::MatcherPass {
646643
auto causal_mask = opp::wrap_type<ov::op::v1::LessEqual>({opt_key_range_row_f32, query_range_column});
647644

648645
auto sliding_and_causal_mask = opp::wrap_type<ov::op::v13::BitwiseAnd>({sliding_and_true, causal_mask});
649-
auto sliding_causal_and_true =
650-
opp::wrap_type<ov::op::v13::BitwiseAnd>({opp::any_input(), sliding_and_causal_mask});
651-
652-
auto atten_mask_param = opp::wrap_type<ov::op::v0::Parameter>();
653-
auto atten_mask_boolean = opp::wrap_type<ov::op::v0::Convert>({atten_mask_param});
654-
auto atten_mask_reshaped = opp::wrap_type<ov::op::v1::Reshape>({atten_mask_boolean, opp::any_input()});
655-
auto atten_mask_gathered =
656-
opp::wrap_type<ov::op::v8::Gather>({atten_mask_reshaped, opp::any_input(), opp::any_input()});
657-
auto atten_mask_reshaped_2 = opp::wrap_type<ov::op::v1::Reshape>({atten_mask_gathered, opp::any_input()});
658-
auto atten_mask_reshaped_3 = opp::wrap_type<ov::op::v1::Reshape>({atten_mask_reshaped_2, opp::any_input()});
659-
660-
auto final_sliding_attention =
661-
opp::wrap_type<ov::op::v13::BitwiseAnd>({sliding_causal_and_true, atten_mask_reshaped_3});
662646

663647
auto callback = [=](ov::pass::pattern::Matcher& m) {
664648
LOG_INFO("Found (4.53) pattern for Phi-3 Sliding Window Attention, will be replaced with custom for static "
665649
"shapes.");
666650
auto& node_to_output = m.get_pattern_value_map();
667651
auto node_past_kv_len = node_to_output.at(past_kv_len).get_node_shared_ptr();
668-
auto node_pos_ids_param = node_to_output.at(pos_ids_param).get_node_shared_ptr();
669652
auto node_full_ctx_len = node_to_output.at(full_ctx_len).get_node_shared_ptr();
670-
auto node_atten_mask_boolean = node_to_output.at(atten_mask_boolean).get_node_shared_ptr();
671653
auto node_neg_window_size = node_to_output.at(neg_window_size).get_node_shared_ptr();
672654
auto node_sliding_mask = node_to_output.at(sliding_mask).get_node_shared_ptr();
673655
auto node_sliding_and_causal_mask = node_to_output.at(sliding_and_causal_mask).get_node_shared_ptr();
674656

675657
auto matched_past_kv_len = std::static_pointer_cast<ov::op::v8::Gather>(node_past_kv_len);
676-
auto matched_pos_ids_input = std::static_pointer_cast<ov::op::v0::Parameter>(node_pos_ids_param);
677658
auto matched_full_ctx_len = std::static_pointer_cast<ov::op::v1::Add>(node_full_ctx_len);
678-
auto matched_atten_mask_boolean = std::static_pointer_cast<ov::op::v0::Parameter>(node_atten_mask_boolean);
679659
std::shared_ptr<ov::Node> matched_key_range_row = nullptr;
680660
if (node_to_output.count(opt_key_range_row_f32)) {
681661
auto node_key_range_row_f32 = node_to_output[opt_key_range_row_f32].get_node_shared_ptr();
@@ -692,14 +672,19 @@ class Phi3SlidingMask2 : public ov::pass::MatcherPass {
692672
"Sliding window size constant must be of size 1, but got " +
693673
std::to_string(matched_neg_window_size->get_output_size()));
694674

675+
std::shared_ptr<ov::Node> passed_attention_mask = attention_mask_node_ptr;
676+
std::shared_ptr<ov::Node> passed_position_ids = position_ids_node_ptr;
677+
OPENVINO_ASSERT(passed_attention_mask, "Passed attention_mask node is nullptr!");
678+
OPENVINO_ASSERT(passed_position_ids, "Passed position_ids node is nullptr!");
679+
695680
auto const_zero = std::make_shared<ov::op::v0::Constant>(ov::element::i64, ov::Shape{}, 0);
696681
auto const_one = std::make_shared<ov::op::v0::Constant>(ov::element::i64, ov::Shape{}, 1);
697682
auto const_three = std::make_shared<ov::op::v0::Constant>(ov::element::i64, ov::Shape{}, 3);
698683

699684
// 1.(K range > (Q_pos range - sliding window).T) & (K range <= Q range.T)
700-
std::shared_ptr<ov::Node> query_range_as_pos_ids = matched_pos_ids_input;
685+
std::shared_ptr<ov::Node> query_range_as_pos_ids = passed_position_ids;
701686
if (matched_neg_window_size->output(0).get_element_type() == ov::element::f32) {
702-
query_range_as_pos_ids = std::make_shared<ov::op::v0::Convert>(matched_pos_ids_input, ov::element::f32);
687+
query_range_as_pos_ids = std::make_shared<ov::op::v0::Convert>(passed_position_ids, ov::element::f32);
703688
}
704689
auto query_range_as_pos_ids_unsqueezed =
705690
std::make_shared<ov::op::v0::Unsqueeze>(query_range_as_pos_ids, const_zero);
@@ -737,7 +722,9 @@ class Phi3SlidingMask2 : public ov::pass::MatcherPass {
737722
auto full_ctx_len_reshaped =
738723
std::make_shared<ov::op::v1::Reshape>(matched_full_ctx_len, shape_rank_one_const, false);
739724
auto const_one_rank_one = std::make_shared<ov::op::v0::Constant>(ov::element::i64, ov::Shape{1}, 1);
740-
auto present_atten_mask_bool = std::make_shared<ov::op::v8::Slice>(matched_atten_mask_boolean,
725+
auto attention_mask_bool =
726+
std::make_shared<ov::op::v0::Convert>(passed_attention_mask, ov::element::boolean);
727+
auto present_atten_mask_bool = std::make_shared<ov::op::v8::Slice>(attention_mask_bool,
741728
past_len_reshaped,
742729
full_ctx_len_reshaped,
743730
const_one_rank_one,
@@ -755,11 +742,38 @@ class Phi3SlidingMask2 : public ov::pass::MatcherPass {
755742

756743
return true;
757744
};
758-
register_matcher(std::make_shared<opp::Matcher>(final_sliding_attention, "Phi3SlidingMask2"),
745+
register_matcher(std::make_shared<opp::Matcher>(sliding_and_causal_mask, "Phi3SlidingMaskMatcher"),
759746
std::move(callback));
760747
}
761748
};
762749

750+
class Phi3SlidingMask : public ov::pass::ModelPass {
751+
public:
752+
OPENVINO_MODEL_PASS_RTTI("ov::npuw::LLMCompiledModel::Phi3SlidingMask");
753+
Phi3SlidingMask() = default;
754+
bool run_on_model(const std::shared_ptr<ov::Model>& model) override {
755+
std::shared_ptr<ov::Node> attention_mask_node_ptr = nullptr;
756+
std::shared_ptr<ov::Node> position_ids_node_ptr = nullptr;
757+
for (const auto& i : model->inputs()) {
758+
if (i.get_any_name() == "attention_mask") {
759+
attention_mask_node_ptr = i.get_node_shared_ptr();
760+
}
761+
if (i.get_any_name() == "position_ids") {
762+
position_ids_node_ptr = i.get_node_shared_ptr();
763+
}
764+
}
765+
OPENVINO_ASSERT(attention_mask_node_ptr, "attention_mask input is not found!");
766+
OPENVINO_ASSERT(position_ids_node_ptr, "position_ids input is not found!");
767+
768+
ov::pass::Manager manager;
769+
manager.set_per_pass_validation(true);
770+
const auto rewriter = manager.register_pass<ov::pass::GraphRewrite>();
771+
rewriter->add_matcher<Phi3SlidingMaskMatcher>(attention_mask_node_ptr, position_ids_node_ptr);
772+
rewriter->add_matcher<OldPhi3SlidingMaskMatcher>();
773+
return manager.run_passes(model);
774+
}
775+
};
776+
763777
namespace {
764778
uint32_t align_to(uint32_t value, uint32_t alignment) {
765779
return (value + alignment - 1) & ~(alignment - 1);
@@ -837,11 +851,9 @@ void patch_phi3_sliding_mask(const std::shared_ptr<ov::Model>& model) {
837851
// Qwen2.5 VL/Omni uses 3D position_ids, which can't be directly used
838852
// in creation of sliding window mask.
839853
if (!ov::npuw::util::has_input(model, "token_type_ids") && !ov::npuw::util::has_input(model, "inputs_embeds")) {
840-
ov::pass::GraphRewrite rewr;
841-
rewr.add_matcher<Phi3SlidingMask2>();
842-
rewr.add_matcher<Phi3SlidingMask>();
843-
rewr.run_on_model(model);
844-
model->validate_nodes_and_infer_types();
854+
ov::pass::Manager manager;
855+
manager.register_pass<Phi3SlidingMask>();
856+
manager.run_passes(model);
845857
}
846858
}
847859

0 commit comments

Comments
 (0)