1010#include " iree/compiler/Codegen/Common/Passes.h"
1111#include " iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenAttrs.h"
1212#include " iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenDialect.h"
13+ #include " iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUAttrs.h"
14+ #include " iree/compiler/Codegen/Dialect/GPU/TargetUtils/KnownTargets.h"
15+ #include " iree/compiler/Codegen/Utils/GPUUtils.h"
1316#include " iree/compiler/Dialect/Encoding/IR/EncodingTypes.h"
1417#include " iree/compiler/Dialect/Flow/IR/FlowOps.h"
1518#include " iree/compiler/Dialect/Flow/IR/FlowTypes.h"
1619#include " llvm/ADT/STLExtras.h"
1720#include " llvm/ADT/SmallVector.h"
21+ #include " llvm/Support/Debug.h"
1822#include " mlir/Dialect/MemRef/Transforms/Transforms.h"
1923#include " mlir/Dialect/Tensor/IR/Tensor.h"
2024#include " mlir/IR/BuiltinTypes.h"
2327#include " mlir/Transforms/GreedyPatternRewriteDriver.h"
2428#include " mlir/Transforms/Passes.h"
2529
30+ #define DEBUG_TYPE " iree-codegen-materialize-encoding-into-padding"
31+
2632namespace mlir ::iree_compiler {
2733
2834#define GEN_PASS_DEF_MATERIALIZEENCODINGINTOPADDINGPASS
@@ -34,24 +40,36 @@ namespace {
3440
3541// Returns the pad encoding layout, or nullptr if this is not the only layout or
3642// if there's no encoding at all.
37- static PadEncodingLayoutAttr getPadLayout (RankedTensorType type) {
38- auto encoding =
39- dyn_cast_or_null<IREE::Encoding::LayoutAttr>(type.getEncoding ());
40- if (!encoding) {
43+ static PadEncodingLayoutAttr getPadLayout (Attribute layoutAttr,
44+ RankedTensorType type) {
45+ if (!type.getEncoding ()) {
4146 return nullptr ;
4247 }
43- ArrayAttr layouts = encoding.getLayouts ();
44- if (!layouts || layouts.size () != 1 ) {
45- return nullptr ;
48+ auto encoding =
49+ dyn_cast_or_null<IREE::Encoding::LayoutAttr>(type.getEncoding ());
50+ if (encoding) {
51+ ArrayAttr layouts = encoding.getLayouts ();
52+ if (layouts.size () != 1 ) {
53+ return nullptr ;
54+ }
55+ return dyn_cast<PadEncodingLayoutAttr>(*layouts.begin ());
4656 }
47-
48- return dyn_cast<PadEncodingLayoutAttr>(*layouts.begin ());
57+ Attribute resolvedEncoding =
58+ cast<IREE::Encoding::EncodingLayoutResolverAttrInterface>(layoutAttr)
59+ .getLayout (type);
60+ LLVM_DEBUG ({
61+ llvm::dbgs () << " Unresolved type: " << type << " \n " ;
62+ llvm::dbgs () << " layoutAttr: " << layoutAttr << " \n " ;
63+ llvm::dbgs () << " Resolved into: " << resolvedEncoding << " \n " ;
64+ });
65+ return dyn_cast<PadEncodingLayoutAttr>(resolvedEncoding);
4966}
5067
5168// Returns a padded tensor type (without encoding) for tensor types with the pad
5269// encoding layout, or the same type for all other tensors.
53- static RankedTensorType getPaddedType (RankedTensorType type) {
54- PadEncodingLayoutAttr layout = getPadLayout (type);
70+ static RankedTensorType getPaddedType (Attribute layoutAttr,
71+ RankedTensorType type) {
72+ PadEncodingLayoutAttr layout = getPadLayout (layoutAttr, type);
5573 if (!isNonZeroPadding (layout)) {
5674 return type.dropEncoding ();
5775 }
@@ -67,15 +85,11 @@ static RankedTensorType getPaddedType(RankedTensorType type) {
6785 return RankedTensorType::get (newShape, type.getElementType ());
6886}
6987
70- static bool hasNonZeroPadding (RankedTensorType type) {
71- return isNonZeroPadding (getPadLayout (type));
72- }
73-
7488struct MaterializePadEncodingTypeConverter final
7589 : MaterializeEncodingTypeConverter {
76- MaterializePadEncodingTypeConverter (MLIRContext *ctx)
77- : MaterializeEncodingTypeConverter(
78- IREE::Codegen::EncodingNopLayoutAttr::get (ctx) ) {
90+ MaterializePadEncodingTypeConverter (
91+ IREE::Codegen::LayoutAttrInterface layoutAttr)
92+ : MaterializeEncodingTypeConverter(layoutAttr ) {
7993 addConversion ([](RankedTensorType type) -> std::optional<RankedTensorType> {
8094 // The type converter is designed for `pad_encoding_layout` encoding
8195 // attribute. By the definition, the final converted type is the same
@@ -85,18 +99,23 @@ struct MaterializePadEncodingTypeConverter final
8599 addConversion ([&](IREE::Flow::DispatchTensorType dispatchTensorType)
86100 -> IREE::Flow::DispatchTensorType {
87101 auto type = dyn_cast<RankedTensorType>(dispatchTensorType.getBoundType ());
88- if (!type) {
102+ if (!type || !type. getEncoding () ) {
89103 return dispatchTensorType;
90104 }
91105 // The incoming bindings have the padded type, if `pad_encoding_layout` is
92106 // present.
93- if (getPadLayout (type)) {
94- type = getPaddedType (type);
107+ if (getPadLayout (getLayoutAttr (), type)) {
108+ type = getPaddedType (getLayoutAttr (), type);
95109 }
96110 return IREE::Flow::DispatchTensorType::get (dispatchTensorType.getAccess (),
97111 type);
98112 });
99113 }
114+
115+ bool hasNonZeroPadding (RankedTensorType type) const {
116+ PadEncodingLayoutAttr layout = getPadLayout (getLayoutAttr (), type);
117+ return layout && !layout.isIdentityLayout ();
118+ }
100119};
101120
102121// / Pattern to convert `flow.dispatch.tensor.load` operation when
@@ -116,15 +135,15 @@ struct MaterializeFlowDispatchTensorLoadOp final
116135 return rewriter.notifyMatchFailure (loadOp, " unhandled partial loads" );
117136 }
118137
138+ auto &typeConverter =
139+ *getTypeConverter<MaterializePadEncodingTypeConverter>();
119140 IREE::Flow::DispatchTensorType sourceType = loadOp.getSourceType ();
120141 auto boundTensorType = cast<RankedTensorType>(sourceType.getBoundType ());
121- if (!hasNonZeroPadding (boundTensorType)) {
142+ if (!typeConverter. hasNonZeroPadding (boundTensorType)) {
122143 // Let the Nop pattern handle this.
123144 return rewriter.notifyMatchFailure (loadOp, " no padding applied" );
124145 }
125146
126- auto &typeConverter =
127- *getTypeConverter<MaterializePadEncodingTypeConverter>();
128147 auto paddedType =
129148 typeConverter.convertType <RankedTensorType>(boundTensorType);
130149 assert (paddedType != boundTensorType && " Expected conversion with padding" );
@@ -171,15 +190,15 @@ struct MaterializeFlowDispatchTensorStoreOp final
171190 return rewriter.notifyMatchFailure (storeOp, " unhandled partial stores" );
172191 }
173192
193+ auto &typeConverter =
194+ *getTypeConverter<MaterializePadEncodingTypeConverter>();
174195 IREE::Flow::DispatchTensorType targetType = storeOp.getTargetType ();
175196 auto boundTensorType = cast<RankedTensorType>(targetType.getBoundType ());
176- if (!hasNonZeroPadding (boundTensorType)) {
197+ if (!typeConverter. hasNonZeroPadding (boundTensorType)) {
177198 // Let the Nop pattern handle this.
178199 return rewriter.notifyMatchFailure (storeOp, " no padding applied" );
179200 }
180201
181- auto &typeConverter =
182- *getTypeConverter<MaterializePadEncodingTypeConverter>();
183202 IREE::Flow::DispatchTensorType newTargetType =
184203 typeConverter.convertType <IREE::Flow::DispatchTensorType>(targetType);
185204 RankedTensorType paddedType = newTargetType.asRankedTensorType ();
@@ -245,8 +264,9 @@ struct MaterializeEncodingIntoPaddingPass final
245264 : impl::MaterializeEncodingIntoPaddingPassBase<
246265 MaterializeEncodingIntoPaddingPass> {
247266 void getDependentDialects (DialectRegistry ®istry) const override {
248- registry.insert <linalg::LinalgDialect, tensor::TensorDialect,
249- IREE::Codegen::IREECodegenDialect>();
267+ registry.insert <arith::ArithDialect, linalg::LinalgDialect,
268+ tensor::TensorDialect, IREE::Codegen::IREECodegenDialect,
269+ IREE::GPU::IREEGPUDialect>();
250270 }
251271
252272 void runOnOperation () override {
@@ -259,8 +279,43 @@ struct MaterializeEncodingIntoPaddingPass final
259279 return failure ();
260280 };
261281
282+ // Retrieve the config from executable target attribute, if any. Otherwise,
283+ // retrieve the config from CLI GPU target and construct a virtual
284+ // configuration.
285+ auto targetAttr = IREE::HAL::ExecutableTargetAttr::lookup (operation);
286+ DictionaryAttr targetConfig;
287+ if (targetAttr) {
288+ targetConfig = targetAttr.getConfiguration ();
289+ } else {
290+ IREE::GPU::TargetAttr gpuTargetAttr = getCLGPUTarget (context);
291+ SmallVector<NamedAttribute> items;
292+ items.emplace_back (
293+ IREE::Encoding::kEncodingResolverAttrName ,
294+ IREE::GPU::getHIPTargetEncodingLayoutAttr (gpuTargetAttr, " pad" ));
295+ targetConfig = DictionaryAttr::get (context, items);
296+ }
297+
298+ // The layoutAttr should come in without any target info attached to it,
299+ // so we need to clone the layout attrs with the configuration so it can
300+ // access the target info during materialization.
301+ //
302+ // Otherwise, fall back to the nop layout.
303+ IREE::Codegen::LayoutAttrInterface layoutAttr;
304+ if (targetConfig &&
305+ targetConfig.contains (IREE::Encoding::kEncodingResolverAttrName )) {
306+ layoutAttr = targetConfig.getAs <IREE::Codegen::LayoutAttrInterface>(
307+ IREE::Encoding::kEncodingResolverAttrName );
308+ auto resolverAttr =
309+ cast<IREE::Encoding::EncodingLayoutResolverAttrInterface>(layoutAttr);
310+ layoutAttr = cast<IREE::Codegen::LayoutAttrInterface>(
311+ resolverAttr.cloneWithSimplifiedConfig (targetConfig));
312+ } else {
313+ layoutAttr = cast<IREE::Codegen::LayoutAttrInterface>(
314+ IREE::Codegen::EncodingNopLayoutAttr::get (context));
315+ }
316+
262317 RewritePatternSet materializeEncodingPattern (context);
263- MaterializePadEncodingTypeConverter typeConverter (context );
318+ MaterializePadEncodingTypeConverter typeConverter (layoutAttr );
264319 MaterializeEncodingConversionTarget target (*context);
265320 populateMaterializeEncodingPatterns (materializeEncodingPattern, target,
266321 typeConverter,
0 commit comments