Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Add image segmentation for Android #134

Draft
wants to merge 16 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
package com.swmansion.rnexecutorch

import android.util.Log
import com.facebook.react.bridge.Promise
import com.facebook.react.bridge.ReactApplicationContext
import com.facebook.react.bridge.ReadableArray
import com.swmansion.rnexecutorch.models.imagesegmentation.ImageSegmentationModel
import com.swmansion.rnexecutorch.utils.ETError
import com.swmansion.rnexecutorch.utils.ImageProcessor
import org.opencv.android.OpenCVLoader

class ImageSegmentation(
reactContext: ReactApplicationContext,
) : NativeImageSegmentationSpec(reactContext) {
private lateinit var model: ImageSegmentationModel

companion object {
const val NAME = "ImageSegmentation"

init {
if (!OpenCVLoader.initLocal()) {
Log.d("rn_executorch", "OpenCV not loaded")
} else {
Log.d("rn_executorch", "OpenCV loaded")
}
}
}

override fun loadModule(
modelSource: String,
promise: Promise,
) {
try {
model = ImageSegmentationModel(reactApplicationContext)
model.loadModel(modelSource)
promise.resolve(0)
} catch (e: Exception) {
promise.reject(e.message!!, ETError.InvalidModelSource.toString())
}
}

override fun forward(
input: String,
classesOfInterest: ReadableArray,
resize: Boolean,
promise: Promise,
) {
try {
val output =
model.runModel(Triple(ImageProcessor.readImage(input), classesOfInterest, resize))
promise.resolve(output)
} catch (e: Exception) {
promise.reject(e.message!!, e.message)
}
}

override fun getName(): String = NAME
}
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,8 @@ class RnExecutorchPackage : TurboReactPackage() {
OCR(reactContext)
} else if (name == VerticalOCR.NAME) {
VerticalOCR(reactContext)
} else if (name == ImageSegmentation.NAME) {
ImageSegmentation(reactContext)
} else {
null
}
Expand Down Expand Up @@ -115,6 +117,13 @@ class RnExecutorchPackage : TurboReactPackage() {
false, // isCxxModule
true,
)

moduleInfos[ImageSegmentation.NAME] = ReactModuleInfo(
ImageSegmentation.NAME, ImageSegmentation.NAME, false, // canOverrideExistingModule
false, // needsEagerInit
false, // isCxxModule
true
)
moduleInfos
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ package com.swmansion.rnexecutorch
import android.util.Log
import com.facebook.react.bridge.Promise
import com.facebook.react.bridge.ReactApplicationContext
import com.swmansion.rnexecutorch.models.StyleTransferModel
import com.swmansion.rnexecutorch.models.styletransfer.StyleTransferModel
import com.swmansion.rnexecutorch.utils.ETError
import com.swmansion.rnexecutorch.utils.ImageProcessor
import org.opencv.android.OpenCVLoader
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package com.swmansion.rnexecutorch.models.classification
import com.facebook.react.bridge.ReactApplicationContext
import com.swmansion.rnexecutorch.models.BaseModel
import com.swmansion.rnexecutorch.utils.ImageProcessor
import com.swmansion.rnexecutorch.utils.softmax
import org.opencv.core.Mat
import org.opencv.core.Size
import org.opencv.imgproc.Imgproc
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
package com.swmansion.rnexecutorch.models.imagesegmentation

val deeplabv3_resnet50_labels: Array<String> =
arrayOf(
"BACKGROUND",
"AEROPLANE",
"BICYCLE",
"BIRD",
"BOAT",
"BOTTLE",
"BUS",
"CAR",
"CAT",
"CHAIR",
"COW",
"DININGTABLE",
"DOG",
"HORSE",
"MOTORBIKE",
"PERSON",
"POTTEDPLANT",
"SHEEP",
"SOFA",
"TRAIN",
"TVMONITOR",
)
Original file line number Diff line number Diff line change
@@ -0,0 +1,139 @@
package com.swmansion.rnexecutorch.models.imagesegmentation

import com.facebook.react.bridge.Arguments
import com.facebook.react.bridge.ReactApplicationContext
import com.facebook.react.bridge.ReadableArray
import com.facebook.react.bridge.WritableMap
import com.swmansion.rnexecutorch.models.BaseModel
import com.swmansion.rnexecutorch.utils.ArrayUtils
import com.swmansion.rnexecutorch.utils.ImageProcessor
import com.swmansion.rnexecutorch.utils.softmax
import org.opencv.core.CvType
import org.opencv.core.Mat
import org.opencv.core.Size
import org.opencv.imgproc.Imgproc
import org.pytorch.executorch.EValue

class ImageSegmentationModel(
reactApplicationContext: ReactApplicationContext,
) : BaseModel<Triple<Mat, ReadableArray, Boolean>, WritableMap>(reactApplicationContext) {
private lateinit var originalSize: Size

private fun getModelImageSize(): Size {
val inputShape = module.getInputShape(0)
val width = inputShape[inputShape.lastIndex]
val height = inputShape[inputShape.lastIndex - 1]

return Size(height.toDouble(), width.toDouble())
}

fun preprocess(input: Mat): EValue {
originalSize = input.size()
Imgproc.resize(input, input, getModelImageSize())
return ImageProcessor.matToEValue(input, module.getInputShape(0))
}

private fun extractResults(
result: FloatArray,
numLabels: Int,
resize: Boolean,
): List<FloatArray> {
val modelSize = getModelImageSize()
val numModelPixels = (modelSize.height * modelSize.width).toInt()

val extractedLabelScores = mutableListOf<FloatArray>()

for (label in 0..<numLabels) {
// Calls to OpenCV via JNI are very slow so we do as much as we can
// with pure Kotlin
val range = IntRange(label * numModelPixels, (label + 1) * numModelPixels - 1)
val pixelBuffer = result.slice(range).toFloatArray()

if (resize) {
// Rescale the image with OpenCV
val mat = Mat(modelSize, CvType.CV_32F)
mat.put(0, 0, pixelBuffer)
val resizedMat = Mat()
Imgproc.resize(mat, resizedMat, originalSize)
val resizedBuffer = FloatArray((originalSize.height * originalSize.width).toInt())
resizedMat.get(0, 0, resizedBuffer)
extractedLabelScores.add(resizedBuffer)
} else {
extractedLabelScores.add(pixelBuffer)
}
}
return extractedLabelScores
}

private fun adjustScoresPerPixel(
labelScores: List<FloatArray>,
numLabels: Int,
outputSize: Size,
): IntArray {
val numPixels = (outputSize.height * outputSize.width).toInt()
val argMax = IntArray(numPixels)
for (pixel in 0..<numPixels) {
val scores = mutableListOf<Float>()
for (buffer in labelScores) {
scores.add(buffer[pixel])
}
val adjustedScores = softmax(scores.toTypedArray())
for (label in 0..<numLabels) {
labelScores[label][pixel] = adjustedScores[label]
}

val maxIndex = scores.withIndex().maxBy { it.value }.index
argMax[pixel] = maxIndex
}

return argMax
}

fun postprocess(
output: Array<EValue>,
classesOfInterest: ReadableArray,
resize: Boolean,
): WritableMap {
val outputData = output[0].toTensor().dataAsFloatArray
val modelSize = getModelImageSize()
val numLabels = deeplabv3_resnet50_labels.size

require(outputData.count() == (numLabels * modelSize.height * modelSize.width).toInt()) { "Model generated unexpected output size." }

val outputSize = if (resize) originalSize else modelSize

val extractedResults = extractResults(outputData, numLabels, resize)

val argMax = adjustScoresPerPixel(extractedResults, numLabels, outputSize)

val labelSet = mutableSetOf<String>()
// Filter by the label set when base class changed
for (i in 0..<classesOfInterest.size()) {
labelSet.add(classesOfInterest.getString(i))
}

val res = Arguments.createMap()

for (label in 0..<numLabels) {
if (labelSet.contains(deeplabv3_resnet50_labels[label])) {
res.putArray(
deeplabv3_resnet50_labels[label],
ArrayUtils.createReadableArrayFromFloatArray(extractedResults[label]),
)
}
}

res.putArray(
"ARGMAX",
ArrayUtils.createReadableArrayFromIntArray(argMax),
)

return res
}

override fun runModel(input: Triple<Mat, ReadableArray, Boolean>): WritableMap {
val modelInput = preprocess(input.first)
val modelOutput = forward(modelInput)
return postprocess(modelOutput, input.second, input.third)
}
}
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
package com.swmansion.rnexecutorch.models
package com.swmansion.rnexecutorch.models.styletransfer

import com.facebook.react.bridge.ReactApplicationContext
import com.swmansion.rnexecutorch.utils.ImageProcessor
import org.opencv.core.Mat
import org.opencv.core.Size
import org.opencv.imgproc.Imgproc
import org.pytorch.executorch.EValue
import com.swmansion.rnexecutorch.models.BaseModel

class StyleTransferModel(
reactApplicationContext: ReactApplicationContext,
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
package com.swmansion.rnexecutorch.models.classification
package com.swmansion.rnexecutorch.utils

fun softmax(x: Array<Float>): Array<Float> {
val max = x.maxOrNull()!!
Expand Down
5 changes: 5 additions & 0 deletions ios/RnExecutorch/ImageSegmentation.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
#import <RnExecutorchSpec/RnExecutorchSpec.h>

@interface ImageSegmentation : NSObject <NativeImageSegmentationSpec>

@end
63 changes: 63 additions & 0 deletions ios/RnExecutorch/ImageSegmentation.mm
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
#import "ImageSegmentation.h"
#import "models/image_segmentation/ImageSegmentationModel.h"
#import "models/BaseModel.h"
#import "utils/ETError.h"
#import <ExecutorchLib/ETModel.h>
#import <React/RCTBridgeModule.h>
#import <opencv2/opencv.hpp>
#import "ImageProcessor.h"

@implementation ImageSegmentation {
ImageSegmentationModel *model;
}

RCT_EXPORT_MODULE()

- (void)loadModule:(NSString *)modelSource
resolve:(RCTPromiseResolveBlock)resolve
reject:(RCTPromiseRejectBlock)reject {

model = [[ImageSegmentationModel alloc] init];
[model
loadModel:[NSURL URLWithString:modelSource]
completion:^(BOOL success, NSNumber *errorCode) {
if (success) {
resolve(errorCode);
return;
}

reject(@"init_module_error",
[NSString stringWithFormat:@"%ld", (long)[errorCode longValue]],
nil);
return;
}];
}

- (void)forward:(NSString *)input
classesOfInterest:(NSArray *)classesOfInterest
resize:(BOOL)resize
resolve:(RCTPromiseResolveBlock)resolve
reject:(RCTPromiseRejectBlock)reject {

@try {
cv::Mat image = [ImageProcessor readImage:input];
NSDictionary *result = [model runModel:image
returnClasses:classesOfInterest
resize:resize];

resolve(result);
return;
} @catch (NSException *exception) {
NSLog(@"An exception occurred: %@, %@", exception.name, exception.reason);
reject(@"forward_error",
[NSString stringWithFormat:@"%@", exception.reason], nil);
return;
}
}

- (std::shared_ptr<facebook::react::TurboModule>)getTurboModule:
(const facebook::react::ObjCTurboModule::InitParams &)params {
return std::make_shared<facebook::react::NativeImageSegmentationSpecJSI>(params);
}

@end
2 changes: 1 addition & 1 deletion ios/RnExecutorch/StyleTransfer.mm
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
#import "StyleTransfer.h"
#import "ImageProcessor.h"
#import "models/BaseModel.h"
#import "models/StyleTransferModel.h"
#import "models/style_transfer/StyleTransferModel.h"
#import "utils/ETError.h"
#import <ExecutorchLib/ETModel.h>
#import <React/RCTBridgeModule.h>
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
#import "ClassificationModel.h"
#import "../../utils/ImageProcessor.h"
#import "../../utils/Numerical.h"
#import "Constants.h"
#import "Utils.h"
#import "opencv2/opencv.hpp"

@implementation ClassificationModel
Expand Down
5 changes: 5 additions & 0 deletions ios/RnExecutorch/models/image_segmentation/Constants.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
#import <string>
#import <vector>


extern const std::vector<std::string> deeplabv3_resnet50_labels;
10 changes: 10 additions & 0 deletions ios/RnExecutorch/models/image_segmentation/Constants.mm
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
#import "Constants.h"
#import <string>
#import <vector>

const std::vector<std::string> deeplabv3_resnet50_labels = {
"BACKGROUND", "AEROPLANE", "BICYCLE", "BIRD", "BOAT",
"BOTTLE", "BUS", "CAR", "CAT", "CHAIR", "COW", "DININGTABLE",
"DOG", "HORSE", "MOTORBIKE", "PERSON", "POTTEDPLANT", "SHEEP",
"SOFA", "TRAIN", "TVMONITOR"
};
Loading