1
1
package com.swmansion.rnexecutorch.models.imagesegmentation
2
2
3
- import com.facebook.react.bridge.Arguments;
3
+ import com.facebook.react.bridge.Arguments
4
+ import com.facebook.react.bridge.ReactApplicationContext
4
5
import com.facebook.react.bridge.ReadableArray
5
6
import com.facebook.react.bridge.WritableMap
6
- import com.facebook.react.bridge.ReactApplicationContext
7
+ import com.swmansion.rnexecutorch.models.BaseModel
8
+ import com.swmansion.rnexecutorch.utils.ArrayUtils
7
9
import com.swmansion.rnexecutorch.utils.ImageProcessor
8
10
import com.swmansion.rnexecutorch.utils.softmax
9
- import org.opencv.core.Mat
10
11
import org.opencv.core.CvType
12
+ import org.opencv.core.Mat
11
13
import org.opencv.core.Size
12
14
import org.opencv.imgproc.Imgproc
13
- import org.pytorch.executorch.Tensor
14
15
import org.pytorch.executorch.EValue
15
- import com.swmansion.rnexecutorch.models.BaseModel
16
- import com.swmansion.rnexecutorch.utils.ArrayUtils
17
16
18
- class ImageSegmentationModel (reactApplicationContext : ReactApplicationContext )
19
- : BaseModel <Pair <Mat , ReadableArray >, WritableMap > (reactApplicationContext) {
17
+ class ImageSegmentationModel (
18
+ reactApplicationContext : ReactApplicationContext ,
19
+ ) : BaseModel<Triple<Mat, ReadableArray, Boolean>, WritableMap>(reactApplicationContext) {
20
20
private lateinit var originalSize: Size
21
21
22
22
private fun getModelImageSize (): Size {
@@ -27,44 +27,53 @@ class ImageSegmentationModel(reactApplicationContext: ReactApplicationContext)
27
27
return Size (height.toDouble(), width.toDouble())
28
28
}
29
29
30
- override fun preprocess (input : Pair < Mat , ReadableArray > ): EValue {
31
- originalSize = input.first. size()
32
- Imgproc .resize(input.first , input.first , getModelImageSize())
33
- return ImageProcessor .matToEValue(input.first , module.getInputShape(0 ))
30
+ fun preprocess (input : Mat ): EValue {
31
+ originalSize = input.size()
32
+ Imgproc .resize(input, input, getModelImageSize())
33
+ return ImageProcessor .matToEValue(input, module.getInputShape(0 ))
34
34
}
35
35
36
- private fun rescaleResults (result : Array <Float >, numLabels : Int )
37
- : List <Mat > {
38
- val modelShape = getModelImageSize()
39
- val numModelPixels = (modelShape.height * modelShape.width).toInt()
36
+ private fun extractResults (
37
+ result : Array <Float >,
38
+ numLabels : Int ,
39
+ resize : Boolean ,
40
+ ): List <Mat > {
41
+ val modelSize = getModelImageSize()
42
+ val numModelPixels = (modelSize.height * modelSize.width).toInt()
40
43
41
- val resizedLabelScores = mutableListOf<Mat >()
44
+ val extractedLabelScores = mutableListOf<Mat >()
42
45
43
46
for (label in 0 .. < numLabels) {
44
- val mat = Mat (modelShape , CvType .CV_32F )
47
+ val mat = Mat (modelSize , CvType .CV_32F )
45
48
46
49
for (pixel in 0 .. < numModelPixels) {
47
- val row = pixel / modelShape .width.toInt()
48
- val col = pixel % modelShape .width.toInt()
50
+ val row = pixel / modelSize .width.toInt()
51
+ val col = pixel % modelSize .width.toInt()
49
52
val v = floatArrayOf(result[label * numModelPixels + pixel])
50
53
mat.put(row, col, v)
51
54
}
52
55
53
- val resizedMat = Mat ()
54
- Imgproc .resize(mat, resizedMat, originalSize)
55
- resizedLabelScores.add(resizedMat)
56
+ if (resize) {
57
+ val resizedMat = Mat ()
58
+ Imgproc .resize(mat, resizedMat, originalSize)
59
+ extractedLabelScores.add(resizedMat)
60
+ } else {
61
+ extractedLabelScores.add(mat)
62
+ }
56
63
}
57
- return resizedLabelScores;
64
+ return extractedLabelScores
58
65
}
59
66
60
- private fun adjustScoresPerPixel (labelScores : List <Mat >, numLabels : Int )
61
- : Mat {
62
- val argMax = Mat (originalSize, CvType .CV_32S )
63
- val numOriginalPixels = (originalSize.height * originalSize.width).toInt()
64
- android.util.Log .d(" ETTT" , " adjustScoresPerPixel: start" )
65
- for (pixel in 0 .. < numOriginalPixels) {
66
- val row = pixel / originalSize.width.toInt()
67
- val col = pixel % originalSize.width.toInt()
67
+ private fun adjustScoresPerPixel (
68
+ labelScores : List <Mat >,
69
+ numLabels : Int ,
70
+ outputSize : Size ,
71
+ ): Mat {
72
+ val argMax = Mat (outputSize, CvType .CV_32S )
73
+ val numPixels = (outputSize.height * outputSize.width).toInt()
74
+ for (pixel in 0 .. < numPixels) {
75
+ val row = pixel / outputSize.width.toInt()
76
+ val col = pixel % outputSize.width.toInt()
68
77
val scores = mutableListOf<Float >()
69
78
for (mat in labelScores) {
70
79
scores.add(mat.get(row, col)[0 ].toFloat())
@@ -74,59 +83,71 @@ class ImageSegmentationModel(reactApplicationContext: ReactApplicationContext)
74
83
labelScores[label].put(row, col, floatArrayOf(adjustedScores[label]))
75
84
}
76
85
77
- val maxIndex = scores.withIndex().maxBy{ it.value}.index
86
+ val maxIndex = scores.withIndex().maxBy { it.value }.index
78
87
argMax.put(row, col, intArrayOf(maxIndex))
79
88
}
80
89
81
90
return argMax
82
91
}
83
92
84
- override fun postprocess (output : Array <EValue >)
85
- : WritableMap {
93
+ fun postprocess (
94
+ output : Array <EValue >,
95
+ classesOfInterest : ReadableArray ,
96
+ resize : Boolean ,
97
+ ): WritableMap {
86
98
val output = output[0 ].toTensor().dataAsFloatArray.toTypedArray()
87
- val modelShape = getModelImageSize()
88
- val numLabels = deeplabv3_resnet50_labels.size;
89
- val numOriginalPixels = (originalSize.height * originalSize.width).toInt()
99
+ val modelSize = getModelImageSize()
100
+ val numLabels = deeplabv3_resnet50_labels.size
90
101
91
- require(output.count() == (numLabels * modelShape.height * modelShape.width).toInt())
92
- {" Model generated unexpected output size." }
102
+ require(output.count() == (numLabels * modelSize.height * modelSize.width).toInt()) { " Model generated unexpected output size." }
93
103
94
- val rescaledResults = rescaleResults(output, numLabels)
104
+ val outputSize = if (resize) originalSize else modelSize
105
+ val numOutputPixels = (outputSize.height * outputSize.width).toInt()
95
106
96
- val argMax = adjustScoresPerPixel(rescaledResults , numLabels)
107
+ val extractedResults = extractResults(output , numLabels, resize )
97
108
98
- // val labelSet = mutableSetOf<String>()
109
+ val argMax = adjustScoresPerPixel(extractedResults, numLabels, outputSize)
110
+
111
+ val labelSet = mutableSetOf<String >()
99
112
// Filter by the label set when base class changed
113
+ for (i in 0 .. < classesOfInterest.size()) {
114
+ labelSet.add(classesOfInterest.getString(i))
115
+ }
100
116
101
117
val res = Arguments .createMap()
102
118
103
119
for (label in 0 .. < numLabels) {
104
- val buffer = FloatArray (numOriginalPixels)
105
- for (pixel in 0 .. < numOriginalPixels) {
106
- val row = pixel / originalSize.width.toInt()
107
- val col = pixel % originalSize.width.toInt()
108
- buffer[pixel] = rescaledResults[label].get(row, col)[0 ].toFloat()
120
+ if (labelSet.contains(deeplabv3_resnet50_labels[label])) {
121
+ val buffer = FloatArray (numOutputPixels)
122
+ for (pixel in 0 .. < numOutputPixels) {
123
+ val row = pixel / outputSize.width.toInt()
124
+ val col = pixel % outputSize.width.toInt()
125
+ buffer[pixel] = extractedResults[label].get(row, col)[0 ].toFloat()
126
+ }
127
+ res.putArray(
128
+ deeplabv3_resnet50_labels[label],
129
+ ArrayUtils .createReadableArrayFromFloatArray(buffer),
130
+ )
109
131
}
110
- res.putArray(deeplabv3_resnet50_labels[label],
111
- ArrayUtils .createReadableArrayFromFloatArray(buffer))
112
132
}
113
133
114
- val argMaxBuffer = IntArray (numOriginalPixels )
115
- for (pixel in 0 .. < numOriginalPixels ) {
116
- val row = pixel / originalSize .width.toInt()
117
- val col = pixel % originalSize .width.toInt()
134
+ val argMaxBuffer = IntArray (numOutputPixels )
135
+ for (pixel in 0 .. < numOutputPixels ) {
136
+ val row = pixel / outputSize .width.toInt()
137
+ val col = pixel % outputSize .width.toInt()
118
138
argMaxBuffer[pixel] = argMax.get(row, col)[0 ].toInt()
119
139
}
120
- res.putArray(" argmax" ,
121
- ArrayUtils .createReadableArrayFromIntArray(argMaxBuffer))
140
+ res.putArray(
141
+ " argmax" ,
142
+ ArrayUtils .createReadableArrayFromIntArray(argMaxBuffer),
143
+ )
122
144
123
145
return res
124
146
}
125
147
126
- override fun runModel (input : Pair <Mat , ReadableArray >)
127
- : WritableMap {
128
- val modelInput = preprocess(input)
148
+ override fun runModel (input : Triple <Mat , ReadableArray , Boolean >): WritableMap {
149
+ val modelInput = preprocess(input.first)
129
150
val modelOutput = forward(modelInput)
130
- return postprocess(modelOutput)
151
+ return postprocess(modelOutput, input.second, input.third )
131
152
}
132
153
}
0 commit comments