fix(cagra): prevent recall degradation with search_width > 1 at large batch sizes#1841
fix(cagra): prevent recall degradation with search_width > 1 at large batch sizes#1841zbennett10 wants to merge 7 commits intorapidsai:mainfrom
Conversation
…ith search_width > 1 When the AUTO algorithm selector chooses SINGLE_CTA for large batch sizes, the max_iterations calculation uses itopk_size / search_width. With search_width > 1, this produces significantly fewer iterations than MULTI_CTA (which uses internal mc_search_width=1), causing recall to drop from ~0.30 to ~0.12 at batch_size >= 512. The fix adds a search_width <= 1 condition to the AUTO selector so that SINGLE_CTA is only chosen when search_width won't reduce its iteration count below MULTI_CTA's equivalent. When search_width > 1, MULTI_CTA is preferred since it handles parallelism via multiple CTAs per query rather than reduced iterations. Add Python regression test for batch recall consistency across different batch sizes and search_width values. Closes rapidsai#1187
|
/ok to test 197a261 |
|
/ok to test bfc9163 |
device_ndarray from pylibraft does not support Python slice indexing. Slice the host NumPy array first, then wrap with device_ndarray() to transfer each batch to the GPU.
|
/ok to test 7840309 |
|
Thank you for bringing this problem up and for the very descriptive proposal. |
@achirkin thank you for the feedback - I will look at this and get back to you with a benchmark or two and a possible alternative to what I've implemented here. I am thinking that instead of blocking This, I think, would preserve |
…ests Add benchmarks/cagra_recall_throughput_bench.py comparing MULTI_CTA, SINGLE_CTA (default), SINGLE_CTA (floor@32 iterations), and AUTO across batch sizes and search widths. This provides the performance data requested in review. Rewrite test_cagra_batch_recall.py with: - Module-scoped fixture for index build + brute-force ground truth - test_cagra_batch_recall_consistency: recall must not vary across batch sizes (catches the AUTO algorithm switch cliff) - test_cagra_search_width_recall_quality: minimum recall thresholds at batch_size=512 (catches SINGLE_CTA iteration deficit) - test_cagra_search_width_monotonicity: higher search_width must not reduce recall (catches the core inversion bug) Remove cupy dependency (use pylibraft.common.device_ndarray + numpy). Fix dataset_device lifetime: keep device array alive in fixture to prevent use-after-free in CAGRA index search.
|
@achirkin — just getting back to you. I've spent some time time benchmarking both the current fix and an alternative approach. Here are the full results. TL;DR
Benchmark Setup
The benchmark script ( 1. Throughput/Recall: SINGLE_CTA vs MULTI_CTAThese numbers are from directly forcing search_width=1 (no problem — both algorithms get the same iterations)
At search_width=4 (recall cliff begins)
Recall delta: −0.120 (16% relative drop). SINGLE_CTA computes search_width=8 (severe recall cliff)
Recall delta: −0.226 (26% relative drop). SINGLE_CTA computes 2. Alternative Fix: Floor SINGLE_CTA iterations at 32Per your suggestion, I tested whether increasing SINGLE_CTA's iterations to match MULTI_CTA's base count would recover recall while preserving throughput. The alternative: _max_iterations = std::max(itopk_size / search_width, (uint32_t)32);This gives SINGLE_CTA the exact same iteration count as MULTI_CTA (32 base + reachability = 36 total on this dataset). Results with floor@32 (at batch_size=512):
The floor@32 fix recovers almost no recall 😢 At Why the iteration floor doesn't workThe problem isn't just iteration count — it's architectural. Looking at
Even at equal iteration counts, MULTI_CTA's independent-CTA exploration pattern produces better graph traversal coverage than SINGLE_CTA's wide-but-shallow exploration. 3. Recall Consistency Across Batch Sizes (the user-facing bug)This is what users experience when the AUTO selector silently switches algorithms: Unpatched (AUTO mode — current code):
At Patched (force MULTI_CTA when search_width > 1):
Recall is perfectly consistent across all batch sizes. 4. Throughput ImpactThe fix only affects For A user who wants SINGLE_CTA's throughput with 5. Test SuiteThe updated
All 7 tests pass on patched code. On unpatched code, 3 fail (as expected). The recall gap between SINGLE_CTA and MULTI_CTA at Happy to discuss further or run additional experiments if helpful. |
|
/ok to test 4b0da50 |
Fix ruff format violations caught by pre-commit CI: - Reflow long lines in benchmark and test files - Add spaces around operators in f-string index expressions
|
/ok to test 6671ae4 |
|
/ok to test be584d8 |
Closes #1187
Description
Fixes inconsistent CAGRA search recall when query batch size changes, specifically when
search_width > 1.Root cause
The
AUTOalgorithm selector insearch_plan.cuhswitches fromMULTI_CTAtoSINGLE_CTAwhenitopk_size <= 512 && max_queries >= num_sm * 2. The problem is that these two algorithms computemax_iterationsvery differently whensearch_width > 1:mc_search_width=1, somax_iterations = mc_itopk_size / mc_search_width = 32/1 = 32, then addsmin_iterations(typically 6). Total: ~38 iterations.search_widthdirectly, somax_iterations = itopk_size / search_width. Withsearch_width=8anditopk_size=64:max_iterations = 64/8 = 8, plusmin_iterations(6). Total: ~14 iterations.This means at large batch sizes (e.g., 512+ on a machine with ≥256 SMs),
AUTOswitches toSINGLE_CTAwhich gets ~2.7x fewer iterations than theMULTI_CTApath used at smaller batch sizes, causing a significant recall drop.Fix
Added
search_width <= 1as an additional condition for selectingSINGLE_CTAin theAUTOpath. Whensearch_width > 1, the selector now always choosesMULTI_CTA, which handles the search_width parameter correctly with its own internal mc_search_width.Test
Added
test_cagra_batch_recall.py— a Python regression test that:search_widthvalues of 1, 4, and 8Checklist