diff --git a/android/src/main/java/com/swmansion/rnexecutorch/ImageSegmentation.kt b/android/src/main/java/com/swmansion/rnexecutorch/ImageSegmentation.kt new file mode 100644 index 00000000..c18fa8ed --- /dev/null +++ b/android/src/main/java/com/swmansion/rnexecutorch/ImageSegmentation.kt @@ -0,0 +1,58 @@ +package com.swmansion.rnexecutorch + +import android.util.Log +import com.facebook.react.bridge.Promise +import com.facebook.react.bridge.ReactApplicationContext +import com.facebook.react.bridge.ReadableArray +import com.swmansion.rnexecutorch.models.imagesegmentation.ImageSegmentationModel +import com.swmansion.rnexecutorch.utils.ETError +import com.swmansion.rnexecutorch.utils.ImageProcessor +import org.opencv.android.OpenCVLoader + +class ImageSegmentation( + reactContext: ReactApplicationContext, +) : NativeImageSegmentationSpec(reactContext) { + private lateinit var model: ImageSegmentationModel + + companion object { + const val NAME = "ImageSegmentation" + + init { + if (!OpenCVLoader.initLocal()) { + Log.d("rn_executorch", "OpenCV not loaded") + } else { + Log.d("rn_executorch", "OpenCV loaded") + } + } + } + + override fun loadModule( + modelSource: String, + promise: Promise, + ) { + try { + model = ImageSegmentationModel(reactApplicationContext) + model.loadModel(modelSource) + promise.resolve(0) + } catch (e: Exception) { + promise.reject(e.message!!, ETError.InvalidModelSource.toString()) + } + } + + override fun forward( + input: String, + classesOfInterest: ReadableArray, + resize: Boolean, + promise: Promise, + ) { + try { + val output = + model.runModel(Triple(ImageProcessor.readImage(input), classesOfInterest, resize)) + promise.resolve(output) + } catch (e: Exception) { + promise.reject(e.message!!, e.message) + } + } + + override fun getName(): String = NAME +} diff --git a/android/src/main/java/com/swmansion/rnexecutorch/RnExecutorchPackage.kt b/android/src/main/java/com/swmansion/rnexecutorch/RnExecutorchPackage.kt index c88e3870..3c78d4d7 100644 --- a/android/src/main/java/com/swmansion/rnexecutorch/RnExecutorchPackage.kt +++ b/android/src/main/java/com/swmansion/rnexecutorch/RnExecutorchPackage.kt @@ -30,6 +30,8 @@ class RnExecutorchPackage : TurboReactPackage() { OCR(reactContext) } else if (name == VerticalOCR.NAME) { VerticalOCR(reactContext) + } else if (name == ImageSegmentation.NAME) { + ImageSegmentation(reactContext) } else { null } @@ -115,6 +117,13 @@ class RnExecutorchPackage : TurboReactPackage() { false, // isCxxModule true, ) + + moduleInfos[ImageSegmentation.NAME] = ReactModuleInfo( + ImageSegmentation.NAME, ImageSegmentation.NAME, false, // canOverrideExistingModule + false, // needsEagerInit + false, // isCxxModule + true + ) moduleInfos } } diff --git a/android/src/main/java/com/swmansion/rnexecutorch/StyleTransfer.kt b/android/src/main/java/com/swmansion/rnexecutorch/StyleTransfer.kt index 54132b88..224794e1 100644 --- a/android/src/main/java/com/swmansion/rnexecutorch/StyleTransfer.kt +++ b/android/src/main/java/com/swmansion/rnexecutorch/StyleTransfer.kt @@ -3,7 +3,7 @@ package com.swmansion.rnexecutorch import android.util.Log import com.facebook.react.bridge.Promise import com.facebook.react.bridge.ReactApplicationContext -import com.swmansion.rnexecutorch.models.StyleTransferModel +import com.swmansion.rnexecutorch.models.styletransfer.StyleTransferModel import com.swmansion.rnexecutorch.utils.ETError import com.swmansion.rnexecutorch.utils.ImageProcessor import org.opencv.android.OpenCVLoader diff --git a/android/src/main/java/com/swmansion/rnexecutorch/models/classification/ClassificationModel.kt b/android/src/main/java/com/swmansion/rnexecutorch/models/classification/ClassificationModel.kt index b60b0998..776f9a53 100644 --- a/android/src/main/java/com/swmansion/rnexecutorch/models/classification/ClassificationModel.kt +++ b/android/src/main/java/com/swmansion/rnexecutorch/models/classification/ClassificationModel.kt @@ -3,6 +3,7 @@ package com.swmansion.rnexecutorch.models.classification import com.facebook.react.bridge.ReactApplicationContext import com.swmansion.rnexecutorch.models.BaseModel import com.swmansion.rnexecutorch.utils.ImageProcessor +import com.swmansion.rnexecutorch.utils.softmax import org.opencv.core.Mat import org.opencv.core.Size import org.opencv.imgproc.Imgproc diff --git a/android/src/main/java/com/swmansion/rnexecutorch/models/imageSegmentation/Constants.kt b/android/src/main/java/com/swmansion/rnexecutorch/models/imageSegmentation/Constants.kt new file mode 100644 index 00000000..7ba7fcb5 --- /dev/null +++ b/android/src/main/java/com/swmansion/rnexecutorch/models/imageSegmentation/Constants.kt @@ -0,0 +1,26 @@ +package com.swmansion.rnexecutorch.models.imagesegmentation + +val deeplabv3_resnet50_labels: Array = + arrayOf( + "BACKGROUND", + "AEROPLANE", + "BICYCLE", + "BIRD", + "BOAT", + "BOTTLE", + "BUS", + "CAR", + "CAT", + "CHAIR", + "COW", + "DININGTABLE", + "DOG", + "HORSE", + "MOTORBIKE", + "PERSON", + "POTTEDPLANT", + "SHEEP", + "SOFA", + "TRAIN", + "TVMONITOR", + ) diff --git a/android/src/main/java/com/swmansion/rnexecutorch/models/imageSegmentation/ImageSegmentationModel.kt b/android/src/main/java/com/swmansion/rnexecutorch/models/imageSegmentation/ImageSegmentationModel.kt new file mode 100644 index 00000000..36c1594b --- /dev/null +++ b/android/src/main/java/com/swmansion/rnexecutorch/models/imageSegmentation/ImageSegmentationModel.kt @@ -0,0 +1,139 @@ +package com.swmansion.rnexecutorch.models.imagesegmentation + +import com.facebook.react.bridge.Arguments +import com.facebook.react.bridge.ReactApplicationContext +import com.facebook.react.bridge.ReadableArray +import com.facebook.react.bridge.WritableMap +import com.swmansion.rnexecutorch.models.BaseModel +import com.swmansion.rnexecutorch.utils.ArrayUtils +import com.swmansion.rnexecutorch.utils.ImageProcessor +import com.swmansion.rnexecutorch.utils.softmax +import org.opencv.core.CvType +import org.opencv.core.Mat +import org.opencv.core.Size +import org.opencv.imgproc.Imgproc +import org.pytorch.executorch.EValue + +class ImageSegmentationModel( + reactApplicationContext: ReactApplicationContext, +) : BaseModel, WritableMap>(reactApplicationContext) { + private lateinit var originalSize: Size + + private fun getModelImageSize(): Size { + val inputShape = module.getInputShape(0) + val width = inputShape[inputShape.lastIndex] + val height = inputShape[inputShape.lastIndex - 1] + + return Size(height.toDouble(), width.toDouble()) + } + + fun preprocess(input: Mat): EValue { + originalSize = input.size() + Imgproc.resize(input, input, getModelImageSize()) + return ImageProcessor.matToEValue(input, module.getInputShape(0)) + } + + private fun extractResults( + result: FloatArray, + numLabels: Int, + resize: Boolean, + ): List { + val modelSize = getModelImageSize() + val numModelPixels = (modelSize.height * modelSize.width).toInt() + + val extractedLabelScores = mutableListOf() + + for (label in 0.., + numLabels: Int, + outputSize: Size, + ): IntArray { + val numPixels = (outputSize.height * outputSize.width).toInt() + val argMax = IntArray(numPixels) + for (pixel in 0..() + for (buffer in labelScores) { + scores.add(buffer[pixel]) + } + val adjustedScores = softmax(scores.toTypedArray()) + for (label in 0.., + classesOfInterest: ReadableArray, + resize: Boolean, + ): WritableMap { + val outputData = output[0].toTensor().dataAsFloatArray + val modelSize = getModelImageSize() + val numLabels = deeplabv3_resnet50_labels.size + + require(outputData.count() == (numLabels * modelSize.height * modelSize.width).toInt()) { "Model generated unexpected output size." } + + val outputSize = if (resize) originalSize else modelSize + + val extractedResults = extractResults(outputData, numLabels, resize) + + val argMax = adjustScoresPerPixel(extractedResults, numLabels, outputSize) + + val labelSet = mutableSetOf() + // Filter by the label set when base class changed + for (i in 0..): WritableMap { + val modelInput = preprocess(input.first) + val modelOutput = forward(modelInput) + return postprocess(modelOutput, input.second, input.third) + } +} diff --git a/android/src/main/java/com/swmansion/rnexecutorch/models/StyleTransferModel.kt b/android/src/main/java/com/swmansion/rnexecutorch/models/styleTransfer/StyleTransferModel.kt similarity index 92% rename from android/src/main/java/com/swmansion/rnexecutorch/models/StyleTransferModel.kt rename to android/src/main/java/com/swmansion/rnexecutorch/models/styleTransfer/StyleTransferModel.kt index 72d3bc6d..4019015d 100644 --- a/android/src/main/java/com/swmansion/rnexecutorch/models/StyleTransferModel.kt +++ b/android/src/main/java/com/swmansion/rnexecutorch/models/styleTransfer/StyleTransferModel.kt @@ -1,4 +1,4 @@ -package com.swmansion.rnexecutorch.models +package com.swmansion.rnexecutorch.models.styletransfer import com.facebook.react.bridge.ReactApplicationContext import com.swmansion.rnexecutorch.utils.ImageProcessor @@ -6,6 +6,7 @@ import org.opencv.core.Mat import org.opencv.core.Size import org.opencv.imgproc.Imgproc import org.pytorch.executorch.EValue +import com.swmansion.rnexecutorch.models.BaseModel class StyleTransferModel( reactApplicationContext: ReactApplicationContext, diff --git a/android/src/main/java/com/swmansion/rnexecutorch/models/classification/Utils.kt b/android/src/main/java/com/swmansion/rnexecutorch/utils/Numerical.kt similarity index 77% rename from android/src/main/java/com/swmansion/rnexecutorch/models/classification/Utils.kt rename to android/src/main/java/com/swmansion/rnexecutorch/utils/Numerical.kt index e919950a..603699e3 100644 --- a/android/src/main/java/com/swmansion/rnexecutorch/models/classification/Utils.kt +++ b/android/src/main/java/com/swmansion/rnexecutorch/utils/Numerical.kt @@ -1,4 +1,4 @@ -package com.swmansion.rnexecutorch.models.classification +package com.swmansion.rnexecutorch.utils fun softmax(x: Array): Array { val max = x.maxOrNull()!! diff --git a/ios/RnExecutorch/ImageSegmentation.h b/ios/RnExecutorch/ImageSegmentation.h new file mode 100644 index 00000000..59ed56a4 --- /dev/null +++ b/ios/RnExecutorch/ImageSegmentation.h @@ -0,0 +1,5 @@ +#import + +@interface ImageSegmentation : NSObject + +@end \ No newline at end of file diff --git a/ios/RnExecutorch/ImageSegmentation.mm b/ios/RnExecutorch/ImageSegmentation.mm new file mode 100644 index 00000000..19cbe664 --- /dev/null +++ b/ios/RnExecutorch/ImageSegmentation.mm @@ -0,0 +1,63 @@ +#import "ImageSegmentation.h" +#import "models/image_segmentation/ImageSegmentationModel.h" +#import "models/BaseModel.h" +#import "utils/ETError.h" +#import +#import +#import +#import "ImageProcessor.h" + +@implementation ImageSegmentation { + ImageSegmentationModel *model; +} + +RCT_EXPORT_MODULE() + +- (void)loadModule:(NSString *)modelSource + resolve:(RCTPromiseResolveBlock)resolve + reject:(RCTPromiseRejectBlock)reject { + + model = [[ImageSegmentationModel alloc] init]; + [model + loadModel:[NSURL URLWithString:modelSource] + completion:^(BOOL success, NSNumber *errorCode) { + if (success) { + resolve(errorCode); + return; + } + + reject(@"init_module_error", + [NSString stringWithFormat:@"%ld", (long)[errorCode longValue]], + nil); + return; + }]; +} + +- (void)forward:(NSString *)input + classesOfInterest:(NSArray *)classesOfInterest + resize:(BOOL)resize + resolve:(RCTPromiseResolveBlock)resolve + reject:(RCTPromiseRejectBlock)reject { + + @try { + cv::Mat image = [ImageProcessor readImage:input]; + NSDictionary *result = [model runModel:image + returnClasses:classesOfInterest + resize:resize]; + + resolve(result); + return; + } @catch (NSException *exception) { + NSLog(@"An exception occurred: %@, %@", exception.name, exception.reason); + reject(@"forward_error", + [NSString stringWithFormat:@"%@", exception.reason], nil); + return; + } +} + +- (std::shared_ptr)getTurboModule: + (const facebook::react::ObjCTurboModule::InitParams &)params { + return std::make_shared(params); +} + +@end diff --git a/ios/RnExecutorch/StyleTransfer.mm b/ios/RnExecutorch/StyleTransfer.mm index 08e8d4a3..52930cd4 100644 --- a/ios/RnExecutorch/StyleTransfer.mm +++ b/ios/RnExecutorch/StyleTransfer.mm @@ -1,7 +1,7 @@ #import "StyleTransfer.h" #import "ImageProcessor.h" #import "models/BaseModel.h" -#import "models/StyleTransferModel.h" +#import "models/style_transfer/StyleTransferModel.h" #import "utils/ETError.h" #import #import diff --git a/ios/RnExecutorch/models/classification/ClassificationModel.mm b/ios/RnExecutorch/models/classification/ClassificationModel.mm index 8e7973e2..0306e67c 100644 --- a/ios/RnExecutorch/models/classification/ClassificationModel.mm +++ b/ios/RnExecutorch/models/classification/ClassificationModel.mm @@ -1,7 +1,7 @@ #import "ClassificationModel.h" #import "../../utils/ImageProcessor.h" +#import "../../utils/Numerical.h" #import "Constants.h" -#import "Utils.h" #import "opencv2/opencv.hpp" @implementation ClassificationModel diff --git a/ios/RnExecutorch/models/image_segmentation/Constants.h b/ios/RnExecutorch/models/image_segmentation/Constants.h new file mode 100644 index 00000000..889556d7 --- /dev/null +++ b/ios/RnExecutorch/models/image_segmentation/Constants.h @@ -0,0 +1,5 @@ +#import +#import + + +extern const std::vector deeplabv3_resnet50_labels; diff --git a/ios/RnExecutorch/models/image_segmentation/Constants.mm b/ios/RnExecutorch/models/image_segmentation/Constants.mm new file mode 100644 index 00000000..84ce9ea6 --- /dev/null +++ b/ios/RnExecutorch/models/image_segmentation/Constants.mm @@ -0,0 +1,10 @@ +#import "Constants.h" +#import +#import + +const std::vector deeplabv3_resnet50_labels = { + "BACKGROUND", "AEROPLANE", "BICYCLE", "BIRD", "BOAT", + "BOTTLE", "BUS", "CAR", "CAT", "CHAIR", "COW", "DININGTABLE", + "DOG", "HORSE", "MOTORBIKE", "PERSON", "POTTEDPLANT", "SHEEP", + "SOFA", "TRAIN", "TVMONITOR" +}; \ No newline at end of file diff --git a/ios/RnExecutorch/models/image_segmentation/ImageSegmentationModel.h b/ios/RnExecutorch/models/image_segmentation/ImageSegmentationModel.h new file mode 100644 index 00000000..a58733a1 --- /dev/null +++ b/ios/RnExecutorch/models/image_segmentation/ImageSegmentationModel.h @@ -0,0 +1,10 @@ +#import "../BaseModel.h" +#import "opencv2/opencv.hpp" + +@interface ImageSegmentationModel : BaseModel +- (cv::Size)getModelImageSize; +- (NSDictionary *)runModel:(cv::Mat &)input + returnClasses:(NSArray *)classesOfInterest + resize:(BOOL)resize; + +@end \ No newline at end of file diff --git a/ios/RnExecutorch/models/image_segmentation/ImageSegmentationModel.mm b/ios/RnExecutorch/models/image_segmentation/ImageSegmentationModel.mm new file mode 100644 index 00000000..951687c5 --- /dev/null +++ b/ios/RnExecutorch/models/image_segmentation/ImageSegmentationModel.mm @@ -0,0 +1,147 @@ +#import "ImageSegmentationModel.h" +#import +#import +#import +#import "../../utils/ImageProcessor.h" +#import "../../utils/Numerical.h" +#import "../../utils/Conversions.h" +#import "opencv2/opencv.hpp" +#import "Constants.h" + +@interface ImageSegmentationModel () + - (NSArray *)preprocess:(cv::Mat &)input; + - (NSDictionary *)postprocess:(NSArray *)output + returnClasses:(NSArray *)classesOfInterest + resize:(BOOL)resize; +@end + +@implementation ImageSegmentationModel { + cv::Size originalSize; +} + +- (cv::Size)getModelImageSize { + NSArray *inputShape = [module getInputShape:@0]; + NSNumber *widthNumber = inputShape.lastObject; + NSNumber *heightNumber = inputShape[inputShape.count - 2]; + + int height = [heightNumber intValue]; + int width = [widthNumber intValue]; + + return cv::Size(height, width); +} + +- (NSArray *)preprocess:(cv::Mat &)input { + originalSize = cv::Size(input.cols, input.rows); + + cv::Size modelImageSize = [self getModelImageSize]; + cv::Mat output; + cv::resize(input, output, modelImageSize); + + NSArray *modelInput = [ImageProcessor matToNSArray:output]; + return modelInput; +} + +std::vector extractResults(NSArray *result, std::size_t numLabels, + cv::Size modelImageSize, cv::Size originalSize, BOOL resize) { + std::size_t numModelPixels = modelImageSize.height * modelImageSize.width; + + std::vector resizedLabelScores(numLabels); + for (std::size_t label = 0; label < numLabels; ++label) { + cv::Mat labelMat = cv::Mat(modelImageSize, CV_64F); + + for(std::size_t pixel = 0; pixel < numModelPixels; ++pixel){ + int row = pixel / modelImageSize.width; + int col = pixel % modelImageSize.width; + labelMat.at(row, col) = [result[label * numModelPixels + pixel] doubleValue]; + } + + if (resize) { + cv::resize(labelMat, resizedLabelScores[label], originalSize); + } + else { + resizedLabelScores[label] = std::move(labelMat); + } + } + return resizedLabelScores; +} + +void adjustScoresPerPixel(std::vector& labelScores, cv::Mat& argMax, + cv::Size outputSize, std::size_t numLabels) { + std::size_t numOutputPixels = outputSize.height * outputSize.width; + for (std::size_t pixel = 0; pixel < numOutputPixels; ++pixel) { + int row = pixel / outputSize.width; + int col = pixel % outputSize.width; + std::vector scores; + scores.reserve(numLabels); + for (const auto& mat : labelScores) { + scores.push_back(mat.at(row, col)); + } + + std::vector adjustedScores = softmax(scores); + + for (std::size_t label = 0; label < numLabels; ++label) { + labelScores[label].at(row, col) = adjustedScores[label]; + } + + auto maxIt = std::max_element(scores.begin(), scores.end()); + argMax.at(row, col) = std::distance(scores.begin(), maxIt); + } +} + +- (NSDictionary *)postprocess:(NSArray *)output + returnClasses:(NSArray *)classesOfInterest + resize:(BOOL)resize { + cv::Size modelImageSize = [self getModelImageSize]; + + std::size_t numLabels = deeplabv3_resnet50_labels.size(); + + NSAssert((std::size_t)output.count == numLabels * modelImageSize.height * modelImageSize.width, + @"Model generated unexpected output size."); + + // For each label extract it's matrix, + // and rescale it to the original size if `resize` + std::vector resizedLabelScores = + extractResults(output, numLabels, modelImageSize, originalSize, resize); + + cv::Size outputSize = resize ? originalSize : modelImageSize; + cv::Mat argMax = cv::Mat(outputSize, CV_32S); + + // For each pixel apply softmax across all the labels and calculate the argMax + adjustScoresPerPixel(resizedLabelScores, argMax, outputSize, numLabels); + + std::unordered_set labelSet; + + for (id label in classesOfInterest) { + labelSet.insert(std::string([label UTF8String])); + } + + NSMutableDictionary *result = [NSMutableDictionary dictionary]; + + // Convert to NSArray and populate the final dictionary + for (std::size_t label = 0; label < numLabels; ++label) { + if (labelSet.contains(deeplabv3_resnet50_labels[label])){ + NSString *labelString = @(deeplabv3_resnet50_labels[label].c_str()); + NSArray *arr = simpleMatToNSArray(resizedLabelScores[label]); + result[labelString] = arr; + } + } + + result[@"ARGMAX"] = simpleMatToNSArray(argMax); + + return result; +} + +- (NSDictionary *)runModel:(cv::Mat &)input + returnClasses:(NSArray *)classesOfInterest + resize:(BOOL)resize { + NSArray *modelInput = [self preprocess:input]; + NSArray *result = [self forward:modelInput]; + + NSDictionary *output = [self postprocess:result[0] + returnClasses:classesOfInterest + resize:resize]; + + return output; +} + +@end diff --git a/ios/RnExecutorch/models/StyleTransferModel.h b/ios/RnExecutorch/models/style_transfer/StyleTransferModel.h similarity index 90% rename from ios/RnExecutorch/models/StyleTransferModel.h rename to ios/RnExecutorch/models/style_transfer/StyleTransferModel.h index 1fd91d7b..20cdf6dd 100644 --- a/ios/RnExecutorch/models/StyleTransferModel.h +++ b/ios/RnExecutorch/models/style_transfer/StyleTransferModel.h @@ -1,4 +1,4 @@ -#import "BaseModel.h" +#import "../BaseModel.h" #import "opencv2/opencv.hpp" @interface StyleTransferModel : BaseModel diff --git a/ios/RnExecutorch/models/StyleTransferModel.mm b/ios/RnExecutorch/models/style_transfer/StyleTransferModel.mm similarity index 97% rename from ios/RnExecutorch/models/StyleTransferModel.mm rename to ios/RnExecutorch/models/style_transfer/StyleTransferModel.mm index 6051e24b..6a351431 100644 --- a/ios/RnExecutorch/models/StyleTransferModel.mm +++ b/ios/RnExecutorch/models/style_transfer/StyleTransferModel.mm @@ -1,5 +1,5 @@ #import "StyleTransferModel.h" -#import "../utils/ImageProcessor.h" +#import "../../utils/ImageProcessor.h" #import "opencv2/opencv.hpp" @implementation StyleTransferModel { diff --git a/ios/RnExecutorch/utils/Conversions.h b/ios/RnExecutorch/utils/Conversions.h new file mode 100644 index 00000000..a83ec5fb --- /dev/null +++ b/ios/RnExecutorch/utils/Conversions.h @@ -0,0 +1,15 @@ +#import "opencv2/opencv.hpp" + +// Convert a matrix containing a single value per cell to a NSArray +template +NSArray* simpleMatToNSArray(const cv::Mat& mat) { + std::size_t numPixels = mat.rows * mat.cols; + NSMutableArray *arr = [[NSMutableArray alloc] initWithCapacity:numPixels]; + + for (std::size_t x = 0; x < mat.rows; ++x) { + for (std::size_t y = 0; y < mat.cols; ++y) { + arr[x * mat.cols + y] = @(mat.at(x, y)); + } + } + return arr; +} diff --git a/ios/RnExecutorch/models/classification/Utils.h b/ios/RnExecutorch/utils/Numerical.h similarity index 100% rename from ios/RnExecutorch/models/classification/Utils.h rename to ios/RnExecutorch/utils/Numerical.h diff --git a/ios/RnExecutorch/models/classification/Utils.mm b/ios/RnExecutorch/utils/Numerical.mm similarity index 100% rename from ios/RnExecutorch/models/classification/Utils.mm rename to ios/RnExecutorch/utils/Numerical.mm diff --git a/src/hooks/computer_vision/useImageSegmentation.ts b/src/hooks/computer_vision/useImageSegmentation.ts new file mode 100644 index 00000000..4e562d6b --- /dev/null +++ b/src/hooks/computer_vision/useImageSegmentation.ts @@ -0,0 +1,68 @@ +import { useState } from 'react'; +import { _ImageSegmentationModule } from '../../native/RnExecutorchModules'; +import { ETError, getError } from '../../Error'; +import { useModule } from '../useModule'; +import { DeeplabLabel } from '../../types/image_segmentation'; + +interface Props { + modelSource: string | number; +} + +export const useImageSegmentation = ({ + modelSource, +}: Props): { + error: string | null; + isReady: boolean; + isGenerating: boolean; + downloadProgress: number; + forward: ( + input: string, + classesOfInterest?: DeeplabLabel[], + resize?: boolean + ) => Promise<{ [key in DeeplabLabel]?: number[] }>; +} => { + const [module, _] = useState(() => new _ImageSegmentationModule()); + const [isGenerating, setIsGenerating] = useState(false); + const { error, isReady, downloadProgress } = useModule({ + modelSource, + module, + }); + + const forward = async ( + input: string, + classesOfInterest?: DeeplabLabel[], + resize?: boolean + ) => { + if (!isReady) { + throw new Error(getError(ETError.ModuleNotLoaded)); + } + if (isGenerating) { + throw new Error(getError(ETError.ModelGenerating)); + } + + try { + setIsGenerating(true); + const stringDict = await module.forward( + input, + (classesOfInterest || []).map((label) => DeeplabLabel[label]), + resize || false + ); + + let enumDict: { [key in DeeplabLabel]?: number[] } = {}; + + for (const key in stringDict) { + if (key in DeeplabLabel) { + const enumKey = DeeplabLabel[key as keyof typeof DeeplabLabel]; + enumDict[enumKey] = stringDict[key]; + } + } + return enumDict; + } catch (e) { + throw new Error(getError(e)); + } finally { + setIsGenerating(false); + } + }; + + return { error, isReady, isGenerating, downloadProgress, forward }; +}; diff --git a/src/index.tsx b/src/index.tsx index 7ae7a7ad..c4ae2f55 100644 --- a/src/index.tsx +++ b/src/index.tsx @@ -2,6 +2,7 @@ export * from './hooks/computer_vision/useClassification'; export * from './hooks/computer_vision/useObjectDetection'; export * from './hooks/computer_vision/useStyleTransfer'; +export * from './hooks/computer_vision/useImageSegmentation'; export * from './hooks/computer_vision/useOCR'; export * from './hooks/computer_vision/useVerticalOCR'; @@ -14,6 +15,7 @@ export * from './hooks/general/useExecutorchModule'; export * from './modules/computer_vision/ClassificationModule'; export * from './modules/computer_vision/ObjectDetectionModule'; export * from './modules/computer_vision/StyleTransferModule'; +export * from './modules/computer_vision/ImageSegmentationModule'; export * from './modules/computer_vision/OCRModule'; export * from './modules/computer_vision/VerticalOCRModule'; @@ -28,6 +30,7 @@ export * from './utils/listDownloadedResources'; // types export * from './types/object_detection'; export * from './types/ocr'; +export * from './types/image_segmentation'; // constants export * from './constants/modelUrls'; diff --git a/src/modules/BaseModule.ts b/src/modules/BaseModule.ts index e977836f..56cf2e3d 100644 --- a/src/modules/BaseModule.ts +++ b/src/modules/BaseModule.ts @@ -1,4 +1,5 @@ import { + _ImageSegmentationModule, _StyleTransferModule, _ObjectDetectionModule, _ClassificationModule, @@ -10,6 +11,7 @@ import { getError } from '../Error'; export class BaseModule { static module: + | _ImageSegmentationModule | _StyleTransferModule | _ObjectDetectionModule | _ClassificationModule diff --git a/src/modules/computer_vision/ImageSegmentationModule.ts b/src/modules/computer_vision/ImageSegmentationModule.ts new file mode 100644 index 00000000..1d078c1c --- /dev/null +++ b/src/modules/computer_vision/ImageSegmentationModule.ts @@ -0,0 +1,23 @@ +import { BaseModule } from '../BaseModule'; +import { _ImageSegmentationModule } from '../../native/RnExecutorchModules'; +import { getError } from '../../Error'; + +export class ImageSegmentationModule extends BaseModule { + static module = new _ImageSegmentationModule(); + + static async forward( + input: string, + classesOfInterest: string[], + resize: boolean + ) { + try { + return await (this.module.forward( + input, + classesOfInterest, + resize + ) as ReturnType<_ImageSegmentationModule['forward']>); + } catch (e) { + throw new Error(getError(e)); + } + } +} diff --git a/src/native/NativeImageSegmentation.ts b/src/native/NativeImageSegmentation.ts new file mode 100644 index 00000000..c66c8743 --- /dev/null +++ b/src/native/NativeImageSegmentation.ts @@ -0,0 +1,14 @@ +import type { TurboModule } from 'react-native'; +import { TurboModuleRegistry } from 'react-native'; + +export interface Spec extends TurboModule { + loadModule(modelSource: string): Promise; + + forward( + input: string, + classesOfInterest: string[], + resize: boolean + ): Promise<{ [category: string]: number[] }>; +} + +export default TurboModuleRegistry.get('ImageSegmentation'); diff --git a/src/native/RnExecutorchModules.ts b/src/native/RnExecutorchModules.ts index b1edcf52..62ebd309 100644 --- a/src/native/RnExecutorchModules.ts +++ b/src/native/RnExecutorchModules.ts @@ -2,6 +2,7 @@ import { Platform } from 'react-native'; import { Spec as ClassificationInterface } from './NativeClassification'; import { Spec as ObjectDetectionInterface } from './NativeObjectDetection'; import { Spec as StyleTransferInterface } from './NativeStyleTransfer'; +import { Spec as ImageSegmentationInterface } from './NativeImageSegmentation'; import { Spec as ETModuleInterface } from './NativeETModule'; import { Spec as OCRInterface } from './NativeOCR'; import { Spec as VerticalOCRInterface } from './NativeVerticalOCR'; @@ -51,6 +52,19 @@ const Classification = ClassificationSpec } ); +const ImageSegmentationSpec = require('./NativeImageSegmentation').default; + +const ImageSegmentation = ImageSegmentationSpec + ? ImageSegmentationSpec + : new Proxy( + {}, + { + get() { + throw new Error(LINKING_ERROR); + }, + } + ); + const ObjectDetectionSpec = require('./NativeObjectDetection').default; const ObjectDetection = ObjectDetectionSpec @@ -116,6 +130,21 @@ const VerticalOCR = VerticalOCRSpec } ); +class _ImageSegmentationModule { + async forward( + input: string, + classesOfInteres: string[], + resize: boolean + ): ReturnType { + return await ImageSegmentation.forward(input, classesOfInteres, resize); + } + async loadModule( + modelSource: string | number + ): ReturnType { + return await ImageSegmentation.loadModule(modelSource); + } +} + class _ObjectDetectionModule { async forward( input: string @@ -239,12 +268,14 @@ export { Classification, ObjectDetection, StyleTransfer, + ImageSegmentation, SpeechToText, OCR, VerticalOCR, _ETModule, _ClassificationModule, _StyleTransferModule, + _ImageSegmentationModule, _ObjectDetectionModule, _SpeechToTextModule, _OCRModule, diff --git a/src/types/image_segmentation.ts b/src/types/image_segmentation.ts new file mode 100644 index 00000000..bc7d254d --- /dev/null +++ b/src/types/image_segmentation.ts @@ -0,0 +1,24 @@ +export enum DeeplabLabel { + BACKGROUND, + AEROPLANE, + BICYCLE, + BIRD, + BOAT, + BOTTLE, + BUS, + CAR, + CAT, + CHAIR, + COW, + DININGTABLE, + DOG, + HORSE, + MOTORBIKE, + PERSON, + POTTEDPLANT, + SHEEP, + SOFA, + TRAIN, + TVMONITOR, + ARGMAX, // Additional label not present in the model +}