Skip to content

Commit 7cabbe1

Browse files
committed
remove IndexBase and IndexWrapper
1 parent f5cabd8 commit 7cabbe1

File tree

12 files changed

+77
-601
lines changed

12 files changed

+77
-601
lines changed

cpp/CMakeLists.txt

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -566,9 +566,7 @@ if(NOT BUILD_CPU_ONLY)
566566
src/neighbors/iface/iface_pq_uint8_t_int64_t.cu
567567
src/neighbors/detail/cagra/topk_for_cagra/topk.cu
568568
src/neighbors/dynamic_batching.cu
569-
src/neighbors/cagra_index_wrapper.cu
570569
src/neighbors/composite/index.cu
571-
src/neighbors/composite/merge.cpp
572570
$<$<BOOL:${BUILD_CAGRA_HNSWLIB}>:src/neighbors/cagra.cpp>
573571
$<$<BOOL:${BUILD_CAGRA_HNSWLIB}>:src/neighbors/hnsw.cpp>
574572
src/neighbors/ivf_common.cu

cpp/bench/ann/src/cuvs/cuvs_cagra_wrapper.h

Lines changed: 9 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
#include <cuvs/distance/distance.hpp>
1414
#include <cuvs/neighbors/cagra.hpp>
1515
#include <cuvs/neighbors/common.hpp>
16-
#include <cuvs/neighbors/composite/merge.hpp>
16+
#include <cuvs/neighbors/composite/index.hpp>
1717
#include <cuvs/neighbors/dynamic_batching.hpp>
1818
#include <cuvs/neighbors/ivf_pq.hpp>
1919
#include <cuvs/neighbors/nn_descent.hpp>
@@ -453,33 +453,21 @@ void cuvs_cagra<T, IdxT>::search_base(
453453
} else {
454454
if (index_params_.merge_type == CagraMergeType::kLogical) {
455455
// TODO: index merge must happen outside of search, otherwise what are we benchmarking?
456-
cuvs::neighbors::cagra::merge_params merge_params{cuvs::neighbors::cagra::index_params{}};
457-
merge_params.merge_strategy = cuvs::neighbors::MergeStrategy::MERGE_STRATEGY_LOGICAL;
458-
459-
// Create wrapped indices for composite merge
460-
std::vector<std::shared_ptr<cuvs::neighbors::IndexBase<T, IdxT, algo_base::index_type>>>
461-
wrapped_indices;
462-
wrapped_indices.reserve(sub_indices_.size());
456+
std::vector<cuvs::neighbors::cagra::index<T, IdxT>*> cagra_indices;
457+
cagra_indices.reserve(sub_indices_.size());
463458
for (auto& ptr : sub_indices_) {
464-
auto index_wrapper =
465-
cuvs::neighbors::cagra::make_index_wrapper<T, IdxT, algo_base::index_type>(ptr.get());
466-
wrapped_indices.push_back(index_wrapper);
459+
cagra_indices.push_back(ptr.get());
467460
}
468461

469462
raft::resources composite_handle(handle_);
470-
size_t n_streams = wrapped_indices.size();
463+
size_t n_streams = cagra_indices.size();
471464
raft::resource::set_cuda_stream_pool(composite_handle,
472465
std::make_shared<rmm::cuda_stream_pool>(n_streams));
473466

474-
auto merged_index =
475-
cuvs::neighbors::composite::merge(composite_handle, merge_params, wrapped_indices);
476-
cuvs::neighbors::filtering::none_sample_filter empty_filter;
477-
merged_index->search(composite_handle,
478-
search_params_,
479-
queries_view,
480-
neighbors_view,
481-
distances_view,
482-
empty_filter);
467+
cuvs::neighbors::composite::CompositeIndex<T, IdxT, algo_base::index_type> composite(
468+
cagra_indices);
469+
composite.search(
470+
composite_handle, search_params_, queries_view, neighbors_view, distances_view);
483471
}
484472
}
485473
}

cpp/include/cuvs/neighbors/cagra.hpp

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3173,5 +3173,3 @@ void optimize(raft::resources const& handle,
31733173
raft::host_matrix_view<uint32_t, int64_t, raft::row_major> new_graph);
31743174

31753175
} // namespace cuvs::neighbors::cagra::helpers
3176-
3177-
#include <cuvs/neighbors/cagra_index_wrapper.hpp>

cpp/include/cuvs/neighbors/cagra_index_wrapper.hpp

Lines changed: 0 additions & 163 deletions
This file was deleted.

cpp/include/cuvs/neighbors/common.hpp

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -126,13 +126,6 @@ enum class MergeStrategy {
126126
MERGE_STRATEGY_LOGICAL = 1
127127
};
128128

129-
/** Base merge parameters with polymorphic interface. */
130-
struct merge_params {
131-
virtual ~merge_params() = default;
132-
133-
virtual MergeStrategy strategy() const = 0;
134-
};
135-
136129
/** @} */ // end group neighbors_index
137130

138131
/** Two-dimensional dataset; maybe owning, maybe compressed, maybe strided. */
Lines changed: 41 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -1,76 +1,78 @@
11
/*
2-
* SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION.
2+
* SPDX-FileCopyrightText: Copyright (c) 2025-2026, NVIDIA CORPORATION.
33
* SPDX-License-Identifier: Apache-2.0
44
*/
55

66
#pragma once
77

88
#include <cuvs/distance/distance.hpp>
9-
#include <cuvs/neighbors/index_base.hpp>
9+
#include <cuvs/neighbors/cagra.hpp>
1010
#include <raft/core/device_mdspan.hpp>
1111

12-
#include <memory>
1312
#include <vector>
1413

1514
namespace cuvs::neighbors::composite {
1615

1716
/**
18-
* @brief Composite index made of other IndexBase implementations.
17+
* @brief Composite index that searches multiple CAGRA sub-indices and merges results.
18+
*
19+
* When the composite index contains multiple sub-indices, the user can set a
20+
* stream pool in the input raft::resource to enable parallel search across
21+
* sub-indices for improved performance.
22+
*
23+
* Usage example:
24+
* @code{.cpp}
25+
* using namespace cuvs::neighbors;
26+
*
27+
* auto index0 = cagra::build(res, params, dataset0);
28+
* auto index1 = cagra::build(res, params, dataset1);
29+
*
30+
* composite::CompositeIndex<float, uint32_t> composite({&index0, &index1});
31+
*
32+
* // optional: create a stream pool to enable parallel search across sub-indices
33+
* size_t n_streams = 2;
34+
* raft::resource::set_cuda_stream_pool(handle,
35+
* std::make_shared<rmm::cuda_stream_pool>(n_streams));
36+
*
37+
* composite.search(handle, search_params, queries, neighbors, distances);
38+
* @endcode
1939
*/
2040
template <typename T, typename IdxT, typename OutputIdxT = IdxT>
21-
class CompositeIndex : public IndexBase<T, IdxT, OutputIdxT> {
41+
class CompositeIndex {
2242
public:
23-
using value_type = typename IndexBase<T, IdxT, OutputIdxT>::value_type;
24-
using index_type = typename IndexBase<T, IdxT, OutputIdxT>::index_type;
25-
using out_index_type = typename IndexBase<T, IdxT, OutputIdxT>::out_index_type;
26-
using matrix_index_type = typename IndexBase<T, IdxT, OutputIdxT>::matrix_index_type;
43+
using value_type = T;
44+
using index_type = IdxT;
45+
using out_index_type = OutputIdxT;
46+
using matrix_index_type = int64_t;
2747

28-
using index_ptr = std::shared_ptr<IndexBase<value_type, index_type, out_index_type>>;
29-
30-
explicit CompositeIndex(std::vector<index_ptr> children) : children_(std::move(children)) {}
48+
explicit CompositeIndex(std::vector<cuvs::neighbors::cagra::index<T, IdxT>*> children)
49+
: children_(std::move(children))
50+
{
51+
}
3152

3253
/**
3354
* @brief Search the composite index for the k nearest neighbors.
3455
*
35-
* When the composite index contains multiple sub-indices, the user can set a
36-
* stream pool in the input raft::resource to enable parallel search across
37-
* sub-indices for improved performance.
38-
*
39-
* Usage example:
40-
* @code{.cpp}
41-
* using namespace cuvs::neighbors;
42-
* // create a composite index with multiple sub-indices
43-
* std::vector<CompositeIndex<T, IdxT>::index_ptr> sub_indices;
44-
* // ... populate sub_indices ...
45-
* auto composite_index = CompositeIndex<T, IdxT>(std::move(sub_indices));
46-
*
47-
* // optional: create a stream pool to enable parallel search across sub-indices
48-
* // recommended stream count: min(number_of_sub_indices, 8)
49-
* size_t n_streams = std::min(sub_indices.size(), size_t(8));
50-
* raft::resource::set_cuda_stream_pool(handle,
51-
* std::make_shared<rmm::cuda_stream_pool>(n_streams));
52-
*
53-
* // perform search with parallel sub-index execution
54-
* composite_index.search(handle, search_params, queries, neighbors, distances);
55-
* @endcode
56+
* Searches each sub-index independently (optionally in parallel via stream pool),
57+
* then selects the top-k results across all sub-indices.
5658
*
5759
* @param[in] handle raft resource handle
58-
* @param[in] params search parameters
60+
* @param[in] params CAGRA search parameters
5961
* @param[in] queries device matrix view of query vectors [n_queries, dim]
6062
* @param[out] neighbors device matrix view for neighbor indices [n_queries, k]
6163
* @param[out] distances device matrix view for distances [n_queries, k]
6264
* @param[in] filter optional filter for search results
6365
*/
6466
void search(
6567
const raft::resources& handle,
66-
const cuvs::neighbors::search_params& params,
68+
const cuvs::neighbors::cagra::search_params& params,
6769
raft::device_matrix_view<const value_type, matrix_index_type, raft::row_major> queries,
6870
raft::device_matrix_view<out_index_type, matrix_index_type, raft::row_major> neighbors,
6971
raft::device_matrix_view<float, matrix_index_type, raft::row_major> distances,
7072
const cuvs::neighbors::filtering::base_filter& filter =
71-
cuvs::neighbors::filtering::none_sample_filter{}) const override;
73+
cuvs::neighbors::filtering::none_sample_filter{}) const;
7274

73-
index_type size() const noexcept override
75+
index_type size() const noexcept
7476
{
7577
index_type total = 0;
7678
for (const auto& c : children_) {
@@ -79,14 +81,14 @@ class CompositeIndex : public IndexBase<T, IdxT, OutputIdxT> {
7981
return total;
8082
}
8183

82-
cuvs::distance::DistanceType metric() const noexcept override
84+
cuvs::distance::DistanceType metric() const noexcept
8385
{
8486
return children_.empty() ? cuvs::distance::DistanceType::L2Expanded
8587
: children_.front()->metric();
8688
}
8789

8890
private:
89-
std::vector<index_ptr> children_;
91+
std::vector<cuvs::neighbors::cagra::index<T, IdxT>*> children_;
9092
};
9193

9294
} // namespace cuvs::neighbors::composite

0 commit comments

Comments
 (0)