-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathtest_corClust.cpp
More file actions
143 lines (121 loc) · 4.53 KB
/
Copy pathtest_corClust.cpp
File metadata and controls
143 lines (121 loc) · 4.53 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
#include "corClust.h"
#include <chrono>
#include <iomanip>
#include <iostream>
#include <random>
// 生成指定范围内的随机double值
double random_double(double min, double max) {
static std::random_device rd;
static std::mt19937 gen(rd());
std::uniform_real_distribution<> dis(min, max);
return dis(gen);
}
// 生成具有相关性的随机数据
std::vector<std::vector<double>> generate_correlated_data(size_t n_samples,
size_t n_features) {
std::vector<std::vector<double>> data(n_samples,
std::vector<double>(n_features));
// 创建3个相关组:特征0-2相关,特征3-5相关,特征6-9相关
for (size_t i = 0; i < n_samples; ++i) {
// 为每个组生成基础值
double base1 = random_double(-10, 10);
double base2 = random_double(-10, 10);
double base3 = random_double(-10, 10);
// 组1: 特征0-2
data[i][0] = base1 + random_double(-0.5, 0.5);
data[i][1] = base1 * 1.5 + random_double(-0.5, 0.5);
data[i][2] = base1 * 0.8 + random_double(-0.5, 0.5);
// 组2: 特征3-5
if (n_features > 3) {
data[i][3] = base2 + random_double(-0.5, 0.5);
if (n_features > 4) {
data[i][4] = base2 * 2.0 + random_double(-0.5, 0.5);
if (n_features > 5) {
data[i][5] = base2 * 0.5 + random_double(-0.5, 0.5);
}
}
}
// 组3: 特征6-9
if (n_features > 6) {
data[i][6] = base3 + random_double(-0.5, 0.5);
if (n_features > 7) {
data[i][7] = -base3 + random_double(-0.5, 0.5); // 负相关
if (n_features > 8) {
data[i][8] = base3 * 1.2 + random_double(-0.5, 0.5);
if (n_features > 9) {
data[i][9] =
-base3 * 0.7 + random_double(-0.5, 0.5); // 负相关
}
}
}
}
// 如果有更多特征,使它们相互独立
for (size_t j = 10; j < n_features; ++j) {
data[i][j] = random_double(-10, 10);
}
}
return data;
}
// 打印聚类结果
void print_clusters(const std::vector<std::vector<size_t>> &clusters) {
std::cout << "聚类结果:" << std::endl;
for (size_t i = 0; i < clusters.size(); ++i) {
std::cout << " 聚类 " << i << ": [ ";
for (const auto &feature_idx : clusters[i]) {
std::cout << feature_idx << " ";
}
std::cout << "]" << std::endl;
}
}
int main() {
// 参数设置
const size_t n_features = 10;
const size_t n_samples = 1000;
const size_t max_cluster_size = 4;
std::cout << "测试 corClust 类" << std::endl;
std::cout << "--------------------------------" << std::endl;
std::cout << "特征数量: " << n_features << std::endl;
std::cout << "样本数量: " << n_samples << std::endl;
std::cout << "最大聚类大小: " << max_cluster_size << std::endl;
std::cout << "--------------------------------" << std::endl;
// 创建corClust实例
corClust cc(n_features);
// 生成相关数据
std::cout << "生成相关测试数据..." << std::endl;
auto data = generate_correlated_data(n_samples, n_features);
// 计时更新过程
auto start_time = std::chrono::high_resolution_clock::now();
// 更新相关矩阵
std::cout << "更新相关矩阵中..." << std::endl;
for (const auto &sample : data) {
cc.update(sample);
}
// 计算并打印相关距离矩阵
std::cout << "\n相关距离矩阵:" << std::endl;
auto D = cc.corrDist();
std::cout << " ";
for (size_t i = 0; i < n_features; ++i) {
std::cout << std::setw(8) << i << " ";
}
std::cout << std::endl;
for (size_t i = 0; i < n_features; ++i) {
std::cout << std::setw(4) << i << " ";
for (size_t j = 0; j < n_features; ++j) {
std::cout << std::fixed << std::setprecision(4) << std::setw(8)
<< D[i][j] << " ";
}
std::cout << std::endl;
}
// 执行聚类
std::cout << "\n执行特征聚类..." << std::endl;
auto clusters = cc.cluster(max_cluster_size);
auto end_time = std::chrono::high_resolution_clock::now();
auto duration = std::chrono::duration_cast<std::chrono::milliseconds>(
end_time - start_time)
.count();
// 打印聚类结果
print_clusters(clusters);
std::cout << "\n处理时间: " << duration << " 毫秒" << std::endl;
std::cout << "--------------------------------" << std::endl;
return 0;
}