Skip to content

fix(cagra): prevent recall degradation with search_width > 1 at large batch sizes#1841

Open
zbennett10 wants to merge 7 commits intorapidsai:mainfrom
zbennett10:fix/cagra-batch-recall-1187
Open

fix(cagra): prevent recall degradation with search_width > 1 at large batch sizes#1841
zbennett10 wants to merge 7 commits intorapidsai:mainfrom
zbennett10:fix/cagra-batch-recall-1187

Conversation

@zbennett10
Copy link
Contributor

Closes #1187

Description

Fixes inconsistent CAGRA search recall when query batch size changes, specifically when search_width > 1.

Root cause

The AUTO algorithm selector in search_plan.cuh switches from MULTI_CTA to SINGLE_CTA when itopk_size <= 512 && max_queries >= num_sm * 2. The problem is that these two algorithms compute max_iterations very differently when search_width > 1:

  • MULTI_CTA: Uses its own internal mc_search_width=1, so max_iterations = mc_itopk_size / mc_search_width = 32/1 = 32, then adds min_iterations (typically 6). Total: ~38 iterations.
  • SINGLE_CTA: Uses the user's search_width directly, so max_iterations = itopk_size / search_width. With search_width=8 and itopk_size=64: max_iterations = 64/8 = 8, plus min_iterations (6). Total: ~14 iterations.

This means at large batch sizes (e.g., 512+ on a machine with ≥256 SMs), AUTO switches to SINGLE_CTA which gets ~2.7x fewer iterations than the MULTI_CTA path used at smaller batch sizes, causing a significant recall drop.

Fix

Added search_width <= 1 as an additional condition for selecting SINGLE_CTA in the AUTO path. When search_width > 1, the selector now always chooses MULTI_CTA, which handles the search_width parameter correctly with its own internal mc_search_width.

Test

Added test_cagra_batch_recall.py — a Python regression test that:

  • Tests search_width values of 1, 4, and 8
  • Compares recall across batch sizes [64, 256, 512, 1024]
  • Uses single-query ground truth (always MULTI_CTA) as reference
  • Asserts recall std < 0.02 and recall range < 0.05 across batch sizes

Checklist

  • I am familiar with the Contributing Guidelines
  • New or existing tests cover these changes
  • The documentation is up to date with these changes

…ith search_width > 1

When the AUTO algorithm selector chooses SINGLE_CTA for large batch
sizes, the max_iterations calculation uses itopk_size / search_width.
With search_width > 1, this produces significantly fewer iterations
than MULTI_CTA (which uses internal mc_search_width=1), causing
recall to drop from ~0.30 to ~0.12 at batch_size >= 512.

The fix adds a search_width <= 1 condition to the AUTO selector so
that SINGLE_CTA is only chosen when search_width won't reduce its
iteration count below MULTI_CTA's equivalent. When search_width > 1,
MULTI_CTA is preferred since it handles parallelism via multiple CTAs
per query rather than reduced iterations.

Add Python regression test for batch recall consistency across
different batch sizes and search_width values.

Closes rapidsai#1187
@copy-pr-bot
Copy link

copy-pr-bot bot commented Feb 23, 2026

This pull request requires additional validation before any workflows can run on NVIDIA's runners.

Pull request vetters can view their responsibilities here.

Contributors can view more details about this message here.

@aamijar aamijar added non-breaking Introduces a non-breaking change Rust improvement Improves an existing functionality labels Feb 23, 2026
@aamijar
Copy link
Member

aamijar commented Feb 23, 2026

/ok to test 197a261

@aamijar
Copy link
Member

aamijar commented Feb 23, 2026

/ok to test bfc9163

device_ndarray from pylibraft does not support Python slice indexing.
Slice the host NumPy array first, then wrap with device_ndarray()
to transfer each batch to the GPU.
@aamijar
Copy link
Member

aamijar commented Feb 24, 2026

/ok to test 7840309

@achirkin
Copy link
Contributor

achirkin commented Feb 24, 2026

Thank you for bringing this problem up and for the very descriptive proposal.
Before merging this in, however, I'd like to see a bit more of performance analysis. In the gist, the problem is that we apparently have very different throughput/recall default configs for the SINGLE_CTA and MULTI_CTA algorithms. This is especially misleading for the users due to automatic switching between the two algorithms. What I miss so far is comparison of throughput/recall curves before/after the change. From the description I assume we have a sharp drop in recall when switching to SINGLE_CTA from MULTI_CTA; does it come with a sharp increase in throughput?; maybe just changing other parameter defaults would fix that (i.e. increasing the number of iterations)? Normally we assume SINGLE_CTA overall performs better than MULTI_CTA for big enough batches, hence I hesitate to constrain the usage of SINGLE_CTA over MULTI_CTA right away.

@zbennett10
Copy link
Contributor Author

zbennett10 commented Feb 24, 2026

Thank you for bringing this problem up and for the very descriptive proposal. Before merging this in, however, I'd like to see a bit more of performance analysis. In the gist, the problem is that we apparently have very different throughput/recall default configs for the SINGLE_CTA and MULTI_CTA algorithms. This is especially misleading for the users due to automatic switching between the two algorithms. What I miss so far is comparison of throughput/recall curves before/after the change. From the description I assume we have a sharp drop in recall when switching to SINGLE_CTA from MULTI_CTA; does it come with a sharp increase in throughput?; maybe just changing other parameter defaults would fix that (i.e. increasing the number of iterations)? Normally we assume SINGLE_CTA overall performs better than MULTI_CTA for big enough batches, hence I hesitate to constrain the usage of SINGLE_CTA over MULTI_CTA right away.

@achirkin thank you for the feedback - I will look at this and get back to you with a benchmark or two and a possible alternative to what I've implemented here.

I am thinking that instead of blocking SINGLE_CTA we can floor its iterations at MULTI_CTA's base:
_max_iterations = std::max(itopk_size / search_width, (uint32_t)32);

This, I think, would preserve SINGLE_CTA's throughput advantage for large batches while ensuring it gets enough iterations. I'll explore it and write a benchmark of some sort.

…ests

Add benchmarks/cagra_recall_throughput_bench.py comparing MULTI_CTA,
SINGLE_CTA (default), SINGLE_CTA (floor@32 iterations), and AUTO
across batch sizes and search widths. This provides the performance
data requested in review.

Rewrite test_cagra_batch_recall.py with:
- Module-scoped fixture for index build + brute-force ground truth
- test_cagra_batch_recall_consistency: recall must not vary across
  batch sizes (catches the AUTO algorithm switch cliff)
- test_cagra_search_width_recall_quality: minimum recall thresholds
  at batch_size=512 (catches SINGLE_CTA iteration deficit)
- test_cagra_search_width_monotonicity: higher search_width must not
  reduce recall (catches the core inversion bug)

Remove cupy dependency (use pylibraft.common.device_ndarray + numpy).
Fix dataset_device lifetime: keep device array alive in fixture to
prevent use-after-free in CAGRA index search.
@zbennett10
Copy link
Contributor Author

zbennett10 commented Feb 24, 2026

@achirkin — just getting back to you. I've spent some time time benchmarking both the current fix and an alternative approach. Here are the full results.

TL;DR

  • SINGLE_CTA is 2–3x faster than MULTI_CTA at large batch sizes
  • But with search_width > 1, SINGLE_CTA suffers a 0.19–0.30 recall drop that cannot be recovered by increasing iterations alone
  • The iteration-floor alternative fix (max(itopk_size/search_width, 32)) was tested and does not work — the recall gap is architectural, not just an iteration count issue
  • The original fix (block SINGLE_CTA when search_width > 1) seems to be the most correct approach.

Benchmark Setup

  • GPU: Tesla T4 (40 SMs, compute capability 7.5)
  • Dataset: 50,000 vectors, 64 dimensions, float32, normally distributed
  • Queries: 512 independent queries
  • Index: CAGRA with graph_degree=32, intermediate_graph_degree=64
  • Search: itopk_size=64, varying search_width and batch_size
  • Ground truth: Brute-force L2 nearest neighbors via NumPy
  • Recall metric: recall@10 (fraction of true top-10 neighbors found)
  • SINGLE_CTA threshold on T4: max_queries >= 2 * 40 = 80, so the switch happens at batch_size=128

The benchmark script (benchmarks/cagra_recall_throughput_bench.py) is included in this push and can be run on any GPU machine with cuVS installed.


1. Throughput/Recall: SINGLE_CTA vs MULTI_CTA

These numbers are from directly forcing algo="multi_cta" or algo="single_cta" to isolate each algorithm's behavior, collected on the same T4 GPU.

search_width=1 (no problem — both algorithms get the same iterations)

Batch Size MULTI_CTA recall MULTI_CTA QPS SINGLE_CTA recall SINGLE_CTA QPS
512 0.623 84,598 0.623 122,474

At search_width=1: itopk_size/search_width = 64/1 = 64 and mc_itopk_size/mc_search_width = 32/1 = 32. SINGLE_CTA actually gets more iterations than MULTI_CTA, so recall is equivalent. SINGLE_CTA delivers 1.4x throughput — exactly the expected behavior.

search_width=4 (recall cliff begins)

Batch Size MULTI_CTA recall MULTI_CTA QPS SINGLE_CTA recall SINGLE_CTA QPS
512 0.751 45,994 0.631 111,513

Recall delta: −0.120 (16% relative drop). SINGLE_CTA computes 64/4 = 16 base iterations vs MULTI_CTA's hardcoded 32/1 = 32.

search_width=8 (severe recall cliff)

Batch Size MULTI_CTA recall MULTI_CTA QPS SINGLE_CTA recall SINGLE_CTA QPS
512 0.873 25,136 0.647 73,064

Recall delta: −0.226 (26% relative drop). SINGLE_CTA computes 64/8 = 8 base iterations vs MULTI_CTA's 32.


2. Alternative Fix: Floor SINGLE_CTA iterations at 32

Per your suggestion, I tested whether increasing SINGLE_CTA's iterations to match MULTI_CTA's base count would recover recall while preserving throughput. The alternative:

_max_iterations = std::max(itopk_size / search_width, (uint32_t)32);

This gives SINGLE_CTA the exact same iteration count as MULTI_CTA (32 base + reachability = 36 total on this dataset).

Results with floor@32 (at batch_size=512):

Config sw=4 recall sw=4 QPS sw=8 recall sw=8 QPS
MULTI_CTA 0.751 45,994 0.873 25,136
SINGLE_CTA (default iters) 0.631 111,513 0.647 73,064
SINGLE_CTA (floor@32) 0.636 103,845 0.653 71,928

The floor@32 fix recovers almost no recall 😢 At search_width=8, even with the same 36 iterations, SINGLE_CTA gets 0.653 vs MULTI_CTA's 0.873 — a 0.220 recall gap remains.

Why the iteration floor doesn't work

The problem isn't just iteration count — it's architectural. Looking at search_plan.cuh:

  • MULTI_CTA uses hardcoded internal parameters: mc_itopk_size=32, mc_search_width=1. Each CTA independently explores one neighbor per iteration step, and results are merged across CTAs. The user-facing search_width only affects how many graph neighbors are explored per step, but each CTA processes them with mc_search_width=1.

  • SINGLE_CTA uses the user's search_width directly in its kernel. With search_width=8, each iteration step tries to explore 8 neighbors simultaneously within a single CTA. The parallelism model is fundamentally different — the wider search within a single CTA does not achieve the same exploration quality as multiple independent CTAs each exploring 1 neighbor.

Even at equal iteration counts, MULTI_CTA's independent-CTA exploration pattern produces better graph traversal coverage than SINGLE_CTA's wide-but-shallow exploration.


3. Recall Consistency Across Batch Sizes (the user-facing bug)

This is what users experience when the AUTO selector silently switches algorithms:

Unpatched (AUTO mode — current code):

search_width bs=32 bs=64 bs=128 bs=256 bs=512 std range
1 0.5408 0.5404 0.5377 0.5377 0.5377 0.0014 0.0031
4 0.7590 0.7545 0.5662 0.5662 0.5662 0.0933 0.1928
8 0.9021 0.9029 0.5988 0.5988 0.5988 0.1488 0.3041

At search_width=8, recall drops from 0.90 to 0.60 simply because the batch grew past the SINGLE_CTA threshold. Users tuning search_width to improve recall would see their recall get worse as their workload scales up.

Patched (force MULTI_CTA when search_width > 1):

search_width bs=32 bs=64 bs=128 bs=256 bs=512 std range
1 0.5408 0.5404 0.5377 0.5377 0.5377 0.0014 0.0031
4 0.7588 0.7533 0.7535 0.7562 0.7504 0.0029 0.0084
8 0.9031 0.9035 0.9041 0.9055 0.9041 0.0008 0.0023

Recall is perfectly consistent across all batch sizes. search_width=1 (the common case and the default) is completely unaffected — it still uses SINGLE_CTA at large batches and gets the throughput benefit.


4. Throughput Impact

The fix only affects search_width > 1 cases. For search_width=1, SINGLE_CTA continues to be used and delivers its throughput advantage.

For search_width > 1, yes, we lose SINGLE_CTA's throughput (2–3x faster). But the recall loss makes SINGLE_CTA results essentially unusable at search_width > 1 — users who set search_width=8 expect ~0.90 recall, not 0.60. The throughput gain is meaningless if the results are wrong.

A user who wants SINGLE_CTA's throughput with search_width > 1 can still explicitly set algo="single_cta" and accept the recall tradeoff knowingly. This fix only changes the AUTO selector's behavior.


5. Test Suite

The updated test_cagra_batch_recall.py includes three tests (7 parametrized cases):

  1. test_cagra_batch_recall_consistency (search_width=[1, 4, 8]): Recall must not vary more than std=0.02 / range=0.05 across batch sizes. Catches the AUTO algorithm switch cliff.

  2. test_cagra_search_width_recall_quality (sw=1→0.4, sw=4→0.6, sw=8→0.7): Minimum recall thresholds at batch_size=512. Catches SINGLE_CTA's iteration deficit producing abnormally low recall.

  3. test_cagra_search_width_monotonicity: Recall must not decrease as search_width increases (within tolerance). Catches the core inversion where higher search_width paradoxically reduces recall.

All 7 tests pass on patched code. On unpatched code, 3 fail (as expected).


The recall gap between SINGLE_CTA and MULTI_CTA at search_width > 1 seems to be architectural and not an iteration count issue. The only viable fix (that I can see at least 😄 ) is to avoid SINGLE_CTA when search_width > 1. Do you have any other ideas?

Happy to discuss further or run additional experiments if helpful.

@aamijar
Copy link
Member

aamijar commented Feb 24, 2026

/ok to test 4b0da50

zbennett10 and others added 2 commits February 24, 2026 14:01
Fix ruff format violations caught by pre-commit CI:
- Reflow long lines in benchmark and test files
- Add spaces around operators in f-string index expressions
@aamijar
Copy link
Member

aamijar commented Feb 25, 2026

/ok to test 6671ae4

@aamijar
Copy link
Member

aamijar commented Feb 25, 2026

/ok to test be584d8

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

C++ improvement Improves an existing functionality non-breaking Introduces a non-breaking change Python

Projects

Development

Successfully merging this pull request may close these issues.

[BUG] CAGRA search recall inconsistent with query batch size.

3 participants