Skip to content

Commit 7cbb2e1

Browse files
marianotepperJonathan Ellis
andauthored
Improved use ScoreTracker to avoid wasteful searching for very large k (#387)
This improves upon #384 by making the quantiles estimation more lightweight. It models the recent scores as a Normal distribution and uses incremental updates to track sufficient statistics of its mean and variance. Then, quantiles are computed from these statistics. --------- Co-authored-by: Jonathan Ellis <jbellis@datastax.com>
1 parent 9613109 commit 7cbb2e1

File tree

2 files changed

+100
-4
lines changed

2 files changed

+100
-4
lines changed

jvector-base/src/main/java/io/github/jbellis/jvector/graph/GraphSearcher.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -269,7 +269,7 @@ private SearchResult resume(int initialVisited, int topK, int rerankK, float thr
269269
// track scores to predict when we are done with threshold queries
270270
var scoreTracker = threshold > 0
271271
? new ScoreTracker.TwoPhaseTracker(threshold)
272-
: PRUNE ? new ScoreTracker.TwoPhaseTracker(1.0) : new ScoreTracker.NoOpTracker();
272+
: PRUNE ? new ScoreTracker.RelaxedMonotonicityTracker(rerankK) : new ScoreTracker.NoOpTracker();
273273
VectorFloat<?> similarities = null;
274274

275275
// add evicted results from the last call back to the candidates

jvector-base/src/main/java/io/github/jbellis/jvector/graph/ScoreTracker.java

Lines changed: 99 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,6 @@
1616

1717
package io.github.jbellis.jvector.graph;
1818

19-
import io.github.jbellis.jvector.util.AbstractLongHeap;
2019
import io.github.jbellis.jvector.util.BoundedLongHeap;
2120
import org.apache.commons.math3.stat.StatUtils;
2221

@@ -93,8 +92,105 @@ public boolean shouldStop() {
9392
// (paper suggests using the median of recent scores, but experimentally that is too prone to false positives.
9493
// 90th does seem to be enough, but 99th doesn't result in much extra work, so we'll be conservative)
9594
double windowMedian = StatUtils.percentile(recentScores, 99);
96-
double worstBest = sortableIntToFloat((int) bestScores.top());
97-
return windowMedian < worstBest && windowMedian < threshold;
95+
double worstBestScore = sortableIntToFloat((int) bestScores.top());
96+
return windowMedian < worstBestScore && windowMedian < threshold;
97+
}
98+
}
99+
100+
/**
101+
* Follows the methodology of section 3.1 in "VBase: Unifying Online Vector Similarity Search
102+
* and Relational Queries via Relaxed Monotonicity" to determine when we've left phase 1
103+
* (finding the local maximum) and entered phase 2 (mostly just finding worse options)
104+
* To compute quantiles quickly, we treat the distribution of the data as Normal,
105+
* track its mean and variance, and compute quantiles from them as:
106+
* mean + SIGMA_FACTOR * sqrt(variance)
107+
* Empirically, SIGMA_FACTOR=1.75 seems to work reasonably well
108+
* (approximately the 96th percentile of the Normal distribution).
109+
*/
110+
class RelaxedMonotonicityTracker implements ScoreTracker {
111+
static final double SIGMA_FACTOR = 1.75;
112+
113+
// a sliding window of recent scores
114+
private final double[] recentScores;
115+
private int recentEntryIndex;
116+
117+
// Heap of the best scores seen so far
118+
BoundedLongHeap bestScores;
119+
120+
// observation count
121+
private int observationCount;
122+
123+
// the sample mean
124+
private double mean;
125+
126+
// the sample variance multiplied by n-1
127+
private double dSquared;
128+
129+
/**
130+
* Constructor
131+
* @param bestScoresTracked the number of tracked scores used to estimate if we are unlikely to improve
132+
* the results anymore. An empirical rule of thumb is bestScoresTracked=rerankK.
133+
*/
134+
RelaxedMonotonicityTracker(int bestScoresTracked) {
135+
// A quick empirical study yields that the number of recent scores
136+
// that we need to consider grows by a factor of ~sqrt(bestScoresTracked / 2)
137+
int factor = (int) Math.round(Math.sqrt(bestScoresTracked / 2.0));
138+
this.recentScores = new double[200 * factor];
139+
this.bestScores = new BoundedLongHeap(bestScoresTracked);
140+
this.mean = 0;
141+
this.dSquared = 0;
142+
}
143+
144+
@Override
145+
public void track(float score) {
146+
bestScores.push(floatToSortableInt(score));
147+
observationCount++;
148+
149+
// The updates of the sufficient statistics follow
150+
// https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Online
151+
// and
152+
// https://nestedsoftware.com/2019/09/26/incremental-average-and-standard-deviation-with-sliding-window-470k.176143.html
153+
if (observationCount <= this.recentScores.length) {
154+
// if the buffer is not full yet, use standard Welford method
155+
var meanDelta = (score - this.mean) / observationCount;
156+
var newMean = this.mean + meanDelta;
157+
158+
var dSquaredDelta = ((score - newMean) * (score - this.mean));
159+
var newDSquared = this.dSquared + dSquaredDelta;
160+
161+
this.mean = newMean;
162+
this.dSquared = newDSquared;
163+
} else {
164+
// once the buffer is full, adjust Welford method for window size
165+
var oldScore = recentScores[recentEntryIndex];
166+
var meanDelta = (score - oldScore) / this.recentScores.length;
167+
var newMean = this.mean + meanDelta;
168+
169+
var dSquaredDelta = ((score - oldScore) * (score - newMean + oldScore - this.mean));
170+
var newDSquared = this.dSquared + dSquaredDelta;
171+
172+
this.mean = newMean;
173+
this.dSquared = newDSquared;
174+
}
175+
recentScores[recentEntryIndex] = score;
176+
recentEntryIndex = (recentEntryIndex + 1) % this.recentScores.length;
177+
}
178+
179+
@Override
180+
public boolean shouldStop() {
181+
// don't stop if we don't have enough data points
182+
if (observationCount < this.recentScores.length) {
183+
return false;
184+
}
185+
186+
// We're in phase 2 if the q-th percentile of the recent scores evaluated,
187+
// mean + SIGMA_FACTOR * sqrt(variance),
188+
// is lower than the worst of the best scores seen.
189+
// (paper suggests using the median of recent scores, but experimentally that is too prone to false positives)
190+
double std = Math.sqrt(this.dSquared / (this.recentScores.length - 1));
191+
double windowPercentile = this.mean + SIGMA_FACTOR * std;
192+
double worstBestScore = sortableIntToFloat((int) bestScores.top());
193+
return windowPercentile < worstBestScore;
98194
}
99195
}
100196
}

0 commit comments

Comments
 (0)