diff --git a/.github/workflows/all.yml b/.github/workflows/all.yml index c985d2f..64698de 100644 --- a/.github/workflows/all.yml +++ b/.github/workflows/all.yml @@ -214,14 +214,11 @@ jobs: LD_PRELOAD=$(clang -print-file-name=libclang_rt.asan-x86_64.so) \ tox - run-benchmarks: - strategy: - matrix: - os: ['ubuntu-latest'] - runs-on: ${{ matrix.os }} + run-python-benchmarks: + runs-on: ubuntu-latest continue-on-error: true if: github.event_name == 'pull_request' - name: Run benchmarks on ${{ matrix.os }} + name: Run Python benchmarks steps: - name: Set up Python uses: actions/setup-python@v3 @@ -244,6 +241,55 @@ jobs: asv machine --yes asv continuous --sort name --no-only-changed refs/remotes/origin/main ${{ github.sha }} | tee >(sed '1,/All benchmarks:/d' > $GITHUB_STEP_SUMMARY) + run-java-benchmarks: + runs-on: ubuntu-latest + continue-on-error: true + if: github.event_name == 'pull_request' + name: Run Java benchmarks + steps: + - uses: actions/checkout@v3 + with: + fetch-depth: 0 + - name: Set up JDK 17 + uses: actions/setup-java@v3 + with: + java-version: '17' + distribution: 'corretto' + + # Benchmark the base branch first + - name: Checkout base branch + run: git checkout ${{ github.event.pull_request.base.sha }} + - name: Build native library (base) + working-directory: java + run: make + continue-on-error: true + - name: Compile benchmarks (base) + working-directory: java + run: mvn --batch-mode test-compile + continue-on-error: true + - name: Run benchmarks (base) + working-directory: java + run: mvn --batch-mode exec:exec -Dexec.executable=java -Dexec.classpathScope=test -Dexec.args="-cp %classpath org.openjdk.jmh.Main -rf json -rff /tmp/base-results.json -f 1 -wi 2 -i 3 -w 2s -r 2s -jvmArgs -Xms2g -jvmArgs -Xmx2g" + continue-on-error: true + + # Benchmark the PR branch + - name: Checkout PR branch + run: git checkout ${{ github.sha }} + - name: Build native library (PR) + working-directory: java + run: make clean && make + - name: Compile benchmarks (PR) + working-directory: java + run: mvn clean --batch-mode test-compile + - name: Run benchmarks (PR) + working-directory: java + run: mvn --batch-mode exec:exec -Dexec.executable=java -Dexec.classpathScope=test -Dexec.args="-cp %classpath org.openjdk.jmh.Main -rf json -rff /tmp/pr-results.json -f 1 -wi 2 -i 3 -w 2s -r 2s -jvmArgs -Xms2g -jvmArgs -Xmx2g" + + # Compare and report + - name: Compare benchmarks + run: python3 java/scripts/compare_benchmarks.py /tmp/base-results.json /tmp/pr-results.json >> $GITHUB_STEP_SUMMARY + continue-on-error: true + build-python-sdist: needs: [run-python-tests, run-python-tests-with-address-sanitizer] continue-on-error: false diff --git a/java/pom.xml b/java/pom.xml index f8eafd5..aabf9a0 100644 --- a/java/pom.xml +++ b/java/pom.xml @@ -103,6 +103,18 @@ 1.3.2 test + + org.openjdk.jmh + jmh-core + 1.37 + test + + + org.openjdk.jmh + jmh-generator-annprocess + 1.37 + test + @@ -134,6 +146,13 @@ 1.8 1.8 + + + org.openjdk.jmh + jmh-generator-annprocess + 1.37 + + diff --git a/java/scripts/compare_benchmarks.py b/java/scripts/compare_benchmarks.py new file mode 100644 index 0000000..e0a8477 --- /dev/null +++ b/java/scripts/compare_benchmarks.py @@ -0,0 +1,129 @@ +#!/usr/bin/env python3 +"""Compare two JMH JSON benchmark result files and output a markdown summary. + +Usage: + python3 compare_benchmarks.py + +If the base results file does not exist (e.g. when benchmarks are first added), +only the PR results are printed. + +Uses only Python standard library (json, sys, os). +""" + +import json +import os +import sys + + +def load_results(path): + """Load JMH JSON results and return a dict keyed by benchmark name + params.""" + with open(path) as f: + data = json.load(f) + + results = {} + for entry in data: + benchmark = entry["benchmark"] + # Extract short method name from fully qualified name + short_name = benchmark.rsplit(".", 1)[-1] + + params = entry.get("params", {}) + param_key = ", ".join(f"{k}={v}" for k, v in sorted(params.items())) + + key = f"{short_name}({param_key})" if param_key else short_name + + score = entry["primaryMetric"]["score"] + error = entry["primaryMetric"]["scoreError"] + unit = entry["primaryMetric"]["scoreUnit"] + + results[key] = {"score": score, "error": error, "unit": unit} + + return results + + +def format_score(score, error): + """Format a score with error margin.""" + return f"{score:.3f} \u00b1 {error:.3f}" + + +def main(): + if len(sys.argv) < 3: + print(f"Usage: {sys.argv[0]} ", file=sys.stderr) + sys.exit(1) + + base_path = sys.argv[1] + pr_path = sys.argv[2] + + if not os.path.exists(pr_path): + print("Error: PR results file not found.", file=sys.stderr) + sys.exit(1) + + pr_results = load_results(pr_path) + + if not os.path.exists(base_path): + # Base results don't exist yet (first PR adding benchmarks) + print("## Java Benchmark Results\n") + print("_No base branch results available for comparison._\n") + print("| Benchmark | Score | Unit |") + print("|-----------|-------|------|") + for name in sorted(pr_results.keys()): + r = pr_results[name] + print(f"| {name} | {format_score(r['score'], r['error'])} | {r['unit']} |") + return + + base_results = load_results(base_path) + + print("## Java Benchmark Comparison\n") + print("| Benchmark | Base | PR | Delta | Status |") + print("|-----------|------|-----|-------|--------|") + + all_keys = sorted(set(list(base_results.keys()) + list(pr_results.keys()))) + + for name in all_keys: + if name not in base_results: + r = pr_results[name] + print( + f"| {name} | _new_ | {format_score(r['score'], r['error'])} {r['unit']}" + f" | - | \U0001f195 |" + ) + continue + + if name not in pr_results: + r = base_results[name] + print( + f"| {name} | {format_score(r['score'], r['error'])} {r['unit']}" + f" | _removed_ | - | - |" + ) + continue + + base = base_results[name] + pr = pr_results[name] + + if base["score"] == 0: + delta_pct = 0.0 + else: + delta_pct = ((pr["score"] - base["score"]) / base["score"]) * 100 + + # Determine if the change is significant by comparing against combined error margins + combined_error = base["error"] + pr["error"] + abs_diff = abs(pr["score"] - base["score"]) + + if abs_diff > combined_error: + # For time-based benchmarks, lower is better + if pr["score"] < base["score"]: + status = "\u2705 faster" + else: + status = "\u26a0\ufe0f slower" + else: + status = "\u2194\ufe0f unchanged" + + print( + f"| {name}" + f" | {format_score(base['score'], base['error'])}" + f" | {format_score(pr['score'], pr['error'])}" + f" | {delta_pct:+.1f}%" + f" | {status} |" + ) + + +if __name__ == "__main__": + main() diff --git a/java/src/test/java/com/spotify/voyager/jni/IndexCreationBenchmark.java b/java/src/test/java/com/spotify/voyager/jni/IndexCreationBenchmark.java new file mode 100644 index 0000000..45468ed --- /dev/null +++ b/java/src/test/java/com/spotify/voyager/jni/IndexCreationBenchmark.java @@ -0,0 +1,99 @@ +/*- + * -\-\- + * voyager + * -- + * Copyright (C) 2016 - 2023 Spotify AB + * -- + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * -/-/- + */ + +package com.spotify.voyager.jni; + +import java.io.IOException; +import java.util.Random; +import java.util.concurrent.TimeUnit; +import org.openjdk.jmh.annotations.*; + +/** + * JMH benchmarks for index creation performance. + * + *

Mirrors the Python benchmark in benchmarks/index_creation.py. Measures the time to add 1024 + * random vectors of 256 dimensions to a fresh index, parameterized over space type and storage data + * type. + */ +@State(Scope.Benchmark) +@BenchmarkMode(Mode.SingleShotTime) +@OutputTimeUnit(TimeUnit.MILLISECONDS) +@Fork(2) +@Warmup(iterations = 3) +@Measurement(iterations = 5) +public class IndexCreationBenchmark { + + @Param({"256"}) + public int numDimensions; + + @Param({"1024"}) + public int numElements; + + @Param({"Euclidean", "InnerProduct", "Cosine"}) + public String spaceType; + + @Param({"Float32", "Float8", "E4M3"}) + public String storageDataType; + + @Param({"24"}) + public int efConstruction; + + private static final int M = 20; + private static final long RANDOM_SEED = 4321; + + private float[][] inputData; + private Index index; + + @Setup(Level.Trial) + public void generateData() { + Random rng = new Random(1234); + inputData = new float[numElements][numDimensions]; + boolean isFloat8 = "Float8".equals(storageDataType); + + for (int i = 0; i < numElements; i++) { + for (int j = 0; j < numDimensions; j++) { + float val = rng.nextFloat() * 2 - 1; + if (isFloat8) { + val = Math.round(val * 127f) / 127f; + } + inputData[i][j] = val; + } + } + } + + @Setup(Level.Invocation) + public void createFreshIndex() { + Index.SpaceType space = Index.SpaceType.valueOf(spaceType); + Index.StorageDataType storage = Index.StorageDataType.valueOf(storageDataType); + index = new Index(space, numDimensions, M, efConstruction, RANDOM_SEED, numElements, storage); + } + + @TearDown(Level.Invocation) + public void closeIndex() throws IOException { + if (index != null) { + index.close(); + } + } + + @Benchmark + public void addItems() { + index.addItems(inputData, 1); + } +} diff --git a/java/src/test/java/com/spotify/voyager/jni/IndexQueryBenchmark.java b/java/src/test/java/com/spotify/voyager/jni/IndexQueryBenchmark.java new file mode 100644 index 0000000..963e82b --- /dev/null +++ b/java/src/test/java/com/spotify/voyager/jni/IndexQueryBenchmark.java @@ -0,0 +1,108 @@ +/*- + * -\-\- + * voyager + * -- + * Copyright (C) 2016 - 2023 Spotify AB + * -- + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * -/-/- + */ + +package com.spotify.voyager.jni; + +import java.io.IOException; +import java.util.Random; +import java.util.concurrent.TimeUnit; +import org.openjdk.jmh.annotations.*; +import org.openjdk.jmh.infra.Blackhole; + +/** + * JMH benchmarks for index query performance. + * + *

Mirrors the Python benchmark in benchmarks/index_query.py. Queries a pre-populated index of + * 4096 random vectors with 256 dimensions, parameterized over space type, storage data type, and k. + */ +@State(Scope.Benchmark) +@BenchmarkMode(Mode.AverageTime) +@OutputTimeUnit(TimeUnit.MILLISECONDS) +@Fork(2) +@Warmup(iterations = 3) +@Measurement(iterations = 5) +public class IndexQueryBenchmark { + + @Param({"256"}) + public int numDimensions; + + @Param({"4096"}) + public int numElements; + + @Param({"Euclidean", "InnerProduct", "Cosine"}) + public String spaceType; + + @Param({"Float32", "Float8", "E4M3"}) + public String storageDataType; + + @Param({"24"}) + public int efConstruction; + + private static final int M = 20; + private static final long RANDOM_SEED = 4321; + + private Index index; + private float[][] queryVectors; + + @Setup(Level.Trial) + public void buildIndex() { + Random rng = new Random(1234); + boolean isFloat8 = "Float8".equals(storageDataType); + + float[][] inputData = new float[numElements][numDimensions]; + for (int i = 0; i < numElements; i++) { + for (int j = 0; j < numDimensions; j++) { + float val = rng.nextFloat() * 2 - 1; + if (isFloat8) { + val = Math.round(val * 127f) / 127f; + } + inputData[i][j] = val; + } + } + + Index.SpaceType space = Index.SpaceType.valueOf(spaceType); + Index.StorageDataType storage = Index.StorageDataType.valueOf(storageDataType); + index = new Index(space, numDimensions, M, efConstruction, RANDOM_SEED, numElements, storage); + index.addItems(inputData, 1); + + queryVectors = inputData; + } + + @TearDown(Level.Trial) + public void closeIndex() throws IOException { + if (index != null) { + index.close(); + } + } + + @Benchmark + public void queryK1(Blackhole bh) { + for (float[] queryVector : queryVectors) { + bh.consume(index.query(queryVector, 1)); + } + } + + @Benchmark + public void queryK20(Blackhole bh) { + for (float[] queryVector : queryVectors) { + bh.consume(index.query(queryVector, 20)); + } + } +} diff --git a/java/src/test/java/com/spotify/voyager/jni/StringIndexCreationBenchmark.java b/java/src/test/java/com/spotify/voyager/jni/StringIndexCreationBenchmark.java new file mode 100644 index 0000000..7320a79 --- /dev/null +++ b/java/src/test/java/com/spotify/voyager/jni/StringIndexCreationBenchmark.java @@ -0,0 +1,125 @@ +/*- + * -\-\- + * voyager + * -- + * Copyright (C) 2016 - 2023 Spotify AB + * -- + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * -/-/- + */ + +package com.spotify.voyager.jni; + +import java.io.IOException; +import java.util.Random; +import java.util.concurrent.TimeUnit; +import org.openjdk.jmh.annotations.*; + +/** + * JMH benchmarks for StringIndex creation performance. + * + *

Mirrors {@link IndexCreationBenchmark} but uses {@link StringIndex}, which wraps {@link Index} + * with a string-name-to-numeric-ID mapping layer. Measures the overhead of the StringIndex wrapper + * during item insertion. + */ +@State(Scope.Benchmark) +@BenchmarkMode(Mode.SingleShotTime) +@OutputTimeUnit(TimeUnit.MILLISECONDS) +@Fork(2) +@Warmup(iterations = 3) +@Measurement(iterations = 5) +public class StringIndexCreationBenchmark { + + @Param({"256"}) + public int numDimensions; + + @Param({"1024"}) + public int numElements; + + @Param({"Euclidean", "InnerProduct", "Cosine"}) + public String spaceType; + + @Param({"Float32", "Float8", "E4M3"}) + public String storageDataType; + + @Param({"24"}) + public int efConstruction; + + private static final int M = 20; + private static final long RANDOM_SEED = 4321; + private static final int NAME_LENGTH = 22; + private static final String NAME_PREFIX = "spotify:track:"; + + private float[][] inputData; + private String[] itemNames; + private StringIndex index; + + @Setup(Level.Trial) + public void generateData() { + Random rng = new Random(1234); + inputData = new float[numElements][numDimensions]; + boolean isFloat8 = "Float8".equals(storageDataType); + + for (int i = 0; i < numElements; i++) { + for (int j = 0; j < numDimensions; j++) { + float val = rng.nextFloat() * 2 - 1; + if (isFloat8) { + val = Math.round(val * 127f) / 127f; + } + inputData[i][j] = val; + } + } + + Random nameRng = new Random(5678); + itemNames = new String[numElements]; + for (int i = 0; i < numElements; i++) { + itemNames[i] = NAME_PREFIX + randomBase62String(nameRng, NAME_LENGTH); + } + } + + @Setup(Level.Invocation) + public void createFreshIndex() { + Index.SpaceType space = Index.SpaceType.valueOf(spaceType); + Index.StorageDataType storage = Index.StorageDataType.valueOf(storageDataType); + index = + new StringIndex(space, numDimensions, M, efConstruction, RANDOM_SEED, numElements, storage); + } + + @TearDown(Level.Invocation) + public void closeIndex() throws IOException { + if (index != null) { + index.close(); + } + } + + /** + * Adds items to the StringIndex one at a time using {@link StringIndex#addItem(String, float[])}. + */ + @Benchmark + public void addItems() { + for (int i = 0; i < numElements; i++) { + index.addItem(itemNames[i], inputData[i]); + } + } + + private static final String BASE62 = + "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789"; + + private static String randomBase62String(Random rng, int length) { + char[] chars = new char[length]; + for (int i = 0; i < length; i++) { + chars[i] = BASE62.charAt(rng.nextInt(BASE62.length())); + } + return new String(chars); + } +} diff --git a/java/src/test/java/com/spotify/voyager/jni/StringIndexQueryBenchmark.java b/java/src/test/java/com/spotify/voyager/jni/StringIndexQueryBenchmark.java new file mode 100644 index 0000000..5375acf --- /dev/null +++ b/java/src/test/java/com/spotify/voyager/jni/StringIndexQueryBenchmark.java @@ -0,0 +1,142 @@ +/*- + * -\-\- + * voyager + * -- + * Copyright (C) 2016 - 2023 Spotify AB + * -- + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * -/-/- + */ + +package com.spotify.voyager.jni; + +import java.io.IOException; +import java.util.Random; +import java.util.concurrent.TimeUnit; +import org.openjdk.jmh.annotations.*; +import org.openjdk.jmh.infra.Blackhole; + +/** + * JMH benchmarks for StringIndex query performance. + * + *

Mirrors {@link IndexQueryBenchmark} but uses {@link StringIndex}, which wraps {@link Index} + * with a string-name-to-numeric-ID mapping layer. Queries a pre-populated index of 4096 random + * vectors with 256 dimensions, parameterized over space type and storage data type. + */ +@State(Scope.Benchmark) +@BenchmarkMode(Mode.AverageTime) +@OutputTimeUnit(TimeUnit.MILLISECONDS) +@Fork(2) +@Warmup(iterations = 3) +@Measurement(iterations = 5) +public class StringIndexQueryBenchmark { + + @Param({"256"}) + public int numDimensions; + + @Param({"4096"}) + public int numElements; + + @Param({"Euclidean", "InnerProduct", "Cosine"}) + public String spaceType; + + @Param({"Float32", "Float8", "E4M3"}) + public String storageDataType; + + @Param({"24"}) + public int efConstruction; + + private static final int M = 20; + private static final long RANDOM_SEED = 4321; + private static final int DEFAULT_EF = -1; + private static final int NAME_LENGTH = 22; + private static final String NAME_PREFIX = "spotify:track:"; + + private StringIndex index; + private float[][] queryVectors; + + @Setup(Level.Trial) + public void buildIndex() { + Random rng = new Random(1234); + boolean isFloat8 = "Float8".equals(storageDataType); + + float[][] inputData = new float[numElements][numDimensions]; + for (int i = 0; i < numElements; i++) { + for (int j = 0; j < numDimensions; j++) { + float val = rng.nextFloat() * 2 - 1; + if (isFloat8) { + val = Math.round(val * 127f) / 127f; + } + inputData[i][j] = val; + } + } + + Random nameRng = new Random(5678); + String[] itemNames = new String[numElements]; + for (int i = 0; i < numElements; i++) { + itemNames[i] = NAME_PREFIX + randomBase62String(nameRng, NAME_LENGTH); + } + + Index.SpaceType space = Index.SpaceType.valueOf(spaceType); + Index.StorageDataType storage = Index.StorageDataType.valueOf(storageDataType); + index = + new StringIndex(space, numDimensions, M, efConstruction, RANDOM_SEED, numElements, storage); + for (int i = 0; i < numElements; i++) { + index.addItem(itemNames[i], inputData[i]); + } + + queryVectors = inputData; + } + + @TearDown(Level.Trial) + public void closeIndex() throws IOException { + if (index != null) { + index.close(); + } + } + + /** + * Queries the StringIndex for 1 nearest neighbor per vector. + * + * @param bh Blackhole to prevent dead-code elimination + */ + @Benchmark + public void queryK1(Blackhole bh) { + for (float[] queryVector : queryVectors) { + bh.consume(index.query(queryVector, 1, DEFAULT_EF)); + } + } + + /** + * Queries the StringIndex for 20 nearest neighbors per vector. + * + * @param bh Blackhole to prevent dead-code elimination + */ + @Benchmark + public void queryK20(Blackhole bh) { + for (float[] queryVector : queryVectors) { + bh.consume(index.query(queryVector, 20, DEFAULT_EF)); + } + } + + private static final String BASE62 = + "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789"; + + private static String randomBase62String(Random rng, int length) { + char[] chars = new char[length]; + for (int i = 0; i < length; i++) { + chars[i] = BASE62.charAt(rng.nextInt(BASE62.length())); + } + return new String(chars); + } +}