Skip to content

Commit 69802ee

Browse files
mkopcinsMateusz KopcinskiMateusz Kopciński
authored
chore:refactor ts cv hooks (#64)
## Description <!-- Provide a concise and descriptive summary of the changes implemented in this PR. --> ### Type of change - [x] Bug fix (non-breaking change which fixes an issue) - [ ] New feature (non-breaking change which adds functionality) - [ ] Breaking change (fix or feature that would cause existing functionality to not work as expected) - [ ] Documentation update (improves or adds clarity to existing documentation) ### Tested on - [ ] iOS - [x] Android ### Testing instructions <!-- Provide step-by-step instructions on how to test your changes. Include setup details if necessary. --> ### Screenshots <!-- Add screenshots here, if applicable --> ### Related issues <!-- Link related issues here using #issue-number --> ### Checklist - [x] I have performed a self-review of my code - [ ] I have commented my code, particularly in hard-to-understand areas - [x] I have updated the documentation accordingly - [x] My changes generate no new warnings ### Additional notes <!-- Include any additional information, assumptions, or context that reviewers might need to understand this PR. --> --------- Co-authored-by: Mateusz Kopcinski <[email protected]> Co-authored-by: Mateusz Kopciński <[email protected]>
1 parent 6251d27 commit 69802ee

21 files changed

+317
-300
lines changed

android/src/main/java/com/swmansion/rnexecutorch/ObjectDetection.kt

+1-1
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ class ObjectDetection(reactContext: ReactApplicationContext) :
3535
ssdLiteLarge.loadModel(modelSource)
3636
promise.resolve(0)
3737
} catch (e: Exception) {
38-
promise.reject(e.message!!, ETError.InvalidModelPath.toString())
38+
promise.reject(e.message!!, ETError.InvalidModelSource.toString())
3939
}
4040
}
4141

docs/docs/guides/running-llms.md

+3-3
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ const llama = useLLM({
2323
});
2424
```
2525

26-
The code snippet above fetches the model from the specified URL, loads it into memory, and returns an object with various methods and properties for controlling the model. You can monitor the loading progress by checking the `llama.downloadProgress` and `llama.isModelReady` property, and if anything goes wrong, the `llama.error` property will contain the error message.
26+
The code snippet above fetches the model from the specified URL, loads it into memory, and returns an object with various methods and properties for controlling the model. You can monitor the loading progress by checking the `llama.downloadProgress` and `llama.isReady` property, and if anything goes wrong, the `llama.error` property will contain the error message.
2727

2828
:::danger[Danger]
2929
Lower-end devices might not be able to fit LLMs into memory. We recommend using quantized models to reduce the memory footprint.
@@ -50,9 +50,9 @@ Given computational constraints, our architecture is designed to support only on
5050
| `generate` | `(input: string) => Promise<void>` | Function to start generating a response with the given input string. |
5151
| `response` | `string` | State of the generated response. This field is updated with each token generated by the model |
5252
| `error` | <code>string &#124; null</code> | Contains the error message if the model failed to load |
53-
| `isModelGenerating` | `boolean` | Indicates whether the model is currently generating a response |
53+
| `isGenerating` | `boolean` | Indicates whether the model is currently generating a response |
5454
| `interrupt` | `() => void` | Function to interrupt the current inference |
55-
| `isModelReady` | `boolean` | Indicates whether the model is ready |
55+
| `isReady` | `boolean` | Indicates whether the model is ready |
5656
| `downloadProgress` | `number` | Represents the download progress as a value between 0 and 1, indicating the extent of the model file retrieval. |
5757

5858
### Loading models

examples/computer-vision/screens/ClassificationScreen.tsx

+2-5
Original file line numberDiff line numberDiff line change
@@ -44,12 +44,9 @@ export const ClassificationScreen = ({
4444
}
4545
};
4646

47-
if (!model.isModelReady) {
47+
if (!model.isReady) {
4848
return (
49-
<Spinner
50-
visible={!model.isModelReady}
51-
textContent={`Loading the model...`}
52-
/>
49+
<Spinner visible={!model.isReady} textContent={`Loading the model...`} />
5350
);
5451
}
5552

examples/computer-vision/screens/ObjectDetectionScreen.tsx

+2-2
Original file line numberDiff line numberDiff line change
@@ -52,10 +52,10 @@ export const ObjectDetectionScreen = ({
5252
}
5353
};
5454

55-
if (!ssdLite.isModelReady) {
55+
if (!ssdLite.isReady) {
5656
return (
5757
<Spinner
58-
visible={!ssdLite.isModelReady}
58+
visible={!ssdLite.isReady}
5959
textContent={`Loading the model...`}
6060
/>
6161
);

examples/computer-vision/screens/StyleTransferScreen.tsx

+2-5
Original file line numberDiff line numberDiff line change
@@ -37,12 +37,9 @@ export const StyleTransferScreen = ({
3737
}
3838
};
3939

40-
if (!model.isModelReady) {
40+
if (!model.isReady) {
4141
return (
42-
<Spinner
43-
visible={!model.isModelReady}
44-
textContent={`Loading the model...`}
45-
/>
42+
<Spinner visible={!model.isReady} textContent={`Loading the model...`} />
4643
);
4744
}
4845

examples/computer-vision/yarn.lock

+5-5
Original file line numberDiff line numberDiff line change
@@ -3352,7 +3352,7 @@ __metadata:
33523352
metro-config: ^0.81.0
33533353
react: 18.3.1
33543354
react-native: 0.76.3
3355-
react-native-executorch: ../../react-native-executorch-0.1.100.tgz
3355+
react-native-executorch: ^0.1.3
33563356
react-native-image-picker: ^7.2.2
33573357
react-native-loading-spinner-overlay: ^3.0.1
33583358
react-native-reanimated: ^3.16.3
@@ -6799,13 +6799,13 @@ __metadata:
67996799
languageName: node
68006800
linkType: hard
68016801

6802-
"react-native-executorch@file:../../react-native-executorch-0.1.100.tgz::locator=computer-vision%40workspace%3A.":
6803-
version: 0.1.100
6804-
resolution: "react-native-executorch@file:../../react-native-executorch-0.1.100.tgz::locator=computer-vision%40workspace%3A."
6802+
"react-native-executorch@npm:^0.1.3":
6803+
version: 0.1.3
6804+
resolution: "react-native-executorch@npm:0.1.3"
68056805
peerDependencies:
68066806
react: "*"
68076807
react-native: "*"
6808-
checksum: f258452e2050df59e150938f6482ef8eee5fbd4ef4fc4073a920293ca87d543daddf76c560701d0c2626e6677d964b446dad8e670e978ea4f80d0a1bd17dfa03
6808+
checksum: b49f8ca9ae8a7de4a7f2263887537626859507c7d60af47360515b405c7778b48c4c067074e7ce7857782a6737cf47cf5dadada03ae9319a6bf577f8490f431d
68096809
languageName: node
68106810
linkType: hard
68116811

examples/llama/components/Messages.tsx

+3-3
Original file line numberDiff line numberDiff line change
@@ -9,13 +9,13 @@ import MessageItem from './MessageItem';
99
interface MessagesComponentProps {
1010
chatHistory: Array<MessageType>;
1111
llmResponse: string;
12-
isModelGenerating: boolean;
12+
isGenerating: boolean;
1313
}
1414

1515
export default function Messages({
1616
chatHistory,
1717
llmResponse,
18-
isModelGenerating,
18+
isGenerating,
1919
}: MessagesComponentProps) {
2020
const scrollViewRef = useRef<ScrollView>(null);
2121

@@ -29,7 +29,7 @@ export default function Messages({
2929
{chatHistory.map((message, index) => (
3030
<MessageItem key={index} message={message} />
3131
))}
32-
{isModelGenerating && (
32+
{isGenerating && (
3333
<View style={styles.aiMessage}>
3434
<View style={styles.aiMessageIconContainer}>
3535
<LlamaIcon width={24} height={24} />

examples/llama/screens/ChatScreen.tsx

+7-7
Original file line numberDiff line numberDiff line change
@@ -31,10 +31,10 @@ export default function ChatScreen() {
3131
});
3232
const textInputRef = useRef<TextInput>(null);
3333
useEffect(() => {
34-
if (llama.response && !llama.isModelGenerating) {
34+
if (llama.response && !llama.isGenerating) {
3535
appendToMessageHistory(llama.response, 'ai');
3636
}
37-
}, [llama.response, llama.isModelGenerating]);
37+
}, [llama.response, llama.isGenerating]);
3838

3939
const appendToMessageHistory = (input: string, role: SenderType) => {
4040
setChatHistory((prevHistory) => [
@@ -54,9 +54,9 @@ export default function ChatScreen() {
5454
}
5555
};
5656

57-
return !llama.isModelReady ? (
57+
return !llama.isReady ? (
5858
<Spinner
59-
visible={!llama.isModelReady}
59+
visible={!llama.isReady}
6060
textContent={`Loading the model ${(llama.downloadProgress * 100).toFixed(0)} %`}
6161
/>
6262
) : (
@@ -76,7 +76,7 @@ export default function ChatScreen() {
7676
<Messages
7777
chatHistory={chatHistory}
7878
llmResponse={llama.response}
79-
isModelGenerating={llama.isModelGenerating}
79+
isGenerating={llama.isGenerating}
8080
/>
8181
</View>
8282
) : (
@@ -108,13 +108,13 @@ export default function ChatScreen() {
108108
<TouchableOpacity
109109
style={styles.sendChatTouchable}
110110
onPress={async () =>
111-
!llama.isModelGenerating && (await sendMessage())
111+
!llama.isGenerating && (await sendMessage())
112112
}
113113
>
114114
<SendIcon height={24} width={24} padding={4} margin={8} />
115115
</TouchableOpacity>
116116
)}
117-
{llama.isModelGenerating && (
117+
{llama.isGenerating && (
118118
<TouchableOpacity
119119
style={styles.sendChatTouchable}
120120
onPress={llama.interrupt}

src/ETModule.ts

+22-66
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,8 @@
1-
import { useEffect, useState } from 'react';
2-
import { Image } from 'react-native';
3-
import { ETModule } from './native/RnExecutorchModules';
4-
import { ETError, getError } from './Error';
5-
import { ETInput, ExecutorchModule } from './types';
6-
7-
const getTypeIdentifier = (arr: ETInput): number => {
8-
if (arr instanceof Int8Array) return 0;
9-
if (arr instanceof Int32Array) return 1;
10-
if (arr instanceof BigInt64Array) return 2;
11-
if (arr instanceof Float32Array) return 3;
12-
if (arr instanceof Float64Array) return 4;
13-
14-
return -1;
15-
};
1+
import { useState } from 'react';
2+
import { _ETModule } from './native/RnExecutorchModules';
3+
import { getError } from './Error';
4+
import { ExecutorchModule } from './types/common';
5+
import { useModule } from './useModule';
166

177
interface Props {
188
modelSource: string | number;
@@ -21,54 +11,20 @@ interface Props {
2111
export const useExecutorchModule = ({
2212
modelSource,
2313
}: Props): ExecutorchModule => {
24-
const [error, setError] = useState<string | null>(null);
25-
const [isModelLoading, setIsModelLoading] = useState(true);
26-
const [isModelGenerating, setIsModelGenerating] = useState(false);
27-
28-
useEffect(() => {
29-
const loadModel = async () => {
30-
let path = modelSource;
31-
if (typeof modelSource === 'number') {
32-
path = Image.resolveAssetSource(modelSource).uri;
33-
}
34-
35-
try {
36-
setIsModelLoading(true);
37-
await ETModule.loadModule(path);
38-
setIsModelLoading(false);
39-
} catch (e: unknown) {
40-
setError(getError(e));
41-
setIsModelLoading(false);
42-
}
43-
};
44-
loadModel();
45-
}, [modelSource]);
46-
47-
const forward = async (input: ETInput, shape: number[]) => {
48-
if (isModelLoading) {
49-
throw new Error(getError(ETError.ModuleNotLoaded));
50-
}
51-
52-
const inputType = getTypeIdentifier(input);
53-
if (inputType === -1) {
54-
throw new Error(getError(ETError.InvalidArgument));
55-
}
56-
57-
try {
58-
const numberArray = [...input];
59-
setIsModelGenerating(true);
60-
const output = await ETModule.forward(numberArray, shape, inputType);
61-
setIsModelGenerating(false);
62-
return output;
63-
} catch (e) {
64-
setIsModelGenerating(false);
65-
throw new Error(getError(e));
66-
}
67-
};
14+
const [module] = useState(() => new _ETModule());
15+
const {
16+
error,
17+
isReady,
18+
isGenerating,
19+
forwardETInput: forward,
20+
} = useModule({
21+
modelSource,
22+
module,
23+
});
6824

6925
const loadMethod = async (methodName: string) => {
7026
try {
71-
await ETModule.loadMethod(methodName);
27+
await module.loadMethod(methodName);
7228
} catch (e) {
7329
throw new Error(getError(e));
7430
}
@@ -79,11 +35,11 @@ export const useExecutorchModule = ({
7935
};
8036

8137
return {
82-
error: error,
83-
isModelLoading: isModelLoading,
84-
isModelGenerating: isModelGenerating,
85-
forward: forward,
86-
loadMethod: loadMethod,
87-
loadForward: loadForward,
38+
error,
39+
isReady,
40+
isGenerating,
41+
forward,
42+
loadMethod,
43+
loadForward,
8844
};
8945
};

src/LLM.ts

+14-12
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import { useCallback, useEffect, useRef, useState } from 'react';
22
import { EventSubscription, Image } from 'react-native';
3-
import { ResourceSource, Model } from './types';
3+
import { ResourceSource, Model } from './types/common';
44
import {
55
DEFAULT_CONTEXT_WINDOW_LENGTH,
66
DEFAULT_SYSTEM_PROMPT,
@@ -24,8 +24,8 @@ export const useLLM = ({
2424
contextWindowLength?: number;
2525
}): Model => {
2626
const [error, setError] = useState<string | null>(null);
27-
const [isModelReady, setIsModelReady] = useState(false);
28-
const [isModelGenerating, setIsModelGenerating] = useState(false);
27+
const [isReady, setIsReady] = useState(false);
28+
const [isGenerating, setIsGenerating] = useState(false);
2929
const [response, setResponse] = useState('');
3030
const [downloadProgress, setDownloadProgress] = useState(0);
3131
const downloadProgressListener = useRef<null | EventSubscription>(null);
@@ -65,7 +65,7 @@ export const useLLM = ({
6565
contextWindowLength
6666
);
6767

68-
setIsModelReady(true);
68+
setIsReady(true);
6969

7070
tokenGeneratedListener.current = LLM.onToken(
7171
(data: string | undefined) => {
@@ -75,13 +75,13 @@ export const useLLM = ({
7575
if (data !== EOT_TOKEN) {
7676
setResponse((prevResponse) => prevResponse + data);
7777
} else {
78-
setIsModelGenerating(false);
78+
setIsGenerating(false);
7979
}
8080
}
8181
);
8282
} catch (err) {
8383
const message = (err as Error).message;
84-
setIsModelReady(false);
84+
setIsReady(false);
8585
setError(message);
8686
}
8787
};
@@ -99,7 +99,7 @@ export const useLLM = ({
9999

100100
const generate = useCallback(
101101
async (input: string): Promise<void> => {
102-
if (!isModelReady) {
102+
if (!isReady) {
103103
throw new Error('Model is still loading');
104104
}
105105
if (error) {
@@ -108,21 +108,23 @@ export const useLLM = ({
108108

109109
try {
110110
setResponse('');
111-
setIsModelGenerating(true);
111+
setIsGenerating(true);
112112
await LLM.runInference(input);
113113
} catch (err) {
114-
setIsModelGenerating(false);
114+
setIsGenerating(false);
115115
throw new Error((err as Error).message);
116116
}
117117
},
118-
[isModelReady, error]
118+
[isReady, error]
119119
);
120120

121121
return {
122122
generate,
123123
error,
124-
isModelReady,
125-
isModelGenerating,
124+
isReady,
125+
isGenerating,
126+
isModelReady: isReady,
127+
isModelGenerating: isGenerating,
126128
response,
127129
downloadProgress,
128130
interrupt,

0 commit comments

Comments
 (0)