Skip to content

Commit 5ef4c4d

Browse files
committed
Add a way to filter what segmentation classes are returned
1 parent 6e02df8 commit 5ef4c4d

File tree

8 files changed

+91
-34
lines changed

8 files changed

+91
-34
lines changed

ios/RnExecutorch/ImageSegmentation.mm

+2-1
Original file line numberDiff line numberDiff line change
@@ -34,12 +34,13 @@ - (void)loadModule:(NSString *)modelSource
3434
}
3535

3636
- (void)forward:(NSString *)input
37+
classesOfInterest:(NSArray *)classesOfInterest
3738
resolve:(RCTPromiseResolveBlock)resolve
3839
reject:(RCTPromiseRejectBlock)reject {
3940

4041
@try {
4142
cv::Mat image = [ImageProcessor readImage:input];
42-
NSDictionary *result= [model runModel:image];
43+
NSDictionary *result = [model runModel:image returnClasses:classesOfInterest];
4344

4445
resolve(result);
4546
return;

ios/RnExecutorch/models/image_segmentation/ImageSegmentationModel.h

+2-1
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,8 @@
33

44
@interface ImageSegmentationModel : BaseModel
55
- (cv::Size)getModelImageSize;
6-
- (NSDictionary *)runModel:(cv::Mat &)input;
6+
- (NSDictionary *)runModel:(cv::Mat &)input
7+
returnClasses:(NSArray *)classesOfInterest;
78

89
@end
910

ios/RnExecutorch/models/image_segmentation/ImageSegmentationModel.mm

+20-8
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,14 @@
11
#import "ImageSegmentationModel.h"
2+
#import <unordered_set>
23
#import "../../utils/ImageProcessor.h"
34
#import "../../utils/Numerical.h"
45
#import "opencv2/opencv.hpp"
56
#import "Constants.h"
67

78
@interface ImageSegmentationModel ()
89
- (NSArray *)preprocess:(cv::Mat &)input;
9-
- (NSDictionary *)postprocess:(NSArray *)output;
10+
- (NSDictionary *)postprocess:(NSArray *)output
11+
returnClasses:(NSArray *)classesOfInterest;
1012
@end
1113

1214
@implementation ImageSegmentationModel {
@@ -35,7 +37,8 @@ - (NSArray *)preprocess:(cv::Mat &)input {
3537
return modelInput;
3638
}
3739

38-
- (NSDictionary *)postprocess:(NSArray *)output {
40+
- (NSDictionary *)postprocess:(NSArray *)output
41+
returnClasses:(NSArray *)classesOfInterest{
3942
cv::Size modelImageSize = [self getModelImageSize];
4043

4144
std::size_t numLabels = deeplabv3_resnet50_labels.size();
@@ -87,25 +90,34 @@ - (NSDictionary *)postprocess:(NSArray *)output {
8790
maxArg.at<int>(row, col) = maxArgIndex;
8891
}
8992

93+
std::unordered_set<std::string> labelSet;
94+
95+
for (id label in classesOfInterest) {
96+
labelSet.insert(std::string([label UTF8String]));
97+
}
98+
9099
NSMutableDictionary *result = [NSMutableDictionary dictionary];
91-
100+
92101
// Convert to NSArray and populate the final dictionary
93102
for (std::size_t label = 0; label < numLabels; ++label) {
94-
NSString *labelString = @(deeplabv3_resnet50_labels[label].c_str());
95-
NSMutableArray *arr = matToNSArray<double>(resizedLabelScores[label]);
96-
result[labelString] = arr;
103+
if (labelSet.contains(deeplabv3_resnet50_labels[label])){
104+
NSString *labelString = @(deeplabv3_resnet50_labels[label].c_str());
105+
NSArray *arr = matToNSArray<double>(resizedLabelScores[label]);
106+
result[labelString] = arr;
107+
}
97108
}
98109

99110
result[@"argmax"] = matToNSArray<int>(maxArg);
100111

101112
return result;
102113
}
103114

104-
- (NSDictionary *)runModel:(cv::Mat &)input {
115+
- (NSDictionary *)runModel:(cv::Mat &)input
116+
returnClasses:(NSArray *)classesOfInterest {
105117
NSArray *modelInput = [self preprocess:input];
106118
NSArray *result = [self forward:modelInput];
107119

108-
NSDictionary *output = [self postprocess:result[0]];
120+
NSDictionary *output = [self postprocess:result[0] returnClasses:classesOfInterest];
109121

110122
return output;
111123
}
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
1-
import { useState } from 'react';
1+
import { useState, useEffect } from 'react';
22
import { _ImageSegmentationModule } from '../../native/RnExecutorchModules';
3-
import { useModule } from '../useModule';
3+
import { fetchResource } from '../../utils/fetchResource';
4+
import { ETError, getError } from '../../Error';
45

56
interface Props {
67
modelSource: string | number;
@@ -13,19 +14,52 @@ export const useImageSegmentation = ({
1314
isReady: boolean;
1415
isGenerating: boolean;
1516
downloadProgress: number;
16-
forward: (input: string) => Promise<{ [category: string]: number[] }>;
17+
forward: (
18+
input: string,
19+
classesOfInterest?: string[]
20+
) => Promise<{ [category: string]: number[] }>;
1721
} => {
1822
const [module, _] = useState(() => new _ImageSegmentationModule());
19-
const {
20-
error,
21-
isReady,
22-
isGenerating,
23-
downloadProgress,
24-
forwardImage: forward,
25-
} = useModule({
26-
modelSource,
27-
module,
28-
});
23+
const [error, setError] = useState<null | string>(null);
24+
const [isReady, setIsReady] = useState(false);
25+
const [downloadProgress, setDownloadProgress] = useState(0);
26+
const [isGenerating, setIsGenerating] = useState(false);
27+
28+
useEffect(() => {
29+
const loadModel = async () => {
30+
if (!modelSource) return;
31+
32+
try {
33+
setIsReady(false);
34+
const fileUri = await fetchResource(modelSource, setDownloadProgress);
35+
await module.loadModule(fileUri);
36+
setIsReady(true);
37+
} catch (e) {
38+
setError(getError(e));
39+
}
40+
};
41+
42+
loadModel();
43+
}, [modelSource, module]);
44+
45+
const forward = async (input: string, classesOfInterest?: string[]) => {
46+
if (!isReady) {
47+
throw new Error(getError(ETError.ModuleNotLoaded));
48+
}
49+
if (isGenerating) {
50+
throw new Error(getError(ETError.ModelGenerating));
51+
}
52+
53+
try {
54+
setIsGenerating(true);
55+
const output = await module.forward(input, classesOfInterest || []);
56+
return output;
57+
} catch (e) {
58+
throw new Error(getError(e));
59+
} finally {
60+
setIsGenerating(false);
61+
}
62+
};
2963

3064
return { error, isReady, isGenerating, downloadProgress, forward };
3165
};

src/modules/computer_vision/BaseCVModule.ts

-2
Original file line numberDiff line numberDiff line change
@@ -3,15 +3,13 @@ import {
33
_StyleTransferModule,
44
_ObjectDetectionModule,
55
_ClassificationModule,
6-
_ImageSegmentationModule,
76
} from '../../native/RnExecutorchModules';
87
import { getError } from '../../Error';
98

109
export class BaseCVModule extends BaseModule {
1110
static module:
1211
| _StyleTransferModule
1312
| _ObjectDetectionModule
14-
| _ImageSegmentationModule
1513
| _ClassificationModule;
1614

1715
static async forward(input: string) {
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,19 @@
1-
import { BaseCVModule } from './BaseCVModule';
1+
import { BaseModule } from '../BaseModule';
22
import { _ImageSegmentationModule } from '../../native/RnExecutorchModules';
3+
import { getError } from '../../Error';
34

4-
export class ImageSegmentationModule extends BaseCVModule {
5+
export class ImageSegmentationModule extends BaseModule {
56
static module = new _ImageSegmentationModule();
67

7-
static async forward(input: string) {
8-
return await (super.forward(input) as ReturnType<
9-
_ImageSegmentationModule['forward']
10-
>);
8+
static async forward(input: string, classesOfInteres?: string[]) {
9+
console.log('# classes: ', classesOfInteres?.length);
10+
try {
11+
return await (this.module.forward(
12+
input,
13+
classesOfInteres || []
14+
) as ReturnType<_ImageSegmentationModule['forward']>);
15+
} catch (e) {
16+
throw new Error(getError(e));
17+
}
1118
}
1219
}

src/native/NativeImageSegmentation.ts

+4-1
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,10 @@ import { TurboModuleRegistry } from 'react-native';
44
export interface Spec extends TurboModule {
55
loadModule(modelSource: string): Promise<number>;
66

7-
forward(input: string): Promise<{ [category: string]: number[] }>;
7+
forward(
8+
input: string,
9+
classesOfInterest: string[]
10+
): Promise<{ [category: string]: number[] }>;
811
}
912

1013
export default TurboModuleRegistry.get<Spec>('ImageSegmentation');

src/native/RnExecutorchModules.ts

+3-2
Original file line numberDiff line numberDiff line change
@@ -117,9 +117,10 @@ const OCR = OCRSpec
117117

118118
class _ImageSegmentationModule {
119119
async forward(
120-
input: string
120+
input: string,
121+
classesOfInteres: string[]
121122
): ReturnType<ImageSegmentationInterface['forward']> {
122-
return await ImageSegmentation.forward(input);
123+
return await ImageSegmentation.forward(input, classesOfInteres);
123124
}
124125
async loadModule(
125126
modelSource: string | number

0 commit comments

Comments
 (0)