|
5 | 5 | #pragma once |
6 | 6 |
|
7 | 7 | #include <raft/core/error.hpp> |
8 | | - |
| 8 | +#include <atomic> |
9 | 9 | #include <cstdint> |
10 | 10 | #include <mutex> |
11 | 11 |
|
@@ -37,25 +37,36 @@ void safely_launch_kernel_with_smem_size(KernelT const& kernel, |
37 | 37 | { |
38 | 38 | // last_smem_size is a monotonically growing high-water mark across all kernel pointers. |
39 | 39 | // last_kernel tracks which kernel pointer was last used. |
40 | | - static uint32_t last_smem_size{0}; |
41 | | - static KernelT last_kernel{KernelT{}}; |
| 40 | + static std::atomic<uint32_t> last_smem_size{0}; |
| 41 | + static std::atomic<KernelT> last_kernel{KernelT{}}; |
42 | 42 | static std::mutex mutex; |
43 | | - |
| 43 | + bool updated_needed = false; |
| 44 | + // When the kernel function pointer changes, bring the new kernel up to the global high-water |
| 45 | + // mark. This is necessary because cudaFuncSetAttribute applies to a specific function pointer, |
| 46 | + // not to the pointer type — different template instantiations may share the same KernelT. |
| 47 | + if (kernel != last_kernel.load(std::memory_order_relaxed)) |
| 48 | + { |
| 49 | + last_kernel.store(kernel, std::memory_order_relaxed); |
| 50 | + updated_needed = true; |
| 51 | + } |
| 52 | + // Since we first read the kernel pointer, and the shem_size can only grow, |
| 53 | + // reading an inconsistent state is safe. At worst we will use a larger smem_size |
| 54 | + uint32_t cur_smem_size = last_smem_size.load(std::memory_order_relaxed); |
| 55 | + if (smem_size > cur_smem_size) |
| 56 | + { |
| 57 | + last_smem_size.store(smem_size, std::memory_order_relaxed); |
| 58 | + cur_smem_size = smem_size; |
| 59 | + updated_needed = true; |
| 60 | + } |
| 61 | + // Mutex-protected cudaFuncSetAttribute |
| 62 | + if (updated_needed) |
44 | 63 | { |
45 | 64 | std::lock_guard<std::mutex> guard(mutex); |
46 | | - |
47 | | - // When the kernel function pointer changes, bring the new kernel up to the global high-water |
48 | | - // mark. This is necessary because cudaFuncSetAttribute applies to a specific function pointer, |
49 | | - // not to the pointer type — different template instantiations may share the same KernelT. |
50 | | - last_kernel = kernel != last_kernel ? kernel : last_kernel; |
51 | | - // When smem_size exceeds the high-water mark, grow it for the current kernel. |
52 | | - // If the kernel also changed above, this handles the case where smem_size > last_smem_size. |
53 | | - last_smem_size = smem_size > last_smem_size ? smem_size : last_smem_size; |
54 | 65 | auto launch_status = |
55 | | - cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, last_smem_size); |
| 66 | + cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, cur_smem_size); |
56 | 67 | RAFT_EXPECTS(launch_status == cudaSuccess, |
57 | 68 | "Failed to set max dynamic shared memory size to %u bytes", |
58 | | - last_smem_size); |
| 69 | + cur_smem_size); |
59 | 70 | } |
60 | 71 | // The kernel launch is outside the lock: any concurrent cudaFuncSetAttribute can only increase |
61 | 72 | // the limit, so the launch is always safe. |
|
0 commit comments