Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Image segmentation for ios #113

Open
wants to merge 9 commits into
base: main
Choose a base branch
from
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
#import "ImageSegmentationModel.h"
#import <unordered_set>
#import "../../utils/ImageProcessor.h"
#import <algorithm>
#import <vector>
#i\port "../../utils/ImageProcessor.h"
#import "../../utils/Numerical.h"
#import "../../utils/Conversions.h"
#import "opencv2/opencv.hpp"
Expand Down Expand Up @@ -57,7 +59,7 @@ - (NSArray *)preprocess:(cv::Mat &)input {
return resizedLabelScores;
}

void adjustScoresPerPixel(std::vector<cv::Mat>& labelScores, cv::Mat& maxArg,
void adjustScoresPerPixel(std::vector<cv::Mat>& labelScores, cv::Mat& argMax,
cv::Size originalSize, std::size_t numLabels) {
std::size_t numOriginalPixels = originalSize.height * originalSize.width;
for (std::size_t pixel = 0; pixel < numOriginalPixels; ++pixel) {
Expand All @@ -71,17 +73,12 @@ void adjustScoresPerPixel(std::vector<cv::Mat>& labelScores, cv::Mat& maxArg,

std::vector<double> adjustedScores = softmax(scores);

std::size_t maxArgIndex = 0;
double maxArgVal = 0;
for (std::size_t label = 0; label < numLabels; ++label) {
labelScores[label].at<double>(row, col) = adjustedScores[label];
if (adjustedScores[label] > maxArgVal) {
maxArgIndex = label;
maxArgVal = adjustedScores[label];
}
}

maxArg.at<int>(row, col) = maxArgIndex;
auto maxIt = std::max_element(scores.begin(), scores.end());
argMax.at<int>(row, col) = std::distance(scores.begin(), maxIt);
}
}

Expand All @@ -98,10 +95,10 @@ - (NSDictionary *)postprocess:(NSArray *)output
std::vector<cv::Mat> resizedLabelScores =
rescaleResults(output, numLabels, modelImageSize, originalSize);

cv::Mat maxArg = cv::Mat(originalSize, CV_32S);
cv::Mat argMax = cv::Mat(originalSize, CV_32S);

// For each pixel apply softmax across all the labels and calculate the maxArg
adjustScoresPerPixel(resizedLabelScores, maxArg, originalSize, numLabels);
// For each pixel apply softmax across all the labels and calculate the argMax
adjustScoresPerPixel(resizedLabelScores, argMax, originalSize, numLabels);

std::unordered_set<std::string> labelSet;

Expand All @@ -120,7 +117,7 @@ - (NSDictionary *)postprocess:(NSArray *)output
}
}

result[@"argmax"] = simpleMatToNSArray<int>(maxArg);
result[@"argmax"] = simpleMatToNSArray<int>(argMax);

return result;
}
Expand Down
23 changes: 0 additions & 23 deletions src/constants/image_segmentation/image_segmentation.ts

This file was deleted.

23 changes: 18 additions & 5 deletions src/hooks/computer_vision/useImageSegmentation.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ import { useState } from 'react';
import { _ImageSegmentationModule } from '../../native/RnExecutorchModules';
import { ETError, getError } from '../../Error';
import { useModule } from '../useModule';
import { DeeplabLabel } from '../../types/image_segmentation';

interface Props {
modelSource: string | number;
Expand All @@ -16,8 +17,8 @@ export const useImageSegmentation = ({
downloadProgress: number;
forward: (
input: string,
classesOfInterest?: string[]
) => Promise<{ [category: string]: number[] }>;
classesOfInterest?: DeeplabLabel[]
Copy link
Contributor

@chmjkb chmjkb Mar 11, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Wondering if we should use this or keyof typeof DeeplabLabel. Using keyof typeof would make it consistent, since it is used in object detection but I think i like this one better. WDYT?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I haven't notice the use in object detection, but I would still make the case for this. In object detection the enum is used as a return value, so we're just interested in the string content, whereas here the enum is supplied by the user (in both cases, either as the filter for classes in a list, or as the key to query results in the dict), so we care more about correctness. I think that autocompletion for the enum values when writing is worth more than consistency in this case.

) => Promise<{ [key in DeeplabLabel]?: number[] }>;
} => {
const [module, _] = useState(() => new _ImageSegmentationModule());
const [isGenerating, setIsGenerating] = useState(false);
Expand All @@ -26,7 +27,7 @@ export const useImageSegmentation = ({
module,
});

const forward = async (input: string, classesOfInterest?: string[]) => {
const forward = async (input: string, classesOfInterest?: DeeplabLabel[]) => {
if (!isReady) {
throw new Error(getError(ETError.ModuleNotLoaded));
}
Expand All @@ -36,8 +37,20 @@ export const useImageSegmentation = ({

try {
setIsGenerating(true);
const output = await module.forward(input, classesOfInterest || []);
return output;
const stringDict = await module.forward(
input,
(classesOfInterest || []).map((label) => DeeplabLabel[label])
);

let enumDict: { [key in DeeplabLabel]?: number[] } = {};

for (const key in stringDict) {
if (key in DeeplabLabel) {
const enumKey = DeeplabLabel[key as keyof typeof DeeplabLabel];
enumDict[enumKey] = stringDict[key];
}
}
return enumDict;
} catch (e) {
throw new Error(getError(e));
} finally {
Expand Down
1 change: 1 addition & 0 deletions src/index.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ export * from './utils/listDownloadedResources';
// types
export * from './types/object_detection';
export * from './types/ocr';
export * from './types/image_segmentation';

// constants
export * from './constants/modelUrls';
9 changes: 4 additions & 5 deletions src/modules/computer_vision/ImageSegmentationModule.ts
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,11 @@ import { getError } from '../../Error';
export class ImageSegmentationModule extends BaseModule {
static module = new _ImageSegmentationModule();

static async forward(input: string, classesOfInteres?: string[]) {
static async forward(input: string, classesOfInterest: string[]) {
try {
return await (this.module.forward(
input,
classesOfInteres || []
) as ReturnType<_ImageSegmentationModule['forward']>);
return await (this.module.forward(input, classesOfInterest) as ReturnType<
_ImageSegmentationModule['forward']
>);
} catch (e) {
throw new Error(getError(e));
}
Expand Down
24 changes: 24 additions & 0 deletions src/types/image_segmentation.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
export enum DeeplabLabel {
background,
aeroplane,
bicycle,
bird,
boat,
bottle,
bus,
car,
cat,
chair,
cow,
diningtable,
dog,
horse,
motorbike,
person,
pottedplant,
sheep,
sofa,
train,
tvmonitor,
argmax, // Additional label not present in the model
}
Loading