Skip to content

Commit 3d03e08

Browse files
committed
Add label enum to segmentation I/O
1 parent f0b12ea commit 3d03e08

File tree

6 files changed

+57
-46
lines changed

6 files changed

+57
-46
lines changed

ios/RnExecutorch/models/image_segmentation/ImageSegmentationModel.mm

+10-13
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
#import "ImageSegmentationModel.h"
22
#import <unordered_set>
3-
#import "../../utils/ImageProcessor.h"
3+
#import <algorithm>
4+
#import <vector>
5+
#i\port "../../utils/ImageProcessor.h"
46
#import "../../utils/Numerical.h"
57
#import "../../utils/Conversions.h"
68
#import "opencv2/opencv.hpp"
@@ -57,7 +59,7 @@ - (NSArray *)preprocess:(cv::Mat &)input {
5759
return resizedLabelScores;
5860
}
5961

60-
void adjustScoresPerPixel(std::vector<cv::Mat>& labelScores, cv::Mat& maxArg,
62+
void adjustScoresPerPixel(std::vector<cv::Mat>& labelScores, cv::Mat& argMax,
6163
cv::Size originalSize, std::size_t numLabels) {
6264
std::size_t numOriginalPixels = originalSize.height * originalSize.width;
6365
for (std::size_t pixel = 0; pixel < numOriginalPixels; ++pixel) {
@@ -71,17 +73,12 @@ void adjustScoresPerPixel(std::vector<cv::Mat>& labelScores, cv::Mat& maxArg,
7173

7274
std::vector<double> adjustedScores = softmax(scores);
7375

74-
std::size_t maxArgIndex = 0;
75-
double maxArgVal = 0;
7676
for (std::size_t label = 0; label < numLabels; ++label) {
7777
labelScores[label].at<double>(row, col) = adjustedScores[label];
78-
if (adjustedScores[label] > maxArgVal) {
79-
maxArgIndex = label;
80-
maxArgVal = adjustedScores[label];
81-
}
8278
}
8379

84-
maxArg.at<int>(row, col) = maxArgIndex;
80+
auto maxIt = std::max_element(scores.begin(), scores.end());
81+
argMax.at<int>(row, col) = std::distance(scores.begin(), maxIt);
8582
}
8683
}
8784

@@ -98,10 +95,10 @@ - (NSDictionary *)postprocess:(NSArray *)output
9895
std::vector<cv::Mat> resizedLabelScores =
9996
rescaleResults(output, numLabels, modelImageSize, originalSize);
10097

101-
cv::Mat maxArg = cv::Mat(originalSize, CV_32S);
98+
cv::Mat argMax = cv::Mat(originalSize, CV_32S);
10299

103-
// For each pixel apply softmax across all the labels and calculate the maxArg
104-
adjustScoresPerPixel(resizedLabelScores, maxArg, originalSize, numLabels);
100+
// For each pixel apply softmax across all the labels and calculate the argMax
101+
adjustScoresPerPixel(resizedLabelScores, argMax, originalSize, numLabels);
105102

106103
std::unordered_set<std::string> labelSet;
107104

@@ -120,7 +117,7 @@ - (NSDictionary *)postprocess:(NSArray *)output
120117
}
121118
}
122119

123-
result[@"argmax"] = simpleMatToNSArray<int>(maxArg);
120+
result[@"argmax"] = simpleMatToNSArray<int>(argMax);
124121

125122
return result;
126123
}

src/constants/image_segmentation/image_segmentation.ts

-23
This file was deleted.

src/hooks/computer_vision/useImageSegmentation.ts

+18-5
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ import { useState } from 'react';
22
import { _ImageSegmentationModule } from '../../native/RnExecutorchModules';
33
import { ETError, getError } from '../../Error';
44
import { useModule } from '../useModule';
5+
import { DeeplabLabel } from '../../types/image_segmentation';
56

67
interface Props {
78
modelSource: string | number;
@@ -16,8 +17,8 @@ export const useImageSegmentation = ({
1617
downloadProgress: number;
1718
forward: (
1819
input: string,
19-
classesOfInterest?: string[]
20-
) => Promise<{ [category: string]: number[] }>;
20+
classesOfInterest?: DeeplabLabel[]
21+
) => Promise<{ [key in DeeplabLabel]?: number[] }>;
2122
} => {
2223
const [module, _] = useState(() => new _ImageSegmentationModule());
2324
const [isGenerating, setIsGenerating] = useState(false);
@@ -26,7 +27,7 @@ export const useImageSegmentation = ({
2627
module,
2728
});
2829

29-
const forward = async (input: string, classesOfInterest?: string[]) => {
30+
const forward = async (input: string, classesOfInterest?: DeeplabLabel[]) => {
3031
if (!isReady) {
3132
throw new Error(getError(ETError.ModuleNotLoaded));
3233
}
@@ -36,8 +37,20 @@ export const useImageSegmentation = ({
3637

3738
try {
3839
setIsGenerating(true);
39-
const output = await module.forward(input, classesOfInterest || []);
40-
return output;
40+
const stringDict = await module.forward(
41+
input,
42+
(classesOfInterest || []).map((label) => DeeplabLabel[label])
43+
);
44+
45+
let enumDict: { [key in DeeplabLabel]?: number[] } = {};
46+
47+
for (const key in stringDict) {
48+
if (key in DeeplabLabel) {
49+
const enumKey = DeeplabLabel[key as keyof typeof DeeplabLabel];
50+
enumDict[enumKey] = stringDict[key];
51+
}
52+
}
53+
return enumDict;
4154
} catch (e) {
4255
throw new Error(getError(e));
4356
} finally {

src/index.tsx

+1
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@ export * from './utils/listDownloadedResources';
2828
// types
2929
export * from './types/object_detection';
3030
export * from './types/ocr';
31+
export * from './types/image_segmentation';
3132

3233
// constants
3334
export * from './constants/modelUrls';

src/modules/computer_vision/ImageSegmentationModule.ts

+4-5
Original file line numberDiff line numberDiff line change
@@ -5,12 +5,11 @@ import { getError } from '../../Error';
55
export class ImageSegmentationModule extends BaseModule {
66
static module = new _ImageSegmentationModule();
77

8-
static async forward(input: string, classesOfInteres?: string[]) {
8+
static async forward(input: string, classesOfInterest: string[]) {
99
try {
10-
return await (this.module.forward(
11-
input,
12-
classesOfInteres || []
13-
) as ReturnType<_ImageSegmentationModule['forward']>);
10+
return await (this.module.forward(input, classesOfInterest) as ReturnType<
11+
_ImageSegmentationModule['forward']
12+
>);
1413
} catch (e) {
1514
throw new Error(getError(e));
1615
}

src/types/image_segmentation.ts

+24
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
export enum DeeplabLabel {
2+
background,
3+
aeroplane,
4+
bicycle,
5+
bird,
6+
boat,
7+
bottle,
8+
bus,
9+
car,
10+
cat,
11+
chair,
12+
cow,
13+
diningtable,
14+
dog,
15+
horse,
16+
motorbike,
17+
person,
18+
pottedplant,
19+
sheep,
20+
sofa,
21+
train,
22+
tvmonitor,
23+
argmax, // Additional label not present in the model
24+
}

0 commit comments

Comments
 (0)