-
Notifications
You must be signed in to change notification settings - Fork 237
Description
Checklist
- 1. I have searched for related issues and FAQs (https://nunchaku.tech/docs/nunchaku/faq/faq.html) but was unable to find a solution.
- 2. The issue persists in the latest version.
- 3. Please note that without environment information and a minimal reproducible example, it will be difficult for us to reproduce and address the issue, which may delay our response.
- 4. If your report is a question rather than a bug, please submit it as a discussion at https://github.com/mit-han-lab/nunchaku/discussions/new/choose. Otherwise, this issue will be closed.
- 5. If this is related to ComfyUI, please report it at https://github.com/mit-han-lab/ComfyUI-nunchaku/issues.
- 6. I will do my best to describe the issue in English.
Describe the Bug
We are encountering a hard crash (CUDA error: an illegal memory access was encountered) when running the svdq_quantize_w4a4_act_fuse_lora_cuda kernel on NVIDIA Blackwell (SM120) architecture.
This occurs even with perfectly contiguous inputs and valid shapes. The crash is asynchronous: the kernel launch appears to succeed, but the CUDA context is corrupted immediately after, causing the next CUDA synchronization or kernel launch (e.g., torch.randn or torch._scaled_mm) to fail with "Illegal Memory Access" or "CUBLAS_STATUS_NOT_SUPPORTED".
We have confirmed this behavior on an RTX 6000 Blackwell Workstation Edition (Compute Capability 12.0). Independent testing on RTX 5090 (Blackwell) has reproduced the same issue (ref: ComfyUI-nunchaku#476).
Environment
- GPU: RTX 6000 Blackwell (Compute Capability 12.0) / RTX 5090
- CUDA: 12.8
- PyTorch: 2.12.0.dev20260217+cu128
- Nunchaku: 1.3.0.dev20260302+cu12.8torch2.12 (Installed via
pip install --no-build-isolation -v -e .to ensure ABI compatibility with PyTorch 2.12) - OS: Linux 6.14.0-37-generic
Reproduction Steps
The following minimal script reproduces the crash instantly on SM120. It initializes Nunchaku and calls the quantization kernel with dummy data.
import torch
import sys
def test_nunchaku_crash():
print(f"PyTorch Version: {torch.__version__}")
if not torch.cuda.is_available():
print("CUDA not available.")
return
device = torch.device("cuda")
# Check capability
cap = torch.cuda.get_device_capability(device)
print(f"Device: {torch.cuda.get_device_name(device)} | Capability: {cap}")
try:
from nunchaku.ops.quantize import svdq_quantize_w4a4_act_fuse_lora_cuda
print("Nunchaku loaded successfully.")
except ImportError as e:
print(f"Nunchaku import failed: {e}")
return
# Standard Shapes for Flux.2 LoRA Training
M = 16384
K = 3072
rank = 16
pad = 16
# Create valid dummy inputs (BF16/FP32)
# Ensure strict contiguity to rule out layout issues
x_flat = torch.randn((M, K), device=device, dtype=torch.bfloat16).contiguous()
lora_down = torch.randn((K, rank), device=device, dtype=torch.bfloat16).contiguous()
lora_act_out = torch.empty((M, rank), device=device, dtype=torch.float32).contiguous()
print("Launching svdq_quantize_w4a4_act_fuse_lora_cuda kernel...")
try:
# The kernel launch itself usually succeeds (async)
qact, act_scale, _ = svdq_quantize_w4a4_act_fuse_lora_cuda(
x_flat,
lora_down=lora_down,
lora_act_out=lora_act_out,
fp4=True,
pad_size=pad,
)
print("Kernel launch returned.")
# Force synchronization to catch the async error
torch.cuda.synchronize()
print("Synchronization success (No crash yet).")
# Attempt to access memory results
print(f"Result qact shape: {qact.shape}")
except RuntimeError as e:
print(f"\nCRASH DETECTED:\n{e}")
print("\nNote: 'Illegal memory access' indicates the kernel corrupted the context.")
if __name__ == "__main__":
test_nunchaku_crash()
Traceback / Output
PyTorch Version: 2.12.0.dev20260217+cu128
Device: RTX 6000 Blackwell | Capability: (12, 0)
Nunchaku loaded successfully.
Launching svdq_quantize_w4a4_act_fuse_lora_cuda kernel...
Kernel launch returned.
CRASH DETECTED:
CUDA error: an illegal memory access was encountered
CUDA kernel errors might be asynchronously reported at some other API call...
Analysis
We have already verified that the issue is not due to:
- PyTorch Version Mismatch: We rebuilt Nunchaku locally against PyTorch 2.12 using
--no-build-isolation. - Tensor Layouts: We verified that
x_flatandlora_downare contiguous. - Hardware Support:
torch._scaled_mmworks perfectly on this GPU when we bypass Nunchaku and use dummy quantized inputs withScale A (Native)+Scale B (Transposed)layout.
The issue appears to be specific to the implementation ofgemm_w4a4_launch_impl.cuhorsvdq_quantizekernels when running on SM120 hardware, possibly related to warp scheduling or shared memory access patterns changes in Blackwell.