Skip to content

Commit aa969ce

Browse files
committed
add initial cost-model bindings
1 parent 1a98ba4 commit aa969ce

File tree

14 files changed

+436
-38
lines changed

14 files changed

+436
-38
lines changed

.github/workflows/build-cinnamon.sh

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -98,6 +98,8 @@ configure() {
9898
-G Ninja
9999
-DCMAKE_BUILD_TYPE=RelWithDebInfo
100100
-DCMAKE_EXPORT_COMPILE_COMMANDS=ON
101+
-DLLVM_ENABLE_EH=ON
102+
-DLLVM_ENABLE_RTTI=ON
101103
)
102104

103105
if ((${#DEP_OPTS[@]})); then

.github/workflows/build-llvm.sh

Lines changed: 10 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -81,21 +81,20 @@ fi
8181
status "Configuring LLVM (Ninja; always run to catch changes)"
8282
print_and_run cmake -S llvm -B build -G Ninja \
8383
-Wno-dev \
84-
-DLLVM_ENABLE_PROJECTS="$LLVM_PROJECTS" \
85-
-DLLVM_TARGETS_TO_BUILD="$LLVM_TARGETS_TO_BUILD" \
86-
-DLLVM_EXPERIMENTAL_TARGETS_TO_BUILD="$LLVM_EXPERIMENTAL_TARGETS" \
87-
-DLLVM_ENABLE_ASSERTIONS=ON \
88-
-DMLIR_ENABLE_BINDINGS_PYTHON=ON \
89-
-DLLVM_BUILD_TOOLS=ON \
9084
-DCMAKE_BUILD_TYPE=Release \
9185
-DBUILD_SHARED_LIBS=ON \
92-
-DLLVM_INCLUDE_TESTS=OFF \
86+
-DLLVM_BUILD_TOOLS=ON \
87+
-DLLVM_CCACHE_BUILD=ON \
88+
-DLLVM_ENABLE_PROJECTS="$LLVM_PROJECTS" \
89+
-DLLVM_ENABLE_ASSERTIONS=ON \
90+
-DLLVM_ENABLE_EH=ON \
91+
-DLLVM_ENABLE_RTTI=ON \
92+
-DLLVM_EXPERIMENTAL_TARGETS_TO_BUILD="$LLVM_EXPERIMENTAL_TARGETS" \
9393
-DLLVM_INCLUDE_BENCHMARKS=OFF \
94+
-DLLVM_INCLUDE_TESTS=OFF \
9495
-DLLVM_OPTIMIZED_TABLEGEN=ON \
95-
-DLLVM_CCACHE_BUILD=ON \
96-
-DLLVM_PARALLEL_COMPILE_JOBS=4 \
97-
-DLLVM_PARALLEL_LINK_JOBS=1 \
98-
-DLLVM_PARALLEL_TABLEGEN_JOBS=4 \
96+
-DMLIR_ENABLE_BINDINGS_PYTHON=ON \
97+
-DLLVM_TARGETS_TO_BUILD="$LLVM_TARGETS_TO_BUILD" \
9998
"${EXTRA_CMAKE_OPTS[@]}"
10099

101100
# Save config hash so we can detect future changes

.github/workflows/build-torch.sh

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,8 @@ if [[ $checkout_and_build_torch_mlir -eq 1 ]]; then
7676
$dependency_paths \
7777
-Wno-dev \
7878
-DCMAKE_BUILD_TYPE=Release \
79+
-DLLVM_ENABLE_EH=ON \
80+
-DLLVM_ENABLE_RTTI=ON \
7981
-DTORCH_MLIR_OUT_OF_TREE_BUILD=ON \
8082
-DTORCH_MLIR_ENABLE_STABLEHLO=OFF \
8183
-U CMAKE_EXE_LINKER_FLAGS -U CMAKE_SHARED_LINKER_FLAGS \

.gitignore

Lines changed: 6 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,11 @@
11
.cache
2-
/.vscode/
3-
/.idea/
42
.directory
5-
/.venv/
6-
/third-party/ALPINE/
7-
/third-party/llvm
8-
/third-party/torch-mlir/
9-
/third-party/upmem*/
10-
/.env
11-
/testbench/gen
3+
.env
4+
.idea/
5+
.venv/
6+
.vscode/
7+
third-party
8+
testbench/gen
129

1310

1411
# Created by https://www.toptal.com/developers/gitignore/api/cmake
@@ -40,5 +37,3 @@ sandbox
4037
python/cinnamon/_resources
4138
.env
4239
**/__pycache__
43-
44-

.gitmodules

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
[submodule "third-party/pybind11"]
2+
path = third-party/pybind11
3+
url = https://github.com/pybind/pybind11
4+
branch = stable

CMakeLists.txt

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ if (TORCH_MLIR_DIR)
3434
get_property(CINM_DEPENDENCY_LIBS GLOBAL PROPERTY CINM_DEPENDENCY_LIBS)
3535
list(APPEND CINM_DEPENDENCY_LIBS ${TORCH_MLIR_LIBS})
3636
set_property(GLOBAL PROPERTY CINM_DEPENDENCY_LIBS ${CINM_DEPENDENCY_LIBS})
37-
37+
3838
else()
3939
message(WARNING "TORCH_MLIR_DIR not set, torch frontend wont be available.")
4040
endif()
@@ -111,11 +111,11 @@ if (TORCH_MLIR_DIR)
111111
add_library(Torch::TorchMLIRTorchConversionDialect STATIC IMPORTED GLOBAL)
112112
set_property(TARGET Torch::TorchMLIRTorchConversionDialect PROPERTY
113113
IMPORTED_LOCATION ${TORCH_MLIR_DIR}/lib/libTorchMLIRTorchConversionDialect.a)
114-
114+
115115
add_library(Torch::All INTERFACE IMPORTED)
116116
set_property(TARGET Torch::All PROPERTY
117-
INTERFACE_LINK_LIBRARIES Torch::TorchMLIRTorchDialect
118-
Torch::TorchMLIRTorchUtils
117+
INTERFACE_LINK_LIBRARIES Torch::TorchMLIRTorchDialect
118+
Torch::TorchMLIRTorchUtils
119119
Torch::TorchMLIRTorchConversionDialect)
120120
target_include_directories(Torch::All INTERFACE ${TORCH_MLIR_INCLUDE_DIR})
121121

@@ -135,6 +135,11 @@ else()
135135
endif()
136136

137137

138+
# Add pybind11
139+
set(PYBIND11_FINDPYTHON ON)
140+
add_subdirectory("third-party/pybind11")
141+
142+
138143
add_subdirectory(include)
139144
add_subdirectory(lib)
140145
add_subdirectory(test)

cost_model_test.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
import random
2+
3+
from xdsl.context import Context
4+
from xdsl.parser import Parser
5+
6+
name = "cost_model_test"
7+
passes = "cinm-tiling"
8+
operations = {
9+
"cinm.compute"
10+
}
11+
12+
ctx = Context(allow_unregistered=True)
13+
14+
# def run(op: str, elementType: str, operand_dimensions: list[list[int]], location: str) -> float :
15+
# print(op, elementType, operand_dimensions, location)
16+
# return random.uniform(0.0, 10.0)
17+
18+
def run(ir: str, location: str) -> float :
19+
parser = Parser(ctx, ir)
20+
compute_op = parser.parse_operation()
21+
print(ir)
22+
23+
print(parser.forward_ssa_references)
24+
25+
for op in compute_op.walk():
26+
op.name = op.get_attr_or_prop("op_name__").data
27+
28+
cinm_ops = [op for op in compute_op.walk()]
29+
return len(cinm_ops) + random.uniform(0.0, 10.0)

cost_model_test2.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
import random
2+
3+
from xdsl.context import Context
4+
from xdsl.parser import Parser
5+
6+
name = "cost_model_test_2"
7+
passes = ""
8+
operations = {
9+
"cinm.compute"
10+
}
11+
12+
ctx = Context(allow_unregistered=True)
13+
14+
# def run(op: str, elementType: str, operand_dimensions: list[list[int]], location: str) -> float :
15+
# print(op, elementType, operand_dimensions, location)
16+
# return random.uniform(0.0, 10.0)
17+
18+
def run(ir: str, location: str) -> float :
19+
parser = Parser(ctx, ir)
20+
compute_op = parser.parse_operation()
21+
print(ir)
22+
23+
print(parser.forward_ssa_references)
24+
25+
for op in compute_op.walk():
26+
op.name = op.get_attr_or_prop("op_name__").data
27+
28+
cinm_ops = [op for op in compute_op.walk()]
29+
return len(cinm_ops) + random.uniform(0.0, 10.0)

include/cinm-mlir/Dialect/Cinm/IR/CinmAttributes.td

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -104,4 +104,19 @@ def Cinm_RoundingMode : I32EnumAttr<
104104
let cppNamespace = "::mlir::cinm";
105105
}
106106

107+
//===----------------------------------------------------------------------===//
108+
// Cost-Model
109+
//===----------------------------------------------------------------------===//
110+
111+
def Cinm_CostModelData : Cinm_Attr<"CostModelData"> {
112+
let parameters = (ins
113+
"StringAttr": $name,
114+
"FloatAttr": $cost
115+
);
116+
117+
let mnemonic = "cost_model_data";
118+
let assemblyFormat = "`<` struct($name, $cost) `>`";
119+
}
120+
121+
107122
#endif

include/cinm-mlir/Dialect/Cinm/IR/CinmOps.td

Lines changed: 19 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@ def Cinm_ElementwiseOp
6060
];
6161
}
6262

63-
class Cinm_GemmlikeOp<string name, string cppClassName> : Cinm_Op<name,
63+
class Cinm_GemmlikeOp<string name, string cppClassName> : Cinm_Op<name,
6464
[
6565
Pure,
6666
InferTensorTypeAdaptor,
@@ -78,12 +78,12 @@ class Cinm_GemmlikeOp<string name, string cppClassName> : Cinm_Op<name,
7878
);
7979

8080
let results = (outs Optional<AnyRankedTensor>:$result);
81-
let assemblyFormat =
81+
let assemblyFormat =
8282
"$lhs `,` $rhs (`plus` $bias^)? (`into` $out^)? attr-dict `:` type($lhs) `,` type($rhs) (`plus` type($bias)^)? (`into` type($out)^):(`->` type($result))?";
8383

8484
let skipDefaultBuilders = 1;
8585
let builders = [
86-
OpBuilder<(ins "Value":$left, "Value":$right, "Value":$bias, "Value":$out),
86+
OpBuilder<(ins "Value":$left, "Value":$right, "Value":$bias, "Value":$out),
8787
"buildGemmLikeOp<" # cppClassName # ">($_builder, $_state, left, right, bias, out);">,
8888
OpBuilder<(ins "Value":$left, "Value":$right, "Value":$bias),
8989
"build($_builder, $_state, left, right, bias, Value{});">,
@@ -480,4 +480,20 @@ def Cinm_DequantizeOp : Cinm_Op<"op.dequantize", [Pure]> {
480480
let hasCustomAssemblyFormat = true;
481481
}
482482

483+
484+
//===----------------------------------------------------------------------===//
485+
// SelectOp
486+
//===----------------------------------------------------------------------===//
487+
488+
def Cinm_SelectOp : Cinm_Op<"select", [Pure]> {
489+
let summary = "Select one of multiple child-regions depending on the cost-model estimates";
490+
let description = [{}];
491+
492+
let results = (outs Variadic<AnyType>:$results);
493+
let regions = (region VariadicRegion<AnyRegion>:$regions);
494+
495+
let assemblyFormat = "attr-dict `from` $regions (`->` type($results)^)?";
496+
}
497+
498+
483499
#endif

0 commit comments

Comments
 (0)