-
Notifications
You must be signed in to change notification settings - Fork 871
[Codegen] Materialize implicit broadcasts for iGEMM consumer fusions #23954
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||
|---|---|---|---|---|---|---|
|
|
@@ -7,6 +7,7 @@ | |||||
| #include "iree/compiler/Codegen/Common/Transforms.h" | ||||||
| #include "iree/compiler/Dialect/LinalgExt/IR/LinalgExtDialect.h" | ||||||
| #include "iree/compiler/Dialect/LinalgExt/Transforms/Transforms.h" | ||||||
| #include "mlir/Dialect/Linalg/IR/Linalg.h" | ||||||
| #include "mlir/Dialect/Linalg/Transforms/Transforms.h" | ||||||
| #include "mlir/Dialect/Tensor/IR/Tensor.h" | ||||||
| #include "mlir/IR/MLIRContext.h" | ||||||
|
|
@@ -24,6 +25,17 @@ namespace { | |||||
|
|
||||||
| using iree_compiler::IREE::LinalgExt::IREELinalgExtDialect; | ||||||
|
|
||||||
| /// Generalize a specific named op to a linalg.generic. | ||||||
| template <typename OpTy> | ||||||
| struct GeneralizeNamedOp : OpRewritePattern<OpTy> { | ||||||
| using OpRewritePattern<OpTy>::OpRewritePattern; | ||||||
| LogicalResult matchAndRewrite(OpTy op, | ||||||
| PatternRewriter &rewriter) const override { | ||||||
| return linalg::generalizeNamedOp(rewriter, | ||||||
| cast<linalg::LinalgOp>(op.getOperation())); | ||||||
| } | ||||||
| }; | ||||||
|
|
||||||
| /// Pattern to set a lowering configuration on an IGEMM convolution. Searches | ||||||
| /// for a contraction with a linalg_ext.im2col producer, and calls the configFn | ||||||
| /// to set the configuration. | ||||||
|
|
@@ -108,6 +120,21 @@ convertToIGEMMAndSetConfig(FunctionOpInterface funcOp, | |||||
| } | ||||||
| } | ||||||
|
|
||||||
| // Materialize implicit broadcasts in element-wise consumer ops. Consumer | ||||||
| // generics with non-identity indexing maps (e.g., per-row bias with map | ||||||
| // (d0,d1,d2,d3) -> (d1)) cannot be folded through by the reshape | ||||||
| // propagation patterns below. Materialize the broadcasts explicitly to | ||||||
| // turn consumers into pure element-wise ops with identity maps. | ||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||
| { | ||||||
| RewritePatternSet materializeBroadcastPatterns(context); | ||||||
| linalg::populateDecomposeProjectedPermutationPatterns( | ||||||
| materializeBroadcastPatterns); | ||||||
| if (failed(applyPatternsGreedily( | ||||||
| funcOp, std::move(materializeBroadcastPatterns)))) { | ||||||
| return failure(); | ||||||
| } | ||||||
| } | ||||||
|
|
||||||
| // The im2col transformation collapses some of the dimensions of the | ||||||
| // convolution operands. Try to push the reshape ops towards the boundaries | ||||||
| // of the function and fold with interface tensor ops. | ||||||
|
|
@@ -155,6 +182,24 @@ convertToIGEMMAndSetConfig(FunctionOpInterface funcOp, | |||||
| return failure(); | ||||||
| } | ||||||
| } | ||||||
| // Re-fuse the materialized broadcasts back into their element-wise | ||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||
| // consumers. The decomposition above created explicit linalg.broadcast | ||||||
| // and linalg.transpose ops so reshape propagation could fold through | ||||||
| // identity-map generics. Now that reshapes have been pushed to the | ||||||
| // boundaries, generalize those named ops to generics and fuse them back | ||||||
| // into their consumers to produce compact element-wise generics. | ||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||
| { | ||||||
| RewritePatternSet fusionPatterns(context); | ||||||
| // Generalize only broadcast/transpose to generics so elementwise | ||||||
| // fusion can fold them into their consumers. | ||||||
| fusionPatterns.add<GeneralizeNamedOp<linalg::BroadcastOp>, | ||||||
| GeneralizeNamedOp<linalg::TransposeOp>>(context); | ||||||
| linalg::populateElementwiseOpsFusionPatterns( | ||||||
| fusionPatterns, [](OpOperand *) { return true; }); | ||||||
| if (failed(applyPatternsGreedily(funcOp, std::move(fusionPatterns)))) { | ||||||
| return failure(); | ||||||
| } | ||||||
| } | ||||||
| return success(); | ||||||
| } | ||||||
|
|
||||||
|
|
||||||
| Original file line number | Diff line number | Diff line change | ||||
|---|---|---|---|---|---|---|
|
|
@@ -147,3 +147,79 @@ func.func public @no_conv_contraction(%arg0: tensor<128x128xf32>, %arg1: tensor< | |||||
| // CHECK-NOT: iree_linalg_ext.im2col | ||||||
| // CHECK: linalg.generic | ||||||
| // CHECK-SAME: iterator_types = ["parallel", "parallel", "reduction"] | ||||||
|
|
||||||
| // ----- | ||||||
|
|
||||||
| // Test that without a conv, the pass decomposes and re-fuses the broadcast | ||||||
| // so the generic is unchanged. | ||||||
|
|
||||||
| #map_id = affine_map<(d0, d1, d2) -> (d0, d1, d2)> | ||||||
| #map_bcast = affine_map<(d0, d1, d2) -> (d2)> | ||||||
| func.func @elementwise_broadcast_roundtrip( | ||||||
| %arg0: tensor<1x196x16xf32>, | ||||||
| %bias: tensor<16xf32>) -> tensor<1x196x16xf32> { | ||||||
| %empty = tensor.empty() : tensor<1x196x16xf32> | ||||||
| %result = linalg.generic { | ||||||
| indexing_maps = [#map_id, #map_bcast, #map_id], | ||||||
| iterator_types = ["parallel", "parallel", "parallel"] | ||||||
| } ins(%arg0, %bias : tensor<1x196x16xf32>, tensor<16xf32>) | ||||||
| outs(%empty : tensor<1x196x16xf32>) { | ||||||
| ^bb0(%in: f32, %b: f32, %out: f32): | ||||||
| %add = arith.addf %in, %b : f32 | ||||||
| linalg.yield %add : f32 | ||||||
| } -> tensor<1x196x16xf32> | ||||||
| return %result : tensor<1x196x16xf32> | ||||||
| } | ||||||
| // CHECK-LABEL: func.func @elementwise_broadcast_roundtrip | ||||||
| // CHECK: %[[RES:.+]] = linalg.generic | ||||||
| // CHECK-SAME: ins(%{{.*}}, %{{.*}} : tensor<1x196x16xf32>, tensor<16xf32>) | ||||||
| // CHECK-NOT: linalg.broadcast | ||||||
| // CHECK: return %[[RES]] | ||||||
|
|
||||||
| // ----- | ||||||
|
|
||||||
| // Test that an expand_shape before a broadcasted element-wise consumer | ||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||
| // propagates through the generic and into the store_to_buffer. The broadcast | ||||||
| // decomposition turns the non-identity map into identity maps so the | ||||||
| // expand_shape can fold through, then the broadcast is fused back. | ||||||
|
|
||||||
| #map_id_4d = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)> | ||||||
| #map_bcast_4d = affine_map<(d0, d1, d2, d3) -> (d2)> | ||||||
| func.func @expand_shape_propagation_with_broadcast( | ||||||
| %arg0: tensor<1x196x16xf32>, | ||||||
| %bias: tensor<14xf32>, | ||||||
| %arg2: memref<1x14x14x16xf32>) { | ||||||
| %expanded = tensor.expand_shape %arg0 [[0], [1, 2], [3]] | ||||||
| output_shape [1, 14, 14, 16] : tensor<1x196x16xf32> into tensor<1x14x14x16xf32> | ||||||
| %empty = tensor.empty() : tensor<1x14x14x16xf32> | ||||||
| %add = linalg.generic { | ||||||
| indexing_maps = [#map_id_4d, #map_bcast_4d, #map_id_4d], | ||||||
| iterator_types = ["parallel", "parallel", "parallel", "parallel"] | ||||||
| } ins(%expanded, %bias : tensor<1x14x14x16xf32>, tensor<14xf32>) | ||||||
| outs(%empty : tensor<1x14x14x16xf32>) { | ||||||
| ^bb0(%in: f32, %b: f32, %out: f32): | ||||||
| %sum = arith.addf %in, %b : f32 | ||||||
| linalg.yield %sum : f32 | ||||||
| } -> tensor<1x14x14x16xf32> | ||||||
| iree_codegen.store_to_buffer %add, %arg2 | ||||||
| : tensor<1x14x14x16xf32> into memref<1x14x14x16xf32> | ||||||
| return | ||||||
| } | ||||||
| // The expand_shape propagates through the generic to the store boundary. | ||||||
| // The output buffer gets a memref.collapse_shape and the store operates in | ||||||
| // the collapsed 3D shape. | ||||||
| // CHECK-LABEL: func.func @expand_shape_propagation_with_broadcast | ||||||
| // CHECK-SAME: %[[INPUT:[a-zA-Z0-9]+]]: tensor<1x196x16xf32> | ||||||
| // CHECK-SAME: %[[OUTPUT_BUF:[a-zA-Z0-9]+]]: memref<1x14x14x16xf32> | ||||||
| // CHECK: %[[COLLAPSED_OUT:.+]] = memref.collapse_shape %[[OUTPUT_BUF]] | ||||||
| // CHECK-SAME: memref<1x14x14x16xf32> into memref<1x196x16xf32> | ||||||
| // CHECK-NOT: tensor.expand_shape | ||||||
| // CHECK: %[[BCAST:.+]] = linalg.generic | ||||||
| // CHECK-SAME: ins(%{{.*}} : tensor<14xf32>) | ||||||
| // CHECK-SAME: outs(%{{.*}} : tensor<1x14x14x16xf32>) | ||||||
| // CHECK: %[[COLLAPSED_BCAST:.+]] = tensor.collapse_shape %[[BCAST]] | ||||||
| // CHECK-SAME: tensor<1x14x14x16xf32> into tensor<1x196x16xf32> | ||||||
| // CHECK: %[[ADD:.+]] = linalg.generic | ||||||
| // CHECK-SAME: ins(%[[INPUT]], %[[COLLAPSED_BCAST]] : tensor<1x196x16xf32>, tensor<1x196x16xf32>) | ||||||
| // CHECK: iree_codegen.store_to_buffer %[[ADD]], %[[COLLAPSED_OUT]] | ||||||
| // CHECK-SAME: tensor<1x196x16xf32> into memref<1x196x16xf32> | ||||||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.