Skip to content

Commit f10086f

Browse files
committed
robust hash insertion
1 parent cfcafa9 commit f10086f

File tree

2 files changed

+207
-67
lines changed

2 files changed

+207
-67
lines changed

tests/nn/test_sparse_generative_features.py

Lines changed: 129 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
offsets_from_batch_index,
1111
)
1212
from warpconvnet.geometry.coords.ops.serialization import POINT_ORDERING, encode
13+
from warpconvnet.geometry.coords.ops.expand import expand_coords
1314
from warpconvnet.geometry.coords.ops.stride import stride_coords
1415
from warpconvnet.geometry.types.voxels import Voxels
1516

@@ -309,3 +310,131 @@ def test_generate_output_coords_transposed_generative(toy_voxels):
309310
# Kernel map should have valid structure
310311
assert kernel_map is not None
311312
assert len(kernel_map) > 0
313+
314+
315+
def test_large_scale_expand_uniqueness():
316+
"""Ensure expand_coords does not produce duplicates on large input."""
317+
torch.manual_seed(0)
318+
if not torch.cuda.is_available():
319+
pytest.skip("CUDA not available")
320+
device = torch.device("cuda")
321+
322+
# Generate large random coordinates
323+
N = 1000000
324+
coords_range = 60
325+
326+
num_batches = 4
327+
coords_list = []
328+
for i in range(num_batches):
329+
# Generate random coordinates
330+
c = torch.randint(
331+
-coords_range, coords_range, (N // num_batches, 3), device=device, dtype=torch.int32
332+
)
333+
b = torch.full((N // num_batches, 1), i, device=device, dtype=torch.int32)
334+
coords_list.append(torch.cat([b, c], dim=1))
335+
336+
batch_indexed_coords = torch.cat(coords_list, dim=0)
337+
# Ensure input uniqueness per batch
338+
batch_indexed_coords = torch.unique(batch_indexed_coords, dim=0)
339+
340+
kernel_size = (3, 3, 3)
341+
dilation = (1, 1, 1)
342+
343+
out_coords, out_offsets = expand_coords(
344+
batch_indexed_coords, kernel_size=kernel_size, kernel_dilation=dilation
345+
)
346+
347+
# Check for duplicates
348+
unique_out, counts = torch.unique(out_coords, dim=0, return_counts=True)
349+
if unique_out.shape[0] != out_coords.shape[0]:
350+
num_duplicates = out_coords.shape[0] - unique_out.shape[0]
351+
duplicate_mask = counts > 1
352+
duplicate_examples = unique_out[duplicate_mask]
353+
# Find frequencies of duplicates
354+
max_dups = counts.max().item()
355+
pytest.fail(
356+
f"Found {num_duplicates} duplicate coordinates in expanded output. "
357+
f"Total: {out_coords.shape[0]}, Unique: {unique_out.shape[0]}. "
358+
f"Max duplicates for a single coord: {max_dups}. "
359+
f"Example duplicates: {duplicate_examples[:5]}"
360+
)
361+
362+
363+
def test_large_scale_transposed_generative_duplicates():
364+
"""
365+
Reproduce duplicate coordinates issue with SpatiallySparseConv
366+
configured as transposed=True, generative=True, stride=(2,2,2).
367+
"""
368+
torch.manual_seed(0)
369+
if not torch.cuda.is_available():
370+
pytest.skip("CUDA not available")
371+
device = torch.device("cuda")
372+
373+
# Generate large random coordinates
374+
# N needs to be large enough to cause hash collisions or stress the table resizing
375+
N = 1000000
376+
coords_range = 50 # Very dense
377+
378+
num_batches = 1 # Single batch to focus on collisions within one set
379+
coords_list = []
380+
for i in range(num_batches):
381+
c = torch.randint(
382+
-coords_range, coords_range, (N // num_batches, 3), device=device, dtype=torch.int32
383+
)
384+
# Ensure uniqueness within batch for valid input
385+
c = torch.unique(c, dim=0)
386+
b = torch.full((c.shape[0], 1), i, device=device, dtype=torch.int32)
387+
coords_list.append(torch.cat([b, c], dim=1))
388+
389+
batch_indexed_coords = torch.cat(coords_list, dim=0)
390+
391+
# Construct Voxels object
392+
# We need features for the forward pass, even though we only care about coords
393+
features = torch.randn(batch_indexed_coords.shape[0], 16, device=device)
394+
395+
# Re-split by batch for Voxels constructor
396+
coords_per_batch = []
397+
feats_per_batch = []
398+
for i in range(num_batches):
399+
mask = batch_indexed_coords[:, 0] == i
400+
coords_per_batch.append(batch_indexed_coords[mask, 1:])
401+
feats_per_batch.append(features[mask])
402+
403+
voxels = Voxels(
404+
batched_coordinates=coords_per_batch, batched_features=feats_per_batch, device=device
405+
)
406+
407+
# Configure the problematic layer
408+
in_channels = 16
409+
out_channels = 16
410+
kernel_size = (3, 3, 3)
411+
stride = (2, 2, 2)
412+
413+
conv = SpatiallySparseConv(
414+
in_channels,
415+
out_channels,
416+
kernel_size=kernel_size,
417+
stride=stride,
418+
generative=True,
419+
transposed=True,
420+
).to(device)
421+
422+
# Run forward pass
423+
out_voxels = conv(voxels)
424+
425+
# Check for duplicates in output coordinates
426+
out_coords = out_voxels.batch_indexed_coordinates
427+
unique_out, counts = torch.unique(out_coords, dim=0, return_counts=True)
428+
429+
if unique_out.shape[0] != out_coords.shape[0]:
430+
num_duplicates = out_coords.shape[0] - unique_out.shape[0]
431+
max_dups = counts.max().item()
432+
duplicate_examples = unique_out[counts > 1][:5]
433+
434+
pytest.fail(
435+
f"Found {num_duplicates} duplicate coordinates in output.\n"
436+
f"Total output coords: {out_coords.shape[0]}\n"
437+
f"Unique output coords: {unique_out.shape[0]}\n"
438+
f"Max duplicates for a single coord: {max_dups}\n"
439+
f"Examples: {duplicate_examples}"
440+
)

warpconvnet/csrc/hashmap_kernels.cu

Lines changed: 78 additions & 67 deletions
Original file line numberDiff line numberDiff line change
@@ -120,7 +120,8 @@ struct MurmurHash {
120120
// --- Vector Comparison ---
121121
// a, b: pointers to the start of the vectors
122122
// dim: dimension of the vectors
123-
__device__ inline bool vec_equal(const int* a, const int* b, int dim) {
123+
template <typename T>
124+
__device__ inline bool vec_equal(const T* a, const int* b, int dim) {
124125
for (int i = 0; i < dim; ++i) {
125126
if (a[i] != b[i]) {
126127
return false;
@@ -142,63 +143,93 @@ __device__ inline void set_expand_status(int* status_ptr, ExpandStatus new_statu
142143
atomicCAS(status_ptr, kExpandStatusSuccess, static_cast<int>(new_status));
143144
}
144145

146+
// --- Helper for finding or claiming a slot ---
145147
template <typename HashFuncT>
146-
__device__ inline void insert_candidate_if_absent(int* table_kvs,
147-
int* vector_keys,
148-
const int* candidate_key,
149-
int key_dim,
150-
int table_capacity,
151-
int vector_capacity,
152-
int* num_entries_ptr,
153-
int* status_ptr) {
154-
const int initial_slot = HashFuncT::hash(candidate_key, key_dim, table_capacity);
155-
int slot = initial_slot;
148+
__device__ inline int claim_slot_or_find(int* table_kvs,
149+
const int* vector_keys,
150+
const int* key,
151+
int key_dim,
152+
int capacity,
153+
bool* found_existing) {
154+
int slot = HashFuncT::hash(key, key_dim, capacity);
155+
int initial_slot = slot;
156156
int attempts = 0;
157+
*found_existing = false;
157158

158-
while (attempts < table_capacity) {
159+
while (attempts < capacity) {
159160
int* slot_address = &table_kvs[slot * 2];
160161
int prev = atomicCAS(slot_address, -1, slot);
161162

162163
if (prev == -1) {
163-
// Reserve the slot and append the candidate to vector_keys.
164-
int new_index = atomicAdd(num_entries_ptr, 1);
165-
if (new_index >= vector_capacity) {
166-
// Roll back reservation and flag overflow.
167-
atomicExch(slot_address, -1);
168-
table_kvs[slot * 2 + 1] = -1;
169-
set_expand_status(status_ptr, kExpandStatusVectorOverflow);
170-
return;
171-
}
172-
173-
int* dst = &vector_keys[new_index * key_dim];
174-
for (int d = 0; d < key_dim; ++d) {
175-
dst[d] = candidate_key[d];
176-
}
177-
178-
__threadfence();
179-
table_kvs[slot * 2 + 1] = new_index;
180-
return;
164+
return slot; // Successfully claimed
181165
}
182166

183-
int vector_index = table_kvs[slot * 2 + 1];
167+
// Slot occupied (or reserved by another thread just now)
168+
// Use volatile to ensure we read the latest value from memory
169+
volatile int* slot_value_ptr = &table_kvs[slot * 2 + 1];
170+
int vector_index = *slot_value_ptr;
171+
184172
if (vector_index < 0) {
185173
// Another thread is writing to this slot. Retry without advancing.
186174
continue;
187175
}
188176

189-
const int* existing_key = &vector_keys[vector_index * key_dim];
190-
if (vec_equal(existing_key, candidate_key, key_dim)) {
191-
return;
177+
const volatile int* existing_key = &vector_keys[vector_index * key_dim];
178+
if (vec_equal(existing_key, key, key_dim)) {
179+
*found_existing = true;
180+
return -1; // Key found
192181
}
193182

194-
slot = (slot + 1) % table_capacity;
183+
// Collision with different key
184+
slot = (slot + 1) % capacity;
195185
if (slot == initial_slot) {
196186
break;
197187
}
198188
attempts++;
199189
}
190+
return -1; // Table full or not found (and couldn't claim)
191+
}
192+
193+
template <typename HashFuncT>
194+
__device__ inline void insert_candidate_if_absent(int* table_kvs,
195+
int* vector_keys,
196+
const int* candidate_key,
197+
int key_dim,
198+
int table_capacity,
199+
int vector_capacity,
200+
int* num_entries_ptr,
201+
int* status_ptr) {
202+
bool found = false;
203+
int slot = claim_slot_or_find<HashFuncT>(
204+
table_kvs, vector_keys, candidate_key, key_dim, table_capacity, &found);
200205

201-
set_expand_status(status_ptr, kExpandStatusTableFull);
206+
if (found) {
207+
return; // Already present
208+
}
209+
210+
if (slot == -1) {
211+
// Table full
212+
set_expand_status(status_ptr, kExpandStatusTableFull);
213+
return;
214+
}
215+
216+
// Slot reserved, now allocate index
217+
int new_index = atomicAdd(num_entries_ptr, 1);
218+
if (new_index >= vector_capacity) {
219+
// Roll back reservation and flag overflow.
220+
atomicExch(&table_kvs[slot * 2], -1);
221+
table_kvs[slot * 2 + 1] = -1;
222+
set_expand_status(status_ptr, kExpandStatusVectorOverflow);
223+
return;
224+
}
225+
226+
int* dst = &vector_keys[new_index * key_dim];
227+
for (int d = 0; d < key_dim; ++d) {
228+
dst[d] = candidate_key[d];
229+
}
230+
231+
__threadfence();
232+
table_kvs[slot * 2 + 1] = new_index;
202233
}
203234

204235
// --- Device Function for Hash Table Search ---
@@ -271,40 +302,20 @@ __global__ void insert_kernel_templated(
271302
}
272303

273304
const int* key_to_insert = &vector_keys[idx * key_dim];
274-
// Use the templated hash function directly
275-
int slot = HashFuncT::hash(key_to_insert, key_dim, table_capacity);
276-
int initial_slot = slot;
277-
int attempts = 0;
278-
279-
while (attempts < table_capacity) {
280-
int* slot_address = &table_kvs[slot * 2];
281-
// Store the *original index* (idx) in the compare field, not the slot.
282-
// This prevents overwriting if two different keys hash to the same slot initially.
283-
// We are essentially using the first element of the pair to *reserve* the slot
284-
// via atomicCAS, and the second to store the value (original index).
285-
// We store the actual index idx+1 temporarily to distinguish from initial -1.
286-
// Let's refine this: Store 'slot' in compare field as originally, seems simpler.
287-
int prev = atomicCAS(slot_address, -1, slot); // Try to claim the slot marker
288-
289-
if (prev == -1) {
290-
// Slot claimed successfully, now store the actual value index
291-
table_kvs[slot * 2 + 1] = idx;
292-
// Optional: store the actual hash value in table_kvs[slot*2 + 0] = slot;
293-
// Already done by atomicCAS if successful.
294-
return;
295-
}
305+
bool found = false;
306+
int slot = claim_slot_or_find<HashFuncT>(
307+
table_kvs, vector_keys, key_to_insert, key_dim, table_capacity, &found);
296308

297-
// Collision or slot already claimed
298-
slot = (slot + 1) % table_capacity;
309+
if (found) {
310+
return; // Already present (deduplication)
311+
}
299312

300-
if (slot == initial_slot) {
301-
// Table is full or couldn't find an empty slot after full circle
302-
// Consider adding a mechanism to signal failure if needed.
303-
return;
304-
}
305-
attempts++;
313+
if (slot != -1) {
314+
// Claimed successfully, store the index
315+
table_kvs[slot * 2 + 1] = idx;
306316
}
307-
// Exceeded attempts (should only happen if table is pathologically full)
317+
// If slot == -1 and !found, table is full (fail silently as per original logic, or could add
318+
// error handling)
308319
}
309320

310321
// --- Templated Search Kernel ---

0 commit comments

Comments
 (0)