-
Notifications
You must be signed in to change notification settings - Fork 871
Description
Overview
This issue tracks the implementation of full VectorDistribute pipeline support for ArgCompareOp (argmax/argmin operations) on AMDGPU targets. The goal is to enable efficient GPU code generation for reduction operations that return both a selected value and its corresponding index.
ArgCompareOp Overview
ArgCompareOp is defined in LinalgExtOps.td:645-770 and performs:
- A reduction over a specified dimension of a tensor
- Returns two outputs: the selected value AND its corresponding index
- Uses a user-defined comparator region that receives two values and returns
i1
Key Design Feature: The comparator region provides flexibility to express:
- argmax:
arith.cmpf ogt, %a, %b(greater than) - argmin:
arith.cmpf olt, %a, %b(less than) - Custom logic: Any boolean predicate comparing two values
// Example: argmax (select larger value)
iree_linalg_ext.arg_compare dimension(1)
ins(%input : tensor<2x10xf32>)
outs(%out_val, %out_idx : tensor<2xf32>, tensor<2xi32>) {
^bb0(%a: f32, %b: f32):
%cmp = arith.cmpf ogt, %a, %b : f32
iree_linalg_ext.yield %cmp : i1
}
// Example: argmin (select smaller value)
iree_linalg_ext.arg_compare dimension(1)
ins(%input : tensor<2x10xf32>)
outs(%out_val, %out_idx : tensor<2xf32>, tensor<2xi32>) {
^bb0(%a: f32, %b: f32):
%cmp = arith.cmpf olt, %a, %b : f32
iree_linalg_ext.yield %cmp : i1
}
// Example: custom comparator (select value with larger absolute value)
iree_linalg_ext.arg_compare dimension(1)
ins(%input : tensor<2x10xf32>)
outs(%out_val, %out_idx : tensor<2xf32>, tensor<2xi32>) {
^bb0(%a: f32, %b: f32):
%abs_a = math.absf %a : f32
%abs_b = math.absf %b : f32
%cmp = arith.cmpf ogt, %abs_a, %abs_b : f32
iree_linalg_ext.yield %cmp : i1
}Current VectorDistribute Pipeline
The VectorDistribute pipeline for reduction op follows:
linalg.generic → vector.multi_reduction → gpu.subgroup_reduce → amdgpu.dpp → rocdl.update.dpp
For arg_compare, the proposed pipeline is:
iree_linalg_ext.arg_compare (implicit-index mode)
↓ (TileAndDistributeToWorkgroups - if split reduction needed)
Partial reductions in scf.forall:
iree_linalg_ext.arg_compare (implicit-index, computes indices from ivs)
→ produces partial (value, index) pairs
↓ (PartialReductionOpInterface::mergeReductions)
Merge reduction:
iree_linalg_ext.arg_compare (explicit-index mode, 2 inputs)
ins(%partial_values, %partial_indices)
→ merges (value, index) pairs
↓ (GenericVectorizationPass - vectorizeArgCompareOp)
iree_vector_ext.arg_compare (vectorized form)
→ vector<...xf16>, vector<...xi32>
→ includes cloned comparator region
↓ (LLVMGPUConfigureTensorLayouts - setArgCompareAnchor)
iree_vector_ext.to_layout with NestedLayoutAttr
→ distributed across threads/subgroups
↓ (LLVMGPUVectorDistribute - DistributeArgCompare)
DPP butterfly reduction with cloned comparator
→ amdgpu.dpp for cross-lane data movement
→ execute comparator at each reduction stage
→ handle tie-breaking (prefer smaller index)
↓ (GPU lowering)
amdgpu.dpp + rocdl.ballot + rocdl.readlane + AMD intrinsics
Key implementation:
- New
DistributeArgComparepattern handles all comparator types uniformly - DPP butterfly reduction pattern (6 stages for 64-thread subgroup)
- Comparator region cloned and executed at each stage
Existing ArgMax Implementation (ROCM UKernel)
The ROCM argmax ukernel in iree_uk_amdgpu_argmax_f32i64.c demonstrates the key algorithm:
// 1. Reduce to find maximum across subgroup
float wgMax = laneMax;
for (int i = 1; i < warpSize; i *= 2) {
wgMax = __builtin_fmaxf(__shfl_xor_f(wgMax, i), wgMax);
}
// 2. Use ballot to find which lanes have the max
uint64_t laneHasMaxValmask = __ballot(wgMax == laneMax);
// 3. Handle index selection
if (__builtin_popcountll(laneHasMaxValmask) == 1) {
// Single max holder - direct write
if (wgMax == laneMax) {
outputBufferIdx[offset] = laneResult;
}
} else {
// Multiple max holders - find smallest index (argmax semantics)
int64_t indexVal = wgMax == laneMax ? laneResult : INT64_MAX;
laneResult = __ockl_wfred_min_i64(indexVal);
if (laneID == 0) {
outputBufferIdx[offset] = laneResult;
}
}Implementation Plan
Phase 1: Explicit-Index Mode Foundation ✅ (Completed)
This phase addressed the merge reduction challenge by extending arg_compare to accept optional index inputs.
- Extend op definition
[LinalgExt] Extend arg_compare to support both value and index provided ( explicit-index mode) #23153 - Update tiling and split reduction
[LinalgExt] Support and use arg_compare with explicit-index mode in split reduction #23218
[LinalgExt] Extend arg_compare tiling interface for explicit-index mode #23193 - Add verifier checks
[LinalgExt] Add verifier check to disallow index_base in explicit-index mode #23198 -
GPU Generalization [Deprecated](PR #23015)
Phase 2: VectorDistribute Pipeline Integration (Current Focus)
-
PartialReductionOuterReduction Support
[LinalgExt] Add OuterReduction tiling strategy for ArgCompareOp #23102) -
Add iree_vector_ext.arg_compare Op
[VectorExt] Add iree_vector_ext.arg_compare operation #23386 -
Vectorization Support
[VectorExt] Add vectorization support for iree_linalg_ext.arg_compare #23440
[Codegen] Fix ArgCompare vectorization #23775 -
Layout Configuration and Analysis
[Codegen][LLVMGPU] Add layout support for ArgCompare operations #23693 -
Distribution Pattern Implementation
- [WIP] upstream gpu.ballot so that we can do a decent distribution without leaking target info (rocdl.ballot).
[Codegen][GPU] Add DistributeArgCompare pattern #23793- File:
compiler/src/iree/compiler/Codegen/Common/GPU/GPUNestedLayoutDistributionPatterns.cpp - Add
DistributeArgComparepattern foriree_vector_ext.arg_compare - Generate DPP butterfly reduction (6 stages for 64-thread subgroup)
- Clone comparator region at each DPP stage and execute inline
- Use
amdgpu.dppfor cross-lane data movement (shuffle both values and indices) - Handle tie-breaking: when values are equal, prefer smaller index
- Handle local reduction for thread-local elements before subgroup reduction
- Lower to
amdgpu.dpp+rocdl.ballot+rocdl.readlaneintrinsics
- File:
-
KernelConfig Integration
- File:
compiler/src/iree/compiler/Codegen/LLVMGPU/LLVMGPUSelectLoweringStrategy.cpp - Add
setArgCompareReductionConfig()to configure reduction pipeline - Set workgroup size, subgroup size, and reduction tile sizes
- Note: Should be last step after distribution patterns are working
- File:
-
Testing
- Unit tests:
compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/gpu_vector_distribution.mlir- Test argmax (simple comparator: ogt)
- Test argmin (simple comparator: olt)
- Test custom comparator (e.g., absolute value comparison)
- Test both implicit and explicit-index modes
- Verify generated DPP instructions and tie-breaking logic
- E2E tests:
tests/e2e/linalg_ext/argcompare_amdgpu.mlir- Real-world argmax/argmin workloads
- Attention mechanism with argmax
- Custom comparator examples
- Performance validation on MI250/MI300
- Unit tests: