1
1
package com.swmansion.rnexecutorch.models.imagesegmentation
2
2
3
+ import com.facebook.react.bridge.Arguments;
3
4
import com.facebook.react.bridge.ReadableArray
5
+ import com.facebook.react.bridge.WritableMap
4
6
import com.facebook.react.bridge.ReactApplicationContext
5
7
import com.swmansion.rnexecutorch.utils.ImageProcessor
6
8
import com.swmansion.rnexecutorch.utils.softmax
@@ -11,9 +13,10 @@ import org.opencv.imgproc.Imgproc
11
13
import org.pytorch.executorch.Tensor
12
14
import org.pytorch.executorch.EValue
13
15
import com.swmansion.rnexecutorch.models.BaseModel
16
+ import com.swmansion.rnexecutorch.utils.ArrayUtils
14
17
15
18
class ImageSegmentationModel (reactApplicationContext : ReactApplicationContext )
16
- : BaseModel <Pair <Mat , ReadableArray >, Map < String , List < Any >> >(reactApplicationContext) {
19
+ : BaseModel <Pair <Mat , ReadableArray >, WritableMap > (reactApplicationContext) {
17
20
private lateinit var originalSize: Size
18
21
19
22
private fun getModelImageSize (): Size {
@@ -58,30 +61,28 @@ class ImageSegmentationModel(reactApplicationContext: ReactApplicationContext)
58
61
: Mat {
59
62
val argMax = Mat (originalSize, CvType .CV_32S )
60
63
val numOriginalPixels = (originalSize.height * originalSize.width).toInt()
64
+ android.util.Log .d(" ETTT" , " adjustScoresPerPixel: start" )
61
65
for (pixel in 0 .. < numOriginalPixels) {
62
66
val row = pixel / originalSize.width.toInt()
63
- val col = pixel % originalSize.height .toInt()
67
+ val col = pixel % originalSize.width .toInt()
64
68
val scores = mutableListOf<Float >()
65
69
for (mat in labelScores) {
66
- val v = FloatArray (1 )
67
- mat.get(row, col, v)
68
- scores.add(v[0 ])
70
+ scores.add(mat.get(row, col)[0 ].toFloat())
69
71
}
70
-
71
72
val adjustedScores = softmax(scores.toTypedArray())
72
-
73
73
for (label in 0 .. < numLabels) {
74
- labelScores[label].put(row, col, FloatArray ( 1 ){ adjustedScores[label]} )
74
+ labelScores[label].put(row, col, floatArrayOf( adjustedScores[label]) )
75
75
}
76
76
77
77
val maxIndex = scores.withIndex().maxBy{it.value}.index
78
- argMax.put(row, col, IntArray ( 1 ){ maxIndex} )
78
+ argMax.put(row, col, intArrayOf( maxIndex) )
79
79
}
80
80
81
81
return argMax
82
82
}
83
83
84
- override fun postprocess (output : Array <EValue >): Map <String , List <Any >> {
84
+ override fun postprocess (output : Array <EValue >)
85
+ : WritableMap {
85
86
val output = output[0 ].toTensor().dataAsFloatArray.toTypedArray()
86
87
val modelShape = getModelImageSize()
87
88
val numLabels = deeplabv3_resnet50_labels.size;
@@ -93,26 +94,37 @@ class ImageSegmentationModel(reactApplicationContext: ReactApplicationContext)
93
94
val rescaledResults = rescaleResults(output, numLabels)
94
95
95
96
val argMax = adjustScoresPerPixel(rescaledResults, numLabels)
96
-
97
+
97
98
// val labelSet = mutableSetOf<String>()
98
99
// Filter by the label set when base class changed
99
100
100
- val res = mutableMapOf< String , List < Any >> ()
101
-
101
+ val res = Arguments .createMap ()
102
+
102
103
for (label in 0 .. < numLabels) {
103
104
val buffer = FloatArray (numOriginalPixels)
104
- rescaledResults[label].get(0 , 0 , buffer)
105
- res[deeplabv3_resnet50_labels[label]] = buffer.toList()
105
+ for (pixel in 0 .. < numOriginalPixels) {
106
+ val row = pixel / originalSize.width.toInt()
107
+ val col = pixel % originalSize.width.toInt()
108
+ buffer[pixel] = rescaledResults[label].get(row, col)[0 ].toFloat()
109
+ }
110
+ res.putArray(deeplabv3_resnet50_labels[label],
111
+ ArrayUtils .createReadableArrayFromFloatArray(buffer))
106
112
}
107
113
108
114
val argMaxBuffer = IntArray (numOriginalPixels)
109
- argMax.get(0 , 0 , argMaxBuffer)
110
- res[" argmax" ] = argMaxBuffer.toList()
115
+ for (pixel in 0 .. < numOriginalPixels) {
116
+ val row = pixel / originalSize.width.toInt()
117
+ val col = pixel % originalSize.width.toInt()
118
+ argMaxBuffer[pixel] = argMax.get(row, col)[0 ].toInt()
119
+ }
120
+ res.putArray(" argmax" ,
121
+ ArrayUtils .createReadableArrayFromIntArray(argMaxBuffer))
111
122
112
123
return res
113
124
}
114
125
115
- override fun runModel (input : Pair <Mat , ReadableArray >): Map <String , List <Any >> {
126
+ override fun runModel (input : Pair <Mat , ReadableArray >)
127
+ : WritableMap {
116
128
val modelInput = preprocess(input)
117
129
val modelOutput = forward(modelInput)
118
130
return postprocess(modelOutput)
0 commit comments