Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

#include "iree/compiler/Codegen/Common/Passes.h"
#include "iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenAttrs.h"
#include "iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenOps.h"
#include "iree/compiler/Codegen/Dialect/VectorExt/IR/VectorExtDialect.h"

#include "llvm/Support/DebugLog.h"
Expand Down Expand Up @@ -157,6 +158,70 @@ class TileSizes {
return result;
}

/// Map tile sizes from pack source space (rank N) to pack dest space
/// (rank N+K). Divides packed dims by inner tile sizes, applies
/// outer_dims_perm, and appends inner tile sizes as new dimensions.
TileSizes mapPackSourceToDest(ArrayRef<int64_t> innerDimsPos,
ArrayRef<int64_t> staticInnerTiles,
ArrayRef<int64_t> outerDimsPerm) const {
if (empty()) {
return {};
}
TileSizes result = *this;
for (auto [dimPos, tileSize] :
llvm::zip_equal(innerDimsPos, staticInnerTiles)) {
if (result.dims[dimPos] == kUninitialized ||
result.dims[dimPos] == kOverdefined) {
continue;
}
result.dims[dimPos] =
llvm::divideCeilSigned(result.dims[dimPos], tileSize);
}
if (!outerDimsPerm.empty()) {
applyPermutationToVector(result.dims, outerDimsPerm);
}
Comment on lines +180 to +182
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We need a test for outer_dims_perm. Please use a permutation that is not identical to the inversed permutation. e.g., do not use [1, 0].

llvm::append_range(result.dims, staticInnerTiles);
return result;
}

/// Map tile sizes from pack dest space (rank N+K) back to source space
/// (rank N). Truncates to outer dims, applies inverse permutation, and
/// multiplies packed dims by inner tile sizes.
TileSizes mapPackDestToSource(unsigned sourceRank,
ArrayRef<int64_t> innerDimsPos,
ArrayRef<int64_t> staticInnerTiles,
ArrayRef<int64_t> outerDimsPerm) const {
if (empty()) {
return {};
}
assert(sourceRank <= rank() && "sourceRank exceeds dest tile size rank");
TileSizes result(sourceRank);
for (unsigned i = 0; i < sourceRank; ++i) {
result.dims[i] = dims[i];
}
if (!outerDimsPerm.empty()) {
applyPermutationToVector(result.dims,
invertPermutationVector(outerDimsPerm));
}
for (auto [dimPos, tileSize] :
llvm::zip_equal(innerDimsPos, staticInnerTiles)) {
if (result.dims[dimPos] == kUninitialized ||
result.dims[dimPos] == kOverdefined) {
// Uninitialized and overdefined are preserved.
continue;
}
result.dims[dimPos] *= tileSize;
}
return result;
}

/// Append dimensions from `suffix` to produce a higher-rank TileSizes.
TileSizes append(ArrayRef<int64_t> suffix) const {
SmallVector<int64_t> fullDims(dims);
fullDims.append(suffix.begin(), suffix.end());
return TileSizes(fullDims);
}

/// Lattice join: per-dimension merge. Uninitialized is identity.
static TileSizes join(const TileSizes &lhs, const TileSizes &rhs) {
TileSizes result = lhs;
Expand Down Expand Up @@ -304,6 +369,66 @@ class TileSizeForwardAnalysis
return success();
}

// InnerTiledOp: propagate through indexing maps (outer dims only).
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do we exclude inner dims? We will need it when people switch CPU data-tiling to inner_tiles and that is required by scalable vectors and future RVV work.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The comment isn't phrased very well. We don't exclude the inner dims (we append them before propagating, see below), but they don't participate in the propagation through indexing maps (because they don't need to). I've clarified the comment. We also don't materialize the inner dims in the attribute on inner_tiled, because it would duplicate the information.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've clarified the comment.

I don't see the update on the comment. Do you miss the change?

if (auto innerTiledOp = dyn_cast<IREE::Codegen::InnerTiledOp>(op)) {
SmallVector<AffineMap> indexingMaps = innerTiledOp.getIndexingMapsArray();
unsigned numLoops = indexingMaps[0].getNumDims();
TileSizes iterTileSizes(numLoops);

// Merge tile sizes from all operands into iteration space.
// mapToIterationSpace reads only the first getNumResults() elements
// from the operand TileSizes, naturally skipping inner dims.
for (OpOperand &operand : op->getOpOperands()) {
TileSizes opTileSizes = getTileSizesFor(
operand.get(), operands[operand.getOperandNumber()]);
AffineMap map = indexingMaps[operand.getOperandNumber()];
iterTileSizes.merge(opTileSizes.mapToIterationSpace(map));
}

// Propagate to results. Results correspond to outputs.
unsigned numInputs = innerTiledOp.getNumInputs();
for (unsigned i = 0; i < innerTiledOp.getNumOutputs(); ++i) {
AffineMap map = indexingMaps[numInputs + i];
TileSizes outerTileSizes = iterTileSizes.mapFromIterationSpace(map);
if (outerTileSizes.empty()) {
continue;
}
// Extend with static inner dims to match full operand rank.
ArrayRef<int64_t> innerShape =
innerTiledOp.getOperandInnerShape(numInputs + i);
TileSizes fullTileSizes = outerTileSizes.append(innerShape);
propagateIfChanged(results[i], results[i]->join(fullTileSizes));
}
return success();
}

// Pack ops: map source tile sizes to dest space.
if (auto packOp = dyn_cast<linalg::PackOp>(op)) {
if (ShapedType::isDynamicShape(packOp.getStaticInnerTiles())) {
return success();
}
TileSizes srcTileSizes = getTileSizesFor(packOp.getSource(), operands[0]);
TileSizes destTileSizes = srcTileSizes.mapPackSourceToDest(
packOp.getInnerDimsPos(), packOp.getStaticInnerTiles(),
packOp.getOuterDimsPerm());
propagateIfChanged(results[0], results[0]->join(destTileSizes));
return success();
}

// Unpack ops: map source (packed) tile sizes to dest (unpacked) space.
if (auto unpackOp = dyn_cast<linalg::UnPackOp>(op)) {
if (ShapedType::isDynamicShape(unpackOp.getStaticInnerTiles())) {
return success();
}
TileSizes srcTileSizes =
getTileSizesFor(unpackOp.getSource(), operands[0]);
TileSizes destTileSizes = srcTileSizes.mapPackDestToSource(
unpackOp.getDestRank(), unpackOp.getInnerDimsPos(),
unpackOp.getStaticInnerTiles(), unpackOp.getOuterDimsPerm());
propagateIfChanged(results[0], results[0]->join(destTileSizes));
return success();
}

return success();
}
};
Expand Down Expand Up @@ -365,6 +490,83 @@ class TileSizeBackwardAnalysis
return success();
}

// InnerTiledOp: propagate backward through indexing maps.
if (auto innerTiledOp = dyn_cast<IREE::Codegen::InnerTiledOp>(op)) {
SmallVector<AffineMap> indexingMaps = innerTiledOp.getIndexingMapsArray();
unsigned numLoops = indexingMaps[0].getNumDims();
unsigned numInputs = innerTiledOp.getNumInputs();
TileSizes iterTileSizes(numLoops);

// Gather from results (full-rank lattice, mapToIterationSpace skips
// inner dims naturally).
for (auto [result, resultLattice] :
llvm::zip_equal(op->getResults(), results)) {
unsigned resultIdx = cast<OpResult>(result).getResultNumber();
AffineMap map = indexingMaps[numInputs + resultIdx];
TileSizes tileSizes = getTileSizesFor(result, resultLattice);
iterTileSizes.merge(tileSizes.mapToIterationSpace(map));
}

// Gather from operands.
for (OpOperand &operand : op->getOpOperands()) {
TileSizes opTileSizes = getTileSizesFor(
operand.get(), operands[operand.getOperandNumber()]);
AffineMap map = indexingMaps[operand.getOperandNumber()];
iterTileSizes.merge(opTileSizes.mapToIterationSpace(map));
}

// Propagate back to each operand. Extend outer-rank result with inner
// dims to match full operand rank.
for (OpOperand &operand : op->getOpOperands()) {
unsigned idx = operand.getOperandNumber();
AffineMap map = indexingMaps[idx];
TileSizes outerTileSizes = iterTileSizes.mapFromIterationSpace(map);
if (outerTileSizes.empty()) {
continue;
}
ArrayRef<int64_t> innerShape = innerTiledOp.getOperandInnerShape(idx);
TileSizes fullTileSizes = outerTileSizes.append(innerShape);
TileSizeLattice *opLattice = operands[idx];
propagateIfChanged(opLattice, opLattice->meet(fullTileSizes));
}
return success();
}

// Pack ops: result tile sizes → source tile sizes (backward).
if (auto packOp = dyn_cast<linalg::PackOp>(op)) {
if (ShapedType::isDynamicShape(packOp.getStaticInnerTiles())) {
return success();
}
if (packOp.getPaddingValue()) {
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we have a test for padding value?

// We do not backward propagate for pack with padding, as it would
// potentially propagate too large tile sizes.
return success();
}
TileSizes resultTileSizes =
getTileSizesFor(packOp->getResult(0), results[0]);
TileSizes srcTileSizes = resultTileSizes.mapPackDestToSource(
packOp.getSourceRank(), packOp.getInnerDimsPos(),
packOp.getStaticInnerTiles(), packOp.getOuterDimsPerm());
TileSizeLattice *srcLattice = operands[0];
propagateIfChanged(srcLattice, srcLattice->meet(srcTileSizes));
return success();
}

// Unpack ops: result tile sizes → source tile sizes (backward).
if (auto unpackOp = dyn_cast<linalg::UnPackOp>(op)) {
if (ShapedType::isDynamicShape(unpackOp.getStaticInnerTiles())) {
return success();
}
TileSizes resultTileSizes =
getTileSizesFor(unpackOp->getResult(0), results[0]);
TileSizes srcTileSizes = resultTileSizes.mapPackSourceToDest(
unpackOp.getInnerDimsPos(), unpackOp.getStaticInnerTiles(),
unpackOp.getOuterDimsPerm());
TileSizeLattice *srcLattice = operands[0];
propagateIfChanged(srcLattice, srcLattice->meet(srcTileSizes));
return success();
}

return success();
}

Expand All @@ -383,20 +585,20 @@ class TileSizeBackwardAnalysis
// Result querying
//===----------------------------------------------------------------------===//

/// Gather tile sizes into the iteration space of a linalg op by looking up each
/// Gather tile sizes into the iteration space of an op by looking up each
/// operand's lattice state in the solver.
static TileSizes getIterationSpaceTileSizes(linalg::LinalgOp linalgOp,
static TileSizes getIterationSpaceTileSizes(Operation *op, unsigned numLoops,
ArrayRef<AffineMap> indexingMaps,
const DataFlowSolver &solver) {
Comment on lines +590 to 592
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I feel that you want to pass in TilingInterface op which has the additional information.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

IIUC, TilingInterface doesn't give us access to the indexing maps, so we won't gain much here.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh, I missed that. I was going to suggest IndexingMapOpInterface, then I found that InnerTileOp haven't implemented the interface, so it is okay for now.

unsigned numLoops = linalgOp.getNumLoops();
TileSizes iterTileSizes(numLoops);
for (OpOperand &operand : linalgOp->getOpOperands()) {
for (OpOperand &operand : op->getOpOperands()) {
Value val = operand.get();
const TileSizeLattice *lattice = solver.lookupState<TileSizeLattice>(val);
TileSizes tileSize = getTileSizesFor(val, lattice);
if (tileSize.empty()) {
continue;
}
AffineMap map = linalgOp.getMatchingIndexingMap(&operand);
AffineMap map = indexingMaps[operand.getOperandNumber()];
assert(map.getNumDims() == numLoops);
iterTileSizes.merge(tileSize.mapToIterationSpace(map));
}
Expand Down Expand Up @@ -426,23 +628,63 @@ class MaterializeVectorTileSizesPass final
return signalPassFailure();
}

funcOp->walk([&](linalg::LinalgOp linalgOp) {
TileSizes tileSizes = getIterationSpaceTileSizes(linalgOp, solver);
auto materialize = [](Operation *op, TileSizes tileSizes) -> LogicalResult {
if (tileSizes.isOverdefined()) {
linalgOp.emitOpError()
op->emitOpError()
<< "tile size analysis did not determine a valid tile size";
return;
return failure();
}
if (!tileSizes.isDefined()) {
LDBG() << "Analysis did not determine tile size for " << *linalgOp;
return;
LDBG() << "Analysis did not determine tile size for " << *op;
return success();
}
assert(tileSizes.rank() == linalgOp.getNumLoops());

linalgOp->setAttr(
op->setAttr(
kVectorTileSizesAttrName,
DenseI64ArrayAttr::get(linalgOp->getContext(), tileSizes.getDims()));
DenseI64ArrayAttr::get(op->getContext(), tileSizes.getDims()));
return success();
};

auto result = funcOp->walk([&](Operation *op) -> WalkResult {
if (isa<linalg::PackOp>(op) || isa<linalg::UnPackOp>(op)) {
// linalg.pack and linalg.unpack have an unpacked (rank N) and a packed
// (rank N + K) domain. linalg.pack converts from the unpacked domain to
// the packed domain, linalg.unpack works the other way round.
// Vectorization of the operations expects vector sizes in the packed
// domain. After analysis, these are available on operand of linalg.pack
// and the result of linalg.unpack, respectively.
Value packedVal;
if (auto packOp = dyn_cast<linalg::PackOp>(op)) {
packedVal = packOp.getResult();
} else if (auto unpackOp = dyn_cast<linalg::UnPackOp>(op)) {
packedVal = unpackOp.getSource();
}
const TileSizeLattice *lattice =
solver.lookupState<TileSizeLattice>(packedVal);
TileSizes tileSizes = getTileSizesFor(packedVal, lattice);
return WalkResult(materialize(op, tileSizes));
}
if (auto linalgOp = dyn_cast<linalg::LinalgOp>(op)) {
SmallVector<AffineMap> indexingMaps = linalgOp.getIndexingMapsArray();
TileSizes tileSizes = getIterationSpaceTileSizes(
op, linalgOp.getNumLoops(), indexingMaps, solver);
assert(!tileSizes.isDefined() ||
tileSizes.rank() == linalgOp.getNumLoops());
return WalkResult(materialize(op, tileSizes));
}

if (auto innerTiledOp = dyn_cast<IREE::Codegen::InnerTiledOp>(op)) {
SmallVector<AffineMap> indexingMaps =
innerTiledOp.getIndexingMapsArray();
unsigned numLoops = indexingMaps[0].getNumDims();
TileSizes tileSizes =
getIterationSpaceTileSizes(op, numLoops, indexingMaps, solver);
return WalkResult(materialize(op, tileSizes));
}
return WalkResult::advance();
});
if (result.wasInterrupted()) {
signalPassFailure();
}
}
};

Expand Down
Loading
Loading