Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,7 @@
## 1.4.2
* Remove unnecessary warnings in CMake for the Linux build
* Improve README and troubleshooting documentation

## 1.4.1
* Support string tensors in all platforms
* Reinforce structure and behavior consistency between Linux and Windows implementations
Expand Down
47 changes: 47 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,53 @@ Clone [this repository](https://github.com/masicai/flutter-onnxruntime-examples)

<sup>1</sup>: Execution Providers (EP) are hardware accelerated inference interface for AI inference (e.g., CPU, GPU, NPU, TPU, etc.)

## 📋 Required development setup

### Android

Android build requires `proguard-rules.pro` inside your Android project at `android/app/` with the following content:
```
-keep class ai.onnxruntime.** { *; }
```
or running the below command from your terminal:

```bash
echo "-keep class ai.onnxruntime.** { *; }" > android/app/proguard-rules.pro
```

Refer to [troubleshooting.md](doc/troubleshooting.md) for more information.

### iOS

ONNX Runtime requires minimum version `iOS 16` and static linkage.

In `ios/Podfile`, change the following lines:
```bash
platform :ios, '16.0'

# existing code ...

use_frameworks! :linkage => :static

# existing code ...
```

### macOS

macOS build requires minimum version `macOS 14`.

* In `macos/Podfile`, change the following lines:
```bash
platform :osx, '14.0'
```

* Change the "Minimum Deployments" to 14.0 in XCode. In your terminal:
```bash
open Runner.xcworkspace
```
In `Runner` -> `General`, change `Minimum Deployments` to `14.0`.


## 🛠️ Troubleshooting

For troubleshooting, see the [troubleshooting.md](doc/troubleshooting.md) file.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -478,7 +478,7 @@ class FlutterOnnxruntimePlugin : FlutterPlugin, MethodCallHandler {
throw e
}
} catch (e: OrtException) {
result.error("ORT_ERROR", e.message, e.stackTraceToString())
result.error("INFERENCE_ERROR", e.message, e.stackTraceToString())
} catch (e: Exception) {
result.error("PLUGIN_ERROR", e.message, e.stackTraceToString())
}
Expand Down
20 changes: 15 additions & 5 deletions doc/troubleshooting.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,10 @@ Common issues and their solutions.

## iOS
* Target minimum version: iOS 16
* Open `ios/Podfile` and change the target minimum version to 16.0
```
platform :ios, '16.0'
```
* "The 'Pods-Runner' target has transitive dependencies that include statically linked binaries: (onnxruntime-objc and onnxruntime-c)". In `Podfile` change:
```
target 'Runner' do
Expand All @@ -26,16 +30,22 @@ Common issues and their solutions.

## macOS
* Target minimum version: MacOS 14
* Open `macos/Podfile` and change the target minimum version to 14.0
```
platform :osx, '14.0'
```
* "error: compiling for macOS 10.14, but module 'flutter_onnxruntime' has a minimum deployment target of macOS 14.0".
* In terminal, cd to the `macos` directory and run the XCode to open the project:
```
open Runner.xcworkspace
```
* In `Runner` -> `General`, change `Minimum Deployments` to `14.0`.
* "The 'Pods-Runner' target has transitive dependencies that include statically linked binaries: (onnxruntime-objc and onnxruntime-c)". In `Podfile` change:
```
target 'Runner' do
use_frameworks! :linkage => :static
```
* "error: compiling for macOS 10.14, but module 'flutter_onnxruntime' has a minimum deployment target of macOS 14.0". In terminal, cd to the `macos` directory and run the XCode to open the project:
```
open Runner.xcworkspace
```
Then change the "Minimum Deployments" to 14.0.


## Linux
* When running with ONNX Runtime 1.21.0, you may see reference counting warnings related to FlValue objects. These don't prevent the app from running but may be addressed in future updates.
50 changes: 50 additions & 0 deletions example/integration_test/all_tests.dart
Original file line number Diff line number Diff line change
Expand Up @@ -795,6 +795,56 @@ void main() {
await tensorA.dispose();
await tensorB.dispose();
});

testWidgets('Invalid input name test', (WidgetTester tester) async {
// Create tensors with correct shapes but using wrong input name
final tensorA = await OrtValue.fromList([1.0, 1.0, 1.0, 1.0, 1.0, 1.0], [1, 2, 3]);
final tensorB = await OrtValue.fromList([2.0, 2.0, 2.0, 2.0, 2.0, 2.0], [1, 3, 2]);
// await session.run({'X': tensorA, 'B': tensorB});

// Use wrong input name (X instead of A)
// Expect to throw an exception with code "INFERENCE_ERROR"
expect(
() async => await session.run({'X': tensorA, 'B': tensorB}),
throwsA(isA<PlatformException>().having((e) => e.code, 'code', "INFERENCE_ERROR")),
);

// Clean up
await tensorA.dispose();
await tensorB.dispose();
});

testWidgets('Random input order test', (WidgetTester tester) async {
// Create tensors with correct shapes
final tensorA = await OrtValue.fromList([1.0, 1.0, 1.0, 1.0, 1.0, 1.0], [1, 2, 3]);
final tensorB = await OrtValue.fromList([2.0, 2.0, 2.0, 2.0, 2.0, 2.0], [1, 3, 2]);

// Run inference with inputs in normal order
final outputsNormal = await session.run({'A': tensorA, 'B': tensorB});
final outputNormal = outputsNormal['C'];

// Run inference with inputs in reverse order
final outputsReversed = await session.run({'B': tensorB, 'A': tensorA});
final outputReversed = outputsReversed['C'];

// Verify both outputs are the same
expect(outputReversed!.dataType, outputNormal!.dataType);
expect(outputReversed.shape, outputNormal.shape);

final outputDataNormal = await outputNormal.asFlattenedList();
final outputDataReversed = await outputReversed.asFlattenedList();

expect(outputDataNormal.length, outputDataReversed.length);
for (int i = 0; i < outputDataNormal.length; i++) {
expect(outputDataReversed[i], outputDataNormal[i]);
}

// Clean up
await tensorA.dispose();
await tensorB.dispose();
await outputNormal.dispose();
await outputReversed.dispose();
});
});

group('INT32 model test', () {
Expand Down
2 changes: 1 addition & 1 deletion example/pubspec.lock
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ packages:
path: ".."
relative: true
source: path
version: "1.4.0"
version: "1.4.2"
flutter_test:
dependency: "direct dev"
description: flutter
Expand Down
2 changes: 1 addition & 1 deletion ios/Classes/FlutterOnnxruntimePlugin.swift
Original file line number Diff line number Diff line change
Expand Up @@ -269,7 +269,7 @@ public class FlutterOnnxruntimePlugin: NSObject, FlutterPlugin {
} catch let error as FlutterError {
result(error)
} catch {
result(FlutterError(code: "RUN_INFERENCE_ERROR", message: error.localizedDescription, details: nil))
result(FlutterError(code: "INFERENCE_ERROR", message: error.localizedDescription, details: nil))
}
}

Expand Down
2 changes: 1 addition & 1 deletion lib/web/flutter_onnxruntime_web_plugin.dart
Original file line number Diff line number Diff line change
Expand Up @@ -333,7 +333,7 @@ class FlutterOnnxruntimeWebPlugin extends FlutterOnnxruntimePlatform {
if (e is PlatformException) {
rethrow;
}
throw PlatformException(code: "PLUGIN_ERROR", message: "Failed to run inference: $e", details: null);
throw PlatformException(code: "INFERENCE_ERROR", message: "Failed to run inference: $e", details: null);
}
}

Expand Down
15 changes: 10 additions & 5 deletions linux/src/flutter_onnxruntime_plugin.cc
Original file line number Diff line number Diff line change
Expand Up @@ -379,11 +379,12 @@ static FlMethodResponse *run_inference(FlutterOnnxruntimePlugin *self, FlValue *
}

try {
std::vector<std::string> input_names = self->session_manager->getInputNames(session_id);
// Get expected input names from the session
std::vector<std::string> output_names = self->session_manager->getOutputNames(session_id);

// Prepare input tensors
// Prepare input tensors and input names
std::vector<Ort::Value> input_tensors;
std::vector<std::string> input_names;

// Iterate through each input
size_t num_inputs = fl_value_get_length(inputs_value);
Expand All @@ -395,6 +396,9 @@ static FlMethodResponse *run_inference(FlutterOnnxruntimePlugin *self, FlValue *
continue;
}

// Extract the input name from the map key
std::string input_name = fl_value_get_string(key);

FlValue *tensor_id_map = fl_value_lookup_string(value, "valueId");

if (tensor_id_map == nullptr || fl_value_get_type(tensor_id_map) != FL_VALUE_TYPE_STRING) {
Expand All @@ -410,6 +414,7 @@ static FlMethodResponse *run_inference(FlutterOnnxruntimePlugin *self, FlValue *
// Use the tensor manager to clone the tensor
Ort::Value new_tensor = self->tensor_manager->cloneTensor(tensor_id);
input_tensors.push_back(std::move(new_tensor));
input_names.push_back(input_name);
} catch (const std::exception &e) {
g_warning("Failed to clone tensor %s: %s", tensor_id.c_str(), e.what());
// Continue with the next tensor
Expand Down Expand Up @@ -443,10 +448,10 @@ static FlMethodResponse *run_inference(FlutterOnnxruntimePlugin *self, FlValue *
}
}

// Run inference using SessionManager
// Run inference using SessionManager with user-provided input names
std::vector<Ort::Value> output_tensors;
if (!input_tensors.empty()) {
output_tensors = self->session_manager->runInference(session_id, input_tensors, &run_options);
output_tensors = self->session_manager->runInference(session_id, input_tensors, input_names, &run_options);
}

// Process outputs
Expand Down Expand Up @@ -481,7 +486,7 @@ static FlMethodResponse *run_inference(FlutterOnnxruntimePlugin *self, FlValue *
}
return FL_METHOD_RESPONSE(fl_method_success_response_new(outputs_map));
} catch (const Ort::Exception &e) {
return FL_METHOD_RESPONSE(fl_method_error_response_new("INFERENCE_FAILED", e.what(), nullptr));
return FL_METHOD_RESPONSE(fl_method_error_response_new("INFERENCE_ERROR", e.what(), nullptr));
} catch (const std::exception &e) {
return FL_METHOD_RESPONSE(fl_method_error_response_new("PLUGIN_ERROR", e.what(), nullptr));
}
Expand Down
9 changes: 7 additions & 2 deletions linux/src/session_manager.cc
Original file line number Diff line number Diff line change
Expand Up @@ -296,9 +296,10 @@ std::vector<TensorInfo> SessionManager::getOutputInfo(const std::string &session
return info_list;
}

// Run inference
// Run inference with provided input names
std::vector<Ort::Value> SessionManager::runInference(const std::string &session_id,
const std::vector<Ort::Value> &input_tensors,
const std::vector<std::string> &input_names,
Ort::RunOptions *run_options) {

std::lock_guard<std::mutex> lock(mutex_);
Expand All @@ -318,9 +319,13 @@ std::vector<Ort::Value> SessionManager::runInference(const std::string &session_
throw Ort::Exception("No input tensors provided", ORT_INVALID_ARGUMENT);
}

if (input_names.size() != input_tensors.size()) {
throw Ort::Exception("Number of input names must match number of input tensors", ORT_INVALID_ARGUMENT);
}

// Prepare input names
std::vector<const char *> input_names_char;
for (const auto &name : it->second.input_names) {
for (const auto &name : input_names) {
input_names_char.push_back(name.c_str());
}

Expand Down
3 changes: 2 additions & 1 deletion linux/src/session_manager.h
Original file line number Diff line number Diff line change
Expand Up @@ -71,8 +71,9 @@ class SessionManager {
// Get output tensor info for a session
std::vector<TensorInfo> getOutputInfo(const std::string &session_id);

// Run inference with a session
// Run inference with a session using provided input names
std::vector<Ort::Value> runInference(const std::string &session_id, const std::vector<Ort::Value> &input_tensors,
const std::vector<std::string> &input_names,
Ort::RunOptions *run_options = nullptr);

// Helper method to get element type string
Expand Down
2 changes: 1 addition & 1 deletion macos/Classes/FlutterOnnxruntimePlugin.swift
Original file line number Diff line number Diff line change
Expand Up @@ -270,7 +270,7 @@ public class FlutterOnnxruntimePlugin: NSObject, FlutterPlugin {
} catch let error as FlutterError {
result(error)
} catch {
result(FlutterError(code: "RUN_INFERENCE_ERROR", message: error.localizedDescription, details: nil))
result(FlutterError(code: "INFERENCE_ERROR", message: error.localizedDescription, details: nil))
}
}

Expand Down
2 changes: 1 addition & 1 deletion pubspec.yaml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
name: flutter_onnxruntime
description: "A lightweight plugin that provides native wrappers for running ONNX Runtime on multiple platforms"
version: 1.4.1
version: 1.4.2
homepage: https://github.com/masicai/flutter_onnxruntime
repository: https://github.com/masicai/flutter_onnxruntime

Expand Down
2 changes: 1 addition & 1 deletion windows/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ if(USE_SYSTEM_ONNXRUNTIME)
message(STATUS "Found ONNX Runtime: ${ONNXRUNTIME_LIBRARY}")
else()
set(ONNXRUNTIME_FOUND FALSE)
message(WARNING "System ONNX Runtime not found. Falling back to downloaded version.")
message(STATUS "System ONNX Runtime not found. Falling back to downloaded version.")
set(USE_SYSTEM_ONNXRUNTIME OFF)
endif()
endif()
Expand Down
15 changes: 9 additions & 6 deletions windows/flutter_onnxruntime_plugin.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -694,12 +694,11 @@ void FlutterOnnxruntimePlugin::HandleRunInference(
}
}

// Get input and output names
std::vector<std::string> input_names = impl_->sessionManager_->getInputNames(session_id);
std::vector<std::string> output_names = impl_->sessionManager_->getOutputNames(session_id);

// Prepare input tensors
// Prepare input tensors and input names
std::vector<Ort::Value> input_tensors;
std::vector<std::string> input_names;

// Iterate through each input
for (const auto &input_pair : inputs_map) {
Expand All @@ -708,6 +707,9 @@ void FlutterOnnxruntimePlugin::HandleRunInference(
continue;
}

// Extract the input name from the map key
std::string input_name = std::get<std::string>(input_pair.first);

const auto &input_value_map = std::get<flutter::EncodableMap>(input_pair.second);
auto tensor_id_it = input_value_map.find(flutter::EncodableValue("valueId"));

Expand All @@ -725,6 +727,7 @@ void FlutterOnnxruntimePlugin::HandleRunInference(
Ort::Value cloned_tensor = impl_->tensorManager_->cloneTensor(tensor_id);
if (cloned_tensor) {
input_tensors.push_back(std::move(cloned_tensor));
input_names.push_back(input_name);
}
} catch (const std::exception &e) {
// Log the error but continue with the next tensor
Expand All @@ -733,10 +736,10 @@ void FlutterOnnxruntimePlugin::HandleRunInference(
}
}

// Run inference using SessionManager
// Run inference using SessionManager with input names
std::vector<Ort::Value> output_tensors;
if (!input_tensors.empty()) {
output_tensors = impl_->sessionManager_->runInference(session_id, input_tensors, &run_options);
output_tensors = impl_->sessionManager_->runInference(session_id, input_tensors, input_names, &run_options);
}

// Process outputs
Expand Down Expand Up @@ -774,7 +777,7 @@ void FlutterOnnxruntimePlugin::HandleRunInference(

result->Success(flutter::EncodableValue(outputs_map));
} catch (const Ort::Exception &e) {
result->Error("INFERENCE_FAILED", e.what(), nullptr);
result->Error("INFERENCE_ERROR", e.what(), nullptr);
} catch (const std::exception &e) {
result->Error("PLUGIN_ERROR", e.what(), nullptr);
} catch (...) {
Expand Down
9 changes: 7 additions & 2 deletions windows/src/session_manager.cc
Original file line number Diff line number Diff line change
Expand Up @@ -296,9 +296,10 @@ std::vector<TensorInfo> SessionManager::getOutputInfo(const std::string &session
return info_list;
}

// Run inference
// Run inference with provided input names
std::vector<Ort::Value> SessionManager::runInference(const std::string &session_id,
const std::vector<Ort::Value> &input_tensors,
const std::vector<std::string> &input_names,
Ort::RunOptions *run_options) {

std::lock_guard<std::mutex> lock(mutex_);
Expand All @@ -318,9 +319,13 @@ std::vector<Ort::Value> SessionManager::runInference(const std::string &session_
throw Ort::Exception("No input tensors provided", ORT_INVALID_ARGUMENT);
}

if (input_names.size() != input_tensors.size()) {
throw Ort::Exception("Number of input names must match number of input tensors", ORT_INVALID_ARGUMENT);
}

// Prepare input names
std::vector<const char *> input_names_char;
for (const auto &name : it->second.input_names) {
for (const auto &name : input_names) {
input_names_char.push_back(name.c_str());
}

Expand Down
Loading
Loading