Skip to content

Commit 209a8ec

Browse files
committed
Add optional resize for Android segmentation
1 parent d780a87 commit 209a8ec

File tree

2 files changed

+100
-77
lines changed

2 files changed

+100
-77
lines changed

android/src/main/java/com/swmansion/rnexecutorch/ImageSegmentation.kt

+19-17
Original file line numberDiff line numberDiff line change
@@ -2,32 +2,34 @@ package com.swmansion.rnexecutorch
22

33
import android.util.Log
44
import com.facebook.react.bridge.Promise
5-
import com.facebook.react.bridge.WritableMap
6-
import com.facebook.react.bridge.ReadableArray
75
import com.facebook.react.bridge.ReactApplicationContext
8-
import com.swmansion.rnexecutorch.utils.ETError
6+
import com.facebook.react.bridge.ReadableArray
97
import com.swmansion.rnexecutorch.models.imagesegmentation.ImageSegmentationModel
8+
import com.swmansion.rnexecutorch.utils.ETError
109
import com.swmansion.rnexecutorch.utils.ImageProcessor
1110
import org.opencv.android.OpenCVLoader
1211

13-
class ImageSegmentation(reactContext: ReactApplicationContext) :
14-
NativeImageSegmentationSpec(reactContext) {
15-
12+
class ImageSegmentation(
13+
reactContext: ReactApplicationContext,
14+
) : NativeImageSegmentationSpec(reactContext) {
1615
private lateinit var model: ImageSegmentationModel
1716

1817
companion object {
1918
const val NAME = "ImageSegmentation"
2019

2120
init {
22-
if(!OpenCVLoader.initLocal()){
21+
if (!OpenCVLoader.initLocal()) {
2322
Log.d("rn_executorch", "OpenCV not loaded")
2423
} else {
2524
Log.d("rn_executorch", "OpenCV loaded")
2625
}
2726
}
2827
}
2928

30-
override fun loadModule(modelSource: String, promise: Promise) {
29+
override fun loadModule(
30+
modelSource: String,
31+
promise: Promise,
32+
) {
3133
try {
3234
model = ImageSegmentationModel(reactApplicationContext)
3335
model.loadModel(modelSource)
@@ -37,20 +39,20 @@ class ImageSegmentation(reactContext: ReactApplicationContext) :
3739
}
3840
}
3941

40-
override fun forward(input: String,
41-
classesOfInterest: ReadableArray,
42-
resize:Boolean,
43-
promise: Promise) {
42+
override fun forward(
43+
input: String,
44+
classesOfInterest: ReadableArray,
45+
resize: Boolean,
46+
promise: Promise,
47+
) {
4448
try {
4549
val output =
46-
model.runModel(Pair(ImageProcessor.readImage(input), classesOfInterest))
50+
model.runModel(Triple(ImageProcessor.readImage(input), classesOfInterest, resize))
4751
promise.resolve(output)
48-
}catch(e: Exception){
52+
} catch (e: Exception) {
4953
promise.reject(e.message!!, e.message)
5054
}
5155
}
5256

53-
override fun getName(): String {
54-
return NAME
55-
}
57+
override fun getName(): String = NAME
5658
}
Original file line numberDiff line numberDiff line change
@@ -1,22 +1,22 @@
11
package com.swmansion.rnexecutorch.models.imagesegmentation
22

3-
import com.facebook.react.bridge.Arguments;
3+
import com.facebook.react.bridge.Arguments
4+
import com.facebook.react.bridge.ReactApplicationContext
45
import com.facebook.react.bridge.ReadableArray
56
import com.facebook.react.bridge.WritableMap
6-
import com.facebook.react.bridge.ReactApplicationContext
7+
import com.swmansion.rnexecutorch.models.BaseModel
8+
import com.swmansion.rnexecutorch.utils.ArrayUtils
79
import com.swmansion.rnexecutorch.utils.ImageProcessor
810
import com.swmansion.rnexecutorch.utils.softmax
9-
import org.opencv.core.Mat
1011
import org.opencv.core.CvType
12+
import org.opencv.core.Mat
1113
import org.opencv.core.Size
1214
import org.opencv.imgproc.Imgproc
13-
import org.pytorch.executorch.Tensor
1415
import org.pytorch.executorch.EValue
15-
import com.swmansion.rnexecutorch.models.BaseModel
16-
import com.swmansion.rnexecutorch.utils.ArrayUtils
1716

18-
class ImageSegmentationModel(reactApplicationContext: ReactApplicationContext)
19-
: BaseModel <Pair<Mat, ReadableArray>, WritableMap>(reactApplicationContext) {
17+
class ImageSegmentationModel(
18+
reactApplicationContext: ReactApplicationContext,
19+
) : BaseModel<Triple<Mat, ReadableArray, Boolean>, WritableMap>(reactApplicationContext) {
2020
private lateinit var originalSize: Size
2121

2222
private fun getModelImageSize(): Size {
@@ -27,44 +27,53 @@ class ImageSegmentationModel(reactApplicationContext: ReactApplicationContext)
2727
return Size(height.toDouble(), width.toDouble())
2828
}
2929

30-
override fun preprocess(input: Pair<Mat, ReadableArray>): EValue {
31-
originalSize = input.first.size()
32-
Imgproc.resize(input.first, input.first, getModelImageSize())
33-
return ImageProcessor.matToEValue(input.first, module.getInputShape(0))
30+
fun preprocess(input: Mat): EValue {
31+
originalSize = input.size()
32+
Imgproc.resize(input, input, getModelImageSize())
33+
return ImageProcessor.matToEValue(input, module.getInputShape(0))
3434
}
3535

36-
private fun rescaleResults(result: Array<Float>, numLabels: Int)
37-
: List<Mat> {
38-
val modelShape = getModelImageSize()
39-
val numModelPixels = (modelShape.height * modelShape.width).toInt()
36+
private fun extractResults(
37+
result: Array<Float>,
38+
numLabels: Int,
39+
resize: Boolean,
40+
): List<Mat> {
41+
val modelSize = getModelImageSize()
42+
val numModelPixels = (modelSize.height * modelSize.width).toInt()
4043

41-
val resizedLabelScores = mutableListOf<Mat>()
44+
val extractedLabelScores = mutableListOf<Mat>()
4245

4346
for (label in 0..<numLabels) {
44-
val mat = Mat(modelShape, CvType.CV_32F)
47+
val mat = Mat(modelSize, CvType.CV_32F)
4548

4649
for (pixel in 0..<numModelPixels) {
47-
val row = pixel / modelShape.width.toInt()
48-
val col = pixel % modelShape.width.toInt()
50+
val row = pixel / modelSize.width.toInt()
51+
val col = pixel % modelSize.width.toInt()
4952
val v = floatArrayOf(result[label * numModelPixels + pixel])
5053
mat.put(row, col, v)
5154
}
5255

53-
val resizedMat = Mat()
54-
Imgproc.resize(mat, resizedMat, originalSize)
55-
resizedLabelScores.add(resizedMat)
56+
if (resize) {
57+
val resizedMat = Mat()
58+
Imgproc.resize(mat, resizedMat, originalSize)
59+
extractedLabelScores.add(resizedMat)
60+
} else {
61+
extractedLabelScores.add(mat)
62+
}
5663
}
57-
return resizedLabelScores;
64+
return extractedLabelScores
5865
}
5966

60-
private fun adjustScoresPerPixel(labelScores: List<Mat>, numLabels: Int)
61-
: Mat {
62-
val argMax = Mat(originalSize, CvType.CV_32S)
63-
val numOriginalPixels = (originalSize.height * originalSize.width).toInt()
64-
android.util.Log.d("ETTT", "adjustScoresPerPixel: start")
65-
for (pixel in 0..<numOriginalPixels) {
66-
val row = pixel / originalSize.width.toInt()
67-
val col = pixel % originalSize.width.toInt()
67+
private fun adjustScoresPerPixel(
68+
labelScores: List<Mat>,
69+
numLabels: Int,
70+
outputSize: Size,
71+
): Mat {
72+
val argMax = Mat(outputSize, CvType.CV_32S)
73+
val numPixels = (outputSize.height * outputSize.width).toInt()
74+
for (pixel in 0..<numPixels) {
75+
val row = pixel / outputSize.width.toInt()
76+
val col = pixel % outputSize.width.toInt()
6877
val scores = mutableListOf<Float>()
6978
for (mat in labelScores) {
7079
scores.add(mat.get(row, col)[0].toFloat())
@@ -74,59 +83,71 @@ class ImageSegmentationModel(reactApplicationContext: ReactApplicationContext)
7483
labelScores[label].put(row, col, floatArrayOf(adjustedScores[label]))
7584
}
7685

77-
val maxIndex = scores.withIndex().maxBy{it.value}.index
86+
val maxIndex = scores.withIndex().maxBy { it.value }.index
7887
argMax.put(row, col, intArrayOf(maxIndex))
7988
}
8089

8190
return argMax
8291
}
8392

84-
override fun postprocess(output: Array<EValue>)
85-
: WritableMap {
93+
fun postprocess(
94+
output: Array<EValue>,
95+
classesOfInterest: ReadableArray,
96+
resize: Boolean,
97+
): WritableMap {
8698
val output = output[0].toTensor().dataAsFloatArray.toTypedArray()
87-
val modelShape = getModelImageSize()
88-
val numLabels = deeplabv3_resnet50_labels.size;
89-
val numOriginalPixels = (originalSize.height * originalSize.width).toInt()
99+
val modelSize = getModelImageSize()
100+
val numLabels = deeplabv3_resnet50_labels.size
90101

91-
require(output.count() == (numLabels * modelShape.height * modelShape.width).toInt())
92-
{"Model generated unexpected output size."}
102+
require(output.count() == (numLabels * modelSize.height * modelSize.width).toInt()) { "Model generated unexpected output size." }
93103

94-
val rescaledResults = rescaleResults(output, numLabels)
104+
val outputSize = if (resize) originalSize else modelSize
105+
val numOutputPixels = (outputSize.height * outputSize.width).toInt()
95106

96-
val argMax = adjustScoresPerPixel(rescaledResults, numLabels)
107+
val extractedResults = extractResults(output, numLabels, resize)
97108

98-
// val labelSet = mutableSetOf<String>()
109+
val argMax = adjustScoresPerPixel(extractedResults, numLabels, outputSize)
110+
111+
val labelSet = mutableSetOf<String>()
99112
// Filter by the label set when base class changed
113+
for (i in 0..<classesOfInterest.size()) {
114+
labelSet.add(classesOfInterest.getString(i))
115+
}
100116

101117
val res = Arguments.createMap()
102118

103119
for (label in 0..<numLabels) {
104-
val buffer = FloatArray(numOriginalPixels)
105-
for (pixel in 0..<numOriginalPixels) {
106-
val row = pixel / originalSize.width.toInt()
107-
val col = pixel % originalSize.width.toInt()
108-
buffer[pixel] = rescaledResults[label].get(row, col)[0].toFloat()
120+
if (labelSet.contains(deeplabv3_resnet50_labels[label])) {
121+
val buffer = FloatArray(numOutputPixels)
122+
for (pixel in 0..<numOutputPixels) {
123+
val row = pixel / outputSize.width.toInt()
124+
val col = pixel % outputSize.width.toInt()
125+
buffer[pixel] = extractedResults[label].get(row, col)[0].toFloat()
126+
}
127+
res.putArray(
128+
deeplabv3_resnet50_labels[label],
129+
ArrayUtils.createReadableArrayFromFloatArray(buffer),
130+
)
109131
}
110-
res.putArray(deeplabv3_resnet50_labels[label],
111-
ArrayUtils.createReadableArrayFromFloatArray(buffer))
112132
}
113133

114-
val argMaxBuffer = IntArray(numOriginalPixels)
115-
for (pixel in 0..<numOriginalPixels) {
116-
val row = pixel / originalSize.width.toInt()
117-
val col = pixel % originalSize.width.toInt()
134+
val argMaxBuffer = IntArray(numOutputPixels)
135+
for (pixel in 0..<numOutputPixels) {
136+
val row = pixel / outputSize.width.toInt()
137+
val col = pixel % outputSize.width.toInt()
118138
argMaxBuffer[pixel] = argMax.get(row, col)[0].toInt()
119139
}
120-
res.putArray("argmax",
121-
ArrayUtils.createReadableArrayFromIntArray(argMaxBuffer))
140+
res.putArray(
141+
"argmax",
142+
ArrayUtils.createReadableArrayFromIntArray(argMaxBuffer),
143+
)
122144

123145
return res
124146
}
125147

126-
override fun runModel(input: Pair<Mat, ReadableArray>)
127-
: WritableMap {
128-
val modelInput = preprocess(input)
148+
override fun runModel(input: Triple<Mat, ReadableArray, Boolean>): WritableMap {
149+
val modelInput = preprocess(input.first)
129150
val modelOutput = forward(modelInput)
130-
return postprocess(modelOutput)
151+
return postprocess(modelOutput, input.second, input.third)
131152
}
132153
}

0 commit comments

Comments
 (0)