@@ -120,7 +120,8 @@ struct MurmurHash {
120120// --- Vector Comparison ---
121121// a, b: pointers to the start of the vectors
122122// dim: dimension of the vectors
123- __device__ inline bool vec_equal (const int * a, const int * b, int dim) {
123+ template <typename T>
124+ __device__ inline bool vec_equal (const T* a, const int * b, int dim) {
124125 for (int i = 0 ; i < dim; ++i) {
125126 if (a[i] != b[i]) {
126127 return false ;
@@ -142,63 +143,93 @@ __device__ inline void set_expand_status(int* status_ptr, ExpandStatus new_statu
142143 atomicCAS (status_ptr, kExpandStatusSuccess , static_cast <int >(new_status));
143144}
144145
146+ // --- Helper for finding or claiming a slot ---
145147template <typename HashFuncT>
146- __device__ inline void insert_candidate_if_absent (int * table_kvs,
147- int * vector_keys,
148- const int * candidate_key,
149- int key_dim,
150- int table_capacity,
151- int vector_capacity,
152- int * num_entries_ptr,
153- int * status_ptr) {
154- const int initial_slot = HashFuncT::hash (candidate_key, key_dim, table_capacity);
155- int slot = initial_slot;
148+ __device__ inline int claim_slot_or_find (int * table_kvs,
149+ const int * vector_keys,
150+ const int * key,
151+ int key_dim,
152+ int capacity,
153+ bool * found_existing) {
154+ int slot = HashFuncT::hash (key, key_dim, capacity);
155+ int initial_slot = slot;
156156 int attempts = 0 ;
157+ *found_existing = false ;
157158
158- while (attempts < table_capacity ) {
159+ while (attempts < capacity ) {
159160 int * slot_address = &table_kvs[slot * 2 ];
160161 int prev = atomicCAS (slot_address, -1 , slot);
161162
162163 if (prev == -1 ) {
163- // Reserve the slot and append the candidate to vector_keys.
164- int new_index = atomicAdd (num_entries_ptr, 1 );
165- if (new_index >= vector_capacity) {
166- // Roll back reservation and flag overflow.
167- atomicExch (slot_address, -1 );
168- table_kvs[slot * 2 + 1 ] = -1 ;
169- set_expand_status (status_ptr, kExpandStatusVectorOverflow );
170- return ;
171- }
172-
173- int * dst = &vector_keys[new_index * key_dim];
174- for (int d = 0 ; d < key_dim; ++d) {
175- dst[d] = candidate_key[d];
176- }
177-
178- __threadfence ();
179- table_kvs[slot * 2 + 1 ] = new_index;
180- return ;
164+ return slot; // Successfully claimed
181165 }
182166
183- int vector_index = table_kvs[slot * 2 + 1 ];
167+ // Slot occupied (or reserved by another thread just now)
168+ // Use volatile to ensure we read the latest value from memory
169+ volatile int * slot_value_ptr = &table_kvs[slot * 2 + 1 ];
170+ int vector_index = *slot_value_ptr;
171+
184172 if (vector_index < 0 ) {
185173 // Another thread is writing to this slot. Retry without advancing.
186174 continue ;
187175 }
188176
189- const int * existing_key = &vector_keys[vector_index * key_dim];
190- if (vec_equal (existing_key, candidate_key, key_dim)) {
191- return ;
177+ const volatile int * existing_key = &vector_keys[vector_index * key_dim];
178+ if (vec_equal (existing_key, key, key_dim)) {
179+ *found_existing = true ;
180+ return -1 ; // Key found
192181 }
193182
194- slot = (slot + 1 ) % table_capacity;
183+ // Collision with different key
184+ slot = (slot + 1 ) % capacity;
195185 if (slot == initial_slot) {
196186 break ;
197187 }
198188 attempts++;
199189 }
190+ return -1 ; // Table full or not found (and couldn't claim)
191+ }
192+
193+ template <typename HashFuncT>
194+ __device__ inline void insert_candidate_if_absent (int * table_kvs,
195+ int * vector_keys,
196+ const int * candidate_key,
197+ int key_dim,
198+ int table_capacity,
199+ int vector_capacity,
200+ int * num_entries_ptr,
201+ int * status_ptr) {
202+ bool found = false ;
203+ int slot = claim_slot_or_find<HashFuncT>(
204+ table_kvs, vector_keys, candidate_key, key_dim, table_capacity, &found);
200205
201- set_expand_status (status_ptr, kExpandStatusTableFull );
206+ if (found) {
207+ return ; // Already present
208+ }
209+
210+ if (slot == -1 ) {
211+ // Table full
212+ set_expand_status (status_ptr, kExpandStatusTableFull );
213+ return ;
214+ }
215+
216+ // Slot reserved, now allocate index
217+ int new_index = atomicAdd (num_entries_ptr, 1 );
218+ if (new_index >= vector_capacity) {
219+ // Roll back reservation and flag overflow.
220+ atomicExch (&table_kvs[slot * 2 ], -1 );
221+ table_kvs[slot * 2 + 1 ] = -1 ;
222+ set_expand_status (status_ptr, kExpandStatusVectorOverflow );
223+ return ;
224+ }
225+
226+ int * dst = &vector_keys[new_index * key_dim];
227+ for (int d = 0 ; d < key_dim; ++d) {
228+ dst[d] = candidate_key[d];
229+ }
230+
231+ __threadfence ();
232+ table_kvs[slot * 2 + 1 ] = new_index;
202233}
203234
204235// --- Device Function for Hash Table Search ---
@@ -271,40 +302,20 @@ __global__ void insert_kernel_templated(
271302 }
272303
273304 const int * key_to_insert = &vector_keys[idx * key_dim];
274- // Use the templated hash function directly
275- int slot = HashFuncT::hash (key_to_insert, key_dim, table_capacity);
276- int initial_slot = slot;
277- int attempts = 0 ;
278-
279- while (attempts < table_capacity) {
280- int * slot_address = &table_kvs[slot * 2 ];
281- // Store the *original index* (idx) in the compare field, not the slot.
282- // This prevents overwriting if two different keys hash to the same slot initially.
283- // We are essentially using the first element of the pair to *reserve* the slot
284- // via atomicCAS, and the second to store the value (original index).
285- // We store the actual index idx+1 temporarily to distinguish from initial -1.
286- // Let's refine this: Store 'slot' in compare field as originally, seems simpler.
287- int prev = atomicCAS (slot_address, -1 , slot); // Try to claim the slot marker
288-
289- if (prev == -1 ) {
290- // Slot claimed successfully, now store the actual value index
291- table_kvs[slot * 2 + 1 ] = idx;
292- // Optional: store the actual hash value in table_kvs[slot*2 + 0] = slot;
293- // Already done by atomicCAS if successful.
294- return ;
295- }
305+ bool found = false ;
306+ int slot = claim_slot_or_find<HashFuncT>(
307+ table_kvs, vector_keys, key_to_insert, key_dim, table_capacity, &found);
296308
297- // Collision or slot already claimed
298- slot = (slot + 1 ) % table_capacity;
309+ if (found) {
310+ return ; // Already present (deduplication)
311+ }
299312
300- if (slot == initial_slot) {
301- // Table is full or couldn't find an empty slot after full circle
302- // Consider adding a mechanism to signal failure if needed.
303- return ;
304- }
305- attempts++;
313+ if (slot != -1 ) {
314+ // Claimed successfully, store the index
315+ table_kvs[slot * 2 + 1 ] = idx;
306316 }
307- // Exceeded attempts (should only happen if table is pathologically full)
317+ // If slot == -1 and !found, table is full (fail silently as per original logic, or could add
318+ // error handling)
308319}
309320
310321// --- Templated Search Kernel ---
0 commit comments