From 8e43987cd7e85f701fb1d2445cff720f280e6738 Mon Sep 17 00:00:00 2001 From: Nikita Malinin Date: Sat, 12 Feb 2022 18:57:35 +0300 Subject: [PATCH] [POT] Update IEEngine for SW API support (#10304) * Update IEEngine for SW API support * Change Engine for GNA sample * Change stacks into reshape --- .../pot/api/samples/speech/gna_sample.py | 4 ++-- .../tools/pot/api/samples/speech/utils.py | 19 ------------------- .../openvino/tools/pot/engines/ie_engine.py | 17 ++++++++++++----- 3 files changed, 14 insertions(+), 26 deletions(-) delete mode 100644 tools/pot/openvino/tools/pot/api/samples/speech/utils.py diff --git a/tools/pot/openvino/tools/pot/api/samples/speech/gna_sample.py b/tools/pot/openvino/tools/pot/api/samples/speech/gna_sample.py index ecf494ed61b..8dcf59be083 100644 --- a/tools/pot/openvino/tools/pot/api/samples/speech/gna_sample.py +++ b/tools/pot/openvino/tools/pot/api/samples/speech/gna_sample.py @@ -8,7 +8,7 @@ from openvino.tools.pot import load_model, save_model, create_pipeline from openvino.tools.pot.utils.logger import init_logger from openvino.tools.pot.api.samples.utils.argument_parser import get_common_argparser from openvino.tools.pot.api.samples.speech.data_loader import ArkDataLoader -from openvino.tools.pot.api.samples.speech.utils import ArkEngine +from openvino.tools.pot.engines.simplified_engine import SimplifiedEngine def parse_args(): @@ -111,7 +111,7 @@ def get_configs(args): def optimize_model(args): model_config, engine_config, dataset_config, algorithms = get_configs(args) data_loader = ArkDataLoader(dataset_config) - engine = ArkEngine(config=engine_config, data_loader=data_loader) + engine = SimplifiedEngine(config=engine_config, data_loader=data_loader) pipeline = create_pipeline(algorithms, engine) model = load_model(model_config, target_device='GNA') diff --git a/tools/pot/openvino/tools/pot/api/samples/speech/utils.py b/tools/pot/openvino/tools/pot/api/samples/speech/utils.py deleted file mode 100644 index d0f56f68348..00000000000 --- a/tools/pot/openvino/tools/pot/api/samples/speech/utils.py +++ /dev/null @@ -1,19 +0,0 @@ -# Copyright (C) 2020-2022 Intel Corporation -# SPDX-License-Identifier: Apache-2.0 - -import numpy as np -from openvino.tools.pot.engines.simplified_engine import SimplifiedEngine - - -class ArkEngine(SimplifiedEngine): - def _fill_input(self, model, image_batch): - if 'input_names' in self.data_loader.config: - model_inputs = {n.get_node().friendly_name: n for n in model.inputs} - feed_dict = {} - for input_name in self.data_loader.config['input_names']: - input_blob = model_inputs[input_name] - input_blob_name = self._get_input_any_name(input_blob) - input_blob_shape = list(input_blob.shape) - feed_dict[input_blob_name] = np.reshape(image_batch[0][input_name], input_blob_shape) - return feed_dict - raise Exception('input_names is not provided!') diff --git a/tools/pot/openvino/tools/pot/engines/ie_engine.py b/tools/pot/openvino/tools/pot/engines/ie_engine.py index b3c56a5e5cd..a3cb340f3ba 100644 --- a/tools/pot/openvino/tools/pot/engines/ie_engine.py +++ b/tools/pot/openvino/tools/pot/engines/ie_engine.py @@ -226,14 +226,21 @@ class IEEngine(Engine): :param model: IENetwork instance :param image_batch: list of ndarray images or list with a dictionary of inputs mapping """ - if isinstance(image_batch[0], dict): - return image_batch[0] - input_info = model.inputs + + if isinstance(image_batch[0], dict): + feed_dict = {} + input_blobs = {get_clean_name(in_node.get_node().friendly_name): in_node for in_node in input_info} + for input_name in image_batch[0].keys(): + input_blob = input_blobs[input_name] + input_blob_name = self._get_input_any_name(input_blob) + feed_dict[input_blob_name] = np.reshape(image_batch[0][input_name], input_blob.shape) + return feed_dict + if len(input_info) == 1: input_blob = next(iter(input_info)) input_blob_name = self._get_input_any_name(input_blob) - image_batch = {input_blob_name: np.stack(image_batch, axis=0)} + image_batch = {input_blob_name: np.reshape(image_batch, input_blob.shape)} if Shape(image_batch[input_blob_name].shape) != input_info[0].shape: raise ValueError(f"Incompatible input shapes. " f"Cannot infer {Shape(image_batch[input_blob_name].shape)} into {input_info[0].shape}." @@ -253,7 +260,7 @@ class IEEngine(Engine): lambda x: x.get_any_name() != image_info_name, input_info))) image_tensor_name = image_tensor_node.get_any_name() - image_tensor = (image_tensor_name, np.stack(image_batch, axis=0)) + image_tensor = (image_tensor_name, np.reshape(image_batch, input_blob.shape)) if Shape(image_tensor[1].shape) != image_tensor_node.shape: raise ValueError(f"Incompatible input shapes. " f"Cannot infer {Shape(image_tensor[1].shape)} into {image_tensor_node.shape}."