Skip to content

Commit 46b5c40

Browse files
committed
feat(rust): add validated IndexParamsBuilder / SearchParamsBuilder
Add `IndexParams::builder()` / `SearchParams::builder()` entry points to the CAGRA, IVF-PQ, IVF-Flat, and Vamana Rust bindings. ## Problem The existing `IndexParams::new()?.set_graph_degree(0)` setter chain accepts invalid values silently. Errors surface as an opaque CUDA assertion inside `Index::build()` — up to 1.8s after the bad config was written, after GPU memory is already allocated. ## Solution Each `XxxParamsBuilder`: - Stores parameters as Rust-native values (no FFI struct allocated yet) - Exposes `validate() -> Result<()>`: pure-Rust constraint checks, no GPU work; callable before any device is present - Exposes `build() -> Result<XxxParams>`: calls `validate()`, then allocates the FFI struct via the existing `new()?.set_*()` chain - Fully additive — the existing `new()` + setter API is unchanged Validation rules enforced: - CAGRA `IndexParams`: graph_degree > 0, intermediate_graph_degree >= graph_degree, nn_descent_niter > 0 - CAGRA `SearchParams`: itopk_size is a power of 2 (or 0 for auto), team_size ∈ {0,4,8,16,32}, hashmap_max_fill_rate ∈ (0.1, 0.9) - IVF-PQ / IVF-Flat `IndexParams`: n_lists > 0, kmeans_trainset_fraction ∈ (0, 1] - Vamana `IndexParams`: graph_degree > 0, visited_size >= graph_degree, alpha > 0 ## Changes - `src/error.rs`: add `impl From<String> for Error` for ergonomic validation errors (private-field CuvsError constructed in-module) - `src/cagra/index_params.rs`: `IndexParamsBuilder` + `IndexParams::builder()` - `src/cagra/search_params.rs`: `SearchParamsBuilder` + `SearchParams::builder()` - `src/cagra/mod.rs`: re-export `IndexParamsBuilder`, `SearchParamsBuilder` - `examples/cagra.rs`: updated to demonstrate builder API - `src/ivf_pq/index_params.rs`: `IndexParamsBuilder` - `src/ivf_flat/index_params.rs`: `IndexParamsBuilder` - `src/vamana/index_params.rs`: `IndexParamsBuilder` - `src/{ivf_pq,ivf_flat,vamana}/mod.rs`: re-export `IndexParamsBuilder` ## Tests added (all 6 PR gate tests from issue spec) - `builder_rejects_zero_graph_degree` (validate only, no GPU) - `builder_rejects_invalid_intermediate_degree` (validate only) - `builder_rejects_zero_niter` (validate only) - `builder_accepts_valid_params` (validate only) - `builder_round_trips_to_ffi` (requires GPU — compares builder output to manual setter chain at FFI struct level) - `existing_setter_api_unchanged` (requires GPU — confirms no regression)
1 parent 105c61e commit 46b5c40

File tree

11 files changed

+985
-8
lines changed

11 files changed

+985
-8
lines changed

rust/cuvs/examples/cagra.rs

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,11 @@ use ndarray::s;
1010
use ndarray_rand::rand_distr::Uniform;
1111
use ndarray_rand::RandomExt;
1212

13-
/// Example showing how to index and search data with CAGRA
13+
/// Example showing how to index and search data with CAGRA using the validated builder API.
14+
///
15+
/// `IndexParams::builder()` validates parameters before any GPU allocation, surfacing
16+
/// misconfiguration immediately with a clear error message instead of an opaque CUDA
17+
/// assertion 1-2 seconds into `Index::build()`.
1418
fn cagra_example() -> Result<()> {
1519
let res = Resources::new()?;
1620

@@ -20,8 +24,14 @@ fn cagra_example() -> Result<()> {
2024
let dataset =
2125
ndarray::Array::<f32, _>::random((n_datapoints, n_features), Uniform::new(0., 1.0));
2226

23-
// build the cagra index
24-
let build_params = IndexParams::new()?;
27+
// Build the CAGRA index using the validated builder.
28+
// Parameters are checked in Rust before any FFI call — invalid values (e.g.
29+
// graph_degree=0) produce an error here, not inside Index::build().
30+
let build_params = IndexParams::builder()
31+
.graph_degree(32)
32+
.intermediate_graph_degree(64)
33+
.nn_descent_niter(20)
34+
.build()?;
2535
let index = Index::build(&res, &build_params, &dataset)?;
2636
println!(
2737
"Indexed {}x{} datapoints into cagra index",

rust/cuvs/src/cagra/index_params.rs

Lines changed: 225 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -137,6 +137,31 @@ impl IndexParams {
137137
}
138138
}
139139

140+
impl IndexParams {
141+
/// Returns a builder for constructing [`IndexParams`] with validated parameters.
142+
///
143+
/// Unlike the `IndexParams::new()?.set_*()` setter chain, [`IndexParamsBuilder::build`]
144+
/// validates all parameters in Rust before any FFI allocation. Invalid values produce a
145+
/// clear error message naming the offending field and its valid range, before any GPU
146+
/// work begins.
147+
///
148+
/// # Example
149+
///
150+
/// ```no_run
151+
/// use cuvs::cagra::IndexParams;
152+
///
153+
/// let params = IndexParams::builder()
154+
/// .graph_degree(32)
155+
/// .intermediate_graph_degree(64)
156+
/// .nn_descent_niter(20)
157+
/// .build()
158+
/// .unwrap();
159+
/// ```
160+
pub fn builder() -> IndexParamsBuilder {
161+
IndexParamsBuilder::default()
162+
}
163+
}
164+
140165
impl fmt::Debug for IndexParams {
141166
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
142167
// custom debug trait here, default value will show the pointer address
@@ -177,6 +202,115 @@ impl Drop for CompressionParams {
177202
}
178203
}
179204

205+
/// Builder for [`IndexParams`] with pre-validated parameters.
206+
///
207+
/// Construct via [`IndexParams::builder()`]. Call [`IndexParamsBuilder::build`] to
208+
/// validate all parameters and allocate the FFI struct in one step.
209+
///
210+
/// Defaults match the cuVS C API defaults: `graph_degree=64`,
211+
/// `intermediate_graph_degree=128`, `nn_descent_niter=20`.
212+
pub struct IndexParamsBuilder {
213+
graph_degree: usize,
214+
intermediate_graph_degree: usize,
215+
nn_descent_niter: usize,
216+
build_algo: Option<BuildAlgo>,
217+
compression: Option<CompressionParams>,
218+
}
219+
220+
impl Default for IndexParamsBuilder {
221+
fn default() -> Self {
222+
Self {
223+
graph_degree: 64,
224+
intermediate_graph_degree: 128,
225+
nn_descent_niter: 20,
226+
build_algo: None,
227+
compression: None,
228+
}
229+
}
230+
}
231+
232+
impl IndexParamsBuilder {
233+
/// Degree of output graph.
234+
///
235+
/// Must be > 0. Values that are multiples of 32 are preferred for warp alignment.
236+
pub fn graph_degree(mut self, v: usize) -> Self {
237+
self.graph_degree = v;
238+
self
239+
}
240+
241+
/// Degree of input graph for pruning.
242+
///
243+
/// Must be >= `graph_degree`.
244+
pub fn intermediate_graph_degree(mut self, v: usize) -> Self {
245+
self.intermediate_graph_degree = v;
246+
self
247+
}
248+
249+
/// Number of iterations to run if building with NN_DESCENT.
250+
///
251+
/// Must be > 0.
252+
pub fn nn_descent_niter(mut self, v: usize) -> Self {
253+
self.nn_descent_niter = v;
254+
self
255+
}
256+
257+
/// ANN algorithm to build knn graph.
258+
pub fn build_algo(mut self, v: BuildAlgo) -> Self {
259+
self.build_algo = Some(v);
260+
self
261+
}
262+
263+
/// Vector compression parameters.
264+
pub fn compression(mut self, v: CompressionParams) -> Self {
265+
self.compression = Some(v);
266+
self
267+
}
268+
269+
/// Validate all parameters without allocating any GPU resources.
270+
///
271+
/// Returns `Ok(())` if all parameters are valid, or `Err` with a message naming
272+
/// the offending field and its valid range.
273+
pub fn validate(&self) -> crate::error::Result<()> {
274+
if self.graph_degree == 0 {
275+
return Err(format!("graph_degree must be > 0; got {}", self.graph_degree).into());
276+
}
277+
if self.intermediate_graph_degree < self.graph_degree {
278+
return Err(format!(
279+
"intermediate_graph_degree ({}) must be >= graph_degree ({})",
280+
self.intermediate_graph_degree, self.graph_degree
281+
)
282+
.into());
283+
}
284+
if self.nn_descent_niter == 0 {
285+
return Err(format!(
286+
"nn_descent_niter must be > 0; got {}",
287+
self.nn_descent_niter
288+
)
289+
.into());
290+
}
291+
Ok(())
292+
}
293+
294+
/// Validate all parameters and allocate the FFI struct.
295+
///
296+
/// Returns `Err` with a message naming the offending field and its valid range
297+
/// before any GPU work begins.
298+
pub fn build(self) -> crate::error::Result<IndexParams> {
299+
self.validate()?;
300+
let mut params = IndexParams::new()?
301+
.set_graph_degree(self.graph_degree)
302+
.set_intermediate_graph_degree(self.intermediate_graph_degree)
303+
.set_nn_descent_niter(self.nn_descent_niter);
304+
if let Some(algo) = self.build_algo {
305+
params = params.set_build_algo(algo);
306+
}
307+
if let Some(compression) = self.compression {
308+
params = params.set_compression(compression);
309+
}
310+
Ok(params)
311+
}
312+
}
313+
180314
#[cfg(test)]
181315
mod tests {
182316
use super::*;
@@ -206,4 +340,95 @@ mod tests {
206340
assert_eq!((*(*params.0).compression).pq_bits, 4);
207341
}
208342
}
343+
344+
// --- IndexParamsBuilder tests ---
345+
346+
#[test]
347+
fn builder_rejects_zero_graph_degree() {
348+
let err = IndexParams::builder()
349+
.graph_degree(0)
350+
.validate()
351+
.unwrap_err();
352+
assert!(
353+
err.to_string().contains("graph_degree"),
354+
"error message should name the field: {err}"
355+
);
356+
}
357+
358+
#[test]
359+
fn builder_rejects_invalid_intermediate_degree() {
360+
let err = IndexParams::builder()
361+
.graph_degree(32)
362+
.intermediate_graph_degree(16)
363+
.validate()
364+
.unwrap_err();
365+
assert!(
366+
err.to_string().contains("intermediate_graph_degree"),
367+
"error message should name the field: {err}"
368+
);
369+
}
370+
371+
#[test]
372+
fn builder_rejects_zero_niter() {
373+
let err = IndexParams::builder()
374+
.nn_descent_niter(0)
375+
.validate()
376+
.unwrap_err();
377+
assert!(
378+
err.to_string().contains("nn_descent_niter"),
379+
"error message should name the field: {err}"
380+
);
381+
}
382+
383+
#[test]
384+
fn builder_accepts_valid_params() {
385+
assert!(IndexParams::builder()
386+
.graph_degree(32)
387+
.intermediate_graph_degree(64)
388+
.nn_descent_niter(20)
389+
.validate()
390+
.is_ok());
391+
}
392+
393+
#[test]
394+
fn builder_round_trips_to_ffi() {
395+
// Built params must produce the same FFI struct values as the manual setter chain.
396+
let via_builder = IndexParams::builder()
397+
.graph_degree(32)
398+
.intermediate_graph_degree(64)
399+
.nn_descent_niter(20)
400+
.build()
401+
.unwrap();
402+
let via_setters = IndexParams::new()
403+
.unwrap()
404+
.set_graph_degree(32)
405+
.set_intermediate_graph_degree(64)
406+
.set_nn_descent_niter(20);
407+
unsafe {
408+
assert_eq!((*via_builder.0).graph_degree, (*via_setters.0).graph_degree);
409+
assert_eq!(
410+
(*via_builder.0).intermediate_graph_degree,
411+
(*via_setters.0).intermediate_graph_degree
412+
);
413+
assert_eq!(
414+
(*via_builder.0).nn_descent_niter,
415+
(*via_setters.0).nn_descent_niter
416+
);
417+
}
418+
}
419+
420+
#[test]
421+
fn existing_setter_api_unchanged() {
422+
// Ensure the original API still compiles and sets values correctly.
423+
let params = IndexParams::new()
424+
.unwrap()
425+
.set_graph_degree(32)
426+
.set_intermediate_graph_degree(64)
427+
.set_nn_descent_niter(20);
428+
unsafe {
429+
assert_eq!((*params.0).graph_degree, 32);
430+
assert_eq!((*params.0).intermediate_graph_degree, 64);
431+
assert_eq!((*params.0).nn_descent_niter, 20);
432+
}
433+
}
209434
}

rust/cuvs/src/cagra/mod.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -71,5 +71,5 @@ mod index_params;
7171
mod search_params;
7272

7373
pub use index::Index;
74-
pub use index_params::{BuildAlgo, CompressionParams, IndexParams};
75-
pub use search_params::{HashMode, SearchAlgo, SearchParams};
74+
pub use index_params::{BuildAlgo, CompressionParams, IndexParams, IndexParamsBuilder};
75+
pub use search_params::{HashMode, SearchAlgo, SearchParams, SearchParamsBuilder};

0 commit comments

Comments
 (0)