99// See the License for the specific language governing permissions and
1010// limitations under the License.
1111
12+ #include < stack>
13+
1214#include " llvm/ADT/DenseMap.h"
1315#include " llvm/ADT/StringRef.h"
1416#include " llvm/Support/Debug.h"
@@ -423,7 +425,6 @@ struct BroadCastInDimOfReshapeOpCanonicalizationPattern
423425 return success ();
424426 }
425427};
426-
427428// Simplifier extract and from-element op pattern, an example as following:
428429// %0 = tensor.extract %arg0[] : tensor<f32>
429430// %1 = tensor.from_elements %0 : tensor<1xf32>
@@ -439,6 +440,8 @@ struct SimplifierFromElementsPattern
439440 auto loc = op->getLoc ();
440441 Value input = op->getOperand (0 );
441442 Value result = op->getResult (0 );
443+ // only support scalar tensor
444+ if (op->getNumOperands () != 1 ) return failure ();
442445 auto extractOp = input.getDefiningOp <tensor::ExtractOp>();
443446 if (!extractOp) return failure ();
444447
@@ -530,6 +533,93 @@ struct IndexCastSimplifierPattern
530533 return failure ();
531534 }
532535};
536+ // Simplify get_dimension_size pattern. An examples as following:
537+ // Case 1):
538+ // %2 = "mhlo.get_dimension_size"(%1)
539+ // %3 = "tensor.extract" %2[] -> i32
540+ // %from_elements = tensor.from_elements %3, ...
541+ // Convert to:
542+ // %2 = "tensor.dim"(%1, %cst0) -> i32
543+ // %from_elements = tensor.from_elements %2, ...
544+ //
545+ // Case 2):
546+ // %2 = "mhlo.get_dimension_size"(%1)
547+ // %3 = mhlo.mul %2, %4
548+ // %4 = "tensor.extract" %3[] -> i32
549+ // %from_elements = tensor.from_elements %4, ...
550+ // Convert to:
551+ // %2 = "tensor.dim"(%1, %cst0) -> i32
552+ // %3 = arith.mul %2, %4
553+ // %from_elements = tensor.from_elements %3, ...
554+
555+ struct SimplifierGetDimensionSizePattern
556+ : public OpRewritePattern<mhlo::GetDimensionSizeOp> {
557+ using OpRewritePattern<mhlo::GetDimensionSizeOp>::OpRewritePattern;
558+ LogicalResult matchAndRewrite (mhlo::GetDimensionSizeOp getDimOp,
559+ PatternRewriter& rewriter) const override {
560+ auto loc = getDimOp->getLoc ();
561+ Value tensor = getDimOp->getOperand (0 );
562+ auto dim = getDimOp.getDimension ();
563+ auto elemTy = getDimOp.getResult ()
564+ .getType ()
565+ .cast <RankedTensorType>()
566+ .getElementType ();
567+
568+ SmallVector<Operation*, 4 > ops;
569+ std::stack<Operation*> stack;
570+ for (auto user : getDimOp->getUsers ()) {
571+ if (isa<tensor::ExtractOp, func::ReturnOp, mhlo::ReshapeOp,
572+ mhlo::DynamicBroadcastInDimOp>(user))
573+ continue ;
574+ stack.push (user);
575+ }
576+ while (!stack.empty ()) {
577+ auto user = stack.top ();
578+ stack.pop ();
579+ ops.push_back (user);
580+ for (auto op : user->getUsers ()) {
581+ if (isa<tensor::ExtractOp, func::ReturnOp, mhlo::ReshapeOp,
582+ mhlo::DynamicBroadcastInDimOp>(op))
583+ continue ;
584+ stack.push (op);
585+ }
586+ }
587+ for (auto op : ops) {
588+ auto loc = op->getLoc ();
589+ rewriter.setInsertionPoint (op);
590+ auto v1 = rewriter.create <tensor::ExtractOp>(loc, op->getOperand (0 ));
591+ auto v2 = rewriter.create <tensor::ExtractOp>(loc, op->getOperand (1 ));
592+ Value newOpValue;
593+ if (isa<mhlo::MulOp>(op)) {
594+ newOpValue = rewriter.create <arith::MulIOp>(loc, v1, v2).getResult ();
595+ } else if (isa<mhlo::AddOp>(op)) {
596+ newOpValue = rewriter.create <arith::AddIOp>(loc, v1, v2).getResult ();
597+ } else if (isa<mhlo::SubtractOp>(op)) {
598+ newOpValue = rewriter.create <arith::SubIOp>(loc, v1, v2).getResult ();
599+ } else if (isa<mhlo::DivOp>(op)) {
600+ newOpValue = rewriter.create <arith::DivSIOp>(loc, v1, v2).getResult ();
601+ } else {
602+ return failure ();
603+ }
604+ auto result = rewriter.create <tensor::FromElementsOp>(
605+ loc, getDimOp.getResult ().getType ().cast <RankedTensorType>(),
606+ newOpValue);
607+ op->replaceAllUsesWith (result);
608+ }
609+ rewriter.setInsertionPoint (getDimOp);
610+ auto dimValue =
611+ rewriter.create <tensor::DimOp>(loc, tensor, dim).getResult ();
612+ auto castValue = rewriter.create <arith::IndexCastOp>(loc, elemTy, dimValue);
613+ auto dimValueTensor =
614+ rewriter
615+ .create <tensor::FromElementsOp>(
616+ loc, getDimOp.getResult ().getType ().cast <RankedTensorType>(),
617+ ValueRange{castValue})
618+ .getResult ();
619+ getDimOp.replaceAllUsesWith (dimValueTensor);
620+ return success ();
621+ }
622+ };
533623
534624// Consant folding the broadcasted constant, for patterns like:
535625// %0 = mhlo.constant // Scalar or splat constant
@@ -627,7 +717,8 @@ void populateDiscAlgebraicSimplifierPatterns(RewritePatternSet& patterns) {
627717 IdentityBroadCastInDimOpCanonicalizationPattern<mhlo::DynamicBroadcastInDimOp>,
628718 SimplifierFromElementsPattern,
629719 TrunciSimplifierPattern,
630- IndexCastSimplifierPattern
720+ IndexCastSimplifierPattern,
721+ SimplifierGetDimensionSizePattern
631722 >(patterns.getContext ());
632723 if (isMemIntensiveOptExperimentalEnabled ()) {
633724 // Will be enabled by default after a set of robustness testing.
0 commit comments