14
14
"""A wrapper for TensorFlow Lite image classification models."""
15
15
16
16
import dataclasses
17
+ import json
17
18
import platform
18
19
from typing import List
19
- import zipfile
20
20
21
21
import cv2
22
22
import numpy as np
23
+ from tflite_support import metadata
23
24
24
25
# pylint: disable=g-import-not-at-top
25
26
try :
@@ -76,11 +77,6 @@ def edgetpu_lib_name():
76
77
class ImageClassifier (object ):
77
78
"""A wrapper class for a TFLite image classification model."""
78
79
79
- _mean = 127
80
- """Default mean normalization parameter for float model."""
81
- _std = 128
82
- """Default std normalization parameter for float model."""
83
-
84
80
def __init__ (
85
81
self ,
86
82
model_path : str ,
@@ -96,21 +92,27 @@ def __init__(
96
92
ValueError: If the TFLite model is invalid.
97
93
OSError: If the current OS isn't supported by EdgeTPU.
98
94
"""
95
+ # Load metadata from model.
96
+ displayer = metadata .MetadataDisplayer .with_model_file (model_path )
97
+
98
+ # Save model metadata for preprocessing later.
99
+ model_metadata = json .loads (displayer .get_metadata_json ())
100
+ process_units = model_metadata ['subgraph_metadata' ][0 ][
101
+ 'input_tensor_metadata' ][0 ]['process_units' ]
102
+ mean = 127.5
103
+ std = 127.5
104
+ for option in process_units :
105
+ if option ['options_type' ] == 'NormalizationOptions' :
106
+ mean = option ['options' ]['mean' ][0 ]
107
+ std = option ['options' ]['std' ][0 ]
108
+ self ._mean = mean
109
+ self ._std = std
110
+
99
111
# Load label list from metadata.
100
- try :
101
- with zipfile .ZipFile (model_path ) as model_with_metadata :
102
- if not model_with_metadata .namelist ():
103
- raise ValueError ('Invalid TFLite model: no label file found.' )
104
-
105
- file_name = model_with_metadata .namelist ()[0 ]
106
- with model_with_metadata .open (file_name ) as label_file :
107
- label_list = label_file .read ().splitlines ()
108
- self ._labels_list = [label .decode ('ascii' ) for label in label_list ]
109
- except zipfile .BadZipFile :
110
- print (
111
- 'ERROR: Please use models trained with Model Maker or downloaded from TensorFlow Hub.'
112
- )
113
- raise ValueError ('Invalid TFLite model: no metadata found.' )
112
+ file_name = displayer .get_packed_associated_file_list ()[0 ]
113
+ label_map_file = displayer .get_associated_file_buffer (file_name ).decode ()
114
+ label_list = list (filter (len , label_map_file .splitlines ()))
115
+ self ._label_list = label_list
114
116
115
117
# Initialize TFLite model.
116
118
if options .enable_edgetpu :
@@ -191,7 +193,7 @@ def _postprocess(self, output_tensor: np.ndarray) -> List[Category]:
191
193
range (len (output_tensor )), key = lambda k : output_tensor [k ], reverse = True )
192
194
193
195
categories = [
194
- Category (label = self ._labels_list [idx ], score = output_tensor [idx ])
196
+ Category (label = self ._label_list [idx ], score = output_tensor [idx ])
195
197
for idx in prob_descending
196
198
]
197
199
0 commit comments