diff --git a/src/models/kv_cache.cpp b/src/models/kv_cache.cpp index 41b126a51..5e0a1be7b 100644 --- a/src/models/kv_cache.cpp +++ b/src/models/kv_cache.cpp @@ -6,6 +6,7 @@ #include "kv_cache.h" #include "windowed_kv_cache.h" #include "../openvino/interface.h" +#include "../qnn/interface.h" #include namespace Generators { @@ -562,10 +563,10 @@ bool IsCacheNeeded(const Model& model) { } // namespace std::unique_ptr CreateKeyValueCache(State& state) { - // For OpenVINO Stateful models, they do not contain exposed past/present KV tensors. + // For OpenVINO and QNN Stateful models, they do not contain exposed past/present KV tensors. // In this case, 'IsCacheNeeded' below will return false. But in this case we need to create a // special 'ModelManagedKeyValueCache' object, and so we check this condition first. - if (IsOpenVINOStatefulModel(state.model_)) { + if (IsOpenVINOStatefulModel(state.model_) || IsQNNStatefulModel(state.model_)) { if (g_log.enabled) Log("info", "CreateKeyValueCache: Creating ModelManagedKeyValueCache"); return std::make_unique(state); diff --git a/src/qnn/interface.cpp b/src/qnn/interface.cpp index 0fc746437..f05434b52 100644 --- a/src/qnn/interface.cpp +++ b/src/qnn/interface.cpp @@ -3,6 +3,7 @@ #include "../generators.h" #include "../search.h" +#include "../models/model.h" #include "interface.h" namespace Generators { @@ -78,4 +79,31 @@ DeviceInterface* GetQNNInterface() { return g_device.get(); } +bool IsQNNStatefulModel(const Model& model) { + // Check for both QNN and CPU device types + // When using QNN EP with genai_model=True, the model is stateful regardless of device type (QNN/CPU) + // For QNN models with enable_htp_shared_memory_allocator=1, p_device_ will be QNN type + // For QNN models without shared memory allocator, p_device_ will be CPU type + // Both cases need to be handled the same way for stateful models where KV cache is managed internally + if (model.p_device_->GetType() == DeviceType::QNN || model.p_device_->GetType() == DeviceType::CPU) { + const auto& provider_options = model.config_->model.decoder.session_options.provider_options; + for (const auto& po : provider_options) { + if (po.name == "QNN") { + for (const auto& option : po.options) { + // For QNN, if session option 'genie_model' is set to true, the session will encapsulate + // a stateful model, so KVCache will be managed internally. + if (option.first == "genie_model") { + std::string lower_value(option.second); + std::transform(lower_value.begin(), lower_value.end(), lower_value.begin(), + [](unsigned char c) { return static_cast(std::tolower(c)); }); + return lower_value == "true"; + } + } + } + } + } + + return false; +} + } // namespace Generators diff --git a/src/qnn/interface.h b/src/qnn/interface.h index fcbfe1f64..724a6ede8 100644 --- a/src/qnn/interface.h +++ b/src/qnn/interface.h @@ -5,4 +5,7 @@ namespace Generators { DeviceInterface* GetQNNInterface(); +struct Model; +bool IsQNNStatefulModel(const Model& model); + } // namespace Generators \ No newline at end of file