Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

#include "iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenOps.h"
#include "iree/compiler/Codegen/LLVMGPU/Utils/LLVMGPUUtils.h"
#include "iree/compiler/Codegen/Utils/GPUUtils.h"
#include "iree/compiler/Codegen/Utils/Utils.h"
Expand Down Expand Up @@ -136,42 +137,47 @@ static bool hasStreamCopyOps(scf::ForOp forOp) {
return hasGlobalRead && hasSharedWrite;
}

/// Trace through view-like ops to find the root allocation.
/// Trace through view-like ops and swizzle hints to find the root allocation.
static memref::AllocOp traceToAllocation(Value base) {
while (base) {
if (auto alloc = base.getDefiningOp<memref::AllocOp>()) {
return alloc;
}
if (auto viewOp = base.getDefiningOp<ViewLikeOpInterface>()) {
base = viewOp.getViewSource();
} else if (auto hint = base.getDefiningOp<IREE::Codegen::SwizzleHintOp>()) {
base = hint.getOperand();
} else {
break;
}
}
return nullptr;
}

/// Collect all view-like ops that need to be cloned inside the loop.
/// Returns ops in topological order (dependencies first).
/// Collect all view-like ops and swizzle hints that need to be cloned inside
/// the loop. Returns ops in topological order (dependencies first).
/// Returns failure if any use escapes the target loop.
static FailureOr<SmallVector<Operation *>>
collectViewOpsToClone(memref::AllocOp alloc, scf::ForOp forOp) {
SetVector<Operation *> viewOpsToClone;
collectOpsToClone(memref::AllocOp alloc, scf::ForOp forOp) {
SetVector<Operation *> opsToClone;
SmallVector<Value> worklist;

worklist.push_back(alloc.getResult());

// Collect all view-like ops outside the loop reachable from the allocation.
while (!worklist.empty()) {
Value val = worklist.pop_back_val();
for (Operation *user : val.getUsers()) {
if (forOp->isAncestor(user)) {
continue;
}
if (auto viewOp = dyn_cast<ViewLikeOpInterface>(user)) {
if (viewOpsToClone.insert(user)) {
if (opsToClone.insert(user)) {
worklist.push_back(viewOp.getViewDest());
}
} else if (auto hint = dyn_cast<IREE::Codegen::SwizzleHintOp>(user)) {
if (opsToClone.insert(user)) {
worklist.push_back(hint.getResult());
}
}
}
}
Expand All @@ -181,14 +187,14 @@ collectViewOpsToClone(memref::AllocOp alloc, scf::ForOp forOp) {
if (forOp->isAncestor(user)) {
continue;
}
if (viewOpsToClone.contains(user)) {
if (opsToClone.contains(user)) {
continue;
}
// Dealloc should not block view-op cloning.
// Dealloc should not block cloning.
if (isa<memref::DeallocOp>(user)) {
continue;
}
LDBG() << "Cannot clone view ops: found use outside loop: " << *user;
LDBG() << "Cannot clone ops: found use outside loop: " << *user;
return failure();
}
return success();
Expand All @@ -198,14 +204,21 @@ collectViewOpsToClone(memref::AllocOp alloc, scf::ForOp forOp) {
return failure();
}

for (Operation *op : viewOpsToClone) {
auto viewOp = cast<ViewLikeOpInterface>(op);
if (failed(validateUses(viewOp.getViewDest()))) {
for (Operation *op : opsToClone) {
Value dest;
if (auto viewOp = dyn_cast<ViewLikeOpInterface>(op)) {
dest = viewOp.getViewDest();
} else if (auto hint = dyn_cast<IREE::Codegen::SwizzleHintOp>(op)) {
dest = hint.getResult();
} else {
return failure();
}
if (failed(validateUses(dest))) {
return failure();
}
}

SmallVector<Operation *> result(viewOpsToClone.begin(), viewOpsToClone.end());
SmallVector<Operation *> result(opsToClone.begin(), opsToClone.end());

// Sort in topological order - ops must come after their dependencies
llvm::stable_sort(
Expand All @@ -214,23 +227,23 @@ collectViewOpsToClone(memref::AllocOp alloc, scf::ForOp forOp) {
return result;
}

/// Clone view-like operations inside the loop body.
/// This is necessary for multi-buffering to work when view ops are defined
/// Clone view-like ops and swizzle hints inside the loop body.
/// This is necessary for multi-buffering to work when these ops are defined
/// outside the target loop but used inside it.
static LogicalResult cloneViewOpsInsideLoop(memref::AllocOp alloc,
scf::ForOp forOp) {
auto viewOpsOr = collectViewOpsToClone(alloc, forOp);
if (failed(viewOpsOr)) {
static LogicalResult cloneOpsInsideLoop(memref::AllocOp alloc,
scf::ForOp forOp) {
auto opsOr = collectOpsToClone(alloc, forOp);
if (failed(opsOr)) {
return failure();
}

SmallVector<Operation *> &viewOps = *viewOpsOr;
if (viewOps.empty()) {
SmallVector<Operation *> &ops = *opsOr;
if (ops.empty()) {
return success();
}

LDBG() << "Cloning " << viewOps.size()
<< " view ops inside loop for allocation: " << *alloc;
LDBG() << "Cloning " << ops.size()
<< " ops inside loop for allocation: " << *alloc;

// Create clones at the beginning of the loop body
Block *loopBody = forOp.getBody();
Expand All @@ -239,7 +252,7 @@ static LogicalResult cloneViewOpsInsideLoop(memref::AllocOp alloc,

IRMapping mapping;
SmallVector<Operation *> opsToErase;
for (Operation *op : viewOps) {
for (Operation *op : ops) {
Operation *clone = builder.clone(*op, mapping);
LDBG() << " Cloned: " << *op << " -> " << *clone;

Expand All @@ -265,6 +278,103 @@ static LogicalResult cloneViewOpsInsideLoop(memref::AllocOp alloc,
return success();
}

/// memref::multiBuffer propagates type changes through a set of known view-like
/// ops (subview, expand_shape, etc.). SwizzleHintOp is not in that set, so fix
/// up the result type of a single hint and its downstream ExpandShapeOp chain.
static void propagateTypeFromMultiBuffer(IREE::Codegen::SwizzleHintOp hint) {
if (hint.getOperand().getType() != hint.getResult().getType()) {
hint.getResult().setType(hint.getOperand().getType());
}
// Propagate the layout change through the chain of ExpandShapeOps
// downstream of the hint.
SmallVector<Value> worklist = {hint.getResult()};
while (!worklist.empty()) {
Value current = worklist.pop_back_val();
for (OpOperand &use : current.getUses()) {
auto expandOp = dyn_cast<memref::ExpandShapeOp>(use.getOwner());
if (!expandOp) {
continue;
}
auto srcType = cast<MemRefType>(expandOp.getSrc().getType());
MemRefType resultType = expandOp.getResultType();
if (srcType.getLayout() == resultType.getLayout()) {
continue;
}
FailureOr<MemRefType> newResultType =
memref::ExpandShapeOp::computeExpandedType(
srcType, resultType.getShape(),
expandOp.getReassociationIndices());
if (failed(newResultType)) {
continue;
}
expandOp.getResult().setType(*newResultType);
worklist.push_back(expandOp.getResult());
}
}
}

/// After pipelining, the write path retains swizzle_hint but the read path
/// does not. Clone swizzle_hint onto read-side iter_args and loop results.
static void cloneSwizzleHint(scf::ForOp forOp) {
Block *body = forOp.getBody();
auto yieldOp = cast<scf::YieldOp>(body->getTerminator());
int numOperands = yieldOp.getNumOperands();

OpBuilder builder(forOp.getContext());

// Reverse order because the pipeliner appends iter_args oldest-to-newest.
// The newest slot corresponds to the write path.
for (int idx = numOperands - 1; idx >= 0; --idx) {
Value yieldVal = yieldOp.getOperand(idx);

// Check if the yield operand traces through expand_shape -> swizzle_hint.
auto expandOp = yieldVal.getDefiningOp<memref::ExpandShapeOp>();
if (!expandOp) {
continue;
}
auto hintOp =
expandOp.getSrc().getDefiningOp<IREE::Codegen::SwizzleHintOp>();
if (!hintOp) {
continue;
}

BlockArgument iterArg = forOp.getRegionIterArg(idx);
auto iterArgType = cast<MemRefType>(iterArg.getType());
SmallVector<ReassociationIndices> reassoc =
expandOp.getReassociationIndices();

FailureOr<MemRefType> flatType =
memref::CollapseShapeOp::computeCollapsedType(iterArgType, reassoc);
if (failed(flatType)) {
continue;
}

auto swizzleAttr = hintOp.getSwizzle();
Location loc = hintOp.getLoc();

LDBG() << "Cloning swizzle_hint onto iter_arg #" << idx << ": " << iterArg;

// Insert collapse_shape -> swizzle_hint -> expand_shape.
auto insertSwizzleHint = [&](Value value) {
auto collapse = memref::CollapseShapeOp::create(builder, loc, *flatType,
value, reassoc);
auto hint = IREE::Codegen::SwizzleHintOp::create(
builder, loc, collapse.getResult(), swizzleAttr);
auto expand = memref::ExpandShapeOp::create(builder, loc, iterArgType,
hint.getResult(), reassoc);
value.replaceAllUsesExcept(expand.getResult(), collapse.getOperation());
};

// Clone for reads inside the loop body.
builder.setInsertionPointToStart(body);
insertSwizzleHint(iterArg);

// Clone for reads in the epilogue.
builder.setInsertionPointAfter(forOp);
insertSwizzleHint(forOp.getResult(idx));
}
}

/// Multi-buffer LDS allocations used by gather_to_lds operations.
/// This enables double-buffering for pipelined async copies.
static LogicalResult multiBufferLDSAllocations(scf::ForOp forOp,
Expand All @@ -289,7 +399,7 @@ static LogicalResult multiBufferLDSAllocations(scf::ForOp forOp,

// First, clone view ops inside the loop for each allocation
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: no more view.

for (memref::AllocOp alloc : sharedAllocs) {
if (failed(cloneViewOpsInsideLoop(alloc, forOp))) {
if (failed(cloneOpsInsideLoop(alloc, forOp))) {
LDBG() << "Failed to clone view ops for: " << *alloc;
return failure();
}
Expand All @@ -307,6 +417,9 @@ static LogicalResult multiBufferLDSAllocations(scf::ForOp forOp,
<< " buffers at " << loc;
}

// Fix up types for swizzle hints after multi-buffering.
forOp->walk(propagateTypeFromMultiBuffer);

return success();
}

Expand Down Expand Up @@ -1316,6 +1429,9 @@ FailureOr<scf::ForOp> prefetchSharedMemoryCopy(RewriterBase &rewriter,
// Insert barriers using the appropriate strategy for each mode.
insertPipelineBarriers(rewriter, newForOp, mode);

// If swizzle_hint was applied, fix it by cloning onto the read-side.
cloneSwizzleHint(newForOp);

// For async copy mode, convert gather_to_lds to async and insert explicit
// async markers (asyncmark + wait.asyncmark). This replaces the backend's
// alias-analysis-based vmcnt insertion with precise explicit synchronization,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -699,3 +699,52 @@ func.func @gather_to_lds_nested_loop_async(
}
return
}

// -----

// Test that swizzle_hint is cloned onto the read-side iter_arg and loop result
// after pipelining.

// CHECK-LABEL: @prefetch_gather_to_lds_with_swizzle
func.func @prefetch_gather_to_lds_with_swizzle(
%global: memref<128x8xf32>,
%output: memref<8xf32>) {
%cst = arith.constant dense<0.000000e+00> : vector<1xf32>
%cst_0 = arith.constant 0.000000e+00 : f32
%c128 = arith.constant 128 : index
%c1 = arith.constant 1 : index
%c0 = arith.constant 0 : index

// CHECK: memref.alloc() : memref<2x1x8xf32, #gpu.address_space<workgroup>>
%alloc = memref.alloc() : memref<1x8xf32, #gpu.address_space<workgroup>>
%collapsed = memref.collapse_shape %alloc [[0, 1]] : memref<1x8xf32, #gpu.address_space<workgroup>> into memref<8xf32, #gpu.address_space<workgroup>>
%swizzled = iree_codegen.swizzle_hint %collapsed [#iree_codegen.xor_shuffle<128, 8>] : memref<8xf32, #gpu.address_space<workgroup>>
%expanded = memref.expand_shape %swizzled [[0, 1]] output_shape [1, 8] : memref<8xf32, #gpu.address_space<workgroup>> into memref<1x8xf32, #gpu.address_space<workgroup>>

// Prologue: write stage with swizzle
// CHECK: iree_codegen.swizzle_hint {{.*}}[#iree_codegen.xor_shuffle<128, 8>]
// CHECK: amdgpu.gather_to_lds async
// CHECK: rocdl.asyncmark
// CHECK: scf.for
%result = scf.for %k = %c0 to %c128 step %c1 iter_args(%acc = %cst) -> (vector<1xf32>) {
amdgpu.gather_to_lds %global[%k, %c0], %expanded[%c0, %c0] : vector<1xf32>, memref<128x8xf32>, memref<1x8xf32, #gpu.address_space<workgroup>>
%val = vector.transfer_read %expanded[%c0, %c0], %cst_0 : memref<1x8xf32, #gpu.address_space<workgroup>>, vector<1xf32>
%sum = arith.addf %val, %acc : vector<1xf32>
scf.yield %sum : vector<1xf32>
}

// Read-side iter_arg: swizzle_hint cloned onto iter_arg for correct reads
// CHECK: iree_codegen.swizzle_hint {{.*}}[#iree_codegen.xor_shuffle<128, 8>]
// Write-side: swizzle_hint on new write buffer
// CHECK: iree_codegen.swizzle_hint {{.*}}[#iree_codegen.xor_shuffle<128, 8>]
// CHECK: amdgpu.gather_to_lds async
// CHECK: vector.transfer_read
// CHECK: scf.yield

// Epilogue: swizzle_hint cloned onto loop result
// CHECK: iree_codegen.swizzle_hint {{.*}}[#iree_codegen.xor_shuffle<128, 8>]
// CHECK: vector.transfer_read

vector.transfer_write %result, %output[%c0] {in_bounds = [true]} : vector<1xf32>, memref<8xf32>
return
}
Loading