-
Notifications
You must be signed in to change notification settings - Fork 16
/
Copy pathuseImageSegmentation.ts
68 lines (61 loc) · 1.8 KB
/
useImageSegmentation.ts
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
import { useState } from 'react';
import { _ImageSegmentationModule } from '../../native/RnExecutorchModules';
import { ETError, getError } from '../../Error';
import { useModule } from '../useModule';
import { DeeplabLabel } from '../../types/image_segmentation';
interface Props {
modelSource: string | number;
}
export const useImageSegmentation = ({
modelSource,
}: Props): {
error: string | null;
isReady: boolean;
isGenerating: boolean;
downloadProgress: number;
forward: (
input: string,
classesOfInterest?: DeeplabLabel[],
resize?: boolean
) => Promise<{ [key in DeeplabLabel]?: number[] }>;
} => {
const [module, _] = useState(() => new _ImageSegmentationModule());
const [isGenerating, setIsGenerating] = useState(false);
const { error, isReady, downloadProgress } = useModule({
modelSource,
module,
});
const forward = async (
input: string,
classesOfInterest?: DeeplabLabel[],
resize?: boolean
) => {
if (!isReady) {
throw new Error(getError(ETError.ModuleNotLoaded));
}
if (isGenerating) {
throw new Error(getError(ETError.ModelGenerating));
}
try {
setIsGenerating(true);
const stringDict = await module.forward(
input,
(classesOfInterest || []).map((label) => DeeplabLabel[label]),
resize || false
);
let enumDict: { [key in DeeplabLabel]?: number[] } = {};
for (const key in stringDict) {
if (key in DeeplabLabel) {
const enumKey = DeeplabLabel[key as keyof typeof DeeplabLabel];
enumDict[enumKey] = stringDict[key];
}
}
return enumDict;
} catch (e) {
throw new Error(getError(e));
} finally {
setIsGenerating(false);
}
};
return { error, isReady, isGenerating, downloadProgress, forward };
};