Skip to content

Commit 129ee4f

Browse files
added the fix that also checks whether the kernel function pointer has changed
1 parent 2c71899 commit 129ee4f

File tree

1 file changed

+36
-22
lines changed

1 file changed

+36
-22
lines changed

cpp/src/neighbors/detail/smem_utils.cuh

Lines changed: 36 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -6,22 +6,25 @@
66

77
#include <raft/core/error.hpp>
88

9-
#include <atomic>
109
#include <cstdint>
1110
#include <mutex>
1211

1312
namespace cuvs::neighbors::detail {
1413

1514
/**
1615
* @brief (Thread-)Safely invoke a kernel with a maximum dynamic shared memory size.
17-
* This is required because the sequence `cudaFuncSetAttribute` + kernel launch is not executed
18-
* atomically.
1916
*
20-
* Used this way, the cudaFuncAttributeMaxDynamicSharedMemorySize can only grow and thus
21-
* guarantees that the kernel is safe to launch.
17+
* Maintains a monotonically growing high-water mark for `cudaFuncAttributeMaxDynamicSharedMemorySize`.
18+
* When the kernel function pointer changes, the new kernel is brought up to the current high-water
19+
* mark; when smem_size exceeds the high-water mark, it is grown for the current kernel.
20+
* This guarantees every kernel's attribute is always >= smem_size at the time of launch.
21+
*
22+
* NB: cudaFuncSetAttribute is per kernel function pointer value, not per type. Multiple kernel
23+
* template instantiations may share the same KernelT type (e.g. function pointers with the same
24+
* signature), so we track the kernel identity alongside the smem high-water mark.
2225
*
2326
* @tparam KernelT The type of the kernel.
24-
* @tparam InvocationT The type of the invocation function.
27+
* @tparam KernelLauncherT The type of the launch function/lambda.
2528
* @param kernel The kernel function address (for whom the smem-size is specified).
2629
* @param smem_size The size of the dynamic shared memory to be set.
2730
* @param launch The kernel launch function/lambda.
@@ -31,31 +34,42 @@ void safely_launch_kernel_with_smem_size(KernelT const& kernel,
3134
uint32_t smem_size,
3235
KernelLauncherT const& launch)
3336
{
34-
// the last smem size is parameterized by the kernel thanks to the template parameter.
35-
static std::atomic<uint32_t> current_smem_size{0};
36-
auto last_smem_size = current_smem_size.load(std::memory_order_relaxed);
37-
if (smem_size > last_smem_size) {
38-
// We still need a mutex for the critical section: actualize last_smem_size and set the
39-
// attribute.
40-
static auto mutex = std::mutex{};
41-
auto guard = std::lock_guard<std::mutex>{mutex};
42-
if (!current_smem_size.compare_exchange_strong(
43-
last_smem_size, smem_size, std::memory_order_relaxed, std::memory_order_relaxed)) {
44-
// The value has been updated by another thread between the load and the mutex acquisition.
45-
if (smem_size > last_smem_size) {
46-
current_smem_size.store(smem_size, std::memory_order_relaxed);
47-
}
37+
// current_smem_size is a monotonically growing high-water mark across all kernel pointers.
38+
// current_kernel tracks which kernel pointer was last used.
39+
static uint32_t current_smem_size{0};
40+
static KernelT current_kernel{KernelT{}};
41+
static std::mutex mutex;
42+
43+
{
44+
std::lock_guard<std::mutex> guard(mutex);
45+
46+
auto last_kernel = current_kernel;
47+
auto last_smem_size = current_smem_size;
48+
49+
// When the kernel function pointer changes, bring the new kernel up to the global high-water
50+
// mark. This is necessary because cudaFuncSetAttribute applies to a specific function pointer,
51+
// not to the pointer type — different template instantiations may share the same KernelT.
52+
if (kernel != last_kernel) {
53+
current_kernel = kernel;
54+
auto launch_status =
55+
cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, last_smem_size);
56+
RAFT_EXPECTS(launch_status == cudaSuccess,
57+
"Failed to set max dynamic shared memory size to %u bytes",
58+
last_smem_size);
4859
}
49-
// Only update if the last seen value is smaller than the new one.
60+
// When smem_size exceeds the high-water mark, grow it for the current kernel.
61+
// If the kernel also changed above, this handles the case where smem_size > last_smem_size.
5062
if (smem_size > last_smem_size) {
63+
current_smem_size = smem_size;
5164
auto launch_status =
5265
cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size);
5366
RAFT_EXPECTS(launch_status == cudaSuccess,
5467
"Failed to set max dynamic shared memory size to %u bytes",
5568
smem_size);
5669
}
5770
}
58-
// We don't need to guard the kernel launch because the smem_size can only grow.
71+
// The kernel launch is outside the lock: any concurrent cudaFuncSetAttribute can only increase
72+
// the limit, so the launch is always safe.
5973
return launch(kernel);
6074
}
6175

0 commit comments

Comments
 (0)