feat: Add ROCm backend with attention network support#2375
feat: Add ROCm backend with attention network support#2375johnnytshi wants to merge 8 commits intoLeelaChessZero:masterfrom
Conversation
Implemented a complete ROCm backend for AMD GPUs, enabling support for
modern attention-based chess networks on RDNA 3.5 and other AMD architectures.
Implementation Details:
- Added full ROCm backend in src/neural/backends/rocm/
- Implemented attention network architecture (multi-head self-attention, FFN, embeddings)
- Used rocBLAS for GEMM operations and MIOpen for convolutions
- NCHW layout optimized for FP16 performance on RDNA 3.5
- Three backend variants: rocm (FP32), rocm-fp16 (FP16), rocm-auto (auto-detect)
- MIOpen is required dependency (similar to cuDNN for CUDA)
- Automatic AMD GPU architecture detection via rocm_agent_enumerator
- Build option: -Drocm=true -Damd_gfx=gfx1151 (or auto-detect)
Key Files:
- src/neural/backends/rocm/network_rocm.cc - Main network implementation
- src/neural/backends/rocm/layers.{cc,h} - Layer implementations
- src/neural/backends/rocm/*.hip - GPU kernels (FP16 and FP32)
- meson.build, meson_options.txt - Build configuration
Performance Notes:
- FP16 performance: >2000 nps on Strix Halo (Radeon 8060S, gfx1151)
- Automatic batch size tuning (min_batch=64 for RDNA 3.5)
- Tested rocWMMA but rocBLAS provided better performance
OpenCL/SYCL Compatibility:
- Preserved existing OpenCL/SYCL AMD backend (uses hip_* naming)
- ROCm backend separate from SYCL backend (uses rocm_* naming)
Verification (Strix Halo - Radeon 8060S, gfx1151):
- Tested models: 768x15x24h-t82-swa-7464000.pb.gz and maia-1900.pb.gz
- Backend: rocm-fp16 functional and producing correct moves
- ROCm 7.2.53150, MIOpen 3.5.1
- Only tested on RDNA 3.5; other AMD architectures not verified
Co-Authored-By: Claude Sonnet 4.5 <noreply@anthropic.com>
|
There is a lot to discuss here, can you join our Discord chat? |
sure, which channel should i be using? |
|
some background, I also tried adding libTorch backend, but that's about 5-10 times slower than raw ROCm, on the same hardware |
Menkib64
left a comment
There was a problem hiding this comment.
#2365 has an automated conversion from a newer version of cuda backend. I stopped attempting to fix bugs which I found because onnx-migraphx is much faster. My guess is that most kernels would have to be tuned specifically for AMD.
I discovered a few bugs which I documented in the pull request message.
Automatic batch size tuning (min_batch=64 for RDNA 3.5)
This sounds like you might be optimizing for a wrong target. Optimization target is a combination of batch latency and throughput. Search cannot always produce optimally sized batches which requires backend to evaluate suboptimal batches at the lowest possible latency. Backend or user should tell the optimal batch for the search which should be the smallest batch size which reaches the best or close to the best throughput. Running backendbench with --clippy can be used to estimate the best trade of between latency and throughput.
Tested rocWMMA but rocBLAS provided better performance
I haven't used rocWMMA. My assumption is that either rocWMMA, composable kernels or MigraphX offers the best performance potential. User code has to implement the same optimizations which are included in rocBLAS. It can be hard to optimize GEMM correctly because small changes can prevent HW from doing memory operations at the same time as computation. These libraries should allow user to fuse other work with matrix multiplications. Fused kernels avoid overhead from small kernels which allow higher network evaluation performance. The first step would be reaching rocBLAS performance with custom matrix multiplication kernels.
I mostly skipped GPU code. There are too many changes to quickly check everything.
| 'src/neural/backends/rocm/layers.cc', | ||
| 'src/neural/backends/rocm/network_rocm.cc'] | ||
|
|
||
| add_project_arguments('-DUSE_MIOPEN', language : 'cpp') |
There was a problem hiding this comment.
USE_MIOPEN seem redundant. The code can be enabled always.
| includes += include_directories(d, is_system: true) | ||
| endif | ||
| endforeach | ||
| includes += include_directories('src/neural/backends/rocm/') |
There was a problem hiding this comment.
Backend should work without adding an include path.
| endforeach | ||
| includes += include_directories('src/neural/backends/rocm/') | ||
|
|
||
| add_project_arguments('-DUSE_HIP', '-D__HIP_PLATFORM_AMD__', language : 'cpp') |
There was a problem hiding this comment.
USE_HIP looks redundant.
__HIP_PLATFORM_AMD__ can be set using dependency('hip')
| constexpr bool fp16 = std::is_same<half, DataType>::value; | ||
| if (fp16) { | ||
| // FP16: optimal batch 49-73, use 64 as default | ||
| min_batch_size_ = 64; |
There was a problem hiding this comment.
Minimum batch size of 64 doesn't sound optimal when search wants to evaluate only a position. Latency of evaluation matters too.
If kernels would be optimized well for AMD, then I would expect that optimal batch size would be multiple of N * compute_units/2. N would depend on the size of network in use. For BT4 it should be 1 and smaller networks would require higher N.
| miopenConvolutionDescriptor_t convDesc; | ||
| miopenTensorDescriptor_t xDesc; | ||
| // MIOpen uses miopenCreateTensorDescriptor for both filters and tensors | ||
| miopenCreateTensorDescriptor(&wDesc); |
There was a problem hiding this comment.
miopen initialization and allocations should happen only when network needs miopen. The important networks use transformers only currently.
| } | ||
|
|
||
| // Defer policy copy until after all GPU work is queued (Optimization #3) | ||
| ReportHIPErrors(hipMemcpyAsync(io->op_policy_mem_, io->op_policy_mem_gpu_, |
There was a problem hiding this comment.
There should be a separate stream for output downloads. A separate stream allows download happen at the same time while other computations are happening.
The latest version of cuda backend has an example how to do it for both uploads and downloads.
| ReportHIPErrors(hipEventRecord(policy_ready_event_, stream)); | ||
|
|
||
| // Wait ONLY for value (needed for CPU softmax) | ||
| ReportHIPErrors(hipEventSynchronize(value_ready_event_)); |
There was a problem hiding this comment.
hipEventSynchronize should be outside locked section. It will be blocking until GPU has completed operations. We want to queue the next batch before the current batch completes the computation.
| constexpr bool fp16 = std::is_same<half, DataType>::value; | ||
|
|
||
| if (fp16) { | ||
| // Check if the GPU support FP16. |
There was a problem hiding this comment.
Documentation claims fp16 support for all hardware. This makes me thing that support check can be removed.
There was a problem hiding this comment.
This precision format is supported across all GPU architectures. The HIP types and functions are available for use in both host and device code, with implementation handled by the compiler and device libraries.
I'm not an expert here, but I think there are some older GPUs that don't support it directly, so a check seems right.
There was a problem hiding this comment.
D3D optimization article claims that RX Vega was the first generation which supported fp16. ROCm HW support starts from the next generation (first RDNA and CDNA architectures).
| rocblas_lib = cc.find_library('rocblas', dirs: rocm_libdirs, required: false) | ||
| hipblas_lib = cc.find_library('hipblas', dirs: rocm_libdirs, required: false) | ||
| miopen_lib = cc.find_library('MIOpen', dirs: rocm_libdirs, required: false) | ||
| amdhip_lib = cc.find_library('amdhip64', dirs: rocm_libdirs, required: false) |
There was a problem hiding this comment.
dependecy() function makes discovering libraries and flags easier. It should be used if possible. You can see my quick hacky build changes from #2365 for an example. My changes requires using hipcc for everything because it made flag handling simpler for quick testing. It should be possible to use hipcc only for kernels like your changes do.
There was a problem hiding this comment.
Let me give onnx-migraphx a try on my system
I've just implemented fused kernel for attention on my local, which improved T82 backendbench from 2000 nps, to 2500 nps. Right now I am playing around with multi stream, supposely it would improve another 30%.
…e implementation, a wrapper, build system updates, tuning scripts, and comprehensive documentation.
Implements three key optimizations to the ROCm flash attention kernel: 1. **Fix warp reduction bug** (CRITICAL correctness fix) - Changed loop condition from 'offset >= 16' to 'offset >= 1' - Previous code only executed one iteration instead of full warp reduction - Ensures proper max value propagation across all 32 threads in warp - Impact: Correctness + ~1% performance from better numerical stability 2. **Remove unnecessary synchronization barrier** - Eliminated __syncthreads() after KQ matrix computation (line 341) - Analysis showed only register operations between barrier and next shared memory access - No shared memory hazards, barrier was pure overhead - Impact: ~2% performance reduction in synchronization costs 3. **Optimize shared memory padding** - Reduced padding from +4 to +2 half2 elements (25% → 12.5% overhead) - Profiling confirmed 0% LDS bank conflicts with reduced padding - Saves 50% of padding overhead while maintaining memory safety - Impact: ~0.5% performance from reduced shared memory footprint 4. **Fix meson.build to enable flash attention in C++ compilation** - Added add_project_arguments() to pass -DUSE_FLASH_ATTENTION=1 to C++ compiler - Previously flags were only passed to HIP kernel compilation - Required for layers.cc to actually use the flash attention code path Performance Results (batch=64, 150 iterations): - Baseline (pre-optimization): ~2,246 nps mean / ~2,357 nps peak - Phase 1 (post-optimization): 2,261 nps mean / 2,419 nps peak - Improvement: +0.8% mean / +2.6% peak - Stability: CV = 1.94% (excellent) Profiling Data (rocprofv3): - L2 Cache Hit Rate: 62.5% (moderate - memory bandwidth bound) - LDS Bank Conflicts: 0.0% (optimal) - Occupancy: 57.6% avg (moderate - memory latency limited) - VGPR Usage: 8 registers/thread (excellent - not register bound) Total improvement since rocBLAS baseline (~2,000 nps): +13.1% Co-Authored-By: Claude Sonnet 4.5 <noreply@anthropic.com>
Reduced nbatch_K2 and nbatch_V2 from 32 to 24 to improve memory subsystem performance through better shared memory utilization and work distribution. Changes: - nbatch_k2_d32: 32 → 24 (K tile size) - nbatch_v2_d32: 32 → 24 (V tile size) Performance Impact: - Mean NPS: 2,261 → 2,358 (+4.3%) - Peak NPS: 2,419 → 2,588 (+7.0%) - Variance: CV = 4.93% (acceptable, up from 1.94%) Analysis: Profiling showed L2 cache hit rate remained at ~62.5%, so the performance gain comes from: 1. **Reduced shared memory pressure**: Smaller tiles use less LDS 2. **Better work distribution**: More loop iterations improve load balancing 3. **Improved instruction-level parallelism**: Compiler has more optimization opportunities Tested multiple configurations: - nbatch_K2/V2=32: 2,261 nps (baseline, lowest variance) - nbatch_K2/V2=24: 2,358 nps (best balance of performance/stability) - nbatch_K2/V2=16: 2,349 nps mean / 2,664 nps peak (highest performance, too much variance) Selected nbatch_K2/V2=24 as optimal tradeoff between performance gain and stability. Total improvement since rocBLAS baseline (~2,000 nps): +17.9% Co-Authored-By: Claude Sonnet 4.5 <noreply@anthropic.com>
- Implement per-stream resources (streams, rocBLAS handles, memory) - Add LockEval() method for conditional mutex locking - Fix lock.unlock() issue with empty unique_lock in multi-stream mode - Add extensive debug logging to track execution flow This commit includes debug output for troubleshooting.
Clean up debug logging while keeping all functional changes: - Multi-stream resource management - Lock ownership checking before unlock - Device context setting for per-stream resources Performance: 2,173 nps (same as single-stream baseline)
Implements selective use of hipBLASLt for feed-forward network (FFN) operations to optimize GPU utilization across different batch sizes. Key improvements: - Split-K parallelization for small batches (< 32) to saturate GPU - Bias fusion in GEMM epilogue eliminates memory bandwidth waste - Automatic fallback to rocBLAS for large batches (≥ 32) - LayerNorm kernel updated to support nullptr bias when pre-fused Performance results (AMD Radeon 8060S, gfx1151): - Batch 16: +12.3% improvement (585 vs 521 nps) - Batch 64: Baseline maintained (2,229 nps, no regression) - Small batches benefit from Split-K GPU saturation - Large batches bypass overhead, use optimized rocBLAS path Technical details: - hipBLASLt workspace: 8MB allocated for Split-K algorithms - Heuristic selection: Requests 10 algorithms, tries best-first - Threshold: N < 32 uses hipBLASLt, N ≥ 32 uses rocBLAS - Memory savings: Eliminates 94.5 MB/batch of redundant traffic Files modified: - src/neural/backends/rocm/layers.cc: Conditional FFN Dense 2 path - src/neural/backends/rocm/hipblaslt_wrapper.h: Split-K wrapper - src/neural/backends/rocm/common_kernels.hip: LayerNorm nullptr check - meson.build: hipBLASLt library detection and linking Co-Authored-By: Claude Sonnet 4.5 <noreply@anthropic.com>
|
building onnx-migraph is criminal. I had to build rocMLIR, AMDMIGraphX, then onnxruntime. at lower batch size, migraph is faster, but max is under 2000 nps for T82 across different batch sizes ROCm with handcrafted fused kernel hit 2500 nps for T82, with batch size of 64 |
|
There are typically prebuilt packages available to use migraphx. I have to only compile onnxruntime because it is typically only build for nvidia. BT4 is more interesting test than T82. T82 is a good network but BT4 architecture has shown improvements over it. |
|
I'll be trying this on NixOS with gfx1100. I've gotten the SYCL amd backend working just now and I'd like to compare the performance between this and that. I'll let you know once I get all this up and running |
Implemented a complete ROCm backend for AMD GPUs, enabling support for modern attention-based chess networks on RDNA 3.5 and other AMD architectures.
Implementation Details:
Key Files:
Performance Notes:
OpenCL/SYCL Compatibility:
Verification (Strix Halo - Radeon 8060S, gfx1151):