Skip to content

Commit 899e440

Browse files
committed
Add image segmentation model logic
1 parent 60ae07b commit 899e440

File tree

6 files changed

+191
-0
lines changed

6 files changed

+191
-0
lines changed
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,52 @@
1+
package com.swmansion.rnexecutorch
2+
3+
import android.util.Log
4+
import com.facebook.react.bridge.Promise
5+
import com.facebook.react.bridge.ReadableArray
6+
import com.facebook.react.bridge.ReactApplicationContext
7+
import com.swmansion.rnexecutorch.utils.ETError
8+
import com.swmansion.rnexecutorch.models.imagesegmentation.ImageSegmentationModel
9+
import com.swmansion.rnexecutorch.utils.ImageProcessor
10+
import org.opencv.android.OpenCVLoader
11+
12+
class ImageSegmentation(reactContext: ReactApplicationContext) :
13+
NativeImageSegmentationSpec(reactContext) {
14+
15+
private lateinit var model: ImageSegmentationModel
16+
17+
companion object {
18+
const val NAME = "ImageSegmentation"
19+
20+
init {
21+
if(!OpenCVLoader.initLocal()){
22+
Log.d("rn_executorch", "OpenCV not loaded")
23+
} else {
24+
Log.d("rn_executorch", "OpenCV loaded")
25+
}
26+
}
27+
}
28+
29+
override fun loadModule(modelSource: String, promise: Promise) {
30+
try {
31+
model = ImageSegmentationModel(reactApplicationContext)
32+
model.loadModel(modelSource)
33+
promise.resolve(0)
34+
} catch (e: Exception) {
35+
promise.reject(e.message!!, ETError.InvalidModelSource.toString())
36+
}
37+
}
38+
39+
override fun forward(input: String, classesOfInterest: ReadableArray, promise: Promise) {
40+
try {
41+
val output =
42+
model.runModel(Pair(ImageProcessor.readImage(input), classesOfInterest))
43+
promise.resolve(output)
44+
}catch(e: Exception){
45+
promise.reject(e.message!!, e.message)
46+
}
47+
}
48+
49+
override fun getName(): String {
50+
return NAME
51+
}
52+
}

android/src/main/java/com/swmansion/rnexecutorch/RnExecutorchPackage.kt

+9
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,8 @@ class RnExecutorchPackage : TurboReactPackage() {
3030
OCR(reactContext)
3131
} else if (name == VerticalOCR.NAME) {
3232
VerticalOCR(reactContext)
33+
} else if (name == ImageSegmentation.NAME) {
34+
ImageSegmentation(reactContext)
3335
} else {
3436
null
3537
}
@@ -115,6 +117,13 @@ class RnExecutorchPackage : TurboReactPackage() {
115117
false, // isCxxModule
116118
true,
117119
)
120+
121+
moduleInfos[ImageSegmentation.NAME] = ReactModuleInfo(
122+
ImageSegmentation.NAME, ImageSegmentation.NAME, false, // canOverrideExistingModule
123+
false, // needsEagerInit
124+
false, // isCxxModule
125+
true
126+
)
118127
moduleInfos
119128
}
120129
}

android/src/main/java/com/swmansion/rnexecutorch/models/classification/ClassificationModel.kt

+1
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ package com.swmansion.rnexecutorch.models.classification
33
import com.facebook.react.bridge.ReactApplicationContext
44
import com.swmansion.rnexecutorch.models.BaseModel
55
import com.swmansion.rnexecutorch.utils.ImageProcessor
6+
import com.swmansion.rnexecutorch.utils.softmax
67
import org.opencv.core.Mat
78
import org.opencv.core.Size
89
import org.opencv.imgproc.Imgproc
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
package com.swmansion.rnexecutorch.models.imagesegmentation
2+
3+
val deeplabv3_resnet50_labels: Array<String> = arrayOf(
4+
"background", "aeroplane", "bicycle", "bird", "boat",
5+
"bottle", "bus", "car", "cat", "chair", "cow", "diningtable",
6+
"dog", "horse", "motorbike", "person", "pottedplant", "sheep",
7+
"sofa", "train", "tvmonitor"
8+
)
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,120 @@
1+
package com.swmansion.rnexecutorch.models.imagesegmentation
2+
3+
import com.facebook.react.bridge.ReadableArray
4+
import com.facebook.react.bridge.ReactApplicationContext
5+
import com.swmansion.rnexecutorch.utils.ImageProcessor
6+
import com.swmansion.rnexecutorch.utils.softmax
7+
import org.opencv.core.Mat
8+
import org.opencv.core.CvType
9+
import org.opencv.core.Size
10+
import org.opencv.imgproc.Imgproc
11+
import org.pytorch.executorch.Tensor
12+
import org.pytorch.executorch.EValue
13+
import com.swmansion.rnexecutorch.models.BaseModel
14+
15+
class ImageSegmentationModel(reactApplicationContext: ReactApplicationContext)
16+
: BaseModel <Pair<Mat, ReadableArray>, Map<String, List<Any>>>(reactApplicationContext) {
17+
private lateinit var originalSize: Size
18+
19+
private fun getModelImageSize(): Size {
20+
val inputShape = module.getInputShape(0)
21+
val width = inputShape[inputShape.lastIndex]
22+
val height = inputShape[inputShape.lastIndex - 1]
23+
24+
return Size(height.toDouble(), width.toDouble())
25+
}
26+
27+
override fun preprocess(input: Pair<Mat, ReadableArray>): EValue {
28+
originalSize = input.first.size()
29+
Imgproc.resize(input.first, input.first, getModelImageSize())
30+
return ImageProcessor.matToEValue(input.first, module.getInputShape(0))
31+
}
32+
33+
private fun rescaleResults(result: Array<Float>, numLabels: Int)
34+
: List<Mat> {
35+
val modelShape = getModelImageSize()
36+
val numModelPixels = (modelShape.height * modelShape.width).toInt()
37+
38+
val resizedLabelScores = mutableListOf<Mat>()
39+
40+
for (label in 0..<numLabels) {
41+
val mat = Mat(modelShape, CvType.CV_32F)
42+
43+
for (pixel in 0..<numModelPixels) {
44+
val row = pixel / modelShape.width.toInt()
45+
val col = pixel % modelShape.width.toInt()
46+
val v = floatArrayOf(result[label * numModelPixels + pixel])
47+
mat.put(row, col, v)
48+
}
49+
50+
val resizedMat = Mat()
51+
Imgproc.resize(mat, resizedMat, originalSize)
52+
resizedLabelScores.add(resizedMat)
53+
}
54+
return resizedLabelScores;
55+
}
56+
57+
private fun adjustScoresPerPixel(labelScores: List<Mat>, numLabels: Int)
58+
: Mat {
59+
val argMax = Mat(originalSize, CvType.CV_32S)
60+
val numOriginalPixels = (originalSize.height * originalSize.width).toInt()
61+
for (pixel in 0..<numOriginalPixels) {
62+
val row = pixel / originalSize.width.toInt()
63+
val col = pixel % originalSize.height.toInt()
64+
val scores = mutableListOf<Float>()
65+
for (mat in labelScores) {
66+
val v = FloatArray(1)
67+
mat.get(row, col, v)
68+
scores.add(v[0])
69+
}
70+
71+
val adjustedScores = softmax(scores.toTypedArray())
72+
73+
for (label in 0..<numLabels) {
74+
labelScores[label].put(row, col, FloatArray(1){adjustedScores[label]})
75+
}
76+
77+
val maxIndex = scores.withIndex().maxBy{it.value}.index
78+
argMax.put(row, col, IntArray(1){maxIndex})
79+
}
80+
81+
return argMax
82+
}
83+
84+
override fun postprocess(output: Array<EValue>): Map<String, List<Any>> {
85+
val output = output[0].toTensor().dataAsFloatArray.toTypedArray()
86+
val modelShape = getModelImageSize()
87+
val numLabels = deeplabv3_resnet50_labels.size;
88+
val numOriginalPixels = (originalSize.height * originalSize.width).toInt()
89+
90+
require(output.count() == (numLabels * modelShape.height * modelShape.width).toInt())
91+
{"Model generated unexpected output size."}
92+
93+
val rescaledResults = rescaleResults(output, numLabels)
94+
95+
val argMax = adjustScoresPerPixel(rescaledResults, numLabels)
96+
97+
// val labelSet = mutableSetOf<String>()
98+
// Filter by the label set when base class changed
99+
100+
val res = mutableMapOf<String, List<Any>>()
101+
102+
for (label in 0..<numLabels) {
103+
val buffer = FloatArray(numOriginalPixels)
104+
rescaledResults[label].get(0, 0, buffer)
105+
res[deeplabv3_resnet50_labels[label]] = buffer.toList()
106+
}
107+
108+
val argMaxBuffer = IntArray(numOriginalPixels)
109+
argMax.get(0, 0, argMaxBuffer)
110+
res["argmax"] = argMaxBuffer.toList()
111+
112+
return res
113+
}
114+
115+
override fun runModel(input: Pair<Mat, ReadableArray>): Map<String, List<Any>> {
116+
val modelInput = preprocess(input)
117+
val modelOutput = forward(modelInput)
118+
return postprocess(modelOutput)
119+
}
120+
}

android/src/main/java/com/swmansion/rnexecutorch/models/styleTransfer/StyleTransferModel.kt

+1
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ import org.opencv.core.Mat
66
import org.opencv.core.Size
77
import org.opencv.imgproc.Imgproc
88
import org.pytorch.executorch.EValue
9+
import com.swmansion.rnexecutorch.models.BaseModel
910

1011
class StyleTransferModel(
1112
reactApplicationContext: ReactApplicationContext,

0 commit comments

Comments
 (0)