Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
105 changes: 94 additions & 11 deletions compiler/src/iree/compiler/Codegen/Common/VectorLayoutAnalysis.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,13 @@
#include "llvm/Support/DebugLog.h"
#include "llvm/Support/ErrorHandling.h"
#include "llvm/Support/raw_ostream.h"
#include "mlir/Analysis/TopologicalSortUtils.h"
#include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/Dialect/Utils/IndexingUtils.h"
#include "mlir/Dialect/Vector/IR/VectorOps.h"
#include "mlir/IR/Diagnostics.h"
#include "mlir/IR/IRMapping.h"
#include "mlir/Interfaces/SideEffectInterfaces.h"

#define DEBUG_TYPE "iree-codegen-vector-layout-analysis"

Expand All @@ -30,6 +33,10 @@ using namespace IREE::VectorExt;
/// analysis cost; most values see 1-2 candidates in practice.
static constexpr int kMaxCandidatesPerValue = 4;

/// Maximum length of chains of cheap-to-compute operations that get duplicated
/// for layout conflict resolution.
static constexpr unsigned kMaxChainLength = 8;

//===----------------------------------------------------------------------===//
// Layout Analysis
//
Expand Down Expand Up @@ -464,6 +471,70 @@ void LayoutAnalysis::fixupOp(Operation *op) {
}
}

/// Returns true if the operation is a duplicatable leaf: trivially cheap to
/// recompute and has no operands that need cloning.
static bool isDuplicatableLeaf(Operation *op) {
return op->hasTrait<OpTrait::ConstantLike>() ||
isa<vector::StepOp, vector::CreateMaskOp, vector::ConstantMaskOp>(op);
}

/// Returns true if the operation is a cheap single-result op that can be
/// cloned as part of a duplicatable chain. These ops must be pure and have
/// exactly one result.
static bool isCheapToClone(Operation *op) {
if (isDuplicatableLeaf(op)) {
return true;
}
return isPure(op) &&
(isa<vector::BroadcastOp, vector::TransposeOp, vector::ShapeCastOp>(
op) ||
OpTrait::hasElementwiseMappableTraits(op));
}

/// Collect a chain of ops that can be cloned together. Starting from `op`,
/// walk backward through single-result, cheap-to-clone ops until we reach
/// duplicatable leaves, constants, or non-vector operands. Returns true if
/// the entire chain is safe to clone. Shared intermediates (with multiple
/// uses) are allowed because all ops in the chain are cheap to duplicate.
static bool collectDuplicatableChain(Operation *op,
SmallVectorImpl<Operation *> &chain) {
// The chain is built bottom-up (from consumer toward producers).
Block *block = op->getBlock();
std::queue<Operation *> worklist;
llvm::SmallPtrSet<Operation *, 8> visited;
worklist.push(op);
while (!worklist.empty()) {
Operation *current = worklist.front();
worklist.pop();
if (!visited.insert(current).second) {
// Operation was already visited.
continue;
}
if (!isCheapToClone(current)) {
return false;
}
chain.push_back(current);
if (chain.size() > kMaxChainLength) {
return false;
}
if (isDuplicatableLeaf(current)) {
continue;
}
for (Value operand : current->getOperands()) {
// Non-vector operands (scalars, indices) don't need cloning.
if (!isa<VectorType>(operand.getType())) {
continue;
}
Operation *defOp = operand.getDefiningOp();
if (!defOp || defOp->getBlock() != block) {
return false;
}
worklist.push(defOp);
}
}
return true;
}

/// Assign a layout to an operand, cloning cheap ops or inserting conversions
/// on conflict.
void LayoutAnalysis::setLayoutOrClone(OpOperand *val,
Expand All @@ -489,17 +560,29 @@ void LayoutAnalysis::setLayoutOrClone(OpOperand *val,
// Different layout -- clone cheap ops or insert to_layout conversion.
OpBuilder b(val->getOwner());
if (Operation *defOp = val->get().getDefiningOp()) {
// Clone constant-like and duplicatable ops per use site.
bool isConstantLike = defOp->hasTrait<OpTrait::ConstantLike>();
bool isDuplicatable =
isa<vector::StepOp, vector::CreateMaskOp, vector::ConstantMaskOp>(
defOp);
if (isConstantLike || isDuplicatable) {
b.setInsertionPoint(defOp);
Operation *cloned = b.clone(*defOp);
val->set(cloned->getResult(0));
resolved[cloned->getResult(0)] = layout;
return;
// Try to clone a chain of cheap ops rooted at duplicatable leaves.
if (isCheapToClone(defOp)) {
SmallVector<Operation *> chain;
if (collectDuplicatableChain(defOp, chain)) {
// Sort so cloning visits producers before consumers.
computeTopologicalSorting(chain);
IRMapping mapping;
b.setInsertionPoint(chain.front());
for (Operation *op : chain) {
b.clone(*op, mapping);
}
Value cloned = mapping.lookup(val->get());
val->set(cloned);
resolved[cloned] = layout;
// Propagate layouts through the cloned chain. The cloned ops are
// not visited by the outer fixupRegion walk (which collects ops
// upfront), so we must fix them up here. Walk in reverse program
// order so that result layouts propagate to operands.
for (Operation *op : llvm::reverse(chain)) {
fixupOp(mapping.lookup(op->getResult(0)).getDefiningOp());
}
return;
}
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,7 @@ iree_lit_test_suite(
"type_propagation.mlir",
"unroll_annotated_loops.mlir",
"vector_layout_analysis.mlir",
"vector_layout_analysis_chain_cloning.mlir",
"vectorize_memref_copy.mlir",
"vectorize_tensor_pad.mlir",
"verify_smt_constraints.mlir",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,7 @@ iree_lit_test_suite(
"type_propagation.mlir"
"unroll_annotated_loops.mlir"
"vector_layout_analysis.mlir"
"vector_layout_analysis_chain_cloning.mlir"
"vectorize_memref_copy.mlir"
"vectorize_tensor_pad.mlir"
"verify_smt_constraints.mlir"
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,148 @@
// RUN: iree-opt --pass-pipeline="builtin.module(func.func(iree-codegen-test-vector-layout-analysis))" --split-input-file %s | FileCheck %s

#layoutA = #iree_vector_ext.nested_layout<
subgroup_tile = [1, 1],
batch_tile = [1, 1],
outer_tile = [1, 1],
thread_tile = [1, 1],
element_tile = [16, 64],

subgroup_strides = [0, 0],
thread_strides = [0, 0]
>

#layoutB = #iree_vector_ext.nested_layout<
subgroup_tile = [1, 1],
batch_tile = [2, 2],
outer_tile = [1, 1],
thread_tile = [1, 1],
element_tile = [8, 32],

subgroup_strides = [0, 0],
thread_strides = [0, 0]
>

// CHECK-LABEL: @clone_mask_chain
// CHECK: %[[STEP_A:.+]] = vector.step
// CHECK: %[[LIMIT_A:.+]] = vector.broadcast %{{.+}} : index to vector<64xindex>
// CHECK: %[[CMPI_A:.+]] = arith.cmpi slt, %[[STEP_A]], %[[LIMIT_A]]
// CHECK: %[[MASK_A:.+]] = vector.broadcast %[[CMPI_A]] : vector<64xi1> to vector<16x64xi1>
// CHECK: %[[STEP_B:.+]] = vector.step
// CHECK: %[[LIMIT_B:.+]] = vector.broadcast %{{.+}} : index to vector<64xindex>
// CHECK: %[[CMPI_B:.+]] = arith.cmpi slt, %[[STEP_B]], %[[LIMIT_B]]
// CHECK: %[[MASK_B:.+]] = vector.broadcast %[[CMPI_B]] : vector<64xi1> to vector<16x64xi1>
// CHECK-NOT: iree_vector_ext.to_layout {{.*}}xi1
// CHECK: arith.select %[[MASK_A]]
// CHECK: arith.select %[[MASK_B]]
func.func @clone_mask_chain(%a: vector<16x64xf16>, %b: vector<16x64xf16>, %n: index) -> (vector<16x64xf16>, vector<16x64xf16>) {
%cst = arith.constant dense<0.0> : vector<16x64xf16>
%step = vector.step : vector<64xindex>
%limit = vector.broadcast %n : index to vector<64xindex>
%mask_1d = arith.cmpi slt, %step, %limit : vector<64xindex>
%mask = vector.broadcast %mask_1d : vector<64xi1> to vector<16x64xi1>
%al = iree_vector_ext.to_layout %a to layout(#layoutA) : vector<16x64xf16>
%bl = iree_vector_ext.to_layout %b to layout(#layoutB) : vector<16x64xf16>
%sa = arith.select %mask, %al, %cst : vector<16x64xi1>, vector<16x64xf16>
%sb = arith.select %mask, %bl, %cst : vector<16x64xi1>, vector<16x64xf16>
func.return %sa, %sb : vector<16x64xf16>, vector<16x64xf16>
}

// -----

#layoutC = #iree_vector_ext.nested_layout<
subgroup_tile = [1, 1, 1],
batch_tile = [1, 8, 1],
outer_tile = [1, 1, 1],
thread_tile = [1, 8, 8],
element_tile = [1, 1, 8],

subgroup_strides = [0, 0, 0],
thread_strides = [0, 8, 1]
>

#layoutD = #iree_vector_ext.nested_layout<
subgroup_tile = [1, 1, 1],
batch_tile = [1, 1, 4],
outer_tile = [1, 1, 1],
thread_tile = [1, 4, 16],
element_tile = [1, 4, 1],

subgroup_strides = [0, 0, 0],
thread_strides = [0, 16, 1]
>

// CHECK-LABEL: @clone_mask_chain_shared_intermediate
// CHECK: %[[CMPI_A:.+]] = arith.cmpi
// CHECK: %[[CMPI_B:.+]] = arith.cmpi
// CHECK: %[[MASK_A:.+]] = vector.broadcast %[[CMPI_A]] : vector<64xi1> to vector<1x64x64xi1>
// CHECK: %[[MASK_B:.+]] = vector.broadcast %[[CMPI_B]] : vector<64xi1> to vector<1x16x64xi1>
// CHECK-NOT: iree_vector_ext.to_layout {{.*}}xi1
// CHECK: arith.select %[[MASK_A]]
// CHECK: arith.select %[[MASK_B]]
func.func @clone_mask_chain_shared_intermediate(
%a: vector<1x64x64xf16>, %b: vector<1x16x64xf32>, %n: index)
-> (vector<1x64x64xf16>, vector<1x16x64xf32>) {
%cst_f16 = arith.constant dense<0.0> : vector<1x64x64xf16>
%cst_f32 = arith.constant dense<0.0> : vector<1x16x64xf32>
%step = vector.step : vector<64xindex>
%limit = vector.broadcast %n : index to vector<64xindex>
%mask_1d = arith.cmpi slt, %step, %limit : vector<64xindex>
%mask_big = vector.broadcast %mask_1d : vector<64xi1> to vector<1x64x64xi1>
%mask_small = vector.broadcast %mask_1d : vector<64xi1> to vector<1x16x64xi1>
%al = iree_vector_ext.to_layout %a to layout(#layoutC) : vector<1x64x64xf16>
%bl = iree_vector_ext.to_layout %b to layout(#layoutD) : vector<1x16x64xf32>
%sa = arith.select %mask_big, %al, %cst_f16 : vector<1x64x64xi1>, vector<1x64x64xf16>
%sb = arith.select %mask_small, %bl, %cst_f32 : vector<1x16x64xi1>, vector<1x16x64xf32>
func.return %sa, %sb : vector<1x64x64xf16>, vector<1x16x64xf32>
}

// -----

// Negative test, chain cloning must stop when the chain reaches a non-cheap op
// (vector.contract in this case). The analysis should insert a `to_layout`
// conversion instead.

#layoutE = #iree_vector_ext.nested_layout<
subgroup_tile = [1, 1],
batch_tile = [1, 1],
outer_tile = [1, 1],
thread_tile = [1, 1],
element_tile = [16, 64],

subgroup_strides = [0, 0],
thread_strides = [0, 0]
>

#layoutF = #iree_vector_ext.nested_layout<
subgroup_tile = [1, 1],
batch_tile = [2, 2],
outer_tile = [1, 1],
thread_tile = [1, 1],
element_tile = [8, 32],

subgroup_strides = [0, 0],
thread_strides = [0, 0]
>

// CHECK-LABEL: @no_clone_non_cheap_producer
// CHECK: %[[NEG:.+]] = arith.negf
// CHECK: %[[CONV:.+]] = iree_vector_ext.to_layout %[[NEG]] to layout({{.+}})
// CHECK: %[[A:.+]] = iree_vector_ext.to_layout %[[CONV]] to layout({{.+}})
// CHECK: %[[B:.+]] = iree_vector_ext.to_layout %[[NEG]] to layout({{.+}})
// CHECK: return %[[A]], %[[B]]
func.func @no_clone_non_cheap_producer(
%lhs: vector<16x32xf16>, %rhs: vector<32x64xf16>, %acc: vector<16x64xf16>)
-> (vector<16x64xf16>, vector<16x64xf16>) {
%contract = vector.contract {
indexing_maps = [
affine_map<(d0, d1, d2) -> (d0, d2)>,
affine_map<(d0, d1, d2) -> (d2, d1)>,
affine_map<(d0, d1, d2) -> (d0, d1)>],
iterator_types = ["parallel", "parallel", "reduction"],
kind = #vector.kind<add>}
%lhs, %rhs, %acc : vector<16x32xf16>, vector<32x64xf16> into vector<16x64xf16>
%neg = arith.negf %contract : vector<16x64xf16>
%a = iree_vector_ext.to_layout %neg to layout(#layoutE) : vector<16x64xf16>
%b = iree_vector_ext.to_layout %neg to layout(#layoutF) : vector<16x64xf16>
func.return %a, %b : vector<16x64xf16>, vector<16x64xf16>
}
Loading