[POT] Update IEEngine for SW API support (#10304)

* Update IEEngine for SW API support

* Change Engine for GNA sample

* Change stacks into reshape
This commit is contained in:
Nikita Malinin
2022-02-12 18:57:35 +03:00
committed by GitHub
parent 976a20cedf
commit 8e43987cd7
3 changed files with 14 additions and 26 deletions

View File

@@ -8,7 +8,7 @@ from openvino.tools.pot import load_model, save_model, create_pipeline
from openvino.tools.pot.utils.logger import init_logger
from openvino.tools.pot.api.samples.utils.argument_parser import get_common_argparser
from openvino.tools.pot.api.samples.speech.data_loader import ArkDataLoader
from openvino.tools.pot.api.samples.speech.utils import ArkEngine
from openvino.tools.pot.engines.simplified_engine import SimplifiedEngine
def parse_args():
@@ -111,7 +111,7 @@ def get_configs(args):
def optimize_model(args):
model_config, engine_config, dataset_config, algorithms = get_configs(args)
data_loader = ArkDataLoader(dataset_config)
engine = ArkEngine(config=engine_config, data_loader=data_loader)
engine = SimplifiedEngine(config=engine_config, data_loader=data_loader)
pipeline = create_pipeline(algorithms, engine)
model = load_model(model_config, target_device='GNA')

View File

@@ -1,19 +0,0 @@
# Copyright (C) 2020-2022 Intel Corporation
# SPDX-License-Identifier: Apache-2.0
import numpy as np
from openvino.tools.pot.engines.simplified_engine import SimplifiedEngine
class ArkEngine(SimplifiedEngine):
def _fill_input(self, model, image_batch):
if 'input_names' in self.data_loader.config:
model_inputs = {n.get_node().friendly_name: n for n in model.inputs}
feed_dict = {}
for input_name in self.data_loader.config['input_names']:
input_blob = model_inputs[input_name]
input_blob_name = self._get_input_any_name(input_blob)
input_blob_shape = list(input_blob.shape)
feed_dict[input_blob_name] = np.reshape(image_batch[0][input_name], input_blob_shape)
return feed_dict
raise Exception('input_names is not provided!')

View File

@@ -226,14 +226,21 @@ class IEEngine(Engine):
:param model: IENetwork instance
:param image_batch: list of ndarray images or list with a dictionary of inputs mapping
"""
if isinstance(image_batch[0], dict):
return image_batch[0]
input_info = model.inputs
if isinstance(image_batch[0], dict):
feed_dict = {}
input_blobs = {get_clean_name(in_node.get_node().friendly_name): in_node for in_node in input_info}
for input_name in image_batch[0].keys():
input_blob = input_blobs[input_name]
input_blob_name = self._get_input_any_name(input_blob)
feed_dict[input_blob_name] = np.reshape(image_batch[0][input_name], input_blob.shape)
return feed_dict
if len(input_info) == 1:
input_blob = next(iter(input_info))
input_blob_name = self._get_input_any_name(input_blob)
image_batch = {input_blob_name: np.stack(image_batch, axis=0)}
image_batch = {input_blob_name: np.reshape(image_batch, input_blob.shape)}
if Shape(image_batch[input_blob_name].shape) != input_info[0].shape:
raise ValueError(f"Incompatible input shapes. "
f"Cannot infer {Shape(image_batch[input_blob_name].shape)} into {input_info[0].shape}."
@@ -253,7 +260,7 @@ class IEEngine(Engine):
lambda x: x.get_any_name() != image_info_name, input_info)))
image_tensor_name = image_tensor_node.get_any_name()
image_tensor = (image_tensor_name, np.stack(image_batch, axis=0))
image_tensor = (image_tensor_name, np.reshape(image_batch, input_blob.shape))
if Shape(image_tensor[1].shape) != image_tensor_node.shape:
raise ValueError(f"Incompatible input shapes. "
f"Cannot infer {Shape(image_tensor[1].shape)} into {image_tensor_node.shape}."