Skip to content

Commit d780a87

Browse files
committed
Adapt segmentation bridge types
1 parent 899e440 commit d780a87

File tree

2 files changed

+35
-19
lines changed

2 files changed

+35
-19
lines changed

android/src/main/java/com/swmansion/rnexecutorch/ImageSegmentation.kt

+5-1
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ package com.swmansion.rnexecutorch
22

33
import android.util.Log
44
import com.facebook.react.bridge.Promise
5+
import com.facebook.react.bridge.WritableMap
56
import com.facebook.react.bridge.ReadableArray
67
import com.facebook.react.bridge.ReactApplicationContext
78
import com.swmansion.rnexecutorch.utils.ETError
@@ -36,7 +37,10 @@ class ImageSegmentation(reactContext: ReactApplicationContext) :
3637
}
3738
}
3839

39-
override fun forward(input: String, classesOfInterest: ReadableArray, promise: Promise) {
40+
override fun forward(input: String,
41+
classesOfInterest: ReadableArray,
42+
resize:Boolean,
43+
promise: Promise) {
4044
try {
4145
val output =
4246
model.runModel(Pair(ImageProcessor.readImage(input), classesOfInterest))

android/src/main/java/com/swmansion/rnexecutorch/models/imageSegmentation/ImageSegmentationModel.kt

+30-18
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
package com.swmansion.rnexecutorch.models.imagesegmentation
22

3+
import com.facebook.react.bridge.Arguments;
34
import com.facebook.react.bridge.ReadableArray
5+
import com.facebook.react.bridge.WritableMap
46
import com.facebook.react.bridge.ReactApplicationContext
57
import com.swmansion.rnexecutorch.utils.ImageProcessor
68
import com.swmansion.rnexecutorch.utils.softmax
@@ -11,9 +13,10 @@ import org.opencv.imgproc.Imgproc
1113
import org.pytorch.executorch.Tensor
1214
import org.pytorch.executorch.EValue
1315
import com.swmansion.rnexecutorch.models.BaseModel
16+
import com.swmansion.rnexecutorch.utils.ArrayUtils
1417

1518
class ImageSegmentationModel(reactApplicationContext: ReactApplicationContext)
16-
: BaseModel <Pair<Mat, ReadableArray>, Map<String, List<Any>>>(reactApplicationContext) {
19+
: BaseModel <Pair<Mat, ReadableArray>, WritableMap>(reactApplicationContext) {
1720
private lateinit var originalSize: Size
1821

1922
private fun getModelImageSize(): Size {
@@ -58,30 +61,28 @@ class ImageSegmentationModel(reactApplicationContext: ReactApplicationContext)
5861
: Mat {
5962
val argMax = Mat(originalSize, CvType.CV_32S)
6063
val numOriginalPixels = (originalSize.height * originalSize.width).toInt()
64+
android.util.Log.d("ETTT", "adjustScoresPerPixel: start")
6165
for (pixel in 0..<numOriginalPixels) {
6266
val row = pixel / originalSize.width.toInt()
63-
val col = pixel % originalSize.height.toInt()
67+
val col = pixel % originalSize.width.toInt()
6468
val scores = mutableListOf<Float>()
6569
for (mat in labelScores) {
66-
val v = FloatArray(1)
67-
mat.get(row, col, v)
68-
scores.add(v[0])
70+
scores.add(mat.get(row, col)[0].toFloat())
6971
}
70-
7172
val adjustedScores = softmax(scores.toTypedArray())
72-
7373
for (label in 0..<numLabels) {
74-
labelScores[label].put(row, col, FloatArray(1){adjustedScores[label]})
74+
labelScores[label].put(row, col, floatArrayOf(adjustedScores[label]))
7575
}
7676

7777
val maxIndex = scores.withIndex().maxBy{it.value}.index
78-
argMax.put(row, col, IntArray(1){maxIndex})
78+
argMax.put(row, col, intArrayOf(maxIndex))
7979
}
8080

8181
return argMax
8282
}
8383

84-
override fun postprocess(output: Array<EValue>): Map<String, List<Any>> {
84+
override fun postprocess(output: Array<EValue>)
85+
: WritableMap {
8586
val output = output[0].toTensor().dataAsFloatArray.toTypedArray()
8687
val modelShape = getModelImageSize()
8788
val numLabels = deeplabv3_resnet50_labels.size;
@@ -93,26 +94,37 @@ class ImageSegmentationModel(reactApplicationContext: ReactApplicationContext)
9394
val rescaledResults = rescaleResults(output, numLabels)
9495

9596
val argMax = adjustScoresPerPixel(rescaledResults, numLabels)
96-
97+
9798
// val labelSet = mutableSetOf<String>()
9899
// Filter by the label set when base class changed
99100

100-
val res = mutableMapOf<String, List<Any>>()
101-
101+
val res = Arguments.createMap()
102+
102103
for (label in 0..<numLabels) {
103104
val buffer = FloatArray(numOriginalPixels)
104-
rescaledResults[label].get(0, 0, buffer)
105-
res[deeplabv3_resnet50_labels[label]] = buffer.toList()
105+
for (pixel in 0..<numOriginalPixels) {
106+
val row = pixel / originalSize.width.toInt()
107+
val col = pixel % originalSize.width.toInt()
108+
buffer[pixel] = rescaledResults[label].get(row, col)[0].toFloat()
109+
}
110+
res.putArray(deeplabv3_resnet50_labels[label],
111+
ArrayUtils.createReadableArrayFromFloatArray(buffer))
106112
}
107113

108114
val argMaxBuffer = IntArray(numOriginalPixels)
109-
argMax.get(0, 0, argMaxBuffer)
110-
res["argmax"] = argMaxBuffer.toList()
115+
for (pixel in 0..<numOriginalPixels) {
116+
val row = pixel / originalSize.width.toInt()
117+
val col = pixel % originalSize.width.toInt()
118+
argMaxBuffer[pixel] = argMax.get(row, col)[0].toInt()
119+
}
120+
res.putArray("argmax",
121+
ArrayUtils.createReadableArrayFromIntArray(argMaxBuffer))
111122

112123
return res
113124
}
114125

115-
override fun runModel(input: Pair<Mat, ReadableArray>): Map<String, List<Any>> {
126+
override fun runModel(input: Pair<Mat, ReadableArray>)
127+
: WritableMap {
116128
val modelInput = preprocess(input)
117129
val modelOutput = forward(modelInput)
118130
return postprocess(modelOutput)

0 commit comments

Comments
 (0)