Skip to content

Commit 0962ddb

Browse files
Remove query-time usage of ByteSequence::slice to reduce object allocations (#403)
1 parent 3642fc9 commit 0962ddb

File tree

8 files changed

+141
-55
lines changed

8 files changed

+141
-55
lines changed

jvector-base/src/main/java/io/github/jbellis/jvector/quantization/PQDecoder.java

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -53,8 +53,8 @@ protected CachingDecoder(PQVectors cv, VectorFloat<?> query, VectorSimilarityFun
5353
}
5454
}
5555

56-
protected float decodedSimilarity(ByteSequence<?> encoded) {
57-
return VectorUtil.assembleAndSum(partialSums, cv.pq.getClusterCount(), encoded);
56+
protected float decodedSimilarity(ByteSequence<?> encoded, int offset, int length) {
57+
return VectorUtil.assembleAndSum(partialSums, cv.pq.getClusterCount(), encoded, offset, length);
5858
}
5959
}
6060

@@ -65,7 +65,7 @@ public DotProductDecoder(PQVectors cv, VectorFloat<?> query) {
6565

6666
@Override
6767
public float similarityTo(int node2) {
68-
return (1 + decodedSimilarity(cv.get(node2))) / 2;
68+
return (1 + decodedSimilarity(cv.getChunk(node2), cv.getOffsetInChunk(node2), cv.pq.getSubspaceCount())) / 2;
6969
}
7070
}
7171

@@ -76,7 +76,7 @@ public EuclideanDecoder(PQVectors cv, VectorFloat<?> query) {
7676

7777
@Override
7878
public float similarityTo(int node2) {
79-
return 1 / (1 + decodedSimilarity(cv.get(node2)));
79+
return 1 / (1 + decodedSimilarity(cv.getChunk(node2), cv.getOffsetInChunk(node2), cv.pq.getSubspaceCount()));
8080
}
8181
}
8282

@@ -132,9 +132,10 @@ public float similarityTo(int node2) {
132132

133133
protected float decodedCosine(int node2) {
134134

135-
ByteSequence<?> encoded = cv.get(node2);
135+
ByteSequence<?> encoded = cv.getChunk(node2);
136+
int offset = cv.getOffsetInChunk(node2);
136137

137-
return VectorUtil.pqDecodedCosineSimilarity(encoded, cv.pq.getClusterCount(), partialSums, aMagnitude, bMagnitude);
138+
return VectorUtil.pqDecodedCosineSimilarity(encoded, offset, cv.pq.getSubspaceCount(), cv.pq.getClusterCount(), partialSums, aMagnitude, bMagnitude);
138139
}
139140
}
140141
}

jvector-base/src/main/java/io/github/jbellis/jvector/quantization/PQVectors.java

Lines changed: 43 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -229,11 +229,12 @@ public ScoreFunction.ApproximateScoreFunction scoreFunctionFor(VectorFloat<?> q,
229229
switch (similarityFunction) {
230230
case DOT_PRODUCT:
231231
return (node2) -> {
232-
var encoded = get(node2);
232+
var encodedChunk = getChunk(node2);
233+
var encodedOffset = getOffsetInChunk(node2);
233234
// compute the dot product of the query and the codebook centroids corresponding to the encoded points
234235
float dp = 0;
235236
for (int m = 0; m < pq.getSubspaceCount(); m++) {
236-
int centroidIndex = Byte.toUnsignedInt(encoded.get(m));
237+
int centroidIndex = Byte.toUnsignedInt(encodedChunk.get(m + encodedOffset));
237238
int centroidLength = pq.subvectorSizesAndOffsets[m][0];
238239
int centroidOffset = pq.subvectorSizesAndOffsets[m][1];
239240
dp += VectorUtil.dotProduct(pq.codebooks[m], centroidIndex * centroidLength, centeredQuery, centroidOffset, centroidLength);
@@ -244,12 +245,13 @@ public ScoreFunction.ApproximateScoreFunction scoreFunctionFor(VectorFloat<?> q,
244245
case COSINE:
245246
float norm1 = VectorUtil.dotProduct(centeredQuery, centeredQuery);
246247
return (node2) -> {
247-
var encoded = get(node2);
248+
var encodedChunk = getChunk(node2);
249+
var encodedOffset = getOffsetInChunk(node2);
248250
// compute the dot product of the query and the codebook centroids corresponding to the encoded points
249251
float sum = 0;
250252
float norm2 = 0;
251253
for (int m = 0; m < pq.getSubspaceCount(); m++) {
252-
int centroidIndex = Byte.toUnsignedInt(encoded.get(m));
254+
int centroidIndex = Byte.toUnsignedInt(encodedChunk.get(m + encodedOffset));
253255
int centroidLength = pq.subvectorSizesAndOffsets[m][0];
254256
int centroidOffset = pq.subvectorSizesAndOffsets[m][1];
255257
var codebookOffset = centroidIndex * centroidLength;
@@ -262,11 +264,12 @@ public ScoreFunction.ApproximateScoreFunction scoreFunctionFor(VectorFloat<?> q,
262264
};
263265
case EUCLIDEAN:
264266
return (node2) -> {
265-
var encoded = get(node2);
267+
var encodedChunk = getChunk(node2);
268+
var encodedOffset = getOffsetInChunk(node2);
266269
// compute the euclidean distance between the query and the codebook centroids corresponding to the encoded points
267270
float sum = 0;
268271
for (int m = 0; m < pq.getSubspaceCount(); m++) {
269-
int centroidIndex = Byte.toUnsignedInt(encoded.get(m));
272+
int centroidIndex = Byte.toUnsignedInt(encodedChunk.get(m + encodedOffset));
270273
int centroidLength = pq.subvectorSizesAndOffsets[m][0];
271274
int centroidOffset = pq.subvectorSizesAndOffsets[m][1];
272275
sum += VectorUtil.squareL2Distance(pq.codebooks[m], centroidIndex * centroidLength, centeredQuery, centroidOffset, centroidLength);
@@ -279,17 +282,49 @@ public ScoreFunction.ApproximateScoreFunction scoreFunctionFor(VectorFloat<?> q,
279282
}
280283
}
281284

285+
/**
286+
* Returns a {@link ByteSequence} for the given ordinal.
287+
* @param ordinal the vector's ordinal
288+
* @return the {@link ByteSequence}
289+
*/
282290
public ByteSequence<?> get(int ordinal) {
283291
if (ordinal < 0 || ordinal >= count())
284292
throw new IndexOutOfBoundsException("Ordinal " + ordinal + " out of bounds for vector count " + count());
285293
return get(compressedDataChunks, ordinal, vectorsPerChunk, pq.getSubspaceCount());
286294
}
287295

288296
static ByteSequence<?> get(ByteSequence<?>[] chunks, int ordinal, int vectorsPerChunk, int subspaceCount) {
289-
int chunkIndex = ordinal / vectorsPerChunk;
290297
int vectorIndexInChunk = ordinal % vectorsPerChunk;
291298
int start = vectorIndexInChunk * subspaceCount;
292-
return chunks[chunkIndex].slice(start, subspaceCount);
299+
return getChunk(chunks, ordinal, vectorsPerChunk).slice(start, subspaceCount);
300+
}
301+
302+
/**
303+
* Returns a reference to the {@link ByteSequence} containing for the given ordinal. Only intended for use where
304+
* the caller wants to avoid an allocation for the slice object. After getting the chunk, callers should use the
305+
* {@link #getOffsetInChunk(int)} method to get the offset of the vector within the chunk and then use the pq's
306+
* {@link ProductQuantization#getSubspaceCount()} to get the length of the vector.
307+
* @param ordinal the vector's ordinal
308+
* @return the {@link ByteSequence} chunk containing the vector
309+
*/
310+
ByteSequence<?> getChunk(int ordinal) {
311+
if (ordinal < 0 || ordinal >= count())
312+
throw new IndexOutOfBoundsException("Ordinal " + ordinal + " out of bounds for vector count " + count());
313+
314+
return getChunk(compressedDataChunks, ordinal, vectorsPerChunk);
315+
}
316+
317+
int getOffsetInChunk(int ordinal) {
318+
if (ordinal < 0 || ordinal >= count())
319+
throw new IndexOutOfBoundsException("Ordinal " + ordinal + " out of bounds for vector count " + count());
320+
321+
int vectorIndexInChunk = ordinal % vectorsPerChunk;
322+
return vectorIndexInChunk * pq.getSubspaceCount();
323+
}
324+
325+
static ByteSequence<?> getChunk(ByteSequence<?>[] chunks, int ordinal, int vectorsPerChunk) {
326+
int chunkIndex = ordinal / vectorsPerChunk;
327+
return chunks[chunkIndex];
293328
}
294329

295330

jvector-base/src/main/java/io/github/jbellis/jvector/vector/DefaultVectorUtilSupport.java

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -296,9 +296,14 @@ public void minInPlace(VectorFloat<?> v1, VectorFloat<?> v2) {
296296

297297
@Override
298298
public float assembleAndSum(VectorFloat<?> data, int dataBase, ByteSequence<?> baseOffsets) {
299+
return assembleAndSum(data, dataBase, baseOffsets, 0, baseOffsets.length());
300+
}
301+
302+
@Override
303+
public float assembleAndSum(VectorFloat<?> data, int dataBase, ByteSequence<?> baseOffsets, int baseOffsetsOffset, int baseOffsetsLength) {
299304
float sum = 0f;
300-
for (int i = 0; i < baseOffsets.length(); i++) {
301-
sum += data.get(dataBase * i + Byte.toUnsignedInt(baseOffsets.get(i)));
305+
for (int i = 0; i < baseOffsetsLength; i++) {
306+
sum += data.get(dataBase * i + Byte.toUnsignedInt(baseOffsets.get(i + baseOffsetsOffset)));
302307
}
303308
return sum;
304309
}

jvector-base/src/main/java/io/github/jbellis/jvector/vector/VectorUtil.java

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -166,6 +166,10 @@ public static float assembleAndSum(VectorFloat<?> data, int dataBase, ByteSequen
166166
return impl.assembleAndSum(data, dataBase, dataOffsets);
167167
}
168168

169+
public static float assembleAndSum(VectorFloat<?> data, int dataBase, ByteSequence<?> dataOffsets, int dataOffsetsOffset, int dataOffsetsLength) {
170+
return impl.assembleAndSum(data, dataBase, dataOffsets, dataOffsetsOffset, dataOffsetsLength);
171+
}
172+
169173
public static void bulkShuffleQuantizedSimilarity(ByteSequence<?> shuffles, int codebookCount, ByteSequence<?> quantizedPartials, float delta, float minDistance, VectorFloat<?> results, VectorSimilarityFunction vsf) {
170174
impl.bulkShuffleQuantizedSimilarity(shuffles, codebookCount, quantizedPartials, delta, minDistance, vsf, results);
171175
}
@@ -215,6 +219,10 @@ public static float pqDecodedCosineSimilarity(ByteSequence<?> encoded, int clust
215219
return impl.pqDecodedCosineSimilarity(encoded, clusterCount, partialSums, aMagnitude, bMagnitude);
216220
}
217221

222+
public static float pqDecodedCosineSimilarity(ByteSequence<?> encoded, int encodedOffset, int encodedLength, int clusterCount, VectorFloat<?> partialSums, VectorFloat<?> aMagnitude, float bMagnitude) {
223+
return impl.pqDecodedCosineSimilarity(encoded, encodedOffset, encodedLength, clusterCount, partialSums, aMagnitude, bMagnitude);
224+
}
225+
218226
public static float nvqDotProduct8bit(VectorFloat<?> vector, ByteSequence<?> bytes, float growthRate, float midpoint, float minValue, float maxValue) {
219227
return impl.nvqDotProduct8bit(vector, bytes, growthRate, midpoint, minValue, maxValue);
220228
}

jvector-base/src/main/java/io/github/jbellis/jvector/vector/VectorUtilSupport.java

Lines changed: 20 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -100,6 +100,19 @@ public interface VectorUtilSupport {
100100
*/
101101
float assembleAndSum(VectorFloat<?> data, int baseIndex, ByteSequence<?> baseOffsets);
102102

103+
/**
104+
* Calculates the sum of sparse points in a vector.
105+
*
106+
* @param data the vector of all datapoints
107+
* @param baseIndex the start of the data in the offset table
108+
* (scaled by the index of the lookup table)
109+
* @param baseOffsets bytes that represent offsets from the baseIndex
110+
* @param baseOffsetsOffset the offset into the baseOffsets ByteSequence
111+
* @param baseOffsetsLength the length of the baseOffsets ByteSequence to use
112+
* @return the sum of the points
113+
*/
114+
float assembleAndSum(VectorFloat<?> data, int baseIndex, ByteSequence<?> baseOffsets, int baseOffsetsOffset, int baseOffsetsLength);
115+
103116
int hammingDistance(long[] v1, long[] v2);
104117

105118

@@ -212,12 +225,17 @@ default void bulkShuffleQuantizedSimilarityCosine(ByteSequence<?> shuffles, int
212225
float min(VectorFloat<?> v);
213226

214227
default float pqDecodedCosineSimilarity(ByteSequence<?> encoded, int clusterCount, VectorFloat<?> partialSums, VectorFloat<?> aMagnitude, float bMagnitude)
228+
{
229+
return pqDecodedCosineSimilarity(encoded, 0, encoded.length(), clusterCount, partialSums, aMagnitude, bMagnitude);
230+
}
231+
232+
default float pqDecodedCosineSimilarity(ByteSequence<?> encoded, int encodedOffset, int encodedLength, int clusterCount, VectorFloat<?> partialSums, VectorFloat<?> aMagnitude, float bMagnitude)
215233
{
216234
float sum = 0.0f;
217235
float aMag = 0.0f;
218236

219-
for (int m = 0; m < encoded.length(); ++m) {
220-
int centroidIndex = Byte.toUnsignedInt(encoded.get(m));
237+
for (int m = 0; m < encodedLength; ++m) {
238+
int centroidIndex = Byte.toUnsignedInt(encoded.get(m + encodedOffset));
221239
var index = m * clusterCount + centroidIndex;
222240
sum += partialSums.get(index);
223241
aMag += aMagnitude.get(index);

jvector-native/src/main/java/io/github/jbellis/jvector/vector/NativeVectorUtilSupport.java

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -121,6 +121,14 @@ public float assembleAndSum(VectorFloat<?> data, int dataBase, ByteSequence<?> b
121121
return NativeSimdOps.assemble_and_sum_f32_512(((MemorySegmentVectorFloat)data).get(), dataBase, ((MemorySegmentByteSequence)baseOffsets).get(), baseOffsets.length());
122122
}
123123

124+
@Override
125+
public float assembleAndSum(VectorFloat<?> data, int dataBase, ByteSequence<?> baseOffsets, int baseOffsetsOffset, int baseOffsetsLength)
126+
{
127+
assert baseOffsetsOffset == 0;
128+
assert baseOffsetsLength == baseOffsets.length();
129+
return assembleAndSum(data, dataBase, baseOffsets);
130+
}
131+
124132
@Override
125133
public int hammingDistance(long[] v1, long[] v2) {
126134
return VectorSimdOps.hammingDistance(v1, v2);

jvector-twenty/src/main/java/io/github/jbellis/jvector/vector/PanamaVectorUtilSupport.java

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -112,7 +112,13 @@ public void minInPlace(VectorFloat<?> v1, VectorFloat<?> v2) {
112112

113113
@Override
114114
public float assembleAndSum(VectorFloat<?> data, int dataBase, ByteSequence<?> baseOffsets) {
115-
return SimdOps.assembleAndSum(((ArrayVectorFloat) data).get(), dataBase, ((ByteSequence<byte[]>) baseOffsets));
115+
return SimdOps.assembleAndSum(((ArrayVectorFloat) data).get(), dataBase, ((ByteSequence<byte[]>) baseOffsets),
116+
0, baseOffsets.length());
117+
}
118+
119+
@Override
120+
public float assembleAndSum(VectorFloat<?> data, int dataBase, ByteSequence<?> baseOffsets, int baseOffsetsOffset, int baseOffsetsLength) {
121+
return SimdOps.assembleAndSum(((ArrayVectorFloat) data).get(), dataBase, ((ByteSequence<byte[]>) baseOffsets), baseOffsetsOffset, baseOffsetsLength);
116122
}
117123

118124
@Override
@@ -177,9 +183,14 @@ public void quantizePartials(float delta, VectorFloat<?> partials, VectorFloat<?
177183
}
178184

179185
@Override
180-
public float pqDecodedCosineSimilarity(ByteSequence<?> encoded, int clusterCount, VectorFloat<?> partialSums, VectorFloat<?> aMagnitude, float bMagnitude)
186+
public float pqDecodedCosineSimilarity(ByteSequence<?> encoded, int clusterCount, VectorFloat<?> partialSums, VectorFloat<?> aMagnitude, float bMagnitude) {
187+
return pqDecodedCosineSimilarity(encoded, 0, encoded.length(), clusterCount, partialSums, aMagnitude, bMagnitude);
188+
}
189+
190+
@Override
191+
public float pqDecodedCosineSimilarity(ByteSequence<?> encoded, int encodedOffset, int encodedLength, int clusterCount, VectorFloat<?> partialSums, VectorFloat<?> aMagnitude, float bMagnitude)
181192
{
182-
return SimdOps.pqDecodedCosineSimilarity((ByteSequence<byte[]>) encoded, clusterCount, (ArrayVectorFloat) partialSums, (ArrayVectorFloat) aMagnitude, bMagnitude);
193+
return SimdOps.pqDecodedCosineSimilarity((ByteSequence<byte[]>) encoded, encodedOffset, encodedLength, clusterCount, (ArrayVectorFloat) partialSums, (ArrayVectorFloat) aMagnitude, bMagnitude);
183194
}
184195

185196
@Override

0 commit comments

Comments
 (0)