diff --git a/Package.resolved b/Package.resolved index 2c72e0702..40b558dc2 100644 --- a/Package.resolved +++ b/Package.resolved @@ -9,15 +9,6 @@ "version" : "5.10.2" } }, - { - "identity" : "bitbytedata", - "kind" : "remoteSourceControl", - "location" : "https://github.com/tsolomko/BitByteData", - "state" : { - "revision" : "cdcdc5177ad536cfb11b95c620f926a81014b7fe", - "version" : "2.0.4" - } - }, { "identity" : "devicekit", "kind" : "remoteSourceControl", @@ -54,15 +45,6 @@ "version" : "8.58.0" } }, - { - "identity" : "swcompression", - "kind" : "remoteSourceControl", - "location" : "https://github.com/tsolomko/SWCompression.git", - "state" : { - "revision" : "390e0b0af8dd19a600005a242a89e570ff482e09", - "version" : "4.8.6" - } - }, { "identity" : "swift-argument-parser", "kind" : "remoteSourceControl", @@ -125,15 +107,6 @@ "revision" : "664e1b5a65296cd957dfdf262cd120ca88f3b24b", "version" : "0.15.0" } - }, - { - "identity" : "zipfoundation", - "kind" : "remoteSourceControl", - "location" : "https://github.com/weichsel/ZIPFoundation.git", - "state" : { - "revision" : "22787ffb59de99e5dc1fbfe80b19c97a904ad48d", - "version" : "0.9.20" - } } ], "version" : 2 diff --git a/Package.swift b/Package.swift index 9caf39474..4ff4d47f0 100644 --- a/Package.swift +++ b/Package.swift @@ -86,9 +86,7 @@ let package = Package( .package(url: "https://github.com/apple/swift-crypto.git", from: "3.0.0"), .package(url: "https://github.com/Alamofire/Alamofire.git", from: "5.9.0"), .package(url: "https://github.com/JohnSundell/Files.git", from: "4.3.0"), - .package(url: "https://github.com/weichsel/ZIPFoundation.git", from: "0.9.0"), .package(url: "https://github.com/devicekit/DeviceKit.git", from: "5.6.0"), - .package(url: "https://github.com/tsolomko/SWCompression.git", from: "4.8.0"), .package(url: "https://github.com/getsentry/sentry-cocoa", from: "8.40.0"), // ml-stable-diffusion for CoreML-based image generation .package(url: "https://github.com/apple/ml-stable-diffusion.git", from: "1.1.0"), @@ -139,9 +137,7 @@ let package = Package( .product(name: "Crypto", package: "swift-crypto"), .product(name: "Alamofire", package: "Alamofire"), .product(name: "Files", package: "Files"), - .product(name: "ZIPFoundation", package: "ZIPFoundation"), .product(name: "DeviceKit", package: "DeviceKit"), - .product(name: "SWCompression", package: "SWCompression"), .product(name: "Sentry", package: "sentry-cocoa"), .product(name: "StableDiffusion", package: "ml-stable-diffusion"), "CRACommons", diff --git a/Playground/README.md b/Playground/README.md index 82db7b056..8e1f575aa 100644 --- a/Playground/README.md +++ b/Playground/README.md @@ -42,7 +42,7 @@ A complete on-device voice AI pipeline for Linux (Raspberry Pi 5, x86_64, ARM64) - **Large Language Model** — Qwen2.5 0.5B Q4 via llama.cpp (fully local) - **Text-to-Speech** — Piper Lessac Medium neural TTS -**Requirements:** Linux (ALSA), x86_64 or ARM64, CMake 3.16+, C++17 +**Requirements:** Linux (ALSA), x86_64 or ARM64, CMake 3.16+, C++20 ## swift-starter-app @@ -94,4 +94,4 @@ A hybrid voice assistant that combines on-device AI inference with cloud LLM rea - **Barge-in Support** — Wake word during TTS playback cancels speech and re-listens - **Waiting Chime** — Earcon feedback while waiting for cloud response -**Requirements:** Linux (ALSA), x86_64 or ARM64, CMake 3.16+, C++17 +**Requirements:** Linux (ALSA), x86_64 or ARM64, CMake 3.16+, C++20 diff --git a/Playground/linux-voice-assistant/CMakeLists.txt b/Playground/linux-voice-assistant/CMakeLists.txt index cb7d01b81..1b96dcd10 100644 --- a/Playground/linux-voice-assistant/CMakeLists.txt +++ b/Playground/linux-voice-assistant/CMakeLists.txt @@ -31,7 +31,7 @@ project(linux-voice-assistant DESCRIPTION "Linux Voice Assistant using RunAnywhere Commons" ) -set(CMAKE_CXX_STANDARD 17) +set(CMAKE_CXX_STANDARD 20) set(CMAKE_CXX_STANDARD_REQUIRED ON) # ============================================================================= diff --git a/Playground/linux-voice-assistant/README.md b/Playground/linux-voice-assistant/README.md index 43dd197b2..01ffd5f18 100644 --- a/Playground/linux-voice-assistant/README.md +++ b/Playground/linux-voice-assistant/README.md @@ -56,7 +56,7 @@ linux-voice-assistant/ - Linux (Raspberry Pi 5, Ubuntu, Debian, etc.) - CMake 3.16+ -- C++17 compiler (g++ or clang++) +- C++20 compiler (g++ or clang++) - ALSA development headers: `sudo apt install libasound2-dev` ### Build and Run diff --git a/Playground/openclaw-hybrid-assistant/CMakeLists.txt b/Playground/openclaw-hybrid-assistant/CMakeLists.txt index b55bacdbc..2fdb5efbf 100644 --- a/Playground/openclaw-hybrid-assistant/CMakeLists.txt +++ b/Playground/openclaw-hybrid-assistant/CMakeLists.txt @@ -1,7 +1,7 @@ cmake_minimum_required(VERSION 3.16) project(openclaw-hybrid-assistant VERSION 0.1.0 LANGUAGES CXX) -set(CMAKE_CXX_STANDARD 17) +set(CMAKE_CXX_STANDARD 20) set(CMAKE_CXX_STANDARD_REQUIRED ON) # ============================================================================= diff --git a/examples/flutter/RunAnywhereAI/android/app/src/main/java/io/flutter/plugins/GeneratedPluginRegistrant.java b/examples/flutter/RunAnywhereAI/android/app/src/main/java/io/flutter/plugins/GeneratedPluginRegistrant.java index 5baa6f45c..bbc04e83d 100644 --- a/examples/flutter/RunAnywhereAI/android/app/src/main/java/io/flutter/plugins/GeneratedPluginRegistrant.java +++ b/examples/flutter/RunAnywhereAI/android/app/src/main/java/io/flutter/plugins/GeneratedPluginRegistrant.java @@ -20,11 +20,21 @@ public static void registerWith(@NonNull FlutterEngine flutterEngine) { } catch (Exception e) { Log.e(TAG, "Error registering plugin audioplayers_android, xyz.luan.audioplayers.AudioplayersPlugin", e); } + try { + flutterEngine.getPlugins().add(new io.flutter.plugins.camerax.CameraAndroidCameraxPlugin()); + } catch (Exception e) { + Log.e(TAG, "Error registering plugin camera_android_camerax, io.flutter.plugins.camerax.CameraAndroidCameraxPlugin", e); + } try { flutterEngine.getPlugins().add(new dev.fluttercommunity.plus.device_info.DeviceInfoPlusPlugin()); } catch (Exception e) { Log.e(TAG, "Error registering plugin device_info_plus, dev.fluttercommunity.plus.device_info.DeviceInfoPlusPlugin", e); } + try { + flutterEngine.getPlugins().add(new io.flutter.plugins.flutter_plugin_android_lifecycle.FlutterAndroidLifecyclePlugin()); + } catch (Exception e) { + Log.e(TAG, "Error registering plugin flutter_plugin_android_lifecycle, io.flutter.plugins.flutter_plugin_android_lifecycle.FlutterAndroidLifecyclePlugin", e); + } try { flutterEngine.getPlugins().add(new com.it_nomads.fluttersecurestorage.FlutterSecureStoragePlugin()); } catch (Exception e) { @@ -35,6 +45,11 @@ public static void registerWith(@NonNull FlutterEngine flutterEngine) { } catch (Exception e) { Log.e(TAG, "Error registering plugin flutter_tts, com.tundralabs.fluttertts.FlutterTtsPlugin", e); } + try { + flutterEngine.getPlugins().add(new io.flutter.plugins.imagepicker.ImagePickerPlugin()); + } catch (Exception e) { + Log.e(TAG, "Error registering plugin image_picker_android, io.flutter.plugins.imagepicker.ImagePickerPlugin", e); + } try { flutterEngine.getPlugins().add(new dev.fluttercommunity.plus.packageinfo.PackageInfoPlugin()); } catch (Exception e) { diff --git a/examples/flutter/RunAnywhereAI/ios/Runner/GeneratedPluginRegistrant.m b/examples/flutter/RunAnywhereAI/ios/Runner/GeneratedPluginRegistrant.m index 500cb4c19..08477fd00 100644 --- a/examples/flutter/RunAnywhereAI/ios/Runner/GeneratedPluginRegistrant.m +++ b/examples/flutter/RunAnywhereAI/ios/Runner/GeneratedPluginRegistrant.m @@ -12,6 +12,12 @@ @import audioplayers_darwin; #endif +#if __has_include() +#import +#else +@import camera_avfoundation; +#endif + #if __has_include() #import #else @@ -30,6 +36,12 @@ @import flutter_tts; #endif +#if __has_include() +#import +#else +@import image_picker_ios; +#endif + #if __has_include() #import #else @@ -94,9 +106,11 @@ @implementation GeneratedPluginRegistrant + (void)registerWithRegistry:(NSObject*)registry { [AudioplayersDarwinPlugin registerWithRegistrar:[registry registrarForPlugin:@"AudioplayersDarwinPlugin"]]; + [CameraPlugin registerWithRegistrar:[registry registrarForPlugin:@"CameraPlugin"]]; [FPPDeviceInfoPlusPlugin registerWithRegistrar:[registry registrarForPlugin:@"FPPDeviceInfoPlusPlugin"]]; [FlutterSecureStoragePlugin registerWithRegistrar:[registry registrarForPlugin:@"FlutterSecureStoragePlugin"]]; [FlutterTtsPlugin registerWithRegistrar:[registry registrarForPlugin:@"FlutterTtsPlugin"]]; + [FLTImagePickerPlugin registerWithRegistrar:[registry registrarForPlugin:@"FLTImagePickerPlugin"]]; [FPPPackageInfoPlusPlugin registerWithRegistrar:[registry registrarForPlugin:@"FPPPackageInfoPlusPlugin"]]; [PathProviderPlugin registerWithRegistrar:[registry registrarForPlugin:@"PathProviderPlugin"]]; [PermissionHandlerPlugin registerWithRegistrar:[registry registrarForPlugin:@"PermissionHandlerPlugin"]]; diff --git a/examples/ios/RunAnywhereAI/RunAnywhereAI.xcodeproj/project.xcworkspace/xcshareddata/swiftpm/Package.resolved b/examples/ios/RunAnywhereAI/RunAnywhereAI.xcodeproj/project.xcworkspace/xcshareddata/swiftpm/Package.resolved index 2e9eca2fc..3288057ed 100644 --- a/examples/ios/RunAnywhereAI/RunAnywhereAI.xcodeproj/project.xcworkspace/xcshareddata/swiftpm/Package.resolved +++ b/examples/ios/RunAnywhereAI/RunAnywhereAI.xcodeproj/project.xcworkspace/xcshareddata/swiftpm/Package.resolved @@ -1,5 +1,5 @@ { - "originHash" : "3460a3a28d41ad2e683d917ec52c81d360e2e01615debdddb4f13b7b027f04df", + "originHash" : "070e7557fa52110bb9abeaaa337e8dc0e5977347ec91cbfb35b730d5956bf217", "pins" : [ { "identity" : "alamofire", @@ -10,15 +10,6 @@ "version" : "5.11.1" } }, - { - "identity" : "bitbytedata", - "kind" : "remoteSourceControl", - "location" : "https://github.com/tsolomko/BitByteData", - "state" : { - "revision" : "cdcdc5177ad536cfb11b95c620f926a81014b7fe", - "version" : "2.0.4" - } - }, { "identity" : "devicekit", "kind" : "remoteSourceControl", @@ -55,15 +46,6 @@ "version" : "8.58.0" } }, - { - "identity" : "swcompression", - "kind" : "remoteSourceControl", - "location" : "https://github.com/tsolomko/SWCompression.git", - "state" : { - "revision" : "390e0b0af8dd19a600005a242a89e570ff482e09", - "version" : "4.8.6" - } - }, { "identity" : "swift-argument-parser", "kind" : "remoteSourceControl", @@ -126,15 +108,6 @@ "revision" : "664e1b5a65296cd957dfdf262cd120ca88f3b24b", "version" : "0.15.0" } - }, - { - "identity" : "zipfoundation", - "kind" : "remoteSourceControl", - "location" : "https://github.com/weichsel/ZIPFoundation.git", - "state" : { - "revision" : "22787ffb59de99e5dc1fbfe80b19c97a904ad48d", - "version" : "0.9.20" - } } ], "version" : 3 diff --git a/examples/react-native/RunAnywhereAI/package-lock.json b/examples/react-native/RunAnywhereAI/package-lock.json index fba3fbdd2..76ad4b315 100644 --- a/examples/react-native/RunAnywhereAI/package-lock.json +++ b/examples/react-native/RunAnywhereAI/package-lock.json @@ -17,6 +17,7 @@ "@runanywhere/core": "file:../../../sdk/runanywhere-react-native/packages/core", "@runanywhere/llamacpp": "file:../../../sdk/runanywhere-react-native/packages/llamacpp", "@runanywhere/onnx": "file:../../../sdk/runanywhere-react-native/packages/onnx", + "@runanywhere/rag": "file:../../../sdk/runanywhere-react-native/packages/rag", "react": "19.2.0", "react-native": "0.83.1", "react-native-document-picker": "^9.3.1", @@ -76,8 +77,7 @@ "react-native-blob-util": ">=0.19.0", "react-native-device-info": ">=11.0.0", "react-native-fs": ">=2.20.0", - "react-native-nitro-modules": ">=0.31.3", - "react-native-zip-archive": ">=6.1.0" + "react-native-nitro-modules": ">=0.31.3" }, "peerDependenciesMeta": { "react-native-blob-util": { @@ -88,9 +88,6 @@ }, "react-native-fs": { "optional": true - }, - "react-native-zip-archive": { - "optional": true } } }, diff --git a/examples/react-native/RunAnywhereAI/yarn.lock b/examples/react-native/RunAnywhereAI/yarn.lock index 31f1866cb..394b429b3 100644 --- a/examples/react-native/RunAnywhereAI/yarn.lock +++ b/examples/react-native/RunAnywhereAI/yarn.lock @@ -1479,6 +1479,10 @@ version "0.18.1" resolved "file:../../../sdk/runanywhere-react-native/packages/onnx" +"@runanywhere/rag@file:../../../sdk/runanywhere-react-native/packages/rag": + version "0.1.0" + resolved "file:../../../sdk/runanywhere-react-native/packages/rag" + "@sideway/address@^4.1.5": version "4.1.5" resolved "https://registry.npmjs.org/@sideway/address/-/address-4.1.5.tgz" @@ -5120,13 +5124,6 @@ react-is@^19.1.0: resolved "https://registry.npmjs.org/react-is/-/react-is-19.2.4.tgz" integrity sha512-W+EWGn2v0ApPKgKKCy/7s7WHXkboGcsrXE+2joLyVxkbyVQfO3MUEaUQDHoSmb8TFFrSKYa9mw64WZHNHSDzYA== -react-native-document-picker@^9.3.1: - version "9.3.1" - resolved "https://registry.npmjs.org/react-native-document-picker/-/react-native-document-picker-9.3.1.tgz" - integrity sha512-Vcofv9wfB0j67zawFjfq9WQPMMzXxOZL9kBmvWDpjVuEcVK73ndRmlXHlkeFl5ZHVsv4Zb6oZYhqm9u5omJOPA== - dependencies: - invariant "^2.2.4" - react-native-fs@^2.20.0: version "2.20.0" resolved "https://registry.npmjs.org/react-native-fs/-/react-native-fs-2.20.0.tgz" @@ -5184,12 +5181,7 @@ react-native-vector-icons@^10.3.0: prop-types "^15.7.2" yargs "^16.1.1" -react-native-vision-camera@^4.7.3: - version "4.7.3" - resolved "https://registry.npmjs.org/react-native-vision-camera/-/react-native-vision-camera-4.7.3.tgz" - integrity sha512-g1/neOyjSqn1kaAa2FxI/qp5KzNvPcF0bnQw6NntfbxH6tm0+8WFZszlgb5OV+iYlB6lFUztCbDtyz5IpL47OA== - -react-native@*, "react-native@^0.0.0-0 || >=0.65 <1.0", "react-native@>= 0.61.5", react-native@>=0.70.0, react-native@0.83.1: +react-native@*, "react-native@^0.0.0-0 || >=0.65 <1.0", react-native@>=0.70.0, react-native@0.83.1: version "0.83.1" resolved "https://registry.npmjs.org/react-native/-/react-native-0.83.1.tgz" integrity sha512-mL1q5HPq5cWseVhWRLl+Fwvi5z1UO+3vGOpjr+sHFwcUletPRZ5Kv+d0tUfqHmvi73/53NjlQqX1Pyn4GguUfA== @@ -5235,7 +5227,7 @@ react-refresh@^0.14.0: resolved "https://registry.npmjs.org/react-refresh/-/react-refresh-0.14.2.tgz" integrity sha512-jCvmsr+1IUSMUyzOkRcvnVbX3ZYC6g9TDrDbFuFmRDq7PD4yaGbLKNQL6k2jnArV8hjYxh7hVhAZB6s9HDGpZA== -react@*, "react@^16.8.0 || ^17.0.0 || ^18.0.0 || ^19.0.0", react@^19.2.0, "react@>= 16.9.0", "react@>= 18.2.0", react@>=16.8, react@>=17.0.0, react@>=18.0.0, react@>=18.1.0, react@19.2.0: +react@*, "react@^16.8.0 || ^17.0.0 || ^18.0.0 || ^19.0.0", react@^19.2.0, "react@>= 18.2.0", react@>=16.8, react@>=17.0.0, react@>=18.0.0, react@>=18.1.0, react@19.2.0: version "19.2.0" resolved "https://registry.npmjs.org/react/-/react-19.2.0.tgz" integrity sha512-tmbWg6W31tQLeB5cdIBOicJDJRR2KzXsV7uSK9iNfLWQ5bIZfxuPEHp7M8wiHyHnn0DD1i7w3Zmin0FtkrwoCQ== diff --git a/sdk/runanywhere-commons/.clang-format b/sdk/runanywhere-commons/.clang-format index a9242e92c..7fb0c3f61 100644 --- a/sdk/runanywhere-commons/.clang-format +++ b/sdk/runanywhere-commons/.clang-format @@ -90,7 +90,7 @@ SpacesInSquareBrackets: false # C/C++ specific Language: Cpp -Standard: c++17 +Standard: c++20 # Penalties (for line breaking decisions) PenaltyBreakBeforeFirstCallParameter: 19 diff --git a/sdk/runanywhere-commons/CLAUDE.md b/sdk/runanywhere-commons/CLAUDE.md index 119d22ed3..7ddd9bbfa 100644 --- a/sdk/runanywhere-commons/CLAUDE.md +++ b/sdk/runanywhere-commons/CLAUDE.md @@ -11,7 +11,7 @@ ## C++ Specific Rules -- C++17 standard required +- C++20 standard required - Google C++ Style Guide with project customizations (see `.clang-format`) - Run `./scripts/lint-cpp.sh` before committing - Use `./scripts/lint-cpp.sh --fix` to auto-fix formatting issues diff --git a/sdk/runanywhere-commons/CMakeLists.txt b/sdk/runanywhere-commons/CMakeLists.txt index 6a89c2878..bb0ee984c 100644 --- a/sdk/runanywhere-commons/CMakeLists.txt +++ b/sdk/runanywhere-commons/CMakeLists.txt @@ -46,7 +46,7 @@ option(RAC_BUILD_SERVER "Build OpenAI-compatible HTTP server (runanywhere-server # C++ CONFIGURATION # ============================================================================= -set(CMAKE_CXX_STANDARD 17) +set(CMAKE_CXX_STANDARD 20) set(CMAKE_CXX_STANDARD_REQUIRED ON) set(CMAKE_CXX_EXTENSIONS OFF) set(CMAKE_POSITION_INDEPENDENT_CODE ON) @@ -119,6 +119,92 @@ FetchContent_MakeAvailable(nlohmann_json) # (nlohmann_json_SOURCE_DIR is set by FetchContent_MakeAvailable) include_directories(SYSTEM ${nlohmann_json_SOURCE_DIR}/include) +# libarchive - streaming archive extraction (ZIP, TAR.GZ, TAR.BZ2) +# Used for native model archive extraction across all platforms +if(NOT DEFINED LIBARCHIVE_VERSION) + set(LIBARCHIVE_VERSION "3.8.1") +endif() + +# ----------------------------------------------------------------------------- +# BZip2: Bundle from source for cross-compilation targets +# Android NDK and Emscripten don't ship libbz2. macOS/iOS have it in the SDK. +# We try system first; if not found, build from source so libarchive gets it. +# ----------------------------------------------------------------------------- +find_package(BZip2 QUIET) +if(NOT BZIP2_FOUND) + message(STATUS "System BZip2 not found — bundling from source for cross-compilation...") + FetchContent_Declare( + bzip2_src + URL https://sourceware.org/pub/bzip2/bzip2-1.0.8.tar.gz + URL_HASH SHA256=ab5a03176ee106d3f0fa90e381da478ddae405918153cca248e682cd0c4a2269 + ) + FetchContent_MakeAvailable(bzip2_src) + + add_library(bz2_bundled STATIC + ${bzip2_src_SOURCE_DIR}/blocksort.c + ${bzip2_src_SOURCE_DIR}/huffman.c + ${bzip2_src_SOURCE_DIR}/crctable.c + ${bzip2_src_SOURCE_DIR}/randtable.c + ${bzip2_src_SOURCE_DIR}/compress.c + ${bzip2_src_SOURCE_DIR}/decompress.c + ${bzip2_src_SOURCE_DIR}/bzlib.c + ) + target_include_directories(bz2_bundled PUBLIC ${bzip2_src_SOURCE_DIR}) + set_target_properties(bz2_bundled PROPERTIES POSITION_INDEPENDENT_CODE ON) + + # Set cache variables so libarchive's find_package(BZip2) picks up our build + set(BZIP2_INCLUDE_DIR "${bzip2_src_SOURCE_DIR}" CACHE PATH "" FORCE) + set(BZIP2_LIBRARIES bz2_bundled CACHE STRING "" FORCE) + set(BZIP2_FOUND TRUE CACHE BOOL "" FORCE) + message(STATUS "Bundled BZip2 ready (v1.0.8)") +else() + message(STATUS "Using system BZip2: ${BZIP2_LIBRARIES}") +endif() + +FetchContent_Declare( + libarchive + GIT_REPOSITORY https://github.com/libarchive/libarchive.git + GIT_TAG v${LIBARCHIVE_VERSION} + GIT_SHALLOW TRUE +) +# Disable everything except the static library and the formats we need +set(ENABLE_MBEDTLS OFF CACHE BOOL "" FORCE) +set(ENABLE_NETTLE OFF CACHE BOOL "" FORCE) +set(ENABLE_OPENSSL OFF CACHE BOOL "" FORCE) +set(ENABLE_LIBB2 OFF CACHE BOOL "" FORCE) +set(ENABLE_LZ4 OFF CACHE BOOL "" FORCE) +set(ENABLE_LZO OFF CACHE BOOL "" FORCE) +set(ENABLE_LZMA OFF CACHE BOOL "" FORCE) # tar.xz not currently used by any model +set(ENABLE_ZSTD OFF CACHE BOOL "" FORCE) +set(ENABLE_ZLIB ON CACHE BOOL "" FORCE) # Needed for tar.gz and zip +set(ENABLE_BZip2 ON CACHE BOOL "" FORCE) # Needed for tar.bz2 (k2-fsa models) +set(ENABLE_LIBXML2 OFF CACHE BOOL "" FORCE) +set(ENABLE_EXPAT OFF CACHE BOOL "" FORCE) +set(ENABLE_PCREPOSIX OFF CACHE BOOL "" FORCE) +set(ENABLE_PCRE2POSIX OFF CACHE BOOL "" FORCE) +set(ENABLE_LIBGCC OFF CACHE BOOL "" FORCE) +set(ENABLE_CNG OFF CACHE BOOL "" FORCE) +set(ENABLE_TAR OFF CACHE BOOL "" FORCE) # Don't build bsdtar binary +set(ENABLE_CPIO OFF CACHE BOOL "" FORCE) # Don't build bsdcpio binary +set(ENABLE_CAT OFF CACHE BOOL "" FORCE) # Don't build bsdcat binary +set(ENABLE_UNZIP OFF CACHE BOOL "" FORCE) # Don't build bsdunzip binary +set(ENABLE_TEST OFF CACHE BOOL "" FORCE) +set(ENABLE_INSTALL OFF CACHE BOOL "" FORCE) +set(ENABLE_ACL OFF CACHE BOOL "" FORCE) +set(ENABLE_XATTR OFF CACHE BOOL "" FORCE) +set(ENABLE_ICONV OFF CACHE BOOL "" FORCE) +# Save and restore BUILD_SHARED_LIBS since libarchive respects it +set(_SAVED_BUILD_SHARED_LIBS ${BUILD_SHARED_LIBS}) +set(BUILD_SHARED_LIBS OFF) +FetchContent_MakeAvailable(libarchive) +set(BUILD_SHARED_LIBS ${_SAVED_BUILD_SHARED_LIBS}) + +if(TARGET bz2_bundled AND TARGET archive_static) + target_link_libraries(archive_static bz2_bundled) +endif() + +message(STATUS "libarchive ready (v${LIBARCHIVE_VERSION})") + # ============================================================================= # SERVER DEPENDENCIES (FetchContent) # ============================================================================= @@ -205,6 +291,7 @@ set(RAC_INFRASTRUCTURE_SOURCES src/infrastructure/registry/module_registry.cpp src/infrastructure/registry/service_registry.cpp src/infrastructure/download/download_manager.cpp + src/infrastructure/download/download_orchestrator.cpp src/infrastructure/model_management/model_registry.cpp src/infrastructure/model_management/lora_registry.cpp src/infrastructure/model_management/model_types.cpp @@ -213,6 +300,7 @@ set(RAC_INFRASTRUCTURE_SOURCES src/infrastructure/model_management/model_assignment.cpp src/infrastructure/model_management/model_compatibility.cpp src/infrastructure/storage/storage_analyzer.cpp + src/infrastructure/file_management/file_manager.cpp src/infrastructure/network/environment.cpp src/infrastructure/network/endpoints.cpp src/infrastructure/network/api_types.cpp @@ -223,6 +311,7 @@ set(RAC_INFRASTRUCTURE_SOURCES src/infrastructure/telemetry/telemetry_json.cpp src/infrastructure/telemetry/telemetry_manager.cpp src/infrastructure/device/rac_device_manager.cpp + src/infrastructure/extraction/rac_extraction.cpp ) # Feature sources - LLM, STT, TTS, VAD, Wake Word, VLM, Diffusion (iOS/Apple only) @@ -330,6 +419,10 @@ if(RAC_BUILD_SHARED) ) endif() +# libarchive - native archive extraction +target_link_libraries(rac_commons PRIVATE archive_static) +target_include_directories(rac_commons PRIVATE ${libarchive_SOURCE_DIR}/libarchive ${libarchive_BINARY_DIR}) + # Platform-specific linking if(APPLE) target_link_libraries(rac_commons PUBLIC @@ -348,7 +441,7 @@ if(RAC_PLATFORM_ANDROID) endif() set_target_properties(rac_commons PROPERTIES - CXX_STANDARD 17 + CXX_STANDARD 20 CXX_STANDARD_REQUIRED ON CXX_EXTENSIONS OFF ) diff --git a/sdk/runanywhere-commons/README.md b/sdk/runanywhere-commons/README.md index 6de4272f7..1b604f592 100644 --- a/sdk/runanywhere-commons/README.md +++ b/sdk/runanywhere-commons/README.md @@ -32,7 +32,7 @@ RunAnywhere Commons is the shared C++ layer that sits between platform SDKs (Swi ### Design Principles -- **C++ Core, C API Surface** - C++17 internally, pure C API for FFI compatibility +- **C++ Core, C API Surface** - C++20 internally, pure C API for FFI compatibility - **Vtable-Based Polymorphism** - No C++ virtual inheritance at API boundaries - **Priority-Based Dispatch** - Service providers register with priority; first capable handler wins - **Lazy Initialization** - Services created on-demand, not at startup @@ -173,7 +173,7 @@ printf("Transcription: %s\n", result.text); ### Prerequisites - **CMake** 3.22 or higher -- **C++17** compatible compiler (Clang, GCC) +- **C++20** compatible compiler (Clang, GCC, MSVC) - **Platform-specific**: Xcode 15+ (iOS/macOS), Android NDK r25+ (Android) ### Quick Start diff --git a/sdk/runanywhere-commons/VERSIONS b/sdk/runanywhere-commons/VERSIONS index b848f2936..47c030882 100644 --- a/sdk/runanywhere-commons/VERSIONS +++ b/sdk/runanywhere-commons/VERSIONS @@ -81,6 +81,11 @@ LLAMACPP_VERSION=b8201 # ============================================================================= NLOHMANN_JSON_VERSION=3.11.3 +# ============================================================================= +# libarchive (archive extraction - ZIP, TAR.GZ, TAR.BZ2, TAR.XZ) +# ============================================================================= +LIBARCHIVE_VERSION=3.8.1 + # ============================================================================= # RAC Commons Version (for remote builds/CI) # ============================================================================= diff --git a/sdk/runanywhere-commons/exports/RACommons.exports b/sdk/runanywhere-commons/exports/RACommons.exports index 703aa6424..5ea2cf5c8 100644 --- a/sdk/runanywhere-commons/exports/RACommons.exports +++ b/sdk/runanywhere-commons/exports/RACommons.exports @@ -44,6 +44,8 @@ _rac_artifact_infer_from_url _rac_artifact_requires_download _rac_artifact_requires_extraction _rac_extract_archive +_rac_extract_archive_native +_rac_detect_archive_type # Component Types _rac_capability_resource_type_raw_value @@ -61,6 +63,8 @@ _rac_download_manager_get_active_tasks _rac_download_manager_get_progress _rac_download_manager_is_healthy _rac_download_manager_mark_complete +_rac_download_manager_mark_extraction_complete +_rac_download_manager_mark_extraction_failed _rac_download_manager_mark_failed _rac_download_manager_pause_all _rac_download_manager_resume_all @@ -70,9 +74,29 @@ _rac_download_stage_display_name _rac_download_stage_progress_range _rac_download_task_free _rac_download_task_ids_free +# Download Orchestrator +_rac_download_orchestrate +_rac_download_orchestrate_multi +_rac_download_compute_destination +_rac_download_requires_extraction +_rac_find_model_path_after_extraction _rac_http_download _rac_http_download_cancel +# File Manager +_rac_file_manager_cache_size +_rac_file_manager_calculate_dir_size +_rac_file_manager_check_storage +_rac_file_manager_clear_cache +_rac_file_manager_clear_directory +_rac_file_manager_clear_temp +_rac_file_manager_create_directory_structure +_rac_file_manager_create_model_folder +_rac_file_manager_delete_model +_rac_file_manager_get_storage_info +_rac_file_manager_model_folder_exists +_rac_file_manager_models_storage_used + # Events _rac_event_category_name _rac_event_publish diff --git a/sdk/runanywhere-commons/include/rac/core/capabilities/rac_lifecycle.h b/sdk/runanywhere-commons/include/rac/core/capabilities/rac_lifecycle.h index 273f45e50..75456e7e7 100644 --- a/sdk/runanywhere-commons/include/rac/core/capabilities/rac_lifecycle.h +++ b/sdk/runanywhere-commons/include/rac/core/capabilities/rac_lifecycle.h @@ -234,6 +234,25 @@ RAC_API rac_handle_t rac_lifecycle_get_service(rac_handle_t handle); */ RAC_API rac_result_t rac_lifecycle_require_service(rac_handle_t handle, rac_handle_t* out_service); +/** + * @brief Acquire (pin) the current service, preventing unload while held. + * + * Increments an internal refcount. The caller MUST call rac_lifecycle_release_service() + * when done. Unload/destroy will block until all acquired references are released. + * + * @param handle Lifecycle manager handle + * @param out_service Output: Service handle (pinned) + * @return RAC_SUCCESS or RAC_ERROR_NOT_INITIALIZED if not loaded + */ +RAC_API rac_result_t rac_lifecycle_acquire_service(rac_handle_t handle, rac_handle_t* out_service); + +/** + * @brief Release a previously acquired service reference. + * + * @param handle Lifecycle manager handle + */ +RAC_API void rac_lifecycle_release_service(rac_handle_t handle); + /** * @brief Track an operation error * diff --git a/sdk/runanywhere-commons/include/rac/core/rac_logger.h b/sdk/runanywhere-commons/include/rac/core/rac_logger.h index 3df0f86c3..2ccd69083 100644 --- a/sdk/runanywhere-commons/include/rac/core/rac_logger.h +++ b/sdk/runanywhere-commons/include/rac/core/rac_logger.h @@ -228,63 +228,85 @@ RAC_API void rac_logger_logv(rac_log_level_t level, const char* category, #endif // --- Level-specific logging macros with automatic source location --- +// Each macro checks the current min level BEFORE constructing metadata +// or calling the log function. This avoids function call overhead, metadata +// struct construction, and vsnprintf formatting for filtered messages. +// rac_logger_get_min_level() is an atomic read (no mutex). + +#define RAC_LOG_TRACE(category, ...) \ + do { \ + if (RAC_LOG_TRACE >= rac_logger_get_min_level()) { \ + rac_log_metadata_t _meta = RAC_LOG_META_HERE(); \ + rac_logger_logf(RAC_LOG_TRACE, category, &_meta, __VA_ARGS__); \ + } \ + } while (0) -#define RAC_LOG_TRACE(category, ...) \ - do { \ - rac_log_metadata_t _meta = RAC_LOG_META_HERE(); \ - rac_logger_logf(RAC_LOG_TRACE, category, &_meta, __VA_ARGS__); \ +#define RAC_LOG_DEBUG(category, ...) \ + do { \ + if (RAC_LOG_DEBUG >= rac_logger_get_min_level()) { \ + rac_log_metadata_t _meta = RAC_LOG_META_HERE(); \ + rac_logger_logf(RAC_LOG_DEBUG, category, &_meta, __VA_ARGS__); \ + } \ } while (0) -#define RAC_LOG_DEBUG(category, ...) \ +#define RAC_LOG_INFO(category, ...) \ do { \ - rac_log_metadata_t _meta = RAC_LOG_META_HERE(); \ - rac_logger_logf(RAC_LOG_DEBUG, category, &_meta, __VA_ARGS__); \ + if (RAC_LOG_INFO >= rac_logger_get_min_level()) { \ + rac_log_metadata_t _meta = RAC_LOG_META_HERE(); \ + rac_logger_logf(RAC_LOG_INFO, category, &_meta, __VA_ARGS__); \ + } \ } while (0) -#define RAC_LOG_INFO(category, ...) \ - do { \ - rac_log_metadata_t _meta = RAC_LOG_META_HERE(); \ - rac_logger_logf(RAC_LOG_INFO, category, &_meta, __VA_ARGS__); \ +#define RAC_LOG_WARNING(category, ...) \ + do { \ + if (RAC_LOG_WARNING >= rac_logger_get_min_level()) { \ + rac_log_metadata_t _meta = RAC_LOG_META_HERE(); \ + rac_logger_logf(RAC_LOG_WARNING, category, &_meta, __VA_ARGS__); \ + } \ } while (0) -#define RAC_LOG_WARNING(category, ...) \ - do { \ - rac_log_metadata_t _meta = RAC_LOG_META_HERE(); \ - rac_logger_logf(RAC_LOG_WARNING, category, &_meta, __VA_ARGS__); \ +#define RAC_LOG_ERROR(category, ...) \ + do { \ + if (RAC_LOG_ERROR >= rac_logger_get_min_level()) { \ + rac_log_metadata_t _meta = RAC_LOG_META_HERE(); \ + rac_logger_logf(RAC_LOG_ERROR, category, &_meta, __VA_ARGS__); \ + } \ } while (0) -#define RAC_LOG_ERROR(category, ...) \ - do { \ - rac_log_metadata_t _meta = RAC_LOG_META_HERE(); \ - rac_logger_logf(RAC_LOG_ERROR, category, &_meta, __VA_ARGS__); \ - } while (0) - -#define RAC_LOG_FATAL(category, ...) \ - do { \ - rac_log_metadata_t _meta = RAC_LOG_META_HERE(); \ - rac_logger_logf(RAC_LOG_FATAL, category, &_meta, __VA_ARGS__); \ +#define RAC_LOG_FATAL(category, ...) \ + do { \ + if (RAC_LOG_FATAL >= rac_logger_get_min_level()) { \ + rac_log_metadata_t _meta = RAC_LOG_META_HERE(); \ + rac_logger_logf(RAC_LOG_FATAL, category, &_meta, __VA_ARGS__); \ + } \ } while (0) // --- Error logging with code --- -#define RAC_LOG_ERROR_CODE(category, code, ...) \ - do { \ - rac_log_metadata_t _meta = RAC_LOG_META_ERROR(code, NULL); \ - rac_logger_logf(RAC_LOG_ERROR, category, &_meta, __VA_ARGS__); \ +#define RAC_LOG_ERROR_CODE(category, code, ...) \ + do { \ + if (RAC_LOG_ERROR >= rac_logger_get_min_level()) { \ + rac_log_metadata_t _meta = RAC_LOG_META_ERROR(code, NULL); \ + rac_logger_logf(RAC_LOG_ERROR, category, &_meta, __VA_ARGS__); \ + } \ } while (0) // --- Model context logging --- -#define RAC_LOG_MODEL_INFO(category, model_id, framework, ...) \ - do { \ - rac_log_metadata_t _meta = RAC_LOG_META_MODEL(model_id, framework); \ - rac_logger_logf(RAC_LOG_INFO, category, &_meta, __VA_ARGS__); \ +#define RAC_LOG_MODEL_INFO(category, model_id, framework, ...) \ + do { \ + if (RAC_LOG_INFO >= rac_logger_get_min_level()) { \ + rac_log_metadata_t _meta = RAC_LOG_META_MODEL(model_id, framework); \ + rac_logger_logf(RAC_LOG_INFO, category, &_meta, __VA_ARGS__); \ + } \ } while (0) -#define RAC_LOG_MODEL_ERROR(category, model_id, framework, ...) \ - do { \ - rac_log_metadata_t _meta = RAC_LOG_META_MODEL(model_id, framework); \ - rac_logger_logf(RAC_LOG_ERROR, category, &_meta, __VA_ARGS__); \ +#define RAC_LOG_MODEL_ERROR(category, model_id, framework, ...) \ + do { \ + if (RAC_LOG_ERROR >= rac_logger_get_min_level()) { \ + rac_log_metadata_t _meta = RAC_LOG_META_MODEL(model_id, framework); \ + rac_logger_logf(RAC_LOG_ERROR, category, &_meta, __VA_ARGS__); \ + } \ } while (0) // ============================================================================= @@ -341,74 +363,83 @@ namespace rac { class Logger { public: explicit Logger(const char* category) : category_(category) {} - explicit Logger(const std::string& category) : category_(category.c_str()) {} + explicit Logger(const std::string& category) : category_(category) {} void trace(const char* format, ...) const { + if (RAC_LOG_TRACE < rac_logger_get_min_level()) return; va_list args; va_start(args, format); - rac_logger_logv(RAC_LOG_TRACE, category_, nullptr, format, args); + rac_logger_logv(RAC_LOG_TRACE, category_.c_str(), nullptr, format, args); va_end(args); } void debug(const char* format, ...) const { + if (RAC_LOG_DEBUG < rac_logger_get_min_level()) return; va_list args; va_start(args, format); - rac_logger_logv(RAC_LOG_DEBUG, category_, nullptr, format, args); + rac_logger_logv(RAC_LOG_DEBUG, category_.c_str(), nullptr, format, args); va_end(args); } void info(const char* format, ...) const { + if (RAC_LOG_INFO < rac_logger_get_min_level()) return; va_list args; va_start(args, format); - rac_logger_logv(RAC_LOG_INFO, category_, nullptr, format, args); + rac_logger_logv(RAC_LOG_INFO, category_.c_str(), nullptr, format, args); va_end(args); } void warning(const char* format, ...) const { + if (RAC_LOG_WARNING < rac_logger_get_min_level()) return; va_list args; va_start(args, format); - rac_logger_logv(RAC_LOG_WARNING, category_, nullptr, format, args); + rac_logger_logv(RAC_LOG_WARNING, category_.c_str(), nullptr, format, args); va_end(args); } void error(const char* format, ...) const { + if (RAC_LOG_ERROR < rac_logger_get_min_level()) return; va_list args; va_start(args, format); - rac_logger_logv(RAC_LOG_ERROR, category_, nullptr, format, args); + rac_logger_logv(RAC_LOG_ERROR, category_.c_str(), nullptr, format, args); va_end(args); } void error(int32_t code, const char* format, ...) const { + if (RAC_LOG_ERROR < rac_logger_get_min_level()) return; rac_log_metadata_t meta = RAC_LOG_METADATA_EMPTY; meta.error_code = code; va_list args; va_start(args, format); - rac_logger_logv(RAC_LOG_ERROR, category_, &meta, format, args); + rac_logger_logv(RAC_LOG_ERROR, category_.c_str(), &meta, format, args); va_end(args); } void fatal(const char* format, ...) const { + // Fatal is always logged — no early exit va_list args; va_start(args, format); - rac_logger_logv(RAC_LOG_FATAL, category_, nullptr, format, args); + rac_logger_logv(RAC_LOG_FATAL, category_.c_str(), nullptr, format, args); va_end(args); } // Log with model context void modelInfo(const char* model_id, const char* framework, const char* format, ...) const { + if (RAC_LOG_INFO < rac_logger_get_min_level()) return; rac_log_metadata_t meta = RAC_LOG_METADATA_EMPTY; meta.model_id = model_id; meta.framework = framework; va_list args; va_start(args, format); - rac_logger_logv(RAC_LOG_INFO, category_, &meta, format, args); + rac_logger_logv(RAC_LOG_INFO, category_.c_str(), &meta, format, args); va_end(args); } void modelError(const char* model_id, const char* framework, int32_t code, const char* format, ...) const { + if (RAC_LOG_ERROR < rac_logger_get_min_level()) return; rac_log_metadata_t meta = RAC_LOG_METADATA_EMPTY; meta.model_id = model_id; meta.framework = framework; @@ -416,12 +447,12 @@ class Logger { va_list args; va_start(args, format); - rac_logger_logv(RAC_LOG_ERROR, category_, &meta, format, args); + rac_logger_logv(RAC_LOG_ERROR, category_.c_str(), &meta, format, args); va_end(args); } private: - const char* category_; + std::string category_; }; // Predefined loggers for common categories diff --git a/sdk/runanywhere-commons/include/rac/infrastructure/download/rac_download.h b/sdk/runanywhere-commons/include/rac/infrastructure/download/rac_download.h index fc9bdffef..9f092f1e6 100644 --- a/sdk/runanywhere-commons/include/rac/infrastructure/download/rac_download.h +++ b/sdk/runanywhere-commons/include/rac/infrastructure/download/rac_download.h @@ -392,6 +392,39 @@ RAC_API rac_result_t rac_download_manager_mark_failed(rac_download_manager_handl const char* task_id, rac_result_t error_code, const char* error_message); +// ============================================================================= +// EXTRACTION COMPLETION API +// ============================================================================= + +/** + * @brief Mark extraction as completed for a download task. + * + * Called after archive extraction succeeds. Transitions the task + * from EXTRACTING to COMPLETED state. + * + * @param handle Manager handle + * @param task_id Task ID + * @param extracted_path Path to the extracted model directory + * @return RAC_SUCCESS or error code + */ +RAC_API rac_result_t rac_download_manager_mark_extraction_complete( + rac_download_manager_handle_t handle, const char* task_id, const char* extracted_path); + +/** + * @brief Mark extraction as failed for a download task. + * + * Called if archive extraction fails. + * + * @param handle Manager handle + * @param task_id Task ID + * @param error_code Extraction error code + * @param error_message Error description (can be NULL) + * @return RAC_SUCCESS or error code + */ +RAC_API rac_result_t rac_download_manager_mark_extraction_failed( + rac_download_manager_handle_t handle, const char* task_id, rac_result_t error_code, + const char* error_message); + // ============================================================================= // MEMORY MANAGEMENT // ============================================================================= diff --git a/sdk/runanywhere-commons/include/rac/infrastructure/download/rac_download_orchestrator.h b/sdk/runanywhere-commons/include/rac/infrastructure/download/rac_download_orchestrator.h new file mode 100644 index 000000000..ce957d376 --- /dev/null +++ b/sdk/runanywhere-commons/include/rac/infrastructure/download/rac_download_orchestrator.h @@ -0,0 +1,166 @@ +/** + * @file rac_download_orchestrator.h + * @brief Download Orchestrator - High-Level Model Download Lifecycle Management + * + * Consolidates download business logic from all platform SDKs into C++. + * Handles the full download lifecycle: path resolution, extraction detection, + * HTTP download (via platform adapter), post-download extraction, model path + * finding, registry updates, and archive cleanup. + * + * HTTP transport remains platform-specific via rac_platform_adapter_t.http_download. + * This layer handles ALL orchestration logic so each SDK reduces to: + * 1. Register http_download callback + * 2. Call rac_download_orchestrate() + * 3. Wrap result in SDK types + * + * Depends on: + * - rac_download.h (download manager state machine, progress tracking) + * - rac_platform_adapter.h (http_download callback for HTTP transport) + * - rac_extraction.h (rac_extract_archive_native for archive extraction) + * - rac_model_paths.h (destination path resolution) + * - rac_model_types.h (model types, archive types, frameworks) + */ + +#ifndef RAC_DOWNLOAD_ORCHESTRATOR_H +#define RAC_DOWNLOAD_ORCHESTRATOR_H + +#include "rac/core/rac_error.h" +#include "rac/core/rac_types.h" +#include "rac/infrastructure/download/rac_download.h" +#include "rac/infrastructure/model_management/rac_model_types.h" + +#ifdef __cplusplus +extern "C" { +#endif + +// ============================================================================= +// DOWNLOAD ORCHESTRATION - Full Lifecycle Model Download +// ============================================================================= + +/** + * @brief Orchestrate a single-file model download with full lifecycle management. + * + * This is the main entry point for downloading a model. It handles: + * 1. Compute destination folder via rac_model_paths_get_model_folder() + * 2. Detect if extraction is needed via rac_archive_type_from_path() + * 3. Download to temp path if extraction needed, else download to model folder + * 4. Invoke platform http_download via rac_http_download() + * 5. On HTTP completion: extract if needed, find model path, cleanup archive + * 6. Update download manager state (DOWNLOADING → EXTRACTING → COMPLETED) + * 7. Invoke user callbacks with final model path + * + * @param dm_handle Download manager handle (for state tracking) + * @param model_id Model identifier (used for folder naming and registry) + * @param download_url URL to download from + * @param framework Inference framework (determines storage directory) + * @param format Model format (determines file extension and path finding) + * @param archive_structure Archive structure hint (used for post-extraction path finding) + * @param progress_callback Progress updates across all stages (can be NULL) + * @param complete_callback Called when entire lifecycle completes or fails + * @param user_data User context passed to callbacks + * @param out_task_id Output: Task ID for tracking/cancellation (owned, free with rac_free) + * @return RAC_SUCCESS if download started, error code if failed to start + */ +RAC_API rac_result_t rac_download_orchestrate( + rac_download_manager_handle_t dm_handle, const char* model_id, const char* download_url, + rac_inference_framework_t framework, rac_model_format_t format, + rac_archive_structure_t archive_structure, + rac_download_progress_callback_fn progress_callback, + rac_download_complete_callback_fn complete_callback, void* user_data, char** out_task_id); + +/** + * @brief Orchestrate a multi-file model download (e.g., VLM with companion files). + * + * Downloads multiple files sequentially into the same model folder. + * Progress is distributed across all files proportionally. + * Extraction is applied to each file individually if needed. + * + * @param dm_handle Download manager handle (for state tracking) + * @param model_id Model identifier + * @param files Array of file descriptors (relative_path, destination_path, is_required) + * @param file_count Number of files to download + * @param base_download_url Base URL — file relative_path is appended to this + * @param framework Inference framework + * @param format Model format + * @param progress_callback Progress updates across all files and stages (can be NULL) + * @param complete_callback Called when all files complete or any required file fails + * @param user_data User context passed to callbacks + * @param out_task_id Output: Task ID for tracking/cancellation (owned, free with rac_free) + * @return RAC_SUCCESS if download started, error code if failed to start + */ +RAC_API rac_result_t rac_download_orchestrate_multi( + rac_download_manager_handle_t dm_handle, const char* model_id, + const rac_model_file_descriptor_t* files, size_t file_count, const char* base_download_url, + rac_inference_framework_t framework, rac_model_format_t format, + rac_download_progress_callback_fn progress_callback, + rac_download_complete_callback_fn complete_callback, void* user_data, char** out_task_id); + +// ============================================================================= +// POST-EXTRACTION MODEL PATH FINDING +// ============================================================================= + +/** + * @brief Find the actual model path after extraction. + * + * Consolidates duplicated Swift/Kotlin logic for scanning extracted directories: + * - Finds .gguf, .onnx, .ort, .bin files + * - Handles nested directories (e.g., sherpa-onnx archives with subdirectory) + * - Handles single-file-nested pattern (model file inside one subdirectory) + * - Returns the directory itself for directory-based models (ONNX) + * + * Uses POSIX opendir/readdir for cross-platform compatibility (iOS/Android/Linux/macOS). + * + * @param extracted_dir Directory where archive was extracted + * @param structure Archive structure hint (SINGLE_FILE_NESTED, NESTED_DIRECTORY, etc.) + * @param framework Inference framework (used to determine if directory-based) + * @param format Model format (used to determine expected file extensions) + * @param out_path Output buffer for the found model path + * @param path_size Size of output buffer + * @return RAC_SUCCESS if model path found, RAC_ERROR_NOT_FOUND if no model file found + */ +RAC_API rac_result_t rac_find_model_path_after_extraction( + const char* extracted_dir, rac_archive_structure_t structure, + rac_inference_framework_t framework, rac_model_format_t format, char* out_path, + size_t path_size); + +// ============================================================================= +// UTILITY FUNCTIONS +// ============================================================================= + +/** + * @brief Compute the download destination path for a model. + * + * If extraction is needed: returns a temp path in the downloads directory. + * If no extraction: returns the final model folder path. + * + * @param model_id Model identifier + * @param download_url URL to download (used for archive detection and extension) + * @param framework Inference framework + * @param format Model format + * @param out_path Output buffer for destination path + * @param path_size Size of output buffer + * @param out_needs_extraction Output: RAC_TRUE if download needs extraction + * @return RAC_SUCCESS or error code + */ +RAC_API rac_result_t rac_download_compute_destination(const char* model_id, + const char* download_url, + rac_inference_framework_t framework, + rac_model_format_t format, char* out_path, + size_t path_size, + rac_bool_t* out_needs_extraction); + +/** + * @brief Check if a download URL requires extraction. + * + * Convenience wrapper around rac_archive_type_from_path(). + * + * @param download_url URL to check + * @return RAC_TRUE if URL points to an archive that needs extraction + */ +RAC_API rac_bool_t rac_download_requires_extraction(const char* download_url); + +#ifdef __cplusplus +} +#endif + +#endif /* RAC_DOWNLOAD_ORCHESTRATOR_H */ diff --git a/sdk/runanywhere-commons/include/rac/infrastructure/extraction/rac_extraction.h b/sdk/runanywhere-commons/include/rac/infrastructure/extraction/rac_extraction.h new file mode 100644 index 000000000..fd9b0ea45 --- /dev/null +++ b/sdk/runanywhere-commons/include/rac/infrastructure/extraction/rac_extraction.h @@ -0,0 +1,149 @@ +/** + * @file rac_extraction.h + * @brief RunAnywhere Commons - Native Archive Extraction + * + * Native archive extraction using libarchive. + * Supports ZIP, TAR.GZ, TAR.BZ2, TAR.XZ with streaming extraction + * (constant memory usage regardless of archive size). + * + * Security features: + * - Zip-slip protection (path traversal prevention) + * - macOS resource fork skipping (._files, __MACOSX/) + * - Symbolic link safety (contained within destination) + * - Archive type auto-detection via magic bytes + */ + +#ifndef RAC_EXTRACTION_H +#define RAC_EXTRACTION_H + +#include "rac/core/rac_error.h" +#include "rac/core/rac_types.h" +#include "rac/infrastructure/model_management/rac_model_types.h" + +#ifdef __cplusplus +extern "C" { +#endif + +// ============================================================================= +// EXTRACTION OPTIONS +// ============================================================================= + +/** + * @brief Options for archive extraction. + */ +typedef struct rac_extraction_options { + /** Skip macOS resource forks (._ files, __MACOSX/ directories). + * Default: RAC_TRUE */ + rac_bool_t skip_macos_resources; + + /** Skip symbolic links entirely. + * Default: RAC_FALSE (symlinks are created if safe) */ + rac_bool_t skip_symlinks; + + /** Archive type hint. RAC_ARCHIVE_TYPE_NONE = auto-detect from magic bytes. + * Default: RAC_ARCHIVE_TYPE_NONE */ + rac_archive_type_t archive_type_hint; +} rac_extraction_options_t; + +/** + * @brief Default extraction options. + */ +#ifdef __cplusplus +inline constexpr rac_extraction_options_t RAC_EXTRACTION_OPTIONS_DEFAULT = { + RAC_TRUE, /* skip_macos_resources */ + RAC_FALSE, /* skip_symlinks */ + RAC_ARCHIVE_TYPE_NONE /* archive_type_hint */ +}; +#else +static const rac_extraction_options_t RAC_EXTRACTION_OPTIONS_DEFAULT = { + RAC_TRUE, /* skip_macos_resources */ + RAC_FALSE, /* skip_symlinks */ + RAC_ARCHIVE_TYPE_NONE /* archive_type_hint */ +}; +#endif + +// ============================================================================= +// EXTRACTION RESULT +// ============================================================================= + +/** + * @brief Result of an extraction operation. + */ +typedef struct rac_extraction_result { + /** Number of files extracted */ + int32_t files_extracted; + + /** Number of directories created */ + int32_t directories_created; + + /** Total bytes written to disk */ + int64_t bytes_extracted; + + /** Number of entries skipped (resource forks, unsafe paths) */ + int32_t entries_skipped; +} rac_extraction_result_t; + +// ============================================================================= +// EXTRACTION PROGRESS CALLBACK +// ============================================================================= + +/** + * @brief Progress callback for extraction. + * + * @param files_extracted Number of files extracted so far + * @param total_files Total files in archive (0 if unknown for streaming formats) + * @param bytes_extracted Bytes written to disk so far + * @param user_data User-provided context + */ +typedef void (*rac_extraction_progress_fn)(int32_t files_extracted, int32_t total_files, + int64_t bytes_extracted, void* user_data); + +// ============================================================================= +// EXTRACTION API +// ============================================================================= + +/** + * @brief Extract an archive using native libarchive. + * + * Performs streaming extraction with constant memory usage. + * Auto-detects archive format from magic bytes if archive_type_hint + * is RAC_ARCHIVE_TYPE_NONE. + * + * @param archive_path Path to the archive file + * @param destination_dir Directory to extract into (created if needed) + * @param options Extraction options (NULL for defaults) + * @param progress_callback Progress callback (can be NULL) + * @param user_data Context for progress callback + * @param out_result Output: extraction statistics (can be NULL) + * @return RAC_SUCCESS on success, error code on failure + * + * Error codes: + * - RAC_ERROR_EXTRACTION_FAILED: General extraction error + * - RAC_ERROR_UNSUPPORTED_ARCHIVE: Unrecognized archive format + * - RAC_ERROR_FILE_NOT_FOUND: Archive file does not exist + * - RAC_ERROR_NULL_POINTER: archive_path or destination_dir is NULL + */ +RAC_API rac_result_t rac_extract_archive_native(const char* archive_path, + const char* destination_dir, + const rac_extraction_options_t* options, + rac_extraction_progress_fn progress_callback, + void* user_data, + rac_extraction_result_t* out_result); + +/** + * @brief Detect archive type from file magic bytes. + * + * Reads the first few bytes of the file to determine the archive format. + * More reliable than file extension detection. + * + * @param file_path Path to the file + * @param out_type Output: detected archive type + * @return RAC_TRUE if archive type detected, RAC_FALSE otherwise + */ +RAC_API rac_bool_t rac_detect_archive_type(const char* file_path, rac_archive_type_t* out_type); + +#ifdef __cplusplus +} +#endif + +#endif /* RAC_EXTRACTION_H */ diff --git a/sdk/runanywhere-commons/include/rac/infrastructure/file_management/rac_file_manager.h b/sdk/runanywhere-commons/include/rac/infrastructure/file_management/rac_file_manager.h new file mode 100644 index 000000000..d77430c3f --- /dev/null +++ b/sdk/runanywhere-commons/include/rac/infrastructure/file_management/rac_file_manager.h @@ -0,0 +1,359 @@ +/** + * @file rac_file_manager.h + * @brief File Manager - Centralized File Management Business Logic + * + * Consolidates common file management operations that were duplicated + * across SDKs (Swift, Kotlin, Flutter, React Native): + * - Directory size calculation (recursive traversal) + * - Directory structure creation (Models/Cache/Temp/Downloads) + * - Cache and temp cleanup + * - Model folder management (create, delete, check existence) + * - Storage availability checking + * + * Platform-specific file I/O is provided via callbacks (rac_file_callbacks_t). + * C++ handles all business logic; SDKs only provide thin I/O implementations. + * + * Uses rac_model_paths for path computation. + */ + +#ifndef RAC_FILE_MANAGER_H +#define RAC_FILE_MANAGER_H + +#include +#include + +#include "rac/core/rac_error.h" +#include "rac/core/rac_types.h" +#include "rac/infrastructure/model_management/rac_model_types.h" +#include "rac/infrastructure/storage/rac_storage_analyzer.h" + +#ifdef __cplusplus +extern "C" { +#endif + +// ============================================================================= +// PLATFORM I/O CALLBACKS +// ============================================================================= + +/** + * @brief Platform-specific file I/O callbacks. + * + * SDKs implement these thin wrappers around native file operations. + * C++ business logic calls these for all file system access. + */ +typedef struct { + /** + * Create a directory (optionally recursive). + * @param path Directory path to create + * @param recursive If non-zero, create intermediate directories + * @param user_data Platform context + * @return RAC_SUCCESS or error code + */ + rac_result_t (*create_directory)(const char* path, int recursive, void* user_data); + + /** + * Delete a file or directory. + * @param path Path to delete + * @param recursive If non-zero, delete directory contents recursively + * @param user_data Platform context + * @return RAC_SUCCESS or error code + */ + rac_result_t (*delete_path)(const char* path, int recursive, void* user_data); + + /** + * List directory contents (entry names only, not full paths). + * @param path Directory path + * @param out_entries Output: Array of entry name strings (allocated by callback) + * @param out_count Output: Number of entries + * @param user_data Platform context + * @return RAC_SUCCESS or error code + */ + rac_result_t (*list_directory)(const char* path, char*** out_entries, size_t* out_count, + void* user_data); + + /** + * Free directory entries returned by list_directory. + * @param entries Array of entry names + * @param count Number of entries + * @param user_data Platform context + */ + void (*free_entries)(char** entries, size_t count, void* user_data); + + /** + * Check if a path exists. + * @param path Path to check + * @param out_is_directory Output: non-zero if path is a directory (can be NULL) + * @param user_data Platform context + * @return RAC_TRUE if exists, RAC_FALSE otherwise + */ + rac_bool_t (*path_exists)(const char* path, rac_bool_t* out_is_directory, void* user_data); + + /** + * Get file size in bytes. + * @param path File path + * @param user_data Platform context + * @return File size in bytes, or -1 on error + */ + int64_t (*get_file_size)(const char* path, void* user_data); + + /** + * Get available disk space in bytes. + * @param user_data Platform context + * @return Available space in bytes, or -1 on error + */ + int64_t (*get_available_space)(void* user_data); + + /** + * Get total disk space in bytes. + * @param user_data Platform context + * @return Total space in bytes, or -1 on error + */ + int64_t (*get_total_space)(void* user_data); + + /** Platform-specific context passed to all callbacks */ + void* user_data; +} rac_file_callbacks_t; + +// ============================================================================= +// DATA STRUCTURES +// ============================================================================= + +/** + * @brief Combined storage information. + * + * Aggregates device storage, app storage (models/cache/temp), and + * computed totals. Replaces per-SDK storage info structs. + */ +typedef struct { + /** Total device storage in bytes */ + int64_t device_total; + /** Free device storage in bytes */ + int64_t device_free; + /** Total models directory size in bytes */ + int64_t models_size; + /** Cache directory size in bytes */ + int64_t cache_size; + /** Temp directory size in bytes */ + int64_t temp_size; + /** Total app storage (models + cache + temp) */ + int64_t total_app_size; +} rac_file_manager_storage_info_t; + +// ============================================================================= +// DIRECTORY STRUCTURE +// ============================================================================= + +/** + * @brief Create the standard directory structure under the base directory. + * + * Creates: Models/, Cache/, Temp/, Downloads/ under {base_dir}/RunAnywhere/ + * Uses rac_model_paths for path computation. + * + * Replaces: + * - Swift: SimplifiedFileManager.createDirectoryStructure() + * - Kotlin: SharedFileSystem directory creation + * - Flutter: SimplifiedFileManager._createDirectoryStructure() + * + * @param cb Platform I/O callbacks + * @return RAC_SUCCESS or error code + */ +RAC_API rac_result_t rac_file_manager_create_directory_structure(const rac_file_callbacks_t* cb); + +// ============================================================================= +// MODEL FOLDER MANAGEMENT +// ============================================================================= + +/** + * @brief Create a model folder and return its path. + * + * Creates: {base_dir}/RunAnywhere/Models/{framework}/{modelId}/ + * + * @param cb Platform I/O callbacks + * @param model_id Model identifier + * @param framework Inference framework + * @param out_path Output buffer for the created folder path + * @param path_size Size of output buffer + * @return RAC_SUCCESS or error code + */ +RAC_API rac_result_t rac_file_manager_create_model_folder(const rac_file_callbacks_t* cb, + const char* model_id, + rac_inference_framework_t framework, + char* out_path, size_t path_size); + +/** + * @brief Check if a model folder exists and optionally if it has contents. + * + * @param cb Platform I/O callbacks + * @param model_id Model identifier + * @param framework Inference framework + * @param out_exists Output: RAC_TRUE if folder exists + * @param out_has_contents Output: RAC_TRUE if folder has files (can be NULL) + * @return RAC_SUCCESS or error code + */ +RAC_API rac_result_t rac_file_manager_model_folder_exists(const rac_file_callbacks_t* cb, + const char* model_id, + rac_inference_framework_t framework, + rac_bool_t* out_exists, + rac_bool_t* out_has_contents); + +/** + * @brief Delete a model folder recursively. + * + * Replaces: + * - Swift: SimplifiedFileManager.deleteModel(modelId:framework:) + * - Flutter: SimplifiedFileManager.deleteModelFolder() + * + * @param cb Platform I/O callbacks + * @param model_id Model identifier + * @param framework Inference framework + * @return RAC_SUCCESS, or RAC_ERROR_FILE_NOT_FOUND if folder doesn't exist + */ +RAC_API rac_result_t rac_file_manager_delete_model(const rac_file_callbacks_t* cb, + const char* model_id, + rac_inference_framework_t framework); + +// ============================================================================= +// DIRECTORY SIZE CALCULATION +// ============================================================================= + +/** + * @brief Calculate directory size recursively. + * + * Traverses the directory tree using callbacks, summing file sizes. + * This is the core duplicated logic across all SDKs. + * + * Replaces: + * - Swift: SimplifiedFileManager.calculateDirectorySize(at:) + * - Kotlin: calculateDirectorySize(directory:) + * - Flutter: SimplifiedFileManager.calculateModelsSize() + * - RN: FileSystem.getDirectorySize() + * + * @param cb Platform I/O callbacks + * @param path Directory path to measure + * @param out_size Output: Total size in bytes + * @return RAC_SUCCESS or error code + */ +RAC_API rac_result_t rac_file_manager_calculate_dir_size(const rac_file_callbacks_t* cb, + const char* path, int64_t* out_size); + +/** + * @brief Get total models directory storage used. + * + * Convenience wrapper: calculates size of the models directory. + * + * @param cb Platform I/O callbacks + * @param out_size Output: Total models size in bytes + * @return RAC_SUCCESS or error code + */ +RAC_API rac_result_t rac_file_manager_models_storage_used(const rac_file_callbacks_t* cb, + int64_t* out_size); + +// ============================================================================= +// CACHE & TEMP MANAGEMENT +// ============================================================================= + +/** + * @brief Clear the cache directory. + * + * Deletes all files and subdirectories in the cache directory, + * then recreates the empty cache directory. + * + * Replaces: + * - Swift: SimplifiedFileManager.clearCache() + * - Kotlin: RunAnywhere.clearCache() + * - Flutter: SimplifiedFileManager.clearCache() + * + * @param cb Platform I/O callbacks + * @return RAC_SUCCESS or error code + */ +RAC_API rac_result_t rac_file_manager_clear_cache(const rac_file_callbacks_t* cb); + +/** + * @brief Clear the temp directory. + * + * Deletes all files and subdirectories in the temp directory, + * then recreates the empty temp directory. + * + * Replaces: + * - Swift: SimplifiedFileManager.cleanTempFiles() + * - Flutter: SimplifiedFileManager.clearTemp() + * + * @param cb Platform I/O callbacks + * @return RAC_SUCCESS or error code + */ +RAC_API rac_result_t rac_file_manager_clear_temp(const rac_file_callbacks_t* cb); + +/** + * @brief Get the cache directory size. + * + * @param cb Platform I/O callbacks + * @param out_size Output: Cache size in bytes + * @return RAC_SUCCESS or error code + */ +RAC_API rac_result_t rac_file_manager_cache_size(const rac_file_callbacks_t* cb, + int64_t* out_size); + +// ============================================================================= +// STORAGE INFO +// ============================================================================= + +/** + * @brief Get combined storage information. + * + * Calculates device storage, models size, cache size, and temp size + * in a single call. + * + * Replaces: + * - Swift: SimplifiedFileManager.getDeviceStorageInfo() + getAvailableSpace() + * - Kotlin: RunAnywhere.storageInfo() + * - Flutter: SimplifiedFileManager.getDeviceStorageInfo() + * + * @param cb Platform I/O callbacks + * @param out_info Output: Storage information + * @return RAC_SUCCESS or error code + */ +RAC_API rac_result_t rac_file_manager_get_storage_info(const rac_file_callbacks_t* cb, + rac_file_manager_storage_info_t* out_info); + +/** + * @brief Check storage availability for a download. + * + * Checks if enough space is available and warns if remaining + * space would be below 1GB after the operation. + * + * Replaces: + * - Kotlin: RunAnywhere.checkStorageAvailability(requiredBytes:) + * - Swift: storage availability logic in download flow + * + * @param cb Platform I/O callbacks + * @param required_bytes Space needed in bytes + * @param out_availability Output: Availability result (uses rac_storage_availability_t + * from rac_storage_analyzer.h) + * @return RAC_SUCCESS or error code + */ +RAC_API rac_result_t rac_file_manager_check_storage( + const rac_file_callbacks_t* cb, int64_t required_bytes, + rac_storage_availability_t* out_availability); + +// ============================================================================= +// DIRECTORY CLEARING (INTERNAL HELPER) +// ============================================================================= + +/** + * @brief Clear all contents of a directory (delete + recreate). + * + * Useful for clearing any directory. Used internally by + * rac_file_manager_clear_cache() and rac_file_manager_clear_temp(). + * + * @param cb Platform I/O callbacks + * @param path Directory path to clear + * @return RAC_SUCCESS or error code + */ +RAC_API rac_result_t rac_file_manager_clear_directory(const rac_file_callbacks_t* cb, + const char* path); + +#ifdef __cplusplus +} +#endif + +#endif /* RAC_FILE_MANAGER_H */ diff --git a/sdk/runanywhere-commons/src/backends/llamacpp/CMakeLists.txt b/sdk/runanywhere-commons/src/backends/llamacpp/CMakeLists.txt index fcca88adb..92f934451 100644 --- a/sdk/runanywhere-commons/src/backends/llamacpp/CMakeLists.txt +++ b/sdk/runanywhere-commons/src/backends/llamacpp/CMakeLists.txt @@ -205,7 +205,7 @@ target_link_libraries(rac_backend_llamacpp PUBLIC ) set_target_properties(rac_backend_llamacpp PROPERTIES - CXX_STANDARD 17 + CXX_STANDARD 20 CXX_STANDARD_REQUIRED ON CXX_EXTENSIONS OFF ) diff --git a/sdk/runanywhere-commons/src/backends/llamacpp/jni/rac_backend_llamacpp_jni.cpp b/sdk/runanywhere-commons/src/backends/llamacpp/jni/rac_backend_llamacpp_jni.cpp index 2c3986d00..56626c4aa 100644 --- a/sdk/runanywhere-commons/src/backends/llamacpp/jni/rac_backend_llamacpp_jni.cpp +++ b/sdk/runanywhere-commons/src/backends/llamacpp/jni/rac_backend_llamacpp_jni.cpp @@ -15,25 +15,19 @@ #include #include -#ifdef __ANDROID__ -#include -#define TAG "RACLlamaCPPJNI" -#define LOGi(...) __android_log_print(ANDROID_LOG_INFO, TAG, __VA_ARGS__) -#define LOGe(...) __android_log_print(ANDROID_LOG_ERROR, TAG, __VA_ARGS__) -#define LOGw(...) __android_log_print(ANDROID_LOG_WARN, TAG, __VA_ARGS__) -#else -#include -#define LOGi(...) fprintf(stdout, "[INFO] " __VA_ARGS__); fprintf(stdout, "\n") -#define LOGe(...) fprintf(stderr, "[ERROR] " __VA_ARGS__); fprintf(stderr, "\n") -#define LOGw(...) fprintf(stdout, "[WARN] " __VA_ARGS__); fprintf(stdout, "\n") -#endif - // Include LlamaCPP backend header (direct API) #include "rac_llm_llamacpp.h" // Include commons for registration and service lookup #include "rac/core/rac_core.h" #include "rac/core/rac_error.h" +#include "rac/core/rac_logger.h" + +// Route JNI logging through unified RAC_LOG_* system +static const char* LOG_TAG = "JNI.LlamaCpp"; +#define LOGi(...) RAC_LOG_INFO(LOG_TAG, __VA_ARGS__) +#define LOGe(...) RAC_LOG_ERROR(LOG_TAG, __VA_ARGS__) +#define LOGw(...) RAC_LOG_WARNING(LOG_TAG, __VA_ARGS__) // Forward declaration for registration functions extern "C" rac_result_t rac_backend_llamacpp_register(void); diff --git a/sdk/runanywhere-commons/src/backends/llamacpp/llamacpp_backend.cpp b/sdk/runanywhere-commons/src/backends/llamacpp/llamacpp_backend.cpp index 76218a7ae..646503868 100644 --- a/sdk/runanywhere-commons/src/backends/llamacpp/llamacpp_backend.cpp +++ b/sdk/runanywhere-commons/src/backends/llamacpp/llamacpp_backend.cpp @@ -10,9 +10,40 @@ #include "rac/core/rac_logger.h" -// Use the RAC logging system -#define LOGI(...) RAC_LOG_INFO("LLM.LlamaCpp", __VA_ARGS__) -#define LOGE(...) RAC_LOG_ERROR("LLM.LlamaCpp", __VA_ARGS__) +// ============================================================================= +// NAMED CONSTANTS +// ============================================================================= + +namespace { + +// Thread configuration +static constexpr int kMinThreads = 1; +static constexpr int kMaxThreads = 8; +static constexpr int kReservedCores = 2; +static constexpr int kDefaultThreads = 4; + +// GPU layer limiting for large models on mobile devices +static constexpr int kLargeModelGpuLayers = 24; + +// Model size thresholds (billions of parameters) +static constexpr double kLargeModelThresholdB = 7.0; +static constexpr double kMediumModelThresholdB = 3.0; +static constexpr double kSmallModelThresholdB = 1.0; + +// Adaptive context sizes per model tier +static constexpr int kLargeModelContextSize = 2048; +static constexpr int kMediumModelContextSize = 4096; +static constexpr int kSmallModelContextSize = 2048; + +// Generation parameters +static constexpr int kReservedEosTokens = 4; // Tokens reserved for EOS at end of context +static constexpr int kRepeatPenaltyWindow = 64; // Last-N tokens for repetition penalty + +// Buffer sizes +static constexpr size_t kChatTemplateBufSize = 2048; +static constexpr size_t kFormattedPromptBufSize = 256 * 1024; + +} // namespace namespace runanywhere { @@ -77,19 +108,19 @@ static void llama_log_callback(ggml_log_level level, const char* fmt, void* data // ============================================================================= LlamaCppBackend::LlamaCppBackend() { - LOGI("LlamaCppBackend created"); + RAC_LOG_INFO("LLM.LlamaCpp","LlamaCppBackend created"); } LlamaCppBackend::~LlamaCppBackend() { cleanup(); - LOGI("LlamaCppBackend destroyed"); + RAC_LOG_INFO("LLM.LlamaCpp","LlamaCppBackend destroyed"); } bool LlamaCppBackend::initialize(const nlohmann::json& config) { std::lock_guard lock(mutex_); if (initialized_) { - LOGI("LlamaCppBackend already initialized"); + RAC_LOG_INFO("LLM.LlamaCpp","LlamaCppBackend already initialized"); return true; } @@ -104,13 +135,14 @@ bool LlamaCppBackend::initialize(const nlohmann::json& config) { if (num_threads_ <= 0) { #ifdef _SC_NPROCESSORS_ONLN - num_threads_ = std::max(1, std::min(8, (int)sysconf(_SC_NPROCESSORS_ONLN) - 2)); + num_threads_ = std::max(kMinThreads, std::min(kMaxThreads, + static_cast(sysconf(_SC_NPROCESSORS_ONLN)) - kReservedCores)); #else - num_threads_ = 4; + num_threads_ = kDefaultThreads; #endif } - LOGI("LlamaCppBackend initialized with %d threads", num_threads_); + RAC_LOG_INFO("LLM.LlamaCpp","LlamaCppBackend initialized with %d threads", num_threads_); create_text_generation(); @@ -133,7 +165,7 @@ void LlamaCppBackend::cleanup() { llama_backend_free(); initialized_ = false; - LOGI("LlamaCppBackend cleaned up"); + RAC_LOG_INFO("LLM.LlamaCpp","LlamaCppBackend cleaned up"); } DeviceType LlamaCppBackend::get_device_type() const { @@ -154,7 +186,7 @@ size_t LlamaCppBackend::get_memory_usage() const { void LlamaCppBackend::create_text_generation() { text_gen_ = std::make_unique(this); - LOGI("Created text generation component"); + RAC_LOG_INFO("LLM.LlamaCpp","Created text generation component"); } // ============================================================================= @@ -162,12 +194,12 @@ void LlamaCppBackend::create_text_generation() { // ============================================================================= LlamaCppTextGeneration::LlamaCppTextGeneration(LlamaCppBackend* backend) : backend_(backend) { - LOGI("LlamaCppTextGeneration created"); + RAC_LOG_INFO("LLM.LlamaCpp","LlamaCppTextGeneration created"); } LlamaCppTextGeneration::~LlamaCppTextGeneration() { unload_model(); - LOGI("LlamaCppTextGeneration destroyed"); + RAC_LOG_INFO("LLM.LlamaCpp","LlamaCppTextGeneration destroyed"); } bool LlamaCppTextGeneration::is_ready() const { @@ -179,11 +211,11 @@ bool LlamaCppTextGeneration::load_model(const std::string& model_path, std::lock_guard lock(mutex_); if (model_loaded_) { - LOGI("Unloading previous model before loading new one"); + RAC_LOG_INFO("LLM.LlamaCpp","Unloading previous model before loading new one"); unload_model_internal(); } - LOGI("Loading model from: %s", model_path.c_str()); + RAC_LOG_INFO("LLM.LlamaCpp","Loading model from: %s", model_path.c_str()); int user_context_size = 0; if (config.contains("context_size")) { @@ -241,9 +273,9 @@ bool LlamaCppTextGeneration::load_model(const std::string& model_path, bool is_large_model = false; if (config.contains("expected_params_billions")) { double expected_params = config["expected_params_billions"].get(); - is_large_model = (expected_params >= 7.0); + is_large_model = (expected_params >= kLargeModelThresholdB); if (is_large_model) { - LOGI("Large model detected from config (%.1fB expected params)", expected_params); + RAC_LOG_INFO("LLM.LlamaCpp","Large model detected from config (%.1fB expected params)", expected_params); } } @@ -259,59 +291,56 @@ bool LlamaCppTextGeneration::load_model(const std::string& model_path, if (is_large_model) { // For 7B+ models on mobile: limit GPU layers to prevent OOM // Most 7B models have 32 layers, offload ~24 to GPU, rest to CPU - gpu_layers = 24; - LOGI("Large model detected, limiting GPU layers to %d to prevent OOM", gpu_layers); + gpu_layers = kLargeModelGpuLayers; + RAC_LOG_INFO("LLM.LlamaCpp","Large model detected, limiting GPU layers to %d to prevent OOM", gpu_layers); } // Allow user override via config if (config.contains("gpu_layers")) { gpu_layers = config["gpu_layers"].get(); - LOGI("Using user-provided GPU layers: %d", gpu_layers); + RAC_LOG_INFO("LLM.LlamaCpp","Using user-provided GPU layers: %d", gpu_layers); } model_params.n_gpu_layers = gpu_layers; - LOGI("Loading model with n_gpu_layers=%d", gpu_layers); + RAC_LOG_INFO("LLM.LlamaCpp","Loading model with n_gpu_layers=%d", gpu_layers); model_ = llama_model_load_from_file(model_path.c_str(), model_params); if (!model_) { - LOGE("Failed to load model from: %s", model_path.c_str()); + RAC_LOG_ERROR("LLM.LlamaCpp","Failed to load model from: %s", model_path.c_str()); return false; } int model_train_ctx = llama_model_n_ctx_train(model_); - LOGI("Model training context size: %d", model_train_ctx); + RAC_LOG_INFO("LLM.LlamaCpp","Model training context size: %d", model_train_ctx); // Get model parameter count to determine appropriate context size // Large models (7B+) need smaller context on mobile to fit in memory uint64_t n_params = llama_model_n_params(model_); double params_billions = static_cast(n_params) / 1e9; - LOGI("Model parameters: %.2fB", params_billions); + RAC_LOG_INFO("LLM.LlamaCpp","Model parameters: %.2fB", params_billions); // Post-load verification: warn if actual param count differs from filename heuristic - bool actual_is_large = (params_billions >= 7.0); + bool actual_is_large = (params_billions >= kLargeModelThresholdB); if (actual_is_large && !is_large_model) { - LOGI("WARNING: Model has %.1fB params but filename didn't indicate large model. " + RAC_LOG_INFO("LLM.LlamaCpp","WARNING: Model has %.1fB params but filename didn't indicate large model. " "Consider using gpu_layers config for optimal performance.", params_billions); } else if (!actual_is_large && is_large_model) { - LOGI("NOTE: Filename suggested large model but actual params are %.1fB. " + RAC_LOG_INFO("LLM.LlamaCpp","NOTE: Filename suggested large model but actual params are %.1fB. " "GPU layer limiting may be conservative.", params_billions); } // Adaptive context size based on model size for mobile devices int adaptive_max_context; - if (params_billions >= 7.0) { - // 7B+ models: use 2048 context to fit in ~6GB GPU memory - adaptive_max_context = 2048; - LOGI("Large model detected (%.1fB params), limiting context to %d for memory", params_billions, adaptive_max_context); - } else if (params_billions >= 3.0) { - // 3-7B models: use 4096 context - adaptive_max_context = 4096; - LOGI("Medium model detected (%.1fB params), limiting context to %d", params_billions, adaptive_max_context); - } else if (params_billions >= 1.0) { - // 1-3B models: use 2048 context (higher values OOM on mobile, especially with LoRA) - adaptive_max_context = 2048; - LOGI("Small-medium model detected (%.1fB params), limiting context to %d", params_billions, adaptive_max_context); + if (params_billions >= kLargeModelThresholdB) { + adaptive_max_context = kLargeModelContextSize; + RAC_LOG_INFO("LLM.LlamaCpp","Large model detected (%.1fB params), limiting context to %d for memory", params_billions, adaptive_max_context); + } else if (params_billions >= kMediumModelThresholdB) { + adaptive_max_context = kMediumModelContextSize; + RAC_LOG_INFO("LLM.LlamaCpp","Medium model detected (%.1fB params), limiting context to %d", params_billions, adaptive_max_context); + } else if (params_billions >= kSmallModelThresholdB) { + adaptive_max_context = kSmallModelContextSize; + RAC_LOG_INFO("LLM.LlamaCpp","Small-medium model detected (%.1fB params), limiting context to %d", params_billions, adaptive_max_context); } else { // Tiny models (<1B): can use larger context adaptive_max_context = max_default_context_; @@ -319,11 +348,11 @@ bool LlamaCppTextGeneration::load_model(const std::string& model_path, if (user_context_size > 0) { context_size_ = std::min(user_context_size, model_train_ctx); - LOGI("Using user-provided context size: %d (requested: %d, model max: %d)", context_size_, + RAC_LOG_INFO("LLM.LlamaCpp","Using user-provided context size: %d (requested: %d, model max: %d)", context_size_, user_context_size, model_train_ctx); } else { context_size_ = std::min({model_train_ctx, max_default_context_, adaptive_max_context}); - LOGI("Auto-detected context size: %d (model: %d, cap: %d, adaptive: %d)", context_size_, + RAC_LOG_INFO("LLM.LlamaCpp","Auto-detected context size: %d (model: %d, cap: %d, adaptive: %d)", context_size_, model_train_ctx, max_default_context_, adaptive_max_context); } @@ -342,13 +371,13 @@ bool LlamaCppTextGeneration::load_model(const std::string& model_path, ctx_params.n_threads_batch = backend_->get_num_threads(); ctx_params.no_perf = true; - LOGI("Context params: n_ctx=%d, n_batch=%d, n_ubatch=%d", + RAC_LOG_INFO("LLM.LlamaCpp", "Context params: n_ctx=%d, n_batch=%d, n_ubatch=%d", ctx_params.n_ctx, ctx_params.n_batch, ctx_params.n_ubatch); context_ = llama_init_from_model(model_, ctx_params); if (!context_) { - LOGE("Failed to create context"); + RAC_LOG_ERROR("LLM.LlamaCpp","Failed to create context"); llama_model_free(model_); model_ = nullptr; return false; @@ -362,7 +391,7 @@ bool LlamaCppTextGeneration::load_model(const std::string& model_path, llama_sampler_chain_add(sampler_, llama_sampler_init_greedy()); model_loaded_ = true; - LOGI("Model loaded successfully: context_size=%d", context_size_); + RAC_LOG_INFO("LLM.LlamaCpp","Model loaded successfully: context_size=%d", context_size_); return true; } @@ -376,7 +405,7 @@ bool LlamaCppTextGeneration::unload_model_internal() { return true; } - LOGI("Unloading model"); + RAC_LOG_INFO("LLM.LlamaCpp","Unloading model"); // Clear LoRA adapters from context before freeing // (adapter memory is freed automatically with the model per llama.cpp API) @@ -404,7 +433,7 @@ bool LlamaCppTextGeneration::unload_model_internal() { model_loaded_ = false; model_path_.clear(); - LOGI("Model unloaded"); + RAC_LOG_INFO("LLM.LlamaCpp","Model unloaded"); return true; } @@ -420,14 +449,14 @@ std::string LlamaCppTextGeneration::build_prompt(const TextGenerationRequest& re messages = request.messages; } else if (!request.prompt.empty()) { messages.push_back({"user", request.prompt}); - LOGI("Converted prompt to user message for chat template"); + RAC_LOG_INFO("LLM.LlamaCpp","Converted prompt to user message for chat template"); } else { - LOGE("No prompt or messages provided"); + RAC_LOG_ERROR("LLM.LlamaCpp","No prompt or messages provided"); return ""; } std::string formatted = apply_chat_template(messages, request.system_prompt, true); - LOGI("Applied chat template, formatted prompt length: %zu", formatted.length()); + RAC_LOG_INFO("LLM.LlamaCpp","Applied chat template, formatted prompt length: %zu", formatted.length()); return formatted; } @@ -452,7 +481,7 @@ std::string LlamaCppTextGeneration::apply_chat_template( } std::string model_template; - model_template.resize(2048); + model_template.resize(kChatTemplateBufSize); int32_t template_len = llama_model_meta_val_str(model_, "tokenizer.chat_template", model_template.data(), model_template.size()); @@ -463,7 +492,7 @@ std::string LlamaCppTextGeneration::apply_chat_template( } std::string formatted; - formatted.resize(1024 * 256); + formatted.resize(kFormattedPromptBufSize); // llama_chat_apply_template may throw C++ exceptions for unsupported Jinja // template features (e.g. certain model chat templates use advanced Jinja syntax @@ -475,15 +504,15 @@ std::string LlamaCppTextGeneration::apply_chat_template( llama_chat_apply_template(tmpl_to_use, chat_messages.data(), chat_messages.size(), add_assistant_token, formatted.data(), formatted.size()); } catch (const std::exception& e) { - LOGE("llama_chat_apply_template threw exception: %s", e.what()); + RAC_LOG_ERROR("LLM.LlamaCpp","llama_chat_apply_template threw exception: %s", e.what()); result = -1; } catch (...) { - LOGE("llama_chat_apply_template threw unknown exception"); + RAC_LOG_ERROR("LLM.LlamaCpp","llama_chat_apply_template threw unknown exception"); result = -1; } if (result < 0) { - LOGI("Chat template failed (result=%d), using simple fallback format", result); + RAC_LOG_INFO("LLM.LlamaCpp","Chat template failed (result=%d), using simple fallback format", result); std::string fallback; for (const auto& msg : chat_messages) { fallback += std::string(msg.role) + ": " + msg.content + "\n"; @@ -500,7 +529,12 @@ std::string LlamaCppTextGeneration::apply_chat_template( result = llama_chat_apply_template(tmpl_to_use, chat_messages.data(), chat_messages.size(), add_assistant_token, formatted.data(), formatted.size()); } catch (...) { - LOGE("llama_chat_apply_template threw exception on retry"); + RAC_LOG_ERROR("LLM.LlamaCpp","llama_chat_apply_template threw exception on retry"); + result = -1; + } + + if (result <= 0) { + RAC_LOG_INFO("LLM.LlamaCpp","Chat template retry failed (result=%d), using simple fallback format", result); std::string fallback; for (const auto& msg : chat_messages) { fallback += std::string(msg.role) + ": " + msg.content + "\n"; @@ -520,7 +554,7 @@ std::string LlamaCppTextGeneration::apply_chat_template( } TextGenerationResult LlamaCppTextGeneration::generate(const TextGenerationRequest& request) { - LOGI("generate() START: max_tokens=%d, temp=%.2f, prompt_len=%zu", + RAC_LOG_INFO("LLM.LlamaCpp","generate() START: max_tokens=%d, temp=%.2f, prompt_len=%zu", request.max_tokens, request.temperature, request.prompt.length()); TextGenerationResult result; @@ -532,7 +566,7 @@ TextGenerationResult LlamaCppTextGeneration::generate(const TextGenerationReques auto start_time = std::chrono::high_resolution_clock::now(); - LOGI("generate(): calling generate_stream..."); + RAC_LOG_INFO("LLM.LlamaCpp","generate(): calling generate_stream..."); bool success = generate_stream( request, [&](const std::string& token) -> bool { @@ -541,7 +575,7 @@ TextGenerationResult LlamaCppTextGeneration::generate(const TextGenerationReques return !cancel_requested_.load(); }, &prompt_tokens); - LOGI("generate(): generate_stream returned success=%d, tokens=%d", success, tokens_generated); + RAC_LOG_INFO("LLM.LlamaCpp","generate(): generate_stream returned success=%d, tokens=%d", success, tokens_generated); auto end_time = std::chrono::high_resolution_clock::now(); auto duration = std::chrono::duration_cast(end_time - start_time); @@ -568,7 +602,7 @@ bool LlamaCppTextGeneration::generate_stream(const TextGenerationRequest& reques std::lock_guard lock(mutex_); if (!is_ready()) { - LOGE("Model not ready for generation"); + RAC_LOG_ERROR("LLM.LlamaCpp","Model not ready for generation"); return false; } @@ -583,30 +617,30 @@ bool LlamaCppTextGeneration::generate_stream(const TextGenerationRequest& reques decode_failed_ = false; std::string prompt = build_prompt(request); - LOGI("Generating with prompt length: %zu", prompt.length()); + RAC_LOG_INFO("LLM.LlamaCpp","Generating with prompt length: %zu", prompt.length()); const auto tokens_list = common_tokenize(context_, prompt, true, true); - int n_ctx = llama_n_ctx(context_); - int prompt_tokens = static_cast(tokens_list.size()); + const int n_ctx = llama_n_ctx(context_); + const int prompt_tokens = static_cast(tokens_list.size()); if (out_prompt_tokens) { *out_prompt_tokens = prompt_tokens; } - int available_tokens = n_ctx - prompt_tokens - 4; + const int available_tokens = n_ctx - prompt_tokens - kReservedEosTokens; if (available_tokens <= 0) { - LOGE("Prompt too long: %d tokens, context size: %d", prompt_tokens, n_ctx); + RAC_LOG_ERROR("LLM.LlamaCpp","Prompt too long: %d tokens, context size: %d", prompt_tokens, n_ctx); return false; } - int effective_max_tokens = std::min(request.max_tokens, available_tokens); - LOGI("Generation: prompt_tokens=%d, max_tokens=%d, context=%d", + const int effective_max_tokens = std::min(request.max_tokens, available_tokens); + RAC_LOG_INFO("LLM.LlamaCpp","Generation: prompt_tokens=%d, max_tokens=%d, context=%d", prompt_tokens, effective_max_tokens, n_ctx); const int n_batch = batch_size_ > 0 ? batch_size_ : n_ctx; - LOGI("generate_stream: processing %d prompt tokens in chunks of %d", prompt_tokens, n_batch); + RAC_LOG_INFO("LLM.LlamaCpp", "generate_stream: processing %d prompt tokens in chunks of %d", prompt_tokens, n_batch); llama_batch batch = llama_batch_init(n_batch, 0, 1); for (int chunk_start = 0; chunk_start < prompt_tokens; chunk_start += n_batch) { @@ -620,47 +654,61 @@ bool LlamaCppTextGeneration::generate_stream(const TextGenerationRequest& reques } if (llama_decode(context_, batch) != 0) { - LOGE("llama_decode failed for prompt chunk [%d..%d)", chunk_start, chunk_end); + RAC_LOG_ERROR("LLM.LlamaCpp", "llama_decode failed for prompt chunk [%d..%d)", chunk_start, chunk_end); llama_batch_free(batch); return false; } } - LOGI("generate_stream: prompt decoded successfully"); + RAC_LOG_INFO("LLM.LlamaCpp", "generate_stream: prompt decoded successfully"); - // Configure sampler with request parameters - if (sampler_) { - llama_sampler_free(sampler_); - } + // Configure sampler with request parameters — skip rebuild if params unchanged + { + const bool params_match = sampler_ && + cached_temperature_ == request.temperature && + cached_top_p_ == request.top_p && + cached_top_k_ == request.top_k && + cached_repetition_penalty_ == request.repetition_penalty; + + if (!params_match) { + if (sampler_) { + llama_sampler_free(sampler_); + } - auto sparams = llama_sampler_chain_default_params(); - sparams.no_perf = true; - sampler_ = llama_sampler_chain_init(sparams); + auto sparams = llama_sampler_chain_default_params(); + sparams.no_perf = true; + sampler_ = llama_sampler_chain_init(sparams); - if (request.temperature > 0.0f) { - // Use default penalties (1.2f repetition) or request params if added later - llama_sampler_chain_add(sampler_, - llama_sampler_init_penalties(64, request.repetition_penalty, 0.0f, 0.0f)); + if (request.temperature > 0.0f) { + llama_sampler_chain_add(sampler_, + llama_sampler_init_penalties(kRepeatPenaltyWindow, request.repetition_penalty, 0.0f, 0.0f)); - if (request.top_k > 0) { - llama_sampler_chain_add(sampler_, llama_sampler_init_top_k(request.top_k)); - } + if (request.top_k > 0) { + llama_sampler_chain_add(sampler_, llama_sampler_init_top_k(request.top_k)); + } - llama_sampler_chain_add(sampler_, llama_sampler_init_top_p(request.top_p, 1)); - llama_sampler_chain_add(sampler_, llama_sampler_init_temp(request.temperature)); - llama_sampler_chain_add(sampler_, llama_sampler_init_dist(LLAMA_DEFAULT_SEED)); - } else { - llama_sampler_chain_add(sampler_, llama_sampler_init_greedy()); + llama_sampler_chain_add(sampler_, llama_sampler_init_top_p(request.top_p, 1)); + llama_sampler_chain_add(sampler_, llama_sampler_init_temp(request.temperature)); + llama_sampler_chain_add(sampler_, llama_sampler_init_dist(LLAMA_DEFAULT_SEED)); + } else { + llama_sampler_chain_add(sampler_, llama_sampler_init_greedy()); + } + + cached_temperature_ = request.temperature; + cached_top_p_ = request.top_p; + cached_top_k_ = request.top_k; + cached_repetition_penalty_ = request.repetition_penalty; + } } // Log generation parameters - LOGI("[PARAMS] LLM generate_stream (per-request options): temperature=%.4f, top_p=%.4f, top_k=%d, " + RAC_LOG_INFO("LLM.LlamaCpp","[PARAMS] LLM generate_stream (per-request options): temperature=%.4f, top_p=%.4f, top_k=%d, " "max_tokens=%d (effective=%d), repetition_penalty=%.4f, " "system_prompt_len=%zu", request.temperature, request.top_p, request.top_k, request.max_tokens, effective_max_tokens, request.repetition_penalty, request.system_prompt.length()); - const auto vocab = llama_model_get_vocab(model_); + const auto* const vocab = llama_model_get_vocab(model_); static const std::vector STOP_SEQUENCES = { "<|im_end|>", "<|eot_id|>", "", "<|end|>", "<|endoftext|>", @@ -679,6 +727,9 @@ bool LlamaCppTextGeneration::generate_stream(const TextGenerationRequest& reques std::string partial_utf8_buffer; partial_utf8_buffer.reserve(8); + // Persist UTF-8 scanner across iterations to avoid re-scanning partial bytes + Utf8State utf8_scanner; + int n_cur = batch.n_tokens; int tokens_generated = 0; bool stop_sequence_hit = false; @@ -689,20 +740,21 @@ bool LlamaCppTextGeneration::generate_stream(const TextGenerationRequest& reques llama_sampler_accept(sampler_, new_token_id); if (llama_vocab_is_eog(vocab, new_token_id)) { - LOGI("End of generation token received"); + RAC_LOG_INFO("LLM.LlamaCpp","End of generation token received"); break; } const std::string new_token_chars = common_token_to_piece(context_, new_token_id); + // Only scan newly appended bytes — scanner state persists from prior iterations + const size_t scan_start = partial_utf8_buffer.size(); partial_utf8_buffer.append(new_token_chars); - Utf8State scanner_state; size_t valid_upto = 0; - for (size_t i = 0; i < partial_utf8_buffer.size(); ++i) { - scanner_state.process(static_cast(partial_utf8_buffer[i])); - if (scanner_state.state == 0) { + for (size_t i = scan_start; i < partial_utf8_buffer.size(); ++i) { + utf8_scanner.process(static_cast(partial_utf8_buffer[i])); + if (utf8_scanner.state == 0) { valid_upto = i + 1; } } @@ -723,7 +775,7 @@ bool LlamaCppTextGeneration::generate_stream(const TextGenerationRequest& reques } if (found_stop_pos != std::string::npos) { - LOGI("Stop sequence detected"); + RAC_LOG_INFO("LLM.LlamaCpp","Stop sequence detected"); stop_sequence_hit = true; if (found_stop_pos > 0) { if (!callback(stop_window.substr(0, found_stop_pos))) { @@ -736,7 +788,7 @@ bool LlamaCppTextGeneration::generate_stream(const TextGenerationRequest& reques if (stop_window.size() > MAX_STOP_LEN) { size_t safe_len = stop_window.size() - MAX_STOP_LEN; if (!callback(stop_window.substr(0, safe_len))) { - LOGI("Generation cancelled by callback"); + RAC_LOG_INFO("LLM.LlamaCpp","Generation cancelled by callback"); cancel_requested_.store(true); break; } @@ -751,12 +803,17 @@ bool LlamaCppTextGeneration::generate_stream(const TextGenerationRequest& reques tokens_generated++; if (llama_decode(context_, batch) != 0) { - LOGE("llama_decode failed during generation"); + RAC_LOG_ERROR("LLM.LlamaCpp","llama_decode failed during generation"); decode_failed_ = true; break; } } + // Flush any remaining partial UTF-8 bytes (e.g. trailing multi-byte char at end of generation) + if (!cancel_requested_.load() && !stop_sequence_hit && !partial_utf8_buffer.empty()) { + stop_window.append(partial_utf8_buffer); + } + if (!cancel_requested_.load() && !stop_sequence_hit && !stop_window.empty()) { callback(stop_window); } @@ -767,13 +824,13 @@ bool LlamaCppTextGeneration::generate_stream(const TextGenerationRequest& reques llama_batch_free(batch); - LOGI("Generation complete: %d tokens", tokens_generated); + RAC_LOG_INFO("LLM.LlamaCpp","Generation complete: %d tokens", tokens_generated); return !cancel_requested_.load(); } void LlamaCppTextGeneration::cancel() { cancel_requested_.store(true); - LOGI("Generation cancel requested"); + RAC_LOG_INFO("LLM.LlamaCpp","Generation cancel requested"); } @@ -781,7 +838,7 @@ bool LlamaCppTextGeneration::inject_system_prompt(const std::string& prompt) { std::lock_guard lock(mutex_); if (!is_ready()) { - LOGE("inject_system_prompt: model not ready"); + RAC_LOG_ERROR("LLM.LlamaCpp","inject_system_prompt: model not ready"); return false; } @@ -794,13 +851,13 @@ bool LlamaCppTextGeneration::inject_system_prompt(const std::string& prompt) { const int n_tokens = static_cast(tokens.size()); if (n_tokens <= 0) { - LOGE("inject_system_prompt: tokenization produced no tokens"); + RAC_LOG_ERROR("LLM.LlamaCpp","inject_system_prompt: tokenization produced no tokens"); return false; } const int n_ctx = llama_n_ctx(context_); if (n_tokens >= n_ctx) { - LOGE("inject_system_prompt: prompt too long (%d tokens, ctx=%d)", n_tokens, n_ctx); + RAC_LOG_ERROR("LLM.LlamaCpp","inject_system_prompt: prompt too long (%d tokens, ctx=%d)", n_tokens, n_ctx); return false; } @@ -816,14 +873,14 @@ bool LlamaCppTextGeneration::inject_system_prompt(const std::string& prompt) { } if (llama_decode(context_, batch) != 0) { - LOGE("inject_system_prompt: llama_decode failed at chunk [%d..%d)", chunk_start, chunk_end); + RAC_LOG_ERROR("LLM.LlamaCpp","inject_system_prompt: llama_decode failed at chunk [%d..%d)", chunk_start, chunk_end); llama_batch_free(batch); return false; } } llama_batch_free(batch); - LOGI("inject_system_prompt: injected %d tokens into KV cache", n_tokens); + RAC_LOG_INFO("LLM.LlamaCpp","inject_system_prompt: injected %d tokens into KV cache", n_tokens); return true; } @@ -831,7 +888,7 @@ bool LlamaCppTextGeneration::append_context(const std::string& text) { std::lock_guard lock(mutex_); if (!is_ready()) { - LOGE("append_context: model not ready"); + RAC_LOG_ERROR("LLM.LlamaCpp","append_context: model not ready"); return false; } @@ -847,7 +904,7 @@ bool LlamaCppTextGeneration::append_context(const std::string& text) { const int n_ctx = llama_n_ctx(context_); if (start_pos + n_tokens >= n_ctx) { - LOGE("append_context: context full (pos=%d, tokens=%d, ctx=%d)", start_pos, n_tokens, n_ctx); + RAC_LOG_ERROR("LLM.LlamaCpp","append_context: context full (pos=%d, tokens=%d, ctx=%d)", start_pos, n_tokens, n_ctx); return false; } @@ -863,14 +920,14 @@ bool LlamaCppTextGeneration::append_context(const std::string& text) { } if (llama_decode(context_, batch) != 0) { - LOGE("append_context: llama_decode failed at chunk [%d..%d)", chunk_start, chunk_end); + RAC_LOG_ERROR("LLM.LlamaCpp","append_context: llama_decode failed at chunk [%d..%d)", chunk_start, chunk_end); llama_batch_free(batch); return false; } } llama_batch_free(batch); - LOGI("append_context: appended %d tokens at pos %d", n_tokens, start_pos); + RAC_LOG_INFO("LLM.LlamaCpp","append_context: appended %d tokens at pos %d", n_tokens, start_pos); return true; } @@ -881,7 +938,7 @@ TextGenerationResult LlamaCppTextGeneration::generate_from_context(const TextGen result.finish_reason = "error"; if (!is_ready()) { - LOGE("generate_from_context: model not ready"); + RAC_LOG_ERROR("LLM.LlamaCpp","generate_from_context: model not ready"); return result; } @@ -894,7 +951,7 @@ TextGenerationResult LlamaCppTextGeneration::generate_from_context(const TextGen const int n_prompt = static_cast(tokens.size()); if (n_prompt <= 0) { - LOGE("generate_from_context: failed to tokenize prompt"); + RAC_LOG_ERROR("LLM.LlamaCpp","generate_from_context: failed to tokenize prompt"); return result; } @@ -905,13 +962,13 @@ TextGenerationResult LlamaCppTextGeneration::generate_from_context(const TextGen const int available_tokens = n_ctx - static_cast(current_pos) - n_prompt - 4; if (available_tokens <= 0) { - LOGE("generate_from_context: no space for generation (pos=%d, prompt=%d, ctx=%d)", + RAC_LOG_ERROR("LLM.LlamaCpp","generate_from_context: no space for generation (pos=%d, prompt=%d, ctx=%d)", static_cast(current_pos), n_prompt, n_ctx); return result; } const int effective_max_tokens = std::min(request.max_tokens, available_tokens); - LOGI("generate_from_context: pos=%d, prompt_tokens=%d, max_tokens=%d", + RAC_LOG_INFO("LLM.LlamaCpp","generate_from_context: pos=%d, prompt_tokens=%d, max_tokens=%d", static_cast(current_pos), n_prompt, effective_max_tokens); const int n_batch_lim = batch_size_ > 0 ? batch_size_ : n_ctx; @@ -928,7 +985,7 @@ TextGenerationResult LlamaCppTextGeneration::generate_from_context(const TextGen } if (llama_decode(context_, batch) != 0) { - LOGE("generate_from_context: llama_decode failed at chunk [%d..%d)", chunk_start, chunk_end); + RAC_LOG_ERROR("LLM.LlamaCpp","generate_from_context: llama_decode failed at chunk [%d..%d)", chunk_start, chunk_end); llama_batch_free(batch); return result; } @@ -1054,7 +1111,7 @@ TextGenerationResult LlamaCppTextGeneration::generate_from_context(const TextGen tokens_generated++; if (llama_decode(context_, batch) != 0) { - LOGE("generate_from_context: llama_decode failed during generation"); + RAC_LOG_ERROR("LLM.LlamaCpp","generate_from_context: llama_decode failed during generation"); decode_failed_ = true; break; } @@ -1080,7 +1137,7 @@ TextGenerationResult LlamaCppTextGeneration::generate_from_context(const TextGen result.finish_reason = tokens_generated >= effective_max_tokens ? "length" : "stop"; } - LOGI("generate_from_context: complete, tokens=%d, reason=%s", + RAC_LOG_INFO("LLM.LlamaCpp","generate_from_context: complete, tokens=%d, reason=%s", tokens_generated, result.finish_reason.c_str()); return result; } @@ -1093,7 +1150,7 @@ void LlamaCppTextGeneration::clear_context() { if (mem) { llama_memory_clear(mem, true); } - LOGI("clear_context: KV cache cleared"); + RAC_LOG_INFO("LLM.LlamaCpp","clear_context: KV cache cleared"); } } @@ -1124,7 +1181,7 @@ nlohmann::json LlamaCppTextGeneration::get_model_info() const { // ============================================================================= bool LlamaCppTextGeneration::recreate_context() { - LOGI("Recreating context to accommodate LoRA adapters"); + RAC_LOG_INFO("LLM.LlamaCpp","Recreating context to accommodate LoRA adapters"); // Free existing sampler and context if (sampler_) { @@ -1148,17 +1205,23 @@ bool LlamaCppTextGeneration::recreate_context() { context_ = llama_init_from_model(model_, ctx_params); if (!context_) { - LOGE("Failed to recreate context after LoRA adapter load"); + RAC_LOG_ERROR("LLM.LlamaCpp","Failed to recreate context after LoRA adapter load"); return false; } - // Rebuild sampler chain + // Rebuild sampler chain (greedy placeholder — real sampler built on first generate_stream) auto sparams = llama_sampler_chain_default_params(); sparams.no_perf = true; sampler_ = llama_sampler_chain_init(sparams); llama_sampler_chain_add(sampler_, llama_sampler_init_greedy()); - LOGI("Context recreated successfully"); + // Invalidate cached params so the next generate_stream() rebuilds with real params + cached_temperature_ = -1.0f; + cached_top_p_ = -1.0f; + cached_top_k_ = -1; + cached_repetition_penalty_ = -1.0f; + + RAC_LOG_INFO("LLM.LlamaCpp","Context recreated successfully"); return true; } @@ -1180,7 +1243,7 @@ bool LlamaCppTextGeneration::apply_lora_adapters() { int32_t result = llama_set_adapters_lora(context_, adapters.data(), adapters.size(), scales.data()); if (result != 0) { - LOGE("Failed to apply LoRA adapters (error=%d)", result); + RAC_LOG_ERROR("LLM.LlamaCpp","Failed to apply LoRA adapters (error=%d)", result); for (auto& entry : lora_adapters_) { entry.applied = false; } @@ -1189,7 +1252,7 @@ bool LlamaCppTextGeneration::apply_lora_adapters() { for (auto& entry : lora_adapters_) { entry.applied = true; - LOGI("Applied LoRA adapter: %s (scale=%.2f)", entry.path.c_str(), entry.scale); + RAC_LOG_INFO("LLM.LlamaCpp","Applied LoRA adapter: %s (scale=%.2f)", entry.path.c_str(), entry.scale); } return true; } @@ -1198,24 +1261,24 @@ bool LlamaCppTextGeneration::load_lora_adapter(const std::string& adapter_path, std::lock_guard lock(mutex_); if (!model_loaded_ || !model_) { - LOGE("Cannot load LoRA adapter: model not loaded"); + RAC_LOG_ERROR("LLM.LlamaCpp","Cannot load LoRA adapter: model not loaded"); return false; } // Check if adapter already loaded for (const auto& entry : lora_adapters_) { if (entry.path == adapter_path) { - LOGE("LoRA adapter already loaded: %s", adapter_path.c_str()); + RAC_LOG_ERROR("LLM.LlamaCpp","LoRA adapter already loaded: %s", adapter_path.c_str()); return false; } } - LOGI("Loading LoRA adapter: %s (scale=%.2f)", adapter_path.c_str(), scale); + RAC_LOG_INFO("LLM.LlamaCpp","Loading LoRA adapter: %s (scale=%.2f)", adapter_path.c_str(), scale); // Load adapter against model llama_adapter_lora* adapter = llama_adapter_lora_init(model_, adapter_path.c_str()); if (!adapter) { - LOGE("Failed to load LoRA adapter from: %s", adapter_path.c_str()); + RAC_LOG_ERROR("LLM.LlamaCpp","Failed to load LoRA adapter from: %s", adapter_path.c_str()); return false; } @@ -1243,7 +1306,7 @@ bool LlamaCppTextGeneration::load_lora_adapter(const std::string& adapter_path, // Clear KV cache after adapter changes llama_memory_clear(llama_get_memory(context_), true); - LOGI("LoRA adapter loaded and applied: %s (%zu total adapters)", + RAC_LOG_INFO("LLM.LlamaCpp","LoRA adapter loaded and applied: %s (%zu total adapters)", adapter_path.c_str(), lora_adapters_.size()); return true; } @@ -1252,7 +1315,7 @@ bool LlamaCppTextGeneration::remove_lora_adapter(const std::string& adapter_path std::lock_guard lock(mutex_); if (!model_loaded_ || !context_) { - LOGE("Cannot remove LoRA adapter: model not loaded"); + RAC_LOG_ERROR("LLM.LlamaCpp","Cannot remove LoRA adapter: model not loaded"); return false; } @@ -1260,7 +1323,7 @@ bool LlamaCppTextGeneration::remove_lora_adapter(const std::string& adapter_path [&adapter_path](const LoraAdapterEntry& e) { return e.path == adapter_path; }); if (it == lora_adapters_.end()) { - LOGE("LoRA adapter not found: %s", adapter_path.c_str()); + RAC_LOG_ERROR("LLM.LlamaCpp","LoRA adapter not found: %s", adapter_path.c_str()); return false; } @@ -1268,14 +1331,14 @@ bool LlamaCppTextGeneration::remove_lora_adapter(const std::string& adapter_path // Re-apply remaining adapters (or clear if none left) if (!apply_lora_adapters()) { - LOGE("Failed to re-apply remaining LoRA adapters after removal"); + RAC_LOG_ERROR("LLM.LlamaCpp","Failed to re-apply remaining LoRA adapters after removal"); return false; } // Clear KV cache after adapter changes llama_memory_clear(llama_get_memory(context_), true); - LOGI("LoRA adapter removed: %s (%zu remaining)", adapter_path.c_str(), lora_adapters_.size()); + RAC_LOG_INFO("LLM.LlamaCpp","LoRA adapter removed: %s (%zu remaining)", adapter_path.c_str(), lora_adapters_.size()); return true; } @@ -1292,7 +1355,7 @@ void LlamaCppTextGeneration::clear_lora_adapters() { } lora_adapters_.clear(); - LOGI("All LoRA adapters cleared"); + RAC_LOG_INFO("LLM.LlamaCpp","All LoRA adapters cleared"); } nlohmann::json LlamaCppTextGeneration::get_lora_info() const { diff --git a/sdk/runanywhere-commons/src/backends/llamacpp/llamacpp_backend.h b/sdk/runanywhere-commons/src/backends/llamacpp/llamacpp_backend.h index 364baa668..3fa5bc1a1 100644 --- a/sdk/runanywhere-commons/src/backends/llamacpp/llamacpp_backend.h +++ b/sdk/runanywhere-commons/src/backends/llamacpp/llamacpp_backend.h @@ -56,6 +56,11 @@ struct TextGenerationResult { std::string finish_reason; // "stop", "length", "cancelled" }; +// Verify request struct size — allocated per generate() call. +// If this fires, review whether new fields should be passed by reference instead. +static_assert(sizeof(TextGenerationRequest) <= 256, + "TextGenerationRequest grew — consider passing by reference or reducing members"); + // Streaming callback: receives token, returns false to cancel using TextStreamCallback = std::function; @@ -182,6 +187,12 @@ class LlamaCppTextGeneration { llama_context* context_ = nullptr; llama_sampler* sampler_ = nullptr; + // Cached sampler parameters — skip rebuild when unchanged + float cached_temperature_ = -1.0f; + float cached_top_p_ = -1.0f; + int cached_top_k_ = -1; + float cached_repetition_penalty_ = -1.0f; + bool model_loaded_ = false; std::atomic cancel_requested_{false}; std::atomic decode_failed_{false}; diff --git a/sdk/runanywhere-commons/src/backends/llamacpp/rac_llm_llamacpp.cpp b/sdk/runanywhere-commons/src/backends/llamacpp/rac_llm_llamacpp.cpp index 85982adf0..8ff7a45eb 100644 --- a/sdk/runanywhere-commons/src/backends/llamacpp/rac_llm_llamacpp.cpp +++ b/sdk/runanywhere-commons/src/backends/llamacpp/rac_llm_llamacpp.cpp @@ -19,8 +19,6 @@ #include "rac/core/rac_logger.h" #include "rac/infrastructure/events/rac_events.h" -// Use the RAC logging system -#define LOGI(...) RAC_LOG_INFO("LLM.LlamaCpp.C-API", __VA_ARGS__) // ============================================================================= // INTERNAL HANDLE STRUCTURE @@ -140,7 +138,7 @@ rac_bool_t rac_llm_llamacpp_is_model_loaded(rac_handle_t handle) { rac_result_t rac_llm_llamacpp_generate(rac_handle_t handle, const char* prompt, const rac_llm_options_t* options, rac_llm_result_t* out_result) { - RAC_LOG_INFO("LLM.LlamaCpp", "rac_llm_llamacpp_generate: START handle=%p", handle); + RAC_LOG_DEBUG("LLM.LlamaCpp", "rac_llm_llamacpp_generate: START handle=%p", handle); if (handle == nullptr || prompt == nullptr || out_result == nullptr) { RAC_LOG_ERROR("LLM.LlamaCpp", "rac_llm_llamacpp_generate: NULL pointer! handle=%p, prompt=%p, out_result=%p", @@ -148,9 +146,7 @@ rac_result_t rac_llm_llamacpp_generate(rac_handle_t handle, const char* prompt, return RAC_ERROR_NULL_POINTER; } - RAC_LOG_INFO("LLM.LlamaCpp", "rac_llm_llamacpp_generate: casting handle..."); auto* h = static_cast(handle); - RAC_LOG_INFO("LLM.LlamaCpp", "rac_llm_llamacpp_generate: handle cast ok, text_gen=%p", (void*)h->text_gen); if (!h->text_gen) { RAC_LOG_ERROR("LLM.LlamaCpp", "rac_llm_llamacpp_generate: text_gen is null!"); @@ -158,7 +154,7 @@ rac_result_t rac_llm_llamacpp_generate(rac_handle_t handle, const char* prompt, } // Build request from RAC options - RAC_LOG_INFO("LLM.LlamaCpp", "rac_llm_llamacpp_generate: building request, prompt_len=%zu", strlen(prompt)); + RAC_LOG_DEBUG("LLM.LlamaCpp", "rac_llm_llamacpp_generate: building request, prompt_len=%zu", strlen(prompt)); runanywhere::TextGenerationRequest request; request.prompt = prompt; if (options != nullptr) { @@ -178,12 +174,12 @@ rac_result_t rac_llm_llamacpp_generate(rac_handle_t handle, const char* prompt, } } } - LOGI("[PARAMS] LLM C-API (from caller options): max_tokens=%d, temperature=%.4f, " + RAC_LOG_INFO("LLM.LlamaCpp.C-API","[PARAMS] LLM C-API (from caller options): max_tokens=%d, temperature=%.4f, " "top_p=%.4f, system_prompt=%s", request.max_tokens, request.temperature, request.top_p, request.system_prompt.empty() ? "(none)" : "(set)"); } else { - LOGI("[PARAMS] LLM C-API (using struct defaults): max_tokens=%d, temperature=%.4f, " + RAC_LOG_INFO("LLM.LlamaCpp.C-API","[PARAMS] LLM C-API (using struct defaults): max_tokens=%d, temperature=%.4f, " "top_p=%.4f, system_prompt=(none)", request.max_tokens, request.temperature, request.top_p); } @@ -194,7 +190,6 @@ rac_result_t rac_llm_llamacpp_generate(rac_handle_t handle, const char* prompt, // templates that use unsupported features. Without this catch, the exception // propagates through the extern "C" boundary causing undefined behavior in WASM // (Emscripten returns the exception pointer as the function return value). - RAC_LOG_INFO("LLM.LlamaCpp", "rac_llm_llamacpp_generate: calling text_gen->generate()..."); runanywhere::TextGenerationResult result; try { result = h->text_gen->generate(request); @@ -205,7 +200,7 @@ rac_result_t rac_llm_llamacpp_generate(rac_handle_t handle, const char* prompt, rac_error_set_details("Unknown C++ exception during LLM generation"); return RAC_ERROR_INFERENCE_FAILED; } - RAC_LOG_INFO("LLM.LlamaCpp", "rac_llm_llamacpp_generate: generate() returned, tokens=%d", result.tokens_generated); + RAC_LOG_DEBUG("LLM.LlamaCpp", "rac_llm_llamacpp_generate: generate() returned, tokens=%d", result.tokens_generated); // finish_reason is std::string; TODO: migrate to enum if TextGenerationResult gains one if (result.finish_reason == "error") { @@ -215,7 +210,14 @@ rac_result_t rac_llm_llamacpp_generate(rac_handle_t handle, const char* prompt, } // Fill RAC result struct - out_result->text = result.text.empty() ? nullptr : strdup(result.text.c_str()); + if (!result.text.empty()) { + out_result->text = strdup(result.text.c_str()); + if (!out_result->text) { + return RAC_ERROR_OUT_OF_MEMORY; + } + } else { + out_result->text = nullptr; + } out_result->completion_tokens = result.tokens_generated; out_result->prompt_tokens = result.prompt_tokens; out_result->total_tokens = result.prompt_tokens + result.tokens_generated; @@ -262,12 +264,12 @@ rac_result_t rac_llm_llamacpp_generate_stream(rac_handle_t handle, const char* p } } } - LOGI("[PARAMS] LLM C-API (from caller options): max_tokens=%d, temperature=%.4f, " + RAC_LOG_INFO("LLM.LlamaCpp.C-API","[PARAMS] LLM C-API (from caller options): max_tokens=%d, temperature=%.4f, " "top_p=%.4f, system_prompt=%s", request.max_tokens, request.temperature, request.top_p, request.system_prompt.empty() ? "(none)" : "(set)"); } else { - LOGI("[PARAMS] LLM C-API (using struct defaults): max_tokens=%d, temperature=%.4f, " + RAC_LOG_INFO("LLM.LlamaCpp.C-API","[PARAMS] LLM C-API (using struct defaults): max_tokens=%d, temperature=%.4f, " "top_p=%.4f, system_prompt=(none)", request.max_tokens, request.temperature, request.top_p); } @@ -325,6 +327,9 @@ rac_result_t rac_llm_llamacpp_get_model_info(rac_handle_t handle, char** out_jso std::string json_str = info.dump(); *out_json = strdup(json_str.c_str()); + if (!*out_json) { + return RAC_ERROR_OUT_OF_MEMORY; + } return RAC_SUCCESS; } @@ -398,6 +403,9 @@ rac_result_t rac_llm_llamacpp_get_lora_info(rac_handle_t handle, char** out_json auto info = h->text_gen->get_lora_info(); std::string json_str = info.dump(); *out_json = strdup(json_str.c_str()); + if (!*out_json) { + return RAC_ERROR_OUT_OF_MEMORY; + } return RAC_SUCCESS; } diff --git a/sdk/runanywhere-commons/src/backends/llamacpp/rac_vlm_llamacpp.cpp b/sdk/runanywhere-commons/src/backends/llamacpp/rac_vlm_llamacpp.cpp index c716b44c0..c2bd98599 100644 --- a/sdk/runanywhere-commons/src/backends/llamacpp/rac_vlm_llamacpp.cpp +++ b/sdk/runanywhere-commons/src/backends/llamacpp/rac_vlm_llamacpp.cpp @@ -34,6 +34,14 @@ static const char* LOG_CAT = "VLM.LlamaCPP"; +// ============================================================================= +// NAMED CONSTANTS +// ============================================================================= + +static constexpr int kDefaultMaxContextSize = 4096; +static constexpr int kDefaultBatchSize = 512; +static constexpr int kDefaultMaxTokens = 2048; + // ============================================================================= // INTERNAL BACKEND STATE // ============================================================================= @@ -73,6 +81,10 @@ struct LlamaCppVLMBackend { // Detected model type for chat template VLMModelType model_type = static_cast(0); // Unknown + // Cached sampler parameters to avoid unnecessary rebuilds + float cached_temperature = -1.0f; + float cached_top_p = -1.0f; + // Thread safety mutable std::mutex mutex; }; @@ -80,12 +92,12 @@ struct LlamaCppVLMBackend { /** * Get number of CPU threads to use. */ -int get_num_threads(int config_threads) { +int get_num_threads(const int config_threads) { if (config_threads > 0) return config_threads; // Auto-detect based on hardware - int threads = std::thread::hardware_concurrency(); + int threads = static_cast(std::thread::hardware_concurrency()); if (threads <= 0) threads = 4; if (threads > 8) @@ -259,59 +271,6 @@ std::string format_vlm_prompt_with_template(llama_model* model, const std::strin return formatted; } -/** - * Legacy format function for backward compatibility. - * Uses model type detection for manual template selection. - */ -std::string format_vlm_prompt(VLMModelType model_type, const std::string& user_prompt, - const char* image_marker, bool has_image) { - std::string formatted; - - // Build user content with image marker - std::string user_content; - if (has_image) { - user_content = std::string(image_marker) + user_prompt; - } else { - user_content = user_prompt; - } - - switch (model_type) { - case VLMModelType::SmolVLM: - // SmolVLM format: <|im_start|>User: content \nAssistant: - formatted = "<|im_start|>User: "; - formatted += user_content; - formatted += " \nAssistant:"; - break; - - case VLMModelType::Qwen2VL: - // Qwen2-VL chatml format - formatted = "<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n"; - formatted += "<|im_start|>user\n"; - formatted += user_content; - formatted += "<|im_end|>\n<|im_start|>assistant\n"; - break; - - case VLMModelType::LLaVA: - // LLaVA/Vicuna format - formatted = "USER: "; - formatted += user_content; - formatted += "\nASSISTANT:"; - break; - - case VLMModelType::Generic: - default: - // Generic chatml format - formatted = "<|im_start|>user\n"; - formatted += user_content; - formatted += "<|im_end|>\n<|im_start|>assistant\n"; - break; - } - - RAC_LOG_DEBUG(LOG_CAT, "Formatted prompt (%d chars): %.100s...", - (int)formatted.length(), formatted.c_str()); - return formatted; -} - /** * Get the image marker string. * When mtmd is available, uses the default marker from mtmd. @@ -327,15 +286,10 @@ const char* get_image_marker() { /** * Configure the sampler chain with the given generation parameters. - * Rebuilds the sampler to apply per-request temperature, top_p, etc. + * Only rebuilds the sampler when parameters actually change, avoiding + * unnecessary heap allocations on every inference call. */ void configure_sampler(LlamaCppVLMBackend* backend, const rac_vlm_options_t* options) { - // Free existing sampler - if (backend->sampler) { - llama_sampler_free(backend->sampler); - backend->sampler = nullptr; - } - // Determine parameters from options or use defaults float temperature = 0.7f; float top_p = 0.9f; @@ -349,27 +303,49 @@ void configure_sampler(LlamaCppVLMBackend* backend, const rac_vlm_options_t* opt } } + // Skip rebuild if params haven't changed and sampler already exists + if (backend->sampler && + backend->cached_temperature == temperature && + backend->cached_top_p == top_p) { + return; + } + + // Free existing sampler + if (backend->sampler) { + llama_sampler_free(backend->sampler); + backend->sampler = nullptr; + } + // Build new sampler chain. // Order follows llama.cpp common_sampler_init: penalties → DRY → top_p → min_p → temp → dist. // Penalties and DRY must be applied to raw logits before temperature softens them. llama_sampler_chain_params sampler_params = llama_sampler_chain_default_params(); + sampler_params.no_perf = true; // Disable perf tracking (consistent with LLM backend) backend->sampler = llama_sampler_chain_init(sampler_params); - // Token-level repetition penalty + frequency/presence penalties - llama_sampler_chain_add(backend->sampler, llama_sampler_init_penalties(256, 1.3f, 0.1f, 0.1f)); + if (temperature > 0.0f) { + // Token-level repetition penalty + frequency/presence penalties + llama_sampler_chain_add(backend->sampler, llama_sampler_init_penalties(256, 1.3f, 0.1f, 0.1f)); - // DRY sampler: catches n-gram (sequence) repetition like "gó gó gó" where individual - // tokens may alternate. Multiplier=0.8, base=1.75, allowed_length=2, last_n=256. - const llama_vocab* vocab = llama_model_get_vocab(backend->model); - static const char* dry_breakers[] = { "\n", ":", "\"", "*" }; - llama_sampler_chain_add(backend->sampler, llama_sampler_init_dry( - vocab, llama_model_n_ctx_train(backend->model), - 0.8f, 1.75f, 2, 256, dry_breakers, 4)); + // DRY sampler: catches n-gram (sequence) repetition like "gó gó gó" where individual + // tokens may alternate. Multiplier=0.8, base=1.75, allowed_length=2, last_n=256. + const llama_vocab* vocab = llama_model_get_vocab(backend->model); + static const char* dry_breakers[] = { "\n", ":", "\"", "*" }; + llama_sampler_chain_add(backend->sampler, llama_sampler_init_dry( + vocab, llama_model_n_ctx_train(backend->model), + 0.8f, 1.75f, 2, 256, dry_breakers, 4)); + + llama_sampler_chain_add(backend->sampler, llama_sampler_init_top_p(top_p, 1)); + llama_sampler_chain_add(backend->sampler, llama_sampler_init_min_p(0.1f, 1)); + llama_sampler_chain_add(backend->sampler, llama_sampler_init_temp(temperature)); + llama_sampler_chain_add(backend->sampler, llama_sampler_init_dist(LLAMA_DEFAULT_SEED)); + } else { + llama_sampler_chain_add(backend->sampler, llama_sampler_init_greedy()); + } - llama_sampler_chain_add(backend->sampler, llama_sampler_init_top_p(top_p, 1)); - llama_sampler_chain_add(backend->sampler, llama_sampler_init_min_p(0.1f, 1)); - llama_sampler_chain_add(backend->sampler, llama_sampler_init_temp(temperature)); - llama_sampler_chain_add(backend->sampler, llama_sampler_init_dist(LLAMA_DEFAULT_SEED)); + // Cache the params for next comparison + backend->cached_temperature = temperature; + backend->cached_top_p = top_p; RAC_LOG_INFO(LOG_CAT, "[v3] Sampler: temp=%.2f top_p=%.2f repeat=1.3 freq=0.1 pres=0.1 DRY=0.8 min_p=0.1 + repeat_guard=4", temperature, top_p); @@ -391,6 +367,144 @@ static VLMModelType resolve_effective_model_type(VLMModelType detected, return detected; } +/** + * Prepare the VLM context for generation: reset state, configure sampler, + * build prompt, load image (if provided), tokenize, and evaluate. + * After success, the backend is ready for token sampling (n_past is set). + * + * Shared between rac_vlm_llamacpp_process() and rac_vlm_llamacpp_process_stream() + * to eliminate code duplication (~100 lines of identical prompt prep logic). + */ +rac_result_t prepare_vlm_context(LlamaCppVLMBackend* backend, + const rac_vlm_image_t* image, + const char* prompt, + const rac_vlm_options_t* options) { + backend->cancel_requested = false; + configure_sampler(backend, options); + + // Clear KV cache before each new request + llama_memory_t mem = llama_get_memory(backend->ctx); + if (mem) { + llama_memory_clear(mem, true); + } + backend->n_past = 0; + + // Resolve effective model type: options override > auto-detected at load time + VLMModelType effective_model_type = resolve_effective_model_type(backend->model_type, options); + const char* system_prompt = (options && options->system_prompt) ? options->system_prompt : nullptr; + + // Build prompt with image handling + std::string full_prompt; + bool has_image = false; + const char* image_marker = get_image_marker(); + +#ifdef RAC_VLM_USE_MTMD + mtmd_bitmap* bitmap = nullptr; + + if (image && backend->mtmd_ctx) { + if (image->format == RAC_VLM_IMAGE_FORMAT_FILE_PATH && image->file_path) { + bitmap = mtmd_helper_bitmap_init_from_file(backend->mtmd_ctx, image->file_path); + } else if (image->format == RAC_VLM_IMAGE_FORMAT_RGB_PIXELS && image->pixel_data) { + bitmap = mtmd_bitmap_init(image->width, image->height, image->pixel_data); + } else if (image->format == RAC_VLM_IMAGE_FORMAT_BASE64 && image->base64_data) { + RAC_LOG_WARNING(LOG_CAT, "Base64 image format not yet supported, using text-only"); + } + + has_image = (bitmap != nullptr); + if (!has_image && image->format != RAC_VLM_IMAGE_FORMAT_BASE64) { + RAC_LOG_ERROR(LOG_CAT, "Failed to load image"); + return RAC_ERROR_INVALID_INPUT; + } + } + + full_prompt = format_vlm_prompt_with_template(backend->model, prompt, image_marker, has_image, + system_prompt, effective_model_type); + + RAC_LOG_INFO(LOG_CAT, "[v3-prep] Prompt ready (chars=%d, img=%d, type=%d)", + (int)full_prompt.length(), has_image ? 1 : 0, (int)effective_model_type); + + // Tokenize and evaluate with MTMD if image present + if (backend->mtmd_ctx && bitmap) { + mtmd_input_chunks* chunks = mtmd_input_chunks_init(); + + mtmd_input_text text; + text.text = full_prompt.c_str(); + text.add_special = true; + text.parse_special = true; + + const mtmd_bitmap* bitmaps[] = { bitmap }; + int32_t tokenize_result = mtmd_tokenize(backend->mtmd_ctx, chunks, &text, bitmaps, 1); + + if (tokenize_result != 0) { + RAC_LOG_ERROR(LOG_CAT, "Failed to tokenize prompt with image: %d", tokenize_result); + mtmd_bitmap_free(bitmap); + mtmd_input_chunks_free(chunks); + return RAC_ERROR_PROCESSING_FAILED; + } + + llama_pos new_n_past = 0; + int32_t eval_result = mtmd_helper_eval_chunks( + backend->mtmd_ctx, backend->ctx, chunks, + 0, 0, + backend->config.batch_size > 0 ? backend->config.batch_size : kDefaultBatchSize, + true, &new_n_past + ); + + mtmd_bitmap_free(bitmap); + mtmd_input_chunks_free(chunks); + + if (eval_result != 0) { + RAC_LOG_ERROR(LOG_CAT, "Failed to evaluate chunks: %d", eval_result); + return RAC_ERROR_PROCESSING_FAILED; + } + + backend->n_past = new_n_past; + } else +#endif + { + // Text-only mode - still apply chat template for consistent formatting + full_prompt = format_vlm_prompt_with_template(backend->model, prompt, image_marker, false, + system_prompt, effective_model_type); + + const llama_vocab* vocab = llama_model_get_vocab(backend->model); + std::vector tokens(full_prompt.size() + 16); + int n_tokens = llama_tokenize(vocab, full_prompt.c_str(), full_prompt.size(), + tokens.data(), tokens.size(), true, true); + if (n_tokens < 0) { + tokens.resize(-n_tokens); + n_tokens = llama_tokenize(vocab, full_prompt.c_str(), full_prompt.size(), + tokens.data(), tokens.size(), true, true); + } + tokens.resize(n_tokens); + + llama_batch batch = llama_batch_init(n_tokens, 0, 1); + for (int i = 0; i < n_tokens; i++) { + batch.token[i] = tokens[i]; + batch.pos[i] = i; + batch.n_seq_id[i] = 1; + batch.seq_id[i][0] = 0; + batch.logits[i] = (i == n_tokens - 1); + } + batch.n_tokens = n_tokens; + + if (llama_decode(backend->ctx, batch) != 0) { + llama_batch_free(batch); + RAC_LOG_ERROR(LOG_CAT, "Failed to decode prompt"); + return RAC_ERROR_PROCESSING_FAILED; + } + + llama_batch_free(batch); + backend->n_past = n_tokens; + } + + return RAC_SUCCESS; +} + +// Verify backend struct size hasn't grown unexpectedly (catches accidental +// large member additions that might hurt cache locality). +static_assert(sizeof(LlamaCppVLMBackend) <= 512, + "LlamaCppVLMBackend grew unexpectedly — review member layout"); + } // namespace // ============================================================================= @@ -502,7 +616,7 @@ rac_result_t rac_vlm_llamacpp_load_model(rac_handle_t handle, const char* model_ int ctx_size = backend->config.context_size; if (ctx_size <= 0) { ctx_size = llama_model_n_ctx_train(backend->model); - if (ctx_size > 4096) ctx_size = 4096; // Cap for mobile + if (ctx_size > kDefaultMaxContextSize) ctx_size = kDefaultMaxContextSize; // Cap for mobile } backend->context_size = ctx_size; @@ -510,7 +624,7 @@ rac_result_t rac_vlm_llamacpp_load_model(rac_handle_t handle, const char* model_ int n_threads = get_num_threads(backend->config.num_threads); llama_context_params ctx_params = llama_context_default_params(); ctx_params.n_ctx = ctx_size; - ctx_params.n_batch = backend->config.batch_size > 0 ? backend->config.batch_size : 512; + ctx_params.n_batch = backend->config.batch_size > 0 ? backend->config.batch_size : kDefaultBatchSize; ctx_params.n_threads = n_threads; ctx_params.n_threads_batch = n_threads; @@ -631,144 +745,20 @@ rac_result_t rac_vlm_llamacpp_process(rac_handle_t handle, const rac_vlm_image_t return RAC_ERROR_MODEL_NOT_LOADED; } - backend->cancel_requested = false; - - // Reconfigure sampler with per-request options (temperature, top_p) - configure_sampler(backend, options); - - // Clear KV cache (memory) before each new request to avoid position conflicts - llama_memory_t mem = llama_get_memory(backend->ctx); - if (mem) { - llama_memory_clear(mem, true); + // Shared context preparation: reset, configure sampler, build prompt, evaluate + rac_result_t prep_result = prepare_vlm_context(backend, image, prompt, options); + if (prep_result != RAC_SUCCESS) { + return prep_result; } - backend->n_past = 0; - - // Resolve effective model type: options override > auto-detected at load time - VLMModelType effective_model_type = resolve_effective_model_type(backend->model_type, options); - - const char* system_prompt = (options && options->system_prompt) ? options->system_prompt : nullptr; - - // Build the prompt with proper chat template formatting - std::string full_prompt; - bool has_image = false; - const char* image_marker = get_image_marker(); -#ifdef RAC_VLM_USE_MTMD - mtmd_bitmap* bitmap = nullptr; - - if (image && backend->mtmd_ctx) { - // Load image based on format - if (image->format == RAC_VLM_IMAGE_FORMAT_FILE_PATH && image->file_path) { - bitmap = mtmd_helper_bitmap_init_from_file(backend->mtmd_ctx, image->file_path); - } else if (image->format == RAC_VLM_IMAGE_FORMAT_RGB_PIXELS && image->pixel_data) { - bitmap = mtmd_bitmap_init(image->width, image->height, image->pixel_data); - } else if (image->format == RAC_VLM_IMAGE_FORMAT_BASE64 && image->base64_data) { - // Decode base64 first - // For now, skip base64 - would need base64 decoder - RAC_LOG_WARNING(LOG_CAT, "Base64 image format not yet supported, using text-only"); - } - - has_image = (bitmap != nullptr); - if (!has_image && image->format != RAC_VLM_IMAGE_FORMAT_BASE64) { - RAC_LOG_ERROR(LOG_CAT, "Failed to load image"); - return RAC_ERROR_INVALID_INPUT; - } - } - - // Format prompt using model's built-in chat template - full_prompt = format_vlm_prompt_with_template(backend->model, prompt, image_marker, has_image, - system_prompt, effective_model_type); - - RAC_LOG_INFO(LOG_CAT, "[v3-process] Prompt ready (chars=%d, img=%d, type=%d)", - (int)full_prompt.length(), has_image ? 1 : 0, (int)effective_model_type); - - // Tokenize and evaluate - if (backend->mtmd_ctx && bitmap) { - mtmd_input_chunks* chunks = mtmd_input_chunks_init(); - - mtmd_input_text text; - text.text = full_prompt.c_str(); - text.add_special = true; - text.parse_special = true; - - const mtmd_bitmap* bitmaps[] = { bitmap }; - int32_t tokenize_result = mtmd_tokenize(backend->mtmd_ctx, chunks, &text, bitmaps, 1); - - if (tokenize_result != 0) { - RAC_LOG_ERROR(LOG_CAT, "Failed to tokenize prompt with image: %d", tokenize_result); - mtmd_bitmap_free(bitmap); - mtmd_input_chunks_free(chunks); - return RAC_ERROR_PROCESSING_FAILED; - } - - // Evaluate chunks - llama_pos new_n_past = 0; - int32_t eval_result = mtmd_helper_eval_chunks( - backend->mtmd_ctx, - backend->ctx, - chunks, - 0, // n_past - 0, // seq_id - backend->config.batch_size > 0 ? backend->config.batch_size : 512, - true, // logits_last - &new_n_past - ); - - mtmd_bitmap_free(bitmap); - mtmd_input_chunks_free(chunks); - - if (eval_result != 0) { - RAC_LOG_ERROR(LOG_CAT, "Failed to evaluate chunks: %d", eval_result); - return RAC_ERROR_PROCESSING_FAILED; - } - - backend->n_past = new_n_past; - } else -#endif - { - // Text-only mode - still apply chat template for consistent formatting - full_prompt = format_vlm_prompt_with_template(backend->model, prompt, image_marker, false, - system_prompt, effective_model_type); - - const llama_vocab* vocab = llama_model_get_vocab(backend->model); - std::vector tokens(full_prompt.size() + 16); - int n_tokens = llama_tokenize(vocab, full_prompt.c_str(), full_prompt.size(), - tokens.data(), tokens.size(), true, true); - if (n_tokens < 0) { - tokens.resize(-n_tokens); - n_tokens = llama_tokenize(vocab, full_prompt.c_str(), full_prompt.size(), - tokens.data(), tokens.size(), true, true); - } - tokens.resize(n_tokens); - - // Create batch and decode - llama_batch batch = llama_batch_init(n_tokens, 0, 1); - for (int i = 0; i < n_tokens; i++) { - batch.token[i] = tokens[i]; - batch.pos[i] = i; - batch.n_seq_id[i] = 1; - batch.seq_id[i][0] = 0; - batch.logits[i] = (i == n_tokens - 1); - } - batch.n_tokens = n_tokens; - - if (llama_decode(backend->ctx, batch) != 0) { - llama_batch_free(batch); - RAC_LOG_ERROR(LOG_CAT, "Failed to decode prompt"); - return RAC_ERROR_PROCESSING_FAILED; - } - - llama_batch_free(batch); - backend->n_past = n_tokens; - } - - // Generate response - int max_tokens = (options && options->max_tokens > 0) ? options->max_tokens : 2048; + // Generate response (batch mode — accumulate all tokens) + const int max_tokens = (options && options->max_tokens > 0) ? options->max_tokens : kDefaultMaxTokens; std::string response; + response.reserve(kDefaultMaxTokens); // Typical VLM responses are a few hundred tokens int tokens_generated = 0; llama_batch batch = llama_batch_init(1, 0, 1); - const llama_vocab* vocab = llama_model_get_vocab(backend->model); + const llama_vocab* const vocab = llama_model_get_vocab(backend->model); // Runtime repetition guard: track last token and consecutive repeat count. // If the same token appears too many times in a row, the model is stuck and @@ -837,7 +827,7 @@ rac_result_t rac_vlm_llamacpp_process(rac_handle_t handle, const rac_vlm_image_t prev_token = token; char buf[256]; - int len = llama_token_to_piece(vocab, token, buf, sizeof(buf), 0, true); + int len = llama_token_to_piece(vocab, token, buf, sizeof(buf) - 1, 0, true); if (len > 0) { response.append(buf, len); } @@ -860,6 +850,10 @@ rac_result_t rac_vlm_llamacpp_process(rac_handle_t handle, const rac_vlm_image_t // Fill result out_result->text = strdup(response.c_str()); + if (!out_result->text) { + RAC_LOG_ERROR(LOG_CAT, "Failed to allocate result text"); + return RAC_ERROR_OUT_OF_MEMORY; + } out_result->completion_tokens = tokens_generated; out_result->prompt_tokens = backend->n_past - tokens_generated; out_result->total_tokens = backend->n_past; @@ -884,136 +878,17 @@ rac_result_t rac_vlm_llamacpp_process_stream(rac_handle_t handle, const rac_vlm_ return RAC_ERROR_MODEL_NOT_LOADED; } - backend->cancel_requested = false; - - // Reconfigure sampler with per-request options (temperature, top_p) - configure_sampler(backend, options); - - // Clear KV cache (memory) before each new request to avoid position conflicts - llama_memory_t mem = llama_get_memory(backend->ctx); - if (mem) { - llama_memory_clear(mem, true); - } - backend->n_past = 0; - RAC_LOG_DEBUG(LOG_CAT, "Cleared KV cache for new request"); - - // Resolve effective model type: options override > auto-detected at load time - VLMModelType effective_model_type = resolve_effective_model_type(backend->model_type, options); - - const char* system_prompt = (options && options->system_prompt) ? options->system_prompt : nullptr; - - // Build the prompt with proper chat template formatting - std::string full_prompt; - bool has_image = false; - const char* image_marker = get_image_marker(); - -#ifdef RAC_VLM_USE_MTMD - mtmd_bitmap* bitmap = nullptr; - - if (image && backend->mtmd_ctx) { - // Load image based on format - if (image->format == RAC_VLM_IMAGE_FORMAT_FILE_PATH && image->file_path) { - bitmap = mtmd_helper_bitmap_init_from_file(backend->mtmd_ctx, image->file_path); - } else if (image->format == RAC_VLM_IMAGE_FORMAT_RGB_PIXELS && image->pixel_data) { - bitmap = mtmd_bitmap_init(image->width, image->height, image->pixel_data); - } - - has_image = (bitmap != nullptr); - if (!has_image) { - RAC_LOG_WARNING(LOG_CAT, "Failed to load image, using text-only"); - } - } - - // Format prompt using model's built-in chat template (streaming path) - full_prompt = format_vlm_prompt_with_template(backend->model, prompt, image_marker, has_image, - system_prompt, effective_model_type); - - RAC_LOG_INFO(LOG_CAT, "[v3-stream] Prompt ready (chars=%d, img=%d, type=%d)", - (int)full_prompt.length(), has_image ? 1 : 0, (int)effective_model_type); - - // Tokenize and evaluate - if (backend->mtmd_ctx && bitmap) { - mtmd_input_chunks* chunks = mtmd_input_chunks_init(); - - mtmd_input_text text; - text.text = full_prompt.c_str(); - text.add_special = true; - text.parse_special = true; - - const mtmd_bitmap* bitmaps[] = { bitmap }; - int32_t tokenize_result = mtmd_tokenize(backend->mtmd_ctx, chunks, &text, bitmaps, 1); - - if (tokenize_result != 0) { - RAC_LOG_ERROR(LOG_CAT, "Failed to tokenize prompt with image: %d", tokenize_result); - mtmd_bitmap_free(bitmap); - mtmd_input_chunks_free(chunks); - return RAC_ERROR_PROCESSING_FAILED; - } - - // Evaluate chunks - llama_pos new_n_past = 0; - int32_t eval_result = mtmd_helper_eval_chunks( - backend->mtmd_ctx, - backend->ctx, - chunks, - 0, // n_past - 0, // seq_id - backend->config.batch_size > 0 ? backend->config.batch_size : 512, - true, // logits_last - &new_n_past - ); - - mtmd_bitmap_free(bitmap); - mtmd_input_chunks_free(chunks); - - if (eval_result != 0) { - RAC_LOG_ERROR(LOG_CAT, "Failed to evaluate chunks: %d", eval_result); - return RAC_ERROR_PROCESSING_FAILED; - } - - backend->n_past = new_n_past; - } else -#endif - { - // Text-only mode - still apply chat template for consistent formatting - full_prompt = format_vlm_prompt_with_template(backend->model, prompt, image_marker, false, - system_prompt, effective_model_type); - - const llama_vocab* vocab = llama_model_get_vocab(backend->model); - std::vector tokens(full_prompt.size() + 16); - int n_tokens = llama_tokenize(vocab, full_prompt.c_str(), full_prompt.size(), - tokens.data(), tokens.size(), true, true); - if (n_tokens < 0) { - tokens.resize(-n_tokens); - n_tokens = llama_tokenize(vocab, full_prompt.c_str(), full_prompt.size(), - tokens.data(), tokens.size(), true, true); - } - tokens.resize(n_tokens); - - llama_batch batch = llama_batch_init(n_tokens, 0, 1); - for (int i = 0; i < n_tokens; i++) { - batch.token[i] = tokens[i]; - batch.pos[i] = i; - batch.n_seq_id[i] = 1; - batch.seq_id[i][0] = 0; - batch.logits[i] = (i == n_tokens - 1); - } - batch.n_tokens = n_tokens; - - if (llama_decode(backend->ctx, batch) != 0) { - llama_batch_free(batch); - return RAC_ERROR_PROCESSING_FAILED; - } - - llama_batch_free(batch); - backend->n_past = n_tokens; + // Shared context preparation: reset, configure sampler, build prompt, evaluate + rac_result_t prep_result = prepare_vlm_context(backend, image, prompt, options); + if (prep_result != RAC_SUCCESS) { + return prep_result; } - // Generate response with streaming - int max_tokens = (options && options->max_tokens > 0) ? options->max_tokens : 2048; + // Generate response (streaming mode — callback per token) + const int max_tokens = (options && options->max_tokens > 0) ? options->max_tokens : kDefaultMaxTokens; llama_batch batch = llama_batch_init(1, 0, 1); - const llama_vocab* vocab = llama_model_get_vocab(backend->model); + const llama_vocab* const vocab = llama_model_get_vocab(backend->model); // Runtime repetition guard (same as non-streaming path) llama_token prev_token = -1; @@ -1043,7 +918,7 @@ rac_result_t rac_vlm_llamacpp_process_stream(rac_handle_t handle, const rac_vlm_ } char buf[256]; - int len = llama_token_to_piece(vocab, token, buf, sizeof(buf), 0, true); + int len = llama_token_to_piece(vocab, token, buf, sizeof(buf) - 1, 0, true); if (len > 0) { buf[len] = '\0'; if (callback(buf, is_eog ? RAC_TRUE : RAC_FALSE, user_data) == RAC_FALSE) { @@ -1104,6 +979,9 @@ rac_result_t rac_vlm_llamacpp_get_model_info(rac_handle_t handle, char** out_jso ); *out_json = strdup(buffer); + if (!*out_json) { + return RAC_ERROR_OUT_OF_MEMORY; + } return RAC_SUCCESS; } diff --git a/sdk/runanywhere-commons/src/backends/onnx/CMakeLists.txt b/sdk/runanywhere-commons/src/backends/onnx/CMakeLists.txt index 2b62d6fc4..d5aeca964 100644 --- a/sdk/runanywhere-commons/src/backends/onnx/CMakeLists.txt +++ b/sdk/runanywhere-commons/src/backends/onnx/CMakeLists.txt @@ -234,7 +234,7 @@ else() endif() set_target_properties(rac_backend_onnx PROPERTIES - CXX_STANDARD 17 + CXX_STANDARD 20 CXX_STANDARD_REQUIRED ON CXX_EXTENSIONS OFF ) diff --git a/sdk/runanywhere-commons/src/backends/onnx/jni/rac_backend_onnx_jni.cpp b/sdk/runanywhere-commons/src/backends/onnx/jni/rac_backend_onnx_jni.cpp index fc3b997be..2712b68c3 100644 --- a/sdk/runanywhere-commons/src/backends/onnx/jni/rac_backend_onnx_jni.cpp +++ b/sdk/runanywhere-commons/src/backends/onnx/jni/rac_backend_onnx_jni.cpp @@ -12,25 +12,19 @@ #include #include -#ifdef __ANDROID__ -#include -#define TAG "RACOnnxJNI" -#define LOGi(...) __android_log_print(ANDROID_LOG_INFO, TAG, __VA_ARGS__) -#define LOGe(...) __android_log_print(ANDROID_LOG_ERROR, TAG, __VA_ARGS__) -#define LOGw(...) __android_log_print(ANDROID_LOG_WARN, TAG, __VA_ARGS__) -#else -#include -#define LOGi(...) fprintf(stdout, "[INFO] " __VA_ARGS__); fprintf(stdout, "\n") -#define LOGe(...) fprintf(stderr, "[ERROR] " __VA_ARGS__); fprintf(stderr, "\n") -#define LOGw(...) fprintf(stdout, "[WARN] " __VA_ARGS__); fprintf(stdout, "\n") -#endif - #include "rac_stt_onnx.h" #include "rac_tts_onnx.h" #include "rac_vad_onnx.h" #include "rac/core/rac_core.h" #include "rac/core/rac_error.h" +#include "rac/core/rac_logger.h" + +// Route JNI logging through unified RAC_LOG_* system +static const char* LOG_TAG = "JNI.ONNX"; +#define LOGi(...) RAC_LOG_INFO(LOG_TAG, __VA_ARGS__) +#define LOGe(...) RAC_LOG_ERROR(LOG_TAG, __VA_ARGS__) +#define LOGw(...) RAC_LOG_WARNING(LOG_TAG, __VA_ARGS__) // Forward declaration extern "C" rac_result_t rac_backend_onnx_register(void); diff --git a/sdk/runanywhere-commons/src/backends/onnx/onnx_backend.cpp b/sdk/runanywhere-commons/src/backends/onnx/onnx_backend.cpp index 72ece0f28..c8f3310a1 100644 --- a/sdk/runanywhere-commons/src/backends/onnx/onnx_backend.cpp +++ b/sdk/runanywhere-commons/src/backends/onnx/onnx_backend.cpp @@ -24,11 +24,8 @@ #include "rac/core/rac_logger.h" #if SHERPA_ONNX_AVAILABLE -extern "C" { - int espeak_Initialize(int output, int buflength, const char *path, int options); - int espeak_SetVoiceByName(const char *name); -} -#define ESPEAK_AUDIO_OUTPUT_SYNCHRONOUS 0x0003 +// espeak-ng reinitialization is handled by destroying and recreating +// the SherpaOnnxOfflineTts instance via the sherpa-onnx C API. #endif namespace runanywhere { @@ -966,26 +963,15 @@ bool ONNXTTS::load_model(const std::string& model_path, TTSModelType model_type, return false; } - sherpa_tts_ = new_tts; - - // Force espeak-ng to use THIS model's data_dir. - // Sherpa-ONNX uses std::once_flag for espeak_Initialize, so only the first - // model loaded gets its data_dir registered. Re-calling espeak_Initialize - // directly resets the internal path_home to the current model's directory. - if (!espeak_data_dir_.empty()) { - int reinit = espeak_Initialize(ESPEAK_AUDIO_OUTPUT_SYNCHRONOUS, 0, espeak_data_dir_.c_str(), 0); - RAC_LOG_INFO("ONNX.TTS", "espeak_Initialize override: result=%d (expected 22050), data_dir=%s", - reinit, espeak_data_dir_.c_str()); - - if (reinit == 22050) { - int voice_test = espeak_SetVoiceByName("en-us"); - RAC_LOG_INFO("ONNX.TTS", "espeak_SetVoiceByName('en-us') test: result=%d (0=success)", voice_test); - int voice_test_gb = espeak_SetVoiceByName("en-gb"); - RAC_LOG_INFO("ONNX.TTS", "espeak_SetVoiceByName('en-gb') test: result=%d (0=success)", voice_test_gb); - } else { - RAC_LOG_ERROR("ONNX.TTS", "espeak_Initialize override FAILED with code %d", reinit); - } + // Workaround for sherpa-onnx std::once_flag issue: espeak_Initialize is + // only called internally on the very first SherpaOnnxCreateOfflineTts call. + // When switching TTS models with different data_dir, destroy and recreate + // the instance so the config (including data_dir) is applied correctly. + if (sherpa_tts_ && sherpa_tts_ != new_tts) { + SherpaOnnxDestroyOfflineTts(sherpa_tts_); + sherpa_tts_ = nullptr; } + sherpa_tts_ = new_tts; sample_rate_ = SherpaOnnxOfflineTtsSampleRate(sherpa_tts_); int num_speakers = SherpaOnnxOfflineTtsNumSpeakers(sherpa_tts_); @@ -1220,11 +1206,13 @@ bool ONNXVAD::unload_model() { } bool ONNXVAD::configure_vad(const VADConfig& config) { + std::lock_guard lock(mutex_); config_ = config; return true; } VADResult ONNXVAD::process(const std::vector& audio_samples, int sample_rate) { + std::lock_guard lock(mutex_); VADResult result; #if SHERPA_ONNX_AVAILABLE @@ -1232,18 +1220,24 @@ VADResult ONNXVAD::process(const std::vector& audio_samples, int sample_r return result; } - const int32_t window_size = 512; // Silero native window size + static constexpr int32_t SILERO_WINDOW_SIZE = 512; // Append incoming audio to the pending buffer. - // Audio capture may deliver chunks smaller than window_size (e.g. 256 samples), + // Audio capture may deliver chunks smaller than SILERO_WINDOW_SIZE (e.g. 256 samples), // but Silero VAD requires exactly 512 samples per call. pending_samples_.insert(pending_samples_.end(), audio_samples.begin(), audio_samples.end()); - // Feed complete window_size chunks to Silero VAD - while (pending_samples_.size() >= static_cast(window_size)) { + // Feed complete SILERO_WINDOW_SIZE chunks to Silero VAD. + // Use offset tracking instead of repeated front-erase (O(n) per erase). + size_t consumed = 0; + while (consumed + SILERO_WINDOW_SIZE <= pending_samples_.size()) { SherpaOnnxVoiceActivityDetectorAcceptWaveform( - sherpa_vad_, pending_samples_.data(), window_size); - pending_samples_.erase(pending_samples_.begin(), pending_samples_.begin() + window_size); + sherpa_vad_, pending_samples_.data() + consumed, SILERO_WINDOW_SIZE); + consumed += SILERO_WINDOW_SIZE; + } + if (consumed > 0) { + pending_samples_.erase(pending_samples_.begin(), + pending_samples_.begin() + static_cast(consumed)); } // Check if speech is currently detected in the latest frame @@ -1280,6 +1274,7 @@ VADResult ONNXVAD::feed_audio(const std::string& stream_id, const std::vector lock(mutex_); #if SHERPA_ONNX_AVAILABLE if (sherpa_vad_) { SherpaOnnxVoiceActivityDetectorReset(sherpa_vad_); @@ -1289,6 +1284,7 @@ void ONNXVAD::reset() { } VADConfig ONNXVAD::get_vad_config() const { + std::lock_guard lock(mutex_); return config_; } diff --git a/sdk/runanywhere-commons/src/backends/onnx/onnx_backend.h b/sdk/runanywhere-commons/src/backends/onnx/onnx_backend.h index d0c3ba1a8..44ec9524f 100644 --- a/sdk/runanywhere-commons/src/backends/onnx/onnx_backend.h +++ b/sdk/runanywhere-commons/src/backends/onnx/onnx_backend.h @@ -221,7 +221,7 @@ class ONNXBackendNew { bool initialize_ort(); void create_capabilities(); - bool initialized_ = false; + std::atomic initialized_{false}; const OrtApi* ort_api_ = nullptr; OrtEnv* ort_env_ = nullptr; nlohmann::json config_; @@ -276,7 +276,7 @@ class ONNXSTT { void* sherpa_recognizer_ = nullptr; #endif STTModelType model_type_ = STTModelType::WHISPER; - bool model_loaded_ = false; + std::atomic model_loaded_{false}; std::atomic cancel_requested_{false}; std::unordered_map streams_; int stream_counter_ = 0; @@ -321,7 +321,7 @@ class ONNXTTS { void* sherpa_tts_ = nullptr; #endif TTSModelType model_type_ = TTSModelType::PIPER; - bool model_loaded_ = false; + std::atomic model_loaded_{false}; std::atomic cancel_requested_{false}; std::atomic active_synthesis_count_{0}; std::vector voices_; @@ -366,7 +366,7 @@ class ONNXVAD { #endif std::string model_path_; VADConfig config_; - bool model_loaded_ = false; + std::atomic model_loaded_{false}; mutable std::mutex mutex_; // Internal buffer to accumulate audio until we have a full Silero window (512 samples). diff --git a/sdk/runanywhere-commons/src/backends/onnx/rac_backend_onnx_register.cpp b/sdk/runanywhere-commons/src/backends/onnx/rac_backend_onnx_register.cpp index dd19fb71c..b86bc6d34 100644 --- a/sdk/runanywhere-commons/src/backends/onnx/rac_backend_onnx_register.cpp +++ b/sdk/runanywhere-commons/src/backends/onnx/rac_backend_onnx_register.cpp @@ -173,6 +173,7 @@ static rac_result_t onnx_tts_vtable_synthesize_stream(void* impl, const char* te if (status == RAC_SUCCESS && callback) { callback(result.audio_data, result.audio_size, user_data); } + rac_tts_result_free(&result); return status; } @@ -437,6 +438,9 @@ rac_result_t onnx_download_post_process(const rac_model_download_config_t* confi out_result->was_extracted = (config->archive_type != RAC_ARCHIVE_TYPE_NONE) ? RAC_TRUE : RAC_FALSE; out_result->final_path = strdup(downloaded_path); + if (!out_result->final_path) { + return RAC_ERROR_OUT_OF_MEMORY; + } out_result->file_count = 1; return RAC_SUCCESS; diff --git a/sdk/runanywhere-commons/src/backends/onnx/rac_onnx.cpp b/sdk/runanywhere-commons/src/backends/onnx/rac_onnx.cpp index ea5fe974b..232525c0d 100644 --- a/sdk/runanywhere-commons/src/backends/onnx/rac_onnx.cpp +++ b/sdk/runanywhere-commons/src/backends/onnx/rac_onnx.cpp @@ -137,8 +137,16 @@ rac_result_t rac_stt_onnx_transcribe(rac_handle_t handle, const float* audio_sam auto result = h->stt->transcribe(request); out_result->text = result.text.empty() ? nullptr : strdup(result.text.c_str()); + if (!result.text.empty() && !out_result->text) { + return RAC_ERROR_OUT_OF_MEMORY; + } out_result->detected_language = result.detected_language.empty() ? nullptr : strdup(result.detected_language.c_str()); + if (!result.detected_language.empty() && !out_result->detected_language) { + free(out_result->text); + out_result->text = nullptr; + return RAC_ERROR_OUT_OF_MEMORY; + } out_result->words = nullptr; out_result->num_words = 0; out_result->confidence = 1.0f; @@ -173,7 +181,11 @@ rac_result_t rac_stt_onnx_create_stream(rac_handle_t handle, rac_handle_t* out_s return RAC_ERROR_BACKEND_INIT_FAILED; } - *out_stream = static_cast(strdup(stream_id.c_str())); + char* stream_copy = strdup(stream_id.c_str()); + if (!stream_copy) { + return RAC_ERROR_OUT_OF_MEMORY; + } + *out_stream = static_cast(stream_copy); return RAC_SUCCESS; } @@ -213,6 +225,9 @@ rac_result_t rac_stt_onnx_decode_stream(rac_handle_t handle, rac_handle_t stream auto result = h->stt->decode(stream_id); *out_text = strdup(result.text.c_str()); + if (!*out_text) { + return RAC_ERROR_OUT_OF_MEMORY; + } return RAC_SUCCESS; } @@ -377,16 +392,31 @@ rac_result_t rac_tts_onnx_get_voices(rac_handle_t handle, char*** out_voices, si } auto voices = h->tts->get_voices(); - *out_count = voices.size(); if (voices.empty()) { *out_voices = nullptr; + *out_count = 0; return RAC_SUCCESS; } *out_voices = static_cast(malloc(voices.size() * sizeof(char*))); + if (!*out_voices) { + *out_count = 0; + return RAC_ERROR_OUT_OF_MEMORY; + } + + *out_count = voices.size(); for (size_t i = 0; i < voices.size(); i++) { (*out_voices)[i] = strdup(voices[i].id.c_str()); + if (!(*out_voices)[i]) { + for (size_t j = 0; j < i; j++) { + free((*out_voices)[j]); + } + free(*out_voices); + *out_voices = nullptr; + *out_count = 0; + return RAC_ERROR_OUT_OF_MEMORY; + } } return RAC_SUCCESS; diff --git a/sdk/runanywhere-commons/src/backends/whispercpp/CMakeLists.txt b/sdk/runanywhere-commons/src/backends/whispercpp/CMakeLists.txt index 7b82d517a..84a5dcee3 100644 --- a/sdk/runanywhere-commons/src/backends/whispercpp/CMakeLists.txt +++ b/sdk/runanywhere-commons/src/backends/whispercpp/CMakeLists.txt @@ -122,7 +122,7 @@ target_link_libraries(rac_backend_whispercpp PUBLIC nlohmann_json::nlohmann_json ) -target_compile_features(rac_backend_whispercpp PUBLIC cxx_std_17) +target_compile_features(rac_backend_whispercpp PUBLIC cxx_std_20) # ============================================================================= # Platform-specific configuration diff --git a/sdk/runanywhere-commons/src/backends/whispercpp/jni/rac_backend_whispercpp_jni.cpp b/sdk/runanywhere-commons/src/backends/whispercpp/jni/rac_backend_whispercpp_jni.cpp index d735c9225..0b48a038d 100644 --- a/sdk/runanywhere-commons/src/backends/whispercpp/jni/rac_backend_whispercpp_jni.cpp +++ b/sdk/runanywhere-commons/src/backends/whispercpp/jni/rac_backend_whispercpp_jni.cpp @@ -12,23 +12,17 @@ #include #include -#ifdef __ANDROID__ -#include -#define TAG "RACWhisperCPPJNI" -#define LOGi(...) __android_log_print(ANDROID_LOG_INFO, TAG, __VA_ARGS__) -#define LOGe(...) __android_log_print(ANDROID_LOG_ERROR, TAG, __VA_ARGS__) -#define LOGw(...) __android_log_print(ANDROID_LOG_WARN, TAG, __VA_ARGS__) -#else -#include -#define LOGi(...) fprintf(stdout, "[INFO] " __VA_ARGS__); fprintf(stdout, "\n") -#define LOGe(...) fprintf(stderr, "[ERROR] " __VA_ARGS__); fprintf(stderr, "\n") -#define LOGw(...) fprintf(stdout, "[WARN] " __VA_ARGS__); fprintf(stdout, "\n") -#endif - #include "rac_stt_whispercpp.h" #include "rac/core/rac_core.h" #include "rac/core/rac_error.h" +#include "rac/core/rac_logger.h" + +// Route JNI logging through unified RAC_LOG_* system +static const char* LOG_TAG = "JNI.WhisperCpp"; +#define LOGi(...) RAC_LOG_INFO(LOG_TAG, __VA_ARGS__) +#define LOGe(...) RAC_LOG_ERROR(LOG_TAG, __VA_ARGS__) +#define LOGw(...) RAC_LOG_WARNING(LOG_TAG, __VA_ARGS__) extern "C" { diff --git a/sdk/runanywhere-commons/src/backends/whispercpp/rac_backend_whispercpp_register.cpp b/sdk/runanywhere-commons/src/backends/whispercpp/rac_backend_whispercpp_register.cpp index 8a6c62837..c0f17302e 100644 --- a/sdk/runanywhere-commons/src/backends/whispercpp/rac_backend_whispercpp_register.cpp +++ b/sdk/runanywhere-commons/src/backends/whispercpp/rac_backend_whispercpp_register.cpp @@ -51,6 +51,9 @@ static rac_result_t whispercpp_stt_vtable_transcribe(void* impl, const void* aud size_t audio_size, const rac_stt_options_t* options, rac_stt_result_t* out_result) { + if (!audio_data || audio_size == 0 || !out_result) { + return RAC_ERROR_INVALID_ARGUMENT; + } std::vector float_samples = convert_int16_to_float32(audio_data, audio_size); return rac_stt_whispercpp_transcribe(impl, float_samples.data(), float_samples.size(), options, out_result); @@ -71,6 +74,7 @@ static rac_result_t whispercpp_stt_vtable_transcribe_stream(void* impl, const vo if (status == RAC_SUCCESS && callback && result.text) { callback(result.text, RAC_TRUE, user_data); } + rac_stt_result_free(&result); return status; } diff --git a/sdk/runanywhere-commons/src/backends/whispercpp/rac_stt_whispercpp.cpp b/sdk/runanywhere-commons/src/backends/whispercpp/rac_stt_whispercpp.cpp index 6fa394d81..a5a16555f 100644 --- a/sdk/runanywhere-commons/src/backends/whispercpp/rac_stt_whispercpp.cpp +++ b/sdk/runanywhere-commons/src/backends/whispercpp/rac_stt_whispercpp.cpp @@ -121,8 +121,16 @@ rac_result_t rac_stt_whispercpp_transcribe(rac_handle_t handle, const float* aud // Fill output out_result->text = result.text.empty() ? nullptr : strdup(result.text.c_str()); + if (!result.text.empty() && !out_result->text) { + return RAC_ERROR_OUT_OF_MEMORY; + } out_result->detected_language = result.detected_language.empty() ? nullptr : strdup(result.detected_language.c_str()); + if (!result.detected_language.empty() && !out_result->detected_language) { + free(out_result->text); + out_result->text = nullptr; + return RAC_ERROR_OUT_OF_MEMORY; + } out_result->confidence = result.confidence; out_result->processing_time_ms = result.inference_time_ms; @@ -130,12 +138,23 @@ rac_result_t rac_stt_whispercpp_transcribe(rac_handle_t handle, const float* aud out_result->words = nullptr; out_result->num_words = 0; if (!result.word_timings.empty()) { - out_result->num_words = result.word_timings.size(); out_result->words = static_cast(malloc(result.word_timings.size() * sizeof(rac_stt_word_t))); if (out_result->words) { + out_result->num_words = result.word_timings.size(); for (size_t i = 0; i < result.word_timings.size(); i++) { out_result->words[i].text = strdup(result.word_timings[i].word.c_str()); + if (!out_result->words[i].text) { + // Clean up already-allocated word texts + for (size_t j = 0; j < i; j++) { + free(out_result->words[j].text); + } + free(out_result->words); + out_result->words = nullptr; + out_result->num_words = 0; + // Continue without word timings rather than failing entirely + break; + } out_result->words[i].start_ms = static_cast(result.word_timings[i].start_time_ms); out_result->words[i].end_ms = @@ -163,6 +182,9 @@ rac_result_t rac_stt_whispercpp_get_language(rac_handle_t handle, char** out_lan } *out_language = strdup(h->detected_language.c_str()); + if (!*out_language) { + return RAC_ERROR_OUT_OF_MEMORY; + } return RAC_SUCCESS; } diff --git a/sdk/runanywhere-commons/src/backends/whispercpp/whispercpp_backend.cpp b/sdk/runanywhere-commons/src/backends/whispercpp/whispercpp_backend.cpp index f51b3f803..a93e4a997 100644 --- a/sdk/runanywhere-commons/src/backends/whispercpp/whispercpp_backend.cpp +++ b/sdk/runanywhere-commons/src/backends/whispercpp/whispercpp_backend.cpp @@ -14,10 +14,6 @@ #include "rac/core/rac_logger.h" -// Use the RAC logging system -#define LOGI(...) RAC_LOG_INFO("STT.WhisperCpp", __VA_ARGS__) -#define LOGE(...) RAC_LOG_ERROR("STT.WhisperCpp", __VA_ARGS__) -#define LOGW(...) RAC_LOG_WARNING("STT.WhisperCpp", __VA_ARGS__) // Whisper sample rate constant #ifndef WHISPER_SAMPLE_RATE @@ -31,19 +27,19 @@ namespace runanywhere { // ============================================================================= WhisperCppBackend::WhisperCppBackend() { - LOGI("WhisperCppBackend created"); + RAC_LOG_INFO("STT.WhisperCpp","WhisperCppBackend created"); } WhisperCppBackend::~WhisperCppBackend() { cleanup(); - LOGI("WhisperCppBackend destroyed"); + RAC_LOG_INFO("STT.WhisperCpp","WhisperCppBackend destroyed"); } bool WhisperCppBackend::initialize(const nlohmann::json& config) { std::lock_guard lock(mutex_); if (initialized_) { - LOGI("WhisperCppBackend already initialized"); + RAC_LOG_INFO("STT.WhisperCpp","WhisperCppBackend already initialized"); return true; } @@ -65,7 +61,7 @@ bool WhisperCppBackend::initialize(const nlohmann::json& config) { use_gpu_ = config["use_gpu"].get(); } - LOGI("WhisperCppBackend initialized with %d threads, GPU: %s", num_threads_, + RAC_LOG_INFO("STT.WhisperCpp","WhisperCppBackend initialized with %d threads, GPU: %s", num_threads_, use_gpu_ ? "enabled" : "disabled"); create_stt(); @@ -87,12 +83,12 @@ void WhisperCppBackend::cleanup() { stt_.reset(); initialized_ = false; - LOGI("WhisperCppBackend cleaned up"); + RAC_LOG_INFO("STT.WhisperCpp","WhisperCppBackend cleaned up"); } void WhisperCppBackend::create_stt() { stt_ = std::make_unique(this); - LOGI("Created STT component"); + RAC_LOG_INFO("STT.WhisperCpp","Created STT component"); } DeviceType WhisperCppBackend::get_device_type() const { @@ -114,7 +110,7 @@ size_t WhisperCppBackend::get_memory_usage() const { // ============================================================================= WhisperCppSTT::WhisperCppSTT(WhisperCppBackend* backend) : backend_(backend) { - LOGI("WhisperCppSTT created"); + RAC_LOG_INFO("STT.WhisperCpp","WhisperCppSTT created"); } WhisperCppSTT::~WhisperCppSTT() { @@ -127,7 +123,7 @@ WhisperCppSTT::~WhisperCppSTT() { } streams_.clear(); - LOGI("WhisperCppSTT destroyed"); + RAC_LOG_INFO("STT.WhisperCpp","WhisperCppSTT destroyed"); } bool WhisperCppSTT::is_ready() const { @@ -139,13 +135,13 @@ bool WhisperCppSTT::load_model(const std::string& model_path, STTModelType model std::lock_guard lock(mutex_); if (model_loaded_ && ctx_) { - LOGI("Unloading previous model"); + RAC_LOG_INFO("STT.WhisperCpp","Unloading previous model"); whisper_free(ctx_); ctx_ = nullptr; model_loaded_ = false; } - LOGI("Loading whisper model from: %s", model_path.c_str()); + RAC_LOG_INFO("STT.WhisperCpp","Loading whisper model from: %s", model_path.c_str()); whisper_context_params cparams = whisper_context_default_params(); cparams.use_gpu = backend_->is_gpu_enabled(); @@ -162,7 +158,7 @@ bool WhisperCppSTT::load_model(const std::string& model_path, STTModelType model ctx_ = whisper_init_from_file_with_params(model_path.c_str(), cparams); if (!ctx_) { - LOGE("Failed to load whisper model from: %s", model_path.c_str()); + RAC_LOG_ERROR("STT.WhisperCpp","Failed to load whisper model from: %s", model_path.c_str()); return false; } @@ -170,7 +166,7 @@ bool WhisperCppSTT::load_model(const std::string& model_path, STTModelType model model_config_ = config; model_loaded_ = true; - LOGI("Whisper model loaded successfully. Multilingual: %s", + RAC_LOG_INFO("STT.WhisperCpp","Whisper model loaded successfully. Multilingual: %s", whisper_is_multilingual(ctx_) ? "yes" : "no"); return true; @@ -199,7 +195,7 @@ bool WhisperCppSTT::unload_model() { model_loaded_ = false; model_path_.clear(); - LOGI("Whisper model unloaded"); + RAC_LOG_INFO("STT.WhisperCpp","Whisper model unloaded"); return true; } @@ -214,7 +210,7 @@ STTResult WhisperCppSTT::transcribe(const STTRequest& request) { result.is_final = true; if (!model_loaded_ || !ctx_) { - LOGE("Model not loaded"); + RAC_LOG_ERROR("STT.WhisperCpp","Model not loaded"); return result; } @@ -265,7 +261,7 @@ STTResult WhisperCppSTT::transcribe_internal(const std::vector& audio, int ret = whisper_full(ctx_, wparams, audio.data(), static_cast(audio.size())); if (ret != 0) { - LOGE("whisper_full failed with code: %d", ret); + RAC_LOG_ERROR("STT.WhisperCpp","whisper_full failed with code: %d", ret); return result; } @@ -332,7 +328,7 @@ STTResult WhisperCppSTT::transcribe_internal(const std::vector& audio, result.confidence = total_conf / static_cast(result.segments.size()); } - LOGI("Transcription complete: %d segments, %.0fms inference, lang=%s", n_segments, + RAC_LOG_INFO("STT.WhisperCpp","Transcription complete: %d segments, %.0fms inference, lang=%s", n_segments, result.inference_time_ms, result.detected_language.empty() ? "unknown" : result.detected_language.c_str()); @@ -353,7 +349,7 @@ std::string WhisperCppSTT::create_stream(const nlohmann::json& config) { std::lock_guard lock(mutex_); if (!model_loaded_ || !ctx_) { - LOGE("Cannot create stream: model not loaded"); + RAC_LOG_ERROR("STT.WhisperCpp","Cannot create stream: model not loaded"); return ""; } @@ -363,7 +359,7 @@ std::string WhisperCppSTT::create_stream(const nlohmann::json& config) { state->state = whisper_init_state(ctx_); if (!state->state) { - LOGE("Failed to create whisper state for stream"); + RAC_LOG_ERROR("STT.WhisperCpp","Failed to create whisper state for stream"); return ""; } @@ -377,7 +373,7 @@ std::string WhisperCppSTT::create_stream(const nlohmann::json& config) { streams_[stream_id] = std::move(state); - LOGI("Created stream: %s", stream_id.c_str()); + RAC_LOG_INFO("STT.WhisperCpp","Created stream: %s", stream_id.c_str()); return stream_id; } @@ -387,7 +383,7 @@ bool WhisperCppSTT::feed_audio(const std::string& stream_id, const std::vector(stream_state->audio_buffer.size())); if (ret != 0) { - LOGE("whisper_full_with_state failed: %d", ret); + RAC_LOG_ERROR("STT.WhisperCpp","whisper_full_with_state failed: %d", ret); return result; } @@ -504,7 +500,7 @@ void WhisperCppSTT::input_finished(const std::string& stream_id) { auto it = streams_.find(stream_id); if (it != streams_.end()) { it->second->input_finished = true; - LOGI("Input finished for stream: %s", stream_id.c_str()); + RAC_LOG_INFO("STT.WhisperCpp","Input finished for stream: %s", stream_id.c_str()); } } @@ -515,7 +511,7 @@ void WhisperCppSTT::reset_stream(const std::string& stream_id) { if (it != streams_.end()) { it->second->audio_buffer.clear(); it->second->input_finished = false; - LOGI("Reset stream: %s", stream_id.c_str()); + RAC_LOG_INFO("STT.WhisperCpp","Reset stream: %s", stream_id.c_str()); } } @@ -528,13 +524,13 @@ void WhisperCppSTT::destroy_stream(const std::string& stream_id) { whisper_free_state(it->second->state); } streams_.erase(it); - LOGI("Destroyed stream: %s", stream_id.c_str()); + RAC_LOG_INFO("STT.WhisperCpp","Destroyed stream: %s", stream_id.c_str()); } } void WhisperCppSTT::cancel() { cancel_requested_.store(true); - LOGI("Cancellation requested"); + RAC_LOG_INFO("STT.WhisperCpp","Cancellation requested"); } std::vector WhisperCppSTT::get_supported_languages() const { @@ -613,7 +609,7 @@ std::vector WhisperCppSTT::resample_to_16khz(const std::vector& sa pos += step; } - LOGI("Resampled audio from %d Hz to %d Hz (%zu -> %zu samples)", source_rate, + RAC_LOG_INFO("STT.WhisperCpp","Resampled audio from %d Hz to %d Hz (%zu -> %zu samples)", source_rate, WHISPER_SAMPLE_RATE, samples.size(), output_size); return output; diff --git a/sdk/runanywhere-commons/src/backends/whisperkit_coreml/CMakeLists.txt b/sdk/runanywhere-commons/src/backends/whisperkit_coreml/CMakeLists.txt index bdb28d0a0..6dfc10aef 100644 --- a/sdk/runanywhere-commons/src/backends/whisperkit_coreml/CMakeLists.txt +++ b/sdk/runanywhere-commons/src/backends/whisperkit_coreml/CMakeLists.txt @@ -32,7 +32,7 @@ target_link_libraries(rac_backend_whisperkit_coreml PUBLIC rac_commons ) -target_compile_features(rac_backend_whisperkit_coreml PUBLIC cxx_std_17) +target_compile_features(rac_backend_whisperkit_coreml PUBLIC cxx_std_20) # ============================================================================= # Summary diff --git a/sdk/runanywhere-commons/src/backends/whisperkit_coreml/rac_backend_whisperkit_coreml_register.cpp b/sdk/runanywhere-commons/src/backends/whisperkit_coreml/rac_backend_whisperkit_coreml_register.cpp index 6f20bf598..cb9d7ad4e 100644 --- a/sdk/runanywhere-commons/src/backends/whisperkit_coreml/rac_backend_whisperkit_coreml_register.cpp +++ b/sdk/runanywhere-commons/src/backends/whisperkit_coreml/rac_backend_whisperkit_coreml_register.cpp @@ -58,6 +58,7 @@ static rac_result_t whisperkit_coreml_stt_vtable_transcribe_stream( if (status == RAC_SUCCESS && callback && result.text) { callback(result.text, RAC_TRUE, user_data); } + rac_stt_result_free(&result); return status; } diff --git a/sdk/runanywhere-commons/src/core/capabilities/lifecycle_manager.cpp b/sdk/runanywhere-commons/src/core/capabilities/lifecycle_manager.cpp index 21e15af86..60b11bdd5 100644 --- a/sdk/runanywhere-commons/src/core/capabilities/lifecycle_manager.cpp +++ b/sdk/runanywhere-commons/src/core/capabilities/lifecycle_manager.cpp @@ -13,6 +13,7 @@ #include #include #include +#include #include #include "rac/core/capabilities/rac_lifecycle.h" @@ -46,7 +47,7 @@ struct LifecycleManager { std::string current_model_id{}; // Model identifier for telemetry (e.g., "sherpa-onnx-whisper-tiny.en") std::string current_model_name{}; // Human-readable name (e.g., "Sherpa Whisper Tiny (ONNX)") - rac_handle_t current_service{nullptr}; + std::atomic current_service{nullptr}; // Metrics (mirrors Swift's ManagedLifecycle metrics) int32_t load_count{0}; @@ -59,6 +60,10 @@ struct LifecycleManager { // Thread safety std::mutex mutex{}; + // Service pinning (acquire/release) — prevents unload while service is in use + std::atomic service_refcount{0}; + std::condition_variable service_cv{}; + LifecycleManager() { // Set start time (mirrors Swift's startTime = Date()) auto now = std::chrono::system_clock::now(); @@ -134,7 +139,10 @@ rac_result_t rac_lifecycle_create(const rac_lifecycle_config_t* config, return RAC_ERROR_NULL_POINTER; } - auto* mgr = new LifecycleManager(); + auto* mgr = new (std::nothrow) LifecycleManager(); + if (!mgr) { + return RAC_ERROR_OUT_OF_MEMORY; + } mgr->resource_type = config->resource_type; mgr->logger_category = config->logger_category ? config->logger_category : "Lifecycle"; mgr->user_data = config->user_data; @@ -166,10 +174,10 @@ rac_result_t rac_lifecycle_load(rac_handle_t handle, const char* model_path, con // Check if already loaded with same path - skip duplicate events // Mirrors Swift: if await lifecycle.currentResourceId == modelId if (mgr->state.load() == RAC_LIFECYCLE_STATE_LOADED && mgr->current_model_path == model_path && - mgr->current_service != nullptr) { + mgr->current_service.load() != nullptr) { // Mirrors Swift: logger.info("Model already loaded, skipping duplicate load") RAC_LOG_INFO(mgr->logger_category.c_str(), "Model already loaded, skipping duplicate load"); - *out_service = mgr->current_service; + *out_service = mgr->current_service.load(); return RAC_SUCCESS; } @@ -192,7 +200,7 @@ rac_result_t rac_lifecycle_load(rac_handle_t handle, const char* model_path, con mgr->current_model_path = model_path; mgr->current_model_id = model_id; // Model identifier for telemetry mgr->current_model_name = model_name; // Human-readable name for telemetry - mgr->current_service = service; + mgr->current_service.store(service); mgr->state.store(RAC_LIFECYCLE_STATE_LOADED); // Track load completed (mirrors Swift: trackEvent(type: .loadCompleted)) @@ -221,22 +229,58 @@ rac_result_t rac_lifecycle_load(rac_handle_t handle, const char* model_path, con return result; } +rac_result_t rac_lifecycle_acquire_service(rac_handle_t handle, rac_handle_t* out_service) { + if (handle == nullptr || out_service == nullptr) { + return RAC_ERROR_NULL_POINTER; + } + + auto* mgr = static_cast(handle); + std::lock_guard lock(mgr->mutex); + + if (mgr->state.load() != RAC_LIFECYCLE_STATE_LOADED || mgr->current_service.load() == nullptr) { + return RAC_ERROR_NOT_INITIALIZED; + } + + mgr->service_refcount.fetch_add(1); + *out_service = mgr->current_service.load(); + return RAC_SUCCESS; +} + +void rac_lifecycle_release_service(rac_handle_t handle) { + if (handle == nullptr) { + return; + } + + auto* mgr = static_cast(handle); + int prev = mgr->service_refcount.fetch_sub(1); + if (prev <= 1) { + mgr->service_cv.notify_all(); + } +} + rac_result_t rac_lifecycle_unload(rac_handle_t handle) { if (handle == nullptr) { return RAC_ERROR_NULL_POINTER; } auto* mgr = static_cast(handle); - std::lock_guard lock(mgr->mutex); + std::unique_lock lock(mgr->mutex); + + // Wait for all acquired services to be released + mgr->service_cv.wait(lock, [mgr] { return mgr->service_refcount.load() == 0; }); // Mirrors Swift: if let modelId = await lifecycle.currentResourceId if (!mgr->current_model_id.empty()) { RAC_LOG_INFO(mgr->logger_category.c_str(), "Unloading model: %s", mgr->current_model_id.c_str()); - // Destroy service if callback provided - if (mgr->destroy_fn != nullptr && mgr->current_service != nullptr) { - mgr->destroy_fn(mgr->current_service, mgr->user_data); + // Destroy service if callback provided. + // Store nullptr BEFORE calling destroy_fn so that concurrent cancel() + // reads via get_service() see nullptr rather than a dangling pointer. + rac_handle_t svc = mgr->current_service.load(); + mgr->current_service.store(nullptr); + if (mgr->destroy_fn != nullptr && svc != nullptr) { + mgr->destroy_fn(svc, mgr->user_data); } // Track unload event (mirrors Swift: trackEvent(type: .unloaded)) @@ -249,7 +293,7 @@ rac_result_t rac_lifecycle_unload(rac_handle_t handle) { mgr->current_model_path.clear(); mgr->current_model_id.clear(); mgr->current_model_name.clear(); - mgr->current_service = nullptr; + mgr->current_service.store(nullptr); mgr->state.store(RAC_LIFECYCLE_STATE_IDLE); return RAC_SUCCESS; @@ -261,15 +305,20 @@ rac_result_t rac_lifecycle_reset(rac_handle_t handle) { } auto* mgr = static_cast(handle); - std::lock_guard lock(mgr->mutex); + std::unique_lock lock(mgr->mutex); + + // Wait for all acquired services to be released + mgr->service_cv.wait(lock, [mgr] { return mgr->service_refcount.load() == 0; }); // Track unload if currently loaded (mirrors Swift reset()) if (!mgr->current_model_id.empty()) { track_lifecycle_event(mgr, "unloaded", mgr->current_model_id.c_str(), 0.0, RAC_SUCCESS); - // Destroy service if callback provided - if (mgr->destroy_fn != nullptr && mgr->current_service != nullptr) { - mgr->destroy_fn(mgr->current_service, mgr->user_data); + // Store nullptr BEFORE calling destroy_fn (same reasoning as unload) + rac_handle_t svc = mgr->current_service.load(); + mgr->current_service.store(nullptr); + if (mgr->destroy_fn != nullptr && svc != nullptr) { + mgr->destroy_fn(svc, mgr->user_data); } } @@ -277,7 +326,7 @@ rac_result_t rac_lifecycle_reset(rac_handle_t handle) { mgr->current_model_path.clear(); mgr->current_model_id.clear(); mgr->current_model_name.clear(); - mgr->current_service = nullptr; + mgr->current_service.store(nullptr); mgr->state.store(RAC_LIFECYCLE_STATE_IDLE); return RAC_SUCCESS; @@ -335,7 +384,10 @@ rac_handle_t rac_lifecycle_get_service(rac_handle_t handle) { } auto* mgr = static_cast(handle); - return mgr->current_service; + // Atomic load — safe to call without the lifecycle mutex. + // This is used by cancel() paths which intentionally skip locking + // to avoid deadlock with streaming operations. + return mgr->current_service.load(std::memory_order_acquire); } rac_result_t rac_lifecycle_require_service(rac_handle_t handle, rac_handle_t* out_service) { @@ -345,12 +397,12 @@ rac_result_t rac_lifecycle_require_service(rac_handle_t handle, rac_handle_t* ou auto* mgr = static_cast(handle); - if (mgr->state.load() != RAC_LIFECYCLE_STATE_LOADED || mgr->current_service == nullptr) { + if (mgr->state.load() != RAC_LIFECYCLE_STATE_LOADED || mgr->current_service.load() == nullptr) { rac_error_set_details("Service not loaded - call load() first"); return RAC_ERROR_NOT_INITIALIZED; } - *out_service = mgr->current_service; + *out_service = mgr->current_service.load(); return RAC_SUCCESS; } @@ -436,6 +488,10 @@ const char* rac_resource_type_name(rac_resource_type_t type) { return "vadModel"; case RAC_RESOURCE_TYPE_DIARIZATION_MODEL: return "diarizationModel"; + case RAC_RESOURCE_TYPE_VLM_MODEL: + return "vlmModel"; + case RAC_RESOURCE_TYPE_DIFFUSION_MODEL: + return "diffusionModel"; default: return "unknown"; } diff --git a/sdk/runanywhere-commons/src/core/rac_audio_utils.cpp b/sdk/runanywhere-commons/src/core/rac_audio_utils.cpp index 32823564a..52cad6a29 100644 --- a/sdk/runanywhere-commons/src/core/rac_audio_utils.cpp +++ b/sdk/runanywhere-commons/src/core/rac_audio_utils.cpp @@ -122,10 +122,15 @@ rac_result_t rac_audio_float32_to_wav(const void* pcm_data, size_t pcm_size, int return RAC_ERROR_INVALID_ARGUMENT; } - const size_t num_samples = pcm_size / 4; + const size_t num_samples = pcm_size / sizeof(float); + + // Guard against WAV header overflow: data_size field is uint32_t (max ~4GB) + if (num_samples > UINT32_MAX / sizeof(int16_t)) { + return RAC_ERROR_INVALID_ARGUMENT; + } // Int16 data size (2 bytes per sample) - const uint32_t int16_data_size = static_cast(num_samples * 2); + const uint32_t int16_data_size = static_cast(num_samples * sizeof(int16_t)); // Total WAV file size const size_t wav_size = WAV_HEADER_SIZE + int16_data_size; @@ -140,14 +145,18 @@ rac_result_t rac_audio_float32_to_wav(const void* pcm_data, size_t pcm_size, int build_wav_header(wav_data, sample_rate, int16_data_size); // Convert Float32 to Int16 - const float* float_samples = static_cast(pcm_data); - int16_t* int16_samples = reinterpret_cast(wav_data + WAV_HEADER_SIZE); + // Use __restrict to guarantee no aliasing, enabling auto-vectorization. + // The loop body is kept simple (multiply + clamp) so compilers can emit + // NEON (ARM) or SSE/AVX (x86) vector instructions with -O2. + const float* __restrict float_samples = static_cast(pcm_data); + int16_t* __restrict int16_samples = reinterpret_cast(wav_data + WAV_HEADER_SIZE); for (size_t i = 0; i < num_samples; ++i) { - // Clamp to [-1.0, 1.0] range - float sample = std::max(-1.0f, std::min(1.0f, float_samples[i])); - // Scale to Int16 range [-32768, 32767] - int16_samples[i] = static_cast(sample * 32767.0f); + // Multiply first, then clamp to Int16 range in one step. + // Avoids two separate clamp operations and is auto-vectorizable. + // std::clamp produces branchless vmin/vmax on ARM NEON, minps/maxps on SSE. + const float scaled = std::clamp(float_samples[i] * 32767.0f, -32768.0f, 32767.0f); + int16_samples[i] = static_cast(scaled); } *out_wav_data = wav_data; @@ -172,6 +181,12 @@ rac_result_t rac_audio_int16_to_wav(const void* pcm_data, size_t pcm_size, int32 return RAC_ERROR_INVALID_ARGUMENT; } + // Guard against WAV header overflow: the RIFF chunk-size field (data_size + 36) + // is uint32_t, so data_size must leave room for the 36-byte header overhead. + if (pcm_size > static_cast(UINT32_MAX) - (WAV_HEADER_SIZE - 8)) { + return RAC_ERROR_INVALID_ARGUMENT; + } + const uint32_t data_size = static_cast(pcm_size); // Total WAV file size diff --git a/sdk/runanywhere-commons/src/core/rac_core.cpp b/sdk/runanywhere-commons/src/core/rac_core.cpp index 764bf39b5..7c91c5ece 100644 --- a/sdk/runanywhere-commons/src/core/rac_core.cpp +++ b/sdk/runanywhere-commons/src/core/rac_core.cpp @@ -194,6 +194,10 @@ rac_result_t rac_http_download(const char* url, const char* destination_path, rac_http_progress_callback_fn progress_callback, rac_http_complete_callback_fn complete_callback, void* callback_user_data, char** out_task_id) { + if (url == nullptr || destination_path == nullptr) { + return RAC_ERROR_NULL_POINTER; + } + if (s_platform_adapter == nullptr) { return RAC_ERROR_ADAPTER_NOT_SET; } @@ -208,6 +212,10 @@ rac_result_t rac_http_download(const char* url, const char* destination_path, } rac_result_t rac_http_download_cancel(const char* task_id) { + if (task_id == nullptr) { + return RAC_ERROR_NULL_POINTER; + } + if (s_platform_adapter == nullptr) { return RAC_ERROR_ADAPTER_NOT_SET; } @@ -223,19 +231,39 @@ rac_result_t rac_http_download_cancel(const char* task_id) { // ARCHIVE EXTRACTION CONVENIENCE FUNCTION // ============================================================================= +#include "rac/infrastructure/extraction/rac_extraction.h" + rac_result_t rac_extract_archive(const char* archive_path, const char* destination_dir, rac_extract_progress_callback_fn progress_callback, void* callback_user_data) { - if (s_platform_adapter == nullptr) { - return RAC_ERROR_ADAPTER_NOT_SET; + if (archive_path == nullptr || destination_dir == nullptr) { + return RAC_ERROR_NULL_POINTER; } - if (s_platform_adapter->extract_archive == nullptr) { - return RAC_ERROR_NOT_SUPPORTED; + // Bridge the old callback signature to the new one + struct bridge_ctx { + rac_extract_progress_callback_fn callback; + void* user_data; + }; + bridge_ctx ctx = {progress_callback, callback_user_data}; + + rac_extraction_progress_fn bridged_callback = nullptr; + void* bridged_user_data = nullptr; + if (progress_callback) { + bridged_callback = [](int32_t files_extracted, int32_t total_files, + int64_t /* bytes_extracted */, void* ud) { + auto* c = static_cast(ud); + if (c->callback) { + c->callback(files_extracted, total_files, c->user_data); + } + }; + bridged_user_data = &ctx; } - return s_platform_adapter->extract_archive(archive_path, destination_dir, progress_callback, - callback_user_data, s_platform_adapter->user_data); + // Use native libarchive extraction + return rac_extract_archive_native(archive_path, destination_dir, nullptr /* default options */, + bridged_callback, bridged_user_data, + nullptr /* no result output */); } // ============================================================================= diff --git a/sdk/runanywhere-commons/src/core/rac_error.cpp b/sdk/runanywhere-commons/src/core/rac_error.cpp index d0e178686..61b4d0cfa 100644 --- a/sdk/runanywhere-commons/src/core/rac_error.cpp +++ b/sdk/runanywhere-commons/src/core/rac_error.cpp @@ -41,6 +41,8 @@ const char* rac_error_message(rac_result_t error_code) { return "API key is invalid or missing"; case RAC_ERROR_ENVIRONMENT_MISMATCH: return "Environment mismatch"; + case RAC_ERROR_INVALID_PARAMETER: + return "Invalid parameter value"; // ================================================================= // MODEL ERRORS (-110 to -129) diff --git a/sdk/runanywhere-commons/src/core/rac_logger.cpp b/sdk/runanywhere-commons/src/core/rac_logger.cpp index 98c70af40..fcdb70908 100644 --- a/sdk/runanywhere-commons/src/core/rac_logger.cpp +++ b/sdk/runanywhere-commons/src/core/rac_logger.cpp @@ -8,6 +8,7 @@ #include "rac/core/rac_logger.h" +#include #include #include #include @@ -22,12 +23,14 @@ namespace { // Logger configuration +// min_level is atomic so log-level checks can skip the mutex entirely. +// stderr_fallback/stderr_always are also atomic for lock-free reads in the hot path. struct LoggerState { - rac_log_level_t min_level = RAC_LOG_INFO; - rac_bool_t stderr_fallback = RAC_TRUE; - rac_bool_t stderr_always = RAC_TRUE; // Always log to stderr (safe during static init) - rac_bool_t initialized = RAC_FALSE; - std::mutex mutex; + std::atomic min_level{RAC_LOG_INFO}; + std::atomic stderr_fallback{RAC_TRUE}; + std::atomic stderr_always{RAC_TRUE}; + std::atomic initialized{RAC_FALSE}; + std::mutex mutex; // Only used for write operations (init/shutdown/set) }; LoggerState& state() { @@ -61,7 +64,14 @@ const char* filename_from_path(const char* path) { return nullptr; const char* last_slash = strrchr(path, '/'); const char* last_backslash = strrchr(path, '\\'); - const char* last_sep = last_slash > last_backslash ? last_slash : last_backslash; + // Pick the later separator. Avoid comparing two pointers from unrelated + // arrays (UB when one is nullptr): explicitly handle the null cases. + const char* last_sep; + if (last_slash && last_backslash) { + last_sep = last_slash > last_backslash ? last_slash : last_backslash; + } else { + last_sep = last_slash ? last_slash : last_backslash; + } return last_sep ? last_sep + 1 : path; } @@ -134,10 +144,10 @@ void format_message_with_metadata(char* buffer, size_t buffer_size, const char* // Fallback to stderr void log_to_stderr(rac_log_level_t level, const char* category, const char* message, const rac_log_metadata_t* metadata) { - const char* level_str = level_to_string(level); + const char* const level_str = level_to_string(level); // Determine output stream - FILE* stream = (level >= RAC_LOG_ERROR) ? stderr : stdout; + FILE* const stream = (level >= RAC_LOG_ERROR) ? stderr : stdout; // Print base message fprintf(stream, "[RAC][%s][%s] %s", level_str, category, message); @@ -177,35 +187,29 @@ void log_to_stderr(rac_log_level_t level, const char* category, const char* mess extern "C" { rac_result_t rac_logger_init(rac_log_level_t min_level) { - std::lock_guard lock(state().mutex); - state().min_level = min_level; - state().initialized = RAC_TRUE; + state().min_level.store(min_level, std::memory_order_relaxed); + state().initialized.store(RAC_TRUE, std::memory_order_release); return RAC_SUCCESS; } void rac_logger_shutdown(void) { - std::lock_guard lock(state().mutex); - state().initialized = RAC_FALSE; + state().initialized.store(RAC_FALSE, std::memory_order_release); } void rac_logger_set_min_level(rac_log_level_t level) { - std::lock_guard lock(state().mutex); - state().min_level = level; + state().min_level.store(level, std::memory_order_relaxed); } rac_log_level_t rac_logger_get_min_level(void) { - std::lock_guard lock(state().mutex); - return state().min_level; + return state().min_level.load(std::memory_order_relaxed); } void rac_logger_set_stderr_fallback(rac_bool_t enabled) { - std::lock_guard lock(state().mutex); - state().stderr_fallback = enabled; + state().stderr_fallback.store(enabled, std::memory_order_relaxed); } void rac_logger_set_stderr_always(rac_bool_t enabled) { - std::lock_guard lock(state().mutex); - state().stderr_always = enabled; + state().stderr_always.store(enabled, std::memory_order_relaxed); } void rac_logger_log(rac_log_level_t level, const char* category, const char* message, @@ -215,18 +219,12 @@ void rac_logger_log(rac_log_level_t level, const char* category, const char* mes if (!category) category = "RAC"; - // Get state configuration (with lock) - rac_log_level_t min_level; - rac_bool_t stderr_always; - rac_bool_t stderr_fallback; - { - std::lock_guard lock(state().mutex); - min_level = state().min_level; - stderr_always = state().stderr_always; - stderr_fallback = state().stderr_fallback; - } + // Atomic reads — no mutex needed for the hot-path level check + const rac_log_level_t min_level = state().min_level.load(std::memory_order_relaxed); + const rac_bool_t stderr_always = state().stderr_always.load(std::memory_order_relaxed); + const rac_bool_t stderr_fallback = state().stderr_fallback.load(std::memory_order_relaxed); - // Check min level + // Check min level (early out before any formatting work) if (level < min_level) return; @@ -254,6 +252,10 @@ void rac_logger_logf(rac_log_level_t level, const char* category, if (!format) return; + // Early level check: skip vsnprintf entirely if this message will be filtered + if (level < state().min_level.load(std::memory_order_relaxed)) + return; + va_list args; va_start(args, format); rac_logger_logv(level, category, metadata, format, args); @@ -265,7 +267,11 @@ void rac_logger_logv(rac_log_level_t level, const char* category, if (!format) return; - // Format the message + // Early level check: skip vsnprintf entirely if this message will be filtered + if (level < state().min_level.load(std::memory_order_relaxed)) + return; + + // Format the message (only reached when level passes) char buffer[2048]; vsnprintf(buffer, sizeof(buffer), format, args); diff --git a/sdk/runanywhere-commons/src/core/rac_structured_error.cpp b/sdk/runanywhere-commons/src/core/rac_structured_error.cpp index ad79ac6be..67d34c510 100644 --- a/sdk/runanywhere-commons/src/core/rac_structured_error.cpp +++ b/sdk/runanywhere-commons/src/core/rac_structured_error.cpp @@ -124,7 +124,13 @@ void rac_error_set_source(rac_error_t* error, const char* file, int32_t line, if (file) { const char* last_slash = strrchr(file, '/'); const char* last_backslash = strrchr(file, '\\'); - const char* last_sep = (last_slash > last_backslash) ? last_slash : last_backslash; + // Avoid UB: relational comparison of unrelated pointers when one is NULL + const char* last_sep; + if (last_slash && last_backslash) { + last_sep = last_slash > last_backslash ? last_slash : last_backslash; + } else { + last_sep = last_slash ? last_slash : last_backslash; + } const char* filename = last_sep ? last_sep + 1 : file; safe_strcpy(error->source_file, sizeof(error->source_file), filename); } @@ -602,71 +608,119 @@ char* rac_error_to_json(const rac_error_t* error) { if (!error) return nullptr; - // Allocate buffer for JSON - size_t buffer_size = 4096; + // Buffer large enough for all rac_error_t fields at max capacity + constexpr size_t buffer_size = 8192; char* json = static_cast(malloc(buffer_size)); if (!json) return nullptr; int pos = 0; + + // Clamp pos after snprintf to prevent buffer overrun + // (snprintf returns count that WOULD be written, which can exceed available space) + auto clamp = [&]() { + if (pos >= static_cast(buffer_size)) + pos = static_cast(buffer_size) - 1; + }; + + // Write a JSON-escaped string into the buffer + auto write_escaped = [&](const char* str) { + for (const char* p = str; *p != '\0' && pos < static_cast(buffer_size) - 2; p++) { + auto c = static_cast(*p); + if (c == '"' || c == '\\') { + json[pos++] = '\\'; + if (pos < static_cast(buffer_size) - 1) + json[pos++] = static_cast(c); + } else if (c == '\n') { + json[pos++] = '\\'; + if (pos < static_cast(buffer_size) - 1) + json[pos++] = 'n'; + } else if (c == '\r') { + json[pos++] = '\\'; + if (pos < static_cast(buffer_size) - 1) + json[pos++] = 'r'; + } else if (c == '\t') { + json[pos++] = '\\'; + if (pos < static_cast(buffer_size) - 1) + json[pos++] = 't'; + } else if (c < 0x20) { + continue; // Skip other control characters + } else { + json[pos++] = static_cast(c); + } + } + }; + pos += snprintf(json + pos, buffer_size - pos, "{"); + clamp(); // NOLINTNEXTLINE(modernize-raw-string-literal) pos += snprintf(json + pos, buffer_size - pos, "\"code\":%d,", error->code); + clamp(); // NOLINTNEXTLINE(modernize-raw-string-literal) pos += snprintf(json + pos, buffer_size - pos, "\"code_name\":\"%s\",", rac_error_code_name(error->code)); + clamp(); // NOLINTNEXTLINE(modernize-raw-string-literal) pos += snprintf(json + pos, buffer_size - pos, "\"category\":\"%s\",", rac_error_category_name(error->category)); + clamp(); - // Escape message for JSON + // Escape message for JSON (handles ", \, \n, \r, \t, control chars) // NOLINTNEXTLINE(modernize-raw-string-literal) pos += snprintf(json + pos, buffer_size - pos, "\"message\":\""); - for (const char* p = error->message; *p != '\0' && pos < (int)buffer_size - 10; p++) { - if (*p == '"' || *p == '\\') { - json[pos++] = '\\'; - } - json[pos++] = *p; - } + clamp(); + write_escaped(error->message); // NOLINTNEXTLINE(modernize-raw-string-literal) pos += snprintf(json + pos, buffer_size - pos, "\","); + clamp(); // NOLINTNEXTLINE(modernize-raw-string-literal) pos += snprintf(json + pos, buffer_size - pos, "\"timestamp_ms\":%lld,", static_cast(error->timestamp_ms)); + clamp(); // Source location if (error->source_file[0] != '\0') { // NOLINTNEXTLINE(modernize-raw-string-literal) pos += snprintf(json + pos, buffer_size - pos, "\"source_file\":\"%s\",\"source_line\":%d,", error->source_file, error->source_line); + clamp(); } if (error->source_function[0] != '\0') { // NOLINTNEXTLINE(modernize-raw-string-literal) pos += snprintf(json + pos, buffer_size - pos, "\"source_function\":\"%s\",", error->source_function); + clamp(); } // Model context if (error->model_id[0] != '\0') { // NOLINTNEXTLINE(modernize-raw-string-literal) pos += snprintf(json + pos, buffer_size - pos, "\"model_id\":\"%s\",", error->model_id); + clamp(); } if (error->framework[0] != '\0') { // NOLINTNEXTLINE(modernize-raw-string-literal) pos += snprintf(json + pos, buffer_size - pos, "\"framework\":\"%s\",", error->framework); + clamp(); } if (error->session_id[0] != '\0') { // NOLINTNEXTLINE(modernize-raw-string-literal) pos += snprintf(json + pos, buffer_size - pos, "\"session_id\":\"%s\",", error->session_id); + clamp(); } - // Underlying error + // Underlying error — escape the message if (error->underlying_code != 0) { - pos += snprintf( - json + pos, buffer_size - pos, - "\"underlying_code\":%d,\"underlying_message\":\"%s\",", // NOLINT(modernize-raw-string-literal) - error->underlying_code, error->underlying_message); + // NOLINTNEXTLINE(modernize-raw-string-literal) + pos += snprintf(json + pos, buffer_size - pos, + "\"underlying_code\":%d,\"underlying_message\":\"", + error->underlying_code); + clamp(); + write_escaped(error->underlying_message); + // NOLINTNEXTLINE(modernize-raw-string-literal) + pos += snprintf(json + pos, buffer_size - pos, "\","); + clamp(); } // Stack trace @@ -674,6 +728,7 @@ char* rac_error_to_json(const rac_error_t* error) { // NOLINTNEXTLINE(modernize-raw-string-literal) pos += snprintf(json + pos, buffer_size - pos, "\"stack_frame_count\":%d,", error->stack_frame_count); + clamp(); } // Custom metadata @@ -681,22 +736,26 @@ char* rac_error_to_json(const rac_error_t* error) { // NOLINTNEXTLINE(modernize-raw-string-literal) pos += snprintf(json + pos, buffer_size - pos, "\"%s\":\"%s\",", error->custom_key1, error->custom_value1); + clamp(); } if (error->custom_key2[0] != '\0' && error->custom_value2[0] != '\0') { // NOLINTNEXTLINE(modernize-raw-string-literal) pos += snprintf(json + pos, buffer_size - pos, "\"%s\":\"%s\",", error->custom_key2, error->custom_value2); + clamp(); } if (error->custom_key3[0] != '\0' && error->custom_value3[0] != '\0') { // NOLINTNEXTLINE(modernize-raw-string-literal) pos += snprintf(json + pos, buffer_size - pos, "\"%s\":\"%s\",", error->custom_key3, error->custom_value3); + clamp(); } // Remove trailing comma and close - if (json[pos - 1] == ',') + if (pos > 0 && json[pos - 1] == ',') pos--; - json[pos++] = '}'; + if (pos < static_cast(buffer_size) - 1) + json[pos++] = '}'; json[pos] = '\0'; return json; @@ -712,16 +771,31 @@ int32_t rac_error_get_telemetry_properties(const rac_error_t* error, char** out_ // Error code out_keys[count] = strdup("error_code"); out_values[count] = strdup(rac_error_code_name(error->code)); + if (!out_keys[count] || !out_values[count]) { + free(out_keys[count]); + free(out_values[count]); + return count; + } count++; // Error category out_keys[count] = strdup("error_category"); out_values[count] = strdup(rac_error_category_name(error->category)); + if (!out_keys[count] || !out_values[count]) { + free(out_keys[count]); + free(out_values[count]); + return count; + } count++; // Error message out_keys[count] = strdup("error_message"); out_values[count] = strdup(error->message); + if (!out_keys[count] || !out_values[count]) { + free(out_keys[count]); + free(out_values[count]); + return count; + } count++; return count; @@ -746,42 +820,56 @@ char* rac_error_to_debug_string(const rac_error_t* error) { if (!error) return nullptr; - size_t size = 2048; + constexpr size_t size = 2048; char* str = static_cast(malloc(size)); if (!str) return nullptr; int pos = 0; + // Clamp pos after snprintf to prevent buffer overrun + auto clamp = [&]() { + if (pos >= static_cast(size)) + pos = static_cast(size) - 1; + }; + pos += snprintf(str + pos, size - pos, "SDKError[%s.%s]: %s", rac_error_category_name(error->category), rac_error_code_name(error->code), error->message); + clamp(); if (error->underlying_code != 0) { pos += snprintf(str + pos, size - pos, "\n Caused by: %s (%d)", error->underlying_message, error->underlying_code); + clamp(); } if (error->source_file[0] != '\0') { pos += snprintf(str + pos, size - pos, "\n At: %s:%d in %s", error->source_file, error->source_line, error->source_function); + clamp(); } if (error->model_id[0] != '\0') { pos += snprintf(str + pos, size - pos, "\n Model: %s (%s)", error->model_id, error->framework); + clamp(); } if (error->stack_frame_count > 0) { pos += snprintf(str + pos, size - pos, "\n Stack trace (%d frames):", error->stack_frame_count); - for (int i = 0; i < error->stack_frame_count && i < 5 && pos < (int)size - 100; i++) { + clamp(); + for (int i = 0; i < error->stack_frame_count && i < 5 && pos < static_cast(size) - 100; + i++) { if (error->stack_frames[i].function != nullptr) { pos += snprintf( str + pos, size - pos, "\n %s at %s:%d", error->stack_frames[i].function, error->stack_frames[i].file != nullptr ? error->stack_frames[i].file : "?", error->stack_frames[i].line); + clamp(); } else if (error->stack_frames[i].address != nullptr) { pos += snprintf(str + pos, size - pos, "\n %p", error->stack_frames[i].address); + clamp(); } } } diff --git a/sdk/runanywhere-commons/src/features/diffusion/diffusion_component.cpp b/sdk/runanywhere-commons/src/features/diffusion/diffusion_component.cpp index fe11166d3..5e49b8856 100644 --- a/sdk/runanywhere-commons/src/features/diffusion/diffusion_component.cpp +++ b/sdk/runanywhere-commons/src/features/diffusion/diffusion_component.cpp @@ -13,6 +13,7 @@ #include #include #include +#include #include #include "rac/core/capabilities/rac_lifecycle.h" @@ -115,11 +116,10 @@ static rac_diffusion_options_t merge_diffusion_options( * Generate a unique ID for generation tracking. */ static std::string generate_unique_id() { - auto now = std::chrono::high_resolution_clock::now(); - auto epoch = now.time_since_epoch(); - auto ns = std::chrono::duration_cast(epoch).count(); - char buffer[64]; - snprintf(buffer, sizeof(buffer), "diffusion_%lld", static_cast(ns)); + static thread_local std::mt19937 gen(std::random_device{}()); + std::uniform_int_distribution dis; + char buffer[32]; + snprintf(buffer, sizeof(buffer), "diff_%08x%08x", dis(gen), dis(gen)); return std::string(buffer); } @@ -386,29 +386,27 @@ extern "C" rac_result_t rac_diffusion_component_generate(rac_handle_t handle, return RAC_ERROR_INVALID_ARGUMENT; auto* component = reinterpret_cast(handle); - std::lock_guard lock(component->mtx); - - // Reset cancellation flag - component->cancel_requested = false; - - // Generate unique ID for this generation - std::string generation_id = generate_unique_id(); - // Get model ID and name from lifecycle manager - const char* model_id = rac_lifecycle_get_model_id(component->lifecycle); - const char* model_name = rac_lifecycle_get_model_name(component->lifecycle); - - // Get service from lifecycle manager + // Acquire lock only for state reads; release before long-running generation rac_handle_t service = nullptr; - rac_result_t result = rac_lifecycle_require_service(component->lifecycle, &service); - if (result != RAC_SUCCESS) { - RAC_LOG_ERROR("Diffusion.Component", "No model loaded - cannot generate"); - return result; - } + rac_diffusion_options_t effective_options; + { + std::lock_guard lock(component->mtx); + + // Reset cancellation flag (also atomic, but set under lock for consistency) + component->cancel_requested = false; + + // Pin service via acquire to prevent unload during generation + rac_result_t result = rac_lifecycle_acquire_service(component->lifecycle, &service); + if (result != RAC_SUCCESS) { + RAC_LOG_ERROR("Diffusion.Component", "No model loaded - cannot generate"); + return result; + } - // Merge user options over component defaults - rac_diffusion_options_t effective_options = merge_diffusion_options( - component->default_options, options); + // Merge user options over component defaults + effective_options = merge_diffusion_options(component->default_options, options); + } + // Lock released — safe to do long-running generation RAC_LOG_INFO("Diffusion.Component", "Starting generation: %dx%d, %d steps, guidance=%.1f, scheduler=%d", @@ -417,8 +415,11 @@ extern "C" rac_result_t rac_diffusion_component_generate(rac_handle_t handle, auto start_time = std::chrono::steady_clock::now(); - // Perform generation - result = rac_diffusion_generate(service, &effective_options, out_result); + // Perform generation outside lock + rac_result_t result = rac_diffusion_generate(service, &effective_options, out_result); + + // Release pinned service in all exit paths + rac_lifecycle_release_service(component->lifecycle); if (result != RAC_SUCCESS) { RAC_LOG_ERROR("Diffusion.Component", "Generation failed: %d", result); @@ -483,25 +484,30 @@ extern "C" rac_result_t rac_diffusion_component_generate_with_callbacks( return RAC_ERROR_INVALID_ARGUMENT; auto* component = reinterpret_cast(handle); - std::lock_guard lock(component->mtx); - - // Reset cancellation flag - component->cancel_requested = false; - // Get service from lifecycle manager + // Acquire lock only for state reads; release before long-running generation rac_handle_t service = nullptr; - rac_result_t result = rac_lifecycle_require_service(component->lifecycle, &service); - if (result != RAC_SUCCESS) { - RAC_LOG_ERROR("Diffusion.Component", "No model loaded - cannot generate"); - if (error_callback) { - error_callback(result, "No model loaded", user_data); + rac_diffusion_options_t effective_options; + { + std::lock_guard lock(component->mtx); + + // Reset cancellation flag + component->cancel_requested = false; + + // Pin service via acquire to prevent unload during generation + rac_result_t result = rac_lifecycle_acquire_service(component->lifecycle, &service); + if (result != RAC_SUCCESS) { + RAC_LOG_ERROR("Diffusion.Component", "No model loaded - cannot generate"); + if (error_callback) { + error_callback(result, "No model loaded", user_data); + } + return result; } - return result; - } - // Merge user options over component defaults - rac_diffusion_options_t effective_options = merge_diffusion_options( - component->default_options, options); + // Merge user options over component defaults + effective_options = merge_diffusion_options(component->default_options, options); + } + // Lock released — safe to do long-running generation RAC_LOG_INFO("Diffusion.Component", "Starting generation with callbacks: %dx%d, %d steps, stride=%d", @@ -518,10 +524,13 @@ extern "C" rac_result_t rac_diffusion_component_generate_with_callbacks( ctx.start_time = std::chrono::steady_clock::now(); ctx.generation_id = generate_unique_id(); - // Perform generation with progress + // Perform generation with progress (outside lock) rac_diffusion_result_t gen_result = {}; - result = rac_diffusion_generate_with_progress(service, &effective_options, - diffusion_progress_wrapper, &ctx, &gen_result); + rac_result_t result = rac_diffusion_generate_with_progress(service, &effective_options, + diffusion_progress_wrapper, &ctx, &gen_result); + + // Release pinned service in all exit paths + rac_lifecycle_release_service(component->lifecycle); if (result != RAC_SUCCESS) { RAC_LOG_ERROR("Diffusion.Component", "Generation failed: %d", result); diff --git a/sdk/runanywhere-commons/src/features/llm/llm_analytics.cpp b/sdk/runanywhere-commons/src/features/llm/llm_analytics.cpp index 7a4d84a0a..ab27bde16 100644 --- a/sdk/runanywhere-commons/src/features/llm/llm_analytics.cpp +++ b/sdk/runanywhere-commons/src/features/llm/llm_analytics.cpp @@ -11,6 +11,7 @@ #include #include #include +#include #include #include #include @@ -46,9 +47,8 @@ int64_t get_current_time_ms() { } std::string generate_uuid() { - static std::random_device rd; - static std::mt19937 gen(rd()); - static std::uniform_int_distribution<> dis(0, 15); + static thread_local std::mt19937 gen(std::random_device{}()); + static thread_local std::uniform_int_distribution<> dis(0, 15); std::stringstream ss; ss << std::hex; @@ -120,13 +120,12 @@ rac_result_t rac_llm_analytics_create(rac_llm_analytics_handle_t* out_handle) { return RAC_ERROR_INVALID_PARAMETER; } - try { - *out_handle = new rac_llm_analytics_s(); - log_info("LLM.Analytics", "LLM analytics service created"); - return RAC_SUCCESS; - } catch (...) { + *out_handle = new (std::nothrow) rac_llm_analytics_s(); + if (!*out_handle) { return RAC_ERROR_OUT_OF_MEMORY; } + log_info("LLM.Analytics", "LLM analytics service created"); + return RAC_SUCCESS; } void rac_llm_analytics_destroy(rac_llm_analytics_handle_t handle) { @@ -171,7 +170,7 @@ rac_result_t rac_llm_analytics_start_generation(rac_llm_analytics_handle_t handl if (!*out_generation_id) { return RAC_ERROR_OUT_OF_MEMORY; } - strcpy(*out_generation_id, id.c_str()); + memcpy(*out_generation_id, id.c_str(), id.size() + 1); log_debug("LLM.Analytics", "Non-streaming generation started: %s", id.c_str()); return RAC_SUCCESS; @@ -210,7 +209,7 @@ rac_result_t rac_llm_analytics_start_streaming_generation( if (!*out_generation_id) { return RAC_ERROR_OUT_OF_MEMORY; } - strcpy(*out_generation_id, id.c_str()); + memcpy(*out_generation_id, id.c_str(), id.size() + 1); log_debug("LLM.Analytics", "Streaming generation started: %s", id.c_str()); return RAC_SUCCESS; diff --git a/sdk/runanywhere-commons/src/features/llm/llm_component.cpp b/sdk/runanywhere-commons/src/features/llm/llm_component.cpp index d09f739ea..1c8a4da4d 100644 --- a/sdk/runanywhere-commons/src/features/llm/llm_component.cpp +++ b/sdk/runanywhere-commons/src/features/llm/llm_component.cpp @@ -13,6 +13,7 @@ #include #include #include +#include #include #include "rac/core/capabilities/rac_lifecycle.h" @@ -76,11 +77,10 @@ static int32_t estimate_tokens(const char* text) { * Generate a unique ID for generation tracking. */ static std::string generate_unique_id() { - auto now = std::chrono::high_resolution_clock::now(); - auto epoch = now.time_since_epoch(); - auto ns = std::chrono::duration_cast(epoch).count(); - char buffer[64]; - snprintf(buffer, sizeof(buffer), "gen_%lld", static_cast(ns)); + static thread_local std::mt19937 gen(std::random_device{}()); + std::uniform_int_distribution dis; + char buffer[32]; + snprintf(buffer, sizeof(buffer), "gen_%08x%08x", dis(gen), dis(gen)); return std::string(buffer); } @@ -223,9 +223,14 @@ extern "C" void rac_llm_component_destroy(rac_handle_t handle) { auto* component = reinterpret_cast(handle); - // Destroy lifecycle manager (will cleanup service if loaded) - if (component->lifecycle) { - rac_lifecycle_destroy(component->lifecycle); + // Acquire component mutex to serialize against in-flight operations. + // lifecycle_destroy -> unload will block until any acquired services are released. + { + std::lock_guard lock(component->mtx); + if (component->lifecycle) { + rac_lifecycle_destroy(component->lifecycle); + component->lifecycle = nullptr; + } } log_info("LLM.Component", "LLM component destroyed"); @@ -667,6 +672,8 @@ extern "C" rac_result_t rac_llm_component_generate_stream( ctx.temperature = effective_options->temperature; ctx.max_tokens = effective_options->max_tokens; ctx.token_count = 0; + // Pre-allocate to avoid repeated reallocations during streaming + ctx.full_text.reserve(2048); // Perform streaming generation result = rac_llm_generate_stream(service, prompt, effective_options, llm_stream_token_callback, @@ -701,6 +708,13 @@ extern "C" rac_result_t rac_llm_component_generate_stream( rac_llm_result_t final_result = {}; final_result.text = strdup(ctx.full_text.c_str()); + if (!final_result.text) { + log_error("LLM.Component", "Failed to allocate result text"); + if (error_callback) { + error_callback(RAC_ERROR_OUT_OF_MEMORY, "Failed to allocate result text", user_data); + } + return RAC_ERROR_OUT_OF_MEMORY; + } final_result.prompt_tokens = ctx.prompt_tokens; final_result.completion_tokens = estimate_tokens(ctx.full_text.c_str()); final_result.total_tokens = final_result.prompt_tokens + final_result.completion_tokens; @@ -761,11 +775,15 @@ extern "C" rac_result_t rac_llm_component_cancel(rac_handle_t handle) { return RAC_ERROR_INVALID_HANDLE; auto* component = reinterpret_cast(handle); - std::lock_guard lock(component->mtx); - rac_handle_t service = rac_lifecycle_get_service(component->lifecycle); - if (service) { + // Use acquire/release to pin the service for the duration of the cancel call, + // preventing use-after-free if destroy races with cancel. + // Do NOT acquire component->mtx — generate_stream() holds it during streaming. + rac_handle_t service = nullptr; + rac_result_t acq = rac_lifecycle_acquire_service(component->lifecycle, &service); + if (acq == RAC_SUCCESS && service) { rac_llm_cancel(service); + rac_lifecycle_release_service(component->lifecycle); } log_info("LLM.Component", "Generation cancellation requested"); diff --git a/sdk/runanywhere-commons/src/features/llm/streaming_metrics.cpp b/sdk/runanywhere-commons/src/features/llm/streaming_metrics.cpp index 22a7b6d59..1f3dbd5bf 100644 --- a/sdk/runanywhere-commons/src/features/llm/streaming_metrics.cpp +++ b/sdk/runanywhere-commons/src/features/llm/streaming_metrics.cpp @@ -13,6 +13,7 @@ #include #include #include +#include #include #include "rac/core/rac_logger.h" @@ -105,7 +106,10 @@ rac_result_t rac_streaming_metrics_create(const char* model_id, const char* gene return RAC_ERROR_INVALID_ARGUMENT; } - rac_streaming_metrics_collector* collector = new rac_streaming_metrics_collector(); + rac_streaming_metrics_collector* collector = new (std::nothrow) rac_streaming_metrics_collector(); + if (!collector) { + return RAC_ERROR_OUT_OF_MEMORY; + } collector->model_id = model_id; collector->generation_id = generation_id; collector->prompt_length = prompt_length; @@ -296,7 +300,10 @@ rac_result_t rac_generation_analytics_create(rac_generation_analytics_handle_t* return RAC_ERROR_INVALID_ARGUMENT; } - rac_generation_analytics* service = new rac_generation_analytics(); + rac_generation_analytics* service = new (std::nothrow) rac_generation_analytics(); + if (!service) { + return RAC_ERROR_OUT_OF_MEMORY; + } RAC_LOG_INFO("GenerationAnalytics", "Service created"); diff --git a/sdk/runanywhere-commons/src/features/llm/tool_calling.cpp b/sdk/runanywhere-commons/src/features/llm/tool_calling.cpp index 4212d4abb..ac8ef8825 100644 --- a/sdk/runanywhere-commons/src/features/llm/tool_calling.cpp +++ b/sdk/runanywhere-commons/src/features/llm/tool_calling.cpp @@ -285,9 +285,10 @@ static bool extract_json_string(const char* str, size_t pos, size_t len, char** if (ch == '"') { // End of string *out_value = static_cast(malloc(result.size() + 1)); - if (*out_value) { - memcpy(*out_value, result.c_str(), result.size() + 1); + if (!*out_value) { + return false; } + memcpy(*out_value, result.c_str(), result.size() + 1); *out_end_pos = i + 1; return true; } diff --git a/sdk/runanywhere-commons/src/features/platform/rac_backend_platform_register.cpp b/sdk/runanywhere-commons/src/features/platform/rac_backend_platform_register.cpp index 0e999bda9..cb93eae3f 100644 --- a/sdk/runanywhere-commons/src/features/platform/rac_backend_platform_register.cpp +++ b/sdk/runanywhere-commons/src/features/platform/rac_backend_platform_register.cpp @@ -799,11 +799,24 @@ void register_coreml_diffusion_entry() { rac_model_info_t model = {}; model.id = strdup("coreml-diffusion"); model.name = strdup("CoreML Diffusion"); + model.local_path = strdup("builtin://coreml-diffusion"); + model.description = strdup( + "Platform's Stable Diffusion implementation using Core ML. " + "Provides text-to-image, image-to-image, and inpainting capabilities."); + + if (!model.id || !model.name || !model.local_path || !model.description) { + RAC_LOG_ERROR(LOG_CAT, "OOM registering coreml-diffusion model"); + free(model.id); + free(model.name); + free(model.local_path); + free(model.description); + return; + } + model.category = RAC_MODEL_CATEGORY_IMAGE_GENERATION; model.format = RAC_MODEL_FORMAT_COREML; model.framework = RAC_FRAMEWORK_COREML; model.download_url = nullptr; - model.local_path = strdup("builtin://coreml-diffusion"); model.artifact_info.kind = RAC_ARTIFACT_KIND_BUILT_IN; model.download_size = 0; model.memory_required = 4000000000; // ~4GB for SD 1.5 @@ -811,9 +824,6 @@ void register_coreml_diffusion_entry() { model.supports_thinking = RAC_FALSE; model.tags = nullptr; model.tag_count = 0; - model.description = strdup( - "Platform's Stable Diffusion implementation using Core ML. " - "Provides text-to-image, image-to-image, and inpainting capabilities."); model.source = RAC_MODEL_SOURCE_LOCAL; rac_result_t result = rac_model_registry_save(registry, &model); @@ -837,11 +847,24 @@ void register_foundation_models_entry() { rac_model_info_t model = {}; model.id = strdup("foundation-models-default"); model.name = strdup("Platform LLM"); + model.local_path = strdup("builtin://foundation-models"); + model.description = strdup( + "Platform's built-in language model. " + "Uses the device's native AI capabilities when available."); + + if (!model.id || !model.name || !model.local_path || !model.description) { + RAC_LOG_ERROR(LOG_CAT, "OOM registering foundation-models-default model"); + free(model.id); + free(model.name); + free(model.local_path); + free(model.description); + return; + } + model.category = RAC_MODEL_CATEGORY_LANGUAGE; model.format = RAC_MODEL_FORMAT_UNKNOWN; model.framework = RAC_FRAMEWORK_FOUNDATION_MODELS; model.download_url = nullptr; - model.local_path = strdup("builtin://foundation-models"); model.artifact_info.kind = RAC_ARTIFACT_KIND_BUILT_IN; model.download_size = 0; model.memory_required = 0; @@ -849,9 +872,6 @@ void register_foundation_models_entry() { model.supports_thinking = RAC_FALSE; model.tags = nullptr; model.tag_count = 0; - model.description = strdup( - "Platform's built-in language model. " - "Uses the device's native AI capabilities when available."); model.source = RAC_MODEL_SOURCE_LOCAL; rac_result_t result = rac_model_registry_save(registry, &model); @@ -874,11 +894,22 @@ void register_system_tts_entry() { rac_model_info_t model = {}; model.id = strdup("system-tts"); model.name = strdup("Platform TTS"); + model.local_path = strdup("builtin://system-tts"); + model.description = strdup("Platform's built-in Text-to-Speech using native synthesis."); + + if (!model.id || !model.name || !model.local_path || !model.description) { + RAC_LOG_ERROR(LOG_CAT, "OOM registering system-tts model"); + free(model.id); + free(model.name); + free(model.local_path); + free(model.description); + return; + } + model.category = RAC_MODEL_CATEGORY_SPEECH_SYNTHESIS; model.format = RAC_MODEL_FORMAT_UNKNOWN; model.framework = RAC_FRAMEWORK_SYSTEM_TTS; model.download_url = nullptr; - model.local_path = strdup("builtin://system-tts"); model.artifact_info.kind = RAC_ARTIFACT_KIND_BUILT_IN; model.download_size = 0; model.memory_required = 0; @@ -886,7 +917,6 @@ void register_system_tts_entry() { model.supports_thinking = RAC_FALSE; model.tags = nullptr; model.tag_count = 0; - model.description = strdup("Platform's built-in Text-to-Speech using native synthesis."); model.source = RAC_MODEL_SOURCE_LOCAL; rac_result_t result = rac_model_registry_save(registry, &model); diff --git a/sdk/runanywhere-commons/src/features/rag/CMakeLists.txt b/sdk/runanywhere-commons/src/features/rag/CMakeLists.txt index 60d4fc10e..1e943c26a 100644 --- a/sdk/runanywhere-commons/src/features/rag/CMakeLists.txt +++ b/sdk/runanywhere-commons/src/features/rag/CMakeLists.txt @@ -95,7 +95,7 @@ target_link_libraries(rac_backend_rag PUBLIC target_compile_definitions(rac_backend_rag PRIVATE RAC_RAG_BUILDING) set_target_properties(rac_backend_rag PROPERTIES - CXX_STANDARD 17 + CXX_STANDARD 20 CXX_STANDARD_REQUIRED ON CXX_EXTENSIONS OFF ) diff --git a/sdk/runanywhere-commons/src/features/rag/jni/rac_rag_jni.cpp b/sdk/runanywhere-commons/src/features/rag/jni/rac_rag_jni.cpp index a19c035ee..2fe9124f2 100644 --- a/sdk/runanywhere-commons/src/features/rag/jni/rac_rag_jni.cpp +++ b/sdk/runanywhere-commons/src/features/rag/jni/rac_rag_jni.cpp @@ -13,20 +13,15 @@ #include #include -#ifdef __ANDROID__ -#include -#define TAG "RACRagJNI" -#define LOGi(...) __android_log_print(ANDROID_LOG_INFO, TAG, __VA_ARGS__) -#define LOGe(...) __android_log_print(ANDROID_LOG_ERROR, TAG, __VA_ARGS__) -#define LOGw(...) __android_log_print(ANDROID_LOG_WARN, TAG, __VA_ARGS__) -#else -#define LOGi(...) fprintf(stdout, "[INFO] " __VA_ARGS__); fprintf(stdout, "\n") -#define LOGe(...) fprintf(stderr, "[ERROR] " __VA_ARGS__); fprintf(stderr, "\n") -#define LOGw(...) fprintf(stdout, "[WARN] " __VA_ARGS__); fprintf(stdout, "\n") -#endif - #include "rac/core/rac_core.h" #include "rac/core/rac_error.h" +#include "rac/core/rac_logger.h" + +// Route JNI logging through unified RAC_LOG_* system +static const char* LOG_TAG = "JNI.RAG"; +#define LOGi(...) RAC_LOG_INFO(LOG_TAG, __VA_ARGS__) +#define LOGe(...) RAC_LOG_ERROR(LOG_TAG, __VA_ARGS__) +#define LOGw(...) RAC_LOG_WARNING(LOG_TAG, __VA_ARGS__) #include "rac/features/rag/rac_rag_pipeline.h" // Forward declarations diff --git a/sdk/runanywhere-commons/src/features/rag/onnx_embedding_provider.cpp b/sdk/runanywhere-commons/src/features/rag/onnx_embedding_provider.cpp index a57b3e10a..85ac2475c 100644 --- a/sdk/runanywhere-commons/src/features/rag/onnx_embedding_provider.cpp +++ b/sdk/runanywhere-commons/src/features/rag/onnx_embedding_provider.cpp @@ -121,8 +121,9 @@ class SimpleTokenizer { std::vector create_attention_mask(const std::vector& token_ids) { std::vector mask; + mask.reserve(token_ids.size()); for (auto id : token_ids) { - mask.push_back(id != 0 ? 1 : 0); // 1 for real tokens, 0 for padding + mask.push_back(id != pad_id_ ? 1 : 0); } return mask; } @@ -479,13 +480,13 @@ class ONNXEmbeddingProvider::Impl { try { config_ = nlohmann::json::parse(config_json); } catch (const std::exception& e) { - LOGE("Failed to parse config JSON: %s", e.what()); + RAC_LOG_ERROR(LOG_TAG,"Failed to parse config JSON: %s", e.what()); } } // Initialize ONNX Runtime if (!initialize_onnx_runtime()) { - LOGE("Failed to initialize ONNX Runtime"); + RAC_LOG_ERROR(LOG_TAG,"Failed to initialize ONNX Runtime"); return; } @@ -542,14 +543,15 @@ class ONNXEmbeddingProvider::Impl { vocab_path.c_str(), search_dir.string().c_str()); return; } + } if (!tokenizer_.load_vocab(vocab_path)) { - LOGE("Failed to load tokenizer vocab: %s", vocab_path.c_str()); + RAC_LOG_ERROR(LOG_TAG,"Failed to load tokenizer vocab: %s", vocab_path.c_str()); return; } - LOGI("Loaded tokenizer vocab: %s", vocab_path.c_str()); + RAC_LOG_INFO(LOG_TAG,"Loaded tokenizer vocab: %s", vocab_path.c_str()); std::string resolved_model_path = model_path; if (std::filesystem::is_directory(resolved_model_path)) { @@ -559,12 +561,13 @@ class ONNXEmbeddingProvider::Impl { // Load model if (!load_model(resolved_model_path)) { LOGE("Failed to load model: %s", resolved_model_path.c_str()); + return; } ready_ = true; - LOGI("ONNX embedding provider initialized: %s", model_path.c_str()); - LOGI(" Hidden dimension: %zu", embedding_dim_); + RAC_LOG_INFO(LOG_TAG,"ONNX embedding provider initialized: %s", model_path.c_str()); + RAC_LOG_INFO(LOG_TAG," Hidden dimension: %zu", embedding_dim_); } ~Impl() { @@ -575,6 +578,7 @@ class ONNXEmbeddingProvider::Impl { if (!ready_) { LOGE("Embedding provider not ready"); return {}; + } std::lock_guard lock(embed_mutex_); @@ -612,8 +616,10 @@ class ONNXEmbeddingProvider::Impl { input_ids_guard.ptr() )); if (status_guard.is_error()) { + LOGE("CreateTensorWithDataAsOrtValue (input_ids) failed: %s", status_guard.error_message()); return {}; + } // Create attention_mask tensor @@ -627,8 +633,10 @@ class ONNXEmbeddingProvider::Impl { attention_mask_guard.ptr() )); if (status_guard.is_error()) { + LOGE("CreateTensorWithDataAsOrtValue (attention_mask) failed: %s", status_guard.error_message()); return {}; + } // Create token_type_ids tensor @@ -642,8 +650,10 @@ class ONNXEmbeddingProvider::Impl { token_type_ids_guard.ptr() )); if (status_guard.is_error()) { + LOGE("CreateTensorWithDataAsOrtValue (token_type_ids) failed: %s", status_guard.error_message()); return {}; + } // 3. Run inference @@ -665,8 +675,10 @@ class ONNXEmbeddingProvider::Impl { )); if (status_guard.is_error()) { + LOGE("ONNX inference failed: %s", status_guard.error_message()); return {}; + } // Transfer ownership to guard for automatic cleanup @@ -678,13 +690,17 @@ class ONNXEmbeddingProvider::Impl { output_status_guard.reset(ort_api_->GetTensorMutableData(output_guard.get(), (void**)&output_data)); if (output_status_guard.is_error()) { + LOGE("Failed to get output tensor data: %s", output_status_guard.error_message()); return {}; + } if (output_data == nullptr) { + LOGE("Output tensor data pointer is null"); return {}; + } OrtTensorTypeAndShapeInfo* shape_info = nullptr; @@ -700,7 +716,7 @@ class ONNXEmbeddingProvider::Impl { ort_api_->GetDimensions(shape_info, dims.data(), dim_count); actual_hidden_dim = static_cast(dims[2]); if (actual_hidden_dim != embedding_dim_) { - LOGI("Model hidden dim %zu differs from configured %zu, using actual", + RAC_LOG_INFO(LOG_TAG,"Model hidden dim %zu differs from configured %zu, using actual", actual_hidden_dim, embedding_dim_); embedding_dim_ = actual_hidden_dim; } @@ -719,12 +735,13 @@ class ONNXEmbeddingProvider::Impl { normalize_vector(pooled); // All resources automatically cleaned up by RAII guards - LOGI("Generated embedding: dim=%zu, norm=1.0", pooled.size()); + RAC_LOG_INFO(LOG_TAG,"Generated embedding: dim=%zu, norm=1.0", pooled.size()); return pooled; } catch (const std::exception& e) { LOGE("Embedding generation failed: %s", e.what()); return {}; + } } @@ -944,7 +961,7 @@ class ONNXEmbeddingProvider::Impl { const char* ort_version = ort_api_base ? ort_api_base->GetVersionString() : "unknown"; ort_api_ = ort_api_base ? ort_api_base->GetApi(ORT_API_VERSION) : nullptr; if (!ort_api_) { - LOGE("Failed to get ONNX Runtime API (ORT_API_VERSION=%d, runtime=%s)", ORT_API_VERSION, ort_version); + RAC_LOG_ERROR(LOG_TAG,"Failed to get ONNX Runtime API (ORT_API_VERSION=%d, runtime=%s)", ORT_API_VERSION, ort_version); return false; } @@ -952,7 +969,7 @@ class ONNXEmbeddingProvider::Impl { OrtStatus* status = ort_api_->CreateEnv(ORT_LOGGING_LEVEL_WARNING, "RAGEmbedding", &ort_env_); if (status != nullptr) { const char* error_msg = ort_api_->GetErrorMessage(status); - LOGE("Failed to create ORT environment: %s", error_msg); + RAC_LOG_ERROR(LOG_TAG,"Failed to create ORT environment: %s", error_msg); ort_api_->ReleaseStatus(status); return false; } @@ -967,25 +984,25 @@ class ONNXEmbeddingProvider::Impl { status_guard.reset(ort_api_->CreateSessionOptions(options_guard.ptr())); if (status_guard.is_error()) { - LOGE("Failed to create session options: %s", status_guard.error_message()); + RAC_LOG_ERROR(LOG_TAG,"Failed to create session options: %s", status_guard.error_message()); return false; } if (options_guard.get() == nullptr) { - LOGE("Session options is null after creation"); + RAC_LOG_ERROR(LOG_TAG,"Session options is null after creation"); return false; } // Configure session options with error checking status_guard.reset(ort_api_->SetIntraOpNumThreads(options_guard.get(), 4)); if (status_guard.is_error()) { - LOGE("Failed to set intra-op threads: %s", status_guard.error_message()); + RAC_LOG_ERROR(LOG_TAG,"Failed to set intra-op threads: %s", status_guard.error_message()); return false; } status_guard.reset(ort_api_->SetSessionGraphOptimizationLevel(options_guard.get(), ORT_ENABLE_ALL)); if (status_guard.is_error()) { - LOGE("Failed to set graph optimization level: %s", status_guard.error_message()); + RAC_LOG_ERROR(LOG_TAG,"Failed to set graph optimization level: %s", status_guard.error_message()); return false; } @@ -999,7 +1016,7 @@ class ONNXEmbeddingProvider::Impl { // options_guard automatically releases session options on scope exit if (status_guard.is_error()) { - LOGE("Failed to load model: %s", status_guard.error_message()); + RAC_LOG_ERROR(LOG_TAG,"Failed to load model: %s", status_guard.error_message()); return false; } diff --git a/sdk/runanywhere-commons/src/features/rag/vector_store_usearch.cpp b/sdk/runanywhere-commons/src/features/rag/vector_store_usearch.cpp index 7840c68f9..70d6a0358 100644 --- a/sdk/runanywhere-commons/src/features/rag/vector_store_usearch.cpp +++ b/sdk/runanywhere-commons/src/features/rag/vector_store_usearch.cpp @@ -38,6 +38,7 @@ #define LOGW(...) RAC_LOG_WARNING(LOG_TAG, __VA_ARGS__) #define LOGE(...) RAC_LOG_ERROR(LOG_TAG, __VA_ARGS__) + namespace runanywhere { namespace rag { @@ -66,7 +67,7 @@ class VectorStoreUSearch::Impl { // Create index auto result = index_dense_t::make(metric, usearch_config); if (!result) { - LOGE("Failed to create USearch index: %s", result.error.what()); + RAC_LOG_ERROR(LOG_TAG,"Failed to create USearch index: %s", result.error.what()); throw std::runtime_error("Failed to create USearch index"); } index_ = std::move(result.index); @@ -81,14 +82,14 @@ class VectorStoreUSearch::Impl { std::lock_guard lock(mutex_); if (chunk.embedding.size() != config_.dimension) { - LOGE("Invalid embedding dimension: %zu (expected %zu)", + RAC_LOG_ERROR(LOG_TAG,"Invalid embedding dimension: %zu (expected %zu)", chunk.embedding.size(), config_.dimension); return false; } // Check for duplicate ID if (id_to_key_.find(chunk.id) != id_to_key_.end()) { - LOGE("Duplicate chunk ID: %s", chunk.id.c_str()); + RAC_LOG_ERROR(LOG_TAG,"Duplicate chunk ID: %s", chunk.id.c_str()); return false; } @@ -98,7 +99,7 @@ class VectorStoreUSearch::Impl { // Add to USearch index auto add_result = index_.add(key, chunk.embedding.data()); if (!add_result) { - LOGE("Failed to add chunk to index: %s", add_result.error.what()); + RAC_LOG_ERROR(LOG_TAG,"Failed to add chunk to index: %s", add_result.error.what()); return false; } @@ -118,13 +119,13 @@ class VectorStoreUSearch::Impl { for (const auto& chunk : chunks) { if (chunk.embedding.size() != config_.dimension) { - LOGE("Invalid embedding dimension in batch"); + RAC_LOG_ERROR(LOG_TAG,"Invalid embedding dimension in batch"); continue; } // Check for duplicate ID if (id_to_key_.find(chunk.id) != id_to_key_.end()) { - LOGE("Duplicate chunk ID in batch: %s", chunk.id.c_str()); + RAC_LOG_ERROR(LOG_TAG,"Duplicate chunk ID in batch: %s", chunk.id.c_str()); continue; } @@ -132,7 +133,7 @@ class VectorStoreUSearch::Impl { std::size_t key = next_key_++; auto add_result = index_.add(key, chunk.embedding.data()); if (!add_result) { - LOGE("Failed to add chunk to batch: %s", add_result.error.what()); + RAC_LOG_ERROR(LOG_TAG,"Failed to add chunk to batch: %s", add_result.error.what()); continue; } // Store metadata @@ -155,7 +156,7 @@ class VectorStoreUSearch::Impl { std::lock_guard lock(mutex_); if (query_embedding.size() != config_.dimension) { - LOGE("Invalid query embedding dimension"); + RAC_LOG_ERROR(LOG_TAG,"Invalid query embedding dimension"); return {}; } @@ -166,7 +167,7 @@ class VectorStoreUSearch::Impl { // Search for the closest K matches auto matches = index_.search(query_embedding.data(), top_k); - LOGI("USearch returned %zu matches from %zu total vectors", + RAC_LOG_INFO(LOG_TAG,"USearch returned %zu matches from %zu total vectors", matches.size(), index_.size()); float effective_threshold = threshold; @@ -185,18 +186,13 @@ class VectorStoreUSearch::Impl { // USearch cosine distance is 1 - cosine_similarity float similarity = 1.0f - distance; - LOGI("Match %zu: key=%zu, distance=%.4f, similarity=%.4f, effective_threshold=%.4f", - i, key, distance, similarity, effective_threshold); - - // Use our capped threshold for filtering if (similarity < effective_threshold) { - LOGI(" Skipping: similarity %.4f < effective_threshold %.4f", similarity, effective_threshold); continue; } auto it = chunks_.find(key); if (it == chunks_.end()) { - LOGE("Chunk key %zu not found in metadata map", key); + RAC_LOG_ERROR(LOG_TAG,"Chunk key %zu not found in metadata map", key); continue; } @@ -240,7 +236,7 @@ class VectorStoreUSearch::Impl { std::size_t key = it->second; auto remove_result = index_.remove(key); if (!remove_result) { - LOGE("Failed to remove chunk from index: %s", remove_result.error.what()); + RAC_LOG_ERROR(LOG_TAG,"Failed to remove chunk from index: %s", remove_result.error.what()); return false; } chunks_.erase(key); @@ -255,7 +251,7 @@ class VectorStoreUSearch::Impl { chunks_.clear(); id_to_key_.clear(); next_key_ = 0; // Reset counter - LOGI("Cleared vector store"); + RAC_LOG_INFO(LOG_TAG,"Cleared vector store"); } size_t size() const { @@ -281,13 +277,13 @@ class VectorStoreUSearch::Impl { return stats; } - bool save(const std::string& path) { + bool save(const std::string& path) const { std::lock_guard lock(mutex_); // Save USearch index auto save_result = index_.save(path.c_str()); if (!save_result) { - LOGE("Failed to save USearch index: %s", save_result.error.what()); + RAC_LOG_ERROR(LOG_TAG,"Failed to save USearch index: %s", save_result.error.what()); return false; } @@ -308,13 +304,13 @@ class VectorStoreUSearch::Impl { std::string metadata_path = path + ".metadata.json"; std::ofstream metadata_file(metadata_path); if (!metadata_file) { - LOGE("Failed to open metadata file: %s", metadata_path.c_str()); + RAC_LOG_ERROR(LOG_TAG,"Failed to open metadata file: %s", metadata_path.c_str()); return false; } metadata_file << metadata.dump(); metadata_file.close(); - LOGI("Saved index and metadata to %s", path.c_str()); + RAC_LOG_INFO(LOG_TAG,"Saved index and metadata to %s", path.c_str()); return true; } @@ -324,7 +320,7 @@ class VectorStoreUSearch::Impl { // Load USearch index auto load_result = index_.load(path.c_str()); if (!load_result) { - LOGE("Failed to load USearch index: %s", load_result.error.what()); + RAC_LOG_ERROR(LOG_TAG,"Failed to load USearch index: %s", load_result.error.what()); return false; } @@ -332,7 +328,7 @@ class VectorStoreUSearch::Impl { std::string metadata_path = path + ".metadata.json"; std::ifstream metadata_file(metadata_path); if (!metadata_file) { - LOGE("Failed to open metadata file: %s", metadata_path.c_str()); + RAC_LOG_ERROR(LOG_TAG,"Failed to open metadata file: %s", metadata_path.c_str()); return false; } @@ -357,19 +353,21 @@ class VectorStoreUSearch::Impl { } chunk.metadata = chunk_json.at("metadata"); + std::string chunk_id = chunk.id; new_chunks[key] = std::move(chunk); - new_id_to_key[new_chunks[key].id] = key; + new_id_to_key[chunk_id] = key; } next_key_ = parsed_next_key; chunks_ = std::move(new_chunks); id_to_key_ = std::move(new_id_to_key); } catch (const std::exception& e) { - LOGE("Failed to parse metadata JSON: %s", e.what()); + RAC_LOG_ERROR(LOG_TAG,"Failed to parse metadata JSON: %s", e.what()); + index_.clear(); // Revert to consistent empty state return false; } - LOGI("Loaded index and metadata from %s (next_key=%zu, chunks=%zu)", + RAC_LOG_INFO(LOG_TAG,"Loaded index and metadata from %s (next_key=%zu, chunks=%zu)", path.c_str(), next_key_, chunks_.size()); return true; } @@ -409,10 +407,10 @@ std::vector VectorStoreUSearch::search( try { return impl_->search(query_embedding, top_k, threshold); } catch (const std::exception& e) { - LOGE("search() exception: %s", e.what()); + RAC_LOG_ERROR(LOG_TAG,"search() exception: %s", e.what()); return {}; } catch (...) { - LOGE("search() unknown exception"); + RAC_LOG_ERROR(LOG_TAG,"search() unknown exception"); return {}; } } diff --git a/sdk/runanywhere-commons/src/features/stt/stt_analytics.cpp b/sdk/runanywhere-commons/src/features/stt/stt_analytics.cpp index ec377061d..61319a28b 100644 --- a/sdk/runanywhere-commons/src/features/stt/stt_analytics.cpp +++ b/sdk/runanywhere-commons/src/features/stt/stt_analytics.cpp @@ -11,6 +11,7 @@ #include #include #include +#include #include #include #include @@ -42,9 +43,8 @@ int64_t get_current_time_ms() { } std::string generate_uuid() { - static std::random_device rd; - static std::mt19937 gen(rd()); - static std::uniform_int_distribution<> dis(0, 15); + static thread_local std::mt19937 gen(std::random_device{}()); + static thread_local std::uniform_int_distribution<> dis(0, 15); std::stringstream ss; ss << std::hex; @@ -110,13 +110,12 @@ rac_result_t rac_stt_analytics_create(rac_stt_analytics_handle_t* out_handle) { return RAC_ERROR_INVALID_PARAMETER; } - try { - *out_handle = new rac_stt_analytics_s(); - log_info("STT.Analytics", "STT analytics service created"); - return RAC_SUCCESS; - } catch (...) { + *out_handle = new (std::nothrow) rac_stt_analytics_s(); + if (!*out_handle) { return RAC_ERROR_OUT_OF_MEMORY; } + log_info("STT.Analytics", "STT analytics service created"); + return RAC_SUCCESS; } void rac_stt_analytics_destroy(rac_stt_analytics_handle_t handle) { @@ -156,7 +155,7 @@ rac_result_t rac_stt_analytics_start_transcription(rac_stt_analytics_handle_t ha if (!*out_transcription_id) { return RAC_ERROR_OUT_OF_MEMORY; } - strcpy(*out_transcription_id, id.c_str()); + memcpy(*out_transcription_id, id.c_str(), id.size() + 1); log_debug("STT.Analytics", "Transcription started: %s, model: %s, audio: %.1fms, %d bytes", id.c_str(), model_id, audio_length_ms, audio_size_bytes); diff --git a/sdk/runanywhere-commons/src/features/stt/stt_component.cpp b/sdk/runanywhere-commons/src/features/stt/stt_component.cpp index 558643971..19acec9ac 100644 --- a/sdk/runanywhere-commons/src/features/stt/stt_component.cpp +++ b/sdk/runanywhere-commons/src/features/stt/stt_component.cpp @@ -13,6 +13,7 @@ #include #include #include +#include #include #include "rac/core/capabilities/rac_lifecycle.h" @@ -63,11 +64,10 @@ struct rac_stt_component { * Generate a unique ID for transcription tracking. */ static std::string generate_unique_id() { - auto now = std::chrono::high_resolution_clock::now(); - auto epoch = now.time_since_epoch(); - auto ns = std::chrono::duration_cast(epoch).count(); - char buffer[64]; - snprintf(buffer, sizeof(buffer), "trans_%lld", static_cast(ns)); + static thread_local std::mt19937 gen(std::random_device{}()); + std::uniform_int_distribution dis; + char buffer[32]; + snprintf(buffer, sizeof(buffer), "trans_%08x%08x", dis(gen), dis(gen)); return std::string(buffer); } @@ -323,48 +323,52 @@ extern "C" rac_result_t rac_stt_component_transcribe(rac_handle_t handle, const return RAC_ERROR_INVALID_ARGUMENT; auto* component = reinterpret_cast(handle); - std::lock_guard lock(component->mtx); - // Generate unique ID for this transcription + // Acquire lock only for state reads, release before long-running transcription std::string transcription_id = generate_unique_id(); - const char* model_id = rac_lifecycle_get_model_id(component->lifecycle); - const char* model_name = rac_lifecycle_get_model_name(component->lifecycle); + rac_handle_t service = nullptr; + rac_stt_options_t local_options; + rac_inference_framework_t framework; + int32_t sample_rate = 0; + const char* model_id = nullptr; + const char* model_name = nullptr; - // Debug: Log if model_id is null - if (!model_id) { - log_warning( - "STT.Component", - "rac_lifecycle_get_model_id returned null - model_id may not be set in telemetry"); - } else { - log_debug("STT.Component", "STT transcription using model_id: %s", model_id); + { + std::lock_guard lock(component->mtx); + + model_id = rac_lifecycle_get_model_id(component->lifecycle); + model_name = rac_lifecycle_get_model_name(component->lifecycle); + framework = component->actual_framework; + sample_rate = component->config.sample_rate; + + // Copy effective options to local so we can release the lock + local_options = options ? *options : component->default_options; + + rac_result_t result = rac_lifecycle_require_service(component->lifecycle, &service); + if (result != RAC_SUCCESS) { + log_error("STT.Component", "No model loaded - cannot transcribe"); + + // Emit transcription failed event + rac_analytics_event_data_t event = {}; + event.type = RAC_EVENT_STT_TRANSCRIPTION_FAILED; + event.data.stt_transcription = RAC_ANALYTICS_STT_TRANSCRIPTION_DEFAULT; + event.data.stt_transcription.transcription_id = transcription_id.c_str(); + event.data.stt_transcription.model_id = model_id; + event.data.stt_transcription.model_name = model_name; + event.data.stt_transcription.error_code = result; + event.data.stt_transcription.error_message = "No model loaded"; + rac_analytics_event_emit(RAC_EVENT_STT_TRANSCRIPTION_FAILED, &event); + + return result; + } } + // Lock released — safe to do long-running transcription // Estimate audio length (assuming 16kHz mono 16-bit audio) double audio_length_ms = (audio_size / 2.0 / 16000.0) * 1000.0; - rac_handle_t service = nullptr; - rac_result_t result = rac_lifecycle_require_service(component->lifecycle, &service); - if (result != RAC_SUCCESS) { - log_error("STT.Component", "No model loaded - cannot transcribe"); - - // Emit transcription failed event - rac_analytics_event_data_t event = {}; - event.type = RAC_EVENT_STT_TRANSCRIPTION_FAILED; - event.data.stt_transcription = RAC_ANALYTICS_STT_TRANSCRIPTION_DEFAULT; - event.data.stt_transcription.transcription_id = transcription_id.c_str(); - event.data.stt_transcription.model_id = model_id; - event.data.stt_transcription.model_name = model_name; - event.data.stt_transcription.error_code = result; - event.data.stt_transcription.error_message = "No model loaded"; - rac_analytics_event_emit(RAC_EVENT_STT_TRANSCRIPTION_FAILED, &event); - - return result; - } - log_info("STT.Component", "Transcribing audio"); - const rac_stt_options_t* effective_options = options ? options : &component->default_options; - // Emit transcription started event { rac_analytics_event_data_t event = {}; @@ -375,16 +379,16 @@ extern "C" rac_result_t rac_stt_component_transcribe(rac_handle_t handle, const event.data.stt_transcription.model_name = model_name; event.data.stt_transcription.audio_length_ms = audio_length_ms; event.data.stt_transcription.audio_size_bytes = static_cast(audio_size); - event.data.stt_transcription.language = effective_options->language; + event.data.stt_transcription.language = local_options.language; event.data.stt_transcription.is_streaming = RAC_FALSE; - event.data.stt_transcription.sample_rate = component->config.sample_rate; - event.data.stt_transcription.framework = component->actual_framework; + event.data.stt_transcription.sample_rate = sample_rate; + event.data.stt_transcription.framework = framework; rac_analytics_event_emit(RAC_EVENT_STT_TRANSCRIPTION_STARTED, &event); } auto start_time = std::chrono::steady_clock::now(); - result = rac_stt_transcribe(service, audio_data, audio_size, effective_options, out_result); + rac_result_t result = rac_stt_transcribe(service, audio_data, audio_size, &local_options, out_result); if (result != RAC_SUCCESS) { log_error("STT.Component", "Transcription failed"); @@ -434,9 +438,9 @@ extern "C" rac_result_t rac_stt_component_transcribe(rac_handle_t handle, const event.data.stt_transcription.audio_size_bytes = static_cast(audio_size); event.data.stt_transcription.word_count = word_count; event.data.stt_transcription.real_time_factor = real_time_factor; - event.data.stt_transcription.language = effective_options->language; - event.data.stt_transcription.sample_rate = component->config.sample_rate; - event.data.stt_transcription.framework = component->actual_framework; + event.data.stt_transcription.language = local_options.language; + event.data.stt_transcription.sample_rate = sample_rate; + event.data.stt_transcription.framework = framework; event.data.stt_transcription.error_code = RAC_SUCCESS; rac_analytics_event_emit(RAC_EVENT_STT_TRANSCRIPTION_COMPLETED, &event); } diff --git a/sdk/runanywhere-commons/src/features/tts/tts_analytics.cpp b/sdk/runanywhere-commons/src/features/tts/tts_analytics.cpp index 276f01bbc..9519202cc 100644 --- a/sdk/runanywhere-commons/src/features/tts/tts_analytics.cpp +++ b/sdk/runanywhere-commons/src/features/tts/tts_analytics.cpp @@ -11,6 +11,7 @@ #include #include #include +#include #include #include #include @@ -39,9 +40,8 @@ int64_t get_current_time_ms() { } std::string generate_uuid() { - static std::random_device rd; - static std::mt19937 gen(rd()); - static std::uniform_int_distribution<> dis(0, 15); + static thread_local std::mt19937 gen(std::random_device{}()); + static thread_local std::uniform_int_distribution<> dis(0, 15); std::stringstream ss; ss << std::hex; @@ -109,13 +109,12 @@ rac_result_t rac_tts_analytics_create(rac_tts_analytics_handle_t* out_handle) { return RAC_ERROR_INVALID_PARAMETER; } - try { - *out_handle = new rac_tts_analytics_s(); - log_info("TTS.Analytics", "TTS analytics service created"); - return RAC_SUCCESS; - } catch (...) { + *out_handle = new (std::nothrow) rac_tts_analytics_s(); + if (!*out_handle) { return RAC_ERROR_OUT_OF_MEMORY; } + log_info("TTS.Analytics", "TTS analytics service created"); + return RAC_SUCCESS; } void rac_tts_analytics_destroy(rac_tts_analytics_handle_t handle) { @@ -151,7 +150,7 @@ rac_result_t rac_tts_analytics_start_synthesis(rac_tts_analytics_handle_t handle if (!*out_synthesis_id) { return RAC_ERROR_OUT_OF_MEMORY; } - strcpy(*out_synthesis_id, id.c_str()); + memcpy(*out_synthesis_id, id.c_str(), id.size() + 1); log_debug("TTS.Analytics", "Synthesis started: %s, voice: %s, %d characters", id.c_str(), voice, character_count); diff --git a/sdk/runanywhere-commons/src/features/tts/tts_component.cpp b/sdk/runanywhere-commons/src/features/tts/tts_component.cpp index 420236539..d8eed22d3 100644 --- a/sdk/runanywhere-commons/src/features/tts/tts_component.cpp +++ b/sdk/runanywhere-commons/src/features/tts/tts_component.cpp @@ -13,6 +13,7 @@ #include #include #include +#include #include #include "rac/core/capabilities/rac_lifecycle.h" @@ -50,13 +51,15 @@ struct rac_tts_component { // Generate a simple UUID v4-like string for event tracking static std::string generate_uuid_v4() { + static thread_local std::mt19937 gen(std::random_device{}()); + static thread_local std::uniform_int_distribution<> dis(0, 15); static const char* hex = "0123456789abcdef"; std::string uuid = "xxxxxxxx-xxxx-4xxx-yxxx-xxxxxxxxxxxx"; for (size_t i = 0; i < uuid.size(); i++) { if (uuid[i] == 'x') { - uuid[i] = hex[std::rand() % 16]; + uuid[i] = hex[dis(gen)]; } else if (uuid[i] == 'y') { - uuid[i] = hex[(std::rand() % 4) + 8]; + uuid[i] = hex[(dis(gen) % 4) + 8]; } } return uuid; @@ -318,20 +321,42 @@ extern "C" rac_result_t rac_tts_component_synthesize(rac_handle_t handle, const return RAC_ERROR_INVALID_ARGUMENT; auto* component = reinterpret_cast(handle); - std::lock_guard lock(component->mtx); - // Generate synthesis ID for event tracking + // Acquire lock only for state reads, release before long-running synthesis std::string synthesis_id = generate_uuid_v4(); - const char* voice_id = rac_lifecycle_get_model_id(component->lifecycle); - const char* voice_name = rac_lifecycle_get_model_name(component->lifecycle); + rac_handle_t service = nullptr; + rac_tts_options_t local_options; + rac_inference_framework_t framework; + const char* voice_id = nullptr; + const char* voice_name = nullptr; - // Debug: Log if voice_id is null - if (!voice_id) { - log_warning("TTS.Component", - "rac_lifecycle_get_model_id returned null - voice may not be set in telemetry"); - } else { - log_debug("TTS.Component", "TTS synthesis using voice_id: %s", voice_id); + { + std::lock_guard lock(component->mtx); + + voice_id = rac_lifecycle_get_model_id(component->lifecycle); + voice_name = rac_lifecycle_get_model_name(component->lifecycle); + framework = component->actual_framework; + + // Copy effective options to local so we can release the lock + local_options = options ? *options : component->default_options; + + rac_result_t result = rac_lifecycle_require_service(component->lifecycle, &service); + if (result != RAC_SUCCESS) { + log_error("TTS.Component", "No voice loaded - cannot synthesize"); + // Emit SYNTHESIS_FAILED event + rac_analytics_event_data_t event_data; + event_data.data.tts_synthesis = RAC_ANALYTICS_TTS_SYNTHESIS_DEFAULT; + event_data.data.tts_synthesis.synthesis_id = synthesis_id.c_str(); + event_data.data.tts_synthesis.model_id = voice_id; + event_data.data.tts_synthesis.model_name = voice_name; + event_data.data.tts_synthesis.framework = framework; + event_data.data.tts_synthesis.error_code = result; + event_data.data.tts_synthesis.error_message = "No voice loaded"; + rac_analytics_event_emit(RAC_EVENT_TTS_SYNTHESIS_FAILED, &event_data); + return result; + } } + // Lock released — safe to do long-running synthesis // Emit SYNTHESIS_STARTED event { @@ -341,34 +366,15 @@ extern "C" rac_result_t rac_tts_component_synthesize(rac_handle_t handle, const event_data.data.tts_synthesis.model_id = voice_id; event_data.data.tts_synthesis.model_name = voice_name; event_data.data.tts_synthesis.character_count = static_cast(std::strlen(text)); - event_data.data.tts_synthesis.framework = component->actual_framework; + event_data.data.tts_synthesis.framework = framework; rac_analytics_event_emit(RAC_EVENT_TTS_SYNTHESIS_STARTED, &event_data); } - rac_handle_t service = nullptr; - rac_result_t result = rac_lifecycle_require_service(component->lifecycle, &service); - if (result != RAC_SUCCESS) { - log_error("TTS.Component", "No voice loaded - cannot synthesize"); - // Emit SYNTHESIS_FAILED event - rac_analytics_event_data_t event_data; - event_data.data.tts_synthesis = RAC_ANALYTICS_TTS_SYNTHESIS_DEFAULT; - event_data.data.tts_synthesis.synthesis_id = synthesis_id.c_str(); - event_data.data.tts_synthesis.model_id = voice_id; - event_data.data.tts_synthesis.model_name = voice_name; - event_data.data.tts_synthesis.framework = component->actual_framework; - event_data.data.tts_synthesis.error_code = result; - event_data.data.tts_synthesis.error_message = "No voice loaded"; - rac_analytics_event_emit(RAC_EVENT_TTS_SYNTHESIS_FAILED, &event_data); - return result; - } - log_info("TTS.Component", "Synthesizing text"); - const rac_tts_options_t* effective_options = options ? options : &component->default_options; - auto start_time = std::chrono::steady_clock::now(); - result = rac_tts_synthesize(service, text, effective_options, out_result); + rac_result_t result = rac_tts_synthesize(service, text, &local_options, out_result); auto end_time = std::chrono::steady_clock::now(); auto duration = std::chrono::duration_cast(end_time - start_time); @@ -384,7 +390,7 @@ extern "C" rac_result_t rac_tts_component_synthesize(rac_handle_t handle, const event_data.data.tts_synthesis.model_name = voice_name; event_data.data.tts_synthesis.processing_duration_ms = static_cast(duration.count()); - event_data.data.tts_synthesis.framework = component->actual_framework; + event_data.data.tts_synthesis.framework = framework; event_data.data.tts_synthesis.error_code = result; event_data.data.tts_synthesis.error_message = "Synthesis failed"; rac_analytics_event_emit(RAC_EVENT_TTS_SYNTHESIS_FAILED, &event_data); @@ -414,7 +420,7 @@ extern "C" rac_result_t rac_tts_component_synthesize(rac_handle_t handle, const event_data.data.tts_synthesis.processing_duration_ms = processing_ms; event_data.data.tts_synthesis.characters_per_second = chars_per_sec; event_data.data.tts_synthesis.sample_rate = static_cast(out_result->sample_rate); - event_data.data.tts_synthesis.framework = component->actual_framework; + event_data.data.tts_synthesis.framework = framework; rac_analytics_event_emit(RAC_EVENT_TTS_SYNTHESIS_COMPLETED, &event_data); } @@ -433,14 +439,44 @@ extern "C" rac_result_t rac_tts_component_synthesize_stream(rac_handle_t handle, return RAC_ERROR_INVALID_ARGUMENT; auto* component = reinterpret_cast(handle); - std::lock_guard lock(component->mtx); - // Generate synthesis ID for event tracking + // Acquire lock only for state reads, release before long-running synthesis std::string synthesis_id = generate_uuid_v4(); - const char* voice_id = rac_lifecycle_get_model_id(component->lifecycle); - const char* voice_name = rac_lifecycle_get_model_name(component->lifecycle); + rac_handle_t service = nullptr; + rac_tts_options_t local_options; + rac_inference_framework_t framework; + const char* voice_id = nullptr; + const char* voice_name = nullptr; int32_t char_count = static_cast(std::strlen(text)); + { + std::lock_guard lock(component->mtx); + + voice_id = rac_lifecycle_get_model_id(component->lifecycle); + voice_name = rac_lifecycle_get_model_name(component->lifecycle); + framework = component->actual_framework; + + // Copy effective options to local so we can release the lock + local_options = options ? *options : component->default_options; + + rac_result_t result = rac_lifecycle_require_service(component->lifecycle, &service); + if (result != RAC_SUCCESS) { + log_error("TTS.Component", "No voice loaded - cannot synthesize stream"); + // Emit SYNTHESIS_FAILED event + rac_analytics_event_data_t event_data; + event_data.data.tts_synthesis = RAC_ANALYTICS_TTS_SYNTHESIS_DEFAULT; + event_data.data.tts_synthesis.synthesis_id = synthesis_id.c_str(); + event_data.data.tts_synthesis.model_id = voice_id; + event_data.data.tts_synthesis.model_name = voice_name; + event_data.data.tts_synthesis.framework = framework; + event_data.data.tts_synthesis.error_code = result; + event_data.data.tts_synthesis.error_message = "No voice loaded"; + rac_analytics_event_emit(RAC_EVENT_TTS_SYNTHESIS_FAILED, &event_data); + return result; + } + } + // Lock released — safe to do long-running synthesis + // Emit SYNTHESIS_STARTED event { rac_analytics_event_data_t event_data; @@ -449,34 +485,15 @@ extern "C" rac_result_t rac_tts_component_synthesize_stream(rac_handle_t handle, event_data.data.tts_synthesis.model_id = voice_id; event_data.data.tts_synthesis.model_name = voice_name; event_data.data.tts_synthesis.character_count = char_count; - event_data.data.tts_synthesis.framework = component->actual_framework; + event_data.data.tts_synthesis.framework = framework; rac_analytics_event_emit(RAC_EVENT_TTS_SYNTHESIS_STARTED, &event_data); } - rac_handle_t service = nullptr; - rac_result_t result = rac_lifecycle_require_service(component->lifecycle, &service); - if (result != RAC_SUCCESS) { - log_error("TTS.Component", "No voice loaded - cannot synthesize stream"); - // Emit SYNTHESIS_FAILED event - rac_analytics_event_data_t event_data; - event_data.data.tts_synthesis = RAC_ANALYTICS_TTS_SYNTHESIS_DEFAULT; - event_data.data.tts_synthesis.synthesis_id = synthesis_id.c_str(); - event_data.data.tts_synthesis.model_id = voice_id; - event_data.data.tts_synthesis.model_name = voice_name; - event_data.data.tts_synthesis.framework = component->actual_framework; - event_data.data.tts_synthesis.error_code = result; - event_data.data.tts_synthesis.error_message = "No voice loaded"; - rac_analytics_event_emit(RAC_EVENT_TTS_SYNTHESIS_FAILED, &event_data); - return result; - } - log_info("TTS.Component", "Starting streaming synthesis"); - const rac_tts_options_t* effective_options = options ? options : &component->default_options; - auto start_time = std::chrono::steady_clock::now(); - result = rac_tts_synthesize_stream(service, text, effective_options, callback, user_data); + rac_result_t result = rac_tts_synthesize_stream(service, text, &local_options, callback, user_data); auto end_time = std::chrono::steady_clock::now(); auto duration = std::chrono::duration_cast(end_time - start_time); @@ -492,7 +509,7 @@ extern "C" rac_result_t rac_tts_component_synthesize_stream(rac_handle_t handle, event_data.data.tts_synthesis.model_name = voice_name; event_data.data.tts_synthesis.processing_duration_ms = static_cast(duration.count()); - event_data.data.tts_synthesis.framework = component->actual_framework; + event_data.data.tts_synthesis.framework = framework; event_data.data.tts_synthesis.error_code = result; event_data.data.tts_synthesis.error_message = "Streaming synthesis failed"; rac_analytics_event_emit(RAC_EVENT_TTS_SYNTHESIS_FAILED, &event_data); @@ -509,7 +526,7 @@ extern "C" rac_result_t rac_tts_component_synthesize_stream(rac_handle_t handle, event_data.data.tts_synthesis.character_count = char_count; event_data.data.tts_synthesis.processing_duration_ms = processing_ms; event_data.data.tts_synthesis.characters_per_second = chars_per_sec; - event_data.data.tts_synthesis.framework = component->actual_framework; + event_data.data.tts_synthesis.framework = framework; rac_analytics_event_emit(RAC_EVENT_TTS_SYNTHESIS_COMPLETED, &event_data); } diff --git a/sdk/runanywhere-commons/src/features/vad/energy_vad.cpp b/sdk/runanywhere-commons/src/features/vad/energy_vad.cpp index 0d856afbe..d5f6c9364 100644 --- a/sdk/runanywhere-commons/src/features/vad/energy_vad.cpp +++ b/sdk/runanywhere-commons/src/features/vad/energy_vad.cpp @@ -13,6 +13,7 @@ #include #include #include +#include #include #include @@ -25,67 +26,130 @@ // INTERNAL STRUCTURE - Mirrors Swift's SimpleEnergyVADService properties // ============================================================================= +// Cache line size for alignment (64 bytes on most modern CPUs) +static constexpr size_t CACHE_LINE_SIZE = 64; + struct rac_energy_vad { - - // Hot data -> accessed frequently - bool is_active; - bool is_currently_speaking; - bool is_paused; - bool is_tts_active; + // === Group 1: Hot processing data (read/written every frame) === + // Kept together on their own cache line(s) for spatial locality + + float energy_threshold; + float energy_threshold_sq; // energy_threshold², for sqrt-free comparison in hot path + float base_energy_threshold; int32_t consecutive_silent_frames; int32_t consecutive_voice_frames; - float energy_threshold; - float base_energy_threshold; + bool is_active; + bool is_currently_speaking; + bool is_paused; + bool is_tts_active; int32_t voice_start_threshold; int32_t voice_end_threshold; int32_t tts_voice_start_threshold; int32_t tts_voice_end_threshold; - size_t ring_buffer_write_index; + // === Group 2: Debug ring buffer (written every frame, separate cache line) === + alignas(CACHE_LINE_SIZE) size_t ring_buffer_write_index; size_t ring_buffer_count; + std::vector recent_energy_values; + int32_t max_recent_values; + int32_t debug_frame_count; - // Cold data -> accessed less frequently - - int32_t sample_rate; + // === Group 3: Cold config data (set once at init, read-only in hot path) === + alignas(CACHE_LINE_SIZE) int32_t sample_rate; int32_t frame_length_samples; float tts_threshold_multiplier; float calibration_multiplier; + // === Group 4: Calibration state (only active during calibration phase) === bool is_calibrating; float ambient_noise_level; int32_t calibration_frame_count; int32_t calibration_frames_needed; std::vector calibration_samples; - std::vector recent_energy_values; - int32_t max_recent_values; - int32_t debug_frame_count; - - rac_speech_activity_callback_fn speech_callback; + // === Group 5: Callbacks (read in hot path, written rarely) === + alignas(CACHE_LINE_SIZE) rac_speech_activity_callback_fn speech_callback; void* speech_user_data; rac_audio_buffer_callback_fn audio_callback; void* audio_user_data; - std::mutex mutex; + // === Group 6: Mutex (separate cache line to avoid false sharing) === + alignas(CACHE_LINE_SIZE) std::mutex mutex; }; +// Verify struct layout hasn't regressed. rac_energy_vad is split into +// cache-line-aligned groups (see alignas(64) above). If someone adds fields, +// this assert fires as a reminder to check the layout. +// On most platforms: ~448-512 bytes (varies due to std::mutex/vector sizes). +// The assert is generous; tighten it when profiling reveals a regression. +static_assert(sizeof(rac_energy_vad) <= 1024, + "rac_energy_vad grew unexpectedly — check cache-line alignment groups"); + // ============================================================================= // HELPER FUNCTIONS - Mirrors Swift's private methods // ============================================================================= /** - * Update voice activity state with hysteresis - * Mirrors Swift's updateVoiceActivityState(hasVoice:) + * Update threshold_sq whenever energy_threshold changes. + * This pre-computes the squared threshold so the hot-path + * comparison can use mean-square energy vs threshold² (no sqrt). + */ +static inline void update_threshold_sq(rac_energy_vad* vad) { + vad->energy_threshold_sq = vad->energy_threshold * vad->energy_threshold; +} + +/** + * Compute mean-square energy (sum_of_squares / n) WITHOUT the final sqrt. + * Used in the hot path to avoid a per-frame sqrt; the caller compares + * the result against energy_threshold_sq instead. + * + * This is the same math as rac_energy_vad_calculate_rms() minus the sqrt. */ -static void update_voice_activity_state(rac_energy_vad* vad, bool has_voice) { +static float calculate_mean_square(const float* __restrict audio_data, size_t sample_count) { + if (sample_count == 0 || audio_data == nullptr) { + return 0.0f; + } + + float s0 = 0.0f, s1 = 0.0f, s2 = 0.0f, s3 = 0.0f; + size_t i = 0; + + for (; i + 3 < sample_count; i += 4) { + const float a = audio_data[i]; + const float b = audio_data[i + 1]; + const float c = audio_data[i + 2]; + const float d = audio_data[i + 3]; + s0 += a * a; + s1 += b * b; + s2 += c * c; + s3 += d * d; + } + + float sum_squares = (s0 + s1) + (s2 + s3); + + for (; i < sample_count; ++i) { + const float x = audio_data[i]; + sum_squares += x * x; + } + return sum_squares / static_cast(sample_count); +} + +/** + * Update voice activity state with hysteresis. + * Mirrors Swift's updateVoiceActivityState(hasVoice:). + * + * Returns the pending speech event to fire AFTER releasing the mutex: + * -1 = no event, RAC_SPEECH_ACTIVITY_STARTED (0), RAC_SPEECH_ACTIVITY_ENDED (1). + * The caller is responsible for invoking the callback outside the lock. + */ +static int update_voice_activity_state(rac_energy_vad* vad, const bool has_voice) { // Use different thresholds based on TTS state (mirrors Swift logic) - int32_t start_threshold = + const int32_t start_threshold = vad->is_tts_active ? vad->tts_voice_start_threshold : vad->voice_start_threshold; - int32_t end_threshold = + const int32_t end_threshold = vad->is_tts_active ? vad->tts_voice_end_threshold : vad->voice_end_threshold; if (has_voice) { @@ -98,16 +162,17 @@ static void update_voice_activity_state(rac_energy_vad* vad, bool has_voice) { if (vad->is_tts_active) { RAC_LOG_WARNING("EnergyVAD", "Voice detected during TTS playback - likely feedback! Ignoring."); - return; + // Reset counter to prevent instant re-trigger once TTS ends. + // Without this, consecutive_voice_frames keeps accumulating + // across TTS duration and immediately exceeds the start threshold + // on the first voiced frame after TTS finishes. + vad->consecutive_voice_frames = 0; + return -1; } vad->is_currently_speaking = true; RAC_LOG_INFO("EnergyVAD", "VAD: SPEECH STARTED"); - - // Fire callback - if (vad->speech_callback) { - vad->speech_callback(RAC_SPEECH_ACTIVITY_STARTED, vad->speech_user_data); - } + return RAC_SPEECH_ACTIVITY_STARTED; } } else { vad->consecutive_silent_frames++; @@ -117,20 +182,18 @@ static void update_voice_activity_state(rac_energy_vad* vad, bool has_voice) { if (vad->is_currently_speaking && vad->consecutive_silent_frames >= end_threshold) { vad->is_currently_speaking = false; RAC_LOG_INFO("EnergyVAD", "VAD: SPEECH ENDED"); - - // Fire callback - if (vad->speech_callback) { - vad->speech_callback(RAC_SPEECH_ACTIVITY_ENDED, vad->speech_user_data); - } + return RAC_SPEECH_ACTIVITY_ENDED; } } + + return -1; } /** * Handle a frame during calibration * Mirrors Swift's handleCalibrationFrame(energy:) */ -static void handle_calibration_frame(rac_energy_vad* vad, float energy) { +static void handle_calibration_frame(rac_energy_vad* vad, const float energy) { if (!vad->is_calibrating) { return; } @@ -149,16 +212,16 @@ static void handle_calibration_frame(rac_energy_vad* vad, float energy) { std::vector sorted_samples = vad->calibration_samples; std::sort(sorted_samples.begin(), sorted_samples.end()); - size_t count = sorted_samples.size(); - float percentile_90 = + const size_t count = sorted_samples.size(); + const float percentile_90 = sorted_samples[std::min(count - 1, static_cast(count * 0.90f))]; // Use 90th percentile as ambient noise level (mirrors Swift) vad->ambient_noise_level = percentile_90; // Calculate dynamic threshold (mirrors Swift logic) - float minimum_threshold = std::max(vad->ambient_noise_level * 2.0f, RAC_VAD_MIN_THRESHOLD); - float calculated_threshold = vad->ambient_noise_level * vad->calibration_multiplier; + const float minimum_threshold = std::max(vad->ambient_noise_level * 2.0f, RAC_VAD_MIN_THRESHOLD); + const float calculated_threshold = vad->ambient_noise_level * vad->calibration_multiplier; // Apply threshold with sensible bounds vad->energy_threshold = std::max(calculated_threshold, minimum_threshold); @@ -170,6 +233,8 @@ static void handle_calibration_frame(rac_energy_vad* vad, float energy) { "Calibration detected high ambient noise. Capping threshold."); } + update_threshold_sq(vad); + RAC_LOG_INFO("EnergyVAD", "VAD Calibration Complete"); vad->is_calibrating = false; @@ -182,7 +247,7 @@ static void handle_calibration_frame(rac_energy_vad* vad, float energy) { * Mirrors Swift's updateDebugStatistics(energy:) * Optimised to use ring buffer */ -static void update_debug_statistics(rac_energy_vad* vad, float energy) { +static void update_debug_statistics(rac_energy_vad* vad, const float energy) { if (vad->recent_energy_values.empty()) { return; } @@ -211,13 +276,17 @@ rac_result_t rac_energy_vad_create(const rac_energy_vad_config_t* config, const rac_energy_vad_config_t* cfg = config ? config : &RAC_ENERGY_VAD_CONFIG_DEFAULT; - rac_energy_vad* vad = new rac_energy_vad(); + rac_energy_vad* vad = new (std::nothrow) rac_energy_vad(); + if (!vad) { + return RAC_ERROR_OUT_OF_MEMORY; + } // Initialize from config (mirrors Swift init) vad->sample_rate = cfg->sample_rate; vad->frame_length_samples = static_cast(cfg->frame_length * static_cast(cfg->sample_rate)); vad->energy_threshold = cfg->energy_threshold; + vad->energy_threshold_sq = cfg->energy_threshold * cfg->energy_threshold; vad->base_energy_threshold = cfg->energy_threshold; vad->calibration_multiplier = RAC_VAD_DEFAULT_CALIBRATION_MULTIPLIER; vad->tts_threshold_multiplier = RAC_VAD_DEFAULT_TTS_THRESHOLD_MULTIPLIER; @@ -320,28 +389,40 @@ rac_result_t rac_energy_vad_stop(rac_energy_vad_handle_t handle) { return RAC_ERROR_INVALID_ARGUMENT; } - std::lock_guard lock(handle->mutex); + // Deferred callback (invoked outside lock to prevent re-entrant deadlock) + rac_speech_activity_callback_fn deferred_cb = nullptr; + void* deferred_data = nullptr; - // Mirrors Swift's stop() - if (!handle->is_active) { - return RAC_SUCCESS; - } + { + std::lock_guard lock(handle->mutex); + + // Mirrors Swift's stop() + if (!handle->is_active) { + return RAC_SUCCESS; + } - // If currently speaking, send end event - if (handle->is_currently_speaking) { - handle->is_currently_speaking = false; - RAC_LOG_INFO("EnergyVAD", "VAD: SPEECH ENDED (stopped)"); + // If currently speaking, send end event + if (handle->is_currently_speaking) { + handle->is_currently_speaking = false; + RAC_LOG_INFO("EnergyVAD", "VAD: SPEECH ENDED (stopped)"); - if (handle->speech_callback) { - handle->speech_callback(RAC_SPEECH_ACTIVITY_ENDED, handle->speech_user_data); + if (handle->speech_callback) { + deferred_cb = handle->speech_callback; + deferred_data = handle->speech_user_data; + } } + + handle->is_active = false; + handle->consecutive_silent_frames = 0; + handle->consecutive_voice_frames = 0; + + RAC_LOG_INFO("EnergyVAD", "SimpleEnergyVADService stopped"); } - handle->is_active = false; - handle->consecutive_silent_frames = 0; - handle->consecutive_voice_frames = 0; + if (deferred_cb) { + deferred_cb(RAC_SPEECH_ACTIVITY_ENDED, deferred_data); + } - RAC_LOG_INFO("EnergyVAD", "SimpleEnergyVADService stopped"); return RAC_SUCCESS; } @@ -361,60 +442,104 @@ rac_result_t rac_energy_vad_reset(rac_energy_vad_handle_t handle) { return RAC_SUCCESS; } -rac_result_t rac_energy_vad_process_audio(rac_energy_vad_handle_t handle, const float* audio_data, +rac_result_t rac_energy_vad_process_audio(rac_energy_vad_handle_t handle, const float* __restrict audio_data, size_t sample_count, rac_bool_t* out_has_voice) { if (!handle || !audio_data || sample_count == 0) { return RAC_ERROR_INVALID_ARGUMENT; } - std::lock_guard lock(handle->mutex); - - // Mirrors Swift's processAudioData(_:) - if (!handle->is_active) { - if (out_has_voice) - *out_has_voice = RAC_FALSE; - return RAC_SUCCESS; + // --- Phase 1: Read shared flags under lock (minimal critical section) --- + bool is_active; + bool is_tts_active; + bool is_paused; + { + std::lock_guard lock(handle->mutex); + is_active = handle->is_active; + is_tts_active = handle->is_tts_active; + is_paused = handle->is_paused; } - // Complete audio blocking during TTS (mirrors Swift) - if (handle->is_tts_active) { + // Early-out checks (no lock needed) + if (!is_active || is_tts_active || is_paused) { if (out_has_voice) *out_has_voice = RAC_FALSE; return RAC_SUCCESS; } - if (handle->is_paused) { - if (out_has_voice) - *out_has_voice = RAC_FALSE; - return RAC_SUCCESS; - } + // --- Phase 2: Pure math — no shared state, no lock needed --- + // Compute mean-square energy (no sqrt). The hot-path voice detection + // compares mean_sq > threshold² instead of sqrt(mean_sq) > threshold, + // saving ~15 cycles/frame on ARM. + const float mean_sq = calculate_mean_square(audio_data, sample_count); + + // Deferred callback data (invoked outside lock to prevent re-entrant deadlock). + // If a callback re-enters any rac_energy_vad_* function on the same thread, + // std::mutex (non-recursive) would deadlock without this deferral pattern. + int pending_speech_event = -1; + rac_speech_activity_callback_fn deferred_speech_cb = nullptr; + void* deferred_speech_data = nullptr; + rac_audio_buffer_callback_fn deferred_audio_cb = nullptr; + void* deferred_audio_data = nullptr; + + // --- Phase 3: Update shared state under lock (minimal critical section) --- + { + std::lock_guard lock(handle->mutex); + + // Re-check flags that may have changed between Phase 1 and Phase 3. + // If stop()/pause()/notify_tts_start() ran in the gap, they already + // handled state transitions (including SPEECH_ENDED callbacks). + // Processing stale data here would cause duplicate callbacks. + if (!handle->is_active || handle->is_tts_active || handle->is_paused) { + if (out_has_voice) + *out_has_voice = RAC_FALSE; + return RAC_SUCCESS; + } - // Calculate energy using RMS - float energy = rac_energy_vad_calculate_rms(audio_data, sample_count); + // Handle calibration if active — needs RMS (with sqrt) for + // correct threshold calculation. Calibration is infrequent + // (~20 frames at startup), so the sqrt cost is acceptable. + if (handle->is_calibrating) { + const float energy_rms = std::sqrt(mean_sq); + update_debug_statistics(handle, mean_sq); + handle_calibration_frame(handle, energy_rms); + if (out_has_voice) + *out_has_voice = RAC_FALSE; + return RAC_SUCCESS; + } - // Update debug statistics - update_debug_statistics(handle, energy); + // Normal operation: store mean-square in debug ring buffer. + // get_statistics() converts back to RMS when reading. + update_debug_statistics(handle, mean_sq); - // Handle calibration if active (mirrors Swift) - if (handle->is_calibrating) { - handle_calibration_frame(handle, energy); - if (out_has_voice) - *out_has_voice = RAC_FALSE; - return RAC_SUCCESS; - } + // Compare in squared domain — no sqrt needed. + // Re-read threshold_sq under lock in case it changed (TTS notification). + const bool has_voice = mean_sq > handle->energy_threshold_sq; - bool has_voice = energy > handle->energy_threshold; + // Update state (mirrors Swift's updateVoiceActivityState) + pending_speech_event = update_voice_activity_state(handle, has_voice); - // Update state (mirrors Swift's updateVoiceActivityState) - update_voice_activity_state(handle, has_voice); + // Copy callback pointers for deferred invocation outside the lock + if (pending_speech_event >= 0 && handle->speech_callback) { + deferred_speech_cb = handle->speech_callback; + deferred_speech_data = handle->speech_user_data; + } + if (handle->audio_callback) { + deferred_audio_cb = handle->audio_callback; + deferred_audio_data = handle->audio_user_data; + } - // Call audio buffer callback if provided - if (handle->audio_callback) { - handle->audio_callback(audio_data, sample_count * sizeof(float), handle->audio_user_data); + if (out_has_voice) { + *out_has_voice = has_voice ? RAC_TRUE : RAC_FALSE; + } } - if (out_has_voice) { - *out_has_voice = has_voice ? RAC_TRUE : RAC_FALSE; + // --- Phase 4: Invoke callbacks outside the lock --- + if (deferred_speech_cb) { + deferred_speech_cb(static_cast(pending_speech_event), + deferred_speech_data); + } + if (deferred_audio_cb) { + deferred_audio_cb(audio_data, sample_count * sizeof(float), deferred_audio_data); } return RAC_SUCCESS; @@ -454,30 +579,41 @@ rac_result_t rac_energy_vad_pause(rac_energy_vad_handle_t handle) { return RAC_ERROR_INVALID_ARGUMENT; } - std::lock_guard lock(handle->mutex); + // Deferred callback (invoked outside lock to prevent re-entrant deadlock) + rac_speech_activity_callback_fn deferred_cb = nullptr; + void* deferred_data = nullptr; - // Mirrors Swift's pause() - if (handle->is_paused) { - return RAC_SUCCESS; - } + { + std::lock_guard lock(handle->mutex); - handle->is_paused = true; - RAC_LOG_INFO("EnergyVAD", "VAD paused"); + // Mirrors Swift's pause() + if (handle->is_paused) { + return RAC_SUCCESS; + } + + handle->is_paused = true; + RAC_LOG_INFO("EnergyVAD", "VAD paused"); - // If currently speaking, send end event - if (handle->is_currently_speaking) { - handle->is_currently_speaking = false; - if (handle->speech_callback) { - handle->speech_callback(RAC_SPEECH_ACTIVITY_ENDED, handle->speech_user_data); + // If currently speaking, send end event + if (handle->is_currently_speaking) { + handle->is_currently_speaking = false; + if (handle->speech_callback) { + deferred_cb = handle->speech_callback; + deferred_data = handle->speech_user_data; + } } + + // Clear recent energy values (Reset Ring Buffer) + handle->ring_buffer_count = 0; + handle->ring_buffer_write_index = 0; + // No need to zero out vector, just reset indices + handle->consecutive_silent_frames = 0; + handle->consecutive_voice_frames = 0; } - // Clear recent energy values (Reset Ring Buffer) - handle->ring_buffer_count = 0; - handle->ring_buffer_write_index = 0; - // No need to zero out vector, just reset indices - handle->consecutive_silent_frames = 0; - handle->consecutive_voice_frames = 0; + if (deferred_cb) { + deferred_cb(RAC_SPEECH_ACTIVITY_ENDED, deferred_data); + } return RAC_SUCCESS; } @@ -556,31 +692,43 @@ rac_result_t rac_energy_vad_notify_tts_start(rac_energy_vad_handle_t handle) { return RAC_ERROR_INVALID_ARGUMENT; } - std::lock_guard lock(handle->mutex); + // Deferred callback (invoked outside lock to prevent re-entrant deadlock) + rac_speech_activity_callback_fn deferred_cb = nullptr; + void* deferred_data = nullptr; - // Mirrors Swift's notifyTTSWillStart() - handle->is_tts_active = true; + { + std::lock_guard lock(handle->mutex); - // Save base threshold - handle->base_energy_threshold = handle->energy_threshold; + // Mirrors Swift's notifyTTSWillStart() + handle->is_tts_active = true; - // Increase threshold significantly to prevent TTS audio from triggering VAD - float new_threshold = handle->energy_threshold * handle->tts_threshold_multiplier; - handle->energy_threshold = std::min(new_threshold, 0.1f); + // Save base threshold + handle->base_energy_threshold = handle->energy_threshold; - RAC_LOG_INFO("EnergyVAD", "TTS starting - VAD blocked and threshold increased"); + // Increase threshold significantly to prevent TTS audio from triggering VAD + float new_threshold = handle->energy_threshold * handle->tts_threshold_multiplier; + handle->energy_threshold = std::min(new_threshold, 0.1f); + update_threshold_sq(handle); - // End any current speech detection - if (handle->is_currently_speaking) { - handle->is_currently_speaking = false; - if (handle->speech_callback) { - handle->speech_callback(RAC_SPEECH_ACTIVITY_ENDED, handle->speech_user_data); + RAC_LOG_INFO("EnergyVAD", "TTS starting - VAD blocked and threshold increased"); + + // End any current speech detection + if (handle->is_currently_speaking) { + handle->is_currently_speaking = false; + if (handle->speech_callback) { + deferred_cb = handle->speech_callback; + deferred_data = handle->speech_user_data; + } } + + // Reset counters + handle->consecutive_silent_frames = 0; + handle->consecutive_voice_frames = 0; } - // Reset counters - handle->consecutive_silent_frames = 0; - handle->consecutive_voice_frames = 0; + if (deferred_cb) { + deferred_cb(RAC_SPEECH_ACTIVITY_ENDED, deferred_data); + } return RAC_SUCCESS; } @@ -597,6 +745,7 @@ rac_result_t rac_energy_vad_notify_tts_finish(rac_energy_vad_handle_t handle) { // Immediately restore threshold handle->energy_threshold = handle->base_energy_threshold; + update_threshold_sq(handle); RAC_LOG_INFO("EnergyVAD", "TTS finished - VAD threshold restored"); @@ -657,6 +806,7 @@ rac_result_t rac_energy_vad_set_threshold(rac_energy_vad_handle_t handle, float std::lock_guard lock(handle->mutex); handle->energy_threshold = threshold; handle->base_energy_threshold = threshold; + update_threshold_sq(handle); return RAC_SUCCESS; } @@ -670,20 +820,23 @@ rac_result_t rac_energy_vad_get_statistics(rac_energy_vad_handle_t handle, std::lock_guard lock(handle->mutex); // Mirrors Swift's getStatistics() + // Note: the ring buffer stores mean-square values (not RMS) to avoid + // per-frame sqrt in process_audio(). We convert back to RMS here since + // get_statistics() is called infrequently (on demand, not per frame). float recent_avg = 0.0f; float recent_max = 0.0f; float current = 0.0f; size_t count = handle->ring_buffer_count; if (count > 0) { - + size_t last_idx = (handle->ring_buffer_write_index == 0) ? (handle->recent_energy_values.size() - 1) : (handle->ring_buffer_write_index - 1); - current = handle->recent_energy_values[last_idx]; + current = std::sqrt(handle->recent_energy_values[last_idx]); for (size_t i = 0; i < count; ++i) { - float val = handle->recent_energy_values[i]; + float val = std::sqrt(handle->recent_energy_values[i]); recent_avg += val; recent_max = std::max(recent_max, val); } diff --git a/sdk/runanywhere-commons/src/features/vad/vad_analytics.cpp b/sdk/runanywhere-commons/src/features/vad/vad_analytics.cpp index 27ad90327..4ec9a294b 100644 --- a/sdk/runanywhere-commons/src/features/vad/vad_analytics.cpp +++ b/sdk/runanywhere-commons/src/features/vad/vad_analytics.cpp @@ -9,6 +9,7 @@ #include #include #include +#include #include "rac/core/rac_logger.h" #include "rac/features/vad/rac_vad_analytics.h" @@ -71,12 +72,15 @@ rac_result_t rac_vad_analytics_create(rac_vad_analytics_handle_t* out_handle) { } try { - *out_handle = new rac_vad_analytics_s(); - log_info("VAD.Analytics", "VAD analytics service created"); - return RAC_SUCCESS; + *out_handle = new (std::nothrow) rac_vad_analytics_s(); } catch (...) { + *out_handle = nullptr; + } + if (!*out_handle) { return RAC_ERROR_OUT_OF_MEMORY; } + log_info("VAD.Analytics", "VAD analytics service created"); + return RAC_SUCCESS; } void rac_vad_analytics_destroy(rac_vad_analytics_handle_t handle) { diff --git a/sdk/runanywhere-commons/src/features/vad/vad_component.cpp b/sdk/runanywhere-commons/src/features/vad/vad_component.cpp index c5a449b3d..54ea2e3df 100644 --- a/sdk/runanywhere-commons/src/features/vad/vad_component.cpp +++ b/sdk/runanywhere-commons/src/features/vad/vad_component.cpp @@ -9,6 +9,7 @@ * Do NOT add features not present in the Swift code. */ +#include #include #include #include @@ -42,8 +43,8 @@ struct rac_vad_component { rac_vad_audio_callback_fn audio_callback; void* audio_user_data; - /** Initialization state */ - bool is_initialized; + /** Initialization state (atomic for lock-free query from callbacks) */ + std::atomic is_initialized; /** Mutex for thread safety */ std::mutex mtx; @@ -181,7 +182,7 @@ extern "C" rac_bool_t rac_vad_component_is_initialized(rac_handle_t handle) { return RAC_FALSE; auto* component = reinterpret_cast(handle); - return component->is_initialized ? RAC_TRUE : RAC_FALSE; + return component->is_initialized.load(std::memory_order_acquire) ? RAC_TRUE : RAC_FALSE; } extern "C" rac_result_t rac_vad_component_initialize(rac_handle_t handle) { @@ -478,12 +479,8 @@ extern "C" rac_lifecycle_state_t rac_vad_component_get_state(rac_handle_t handle return RAC_LIFECYCLE_STATE_IDLE; auto* component = reinterpret_cast(handle); - - if (component->is_initialized) { - return RAC_LIFECYCLE_STATE_LOADED; - } - - return RAC_LIFECYCLE_STATE_IDLE; + return component->is_initialized.load(std::memory_order_acquire) ? RAC_LIFECYCLE_STATE_LOADED + : RAC_LIFECYCLE_STATE_IDLE; } extern "C" rac_result_t rac_vad_component_get_metrics(rac_handle_t handle, @@ -497,6 +494,7 @@ extern "C" rac_result_t rac_vad_component_get_metrics(rac_handle_t handle, memset(out_metrics, 0, sizeof(rac_lifecycle_metrics_t)); auto* component = reinterpret_cast(handle); + std::lock_guard lock(component->mtx); if (component->is_initialized) { out_metrics->total_loads = 1; out_metrics->successful_loads = 1; diff --git a/sdk/runanywhere-commons/src/features/vlm/vlm_component.cpp b/sdk/runanywhere-commons/src/features/vlm/vlm_component.cpp index 95b8426b9..2bc804233 100644 --- a/sdk/runanywhere-commons/src/features/vlm/vlm_component.cpp +++ b/sdk/runanywhere-commons/src/features/vlm/vlm_component.cpp @@ -11,6 +11,7 @@ #include #include #include +#include #include #include @@ -74,8 +75,8 @@ struct rac_vlm_component { static int32_t estimate_tokens(const char* text) { if (!text) return 1; - size_t len = strlen(text); - int32_t tokens = static_cast((len + 3) / 4); + const size_t len = strlen(text); + const int32_t tokens = static_cast((len + 3) / 4); return tokens > 0 ? tokens : 1; } @@ -83,11 +84,10 @@ static int32_t estimate_tokens(const char* text) { * Generate a unique ID for generation tracking. */ static std::string generate_unique_id() { - auto now = std::chrono::high_resolution_clock::now(); - auto epoch = now.time_since_epoch(); - auto ns = std::chrono::duration_cast(epoch).count(); - char buffer[64]; - snprintf(buffer, sizeof(buffer), "vlm_gen_%lld", static_cast(ns)); + static thread_local std::mt19937 gen(std::random_device{}()); + std::uniform_int_distribution dis; + char buffer[32]; + snprintf(buffer, sizeof(buffer), "vlm_%08x%08x", dis(gen), dis(gen)); return std::string(buffer); } @@ -111,21 +111,22 @@ static const char* vlm_strip_special_tokens(const char* token, char* buf, size_t size_t out = 0; size_t i = 0; - size_t len = strlen(token); - while (i < len && out < buf_size - 1) { - if (token[i] == '<' && i + 1 < len && token[i + 1] == '|') { + // Use null-terminator checks instead of strlen() to avoid the upfront O(n) scan. + // Tokens are typically short (1-4 chars), but this avoids redundant work. + while (token[i] != '\0' && out < buf_size - 1) { + if (token[i] == '<' && token[i + 1] == '|') { // Scan ahead for closing |> size_t end = i + 2; - while (end < len) { - if (token[end] == '|' && end + 1 < len && token[end + 1] == '>') { + while (token[end] != '\0') { + if (token[end] == '|' && token[end + 1] == '>') { // Found <|...|> — skip the entire special token i = end + 2; break; } end++; } - if (end >= len) { + if (token[end] == '\0') { // No closing |> found — copy the '<' literally buf[out++] = token[i++]; } @@ -169,8 +170,8 @@ extern "C" rac_result_t rac_vlm_resolve_model_files(const char* model_dir, char* struct dirent* entry; while ((entry = readdir(dir)) != nullptr) { - const char* name = entry->d_name; - size_t name_len = strlen(name); + const char* const name = entry->d_name; + const size_t name_len = strlen(name); // Must end with .gguf (case-insensitive) if (name_len < 5) continue; @@ -655,6 +656,11 @@ extern "C" rac_result_t rac_vlm_component_process_stream( ctx.prompt_tokens = estimate_tokens(prompt); ctx.token_count = 0; + // Pre-allocate string capacity to avoid repeated reallocations during streaming. + // Typical VLM responses are a few hundred tokens (~2KB text). + ctx.full_text.reserve(2048); + ctx.cleaned_text.reserve(2048); + // Perform streaming generation result = rac_vlm_process_stream(service, image, prompt, effective_options, vlm_stream_token_callback, &ctx); @@ -679,6 +685,13 @@ extern "C" rac_result_t rac_vlm_component_process_stream( // Fall back to full_text if no cleaned tokens were produced. const std::string& result_text = ctx.cleaned_text.empty() ? ctx.full_text : ctx.cleaned_text; final_result.text = strdup(result_text.c_str()); + if (!final_result.text) { + RAC_LOG_ERROR(LOG_CAT, "Failed to allocate result text"); + if (error_callback) { + error_callback(RAC_ERROR_OUT_OF_MEMORY, "Failed to allocate result text", user_data); + } + return RAC_ERROR_OUT_OF_MEMORY; + } final_result.prompt_tokens = ctx.prompt_tokens; final_result.completion_tokens = estimate_tokens(result_text.c_str()); final_result.total_tokens = final_result.prompt_tokens + final_result.completion_tokens; @@ -714,8 +727,11 @@ extern "C" rac_result_t rac_vlm_component_cancel(rac_handle_t handle) { return RAC_ERROR_INVALID_HANDLE; auto* component = reinterpret_cast(handle); - std::lock_guard lock(component->mtx); + // Do NOT acquire component->mtx here. process_stream holds the mutex for + // the entire streaming duration, so locking here would deadlock until + // generation finishes — defeating the purpose of cancel. + // rac_vlm_cancel only sets an atomic bool, so it is safe without the lock. rac_handle_t service = rac_lifecycle_get_service(component->lifecycle); if (service) { rac_vlm_cancel(service); diff --git a/sdk/runanywhere-commons/src/features/voice_agent/voice_agent.cpp b/sdk/runanywhere-commons/src/features/voice_agent/voice_agent.cpp index d9ff77127..1e059ffd8 100644 --- a/sdk/runanywhere-commons/src/features/voice_agent/voice_agent.cpp +++ b/sdk/runanywhere-commons/src/features/voice_agent/voice_agent.cpp @@ -8,9 +8,12 @@ * CRITICAL: This is a direct port of Swift implementation - do NOT add custom logic! */ +#include #include #include #include +#include +#include #include "rac/core/rac_analytics_events.h" #include "rac/core/rac_audio_utils.h" @@ -43,24 +46,27 @@ void emit_voice_agent_all_ready(); // ============================================================================= struct rac_voice_agent { - // State - bool is_configured; + // State — atomic so is_ready() checks don't need the mutex + std::atomic is_configured{false}; + + // Shutdown barrier — prevents use-after-free on destroy + std::atomic is_shutting_down{false}; + std::atomic in_flight{0}; // Whether we own the component handles (and should destroy them) bool owns_components; - // Composed component handles + // Composed component handles (set at creation, immutable after) rac_handle_t llm_handle; rac_handle_t stt_handle; rac_handle_t tts_handle; rac_handle_t vad_handle; - // Thread safety + // Thread safety — protects mutable operations (load, process, cleanup) std::mutex mutex; rac_voice_agent() - : is_configured(false), - owns_components(false), + : owns_components(false), llm_handle(nullptr), stt_handle(nullptr), tts_handle(nullptr), @@ -150,7 +156,10 @@ rac_result_t rac_voice_agent_create_standalone(rac_voice_agent_handle_t* out_han RAC_LOG_INFO("VoiceAgent", "Creating standalone voice agent"); - rac_voice_agent* agent = new rac_voice_agent(); + rac_voice_agent* agent = new (std::nothrow) rac_voice_agent(); + if (!agent) { + return RAC_ERROR_OUT_OF_MEMORY; + } agent->owns_components = true; // Create LLM component @@ -212,7 +221,10 @@ rac_result_t rac_voice_agent_create(rac_handle_t llm_component_handle, return RAC_ERROR_INVALID_ARGUMENT; } - rac_voice_agent* agent = new rac_voice_agent(); + rac_voice_agent* agent = new (std::nothrow) rac_voice_agent(); + if (!agent) { + return RAC_ERROR_OUT_OF_MEMORY; + } agent->owns_components = false; // External handles, don't destroy them agent->llm_handle = llm_component_handle; agent->stt_handle = stt_component_handle; @@ -230,19 +242,32 @@ void rac_voice_agent_destroy(rac_voice_agent_handle_t handle) { return; } - // If we own the components, destroy them - if (handle->owns_components) { - RAC_LOG_DEBUG("VoiceAgent", "Destroying owned component handles"); - if (handle->vad_handle) - rac_vad_component_destroy(handle->vad_handle); - if (handle->tts_handle) - rac_tts_component_destroy(handle->tts_handle); - if (handle->stt_handle) - rac_stt_component_destroy(handle->stt_handle); - if (handle->llm_handle) - rac_llm_component_destroy(handle->llm_handle); + // Signal shutdown and wait for all in-flight operations (including lock-free ones) + handle->is_shutting_down.store(true, std::memory_order_release); + handle->is_configured.store(false, std::memory_order_release); + + // Spin-wait until all in-flight operations complete + while (handle->in_flight.load(std::memory_order_acquire) > 0) { + std::this_thread::yield(); + } + + { + std::lock_guard lock(handle->mutex); + + if (handle->owns_components) { + RAC_LOG_DEBUG("VoiceAgent", "Destroying owned component handles"); + if (handle->vad_handle) + rac_vad_component_destroy(handle->vad_handle); + if (handle->tts_handle) + rac_tts_component_destroy(handle->tts_handle); + if (handle->stt_handle) + rac_stt_component_destroy(handle->stt_handle); + if (handle->llm_handle) + rac_llm_component_destroy(handle->llm_handle); + } } + // All threads that held/waited on mutex have now exited delete handle; RAC_LOG_DEBUG("VoiceAgent", "Voice agent destroyed"); } @@ -458,7 +483,7 @@ rac_result_t rac_voice_agent_initialize(rac_voice_agent_handle_t handle, // Step 5: Verify all components ready (mirrors Swift's verifyAllComponentsReady) // Note: In the C API, we trust initialization succeeded - handle->is_configured = true; + handle->is_configured.store(true, std::memory_order_release); RAC_LOG_INFO("VoiceAgent", "Voice Agent initialized successfully"); return RAC_SUCCESS; @@ -483,7 +508,7 @@ rac_result_t rac_voice_agent_initialize_with_loaded_models(rac_voice_agent_handl // Note: In C API, we trust that components are already initialized // The Swift version checks isModelLoaded properties - handle->is_configured = true; + handle->is_configured.store(true, std::memory_order_release); RAC_LOG_INFO("VoiceAgent", "Voice Agent initialized with pre-loaded models"); return RAC_SUCCESS; @@ -506,7 +531,7 @@ rac_result_t rac_voice_agent_cleanup(rac_voice_agent_handle_t handle) { rac_vad_component_stop(handle->vad_handle); rac_vad_component_reset(handle->vad_handle); - handle->is_configured = false; + handle->is_configured.store(false, std::memory_order_release); return RAC_SUCCESS; } @@ -516,8 +541,8 @@ rac_result_t rac_voice_agent_is_ready(rac_voice_agent_handle_t handle, rac_bool_ return RAC_ERROR_INVALID_ARGUMENT; } - std::lock_guard lock(handle->mutex); - *out_is_ready = handle->is_configured ? RAC_TRUE : RAC_FALSE; + // Atomic read — no mutex needed for a simple state check + *out_is_ready = handle->is_configured.load(std::memory_order_acquire) ? RAC_TRUE : RAC_FALSE; return RAC_SUCCESS; } @@ -533,16 +558,16 @@ rac_result_t rac_voice_agent_process_voice_turn(rac_voice_agent_handle_t handle, return RAC_ERROR_INVALID_ARGUMENT; } + // Hold lock for the entire pipeline to prevent TOCTOU races. + // is_ready() uses the atomic is_configured (no mutex needed), and + // detect_speech() doesn't use the mutex, so this doesn't block them. std::lock_guard lock(handle->mutex); - // Mirrors Swift's guard isConfigured - if (!handle->is_configured) { + if (!handle->is_configured.load(std::memory_order_acquire)) { RAC_LOG_ERROR("VoiceAgent", "Voice Agent is not initialized"); return RAC_ERROR_NOT_INITIALIZED; } - // Defensive validation: Verify all components are in LOADED state before processing - // This catches issues like dangling handles or improperly initialized components rac_result_t validation_result = validate_all_components_ready(handle); if (validation_result != RAC_SUCCESS) { RAC_LOG_ERROR("VoiceAgent", "Component validation failed - cannot process"); @@ -554,13 +579,12 @@ rac_result_t rac_voice_agent_process_voice_turn(rac_voice_agent_handle_t handle, // Initialize result memset(out_result, 0, sizeof(rac_voice_agent_result_t)); - // Step 1: Transcribe audio (mirrors Swift's Step 1) + // Step 1: Transcribe audio RAC_LOG_DEBUG("VoiceAgent", "Step 1: Transcribing audio"); - rac_stt_result_t stt_result = {}; - rac_result_t result = rac_stt_component_transcribe(handle->stt_handle, audio_data, audio_size, - nullptr, // default options - &stt_result); + rac_result_t result; + result = rac_stt_component_transcribe(handle->stt_handle, audio_data, audio_size, + nullptr, &stt_result); if (result != RAC_SUCCESS) { RAC_LOG_ERROR("VoiceAgent", "STT transcription failed"); @@ -570,19 +594,16 @@ rac_result_t rac_voice_agent_process_voice_turn(rac_voice_agent_handle_t handle, if (!stt_result.text || strlen(stt_result.text) == 0) { RAC_LOG_WARNING("VoiceAgent", "Empty transcription, skipping processing"); rac_stt_result_free(&stt_result); - // Return invalid state to indicate empty input (mirrors Swift's emptyInput error) return RAC_ERROR_INVALID_STATE; } RAC_LOG_INFO("VoiceAgent", "Transcription completed"); - // Step 2: Generate LLM response (mirrors Swift's Step 2) + // Step 2: Generate LLM response RAC_LOG_DEBUG("VoiceAgent", "Step 2: Generating LLM response"); - rac_llm_result_t llm_result = {}; result = rac_llm_component_generate(handle->llm_handle, stt_result.text, - nullptr, // default options - &llm_result); + nullptr, &llm_result); if (result != RAC_SUCCESS) { RAC_LOG_ERROR("VoiceAgent", "LLM generation failed"); @@ -592,13 +613,11 @@ rac_result_t rac_voice_agent_process_voice_turn(rac_voice_agent_handle_t handle, RAC_LOG_INFO("VoiceAgent", "LLM response generated"); - // Step 3: Synthesize speech (mirrors Swift's Step 3) + // Step 3: Synthesize speech RAC_LOG_DEBUG("VoiceAgent", "Step 3: Synthesizing speech"); - rac_tts_result_t tts_result = {}; result = rac_tts_component_synthesize(handle->tts_handle, llm_result.text, - nullptr, // default options - &tts_result); + nullptr, &tts_result); if (result != RAC_SUCCESS) { RAC_LOG_ERROR("VoiceAgent", "TTS synthesis failed"); @@ -607,9 +626,8 @@ rac_result_t rac_voice_agent_process_voice_turn(rac_voice_agent_handle_t handle, return result; } - // Step 4: Convert Float32 PCM to WAV format for playback - // Platform TTS (e.g. System TTS) plays audio directly and returns no PCM data. - // Only convert when actual audio data is returned (e.g. Piper/ONNX TTS). + + // Step 4: Convert Float32 PCM to WAV format — no lock needed (pure computation) void* wav_data = nullptr; size_t wav_size = 0; @@ -639,7 +657,7 @@ rac_result_t rac_voice_agent_process_voice_turn(rac_voice_agent_handle_t handle, out_result->synthesized_audio = wav_data; out_result->synthesized_audio_size = wav_size; - // Free intermediate results (tts_result audio data is no longer needed since we have WAV) + // Free intermediate results rac_stt_result_free(&stt_result); rac_llm_result_free(&llm_result); rac_tts_result_free(&tts_result); @@ -657,9 +675,10 @@ rac_result_t rac_voice_agent_process_stream(rac_voice_agent_handle_t handle, con return RAC_ERROR_INVALID_ARGUMENT; } + // Hold lock for the entire pipeline to prevent TOCTOU races. std::lock_guard lock(handle->mutex); - if (!handle->is_configured) { + if (!handle->is_configured.load(std::memory_order_acquire)) { rac_voice_agent_event_t error_event = {}; error_event.type = RAC_VOICE_AGENT_EVENT_ERROR; error_event.data.error_code = RAC_ERROR_NOT_INITIALIZED; @@ -667,7 +686,6 @@ rac_result_t rac_voice_agent_process_stream(rac_voice_agent_handle_t handle, con return RAC_ERROR_NOT_INITIALIZED; } - // Defensive validation: Verify all components are in LOADED state before processing rac_result_t validation_result = validate_all_components_ready(handle); if (validation_result != RAC_SUCCESS) { RAC_LOG_ERROR("VoiceAgent", "Component validation failed - cannot process stream"); @@ -680,8 +698,9 @@ rac_result_t rac_voice_agent_process_stream(rac_voice_agent_handle_t handle, con // Step 1: Transcribe rac_stt_result_t stt_result = {}; - rac_result_t result = rac_stt_component_transcribe(handle->stt_handle, audio_data, audio_size, - nullptr, &stt_result); + rac_result_t result; + result = rac_stt_component_transcribe(handle->stt_handle, audio_data, audio_size, + nullptr, &stt_result); if (result != RAC_SUCCESS) { rac_voice_agent_event_t error_event = {}; @@ -718,8 +737,7 @@ rac_result_t rac_voice_agent_process_stream(rac_voice_agent_handle_t handle, con // Step 3: Synthesize rac_tts_result_t tts_result = {}; - result = - rac_tts_component_synthesize(handle->tts_handle, llm_result.text, nullptr, &tts_result); + result = rac_tts_component_synthesize(handle->tts_handle, llm_result.text, nullptr, &tts_result); if (result != RAC_SUCCESS) { rac_stt_result_free(&stt_result); @@ -730,9 +748,7 @@ rac_result_t rac_voice_agent_process_stream(rac_voice_agent_handle_t handle, con callback(&error_event, user_data); return result; } - - // Step 4: Convert Float32 PCM to WAV format for playback - // Platform TTS plays audio directly and returns no PCM data — skip conversion. + // Step 4: Convert Float32 PCM to WAV — no lock needed (pure computation) void* wav_data = nullptr; size_t wav_size = 0; @@ -754,24 +770,40 @@ rac_result_t rac_voice_agent_process_stream(rac_voice_agent_handle_t handle, con } } - // Emit audio synthesized event (with WAV data, or empty for platform TTS) + + // Emit audio synthesized event rac_voice_agent_event_t audio_event = {}; audio_event.type = RAC_VOICE_AGENT_EVENT_AUDIO_SYNTHESIZED; audio_event.data.audio.audio_data = wav_data; audio_event.data.audio.audio_size = wav_size; callback(&audio_event, user_data); + // Copy wav_data for the processed event so each callback gets independent memory + void* wav_copy = nullptr; + if (wav_data && wav_size > 0) { + wav_copy = malloc(wav_size); + if (wav_copy) { + memcpy(wav_copy, wav_data, wav_size); + } + } + // Emit final processed event rac_voice_agent_event_t processed_event = {}; processed_event.type = RAC_VOICE_AGENT_EVENT_PROCESSED; processed_event.data.result.speech_detected = RAC_TRUE; processed_event.data.result.transcription = rac_strdup(stt_result.text); processed_event.data.result.response = rac_strdup(llm_result.text); - processed_event.data.result.synthesized_audio = wav_data; - processed_event.data.result.synthesized_audio_size = wav_size; + processed_event.data.result.synthesized_audio = wav_copy; + processed_event.data.result.synthesized_audio_size = wav_copy ? wav_size : 0; callback(&processed_event, user_data); - // Free intermediate results (WAV data ownership transferred to processed_event) + // Free event-owned allocations (callback has consumed the data) + free(processed_event.data.result.transcription); + free(processed_event.data.result.response); + free(wav_copy); + free(wav_data); + + // Free intermediate results rac_stt_result_free(&stt_result); rac_llm_result_free(&llm_result); rac_tts_result_free(&tts_result); @@ -791,7 +823,7 @@ rac_result_t rac_voice_agent_transcribe(rac_voice_agent_handle_t handle, const v std::lock_guard lock(handle->mutex); - if (!handle->is_configured) { + if (!handle->is_configured.load(std::memory_order_acquire)) { return RAC_ERROR_NOT_INITIALIZED; } @@ -817,7 +849,7 @@ rac_result_t rac_voice_agent_generate_response(rac_voice_agent_handle_t handle, std::lock_guard lock(handle->mutex); - if (!handle->is_configured) { + if (!handle->is_configured.load(std::memory_order_acquire)) { return RAC_ERROR_NOT_INITIALIZED; } @@ -843,7 +875,7 @@ rac_result_t rac_voice_agent_synthesize_speech(rac_voice_agent_handle_t handle, std::lock_guard lock(handle->mutex); - if (!handle->is_configured) { + if (!handle->is_configured.load(std::memory_order_acquire)) { return RAC_ERROR_NOT_INITIALIZED; } @@ -887,10 +919,23 @@ rac_result_t rac_voice_agent_detect_speech(rac_voice_agent_handle_t handle, cons return RAC_ERROR_INVALID_ARGUMENT; } + // Check shutdown barrier (this is a lock-free path) + if (handle->is_shutting_down.load(std::memory_order_acquire)) { + return RAC_ERROR_INVALID_STATE; + } + handle->in_flight.fetch_add(1, std::memory_order_acq_rel); + + // Re-check after incrementing to avoid TOCTOU with destroy + if (handle->is_shutting_down.load(std::memory_order_acquire)) { + handle->in_flight.fetch_sub(1, std::memory_order_acq_rel); + return RAC_ERROR_INVALID_STATE; + } + // VAD doesn't require is_configured (mirrors Swift) rac_result_t result = rac_vad_component_process(handle->vad_handle, samples, sample_count, out_speech_detected); + handle->in_flight.fetch_sub(1, std::memory_order_acq_rel); return result; } diff --git a/sdk/runanywhere-commons/src/features/wakeword/wakeword_service.cpp b/sdk/runanywhere-commons/src/features/wakeword/wakeword_service.cpp index fbf362281..469ba9895 100644 --- a/sdk/runanywhere-commons/src/features/wakeword/wakeword_service.cpp +++ b/sdk/runanywhere-commons/src/features/wakeword/wakeword_service.cpp @@ -130,7 +130,13 @@ RAC_API rac_result_t rac_wakeword_initialize(rac_handle_t handle, // Calculate samples per frame service->samples_per_frame = - (service->config.sample_rate * service->config.frame_length_ms) / 1000; + (static_cast(service->config.sample_rate) * service->config.frame_length_ms) / 1000; + + if (service->samples_per_frame == 0) { + RAC_LOG_ERROR("WakeWord", "Invalid config: samples_per_frame is 0 (sample_rate=%d, frame_length_ms=%d)", + service->config.sample_rate, service->config.frame_length_ms); + return RAC_ERROR_INVALID_ARGUMENT; + } // Reserve audio buffer service->audio_buffer.reserve(service->samples_per_frame * 2); @@ -503,6 +509,9 @@ RAC_API rac_result_t rac_wakeword_process(rac_handle_t handle, void* det_ud = nullptr; rac_wakeword_event_t event = {}; bool should_invoke_detection = false; + // Local copies to outlive lock release (c_str() would dangle) + std::string event_keyword_name; + std::string event_model_id; if (detected && keyword_index >= 0) { int64_t now = get_timestamp_ms(); @@ -521,9 +530,11 @@ RAC_API rac_result_t rac_wakeword_process(rac_handle_t handle, if (service->detection_callback && keyword_index < (int32_t)service->models.size()) { det_cb = service->detection_callback; det_ud = service->detection_user_data; + event_keyword_name = service->models[keyword_index].wake_word; + event_model_id = service->models[keyword_index].model_id; event.keyword_index = keyword_index; - event.keyword_name = service->models[keyword_index].wake_word.c_str(); - event.model_id = service->models[keyword_index].model_id.c_str(); + event.keyword_name = event_keyword_name.c_str(); + event.model_id = event_model_id.c_str(); event.confidence = confidence; event.timestamp_ms = now - service->stream_start_time; event.duration_ms = service->config.frame_length_ms; diff --git a/sdk/runanywhere-commons/src/infrastructure/download/download_manager.cpp b/sdk/runanywhere-commons/src/infrastructure/download/download_manager.cpp index 3f58fc307..9065e7032 100644 --- a/sdk/runanywhere-commons/src/infrastructure/download/download_manager.cpp +++ b/sdk/runanywhere-commons/src/infrastructure/download/download_manager.cpp @@ -16,6 +16,7 @@ #include #include #include +#include #include #include @@ -120,7 +121,10 @@ rac_result_t rac_download_manager_create(const rac_download_config_t* config, return RAC_ERROR_INVALID_ARGUMENT; } - rac_download_manager* mgr = new rac_download_manager(); + rac_download_manager* mgr = new (std::nothrow) rac_download_manager(); + if (!mgr) { + return RAC_ERROR_OUT_OF_MEMORY; + } // Initialize config if (config) { @@ -482,6 +486,69 @@ rac_result_t rac_download_manager_mark_failed(rac_download_manager_handle_t hand return RAC_SUCCESS; } +// ============================================================================= +// PUBLIC API - EXTRACTION COMPLETION +// ============================================================================= + +rac_result_t rac_download_manager_mark_extraction_complete(rac_download_manager_handle_t handle, + const char* task_id, + const char* extracted_path) { + if (!handle || !task_id || !extracted_path) { + return RAC_ERROR_INVALID_ARGUMENT; + } + + std::lock_guard lock(handle->mutex); + + auto it = handle->tasks.find(task_id); + if (it == handle->tasks.end()) { + return RAC_ERROR_NOT_FOUND; + } + + download_task_internal& task = it->second; + + task.progress.state = RAC_DOWNLOAD_STATE_COMPLETED; + task.progress.stage = RAC_DOWNLOAD_STAGE_COMPLETED; + task.progress.stage_progress = 1.0; + task.progress.overall_progress = 1.0; + notify_progress(task); + notify_complete(task, RAC_SUCCESS, extracted_path); + + RAC_LOG_INFO("DownloadManager", "Extraction completed for task: %s", task_id); + + return RAC_SUCCESS; +} + +rac_result_t rac_download_manager_mark_extraction_failed(rac_download_manager_handle_t handle, + const char* task_id, + rac_result_t error_code, + const char* error_message) { + if (!handle || !task_id) { + return RAC_ERROR_INVALID_ARGUMENT; + } + + std::lock_guard lock(handle->mutex); + + auto it = handle->tasks.find(task_id); + if (it == handle->tasks.end()) { + return RAC_ERROR_NOT_FOUND; + } + + download_task_internal& task = it->second; + + task.progress.state = RAC_DOWNLOAD_STATE_FAILED; + task.progress.error_code = error_code; + if (error_message) { + task.error_message = error_message; + task.progress.error_message = task.error_message.c_str(); + } + notify_progress(task); + notify_complete(task, error_code, nullptr); + + RAC_LOG_ERROR("DownloadManager", "Extraction failed for task: %s", task_id); + + return RAC_SUCCESS; +} + // ============================================================================= // PUBLIC API - STAGE INFO // ============================================================================= diff --git a/sdk/runanywhere-commons/src/infrastructure/download/download_orchestrator.cpp b/sdk/runanywhere-commons/src/infrastructure/download/download_orchestrator.cpp new file mode 100644 index 000000000..b8cbebb63 --- /dev/null +++ b/sdk/runanywhere-commons/src/infrastructure/download/download_orchestrator.cpp @@ -0,0 +1,807 @@ +/** + * @file download_orchestrator.cpp + * @brief Download Orchestrator - High-Level Model Download Lifecycle Management + * + * Consolidates download business logic from Swift/Kotlin/RN/Flutter SDKs into C++. + * Each SDK now only provides the HTTP transport callback and calls rac_download_orchestrate(). + * + * Full lifecycle: + * 1. Compute destination path (temp if extraction needed, final if not) + * 2. Start HTTP download via platform adapter (rac_http_download) + * 3. On HTTP completion: + * a. If extraction needed → rac_extract_archive_native → find model path → cleanup archive + * b. Update download manager state + * 4. Invoke user's complete_callback with final model path + */ + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "rac/core/rac_logger.h" +#include "rac/core/rac_platform_adapter.h" +#include "rac/infrastructure/download/rac_download.h" +#include "rac/infrastructure/download/rac_download_orchestrator.h" +#include "rac/infrastructure/extraction/rac_extraction.h" +#include "rac/infrastructure/model_management/rac_model_paths.h" +#include "rac/infrastructure/model_management/rac_model_types.h" + +static const char* LOG_TAG = "DownloadOrchestrator"; + +// ============================================================================= +// INTERNAL HELPERS +// ============================================================================= + +/** + * Get file extension from a URL/path string (without dot). + * Handles compound extensions like .tar.gz, .tar.bz2, .tar.xz. + */ +static std::string get_file_extension(const char* url) { + if (!url) return ""; + + std::string path(url); + + // Strip query string and fragment + auto query_pos = path.find('?'); + if (query_pos != std::string::npos) path = path.substr(0, query_pos); + auto frag_pos = path.find('#'); + if (frag_pos != std::string::npos) path = path.substr(0, frag_pos); + + // Find the last path component + auto slash_pos = path.rfind('/'); + std::string filename = (slash_pos != std::string::npos) ? path.substr(slash_pos + 1) : path; + + // Check for compound extensions first + if (filename.length() > 7) { + std::string lower = filename; + for (auto& c : lower) c = static_cast(tolower(c)); + + if (lower.rfind(".tar.gz") == lower.length() - 7) return "tar.gz"; + if (lower.rfind(".tar.bz2") == lower.length() - 8) return "tar.bz2"; + if (lower.rfind(".tar.xz") == lower.length() - 7) return "tar.xz"; + if (lower.rfind(".tgz") == lower.length() - 4) return "tar.gz"; + if (lower.rfind(".tbz2") == lower.length() - 5) return "tar.bz2"; + if (lower.rfind(".txz") == lower.length() - 4) return "tar.xz"; + } + + // Simple extension + auto dot_pos = filename.rfind('.'); + if (dot_pos != std::string::npos && dot_pos < filename.length() - 1) { + return filename.substr(dot_pos + 1); + } + + return ""; +} + +/** + * Get the filename (without extension) from a URL. + */ +static std::string get_filename_stem(const char* url) { + if (!url) return ""; + + std::string path(url); + auto query_pos = path.find('?'); + if (query_pos != std::string::npos) path = path.substr(0, query_pos); + + auto slash_pos = path.rfind('/'); + std::string filename = (slash_pos != std::string::npos) ? path.substr(slash_pos + 1) : path; + + // Strip compound extensions + std::string lower = filename; + for (auto& c : lower) c = static_cast(tolower(c)); + + const char* compound_exts[] = {".tar.gz", ".tar.bz2", ".tar.xz", ".tgz", ".tbz2", ".txz"}; + for (const auto& ext : compound_exts) { + size_t ext_len = strlen(ext); + if (lower.length() > ext_len && lower.rfind(ext) == lower.length() - ext_len) { + return filename.substr(0, filename.length() - ext_len); + } + } + + // Strip simple extension + auto dot_pos = filename.rfind('.'); + if (dot_pos != std::string::npos) { + return filename.substr(0, dot_pos); + } + + return filename; +} + +/** + * Check if a file extension is a known model extension. + */ +static bool is_model_extension(const char* ext) { + if (!ext) return false; + // Compare case-insensitively + std::string lower(ext); + for (auto& c : lower) c = static_cast(tolower(c)); + + return lower == "gguf" || lower == "onnx" || lower == "ort" || lower == "bin" || + lower == "mlmodelc" || lower == "mlpackage"; +} + +/** + * Check if a directory exists. + */ +static bool dir_exists(const char* path) { + struct stat st; + return stat(path, &st) == 0 && S_ISDIR(st.st_mode); +} + +/** + * Create directories recursively (like mkdir -p). + */ +static bool mkdir_p(const char* path) { + if (dir_exists(path)) return true; + + std::string s(path); + std::string::size_type pos = 0; + + while ((pos = s.find('/', pos + 1)) != std::string::npos) { + std::string sub = s.substr(0, pos); + if (!sub.empty()) { + mkdir(sub.c_str(), 0755); + } + } + return mkdir(s.c_str(), 0755) == 0 || dir_exists(path); +} + +/** + * Delete a file. + */ +static void delete_file(const char* path) { + if (path) { + remove(path); + } +} + +// ============================================================================= +// POST-EXTRACTION MODEL PATH FINDING (ported from Swift ExtractionService) +// ============================================================================= + +/** + * Find a single model file in a directory, searching recursively up to max_depth levels. + * Ported from Swift's ExtractionService.findSingleModelFile(). + */ +static bool find_single_model_file(const char* directory, int depth, int max_depth, char* out_path, + size_t path_size) { + if (depth >= max_depth) return false; + + DIR* dir = opendir(directory); + if (!dir) return false; + + struct dirent* entry; + std::string found_model; + std::vector subdirs; + + while ((entry = readdir(dir)) != nullptr) { + // Skip . and .. + if (strcmp(entry->d_name, ".") == 0 || strcmp(entry->d_name, "..") == 0) continue; + // Skip hidden files and macOS resource forks + if (entry->d_name[0] == '.') continue; + + std::string full_path = std::string(directory) + "/" + entry->d_name; + + struct stat st; + if (stat(full_path.c_str(), &st) != 0) continue; + + if (S_ISREG(st.st_mode)) { + // Check if this is a model file + const char* dot = strrchr(entry->d_name, '.'); + if (dot && is_model_extension(dot + 1)) { + found_model = full_path; + break; // Found it + } + } else if (S_ISDIR(st.st_mode)) { + subdirs.push_back(full_path); + } + } + closedir(dir); + + if (!found_model.empty()) { + snprintf(out_path, path_size, "%s", found_model.c_str()); + return true; + } + + // Recursively check subdirectories + for (const auto& subdir : subdirs) { + if (find_single_model_file(subdir.c_str(), depth + 1, max_depth, out_path, path_size)) { + return true; + } + } + + return false; +} + +/** + * Find the nested directory (single visible subdirectory) in an extracted archive. + * Ported from Swift's ExtractionService.findNestedDirectory(). + * + * Common pattern: archive contains one subdirectory with all the files. + * e.g., sherpa-onnx archives extract to: extractedDir/vits-xxx/ + */ +static std::string find_nested_directory(const char* extracted_dir) { + DIR* dir = opendir(extracted_dir); + if (!dir) return extracted_dir; + + struct dirent* entry; + std::vector visible_dirs; + + while ((entry = readdir(dir)) != nullptr) { + if (strcmp(entry->d_name, ".") == 0 || strcmp(entry->d_name, "..") == 0) continue; + // Skip hidden files and macOS resource forks + if (entry->d_name[0] == '.') continue; + if (strncmp(entry->d_name, "._", 2) == 0) continue; + + std::string full_path = std::string(extracted_dir) + "/" + entry->d_name; + + struct stat st; + if (stat(full_path.c_str(), &st) == 0 && S_ISDIR(st.st_mode)) { + visible_dirs.push_back(full_path); + } + } + closedir(dir); + + // If there's exactly one visible subdirectory, return it + if (visible_dirs.size() == 1) { + return visible_dirs[0]; + } + + if (visible_dirs.size() > 1) { + RAC_LOG_WARNING(LOG_TAG, + "find_nested_directory: found %zu subdirectories in '%s', " + "falling back to root (expected exactly 1)", + visible_dirs.size(), extracted_dir); + } + + return extracted_dir; +} + +// ============================================================================= +// ORCHESTRATION CONTEXT (passed through HTTP callbacks) +// ============================================================================= + +struct orchestrate_context { + // Download manager handle + rac_download_manager_handle_t dm_handle; + + // Model info + std::string model_id; + std::string download_url; + rac_inference_framework_t framework; + rac_model_format_t format; + rac_archive_structure_t archive_structure; + + // Paths + std::string download_dest_path; // Where HTTP downloads to + std::string model_folder_path; // Final model folder + bool needs_extraction; + + // Task tracking + std::string task_id; + + // User callbacks + rac_download_progress_callback_fn user_progress_callback; + rac_download_complete_callback_fn user_complete_callback; + void* user_data; +}; + +/** + * Prevent double-free of orchestrate_context when async callbacks race with error paths. + * + * The context is wrapped in a shared_ptr stored in a shared_ctx_holder. + * The holder is passed as raw void* to C callbacks. + * Both the caller and the callback own a reference via the shared_ptr, + * ensuring the context outlives all users. + */ +struct shared_ctx_holder { + std::shared_ptr ctx; +}; + +/** + * HTTP progress callback — forwards to download manager which recalculates overall progress. + */ +static void orchestrate_http_progress(int64_t bytes_downloaded, int64_t total_bytes, + void* callback_user_data) { + auto* holder = static_cast(callback_user_data); + if (!holder || !holder->ctx || !holder->ctx->dm_handle) return; + + auto& ctx = holder->ctx; + rac_download_manager_update_progress(ctx->dm_handle, ctx->task_id.c_str(), bytes_downloaded, + total_bytes); +} + +/** + * HTTP completion callback — handles post-download extraction and cleanup. + * Deletes the holder (releasing its shared_ptr reference) when done. + */ +static void orchestrate_http_complete(rac_result_t result, const char* downloaded_path, + void* callback_user_data) { + auto* holder = static_cast(callback_user_data); + if (!holder || !holder->ctx) { + delete holder; + return; + } + + // Take ownership — holder is deleted at every exit path below + auto ctx = holder->ctx; + delete holder; + + if (result != RAC_SUCCESS) { + // HTTP download failed + RAC_LOG_ERROR(LOG_TAG, "HTTP download failed for model: %s", ctx->model_id.c_str()); + rac_download_manager_mark_failed(ctx->dm_handle, ctx->task_id.c_str(), result, + "HTTP download failed"); + + if (ctx->user_complete_callback) { + ctx->user_complete_callback(ctx->task_id.c_str(), result, nullptr, ctx->user_data); + } + return; + } + + std::string final_path; + + if (ctx->needs_extraction) { + // Mark download as complete (transitions to EXTRACTING state) + rac_download_manager_mark_complete(ctx->dm_handle, ctx->task_id.c_str(), + downloaded_path ? downloaded_path + : ctx->download_dest_path.c_str()); + + RAC_LOG_INFO(LOG_TAG, "Starting extraction for model: %s", ctx->model_id.c_str()); + + // Extract archive using native libarchive + rac_extraction_result_t extraction_result = {}; + rac_result_t extract_result = rac_extract_archive_native( + downloaded_path ? downloaded_path : ctx->download_dest_path.c_str(), + ctx->model_folder_path.c_str(), nullptr /* default options */, nullptr /* no progress */, + nullptr /* no user data */, &extraction_result); + + if (extract_result != RAC_SUCCESS) { + RAC_LOG_ERROR(LOG_TAG, "Extraction failed for model: %s", ctx->model_id.c_str()); + rac_download_manager_mark_extraction_failed(ctx->dm_handle, ctx->task_id.c_str(), + extract_result, "Archive extraction failed"); + + if (ctx->user_complete_callback) { + ctx->user_complete_callback(ctx->task_id.c_str(), extract_result, nullptr, + ctx->user_data); + } + + // Cleanup temp archive + delete_file(ctx->download_dest_path.c_str()); + return; + } + + RAC_LOG_INFO(LOG_TAG, "Extraction complete: %d files, %lld bytes", + extraction_result.files_extracted, extraction_result.bytes_extracted); + + // Find the actual model path after extraction + char model_path[4096]; + rac_result_t find_result = rac_find_model_path_after_extraction( + ctx->model_folder_path.c_str(), ctx->archive_structure, ctx->framework, ctx->format, + model_path, sizeof(model_path)); + + if (find_result == RAC_SUCCESS) { + final_path = model_path; + } else { + // Fallback to model folder itself + final_path = ctx->model_folder_path; + RAC_LOG_WARNING( + LOG_TAG, + "Could not find specific model file after extraction, using folder: %s", + final_path.c_str()); + } + + // Cleanup temp archive file + delete_file(ctx->download_dest_path.c_str()); + + // Mark extraction complete + rac_download_manager_mark_extraction_complete(ctx->dm_handle, ctx->task_id.c_str(), + final_path.c_str()); + } else { + // No extraction needed — file downloaded directly to model folder + final_path = + downloaded_path ? std::string(downloaded_path) : ctx->download_dest_path; + + rac_download_manager_mark_complete(ctx->dm_handle, ctx->task_id.c_str(), + final_path.c_str()); + } + + RAC_LOG_INFO(LOG_TAG, "Download orchestration complete for model: %s → %s", + ctx->model_id.c_str(), final_path.c_str()); + + // Invoke user callback + if (ctx->user_complete_callback) { + ctx->user_complete_callback(ctx->task_id.c_str(), RAC_SUCCESS, final_path.c_str(), + ctx->user_data); + } +} + +// ============================================================================= +// PUBLIC API — DOWNLOAD ORCHESTRATION +// ============================================================================= + +rac_result_t rac_download_orchestrate(rac_download_manager_handle_t dm_handle, + const char* model_id, const char* download_url, + rac_inference_framework_t framework, + rac_model_format_t format, + rac_archive_structure_t archive_structure, + rac_download_progress_callback_fn progress_callback, + rac_download_complete_callback_fn complete_callback, + void* user_data, char** out_task_id) { + if (!dm_handle || !model_id || !download_url || !out_task_id) { + return RAC_ERROR_INVALID_ARGUMENT; + } + + // 1. Compute model folder path + char model_folder[4096]; + rac_result_t path_result = + rac_model_paths_get_model_folder(model_id, framework, model_folder, sizeof(model_folder)); + if (path_result != RAC_SUCCESS) { + RAC_LOG_ERROR(LOG_TAG, "Failed to compute model folder path for: %s", model_id); + return path_result; + } + + // Ensure model folder exists + mkdir_p(model_folder); + + // 2. Determine if extraction is needed + rac_archive_type_t archive_type; + bool needs_extraction = rac_archive_type_from_path(download_url, &archive_type) == RAC_TRUE; + + // 3. Compute download destination + std::string download_dest; + if (needs_extraction) { + // Download to temp path — will be extracted to model folder + char downloads_dir[4096]; + rac_result_t dl_result = + rac_model_paths_get_downloads_directory(downloads_dir, sizeof(downloads_dir)); + if (dl_result != RAC_SUCCESS) { + RAC_LOG_ERROR(LOG_TAG, "Failed to get downloads directory"); + return dl_result; + } + mkdir_p(downloads_dir); + + std::string ext = get_file_extension(download_url); + std::string stem = get_filename_stem(download_url); + if (stem.empty()) stem = model_id; + + download_dest = + std::string(downloads_dir) + "/" + stem + (ext.empty() ? "" : "." + ext); + } else { + // Download directly to model folder + std::string ext = get_file_extension(download_url); + std::string stem = get_filename_stem(download_url); + if (stem.empty()) stem = model_id; + + download_dest = + std::string(model_folder) + "/" + stem + (ext.empty() ? "" : "." + ext); + } + + // 4. Register with download manager (creates task tracking state) + char* task_id = nullptr; + rac_result_t start_result = rac_download_manager_start( + dm_handle, model_id, download_url, download_dest.c_str(), + needs_extraction ? RAC_TRUE : RAC_FALSE, progress_callback, nullptr /* we handle complete */, + user_data, &task_id); + + if (start_result != RAC_SUCCESS) { + RAC_LOG_ERROR(LOG_TAG, "Failed to register download task for: %s", model_id); + return start_result; + } + + // 5. Create orchestration context for callbacks (shared_ptr for safe async lifetime) + auto ctx = std::make_shared(); + ctx->dm_handle = dm_handle; + ctx->model_id = model_id; + ctx->download_url = download_url; + ctx->framework = framework; + ctx->format = format; + ctx->archive_structure = archive_structure; + ctx->download_dest_path = download_dest; + ctx->model_folder_path = model_folder; + ctx->needs_extraction = needs_extraction; + ctx->task_id = task_id; + ctx->user_progress_callback = progress_callback; + ctx->user_complete_callback = complete_callback; + ctx->user_data = user_data; + + // Wrap in holder for C callback void* — callback takes ownership and deletes holder + auto* holder = new shared_ctx_holder{ctx}; + + // 6. Start HTTP download via platform adapter + char* http_task_id = nullptr; + rac_result_t http_result = + rac_http_download(download_url, download_dest.c_str(), orchestrate_http_progress, + orchestrate_http_complete, holder, &http_task_id); + + if (http_result != RAC_SUCCESS) { + RAC_LOG_ERROR(LOG_TAG, "Failed to start HTTP download for: %s", model_id); + rac_download_manager_mark_failed(dm_handle, task_id, http_result, + "Failed to start HTTP download"); + delete holder; // Safe — ctx shared_ptr ref still alive until scope exit + rac_free(task_id); + return http_result; + } + + if (http_task_id) { + rac_free(http_task_id); // We track via download manager task_id instead + } + + *out_task_id = task_id; + + RAC_LOG_INFO(LOG_TAG, "Download orchestration started: model=%s, extraction=%s", model_id, + needs_extraction ? "yes" : "no"); + + return RAC_SUCCESS; +} + +rac_result_t rac_download_orchestrate_multi( + rac_download_manager_handle_t dm_handle, const char* model_id, + const rac_model_file_descriptor_t* files, size_t file_count, const char* base_download_url, + rac_inference_framework_t framework, rac_model_format_t format, + rac_download_progress_callback_fn progress_callback, + rac_download_complete_callback_fn complete_callback, void* user_data, char** out_task_id) { + if (!dm_handle || !model_id || !files || file_count == 0 || !base_download_url || + !out_task_id) { + return RAC_ERROR_INVALID_ARGUMENT; + } + + // Compute model folder + char model_folder[4096]; + rac_result_t path_result = + rac_model_paths_get_model_folder(model_id, framework, model_folder, sizeof(model_folder)); + if (path_result != RAC_SUCCESS) { + return path_result; + } + mkdir_p(model_folder); + + // Register a single task for the multi-file download + std::string composite_url = std::string(base_download_url) + " [" + + std::to_string(file_count) + " files]"; + char* task_id = nullptr; + rac_result_t start_result = rac_download_manager_start( + dm_handle, model_id, composite_url.c_str(), model_folder, RAC_FALSE /* no extraction */, + progress_callback, complete_callback, user_data, &task_id); + + if (start_result != RAC_SUCCESS) { + return start_result; + } + + // Shared state for async completion barrier across all file downloads. + // Each launched download increments pending; its callback decrements and notifies. + // After the loop we wait until all in-flight downloads have reported back. + struct multi_download_barrier { + std::mutex mtx; + std::condition_variable cv; + int pending{0}; + bool any_required_failed{false}; + }; + auto barrier = std::make_shared(); + + // Per-file context passed through the C callback void*. + struct multi_file_holder { + std::shared_ptr barrier; + bool is_required; + }; + + bool launch_failed = false; + for (size_t i = 0; i < file_count; ++i) { + const rac_model_file_descriptor_t& file = files[i]; + + // Build full download URL + std::string file_url = std::string(base_download_url); + if (!file_url.empty() && file_url.back() != '/') file_url += "/"; + file_url += file.relative_path; + + // Build destination path + std::string dest_path = std::string(model_folder); + if (file.destination_path && file.destination_path[0] != '\0') { + dest_path += "/" + std::string(file.destination_path); + } else { + dest_path += "/" + std::string(file.relative_path); + } + + // Ensure parent directory exists + auto last_slash = dest_path.rfind('/'); + if (last_slash != std::string::npos) { + mkdir_p(dest_path.substr(0, last_slash).c_str()); + } + + // Update download manager with file-level progress + int64_t fake_downloaded = static_cast( + static_cast(i) / static_cast(file_count) * 100); + rac_download_manager_update_progress(dm_handle, task_id, fake_downloaded, 100); + + // Increment pending count *before* launching so the barrier is always ahead of callbacks + { + std::lock_guard lk(barrier->mtx); + barrier->pending++; + } + + auto* file_holder = new multi_file_holder{barrier, file.is_required == RAC_TRUE}; + + auto file_complete = [](rac_result_t result, const char* /*path*/, void* ud) { + auto* holder = static_cast(ud); + if (!holder) return; + + auto b = holder->barrier; + bool required = holder->is_required; + delete holder; + + std::lock_guard lk(b->mtx); + if (result != RAC_SUCCESS && required) { + b->any_required_failed = true; + } + b->pending--; + b->cv.notify_all(); + }; + + char* http_task_id = nullptr; + rac_result_t http_result = rac_http_download( + file_url.c_str(), dest_path.c_str(), nullptr /* no per-file progress */, file_complete, + file_holder, &http_task_id); + + if (http_task_id) rac_free(http_task_id); + + if (http_result != RAC_SUCCESS) { + // Download never started — callback won't fire, so clean up manually + delete file_holder; + { + std::lock_guard lk(barrier->mtx); + barrier->pending--; // undo the pre-increment + } + + if (file.is_required == RAC_TRUE) { + RAC_LOG_ERROR(LOG_TAG, "Required file download failed to start: %s", + file.relative_path); + launch_failed = true; + break; + } + RAC_LOG_WARNING(LOG_TAG, "Optional file download failed to start: %s", + file.relative_path); + continue; + } + + // Download started — async callback owns file_holder + } + + // Wait for all in-flight downloads to complete before reporting final status + { + std::unique_lock lk(barrier->mtx); + barrier->cv.wait(lk, [&barrier] { return barrier->pending == 0; }); + } + + bool any_failed = launch_failed || barrier->any_required_failed; + + if (any_failed) { + rac_download_manager_mark_failed(dm_handle, task_id, RAC_ERROR_DOWNLOAD_FAILED, + "One or more required files failed to download"); + *out_task_id = task_id; + return RAC_ERROR_DOWNLOAD_FAILED; + } else { + // Update final progress + rac_download_manager_update_progress(dm_handle, task_id, 100, 100); + rac_download_manager_mark_complete(dm_handle, task_id, model_folder); + } + + *out_task_id = task_id; + return RAC_SUCCESS; +} + +// ============================================================================= +// PUBLIC API — POST-EXTRACTION MODEL PATH FINDING +// ============================================================================= + +rac_result_t rac_find_model_path_after_extraction(const char* extracted_dir, + rac_archive_structure_t structure, + rac_inference_framework_t framework, + rac_model_format_t format, char* out_path, + size_t path_size) { + if (!extracted_dir || !out_path || path_size == 0) { + return RAC_ERROR_INVALID_ARGUMENT; + } + + // For directory-based frameworks (ONNX), the directory itself is the model path + if (rac_framework_uses_directory_based_models(framework) == RAC_TRUE) { + // Check for nested directory pattern + std::string nested = find_nested_directory(extracted_dir); + snprintf(out_path, path_size, "%s", nested.c_str()); + return RAC_SUCCESS; + } + + // Handle based on archive structure + switch (structure) { + case RAC_ARCHIVE_STRUCTURE_SINGLE_FILE_NESTED: { + // Look for a single model file, possibly in a subdirectory (up to 2 levels deep) + if (find_single_model_file(extracted_dir, 0, 2, out_path, path_size)) { + return RAC_SUCCESS; + } + // Fallback: return extracted dir + snprintf(out_path, path_size, "%s", extracted_dir); + return RAC_SUCCESS; + } + + case RAC_ARCHIVE_STRUCTURE_NESTED_DIRECTORY: { + // Common pattern: archive contains one subdirectory with all the files + std::string nested = find_nested_directory(extracted_dir); + snprintf(out_path, path_size, "%s", nested.c_str()); + return RAC_SUCCESS; + } + + case RAC_ARCHIVE_STRUCTURE_DIRECTORY_BASED: + case RAC_ARCHIVE_STRUCTURE_UNKNOWN: + default: { + // Try to find a model file first + if (find_single_model_file(extracted_dir, 0, 2, out_path, path_size)) { + return RAC_SUCCESS; + } + // Check for nested directory + std::string nested = find_nested_directory(extracted_dir); + snprintf(out_path, path_size, "%s", nested.c_str()); + return RAC_SUCCESS; + } + } +} + +// ============================================================================= +// PUBLIC API — UTILITY FUNCTIONS +// ============================================================================= + +rac_result_t rac_download_compute_destination(const char* model_id, const char* download_url, + rac_inference_framework_t framework, + rac_model_format_t format, char* out_path, + size_t path_size, + rac_bool_t* out_needs_extraction) { + if (!model_id || !download_url || !out_path || path_size == 0 || !out_needs_extraction) { + return RAC_ERROR_INVALID_ARGUMENT; + } + + // Check if extraction is needed + rac_archive_type_t archive_type; + bool needs_extraction = rac_archive_type_from_path(download_url, &archive_type) == RAC_TRUE; + *out_needs_extraction = needs_extraction ? RAC_TRUE : RAC_FALSE; + + if (needs_extraction) { + // Temp path in downloads directory + char downloads_dir[4096]; + rac_result_t result = + rac_model_paths_get_downloads_directory(downloads_dir, sizeof(downloads_dir)); + if (result != RAC_SUCCESS) return result; + + std::string ext = get_file_extension(download_url); + std::string stem = get_filename_stem(download_url); + if (stem.empty()) stem = model_id; + + snprintf(out_path, path_size, "%s/%s%s%s", downloads_dir, stem.c_str(), + ext.empty() ? "" : ".", ext.empty() ? "" : ext.c_str()); + } else { + // Direct to model folder + char model_folder[4096]; + rac_result_t result = + rac_model_paths_get_model_folder(model_id, framework, model_folder, sizeof(model_folder)); + if (result != RAC_SUCCESS) return result; + + std::string ext = get_file_extension(download_url); + std::string stem = get_filename_stem(download_url); + if (stem.empty()) stem = model_id; + + snprintf(out_path, path_size, "%s/%s%s%s", model_folder, stem.c_str(), + ext.empty() ? "" : ".", ext.empty() ? "" : ext.c_str()); + } + + return RAC_SUCCESS; +} + +rac_bool_t rac_download_requires_extraction(const char* download_url) { + if (!download_url) return RAC_FALSE; + + rac_archive_type_t type; + return rac_archive_type_from_path(download_url, &type); +} diff --git a/sdk/runanywhere-commons/src/infrastructure/events/event_publisher.cpp b/sdk/runanywhere-commons/src/infrastructure/events/event_publisher.cpp index 1c831dfa2..f31bb1573 100644 --- a/sdk/runanywhere-commons/src/infrastructure/events/event_publisher.cpp +++ b/sdk/runanywhere-commons/src/infrastructure/events/event_publisher.cpp @@ -30,6 +30,7 @@ struct Subscription { uint64_t id; rac_event_callback_fn callback; void* user_data; + std::shared_ptr> alive; }; std::mutex g_event_mutex; @@ -41,9 +42,6 @@ std::unordered_map> g_subscripti // All-events subscriptions std::vector g_all_subscriptions; -// Sentinel category for "all events" -const rac_event_category_t CATEGORY_ALL_SENTINEL = static_cast(-1); - uint64_t current_time_ms() { using namespace std::chrono; return duration_cast(system_clock::now().time_since_epoch()).count(); @@ -80,6 +78,7 @@ uint64_t rac_event_subscribe(rac_event_category_t category, rac_event_callback_f sub.id = g_next_subscription_id.fetch_add(1); sub.callback = callback; sub.user_data = user_data; + sub.alive = std::make_shared>(true); g_subscriptions[category].push_back(sub); @@ -97,6 +96,7 @@ uint64_t rac_event_subscribe_all(rac_event_callback_fn callback, void* user_data sub.id = g_next_subscription_id.fetch_add(1); sub.callback = callback; sub.user_data = user_data; + sub.alive = std::make_shared>(true); g_all_subscriptions.push_back(sub); @@ -112,8 +112,12 @@ void rac_event_unsubscribe(uint64_t subscription_id) { auto remove_from = [subscription_id](std::vector& subs) { auto it = - std::remove_if(subs.begin(), subs.end(), [subscription_id](const Subscription& s) { - return s.id == subscription_id; + std::remove_if(subs.begin(), subs.end(), [subscription_id](Subscription& s) { + if (s.id == subscription_id) { + s.alive->store(false); + return true; + } + return false; }); if (it != subs.end()) { subs.erase(it, subs.end()); @@ -146,19 +150,33 @@ rac_result_t rac_event_publish(const rac_event_t* event) { event_copy.timestamp_ms = static_cast(current_time_ms()); } - std::lock_guard lock(g_event_mutex); + // Copy subscriber lists under lock, then invoke callbacks without lock + // to avoid deadlock if a callback subscribes/unsubscribes/publishes. + std::vector category_subs; + std::vector all_subs; + + { + std::lock_guard lock(g_event_mutex); + + auto it = g_subscriptions.find(event_copy.category); + if (it != g_subscriptions.end()) { + category_subs = it->second; + } + all_subs = g_all_subscriptions; + } - // Notify category-specific subscribers - auto it = g_subscriptions.find(event_copy.category); - if (it != g_subscriptions.end()) { - for (const auto& sub : it->second) { + // Notify category-specific subscribers (skip if unsubscribed after snapshot) + for (const auto& sub : category_subs) { + if (sub.alive->load()) { sub.callback(&event_copy, sub.user_data); } } - // Notify all-events subscribers - for (const auto& sub : g_all_subscriptions) { - sub.callback(&event_copy, sub.user_data); + // Notify all-events subscribers (skip if unsubscribed after snapshot) + for (const auto& sub : all_subs) { + if (sub.alive->load()) { + sub.callback(&event_copy, sub.user_data); + } } return RAC_SUCCESS; diff --git a/sdk/runanywhere-commons/src/infrastructure/extraction/rac_extraction.cpp b/sdk/runanywhere-commons/src/infrastructure/extraction/rac_extraction.cpp new file mode 100644 index 000000000..daed9a0dd --- /dev/null +++ b/sdk/runanywhere-commons/src/infrastructure/extraction/rac_extraction.cpp @@ -0,0 +1,386 @@ +/** + * @file rac_extraction.cpp + * @brief Native archive extraction implementation using libarchive. + * + * Streaming extraction with constant memory usage regardless of archive size. + * Supports ZIP, TAR.GZ, TAR.BZ2, TAR.XZ with auto-detection via magic bytes. + */ + +#include "rac/infrastructure/extraction/rac_extraction.h" + +#include +#include + +#include +#include +#include +#include +#include +#include + +#ifdef _WIN32 +#include // for _mkdir +#endif + +#include "rac/core/rac_logger.h" + +static const char* kLogTag = "Extraction"; + +// ============================================================================= +// INTERNAL HELPERS +// ============================================================================= + +/** + * Security: Check for path traversal (zip-slip attack). + * Rejects absolute paths and paths containing ".." components. + */ +static bool is_path_safe(const char* pathname) { + if (!pathname || pathname[0] == '\0') return false; + + // Reject absolute paths (Unix) + if (pathname[0] == '/') return false; + + // Reject Windows UNC paths (\\server\share) + if (pathname[0] == '\\' && pathname[1] == '\\') return false; + + // Reject Windows drive letters (C:, D:, etc.) + if (((pathname[0] >= 'A' && pathname[0] <= 'Z') || + (pathname[0] >= 'a' && pathname[0] <= 'z')) && + pathname[1] == ':') { + return false; + } + + // Normalize and check for ".." components (handle both / and \ separators) + const char* p = pathname; + while (*p) { + if (p[0] == '.' && p[1] == '.') { + bool at_start = (p == pathname || *(p - 1) == '/' || *(p - 1) == '\\'); + bool at_end = (p[2] == '/' || p[2] == '\\' || p[2] == '\0'); + if (at_start && at_end) { + return false; + } + } + p++; + } + return true; +} + +/** + * Check if an entry should be skipped (macOS resource forks, etc.). + */ +static bool should_skip_entry(const char* pathname, rac_bool_t skip_macos) { + if (!pathname || pathname[0] == '\0') return true; + + if (skip_macos) { + // Skip __MACOSX/ directory and its contents + if (strstr(pathname, "__MACOSX") != nullptr) return true; + + // Skip ._ resource fork files + const char* basename = strrchr(pathname, '/'); + basename = basename ? basename + 1 : pathname; + if (basename[0] == '.' && basename[1] == '_') return true; + } + return false; +} + +/** + * Create a directory and all intermediate directories. + * Equivalent to `mkdir -p`. + */ +static rac_result_t create_directories(const std::string& path) { + if (path.empty()) return RAC_SUCCESS; + + std::string current; + for (size_t i = 0; i < path.size(); i++) { + current += path[i]; + if (path[i] == '/' || i == path.size() - 1) { + if (current == "/") continue; +#ifdef _WIN32 + int ret = _mkdir(current.c_str()); +#else + int ret = mkdir(current.c_str(), 0755); +#endif + if (ret != 0 && errno != EEXIST) { + // Check if it already exists as a directory + struct stat st; + if (stat(current.c_str(), &st) != 0 || !S_ISDIR(st.st_mode)) { + return RAC_ERROR_EXTRACTION_FAILED; + } + } + } + } + return RAC_SUCCESS; +} + +/** + * Ensure trailing slash on directory path. + */ +static std::string ensure_trailing_slash(const std::string& path) { + if (path.empty() || path.back() == '/') return path; + return path + '/'; +} + +// ============================================================================= +// PUBLIC API - rac_extract_archive_native +// ============================================================================= + +rac_result_t rac_extract_archive_native(const char* archive_path, const char* destination_dir, + const rac_extraction_options_t* options, + rac_extraction_progress_fn progress_callback, + void* user_data, rac_extraction_result_t* out_result) { + if (!archive_path || !destination_dir) { + return RAC_ERROR_NULL_POINTER; + } + + // Check archive file exists + struct stat archive_stat; + if (stat(archive_path, &archive_stat) != 0) { + RAC_LOG_ERROR(kLogTag, "Archive file not found: %s", archive_path); + return RAC_ERROR_FILE_NOT_FOUND; + } + + // Use defaults if no options provided + rac_extraction_options_t opts = + options ? *options : RAC_EXTRACTION_OPTIONS_DEFAULT; + + // Create destination directory + rac_result_t dir_result = create_directories(destination_dir); + if (RAC_FAILED(dir_result)) { + RAC_LOG_ERROR(kLogTag, "Failed to create destination directory: %s", destination_dir); + return RAC_ERROR_EXTRACTION_FAILED; + } + + std::string dest_dir = ensure_trailing_slash(destination_dir); + + RAC_LOG_INFO(kLogTag, "Extracting archive: %s -> %s", archive_path, destination_dir); + + // Open archive for reading (streaming) + struct archive* a = archive_read_new(); + if (!a) { + RAC_LOG_ERROR(kLogTag, "Failed to allocate archive reader"); + return RAC_ERROR_EXTRACTION_FAILED; + } + + // Enable all supported formats and filters for auto-detection + archive_read_support_format_all(a); + archive_read_support_filter_all(a); + + // Open the archive file with 10KB block size (streaming) + int r = archive_read_open_filename(a, archive_path, 10240); + if (r != ARCHIVE_OK) { + const char* err = archive_error_string(a); + RAC_LOG_ERROR(kLogTag, "Failed to open archive: %s (%s)", archive_path, + err ? err : "unknown error"); + archive_read_free(a); + return RAC_ERROR_UNSUPPORTED_ARCHIVE; + } + + // Prepare disk writer for extraction + struct archive* ext = archive_write_disk_new(); + if (!ext) { + RAC_LOG_ERROR(kLogTag, "Failed to allocate disk writer"); + archive_read_free(a); + return RAC_ERROR_EXTRACTION_FAILED; + } + + // Set extraction flags: preserve timestamps and permissions + int flags = ARCHIVE_EXTRACT_TIME | ARCHIVE_EXTRACT_PERM; + archive_write_disk_set_options(ext, flags); + archive_write_disk_set_standard_lookup(ext); + + // Extract entries (streaming loop) + rac_extraction_result_t result = {0, 0, 0, 0}; + struct archive_entry* entry; + rac_result_t status = RAC_SUCCESS; + + while (true) { + r = archive_read_next_header(a, &entry); + if (r == ARCHIVE_EOF) break; + + if (r != ARCHIVE_OK && r != ARCHIVE_WARN) { + const char* err = archive_error_string(a); + RAC_LOG_ERROR(kLogTag, "Error reading archive entry: %s", err ? err : "unknown"); + status = RAC_ERROR_EXTRACTION_FAILED; + break; + } + + const char* pathname = archive_entry_pathname(entry); + if (!pathname) { + archive_read_data_skip(a); + continue; + } + + // Security: zip-slip protection + if (!is_path_safe(pathname)) { + RAC_LOG_WARNING(kLogTag, "Skipping unsafe path: %s", pathname); + result.entries_skipped++; + archive_read_data_skip(a); + continue; + } + + // Skip macOS resource forks + if (should_skip_entry(pathname, opts.skip_macos_resources)) { + result.entries_skipped++; + archive_read_data_skip(a); + continue; + } + + // Handle symbolic links + unsigned int entry_type = archive_entry_filetype(entry); + if (entry_type == AE_IFLNK) { + if (opts.skip_symlinks) { + result.entries_skipped++; + archive_read_data_skip(a); + continue; + } + // Safety: reject symlinks pointing outside destination + const char* link_target = archive_entry_symlink(entry); + if (link_target && (link_target[0] == '/' || strstr(link_target, "..") != nullptr)) { + RAC_LOG_WARNING(kLogTag, "Skipping unsafe symlink: %s -> %s", pathname, + link_target); + result.entries_skipped++; + archive_read_data_skip(a); + continue; + } + } + + // Rewrite path to be under destination directory + std::string full_path = dest_dir + pathname; + archive_entry_set_pathname(entry, full_path.c_str()); + + // Also rewrite hardlink paths if present (with safety check) + const char* hardlink = archive_entry_hardlink(entry); + if (hardlink && hardlink[0] != '\0') { + if (!is_path_safe(hardlink)) { + RAC_LOG_WARNING(kLogTag, "Skipping unsafe hardlink target: %s -> %s", pathname, + hardlink); + result.entries_skipped++; + archive_read_data_skip(a); + continue; + } + std::string full_hardlink = dest_dir + hardlink; + archive_entry_set_hardlink(entry, full_hardlink.c_str()); + } + + // Write entry header (creates file/directory on disk) + r = archive_write_header(ext, entry); + if (r != ARCHIVE_OK) { + const char* err = archive_error_string(ext); + RAC_LOG_WARNING(kLogTag, "Failed to write header for: %s (%s)", pathname, + err ? err : "unknown"); + archive_read_data_skip(a); + continue; + } + + // Copy file data (streaming, constant memory) + if (archive_entry_size(entry) > 0 && entry_type == AE_IFREG) { + const void* buff; + size_t size; + la_int64_t offset; + + bool data_error = false; + while (true) { + r = archive_read_data_block(a, &buff, &size, &offset); + if (r == ARCHIVE_EOF) break; + if (r != ARCHIVE_OK) { + const char* err = archive_error_string(a); + RAC_LOG_ERROR(kLogTag, "Error reading data for: %s (%s)", pathname, + err ? err : "unknown"); + data_error = true; + break; + } + r = archive_write_data_block(ext, buff, size, offset); + if (r != ARCHIVE_OK) { + const char* err = archive_error_string(ext); + RAC_LOG_ERROR(kLogTag, "Error writing data for: %s (%s)", pathname, + err ? err : "unknown"); + data_error = true; + break; + } + result.bytes_extracted += static_cast(size); + } + if (data_error) { + status = RAC_ERROR_EXTRACTION_FAILED; + break; + } + } + + // Finish entry (sets permissions, timestamps) + archive_write_finish_entry(ext); + + // Track statistics + if (entry_type == AE_IFDIR) { + result.directories_created++; + } else if (entry_type == AE_IFREG) { + result.files_extracted++; + } + + // Progress callback + if (progress_callback) { + progress_callback(result.files_extracted, 0 /* total unknown in streaming */, + result.bytes_extracted, user_data); + } + } + + // Cleanup + archive_read_free(a); + archive_write_free(ext); + + // Output result + if (out_result) { + *out_result = result; + } + + if (RAC_SUCCEEDED(status)) { + RAC_LOG_INFO(kLogTag, "Extraction complete: %d files, %d dirs, %lld bytes, %d skipped", + result.files_extracted, result.directories_created, + static_cast(result.bytes_extracted), result.entries_skipped); + } + + return status; +} + +// ============================================================================= +// PUBLIC API - rac_detect_archive_type +// ============================================================================= + +rac_bool_t rac_detect_archive_type(const char* file_path, rac_archive_type_t* out_type) { + if (!file_path || !out_type) return RAC_FALSE; + + FILE* f = fopen(file_path, "rb"); + if (!f) return RAC_FALSE; + + unsigned char magic[6] = {0}; + size_t bytes_read = fread(magic, 1, sizeof(magic), f); + fclose(f); + + if (bytes_read < 2) return RAC_FALSE; + + // ZIP: PK\x03\x04 + if (bytes_read >= 4 && magic[0] == 0x50 && magic[1] == 0x4B && magic[2] == 0x03 && + magic[3] == 0x04) { + *out_type = RAC_ARCHIVE_TYPE_ZIP; + return RAC_TRUE; + } + + // GZIP (tar.gz): \x1f\x8b + if (magic[0] == 0x1F && magic[1] == 0x8B) { + *out_type = RAC_ARCHIVE_TYPE_TAR_GZ; + return RAC_TRUE; + } + + // BZIP2 (tar.bz2): BZh + if (bytes_read >= 3 && magic[0] == 0x42 && magic[1] == 0x5A && magic[2] == 0x68) { + *out_type = RAC_ARCHIVE_TYPE_TAR_BZ2; + return RAC_TRUE; + } + + // XZ (tar.xz): \xFD7zXZ\x00 + if (bytes_read >= 6 && magic[0] == 0xFD && magic[1] == 0x37 && magic[2] == 0x7A && + magic[3] == 0x58 && magic[4] == 0x5A && magic[5] == 0x00) { + *out_type = RAC_ARCHIVE_TYPE_TAR_XZ; + return RAC_TRUE; + } + + return RAC_FALSE; +} diff --git a/sdk/runanywhere-commons/src/infrastructure/file_management/file_manager.cpp b/sdk/runanywhere-commons/src/infrastructure/file_management/file_manager.cpp new file mode 100644 index 000000000..6e1e0f621 --- /dev/null +++ b/sdk/runanywhere-commons/src/infrastructure/file_management/file_manager.cpp @@ -0,0 +1,536 @@ +/** + * @file file_manager.cpp + * @brief File Manager - Centralized File Management Business Logic + * + * Consolidates duplicated file management logic from Swift, Kotlin, Flutter, and RN SDKs. + * All file I/O is performed via platform callbacks (rac_file_callbacks_t). + * Business logic (recursive traversal, path computation, threshold checks) lives here. + */ + +#include "rac/infrastructure/file_management/rac_file_manager.h" + +#include +#include + +#include "rac/core/rac_logger.h" +#include "rac/infrastructure/model_management/rac_model_paths.h" + +// ============================================================================= +// INTERNAL HELPERS +// ============================================================================= + +static const char* LOG_CATEGORY = "FileManager"; + +/** Storage warning threshold: 1 GB */ +static const int64_t STORAGE_WARNING_THRESHOLD = 1024LL * 1024LL * 1024LL; + +/** + * Validate that required callbacks are present. + */ +static rac_result_t validate_callbacks(const rac_file_callbacks_t* cb) { + if (cb == nullptr) { + return RAC_ERROR_NULL_POINTER; + } + if (cb->create_directory == nullptr || cb->delete_path == nullptr || + cb->list_directory == nullptr || cb->free_entries == nullptr || + cb->path_exists == nullptr || cb->get_file_size == nullptr) { + return RAC_ERROR_INVALID_ARGUMENT; + } + return RAC_SUCCESS; +} + +/** + * Build a full path by joining parent and child with '/'. + */ +static std::string join_path(const char* parent, const char* child) { + std::string result(parent); + if (!result.empty() && result.back() != '/') { + result += '/'; + } + result += child; + return result; +} + +/** + * Recursive directory size calculation. + * This is the core logic duplicated across all SDKs. + * + * Algorithm (identical in Swift/Kotlin/Flutter/RN): + * 1. List directory entries + * 2. For each entry: if directory → recurse, else → add file size + */ +static rac_result_t calculate_dir_size_recursive(const rac_file_callbacks_t* cb, const char* path, + int64_t* out_size) { + char** entries = nullptr; + size_t count = 0; + + rac_result_t result = cb->list_directory(path, &entries, &count, cb->user_data); + if (RAC_FAILED(result)) { + // Directory doesn't exist or can't be listed — treat as 0 size + *out_size = 0; + return RAC_SUCCESS; + } + + int64_t total = 0; + + for (size_t i = 0; i < count; i++) { + // Skip . and .. + if (std::strcmp(entries[i], ".") == 0 || std::strcmp(entries[i], "..") == 0) { + continue; + } + + std::string entry_path = join_path(path, entries[i]); + rac_bool_t is_directory = RAC_FALSE; + rac_bool_t exists = cb->path_exists(entry_path.c_str(), &is_directory, cb->user_data); + + if (exists == RAC_TRUE) { + if (is_directory == RAC_TRUE) { + int64_t sub_size = 0; + calculate_dir_size_recursive(cb, entry_path.c_str(), &sub_size); + total += sub_size; + } else { + int64_t file_size = cb->get_file_size(entry_path.c_str(), cb->user_data); + if (file_size > 0) { + total += file_size; + } + } + } + } + + cb->free_entries(entries, count, cb->user_data); + *out_size = total; + return RAC_SUCCESS; +} + +/** + * Clear a directory: delete all contents, then recreate the empty directory. + */ +static rac_result_t clear_directory_impl(const rac_file_callbacks_t* cb, const char* path) { + rac_bool_t is_dir = RAC_FALSE; + rac_bool_t exists = cb->path_exists(path, &is_dir, cb->user_data); + + if (exists == RAC_TRUE && is_dir == RAC_TRUE) { + // Delete the directory and all contents + rac_result_t result = cb->delete_path(path, 1 /* recursive */, cb->user_data); + if (RAC_FAILED(result)) { + return result; + } + } + + // Recreate the empty directory + return cb->create_directory(path, 1 /* recursive */, cb->user_data); +} + +// ============================================================================= +// DIRECTORY STRUCTURE +// ============================================================================= + +rac_result_t rac_file_manager_create_directory_structure(const rac_file_callbacks_t* cb) { + rac_result_t result = validate_callbacks(cb); + if (RAC_FAILED(result)) { + return result; + } + + // Get paths from rac_model_paths + char models_dir[1024]; + char cache_dir[1024]; + char temp_dir[1024]; + char downloads_dir[1024]; + + result = rac_model_paths_get_models_directory(models_dir, sizeof(models_dir)); + if (RAC_FAILED(result)) { + RAC_LOG_ERROR(LOG_CATEGORY, "Failed to get models directory path"); + return result; + } + + result = rac_model_paths_get_cache_directory(cache_dir, sizeof(cache_dir)); + if (RAC_FAILED(result)) { + RAC_LOG_ERROR(LOG_CATEGORY, "Failed to get cache directory path"); + return result; + } + + result = rac_model_paths_get_temp_directory(temp_dir, sizeof(temp_dir)); + if (RAC_FAILED(result)) { + RAC_LOG_ERROR(LOG_CATEGORY, "Failed to get temp directory path"); + return result; + } + + result = rac_model_paths_get_downloads_directory(downloads_dir, sizeof(downloads_dir)); + if (RAC_FAILED(result)) { + RAC_LOG_ERROR(LOG_CATEGORY, "Failed to get downloads directory path"); + return result; + } + + // Create each directory + const char* dirs[] = {models_dir, cache_dir, temp_dir, downloads_dir}; + for (const char* dir : dirs) { + result = cb->create_directory(dir, 1 /* recursive */, cb->user_data); + if (RAC_FAILED(result)) { + RAC_LOG_ERROR(LOG_CATEGORY, "Failed to create directory"); + return result; + } + } + + RAC_LOG_INFO(LOG_CATEGORY, "Directory structure created successfully"); + return RAC_SUCCESS; +} + +// ============================================================================= +// MODEL FOLDER MANAGEMENT +// ============================================================================= + +rac_result_t rac_file_manager_create_model_folder(const rac_file_callbacks_t* cb, + const char* model_id, + rac_inference_framework_t framework, + char* out_path, size_t path_size) { + if (model_id == nullptr || out_path == nullptr || path_size == 0) { + return RAC_ERROR_NULL_POINTER; + } + + rac_result_t result = validate_callbacks(cb); + if (RAC_FAILED(result)) { + return result; + } + + // Get model folder path from rac_model_paths + result = rac_model_paths_get_model_folder(model_id, framework, out_path, path_size); + if (RAC_FAILED(result)) { + return result; + } + + // Create the directory + result = cb->create_directory(out_path, 1 /* recursive */, cb->user_data); + if (RAC_FAILED(result)) { + RAC_LOG_ERROR(LOG_CATEGORY, "Failed to create model folder"); + return result; + } + + return RAC_SUCCESS; +} + +rac_result_t rac_file_manager_model_folder_exists(const rac_file_callbacks_t* cb, + const char* model_id, + rac_inference_framework_t framework, + rac_bool_t* out_exists, + rac_bool_t* out_has_contents) { + if (model_id == nullptr || out_exists == nullptr) { + return RAC_ERROR_NULL_POINTER; + } + + rac_result_t result = validate_callbacks(cb); + if (RAC_FAILED(result)) { + return result; + } + + // Get model folder path + char folder_path[1024]; + result = rac_model_paths_get_model_folder(model_id, framework, folder_path, sizeof(folder_path)); + if (RAC_FAILED(result)) { + *out_exists = RAC_FALSE; + if (out_has_contents != nullptr) { + *out_has_contents = RAC_FALSE; + } + return RAC_SUCCESS; + } + + // Check existence + rac_bool_t is_directory = RAC_FALSE; + rac_bool_t exists = cb->path_exists(folder_path, &is_directory, cb->user_data); + + *out_exists = (exists == RAC_TRUE && is_directory == RAC_TRUE) ? RAC_TRUE : RAC_FALSE; + + // Check contents if requested + if (out_has_contents != nullptr) { + *out_has_contents = RAC_FALSE; + if (*out_exists == RAC_TRUE) { + char** entries = nullptr; + size_t count = 0; + result = cb->list_directory(folder_path, &entries, &count, cb->user_data); + if (RAC_SUCCEEDED(result)) { + // Count non-dot entries + for (size_t i = 0; i < count; i++) { + if (std::strcmp(entries[i], ".") != 0 && std::strcmp(entries[i], "..") != 0) { + *out_has_contents = RAC_TRUE; + break; + } + } + cb->free_entries(entries, count, cb->user_data); + } + } + } + + return RAC_SUCCESS; +} + +rac_result_t rac_file_manager_delete_model(const rac_file_callbacks_t* cb, const char* model_id, + rac_inference_framework_t framework) { + if (model_id == nullptr) { + return RAC_ERROR_NULL_POINTER; + } + + rac_result_t result = validate_callbacks(cb); + if (RAC_FAILED(result)) { + return result; + } + + // Get model folder path + char folder_path[1024]; + result = rac_model_paths_get_model_folder(model_id, framework, folder_path, sizeof(folder_path)); + if (RAC_FAILED(result)) { + return result; + } + + // Check if it exists + rac_bool_t is_directory = RAC_FALSE; + rac_bool_t exists = cb->path_exists(folder_path, &is_directory, cb->user_data); + + if (exists != RAC_TRUE) { + return RAC_ERROR_FILE_NOT_FOUND; + } + + // Delete recursively + result = cb->delete_path(folder_path, 1 /* recursive */, cb->user_data); + if (RAC_FAILED(result)) { + RAC_LOG_ERROR(LOG_CATEGORY, "Failed to delete model folder"); + return result; + } + + RAC_LOG_INFO(LOG_CATEGORY, "Deleted model folder"); + return RAC_SUCCESS; +} + +// ============================================================================= +// DIRECTORY SIZE CALCULATION +// ============================================================================= + +rac_result_t rac_file_manager_calculate_dir_size(const rac_file_callbacks_t* cb, const char* path, + int64_t* out_size) { + if (path == nullptr || out_size == nullptr) { + return RAC_ERROR_NULL_POINTER; + } + + rac_result_t result = validate_callbacks(cb); + if (RAC_FAILED(result)) { + return result; + } + + // Check if path exists + rac_bool_t is_directory = RAC_FALSE; + rac_bool_t exists = cb->path_exists(path, &is_directory, cb->user_data); + + if (exists != RAC_TRUE) { + *out_size = 0; + return RAC_SUCCESS; + } + + if (is_directory != RAC_TRUE) { + // Single file + *out_size = cb->get_file_size(path, cb->user_data); + if (*out_size < 0) { + *out_size = 0; + } + return RAC_SUCCESS; + } + + return calculate_dir_size_recursive(cb, path, out_size); +} + +rac_result_t rac_file_manager_models_storage_used(const rac_file_callbacks_t* cb, + int64_t* out_size) { + if (out_size == nullptr) { + return RAC_ERROR_NULL_POINTER; + } + + rac_result_t result = validate_callbacks(cb); + if (RAC_FAILED(result)) { + return result; + } + + char models_dir[1024]; + result = rac_model_paths_get_models_directory(models_dir, sizeof(models_dir)); + if (RAC_FAILED(result)) { + *out_size = 0; + return result; + } + + return rac_file_manager_calculate_dir_size(cb, models_dir, out_size); +} + +// ============================================================================= +// CACHE & TEMP MANAGEMENT +// ============================================================================= + +rac_result_t rac_file_manager_clear_cache(const rac_file_callbacks_t* cb) { + rac_result_t result = validate_callbacks(cb); + if (RAC_FAILED(result)) { + return result; + } + + char cache_dir[1024]; + result = rac_model_paths_get_cache_directory(cache_dir, sizeof(cache_dir)); + if (RAC_FAILED(result)) { + return result; + } + + result = clear_directory_impl(cb, cache_dir); + if (RAC_SUCCEEDED(result)) { + RAC_LOG_INFO(LOG_CATEGORY, "Cache cleared"); + } + return result; +} + +rac_result_t rac_file_manager_clear_temp(const rac_file_callbacks_t* cb) { + rac_result_t result = validate_callbacks(cb); + if (RAC_FAILED(result)) { + return result; + } + + char temp_dir[1024]; + result = rac_model_paths_get_temp_directory(temp_dir, sizeof(temp_dir)); + if (RAC_FAILED(result)) { + return result; + } + + result = clear_directory_impl(cb, temp_dir); + if (RAC_SUCCEEDED(result)) { + RAC_LOG_INFO(LOG_CATEGORY, "Temp directory cleared"); + } + return result; +} + +rac_result_t rac_file_manager_cache_size(const rac_file_callbacks_t* cb, int64_t* out_size) { + if (out_size == nullptr) { + return RAC_ERROR_NULL_POINTER; + } + + rac_result_t result = validate_callbacks(cb); + if (RAC_FAILED(result)) { + return result; + } + + char cache_dir[1024]; + result = rac_model_paths_get_cache_directory(cache_dir, sizeof(cache_dir)); + if (RAC_FAILED(result)) { + *out_size = 0; + return result; + } + + return rac_file_manager_calculate_dir_size(cb, cache_dir, out_size); +} + +// ============================================================================= +// STORAGE INFO +// ============================================================================= + +rac_result_t rac_file_manager_get_storage_info(const rac_file_callbacks_t* cb, + rac_file_manager_storage_info_t* out_info) { + if (out_info == nullptr) { + return RAC_ERROR_NULL_POINTER; + } + + rac_result_t result = validate_callbacks(cb); + if (RAC_FAILED(result)) { + return result; + } + + // Zero-initialize + std::memset(out_info, 0, sizeof(rac_file_manager_storage_info_t)); + + // Device storage (requires get_available_space and get_total_space) + if (cb->get_available_space != nullptr) { + out_info->device_free = cb->get_available_space(cb->user_data); + } + if (cb->get_total_space != nullptr) { + out_info->device_total = cb->get_total_space(cb->user_data); + } + + // Models size + char models_dir[1024]; + if (RAC_SUCCEEDED(rac_model_paths_get_models_directory(models_dir, sizeof(models_dir)))) { + rac_file_manager_calculate_dir_size(cb, models_dir, &out_info->models_size); + } + + // Cache size + char cache_dir[1024]; + if (RAC_SUCCEEDED(rac_model_paths_get_cache_directory(cache_dir, sizeof(cache_dir)))) { + rac_file_manager_calculate_dir_size(cb, cache_dir, &out_info->cache_size); + } + + // Temp size + char temp_dir[1024]; + if (RAC_SUCCEEDED(rac_model_paths_get_temp_directory(temp_dir, sizeof(temp_dir)))) { + rac_file_manager_calculate_dir_size(cb, temp_dir, &out_info->temp_size); + } + + // Total app size + out_info->total_app_size = out_info->models_size + out_info->cache_size + out_info->temp_size; + + return RAC_SUCCESS; +} + +rac_result_t rac_file_manager_check_storage(const rac_file_callbacks_t* cb, + int64_t required_bytes, + rac_storage_availability_t* out_availability) { + if (out_availability == nullptr) { + return RAC_ERROR_NULL_POINTER; + } + + rac_result_t result = validate_callbacks(cb); + if (RAC_FAILED(result)) { + return result; + } + + // Zero-initialize + std::memset(out_availability, 0, sizeof(rac_storage_availability_t)); + + // Get available space + int64_t available_space = 0; + if (cb->get_available_space != nullptr) { + available_space = cb->get_available_space(cb->user_data); + } + + out_availability->required_space = required_bytes; + out_availability->available_space = available_space; + + // Check availability + if (available_space >= required_bytes) { + out_availability->is_available = RAC_TRUE; + + // Check for warning (less than 1GB remaining after operation) + int64_t remaining = available_space - required_bytes; + if (remaining < STORAGE_WARNING_THRESHOLD) { + out_availability->has_warning = RAC_TRUE; + out_availability->recommendation = + rac_strdup("Low storage warning: less than 1 GB will remain after this download. " + "Consider freeing space by removing unused models."); + } else { + out_availability->has_warning = RAC_FALSE; + out_availability->recommendation = nullptr; + } + } else { + out_availability->is_available = RAC_FALSE; + out_availability->has_warning = RAC_TRUE; + out_availability->recommendation = + rac_strdup("Insufficient storage space for this download. " + "Please free space by removing unused models or clearing the cache."); + } + + return RAC_SUCCESS; +} + +// ============================================================================= +// DIRECTORY CLEARING (PUBLIC HELPER) +// ============================================================================= + +rac_result_t rac_file_manager_clear_directory(const rac_file_callbacks_t* cb, const char* path) { + if (path == nullptr) { + return RAC_ERROR_NULL_POINTER; + } + + rac_result_t result = validate_callbacks(cb); + if (RAC_FAILED(result)) { + return result; + } + + return clear_directory_impl(cb, path); +} diff --git a/sdk/runanywhere-commons/src/infrastructure/model_management/lora_registry.cpp b/sdk/runanywhere-commons/src/infrastructure/model_management/lora_registry.cpp index 8c5f562ac..d7e25ddc9 100644 --- a/sdk/runanywhere-commons/src/infrastructure/model_management/lora_registry.cpp +++ b/sdk/runanywhere-commons/src/infrastructure/model_management/lora_registry.cpp @@ -96,8 +96,11 @@ rac_result_t rac_lora_registry_create(rac_lora_registry_handle_t* out_handle) { void rac_lora_registry_destroy(rac_lora_registry_handle_t handle) { if (!handle) return; - for (auto& pair : handle->entries) { free_lora_entry(pair.second); } - handle->entries.clear(); + { + std::lock_guard lock(handle->mutex); + for (auto& pair : handle->entries) { free_lora_entry(pair.second); } + handle->entries.clear(); + } delete handle; RAC_LOG_DEBUG("LoraRegistry", "LoRA registry destroyed"); } diff --git a/sdk/runanywhere-commons/src/infrastructure/model_management/model_assignment.cpp b/sdk/runanywhere-commons/src/infrastructure/model_management/model_assignment.cpp index ad575705a..ea693675f 100644 --- a/sdk/runanywhere-commons/src/infrastructure/model_management/model_assignment.cpp +++ b/sdk/runanywhere-commons/src/infrastructure/model_management/model_assignment.cpp @@ -186,6 +186,11 @@ static std::vector parse_models_json(const char* json_str, si model->id = strdup(id.c_str()); model->name = strdup(name.c_str()); + if (!model->id || !model->name) { + rac_model_info_free(model); + pos = obj_end; + continue; + } model->download_url = download_url.empty() ? nullptr : strdup(download_url.c_str()); model->description = description.empty() ? nullptr : strdup(description.c_str()); model->download_size = size; diff --git a/sdk/runanywhere-commons/src/infrastructure/model_management/model_paths.cpp b/sdk/runanywhere-commons/src/infrastructure/model_management/model_paths.cpp index 4c7a081a7..1d53d88da 100644 --- a/sdk/runanywhere-commons/src/infrastructure/model_management/model_paths.cpp +++ b/sdk/runanywhere-commons/src/infrastructure/model_management/model_paths.cpp @@ -47,8 +47,15 @@ rac_result_t rac_model_paths_set_base_dir(const char* base_dir) { } const char* rac_model_paths_get_base_dir(void) { + // Use thread_local copy to avoid returning c_str() that dangles after mutex release. + // Valid until the next call from the same thread. + static thread_local std::string tl_base_dir; std::lock_guard lock(g_paths_mutex); - return g_base_dir.empty() ? nullptr : g_base_dir.c_str(); + if (g_base_dir.empty()) { + return nullptr; + } + tl_base_dir = g_base_dir; + return tl_base_dir.c_str(); } // ============================================================================= diff --git a/sdk/runanywhere-commons/src/infrastructure/model_management/model_registry.cpp b/sdk/runanywhere-commons/src/infrastructure/model_management/model_registry.cpp index d866f4f4a..529f97a29 100644 --- a/sdk/runanywhere-commons/src/infrastructure/model_management/model_registry.cpp +++ b/sdk/runanywhere-commons/src/infrastructure/model_management/model_registry.cpp @@ -15,6 +15,7 @@ #include #include #include +#include #include #include @@ -128,7 +129,10 @@ rac_result_t rac_model_registry_create(rac_model_registry_handle_t* out_handle) return RAC_ERROR_INVALID_ARGUMENT; } - rac_model_registry* registry = new rac_model_registry(); + rac_model_registry* registry = new (std::nothrow) rac_model_registry(); + if (!registry) { + return RAC_ERROR_OUT_OF_MEMORY; + } RAC_LOG_INFO("ModelRegistry", "Model registry created"); diff --git a/sdk/runanywhere-commons/src/infrastructure/model_management/model_strategy.cpp b/sdk/runanywhere-commons/src/infrastructure/model_management/model_strategy.cpp index 01c633699..3e5355b07 100644 --- a/sdk/runanywhere-commons/src/infrastructure/model_management/model_strategy.cpp +++ b/sdk/runanywhere-commons/src/infrastructure/model_management/model_strategy.cpp @@ -217,7 +217,7 @@ rac_result_t rac_model_strategy_get_download_dest(rac_inference_framework_t fram if (len >= path_size) { return RAC_ERROR_BUFFER_TOO_SMALL; } - strcpy(out_path, config->destination_folder); + memcpy(out_path, config->destination_folder, len + 1); return RAC_SUCCESS; } return RAC_ERROR_INVALID_PARAMETER; @@ -238,6 +238,9 @@ rac_result_t rac_model_strategy_post_process(rac_inference_framework_t framework if (!strategy || !strategy->post_process) { // No custom strategy - set basic result out_result->final_path = strdup(downloaded_path); + if (!out_result->final_path) { + return RAC_ERROR_OUT_OF_MEMORY; + } out_result->downloaded_size = 0; // Unknown out_result->was_extracted = RAC_FALSE; out_result->file_count = 1; diff --git a/sdk/runanywhere-commons/src/infrastructure/network/api_types.cpp b/sdk/runanywhere-commons/src/infrastructure/network/api_types.cpp index a9f0ab271..99692aa97 100644 --- a/sdk/runanywhere-commons/src/infrastructure/network/api_types.cpp +++ b/sdk/runanywhere-commons/src/infrastructure/network/api_types.cpp @@ -530,8 +530,13 @@ char* rac_telemetry_batch_to_json(const rac_telemetry_batch_t* batch) { if (!batch) return nullptr; - // Estimate size needed - size_t buf_size = 1024 + (batch->event_count * 8192); + // Estimate size needed (with overflow check) + static constexpr size_t kPerEventEstimate = 8192; + static constexpr size_t kBaseEstimate = 1024; + if (batch->event_count > (SIZE_MAX - kBaseEstimate) / kPerEventEstimate) { + return nullptr; + } + size_t buf_size = kBaseEstimate + (batch->event_count * kPerEventEstimate); char* buf = (char*)malloc(buf_size); if (!buf) return nullptr; diff --git a/sdk/runanywhere-commons/src/infrastructure/network/auth_manager.cpp b/sdk/runanywhere-commons/src/infrastructure/network/auth_manager.cpp index 03243008a..015b6cda3 100644 --- a/sdk/runanywhere-commons/src/infrastructure/network/auth_manager.cpp +++ b/sdk/runanywhere-commons/src/infrastructure/network/auth_manager.cpp @@ -153,14 +153,25 @@ static int update_auth_state_from_response(const rac_auth_response_t* response) return -1; } + // Pre-allocate required strings before modifying state + char* new_access = str_dup(response->access_token); + char* new_refresh = str_dup(response->refresh_token); + if (!new_access || !new_refresh) { + free(new_access); + free(new_refresh); + return -1; + } + // Free old strings free_auth_state_strings(); - // Copy new values - g_auth_state.access_token = str_dup(response->access_token); - g_auth_state.refresh_token = str_dup(response->refresh_token); + // Assign pre-allocated required values + g_auth_state.access_token = new_access; + g_auth_state.refresh_token = new_refresh; + + // Copy optional values (NULL is acceptable) g_auth_state.device_id = str_dup(response->device_id); - g_auth_state.user_id = str_dup(response->user_id); // Can be NULL + g_auth_state.user_id = str_dup(response->user_id); g_auth_state.organization_id = str_dup(response->organization_id); // Calculate expiry timestamp @@ -287,6 +298,9 @@ int rac_auth_load_stored_tokens(void) { g_auth_state.token_expires_at = 0; } + // Clear sensitive data from stack buffer + memset(buffer, 0, sizeof(buffer)); + return 0; } diff --git a/sdk/runanywhere-commons/src/infrastructure/network/http_client.cpp b/sdk/runanywhere-commons/src/infrastructure/network/http_client.cpp index c77718ac1..bbd3a7a9d 100644 --- a/sdk/runanywhere-commons/src/infrastructure/network/http_client.cpp +++ b/sdk/runanywhere-commons/src/infrastructure/network/http_client.cpp @@ -73,6 +73,10 @@ rac_http_request_t* rac_http_request_create(rac_http_method_t method, const char request->method = method; request->url = str_dup(url); + if (url && !request->url) { + free(request); + return nullptr; + } request->timeout_ms = 30000; // Default 30s timeout return request; @@ -100,8 +104,15 @@ void rac_http_request_add_header(rac_http_request_t* request, const char* key, c return; request->headers = new_headers; - request->headers[request->header_count].key = str_dup(key); - request->headers[request->header_count].value = str_dup(value); + char* dup_key = str_dup(key); + char* dup_value = str_dup(value); + if (!dup_key || !dup_value) { + free(dup_key); + free(dup_value); + return; + } + request->headers[request->header_count].key = dup_key; + request->headers[request->header_count].value = dup_value; request->header_count = new_count; } diff --git a/sdk/runanywhere-commons/src/infrastructure/storage/storage_analyzer.cpp b/sdk/runanywhere-commons/src/infrastructure/storage/storage_analyzer.cpp index a620af0f2..bebad9a27 100644 --- a/sdk/runanywhere-commons/src/infrastructure/storage/storage_analyzer.cpp +++ b/sdk/runanywhere-commons/src/infrastructure/storage/storage_analyzer.cpp @@ -119,6 +119,12 @@ rac_result_t rac_storage_analyzer_analyze(rac_storage_analyzer_handle_t handle, // Copy model info metrics->model_id = model->id ? strdup(model->id) : nullptr; metrics->model_name = model->name ? strdup(model->name) : nullptr; + if ((model->id && !metrics->model_id) || (model->name && !metrics->model_name)) { + free(const_cast(metrics->model_id)); + free(const_cast(metrics->model_name)); + memset(metrics, 0, sizeof(rac_model_storage_metrics_t)); + continue; + } metrics->framework = model->framework; metrics->format = model->format; metrics->artifact_info = model->artifact_info; @@ -179,6 +185,13 @@ rac_result_t rac_storage_analyzer_get_model_metrics(rac_storage_analyzer_handle_ // Copy model info out_metrics->model_id = model->id ? strdup(model->id) : nullptr; out_metrics->model_name = model->name ? strdup(model->name) : nullptr; + if ((model->id && !out_metrics->model_id) || (model->name && !out_metrics->model_name)) { + free(const_cast(out_metrics->model_id)); + free(const_cast(out_metrics->model_name)); + memset(out_metrics, 0, sizeof(rac_model_storage_metrics_t)); + rac_model_info_free(model); + return RAC_ERROR_OUT_OF_MEMORY; + } out_metrics->framework = model->framework; out_metrics->format = model->format; out_metrics->artifact_info = model->artifact_info; @@ -230,15 +243,21 @@ rac_result_t rac_storage_analyzer_check_available(rac_storage_analyzer_handle_t out_availability->is_available = available > required ? RAC_TRUE : RAC_FALSE; out_availability->has_warning = available < required * 2 ? RAC_TRUE : RAC_FALSE; - // Generate recommendation message + // Generate recommendation message (NULL recommendation is acceptable on OOM) if (out_availability->is_available == RAC_FALSE) { int64_t shortfall = required - available; // Simple message - platform can format with locale-specific formatter char msg[256]; snprintf(msg, sizeof(msg), "Need %lld more bytes of space.", (long long)shortfall); out_availability->recommendation = strdup(msg); + if (!out_availability->recommendation) { + return RAC_ERROR_OUT_OF_MEMORY; + } } else if (out_availability->has_warning == RAC_TRUE) { out_availability->recommendation = strdup("Storage space is getting low."); + if (!out_availability->recommendation) { + return RAC_ERROR_OUT_OF_MEMORY; + } } return RAC_SUCCESS; diff --git a/sdk/runanywhere-commons/src/infrastructure/telemetry/telemetry_manager.cpp b/sdk/runanywhere-commons/src/infrastructure/telemetry/telemetry_manager.cpp index 2a98737ff..f8d64ef40 100644 --- a/sdk/runanywhere-commons/src/infrastructure/telemetry/telemetry_manager.cpp +++ b/sdk/runanywhere-commons/src/infrastructure/telemetry/telemetry_manager.cpp @@ -10,6 +10,7 @@ #include #include #include +#include #include #include #include @@ -61,36 +62,19 @@ int64_t get_current_timestamp_ms() { return std::chrono::duration_cast(duration).count(); } -// Thread-safe seeding flag -std::once_flag rand_seed_flag; - -// Ensure random number generator is seeded exactly once (thread-safe) -void ensure_rand_seeded() { - std::call_once(rand_seed_flag, []() { - // Seed with combination of time and memory address for better entropy - auto now = std::chrono::high_resolution_clock::now(); - auto nanos = - std::chrono::duration_cast(now.time_since_epoch()).count(); - unsigned int seed = - static_cast(nanos ^ reinterpret_cast(&rand_seed_flag)); - srand(seed); - }); -} - -// Generate UUID +// Generate UUID using thread-safe RNG std::string generate_uuid() { - // Ensure random number generator is seeded - ensure_rand_seeded(); + static thread_local std::mt19937 gen(std::random_device{}()); + static thread_local std::uniform_int_distribution<> dis(0, 15); - // Simple UUID generation (not RFC4122 compliant, but sufficient for event IDs) static const char hex[] = "0123456789abcdef"; std::string uuid = "xxxxxxxx-xxxx-4xxx-yxxx-xxxxxxxxxxxx"; for (char& c : uuid) { if (c == 'x') { - c = hex[rand() % 16]; + c = hex[dis(gen)]; } else if (c == 'y') { - c = hex[(rand() % 4) + 8]; // 8, 9, a, or b + c = hex[(dis(gen) % 4) + 8]; // 8, 9, a, or b } } diff --git a/sdk/runanywhere-commons/src/jni/CMakeLists.txt b/sdk/runanywhere-commons/src/jni/CMakeLists.txt index 5dbb6efec..7b2be025c 100644 --- a/sdk/runanywhere-commons/src/jni/CMakeLists.txt +++ b/sdk/runanywhere-commons/src/jni/CMakeLists.txt @@ -19,7 +19,7 @@ cmake_minimum_required(VERSION 3.14) project(runanywhere_commons_jni) -set(CMAKE_CXX_STANDARD 17) +set(CMAKE_CXX_STANDARD 20) set(CMAKE_CXX_STANDARD_REQUIRED ON) # nlohmann_json is provided by the parent CMakeLists.txt diff --git a/sdk/runanywhere-commons/src/jni/runanywhere_commons_jni.cpp b/sdk/runanywhere-commons/src/jni/runanywhere_commons_jni.cpp index 554f86965..c347aee70 100644 --- a/sdk/runanywhere-commons/src/jni/runanywhere_commons_jni.cpp +++ b/sdk/runanywhere-commons/src/jni/runanywhere_commons_jni.cpp @@ -19,9 +19,14 @@ #include #include #include +#include #include #include +#ifdef __ANDROID__ +#include +#endif + // Include runanywhere-commons C API headers #include "rac/core/rac_analytics_events.h" #include "rac/core/rac_audio_utils.h" @@ -44,34 +49,21 @@ #include "rac/infrastructure/telemetry/rac_telemetry_manager.h" #include "rac/infrastructure/telemetry/rac_telemetry_types.h" #include "rac/features/llm/rac_tool_calling.h" +#include "rac/infrastructure/download/rac_download_orchestrator.h" +#include "rac/infrastructure/extraction/rac_extraction.h" +#include "rac/infrastructure/file_management/rac_file_manager.h" // NOTE: Backend headers are NOT included here. // Backend registration is handled by their respective JNI libraries: // - backends/llamacpp/src/jni/rac_backend_llamacpp_jni.cpp // - backends/onnx/src/jni/rac_backend_onnx_jni.cpp -#ifdef __ANDROID__ -#include -#define TAG "RACCommonsJNI" -#define LOGi(...) __android_log_print(ANDROID_LOG_INFO, TAG, __VA_ARGS__) -#define LOGe(...) __android_log_print(ANDROID_LOG_ERROR, TAG, __VA_ARGS__) -#define LOGw(...) __android_log_print(ANDROID_LOG_WARN, TAG, __VA_ARGS__) -#define LOGd(...) __android_log_print(ANDROID_LOG_DEBUG, TAG, __VA_ARGS__) -#else -#include -#define LOGi(...) \ - fprintf(stdout, "[INFO] " __VA_ARGS__); \ - fprintf(stdout, "\n") -#define LOGe(...) \ - fprintf(stderr, "[ERROR] " __VA_ARGS__); \ - fprintf(stderr, "\n") -#define LOGw(...) \ - fprintf(stdout, "[WARN] " __VA_ARGS__); \ - fprintf(stdout, "\n") -#define LOGd(...) \ - fprintf(stdout, "[DEBUG] " __VA_ARGS__); \ - fprintf(stdout, "\n") -#endif +// Route JNI logging through unified RAC_LOG_* system +static const char* JNI_LOG_TAG = "JNI.Commons"; +#define LOGi(...) RAC_LOG_INFO(JNI_LOG_TAG, __VA_ARGS__) +#define LOGe(...) RAC_LOG_ERROR(JNI_LOG_TAG, __VA_ARGS__) +#define LOGw(...) RAC_LOG_WARNING(JNI_LOG_TAG, __VA_ARGS__) +#define LOGd(...) RAC_LOG_DEBUG(JNI_LOG_TAG, __VA_ARGS__) // ============================================================================= // Global State for Platform Adapter JNI Callbacks @@ -139,6 +131,8 @@ static std::string getCString(JNIEnv* env, jstring str) { if (str == nullptr) return ""; const char* chars = env->GetStringUTFChars(str, nullptr); + if (chars == nullptr) + return ""; std::string result(chars); env->ReleaseStringUTFChars(str, chars); return result; @@ -162,8 +156,14 @@ static void jni_log_callback(rac_log_level_t level, const char* tag, const char* void* user_data) { JNIEnv* env = getJNIEnv(); if (env == nullptr || g_platform_adapter == nullptr || g_method_log == nullptr) { - // Fallback to native logging - LOGd("[%s] %s", tag ? tag : "RAC", message ? message : ""); + // Fallback to direct native logging (NOT through RAC_LOG_* to avoid recursion, + // since this function IS the platform adapter's log callback) +#ifdef __ANDROID__ + __android_log_print(ANDROID_LOG_DEBUG, "RACCommonsJNI", "[%s] %s", + tag ? tag : "RAC", message ? message : ""); +#else + fprintf(stdout, "[DEBUG] [%s] %s\n", tag ? tag : "RAC", message ? message : ""); +#endif return; } @@ -208,8 +208,13 @@ static rac_result_t jni_file_read_callback(const char* path, void** out_data, si } jsize len = env->GetArrayLength(result); - *out_size = static_cast(len); *out_data = malloc(len); + if (*out_data == nullptr) { + *out_size = 0; + env->DeleteLocalRef(result); + return RAC_ERROR_OUT_OF_MEMORY; + } + *out_size = static_cast(len); env->GetByteArrayRegion(result, 0, len, reinterpret_cast(*out_data)); env->DeleteLocalRef(result); @@ -266,10 +271,18 @@ static rac_result_t jni_secure_get_callback(const char* key, char** out_value, v } const char* chars = env->GetStringUTFChars(result, nullptr); + if (!chars) { + env->DeleteLocalRef(result); + *out_value = nullptr; + return RAC_ERROR_INTERNAL; + } *out_value = strdup(chars); env->ReleaseStringUTFChars(result, chars); env->DeleteLocalRef(result); + if (!*out_value) { + return RAC_ERROR_OUT_OF_MEMORY; + } return RAC_SUCCESS; } @@ -605,6 +618,7 @@ Java_com_runanywhere_sdk_native_bridge_RunAnywhereBridge_racLlmComponentGenerate } LOGw("racLlmComponentGenerate: result.text is null"); + rac_llm_result_free(&result); return env->NewStringUTF("{\"text\":\"\",\"completion_tokens\":0}"); } @@ -1237,6 +1251,19 @@ Java_com_runanywhere_sdk_native_bridge_RunAnywhereBridge_racLoraRegistryRegister entry.description = desc_str ? strdup(desc_str) : nullptr; entry.download_url = url_str ? strdup(url_str) : nullptr; entry.filename = file_str ? strdup(file_str) : nullptr; + + // Check mandatory field allocation + if (id_str && !entry.id) { + free(entry.name); free(entry.description); + free(entry.download_url); free(entry.filename); + if (id_str) env->ReleaseStringUTFChars(id, id_str); + if (name_str) env->ReleaseStringUTFChars(name, name_str); + if (desc_str) env->ReleaseStringUTFChars(description, desc_str); + if (url_str) env->ReleaseStringUTFChars(downloadUrl, url_str); + if (file_str) env->ReleaseStringUTFChars(filename, file_str); + return RAC_ERROR_OUT_OF_MEMORY; + } + entry.file_size = fileSize; entry.default_scale = defaultScale; @@ -1449,6 +1476,7 @@ Java_com_runanywhere_sdk_native_bridge_RunAnywhereBridge_racSttComponentTranscri if (status != RAC_SUCCESS) { LOGe("STT transcribe failed with status: %d", status); + rac_stt_result_free(&result); return nullptr; } @@ -1604,6 +1632,7 @@ Java_com_runanywhere_sdk_native_bridge_RunAnywhereBridge_racTtsComponentSynthesi textStr.c_str(), &options, &result); if (status != RAC_SUCCESS || result.audio_data == nullptr) { + rac_tts_result_free(&result); return nullptr; } @@ -1884,15 +1913,19 @@ static rac_model_info_t* javaModelInfoToC(JNIEnv* env, jobject modelInfo) { jstring jId = (jstring)env->GetObjectField(modelInfo, idField); if (jId) { const char* str = env->GetStringUTFChars(jId, nullptr); - model->id = strdup(str); - env->ReleaseStringUTFChars(jId, str); + if (str) { + model->id = strdup(str); + env->ReleaseStringUTFChars(jId, str); + } } jstring jName = (jstring)env->GetObjectField(modelInfo, nameField); if (jName) { const char* str = env->GetStringUTFChars(jName, nullptr); - model->name = strdup(str); - env->ReleaseStringUTFChars(jName, str); + if (str) { + model->name = strdup(str); + env->ReleaseStringUTFChars(jName, str); + } } model->category = static_cast(env->GetIntField(modelInfo, categoryField)); @@ -1903,15 +1936,19 @@ static rac_model_info_t* javaModelInfoToC(JNIEnv* env, jobject modelInfo) { jstring jDownloadUrl = (jstring)env->GetObjectField(modelInfo, downloadUrlField); if (jDownloadUrl) { const char* str = env->GetStringUTFChars(jDownloadUrl, nullptr); - model->download_url = strdup(str); - env->ReleaseStringUTFChars(jDownloadUrl, str); + if (str) { + model->download_url = strdup(str); + env->ReleaseStringUTFChars(jDownloadUrl, str); + } } jstring jLocalPath = (jstring)env->GetObjectField(modelInfo, localPathField); if (jLocalPath) { const char* str = env->GetStringUTFChars(jLocalPath, nullptr); - model->local_path = strdup(str); - env->ReleaseStringUTFChars(jLocalPath, str); + if (str) { + model->local_path = strdup(str); + env->ReleaseStringUTFChars(jLocalPath, str); + } } model->download_size = env->GetLongField(modelInfo, downloadSizeField); @@ -1922,8 +1959,16 @@ static rac_model_info_t* javaModelInfoToC(JNIEnv* env, jobject modelInfo) { jstring jDesc = (jstring)env->GetObjectField(modelInfo, descriptionField); if (jDesc) { const char* str = env->GetStringUTFChars(jDesc, nullptr); - model->description = strdup(str); - env->ReleaseStringUTFChars(jDesc, str); + if (str) { + model->description = strdup(str); + env->ReleaseStringUTFChars(jDesc, str); + } + } + + // Verify mandatory field allocation (id is required for all model operations) + if (jId && !model->id) { + rac_model_info_free(model); + return nullptr; } return model; @@ -2023,6 +2068,13 @@ Java_com_runanywhere_sdk_native_bridge_RunAnywhereBridge_racModelRegistrySave( if (desc_str) env->ReleaseStringUTFChars(description, desc_str); + // Check mandatory field allocation + if ((id_str && !model->id) || (name_str && !model->name)) { + LOGe("OOM: failed to allocate mandatory model fields"); + rac_model_info_free(model); + return RAC_ERROR_OUT_OF_MEMORY; + } + LOGi("Saving model to C++ registry: %s (framework=%d)", model->id, framework); rac_result_t result = rac_model_registry_save(registry, model); @@ -2256,7 +2308,12 @@ static rac_result_t model_assignment_http_get_callback(const char* endpoint, out_response->result = RAC_SUCCESS; out_response->status_code = 200; out_response->response_body = strdup(response_str); - out_response->response_length = strlen(response_str); + if (!out_response->response_body) { + out_response->result = RAC_ERROR_OUT_OF_MEMORY; + result = RAC_ERROR_OUT_OF_MEMORY; + } else { + out_response->response_length = strlen(response_str); + } } } env->ReleaseStringUTFChars(jResponse, response_str); @@ -2681,6 +2738,10 @@ static const char* jni_device_get_id(void* user_data) { if (jResult) { const char* str = env->GetStringUTFChars(jResult, nullptr); + if (str == nullptr) { + env->DeleteLocalRef(jResult); + return ""; + } // Lock mutex to protect g_cached_device_id from concurrent access std::lock_guard lock(g_device_jni_state.mtx); @@ -3707,7 +3768,10 @@ static void fillVlmImage(rac_vlm_image_t& image, case RAC_VLM_IMAGE_FORMAT_RGB_PIXELS: if (imageData != nullptr) { jsize len = env->GetArrayLength(imageData); - auto* buf = new uint8_t[len]; + auto* buf = new (std::nothrow) uint8_t[len]; + if (!buf) { + return; + } env->GetByteArrayRegion(imageData, 0, len, reinterpret_cast(buf)); image.pixel_data = buf; image.data_size = static_cast(len); @@ -3875,6 +3939,7 @@ Java_com_runanywhere_sdk_native_bridge_RunAnywhereBridge_racVlmComponentProcess( if (status != RAC_SUCCESS) { LOGe("racVlmComponentProcess failed with status=%d", status); + rac_vlm_result_free(&result); return nullptr; } @@ -4135,6 +4200,399 @@ Java_com_runanywhere_sdk_native_bridge_RunAnywhereBridge_racVlmComponentGetMetri return env->NewStringUTF(json.c_str()); } +// ============================================================================= +// ARCHIVE EXTRACTION +// ============================================================================= + +JNIEXPORT jint JNICALL +Java_com_runanywhere_sdk_native_1bridge_RunAnywhereBridge_nativeExtractArchive(JNIEnv* env, + jobject /* thiz */, + jstring jArchivePath, + jstring jDestDir) { + std::string archivePath = getCString(env, jArchivePath); + std::string destDir = getCString(env, jDestDir); + + LOGi("Extracting archive: %s -> %s", archivePath.c_str(), destDir.c_str()); + + rac_extraction_result_t result = {}; + rac_result_t status = + rac_extract_archive_native(archivePath.c_str(), destDir.c_str(), nullptr /* default options */, + nullptr /* no progress */, nullptr, &result); + + if (RAC_SUCCEEDED(status)) { + LOGi("Extraction complete: %d files, %lld bytes", result.files_extracted, + static_cast(result.bytes_extracted)); + } else { + LOGe("Extraction failed with code: %d", status); + } + + return static_cast(status); +} + +JNIEXPORT jint JNICALL +Java_com_runanywhere_sdk_native_1bridge_RunAnywhereBridge_nativeDetectArchiveType( + JNIEnv* env, jobject /* thiz */, jstring jFilePath) { + std::string filePath = getCString(env, jFilePath); + rac_archive_type_t type = RAC_ARCHIVE_TYPE_NONE; + rac_detect_archive_type(filePath.c_str(), &type); + return static_cast(type); +} + +// ============================================================================= +// DOWNLOAD ORCHESTRATOR +// ============================================================================= + +JNIEXPORT jstring JNICALL +Java_com_runanywhere_sdk_native_1bridge_RunAnywhereBridge_nativeFindModelPathAfterExtraction( + JNIEnv* env, jobject /* thiz */, jstring jExtractedDir, jint jStructure, jint jFramework, + jint jFormat) { + std::string extractedDir = getCString(env, jExtractedDir); + + char outPath[4096]; + rac_result_t result = rac_find_model_path_after_extraction( + extractedDir.c_str(), static_cast(jStructure), + static_cast(jFramework), + static_cast(jFormat), outPath, sizeof(outPath)); + + if (RAC_SUCCEEDED(result)) { + return env->NewStringUTF(outPath); + } + return env->NewStringUTF(extractedDir.c_str()); +} + +JNIEXPORT jboolean JNICALL +Java_com_runanywhere_sdk_native_1bridge_RunAnywhereBridge_nativeDownloadRequiresExtraction( + JNIEnv* env, jobject /* thiz */, jstring jUrl) { + std::string url = getCString(env, jUrl); + return static_cast(rac_download_requires_extraction(url.c_str()) == RAC_TRUE); +} + +JNIEXPORT jstring JNICALL +Java_com_runanywhere_sdk_native_1bridge_RunAnywhereBridge_nativeComputeDownloadDestination( + JNIEnv* env, jobject /* thiz */, jstring jModelId, jstring jDownloadUrl, jint jFramework, + jint jFormat) { + std::string modelId = getCString(env, jModelId); + std::string downloadUrl = getCString(env, jDownloadUrl); + + char outPath[4096]; + rac_bool_t needsExtraction = RAC_FALSE; + rac_result_t result = rac_download_compute_destination( + modelId.c_str(), downloadUrl.c_str(), + static_cast(jFramework), + static_cast(jFormat), outPath, sizeof(outPath), &needsExtraction); + + if (RAC_SUCCEEDED(result)) { + return env->NewStringUTF(outPath); + } + return nullptr; +} + +// ============================================================================= +// File Manager JNI Wrappers +// ============================================================================= + +// Global reference for file callbacks object +static jobject g_file_callbacks_obj = nullptr; +static jmethodID g_fc_create_directory = nullptr; +static jmethodID g_fc_delete_path = nullptr; +static jmethodID g_fc_list_directory = nullptr; +static jmethodID g_fc_path_exists = nullptr; +static jmethodID g_fc_is_directory = nullptr; +static jmethodID g_fc_get_file_size = nullptr; +static jmethodID g_fc_get_available_space = nullptr; +static jmethodID g_fc_get_total_space = nullptr; + +// JNI file callback implementations +static rac_result_t jni_fc_create_directory(const char* path, int recursive, void* user_data) { + JNIEnv* env = getJNIEnv(); + if (env == nullptr || g_file_callbacks_obj == nullptr) return RAC_ERROR_NOT_INITIALIZED; + jstring jPath = env->NewStringUTF(path); + jint result = env->CallIntMethod(g_file_callbacks_obj, g_fc_create_directory, jPath, + static_cast(recursive != 0)); + env->DeleteLocalRef(jPath); + return static_cast(result); +} + +static rac_result_t jni_fc_delete_path(const char* path, int recursive, void* user_data) { + JNIEnv* env = getJNIEnv(); + if (env == nullptr || g_file_callbacks_obj == nullptr) return RAC_ERROR_NOT_INITIALIZED; + jstring jPath = env->NewStringUTF(path); + jint result = env->CallIntMethod(g_file_callbacks_obj, g_fc_delete_path, jPath, + static_cast(recursive != 0)); + env->DeleteLocalRef(jPath); + return static_cast(result); +} + +static rac_result_t jni_fc_list_directory(const char* path, char*** out_entries, size_t* out_count, + void* user_data) { + JNIEnv* env = getJNIEnv(); + if (env == nullptr || g_file_callbacks_obj == nullptr) return RAC_ERROR_NOT_INITIALIZED; + + jstring jPath = env->NewStringUTF(path); + jobjectArray jEntries = static_cast( + env->CallObjectMethod(g_file_callbacks_obj, g_fc_list_directory, jPath)); + env->DeleteLocalRef(jPath); + + if (jEntries == nullptr) { + *out_entries = nullptr; + *out_count = 0; + return RAC_ERROR_FILE_NOT_FOUND; + } + + jsize count = env->GetArrayLength(jEntries); + auto** entries = static_cast(std::malloc(count * sizeof(char*))); + if (entries == nullptr) { + env->DeleteLocalRef(jEntries); + return RAC_ERROR_OUT_OF_MEMORY; + } + + for (jsize i = 0; i < count; i++) { + auto jEntry = static_cast(env->GetObjectArrayElement(jEntries, i)); + const char* entryChars = env->GetStringUTFChars(jEntry, nullptr); + entries[i] = strdup(entryChars); + env->ReleaseStringUTFChars(jEntry, entryChars); + env->DeleteLocalRef(jEntry); + } + + env->DeleteLocalRef(jEntries); + *out_entries = entries; + *out_count = static_cast(count); + return RAC_SUCCESS; +} + +static void jni_fc_free_entries(char** entries, size_t count, void* user_data) { + if (entries == nullptr) return; + for (size_t i = 0; i < count; i++) { + std::free(entries[i]); + } + std::free(entries); +} + +static rac_bool_t jni_fc_path_exists(const char* path, rac_bool_t* out_is_directory, + void* user_data) { + JNIEnv* env = getJNIEnv(); + if (env == nullptr || g_file_callbacks_obj == nullptr) return RAC_FALSE; + + jstring jPath = env->NewStringUTF(path); + jboolean exists = env->CallBooleanMethod(g_file_callbacks_obj, g_fc_path_exists, jPath); + + if (out_is_directory != nullptr && exists) { + jboolean isDir = env->CallBooleanMethod(g_file_callbacks_obj, g_fc_is_directory, jPath); + *out_is_directory = isDir ? RAC_TRUE : RAC_FALSE; + } + + env->DeleteLocalRef(jPath); + return exists ? RAC_TRUE : RAC_FALSE; +} + +static int64_t jni_fc_get_file_size(const char* path, void* user_data) { + JNIEnv* env = getJNIEnv(); + if (env == nullptr || g_file_callbacks_obj == nullptr) return -1; + jstring jPath = env->NewStringUTF(path); + jlong size = env->CallLongMethod(g_file_callbacks_obj, g_fc_get_file_size, jPath); + env->DeleteLocalRef(jPath); + return static_cast(size); +} + +static int64_t jni_fc_get_available_space(void* user_data) { + JNIEnv* env = getJNIEnv(); + if (env == nullptr || g_file_callbacks_obj == nullptr) return -1; + return static_cast( + env->CallLongMethod(g_file_callbacks_obj, g_fc_get_available_space)); +} + +static int64_t jni_fc_get_total_space(void* user_data) { + JNIEnv* env = getJNIEnv(); + if (env == nullptr || g_file_callbacks_obj == nullptr) return -1; + return static_cast(env->CallLongMethod(g_file_callbacks_obj, g_fc_get_total_space)); +} + +/** + * Build rac_file_callbacks_t from registered JNI callbacks. + */ +static rac_file_callbacks_t build_jni_file_callbacks() { + rac_file_callbacks_t cb = {}; + cb.create_directory = jni_fc_create_directory; + cb.delete_path = jni_fc_delete_path; + cb.list_directory = jni_fc_list_directory; + cb.free_entries = jni_fc_free_entries; + cb.path_exists = jni_fc_path_exists; + cb.get_file_size = jni_fc_get_file_size; + cb.get_available_space = jni_fc_get_available_space; + cb.get_total_space = jni_fc_get_total_space; + cb.user_data = nullptr; + return cb; +} + +// Register file callbacks object from Kotlin +JNIEXPORT jint JNICALL +Java_com_runanywhere_sdk_native_1bridge_RunAnywhereBridge_nativeFileManagerRegisterCallbacks( + JNIEnv* env, jobject /* thiz */, jobject callbacksObj) { + if (callbacksObj == nullptr) return RAC_ERROR_NULL_POINTER; + + // Store global reference + if (g_file_callbacks_obj != nullptr) { + env->DeleteGlobalRef(g_file_callbacks_obj); + } + g_file_callbacks_obj = env->NewGlobalRef(callbacksObj); + + // Cache method IDs + jclass cls = env->GetObjectClass(callbacksObj); + g_fc_create_directory = env->GetMethodID(cls, "createDirectory", "(Ljava/lang/String;Z)I"); + g_fc_delete_path = env->GetMethodID(cls, "deletePath", "(Ljava/lang/String;Z)I"); + g_fc_list_directory = + env->GetMethodID(cls, "listDirectory", "(Ljava/lang/String;)[Ljava/lang/String;"); + g_fc_path_exists = env->GetMethodID(cls, "pathExists", "(Ljava/lang/String;)Z"); + g_fc_is_directory = env->GetMethodID(cls, "isDirectory", "(Ljava/lang/String;)Z"); + g_fc_get_file_size = env->GetMethodID(cls, "getFileSize", "(Ljava/lang/String;)J"); + g_fc_get_available_space = env->GetMethodID(cls, "getAvailableSpace", "()J"); + g_fc_get_total_space = env->GetMethodID(cls, "getTotalSpace", "()J"); + env->DeleteLocalRef(cls); + + LOGi("File manager callbacks registered"); + return RAC_SUCCESS; +} + +// Create directory structure +JNIEXPORT jint JNICALL +Java_com_runanywhere_sdk_native_1bridge_RunAnywhereBridge_nativeFileManagerCreateDirectoryStructure( + JNIEnv* env, jobject /* thiz */) { + rac_file_callbacks_t cb = build_jni_file_callbacks(); + return static_cast(rac_file_manager_create_directory_structure(&cb)); +} + +// Calculate directory size +JNIEXPORT jlong JNICALL +Java_com_runanywhere_sdk_native_1bridge_RunAnywhereBridge_nativeFileManagerCalculateDirSize( + JNIEnv* env, jobject /* thiz */, jstring jPath) { + std::string path = getCString(env, jPath); + rac_file_callbacks_t cb = build_jni_file_callbacks(); + int64_t size = 0; + rac_result_t result = rac_file_manager_calculate_dir_size(&cb, path.c_str(), &size); + return RAC_SUCCEEDED(result) ? static_cast(size) : 0L; +} + +// Models storage used +JNIEXPORT jlong JNICALL +Java_com_runanywhere_sdk_native_1bridge_RunAnywhereBridge_nativeFileManagerModelsStorageUsed( + JNIEnv* env, jobject /* thiz */) { + rac_file_callbacks_t cb = build_jni_file_callbacks(); + int64_t size = 0; + rac_result_t result = rac_file_manager_models_storage_used(&cb, &size); + return RAC_SUCCEEDED(result) ? static_cast(size) : 0L; +} + +// Clear cache +JNIEXPORT jint JNICALL +Java_com_runanywhere_sdk_native_1bridge_RunAnywhereBridge_nativeFileManagerClearCache( + JNIEnv* env, jobject /* thiz */) { + rac_file_callbacks_t cb = build_jni_file_callbacks(); + return static_cast(rac_file_manager_clear_cache(&cb)); +} + +// Clear temp +JNIEXPORT jint JNICALL +Java_com_runanywhere_sdk_native_1bridge_RunAnywhereBridge_nativeFileManagerClearTemp( + JNIEnv* env, jobject /* thiz */) { + rac_file_callbacks_t cb = build_jni_file_callbacks(); + return static_cast(rac_file_manager_clear_temp(&cb)); +} + +// Cache size +JNIEXPORT jlong JNICALL +Java_com_runanywhere_sdk_native_1bridge_RunAnywhereBridge_nativeFileManagerCacheSize( + JNIEnv* env, jobject /* thiz */) { + rac_file_callbacks_t cb = build_jni_file_callbacks(); + int64_t size = 0; + rac_result_t result = rac_file_manager_cache_size(&cb, &size); + return RAC_SUCCEEDED(result) ? static_cast(size) : 0L; +} + +// Delete model +JNIEXPORT jint JNICALL +Java_com_runanywhere_sdk_native_1bridge_RunAnywhereBridge_nativeFileManagerDeleteModel( + JNIEnv* env, jobject /* thiz */, jstring jModelId, jint jFramework) { + std::string modelId = getCString(env, jModelId); + rac_file_callbacks_t cb = build_jni_file_callbacks(); + return static_cast(rac_file_manager_delete_model( + &cb, modelId.c_str(), static_cast(jFramework))); +} + +// Create model folder +JNIEXPORT jstring JNICALL +Java_com_runanywhere_sdk_native_1bridge_RunAnywhereBridge_nativeFileManagerCreateModelFolder( + JNIEnv* env, jobject /* thiz */, jstring jModelId, jint jFramework) { + std::string modelId = getCString(env, jModelId); + rac_file_callbacks_t cb = build_jni_file_callbacks(); + char outPath[4096]; + rac_result_t result = rac_file_manager_create_model_folder( + &cb, modelId.c_str(), static_cast(jFramework), outPath, + sizeof(outPath)); + if (RAC_SUCCEEDED(result)) { + return env->NewStringUTF(outPath); + } + return nullptr; +} + +// Model folder exists +JNIEXPORT jboolean JNICALL +Java_com_runanywhere_sdk_native_1bridge_RunAnywhereBridge_nativeFileManagerModelFolderExists( + JNIEnv* env, jobject /* thiz */, jstring jModelId, jint jFramework) { + std::string modelId = getCString(env, jModelId); + rac_file_callbacks_t cb = build_jni_file_callbacks(); + rac_bool_t exists = RAC_FALSE; + rac_file_manager_model_folder_exists(&cb, modelId.c_str(), + static_cast(jFramework), + &exists, nullptr); + return static_cast(exists == RAC_TRUE); +} + +// Check storage availability - returns JSON string with result +JNIEXPORT jstring JNICALL +Java_com_runanywhere_sdk_native_1bridge_RunAnywhereBridge_nativeFileManagerCheckStorage( + JNIEnv* env, jobject /* thiz */, jlong requiredBytes) { + rac_file_callbacks_t cb = build_jni_file_callbacks(); + rac_storage_availability_t availability = {}; + rac_result_t result = + rac_file_manager_check_storage(&cb, static_cast(requiredBytes), &availability); + + if (RAC_FAILED(result)) return nullptr; + + nlohmann::json j; + j["isAvailable"] = availability.is_available == RAC_TRUE; + j["requiredSpace"] = availability.required_space; + j["availableSpace"] = availability.available_space; + j["hasWarning"] = availability.has_warning == RAC_TRUE; + j["recommendation"] = availability.recommendation != nullptr ? availability.recommendation : ""; + + rac_storage_availability_free(&availability); + + std::string jsonStr = j.dump(); + return env->NewStringUTF(jsonStr.c_str()); +} + +// Get storage info - returns JSON string with result +JNIEXPORT jstring JNICALL +Java_com_runanywhere_sdk_native_1bridge_RunAnywhereBridge_nativeFileManagerGetStorageInfo( + JNIEnv* env, jobject /* thiz */) { + rac_file_callbacks_t cb = build_jni_file_callbacks(); + rac_file_manager_storage_info_t info = {}; + rac_result_t result = rac_file_manager_get_storage_info(&cb, &info); + + if (RAC_FAILED(result)) return nullptr; + + nlohmann::json j; + j["deviceTotal"] = info.device_total; + j["deviceFree"] = info.device_free; + j["modelsSize"] = info.models_size; + j["cacheSize"] = info.cache_size; + j["tempSize"] = info.temp_size; + j["totalAppSize"] = info.total_app_size; + + std::string jsonStr = j.dump(); + return env->NewStringUTF(jsonStr.c_str()); +} + } // extern "C" // ============================================================================= diff --git a/sdk/runanywhere-commons/src/server/CMakeLists.txt b/sdk/runanywhere-commons/src/server/CMakeLists.txt index 7946f7c83..909528f6f 100644 --- a/sdk/runanywhere-commons/src/server/CMakeLists.txt +++ b/sdk/runanywhere-commons/src/server/CMakeLists.txt @@ -66,7 +66,7 @@ if(UNIX AND NOT APPLE) endif() # Compiler options -target_compile_features(rac_server PUBLIC cxx_std_17) +target_compile_features(rac_server PUBLIC cxx_std_20) if(CMAKE_CXX_COMPILER_ID MATCHES "Clang|GNU") target_compile_options(rac_server PRIVATE -Wall -Wextra -Wpedantic) diff --git a/sdk/runanywhere-commons/src/server/http_server.cpp b/sdk/runanywhere-commons/src/server/http_server.cpp index 5645d2bf6..a6fef5b99 100644 --- a/sdk/runanywhere-commons/src/server/http_server.cpp +++ b/sdk/runanywhere-commons/src/server/http_server.cpp @@ -33,16 +33,8 @@ int64_t getCurrentTimestamp() { } std::string extractModelIdFromPath(const std::string& path) { - std::filesystem::path fsPath(path); - std::string filename = fsPath.stem().string(); - - // Remove common extensions like .gguf - if (fsPath.extension() == ".gguf") { - // Already handled by stem() - } - - // Clean up the name - return filename; + const std::filesystem::path fsPath(path); + return fsPath.stem().string(); } // ============================================================================= @@ -63,60 +55,67 @@ HttpServer::~HttpServer() { } rac_result_t HttpServer::start(const rac_server_config_t& config) { - std::lock_guard lock(mutex_); + static constexpr int SERVER_START_POLL_ITERATIONS = 100; + static constexpr int SERVER_START_POLL_MS = 100; - if (running_) { - return RAC_ERROR_SERVER_ALREADY_RUNNING; - } + { + std::lock_guard lock(mutex_); - // Validate config - if (!config.model_path) { - RAC_LOG_ERROR("Server", "model_path is required"); - return RAC_ERROR_INVALID_ARGUMENT; - } + if (running_) { + return RAC_ERROR_SERVER_ALREADY_RUNNING; + } - // Check if model file exists - if (!std::filesystem::exists(config.model_path)) { - RAC_LOG_ERROR("Server", "Model file not found: %s", config.model_path); - return RAC_ERROR_SERVER_MODEL_NOT_FOUND; - } + // Validate config + if (!config.model_path) { + RAC_LOG_ERROR("Server", "model_path is required"); + return RAC_ERROR_INVALID_ARGUMENT; + } - // Copy configuration - config_ = config; - host_ = config.host ? config.host : "127.0.0.1"; - modelPath_ = config.model_path; - modelId_ = config.model_id ? config.model_id : extractModelIdFromPath(modelPath_); + // Check if model file exists (use error_code overload to avoid exceptions) + std::error_code ec; + if (!std::filesystem::exists(config.model_path, ec) || ec) { + RAC_LOG_ERROR("Server", "Model file not found: %s", config.model_path); + return RAC_ERROR_SERVER_MODEL_NOT_FOUND; + } - // Load the model - rac_result_t rc = loadModel(modelPath_); - if (RAC_FAILED(rc)) { - return rc; - } + // Copy configuration + config_ = config; + host_ = config.host ? config.host : "127.0.0.1"; + modelPath_ = config.model_path; + modelId_ = config.model_id ? config.model_id : extractModelIdFromPath(modelPath_); - // Create HTTP server - server_ = std::make_unique(); + // Load the model + rac_result_t rc = loadModel(modelPath_); + if (RAC_FAILED(rc)) { + return rc; + } - // Setup CORS if enabled - if (config.enable_cors == RAC_TRUE) { - setupCors(); - } + // Create HTTP server + server_ = std::make_unique(); - // Setup routes - setupRoutes(); + // Setup CORS if enabled + if (config.enable_cors == RAC_TRUE) { + setupCors(); + } + + // Setup routes + setupRoutes(); - // Reset state - shouldStop_ = false; - activeRequests_ = 0; - totalRequests_ = 0; - totalTokensGenerated_ = 0; - startTime_ = std::chrono::steady_clock::now(); + // Reset state + shouldStop_ = false; + activeRequests_ = 0; + totalRequests_ = 0; + totalTokensGenerated_ = 0; + startTime_ = std::chrono::steady_clock::now(); - // Start server thread - serverThread_ = std::thread(&HttpServer::serverThread, this); + // Start server thread + serverThread_ = std::thread(&HttpServer::serverThread, this); + } + // Lock released — running_ and shouldStop_ are atomic, safe to poll without lock // Wait for server to be ready (with timeout) - for (int i = 0; i < 100; ++i) { // 10 second timeout - std::this_thread::sleep_for(std::chrono::milliseconds(100)); + for (int i = 0; i < SERVER_START_POLL_ITERATIONS; ++i) { + std::this_thread::sleep_for(std::chrono::milliseconds(SERVER_START_POLL_MS)); if (running_) { RAC_LOG_INFO("Server", "RunAnywhere Server started on http://%s:%d", host_.c_str(), config_.port); @@ -125,12 +124,15 @@ rac_result_t HttpServer::start(const rac_server_config_t& config) { } } - // Timeout - something went wrong + // Timeout - clean up (shouldStop_ is atomic) shouldStop_ = true; if (serverThread_.joinable()) { serverThread_.join(); } + + std::lock_guard lock(mutex_); unloadModel(); + server_.reset(); RAC_LOG_ERROR("Server", "Failed to start server"); return RAC_ERROR_SERVER_BIND_FAILED; @@ -172,10 +174,16 @@ bool HttpServer::isRunning() const { void HttpServer::getStatus(rac_server_status_t& status) const { std::lock_guard lock(mutex_); + // Use thread_local copies so c_str() pointers remain valid after lock release + thread_local std::string tl_host; + thread_local std::string tl_model_id; + tl_host = host_; + tl_model_id = modelId_; + status.is_running = running_ ? RAC_TRUE : RAC_FALSE; - status.host = host_.c_str(); + status.host = tl_host.c_str(); status.port = config_.port; - status.model_id = modelId_.c_str(); + status.model_id = tl_model_id.c_str(); status.active_requests = activeRequests_; status.total_requests = totalRequests_; status.total_tokens_generated = totalTokensGenerated_; @@ -197,11 +205,13 @@ int HttpServer::wait() { } void HttpServer::setRequestCallback(rac_server_request_callback_fn callback, void* userData) { + std::lock_guard lock(callback_mutex_); requestCallback_ = callback; requestCallbackUserData_ = userData; } void HttpServer::setErrorCallback(rac_server_error_callback_fn callback, void* userData) { + std::lock_guard lock(callback_mutex_); errorCallback_ = callback; errorCallbackUserData_ = userData; } @@ -213,8 +223,11 @@ void HttpServer::setupRoutes() { // GET /v1/models server_->Get("/v1/models", [this, handler](const httplib::Request& req, httplib::Response& res) { totalRequests_++; - if (requestCallback_) { - requestCallback_("GET", "/v1/models", requestCallbackUserData_); + { + std::lock_guard lock(callback_mutex_); + if (requestCallback_) { + requestCallback_("GET", "/v1/models", requestCallbackUserData_); + } } handler->handleModels(req, res); }); @@ -224,16 +237,22 @@ void HttpServer::setupRoutes() { totalRequests_++; activeRequests_++; - if (requestCallback_) { - requestCallback_("POST", "/v1/chat/completions", requestCallbackUserData_); + { + std::lock_guard lock(callback_mutex_); + if (requestCallback_) { + requestCallback_("POST", "/v1/chat/completions", requestCallbackUserData_); + } } try { handler->handleChatCompletions(req, res); } catch (const std::exception& e) { RAC_LOG_ERROR("Server", "Error handling chat completions: %s", e.what()); - if (errorCallback_) { - errorCallback_("/v1/chat/completions", RAC_ERROR_UNKNOWN, e.what(), errorCallbackUserData_); + { + std::lock_guard lock(callback_mutex_); + if (errorCallback_) { + errorCallback_("/v1/chat/completions", RAC_ERROR_UNKNOWN, e.what(), errorCallbackUserData_); + } } res.status = 500; res.set_content("{\"error\": {\"message\": \"Internal server error\"}}", "application/json"); @@ -354,27 +373,56 @@ RAC_API rac_result_t rac_server_start(const rac_server_config_t* config) { if (!config) { return RAC_ERROR_INVALID_ARGUMENT; } - return rac::server::HttpServer::instance().start(*config); + try { + return rac::server::HttpServer::instance().start(*config); + } catch (const std::exception& e) { + RAC_LOG_ERROR("Server", "Failed to start: %s", e.what()); + return RAC_ERROR_INTERNAL; + } catch (...) { + return RAC_ERROR_INTERNAL; + } } RAC_API rac_result_t rac_server_stop(void) { - return rac::server::HttpServer::instance().stop(); + try { + return rac::server::HttpServer::instance().stop(); + } catch (const std::exception& e) { + RAC_LOG_ERROR("Server", "Failed to stop: %s", e.what()); + return RAC_ERROR_INTERNAL; + } catch (...) { + return RAC_ERROR_INTERNAL; + } } RAC_API rac_bool_t rac_server_is_running(void) { - return rac::server::HttpServer::instance().isRunning() ? RAC_TRUE : RAC_FALSE; + try { + return rac::server::HttpServer::instance().isRunning() ? RAC_TRUE : RAC_FALSE; + } catch (...) { + return RAC_FALSE; + } } RAC_API rac_result_t rac_server_get_status(rac_server_status_t* status) { if (!status) { return RAC_ERROR_INVALID_ARGUMENT; } - rac::server::HttpServer::instance().getStatus(*status); - return RAC_SUCCESS; + try { + rac::server::HttpServer::instance().getStatus(*status); + return RAC_SUCCESS; + } catch (const std::exception& e) { + RAC_LOG_ERROR("Server", "Failed to get status: %s", e.what()); + return RAC_ERROR_INTERNAL; + } catch (...) { + return RAC_ERROR_INTERNAL; + } } RAC_API int rac_server_wait(void) { - return rac::server::HttpServer::instance().wait(); + try { + return rac::server::HttpServer::instance().wait(); + } catch (...) { + return -1; + } } RAC_API void rac_server_set_request_callback(rac_server_request_callback_fn callback, diff --git a/sdk/runanywhere-commons/src/server/http_server.h b/sdk/runanywhere-commons/src/server/http_server.h index 019854caa..c5372d77c 100644 --- a/sdk/runanywhere-commons/src/server/http_server.h +++ b/sdk/runanywhere-commons/src/server/http_server.h @@ -124,6 +124,7 @@ class HttpServer { std::atomic running_{false}; std::atomic shouldStop_{false}; mutable std::mutex mutex_; + mutable std::mutex callback_mutex_; // Protects callback pointers // Configuration (copied on start) rac_server_config_t config_; diff --git a/sdk/runanywhere-commons/src/server/openai_handler.cpp b/sdk/runanywhere-commons/src/server/openai_handler.cpp index 472f3057d..d7044ca7e 100644 --- a/sdk/runanywhere-commons/src/server/openai_handler.cpp +++ b/sdk/runanywhere-commons/src/server/openai_handler.cpp @@ -151,6 +151,7 @@ void OpenAIHandler::processNonStreaming(const httplib::Request& /*req*/, RAC_LOG_INFO("Server", "processNonStreaming: rac_llm_llamacpp_generate returned rc=%d", rc); if (RAC_FAILED(rc)) { + rac_llm_result_free(&result); sendError(res, 500, "Generation failed", "server_error"); return; } diff --git a/sdk/runanywhere-commons/src/utils/rac_image_utils.cpp b/sdk/runanywhere-commons/src/utils/rac_image_utils.cpp index 3e06a2675..b536ad3ed 100644 --- a/sdk/runanywhere-commons/src/utils/rac_image_utils.cpp +++ b/sdk/runanywhere-commons/src/utils/rac_image_utils.cpp @@ -117,8 +117,8 @@ std::vector base64_decode(const char* data, size_t len) { */ void bilinear_resize(const uint8_t* src, int src_w, int src_h, uint8_t* dst, int dst_w, int dst_h, int channels) { - float x_ratio = static_cast(src_w - 1) / static_cast(dst_w - 1); - float y_ratio = static_cast(src_h - 1) / static_cast(dst_h - 1); + float x_ratio = (dst_w > 1) ? static_cast(src_w - 1) / static_cast(dst_w - 1) : 0.0f; + float y_ratio = (dst_h > 1) ? static_cast(src_h - 1) / static_cast(dst_h - 1) : 0.0f; for (int y = 0; y < dst_h; y++) { for (int x = 0; x < dst_w; x++) { @@ -313,6 +313,10 @@ rac_result_t rac_image_normalize(const rac_image_data_t* image, const float* mea return RAC_ERROR_NULL_POINTER; } + if (image->channels < 1 || image->channels > 3) { + return RAC_ERROR_INVALID_PARAMETER; + } + memset(out_float, 0, sizeof(rac_image_float_t)); // Default mean and std (ImageNet-style normalization) @@ -486,6 +490,12 @@ void rac_image_calc_resize(int32_t width, int32_t height, int32_t max_size, int3 if (!out_width || !out_height) return; + if (width <= 0 || height <= 0) { + *out_width = 1; + *out_height = 1; + return; + } + if (width <= max_size && height <= max_size) { *out_width = width; *out_height = height; diff --git a/sdk/runanywhere-commons/tests/CMakeLists.txt b/sdk/runanywhere-commons/tests/CMakeLists.txt index 067ee4369..6ddd981ba 100644 --- a/sdk/runanywhere-commons/tests/CMakeLists.txt +++ b/sdk/runanywhere-commons/tests/CMakeLists.txt @@ -11,9 +11,32 @@ target_include_directories(test_core PRIVATE ${CMAKE_SOURCE_DIR}/include ) target_link_libraries(test_core PRIVATE rac_commons) -target_compile_features(test_core PRIVATE cxx_std_17) +target_compile_features(test_core PRIVATE cxx_std_20) add_test(NAME core_tests COMMAND test_core --run-all) +# --- test_extraction: Always built (no backend dependency) --- +# Tests rac_extract_archive_native() and rac_detect_archive_type() from libarchive. +add_executable(test_extraction test_extraction.cpp) +target_include_directories(test_extraction PRIVATE + ${CMAKE_CURRENT_SOURCE_DIR} + ${CMAKE_SOURCE_DIR}/include +) +target_link_libraries(test_extraction PRIVATE rac_commons) +target_compile_features(test_extraction PRIVATE cxx_std_17) +add_test(NAME extraction_tests COMMAND test_extraction --run-all) + +# --- test_download_orchestrator: Always built (no backend dependency) --- +# Tests rac_find_model_path_after_extraction(), rac_download_compute_destination(), +# and rac_download_requires_extraction() from rac_download_orchestrator.h. +add_executable(test_download_orchestrator test_download_orchestrator.cpp) +target_include_directories(test_download_orchestrator PRIVATE + ${CMAKE_CURRENT_SOURCE_DIR} + ${CMAKE_SOURCE_DIR}/include +) +target_link_libraries(test_download_orchestrator PRIVATE rac_commons) +target_compile_features(test_download_orchestrator PRIVATE cxx_std_17) +add_test(NAME download_orchestrator_tests COMMAND test_download_orchestrator --run-all) + # --- ONNX backend tests (VAD, STT, TTS, WakeWord) --- if(RAC_BACKEND_ONNX AND TARGET rac_backend_onnx) # VAD test @@ -23,7 +46,7 @@ if(RAC_BACKEND_ONNX AND TARGET rac_backend_onnx) ${CMAKE_SOURCE_DIR}/include ) target_link_libraries(test_vad PRIVATE rac_commons rac_backend_onnx) - target_compile_features(test_vad PRIVATE cxx_std_17) + target_compile_features(test_vad PRIVATE cxx_std_20) add_test(NAME vad_tests COMMAND test_vad --run-all) # STT test @@ -33,7 +56,7 @@ if(RAC_BACKEND_ONNX AND TARGET rac_backend_onnx) ${CMAKE_SOURCE_DIR}/include ) target_link_libraries(test_stt PRIVATE rac_commons rac_backend_onnx) - target_compile_features(test_stt PRIVATE cxx_std_17) + target_compile_features(test_stt PRIVATE cxx_std_20) add_test(NAME stt_tests COMMAND test_stt --run-all) # TTS test @@ -43,7 +66,7 @@ if(RAC_BACKEND_ONNX AND TARGET rac_backend_onnx) ${CMAKE_SOURCE_DIR}/include ) target_link_libraries(test_tts PRIVATE rac_commons rac_backend_onnx) - target_compile_features(test_tts PRIVATE cxx_std_17) + target_compile_features(test_tts PRIVATE cxx_std_20) add_test(NAME tts_tests COMMAND test_tts --run-all) # WakeWord test @@ -53,7 +76,7 @@ if(RAC_BACKEND_ONNX AND TARGET rac_backend_onnx) ${CMAKE_SOURCE_DIR}/include ) target_link_libraries(test_wakeword PRIVATE rac_commons rac_backend_onnx) - target_compile_features(test_wakeword PRIVATE cxx_std_17) + target_compile_features(test_wakeword PRIVATE cxx_std_20) add_test(NAME wakeword_tests COMMAND test_wakeword --run-all) endif() @@ -65,7 +88,7 @@ if(RAC_BACKEND_LLAMACPP AND TARGET rac_backend_llamacpp) ${CMAKE_SOURCE_DIR}/include ) target_link_libraries(test_llm PRIVATE rac_commons rac_backend_llamacpp) - target_compile_features(test_llm PRIVATE cxx_std_17) + target_compile_features(test_llm PRIVATE cxx_std_20) add_test(NAME llm_tests COMMAND test_llm --run-all) endif() @@ -79,7 +102,7 @@ if(RAC_BACKEND_ONNX AND RAC_BACKEND_LLAMACPP ) target_link_libraries(test_voice_agent PRIVATE rac_commons rac_backend_onnx rac_backend_llamacpp) - target_compile_features(test_voice_agent PRIVATE cxx_std_17) + target_compile_features(test_voice_agent PRIVATE cxx_std_20) add_test(NAME voice_agent_tests COMMAND test_voice_agent --run-all) endif() @@ -137,7 +160,7 @@ if(RAC_BACKEND_RAG) Threads::Threads GTest::gtest_main ) - target_compile_features(rac_rag_backend_thread_safety_test PRIVATE cxx_std_17) + target_compile_features(rac_rag_backend_thread_safety_test PRIVATE cxx_std_20) include(GoogleTest) gtest_discover_tests(rac_rag_backend_thread_safety_test DISCOVERY_MODE PRE_TEST @@ -157,7 +180,7 @@ if(RAC_BACKEND_RAG) Threads::Threads GTest::gtest_main ) - target_compile_features(rac_chunker_test PRIVATE cxx_std_17) + target_compile_features(rac_chunker_test PRIVATE cxx_std_20) gtest_discover_tests(rac_chunker_test DISCOVERY_MODE PRE_TEST ) @@ -176,7 +199,7 @@ if(RAC_BACKEND_RAG) Threads::Threads GTest::gtest_main ) - target_compile_features(rac_simple_tokenizer_test PRIVATE cxx_std_17) + target_compile_features(rac_simple_tokenizer_test PRIVATE cxx_std_20) gtest_discover_tests(rac_simple_tokenizer_test DISCOVERY_MODE PRE_TEST ) diff --git a/sdk/runanywhere-commons/tests/test_download_orchestrator.cpp b/sdk/runanywhere-commons/tests/test_download_orchestrator.cpp new file mode 100644 index 000000000..614802595 --- /dev/null +++ b/sdk/runanywhere-commons/tests/test_download_orchestrator.cpp @@ -0,0 +1,456 @@ +/** + * @file test_download_orchestrator.cpp + * @brief Unit tests for download orchestrator utilities. + * + * Tests rac_find_model_path_after_extraction(), rac_download_compute_destination(), + * and rac_download_requires_extraction() from rac_download_orchestrator.h. + * + * No ML backend or platform adapter needed — these are pure utility functions. + */ + +#include "test_common.h" + +#include "rac/infrastructure/download/rac_download_orchestrator.h" +#include "rac/infrastructure/model_management/rac_model_paths.h" +#include "rac/infrastructure/model_management/rac_model_types.h" + +#include +#include +#include +#include +#include +#include +#include + +// ============================================================================= +// Test helpers +// ============================================================================= + +/** Create a unique temporary directory for test artifacts. */ +static std::string create_temp_dir(const std::string& suffix) { + char tmpl[256]; + snprintf(tmpl, sizeof(tmpl), "/tmp/rac_dl_test_%s_XXXXXX", suffix.c_str()); + char* result = mkdtemp(tmpl); + if (!result) { + std::cerr << "Failed to create temp dir: " << tmpl << "\n"; + return ""; + } + return std::string(result); +} + +/** Recursively remove a directory. */ +static void remove_dir(const std::string& path) { + std::string cmd = "rm -rf \"" + path + "\""; + system(cmd.c_str()); +} + +/** Create a directory (like mkdir -p). */ +static void mkdir_p(const std::string& path) { + std::string cmd = "mkdir -p \"" + path + "\""; + system(cmd.c_str()); +} + +/** Write a dummy file. */ +static void write_dummy_file(const std::string& path, const std::string& content = "model data") { + std::ofstream f(path, std::ios::binary); + f << content; +} + +// ============================================================================= +// Tests: rac_download_requires_extraction +// ============================================================================= + +static TestResult test_requires_extraction_tar_gz() { + TestResult r; + r.test_name = "requires_extraction_tar_gz"; + + ASSERT_TRUE(rac_download_requires_extraction("https://example.com/model.tar.gz") == RAC_TRUE, + ".tar.gz should require extraction"); + ASSERT_TRUE(rac_download_requires_extraction("https://example.com/model.tgz") == RAC_TRUE, + ".tgz should require extraction"); + + r.passed = true; + return r; +} + +static TestResult test_requires_extraction_tar_bz2() { + TestResult r; + r.test_name = "requires_extraction_tar_bz2"; + + ASSERT_TRUE(rac_download_requires_extraction("https://example.com/model.tar.bz2") == RAC_TRUE, + ".tar.bz2 should require extraction"); + ASSERT_TRUE(rac_download_requires_extraction("https://example.com/model.tbz2") == RAC_TRUE, + ".tbz2 should require extraction"); + + r.passed = true; + return r; +} + +static TestResult test_requires_extraction_zip() { + TestResult r; + r.test_name = "requires_extraction_zip"; + + ASSERT_TRUE(rac_download_requires_extraction("https://example.com/model.zip") == RAC_TRUE, + ".zip should require extraction"); + + r.passed = true; + return r; +} + +static TestResult test_requires_extraction_no_archive() { + TestResult r; + r.test_name = "requires_extraction_no_archive"; + + ASSERT_TRUE(rac_download_requires_extraction("https://example.com/model.gguf") == RAC_FALSE, + ".gguf should NOT require extraction"); + ASSERT_TRUE(rac_download_requires_extraction("https://example.com/model.onnx") == RAC_FALSE, + ".onnx should NOT require extraction"); + ASSERT_TRUE(rac_download_requires_extraction("https://example.com/model.bin") == RAC_FALSE, + ".bin should NOT require extraction"); + ASSERT_TRUE(rac_download_requires_extraction(nullptr) == RAC_FALSE, + "NULL URL should NOT require extraction"); + + r.passed = true; + return r; +} + +static TestResult test_requires_extraction_url_with_query() { + TestResult r; + r.test_name = "requires_extraction_url_with_query"; + + ASSERT_TRUE( + rac_download_requires_extraction("https://example.com/model.tar.gz?token=abc") == RAC_TRUE, + ".tar.gz with query string should require extraction"); + ASSERT_TRUE( + rac_download_requires_extraction("https://example.com/model.gguf?v=2") == RAC_FALSE, + ".gguf with query string should NOT require extraction"); + + r.passed = true; + return r; +} + +// ============================================================================= +// Tests: rac_find_model_path_after_extraction +// ============================================================================= + +static TestResult test_find_model_single_gguf() { + TestResult r; + r.test_name = "find_model_single_gguf"; + + std::string dir = create_temp_dir("gguf"); + ASSERT_TRUE(!dir.empty(), "Failed to create temp dir"); + + // Create a single .gguf file at root + write_dummy_file(dir + "/llama-7b.gguf"); + + char out_path[4096]; + rac_result_t result = rac_find_model_path_after_extraction( + dir.c_str(), RAC_ARCHIVE_STRUCTURE_SINGLE_FILE_NESTED, RAC_FRAMEWORK_LLAMACPP, + RAC_MODEL_FORMAT_GGUF, out_path, sizeof(out_path)); + + ASSERT_TRUE(result == RAC_SUCCESS, "Should find model path"); + + std::string found(out_path); + ASSERT_TRUE(found.find("llama-7b.gguf") != std::string::npos, + "Should find the .gguf file"); + + remove_dir(dir); + r.passed = true; + return r; +} + +static TestResult test_find_model_nested_gguf() { + TestResult r; + r.test_name = "find_model_nested_gguf"; + + std::string dir = create_temp_dir("nested_gguf"); + ASSERT_TRUE(!dir.empty(), "Failed to create temp dir"); + + // Create a .gguf file nested one level deep (common archive pattern) + mkdir_p(dir + "/model-folder"); + write_dummy_file(dir + "/model-folder/model.gguf"); + + char out_path[4096]; + rac_result_t result = rac_find_model_path_after_extraction( + dir.c_str(), RAC_ARCHIVE_STRUCTURE_SINGLE_FILE_NESTED, RAC_FRAMEWORK_LLAMACPP, + RAC_MODEL_FORMAT_GGUF, out_path, sizeof(out_path)); + + ASSERT_TRUE(result == RAC_SUCCESS, "Should find nested model path"); + + std::string found(out_path); + ASSERT_TRUE(found.find("model.gguf") != std::string::npos, + "Should find the nested .gguf file"); + + remove_dir(dir); + r.passed = true; + return r; +} + +static TestResult test_find_model_nested_directory() { + TestResult r; + r.test_name = "find_model_nested_directory"; + + std::string dir = create_temp_dir("nested_dir"); + ASSERT_TRUE(!dir.empty(), "Failed to create temp dir"); + + // Sherpa-ONNX pattern: archive extracts to a single subdirectory + mkdir_p(dir + "/vits-piper-en_US-libritts_r-medium"); + write_dummy_file(dir + "/vits-piper-en_US-libritts_r-medium/model.onnx"); + write_dummy_file(dir + "/vits-piper-en_US-libritts_r-medium/tokens.txt"); + write_dummy_file(dir + "/vits-piper-en_US-libritts_r-medium/lexicon.txt"); + + char out_path[4096]; + rac_result_t result = rac_find_model_path_after_extraction( + dir.c_str(), RAC_ARCHIVE_STRUCTURE_NESTED_DIRECTORY, RAC_FRAMEWORK_ONNX, + RAC_MODEL_FORMAT_ONNX, out_path, sizeof(out_path)); + + ASSERT_TRUE(result == RAC_SUCCESS, "Should find nested directory"); + + std::string found(out_path); + ASSERT_TRUE(found.find("vits-piper-en_US-libritts_r-medium") != std::string::npos, + "Should return the nested subdirectory path"); + + remove_dir(dir); + r.passed = true; + return r; +} + +static TestResult test_find_model_directory_based_onnx() { + TestResult r; + r.test_name = "find_model_directory_based_onnx"; + + std::string dir = create_temp_dir("onnx_dir"); + ASSERT_TRUE(!dir.empty(), "Failed to create temp dir"); + + // ONNX directory-based model: multiple files at root + write_dummy_file(dir + "/encoder.onnx"); + write_dummy_file(dir + "/decoder.onnx"); + write_dummy_file(dir + "/tokens.txt"); + + char out_path[4096]; + rac_result_t result = rac_find_model_path_after_extraction( + dir.c_str(), RAC_ARCHIVE_STRUCTURE_DIRECTORY_BASED, RAC_FRAMEWORK_ONNX, + RAC_MODEL_FORMAT_ONNX, out_path, sizeof(out_path)); + + ASSERT_TRUE(result == RAC_SUCCESS, "Should succeed for directory-based model"); + + // For ONNX directory-based, should return the directory itself + std::string found(out_path); + ASSERT_TRUE(found == dir, "Should return the extraction directory for directory-based ONNX"); + + remove_dir(dir); + r.passed = true; + return r; +} + +static TestResult test_find_model_skips_hidden_files() { + TestResult r; + r.test_name = "find_model_skips_hidden_files"; + + std::string dir = create_temp_dir("hidden"); + ASSERT_TRUE(!dir.empty(), "Failed to create temp dir"); + + // Create macOS resource fork and hidden files (should be ignored) + write_dummy_file(dir + "/._model.gguf"); + mkdir_p(dir + "/.DS_Store"); + mkdir_p(dir + "/__MACOSX"); + // Real model in subdirectory + mkdir_p(dir + "/model-dir"); + write_dummy_file(dir + "/model-dir/real-model.gguf"); + + char out_path[4096]; + rac_result_t result = rac_find_model_path_after_extraction( + dir.c_str(), RAC_ARCHIVE_STRUCTURE_SINGLE_FILE_NESTED, RAC_FRAMEWORK_LLAMACPP, + RAC_MODEL_FORMAT_GGUF, out_path, sizeof(out_path)); + + ASSERT_TRUE(result == RAC_SUCCESS, "Should find real model file"); + + std::string found(out_path); + ASSERT_TRUE(found.find("real-model.gguf") != std::string::npos, + "Should find the real model, not hidden files"); + ASSERT_TRUE(found.find("._model") == std::string::npos, + "Should NOT match macOS resource fork files"); + + remove_dir(dir); + r.passed = true; + return r; +} + +static TestResult test_find_model_unknown_structure() { + TestResult r; + r.test_name = "find_model_unknown_structure"; + + std::string dir = create_temp_dir("unknown"); + ASSERT_TRUE(!dir.empty(), "Failed to create temp dir"); + + // Single .bin file at root + write_dummy_file(dir + "/model.bin"); + + char out_path[4096]; + rac_result_t result = rac_find_model_path_after_extraction( + dir.c_str(), RAC_ARCHIVE_STRUCTURE_UNKNOWN, RAC_FRAMEWORK_LLAMACPP, RAC_MODEL_FORMAT_BIN, + out_path, sizeof(out_path)); + + ASSERT_TRUE(result == RAC_SUCCESS, "Should find model with unknown structure"); + + std::string found(out_path); + ASSERT_TRUE(found.find("model.bin") != std::string::npos, + "Should find the .bin model file"); + + remove_dir(dir); + r.passed = true; + return r; +} + +static TestResult test_find_model_empty_dir() { + TestResult r; + r.test_name = "find_model_empty_dir"; + + std::string dir = create_temp_dir("empty"); + ASSERT_TRUE(!dir.empty(), "Failed to create temp dir"); + + char out_path[4096]; + rac_result_t result = rac_find_model_path_after_extraction( + dir.c_str(), RAC_ARCHIVE_STRUCTURE_SINGLE_FILE_NESTED, RAC_FRAMEWORK_LLAMACPP, + RAC_MODEL_FORMAT_GGUF, out_path, sizeof(out_path)); + + // Should still succeed (returns the directory itself as fallback) + ASSERT_TRUE(result == RAC_SUCCESS, "Should succeed even for empty dir"); + + remove_dir(dir); + r.passed = true; + return r; +} + +static TestResult test_find_model_null_args() { + TestResult r; + r.test_name = "find_model_null_args"; + + char out_path[4096]; + + ASSERT_TRUE(rac_find_model_path_after_extraction(nullptr, RAC_ARCHIVE_STRUCTURE_UNKNOWN, + RAC_FRAMEWORK_LLAMACPP, RAC_MODEL_FORMAT_GGUF, + out_path, sizeof(out_path)) == + RAC_ERROR_INVALID_ARGUMENT, + "NULL extracted_dir should return INVALID_ARGUMENT"); + + ASSERT_TRUE(rac_find_model_path_after_extraction("/tmp", RAC_ARCHIVE_STRUCTURE_UNKNOWN, + RAC_FRAMEWORK_LLAMACPP, RAC_MODEL_FORMAT_GGUF, + nullptr, 0) == RAC_ERROR_INVALID_ARGUMENT, + "NULL out_path should return INVALID_ARGUMENT"); + + r.passed = true; + return r; +} + +// ============================================================================= +// Tests: rac_download_compute_destination +// ============================================================================= + +static TestResult test_compute_destination_needs_base_dir() { + TestResult r; + r.test_name = "compute_destination_needs_base_dir"; + + // Set up base dir for model paths + std::string base_dir = create_temp_dir("base"); + ASSERT_TRUE(!base_dir.empty(), "Failed to create temp dir"); + + rac_model_paths_set_base_dir(base_dir.c_str()); + + char out_path[4096]; + rac_bool_t needs_extraction = RAC_FALSE; + + rac_result_t result = rac_download_compute_destination( + "test-model", "https://example.com/model.gguf", RAC_FRAMEWORK_LLAMACPP, + RAC_MODEL_FORMAT_GGUF, out_path, sizeof(out_path), &needs_extraction); + + ASSERT_TRUE(result == RAC_SUCCESS, "Should compute destination successfully"); + ASSERT_TRUE(needs_extraction == RAC_FALSE, ".gguf should not need extraction"); + + std::string path(out_path); + ASSERT_TRUE(path.find("model.gguf") != std::string::npos, + "Should contain the filename"); + + remove_dir(base_dir); + r.passed = true; + return r; +} + +static TestResult test_compute_destination_archive() { + TestResult r; + r.test_name = "compute_destination_archive"; + + std::string base_dir = create_temp_dir("base_archive"); + ASSERT_TRUE(!base_dir.empty(), "Failed to create temp dir"); + + rac_model_paths_set_base_dir(base_dir.c_str()); + + char out_path[4096]; + rac_bool_t needs_extraction = RAC_FALSE; + + rac_result_t result = rac_download_compute_destination( + "sherpa-model", "https://example.com/sherpa-model.tar.bz2", RAC_FRAMEWORK_ONNX, + RAC_MODEL_FORMAT_ONNX, out_path, sizeof(out_path), &needs_extraction); + + ASSERT_TRUE(result == RAC_SUCCESS, "Should compute destination for archive"); + ASSERT_TRUE(needs_extraction == RAC_TRUE, ".tar.bz2 should need extraction"); + + std::string path(out_path); + ASSERT_TRUE(path.find("Downloads") != std::string::npos || path.find("download") != std::string::npos, + "Archive should download to downloads/temp directory"); + ASSERT_TRUE(path.find(".tar.bz2") != std::string::npos, + "Should preserve archive extension"); + + remove_dir(base_dir); + r.passed = true; + return r; +} + +static TestResult test_compute_destination_null_args() { + TestResult r; + r.test_name = "compute_destination_null_args"; + + char out_path[4096]; + rac_bool_t needs_extraction = RAC_FALSE; + + ASSERT_TRUE(rac_download_compute_destination(nullptr, "url", RAC_FRAMEWORK_LLAMACPP, + RAC_MODEL_FORMAT_GGUF, out_path, + sizeof(out_path), + &needs_extraction) == RAC_ERROR_INVALID_ARGUMENT, + "NULL model_id should return INVALID_ARGUMENT"); + + r.passed = true; + return r; +} + +// ============================================================================= +// Test runner +// ============================================================================= + +int main(int argc, char** argv) { + TestSuite suite("download_orchestrator"); + + // rac_download_requires_extraction + suite.add("requires_extraction_tar_gz", test_requires_extraction_tar_gz); + suite.add("requires_extraction_tar_bz2", test_requires_extraction_tar_bz2); + suite.add("requires_extraction_zip", test_requires_extraction_zip); + suite.add("requires_extraction_no_archive", test_requires_extraction_no_archive); + suite.add("requires_extraction_url_with_query", test_requires_extraction_url_with_query); + + // rac_find_model_path_after_extraction + suite.add("find_model_single_gguf", test_find_model_single_gguf); + suite.add("find_model_nested_gguf", test_find_model_nested_gguf); + suite.add("find_model_nested_directory", test_find_model_nested_directory); + suite.add("find_model_directory_based_onnx", test_find_model_directory_based_onnx); + suite.add("find_model_skips_hidden_files", test_find_model_skips_hidden_files); + suite.add("find_model_unknown_structure", test_find_model_unknown_structure); + suite.add("find_model_empty_dir", test_find_model_empty_dir); + suite.add("find_model_null_args", test_find_model_null_args); + + // rac_download_compute_destination + suite.add("compute_destination_needs_base_dir", test_compute_destination_needs_base_dir); + suite.add("compute_destination_archive", test_compute_destination_archive); + suite.add("compute_destination_null_args", test_compute_destination_null_args); + + return suite.run(argc, argv); +} diff --git a/sdk/runanywhere-commons/tests/test_extraction.cpp b/sdk/runanywhere-commons/tests/test_extraction.cpp new file mode 100644 index 000000000..bc62fc87a --- /dev/null +++ b/sdk/runanywhere-commons/tests/test_extraction.cpp @@ -0,0 +1,808 @@ +/** + * @file test_extraction.cpp + * @brief Integration tests for native archive extraction (libarchive). + * + * Tests rac_extract_archive_native() and rac_detect_archive_type() from + * rac_extraction.h. No ML backend dependency — only links rac_commons. + * + * Uses system `tar` and `zip` commands to create test archives on macOS/Linux. + */ + +#include "test_common.h" +#include "test_config.h" + +#include "rac/infrastructure/extraction/rac_extraction.h" +#include "rac/infrastructure/model_management/rac_model_types.h" + +#include +#include +#include +#include +#include +#include +#include +#include + +// No platform adapter or rac_init() needed — extraction APIs are standalone. + +// ============================================================================= +// Test helpers +// ============================================================================= + +static std::string g_test_dir; + +/** Create a unique temporary directory for test artifacts. */ +static std::string create_temp_dir(const std::string& suffix) { + char tmpl[256]; + snprintf(tmpl, sizeof(tmpl), "/tmp/rac_test_%s_XXXXXX", suffix.c_str()); + char* result = mkdtemp(tmpl); + if (!result) { + std::cerr << "Failed to create temp dir: " << tmpl << "\n"; + return ""; + } + return std::string(result); +} + +/** Recursively remove a directory. */ +static void remove_dir(const std::string& path) { + std::string cmd = "rm -rf \"" + path + "\""; + system(cmd.c_str()); +} + +/** Check if a file exists. */ +static bool file_exists(const std::string& path) { + struct stat st; + return stat(path.c_str(), &st) == 0; +} + +/** Read entire file contents. */ +static std::string read_file_contents(const std::string& path) { + std::ifstream f(path, std::ios::binary); + if (!f.is_open()) return ""; + return std::string((std::istreambuf_iterator(f)), + std::istreambuf_iterator()); +} + +/** Write bytes to a file. */ +static bool write_file(const std::string& path, const void* data, size_t size) { + std::ofstream f(path, std::ios::binary); + if (!f.is_open()) return false; + f.write(static_cast(data), static_cast(size)); + return f.good(); +} + +/** Write string to a file. */ +static bool write_file(const std::string& path, const std::string& content) { + return write_file(path, content.data(), content.size()); +} + +/** Check if tar command is available. */ +static bool has_tar() { + return system("tar --version > /dev/null 2>&1") == 0; +} + +/** Check if zip command is available. */ +static bool has_zip() { + return system("zip --version > /dev/null 2>&1") == 0; +} + +/** + * Create a tar.gz archive containing test files. + * Returns path to the created archive, or empty string on failure. + */ +static std::string create_test_tar_gz(const std::string& base_dir) { + std::string content_dir = base_dir + "/content"; + std::string sub_dir = content_dir + "/subdir"; + mkdir(content_dir.c_str(), 0755); + mkdir(sub_dir.c_str(), 0755); + + write_file(content_dir + "/hello.txt", "Hello, World!\n"); + write_file(content_dir + "/data.bin", std::string(256, '\x42')); + write_file(sub_dir + "/nested.txt", "Nested file content\n"); + + std::string archive_path = base_dir + "/test.tar.gz"; + std::string cmd = "tar czf \"" + archive_path + "\" -C \"" + base_dir + "\" content"; + if (system(cmd.c_str()) != 0) return ""; + return archive_path; +} + +/** + * Create a ZIP archive containing test files. + * Returns path to the created archive, or empty string on failure. + */ +static std::string create_test_zip(const std::string& base_dir) { + std::string content_dir = base_dir + "/zipcontent"; + std::string sub_dir = content_dir + "/subdir"; + mkdir(content_dir.c_str(), 0755); + mkdir(sub_dir.c_str(), 0755); + + write_file(content_dir + "/readme.txt", "ZIP test file\n"); + write_file(content_dir + "/binary.dat", std::string(128, '\xAB')); + write_file(sub_dir + "/deep.txt", "Deep nested\n"); + + std::string archive_path = base_dir + "/test.zip"; + std::string cmd = "cd \"" + base_dir + "\" && zip -r \"" + archive_path + "\" zipcontent > /dev/null 2>&1"; + if (system(cmd.c_str()) != 0) return ""; + return archive_path; +} + +// ============================================================================= +// Test: null pointer handling +// ============================================================================= + +static TestResult test_null_pointer() { + rac_result_t rc = rac_extract_archive_native(nullptr, "/tmp", nullptr, nullptr, nullptr, nullptr); + ASSERT_EQ(rc, RAC_ERROR_NULL_POINTER, "NULL archive_path should return RAC_ERROR_NULL_POINTER"); + + rc = rac_extract_archive_native("/tmp/test.tar.gz", nullptr, nullptr, nullptr, nullptr, nullptr); + ASSERT_EQ(rc, RAC_ERROR_NULL_POINTER, "NULL destination_dir should return RAC_ERROR_NULL_POINTER"); + + rc = rac_extract_archive_native(nullptr, nullptr, nullptr, nullptr, nullptr, nullptr); + ASSERT_EQ(rc, RAC_ERROR_NULL_POINTER, "Both NULL should return RAC_ERROR_NULL_POINTER"); + + return TEST_PASS(); +} + +// ============================================================================= +// Test: file not found +// ============================================================================= + +static TestResult test_file_not_found() { + rac_result_t rc = rac_extract_archive_native( + "/nonexistent/path/archive.tar.gz", "/tmp/dest", + nullptr, nullptr, nullptr, nullptr); + ASSERT_EQ(rc, RAC_ERROR_FILE_NOT_FOUND, + "Non-existent archive should return RAC_ERROR_FILE_NOT_FOUND"); + + return TEST_PASS(); +} + +// ============================================================================= +// Test: detect archive type - null handling +// ============================================================================= + +static TestResult test_detect_null() { + rac_archive_type_t type; + ASSERT_EQ(rac_detect_archive_type(nullptr, &type), RAC_FALSE, + "NULL file_path should return RAC_FALSE"); + ASSERT_EQ(rac_detect_archive_type("/tmp/test.bin", nullptr), RAC_FALSE, + "NULL out_type should return RAC_FALSE"); + + return TEST_PASS(); +} + +// ============================================================================= +// Test: detect archive type - non-existent file +// ============================================================================= + +static TestResult test_detect_nonexistent() { + rac_archive_type_t type; + ASSERT_EQ(rac_detect_archive_type("/nonexistent/file.bin", &type), RAC_FALSE, + "Non-existent file should return RAC_FALSE"); + + return TEST_PASS(); +} + +// ============================================================================= +// Test: detect ZIP magic bytes +// ============================================================================= + +static TestResult test_detect_zip() { + std::string path = g_test_dir + "/magic_zip.bin"; + unsigned char zip_magic[] = {0x50, 0x4B, 0x03, 0x04, 0x00, 0x00}; + write_file(path, zip_magic, sizeof(zip_magic)); + + rac_archive_type_t type; + ASSERT_EQ(rac_detect_archive_type(path.c_str(), &type), RAC_TRUE, + "ZIP magic bytes should be detected"); + ASSERT_EQ(type, RAC_ARCHIVE_TYPE_ZIP, "Type should be RAC_ARCHIVE_TYPE_ZIP"); + + return TEST_PASS(); +} + +// ============================================================================= +// Test: detect GZIP magic bytes +// ============================================================================= + +static TestResult test_detect_gzip() { + std::string path = g_test_dir + "/magic_gzip.bin"; + unsigned char gz_magic[] = {0x1F, 0x8B, 0x08, 0x00}; + write_file(path, gz_magic, sizeof(gz_magic)); + + rac_archive_type_t type; + ASSERT_EQ(rac_detect_archive_type(path.c_str(), &type), RAC_TRUE, + "GZIP magic bytes should be detected"); + ASSERT_EQ(type, RAC_ARCHIVE_TYPE_TAR_GZ, "Type should be RAC_ARCHIVE_TYPE_TAR_GZ"); + + return TEST_PASS(); +} + +// ============================================================================= +// Test: detect BZIP2 magic bytes +// ============================================================================= + +static TestResult test_detect_bzip2() { + std::string path = g_test_dir + "/magic_bz2.bin"; + unsigned char bz2_magic[] = {0x42, 0x5A, 0x68, 0x39}; // "BZh9" + write_file(path, bz2_magic, sizeof(bz2_magic)); + + rac_archive_type_t type; + ASSERT_EQ(rac_detect_archive_type(path.c_str(), &type), RAC_TRUE, + "BZIP2 magic bytes should be detected"); + ASSERT_EQ(type, RAC_ARCHIVE_TYPE_TAR_BZ2, "Type should be RAC_ARCHIVE_TYPE_TAR_BZ2"); + + return TEST_PASS(); +} + +// ============================================================================= +// Test: detect XZ magic bytes +// ============================================================================= + +static TestResult test_detect_xz() { + std::string path = g_test_dir + "/magic_xz.bin"; + unsigned char xz_magic[] = {0xFD, 0x37, 0x7A, 0x58, 0x5A, 0x00}; + write_file(path, xz_magic, sizeof(xz_magic)); + + rac_archive_type_t type; + ASSERT_EQ(rac_detect_archive_type(path.c_str(), &type), RAC_TRUE, + "XZ magic bytes should be detected"); + ASSERT_EQ(type, RAC_ARCHIVE_TYPE_TAR_XZ, "Type should be RAC_ARCHIVE_TYPE_TAR_XZ"); + + return TEST_PASS(); +} + +// ============================================================================= +// Test: detect unknown format +// ============================================================================= + +static TestResult test_detect_unknown() { + std::string path = g_test_dir + "/magic_unknown.bin"; + unsigned char random_bytes[] = {0xDE, 0xAD, 0xBE, 0xEF, 0xCA, 0xFE}; + write_file(path, random_bytes, sizeof(random_bytes)); + + rac_archive_type_t type; + ASSERT_EQ(rac_detect_archive_type(path.c_str(), &type), RAC_FALSE, + "Unknown magic bytes should return RAC_FALSE"); + + return TEST_PASS(); +} + +// ============================================================================= +// Test: detect empty file +// ============================================================================= + +static TestResult test_detect_empty_file() { + std::string path = g_test_dir + "/empty.bin"; + write_file(path, "", 0); + + rac_archive_type_t type; + ASSERT_EQ(rac_detect_archive_type(path.c_str(), &type), RAC_FALSE, + "Empty file should return RAC_FALSE"); + + return TEST_PASS(); +} + +// ============================================================================= +// Test: extract tar.gz archive +// ============================================================================= + +static TestResult test_extract_tar_gz() { + if (!has_tar()) { + TestResult r; + r.passed = true; + r.details = "SKIPPED (tar not available)"; + return r; + } + + std::string archive_dir = create_temp_dir("tgz_src"); + std::string dest_dir = create_temp_dir("tgz_dest"); + ASSERT_TRUE(!archive_dir.empty(), "Should create archive source dir"); + ASSERT_TRUE(!dest_dir.empty(), "Should create dest dir"); + + std::string archive_path = create_test_tar_gz(archive_dir); + ASSERT_TRUE(!archive_path.empty(), "Should create tar.gz archive"); + ASSERT_TRUE(file_exists(archive_path), "Archive file should exist"); + + // Verify detection + rac_archive_type_t type; + ASSERT_EQ(rac_detect_archive_type(archive_path.c_str(), &type), RAC_TRUE, + "Should detect tar.gz"); + ASSERT_EQ(type, RAC_ARCHIVE_TYPE_TAR_GZ, "Should be TAR_GZ"); + + // Extract + rac_extraction_result_t result = {}; + rac_result_t rc = rac_extract_archive_native( + archive_path.c_str(), dest_dir.c_str(), + nullptr, nullptr, nullptr, &result); + ASSERT_EQ(rc, RAC_SUCCESS, "Extraction should succeed"); + + // Verify extracted files + ASSERT_TRUE(result.files_extracted >= 3, "Should extract at least 3 files"); + ASSERT_TRUE(result.directories_created >= 1, "Should create at least 1 directory"); + ASSERT_TRUE(result.bytes_extracted > 0, "Should extract some bytes"); + + // Verify file contents + std::string hello_content = read_file_contents(dest_dir + "/content/hello.txt"); + ASSERT_TRUE(hello_content == "Hello, World!\n", + "hello.txt content should match"); + + std::string nested_content = read_file_contents(dest_dir + "/content/subdir/nested.txt"); + ASSERT_TRUE(nested_content == "Nested file content\n", + "nested.txt content should match"); + + std::string data_content = read_file_contents(dest_dir + "/content/data.bin"); + ASSERT_EQ(static_cast(data_content.size()), 256, "data.bin should be 256 bytes"); + ASSERT_TRUE(data_content[0] == '\x42', "data.bin content should be 0x42"); + + // Cleanup + remove_dir(archive_dir); + remove_dir(dest_dir); + + return TEST_PASS(); +} + +// ============================================================================= +// Test: extract ZIP archive +// ============================================================================= + +static TestResult test_extract_zip() { + if (!has_zip()) { + TestResult r; + r.passed = true; + r.details = "SKIPPED (zip not available)"; + return r; + } + + std::string archive_dir = create_temp_dir("zip_src"); + std::string dest_dir = create_temp_dir("zip_dest"); + ASSERT_TRUE(!archive_dir.empty(), "Should create archive source dir"); + ASSERT_TRUE(!dest_dir.empty(), "Should create dest dir"); + + std::string archive_path = create_test_zip(archive_dir); + ASSERT_TRUE(!archive_path.empty(), "Should create ZIP archive"); + ASSERT_TRUE(file_exists(archive_path), "Archive file should exist"); + + // Verify detection + rac_archive_type_t type; + ASSERT_EQ(rac_detect_archive_type(archive_path.c_str(), &type), RAC_TRUE, + "Should detect ZIP"); + ASSERT_EQ(type, RAC_ARCHIVE_TYPE_ZIP, "Should be ZIP"); + + // Extract + rac_extraction_result_t result = {}; + rac_result_t rc = rac_extract_archive_native( + archive_path.c_str(), dest_dir.c_str(), + nullptr, nullptr, nullptr, &result); + ASSERT_EQ(rc, RAC_SUCCESS, "ZIP extraction should succeed"); + + // Verify extracted files + ASSERT_TRUE(result.files_extracted >= 3, "Should extract at least 3 files"); + ASSERT_TRUE(result.bytes_extracted > 0, "Should extract some bytes"); + + // Verify file contents + std::string readme_content = read_file_contents(dest_dir + "/zipcontent/readme.txt"); + ASSERT_TRUE(readme_content == "ZIP test file\n", + "readme.txt content should match"); + + std::string deep_content = read_file_contents(dest_dir + "/zipcontent/subdir/deep.txt"); + ASSERT_TRUE(deep_content == "Deep nested\n", + "deep.txt content should match"); + + // Cleanup + remove_dir(archive_dir); + remove_dir(dest_dir); + + return TEST_PASS(); +} + +// ============================================================================= +// Test: progress callback is invoked +// ============================================================================= + +struct ProgressData { + int callback_count; + int32_t last_files_extracted; + int64_t last_bytes_extracted; +}; + +static void test_progress_callback(int32_t files_extracted, int32_t /*total_files*/, + int64_t bytes_extracted, void* user_data) { + auto* data = static_cast(user_data); + data->callback_count++; + data->last_files_extracted = files_extracted; + data->last_bytes_extracted = bytes_extracted; +} + +static TestResult test_progress_callback_invoked() { + if (!has_tar()) { + TestResult r; + r.passed = true; + r.details = "SKIPPED (tar not available)"; + return r; + } + + std::string archive_dir = create_temp_dir("prog_src"); + std::string dest_dir = create_temp_dir("prog_dest"); + ASSERT_TRUE(!archive_dir.empty() && !dest_dir.empty(), "Should create dirs"); + + std::string archive_path = create_test_tar_gz(archive_dir); + ASSERT_TRUE(!archive_path.empty(), "Should create archive"); + + ProgressData progress = {0, 0, 0}; + rac_result_t rc = rac_extract_archive_native( + archive_path.c_str(), dest_dir.c_str(), + nullptr, test_progress_callback, &progress, nullptr); + ASSERT_EQ(rc, RAC_SUCCESS, "Extraction with progress should succeed"); + ASSERT_TRUE(progress.callback_count > 0, + "Progress callback should be invoked at least once"); + ASSERT_TRUE(progress.last_files_extracted > 0, + "Last files_extracted should be > 0"); + ASSERT_TRUE(progress.last_bytes_extracted > 0, + "Last bytes_extracted should be > 0"); + + remove_dir(archive_dir); + remove_dir(dest_dir); + + return TEST_PASS(); +} + +// ============================================================================= +// Test: extraction result statistics +// ============================================================================= + +static TestResult test_extraction_result_stats() { + if (!has_tar()) { + TestResult r; + r.passed = true; + r.details = "SKIPPED (tar not available)"; + return r; + } + + std::string archive_dir = create_temp_dir("stats_src"); + std::string dest_dir = create_temp_dir("stats_dest"); + ASSERT_TRUE(!archive_dir.empty() && !dest_dir.empty(), "Should create dirs"); + + std::string archive_path = create_test_tar_gz(archive_dir); + ASSERT_TRUE(!archive_path.empty(), "Should create archive"); + + rac_extraction_result_t result = {}; + rac_result_t rc = rac_extract_archive_native( + archive_path.c_str(), dest_dir.c_str(), + nullptr, nullptr, nullptr, &result); + ASSERT_EQ(rc, RAC_SUCCESS, "Extraction should succeed"); + + // We created 3 files (hello.txt, data.bin, nested.txt) + ASSERT_EQ(result.files_extracted, 3, + "Should extract exactly 3 files"); + // We created 2 directories (content, content/subdir) + ASSERT_TRUE(result.directories_created >= 1, + "Should create at least 1 directory"); + // hello.txt(14) + data.bin(256) + nested.txt(20) = 290 bytes + ASSERT_TRUE(result.bytes_extracted >= 290, + "bytes_extracted should account for all file data"); + // No entries should be skipped (no macOS resource forks, no unsafe paths) + ASSERT_EQ(result.entries_skipped, 0, + "No entries should be skipped"); + + remove_dir(archive_dir); + remove_dir(dest_dir); + + return TEST_PASS(); +} + +// ============================================================================= +// Test: unsupported archive format +// ============================================================================= + +static TestResult test_unsupported_format() { + std::string path = g_test_dir + "/not_an_archive.dat"; + // Write random data that isn't a valid archive + std::string garbage(1024, '\xAB'); + write_file(path, garbage); + + std::string dest_dir = create_temp_dir("unsup_dest"); + ASSERT_TRUE(!dest_dir.empty(), "Should create dest dir"); + + rac_result_t rc = rac_extract_archive_native( + path.c_str(), dest_dir.c_str(), + nullptr, nullptr, nullptr, nullptr); + ASSERT_EQ(rc, RAC_ERROR_UNSUPPORTED_ARCHIVE, + "Invalid archive should return RAC_ERROR_UNSUPPORTED_ARCHIVE"); + + remove_dir(dest_dir); + + return TEST_PASS(); +} + +// ============================================================================= +// Test: extraction creates destination directory +// ============================================================================= + +static TestResult test_creates_dest_dir() { + if (!has_tar()) { + TestResult r; + r.passed = true; + r.details = "SKIPPED (tar not available)"; + return r; + } + + std::string archive_dir = create_temp_dir("mkdir_src"); + ASSERT_TRUE(!archive_dir.empty(), "Should create archive source dir"); + + std::string archive_path = create_test_tar_gz(archive_dir); + ASSERT_TRUE(!archive_path.empty(), "Should create archive"); + + // Destination directory that doesn't exist yet + std::string dest_dir = g_test_dir + "/new_nested/extraction/output"; + ASSERT_TRUE(!file_exists(dest_dir), "Dest dir should not exist yet"); + + rac_result_t rc = rac_extract_archive_native( + archive_path.c_str(), dest_dir.c_str(), + nullptr, nullptr, nullptr, nullptr); + ASSERT_EQ(rc, RAC_SUCCESS, "Extraction should create destination and succeed"); + ASSERT_TRUE(file_exists(dest_dir), "Destination dir should now exist"); + ASSERT_TRUE(file_exists(dest_dir + "/content/hello.txt"), + "Extracted file should exist"); + + remove_dir(archive_dir); + + return TEST_PASS(); +} + +// ============================================================================= +// Test: default options (skip macOS resources) +// ============================================================================= + +static TestResult test_default_options_skip_macos() { + if (!has_tar()) { + TestResult r; + r.passed = true; + r.details = "SKIPPED (tar not available)"; + return r; + } + + // Create content with macOS resource fork files + std::string archive_dir = create_temp_dir("macos_src"); + std::string content_dir = archive_dir + "/macos_content"; + std::string macosx_dir = content_dir + "/__MACOSX"; + mkdir(content_dir.c_str(), 0755); + mkdir(macosx_dir.c_str(), 0755); + + write_file(content_dir + "/real_file.txt", "real content\n"); + write_file(content_dir + "/._resource_fork", "resource fork\n"); + write_file(macosx_dir + "/metadata.plist", "macos metadata\n"); + + std::string archive_path = archive_dir + "/macos_test.tar.gz"; + std::string cmd = "tar czf \"" + archive_path + "\" -C \"" + archive_dir + "\" macos_content"; + ASSERT_TRUE(system(cmd.c_str()) == 0, "Should create tar.gz with macOS entries"); + + std::string dest_dir = create_temp_dir("macos_dest"); + ASSERT_TRUE(!dest_dir.empty(), "Should create dest dir"); + + rac_extraction_result_t result = {}; + rac_result_t rc = rac_extract_archive_native( + archive_path.c_str(), dest_dir.c_str(), + nullptr, nullptr, nullptr, &result); + ASSERT_EQ(rc, RAC_SUCCESS, "Extraction should succeed"); + + // real_file.txt should be extracted + ASSERT_TRUE(file_exists(dest_dir + "/macos_content/real_file.txt"), + "Real file should be extracted"); + + // macOS resource forks should be skipped + ASSERT_TRUE(result.entries_skipped > 0, + "Should skip macOS resource entries"); + ASSERT_TRUE(!file_exists(dest_dir + "/macos_content/__MACOSX/metadata.plist"), + "__MACOSX directory contents should be skipped"); + ASSERT_TRUE(!file_exists(dest_dir + "/macos_content/._resource_fork"), + "._ resource fork files should be skipped"); + + remove_dir(archive_dir); + remove_dir(dest_dir); + + return TEST_PASS(); +} + +// ============================================================================= +// Test: extraction with custom options (don't skip macOS resources) +// ============================================================================= + +static TestResult test_custom_options_keep_macos() { + if (!has_tar()) { + TestResult r; + r.passed = true; + r.details = "SKIPPED (tar not available)"; + return r; + } + + std::string archive_dir = create_temp_dir("keepmac_src"); + std::string content_dir = archive_dir + "/keep_content"; + std::string macosx_dir = content_dir + "/__MACOSX"; + mkdir(content_dir.c_str(), 0755); + mkdir(macosx_dir.c_str(), 0755); + + write_file(content_dir + "/file.txt", "content\n"); + write_file(macosx_dir + "/meta.plist", "metadata\n"); + + std::string archive_path = archive_dir + "/keep_macos.tar.gz"; + std::string cmd = "tar czf \"" + archive_path + "\" -C \"" + archive_dir + "\" keep_content"; + ASSERT_TRUE(system(cmd.c_str()) == 0, "Should create tar.gz"); + + std::string dest_dir = create_temp_dir("keepmac_dest"); + ASSERT_TRUE(!dest_dir.empty(), "Should create dest dir"); + + // Don't skip macOS resources + rac_extraction_options_t opts = {}; + opts.skip_macos_resources = RAC_FALSE; + opts.skip_symlinks = RAC_FALSE; + opts.archive_type_hint = RAC_ARCHIVE_TYPE_NONE; + + rac_extraction_result_t result = {}; + rac_result_t rc = rac_extract_archive_native( + archive_path.c_str(), dest_dir.c_str(), + &opts, nullptr, nullptr, &result); + ASSERT_EQ(rc, RAC_SUCCESS, "Extraction should succeed"); + + // Both files should be extracted (no skipping) + ASSERT_TRUE(file_exists(dest_dir + "/keep_content/file.txt"), + "file.txt should be extracted"); + ASSERT_TRUE(file_exists(dest_dir + "/keep_content/__MACOSX/meta.plist"), + "__MACOSX content should be extracted when skip_macos_resources=FALSE"); + + remove_dir(archive_dir); + remove_dir(dest_dir); + + return TEST_PASS(); +} + +// ============================================================================= +// Test: detect archive type from real tar.gz +// ============================================================================= + +static TestResult test_detect_real_tar_gz() { + if (!has_tar()) { + TestResult r; + r.passed = true; + r.details = "SKIPPED (tar not available)"; + return r; + } + + std::string archive_dir = create_temp_dir("detect_src"); + std::string archive_path = create_test_tar_gz(archive_dir); + ASSERT_TRUE(!archive_path.empty(), "Should create archive"); + + rac_archive_type_t type; + ASSERT_EQ(rac_detect_archive_type(archive_path.c_str(), &type), RAC_TRUE, + "Should detect real tar.gz archive"); + ASSERT_EQ(type, RAC_ARCHIVE_TYPE_TAR_GZ, "Should be TAR_GZ"); + + remove_dir(archive_dir); + + return TEST_PASS(); +} + +// ============================================================================= +// Test: detect archive type from real ZIP +// ============================================================================= + +static TestResult test_detect_real_zip() { + if (!has_zip()) { + TestResult r; + r.passed = true; + r.details = "SKIPPED (zip not available)"; + return r; + } + + std::string archive_dir = create_temp_dir("detectzip_src"); + std::string archive_path = create_test_zip(archive_dir); + ASSERT_TRUE(!archive_path.empty(), "Should create archive"); + + rac_archive_type_t type; + ASSERT_EQ(rac_detect_archive_type(archive_path.c_str(), &type), RAC_TRUE, + "Should detect real ZIP archive"); + ASSERT_EQ(type, RAC_ARCHIVE_TYPE_ZIP, "Should be ZIP"); + + remove_dir(archive_dir); + + return TEST_PASS(); +} + +// ============================================================================= +// Test: archive_type_extension helper +// ============================================================================= + +static TestResult test_archive_type_extension() { + ASSERT_TRUE(std::strcmp(rac_archive_type_extension(RAC_ARCHIVE_TYPE_ZIP), "zip") == 0, + "ZIP extension should be 'zip'"); + ASSERT_TRUE(std::strcmp(rac_archive_type_extension(RAC_ARCHIVE_TYPE_TAR_GZ), "tar.gz") == 0, + "TAR_GZ extension should be 'tar.gz'"); + ASSERT_TRUE(std::strcmp(rac_archive_type_extension(RAC_ARCHIVE_TYPE_TAR_BZ2), "tar.bz2") == 0, + "TAR_BZ2 extension should be 'tar.bz2'"); + ASSERT_TRUE(std::strcmp(rac_archive_type_extension(RAC_ARCHIVE_TYPE_TAR_XZ), "tar.xz") == 0, + "TAR_XZ extension should be 'tar.xz'"); + + return TEST_PASS(); +} + +// ============================================================================= +// Test: archive_type_from_path helper +// ============================================================================= + +static TestResult test_archive_type_from_path() { + rac_archive_type_t type; + + ASSERT_EQ(rac_archive_type_from_path("model.tar.gz", &type), RAC_TRUE, + "Should detect tar.gz from path"); + ASSERT_EQ(type, RAC_ARCHIVE_TYPE_TAR_GZ, "Should be TAR_GZ"); + + ASSERT_EQ(rac_archive_type_from_path("model.tar.bz2", &type), RAC_TRUE, + "Should detect tar.bz2 from path"); + ASSERT_EQ(type, RAC_ARCHIVE_TYPE_TAR_BZ2, "Should be TAR_BZ2"); + + ASSERT_EQ(rac_archive_type_from_path("model.zip", &type), RAC_TRUE, + "Should detect zip from path"); + ASSERT_EQ(type, RAC_ARCHIVE_TYPE_ZIP, "Should be ZIP"); + + ASSERT_EQ(rac_archive_type_from_path("model.tar.xz", &type), RAC_TRUE, + "Should detect tar.xz from path"); + ASSERT_EQ(type, RAC_ARCHIVE_TYPE_TAR_XZ, "Should be TAR_XZ"); + + ASSERT_EQ(rac_archive_type_from_path("model.gguf", &type), RAC_FALSE, + "Should not detect archive from .gguf"); + + return TEST_PASS(); +} + +// ============================================================================= +// Main: register tests and dispatch via CLI args +// ============================================================================= + +int main(int argc, char** argv) { + // Create shared temp directory for all tests + g_test_dir = create_temp_dir("extraction"); + if (g_test_dir.empty()) { + std::cerr << "FATAL: Cannot create temp directory\n"; + return 1; + } + + TestSuite suite("extraction"); + + // Null/error handling + suite.add("null_pointer", test_null_pointer); + suite.add("file_not_found", test_file_not_found); + suite.add("unsupported_format", test_unsupported_format); + + // Archive type detection (magic bytes) + suite.add("detect_null", test_detect_null); + suite.add("detect_nonexistent", test_detect_nonexistent); + suite.add("detect_zip", test_detect_zip); + suite.add("detect_gzip", test_detect_gzip); + suite.add("detect_bzip2", test_detect_bzip2); + suite.add("detect_xz", test_detect_xz); + suite.add("detect_unknown", test_detect_unknown); + suite.add("detect_empty_file", test_detect_empty_file); + suite.add("detect_real_tar_gz", test_detect_real_tar_gz); + suite.add("detect_real_zip", test_detect_real_zip); + + // Type helper functions + suite.add("archive_type_extension", test_archive_type_extension); + suite.add("archive_type_from_path", test_archive_type_from_path); + + // Extraction + suite.add("extract_tar_gz", test_extract_tar_gz); + suite.add("extract_zip", test_extract_zip); + suite.add("progress_callback", test_progress_callback_invoked); + suite.add("extraction_result_stats", test_extraction_result_stats); + suite.add("creates_dest_dir", test_creates_dest_dir); + + // Options + suite.add("default_options_skip_macos", test_default_options_skip_macos); + suite.add("custom_options_keep_macos", test_custom_options_keep_macos); + + int result = suite.run(argc, argv); + + // Cleanup shared temp directory + remove_dir(g_test_dir); + + return result; +} diff --git a/sdk/runanywhere-commons/tools/CMakeLists.txt b/sdk/runanywhere-commons/tools/CMakeLists.txt index 52f624821..2f250f449 100644 --- a/sdk/runanywhere-commons/tools/CMakeLists.txt +++ b/sdk/runanywhere-commons/tools/CMakeLists.txt @@ -38,7 +38,7 @@ find_package(Threads REQUIRED) target_link_libraries(runanywhere-server PRIVATE Threads::Threads) # Compiler features -target_compile_features(runanywhere-server PRIVATE cxx_std_17) +target_compile_features(runanywhere-server PRIVATE cxx_std_20) if(CMAKE_CXX_COMPILER_ID MATCHES "Clang|GNU") target_compile_options(runanywhere-server PRIVATE -Wall -Wextra) diff --git a/sdk/runanywhere-flutter/packages/runanywhere/lib/infrastructure/download/download_service.dart b/sdk/runanywhere-flutter/packages/runanywhere/lib/infrastructure/download/download_service.dart index a28143b4d..032ed6e2e 100644 --- a/sdk/runanywhere-flutter/packages/runanywhere/lib/infrastructure/download/download_service.dart +++ b/sdk/runanywhere-flutter/packages/runanywhere/lib/infrastructure/download/download_service.dart @@ -1,12 +1,16 @@ import 'dart:async'; +import 'dart:ffi'; import 'dart:io'; -import 'package:archive/archive.dart'; +import 'package:ffi/ffi.dart'; import 'package:http/http.dart' as http; import 'package:path/path.dart' as p; import 'package:runanywhere/core/types/model_types.dart'; import 'package:runanywhere/foundation/logging/sdk_logger.dart'; +import 'package:runanywhere/native/dart_bridge_download.dart'; import 'package:runanywhere/native/dart_bridge_model_paths.dart'; +import 'package:runanywhere/native/platform_loader.dart'; +import 'package:runanywhere/native/type_conversions/model_types_cpp_bridge.dart'; import 'package:runanywhere/public/events/event_bus.dart'; import 'package:runanywhere/public/events/sdk_event.dart'; import 'package:runanywhere/public/runanywhere.dart'; @@ -263,7 +267,8 @@ class ModelDownloadService { final extractedPath = await _extractArchive( downloadPath, destDir.path, - model.artifactType, + framework: model.framework, + format: model.format, ); finalModelPath = extractedPath; @@ -323,68 +328,60 @@ class ModelDownloadService { return Directory(modelPath); } - /// Extract an archive to the destination + /// Extract an archive to the destination using native C++ (libarchive). + /// Supports ZIP, TAR.GZ, TAR.BZ2, TAR.XZ with auto-detection. + /// Post-extraction model path finding is delegated to C++. Future _extractArchive( String archivePath, - String destDir, - ModelArtifactType artifactType, - ) async { + String destDir, { + required InferenceFramework framework, + required ModelFormat format, + }) async { _logger.info('Extracting archive: $archivePath'); - final archiveFile = File(archivePath); - final bytes = await archiveFile.readAsBytes(); - - Archive? archive; - - // Determine archive type - if (archivePath.endsWith('.tar.gz') || archivePath.endsWith('.tgz')) { - final gzDecoded = GZipDecoder().decodeBytes(bytes); - archive = TarDecoder().decodeBytes(gzDecoded); - } else if (archivePath.endsWith('.tar.bz2') || - archivePath.endsWith('.tbz2')) { - final bz2Decoded = BZip2Decoder().decodeBytes(bytes); - archive = TarDecoder().decodeBytes(bz2Decoded); - } else if (archivePath.endsWith('.zip')) { - archive = ZipDecoder().decodeBytes(bytes); - } else if (archivePath.endsWith('.tar')) { - archive = TarDecoder().decodeBytes(bytes); - } else { - _logger.warning('Unknown archive format: $archivePath'); - return archivePath; - } + final lib = PlatformLoader.loadCommons(); + final extractFn = lib.lookupFunction< + Int32 Function(Pointer, Pointer, Pointer, + Pointer, Pointer, Pointer), + int Function(Pointer, Pointer, Pointer, + Pointer, Pointer, Pointer)>( + 'rac_extract_archive_native', + ); - // Extract files - String? rootDir; - for (final file in archive) { - final filePath = p.join(destDir, file.name); - - if (file.isFile) { - final outFile = File(filePath); - await outFile.create(recursive: true); - await outFile.writeAsBytes(file.content as List); - _logger.debug('Extracted: ${file.name}'); - - // Track root directory - final parts = file.name.split('/'); - if (parts.isNotEmpty && rootDir == null) { - rootDir = parts.first; - } - } else { - await Directory(filePath).create(recursive: true); + final archivePathPtr = archivePath.toNativeUtf8(allocator: calloc); + final destPathPtr = destDir.toNativeUtf8(allocator: calloc); + + try { + final result = extractFn( + archivePathPtr, + destPathPtr, + nullptr, + nullptr, + nullptr, + nullptr, + ); + + if (result != 0) { + _logger.error('Native extraction failed with code: $result'); + throw Exception('Native extraction failed with code: $result'); } + } finally { + calloc.free(archivePathPtr); + calloc.free(destPathPtr); } _logger.info('Extraction complete: $destDir'); - // Return the model directory (could be a nested directory) - if (rootDir != null) { - final nestedPath = p.join(destDir, rootDir); - if (await Directory(nestedPath).exists()) { - return nestedPath; - } - } + // Use C++ to find the actual model path after extraction + // (handles nested directories, model file scanning, etc.) + final modelPath = DartBridgeDownload.findModelPathAfterExtraction( + extractedDir: destDir, + structure: 99, // RAC_ARCHIVE_STRUCTURE_UNKNOWN - auto-detect + framework: framework.toC(), + format: format.toC(), + ); - return destDir; + return modelPath ?? destDir; } /// Update model's local path after download diff --git a/sdk/runanywhere-flutter/packages/runanywhere/lib/infrastructure/file_management/services/simplified_file_manager.dart b/sdk/runanywhere-flutter/packages/runanywhere/lib/infrastructure/file_management/services/simplified_file_manager.dart index 424649a47..897786ca5 100644 --- a/sdk/runanywhere-flutter/packages/runanywhere/lib/infrastructure/file_management/services/simplified_file_manager.dart +++ b/sdk/runanywhere-flutter/packages/runanywhere/lib/infrastructure/file_management/services/simplified_file_manager.dart @@ -11,6 +11,7 @@ import 'package:path/path.dart' as path; import 'package:path_provider/path_provider.dart'; import 'package:runanywhere/core/types/storage_types.dart'; import 'package:runanywhere/foundation/logging/sdk_logger.dart'; +import 'package:runanywhere/native/dart_bridge_file_manager.dart'; /// File manager for RunAnywhere SDK /// Matches iOS SimplifiedFileManager from Infrastructure/FileManagement/Services/SimplifiedFileManager.swift @@ -161,19 +162,10 @@ class SimplifiedFileManager { } } - /// Calculate total size of all models + /// Calculate total size of all models (C++ recursive traversal) Future calculateModelsSize() async { _ensureInitialized(); - final modelsDir = Directory(path.join(_baseDirectory!.path, 'Models')); - if (!await modelsDir.exists()) return 0; - - int totalSize = 0; - await for (final entity in modelsDir.list(recursive: true)) { - if (entity is File) { - totalSize += await entity.length(); - } - } - return totalSize; + return DartBridgeFileManager.modelsStorageUsed(); } /// Get device storage info @@ -187,36 +179,18 @@ class SimplifiedFileManager { ); } - /// Clear all cache + /// Clear all cache (C++ handles delete + recreate) Future clearCache() async { _ensureInitialized(); - final cacheDir = Directory(path.join(_baseDirectory!.path, 'Cache')); - if (await cacheDir.exists()) { - await for (final entity in cacheDir.list()) { - if (entity is File) { - await entity.delete(); - } else if (entity is Directory) { - await entity.delete(recursive: true); - } - } - _logger.info('Cache cleared'); - } + DartBridgeFileManager.clearCache(); + _logger.info('Cache cleared'); } - /// Clear all temporary files + /// Clear all temporary files (C++ handles delete + recreate) Future clearTemp() async { _ensureInitialized(); - final tempDir = Directory(path.join(_baseDirectory!.path, 'Temp')); - if (await tempDir.exists()) { - await for (final entity in tempDir.list()) { - if (entity is File) { - await entity.delete(); - } else if (entity is Directory) { - await entity.delete(recursive: true); - } - } - _logger.info('Temp directory cleared'); - } + DartBridgeFileManager.clearTemp(); + _logger.info('Temp directory cleared'); } void _ensureInitialized() { diff --git a/sdk/runanywhere-flutter/packages/runanywhere/lib/native/dart_bridge.dart b/sdk/runanywhere-flutter/packages/runanywhere/lib/native/dart_bridge.dart index a73fc6f9e..eb76fb7bb 100644 --- a/sdk/runanywhere-flutter/packages/runanywhere/lib/native/dart_bridge.dart +++ b/sdk/runanywhere-flutter/packages/runanywhere/lib/native/dart_bridge.dart @@ -13,6 +13,7 @@ import 'package:runanywhere/native/dart_bridge_download.dart'; import 'package:runanywhere/native/dart_bridge_environment.dart' show RacSdkConfigStruct; import 'package:runanywhere/native/dart_bridge_events.dart'; +import 'package:runanywhere/native/dart_bridge_file_manager.dart'; import 'package:runanywhere/native/dart_bridge_http.dart'; import 'package:runanywhere/native/dart_bridge_llm.dart'; import 'package:runanywhere/native/dart_bridge_model_assignment.dart'; @@ -28,6 +29,7 @@ import 'package:runanywhere/native/dart_bridge_tts.dart'; import 'package:runanywhere/native/dart_bridge_vad.dart'; import 'package:runanywhere/native/dart_bridge_vlm.dart'; import 'package:runanywhere/native/dart_bridge_voice_agent.dart'; +import 'package:runanywhere/native/dart_bridge_lora.dart'; import 'package:runanywhere/native/dart_bridge_rag.dart'; import 'package:runanywhere/native/platform_loader.dart'; import 'package:runanywhere/public/configuration/sdk_environment.dart'; @@ -143,6 +145,10 @@ class DartBridge { DartBridgeDevice.registerCallbacks(); _logger.debug('Device callbacks registered'); + // Step 8: Register file manager I/O callbacks + DartBridgeFileManager.register(); + _logger.debug('File manager callbacks registered'); + _isInitialized = true; _logger.info('Phase 1 initialization complete'); } @@ -312,6 +318,13 @@ class DartBridge { /// RAG pipeline bridge static DartBridgeRAG get rag => DartBridgeRAG.shared; + /// LoRA adapter bridge + static DartBridgeLora get lora => DartBridgeLora.shared; + + /// LoRA registry bridge + static DartBridgeLoraRegistry get loraRegistry => + DartBridgeLoraRegistry.shared; + // ------------------------------------------------------------------------- // Private Helpers // ------------------------------------------------------------------------- diff --git a/sdk/runanywhere-flutter/packages/runanywhere/lib/native/dart_bridge_download.dart b/sdk/runanywhere-flutter/packages/runanywhere/lib/native/dart_bridge_download.dart index d7028d7a3..ef25aad95 100644 --- a/sdk/runanywhere-flutter/packages/runanywhere/lib/native/dart_bridge_download.dart +++ b/sdk/runanywhere-flutter/packages/runanywhere/lib/native/dart_bridge_download.dart @@ -114,6 +114,120 @@ class DartBridgeDownload { /// Get active download count int get activeDownloadCount => _activeTasks.length; + + // =========================================================================== + // Download Orchestrator Utilities (from rac_download_orchestrator.h) + // =========================================================================== + + /// Find the actual model path after extraction. + /// + /// Consolidates duplicated Dart logic for scanning extracted directories. + /// Uses C++ `rac_find_model_path_after_extraction()` which handles: + /// - Finding .gguf, .onnx, .ort, .bin files + /// - Nested directories (e.g., sherpa-onnx archives) + /// - Single-file-nested pattern + /// - Directory-based models (ONNX) + /// + /// [structure]: C++ archive structure constant (99 = unknown/auto-detect) + /// [framework]: C++ inference framework constant (from RacInferenceFramework) + /// [format]: C++ model format constant (from RacModelFormat) + static String? findModelPathAfterExtraction({ + required String extractedDir, + required int structure, + required int framework, + required int format, + }) { + try { + final lib = PlatformLoader.loadCommons(); + final fn = lib.lookupFunction< + Int32 Function( + Pointer, Int32, Int32, Int32, Pointer, IntPtr), + int Function(Pointer, int, int, int, Pointer, + int)>('rac_find_model_path_after_extraction'); + + final dirPtr = extractedDir.toNativeUtf8(); + final outPath = calloc(4096); + + try { + final result = fn( + dirPtr, structure, framework, format, outPath.cast(), 4096); + if (result != RacResultCode.success) return null; + return outPath.cast().toDartString(); + } finally { + calloc.free(dirPtr); + calloc.free(outPath); + } + } catch (e) { + _logger.debug('rac_find_model_path_after_extraction not available: $e'); + return null; + } + } + + /// Check if a download URL requires extraction. + /// + /// Wraps C++ `rac_download_requires_extraction()` which checks URL suffix + /// for archive extensions (.tar.gz, .tar.bz2, .tar.xz, .zip). + static bool downloadRequiresExtraction(String url) { + try { + final lib = PlatformLoader.loadCommons(); + final fn = lib.lookupFunction), + int Function(Pointer)>('rac_download_requires_extraction'); + + final urlPtr = url.toNativeUtf8(); + try { + return fn(urlPtr) == RAC_TRUE; + } finally { + calloc.free(urlPtr); + } + } catch (e) { + _logger.debug('rac_download_requires_extraction not available: $e'); + return false; + } + } + + /// Compute the download destination path for a model. + /// + /// Wraps C++ `rac_download_compute_destination()`. + /// Returns the destination path and whether extraction is needed, + /// or null if the computation fails. + static ({String path, bool needsExtraction})? computeDownloadDestination({ + required String modelId, + required String downloadUrl, + required int framework, + required int format, + }) { + try { + final lib = PlatformLoader.loadCommons(); + final fn = lib.lookupFunction< + Int32 Function(Pointer, Pointer, Int32, Int32, + Pointer, IntPtr, Pointer), + int Function(Pointer, Pointer, int, int, Pointer, + int, Pointer)>('rac_download_compute_destination'); + + final modelIdPtr = modelId.toNativeUtf8(); + final urlPtr = downloadUrl.toNativeUtf8(); + final outPath = calloc(4096); + final outNeedsExtraction = calloc(); + + try { + final result = fn(modelIdPtr, urlPtr, framework, format, + outPath.cast(), 4096, outNeedsExtraction); + if (result != RacResultCode.success) return null; + return ( + path: outPath.cast().toDartString(), + needsExtraction: outNeedsExtraction.value == RAC_TRUE, + ); + } finally { + calloc.free(modelIdPtr); + calloc.free(urlPtr); + calloc.free(outPath); + calloc.free(outNeedsExtraction); + } + } catch (e) { + _logger.debug('rac_download_compute_destination not available: $e'); + return null; + } + } } class _DownloadTask { diff --git a/sdk/runanywhere-flutter/packages/runanywhere/lib/native/dart_bridge_file_manager.dart b/sdk/runanywhere-flutter/packages/runanywhere/lib/native/dart_bridge_file_manager.dart new file mode 100644 index 000000000..61de72b05 --- /dev/null +++ b/sdk/runanywhere-flutter/packages/runanywhere/lib/native/dart_bridge_file_manager.dart @@ -0,0 +1,472 @@ +// ignore_for_file: avoid_classes_with_only_static_members + +import 'dart:ffi'; +import 'dart:io'; + +import 'package:ffi/ffi.dart'; +import 'package:runanywhere/foundation/logging/sdk_logger.dart'; +import 'package:runanywhere/native/ffi_types.dart'; +import 'package:runanywhere/native/platform_loader.dart'; + +// ============================================================================= +// Exception Return Constants (must be compile-time constants for FFI) +// ============================================================================= + +const int _errorDirectoryCreationFailed = -189; +const int _errorDeleteFailed = -187; +const int _errorFileNotFound = -183; +const int _falseReturn = 0; +const int _negativeReturn = -1; + +// ============================================================================= +// File Manager Bridge +// ============================================================================= + +/// File manager bridge to C++ rac_file_manager. +/// C++ owns business logic; Dart provides thin I/O callbacks. +/// Matches iOS CppBridge+FileManager.swift / Kotlin CppBridgeFileManager.kt. +class DartBridgeFileManager { + DartBridgeFileManager._(); + + static final _logger = SDKLogger('DartBridge.FileManager'); + static final DartBridgeFileManager instance = DartBridgeFileManager._(); + + static bool _isRegistered = false; + static Pointer? _callbacksPtr; + + /// Register file manager callbacks. Call during SDK init. + static void register() { + if (_isRegistered) return; + + _callbacksPtr = calloc(); + final cb = _callbacksPtr!; + + cb.ref.createDirectory = + Pointer.fromFunction( + _createDirectoryCallback, _errorDirectoryCreationFailed); + cb.ref.deletePath = Pointer.fromFunction( + _deletePathCallback, _errorDeleteFailed); + cb.ref.listDirectory = + Pointer.fromFunction( + _listDirectoryCallback, _errorFileNotFound); + cb.ref.freeEntries = Pointer.fromFunction( + _freeEntriesCallback); + cb.ref.pathExists = Pointer.fromFunction( + _pathExistsCallback, _falseReturn); + cb.ref.getFileSize = Pointer.fromFunction( + _getFileSizeCallback, _negativeReturn); + cb.ref.getAvailableSpace = + Pointer.fromFunction( + _getAvailableSpaceCallback, 0); + cb.ref.getTotalSpace = Pointer.fromFunction( + _getTotalSpaceCallback, 0); + cb.ref.userData = nullptr; + + _isRegistered = true; + _logger.debug('File manager callbacks registered'); + } + + /// Cleanup + static void unregister() { + if (_callbacksPtr != null) { + calloc.free(_callbacksPtr!); + _callbacksPtr = null; + } + _isRegistered = false; + } + + // ========================================================================= + // Public API + // ========================================================================= + + /// Create directory structure (Models, Cache, Temp, Downloads). + static bool createDirectoryStructure() { + final lib = _lib(); + if (lib == null || _callbacksPtr == null) return false; + final fn = lib.lookupFunction), + int Function(Pointer)>( + 'rac_file_manager_create_directory_structure'); + return fn(_callbacksPtr!) == RacResultCode.success; + } + + /// Calculate directory size recursively. + static int calculateDirectorySize(String path) { + final lib = _lib(); + if (lib == null || _callbacksPtr == null) return 0; + final fn = lib.lookupFunction< + Int32 Function( + Pointer, Pointer, Pointer), + int Function(Pointer, Pointer, + Pointer)>('rac_file_manager_calculate_dir_size'); + + final pathPtr = path.toNativeUtf8(); + final sizePtr = calloc(); + try { + fn(_callbacksPtr!, pathPtr, sizePtr); + return sizePtr.value; + } finally { + calloc.free(pathPtr); + calloc.free(sizePtr); + } + } + + /// Get total models storage used. + static int modelsStorageUsed() { + final lib = _lib(); + if (lib == null || _callbacksPtr == null) return 0; + final fn = lib.lookupFunction< + Int32 Function(Pointer, Pointer), + int Function( + Pointer, Pointer)>( + 'rac_file_manager_models_storage_used'); + + final sizePtr = calloc(); + try { + fn(_callbacksPtr!, sizePtr); + return sizePtr.value; + } finally { + calloc.free(sizePtr); + } + } + + /// Clear cache directory. + static bool clearCache() { + final lib = _lib(); + if (lib == null || _callbacksPtr == null) return false; + final fn = lib.lookupFunction), + int Function(Pointer)>( + 'rac_file_manager_clear_cache'); + return fn(_callbacksPtr!) == RacResultCode.success; + } + + /// Clear temp directory. + static bool clearTemp() { + final lib = _lib(); + if (lib == null || _callbacksPtr == null) return false; + final fn = lib.lookupFunction), + int Function(Pointer)>( + 'rac_file_manager_clear_temp'); + return fn(_callbacksPtr!) == RacResultCode.success; + } + + /// Get cache size. + static int cacheSize() { + final lib = _lib(); + if (lib == null || _callbacksPtr == null) return 0; + final fn = lib.lookupFunction< + Int32 Function(Pointer, Pointer), + int Function( + Pointer, Pointer)>( + 'rac_file_manager_cache_size'); + + final sizePtr = calloc(); + try { + fn(_callbacksPtr!, sizePtr); + return sizePtr.value; + } finally { + calloc.free(sizePtr); + } + } + + /// Create a model folder and return its path. + static String? createModelFolder(String modelId, int framework) { + final lib = _lib(); + if (lib == null || _callbacksPtr == null) return null; + final fn = lib.lookupFunction< + Int32 Function(Pointer, Pointer, Int32, + Pointer, Size), + int Function(Pointer, Pointer, int, + Pointer, int)>('rac_file_manager_create_model_folder'); + + final modelIdPtr = modelId.toNativeUtf8(); + const bufSize = 1024; + final outPath = calloc(bufSize).cast(); + try { + final result = fn(_callbacksPtr!, modelIdPtr, framework, outPath, bufSize); + if (result != RacResultCode.success) return null; + return outPath.toDartString(); + } finally { + calloc.free(modelIdPtr); + calloc.free(outPath); + } + } + + /// Check if a model folder exists. + static bool modelFolderExists(String modelId, int framework) { + final lib = _lib(); + if (lib == null || _callbacksPtr == null) return false; + final fn = lib.lookupFunction< + Int32 Function(Pointer, Pointer, Int32, + Pointer, Pointer), + int Function(Pointer, Pointer, int, + Pointer, Pointer)>( + 'rac_file_manager_model_folder_exists'); + + final modelIdPtr = modelId.toNativeUtf8(); + final existsPtr = calloc(); + try { + fn(_callbacksPtr!, modelIdPtr, framework, existsPtr, nullptr); + return existsPtr.value == RAC_TRUE; + } finally { + calloc.free(modelIdPtr); + calloc.free(existsPtr); + } + } + + /// Get combined storage information. + static NativeStorageInfo? getStorageInfo() { + final lib = _lib(); + if (lib == null || _callbacksPtr == null) return null; + final fn = lib.lookupFunction< + Int32 Function( + Pointer, Pointer), + int Function(Pointer, + Pointer)>( + 'rac_file_manager_get_storage_info'); + + final infoPtr = calloc(); + try { + final result = fn(_callbacksPtr!, infoPtr); + if (result != RacResultCode.success) return null; + return NativeStorageInfo( + deviceTotal: infoPtr.ref.deviceTotal, + deviceFree: infoPtr.ref.deviceFree, + modelsSize: infoPtr.ref.modelsSize, + cacheSize: infoPtr.ref.cacheSize, + tempSize: infoPtr.ref.tempSize, + totalAppSize: infoPtr.ref.totalAppSize, + ); + } finally { + calloc.free(infoPtr); + } + } + + /// Check storage availability via C++ rac_file_manager_check_storage. + /// Returns full availability result including warnings and recommendations. + static NativeStorageAvailability? checkStorageAvailability(int requiredBytes) { + final lib = _lib(); + if (lib == null || _callbacksPtr == null) return null; + final fn = lib.lookupFunction< + Int32 Function(Pointer, Int64, + Pointer), + int Function(Pointer, int, + Pointer)>( + 'rac_file_manager_check_storage'); + + final availPtr = calloc(); + try { + final result = fn(_callbacksPtr!, requiredBytes, availPtr); + if (result != RacResultCode.success) return null; + final rec = availPtr.ref.recommendation; + return NativeStorageAvailability( + isAvailable: availPtr.ref.isAvailable == RAC_TRUE, + requiredSpace: availPtr.ref.requiredSpace, + availableSpace: availPtr.ref.availableSpace, + hasWarning: availPtr.ref.hasWarning == RAC_TRUE, + recommendation: rec != nullptr ? rec.toDartString() : null, + ); + } finally { + calloc.free(availPtr); + } + } + + /// Check storage availability for a given number of bytes. + /// Convenience wrapper that returns a simple bool. + static bool checkStorage(int requiredBytes) { + final result = checkStorageAvailability(requiredBytes); + if (result == null) return true; // Default to available if check fails + return result.isAvailable; + } + + /// Delete a model folder. + static bool deleteModel(String modelId, int framework) { + final lib = _lib(); + if (lib == null || _callbacksPtr == null) return false; + final fn = lib.lookupFunction< + Int32 Function( + Pointer, Pointer, Int32), + int Function(Pointer, Pointer, + int)>('rac_file_manager_delete_model'); + + final modelIdPtr = modelId.toNativeUtf8(); + try { + return fn(_callbacksPtr!, modelIdPtr, framework) == + RacResultCode.success; + } finally { + calloc.free(modelIdPtr); + } + } + + // ========================================================================= + // Private helpers + // ========================================================================= + + static DynamicLibrary? _lib() { + try { + return PlatformLoader.loadCommons(); + } catch (e) { + _logger.debug('Native library not available: $e'); + return null; + } + } +} + +// ============================================================================= +// Storage Info Data Class +// ============================================================================= + +/// Storage availability result from C++ rac_file_manager_check_storage. +class NativeStorageAvailability { + final bool isAvailable; + final int requiredSpace; + final int availableSpace; + final bool hasWarning; + final String? recommendation; + + const NativeStorageAvailability({ + required this.isAvailable, + required this.requiredSpace, + required this.availableSpace, + required this.hasWarning, + this.recommendation, + }); +} + +/// Combined storage information from C++ file manager. +class NativeStorageInfo { + final int deviceTotal; + final int deviceFree; + final int modelsSize; + final int cacheSize; + final int tempSize; + final int totalAppSize; + + const NativeStorageInfo({ + required this.deviceTotal, + required this.deviceFree, + required this.modelsSize, + required this.cacheSize, + required this.tempSize, + required this.totalAppSize, + }); +} + +// ============================================================================= +// C Callbacks (Platform I/O) +// ============================================================================= + +int _createDirectoryCallback( + Pointer path, int recursive, Pointer userData) { + try { + final dir = Directory(path.toDartString()); + if (recursive != 0) { + dir.createSync(recursive: true); + } else { + dir.createSync(); + } + return RacResultCode.success; + } catch (_) { + return _errorDirectoryCreationFailed; + } +} + +int _deletePathCallback( + Pointer path, int recursive, Pointer userData) { + try { + final pathStr = path.toDartString(); + final type = FileSystemEntity.typeSync(pathStr); + if (type == FileSystemEntityType.notFound) return RacResultCode.success; + + if (type == FileSystemEntityType.directory) { + Directory(pathStr).deleteSync(recursive: recursive != 0); + } else { + File(pathStr).deleteSync(); + } + return RacResultCode.success; + } catch (_) { + return _errorDeleteFailed; + } +} + +int _listDirectoryCallback( + Pointer path, + Pointer>> outEntries, + Pointer outCount, + Pointer userData, +) { + try { + final dir = Directory(path.toDartString()); + if (!dir.existsSync()) { + outEntries.value = nullptr; + outCount.value = 0; + return _errorFileNotFound; + } + + final contents = dir.listSync(); + final count = contents.length; + + final entries = calloc>(count); + for (var i = 0; i < count; i++) { + final name = contents[i].uri.pathSegments.lastWhere((s) => s.isNotEmpty); + entries[i] = name.toNativeUtf8(); + } + + outEntries.value = entries; + outCount.value = count; + return RacResultCode.success; + } catch (_) { + outEntries.value = nullptr; + outCount.value = 0; + return _errorFileNotFound; + } +} + +void _freeEntriesCallback( + Pointer> entries, int count, Pointer userData) { + if (entries == nullptr) return; + for (var i = 0; i < count; i++) { + if (entries[i] != nullptr) { + calloc.free(entries[i]); + } + } + calloc.free(entries); +} + +int _pathExistsCallback( + Pointer path, Pointer outIsDirectory, Pointer userData) { + try { + final pathStr = path.toDartString(); + final type = FileSystemEntity.typeSync(pathStr); + if (type == FileSystemEntityType.notFound) return RAC_FALSE; + + if (outIsDirectory != nullptr) { + outIsDirectory.value = + type == FileSystemEntityType.directory ? RAC_TRUE : RAC_FALSE; + } + return RAC_TRUE; + } catch (_) { + return RAC_FALSE; + } +} + +int _getFileSizeCallback(Pointer path, Pointer userData) { + try { + final file = File(path.toDartString()); + if (file.existsSync()) { + return file.lengthSync(); + } + return -1; + } catch (_) { + return -1; + } +} + +int _getAvailableSpaceCallback(Pointer userData) { + // Dart doesn't have a direct API for disk space. + // Return 0 to indicate unknown (C++ will handle gracefully). + return 0; +} + +int _getTotalSpaceCallback(Pointer userData) { + return 0; +} diff --git a/sdk/runanywhere-flutter/packages/runanywhere/lib/native/dart_bridge_lora.dart b/sdk/runanywhere-flutter/packages/runanywhere/lib/native/dart_bridge_lora.dart new file mode 100644 index 000000000..78a293790 --- /dev/null +++ b/sdk/runanywhere-flutter/packages/runanywhere/lib/native/dart_bridge_lora.dart @@ -0,0 +1,500 @@ +/// DartBridge+LoRA +/// +/// LoRA adapter bridge - manages C++ LoRA operations via FFI. +/// Mirrors Swift's CppBridge+LLM.swift LoRA section and +/// CppBridge+LoraRegistry.swift. +/// +/// Two classes: +/// - [DartBridgeLora] - Runtime LoRA operations (load/remove/clear/info) +/// - [DartBridgeLoraRegistry] - Catalog registry (register/query adapters) +library dart_bridge_lora; + +import 'dart:convert'; +import 'dart:ffi'; + +import 'package:ffi/ffi.dart'; + +import 'package:runanywhere/foundation/logging/sdk_logger.dart'; +import 'package:runanywhere/native/dart_bridge_llm.dart'; +import 'package:runanywhere/native/ffi_types.dart'; +import 'package:runanywhere/native/platform_loader.dart'; +import 'package:runanywhere/public/types/lora_types.dart'; + +// ============================================================================= +// FFI Struct: rac_lora_entry_t +// ============================================================================= + +/// Matches C struct rac_lora_entry_t from rac_lora_registry.h. +/// Field order MUST match the C struct exactly. +base class RacLoraEntryCStruct extends Struct { + // char* id + external Pointer id; + + // char* name + external Pointer name; + + // char* description + external Pointer description; + + // char* download_url + external Pointer downloadUrl; + + // char* filename + external Pointer filename; + + // char** compatible_model_ids + external Pointer> compatibleModelIds; + + // size_t compatible_model_count + @IntPtr() + external int compatibleModelCount; + + // int64_t file_size + @Int64() + external int fileSize; + + // float default_scale + @Float() + external double defaultScale; +} + +// ============================================================================= +// LoRA Runtime Operations (via LLM Component) +// ============================================================================= + +/// LoRA adapter bridge for runtime operations. +/// +/// Uses the LLM component handle - LoRA ops are on the LLM component in C++. +/// Matches Swift CppBridge.LLM LoRA methods. +class DartBridgeLora { + // MARK: - Singleton + + static final DartBridgeLora shared = DartBridgeLora._(); + + DartBridgeLora._(); + + final _logger = SDKLogger('DartBridge.LoRA'); + + // MARK: - LoRA Adapter Management + + /// Load a LoRA adapter and apply it to the current model. + /// + /// Context is recreated internally and KV cache is cleared. + /// Throws on failure. + void loadAdapter(String adapterPath, double scale) { + final handle = DartBridgeLLM.shared.getHandle(); + + final pathPtr = adapterPath.toNativeUtf8(); + try { + final lib = PlatformLoader.loadCommons(); + final fn = lib.lookupFunction< + Int32 Function(RacHandle, Pointer, Float), + int Function(RacHandle, Pointer, double)>( + 'rac_llm_component_load_lora', + ); + + final result = fn(handle, pathPtr, scale); + if (result != RAC_SUCCESS) { + throw StateError( + 'Failed to load LoRA adapter: ${RacResultCode.getMessage(result)}', + ); + } + _logger.info('LoRA adapter loaded: $adapterPath (scale=$scale)'); + } finally { + calloc.free(pathPtr); + } + } + + /// Remove a specific LoRA adapter by path. + /// + /// KV cache is cleared automatically. + /// Throws on failure. + void removeAdapter(String adapterPath) { + final handle = DartBridgeLLM.shared.getHandle(); + + final pathPtr = adapterPath.toNativeUtf8(); + try { + final lib = PlatformLoader.loadCommons(); + final fn = lib.lookupFunction< + Int32 Function(RacHandle, Pointer), + int Function(RacHandle, Pointer)>( + 'rac_llm_component_remove_lora', + ); + + final result = fn(handle, pathPtr); + if (result != RAC_SUCCESS) { + throw StateError( + 'Failed to remove LoRA adapter: ${RacResultCode.getMessage(result)}', + ); + } + _logger.info('LoRA adapter removed: $adapterPath'); + } finally { + calloc.free(pathPtr); + } + } + + /// Remove all LoRA adapters. + /// + /// KV cache is cleared automatically. + /// Throws on failure. + void clearAdapters() { + final handle = DartBridgeLLM.shared.getHandle(); + + final lib = PlatformLoader.loadCommons(); + final fn = lib.lookupFunction< + Int32 Function(RacHandle), + int Function(RacHandle)>( + 'rac_llm_component_clear_lora', + ); + + final result = fn(handle); + if (result != RAC_SUCCESS) { + throw StateError( + 'Failed to clear LoRA adapters: ${RacResultCode.getMessage(result)}', + ); + } + _logger.info('All LoRA adapters cleared'); + } + + /// Get info about currently loaded LoRA adapters. + /// + /// Returns a list parsed from JSON: [{"path":"...", "scale":1.0, "applied":true}] + List getLoadedAdapters() { + final handle = DartBridgeLLM.shared.getHandle(); + + final outJsonPtr = calloc>(); + try { + final lib = PlatformLoader.loadCommons(); + final fn = lib.lookupFunction< + Int32 Function(RacHandle, Pointer>), + int Function(RacHandle, Pointer>)>( + 'rac_llm_component_get_lora_info', + ); + + final result = fn(handle, outJsonPtr); + if (result != RAC_SUCCESS) { + _logger.error('Failed to get LoRA info: $result'); + return []; + } + + final jsonPtr = outJsonPtr.value; + if (jsonPtr == nullptr) return []; + + final jsonStr = jsonPtr.toDartString(); + + // Free the C-allocated JSON string + final freeFn = lib.lookupFunction< + Void Function(Pointer), + void Function(Pointer)>('rac_free'); + freeFn(jsonPtr.cast()); + + return _parseAdapterInfoJson(jsonStr); + } finally { + calloc.free(outJsonPtr); + } + } + + /// Check if the current backend supports LoRA adapters. + LoraCompatibilityResult checkCompatibility(String loraPath) { + final handle = DartBridgeLLM.shared.getHandle(); + + final pathPtr = loraPath.toNativeUtf8(); + final outErrorPtr = calloc>(); + try { + final lib = PlatformLoader.loadCommons(); + final fn = lib.lookupFunction< + Int32 Function(RacHandle, Pointer, Pointer>), + int Function(RacHandle, Pointer, Pointer>)>( + 'rac_llm_component_check_lora_compat', + ); + + final result = fn(handle, pathPtr, outErrorPtr); + + if (result == RAC_SUCCESS) { + return const LoraCompatibilityResult(isCompatible: true); + } + + // Read error message + String? errorMsg; + final errorPtr = outErrorPtr.value; + if (errorPtr != nullptr) { + errorMsg = errorPtr.toDartString(); + // Free the C-allocated error string + final freeFn = lib.lookupFunction< + Void Function(Pointer), + void Function(Pointer)>('rac_free'); + freeFn(errorPtr.cast()); + } + + return LoraCompatibilityResult( + isCompatible: false, + error: errorMsg, + ); + } finally { + calloc.free(pathPtr); + calloc.free(outErrorPtr); + } + } + + // MARK: - Private Helpers + + List _parseAdapterInfoJson(String jsonStr) { + try { + final list = jsonDecode(jsonStr) as List; + return list.map((item) { + final map = item as Map; + return LoRAAdapterInfo( + path: (map['path'] as String?) ?? '', + scale: ((map['scale'] as num?) ?? 1.0).toDouble(), + applied: (map['applied'] as bool?) ?? false, + ); + }).toList(); + } catch (e) { + _logger.error('Failed to parse LoRA info JSON: $e'); + return []; + } + } +} + +// ============================================================================= +// LoRA Registry (Catalog Operations) +// ============================================================================= + +/// LoRA adapter registry bridge for catalog operations. +/// +/// Uses the global C++ registry singleton via rac_register_lora / rac_get_lora_for_model. +/// Matches Swift CppBridge.LoraRegistry. +class DartBridgeLoraRegistry { + // MARK: - Singleton + + static final DartBridgeLoraRegistry shared = DartBridgeLoraRegistry._(); + + DartBridgeLoraRegistry._(); + + final _logger = SDKLogger('DartBridge.LoRA.Registry'); + + // MARK: - Registry Operations + + /// Register a LoRA adapter in the global registry. + /// + /// Entry is deep-copied internally by C++. + /// Throws on failure. + void register(LoraAdapterCatalogEntry entry) { + final lib = PlatformLoader.loadCommons(); + + final strdupFn = lib.lookupFunction< + Pointer Function(Pointer), + Pointer Function(Pointer)>('rac_strdup'); + + final registerFn = lib.lookupFunction< + Int32 Function(Pointer), + int Function(Pointer)>('rac_register_lora'); + + // Allocate C struct on Dart heap + final entryPtr = calloc(); + + // Temporary Dart strings for conversion + final idDart = entry.id.toNativeUtf8(); + final nameDart = entry.name.toNativeUtf8(); + final descDart = entry.description.toNativeUtf8(); + final urlDart = entry.downloadUrl.toNativeUtf8(); + final filenameDart = entry.filename.toNativeUtf8(); + + // Allocate compatible model IDs array + final compatCount = entry.compatibleModelIds.length; + final compatArrayPtr = calloc>(compatCount); + final compatDartPtrs = >[]; + + try { + // Fill string fields using strdup (C heap allocation) + entryPtr.ref.id = strdupFn(idDart); + entryPtr.ref.name = strdupFn(nameDart); + entryPtr.ref.description = strdupFn(descDart); + entryPtr.ref.downloadUrl = strdupFn(urlDart); + entryPtr.ref.filename = strdupFn(filenameDart); + + // Fill compatible model IDs + for (int i = 0; i < compatCount; i++) { + final dartPtr = entry.compatibleModelIds[i].toNativeUtf8(); + compatDartPtrs.add(dartPtr); + compatArrayPtr[i] = strdupFn(dartPtr); + } + entryPtr.ref.compatibleModelIds = compatArrayPtr; + entryPtr.ref.compatibleModelCount = compatCount; + entryPtr.ref.fileSize = entry.fileSize; + entryPtr.ref.defaultScale = entry.defaultScale; + + final result = registerFn(entryPtr); + if (result != RAC_SUCCESS) { + throw StateError( + 'Failed to register LoRA adapter "${entry.id}": ${RacResultCode.getMessage(result)}', + ); + } + _logger.info('LoRA adapter registered: ${entry.id}'); + } finally { + // Free Dart-allocated temporary strings + calloc.free(idDart); + calloc.free(nameDart); + calloc.free(descDart); + calloc.free(urlDart); + calloc.free(filenameDart); + for (final ptr in compatDartPtrs) { + calloc.free(ptr); + } + + // Free the C struct fields (strdup'd strings) via rac_lora_entry_free + // But we used calloc for the struct itself, so we need to free the + // strdup'd strings individually. C deep-copied on register, so the + // strdup'd pointers in the struct need to be freed. + final cFreeFn = lib.lookupFunction< + Void Function(Pointer), + void Function(Pointer)>('free'); + + // Free strdup'd strings inside the struct + if (entryPtr.ref.id != nullptr) cFreeFn(entryPtr.ref.id.cast()); + if (entryPtr.ref.name != nullptr) cFreeFn(entryPtr.ref.name.cast()); + if (entryPtr.ref.description != nullptr) { + cFreeFn(entryPtr.ref.description.cast()); + } + if (entryPtr.ref.downloadUrl != nullptr) { + cFreeFn(entryPtr.ref.downloadUrl.cast()); + } + if (entryPtr.ref.filename != nullptr) { + cFreeFn(entryPtr.ref.filename.cast()); + } + // Free strdup'd compatible model IDs + for (int i = 0; i < compatCount; i++) { + if (compatArrayPtr[i] != nullptr) { + cFreeFn(compatArrayPtr[i].cast()); + } + } + calloc.free(compatArrayPtr); + calloc.free(entryPtr); + } + } + + /// Get all registered LoRA adapters compatible with a model. + List getForModel(String modelId) { + final lib = PlatformLoader.loadCommons(); + + final modelIdPtr = modelId.toNativeUtf8(); + final outEntriesPtr = calloc>>(); + final outCountPtr = calloc(); + + try { + final fn = lib.lookupFunction< + Int32 Function(Pointer, + Pointer>>, Pointer), + int Function( + Pointer, + Pointer>>, + Pointer)>('rac_get_lora_for_model'); + + final result = fn(modelIdPtr, outEntriesPtr, outCountPtr); + if (result != RAC_SUCCESS) { + _logger.error('Failed to get LoRA adapters for model $modelId'); + return []; + } + + return _readEntryArray(lib, outEntriesPtr.value, outCountPtr.value); + } finally { + calloc.free(modelIdPtr); + calloc.free(outEntriesPtr); + calloc.free(outCountPtr); + } + } + + /// Get all registered LoRA adapters. + List getAll() { + final lib = PlatformLoader.loadCommons(); + + // Use the registry handle to call get_all + final getRegistryFn = lib.lookupFunction< + Pointer Function(), + Pointer Function()>('rac_get_lora_registry'); + + final registry = getRegistryFn(); + if (registry == nullptr) return []; + + final outEntriesPtr = calloc>>(); + final outCountPtr = calloc(); + + try { + final fn = lib.lookupFunction< + Int32 Function(Pointer, + Pointer>>, Pointer), + int Function( + Pointer, + Pointer>>, + Pointer)>('rac_lora_registry_get_all'); + + final result = fn(registry, outEntriesPtr, outCountPtr); + if (result != RAC_SUCCESS) { + _logger.error('Failed to get all LoRA adapters'); + return []; + } + + return _readEntryArray(lib, outEntriesPtr.value, outCountPtr.value); + } finally { + calloc.free(outEntriesPtr); + calloc.free(outCountPtr); + } + } + + // MARK: - Private Helpers + + /// Read an array of rac_lora_entry_t* pointers and convert to Dart. + List _readEntryArray( + DynamicLibrary lib, + Pointer> entriesPtr, + int count, + ) { + if (entriesPtr == nullptr || count <= 0) return []; + + final freeFn = lib.lookupFunction< + Void Function(Pointer>, IntPtr), + void Function( + Pointer>, int)>('rac_lora_entry_array_free'); + + try { + final results = []; + for (int i = 0; i < count; i++) { + final entryPtr = entriesPtr[i]; + if (entryPtr == nullptr) continue; + + final entry = entryPtr.ref; + + // Read compatible model IDs + final compatIds = []; + if (entry.compatibleModelIds != nullptr) { + for (int j = 0; j < entry.compatibleModelCount; j++) { + final idPtr = entry.compatibleModelIds[j]; + if (idPtr != nullptr) { + compatIds.add(idPtr.toDartString()); + } + } + } + + results.add(LoraAdapterCatalogEntry( + id: entry.id != nullptr ? entry.id.toDartString() : '', + name: entry.name != nullptr ? entry.name.toDartString() : '', + description: entry.description != nullptr + ? entry.description.toDartString() + : '', + downloadUrl: entry.downloadUrl != nullptr + ? entry.downloadUrl.toDartString() + : '', + filename: entry.filename != nullptr + ? entry.filename.toDartString() + : '', + compatibleModelIds: compatIds, + fileSize: entry.fileSize, + defaultScale: entry.defaultScale, + )); + } + return results; + } finally { + freeFn(entriesPtr, count); + } + } +} diff --git a/sdk/runanywhere-flutter/packages/runanywhere/lib/native/dart_bridge_model_paths.dart b/sdk/runanywhere-flutter/packages/runanywhere/lib/native/dart_bridge_model_paths.dart index 5bba19572..775f6643b 100644 --- a/sdk/runanywhere-flutter/packages/runanywhere/lib/native/dart_bridge_model_paths.dart +++ b/sdk/runanywhere-flutter/packages/runanywhere/lib/native/dart_bridge_model_paths.dart @@ -6,8 +6,10 @@ import 'package:path_provider/path_provider.dart'; import 'package:runanywhere/core/types/model_types.dart'; import 'package:runanywhere/foundation/logging/sdk_logger.dart'; +import 'package:runanywhere/native/dart_bridge_download.dart'; import 'package:runanywhere/native/ffi_types.dart'; import 'package:runanywhere/native/platform_loader.dart'; +import 'package:runanywhere/native/type_conversions/model_types_cpp_bridge.dart'; /// Model path utilities bridge. /// Wraps C++ rac_model_paths.h functions. @@ -160,102 +162,23 @@ class DartBridgeModelPaths { } // MARK: - Model File Resolution - // Matches Swift: resolveModelFilePath(for:) /// Resolve the actual model file path for loading. - /// For single-file models (LlamaCpp), finds the actual .gguf file. - /// For directory-based models (ONNX), returns the folder. + /// Delegates to C++ rac_find_model_path_after_extraction() which handles + /// all model types: ONNX directories, LlamaCpp .gguf files, nested structures. Future resolveModelFilePath(ModelInfo model) async { final modelFolder = getModelFolder(model.id, model.framework); if (modelFolder == null) return null; - // For ONNX models (directory-based), find the model directory - if (model.framework == InferenceFramework.onnx) { - return _resolveONNXModelPath(modelFolder, model.id); - } - - // For single-file models (LlamaCpp), find the actual file - return _resolveSingleFileModelPath(modelFolder, model); - } - - /// Resolve ONNX model directory path - String _resolveONNXModelPath(String modelFolder, String modelId) { - // Check if there's a nested folder with the model name - final nestedFolder = '$modelFolder/$modelId'; - if (Directory(nestedFolder).existsSync()) { - if (_hasONNXModelFiles(nestedFolder)) { - _logger.info('Found ONNX model at nested path: $nestedFolder'); - return nestedFolder; - } - } - - // Check if model files exist directly in the folder - if (_hasONNXModelFiles(modelFolder)) { - _logger.info('Found ONNX model at folder: $modelFolder'); - return modelFolder; - } + // Use C++ to find the actual model path (handles all frameworks/formats) + final resolved = DartBridgeDownload.findModelPathAfterExtraction( + extractedDir: modelFolder, + structure: 99, // RAC_ARCHIVE_STRUCTURE_UNKNOWN - auto-detect + framework: _frameworkToCValue(model.framework), + format: model.format.toC(), + ); - // Scan for any subdirectory with model files - final dir = Directory(modelFolder); - if (dir.existsSync()) { - for (final entity in dir.listSync()) { - if (entity is Directory && _hasONNXModelFiles(entity.path)) { - _logger.info('Found ONNX model in subdirectory: ${entity.path}'); - return entity.path; - } - } - } - - // Fallback - _logger.warning('No ONNX model files found, using: $modelFolder'); - return modelFolder; - } - - /// Check if directory contains ONNX model files - bool _hasONNXModelFiles(String directory) { - final dir = Directory(directory); - if (!dir.existsSync()) return false; - - try { - return dir.listSync().any((entity) { - if (entity is! File) return false; - final name = entity.path.toLowerCase(); - return name.endsWith('.onnx') || - name.endsWith('.ort') || - name.contains('encoder') || - name.contains('decoder') || - name.contains('tokens'); - }); - } catch (e) { - return false; - } - } - - /// Resolve single-file model path (LlamaCpp .gguf files) - String? _resolveSingleFileModelPath(String modelFolder, ModelInfo model) { - final dir = Directory(modelFolder); - if (!dir.existsSync()) { - _logger.warning('Model folder does not exist: $modelFolder'); - return null; - } - - // Find the model file - try { - for (final entity in dir.listSync()) { - if (entity is File) { - final name = entity.path.toLowerCase(); - if (name.endsWith('.gguf') || name.endsWith('.bin')) { - _logger.info('Found model file: ${entity.path}'); - return entity.path; - } - } - } - } catch (e) { - _logger.warning('Error scanning model folder: $e'); - } - - _logger.warning('No model file found in: $modelFolder'); - return null; + return resolved ?? modelFolder; } // MARK: - Path Analysis diff --git a/sdk/runanywhere-flutter/packages/runanywhere/lib/native/dart_bridge_model_registry.dart b/sdk/runanywhere-flutter/packages/runanywhere/lib/native/dart_bridge_model_registry.dart index 14e454cdf..16882720a 100644 --- a/sdk/runanywhere-flutter/packages/runanywhere/lib/native/dart_bridge_model_registry.dart +++ b/sdk/runanywhere-flutter/packages/runanywhere/lib/native/dart_bridge_model_registry.dart @@ -847,9 +847,9 @@ int _listDirectoryCallback( void _freeEntriesCallback( Pointer> entries, int count, Pointer userData) { for (var i = 0; i < count; i++) { - if (entries[i] != nullptr) calloc.free(entries[i]); + if (entries[i] != nullptr) malloc.free(entries[i]); } - calloc.free(entries); + malloc.free(entries); } int _isDirectoryCallback(Pointer path, Pointer userData) { @@ -973,6 +973,10 @@ base class RacModelInfoCStruct extends Struct { @Int32() external int supportsThinking; + // rac_bool_t supports_lora (int32_t) + @Int32() + external int supportsLora; + // char** tags external Pointer> tags; diff --git a/sdk/runanywhere-flutter/packages/runanywhere/lib/native/dart_bridge_rag.dart b/sdk/runanywhere-flutter/packages/runanywhere/lib/native/dart_bridge_rag.dart index a90f20a42..52b5bc7ac 100644 --- a/sdk/runanywhere-flutter/packages/runanywhere/lib/native/dart_bridge_rag.dart +++ b/sdk/runanywhere-flutter/packages/runanywhere/lib/native/dart_bridge_rag.dart @@ -14,113 +14,7 @@ import 'package:ffi/ffi.dart'; import 'package:runanywhere/foundation/logging/sdk_logger.dart'; import 'package:runanywhere/native/ffi_types.dart'; import 'package:runanywhere/native/platform_loader.dart'; - -// ============================================================================= -// RAG Types (mirrors Swift RAGTypes.swift / Kotlin RAGTypes.kt) -// ============================================================================= - -/// Configuration for creating a RAG pipeline. -class RAGConfiguration { - /// Path to the ONNX embedding model - final String embeddingModelPath; - - /// Path to the GGUF LLM model - final String llmModelPath; - - /// Embedding vector dimension (default: 384 for all-MiniLM-L6-v2) - final int embeddingDimension; - - /// Number of top chunks to retrieve per query - final int topK; - - /// Minimum cosine similarity threshold 0.0-1.0 - final double similarityThreshold; - - /// Maximum tokens for context sent to the LLM - final int maxContextTokens; - - /// Tokens per chunk when splitting documents - final int chunkSize; - - /// Overlap tokens between consecutive chunks - final int chunkOverlap; - - /// Prompt template with {context} and {query} placeholders - final String? promptTemplate; - - /// Optional configuration JSON for the embedding model - final String? embeddingConfigJson; - - /// Optional configuration JSON for the LLM model - final String? llmConfigJson; - - const RAGConfiguration({ - required this.embeddingModelPath, - required this.llmModelPath, - this.embeddingDimension = 384, - this.topK = 10, - this.similarityThreshold = 0.15, - this.maxContextTokens = 2048, - this.chunkSize = 512, - this.chunkOverlap = 50, - this.promptTemplate, - this.embeddingConfigJson, - this.llmConfigJson, - }); -} - -/// Options for querying the RAG pipeline. -class RAGQueryOptions { - final String question; - final String? systemPrompt; - final int maxTokens; - final double temperature; - final double topP; - final int topK; - - const RAGQueryOptions({ - required this.question, - this.systemPrompt, - this.maxTokens = 512, - this.temperature = 0.7, - this.topP = 0.9, - this.topK = 40, - }); -} - -/// A single retrieved document chunk. -class RAGSearchResult { - final String chunkId; - final String text; - final double similarityScore; - final String? metadataJson; - - const RAGSearchResult({ - required this.chunkId, - required this.text, - required this.similarityScore, - this.metadataJson, - }); -} - -/// Result of a RAG query. -class RAGResult { - final String answer; - final List retrievedChunks; - final String? contextUsed; - final double retrievalTimeMs; - final double generationTimeMs; - final double totalTimeMs; - - const RAGResult({ - required this.answer, - required this.retrievedChunks, - this.contextUsed, - required this.retrievalTimeMs, - required this.generationTimeMs, - required this.totalTimeMs, - }); -} +import 'package:runanywhere/public/types/rag_types.dart'; // ============================================================================= // FFI Struct for rac_rag_config_t (legacy standalone config) @@ -276,11 +170,11 @@ class DartBridgeRAG { cConfig.ref.promptTemplate = config.promptTemplate != null ? config.promptTemplate!.toNativeUtf8() : nullptr; - cConfig.ref.embeddingConfigJson = config.embeddingConfigJson != null - ? config.embeddingConfigJson!.toNativeUtf8() + cConfig.ref.embeddingConfigJson = config.embeddingConfigJSON != null + ? config.embeddingConfigJSON!.toNativeUtf8() : nullptr; - cConfig.ref.llmConfigJson = config.llmConfigJson != null - ? config.llmConfigJson!.toNativeUtf8() + cConfig.ref.llmConfigJson = config.llmConfigJSON != null + ? config.llmConfigJSON!.toNativeUtf8() : nullptr; final result = fn(cConfig, outPipeline); @@ -410,7 +304,7 @@ class DartBridgeRAG { chunkId: c.chunkId != nullptr ? c.chunkId.toDartString() : '', text: c.text != nullptr ? c.text.toDartString() : '', similarityScore: c.similarityScore, - metadataJson: + metadataJSON: c.metadataJson != nullptr ? c.metadataJson.toDartString() : null, )); } diff --git a/sdk/runanywhere-flutter/packages/runanywhere/lib/native/ffi_types.dart b/sdk/runanywhere-flutter/packages/runanywhere/lib/native/ffi_types.dart index 35e9e8df5..c81a6f9b3 100644 --- a/sdk/runanywhere-flutter/packages/runanywhere/lib/native/ffi_types.dart +++ b/sdk/runanywhere-flutter/packages/runanywhere/lib/native/ffi_types.dart @@ -1405,6 +1405,84 @@ typedef RacBackendRagRegisterDart = int Function(); typedef RacBackendRagUnregisterNative = Int32 Function(); typedef RacBackendRagUnregisterDart = int Function(); +// File Manager Types (from rac_file_manager.h) +// ============================================================================= + +/// Callback: create_directory(path, recursive, user_data) -> rac_result_t +typedef RacFmCreateDirectoryNative = Int32 Function( + Pointer, Int32, Pointer); + +/// Callback: delete_path(path, recursive, user_data) -> rac_result_t +typedef RacFmDeletePathNative = Int32 Function( + Pointer, Int32, Pointer); + +/// Callback: list_directory(path, out_entries, out_count, user_data) -> rac_result_t +typedef RacFmListDirectoryNative = Int32 Function( + Pointer, + Pointer>>, + Pointer, + Pointer); + +/// Callback: free_entries(entries, count, user_data) +typedef RacFmFreeEntriesNative = Void Function( + Pointer>, Size, Pointer); + +/// Callback: path_exists(path, out_is_directory, user_data) -> rac_bool_t +typedef RacFmPathExistsNative = Int32 Function( + Pointer, Pointer, Pointer); + +/// Callback: get_file_size(path, user_data) -> int64_t +typedef RacFmGetFileSizeNative = Int64 Function(Pointer, Pointer); + +/// Callback: get_available_space(user_data) -> int64_t +typedef RacFmGetAvailableSpaceNative = Int64 Function(Pointer); + +/// Callback: get_total_space(user_data) -> int64_t +typedef RacFmGetTotalSpaceNative = Int64 Function(Pointer); + +/// File callbacks struct matching rac_file_callbacks_t +final class RacFileCallbacksStruct extends Struct { + external Pointer> createDirectory; + external Pointer> deletePath; + external Pointer> listDirectory; + external Pointer> freeEntries; + external Pointer> pathExists; + external Pointer> getFileSize; + external Pointer> + getAvailableSpace; + external Pointer> getTotalSpace; + external Pointer userData; +} + +/// Storage info struct matching rac_file_manager_storage_info_t +final class RacFileManagerStorageInfoStruct extends Struct { + @Int64() + external int deviceTotal; + @Int64() + external int deviceFree; + @Int64() + external int modelsSize; + @Int64() + external int cacheSize; + @Int64() + external int tempSize; + @Int64() + external int totalAppSize; +} + +/// Storage availability struct matching rac_storage_availability_t +final class RacStorageAvailabilityStruct extends Struct { + @Int32() + external int isAvailable; + @Int64() + external int requiredSpace; + @Int64() + external int availableSpace; + @Int32() + external int hasWarning; + external Pointer recommendation; +} + // ============================================================================= // Backward Compatibility Aliases // ============================================================================= diff --git a/sdk/runanywhere-flutter/packages/runanywhere/lib/public/extensions/rag_module.dart b/sdk/runanywhere-flutter/packages/runanywhere/lib/public/extensions/rag_module.dart index adadb227b..3e631bae2 100644 --- a/sdk/runanywhere-flutter/packages/runanywhere/lib/public/extensions/rag_module.dart +++ b/sdk/runanywhere-flutter/packages/runanywhere/lib/public/extensions/rag_module.dart @@ -30,10 +30,8 @@ library rag_module; import 'package:runanywhere/core/module/runanywhere_module.dart'; import 'package:runanywhere/core/types/model_types.dart'; import 'package:runanywhere/core/types/sdk_component.dart'; -import 'package:runanywhere/foundation/error_types/sdk_error.dart'; import 'package:runanywhere/foundation/logging/sdk_logger.dart'; import 'package:runanywhere/native/dart_bridge_rag.dart'; -import 'package:runanywhere/native/ffi_types.dart'; /// RAG module for Retrieval-Augmented Generation. /// @@ -95,18 +93,7 @@ class RAGModule implements RunAnywhereModule { _logger.info('Registering RAG backend with C++ registry...'); try { - final result = DartBridgeRAG.registerBackend(); - _logger.info( - 'rac_backend_rag_register() returned: $result (${RacResultCode.getMessage(result)})', - ); - - if (result != RacResultCode.success && - result != RacResultCode.errorModuleAlreadyRegistered) { - _logger.error('RAG backend registration FAILED with code: $result'); - throw SDKError.frameworkNotAvailable( - 'RAG backend registration failed with code: $result (${RacResultCode.getMessage(result)})', - ); - } + DartBridgeRAG.shared.register(); _isRegistered = true; _logger.info('RAG backend registered successfully'); @@ -122,7 +109,6 @@ class RAGModule implements RunAnywhereModule { return; } - DartBridgeRAG.unregisterBackend(); _isRegistered = false; _logger.info('RAG backend unregistered'); } diff --git a/sdk/runanywhere-flutter/packages/runanywhere/lib/public/extensions/runanywhere_lora.dart b/sdk/runanywhere-flutter/packages/runanywhere/lib/public/extensions/runanywhere_lora.dart new file mode 100644 index 000000000..dfae3db78 --- /dev/null +++ b/sdk/runanywhere-flutter/packages/runanywhere/lib/public/extensions/runanywhere_lora.dart @@ -0,0 +1,106 @@ +/// RunAnywhere + LoRA +/// +/// Public API for LoRA (Low-Rank Adaptation) adapter operations. +/// Mirrors Swift's RunAnywhere+LoRA.swift and Kotlin's RunAnywhere+LoRA.kt. +/// +/// Provides: +/// - Runtime operations: load, remove, clear, query adapters +/// - Catalog operations: register, query adapter metadata +/// - Compatibility checking +library runanywhere_lora; + +import 'package:runanywhere/native/dart_bridge_lora.dart'; +import 'package:runanywhere/public/runanywhere.dart'; +import 'package:runanywhere/public/types/lora_types.dart'; + +/// Extension providing static LoRA methods on RunAnywhere. +/// +/// Usage: +/// ```dart +/// // Load a LoRA adapter +/// RunAnywhereLoRA.loadLoraAdapter(LoRAAdapterConfig(path: '/path/to/adapter.gguf')); +/// +/// // Check loaded adapters +/// final adapters = RunAnywhereLoRA.getLoadedLoraAdapters(); +/// +/// // Remove all adapters +/// RunAnywhereLoRA.clearLoraAdapters(); +/// ``` +extension RunAnywhereLoRA on RunAnywhere { + // MARK: - Runtime Operations + + /// Load and apply a LoRA adapter to the current model. + /// + /// Context is recreated internally and KV cache is cleared. + /// Throws if SDK not initialized or load fails. + static void loadLoraAdapter(LoRAAdapterConfig config) { + if (!RunAnywhere.isSDKInitialized) { + throw StateError('SDK not initialized'); + } + DartBridgeLora.shared.loadAdapter(config.path, config.scale); + } + + /// Remove a specific LoRA adapter by path. + /// + /// Throws if SDK not initialized or adapter not found. + static void removeLoraAdapter(String path) { + if (!RunAnywhere.isSDKInitialized) { + throw StateError('SDK not initialized'); + } + DartBridgeLora.shared.removeAdapter(path); + } + + /// Remove all LoRA adapters. + /// + /// Throws if SDK not initialized. + static void clearLoraAdapters() { + if (!RunAnywhere.isSDKInitialized) { + throw StateError('SDK not initialized'); + } + DartBridgeLora.shared.clearAdapters(); + } + + /// Get info about currently loaded LoRA adapters. + /// + /// Returns empty list if SDK not initialized or no adapters loaded. + static List getLoadedLoraAdapters() { + if (!RunAnywhere.isSDKInitialized) return []; + return DartBridgeLora.shared.getLoadedAdapters(); + } + + /// Check if the current backend supports LoRA for the given adapter path. + static LoraCompatibilityResult checkLoraCompatibility(String loraPath) { + if (!RunAnywhere.isSDKInitialized) { + return const LoraCompatibilityResult( + isCompatible: false, + error: 'SDK not initialized', + ); + } + return DartBridgeLora.shared.checkCompatibility(loraPath); + } + + // MARK: - Catalog Operations + + /// Register a LoRA adapter in the global registry. + /// + /// Entry is deep-copied internally by C++. + /// Throws if SDK not initialized or registration fails. + static void registerLoraAdapter(LoraAdapterCatalogEntry entry) { + if (!RunAnywhere.isSDKInitialized) { + throw StateError('SDK not initialized'); + } + DartBridgeLoraRegistry.shared.register(entry); + } + + /// Get all registered LoRA adapters compatible with a model. + static List loraAdaptersForModel(String modelId) { + if (!RunAnywhere.isSDKInitialized) return []; + return DartBridgeLoraRegistry.shared.getForModel(modelId); + } + + /// Get all registered LoRA adapters. + static List allRegisteredLoraAdapters() { + if (!RunAnywhere.isSDKInitialized) return []; + return DartBridgeLoraRegistry.shared.getAll(); + } +} diff --git a/sdk/runanywhere-flutter/packages/runanywhere/lib/public/extensions/runanywhere_rag.dart b/sdk/runanywhere-flutter/packages/runanywhere/lib/public/extensions/runanywhere_rag.dart index 115e30587..9379cbdb8 100644 --- a/sdk/runanywhere-flutter/packages/runanywhere/lib/public/extensions/runanywhere_rag.dart +++ b/sdk/runanywhere-flutter/packages/runanywhere/lib/public/extensions/runanywhere_rag.dart @@ -7,13 +7,9 @@ /// with initialization guards, event publishing, and typed error conversion. library runanywhere_rag; -import 'dart:ffi'; - -import 'package:ffi/ffi.dart'; - import 'package:runanywhere/foundation/error_types/sdk_error.dart'; -import 'package:runanywhere/native/dart_bridge_rag.dart'; -import 'package:runanywhere/native/ffi_types.dart'; +import 'package:runanywhere/native/dart_bridge_rag.dart' + hide RAGConfiguration, RAGQueryOptions, RAGSearchResult, RAGResult; import 'package:runanywhere/public/events/event_bus.dart'; import 'package:runanywhere/public/events/sdk_event.dart'; import 'package:runanywhere/public/extensions/rag_module.dart'; @@ -41,8 +37,8 @@ extension RunAnywhereRAG on RunAnywhere { /// Create the RAG pipeline with the given configuration. /// - /// Marshals [config] to a native [RacRagConfigStruct] via FFI, calls - /// [DartBridgeRAG.createPipeline], then publishes [SDKRAGEvent.pipelineCreated]. + /// Passes [config] to [DartBridgeRAG.createPipeline] which handles + /// all FFI marshaling internally, then publishes [SDKRAGEvent.pipelineCreated]. /// /// Throws [SDKError.notInitialized] if SDK is not initialized. /// Throws [SDKError.invalidState] if pipeline creation fails. @@ -57,39 +53,13 @@ extension RunAnywhereRAG on RunAnywhere { ); } - final embeddingModelPathPtr = config.embeddingModelPath.toNativeUtf8(); - final llmModelPathPtr = config.llmModelPath.toNativeUtf8(); - final promptTemplatePtr = config.promptTemplate?.toNativeUtf8(); - final embeddingConfigJsonPtr = config.embeddingConfigJSON?.toNativeUtf8(); - final llmConfigJsonPtr = config.llmConfigJSON?.toNativeUtf8(); - final configPtr = calloc(); - try { - configPtr.ref.embeddingModelPath = embeddingModelPathPtr; - configPtr.ref.llmModelPath = llmModelPathPtr; - configPtr.ref.embeddingDimension = config.embeddingDimension; - configPtr.ref.topK = config.topK; - configPtr.ref.similarityThreshold = config.similarityThreshold; - configPtr.ref.maxContextTokens = config.maxContextTokens; - configPtr.ref.chunkSize = config.chunkSize; - configPtr.ref.chunkOverlap = config.chunkOverlap; - configPtr.ref.promptTemplate = promptTemplatePtr ?? nullptr; - configPtr.ref.embeddingConfigJson = embeddingConfigJsonPtr ?? nullptr; - configPtr.ref.llmConfigJson = llmConfigJsonPtr ?? nullptr; - - DartBridgeRAG.shared.createPipeline(config: configPtr); + DartBridgeRAG.shared.createPipeline(config); EventBus.shared.publish(SDKRAGEvent.pipelineCreated()); } catch (e) { EventBus.shared.publish(SDKRAGEvent.error(message: e.toString())); throw SDKError.invalidState('RAG pipeline creation failed: $e'); - } finally { - calloc.free(embeddingModelPathPtr); - calloc.free(llmModelPathPtr); - if (promptTemplatePtr != null) calloc.free(promptTemplatePtr); - if (embeddingConfigJsonPtr != null) calloc.free(embeddingConfigJsonPtr); - if (llmConfigJsonPtr != null) calloc.free(llmConfigJsonPtr); - calloc.free(configPtr); } } @@ -103,7 +73,7 @@ extension RunAnywhereRAG on RunAnywhere { throw SDKError.notInitialized(); } - DartBridgeRAG.shared.destroy(); + DartBridgeRAG.shared.destroyPipeline(); EventBus.shared.publish(SDKRAGEvent.pipelineDestroyed()); } @@ -132,7 +102,7 @@ extension RunAnywhereRAG on RunAnywhere { final stopwatch = Stopwatch()..start(); try { - DartBridgeRAG.shared.addDocument(text, metadataJSON: metadataJSON); + DartBridgeRAG.shared.addDocument(text, metadataJson: metadataJSON); stopwatch.stop(); @@ -210,16 +180,23 @@ extension RunAnywhereRAG on RunAnywhere { ); try { - final bridgeResult = DartBridgeRAG.shared.query( - question, - systemPrompt: options?.systemPrompt, - maxTokens: options?.maxTokens ?? 512, - temperature: options?.temperature ?? 0.7, - topP: options?.topP ?? 0.9, - topK: options?.topK ?? 40, - ); - - final result = RAGResult.fromBridge(bridgeResult); + final queryOptions = options ?? + RAGQueryOptions(question: question); + + // If caller provided options but with a different question field, + // create a new options with the positional question. + final effectiveOptions = queryOptions.question == question + ? queryOptions + : RAGQueryOptions( + question: question, + systemPrompt: queryOptions.systemPrompt, + maxTokens: queryOptions.maxTokens, + temperature: queryOptions.temperature, + topP: queryOptions.topP, + topK: queryOptions.topK, + ); + + final result = DartBridgeRAG.shared.query(effectiveOptions); EventBus.shared.publish( SDKRAGEvent.queryComplete( diff --git a/sdk/runanywhere-flutter/packages/runanywhere/lib/public/extensions/runanywhere_storage.dart b/sdk/runanywhere-flutter/packages/runanywhere/lib/public/extensions/runanywhere_storage.dart index ec1151f69..9df2dcb3c 100644 --- a/sdk/runanywhere-flutter/packages/runanywhere/lib/public/extensions/runanywhere_storage.dart +++ b/sdk/runanywhere-flutter/packages/runanywhere/lib/public/extensions/runanywhere_storage.dart @@ -4,10 +4,9 @@ /// Mirrors Swift's RunAnywhere+Storage.swift. library runanywhere_storage; -import 'dart:io'; - import 'package:path_provider/path_provider.dart'; import 'package:runanywhere/infrastructure/download/download_service.dart'; +import 'package:runanywhere/native/dart_bridge_file_manager.dart'; import 'package:runanywhere/native/dart_bridge_storage.dart'; import 'package:runanywhere/public/events/event_bus.dart'; import 'package:runanywhere/public/events/sdk_event.dart'; @@ -22,20 +21,14 @@ extension RunAnywhereStorage on RunAnywhere { /// Check if storage is available for a model download /// /// Returns true if sufficient storage is available for the given model size. + /// Delegates to C++ file manager for storage checks. static Future checkStorageAvailable({ required int modelSize, double safetyMargin = 0.1, }) async { try { - final directory = await getApplicationDocumentsDirectory(); final requiredWithMargin = (modelSize * (1 + safetyMargin)).toInt(); - - // Get directory size as a proxy for available space check - final dirSize = await _getDirectorySize(directory); - - // If the SDK directory is larger than the model size, - // we assume storage is available (simplified check) - return dirSize > requiredWithMargin; + return DartBridgeFileManager.checkStorage(requiredWithMargin); } catch (_) { // Default to available if check fails return true; @@ -86,18 +79,4 @@ extension RunAnywhereStorage on RunAnywhere { return ModelDownloadService.shared.downloadModel(modelId); } - /// Helper to get directory size - static Future _getDirectorySize(Directory directory) async { - int size = 0; - try { - await for (final entity in directory.list(recursive: true)) { - if (entity is File) { - size += await entity.length(); - } - } - } catch (_) { - // Ignore errors - } - return size; - } } diff --git a/sdk/runanywhere-flutter/packages/runanywhere/lib/public/runanywhere.dart b/sdk/runanywhere-flutter/packages/runanywhere/lib/public/runanywhere.dart index ca43c6a3b..e28523880 100644 --- a/sdk/runanywhere-flutter/packages/runanywhere/lib/public/runanywhere.dart +++ b/sdk/runanywhere-flutter/packages/runanywhere/lib/public/runanywhere.dart @@ -17,6 +17,7 @@ import 'package:runanywhere/foundation/logging/sdk_logger.dart'; import 'package:runanywhere/infrastructure/download/download_service.dart'; import 'package:runanywhere/native/dart_bridge.dart'; import 'package:runanywhere/native/dart_bridge_auth.dart'; +import 'package:runanywhere/native/dart_bridge_file_manager.dart'; import 'package:runanywhere/native/dart_bridge_device.dart'; import 'package:runanywhere/native/dart_bridge_model_paths.dart'; import 'package:runanywhere/native/dart_bridge_model_registry.dart' @@ -2218,27 +2219,9 @@ class RunAnywhere { } } - /// Calculate directory size recursively. + /// Calculate directory size — delegates to C++ file manager. static Future _getDirectorySize(String path) async { - try { - final dir = Directory(path); - if (!await dir.exists()) return 0; - - int totalSize = 0; - await for (final entity - in dir.list(recursive: true, followLinks: false)) { - if (entity is File) { - try { - totalSize += await entity.length(); - } catch (_) { - // Skip files we can't read - } - } - } - return totalSize; - } catch (e) { - return 0; - } + return DartBridgeFileManager.calculateDirectorySize(path); } /// Get downloaded models with their file sizes. diff --git a/sdk/runanywhere-flutter/packages/runanywhere/lib/public/types/lora_types.dart b/sdk/runanywhere-flutter/packages/runanywhere/lib/public/types/lora_types.dart new file mode 100644 index 000000000..3c45509f5 --- /dev/null +++ b/sdk/runanywhere-flutter/packages/runanywhere/lib/public/types/lora_types.dart @@ -0,0 +1,89 @@ +/// LoRA Types +/// +/// Data types for LoRA (Low-Rank Adaptation) adapter operations. +/// Mirrors Swift's LLMTypes.swift LoRA types and Kotlin's RunAnywhere+LoRA.kt. +library lora_types; + +/// Configuration for loading a LoRA adapter. +class LoRAAdapterConfig { + /// Path to the LoRA adapter GGUF file. + final String path; + + /// Scale factor for the adapter (0.0-1.0+, default 1.0). + final double scale; + + LoRAAdapterConfig({ + required this.path, + this.scale = 1.0, + }) : assert(path.isNotEmpty, 'LoRA adapter path cannot be empty'); +} + +/// Info about a currently loaded LoRA adapter. +class LoRAAdapterInfo { + /// File path where adapter was loaded from. + final String path; + + /// LoRA scale factor. + final double scale; + + /// Whether adapter is currently applied to the context. + final bool applied; + + const LoRAAdapterInfo({ + required this.path, + required this.scale, + required this.applied, + }); +} + +/// Catalog entry for a LoRA adapter (metadata for registry). +class LoraAdapterCatalogEntry { + /// Unique adapter identifier. + final String id; + + /// Human-readable display name. + final String name; + + /// Short description of what this adapter does. + final String description; + + /// Direct download URL (.gguf file). + final String downloadUrl; + + /// Filename to save as on disk. + final String filename; + + /// Explicit list of compatible base model IDs. + final List compatibleModelIds; + + /// File size in bytes (0 if unknown). + final int fileSize; + + /// Recommended LoRA scale (e.g. 0.3). + final double defaultScale; + + const LoraAdapterCatalogEntry({ + required this.id, + required this.name, + required this.description, + required this.downloadUrl, + required this.filename, + required this.compatibleModelIds, + this.fileSize = 0, + this.defaultScale = 1.0, + }); +} + +/// Result of a LoRA compatibility check. +class LoraCompatibilityResult { + /// Whether the current backend supports LoRA for the given adapter. + final bool isCompatible; + + /// Error message if incompatible, null if compatible. + final String? error; + + const LoraCompatibilityResult({ + required this.isCompatible, + this.error, + }); +} diff --git a/sdk/runanywhere-flutter/packages/runanywhere/lib/public/types/rag_types.dart b/sdk/runanywhere-flutter/packages/runanywhere/lib/public/types/rag_types.dart index 58d7381b9..5d2d8880c 100644 --- a/sdk/runanywhere-flutter/packages/runanywhere/lib/public/types/rag_types.dart +++ b/sdk/runanywhere-flutter/packages/runanywhere/lib/public/types/rag_types.dart @@ -4,8 +4,6 @@ /// Mirrors iOS RAGTypes.swift adapted for Flutter/Dart. library rag_types; -import 'package:runanywhere/native/dart_bridge_rag.dart'; - // MARK: - RAGConfiguration /// Configuration for the RAG pipeline. @@ -135,18 +133,6 @@ class RAGSearchResult { this.metadataJSON, }); - /// Create from a bridge search result. - /// - /// Converts empty metadataJson strings to null. - factory RAGSearchResult.fromBridge(RAGBridgeSearchResult bridge) { - return RAGSearchResult( - chunkId: bridge.chunkId, - text: bridge.text, - similarityScore: bridge.similarityScore, - metadataJSON: bridge.metadataJson.isEmpty ? null : bridge.metadataJson, - ); - } - @override String toString() { return 'RAGSearchResult(chunkId: $chunkId, score: $similarityScore)'; @@ -188,24 +174,6 @@ class RAGResult { required this.totalTimeMs, }); - /// Create from a bridge result. - /// - /// Converts the bridge's [RAGBridgeResult] to public [RAGResult], - /// mapping each chunk through [RAGSearchResult.fromBridge]. - /// Empty contextUsed strings are converted to null. - factory RAGResult.fromBridge(RAGBridgeResult bridge) { - return RAGResult( - answer: bridge.answer, - retrievedChunks: bridge.retrievedChunks - .map(RAGSearchResult.fromBridge) - .toList(growable: false), - contextUsed: bridge.contextUsed.isEmpty ? null : bridge.contextUsed, - retrievalTimeMs: bridge.retrievalTimeMs, - generationTimeMs: bridge.generationTimeMs, - totalTimeMs: bridge.totalTimeMs, - ); - } - @override String toString() { final preview = answer.length > 50 ? answer.substring(0, 50) : answer; diff --git a/sdk/runanywhere-flutter/packages/runanywhere/lib/public/types/types.dart b/sdk/runanywhere-flutter/packages/runanywhere/lib/public/types/types.dart index 42fa6f24a..a7a2cde66 100644 --- a/sdk/runanywhere-flutter/packages/runanywhere/lib/public/types/types.dart +++ b/sdk/runanywhere-flutter/packages/runanywhere/lib/public/types/types.dart @@ -12,4 +12,5 @@ export 'structured_output_types.dart'; export 'tool_calling_types.dart'; export 'rag_types.dart'; export 'vlm_types.dart'; +export 'lora_types.dart'; export 'voice_agent_types.dart'; diff --git a/sdk/runanywhere-flutter/packages/runanywhere/lib/runanywhere.dart b/sdk/runanywhere-flutter/packages/runanywhere/lib/runanywhere.dart index 6f3504ae2..c6b464689 100644 --- a/sdk/runanywhere-flutter/packages/runanywhere/lib/runanywhere.dart +++ b/sdk/runanywhere-flutter/packages/runanywhere/lib/runanywhere.dart @@ -26,9 +26,8 @@ export 'public/events/event_bus.dart'; export 'public/events/sdk_event.dart'; export 'public/extensions/runanywhere_frameworks.dart'; export 'public/extensions/runanywhere_logging.dart'; +export 'public/extensions/runanywhere_lora.dart'; export 'public/extensions/runanywhere_storage.dart'; -export 'native/dart_bridge_rag.dart' - show RAGConfiguration, RAGQueryOptions, RAGSearchResult, RAGResult; export 'public/runanywhere.dart'; export 'public/runanywhere_tool_calling.dart'; export 'public/types/tool_calling_types.dart'; diff --git a/sdk/runanywhere-flutter/packages/runanywhere/pubspec.yaml b/sdk/runanywhere-flutter/packages/runanywhere/pubspec.yaml index a2f8cfc15..e3821196c 100644 --- a/sdk/runanywhere-flutter/packages/runanywhere/pubspec.yaml +++ b/sdk/runanywhere-flutter/packages/runanywhere/pubspec.yaml @@ -36,8 +36,6 @@ dependencies: collection: ^1.18.0 json_annotation: ^4.9.0 path: ^1.9.0 - # Archive extraction (tar.bz2, zip) - archive: ^3.6.1 # TTS fallback (system TTS) flutter_tts: ^3.8.0 # Audio recording for voice sessions diff --git a/sdk/runanywhere-flutter/packages/runanywhere_onnx/lib/onnx_download_strategy.dart b/sdk/runanywhere-flutter/packages/runanywhere_onnx/lib/onnx_download_strategy.dart index 91d595c2c..14c0736d8 100644 --- a/sdk/runanywhere-flutter/packages/runanywhere_onnx/lib/onnx_download_strategy.dart +++ b/sdk/runanywhere-flutter/packages/runanywhere_onnx/lib/onnx_download_strategy.dart @@ -1,16 +1,37 @@ import 'dart:async'; +import 'dart:ffi'; import 'dart:io'; -import 'package:archive/archive.dart'; +import 'package:ffi/ffi.dart'; import 'package:http/http.dart' as http; import 'package:runanywhere/foundation/error_types/sdk_error.dart'; import 'package:runanywhere/foundation/logging/sdk_logger.dart'; +import 'package:runanywhere/native/dart_bridge_download.dart'; +import 'package:runanywhere/native/platform_loader.dart'; +import 'package:runanywhere/native/type_conversions/model_types_cpp_bridge.dart'; + +/// FFI typedef for rac_extract_archive_native +typedef _RacExtractNative = Int32 Function( + Pointer archivePath, + Pointer destinationDir, + Pointer options, + Pointer progressCallback, + Pointer userData, + Pointer outResult, +); +typedef _RacExtractDart = int Function( + Pointer archivePath, + Pointer destinationDir, + Pointer options, + Pointer progressCallback, + Pointer userData, + Pointer outResult, +); /// ONNX download strategy for handling .onnx files and .tar.bz2 archives /// Matches iOS ONNXDownloadStrategy pattern /// -/// Uses pure Dart archive extraction (via `archive` package) for cross-platform support. -/// This works on both iOS and Android without requiring native libarchive. +/// Uses native C++ extraction via libarchive for all archive formats. class OnnxDownloadStrategy { final SDKLogger logger = SDKLogger('OnnxDownloadStrategy'); @@ -146,18 +167,16 @@ class OnnxDownloadStrategy { logger.info('Archive downloaded, extracting to: ${modelFolder.path}'); - // Extract the archive using pure Dart + // Extract the archive using native C++ (libarchive) try { - await _extractTarBz2( + await _extractNative( archivePath: archivePath.path, destinationPath: modelFolder.path, - onProgress: (extractProgress) { - // Map extraction progress to 50% - 95% of overall progress - final overallProgress = 0.5 + (extractProgress * 0.45); - progressHandler?.call(overallProgress); - }, ); + // Report extraction progress complete + progressHandler?.call(0.95); + logger.info('Archive extracted successfully to: ${modelFolder.path}'); } catch (e) { logger.error('Archive extraction failed: $e'); @@ -175,23 +194,18 @@ class OnnxDownloadStrategy { logger.warning('Failed to delete archive file: $e'); } - // Find the extracted model directory - // Sherpa-ONNX archives typically extract to a subdirectory with the model name - final contents = await modelFolder.list().toList(); - logger.debug( - 'Extracted contents: ${contents.map((e) => e.path.split('/').last).join(", ")}'); - - // If there's a single subdirectory, the actual model files are in there - var modelDir = modelFolder; - if (contents.length == 1 && contents.first is Directory) { - final subdir = contents.first as Directory; - final subdirStat = await subdir.stat(); - if (subdirStat.type == FileSystemEntityType.directory) { - // Model files are in the subdirectory - modelDir = subdir; - logger.info( - 'Model files are in subdirectory: ${subdir.path.split('/').last}'); - } + // Use C++ to find the actual model path after extraction + // (handles nested directories, model file scanning for sherpa-onnx archives) + final foundPath = DartBridgeDownload.findModelPathAfterExtraction( + extractedDir: modelFolder.path, + structure: 99, // RAC_ARCHIVE_STRUCTURE_UNKNOWN - auto-detect + framework: RacInferenceFramework.onnx, + format: RacModelFormat.onnx, + ); + final modelDir = foundPath != null ? Directory(foundPath) : modelFolder; + if (foundPath != null && foundPath != modelFolder.path) { + logger.info( + 'Model files found at: ${foundPath.split('/').last}'); } // Report completion (100%) @@ -201,33 +215,39 @@ class OnnxDownloadStrategy { return modelDir.uri; } - /// Extract tar.bz2 archive using pure Dart - Future _extractTarBz2({ + /// Extract archive using native C++ (libarchive) via FFI. + /// Supports ZIP, TAR.GZ, TAR.BZ2, TAR.XZ with auto-detection. + Future _extractNative({ required String archivePath, required String destinationPath, - void Function(double progress)? onProgress, }) async { - final archiveFile = File(archivePath); - final bytes = await archiveFile.readAsBytes(); - - // Decompress bz2 - final decompressed = BZip2Decoder().decodeBytes(bytes); + final lib = PlatformLoader.loadCommons(); + final extractFn = lib.lookupFunction<_RacExtractNative, _RacExtractDart>( + 'rac_extract_archive_native', + ); - // Extract tar - final archive = TarDecoder().decodeBytes(decompressed); - final totalFiles = archive.files.length; + final archivePathPtr = archivePath.toNativeUtf8(); + final destPathPtr = destinationPath.toNativeUtf8(); - for (var i = 0; i < archive.files.length; i++) { - final file = archive.files[i]; - final filename = file.name; + try { + final result = extractFn( + archivePathPtr, + destPathPtr, + nullptr, // default options + nullptr, // no progress callback + nullptr, // no user data + nullptr, // no result output + ); - if (file.isFile) { - final outputFile = File('$destinationPath/$filename'); - await outputFile.parent.create(recursive: true); - await outputFile.writeAsBytes(file.content as List); + if (result != 0) { + throw SDKError.downloadFailed( + archivePath, + 'Native extraction failed with code: $result', + ); } - - onProgress?.call((i + 1) / totalFiles); + } finally { + calloc.free(archivePathPtr); + calloc.free(destPathPtr); } } diff --git a/sdk/runanywhere-flutter/packages/runanywhere_onnx/pubspec.yaml b/sdk/runanywhere-flutter/packages/runanywhere_onnx/pubspec.yaml index 028c7d43e..21f9128ab 100644 --- a/sdk/runanywhere-flutter/packages/runanywhere_onnx/pubspec.yaml +++ b/sdk/runanywhere-flutter/packages/runanywhere_onnx/pubspec.yaml @@ -23,8 +23,6 @@ dependencies: ffi: ^2.1.0 # HTTP for download strategy http: ^1.2.1 - # Archive extraction for model downloads - archive: ^3.4.9 dev_dependencies: flutter_test: diff --git a/sdk/runanywhere-kotlin/build.gradle.kts b/sdk/runanywhere-kotlin/build.gradle.kts index ed4409ad0..74bbd0b9c 100644 --- a/sdk/runanywhere-kotlin/build.gradle.kts +++ b/sdk/runanywhere-kotlin/build.gradle.kts @@ -157,7 +157,7 @@ kotlin { implementation(libs.okhttp.logging) implementation(libs.gson) implementation(libs.commons.io) - implementation(libs.commons.compress) + implementation(libs.ktor.client.okhttp) // Error tracking - Sentry (matches iOS SDK SentryDestination) implementation(libs.sentry) diff --git a/sdk/runanywhere-kotlin/modules/runanywhere-core-onnx/build.gradle.kts b/sdk/runanywhere-kotlin/modules/runanywhere-core-onnx/build.gradle.kts index 532d66046..c4cd91ab8 100644 --- a/sdk/runanywhere-kotlin/modules/runanywhere-core-onnx/build.gradle.kts +++ b/sdk/runanywhere-kotlin/modules/runanywhere-core-onnx/build.gradle.kts @@ -103,9 +103,6 @@ kotlin { val jvmAndroidMain by creating { dependsOn(commonMain) - dependencies { - implementation("org.apache.commons:commons-compress:1.26.0") - } } val jvmMain by getting { diff --git a/sdk/runanywhere-kotlin/src/jvmAndroidMain/kotlin/com/runanywhere/sdk/foundation/bridge/CppBridge.kt b/sdk/runanywhere-kotlin/src/jvmAndroidMain/kotlin/com/runanywhere/sdk/foundation/bridge/CppBridge.kt index d273c86ab..9ed2310db 100644 --- a/sdk/runanywhere-kotlin/src/jvmAndroidMain/kotlin/com/runanywhere/sdk/foundation/bridge/CppBridge.kt +++ b/sdk/runanywhere-kotlin/src/jvmAndroidMain/kotlin/com/runanywhere/sdk/foundation/bridge/CppBridge.kt @@ -14,6 +14,7 @@ import com.runanywhere.sdk.foundation.SDKLogger import com.runanywhere.sdk.foundation.bridge.extensions.CppBridgeAuth import com.runanywhere.sdk.foundation.bridge.extensions.CppBridgeDevice import com.runanywhere.sdk.foundation.bridge.extensions.CppBridgeEvents +import com.runanywhere.sdk.foundation.bridge.extensions.CppBridgeFileManager import com.runanywhere.sdk.foundation.bridge.extensions.CppBridgeModelAssignment import com.runanywhere.sdk.foundation.bridge.extensions.CppBridgePlatform import com.runanywhere.sdk.foundation.bridge.extensions.CppBridgePlatformAdapter @@ -207,6 +208,9 @@ object CppBridge { logger.warn("Telemetry handle not available, analytics events will not be tracked") } + // Register file manager I/O callbacks for C++ file management + CppBridgeFileManager.register() + _isInitialized = true // Emit SDK init completed event with duration diff --git a/sdk/runanywhere-kotlin/src/jvmAndroidMain/kotlin/com/runanywhere/sdk/foundation/bridge/extensions/CppBridgeFileManager.kt b/sdk/runanywhere-kotlin/src/jvmAndroidMain/kotlin/com/runanywhere/sdk/foundation/bridge/extensions/CppBridgeFileManager.kt new file mode 100644 index 000000000..398809437 --- /dev/null +++ b/sdk/runanywhere-kotlin/src/jvmAndroidMain/kotlin/com/runanywhere/sdk/foundation/bridge/extensions/CppBridgeFileManager.kt @@ -0,0 +1,170 @@ +/* + * Copyright 2026 RunAnywhere SDK + * SPDX-License-Identifier: Apache-2.0 + * + * FileManager extension for CppBridge. + * C++ owns business logic (recursive dir size, cache clearing, storage checks). + * Kotlin provides thin I/O callbacks (create dir, delete, list, stat). + * + * Follows iOS CppBridge+FileManager.swift architecture. + */ + +package com.runanywhere.sdk.foundation.bridge.extensions + +import com.runanywhere.sdk.native.bridge.RunAnywhereBridge +import java.io.File + +/** + * File manager bridge to C++ rac_file_manager. + * + * C++ handles: recursive dir size, directory structure, cache clearing, storage checks. + * Kotlin provides: thin I/O callbacks (create dir, delete, list, stat, file size). + */ +object CppBridgeFileManager { + + @Volatile + private var isRegistered: Boolean = false + private val lock = Any() + + // ======================================================================== + // REGISTRATION + // ======================================================================== + + /** + * Register the file I/O callbacks with C++ core. + * Must be called during SDK initialization after native library is loaded. + */ + fun register() { + synchronized(lock) { + if (isRegistered) return + RunAnywhereBridge.nativeFileManagerRegisterCallbacks(FileCallbackProvider) + isRegistered = true + } + } + + // ======================================================================== + // PUBLIC API + // ======================================================================== + + /** Create directory structure (Models, Cache, Temp, Downloads). */ + fun createDirectoryStructure(): Boolean { + return RunAnywhereBridge.nativeFileManagerCreateDirectoryStructure() == RunAnywhereBridge.RAC_SUCCESS + } + + /** Calculate directory size recursively (C++ logic, Kotlin I/O). */ + fun calculateDirectorySize(path: String): Long { + return RunAnywhereBridge.nativeFileManagerCalculateDirSize(path) + } + + /** Get total models storage used. */ + fun modelsStorageUsed(): Long { + return RunAnywhereBridge.nativeFileManagerModelsStorageUsed() + } + + /** Clear cache directory. */ + fun clearCache(): Boolean { + return RunAnywhereBridge.nativeFileManagerClearCache() == RunAnywhereBridge.RAC_SUCCESS + } + + /** Clear temp directory. */ + fun clearTemp(): Boolean { + return RunAnywhereBridge.nativeFileManagerClearTemp() == RunAnywhereBridge.RAC_SUCCESS + } + + /** Get cache size. */ + fun cacheSize(): Long { + return RunAnywhereBridge.nativeFileManagerCacheSize() + } + + /** Delete a model folder. */ + fun deleteModel(modelId: String, framework: Int): Boolean { + return RunAnywhereBridge.nativeFileManagerDeleteModel(modelId, framework) == RunAnywhereBridge.RAC_SUCCESS + } + + /** Create model folder and return path. */ + fun createModelFolder(modelId: String, framework: Int): String? { + return RunAnywhereBridge.nativeFileManagerCreateModelFolder(modelId, framework) + } + + /** Check if model folder exists. */ + fun modelFolderExists(modelId: String, framework: Int): Boolean { + return RunAnywhereBridge.nativeFileManagerModelFolderExists(modelId, framework) + } + + /** Get storage info as JSON. */ + fun getStorageInfoJson(): String? { + return RunAnywhereBridge.nativeFileManagerGetStorageInfo() + } + + /** Check storage availability as JSON. */ + fun checkStorageJson(requiredBytes: Long): String? { + return RunAnywhereBridge.nativeFileManagerCheckStorage(requiredBytes) + } + + // ======================================================================== + // PLATFORM I/O CALLBACK PROVIDER + // ======================================================================== + + /** + * Provides platform file I/O methods called by C++ via JNI. + * Method signatures must match JNI expectations exactly. + */ + private object FileCallbackProvider { + + @Suppress("unused") // Called from JNI + fun createDirectory(path: String, recursive: Boolean): Int { + return try { + val dir = File(path) + val success = if (recursive) dir.mkdirs() else dir.mkdir() + if (success || dir.exists()) 0 else -180 // RAC_ERROR_DIRECTORY_CREATION_FAILED + } catch (_: Exception) { + -180 + } + } + + @Suppress("unused") // Called from JNI + fun deletePath(path: String, recursive: Boolean): Int { + return try { + val file = File(path) + if (!file.exists()) return 0 + val success = if (recursive) file.deleteRecursively() else file.delete() + if (success) 0 else -182 // RAC_ERROR_DELETE_FAILED + } catch (_: Exception) { + -182 + } + } + + @Suppress("unused") // Called from JNI + fun listDirectory(path: String): Array? { + return File(path).list() + } + + @Suppress("unused") // Called from JNI + fun pathExists(path: String): Boolean { + return File(path).exists() + } + + @Suppress("unused") // Called from JNI + fun isDirectory(path: String): Boolean { + return File(path).isDirectory + } + + @Suppress("unused") // Called from JNI + fun getFileSize(path: String): Long { + val file = File(path) + return if (file.isFile) file.length() else -1L + } + + @Suppress("unused") // Called from JNI + fun getAvailableSpace(): Long { + val baseDir = File(CppBridgeModelPaths.getBaseDirectory()) + return baseDir.freeSpace + } + + @Suppress("unused") // Called from JNI + fun getTotalSpace(): Long { + val baseDir = File(CppBridgeModelPaths.getBaseDirectory()) + return baseDir.totalSpace + } + } +} diff --git a/sdk/runanywhere-kotlin/src/jvmAndroidMain/kotlin/com/runanywhere/sdk/foundation/bridge/extensions/CppBridgeModelRegistry.kt b/sdk/runanywhere-kotlin/src/jvmAndroidMain/kotlin/com/runanywhere/sdk/foundation/bridge/extensions/CppBridgeModelRegistry.kt index 74596ac8f..f0649739f 100644 --- a/sdk/runanywhere-kotlin/src/jvmAndroidMain/kotlin/com/runanywhere/sdk/foundation/bridge/extensions/CppBridgeModelRegistry.kt +++ b/sdk/runanywhere-kotlin/src/jvmAndroidMain/kotlin/com/runanywhere/sdk/foundation/bridge/extensions/CppBridgeModelRegistry.kt @@ -52,8 +52,8 @@ object CppBridgeModelRegistry { const val STT = ModelCategory.SPEECH_RECOGNITION const val TTS = ModelCategory.SPEECH_SYNTHESIS const val VAD = ModelCategory.AUDIO - const val EMBEDDING = 99 - const val UNKNOWN = 99 + const val EMBEDDING = ModelCategory.EMBEDDING + const val UNKNOWN = -1 /** * Get display name for a model type. diff --git a/sdk/runanywhere-kotlin/src/jvmAndroidMain/kotlin/com/runanywhere/sdk/native/bridge/RunAnywhereBridge.kt b/sdk/runanywhere-kotlin/src/jvmAndroidMain/kotlin/com/runanywhere/sdk/native/bridge/RunAnywhereBridge.kt index 260e74a95..0010d7c3b 100644 --- a/sdk/runanywhere-kotlin/src/jvmAndroidMain/kotlin/com/runanywhere/sdk/native/bridge/RunAnywhereBridge.kt +++ b/sdk/runanywhere-kotlin/src/jvmAndroidMain/kotlin/com/runanywhere/sdk/native/bridge/RunAnywhereBridge.kt @@ -465,6 +465,55 @@ object RunAnywhereBridge { @JvmStatic external fun racHttpDownloadReportComplete(taskId: String, result: Int, downloadedPath: String?): Int + // ======================================================================== + // ARCHIVE EXTRACTION (rac_extraction.h) + // ======================================================================== + + /** Extract an archive (ZIP, TAR.GZ, TAR.BZ2, TAR.XZ) to destination directory. + * Returns RAC_SUCCESS (0) on success, negative error code on failure. */ + @JvmStatic + external fun nativeExtractArchive(archivePath: String, destinationDir: String): Int + + /** Detect archive type from magic bytes. Returns rac_archive_type_t enum value, or -1 on failure. */ + @JvmStatic + external fun nativeDetectArchiveType(filePath: String): Int + + // ======================================================================== + // DOWNLOAD ORCHESTRATOR (rac_download_orchestrator.h) + // ======================================================================== + + /** Find model path after extraction. Returns the actual model file/directory path. + * Uses C++ rac_find_model_path_after_extraction() — consolidated from platform-specific logic. + * @param extractedDir Directory where archive was extracted + * @param structure Archive structure hint (rac_archive_structure_t enum ordinal) + * @param framework Inference framework (rac_inference_framework_t enum ordinal) + * @param format Model format (rac_model_format_t enum ordinal) + * @return The found model path, or extractedDir as fallback */ + @JvmStatic + external fun nativeFindModelPathAfterExtraction( + extractedDir: String, + structure: Int, + framework: Int, + format: Int, + ): String + + /** Check if a download URL requires extraction. + * Uses C++ rac_download_requires_extraction() — handles .tar.gz, .tar.bz2, .zip, etc. + * @return true if URL points to an archive */ + @JvmStatic + external fun nativeDownloadRequiresExtraction(url: String): Boolean + + /** Compute download destination path. + * Uses C++ rac_download_compute_destination(). + * @return Destination path, or null on failure */ + @JvmStatic + external fun nativeComputeDownloadDestination( + modelId: String, + downloadUrl: String, + framework: Int, + format: Int, + ): String? + // ======================================================================== // BACKEND REGISTRATION // ======================================================================== @@ -1098,6 +1147,60 @@ object RunAnywhereBridge { @JvmStatic external fun racToolCallNormalizeJson(jsonStr: String): String? + // ======================================================================== + // FILE MANAGER (rac_file_manager.h) + // ======================================================================== + + /** + * Register file manager callbacks object. + * The callback object must implement: + * - createDirectory(path: String, recursive: Boolean): Int + * - deletePath(path: String, recursive: Boolean): Int + * - listDirectory(path: String): Array? + * - pathExists(path: String): Boolean + * - isDirectory(path: String): Boolean + * - getFileSize(path: String): Long + * - getAvailableSpace(): Long + * - getTotalSpace(): Long + */ + @JvmStatic + external fun nativeFileManagerRegisterCallbacks(callbacksObj: Any): Int + + @JvmStatic + external fun nativeFileManagerCreateDirectoryStructure(): Int + + @JvmStatic + external fun nativeFileManagerCalculateDirSize(path: String): Long + + @JvmStatic + external fun nativeFileManagerModelsStorageUsed(): Long + + @JvmStatic + external fun nativeFileManagerClearCache(): Int + + @JvmStatic + external fun nativeFileManagerClearTemp(): Int + + @JvmStatic + external fun nativeFileManagerCacheSize(): Long + + @JvmStatic + external fun nativeFileManagerDeleteModel(modelId: String, framework: Int): Int + + @JvmStatic + external fun nativeFileManagerCreateModelFolder(modelId: String, framework: Int): String? + + @JvmStatic + external fun nativeFileManagerModelFolderExists(modelId: String, framework: Int): Boolean + + /** Returns JSON: {isAvailable, requiredSpace, availableSpace, hasWarning, recommendation} */ + @JvmStatic + external fun nativeFileManagerCheckStorage(requiredBytes: Long): String? + + /** Returns JSON: {deviceTotal, deviceFree, modelsSize, cacheSize, tempSize, totalAppSize} */ + @JvmStatic + external fun nativeFileManagerGetStorageInfo(): String? + // ======================================================================== // CONSTANTS // ======================================================================== diff --git a/sdk/runanywhere-kotlin/src/jvmAndroidMain/kotlin/com/runanywhere/sdk/public/extensions/RunAnywhere+ModelManagement.jvmAndroid.kt b/sdk/runanywhere-kotlin/src/jvmAndroidMain/kotlin/com/runanywhere/sdk/public/extensions/RunAnywhere+ModelManagement.jvmAndroid.kt index 15661c1cd..8c7225ff1 100644 --- a/sdk/runanywhere-kotlin/src/jvmAndroidMain/kotlin/com/runanywhere/sdk/public/extensions/RunAnywhere+ModelManagement.jvmAndroid.kt +++ b/sdk/runanywhere-kotlin/src/jvmAndroidMain/kotlin/com/runanywhere/sdk/public/extensions/RunAnywhere+ModelManagement.jvmAndroid.kt @@ -30,16 +30,11 @@ import kotlinx.coroutines.flow.Flow import kotlinx.coroutines.flow.callbackFlow import kotlinx.coroutines.flow.flow import kotlinx.coroutines.withContext -import org.apache.commons.compress.archivers.tar.TarArchiveInputStream -import org.apache.commons.compress.compressors.bzip2.BZip2CompressorInputStream -import org.apache.commons.compress.compressors.gzip.GzipCompressorInputStream -import java.io.BufferedInputStream +import com.runanywhere.sdk.native.bridge.RunAnywhereBridge import java.io.File -import java.io.FileInputStream import java.io.FileOutputStream import java.io.IOException import java.net.HttpURLConnection -import java.util.zip.ZipInputStream // MARK: - Multi-File Model Companion Storage @@ -245,10 +240,10 @@ actual suspend fun RunAnywhere.models(category: ModelCategory): List } val type = when (category) { - ModelCategory.LANGUAGE -> CppBridgeModelRegistry.ModelType.LLM - ModelCategory.SPEECH_RECOGNITION -> CppBridgeModelRegistry.ModelType.STT - ModelCategory.SPEECH_SYNTHESIS -> CppBridgeModelRegistry.ModelType.TTS - ModelCategory.AUDIO -> CppBridgeModelRegistry.ModelType.VAD + ModelCategory.LANGUAGE -> CppBridgeModelRegistry.ModelCategory.LANGUAGE + ModelCategory.SPEECH_RECOGNITION -> CppBridgeModelRegistry.ModelCategory.SPEECH_RECOGNITION + ModelCategory.SPEECH_SYNTHESIS -> CppBridgeModelRegistry.ModelCategory.SPEECH_SYNTHESIS + ModelCategory.AUDIO -> CppBridgeModelRegistry.ModelCategory.AUDIO ModelCategory.VISION -> CppBridgeModelRegistry.ModelCategory.VISION ModelCategory.IMAGE_GENERATION -> CppBridgeModelRegistry.ModelCategory.IMAGE_GENERATION ModelCategory.MULTIMODAL -> CppBridgeModelRegistry.ModelCategory.MULTIMODAL @@ -710,8 +705,25 @@ actual fun RunAnywhere.downloadModel(modelId: String): Flow { ), ) - // Pass the URL to determine archive type (file may be saved without extension) - val extractedPath = extractArchive(downloadedFile, modelId, modelType, downloadUrl, downloadLogger) + // Extract and find model path using C++ orchestrator utilities + val racFramework = when (modelInfo.framework) { + InferenceFramework.LLAMA_CPP -> CppBridgeModelRegistry.Framework.LLAMACPP + InferenceFramework.ONNX -> CppBridgeModelRegistry.Framework.ONNX + InferenceFramework.FOUNDATION_MODELS -> CppBridgeModelRegistry.Framework.FOUNDATION_MODELS + InferenceFramework.SYSTEM_TTS -> CppBridgeModelRegistry.Framework.SYSTEM_TTS + InferenceFramework.FLUID_AUDIO -> CppBridgeModelRegistry.Framework.FLUID_AUDIO + InferenceFramework.BUILT_IN -> CppBridgeModelRegistry.Framework.BUILTIN + InferenceFramework.NONE -> CppBridgeModelRegistry.Framework.NONE + InferenceFramework.UNKNOWN -> CppBridgeModelRegistry.Framework.UNKNOWN + } + val racFormat = when (modelInfo.format) { + ModelFormat.GGUF -> CppBridgeModelRegistry.ModelFormat.GGUF + ModelFormat.ONNX -> CppBridgeModelRegistry.ModelFormat.ONNX + ModelFormat.ORT -> CppBridgeModelRegistry.ModelFormat.ORT + ModelFormat.BIN -> CppBridgeModelRegistry.ModelFormat.BIN + ModelFormat.UNKNOWN -> CppBridgeModelRegistry.ModelFormat.UNKNOWN + } + val extractedPath = extractArchive(downloadedFile, modelId, racFramework, racFormat, downloadLogger) downloadLogger.info("Extraction complete: $extractedPath") extractedPath } else { @@ -759,46 +771,36 @@ actual fun RunAnywhere.downloadModel(modelId: String): Flow { /** * Check if URL requires extraction (is an archive). - * Supports: .tar.gz, .tgz, .tar.bz2, .tbz2, .zip + * Delegates to C++ rac_download_requires_extraction() for consistent behavior across all SDKs. */ private fun requiresExtraction(url: String): Boolean { - val lowercaseUrl = url.lowercase() - return lowercaseUrl.endsWith(".tar.gz") || - lowercaseUrl.endsWith(".tgz") || - lowercaseUrl.endsWith(".tar.bz2") || - lowercaseUrl.endsWith(".tbz2") || - lowercaseUrl.endsWith(".zip") + return RunAnywhereBridge.nativeDownloadRequiresExtraction(url) } /** - * Extract an archive to the model directory. - * - * Supports: - * - .tar.gz / .tgz → Uses Apache Commons Compress - * - .tar.bz2 / .tbz2 → Uses Apache Commons Compress - * - .zip → Uses java.util.zip + * Extract an archive to the model directory using native C++ extraction (libarchive). * + * Supports all formats via auto-detection: ZIP, TAR.GZ, TAR.BZ2, TAR.XZ. * Archives typically contain a root folder (e.g., sherpa-onnx-whisper-tiny.en/) * so we extract to the parent directory and the archive structure creates the model folder. * + * Post-extraction model path finding uses C++ rac_find_model_path_after_extraction() + * for consistent behavior across all SDKs. + * * @param archiveFile The downloaded archive file (may not have extension in filename) * @param modelId The model ID - * @param modelType The model type - * @param originalUrl The original download URL (used to determine archive type) + * @param racFramework C++ framework constant (CppBridgeModelRegistry.Framework.*) + * @param racFormat C++ format constant (CppBridgeModelRegistry.ModelFormat.*) * @param logger Logger for debug output */ -@Suppress("UNUSED_PARAMETER") private suspend fun extractArchive( archiveFile: File, modelId: String, - modelType: Int, // Reserved for future type-specific extraction logic - originalUrl: String, + racFramework: Int, + racFormat: Int, logger: SDKLogger, ): String = withContext(Dispatchers.IO) { - // Extract to parent directory - the archive typically contains a root folder - // e.g., archive contains: sherpa-onnx-whisper-tiny.en/tiny.en-decoder.onnx - // So we extract to /models/stt/ and get /models/stt/sherpa-onnx-whisper-tiny.en/ val parentDir = archiveFile.parentFile if (parentDir == null || !parentDir.exists()) { throw SDKError.download("Cannot determine extraction directory for: ${archiveFile.absolutePath}") @@ -806,197 +808,51 @@ private suspend fun extractArchive( logger.info("Extracting to parent: ${parentDir.absolutePath}") logger.debug("Archive file: ${archiveFile.absolutePath}") - logger.debug("Original URL: $originalUrl") - - // Use the URL to determine archive type (file may be saved without extension) - val lowercaseUrl = originalUrl.lowercase() - // IMPORTANT: The archive file name might conflict with the folder inside the archive - // (e.g., file "sherpa-onnx-whisper-tiny.en" and archive contains folder "sherpa-onnx-whisper-tiny.en/") - // We need to rename/move the archive before extracting to avoid ENOTDIR errors + // Rename archive to temp to avoid name conflicts with extracted contents val tempArchiveFile = File(parentDir, "${archiveFile.name}.tmp_archive") try { if (!archiveFile.renameTo(tempArchiveFile)) { - // If rename fails, copy and delete archiveFile.copyTo(tempArchiveFile, overwrite = true) archiveFile.delete() } - logger.debug("Moved archive to temp: ${tempArchiveFile.absolutePath}") } catch (e: Exception) { - logger.error("Failed to move archive to temp location: ${e.message}") throw SDKError.download("Failed to prepare archive for extraction: ${e.message}") } try { - when { - lowercaseUrl.endsWith(".tar.gz") || lowercaseUrl.endsWith(".tgz") -> { - logger.info("Extracting tar.gz archive...") - extractTarGz(tempArchiveFile, parentDir, logger) - } - lowercaseUrl.endsWith(".tar.bz2") || lowercaseUrl.endsWith(".tbz2") -> { - logger.info("Extracting tar.bz2 archive...") - extractTarBz2(tempArchiveFile, parentDir, logger) - } - lowercaseUrl.endsWith(".zip") -> { - logger.info("Extracting zip archive...") - extractZip(tempArchiveFile, parentDir, logger) - } - else -> { - logger.warn("Unknown archive type for URL: $originalUrl") - // Restore the original file - tempArchiveFile.renameTo(archiveFile) - return@withContext archiveFile.absolutePath - } + // Use native C++ extraction (libarchive) — auto-detects format from magic bytes + val result = RunAnywhereBridge.nativeExtractArchive( + tempArchiveFile.absolutePath, + parentDir.absolutePath, + ) + if (result != 0) { + throw SDKError.download("Native extraction failed with code: $result") } + logger.info("Native extraction completed successfully") } finally { - // Always clean up the temp archive file try { if (tempArchiveFile.exists()) { tempArchiveFile.delete() - logger.debug("Cleaned up temp archive: ${tempArchiveFile.absolutePath}") } } catch (e: Exception) { logger.warn("Failed to clean up temp archive: ${e.message}") } } - // Find the extracted model directory - // The archive should have created a folder with the model ID name - val expectedModelDir = File(parentDir, modelId) - val finalPath = - if (expectedModelDir.exists() && expectedModelDir.isDirectory) { - expectedModelDir.absolutePath - } else { - // Fallback: look for any new directory created - parentDir - .listFiles() - ?.firstOrNull { - it.isDirectory && it.name.contains(modelId.substringBefore("-")) - }?.absolutePath ?: parentDir.absolutePath - } + // Find the extracted model path using C++ rac_find_model_path_after_extraction() + // Uses UNKNOWN structure to let C++ scan for model files and nested directories + val finalPath = RunAnywhereBridge.nativeFindModelPathAfterExtraction( + parentDir.absolutePath, + 99, // RAC_ARCHIVE_STRUCTURE_UNKNOWN — let C++ auto-detect + racFramework, + racFormat, + ) logger.info("Model extracted to: $finalPath") finalPath } -/** - * Extract a .tar.gz archive. - */ -private fun extractTarGz(archiveFile: File, destDir: File, logger: SDKLogger) { - logger.debug("Extracting tar.gz: ${archiveFile.absolutePath}") - - FileInputStream(archiveFile).use { fis -> - BufferedInputStream(fis).use { bis -> - GzipCompressorInputStream(bis).use { gzis -> - TarArchiveInputStream(gzis).use { tais -> - var entry = tais.nextEntry - var fileCount = 0 - - while (entry != null) { - val destFile = File(destDir, entry.name) - - // Security check - prevent path traversal - if (!destFile.canonicalPath.startsWith(destDir.canonicalPath)) { - throw SecurityException("Tar entry outside destination: ${entry.name}") - } - - if (entry.isDirectory) { - destFile.mkdirs() - } else { - destFile.parentFile?.mkdirs() - FileOutputStream(destFile).use { fos -> - tais.copyTo(fos) - } - fileCount++ - } - - entry = tais.nextEntry - } - - logger.info("Extracted $fileCount files from tar.gz") - } - } - } - } -} - -/** - * Extract a .tar.bz2 archive. - */ -private fun extractTarBz2(archiveFile: File, destDir: File, logger: SDKLogger) { - logger.debug("Extracting tar.bz2: ${archiveFile.absolutePath}") - - FileInputStream(archiveFile).use { fis -> - BufferedInputStream(fis).use { bis -> - BZip2CompressorInputStream(bis).use { bzis -> - TarArchiveInputStream(bzis).use { tais -> - var entry = tais.nextEntry - var fileCount = 0 - - while (entry != null) { - val destFile = File(destDir, entry.name) - - // Security check - prevent path traversal - if (!destFile.canonicalPath.startsWith(destDir.canonicalPath)) { - throw SecurityException("Tar entry outside destination: ${entry.name}") - } - - if (entry.isDirectory) { - destFile.mkdirs() - } else { - destFile.parentFile?.mkdirs() - FileOutputStream(destFile).use { fos -> - tais.copyTo(fos) - } - fileCount++ - } - - entry = tais.nextEntry - } - - logger.info("Extracted $fileCount files from tar.bz2") - } - } - } - } -} - -/** - * Extract a .zip archive. - */ -private fun extractZip(archiveFile: File, destDir: File, logger: SDKLogger) { - logger.debug("Extracting zip: ${archiveFile.absolutePath}") - - ZipInputStream(FileInputStream(archiveFile)).use { zis -> - var entry = zis.nextEntry - var fileCount = 0 - - while (entry != null) { - val destFile = File(destDir, entry.name) - - // Security check - prevent path traversal - if (!destFile.canonicalPath.startsWith(destDir.canonicalPath)) { - throw SecurityException("Zip entry outside destination: ${entry.name}") - } - - if (entry.isDirectory) { - destFile.mkdirs() - } else { - destFile.parentFile?.mkdirs() - FileOutputStream(destFile).use { fos -> - zis.copyTo(fos) - } - fileCount++ - } - - zis.closeEntry() - entry = zis.nextEntry - } - - logger.info("Extracted $fileCount files from zip") - } -} - // MARK: - Embedding Model Direct HTTP Download /** diff --git a/sdk/runanywhere-kotlin/src/jvmAndroidMain/kotlin/com/runanywhere/sdk/public/extensions/RunAnywhere+Storage.jvmAndroid.kt b/sdk/runanywhere-kotlin/src/jvmAndroidMain/kotlin/com/runanywhere/sdk/public/extensions/RunAnywhere+Storage.jvmAndroid.kt index 622ac9f39..16bc709fa 100644 --- a/sdk/runanywhere-kotlin/src/jvmAndroidMain/kotlin/com/runanywhere/sdk/public/extensions/RunAnywhere+Storage.jvmAndroid.kt +++ b/sdk/runanywhere-kotlin/src/jvmAndroidMain/kotlin/com/runanywhere/sdk/public/extensions/RunAnywhere+Storage.jvmAndroid.kt @@ -9,6 +9,7 @@ package com.runanywhere.sdk.public.extensions import com.runanywhere.sdk.core.types.InferenceFramework import com.runanywhere.sdk.foundation.SDKLogger +import com.runanywhere.sdk.foundation.bridge.extensions.CppBridgeFileManager import com.runanywhere.sdk.foundation.bridge.extensions.CppBridgeModelPaths import com.runanywhere.sdk.foundation.bridge.extensions.CppBridgeModelRegistry import com.runanywhere.sdk.foundation.bridge.extensions.CppBridgeStorage @@ -41,10 +42,10 @@ actual suspend fun RunAnywhere.storageInfo(): StorageInfo { val modelsDir = File(baseDir, "models") val appSupportDir = File(baseDir, "data") - // Calculate directory sizes - val cacheSize = calculateDirectorySize(cacheDir) - val modelsSize = calculateDirectorySize(modelsDir) - val appSupportSize = calculateDirectorySize(appSupportDir) + // Calculate directory sizes via C++ (recursive traversal in C++, Kotlin provides I/O callbacks) + val cacheSize = CppBridgeFileManager.calculateDirectorySize(cacheDir.absolutePath) + val modelsSize = CppBridgeFileManager.calculateDirectorySize(modelsDir.absolutePath) + val appSupportSize = CppBridgeFileManager.calculateDirectorySize(appSupportDir.absolutePath) val appStorage = AppStorageInfo( @@ -85,26 +86,21 @@ actual suspend fun RunAnywhere.checkStorageAvailability(requiredBytes: Long): St throw SDKError.notInitialized("SDK not initialized") } + // Delegate to C++ for storage check (1GB warning threshold logic is in C++) + val json = CppBridgeFileManager.checkStorageJson(requiredBytes) + if (json != null) { + return parseStorageAvailabilityJson(json, requiredBytes) + } + + // Fallback if C++ call fails val baseDir = File(CppBridgeModelPaths.getBaseDirectory()) val availableSpace = baseDir.freeSpace - val isAvailable = availableSpace >= requiredBytes - - // Check if we're getting low on space (less than 1GB after this operation) - val hasWarning = isAvailable && (availableSpace - requiredBytes) < 1024L * 1024 * 1024 - - val recommendation = - when { - !isAvailable -> "Not enough storage space. Required: ${formatBytes(requiredBytes)}, Available: ${formatBytes(availableSpace)}" - hasWarning -> "Storage is running low. Consider clearing cache or removing unused models." - else -> null - } - return StorageAvailability( - isAvailable = isAvailable, + isAvailable = availableSpace >= requiredBytes, requiredSpace = requiredBytes, availableSpace = availableSpace, - hasWarning = hasWarning, - recommendation = recommendation, + hasWarning = false, + recommendation = null, ) } @@ -113,8 +109,7 @@ actual suspend fun RunAnywhere.cacheSize(): Long { throw SDKError.notInitialized("SDK not initialized") } - val cacheDir = File(CppBridgeModelPaths.getBaseDirectory(), "cache") - return calculateDirectorySize(cacheDir) + return CppBridgeFileManager.cacheSize() } actual suspend fun RunAnywhere.clearCache() { @@ -127,12 +122,8 @@ actual suspend fun RunAnywhere.clearCache() { // Clear the storage cache namespace CppBridgeStorage.clear(CppBridgeStorage.StorageNamespace.INFERENCE_CACHE, CppBridgeStorage.StorageType.CACHE) - // Also clear the file cache directory - val cacheDir = File(CppBridgeModelPaths.getBaseDirectory(), "cache") - if (cacheDir.exists()) { - cacheDir.deleteRecursively() - cacheDir.mkdirs() - } + // Clear the file cache directory via C++ + CppBridgeFileManager.clearCache() storageLogger.info("Cache cleared") } @@ -151,25 +142,13 @@ actual suspend fun RunAnywhere.modelStorageUsed(): Long { throw SDKError.notInitialized("SDK not initialized") } - val modelsDir = File(CppBridgeModelPaths.getBaseDirectory(), "models") - return calculateDirectorySize(modelsDir) + return CppBridgeFileManager.modelsStorageUsed() } -// Helper function to calculate directory size recursively +// Delegate to C++ for recursive directory size calculation private fun calculateDirectorySize(directory: File): Long { if (!directory.exists()) return 0L - if (directory.isFile) return directory.length() - - var size = 0L - directory.listFiles()?.forEach { file -> - size += - if (file.isDirectory) { - calculateDirectorySize(file) - } else { - file.length() - } - } - return size + return CppBridgeFileManager.calculateDirectorySize(directory.absolutePath) } // Helper function to format bytes as human-readable string @@ -183,6 +162,29 @@ private fun formatBytes(bytes: Long): String { return "%.2f GB".format(gb) } +/** + * Parse storage availability JSON from C++ rac_file_manager_check_storage. + */ +private fun parseStorageAvailabilityJson(json: String, requiredBytes: Long): StorageAvailability { + // Simple JSON parsing without external library + val isAvailable = json.contains("\"isAvailable\":true") + val hasWarning = json.contains("\"hasWarning\":true") + + val availableSpace = Regex("\"availableSpace\":(\\d+)").find(json) + ?.groupValues?.get(1)?.toLongOrNull() ?: 0L + + val recommendation = Regex("\"recommendation\":\"([^\"]*)\"").find(json) + ?.groupValues?.get(1)?.takeIf { it.isNotEmpty() } + + return StorageAvailability( + isAvailable = isAvailable, + requiredSpace = requiredBytes, + availableSpace = availableSpace, + hasWarning = hasWarning, + recommendation = recommendation, + ) +} + /** * Convert a CppBridgeModelRegistry.ModelInfo to ModelStorageMetrics. * Calculates actual size on disk from the model's local path. diff --git a/sdk/runanywhere-react-native/packages/core/android/CMakeLists.txt b/sdk/runanywhere-react-native/packages/core/android/CMakeLists.txt index 7b56bf009..b2c328ff3 100644 --- a/sdk/runanywhere-react-native/packages/core/android/CMakeLists.txt +++ b/sdk/runanywhere-react-native/packages/core/android/CMakeLists.txt @@ -115,6 +115,8 @@ include_directories( "${RAC_INCLUDE_DIR}/rac/infrastructure/events" "${RAC_INCLUDE_DIR}/rac/infrastructure/model_management" "${RAC_INCLUDE_DIR}/rac/infrastructure/network" + "${RAC_INCLUDE_DIR}/rac/infrastructure/extraction" + "${RAC_INCLUDE_DIR}/rac/infrastructure/file_management" "${RAC_INCLUDE_DIR}/rac/infrastructure/storage" "${RAC_INCLUDE_DIR}/rac/infrastructure/telemetry" ) diff --git a/sdk/runanywhere-react-native/packages/core/android/build.gradle b/sdk/runanywhere-react-native/packages/core/android/build.gradle index b45008300..dec765451 100644 --- a/sdk/runanywhere-react-native/packages/core/android/build.gradle +++ b/sdk/runanywhere-react-native/packages/core/android/build.gradle @@ -429,9 +429,6 @@ dependencies { implementation "com.facebook.react:react-android" implementation project(":react-native-nitro-modules") - // Apache Commons Compress for tar.gz archive extraction - implementation "org.apache.commons:commons-compress:1.26.0" - // AndroidX Security for EncryptedSharedPreferences (device identity persistence) implementation "androidx.security:security-crypto:1.1.0-alpha06" } diff --git a/sdk/runanywhere-react-native/packages/core/android/consumer-rules.pro b/sdk/runanywhere-react-native/packages/core/android/consumer-rules.pro index 753e07c62..c28c2b2b0 100644 --- a/sdk/runanywhere-react-native/packages/core/android/consumer-rules.pro +++ b/sdk/runanywhere-react-native/packages/core/android/consumer-rules.pro @@ -1,5 +1 @@ -# Keep ArchiveUtility for JNI access --keep class com.margelo.nitro.runanywhere.ArchiveUtility { *; } --keepclassmembers class com.margelo.nitro.runanywhere.ArchiveUtility { - public static *** extract(java.lang.String, java.lang.String); -} +# No ProGuard rules needed - archive extraction is handled by native C++ (libarchive) diff --git a/sdk/runanywhere-react-native/packages/core/android/src/main/cpp/cpp-adapter.cpp b/sdk/runanywhere-react-native/packages/core/android/src/main/cpp/cpp-adapter.cpp index b0f894a27..efd0f68ca 100644 --- a/sdk/runanywhere-react-native/packages/core/android/src/main/cpp/cpp-adapter.cpp +++ b/sdk/runanywhere-react-native/packages/core/android/src/main/cpp/cpp-adapter.cpp @@ -4,7 +4,7 @@ #include "runanywherecoreOnLoad.hpp" #include "PlatformDownloadBridge.h" -#define LOG_TAG "ArchiveJNI" +#define LOG_TAG "RunAnywhereJNI" #define LOGI(...) __android_log_print(ANDROID_LOG_INFO, LOG_TAG, __VA_ARGS__) #define LOGE(...) __android_log_print(ANDROID_LOG_ERROR, LOG_TAG, __VA_ARGS__) @@ -12,11 +12,6 @@ // NOT static - needs to be accessible from InitBridge.cpp for secure storage JavaVM* g_javaVM = nullptr; -// Cache class and method references at JNI_OnLoad time -// This is necessary because FindClass from native threads uses the system class loader -static jclass g_archiveUtilityClass = nullptr; -static jmethodID g_extractMethod = nullptr; - // PlatformAdapterBridge class and methods for secure storage (used by InitBridge.cpp) // NOT static - needs to be accessible from InitBridge.cpp jclass g_platformAdapterBridgeClass = nullptr; @@ -44,44 +39,12 @@ jfieldID g_httpResponse_statusCodeField = nullptr; jfieldID g_httpResponse_responseBodyField = nullptr; jfieldID g_httpResponse_errorMessageField = nullptr; -// Forward declaration -extern "C" bool ArchiveUtility_extractAndroid(const char* archivePath, const char* destinationPath); - JNIEXPORT jint JNICALL JNI_OnLoad(JavaVM* vm, void*) { g_javaVM = vm; // Get JNIEnv to cache class references JNIEnv* env = nullptr; if (vm->GetEnv(reinterpret_cast(&env), JNI_VERSION_1_6) == JNI_OK && env != nullptr) { - // Find and cache the ArchiveUtility class - jclass localClass = env->FindClass("com/margelo/nitro/runanywhere/ArchiveUtility"); - if (localClass != nullptr) { - // Create a global reference so it persists across JNI calls - g_archiveUtilityClass = (jclass)env->NewGlobalRef(localClass); - env->DeleteLocalRef(localClass); - - // Cache the extract method - g_extractMethod = env->GetStaticMethodID( - g_archiveUtilityClass, - "extract", - "(Ljava/lang/String;Ljava/lang/String;)Z" - ); - - if (g_extractMethod != nullptr) { - LOGI("ArchiveUtility class and method cached successfully"); - } else { - LOGE("Failed to find extract method in ArchiveUtility"); - if (env->ExceptionCheck()) { - env->ExceptionClear(); - } - } - } else { - LOGE("Failed to find ArchiveUtility class at JNI_OnLoad"); - if (env->ExceptionCheck()) { - env->ExceptionClear(); - } - } - // Find and cache the PlatformAdapterBridge class (for secure storage) jclass platformClass = env->FindClass("com/margelo/nitro/runanywhere/PlatformAdapterBridge"); if (platformClass != nullptr) { @@ -219,63 +182,6 @@ static void logAndClearException(JNIEnv* env, const char* context) { } } -/** - * Call Kotlin ArchiveUtility.extract() via JNI - * Uses cached class and method references from JNI_OnLoad - */ -extern "C" bool ArchiveUtility_extractAndroid(const char* archivePath, const char* destinationPath) { - LOGI("Starting extraction: %s -> %s", archivePath, destinationPath); - - // Check if class and method were cached - if (g_archiveUtilityClass == nullptr || g_extractMethod == nullptr) { - LOGE("ArchiveUtility class or method not cached. JNI_OnLoad may have failed."); - return false; - } - - JNIEnv* env = getJNIEnv(); - if (env == nullptr) { - LOGE("Failed to get JNIEnv"); - return false; - } - - LOGI("Using cached ArchiveUtility class and method"); - - // Create Java strings - jstring jArchivePath = env->NewStringUTF(archivePath); - jstring jDestinationPath = env->NewStringUTF(destinationPath); - - if (jArchivePath == nullptr || jDestinationPath == nullptr) { - LOGE("Failed to create Java strings"); - if (jArchivePath) env->DeleteLocalRef(jArchivePath); - if (jDestinationPath) env->DeleteLocalRef(jDestinationPath); - return false; - } - - // Call the method using cached references - LOGI("Calling ArchiveUtility.extract()..."); - jboolean result = env->CallStaticBooleanMethod( - g_archiveUtilityClass, - g_extractMethod, - jArchivePath, - jDestinationPath - ); - - // Check for exceptions - if (env->ExceptionCheck()) { - LOGE("Exception during extraction"); - logAndClearException(env, "extract"); - result = JNI_FALSE; - } else { - LOGI("Extraction returned: %s", result ? "true" : "false"); - } - - // Cleanup local references - env->DeleteLocalRef(jArchivePath); - env->DeleteLocalRef(jDestinationPath); - - return result == JNI_TRUE; -} - // ============================================================================= // HTTP Download Callback Reporting (from Kotlin to C++) // ============================================================================= diff --git a/sdk/runanywhere-react-native/packages/core/android/src/main/java/com/margelo/nitro/runanywhere/ArchiveUtility.kt b/sdk/runanywhere-react-native/packages/core/android/src/main/java/com/margelo/nitro/runanywhere/ArchiveUtility.kt deleted file mode 100644 index 6dd64acad..000000000 --- a/sdk/runanywhere-react-native/packages/core/android/src/main/java/com/margelo/nitro/runanywhere/ArchiveUtility.kt +++ /dev/null @@ -1,308 +0,0 @@ -/** - * ArchiveUtility.kt - * - * Native archive extraction utility for Android. - * Uses Apache Commons Compress for robust tar.gz extraction (streaming, memory-efficient). - * Uses Java's native ZipInputStream for zip extraction. - * - * Mirrors the implementation from: - * sdk/runanywhere-swift/Sources/RunAnywhere/Infrastructure/Download/Utilities/ArchiveUtility.swift - * - * Supports: tar.gz, zip - * Note: All models should use tar.gz from RunanywhereAI/sherpa-onnx fork for best performance - */ - -package com.margelo.nitro.runanywhere - -import org.apache.commons.compress.archivers.tar.TarArchiveEntry -import org.apache.commons.compress.archivers.tar.TarArchiveInputStream -import org.apache.commons.compress.compressors.gzip.GzipCompressorInputStream -import java.io.BufferedInputStream -import java.io.File -import java.io.FileInputStream -import java.io.FileOutputStream -import java.util.zip.ZipInputStream - -/** - * Utility for handling archive extraction on Android - */ -object ArchiveUtility { - private val logger = SDKLogger.archive - - /** - * Extract an archive to a destination directory - * @param archivePath Path to the archive file - * @param destinationPath Destination directory path - * @return true if extraction succeeded - */ - @JvmStatic - fun extract(archivePath: String, destinationPath: String): Boolean { - logger.info("extract() called: $archivePath -> $destinationPath") - return try { - extractArchive(archivePath, destinationPath) - logger.info("extract() succeeded") - true - } catch (e: Exception) { - logger.logError(e, "Extraction failed") - false - } - } - - /** - * Extract an archive to a destination directory (throwing version) - */ - @Throws(Exception::class) - fun extractArchive( - archivePath: String, - destinationPath: String, - progressHandler: ((Double) -> Unit)? = null - ) { - val archiveFile = File(archivePath) - val destinationDir = File(destinationPath) - - if (!archiveFile.exists()) { - throw Exception("Archive not found: $archivePath") - } - - // Detect archive type by magic bytes (more reliable than file extension) - val archiveType = detectArchiveTypeByMagicBytes(archiveFile) - logger.info("Detected archive type: $archiveType for: $archivePath") - - when (archiveType) { - ArchiveType.GZIP -> { - extractTarGz(archiveFile, destinationDir, progressHandler) - } - ArchiveType.ZIP -> { - extractZip(archiveFile, destinationDir, progressHandler) - } - ArchiveType.BZIP2 -> { - throw Exception("tar.bz2 not supported. Use tar.gz from RunanywhereAI/sherpa-onnx fork.") - } - ArchiveType.XZ -> { - throw Exception("tar.xz not supported. Use tar.gz from RunanywhereAI/sherpa-onnx fork.") - } - ArchiveType.UNKNOWN -> { - // Fallback to file extension check - val lowercased = archivePath.lowercase() - when { - lowercased.endsWith(".tar.gz") || lowercased.endsWith(".tgz") -> { - extractTarGz(archiveFile, destinationDir, progressHandler) - } - lowercased.endsWith(".zip") -> { - extractZip(archiveFile, destinationDir, progressHandler) - } - else -> { - throw Exception("Unknown archive format: $archivePath") - } - } - } - } - } - - /** - * Archive type detected by magic bytes - */ - private enum class ArchiveType { - GZIP, ZIP, BZIP2, XZ, UNKNOWN - } - - /** - * Detect archive type by reading magic bytes from file header - */ - private fun detectArchiveTypeByMagicBytes(file: File): ArchiveType { - return try { - FileInputStream(file).use { fis -> - val header = ByteArray(6) - val bytesRead = fis.read(header) - if (bytesRead < 2) return ArchiveType.UNKNOWN - - // Check for gzip: 0x1f 0x8b - if (header[0] == 0x1f.toByte() && header[1] == 0x8b.toByte()) { - return ArchiveType.GZIP - } - - // Check for zip: 0x50 0x4b 0x03 0x04 ("PK\x03\x04") - if (bytesRead >= 4 && - header[0] == 0x50.toByte() && header[1] == 0x4b.toByte() && - header[2] == 0x03.toByte() && header[3] == 0x04.toByte()) { - return ArchiveType.ZIP - } - - // Check for bzip2: 0x42 0x5a ("BZ") - if (header[0] == 0x42.toByte() && header[1] == 0x5a.toByte()) { - return ArchiveType.BZIP2 - } - - // Check for xz: 0xfd 0x37 0x7a 0x58 0x5a 0x00 - if (bytesRead >= 6 && - header[0] == 0xfd.toByte() && header[1] == 0x37.toByte() && - header[2] == 0x7a.toByte() && header[3] == 0x58.toByte() && - header[4] == 0x5a.toByte() && header[5] == 0x00.toByte()) { - return ArchiveType.XZ - } - - ArchiveType.UNKNOWN - } - } catch (e: Exception) { - logger.error("Failed to detect archive type: ${e.message}") - ArchiveType.UNKNOWN - } - } - - // MARK: - tar.gz Extraction - - /** - * Extract a tar.gz archive using Apache Commons Compress (streaming, memory-efficient) - * This approach doesn't load the entire file into memory. - */ - private fun extractTarGz( - sourceFile: File, - destinationDir: File, - progressHandler: ((Double) -> Unit)? - ) { - val startTime = System.currentTimeMillis() - logger.info("Extracting tar.gz: ${sourceFile.name} (size: ${formatBytes(sourceFile.length())})") - progressHandler?.invoke(0.0) - - destinationDir.mkdirs() - var fileCount = 0 - val totalSize = sourceFile.length() - var bytesRead = 0L - - try { - // Use Apache Commons Compress for streaming tar.gz extraction - FileInputStream(sourceFile).use { fis -> - BufferedInputStream(fis).use { bis -> - GzipCompressorInputStream(bis).use { gzis -> - TarArchiveInputStream(gzis).use { tarIn -> - var entry: TarArchiveEntry? = tarIn.nextTarEntry - while (entry != null) { - val name = entry.name - - // Skip macOS resource forks and empty names - if (name.isEmpty() || name.startsWith("._") || name.startsWith("./._")) { - entry = tarIn.nextTarEntry - continue - } - - val outputFile = File(destinationDir, name) - - // Security check - prevent zip slip attack - val destDirPath = destinationDir.canonicalPath - val outputFilePath = outputFile.canonicalPath - if (!outputFilePath.startsWith(destDirPath + File.separator) && - outputFilePath != destDirPath) { - logger.warning("Skipping entry outside destination: $name") - entry = tarIn.nextTarEntry - continue - } - - if (entry.isDirectory) { - outputFile.mkdirs() - } else { - // Create parent directories - outputFile.parentFile?.mkdirs() - - // Extract file - FileOutputStream(outputFile).use { fos -> - val buffer = ByteArray(8192) - var len: Int - while (tarIn.read(buffer).also { len = it } != -1) { - fos.write(buffer, 0, len) - bytesRead += len - } - } - fileCount++ - - // Log progress for large files - if (fileCount % 10 == 0) { - logger.debug("Extracted $fileCount files...") - } - } - - // Report progress (estimate based on compressed bytes) - val progress = (bytesRead.toDouble() / (totalSize * 3)).coerceAtMost(0.95) - progressHandler?.invoke(progress) - - entry = tarIn.nextTarEntry - } - } - } - } - } - - val totalTime = System.currentTimeMillis() - startTime - logger.info("Extracted $fileCount files in ${totalTime}ms") - progressHandler?.invoke(1.0) - } catch (e: Exception) { - logger.logError(e, "tar.gz extraction failed") - throw e - } - } - - // MARK: - ZIP Extraction - - /** - * Extract a zip archive using Java's native ZipInputStream - */ - private fun extractZip( - sourceFile: File, - destinationDir: File, - progressHandler: ((Double) -> Unit)? - ) { - logger.info("Extracting zip: ${sourceFile.name}") - progressHandler?.invoke(0.0) - - destinationDir.mkdirs() - - var fileCount = 0 - ZipInputStream(BufferedInputStream(FileInputStream(sourceFile))).use { zis -> - var entry = zis.nextEntry - while (entry != null) { - val fileName = entry.name - val newFile = File(destinationDir, fileName) - - // Security check - prevent zip slip attack - val destDirPath = destinationDir.canonicalPath - val newFilePath = newFile.canonicalPath - if (!newFilePath.startsWith(destDirPath + File.separator)) { - throw Exception("Entry is outside of the target dir: $fileName") - } - - if (entry.isDirectory) { - newFile.mkdirs() - } else { - // Create parent directories - newFile.parentFile?.mkdirs() - - // Write file - FileOutputStream(newFile).use { fos -> - val buffer = ByteArray(8192) - var len: Int - while (zis.read(buffer).also { len = it } != -1) { - fos.write(buffer, 0, len) - } - } - fileCount++ - } - - zis.closeEntry() - entry = zis.nextEntry - } - } - - logger.info("Extracted $fileCount files from zip") - progressHandler?.invoke(1.0) - } - - // MARK: - Helpers - - private fun formatBytes(bytes: Long): String { - return when { - bytes < 1024 -> "$bytes B" - bytes < 1024 * 1024 -> String.format("%.1f KB", bytes / 1024.0) - bytes < 1024 * 1024 * 1024 -> String.format("%.1f MB", bytes / (1024.0 * 1024)) - else -> String.format("%.2f GB", bytes / (1024.0 * 1024 * 1024)) - } - } -} diff --git a/sdk/runanywhere-react-native/packages/core/cpp/HybridRunAnywhereCore.cpp b/sdk/runanywhere-react-native/packages/core/cpp/HybridRunAnywhereCore.cpp index 12506abf1..01e0d3b4b 100644 --- a/sdk/runanywhere-react-native/packages/core/cpp/HybridRunAnywhereCore.cpp +++ b/sdk/runanywhere-react-native/packages/core/cpp/HybridRunAnywhereCore.cpp @@ -38,6 +38,7 @@ #include "bridges/TelemetryBridge.hpp" #include "bridges/ToolCallingBridge.hpp" #include "bridges/RAGBridge.hpp" +#include "bridges/FileManagerBridge.hpp" // RACommons C API headers for capability methods // These are backend-agnostic - they work with any registered backend @@ -54,6 +55,7 @@ #include "rac_voice_agent.h" #include "rac_types.h" #include "rac_model_assignment.h" +#include "rac_extraction.h" #include #include @@ -340,6 +342,7 @@ HybridRunAnywhereCore::~HybridRunAnywhereCore() { // across instances and should persist for the SDK lifetime) EventBridge::shared().unregisterFromEvents(); DownloadBridge::shared().shutdown(); + FileManagerBridge::shared().shutdown(); StorageBridge::shared().shutdown(); ModelRegistryBridge::shared().shutdown(); // Note: InitBridge and TelemetryBridge are not shutdown in destructor @@ -406,6 +409,10 @@ std::shared_ptr> HybridRunAnywhereCore::initialize( // Continue - not fatal } + // 4b. Initialize file manager bridge (POSIX-based I/O for C++ business logic) + FileManagerBridge::shared().initialize(); + FileManagerBridge::shared().createDirectoryStructure(); + // 5. Initialize download manager result = DownloadBridge::shared().initialize(); if (result != RAC_SUCCESS) { @@ -521,6 +528,7 @@ std::shared_ptr> HybridRunAnywhereCore::destroy() { TelemetryBridge::shared().shutdown(); // Flush and destroy telemetry first EventBridge::shared().unregisterFromEvents(); DownloadBridge::shared().shutdown(); + FileManagerBridge::shared().shutdown(); StorageBridge::shared().shutdown(); ModelRegistryBridge::shared().shutdown(); InitBridge::shared().shutdown(); @@ -1185,19 +1193,23 @@ std::shared_ptr> HybridRunAnywhereCore::getDownloadProgress std::shared_ptr> HybridRunAnywhereCore::getStorageInfo() { return Promise::async([]() { + // Use FileManagerBridge for accurate storage info via C++ recursive traversal + auto fmInfo = FileManagerBridge::shared().getStorageInfo(); + + // Also get model count from registry auto registryHandle = ModelRegistryBridge::shared().getHandle(); - auto info = StorageBridge::shared().analyzeStorage(registryHandle); + auto storageInfo = StorageBridge::shared().analyzeStorage(registryHandle); return buildJsonObject({ - {"totalDeviceSpace", std::to_string(info.deviceStorage.totalSpace)}, - {"freeDeviceSpace", std::to_string(info.deviceStorage.freeSpace)}, - {"usedDeviceSpace", std::to_string(info.deviceStorage.usedSpace)}, - {"documentsSize", std::to_string(info.appStorage.documentsSize)}, - {"cacheSize", std::to_string(info.appStorage.cacheSize)}, - {"appSupportSize", std::to_string(info.appStorage.appSupportSize)}, - {"totalAppSize", std::to_string(info.appStorage.totalSize)}, - {"totalModelsSize", std::to_string(info.totalModelsSize)}, - {"modelCount", std::to_string(info.models.size())} + {"totalDeviceSpace", std::to_string(fmInfo.device_total)}, + {"freeDeviceSpace", std::to_string(fmInfo.device_free)}, + {"usedDeviceSpace", std::to_string(fmInfo.device_total - fmInfo.device_free)}, + {"documentsSize", std::to_string(fmInfo.models_size)}, + {"cacheSize", std::to_string(fmInfo.cache_size)}, + {"appSupportSize", std::to_string(fmInfo.temp_size)}, + {"totalAppSize", std::to_string(fmInfo.total_app_size)}, + {"totalModelsSize", std::to_string(fmInfo.models_size)}, + {"modelCount", std::to_string(storageInfo.models.size())} }); }); } @@ -1209,6 +1221,10 @@ std::shared_ptr> HybridRunAnywhereCore::clearCache() { // Clear the model assignment cache (in-memory cache for model assignments) rac_model_assignment_clear_cache(); + // Clear file cache and temp directories via C++ file manager + FileManagerBridge::shared().clearCache(); + FileManagerBridge::shared().clearTemp(); + LOGI("Cache cleared successfully"); return true; }); @@ -1218,7 +1234,19 @@ std::shared_ptr> HybridRunAnywhereCore::deleteModel( const std::string& modelId) { return Promise::async([modelId]() { LOGI("Deleting model: %s", modelId.c_str()); + + // Get framework from registry before removing, so we can delete files + auto modelInfo = ModelRegistryBridge::shared().getModel(modelId); + int framework = modelInfo ? static_cast(modelInfo->framework) : -1; + + // Remove from registry rac_result_t result = ModelRegistryBridge::shared().removeModel(modelId); + + // Delete files from disk + if (framework >= 0) { + FileManagerBridge::shared().deleteModel(modelId, framework); + } + return result == RAC_SUCCESS; }); } @@ -1310,47 +1338,29 @@ std::shared_ptr> HybridRunAnywhereCore::getLastError() { return Promise::async([this]() { return lastError_; }); } -// Forward declaration for platform-specific archive extraction -#if defined(__APPLE__) -extern "C" bool ArchiveUtility_extract(const char* archivePath, const char* destinationPath); -#elif defined(__ANDROID__) -// On Android, we'll call the Kotlin ArchiveUtility via JNI in a separate helper -extern "C" bool ArchiveUtility_extractAndroid(const char* archivePath, const char* destinationPath); -#endif - std::shared_ptr> HybridRunAnywhereCore::extractArchive( const std::string& archivePath, const std::string& destPath) { return Promise::async([this, archivePath, destPath]() { LOGI("extractArchive: %s -> %s", archivePath.c_str(), destPath.c_str()); -#if defined(__APPLE__) - // iOS: Call Swift ArchiveUtility - bool success = ArchiveUtility_extract(archivePath.c_str(), destPath.c_str()); - if (success) { - LOGI("iOS archive extraction succeeded"); - return true; - } else { - LOGE("iOS archive extraction failed"); - setLastError("Archive extraction failed"); - return false; - } -#elif defined(__ANDROID__) - // Android: Call Kotlin ArchiveUtility via JNI - bool success = ArchiveUtility_extractAndroid(archivePath.c_str(), destPath.c_str()); - if (success) { - LOGI("Android archive extraction succeeded"); + // Use native C++ extraction (libarchive) — works on all platforms + rac_result_t result = rac_extract_archive_native( + archivePath.c_str(), destPath.c_str(), + nullptr, // default options + nullptr, // no progress callback + nullptr, // no user data + nullptr // no result output + ); + + if (result == RAC_SUCCESS) { + LOGI("Native archive extraction succeeded"); return true; } else { - LOGE("Android archive extraction failed"); + LOGE("Native archive extraction failed with code: %d", result); setLastError("Archive extraction failed"); return false; } -#else - LOGW("Archive extraction not supported on this platform"); - setLastError("Archive extraction not supported"); - return false; -#endif }); } diff --git a/sdk/runanywhere-react-native/packages/core/cpp/bridges/FileManagerBridge.cpp b/sdk/runanywhere-react-native/packages/core/cpp/bridges/FileManagerBridge.cpp new file mode 100644 index 000000000..2fbd3b0c3 --- /dev/null +++ b/sdk/runanywhere-react-native/packages/core/cpp/bridges/FileManagerBridge.cpp @@ -0,0 +1,289 @@ +/** + * @file FileManagerBridge.cpp + * @brief C++ bridge for file manager operations. + * + * POSIX-based rac_file_callbacks_t implementation. + * Works on both iOS and Android (both are POSIX-compliant). + */ + +#include "FileManagerBridge.hpp" + +#include +#include +#include +#include +#include +#include +#include + +// Platform-specific logging +#if defined(ANDROID) || defined(__ANDROID__) +#include +#define LOG_TAG "FileManagerBridge" +#define LOGI(...) __android_log_print(ANDROID_LOG_INFO, LOG_TAG, __VA_ARGS__) +#define LOGD(...) __android_log_print(ANDROID_LOG_DEBUG, LOG_TAG, __VA_ARGS__) +#define LOGE(...) __android_log_print(ANDROID_LOG_ERROR, LOG_TAG, __VA_ARGS__) +#else +#include +#define LOGI(...) printf("[FileManagerBridge] "); printf(__VA_ARGS__); printf("\n") +#define LOGD(...) printf("[FileManagerBridge DEBUG] "); printf(__VA_ARGS__); printf("\n") +#define LOGE(...) printf("[FileManagerBridge ERROR] "); printf(__VA_ARGS__); printf("\n") +#endif + +namespace runanywhere { +namespace bridges { + +// ============================================================================= +// POSIX Callback Implementations +// ============================================================================= + +static rac_result_t posixCreateDirectory(const char* path, int recursive, void* /*userData*/) { + if (!path) return RAC_ERROR_NULL_POINTER; + + if (recursive) { + // Create intermediate directories + std::string pathStr(path); + size_t pos = 0; + while ((pos = pathStr.find('/', pos + 1)) != std::string::npos) { + std::string subPath = pathStr.substr(0, pos); + mkdir(subPath.c_str(), 0755); + } + } + + if (mkdir(path, 0755) == 0 || errno == EEXIST) { + return RAC_SUCCESS; + } + return RAC_ERROR_FILE_IO; +} + +static rac_result_t posixDeletePath(const char* path, int recursive, void* /*userData*/) { + if (!path) return RAC_ERROR_NULL_POINTER; + + struct stat st; + if (lstat(path, &st) != 0) { + return RAC_SUCCESS; // Already gone + } + + // Handle symlinks: remove the link itself, don't follow it + if (S_ISLNK(st.st_mode)) { + return unlink(path) == 0 ? RAC_SUCCESS : RAC_ERROR_FILE_IO; + } + + if (S_ISDIR(st.st_mode)) { + if (recursive) { + // Recursively delete directory contents + DIR* dir = opendir(path); + if (!dir) return RAC_ERROR_FILE_IO; + + struct dirent* entry; + while ((entry = readdir(dir)) != nullptr) { + if (strcmp(entry->d_name, ".") == 0 || strcmp(entry->d_name, "..") == 0) { + continue; + } + std::string childPath = std::string(path) + "/" + entry->d_name; + posixDeletePath(childPath.c_str(), 1, nullptr); + } + closedir(dir); + } + return (rmdir(path) == 0) ? RAC_SUCCESS : RAC_ERROR_FILE_IO; + } else { + return (unlink(path) == 0) ? RAC_SUCCESS : RAC_ERROR_FILE_IO; + } +} + +static rac_result_t posixListDirectory(const char* path, char*** outEntries, + size_t* outCount, void* /*userData*/) { + if (!path || !outEntries || !outCount) return RAC_ERROR_NULL_POINTER; + + *outEntries = nullptr; + *outCount = 0; + + DIR* dir = opendir(path); + if (!dir) return RAC_ERROR_FILE_NOT_FOUND; + + // First pass: count entries + size_t count = 0; + struct dirent* entry; + while ((entry = readdir(dir)) != nullptr) { + if (strcmp(entry->d_name, ".") == 0 || strcmp(entry->d_name, "..") == 0) { + continue; + } + count++; + } + + if (count == 0) { + closedir(dir); + return RAC_SUCCESS; + } + + // Allocate array + char** entries = static_cast(malloc(count * sizeof(char*))); + if (!entries) { + closedir(dir); + return RAC_ERROR_FILE_IO; + } + + // Second pass: fill entries + rewinddir(dir); + size_t i = 0; + while ((entry = readdir(dir)) != nullptr && i < count) { + if (strcmp(entry->d_name, ".") == 0 || strcmp(entry->d_name, "..") == 0) { + continue; + } + entries[i] = strdup(entry->d_name); + i++; + } + closedir(dir); + + *outEntries = entries; + *outCount = i; + return RAC_SUCCESS; +} + +static void posixFreeEntries(char** entries, size_t count, void* /*userData*/) { + if (!entries) return; + for (size_t i = 0; i < count; i++) { + free(entries[i]); + } + free(entries); +} + +static rac_bool_t posixPathExists(const char* path, rac_bool_t* outIsDirectory, + void* /*userData*/) { + if (!path) return RAC_FALSE; + + struct stat st; + if (stat(path, &st) != 0) return RAC_FALSE; + + if (outIsDirectory) { + *outIsDirectory = S_ISDIR(st.st_mode) ? RAC_TRUE : RAC_FALSE; + } + return RAC_TRUE; +} + +static int64_t posixGetFileSize(const char* path, void* /*userData*/) { + if (!path) return -1; + + struct stat st; + if (stat(path, &st) != 0) return -1; + return static_cast(st.st_size); +} + +static int64_t posixGetAvailableSpace(void* /*userData*/) { + // Use root "/" on iOS, "/data" on Android +#if defined(ANDROID) || defined(__ANDROID__) + const char* mountPoint = "/data"; +#else + const char* mountPoint = "/"; +#endif + + struct statvfs vfs; + if (statvfs(mountPoint, &vfs) != 0) return 0; + return static_cast(vfs.f_bavail) * static_cast(vfs.f_frsize); +} + +static int64_t posixGetTotalSpace(void* /*userData*/) { +#if defined(ANDROID) || defined(__ANDROID__) + const char* mountPoint = "/data"; +#else + const char* mountPoint = "/"; +#endif + + struct statvfs vfs; + if (statvfs(mountPoint, &vfs) != 0) return 0; + return static_cast(vfs.f_blocks) * static_cast(vfs.f_frsize); +} + +// ============================================================================= +// FileManagerBridge Implementation +// ============================================================================= + +FileManagerBridge& FileManagerBridge::shared() { + static FileManagerBridge instance; + return instance; +} + +void FileManagerBridge::initialize() { + if (isInitialized_) { + LOGD("File manager bridge already initialized"); + return; + } + + // Set up POSIX-based callbacks + memset(&callbacks_, 0, sizeof(callbacks_)); + callbacks_.create_directory = posixCreateDirectory; + callbacks_.delete_path = posixDeletePath; + callbacks_.list_directory = posixListDirectory; + callbacks_.free_entries = posixFreeEntries; + callbacks_.path_exists = posixPathExists; + callbacks_.get_file_size = posixGetFileSize; + callbacks_.get_available_space = posixGetAvailableSpace; + callbacks_.get_total_space = posixGetTotalSpace; + callbacks_.user_data = nullptr; + + isInitialized_ = true; + LOGI("File manager bridge initialized with POSIX callbacks"); +} + +void FileManagerBridge::shutdown() { + isInitialized_ = false; + memset(&callbacks_, 0, sizeof(callbacks_)); + LOGI("File manager bridge shutdown"); +} + +bool FileManagerBridge::createDirectoryStructure() { + if (!isInitialized_) return false; + return rac_file_manager_create_directory_structure(&callbacks_) == RAC_SUCCESS; +} + +int64_t FileManagerBridge::calculateDirectorySize(const std::string& path) { + if (!isInitialized_) return 0; + + int64_t size = 0; + rac_result_t result = rac_file_manager_calculate_dir_size(&callbacks_, path.c_str(), &size); + return (result == RAC_SUCCESS) ? size : 0; +} + +int64_t FileManagerBridge::modelsStorageUsed() { + if (!isInitialized_) return 0; + + int64_t size = 0; + rac_result_t result = rac_file_manager_models_storage_used(&callbacks_, &size); + return (result == RAC_SUCCESS) ? size : 0; +} + +bool FileManagerBridge::clearCache() { + if (!isInitialized_) return false; + return rac_file_manager_clear_cache(&callbacks_) == RAC_SUCCESS; +} + +bool FileManagerBridge::clearTemp() { + if (!isInitialized_) return false; + return rac_file_manager_clear_temp(&callbacks_) == RAC_SUCCESS; +} + +int64_t FileManagerBridge::cacheSize() { + if (!isInitialized_) return 0; + + int64_t size = 0; + rac_result_t result = rac_file_manager_cache_size(&callbacks_, &size); + return (result == RAC_SUCCESS) ? size : 0; +} + +bool FileManagerBridge::deleteModel(const std::string& modelId, int framework) { + if (!isInitialized_) return false; + return rac_file_manager_delete_model( + &callbacks_, modelId.c_str(), + static_cast(framework)) == RAC_SUCCESS; +} + +rac_file_manager_storage_info_t FileManagerBridge::getStorageInfo() { + rac_file_manager_storage_info_t info = {}; + if (!isInitialized_) return info; + + rac_file_manager_get_storage_info(&callbacks_, &info); + return info; +} + +} // namespace bridges +} // namespace runanywhere diff --git a/sdk/runanywhere-react-native/packages/core/cpp/bridges/FileManagerBridge.hpp b/sdk/runanywhere-react-native/packages/core/cpp/bridges/FileManagerBridge.hpp new file mode 100644 index 000000000..3533246f7 --- /dev/null +++ b/sdk/runanywhere-react-native/packages/core/cpp/bridges/FileManagerBridge.hpp @@ -0,0 +1,113 @@ +/** + * @file FileManagerBridge.hpp + * @brief C++ bridge for file manager operations via rac_file_manager_* API. + * + * Uses POSIX I/O directly (works on both iOS and Android) to implement + * rac_file_callbacks_t. C++ handles all business logic (recursive traversal, + * cache clearing, storage info); no JS callbacks needed. + * + * Reference: sdk/runanywhere-commons/include/rac/infrastructure/file_management/rac_file_manager.h + */ + +#pragma once + +#include +#include + +#include "rac_types.h" +#include "rac_file_manager.h" + +namespace runanywhere { +namespace bridges { + +/** + * FileManagerBridge - File management via rac_file_manager_* API + * + * Provides POSIX-based rac_file_callbacks_t implementation so C++ + * handles all recursion and business logic. SDKs only call high-level methods. + */ +class FileManagerBridge { +public: + /** + * Get shared instance + */ + static FileManagerBridge& shared(); + + /** + * Initialize the file manager bridge. + * Sets up POSIX-based file callbacks. + */ + void initialize(); + + /** + * Shutdown and cleanup + */ + void shutdown(); + + /** + * Check if initialized + */ + bool isInitialized() const { return isInitialized_; } + + /** + * Get the file callbacks struct (for use by StorageBridge if needed) + */ + const rac_file_callbacks_t* getCallbacks() const { return &callbacks_; } + + // ========================================================================= + // Public API (wraps rac_file_manager_* functions) + // ========================================================================= + + /** + * Create standard directory structure (Models/Cache/Temp/Downloads) + */ + bool createDirectoryStructure(); + + /** + * Calculate directory size recursively (in C++) + * Replaces FileSystem.ts getDirectorySize() and calculateExtractionStats() + */ + int64_t calculateDirectorySize(const std::string& path); + + /** + * Get total models storage used + */ + int64_t modelsStorageUsed(); + + /** + * Clear cache directory (delete + recreate) + */ + bool clearCache(); + + /** + * Clear temp directory (delete + recreate) + */ + bool clearTemp(); + + /** + * Get cache directory size + */ + int64_t cacheSize(); + + /** + * Delete a model folder + */ + bool deleteModel(const std::string& modelId, int framework); + + /** + * Get combined storage info (device + app) + */ + rac_file_manager_storage_info_t getStorageInfo(); + +private: + FileManagerBridge() = default; + ~FileManagerBridge() = default; + FileManagerBridge(const FileManagerBridge&) = delete; + FileManagerBridge& operator=(const FileManagerBridge&) = delete; + + bool isInitialized_ = false; + rac_file_callbacks_t callbacks_{}; +}; + +} // namespace bridges +} // namespace runanywhere diff --git a/sdk/runanywhere-react-native/packages/core/ios/ArchiveUtility.swift b/sdk/runanywhere-react-native/packages/core/ios/ArchiveUtility.swift deleted file mode 100644 index cf890a1c9..000000000 --- a/sdk/runanywhere-react-native/packages/core/ios/ArchiveUtility.swift +++ /dev/null @@ -1,557 +0,0 @@ -/** - * ArchiveUtility.swift - * - * Native archive extraction utility for React Native. - * Uses Apple's native Compression framework for gzip decompression (fast) - * and pure Swift tar extraction. - * - * Mirrors the implementation from: - * sdk/runanywhere-swift/Sources/RunAnywhere/Infrastructure/Download/Utilities/ArchiveUtility.swift - * - * Supports: tar.gz, zip - * Note: All models should use tar.gz from RunanywhereAI/sherpa-onnx fork for best performance - */ - -import Compression -import Foundation - -/// Archive extraction errors -public enum ArchiveError: Error, LocalizedError { - case invalidArchive(String) - case decompressionFailed(String) - case extractionFailed(String) - case unsupportedFormat(String) - case fileNotFound(String) - - public var errorDescription: String? { - switch self { - case .invalidArchive(let msg): return "Invalid archive: \(msg)" - case .decompressionFailed(let msg): return "Decompression failed: \(msg)" - case .extractionFailed(let msg): return "Extraction failed: \(msg)" - case .unsupportedFormat(let msg): return "Unsupported format: \(msg)" - case .fileNotFound(let msg): return "File not found: \(msg)" - } - } -} - -/// Utility for handling archive extraction -@objc public final class ArchiveUtility: NSObject { - - // MARK: - Public API - - /// Extract an archive to a destination directory - /// - Parameters: - /// - archivePath: Path to the archive file - /// - destinationPath: Destination directory path - /// - Returns: true if extraction succeeded - @objc public static func extract( - archivePath: String, - to destinationPath: String - ) -> Bool { - do { - try extractArchive(archivePath: archivePath, to: destinationPath) - return true - } catch { - SDKLogger.archive.logError(error, additionalInfo: "Extraction failed") - return false - } - } - - /// Extract an archive to a destination directory (throwing version) - public static func extractArchive( - archivePath: String, - to destinationPath: String, - progressHandler: ((Double) -> Void)? = nil - ) throws { - let archiveURL = URL(fileURLWithPath: archivePath) - let destinationURL = URL(fileURLWithPath: destinationPath) - - // Ensure archive exists - guard FileManager.default.fileExists(atPath: archivePath) else { - throw ArchiveError.fileNotFound("Archive not found: \(archivePath)") - } - - // Detect archive type by magic bytes (more reliable than file extension) - let archiveType = try detectArchiveTypeByMagicBytes(archivePath) - SDKLogger.archive.info("Detected archive type: \(archiveType) for: \(archivePath)") - - switch archiveType { - case .gzip: - try extractTarGz(from: archiveURL, to: destinationURL, progressHandler: progressHandler) - case .zip: - try extractZip(from: archiveURL, to: destinationURL, progressHandler: progressHandler) - case .bzip2: - throw ArchiveError.unsupportedFormat("tar.bz2 not supported. Use tar.gz from RunanywhereAI/sherpa-onnx fork.") - case .xz: - throw ArchiveError.unsupportedFormat("tar.xz not supported. Use tar.gz from RunanywhereAI/sherpa-onnx fork.") - case .unknown: - // Fallback to file extension check - let lowercased = archivePath.lowercased() - if lowercased.hasSuffix(".tar.gz") || lowercased.hasSuffix(".tgz") { - try extractTarGz(from: archiveURL, to: destinationURL, progressHandler: progressHandler) - } else if lowercased.hasSuffix(".zip") { - try extractZip(from: archiveURL, to: destinationURL, progressHandler: progressHandler) - } else { - throw ArchiveError.unsupportedFormat("Unknown archive format: \(archivePath)") - } - } - } - - /// Archive type detected by magic bytes - private enum DetectedArchiveType { - case gzip - case zip - case bzip2 - case xz - case unknown - } - - /// Detect archive type by reading magic bytes from file header - private static func detectArchiveTypeByMagicBytes(_ path: String) throws -> DetectedArchiveType { - guard let fileHandle = FileHandle(forReadingAtPath: path) else { - throw ArchiveError.fileNotFound("Cannot open file: \(path)") - } - defer { try? fileHandle.close() } - - // Read first 6 bytes for magic number detection - guard let headerData = try? fileHandle.read(upToCount: 6), headerData.count >= 2 else { - return .unknown - } - - // Check for gzip: 0x1f 0x8b - if headerData[0] == 0x1f && headerData[1] == 0x8b { - return .gzip - } - - // Check for zip: 0x50 0x4b 0x03 0x04 ("PK\x03\x04") - if headerData.count >= 4 && - headerData[0] == 0x50 && headerData[1] == 0x4b && - headerData[2] == 0x03 && headerData[3] == 0x04 { - return .zip - } - - // Check for bzip2: 0x42 0x5a ("BZ") - if headerData[0] == 0x42 && headerData[1] == 0x5a { - return .bzip2 - } - - // Check for xz: 0xfd 0x37 0x7a 0x58 0x5a 0x00 - if headerData.count >= 6 && - headerData[0] == 0xfd && headerData[1] == 0x37 && - headerData[2] == 0x7a && headerData[3] == 0x58 && - headerData[4] == 0x5a && headerData[5] == 0x00 { - return .xz - } - - return .unknown - } - - // MARK: - tar.gz Extraction (Native Compression Framework) - - /// Extract a tar.gz archive using streaming decompression to keep memory constant. - /// Decompresses gzip to a temporary tar file on disk, then extracts tar entries. - private static func extractTarGz( - from sourceURL: URL, - to destinationURL: URL, - progressHandler: ((Double) -> Void)? - ) throws { - let overallStart = Date() - SDKLogger.archive.info("Extracting tar.gz: \(sourceURL.lastPathComponent)") - progressHandler?(0.0) - - // Step 1: Stream-decompress gzip to a temporary tar file on disk. - // This avoids holding both compressed + decompressed data in memory simultaneously. - let tempTarURL = destinationURL.appendingPathExtension("tar.tmp") - defer { try? FileManager.default.removeItem(at: tempTarURL) } - - let decompressStart = Date() - SDKLogger.archive.info("Starting streaming gzip decompression...") - try decompressGzipToFile(from: sourceURL, to: tempTarURL) - let decompressTime = Date().timeIntervalSince(decompressStart) - let tarFileSize = (try? FileManager.default.attributesOfItem(atPath: tempTarURL.path)[.size] as? Int) ?? 0 - SDKLogger.archive.info("Decompressed to \(formatBytes(tarFileSize)) in \(String(format: "%.2f", decompressTime))s") - progressHandler?(0.3) - - // Step 2: Read tar data and extract entries - let extractStart = Date() - SDKLogger.archive.info("Extracting tar data...") - let tarData = try Data(contentsOf: tempTarURL) - try extractTarData(tarData, to: destinationURL, progressHandler: { progress in - progressHandler?(0.3 + progress * 0.7) - }) - let extractTime = Date().timeIntervalSince(extractStart) - SDKLogger.archive.info("Tar extract completed in \(String(format: "%.2f", extractTime))s") - - let totalTime = Date().timeIntervalSince(overallStart) - SDKLogger.archive.info("Total extraction time: \(String(format: "%.2f", totalTime))s") - progressHandler?(1.0) - } - - /// Stream-decompress a gzip file to an output file using compression_stream_process. - /// Peak memory usage is ~512 KB (input + output buffers) regardless of file size. - private static func decompressGzipToFile(from sourceURL: URL, to destinationURL: URL) throws { - guard let inputHandle = FileHandle(forReadingAtPath: sourceURL.path) else { - throw ArchiveError.fileNotFound("Cannot open: \(sourceURL.path)") - } - defer { try? inputHandle.close() } - - // Parse gzip header to find where the deflate stream begins - let headerOffset = try parseGzipHeader(from: inputHandle) - inputHandle.seek(toFileOffset: UInt64(headerOffset)) - - // The deflate stream ends 8 bytes before EOF (CRC32 + ISIZE trailer) - let attrs = try FileManager.default.attributesOfItem(atPath: sourceURL.path) - let fileSize = (attrs[.size] as? UInt64) ?? 0 - guard fileSize > UInt64(headerOffset) + 8 else { - throw ArchiveError.invalidArchive("Gzip file too small for valid deflate stream") - } - let deflateEndOffset = fileSize - 8 - - // Create output file - FileManager.default.createFile(atPath: destinationURL.path, contents: nil) - guard let outputHandle = FileHandle(forWritingAtPath: destinationURL.path) else { - throw ArchiveError.extractionFailed("Cannot create temp file: \(destinationURL.path)") - } - defer { try? outputHandle.close() } - - // Initialize streaming decompression - var stream = compression_stream() - guard compression_stream_init(&stream, COMPRESSION_STREAM_DECODE, COMPRESSION_ZLIB) == COMPRESSION_STATUS_OK else { - throw ArchiveError.decompressionFailed("Failed to initialize decompression stream") - } - defer { compression_stream_destroy(&stream) } - - let chunkSize = 256 * 1024 // 256 KB - let inputBuffer = UnsafeMutablePointer.allocate(capacity: chunkSize) - let outputBuffer = UnsafeMutablePointer.allocate(capacity: chunkSize) - defer { - inputBuffer.deallocate() - outputBuffer.deallocate() - } - - stream.src_size = 0 - var finished = false - - while !finished { - // Feed more input when the previous chunk is consumed - if stream.src_size == 0 { - let currentOffset = inputHandle.offsetInFile - if currentOffset >= deflateEndOffset { - finished = true - } else { - let bytesToRead = min(UInt64(chunkSize), deflateEndOffset - currentOffset) - let chunk = inputHandle.readData(ofLength: Int(bytesToRead)) - if chunk.isEmpty { - finished = true - } else { - chunk.copyBytes(to: inputBuffer, count: chunk.count) - stream.src_ptr = UnsafePointer(inputBuffer) - stream.src_size = chunk.count - } - } - } - - stream.dst_ptr = outputBuffer - stream.dst_size = chunkSize - - let flags: Int32 = finished ? Int32(COMPRESSION_STREAM_FINALIZE.rawValue) : 0 - let status = compression_stream_process(&stream, flags) - - let bytesProduced = chunkSize - stream.dst_size - if bytesProduced > 0 { - outputHandle.write(Data(bytes: outputBuffer, count: bytesProduced)) - } - - switch status { - case COMPRESSION_STATUS_OK: - continue - case COMPRESSION_STATUS_END: - finished = true - case COMPRESSION_STATUS_ERROR: - throw ArchiveError.decompressionFailed("Streaming decompression error") - default: - throw ArchiveError.decompressionFailed("Unexpected compression status: \(status)") - } - } - } - - /// Parse the gzip header and return the byte offset where the deflate stream begins. - private static func parseGzipHeader(from handle: FileHandle) throws -> Int { - handle.seek(toFileOffset: 0) - guard let header = try? handle.read(upToCount: 10), header.count >= 10 else { - throw ArchiveError.invalidArchive("Gzip data too short") - } - guard header[0] == 0x1f && header[1] == 0x8b else { - throw ArchiveError.invalidArchive("Invalid gzip magic number") - } - guard header[2] == 8 else { - throw ArchiveError.invalidArchive("Unsupported gzip compression method") - } - - let flags = header[3] - var offset = 10 - - if (flags & 0x04) != 0 { // FEXTRA - handle.seek(toFileOffset: UInt64(offset)) - guard let extraLenData = try? handle.read(upToCount: 2), extraLenData.count >= 2 else { - throw ArchiveError.invalidArchive("Truncated gzip header (FEXTRA)") - } - let extraLen = Int(extraLenData[0]) | (Int(extraLenData[1]) << 8) - offset += 2 + extraLen - } - - if (flags & 0x08) != 0 { // FNAME - handle.seek(toFileOffset: UInt64(offset)) - while true { - guard let byte = try? handle.read(upToCount: 1), byte.count == 1 else { break } - offset += 1 - if byte[0] == 0 { break } - } - } - - if (flags & 0x10) != 0 { // FCOMMENT - handle.seek(toFileOffset: UInt64(offset)) - while true { - guard let byte = try? handle.read(upToCount: 1), byte.count == 1 else { break } - offset += 1 - if byte[0] == 0 { break } - } - } - - if (flags & 0x02) != 0 { // FHCRC - offset += 2 - } - - return offset - } - - // MARK: - ZIP Extraction (Pure Swift using Foundation) - - private static func extractZip( - from sourceURL: URL, - to destinationURL: URL, - progressHandler: ((Double) -> Void)? - ) throws { - SDKLogger.archive.info("Extracting zip: \(sourceURL.lastPathComponent)") - progressHandler?(0.0) - - // Create destination directory - try FileManager.default.createDirectory(at: destinationURL, withIntermediateDirectories: true) - - // Read zip file - guard let archive = try? Data(contentsOf: sourceURL) else { - throw ArchiveError.fileNotFound("Cannot read zip file: \(sourceURL.path)") - } - - // Parse and extract ZIP using pure Swift - var offset = 0 - var fileCount = 0 - let totalSize = archive.count - - while offset < archive.count - 4 { - // Check for local file header signature (0x04034b50 = PK\x03\x04) - let sig0 = archive[offset] - let sig1 = archive[offset + 1] - let sig2 = archive[offset + 2] - let sig3 = archive[offset + 3] - - if sig0 == 0x50 && sig1 == 0x4b && sig2 == 0x03 && sig3 == 0x04 { - // Local file header - let compressionMethod = UInt16(archive[offset + 8]) | (UInt16(archive[offset + 9]) << 8) - let compressedSize = UInt32(archive[offset + 18]) | - (UInt32(archive[offset + 19]) << 8) | - (UInt32(archive[offset + 20]) << 16) | - (UInt32(archive[offset + 21]) << 24) - let uncompressedSize = UInt32(archive[offset + 22]) | - (UInt32(archive[offset + 23]) << 8) | - (UInt32(archive[offset + 24]) << 16) | - (UInt32(archive[offset + 25]) << 24) - let fileNameLength = UInt16(archive[offset + 26]) | (UInt16(archive[offset + 27]) << 8) - let extraFieldLength = UInt16(archive[offset + 28]) | (UInt16(archive[offset + 29]) << 8) - - let headerEnd = offset + 30 - let fileNameData = archive.subdata(in: headerEnd..<(headerEnd + Int(fileNameLength))) - let fileName = String(data: fileNameData, encoding: .utf8) ?? "" - - let dataStart = headerEnd + Int(fileNameLength) + Int(extraFieldLength) - let dataEnd = dataStart + Int(compressedSize) - - let filePath = destinationURL.appendingPathComponent(fileName) - - if fileName.hasSuffix("/") { - // Directory - try FileManager.default.createDirectory(at: filePath, withIntermediateDirectories: true) - } else if !fileName.isEmpty && !fileName.hasPrefix("__MACOSX") { - // File - try FileManager.default.createDirectory(at: filePath.deletingLastPathComponent(), withIntermediateDirectories: true) - - if compressionMethod == 0 { - // Stored (no compression) - let fileData = archive.subdata(in: dataStart.. Data? { - var destinationBufferSize = max(uncompressedSize, data.count * 4) - var decompressedData = Data(count: destinationBufferSize) - - let decompressedSize = data.withUnsafeBytes { (srcPtr: UnsafeRawBufferPointer) -> Int in - guard let sourceAddress = srcPtr.baseAddress else { return 0 } - - return decompressedData.withUnsafeMutableBytes { (destPtr: UnsafeMutableRawBufferPointer) -> Int in - guard let destAddress = destPtr.baseAddress else { return 0 } - - // Use COMPRESSION_ZLIB for raw deflate - return compression_decode_buffer( - destAddress.assumingMemoryBound(to: UInt8.self), - destinationBufferSize, - sourceAddress.assumingMemoryBound(to: UInt8.self), - data.count, - nil, - COMPRESSION_ZLIB - ) - } - } - - guard decompressedSize > 0 else { return nil } - decompressedData.count = decompressedSize - return decompressedData - } - - // MARK: - TAR Extraction (Pure Swift) - - /// Extract tar data to destination directory - private static func extractTarData( - _ tarData: Data, - to destinationURL: URL, - progressHandler: ((Double) -> Void)? - ) throws { - // Create destination directory - try FileManager.default.createDirectory(at: destinationURL, withIntermediateDirectories: true) - - var offset = 0 - let totalSize = tarData.count - var fileCount = 0 - - while offset + 512 <= tarData.count { - // Read tar header (512 bytes) - let headerData = tarData.subdata(in: offset..<(offset + 512)) - - // Check for end of archive (two consecutive zero blocks) - if headerData.allSatisfy({ $0 == 0 }) { - break - } - - // Parse header - let nameData = headerData.subdata(in: 0..<100) - let sizeData = headerData.subdata(in: 124..<136) - let typeFlag = headerData[156] - let prefixData = headerData.subdata(in: 345..<500) - - // Get file name - let name = extractNullTerminatedString(from: nameData) - let prefix = extractNullTerminatedString(from: prefixData) - let fullName = prefix.isEmpty ? name : "\(prefix)/\(name)" - - // Skip if name is empty or is macOS resource fork - guard !fullName.isEmpty, !fullName.hasPrefix("._") else { - offset += 512 - continue - } - - // Parse file size (octal) - let sizeString = extractNullTerminatedString(from: sizeData).trimmingCharacters(in: .whitespaces) - let fileSize = Int(sizeString, radix: 8) ?? 0 - - offset += 512 // Move past header - - let filePath = destinationURL.appendingPathComponent(fullName) - - // Handle different entry types - if typeFlag == 0x35 || (typeFlag == 0x30 && fullName.hasSuffix("/")) { // Directory - try FileManager.default.createDirectory(at: filePath, withIntermediateDirectories: true) - } else if typeFlag == 0x30 || typeFlag == 0 { // Regular file - // Ensure parent directory exists - try FileManager.default.createDirectory(at: filePath.deletingLastPathComponent(), withIntermediateDirectories: true) - - // Extract file data - if fileSize > 0 && offset + fileSize <= tarData.count { - let fileData = tarData.subdata(in: offset..<(offset + fileSize)) - try fileData.write(to: filePath) - } else { - // Create empty file - FileManager.default.createFile(atPath: filePath.path, contents: nil) - } - fileCount += 1 - } else if typeFlag == 0x32 { // Symbolic link - let linkName = extractNullTerminatedString(from: headerData.subdata(in: 157..<257)) - if !linkName.isEmpty { - try FileManager.default.createDirectory(at: filePath.deletingLastPathComponent(), withIntermediateDirectories: true) - try? FileManager.default.createSymbolicLink(atPath: filePath.path, withDestinationPath: linkName) - } - } - - // Move to next entry (file data + padding to 512-byte boundary) - offset += fileSize - let padding = (512 - (fileSize % 512)) % 512 - offset += padding - - // Report progress - progressHandler?(Double(offset) / Double(totalSize)) - } - - SDKLogger.archive.info("Extracted \(fileCount) files") - } - - // MARK: - Helpers - - private static func extractNullTerminatedString(from data: Data) -> String { - if let nullIndex = data.firstIndex(of: 0) { - return String(data: data.subdata(in: 0.. String { - if bytes < 1024 { - return "\(bytes) B" - } else if bytes < 1024 * 1024 { - return String(format: "%.1f KB", Double(bytes) / 1024) - } else if bytes < 1024 * 1024 * 1024 { - return String(format: "%.1f MB", Double(bytes) / (1024 * 1024)) - } else { - return String(format: "%.2f GB", Double(bytes) / (1024 * 1024 * 1024)) - } - } -} diff --git a/sdk/runanywhere-react-native/packages/core/ios/ArchiveUtilityBridge.m b/sdk/runanywhere-react-native/packages/core/ios/ArchiveUtilityBridge.m deleted file mode 100644 index 7955617db..000000000 --- a/sdk/runanywhere-react-native/packages/core/ios/ArchiveUtilityBridge.m +++ /dev/null @@ -1,52 +0,0 @@ -/** - * ArchiveUtilityBridge.m - * - * C bridge to call Swift ArchiveUtility from C++. - * This bridge is necessary because C++ cannot directly call Swift code. - */ - -#import -#import "RNSDKLoggerBridge.h" - -static NSString * const kLogCategory = @"ArchiveBridge"; - -// Import the generated Swift header from the pod -#if __has_include() -#import -#elif __has_include("RunAnywhereCore-Swift.h") -#import "RunAnywhereCore-Swift.h" -#else -// Forward declare the Swift class if header not found -@interface ArchiveUtility : NSObject -+ (BOOL)extractWithArchivePath:(NSString * _Nonnull)archivePath to:(NSString * _Nonnull)destinationPath; -@end -#endif - -/** - * Extract an archive to a destination directory - * Called from C++ HybridRunAnywhereCore::extractArchive - */ -bool ArchiveUtility_extract(const char* archivePath, const char* destinationPath) { - @autoreleasepool { - if (archivePath == NULL || destinationPath == NULL) { - RN_LOG_ERROR(kLogCategory, @"Invalid null path"); - return false; - } - - NSString* archivePathStr = [NSString stringWithUTF8String:archivePath]; - NSString* destinationPathStr = [NSString stringWithUTF8String:destinationPath]; - - if (archivePathStr == nil || destinationPathStr == nil) { - RN_LOG_ERROR(kLogCategory, @"Failed to create NSString from path"); - return false; - } - - @try { - BOOL result = [ArchiveUtility extractWithArchivePath:archivePathStr to:destinationPathStr]; - return result; - } @catch (NSException *exception) { - RN_LOG_ERROR(kLogCategory, @"Exception: %@", exception); - return false; - } - } -} diff --git a/sdk/runanywhere-react-native/packages/core/package.json b/sdk/runanywhere-react-native/packages/core/package.json index e7b2f10e9..3a920ba44 100644 --- a/sdk/runanywhere-react-native/packages/core/package.json +++ b/sdk/runanywhere-react-native/packages/core/package.json @@ -52,8 +52,7 @@ "react-native-blob-util": ">=0.19.0", "react-native-device-info": ">=11.0.0", "react-native-fs": ">=2.20.0", - "react-native-nitro-modules": ">=0.31.3", - "react-native-zip-archive": ">=6.1.0" + "react-native-nitro-modules": ">=0.31.3" }, "peerDependenciesMeta": { "react-native-blob-util": { @@ -64,9 +63,6 @@ }, "react-native-fs": { "optional": true - }, - "react-native-zip-archive": { - "optional": true } }, "devDependencies": { diff --git a/sdk/runanywhere-react-native/packages/core/src/Public/Extensions/RunAnywhere+Storage.ts b/sdk/runanywhere-react-native/packages/core/src/Public/Extensions/RunAnywhere+Storage.ts index d027065b6..4c8ca7d76 100644 --- a/sdk/runanywhere-react-native/packages/core/src/Public/Extensions/RunAnywhere+Storage.ts +++ b/sdk/runanywhere-react-native/packages/core/src/Public/Extensions/RunAnywhere+Storage.ts @@ -2,7 +2,7 @@ * RunAnywhere+Storage.ts * * Storage management extension. - * Uses react-native-fs via FileSystem service. + * Delegates to C++ via native module for storage info (C++ handles recursive traversal). * * Reference: sdk/runanywhere-swift/Sources/RunAnywhere/Public/Extensions/Storage/RunAnywhere+Storage.swift */ @@ -10,6 +10,7 @@ import { ModelRegistry } from '../../services/ModelRegistry'; import { FileSystem } from '../../services/FileSystem'; import { SDKLogger } from '../../Foundation/Logging/Logger/SDKLogger'; +import { requireNativeModule, isNativeModuleAvailable } from '../../native'; const logger = new SDKLogger('RunAnywhere.Storage'); @@ -67,7 +68,8 @@ export async function getModelsDirectory(): Promise { /** * Get storage information - * Returns structure matching Swift's StorageInfo + * Delegates to C++ FileManagerBridge for recursive directory traversal. + * Returns structure matching Swift's StorageInfo. */ export async function getStorageInfo(): Promise { const emptyResult: StorageInfo = { @@ -78,75 +80,46 @@ export async function getStorageInfo(): Promise { totalModelsSize: 0, }; - if (!FileSystem.isAvailable()) { - return emptyResult; - } - try { - const freeSpace = await FileSystem.getAvailableDiskSpace(); - const totalSpace = await FileSystem.getTotalDiskSpace(); - const usedSpace = totalSpace - freeSpace; - - // Get models directory size - let modelsSize = 0; - let modelCount = 0; - try { - const modelsDir = FileSystem.getModelsDirectory(); - const exists = await FileSystem.directoryExists(modelsDir); - if (exists) { - modelsSize = await FileSystem.getDirectorySize(modelsDir); - const files = await FileSystem.listDirectory(modelsDir); - modelCount = files.length; - } - } catch { - // Models directory may not exist yet + // Use native module (C++ FileManagerBridge handles recursive traversal) + if (isNativeModuleAvailable()) { + const native = requireNativeModule(); + const json = await native.getStorageInfo(); + const info = JSON.parse(json); + + const totalDeviceSpace = parseInt(info.totalDeviceSpace || '0', 10); + const freeDeviceSpace = parseInt(info.freeDeviceSpace || '0', 10); + const usedDeviceSpace = parseInt(info.usedDeviceSpace || '0', 10); + const documentsSize = parseInt(info.documentsSize || '0', 10); + const cacheSz = parseInt(info.cacheSize || '0', 10); + const appSupportSize = parseInt(info.appSupportSize || '0', 10); + const totalAppSize = parseInt(info.totalAppSize || '0', 10); + const totalModelsSize = parseInt(info.totalModelsSize || '0', 10); + const modelCount = parseInt(info.modelCount || '0', 10); + + return { + deviceStorage: { + totalSpace: totalDeviceSpace, + freeSpace: freeDeviceSpace, + usedSpace: usedDeviceSpace, + }, + appStorage: { + documentsSize, + cacheSize: cacheSz, + appSupportSize, + totalSize: totalAppSize, + }, + modelStorage: { + totalSize: totalModelsSize, + modelCount, + }, + cacheSize: cacheSz, + totalModelsSize, + }; } - // Get cache size - let cacheSize = 0; - try { - const cacheDir = FileSystem.getCacheDirectory(); - const exists = await FileSystem.directoryExists(cacheDir); - if (exists) { - cacheSize = await FileSystem.getDirectorySize(cacheDir); - } - } catch { - // Cache directory may not exist - } - - // Get app documents size (RunAnywhere directory) - let documentsSize = 0; - try { - const docsDir = FileSystem.getRunAnywhereDirectory(); - const exists = await FileSystem.directoryExists(docsDir); - if (exists) { - documentsSize = await FileSystem.getDirectorySize(docsDir); - } - } catch { - // Documents directory may not exist - } - - const totalAppSize = documentsSize + cacheSize; - - return { - deviceStorage: { - totalSpace, - freeSpace, - usedSpace, - }, - appStorage: { - documentsSize, - cacheSize, - appSupportSize: 0, - totalSize: totalAppSize, - }, - modelStorage: { - totalSize: modelsSize, - modelCount, - }, - cacheSize, - totalModelsSize: modelsSize, - }; + // Native module is required for storage info (C++ handles recursive traversal) + return emptyResult; } catch (error) { logger.warning('Failed to get storage info:', { error }); return emptyResult; @@ -155,8 +128,21 @@ export async function getStorageInfo(): Promise { /** * Clear cache + * Delegates to C++ FileManagerBridge for file cache/temp clearing. */ export async function clearCache(): Promise { + // Clear in-memory model registry cache ModelRegistry.reset(); + + // Clear file caches via native module (C++ handles directory clearing) + if (isNativeModuleAvailable()) { + try { + const native = requireNativeModule(); + await native.clearCache(); + } catch (error) { + logger.warning('Failed to clear native cache:', { error }); + } + } + logger.info('Cache cleared'); } diff --git a/sdk/runanywhere-react-native/packages/core/src/services/FileSystem.ts b/sdk/runanywhere-react-native/packages/core/src/services/FileSystem.ts index 5fe867b65..d81134401 100644 --- a/sdk/runanywhere-react-native/packages/core/src/services/FileSystem.ts +++ b/sdk/runanywhere-react-native/packages/core/src/services/FileSystem.ts @@ -103,20 +103,6 @@ try { logger.warning('react-native-fs not installed, file operations will be limited'); } -// Try to import react-native-zip-archive -let ZipArchive: { - unzip: (source: string, target: string) => Promise; - unzipWithPassword: (source: string, target: string, password: string) => Promise; - unzipAssets: (assetPath: string, target: string) => Promise; - subscribe: (callback: (event: { progress: number; filePath: string }) => void) => { remove: () => void }; -} | null = null; -try { - // eslint-disable-next-line @typescript-eslint/no-require-imports - ZipArchive = require('react-native-zip-archive'); -} catch { - logger.warning('react-native-zip-archive not installed, archive extraction will be limited'); -} - // Constants matching Swift SDK path structure const RUN_ANYWHERE_DIR = 'RunAnywhere'; const MODELS_DIR = 'Models'; @@ -161,65 +147,20 @@ export interface DownloadProgress { } /** - * Archive types supported for extraction - * Matches Swift SDK's ArchiveType enum - */ -export enum ArchiveType { - Zip = 'zip', - TarBz2 = 'tar.bz2', - TarGz = 'tar.gz', - TarXz = 'tar.xz', -} - -/** - * Describes the internal structure of an archive after extraction - * Matches Swift SDK's ArchiveStructure enum - */ -export enum ArchiveStructure { - SingleFileNested = 'singleFileNested', - DirectoryBased = 'directoryBased', - NestedDirectory = 'nestedDirectory', - Unknown = 'unknown', -} - -/** - * Model artifact type - describes how a model is packaged - * Matches Swift SDK's ModelArtifactType enum - */ -export type ModelArtifactType = - | { type: 'singleFile' } - | { type: 'archive'; archiveType: ArchiveType; structure: ArchiveStructure } - | { type: 'multiFile'; files: string[] } - | { type: 'custom'; strategyId: string } - | { type: 'builtIn' }; - -/** - * Extraction result - */ -export interface ExtractionResult { - modelPath: string; - extractedSize: number; - fileCount: number; -} - -/** - * Infer archive type from URL + * Check if a URL points to an archive that needs extraction. + * C++ handles actual format detection via rac_detect_archive_type(). */ -function inferArchiveType(url: string): ArchiveType | null { +function isArchiveUrl(url: string): boolean { const lowercased = url.toLowerCase(); - if (lowercased.includes('.tar.bz2') || lowercased.includes('.tbz2')) { - return ArchiveType.TarBz2; - } - if (lowercased.includes('.tar.gz') || lowercased.includes('.tgz')) { - return ArchiveType.TarGz; - } - if (lowercased.includes('.tar.xz') || lowercased.includes('.txz')) { - return ArchiveType.TarXz; - } - if (lowercased.includes('.zip')) { - return ArchiveType.Zip; - } - return null; + return ( + lowercased.includes('.tar.bz2') || + lowercased.includes('.tbz2') || + lowercased.includes('.tar.gz') || + lowercased.includes('.tgz') || + lowercased.includes('.tar.xz') || + lowercased.includes('.txz') || + lowercased.includes('.zip') + ); } /** @@ -326,21 +267,7 @@ export const FileSystem = { return `${folder}/${baseId}${ext}`; } - // For ONNX, check if the model is in a nested directory structure - if (RNFS) { - try { - const exists = await RNFS.exists(folder); - if (exists) { - // Find the actual model path (handles nested directory structures) - const modelPath = await this.findModelPathAfterExtraction(folder); - return modelPath; - } - } catch { - // Fall through to return the default folder - } - } - - // Directory-based model (ONNX) + // Directory-based model (ONNX) — return the folder return folder; }, @@ -364,10 +291,8 @@ export const FileSystem = { if (files.length === 0) return false; if (fw === 'ONNX') { - // For ONNX, we need to check if there are actual model files (not just an archive) - // ONNX models should have .onnx files after extraction - const hasOnnxFiles = await this.hasModelFiles(folder); - return hasOnnxFiles; + // For ONNX, directory has contents — model is present + return true; } return true; @@ -376,35 +301,6 @@ export const FileSystem = { } }, - /** - * Recursively check if a folder contains model files - */ - async hasModelFiles(folder: string): Promise { - if (!RNFS) return false; - - try { - const contents = await RNFS.readDir(folder); - - for (const item of contents) { - if (item.isFile()) { - const name = item.name.toLowerCase(); - // Check for actual model files, not archive files - if (name.endsWith('.onnx') || name.endsWith('.bin') || name.endsWith('.txt')) { - return true; - } - } else if (item.isDirectory()) { - // Check nested directories - const hasFiles = await this.hasModelFiles(item.path); - if (hasFiles) return true; - } - } - - return false; - } catch { - return false; - } - }, - /** * Create directory if it doesn't exist */ @@ -446,8 +342,8 @@ export const FileSystem = { // Determine destination path let destPath: string; - const archiveType = inferArchiveType(url); -if (fw === 'LlamaCpp' && archiveType === null) { + const needsExtraction = isArchiveUrl(url); + if (fw === 'LlamaCpp' && !needsExtraction) { // Single GGUF/BIN file (not an archive) const ext = modelId.includes('.gguf') || url.includes('.gguf') @@ -456,7 +352,7 @@ if (fw === 'LlamaCpp' && archiveType === null) { ? '.bin' : '.gguf'; destPath = `${folder}/${baseId}${ext}`; - } else if (fw === 'ONNX' && archiveType === null) { + } else if (fw === 'ONNX' && !needsExtraction) { // ONNX single-file model (.onnx) const ext = modelId.includes('.onnx') || url.includes('.onnx') ? '.onnx' : ''; destPath = `${folder}/${baseId}${ext}`; @@ -472,7 +368,7 @@ if (fw === 'LlamaCpp' && archiveType === null) { // Check if already exists const exists = await RNFS.exists(destPath); - if (exists && (fw === 'LlamaCpp' || (fw === 'ONNX' && archiveType === null))) { + if (exists && (fw === 'LlamaCpp' || (fw === 'ONNX' && !needsExtraction))) { logger.info(`Model already exists: ${destPath}`); return destPath; } @@ -511,25 +407,18 @@ if (fw === 'LlamaCpp' && archiveType === null) { logger.info(`Download completed: ${result.bytesWritten} bytes`); -// For archives (ONNX or LlamaCpp VLM), extract to final location - if (archiveType !== null) { - logger.info(`Extracting ${archiveType} archive for ${fw}...`); + // For archives (ONNX or LlamaCpp VLM), extract to final location + if (needsExtraction) { + logger.info(`Extracting archive for ${fw}...`); try { - const extractionResult = await this.extractArchive(destPath, folder, archiveType); - logger.info(`Extraction completed: ${extractionResult.fileCount} files, ${extractionResult.extractedSize} bytes`); + const modelPath = await this.extractArchive(destPath, folder); + logger.info(`Extraction completed, model at: ${modelPath}`); // Clean up the temporary archive file await RNFS.unlink(destPath); - // For LlamaCpp VLM, find the .gguf file in extracted folder - if (fw === 'LlamaCpp') { - destPath = await this.findGGUFInDirectory(extractionResult.modelPath); - logger.info(`Found GGUF model at: ${destPath}`); - } else { - // For ONNX, return the extracted folder path - destPath = extractionResult.modelPath; - } + destPath = modelPath; } catch (extractError) { logger.error(`Archive extraction failed: ${extractError}`); // Clean up temp file on failure @@ -546,167 +435,41 @@ if (fw === 'LlamaCpp' && archiveType === null) { }, /** - * Extract an archive to a destination folder - * Uses native extraction via the core module (iOS: ArchiveUtility, Android: native extraction) + * Extract an archive to a destination folder. + * Uses native C++ extraction via libarchive (auto-detects format). + * Returns the extracted model path. */ async extractArchive( archivePath: string, destinationFolder: string, - archiveType: ArchiveType, - onProgress?: (progress: number) => void - ): Promise { + ): Promise { if (!RNFS) { throw new Error('react-native-fs not installed'); } logger.info(`Extracting archive: ${archivePath}`); - logger.info(`Archive type: ${archiveType}`); logger.info(`Destination: ${destinationFolder}`); // Ensure destination exists await this.ensureDirectory(destinationFolder); - // Try native extraction first (supports tar.gz, tar.bz2, zip) - try { - const native = getNativeModule(); - if (!native) { - throw new Error('Native module not available'); - } - - logger.info('Using native archive extraction...'); - const success = await native.extractArchive(archivePath, destinationFolder); - - if (!success) { - throw new Error('Native extraction returned false'); - } - - logger.info('Native extraction completed successfully'); - } catch (nativeError) { - logger.warning(`Native extraction failed: ${nativeError}, trying fallback...`); - - // Fallback to react-native-zip-archive for ZIP files only - if (archiveType === ArchiveType.Zip && ZipArchive) { - logger.info('Falling back to react-native-zip-archive for ZIP...'); - - let subscription: { remove: () => void } | null = null; - if (onProgress) { - subscription = ZipArchive.subscribe(({ progress }) => { - onProgress(progress); - }); - } - - try { - await ZipArchive.unzip(archivePath, destinationFolder); - } finally { - if (subscription) { - subscription.remove(); - } - } - } else if (archiveType === ArchiveType.TarGz || archiveType === ArchiveType.TarBz2) { - // No fallback for tar archives - native is required - throw new Error( - `Archive extraction failed for ${archiveType}. Native extraction is required for tar archives. Error: ${nativeError}` - ); - } else { - throw new Error(`Archive extraction failed: ${nativeError}`); - } + // Use native C++ extraction (libarchive) — auto-detects all formats + const native = getNativeModule(); + if (!native) { + throw new Error('Native module not available'); } - // After extraction, find the actual model path - // ONNX models are typically nested in a directory with the same name - const modelPath = await this.findModelPathAfterExtraction(destinationFolder); + const success = await native.extractArchive(archivePath, destinationFolder); - // Calculate extraction stats - const stats = await this.calculateExtractionStats(destinationFolder); - - return { - modelPath, - extractedSize: stats.totalSize, - fileCount: stats.fileCount, - }; - }, - - /** - * Find the actual model path after extraction - * Handles nested directory structures common in ONNX archives - */ - async findModelPathAfterExtraction(extractedFolder: string): Promise { - if (!RNFS) { - return extractedFolder; + if (!success) { + throw new Error('Native extraction failed'); } - try { - const contents = await RNFS.readDir(extractedFolder); - - // If the directory contains .onnx files or other model files (tokens.txt, espeak-ng-data), - // return the DIRECTORY path — the C++ backend scans it internally for all needed files - // (encoder.onnx, decoder.onnx, tokens.txt, espeak-ng-data/, vocab.txt, etc.). - // This matches the iOS SDK which always passes directory paths for ONNX models. - const hasModelFiles = contents.some( - item => item.isFile() && ( - item.name.toLowerCase().endsWith('.onnx') || - item.name === 'tokens.txt' || - item.name === 'vocab.txt' - ) - ); - if (hasModelFiles) { - logger.info(`Found model files in directory: ${extractedFolder}`); - return extractedFolder; - } - - // If there's exactly one directory and no model files, it's a nested archive structure - const directories = contents.filter(item => item.isDirectory()); - const files = contents.filter(item => item.isFile()); - - if (directories.length === 1 && files.length === 0) { - const nestedDir = directories[0]; - logger.info(`Found nested directory structure: ${nestedDir.name}`); - return this.findModelPathAfterExtraction(nestedDir.path); - } + logger.info('Native extraction completed successfully'); - return extractedFolder; - } catch (error) { - logger.error(`Error finding model path: ${error}`); - return extractedFolder; - } - }, - - /** - * Find GGUF file in extracted directory (for VLM models) - * Recursively searches for the main model .gguf file - */ - async findGGUFInDirectory(directory: string): Promise { - if (!RNFS) { - throw new Error('react-native-fs not available'); - } - - try { - const contents = await RNFS.readDir(directory); - - // Look for .gguf files (not mmproj) - for (const item of contents) { - if (item.isFile() && item.name.endsWith('.gguf') && !item.name.includes('mmproj')) { - logger.info(`Found main GGUF model: ${item.name}`); - return item.path; - } - } - - // If not found, check nested directories - for (const item of contents) { - if (item.isDirectory()) { - try { - return await this.findGGUFInDirectory(item.path); - } catch { - // Continue searching other directories - } - } - } - - throw new Error(`No GGUF model file found in ${directory}`); - } catch (error) { - logger.error(`Error finding GGUF file: ${error}`); - throw error; - } + // C++ extraction handles format detection and path finding. + // Return the destination folder as the model path. + return destinationFolder; }, /** @@ -739,112 +502,6 @@ if (fw === 'LlamaCpp' && archiveType === null) { } }, - /** - * Calculate extraction statistics - */ - async calculateExtractionStats(folder: string): Promise<{ totalSize: number; fileCount: number }> { - if (!RNFS) { - return { totalSize: 0, fileCount: 0 }; - } - - let totalSize = 0; - let fileCount = 0; - - const processDir = async (dir: string) => { - try { - const contents = await RNFS!.readDir(dir); - for (const item of contents) { - if (item.isFile()) { - totalSize += item.size; - fileCount++; - } else if (item.isDirectory()) { - await processDir(item.path); - } - } - } catch { - // Ignore errors - } - }; - - await processDir(folder); - return { totalSize, fileCount }; - }, - - /** - * Download a multi-file model. - * All files are placed in the same directory: Models/{framework}/{modelId}/ - * Returns the folder path (not a file path), matching the Swift SDK behavior. - */ - async downloadMultiFileModel( - modelId: string, - files: ModelFileDescriptor[], - onProgress?: (progress: DownloadProgress) => void, - framework?: string - ): Promise { - if (!RNFS) { - throw new Error('react-native-fs not installed'); - } - - const fw = framework || 'ONNX'; - const baseId = getBaseModelId(modelId); - const folder = `${this.getFrameworkDirectory(fw)}/${baseId}`; - - await this.ensureDirectory(this.getRunAnywhereDirectory()); - await this.ensureDirectory(this.getModelsDirectory()); - await this.ensureDirectory(this.getFrameworkDirectory(fw)); - await this.ensureDirectory(folder); - - logger.info(`Downloading multi-file model: ${modelId} (${files.length} files)`); - - let totalBytesWritten = 0; - let totalContentLength = 0; - - for (let i = 0; i < files.length; i++) { - const fileDesc = files[i]; - const destPath = `${folder}/${fileDesc.filename}`; - - const exists = await RNFS.exists(destPath); - if (exists) { - logger.info(`File already exists, skipping: ${fileDesc.filename}`); - continue; - } - - logger.info(`Downloading file ${i + 1}/${files.length}: ${fileDesc.filename}`); - - const downloadResult = RNFS.downloadFile({ - fromUrl: fileDesc.url, - toFile: destPath, - background: true, - progressDivider: 1, - begin: (res) => { - totalContentLength += res.contentLength; - logger.info(`File download started: ${fileDesc.filename} (${res.contentLength} bytes)`); - }, - progress: (res) => { - if (onProgress && totalContentLength > 0) { - onProgress({ - bytesWritten: totalBytesWritten + res.bytesWritten, - contentLength: totalContentLength, - progress: (totalBytesWritten + res.bytesWritten) / totalContentLength, - }); - } - }, - }); - - const result = await downloadResult.promise; - - if (result.statusCode !== 200) { - throw new Error(`Download failed for ${fileDesc.filename}: status ${result.statusCode}`); - } - - totalBytesWritten += result.bytesWritten; - logger.info(`File downloaded: ${fileDesc.filename} (${result.bytesWritten} bytes)`); - } - - logger.info(`Multi-file model download complete: ${folder}`); - return folder; - }, - /** * Delete a model */ @@ -938,33 +595,6 @@ if (fw === 'LlamaCpp' && archiveType === null) { } }, - /** - * Get the size of a directory in bytes (recursive) - */ - async getDirectorySize(dirPath: string): Promise { - if (!RNFS) return 0; - - try { - const exists = await RNFS.exists(dirPath); - if (!exists) return 0; - - let totalSize = 0; - const contents = await RNFS.readDir(dirPath); - - for (const item of contents) { - if (item.isDirectory()) { - totalSize += await this.getDirectorySize(item.path); - } else { - totalSize += item.size || 0; - } - } - - return totalSize; - } catch { - return 0; - } - }, - /** * Get the cache directory path */ diff --git a/sdk/runanywhere-react-native/packages/core/src/services/index.ts b/sdk/runanywhere-react-native/packages/core/src/services/index.ts index 5a0ba2f29..d6b0b7b2f 100644 --- a/sdk/runanywhere-react-native/packages/core/src/services/index.ts +++ b/sdk/runanywhere-react-native/packages/core/src/services/index.ts @@ -15,12 +15,8 @@ export { export { FileSystem, MultiFileModelCache, - ArchiveType, - ArchiveStructure, - type ModelArtifactType, type ModelFileDescriptor, type DownloadProgress as FSDownloadProgress, - type ExtractionResult, } from './FileSystem'; // Download Service - Native-based download (delegates to native commons) diff --git a/sdk/runanywhere-react-native/yarn.lock b/sdk/runanywhere-react-native/yarn.lock index da55f455f..1fd71438e 100644 --- a/sdk/runanywhere-react-native/yarn.lock +++ b/sdk/runanywhere-react-native/yarn.lock @@ -1622,7 +1622,6 @@ __metadata: react-native-device-info: ">=11.0.0" react-native-fs: ">=2.20.0" react-native-nitro-modules: ">=0.31.3" - react-native-zip-archive: ">=6.1.0" peerDependenciesMeta: react-native-blob-util: optional: true @@ -1630,8 +1629,6 @@ __metadata: optional: true react-native-fs: optional: true - react-native-zip-archive: - optional: true languageName: unknown linkType: soft diff --git a/sdk/runanywhere-swift/Sources/RunAnywhere/CRACommons/include/CRACommons.h b/sdk/runanywhere-swift/Sources/RunAnywhere/CRACommons/include/CRACommons.h index 2f9c83dee..b1439d93c 100644 --- a/sdk/runanywhere-swift/Sources/RunAnywhere/CRACommons/include/CRACommons.h +++ b/sdk/runanywhere-swift/Sources/RunAnywhere/CRACommons/include/CRACommons.h @@ -103,6 +103,7 @@ // Download management #include "rac_download.h" +#include "rac_download_orchestrator.h" // Model management #include "rac_model_types.h" @@ -115,6 +116,9 @@ // Storage #include "rac_storage_analyzer.h" +// File Management +#include "rac_file_manager.h" + // Device #include "rac_device_manager.h" diff --git a/sdk/runanywhere-swift/Sources/RunAnywhere/CRACommons/include/rac_download_orchestrator.h b/sdk/runanywhere-swift/Sources/RunAnywhere/CRACommons/include/rac_download_orchestrator.h new file mode 100644 index 000000000..1d24d58a4 --- /dev/null +++ b/sdk/runanywhere-swift/Sources/RunAnywhere/CRACommons/include/rac_download_orchestrator.h @@ -0,0 +1,166 @@ +/** + * @file rac_download_orchestrator.h + * @brief Download Orchestrator - High-Level Model Download Lifecycle Management + * + * Consolidates download business logic from all platform SDKs into C++. + * Handles the full download lifecycle: path resolution, extraction detection, + * HTTP download (via platform adapter), post-download extraction, model path + * finding, registry updates, and archive cleanup. + * + * HTTP transport remains platform-specific via rac_platform_adapter_t.http_download. + * This layer handles ALL orchestration logic so each SDK reduces to: + * 1. Register http_download callback + * 2. Call rac_download_orchestrate() + * 3. Wrap result in SDK types + * + * Depends on: + * - rac_download.h (download manager state machine, progress tracking) + * - rac_platform_adapter.h (http_download callback for HTTP transport) + * - rac_extraction.h (rac_extract_archive_native for archive extraction) + * - rac_model_paths.h (destination path resolution) + * - rac_model_types.h (model types, archive types, frameworks) + */ + +#ifndef RAC_DOWNLOAD_ORCHESTRATOR_H +#define RAC_DOWNLOAD_ORCHESTRATOR_H + +#include "rac_error.h" +#include "rac_types.h" +#include "rac_download.h" +#include "rac_model_types.h" + +#ifdef __cplusplus +extern "C" { +#endif + +// ============================================================================= +// DOWNLOAD ORCHESTRATION - Full Lifecycle Model Download +// ============================================================================= + +/** + * @brief Orchestrate a single-file model download with full lifecycle management. + * + * This is the main entry point for downloading a model. It handles: + * 1. Compute destination folder via rac_model_paths_get_model_folder() + * 2. Detect if extraction is needed via rac_archive_type_from_path() + * 3. Download to temp path if extraction needed, else download to model folder + * 4. Invoke platform http_download via rac_http_download() + * 5. On HTTP completion: extract if needed, find model path, cleanup archive + * 6. Update download manager state (DOWNLOADING → EXTRACTING → COMPLETED) + * 7. Invoke user callbacks with final model path + * + * @param dm_handle Download manager handle (for state tracking) + * @param model_id Model identifier (used for folder naming and registry) + * @param download_url URL to download from + * @param framework Inference framework (determines storage directory) + * @param format Model format (determines file extension and path finding) + * @param archive_structure Archive structure hint (used for post-extraction path finding) + * @param progress_callback Progress updates across all stages (can be NULL) + * @param complete_callback Called when entire lifecycle completes or fails + * @param user_data User context passed to callbacks + * @param out_task_id Output: Task ID for tracking/cancellation (owned, free with rac_free) + * @return RAC_SUCCESS if download started, error code if failed to start + */ +RAC_API rac_result_t rac_download_orchestrate( + rac_download_manager_handle_t dm_handle, const char* model_id, const char* download_url, + rac_inference_framework_t framework, rac_model_format_t format, + rac_archive_structure_t archive_structure, + rac_download_progress_callback_fn progress_callback, + rac_download_complete_callback_fn complete_callback, void* user_data, char** out_task_id); + +/** + * @brief Orchestrate a multi-file model download (e.g., VLM with companion files). + * + * Downloads multiple files sequentially into the same model folder. + * Progress is distributed across all files proportionally. + * Extraction is applied to each file individually if needed. + * + * @param dm_handle Download manager handle (for state tracking) + * @param model_id Model identifier + * @param files Array of file descriptors (relative_path, destination_path, is_required) + * @param file_count Number of files to download + * @param base_download_url Base URL — file relative_path is appended to this + * @param framework Inference framework + * @param format Model format + * @param progress_callback Progress updates across all files and stages (can be NULL) + * @param complete_callback Called when all files complete or any required file fails + * @param user_data User context passed to callbacks + * @param out_task_id Output: Task ID for tracking/cancellation (owned, free with rac_free) + * @return RAC_SUCCESS if download started, error code if failed to start + */ +RAC_API rac_result_t rac_download_orchestrate_multi( + rac_download_manager_handle_t dm_handle, const char* model_id, + const rac_model_file_descriptor_t* files, size_t file_count, const char* base_download_url, + rac_inference_framework_t framework, rac_model_format_t format, + rac_download_progress_callback_fn progress_callback, + rac_download_complete_callback_fn complete_callback, void* user_data, char** out_task_id); + +// ============================================================================= +// POST-EXTRACTION MODEL PATH FINDING +// ============================================================================= + +/** + * @brief Find the actual model path after extraction. + * + * Consolidates duplicated Swift/Kotlin logic for scanning extracted directories: + * - Finds .gguf, .onnx, .ort, .bin files + * - Handles nested directories (e.g., sherpa-onnx archives with subdirectory) + * - Handles single-file-nested pattern (model file inside one subdirectory) + * - Returns the directory itself for directory-based models (ONNX) + * + * Uses POSIX opendir/readdir for cross-platform compatibility (iOS/Android/Linux/macOS). + * + * @param extracted_dir Directory where archive was extracted + * @param structure Archive structure hint (SINGLE_FILE_NESTED, NESTED_DIRECTORY, etc.) + * @param framework Inference framework (used to determine if directory-based) + * @param format Model format (used to determine expected file extensions) + * @param out_path Output buffer for the found model path + * @param path_size Size of output buffer + * @return RAC_SUCCESS if model path found, RAC_ERROR_NOT_FOUND if no model file found + */ +RAC_API rac_result_t rac_find_model_path_after_extraction( + const char* extracted_dir, rac_archive_structure_t structure, + rac_inference_framework_t framework, rac_model_format_t format, char* out_path, + size_t path_size); + +// ============================================================================= +// UTILITY FUNCTIONS +// ============================================================================= + +/** + * @brief Compute the download destination path for a model. + * + * If extraction is needed: returns a temp path in the downloads directory. + * If no extraction: returns the final model folder path. + * + * @param model_id Model identifier + * @param download_url URL to download (used for archive detection and extension) + * @param framework Inference framework + * @param format Model format + * @param out_path Output buffer for destination path + * @param path_size Size of output buffer + * @param out_needs_extraction Output: RAC_TRUE if download needs extraction + * @return RAC_SUCCESS or error code + */ +RAC_API rac_result_t rac_download_compute_destination(const char* model_id, + const char* download_url, + rac_inference_framework_t framework, + rac_model_format_t format, char* out_path, + size_t path_size, + rac_bool_t* out_needs_extraction); + +/** + * @brief Check if a download URL requires extraction. + * + * Convenience wrapper around rac_archive_type_from_path(). + * + * @param download_url URL to check + * @return RAC_TRUE if URL points to an archive that needs extraction + */ +RAC_API rac_bool_t rac_download_requires_extraction(const char* download_url); + +#ifdef __cplusplus +} +#endif + +#endif /* RAC_DOWNLOAD_ORCHESTRATOR_H */ diff --git a/sdk/runanywhere-swift/Sources/RunAnywhere/CRACommons/include/rac_file_manager.h b/sdk/runanywhere-swift/Sources/RunAnywhere/CRACommons/include/rac_file_manager.h new file mode 100644 index 000000000..5712e3b6b --- /dev/null +++ b/sdk/runanywhere-swift/Sources/RunAnywhere/CRACommons/include/rac_file_manager.h @@ -0,0 +1,305 @@ +/** + * @file rac_file_manager.h + * @brief File Manager - Centralized File Management Business Logic + * + * Consolidates common file management operations that were duplicated + * across SDKs (Swift, Kotlin, Flutter, React Native): + * - Directory size calculation (recursive traversal) + * - Directory structure creation (Models/Cache/Temp/Downloads) + * - Cache and temp cleanup + * - Model folder management (create, delete, check existence) + * - Storage availability checking + * + * Platform-specific file I/O is provided via callbacks (rac_file_callbacks_t). + * C++ handles all business logic; SDKs only provide thin I/O implementations. + * + * Uses rac_model_paths for path computation. + */ + +#ifndef RAC_FILE_MANAGER_H +#define RAC_FILE_MANAGER_H + +#include +#include + +#include "rac_error.h" +#include "rac_types.h" +#include "rac_model_types.h" +#include "rac_storage_analyzer.h" + +#ifdef __cplusplus +extern "C" { +#endif + +// ============================================================================= +// PLATFORM I/O CALLBACKS +// ============================================================================= + +/** + * @brief Platform-specific file I/O callbacks. + * + * SDKs implement these thin wrappers around native file operations. + * C++ business logic calls these for all file system access. + */ +typedef struct { + /** + * Create a directory (optionally recursive). + * @param path Directory path to create + * @param recursive If non-zero, create intermediate directories + * @param user_data Platform context + * @return RAC_SUCCESS or error code + */ + rac_result_t (*create_directory)(const char* path, int recursive, void* user_data); + + /** + * Delete a file or directory. + * @param path Path to delete + * @param recursive If non-zero, delete directory contents recursively + * @param user_data Platform context + * @return RAC_SUCCESS or error code + */ + rac_result_t (*delete_path)(const char* path, int recursive, void* user_data); + + /** + * List directory contents (entry names only, not full paths). + * @param path Directory path + * @param out_entries Output: Array of entry name strings (allocated by callback) + * @param out_count Output: Number of entries + * @param user_data Platform context + * @return RAC_SUCCESS or error code + */ + rac_result_t (*list_directory)(const char* path, char*** out_entries, size_t* out_count, + void* user_data); + + /** + * Free directory entries returned by list_directory. + * @param entries Array of entry names + * @param count Number of entries + * @param user_data Platform context + */ + void (*free_entries)(char** entries, size_t count, void* user_data); + + /** + * Check if a path exists. + * @param path Path to check + * @param out_is_directory Output: non-zero if path is a directory (can be NULL) + * @param user_data Platform context + * @return RAC_TRUE if exists, RAC_FALSE otherwise + */ + rac_bool_t (*path_exists)(const char* path, rac_bool_t* out_is_directory, void* user_data); + + /** + * Get file size in bytes. + * @param path File path + * @param user_data Platform context + * @return File size in bytes, or -1 on error + */ + int64_t (*get_file_size)(const char* path, void* user_data); + + /** + * Get available disk space in bytes. + * @param user_data Platform context + * @return Available space in bytes, or -1 on error + */ + int64_t (*get_available_space)(void* user_data); + + /** + * Get total disk space in bytes. + * @param user_data Platform context + * @return Total space in bytes, or -1 on error + */ + int64_t (*get_total_space)(void* user_data); + + /** Platform-specific context passed to all callbacks */ + void* user_data; +} rac_file_callbacks_t; + +// ============================================================================= +// DATA STRUCTURES +// ============================================================================= + +/** + * @brief Combined storage information. + * + * Aggregates device storage, app storage (models/cache/temp), and + * computed totals. Replaces per-SDK storage info structs. + */ +typedef struct { + /** Total device storage in bytes */ + int64_t device_total; + /** Free device storage in bytes */ + int64_t device_free; + /** Total models directory size in bytes */ + int64_t models_size; + /** Cache directory size in bytes */ + int64_t cache_size; + /** Temp directory size in bytes */ + int64_t temp_size; + /** Total app storage (models + cache + temp) */ + int64_t total_app_size; +} rac_file_manager_storage_info_t; + +// ============================================================================= +// DIRECTORY STRUCTURE +// ============================================================================= + +/** + * @brief Create the standard directory structure under the base directory. + * + * Creates: Models/, Cache/, Temp/, Downloads/ under {base_dir}/RunAnywhere/ + * Uses rac_model_paths for path computation. + * + * @param cb Platform I/O callbacks + * @return RAC_SUCCESS or error code + */ +RAC_API rac_result_t rac_file_manager_create_directory_structure(const rac_file_callbacks_t* cb); + +// ============================================================================= +// MODEL FOLDER MANAGEMENT +// ============================================================================= + +/** + * @brief Create a model folder and return its path. + * + * Creates: {base_dir}/RunAnywhere/Models/{framework}/{modelId}/ + * + * @param cb Platform I/O callbacks + * @param model_id Model identifier + * @param framework Inference framework + * @param out_path Output buffer for the created folder path + * @param path_size Size of output buffer + * @return RAC_SUCCESS or error code + */ +RAC_API rac_result_t rac_file_manager_create_model_folder(const rac_file_callbacks_t* cb, + const char* model_id, + rac_inference_framework_t framework, + char* out_path, size_t path_size); + +/** + * @brief Check if a model folder exists and optionally if it has contents. + * + * @param cb Platform I/O callbacks + * @param model_id Model identifier + * @param framework Inference framework + * @param out_exists Output: RAC_TRUE if folder exists + * @param out_has_contents Output: RAC_TRUE if folder has files (can be NULL) + * @return RAC_SUCCESS or error code + */ +RAC_API rac_result_t rac_file_manager_model_folder_exists(const rac_file_callbacks_t* cb, + const char* model_id, + rac_inference_framework_t framework, + rac_bool_t* out_exists, + rac_bool_t* out_has_contents); + +/** + * @brief Delete a model folder recursively. + * + * @param cb Platform I/O callbacks + * @param model_id Model identifier + * @param framework Inference framework + * @return RAC_SUCCESS, or RAC_ERROR_FILE_NOT_FOUND if folder doesn't exist + */ +RAC_API rac_result_t rac_file_manager_delete_model(const rac_file_callbacks_t* cb, + const char* model_id, + rac_inference_framework_t framework); + +// ============================================================================= +// DIRECTORY SIZE CALCULATION +// ============================================================================= + +/** + * @brief Calculate directory size recursively. + * + * @param cb Platform I/O callbacks + * @param path Directory path to measure + * @param out_size Output: Total size in bytes + * @return RAC_SUCCESS or error code + */ +RAC_API rac_result_t rac_file_manager_calculate_dir_size(const rac_file_callbacks_t* cb, + const char* path, int64_t* out_size); + +/** + * @brief Get total models directory storage used. + * + * @param cb Platform I/O callbacks + * @param out_size Output: Total models size in bytes + * @return RAC_SUCCESS or error code + */ +RAC_API rac_result_t rac_file_manager_models_storage_used(const rac_file_callbacks_t* cb, + int64_t* out_size); + +// ============================================================================= +// CACHE & TEMP MANAGEMENT +// ============================================================================= + +/** + * @brief Clear the cache directory. + * + * @param cb Platform I/O callbacks + * @return RAC_SUCCESS or error code + */ +RAC_API rac_result_t rac_file_manager_clear_cache(const rac_file_callbacks_t* cb); + +/** + * @brief Clear the temp directory. + * + * @param cb Platform I/O callbacks + * @return RAC_SUCCESS or error code + */ +RAC_API rac_result_t rac_file_manager_clear_temp(const rac_file_callbacks_t* cb); + +/** + * @brief Get the cache directory size. + * + * @param cb Platform I/O callbacks + * @param out_size Output: Cache size in bytes + * @return RAC_SUCCESS or error code + */ +RAC_API rac_result_t rac_file_manager_cache_size(const rac_file_callbacks_t* cb, + int64_t* out_size); + +// ============================================================================= +// STORAGE INFO +// ============================================================================= + +/** + * @brief Get combined storage information. + * + * @param cb Platform I/O callbacks + * @param out_info Output: Storage information + * @return RAC_SUCCESS or error code + */ +RAC_API rac_result_t rac_file_manager_get_storage_info(const rac_file_callbacks_t* cb, + rac_file_manager_storage_info_t* out_info); + +/** + * @brief Check storage availability for a download. + * + * @param cb Platform I/O callbacks + * @param required_bytes Space needed in bytes + * @param out_availability Output: Availability result + * @return RAC_SUCCESS or error code + */ +RAC_API rac_result_t rac_file_manager_check_storage( + const rac_file_callbacks_t* cb, int64_t required_bytes, + rac_storage_availability_t* out_availability); + +// ============================================================================= +// DIRECTORY CLEARING (INTERNAL HELPER) +// ============================================================================= + +/** + * @brief Clear all contents of a directory (delete + recreate). + * + * @param cb Platform I/O callbacks + * @param path Directory path to clear + * @return RAC_SUCCESS or error code + */ +RAC_API rac_result_t rac_file_manager_clear_directory(const rac_file_callbacks_t* cb, + const char* path); + +#ifdef __cplusplus +} +#endif + +#endif /* RAC_FILE_MANAGER_H */ diff --git a/sdk/runanywhere-swift/Sources/RunAnywhere/Foundation/Bridge/Extensions/CppBridge+Download.swift b/sdk/runanywhere-swift/Sources/RunAnywhere/Foundation/Bridge/Extensions/CppBridge+Download.swift index 46b3a7b96..3d639cda1 100644 --- a/sdk/runanywhere-swift/Sources/RunAnywhere/Foundation/Bridge/Extensions/CppBridge+Download.swift +++ b/sdk/runanywhere-swift/Sources/RunAnywhere/Foundation/Bridge/Extensions/CppBridge+Download.swift @@ -19,6 +19,71 @@ extension CppBridge { /// Shared download manager instance public static let shared = Download() + // MARK: - Download Orchestrator Utilities (stateless, nonisolated) + + /// Find model path after archive extraction using C++ rac_find_model_path_after_extraction(). + /// Consolidates Swift findModelPath/findNestedDirectory/findSingleModelFile into one C++ call. + public static func findModelPathAfterExtraction( + extractedDir: URL, + structure: ArchiveStructure, + framework: InferenceFramework, + format: ModelFormat + ) -> URL? { + var outPath = [CChar](repeating: 0, count: 4096) + + let result = extractedDir.path.withCString { dir in + rac_find_model_path_after_extraction( + dir, + structure.toC(), + framework.toC(), + format.toC(), + &outPath, + outPath.count + ) + } + + guard result == RAC_SUCCESS else { return nil } + return URL(fileURLWithPath: String(cString: outPath)) + } + + /// Check if a download URL requires extraction. + /// Uses C++ rac_download_requires_extraction() — convenience wrapper around rac_archive_type_from_path(). + public static func downloadRequiresExtraction(url: URL) -> Bool { + return url.absoluteString.withCString { urlStr in + rac_download_requires_extraction(urlStr) == RAC_TRUE + } + } + + /// Compute download destination path using C++ rac_download_compute_destination(). + /// Returns the path and whether extraction is needed, or nil on failure. + public static func computeDownloadDestination( + modelId: String, + downloadURL: URL, + framework: InferenceFramework, + format: ModelFormat + ) -> (path: URL, needsExtraction: Bool)? { + var outPath = [CChar](repeating: 0, count: 4096) + var needsExtraction: rac_bool_t = RAC_FALSE + + let result = modelId.withCString { mid in + downloadURL.absoluteString.withCString { urlStr in + rac_download_compute_destination( + mid, urlStr, + framework.toC(), + format.toC(), + &outPath, outPath.count, + &needsExtraction + ) + } + } + + guard result == RAC_SUCCESS else { return nil } + return ( + path: URL(fileURLWithPath: String(cString: outPath)), + needsExtraction: needsExtraction == RAC_TRUE + ) + } + private var handle: rac_download_manager_handle_t? private let logger = SDKLogger(category: "CppBridge.Download") diff --git a/sdk/runanywhere-swift/Sources/RunAnywhere/Foundation/Bridge/Extensions/CppBridge+FileManager.swift b/sdk/runanywhere-swift/Sources/RunAnywhere/Foundation/Bridge/Extensions/CppBridge+FileManager.swift new file mode 100644 index 000000000..b0397fe45 --- /dev/null +++ b/sdk/runanywhere-swift/Sources/RunAnywhere/Foundation/Bridge/Extensions/CppBridge+FileManager.swift @@ -0,0 +1,265 @@ +// +// CppBridge+FileManager.swift +// RunAnywhere SDK +// +// File manager bridge - C++ owns business logic, Swift provides file I/O callbacks. +// Consolidates duplicated file management logic from SimplifiedFileManager. +// + +import CRACommons +import Foundation + +// MARK: - File Manager Bridge + +extension CppBridge { + + /// File manager bridge to C++ rac_file_manager + /// C++ handles: recursive dir size, directory structure, cache clearing, storage checks + /// Swift provides: thin I/O callbacks (create dir, delete, list, stat, file size) + public enum FileManager { + + private static let logger = SDKLogger(category: "CppBridge.FileManager") + + // MARK: - Callbacks Construction + + /// Build rac_file_callbacks_t with Swift I/O implementations + static func makeCallbacks() -> rac_file_callbacks_t { + var cb = rac_file_callbacks_t() + cb.create_directory = fmCreateDirectoryCallback + cb.delete_path = fmDeletePathCallback + cb.list_directory = fmListDirectoryCallback + cb.free_entries = fmFreeEntriesCallback + cb.path_exists = fmPathExistsCallback + cb.get_file_size = fmGetFileSizeCallback + cb.get_available_space = fmGetAvailableSpaceCallback + cb.get_total_space = fmGetTotalSpaceCallback + cb.user_data = nil + return cb + } + + // MARK: - Public API + + /// Create directory structure (Models, Cache, Temp, Downloads) + public static func createDirectoryStructure() -> Bool { + var cb = makeCallbacks() + let result = rac_file_manager_create_directory_structure(&cb) + return result == RAC_SUCCESS + } + + /// Calculate directory size recursively (C++ logic, Swift I/O) + public static func calculateDirectorySize(at url: URL) -> Int64 { + var cb = makeCallbacks() + var size: Int64 = 0 + url.path.withCString { pathPtr in + _ = rac_file_manager_calculate_dir_size(&cb, pathPtr, &size) + } + return size + } + + /// Get total models storage used + public static func modelsStorageUsed() -> Int64 { + var cb = makeCallbacks() + var size: Int64 = 0 + _ = rac_file_manager_models_storage_used(&cb, &size) + return size + } + + /// Clear cache directory + public static func clearCache() -> Bool { + var cb = makeCallbacks() + return rac_file_manager_clear_cache(&cb) == RAC_SUCCESS + } + + /// Clear temp directory + public static func clearTemp() -> Bool { + var cb = makeCallbacks() + return rac_file_manager_clear_temp(&cb) == RAC_SUCCESS + } + + /// Get cache size + public static func cacheSize() -> Int64 { + var cb = makeCallbacks() + var size: Int64 = 0 + _ = rac_file_manager_cache_size(&cb, &size) + return size + } + + /// Delete a model folder + public static func deleteModel(modelId: String, framework: InferenceFramework) -> Bool { + var cb = makeCallbacks() + return modelId.withCString { mid in + rac_file_manager_delete_model(&cb, mid, framework.toCFramework()) == RAC_SUCCESS + } + } + + /// Check if model folder exists + public static func modelFolderExists(modelId: String, framework: InferenceFramework) -> Bool { + var cb = makeCallbacks() + var exists: rac_bool_t = RAC_FALSE + modelId.withCString { mid in + _ = rac_file_manager_model_folder_exists(&cb, mid, framework.toCFramework(), &exists, nil) + } + return exists == RAC_TRUE + } + + /// Check if model folder exists and has contents + public static func modelFolderHasContents(modelId: String, framework: InferenceFramework) -> Bool { + var cb = makeCallbacks() + var exists: rac_bool_t = RAC_FALSE + var hasContents: rac_bool_t = RAC_FALSE + modelId.withCString { mid in + _ = rac_file_manager_model_folder_exists(&cb, mid, framework.toCFramework(), &exists, &hasContents) + } + return exists == RAC_TRUE && hasContents == RAC_TRUE + } + + /// Get combined storage info + public static func getStorageInfo() -> rac_file_manager_storage_info_t { + var cb = makeCallbacks() + var info = rac_file_manager_storage_info_t() + _ = rac_file_manager_get_storage_info(&cb, &info) + return info + } + + /// Check storage availability + public static func checkStorage(requiredBytes: Int64) -> rac_storage_availability_t { + var cb = makeCallbacks() + var availability = rac_storage_availability_t() + _ = rac_file_manager_check_storage(&cb, requiredBytes, &availability) + return availability + } + } +} + +// MARK: - C Callbacks (Platform-Specific I/O) + +/// Create a directory, optionally with intermediate directories +private func fmCreateDirectoryCallback( + path: UnsafePointer?, + recursive: Int32, + userData _: UnsafeMutableRawPointer? +) -> rac_result_t { + guard let path = path else { return RAC_ERROR_NULL_POINTER } + let url = URL(fileURLWithPath: String(cString: path)) + do { + try Foundation.FileManager.default.createDirectory( + at: url, + withIntermediateDirectories: recursive != 0, + attributes: nil + ) + return RAC_SUCCESS + } catch { + return RAC_ERROR_DIRECTORY_CREATION_FAILED + } +} + +/// Delete a file or directory +private func fmDeletePathCallback( + path: UnsafePointer?, + recursive _: Int32, + userData _: UnsafeMutableRawPointer? +) -> rac_result_t { + guard let path = path else { return RAC_ERROR_NULL_POINTER } + let pathStr = String(cString: path) + do { + if Foundation.FileManager.default.fileExists(atPath: pathStr) { + try Foundation.FileManager.default.removeItem(atPath: pathStr) + } + return RAC_SUCCESS + } catch { + return RAC_ERROR_DELETE_FAILED + } +} + +/// List directory contents (entry names only) +private func fmListDirectoryCallback( + path: UnsafePointer?, + outEntries: UnsafeMutablePointer?>?>?, + outCount: UnsafeMutablePointer?, + userData _: UnsafeMutableRawPointer? +) -> rac_result_t { + guard let path = path, let outEntries = outEntries, let outCount = outCount else { + return RAC_ERROR_NULL_POINTER + } + + let pathStr = String(cString: path) + guard let contents = try? Foundation.FileManager.default.contentsOfDirectory(atPath: pathStr) else { + outEntries.pointee = nil + outCount.pointee = 0 + return RAC_ERROR_FILE_NOT_FOUND + } + + let count = contents.count + let entries = UnsafeMutablePointer?>.allocate(capacity: count) + + for (i, name) in contents.enumerated() { + entries[i] = strdup(name) + } + + outEntries.pointee = entries + outCount.pointee = count + return RAC_SUCCESS +} + +/// Free directory entries +private func fmFreeEntriesCallback( + entries: UnsafeMutablePointer?>?, + count: Int, + userData _: UnsafeMutableRawPointer? +) { + guard let entries = entries else { return } + for i in 0..?, + outIsDirectory: UnsafeMutablePointer?, + userData _: UnsafeMutableRawPointer? +) -> rac_bool_t { + guard let path = path else { return RAC_FALSE } + let pathStr = String(cString: path) + var isDir: ObjCBool = false + let exists = Foundation.FileManager.default.fileExists(atPath: pathStr, isDirectory: &isDir) + outIsDirectory?.pointee = isDir.boolValue ? RAC_TRUE : RAC_FALSE + return exists ? RAC_TRUE : RAC_FALSE +} + +/// Get file size +private func fmGetFileSizeCallback( + path: UnsafePointer?, + userData _: UnsafeMutableRawPointer? +) -> Int64 { + guard let path = path else { return -1 } + let pathStr = String(cString: path) + guard let attrs = try? Foundation.FileManager.default.attributesOfItem(atPath: pathStr), + let size = attrs[.size] as? Int64 else { + return -1 + } + return size +} + +/// Get available disk space +private func fmGetAvailableSpaceCallback(userData _: UnsafeMutableRawPointer?) -> Int64 { + do { + let attrs = try Foundation.FileManager.default.attributesOfFileSystem(forPath: NSHomeDirectory()) + return (attrs[.systemFreeSize] as? Int64) ?? 0 + } catch { + return 0 + } +} + +/// Get total disk space +private func fmGetTotalSpaceCallback(userData _: UnsafeMutableRawPointer?) -> Int64 { + do { + let attrs = try Foundation.FileManager.default.attributesOfFileSystem(forPath: NSHomeDirectory()) + return (attrs[.systemSize] as? Int64) ?? 0 + } catch { + return 0 + } +} diff --git a/sdk/runanywhere-swift/Sources/RunAnywhere/Foundation/Bridge/Extensions/CppBridge+ModelRegistry.swift b/sdk/runanywhere-swift/Sources/RunAnywhere/Foundation/Bridge/Extensions/CppBridge+ModelRegistry.swift index f739b3bf0..05a396fd0 100644 --- a/sdk/runanywhere-swift/Sources/RunAnywhere/Foundation/Bridge/Extensions/CppBridge+ModelRegistry.swift +++ b/sdk/runanywhere-swift/Sources/RunAnywhere/Foundation/Bridge/Extensions/CppBridge+ModelRegistry.swift @@ -243,7 +243,7 @@ extension CppBridge { guard let path = path else { return RAC_ERROR_INVALID_ARGUMENT } let url = URL(fileURLWithPath: String(cString: path)) - let fm = FileManager.default + let fm = Foundation.FileManager.default guard let contents = try? fm.contentsOfDirectory(at: url, includingPropertiesForKeys: nil) else { outEntries?.pointee = nil @@ -286,7 +286,7 @@ extension CppBridge { guard let path = path else { return RAC_FALSE } let pathStr = String(cString: path) var isDir: ObjCBool = false - if FileManager.default.fileExists(atPath: pathStr, isDirectory: &isDir) { + if Foundation.FileManager.default.fileExists(atPath: pathStr, isDirectory: &isDir) { return isDir.boolValue ? RAC_TRUE : RAC_FALSE } return RAC_FALSE @@ -296,7 +296,7 @@ extension CppBridge { callbacks.path_exists = { path, _ -> rac_bool_t in guard let path = path else { return RAC_FALSE } let pathStr = String(cString: path) - return FileManager.default.fileExists(atPath: pathStr) ? RAC_TRUE : RAC_FALSE + return Foundation.FileManager.default.fileExists(atPath: pathStr) ? RAC_TRUE : RAC_FALSE } // Is model file callback - checks for known model extensions diff --git a/sdk/runanywhere-swift/Sources/RunAnywhere/Foundation/Bridge/Extensions/CppBridge+Storage.swift b/sdk/runanywhere-swift/Sources/RunAnywhere/Foundation/Bridge/Extensions/CppBridge+Storage.swift index ab60fdbd4..21a0e5e0f 100644 --- a/sdk/runanywhere-swift/Sources/RunAnywhere/Foundation/Bridge/Extensions/CppBridge+Storage.swift +++ b/sdk/runanywhere-swift/Sources/RunAnywhere/Foundation/Bridge/Extensions/CppBridge+Storage.swift @@ -173,14 +173,14 @@ extension CppBridge { // MARK: - C Callbacks (Platform-Specific File Operations) -/// Calculate directory size using FileManager +/// Calculate directory size — delegates to C++ file manager (single source of truth) private func storageCalculateDirSizeCallback( path: UnsafePointer?, userData _: UnsafeMutableRawPointer? ) -> Int64 { guard let path = path else { return 0 } let url = URL(fileURLWithPath: String(cString: path)) - return calculateDirectorySize(at: url) + return CppBridge.FileManager.calculateDirectorySize(at: url) } /// Get file size @@ -230,43 +230,6 @@ private func storageGetTotalSpaceCallback(userData _: UnsafeMutableRawPointer?) } } -/// Calculate directory size (recursive) -private func calculateDirectorySize(at url: URL) -> Int64 { - let fm = FileManager.default - - // Check if it's a directory - var isDirectory: ObjCBool = false - if fm.fileExists(atPath: url.path, isDirectory: &isDirectory) { - if !isDirectory.boolValue { - // It's a file - if let attrs = try? fm.attributesOfItem(atPath: url.path), - let fileSize = attrs[.size] as? Int64 { - return fileSize - } else { - return 0 - } - } - } - - // It's a directory - guard let enumerator = fm.enumerator( - at: url, - includingPropertiesForKeys: [.fileSizeKey, .isRegularFileKey], - options: [.skipsHiddenFiles] - ) else { - return 0 - } - - var totalSize: Int64 = 0 - for case let fileURL as URL in enumerator { - if let values = try? fileURL.resourceValues(forKeys: [.fileSizeKey, .isRegularFileKey]), - values.isRegularFile == true { - totalSize += Int64(values.fileSize ?? 0) - } - } - return totalSize -} - // MARK: - Swift Type Conversions extension StorageInfo { diff --git a/sdk/runanywhere-swift/Sources/RunAnywhere/Infrastructure/Download/Services/AlamofireDownloadService+Execution.swift b/sdk/runanywhere-swift/Sources/RunAnywhere/Infrastructure/Download/Services/AlamofireDownloadService+Execution.swift index 218d4097d..a65b3efa1 100644 --- a/sdk/runanywhere-swift/Sources/RunAnywhere/Infrastructure/Download/Services/AlamofireDownloadService+Execution.swift +++ b/sdk/runanywhere-swift/Sources/RunAnywhere/Infrastructure/Download/Services/AlamofireDownloadService+Execution.swift @@ -74,7 +74,7 @@ extension AlamofireDownloadService { } .validate() - activeDownloadRequests[taskId] = downloadRequest + storeDownloadRequest(downloadRequest, forKey: taskId) return try await withCheckedThrowingContinuation { continuation in downloadRequest.response { response in @@ -103,46 +103,43 @@ extension AlamofireDownloadService { } } - /// Perform extraction for archive models (platform-specific - uses SWCompression) + /// Perform extraction for archive models (uses native C++ libarchive via rac_extract_archive) + /// Archive type auto-detection and post-extraction model path finding are handled by C++. func performExtraction( archiveURL: URL, destinationFolder: URL, model: ModelInfo, progressContinuation: AsyncStream.Continuation ) async throws -> URL { - // Determine archive type from model artifact type or infer from archive URL/download URL - let archiveType: ArchiveType + // Use artifact type directly — C++ extraction auto-detects archive format from file contents. + // If model doesn't have an explicit archive artifact type, construct one with .unknown structure. let artifactTypeForExtraction: ModelArtifactType - - if case .archive(let type, let structure, let expectedFiles) = model.artifactType { - archiveType = type + if case .archive = model.artifactType { artifactTypeForExtraction = model.artifactType - } else if let inferredType = ArchiveType.from(url: archiveURL) { - // Infer from downloaded archive file path - archiveType = inferredType - artifactTypeForExtraction = .archive(inferredType, structure: .unknown, expectedFiles: .none) - logger.info("Inferred archive type from file path: \(inferredType.rawValue)") - } else if let originalDownloadURL = model.downloadURL, - let inferredType = ArchiveType.from(url: originalDownloadURL) { - // Infer from original download URL - archiveType = inferredType - artifactTypeForExtraction = .archive(inferredType, structure: .unknown, expectedFiles: .none) - logger.info("Inferred archive type from download URL: \(inferredType.rawValue)") } else { - throw SDKError.download(.extractionFailed, "Could not determine archive type for model: \(model.id)") + // C++ rac_extract_archive_native() auto-detects archive format, so archive type here + // is only used for the structure hint passed to post-extraction path finding. + artifactTypeForExtraction = .archive(.zip, structure: .unknown, expectedFiles: .none) } let extractionStartTime = Date() // Track extraction started via C++ event system + // Archive type detection is now in C++ — use artifact type if known, otherwise "unknown" + let archiveTypeString: String + if case .archive(let type, _, _) = model.artifactType { + archiveTypeString = type.rawValue + } else { + archiveTypeString = "unknown" + } CppBridge.Events.emitExtractionStarted( modelId: model.id, - archiveType: archiveType.rawValue + archiveType: archiveTypeString ) logger.info("Starting extraction", metadata: [ "modelId": model.id, - "archiveType": archiveType.rawValue, + "archiveType": archiveTypeString, "archiveURL": archiveURL.path, "destination": destinationFolder.path ]) @@ -156,6 +153,8 @@ extension AlamofireDownloadService { archiveURL: archiveURL, to: destinationFolder, artifactType: artifactTypeForExtraction, + framework: model.framework, + format: model.format, progressHandler: { progress in // Track extraction progress (via C++ for routing to EventBus/telemetry) if progress - lastReportedExtractionProgress >= 0.1 { diff --git a/sdk/runanywhere-swift/Sources/RunAnywhere/Infrastructure/Download/Services/AlamofireDownloadService.swift b/sdk/runanywhere-swift/Sources/RunAnywhere/Infrastructure/Download/Services/AlamofireDownloadService.swift index f9f7183a4..3e78130ff 100644 --- a/sdk/runanywhere-swift/Sources/RunAnywhere/Infrastructure/Download/Services/AlamofireDownloadService.swift +++ b/sdk/runanywhere-swift/Sources/RunAnywhere/Infrastructure/Download/Services/AlamofireDownloadService.swift @@ -4,7 +4,7 @@ import Foundation /// Download service using Alamofire for HTTP and C++ bridge for orchestration /// C++ handles: task tracking, progress calculation, retry logic -/// Swift handles: HTTP transport via Alamofire, extraction via SWCompression +/// Swift handles: HTTP transport via Alamofire, extraction via native C++ libarchive public class AlamofireDownloadService: @unchecked Sendable { // MARK: - Shared Instance @@ -15,19 +15,20 @@ public class AlamofireDownloadService: @unchecked Sendable { // MARK: - Properties let session: Session - var activeDownloadRequests: [String: DownloadRequest] = [:] + private var activeDownloadRequests: [String: DownloadRequest] = [:] + private let requestsQueue = DispatchQueue(label: "com.runanywhere.download.requests") let logger = SDKLogger(category: "AlamofireDownloadService") // MARK: - Services /// Extraction service for handling archive extraction - let extractionService: ModelExtractionServiceProtocol + let extractionService: ExtractionServiceProtocol // MARK: - Initialization public init( configuration: DownloadConfiguration = DownloadConfiguration(), - extractionService: ModelExtractionServiceProtocol = DefaultModelExtractionService() + extractionService: ExtractionServiceProtocol = DefaultExtractionService() ) { self.extractionService = extractionService @@ -67,9 +68,14 @@ public class AlamofireDownloadService: @unchecked Sendable { } public func cancelDownload(taskId: String) { - if let downloadRequest = activeDownloadRequests[taskId] { - downloadRequest.cancel() + let downloadRequest: DownloadRequest? = requestsQueue.sync { + guard let request = activeDownloadRequests[taskId] else { return nil } activeDownloadRequests.removeValue(forKey: taskId) + return request + } + + if let downloadRequest { + downloadRequest.cancel() // Notify C++ bridge Task { @@ -85,7 +91,7 @@ public class AlamofireDownloadService: @unchecked Sendable { /// Pause all active downloads public func pauseAll() { - activeDownloadRequests.values.forEach { $0.suspend() } + requestsQueue.sync { activeDownloadRequests.values }.forEach { $0.suspend() } Task { try? await CppBridge.Download.shared.pauseAll() } @@ -94,7 +100,7 @@ public class AlamofireDownloadService: @unchecked Sendable { /// Resume all paused downloads public func resumeAll() { - activeDownloadRequests.values.forEach { $0.resume() } + requestsQueue.sync { activeDownloadRequests.values }.forEach { $0.resume() } Task { try? await CppBridge.Download.shared.resumeAll() } @@ -134,13 +140,13 @@ public class AlamofireDownloadService: @unchecked Sendable { let (progressStream, progressContinuation) = AsyncStream.makeStream() // Determine if we need extraction - // First check artifact type, then infer from URL if not explicitly set + // First check artifact type, then infer from URL via C++ rac_download_requires_extraction() var requiresExtraction = model.artifactType.requiresExtraction // If artifact type doesn't require extraction, check if URL indicates an archive - // This is a safeguard for models registered without explicit artifact type - if !requiresExtraction, let archiveType = ArchiveType.from(url: downloadURL) { - logger.info("URL indicates archive type (\(archiveType.rawValue)) but artifact type doesn't require extraction. Inferring extraction needed.") + // Uses C++ archive detection (handles .tar.gz, .tar.bz2, .zip, etc.) + if !requiresExtraction, CppBridge.Download.downloadRequiresExtraction(url: downloadURL) { + logger.info("URL indicates archive but artifact type doesn't require extraction. Inferring extraction needed.") requiresExtraction = true } @@ -167,7 +173,7 @@ public class AlamofireDownloadService: @unchecked Sendable { result: Task { defer { progressContinuation.finish() - self.activeDownloadRequests.removeValue(forKey: taskId) + self.requestsQueue.sync { self.activeDownloadRequests.removeValue(forKey: taskId) } } do { @@ -219,7 +225,7 @@ public class AlamofireDownloadService: @unchecked Sendable { result: Task { defer { progressContinuation.finish() - self.activeDownloadRequests.removeValue(forKey: taskId) + self.requestsQueue.sync { self.activeDownloadRequests.removeValue(forKey: taskId) } } do { @@ -331,42 +337,26 @@ public class AlamofireDownloadService: @unchecked Sendable { return finalModelPath } - /// Determine the download destination based on extraction requirements + /// Determine the download destination using C++ path utilities. + /// C++ handles archive detection, temp path generation, and model folder resolution. private func determineDownloadDestination( for model: ModelInfo, modelFolderURL: URL, requiresExtraction: Bool ) -> URL { - if requiresExtraction { - // Download to temp location for archives - // Get archive extension - use the one from artifact type or infer from URL - let archiveExt = getArchiveExtensionFromModelOrURL(model) - - // Note: URL.appendingPathExtension doesn't work well with multi-part extensions like "tar.gz" - // So we construct the filename with extension directly - let filename = "\(model.id)_\(UUID().uuidString).\(archiveExt)" - return FileManager.default.temporaryDirectory.appendingPathComponent(filename) - } else { - // Download directly to model folder - return modelFolderURL.appendingPathComponent("\(model.id).\(model.format.rawValue)") - } - } - - /// Get archive extension from model's artifact type or infer from download URL - private func getArchiveExtensionFromModelOrURL(_ model: ModelInfo) -> String { - // First try to get from artifact type - if case .archive(let archiveType, _, _) = model.artifactType { - return archiveType.fileExtension - } - - // If not an explicit archive type, try to infer from download URL - if let url = model.downloadURL, - let archiveType = ArchiveType.from(url: url) { - return archiveType.fileExtension + // Try C++ destination computation first + if let downloadURL = model.downloadURL, + let result = CppBridge.Download.computeDownloadDestination( + modelId: model.id, + downloadURL: downloadURL, + framework: model.framework, + format: model.format + ) { + return result.path } - // Default to archive (unknown type) - return "archive" + // Fallback: download directly to model folder + return modelFolderURL.appendingPathComponent("\(model.id).\(model.format.rawValue)") } /// Log download start information @@ -437,6 +427,16 @@ public class AlamofireDownloadService: @unchecked Sendable { ]) } + // MARK: - Thread-Safe Request Management + + func storeDownloadRequest(_ request: DownloadRequest, forKey key: String) { + requestsQueue.sync { activeDownloadRequests[key] = request } + } + + func removeDownloadRequest(forKey key: String) { + requestsQueue.sync { _ = activeDownloadRequests.removeValue(forKey: key) } + } + // MARK: - Helper Methods func mapAlamofireError(_ error: AFError) -> SDKError { diff --git a/sdk/runanywhere-swift/Sources/RunAnywhere/Infrastructure/Download/Services/ExtractionService.swift b/sdk/runanywhere-swift/Sources/RunAnywhere/Infrastructure/Download/Services/ExtractionService.swift index 707c65088..a5d1e1db1 100644 --- a/sdk/runanywhere-swift/Sources/RunAnywhere/Infrastructure/Download/Services/ExtractionService.swift +++ b/sdk/runanywhere-swift/Sources/RunAnywhere/Infrastructure/Download/Services/ExtractionService.swift @@ -3,10 +3,11 @@ // RunAnywhere SDK // // Centralized service for extracting model archives. -// Uses pure Swift extraction via SWCompression (no native C library dependency). +// Uses native C++ extraction via libarchive (rac_extract_archive). // Located in Download as it's part of the download post-processing pipeline. // +import CRACommons import Foundation // MARK: - Extraction Result @@ -42,20 +43,45 @@ public protocol ExtractionServiceProtocol: Sendable { /// - archiveURL: URL to the downloaded archive /// - destinationURL: Directory to extract to /// - artifactType: The model's artifact type (determines extraction method) + /// - framework: Inference framework (used for post-extraction model path finding) + /// - format: Model format (used for post-extraction model path finding) /// - progressHandler: Optional callback for extraction progress (0.0 to 1.0) /// - Returns: Result containing the path to the extracted model func extract( archiveURL: URL, to destinationURL: URL, artifactType: ModelArtifactType, + framework: InferenceFramework, + format: ModelFormat, progressHandler: ((Double) -> Void)? ) async throws -> ExtractionResult } +// MARK: - Protocol Extension for Backward Compatibility + +extension ExtractionServiceProtocol { + /// Convenience method without framework/format (defaults to .unknown) + func extract( + archiveURL: URL, + to destinationURL: URL, + artifactType: ModelArtifactType, + progressHandler: ((Double) -> Void)? + ) async throws -> ExtractionResult { + return try await extract( + archiveURL: archiveURL, + to: destinationURL, + artifactType: artifactType, + framework: .unknown, + format: .unknown, + progressHandler: progressHandler + ) + } +} + // MARK: - Default Extraction Service /// Default implementation of the model extraction service -/// Uses pure Swift extraction via SWCompression for all archive types +/// Uses native C++ extraction via libarchive for all archive types public final class DefaultExtractionService: ExtractionServiceProtocol, @unchecked Sendable { private let logger = SDKLogger(category: "ExtractionService") @@ -65,18 +91,19 @@ public final class DefaultExtractionService: ExtractionServiceProtocol, @uncheck archiveURL: URL, to destinationURL: URL, artifactType: ModelArtifactType, + framework: InferenceFramework, + format: ModelFormat, progressHandler: ((Double) -> Void)? ) async throws -> ExtractionResult { let startTime = Date() - guard case .archive(let archiveType, let structure, _) = artifactType else { + guard case .archive(_, let structure, _) = artifactType else { throw SDKError.download(.extractionFailed, "Artifact type does not require extraction") } logger.info("Starting extraction", metadata: [ "archiveURL": archiveURL.path, - "destination": destinationURL.path, - "archiveType": archiveType.rawValue + "destination": destinationURL.path ]) // Ensure destination exists @@ -85,30 +112,35 @@ public final class DefaultExtractionService: ExtractionServiceProtocol, @uncheck // Report starting progressHandler?(0.0) - // Perform extraction based on archive type using pure Swift (SWCompression) - switch archiveType { - case .zip: - try ArchiveUtility.extractZipArchive(from: archiveURL, to: destinationURL, progressHandler: progressHandler) - case .tarBz2: - try ArchiveUtility.extractTarBz2Archive(from: archiveURL, to: destinationURL, progressHandler: progressHandler) - case .tarGz: - try ArchiveUtility.extractTarGzArchive(from: archiveURL, to: destinationURL, progressHandler: progressHandler) - case .tarXz: - try ArchiveUtility.extractTarXzArchive(from: archiveURL, to: destinationURL, progressHandler: progressHandler) + // Use native C++ extraction (libarchive) — auto-detects format from file contents + let result = rac_extract_archive( + archiveURL.path, + destinationURL.path, + nil, // no progress callback needed (we report 0.0 and 1.0) + nil // no user data + ) + + guard result == RAC_SUCCESS else { + throw SDKError.download(.extractionFailed, "Native extraction failed with code: \(result)") } - // Find the actual model path based on structure - let modelPath = findModelPath(in: destinationURL, structure: structure) + // Find the actual model path using C++ rac_find_model_path_after_extraction() + // This consolidates the previously duplicated findModelPath/findNestedDirectory/findSingleModelFile logic + let modelPath = CppBridge.Download.findModelPathAfterExtraction( + extractedDir: destinationURL, + structure: structure, + framework: framework, + format: format + ) ?? destinationURL - // Calculate extracted size and file count - let (extractedSize, fileCount) = calculateExtractionStats(at: destinationURL) + // Calculate extracted size using C++ file manager (single source of truth) + let extractedSize = CppBridge.FileManager.calculateDirectorySize(at: destinationURL) let duration = Date().timeIntervalSince(startTime) logger.info("Extraction completed", metadata: [ "modelPath": modelPath.path, "extractedSize": extractedSize, - "fileCount": fileCount, "durationSeconds": duration ]) @@ -117,114 +149,9 @@ public final class DefaultExtractionService: ExtractionServiceProtocol, @uncheck return ExtractionResult( modelPath: modelPath, extractedSize: extractedSize, - fileCount: fileCount, + fileCount: 0, durationSeconds: duration ) } - // MARK: - Helper Methods - - /// Find the actual model path based on archive structure - private func findModelPath(in extractedDir: URL, structure: ArchiveStructure) -> URL { - switch structure { - case .singleFileNested: - // Look for a single model file, possibly in a subdirectory - return findSingleModelFile(in: extractedDir) ?? extractedDir - - case .nestedDirectory: - // Common pattern: archive contains one subdirectory with all the files - // e.g., sherpa-onnx archives extract to: extractedDir/vits-xxx/ - return findNestedDirectory(in: extractedDir) - - case .directoryBased, .unknown: - // Return the extraction directory itself - return extractedDir - } - } - - /// Find nested directory (for archives that extract to a subdirectory) - private func findNestedDirectory(in extractedDir: URL) -> URL { - let fm = FileManager.default - - guard let contents = try? fm.contentsOfDirectory(at: extractedDir, includingPropertiesForKeys: [.isDirectoryKey]) else { - return extractedDir - } - - // Filter out hidden files and macOS resource forks - let visibleContents = contents.filter { - !$0.lastPathComponent.hasPrefix(".") && !$0.lastPathComponent.hasPrefix("._") - } - - // If there's a single visible subdirectory, return it - if visibleContents.count == 1, let first = visibleContents.first { - var isDir: ObjCBool = false - if fm.fileExists(atPath: first.path, isDirectory: &isDir), isDir.boolValue { - return first - } - } - - return extractedDir - } - - /// Find a single model file in a directory (recursive, up to 2 levels) - private func findSingleModelFile(in directory: URL, depth: Int = 0) -> URL? { - guard depth < 2 else { return nil } - - let fm = FileManager.default - guard let contents = try? fm.contentsOfDirectory(at: directory, includingPropertiesForKeys: [.isDirectoryKey]) else { - return nil - } - - // Known model file extensions - let modelExtensions = Set(["gguf", "onnx", "ort", "bin"]) - - // Look for model files at this level - for item in contents where modelExtensions.contains(item.pathExtension.lowercased()) { - return item - } - - // Recursively check subdirectories - for item in contents { - var isDir: ObjCBool = false - if fm.fileExists(atPath: item.path, isDirectory: &isDir), isDir.boolValue { - if let found = findSingleModelFile(in: item, depth: depth + 1) { - return found - } - } - } - - return nil - } - - /// Calculate size and file count for extracted content - private func calculateExtractionStats(at url: URL) -> (Int64, Int) { - let fm = FileManager.default - guard let enumerator = fm.enumerator( - at: url, - includingPropertiesForKeys: [.fileSizeKey, .isRegularFileKey], - options: [] - ) else { - return (0, 0) - } - - var totalSize: Int64 = 0 - var fileCount = 0 - - for case let fileURL as URL in enumerator { - if let values = try? fileURL.resourceValues(forKeys: [.fileSizeKey, .isRegularFileKey]) { - if values.isRegularFile == true { - fileCount += 1 - totalSize += Int64(values.fileSize ?? 0) - } - } - } - - return (totalSize, fileCount) - } } - -// MARK: - Type Aliases for backward compatibility - -/// Type alias for backward compatibility -public typealias ModelExtractionServiceProtocol = ExtractionServiceProtocol -public typealias DefaultModelExtractionService = DefaultExtractionService diff --git a/sdk/runanywhere-swift/Sources/RunAnywhere/Infrastructure/Download/Utilities/ArchiveUtility.swift b/sdk/runanywhere-swift/Sources/RunAnywhere/Infrastructure/Download/Utilities/ArchiveUtility.swift deleted file mode 100644 index 3123541de..000000000 --- a/sdk/runanywhere-swift/Sources/RunAnywhere/Infrastructure/Download/Utilities/ArchiveUtility.swift +++ /dev/null @@ -1,546 +0,0 @@ -import Compression -import Foundation -import SWCompression -import ZIPFoundation - -/// Utility for handling archive operations -/// Uses Apple's native Compression framework for gzip (fast) and SWCompression for bzip2/xz (pure Swift) -/// Works on all Apple platforms (iOS, macOS, tvOS, watchOS) -public final class ArchiveUtility { - - private static let logger = SDKLogger(category: "ArchiveUtility") - - private init() {} - - // MARK: - Public Extraction Methods - - /// Extract a tar.bz2 archive to a destination directory - /// Uses SWCompression for pure Swift bzip2 decompression (slower - Apple doesn't support bzip2 natively) - /// - Parameters: - /// - sourceURL: The URL of the tar.bz2 file to extract - /// - destinationURL: The destination directory URL - /// - progressHandler: Optional callback for extraction progress (0.0 to 1.0) - /// - Throws: SDKError if extraction fails - public static func extractTarBz2Archive( - from sourceURL: URL, - to destinationURL: URL, - progressHandler: ((Double) -> Void)? = nil - ) throws { - let overallStart = Date() - logger.info("🗜️ [EXTRACTION START] tar.bz2 archive: \(sourceURL.lastPathComponent)") - logger.warning("⚠️ bzip2 uses pure Swift decompression (slower than native gzip)") - progressHandler?(0.0) - - // Step 1: Read compressed data - let readStart = Date() - let compressedData = try Data(contentsOf: sourceURL) - let readTime = Date().timeIntervalSince(readStart) - logger.info("📖 [READ] \(formatBytes(compressedData.count)) in \(String(format: "%.2f", readTime))s") - progressHandler?(0.05) - - // Step 2: Decompress bzip2 using pure Swift (no native support from Apple) - let decompressStart = Date() - logger.info("🐢 [DECOMPRESS] Starting pure Swift bzip2 decompression (this may take a while)...") - let tarData: Data - do { - tarData = try BZip2.decompress(data: compressedData) - } catch { - logger.error("BZip2 decompression failed: \(error)") - throw SDKError.download(.extractionFailed, "BZip2 decompression failed: \(error.localizedDescription)", underlying: error) - } - let decompressTime = Date().timeIntervalSince(decompressStart) - logger.info("✅ [DECOMPRESS] \(formatBytes(compressedData.count)) → \(formatBytes(tarData.count)) in \(String(format: "%.2f", decompressTime))s") - progressHandler?(0.4) - - // Step 3: Extract tar archive - let extractStart = Date() - logger.info("📦 [TAR EXTRACT] Extracting files...") - try extractTarData(tarData, to: destinationURL, progressHandler: { progress in - progressHandler?(0.4 + progress * 0.6) - }) - let extractTime = Date().timeIntervalSince(extractStart) - logger.info("✅ [TAR EXTRACT] Completed in \(String(format: "%.2f", extractTime))s") - - let totalTime = Date().timeIntervalSince(overallStart) - let timingInfo = """ - read: \(String(format: "%.2f", readTime))s, \ - decompress: \(String(format: "%.2f", decompressTime))s, \ - extract: \(String(format: "%.2f", extractTime))s - """ - logger.info("🎉 [EXTRACTION COMPLETE] Total: \(String(format: "%.2f", totalTime))s (\(timingInfo))") - progressHandler?(1.0) - } - - /// Extract a tar.gz archive to a destination directory - /// Uses Apple's native Compression framework for fast gzip decompression - /// - Parameters: - /// - sourceURL: The URL of the tar.gz file to extract - /// - destinationURL: The destination directory URL - /// - progressHandler: Optional callback for extraction progress (0.0 to 1.0) - /// - Throws: SDKError if extraction fails - public static func extractTarGzArchive( - from sourceURL: URL, - to destinationURL: URL, - progressHandler: ((Double) -> Void)? = nil - ) throws { - let overallStart = Date() - logger.info("🗜️ [EXTRACTION START] tar.gz archive: \(sourceURL.lastPathComponent)") - progressHandler?(0.0) - - // Step 1: Read compressed data - let readStart = Date() - let compressedData = try Data(contentsOf: sourceURL) - let readTime = Date().timeIntervalSince(readStart) - logger.info("📖 [READ] \(formatBytes(compressedData.count)) in \(String(format: "%.2f", readTime))s") - progressHandler?(0.05) - - // Step 2: Decompress gzip using NATIVE Compression framework (10-20x faster than pure Swift) - let decompressStart = Date() - logger.info("⚡ [DECOMPRESS] Starting native gzip decompression...") - let tarData: Data - do { - tarData = try decompressGzipNative(compressedData) - } catch { - logger.error("Native gzip decompression failed: \(error), falling back to pure Swift") - // Fallback to SWCompression if native fails - do { - tarData = try GzipArchive.unarchive(archive: compressedData) - } catch { - logger.error("Gzip decompression failed: \(error)") - throw SDKError.download(.extractionFailed, "Gzip decompression failed: \(error.localizedDescription)", underlying: error) - } - } - let decompressTime = Date().timeIntervalSince(decompressStart) - logger.info("✅ [DECOMPRESS] \(formatBytes(compressedData.count)) → \(formatBytes(tarData.count)) in \(String(format: "%.2f", decompressTime))s") - progressHandler?(0.3) - - // Step 3: Extract tar archive - let extractStart = Date() - logger.info("📦 [TAR EXTRACT] Extracting files...") - try extractTarData(tarData, to: destinationURL, progressHandler: { progress in - progressHandler?(0.3 + progress * 0.7) - }) - let extractTime = Date().timeIntervalSince(extractStart) - logger.info("✅ [TAR EXTRACT] Completed in \(String(format: "%.2f", extractTime))s") - - let totalTime = Date().timeIntervalSince(overallStart) - let gzTimingInfo = """ - read: \(String(format: "%.2f", readTime))s, \ - decompress: \(String(format: "%.2f", decompressTime))s, \ - extract: \(String(format: "%.2f", extractTime))s - """ - logger.info("🎉 [EXTRACTION COMPLETE] Total: \(String(format: "%.2f", totalTime))s (\(gzTimingInfo))") - progressHandler?(1.0) - } - - /// Decompress gzip data using Apple's native Compression framework - /// Uses streaming decompression (compression_stream_process) to avoid huge pre-allocations - private static func decompressGzipNative(_ compressedData: Data) throws -> Data { - guard compressedData.count >= 10 else { - throw SDKError.download(.extractionFailed, "Invalid gzip data: too short") - } - - guard compressedData[0] == 0x1f && compressedData[1] == 0x8b else { - throw SDKError.download(.extractionFailed, "Invalid gzip magic number") - } - - guard compressedData[2] == 8 else { - throw SDKError.download(.extractionFailed, "Unsupported gzip compression method") - } - - let flags = compressedData[3] - var headerSize = 10 - - if flags & 0x04 != 0 { // FEXTRA - guard compressedData.count >= headerSize + 2 else { - throw SDKError.download(.extractionFailed, "Invalid gzip extra field") - } - let extraLen = Int(compressedData[headerSize]) | (Int(compressedData[headerSize + 1]) << 8) - headerSize += 2 + extraLen - } - - if flags & 0x08 != 0 { // FNAME - while headerSize < compressedData.count && compressedData[headerSize] != 0 { - headerSize += 1 - } - headerSize += 1 - } - - if flags & 0x10 != 0 { // FCOMMENT - while headerSize < compressedData.count && compressedData[headerSize] != 0 { - headerSize += 1 - } - headerSize += 1 - } - - if flags & 0x02 != 0 { // FHCRC - headerSize += 2 - } - - guard compressedData.count > headerSize + 8 else { - throw SDKError.download(.extractionFailed, "Invalid gzip structure") - } - - let deflateStart = headerSize - let deflateEnd = compressedData.count - 8 - - return try decompressDeflateStreaming(compressedData, range: deflateStart..) throws -> Data { - var stream = compression_stream() - guard compression_stream_init(&stream, COMPRESSION_STREAM_DECODE, COMPRESSION_ZLIB) == COMPRESSION_STATUS_OK else { - throw SDKError.download(.extractionFailed, "Failed to initialize decompression stream") - } - defer { compression_stream_destroy(&stream) } - - let outputChunkSize = 256 * 1024 // 256 KB - let outputBuffer = UnsafeMutablePointer.allocate(capacity: outputChunkSize) - defer { outputBuffer.deallocate() } - - var result = Data() - let deflateSize = range.count - result.reserveCapacity(min(deflateSize * 2, 1024 * 1024 * 1024)) - - try data.withUnsafeBytes { rawBuffer in - guard let base = rawBuffer.baseAddress else { - throw SDKError.download(.extractionFailed, "Cannot access compressed data buffer") - } - let srcBase = base.advanced(by: range.lowerBound).assumingMemoryBound(to: UInt8.self) - - stream.src_ptr = srcBase - stream.src_size = deflateSize - - var status: compression_status - repeat { - stream.dst_ptr = outputBuffer - stream.dst_size = outputChunkSize - - status = compression_stream_process(&stream, COMPRESSION_STREAM_FINALIZE) - - let bytesProduced = outputChunkSize - stream.dst_size - if bytesProduced > 0 { - result.append(outputBuffer, count: bytesProduced) - } - } while status == COMPRESSION_STATUS_OK - - guard status == COMPRESSION_STATUS_END else { - throw SDKError.download(.extractionFailed, "Streaming decompression failed (status \(status))") - } - } - - return result - } - - /// Extract a tar.xz archive to a destination directory - /// Uses SWCompression for pure Swift LZMA/XZ decompression and tar extraction - /// - Parameters: - /// - sourceURL: The URL of the tar.xz file to extract - /// - destinationURL: The destination directory URL - /// - progressHandler: Optional callback for extraction progress (0.0 to 1.0) - /// - Throws: SDKError if extraction fails - public static func extractTarXzArchive( - from sourceURL: URL, - to destinationURL: URL, - progressHandler: ((Double) -> Void)? = nil - ) throws { - logger.info("Extracting tar.xz archive: \(sourceURL.lastPathComponent)") - progressHandler?(0.0) - - // Read compressed data - let compressedData = try Data(contentsOf: sourceURL) - logger.debug("Read \(formatBytes(compressedData.count)) from archive") - progressHandler?(0.1) - - // Step 1: Decompress XZ using SWCompression - logger.debug("Decompressing XZ...") - let tarData: Data - do { - tarData = try XZArchive.unarchive(archive: compressedData) - } catch { - logger.error("XZ decompression failed: \(error)") - throw SDKError.download(.extractionFailed, "XZ decompression failed: \(error.localizedDescription)", underlying: error) - } - logger.debug("Decompressed to \(formatBytes(tarData.count)) of tar data") - progressHandler?(0.4) - - // Step 2: Extract tar archive using SWCompression - try extractTarData(tarData, to: destinationURL, progressHandler: { progress in - // Map tar extraction progress (0.4 to 1.0) - progressHandler?(0.4 + progress * 0.6) - }) - - logger.info("tar.xz extraction completed to: \(destinationURL.lastPathComponent)") - progressHandler?(1.0) - } - - /// Extract a zip archive to a destination directory - /// Uses ZIPFoundation for zip extraction - /// - Parameters: - /// - sourceURL: The URL of the zip file to extract - /// - destinationURL: The destination directory URL - /// - progressHandler: Optional callback for extraction progress (0.0 to 1.0) - /// - Throws: SDKError if extraction fails - public static func extractZipArchive( - from sourceURL: URL, - to destinationURL: URL, - progressHandler: ((Double) -> Void)? = nil - ) throws { - logger.info("Extracting zip archive: \(sourceURL.lastPathComponent)") - progressHandler?(0.0) - - do { - let fileManager = FileManager.default - - // Clean up any existing partial extraction to avoid "file already exists" errors - // This handles cases where a previous extraction was interrupted - if fileManager.fileExists(atPath: destinationURL.path) { - logger.info("Removing existing destination directory for clean extraction: \(destinationURL.lastPathComponent)") - try fileManager.removeItem(at: destinationURL) - } - - // Ensure destination directory exists - try fileManager.createDirectory( - at: destinationURL, - withIntermediateDirectories: true, - attributes: nil - ) - - // Use ZIPFoundation to extract - try fileManager.unzipItem( - at: sourceURL, - to: destinationURL, - skipCRC32: true, - progress: nil, - pathEncoding: .utf8 - ) - - logger.info("zip extraction completed to: \(destinationURL.lastPathComponent)") - progressHandler?(1.0) - } catch { - logger.error("Zip extraction failed: \(error)") - throw SDKError.download(.extractionFailed, "Failed to extract zip archive: \(error.localizedDescription)", underlying: error) - } - } - - /// Extract any supported archive format based on file extension - /// - Parameters: - /// - sourceURL: The archive file URL - /// - destinationURL: The destination directory URL - /// - progressHandler: Optional callback for extraction progress (0.0 to 1.0) - /// - Throws: SDKError if extraction fails or format is unsupported - public static func extractArchive( - from sourceURL: URL, - to destinationURL: URL, - progressHandler: ((Double) -> Void)? = nil - ) throws { - let archiveType = detectArchiveType(from: sourceURL) - - switch archiveType { - case .tarBz2: - try extractTarBz2Archive(from: sourceURL, to: destinationURL, progressHandler: progressHandler) - case .tarGz: - try extractTarGzArchive(from: sourceURL, to: destinationURL, progressHandler: progressHandler) - case .tarXz: - try extractTarXzArchive(from: sourceURL, to: destinationURL, progressHandler: progressHandler) - case .zip: - try extractZipArchive(from: sourceURL, to: destinationURL, progressHandler: progressHandler) - case .unknown: - throw SDKError.download(.unsupportedArchive, "Unsupported archive format: \(sourceURL.pathExtension)") - } - } - - // MARK: - Archive Type Detection - - /// Supported archive types - public enum ArchiveFormat { - case tarBz2 - case tarGz - case tarXz - case zip - case unknown - } - - /// Detect archive type from URL - public static func detectArchiveType(from url: URL) -> ArchiveFormat { - let path = url.path.lowercased() - - if path.hasSuffix(".tar.bz2") || path.hasSuffix(".tbz2") || path.hasSuffix(".tbz") { - return .tarBz2 - } else if path.hasSuffix(".tar.gz") || path.hasSuffix(".tgz") { - return .tarGz - } else if path.hasSuffix(".tar.xz") || path.hasSuffix(".txz") { - return .tarXz - } else if path.hasSuffix(".zip") { - return .zip - } - - return .unknown - } - - /// Check if a URL points to a tar.bz2 archive - public static func isTarBz2Archive(_ url: URL) -> Bool { - detectArchiveType(from: url) == .tarBz2 - } - - /// Check if a URL points to a tar.gz archive - public static func isTarGzArchive(_ url: URL) -> Bool { - detectArchiveType(from: url) == .tarGz - } - - /// Check if a URL points to a zip archive - public static func isZipArchive(_ url: URL) -> Bool { - detectArchiveType(from: url) == .zip - } - - /// Check if a URL points to any supported archive format - public static func isSupportedArchive(_ url: URL) -> Bool { - detectArchiveType(from: url) != .unknown - } - - // MARK: - Zip Creation - - /// Create a zip archive from a source directory - /// - Parameters: - /// - sourceURL: The source directory URL - /// - destinationURL: The destination zip file URL - /// - Throws: SDKError if compression fails - public static func createZipArchive( - from sourceURL: URL, - to destinationURL: URL - ) throws { - do { - try FileManager.default.zipItem( - at: sourceURL, - to: destinationURL, - shouldKeepParent: false, - compressionMethod: .deflate, - progress: nil - ) - logger.info("Created zip archive at: \(destinationURL.lastPathComponent)") - } catch { - logger.error("Failed to create zip archive: \(error)") - throw SDKError.download(.extractionFailed, "Failed to create archive: \(error.localizedDescription)", underlying: error) - } - } - - // MARK: - Private Helpers - - /// Extract tar data to destination directory using SWCompression - private static func extractTarData( - _ tarData: Data, - to destinationURL: URL, - progressHandler: ((Double) -> Void)? = nil - ) throws { - // Step 1: Parse tar entries - let parseStart = Date() - logger.info(" 📋 [TAR PARSE] Parsing tar entries from \(formatBytes(tarData.count))...") - - // Ensure destination directory exists - try FileManager.default.createDirectory(at: destinationURL, withIntermediateDirectories: true) - - // Parse tar entries using SWCompression - let entries: [TarEntry] - do { - entries = try TarContainer.open(container: tarData) - } catch { - logger.error("Tar parsing failed: \(error)") - throw SDKError.download(.extractionFailed, "Tar parsing failed: \(error.localizedDescription)", underlying: error) - } - let parseTime = Date().timeIntervalSince(parseStart) - logger.info(" ✅ [TAR PARSE] Found \(entries.count) entries in \(String(format: "%.2f", parseTime))s") - - // Step 2: Write files to disk - let writeStart = Date() - logger.info(" 💾 [FILE WRITE] Writing files to disk...") - - var extractedCount = 0 - var extractedFiles = 0 - var extractedDirs = 0 - var totalBytesWritten: Int64 = 0 - - for entry in entries { - let entryPath = entry.info.name - - // Skip empty names or entries starting with ._ (macOS resource forks) - guard !entryPath.isEmpty, !entryPath.hasPrefix("._") else { - continue - } - - let fullPath = destinationURL.appendingPathComponent(entryPath) - - switch entry.info.type { - case .directory: - try FileManager.default.createDirectory(at: fullPath, withIntermediateDirectories: true) - extractedDirs += 1 - - case .regular: - // Create parent directory if needed - let parentDir = fullPath.deletingLastPathComponent() - try FileManager.default.createDirectory(at: parentDir, withIntermediateDirectories: true) - - // Write file data - if let data = entry.data { - try data.write(to: fullPath) - extractedFiles += 1 - totalBytesWritten += Int64(data.count) - } - - case .symbolicLink: - // Handle symbolic links if needed - let linkName = entry.info.linkName - if !linkName.isEmpty { - let parentDir = fullPath.deletingLastPathComponent() - try FileManager.default.createDirectory(at: parentDir, withIntermediateDirectories: true) - try? FileManager.default.createSymbolicLink(atPath: fullPath.path, withDestinationPath: linkName) - } - - default: - // Skip other types (block devices, character devices, etc.) - break - } - - extractedCount += 1 - progressHandler?(Double(extractedCount) / Double(entries.count)) - } - - let writeTime = Date().timeIntervalSince(writeStart) - let bytesStr = formatBytes(Int(totalBytesWritten)) - let timeStr = String(format: "%.2f", writeTime) - logger.info(" ✅ [FILE WRITE] Wrote \(extractedFiles) files (\(bytesStr)) and \(extractedDirs) dirs in \(timeStr)s") - } - - /// Format bytes for logging - private static func formatBytes(_ bytes: Int) -> String { - if bytes < 1024 { - return "\(bytes) B" - } else if bytes < 1024 * 1024 { - return String(format: "%.1f KB", Double(bytes) / 1024) - } else if bytes < 1024 * 1024 * 1024 { - return String(format: "%.1f MB", Double(bytes) / (1024 * 1024)) - } else { - return String(format: "%.2f GB", Double(bytes) / (1024 * 1024 * 1024)) - } - } -} - -// MARK: - FileManager Extension for Archive Operations - -public extension FileManager { - - /// Extract any supported archive format - /// - Parameters: - /// - sourceURL: The archive file URL - /// - destinationURL: The destination directory URL - /// - progressHandler: Optional callback for extraction progress (0.0 to 1.0) - /// - Throws: SDKError if extraction fails or format is unsupported - func extractArchive( - from sourceURL: URL, - to destinationURL: URL, - progressHandler: ((Double) -> Void)? = nil - ) throws { - try ArchiveUtility.extractArchive(from: sourceURL, to: destinationURL, progressHandler: progressHandler) - } -} diff --git a/sdk/runanywhere-swift/Sources/RunAnywhere/Infrastructure/FileManagement/Services/SimplifiedFileManager.swift b/sdk/runanywhere-swift/Sources/RunAnywhere/Infrastructure/FileManagement/Services/SimplifiedFileManager.swift index 8db6138ac..dda6f56de 100644 --- a/sdk/runanywhere-swift/Sources/RunAnywhere/Infrastructure/FileManagement/Services/SimplifiedFileManager.swift +++ b/sdk/runanywhere-swift/Sources/RunAnywhere/Infrastructure/FileManagement/Services/SimplifiedFileManager.swift @@ -43,10 +43,9 @@ public class SimplifiedFileManager { } private func createDirectoryStructure() throws { - _ = try baseFolder.createSubfolderIfNeeded(withName: "Models") - _ = try baseFolder.createSubfolderIfNeeded(withName: "Cache") - _ = try baseFolder.createSubfolderIfNeeded(withName: "Temp") - _ = try baseFolder.createSubfolderIfNeeded(withName: "Downloads") + guard CppBridge.FileManager.createDirectoryStructure() else { + throw SDKError.fileManagement(.directoryCreationFailed, "Failed to create directory structure via C++ bridge") + } } // MARK: - Model Folder Access @@ -59,10 +58,7 @@ public class SimplifiedFileManager { /// Check if a model folder exists and contains files public func modelFolderExists(modelId: String, framework: InferenceFramework) -> Bool { - guard let folderURL = try? CppBridge.ModelPaths.getModelFolder(modelId: modelId, framework: framework) else { - return false - } - return folderExistsAndHasContents(at: folderURL) + return CppBridge.FileManager.modelFolderHasContents(modelId: modelId, framework: framework) } /// Get the model folder URL (without creating it) @@ -72,14 +68,10 @@ public class SimplifiedFileManager { /// Delete a model folder and all its contents public func deleteModel(modelId: String, framework: InferenceFramework) throws { - let folderURL = try CppBridge.ModelPaths.getModelFolder(modelId: modelId, framework: framework) - - if FileManager.default.fileExists(atPath: folderURL.path) { - try FileManager.default.removeItem(at: folderURL) - logger.info("Deleted model: \(modelId) from \(framework.rawValue)") - } else { - logger.info("\(modelId) does NOT exist in \(framework.rawValue)") + guard CppBridge.FileManager.deleteModel(modelId: modelId, framework: framework) else { + throw SDKError.fileManagement(.deleteFailed, "Failed to delete model: \(modelId)") } + logger.info("Deleted model: \(modelId) from \(framework.rawValue)") } // MARK: - Model Discovery @@ -128,15 +120,7 @@ public class SimplifiedFileManager { /// Check if a specific model is downloaded @MainActor public func isModelDownloaded(modelId: String, framework: InferenceFramework) -> Bool { - // Check if the folder exists and has contents - guard let folderURL = try? CppBridge.ModelPaths.getModelFolder(modelId: modelId, framework: framework), - folderExistsAndHasContents(at: folderURL) else { - return false - } - - // Folder exists with contents - model is downloaded - // Module-specific validation can be done by the service when loading - return true + return CppBridge.FileManager.modelFolderHasContents(modelId: modelId, framework: framework) } // MARK: - Download Management @@ -165,9 +149,8 @@ public class SimplifiedFileManager { } public func clearCache() throws { - let cacheFolder = try baseFolder.subfolder(named: "Cache") - for file in cacheFolder.files { - try file.delete() + guard CppBridge.FileManager.clearCache() else { + throw SDKError.fileManagement(.deleteFailed, "Failed to clear cache") } logger.info("Cleared cache") } @@ -175,45 +158,16 @@ public class SimplifiedFileManager { // MARK: - Temp Files public func cleanTempFiles() throws { - let tempFolder = try baseFolder.subfolder(named: "Temp") - for file in tempFolder.files { - try file.delete() + guard CppBridge.FileManager.clearTemp() else { + throw SDKError.fileManagement(.deleteFailed, "Failed to clean temp files") } logger.info("Cleaned temp files") } // MARK: - Storage Info - public func getAvailableSpace() -> Int64 { - do { - let values = try URL(fileURLWithPath: baseFolder.path).resourceValues(forKeys: [.volumeAvailableCapacityForImportantUsageKey]) - return values.volumeAvailableCapacityForImportantUsage ?? 0 - } catch { - return 0 - } - } - - public func getDeviceStorageInfo() -> DeviceStorageInfo { - do { - let attributes = try FileManager.default.attributesOfFileSystem(forPath: NSHomeDirectory()) - let totalSpace = (attributes[.systemSize] as? Int64) ?? 0 - let freeSpace = (attributes[.systemFreeSize] as? Int64) ?? 0 - return DeviceStorageInfo(totalSpace: totalSpace, freeSpace: freeSpace, usedSpace: totalSpace - freeSpace) - } catch { - return DeviceStorageInfo(totalSpace: 0, freeSpace: 0, usedSpace: 0) - } - } - public func calculateDirectorySize(at url: URL) -> Int64 { - var totalSize: Int64 = 0 - if let enumerator = FileManager.default.enumerator(at: url, includingPropertiesForKeys: [.fileSizeKey], options: []) { - for case let fileURL as URL in enumerator { - if let fileSize = try? fileURL.resourceValues(forKeys: [.fileSizeKey]).fileSize { - totalSize += Int64(fileSize) - } - } - } - return totalSize + return CppBridge.FileManager.calculateDirectorySize(at: url) } public func getBaseDirectoryURL() -> URL { diff --git a/sdk/runanywhere-swift/Sources/RunAnywhere/Infrastructure/FileManagement/Utilities/FileOperationsUtilities.swift b/sdk/runanywhere-swift/Sources/RunAnywhere/Infrastructure/FileManagement/Utilities/FileOperationsUtilities.swift index 84125cc32..d8dab97d7 100644 --- a/sdk/runanywhere-swift/Sources/RunAnywhere/Infrastructure/FileManagement/Utilities/FileOperationsUtilities.swift +++ b/sdk/runanywhere-swift/Sources/RunAnywhere/Infrastructure/FileManagement/Utilities/FileOperationsUtilities.swift @@ -149,23 +149,6 @@ public struct FileOperationsUtilities { try FileManager.default.createDirectory(at: url, withIntermediateDirectories: withIntermediateDirectories, attributes: nil) } - /// Calculate the total size of a directory including all subdirectories - /// - Parameter url: The directory URL - /// - Returns: Total size in bytes - public static func calculateDirectorySize(at url: URL) -> Int64 { - var totalSize: Int64 = 0 - - if let enumerator = enumerateDirectory(at: url, includingPropertiesForKeys: [.fileSizeKey]) { - for case let fileURL as URL in enumerator { - if let fileSize = try? fileURL.resourceValues(forKeys: [.fileSizeKey]).fileSize { - totalSize += Int64(fileSize) - } - } - } - - return totalSize - } - // MARK: - File/Directory Removal /// Remove a file or directory at the specified URL diff --git a/sdk/runanywhere-web/packages/core/src/Infrastructure/ArchiveUtility.ts b/sdk/runanywhere-web/packages/core/src/Infrastructure/ArchiveUtility.ts index 51e9c46d8..509a812d3 100644 --- a/sdk/runanywhere-web/packages/core/src/Infrastructure/ArchiveUtility.ts +++ b/sdk/runanywhere-web/packages/core/src/Infrastructure/ArchiveUtility.ts @@ -1,9 +1,21 @@ /** - * Archive Utility - Tar.gz extraction for model archives + * Archive Utility - Tar.gz extraction for model archives (Web-specific) * - * Provides browser-native tar.gz extraction using DecompressionStream (gzip) - * and a minimal tar parser. This matches the Swift SDK approach where Piper TTS - * models are distributed as .tar.gz archives bundling model files + espeak-ng-data. + * Web-platform equivalent of `rac_extract_archive_native()` (libarchive) used + * by all native SDKs (Swift, Kotlin, React Native, Flutter). + * + * On native platforms, extraction goes through the shared C++ libarchive + * implementation which operates on filesystem paths. The Web SDK cannot use + * that path because: + * 1. The ONNX provider (sherpa-onnx.wasm) and RACommons (racommons-llamacpp.wasm) + * are separate WASM modules with isolated virtual filesystems and memory spaces. + * 2. `rac_extract_archive_native` operates on file paths — using it would require + * writing the archive to the llamacpp WASM FS, extracting, reading files back + * to JS, then writing them to the sherpa WASM FS (double copy, extra WASM load). + * 3. The browser already provides native gzip decompression via DecompressionStream. + * + * This implementation uses browser-native DecompressionStream (gzip) and a minimal + * tar parser to extract archives in-memory (Uint8Array → TarEntry[]). */ // --------------------------------------------------------------------------- diff --git a/sdk/runanywhere-web/packages/llamacpp/src/Foundation/PlatformAdapter.ts b/sdk/runanywhere-web/packages/llamacpp/src/Foundation/PlatformAdapter.ts index 9cef5e825..bae522196 100644 --- a/sdk/runanywhere-web/packages/llamacpp/src/Foundation/PlatformAdapter.ts +++ b/sdk/runanywhere-web/packages/llamacpp/src/Foundation/PlatformAdapter.ts @@ -33,7 +33,6 @@ interface RegisteredCallbacks { nowMs: number; getMemoryInfo: number; httpDownload: number; - extractArchive: number; } /** @@ -79,7 +78,6 @@ export class PlatformAdapter { nowMs: this.registerNowMs(m), getMemoryInfo: this.registerGetMemoryInfo(m), httpDownload: this.registerHttpDownload(m), - extractArchive: this.registerExtractArchive(m), }; // Write function pointers into the struct. @@ -103,7 +101,8 @@ export class PlatformAdapter { m.setValue(this.adapterPtr + offset, this.callbacks.httpDownload, '*'); offset += PTR_SIZE; // http_download_cancel: optional, set to 0 (null) m.setValue(this.adapterPtr + offset, 0, '*'); offset += PTR_SIZE; - m.setValue(this.adapterPtr + offset, this.callbacks.extractArchive, '*'); offset += PTR_SIZE; + // extract_archive: no-op (native libarchive compiled into WASM, bypasses platform adapter) + m.setValue(this.adapterPtr + offset, 0, '*'); offset += PTR_SIZE; // user_data: set to 0 (null) m.setValue(this.adapterPtr + offset, 0, '*'); @@ -372,19 +371,6 @@ export class PlatformAdapter { ); } - /** - * extract_archive: rac_result_t (*)(const char* archive_path, const char* dest_dir, - * progress_cb, void* cb_user_data, void* user_data) - * Note: 5 params in C - */ - private registerExtractArchive(m: LlamaCppModule): number { - return m.addFunction((_archivePtr: number, _destPtr: number, _progressCb: number, _cbUserData: number, _userData: number): number => { - // Archive extraction not yet implemented for WASM - logger.warning('Archive extraction not yet implemented for WASM'); - return -180; - }, 'iiiiii'); - } - // ----------------------------------------------------------------------- // Helpers // ----------------------------------------------------------------------- diff --git a/sdk/runanywhere-web/wasm/CMakeLists.txt b/sdk/runanywhere-web/wasm/CMakeLists.txt index b86aca4e3..1ef586acd 100644 --- a/sdk/runanywhere-web/wasm/CMakeLists.txt +++ b/sdk/runanywhere-web/wasm/CMakeLists.txt @@ -35,7 +35,7 @@ option(RAC_WASM_WEBGPU "Enable WebGPU acceleration for llama.cpp (requires JSPI- # C++ CONFIGURATION # ============================================================================= -set(CMAKE_CXX_STANDARD 17) +set(CMAKE_CXX_STANDARD 20) set(CMAKE_CXX_STANDARD_REQUIRED ON) set(CMAKE_CXX_EXTENSIONS OFF)