@@ -1098,6 +1098,51 @@ template <typename key_at, typename slot_at> void test_replacing_update() {
10981098 expect_eq (final_search[2 ].member .key , 44 );
10991099}
11001100
1101+ /* *
1102+ * Tests the filtered search functionality of the index.
1103+ */
1104+ void test_filtered_search () {
1105+ constexpr std::size_t dataset_count = 2048 ;
1106+ constexpr std::size_t dimensions = 32 ;
1107+ metric_punned_t metric (dimensions, metric_kind_t ::cos_k);
1108+
1109+ std::random_device rd;
1110+ std::mt19937 gen (rd ());
1111+ std::uniform_real_distribution<> dis (0.0 , 1.0 );
1112+ using vector_of_vectors_t = std::vector<std::vector<float >>;
1113+
1114+ vector_of_vectors_t vector_of_vectors (dataset_count);
1115+ for (auto & vector : vector_of_vectors) {
1116+ vector.resize (dimensions);
1117+ std::generate (vector.begin (), vector.end (), [&] { return dis (gen); });
1118+ }
1119+
1120+ index_dense_t index = index_dense_t::make (metric);
1121+ index.reserve (dataset_count);
1122+ for (std::size_t idx = 0 ; idx < dataset_count; ++idx)
1123+ index.add (idx, vector_of_vectors[idx].data ());
1124+ expect_eq (index.size (), dataset_count);
1125+
1126+ {
1127+ auto predicate = [](index_dense_t ::key_t key) { return key != 0 ; };
1128+ auto results = index.filtered_search (vector_of_vectors[0 ].data (), 10 , predicate);
1129+ expect_eq (10 , results.size ()); // ! Should not contain 0
1130+ for (std::size_t i = 0 ; i != results.size (); ++i)
1131+ expect (0 != results[i].member .key );
1132+ }
1133+ {
1134+ auto predicate = [](index_dense_t ::key_t ) { return false ; };
1135+ auto results = index.filtered_search (vector_of_vectors[0 ].data (), 10 , predicate);
1136+ expect_eq (0 , results.size ()); // ! Should not contain 0
1137+ }
1138+ {
1139+ auto predicate = [](index_dense_t ::key_t key) { return key == 10 ; };
1140+ auto results = index.filtered_search (vector_of_vectors[0 ].data (), 10 , predicate);
1141+ expect_eq (1 , results.size ()); // ! Should not contain 0
1142+ expect_eq (10 , results[0 ].member .key );
1143+ }
1144+ }
1145+
11011146int main (int , char **) {
11021147 test_uint40 ();
11031148 test_cosine<float , std::int64_t , uint40_t >(10 , 10 );
@@ -1174,5 +1219,6 @@ int main(int, char**) {
11741219 test_sets<std::int64_t , slot32_t >(set_size, 20 , 30 );
11751220 test_strings<std::int64_t , slot32_t >();
11761221
1222+ test_filtered_search ();
11771223 return 0 ;
11781224}
0 commit comments