Commit b188f21
authored
perf(TDDFT): Add CUDA acceleration for snap_psibeta_half function (Useful information about largely improves the snap_psibeta_half function) (#6808)
* feat(tddft): Add CUDA acceleration for snap_psibeta with constant memory grids
- Implement GPU-accelerated snap_psibeta_neighbor_batch_kernel
- Use constant memory for Lebedev and Gauss-Legendre integration grids
- Add multi-GPU support via set_device_by_rank
- Initialize/finalize GPU resources in each calculate_HR call
- Remove static global variables for cleaner resource management
- CPU fallback when GPU processing fails
* refactor(snap_psibeta_atom_batch_gpu): add timer and simplify code structure
- Add ModuleBase::timer for snap_psibeta_atom_batch_gpu function
- Remove GPU fallback to CPU design (return true/false in void function)
- Replace fallback returns with error messages and proper early exits
- Ensure timer is properly called on all exit paths
- Simplify code structure for better readability
* remove snap_psibeta_neighbor_batch_gpu
* perf(gpu): optimize snap_psibeta_atom_batch_kernel with loop restructuring
- Move ylm0 computation outside radial loop (saves 140x redundant calculations)
- Hoist A_dot_leb and dR calculations outside inner loop
- Add #pragma unroll hints for radial and m0 loops
Achieves 23.3% speedup on snap_psibeta_gpu (19.27s -> 14.78s).
Numerical correctness verified: energy matches baseline (-756.053 Ry).
* perf(gpu): optimize compute_ylm_gpu with atan2 and sincos
- Replace conditional atan branches with single atan2 call
- Use sincos() instead of separate sin/cos calls
Achieves 8.4% additional speedup (14.78s -> 13.56s)
Combined with loop restructuring: 29.6% total from baseline
Numerical correctness verified: -756.053 Ry
* make the code more concise
* perf(cuda): optimize compute_ylm_gpu with template parameters
- Convert compute_ylm_gpu to templated version with L as template param
- Use linear array for Legendre polynomials (reduces from 25 to 15 doubles)
- Add DISPATCH_YLM macro for runtime-to-template dispatch
- Add MAX_M0_SIZE constant for result array sizing
- Replace C++17 constexpr if with regular if for C++14 compatibility
- Enable compiler loop unrolling with #pragma unroll
Performance: snap_psibeta_gpu improved from 13.27s to 9.83s (1.35x speedup)
* perf(cuda): use warp shuffle for reduction in snap_psibeta kernel
- Replace shared memory tree reduction with warp shuffle reduction
- Use warp_reduce_sum for intra-warp reduction (faster shuffle ops)
- Reduce shared memory from BLOCK_SIZE (2KB) to NUM_WARPS (64 bytes)
- Cross-warp reduction done by first warp reading from shared memory
Reduces register usage from 94 to 88, shared memory from 2KB to 64 bytes.
* refactor(cuda): improve snap_psibeta_kernel code organization and documentation
- Add comprehensive file headers explaining purpose and key features
- Organize code into logical sections with clear separators
- Add doxygen-style documentation for all functions, structs, and constants
- Fix inaccurate comments (BLOCK_SIZE requirement, direction vector normalization)
- Remove unused variables (dR, distance01)
- Remove finalize_gpu_resources() as it's not needed for constant memory
- Improve inline comments explaining algorithms and optimizations
* refactor(td_nonlocal_lcao): use runtime check for GPU/CPU branch selection
- Add use_gpu runtime flag that checks both __CUDA macro and PARAM.inp.device
- GPU path is now only enabled when __CUDA is defined AND device == "gpu"
- Makes the conditional logic clearer with if/else instead of nested #ifdef
* refactor(snap_psibeta): unify CUDA error checking macros
- Move CUDA_CHECK macro to shared header snap_psibeta_kernel.cuh
- Remove duplicate CUDA_CHECK definition from snap_psibeta_gpu.cu
- Remove CUDA_CHECK_KERNEL macro and replace all usages with CUDA_CHECK
- Reduces code duplication and improves consistency
* refactor(snap_psibeta_kernel): use ModuleBase constants
- Replace local PI, FOUR_PI, SQRT2 definitions with ModuleBase:: versions
- Add include for source_base/constants.h
* refactor(snap_psibeta): use ModuleBase::WARNING_QUIT for error handling
- Replace fprintf(stderr, ...) with ModuleBase::WARNING_QUIT
- Update CUDA_CHECK macro to use WARNING_QUIT instead of fprintf
- Add includes for tool_quit.h and string header
- Consistent error handling with ABACUS codebase conventions1 parent 29f141d commit b188f21
File tree
6 files changed
+1430
-33
lines changed- source/source_lcao
- module_operator_lcao
- module_rt
- kernels
- cuda
6 files changed
+1430
-33
lines changedLines changed: 75 additions & 33 deletions
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
1 | 1 | | |
2 | 2 | | |
3 | | - | |
4 | 3 | | |
5 | 4 | | |
6 | 5 | | |
7 | | - | |
| 6 | + | |
8 | 7 | | |
| 8 | + | |
9 | 9 | | |
| 10 | + | |
| 11 | + | |
| 12 | + | |
| 13 | + | |
| 14 | + | |
10 | 15 | | |
11 | 16 | | |
12 | | - | |
13 | 17 | | |
| 18 | + | |
14 | 19 | | |
15 | 20 | | |
16 | 21 | | |
| |||
127 | 132 | | |
128 | 133 | | |
129 | 134 | | |
| 135 | + | |
| 136 | + | |
| 137 | + | |
| 138 | + | |
| 139 | + | |
| 140 | + | |
| 141 | + | |
| 142 | + | |
| 143 | + | |
| 144 | + | |
| 145 | + | |
| 146 | + | |
| 147 | + | |
| 148 | + | |
| 149 | + | |
| 150 | + | |
| 151 | + | |
| 152 | + | |
| 153 | + | |
| 154 | + | |
| 155 | + | |
130 | 156 | | |
131 | 157 | | |
132 | 158 | | |
| |||
145 | 171 | | |
146 | 172 | | |
147 | 173 | | |
148 | | - | |
| 174 | + | |
149 | 175 | | |
150 | | - | |
| 176 | + | |
| 177 | + | |
| 178 | + | |
| 179 | + | |
| 180 | + | |
| 181 | + | |
| 182 | + | |
| 183 | + | |
| 184 | + | |
| 185 | + | |
| 186 | + | |
| 187 | + | |
| 188 | + | |
| 189 | + | |
| 190 | + | |
| 191 | + | |
| 192 | + | |
| 193 | + | |
| 194 | + | |
151 | 195 | | |
152 | 196 | | |
153 | 197 | | |
| |||
160 | 204 | | |
161 | 205 | | |
162 | 206 | | |
163 | | - | |
| 207 | + | |
| 208 | + | |
| 209 | + | |
164 | 210 | | |
165 | 211 | | |
166 | 212 | | |
167 | | - | |
168 | | - | |
169 | | - | |
170 | | - | |
171 | | - | |
172 | | - | |
173 | 213 | | |
174 | | - | |
175 | | - | |
176 | | - | |
177 | | - | |
178 | | - | |
179 | | - | |
180 | | - | |
181 | | - | |
182 | | - | |
183 | | - | |
184 | | - | |
| 214 | + | |
| 215 | + | |
| 216 | + | |
| 217 | + | |
| 218 | + | |
| 219 | + | |
| 220 | + | |
| 221 | + | |
| 222 | + | |
| 223 | + | |
| 224 | + | |
185 | 225 | | |
186 | 226 | | |
187 | 227 | | |
188 | 228 | | |
189 | 229 | | |
190 | 230 | | |
| 231 | + | |
191 | 232 | | |
| 233 | + | |
| 234 | + | |
| 235 | + | |
| 236 | + | |
192 | 237 | | |
193 | 238 | | |
194 | 239 | | |
| |||
205 | 250 | | |
206 | 251 | | |
207 | 252 | | |
208 | | - | |
| 253 | + | |
209 | 254 | | |
210 | 255 | | |
211 | 256 | | |
| |||
215 | 260 | | |
216 | 261 | | |
217 | 262 | | |
218 | | - | |
219 | 263 | | |
220 | 264 | | |
221 | 265 | | |
| |||
228 | 272 | | |
229 | 273 | | |
230 | 274 | | |
231 | | - | |
| 275 | + | |
232 | 276 | | |
233 | 277 | | |
234 | 278 | | |
| |||
247 | 291 | | |
248 | 292 | | |
249 | 293 | | |
250 | | - | |
| 294 | + | |
251 | 295 | | |
252 | | - | |
| 296 | + | |
253 | 297 | | |
254 | 298 | | |
255 | 299 | | |
| |||
276 | 320 | | |
277 | 321 | | |
278 | 322 | | |
279 | | - | |
280 | | - | |
281 | | - | |
| 323 | + | |
| 324 | + | |
282 | 325 | | |
283 | 326 | | |
284 | 327 | | |
285 | 328 | | |
| 329 | + | |
286 | 330 | | |
287 | 331 | | |
288 | 332 | | |
| |||
396 | 440 | | |
397 | 441 | | |
398 | 442 | | |
399 | | - | |
400 | 443 | | |
401 | 444 | | |
402 | 445 | | |
| |||
436 | 479 | | |
437 | 480 | | |
438 | 481 | | |
439 | | - | |
440 | 482 | | |
441 | 483 | | |
442 | 484 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
18 | 18 | | |
19 | 19 | | |
20 | 20 | | |
| 21 | + | |
| 22 | + | |
| 23 | + | |
| 24 | + | |
| 25 | + | |
| 26 | + | |
| 27 | + | |
21 | 28 | | |
22 | 29 | | |
23 | 30 | | |
| |||
0 commit comments