Skip to content

Commit 2520747

Browse files
NorbertKlockiewiczchmjkb
authored andcommitted
feat: style transfer with openCV(ios) (#48)
## Description This PR introduces implementation for style transfer module which is using opencv under the hood, it allows to load images from external sources, device storage and base64. ### Type of change - [ ] Bug fix (non-breaking change which fixes an issue) - [x] New feature (non-breaking change which adds functionality) - [ ] Breaking change (fix or feature that would cause existing functionality to not work as expected) - [ ] Documentation update (improves or adds clarity to existing documentation) ### Tested on - [x] iOS - [ ] Android ### Testing instructions <!-- Provide step-by-step instructions on how to test your changes. Include setup details if necessary. --> ### Screenshots <!-- Add screenshots here, if applicable --> ### Related issues <!-- Link related issues here using #issue-number --> ### Checklist - [x] I have performed a self-review of my code - [ ] I have commented my code, particularly in hard-to-understand areas - [x] I have updated the documentation accordingly - [x] My changes generate no new warnings ### Additional notes <!-- Include any additional information, assumptions, or context that reviewers might need to understand this PR. -->
1 parent 0cc0f26 commit 2520747

File tree

28 files changed

+450
-436
lines changed

28 files changed

+450
-436
lines changed

examples/computer-vision/screens/StyleTransferScreen.tsx

+2-2
Original file line numberDiff line numberDiff line change
@@ -34,10 +34,10 @@ export const StyleTransferScreen = ({
3434
}
3535
};
3636

37-
if (model.isModelLoading) {
37+
if (!model.isModelReady) {
3838
return (
3939
<Spinner
40-
visible={model.isModelLoading}
40+
visible={!model.isModelReady}
4141
textContent={`Loading the model...`}
4242
/>
4343
);
Binary file not shown.

ios/RnExecutorch.xcodeproj/project.pbxproj

+19
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,11 @@
66
objectVersion = 77;
77
objects = {
88

9+
/* Begin PBXBuildFile section */
10+
55D6EA8C2D0987D2009BA408 /* ExecutorchLib.xcframework in Frameworks */ = {isa = PBXBuildFile; fileRef = 55D6EA8B2D0987D2009BA408 /* ExecutorchLib.xcframework */; };
11+
55D6EA8E2D0987DF009BA408 /* opencv2.xcframework in Frameworks */ = {isa = PBXBuildFile; fileRef = 55D6EA8D2D0987DF009BA408 /* opencv2.xcframework */; };
12+
/* End PBXBuildFile section */
13+
914
/* Begin PBXCopyFilesBuildPhase section */
1015
550986872CEF541900FECBB8 /* CopyFiles */ = {
1116
isa = PBXCopyFilesBuildPhase;
@@ -20,6 +25,8 @@
2025

2126
/* Begin PBXFileReference section */
2227
550986892CEF541900FECBB8 /* libRnExecutorch.a */ = {isa = PBXFileReference; explicitFileType = archive.ar; includeInIndex = 0; path = libRnExecutorch.a; sourceTree = BUILT_PRODUCTS_DIR; };
28+
55D6EA8B2D0987D2009BA408 /* ExecutorchLib.xcframework */ = {isa = PBXFileReference; lastKnownFileType = wrapper.xcframework; path = ExecutorchLib.xcframework; sourceTree = "<group>"; };
29+
55D6EA8D2D0987DF009BA408 /* opencv2.xcframework */ = {isa = PBXFileReference; lastKnownFileType = wrapper.xcframework; name = opencv2.xcframework; path = ../../../opencv2.xcframework; sourceTree = "<group>"; };
2330
/* End PBXFileReference section */
2431

2532
/* Begin PBXFileSystemSynchronizedGroupBuildPhaseMembershipExceptionSet section */
@@ -48,6 +55,8 @@
4855
isa = PBXFrameworksBuildPhase;
4956
buildActionMask = 2147483647;
5057
files = (
58+
55D6EA8E2D0987DF009BA408 /* opencv2.xcframework in Frameworks */,
59+
55D6EA8C2D0987D2009BA408 /* ExecutorchLib.xcframework in Frameworks */,
5160
);
5261
runOnlyForDeploymentPostprocessing = 0;
5362
};
@@ -58,6 +67,7 @@
5867
isa = PBXGroup;
5968
children = (
6069
5509868B2CEF541900FECBB8 /* RnExecutorch */,
70+
55D6EA8A2D0987D2009BA408 /* Frameworks */,
6171
5509868A2CEF541900FECBB8 /* Products */,
6272
);
6373
sourceTree = "<group>";
@@ -70,6 +80,15 @@
7080
name = Products;
7181
sourceTree = "<group>";
7282
};
83+
55D6EA8A2D0987D2009BA408 /* Frameworks */ = {
84+
isa = PBXGroup;
85+
children = (
86+
55D6EA8D2D0987DF009BA408 /* opencv2.xcframework */,
87+
55D6EA8B2D0987D2009BA408 /* ExecutorchLib.xcframework */,
88+
);
89+
name = Frameworks;
90+
sourceTree = "<group>";
91+
};
7392
/* End PBXGroup section */
7493

7594
/* Begin PBXNativeTarget section */

ios/RnExecutorch/ETModule.mm

+23-23
Original file line numberDiff line numberDiff line change
@@ -16,30 +16,30 @@ - (void)loadModule:(NSString *)modelSource
1616
if (!module) {
1717
module = [[ETModel alloc] init];
1818
}
19-
19+
2020
[Fetcher fetchResource:[NSURL URLWithString:modelSource]
2121
resourceType:ResourceType::MODEL
2222
completionHandler:^(NSString *filePath, NSError *error) {
23-
if (error) {
24-
reject(@"init_module_error", @"-1", nil);
25-
return;
26-
}
27-
28-
NSNumber *result = [self->module loadModel:filePath];
29-
if ([result isEqualToNumber:@(0)]) {
30-
resolve(result);
31-
} else {
32-
NSError *error = [NSError
33-
errorWithDomain:@"ETModuleErrorDomain"
34-
code:[result intValue]
35-
userInfo:@{
36-
NSLocalizedDescriptionKey : [NSString
37-
stringWithFormat:@"%ld", (long)[result longValue]]
38-
}];
39-
40-
reject(@"init_module_error", error.localizedDescription, error);
41-
}
42-
}];
23+
if (error) {
24+
reject(@"init_module_error", @"-1", nil);
25+
return;
26+
}
27+
28+
NSNumber *result = [self->module loadModel:filePath];
29+
if ([result isEqualToNumber:@(0)]) {
30+
resolve(result);
31+
} else {
32+
NSError *error = [NSError
33+
errorWithDomain:@"ETModuleErrorDomain"
34+
code:[result intValue]
35+
userInfo:@{
36+
NSLocalizedDescriptionKey : [NSString
37+
stringWithFormat:@"%ld", (long)[result longValue]]
38+
}];
39+
40+
reject(@"init_module_error", error.localizedDescription, error);
41+
}
42+
}];
4343
}
4444

4545
- (void)forward:(NSArray *)input
@@ -48,7 +48,7 @@ - (void)forward:(NSArray *)input
4848
resolve:(RCTPromiseResolveBlock)resolve
4949
reject:(RCTPromiseRejectBlock)reject {
5050
@try {
51-
NSArray *result = [module forward:input shape:shape inputType:[NSNumber numberWithInt:inputType]];
51+
NSArray *result = [module forward:input shape:shape inputType:[NSNumber numberWithInt:inputType]];
5252
resolve(result);
5353
} @catch (NSException *exception) {
5454
NSLog(@"An exception occurred: %@, %@", exception.name, exception.reason);
@@ -69,7 +69,7 @@ - (void)loadMethod:(NSString *)methodName
6969
}
7070

7171
- (std::shared_ptr<facebook::react::TurboModule>)getTurboModule:
72-
(const facebook::react::ObjCTurboModule::InitParams &)params {
72+
(const facebook::react::ObjCTurboModule::InitParams &)params {
7373
return std::make_shared<facebook::react::NativeETModuleSpecJSI>(params);
7474
}
7575

ios/RnExecutorch/LLM.mm

+61-61
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
#import "LLM.h"
22
#import <ExecutorchLib/LLaMARunner.h>
3-
#import "utils/ConversationManager.h"
4-
#import "utils/Constants.h"
3+
#import "utils/llms/ConversationManager.h"
4+
#import "utils/llms/Constants.h"
55
#import "utils/Fetcher.h"
66
#import "utils/LargeFileFetcher.h"
77
#import <UIKit/UIKit.h>
@@ -47,77 +47,77 @@ - (void)onResult:(NSString *)token prompt:(NSString *)prompt {
4747

4848
- (void)updateDownloadProgress:(NSNumber *)progress {
4949
dispatch_async(dispatch_get_main_queue(), ^{
50-
[self emitOnDownloadProgress:progress];
50+
[self emitOnDownloadProgress:progress];
5151
});
5252
}
5353

5454
- (void)loadLLM:(NSString *)modelSource tokenizerSource:(NSString *)tokenizerSource systemPrompt:(NSString *)systemPrompt contextWindowLength:(double)contextWindowLength resolve:(RCTPromiseResolveBlock)resolve reject:(RCTPromiseRejectBlock)reject {
55-
NSURL *modelURL = [NSURL URLWithString:modelSource];
56-
NSURL *tokenizerURL = [NSURL URLWithString:tokenizerSource];
57-
58-
if(self->runner || isFetching){
59-
reject(@"model_already_loaded", @"Model and tokenizer already loaded", nil);
55+
NSURL *modelURL = [NSURL URLWithString:modelSource];
56+
NSURL *tokenizerURL = [NSURL URLWithString:tokenizerSource];
57+
58+
if(self->runner || isFetching){
59+
reject(@"model_already_loaded", @"Model and tokenizer already loaded", nil);
60+
return;
61+
}
62+
63+
isFetching = YES;
64+
[Fetcher fetchResource:tokenizerURL resourceType:ResourceType::TOKENIZER completionHandler:^(NSString *tokenizerFilePath, NSError *error) {
65+
if(error){
66+
reject(@"download_error", error.localizedDescription, nil);
6067
return;
6168
}
69+
LargeFileFetcher *modelFetcher = [[LargeFileFetcher alloc] init];
70+
modelFetcher.onProgress = ^(NSNumber *progress) {
71+
[self updateDownloadProgress:progress];
72+
};
6273

63-
isFetching = YES;
64-
[Fetcher fetchResource:tokenizerURL resourceType:ResourceType::TOKENIZER completionHandler:^(NSString *tokenizerFilePath, NSError *error) {
65-
if(error){
66-
reject(@"download_error", error.localizedDescription, nil);
67-
return;
68-
}
69-
LargeFileFetcher *modelFetcher = [[LargeFileFetcher alloc] init];
70-
modelFetcher.onProgress = ^(NSNumber *progress) {
71-
[self updateDownloadProgress:progress];
72-
};
73-
74-
modelFetcher.onFailure = ^(NSError *error){
75-
reject(@"download_error", error.localizedDescription, nil);
76-
return;
77-
};
78-
79-
modelFetcher.onFinish = ^(NSString *modelFilePath) {
80-
self->runner = [[LLaMARunner alloc] initWithModelPath:modelFilePath tokenizerPath:tokenizerFilePath];
81-
NSUInteger contextWindowLengthUInt = (NSUInteger)round(contextWindowLength);
82-
83-
self->conversationManager = [[ConversationManager alloc] initWithNumMessagesContextWindow: contextWindowLengthUInt systemPrompt: systemPrompt];
84-
self->isFetching = NO;
85-
resolve(@"Model and tokenizer loaded successfully");
86-
return;
87-
};
74+
modelFetcher.onFailure = ^(NSError *error){
75+
reject(@"download_error", error.localizedDescription, nil);
76+
return;
77+
};
78+
79+
modelFetcher.onFinish = ^(NSString *modelFilePath) {
80+
self->runner = [[LLaMARunner alloc] initWithModelPath:modelFilePath tokenizerPath:tokenizerFilePath];
81+
NSUInteger contextWindowLengthUInt = (NSUInteger)round(contextWindowLength);
8882

89-
[modelFetcher startDownloadingFileFromURL:modelURL];
90-
}];
83+
self->conversationManager = [[ConversationManager alloc] initWithNumMessagesContextWindow: contextWindowLengthUInt systemPrompt: systemPrompt];
84+
self->isFetching = NO;
85+
resolve(@"Model and tokenizer loaded successfully");
86+
return;
87+
};
88+
89+
[modelFetcher startDownloadingFileFromURL:modelURL];
90+
}];
9191
}
9292

9393

9494
- (void) runInference:(NSString *)input resolve:(RCTPromiseResolveBlock)resolve reject:(RCTPromiseRejectBlock)reject {
95-
[conversationManager addResponse:input senderRole:ChatRole::USER];
96-
NSString *prompt = [conversationManager getConversation];
97-
98-
dispatch_async(dispatch_get_global_queue(DISPATCH_QUEUE_PRIORITY_DEFAULT, 0), ^{
99-
NSError *error = nil;
100-
[self->runner generate:prompt withTokenCallback:^(NSString *token) {
101-
[self onResult:token prompt:prompt];
102-
} error:&error];
103-
104-
// make sure to add eot token once generation is done
105-
if (![self->tempLlamaResponse hasSuffix:END_OF_TEXT_TOKEN_NS]) {
106-
[self onResult:END_OF_TEXT_TOKEN_NS prompt:prompt];
107-
}
108-
109-
if (self->tempLlamaResponse) {
110-
[self->conversationManager addResponse:self->tempLlamaResponse senderRole:ChatRole::ASSISTANT];
111-
self->tempLlamaResponse = [NSMutableString string];
112-
}
113-
114-
if (error) {
115-
reject(@"error_in_generation", error.localizedDescription, nil);
116-
return;
117-
}
118-
resolve(@"Inference completed successfully");
119-
return;
120-
});
95+
[conversationManager addResponse:input senderRole:ChatRole::USER];
96+
NSString *prompt = [conversationManager getConversation];
97+
98+
dispatch_async(dispatch_get_global_queue(DISPATCH_QUEUE_PRIORITY_DEFAULT, 0), ^{
99+
NSError *error = nil;
100+
[self->runner generate:prompt withTokenCallback:^(NSString *token) {
101+
[self onResult:token prompt:prompt];
102+
} error:&error];
103+
104+
// make sure to add eot token once generation is done
105+
if (![self->tempLlamaResponse hasSuffix:END_OF_TEXT_TOKEN_NS]) {
106+
[self onResult:END_OF_TEXT_TOKEN_NS prompt:prompt];
107+
}
108+
109+
if (self->tempLlamaResponse) {
110+
[self->conversationManager addResponse:self->tempLlamaResponse senderRole:ChatRole::ASSISTANT];
111+
self->tempLlamaResponse = [NSMutableString string];
112+
}
113+
114+
if (error) {
115+
reject(@"error_in_generation", error.localizedDescription, nil);
116+
return;
117+
}
118+
resolve(@"Inference completed successfully");
119+
return;
120+
});
121121
}
122122

123123

ios/RnExecutorch/StyleTransfer.mm

+18-34
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,12 @@
11
#import "StyleTransfer.h"
22
#import "utils/Fetcher.h"
33
#import "models/BaseModel.h"
4+
#import "utils/ETError.h"
5+
#import "ImageProcessor.h"
46
#import <ExecutorchLib/ETModel.h>
57
#import <React/RCTBridgeModule.h>
68
#import "models/StyleTransferModel.h"
7-
#include <string>
9+
#import <opencv2/opencv.hpp>
810

911
@implementation StyleTransfer {
1012
StyleTransferModel* model;
@@ -22,52 +24,34 @@ - (void)loadModule:(NSString *)modelSource
2224
return;
2325
}
2426

25-
NSError *error = [NSError
26-
errorWithDomain:@"StyleTransferErrorDomain"
27-
code:[errorCode intValue]
28-
userInfo:@{
29-
NSLocalizedDescriptionKey : [NSString
30-
stringWithFormat:@"%ld", (long)[errorCode longValue]]
31-
}];
32-
33-
reject(@"init_module_error", error.localizedDescription, error);
27+
reject(@"init_module_error", [NSString
28+
stringWithFormat:@"%ld", (long)[errorCode longValue]], nil);
3429
return;
3530
}];
3631
}
3732

38-
- (std::shared_ptr<facebook::react::TurboModule>)getTurboModule:
39-
(const facebook::react::ObjCTurboModule::InitParams &)params {
40-
return std::make_shared<facebook::react::NativeStyleTransferSpecJSI>(params);
41-
}
42-
4333
- (void)forward:(NSString *)input
4434
resolve:(RCTPromiseResolveBlock)resolve
4535
reject:(RCTPromiseRejectBlock)reject {
4636
@try {
47-
NSURL *url = [NSURL URLWithString:input];
48-
NSData *data = [NSData dataWithContentsOfURL:url];
49-
if (!data) {
50-
reject(@"img_loading_error", @"Unable to load image data", nil);
51-
return;
52-
}
53-
UIImage *inputImage = [UIImage imageWithData:data];
54-
55-
UIImage* result = [model runModel:inputImage];
56-
57-
// save img to tmp dir, return URI
58-
NSString *outputPath = [NSTemporaryDirectory() stringByAppendingPathComponent:[@"test" stringByAppendingString:@".png"]];
59-
if ([UIImagePNGRepresentation(result) writeToFile:outputPath atomically:YES]) {
60-
NSURL *fileURL = [NSURL fileURLWithPath:outputPath];
61-
resolve([fileURL absoluteString]);
62-
} else {
63-
reject(@"img_write_error", @"Failed to write processed image to file", nil);
64-
}
37+
cv::Mat image = [ImageProcessor readImage:input];
38+
cv::Mat resultImage = [model runModel:image];
6539

40+
NSString* tempFilePath = [ImageProcessor saveToTempFile:resultImage];
41+
resolve(tempFilePath);
42+
return;
6643
} @catch (NSException *exception) {
6744
NSLog(@"An exception occurred: %@, %@", exception.name, exception.reason);
68-
reject(@"result_error", [NSString stringWithFormat:@"%@", exception.reason],
45+
reject(@"forward_error", [NSString stringWithFormat:@"%@", exception.reason],
6946
nil);
47+
return;
7048
}
7149
}
7250

51+
52+
- (std::shared_ptr<facebook::react::TurboModule>)getTurboModule:
53+
(const facebook::react::ObjCTurboModule::InitParams &)params {
54+
return std::make_shared<facebook::react::NativeStyleTransferSpecJSI>(params);
55+
}
56+
7357
@end

ios/RnExecutorch/models/BaseModel.h

+3-3
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,14 @@
11
#import <Foundation/Foundation.h>
22
#import <UIKit/UIKit.h>
3-
#import "ETModel.h"
3+
#import "ExecutorchLib/ETModel.h"
44

55
@interface BaseModel : NSObject
66
{
77
@protected
8-
ETModel *module;
8+
ETModel *module;
99
}
1010

11-
- (NSArray *)forward:(NSArray *)input shape:(NSArray *)shape inputType:(NSNumber *)inputType error:(NSError **)error;
11+
- (NSArray *)forward:(NSArray *)input;
1212
- (void)loadModel:(NSURL *)modelURL completion:(void (^)(BOOL success, NSNumber *code))completion;
1313

1414
@end

0 commit comments

Comments
 (0)