fix: add __launch_bounds__ to winograd kernels for Blackwell GPUs#2391
Open
stondo wants to merge 1 commit intoLeelaChessZero:masterfrom
Open
fix: add __launch_bounds__ to winograd kernels for Blackwell GPUs#2391stondo wants to merge 1 commit intoLeelaChessZero:masterfrom
stondo wants to merge 1 commit intoLeelaChessZero:masterfrom
Conversation
InputTransform_kernel and OutputTransform_kernel exceed Blackwell sm_121 per-block resource limits without explicit launch bounds, causing "too many resources requested for launch" at runtime. Adding __launch_bounds__(1024) constrains register allocation and fixes the crash. No impact on pre-Blackwell architectures. Same class of fix as PyTorch #150266 for Blackwell compatibility. Tested on NVIDIA GB10 (sm_121) with T78 (2,466 NPS) and BT4-1740 (2,583 NPS) networks, cuda-fp16 backend.
Menkib64
requested changes
Feb 28, 2026
| // - producing 4 x 6x6 elements | ||
| template <typename T, bool nhcw> | ||
| __global__ void InputTransform_kernel(int N, int C, const T* input, T* output) { | ||
| __global__ __launch_bounds__(1024) void InputTransform_kernel(int N, int C, const T* input, T* output) { |
Contributor
There was a problem hiding this comment.
1024 is too high limit. It will limit kernel to 64 registers. T78 uses only 512 channels which uses block size 512. Kernels are a little faster when using 128 registers.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Summary
__launch_bounds__(1024)toInputTransform_kernelandOutputTransform_kernelinwinograd_helper.incDetails
Blackwell (sm_120/sm_121) has different per-SM resource limits compared to previous architectures. Without explicit
__launch_bounds__, the CUDA compiler over-allocates registers for the Winograd transform kernels, causing them to exceed per-block resource limits at launch time.This is the same class of issue documented in pytorch/pytorch#150266.
Adding
__launch_bounds__(1024)constrains the compiler's register allocation without affecting kernel behavior. The value 1024 matchescudaDeviceProp::maxThreadsPerBlockand is the standard approach for this category of Blackwell compatibility fix.Without this fix, lc0 crashes immediately on any Blackwell GPU (RTX 5090, RTX 5080, RTX 5070 Ti, GB10, GB200, etc.).
Benchmarks (NVIDIA GB10, sm_121, CUDA 13.0)
Test plan