-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathParallelKitNET.h
More file actions
331 lines (283 loc) · 8.31 KB
/
ParallelKitNET.h
File metadata and controls
331 lines (283 loc) · 8.31 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
#pragma once
#include "KitNET.h"
#include <atomic>
#include <condition_variable>
#include <future>
#include <memory>
#include <mutex>
#include <queue>
#include <thread>
#include <vector>
/**
* @brief 队列元素结构体,用于存储输入数据和结果
*/
struct QueueItem {
std::vector<double> x; // 输入特征向量
double result = 0.0; // 处理结果
std::promise<double> resultPromise; // 用于获取异步结果
bool isTraining = false; // 是否处于训练模式
};
/**
* @brief 线程安全的KitNET工作队列
*/
class KitNETWorkQueue {
private:
std::queue<std::shared_ptr<QueueItem>> items;
std::mutex mutex;
std::condition_variable cv;
bool shutdown = false;
public:
/**
* @brief 添加工作项到队列
* @param item 要处理的工作项
*/
void push(std::shared_ptr<QueueItem> item) {
std::unique_lock<std::mutex> lock(mutex);
items.push(item);
cv.notify_one();
}
/**
* @brief 从队列中获取工作项
* @return 工作项,如果队列关闭则返回nullptr
*/
std::shared_ptr<QueueItem> pop() {
std::unique_lock<std::mutex> lock(mutex);
while (items.empty() && !shutdown) {
cv.wait(lock);
}
if (shutdown && items.empty()) {
return nullptr;
}
auto item = items.front();
items.pop();
return item;
}
/**
* @brief 关闭工作队列
*/
void stop() {
std::unique_lock<std::mutex> lock(mutex);
shutdown = true;
cv.notify_all();
}
/**
* @brief 检查队列是否为空
* @return 如果队列为空返回true
*/
bool empty() const {
std::unique_lock<std::mutex> lock(const_cast<std::mutex &>(mutex));
return items.empty();
}
/**
* @brief 获取队列中的项目数
* @return 队列中的项目数
*/
size_t size() const {
std::unique_lock<std::mutex> lock(const_cast<std::mutex &>(mutex));
return items.size();
}
};
/**
* @brief 并行KitNET实现,使用线程池处理数据
*/
class ParallelKitNET {
private:
std::unique_ptr<KitNET> kitnet; // 原始KitNET实例
std::vector<std::thread> workers; // 工作线程池
KitNETWorkQueue workQueue; // 工作队列
std::atomic<bool> running{false}; // 线程池运行状态
std::atomic<size_t> max_queue_size{1000}; // 最大队列大小
std::mutex kitnet_mutex; // 保护KitNET实例的互斥锁
/**
* @brief 工作线程函数
*/
void workerFunction() {
while (running) {
auto item = workQueue.pop();
if (!item) {
break; // 队列已关闭
}
// 处理数据
std::unique_lock<std::mutex> lock(kitnet_mutex);
double result = 0.0;
if (item->isTraining) {
result = kitnet->train(item->x);
} else {
result = kitnet->process(item->x);
}
lock.unlock();
// 设置结果
item->result = result;
item->resultPromise.set_value(result);
}
}
public:
/**
* @brief 构造函数
* @param n 特征数量
* @param num_threads 使用的线程数量,默认为系统核心数
* @param max_autoencoder_size 自动编码器最大大小
* @param FM_grace_period 特征映射学习期
* @param AD_grace_period 异常检测学习期
* @param learning_rate 学习率
* @param hidden_ratio 隐藏比率
* @param feature_map 特征映射
*/
ParallelKitNET(size_t n, size_t num_threads = 0,
size_t max_autoencoder_size = 10, size_t FM_grace_period = 0,
size_t AD_grace_period = 10000, double learning_rate = 0.1,
double hidden_ratio = 0.75,
const std::vector<std::vector<size_t>> &feature_map = {}) {
// 创建KitNET实例
const std::vector<std::vector<size_t>> *feature_map_ptr = nullptr;
if (!feature_map.empty()) {
feature_map_ptr = &feature_map;
}
kitnet = std::unique_ptr<KitNET>(new KitNET(
n, max_autoencoder_size, FM_grace_period, AD_grace_period,
learning_rate, hidden_ratio, feature_map_ptr));
// 如果未指定线程数,使用系统核心数
if (num_threads == 0) {
num_threads = std::thread::hardware_concurrency();
if (num_threads == 0) {
num_threads = 4; // 默认至少使用4个线程
}
}
// 启动工作线程
startThreads(num_threads);
}
/**
* @brief 析构函数,确保所有线程正确关闭
*/
~ParallelKitNET() { stopThreads(); }
/**
* @brief 启动工作线程
* @param num_threads 要启动的线程数
*/
void startThreads(size_t num_threads) {
if (running) {
return;
}
running = true;
for (size_t i = 0; i < num_threads; ++i) {
workers.emplace_back(&ParallelKitNET::workerFunction, this);
}
}
/**
* @brief 停止所有工作线程
*/
void stopThreads() {
if (!running) {
return;
}
running = false;
workQueue.stop();
for (auto &worker : workers) {
if (worker.joinable()) {
worker.join();
}
}
workers.clear();
}
/**
* @brief 处理输入向量x(训练或执行)
* @param x 输入特征向量
* @param block 是否阻塞等待结果
* @return 如果阻塞,直接返回结果;否则返回future
*/
std::future<double> process(const std::vector<double> &x,
bool block = true) {
auto item = std::make_shared<QueueItem>();
item->x = x;
item->isTraining = false;
auto future = item->resultPromise.get_future();
// 添加到工作队列
workQueue.push(item);
// 如果队列太大,等待一些项目完成
while (workQueue.size() > max_queue_size) {
std::this_thread::sleep_for(std::chrono::nanoseconds(100));
}
// 如果阻塞,等待结果并返回
if (block) {
return future;
}
return future;
}
/**
* @brief 强制训练KitNET
* @param x 输入特征向量
* @param block 是否阻塞等待结果
* @return 如果阻塞,直接返回结果;否则返回future
*/
std::future<double> train(const std::vector<double> &x, bool block = true) {
auto item = std::make_shared<QueueItem>();
item->x = x;
item->isTraining = true;
auto future = item->resultPromise.get_future();
// 添加到工作队列
workQueue.push(item);
// 如果队列太大,等待一些项目完成
while (workQueue.size() > max_queue_size) {
std::this_thread::sleep_for(std::chrono::milliseconds(1));
}
// 如果阻塞,等待结果并返回
if (block) {
return future;
}
return future;
}
/**
* @brief 设置最大队列大小
* @param size 最大队列容量
*/
void setMaxQueueSize(size_t size) { max_queue_size = size; }
/**
* @brief 获取当前队列大小
* @return 当前队列中的项目数
*/
size_t getQueueSize() const { return workQueue.size(); }
/**
* @brief 获取工作线程数
* @return 当前工作线程数
*/
size_t getNumThreads() const { return workers.size(); }
/**
* @brief 处理并等待结果(便捷函数)
* @param x 输入特征向量
* @return 处理结果
*/
double processAndWait(const std::vector<double> &x);
/**
* @brief 训练并等待结果(便捷函数)
* @param x 输入特征向量
* @return 训练结果
*/
double trainAndWait(const std::vector<double> &x);
/**
* @brief 批量处理多个特征向量
* @param batch 特征向量批次
* @param block 是否阻塞等待所有结果
* @return 如果阻塞,返回所有结果;否则返回空向量
*/
std::vector<double>
batchProcess(const std::vector<std::vector<double>> &batch,
bool block = true);
/**
* @brief 批量训练多个特征向量
* @param batch 特征向量批次
* @param block 是否阻塞等待所有结果
* @return 如果阻塞,返回所有结果;否则返回空向量
*/
std::vector<double>
batchTrain(const std::vector<std::vector<double>> &batch,
bool block = true);
/**
* @brief 调整工作线程池大小
* @param num_threads 新的线程数
*/
void resizeThreadPool(size_t num_threads);
/**
* @brief 等待所有队列中的任务完成
*/
void waitForAllTasks();
};