Skip to content

Commit 05ab7f4

Browse files
authored
[Codegen] Update warp reduction config for multiple reduction (#20585)
Updates the kernel configuration in the case of multiple reduction dimensions, with some having dynamic shapes.
1 parent 7456631 commit 05ab7f4

File tree

3 files changed

+62
-6
lines changed

3 files changed

+62
-6
lines changed

compiler/src/iree/compiler/Codegen/LLVMGPU/KernelConfig.cpp

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2480,14 +2480,28 @@ setWarpReductionConfig(IREE::GPU::TargetAttr target,
24802480
if (numDynamicDims > 0) {
24812481
SmallVector<int64_t> reductionTileSizes(op.getNumLoops(), 0);
24822482
int64_t preferredSubgroupSize = target.getPreferredSubgroupSize();
2483-
reductionTileSizes[reductionDims[0]] = preferredSubgroupSize;
2483+
// We should set the subgroup size on:
2484+
// Priority 1: The innermost reduction dimension with static shapes.
2485+
// Priority 2: If there's no reduction dimension with static shapes
2486+
// then the innermost reduction dim.
2487+
unsigned lastNonDynamicReductionDim = reductionDims.back();
2488+
if (reductionDims.size() > 1) {
2489+
for (unsigned dim : reductionDims) {
2490+
if (ShapedType::isDynamic(bounds[dim])) {
2491+
reductionTileSizes[dim] = 1;
2492+
} else {
2493+
lastNonDynamicReductionDim = dim;
2494+
}
2495+
}
2496+
}
2497+
reductionTileSizes[lastNonDynamicReductionDim] = preferredSubgroupSize;
24842498
TileSizesListType tileSizes;
24852499
tileSizes.emplace_back(std::move(workgroupTileSizes)); // Workgroup level
24862500
tileSizes.emplace_back(std::move(reductionTileSizes)); // Reduction level
24872501
std::array<int64_t, 3> workgroupSize = {preferredSubgroupSize, 1, 1};
24882502
if (failed(setOpConfigAndEntryPointFnTranslation(
24892503
entryPoint, op, tileSizes, CodeGenPipeline::LLVMGPUWarpReduction,
2490-
workgroupSize))) {
2504+
workgroupSize, preferredSubgroupSize))) {
24912505
return failure();
24922506
}
24932507
return success();

compiler/src/iree/compiler/Codegen/LLVMGPU/test/config_matvec.mlir

Lines changed: 45 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ func.func @dynamic_batch_matvec() {
3434
}
3535

3636
// CDNA3-DAG: #[[$CONFIG:.+]] = #iree_codegen.lowering_config<tile_sizes = {{\[}}[1, 1, 1], [0, 0, 0, 32]{{\]}}>
37-
// CDNA3-DAG: #[[$TRANSLATION:.+]] = #iree_codegen.translation_info<pipeline = LLVMGPUWarpReduction workgroup_size = [32, 1, 1]>
37+
// CDNA3-DAG: #[[$TRANSLATION:.+]] = #iree_codegen.translation_info<pipeline = LLVMGPUWarpReduction workgroup_size = [32, 1, 1] subgroup_size = 32>
3838
// CDNA3-LABEL: func.func @dynamic_batch_matvec()
3939
// CDNA3-SAME: translation_info = #[[$TRANSLATION]]
4040
// CDNA3: linalg.batch_matmul
@@ -367,15 +367,57 @@ func.func @dynamic_parallel_dims(%dynsize : index, %input : tensor<4x?x4096xf16>
367367
return %2 : tensor<4x?xf32>
368368
}
369369
// CHECK-DAG: #[[CONFIG:.+]] = #iree_codegen.lowering_config<tile_sizes = {{\[}}[1, 1], [0, 0, 64]{{\]}}
370-
// CHECK-DAG: #[[TRANSLATION:.+]] = #iree_codegen.translation_info<pipeline = LLVMGPUWarpReduction workgroup_size = [64, 1, 1]>
370+
// CHECK-DAG: #[[TRANSLATION:.+]] = #iree_codegen.translation_info<pipeline = LLVMGPUWarpReduction workgroup_size = [64, 1, 1] subgroup_size = 64>
371371
// CHECK: func @dynamic_parallel_dims
372372
// CHECK-SAME: translation_info = #[[TRANSLATION]]
373373
// CHECK: linalg.generic
374374
// CHECK-SAME: lowering_config = #[[CONFIG]]
375375

376376
// CDNA3-DAG: #[[CONFIG:.+]] = #iree_codegen.lowering_config<tile_sizes = {{\[}}[1, 1], [0, 0, 32]{{\]}}
377-
// CDNA3-DAG: #[[TRANSLATION:.+]] = #iree_codegen.translation_info<pipeline = LLVMGPUWarpReduction workgroup_size = [32, 1, 1]>
377+
// CDNA3-DAG: #[[TRANSLATION:.+]] = #iree_codegen.translation_info<pipeline = LLVMGPUWarpReduction workgroup_size = [32, 1, 1] subgroup_size = 32>
378378
// CDNA3: func @dynamic_parallel_dims
379379
// CDNA3-SAME: translation_info = #[[TRANSLATION]]
380380
// CDNA3: linalg.generic
381381
// CDNA3-SAME: lowering_config = #[[CONFIG]]
382+
383+
// -----
384+
385+
#map = affine_map<(d0, d1, d2, d3) -> (d0, d2, d3)>
386+
#map1 = affine_map<(d0, d1, d2, d3) -> (d0, d2, d3, d1)>
387+
#map2 = affine_map<(d0, d1, d2, d3) -> (d0, d1)>
388+
#map3 = affine_map<(d0, d1) -> (d0, d1)>
389+
#map4 = affine_map<(d0, d1) -> ()>
390+
func.func @test_dyn_reduction(%arg0: tensor<128x?x32xf8E4M3FNUZ>, %arg1: tensor<128x?x32x128xf8E4M3FNUZ>, %arg2: tensor<f32>) -> tensor<128x128xf8E4M3FNUZ> {
391+
%cst = arith.constant 0.000000e+00 : f32
392+
%cst_0 = arith.constant -2.400000e+02 : f8E4M3FNUZ
393+
%cst_1 = arith.constant 2.400000e+02 : f8E4M3FNUZ
394+
%0 = tensor.empty() : tensor<128x128xf8E4M3FNUZ>
395+
%1 = tensor.empty() : tensor<128x128xf32>
396+
%2 = linalg.fill ins(%cst : f32) outs(%1 : tensor<128x128xf32>) -> tensor<128x128xf32>
397+
%3 = linalg.generic {indexing_maps = [#map, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction", "reduction"]} ins(%arg0, %arg1 : tensor<128x?x32xf8E4M3FNUZ>, tensor<128x?x32x128xf8E4M3FNUZ>) outs(%2 : tensor<128x128xf32>) {
398+
^bb0(%in: f8E4M3FNUZ, %in_2: f8E4M3FNUZ, %out: f32):
399+
%5 = arith.extf %in : f8E4M3FNUZ to f32
400+
%6 = arith.extf %in_2 : f8E4M3FNUZ to f32
401+
%7 = arith.mulf %5, %6 : f32
402+
%8 = arith.addf %out, %7 : f32
403+
linalg.yield %8 : f32
404+
} -> tensor<128x128xf32>
405+
%4 = linalg.generic {indexing_maps = [#map3, #map4, #map3], iterator_types = ["parallel", "parallel"]} ins(%3, %arg2 : tensor<128x128xf32>, tensor<f32>) outs(%0 : tensor<128x128xf8E4M3FNUZ>) {
406+
^bb0(%in: f32, %in_2: f32, %out: f8E4M3FNUZ):
407+
%5 = arith.truncf %in : f32 to f8E4M3FNUZ
408+
%6 = arith.truncf %in_2 : f32 to f8E4M3FNUZ
409+
%7 = arith.divf %5, %6 : f8E4M3FNUZ
410+
%8 = arith.cmpf ult, %7, %cst_0 : f8E4M3FNUZ
411+
%9 = arith.select %8, %cst_0, %7 : f8E4M3FNUZ
412+
%10 = arith.cmpf ugt, %9, %cst_1 : f8E4M3FNUZ
413+
%11 = arith.select %10, %cst_1, %9 : f8E4M3FNUZ
414+
linalg.yield %11 : f8E4M3FNUZ
415+
} -> tensor<128x128xf8E4M3FNUZ>
416+
return %4 : tensor<128x128xf8E4M3FNUZ>
417+
}
418+
// CHECK-DAG: #[[$CONFIG:.+]] = #iree_codegen.lowering_config<tile_sizes = {{\[}}[1, 1], [0, 0, 1, 64]{{\]}}>
419+
// CHECK: #[[$TRANSLATION:.+]] = #iree_codegen.translation_info<pipeline = LLVMGPUWarpReduction workgroup_size = [64, 1, 1] subgroup_size = 64>
420+
// CHECK: func.func @test_dyn_reduction
421+
// CHECK-SAME: translation_info = #[[$TRANSLATION]]
422+
// CHECK: linalg.generic
423+
// CHECK-SAME: lowering_config = #[[$CONFIG]]

compiler/src/iree/compiler/Codegen/LLVMGPU/test/gpu_set_num_workgroups.mlir

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -699,7 +699,7 @@ func.func @i4_dequant_matvec() {
699699
}
700700

701701
// CHECK-DAG: #[[$CONFIG:.+]] = #iree_codegen.lowering_config<tile_sizes = {{\[}}[1, 1], [0, 0, 32]{{\]}}>
702-
// CHECK-DAG: #[[$TRANSLATION:.+]] = #iree_codegen.translation_info<pipeline = LLVMGPUWarpReduction workgroup_size = [32, 1, 1]>
702+
// CHECK-DAG: #[[$TRANSLATION:.+]] = #iree_codegen.translation_info<pipeline = LLVMGPUWarpReduction workgroup_size = [32, 1, 1] subgroup_size = 32>
703703
// CHECK-LABEL: func.func @i4_dequant_matvec()
704704
// CHECK-SAME: translation_info = #[[$TRANSLATION]]
705705
// CHECK: linalg.generic

0 commit comments

Comments
 (0)