55//! Brute Force KNN
66
77use std:: io:: { stderr, Write } ;
8+ use std:: marker:: PhantomData ;
89
910use crate :: distance_type:: DistanceType ;
10- use crate :: dlpack:: ManagedTensor ;
11+ use crate :: dlpack:: { DatasetOwnership , ManagedTensor } ;
1112use crate :: error:: { check_cuvs, Result } ;
1213use crate :: resources:: Resources ;
1314
1415/// Brute Force KNN Index
16+ ///
17+ /// The brute force C library stores a non-owning view into the original dataset.
18+ /// The lifetime parameter `'a` ensures the dataset outlives the index when built
19+ /// with [`Index::build`]. Use [`Index::build_owned`] for a self-contained index
20+ /// that owns its dataset (e.g., after [`ManagedTensor::to_device`]).
21+ ///
22+ /// # Examples
23+ ///
24+ /// ## Borrowed dataset (compiler enforces lifetime)
25+ ///
26+ /// ```no_run
27+ /// # use cuvs::{ManagedTensor, Resources};
28+ /// # use cuvs::brute_force::Index;
29+ /// # use cuvs::distance_type::DistanceType;
30+ /// let res = Resources::new().unwrap();
31+ /// let arr = ndarray::Array::<f32, _>::zeros((64, 8));
32+ /// let tensor = ManagedTensor::from(&arr);
33+ /// let index = Index::build(&res, DistanceType::L2Expanded, None, &tensor).unwrap();
34+ /// // arr and tensor must remain alive while index is in use
35+ /// ```
36+ ///
37+ /// ## Owned dataset ('static lifetime)
38+ ///
39+ /// ```no_run
40+ /// # use cuvs::{ManagedTensor, Resources};
41+ /// # use cuvs::brute_force::Index;
42+ /// # use cuvs::distance_type::DistanceType;
43+ /// let res = Resources::new().unwrap();
44+ /// let arr = ndarray::Array::<f32, _>::zeros((64, 8));
45+ /// let device_tensor = ManagedTensor::from(&arr).to_device(&res).unwrap();
46+ /// let index = Index::build_owned(&res, DistanceType::L2Expanded, None, device_tensor).unwrap();
47+ /// drop(arr); // Fine — index owns the device copy
48+ /// ```
1549#[ derive( Debug ) ]
16- pub struct Index ( ffi:: cuvsBruteForceIndex_t ) ;
50+ pub struct Index < ' a > {
51+ inner : ffi:: cuvsBruteForceIndex_t ,
52+ _data : DatasetOwnership < ' a > ,
53+ }
1754
18- impl Index {
19- /// Builds a new Brute Force KNN Index from the dataset for efficient search.
55+ impl < ' a > Index < ' a > {
56+ /// Creates a new FFI index handle.
57+ fn create_handle ( ) -> Result < ffi:: cuvsBruteForceIndex_t > {
58+ unsafe {
59+ let mut index = std:: mem:: MaybeUninit :: < ffi:: cuvsBruteForceIndex_t > :: uninit ( ) ;
60+ check_cuvs ( ffi:: cuvsBruteForceIndexCreate ( index. as_mut_ptr ( ) ) ) ?;
61+ Ok ( index. assume_init ( ) )
62+ }
63+ }
64+
65+ /// Builds a new Brute Force KNN Index from a borrowed dataset.
66+ ///
67+ /// The compiler enforces that `dataset` outlives the returned index,
68+ /// preventing use-after-free when the C library dereferences its
69+ /// internal view of the data.
2070 ///
2171 /// # Arguments
2272 ///
2373 /// * `res` - Resources to use
2474 /// * `metric` - DistanceType to use for building the index
2575 /// * `metric_arg` - Optional value of `p` for Minkowski distances
2676 /// * `dataset` - A row-major matrix on either the host or device to index
27- pub fn build < T : Into < ManagedTensor > > (
77+ pub fn build (
2878 res : & Resources ,
2979 metric : DistanceType ,
3080 metric_arg : Option < f32 > ,
31- dataset : T ,
32- ) -> Result < Index > {
33- let dataset: ManagedTensor = dataset. into ( ) ;
34- let index = Index :: new ( ) ?;
81+ dataset : & ' a ManagedTensor ,
82+ ) -> Result < Index < ' a > > {
83+ let inner = Self :: create_handle ( ) ?;
3584 unsafe {
3685 check_cuvs ( ffi:: cuvsBruteForceBuild (
3786 res. 0 ,
3887 dataset. as_ptr ( ) ,
3988 metric,
4089 metric_arg. unwrap_or ( 2.0 ) ,
41- index . 0 ,
90+ inner ,
4291 ) ) ?;
4392 }
44- Ok ( index)
93+ Ok ( Index {
94+ inner,
95+ _data : DatasetOwnership :: Borrowed ( PhantomData ) ,
96+ } )
4597 }
4698
4799 /// Creates a new empty index
48- pub fn new ( ) -> Result < Index > {
49- unsafe {
50- let mut index = std:: mem:: MaybeUninit :: < ffi:: cuvsBruteForceIndex_t > :: uninit ( ) ;
51- check_cuvs ( ffi:: cuvsBruteForceIndexCreate ( index. as_mut_ptr ( ) ) ) ?;
52- Ok ( Index ( index. assume_init ( ) ) )
53- }
100+ pub fn new ( ) -> Result < Index < ' a > > {
101+ Ok ( Index {
102+ inner : Self :: create_handle ( ) ?,
103+ _data : DatasetOwnership :: Borrowed ( PhantomData ) ,
104+ } )
54105 }
55106
56107 /// Perform a Nearest Neighbors search on the Index
@@ -76,7 +127,7 @@ impl Index {
76127
77128 check_cuvs ( ffi:: cuvsBruteForceSearch (
78129 res. 0 ,
79- self . 0 ,
130+ self . inner ,
80131 queries. as_ptr ( ) ,
81132 neighbors. as_ptr ( ) ,
82133 distances. as_ptr ( ) ,
@@ -86,9 +137,46 @@ impl Index {
86137 }
87138}
88139
89- impl Drop for Index {
140+ impl Index < ' static > {
141+ /// Builds a new Brute Force KNN Index from an owned dataset.
142+ ///
143+ /// The index takes ownership of `dataset`, making it self-contained
144+ /// with a `'static` lifetime. This is useful when the dataset is a
145+ /// device copy (from [`ManagedTensor::to_device`]) that should live
146+ /// as long as the index.
147+ ///
148+ /// # Arguments
149+ ///
150+ /// * `res` - Resources to use
151+ /// * `metric` - DistanceType to use for building the index
152+ /// * `metric_arg` - Optional value of `p` for Minkowski distances
153+ /// * `dataset` - A row-major matrix to index (ownership transferred to the index)
154+ pub fn build_owned (
155+ res : & Resources ,
156+ metric : DistanceType ,
157+ metric_arg : Option < f32 > ,
158+ dataset : ManagedTensor ,
159+ ) -> Result < Index < ' static > > {
160+ let inner = Self :: create_handle ( ) ?;
161+ unsafe {
162+ check_cuvs ( ffi:: cuvsBruteForceBuild (
163+ res. 0 ,
164+ dataset. as_ptr ( ) ,
165+ metric,
166+ metric_arg. unwrap_or ( 2.0 ) ,
167+ inner,
168+ ) ) ?;
169+ }
170+ Ok ( Index {
171+ inner,
172+ _data : DatasetOwnership :: Owned ( dataset) ,
173+ } )
174+ }
175+ }
176+
177+ impl Drop for Index < ' _ > {
90178 fn drop ( & mut self ) {
91- if let Err ( e) = check_cuvs ( unsafe { ffi:: cuvsBruteForceIndexDestroy ( self . 0 ) } ) {
179+ if let Err ( e) = check_cuvs ( unsafe { ffi:: cuvsBruteForceIndexDestroy ( self . inner ) } ) {
92180 write ! ( stderr( ) , "failed to call bruteForceIndexDestroy {:?}" , e)
93181 . expect ( "failed to write to stderr" ) ;
94182 }
@@ -116,9 +204,9 @@ mod tests {
116204
117205 println ! ( "dataset {:#?}" , dataset_host) ;
118206
119- // build the brute force index
120- let index =
121- Index :: build ( & res , metric , None , dataset ) . expect ( "failed to create brute force index" ) ;
207+ // build the brute force index (owned — device copy lives in the index)
208+ let index = Index :: build_owned ( & res , metric , None , dataset )
209+ . expect ( "failed to create brute force index" ) ;
122210
123211 res. sync_stream ( ) . unwrap ( ) ;
124212
@@ -173,10 +261,104 @@ mod tests {
173261 test_bfknn ( DistanceType :: L2Expanded ) ;
174262 }
175263
176- // NOTE: brute_force multiple-search test is omitted here because the C++
177- // brute_force::index stores a non-owning view into the dataset. Building
178- // from device data via `build()` drops the ManagedTensor after the call,
179- // leaving a dangling pointer. A follow-up PR will add dataset lifetime
180- // enforcement (DatasetOwnership<'a>) to make this safe.
181- // See: https://github.com/rapidsai/cuvs/issues/1838
264+ /// Test that an index built with build_owned can be searched multiple times.
265+ #[ test]
266+ fn test_brute_force_multiple_searches ( ) {
267+ let res = Resources :: new ( ) . unwrap ( ) ;
268+
269+ // Create a random dataset
270+ let n_datapoints = 64 ;
271+ let n_features = 8 ;
272+ let dataset =
273+ ndarray:: Array :: < f32 , _ > :: random ( ( n_datapoints, n_features) , Uniform :: new ( 0. , 1.0 ) ) ;
274+
275+ // Build the brute force index with owned device memory
276+ let dataset_device = ManagedTensor :: from ( & dataset) . to_device ( & res) . unwrap ( ) ;
277+ let index = Index :: build_owned ( & res, DistanceType :: L2Expanded , None , dataset_device)
278+ . expect ( "failed to create brute force index" ) ;
279+
280+ res. sync_stream ( ) . unwrap ( ) ;
281+
282+ let k = 4 ;
283+
284+ // Perform multiple searches on the same index
285+ for search_iter in 0 ..3 {
286+ let n_queries = 4 ;
287+ let queries = dataset. slice ( s ! [ 0 ..n_queries, ..] ) ;
288+ let queries = ManagedTensor :: from ( & queries) . to_device ( & res) . unwrap ( ) ;
289+
290+ let mut neighbors_host = ndarray:: Array :: < i64 , _ > :: zeros ( ( n_queries, k) ) ;
291+ let neighbors = ManagedTensor :: from ( & neighbors_host)
292+ . to_device ( & res)
293+ . unwrap ( ) ;
294+
295+ let mut distances_host = ndarray:: Array :: < f32 , _ > :: zeros ( ( n_queries, k) ) ;
296+ let distances = ManagedTensor :: from ( & distances_host)
297+ . to_device ( & res)
298+ . unwrap ( ) ;
299+
300+ index
301+ . search ( & res, & queries, & neighbors, & distances)
302+ . unwrap_or_else ( |e| panic ! ( "search iteration {} failed: {}" , search_iter, e) ) ;
303+
304+ // Copy back to host memory
305+ distances. to_host ( & res, & mut distances_host) . unwrap ( ) ;
306+ neighbors. to_host ( & res, & mut neighbors_host) . unwrap ( ) ;
307+ res. sync_stream ( ) . unwrap ( ) ;
308+
309+ // Verify results are consistent
310+ assert_eq ! (
311+ neighbors_host[ [ 0 , 0 ] ] ,
312+ 0 ,
313+ "iteration {}: first query should find itself" ,
314+ search_iter
315+ ) ;
316+ }
317+ }
318+
319+ /// Test that an index built with build (borrowed) ties the dataset lifetime.
320+ #[ test]
321+ fn test_brute_force_borrowed_build ( ) {
322+ let res = Resources :: new ( ) . unwrap ( ) ;
323+
324+ let n_datapoints = 64 ;
325+ let n_features = 8 ;
326+ let dataset_host =
327+ ndarray:: Array :: < f32 , _ > :: random ( ( n_datapoints, n_features) , Uniform :: new ( 0. , 1.0 ) ) ;
328+
329+ // Create a device tensor and borrow it for the index
330+ let dataset_device = ManagedTensor :: from ( & dataset_host) . to_device ( & res) . unwrap ( ) ;
331+ let index = Index :: build ( & res, DistanceType :: L2Expanded , None , & dataset_device)
332+ . expect ( "failed to create brute force index" ) ;
333+
334+ res. sync_stream ( ) . unwrap ( ) ;
335+
336+ // Search while the borrowed dataset is still alive
337+ let n_queries = 4 ;
338+ let k = 4 ;
339+ let queries = dataset_host. slice ( s ! [ 0 ..n_queries, ..] ) ;
340+ let queries = ManagedTensor :: from ( & queries) . to_device ( & res) . unwrap ( ) ;
341+
342+ let mut neighbors_host = ndarray:: Array :: < i64 , _ > :: zeros ( ( n_queries, k) ) ;
343+ let neighbors = ManagedTensor :: from ( & neighbors_host)
344+ . to_device ( & res)
345+ . unwrap ( ) ;
346+
347+ let mut distances_host = ndarray:: Array :: < f32 , _ > :: zeros ( ( n_queries, k) ) ;
348+ let distances = ManagedTensor :: from ( & distances_host)
349+ . to_device ( & res)
350+ . unwrap ( ) ;
351+
352+ index
353+ . search ( & res, & queries, & neighbors, & distances)
354+ . unwrap ( ) ;
355+
356+ distances. to_host ( & res, & mut distances_host) . unwrap ( ) ;
357+ neighbors. to_host ( & res, & mut neighbors_host) . unwrap ( ) ;
358+ res. sync_stream ( ) . unwrap ( ) ;
359+
360+ assert_eq ! ( neighbors_host[ [ 0 , 0 ] ] , 0 ) ;
361+ assert_eq ! ( neighbors_host[ [ 1 , 0 ] ] , 1 ) ;
362+ // dataset_device is still alive here — compiler ensures it
363+ }
182364}
0 commit comments