Skip to content

Commit 5efef96

Browse files
formate
1 parent 006639b commit 5efef96

File tree

1 file changed

+26
-26
lines changed

1 file changed

+26
-26
lines changed

src/plugins/intel_npu/src/plugin/npuw/llm_compiled_model.cpp

Lines changed: 26 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -467,7 +467,7 @@ struct KVAxesPosition {
467467
class CutLMHead : public ov::pass::MatcherPass {
468468
public:
469469
OPENVINO_MATCHER_PASS_RTTI("npuw::patterns::CutLMHead");
470-
470+
471471
struct CutContext {
472472
std::shared_ptr<ov::Model> lm_head_model = nullptr;
473473
std::shared_ptr<ov::Model> source_model = nullptr;
@@ -478,71 +478,70 @@ class CutLMHead : public ov::pass::MatcherPass {
478478
// Create a dummy pattern that matches any Result node
479479
// We'll do the actual logic in the callback
480480
auto result_pattern = opp::wrap_type<ov::op::v0::Result>();
481-
481+
482482
auto callback = [this](ov::pass::pattern::Matcher& m) -> bool {
483483
// Only execute the transformation once
484484
if (m_executed) {
485485
return false;
486486
}
487487
m_executed = true;
488-
488+
489489
// Use the source model provided in context
490490
if (!m_cut_context.get().source_model) {
491491
return false;
492492
}
493-
493+
494494
return execute_cut_logic(m_cut_context.get().source_model);
495495
};
496-
496+
497497
register_matcher(std::make_shared<opp::Matcher>(result_pattern, "CutLMHead"), std::move(callback));
498498
}
499499

500500
private:
501501
CutContext::Ref m_cut_context;
502502
bool m_executed = false;
503-
503+
504504
bool execute_cut_logic(std::shared_ptr<ov::Model> model) {
505505
std::shared_ptr<ov::op::v0::MatMul> best_matmul = nullptr;
506506
std::shared_ptr<ov::op::v0::Result> best_result = nullptr;
507507
std::shared_ptr<ov::Node> best_last_op = nullptr;
508508
size_t max_matrix_size = 0;
509-
509+
510510
for (const auto& result : model->get_results()) {
511511
auto input_node = result->input(0).get_source_output().get_node_shared_ptr();
512-
512+
513513
std::shared_ptr<ov::op::v0::MatMul> matmul = nullptr;
514514
std::shared_ptr<ov::Node> last_op = nullptr;
515-
515+
516516
// Pattern 1: MatMul -> Result (direct connection)
517517
if (auto direct_matmul = ov::as_type_ptr<ov::op::v0::MatMul>(input_node)) {
518518
matmul = direct_matmul;
519519
last_op = direct_matmul;
520520
}
521521
// Patterns 2: MatMul -> (single intermediate op) -> Result
522-
else if (ov::is_type<ov::op::v1::Add>(input_node) ||
523-
ov::is_type<ov::op::v1::Transpose>(input_node) ||
522+
else if (ov::is_type<ov::op::v1::Add>(input_node) || ov::is_type<ov::op::v1::Transpose>(input_node) ||
524523
ov::is_type<ov::op::v0::Convert>(input_node)) {
525524
if (auto matmul_node = ov::as_type_ptr<ov::op::v0::MatMul>(
526-
input_node->input(0).get_source_output().get_node_shared_ptr())) {
525+
input_node->input(0).get_source_output().get_node_shared_ptr())) {
527526
matmul = matmul_node;
528527
last_op = input_node;
529528
}
530529
}
531530
// Pattern 3: MatMul -> Divide -> Tanh -> Multiply -> Result
532531
else if (auto multiply_node = ov::as_type_ptr<ov::op::v1::Multiply>(input_node)) {
533532
if (auto tanh_node = ov::as_type_ptr<ov::op::v0::Tanh>(
534-
multiply_node->input(0).get_source_output().get_node_shared_ptr())) {
533+
multiply_node->input(0).get_source_output().get_node_shared_ptr())) {
535534
if (auto divide_node = ov::as_type_ptr<ov::op::v1::Divide>(
536-
tanh_node->input(0).get_source_output().get_node_shared_ptr())) {
535+
tanh_node->input(0).get_source_output().get_node_shared_ptr())) {
537536
if (auto matmul_node = ov::as_type_ptr<ov::op::v0::MatMul>(
538-
divide_node->input(0).get_source_output().get_node_shared_ptr())) {
537+
divide_node->input(0).get_source_output().get_node_shared_ptr())) {
539538
matmul = matmul_node;
540539
last_op = multiply_node;
541540
}
542541
}
543542
}
544543
}
545-
544+
546545
if (matmul) {
547546
auto weight_shape = matmul->input(1).get_source_output().get_partial_shape();
548547
size_t current_matrix_size = 0;
@@ -555,7 +554,7 @@ class CutLMHead : public ov::pass::MatcherPass {
555554
current_matrix_size = std::max(size0, size1);
556555
}
557556
}
558-
557+
559558
// Update best candidate if this one is larger
560559
if (current_matrix_size > max_matrix_size) {
561560
max_matrix_size = current_matrix_size;
@@ -565,11 +564,11 @@ class CutLMHead : public ov::pass::MatcherPass {
565564
}
566565
}
567566
}
568-
567+
569568
if (best_matmul) {
570569
// Cut point:
571570
auto matmul_first_source = best_matmul->input(0).get_source_output();
572-
571+
573572
// Cut original model:
574573
best_result->input(0).replace_source_output(matmul_first_source);
575574
// FIXME: Somehow for KVCache model result output gets renamed in
@@ -579,23 +578,24 @@ class CutLMHead : public ov::pass::MatcherPass {
579578
matmul_first_source.set_names({ov::npuw::LLMCompiledModel::output_embeds});
580579
best_result->output(0).set_names({ov::npuw::LLMCompiledModel::output_embeds});
581580
best_result->validate_and_infer_types();
582-
581+
583582
// Create an additional model after cut point:
584583
auto new_param = std::make_shared<ov::op::v0::Parameter>(matmul_first_source.get_element_type(),
585584
matmul_first_source.get_partial_shape());
586585
new_param->output(0).add_names({ov::npuw::LLMCompiledModel::output_embeds});
587586
best_matmul->input(0).replace_source_output(new_param);
588587
auto new_result = std::make_shared<ov::op::v0::Result>(best_last_op);
589-
m_cut_context.get().lm_head_model = std::make_shared<ov::Model>(ov::OutputVector{new_result->output(0)}, ov::ParameterVector{new_param});
590-
588+
m_cut_context.get().lm_head_model =
589+
std::make_shared<ov::Model>(ov::OutputVector{new_result->output(0)}, ov::ParameterVector{new_param});
590+
591591
if (m_cut_context.get().lm_head_model) {
592592
m_cut_context.get().lm_head_model->set_friendly_name(model->get_friendly_name() + "_lm_head");
593593
}
594594
model->validate_nodes_and_infer_types();
595-
595+
596596
return true;
597597
}
598-
598+
599599
return false;
600600
}
601601
};
@@ -605,13 +605,13 @@ namespace {
605605
std::shared_ptr<ov::Model> cut_lm_head(std::shared_ptr<ov::Model>& model) {
606606
LOG_DEBUG("Executing LM head cutting transformation using MatcherPass");
607607
LOG_BLOCK();
608-
608+
609609
CutLMHead::CutContext cut_ctx;
610610
cut_ctx.source_model = model;
611611
ov::pass::GraphRewrite rewr;
612612
rewr.add_matcher<CutLMHead>(std::ref(cut_ctx));
613613
rewr.run_on_model(model);
614-
614+
615615
return cut_ctx.lm_head_model;
616616
}
617617

0 commit comments

Comments
 (0)