Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Image segmentation for ios #113

Open
wants to merge 9 commits into
base: main
Choose a base branch
from
5 changes: 5 additions & 0 deletions ios/RnExecutorch/ImageSegmentation.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
#import <RnExecutorchSpec/RnExecutorchSpec.h>

@interface ImageSegmentation : NSObject <NativeImageSegmentationSpec>

@end
60 changes: 60 additions & 0 deletions ios/RnExecutorch/ImageSegmentation.mm
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
#import "ImageSegmentation.h"
#import "models/image_segmentation/ImageSegmentationModel.h"
#import "models/BaseModel.h"
#import "utils/ETError.h"
#import <ExecutorchLib/ETModel.h>
#import <React/RCTBridgeModule.h>
#import <opencv2/opencv.hpp>
#import "ImageProcessor.h"

@implementation ImageSegmentation {
ImageSegmentationModel *model;
}

RCT_EXPORT_MODULE()

- (void)loadModule:(NSString *)modelSource
resolve:(RCTPromiseResolveBlock)resolve
reject:(RCTPromiseRejectBlock)reject {

model = [[ImageSegmentationModel alloc] init];
[model
loadModel:[NSURL URLWithString:modelSource]
completion:^(BOOL success, NSNumber *errorCode) {
if (success) {
resolve(errorCode);
return;
}

reject(@"init_module_error",
[NSString stringWithFormat:@"%ld", (long)[errorCode longValue]],
nil);
return;
}];
}

- (void)forward:(NSString *)input
classesOfInterest:(NSArray *)classesOfInterest
resolve:(RCTPromiseResolveBlock)resolve
reject:(RCTPromiseRejectBlock)reject {

@try {
cv::Mat image = [ImageProcessor readImage:input];
NSDictionary *result = [model runModel:image returnClasses:classesOfInterest];

resolve(result);
return;
} @catch (NSException *exception) {
NSLog(@"An exception occurred: %@, %@", exception.name, exception.reason);
reject(@"forward_error",
[NSString stringWithFormat:@"%@", exception.reason], nil);
return;
}
}

- (std::shared_ptr<facebook::react::TurboModule>)getTurboModule:
(const facebook::react::ObjCTurboModule::InitParams &)params {
return std::make_shared<facebook::react::NativeImageSegmentationSpecJSI>(params);
}

@end
2 changes: 1 addition & 1 deletion ios/RnExecutorch/StyleTransfer.mm
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
#import "StyleTransfer.h"
#import "ImageProcessor.h"
#import "models/BaseModel.h"
#import "models/StyleTransferModel.h"
#import "models/style_transfer/StyleTransferModel.h"
#import "utils/ETError.h"
#import <ExecutorchLib/ETModel.h>
#import <React/RCTBridgeModule.h>
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
#import "ClassificationModel.h"
#import "../../utils/ImageProcessor.h"
#import "../../utils/Numerical.h"
#import "Constants.h"
#import "Utils.h"
#import "opencv2/opencv.hpp"

@implementation ClassificationModel
Expand Down
5 changes: 5 additions & 0 deletions ios/RnExecutorch/models/image_segmentation/Constants.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
#import <string>
#import <vector>


extern const std::vector<std::string> deeplabv3_resnet50_labels;
10 changes: 10 additions & 0 deletions ios/RnExecutorch/models/image_segmentation/Constants.mm
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
#import "Constants.h"
#import <string>
#import <vector>

const std::vector<std::string> deeplabv3_resnet50_labels = {
"background", "aeroplane", "bicycle", "bird", "boat",
"bottle", "bus", "car", "cat", "chair", "cow", "diningtable",
"dog", "horse", "motorbike", "person", "pottedplant", "sheep",
"sofa", "train", "tvmonitor"
};
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
#import "../BaseModel.h"
#import "opencv2/opencv.hpp"

@interface ImageSegmentationModel : BaseModel
- (cv::Size)getModelImageSize;
- (NSDictionary *)runModel:(cv::Mat &)input
returnClasses:(NSArray *)classesOfInterest;

@end

template <typename T>
NSArray* matToNSArray(const cv::Mat& mat) {
std::size_t numPixels = mat.rows * mat.cols;
NSMutableArray *arr = [[NSMutableArray alloc] initWithCapacity:numPixels];

for (std::size_t x = 0; x < mat.rows; ++x) {
for (std::size_t y = 0; y < mat.cols; ++y) {
arr[x * mat.cols + y] = @(mat.at<T>(x, y));
}
}
return arr;
}
137 changes: 137 additions & 0 deletions ios/RnExecutorch/models/image_segmentation/ImageSegmentationModel.mm
Original file line number Diff line number Diff line change
@@ -0,0 +1,137 @@
#import "ImageSegmentationModel.h"
#import <unordered_set>
#import "../../utils/ImageProcessor.h"
#import "../../utils/Numerical.h"
#import "opencv2/opencv.hpp"
#import "Constants.h"

@interface ImageSegmentationModel ()
- (NSArray *)preprocess:(cv::Mat &)input;
- (NSDictionary *)postprocess:(NSArray *)output
returnClasses:(NSArray *)classesOfInterest;
@end

@implementation ImageSegmentationModel {
cv::Size originalSize;
}

- (cv::Size)getModelImageSize {
NSArray *inputShape = [module getInputShape:@0];
NSNumber *widthNumber = inputShape.lastObject;
NSNumber *heightNumber = inputShape[inputShape.count - 2];

int height = [heightNumber intValue];
int width = [widthNumber intValue];

return cv::Size(height, width);
}

- (NSArray *)preprocess:(cv::Mat &)input {
originalSize = cv::Size(input.cols, input.rows);

cv::Size modelImageSize = [self getModelImageSize];
cv::Mat output;
cv::resize(input, output, modelImageSize);

NSArray *modelInput = [ImageProcessor matToNSArray:output];
return modelInput;
}

std::vector<cv::Mat> rescaleResults(NSArray *result, std::size_t numLabels,
cv::Size modelImageSize, cv::Size originalSize) {
std::size_t numModelPixels = modelImageSize.height * modelImageSize.width;

std::vector<cv::Mat> resizedLabelScores(numLabels);
for (std::size_t label = 0; label < numLabels; ++label) {
cv::Mat labelMat = cv::Mat(modelImageSize, CV_64F);

for(std::size_t pixel = 0; pixel < numModelPixels; ++pixel){
int row = pixel / modelImageSize.width;
int col = pixel % modelImageSize.width;
labelMat.at<double>(row, col) = [result[label * numModelPixels + pixel] doubleValue];
}

cv::resize(labelMat, resizedLabelScores[label], originalSize);
}
return resizedLabelScores;
}

void adjustScoresPerPixel(std::vector<cv::Mat>& labelScores, cv::Mat& maxArg,
cv::Size originalSize, std::size_t numLabels) {
std::size_t numOriginalPixels = originalSize.height * originalSize.width;
for (std::size_t pixel = 0; pixel < numOriginalPixels; ++pixel) {
int row = pixel / originalSize.width;
int col = pixel % originalSize.width;
std::vector<double> scores;
scores.reserve(numLabels);
for (const cv::Mat& mat : labelScores) {
scores.push_back(mat.at<double>(row, col));
}

std::vector<double> adjustedScores = softmax(scores);

std::size_t maxArgIndex = 0;
double maxArgVal = 0;
for (std::size_t label = 0; label < numLabels; ++label) {
labelScores[label].at<double>(row, col) = adjustedScores[label];
if (adjustedScores[label] > maxArgVal) {
maxArgIndex = label;
maxArgVal = adjustedScores[label];
}
}

maxArg.at<int>(row, col) = maxArgIndex;
}
}

- (NSDictionary *)postprocess:(NSArray *)output
returnClasses:(NSArray *)classesOfInterest{
cv::Size modelImageSize = [self getModelImageSize];

std::size_t numLabels = deeplabv3_resnet50_labels.size();

NSAssert((std::size_t)output.count == numLabels * modelImageSize.height * modelImageSize.width,
@"Model generated unexpected output size.");

// For each label extract it's matrix and rescale it to the original size
std::vector<cv::Mat> resizedLabelScores =
rescaleResults(output, numLabels, modelImageSize, originalSize);

cv::Mat maxArg = cv::Mat(originalSize, CV_32S);

// For each pixel apply softmax across all the labels and calculate the maxArg
adjustScoresPerPixel(resizedLabelScores, maxArg, originalSize, numLabels);

std::unordered_set<std::string> labelSet;

for (id label in classesOfInterest) {
labelSet.insert(std::string([label UTF8String]));
}

NSMutableDictionary *result = [NSMutableDictionary dictionary];

// Convert to NSArray and populate the final dictionary
for (std::size_t label = 0; label < numLabels; ++label) {
if (labelSet.contains(deeplabv3_resnet50_labels[label])){
NSString *labelString = @(deeplabv3_resnet50_labels[label].c_str());
NSArray *arr = matToNSArray<double>(resizedLabelScores[label]);
result[labelString] = arr;
}
}

result[@"argmax"] = matToNSArray<int>(maxArg);

return result;
}

- (NSDictionary *)runModel:(cv::Mat &)input
returnClasses:(NSArray *)classesOfInterest {
NSArray *modelInput = [self preprocess:input];
NSArray *result = [self forward:modelInput];

NSDictionary *output = [self postprocess:result[0] returnClasses:classesOfInterest];

return output;
}

@end
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
#import "BaseModel.h"
#import "../BaseModel.h"
#import "opencv2/opencv.hpp"

@interface StyleTransferModel : BaseModel
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
#import "StyleTransferModel.h"
#import "../utils/ImageProcessor.h"
#import "../../utils/ImageProcessor.h"
#import "opencv2/opencv.hpp"

@implementation StyleTransferModel {
Expand Down
23 changes: 23 additions & 0 deletions src/constants/image_segmentation/image_segmentation.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
export const classLabels = new Map<number, string>([
[0, 'background'],
[1, 'aeroplane'],
[2, 'bicycle'],
[3, 'bird'],
[4, 'boat'],
[5, 'bottle'],
[6, 'bus'],
[7, 'car'],
[8, 'cat'],
[9, 'chair'],
[10, 'cow'],
[11, 'diningtable'],
[12, 'dog'],
[13, 'horse'],
[14, 'motorbike'],
[15, 'person'],
[16, 'pottedplant'],
[17, 'sheep'],
[18, 'sofa'],
[19, 'train'],
[20, 'tvmonitor'],
]);
65 changes: 65 additions & 0 deletions src/hooks/computer_vision/useImageSegmentation.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
import { useState, useEffect } from 'react';
import { _ImageSegmentationModule } from '../../native/RnExecutorchModules';
import { fetchResource } from '../../utils/fetchResource';
import { ETError, getError } from '../../Error';

interface Props {
modelSource: string | number;
}

export const useImageSegmentation = ({
modelSource,
}: Props): {
error: string | null;
isReady: boolean;
isGenerating: boolean;
downloadProgress: number;
forward: (
input: string,
classesOfInterest?: string[]
) => Promise<{ [category: string]: number[] }>;
} => {
const [module, _] = useState(() => new _ImageSegmentationModule());
const [error, setError] = useState<null | string>(null);
const [isReady, setIsReady] = useState(false);
const [downloadProgress, setDownloadProgress] = useState(0);
const [isGenerating, setIsGenerating] = useState(false);

useEffect(() => {
const loadModel = async () => {
if (!modelSource) return;

try {
setIsReady(false);
const fileUri = await fetchResource(modelSource, setDownloadProgress);
await module.loadModule(fileUri);
setIsReady(true);
} catch (e) {
setError(getError(e));
}
};

loadModel();
}, [modelSource, module]);

const forward = async (input: string, classesOfInterest?: string[]) => {
if (!isReady) {
throw new Error(getError(ETError.ModuleNotLoaded));
}
if (isGenerating) {
throw new Error(getError(ETError.ModelGenerating));
}

try {
setIsGenerating(true);
const output = await module.forward(input, classesOfInterest || []);
return output;
} catch (e) {
throw new Error(getError(e));
} finally {
setIsGenerating(false);
}
};

return { error, isReady, isGenerating, downloadProgress, forward };
};
2 changes: 2 additions & 0 deletions src/index.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
export * from './hooks/computer_vision/useClassification';
export * from './hooks/computer_vision/useObjectDetection';
export * from './hooks/computer_vision/useStyleTransfer';
export * from './hooks/computer_vision/useImageSegmentation';
export * from './hooks/computer_vision/useOCR';
export * from './hooks/computer_vision/useVerticalOCR';

Expand All @@ -13,6 +14,7 @@ export * from './hooks/general/useExecutorchModule';
export * from './modules/computer_vision/ClassificationModule';
export * from './modules/computer_vision/ObjectDetectionModule';
export * from './modules/computer_vision/StyleTransferModule';
export * from './modules/computer_vision/ImageSegmentationModule';
export * from './modules/computer_vision/OCRModule';
export * from './modules/computer_vision/VerticalOCRModule';

Expand Down
2 changes: 2 additions & 0 deletions src/modules/BaseModule.ts
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import {
_ImageSegmentationModule,
_StyleTransferModule,
_ObjectDetectionModule,
_ClassificationModule,
Expand All @@ -10,6 +11,7 @@ import { getError } from '../Error';

export class BaseModule {
static module:
| _ImageSegmentationModule
| _StyleTransferModule
| _ObjectDetectionModule
| _ClassificationModule
Expand Down
Loading
Loading