Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -69,9 +69,16 @@ getPackedSizes(linalg::LinalgOp linalgOp, RewriterBase &rewriter,
FailureOr<linalg::ContractionDimensions> contractionDims =
linalg::inferContractionDims(linalgOp);
if (succeeded(contractionDims)) {
auto [m, n, k] = mmaKind.getMNKShape();
indices = {contractionDims->m, contractionDims->n, contractionDims->k};
dims = {m, n, k};
if (mmaKind.isBlockIntrinsic()) {
auto [b, m, n, k] = mmaKind.getBMNKShape();
indices = {contractionDims->batch, contractionDims->m,
contractionDims->n, contractionDims->k};
dims = {b, m, n, k};
} else {
auto [m, n, k] = mmaKind.getMNKShape();
indices = {contractionDims->m, contractionDims->n, contractionDims->k};
dims = {m, n, k};
}
}
}

Expand Down
283 changes: 265 additions & 18 deletions compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUAttrs.cpp

Large diffs are not rendered by default.

Original file line number Diff line number Diff line change
Expand Up @@ -303,6 +303,8 @@ def IREEGPU_MMAAttr : AttrDef<IREEGPU_Dialect, "MMA", [
int64_t getBlockSize() const;

SmallVector<VirtualMMAIntrinsic> getVirtualIntrinsics() const;

bool isBlockIntrinsic() const;
}];
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -171,8 +171,22 @@ def MFMA_F32_32x32x4_BF16 : I32EnumAttrCase<"MFMA_F32_32x32x4_BF16", 0x1121>;
def MFMA_F64_16x16x4_F64 : I32EnumAttrCase<"MFMA_F64_16x16x4_F64", 0x1100>;

// Introduced in CDNA3.
def MFMA_F32_16x16x16_BF16 : I32EnumAttrCase<"MFMA_F32_16x16x16_BF16", 0x1220>;
def MFMA_F32_32x32x8_BF16 : I32EnumAttrCase<"MFMA_F32_32x32x8_BF16", 0x1221>;
// Block intrinsics (multiple independent MMA operations sharing a subgroup).
def MFMA_F32_4x4x4x16B_F16 : I32EnumAttrCase<"MFMA_F32_4x4x4x16B_F16", 0x1220>;
def MFMA_F32_4x4x4x16B_BF16 : I32EnumAttrCase<"MFMA_F32_4x4x4x16B_BF16", 0x1221>;
def MFMA_F32_16x16x4x4B_F16 : I32EnumAttrCase<"MFMA_F32_16x16x4x4B_F16", 0x1224>;
def MFMA_F32_16x16x4x4B_BF16 : I32EnumAttrCase<"MFMA_F32_16x16x4x4B_BF16", 0x1225>;
def MFMA_F32_32x32x4x2B_F16 : I32EnumAttrCase<"MFMA_F32_32x32x4x2B_F16", 0x1226>;
def MFMA_F32_32x32x4x2B_BF16 : I32EnumAttrCase<"MFMA_F32_32x32x4x2B_BF16", 0x1227>;
def MFMA_I32_16x16x4x4B_I8 : I32EnumAttrCase<"MFMA_I32_16x16x4x4B_I8", 0x12C2>;
def MFMA_I32_32x32x4x2B_I8 : I32EnumAttrCase<"MFMA_I32_32x32x4x2B_I8", 0x12C3>;
def MFMA_I32_4x4x4x16B_I8 : I32EnumAttrCase<"MFMA_I32_4x4x4x16B_I8", 0x12C4>;
def MFMA_F32_4x4x1x16B_F32 : I32EnumAttrCase<"MFMA_F32_4x4x1x16B_F32", 0x1214>;
def MFMA_F32_16x16x1x4B_F32 : I32EnumAttrCase<"MFMA_F32_16x16x1x4B_F32", 0x1215>;
def MFMA_F32_32x32x1x2B_F32 : I32EnumAttrCase<"MFMA_F32_32x32x1x2B_F32", 0x1216>;
def MFMA_F64_4x4x4x4B_F64 : I32EnumAttrCase<"MFMA_F64_4x4x4x4B_F64", 0x1201>;
def MFMA_F32_16x16x16_BF16 : I32EnumAttrCase<"MFMA_F32_16x16x16_BF16", 0x1222>;
def MFMA_F32_32x32x8_BF16 : I32EnumAttrCase<"MFMA_F32_32x32x8_BF16", 0x1223>;
def MFMA_F32_16x16x32_F8E5M2FNUZ : I32EnumAttrCase<"MFMA_F32_16x16x32_F8E5M2FNUZ", 0x1230>;
def MFMA_F32_16x16x32_F8E5M2FNUZ_F8E4M3FNUZ : I32EnumAttrCase<"MFMA_F32_16x16x32_F8E5M2FNUZ_F8E4M3FNUZ", 0x1231>;
def MFMA_F32_16x16x32_F8E4M3FNUZ : I32EnumAttrCase<"MFMA_F32_16x16x32_F8E4M3FNUZ", 0x1232>;
Expand Down Expand Up @@ -278,6 +292,19 @@ def IREEGPU_MMAIntrinsic : IREEGPU_I32EnumAttr<"MMAIntrinsic",
MFMA_F64_16x16x4_F64,

// Introduced in CDNA3.
MFMA_F32_4x4x4x16B_F16,
MFMA_F32_4x4x4x16B_BF16,
MFMA_F32_16x16x4x4B_F16,
MFMA_F32_16x16x4x4B_BF16,
MFMA_F32_32x32x4x2B_F16,
MFMA_F32_32x32x4x2B_BF16,
MFMA_I32_16x16x4x4B_I8,
MFMA_I32_32x32x4x2B_I8,
MFMA_I32_4x4x4x16B_I8,
MFMA_F32_4x4x1x16B_F32,
MFMA_F32_16x16x1x4B_F32,
MFMA_F32_32x32x1x2B_F32,
MFMA_F64_4x4x4x4B_F64,
MFMA_F32_16x16x16_BF16,
MFMA_F32_32x32x8_BF16,
MFMA_F32_16x16x32_F8E5M2FNUZ,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,11 +33,35 @@ def IREEGPU_MmaInterfaceAttr
/*methodName=*/"getSubgroupSize",
/*args=*/(ins)
>,
InterfaceMethod<
/*desc=*/[{
Return true if this is a block (batched) intrinsic whose operands have
a leading batch/block dimension, i.e. A: <B, M, K>, B: <B, K, N>,
C: <B, M, N>. Defaults to false for non-block intrinsics.
}],
/*retType=*/"bool",
/*methodName=*/"isBlockIntrinsic",
/*args=*/(ins),
/*methodBody=*/"",
/*defaultImplementation=*/"return false;"
>,
];

// Methods included to preserve interfaces from beforet he InnerTileDescAttrInterface
// refactoring.
let extraSharedClassDeclaration = [{
/// Returns the (B, M, N, K) shape of a block MMA operation where A has
/// shape <B, M, K>, B has shape <B, K, N>, and C has shape <B, M, N>.
/// Only valid to call when isBlockIntrinsic() returns true.
::std::tuple<int64_t, int64_t, int64_t, int64_t>
getBMNKShape() const {
::llvm::SmallVector<::mlir::VectorType> preThreadTypes;
$_attr.getUndistributedTileTypes(preThreadTypes);
::llvm::ArrayRef<int64_t> accShape = preThreadTypes[2].getShape();
::llvm::ArrayRef<int64_t> lhsShape = preThreadTypes[0].getShape();
return {accShape[0], accShape[1], accShape[2], lhsShape[2]};
}

/// Returns the shape of the MMA operation:
/// ```
/// C += A * B
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -195,8 +195,7 @@ TargetAttr createTargetAttr(const TargetDetails &details, StringRef arch,
const WgpDetails *getCDNA4WgpDetails() {
static const MMAIntrinsic cdna4MMAOps[] = {
// Introduced in CDNA4
MMAIntrinsic::MFMA_F32_16x16x32_F16,
MMAIntrinsic::MFMA_F32_32x32x16_F16,
MMAIntrinsic::MFMA_F32_16x16x32_F16, MMAIntrinsic::MFMA_F32_32x32x16_F16,
MMAIntrinsic::MFMA_F32_16x16x32_BF16,
MMAIntrinsic::MFMA_F32_32x32x16_BF16,
MMAIntrinsic::MFMA_F32_16x16x128_F8E5M2,
Expand All @@ -207,11 +206,9 @@ const WgpDetails *getCDNA4WgpDetails() {
MMAIntrinsic::MFMA_F32_32x32x64_F8E5M2_F8E4M3FN,
MMAIntrinsic::MFMA_F32_32x32x64_F8E4M3FN,
MMAIntrinsic::MFMA_F32_32x32x64_F8E4M3FN_F8E5M2,
MMAIntrinsic::MFMA_I32_16x16x64_I8,
MMAIntrinsic::MFMA_I32_32x32x32_I8,
MMAIntrinsic::MFMA_I32_16x16x64_I8, MMAIntrinsic::MFMA_I32_32x32x32_I8,
// Introduced in CDNA3
MMAIntrinsic::MFMA_F32_16x16x16_BF16,
MMAIntrinsic::MFMA_F32_32x32x8_BF16,
MMAIntrinsic::MFMA_F32_16x16x16_BF16, MMAIntrinsic::MFMA_F32_32x32x8_BF16,
// Note: use same instructions as in CDNA3 but different types
MMAIntrinsic::MFMA_F32_16x16x32_F8E5M2,
MMAIntrinsic::MFMA_F32_16x16x32_F8E5M2_F8E4M3FN,
Expand All @@ -221,14 +218,27 @@ const WgpDetails *getCDNA4WgpDetails() {
MMAIntrinsic::MFMA_F32_32x32x16_F8E5M2_F8E4M3FN,
MMAIntrinsic::MFMA_F32_32x32x16_F8E4M3FN,
MMAIntrinsic::MFMA_F32_32x32x16_F8E4M3FN_F8E5M2,
MMAIntrinsic::MFMA_I32_16x16x32_I8,
MMAIntrinsic::MFMA_I32_32x32x16_I8,
MMAIntrinsic::MFMA_I32_16x16x32_I8, MMAIntrinsic::MFMA_I32_32x32x16_I8,
// Introduced in CDNA2, still present in CDNA3
MMAIntrinsic::MFMA_F64_16x16x4_F64,
// Introduced in CDNA1, still present in CDNA3
MMAIntrinsic::MFMA_F32_16x16x4_F32,
MMAIntrinsic::MFMA_F32_16x16x16_F16,
MMAIntrinsic::MFMA_F32_16x16x4_F32, MMAIntrinsic::MFMA_F32_16x16x16_F16,
MMAIntrinsic::MFMA_F32_32x32x8_F16,
// Block intrinsics - commented out because heuristic
// does not yet handle them correctly. They can
// still be used via explicit compilation_info.
// (F16)
// MMAIntrinsic::MFMA_F32_4x4x4x16B_F16,
// MMAIntrinsic::MFMA_F32_16x16x4x4B_F16,
// MMAIntrinsic::MFMA_F32_32x32x4x2B_F16,
// (BF16)
// MMAIntrinsic::MFMA_F32_4x4x4x16B_BF16,
// MMAIntrinsic::MFMA_F32_16x16x4x4B_BF16,
// MMAIntrinsic::MFMA_F32_32x32x4x2B_BF16,
// (I8)
// MMAIntrinsic::MFMA_I32_4x4x4x16B_I8,
// MMAIntrinsic::MFMA_I32_16x16x4x4B_I8,
// MMAIntrinsic::MFMA_I32_32x32x4x2B_I8,
};
static const ScaledMMAIntrinsic cdna4ScaledMMAOps[] = {
// Introduced in CDNA4
Expand Down Expand Up @@ -263,8 +273,7 @@ const WgpDetails *getCDNA3WgpDetails() {
// Note: these operations are listed in order of preference.
static const MMAIntrinsic cdna3MMAOps[] = {
// Introduced in CDNA3
MMAIntrinsic::MFMA_F32_16x16x16_BF16,
MMAIntrinsic::MFMA_F32_32x32x8_BF16,
MMAIntrinsic::MFMA_F32_16x16x16_BF16, MMAIntrinsic::MFMA_F32_32x32x8_BF16,
MMAIntrinsic::MFMA_F32_16x16x32_F8E5M2FNUZ,
MMAIntrinsic::MFMA_F32_16x16x32_F8E5M2FNUZ_F8E4M3FNUZ,
MMAIntrinsic::MFMA_F32_16x16x32_F8E4M3FNUZ,
Expand All @@ -273,14 +282,27 @@ const WgpDetails *getCDNA3WgpDetails() {
MMAIntrinsic::MFMA_F32_32x32x16_F8E5M2FNUZ_F8E4M3FNUZ,
MMAIntrinsic::MFMA_F32_32x32x16_F8E4M3FNUZ,
MMAIntrinsic::MFMA_F32_32x32x16_F8E4M3FNUZ_F8E5M2FNUZ,
MMAIntrinsic::MFMA_I32_16x16x32_I8,
MMAIntrinsic::MFMA_I32_32x32x16_I8,
MMAIntrinsic::MFMA_I32_16x16x32_I8, MMAIntrinsic::MFMA_I32_32x32x16_I8,
// Introduced in CDNA2, still present in CDNA3
MMAIntrinsic::MFMA_F64_16x16x4_F64,
// Introduced in CDNA1, still present in CDNA3
MMAIntrinsic::MFMA_F32_16x16x4_F32,
MMAIntrinsic::MFMA_F32_16x16x16_F16,
MMAIntrinsic::MFMA_F32_16x16x4_F32, MMAIntrinsic::MFMA_F32_16x16x16_F16,
MMAIntrinsic::MFMA_F32_32x32x8_F16,
// Block intrinsics - commented out because heuristic
// does not yet handle them correctly. They can
// still be used via explicit compilation_info.
// (F16)
// MMAIntrinsic::MFMA_F32_4x4x4x16B_F16,
// MMAIntrinsic::MFMA_F32_16x16x4x4B_F16,
// MMAIntrinsic::MFMA_F32_32x32x4x2B_F16,
// (BF16)
// MMAIntrinsic::MFMA_F32_4x4x4x16B_BF16,
// MMAIntrinsic::MFMA_F32_16x16x4x4B_BF16,
// MMAIntrinsic::MFMA_F32_32x32x4x2B_BF16,
// (I8)
// MMAIntrinsic::MFMA_I32_4x4x4x16B_I8,
// MMAIntrinsic::MFMA_I32_16x16x4x4B_I8,
// MMAIntrinsic::MFMA_I32_32x32x4x2B_I8,
};
static const int64_t cdna3DMASizes[] = {32};
static const WgpDetails cdna3Wgp = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1346,15 +1346,17 @@ FailureOr<IREE::Codegen::InnerTiledOp> convertContractionToInnerTiledMma(
return failure();
}

bool isBlock = mmaKind.isBlockIntrinsic();
if (isBlock && contractionDims.batch.empty()) {
return failure();
}

MLIRContext *context = rewriter.getContext();

int64_t innerM = contractionDims.m.back();
int64_t innerN = contractionDims.n.back();
int64_t innerK = contractionDims.k.back();

AffineExpr d0, d1, d2;
bindDims(context, d0, d1, d2);
llvm::SmallDenseMap<AffineExpr, AffineExpr> newDims;
AffineExpr mExpr = rewriter.getAffineDimExpr(innerM);
AffineExpr nExpr = rewriter.getAffineDimExpr(innerN);
AffineExpr kExpr = rewriter.getAffineDimExpr(innerK);
Expand All @@ -1381,24 +1383,49 @@ FailureOr<IREE::Codegen::InnerTiledOp> convertContractionToInnerTiledMma(
return permutation;
};

// TODO: Enable batched intrinsics and get the appropriate sub-map here.
SmallVector<int64_t> lhsInnerPerm =
getNormalizedPermutation(lhsMap.getMinorSubMap(2), {mExpr, kExpr});
SmallVector<int64_t> rhsInnerPerm =
getNormalizedPermutation(rhsMap.getMinorSubMap(2), {kExpr, nExpr});
SmallVector<int64_t> accInnerPerm =
getNormalizedPermutation(accMap.getMinorSubMap(2), {mExpr, nExpr});

if (lhsInnerPerm.empty() || rhsInnerPerm.empty() || accInnerPerm.empty()) {
return failure();
}

SmallVector<int64_t> lhsInnerPerm, rhsInnerPerm, accInnerPerm;
SmallVector<int64_t> bounds = linalgOp.getStaticLoopRanges();

auto [intrinsicM, intrinsicN, intrinsicK] = mmaKind.getMNKShape();
if (intrinsicM != bounds[innerM] || intrinsicN != bounds[innerN] ||
intrinsicK != bounds[innerK]) {
return failure();
int64_t numDims = lhsMap.getNumDims();
llvm::SmallDenseSet<int64_t> droppedDims;
int64_t numInnerDims;

if (isBlock) {
int64_t innerB = contractionDims.batch.back();
AffineExpr bExpr = rewriter.getAffineDimExpr(innerB);
lhsInnerPerm = getNormalizedPermutation(lhsMap.getMinorSubMap(3),
{bExpr, mExpr, kExpr});
rhsInnerPerm = getNormalizedPermutation(rhsMap.getMinorSubMap(3),
{bExpr, kExpr, nExpr});
accInnerPerm = getNormalizedPermutation(accMap.getMinorSubMap(3),
{bExpr, mExpr, nExpr});
if (lhsInnerPerm.empty() || rhsInnerPerm.empty() || accInnerPerm.empty()) {
return failure();
}
auto [intrinsicB, intrinsicM, intrinsicN, intrinsicK] =
mmaKind.getBMNKShape();
if (intrinsicB != bounds[innerB] || intrinsicM != bounds[innerM] ||
intrinsicN != bounds[innerN] || intrinsicK != bounds[innerK]) {
return failure();
}
droppedDims = {innerB, innerM, innerN, innerK};
numInnerDims = 4;
} else {
lhsInnerPerm =
getNormalizedPermutation(lhsMap.getMinorSubMap(2), {mExpr, kExpr});
rhsInnerPerm =
getNormalizedPermutation(rhsMap.getMinorSubMap(2), {kExpr, nExpr});
accInnerPerm =
getNormalizedPermutation(accMap.getMinorSubMap(2), {mExpr, nExpr});
if (lhsInnerPerm.empty() || rhsInnerPerm.empty() || accInnerPerm.empty()) {
return failure();
}
auto [intrinsicM, intrinsicN, intrinsicK] = mmaKind.getMNKShape();
if (intrinsicM != bounds[innerM] || intrinsicN != bounds[innerN] ||
intrinsicK != bounds[innerK]) {
return failure();
}
droppedDims = {innerM, innerN, innerK};
numInnerDims = 3;
}

SmallVector<Value> inputs = linalgOp->getOperands();
Expand All @@ -1415,10 +1442,8 @@ FailureOr<IREE::Codegen::InnerTiledOp> convertContractionToInnerTiledMma(

SmallVector<utils::IteratorType> linalgIteratorTypes =
linalgOp.getIteratorTypesArray();
llvm::SmallDenseSet<int64_t> droppedDims = {innerM, innerN, innerK};
llvm::SmallDenseMap<int64_t, int64_t> oldDimsToNewDimsMap;
int64_t currentDim = 0;
int64_t numDims = lhsMap.getNumDims();
SmallVector<utils::IteratorType> iteratorTypes;
for (int64_t dim = 0, e = numDims; dim < e; ++dim) {
if (droppedDims.contains(dim)) {
Expand All @@ -1429,16 +1454,17 @@ FailureOr<IREE::Codegen::InnerTiledOp> convertContractionToInnerTiledMma(
}

AffineMap outerLhsMap =
dropDims(context, numDims - 3, lhsMap, oldDimsToNewDimsMap);
dropDims(context, numDims - numInnerDims, lhsMap, oldDimsToNewDimsMap);
AffineMap outerRhsMap =
dropDims(context, numDims - 3, rhsMap, oldDimsToNewDimsMap);
dropDims(context, numDims - numInnerDims, rhsMap, oldDimsToNewDimsMap);
AffineMap outerAccMap =
dropDims(context, numDims - 3, accMap, oldDimsToNewDimsMap);
dropDims(context, numDims - numInnerDims, accMap, oldDimsToNewDimsMap);

std::optional<SmallVector<SmallVector<int64_t>>> perms =
SmallVector<SmallVector<int64_t>>{lhsInnerPerm, rhsInnerPerm,
accInnerPerm};
SmallVector<int64_t> identityPerm = {0, 1};
SmallVector<int64_t> identityPerm =
isBlock ? SmallVector<int64_t>{0, 1, 2} : SmallVector<int64_t>{0, 1};

if (lhsInnerPerm == identityPerm && rhsInnerPerm == identityPerm &&
accInnerPerm == identityPerm) {
Expand Down
Loading
Loading