Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
94 changes: 53 additions & 41 deletions src/plugins/intel_npu/src/plugin/npuw/llm_compiled_model.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
#include "openvino/openvino.hpp"
#include "openvino/opsets/opset13.hpp"
#include "openvino/pass/graph_rewrite.hpp"
#include "openvino/pass/manager.hpp"
#include "openvino/pass/matcher_pass.hpp"
#include "openvino/pass/pattern/op/optional.hpp"
#include "openvino/pass/pattern/op/or.hpp"
Expand Down Expand Up @@ -398,12 +399,12 @@ class GemmaSlidingMask : public ov::pass::MatcherPass {
}
};

class Phi3SlidingMask : public ov::pass::MatcherPass {
class OldPhi3SlidingMaskMatcher : public ov::pass::MatcherPass {
public:
OPENVINO_MATCHER_PASS_RTTI("npuw::LLMCompiledModel::Phi3SlidingMask");
OPENVINO_MATCHER_PASS_RTTI("npuw::LLMCompiledModel::OldPhi3SlidingMaskMatcher");

Phi3SlidingMask() {
// Search for the Phi3 sliding mask pattern to extend it to work with right-padded
OldPhi3SlidingMaskMatcher() {
// Search for the Phi3 old sliding mask pattern to extend it to work with right-padded
// past tokens and left-padded present tokens.
//
// Mask creation is simply done via "less_equal" and "greater" operations between
Expand Down Expand Up @@ -587,16 +588,17 @@ class Phi3SlidingMask : public ov::pass::MatcherPass {

return true;
};
register_matcher(std::make_shared<opp::Matcher>(inv_sliding_attention_mask, "Phi3SlidingMask"),
register_matcher(std::make_shared<opp::Matcher>(inv_sliding_attention_mask, "OldPhi3SlidingMaskMatcher"),
std::move(callback));
}
};

class Phi3SlidingMask2 : public ov::pass::MatcherPass {
class Phi3SlidingMaskMatcher : public ov::pass::MatcherPass {
public:
OPENVINO_MATCHER_PASS_RTTI("npuw::LLMCompiledModel::Phi3SlidingMask2");
OPENVINO_MATCHER_PASS_RTTI("npuw::LLMCompiledModel::Phi3SlidingMaskMatcher");

Phi3SlidingMask2() {
Phi3SlidingMaskMatcher(const std::shared_ptr<ov::Node>& attention_mask_node_ptr,
const std::shared_ptr<ov::Node>& position_ids_node_ptr) {
// Search for the Phi3 sliding mask pattern to extend it to work with right-padded
// past tokens and left-padded present tokens. Logic to replace pattern is the same
// as in Phi3SlidingMask rewriter, but adjusted to another set of operations for
Expand All @@ -622,17 +624,12 @@ class Phi3SlidingMask2 : public ov::pass::MatcherPass {
};

auto past_kv_len = opp::wrap_type<ov::op::v8::Gather>({opp::any_input(), opp::any_input(), opp::any_input()});
auto pos_ids_param = opp::wrap_type<ov::op::v0::Parameter>();
auto pos_ids_shape_of = opp::wrap_type<ov::op::v3::ShapeOf>({pos_ids_param});
auto pos_ids_len = opp::wrap_type<ov::op::v8::Gather>({pos_ids_shape_of, opp::any_input(), opp::any_input()});
auto full_ctx_len = opp::wrap_type<ov::op::v1::Add>({past_kv_len, pos_ids_len});
auto full_ctx_len = opp::wrap_type<ov::op::v1::Add>({past_kv_len, opp::any_input()});
auto query_range = opp::wrap_type<ov::op::v4::Range>({past_kv_len, full_ctx_len, opp::any_input()});
auto query_range_column = unsqueeze_sequence(query_range);

auto zero_const = opp::wrap_type<ov::op::v0::Constant>();
auto pos_ids_len_reshaped = opp::wrap_type<ov::op::v1::Reshape>({pos_ids_len, opp::any_input()});
auto pos_ids_len_squeezed = opp::wrap_type<ov::op::v0::Squeeze>({pos_ids_len_reshaped, opp::any_input()});
auto full_ctx_len_2 = opp::wrap_type<ov::op::v1::Add>({pos_ids_len_squeezed, past_kv_len});
auto full_ctx_len_2 = opp::wrap_type<ov::op::v1::Add>({opp::any_input(), past_kv_len});
auto key_range = opp::wrap_type<ov::op::v4::Range>({zero_const, full_ctx_len_2, opp::any_input()});
auto key_range_row = unsqueeze_sequence(key_range);
auto opt_key_range_row_f32 = opp::optional<ov::op::v0::Convert>({key_range_row->output(0)});
Expand All @@ -646,36 +643,19 @@ class Phi3SlidingMask2 : public ov::pass::MatcherPass {
auto causal_mask = opp::wrap_type<ov::op::v1::LessEqual>({opt_key_range_row_f32, query_range_column});

auto sliding_and_causal_mask = opp::wrap_type<ov::op::v13::BitwiseAnd>({sliding_and_true, causal_mask});
auto sliding_causal_and_true =
opp::wrap_type<ov::op::v13::BitwiseAnd>({opp::any_input(), sliding_and_causal_mask});

auto atten_mask_param = opp::wrap_type<ov::op::v0::Parameter>();
auto atten_mask_boolean = opp::wrap_type<ov::op::v0::Convert>({atten_mask_param});
auto atten_mask_reshaped = opp::wrap_type<ov::op::v1::Reshape>({atten_mask_boolean, opp::any_input()});
auto atten_mask_gathered =
opp::wrap_type<ov::op::v8::Gather>({atten_mask_reshaped, opp::any_input(), opp::any_input()});
auto atten_mask_reshaped_2 = opp::wrap_type<ov::op::v1::Reshape>({atten_mask_gathered, opp::any_input()});
auto atten_mask_reshaped_3 = opp::wrap_type<ov::op::v1::Reshape>({atten_mask_reshaped_2, opp::any_input()});

auto final_sliding_attention =
opp::wrap_type<ov::op::v13::BitwiseAnd>({sliding_causal_and_true, atten_mask_reshaped_3});

auto callback = [=](ov::pass::pattern::Matcher& m) {
LOG_INFO("Found (4.53) pattern for Phi-3 Sliding Window Attention, will be replaced with custom for static "
"shapes.");
auto& node_to_output = m.get_pattern_value_map();
auto node_past_kv_len = node_to_output.at(past_kv_len).get_node_shared_ptr();
auto node_pos_ids_param = node_to_output.at(pos_ids_param).get_node_shared_ptr();
auto node_full_ctx_len = node_to_output.at(full_ctx_len).get_node_shared_ptr();
auto node_atten_mask_boolean = node_to_output.at(atten_mask_boolean).get_node_shared_ptr();
auto node_neg_window_size = node_to_output.at(neg_window_size).get_node_shared_ptr();
auto node_sliding_mask = node_to_output.at(sliding_mask).get_node_shared_ptr();
auto node_sliding_and_causal_mask = node_to_output.at(sliding_and_causal_mask).get_node_shared_ptr();

auto matched_past_kv_len = std::static_pointer_cast<ov::op::v8::Gather>(node_past_kv_len);
auto matched_pos_ids_input = std::static_pointer_cast<ov::op::v0::Parameter>(node_pos_ids_param);
auto matched_full_ctx_len = std::static_pointer_cast<ov::op::v1::Add>(node_full_ctx_len);
auto matched_atten_mask_boolean = std::static_pointer_cast<ov::op::v0::Parameter>(node_atten_mask_boolean);
std::shared_ptr<ov::Node> matched_key_range_row = nullptr;
if (node_to_output.count(opt_key_range_row_f32)) {
auto node_key_range_row_f32 = node_to_output[opt_key_range_row_f32].get_node_shared_ptr();
Expand All @@ -692,14 +672,19 @@ class Phi3SlidingMask2 : public ov::pass::MatcherPass {
"Sliding window size constant must be of size 1, but got " +
std::to_string(matched_neg_window_size->get_output_size()));

std::shared_ptr<ov::Node> passed_attention_mask = attention_mask_node_ptr;
std::shared_ptr<ov::Node> passed_position_ids = position_ids_node_ptr;
OPENVINO_ASSERT(passed_attention_mask, "Passed attention_mask node is nullptr!");
OPENVINO_ASSERT(passed_position_ids, "Passed position_ids node is nullptr!");

auto const_zero = std::make_shared<ov::op::v0::Constant>(ov::element::i64, ov::Shape{}, 0);
auto const_one = std::make_shared<ov::op::v0::Constant>(ov::element::i64, ov::Shape{}, 1);
auto const_three = std::make_shared<ov::op::v0::Constant>(ov::element::i64, ov::Shape{}, 3);

// 1.(K range > (Q_pos range - sliding window).T) & (K range <= Q range.T)
std::shared_ptr<ov::Node> query_range_as_pos_ids = matched_pos_ids_input;
std::shared_ptr<ov::Node> query_range_as_pos_ids = passed_position_ids;
if (matched_neg_window_size->output(0).get_element_type() == ov::element::f32) {
query_range_as_pos_ids = std::make_shared<ov::op::v0::Convert>(matched_pos_ids_input, ov::element::f32);
query_range_as_pos_ids = std::make_shared<ov::op::v0::Convert>(passed_position_ids, ov::element::f32);
}
auto query_range_as_pos_ids_unsqueezed =
std::make_shared<ov::op::v0::Unsqueeze>(query_range_as_pos_ids, const_zero);
Expand Down Expand Up @@ -737,7 +722,9 @@ class Phi3SlidingMask2 : public ov::pass::MatcherPass {
auto full_ctx_len_reshaped =
std::make_shared<ov::op::v1::Reshape>(matched_full_ctx_len, shape_rank_one_const, false);
auto const_one_rank_one = std::make_shared<ov::op::v0::Constant>(ov::element::i64, ov::Shape{1}, 1);
auto present_atten_mask_bool = std::make_shared<ov::op::v8::Slice>(matched_atten_mask_boolean,
auto attention_mask_bool =
std::make_shared<ov::op::v0::Convert>(passed_attention_mask, ov::element::boolean);
auto present_atten_mask_bool = std::make_shared<ov::op::v8::Slice>(attention_mask_bool,
past_len_reshaped,
full_ctx_len_reshaped,
const_one_rank_one,
Expand All @@ -755,11 +742,38 @@ class Phi3SlidingMask2 : public ov::pass::MatcherPass {

return true;
};
register_matcher(std::make_shared<opp::Matcher>(final_sliding_attention, "Phi3SlidingMask2"),
register_matcher(std::make_shared<opp::Matcher>(sliding_and_causal_mask, "Phi3SlidingMaskMatcher"),
std::move(callback));
}
};

class Phi3SlidingMask : public ov::pass::ModelPass {
public:
OPENVINO_MODEL_PASS_RTTI("ov::npuw::LLMCompiledModel::Phi3SlidingMask");
Phi3SlidingMask() = default;
bool run_on_model(const std::shared_ptr<ov::Model>& model) override {
std::shared_ptr<ov::Node> attention_mask_node_ptr = nullptr;
std::shared_ptr<ov::Node> position_ids_node_ptr = nullptr;
for (const auto& i : model->inputs()) {
if (i.get_any_name() == "attention_mask") {
attention_mask_node_ptr = i.get_node_shared_ptr();
}
if (i.get_any_name() == "position_ids") {
position_ids_node_ptr = i.get_node_shared_ptr();
}
}
OPENVINO_ASSERT(attention_mask_node_ptr, "attention_mask input is not found!");
OPENVINO_ASSERT(position_ids_node_ptr, "position_ids input is not found!");

ov::pass::Manager manager;
manager.set_per_pass_validation(true);
const auto rewriter = manager.register_pass<ov::pass::GraphRewrite>();
rewriter->add_matcher<Phi3SlidingMaskMatcher>(attention_mask_node_ptr, position_ids_node_ptr);
rewriter->add_matcher<OldPhi3SlidingMaskMatcher>();
return manager.run_passes(model);
}
};

namespace {
uint32_t align_to(uint32_t value, uint32_t alignment) {
return (value + alignment - 1) & ~(alignment - 1);
Expand Down Expand Up @@ -837,11 +851,9 @@ void patch_phi3_sliding_mask(const std::shared_ptr<ov::Model>& model) {
// Qwen2.5 VL/Omni uses 3D position_ids, which can't be directly used
// in creation of sliding window mask.
if (!ov::npuw::util::has_input(model, "token_type_ids") && !ov::npuw::util::has_input(model, "inputs_embeds")) {
ov::pass::GraphRewrite rewr;
rewr.add_matcher<Phi3SlidingMask2>();
rewr.add_matcher<Phi3SlidingMask>();
rewr.run_on_model(model);
model->validate_nodes_and_infer_types();
ov::pass::Manager manager;
manager.register_pass<Phi3SlidingMask>();
manager.run_passes(model);
}
}

Expand Down
Loading