Skip to content

Commit 0be2992

Browse files
authored
fix: Change iOS native bindings so they return an array of outputs (#43)
This PR fixes the issue in iOS ExecuTorch bindings where if the model returned multiple output arrays, only the first one was considered. ## Description <!-- Provide a concise and descriptive summary of the changes implemented in this PR. --> ### Type of change - [x] Bug fix (non-breaking change which fixes an issue) - [ ] New feature (non-breaking change which adds functionality) - [ ] Breaking change (fix or feature that would cause existing functionality to not work as expected) - [ ] Documentation update (improves or adds clarity to existing documentation) ### Tested on - [x] iOS - [ ] Android ### Testing instructions <!-- Provide step-by-step instructions on how to test your changes. Include setup details if necessary. --> ### Screenshots <!-- Add screenshots here, if applicable --> ### Related issues <!-- Link related issues here using #issue-number --> ### Checklist - [x] I have performed a self-review of my code - [x] I have commented my code, particularly in hard-to-understand areas - [ ] I have updated the documentation accordingly - [ ] My changes generate no new warnings ### Additional notes <!-- Include any additional information, assumptions, or context that reviewers might need to understand this PR. -->
1 parent 15a867e commit 0be2992

File tree

8 files changed

+130
-143
lines changed

8 files changed

+130
-143
lines changed
Binary file not shown.

src/types.ts

+1-1
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ export interface ExecutorchModule {
2121
error: string | null;
2222
isModelLoading: boolean;
2323
isModelRunning: boolean;
24-
forward: (input: ETInput, shape: number[]) => Promise<number[]>;
24+
forward: (input: ETInput, shape: number[]) => Promise<number[][]>;
2525
loadMethod: (methodName: string) => Promise<void>;
2626
loadForward: () => Promise<void>;
2727
}

third-party/ios/ExecutorchLib/ExecutorchLib.xcodeproj/project.pbxproj

+4-4
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@
4242
55EA2C572CB90E7D004315B3 /* Accelerate.framework in Frameworks */ = {isa = PBXBuildFile; fileRef = 55EA2C562CB90E7D004315B3 /* Accelerate.framework */; };
4343
55EA2C592CB90E80004315B3 /* CoreML.framework in Frameworks */ = {isa = PBXBuildFile; fileRef = 55EA2C582CB90E80004315B3 /* CoreML.framework */; };
4444
55EA2C5B2CB90E85004315B3 /* libsqlite3.tbd in Frameworks */ = {isa = PBXBuildFile; fileRef = 55EA2C5A2CB90E85004315B3 /* libsqlite3.tbd */; };
45-
A851C4062CF9F1B600424E93 /* Utils.mm in Sources */ = {isa = PBXBuildFile; fileRef = A851C4052CF9F1B600424E93 /* Utils.mm */; };
45+
A84198842D02DF29006D4D5E /* InputType.h in Headers */ = {isa = PBXBuildFile; fileRef = A84198832D02DF29006D4D5E /* InputType.h */; };
4646
A851C4072CF9F1B600424E93 /* Utils.hpp in Headers */ = {isa = PBXBuildFile; fileRef = A851C4042CF9F1B600424E93 /* Utils.hpp */; };
4747
/* End PBXBuildFile section */
4848

@@ -83,8 +83,8 @@
8383
55EA2C562CB90E7D004315B3 /* Accelerate.framework */ = {isa = PBXFileReference; lastKnownFileType = wrapper.framework; name = Accelerate.framework; path = System/Library/Frameworks/Accelerate.framework; sourceTree = SDKROOT; };
8484
55EA2C582CB90E80004315B3 /* CoreML.framework */ = {isa = PBXFileReference; lastKnownFileType = wrapper.framework; name = CoreML.framework; path = System/Library/Frameworks/CoreML.framework; sourceTree = SDKROOT; };
8585
55EA2C5A2CB90E85004315B3 /* libsqlite3.tbd */ = {isa = PBXFileReference; lastKnownFileType = "sourcecode.text-based-dylib-definition"; name = libsqlite3.tbd; path = usr/lib/libsqlite3.tbd; sourceTree = SDKROOT; };
86+
A84198832D02DF29006D4D5E /* InputType.h */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.c.h; path = InputType.h; sourceTree = "<group>"; };
8687
A851C4042CF9F1B600424E93 /* Utils.hpp */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.cpp.h; path = Utils.hpp; sourceTree = "<group>"; };
87-
A851C4052CF9F1B600424E93 /* Utils.mm */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.cpp.objcpp; path = Utils.mm; sourceTree = "<group>"; };
8888
/* End PBXFileReference section */
8989

9090
/* Begin PBXFrameworksBuildPhase section */
@@ -135,7 +135,7 @@
135135
55EA2C322CB90C7A004315B3 /* sampler */,
136136
55EA2C3E2CB90C7A004315B3 /* tokenizer */,
137137
A851C4042CF9F1B600424E93 /* Utils.hpp */,
138-
A851C4052CF9F1B600424E93 /* Utils.mm */,
138+
A84198832D02DF29006D4D5E /* InputType.h */,
139139
);
140140
path = ExecutorchLib;
141141
sourceTree = "<group>";
@@ -220,6 +220,7 @@
220220
55EA2C542CB90E70004315B3 /* LLaMARunner.h in Headers */,
221221
5576B4B72CEF9709005027B7 /* ETModel.h in Headers */,
222222
55EA2C532CB90C7A004315B3 /* tokenizer.h in Headers */,
223+
A84198842D02DF29006D4D5E /* InputType.h in Headers */,
223224
55EA2C412CB90C7A004315B3 /* stats.h in Headers */,
224225
55EA2C4E2CB90C7A004315B3 /* bpe_tokenizer.h in Headers */,
225226
55EA2C402CB90C7A004315B3 /* runner.h in Headers */,
@@ -308,7 +309,6 @@
308309
buildActionMask = 2147483647;
309310
files = (
310311
55EA2C482CB90C7A004315B3 /* sampler.cpp in Sources */,
311-
A851C4062CF9F1B600424E93 /* Utils.mm in Sources */,
312312
55EA2C3F2CB90C7A004315B3 /* runner.cpp in Sources */,
313313
55EA2C422CB90C7A004315B3 /* text_decoder_runner.cpp in Sources */,
314314
55EA2C4D2CB90C7A004315B3 /* bpe_tokenizer.cpp in Sources */,

third-party/ios/ExecutorchLib/ExecutorchLib/Exported/ETModel.mm

+21-26
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
#import "ETModel.h"
22
#include "Utils.hpp"
3+
#include <InputType.h>
34
#include <executorch/extension/module/module.h>
45
#include <executorch/extension/tensor/tensor.h>
56
#include <executorch/runtime/core/error.h>
@@ -34,38 +35,32 @@ - (NSArray *)forward:(NSArray *)input
3435
inputType:(NSNumber *)inputType {
3536
int inputTypeIntValue = [inputType intValue];
3637
std::vector<int> shapes = NSArrayToIntVector(shape);
37-
ssize_t outputNumel = 0;
3838
@try {
3939
switch (inputTypeIntValue) {
40-
case 0: {
41-
// Int8Array
42-
const int8_t *output =
43-
runForwardFromNSArray<int8_t>(input, outputNumel, shapes, _model);
44-
return arrayToNSArray<int8_t>(output, outputNumel);
40+
case InputTypeInt8: {
41+
std::vector<std::span<const int8_t>> output =
42+
runForwardFromNSArray<int8_t>(input, shapes, _model);
43+
return arrayToNSArray<int8_t>(output);
4544
}
46-
case 1: {
47-
// Int32Array
48-
const int32_t *output =
49-
runForwardFromNSArray<int32_t>(input, outputNumel, shapes, _model);
50-
return arrayToNSArray<int32_t>(output, outputNumel);
45+
case InputTypeInt32: {
46+
std::vector<std::span<const int32_t>> output =
47+
runForwardFromNSArray<int32_t>(input, shapes, _model);
48+
return arrayToNSArray<int32_t>(output);
5149
}
52-
case 2: {
53-
// BigInt64Array
54-
const int64_t *output =
55-
runForwardFromNSArray<int64_t>(input, outputNumel, shapes, _model);
56-
return arrayToNSArray<int64_t>(output, outputNumel);
50+
case InputTypeInt64: {
51+
std::vector<std::span<const int64_t>> output =
52+
runForwardFromNSArray<int64_t>(input, shapes, _model);
53+
return arrayToNSArray<int64_t>(output);
5754
}
58-
case 3: {
59-
// Float32Array
60-
const float *output =
61-
runForwardFromNSArray<float>(input, outputNumel, shapes, _model);
62-
return arrayToNSArray<float>(output, outputNumel);
55+
case InputTypeFloat32: {
56+
std::vector<std::span<const float>> output =
57+
runForwardFromNSArray<float>(input, shapes, _model);
58+
return arrayToNSArray<float>(output);
6359
}
64-
case 4: {
65-
// Float64Array
66-
const double *output =
67-
runForwardFromNSArray<double>(input, outputNumel, shapes, _model);
68-
return arrayToNSArray<double>(output, outputNumel);
60+
case InputTypeFloat64: {
61+
std::vector<std::span<const double>> output =
62+
runForwardFromNSArray<double>(input, shapes, _model);
63+
return arrayToNSArray<double>(output);
6964
}
7065
}
7166
} @catch (NSException *exception) {
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
#ifndef InputType_h
2+
#define InputType_h
3+
4+
typedef NS_ENUM(NSInteger, InputType) {
5+
InputTypeInt8 = 0,
6+
InputTypeInt32 = 1,
7+
InputTypeInt64 = 2,
8+
InputTypeFloat32 = 3,
9+
InputTypeFloat64 = 4,
10+
};
11+
12+
#endif /* InputType_h */

third-party/ios/ExecutorchLib/ExecutorchLib/Utils.hpp

+92-7
Original file line numberDiff line numberDiff line change
@@ -7,24 +7,109 @@
77
#include <memory>
88
#include <string>
99
#include <vector>
10+
#include <span>
1011

1112
#ifdef __OBJC__
1213
#import <Foundation/Foundation.h>
1314
#endif
1415

15-
template <typename T> NSArray *arrayToNSArray(const void *array, ssize_t numel);
16+
using namespace ::executorch::extension;
17+
using namespace ::torch::executor;
1618

17-
std::vector<int> NSArrayToIntVector(NSArray *inputArray);
19+
template <typename T> T getValueFromNSNumber(NSNumber *number) {
20+
if constexpr (std::is_same<T, int8_t>::value) {
21+
return static_cast<T>([number charValue]); // `charValue` for 8-bit integers
22+
} else if constexpr (std::is_same<T, int32_t>::value) {
23+
return static_cast<T>([number intValue]); // `intValue` for 32-bit integers
24+
} else if constexpr (std::is_same<T, int64_t>::value) {
25+
return static_cast<T>(
26+
[number longLongValue]); // `longLongValue` for 64-bit integers
27+
} else if constexpr (std::is_same<T, float>::value) {
28+
return static_cast<T>([number floatValue]);
29+
} else if constexpr (std::is_same<T, double>::value) {
30+
return static_cast<T>([number doubleValue]);
31+
}
32+
}
1833

1934
template <typename T>
20-
std::unique_ptr<T[]> NSArrayToTypedArray(NSArray *nsArray);
35+
std::unique_ptr<T[]> NSArrayToTypedArray(NSArray *nsArray) {
36+
size_t arraySize = [nsArray count];
2137

22-
template <typename T> T getValueFromNSNumber(NSNumber *number);
38+
std::unique_ptr<T[]> typedArray(new T[arraySize]);
39+
40+
for (NSUInteger i = 0; i < arraySize; ++i) {
41+
NSNumber *number = [nsArray objectAtIndex:i];
42+
if ([number isKindOfClass:[NSNumber class]]) {
43+
typedArray[i] = getValueFromNSNumber<T>(number);
44+
} else {
45+
typedArray[i] = T();
46+
}
47+
}
48+
return typedArray;
49+
}
50+
51+
template <typename T>
52+
NSArray *arrayToNSArray(const void *array, ssize_t numel) {
53+
const T *typedArray = static_cast<const T *>(array);
54+
NSMutableArray *nsArray = [NSMutableArray arrayWithCapacity:numel];
55+
56+
for (int i = 0; i < numel; ++i) {
57+
[nsArray addObject:@(typedArray[i])];
58+
}
59+
60+
return [nsArray copy];
61+
}
2362

2463
template <typename T>
25-
const T*
26-
runForwardFromNSArray(NSArray *inputArray, ssize_t& numel, std::vector<int> shapes,
27-
std::unique_ptr<executorch::extension::Module> &model);
64+
NSArray *arrayToNSArray(const std::vector<std::span<const T>> &dataPtrVec) {
65+
NSMutableArray *nsArray = [NSMutableArray array];
66+
for (const auto &span : dataPtrVec) {
67+
NSMutableArray *innerArray = [NSMutableArray arrayWithCapacity:span.size()];
68+
for(auto x : span) {
69+
[innerArray addObject:@(x)];
70+
}
71+
[nsArray addObject:[innerArray copy]];
72+
}
73+
return [nsArray copy];
74+
}
75+
76+
std::vector<int> NSArrayToIntVector(NSArray *inputArray) {
77+
std::vector<int> output;
78+
for (NSUInteger i = 0; i < [inputArray count]; ++i) {
79+
NSNumber *number = [inputArray objectAtIndex:i];
80+
if (number) {
81+
output.push_back([number intValue]);
82+
} else {
83+
output.push_back(0);
84+
}
85+
}
86+
return output;
87+
}
88+
89+
template <typename T>
90+
std::vector<std::span<const T>>
91+
runForwardFromNSArray(NSArray *inputArray, std::vector<int> shapes,
92+
std::unique_ptr<Module> &model) {
93+
std::unique_ptr<T[]> inputPtr = NSArrayToTypedArray<T>(inputArray);
94+
95+
TensorPtr inputTensor = from_blob(inputPtr.get(), shapes);
96+
Result result = model->forward(inputTensor);
97+
98+
if (result.ok()) {
99+
std::vector<std::span<const T>> outputVec;
100+
101+
for (const auto &currentResult : *result) {
102+
Tensor currentTensor = currentResult.toTensor();
103+
std::span<const T> currentSpan(currentTensor.const_data_ptr<T>(), currentTensor.numel());
104+
outputVec.push_back(std::move(currentSpan));
105+
}
106+
return outputVec;
107+
}
28108

109+
@throw [NSException
110+
exceptionWithName:@"forward_error"
111+
reason:[NSString stringWithFormat:@"%d", (int)result.error()]
112+
userInfo:nil];
113+
}
29114

30115
#endif // Utils_hpp

third-party/ios/ExecutorchLib/ExecutorchLib/Utils.mm

-105
This file was deleted.

0 commit comments

Comments
 (0)