11/*
2- * SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION.
2+ * SPDX-FileCopyrightText: Copyright (c) 2025-2026 , NVIDIA CORPORATION.
33 * SPDX-License-Identifier: Apache-2.0
44 */
55
66#pragma once
77
88#include < cuvs/distance/distance.hpp>
9- #include < cuvs/neighbors/index_base .hpp>
9+ #include < cuvs/neighbors/cagra .hpp>
1010#include < raft/core/device_mdspan.hpp>
1111
12- #include < memory>
1312#include < vector>
1413
1514namespace cuvs ::neighbors::composite {
1615
1716/* *
18- * @brief Composite index made of other IndexBase implementations.
17+ * @brief Composite index that searches multiple CAGRA sub-indices and merges results.
18+ *
19+ * When the composite index contains multiple sub-indices, the user can set a
20+ * stream pool in the input raft::resource to enable parallel search across
21+ * sub-indices for improved performance.
22+ *
23+ * Usage example:
24+ * @code{.cpp}
25+ * using namespace cuvs::neighbors;
26+ *
27+ * auto index0 = cagra::build(res, params, dataset0);
28+ * auto index1 = cagra::build(res, params, dataset1);
29+ *
30+ * composite::CompositeIndex<float, uint32_t> composite({&index0, &index1});
31+ *
32+ * // optional: create a stream pool to enable parallel search across sub-indices
33+ * size_t n_streams = 2;
34+ * raft::resource::set_cuda_stream_pool(handle,
35+ * std::make_shared<rmm::cuda_stream_pool>(n_streams));
36+ *
37+ * composite.search(handle, search_params, queries, neighbors, distances);
38+ * @endcode
1939 */
2040template <typename T, typename IdxT, typename OutputIdxT = IdxT>
21- class CompositeIndex : public IndexBase <T, IdxT, OutputIdxT> {
41+ class CompositeIndex {
2242 public:
23- using value_type = typename IndexBase<T, IdxT, OutputIdxT>::value_type ;
24- using index_type = typename IndexBase<T, IdxT, OutputIdxT>::index_type ;
25- using out_index_type = typename IndexBase<T, IdxT, OutputIdxT>::out_index_type ;
26- using matrix_index_type = typename IndexBase<T, IdxT, OutputIdxT>::matrix_index_type ;
43+ using value_type = T ;
44+ using index_type = IdxT;
45+ using out_index_type = OutputIdxT;
46+ using matrix_index_type = int64_t ;
2747
28- using index_ptr = std::shared_ptr<IndexBase<value_type, index_type, out_index_type>>;
29-
30- explicit CompositeIndex (std::vector<index_ptr> children) : children_(std::move(children)) {}
48+ explicit CompositeIndex (std::vector<cuvs::neighbors::cagra::index<T, IdxT>*> children)
49+ : children_(std::move(children))
50+ {
51+ }
3152
3253 /* *
3354 * @brief Search the composite index for the k nearest neighbors.
3455 *
35- * When the composite index contains multiple sub-indices, the user can set a
36- * stream pool in the input raft::resource to enable parallel search across
37- * sub-indices for improved performance.
38- *
39- * Usage example:
40- * @code{.cpp}
41- * using namespace cuvs::neighbors;
42- * // create a composite index with multiple sub-indices
43- * std::vector<CompositeIndex<T, IdxT>::index_ptr> sub_indices;
44- * // ... populate sub_indices ...
45- * auto composite_index = CompositeIndex<T, IdxT>(std::move(sub_indices));
46- *
47- * // optional: create a stream pool to enable parallel search across sub-indices
48- * // recommended stream count: min(number_of_sub_indices, 8)
49- * size_t n_streams = std::min(sub_indices.size(), size_t(8));
50- * raft::resource::set_cuda_stream_pool(handle,
51- * std::make_shared<rmm::cuda_stream_pool>(n_streams));
52- *
53- * // perform search with parallel sub-index execution
54- * composite_index.search(handle, search_params, queries, neighbors, distances);
55- * @endcode
56+ * Searches each sub-index independently (optionally in parallel via stream pool),
57+ * then selects the top-k results across all sub-indices.
5658 *
5759 * @param[in] handle raft resource handle
58- * @param[in] params search parameters
60+ * @param[in] params CAGRA search parameters
5961 * @param[in] queries device matrix view of query vectors [n_queries, dim]
6062 * @param[out] neighbors device matrix view for neighbor indices [n_queries, k]
6163 * @param[out] distances device matrix view for distances [n_queries, k]
6264 * @param[in] filter optional filter for search results
6365 */
6466 void search (
6567 const raft::resources& handle,
66- const cuvs::neighbors::search_params& params,
68+ const cuvs::neighbors::cagra:: search_params& params,
6769 raft::device_matrix_view<const value_type, matrix_index_type, raft::row_major> queries,
6870 raft::device_matrix_view<out_index_type, matrix_index_type, raft::row_major> neighbors,
6971 raft::device_matrix_view<float , matrix_index_type, raft::row_major> distances,
7072 const cuvs::neighbors::filtering::base_filter& filter =
71- cuvs::neighbors::filtering::none_sample_filter{}) const override ;
73+ cuvs::neighbors::filtering::none_sample_filter{}) const ;
7274
73- index_type size () const noexcept override
75+ index_type size () const noexcept
7476 {
7577 index_type total = 0 ;
7678 for (const auto & c : children_) {
@@ -79,14 +81,14 @@ class CompositeIndex : public IndexBase<T, IdxT, OutputIdxT> {
7981 return total;
8082 }
8183
82- cuvs::distance::DistanceType metric () const noexcept override
84+ cuvs::distance::DistanceType metric () const noexcept
8385 {
8486 return children_.empty () ? cuvs::distance::DistanceType::L2Expanded
8587 : children_.front ()->metric ();
8688 }
8789
8890 private:
89- std::vector<index_ptr > children_;
91+ std::vector<cuvs::neighbors::cagra::index<T, IdxT>* > children_;
9092};
9193
9294} // namespace cuvs::neighbors::composite
0 commit comments