Skip to content

Commit b9d437f

Browse files
committed
[Codegen] Introduce PropagateDispatchConfig pass.
The pass clones the dispatch metadata and slice computation from IREE::Codegen::DispatchConfig op to HAL::ExportOp. It works at `IREE::HAL::ExecutableVariantOp` scope because it needs to access the export op and configs within the inner module. Signed-off-by: hanhanW <hanhan0912@gmail.com>
1 parent d7da789 commit b9d437f

File tree

7 files changed

+324
-1
lines changed

7 files changed

+324
-1
lines changed

compiler/src/iree/compiler/Codegen/Common/BUILD.bazel

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -147,6 +147,7 @@ iree_compiler_cc_library(
147147
"Passes.cpp",
148148
"PatchFuncOps.cpp",
149149
"PropagateConstantOffsets.cpp",
150+
"PropagateDispatchConfig.cpp",
150151
"PropagateDispatchSizeBounds.cpp",
151152
"PropagateReshapesByExpansion.cpp",
152153
"ReconcileTranslationInfo.cpp",

compiler/src/iree/compiler/Codegen/Common/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -140,6 +140,7 @@ iree_cc_library(
140140
"Passes.cpp"
141141
"PatchFuncOps.cpp"
142142
"PropagateConstantOffsets.cpp"
143+
"PropagateDispatchConfig.cpp"
143144
"PropagateDispatchSizeBounds.cpp"
144145
"PropagateReshapesByExpansion.cpp"
145146
"ReconcileTranslationInfo.cpp"

compiler/src/iree/compiler/Codegen/Common/Passes.td

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -483,7 +483,19 @@ def CreateDispatchConfigPass
483483
];
484484
}
485485

486-
486+
def PropagateDispatchConfigPass
487+
: Pass<"iree-codegen-propagate-dispatch-config", "IREE::HAL::ExecutableVariantOp"> {
488+
let summary = "Propagate dispatch_config into hal.executable.export count regions.";
489+
let description = [{
490+
For each `iree_codegen.dispatch_config` in the inner module:
491+
1. Find the matching `hal.executable.export` by function_ref name.
492+
2. Replace the export count region body with the dispatch_config
493+
body, mapping block args (offset by 1 for `!hal.device`) and
494+
replacing `iree_codegen.yield` with `hal.return`.
495+
3. Set workgroup_size / subgroup_size on the export.
496+
4. Erase the dispatch_config op.
497+
}];
498+
}
487499
def ReplaceSlowMinMaxOpsPass
488500
: InterfacePass<"iree-codegen-replace-slow-min-max-ops", "mlir::FunctionOpInterface"> {
489501
let summary =
Lines changed: 121 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,121 @@
1+
// Copyright 2026 The IREE Authors
2+
//
3+
// Licensed under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
7+
#include "iree/compiler/Codegen/Common/Passes.h"
8+
#include "iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenOps.h"
9+
#include "iree/compiler/Dialect/HAL/IR/HALOps.h"
10+
#include "llvm/ADT/SmallVectorExtras.h"
11+
#include "llvm/Support/DebugLog.h"
12+
#include "mlir/IR/IRMapping.h"
13+
14+
#define DEBUG_TYPE "iree-codegen-propagate-dispatch-config"
15+
16+
namespace mlir::iree_compiler {
17+
18+
#define GEN_PASS_DEF_PROPAGATEDISPATCHCONFIGPASS
19+
#include "iree/compiler/Codegen/Common/Passes.h.inc"
20+
21+
namespace {
22+
23+
class PropagateDispatchConfigPass final
24+
: public impl::PropagateDispatchConfigPassBase<
25+
PropagateDispatchConfigPass> {
26+
public:
27+
using Base::Base;
28+
void runOnOperation() override;
29+
};
30+
31+
void PropagateDispatchConfigPass::runOnOperation() {
32+
IREE::HAL::ExecutableVariantOp variantOp = getOperation();
33+
ModuleOp innerModule = variantOp.getInnerModule();
34+
if (!innerModule) {
35+
return;
36+
}
37+
38+
// Collect all dispatch_config ops.
39+
SmallVector<IREE::Codegen::DispatchConfigOp> configOps;
40+
innerModule->walk(
41+
[&](IREE::Codegen::DispatchConfigOp op) { configOps.push_back(op); });
42+
if (configOps.empty()) {
43+
return;
44+
}
45+
46+
// Build a map from export name to export op.
47+
DenseMap<StringRef, IREE::HAL::ExecutableExportOp> exportMap;
48+
for (auto exportOp : variantOp.getExportOps()) {
49+
exportMap[exportOp.getSymName()] = exportOp;
50+
}
51+
52+
for (IREE::Codegen::DispatchConfigOp configOp : configOps) {
53+
StringRef funcRef = configOp.getFunctionRef();
54+
55+
if (!exportMap.contains(funcRef)) {
56+
// No export for this function (e.g. a helper that is not an entry
57+
// point). Erase the dispatch_config and move on.
58+
configOp.erase();
59+
continue;
60+
}
61+
IREE::HAL::ExecutableExportOp exportOp = exportMap[funcRef];
62+
63+
// Replace the export count region body with the dispatch_config body.
64+
// Map dispatch_config block args to export block args (offset by 1
65+
// for !hal.device at position 0).
66+
Region &countRegion = exportOp.getWorkgroupCount();
67+
OpBuilder builder(&getContext());
68+
69+
if (!countRegion.empty()) {
70+
Block &configBlock = configOp.getBody().front();
71+
Block *exportBlock = exportOp.getWorkgroupCountBody();
72+
unsigned configArity = configBlock.getNumArguments();
73+
unsigned exportArity = exportBlock->getNumArguments();
74+
// Export count region has !hal.device as block arg 0, then workloads.
75+
if (configArity + 1 > exportArity) {
76+
configOp.emitError("workload arity mismatch: dispatch_config has ")
77+
<< configArity << " args but export count region has "
78+
<< exportArity << " (expected >= " << configArity + 1
79+
<< " = config args + !hal.device)";
80+
return signalPassFailure();
81+
}
82+
exportBlock->clear();
83+
IRMapping mapping;
84+
for (unsigned i = 0; i < configArity; ++i) {
85+
mapping.map(configBlock.getArgument(i),
86+
exportBlock->getArgument(i + 1));
87+
}
88+
builder.setInsertionPointToEnd(exportBlock);
89+
for (Operation &op : configBlock.without_terminator()) {
90+
builder.clone(op, mapping);
91+
}
92+
// Replace iree_codegen.yield with hal.return.
93+
auto yieldOp = cast<IREE::Codegen::YieldOp>(configBlock.getTerminator());
94+
auto returnValues =
95+
llvm::map_to_vector(yieldOp.getOperands(), [&](Value v) {
96+
return mapping.lookupOrDefault(v);
97+
});
98+
IREE::HAL::ReturnOp::create(builder, yieldOp.getLoc(), returnValues);
99+
}
100+
101+
// Set workgroup_size and subgroup_size on the export.
102+
auto wgSize = configOp.getWorkgroupSize();
103+
if (!wgSize) {
104+
configOp.emitError("missing workgroup_size attribute");
105+
return signalPassFailure();
106+
}
107+
SmallVector<int64_t, 3> wgSizePadded(wgSize->begin(), wgSize->end());
108+
while (wgSizePadded.size() < 3) {
109+
wgSizePadded.push_back(1);
110+
}
111+
exportOp.setWorkgroupSizeAttr(builder.getIndexArrayAttr(wgSizePadded));
112+
if (auto subgroupSize = configOp.getSubgroupSize()) {
113+
exportOp.setSubgroupSizeAttr(builder.getIndexAttr(subgroupSize.value()));
114+
}
115+
116+
configOp.erase();
117+
}
118+
}
119+
120+
} // namespace
121+
} // namespace mlir::iree_compiler

compiler/src/iree/compiler/Codegen/Common/test/BUILD.bazel

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -110,6 +110,7 @@ iree_lit_test_suite(
110110
"pad_dynamic_alloc.mlir",
111111
"patch_func_ops.mlir",
112112
"propagate_constant_offsets.mlir",
113+
"propagate_dispatch_config.mlir",
113114
"propagate_dispatch_size_bounds.mlir",
114115
"propagate_reshapes_by_expansion.mlir",
115116
"reconcile_translation_info.mlir",

compiler/src/iree/compiler/Codegen/Common/test/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -105,6 +105,7 @@ iree_lit_test_suite(
105105
"pad_dynamic_alloc.mlir"
106106
"patch_func_ops.mlir"
107107
"propagate_constant_offsets.mlir"
108+
"propagate_dispatch_config.mlir"
108109
"propagate_dispatch_size_bounds.mlir"
109110
"propagate_reshapes_by_expansion.mlir"
110111
"reconcile_translation_info.mlir"
Lines changed: 186 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,186 @@
1+
// RUN: iree-opt --split-input-file --pass-pipeline="builtin.module(hal.executable(hal.executable.variant(iree-codegen-propagate-dispatch-config)))" --verify-diagnostics %s | FileCheck %s
2+
3+
// Basic: single dispatch_config + export.
4+
hal.executable private @basic_exe {
5+
hal.executable.variant public @rocm target(<"rocm", "rocm-hsaco-fb">) {
6+
hal.executable.export @matmul ordinal(0)
7+
layout(#hal.pipeline.layout<bindings = [
8+
#hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">,
9+
#hal.pipeline.binding<storage_buffer, Indirect>
10+
]>)
11+
count(%device: !hal.device, %w0: index, %w1: index)
12+
-> (index, index, index) {
13+
%x, %y, %z = iree_tensor_ext.dispatch.workgroup_count_from_slice(%w0, %w1)
14+
hal.return %x, %y, %z : index, index, index
15+
}
16+
builtin.module {
17+
func.func @matmul() {
18+
return
19+
}
20+
iree_codegen.dispatch_config @matmul
21+
workgroup_size = [64, 16, 1] subgroup_size = 64 {
22+
^bb0(%w0: index, %w1: index):
23+
%c1 = arith.constant 1 : index
24+
%0 = affine.apply affine_map<()[s0] -> (s0 ceildiv 128)>()[%w0]
25+
iree_codegen.yield %0, %w1, %c1 : index, index, index
26+
}
27+
}
28+
}
29+
}
30+
// CHECK-LABEL: hal.executable private @basic_exe
31+
// CHECK: hal.executable.export public @matmul
32+
// CHECK: count(%[[DEV:.+]]: !hal.device, %[[W0:.+]]: index, %[[W1:.+]]: index)
33+
// CHECK: %[[C1:.+]] = arith.constant 1 : index
34+
// CHECK: %[[X:.+]] = affine.apply
35+
// CHECK: hal.return %[[X]], %[[W1]], %[[C1]]
36+
// CHECK: } attributes {subgroup_size = 64 : index, workgroup_size = [64 : index, 16 : index, 1 : index]}
37+
// CHECK: builtin.module
38+
// CHECK: func.func @matmul()
39+
// CHECK-NOT: iree_codegen.dispatch_config
40+
41+
// -----
42+
43+
// Multiple dispatch_configs (specialization).
44+
hal.executable private @specialized_exe {
45+
hal.executable.variant public @rocm target(<"rocm", "rocm-hsaco-fb">) {
46+
hal.executable.export @matmul ordinal(0)
47+
layout(#hal.pipeline.layout<bindings = [
48+
#hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">,
49+
#hal.pipeline.binding<storage_buffer, Indirect>
50+
]>)
51+
count(%device: !hal.device, %w0: index)
52+
-> (index, index, index) {
53+
%x, %y, %z = iree_tensor_ext.dispatch.workgroup_count_from_slice(%w0)
54+
hal.return %x, %y, %z : index, index, index
55+
}
56+
hal.executable.export @matmul_0 ordinal(1)
57+
layout(#hal.pipeline.layout<bindings = [
58+
#hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">,
59+
#hal.pipeline.binding<storage_buffer, Indirect>
60+
]>)
61+
count(%device: !hal.device, %w0: index)
62+
-> (index, index, index) {
63+
%x, %y, %z = iree_tensor_ext.dispatch.workgroup_count_from_slice(%w0)
64+
hal.return %x, %y, %z : index, index, index
65+
}
66+
builtin.module {
67+
func.func @matmul() {
68+
return
69+
}
70+
func.func @matmul_0() {
71+
return
72+
}
73+
iree_codegen.dispatch_config @matmul
74+
workgroup_size = [64, 16, 1] subgroup_size = 64 {
75+
^bb0(%w0: index):
76+
%c1 = arith.constant 1 : index
77+
iree_codegen.yield %w0, %c1, %c1 : index, index, index
78+
}
79+
iree_codegen.dispatch_config @matmul_0
80+
workgroup_size = [256, 1, 1] subgroup_size = 64 {
81+
^bb0(%w0: index):
82+
%c1 = arith.constant 1 : index
83+
%0 = affine.apply affine_map<()[s0] -> (s0 ceildiv 64)>()[%w0]
84+
iree_codegen.yield %0, %c1, %c1 : index, index, index
85+
}
86+
}
87+
}
88+
}
89+
// CHECK-LABEL: hal.executable private @specialized_exe
90+
// CHECK: hal.executable.export public @matmul
91+
// CHECK: count(%{{.+}}: !hal.device, %[[W0A:.+]]: index)
92+
// CHECK: %[[C1A:.+]] = arith.constant 1 : index
93+
// CHECK: hal.return %[[W0A]], %[[C1A]], %[[C1A]]
94+
// CHECK: } attributes {subgroup_size = 64 : index, workgroup_size = [64 : index, 16 : index, 1 : index]}
95+
// CHECK: hal.executable.export public @matmul_0
96+
// CHECK: count(%{{.+}}: !hal.device, %[[W0B:.+]]: index)
97+
// CHECK: %[[C1B:.+]] = arith.constant 1 : index
98+
// CHECK: %[[X:.+]] = affine.apply
99+
// CHECK: hal.return %[[X]], %[[C1B]], %[[C1B]]
100+
// CHECK: } attributes {subgroup_size = 64 : index, workgroup_size = [256 : index, 1 : index, 1 : index]}
101+
102+
// -----
103+
104+
// No subgroup_size attribute.
105+
hal.executable private @no_subgroup_exe {
106+
hal.executable.variant public @rocm target(<"rocm", "rocm-hsaco-fb">) {
107+
hal.executable.export @entry ordinal(0)
108+
layout(#hal.pipeline.layout<bindings = [
109+
#hal.pipeline.binding<storage_buffer, Indirect>
110+
]>)
111+
count(%device: !hal.device, %w0: index)
112+
-> (index, index, index) {
113+
%x, %y, %z = iree_tensor_ext.dispatch.workgroup_count_from_slice(%w0)
114+
hal.return %x, %y, %z : index, index, index
115+
}
116+
builtin.module {
117+
func.func @entry() {
118+
return
119+
}
120+
iree_codegen.dispatch_config @entry
121+
workgroup_size = [1024, 1, 1] {
122+
^bb0(%w0: index):
123+
%c1 = arith.constant 1 : index
124+
iree_codegen.yield %w0, %c1, %c1 : index, index, index
125+
}
126+
}
127+
}
128+
}
129+
// CHECK-LABEL: hal.executable private @no_subgroup_exe
130+
// CHECK: hal.executable.export public @entry
131+
// CHECK: hal.return
132+
// CHECK: } attributes {workgroup_size = [1024 : index, 1 : index, 1 : index]}
133+
// CHECK-NOT: subgroup_size
134+
135+
// -----
136+
137+
// No dispatch_config ops — pass is a no-op.
138+
hal.executable private @noop_exe {
139+
hal.executable.variant public @rocm target(<"rocm", "rocm-hsaco-fb">) {
140+
hal.executable.export @entry ordinal(0)
141+
layout(#hal.pipeline.layout<bindings = [
142+
#hal.pipeline.binding<storage_buffer, Indirect>
143+
]>)
144+
count(%device: !hal.device, %w0: index)
145+
-> (index, index, index) {
146+
%x, %y, %z = iree_tensor_ext.dispatch.workgroup_count_from_slice(%w0)
147+
hal.return %x, %y, %z : index, index, index
148+
}
149+
builtin.module {
150+
func.func @entry() {
151+
return
152+
}
153+
}
154+
}
155+
}
156+
// CHECK-LABEL: hal.executable private @noop_exe
157+
// CHECK: iree_tensor_ext.dispatch.workgroup_count_from_slice
158+
159+
// -----
160+
161+
// Error: arity mismatch.
162+
hal.executable private @arity_mismatch_exe {
163+
hal.executable.variant public @rocm target(<"rocm", "rocm-hsaco-fb">) {
164+
hal.executable.export @entry ordinal(0)
165+
layout(#hal.pipeline.layout<bindings = [
166+
#hal.pipeline.binding<storage_buffer, Indirect>
167+
]>)
168+
count(%device: !hal.device, %w0: index)
169+
-> (index, index, index) {
170+
%x, %y, %z = iree_tensor_ext.dispatch.workgroup_count_from_slice(%w0)
171+
hal.return %x, %y, %z : index, index, index
172+
}
173+
builtin.module {
174+
func.func @entry() {
175+
return
176+
}
177+
// expected-error @+1 {{workload arity mismatch}}
178+
iree_codegen.dispatch_config @entry
179+
workgroup_size = [64, 1, 1] {
180+
^bb0(%w0: index, %w1: index, %w2: index):
181+
%c1 = arith.constant 1 : index
182+
iree_codegen.yield %w0, %c1, %c1 : index, index, index
183+
}
184+
}
185+
}
186+
}

0 commit comments

Comments
 (0)