|
7 | 7 | #include "iree/compiler/Codegen/Common/Passes.h" |
8 | 8 | #include "iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenOps.h" |
9 | 9 | #include "iree/compiler/Dialect/HAL/IR/HALOps.h" |
10 | | -#include "llvm/ADT/SmallVectorExtras.h" |
11 | 10 | #include "llvm/Support/DebugLog.h" |
12 | | -#include "mlir/IR/IRMapping.h" |
13 | 11 |
|
14 | 12 | #define DEBUG_TYPE "iree-codegen-propagate-dispatch-config" |
15 | 13 |
|
@@ -43,59 +41,41 @@ void PropagateDispatchConfigPass::runOnOperation() { |
43 | 41 | return; |
44 | 42 | } |
45 | 43 |
|
46 | | - // Build a map from export name to export op. |
47 | | - DenseMap<StringRef, IREE::HAL::ExecutableExportOp> exportMap; |
48 | | - for (auto exportOp : variantOp.getExportOps()) { |
49 | | - exportMap[exportOp.getSymName()] = exportOp; |
50 | | - } |
51 | | - |
| 44 | + SymbolTable symbolTable(variantOp); |
52 | 45 | for (IREE::Codegen::DispatchConfigOp configOp : configOps) { |
53 | 46 | StringRef funcRef = configOp.getFunctionRef(); |
54 | | - |
55 | | - if (!exportMap.contains(funcRef)) { |
56 | | - // No export for this function (e.g. a helper that is not an entry |
57 | | - // point). Erase the dispatch_config and move on. |
| 47 | + auto exportOp = |
| 48 | + symbolTable.lookup<IREE::HAL::ExecutableExportOp>(funcRef); |
| 49 | + if (!exportOp) { |
| 50 | + // No export for this function, so erase the dispatch_config and move on. |
58 | 51 | configOp.erase(); |
59 | 52 | continue; |
60 | 53 | } |
61 | | - IREE::HAL::ExecutableExportOp exportOp = exportMap[funcRef]; |
62 | 54 |
|
63 | | - // Replace the export count region body with the dispatch_config body. |
64 | | - // Map dispatch_config block args to export block args (offset by 1 |
65 | | - // for !hal.device at position 0). |
| 55 | + // Move the dispatch_config region body into the export count region, |
| 56 | + // replacing iree_codegen.yield with hal.return. |
66 | 57 | Region &countRegion = exportOp.getWorkgroupCount(); |
67 | 58 | OpBuilder builder(&getContext()); |
68 | 59 |
|
69 | 60 | if (!countRegion.empty()) { |
70 | 61 | Block &configBlock = configOp.getBody().front(); |
71 | 62 | Block *exportBlock = exportOp.getWorkgroupCountBody(); |
72 | | - unsigned configArity = configBlock.getNumArguments(); |
73 | | - unsigned exportArity = exportBlock->getNumArguments(); |
74 | | - // Export count region has !hal.device as block arg 0, then workloads. |
75 | | - if (configArity + 1 > exportArity) { |
76 | | - configOp.emitError("workload arity mismatch: dispatch_config has ") |
77 | | - << configArity << " args but export count region has " |
78 | | - << exportArity << " (expected >= " << configArity + 1 |
79 | | - << " = config args + !hal.device)"; |
| 63 | + TypeRange configArgTypes = configBlock.getArgumentTypes(); |
| 64 | + TypeRange exportArgTypes = exportBlock->getArgumentTypes(); |
| 65 | + if (configArgTypes != exportArgTypes) { |
| 66 | + configOp.emitError("block argument mismatch: dispatch_config has (") |
| 67 | + << configArgTypes << ") but export count region has (" |
| 68 | + << exportArgTypes << ")"; |
80 | 69 | return signalPassFailure(); |
81 | 70 | } |
82 | | - exportBlock->clear(); |
83 | | - IRMapping mapping; |
84 | | - for (unsigned i = 0; i < configArity; ++i) { |
85 | | - mapping.map(configBlock.getArgument(i), |
86 | | - exportBlock->getArgument(i + 1)); |
87 | | - } |
88 | | - builder.setInsertionPointToEnd(exportBlock); |
89 | | - for (Operation &op : configBlock.without_terminator()) { |
90 | | - builder.clone(op, mapping); |
91 | | - } |
| 71 | + countRegion.takeBody(configOp.getBody()); |
92 | 72 | // Replace iree_codegen.yield with hal.return. |
93 | | - auto yieldOp = cast<IREE::Codegen::YieldOp>(configBlock.getTerminator()); |
94 | | - auto returnValues = |
95 | | - llvm::map_to_vector(yieldOp.getOperands(), [&](Value v) { |
96 | | - return mapping.lookupOrDefault(v); |
97 | | - }); |
98 | | - IREE::HAL::ReturnOp::create(builder, yieldOp.getLoc(), returnValues); |
| 73 | + Block &block = countRegion.front(); |
| 74 | + auto yieldOp = cast<IREE::Codegen::YieldOp>(block.getTerminator()); |
| 75 | + builder.setInsertionPoint(yieldOp); |
| 76 | + IREE::HAL::ReturnOp::create(builder, yieldOp.getLoc(), |
| 77 | + yieldOp.getOperands()); |
| 78 | + yieldOp.erase(); |
99 | 79 | } |
100 | 80 |
|
101 | 81 | // Set workgroup_size and subgroup_size on the export. |
|
0 commit comments