Skip to content

[Bug] Critical SM120 (Blackwell) Incompatibility: "Illegal Memory Access" in svdq_quantize_w4a4_act_fuse_lora_cuda #911

@jayhsu0627

Description

@jayhsu0627

Checklist

Describe the Bug

We are encountering a hard crash (CUDA error: an illegal memory access was encountered) when running the svdq_quantize_w4a4_act_fuse_lora_cuda kernel on NVIDIA Blackwell (SM120) architecture.
This occurs even with perfectly contiguous inputs and valid shapes. The crash is asynchronous: the kernel launch appears to succeed, but the CUDA context is corrupted immediately after, causing the next CUDA synchronization or kernel launch (e.g., torch.randn or torch._scaled_mm) to fail with "Illegal Memory Access" or "CUBLAS_STATUS_NOT_SUPPORTED".

We have confirmed this behavior on an RTX 6000 Blackwell Workstation Edition (Compute Capability 12.0). Independent testing on RTX 5090 (Blackwell) has reproduced the same issue (ref: ComfyUI-nunchaku#476).

Environment

  • GPU: RTX 6000 Blackwell (Compute Capability 12.0) / RTX 5090
  • CUDA: 12.8
  • PyTorch: 2.12.0.dev20260217+cu128
  • Nunchaku: 1.3.0.dev20260302+cu12.8torch2.12 (Installed via pip install --no-build-isolation -v -e . to ensure ABI compatibility with PyTorch 2.12)
  • OS: Linux 6.14.0-37-generic

Reproduction Steps

The following minimal script reproduces the crash instantly on SM120. It initializes Nunchaku and calls the quantization kernel with dummy data.

import torch
import sys

def test_nunchaku_crash():
    print(f"PyTorch Version: {torch.__version__}")
    if not torch.cuda.is_available():
        print("CUDA not available.")
        return

    device = torch.device("cuda")
    
    # Check capability
    cap = torch.cuda.get_device_capability(device)
    print(f"Device: {torch.cuda.get_device_name(device)} | Capability: {cap}")
    
    try:
        from nunchaku.ops.quantize import svdq_quantize_w4a4_act_fuse_lora_cuda
        print("Nunchaku loaded successfully.")
    except ImportError as e:
        print(f"Nunchaku import failed: {e}")
        return

    # Standard Shapes for Flux.2 LoRA Training
    M = 16384
    K = 3072
    rank = 16
    pad = 16
    
    # Create valid dummy inputs (BF16/FP32)
    # Ensure strict contiguity to rule out layout issues
    x_flat = torch.randn((M, K), device=device, dtype=torch.bfloat16).contiguous()
    lora_down = torch.randn((K, rank), device=device, dtype=torch.bfloat16).contiguous()
    lora_act_out = torch.empty((M, rank), device=device, dtype=torch.float32).contiguous()
    
    print("Launching svdq_quantize_w4a4_act_fuse_lora_cuda kernel...")
    
    try:
        # The kernel launch itself usually succeeds (async)
        qact, act_scale, _ = svdq_quantize_w4a4_act_fuse_lora_cuda(
            x_flat,
            lora_down=lora_down,
            lora_act_out=lora_act_out,
            fp4=True,
            pad_size=pad,
        )
        print("Kernel launch returned.")
        
        # Force synchronization to catch the async error
        torch.cuda.synchronize()
        print("Synchronization success (No crash yet).")
        
        # Attempt to access memory results
        print(f"Result qact shape: {qact.shape}")
        
    except RuntimeError as e:
        print(f"\nCRASH DETECTED:\n{e}")
        print("\nNote: 'Illegal memory access' indicates the kernel corrupted the context.")

if __name__ == "__main__":
    test_nunchaku_crash()

Traceback / Output

PyTorch Version: 2.12.0.dev20260217+cu128
Device: RTX 6000 Blackwell | Capability: (12, 0)
Nunchaku loaded successfully.
Launching svdq_quantize_w4a4_act_fuse_lora_cuda kernel...
Kernel launch returned.

CRASH DETECTED:
CUDA error: an illegal memory access was encountered
CUDA kernel errors might be asynchronously reported at some other API call...

Analysis
We have already verified that the issue is not due to:

  1. PyTorch Version Mismatch: We rebuilt Nunchaku locally against PyTorch 2.12 using --no-build-isolation.
  2. Tensor Layouts: We verified that x_flat and lora_down are contiguous.
  3. Hardware Support: torch._scaled_mm works perfectly on this GPU when we bypass Nunchaku and use dummy quantized inputs with Scale A (Native) + Scale B (Transposed) layout.
    The issue appears to be specific to the implementation of gemm_w4a4_launch_impl.cuh or svdq_quantize kernels when running on SM120 hardware, possibly related to warp scheduling or shared memory access patterns changes in Blackwell.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions