Skip to content

Commit 985b8a0

Browse files
authored
feat: Make third-party/ExecuTorchLib's forward() accept multiple inputs (#83)
## Description This PR allows for passing and receiving multiple inputs from the native ET module. ### Type of change - [ ] Bug fix (non-breaking change which fixes an issue) - [x] New feature (non-breaking change which adds functionality) - [ ] Breaking change (fix or feature that would cause existing functionality to not work as expected) - [ ] Documentation update (improves or adds clarity to existing documentation) ### Tested on - [x] iOS - [ ] Android ### Testing instructions <!-- Provide step-by-step instructions on how to test your changes. Include setup details if necessary. --> ### Screenshots <!-- Add screenshots here, if applicable --> ### Related issues <!-- Link related issues here using #issue-number --> ### Checklist - [ ] I have performed a self-review of my code - [ ] I have commented my code, particularly in hard-to-understand areas - [ ] I have updated the documentation accordingly - [ ] My changes generate no new warnings ### Additional notes <!-- Include any additional information, assumptions, or context that reviewers might need to understand this PR. -->
1 parent 44c7eb8 commit 985b8a0

File tree

17 files changed

+462
-279
lines changed

17 files changed

+462
-279
lines changed

android/src/main/java/com/swmansion/rnexecutorch/ETModule.kt

+28-13
Original file line numberDiff line numberDiff line change
@@ -7,12 +7,13 @@ import com.facebook.react.bridge.ReadableArray
77
import com.swmansion.rnexecutorch.utils.ArrayUtils
88
import com.swmansion.rnexecutorch.utils.ETError
99
import com.swmansion.rnexecutorch.utils.TensorUtils
10+
import org.pytorch.executorch.EValue
1011
import org.pytorch.executorch.Module
1112
import java.net.URL
1213

1314
class ETModule(reactContext: ReactApplicationContext) : NativeETModuleSpec(reactContext) {
1415
private lateinit var module: Module
15-
16+
private var reactApplicationContext = reactContext;
1617
override fun getName(): String {
1718
return NAME
1819
}
@@ -33,26 +34,40 @@ class ETModule(reactContext: ReactApplicationContext) : NativeETModuleSpec(react
3334
}
3435

3536
override fun forward(
36-
input: ReadableArray,
37-
shape: ReadableArray,
38-
inputType: Double,
37+
inputs: ReadableArray,
38+
shapes: ReadableArray,
39+
inputTypes: ReadableArray,
3940
promise: Promise
4041
) {
42+
val inputEValues = ArrayList<EValue>()
4143
try {
42-
val executorchInput =
43-
TensorUtils.getExecutorchInput(input, ArrayUtils.createLongArray(shape), inputType.toInt())
44+
for (i in 0 until inputs.size()) {
45+
val currentInput = inputs.getArray(i)
46+
?: throw Exception(ETError.InvalidArgument.code.toString())
47+
val currentShape = shapes.getArray(i)
48+
?: throw Exception(ETError.InvalidArgument.code.toString())
49+
val currentInputType = inputTypes.getInt(i)
4450

45-
val result = module.forward(executorchInput)
46-
val resultArray = Arguments.createArray()
51+
val currentEValue = TensorUtils.getExecutorchInput(
52+
currentInput,
53+
ArrayUtils.createLongArray(currentShape),
54+
currentInputType
55+
)
4756

48-
for (evalue in result) {
49-
resultArray.pushArray(ArrayUtils.createReadableArray(evalue.toTensor()))
57+
inputEValues.add(currentEValue)
5058
}
5159

52-
promise.resolve(resultArray)
53-
return
60+
val forwardOutputs = module.forward(*inputEValues.toTypedArray());
61+
val outputArray = Arguments.createArray()
62+
63+
for (output in forwardOutputs) {
64+
val arr = ArrayUtils.createReadableArrayFromTensor(output.toTensor())
65+
outputArray.pushArray(arr)
66+
}
67+
promise.resolve(outputArray)
68+
5469
} catch (e: IllegalArgumentException) {
55-
//The error is thrown when transformation to Tensor fails
70+
// The error is thrown when transformation to Tensor fails
5671
promise.reject("Forward Failed Execution", ETError.InvalidArgument.code.toString())
5772
return
5873
} catch (e: Exception) {

android/src/main/java/com/swmansion/rnexecutorch/utils/ArrayUtils.kt

+18-48
Original file line numberDiff line numberDiff line change
@@ -7,82 +7,52 @@ import org.pytorch.executorch.Tensor
77

88
class ArrayUtils {
99
companion object {
10-
fun createByteArray(input: ReadableArray): ByteArray {
11-
val byteArray = ByteArray(input.size())
12-
for (i in 0 until input.size()) {
13-
byteArray[i] = input.getInt(i).toByte()
14-
}
15-
return byteArray
10+
private inline fun <reified T> createTypedArrayFromReadableArray(input: ReadableArray, transform: (ReadableArray, Int) -> T): Array<T> {
11+
return Array(input.size()) { index -> transform(input, index) }
1612
}
1713

14+
fun createByteArray(input: ReadableArray): ByteArray {
15+
return createTypedArrayFromReadableArray(input) { array, index -> array.getInt(index).toByte() }.toByteArray()
16+
}
1817
fun createIntArray(input: ReadableArray): IntArray {
19-
val intArray = IntArray(input.size())
20-
for (i in 0 until input.size()) {
21-
intArray[i] = input.getInt(i)
22-
}
23-
return intArray
18+
return createTypedArrayFromReadableArray(input) { array, index -> array.getInt(index) }.toIntArray()
2419
}
2520

2621
fun createFloatArray(input: ReadableArray): FloatArray {
27-
val floatArray = FloatArray(input.size())
28-
for (i in 0 until input.size()) {
29-
floatArray[i] = input.getDouble(i).toFloat()
30-
}
31-
return floatArray
22+
return createTypedArrayFromReadableArray(input) { array, index -> array.getDouble(index).toFloat() }.toFloatArray()
3223
}
3324

3425
fun createLongArray(input: ReadableArray): LongArray {
35-
val longArray = LongArray(input.size())
36-
for (i in 0 until input.size()) {
37-
longArray[i] = input.getInt(i).toLong()
38-
}
39-
return longArray
26+
return createTypedArrayFromReadableArray(input) { array, index -> array.getInt(index).toLong() }.toLongArray()
4027
}
4128

4229
fun createDoubleArray(input: ReadableArray): DoubleArray {
43-
val doubleArray = DoubleArray(input.size())
44-
for (i in 0 until input.size()) {
45-
doubleArray[i] = input.getDouble(i)
46-
}
47-
return doubleArray
30+
return createTypedArrayFromReadableArray(input) { array, index -> array.getDouble(index) }.toDoubleArray()
4831
}
49-
50-
fun createReadableArray(result: Tensor): ReadableArray {
32+
fun createReadableArrayFromTensor(result: Tensor): ReadableArray {
5133
val resultArray = Arguments.createArray()
34+
5235
when (result.dtype()) {
5336
DType.UINT8 -> {
54-
val byteArray = result.dataAsByteArray
55-
for (i in byteArray) {
56-
resultArray.pushInt(i.toInt())
57-
}
37+
result.dataAsByteArray.forEach { resultArray.pushInt(it.toInt()) }
5838
}
5939

6040
DType.INT32 -> {
61-
val intArray = result.dataAsIntArray
62-
for (i in intArray) {
63-
resultArray.pushInt(i)
64-
}
41+
result.dataAsIntArray.forEach { resultArray.pushInt(it) }
6542
}
6643

6744
DType.FLOAT -> {
68-
val longArray = result.dataAsFloatArray
69-
for (i in longArray) {
70-
resultArray.pushDouble(i.toDouble())
71-
}
45+
result.dataAsFloatArray.forEach { resultArray.pushDouble(it.toDouble()) }
7246
}
7347

7448
DType.DOUBLE -> {
75-
val floatArray = result.dataAsDoubleArray
76-
for (i in floatArray) {
77-
resultArray.pushDouble(i)
78-
}
49+
result.dataAsDoubleArray.forEach { resultArray.pushDouble(it) }
7950
}
8051

8152
DType.INT64 -> {
82-
val doubleArray = result.dataAsLongArray
83-
for (i in doubleArray) {
84-
resultArray.pushLong(i)
85-
}
53+
// TODO: Do something to handle or deprecate long dtype
54+
// https://github.com/facebook/react-native/issues/12506
55+
result.dataAsLongArray.forEach { resultArray.pushInt(it.toInt()) }
8656
}
8757

8858
else -> {

android/src/main/java/com/swmansion/rnexecutorch/utils/TensorUtils.kt

+5-9
Original file line numberDiff line numberDiff line change
@@ -12,27 +12,23 @@ class TensorUtils {
1212
fun getExecutorchInput(input: ReadableArray, shape: LongArray, type: Int): EValue {
1313
try {
1414
when (type) {
15-
0 -> {
15+
1 -> {
1616
val inputTensor = Tensor.fromBlob(ArrayUtils.createByteArray(input), shape)
1717
return EValue.from(inputTensor)
1818
}
19-
20-
1 -> {
19+
3 -> {
2120
val inputTensor = Tensor.fromBlob(ArrayUtils.createIntArray(input), shape)
2221
return EValue.from(inputTensor)
2322
}
24-
25-
2 -> {
23+
4 -> {
2624
val inputTensor = Tensor.fromBlob(ArrayUtils.createLongArray(input), shape)
2725
return EValue.from(inputTensor)
2826
}
29-
30-
3 -> {
27+
6 -> {
3128
val inputTensor = Tensor.fromBlob(ArrayUtils.createFloatArray(input), shape)
3229
return EValue.from(inputTensor)
3330
}
34-
35-
4 -> {
31+
7 -> {
3632
val inputTensor = Tensor.fromBlob(ArrayUtils.createDoubleArray(input), shape)
3733
return EValue.from(inputTensor)
3834
}

ios/RnExecutorch/ETModule.mm

+13-9
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
#import "ETModule.h"
22
#import <ExecutorchLib/ETModel.h>
3+
#include <Foundation/Foundation.h>
34
#import <React/RCTBridgeModule.h>
45
#include <string>
56

@@ -36,20 +37,23 @@ - (void)loadModule:(NSString *)modelSource
3637
resolve(result);
3738
}
3839

39-
- (void)forward:(NSArray *)input
40-
shape:(NSArray *)shape
41-
inputType:(double)inputType
40+
- (void)forward:(NSArray *)inputs
41+
shapes:(NSArray *)shapes
42+
inputTypes:(NSArray *)inputTypes
4243
resolve:(RCTPromiseResolveBlock)resolve
4344
reject:(RCTPromiseRejectBlock)reject {
4445
@try {
45-
NSArray *result = [module forward:input
46-
shape:shape
47-
inputType:[NSNumber numberWithInt:inputType]];
46+
NSArray *result = [module forward:inputs
47+
shapes:shapes
48+
inputTypes:inputTypes];
4849
resolve(result);
4950
} @catch (NSException *exception) {
50-
NSLog(@"An exception occurred: %@, %@", exception.name, exception.reason);
51-
reject(@"result_error", [NSString stringWithFormat:@"%@", exception.reason],
52-
nil);
51+
NSLog(@"An exception occurred in forward: %@, %@", exception.name,
52+
exception.reason);
53+
reject(
54+
@"forward_error",
55+
[NSString stringWithFormat:@"An error occurred: %@", exception.reason],
56+
nil);
5357
}
5458
}
5559

ios/RnExecutorch/models/BaseModel.h

+5
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,11 @@
88
}
99

1010
- (NSArray *)forward:(NSArray *)input;
11+
12+
- (NSArray *)forward:(NSArray *)inputs
13+
shapes:(NSArray *)shapes
14+
inputTypes:(NSArray *)inputTypes;
15+
1116
- (void)loadModel:(NSURL *)modelURL
1217
completion:(void (^)(BOOL success, NSNumber *code))completion;
1318

ios/RnExecutorch/models/BaseModel.mm

+17-4
Original file line numberDiff line numberDiff line change
@@ -4,16 +4,29 @@
44
@implementation BaseModel
55

66
- (NSArray *)forward:(NSArray *)input {
7-
NSArray *result = [module forward:input
8-
shape:[module getInputShape:@0]
9-
inputType:[module getInputType:@0]];
7+
NSMutableArray *shapes = [NSMutableArray new];
8+
NSMutableArray *inputTypes = [NSMutableArray new];
9+
NSNumber *numberOfInputs = [module getNumberOfInputs];
10+
11+
for (NSUInteger i = 0; i < [numberOfInputs intValue]; i++) {
12+
[shapes addObject:[module getInputShape:[NSNumber numberWithInt:i]]];
13+
[inputTypes addObject:[module getInputType:[NSNumber numberWithInt:i]]];
14+
}
15+
16+
NSArray *result = [module forward:@[input] shapes:shapes inputTypes:inputTypes];
17+
return result;
18+
}
19+
20+
- (NSArray *)forward:(NSArray *)inputs
21+
shapes:(NSArray *)shapes
22+
inputTypes:(NSArray *)inputTypes {
23+
NSArray *result = [module forward:inputs shapes:shapes inputTypes:inputTypes];
1024
return result;
1125
}
1226

1327
- (void)loadModel:(NSURL *)modelURL
1428
completion:(void (^)(BOOL success, NSNumber *code))completion {
1529
module = [[ETModel alloc] init];
16-
1730
NSNumber *result = [self->module loadModel:modelURL.path];
1831
if ([result intValue] != 0) {
1932
completion(NO, result);

lefthook.yml

+1-1
Original file line numberDiff line numberDiff line change
@@ -6,4 +6,4 @@ pre-commit:
66
run: npx eslint {staged_files}
77
types:
88
glob: '*.{js,ts, jsx, tsx}'
9-
run: npx tsc
9+
run: npx tsc --noEmit

src/hooks/general/useExecutorchModule.ts

+4-1
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,10 @@ export const useExecutorchModule = ({
1515
isReady: boolean;
1616
isGenerating: boolean;
1717
downloadProgress: number;
18-
forward: (input: ETInput, shape: number[]) => Promise<number[][]>;
18+
forward: (
19+
input: ETInput | ETInput[],
20+
shape: number[] | number[][]
21+
) => Promise<number[][]>;
1922
loadMethod: (methodName: string) => Promise<void>;
2023
loadForward: () => Promise<void>;
2124
} => {

src/hooks/useModule.ts

+45-8
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,17 @@
11
import { useEffect, useState } from 'react';
22
import { fetchResource } from '../utils/fetchResource';
33
import { ETError, getError } from '../Error';
4-
import { ETInput, Module, getTypeIdentifier } from '../types/common';
4+
import { ETInput, Module } from '../types/common';
5+
import { _ETModule } from '../native/RnExecutorchModules';
6+
7+
export const getTypeIdentifier = (input: ETInput): number => {
8+
if (input instanceof Int8Array) return 1;
9+
if (input instanceof Int32Array) return 3;
10+
if (input instanceof BigInt64Array) return 4;
11+
if (input instanceof Float32Array) return 6;
12+
if (input instanceof Float64Array) return 7;
13+
return -1;
14+
};
515

616
interface Props {
717
modelSource: string | number;
@@ -13,7 +23,10 @@ interface _Module {
1323
isReady: boolean;
1424
isGenerating: boolean;
1525
downloadProgress: number;
16-
forwardETInput: (input: ETInput, shape: number[]) => Promise<any>;
26+
forwardETInput: (
27+
input: ETInput[] | ETInput,
28+
shape: number[][] | number[]
29+
) => ReturnType<_ETModule['forward']>;
1730
forwardImage: (input: string) => Promise<any>;
1831
}
1932

@@ -59,23 +72,47 @@ export const useModule = ({ modelSource, module }: Props): _Module => {
5972
}
6073
};
6174

62-
const forwardETInput = async (input: ETInput, shape: number[]) => {
75+
const forwardETInput = async (
76+
input: ETInput[] | ETInput,
77+
shape: number[][] | number[]
78+
) => {
6379
if (!isReady) {
6480
throw new Error(getError(ETError.ModuleNotLoaded));
6581
}
6682
if (isGenerating) {
6783
throw new Error(getError(ETError.ModelGenerating));
6884
}
6985

70-
const inputType = getTypeIdentifier(input);
71-
if (inputType === -1) {
72-
throw new Error(getError(ETError.InvalidArgument));
86+
// Since the native module expects an array of inputs and an array of shapes,
87+
// if the user provides a single ETInput, we want to "unsqueeze" the array so
88+
// the data is properly processed on the native side
89+
if (!Array.isArray(input)) {
90+
input = [input];
91+
}
92+
93+
if (!Array.isArray(shape[0])) {
94+
shape = [shape] as number[][];
95+
}
96+
97+
let inputTypeIdentifiers: any[] = [];
98+
let modelInputs: any[] = [];
99+
100+
for (let idx = 0; idx < input.length; idx++) {
101+
let currentInputTypeIdentifier = getTypeIdentifier(input[idx] as ETInput);
102+
if (currentInputTypeIdentifier === -1) {
103+
throw new Error(getError(ETError.InvalidArgument));
104+
}
105+
inputTypeIdentifiers.push(currentInputTypeIdentifier);
106+
modelInputs.push([...(input[idx] as ETInput)]);
73107
}
74108

75109
try {
76-
const numberArray = [...input];
77110
setIsGenerating(true);
78-
const output = await module.forward(numberArray, shape, inputType);
111+
const output = await module.forward(
112+
modelInputs,
113+
shape,
114+
inputTypeIdentifiers
115+
);
79116
setIsGenerating(false);
80117
return output;
81118
} catch (e) {

0 commit comments

Comments
 (0)