Skip to content

Commit 356c490

Browse files
Acceleration of initial centroid selection in kmeans++
Use Java Vector API to accelerate chooseInitialCentroids in KMeansPlusPlusClusterer.
1 parent 1f6c5b1 commit 356c490

File tree

8 files changed

+80
-11
lines changed

8 files changed

+80
-11
lines changed

jvector-base/src/main/java/io/github/jbellis/jvector/quantization/KMeansPlusPlusClusterer.java

Lines changed: 12 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -181,26 +181,26 @@ private static VectorFloat<?> chooseInitialCentroids(VectorFloat<?>[] points, in
181181

182182
float[] distances = new float[points.length];
183183
Arrays.fill(distances, Float.MAX_VALUE);
184+
VectorFloat<?> distancesVector = vectorTypeSupport.createFloatVector(distances);
185+
int distancesLength = points.length;
184186

185187
// Choose the first centroid randomly
186188
VectorFloat<?> firstCentroid = points[random.nextInt(points.length)];
187189
centroids.copyFrom(firstCentroid, 0, 0, firstCentroid.length());
190+
VectorFloat<?> newDistancesVector = vectorTypeSupport.createFloatVector(points.length);
188191
for (int i = 0; i < points.length; i++) {
189-
float distance1 = squareL2Distance(points[i], firstCentroid);
190-
distances[i] = Math.min(distances[i], distance1);
192+
newDistancesVector.set(i, squareL2Distance(points[i], firstCentroid));
191193
}
194+
VectorUtil.minInPlace(distancesVector, newDistancesVector);
192195

193196
// For each subsequent centroid
194197
for (int i = 1; i < k; i++) {
195-
float totalDistance = 0;
196-
for (float distance : distances) {
197-
totalDistance += distance;
198-
}
198+
float totalDistance = VectorUtil.sum(distancesVector);
199199

200200
float r = random.nextFloat() * totalDistance;
201201
int selectedIdx = -1;
202-
for (int j = 0; j < distances.length; j++) {
203-
r -= distances[j];
202+
for (int j = 0; j < distancesLength; j++) {
203+
r -= distancesVector.get(j);
204204
if (r < 1e-6) {
205205
selectedIdx = j;
206206
break;
@@ -215,10 +215,11 @@ private static VectorFloat<?> chooseInitialCentroids(VectorFloat<?>[] points, in
215215
centroids.copyFrom(nextCentroid, 0, i * nextCentroid.length(), nextCentroid.length());
216216

217217
// Update distances, but only if the new centroid provides a closer distance
218-
for (int j = 0; j < points.length; j++) {
219-
float newDistance = squareL2Distance(points[j], nextCentroid);
220-
distances[j] = Math.min(distances[j], newDistance);
218+
// All entries of newDistancesVector is overwritten with the updated squareL2Distance value
219+
for (int j = 0; j < distancesLength; j++) {
220+
newDistancesVector.set(j, squareL2Distance(points[j], nextCentroid));
221221
}
222+
VectorUtil.minInPlace(distancesVector, newDistancesVector);
222223
}
223224
assertFinite(centroids);
224225
return centroids;

jvector-base/src/main/java/io/github/jbellis/jvector/vector/DefaultVectorUtilSupport.java

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -287,6 +287,13 @@ public VectorFloat<?> sub(VectorFloat<?> a, int aOffset, VectorFloat<?> b, int b
287287
return result;
288288
}
289289

290+
@Override
291+
public void minInPlace(VectorFloat<?> v1, VectorFloat<?> v2) {
292+
for (int i = 0; i < v1.length(); i++) {
293+
v1.set(i, Math.min(v1.get(i), v2.get(i)));
294+
}
295+
}
296+
290297
@Override
291298
public float assembleAndSum(VectorFloat<?> data, int dataBase, ByteSequence<?> baseOffsets) {
292299
float sum = 0f;
@@ -541,4 +548,5 @@ public float nvqUniformLoss(VectorFloat<?> vector, float minValue, float maxValu
541548

542549
return squaredSum;
543550
}
551+
544552
}

jvector-base/src/main/java/io/github/jbellis/jvector/vector/VectorUtil.java

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -158,6 +158,10 @@ public static VectorFloat<?> sub(VectorFloat<?> a, int aOffset, VectorFloat<?> b
158158
return impl.sub(a, aOffset, b, bOffset, length);
159159
}
160160

161+
public static void minInPlace(VectorFloat<?> distances1, VectorFloat<?> distances2) {
162+
impl.minInPlace(distances1, distances2);
163+
}
164+
161165
public static float assembleAndSum(VectorFloat<?> data, int dataBase, ByteSequence<?> dataOffsets) {
162166
return impl.assembleAndSum(data, dataBase, dataOffsets);
163167
}
@@ -238,4 +242,5 @@ public static float nvqLoss(VectorFloat<?> vector, float growthRate, float midpo
238242
public static float nvqUniformLoss(VectorFloat<?> vector, float minValue, float maxValue, int nBits) {
239243
return impl.nvqUniformLoss(vector, minValue, maxValue, nBits);
240244
}
245+
241246
}

jvector-base/src/main/java/io/github/jbellis/jvector/vector/VectorUtilSupport.java

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,9 @@ public interface VectorUtilSupport {
8282
/** @return a - b, element-wise, starting at aOffset and bOffset respectively */
8383
VectorFloat<?> sub(VectorFloat<?> a, int aOffset, VectorFloat<?> b, int bOffset, int length);
8484

85+
/** Calculates the minimum value for every corresponding lane values in v1 and v2, in place (v1 will be modified) */
86+
void minInPlace(VectorFloat<?> v1, VectorFloat<?> v2);
87+
8588
/**
8689
* Calculates the sum of sparse points in a vector.
8790
* <p>
@@ -300,4 +303,5 @@ default float pqDecodedCosineSimilarity(ByteSequence<?> encoded, int clusterCoun
300303
* @param nBits the number of bits per dimension
301304
*/
302305
float nvqUniformLoss(VectorFloat<?> vector, float minValue, float maxValue, int nBits);
306+
303307
}

jvector-native/src/main/java/io/github/jbellis/jvector/vector/NativeVectorUtilSupport.java

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -110,6 +110,11 @@ public VectorFloat<?> sub(VectorFloat<?> a, int aOffset, VectorFloat<?> b, int b
110110
return VectorSimdOps.sub((MemorySegmentVectorFloat) a, aOffset, (MemorySegmentVectorFloat) b, bOffset, length);
111111
}
112112

113+
@Override
114+
public void minInPlace(VectorFloat<?> v1, VectorFloat<?> v2) {
115+
VectorSimdOps.minInPlace((MemorySegmentVectorFloat) v1, (MemorySegmentVectorFloat) v2);
116+
}
117+
113118
@Override
114119
public float assembleAndSum(VectorFloat<?> data, int dataBase, ByteSequence<?> baseOffsets) {
115120
assert baseOffsets.offset() == 0 : "Base offsets are expected to have an offset of 0. Found: " + baseOffsets.offset();
@@ -225,4 +230,5 @@ public float nvqLoss(VectorFloat<?> vector, float growthRate, float midpoint, fl
225230
public float nvqUniformLoss(VectorFloat<?> vector, float minValue, float maxValue, int nBits) {
226231
return VectorSimdOps.nvqUniformLoss((MemorySegmentVectorFloat) vector, minValue, maxValue, nBits);
227232
}
233+
228234
}

jvector-native/src/main/java/io/github/jbellis/jvector/vector/VectorSimdOps.java

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -598,6 +598,26 @@ public static int hammingDistance(long[] a, long[] b) {
598598
return res;
599599
}
600600

601+
static void minInPlace(MemorySegmentVectorFloat v1, MemorySegmentVectorFloat v2) {
602+
if (v1.length() != v2.length()) {
603+
throw new IllegalArgumentException("Vectors must have the same length");
604+
}
605+
606+
int vectorizedLength = FloatVector.SPECIES_PREFERRED.loopBound(v1.length());
607+
608+
// Process the vectorized part
609+
for (int i = 0; i < vectorizedLength; i += FloatVector.SPECIES_PREFERRED.length()) {
610+
var a = FloatVector.fromMemorySegment(FloatVector.SPECIES_PREFERRED, v1.get(), v1.offset(i), ByteOrder.LITTLE_ENDIAN);
611+
var b = FloatVector.fromMemorySegment(FloatVector.SPECIES_PREFERRED, v2.get(), v2.offset(i), ByteOrder.LITTLE_ENDIAN);
612+
a.min(b).intoMemorySegment(v1.get(), v1.offset(i), ByteOrder.LITTLE_ENDIAN);
613+
}
614+
615+
// Process the tail
616+
for (int i = vectorizedLength; i < v1.length(); i++) {
617+
v1.set(i, Math.min(v1.get(i), v2.get(i)));
618+
}
619+
}
620+
601621
public static float max(MemorySegmentVectorFloat vector) {
602622
var accum = FloatVector.broadcast(FloatVector.SPECIES_PREFERRED, -Float.MAX_VALUE);
603623
int vectorizedLength = FloatVector.SPECIES_PREFERRED.loopBound(vector.length());

jvector-twenty/src/main/java/io/github/jbellis/jvector/vector/PanamaVectorUtilSupport.java

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -105,6 +105,11 @@ public VectorFloat<?> sub(VectorFloat<?> a, int aOffset, VectorFloat<?> b, int b
105105
return SimdOps.sub((ArrayVectorFloat) a, aOffset, (ArrayVectorFloat) b, bOffset, length);
106106
}
107107

108+
@Override
109+
public void minInPlace(VectorFloat<?> v1, VectorFloat<?> v2) {
110+
SimdOps.minInPlace((ArrayVectorFloat)v1, (ArrayVectorFloat)v2);
111+
}
112+
108113
@Override
109114
public float assembleAndSum(VectorFloat<?> data, int dataBase, ByteSequence<?> baseOffsets) {
110115
return SimdOps.assembleAndSum(((ArrayVectorFloat) data).get(), dataBase, ((ByteSequence<byte[]>) baseOffsets));

jvector-twenty/src/main/java/io/github/jbellis/jvector/vector/SimdOps.java

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -612,6 +612,26 @@ static VectorFloat<?> sub(ArrayVectorFloat a, int aOffset, float value, int leng
612612
return new ArrayVectorFloat(res);
613613
}
614614

615+
static void minInPlace(ArrayVectorFloat v1, ArrayVectorFloat v2) {
616+
if (v1.length() != v2.length()) {
617+
throw new IllegalArgumentException("Vectors must have the same length");
618+
}
619+
620+
int vectorizedLength = FloatVector.SPECIES_PREFERRED.loopBound(v1.length());
621+
622+
// Process the vectorized part
623+
for (int i = 0; i < vectorizedLength; i += FloatVector.SPECIES_PREFERRED.length()) {
624+
var a = FloatVector.fromArray(FloatVector.SPECIES_PREFERRED, v1.get(), i);
625+
var b = FloatVector.fromArray(FloatVector.SPECIES_PREFERRED, v2.get(), i);
626+
a.min(b).intoArray(v1.get(), i);
627+
}
628+
629+
// Process the tail
630+
for (int i = vectorizedLength; i < v1.length(); i++) {
631+
v1.set(i, Math.min(v1.get(i), v2.get(i)));
632+
}
633+
}
634+
615635
static float assembleAndSum(float[] data, int dataBase, ByteSequence<byte[]> baseOffsets) {
616636
return switch (PREFERRED_BIT_SIZE)
617637
{

0 commit comments

Comments
 (0)