Skip to content

Commit 7484a31

Browse files
authored
@jakmro/classification ios (#54)
## Description Image classification for iOS ### Type of change - [ ] Bug fix (non-breaking change which fixes an issue) - [x] New feature (non-breaking change which adds functionality) - [ ] Breaking change (fix or feature that would cause existing functionality to not work as expected) - [ ] Documentation update (improves or adds clarity to existing documentation) ### Tested on - [x] iOS - [ ] Android ### Checklist - [x] I have performed a self-review of my code - [x] I have commented my code, particularly in hard-to-understand areas - [ ] I have updated the documentation accordingly - [x] My changes generate no new warnings
1 parent b3c07ad commit 7484a31

21 files changed

+1412
-23
lines changed

examples/computer-vision/App.tsx

+11-8
Original file line numberDiff line numberDiff line change
@@ -6,11 +6,12 @@ import { useState } from 'react';
66
import { StyleTransferScreen } from './screens/StyleTransferScreen';
77
import { SafeAreaProvider, SafeAreaView } from 'react-native-safe-area-context';
88
import { View, StyleSheet } from 'react-native';
9+
import { ClassificationScreen } from './screens/ClassificationScreen';
910

1011
enum ModelType {
1112
STYLE_TRANSFER,
1213
OBJECT_DETECTION,
13-
IMAGE_CLASSIFICATION,
14+
CLASSIFICATION,
1415
SEMANTIC_SEGMENTATION,
1516
}
1617

@@ -36,8 +37,10 @@ export default function App() {
3637
);
3738
case ModelType.OBJECT_DETECTION:
3839
return <></>;
39-
case ModelType.IMAGE_CLASSIFICATION:
40-
return <></>;
40+
case ModelType.CLASSIFICATION:
41+
return (
42+
<ClassificationScreen imageUri={imageUri} setImageUri={setImageUri} />
43+
);
4144
case ModelType.SEMANTIC_SEGMENTATION:
4245
return <></>;
4346
default:
@@ -57,17 +60,17 @@ export default function App() {
5760
dataSource={[
5861
'Style Transfer',
5962
'Object Detection',
60-
'Image Classification',
63+
'Classification',
6164
'Semantic Segmentation',
6265
]}
6366
onValueChange={(_, selectedIndex) => {
6467
handleModeChange(selectedIndex);
6568
}}
66-
wrapperHeight={135}
69+
wrapperHeight={100}
6770
highlightColor={ColorPalette.primary}
6871
wrapperBackground="#fff"
6972
highlightBorderWidth={3}
70-
itemHeight={60}
73+
itemHeight={40}
7174
activeItemTextStyle={styles.activeScrollItem}
7275
/>
7376
</View>
@@ -85,15 +88,15 @@ const styles = StyleSheet.create({
8588
},
8689
topContainer: {
8790
marginTop: 5,
88-
height: 150,
91+
height: 145,
8992
width: '100%',
9093
alignItems: 'center',
9194
justifyContent: 'center',
9295
marginBottom: 16,
9396
},
9497
wheelPickerContainer: {
9598
width: '100%',
96-
height: 135,
99+
height: 100,
97100
},
98101
activeScrollItem: {
99102
color: ColorPalette.primary,

examples/computer-vision/ios/Podfile.lock

+6-2
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@ PODS:
4242
- hermes-engine (0.76.3):
4343
- hermes-engine/Pre-built (= 0.76.3)
4444
- hermes-engine/Pre-built (0.76.3)
45+
- opencv-rne (0.1.0)
4546
- RCT-Folly (2024.01.01.00):
4647
- boost
4748
- DoubleConversion
@@ -1277,10 +1278,11 @@ PODS:
12771278
- ReactCommon/turbomodule/bridging
12781279
- ReactCommon/turbomodule/core
12791280
- Yoga
1280-
- react-native-executorch (0.1.524):
1281+
- react-native-executorch (0.1.100):
12811282
- DoubleConversion
12821283
- glog
12831284
- hermes-engine
1285+
- opencv-rne (~> 0.1.0)
12841286
- RCT-Folly (= 2024.01.01.00)
12851287
- RCTRequired
12861288
- RCTTypeSafety
@@ -1866,6 +1868,7 @@ DEPENDENCIES:
18661868

18671869
SPEC REPOS:
18681870
trunk:
1871+
- opencv-rne
18691872
- SocketRocket
18701873

18711874
EXTERNAL SOURCES:
@@ -2035,6 +2038,7 @@ SPEC CHECKSUMS:
20352038
fmt: 10c6e61f4be25dc963c36bd73fc7b1705fe975be
20362039
glog: 08b301085f15bcbb6ff8632a8ebaf239aae04e6a
20372040
hermes-engine: 0555a84ea495e8e3b4bde71b597cd87fbb382888
2041+
opencv-rne: 63e933ae2373fc91351f9a348dc46c3f523c2d3f
20382042
RCT-Folly: bf5c0376ffe4dd2cf438dcf86db385df9fdce648
20392043
RCTDeprecation: 2c5e1000b04ab70b53956aa498bf7442c3c6e497
20402044
RCTRequired: 5f785a001cf68a551c5f5040fb4c415672dbb481
@@ -2064,7 +2068,7 @@ SPEC CHECKSUMS:
20642068
React-logger: 26155dc23db5c9038794db915f80bd2044512c2e
20652069
React-Mapbuffer: ad1ba0205205a16dbff11b8ade6d1b3959451658
20662070
React-microtasksnativemodule: e771eb9eb6ace5884ee40a293a0e14a9d7a4343c
2067-
react-native-executorch: 9782e20c5bb4ddf7836af8779f887223228833e9
2071+
react-native-executorch: a30dd907f470d5c4f8135e2ba1fa2a3bb65ea47a
20682072
react-native-image-picker: bfb56e2b39dc63abfcc6de44ee239c6633f47d66
20692073
react-native-safe-area-context: d6406c2adbd41b2e09ab1c386781dc1c81a90919
20702074
React-nativeconfig: aeed6e2a8ac02b2df54476afcc7c663416c12bf7

examples/computer-vision/ios/computervision.xcodeproj/project.pbxproj

+2-2
Original file line numberDiff line numberDiff line change
@@ -367,7 +367,7 @@
367367
);
368368
OTHER_SWIFT_FLAGS = "$(inherited) -D EXPO_CONFIGURATION_DEBUG";
369369
PRODUCT_BUNDLE_IDENTIFIER = com.anonymous.computervision;
370-
PRODUCT_NAME = "computervision";
370+
PRODUCT_NAME = computervision;
371371
SWIFT_OBJC_BRIDGING_HEADER = "computervision/computervision-Bridging-Header.h";
372372
SWIFT_OPTIMIZATION_LEVEL = "-Onone";
373373
SWIFT_VERSION = 5.0;
@@ -399,7 +399,7 @@
399399
);
400400
OTHER_SWIFT_FLAGS = "$(inherited) -D EXPO_CONFIGURATION_RELEASE";
401401
PRODUCT_BUNDLE_IDENTIFIER = com.anonymous.computervision;
402-
PRODUCT_NAME = "computervision";
402+
PRODUCT_NAME = computervision;
403403
SWIFT_OBJC_BRIDGING_HEADER = "computervision/computervision-Bridging-Header.h";
404404
SWIFT_VERSION = 5.0;
405405
TARGETED_DEVICE_FAMILY = "1,2";
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
import { Platform } from 'react-native';
2+
3+
export const efficientnet_v2_s =
4+
Platform.OS === 'ios'
5+
? require('../assets/classification/ios/efficientnet_v2_s_coreml_all.pte')
6+
: require('../assets/classification/android/efficientnet_v2_s_xnnpack.pte');
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,126 @@
1+
import { useState } from 'react';
2+
import Spinner from 'react-native-loading-spinner-overlay';
3+
import { BottomBar } from '../components/BottomBar';
4+
import { efficientnet_v2_s } from '../models/classification';
5+
import { getImageUri } from '../utils';
6+
import { useClassification } from 'react-native-executorch';
7+
import { View, StyleSheet, Image, Text, ScrollView } from 'react-native';
8+
9+
export const ClassificationScreen = ({
10+
imageUri,
11+
setImageUri,
12+
}: {
13+
imageUri: string;
14+
setImageUri: (imageUri: string) => void;
15+
}) => {
16+
const [results, setResults] = useState<{ label: string; score: number }[]>(
17+
[]
18+
);
19+
20+
const model = useClassification({
21+
modulePath: efficientnet_v2_s,
22+
});
23+
24+
const handleCameraPress = async (isCamera: boolean) => {
25+
const uri = await getImageUri(isCamera);
26+
if (typeof uri === 'string') {
27+
setImageUri(uri as string);
28+
setResults([]);
29+
}
30+
};
31+
32+
const runForward = async () => {
33+
if (imageUri) {
34+
try {
35+
const output = await model.forward(imageUri);
36+
const top10 = Object.entries(output)
37+
.sort(([, a], [, b]) => b - a)
38+
.slice(0, 10)
39+
.map(([label, score]) => ({ label, score }));
40+
setResults(top10);
41+
} catch (e) {
42+
console.error(e);
43+
}
44+
}
45+
};
46+
47+
if (!model.isModelReady) {
48+
return (
49+
<Spinner
50+
visible={!model.isModelReady}
51+
textContent={`Loading the model...`}
52+
/>
53+
);
54+
}
55+
56+
return (
57+
<>
58+
<View style={styles.imageContainer}>
59+
<Image
60+
style={styles.image}
61+
resizeMode="contain"
62+
source={
63+
imageUri
64+
? { uri: imageUri }
65+
: require('../assets/icons/executorch_logo.png')
66+
}
67+
/>
68+
{results.length > 0 && (
69+
<View style={styles.results}>
70+
<Text style={styles.resultHeader}>Results Top 10</Text>
71+
<ScrollView style={styles.resultsList}>
72+
{results.map(({ label, score }) => (
73+
<View key={label} style={styles.resultRecord}>
74+
<Text style={styles.resultLabel}>{label}</Text>
75+
<Text>{score.toFixed(3)}</Text>
76+
</View>
77+
))}
78+
</ScrollView>
79+
</View>
80+
)}
81+
</View>
82+
<BottomBar
83+
handleCameraPress={handleCameraPress}
84+
runForward={runForward}
85+
/>
86+
</>
87+
);
88+
};
89+
90+
const styles = StyleSheet.create({
91+
imageContainer: {
92+
flex: 6,
93+
width: '100%',
94+
padding: 16,
95+
},
96+
image: {
97+
flex: 2,
98+
borderRadius: 8,
99+
width: '100%',
100+
},
101+
results: {
102+
flex: 1,
103+
alignItems: 'center',
104+
justifyContent: 'center',
105+
gap: 4,
106+
padding: 4,
107+
},
108+
resultHeader: {
109+
fontSize: 18,
110+
color: 'navy',
111+
},
112+
resultsList: {
113+
flex: 1,
114+
},
115+
resultRecord: {
116+
flexDirection: 'row',
117+
width: '100%',
118+
justifyContent: 'space-between',
119+
padding: 8,
120+
borderBottomWidth: 1,
121+
},
122+
resultLabel: {
123+
flex: 1,
124+
marginRight: 4,
125+
},
126+
});

examples/computer-vision/yarn.lock

+5-5
Original file line numberDiff line numberDiff line change
@@ -3352,7 +3352,7 @@ __metadata:
33523352
metro-config: ^0.81.0
33533353
react: 18.3.1
33543354
react-native: 0.76.3
3355-
react-native-executorch: /Users/norbertklockiewicz/Desktop/work/react-native-executorch/react-native-executorch-0.1.524.tgz
3355+
react-native-executorch: ../../react-native-executorch-0.1.100.tgz
33563356
react-native-image-picker: ^7.2.2
33573357
react-native-loading-spinner-overlay: ^3.0.1
33583358
react-native-reanimated: ^3.16.3
@@ -6799,13 +6799,13 @@ __metadata:
67996799
languageName: node
68006800
linkType: hard
68016801

6802-
"react-native-executorch@file:/Users/norbertklockiewicz/Desktop/work/react-native-executorch/react-native-executorch-0.1.524.tgz::locator=computer-vision%40workspace%3A.":
6803-
version: 0.1.524
6804-
resolution: "react-native-executorch@file:/Users/norbertklockiewicz/Desktop/work/react-native-executorch/react-native-executorch-0.1.524.tgz::locator=computer-vision%40workspace%3A."
6802+
"react-native-executorch@file:../../react-native-executorch-0.1.100.tgz::locator=computer-vision%40workspace%3A.":
6803+
version: 0.1.100
6804+
resolution: "react-native-executorch@file:../../react-native-executorch-0.1.100.tgz::locator=computer-vision%40workspace%3A."
68056805
peerDependencies:
68066806
react: "*"
68076807
react-native: "*"
6808-
checksum: 4f67dbd81711997e5f890b2d7c7b025777d6bcf71a335c7efa56da3fa59a6d00a4915c8cb63e4362d931885adde4072fad63bfbabb93bc36edf116e4fa98f5b9
6808+
checksum: f258452e2050df59e150938f6482ef8eee5fbd4ef4fc4073a920293ca87d543daddf76c560701d0c2626e6677d964b446dad8e670e978ea4f80d0a1bd17dfa03
68096809
languageName: node
68106810
linkType: hard
68116811

ios/RnExecutorch/Classification.h

+5
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
#import <RnExecutorchSpec/RnExecutorchSpec.h>
2+
3+
@interface Classification : NSObject <NativeClassificationSpec>
4+
5+
@end

ios/RnExecutorch/Classification.mm

+56
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,56 @@
1+
#import "Classification.h"
2+
#import "utils/Fetcher.h"
3+
#import "models/BaseModel.h"
4+
#import "utils/ETError.h"
5+
#import "ImageProcessor.h"
6+
#import <ExecutorchLib/ETModel.h>
7+
#import <React/RCTBridgeModule.h>
8+
#import "models/classification/ClassificationModel.h"
9+
#import "opencv2/opencv.hpp"
10+
11+
@implementation Classification {
12+
ClassificationModel* model;
13+
}
14+
15+
RCT_EXPORT_MODULE()
16+
17+
- (void)loadModule:(NSString *)modelSource
18+
resolve:(RCTPromiseResolveBlock)resolve
19+
reject:(RCTPromiseRejectBlock)reject {
20+
model = [[ClassificationModel alloc] init];
21+
[model loadModel: [NSURL URLWithString:modelSource] completion:^(BOOL success, NSNumber *errorCode){
22+
if(success){
23+
resolve(errorCode);
24+
return;
25+
}
26+
27+
reject(@"init_module_error", [NSString
28+
stringWithFormat:@"%ld", (long)[errorCode longValue]], nil);
29+
return;
30+
}];
31+
}
32+
33+
- (void)forward:(NSString *)input
34+
resolve:(RCTPromiseResolveBlock)resolve
35+
reject:(RCTPromiseRejectBlock)reject {
36+
@try {
37+
cv::Mat image = [ImageProcessor readImage:input];
38+
NSDictionary *result = [model runModel:image];
39+
40+
resolve(result);
41+
return;
42+
} @catch (NSException *exception) {
43+
NSLog(@"An exception occurred: %@, %@", exception.name, exception.reason);
44+
reject(@"forward_error", [NSString stringWithFormat:@"%@", exception.reason],
45+
nil);
46+
return;
47+
}
48+
}
49+
50+
51+
- (std::shared_ptr<facebook::react::TurboModule>)getTurboModule:
52+
(const facebook::react::ObjCTurboModule::InitParams &)params {
53+
return std::make_shared<facebook::react::NativeClassificationSpecJSI>(params);
54+
}
55+
56+
@end

ios/RnExecutorch/models/StyleTransferModel.h

+1-2
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
1-
#import <UIKit/UIKit.h>
21
#import "BaseModel.h"
3-
#import <opencv2/opencv.hpp>
2+
#import "opencv2/opencv.hpp"
43

54
@interface StyleTransferModel : BaseModel
65

Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
#import "BaseModel.h"
2+
#import "opencv2/opencv.hpp"
3+
4+
@interface ClassificationModel : BaseModel
5+
6+
- (NSArray *)preprocess:(cv::Mat &)input;
7+
- (NSDictionary *)runModel:(cv::Mat &)input;
8+
- (NSDictionary *)postprocess:(NSArray *)output;
9+
10+
@end

0 commit comments

Comments
 (0)