1616#include " openvino/openvino.hpp"
1717#include " openvino/opsets/opset13.hpp"
1818#include " openvino/pass/graph_rewrite.hpp"
19+ #include " openvino/pass/manager.hpp"
1920#include " openvino/pass/matcher_pass.hpp"
2021#include " openvino/pass/pattern/op/optional.hpp"
2122#include " openvino/pass/pattern/op/or.hpp"
@@ -398,12 +399,12 @@ class GemmaSlidingMask : public ov::pass::MatcherPass {
398399 }
399400};
400401
401- class Phi3SlidingMask : public ov ::pass::MatcherPass {
402+ class OldPhi3SlidingMaskMatcher : public ov ::pass::MatcherPass {
402403public:
403- OPENVINO_MATCHER_PASS_RTTI (" npuw::LLMCompiledModel::Phi3SlidingMask " );
404+ OPENVINO_MATCHER_PASS_RTTI (" npuw::LLMCompiledModel::OldPhi3SlidingMaskMatcher " );
404405
405- Phi3SlidingMask () {
406- // Search for the Phi3 sliding mask pattern to extend it to work with right-padded
406+ OldPhi3SlidingMaskMatcher () {
407+ // Search for the Phi3 old sliding mask pattern to extend it to work with right-padded
407408 // past tokens and left-padded present tokens.
408409 //
409410 // Mask creation is simply done via "less_equal" and "greater" operations between
@@ -587,16 +588,17 @@ class Phi3SlidingMask : public ov::pass::MatcherPass {
587588
588589 return true ;
589590 };
590- register_matcher (std::make_shared<opp::Matcher>(inv_sliding_attention_mask, " Phi3SlidingMask " ),
591+ register_matcher (std::make_shared<opp::Matcher>(inv_sliding_attention_mask, " OldPhi3SlidingMaskMatcher " ),
591592 std::move (callback));
592593 }
593594};
594595
595- class Phi3SlidingMask2 : public ov ::pass::MatcherPass {
596+ class Phi3SlidingMaskMatcher : public ov ::pass::MatcherPass {
596597public:
597- OPENVINO_MATCHER_PASS_RTTI (" npuw::LLMCompiledModel::Phi3SlidingMask2 " );
598+ OPENVINO_MATCHER_PASS_RTTI (" npuw::LLMCompiledModel::Phi3SlidingMaskMatcher " );
598599
599- Phi3SlidingMask2 () {
600+ Phi3SlidingMaskMatcher (const std::shared_ptr<ov::Node>& attention_mask_node_ptr,
601+ const std::shared_ptr<ov::Node>& position_ids_node_ptr) {
600602 // Search for the Phi3 sliding mask pattern to extend it to work with right-padded
601603 // past tokens and left-padded present tokens. Logic to replace pattern is the same
602604 // as in Phi3SlidingMask rewriter, but adjusted to another set of operations for
@@ -622,17 +624,12 @@ class Phi3SlidingMask2 : public ov::pass::MatcherPass {
622624 };
623625
624626 auto past_kv_len = opp::wrap_type<ov::op::v8::Gather>({opp::any_input (), opp::any_input (), opp::any_input ()});
625- auto pos_ids_param = opp::wrap_type<ov::op::v0::Parameter>();
626- auto pos_ids_shape_of = opp::wrap_type<ov::op::v3::ShapeOf>({pos_ids_param});
627- auto pos_ids_len = opp::wrap_type<ov::op::v8::Gather>({pos_ids_shape_of, opp::any_input (), opp::any_input ()});
628- auto full_ctx_len = opp::wrap_type<ov::op::v1::Add>({past_kv_len, pos_ids_len});
627+ auto full_ctx_len = opp::wrap_type<ov::op::v1::Add>({past_kv_len, opp::any_input ()});
629628 auto query_range = opp::wrap_type<ov::op::v4::Range>({past_kv_len, full_ctx_len, opp::any_input ()});
630629 auto query_range_column = unsqueeze_sequence (query_range);
631630
632631 auto zero_const = opp::wrap_type<ov::op::v0::Constant>();
633- auto pos_ids_len_reshaped = opp::wrap_type<ov::op::v1::Reshape>({pos_ids_len, opp::any_input ()});
634- auto pos_ids_len_squeezed = opp::wrap_type<ov::op::v0::Squeeze>({pos_ids_len_reshaped, opp::any_input ()});
635- auto full_ctx_len_2 = opp::wrap_type<ov::op::v1::Add>({pos_ids_len_squeezed, past_kv_len});
632+ auto full_ctx_len_2 = opp::wrap_type<ov::op::v1::Add>({opp::any_input (), past_kv_len});
636633 auto key_range = opp::wrap_type<ov::op::v4::Range>({zero_const, full_ctx_len_2, opp::any_input ()});
637634 auto key_range_row = unsqueeze_sequence (key_range);
638635 auto opt_key_range_row_f32 = opp::optional<ov::op::v0::Convert>({key_range_row->output (0 )});
@@ -646,36 +643,19 @@ class Phi3SlidingMask2 : public ov::pass::MatcherPass {
646643 auto causal_mask = opp::wrap_type<ov::op::v1::LessEqual>({opt_key_range_row_f32, query_range_column});
647644
648645 auto sliding_and_causal_mask = opp::wrap_type<ov::op::v13::BitwiseAnd>({sliding_and_true, causal_mask});
649- auto sliding_causal_and_true =
650- opp::wrap_type<ov::op::v13::BitwiseAnd>({opp::any_input (), sliding_and_causal_mask});
651-
652- auto atten_mask_param = opp::wrap_type<ov::op::v0::Parameter>();
653- auto atten_mask_boolean = opp::wrap_type<ov::op::v0::Convert>({atten_mask_param});
654- auto atten_mask_reshaped = opp::wrap_type<ov::op::v1::Reshape>({atten_mask_boolean, opp::any_input ()});
655- auto atten_mask_gathered =
656- opp::wrap_type<ov::op::v8::Gather>({atten_mask_reshaped, opp::any_input (), opp::any_input ()});
657- auto atten_mask_reshaped_2 = opp::wrap_type<ov::op::v1::Reshape>({atten_mask_gathered, opp::any_input ()});
658- auto atten_mask_reshaped_3 = opp::wrap_type<ov::op::v1::Reshape>({atten_mask_reshaped_2, opp::any_input ()});
659-
660- auto final_sliding_attention =
661- opp::wrap_type<ov::op::v13::BitwiseAnd>({sliding_causal_and_true, atten_mask_reshaped_3});
662646
663647 auto callback = [=](ov::pass::pattern::Matcher& m) {
664648 LOG_INFO (" Found (4.53) pattern for Phi-3 Sliding Window Attention, will be replaced with custom for static "
665649 " shapes." );
666650 auto & node_to_output = m.get_pattern_value_map ();
667651 auto node_past_kv_len = node_to_output.at (past_kv_len).get_node_shared_ptr ();
668- auto node_pos_ids_param = node_to_output.at (pos_ids_param).get_node_shared_ptr ();
669652 auto node_full_ctx_len = node_to_output.at (full_ctx_len).get_node_shared_ptr ();
670- auto node_atten_mask_boolean = node_to_output.at (atten_mask_boolean).get_node_shared_ptr ();
671653 auto node_neg_window_size = node_to_output.at (neg_window_size).get_node_shared_ptr ();
672654 auto node_sliding_mask = node_to_output.at (sliding_mask).get_node_shared_ptr ();
673655 auto node_sliding_and_causal_mask = node_to_output.at (sliding_and_causal_mask).get_node_shared_ptr ();
674656
675657 auto matched_past_kv_len = std::static_pointer_cast<ov::op::v8::Gather>(node_past_kv_len);
676- auto matched_pos_ids_input = std::static_pointer_cast<ov::op::v0::Parameter>(node_pos_ids_param);
677658 auto matched_full_ctx_len = std::static_pointer_cast<ov::op::v1::Add>(node_full_ctx_len);
678- auto matched_atten_mask_boolean = std::static_pointer_cast<ov::op::v0::Parameter>(node_atten_mask_boolean);
679659 std::shared_ptr<ov::Node> matched_key_range_row = nullptr ;
680660 if (node_to_output.count (opt_key_range_row_f32)) {
681661 auto node_key_range_row_f32 = node_to_output[opt_key_range_row_f32].get_node_shared_ptr ();
@@ -692,14 +672,19 @@ class Phi3SlidingMask2 : public ov::pass::MatcherPass {
692672 " Sliding window size constant must be of size 1, but got " +
693673 std::to_string (matched_neg_window_size->get_output_size ()));
694674
675+ std::shared_ptr<ov::Node> passed_attention_mask = attention_mask_node_ptr;
676+ std::shared_ptr<ov::Node> passed_position_ids = position_ids_node_ptr;
677+ OPENVINO_ASSERT (passed_attention_mask, " Passed attention_mask node is nullptr!" );
678+ OPENVINO_ASSERT (passed_position_ids, " Passed position_ids node is nullptr!" );
679+
695680 auto const_zero = std::make_shared<ov::op::v0::Constant>(ov::element::i64 , ov::Shape{}, 0 );
696681 auto const_one = std::make_shared<ov::op::v0::Constant>(ov::element::i64 , ov::Shape{}, 1 );
697682 auto const_three = std::make_shared<ov::op::v0::Constant>(ov::element::i64 , ov::Shape{}, 3 );
698683
699684 // 1.(K range > (Q_pos range - sliding window).T) & (K range <= Q range.T)
700- std::shared_ptr<ov::Node> query_range_as_pos_ids = matched_pos_ids_input ;
685+ std::shared_ptr<ov::Node> query_range_as_pos_ids = passed_position_ids ;
701686 if (matched_neg_window_size->output (0 ).get_element_type () == ov::element::f32 ) {
702- query_range_as_pos_ids = std::make_shared<ov::op::v0::Convert>(matched_pos_ids_input , ov::element::f32 );
687+ query_range_as_pos_ids = std::make_shared<ov::op::v0::Convert>(passed_position_ids , ov::element::f32 );
703688 }
704689 auto query_range_as_pos_ids_unsqueezed =
705690 std::make_shared<ov::op::v0::Unsqueeze>(query_range_as_pos_ids, const_zero);
@@ -737,7 +722,9 @@ class Phi3SlidingMask2 : public ov::pass::MatcherPass {
737722 auto full_ctx_len_reshaped =
738723 std::make_shared<ov::op::v1::Reshape>(matched_full_ctx_len, shape_rank_one_const, false );
739724 auto const_one_rank_one = std::make_shared<ov::op::v0::Constant>(ov::element::i64 , ov::Shape{1 }, 1 );
740- auto present_atten_mask_bool = std::make_shared<ov::op::v8::Slice>(matched_atten_mask_boolean,
725+ auto attention_mask_bool =
726+ std::make_shared<ov::op::v0::Convert>(passed_attention_mask, ov::element::boolean);
727+ auto present_atten_mask_bool = std::make_shared<ov::op::v8::Slice>(attention_mask_bool,
741728 past_len_reshaped,
742729 full_ctx_len_reshaped,
743730 const_one_rank_one,
@@ -755,11 +742,38 @@ class Phi3SlidingMask2 : public ov::pass::MatcherPass {
755742
756743 return true ;
757744 };
758- register_matcher (std::make_shared<opp::Matcher>(final_sliding_attention , " Phi3SlidingMask2 " ),
745+ register_matcher (std::make_shared<opp::Matcher>(sliding_and_causal_mask , " Phi3SlidingMaskMatcher " ),
759746 std::move (callback));
760747 }
761748};
762749
750+ class Phi3SlidingMask : public ov ::pass::ModelPass {
751+ public:
752+ OPENVINO_MODEL_PASS_RTTI (" ov::npuw::LLMCompiledModel::Phi3SlidingMask" );
753+ Phi3SlidingMask () = default ;
754+ bool run_on_model (const std::shared_ptr<ov::Model>& model) override {
755+ std::shared_ptr<ov::Node> attention_mask_node_ptr = nullptr ;
756+ std::shared_ptr<ov::Node> position_ids_node_ptr = nullptr ;
757+ for (const auto & i : model->inputs ()) {
758+ if (i.get_any_name () == " attention_mask" ) {
759+ attention_mask_node_ptr = i.get_node_shared_ptr ();
760+ }
761+ if (i.get_any_name () == " position_ids" ) {
762+ position_ids_node_ptr = i.get_node_shared_ptr ();
763+ }
764+ }
765+ OPENVINO_ASSERT (attention_mask_node_ptr, " attention_mask input is not found!" );
766+ OPENVINO_ASSERT (position_ids_node_ptr, " position_ids input is not found!" );
767+
768+ ov::pass::Manager manager;
769+ manager.set_per_pass_validation (true );
770+ const auto rewriter = manager.register_pass <ov::pass::GraphRewrite>();
771+ rewriter->add_matcher <Phi3SlidingMaskMatcher>(attention_mask_node_ptr, position_ids_node_ptr);
772+ rewriter->add_matcher <OldPhi3SlidingMaskMatcher>();
773+ return manager.run_passes (model);
774+ }
775+ };
776+
763777namespace {
764778uint32_t align_to (uint32_t value, uint32_t alignment) {
765779 return (value + alignment - 1 ) & ~(alignment - 1 );
@@ -837,11 +851,9 @@ void patch_phi3_sliding_mask(const std::shared_ptr<ov::Model>& model) {
837851 // Qwen2.5 VL/Omni uses 3D position_ids, which can't be directly used
838852 // in creation of sliding window mask.
839853 if (!ov::npuw::util::has_input (model, " token_type_ids" ) && !ov::npuw::util::has_input (model, " inputs_embeds" )) {
840- ov::pass::GraphRewrite rewr;
841- rewr.add_matcher <Phi3SlidingMask2>();
842- rewr.add_matcher <Phi3SlidingMask>();
843- rewr.run_on_model (model);
844- model->validate_nodes_and_infer_types ();
854+ ov::pass::Manager manager;
855+ manager.register_pass <Phi3SlidingMask>();
856+ manager.run_passes (model);
845857 }
846858}
847859
0 commit comments