Skip to content

Commit d874dda

Browse files
authored
Merge pull request #19 from qingfeng19491001/fix/tflite-init-no-gms
fix(ai): fallback to bundled TFLite runtime on non-GMS devices
2 parents 126f3dd + 409fbd0 commit d874dda

File tree

5 files changed

+80
-38
lines changed

5 files changed

+80
-38
lines changed

gradle/libs.versions.toml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,8 @@ circleimageview = "3.1.0"
8484
simpleMvp = "1.0.2"
8585
dtoast = "1.1.5"
8686
preferenceKtx = "1.2.1"
87+
tensorflowLite = "2.14.0"
88+
tensorflowLiteSupport = "0.4.4"
8789

8890
[libraries]
8991
androidx-core-ktx = { group = "androidx.core", name = "core-ktx", version.ref = "coreKtx" }
@@ -175,6 +177,8 @@ circleimageview = { module = "de.hdodenhof:circleimageview", version.ref = "circ
175177
jaredrummler-simple-mvp = { module = "com.jaredrummler:simple-mvp", version.ref = "simpleMvp" }
176178
dtoast = { module = "com.github.Dovar66:DToast", version.ref = "dtoast" }
177179
androidx-preference-ktx = { module = "androidx.preference:preference-ktx", version.ref = "preferenceKtx" }
180+
tensorflow-lite = { module = "org.tensorflow:tensorflow-lite", version.ref = "tensorflowLite" }
181+
tensorflow-lite-support = { module = "org.tensorflow:tensorflow-lite-support", version.ref = "tensorflowLiteSupport" }
178182

179183
[plugins]
180184
androidApplication = { id = "com.android.application", version.ref = "agp" }

subs/ai/build.gradle.kts

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -30,11 +30,11 @@ android {
3030
}
3131

3232
compileOptions {
33-
sourceCompatibility = JavaVersion.VERSION_21
34-
targetCompatibility = JavaVersion.VERSION_21
33+
sourceCompatibility = JavaVersion.VERSION_17
34+
targetCompatibility = JavaVersion.VERSION_17
3535
}
3636
kotlinOptions {
37-
jvmTarget = "21"
37+
jvmTarget = "17"
3838
}
3939
androidResources {
4040
noCompress("tflite")
@@ -58,4 +58,8 @@ dependencies {
5858

5959
implementation(libs.play.services.tflite.java)
6060
implementation(libs.play.services.tflite.support)
61+
62+
implementation(libs.tensorflow.lite)
63+
implementation(libs.tensorflow.lite.support)
64+
6165
}

subs/ai/src/main/java/com/engineer/ai/GanActivity.kt

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -60,9 +60,13 @@ class GanActivity : AppCompatActivity() {
6060
}
6161

6262
private fun genBitmap() {
63-
TensorFlowLiteHelper.init(this) {
64-
val interpreterApi = TensorFlowLiteHelper.createInterpreterApi(this, "dcgan.tflite")
65-
interpreterApi?.let {
63+
TensorFlowLiteHelper.init(this) { playServicesOk ->
64+
val interpreterApi = TensorFlowLiteHelper.createInterpreterApi(
65+
context = this,
66+
modelName = "dcgan.tflite",
67+
preferPlayServices = playServicesOk
68+
)
69+
interpreterApi.let {
6670
Log.d(TAG, interpreterApi.getInputTensor(0).shape().contentToString())
6771
Log.d(TAG, interpreterApi.getOutputTensor(0).shape().contentToString())
6872

subs/ai/src/main/java/com/engineer/ai/util/DigitClassifier.kt

Lines changed: 28 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -40,21 +40,35 @@ class DigitClassifier(private val context: Context) {
4040
private var interpreter: InterpreterApi? = null
4141

4242
fun initialize(cb: (Boolean) -> Unit) {
43-
TensorFlowLiteHelper.init(context) {
44-
cb(it)
45-
if (it) {
46-
interpreter = TensorFlowLiteHelper.createInterpreterApi(context, "mnist.tflite")
47-
// Read input shape from model file
48-
interpreter?.let { inter ->
49-
val inputShape = inter.getInputTensor(0).shape()
50-
Log.d(TAG, "input shape = ${inputShape.contentToString()}")
51-
Log.d(TAG, "elem shape = ${inter.getInputTensor(0).numElements()}")
52-
Log.d(TAG, "output shape = ${inter.getOutputTensor(0).shape().contentToString()}")
53-
inputImageWidth = inputShape[1]
54-
inputImageHeight = inputShape[2]
55-
modelInputSize = FLOAT_TYPE_SIZE * inputImageWidth * inputImageHeight * PIXEL_SIZE
56-
isInitialized = true
43+
TensorFlowLiteHelper.init(context) { playServicesOk ->
44+
try {
45+
interpreter = TensorFlowLiteHelper.createInterpreterApi(
46+
context = context,
47+
modelName = "mnist.tflite",
48+
preferPlayServices = playServicesOk
49+
)
50+
51+
val inter = interpreter
52+
if (inter == null) {
53+
isInitialized = false
54+
cb(false)
55+
return@init
5756
}
57+
58+
val inputShape = inter.getInputTensor(0).shape()
59+
Log.d(TAG, "input shape = ${inputShape.contentToString()}")
60+
Log.d(TAG, "elem shape = ${inter.getInputTensor(0).numElements()}")
61+
Log.d(TAG, "output shape = ${inter.getOutputTensor(0).shape().contentToString()}")
62+
63+
inputImageWidth = inputShape[1]
64+
inputImageHeight = inputShape[2]
65+
modelInputSize = FLOAT_TYPE_SIZE * inputImageWidth * inputImageHeight * PIXEL_SIZE
66+
isInitialized = true
67+
cb(true)
68+
} catch (t: Throwable) {
69+
Log.e(TAG, "Failed to initialize DigitClassifier.", t)
70+
isInitialized = false
71+
cb(false)
5872
}
5973
}
6074
}

subs/ai/src/main/java/com/engineer/ai/util/TensorFlowLiteHelper.kt

Lines changed: 34 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -16,29 +16,45 @@ import java.nio.channels.FileChannel
1616

1717
object TensorFlowLiteHelper {
1818
private const val TAG = "TensorFlowLiteHelper"
19-
private lateinit var initializeTask: Task<Void>
20-
private var interpreter: InterpreterApi? = null
2119

20+
/**
21+
* Try to init Google Play Services TFLite (dynamite module).
22+
*
23+
* This will fail on devices without Google Play Services (e.g. many CN ROMs).
24+
* We treat failure as non-fatal and fall back to bundled TFLite runtime.
25+
*/
2226
fun init(context: Context, cb: (Boolean) -> Unit) {
23-
initializeTask = TfLite.initialize(context)
24-
initializeTask.addOnSuccessListener {
25-
Log.d(TAG, "Initialized TFLite interpreter.")
26-
Log.d(TAG, "ver ${TensorFlowLite.schemaVersion(InterpreterApi.Options.TfLiteRuntime.FROM_SYSTEM_ONLY)}")
27-
Log.d(TAG, "ver ${TensorFlowLite.runtimeVersion(InterpreterApi.Options.TfLiteRuntime.FROM_SYSTEM_ONLY)}")
28-
cb(true)
29-
}.addOnFailureListener {
30-
Log.d(TAG, "Initialized TFLite fail")
31-
cb(false)
32-
Log.e(TAG, "error ", it)
33-
}
27+
TfLite.initialize(context)
28+
.addOnSuccessListener {
29+
Log.d(TAG, "Initialized Play Services TFLite.")
30+
try {
31+
Log.d(
32+
TAG,
33+
"schema=${TensorFlowLite.schemaVersion(InterpreterApi.Options.TfLiteRuntime.FROM_SYSTEM_ONLY)} runtime=${TensorFlowLite.runtimeVersion(InterpreterApi.Options.TfLiteRuntime.FROM_SYSTEM_ONLY)}"
34+
)
35+
} catch (t: Throwable) {
36+
Log.w(TAG, "Unable to query system-only TFLite version.", t)
37+
}
38+
cb(true)
39+
}
40+
.addOnFailureListener { e ->
41+
Log.w(TAG, "Play Services TFLite init failed; will fall back to bundled runtime.", e)
42+
cb(false)
43+
}
3444
}
3545

36-
fun createInterpreterApi(context: Context, modelName: String): InterpreterApi? {
46+
fun createInterpreterApi(context: Context, modelName: String, preferPlayServices: Boolean): InterpreterApi {
3747
val model = loadModelFile(context.assets, modelName)
38-
val interpreterOption =
39-
InterpreterApi.Options().setRuntime(InterpreterApi.Options.TfLiteRuntime.FROM_SYSTEM_ONLY)
40-
interpreter = InterpreterApi.create(model, interpreterOption)
41-
return interpreter
48+
49+
val runtime = if (preferPlayServices) {
50+
InterpreterApi.Options.TfLiteRuntime.FROM_SYSTEM_ONLY
51+
} else {
52+
// Bundled runtime provided by org.tensorflow:tensorflow-lite
53+
InterpreterApi.Options.TfLiteRuntime.FROM_APPLICATION_ONLY
54+
}
55+
56+
val options = InterpreterApi.Options().setRuntime(runtime)
57+
return InterpreterApi.create(model, options)
4258
}
4359

4460
@Throws(IOException::class)

0 commit comments

Comments
 (0)