Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Image segmentation for ios #113

Open
wants to merge 9 commits into
base: main
Choose a base branch
from
5 changes: 5 additions & 0 deletions ios/RnExecutorch/ImageSegmentation.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
#import <RnExecutorchSpec/RnExecutorchSpec.h>

@interface ImageSegmentation : NSObject <NativeImageSegmentationSpec>

@end
63 changes: 63 additions & 0 deletions ios/RnExecutorch/ImageSegmentation.mm
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
#import "ImageSegmentation.h"
#import "models/image_segmentation/ImageSegmentationModel.h"
#import "models/BaseModel.h"
#import "utils/ETError.h"
#import <ExecutorchLib/ETModel.h>
#import <React/RCTBridgeModule.h>
#import <opencv2/opencv.hpp>
#import "ImageProcessor.h"

@implementation ImageSegmentation {
ImageSegmentationModel *model;
}

RCT_EXPORT_MODULE()

- (void)loadModule:(NSString *)modelSource
resolve:(RCTPromiseResolveBlock)resolve
reject:(RCTPromiseRejectBlock)reject {

model = [[ImageSegmentationModel alloc] init];
[model
loadModel:[NSURL URLWithString:modelSource]
completion:^(BOOL success, NSNumber *errorCode) {
if (success) {
resolve(errorCode);
return;
}

reject(@"init_module_error",
[NSString stringWithFormat:@"%ld", (long)[errorCode longValue]],
nil);
return;
}];
}

- (void)forward:(NSString *)input
classesOfInterest:(NSArray *)classesOfInterest
resize:(BOOL)resize
resolve:(RCTPromiseResolveBlock)resolve
reject:(RCTPromiseRejectBlock)reject {

@try {
cv::Mat image = [ImageProcessor readImage:input];
NSDictionary *result = [model runModel:image
returnClasses:classesOfInterest
resize:resize];

resolve(result);
return;
} @catch (NSException *exception) {
NSLog(@"An exception occurred: %@, %@", exception.name, exception.reason);
reject(@"forward_error",
[NSString stringWithFormat:@"%@", exception.reason], nil);
return;
}
}

- (std::shared_ptr<facebook::react::TurboModule>)getTurboModule:
(const facebook::react::ObjCTurboModule::InitParams &)params {
return std::make_shared<facebook::react::NativeImageSegmentationSpecJSI>(params);
}

@end
2 changes: 1 addition & 1 deletion ios/RnExecutorch/StyleTransfer.mm
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
#import "StyleTransfer.h"
#import "ImageProcessor.h"
#import "models/BaseModel.h"
#import "models/StyleTransferModel.h"
#import "models/style_transfer/StyleTransferModel.h"
#import "utils/ETError.h"
#import <ExecutorchLib/ETModel.h>
#import <React/RCTBridgeModule.h>
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
#import "ClassificationModel.h"
#import "../../utils/ImageProcessor.h"
#import "../../utils/Numerical.h"
#import "Constants.h"
#import "Utils.h"
#import "opencv2/opencv.hpp"

@implementation ClassificationModel
Expand Down
5 changes: 5 additions & 0 deletions ios/RnExecutorch/models/image_segmentation/Constants.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
#import <string>
#import <vector>


extern const std::vector<std::string> deeplabv3_resnet50_labels;
10 changes: 10 additions & 0 deletions ios/RnExecutorch/models/image_segmentation/Constants.mm
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
#import "Constants.h"
#import <string>
#import <vector>

const std::vector<std::string> deeplabv3_resnet50_labels = {
"BACKGROUND", "AEROPLANE", "BICYCLE", "BIRD", "BOAT",
"BOTTLE", "BUS", "CAR", "CAT", "CHAIR", "COW", "DININGTABLE",
"DOG", "HORSE", "MOTORBIKE", "PERSON", "POTTEDPLANT", "SHEEP",
"SOFA", "TRAIN", "TVMONITOR"
};
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
#import "../BaseModel.h"
#import "opencv2/opencv.hpp"

@interface ImageSegmentationModel : BaseModel
- (cv::Size)getModelImageSize;
- (NSDictionary *)runModel:(cv::Mat &)input
returnClasses:(NSArray *)classesOfInterest
resize:(BOOL)resize;

@end
147 changes: 147 additions & 0 deletions ios/RnExecutorch/models/image_segmentation/ImageSegmentationModel.mm
Original file line number Diff line number Diff line change
@@ -0,0 +1,147 @@
#import "ImageSegmentationModel.h"
#import <unordered_set>
#import <algorithm>
#import <vector>
#import "../../utils/ImageProcessor.h"
#import "../../utils/Numerical.h"
#import "../../utils/Conversions.h"
#import "opencv2/opencv.hpp"
#import "Constants.h"

@interface ImageSegmentationModel ()
- (NSArray *)preprocess:(cv::Mat &)input;
- (NSDictionary *)postprocess:(NSArray *)output
returnClasses:(NSArray *)classesOfInterest
resize:(BOOL)resize;
@end

@implementation ImageSegmentationModel {
cv::Size originalSize;
}

- (cv::Size)getModelImageSize {
NSArray *inputShape = [module getInputShape:@0];
NSNumber *widthNumber = inputShape.lastObject;
NSNumber *heightNumber = inputShape[inputShape.count - 2];

int height = [heightNumber intValue];
int width = [widthNumber intValue];

return cv::Size(height, width);
}

- (NSArray *)preprocess:(cv::Mat &)input {
originalSize = cv::Size(input.cols, input.rows);

cv::Size modelImageSize = [self getModelImageSize];
cv::Mat output;
cv::resize(input, output, modelImageSize);

NSArray *modelInput = [ImageProcessor matToNSArray:output];
return modelInput;
}

std::vector<cv::Mat> extractResults(NSArray *result, std::size_t numLabels,
cv::Size modelImageSize, cv::Size originalSize, BOOL resize) {
std::size_t numModelPixels = modelImageSize.height * modelImageSize.width;

std::vector<cv::Mat> resizedLabelScores(numLabels);
for (std::size_t label = 0; label < numLabels; ++label) {
cv::Mat labelMat = cv::Mat(modelImageSize, CV_64F);

for(std::size_t pixel = 0; pixel < numModelPixels; ++pixel){
int row = pixel / modelImageSize.width;
int col = pixel % modelImageSize.width;
labelMat.at<double>(row, col) = [result[label * numModelPixels + pixel] doubleValue];
}

if (resize) {
cv::resize(labelMat, resizedLabelScores[label], originalSize);
}
else {
resizedLabelScores[label] = std::move(labelMat);
}
}
return resizedLabelScores;
}

void adjustScoresPerPixel(std::vector<cv::Mat>& labelScores, cv::Mat& argMax,
cv::Size outputSize, std::size_t numLabels) {
std::size_t numOutputPixels = outputSize.height * outputSize.width;
for (std::size_t pixel = 0; pixel < numOutputPixels; ++pixel) {
int row = pixel / outputSize.width;
int col = pixel % outputSize.width;
std::vector<double> scores;
scores.reserve(numLabels);
for (const auto& mat : labelScores) {
scores.push_back(mat.at<double>(row, col));
}

std::vector<double> adjustedScores = softmax(scores);

for (std::size_t label = 0; label < numLabels; ++label) {
labelScores[label].at<double>(row, col) = adjustedScores[label];
}

auto maxIt = std::max_element(scores.begin(), scores.end());
argMax.at<int>(row, col) = std::distance(scores.begin(), maxIt);
}
}

- (NSDictionary *)postprocess:(NSArray *)output
returnClasses:(NSArray *)classesOfInterest
resize:(BOOL)resize {
cv::Size modelImageSize = [self getModelImageSize];

std::size_t numLabels = deeplabv3_resnet50_labels.size();

NSAssert((std::size_t)output.count == numLabels * modelImageSize.height * modelImageSize.width,
@"Model generated unexpected output size.");

// For each label extract it's matrix,
// and rescale it to the original size if `resize`
std::vector<cv::Mat> resizedLabelScores =
extractResults(output, numLabels, modelImageSize, originalSize, resize);

cv::Size outputSize = resize ? originalSize : modelImageSize;
cv::Mat argMax = cv::Mat(outputSize, CV_32S);

// For each pixel apply softmax across all the labels and calculate the argMax
adjustScoresPerPixel(resizedLabelScores, argMax, outputSize, numLabels);

std::unordered_set<std::string> labelSet;

for (id label in classesOfInterest) {
labelSet.insert(std::string([label UTF8String]));
}

NSMutableDictionary *result = [NSMutableDictionary dictionary];

// Convert to NSArray and populate the final dictionary
for (std::size_t label = 0; label < numLabels; ++label) {
if (labelSet.contains(deeplabv3_resnet50_labels[label])){
NSString *labelString = @(deeplabv3_resnet50_labels[label].c_str());
NSArray *arr = simpleMatToNSArray<double>(resizedLabelScores[label]);
result[labelString] = arr;
}
}

result[@"ARGMAX"] = simpleMatToNSArray<int>(argMax);

return result;
}

- (NSDictionary *)runModel:(cv::Mat &)input
returnClasses:(NSArray *)classesOfInterest
resize:(BOOL)resize {
NSArray *modelInput = [self preprocess:input];
NSArray *result = [self forward:modelInput];

NSDictionary *output = [self postprocess:result[0]
returnClasses:classesOfInterest
resize:resize];

return output;
}

@end
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
#import "BaseModel.h"
#import "../BaseModel.h"
#import "opencv2/opencv.hpp"

@interface StyleTransferModel : BaseModel
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
#import "StyleTransferModel.h"
#import "../utils/ImageProcessor.h"
#import "../../utils/ImageProcessor.h"
#import "opencv2/opencv.hpp"

@implementation StyleTransferModel {
Expand Down
15 changes: 15 additions & 0 deletions ios/RnExecutorch/utils/Conversions.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
#import "opencv2/opencv.hpp"

// Convert a matrix containing a single value per cell to a NSArray
template <typename T>
NSArray* simpleMatToNSArray(const cv::Mat& mat) {
std::size_t numPixels = mat.rows * mat.cols;
NSMutableArray *arr = [[NSMutableArray alloc] initWithCapacity:numPixels];

for (std::size_t x = 0; x < mat.rows; ++x) {
for (std::size_t y = 0; y < mat.cols; ++y) {
arr[x * mat.cols + y] = @(mat.at<T>(x, y));
}
}
return arr;
}
68 changes: 68 additions & 0 deletions src/hooks/computer_vision/useImageSegmentation.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
import { useState } from 'react';
import { _ImageSegmentationModule } from '../../native/RnExecutorchModules';
import { ETError, getError } from '../../Error';
import { useModule } from '../useModule';
import { DeeplabLabel } from '../../types/image_segmentation';

interface Props {
modelSource: string | number;
}

export const useImageSegmentation = ({
modelSource,
}: Props): {
error: string | null;
isReady: boolean;
isGenerating: boolean;
downloadProgress: number;
forward: (
input: string,
classesOfInterest?: DeeplabLabel[],
resize?: boolean
) => Promise<{ [key in DeeplabLabel]?: number[] }>;
} => {
const [module, _] = useState(() => new _ImageSegmentationModule());
const [isGenerating, setIsGenerating] = useState(false);
const { error, isReady, downloadProgress } = useModule({
modelSource,
module,
});

const forward = async (
input: string,
classesOfInterest?: DeeplabLabel[],
resize?: boolean
) => {
if (!isReady) {
throw new Error(getError(ETError.ModuleNotLoaded));
}
if (isGenerating) {
throw new Error(getError(ETError.ModelGenerating));
}

try {
setIsGenerating(true);
const stringDict = await module.forward(
input,
(classesOfInterest || []).map((label) => DeeplabLabel[label]),
resize || false
);

let enumDict: { [key in DeeplabLabel]?: number[] } = {};

for (const key in stringDict) {
if (key in DeeplabLabel) {
const enumKey = DeeplabLabel[key as keyof typeof DeeplabLabel];
enumDict[enumKey] = stringDict[key];
}
}
return enumDict;
} catch (e) {
throw new Error(getError(e));
} finally {
setIsGenerating(false);
}
};

return { error, isReady, isGenerating, downloadProgress, forward };
};
3 changes: 3 additions & 0 deletions src/index.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
export * from './hooks/computer_vision/useClassification';
export * from './hooks/computer_vision/useObjectDetection';
export * from './hooks/computer_vision/useStyleTransfer';
export * from './hooks/computer_vision/useImageSegmentation';
export * from './hooks/computer_vision/useOCR';
export * from './hooks/computer_vision/useVerticalOCR';

Expand All @@ -13,6 +14,7 @@ export * from './hooks/general/useExecutorchModule';
export * from './modules/computer_vision/ClassificationModule';
export * from './modules/computer_vision/ObjectDetectionModule';
export * from './modules/computer_vision/StyleTransferModule';
export * from './modules/computer_vision/ImageSegmentationModule';
export * from './modules/computer_vision/OCRModule';
export * from './modules/computer_vision/VerticalOCRModule';

Expand All @@ -26,6 +28,7 @@ export * from './utils/listDownloadedResources';
// types
export * from './types/object_detection';
export * from './types/ocr';
export * from './types/image_segmentation';

// constants
export * from './constants/modelUrls';
Loading