|
16 | 16 |
|
17 | 17 | import re
|
18 | 18 | from collections import OrderedDict
|
| 19 | +from pathlib import Path |
19 | 20 | import numpy as np
|
20 | 21 | import cv2
|
21 | 22 |
|
22 | 23 | from ..config import PathField, StringField, ConfigError, ListInputsField
|
23 | 24 | from ..logging import print_info
|
24 | 25 | from .launcher import Launcher, LauncherConfigValidator
|
25 |
| -from ..utils import get_or_parse_value |
| 26 | +from ..utils import get_or_parse_value, get_path |
26 | 27 |
|
27 | 28 | DEVICE_REGEX = r'(?P<device>cpu$|gpu|gpu_fp16)?'
|
28 | 29 | BACKEND_REGEX = r'(?P<backend>ocv|ie)?'
|
@@ -63,8 +64,11 @@ class OpenCVLauncher(Launcher):
|
63 | 64 | def parameters(cls):
|
64 | 65 | parameters = super().parameters()
|
65 | 66 | parameters.update({
|
66 |
| - 'model': PathField(description="Path to model file."), |
67 |
| - 'weights': PathField(description="Path to weights file.", optional=True, default='', check_exists=False), |
| 67 | + 'model': PathField(description="Path to model file.", file_or_directory=True), |
| 68 | + 'weights': PathField( |
| 69 | + description="Path to weights file.", optional=True, |
| 70 | + check_exists=False, file_or_directory=True |
| 71 | + ), |
68 | 72 | 'device': StringField(
|
69 | 73 | regex=DEVICE_REGEX, choices=OpenCVLauncher.TARGET_DEVICES.keys(),
|
70 | 74 | description="Device name: {}".format(', '.join(OpenCVLauncher.TARGET_DEVICES.keys()))
|
@@ -100,8 +104,10 @@ def __init__(self, config_entry: dict, *args, **kwargs):
|
100 | 104 | raise ConfigError('{} is not supported device'.format(selected_device))
|
101 | 105 |
|
102 | 106 | if not self._delayed_model_loading:
|
103 |
| - self.model = self.get_value_from_config('model') |
104 |
| - self.weights = self.get_value_from_config('weights') |
| 107 | + self.model, self.weights = self.automatic_model_search(self._model_name, |
| 108 | + self.get_value_from_config('model'), self.get_value_from_config('weights'), |
| 109 | + self.get_value_from_config('_model_type') |
| 110 | + ) |
105 | 111 | self.network = self.create_network(self.model, self.weights)
|
106 | 112 | self._inputs_shapes = self.get_inputs_from_config(self.config)
|
107 | 113 | self.network.setInputsNames(list(self._inputs_shapes.keys()))
|
@@ -130,6 +136,71 @@ def batch(self):
|
130 | 136 | def output_blob(self):
|
131 | 137 | return next(iter(self.output_names))
|
132 | 138 |
|
| 139 | + def automatic_model_search(self, model_name, model_cfg, weights_cfg, model_type=None): |
| 140 | + model_type_ext = { |
| 141 | + 'xml': 'xml', |
| 142 | + 'blob': 'blob', |
| 143 | + 'onnx': 'onnx', |
| 144 | + 'caffe': 'prototxt', |
| 145 | + 'tf': 'pb' |
| 146 | + } |
| 147 | + def get_model_by_suffix(model_name, model_dir, suffix): |
| 148 | + model_list = list(Path(model_dir).glob('{}.{}'.format(model_name, suffix))) |
| 149 | + if not model_list: |
| 150 | + model_list = list(Path(model_dir).glob('*.{}'.format(suffix))) |
| 151 | + if not model_list: |
| 152 | + model_list = list(Path(model_dir).parent.rglob('*.{}'.format(suffix))) |
| 153 | + return model_list |
| 154 | + |
| 155 | + def get_model(): |
| 156 | + model = Path(model_cfg) |
| 157 | + if not model.is_dir(): |
| 158 | + accepted_suffixes = list(model_type_ext.values()) |
| 159 | + if model.suffix[1:] not in accepted_suffixes: |
| 160 | + raise ConfigError('Models with following suffixes are allowed: {}'.format(accepted_suffixes)) |
| 161 | + print_info('Found model {}'.format(model)) |
| 162 | + return model, model.suffix == '.blob' |
| 163 | + model_list = [] |
| 164 | + if model_type is not None: |
| 165 | + model_list = get_model_by_suffix(model_name, model, model_type_ext[model_type]) |
| 166 | + else: |
| 167 | + for ext in model_type_ext.values(): |
| 168 | + model_list = get_model_by_suffix(model_name, model, ext) |
| 169 | + if model_list: |
| 170 | + break |
| 171 | + if not model_list: |
| 172 | + raise ConfigError('suitable model is not found') |
| 173 | + if len(model_list) != 1: |
| 174 | + raise ConfigError('More than one model matched, please specify explicitly') |
| 175 | + model = model_list[0] |
| 176 | + print_info('Found model {}'.format(model)) |
| 177 | + return model, model.suffix == '.blob' |
| 178 | + |
| 179 | + model, is_blob = get_model() |
| 180 | + if is_blob: |
| 181 | + return model, None |
| 182 | + weights = weights_cfg |
| 183 | + if (weights is None or Path(weights).is_dir()) and model.suffix != '.onnx': |
| 184 | + weights_dir = weights or model.parent |
| 185 | + weights_list = [] |
| 186 | + if model.suffix == '.xml': |
| 187 | + weights = Path(weights_dir) / model.name.replace('xml', 'bin') |
| 188 | + else: |
| 189 | + if model.suffix == '.prototxt': |
| 190 | + weights_list = list(Path(weights_dir).glob('*.{}'.format('caffemodel'))) |
| 191 | + if not weights_list: |
| 192 | + raise ConfigError('Suitable weights is not detected') |
| 193 | + if len(weights_list) != 1: |
| 194 | + raise ConfigError('Several suitable weights found, please specify required explicitly') |
| 195 | + weights = weights_list[0] |
| 196 | + if weights is not None: |
| 197 | + accepted_weights_suffixes = ['.bin', '.caffemodel'] |
| 198 | + if weights.suffix not in accepted_weights_suffixes: |
| 199 | + raise ConfigError('Weights with following suffixes are allowed: {}'.format(accepted_weights_suffixes)) |
| 200 | + print_info('Found weights {}'.format(get_path(weights))) |
| 201 | + |
| 202 | + return model, weights |
| 203 | + |
133 | 204 | def predict(self, inputs, metadata=None, **kwargs):
|
134 | 205 | """
|
135 | 206 | Args:
|
|
0 commit comments