66
77#include < raft/core/error.hpp>
88
9- #include < atomic>
109#include < cstdint>
1110#include < mutex>
1211
1312namespace cuvs ::neighbors::detail {
1413
1514/* *
1615 * @brief (Thread-)Safely invoke a kernel with a maximum dynamic shared memory size.
17- * This is required because the sequence `cudaFuncSetAttribute` + kernel launch is not executed
18- * atomically.
1916 *
20- * Used this way, the cudaFuncAttributeMaxDynamicSharedMemorySize can only grow and thus
21- * guarantees that the kernel is safe to launch.
17+ * Maintains a monotonically growing high-water mark for `cudaFuncAttributeMaxDynamicSharedMemorySize`.
18+ * When the kernel function pointer changes, the new kernel is brought up to the current high-water
19+ * mark; when smem_size exceeds the high-water mark, it is grown for the current kernel.
20+ * This guarantees every kernel's attribute is always >= smem_size at the time of launch.
21+ *
22+ * NB: cudaFuncSetAttribute is per kernel function pointer value, not per type. Multiple kernel
23+ * template instantiations may share the same KernelT type (e.g. function pointers with the same
24+ * signature), so we track the kernel identity alongside the smem high-water mark.
2225 *
2326 * @tparam KernelT The type of the kernel.
24- * @tparam InvocationT The type of the invocation function.
27+ * @tparam KernelLauncherT The type of the launch function/lambda .
2528 * @param kernel The kernel function address (for whom the smem-size is specified).
2629 * @param smem_size The size of the dynamic shared memory to be set.
2730 * @param launch The kernel launch function/lambda.
@@ -31,31 +34,42 @@ void safely_launch_kernel_with_smem_size(KernelT const& kernel,
3134 uint32_t smem_size,
3235 KernelLauncherT const & launch)
3336{
34- // the last smem size is parameterized by the kernel thanks to the template parameter.
35- static std::atomic<uint32_t > current_smem_size{0 };
36- auto last_smem_size = current_smem_size.load (std::memory_order_relaxed);
37- if (smem_size > last_smem_size) {
38- // We still need a mutex for the critical section: actualize last_smem_size and set the
39- // attribute.
40- static auto mutex = std::mutex{};
41- auto guard = std::lock_guard<std::mutex>{mutex};
42- if (!current_smem_size.compare_exchange_strong (
43- last_smem_size, smem_size, std::memory_order_relaxed, std::memory_order_relaxed)) {
44- // The value has been updated by another thread between the load and the mutex acquisition.
45- if (smem_size > last_smem_size) {
46- current_smem_size.store (smem_size, std::memory_order_relaxed);
47- }
37+ // current_smem_size is a monotonically growing high-water mark across all kernel pointers.
38+ // current_kernel tracks which kernel pointer was last used.
39+ static uint32_t current_smem_size{0 };
40+ static KernelT current_kernel{KernelT{}};
41+ static std::mutex mutex;
42+
43+ {
44+ std::lock_guard<std::mutex> guard (mutex);
45+
46+ auto last_kernel = current_kernel;
47+ auto last_smem_size = current_smem_size;
48+
49+ // When the kernel function pointer changes, bring the new kernel up to the global high-water
50+ // mark. This is necessary because cudaFuncSetAttribute applies to a specific function pointer,
51+ // not to the pointer type — different template instantiations may share the same KernelT.
52+ if (kernel != last_kernel) {
53+ current_kernel = kernel;
54+ auto launch_status =
55+ cudaFuncSetAttribute (kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, last_smem_size);
56+ RAFT_EXPECTS (launch_status == cudaSuccess,
57+ " Failed to set max dynamic shared memory size to %u bytes" ,
58+ last_smem_size);
4859 }
49- // Only update if the last seen value is smaller than the new one.
60+ // When smem_size exceeds the high-water mark, grow it for the current kernel.
61+ // If the kernel also changed above, this handles the case where smem_size > last_smem_size.
5062 if (smem_size > last_smem_size) {
63+ current_smem_size = smem_size;
5164 auto launch_status =
5265 cudaFuncSetAttribute (kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size);
5366 RAFT_EXPECTS (launch_status == cudaSuccess,
5467 " Failed to set max dynamic shared memory size to %u bytes" ,
5568 smem_size);
5669 }
5770 }
58- // We don't need to guard the kernel launch because the smem_size can only grow.
71+ // The kernel launch is outside the lock: any concurrent cudaFuncSetAttribute can only increase
72+ // the limit, so the launch is always safe.
5973 return launch (kernel);
6074}
6175
0 commit comments