Skip to content

Commit ea18f7f

Browse files
[DispatchCreation] Use patterns to bubble up expand shape across collapse shapes. (#20648)
Signed-off-by: MaheshRavishankar <mahesh.ravishankar@gmail.com>
1 parent f6a5d72 commit ea18f7f

File tree

3 files changed

+16
-7
lines changed

3 files changed

+16
-7
lines changed

compiler/src/iree/compiler/DispatchCreation/BubbleUpExpandShapes.cpp

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,12 @@
2929

3030
#define DEBUG_TYPE "iree-dispatch-creation-bubble-up-expand-shapes"
3131

32+
static llvm::cl::opt<bool> clPropagateCollapseAcrossExpands(
33+
"iree-dispatch-creation-propagate-collapse-across-expands",
34+
llvm::cl::desc("Enables change to propagate collapse shapes across expand "
35+
"shapes. This flag is meant as a stop-gap solution before "
36+
"making this default due to codegen issues."),
37+
llvm::cl::init(false));
3238
namespace mlir::iree_compiler::DispatchCreation {
3339

3440
#define GEN_PASS_DEF_BUBBLEUPEXPANDSHAPESPASS
@@ -212,6 +218,10 @@ void BubbleUpExpandShapesPass::runOnOperation() {
212218
memref::populateResolveRankedShapedTypeResultDimsPatterns(
213219
bubbleExpandShapePatterns);
214220

221+
if (clPropagateCollapseAcrossExpands) {
222+
tensor::populateBubbleUpExpandShapePatterns(bubbleExpandShapePatterns);
223+
}
224+
215225
GreedyRewriteConfig rewriteConfig;
216226
rewriteConfig.maxIterations = GreedyRewriteConfig::kNoLimit;
217227
if (failed(applyPatternsGreedily(getOperation(),

compiler/src/iree/compiler/DispatchCreation/test/attention_fuse_by_expansion.mlir

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
// RUN: iree-opt --split-input-file --mlir-print-local-scope --pass-pipeline="builtin.module(util.func(iree-dispatch-creation-bubble-up-expand-shapes, canonicalize, cse, canonicalize))" %s | FileCheck %s
1+
// RUN: iree-opt --split-input-file --mlir-print-local-scope --pass-pipeline="builtin.module(util.func(iree-dispatch-creation-bubble-up-expand-shapes, canonicalize, cse, canonicalize))" --iree-dispatch-creation-propagate-collapse-across-expands=true %s | FileCheck %s
22

33
#map = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2)>
44
#map1 = affine_map<(d0, d1, d2, d3, d4) -> (d0, d3, d2)>
@@ -584,8 +584,7 @@ util.func @scatter_collapse_original_partial(%arg0: tensor<?x1x32x8x128xf16>, %a
584584
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]:
585585
// CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]]:
586586
// CHECK-DAG: %[[UPDATES:.+]] = tensor.expand_shape %[[ARG0]] {{.*}} tensor<?x1x32x8x128xf16> into tensor<?x1x2x16x4x2x64x2xf16>
587-
// TODO(IanWood1): fix this so the collapse folds with the expand
588-
// CHECK-DAG: %[[ORIGINAL:.+]] = tensor.expand_shape {{.*}} tensor<?x32x8x128xf16> into tensor<?x2x16x4x2x64x2xf16>
587+
// CHECK-DAG: %[[ORIGINAL:.+]] = tensor.collapse_shape %[[ARG2]] {{.*}} tensor<5x?x2x16x4x2x64x2xf16> into tensor<?x2x16x4x2x64x2xf16>
589588
// CHECK: %[[SCATTER:.+]] = iree_linalg_ext.scatter
590589
// CHECK-SAME: ins(%[[UPDATES]], %[[ARG1]]
591590
// CHECK-SAME: outs(%[[ORIGINAL]]

compiler/src/iree/compiler/Preprocessing/Common/test/attr_based_pipeline.mlir

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
// RUN: iree-opt --iree-preprocessing-attr-based-pipeline --mlir-print-local-scope --split-input-file --verify-diagnostics %s | FileCheck %s
1+
// RUN: iree-opt --iree-preprocessing-attr-based-pipeline --mlir-print-local-scope --split-input-file --verify-diagnostics --iree-dispatch-creation-propagate-collapse-across-expands=true %s | FileCheck %s
22

33
func.func @single_dispatch_dropunitdims(%lhs : tensor<1x26x18x288xbf16>, %rhs : tensor<288x288x3x3xbf16>, %outs : tensor<1x288x26x18xbf16>,
44
%outs2 : tensor<1x288x24x16xf32>) -> tensor<1x288x24x16xf32> attributes {
@@ -12,9 +12,9 @@ func.func @single_dispatch_dropunitdims(%lhs : tensor<1x26x18x288xbf16>, %rhs :
1212
// CHECK-LABEL: @single_dispatch_dropunitdims
1313
// CHECK-SAME: %[[ARG0:[A-Za-z0-9]+]]: tensor<1x26x18x288xbf16>
1414
// CHECK: %[[DISPATCH:.+]] = flow.dispatch.region
15-
// CHECK: %[[COLLAPSE:.+]] = tensor.collapse_shape %[[ARG0]]
16-
// CHECK: %[[EXPAND:.+]] = tensor.expand_shape %[[COLLAPSE]]
17-
// CHECK: %[[CONV:.+]] = linalg.generic {{.*}} ins(%[[EXPAND]]
15+
// CHECK: %[[EXPAND:.+]] = tensor.expand_shape %[[ARG0]]
16+
// CHECK: %[[COLLAPSE:.+]] = tensor.collapse_shape %[[EXPAND]]
17+
// CHECK: %[[CONV:.+]] = linalg.generic {{.*}} ins(%[[COLLAPSE]]
1818
// CHECK: flow.return %[[CONV]]
1919
// CHECK: return %[[DISPATCH]]
2020

0 commit comments

Comments
 (0)