Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

#include "iree/compiler/Codegen/Common/Passes.h"
#include "iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenAttrs.h"
#include "iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenOps.h"
#include "iree/compiler/Codegen/Dialect/VectorExt/IR/VectorExtDialect.h"

#include "llvm/Support/DebugLog.h"
Expand Down Expand Up @@ -157,6 +158,76 @@ class TileSizes {
return result;
}

/// Map tile sizes from pack source space (rank N) to pack dest space
/// (rank N+K). Divides packed dims by inner tile sizes, applies
/// outer_dims_perm, and appends inner tile sizes as new dimensions.
TileSizes mapPackSourceToDest(ArrayRef<int64_t> innerDimsPos,
ArrayRef<int64_t> staticInnerTiles,
ArrayRef<int64_t> outerDimsPerm) const {
if (empty()) {
return {};
}
TileSizes result = *this;
for (auto [dimPos, tileSize] :
llvm::zip_equal(innerDimsPos, staticInnerTiles)) {
if (result.dims[dimPos] == kUninitialized ||
result.dims[dimPos] == kOverdefined) {
return {};
}
if (result.dims[dimPos] % tileSize != 0) {
return {};
}
result.dims[dimPos] /= tileSize;
}
if (!outerDimsPerm.empty()) {
applyPermutationToVector(result.dims, outerDimsPerm);
}
for (int64_t t : staticInnerTiles) {
result.dims.push_back(t);
}
return result;
}

/// Append dimensions from `suffix` to produce a higher-rank TileSizes.
TileSizes extend(ArrayRef<int64_t> suffix) const {
SmallVector<int64_t> fullDims(dims);
fullDims.append(suffix.begin(), suffix.end());
return TileSizes(fullDims);
}

/// Map tile sizes from pack dest space (rank N+K) back to source space
/// (rank N). Truncates to outer dims, applies inverse permutation, and
/// multiplies packed dims by inner tile sizes.
TileSizes mapPackDestToSource(unsigned sourceRank,
ArrayRef<int64_t> innerDimsPos,
ArrayRef<int64_t> staticInnerTiles,
ArrayRef<int64_t> outerDimsPerm) const {
if (empty()) {
return {};
}
assert(sourceRank <= rank() && "sourceRank exceeds dest tile size rank");
TileSizes result(sourceRank);
for (unsigned i = 0; i < sourceRank; ++i) {
result.dims[i] = dims[i];
}
if (!outerDimsPerm.empty()) {
applyPermutationToVector(result.dims,
invertPermutationVector(outerDimsPerm));
}
for (auto [dimPos, tileSize] :
llvm::zip_equal(innerDimsPos, staticInnerTiles)) {
if (result.dims[dimPos] == kUninitialized ||
result.dims[dimPos] == kOverdefined) {
// Uninitialized (0) stays 0 after multiply; overdefined is preserved
// by skipping. Safe because we only multiply here, unlike
// mapPackSourceToDest which must divide (and can't divide by zero).
continue;
}
result.dims[dimPos] *= tileSize;
}
return result;
}

/// Lattice join: per-dimension merge. Uninitialized is identity.
static TileSizes join(const TileSizes &lhs, const TileSizes &rhs) {
TileSizes result = lhs;
Expand Down Expand Up @@ -304,6 +375,66 @@ class TileSizeForwardAnalysis
return success();
}

// InnerTiledOp: propagate through indexing maps (outer dims only).
if (auto innerTiledOp = dyn_cast<IREE::Codegen::InnerTiledOp>(op)) {
SmallVector<AffineMap> indexingMaps = innerTiledOp.getIndexingMapsArray();
unsigned numLoops = indexingMaps[0].getNumDims();
TileSizes iterTileSizes(numLoops);

// Merge tile sizes from all operands into iteration space.
// mapToIterationSpace reads only the first getNumResults() elements
// from the operand TileSizes, naturally skipping inner dims.
for (OpOperand &operand : op->getOpOperands()) {
TileSizes opTileSizes = getTileSizesFor(
operand.get(), operands[operand.getOperandNumber()]);
AffineMap map = indexingMaps[operand.getOperandNumber()];
iterTileSizes.merge(opTileSizes.mapToIterationSpace(map));
}

// Propagate to results. Results correspond to outputs.
unsigned numInputs = innerTiledOp.getNumInputs();
for (unsigned i = 0; i < innerTiledOp.getNumOutputs(); ++i) {
AffineMap map = indexingMaps[numInputs + i];
TileSizes outerTileSizes = iterTileSizes.mapFromIterationSpace(map);
if (outerTileSizes.empty()) {
continue;
}
// Extend with static inner dims to match full operand rank.
ArrayRef<int64_t> innerShape =
innerTiledOp.getOperandInnerShape(numInputs + i);
TileSizes fullTileSizes = outerTileSizes.extend(innerShape);
propagateIfChanged(results[i], results[i]->join(fullTileSizes));
}
return success();
}

// Pack ops: map source tile sizes to dest space.
if (auto packOp = dyn_cast<linalg::PackOp>(op)) {
if (llvm::any_of(packOp.getStaticInnerTiles(), ShapedType::isDynamic)) {
return success();
}
TileSizes srcTileSizes = getTileSizesFor(packOp.getSource(), operands[0]);
TileSizes destTileSizes = srcTileSizes.mapPackSourceToDest(
packOp.getInnerDimsPos(), packOp.getStaticInnerTiles(),
packOp.getOuterDimsPerm());
propagateIfChanged(results[0], results[0]->join(destTileSizes));
return success();
}

// Unpack ops: map source (packed) tile sizes to dest (unpacked) space.
if (auto unpackOp = dyn_cast<linalg::UnPackOp>(op)) {
if (llvm::any_of(unpackOp.getStaticInnerTiles(), ShapedType::isDynamic)) {
return success();
}
TileSizes srcTileSizes =
getTileSizesFor(unpackOp.getSource(), operands[0]);
TileSizes destTileSizes = srcTileSizes.mapPackDestToSource(
unpackOp.getDestRank(), unpackOp.getInnerDimsPos(),
unpackOp.getStaticInnerTiles(), unpackOp.getOuterDimsPerm());
propagateIfChanged(results[0], results[0]->join(destTileSizes));
return success();
}

return success();
}
};
Expand Down Expand Up @@ -365,6 +496,78 @@ class TileSizeBackwardAnalysis
return success();
}

// InnerTiledOp: propagate backward through indexing maps.
if (auto innerTiledOp = dyn_cast<IREE::Codegen::InnerTiledOp>(op)) {
SmallVector<AffineMap> indexingMaps = innerTiledOp.getIndexingMapsArray();
unsigned numLoops = indexingMaps[0].getNumDims();
unsigned numInputs = innerTiledOp.getNumInputs();
TileSizes iterTileSizes(numLoops);

// Gather from results (full-rank lattice, mapToIterationSpace skips
// inner dims naturally).
for (auto [result, resultLattice] :
llvm::zip_equal(op->getResults(), results)) {
unsigned resultIdx = cast<OpResult>(result).getResultNumber();
AffineMap map = indexingMaps[numInputs + resultIdx];
TileSizes tileSizes = getTileSizesFor(result, resultLattice);
iterTileSizes.merge(tileSizes.mapToIterationSpace(map));
}

// Gather from operands.
for (OpOperand &operand : op->getOpOperands()) {
TileSizes opTileSizes = getTileSizesFor(
operand.get(), operands[operand.getOperandNumber()]);
AffineMap map = indexingMaps[operand.getOperandNumber()];
iterTileSizes.merge(opTileSizes.mapToIterationSpace(map));
}

// Propagate back to each operand. Extend outer-rank result with inner
// dims to match full operand rank.
for (OpOperand &operand : op->getOpOperands()) {
unsigned idx = operand.getOperandNumber();
AffineMap map = indexingMaps[idx];
TileSizes outerTileSizes = iterTileSizes.mapFromIterationSpace(map);
if (outerTileSizes.empty()) {
continue;
}
ArrayRef<int64_t> innerShape = innerTiledOp.getOperandInnerShape(idx);
TileSizes fullTileSizes = outerTileSizes.extend(innerShape);
TileSizeLattice *opLattice = operands[idx];
propagateIfChanged(opLattice, opLattice->meet(fullTileSizes));
}
return success();
}

// Pack ops: result tile sizes → source tile sizes (backward).
if (auto packOp = dyn_cast<linalg::PackOp>(op)) {
if (llvm::any_of(packOp.getStaticInnerTiles(), ShapedType::isDynamic)) {
return success();
}
TileSizes resultTileSizes =
getTileSizesFor(packOp->getResult(0), results[0]);
TileSizes srcTileSizes = resultTileSizes.mapPackDestToSource(
packOp.getSourceRank(), packOp.getInnerDimsPos(),
packOp.getStaticInnerTiles(), packOp.getOuterDimsPerm());
TileSizeLattice *srcLattice = operands[0];
propagateIfChanged(srcLattice, srcLattice->meet(srcTileSizes));
return success();
}

// Unpack ops: result tile sizes → source tile sizes (backward).
if (auto unpackOp = dyn_cast<linalg::UnPackOp>(op)) {
if (llvm::any_of(unpackOp.getStaticInnerTiles(), ShapedType::isDynamic)) {
return success();
}
TileSizes resultTileSizes =
getTileSizesFor(unpackOp->getResult(0), results[0]);
TileSizes srcTileSizes = resultTileSizes.mapPackSourceToDest(
unpackOp.getInnerDimsPos(), unpackOp.getStaticInnerTiles(),
unpackOp.getOuterDimsPerm());
TileSizeLattice *srcLattice = operands[0];
propagateIfChanged(srcLattice, srcLattice->meet(srcTileSizes));
return success();
}

return success();
}

Expand All @@ -383,20 +586,20 @@ class TileSizeBackwardAnalysis
// Result querying
//===----------------------------------------------------------------------===//

/// Gather tile sizes into the iteration space of a linalg op by looking up each
/// Gather tile sizes into the iteration space of an op by looking up each
/// operand's lattice state in the solver.
static TileSizes getIterationSpaceTileSizes(linalg::LinalgOp linalgOp,
static TileSizes getIterationSpaceTileSizes(Operation *op, unsigned numLoops,
ArrayRef<AffineMap> indexingMaps,
const DataFlowSolver &solver) {
unsigned numLoops = linalgOp.getNumLoops();
TileSizes iterTileSizes(numLoops);
for (OpOperand &operand : linalgOp->getOpOperands()) {
for (OpOperand &operand : op->getOpOperands()) {
Value val = operand.get();
const TileSizeLattice *lattice = solver.lookupState<TileSizeLattice>(val);
TileSizes tileSize = getTileSizesFor(val, lattice);
if (tileSize.empty()) {
continue;
}
AffineMap map = linalgOp.getMatchingIndexingMap(&operand);
AffineMap map = indexingMaps[operand.getOperandNumber()];
assert(map.getNumDims() == numLoops);
iterTileSizes.merge(tileSize.mapToIterationSpace(map));
}
Expand Down Expand Up @@ -426,22 +629,61 @@ class MaterializeVectorTileSizesPass final
return signalPassFailure();
}

funcOp->walk([&](linalg::LinalgOp linalgOp) {
TileSizes tileSizes = getIterationSpaceTileSizes(linalgOp, solver);
if (tileSizes.isOverdefined()) {
linalgOp.emitOpError()
<< "tile size analysis did not determine a valid tile size";
auto materialize =
[&](Operation *op, TileSizes tileSizes) {
if (tileSizes.isOverdefined()) {
op->emitOpError()
<< "tile size analysis did not determine a valid tile size";
return;
}
if (!tileSizes.isDefined()) {
LDBG() << "Analysis did not determine tile size for " << *op;
return;
}
op->setAttr(
kVectorTileSizesAttrName,
DenseI64ArrayAttr::get(op->getContext(), tileSizes.getDims()));
};

funcOp->walk([&](Operation *op) {
if (auto linalgOp = dyn_cast<linalg::LinalgOp>(op)) {
SmallVector<AffineMap> indexingMaps = linalgOp.getIndexingMapsArray();
TileSizes tileSizes = getIterationSpaceTileSizes(
op, linalgOp.getNumLoops(), indexingMaps, solver);
assert(!tileSizes.isDefined() ||
tileSizes.rank() == linalgOp.getNumLoops());
materialize(op, tileSizes);
return;
}
if (!tileSizes.isDefined()) {
LDBG() << "Analysis did not determine tile size for " << *linalgOp;

if (auto innerTiledOp = dyn_cast<IREE::Codegen::InnerTiledOp>(op)) {
SmallVector<AffineMap> indexingMaps =
innerTiledOp.getIndexingMapsArray();
unsigned numLoops = indexingMaps[0].getNumDims();
TileSizes tileSizes =
getIterationSpaceTileSizes(op, numLoops, indexingMaps, solver);
materialize(op, tileSizes);
return;
}
assert(tileSizes.rank() == linalgOp.getNumLoops());

linalgOp->setAttr(
kVectorTileSizesAttrName,
DenseI64ArrayAttr::get(linalgOp->getContext(), tileSizes.getDims()));
// linalg.pack and linalg.unpack have an unpacked (rank N) and a packed
// (rank N + K) domain. linalg.pack converts from the unpacked domain to
// the packed domain, linalg.unpack works the other way round.
// Vectorization of the operations expects vector sizes in the packed
// domain. After analysis, these are available on operand of linalg.pack
// and the result of linalg.unpack, respectively.
Value packedVal;
if (auto packOp = dyn_cast<linalg::PackOp>(op)) {
packedVal = packOp.getResult();
} else if (auto unpackOp = dyn_cast<linalg::UnPackOp>(op)) {
packedVal = unpackOp.getSource();
} else {
return;
}
const TileSizeLattice *lattice =
solver.lookupState<TileSizeLattice>(packedVal);
TileSizes tileSizes = getTileSizesFor(packedVal, lattice);
materialize(op, tileSizes);
});
}
};
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -873,3 +873,45 @@ func.func @scaled_tensor_multi_mma(%arg0: tensor<3x5x1x32xf4E2M1FN>, %arg1: tens
// CHECK: %[[MMA:.+]] = iree_codegen.inner_tiled ins(%[[LHS]], %[[RHS]], %[[LHS_SCALE]], %[[RHS_SCALE]]) outs(%[[ACC]])
// CHECK-SAME: : vector<3x5x1x32xf4E2M1FN>, vector<5x1x7x32xf8E4M3FN>, vector<3x5x1xf8E8M0FNU>, vector<5x7x1xf8E8M0FNU> into vector<3x7x4xf32>
// CHECK: vector.transfer_write %[[MMA]], %arg4[%c0, %c0, %c0] {{.*}} : vector<3x7x4xf32>, tensor<3x7x4xf32>

// -----

// Masked vectorization of inner_tiled with dynamic outer dimensions.
// Vector sizes come from the iree_codegen.vector_tile_sizes attribute.
// The LHS has shape <?x?x4xf16> (outer dims dynamic), iteration space is
// [i=2, j=5, k=3], so reads should be masked.
#contraction_accesses = [
affine_map<(i, j, k) -> (i, k)>,
affine_map<(i, j, k) -> (k, j)>,
affine_map<(i, j, k) -> (i, j)>
]
func.func @masked_tensor_multi_mma(%lhs: tensor<?x?x4xf16>, %rhs: tensor<?x?x4xf16>, %acc: tensor<?x?x4xf32>) -> tensor<?x?x4xf32> {
%0 = iree_codegen.inner_tiled ins(%lhs, %rhs) outs(%acc) {
iree_codegen.vector_tile_sizes = array<i64: 2, 5, 3>,
indexing_maps = #contraction_accesses,
iterator_types = [#linalg.iterator_type<parallel>, #linalg.iterator_type<parallel>, #linalg.iterator_type<reduction>],
kind = #iree_gpu.mma_layout<MFMA_F32_16x16x16_F16>,
semantics = #iree_gpu.mma_semantics<distributed = true, opaque = false>
} : tensor<?x?x4xf16>, tensor<?x?x4xf16> into tensor<?x?x4xf32>
return %0 : tensor<?x?x4xf32>
}

// CHECK-LABEL: func @masked_tensor_multi_mma

// With vectorSizes, reads use create_mask for dynamic dims.
// LHS: outer (i=2, k=3) + inner (4) → vector<2x3x4xf16>
// RHS: outer (k=3, j=5) + inner (4) → vector<3x5x4xf16>
// ACC: outer (i=2, j=5) + inner (4) → vector<2x5x4xf32>
// CHECK-DAG: %[[CSTF16:.+]] = arith.constant 0.000000e+00 : f16
// CHECK-DAG: %[[CSTF32:.+]] = arith.constant 0.000000e+00 : f32
// CHECK: %[[LHS_MASK:.+]] = vector.create_mask {{.*}} : vector<2x3x4xi1>
// CHECK: %[[LHS:.+]] = vector.transfer_read %arg0{{.*}}, %[[CSTF16]], %[[LHS_MASK]]{{.*}} : tensor<?x?x4xf16>, vector<2x3x4xf16>
// CHECK: %[[RHS_MASK:.+]] = vector.create_mask {{.*}} : vector<3x5x4xi1>
// CHECK: %[[RHS:.+]] = vector.transfer_read %arg1{{.*}}, %[[CSTF16]], %[[RHS_MASK]]{{.*}} : tensor<?x?x4xf16>, vector<3x5x4xf16>
// CHECK: %[[ACC_MASK:.+]] = vector.create_mask {{.*}} : vector<2x5x4xi1>
// CHECK: %[[ACC:.+]] = vector.transfer_read %arg2{{.*}}, %[[CSTF32]], %[[ACC_MASK]]{{.*}} : tensor<?x?x4xf32>, vector<2x5x4xf32>
// CHECK: %[[MMA:.+]] = iree_codegen.inner_tiled ins(%[[LHS]], %[[RHS]]) outs(%[[ACC]])
// CHECK-SAME: : vector<2x3x4xf16>, vector<3x5x4xf16> into vector<2x5x4xf32>
// CHECK: %[[WRITE_MASK:.+]] = vector.create_mask {{.*}} : vector<2x5x4xi1>
// CHECK: vector.transfer_write %[[MMA]], %arg2{{.*}}, %[[WRITE_MASK]] {in_bounds = [false, false, true]}
// CHECK-SAME: : vector<2x5x4xf32>, tensor<?x?x4xf32>
Loading
Loading