Skip to content

Commit ccf7194

Browse files
rt_info solution
1 parent 4fb2802 commit ccf7194

File tree

1 file changed

+73
-92
lines changed

1 file changed

+73
-92
lines changed

src/plugins/intel_npu/src/plugin/npuw/llm_compiled_model.cpp

Lines changed: 73 additions & 92 deletions
Original file line numberDiff line numberDiff line change
@@ -833,126 +833,107 @@ void patch_phi3_sliding_mask(const std::shared_ptr<ov::Model>& model) {
833833
}
834834
} // namespace
835835

836-
class CutLMHead : public ov::pass::ModelPass {
836+
class CutLMHead : public ov::pass::MatcherPass {
837837
public:
838-
OPENVINO_MODEL_PASS_RTTI("npuw::patterns::CutLMHead");
839-
840-
CutLMHead() = default;
841-
842-
bool run_on_model(const std::shared_ptr<ov::Model>& model) override {
843-
std::shared_ptr<ov::op::v0::MatMul> best_matmul = nullptr;
844-
std::shared_ptr<ov::op::v0::Result> best_result = nullptr;
845-
std::shared_ptr<ov::Node> best_last_op = nullptr;
846-
size_t max_matrix_size = 0;
847-
848-
for (const auto& result : model->get_results()) {
849-
auto input_node = result->input(0).get_source_output().get_node_shared_ptr();
850-
851-
std::shared_ptr<ov::op::v0::MatMul> matmul = nullptr;
852-
std::shared_ptr<ov::Node> last_op = nullptr;
853-
854-
if (auto direct_matmul = ov::as_type_ptr<ov::op::v0::MatMul>(input_node)) {
855-
// Pattern 1: MatMul -> Result (direct connection)
856-
matmul = direct_matmul;
857-
last_op = direct_matmul;
858-
} else if (ov::is_type<ov::op::v1::Add>(input_node) || ov::is_type<ov::op::v1::Transpose>(input_node) ||
859-
ov::is_type<ov::op::v0::Convert>(input_node)) {
860-
// Patterns 2: MatMul -> (single intermediate op) -> Result
861-
if (auto matmul_node = ov::as_type_ptr<ov::op::v0::MatMul>(
862-
input_node->input(0).get_source_output().get_node_shared_ptr())) {
863-
matmul = matmul_node;
864-
last_op = input_node;
865-
}
866-
} else if (auto multiply_node = ov::as_type_ptr<ov::op::v1::Multiply>(input_node)) {
867-
// Pattern 3: MatMul -> Divide -> Tanh -> Multiply -> Result
868-
if (auto tanh_node = ov::as_type_ptr<ov::op::v0::Tanh>(
869-
multiply_node->input(0).get_source_output().get_node_shared_ptr())) {
870-
if (auto divide_node = ov::as_type_ptr<ov::op::v1::Divide>(
871-
tanh_node->input(0).get_source_output().get_node_shared_ptr())) {
872-
if (auto matmul_node = ov::as_type_ptr<ov::op::v0::MatMul>(
873-
divide_node->input(0).get_source_output().get_node_shared_ptr())) {
874-
matmul = matmul_node;
875-
last_op = multiply_node;
876-
}
877-
}
878-
}
879-
}
838+
OPENVINO_MATCHER_PASS_RTTI("npuw::patterns::CutLMHead");
839+
CutLMHead(std::shared_ptr<ov::Model>& lm_head_model, const ov::AnyMap& model_rt_info) {
840+
// We are interested at first input to MatMul as a cut point
841+
auto matmul = opp::wrap_type<ov::op::v0::MatMul>({opp::any_input(), opp::any_input()});
842+
843+
// There are several patterns for matmul we are looking for:
844+
// Matmul -> Result
845+
// Matmul -> Add -> Result
846+
auto matmul_add = opp::wrap_type<ov::op::v1::Add>({matmul, opp::any_input()});
847+
// Matmul -> Transpose -> Result
848+
auto matmul_transpose = opp::wrap_type<ov::op::v1::Transpose>({matmul, opp::any_input()});
849+
// Matmul -> Convert -> Result
850+
auto matmul_convert = opp::wrap_type<ov::op::v0::Convert>({matmul});
851+
// MatMul -> Divide -> Tanh -> Multiply -> Result
852+
auto div = opp::wrap_type<ov::op::v1::Multiply, ov::op::v1::Divide>({matmul, opp::any_input()});
853+
auto tanh = opp::wrap_type<ov::op::v0::Tanh>({div});
854+
auto matmul_multiply = opp::wrap_type<ov::op::v1::Multiply>({tanh, opp::any_input()});
855+
856+
auto last_op = std::make_shared<ov::pass::pattern::op::Or>(ov::OutputVector{matmul->output(0),
857+
matmul_add->output(0),
858+
matmul_transpose->output(0),
859+
matmul_convert->output(0),
860+
matmul_multiply->output(0)});
861+
auto res = opp::wrap_type<ov::op::v0::Result>({last_op->output(0)});
862+
863+
auto callback = [=, &lm_head_model](ov::pass::pattern::Matcher& m) {
864+
auto& node_to_output = m.get_pattern_value_map();
880865

881-
if (matmul) {
882-
auto weight_shape = matmul->input(1).get_source_output().get_partial_shape();
883-
size_t current_matrix_size = 0;
884-
if (weight_shape.rank().is_static() && weight_shape.rank().get_length() == 2) {
885-
auto dim0 = weight_shape[0];
886-
auto dim1 = weight_shape[1];
887-
if (dim0.is_static() && dim1.is_static()) {
888-
auto size0 = static_cast<size_t>(dim0.get_length());
889-
auto size1 = static_cast<size_t>(dim1.get_length());
890-
current_matrix_size = std::max(size0, size1);
866+
auto matched_node_matmul = node_to_output.at(matmul).get_node_shared_ptr();
867+
std::shared_ptr<ov::Node> matched_node_last_op = nullptr;
868+
if (node_to_output.count(matmul_add)) {
869+
matched_node_last_op = node_to_output[matmul_add].get_node_shared_ptr();
870+
} else if (node_to_output.count(matmul_transpose)) {
871+
matched_node_last_op = node_to_output[matmul_transpose].get_node_shared_ptr();
872+
} else if (node_to_output.count(matmul_convert)) {
873+
matched_node_last_op = node_to_output[matmul_convert].get_node_shared_ptr();
874+
} else if (node_to_output.count(matmul_multiply)) {
875+
matched_node_last_op = node_to_output[matmul_multiply].get_node_shared_ptr();
876+
} else {
877+
matched_node_last_op = matched_node_matmul;
878+
}
879+
auto matched_node_result = node_to_output.at(res).get_node_shared_ptr();
880+
881+
auto matched_matmul = std::static_pointer_cast<ov::op::v0::MatMul>(matched_node_matmul);
882+
auto matched_result = std::static_pointer_cast<ov::op::v0::Result>(matched_node_result);
883+
884+
// Some LLMs add intermediate hidden state outputs that can interfere with LM head detection
885+
// Use "hidden_output_name" rt_info to distinguish them from the actual logits output
886+
// For example, Eagle-3 target/draft models add "last_hidden_state" output which should be skipped
887+
if (model_rt_info.count("hidden_output_name")) {
888+
const auto& hidden_output_name = model_rt_info.at("hidden_output_name").as<std::string>();
889+
const auto& result_output_names = matched_result->output(0).get_names();
890+
for (const auto& name : result_output_names) {
891+
if (name == hidden_output_name) {
892+
return false;
891893
}
892894
}
893-
894-
// Update best candidate if this one is larger
895-
if (current_matrix_size > max_matrix_size) {
896-
max_matrix_size = current_matrix_size;
897-
best_matmul = matmul;
898-
best_result = result;
899-
best_last_op = last_op;
900-
}
901895
}
902-
}
903896

904-
if (best_matmul) {
905897
// Cut point:
906-
auto matmul_first_source = best_matmul->input(0).get_source_output();
898+
auto matmul_first_source = matched_matmul->input(0).get_source_output();
907899

908900
// Cut original model:
909-
best_result->input(0).replace_source_output(matmul_first_source);
901+
matched_result->input(0).replace_source_output(matmul_first_source);
910902
// FIXME: Somehow for KVCache model result output gets renamed in
911903
// ICompiledModel::ICompiledModel().
912904
// As a WA, setting the same name to output from MatMul
913905
// avoids the issue.
914906
matmul_first_source.set_names({ov::npuw::LLMCompiledModel::output_embeds});
915-
best_result->output(0).set_names({ov::npuw::LLMCompiledModel::output_embeds});
916-
best_result->validate_and_infer_types();
907+
matched_result->output(0).set_names({ov::npuw::LLMCompiledModel::output_embeds});
908+
matched_result->validate_and_infer_types();
917909

918910
// Create an additional model after cut point:
919911
auto new_param = std::make_shared<ov::op::v0::Parameter>(matmul_first_source.get_element_type(),
920912
matmul_first_source.get_partial_shape());
921913
new_param->output(0).add_names({ov::npuw::LLMCompiledModel::output_embeds});
922-
best_matmul->input(0).replace_source_output(new_param);
923-
auto new_result = std::make_shared<ov::op::v0::Result>(best_last_op);
924-
m_lm_head_model =
914+
matched_matmul->input(0).replace_source_output(new_param);
915+
auto new_result = std::make_shared<ov::op::v0::Result>(matched_node_last_op);
916+
lm_head_model =
925917
std::make_shared<ov::Model>(ov::OutputVector{new_result->output(0)}, ov::ParameterVector{new_param});
926918

927-
if (m_lm_head_model) {
928-
m_lm_head_model->set_friendly_name(model->get_friendly_name() + "_lm_head");
929-
}
930-
model->validate_nodes_and_infer_types();
931-
932919
return true;
933-
}
934-
935-
return false;
936-
}
937-
938-
std::shared_ptr<ov::Model> get_lm_head_model() const {
939-
return m_lm_head_model;
920+
};
921+
register_matcher(std::make_shared<opp::Matcher>(res, "CutLMHead"), std::move(callback));
940922
}
941-
942-
private:
943-
std::shared_ptr<ov::Model> m_lm_head_model = nullptr;
944923
};
945924

946925
namespace {
947-
948926
std::shared_ptr<ov::Model> cut_lm_head(std::shared_ptr<ov::Model>& model) {
949-
LOG_DEBUG("Executing LM head cutting transformation using ModelPass");
950-
LOG_BLOCK();
951-
952-
auto cut_pass = std::make_shared<CutLMHead>();
953-
cut_pass->run_on_model(model);
927+
ov::pass::GraphRewrite rewr;
928+
std::shared_ptr<ov::Model> lm_head_model = nullptr;
929+
rewr.add_matcher<CutLMHead>(lm_head_model, model->get_rt_info());
930+
rewr.run_on_model(model);
931+
if (lm_head_model) {
932+
lm_head_model->set_friendly_name(model->get_friendly_name() + "_lm_head");
933+
}
934+
model->validate_nodes_and_infer_types();
954935

955-
return cut_pass->get_lm_head_model();
936+
return lm_head_model;
956937
}
957938

958939
void reshape_to_static(std::shared_ptr<ov::Model> model,

0 commit comments

Comments
 (0)