-
Notifications
You must be signed in to change notification settings - Fork 873
[Codegen] Support more operations in tile size analysis #23971
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -6,6 +6,7 @@ | |
|
|
||
| #include "iree/compiler/Codegen/Common/Passes.h" | ||
| #include "iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenAttrs.h" | ||
| #include "iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenOps.h" | ||
| #include "iree/compiler/Codegen/Dialect/VectorExt/IR/VectorExtDialect.h" | ||
|
|
||
| #include "llvm/Support/DebugLog.h" | ||
|
|
@@ -157,6 +158,70 @@ class TileSizes { | |
| return result; | ||
| } | ||
|
|
||
| /// Map tile sizes from pack source space (rank N) to pack dest space | ||
| /// (rank N+K). Divides packed dims by inner tile sizes, applies | ||
| /// outer_dims_perm, and appends inner tile sizes as new dimensions. | ||
| TileSizes mapPackSourceToDest(ArrayRef<int64_t> innerDimsPos, | ||
| ArrayRef<int64_t> staticInnerTiles, | ||
| ArrayRef<int64_t> outerDimsPerm) const { | ||
| if (empty()) { | ||
| return {}; | ||
| } | ||
| TileSizes result = *this; | ||
| for (auto [dimPos, tileSize] : | ||
| llvm::zip_equal(innerDimsPos, staticInnerTiles)) { | ||
| if (result.dims[dimPos] == kUninitialized || | ||
| result.dims[dimPos] == kOverdefined) { | ||
| continue; | ||
| } | ||
| result.dims[dimPos] = | ||
| llvm::divideCeilSigned(result.dims[dimPos], tileSize); | ||
| } | ||
| if (!outerDimsPerm.empty()) { | ||
| applyPermutationToVector(result.dims, outerDimsPerm); | ||
| } | ||
|
Comment on lines
+180
to
+182
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We need a test for |
||
| llvm::append_range(result.dims, staticInnerTiles); | ||
| return result; | ||
| } | ||
|
|
||
| /// Map tile sizes from pack dest space (rank N+K) back to source space | ||
| /// (rank N). Truncates to outer dims, applies inverse permutation, and | ||
| /// multiplies packed dims by inner tile sizes. | ||
| TileSizes mapPackDestToSource(unsigned sourceRank, | ||
| ArrayRef<int64_t> innerDimsPos, | ||
| ArrayRef<int64_t> staticInnerTiles, | ||
| ArrayRef<int64_t> outerDimsPerm) const { | ||
| if (empty()) { | ||
| return {}; | ||
| } | ||
| assert(sourceRank <= rank() && "sourceRank exceeds dest tile size rank"); | ||
| TileSizes result(sourceRank); | ||
| for (unsigned i = 0; i < sourceRank; ++i) { | ||
| result.dims[i] = dims[i]; | ||
| } | ||
| if (!outerDimsPerm.empty()) { | ||
| applyPermutationToVector(result.dims, | ||
| invertPermutationVector(outerDimsPerm)); | ||
| } | ||
| for (auto [dimPos, tileSize] : | ||
| llvm::zip_equal(innerDimsPos, staticInnerTiles)) { | ||
| if (result.dims[dimPos] == kUninitialized || | ||
| result.dims[dimPos] == kOverdefined) { | ||
| // Uninitialized and overdefined are preserved. | ||
| continue; | ||
| } | ||
| result.dims[dimPos] *= tileSize; | ||
| } | ||
| return result; | ||
| } | ||
|
|
||
| /// Append dimensions from `suffix` to produce a higher-rank TileSizes. | ||
| TileSizes append(ArrayRef<int64_t> suffix) const { | ||
| SmallVector<int64_t> fullDims(dims); | ||
| fullDims.append(suffix.begin(), suffix.end()); | ||
| return TileSizes(fullDims); | ||
| } | ||
|
|
||
| /// Lattice join: per-dimension merge. Uninitialized is identity. | ||
| static TileSizes join(const TileSizes &lhs, const TileSizes &rhs) { | ||
| TileSizes result = lhs; | ||
|
|
@@ -304,6 +369,66 @@ class TileSizeForwardAnalysis | |
| return success(); | ||
| } | ||
|
|
||
| // InnerTiledOp: propagate through indexing maps (outer dims only). | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why do we exclude inner dims? We will need it when people switch CPU data-tiling to inner_tiles and that is required by scalable vectors and future RVV work.
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The comment isn't phrased very well. We don't exclude the inner dims (we append them before propagating, see below), but they don't participate in the propagation through indexing maps (because they don't need to). I've clarified the comment. We also don't materialize the inner dims in the attribute on
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
I don't see the update on the comment. Do you miss the change? |
||
| if (auto innerTiledOp = dyn_cast<IREE::Codegen::InnerTiledOp>(op)) { | ||
| SmallVector<AffineMap> indexingMaps = innerTiledOp.getIndexingMapsArray(); | ||
| unsigned numLoops = indexingMaps[0].getNumDims(); | ||
| TileSizes iterTileSizes(numLoops); | ||
|
|
||
| // Merge tile sizes from all operands into iteration space. | ||
| // mapToIterationSpace reads only the first getNumResults() elements | ||
| // from the operand TileSizes, naturally skipping inner dims. | ||
| for (OpOperand &operand : op->getOpOperands()) { | ||
| TileSizes opTileSizes = getTileSizesFor( | ||
| operand.get(), operands[operand.getOperandNumber()]); | ||
| AffineMap map = indexingMaps[operand.getOperandNumber()]; | ||
| iterTileSizes.merge(opTileSizes.mapToIterationSpace(map)); | ||
| } | ||
|
|
||
| // Propagate to results. Results correspond to outputs. | ||
| unsigned numInputs = innerTiledOp.getNumInputs(); | ||
| for (unsigned i = 0; i < innerTiledOp.getNumOutputs(); ++i) { | ||
| AffineMap map = indexingMaps[numInputs + i]; | ||
| TileSizes outerTileSizes = iterTileSizes.mapFromIterationSpace(map); | ||
| if (outerTileSizes.empty()) { | ||
| continue; | ||
| } | ||
| // Extend with static inner dims to match full operand rank. | ||
| ArrayRef<int64_t> innerShape = | ||
| innerTiledOp.getOperandInnerShape(numInputs + i); | ||
| TileSizes fullTileSizes = outerTileSizes.append(innerShape); | ||
| propagateIfChanged(results[i], results[i]->join(fullTileSizes)); | ||
| } | ||
| return success(); | ||
| } | ||
|
|
||
| // Pack ops: map source tile sizes to dest space. | ||
| if (auto packOp = dyn_cast<linalg::PackOp>(op)) { | ||
| if (ShapedType::isDynamicShape(packOp.getStaticInnerTiles())) { | ||
| return success(); | ||
| } | ||
| TileSizes srcTileSizes = getTileSizesFor(packOp.getSource(), operands[0]); | ||
| TileSizes destTileSizes = srcTileSizes.mapPackSourceToDest( | ||
| packOp.getInnerDimsPos(), packOp.getStaticInnerTiles(), | ||
| packOp.getOuterDimsPerm()); | ||
| propagateIfChanged(results[0], results[0]->join(destTileSizes)); | ||
| return success(); | ||
| } | ||
|
|
||
| // Unpack ops: map source (packed) tile sizes to dest (unpacked) space. | ||
| if (auto unpackOp = dyn_cast<linalg::UnPackOp>(op)) { | ||
| if (ShapedType::isDynamicShape(unpackOp.getStaticInnerTiles())) { | ||
| return success(); | ||
| } | ||
| TileSizes srcTileSizes = | ||
| getTileSizesFor(unpackOp.getSource(), operands[0]); | ||
| TileSizes destTileSizes = srcTileSizes.mapPackDestToSource( | ||
| unpackOp.getDestRank(), unpackOp.getInnerDimsPos(), | ||
| unpackOp.getStaticInnerTiles(), unpackOp.getOuterDimsPerm()); | ||
| propagateIfChanged(results[0], results[0]->join(destTileSizes)); | ||
| return success(); | ||
| } | ||
|
|
||
| return success(); | ||
| } | ||
| }; | ||
|
|
@@ -365,6 +490,83 @@ class TileSizeBackwardAnalysis | |
| return success(); | ||
| } | ||
|
|
||
| // InnerTiledOp: propagate backward through indexing maps. | ||
| if (auto innerTiledOp = dyn_cast<IREE::Codegen::InnerTiledOp>(op)) { | ||
| SmallVector<AffineMap> indexingMaps = innerTiledOp.getIndexingMapsArray(); | ||
| unsigned numLoops = indexingMaps[0].getNumDims(); | ||
| unsigned numInputs = innerTiledOp.getNumInputs(); | ||
| TileSizes iterTileSizes(numLoops); | ||
|
|
||
| // Gather from results (full-rank lattice, mapToIterationSpace skips | ||
| // inner dims naturally). | ||
| for (auto [result, resultLattice] : | ||
| llvm::zip_equal(op->getResults(), results)) { | ||
| unsigned resultIdx = cast<OpResult>(result).getResultNumber(); | ||
| AffineMap map = indexingMaps[numInputs + resultIdx]; | ||
| TileSizes tileSizes = getTileSizesFor(result, resultLattice); | ||
| iterTileSizes.merge(tileSizes.mapToIterationSpace(map)); | ||
| } | ||
|
|
||
| // Gather from operands. | ||
| for (OpOperand &operand : op->getOpOperands()) { | ||
| TileSizes opTileSizes = getTileSizesFor( | ||
| operand.get(), operands[operand.getOperandNumber()]); | ||
| AffineMap map = indexingMaps[operand.getOperandNumber()]; | ||
| iterTileSizes.merge(opTileSizes.mapToIterationSpace(map)); | ||
| } | ||
|
|
||
| // Propagate back to each operand. Extend outer-rank result with inner | ||
| // dims to match full operand rank. | ||
| for (OpOperand &operand : op->getOpOperands()) { | ||
| unsigned idx = operand.getOperandNumber(); | ||
| AffineMap map = indexingMaps[idx]; | ||
| TileSizes outerTileSizes = iterTileSizes.mapFromIterationSpace(map); | ||
| if (outerTileSizes.empty()) { | ||
| continue; | ||
| } | ||
| ArrayRef<int64_t> innerShape = innerTiledOp.getOperandInnerShape(idx); | ||
| TileSizes fullTileSizes = outerTileSizes.append(innerShape); | ||
| TileSizeLattice *opLattice = operands[idx]; | ||
| propagateIfChanged(opLattice, opLattice->meet(fullTileSizes)); | ||
| } | ||
| return success(); | ||
| } | ||
|
|
||
| // Pack ops: result tile sizes → source tile sizes (backward). | ||
| if (auto packOp = dyn_cast<linalg::PackOp>(op)) { | ||
| if (ShapedType::isDynamicShape(packOp.getStaticInnerTiles())) { | ||
| return success(); | ||
| } | ||
| if (packOp.getPaddingValue()) { | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can we have a test for padding value? |
||
| // We do not backward propagate for pack with padding, as it would | ||
| // potentially propagate too large tile sizes. | ||
| return success(); | ||
| } | ||
| TileSizes resultTileSizes = | ||
| getTileSizesFor(packOp->getResult(0), results[0]); | ||
| TileSizes srcTileSizes = resultTileSizes.mapPackDestToSource( | ||
| packOp.getSourceRank(), packOp.getInnerDimsPos(), | ||
| packOp.getStaticInnerTiles(), packOp.getOuterDimsPerm()); | ||
| TileSizeLattice *srcLattice = operands[0]; | ||
| propagateIfChanged(srcLattice, srcLattice->meet(srcTileSizes)); | ||
| return success(); | ||
| } | ||
|
|
||
| // Unpack ops: result tile sizes → source tile sizes (backward). | ||
| if (auto unpackOp = dyn_cast<linalg::UnPackOp>(op)) { | ||
| if (ShapedType::isDynamicShape(unpackOp.getStaticInnerTiles())) { | ||
| return success(); | ||
| } | ||
| TileSizes resultTileSizes = | ||
| getTileSizesFor(unpackOp->getResult(0), results[0]); | ||
| TileSizes srcTileSizes = resultTileSizes.mapPackSourceToDest( | ||
| unpackOp.getInnerDimsPos(), unpackOp.getStaticInnerTiles(), | ||
| unpackOp.getOuterDimsPerm()); | ||
| TileSizeLattice *srcLattice = operands[0]; | ||
| propagateIfChanged(srcLattice, srcLattice->meet(srcTileSizes)); | ||
| return success(); | ||
| } | ||
|
|
||
| return success(); | ||
| } | ||
|
|
||
|
|
@@ -383,20 +585,20 @@ class TileSizeBackwardAnalysis | |
| // Result querying | ||
| //===----------------------------------------------------------------------===// | ||
|
|
||
| /// Gather tile sizes into the iteration space of a linalg op by looking up each | ||
| /// Gather tile sizes into the iteration space of an op by looking up each | ||
| /// operand's lattice state in the solver. | ||
| static TileSizes getIterationSpaceTileSizes(linalg::LinalgOp linalgOp, | ||
| static TileSizes getIterationSpaceTileSizes(Operation *op, unsigned numLoops, | ||
| ArrayRef<AffineMap> indexingMaps, | ||
| const DataFlowSolver &solver) { | ||
|
Comment on lines
+590
to
592
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I feel that you want to pass in
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. IIUC,
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Oh, I missed that. I was going to suggest |
||
| unsigned numLoops = linalgOp.getNumLoops(); | ||
| TileSizes iterTileSizes(numLoops); | ||
| for (OpOperand &operand : linalgOp->getOpOperands()) { | ||
| for (OpOperand &operand : op->getOpOperands()) { | ||
| Value val = operand.get(); | ||
| const TileSizeLattice *lattice = solver.lookupState<TileSizeLattice>(val); | ||
| TileSizes tileSize = getTileSizesFor(val, lattice); | ||
| if (tileSize.empty()) { | ||
| continue; | ||
| } | ||
| AffineMap map = linalgOp.getMatchingIndexingMap(&operand); | ||
| AffineMap map = indexingMaps[operand.getOperandNumber()]; | ||
| assert(map.getNumDims() == numLoops); | ||
| iterTileSizes.merge(tileSize.mapToIterationSpace(map)); | ||
| } | ||
|
|
@@ -426,23 +628,63 @@ class MaterializeVectorTileSizesPass final | |
| return signalPassFailure(); | ||
| } | ||
|
|
||
| funcOp->walk([&](linalg::LinalgOp linalgOp) { | ||
| TileSizes tileSizes = getIterationSpaceTileSizes(linalgOp, solver); | ||
| auto materialize = [](Operation *op, TileSizes tileSizes) -> LogicalResult { | ||
| if (tileSizes.isOverdefined()) { | ||
| linalgOp.emitOpError() | ||
| op->emitOpError() | ||
| << "tile size analysis did not determine a valid tile size"; | ||
| return; | ||
| return failure(); | ||
| } | ||
| if (!tileSizes.isDefined()) { | ||
| LDBG() << "Analysis did not determine tile size for " << *linalgOp; | ||
| return; | ||
| LDBG() << "Analysis did not determine tile size for " << *op; | ||
| return success(); | ||
| } | ||
| assert(tileSizes.rank() == linalgOp.getNumLoops()); | ||
|
|
||
| linalgOp->setAttr( | ||
| op->setAttr( | ||
| kVectorTileSizesAttrName, | ||
| DenseI64ArrayAttr::get(linalgOp->getContext(), tileSizes.getDims())); | ||
| DenseI64ArrayAttr::get(op->getContext(), tileSizes.getDims())); | ||
| return success(); | ||
| }; | ||
|
|
||
| auto result = funcOp->walk([&](Operation *op) -> WalkResult { | ||
| if (isa<linalg::PackOp>(op) || isa<linalg::UnPackOp>(op)) { | ||
| // linalg.pack and linalg.unpack have an unpacked (rank N) and a packed | ||
| // (rank N + K) domain. linalg.pack converts from the unpacked domain to | ||
| // the packed domain, linalg.unpack works the other way round. | ||
| // Vectorization of the operations expects vector sizes in the packed | ||
| // domain. After analysis, these are available on operand of linalg.pack | ||
| // and the result of linalg.unpack, respectively. | ||
| Value packedVal; | ||
| if (auto packOp = dyn_cast<linalg::PackOp>(op)) { | ||
| packedVal = packOp.getResult(); | ||
| } else if (auto unpackOp = dyn_cast<linalg::UnPackOp>(op)) { | ||
| packedVal = unpackOp.getSource(); | ||
| } | ||
| const TileSizeLattice *lattice = | ||
| solver.lookupState<TileSizeLattice>(packedVal); | ||
| TileSizes tileSizes = getTileSizesFor(packedVal, lattice); | ||
| return WalkResult(materialize(op, tileSizes)); | ||
| } | ||
| if (auto linalgOp = dyn_cast<linalg::LinalgOp>(op)) { | ||
| SmallVector<AffineMap> indexingMaps = linalgOp.getIndexingMapsArray(); | ||
| TileSizes tileSizes = getIterationSpaceTileSizes( | ||
| op, linalgOp.getNumLoops(), indexingMaps, solver); | ||
| assert(!tileSizes.isDefined() || | ||
| tileSizes.rank() == linalgOp.getNumLoops()); | ||
| return WalkResult(materialize(op, tileSizes)); | ||
| } | ||
|
|
||
| if (auto innerTiledOp = dyn_cast<IREE::Codegen::InnerTiledOp>(op)) { | ||
| SmallVector<AffineMap> indexingMaps = | ||
| innerTiledOp.getIndexingMapsArray(); | ||
| unsigned numLoops = indexingMaps[0].getNumDims(); | ||
| TileSizes tileSizes = | ||
| getIterationSpaceTileSizes(op, numLoops, indexingMaps, solver); | ||
| return WalkResult(materialize(op, tileSizes)); | ||
| } | ||
| return WalkResult::advance(); | ||
| }); | ||
| if (result.wasInterrupted()) { | ||
| signalPassFailure(); | ||
| } | ||
| } | ||
| }; | ||
|
|
||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.