diff --git a/compiler/src/iree/compiler/Codegen/Common/GPU/GPUConvertToCoalescedDMA.cpp b/compiler/src/iree/compiler/Codegen/Common/GPU/GPUConvertToCoalescedDMA.cpp index 077748420b30..9cd21fc2ea5f 100644 --- a/compiler/src/iree/compiler/Codegen/Common/GPU/GPUConvertToCoalescedDMA.cpp +++ b/compiler/src/iree/compiler/Codegen/Common/GPU/GPUConvertToCoalescedDMA.cpp @@ -12,6 +12,8 @@ #include "iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUOps.h" #include "iree/compiler/Codegen/Utils/GPUUtils.h" #include "iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.h" +#include "iree/compiler/Dialect/LinalgExt/Utils/Utils.h" +#include "iree/compiler/Utils/Permutation.h" #include "llvm/Support/Debug.h" #include "mlir/Dialect/AMDGPU/Utils/Chipset.h" #include "mlir/Dialect/Affine/IR/AffineOps.h" @@ -155,6 +157,10 @@ static bool sourceIsFromFatRawBuffer(Value source) { source = pad.getSource(); continue; } + if (auto collapseOp = source.getDefiningOp()) { + source = collapseOp.getSrc(); + continue; + } break; } @@ -846,6 +852,388 @@ struct ConvertGatherToCoalescedDMA } }; +/// Check if an im2col op is viable for conversion to gather + DMA. +/// Validates v1 constraints: identity perms, single K window, +/// channel-aligned k_off, DMA-aligned contiguous size, static shapes. +static bool isIm2colDMAConvertible(IREE::LinalgExt::Im2colOp im2colOp) { + auto funcOp = im2colOp->getParentOfType(); + if (!funcOp) { + return false; + } + + IREE::GPU::TargetAttr target = getGPUTargetAttr(funcOp); + if (!target || !targetSupportsGlobalLoadDMA(target)) { + return false; + } + + // Note: we do NOT check sourceIsFromFatRawBuffer here because at this + // pipeline stage (before bufferization), the im2col input comes from + // dispatch.tensor.load, not from LoadFromBufferOp/fat_raw_buffer. + // The fat_raw_buffer cast happens during bufferization. The gather DMA + // conversion (ConvertGatherToCoalescedDMA) checks the source later when + // it matters. + + // v1: identity output_perm and input_k_perm only. + if (!isIdentityPermutation(im2colOp.getOutputPerm()) || + !isIdentityPermutation(im2colOp.getInputKPerm())) { + return false; + } + + // getVectorizableDim enforces willBeContiguousSlice (single-window K_tile). + OpBuilder b(im2colOp); + Location loc = im2colOp.getLoc(); + std::optional vecDim = im2colOp.getVectorizableDim(b, loc); + if (!vecDim.has_value()) { + return false; + } + + // v1: all output shapes must be static. + auto outputType = cast(im2colOp.getOutputType()); + if (!outputType.hasStaticShape()) { + return false; + } + + int64_t contiguousSize = outputType.getShape()[*vecDim]; + + // v1: k_off must be channel-aligned (k_off % C == 0). + auto inputType = cast(im2colOp.getInputType()); + ArrayRef kPos = im2colOp.getKPos(); + int64_t cDim = kPos.back(); + int64_t C = inputType.getShape()[cDim]; + if (ShapedType::isDynamic(C)) { + return false; + } + + SmallVector mixedOffsets = im2colOp.getMixedOffsets(); + int64_t numBatchDims = im2colOp.getBatchPos().size(); + int64_t numMDims = im2colOp.getNumMOutputDims(); + int64_t kCanonicalIdx = numBatchDims + numMDims; + + if (kCanonicalIdx < static_cast(mixedOffsets.size())) { + OpFoldResult kOff = mixedOffsets[kCanonicalIdx]; + if (auto constVal = getConstantIntValue(kOff)) { + if (*constVal % C != 0) { + return false; + } + } else { + // Dynamic k_off: if contiguousSize <= C, chooseDimToVectorize already + // validated alignment. Otherwise reject. + if (contiguousSize > C) { + return false; + } + } + } + + // DMA alignment check. + return getDMAAlignedSubgroupSize(funcOp, outputType.getElementType(), + contiguousSize) + .has_value(); +} + +/// Build a 1D tensor where each element is the linearized +/// spatial offset in the collapsed source for that batch position. +/// +/// For batch position i: +/// (b, m) = delinearize(i, [batch_tile, M_tile]) +/// (oh, ow, ...) = delinearize(m_off + m, output_sizes_M) +/// (kh, kw, ...) = delinearize(k_off / C, window_sizes) +/// spatial[j] = m_coord[j] * stride[j] + window[j] * dilation[j] +/// n = batch_off + b +/// lin = n * dim[0] * dim[1] * ... + spatial[0] * dim[1] * ... + ... +static Value buildIm2colIndexTensor(PatternRewriter &rewriter, Location loc, + IREE::LinalgExt::Im2colOp im2colOp, + int64_t batchSize) { + using namespace IREE::LinalgExt; + + auto inputType = cast(im2colOp.getInputType()); + auto outputType = cast(im2colOp.getOutputType()); + int64_t inputRank = inputType.getRank(); + ArrayRef inputShape = inputType.getShape(); + + ArrayRef strides = im2colOp.getStrides(); + ArrayRef dilations = im2colOp.getDilations(); + ArrayRef batchPos = im2colOp.getBatchPos(); + ArrayRef mPos = im2colOp.getMPos(); + ArrayRef kPos = im2colOp.getKPos(); + + SmallVector mixedOffsets = im2colOp.getMixedOffsets(); + SmallVector> mixedOutputSizes = + im2colOp.getMixedOutputSizes(); + + int64_t numBatchDims = batchPos.size(); + int64_t numMDims = im2colOp.getNumMOutputDims(); + + SmallVector batchOutputDims = im2colOp.getBatchOutputDims(); + SmallVector mOutputDims = im2colOp.getMOutputDims(); + ArrayRef outputShape = outputType.getShape(); + int64_t batchTile = 1; + for (int64_t d : batchOutputDims) { + batchTile *= outputShape[d]; + } + int64_t mTile = 1; + for (int64_t d : mOutputDims) { + mTile *= outputShape[d]; + } + + // Create tensor.empty for the index tensor. + // Use index type so that after bufferization + DMA lowering, the loaded + // index values are directly usable as gather_to_lds source indices. + Type indexType = rewriter.getIndexType(); + Value emptyTensor = tensor::EmptyOp::create( + rewriter, loc, ArrayRef{batchSize}, indexType); + + // Build linalg.generic with a single parallel iterator. + AffineMap outputMap = rewriter.getMultiDimIdentityMap(1); + SmallVector iterTypes = {utils::IteratorType::parallel}; + + auto genericOp = linalg::GenericOp::create( + rewriter, loc, emptyTensor.getType(), /*inputs=*/ValueRange{}, + /*outputs=*/ValueRange{emptyTensor}, + /*indexingMaps=*/ArrayRef{outputMap}, iterTypes, + [&](OpBuilder &b, Location nestedLoc, ValueRange args) { + // Get the flat iteration index. + Value idx = linalg::IndexOp::create(b, nestedLoc, 0); + + // Delinearize idx into (batchIdx, mIdx). + SmallVector batchMBasis = {b.getIndexAttr(batchTile), + b.getIndexAttr(mTile)}; + auto delinBM = affine::AffineDelinearizeIndexOp::create( + b, nestedLoc, idx, batchMBasis, /*hasOuterBound=*/true); + Value batchIdx = delinBM.getResult(0); + Value mIdx = delinBM.getResult(1); + + // Compute batch offset: n = batch_off + batchIdx. + // For each batch dim, delinearize the batch index using output_sizes. + // With identity output_perm, canonical batch dims map directly. + SmallVector batchCoords; + if (numBatchDims == 1) { + OpFoldResult batchOff = mixedOffsets[0]; + Value batchOffVal = + getValueOrCreateConstantIndexOp(b, nestedLoc, batchOff); + Value n = arith::AddIOp::create(b, nestedLoc, batchOffVal, batchIdx); + batchCoords.push_back(n); + } else { + // Multiple batch dims: delinearize. + SmallVector batchBasis; + for (int64_t i = 0; i < numBatchDims; ++i) { + batchBasis.append(mixedOutputSizes[i].begin(), + mixedOutputSizes[i].end()); + } + auto delinBatch = affine::AffineDelinearizeIndexOp::create( + b, nestedLoc, batchIdx, batchBasis, /*hasOuterBound=*/true); + for (int64_t i = 0; i < numBatchDims; ++i) { + Value coord = delinBatch.getResult(i); + OpFoldResult off = mixedOffsets[i]; + Value offVal = getValueOrCreateConstantIndexOp(b, nestedLoc, off); + batchCoords.push_back( + arith::AddIOp::create(b, nestedLoc, offVal, coord)); + } + } + + // Delinearize M index using M output_sizes. + // m_pos + m_off for each spatial dim. + SmallVector mCoords; + { + // Collect all M output_sizes into a flat basis. + SmallVector mBasis; + for (int64_t i = 0; i < numMDims; ++i) { + int64_t canonIdx = numBatchDims + i; + mBasis.append(mixedOutputSizes[canonIdx].begin(), + mixedOutputSizes[canonIdx].end()); + } + // For each M output dim, add its offset then delinearize using + // its output_sizes to get spatial coordinates. + if (numMDims == 1) { + int64_t canonIdx = numBatchDims; + OpFoldResult mOff = mixedOffsets[canonIdx]; + Value mOffVal = getValueOrCreateConstantIndexOp(b, nestedLoc, mOff); + Value mPos = arith::AddIOp::create(b, nestedLoc, mOffVal, mIdx); + const SmallVector &innerSizes = + mixedOutputSizes[canonIdx]; + if (innerSizes.size() == 1) { + mCoords.push_back(mPos); + } else { + auto delinM = affine::AffineDelinearizeIndexOp::create( + b, nestedLoc, mPos, innerSizes, /*hasOuterBound=*/true); + for (unsigned j = 0; j < innerSizes.size(); ++j) { + mCoords.push_back(delinM.getResult(j)); + } + } + } else { + // Multiple M output dims. Delinearize mIdx into per-dim sizes. + SmallVector mDimSizes; + for (int64_t d : mOutputDims) { + mDimSizes.push_back(b.getIndexAttr(outputShape[d])); + } + auto delinMDims = affine::AffineDelinearizeIndexOp::create( + b, nestedLoc, mIdx, mDimSizes, /*hasOuterBound=*/true); + for (int64_t i = 0; i < numMDims; ++i) { + int64_t canonIdx = numBatchDims + i; + OpFoldResult mOff = mixedOffsets[canonIdx]; + Value mOffVal = + getValueOrCreateConstantIndexOp(b, nestedLoc, mOff); + Value mDimIdx = delinMDims.getResult(i); + Value mPosVal = + arith::AddIOp::create(b, nestedLoc, mOffVal, mDimIdx); + const SmallVector &innerSizes = + mixedOutputSizes[canonIdx]; + if (innerSizes.size() == 1) { + mCoords.push_back(mPosVal); + } else { + auto delinM = affine::AffineDelinearizeIndexOp::create( + b, nestedLoc, mPosVal, innerSizes, + /*hasOuterBound=*/true); + for (unsigned j = 0; j < innerSizes.size(); ++j) { + mCoords.push_back(delinM.getResult(j)); + } + } + } + } + } + + // Compute window offsets from k_off. + // k_off / C gives the linearized window index, which we delinearize + // using the kernel_size (window sizes for each spatial dim). + SmallVector windowCoords; + { + int64_t kCanonIdx = numBatchDims + numMDims; + OpFoldResult kOff = mixedOffsets[kCanonIdx]; + Value kOffVal = getValueOrCreateConstantIndexOp(b, nestedLoc, kOff); + + // Get C = innermost k_pos channel size. + int64_t C = inputShape[kPos.back()]; + Value cVal = arith::ConstantIndexOp::create(b, nestedLoc, C); + Value windowIdx = arith::DivUIOp::create(b, nestedLoc, kOffVal, cVal); + + // Delinearize window index using kernel_size. + SmallVector kernelSize = im2colOp.getMixedKernelSize(); + if (kernelSize.size() == 1) { + windowCoords.push_back(windowIdx); + } else { + auto delinWin = affine::AffineDelinearizeIndexOp::create( + b, nestedLoc, windowIdx, kernelSize, + /*hasOuterBound=*/true); + for (unsigned j = 0; j < kernelSize.size(); ++j) { + windowCoords.push_back(delinWin.getResult(j)); + } + } + } + + // Compute spatial coordinates. + // spatial[j] = mCoords[j] * strides[j] + windowCoords[j] * + // dilations[j] + SmallVector spatialCoords; + AffineExpr d0, d1; + bindDims(b.getContext(), d0, d1); + for (unsigned j = 0; j < mPos.size(); ++j) { + auto map = + AffineMap::get(2, 0, {d0 * strides[j] + d1 * dilations[j]}); + Value spatial = affine::makeComposedAffineApply( + b, nestedLoc, map, {mCoords[j], windowCoords[j]}); + spatialCoords.push_back(spatial); + } + + // Build the full input coordinate vector, then linearize. + // Input layout: dimensions at batchPos get batch coords, + // dimensions at mPos get spatial coords, + // dimensions at kPos are handled by the gather's + // contiguous slice (not part of the index). + // We linearize all dims except the last (channel) dim. + SmallVector inputCoords(inputRank); + int batchCoordIdx = 0; + int spatialCoordIdx = 0; + SetVector batchPosSet(batchPos.begin(), batchPos.end()); + SetVector mPosSet(mPos.begin(), mPos.end()); + for (int64_t i = 0; i < inputRank; ++i) { + if (batchPosSet.contains(i)) { + inputCoords[i] = batchCoords[batchCoordIdx++]; + } else if (mPosSet.contains(i)) { + inputCoords[i] = spatialCoords[spatialCoordIdx++]; + } else { + // K (channel) dims — set to 0 for linearization; the gather + // reads the contiguous slice along these dims. + inputCoords[i] = arith::ConstantIndexOp::create(b, nestedLoc, 0); + } + } + + // Linearize all dims except the last (contiguous channel dim). + // lin = coords[0] * (shape[1]*...*shape[R-2]) + // + coords[1] * (shape[2]*...*shape[R-2]) + // + ... + coords[R-2] + SmallVector outerCoords(inputCoords.begin(), + inputCoords.begin() + inputRank - 1); + SmallVector outerBasis; + for (int64_t i = 0; i < inputRank - 1; ++i) { + outerBasis.push_back(b.getIndexAttr(inputShape[i])); + } + + Value linIdx = affine::AffineLinearizeIndexOp::create( + b, nestedLoc, outerCoords, outerBasis, /*disjoint=*/false); + + linalg::YieldOp::create(b, nestedLoc, linIdx); + }); + + return genericOp.getResult(0); +} + +/// Convert im2col to gather for DMA. Collapses the conv input, computes +/// a linearized index tensor, creates a gather with dimension_map=[0], +/// and reshapes the result back. +struct ConvertIm2colToGather : OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(IREE::LinalgExt::Im2colOp im2colOp, + PatternRewriter &rewriter) const override { + auto dmaConfig = + getLoweringConfig(im2colOp); + if (!dmaConfig) { + return failure(); + } + Location loc = im2colOp.getLoc(); + + auto outputType = cast(im2colOp.getOutputType()); + ArrayRef outputShape = outputType.getShape(); + int64_t outputRank = outputType.getRank(); + + // batch_size = product of all dims except the last (K_tile). + int64_t batchSize = ShapedType::getNumElements(outputShape.drop_back()); + + // 1. Collapse source to 2D: [[0..rank-2], [rank-1]]. + Value input = im2colOp.getInput(); + auto inputType = cast(input.getType()); + int64_t inputRank = inputType.getRank(); + SmallVector srcReassoc = { + llvm::to_vector(llvm::seq(0, inputRank - 1)), {inputRank - 1}}; + Value collapsed = + tensor::CollapseShapeOp::create(rewriter, loc, input, srcReassoc); + + // 2. Compute index tensor. + Value indices = buildIm2colIndexTensor(rewriter, loc, im2colOp, batchSize); + + // 3. Reshape im2col output to [batch_size, C_per_window]. + // Build reassociation: [[0..outputRank-2], [outputRank-1]]. + SmallVector outputReassoc = { + llvm::to_vector(llvm::seq(0, outputRank - 1)), + {outputRank - 1}}; + Value output = im2colOp.getOutput(); + Value reshapedOutput = + tensor::CollapseShapeOp::create(rewriter, loc, output, outputReassoc); + + // 4. Create gather with dimension_map = [0]. + auto gatherOp = IREE::LinalgExt::GatherOp::create( + rewriter, loc, reshapedOutput.getType(), collapsed, indices, + reshapedOutput, rewriter.getDenseI64ArrayAttr({0})); + setLoweringConfig(gatherOp, dmaConfig); + + // 5. Reshape gather result back to original output shape. + Value result = tensor::ExpandShapeOp::create( + rewriter, loc, outputType, gatherOp.getResult(0), outputReassoc); + + rewriter.replaceOp(im2colOp, result); + return success(); + } +}; + struct GPUConvertToCoalescedDMAPass final : impl::GPUConvertToCoalescedDMAPassBase { using GPUConvertToCoalescedDMAPassBase::GPUConvertToCoalescedDMAPassBase; @@ -895,13 +1283,33 @@ struct GPUConvertToCoalescedDMAPass final } } - // Preprocessing: apply subgroup-level tiling. + // Im2col pre-check: individually downgrade non-convertible im2cols. + funcOp->walk([&](IREE::LinalgExt::Im2colOp im2colOp) { + if (getLoweringConfig(im2colOp)) { + if (!isIm2colDMAConvertible(im2colOp)) { + setLoweringConfig(im2colOp, + IREE::GPU::DerivedThreadConfigAttr::get(context)); + } + } + }); + + // Phase 0: convert im2col -> gather. + // This produces new GatherOps that Phase 1 (subgroup tiling) will + // pick up. + { + RewritePatternSet im2colPatterns(context); + im2colPatterns.add(context); + if (failed(applyPatternsGreedily(funcOp, std::move(im2colPatterns)))) { + return signalPassFailure(); + } + } + + // Phase 1: subgroup tiling — also tiles new gather ops from Phase 0. if (failed(applySubgroupTiling(funcOp))) { return signalPassFailure(); } - // Only tile and convert ops within forall ops with warp mapping. - // Also handle tensor.pad fusion cases that don't have warp mapping. + // Phase 2: gather/copy -> DMA. RewritePatternSet patterns(context); patterns.add(context); patterns.add(context); diff --git a/compiler/src/iree/compiler/Codegen/Common/GPU/test/gpu_convert_to_coalesced_dma.mlir b/compiler/src/iree/compiler/Codegen/Common/GPU/test/gpu_convert_to_coalesced_dma.mlir index c4be39aa0571..fe575cc2f347 100644 --- a/compiler/src/iree/compiler/Codegen/Common/GPU/test/gpu_convert_to_coalesced_dma.mlir +++ b/compiler/src/iree/compiler/Codegen/Common/GPU/test/gpu_convert_to_coalesced_dma.mlir @@ -826,3 +826,154 @@ func.func @copy_swizzle_hint_linearized(%source: tensor<128x16xf32>) -> tensor<1 return %result : tensor<128x16xf32> } + +// ----- + +// Test: im2col → gather DMA conversion (happy path). +// NHWC layout, 3×3 kernel, stride 1, dilation 1, C=512 on gfx950. +// The im2col should be converted to: collapse_shape + linalg.generic (index +// computation) + gather → then the gather gets converted to coalesced DMA. + +#gpu_target_im2col = #iree_gpu.target> +#exec_target_im2col = #hal.executable.target<"rocm", "rocm-hsaco-fb", {iree_codegen.target_info = #gpu_target_im2col}> +#translation_im2col = #iree_codegen.translation_info workgroup_size = [256, 1, 1] subgroup_size = 64, {gpu_pipeline_options = #iree_gpu.pipeline_options}> + +// CHECK-LABEL: func.func @im2col_to_gather_dma +// CHECK-SAME: %[[INPUT:[a-zA-Z0-9]+]]: tensor<1x16x16x512xf16> +// CHECK-SAME: %[[OUTPUT:[a-zA-Z0-9]+]]: tensor<1x196x512xf16> +func.func @im2col_to_gather_dma(%input: tensor<1x16x16x512xf16>, %output: tensor<1x196x512xf16>) -> tensor<1x196x512xf16> + attributes {hal.executable.target = #exec_target_im2col, translation_info = #translation_im2col} { + %result = iree_linalg_ext.im2col + {lowering_config = #iree_gpu.use_global_load_dma} + strides = [1, 1] dilations = [1, 1] kernel_size = [3, 3] + offsets = [0, 0, 0] + output_sizes = [[1], [14, 14], [3, 3, 512]] + batch_pos = [0] m_pos = [1, 2] k_pos = [3] + input_k_perm = [0, 1, 2] output_perm = [0, 1, 2] + ins(%input : tensor<1x16x16x512xf16>) + outs(%output : tensor<1x196x512xf16>) -> tensor<1x196x512xf16> + + // Step 1: Collapse input [1,16,16,512] → [256,512] (flatten spatial dims). + // CHECK: %[[COLLAPSED:.+]] = tensor.collapse_shape %[[INPUT]] {{\[}}[0, 1, 2], [3]{{\]}} + // CHECK-SAME: tensor<1x16x16x512xf16> into tensor<256x512xf16> + + // Step 2: Compute linearized spatial indices via linalg.generic. + // Each of the 196 output positions (14×14) maps to a row in the 256-row + // collapsed source via: linearize(delinearize(i, [14,14]), [16,16]). + // CHECK: %[[INDICES:.+]] = linalg.generic + // CHECK: %[[IDX:.+]] = linalg.index 0 + // CHECK: affine.delinearize_index %[[IDX]] into (14, 14) + // CHECK: affine.linearize_index + // CHECK: linalg.yield + // CHECK: } -> tensor<196xindex> + + // Step 3: Collapse output [1,196,512] → [196,512]. + // CHECK: %[[COLLAPSED_OUT:.+]] = tensor.collapse_shape %[[OUTPUT]] {{\[}}[0, 1], [2]{{\]}} + + // Step 4: Warp-level forall distributes 196 batch positions across warps. + // 256 threads / 64 subgroup_size = 4 warps. 196 / 4 = 49 per warp. + // CHECK: scf.forall (%[[WIV0:.+]], %[[WIV1:.+]]) = (0, 0) to (196, 512) step (49, 512) + // CHECK-SAME: shared_outs(%[[WINIT:.+]] = %[[COLLAPSED_OUT]]) + + // Step 5: Slice indices for this warp's batch positions. + // CHECK: %[[WARP_INDICES:.+]] = tensor.extract_slice %[[INDICES]][%[[WIV0]]] [49] [1] + + // Step 6: Lane-level forall (64 lanes) + coalesced gather DMA. + // Each lane reads elementsPerLane contiguous f16 from the collapsed source. + // CHECK: scf.forall (%[[LANE:.+]]) in (64) + // CHECK: scf.forall.in_parallel { + // CHECK: iree_gpu.coalesced_gather_dma %[[COLLAPSED]][%[[WARP_INDICES]]] + // CHECK-SAME: into %{{.+}} lane(%[[LANE]]) + // CHECK-SAME: tensor<256x512xf16>, tensor<49xindex>, tensor<49x512xf16>, index + // CHECK: } {mapping = [#iree_gpu.lane_id<0>]} + + // CHECK: } {mapping = [#gpu.warp, #gpu.warp]} + + // Step 7: Expand result back to [1,196,512]. + // CHECK: tensor.expand_shape %{{.+}} {{\[}}[0, 1], [2]{{\]}} + + // No im2col or gather should remain. + // CHECK-NOT: iree_linalg_ext.im2col + // CHECK-NOT: iree_linalg_ext.gather + + return %result : tensor<1x196x512xf16> +} + +// ----- + +// Negative test: im2col NOT converted when K_tile is too small for DMA +// alignment. With f16, dma_sizes=[32,128], subgroup_size=64: +// min_elements_per_transfer = 64 * (32/16) = 128. K_tile=4 is not aligned. +// The im2col should be downgraded to derived_thread_config. + +#gpu_target_im2col_small_k = #iree_gpu.target> +#exec_target_im2col_small_k = #hal.executable.target<"rocm", "rocm-hsaco-fb", {iree_codegen.target_info = #gpu_target_im2col_small_k}> +#translation_im2col_small_k = #iree_codegen.translation_info workgroup_size = [256, 1, 1] subgroup_size = 64, {gpu_pipeline_options = #iree_gpu.pipeline_options}> + +// CHECK-LABEL: func.func @im2col_small_k_no_dma +func.func @im2col_small_k_no_dma(%input: tensor<1x6x6x4xf16>, %output: tensor<1x16x4xf16>) -> tensor<1x16x4xf16> + attributes {hal.executable.target = #exec_target_im2col_small_k, translation_info = #translation_im2col_small_k} { + %result = iree_linalg_ext.im2col + {lowering_config = #iree_gpu.use_global_load_dma} + strides = [1, 1] dilations = [1, 1] kernel_size = [3, 3] + offsets = [0, 0, 0] + output_sizes = [[1], [4, 4], [3, 3, 4]] + batch_pos = [0] m_pos = [1, 2] k_pos = [3] + input_k_perm = [0, 1, 2] output_perm = [0, 1, 2] + ins(%input : tensor<1x6x6x4xf16>) + outs(%output : tensor<1x16x4xf16>) -> tensor<1x16x4xf16> + + // K_tile=4 is too small. Im2col remains with derived_thread_config. + // CHECK: iree_linalg_ext.im2col + // CHECK-SAME: lowering_config = #iree_gpu.derived_thread_config + // CHECK-NOT: iree_gpu.coalesced_gather_dma + + return %result : tensor<1x16x4xf16> +} + +// ----- + +// Negative test: im2col NOT converted on non-gfx950 target (gfx942). +// gfx942 does not support global load DMA (no dma_sizes field). + +#gpu_target_im2col_nogfx950 = #iree_gpu.target> +#exec_target_im2col_nogfx950 = #hal.executable.target<"rocm", "rocm-hsaco-fb", {iree_codegen.target_info = #gpu_target_im2col_nogfx950}> +#translation_im2col_nogfx950 = #iree_codegen.translation_info workgroup_size = [256, 1, 1] subgroup_size = 64, {gpu_pipeline_options = #iree_gpu.pipeline_options}> + +// CHECK-LABEL: func.func @im2col_nogfx950_no_dma +func.func @im2col_nogfx950_no_dma(%input: tensor<1x16x16x512xf16>, %output: tensor<1x196x512xf16>) -> tensor<1x196x512xf16> + attributes {hal.executable.target = #exec_target_im2col_nogfx950, translation_info = #translation_im2col_nogfx950} { + %result = iree_linalg_ext.im2col + {lowering_config = #iree_gpu.use_global_load_dma} + strides = [1, 1] dilations = [1, 1] kernel_size = [3, 3] + offsets = [0, 0, 0] + output_sizes = [[1], [14, 14], [3, 3, 512]] + batch_pos = [0] m_pos = [1, 2] k_pos = [3] + input_k_perm = [0, 1, 2] output_perm = [0, 1, 2] + ins(%input : tensor<1x16x16x512xf16>) + outs(%output : tensor<1x196x512xf16>) -> tensor<1x196x512xf16> + + // Non-gfx950 target. Im2col remains with derived_thread_config. + // CHECK: iree_linalg_ext.im2col + // CHECK-SAME: lowering_config = #iree_gpu.derived_thread_config + // CHECK-NOT: iree_gpu.coalesced_gather_dma + + return %result : tensor<1x196x512xf16> +} diff --git a/compiler/src/iree/compiler/Codegen/Common/GPU/test/gpu_promote_matmul_operands.mlir b/compiler/src/iree/compiler/Codegen/Common/GPU/test/gpu_promote_matmul_operands.mlir index 7b8fab4b2d69..31b0cf7fac86 100644 --- a/compiler/src/iree/compiler/Codegen/Common/GPU/test/gpu_promote_matmul_operands.mlir +++ b/compiler/src/iree/compiler/Codegen/Common/GPU/test/gpu_promote_matmul_operands.mlir @@ -419,11 +419,10 @@ func.func @im2col_producer_dma_downgraded_to_derived( return %mm : tensor<2x32x256xf32> } -// Im2col gets derived_thread_config (not use_global_load_dma) because Im2col -// has no DMA lowering path. The non-Im2col operand still gets use_global_load_dma. +// Im2col now gets use_global_load_dma because Im2col has a DMA lowering path. // CHECK-LABEL: func.func @im2col_producer_dma_downgraded_to_derived // CHECK: %[[PA:.+]] = iree_linalg_ext.im2col -// CHECK-SAME: lowering_config = #iree_gpu.derived_thread_config +// CHECK-SAME: lowering_config = #iree_gpu.use_global_load_dma // CHECK: %[[PB:.+]] = linalg.copy // CHECK-SAME: lowering_config = #iree_gpu.use_global_load_dma // CHECK: linalg.batch_matmul {{.*}} ins(%[[PA]], %[[PB]] diff --git a/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/PromotionImpls.cpp b/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/PromotionImpls.cpp index 17ec04490fa6..cef17a2f9cd4 100644 --- a/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/PromotionImpls.cpp +++ b/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/PromotionImpls.cpp @@ -79,11 +79,16 @@ static std::optional promotionImpl(OpBuilder &builder, setLoweringConfig(producer, attr); return operand.get(); } - // Im2colOp has no DMA conversion path in GPUConvertToCoalescedDMA, so - // always use derived_thread_config regardless of the requested attr. + // If the promotion attr requests DMA, pass it through to im2col. + // GPUConvertToCoalescedDMA will convert im2col → gather → DMA. + // Otherwise, fall back to derived_thread_config. if (isa(producer.getOperation())) { - setLoweringConfig(producer, - DerivedThreadConfigAttr::get(producer->getContext())); + if (isa(attr)) { + setLoweringConfig(producer, attr); + } else { + setLoweringConfig(producer, + DerivedThreadConfigAttr::get(producer->getContext())); + } return operand.get(); } } diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/BUILD.bazel b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/BUILD.bazel index 34cdc5a96144..ba82324d0d09 100644 --- a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/BUILD.bazel +++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/BUILD.bazel @@ -39,6 +39,7 @@ iree_lit_test_suite( "pipeline_elementwise_f8ocp.mlir", "pipeline_igemm_tile_and_fuse.mlir", "pipeline_igemm_tile_and_fuse_gfx950.mlir", + "pipeline_im2col_dma_gfx950.mlir", "pipeline_lower_to_llvmgpu.mlir", "pipeline_scaled_truncation_gfx950.mlir", "pipeline_tile_and_fuse.mlir", diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/CMakeLists.txt b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/CMakeLists.txt index 93834918b83e..479de4ae6d4b 100644 --- a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/CMakeLists.txt +++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/CMakeLists.txt @@ -34,6 +34,7 @@ iree_lit_test_suite( "pipeline_elementwise_f8ocp.mlir" "pipeline_igemm_tile_and_fuse.mlir" "pipeline_igemm_tile_and_fuse_gfx950.mlir" + "pipeline_im2col_dma_gfx950.mlir" "pipeline_lower_to_llvmgpu.mlir" "pipeline_scaled_truncation_gfx950.mlir" "pipeline_tile_and_fuse.mlir" diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/pipeline_im2col_dma_gfx950.mlir b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/pipeline_im2col_dma_gfx950.mlir new file mode 100644 index 000000000000..8a7f480f8b15 --- /dev/null +++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/pipeline_im2col_dma_gfx950.mlir @@ -0,0 +1,63 @@ +// RUN: iree-opt --split-input-file --iree-gpu-test-target=gfx950 \ +// RUN: --pass-pipeline="builtin.module(hal.executable(hal.executable.variant(builtin.module(func.func(iree-llvmgpu-lower-executable-target{for-rocdl=true})))))" %s | FileCheck %s + +#pipeline_layout = #hal.pipeline.layout, + #hal.pipeline.binding, + #hal.pipeline.binding +]> +#translation = #iree_codegen.translation_info + workgroup_size = [512, 1, 1] + subgroup_size = 64, + { + gpu_pipeline_options = #iree_gpu.pipeline_options< + prefetch_num_stages = 2, + no_reduce_shared_memory_bank_conflicts = true, + use_igemm_convolution = true> + }> +#config = #iree_gpu.lowering_config<{ + mma_kind = #iree_gpu.mma_layout, + promote_operands = [0, 1], + promotion_types = [#iree_gpu.use_global_load_dma, #iree_gpu.use_global_load_dma], + reduction = [0, 0, 0, 0, 4], + subgroup = [1, 2, 1, 4, 0], + workgroup = [1, 4, 32, 128, 0] +}> +hal.executable private @conv_im2col_dma { + hal.executable.variant public @rocm_hsaco_fb target(<"rocm", "rocm-hsaco-fb">) { + hal.executable.export public @conv_im2col_dma ordinal(0) layout(#pipeline_layout) count(%arg0: !hal.device) -> (index, index, index) { + %x, %y, %z = iree_tensor_ext.dispatch.workgroup_count_from_slice() + hal.return %x, %y, %z : index, index, index + } + builtin.module { + func.func @conv_im2col_dma() attributes {translation_info = #translation} { + %cst = arith.constant 0.000000e+00 : f32 + %c0 = arith.constant 0 : index + %0 = hal.interface.binding.subspan layout(#pipeline_layout) binding(0) alignment(64) offset(%c0) flags(ReadOnly) : !iree_tensor_ext.dispatch.tensor> + %1 = hal.interface.binding.subspan layout(#pipeline_layout) binding(1) alignment(64) offset(%c0) flags(ReadOnly) : !iree_tensor_ext.dispatch.tensor> + %2 = hal.interface.binding.subspan layout(#pipeline_layout) binding(2) alignment(64) offset(%c0) : !iree_tensor_ext.dispatch.tensor> + %3 = iree_tensor_ext.dispatch.tensor.load %0, offsets = [0, 0, 0, 0], sizes = [2, 34, 34, 1280], strides = [1, 1, 1, 1] : !iree_tensor_ext.dispatch.tensor> -> tensor<2x34x34x1280xf16> + %4 = iree_tensor_ext.dispatch.tensor.load %1, offsets = [0, 0, 0, 0], sizes = [3, 3, 1280, 1280], strides = [1, 1, 1, 1] : !iree_tensor_ext.dispatch.tensor> -> tensor<3x3x1280x1280xf16> + %5 = tensor.empty() : tensor<2x32x32x1280xf32> + %6 = linalg.fill ins(%cst : f32) outs(%5 : tensor<2x32x32x1280xf32>) -> tensor<2x32x32x1280xf32> + %7 = linalg.conv_2d_nhwc_hwcf {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>, lowering_config = #config} ins(%3, %4 : tensor<2x34x34x1280xf16>, tensor<3x3x1280x1280xf16>) outs(%6 : tensor<2x32x32x1280xf32>) -> tensor<2x32x32x1280xf32> + iree_tensor_ext.dispatch.tensor.store %7, %2, offsets = [0, 0, 0, 0], sizes = [2, 32, 32, 1280], strides = [1, 1, 1, 1] : tensor<2x32x32x1280xf32> -> !iree_tensor_ext.dispatch.tensor> + return + } + } + } +} + +// Verify im2col DMA path: conv is lowered through im2col -> gather -> +// coalesced_gather_dma -> amdgpu.gather_to_lds. The gather_to_lds ops +// read from fat_raw_buffer (global) into workgroup memory (LDS). +// +// CHECK-LABEL: func @conv_im2col_dma +// CHECK: scf.forall +// CHECK: scf.for {{.*}} iter_args +// CHECK: amdgpu.gather_to_lds {{.*}}#amdgpu.address_space{{.*}}#gpu.address_space +// CHECK: amdgpu.gather_to_lds {{.*}}#amdgpu.address_space{{.*}}#gpu.address_space +// CHECK: gpu.barrier +// CHECK: amdgpu.mfma 16x16x32 +// CHECK: scf.yield diff --git a/compiler/src/iree/compiler/Dialect/LinalgExt/IR/AggregatedOpInterfaceImpl.cpp b/compiler/src/iree/compiler/Dialect/LinalgExt/IR/AggregatedOpInterfaceImpl.cpp index 97892cacf4ec..c7fd2f1ff2ed 100644 --- a/compiler/src/iree/compiler/Dialect/LinalgExt/IR/AggregatedOpInterfaceImpl.cpp +++ b/compiler/src/iree/compiler/Dialect/LinalgExt/IR/AggregatedOpInterfaceImpl.cpp @@ -693,6 +693,16 @@ chooseDimToVectorize(OpBuilder &b, Location loc, Im2colOp im2colOp, return std::nullopt; } +std::optional Im2colOp::getVectorizableDim(OpBuilder &b, + Location loc) { + SmallVector iterDomain(getIterationDomain(b)); + SmallVector inputSizes = + tensor::getMixedSizes(b, loc, getInput()); + SmallVector mixedOffsets = getMixedOffsets(); + return chooseDimToVectorize(b, loc, *this, iterDomain, inputSizes, + mixedOffsets); +} + /// Decomposition implementation for iree_linalg_ext.im2col op. /// The im2col op is decomposed into serial loops of `insert->extract->copy`. /// The decomposition supports leaving either the `batch` or `K` dimension @@ -747,8 +757,8 @@ FailureOr> Im2colOp::decomposeOperation(OpBuilder &b) { SmallVector iterationDomain(getIterationDomain(b)); SmallVector inputSizes = tensor::getMixedSizes(b, loc, getInput()); - std::optional maybeOutputDimToVectorize = chooseDimToVectorize( - b, loc, *this, iterationDomain, inputSizes, mixedOffsets); + std::optional maybeOutputDimToVectorize = + getVectorizableDim(b, loc); OpFoldResult innerInputTileSize; if (maybeOutputDimToVectorize.has_value()) { diff --git a/compiler/src/iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.td b/compiler/src/iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.td index 8dabaa75caf4..92cf6e206e64 100644 --- a/compiler/src/iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.td +++ b/compiler/src/iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.td @@ -1439,6 +1439,12 @@ def IREELinalgExt_Im2colOp : IREELinalgExt_Op<"im2col", void setMixedOffsets(SmallVector offsets); void setMixedOutputSizes(ArrayRef> outputSizes); + // Returns the output dimension index that maps to a contiguous slice of + // the input's innermost dimension. Returns std::nullopt if no such dim + // exists. This wraps the chooseDimToVectorize logic used by + // decomposeOperation. + std::optional getVectorizableDim(OpBuilder &b, Location loc); + // Method to implement for specifying output range for // DestinationStyleOpInterface MutableOperandRange getDpsInitsMutable() { diff --git a/tests/e2e/rocm_specific/im2col_dma_conv.mlir b/tests/e2e/rocm_specific/im2col_dma_conv.mlir new file mode 100644 index 000000000000..59033fde7858 --- /dev/null +++ b/tests/e2e/rocm_specific/im2col_dma_conv.mlir @@ -0,0 +1,42 @@ +// Test conv2d using im2col + DMA path on gfx950+. +// +// Compile: +// iree-compile \ +// --iree-hal-target-backends=rocm \ +// --iree-rocm-target=gfx950 \ +// --iree-codegen-llvmgpu-use-igemm=true \ +// --iree-llvmgpu-use-direct-load=true \ +// im2col_dma_conv.mlir -o im2col_dma_conv.vmfb +// +// Run: +// iree-check-module --device=hip --module=im2col_dma_conv.vmfb +// +// Dump IR (for debugging): +// iree-compile \ +// --iree-hal-target-backends=rocm \ +// --iree-rocm-target=gfx950 \ +// --iree-codegen-llvmgpu-use-igemm=true \ +// --iree-llvmgpu-use-direct-load=true \ +// --mlir-print-ir-after-all \ +// im2col_dma_conv.mlir -o im2col_dma_conv.vmfb 2> im2col_dma_ir_dump.mlir + +!input_type = tensor<1x10x10x512xf16> +!filter_type = tensor<3x3x512x512xf16> +!output_type = tensor<1x8x8x512xf32> + +func.func @im2col_dma_conv() { + %input = util.unfoldable_constant dense<1.0> : !input_type + %filter = util.unfoldable_constant dense<1.0> : !filter_type + %cst = arith.constant 0.000000e+00 : f32 + %empty = tensor.empty() : !output_type + %fill = linalg.fill ins(%cst : f32) outs(%empty : !output_type) -> !output_type + %result = linalg.conv_2d_nhwc_hwcf { + dilations = dense<1> : tensor<2xi64>, + strides = dense<1> : tensor<2xi64> + } ins(%input, %filter : !input_type, !filter_type) + outs(%fill : !output_type) -> !output_type + // Each output element = sum over 3*3*512 products of 1*1 = 4608. + check.expect_almost_eq_const( + %result, dense<4608.0> : !output_type) : !output_type + return +}