Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
#include "iree/compiler/Codegen/Common/Transforms.h"
#include "iree/compiler/Dialect/LinalgExt/IR/LinalgExtDialect.h"
#include "iree/compiler/Dialect/LinalgExt/Transforms/Transforms.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/IR/MLIRContext.h"
Expand All @@ -24,6 +25,17 @@ namespace {

using iree_compiler::IREE::LinalgExt::IREELinalgExtDialect;

/// Generalize a specific named op to a linalg.generic.
template <typename OpTy>
struct GeneralizeNamedOp : OpRewritePattern<OpTy> {
using OpRewritePattern<OpTy>::OpRewritePattern;
LogicalResult matchAndRewrite(OpTy op,
PatternRewriter &rewriter) const override {
return linalg::generalizeNamedOp(rewriter,
cast<linalg::LinalgOp>(op.getOperation()));
}
};

/// Pattern to set a lowering configuration on an IGEMM convolution. Searches
/// for a contraction with a linalg_ext.im2col producer, and calls the configFn
/// to set the configuration.
Expand Down Expand Up @@ -108,6 +120,21 @@ convertToIGEMMAndSetConfig(FunctionOpInterface funcOp,
}
}

// Materialize implicit broadcasts in element-wise consumer ops. Consumer
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
// Materialize implicit broadcasts in element-wise consumer ops. Consumer
// Materialize implicit broadcasts in elementwise consumer ops. Consumer

// generics with non-identity indexing maps (e.g., per-row bias with map
// (d0,d1,d2,d3) -> (d1)) cannot be folded through by the reshape
// propagation patterns below. Materialize the broadcasts explicitly to
// turn consumers into pure element-wise ops with identity maps.
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
// turn consumers into pure element-wise ops with identity maps.
// turn consumers into pure elementwise ops with identity maps.

{
RewritePatternSet materializeBroadcastPatterns(context);
linalg::populateDecomposeProjectedPermutationPatterns(
materializeBroadcastPatterns);
if (failed(applyPatternsGreedily(
funcOp, std::move(materializeBroadcastPatterns)))) {
return failure();
}
}

// The im2col transformation collapses some of the dimensions of the
// convolution operands. Try to push the reshape ops towards the boundaries
// of the function and fold with interface tensor ops.
Expand Down Expand Up @@ -155,6 +182,24 @@ convertToIGEMMAndSetConfig(FunctionOpInterface funcOp,
return failure();
}
}
// Re-fuse the materialized broadcasts back into their element-wise
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
// Re-fuse the materialized broadcasts back into their element-wise
// Re-fuse the materialized broadcasts back into their elementwise

// consumers. The decomposition above created explicit linalg.broadcast
// and linalg.transpose ops so reshape propagation could fold through
// identity-map generics. Now that reshapes have been pushed to the
// boundaries, generalize those named ops to generics and fuse them back
// into their consumers to produce compact element-wise generics.
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
// into their consumers to produce compact element-wise generics.
// into their consumers to produce compact elementwise generics.

{
RewritePatternSet fusionPatterns(context);
// Generalize only broadcast/transpose to generics so elementwise
// fusion can fold them into their consumers.
fusionPatterns.add<GeneralizeNamedOp<linalg::BroadcastOp>,
GeneralizeNamedOp<linalg::TransposeOp>>(context);
linalg::populateElementwiseOpsFusionPatterns(
fusionPatterns, [](OpOperand *) { return true; });
if (failed(applyPatternsGreedily(funcOp, std::move(fusionPatterns)))) {
return failure();
}
}
return success();
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -147,3 +147,79 @@ func.func public @no_conv_contraction(%arg0: tensor<128x128xf32>, %arg1: tensor<
// CHECK-NOT: iree_linalg_ext.im2col
// CHECK: linalg.generic
// CHECK-SAME: iterator_types = ["parallel", "parallel", "reduction"]

// -----

// Test that without a conv, the pass decomposes and re-fuses the broadcast
// so the generic is unchanged.

#map_id = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
#map_bcast = affine_map<(d0, d1, d2) -> (d2)>
func.func @elementwise_broadcast_roundtrip(
%arg0: tensor<1x196x16xf32>,
%bias: tensor<16xf32>) -> tensor<1x196x16xf32> {
%empty = tensor.empty() : tensor<1x196x16xf32>
%result = linalg.generic {
indexing_maps = [#map_id, #map_bcast, #map_id],
iterator_types = ["parallel", "parallel", "parallel"]
} ins(%arg0, %bias : tensor<1x196x16xf32>, tensor<16xf32>)
outs(%empty : tensor<1x196x16xf32>) {
^bb0(%in: f32, %b: f32, %out: f32):
%add = arith.addf %in, %b : f32
linalg.yield %add : f32
} -> tensor<1x196x16xf32>
return %result : tensor<1x196x16xf32>
}
// CHECK-LABEL: func.func @elementwise_broadcast_roundtrip
// CHECK: %[[RES:.+]] = linalg.generic
// CHECK-SAME: ins(%{{.*}}, %{{.*}} : tensor<1x196x16xf32>, tensor<16xf32>)
// CHECK-NOT: linalg.broadcast
// CHECK: return %[[RES]]

// -----

// Test that an expand_shape before a broadcasted element-wise consumer
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
// Test that an expand_shape before a broadcasted element-wise consumer
// Test that an expand_shape before a broadcasted elementwise consumer

// propagates through the generic and into the store_to_buffer. The broadcast
// decomposition turns the non-identity map into identity maps so the
// expand_shape can fold through, then the broadcast is fused back.

#map_id_4d = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
#map_bcast_4d = affine_map<(d0, d1, d2, d3) -> (d2)>
func.func @expand_shape_propagation_with_broadcast(
%arg0: tensor<1x196x16xf32>,
%bias: tensor<14xf32>,
%arg2: memref<1x14x14x16xf32>) {
%expanded = tensor.expand_shape %arg0 [[0], [1, 2], [3]]
output_shape [1, 14, 14, 16] : tensor<1x196x16xf32> into tensor<1x14x14x16xf32>
%empty = tensor.empty() : tensor<1x14x14x16xf32>
%add = linalg.generic {
indexing_maps = [#map_id_4d, #map_bcast_4d, #map_id_4d],
iterator_types = ["parallel", "parallel", "parallel", "parallel"]
} ins(%expanded, %bias : tensor<1x14x14x16xf32>, tensor<14xf32>)
outs(%empty : tensor<1x14x14x16xf32>) {
^bb0(%in: f32, %b: f32, %out: f32):
%sum = arith.addf %in, %b : f32
linalg.yield %sum : f32
} -> tensor<1x14x14x16xf32>
iree_codegen.store_to_buffer %add, %arg2
: tensor<1x14x14x16xf32> into memref<1x14x14x16xf32>
return
}
// The expand_shape propagates through the generic to the store boundary.
// The output buffer gets a memref.collapse_shape and the store operates in
// the collapsed 3D shape.
// CHECK-LABEL: func.func @expand_shape_propagation_with_broadcast
// CHECK-SAME: %[[INPUT:[a-zA-Z0-9]+]]: tensor<1x196x16xf32>
// CHECK-SAME: %[[OUTPUT_BUF:[a-zA-Z0-9]+]]: memref<1x14x14x16xf32>
// CHECK: %[[COLLAPSED_OUT:.+]] = memref.collapse_shape %[[OUTPUT_BUF]]
// CHECK-SAME: memref<1x14x14x16xf32> into memref<1x196x16xf32>
// CHECK-NOT: tensor.expand_shape
// CHECK: %[[BCAST:.+]] = linalg.generic
// CHECK-SAME: ins(%{{.*}} : tensor<14xf32>)
// CHECK-SAME: outs(%{{.*}} : tensor<1x14x14x16xf32>)
// CHECK: %[[COLLAPSED_BCAST:.+]] = tensor.collapse_shape %[[BCAST]]
// CHECK-SAME: tensor<1x14x14x16xf32> into tensor<1x196x16xf32>
// CHECK: %[[ADD:.+]] = linalg.generic
// CHECK-SAME: ins(%[[INPUT]], %[[COLLAPSED_BCAST]] : tensor<1x196x16xf32>, tensor<1x196x16xf32>)
// CHECK: iree_codegen.store_to_buffer %[[ADD]], %[[COLLAPSED_OUT]]
// CHECK-SAME: tensor<1x196x16xf32> into memref<1x196x16xf32>
Loading