[POT] Support layout in pot (#9060)

* support layout pot

* pylint
This commit is contained in:
Indira Salyahova 2021-12-15 12:12:54 +03:00 committed by GitHub
parent 4fba88d29a
commit 2f07b98251
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 59 additions and 4 deletions

View File

@ -17,6 +17,7 @@ of all possible parameters can be found in the default_quantization_spec.json */
"engine": {
"type": "simplified",
"layout": "NCHW", // Layout of input data. Supported ["NCHW", "NHWC", "CHW", "CWH"] layout
"data_source": "PATH_TO_SOURCE" // You can specify path to directory with images. Also you can
// specify template for file names to filter images to load.
// Templates are unix style (This option valid only in simplified mode)

View File

@ -26,6 +26,7 @@ def create_data_loader(config, model):
if tuple(in_node.shape) != (1, 3):
data_loader = ImageLoader(config)
data_loader.shape = in_node.shape
data_loader.get_layout(in_node)
return data_loader
if data_loader is None:

View File

@ -3,6 +3,7 @@
from cv2 import imread, IMREAD_GRAYSCALE
from openvino.runtime import Layout, Dimension # pylint: disable=E0611,E0401
from ..api.data_loader import DataLoader
from ..data_loaders.utils import prepare_image, collect_img_files
@ -14,6 +15,7 @@ class ImageLoader(DataLoader):
self._img_files = collect_img_files(config.data_source)
self._shape = None
self._layout = config.get('layout', None)
self._crop_central_fraction = config.get('central_fraction', None)
def __getitem__(self, idx):
@ -37,4 +39,29 @@ class ImageLoader(DataLoader):
if image is None:
raise Exception('Can not read the image: {}'.format(img_path))
return prepare_image(image, self.shape[-2:], self._crop_central_fraction)
return prepare_image(image, self._layout, self.shape[-2:], self._crop_central_fraction)
def get_layout(self, input_node):
if self._layout is not None:
if 'C' not in self._layout or 'H' not in self._layout or 'W' not in self._layout:
raise ValueError('Unexpected {} layout'.format(self._layout))
self._layout = Layout(self._layout)
return
layout_from_ir = input_node.graph.graph.get('layout', None)
if layout_from_ir is not None:
self._layout = Layout(layout_from_ir)
return
image_colors_dim = (Dimension(3), Dimension(1))
num_dims = len(self._shape)
if num_dims == 4:
if self._shape[1] in image_colors_dim:
self._layout = Layout("NCHW")
elif self._shape[3] in image_colors_dim:
self._layout = Layout("NHWC")
elif num_dims == 3:
if self._shape[0] in image_colors_dim:
self._layout = Layout("CHW")
elif self._shape[2] in image_colors_dim:
self._layout = Layout("HWC")

View File

@ -9,6 +9,7 @@ from pathlib import Path
import numpy as np
import cv2 as cv
from openvino.runtime import Layout # pylint: disable=E0611,E0401
from openvino.tools.pot.utils.logger import get_logger
logger = get_logger(__name__)
@ -34,12 +35,11 @@ def crop(image, central_fraction):
return image[start_height:start_height + dst_height, start_width:start_width + dst_width]
def prepare_image(image, dst_shape, central_fraction=None):
def prepare_image(image, layout, dst_shape, central_fraction=None):
if central_fraction:
image = crop(image, central_fraction)
if image.shape[-1] in [3, 1]:
if layout == Layout('NCHW') or layout == Layout('CHW'):
image = cv.resize(image, dst_shape[::-1])
return image.transpose(2, 0, 1)

View File

@ -44,3 +44,29 @@ def test_check_image(tmp_path, models, model_name, model_framework):
num_images_in_dir = len(os.listdir(path_image_data))
assert num_images_from_data_loader == num_images_in_dir
TEST_MODELS_LAYOUT = [('mobilenet-v2-pytorch', 'pytorch', 'NCHW', (3, 224, 224)),
('mobilenet-v2-pytorch', 'pytorch', 'NHWC', (224, 224, 3)),
('mobilenet-v2-pytorch', 'pytorch', None, (3, 224, 224)),
('mobilenet-v1-1.0-224-tf', 'tf', None, (224, 224, 3))]
@pytest.mark.parametrize(
'model_name, model_framework, layout, reference_shape', TEST_MODELS,
ids=['{}_{}'.format(m[0], m[1]) for m in TEST_MODELS])
def test_check_layout(tmp_path, models, model_name, model_framework, layout, reference_shape):
test_dir = Path(__file__).parent
path_image_data = os.path.join(test_dir, "data/image_data")
engine_config = Dict({"device": "CPU",
"type": "simplified",
"layout": layout,
"data_source": path_image_data})
model = models.get(model_name, model_framework, tmp_path)
model = load_model(model.model_params)
data_loader = create_data_loader(engine_config, model)
image = data_loader.item()
assert image.shape == reference_shape