11/*
2- * SPDX-FileCopyrightText: Copyright (c) 2024, NVIDIA CORPORATION.
2+ * SPDX-FileCopyrightText: Copyright (c) 2024-2026 , NVIDIA CORPORATION.
33 * SPDX-License-Identifier: Apache-2.0
44 */
55
99
1010#include < raft/core/resource/cuda_stream.hpp>
1111#include < raft/core/resource/thrust_policy.hpp>
12+ #include < raft/util/cuda_dev_essentials.cuh>
1213
1314#include < cuco/static_map.cuh>
1415#include < thrust/copy.h>
1516#include < thrust/iterator/counting_iterator.h>
1617
18+ #include < cooperative_groups.h>
19+ #include < rmm/device_uvector.hpp>
20+
1721// this is needed by cuco as key, value must be bitwise comparable.
1822// compilers don't declare float/double as bitwise comparable
1923// but that is too strict
@@ -32,11 +36,19 @@ namespace sparse {
3236template <typename value_idx, typename value_t , int tpb>
3337class hash_strategy : public coo_spmv_strategy <value_idx, value_t , tpb> {
3438 public:
35- using insert_type = typename cuco::legacy::
36- static_map<value_idx, value_t , cuda::thread_scope_block>::device_mutable_view;
37- using smem_type = typename insert_type::slot_type*;
38- using find_type =
39- typename cuco::legacy::static_map<value_idx, value_t , cuda::thread_scope_block>::device_view;
39+ static constexpr value_idx empty_key_sentinel = value_idx{-1 };
40+ static constexpr value_t empty_value_sentinel = value_t {0 };
41+ using probing_scheme_type = cuco::linear_probing<1 , cuco::murmurhash3_32<value_idx>>;
42+ using storage_ref_type =
43+ cuco::bucket_storage_ref<cuco::pair<value_idx, value_t >, 1 , cuco::extent<int >>;
44+ using map_type = cuco::static_map_ref<value_idx,
45+ value_t ,
46+ cuda::thread_scope_block,
47+ cuda::std::equal_to<value_idx>,
48+ probing_scheme_type,
49+ storage_ref_type,
50+ cuco::op::insert_tag,
51+ cuco::op::find_tag>;
4052
4153 hash_strategy (const distances_config_t <value_idx, value_t >& config_,
4254 float capacity_threshold_ = 0.5 ,
@@ -220,32 +232,35 @@ class hash_strategy : public coo_spmv_strategy<value_idx, value_t, tpb> {
220232 }
221233 }
222234
223- __device__ inline insert_type init_insert (smem_type cache , const value_idx& cache_size)
235+ __device__ inline map_type init_map ( void * storage , const value_idx& cache_size)
224236 {
225- return insert_type::make_from_uninitialized_slots (cooperative_groups::this_thread_block (),
226- cache,
227- cache_size,
228- cuco::empty_key{value_idx{-1 }},
229- cuco::empty_value{value_t {0 }});
237+ auto map_ref =
238+ map_type{cuco::empty_key<value_idx>{empty_key_sentinel},
239+ cuco::empty_value<value_t >{empty_value_sentinel},
240+ cuda::std::equal_to<value_idx>{},
241+ probing_scheme_type{},
242+ cuco::cuda_thread_scope<cuda::thread_scope_block>{},
243+ storage_ref_type{cuco::extent<int >{cache_size},
244+ static_cast <typename storage_ref_type::value_type*>(storage)}};
245+ map_ref.initialize (cooperative_groups::this_thread_block ());
246+
247+ return map_ref;
230248 }
231249
232- __device__ inline void insert (insert_type cache , const value_idx& key, const value_t & value)
250+ __device__ inline void insert (map_type& map_ref , const value_idx& key, const value_t & value)
233251 {
234- auto success = cache .insert (cuco::pair<value_idx, value_t >( key, value) );
252+ map_ref .insert (cuco::pair{ key, value} );
235253 }
236254
237- __device__ inline find_type init_find (smem_type cache, const value_idx& cache_size)
238- {
239- return find_type (
240- cache, cache_size, cuco::empty_key{value_idx{-1 }}, cuco::empty_value{value_t {0 }});
241- }
255+ // Note: init_find is now merged with init_map since the new API uses the same ref for both
256+ // operations
242257
243- __device__ inline value_t find (find_type cache , const value_idx& key)
258+ __device__ inline value_t find (map_type& map_ref , const value_idx& key)
244259 {
245- auto a_pair = cache .find (key);
260+ auto a_pair = map_ref .find (key);
246261
247262 value_t a_col = 0.0 ;
248- if (a_pair != cache .end ()) { a_col = a_pair->second ; }
263+ if (a_pair != map_ref .end ()) { a_col = a_pair->second ; }
249264 return a_col;
250265 }
251266
@@ -271,7 +286,7 @@ class hash_strategy : public coo_spmv_strategy<value_idx, value_t, tpb> {
271286 inline static int get_map_size ()
272287 {
273288 return (raft::getSharedMemPerBlock () - ((tpb / raft::warp_size ()) * sizeof (value_t ))) /
274- sizeof (typename insert_type::slot_type );
289+ sizeof (cuco::pair<value_idx, value_t > );
275290 }
276291
277292 private:
0 commit comments