From 2f07b982517137031f530b6940baf7b4eb745c6b Mon Sep 17 00:00:00 2001 From: Indira Salyahova Date: Wed, 15 Dec 2021 12:12:54 +0300 Subject: [PATCH] [POT] Support layout in pot (#9060) * support layout pot * pylint --- .../pot/configs/simplified_mode_template.json | 1 + .../tools/pot/data_loaders/creator.py | 1 + .../tools/pot/data_loaders/image_loader.py | 29 ++++++++++++++++++- .../openvino/tools/pot/data_loaders/utils.py | 6 ++-- tools/pot/tests/test_image_loading.py | 26 +++++++++++++++++ 5 files changed, 59 insertions(+), 4 deletions(-) diff --git a/tools/pot/configs/simplified_mode_template.json b/tools/pot/configs/simplified_mode_template.json index 52db5686cd4..2df4aa1e898 100644 --- a/tools/pot/configs/simplified_mode_template.json +++ b/tools/pot/configs/simplified_mode_template.json @@ -17,6 +17,7 @@ of all possible parameters can be found in the default_quantization_spec.json */ "engine": { "type": "simplified", + "layout": "NCHW", // Layout of input data. Supported ["NCHW", "NHWC", "CHW", "CWH"] layout "data_source": "PATH_TO_SOURCE" // You can specify path to directory with images. Also you can // specify template for file names to filter images to load. // Templates are unix style (This option valid only in simplified mode) diff --git a/tools/pot/openvino/tools/pot/data_loaders/creator.py b/tools/pot/openvino/tools/pot/data_loaders/creator.py index 14e76e92f00..f4cd1e05fa9 100644 --- a/tools/pot/openvino/tools/pot/data_loaders/creator.py +++ b/tools/pot/openvino/tools/pot/data_loaders/creator.py @@ -26,6 +26,7 @@ def create_data_loader(config, model): if tuple(in_node.shape) != (1, 3): data_loader = ImageLoader(config) data_loader.shape = in_node.shape + data_loader.get_layout(in_node) return data_loader if data_loader is None: diff --git a/tools/pot/openvino/tools/pot/data_loaders/image_loader.py b/tools/pot/openvino/tools/pot/data_loaders/image_loader.py index 4ba603555e6..d81a5586d4c 100644 --- a/tools/pot/openvino/tools/pot/data_loaders/image_loader.py +++ b/tools/pot/openvino/tools/pot/data_loaders/image_loader.py @@ -3,6 +3,7 @@ from cv2 import imread, IMREAD_GRAYSCALE +from openvino.runtime import Layout, Dimension # pylint: disable=E0611,E0401 from ..api.data_loader import DataLoader from ..data_loaders.utils import prepare_image, collect_img_files @@ -14,6 +15,7 @@ class ImageLoader(DataLoader): self._img_files = collect_img_files(config.data_source) self._shape = None + self._layout = config.get('layout', None) self._crop_central_fraction = config.get('central_fraction', None) def __getitem__(self, idx): @@ -37,4 +39,29 @@ class ImageLoader(DataLoader): if image is None: raise Exception('Can not read the image: {}'.format(img_path)) - return prepare_image(image, self.shape[-2:], self._crop_central_fraction) + return prepare_image(image, self._layout, self.shape[-2:], self._crop_central_fraction) + + def get_layout(self, input_node): + if self._layout is not None: + if 'C' not in self._layout or 'H' not in self._layout or 'W' not in self._layout: + raise ValueError('Unexpected {} layout'.format(self._layout)) + self._layout = Layout(self._layout) + return + + layout_from_ir = input_node.graph.graph.get('layout', None) + if layout_from_ir is not None: + self._layout = Layout(layout_from_ir) + return + + image_colors_dim = (Dimension(3), Dimension(1)) + num_dims = len(self._shape) + if num_dims == 4: + if self._shape[1] in image_colors_dim: + self._layout = Layout("NCHW") + elif self._shape[3] in image_colors_dim: + self._layout = Layout("NHWC") + elif num_dims == 3: + if self._shape[0] in image_colors_dim: + self._layout = Layout("CHW") + elif self._shape[2] in image_colors_dim: + self._layout = Layout("HWC") diff --git a/tools/pot/openvino/tools/pot/data_loaders/utils.py b/tools/pot/openvino/tools/pot/data_loaders/utils.py index d60d5b4d1ff..fde14d66ba2 100644 --- a/tools/pot/openvino/tools/pot/data_loaders/utils.py +++ b/tools/pot/openvino/tools/pot/data_loaders/utils.py @@ -9,6 +9,7 @@ from pathlib import Path import numpy as np import cv2 as cv +from openvino.runtime import Layout # pylint: disable=E0611,E0401 from openvino.tools.pot.utils.logger import get_logger logger = get_logger(__name__) @@ -34,12 +35,11 @@ def crop(image, central_fraction): return image[start_height:start_height + dst_height, start_width:start_width + dst_width] -def prepare_image(image, dst_shape, central_fraction=None): - +def prepare_image(image, layout, dst_shape, central_fraction=None): if central_fraction: image = crop(image, central_fraction) - if image.shape[-1] in [3, 1]: + if layout == Layout('NCHW') or layout == Layout('CHW'): image = cv.resize(image, dst_shape[::-1]) return image.transpose(2, 0, 1) diff --git a/tools/pot/tests/test_image_loading.py b/tools/pot/tests/test_image_loading.py index 0836e3025ff..ff82d73c3d6 100644 --- a/tools/pot/tests/test_image_loading.py +++ b/tools/pot/tests/test_image_loading.py @@ -44,3 +44,29 @@ def test_check_image(tmp_path, models, model_name, model_framework): num_images_in_dir = len(os.listdir(path_image_data)) assert num_images_from_data_loader == num_images_in_dir + + +TEST_MODELS_LAYOUT = [('mobilenet-v2-pytorch', 'pytorch', 'NCHW', (3, 224, 224)), + ('mobilenet-v2-pytorch', 'pytorch', 'NHWC', (224, 224, 3)), + ('mobilenet-v2-pytorch', 'pytorch', None, (3, 224, 224)), + ('mobilenet-v1-1.0-224-tf', 'tf', None, (224, 224, 3))] + + +@pytest.mark.parametrize( + 'model_name, model_framework, layout, reference_shape', TEST_MODELS, + ids=['{}_{}'.format(m[0], m[1]) for m in TEST_MODELS]) +def test_check_layout(tmp_path, models, model_name, model_framework, layout, reference_shape): + test_dir = Path(__file__).parent + path_image_data = os.path.join(test_dir, "data/image_data") + + engine_config = Dict({"device": "CPU", + "type": "simplified", + "layout": layout, + "data_source": path_image_data}) + model = models.get(model_name, model_framework, tmp_path) + model = load_model(model.model_params) + + data_loader = create_data_loader(engine_config, model) + image = data_loader.item() + + assert image.shape == reference_shape