Skip to content

Commit 240af5b

Browse files
authored
@jakmro/classification android (#55)
## Description Image classification for android ### Type of change - [ ] Bug fix (non-breaking change which fixes an issue) - [x] New feature (non-breaking change which adds functionality) - [ ] Breaking change (fix or feature that would cause existing functionality to not work as expected) - [ ] Documentation update (improves or adds clarity to existing documentation) ### Tested on - [ ] iOS - [x] Android ### Checklist - [x] I have performed a self-review of my code - [x] I have commented my code, particularly in hard-to-understand areas - [ ] I have updated the documentation accordingly - [x] My changes generate no new warnings
1 parent 8d0089a commit 240af5b

File tree

7 files changed

+1140
-10
lines changed

7 files changed

+1140
-10
lines changed
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,59 @@
1+
package com.swmansion.rnexecutorch
2+
3+
import android.util.Log
4+
import com.facebook.react.bridge.Promise
5+
import com.facebook.react.bridge.ReactApplicationContext
6+
import com.swmansion.rnexecutorch.models.classification.ClassificationModel
7+
import com.swmansion.rnexecutorch.utils.ETError
8+
import com.swmansion.rnexecutorch.utils.ImageProcessor
9+
import org.opencv.android.OpenCVLoader
10+
import com.facebook.react.bridge.Arguments
11+
import com.facebook.react.bridge.WritableMap
12+
13+
class Classification(reactContext: ReactApplicationContext) :
14+
NativeClassificationSpec(reactContext) {
15+
16+
private lateinit var classificationModel: ClassificationModel
17+
18+
companion object {
19+
const val NAME = "Classification"
20+
init {
21+
if(!OpenCVLoader.initLocal()){
22+
Log.d("rn_executorch", "OpenCV not loaded")
23+
} else {
24+
Log.d("rn_executorch", "OpenCV loaded")
25+
}
26+
}
27+
}
28+
29+
override fun loadModule(modelSource: String, promise: Promise) {
30+
try {
31+
classificationModel = ClassificationModel(reactApplicationContext)
32+
classificationModel.loadModel(modelSource)
33+
promise.resolve(0)
34+
} catch (e: Exception) {
35+
promise.reject(e.message!!, ETError.InvalidModelPath.toString())
36+
}
37+
}
38+
39+
override fun forward(input: String, promise: Promise) {
40+
try {
41+
val image = ImageProcessor.readImage(input)
42+
val output = classificationModel.runModel(image)
43+
44+
val writableMap: WritableMap = Arguments.createMap()
45+
46+
for ((key, value) in output) {
47+
writableMap.putDouble(key, value.toDouble())
48+
}
49+
50+
promise.resolve(writableMap)
51+
}catch(e: Exception){
52+
promise.reject(e.message!!, e.message)
53+
}
54+
}
55+
56+
override fun getName(): String {
57+
return NAME
58+
}
59+
}

android/src/main/java/com/swmansion/rnexecutorch/RnExecutorchPackage.kt

+13-1
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,10 @@ class RnExecutorchPackage : TurboReactPackage() {
1919
ETModule(reactContext)
2020
} else if (name == StyleTransfer.NAME) {
2121
StyleTransfer(reactContext)
22-
} else {
22+
} else if (name == Classification.NAME) {
23+
Classification(reactContext)
24+
}
25+
else {
2326
null
2427
}
2528

@@ -51,6 +54,15 @@ class RnExecutorchPackage : TurboReactPackage() {
5154
false, // isCxxModule
5255
true
5356
)
57+
58+
moduleInfos[Classification.NAME] = ReactModuleInfo(
59+
Classification.NAME,
60+
Classification.NAME,
61+
false, // canOverrideExistingModule
62+
false, // needsEagerInit
63+
false, // isCxxModule
64+
true
65+
)
5466
moduleInfos
5567
}
5668
}

android/src/main/java/com/swmansion/rnexecutorch/models/BaseModel.kt

+2-2
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ abstract class BaseModel<Input, Output>(val context: Context) {
3838

3939
abstract fun runModel(input: Input): Output
4040

41-
protected abstract fun preprocess(input: Input): Input
41+
protected abstract fun preprocess(input: Input): EValue
4242

43-
protected abstract fun postprocess(input: Tensor): Output
43+
protected abstract fun postprocess(output: Array<EValue>): Output
4444
}

android/src/main/java/com/swmansion/rnexecutorch/models/StyleTransferModel.kt

+9-7
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ import org.opencv.core.Mat
66
import org.opencv.core.Size
77
import org.opencv.imgproc.Imgproc
88
import org.pytorch.executorch.Tensor
9+
import org.pytorch.executorch.EValue
910

1011

1112
class StyleTransferModel(reactApplicationContext: ReactApplicationContext) : BaseModel<Mat, Mat>(reactApplicationContext) {
@@ -19,22 +20,23 @@ class StyleTransferModel(reactApplicationContext: ReactApplicationContext) : Bas
1920
return Size(height.toDouble(), width.toDouble())
2021
}
2122

22-
override fun preprocess(input: Mat): Mat {
23+
override fun preprocess(input: Mat): EValue {
2324
originalSize = input.size()
2425
Imgproc.resize(input, input, getModelImageSize())
25-
return input
26+
return ImageProcessor.matToEValue(input, module.getInputShape(0))
2627
}
2728

28-
override fun postprocess(input: Tensor): Mat {
29+
override fun postprocess(output: Array<EValue>): Mat {
30+
val tensor = output[0].toTensor()
2931
val modelShape = getModelImageSize()
30-
val result = ImageProcessor.EValueToMat(input.dataAsFloatArray, modelShape.width.toInt(), modelShape.height.toInt())
32+
val result = ImageProcessor.EValueToMat(tensor.dataAsFloatArray, modelShape.width.toInt(), modelShape.height.toInt())
3133
Imgproc.resize(result, result, originalSize)
3234
return result
3335
}
3436

3537
override fun runModel(input: Mat): Mat {
36-
val inputTensor = ImageProcessor.matToEValue(preprocess(input), module.getInputShape(0))
37-
val outputTensor = forward(inputTensor)
38-
return postprocess(outputTensor[0].toTensor())
38+
val modelInput = preprocess(input)
39+
val modelOutput = forward(modelInput)
40+
return postprocess(modelOutput)
3941
}
4042
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
package com.swmansion.rnexecutorch.models.classification
2+
3+
import com.facebook.react.bridge.ReactApplicationContext
4+
import com.swmansion.rnexecutorch.utils.ImageProcessor
5+
import org.opencv.core.Mat
6+
import org.opencv.core.Size
7+
import org.opencv.imgproc.Imgproc
8+
import org.pytorch.executorch.Tensor
9+
import org.pytorch.executorch.EValue
10+
import com.swmansion.rnexecutorch.models.BaseModel
11+
12+
13+
class ClassificationModel(reactApplicationContext: ReactApplicationContext) : BaseModel<Mat, Map<String, Float>>(reactApplicationContext) {
14+
private fun getModelImageSize(): Size {
15+
val inputShape = module.getInputShape(0)
16+
val width = inputShape[inputShape.lastIndex]
17+
val height = inputShape[inputShape.lastIndex - 1]
18+
19+
return Size(height.toDouble(), width.toDouble())
20+
}
21+
22+
override fun preprocess(input: Mat): EValue {
23+
Imgproc.resize(input, input, getModelImageSize())
24+
return ImageProcessor.matToEValue(input, module.getInputShape(0))
25+
}
26+
27+
override fun postprocess(output: Array<EValue>): Map<String, Float> {
28+
val tensor = output[0].toTensor()
29+
val probabilities = softmax(tensor.dataAsFloatArray.toTypedArray())
30+
31+
val result = mutableMapOf<String, Float>()
32+
33+
for (i in probabilities.indices) {
34+
result[imagenet1k_v1_labels[i]] = probabilities[i]
35+
}
36+
37+
return result
38+
}
39+
40+
override fun runModel(input: Mat): Map<String, Float> {
41+
val modelInput = preprocess(input)
42+
val modelOutput = forward(modelInput)
43+
return postprocess(modelOutput)
44+
}
45+
}

0 commit comments

Comments
 (0)