diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/Utils/ROCDLPrefetchSharedMemoryCopy.cpp b/compiler/src/iree/compiler/Codegen/LLVMGPU/Utils/ROCDLPrefetchSharedMemoryCopy.cpp index 9f7a5f418091..c312a18bba2c 100644 --- a/compiler/src/iree/compiler/Codegen/LLVMGPU/Utils/ROCDLPrefetchSharedMemoryCopy.cpp +++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/Utils/ROCDLPrefetchSharedMemoryCopy.cpp @@ -4,6 +4,7 @@ // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +#include "iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenOps.h" #include "iree/compiler/Codegen/LLVMGPU/Utils/LLVMGPUUtils.h" #include "iree/compiler/Codegen/Utils/GPUUtils.h" #include "iree/compiler/Codegen/Utils/Utils.h" @@ -136,7 +137,7 @@ static bool hasStreamCopyOps(scf::ForOp forOp) { return hasGlobalRead && hasSharedWrite; } -/// Trace through view-like ops to find the root allocation. +/// Trace through view-like ops and swizzle hints to find the root allocation. static memref::AllocOp traceToAllocation(Value base) { while (base) { if (auto alloc = base.getDefiningOp()) { @@ -144,6 +145,8 @@ static memref::AllocOp traceToAllocation(Value base) { } if (auto viewOp = base.getDefiningOp()) { base = viewOp.getViewSource(); + } else if (auto hint = base.getDefiningOp()) { + base = hint.getOperand(); } else { break; } @@ -151,17 +154,16 @@ static memref::AllocOp traceToAllocation(Value base) { return nullptr; } -/// Collect all view-like ops that need to be cloned inside the loop. -/// Returns ops in topological order (dependencies first). +/// Collect all view-like ops and swizzle hints that need to be cloned inside +/// the loop. Returns ops in topological order (dependencies first). /// Returns failure if any use escapes the target loop. static FailureOr> -collectViewOpsToClone(memref::AllocOp alloc, scf::ForOp forOp) { - SetVector viewOpsToClone; +collectOpsToClone(memref::AllocOp alloc, scf::ForOp forOp) { + SetVector opsToClone; SmallVector worklist; worklist.push_back(alloc.getResult()); - // Collect all view-like ops outside the loop reachable from the allocation. while (!worklist.empty()) { Value val = worklist.pop_back_val(); for (Operation *user : val.getUsers()) { @@ -169,9 +171,13 @@ collectViewOpsToClone(memref::AllocOp alloc, scf::ForOp forOp) { continue; } if (auto viewOp = dyn_cast(user)) { - if (viewOpsToClone.insert(user)) { + if (opsToClone.insert(user)) { worklist.push_back(viewOp.getViewDest()); } + } else if (auto hint = dyn_cast(user)) { + if (opsToClone.insert(user)) { + worklist.push_back(hint.getResult()); + } } } } @@ -181,14 +187,14 @@ collectViewOpsToClone(memref::AllocOp alloc, scf::ForOp forOp) { if (forOp->isAncestor(user)) { continue; } - if (viewOpsToClone.contains(user)) { + if (opsToClone.contains(user)) { continue; } - // Dealloc should not block view-op cloning. + // Dealloc should not block cloning. if (isa(user)) { continue; } - LDBG() << "Cannot clone view ops: found use outside loop: " << *user; + LDBG() << "Cannot clone ops: found use outside loop: " << *user; return failure(); } return success(); @@ -198,14 +204,21 @@ collectViewOpsToClone(memref::AllocOp alloc, scf::ForOp forOp) { return failure(); } - for (Operation *op : viewOpsToClone) { - auto viewOp = cast(op); - if (failed(validateUses(viewOp.getViewDest()))) { + for (Operation *op : opsToClone) { + Value dest; + if (auto viewOp = dyn_cast(op)) { + dest = viewOp.getViewDest(); + } else if (auto hint = dyn_cast(op)) { + dest = hint.getResult(); + } else { + return failure(); + } + if (failed(validateUses(dest))) { return failure(); } } - SmallVector result(viewOpsToClone.begin(), viewOpsToClone.end()); + SmallVector result(opsToClone.begin(), opsToClone.end()); // Sort in topological order - ops must come after their dependencies llvm::stable_sort( @@ -214,23 +227,23 @@ collectViewOpsToClone(memref::AllocOp alloc, scf::ForOp forOp) { return result; } -/// Clone view-like operations inside the loop body. -/// This is necessary for multi-buffering to work when view ops are defined +/// Clone view-like ops and swizzle hints inside the loop body. +/// This is necessary for multi-buffering to work when these ops are defined /// outside the target loop but used inside it. -static LogicalResult cloneViewOpsInsideLoop(memref::AllocOp alloc, - scf::ForOp forOp) { - auto viewOpsOr = collectViewOpsToClone(alloc, forOp); - if (failed(viewOpsOr)) { +static LogicalResult cloneOpsInsideLoop(memref::AllocOp alloc, + scf::ForOp forOp) { + auto opsOr = collectOpsToClone(alloc, forOp); + if (failed(opsOr)) { return failure(); } - SmallVector &viewOps = *viewOpsOr; - if (viewOps.empty()) { + SmallVector &ops = *opsOr; + if (ops.empty()) { return success(); } - LDBG() << "Cloning " << viewOps.size() - << " view ops inside loop for allocation: " << *alloc; + LDBG() << "Cloning " << ops.size() + << " ops inside loop for allocation: " << *alloc; // Create clones at the beginning of the loop body Block *loopBody = forOp.getBody(); @@ -239,7 +252,7 @@ static LogicalResult cloneViewOpsInsideLoop(memref::AllocOp alloc, IRMapping mapping; SmallVector opsToErase; - for (Operation *op : viewOps) { + for (Operation *op : ops) { Operation *clone = builder.clone(*op, mapping); LDBG() << " Cloned: " << *op << " -> " << *clone; @@ -265,6 +278,103 @@ static LogicalResult cloneViewOpsInsideLoop(memref::AllocOp alloc, return success(); } +/// memref::multiBuffer propagates type changes through a set of known view-like +/// ops (subview, expand_shape, etc.). SwizzleHintOp is not in that set, so fix +/// up the result type of a single hint and its downstream ExpandShapeOp chain. +static void propagateTypeFromMultiBuffer(IREE::Codegen::SwizzleHintOp hint) { + if (hint.getOperand().getType() != hint.getResult().getType()) { + hint.getResult().setType(hint.getOperand().getType()); + } + // Propagate the layout change through the chain of ExpandShapeOps + // downstream of the hint. + SmallVector worklist = {hint.getResult()}; + while (!worklist.empty()) { + Value current = worklist.pop_back_val(); + for (OpOperand &use : current.getUses()) { + auto expandOp = dyn_cast(use.getOwner()); + if (!expandOp) { + continue; + } + auto srcType = cast(expandOp.getSrc().getType()); + MemRefType resultType = expandOp.getResultType(); + if (srcType.getLayout() == resultType.getLayout()) { + continue; + } + FailureOr newResultType = + memref::ExpandShapeOp::computeExpandedType( + srcType, resultType.getShape(), + expandOp.getReassociationIndices()); + if (failed(newResultType)) { + continue; + } + expandOp.getResult().setType(*newResultType); + worklist.push_back(expandOp.getResult()); + } + } +} + +/// After pipelining, the write path retains swizzle_hint but the read path +/// does not. Clone swizzle_hint onto read-side iter_args and loop results. +static void cloneSwizzleHint(scf::ForOp forOp) { + Block *body = forOp.getBody(); + auto yieldOp = cast(body->getTerminator()); + int numOperands = yieldOp.getNumOperands(); + + OpBuilder builder(forOp.getContext()); + + // Reverse order because the pipeliner appends iter_args oldest-to-newest. + // The newest slot corresponds to the write path. + for (int idx = numOperands - 1; idx >= 0; --idx) { + Value yieldVal = yieldOp.getOperand(idx); + + // Check if the yield operand traces through expand_shape -> swizzle_hint. + auto expandOp = yieldVal.getDefiningOp(); + if (!expandOp) { + continue; + } + auto hintOp = + expandOp.getSrc().getDefiningOp(); + if (!hintOp) { + continue; + } + + BlockArgument iterArg = forOp.getRegionIterArg(idx); + auto iterArgType = cast(iterArg.getType()); + SmallVector reassoc = + expandOp.getReassociationIndices(); + + FailureOr flatType = + memref::CollapseShapeOp::computeCollapsedType(iterArgType, reassoc); + if (failed(flatType)) { + continue; + } + + auto swizzleAttr = hintOp.getSwizzle(); + Location loc = hintOp.getLoc(); + + LDBG() << "Cloning swizzle_hint onto iter_arg #" << idx << ": " << iterArg; + + // Insert collapse_shape -> swizzle_hint -> expand_shape. + auto insertSwizzleHint = [&](Value value) { + auto collapse = memref::CollapseShapeOp::create(builder, loc, *flatType, + value, reassoc); + auto hint = IREE::Codegen::SwizzleHintOp::create( + builder, loc, collapse.getResult(), swizzleAttr); + auto expand = memref::ExpandShapeOp::create(builder, loc, iterArgType, + hint.getResult(), reassoc); + value.replaceAllUsesExcept(expand.getResult(), collapse.getOperation()); + }; + + // Clone for reads inside the loop body. + builder.setInsertionPointToStart(body); + insertSwizzleHint(iterArg); + + // Clone for reads in the epilogue. + builder.setInsertionPointAfter(forOp); + insertSwizzleHint(forOp.getResult(idx)); + } +} + /// Multi-buffer LDS allocations used by gather_to_lds operations. /// This enables double-buffering for pipelined async copies. static LogicalResult multiBufferLDSAllocations(scf::ForOp forOp, @@ -289,7 +399,7 @@ static LogicalResult multiBufferLDSAllocations(scf::ForOp forOp, // First, clone view ops inside the loop for each allocation for (memref::AllocOp alloc : sharedAllocs) { - if (failed(cloneViewOpsInsideLoop(alloc, forOp))) { + if (failed(cloneOpsInsideLoop(alloc, forOp))) { LDBG() << "Failed to clone view ops for: " << *alloc; return failure(); } @@ -307,6 +417,9 @@ static LogicalResult multiBufferLDSAllocations(scf::ForOp forOp, << " buffers at " << loc; } + // Fix up types for swizzle hints after multi-buffering. + forOp->walk(propagateTypeFromMultiBuffer); + return success(); } @@ -1316,6 +1429,9 @@ FailureOr prefetchSharedMemoryCopy(RewriterBase &rewriter, // Insert barriers using the appropriate strategy for each mode. insertPipelineBarriers(rewriter, newForOp, mode); + // If swizzle_hint was applied, fix it by cloning onto the read-side. + cloneSwizzleHint(newForOp); + // For async copy mode, convert gather_to_lds to async and insert explicit // async markers (asyncmark + wait.asyncmark). This replaces the backend's // alias-analysis-based vmcnt insertion with precise explicit synchronization, diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/prefetch_shared_memory.mlir b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/prefetch_shared_memory.mlir index 12b10239a2b0..194c25cd2e64 100644 --- a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/prefetch_shared_memory.mlir +++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/prefetch_shared_memory.mlir @@ -699,3 +699,52 @@ func.func @gather_to_lds_nested_loop_async( } return } + +// ----- + +// Test that swizzle_hint is cloned onto the read-side iter_arg and loop result +// after pipelining. + +// CHECK-LABEL: @prefetch_gather_to_lds_with_swizzle +func.func @prefetch_gather_to_lds_with_swizzle( + %global: memref<128x8xf32>, + %output: memref<8xf32>) { + %cst = arith.constant dense<0.000000e+00> : vector<1xf32> + %cst_0 = arith.constant 0.000000e+00 : f32 + %c128 = arith.constant 128 : index + %c1 = arith.constant 1 : index + %c0 = arith.constant 0 : index + + // CHECK: memref.alloc() : memref<2x1x8xf32, #gpu.address_space> + %alloc = memref.alloc() : memref<1x8xf32, #gpu.address_space> + %collapsed = memref.collapse_shape %alloc [[0, 1]] : memref<1x8xf32, #gpu.address_space> into memref<8xf32, #gpu.address_space> + %swizzled = iree_codegen.swizzle_hint %collapsed [#iree_codegen.xor_shuffle<128, 8>] : memref<8xf32, #gpu.address_space> + %expanded = memref.expand_shape %swizzled [[0, 1]] output_shape [1, 8] : memref<8xf32, #gpu.address_space> into memref<1x8xf32, #gpu.address_space> + + // Prologue: write stage with swizzle + // CHECK: iree_codegen.swizzle_hint {{.*}}[#iree_codegen.xor_shuffle<128, 8>] + // CHECK: amdgpu.gather_to_lds async + // CHECK: rocdl.asyncmark + // CHECK: scf.for + %result = scf.for %k = %c0 to %c128 step %c1 iter_args(%acc = %cst) -> (vector<1xf32>) { + amdgpu.gather_to_lds %global[%k, %c0], %expanded[%c0, %c0] : vector<1xf32>, memref<128x8xf32>, memref<1x8xf32, #gpu.address_space> + %val = vector.transfer_read %expanded[%c0, %c0], %cst_0 : memref<1x8xf32, #gpu.address_space>, vector<1xf32> + %sum = arith.addf %val, %acc : vector<1xf32> + scf.yield %sum : vector<1xf32> + } + + // Read-side iter_arg: swizzle_hint cloned onto iter_arg for correct reads + // CHECK: iree_codegen.swizzle_hint {{.*}}[#iree_codegen.xor_shuffle<128, 8>] + // Write-side: swizzle_hint on new write buffer + // CHECK: iree_codegen.swizzle_hint {{.*}}[#iree_codegen.xor_shuffle<128, 8>] + // CHECK: amdgpu.gather_to_lds async + // CHECK: vector.transfer_read + // CHECK: scf.yield + + // Epilogue: swizzle_hint cloned onto loop result + // CHECK: iree_codegen.swizzle_hint {{.*}}[#iree_codegen.xor_shuffle<128, 8>] + // CHECK: vector.transfer_read + + vector.transfer_write %result, %output[%c0] {in_bounds = [true]} : vector<1xf32>, memref<8xf32> + return +}