File tree Expand file tree Collapse file tree 4 files changed +30
-4
lines changed
Expand file tree Collapse file tree 4 files changed +30
-4
lines changed Original file line number Diff line number Diff line change 66#include " kv_cache.h"
77#include " windowed_kv_cache.h"
88#include " ../openvino/interface.h"
9+ #include " ../qnn/interface.h"
910#include < algorithm>
1011
1112namespace Generators {
@@ -507,10 +508,10 @@ bool IsCacheNeeded(const Model& model) {
507508} // namespace
508509
509510std::unique_ptr<KeyValueCache> CreateKeyValueCache (State& state) {
510- // For OpenVINO Stateful models, they do not contain exposed past/present KV tensors.
511+ // For OpenVINO and QNN Stateful models, they do not contain exposed past/present KV tensors.
511512 // In this case, 'IsCacheNeeded' below will return false. But in this case we need to create a
512513 // special 'ModelManagedKeyValueCache' object, and so we check this condition first.
513- if (IsOpenVINOStatefulModel (state.model_ )) {
514+ if (IsOpenVINOStatefulModel (state.model_ ) || IsQNNStatefulModel (state. model_ ) ) {
514515 if (g_log.enabled )
515516 Log (" info" , " CreateKeyValueCache: Creating ModelManagedKeyValueCache" );
516517 return std::make_unique<ModelManagedKeyValueCache>(state);
Original file line number Diff line number Diff line change 66#include " model.h"
77#include " logits.h"
88#include " ../openvino/interface.h"
9+ #include " ../qnn/interface.h"
910
1011namespace Generators {
1112
@@ -17,8 +18,8 @@ Logits::Logits(State& state)
1718
1819 input_sequence_lengths.resize (state_.params_ ->search .batch_size );
1920
20- if (IsOpenVINOStatefulModel (state.model_ ) || state.model_ .IsPruned ()) {
21- // In the case of OpenVINO stateful models, or any model whose ONNX graph
21+ if (IsOpenVINOStatefulModel (state.model_ ) || IsQNNStatefulModel (state. model_ ) || state.model_ .IsPruned ()) {
22+ // In the case of OpenVINO and QNN stateful models, or any model whose ONNX graph
2223 // has been patched to only output last-token logits (logits dim[1]==1), they only return the
2324 // sliced logits needed for sampling. For example, given 43 prompt tokens, instead of returning
2425 // logits of the shape: [1,43,<vocab_size>]
Original file line number Diff line number Diff line change 33
44#include " ../generators.h"
55#include " ../search.h"
6+ #include " ../models/model.h"
67#include " interface.h"
78
89namespace Generators {
@@ -78,4 +79,24 @@ DeviceInterface* GetQNNInterface() {
7879 return g_device.get ();
7980}
8081
82+ bool IsQNNStatefulModel (const Model& model) {
83+ if (model.p_device_ ->GetType () == DeviceType::QNN || model.p_device_ ->GetType () == DeviceType::CPU) {
84+ const auto & provider_options = model.config_ ->model .decoder .session_options .provider_options ;
85+ for (auto & po : provider_options) {
86+ if (po.name == " QNN" ) {
87+ const auto & qnn_options = po.options ;
88+ for (auto & option : qnn_options) {
89+ // For QNN, if session option 'genai_model' is set, the session will encapsulate
90+ // a stateful model, so KVCache will be managed internally.
91+ if (option.first == " genai_model" && option.second == " True" ) {
92+ return true ;
93+ }
94+ }
95+ }
96+ }
97+ }
98+
99+ return false ;
100+ }
101+
81102} // namespace Generators
Original file line number Diff line number Diff line change @@ -5,4 +5,7 @@ namespace Generators {
55
66DeviceInterface* GetQNNInterface ();
77
8+ struct Model ;
9+ bool IsQNNStatefulModel (const Model& model);
10+
811} // namespace Generators
You can’t perform that action at this time.
0 commit comments