Skip to content

Commit 4bb75e2

Browse files
Add eagle_tree_mask as input
1 parent f8117e8 commit 4bb75e2

File tree

4 files changed

+499
-84
lines changed

4 files changed

+499
-84
lines changed

src/plugins/intel_npu/src/plugin/npuw/llm_compiled_model.cpp

Lines changed: 19 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1094,7 +1094,8 @@ void reshape_to_static(std::shared_ptr<ov::Model> model,
10941094
const uint32_t kvcache_size,
10951095
const KVAxesPosition& kv_axes_position,
10961096
const uint32_t lora_rank,
1097-
const uint32_t lhs_seq_size = 0) {
1097+
const uint32_t lhs_seq_size = 0,
1098+
const bool is_prefill = false) {
10981099
std::map<std::string, ov::PartialShape> new_shapes;
10991100
for (const auto& input : model->inputs()) {
11001101
const auto& input_name = input.get_any_name();
@@ -1129,8 +1130,9 @@ void reshape_to_static(std::shared_ptr<ov::Model> model,
11291130
const auto& partial_shape = input.get_partial_shape();
11301131
new_shape = partial_shape;
11311132
new_shape[0] = 1; // batch_dim
1132-
} else if (ov::npuw::matchEagle3HiddenStatesString(input_name)) {
1133-
new_shape = ov::npuw::Eagle3Extension::get_static_input(model, input, input_size);
1133+
} else if (ov::npuw::matchEagle3HiddenStatesString(input_name) ||
1134+
ov::npuw::matchEagle3TreeMaskString(input_name)) {
1135+
new_shape = ov::npuw::Eagle3Extension::get_static_input(model, input, input_size, kvcache_size, is_prefill);
11341136
} else if (ov::npuw::util::matchLoRAMatMulAString(input_name)) {
11351137
new_shape = ov::PartialShape({lora_rank, input.get_partial_shape()[1]});
11361138
} else if (ov::npuw::util::matchLoRAMatMulAlphaString(input_name)) {
@@ -2018,18 +2020,21 @@ ov::npuw::LLMCompiledModel::LLMCompiledModel(const std::shared_ptr<ov::Model>& m
20182020
LOG_DEBUG("Make prefill model with static shapes");
20192021
m_max_lora_rank = m_cfg.get<::intel_npu::NPUW_LLM_MAX_LORA_RANK>();
20202022
if (m_use_chunk_prefill) {
2021-
ReshapeToStatic(static_cast<uint32_t>(m_prefill_chunk_size),
2022-
m_kvcache_desc.max_prompt_size,
2023-
axes,
2024-
m_max_lora_rank)
2025-
.run_on_model(prefill_model);
2023+
reshape_to_static(prefill_model,
2024+
static_cast<uint32_t>(m_prefill_chunk_size),
2025+
m_kvcache_desc.max_prompt_size,
2026+
axes,
2027+
m_max_lora_rank,
2028+
0,
2029+
true);
20262030
} else {
2027-
ReshapeToStatic(m_kvcache_desc.max_prompt_size,
2028-
m_kvcache_desc.max_prompt_size,
2029-
axes,
2030-
m_max_lora_rank,
2031-
whisper_lhs_seq_size)
2032-
.run_on_model(prefill_model);
2031+
reshape_to_static(prefill_model,
2032+
m_kvcache_desc.max_prompt_size,
2033+
m_kvcache_desc.max_prompt_size,
2034+
axes,
2035+
m_max_lora_rank,
2036+
whisper_lhs_seq_size,
2037+
true);
20332038
}
20342039
LOG_DEBUG("Make kvcache model with static shapes");
20352040

0 commit comments

Comments
 (0)