@@ -833,126 +833,107 @@ void patch_phi3_sliding_mask(const std::shared_ptr<ov::Model>& model) {
833833}
834834} // namespace
835835
836- class CutLMHead : public ov ::pass::ModelPass {
836+ class CutLMHead : public ov ::pass::MatcherPass {
837837public:
838- OPENVINO_MODEL_PASS_RTTI (" npuw::patterns::CutLMHead" );
839-
840- CutLMHead () = default ;
841-
842- bool run_on_model (const std::shared_ptr<ov::Model>& model) override {
843- std::shared_ptr<ov::op::v0::MatMul> best_matmul = nullptr ;
844- std::shared_ptr<ov::op::v0::Result> best_result = nullptr ;
845- std::shared_ptr<ov::Node> best_last_op = nullptr ;
846- size_t max_matrix_size = 0 ;
847-
848- for (const auto & result : model->get_results ()) {
849- auto input_node = result->input (0 ).get_source_output ().get_node_shared_ptr ();
850-
851- std::shared_ptr<ov::op::v0::MatMul> matmul = nullptr ;
852- std::shared_ptr<ov::Node> last_op = nullptr ;
853-
854- if (auto direct_matmul = ov::as_type_ptr<ov::op::v0::MatMul>(input_node)) {
855- // Pattern 1: MatMul -> Result (direct connection)
856- matmul = direct_matmul;
857- last_op = direct_matmul;
858- } else if (ov::is_type<ov::op::v1::Add>(input_node) || ov::is_type<ov::op::v1::Transpose>(input_node) ||
859- ov::is_type<ov::op::v0::Convert>(input_node)) {
860- // Patterns 2: MatMul -> (single intermediate op) -> Result
861- if (auto matmul_node = ov::as_type_ptr<ov::op::v0::MatMul>(
862- input_node->input (0 ).get_source_output ().get_node_shared_ptr ())) {
863- matmul = matmul_node;
864- last_op = input_node;
865- }
866- } else if (auto multiply_node = ov::as_type_ptr<ov::op::v1::Multiply>(input_node)) {
867- // Pattern 3: MatMul -> Divide -> Tanh -> Multiply -> Result
868- if (auto tanh_node = ov::as_type_ptr<ov::op::v0::Tanh>(
869- multiply_node->input (0 ).get_source_output ().get_node_shared_ptr ())) {
870- if (auto divide_node = ov::as_type_ptr<ov::op::v1::Divide>(
871- tanh_node->input (0 ).get_source_output ().get_node_shared_ptr ())) {
872- if (auto matmul_node = ov::as_type_ptr<ov::op::v0::MatMul>(
873- divide_node->input (0 ).get_source_output ().get_node_shared_ptr ())) {
874- matmul = matmul_node;
875- last_op = multiply_node;
876- }
877- }
878- }
879- }
838+ OPENVINO_MATCHER_PASS_RTTI (" npuw::patterns::CutLMHead" );
839+ CutLMHead (std::shared_ptr<ov::Model>& lm_head_model, const ov::AnyMap& model_rt_info) {
840+ // We are interested at first input to MatMul as a cut point
841+ auto matmul = opp::wrap_type<ov::op::v0::MatMul>({opp::any_input (), opp::any_input ()});
842+
843+ // There are several patterns for matmul we are looking for:
844+ // Matmul -> Result
845+ // Matmul -> Add -> Result
846+ auto matmul_add = opp::wrap_type<ov::op::v1::Add>({matmul, opp::any_input ()});
847+ // Matmul -> Transpose -> Result
848+ auto matmul_transpose = opp::wrap_type<ov::op::v1::Transpose>({matmul, opp::any_input ()});
849+ // Matmul -> Convert -> Result
850+ auto matmul_convert = opp::wrap_type<ov::op::v0::Convert>({matmul});
851+ // MatMul -> Divide -> Tanh -> Multiply -> Result
852+ auto div = opp::wrap_type<ov::op::v1::Multiply, ov::op::v1::Divide>({matmul, opp::any_input ()});
853+ auto tanh = opp::wrap_type<ov::op::v0::Tanh>({div});
854+ auto matmul_multiply = opp::wrap_type<ov::op::v1::Multiply>({tanh, opp::any_input ()});
855+
856+ auto last_op = std::make_shared<ov::pass::pattern::op::Or>(ov::OutputVector{matmul->output (0 ),
857+ matmul_add->output (0 ),
858+ matmul_transpose->output (0 ),
859+ matmul_convert->output (0 ),
860+ matmul_multiply->output (0 )});
861+ auto res = opp::wrap_type<ov::op::v0::Result>({last_op->output (0 )});
862+
863+ auto callback = [=, &lm_head_model](ov::pass::pattern::Matcher& m) {
864+ auto & node_to_output = m.get_pattern_value_map ();
880865
881- if (matmul) {
882- auto weight_shape = matmul->input (1 ).get_source_output ().get_partial_shape ();
883- size_t current_matrix_size = 0 ;
884- if (weight_shape.rank ().is_static () && weight_shape.rank ().get_length () == 2 ) {
885- auto dim0 = weight_shape[0 ];
886- auto dim1 = weight_shape[1 ];
887- if (dim0.is_static () && dim1.is_static ()) {
888- auto size0 = static_cast <size_t >(dim0.get_length ());
889- auto size1 = static_cast <size_t >(dim1.get_length ());
890- current_matrix_size = std::max (size0, size1);
866+ auto matched_node_matmul = node_to_output.at (matmul).get_node_shared_ptr ();
867+ std::shared_ptr<ov::Node> matched_node_last_op = nullptr ;
868+ if (node_to_output.count (matmul_add)) {
869+ matched_node_last_op = node_to_output[matmul_add].get_node_shared_ptr ();
870+ } else if (node_to_output.count (matmul_transpose)) {
871+ matched_node_last_op = node_to_output[matmul_transpose].get_node_shared_ptr ();
872+ } else if (node_to_output.count (matmul_convert)) {
873+ matched_node_last_op = node_to_output[matmul_convert].get_node_shared_ptr ();
874+ } else if (node_to_output.count (matmul_multiply)) {
875+ matched_node_last_op = node_to_output[matmul_multiply].get_node_shared_ptr ();
876+ } else {
877+ matched_node_last_op = matched_node_matmul;
878+ }
879+ auto matched_node_result = node_to_output.at (res).get_node_shared_ptr ();
880+
881+ auto matched_matmul = std::static_pointer_cast<ov::op::v0::MatMul>(matched_node_matmul);
882+ auto matched_result = std::static_pointer_cast<ov::op::v0::Result>(matched_node_result);
883+
884+ // Some LLMs add intermediate hidden state outputs that can interfere with LM head detection
885+ // Use "hidden_output_name" rt_info to distinguish them from the actual logits output
886+ // For example, Eagle-3 target/draft models add "last_hidden_state" output which should be skipped
887+ if (model_rt_info.count (" hidden_output_name" )) {
888+ const auto & hidden_output_name = model_rt_info.at (" hidden_output_name" ).as <std::string>();
889+ const auto & result_output_names = matched_result->output (0 ).get_names ();
890+ for (const auto & name : result_output_names) {
891+ if (name == hidden_output_name) {
892+ return false ;
891893 }
892894 }
893-
894- // Update best candidate if this one is larger
895- if (current_matrix_size > max_matrix_size) {
896- max_matrix_size = current_matrix_size;
897- best_matmul = matmul;
898- best_result = result;
899- best_last_op = last_op;
900- }
901895 }
902- }
903896
904- if (best_matmul) {
905897 // Cut point:
906- auto matmul_first_source = best_matmul ->input (0 ).get_source_output ();
898+ auto matmul_first_source = matched_matmul ->input (0 ).get_source_output ();
907899
908900 // Cut original model:
909- best_result ->input (0 ).replace_source_output (matmul_first_source);
901+ matched_result ->input (0 ).replace_source_output (matmul_first_source);
910902 // FIXME: Somehow for KVCache model result output gets renamed in
911903 // ICompiledModel::ICompiledModel().
912904 // As a WA, setting the same name to output from MatMul
913905 // avoids the issue.
914906 matmul_first_source.set_names ({ov::npuw::LLMCompiledModel::output_embeds});
915- best_result ->output (0 ).set_names ({ov::npuw::LLMCompiledModel::output_embeds});
916- best_result ->validate_and_infer_types ();
907+ matched_result ->output (0 ).set_names ({ov::npuw::LLMCompiledModel::output_embeds});
908+ matched_result ->validate_and_infer_types ();
917909
918910 // Create an additional model after cut point:
919911 auto new_param = std::make_shared<ov::op::v0::Parameter>(matmul_first_source.get_element_type (),
920912 matmul_first_source.get_partial_shape ());
921913 new_param->output (0 ).add_names ({ov::npuw::LLMCompiledModel::output_embeds});
922- best_matmul ->input (0 ).replace_source_output (new_param);
923- auto new_result = std::make_shared<ov::op::v0::Result>(best_last_op );
924- m_lm_head_model =
914+ matched_matmul ->input (0 ).replace_source_output (new_param);
915+ auto new_result = std::make_shared<ov::op::v0::Result>(matched_node_last_op );
916+ lm_head_model =
925917 std::make_shared<ov::Model>(ov::OutputVector{new_result->output (0 )}, ov::ParameterVector{new_param});
926918
927- if (m_lm_head_model) {
928- m_lm_head_model->set_friendly_name (model->get_friendly_name () + " _lm_head" );
929- }
930- model->validate_nodes_and_infer_types ();
931-
932919 return true ;
933- }
934-
935- return false ;
936- }
937-
938- std::shared_ptr<ov::Model> get_lm_head_model () const {
939- return m_lm_head_model;
920+ };
921+ register_matcher (std::make_shared<opp::Matcher>(res, " CutLMHead" ), std::move (callback));
940922 }
941-
942- private:
943- std::shared_ptr<ov::Model> m_lm_head_model = nullptr ;
944923};
945924
946925namespace {
947-
948926std::shared_ptr<ov::Model> cut_lm_head (std::shared_ptr<ov::Model>& model) {
949- LOG_DEBUG (" Executing LM head cutting transformation using ModelPass" );
950- LOG_BLOCK ();
951-
952- auto cut_pass = std::make_shared<CutLMHead>();
953- cut_pass->run_on_model (model);
927+ ov::pass::GraphRewrite rewr;
928+ std::shared_ptr<ov::Model> lm_head_model = nullptr ;
929+ rewr.add_matcher <CutLMHead>(lm_head_model, model->get_rt_info ());
930+ rewr.run_on_model (model);
931+ if (lm_head_model) {
932+ lm_head_model->set_friendly_name (model->get_friendly_name () + " _lm_head" );
933+ }
934+ model->validate_nodes_and_infer_types ();
954935
955- return cut_pass-> get_lm_head_model () ;
936+ return lm_head_model ;
956937}
957938
958939void reshape_to_static (std::shared_ptr<ov::Model> model,
0 commit comments