Skip to content

Commit 0af609a

Browse files
anirudlappathianirudlappathi
andauthored
Improve: bf16 API for index_dense (#561)
Closes #553 Co-authored-by: anirudlappathi <anirud.lappathi+anirudlappathi@users.noreply.github.com>
1 parent ed1cc8f commit 0af609a

File tree

3 files changed

+26
-5
lines changed

3 files changed

+26
-5
lines changed

cpp/test.cpp

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -872,16 +872,18 @@ void test_absurd(std::size_t dimensions, std::size_t connectivity, std::size_t e
872872
* @param dataset_count Number of vectors in the dataset.
873873
* @param queries_count Number of query vectors.
874874
* @param wanted_count Number of top matches required from each query.
875+
* @tparam scalar_at Data type of the elements in the vectors.
875876
*/
877+
template <typename scalar_at>
876878
void test_exact_search(std::size_t dataset_count, std::size_t queries_count, std::size_t wanted_count) {
877879
std::size_t dimensions = 32;
878880
metric_punned_t metric(dimensions, metric_kind_t::cos_k);
879881

880882
std::random_device rd;
881883
std::mt19937 gen(rd());
882884
std::uniform_real_distribution<> dis(0.0, 1.0);
883-
std::vector<float> dataset(dataset_count * dimensions);
884-
std::generate(dataset.begin(), dataset.end(), [&] { return dis(gen); });
885+
std::vector<scalar_at> dataset(dataset_count * dimensions);
886+
std::generate(dataset.begin(), dataset.end(), [&] { return static_cast<scalar_at>(dis(gen)); });
885887

886888
exact_search_t search;
887889
auto results = search( //
@@ -1099,6 +1101,7 @@ template <typename key_at, typename slot_at> void test_replacing_update() {
10991101
int main(int, char**) {
11001102
test_uint40();
11011103
test_cosine<float, std::int64_t, uint40_t>(10, 10);
1104+
test_cosine<bf16_t, std::int64_t, uint40_t>(10, 10);
11021105

11031106
// Test plugins, like K-Means clustering.
11041107
{
@@ -1121,8 +1124,10 @@ int main(int, char**) {
11211124
std::printf("Testing exact search\n");
11221125
for (std::size_t dataset_count : {10, 100})
11231126
for (std::size_t queries_count : {1, 10})
1124-
for (std::size_t wanted_count : {1, 5})
1125-
test_exact_search(dataset_count, queries_count, wanted_count);
1127+
for (std::size_t wanted_count : {1, 5}) {
1128+
test_exact_search<float>(dataset_count, queries_count, wanted_count);
1129+
test_exact_search<bf16_t>(dataset_count, queries_count, wanted_count);
1130+
}
11261131

11271132
// Make sure the initializers and the algorithms can work with inadequately small values.
11281133
// Be warned - this combinatorial explosion of tests produces close to __500'000__ tests!
@@ -1149,6 +1154,8 @@ int main(int, char**) {
11491154
test_cosine<float, std::int64_t, slot32_t>(collection_size, dimensions);
11501155
std::printf("- Indexing %zu vectors with cos: <float, std::int64_t, uint40_t> \n", collection_size);
11511156
test_cosine<float, std::int64_t, uint40_t>(collection_size, dimensions);
1157+
std::printf("- Indexing %zu vectors with cos: <bf16, std::int64_t, uint40_t> \n", collection_size);
1158+
test_cosine<bf16_t, std::int64_t, uint40_t>(collection_size, dimensions);
11521159
}
11531160

11541161
// Test with binary vectors

include/usearch/index_dense.hpp

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -760,36 +760,42 @@ class index_dense_gt {
760760
add_result_t add(vector_key_t key, b1x8_t const* vector, std::size_t thread = any_thread(), bool force_vector_copy = true) { return add_(key, vector, thread, force_vector_copy, casts_.from.b1x8); }
761761
add_result_t add(vector_key_t key, i8_t const* vector, std::size_t thread = any_thread(), bool force_vector_copy = true) { return add_(key, vector, thread, force_vector_copy, casts_.from.i8); }
762762
add_result_t add(vector_key_t key, f16_t const* vector, std::size_t thread = any_thread(), bool force_vector_copy = true) { return add_(key, vector, thread, force_vector_copy, casts_.from.f16); }
763+
add_result_t add(vector_key_t key, bf16_t const* vector, std::size_t thread = any_thread(), bool force_vector_copy = true) { return add_(key, vector, thread, force_vector_copy, casts_.from.bf16); }
763764
add_result_t add(vector_key_t key, f32_t const* vector, std::size_t thread = any_thread(), bool force_vector_copy = true) { return add_(key, vector, thread, force_vector_copy, casts_.from.f32); }
764765
add_result_t add(vector_key_t key, f64_t const* vector, std::size_t thread = any_thread(), bool force_vector_copy = true) { return add_(key, vector, thread, force_vector_copy, casts_.from.f64); }
765766

766767
search_result_t search(b1x8_t const* vector, std::size_t wanted, std::size_t thread = any_thread(), bool exact = false) const { return search_(vector, wanted, dummy_predicate_t {}, thread, exact, casts_.from.b1x8); }
767768
search_result_t search(i8_t const* vector, std::size_t wanted, std::size_t thread = any_thread(), bool exact = false) const { return search_(vector, wanted, dummy_predicate_t {}, thread, exact, casts_.from.i8); }
768769
search_result_t search(f16_t const* vector, std::size_t wanted, std::size_t thread = any_thread(), bool exact = false) const { return search_(vector, wanted, dummy_predicate_t {}, thread, exact, casts_.from.f16); }
770+
search_result_t search(bf16_t const* vector, std::size_t wanted, std::size_t thread = any_thread(), bool exact = false) const { return search_(vector, wanted, dummy_predicate_t {}, thread, exact, casts_.from.bf16); }
769771
search_result_t search(f32_t const* vector, std::size_t wanted, std::size_t thread = any_thread(), bool exact = false) const { return search_(vector, wanted, dummy_predicate_t {}, thread, exact, casts_.from.f32); }
770772
search_result_t search(f64_t const* vector, std::size_t wanted, std::size_t thread = any_thread(), bool exact = false) const { return search_(vector, wanted, dummy_predicate_t {}, thread, exact, casts_.from.f64); }
771773

772774
template <typename predicate_at> search_result_t filtered_search(b1x8_t const* vector, std::size_t wanted, predicate_at&& predicate, std::size_t thread = any_thread(), bool exact = false) const { return search_(vector, wanted, std::forward<predicate_at>(predicate), thread, exact, casts_.from.b1x8); }
773775
template <typename predicate_at> search_result_t filtered_search(i8_t const* vector, std::size_t wanted, predicate_at&& predicate, std::size_t thread = any_thread(), bool exact = false) const { return search_(vector, wanted, std::forward<predicate_at>(predicate), thread, exact, casts_.from.i8); }
774776
template <typename predicate_at> search_result_t filtered_search(f16_t const* vector, std::size_t wanted, predicate_at&& predicate, std::size_t thread = any_thread(), bool exact = false) const { return search_(vector, wanted, std::forward<predicate_at>(predicate), thread, exact, casts_.from.f16); }
777+
template <typename predicate_at> search_result_t filtered_search(bf16_t const* vector, std::size_t wanted, predicate_at&& predicate, std::size_t thread = any_thread(), bool exact = false) const { return search_(vector, wanted, std::forward<predicate_at>(predicate), thread, exact, casts_.from.bf16); }
775778
template <typename predicate_at> search_result_t filtered_search(f32_t const* vector, std::size_t wanted, predicate_at&& predicate, std::size_t thread = any_thread(), bool exact = false) const { return search_(vector, wanted, std::forward<predicate_at>(predicate), thread, exact, casts_.from.f32); }
776779
template <typename predicate_at> search_result_t filtered_search(f64_t const* vector, std::size_t wanted, predicate_at&& predicate, std::size_t thread = any_thread(), bool exact = false) const { return search_(vector, wanted, std::forward<predicate_at>(predicate), thread, exact, casts_.from.f64); }
777780

778781
std::size_t get(vector_key_t key, b1x8_t* vector, std::size_t vectors_count = 1) const { return get_(key, vector, vectors_count, casts_.to.b1x8); }
779782
std::size_t get(vector_key_t key, i8_t* vector, std::size_t vectors_count = 1) const { return get_(key, vector, vectors_count, casts_.to.i8); }
780783
std::size_t get(vector_key_t key, f16_t* vector, std::size_t vectors_count = 1) const { return get_(key, vector, vectors_count, casts_.to.f16); }
784+
std::size_t get(vector_key_t key, bf16_t* vector, std::size_t vectors_count = 1) const { return get_(key, vector, vectors_count, casts_.to.bf16); }
781785
std::size_t get(vector_key_t key, f32_t* vector, std::size_t vectors_count = 1) const { return get_(key, vector, vectors_count, casts_.to.f32); }
782786
std::size_t get(vector_key_t key, f64_t* vector, std::size_t vectors_count = 1) const { return get_(key, vector, vectors_count, casts_.to.f64); }
783787

784788
cluster_result_t cluster(b1x8_t const* vector, std::size_t level, std::size_t thread = any_thread()) const { return cluster_(vector, level, thread, casts_.from.b1x8); }
785789
cluster_result_t cluster(i8_t const* vector, std::size_t level, std::size_t thread = any_thread()) const { return cluster_(vector, level, thread, casts_.from.i8); }
786790
cluster_result_t cluster(f16_t const* vector, std::size_t level, std::size_t thread = any_thread()) const { return cluster_(vector, level, thread, casts_.from.f16); }
791+
cluster_result_t cluster(bf16_t const* vector, std::size_t level, std::size_t thread = any_thread()) const { return cluster_(vector, level, thread, casts_.from.bf16); }
787792
cluster_result_t cluster(f32_t const* vector, std::size_t level, std::size_t thread = any_thread()) const { return cluster_(vector, level, thread, casts_.from.f32); }
788793
cluster_result_t cluster(f64_t const* vector, std::size_t level, std::size_t thread = any_thread()) const { return cluster_(vector, level, thread, casts_.from.f64); }
789794

790795
aggregated_distances_t distance_between(vector_key_t key, b1x8_t const* vector, std::size_t thread = any_thread()) const { return distance_between_(key, vector, thread, casts_.to.b1x8); }
791796
aggregated_distances_t distance_between(vector_key_t key, i8_t const* vector, std::size_t thread = any_thread()) const { return distance_between_(key, vector, thread, casts_.to.i8); }
792797
aggregated_distances_t distance_between(vector_key_t key, f16_t const* vector, std::size_t thread = any_thread()) const { return distance_between_(key, vector, thread, casts_.to.f16); }
798+
aggregated_distances_t distance_between(vector_key_t key, bf16_t const* vector, std::size_t thread = any_thread()) const { return distance_between_(key, vector, thread, casts_.to.bf16); }
793799
aggregated_distances_t distance_between(vector_key_t key, f32_t const* vector, std::size_t thread = any_thread()) const { return distance_between_(key, vector, thread, casts_.to.f32); }
794800
aggregated_distances_t distance_between(vector_key_t key, f64_t const* vector, std::size_t thread = any_thread()) const { return distance_between_(key, vector, thread, casts_.to.f64); }
795801
// clang-format on

include/usearch/index_plugins.hpp

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -581,6 +581,11 @@ class bf16_bits_t {
581581
uint16_ = f32_to_bf16(v / bf16_to_f32(uint16_));
582582
return *this;
583583
}
584+
585+
inline bf16_bits_t& operator=(float v) noexcept {
586+
uint16_ = f32_to_bf16(v);
587+
return *this;
588+
}
584589
};
585590

586591
/**
@@ -1223,6 +1228,7 @@ struct casts_punned_t {
12231228
cast_punned_t b1x8{};
12241229
cast_punned_t i8{};
12251230
cast_punned_t f16{};
1231+
cast_punned_t bf16{};
12261232
cast_punned_t f32{};
12271233
cast_punned_t f64{};
12281234

@@ -1231,7 +1237,7 @@ struct casts_punned_t {
12311237
case scalar_kind_t::f64_k: return f64;
12321238
case scalar_kind_t::f32_k: return f32;
12331239
case scalar_kind_t::f16_k: return f16;
1234-
case scalar_kind_t::bf16_k: return f16;
1240+
case scalar_kind_t::bf16_k: return bf16;
12351241
case scalar_kind_t::i8_k: return i8;
12361242
case scalar_kind_t::b1x8_k: return b1x8;
12371243
default: return nullptr;
@@ -1246,12 +1252,14 @@ struct casts_punned_t {
12461252
result.from.b1x8 = &cast_gt<b1x8_t, scalar_at>::try_;
12471253
result.from.i8 = &cast_gt<i8_t, scalar_at>::try_;
12481254
result.from.f16 = &cast_gt<f16_t, scalar_at>::try_;
1255+
result.from.bf16 = &cast_gt<bf16_t, scalar_at>::try_;
12491256
result.from.f32 = &cast_gt<f32_t, scalar_at>::try_;
12501257
result.from.f64 = &cast_gt<f64_t, scalar_at>::try_;
12511258

12521259
result.to.b1x8 = &cast_gt<scalar_at, b1x8_t>::try_;
12531260
result.to.i8 = &cast_gt<scalar_at, i8_t>::try_;
12541261
result.to.f16 = &cast_gt<scalar_at, f16_t>::try_;
1262+
result.to.bf16 = &cast_gt<scalar_at, bf16_t>::try_;
12551263
result.to.f32 = &cast_gt<scalar_at, f32_t>::try_;
12561264
result.to.f64 = &cast_gt<scalar_at, f64_t>::try_;
12571265

0 commit comments

Comments
 (0)