44// See https://llvm.org/LICENSE.txt for license information.
55// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
66
7+ #include " iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenOps.h"
78#include " iree/compiler/Codegen/LLVMGPU/Utils/LLVMGPUUtils.h"
89#include " iree/compiler/Codegen/Utils/GPUUtils.h"
910#include " iree/compiler/Codegen/Utils/Utils.h"
@@ -136,42 +137,47 @@ static bool hasStreamCopyOps(scf::ForOp forOp) {
136137 return hasGlobalRead && hasSharedWrite;
137138}
138139
139- // / Trace through view-like ops to find the root allocation.
140+ // / Trace through view-like ops and swizzle hints to find the root allocation.
140141static memref::AllocOp traceToAllocation (Value base) {
141142 while (base) {
142143 if (auto alloc = base.getDefiningOp <memref::AllocOp>()) {
143144 return alloc;
144145 }
145146 if (auto viewOp = base.getDefiningOp <ViewLikeOpInterface>()) {
146147 base = viewOp.getViewSource ();
148+ } else if (auto hint = base.getDefiningOp <IREE::Codegen::SwizzleHintOp>()) {
149+ base = hint.getOperand ();
147150 } else {
148151 break ;
149152 }
150153 }
151154 return nullptr ;
152155}
153156
154- // / Collect all view-like ops that need to be cloned inside the loop.
155- // / Returns ops in topological order (dependencies first).
157+ // / Collect all view-like ops and swizzle hints that need to be cloned inside
158+ // / the loop. Returns ops in topological order (dependencies first).
156159// / Returns failure if any use escapes the target loop.
157160static FailureOr<SmallVector<Operation *>>
158- collectViewOpsToClone (memref::AllocOp alloc, scf::ForOp forOp) {
159- SetVector<Operation *> viewOpsToClone ;
161+ collectOpsToClone (memref::AllocOp alloc, scf::ForOp forOp) {
162+ SetVector<Operation *> opsToClone ;
160163 SmallVector<Value> worklist;
161164
162165 worklist.push_back (alloc.getResult ());
163166
164- // Collect all view-like ops outside the loop reachable from the allocation.
165167 while (!worklist.empty ()) {
166168 Value val = worklist.pop_back_val ();
167169 for (Operation *user : val.getUsers ()) {
168170 if (forOp->isAncestor (user)) {
169171 continue ;
170172 }
171173 if (auto viewOp = dyn_cast<ViewLikeOpInterface>(user)) {
172- if (viewOpsToClone .insert (user)) {
174+ if (opsToClone .insert (user)) {
173175 worklist.push_back (viewOp.getViewDest ());
174176 }
177+ } else if (auto hint = dyn_cast<IREE::Codegen::SwizzleHintOp>(user)) {
178+ if (opsToClone.insert (user)) {
179+ worklist.push_back (hint.getResult ());
180+ }
175181 }
176182 }
177183 }
@@ -181,14 +187,14 @@ collectViewOpsToClone(memref::AllocOp alloc, scf::ForOp forOp) {
181187 if (forOp->isAncestor (user)) {
182188 continue ;
183189 }
184- if (viewOpsToClone .contains (user)) {
190+ if (opsToClone .contains (user)) {
185191 continue ;
186192 }
187- // Dealloc should not block view-op cloning.
193+ // Dealloc should not block cloning.
188194 if (isa<memref::DeallocOp>(user)) {
189195 continue ;
190196 }
191- LDBG () << " Cannot clone view ops: found use outside loop: " << *user;
197+ LDBG () << " Cannot clone ops: found use outside loop: " << *user;
192198 return failure ();
193199 }
194200 return success ();
@@ -198,14 +204,21 @@ collectViewOpsToClone(memref::AllocOp alloc, scf::ForOp forOp) {
198204 return failure ();
199205 }
200206
201- for (Operation *op : viewOpsToClone) {
202- auto viewOp = cast<ViewLikeOpInterface>(op);
203- if (failed (validateUses (viewOp.getViewDest ()))) {
207+ for (Operation *op : opsToClone) {
208+ Value dest;
209+ if (auto viewOp = dyn_cast<ViewLikeOpInterface>(op)) {
210+ dest = viewOp.getViewDest ();
211+ } else if (auto hint = dyn_cast<IREE::Codegen::SwizzleHintOp>(op)) {
212+ dest = hint.getResult ();
213+ } else {
214+ return failure ();
215+ }
216+ if (failed (validateUses (dest))) {
204217 return failure ();
205218 }
206219 }
207220
208- SmallVector<Operation *> result (viewOpsToClone .begin (), viewOpsToClone .end ());
221+ SmallVector<Operation *> result (opsToClone .begin (), opsToClone .end ());
209222
210223 // Sort in topological order - ops must come after their dependencies
211224 llvm::stable_sort (
@@ -214,23 +227,23 @@ collectViewOpsToClone(memref::AllocOp alloc, scf::ForOp forOp) {
214227 return result;
215228}
216229
217- // / Clone view-like operations inside the loop body.
218- // / This is necessary for multi-buffering to work when view ops are defined
230+ // / Clone view-like ops and swizzle hints inside the loop body.
231+ // / This is necessary for multi-buffering to work when these ops are defined
219232// / outside the target loop but used inside it.
220- static LogicalResult cloneViewOpsInsideLoop (memref::AllocOp alloc,
221- scf::ForOp forOp) {
222- auto viewOpsOr = collectViewOpsToClone (alloc, forOp);
223- if (failed (viewOpsOr )) {
233+ static LogicalResult cloneOpsInsideLoop (memref::AllocOp alloc,
234+ scf::ForOp forOp) {
235+ auto opsOr = collectOpsToClone (alloc, forOp);
236+ if (failed (opsOr )) {
224237 return failure ();
225238 }
226239
227- SmallVector<Operation *> &viewOps = *viewOpsOr ;
228- if (viewOps .empty ()) {
240+ SmallVector<Operation *> &ops = *opsOr ;
241+ if (ops .empty ()) {
229242 return success ();
230243 }
231244
232- LDBG () << " Cloning " << viewOps .size ()
233- << " view ops inside loop for allocation: " << *alloc;
245+ LDBG () << " Cloning " << ops .size ()
246+ << " ops inside loop for allocation: " << *alloc;
234247
235248 // Create clones at the beginning of the loop body
236249 Block *loopBody = forOp.getBody ();
@@ -239,7 +252,7 @@ static LogicalResult cloneViewOpsInsideLoop(memref::AllocOp alloc,
239252
240253 IRMapping mapping;
241254 SmallVector<Operation *> opsToErase;
242- for (Operation *op : viewOps ) {
255+ for (Operation *op : ops ) {
243256 Operation *clone = builder.clone (*op, mapping);
244257 LDBG () << " Cloned: " << *op << " -> " << *clone;
245258
@@ -265,6 +278,103 @@ static LogicalResult cloneViewOpsInsideLoop(memref::AllocOp alloc,
265278 return success ();
266279}
267280
281+ // / memref::multiBuffer propagates type changes through a set of known view-like
282+ // / ops (subview, expand_shape, etc.). SwizzleHintOp is not in that set, so fix
283+ // / up the result type of a single hint and its downstream ExpandShapeOp chain.
284+ static void propagateTypeFromMultiBuffer (IREE::Codegen::SwizzleHintOp hint) {
285+ if (hint.getOperand ().getType () != hint.getResult ().getType ()) {
286+ hint.getResult ().setType (hint.getOperand ().getType ());
287+ }
288+ // Propagate the layout change through the chain of ExpandShapeOps
289+ // downstream of the hint.
290+ SmallVector<Value> worklist = {hint.getResult ()};
291+ while (!worklist.empty ()) {
292+ Value current = worklist.pop_back_val ();
293+ for (OpOperand &use : current.getUses ()) {
294+ auto expandOp = dyn_cast<memref::ExpandShapeOp>(use.getOwner ());
295+ if (!expandOp) {
296+ continue ;
297+ }
298+ auto srcType = cast<MemRefType>(expandOp.getSrc ().getType ());
299+ MemRefType resultType = expandOp.getResultType ();
300+ if (srcType.getLayout () == resultType.getLayout ()) {
301+ continue ;
302+ }
303+ FailureOr<MemRefType> newResultType =
304+ memref::ExpandShapeOp::computeExpandedType (
305+ srcType, resultType.getShape (),
306+ expandOp.getReassociationIndices ());
307+ if (failed (newResultType)) {
308+ continue ;
309+ }
310+ expandOp.getResult ().setType (*newResultType);
311+ worklist.push_back (expandOp.getResult ());
312+ }
313+ }
314+ }
315+
316+ // / After pipelining, the write path retains swizzle_hint but the read path
317+ // / does not. Clone swizzle_hint onto read-side iter_args and loop results.
318+ static void cloneSwizzleHint (scf::ForOp forOp) {
319+ Block *body = forOp.getBody ();
320+ auto yieldOp = cast<scf::YieldOp>(body->getTerminator ());
321+ int numOperands = yieldOp.getNumOperands ();
322+
323+ OpBuilder builder (forOp.getContext ());
324+
325+ // Reverse order because the pipeliner appends iter_args oldest-to-newest.
326+ // The newest slot corresponds to the write path.
327+ for (int idx = numOperands - 1 ; idx >= 0 ; --idx) {
328+ Value yieldVal = yieldOp.getOperand (idx);
329+
330+ // Check if the yield operand traces through expand_shape -> swizzle_hint.
331+ auto expandOp = yieldVal.getDefiningOp <memref::ExpandShapeOp>();
332+ if (!expandOp) {
333+ continue ;
334+ }
335+ auto hintOp =
336+ expandOp.getSrc ().getDefiningOp <IREE::Codegen::SwizzleHintOp>();
337+ if (!hintOp) {
338+ continue ;
339+ }
340+
341+ BlockArgument iterArg = forOp.getRegionIterArg (idx);
342+ auto iterArgType = cast<MemRefType>(iterArg.getType ());
343+ SmallVector<ReassociationIndices> reassoc =
344+ expandOp.getReassociationIndices ();
345+
346+ FailureOr<MemRefType> flatType =
347+ memref::CollapseShapeOp::computeCollapsedType (iterArgType, reassoc);
348+ if (failed (flatType)) {
349+ continue ;
350+ }
351+
352+ auto swizzleAttr = hintOp.getSwizzle ();
353+ Location loc = hintOp.getLoc ();
354+
355+ LDBG () << " Cloning swizzle_hint onto iter_arg #" << idx << " : " << iterArg;
356+
357+ // Insert collapse_shape -> swizzle_hint -> expand_shape.
358+ auto insertSwizzleHint = [&](Value value) {
359+ auto collapse = memref::CollapseShapeOp::create (builder, loc, *flatType,
360+ value, reassoc);
361+ auto hint = IREE::Codegen::SwizzleHintOp::create (
362+ builder, loc, collapse.getResult (), swizzleAttr);
363+ auto expand = memref::ExpandShapeOp::create (builder, loc, iterArgType,
364+ hint.getResult (), reassoc);
365+ value.replaceAllUsesExcept (expand.getResult (), collapse.getOperation ());
366+ };
367+
368+ // Clone for reads inside the loop body.
369+ builder.setInsertionPointToStart (body);
370+ insertSwizzleHint (iterArg);
371+
372+ // Clone for reads in the epilogue.
373+ builder.setInsertionPointAfter (forOp);
374+ insertSwizzleHint (forOp.getResult (idx));
375+ }
376+ }
377+
268378// / Multi-buffer LDS allocations used by gather_to_lds operations.
269379// / This enables double-buffering for pipelined async copies.
270380static LogicalResult multiBufferLDSAllocations (scf::ForOp forOp,
@@ -289,7 +399,7 @@ static LogicalResult multiBufferLDSAllocations(scf::ForOp forOp,
289399
290400 // First, clone view ops inside the loop for each allocation
291401 for (memref::AllocOp alloc : sharedAllocs) {
292- if (failed (cloneViewOpsInsideLoop (alloc, forOp))) {
402+ if (failed (cloneOpsInsideLoop (alloc, forOp))) {
293403 LDBG () << " Failed to clone view ops for: " << *alloc;
294404 return failure ();
295405 }
@@ -307,6 +417,9 @@ static LogicalResult multiBufferLDSAllocations(scf::ForOp forOp,
307417 << " buffers at " << loc;
308418 }
309419
420+ // Fix up types for swizzle hints after multi-buffering.
421+ forOp->walk (propagateTypeFromMultiBuffer);
422+
310423 return success ();
311424}
312425
@@ -1316,6 +1429,9 @@ FailureOr<scf::ForOp> prefetchSharedMemoryCopy(RewriterBase &rewriter,
13161429 // Insert barriers using the appropriate strategy for each mode.
13171430 insertPipelineBarriers (rewriter, newForOp, mode);
13181431
1432+ // If swizzle_hint was applied, fix it by cloning onto the read-side.
1433+ cloneSwizzleHint (newForOp);
1434+
13191435 // For async copy mode, convert gather_to_lds to async and insert explicit
13201436 // async markers (asyncmark + wait.asyncmark). This replaces the backend's
13211437 // alias-analysis-based vmcnt insertion with precise explicit synchronization,
0 commit comments