Skip to content

Commit e006925

Browse files
committed
feat(rust)!: add compile-time dataset lifetime safety for all index types
Introduces `DatasetOwnership<'a>` enum to track whether an index borrows or owns its dataset. All four index types (brute_force, cagra, ivf_flat, ivf_pq) now use `Index<'a>` with dual constructors: - `build(&'a ManagedTensor)` — borrowed, compiler enforces dataset outlives index - `build_owned(ManagedTensor)` — owned ('static), index is self-contained This prevents use-after-free when the C library stores a non-owning view of the dataset. The previous `build<T: Into<ManagedTensor>>` API could not enforce that the original data remained alive. BREAKING CHANGE: `Index::build()` now takes `&ManagedTensor` instead of `impl Into<ManagedTensor>`. Use `build_owned()` for the old move semantics.
1 parent 7844bbc commit e006925

File tree

8 files changed

+696
-107
lines changed

8 files changed

+696
-107
lines changed

rust/cuvs/src/brute_force.rs

Lines changed: 211 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -5,52 +5,103 @@
55
//! Brute Force KNN
66
77
use std::io::{stderr, Write};
8+
use std::marker::PhantomData;
89

910
use crate::distance_type::DistanceType;
10-
use crate::dlpack::ManagedTensor;
11+
use crate::dlpack::{DatasetOwnership, ManagedTensor};
1112
use crate::error::{check_cuvs, Result};
1213
use crate::resources::Resources;
1314

1415
/// Brute Force KNN Index
16+
///
17+
/// The brute force C library stores a non-owning view into the original dataset.
18+
/// The lifetime parameter `'a` ensures the dataset outlives the index when built
19+
/// with [`Index::build`]. Use [`Index::build_owned`] for a self-contained index
20+
/// that owns its dataset (e.g., after [`ManagedTensor::to_device`]).
21+
///
22+
/// # Examples
23+
///
24+
/// ## Borrowed dataset (compiler enforces lifetime)
25+
///
26+
/// ```no_run
27+
/// # use cuvs::{ManagedTensor, Resources};
28+
/// # use cuvs::brute_force::Index;
29+
/// # use cuvs::distance_type::DistanceType;
30+
/// let res = Resources::new().unwrap();
31+
/// let arr = ndarray::Array::<f32, _>::zeros((64, 8));
32+
/// let tensor = ManagedTensor::from(&arr);
33+
/// let index = Index::build(&res, DistanceType::L2Expanded, None, &tensor).unwrap();
34+
/// // arr and tensor must remain alive while index is in use
35+
/// ```
36+
///
37+
/// ## Owned dataset ('static lifetime)
38+
///
39+
/// ```no_run
40+
/// # use cuvs::{ManagedTensor, Resources};
41+
/// # use cuvs::brute_force::Index;
42+
/// # use cuvs::distance_type::DistanceType;
43+
/// let res = Resources::new().unwrap();
44+
/// let arr = ndarray::Array::<f32, _>::zeros((64, 8));
45+
/// let device_tensor = ManagedTensor::from(&arr).to_device(&res).unwrap();
46+
/// let index = Index::build_owned(&res, DistanceType::L2Expanded, None, device_tensor).unwrap();
47+
/// drop(arr); // Fine — index owns the device copy
48+
/// ```
1549
#[derive(Debug)]
16-
pub struct Index(ffi::cuvsBruteForceIndex_t);
50+
pub struct Index<'a> {
51+
inner: ffi::cuvsBruteForceIndex_t,
52+
_data: DatasetOwnership<'a>,
53+
}
1754

18-
impl Index {
19-
/// Builds a new Brute Force KNN Index from the dataset for efficient search.
55+
impl<'a> Index<'a> {
56+
/// Creates a new FFI index handle.
57+
fn create_handle() -> Result<ffi::cuvsBruteForceIndex_t> {
58+
unsafe {
59+
let mut index = std::mem::MaybeUninit::<ffi::cuvsBruteForceIndex_t>::uninit();
60+
check_cuvs(ffi::cuvsBruteForceIndexCreate(index.as_mut_ptr()))?;
61+
Ok(index.assume_init())
62+
}
63+
}
64+
65+
/// Builds a new Brute Force KNN Index from a borrowed dataset.
66+
///
67+
/// The compiler enforces that `dataset` outlives the returned index,
68+
/// preventing use-after-free when the C library dereferences its
69+
/// internal view of the data.
2070
///
2171
/// # Arguments
2272
///
2373
/// * `res` - Resources to use
2474
/// * `metric` - DistanceType to use for building the index
2575
/// * `metric_arg` - Optional value of `p` for Minkowski distances
2676
/// * `dataset` - A row-major matrix on either the host or device to index
27-
pub fn build<T: Into<ManagedTensor>>(
77+
pub fn build(
2878
res: &Resources,
2979
metric: DistanceType,
3080
metric_arg: Option<f32>,
31-
dataset: T,
32-
) -> Result<Index> {
33-
let dataset: ManagedTensor = dataset.into();
34-
let index = Index::new()?;
81+
dataset: &'a ManagedTensor,
82+
) -> Result<Index<'a>> {
83+
let inner = Self::create_handle()?;
3584
unsafe {
3685
check_cuvs(ffi::cuvsBruteForceBuild(
3786
res.0,
3887
dataset.as_ptr(),
3988
metric,
4089
metric_arg.unwrap_or(2.0),
41-
index.0,
90+
inner,
4291
))?;
4392
}
44-
Ok(index)
93+
Ok(Index {
94+
inner,
95+
_data: DatasetOwnership::Borrowed(PhantomData),
96+
})
4597
}
4698

4799
/// Creates a new empty index
48-
pub fn new() -> Result<Index> {
49-
unsafe {
50-
let mut index = std::mem::MaybeUninit::<ffi::cuvsBruteForceIndex_t>::uninit();
51-
check_cuvs(ffi::cuvsBruteForceIndexCreate(index.as_mut_ptr()))?;
52-
Ok(Index(index.assume_init()))
53-
}
100+
pub fn new() -> Result<Index<'a>> {
101+
Ok(Index {
102+
inner: Self::create_handle()?,
103+
_data: DatasetOwnership::Borrowed(PhantomData),
104+
})
54105
}
55106

56107
/// Perform a Nearest Neighbors search on the Index
@@ -76,7 +127,7 @@ impl Index {
76127

77128
check_cuvs(ffi::cuvsBruteForceSearch(
78129
res.0,
79-
self.0,
130+
self.inner,
80131
queries.as_ptr(),
81132
neighbors.as_ptr(),
82133
distances.as_ptr(),
@@ -86,9 +137,46 @@ impl Index {
86137
}
87138
}
88139

89-
impl Drop for Index {
140+
impl Index<'static> {
141+
/// Builds a new Brute Force KNN Index from an owned dataset.
142+
///
143+
/// The index takes ownership of `dataset`, making it self-contained
144+
/// with a `'static` lifetime. This is useful when the dataset is a
145+
/// device copy (from [`ManagedTensor::to_device`]) that should live
146+
/// as long as the index.
147+
///
148+
/// # Arguments
149+
///
150+
/// * `res` - Resources to use
151+
/// * `metric` - DistanceType to use for building the index
152+
/// * `metric_arg` - Optional value of `p` for Minkowski distances
153+
/// * `dataset` - A row-major matrix to index (ownership transferred to the index)
154+
pub fn build_owned(
155+
res: &Resources,
156+
metric: DistanceType,
157+
metric_arg: Option<f32>,
158+
dataset: ManagedTensor,
159+
) -> Result<Index<'static>> {
160+
let inner = Self::create_handle()?;
161+
unsafe {
162+
check_cuvs(ffi::cuvsBruteForceBuild(
163+
res.0,
164+
dataset.as_ptr(),
165+
metric,
166+
metric_arg.unwrap_or(2.0),
167+
inner,
168+
))?;
169+
}
170+
Ok(Index {
171+
inner,
172+
_data: DatasetOwnership::Owned(dataset),
173+
})
174+
}
175+
}
176+
177+
impl Drop for Index<'_> {
90178
fn drop(&mut self) {
91-
if let Err(e) = check_cuvs(unsafe { ffi::cuvsBruteForceIndexDestroy(self.0) }) {
179+
if let Err(e) = check_cuvs(unsafe { ffi::cuvsBruteForceIndexDestroy(self.inner) }) {
92180
write!(stderr(), "failed to call bruteForceIndexDestroy {:?}", e)
93181
.expect("failed to write to stderr");
94182
}
@@ -116,9 +204,9 @@ mod tests {
116204

117205
println!("dataset {:#?}", dataset_host);
118206

119-
// build the brute force index
120-
let index =
121-
Index::build(&res, metric, None, dataset).expect("failed to create brute force index");
207+
// build the brute force index (owned — device copy lives in the index)
208+
let index = Index::build_owned(&res, metric, None, dataset)
209+
.expect("failed to create brute force index");
122210

123211
res.sync_stream().unwrap();
124212

@@ -173,10 +261,104 @@ mod tests {
173261
test_bfknn(DistanceType::L2Expanded);
174262
}
175263

176-
// NOTE: brute_force multiple-search test is omitted here because the C++
177-
// brute_force::index stores a non-owning view into the dataset. Building
178-
// from device data via `build()` drops the ManagedTensor after the call,
179-
// leaving a dangling pointer. A follow-up PR will add dataset lifetime
180-
// enforcement (DatasetOwnership<'a>) to make this safe.
181-
// See: https://github.com/rapidsai/cuvs/issues/1838
264+
/// Test that an index built with build_owned can be searched multiple times.
265+
#[test]
266+
fn test_brute_force_multiple_searches() {
267+
let res = Resources::new().unwrap();
268+
269+
// Create a random dataset
270+
let n_datapoints = 64;
271+
let n_features = 8;
272+
let dataset =
273+
ndarray::Array::<f32, _>::random((n_datapoints, n_features), Uniform::new(0., 1.0));
274+
275+
// Build the brute force index with owned device memory
276+
let dataset_device = ManagedTensor::from(&dataset).to_device(&res).unwrap();
277+
let index = Index::build_owned(&res, DistanceType::L2Expanded, None, dataset_device)
278+
.expect("failed to create brute force index");
279+
280+
res.sync_stream().unwrap();
281+
282+
let k = 4;
283+
284+
// Perform multiple searches on the same index
285+
for search_iter in 0..3 {
286+
let n_queries = 4;
287+
let queries = dataset.slice(s![0..n_queries, ..]);
288+
let queries = ManagedTensor::from(&queries).to_device(&res).unwrap();
289+
290+
let mut neighbors_host = ndarray::Array::<i64, _>::zeros((n_queries, k));
291+
let neighbors = ManagedTensor::from(&neighbors_host)
292+
.to_device(&res)
293+
.unwrap();
294+
295+
let mut distances_host = ndarray::Array::<f32, _>::zeros((n_queries, k));
296+
let distances = ManagedTensor::from(&distances_host)
297+
.to_device(&res)
298+
.unwrap();
299+
300+
index
301+
.search(&res, &queries, &neighbors, &distances)
302+
.unwrap_or_else(|e| panic!("search iteration {} failed: {}", search_iter, e));
303+
304+
// Copy back to host memory
305+
distances.to_host(&res, &mut distances_host).unwrap();
306+
neighbors.to_host(&res, &mut neighbors_host).unwrap();
307+
res.sync_stream().unwrap();
308+
309+
// Verify results are consistent
310+
assert_eq!(
311+
neighbors_host[[0, 0]],
312+
0,
313+
"iteration {}: first query should find itself",
314+
search_iter
315+
);
316+
}
317+
}
318+
319+
/// Test that an index built with build (borrowed) ties the dataset lifetime.
320+
#[test]
321+
fn test_brute_force_borrowed_build() {
322+
let res = Resources::new().unwrap();
323+
324+
let n_datapoints = 64;
325+
let n_features = 8;
326+
let dataset_host =
327+
ndarray::Array::<f32, _>::random((n_datapoints, n_features), Uniform::new(0., 1.0));
328+
329+
// Create a device tensor and borrow it for the index
330+
let dataset_device = ManagedTensor::from(&dataset_host).to_device(&res).unwrap();
331+
let index = Index::build(&res, DistanceType::L2Expanded, None, &dataset_device)
332+
.expect("failed to create brute force index");
333+
334+
res.sync_stream().unwrap();
335+
336+
// Search while the borrowed dataset is still alive
337+
let n_queries = 4;
338+
let k = 4;
339+
let queries = dataset_host.slice(s![0..n_queries, ..]);
340+
let queries = ManagedTensor::from(&queries).to_device(&res).unwrap();
341+
342+
let mut neighbors_host = ndarray::Array::<i64, _>::zeros((n_queries, k));
343+
let neighbors = ManagedTensor::from(&neighbors_host)
344+
.to_device(&res)
345+
.unwrap();
346+
347+
let mut distances_host = ndarray::Array::<f32, _>::zeros((n_queries, k));
348+
let distances = ManagedTensor::from(&distances_host)
349+
.to_device(&res)
350+
.unwrap();
351+
352+
index
353+
.search(&res, &queries, &neighbors, &distances)
354+
.unwrap();
355+
356+
distances.to_host(&res, &mut distances_host).unwrap();
357+
neighbors.to_host(&res, &mut neighbors_host).unwrap();
358+
res.sync_stream().unwrap();
359+
360+
assert_eq!(neighbors_host[[0, 0]], 0);
361+
assert_eq!(neighbors_host[[1, 0]], 1);
362+
// dataset_device is still alive here — compiler ensures it
363+
}
182364
}

0 commit comments

Comments
 (0)