[POT] Update IEEngine for SW API support (#10304)
* Update IEEngine for SW API support * Change Engine for GNA sample * Change stacks into reshape
This commit is contained in:
@@ -8,7 +8,7 @@ from openvino.tools.pot import load_model, save_model, create_pipeline
|
||||
from openvino.tools.pot.utils.logger import init_logger
|
||||
from openvino.tools.pot.api.samples.utils.argument_parser import get_common_argparser
|
||||
from openvino.tools.pot.api.samples.speech.data_loader import ArkDataLoader
|
||||
from openvino.tools.pot.api.samples.speech.utils import ArkEngine
|
||||
from openvino.tools.pot.engines.simplified_engine import SimplifiedEngine
|
||||
|
||||
|
||||
def parse_args():
|
||||
@@ -111,7 +111,7 @@ def get_configs(args):
|
||||
def optimize_model(args):
|
||||
model_config, engine_config, dataset_config, algorithms = get_configs(args)
|
||||
data_loader = ArkDataLoader(dataset_config)
|
||||
engine = ArkEngine(config=engine_config, data_loader=data_loader)
|
||||
engine = SimplifiedEngine(config=engine_config, data_loader=data_loader)
|
||||
pipeline = create_pipeline(algorithms, engine)
|
||||
|
||||
model = load_model(model_config, target_device='GNA')
|
||||
|
||||
@@ -1,19 +0,0 @@
|
||||
# Copyright (C) 2020-2022 Intel Corporation
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
import numpy as np
|
||||
from openvino.tools.pot.engines.simplified_engine import SimplifiedEngine
|
||||
|
||||
|
||||
class ArkEngine(SimplifiedEngine):
|
||||
def _fill_input(self, model, image_batch):
|
||||
if 'input_names' in self.data_loader.config:
|
||||
model_inputs = {n.get_node().friendly_name: n for n in model.inputs}
|
||||
feed_dict = {}
|
||||
for input_name in self.data_loader.config['input_names']:
|
||||
input_blob = model_inputs[input_name]
|
||||
input_blob_name = self._get_input_any_name(input_blob)
|
||||
input_blob_shape = list(input_blob.shape)
|
||||
feed_dict[input_blob_name] = np.reshape(image_batch[0][input_name], input_blob_shape)
|
||||
return feed_dict
|
||||
raise Exception('input_names is not provided!')
|
||||
@@ -226,14 +226,21 @@ class IEEngine(Engine):
|
||||
:param model: IENetwork instance
|
||||
:param image_batch: list of ndarray images or list with a dictionary of inputs mapping
|
||||
"""
|
||||
if isinstance(image_batch[0], dict):
|
||||
return image_batch[0]
|
||||
|
||||
input_info = model.inputs
|
||||
|
||||
if isinstance(image_batch[0], dict):
|
||||
feed_dict = {}
|
||||
input_blobs = {get_clean_name(in_node.get_node().friendly_name): in_node for in_node in input_info}
|
||||
for input_name in image_batch[0].keys():
|
||||
input_blob = input_blobs[input_name]
|
||||
input_blob_name = self._get_input_any_name(input_blob)
|
||||
feed_dict[input_blob_name] = np.reshape(image_batch[0][input_name], input_blob.shape)
|
||||
return feed_dict
|
||||
|
||||
if len(input_info) == 1:
|
||||
input_blob = next(iter(input_info))
|
||||
input_blob_name = self._get_input_any_name(input_blob)
|
||||
image_batch = {input_blob_name: np.stack(image_batch, axis=0)}
|
||||
image_batch = {input_blob_name: np.reshape(image_batch, input_blob.shape)}
|
||||
if Shape(image_batch[input_blob_name].shape) != input_info[0].shape:
|
||||
raise ValueError(f"Incompatible input shapes. "
|
||||
f"Cannot infer {Shape(image_batch[input_blob_name].shape)} into {input_info[0].shape}."
|
||||
@@ -253,7 +260,7 @@ class IEEngine(Engine):
|
||||
lambda x: x.get_any_name() != image_info_name, input_info)))
|
||||
image_tensor_name = image_tensor_node.get_any_name()
|
||||
|
||||
image_tensor = (image_tensor_name, np.stack(image_batch, axis=0))
|
||||
image_tensor = (image_tensor_name, np.reshape(image_batch, input_blob.shape))
|
||||
if Shape(image_tensor[1].shape) != image_tensor_node.shape:
|
||||
raise ValueError(f"Incompatible input shapes. "
|
||||
f"Cannot infer {Shape(image_tensor[1].shape)} into {image_tensor_node.shape}."
|
||||
|
||||
Reference in New Issue
Block a user