Skip to content

Commit 6602fb9

Browse files
committed
add automatic model search for opencv launcher
1 parent 10f54a1 commit 6602fb9

File tree

2 files changed

+77
-6
lines changed

2 files changed

+77
-6
lines changed

tools/accuracy_checker/openvino/tools/accuracy_checker/config/config_reader.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -925,7 +925,7 @@ def provide_precision_and_layout(launchers, input_precisions, input_layouts):
925925

926926

927927
def provide_model_type(launcher, arguments):
928-
if 'model_type' in arguments:
928+
if 'model_type' in arguments and arguments.model_type is not None:
929929
launcher['_model_type'] = arguments.model_type
930930
if launcher['framework'] in ['dlsdk', 'openvino', 'g-api'] and 'model_is_blob' in arguments:
931931
launcher['_model_is_blob'] = arguments.model_is_blob

tools/accuracy_checker/openvino/tools/accuracy_checker/launcher/opencv_launcher.py

+76-5
Original file line numberDiff line numberDiff line change
@@ -16,13 +16,14 @@
1616

1717
import re
1818
from collections import OrderedDict
19+
from pathlib import Path
1920
import numpy as np
2021
import cv2
2122

2223
from ..config import PathField, StringField, ConfigError, ListInputsField
2324
from ..logging import print_info
2425
from .launcher import Launcher, LauncherConfigValidator
25-
from ..utils import get_or_parse_value
26+
from ..utils import get_or_parse_value, get_path
2627

2728
DEVICE_REGEX = r'(?P<device>cpu$|gpu|gpu_fp16)?'
2829
BACKEND_REGEX = r'(?P<backend>ocv|ie)?'
@@ -63,8 +64,11 @@ class OpenCVLauncher(Launcher):
6364
def parameters(cls):
6465
parameters = super().parameters()
6566
parameters.update({
66-
'model': PathField(description="Path to model file."),
67-
'weights': PathField(description="Path to weights file.", optional=True, default='', check_exists=False),
67+
'model': PathField(description="Path to model file.", file_or_directory=True),
68+
'weights': PathField(
69+
description="Path to weights file.", optional=True,
70+
check_exists=False, file_or_directory=True
71+
),
6872
'device': StringField(
6973
regex=DEVICE_REGEX, choices=OpenCVLauncher.TARGET_DEVICES.keys(),
7074
description="Device name: {}".format(', '.join(OpenCVLauncher.TARGET_DEVICES.keys()))
@@ -100,8 +104,10 @@ def __init__(self, config_entry: dict, *args, **kwargs):
100104
raise ConfigError('{} is not supported device'.format(selected_device))
101105

102106
if not self._delayed_model_loading:
103-
self.model = self.get_value_from_config('model')
104-
self.weights = self.get_value_from_config('weights')
107+
self.model, self.weights = self.automatic_model_search(self._model_name,
108+
self.get_value_from_config('model'), self.get_value_from_config('weights'),
109+
self.get_value_from_config('_model_type')
110+
)
105111
self.network = self.create_network(self.model, self.weights)
106112
self._inputs_shapes = self.get_inputs_from_config(self.config)
107113
self.network.setInputsNames(list(self._inputs_shapes.keys()))
@@ -130,6 +136,71 @@ def batch(self):
130136
def output_blob(self):
131137
return next(iter(self.output_names))
132138

139+
def automatic_model_search(self, model_name, model_cfg, weights_cfg, model_type=None):
140+
model_type_ext = {
141+
'xml': 'xml',
142+
'blob': 'blob',
143+
'onnx': 'onnx',
144+
'caffe': 'prototxt',
145+
'tf': 'pb'
146+
}
147+
def get_model_by_suffix(model_name, model_dir, suffix):
148+
model_list = list(Path(model_dir).glob('{}.{}'.format(model_name, suffix)))
149+
if not model_list:
150+
model_list = list(Path(model_dir).glob('*.{}'.format(suffix)))
151+
if not model_list:
152+
model_list = list(Path(model_dir).parent.rglob('*.{}'.format(suffix)))
153+
return model_list
154+
155+
def get_model():
156+
model = Path(model_cfg)
157+
if not model.is_dir():
158+
accepted_suffixes = list(model_type_ext.values())
159+
if model.suffix[1:] not in accepted_suffixes:
160+
raise ConfigError('Models with following suffixes are allowed: {}'.format(accepted_suffixes))
161+
print_info('Found model {}'.format(model))
162+
return model, model.suffix == '.blob'
163+
model_list = []
164+
if model_type is not None:
165+
model_list = get_model_by_suffix(model_name, model, model_type_ext[model_type])
166+
else:
167+
for ext in model_type_ext.values():
168+
model_list = get_model_by_suffix(model_name, model, ext)
169+
if model_list:
170+
break
171+
if not model_list:
172+
raise ConfigError('suitable model is not found')
173+
if len(model_list) != 1:
174+
raise ConfigError('More than one model matched, please specify explicitly')
175+
model = model_list[0]
176+
print_info('Found model {}'.format(model))
177+
return model, model.suffix == '.blob'
178+
179+
model, is_blob = get_model()
180+
if is_blob:
181+
return model, None
182+
weights = weights_cfg
183+
if (weights is None or Path(weights).is_dir()) and model.suffix != '.onnx':
184+
weights_dir = weights or model.parent
185+
weights_list = []
186+
if model.suffix == '.xml':
187+
weights = Path(weights_dir) / model.name.replace('xml', 'bin')
188+
else:
189+
if model.suffix == '.prototxt':
190+
weights_list = list(Path(weights_dir).glob('*.{}'.format('caffemodel')))
191+
if not weights_list:
192+
raise ConfigError('Suitable weights is not detected')
193+
if len(weights_list) != 1:
194+
raise ConfigError('Several suitable weights found, please specify required explicitly')
195+
weights = weights_list[0]
196+
if weights is not None:
197+
accepted_weights_suffixes = ['.bin', '.caffemodel']
198+
if weights.suffix not in accepted_weights_suffixes:
199+
raise ConfigError('Weights with following suffixes are allowed: {}'.format(accepted_weights_suffixes))
200+
print_info('Found weights {}'.format(get_path(weights)))
201+
202+
return model, weights
203+
133204
def predict(self, inputs, metadata=None, **kwargs):
134205
"""
135206
Args:

0 commit comments

Comments
 (0)