@@ -35,39 +35,26 @@ void safely_launch_kernel_with_smem_size(KernelT const& kernel,
3535 uint32_t smem_size,
3636 KernelLauncherT const & launch)
3737{
38- // current_smem_size is a monotonically growing high-water mark across all kernel pointers.
39- // current_kernel tracks which kernel pointer was last used.
40- static uint32_t current_smem_size {0 };
41- static KernelT current_kernel {KernelT{}};
38+ // last_smem_size is a monotonically growing high-water mark across all kernel pointers.
39+ // last_kernel tracks which kernel pointer was last used.
40+ static uint32_t last_smem_size {0 };
41+ static KernelT last_kernel {KernelT{}};
4242 static std::mutex mutex;
4343
4444 {
4545 std::lock_guard<std::mutex> guard (mutex);
4646
47- auto last_kernel = current_kernel;
48- auto last_smem_size = current_smem_size;
49-
5047 // When the kernel function pointer changes, bring the new kernel up to the global high-water
5148 // mark. This is necessary because cudaFuncSetAttribute applies to a specific function pointer,
5249 // not to the pointer type — different template instantiations may share the same KernelT.
53- if (kernel != last_kernel) {
54- current_kernel = kernel;
55- auto launch_status =
56- cudaFuncSetAttribute (kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, last_smem_size);
57- RAFT_EXPECTS (launch_status == cudaSuccess,
58- " Failed to set max dynamic shared memory size to %u bytes" ,
59- last_smem_size);
60- }
50+ last_kernel = kernel != last_kernel ? kernel : last_kernel;
6151 // When smem_size exceeds the high-water mark, grow it for the current kernel.
6252 // If the kernel also changed above, this handles the case where smem_size > last_smem_size.
63- if (smem_size > last_smem_size) {
64- current_smem_size = smem_size;
65- auto launch_status =
66- cudaFuncSetAttribute (kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size);
67- RAFT_EXPECTS (launch_status == cudaSuccess,
68- " Failed to set max dynamic shared memory size to %u bytes" ,
69- smem_size);
70- }
53+ last_smem_size = smem_size > last_smem_size ? smem_size : last_smem_size;
54+ auto launch_status = cudaFuncSetAttribute (kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, last_smem_size);
55+ RAFT_EXPECTS (launch_status == cudaSuccess,
56+ " Failed to set max dynamic shared memory size to %u bytes" ,
57+ last_smem_size);
7158 }
7259 // The kernel launch is outside the lock: any concurrent cudaFuncSetAttribute can only increase
7360 // the limit, so the launch is always safe.
0 commit comments