Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -164,7 +164,7 @@ verifyFlatContiguousSwizzleHintOp(IREE::Codegen::SwizzleHintOp hintOp) {
auto memrefType = cast<MemRefType>(hintOp.getOperand().getType());
// Swizzle hints require flat (rank 1) memrefs.
// For rank 1, allow dynamic memrefs or static contiguous row-major memrefs.
if ((memrefType.getRank() != 1 || !memrefType.getLayout().isIdentity()) ||
if (memrefType.getRank() != 1 ||
(memrefType.hasStaticShape() &&
!memref::isStaticShapeAndContiguousRowMajor(memrefType))) {
hintOp.emitError()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -913,8 +913,33 @@ getMatmulOrIGEMMLoweringConfigAndWorkgroupSize(
SmallVector<Attribute> promotionArray;
auto defaultConfigAttr = IREE::GPU::DerivedThreadConfigAttr::get(context);
Attribute useGlobalDma = IREE::GPU::UseGlobalLoadDMAAttr::get(context);
if (useDirectLoad && !scaled) {
promotionArray = {useGlobalDma, useGlobalDma};
if (!scaled && useDirectLoad) {
Attribute lhsAttr = useGlobalDma;
Attribute rhsAttr = useGlobalDma;
// Apply XOR swizzle for BF16 DMA operands whose reduction dim is
// innermost (contiguous reads) to avoid LDS bank conflicts.
SmallVector<Type> elemTypes;
kind.getElementTypes(elemTypes);
bool isBF16 = !elemTypes.empty() && elemTypes[0].isBF16();
if (isBF16) {
if (!transposedLhs) {
FailureOr<Attribute> lhsSwizzleAttr =
getXorShuffleAttr(context, useGlobalDma, target, kind,
schedule->kTileSizes, kMMAOperandLhs);
if (succeeded(lhsSwizzleAttr)) {
lhsAttr = *lhsSwizzleAttr;
}
}
if (transposedRhs) {
FailureOr<Attribute> rhsSwizzleAttr =
getXorShuffleAttr(context, useGlobalDma, target, kind,
schedule->kTileSizes, kMMAOperandRhs);
if (succeeded(rhsSwizzleAttr)) {
rhsAttr = *rhsSwizzleAttr;
}
}
}
promotionArray = {lhsAttr, rhsAttr};
}
SmallVector<int64_t> promotionList = {0, 1};
if (scaled) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -386,6 +386,38 @@ func.func @scaled_matmul_accumulate(

// -----

// BF16 matmul with direct-load DMA gets LHS XOR swizzle for bank conflict
// avoidance. Only LHS is swizzled (reduction dim is innermost for A[M,K]).
func.func @matmul_bf16(
%arg0: tensor<4096x4096xbf16>,
%arg1: tensor<4096x4096xbf16>,
%arg2: tensor<4096x4096xf32>) -> tensor<4096x4096xf32> {
%0 = linalg.matmul ins(%arg0, %arg1 : tensor<4096x4096xbf16>, tensor<4096x4096xbf16>)
outs(%arg2 : tensor<4096x4096xf32>) -> tensor<4096x4096xf32>
return %0 : tensor<4096x4096xf32>
}
// CHECK-LABEL: func.func @matmul_bf16
// CHECK: lowering_config = #iree_gpu.lowering_config
// CHECK-SAME: mma_kind = #iree_gpu.mma_layout<MFMA_F32_16x16x32_BF16>

// CHECK-REMARKS: [Analysis] SharedMemoryUsage
// CHECK-REMARKS-SAME: Category:deduceMMASchedule
// CHECK-REMARKS-SAME: Remark=16384

// CHECK-REMARKS-DIRECT-LOAD-2: [Analysis] SharedMemoryUsage
// CHECK-REMARKS-DIRECT-LOAD-2-SAME: Category:deduceMMASchedule
// CHECK-REMARKS-DIRECT-LOAD-2-SAME: Remark=32768

// CHECK-REMARKS-DIRECT-LOAD-3: [Analysis] SharedMemoryUsage
// CHECK-REMARKS-DIRECT-LOAD-3-SAME: Category:deduceMMASchedule
// CHECK-REMARKS-DIRECT-LOAD-3-SAME: Remark=49152

// CHECK-DIRECT-LOAD-LABEL: func.func @matmul_bf16
// CHECK-DIRECT-LOAD: linalg.matmul {lowering_config = #iree_gpu.lowering_config
// CHECK-DIRECT-LOAD-SAME: promotion_types = [#iree_gpu.swizzle_operand<copy_config = #iree_gpu.use_global_load_dma, swizzle = #iree_codegen.xor_shuffle<128, 8>>, #iree_gpu.use_global_load_dma]

// -----

// Very large f16 matmul — compute-bound, so picks 32x32x16 (higher compute per
// instruction, lower VGPR pressure than 16x16x32).
func.func @matmul_f16_compute_bound(
Expand Down
11 changes: 10 additions & 1 deletion compiler/src/iree/compiler/Codegen/Utils/GPUUtils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -810,7 +810,16 @@ static FailureOr<XorShuffleParams> getXorShuffleParamsForGfx950(
return XorShuffleParams({/*rowElems=*/256,
/*accessElems=*/32});
default:
// TODO(muzasyed): Add more intrinsics for gfx950.
return failure();
}
}
if (auto mma = dyn_cast<IREE::GPU::MMAAttr>(intrinsic)) {
switch (mma.getIntrinsic()) {
case IREE::GPU::MMAIntrinsic::MFMA_F32_16x16x32_BF16:
case IREE::GPU::MMAIntrinsic::MFMA_F32_32x32x16_BF16:
return XorShuffleParams({/*rowElems=*/128,
/*accessElems=*/8});
default:
return failure();
}
}
Expand Down
Loading