-
Notifications
You must be signed in to change notification settings - Fork 16
/
Copy pathObjectDetection.ts
63 lines (55 loc) · 1.66 KB
/
ObjectDetection.ts
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
import { useEffect, useState } from 'react';
import { Image } from 'react-native';
import { ETError, getError } from '../../Error';
import { ObjectDetection } from '../../native/RnExecutorchModules';
import { Detection } from './types';
interface Props {
modelSource: string | number;
}
interface ObjectDetectionModule {
error: string | null;
isModelReady: boolean;
isModelGenerating: boolean;
forward: (input: string) => Promise<Detection[]>;
}
export const useObjectDetection = ({
modelSource,
}: Props): ObjectDetectionModule => {
const [error, setError] = useState<null | string>(null);
const [isModelReady, setIsModelReady] = useState(false);
const [isModelGenerating, setIsModelGenerating] = useState(false);
useEffect(() => {
const loadModel = async () => {
let path = modelSource;
if (typeof modelSource === 'number') {
path = Image.resolveAssetSource(modelSource).uri;
}
try {
setIsModelReady(false);
await ObjectDetection.loadModule(path);
setIsModelReady(true);
} catch (e) {
setError(getError(e));
}
};
loadModel();
}, [modelSource]);
const forward = async (input: string) => {
if (!isModelReady) {
throw new Error(getError(ETError.ModuleNotLoaded));
}
if (isModelGenerating) {
throw new Error(getError(ETError.ModelGenerating));
}
try {
setIsModelGenerating(true);
const output = await ObjectDetection.forward(input);
return output;
} catch (e) {
throw new Error(getError(e));
} finally {
setIsModelGenerating(false);
}
};
return { error, isModelReady, isModelGenerating, forward };
};