[POT] Fix get layout from model (#10018)

* fix: layout pot

* layout

* fix: layout

* pylint

* add logger

* Update image_loader.py

* pylint

* repeat layout in data free

* resolve conflicts

* sample

* resolve comments
This commit is contained in:
Indira Salyahova 2022-02-04 11:46:54 +03:00 committed by GitHub
parent ed6bb8ab2d
commit da02951d67
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 36 additions and 20 deletions

View File

@ -65,7 +65,7 @@ class ImageNetDataLoader(DataLoader):
"""
image = imread(os.path.join(self.config.data_source, index))
image = self._preprocess(image)
return image.transpose(2, 0, 1)
return image
def _preprocess(self, image):
""" Does preprocessing of an image according to the preprocessing config.

View File

@ -32,8 +32,6 @@ def create_data_loader(config, model):
elif config.type == 'data_free':
if not config.shape:
config.shape = in_node.shape
if not config.layout:
config.layout = in_node.graph.graph.get('layout', None)
data_loader = SyntheticImageLoader(config)
return data_loader

View File

@ -4,8 +4,12 @@
from cv2 import imread, IMREAD_GRAYSCALE
from openvino.runtime import Layout, Dimension # pylint: disable=E0611,E0401
from openvino.tools.mo.utils.cli_parser import get_layout_values
from ..api.data_loader import DataLoader
from ..data_loaders.utils import prepare_image, collect_img_files
from ..utils.logger import get_logger
logger = get_logger(__name__)
class ImageLoader(DataLoader):
@ -43,7 +47,8 @@ class ImageLoader(DataLoader):
if image is None:
raise Exception('Can not read the image: {}'.format(img_path))
return prepare_image(image, self._layout, (self.shape[H], self.shape[W]), self._crop_central_fraction)
return prepare_image(image, self._layout, (self.shape[H], self.shape[W]),
self._crop_central_fraction, self._shape[C] == 1)
def get_layout(self, input_node=None):
if self._layout is not None:
@ -54,9 +59,12 @@ class ImageLoader(DataLoader):
self._layout = Layout(self._layout)
return
if input_node:
layout_from_ir = input_node.graph.graph.get('layout', None)
if input_node and hasattr(input_node.graph, 'meta_data') \
and input_node.graph.meta_data.get('layout', None) not in [None, '()']:
layout_from_ir = get_layout_values(input_node.graph.meta_data.get('layout', None))
if layout_from_ir is not None:
layout_from_ir = layout_from_ir[next(iter(layout_from_ir))].get('source_layout', None)
# SyntheticImageLoader uses only H,W,C dimensions
if self._shape is not None and 'N' in layout_from_ir and len(self._shape) == 3:
layout_from_ir = layout_from_ir[1:]
self._layout = Layout(layout_from_ir)
@ -74,3 +82,4 @@ class ImageLoader(DataLoader):
self._layout = Layout("CHW")
elif self._shape[2] in image_colors_dim:
self._layout = Layout("HWC")
logger.info(f'Layout value is set {self._layout}')

View File

@ -35,15 +35,18 @@ def crop(image, central_fraction):
return image[start_height:start_height + dst_height, start_width:start_width + dst_width]
def prepare_image(image, layout, dst_shape, central_fraction=None):
def prepare_image(image, layout, dst_shape, central_fraction=None, grayscale=False):
if central_fraction:
image = crop(image, central_fraction)
image = cv.resize(image, dst_shape[::-1])
if grayscale:
image = np.expand_dims(image, 2)
if layout == Layout('NCHW') or layout == Layout('CHW'):
image = cv.resize(image, dst_shape[::-1])
return image.transpose(2, 0, 1)
return cv.resize(image, dst_shape[::-1])
return image
def collect_img_files(data_source):

View File

@ -7,7 +7,7 @@ from time import time
import copy
import numpy as np
from openvino.runtime import Core, AsyncInferQueue # pylint: disable=E0611,E0401
from openvino.runtime import Core, AsyncInferQueue, Shape # pylint: disable=E0611,E0401
from .utils import append_stats, process_accumulated_stats, \
restore_original_node_names, align_stat_names_with_results, \
@ -233,7 +233,12 @@ class IEEngine(Engine):
if len(input_info) == 1:
input_blob = next(iter(input_info))
input_blob_name = self._get_input_any_name(input_blob)
return {input_blob_name: np.stack(image_batch, axis=0)}
image_batch = {input_blob_name: np.stack(image_batch, axis=0)}
if Shape(image_batch[input_blob_name].shape) != input_info[0].shape:
raise ValueError(f"Incompatible input shapes. "
f"Cannot infer {Shape(image_batch[input_blob_name].shape)} into {input_info[0].shape}."
f"Try to specify the layout of the model.")
return image_batch
if len(input_info) == 2:
image_info_nodes = list(filter(
@ -249,6 +254,10 @@ class IEEngine(Engine):
image_tensor_name = image_tensor_node.get_any_name()
image_tensor = (image_tensor_name, np.stack(image_batch, axis=0))
if Shape(image_tensor[1].shape) != image_tensor_node.shape:
raise ValueError(f"Incompatible input shapes. "
f"Cannot infer {Shape(image_tensor[1].shape)} into {image_tensor_node.shape}."
f"Try to specify the layout of the model.")
ch, height, width = image_batch[0].shape
image_info = (image_info_name,

View File

@ -47,10 +47,10 @@ def test_check_image(tmp_path, models, model_name, model_framework):
TEST_MODELS_LAYOUT = [
#('mobilenet-v2-pytorch', 'pytorch', 'NCHW', (3, 224, 224)),
#('mobilenet-v2-pytorch', 'pytorch', 'NHWC', (224, 224, 3)),
#('mobilenet-v2-pytorch', 'pytorch', None, (3, 224, 224)),
#('mobilenet-v1-1.0-224-tf', 'tf', None, (224, 224, 3))
('mobilenet-v2-pytorch', 'pytorch', 'NCHW', (3, 224, 224)),
('mobilenet-v1-1.0-224-tf', 'tf', 'NHWC', (224, 224, 3)),
('mobilenet-v2-pytorch', 'pytorch', None, (3, 224, 224)),
('mobilenet-v1-1.0-224-tf', 'tf', None, (224, 224, 3))
]
@ -69,6 +69,6 @@ def test_check_layout(tmp_path, models, model_name, model_framework, layout, ref
model = load_model(model.model_params)
data_loader = create_data_loader(engine_config, model)
image = data_loader.item()
image = next(iter(data_loader))
assert image.shape == reference_shape

View File

@ -114,8 +114,7 @@ def test_compression(_params, tmp_path, models):
TEST_SAMPLE_MODELS = [
# This test is not able to run due to NHWC shape that is not supported
# ('mobilenet-v2-1.0-224', 'tf', 'DefaultQuantization', 'performance', {'accuracy@top1': 0.71})
('mobilenet-v2-1.0-224', 'tf', 'DefaultQuantization', 'performance', {'accuracy@top1': 0.716})
]
@ -272,8 +271,6 @@ TEST_MULTIPLE_OUT_PORTS = [('multiple_out_ports_net', 'tf')]
'model_name, model_framework', TEST_MULTIPLE_OUT_PORTS,
ids=['{}_{}'.format(m[0], m[1]) for m in TEST_MULTIPLE_OUT_PORTS])
def test_multiport_outputs_model(tmp_path, models, model_name, model_framework):
# This test is not able to run due to NHWC shape that is not supported
pytest.skip()
test_dir = Path(__file__).parent
# one image as dataset
data_source = (test_dir / 'data/image_data/').as_posix()