Skip to content

Commit d471fd9

Browse files
committed
Merge branch 'main-dev' of https://github.com/unum-cloud/usearch into main-dev
2 parents a394983 + 369a553 commit d471fd9

File tree

3 files changed

+51
-3
lines changed

3 files changed

+51
-3
lines changed

cpp/test.cpp

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1098,6 +1098,51 @@ template <typename key_at, typename slot_at> void test_replacing_update() {
10981098
expect_eq(final_search[2].member.key, 44);
10991099
}
11001100

1101+
/**
1102+
* Tests the filtered search functionality of the index.
1103+
*/
1104+
void test_filtered_search() {
1105+
constexpr std::size_t dataset_count = 2048;
1106+
constexpr std::size_t dimensions = 32;
1107+
metric_punned_t metric(dimensions, metric_kind_t::cos_k);
1108+
1109+
std::random_device rd;
1110+
std::mt19937 gen(rd());
1111+
std::uniform_real_distribution<> dis(0.0, 1.0);
1112+
using vector_of_vectors_t = std::vector<std::vector<float>>;
1113+
1114+
vector_of_vectors_t vector_of_vectors(dataset_count);
1115+
for (auto& vector : vector_of_vectors) {
1116+
vector.resize(dimensions);
1117+
std::generate(vector.begin(), vector.end(), [&] { return dis(gen); });
1118+
}
1119+
1120+
index_dense_t index = index_dense_t::make(metric);
1121+
index.reserve(dataset_count);
1122+
for (std::size_t idx = 0; idx < dataset_count; ++idx)
1123+
index.add(idx, vector_of_vectors[idx].data());
1124+
expect_eq(index.size(), dataset_count);
1125+
1126+
{
1127+
auto predicate = [](index_dense_t::key_t key) { return key != 0; };
1128+
auto results = index.filtered_search(vector_of_vectors[0].data(), 10, predicate);
1129+
expect_eq(10, results.size()); // ! Should not contain 0
1130+
for (std::size_t i = 0; i != results.size(); ++i)
1131+
expect(0 != results[i].member.key);
1132+
}
1133+
{
1134+
auto predicate = [](index_dense_t::key_t) { return false; };
1135+
auto results = index.filtered_search(vector_of_vectors[0].data(), 10, predicate);
1136+
expect_eq(0, results.size()); // ! Should not contain 0
1137+
}
1138+
{
1139+
auto predicate = [](index_dense_t::key_t key) { return key == 10; };
1140+
auto results = index.filtered_search(vector_of_vectors[0].data(), 10, predicate);
1141+
expect_eq(1, results.size()); // ! Should not contain 0
1142+
expect_eq(10, results[0].member.key);
1143+
}
1144+
}
1145+
11011146
int main(int, char**) {
11021147
test_uint40();
11031148
test_cosine<float, std::int64_t, uint40_t>(10, 10);
@@ -1174,5 +1219,6 @@ int main(int, char**) {
11741219
test_sets<std::int64_t, slot32_t>(set_size, 20, 30);
11751220
test_strings<std::int64_t, slot32_t>();
11761221

1222+
test_filtered_search();
11771223
return 0;
11781224
}

include/usearch/index.hpp

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4178,9 +4178,10 @@ class index_gt {
41784178
// This can substantially grow our priority queue:
41794179
next.insert({-successor_dist, successor_slot});
41804180
if (is_dummy<predicate_at>() ||
4181-
predicate(member_cref_t{node_at_(successor_slot).ckey(), successor_slot}))
4181+
predicate(member_cref_t{node_at_(successor_slot).ckey(), successor_slot})) {
41824182
top.insert({successor_dist, successor_slot}, top_limit);
4183-
radius = top.top().distance;
4183+
radius = top.top().distance;
4184+
}
41844185
}
41854186
}
41864187
}

python/usearch/index.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -604,7 +604,7 @@ def metadata(path_or_buffer: PathOrBuffer) -> Optional[dict]:
604604
raise e
605605

606606
@staticmethod
607-
def restore(path_or_buffer: PathOrBuffer, view: bool = False) -> Optional[Index]:
607+
def restore(path_or_buffer: PathOrBuffer, view: bool = False, **kwargs) -> Optional[Index]:
608608
meta = Index.metadata(path_or_buffer)
609609
if not meta:
610610
return None
@@ -613,6 +613,7 @@ def restore(path_or_buffer: PathOrBuffer, view: bool = False) -> Optional[Index]
613613
ndim=meta["dimensions"],
614614
dtype=meta["kind_scalar"],
615615
metric=meta["kind_metric"],
616+
**kwargs,
616617
)
617618

618619
if view:

0 commit comments

Comments
 (0)