@@ -19,18 +19,21 @@ import android.graphics.Bitmap
1919import android.util.Log
2020import com.google.android.gms.tasks.Task
2121import com.google.android.gms.tasks.TaskCompletionSource
22- import org.tensorflow.lite.Interpreter
22+ import com.google.android.gms.tflite.java.TfLite
23+ import org.tensorflow.lite.InterpreterApi
24+ import org.tensorflow.lite.TensorFlowLite
2325import java.io.FileInputStream
2426import java.io.IOException
2527import java.nio.ByteBuffer
2628import java.nio.ByteOrder
2729import java.nio.channels.FileChannel
2830import java.util.concurrent.ExecutorService
2931import java.util.concurrent.Executors
32+ import androidx.core.graphics.scale
3033
3134class DigitClassifier (private val context : Context ) {
3235
33- private var interpreter: Interpreter ? = null
36+ // private var interpreter: Interpreter? = null
3437
3538 var isInitialized = false
3639 private set
@@ -42,34 +45,41 @@ class DigitClassifier(private val context: Context) {
4245 private var inputImageHeight: Int = 0 // will be inferred from TF Lite model.
4346 private var modelInputSize: Int = 0 // will be inferred from TF Lite model.
4447
45- fun initialize (): Task <Void ?> {
46- val task = TaskCompletionSource <Void ?>()
47- executorService.execute {
48- try {
49- initializeInterpreter()
50- task.setResult(null )
51- } catch (e: IOException ) {
52- task.setException(e)
53- }
54- }
55- return task.task
56- }
48+ private val initializeTask: Task <Void > by lazy { TfLite .initialize(context) }
49+ private var interpreter: InterpreterApi ? = null
5750
58- @Throws(IOException ::class )
59- private fun initializeInterpreter () {
51+ fun initialize (cb : (Boolean ) -> Unit ) {
6052 val assetManager = context.assets
6153 val model = loadModelFile(assetManager, " mnist.tflite" )
62- val interpreter = Interpreter (model)
63- // Read input shape from model file
64- val inputShape = interpreter.getInputTensor(0 ).shape()
65- inputImageWidth = inputShape[1 ]
66- inputImageHeight = inputShape[2 ]
67- modelInputSize = FLOAT_TYPE_SIZE * inputImageWidth * inputImageHeight * PIXEL_SIZE
68-
69- // Finish interpreter initialization
70- this .interpreter = interpreter
71- isInitialized = true
72- Log .d(TAG , " Initialized TFLite interpreter." )
54+
55+ initializeTask.addOnSuccessListener {
56+ val interpreterOption =
57+ InterpreterApi .Options ().setRuntime(InterpreterApi .Options .TfLiteRuntime .FROM_SYSTEM_ONLY )
58+ interpreter = InterpreterApi .create(model, interpreterOption)
59+
60+ Log .d(TAG , " ver ${TensorFlowLite .schemaVersion(InterpreterApi .Options .TfLiteRuntime .FROM_SYSTEM_ONLY )} " )
61+ Log .d(TAG , " ver ${TensorFlowLite .runtimeVersion(InterpreterApi .Options .TfLiteRuntime .FROM_SYSTEM_ONLY )} " )
62+ // Read input shape from model file
63+ interpreter?.let {
64+ val inputShape = it.getInputTensor(0 ).shape()
65+ inputImageWidth = inputShape[1 ]
66+ inputImageHeight = inputShape[2 ]
67+ modelInputSize = FLOAT_TYPE_SIZE * inputImageWidth * inputImageHeight * PIXEL_SIZE
68+
69+
70+ isInitialized = true
71+ Log .d(TAG , " Initialized TFLite interpreter." )
72+ cb(true )
73+ } ? : run {
74+ Log .d(TAG , " Initialized TFLite fail." )
75+ }
76+
77+ }.addOnFailureListener { e ->
78+ cb(false )
79+ Log .e(TAG , " Cannot initialize interpreter" , e)
80+ }
81+
82+
7383 }
7484
7585 @Throws(IOException ::class )
@@ -87,9 +97,7 @@ class DigitClassifier(private val context: Context) {
8797
8898
8999 // Preprocessing: resize the input image to match the model input shape.
90- val resizedImage = Bitmap .createScaledBitmap(
91- bitmap, inputImageWidth, inputImageHeight, true
92- )
100+ val resizedImage = bitmap.scale(inputImageWidth, inputImageHeight)
93101 val byteBuffer = convertBitmapToByteBuffer(resizedImage)
94102 // Define an array to store the model output.
95103 val output = Array (1 ) { FloatArray (OUTPUT_CLASSES_COUNT ) }
@@ -99,7 +107,7 @@ class DigitClassifier(private val context: Context) {
99107 // Post-processing: find the digit that has the highest probability
100108 // and return it a human-readable string.
101109 val result = output[0 ]
102- val maxIndex = result.indices.maxBy { result[it] } ? : - 1
110+ val maxIndex = result.indices.maxBy { result[it] }
103111 val resultString = " Prediction Result: %d\n Confidence: %2f" .format(maxIndex, result[maxIndex])
104112
105113 return resultString
0 commit comments