@@ -833,94 +833,126 @@ void patch_phi3_sliding_mask(const std::shared_ptr<ov::Model>& model) {
833833}
834834} // namespace
835835
836- class CutLMHead : public ov ::pass::MatcherPass {
836+ class CutLMHead : public ov ::pass::ModelPass {
837837public:
838- OPENVINO_MATCHER_PASS_RTTI (" npuw::patterns::CutLMHead" );
839- CutLMHead (std::shared_ptr<ov::Model>& lm_head_model) {
840- // We are interested at first input to MatMul as a cut point
841- auto matmul = opp::wrap_type<ov::op::v0::MatMul>({opp::any_input (), opp::any_input ()});
842-
843- // There are several patterns for matmul we are looking for:
844- // Matmul -> Result
845- // Matmul -> Add -> Result
846- auto matmul_add = opp::wrap_type<ov::op::v1::Add>({matmul, opp::any_input ()});
847- // Matmul -> Transpose -> Result
848- auto matmul_transpose = opp::wrap_type<ov::op::v1::Transpose>({matmul, opp::any_input ()});
849- // Matmul -> Convert -> Result
850- auto matmul_convert = opp::wrap_type<ov::op::v0::Convert>({matmul});
851- // MatMul -> Divide -> Tanh -> Multiply -> Result
852- auto div = opp::wrap_type<ov::op::v1::Multiply, ov::op::v1::Divide>({matmul, opp::any_input ()});
853- auto tanh = opp::wrap_type<ov::op::v0::Tanh>({div});
854- auto matmul_multiply = opp::wrap_type<ov::op::v1::Multiply>({tanh, opp::any_input ()});
855-
856- auto last_op = std::make_shared<ov::pass::pattern::op::Or>(ov::OutputVector{matmul->output (0 ),
857- matmul_add->output (0 ),
858- matmul_transpose->output (0 ),
859- matmul_convert->output (0 ),
860- matmul_multiply->output (0 )});
861- auto res = opp::wrap_type<ov::op::v0::Result>({last_op->output (0 )});
862-
863- auto callback = [=, &lm_head_model](ov::pass::pattern::Matcher& m) {
864- auto & node_to_output = m.get_pattern_value_map ();
865-
866- auto matched_node_matmul = node_to_output.at (matmul).get_node_shared_ptr ();
867- std::shared_ptr<ov::Node> matched_node_last_op = nullptr ;
868- if (node_to_output.count (matmul_add)) {
869- matched_node_last_op = node_to_output[matmul_add].get_node_shared_ptr ();
870- } else if (node_to_output.count (matmul_transpose)) {
871- matched_node_last_op = node_to_output[matmul_transpose].get_node_shared_ptr ();
872- } else if (node_to_output.count (matmul_convert)) {
873- matched_node_last_op = node_to_output[matmul_convert].get_node_shared_ptr ();
874- } else if (node_to_output.count (matmul_multiply)) {
875- matched_node_last_op = node_to_output[matmul_multiply].get_node_shared_ptr ();
876- } else {
877- matched_node_last_op = matched_node_matmul;
838+ OPENVINO_MODEL_PASS_RTTI (" npuw::patterns::CutLMHead" );
839+
840+ CutLMHead () = default ;
841+
842+ bool run_on_model (const std::shared_ptr<ov::Model>& model) override {
843+ std::shared_ptr<ov::op::v0::MatMul> best_matmul = nullptr ;
844+ std::shared_ptr<ov::op::v0::Result> best_result = nullptr ;
845+ std::shared_ptr<ov::Node> best_last_op = nullptr ;
846+ size_t max_matrix_size = 0 ;
847+
848+ for (const auto & result : model->get_results ()) {
849+ auto input_node = result->input (0 ).get_source_output ().get_node_shared_ptr ();
850+
851+ std::shared_ptr<ov::op::v0::MatMul> matmul = nullptr ;
852+ std::shared_ptr<ov::Node> last_op = nullptr ;
853+
854+ if (auto direct_matmul = ov::as_type_ptr<ov::op::v0::MatMul>(input_node)) {
855+ // Pattern 1: MatMul -> Result (direct connection)
856+ matmul = direct_matmul;
857+ last_op = direct_matmul;
858+ } else if (ov::is_type<ov::op::v1::Add>(input_node) || ov::is_type<ov::op::v1::Transpose>(input_node) ||
859+ ov::is_type<ov::op::v0::Convert>(input_node)) {
860+ // Patterns 2: MatMul -> (single intermediate op) -> Result
861+ if (auto matmul_node = ov::as_type_ptr<ov::op::v0::MatMul>(
862+ input_node->input (0 ).get_source_output ().get_node_shared_ptr ())) {
863+ matmul = matmul_node;
864+ last_op = input_node;
865+ }
866+ } else if (auto multiply_node = ov::as_type_ptr<ov::op::v1::Multiply>(input_node)) {
867+ // Pattern 3: MatMul -> Divide -> Tanh -> Multiply -> Result
868+ if (auto tanh_node = ov::as_type_ptr<ov::op::v0::Tanh>(
869+ multiply_node->input (0 ).get_source_output ().get_node_shared_ptr ())) {
870+ if (auto divide_node = ov::as_type_ptr<ov::op::v1::Divide>(
871+ tanh_node->input (0 ).get_source_output ().get_node_shared_ptr ())) {
872+ if (auto matmul_node = ov::as_type_ptr<ov::op::v0::MatMul>(
873+ divide_node->input (0 ).get_source_output ().get_node_shared_ptr ())) {
874+ matmul = matmul_node;
875+ last_op = multiply_node;
876+ }
877+ }
878+ }
878879 }
879- auto matched_node_result = node_to_output.at (res).get_node_shared_ptr ();
880880
881- auto matched_matmul = std::static_pointer_cast<ov::op::v0::MatMul>(matched_node_matmul);
882- auto matched_result = std::static_pointer_cast<ov::op::v0::Result>(matched_node_result);
881+ if (matmul) {
882+ auto weight_shape = matmul->input (1 ).get_source_output ().get_partial_shape ();
883+ size_t current_matrix_size = 0 ;
884+ if (weight_shape.rank ().is_static () && weight_shape.rank ().get_length () == 2 ) {
885+ auto dim0 = weight_shape[0 ];
886+ auto dim1 = weight_shape[1 ];
887+ if (dim0.is_static () && dim1.is_static ()) {
888+ auto size0 = static_cast <size_t >(dim0.get_length ());
889+ auto size1 = static_cast <size_t >(dim1.get_length ());
890+ current_matrix_size = std::max (size0, size1);
891+ }
892+ }
883893
894+ // Update best candidate if this one is larger
895+ if (current_matrix_size > max_matrix_size) {
896+ max_matrix_size = current_matrix_size;
897+ best_matmul = matmul;
898+ best_result = result;
899+ best_last_op = last_op;
900+ }
901+ }
902+ }
903+
904+ if (best_matmul) {
884905 // Cut point:
885- auto matmul_first_source = matched_matmul ->input (0 ).get_source_output ();
906+ auto matmul_first_source = best_matmul ->input (0 ).get_source_output ();
886907
887908 // Cut original model:
888- matched_result ->input (0 ).replace_source_output (matmul_first_source);
909+ best_result ->input (0 ).replace_source_output (matmul_first_source);
889910 // FIXME: Somehow for KVCache model result output gets renamed in
890911 // ICompiledModel::ICompiledModel().
891912 // As a WA, setting the same name to output from MatMul
892913 // avoids the issue.
893914 matmul_first_source.set_names ({ov::npuw::LLMCompiledModel::output_embeds});
894- matched_result ->output (0 ).set_names ({ov::npuw::LLMCompiledModel::output_embeds});
895- matched_result ->validate_and_infer_types ();
915+ best_result ->output (0 ).set_names ({ov::npuw::LLMCompiledModel::output_embeds});
916+ best_result ->validate_and_infer_types ();
896917
897918 // Create an additional model after cut point:
898919 auto new_param = std::make_shared<ov::op::v0::Parameter>(matmul_first_source.get_element_type (),
899920 matmul_first_source.get_partial_shape ());
900921 new_param->output (0 ).add_names ({ov::npuw::LLMCompiledModel::output_embeds});
901- matched_matmul ->input (0 ).replace_source_output (new_param);
902- auto new_result = std::make_shared<ov::op::v0::Result>(matched_node_last_op );
903- lm_head_model =
922+ best_matmul ->input (0 ).replace_source_output (new_param);
923+ auto new_result = std::make_shared<ov::op::v0::Result>(best_last_op );
924+ m_lm_head_model =
904925 std::make_shared<ov::Model>(ov::OutputVector{new_result->output (0 )}, ov::ParameterVector{new_param});
905926
927+ if (m_lm_head_model) {
928+ m_lm_head_model->set_friendly_name (model->get_friendly_name () + " _lm_head" );
929+ }
930+ model->validate_nodes_and_infer_types ();
931+
906932 return true ;
907- };
908- register_matcher (std::make_shared<opp::Matcher>(res, " CutLMHead" ), std::move (callback));
933+ }
934+
935+ return false ;
936+ }
937+
938+ std::shared_ptr<ov::Model> get_lm_head_model () const {
939+ return m_lm_head_model;
909940 }
941+
942+ private:
943+ std::shared_ptr<ov::Model> m_lm_head_model = nullptr ;
910944};
911945
912946namespace {
947+
913948std::shared_ptr<ov::Model> cut_lm_head (std::shared_ptr<ov::Model>& model) {
914- ov::pass::GraphRewrite rewr;
915- std::shared_ptr<ov::Model> lm_head_model = nullptr ;
916- rewr.add_matcher <CutLMHead>(lm_head_model);
917- rewr.run_on_model (model);
918- if (lm_head_model) {
919- lm_head_model->set_friendly_name (model->get_friendly_name () + " _lm_head" );
920- }
921- model->validate_nodes_and_infer_types ();
949+ LOG_DEBUG (" Executing LM head cutting transformation using ModelPass" );
950+ LOG_BLOCK ();
951+
952+ auto cut_pass = std::make_shared<CutLMHead>();
953+ cut_pass->run_on_model (model);
922954
923- return lm_head_model ;
955+ return cut_pass-> get_lm_head_model () ;
924956}
925957
926958void reshape_to_static (std::shared_ptr<ov::Model> model,
0 commit comments