@@ -467,7 +467,7 @@ struct KVAxesPosition {
467467class CutLMHead : public ov ::pass::MatcherPass {
468468public:
469469 OPENVINO_MATCHER_PASS_RTTI (" npuw::patterns::CutLMHead" );
470-
470+
471471 struct CutContext {
472472 std::shared_ptr<ov::Model> lm_head_model = nullptr ;
473473 std::shared_ptr<ov::Model> source_model = nullptr ;
@@ -478,71 +478,70 @@ class CutLMHead : public ov::pass::MatcherPass {
478478 // Create a dummy pattern that matches any Result node
479479 // We'll do the actual logic in the callback
480480 auto result_pattern = opp::wrap_type<ov::op::v0::Result>();
481-
481+
482482 auto callback = [this ](ov::pass::pattern::Matcher& m) -> bool {
483483 // Only execute the transformation once
484484 if (m_executed) {
485485 return false ;
486486 }
487487 m_executed = true ;
488-
488+
489489 // Use the source model provided in context
490490 if (!m_cut_context.get ().source_model ) {
491491 return false ;
492492 }
493-
493+
494494 return execute_cut_logic (m_cut_context.get ().source_model );
495495 };
496-
496+
497497 register_matcher (std::make_shared<opp::Matcher>(result_pattern, " CutLMHead" ), std::move (callback));
498498 }
499499
500500private:
501501 CutContext::Ref m_cut_context;
502502 bool m_executed = false ;
503-
503+
504504 bool execute_cut_logic (std::shared_ptr<ov::Model> model) {
505505 std::shared_ptr<ov::op::v0::MatMul> best_matmul = nullptr ;
506506 std::shared_ptr<ov::op::v0::Result> best_result = nullptr ;
507507 std::shared_ptr<ov::Node> best_last_op = nullptr ;
508508 size_t max_matrix_size = 0 ;
509-
509+
510510 for (const auto & result : model->get_results ()) {
511511 auto input_node = result->input (0 ).get_source_output ().get_node_shared_ptr ();
512-
512+
513513 std::shared_ptr<ov::op::v0::MatMul> matmul = nullptr ;
514514 std::shared_ptr<ov::Node> last_op = nullptr ;
515-
515+
516516 // Pattern 1: MatMul -> Result (direct connection)
517517 if (auto direct_matmul = ov::as_type_ptr<ov::op::v0::MatMul>(input_node)) {
518518 matmul = direct_matmul;
519519 last_op = direct_matmul;
520520 }
521521 // Patterns 2: MatMul -> (single intermediate op) -> Result
522- else if (ov::is_type<ov::op::v1::Add>(input_node) ||
523- ov::is_type<ov::op::v1::Transpose>(input_node) ||
522+ else if (ov::is_type<ov::op::v1::Add>(input_node) || ov::is_type<ov::op::v1::Transpose>(input_node) ||
524523 ov::is_type<ov::op::v0::Convert>(input_node)) {
525524 if (auto matmul_node = ov::as_type_ptr<ov::op::v0::MatMul>(
526- input_node->input (0 ).get_source_output ().get_node_shared_ptr ())) {
525+ input_node->input (0 ).get_source_output ().get_node_shared_ptr ())) {
527526 matmul = matmul_node;
528527 last_op = input_node;
529528 }
530529 }
531530 // Pattern 3: MatMul -> Divide -> Tanh -> Multiply -> Result
532531 else if (auto multiply_node = ov::as_type_ptr<ov::op::v1::Multiply>(input_node)) {
533532 if (auto tanh_node = ov::as_type_ptr<ov::op::v0::Tanh>(
534- multiply_node->input (0 ).get_source_output ().get_node_shared_ptr ())) {
533+ multiply_node->input (0 ).get_source_output ().get_node_shared_ptr ())) {
535534 if (auto divide_node = ov::as_type_ptr<ov::op::v1::Divide>(
536- tanh_node->input (0 ).get_source_output ().get_node_shared_ptr ())) {
535+ tanh_node->input (0 ).get_source_output ().get_node_shared_ptr ())) {
537536 if (auto matmul_node = ov::as_type_ptr<ov::op::v0::MatMul>(
538- divide_node->input (0 ).get_source_output ().get_node_shared_ptr ())) {
537+ divide_node->input (0 ).get_source_output ().get_node_shared_ptr ())) {
539538 matmul = matmul_node;
540539 last_op = multiply_node;
541540 }
542541 }
543542 }
544543 }
545-
544+
546545 if (matmul) {
547546 auto weight_shape = matmul->input (1 ).get_source_output ().get_partial_shape ();
548547 size_t current_matrix_size = 0 ;
@@ -555,7 +554,7 @@ class CutLMHead : public ov::pass::MatcherPass {
555554 current_matrix_size = std::max (size0, size1);
556555 }
557556 }
558-
557+
559558 // Update best candidate if this one is larger
560559 if (current_matrix_size > max_matrix_size) {
561560 max_matrix_size = current_matrix_size;
@@ -565,11 +564,11 @@ class CutLMHead : public ov::pass::MatcherPass {
565564 }
566565 }
567566 }
568-
567+
569568 if (best_matmul) {
570569 // Cut point:
571570 auto matmul_first_source = best_matmul->input (0 ).get_source_output ();
572-
571+
573572 // Cut original model:
574573 best_result->input (0 ).replace_source_output (matmul_first_source);
575574 // FIXME: Somehow for KVCache model result output gets renamed in
@@ -579,23 +578,24 @@ class CutLMHead : public ov::pass::MatcherPass {
579578 matmul_first_source.set_names ({ov::npuw::LLMCompiledModel::output_embeds});
580579 best_result->output (0 ).set_names ({ov::npuw::LLMCompiledModel::output_embeds});
581580 best_result->validate_and_infer_types ();
582-
581+
583582 // Create an additional model after cut point:
584583 auto new_param = std::make_shared<ov::op::v0::Parameter>(matmul_first_source.get_element_type (),
585584 matmul_first_source.get_partial_shape ());
586585 new_param->output (0 ).add_names ({ov::npuw::LLMCompiledModel::output_embeds});
587586 best_matmul->input (0 ).replace_source_output (new_param);
588587 auto new_result = std::make_shared<ov::op::v0::Result>(best_last_op);
589- m_cut_context.get ().lm_head_model = std::make_shared<ov::Model>(ov::OutputVector{new_result->output (0 )}, ov::ParameterVector{new_param});
590-
588+ m_cut_context.get ().lm_head_model =
589+ std::make_shared<ov::Model>(ov::OutputVector{new_result->output (0 )}, ov::ParameterVector{new_param});
590+
591591 if (m_cut_context.get ().lm_head_model ) {
592592 m_cut_context.get ().lm_head_model ->set_friendly_name (model->get_friendly_name () + " _lm_head" );
593593 }
594594 model->validate_nodes_and_infer_types ();
595-
595+
596596 return true ;
597597 }
598-
598+
599599 return false ;
600600 }
601601};
@@ -605,13 +605,13 @@ namespace {
605605std::shared_ptr<ov::Model> cut_lm_head (std::shared_ptr<ov::Model>& model) {
606606 LOG_DEBUG (" Executing LM head cutting transformation using MatcherPass" );
607607 LOG_BLOCK ();
608-
608+
609609 CutLMHead::CutContext cut_ctx;
610610 cut_ctx.source_model = model;
611611 ov::pass::GraphRewrite rewr;
612612 rewr.add_matcher <CutLMHead>(std::ref (cut_ctx));
613613 rewr.run_on_model (model);
614-
614+
615615 return cut_ctx.lm_head_model ;
616616}
617617
0 commit comments