Skip to content

Commit e913531

Browse files
authored
Merge branch 'main' into update-pytest-pin
2 parents af79265 + 105c61e commit e913531

File tree

6 files changed

+258
-23
lines changed

6 files changed

+258
-23
lines changed

cpp/src/neighbors/detail/epsilon_neighborhood.cuh

Lines changed: 34 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
/*
2-
* SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION.
2+
* SPDX-FileCopyrightText: Copyright (c) 2025-2026, NVIDIA CORPORATION.
33
* SPDX-License-Identifier: Apache-2.0
44
*/
55

@@ -27,6 +27,10 @@ struct EpsUnexpL2SqNeighborhood : public BaseClass {
2727

2828
DataT acc[P::AccRowsPerTh][P::AccColsPerTh];
2929

30+
size_t n_blocks_y;
31+
size_t block_x;
32+
size_t block_y;
33+
3034
public:
3135
DI EpsUnexpL2SqNeighborhood(bool* _adj,
3236
IdxT* _vd,
@@ -36,9 +40,17 @@ struct EpsUnexpL2SqNeighborhood : public BaseClass {
3640
IdxT _n,
3741
IdxT _k,
3842
DataT _eps,
39-
char* _smem)
40-
: BaseClass(_x, _y, _m, _n, _k, _smem), adj(_adj), eps(_eps), vd(_vd), smem(_smem)
43+
char* _smem,
44+
size_t _n_blocks_y)
45+
: BaseClass(_x, _y, _m, _n, _k, _smem),
46+
adj(_adj),
47+
eps(_eps),
48+
vd(_vd),
49+
smem(_smem),
50+
n_blocks_y(_n_blocks_y)
4151
{
52+
block_x = static_cast<size_t>(blockIdx.x) / n_blocks_y;
53+
block_y = static_cast<size_t>(blockIdx.x) % n_blocks_y;
4254
}
4355

4456
DI void run()
@@ -51,7 +63,7 @@ struct EpsUnexpL2SqNeighborhood : public BaseClass {
5163
private:
5264
DI void prolog()
5365
{
54-
this->ldgXY(IdxT(blockIdx.x) * P::Mblk, IdxT(blockIdx.y) * P::Nblk, 0);
66+
this->ldgXY(block_x * P::Mblk, block_y * P::Nblk, 0);
5567
#pragma unroll
5668
for (int i = 0; i < P::AccRowsPerTh; ++i) {
5769
#pragma unroll
@@ -67,7 +79,7 @@ struct EpsUnexpL2SqNeighborhood : public BaseClass {
6779
DI void loop()
6880
{
6981
for (int kidx = P::Kblk; kidx < this->k; kidx += P::Kblk) {
70-
this->ldgXY(IdxT(blockIdx.x) * P::Mblk, IdxT(blockIdx.y) * P::Nblk, kidx);
82+
this->ldgXY(block_x * P::Mblk, block_y * P::Nblk, kidx);
7183
accumulate(); // on the previous k-block
7284
this->stsXY();
7385
__syncthreads();
@@ -79,8 +91,8 @@ struct EpsUnexpL2SqNeighborhood : public BaseClass {
7991

8092
DI void epilog()
8193
{
82-
IdxT startx = blockIdx.x * P::Mblk + this->accrowid;
83-
IdxT starty = blockIdx.y * P::Nblk + this->acccolid;
94+
IdxT startx = block_x * P::Mblk + this->accrowid;
95+
IdxT starty = block_y * P::Nblk + this->acccolid;
8496
auto lid = raft::laneId();
8597
IdxT sums[P::AccRowsPerTh];
8698
#pragma unroll
@@ -126,7 +138,7 @@ struct EpsUnexpL2SqNeighborhood : public BaseClass {
126138
__syncthreads(); // so that we can safely reuse smem
127139
int gid = this->accrowid;
128140
int lid = this->acccolid;
129-
auto cidx = IdxT(blockIdx.x) * P::Mblk + gid;
141+
auto cidx = block_x * P::Mblk + gid;
130142
IdxT totalSum = 0;
131143
// update the individual vertex degrees
132144
#pragma unroll
@@ -157,11 +169,18 @@ struct EpsUnexpL2SqNeighborhood : public BaseClass {
157169
}; // struct EpsUnexpL2SqNeighborhood
158170

159171
template <typename DataT, typename IdxT, typename Policy>
160-
__launch_bounds__(Policy::Nthreads, 2) RAFT_KERNEL epsUnexpL2SqNeighKernel(
161-
bool* adj, IdxT* vd, const DataT* x, const DataT* y, IdxT m, IdxT n, IdxT k, DataT eps)
172+
__launch_bounds__(Policy::Nthreads, 2) RAFT_KERNEL epsUnexpL2SqNeighKernel(bool* adj,
173+
IdxT* vd,
174+
const DataT* x,
175+
const DataT* y,
176+
IdxT m,
177+
IdxT n,
178+
IdxT k,
179+
DataT eps,
180+
size_t n_blocks_y)
162181
{
163182
extern __shared__ char smem[];
164-
EpsUnexpL2SqNeighborhood<DataT, IdxT, Policy> obj(adj, vd, x, y, m, n, k, eps, smem);
183+
EpsUnexpL2SqNeighborhood<DataT, IdxT, Policy> obj(adj, vd, x, y, m, n, k, eps, smem, n_blocks_y);
165184
obj.run();
166185
}
167186

@@ -177,10 +196,12 @@ void epsUnexpL2SqNeighImpl(bool* adj,
177196
cudaStream_t stream)
178197
{
179198
typedef typename raft::linalg::Policy4x4<DataT, VecLen>::Policy Policy;
180-
dim3 grid(raft::ceildiv<int>(m, Policy::Mblk), raft::ceildiv<int>(n, Policy::Nblk));
199+
size_t n_blocks_x = raft::ceildiv<size_t>(m, Policy::Mblk);
200+
size_t n_blocks_y = raft::ceildiv<size_t>(n, Policy::Nblk);
201+
dim3 grid(n_blocks_x * n_blocks_y);
181202
dim3 blk(Policy::Nthreads);
182203
epsUnexpL2SqNeighKernel<DataT, IdxT, Policy>
183-
<<<grid, blk, Policy::SmemSize, stream>>>(adj, vd, x, y, m, n, k, eps);
204+
<<<grid, blk, Policy::SmemSize, stream>>>(adj, vd, x, y, m, n, k, eps, n_blocks_y);
184205
RAFT_CUDA_TRY(cudaGetLastError());
185206
}
186207

cpp/tests/neighbors/epsilon_neighborhood.cu

Lines changed: 39 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
/*
2-
* SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION.
2+
* SPDX-FileCopyrightText: Copyright (c) 2025-2026, NVIDIA CORPORATION.
33
* SPDX-License-Identifier: Apache-2.0
44
*/
55

@@ -10,6 +10,7 @@
1010
#include <raft/core/device_mdspan.hpp>
1111
#include <raft/core/host_mdspan.hpp>
1212
#include <raft/core/resource/cuda_stream.hpp>
13+
#include <raft/matrix/init.cuh>
1314
#include <raft/random/make_blobs.cuh>
1415
#include <raft/sparse/convert/csr.cuh>
1516
#include <raft/util/cudart_utils.hpp>
@@ -419,4 +420,41 @@ TEST_P(EpsNeighRbcTestFI, SparseRbcMaxK)
419420

420421
INSTANTIATE_TEST_CASE_P(EpsNeighTests, EpsNeighRbcTestFI, ::testing::ValuesIn(inputsfi_rbc));
421422

423+
TEST(EpsNeighborhood, LargeNDimension)
424+
{
425+
// n just past the grid.y=65535 limit for Nblk=16
426+
int64_t m = 1, n = 65536 * 16 + 1, k = 4;
427+
float eps = 1e10f; // large enough that everything is a neighbor
428+
429+
raft::resources handle;
430+
auto x = raft::make_device_matrix<float, int64_t>(handle, m, k);
431+
auto y = raft::make_device_matrix<float, int64_t>(handle, n, k);
432+
auto adj = raft::make_device_matrix<bool, int64_t>(handle, m, n);
433+
auto vd = raft::make_device_vector<int64_t, int64_t>(handle, m + 1);
434+
435+
// fill x, y with zeros (every pair has distance 0 < eps)
436+
raft::matrix::fill(handle, x.view(), 0.0f);
437+
raft::matrix::fill(handle, y.view(), 0.0f);
438+
439+
cuvs::neighbors::epsilon_neighborhood::compute(handle,
440+
raft::make_const_mdspan(x.view()),
441+
raft::make_const_mdspan(y.view()),
442+
adj.view(),
443+
vd.view(),
444+
eps,
445+
cuvs::distance::DistanceType::L2Unexpanded);
446+
447+
// Verify: with distance=0 and huge eps, every entry in adj should be true
448+
// and vd[0] should equal n
449+
auto adj_expected = raft::make_device_matrix<bool, int64_t>(handle, m, n);
450+
raft::matrix::fill(handle, adj_expected.view(), true);
451+
auto stream = raft::resource::get_cuda_stream(handle);
452+
ASSERT_TRUE(cuvs::devArrMatch(
453+
adj_expected.data_handle(), adj.data_handle(), m * n, cuvs::Compare<bool>(), stream));
454+
455+
int64_t expected_vd0 = n;
456+
ASSERT_TRUE(
457+
cuvs::devArrMatch(&expected_vd0, vd.data_handle(), 1, cuvs::Compare<int64_t>(), stream));
458+
}
459+
422460
}; // namespace cuvs::neighbors::epsilon_neighborhood

rust/cuvs/src/brute_force.rs

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
/*
2-
* SPDX-FileCopyrightText: Copyright (c) 2024-2025, NVIDIA CORPORATION.
2+
* SPDX-FileCopyrightText: Copyright (c) 2024-2026, NVIDIA CORPORATION.
33
* SPDX-License-Identifier: Apache-2.0
44
*/
55
//! Brute Force KNN
@@ -62,7 +62,7 @@ impl Index {
6262
/// * `neighbors` - Matrix in device memory that receives the indices of the nearest neighbors
6363
/// * `distances` - Matrix in device memory that receives the distances of the nearest neighbors
6464
pub fn search(
65-
self,
65+
&self,
6666
res: &Resources,
6767
queries: &ManagedTensor,
6868
neighbors: &ManagedTensor,
@@ -89,7 +89,7 @@ impl Index {
8989
impl Drop for Index {
9090
fn drop(&mut self) {
9191
if let Err(e) = check_cuvs(unsafe { ffi::cuvsBruteForceIndexDestroy(self.0) }) {
92-
write!(stderr(), "failed to call cagraIndexDestroy {:?}", e)
92+
write!(stderr(), "failed to call bruteForceIndexDestroy {:?}", e)
9393
.expect("failed to write to stderr");
9494
}
9595
}
@@ -172,4 +172,11 @@ mod tests {
172172
fn test_l2() {
173173
test_bfknn(DistanceType::L2Expanded);
174174
}
175+
176+
// NOTE: brute_force multiple-search test is omitted here because the C++
177+
// brute_force::index stores a non-owning view into the dataset. Building
178+
// from device data via `build()` drops the ManagedTensor after the call,
179+
// leaving a dangling pointer. A follow-up PR will add dataset lifetime
180+
// enforcement (DatasetOwnership<'a>) to make this safe.
181+
// See: https://github.com/rapidsai/cuvs/issues/1838
175182
}

rust/cuvs/src/cagra/index.rs

Lines changed: 57 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
/*
2-
* SPDX-FileCopyrightText: Copyright (c) 2024, NVIDIA CORPORATION.
2+
* SPDX-FileCopyrightText: Copyright (c) 2024-2026, NVIDIA CORPORATION.
33
* SPDX-License-Identifier: Apache-2.0
44
*/
55

@@ -59,7 +59,7 @@ impl Index {
5959
/// * `neighbors` - Matrix in device memory that receives the indices of the nearest neighbors
6060
/// * `distances` - Matrix in device memory that receives the distances of the nearest neighbors
6161
pub fn search(
62-
self,
62+
&self,
6363
res: &Resources,
6464
params: &SearchParams,
6565
queries: &ManagedTensor,
@@ -167,4 +167,59 @@ mod tests {
167167
.set_compression(CompressionParams::new().unwrap());
168168
test_cagra(build_params);
169169
}
170+
171+
/// Test that an index can be searched multiple times without rebuilding.
172+
/// This validates that search() takes &self instead of self.
173+
#[test]
174+
fn test_cagra_multiple_searches() {
175+
let res = Resources::new().unwrap();
176+
let build_params = IndexParams::new().unwrap();
177+
178+
// Create a random dataset
179+
let n_datapoints = 256;
180+
let n_features = 16;
181+
let dataset =
182+
ndarray::Array::<f32, _>::random((n_datapoints, n_features), Uniform::new(0., 1.0));
183+
184+
// Build the index once
185+
let index =
186+
Index::build(&res, &build_params, &dataset).expect("failed to create cagra index");
187+
188+
let search_params = SearchParams::new().unwrap();
189+
let k = 5;
190+
191+
// Perform multiple searches on the same index
192+
for search_iter in 0..3 {
193+
let n_queries = 4;
194+
let queries = dataset.slice(s![0..n_queries, ..]);
195+
let queries = ManagedTensor::from(&queries).to_device(&res).unwrap();
196+
197+
let mut neighbors_host = ndarray::Array::<u32, _>::zeros((n_queries, k));
198+
let neighbors = ManagedTensor::from(&neighbors_host)
199+
.to_device(&res)
200+
.unwrap();
201+
202+
let mut distances_host = ndarray::Array::<f32, _>::zeros((n_queries, k));
203+
let distances = ManagedTensor::from(&distances_host)
204+
.to_device(&res)
205+
.unwrap();
206+
207+
// This should work on every iteration because search() takes &self
208+
index
209+
.search(&res, &search_params, &queries, &neighbors, &distances)
210+
.expect(&format!("search iteration {} failed", search_iter));
211+
212+
// Copy back to host memory
213+
distances.to_host(&res, &mut distances_host).unwrap();
214+
neighbors.to_host(&res, &mut neighbors_host).unwrap();
215+
216+
// Verify results are consistent across searches
217+
assert_eq!(
218+
neighbors_host[[0, 0]],
219+
0,
220+
"iteration {}: first query should find itself",
221+
search_iter
222+
);
223+
}
224+
}
170225
}

rust/cuvs/src/ivf_flat/index.rs

Lines changed: 59 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
/*
2-
* SPDX-FileCopyrightText: Copyright (c) 2024-2025, NVIDIA CORPORATION.
2+
* SPDX-FileCopyrightText: Copyright (c) 2024-2026, NVIDIA CORPORATION.
33
* SPDX-License-Identifier: Apache-2.0
44
*/
55

@@ -59,7 +59,7 @@ impl Index {
5959
/// * `neighbors` - Matrix in device memory that receives the indices of the nearest neighbors
6060
/// * `distances` - Matrix in device memory that receives the distances of the nearest neighbors
6161
pub fn search(
62-
self,
62+
&self,
6363
res: &Resources,
6464
params: &SearchParams,
6565
queries: &ManagedTensor,
@@ -157,4 +157,61 @@ mod tests {
157157
assert_eq!(neighbors_host[[2, 0]], 2);
158158
assert_eq!(neighbors_host[[3, 0]], 3);
159159
}
160+
161+
/// Test that an index can be searched multiple times without rebuilding.
162+
/// This validates that search() takes &self instead of self.
163+
#[test]
164+
fn test_ivf_flat_multiple_searches() {
165+
let build_params = IndexParams::new().unwrap().set_n_lists(64);
166+
let res = Resources::new().unwrap();
167+
168+
// Create a random dataset
169+
let n_datapoints = 1024;
170+
let n_features = 16;
171+
let dataset =
172+
ndarray::Array::<f32, _>::random((n_datapoints, n_features), Uniform::new(0., 1.0));
173+
174+
let dataset_device = ManagedTensor::from(&dataset).to_device(&res).unwrap();
175+
176+
// Build the index once
177+
let index = Index::build(&res, &build_params, dataset_device)
178+
.expect("failed to create ivf-flat index");
179+
180+
let search_params = SearchParams::new().unwrap();
181+
let k = 5;
182+
183+
// Perform multiple searches on the same index
184+
for search_iter in 0..3 {
185+
let n_queries = 4;
186+
let queries = dataset.slice(s![0..n_queries, ..]);
187+
let queries = ManagedTensor::from(&queries).to_device(&res).unwrap();
188+
189+
let mut neighbors_host = ndarray::Array::<i64, _>::zeros((n_queries, k));
190+
let neighbors = ManagedTensor::from(&neighbors_host)
191+
.to_device(&res)
192+
.unwrap();
193+
194+
let mut distances_host = ndarray::Array::<f32, _>::zeros((n_queries, k));
195+
let distances = ManagedTensor::from(&distances_host)
196+
.to_device(&res)
197+
.unwrap();
198+
199+
// This should work on every iteration because search() takes &self
200+
index
201+
.search(&res, &search_params, &queries, &neighbors, &distances)
202+
.expect(&format!("search iteration {} failed", search_iter));
203+
204+
// Copy back to host memory
205+
distances.to_host(&res, &mut distances_host).unwrap();
206+
neighbors.to_host(&res, &mut neighbors_host).unwrap();
207+
208+
// Verify results are consistent
209+
assert_eq!(
210+
neighbors_host[[0, 0]],
211+
0,
212+
"iteration {}: first query should find itself",
213+
search_iter
214+
);
215+
}
216+
}
160217
}

0 commit comments

Comments
 (0)