-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathFeatureExtractor.cpp
More file actions
268 lines (234 loc) · 7.78 KB
/
FeatureExtractor.cpp
File metadata and controls
268 lines (234 loc) · 7.78 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
#define _CRT_SECURE_NO_WARNINGS
#include "FeatureExtractor.h"
#include <algorithm>
#include <cctype>
#include <cstdlib>
#include <cstring>
#include <filesystem>
#include <fstream>
#include <iostream>
#include <sstream>
#include <stdexcept>
#include <string>
#include <vector>
#ifdef _WIN32
#include <windows.h>
#define PATH_SEPARATOR "\\"
#else
#include <unistd.h>
#define PATH_SEPARATOR "/"
#endif
// CSV解析帮助函数
std::vector<std::string> split(const std::string &str, char delimiter) {
std::vector<std::string> tokens;
std::string token;
std::istringstream token_stream(str);
while (std::getline(token_stream, token, delimiter)) {
tokens.push_back(token);
}
return tokens;
}
// 构造函数
FeatureExtractor::FeatureExtractor(const std::string &file_path, double limit)
: path(file_path), limit(limit), curPacketIndx(0), tsvinf(nullptr) {
// 准备pcap
_prep();
// 准备特征提取器(AfterImage)
double maxHost = 100000000000;
double maxSess = 100000000000;
nstat.reset(new netStat(NAN, maxHost, maxSess));
}
// 析构函数
FeatureExtractor::~FeatureExtractor() {
if (tsvinf) {
fclose(tsvinf);
}
}
// 获取tshark路径
std::string FeatureExtractor::_get_tshark_path() {
#ifdef _WIN32
return "C:\\Program Files\\Wireshark\\tshark.exe";
#else
const char *path = getenv("PATH");
if (!path) {
return "";
}
std::string path_str(path);
std::string delimiter = ":";
size_t pos = 0;
std::string token;
while ((pos = path_str.find(delimiter)) != std::string::npos) {
token = path_str.substr(0, pos);
std::string tshark_path = token + "/tshark";
if (access(tshark_path.c_str(), X_OK) == 0) {
return tshark_path;
}
path_str.erase(0, pos + delimiter.length());
}
// 处理最后一段路径
std::string tshark_path = path_str + "/tshark";
if (access(tshark_path.c_str(), X_OK) == 0) {
return tshark_path;
}
return "";
#endif
}
// 准备文件
void FeatureExtractor::_prep() {
// 检查文件是否存在
if (!std::filesystem::exists(path)) {
std::cerr << "File: " << path << " does not exist" << std::endl;
throw std::runtime_error("File does not exist");
}
// 检查文件类型
std::string extension = path.substr(path.find_last_of(".") + 1);
std::transform(extension.begin(), extension.end(), extension.begin(),
[](unsigned char c) { return std::tolower(c); });
std::string tshark_path = _get_tshark_path();
// 如果文件是TSV(预先由wireshark脚本解析)
if (extension == "tsv") {
parse_type = "tsv";
}
// 如果文件是pcap
else if (extension == "pcap" || extension == "pcapng") {
// 尝试通过wireshark的tshark dll解析(更快)
if (!tshark_path.empty() && std::filesystem::exists(tshark_path)) {
pcap2tsv_with_tshark(); // 创建本地tsv文件
path += ".tsv";
parse_type = "tsv";
} else {
std::cerr
<< "tshark not found. Cannot parse pcap file directly in C++."
<< std::endl;
throw std::runtime_error("tshark not found for pcap parsing");
}
} else {
std::cerr << "File: " << path << " is not a tsv or pcap file"
<< std::endl;
throw std::runtime_error("Invalid file type");
}
// 打开读取器
if (parse_type == "tsv") {
std::cout << "Counting lines in file..." << std::endl;
// 计算文件行数
std::ifstream file(path);
size_t num_lines = 0;
std::string line;
while (std::getline(file, line)) {
++num_lines;
}
std::cout << "There are " << num_lines << " Packets." << std::endl;
limit = (std::min)(limit, static_cast<double>(num_lines - 1));
// 打开TSV文件进行读取
tsvinf = fopen(path.c_str(), "r");
if (!tsvinf) {
throw std::runtime_error("Failed to open TSV file");
}
// 跳过标题行
char buffer[8192];
fgets(buffer, sizeof(buffer), tsvinf);
} else {
throw std::runtime_error(
"Only TSV parsing is supported in this C++ version");
}
}
// 使用tshark将pcap转换为tsv
void FeatureExtractor::pcap2tsv_with_tshark() {
std::cout << "Parsing with tshark..." << std::endl;
std::string tshark_path = _get_tshark_path();
std::string fields =
"-e frame.time_epoch -e frame.len -e eth.src -e eth.dst -e ip.src "
"-e ip.dst -e tcp.srcport -e tcp.dstport -e udp.srcport -e udp.dstport "
"-e icmp.type -e icmp.code -e arp.opcode -e arp.src.hw_mac "
"-e arp.src.proto_ipv4 -e arp.dst.hw_mac -e arp.dst.proto_ipv4 "
"-e ipv6.src -e ipv6.dst";
std::string cmd;
#ifdef _WIN32
cmd = "\"" + tshark_path + "\" -r " + path + " -T fields " + fields +
" -E header=y -E occurrence=f > " + path + ".tsv";
#else
cmd = "\"" + tshark_path + "\" -r " + path + " -T fields " + fields +
" -E header=y -E occurrence=f > " + path + ".tsv";
#endif
int ret = std::system(cmd.c_str());
if (ret != 0) {
throw std::runtime_error("tshark command failed with error code: " +
std::to_string(ret));
}
std::cout << "tshark parsing complete. File saved as: " << path << ".tsv"
<< std::endl;
}
// 获取下一个特征向量
std::vector<double> FeatureExtractor::get_next_vector() {
if (curPacketIndx >= static_cast<uint64_t>(limit)) {
if (tsvinf) {
fclose(tsvinf);
tsvinf = nullptr;
}
return {};
}
// 解析下一个数据包
if (parse_type == "tsv") {
char buffer[8192];
if (!fgets(buffer, sizeof(buffer), tsvinf)) {
return {};
}
// 去除末尾的换行符
size_t len = strlen(buffer);
if (len > 0 && buffer[len - 1] == '\n') {
buffer[len - 1] = '\0';
}
std::vector<std::string> row = split(buffer, '\t');
int IPtype = NAN;
double timestamp = std::stod(row[0]);
int framelen = std::stoi(row[1]);
std::string srcIP = "";
std::string dstIP = "";
if (!row[4].empty()) { // IPv4
srcIP = row[4];
dstIP = row[5];
IPtype = 0;
} else if (!row[17].empty()) { // IPv6
srcIP = row[17];
dstIP = row[18];
IPtype = 1;
}
std::string srcproto = row[6] + row[8]; // UDP或TCP端口
std::string dstproto = row[7] + row[9]; // UDP或TCP端口
std::string srcMAC = row[2];
std::string dstMAC = row[3];
if (srcproto.empty()) { // 这是一个L2/L1级协议
if (!row[12].empty()) { // 是ARP
srcproto = "arp";
dstproto = "arp";
srcIP = row[14]; // 源IP(ARP)
dstIP = row[16]; // 目标IP(ARP)
IPtype = 0;
} else if (!row[10].empty()) { // 是ICMP
srcproto = "icmp";
dstproto = "icmp";
IPtype = 0;
} else if (srcIP.empty() && srcproto.empty() && dstIP.empty() &&
dstproto.empty()) { // 其他协议
srcIP = row[2]; // 源MAC
dstIP = row[3]; // 目标MAC
}
}
curPacketIndx++;
// 提取特征
try {
return nstat->updateGetStats(IPtype, srcMAC, dstMAC, srcIP,
srcproto, dstIP, dstproto, framelen,
timestamp);
} catch (const std::exception &e) {
std::cerr << "Error extracting features: " << e.what() << std::endl;
return {};
}
} else {
return {};
}
}
// 获取特征数量
size_t FeatureExtractor::get_num_features() {
return nstat->getNetStatHeaders().size();
}