Skip to content

Commit 5861270

Browse files
split one mutex into two atomics
1 parent 5499e37 commit 5861270

File tree

1 file changed

+25
-14
lines changed

1 file changed

+25
-14
lines changed

cpp/src/neighbors/detail/smem_utils.cuh

Lines changed: 25 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
#pragma once
66

77
#include <raft/core/error.hpp>
8-
8+
#include <atomic>
99
#include <cstdint>
1010
#include <mutex>
1111

@@ -37,25 +37,36 @@ void safely_launch_kernel_with_smem_size(KernelT const& kernel,
3737
{
3838
// last_smem_size is a monotonically growing high-water mark across all kernel pointers.
3939
// last_kernel tracks which kernel pointer was last used.
40-
static uint32_t last_smem_size{0};
41-
static KernelT last_kernel{KernelT{}};
40+
static std::atomic<uint32_t> last_smem_size{0};
41+
static std::atomic<KernelT> last_kernel{KernelT{}};
4242
static std::mutex mutex;
43-
43+
bool updated_needed = false;
44+
// When the kernel function pointer changes, bring the new kernel up to the global high-water
45+
// mark. This is necessary because cudaFuncSetAttribute applies to a specific function pointer,
46+
// not to the pointer type — different template instantiations may share the same KernelT.
47+
if (kernel != last_kernel.load(std::memory_order_relaxed))
48+
{
49+
last_kernel.store(kernel, std::memory_order_relaxed);
50+
updated_needed = true;
51+
}
52+
// Since we first read the kernel pointer, and the shem_size can only grow,
53+
// reading an inconsistent state is safe. At worst we will use a larger smem_size
54+
uint32_t cur_smem_size = last_smem_size.load(std::memory_order_relaxed);
55+
if (smem_size > cur_smem_size)
56+
{
57+
last_smem_size.store(smem_size, std::memory_order_relaxed);
58+
cur_smem_size = smem_size;
59+
updated_needed = true;
60+
}
61+
// Mutex-protected cudaFuncSetAttribute
62+
if (updated_needed)
4463
{
4564
std::lock_guard<std::mutex> guard(mutex);
46-
47-
// When the kernel function pointer changes, bring the new kernel up to the global high-water
48-
// mark. This is necessary because cudaFuncSetAttribute applies to a specific function pointer,
49-
// not to the pointer type — different template instantiations may share the same KernelT.
50-
last_kernel = kernel != last_kernel ? kernel : last_kernel;
51-
// When smem_size exceeds the high-water mark, grow it for the current kernel.
52-
// If the kernel also changed above, this handles the case where smem_size > last_smem_size.
53-
last_smem_size = smem_size > last_smem_size ? smem_size : last_smem_size;
5465
auto launch_status =
55-
cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, last_smem_size);
66+
cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, cur_smem_size);
5667
RAFT_EXPECTS(launch_status == cudaSuccess,
5768
"Failed to set max dynamic shared memory size to %u bytes",
58-
last_smem_size);
69+
cur_smem_size);
5970
}
6071
// The kernel launch is outside the lock: any concurrent cudaFuncSetAttribute can only increase
6172
// the limit, so the launch is always safe.

0 commit comments

Comments
 (0)