Skip to content

Commit e99dc84

Browse files
committed
Remove unnecessary calls to openCV in Android
1 parent f23cfe7 commit e99dc84

File tree

1 file changed

+24
-38
lines changed

1 file changed

+24
-38
lines changed

android/src/main/java/com/swmansion/rnexecutorch/models/imageSegmentation/ImageSegmentationModel.kt

+24-38
Original file line numberDiff line numberDiff line change
@@ -34,57 +34,56 @@ class ImageSegmentationModel(
3434
}
3535

3636
private fun extractResults(
37-
result: Array<Float>,
37+
result: FloatArray,
3838
numLabels: Int,
3939
resize: Boolean,
40-
): List<Mat> {
40+
): List<FloatArray> {
4141
val modelSize = getModelImageSize()
4242
val numModelPixels = (modelSize.height * modelSize.width).toInt()
4343

44-
val extractedLabelScores = mutableListOf<Mat>()
44+
val extractedLabelScores = mutableListOf<FloatArray>()
4545

4646
for (label in 0..<numLabels) {
47-
val mat = Mat(modelSize, CvType.CV_32F)
48-
49-
for (pixel in 0..<numModelPixels) {
50-
val row = pixel / modelSize.width.toInt()
51-
val col = pixel % modelSize.width.toInt()
52-
val v = floatArrayOf(result[label * numModelPixels + pixel])
53-
mat.put(row, col, v)
54-
}
47+
// Calls to OpenCV via JNI are very slow so we do as much as we can
48+
// with pure Kotlin
49+
val range = IntRange(label * numModelPixels, (label + 1) * numModelPixels - 1)
50+
val pixelBuffer = result.slice(range).toFloatArray()
5551

5652
if (resize) {
53+
// Rescale the image with OpenCV
54+
val mat = Mat(modelSize, CvType.CV_32F)
55+
mat.put(0, 0, pixelBuffer)
5756
val resizedMat = Mat()
5857
Imgproc.resize(mat, resizedMat, originalSize)
59-
extractedLabelScores.add(resizedMat)
58+
val resizedBuffer = FloatArray((originalSize.height * originalSize.width).toInt())
59+
resizedMat.get(0, 0, resizedBuffer)
60+
extractedLabelScores.add(resizedBuffer)
6061
} else {
61-
extractedLabelScores.add(mat)
62+
extractedLabelScores.add(pixelBuffer)
6263
}
6364
}
6465
return extractedLabelScores
6566
}
6667

6768
private fun adjustScoresPerPixel(
68-
labelScores: List<Mat>,
69+
labelScores: List<FloatArray>,
6970
numLabels: Int,
7071
outputSize: Size,
71-
): Mat {
72-
val argMax = Mat(outputSize, CvType.CV_32S)
72+
): IntArray {
7373
val numPixels = (outputSize.height * outputSize.width).toInt()
74+
val argMax = IntArray(numPixels)
7475
for (pixel in 0..<numPixels) {
75-
val row = pixel / outputSize.width.toInt()
76-
val col = pixel % outputSize.width.toInt()
7776
val scores = mutableListOf<Float>()
78-
for (mat in labelScores) {
79-
scores.add(mat.get(row, col)[0].toFloat())
77+
for (buffer in labelScores) {
78+
scores.add(buffer[pixel])
8079
}
8180
val adjustedScores = softmax(scores.toTypedArray())
8281
for (label in 0..<numLabels) {
83-
labelScores[label].put(row, col, floatArrayOf(adjustedScores[label]))
82+
labelScores[label][pixel] = adjustedScores[label]
8483
}
8584

8685
val maxIndex = scores.withIndex().maxBy { it.value }.index
87-
argMax.put(row, col, intArrayOf(maxIndex))
86+
argMax[pixel] = maxIndex
8887
}
8988

9089
return argMax
@@ -95,14 +94,13 @@ class ImageSegmentationModel(
9594
classesOfInterest: ReadableArray,
9695
resize: Boolean,
9796
): WritableMap {
98-
val output = output[0].toTensor().dataAsFloatArray.toTypedArray()
97+
val output = output[0].toTensor().dataAsFloatArray
9998
val modelSize = getModelImageSize()
10099
val numLabels = deeplabv3_resnet50_labels.size
101100

102101
require(output.count() == (numLabels * modelSize.height * modelSize.width).toInt()) { "Model generated unexpected output size." }
103102

104103
val outputSize = if (resize) originalSize else modelSize
105-
val numOutputPixels = (outputSize.height * outputSize.width).toInt()
106104

107105
val extractedResults = extractResults(output, numLabels, resize)
108106

@@ -118,28 +116,16 @@ class ImageSegmentationModel(
118116

119117
for (label in 0..<numLabels) {
120118
if (labelSet.contains(deeplabv3_resnet50_labels[label])) {
121-
val buffer = FloatArray(numOutputPixels)
122-
for (pixel in 0..<numOutputPixels) {
123-
val row = pixel / outputSize.width.toInt()
124-
val col = pixel % outputSize.width.toInt()
125-
buffer[pixel] = extractedResults[label].get(row, col)[0].toFloat()
126-
}
127119
res.putArray(
128120
deeplabv3_resnet50_labels[label],
129-
ArrayUtils.createReadableArrayFromFloatArray(buffer),
121+
ArrayUtils.createReadableArrayFromFloatArray(extractedResults[label]),
130122
)
131123
}
132124
}
133125

134-
val argMaxBuffer = IntArray(numOutputPixels)
135-
for (pixel in 0..<numOutputPixels) {
136-
val row = pixel / outputSize.width.toInt()
137-
val col = pixel % outputSize.width.toInt()
138-
argMaxBuffer[pixel] = argMax.get(row, col)[0].toInt()
139-
}
140126
res.putArray(
141127
"argmax",
142-
ArrayUtils.createReadableArrayFromIntArray(argMaxBuffer),
128+
ArrayUtils.createReadableArrayFromIntArray(argMax),
143129
)
144130

145131
return res

0 commit comments

Comments
 (0)