@@ -1094,7 +1094,8 @@ void reshape_to_static(std::shared_ptr<ov::Model> model,
10941094 const uint32_t kvcache_size,
10951095 const KVAxesPosition& kv_axes_position,
10961096 const uint32_t lora_rank,
1097- const uint32_t lhs_seq_size = 0 ) {
1097+ const uint32_t lhs_seq_size = 0 ,
1098+ const bool is_prefill = false ) {
10981099 std::map<std::string, ov::PartialShape> new_shapes;
10991100 for (const auto & input : model->inputs ()) {
11001101 const auto & input_name = input.get_any_name ();
@@ -1129,8 +1130,9 @@ void reshape_to_static(std::shared_ptr<ov::Model> model,
11291130 const auto & partial_shape = input.get_partial_shape ();
11301131 new_shape = partial_shape;
11311132 new_shape[0 ] = 1 ; // batch_dim
1132- } else if (ov::npuw::matchEagle3HiddenStatesString (input_name)) {
1133- new_shape = ov::npuw::Eagle3Extension::get_static_input (model, input, input_size);
1133+ } else if (ov::npuw::matchEagle3HiddenStatesString (input_name) ||
1134+ ov::npuw::matchEagle3TreeMaskString (input_name)) {
1135+ new_shape = ov::npuw::Eagle3Extension::get_static_input (model, input, input_size, kvcache_size, is_prefill);
11341136 } else if (ov::npuw::util::matchLoRAMatMulAString (input_name)) {
11351137 new_shape = ov::PartialShape ({lora_rank, input.get_partial_shape ()[1 ]});
11361138 } else if (ov::npuw::util::matchLoRAMatMulAlphaString (input_name)) {
@@ -2018,18 +2020,21 @@ ov::npuw::LLMCompiledModel::LLMCompiledModel(const std::shared_ptr<ov::Model>& m
20182020 LOG_DEBUG (" Make prefill model with static shapes" );
20192021 m_max_lora_rank = m_cfg.get <::intel_npu::NPUW_LLM_MAX_LORA_RANK>();
20202022 if (m_use_chunk_prefill) {
2021- ReshapeToStatic (static_cast <uint32_t >(m_prefill_chunk_size),
2022- m_kvcache_desc.max_prompt_size ,
2023- axes,
2024- m_max_lora_rank)
2025- .run_on_model (prefill_model);
2023+ reshape_to_static (prefill_model,
2024+ static_cast <uint32_t >(m_prefill_chunk_size),
2025+ m_kvcache_desc.max_prompt_size ,
2026+ axes,
2027+ m_max_lora_rank,
2028+ 0 ,
2029+ true );
20262030 } else {
2027- ReshapeToStatic (m_kvcache_desc.max_prompt_size ,
2028- m_kvcache_desc.max_prompt_size ,
2029- axes,
2030- m_max_lora_rank,
2031- whisper_lhs_seq_size)
2032- .run_on_model (prefill_model);
2031+ reshape_to_static (prefill_model,
2032+ m_kvcache_desc.max_prompt_size ,
2033+ m_kvcache_desc.max_prompt_size ,
2034+ axes,
2035+ m_max_lora_rank,
2036+ whisper_lhs_seq_size,
2037+ true );
20332038 }
20342039 LOG_DEBUG (" Make kvcache model with static shapes" );
20352040
0 commit comments