Skip to content

Commit d52c9d9

Browse files
committed
GPU-enabled createdb could crash due to invalid keys
1 parent 0578939 commit d52c9d9

File tree

1 file changed

+20
-18
lines changed

1 file changed

+20
-18
lines changed

src/util/createdb.cpp

Lines changed: 20 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -131,12 +131,15 @@ int mergeSequentialByJointIndex(
131131
char * outHeaderIndexFile,
132132
const char * outLookupFile,
133133
std::vector<unsigned int>* sourceLookup,
134-
size_t totalEntries
134+
size_t totalEntries,
135+
size_t shuffleSplits
135136
) {
136137
struct JointEntry {
137138
unsigned int fileIdx;
138139
unsigned int id;
139140
unsigned length;
141+
JointEntry(unsigned int fileIdx, unsigned int id, unsigned length) : fileIdx(fileIdx), id(id), length(length) {};
142+
140143
bool operator<(JointEntry const &o) const {
141144
if (length != o.length){
142145
return length < o.length;
@@ -145,38 +148,37 @@ int mergeSequentialByJointIndex(
145148
}
146149
};
147150

148-
size_t k = 32;
149-
std::vector<JointEntry> joint(totalEntries);
151+
std::vector<JointEntry> joint;
152+
joint.reserve(totalEntries);
150153
size_t maxLen = 0;
151-
size_t pos = 0;
152-
for (size_t i = 0; i < k; i++) {
154+
for (size_t i = 0; i < shuffleSplits; i++) {
153155
DBReader<unsigned int> reader(
154156
dataFiles[i],
155157
indexFiles[i],
156158
1,
157159
DBReader<uint32_t>::USE_INDEX
158160
);
159161
reader.open(DBReader<uint32_t>::HARDNOSORT);
162+
DBReader<unsigned int>::Index* index = reader.getIndex();
160163
for(size_t j = 0; j < reader.getSize(); j++){
161-
joint[pos] = { (unsigned int)i, reader.getIndex()[j].id, reader.getIndex()[j].length };
162-
maxLen = std::max(maxLen, static_cast<size_t>(reader.getIndex()[j].length));
163-
pos++;
164+
joint.emplace_back((unsigned int)i, index[j].id, index[j].length);
165+
maxLen = std::max(maxLen, static_cast<size_t>(index[j].length));
164166
}
165167
reader.close();
166168
}
167169

168170
SORT_PARALLEL(joint.begin(), joint.end());
169171

170172
// 4) Open each data file once (no fseek later)
171-
std::vector<FILE*> inFileSeq(k);
172-
std::vector<FILE*> inFileHeader(k);
173+
std::vector<FILE*> inFileSeq(shuffleSplits);
174+
std::vector<FILE*> inFileHeader(shuffleSplits);
173175

174-
for (size_t i = 0; i < k; i++) {
176+
for (size_t i = 0; i < shuffleSplits; i++) {
175177
inFileSeq[i] = FileUtil::openFileOrDie(
176-
dataFiles[i], "rb", true
178+
dataFiles[i], "rb", true
177179
);
178180
inFileHeader[i] = FileUtil::openFileOrDie(
179-
dataFilesHeader[i], "rb", true
181+
dataFilesHeader[i], "rb", true
180182
);
181183
}
182184

@@ -214,8 +216,8 @@ int mergeSequentialByJointIndex(
214216
// Create buffers for each header file
215217
std::vector<FileBuffer> headerBuffers;
216218
std::vector<char> writeHeaderBuf;
217-
headerBuffers.reserve(k);
218-
for (size_t i = 0; i < k; i++) {
219+
headerBuffers.reserve(shuffleSplits);
220+
for (size_t i = 0; i < shuffleSplits; i++) {
219221
headerBuffers.emplace_back(inFileHeader[i]);
220222
}
221223
size_t mergedOffset = 0;
@@ -646,7 +648,7 @@ int createdb(int argc, const char **argv, const Command& command) {
646648
}
647649
}
648650
processSeqBatch(par, seqWriter, hdrWriter, subMat, dbType, masker, seqs,
649-
id - batchPos, batchEntries, batchPos, shuffleSplits);
651+
id - (batchPos - 1), batchEntries, batchPos, shuffleSplits);
650652
batchPos = 0;
651653
}
652654
}
@@ -667,7 +669,7 @@ int createdb(int argc, const char **argv, const Command& command) {
667669
}
668670
}
669671
processSeqBatch(par, seqWriter, hdrWriter, subMat, dbType, masker, seqs,
670-
(par.identifierOffset + entries_num) - batchPos, batchEntries, batchPos, shuffleSplits);
672+
(par.identifierOffset + entries_num) - (batchPos - 1), batchEntries, batchPos, shuffleSplits);
671673
}
672674

673675
if (numEntriesInCurrFile == 0) {
@@ -706,7 +708,7 @@ int createdb(int argc, const char **argv, const Command& command) {
706708
hdrWriter.getDataFileNames(), hdrWriter.getIndexFileNames(),
707709
seqWriter.getDataFileName(), seqWriter.getIndexFileName(),
708710
hdrWriter.getDataFileName(), hdrWriter.getIndexFileName(),
709-
lookupFile.c_str(), sourceLookup, entries_num);
711+
lookupFile.c_str(), sourceLookup, entries_num, shuffleSplits);
710712
Debug(Debug::INFO) << "Merge all files " << timer.lap() << "\n";
711713
hdrWriter.clearMemory();
712714
seqWriter.clearMemory();

0 commit comments

Comments
 (0)