diff --git a/demos/common/python/openvino/model_zoo/model_api/models/open_pose.py b/demos/common/python/openvino/model_zoo/model_api/models/open_pose.py index 677ba2ced9d..b51d08bc475 100644 --- a/demos/common/python/openvino/model_zoo/model_api/models/open_pose.py +++ b/demos/common/python/openvino/model_zoo/model_api/models/open_pose.py @@ -20,8 +20,8 @@ from numpy.core.umath import clip except ImportError: from numpy import clip -import openvino.runtime.opset8 as opset8 +from ..adapters import OpenvinoAdapter from .image_model import ImageModel from .types import NumericalValue @@ -31,43 +31,44 @@ class OpenPose(ImageModel): def __init__(self, model_adapter, configuration=None, preload=False): super().__init__(model_adapter, configuration, preload=False) - self.pooled_heatmaps_blob_name = 'pooled_heatmaps' - self.heatmaps_blob_name = 'heatmaps' - self.pafs_blob_name = 'pafs' + self._check_io_number(1, 2) - function = self.model_adapter.model - paf = function.get_output_op(0) - paf_shape = paf.output(0).get_shape() - heatmap = function.get_output_op(1) + if isinstance(model_adapter, OpenvinoAdapter): + import openvino.runtime.opset8 as opset8 + self.is_openvino_adapter = True + else: + self.is_openvino_adapter = False + + self.heatmaps_blob_name, self.pafs_blob_name = self.outputs + heatmap_shape = self.outputs[self.heatmaps_blob_name].shape + paf_shape = self.outputs[self.pafs_blob_name].shape - heatmap_shape = heatmap.output(0).get_shape() if len(paf_shape) != 4 and len(heatmap_shape) != 4: self.raise_error('OpenPose outputs must be 4-dimensional') if paf_shape[2] != heatmap_shape[2] and paf_shape[3] != heatmap_shape[3]: self.raise_error('Last two dimensions of OpenPose outputs must match') if paf_shape[1] * 2 == heatmap_shape[1]: - paf, heatmap = heatmap, paf + self.heatmaps_blob_name, self.pafs_blob_name = self.pafs_blob_name, self.heatmaps_blob_name elif paf_shape[1] != heatmap_shape[1] * 2: self.raise_error('Size of second dimension of OpenPose of one output must be two times larger then size ' 'of second dimension of another output') - paf = paf.inputs()[0].get_source_output().get_node() - paf.get_output_tensor(0).set_names({self.pafs_blob_name}) - heatmap = heatmap.inputs()[0].get_source_output().get_node() - - heatmap.get_output_tensor(0).set_names({self.heatmaps_blob_name}) - # Add keypoints NMS to the network. # Heuristic NMS kernel size adjustment depending on the feature maps upsampling ratio. - p = int(np.round(6 / 7 * self.upsample_ratio)) - k = 2 * p + 1 - pooled_heatmap = opset8.max_pool(heatmap, kernel_shape=(k, k), dilations=(1, 1), pads_begin=(p, p), pads_end=(p, p), - strides=(1, 1), name=self.pooled_heatmaps_blob_name) - pooled_heatmap.output(0).get_tensor().set_names({self.pooled_heatmaps_blob_name}) - self.model_adapter.model.add_outputs([pooled_heatmap.output(0)]) + self.p = int(np.round(6 / 7 * self.upsample_ratio)) + self.k = 2 * self.p + 1 + + if self.is_openvino_adapter: + self.pooled_heatmaps_blob_name = 'pooled_heatmaps' + heatmap = self.model_adapter.model.get_output_op(1) + heatmap = heatmap.inputs()[0].get_source_output().get_node() - self.inputs = self.model_adapter.get_input_layers() - self.outputs = self.model_adapter.get_output_layers() + pooled_heatmap = opset8.max_pool( + heatmap, kernel_shape=(self.k, self.k), dilations=(1, 1), pads_begin=(self.p, self.p), + pads_end=(self.p, self.p), strides=(1, 1), name=self.pooled_heatmaps_blob_name) + + pooled_heatmap.output(0).get_tensor().set_names({self.pooled_heatmaps_blob_name}) + self.model_adapter.model.add_outputs([pooled_heatmap.output(0)]) self.output_scale = self.inputs[self.image_blob_name].shape[-2] / self.outputs[self.heatmaps_blob_name].shape[-2] @@ -99,6 +100,15 @@ def parameters(cls): }) return parameters + def max_pooling(self, array): + shapes = array.shape + (self.k, self.k) + array = np.pad(array, [(0, 0), (self.p, self.p), (self.p, self.p)], mode='constant') + strides = (array.strides[0], array.strides[1], array.strides[2], + array.strides[1], array.strides[2]) + strided = np.lib.stride_tricks.as_strided(array, shapes, strides) + pooled_array = strided.max(axis=(3, 4)) + return np.expand_dims(pooled_array, axis=0) + @staticmethod def heatmap_nms(heatmaps, pooled_heatmaps): return heatmaps * (heatmaps == pooled_heatmaps) @@ -128,7 +138,12 @@ def preprocess(self, inputs): def postprocess(self, outputs, meta): heatmaps = outputs[self.heatmaps_blob_name] pafs = outputs[self.pafs_blob_name] - pooled_heatmaps = outputs[self.pooled_heatmaps_blob_name] + + if self.is_openvino_adapter: + pooled_heatmaps = outputs[self.pooled_heatmaps_blob_name] + else: + pooled_heatmaps = self.max_pooling(heatmaps.squeeze()) + nms_heatmaps = self.heatmap_nms(heatmaps, pooled_heatmaps) poses, scores = self.decoder(heatmaps, nms_heatmaps, pafs) # Rescale poses to the original image. diff --git a/demos/human_pose_estimation_demo/python/README.md b/demos/human_pose_estimation_demo/python/README.md index ab06c64b933..d534beda5bd 100644 --- a/demos/human_pose_estimation_demo/python/README.md +++ b/demos/human_pose_estimation_demo/python/README.md @@ -52,10 +52,13 @@ omz_converter --list models.lst Running the application with the `-h` option yields the following usage message: ``` -usage: human_pose_estimation_demo.py [-h] -m MODEL -at {ae,hrnet,openpose} -i - INPUT [--loop] [-o OUTPUT] +usage: human_pose_estimation_demo.py [-h] -m MODEL -at + {ae,higherhrnet,openpose} + [--adapter {openvino,ovms}] -i INPUT + [--loop] [-o OUTPUT] [-limit OUTPUT_LIMIT] [-d DEVICE] [-t PROB_THRESHOLD] [--tsize TSIZE] + [--layout LAYOUT] [-nireq NUM_INFER_REQUESTS] [-nstreams NUM_STREAMS] [-nthreads NUM_THREADS] [-no_show] @@ -68,6 +71,9 @@ Options: Required. Path to an .xml file with a trained model. -at {ae,higherhrnet,openpose}, --architecture_type {ae,higherhrnet,openpose} Required. Specify model' architecture type. + --adapter {openvino,ovms} + Optional. Specify the model adapter. Default is + openvino. -i INPUT, --input INPUT Required. An input to process. The input must be a single image, a folder of images, video file or camera @@ -98,6 +104,8 @@ Common model options: resized to a predefined height, which is the target size in this case. For Associative Embedding-like nets target size is the length of a short first image side. + --layout LAYOUT Optional. Model inputs layouts. Ex. NCHW or + input0:NCHW,input1:NC in case of more than one input. Inference options: -nireq NUM_INFER_REQUESTS, --num_infer_requests NUM_INFER_REQUESTS @@ -148,6 +156,21 @@ To avoid disk space overrun in case of continuous input stream, like camera, you >**NOTE**: Windows\* systems may not have the Motion JPEG codec installed by default. If this is the case, you can download OpenCV FFMPEG back end using the PowerShell script provided with the OpenVINO ™ install package and located at `/opencv/ffmpeg-download.ps1`. The script should be run with administrative privileges if OpenVINO ™ is installed in a system protected folder (this is a typical case). Alternatively, you can save results as images. +## Running with OpenVINO Model Server + +You can also run this demo with model served in [OpenVINO Model Server](https://github.com/openvinotoolkit/model_server). Refer to [`OVMSAdapter`](../../common/python/openvino/model_zoo/model_api/adapters/ovms_adapter.md) to learn about running demos with OVMS. + +Exemplary command: + +```sh +python3 human_pose_estimation_demo.py \ + -d CPU \ + -i 0 \ + -m localhost:9000/models/human_pose_estimation \ + -at ae + --adapter ovms +``` + ## Demo Output The demo uses OpenCV to display the resulting frame with estimated poses. diff --git a/demos/human_pose_estimation_demo/python/human_pose_estimation_demo.py b/demos/human_pose_estimation_demo/python/human_pose_estimation_demo.py index beccd55bd4c..e982cb4c9fe 100755 --- a/demos/human_pose_estimation_demo/python/human_pose_estimation_demo.py +++ b/demos/human_pose_estimation_demo/python/human_pose_estimation_demo.py @@ -29,7 +29,7 @@ from openvino.model_zoo.model_api.models import ImageModel, OutputTransform from openvino.model_zoo.model_api.performance_metrics import PerformanceMetrics from openvino.model_zoo.model_api.pipelines import get_user_config, AsyncPipeline -from openvino.model_zoo.model_api.adapters import create_core, OpenvinoAdapter +from openvino.model_zoo.model_api.adapters import create_core, OpenvinoAdapter, OVMSAdapter import monitors from images_capture import open_images_capture @@ -51,6 +51,8 @@ def build_argparser(): required=True, type=Path) args.add_argument('-at', '--architecture_type', help='Required. Specify model\' architecture type.', type=str, required=True, choices=('ae', 'higherhrnet', 'openpose')) + args.add_argument('--adapter', help='Optional. Specify the model adapter. Default is openvino.', + default='openvino', type=str, choices=('openvino', 'ovms')) args.add_argument('-i', '--input', required=True, help='Required. An input to process. The input must be a single image, ' 'a folder of images, video file or camera id.') @@ -170,8 +172,11 @@ def main(): video_writer = cv2.VideoWriter() plugin_config = get_user_config(args.device, args.num_streams, args.num_threads) - model_adapter = OpenvinoAdapter(create_core(), args.model, device=args.device, plugin_config=plugin_config, - max_num_requests=args.num_infer_requests, model_parameters = {'input_layouts': args.layout}) + if args.adapter == 'openvino': + model_adapter = OpenvinoAdapter(create_core(), args.model, device=args.device, plugin_config=plugin_config, + max_num_requests=args.num_infer_requests, model_parameters = {'input_layouts': args.layout}) + elif args.adapter == 'ovms': + model_adapter = OVMSAdapter(args.model) start_time = perf_counter() frame = cap.read()