|
16 | 16 |
|
17 | 17 | package io.github.jbellis.jvector.graph; |
18 | 18 |
|
19 | | -import io.github.jbellis.jvector.util.AbstractLongHeap; |
20 | 19 | import io.github.jbellis.jvector.util.BoundedLongHeap; |
21 | 20 | import org.apache.commons.math3.stat.StatUtils; |
22 | 21 |
|
@@ -93,8 +92,105 @@ public boolean shouldStop() { |
93 | 92 | // (paper suggests using the median of recent scores, but experimentally that is too prone to false positives. |
94 | 93 | // 90th does seem to be enough, but 99th doesn't result in much extra work, so we'll be conservative) |
95 | 94 | double windowMedian = StatUtils.percentile(recentScores, 99); |
96 | | - double worstBest = sortableIntToFloat((int) bestScores.top()); |
97 | | - return windowMedian < worstBest && windowMedian < threshold; |
| 95 | + double worstBestScore = sortableIntToFloat((int) bestScores.top()); |
| 96 | + return windowMedian < worstBestScore && windowMedian < threshold; |
| 97 | + } |
| 98 | + } |
| 99 | + |
| 100 | + /** |
| 101 | + * Follows the methodology of section 3.1 in "VBase: Unifying Online Vector Similarity Search |
| 102 | + * and Relational Queries via Relaxed Monotonicity" to determine when we've left phase 1 |
| 103 | + * (finding the local maximum) and entered phase 2 (mostly just finding worse options) |
| 104 | + * To compute quantiles quickly, we treat the distribution of the data as Normal, |
| 105 | + * track its mean and variance, and compute quantiles from them as: |
| 106 | + * mean + SIGMA_FACTOR * sqrt(variance) |
| 107 | + * Empirically, SIGMA_FACTOR=1.75 seems to work reasonably well |
| 108 | + * (approximately the 96th percentile of the Normal distribution). |
| 109 | + */ |
| 110 | + class RelaxedMonotonicityTracker implements ScoreTracker { |
| 111 | + static final double SIGMA_FACTOR = 1.75; |
| 112 | + |
| 113 | + // a sliding window of recent scores |
| 114 | + private final double[] recentScores; |
| 115 | + private int recentEntryIndex; |
| 116 | + |
| 117 | + // Heap of the best scores seen so far |
| 118 | + BoundedLongHeap bestScores; |
| 119 | + |
| 120 | + // observation count |
| 121 | + private int observationCount; |
| 122 | + |
| 123 | + // the sample mean |
| 124 | + private double mean; |
| 125 | + |
| 126 | + // the sample variance multiplied by n-1 |
| 127 | + private double dSquared; |
| 128 | + |
| 129 | + /** |
| 130 | + * Constructor |
| 131 | + * @param bestScoresTracked the number of tracked scores used to estimate if we are unlikely to improve |
| 132 | + * the results anymore. An empirical rule of thumb is bestScoresTracked=rerankK. |
| 133 | + */ |
| 134 | + RelaxedMonotonicityTracker(int bestScoresTracked) { |
| 135 | + // A quick empirical study yields that the number of recent scores |
| 136 | + // that we need to consider grows by a factor of ~sqrt(bestScoresTracked / 2) |
| 137 | + int factor = (int) Math.round(Math.sqrt(bestScoresTracked / 2.0)); |
| 138 | + this.recentScores = new double[200 * factor]; |
| 139 | + this.bestScores = new BoundedLongHeap(bestScoresTracked); |
| 140 | + this.mean = 0; |
| 141 | + this.dSquared = 0; |
| 142 | + } |
| 143 | + |
| 144 | + @Override |
| 145 | + public void track(float score) { |
| 146 | + bestScores.push(floatToSortableInt(score)); |
| 147 | + observationCount++; |
| 148 | + |
| 149 | + // The updates of the sufficient statistics follow |
| 150 | + // https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Online |
| 151 | + // and |
| 152 | + // https://nestedsoftware.com/2019/09/26/incremental-average-and-standard-deviation-with-sliding-window-470k.176143.html |
| 153 | + if (observationCount <= this.recentScores.length) { |
| 154 | + // if the buffer is not full yet, use standard Welford method |
| 155 | + var meanDelta = (score - this.mean) / observationCount; |
| 156 | + var newMean = this.mean + meanDelta; |
| 157 | + |
| 158 | + var dSquaredDelta = ((score - newMean) * (score - this.mean)); |
| 159 | + var newDSquared = this.dSquared + dSquaredDelta; |
| 160 | + |
| 161 | + this.mean = newMean; |
| 162 | + this.dSquared = newDSquared; |
| 163 | + } else { |
| 164 | + // once the buffer is full, adjust Welford method for window size |
| 165 | + var oldScore = recentScores[recentEntryIndex]; |
| 166 | + var meanDelta = (score - oldScore) / this.recentScores.length; |
| 167 | + var newMean = this.mean + meanDelta; |
| 168 | + |
| 169 | + var dSquaredDelta = ((score - oldScore) * (score - newMean + oldScore - this.mean)); |
| 170 | + var newDSquared = this.dSquared + dSquaredDelta; |
| 171 | + |
| 172 | + this.mean = newMean; |
| 173 | + this.dSquared = newDSquared; |
| 174 | + } |
| 175 | + recentScores[recentEntryIndex] = score; |
| 176 | + recentEntryIndex = (recentEntryIndex + 1) % this.recentScores.length; |
| 177 | + } |
| 178 | + |
| 179 | + @Override |
| 180 | + public boolean shouldStop() { |
| 181 | + // don't stop if we don't have enough data points |
| 182 | + if (observationCount < this.recentScores.length) { |
| 183 | + return false; |
| 184 | + } |
| 185 | + |
| 186 | + // We're in phase 2 if the q-th percentile of the recent scores evaluated, |
| 187 | + // mean + SIGMA_FACTOR * sqrt(variance), |
| 188 | + // is lower than the worst of the best scores seen. |
| 189 | + // (paper suggests using the median of recent scores, but experimentally that is too prone to false positives) |
| 190 | + double std = Math.sqrt(this.dSquared / (this.recentScores.length - 1)); |
| 191 | + double windowPercentile = this.mean + SIGMA_FACTOR * std; |
| 192 | + double worstBestScore = sortableIntToFloat((int) bestScores.top()); |
| 193 | + return windowPercentile < worstBestScore; |
98 | 194 | } |
99 | 195 | } |
100 | 196 | } |
0 commit comments