Skip to content

Commit 9d44197

Browse files
committed
add support for QNN stateful models
1 parent 3ee4d48 commit 9d44197

File tree

4 files changed

+30
-4
lines changed

4 files changed

+30
-4
lines changed

src/models/kv_cache.cpp

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
#include "kv_cache.h"
77
#include "windowed_kv_cache.h"
88
#include "../openvino/interface.h"
9+
#include "../qnn/interface.h"
910
#include <algorithm>
1011

1112
namespace Generators {
@@ -507,10 +508,10 @@ bool IsCacheNeeded(const Model& model) {
507508
} // namespace
508509

509510
std::unique_ptr<KeyValueCache> CreateKeyValueCache(State& state) {
510-
// For OpenVINO Stateful models, they do not contain exposed past/present KV tensors.
511+
// For OpenVINO and QNN Stateful models, they do not contain exposed past/present KV tensors.
511512
// In this case, 'IsCacheNeeded' below will return false. But in this case we need to create a
512513
// special 'ModelManagedKeyValueCache' object, and so we check this condition first.
513-
if (IsOpenVINOStatefulModel(state.model_)) {
514+
if (IsOpenVINOStatefulModel(state.model_) || IsQNNStatefulModel(state.model_)) {
514515
if (g_log.enabled)
515516
Log("info", "CreateKeyValueCache: Creating ModelManagedKeyValueCache");
516517
return std::make_unique<ModelManagedKeyValueCache>(state);

src/models/logits.cpp

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
#include "model.h"
77
#include "logits.h"
88
#include "../openvino/interface.h"
9+
#include "../qnn/interface.h"
910

1011
namespace Generators {
1112

@@ -17,8 +18,8 @@ Logits::Logits(State& state)
1718

1819
input_sequence_lengths.resize(state_.params_->search.batch_size);
1920

20-
if (IsOpenVINOStatefulModel(state.model_) || state.model_.IsPruned()) {
21-
// In the case of OpenVINO stateful models, or any model whose ONNX graph
21+
if (IsOpenVINOStatefulModel(state.model_) || IsQNNStatefulModel(state.model_) || state.model_.IsPruned()) {
22+
// In the case of OpenVINO and QNN stateful models, or any model whose ONNX graph
2223
// has been patched to only output last-token logits (logits dim[1]==1), they only return the
2324
// sliced logits needed for sampling. For example, given 43 prompt tokens, instead of returning
2425
// logits of the shape: [1,43,<vocab_size>]

src/qnn/interface.cpp

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33

44
#include "../generators.h"
55
#include "../search.h"
6+
#include "../models/model.h"
67
#include "interface.h"
78

89
namespace Generators {
@@ -78,4 +79,24 @@ DeviceInterface* GetQNNInterface() {
7879
return g_device.get();
7980
}
8081

82+
bool IsQNNStatefulModel(const Model& model) {
83+
if (model.p_device_->GetType() == DeviceType::QNN || model.p_device_->GetType() == DeviceType::CPU) {
84+
const auto& provider_options = model.config_->model.decoder.session_options.provider_options;
85+
for (auto& po : provider_options) {
86+
if (po.name == "QNN") {
87+
const auto& qnn_options = po.options;
88+
for (auto& option : qnn_options) {
89+
// For QNN, if session option 'genai_model' is set, the session will encapsulate
90+
// a stateful model, so KVCache will be managed internally.
91+
if (option.first == "genai_model" && option.second == "True") {
92+
return true;
93+
}
94+
}
95+
}
96+
}
97+
}
98+
99+
return false;
100+
}
101+
81102
} // namespace Generators

src/qnn/interface.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,4 +5,7 @@ namespace Generators {
55

66
DeviceInterface* GetQNNInterface();
77

8+
struct Model;
9+
bool IsQNNStatefulModel(const Model& model);
10+
811
} // namespace Generators

0 commit comments

Comments
 (0)