Skip to content

Commit fbe39bc

Browse files
authored
simplifier scalar tensor shape to scalar i32 (#1310)
* simplfier scalar tensor shape to scalar i32 * fix ut * update
1 parent 8f21582 commit fbe39bc

File tree

2 files changed

+111
-3
lines changed

2 files changed

+111
-3
lines changed

tao_compiler/mlir/disc/transforms/disc_algebraic_simplifier.cc

Lines changed: 93 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,8 @@
99
// See the License for the specific language governing permissions and
1010
// limitations under the License.
1111

12+
#include <stack>
13+
1214
#include "llvm/ADT/DenseMap.h"
1315
#include "llvm/ADT/StringRef.h"
1416
#include "llvm/Support/Debug.h"
@@ -423,7 +425,6 @@ struct BroadCastInDimOfReshapeOpCanonicalizationPattern
423425
return success();
424426
}
425427
};
426-
427428
// Simplifier extract and from-element op pattern, an example as following:
428429
// %0 = tensor.extract %arg0[] : tensor<f32>
429430
// %1 = tensor.from_elements %0 : tensor<1xf32>
@@ -439,6 +440,8 @@ struct SimplifierFromElementsPattern
439440
auto loc = op->getLoc();
440441
Value input = op->getOperand(0);
441442
Value result = op->getResult(0);
443+
// only support scalar tensor
444+
if (op->getNumOperands() != 1) return failure();
442445
auto extractOp = input.getDefiningOp<tensor::ExtractOp>();
443446
if (!extractOp) return failure();
444447

@@ -530,6 +533,93 @@ struct IndexCastSimplifierPattern
530533
return failure();
531534
}
532535
};
536+
// Simplify get_dimension_size pattern. An examples as following:
537+
// Case 1):
538+
// %2 = "mhlo.get_dimension_size"(%1)
539+
// %3 = "tensor.extract" %2[] -> i32
540+
// %from_elements = tensor.from_elements %3, ...
541+
// Convert to:
542+
// %2 = "tensor.dim"(%1, %cst0) -> i32
543+
// %from_elements = tensor.from_elements %2, ...
544+
//
545+
// Case 2):
546+
// %2 = "mhlo.get_dimension_size"(%1)
547+
// %3 = mhlo.mul %2, %4
548+
// %4 = "tensor.extract" %3[] -> i32
549+
// %from_elements = tensor.from_elements %4, ...
550+
// Convert to:
551+
// %2 = "tensor.dim"(%1, %cst0) -> i32
552+
// %3 = arith.mul %2, %4
553+
// %from_elements = tensor.from_elements %3, ...
554+
555+
struct SimplifierGetDimensionSizePattern
556+
: public OpRewritePattern<mhlo::GetDimensionSizeOp> {
557+
using OpRewritePattern<mhlo::GetDimensionSizeOp>::OpRewritePattern;
558+
LogicalResult matchAndRewrite(mhlo::GetDimensionSizeOp getDimOp,
559+
PatternRewriter& rewriter) const override {
560+
auto loc = getDimOp->getLoc();
561+
Value tensor = getDimOp->getOperand(0);
562+
auto dim = getDimOp.getDimension();
563+
auto elemTy = getDimOp.getResult()
564+
.getType()
565+
.cast<RankedTensorType>()
566+
.getElementType();
567+
568+
SmallVector<Operation*, 4> ops;
569+
std::stack<Operation*> stack;
570+
for (auto user : getDimOp->getUsers()) {
571+
if (isa<tensor::ExtractOp, func::ReturnOp, mhlo::ReshapeOp,
572+
mhlo::DynamicBroadcastInDimOp>(user))
573+
continue;
574+
stack.push(user);
575+
}
576+
while (!stack.empty()) {
577+
auto user = stack.top();
578+
stack.pop();
579+
ops.push_back(user);
580+
for (auto op : user->getUsers()) {
581+
if (isa<tensor::ExtractOp, func::ReturnOp, mhlo::ReshapeOp,
582+
mhlo::DynamicBroadcastInDimOp>(op))
583+
continue;
584+
stack.push(op);
585+
}
586+
}
587+
for (auto op : ops) {
588+
auto loc = op->getLoc();
589+
rewriter.setInsertionPoint(op);
590+
auto v1 = rewriter.create<tensor::ExtractOp>(loc, op->getOperand(0));
591+
auto v2 = rewriter.create<tensor::ExtractOp>(loc, op->getOperand(1));
592+
Value newOpValue;
593+
if (isa<mhlo::MulOp>(op)) {
594+
newOpValue = rewriter.create<arith::MulIOp>(loc, v1, v2).getResult();
595+
} else if (isa<mhlo::AddOp>(op)) {
596+
newOpValue = rewriter.create<arith::AddIOp>(loc, v1, v2).getResult();
597+
} else if (isa<mhlo::SubtractOp>(op)) {
598+
newOpValue = rewriter.create<arith::SubIOp>(loc, v1, v2).getResult();
599+
} else if (isa<mhlo::DivOp>(op)) {
600+
newOpValue = rewriter.create<arith::DivSIOp>(loc, v1, v2).getResult();
601+
} else {
602+
return failure();
603+
}
604+
auto result = rewriter.create<tensor::FromElementsOp>(
605+
loc, getDimOp.getResult().getType().cast<RankedTensorType>(),
606+
newOpValue);
607+
op->replaceAllUsesWith(result);
608+
}
609+
rewriter.setInsertionPoint(getDimOp);
610+
auto dimValue =
611+
rewriter.create<tensor::DimOp>(loc, tensor, dim).getResult();
612+
auto castValue = rewriter.create<arith::IndexCastOp>(loc, elemTy, dimValue);
613+
auto dimValueTensor =
614+
rewriter
615+
.create<tensor::FromElementsOp>(
616+
loc, getDimOp.getResult().getType().cast<RankedTensorType>(),
617+
ValueRange{castValue})
618+
.getResult();
619+
getDimOp.replaceAllUsesWith(dimValueTensor);
620+
return success();
621+
}
622+
};
533623

534624
// Consant folding the broadcasted constant, for patterns like:
535625
// %0 = mhlo.constant // Scalar or splat constant
@@ -627,7 +717,8 @@ void populateDiscAlgebraicSimplifierPatterns(RewritePatternSet& patterns) {
627717
IdentityBroadCastInDimOpCanonicalizationPattern<mhlo::DynamicBroadcastInDimOp>,
628718
SimplifierFromElementsPattern,
629719
TrunciSimplifierPattern,
630-
IndexCastSimplifierPattern
720+
IndexCastSimplifierPattern,
721+
SimplifierGetDimensionSizePattern
631722
>(patterns.getContext());
632723
if (isMemIntensiveOptExperimentalEnabled()) {
633724
// Will be enabled by default after a set of robustness testing.

tao_compiler/mlir/disc/transforms/tests/disc-algebraic-simplifier.mlir

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -256,4 +256,21 @@ func.func @select_simp(%arg0: tensor<16xf16>) -> (tensor<20xf16>, tensor<20xf16>
256256
%10 = mhlo.constant dense<false> : tensor<20xi1>
257257
%11 = "mhlo.select"(%10, %7, %3): (tensor<20xi1>, tensor<20xf16>, tensor<20xf16>) -> tensor<20xf16>
258258
return %8, %3 : tensor<20xf16>, tensor<20xf16>
259-
}
259+
}
260+
261+
// -----
262+
263+
// CHECK-LABEL: @main
264+
func.func @main(%arg0: tensor<?x10xf32>, %arg1: tensor<10xf32>) -> tensor<?x10xf32> {
265+
%c10_i32 = arith.constant 10 : i32
266+
%c_0 = mhlo.constant dense<4> : tensor<i32>
267+
// CHECK: %dim = tensor.dim %arg0, %c0 : tensor<?x10xf32>
268+
// CHECK: %0 = arith.index_cast %dim : index to i32
269+
%2 = "mhlo.get_dimension_size"(%arg0) {dimension = 0 : i64} : (tensor<?x10xf32>) -> tensor<i32>
270+
// CHECK: %1 = arith.muli %0, %c4_i32 : i32
271+
%3 = mhlo.multiply %2, %c_0 : tensor<i32>
272+
%extracted = tensor.extract %3[] : tensor<i32>
273+
%from_elements = tensor.from_elements %extracted, %c10_i32 : tensor<2xi32>
274+
%4 = "mhlo.dynamic_broadcast_in_dim"(%arg1, %from_elements) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<10xf32>, tensor<2xi32>) -> tensor<?x10xf32>
275+
return %4 : tensor<?x10xf32>
276+
}

0 commit comments

Comments
 (0)