Skip to content

Commit 37a5bff

Browse files
committed
Do full propagation.
Signed-off-by: hanhanW <hanhan0912@gmail.com>
1 parent 62665df commit 37a5bff

File tree

2 files changed

+26
-46
lines changed

2 files changed

+26
-46
lines changed

compiler/src/iree/compiler/Codegen/Common/PropagateDispatchConfig.cpp

Lines changed: 20 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -7,9 +7,7 @@
77
#include "iree/compiler/Codegen/Common/Passes.h"
88
#include "iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenOps.h"
99
#include "iree/compiler/Dialect/HAL/IR/HALOps.h"
10-
#include "llvm/ADT/SmallVectorExtras.h"
1110
#include "llvm/Support/DebugLog.h"
12-
#include "mlir/IR/IRMapping.h"
1311

1412
#define DEBUG_TYPE "iree-codegen-propagate-dispatch-config"
1513

@@ -43,59 +41,41 @@ void PropagateDispatchConfigPass::runOnOperation() {
4341
return;
4442
}
4543

46-
// Build a map from export name to export op.
47-
DenseMap<StringRef, IREE::HAL::ExecutableExportOp> exportMap;
48-
for (auto exportOp : variantOp.getExportOps()) {
49-
exportMap[exportOp.getSymName()] = exportOp;
50-
}
51-
44+
SymbolTable symbolTable(variantOp);
5245
for (IREE::Codegen::DispatchConfigOp configOp : configOps) {
5346
StringRef funcRef = configOp.getFunctionRef();
54-
55-
if (!exportMap.contains(funcRef)) {
56-
// No export for this function (e.g. a helper that is not an entry
57-
// point). Erase the dispatch_config and move on.
47+
auto exportOp =
48+
symbolTable.lookup<IREE::HAL::ExecutableExportOp>(funcRef);
49+
if (!exportOp) {
50+
// No export for this function, so erase the dispatch_config and move on.
5851
configOp.erase();
5952
continue;
6053
}
61-
IREE::HAL::ExecutableExportOp exportOp = exportMap[funcRef];
6254

63-
// Replace the export count region body with the dispatch_config body.
64-
// Map dispatch_config block args to export block args (offset by 1
65-
// for !hal.device at position 0).
55+
// Move the dispatch_config region body into the export count region,
56+
// replacing iree_codegen.yield with hal.return.
6657
Region &countRegion = exportOp.getWorkgroupCount();
6758
OpBuilder builder(&getContext());
6859

6960
if (!countRegion.empty()) {
7061
Block &configBlock = configOp.getBody().front();
7162
Block *exportBlock = exportOp.getWorkgroupCountBody();
72-
unsigned configArity = configBlock.getNumArguments();
73-
unsigned exportArity = exportBlock->getNumArguments();
74-
// Export count region has !hal.device as block arg 0, then workloads.
75-
if (configArity + 1 > exportArity) {
76-
configOp.emitError("workload arity mismatch: dispatch_config has ")
77-
<< configArity << " args but export count region has "
78-
<< exportArity << " (expected >= " << configArity + 1
79-
<< " = config args + !hal.device)";
63+
TypeRange configArgTypes = configBlock.getArgumentTypes();
64+
TypeRange exportArgTypes = exportBlock->getArgumentTypes();
65+
if (configArgTypes != exportArgTypes) {
66+
configOp.emitError("block argument mismatch: dispatch_config has (")
67+
<< configArgTypes << ") but export count region has ("
68+
<< exportArgTypes << ")";
8069
return signalPassFailure();
8170
}
82-
exportBlock->clear();
83-
IRMapping mapping;
84-
for (unsigned i = 0; i < configArity; ++i) {
85-
mapping.map(configBlock.getArgument(i),
86-
exportBlock->getArgument(i + 1));
87-
}
88-
builder.setInsertionPointToEnd(exportBlock);
89-
for (Operation &op : configBlock.without_terminator()) {
90-
builder.clone(op, mapping);
91-
}
71+
countRegion.takeBody(configOp.getBody());
9272
// Replace iree_codegen.yield with hal.return.
93-
auto yieldOp = cast<IREE::Codegen::YieldOp>(configBlock.getTerminator());
94-
auto returnValues =
95-
llvm::map_to_vector(yieldOp.getOperands(), [&](Value v) {
96-
return mapping.lookupOrDefault(v);
97-
});
98-
IREE::HAL::ReturnOp::create(builder, yieldOp.getLoc(), returnValues);
73+
Block &block = countRegion.front();
74+
auto yieldOp = cast<IREE::Codegen::YieldOp>(block.getTerminator());
75+
builder.setInsertionPoint(yieldOp);
76+
IREE::HAL::ReturnOp::create(builder, yieldOp.getLoc(),
77+
yieldOp.getOperands());
78+
yieldOp.erase();
9979
}
10080

10181
// Set workgroup_size and subgroup_size on the export.

compiler/src/iree/compiler/Codegen/Common/test/propagate_dispatch_config.mlir

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ hal.executable private @basic_exe {
1919
}
2020
iree_codegen.dispatch_config @matmul
2121
workgroup_size = [64, 16, 1] subgroup_size = 64 {
22-
^bb0(%w0: index, %w1: index):
22+
^bb0(%device: !hal.device, %w0: index, %w1: index):
2323
%c1 = arith.constant 1 : index
2424
%0 = affine.apply affine_map<()[s0] -> (s0 ceildiv 128)>()[%w0]
2525
iree_codegen.yield %0, %w1, %c1 : index, index, index
@@ -72,13 +72,13 @@ hal.executable private @specialized_exe {
7272
}
7373
iree_codegen.dispatch_config @matmul
7474
workgroup_size = [64, 16, 1] subgroup_size = 64 {
75-
^bb0(%w0: index):
75+
^bb0(%device: !hal.device, %w0: index):
7676
%c1 = arith.constant 1 : index
7777
iree_codegen.yield %w0, %c1, %c1 : index, index, index
7878
}
7979
iree_codegen.dispatch_config @matmul_0
8080
workgroup_size = [256, 1, 1] subgroup_size = 64 {
81-
^bb0(%w0: index):
81+
^bb0(%device: !hal.device, %w0: index):
8282
%c1 = arith.constant 1 : index
8383
%0 = affine.apply affine_map<()[s0] -> (s0 ceildiv 64)>()[%w0]
8484
iree_codegen.yield %0, %c1, %c1 : index, index, index
@@ -119,7 +119,7 @@ hal.executable private @no_subgroup_exe {
119119
}
120120
iree_codegen.dispatch_config @entry
121121
workgroup_size = [1024, 1, 1] {
122-
^bb0(%w0: index):
122+
^bb0(%device: !hal.device, %w0: index):
123123
%c1 = arith.constant 1 : index
124124
iree_codegen.yield %w0, %c1, %c1 : index, index, index
125125
}
@@ -174,10 +174,10 @@ hal.executable private @arity_mismatch_exe {
174174
func.func @entry() {
175175
return
176176
}
177-
// expected-error @+1 {{workload arity mismatch}}
177+
// expected-error @+1 {{block argument mismatch}}
178178
iree_codegen.dispatch_config @entry
179179
workgroup_size = [64, 1, 1] {
180-
^bb0(%w0: index, %w1: index, %w2: index):
180+
^bb0(%device: !hal.device, %w0: index, %w1: index, %w2: index):
181181
%c1 = arith.constant 1 : index
182182
iree_codegen.yield %w0, %c1, %c1 : index, index, index
183183
}

0 commit comments

Comments
 (0)