Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 10 additions & 3 deletions src/Compiler/OMCompile.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ namespace fs = std::filesystem;
namespace onnx_mlir {

void OMCompile::compile(const std::string &modelPath, const std::string &flags,
const std::string &logFilename) {
const std::string &compilerPath, const std::string &logFilename) {
// Initialize state.
successfullyCompiled = false;
outputFilename = {};
Expand All @@ -49,11 +49,18 @@ void OMCompile::compile(const std::string &modelPath, const std::string &flags,
inputFilename + "\"");
}
// Determine the onnx-mlir executable path.
std::string compilerFilename;
if (compilerPath.empty()) {
// Default value relying on PATH to locate the binary.
#ifdef _WIN32
std::string compilerFilename = "onnx-mlir.exe";
compilerFilename = "onnx-mlir.exe";
#else
std::string compilerFilename = "onnx-mlir";
compilerFilename = "onnx-mlir";
#endif
} else {
// Use user provided path to binary, including "onnx-mlir"
compilerFilename = compilerPath;
}
// Execute onnx-mlir command with arguments.
int status;
try {
Expand Down
8 changes: 6 additions & 2 deletions src/Compiler/OMCompile.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -103,17 +103,21 @@ class OMCompile {
* blocks until compilation completes or fails.
*
* @param modelPath Path to the input model file (.onnx, .mlir, or .onnxtext).
* Can include a directory path. If empty, the flags
* parameter must contain the input filename.
* Can include a directory path. If empty, the flags parameter must contain
* the input filename.
* @param flags Compilation flags as a single string (e.g., "-O3 -o output").
* Supports quoted strings for paths with spaces.
* @param compilerPath Optional path to the compiler binary, including the
* binary name. If empty (default) standard onnx-mlir binary will be used at
* standard location.
* @param logFilename Optional path to a file where compilation logs will be
* written. If empty, logs go to stdout/stderr.
*
* @throws OMCompileException if compilation fails for any reason
* (invalid input, compiler errors, missing dependencies, etc.)
*/
void compile(const std::string &modelPath, const std::string &flags,
const std::string &compilerPath = {},
const std::string &logFilename = {});

/**
Expand Down
8 changes: 4 additions & 4 deletions src/Runtime/python/PyExecutionSession.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -145,9 +145,9 @@ PYBIND11_MODULE(PyRuntimeC, m) {
"Returns:\n"
" str: Human-readable description of the model's input signature.\n\n"
"Example:\n"
" >>> session = OMExecutionSession('model.so')\n"
" >>> session = OMExecutionSession('mnist.so')\n"
" >>> print(session.input_signature())\n"
" # Output: \"Input 0: name='input', shape=[1,3,224,224], type=float32\"")
" # Output: input signature in json [{\"type\" : \"f32\", \"dims\" : [1,1,28,28], \"name\" : \"image\"}")
.def("output_signature",
&onnx_mlir::PyExecutionSession::pyOutputSignature,
"Get the output signature of the model.\n\n"
Expand All @@ -157,9 +157,9 @@ PYBIND11_MODULE(PyRuntimeC, m) {
"Returns:\n"
" str: Human-readable description of the model's output signature.\n\n"
"Example:\n"
" >>> session = OMExecutionSession('model.so')\n"
" >>> session = OMExecutionSession('mnist.so')\n"
" >>> print(session.output_signature())\n"
" # Output: \"Output 0: name='output', shape=[1,1000], type=float32\"")
" # Output: output signature in json [{\"type\" : \"f32\", \"dims\" : [1,10], \"name\" : \"prediction\"}")
.def("print_instrumentation",
&onnx_mlir::PyExecutionSession::pyPrintInstrumentation,
"Print instrumentation data from the model execution.\n\n"
Expand Down
16 changes: 9 additions & 7 deletions src/Runtime/python/PyOMCompile.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -27,9 +27,9 @@ namespace onnx_mlir {
// Constructor

PyOMCompile::PyOMCompile(std::string modelPath, std::string flags,
const std::string &logFilename, bool reuseCompiledModel)
: OMcompile() /* constructor without compilation */, modelPath(modelPath),
flags(flags) {
const std::string &compilerPath, const std::string &logFilename,
bool reuseCompiledModel)
: OMcompile() /* constructor without compilation */ {

// See if we can reuse a compilation (no check on model or flag
// equivalencies).
Expand All @@ -48,7 +48,7 @@ PyOMCompile::PyOMCompile(std::string modelPath, std::string flags,
// Must compile?
if (!reuseCompiledModel) {
try {
OMcompile.compile(modelPath, flags, logFilename);
OMcompile.compile(modelPath, flags, compilerPath, logFilename);
} catch (const onnx_mlir::OMCompileException &error) {
std::string errorMessage = error.what();
std::cerr << errorMessage << std::endl;
Expand All @@ -61,11 +61,13 @@ PyOMCompile::PyOMCompile(std::string modelPath, std::string flags,
// Custom getters

std::string PyOMCompile::pyGetOutputFilename() {
return onnx_mlir::OMCompile::getOutputFilename(modelPath, flags);
return OMcompile.getOutputFilename();
}

std::string PyOMCompile::pyGetModelTag() {
return onnx_mlir::OMCompile::getModelTag(flags);
std::string PyOMCompile::pyGetOutputConstantFilename() {
return OMcompile.getOutputConstantFilename();
}

std::string PyOMCompile::pyGetModelTag() { return OMcompile.getModelTag(); }

} // namespace onnx_mlir
25 changes: 17 additions & 8 deletions src/Runtime/python/PyOMCompile.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -29,15 +29,14 @@ namespace onnx_mlir {
class PyOMCompile {
public:
PyOMCompile(std::string modelPath, std::string flags,
const std::string &logFilename = {}, bool reuseCompiledModel = true);
const std::string &compilerPath = {}, const std::string &logFilename = {},
bool reuseCompiledModel = true);
std::string pyGetOutputFilename();
std::string pyGetOutputConstantFilename();
std::string pyGetModelTag();

private:
onnx_mlir::OMCompile OMcompile; // To compile a model.
std::string modelPath;
std::string flags;
};

} // namespace onnx_mlir
Expand All @@ -58,10 +57,11 @@ PYBIND11_MODULE(PyOMCompileC, m) {
" >>> compiler = OMCompile('model.onnx', '-O3 -o output')\n"
" >>> output_file = compiler.get_output_file_name()\n"
" >>> print(f'Compiled to: {output_file}')")
.def(py::init<const std::string &, const std::string &,
.def(py::init<const std::string &, const std::string &, const std::string &,
const std::string &, const bool>(),
py::arg("input_model_path"),
py::arg("flags"),
py::arg("compiler_path") = "",
py::arg("log_file_name") = "",
py::arg("reuse_compiled_model") = false,
"Compile an ONNX model.\n\n"
Expand All @@ -71,6 +71,9 @@ PYBIND11_MODULE(PyOMCompileC, m) {
" flags (str): Compilation flags as a single string.\n"
" Examples: '-O3', '-O3 -o output_name', '--EmitLib'.\n"
" All onnx-mlir command-line options are supported.\n"
" compiler_path (str, optional): Path to onnx-mlir compiler binary,\n"
" namely path plus binary name. If empty (default), use onnx-mlir\n"
" at its default location.\n"
" log_file_name (str, optional): Path to log file for compilation output.\n"
" If empty (default), output goes to stdout/stderr.\n"
" reuse_compiled_model (bool, optional): If True, reuse existing compiled\n"
Expand All @@ -95,18 +98,22 @@ PYBIND11_MODULE(PyOMCompileC, m) {
"determined by the input model name and compilation flags (especially\n"
"the '-o' flag if provided).\n\n"
"Returns:\n"
" str: Full path to the compiled model output file.\n\n"
" str: Full path to the compiled model output file.\n"
"Raises:\n"
" RuntimeError: If the compilation failed\n\n"
"Example:\n"
" >>> compiler = OMCompile('mnist.onnx', '-O3 -o mnist_opt')\n"
" >>> output = compiler.get_output_file_name()\n"
" >>> print(output) # e.g., '/home/me/mnist_opt.so' on Linux")
.def("get_output_constant_file_name",
&onnx_mlir::PyOMCompile::pyGetOutputFilename,
&onnx_mlir::PyOMCompile::pyGetOutputConstantFilename,
"Get the output filename of the compiled model constant file, if any.\n\n"
"If the compiler did generate a data constant file, return its\n"
"absolute path; otherwise, return an emtpy string.\n\n"
"Returns:\n"
" str: Full path to the constant file of the compiled model.\n\n"
" str: Full path to the constant file of the compiled model.\n"
"Raises:\n"
" RuntimeError: If the compilation failed\n\n"
"Example:\n"
" >>> compiler = OMCompile('mnist.onnx', '-O3 -o mnist_opt')\n"
" >>> output = compiler.get_output_constant_file_name()\n"
Expand All @@ -118,7 +125,9 @@ PYBIND11_MODULE(PyOMCompileC, m) {
"compilation flags. This can be used for model identification and\n"
"caching purposes.\n\n"
"Returns:\n"
" str: Model tag string.\n\n"
" str: Model tag string.\n"
"Raises:\n"
" RuntimeError: If the compilation failed\n\n"
"Example:\n"
" >>> compiler = OMCompile('model.onnx', '-O3 --tag=key_model')\n"
" >>> tag = compiler.get_model_tag()\n"
Expand Down
Loading