diff --git a/compiler/src/iree/compiler/Codegen/Common/ResolveSwizzleHints.cpp b/compiler/src/iree/compiler/Codegen/Common/ResolveSwizzleHints.cpp index 42b75e76d177..27df302eea5c 100644 --- a/compiler/src/iree/compiler/Codegen/Common/ResolveSwizzleHints.cpp +++ b/compiler/src/iree/compiler/Codegen/Common/ResolveSwizzleHints.cpp @@ -164,7 +164,7 @@ verifyFlatContiguousSwizzleHintOp(IREE::Codegen::SwizzleHintOp hintOp) { auto memrefType = cast(hintOp.getOperand().getType()); // Swizzle hints require flat (rank 1) memrefs. // For rank 1, allow dynamic memrefs or static contiguous row-major memrefs. - if ((memrefType.getRank() != 1 || !memrefType.getLayout().isIdentity()) || + if (memrefType.getRank() != 1 || (memrefType.hasStaticShape() && !memref::isStaticShapeAndContiguousRowMajor(memrefType))) { hintOp.emitError() diff --git a/compiler/src/iree/compiler/Codegen/Dialect/GPU/TargetUtils/ConfigUtils.cpp b/compiler/src/iree/compiler/Codegen/Dialect/GPU/TargetUtils/ConfigUtils.cpp index 9d174e0b8448..4b6648b25798 100644 --- a/compiler/src/iree/compiler/Codegen/Dialect/GPU/TargetUtils/ConfigUtils.cpp +++ b/compiler/src/iree/compiler/Codegen/Dialect/GPU/TargetUtils/ConfigUtils.cpp @@ -913,8 +913,33 @@ getMatmulOrIGEMMLoweringConfigAndWorkgroupSize( SmallVector promotionArray; auto defaultConfigAttr = IREE::GPU::DerivedThreadConfigAttr::get(context); Attribute useGlobalDma = IREE::GPU::UseGlobalLoadDMAAttr::get(context); - if (useDirectLoad && !scaled) { - promotionArray = {useGlobalDma, useGlobalDma}; + if (!scaled && useDirectLoad) { + Attribute lhsAttr = useGlobalDma; + Attribute rhsAttr = useGlobalDma; + // Apply XOR swizzle for BF16 DMA operands whose reduction dim is + // innermost (contiguous reads) to avoid LDS bank conflicts. + SmallVector elemTypes; + kind.getElementTypes(elemTypes); + bool isBF16 = !elemTypes.empty() && elemTypes[0].isBF16(); + if (isBF16) { + if (!transposedLhs) { + FailureOr lhsSwizzleAttr = + getXorShuffleAttr(context, useGlobalDma, target, kind, + schedule->kTileSizes, kMMAOperandLhs); + if (succeeded(lhsSwizzleAttr)) { + lhsAttr = *lhsSwizzleAttr; + } + } + if (transposedRhs) { + FailureOr rhsSwizzleAttr = + getXorShuffleAttr(context, useGlobalDma, target, kind, + schedule->kTileSizes, kMMAOperandRhs); + if (succeeded(rhsSwizzleAttr)) { + rhsAttr = *rhsSwizzleAttr; + } + } + } + promotionArray = {lhsAttr, rhsAttr}; } SmallVector promotionList = {0, 1}; if (scaled) { diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/config_tile_and_fuse_gfx950.mlir b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/config_tile_and_fuse_gfx950.mlir index 2d61c7a484be..d40052583460 100644 --- a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/config_tile_and_fuse_gfx950.mlir +++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/config_tile_and_fuse_gfx950.mlir @@ -386,6 +386,38 @@ func.func @scaled_matmul_accumulate( // ----- +// BF16 matmul with direct-load DMA gets LHS XOR swizzle for bank conflict +// avoidance. Only LHS is swizzled (reduction dim is innermost for A[M,K]). +func.func @matmul_bf16( + %arg0: tensor<4096x4096xbf16>, + %arg1: tensor<4096x4096xbf16>, + %arg2: tensor<4096x4096xf32>) -> tensor<4096x4096xf32> { + %0 = linalg.matmul ins(%arg0, %arg1 : tensor<4096x4096xbf16>, tensor<4096x4096xbf16>) + outs(%arg2 : tensor<4096x4096xf32>) -> tensor<4096x4096xf32> + return %0 : tensor<4096x4096xf32> +} +// CHECK-LABEL: func.func @matmul_bf16 +// CHECK: lowering_config = #iree_gpu.lowering_config +// CHECK-SAME: mma_kind = #iree_gpu.mma_layout + +// CHECK-REMARKS: [Analysis] SharedMemoryUsage +// CHECK-REMARKS-SAME: Category:deduceMMASchedule +// CHECK-REMARKS-SAME: Remark=16384 + +// CHECK-REMARKS-DIRECT-LOAD-2: [Analysis] SharedMemoryUsage +// CHECK-REMARKS-DIRECT-LOAD-2-SAME: Category:deduceMMASchedule +// CHECK-REMARKS-DIRECT-LOAD-2-SAME: Remark=32768 + +// CHECK-REMARKS-DIRECT-LOAD-3: [Analysis] SharedMemoryUsage +// CHECK-REMARKS-DIRECT-LOAD-3-SAME: Category:deduceMMASchedule +// CHECK-REMARKS-DIRECT-LOAD-3-SAME: Remark=49152 + +// CHECK-DIRECT-LOAD-LABEL: func.func @matmul_bf16 +// CHECK-DIRECT-LOAD: linalg.matmul {lowering_config = #iree_gpu.lowering_config +// CHECK-DIRECT-LOAD-SAME: promotion_types = [#iree_gpu.swizzle_operand>, #iree_gpu.use_global_load_dma] + +// ----- + // Very large f16 matmul — compute-bound, so picks 32x32x16 (higher compute per // instruction, lower VGPR pressure than 16x16x32). func.func @matmul_f16_compute_bound( diff --git a/compiler/src/iree/compiler/Codegen/Utils/GPUUtils.cpp b/compiler/src/iree/compiler/Codegen/Utils/GPUUtils.cpp index 84dd7deb099d..7b9efb70e12f 100644 --- a/compiler/src/iree/compiler/Codegen/Utils/GPUUtils.cpp +++ b/compiler/src/iree/compiler/Codegen/Utils/GPUUtils.cpp @@ -810,7 +810,16 @@ static FailureOr getXorShuffleParamsForGfx950( return XorShuffleParams({/*rowElems=*/256, /*accessElems=*/32}); default: - // TODO(muzasyed): Add more intrinsics for gfx950. + return failure(); + } + } + if (auto mma = dyn_cast(intrinsic)) { + switch (mma.getIntrinsic()) { + case IREE::GPU::MMAIntrinsic::MFMA_F32_16x16x32_BF16: + case IREE::GPU::MMAIntrinsic::MFMA_F32_32x32x16_BF16: + return XorShuffleParams({/*rowElems=*/128, + /*accessElems=*/8}); + default: return failure(); } }