Skip to content

Commit 4fb2802

Browse files
Fix multiple MatMul matching issue in NPUW LM head cutting
1 parent 5d95073 commit 4fb2802

File tree

1 file changed

+94
-62
lines changed

1 file changed

+94
-62
lines changed

src/plugins/intel_npu/src/plugin/npuw/llm_compiled_model.cpp

Lines changed: 94 additions & 62 deletions
Original file line numberDiff line numberDiff line change
@@ -833,94 +833,126 @@ void patch_phi3_sliding_mask(const std::shared_ptr<ov::Model>& model) {
833833
}
834834
} // namespace
835835

836-
class CutLMHead : public ov::pass::MatcherPass {
836+
class CutLMHead : public ov::pass::ModelPass {
837837
public:
838-
OPENVINO_MATCHER_PASS_RTTI("npuw::patterns::CutLMHead");
839-
CutLMHead(std::shared_ptr<ov::Model>& lm_head_model) {
840-
// We are interested at first input to MatMul as a cut point
841-
auto matmul = opp::wrap_type<ov::op::v0::MatMul>({opp::any_input(), opp::any_input()});
842-
843-
// There are several patterns for matmul we are looking for:
844-
// Matmul -> Result
845-
// Matmul -> Add -> Result
846-
auto matmul_add = opp::wrap_type<ov::op::v1::Add>({matmul, opp::any_input()});
847-
// Matmul -> Transpose -> Result
848-
auto matmul_transpose = opp::wrap_type<ov::op::v1::Transpose>({matmul, opp::any_input()});
849-
// Matmul -> Convert -> Result
850-
auto matmul_convert = opp::wrap_type<ov::op::v0::Convert>({matmul});
851-
// MatMul -> Divide -> Tanh -> Multiply -> Result
852-
auto div = opp::wrap_type<ov::op::v1::Multiply, ov::op::v1::Divide>({matmul, opp::any_input()});
853-
auto tanh = opp::wrap_type<ov::op::v0::Tanh>({div});
854-
auto matmul_multiply = opp::wrap_type<ov::op::v1::Multiply>({tanh, opp::any_input()});
855-
856-
auto last_op = std::make_shared<ov::pass::pattern::op::Or>(ov::OutputVector{matmul->output(0),
857-
matmul_add->output(0),
858-
matmul_transpose->output(0),
859-
matmul_convert->output(0),
860-
matmul_multiply->output(0)});
861-
auto res = opp::wrap_type<ov::op::v0::Result>({last_op->output(0)});
862-
863-
auto callback = [=, &lm_head_model](ov::pass::pattern::Matcher& m) {
864-
auto& node_to_output = m.get_pattern_value_map();
865-
866-
auto matched_node_matmul = node_to_output.at(matmul).get_node_shared_ptr();
867-
std::shared_ptr<ov::Node> matched_node_last_op = nullptr;
868-
if (node_to_output.count(matmul_add)) {
869-
matched_node_last_op = node_to_output[matmul_add].get_node_shared_ptr();
870-
} else if (node_to_output.count(matmul_transpose)) {
871-
matched_node_last_op = node_to_output[matmul_transpose].get_node_shared_ptr();
872-
} else if (node_to_output.count(matmul_convert)) {
873-
matched_node_last_op = node_to_output[matmul_convert].get_node_shared_ptr();
874-
} else if (node_to_output.count(matmul_multiply)) {
875-
matched_node_last_op = node_to_output[matmul_multiply].get_node_shared_ptr();
876-
} else {
877-
matched_node_last_op = matched_node_matmul;
838+
OPENVINO_MODEL_PASS_RTTI("npuw::patterns::CutLMHead");
839+
840+
CutLMHead() = default;
841+
842+
bool run_on_model(const std::shared_ptr<ov::Model>& model) override {
843+
std::shared_ptr<ov::op::v0::MatMul> best_matmul = nullptr;
844+
std::shared_ptr<ov::op::v0::Result> best_result = nullptr;
845+
std::shared_ptr<ov::Node> best_last_op = nullptr;
846+
size_t max_matrix_size = 0;
847+
848+
for (const auto& result : model->get_results()) {
849+
auto input_node = result->input(0).get_source_output().get_node_shared_ptr();
850+
851+
std::shared_ptr<ov::op::v0::MatMul> matmul = nullptr;
852+
std::shared_ptr<ov::Node> last_op = nullptr;
853+
854+
if (auto direct_matmul = ov::as_type_ptr<ov::op::v0::MatMul>(input_node)) {
855+
// Pattern 1: MatMul -> Result (direct connection)
856+
matmul = direct_matmul;
857+
last_op = direct_matmul;
858+
} else if (ov::is_type<ov::op::v1::Add>(input_node) || ov::is_type<ov::op::v1::Transpose>(input_node) ||
859+
ov::is_type<ov::op::v0::Convert>(input_node)) {
860+
// Patterns 2: MatMul -> (single intermediate op) -> Result
861+
if (auto matmul_node = ov::as_type_ptr<ov::op::v0::MatMul>(
862+
input_node->input(0).get_source_output().get_node_shared_ptr())) {
863+
matmul = matmul_node;
864+
last_op = input_node;
865+
}
866+
} else if (auto multiply_node = ov::as_type_ptr<ov::op::v1::Multiply>(input_node)) {
867+
// Pattern 3: MatMul -> Divide -> Tanh -> Multiply -> Result
868+
if (auto tanh_node = ov::as_type_ptr<ov::op::v0::Tanh>(
869+
multiply_node->input(0).get_source_output().get_node_shared_ptr())) {
870+
if (auto divide_node = ov::as_type_ptr<ov::op::v1::Divide>(
871+
tanh_node->input(0).get_source_output().get_node_shared_ptr())) {
872+
if (auto matmul_node = ov::as_type_ptr<ov::op::v0::MatMul>(
873+
divide_node->input(0).get_source_output().get_node_shared_ptr())) {
874+
matmul = matmul_node;
875+
last_op = multiply_node;
876+
}
877+
}
878+
}
878879
}
879-
auto matched_node_result = node_to_output.at(res).get_node_shared_ptr();
880880

881-
auto matched_matmul = std::static_pointer_cast<ov::op::v0::MatMul>(matched_node_matmul);
882-
auto matched_result = std::static_pointer_cast<ov::op::v0::Result>(matched_node_result);
881+
if (matmul) {
882+
auto weight_shape = matmul->input(1).get_source_output().get_partial_shape();
883+
size_t current_matrix_size = 0;
884+
if (weight_shape.rank().is_static() && weight_shape.rank().get_length() == 2) {
885+
auto dim0 = weight_shape[0];
886+
auto dim1 = weight_shape[1];
887+
if (dim0.is_static() && dim1.is_static()) {
888+
auto size0 = static_cast<size_t>(dim0.get_length());
889+
auto size1 = static_cast<size_t>(dim1.get_length());
890+
current_matrix_size = std::max(size0, size1);
891+
}
892+
}
883893

894+
// Update best candidate if this one is larger
895+
if (current_matrix_size > max_matrix_size) {
896+
max_matrix_size = current_matrix_size;
897+
best_matmul = matmul;
898+
best_result = result;
899+
best_last_op = last_op;
900+
}
901+
}
902+
}
903+
904+
if (best_matmul) {
884905
// Cut point:
885-
auto matmul_first_source = matched_matmul->input(0).get_source_output();
906+
auto matmul_first_source = best_matmul->input(0).get_source_output();
886907

887908
// Cut original model:
888-
matched_result->input(0).replace_source_output(matmul_first_source);
909+
best_result->input(0).replace_source_output(matmul_first_source);
889910
// FIXME: Somehow for KVCache model result output gets renamed in
890911
// ICompiledModel::ICompiledModel().
891912
// As a WA, setting the same name to output from MatMul
892913
// avoids the issue.
893914
matmul_first_source.set_names({ov::npuw::LLMCompiledModel::output_embeds});
894-
matched_result->output(0).set_names({ov::npuw::LLMCompiledModel::output_embeds});
895-
matched_result->validate_and_infer_types();
915+
best_result->output(0).set_names({ov::npuw::LLMCompiledModel::output_embeds});
916+
best_result->validate_and_infer_types();
896917

897918
// Create an additional model after cut point:
898919
auto new_param = std::make_shared<ov::op::v0::Parameter>(matmul_first_source.get_element_type(),
899920
matmul_first_source.get_partial_shape());
900921
new_param->output(0).add_names({ov::npuw::LLMCompiledModel::output_embeds});
901-
matched_matmul->input(0).replace_source_output(new_param);
902-
auto new_result = std::make_shared<ov::op::v0::Result>(matched_node_last_op);
903-
lm_head_model =
922+
best_matmul->input(0).replace_source_output(new_param);
923+
auto new_result = std::make_shared<ov::op::v0::Result>(best_last_op);
924+
m_lm_head_model =
904925
std::make_shared<ov::Model>(ov::OutputVector{new_result->output(0)}, ov::ParameterVector{new_param});
905926

927+
if (m_lm_head_model) {
928+
m_lm_head_model->set_friendly_name(model->get_friendly_name() + "_lm_head");
929+
}
930+
model->validate_nodes_and_infer_types();
931+
906932
return true;
907-
};
908-
register_matcher(std::make_shared<opp::Matcher>(res, "CutLMHead"), std::move(callback));
933+
}
934+
935+
return false;
936+
}
937+
938+
std::shared_ptr<ov::Model> get_lm_head_model() const {
939+
return m_lm_head_model;
909940
}
941+
942+
private:
943+
std::shared_ptr<ov::Model> m_lm_head_model = nullptr;
910944
};
911945

912946
namespace {
947+
913948
std::shared_ptr<ov::Model> cut_lm_head(std::shared_ptr<ov::Model>& model) {
914-
ov::pass::GraphRewrite rewr;
915-
std::shared_ptr<ov::Model> lm_head_model = nullptr;
916-
rewr.add_matcher<CutLMHead>(lm_head_model);
917-
rewr.run_on_model(model);
918-
if (lm_head_model) {
919-
lm_head_model->set_friendly_name(model->get_friendly_name() + "_lm_head");
920-
}
921-
model->validate_nodes_and_infer_types();
949+
LOG_DEBUG("Executing LM head cutting transformation using ModelPass");
950+
LOG_BLOCK();
951+
952+
auto cut_pass = std::make_shared<CutLMHead>();
953+
cut_pass->run_on_model(model);
922954

923-
return lm_head_model;
955+
return cut_pass->get_lm_head_model();
924956
}
925957

926958
void reshape_to_static(std::shared_ptr<ov::Model> model,

0 commit comments

Comments
 (0)