Skip to content

Commit 510258b

Browse files
committed
Fix transform test
Signed-off-by: hanhanW <hanhan0912@gmail.com>
1 parent 5538cad commit 510258b

File tree

4 files changed

+47
-63
lines changed

4 files changed

+47
-63
lines changed

compiler/src/iree/compiler/Codegen/Common/TransformExtensions/CommonExtensions.cpp

Lines changed: 35 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
#include "iree/compiler/Codegen/Common/GPU/GPUVectorDistribution.h"
1313
#include "iree/compiler/Codegen/Common/GPU/Passes.h"
1414
#include "iree/compiler/Codegen/Common/Transforms.h"
15+
#include "iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenOps.h"
1516
#include "iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUAttrs.h"
1617
#include "iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUOps.h"
1718
#include "iree/compiler/Codegen/Dialect/GPU/Transforms/Transforms.h"
@@ -673,11 +674,42 @@ transform_dialect::PopulateWorkgroupCountRegionUsingNumThreadsSliceOp::
673674
}
674675

675676
auto funcOp = forAllOp->getParentOfType<mlir::FunctionOpInterface>();
676-
if (failed(
677-
lowerWorkgroupCountFromSliceOp(rewriter, funcOp, workgroupCount))) {
677+
678+
// Resolve the workgroup_count_from_slice in the dispatch_config op.
679+
// The dispatch_config is the codegen-level representation of the workgroup
680+
// count — codegen passes no longer touch the hal.executable.export directly.
681+
// PropagateDispatchConfig copies the result to the export later.
682+
auto moduleOp = funcOp->getParentOfType<ModuleOp>();
683+
if (!moduleOp) {
684+
return mlir::emitDefiniteFailure(state.getTopLevel(),
685+
"could not find parent module");
686+
}
687+
IREE::Codegen::DispatchConfigOp configOp;
688+
for (auto op : moduleOp.getOps<IREE::Codegen::DispatchConfigOp>()) {
689+
if (op.getFunctionRef() == funcOp.getName()) {
690+
configOp = op;
691+
break;
692+
}
693+
}
694+
if (!configOp) {
678695
return mlir::emitDefiniteFailure(state.getTopLevel(),
679-
"failed to lower workgroup count region");
696+
"could not find dispatch_config for '")
697+
<< funcOp.getName() << "'";
680698
}
699+
IREE::TensorExt::DispatchWorkgroupCountFromSliceOp fromSliceOp;
700+
configOp.getBody().walk(
701+
[&](IREE::TensorExt::DispatchWorkgroupCountFromSliceOp fs) {
702+
fromSliceOp = fs;
703+
return WalkResult::interrupt();
704+
});
705+
if (fromSliceOp) {
706+
if (failed(lowerWorkgroupCountFromSliceOp(rewriter, fromSliceOp, funcOp,
707+
workgroupCount))) {
708+
return mlir::emitDefiniteFailure(
709+
state.getTopLevel(), "failed to lower dispatch_config count region");
710+
}
711+
}
712+
681713
return DiagnosedSilenceableFailure::success();
682714
}
683715

compiler/src/iree/compiler/Codegen/Common/TransformExtensions/CommonExtensionsOps.td

Lines changed: 7 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -500,15 +500,13 @@ def PopulateWorkgroupCountRegionUsingNumThreadsSliceOp :
500500
TransformOpInterface,
501501
ReportTrackingListenerFailuresOpTrait]> {
502502
let description = [{
503-
Populate the workgroup_count region on the `hal.executable.export` op.
504-
505-
The default dispatch region formation expects that the workgroup count
506-
be computed from within the dispatch by using a program slice. The utility
507-
method `lowerWorkgroupCountFromSliceOp` handles populating the
508-
workgroup count region given the values in the dispatch that represent the
509-
number of workgroups. This transform op calls the underlying function using
510-
the `num_threads` value from the `scf.for_all` op that distributes the work
511-
to different workgroups.
503+
Populate the `iree_codegen.dispatch_config` workgroup count region.
504+
505+
Finds the `dispatch_config` op for the parent function and resolves the
506+
`workgroup_count_from_slice` placeholder in its body using the
507+
`num_threads` values from the given `scf.forall` op.
508+
`PropagateDispatchConfig` later copies the result to the
509+
`hal.executable.export` count region.
512510
}];
513511

514512
let arguments = (ins TransformHandleTypeInterface:$for_all_op);

compiler/src/iree/compiler/Codegen/Transforms/Transforms.cpp

Lines changed: 0 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -521,35 +521,6 @@ LogicalResult lowerWorkgroupCountFromSliceOp(
521521
return success();
522522
}
523523

524-
LogicalResult lowerWorkgroupCountFromSliceOp(
525-
RewriterBase &rewriter, mlir::FunctionOpInterface entryPointFn,
526-
ArrayRef<OpFoldResult> workgroupCount, int maxWorkgroupParallelDims) {
527-
std::optional<IREE::HAL::ExecutableExportOp> exportOp =
528-
getEntryPoint(entryPointFn);
529-
if (!exportOp) {
530-
return success();
531-
}
532-
Block *body = exportOp->getWorkgroupCountBody();
533-
if (!body) {
534-
return success();
535-
}
536-
auto countOps =
537-
body->getOps<IREE::TensorExt::DispatchWorkgroupCountFromSliceOp>();
538-
if (countOps.empty()) {
539-
// If there are no `flow.dispatch.workgroup_count_default` operations
540-
// do nothing.
541-
return success();
542-
}
543-
if (!llvm::hasSingleElement(countOps)) {
544-
return exportOp->emitOpError(
545-
"unexpected multiple flow.dispatch.workgroup_count_default operations "
546-
"in body");
547-
}
548-
return lowerWorkgroupCountFromSliceOp(rewriter, *countOps.begin(),
549-
entryPointFn, workgroupCount,
550-
maxWorkgroupParallelDims);
551-
}
552-
553524
LogicalResult createWorkgroupCountHint(RewriterBase &rewriter, Location loc,
554525
ArrayRef<OpFoldResult> workgroupCount,
555526
int maxWorkgroupParallelDims,

compiler/src/iree/compiler/Codegen/Transforms/Transforms.h

Lines changed: 5 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -145,35 +145,18 @@ FailureOr<SmallVector<OpFoldResult>> materializeWorkgroupCountComputation(
145145
ArrayRef<OpFoldResult> workgroupCount, ValueRange workloadVals);
146146

147147
/// Lower the workgroup count region for the default code-generation path in
148-
/// IREE. Given the list `workgroupCount` (fastest varying dimension innermost)
149-
/// as computed within the `entryPointFn`, clones a backward slice of the
150-
/// computation starting at these values and ending with
151-
/// `flow.dispatch.constant_ordinal` into the workgroup count region on the
152-
/// `hal.executable.export` op corresponding to the `entryPointFn`. Also removes
153-
/// the `flow.dispatch.constant_ordinal` operations from within the
154-
/// `entryPointFn`. Expects the workgroup count region of the corresponding
155-
/// `hal.executable.export` to contain the
156-
/// `flow.dispatch.workgroup_count_slice` operation as a placeholder for the
157-
/// computation to compute the number of workgroups. In absence of this
158-
/// operation, this method does nothing assuming that the workgroup count
159-
/// computation has already been resolved.
148+
/// Replaces a `workgroup_count_from_slice` placeholder with a materialized
149+
/// workgroup count computation. The `workgroupCount` list (fastest varying
150+
/// dimension innermost) provides the desired counts. If there are more
151+
/// dimensions than `maxWorkgroupParallelDims`, excess dimensions are folded
152+
/// into the last parallel dimension. Remaining dimensions are padded with 1.
160153
LogicalResult lowerWorkgroupCountFromSliceOp(
161154
RewriterBase &rewriter,
162155
IREE::TensorExt::DispatchWorkgroupCountFromSliceOp workgroupCountOp,
163156
mlir::FunctionOpInterface entryPointFn,
164157
ArrayRef<OpFoldResult> workgroupCount,
165158
int maxWorkgroupParallelDims = kNumMaxParallelDims);
166159

167-
/// Wrapper around `lowerWorkgroupCountFromSliceOp` method that
168-
/// takes the `iree_tensor_ext.dispatch.workgroup_count_from_slice` op
169-
/// as an argument. Looks up the `hal.executable.export` operation
170-
/// and finds the `iree_tensor_ext.dispatch.workgroup_count_from_slice` op to
171-
/// lower.
172-
LogicalResult lowerWorkgroupCountFromSliceOp(
173-
RewriterBase &rewriter, mlir::FunctionOpInterface entryPointFn,
174-
ArrayRef<OpFoldResult> workgroupCount,
175-
int maxWorkgroupParallelDims = kNumMaxParallelDims);
176-
177160
/// Creates an `iree_codegen.workgroup_count_hint` op at the current insertion
178161
/// point with the provided operands. If there are more operands provided than
179162
/// |maxWorkgroupParallelDims| the outermost sizes are linearized into the

0 commit comments

Comments
 (0)