Skip to content

Commit 81cfcaa

Browse files
authored
[Dispatch Creation] Handle linalg.fill in collapse dimensions (#20863)
Directly handle collapsing fill operations instead of relying on folding to clean up reshapes. This prevents reshapes from getting stuck in dispatches when the `linalg.fill` consumes a value produced by an operation that can't fold with the tensor reshape (see `collapse_fill_of_arg`). Signed-off-by: Ian Wood <ianwood2024@u.northwestern.edu>
1 parent c6eecfd commit 81cfcaa

File tree

2 files changed

+50
-2
lines changed

2 files changed

+50
-2
lines changed

compiler/src/iree/compiler/DispatchCreation/CollapseDimensions.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -168,7 +168,7 @@ static SmallVector<ReassociationIndices> getCollapsibleLoops(Operation *op) {
168168

169169
/// Returns true if the given op is collapsable.
170170
static bool isEligibleForCollapse(Operation *op) {
171-
if (isa<IREE::LinalgExt::AttentionOp>(op)) {
171+
if (isa<IREE::LinalgExt::AttentionOp, linalg::FillOp>(op)) {
172172
return true;
173173
}
174174

@@ -964,7 +964,7 @@ collapseDimensionsForDispatch(IRRewriter &rewriter,
964964
using ResultsType = FailureOr<SmallVector<Value>>;
965965
auto maybeReplacements =
966966
llvm::TypeSwitch<Operation *, ResultsType>(opToCollapse)
967-
.Case<linalg::GenericOp>(
967+
.Case<linalg::LinalgOp>(
968968
[&, &info = info](auto genericOp) -> ResultsType {
969969
FailureOr<linalg::CollapseResult> maybeReplacements =
970970
mlir::linalg::collapseOpIterationDims(

compiler/src/iree/compiler/DispatchCreation/test/collapse_dimensions.mlir

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -863,3 +863,51 @@ util.func public @multi_reduction(%arg0 : tensor<32x16x16384xf32>, %arg1 : tenso
863863
// CHECK: %[[GEN2:.+]] = linalg.generic
864864
// CHECK-SAME: ins(%[[GEN1]] : tensor<32xf32>)
865865
// CHECK: flow.return %[[GEN2]]
866+
867+
// -----
868+
869+
util.func public @collapse_single_fill(%arg0: tensor<11x470x725x224xf32>) -> tensor<11x470x725x224xf32> {
870+
%0 = flow.dispatch.region -> (tensor<11x470x725x224xf32>) {
871+
%cst = arith.constant 0.000000e+00 : f32
872+
%1 = linalg.fill ins(%cst : f32) outs(%arg0 : tensor<11x470x725x224xf32>) -> tensor<11x470x725x224xf32>
873+
flow.return %1 : tensor<11x470x725x224xf32>
874+
}
875+
util.return %0 : tensor<11x470x725x224xf32>
876+
}
877+
// CHECK-LABEL: util.func public @collapse_single_fill
878+
// CHECK-SAME: %[[ARG0:[0-9a-zA-Z]+]]
879+
// CHECK-DAG: %[[COLLAPSE0:.+]] = tensor.collapse_shape %[[ARG0]]
880+
// CHECK: flow.dispatch.region
881+
// CHECK: %[[FILL:.+]] = linalg.fill
882+
// CHECK-SAME: outs(%[[COLLAPSE0]] : tensor<839608000xf32>)
883+
// CHECK: flow.return %[[FILL]]
884+
885+
// -----
886+
887+
util.func public @collapse_fill_of_arg(%arg0: tensor<224x32xf32>, %arg1: tensor<11x470x725x224xf32>, %arg2: tensor<11x470x725x32xf32>) -> tensor<11x470x725x224xf32> {
888+
%0 = flow.dispatch.region -> (tensor<11x470x725x224xf32>) {
889+
%cst = arith.constant 0.000000e+00 : f32
890+
%1 = linalg.fill ins(%cst : f32) outs(%arg1 : tensor<11x470x725x224xf32>) -> tensor<11x470x725x224xf32>
891+
%2 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d4)>, affine_map<(d0, d1, d2, d3, d4) -> (d3, d4)>, affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel", "reduction"]} ins(%arg2, %arg0 : tensor<11x470x725x32xf32>, tensor<224x32xf32>) outs(%1 : tensor<11x470x725x224xf32>) {
892+
^bb0(%in: f32, %in_0: f32, %out: f32):
893+
%3 = arith.mulf %in, %in_0 : f32
894+
%4 = arith.addf %out, %3 : f32
895+
linalg.yield %4 : f32
896+
} -> tensor<11x470x725x224xf32>
897+
flow.return %2 : tensor<11x470x725x224xf32>
898+
}
899+
util.return %0 : tensor<11x470x725x224xf32>
900+
}
901+
// CHECK-LABEL: util.func public @collapse_fill_of_arg
902+
// CHECK-SAME: %[[ARG0:[0-9a-zA-Z]+]]
903+
// CHECK-SAME: %[[ARG1:[0-9a-zA-Z]+]]
904+
// CHECK-SAME: %[[ARG2:[0-9a-zA-Z]+]]
905+
// CHECK-DAG: %[[COLLAPSE1:.+]] = tensor.collapse_shape %[[ARG1]]
906+
// CHECK-DAG: %[[COLLAPSE2:.+]] = tensor.collapse_shape %[[ARG2]]
907+
// CHECK: flow.dispatch.region
908+
// CHECK: %[[FILL:.+]] = linalg.fill
909+
// CHECK-SAME: outs(%[[COLLAPSE1]] : tensor<3748250x224xf32>)
910+
// CHECK: %[[GEN0:.+]] = linalg.generic
911+
// CHECK-SAME: ins(%[[COLLAPSE2]], %[[ARG0]] : tensor<3748250x32xf32>, tensor<224x32xf32>)
912+
// CHECK-SAME: outs(%[[FILL]] : tensor<3748250x224xf32>)
913+
// CHECK: flow.return %[[GEN0]]

0 commit comments

Comments
 (0)