Skip to content

Commit e6593fd

Browse files
authored
feat: Add object detection (iOS) (#49)
## Description This PR introduces a new native and typescript api - object detection. For now, it only supports SSDLiteLarge320 model from torchvision. Given that operations such as NMS are hard/impossible to export, they were implemented in the native code of the library. ### Type of change - [ ] Bug fix (non-breaking change which fixes an issue) - [x] New feature (non-breaking change which adds functionality) - [ ] Breaking change (fix or feature that would cause existing functionality to not work as expected) - [ ] Documentation update (improves or adds clarity to existing documentation) ### Tested on - [x] iOS - [ ] Android ### Testing instructions <!-- Provide step-by-step instructions on how to test your changes. Include setup details if necessary. --> ### Screenshots <!-- Add screenshots here, if applicable --> ### Related issues <!-- Link related issues here using #issue-number --> ### Checklist - [ ] I have performed a self-review of my code - [ ] I have commented my code, particularly in hard-to-understand areas - [ ] I have updated the documentation accordingly - [ ] My changes generate no new warnings ### Additional notes <!-- Include any additional information, assumptions, or context that reviewers might need to understand this PR. -->
1 parent aa0f9c7 commit e6593fd

13 files changed

+476
-5
lines changed

ios/RnExecutorch/ObjectDetection.h

+5
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
#import <RnExecutorchSpec/RnExecutorchSpec.h>
2+
3+
@interface ObjectDetection : NSObject <NativeObjectDetectionSpec>
4+
5+
@end

ios/RnExecutorch/ObjectDetection.mm

+56
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,56 @@
1+
#import "ObjectDetection.h"
2+
#import "models/object_detection/SSDLiteLargeModel.hpp"
3+
#import <ExecutorchLib/ETModel.h>
4+
#import <React/RCTBridgeModule.h>
5+
#import "utils/ImageProcessor.h"
6+
7+
@implementation ObjectDetection {
8+
SSDLiteLargeModel *model;
9+
}
10+
11+
RCT_EXPORT_MODULE()
12+
13+
- (void)loadModule:(NSString *)modelSource
14+
resolve:(RCTPromiseResolveBlock)resolve
15+
reject:(RCTPromiseRejectBlock)reject {
16+
model = [[SSDLiteLargeModel alloc] init];
17+
[model loadModel:[NSURL URLWithString:modelSource]
18+
completion:^(BOOL success, NSNumber *errorCode) {
19+
if (success) {
20+
resolve(errorCode);
21+
return;
22+
}
23+
24+
NSError *error = [NSError
25+
errorWithDomain:@"StyleTransferErrorDomain"
26+
code:[errorCode intValue]
27+
userInfo:@{
28+
NSLocalizedDescriptionKey : [NSString
29+
stringWithFormat:@"%ld", (long)[errorCode longValue]]
30+
}];
31+
32+
reject(@"init_module_error", error.localizedDescription, error);
33+
return;
34+
}];
35+
}
36+
37+
- (void)forward:(NSString *)input
38+
resolve:(RCTPromiseResolveBlock)resolve
39+
reject:(RCTPromiseRejectBlock)reject {
40+
@try {
41+
cv::Mat image = [ImageProcessor readImage:input];
42+
NSArray *result = [model runModel:image];
43+
resolve(result);
44+
} @catch (NSException *exception) {
45+
reject(@"forward_error", [NSString stringWithFormat:@"%@", exception.reason],
46+
nil);
47+
}
48+
}
49+
50+
- (std::shared_ptr<facebook::react::TurboModule>)getTurboModule:
51+
(const facebook::react::ObjCTurboModule::InitParams &)params {
52+
return std::make_shared<facebook::react::NativeObjectDetectionSpecJSI>(
53+
params);
54+
}
55+
56+
@end
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
#import "../BaseModel.h"
2+
#import <UIKit/UIKit.h>
3+
#include <opencv2/opencv.hpp>
4+
5+
@interface SSDLiteLargeModel : BaseModel
6+
7+
- (NSArray *)runModel:(cv::Mat)input;
8+
- (NSArray *)preprocess:(cv::Mat)input;
9+
- (NSArray *)postprocess:(NSArray *)input;
10+
11+
@end
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,65 @@
1+
#include "SSDLiteLargeModel.hpp"
2+
#include "../../utils/ObjectDetectionUtils.hpp"
3+
#include "ImageProcessor.h"
4+
#include <vector>
5+
6+
float constexpr iouThreshold = 0.55;
7+
float constexpr detectionThreshold = 0.7;
8+
int constexpr modelInputWidth = 320;
9+
int constexpr modelInputHeight = 320;
10+
11+
@implementation SSDLiteLargeModel
12+
13+
- (NSArray *)preprocess:(cv::Mat)input {
14+
cv::resize(input, input, cv::Size(modelInputWidth, modelInputHeight));
15+
NSArray *modelInput = [ImageProcessor matToNSArray:input];
16+
return modelInput;
17+
}
18+
19+
- (NSArray *)postprocess:(NSArray *)input
20+
widthRatio:(float)widthRatio
21+
heightRatio:(float)heightRatio {
22+
NSArray *bboxes = [input objectAtIndex:0];
23+
NSArray *scores = [input objectAtIndex:1];
24+
NSArray *labels = [input objectAtIndex:2];
25+
26+
std::vector<Detection> detections;
27+
28+
for (NSUInteger idx = 0; idx < scores.count; idx++) {
29+
float score = [scores[idx] floatValue];
30+
float label = [labels[idx] floatValue];
31+
if (score < detectionThreshold) {
32+
continue;
33+
}
34+
float x1 = [bboxes[idx * 4] floatValue] * widthRatio;
35+
float y1 = [bboxes[idx * 4 + 1] floatValue] * heightRatio;
36+
float x2 = [bboxes[idx * 4 + 2] floatValue] * widthRatio;
37+
float y2 = [bboxes[idx * 4 + 3] floatValue] * heightRatio;
38+
39+
Detection det = {x1, y1, x2, y2, label, score};
40+
detections.push_back(det);
41+
}
42+
std::vector<Detection> nms_output = nms(detections, iouThreshold);
43+
44+
NSMutableArray *output = [NSMutableArray array];
45+
for (Detection &detection : nms_output) {
46+
[output addObject:detectionToNSDictionary(detection)];
47+
}
48+
49+
return output;
50+
}
51+
52+
- (NSArray *)runModel:(cv::Mat)input {
53+
cv::Size size = input.size();
54+
int inputImageWidth = size.width;
55+
int inputImageHeight = size.height;
56+
NSArray *modelInput = [self preprocess:input];
57+
NSArray *forwardResult = [self forward:modelInput];
58+
NSArray *output =
59+
[self postprocess:forwardResult
60+
widthRatio:inputImageWidth / (float)modelInputWidth
61+
heightRatio:inputImageHeight / (float)modelInputHeight];
62+
return output;
63+
}
64+
65+
@end

ios/RnExecutorch/utils/Constants.mm

+34
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
#include "Constants.h"
2+
3+
const std::unordered_map<int, std::string> cocoLabelsMap = {
4+
{1, "PERSON"}, {2, "BICYCLE"}, {3, "CAR"},
5+
{4, "MOTORCYCLE"}, {5, "AIRPLANE"}, {6, "BUS"},
6+
{7, "TRAIN"}, {8, "TRUCK"}, {9, "BOAT"},
7+
{10, "TRAFFIC_LIGHT"}, {11, "FIRE_HYDRANT"}, {12, "STREET_SIGN"},
8+
{13, "STOP_SIGN"}, {14, "PARKING"}, {15, "BENCH"},
9+
{16, "BIRD"}, {17, "CAT"}, {18, "DOG"},
10+
{19, "HORSE"}, {20, "SHEEP"}, {21, "COW"},
11+
{22, "ELEPHANT"}, {23, "BEAR"}, {24, "ZEBRA"},
12+
{25, "GIRAFFE"}, {26, "HAT"}, {27, "BACKPACK"},
13+
{28, "UMBRELLA"}, {29, "SHOE"}, {30, "EYE"},
14+
{31, "HANDBAG"}, {32, "TIE"}, {33, "SUITCASE"},
15+
{34, "FRISBEE"}, {35, "SKIS"}, {36, "SNOWBOARD"},
16+
{37, "SPORTS"}, {38, "KITE"}, {39, "BASEBALL"},
17+
{40, "BASEBALL"}, {41, "SKATEBOARD"}, {42, "SURFBOARD"},
18+
{43, "TENNIS_RACKET"}, {44, "BOTTLE"}, {45, "PLATE"},
19+
{46, "WINE_GLASS"}, {47, "CUP"}, {48, "FORK"},
20+
{49, "KNIFE"}, {50, "SPOON"}, {51, "BOWL"},
21+
{52, "BANANA"}, {53, "APPLE"}, {54, "SANDWICH"},
22+
{55, "ORANGE"}, {56, "BROCCOLI"}, {57, "CARROT"},
23+
{58, "HOT_DOG"}, {59, "PIZZA"}, {60, "DONUT"},
24+
{61, "CAKE"}, {62, "CHAIR"}, {63, "COUCH"},
25+
{64, "POTTED_PLANT"}, {65, "BED"}, {66, "MIRROR"},
26+
{67, "DINING_TABLE"}, {68, "WINDOW"}, {69, "DESK"},
27+
{70, "TOILET"}, {71, "DOOR"}, {72, "TV"},
28+
{73, "LAPTOP"}, {74, "MOUSE"}, {75, "REMOTE"},
29+
{76, "KEYBOARD"}, {77, "CELL_PHONE"}, {78, "MICROWAVE"},
30+
{79, "OVEN"}, {80, "TOASTER"}, {81, "SINK"},
31+
{82, "REFRIGERATOR"}, {83, "BLENDER"}, {84, "BOOK"},
32+
{85, "CLOCK"}, {86, "VASE"}, {87, "SCISSORS"},
33+
{88, "TEDDY_BEAR"}, {89, "HAIR_DRIER"}, {90, "TOOTHBRUSH"},
34+
{91, "HAIR_BRUSH"}};

ios/RnExecutorch/utils/ImageProcessor.mm

+4-4
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,8 @@ + (NSArray *)matToNSArray:(const cv::Mat &)mat {
1414
int row = i / mat.cols;
1515
int col = i % mat.cols;
1616
cv::Vec3b pixel = mat.at<cv::Vec3b>(row, col);
17-
floatArray[i] = @(pixel[2] / 255.0f);
18-
floatArray[pixelCount + i] = @(pixel[1] / 255.0f);
17+
floatArray[0 * pixelCount + i] = @(pixel[2] / 255.0f);
18+
floatArray[1 * pixelCount + i] = @(pixel[1] / 255.0f);
1919
floatArray[2 * pixelCount + i] = @(pixel[0] / 255.0f);
2020
}
2121

@@ -31,8 +31,8 @@ + (NSArray *)matToNSArray:(const cv::Mat &)mat {
3131
int col = i % width;
3232
float r = 0, g = 0, b = 0;
3333

34-
r = [[array objectAtIndex: i] floatValue];
35-
g = [[array objectAtIndex: pixelCount + i] floatValue];
34+
r = [[array objectAtIndex: 0 * pixelCount + i] floatValue];
35+
g = [[array objectAtIndex: 1 * pixelCount + i] floatValue];
3636
b = [[array objectAtIndex: 2 * pixelCount + i] floatValue];
3737

3838
cv::Vec3b color((uchar)(b * 255), (uchar)(g * 255), (uchar)(r * 255));
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
#ifndef ObjectDetectionUtils_hpp
2+
#define ObjectDetectionUtils_hpp
3+
4+
#import <Foundation/Foundation.h>
5+
#include <stdio.h>
6+
#include <vector>
7+
8+
struct Detection {
9+
float x1;
10+
float y1;
11+
float x2;
12+
float y2;
13+
float label;
14+
float score;
15+
};
16+
17+
NSString *floatLabelToNSString(float label);
18+
NSDictionary *detectionToNSDictionary(const Detection &detection);
19+
float iou(const Detection &a, const Detection &b);
20+
std::vector<Detection> nms(std::vector<Detection> detections,
21+
float iouThreshold);
22+
23+
#endif /* ObjectDetectionUtils_hpp */
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,84 @@
1+
#include "ObjectDetectionUtils.hpp"
2+
#include "Constants.h"
3+
#include <map>
4+
#include <vector>
5+
6+
NSString *floatLabelToNSString(float label) {
7+
int intLabel = static_cast<int>(label);
8+
auto it = cocoLabelsMap.find(intLabel);
9+
if (it != cocoLabelsMap.end()) {
10+
return [NSString stringWithUTF8String:it->second.c_str()];
11+
} else {
12+
return [NSString stringWithUTF8String:"unknown"];
13+
}
14+
}
15+
16+
NSDictionary *detectionToNSDictionary(const Detection &detection) {
17+
return @{
18+
@"bbox" : @{
19+
@"x1" : @(detection.x1),
20+
@"y1" : @(detection.y1),
21+
@"x2" : @(detection.x2),
22+
@"y2" : @(detection.y2),
23+
},
24+
@"label" : floatLabelToNSString(detection.label),
25+
@"score" : @(detection.score)
26+
};
27+
}
28+
29+
float iou(const Detection &a, const Detection &b) {
30+
float x1 = std::max(a.x1, b.x1);
31+
float y1 = std::max(a.y1, b.y1);
32+
float x2 = std::min(a.x2, b.x2);
33+
float y2 = std::min(a.y2, b.y2);
34+
35+
float intersectionArea = std::max(0.0f, x2 - x1) * std::max(0.0f, y2 - y1);
36+
float areaA = (a.x2 - a.x1) * (a.y2 - a.y1);
37+
float areaB = (b.x2 - b.x1) * (b.y2 - b.y1);
38+
float unionArea = areaA + areaB - intersectionArea;
39+
40+
return intersectionArea / unionArea;
41+
};
42+
43+
std::vector<Detection> nms(std::vector<Detection> detections,
44+
float iouThreshold) {
45+
if (detections.empty()) {
46+
return {};
47+
}
48+
49+
// Sort by label, then by score
50+
std::sort(detections.begin(), detections.end(),
51+
[](const Detection &a, const Detection &b) {
52+
if (a.label == b.label) {
53+
return a.score > b.score;
54+
}
55+
return a.label < b.label;
56+
});
57+
58+
std::vector<Detection> result;
59+
// Apply NMS for each label
60+
for (size_t i = 0; i < detections.size();) {
61+
float currentLabel = detections[i].label;
62+
63+
std::vector<Detection> labelDetections;
64+
while (i < detections.size() && detections[i].label == currentLabel) {
65+
labelDetections.push_back(detections[i]);
66+
++i;
67+
}
68+
69+
std::vector<Detection> filteredLabelDetections;
70+
while (!labelDetections.empty()) {
71+
Detection current = labelDetections.front();
72+
filteredLabelDetections.push_back(current);
73+
labelDetections.erase(
74+
std::remove_if(labelDetections.begin(), labelDetections.end(),
75+
[&](const Detection &other) {
76+
return iou(current, other) > iouThreshold;
77+
}),
78+
labelDetections.end());
79+
}
80+
result.insert(result.end(), filteredLabelDetections.begin(),
81+
filteredLabelDetections.end());
82+
}
83+
return result;
84+
}

src/index.tsx

+2
Original file line numberDiff line numberDiff line change
@@ -2,3 +2,5 @@ export * from './ETModule';
22
export * from './LLM';
33
export * from './StyleTransfer';
44
export * from './constants/modelUrls';
5+
export * from './models/object_detection/ObjectDetection';
6+
export * from './models/object_detection/types';
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,63 @@
1+
import { useEffect, useState } from 'react';
2+
import { Image } from 'react-native';
3+
import { ETError, getError } from '../../Error';
4+
import { ObjectDetection } from '../../native/RnExecutorchModules';
5+
import { Detection } from './types';
6+
7+
interface Props {
8+
modelSource: string | number;
9+
}
10+
11+
interface ObjectDetectionModule {
12+
error: string | null;
13+
isModelReady: boolean;
14+
isModelGenerating: boolean;
15+
forward: (input: string) => Promise<Detection[]>;
16+
}
17+
18+
export const useObjectDetection = ({
19+
modelSource,
20+
}: Props): ObjectDetectionModule => {
21+
const [error, setError] = useState<null | string>(null);
22+
const [isModelReady, setIsModelReady] = useState(false);
23+
const [isModelGenerating, setIsModelGenerating] = useState(false);
24+
25+
useEffect(() => {
26+
const loadModel = async () => {
27+
let path = modelSource;
28+
if (typeof modelSource === 'number') {
29+
path = Image.resolveAssetSource(modelSource).uri;
30+
}
31+
32+
try {
33+
setIsModelReady(false);
34+
await ObjectDetection.loadModule(path);
35+
setIsModelReady(true);
36+
} catch (e) {
37+
setError(getError(e));
38+
}
39+
};
40+
41+
loadModel();
42+
}, [modelSource]);
43+
44+
const forward = async (input: string) => {
45+
if (!isModelReady) {
46+
throw new Error(getError(ETError.ModuleNotLoaded));
47+
}
48+
if (isModelGenerating) {
49+
throw new Error(getError(ETError.ModelGenerating));
50+
}
51+
try {
52+
setIsModelGenerating(true);
53+
const output = await ObjectDetection.forward(input);
54+
return output;
55+
} catch (e) {
56+
throw new Error(getError(e));
57+
} finally {
58+
setIsModelGenerating(false);
59+
}
60+
};
61+
62+
return { error, isModelReady, isModelGenerating, forward };
63+
};

0 commit comments

Comments
 (0)