Skip to content

Commit dd2e749

Browse files
authored
Migrate hash strategy to use the new cuco::static_map (#1462)
This PR migrates usage from the legacy static_map to the new design, with no algorithmic changes. Authors: - Yunsong Wang (https://github.com/PointKernel) Approvers: - Divye Gala (https://github.com/divyegala) URL: #1462
1 parent 1915ccf commit dd2e749

File tree

3 files changed

+54
-45
lines changed

3 files changed

+54
-45
lines changed

cpp/src/distance/detail/sparse/coo_spmv_kernel.cuh

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,12 @@
11
/*
2-
* SPDX-FileCopyrightText: Copyright (c) 2024, NVIDIA CORPORATION.
2+
* SPDX-FileCopyrightText: Copyright (c) 2024-2026, NVIDIA CORPORATION.
33
* SPDX-License-Identifier: Apache-2.0
44
*/
55

66
#pragma once
77

88
#include <raft/core/detail/macros.hpp>
9+
#include <raft/util/cuda_dev_essentials.cuh>
910

1011
#include <cub/block/block_load.cuh>
1112
#include <cub/block/block_radix_sort.cuh>
@@ -120,10 +121,11 @@ RAFT_KERNEL balanced_coo_generalized_spmv_kernel(strategy_t strategy,
120121

121122
extern __shared__ char smem[];
122123

123-
typename strategy_t::smem_type A = (typename strategy_t::smem_type)(smem);
124-
typename warp_reduce::TempStorage* temp_storage = (typename warp_reduce::TempStorage*)(A + dim);
124+
void* A = smem;
125+
typename warp_reduce::TempStorage* temp_storage =
126+
(typename warp_reduce::TempStorage*)((char*)A + dim);
125127

126-
auto inserter = strategy.init_insert(A, dim);
128+
auto map_ref = strategy.init_map(A, dim);
127129

128130
__syncthreads();
129131

@@ -134,13 +136,11 @@ RAFT_KERNEL balanced_coo_generalized_spmv_kernel(strategy_t strategy,
134136

135137
// Convert current row vector in A to dense
136138
for (int i = tid; i <= (stop_offset_a - start_offset_a); i += blockDim.x) {
137-
strategy.insert(inserter, indicesA[start_offset_a + i], dataA[start_offset_a + i]);
139+
strategy.insert(map_ref, indicesA[start_offset_a + i], dataA[start_offset_a + i]);
138140
}
139141

140142
__syncthreads();
141143

142-
auto finder = strategy.init_find(A, dim);
143-
144144
if (cur_row_a > m || cur_chunk_offset > n_blocks_per_row) return;
145145
if (ind >= nnz_b) return;
146146

@@ -166,7 +166,7 @@ RAFT_KERNEL balanced_coo_generalized_spmv_kernel(strategy_t strategy,
166166
auto in_bounds = indptrA.check_indices_bounds(start_index_a, stop_index_a, index_b);
167167

168168
if (in_bounds) {
169-
value_t a_col = strategy.find(finder, index_b);
169+
value_t a_col = strategy.find(map_ref, index_b);
170170
if (!rev || a_col == 0.0) { c = product_func(a_col, dataB[ind]); }
171171
}
172172
}
@@ -204,7 +204,7 @@ RAFT_KERNEL balanced_coo_generalized_spmv_kernel(strategy_t strategy,
204204
auto index_b = indicesB[ind];
205205
auto in_bounds = indptrA.check_indices_bounds(start_index_a, stop_index_a, index_b);
206206
if (in_bounds) {
207-
value_t a_col = strategy.find(finder, index_b);
207+
value_t a_col = strategy.find(map_ref, index_b);
208208

209209
if (!rev || a_col == 0.0) { c = accum_func(c, product_func(a_col, dataB[ind])); }
210210
}

cpp/src/distance/detail/sparse/coo_spmv_strategies/dense_smem_strategy.cuh

Lines changed: 7 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,13 @@
11
/*
2-
* SPDX-FileCopyrightText: Copyright (c) 2024, NVIDIA CORPORATION.
2+
* SPDX-FileCopyrightText: Copyright (c) 2024-2026, NVIDIA CORPORATION.
33
* SPDX-License-Identifier: Apache-2.0
44
*/
55

66
#pragma once
77

88
#include "base_strategy.cuh"
99

10-
#include <raft/util/cuda_dev_essentials.cuh> // raft::ceildiv
10+
#include <raft/util/cuda_dev_essentials.cuh>
1111

1212
namespace cuvs {
1313
namespace distance {
@@ -17,9 +17,7 @@ namespace sparse {
1717
template <typename value_idx, typename value_t, int tpb>
1818
class dense_smem_strategy : public coo_spmv_strategy<value_idx, value_t, tpb> {
1919
public:
20-
using smem_type = value_t*;
21-
using insert_type = smem_type;
22-
using find_type = smem_type;
20+
using map_type = value_t*;
2321

2422
dense_smem_strategy(const distances_config_t<value_idx, value_t>& config_)
2523
: coo_spmv_strategy<value_idx, value_t, tpb>(config_)
@@ -83,25 +81,21 @@ class dense_smem_strategy : public coo_spmv_strategy<value_idx, value_t, tpb> {
8381
n_blocks_per_row);
8482
}
8583

86-
__device__ inline insert_type init_insert(smem_type cache, const value_idx& cache_size)
84+
__device__ inline map_type init_map(void* storage, const value_idx& cache_size)
8785
{
86+
auto cache = static_cast<value_t*>(storage);
8887
for (int k = threadIdx.x; k < cache_size; k += blockDim.x) {
8988
cache[k] = 0.0;
9089
}
9190
return cache;
9291
}
9392

94-
__device__ inline void insert(insert_type cache, const value_idx& key, const value_t& value)
93+
__device__ inline void insert(map_type& cache, const value_idx& key, const value_t& value)
9594
{
9695
cache[key] = value;
9796
}
9897

99-
__device__ inline find_type init_find(smem_type cache, const value_idx& cache_size)
100-
{
101-
return cache;
102-
}
103-
104-
__device__ inline value_t find(find_type cache, const value_idx& key) { return cache[key]; }
98+
__device__ inline value_t find(map_type& cache, const value_idx& key) { return cache[key]; }
10599
};
106100

107101
} // namespace sparse

cpp/src/distance/detail/sparse/coo_spmv_strategies/hash_strategy.cuh

Lines changed: 38 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
/*
2-
* SPDX-FileCopyrightText: Copyright (c) 2024, NVIDIA CORPORATION.
2+
* SPDX-FileCopyrightText: Copyright (c) 2024-2026, NVIDIA CORPORATION.
33
* SPDX-License-Identifier: Apache-2.0
44
*/
55

@@ -9,11 +9,15 @@
99

1010
#include <raft/core/resource/cuda_stream.hpp>
1111
#include <raft/core/resource/thrust_policy.hpp>
12+
#include <raft/util/cuda_dev_essentials.cuh>
1213

1314
#include <cuco/static_map.cuh>
1415
#include <thrust/copy.h>
1516
#include <thrust/iterator/counting_iterator.h>
1617

18+
#include <cooperative_groups.h>
19+
#include <rmm/device_uvector.hpp>
20+
1721
// this is needed by cuco as key, value must be bitwise comparable.
1822
// compilers don't declare float/double as bitwise comparable
1923
// but that is too strict
@@ -32,11 +36,19 @@ namespace sparse {
3236
template <typename value_idx, typename value_t, int tpb>
3337
class hash_strategy : public coo_spmv_strategy<value_idx, value_t, tpb> {
3438
public:
35-
using insert_type = typename cuco::legacy::
36-
static_map<value_idx, value_t, cuda::thread_scope_block>::device_mutable_view;
37-
using smem_type = typename insert_type::slot_type*;
38-
using find_type =
39-
typename cuco::legacy::static_map<value_idx, value_t, cuda::thread_scope_block>::device_view;
39+
static constexpr value_idx empty_key_sentinel = value_idx{-1};
40+
static constexpr value_t empty_value_sentinel = value_t{0};
41+
using probing_scheme_type = cuco::linear_probing<1, cuco::murmurhash3_32<value_idx>>;
42+
using storage_ref_type =
43+
cuco::bucket_storage_ref<cuco::pair<value_idx, value_t>, 1, cuco::extent<int>>;
44+
using map_type = cuco::static_map_ref<value_idx,
45+
value_t,
46+
cuda::thread_scope_block,
47+
cuda::std::equal_to<value_idx>,
48+
probing_scheme_type,
49+
storage_ref_type,
50+
cuco::op::insert_tag,
51+
cuco::op::find_tag>;
4052

4153
hash_strategy(const distances_config_t<value_idx, value_t>& config_,
4254
float capacity_threshold_ = 0.5,
@@ -220,32 +232,35 @@ class hash_strategy : public coo_spmv_strategy<value_idx, value_t, tpb> {
220232
}
221233
}
222234

223-
__device__ inline insert_type init_insert(smem_type cache, const value_idx& cache_size)
235+
__device__ inline map_type init_map(void* storage, const value_idx& cache_size)
224236
{
225-
return insert_type::make_from_uninitialized_slots(cooperative_groups::this_thread_block(),
226-
cache,
227-
cache_size,
228-
cuco::empty_key{value_idx{-1}},
229-
cuco::empty_value{value_t{0}});
237+
auto map_ref =
238+
map_type{cuco::empty_key<value_idx>{empty_key_sentinel},
239+
cuco::empty_value<value_t>{empty_value_sentinel},
240+
cuda::std::equal_to<value_idx>{},
241+
probing_scheme_type{},
242+
cuco::cuda_thread_scope<cuda::thread_scope_block>{},
243+
storage_ref_type{cuco::extent<int>{cache_size},
244+
static_cast<typename storage_ref_type::value_type*>(storage)}};
245+
map_ref.initialize(cooperative_groups::this_thread_block());
246+
247+
return map_ref;
230248
}
231249

232-
__device__ inline void insert(insert_type cache, const value_idx& key, const value_t& value)
250+
__device__ inline void insert(map_type& map_ref, const value_idx& key, const value_t& value)
233251
{
234-
auto success = cache.insert(cuco::pair<value_idx, value_t>(key, value));
252+
map_ref.insert(cuco::pair{key, value});
235253
}
236254

237-
__device__ inline find_type init_find(smem_type cache, const value_idx& cache_size)
238-
{
239-
return find_type(
240-
cache, cache_size, cuco::empty_key{value_idx{-1}}, cuco::empty_value{value_t{0}});
241-
}
255+
// Note: init_find is now merged with init_map since the new API uses the same ref for both
256+
// operations
242257

243-
__device__ inline value_t find(find_type cache, const value_idx& key)
258+
__device__ inline value_t find(map_type& map_ref, const value_idx& key)
244259
{
245-
auto a_pair = cache.find(key);
260+
auto a_pair = map_ref.find(key);
246261

247262
value_t a_col = 0.0;
248-
if (a_pair != cache.end()) { a_col = a_pair->second; }
263+
if (a_pair != map_ref.end()) { a_col = a_pair->second; }
249264
return a_col;
250265
}
251266

@@ -271,7 +286,7 @@ class hash_strategy : public coo_spmv_strategy<value_idx, value_t, tpb> {
271286
inline static int get_map_size()
272287
{
273288
return (raft::getSharedMemPerBlock() - ((tpb / raft::warp_size()) * sizeof(value_t))) /
274-
sizeof(typename insert_type::slot_type);
289+
sizeof(cuco::pair<value_idx, value_t>);
275290
}
276291

277292
private:

0 commit comments

Comments
 (0)