parent
4fba88d29a
commit
2f07b98251
@ -17,6 +17,7 @@ of all possible parameters can be found in the default_quantization_spec.json */
|
||||
|
||||
"engine": {
|
||||
"type": "simplified",
|
||||
"layout": "NCHW", // Layout of input data. Supported ["NCHW", "NHWC", "CHW", "CWH"] layout
|
||||
"data_source": "PATH_TO_SOURCE" // You can specify path to directory with images. Also you can
|
||||
// specify template for file names to filter images to load.
|
||||
// Templates are unix style (This option valid only in simplified mode)
|
||||
|
@ -26,6 +26,7 @@ def create_data_loader(config, model):
|
||||
if tuple(in_node.shape) != (1, 3):
|
||||
data_loader = ImageLoader(config)
|
||||
data_loader.shape = in_node.shape
|
||||
data_loader.get_layout(in_node)
|
||||
return data_loader
|
||||
|
||||
if data_loader is None:
|
||||
|
@ -3,6 +3,7 @@
|
||||
|
||||
from cv2 import imread, IMREAD_GRAYSCALE
|
||||
|
||||
from openvino.runtime import Layout, Dimension # pylint: disable=E0611,E0401
|
||||
from ..api.data_loader import DataLoader
|
||||
from ..data_loaders.utils import prepare_image, collect_img_files
|
||||
|
||||
@ -14,6 +15,7 @@ class ImageLoader(DataLoader):
|
||||
|
||||
self._img_files = collect_img_files(config.data_source)
|
||||
self._shape = None
|
||||
self._layout = config.get('layout', None)
|
||||
self._crop_central_fraction = config.get('central_fraction', None)
|
||||
|
||||
def __getitem__(self, idx):
|
||||
@ -37,4 +39,29 @@ class ImageLoader(DataLoader):
|
||||
if image is None:
|
||||
raise Exception('Can not read the image: {}'.format(img_path))
|
||||
|
||||
return prepare_image(image, self.shape[-2:], self._crop_central_fraction)
|
||||
return prepare_image(image, self._layout, self.shape[-2:], self._crop_central_fraction)
|
||||
|
||||
def get_layout(self, input_node):
|
||||
if self._layout is not None:
|
||||
if 'C' not in self._layout or 'H' not in self._layout or 'W' not in self._layout:
|
||||
raise ValueError('Unexpected {} layout'.format(self._layout))
|
||||
self._layout = Layout(self._layout)
|
||||
return
|
||||
|
||||
layout_from_ir = input_node.graph.graph.get('layout', None)
|
||||
if layout_from_ir is not None:
|
||||
self._layout = Layout(layout_from_ir)
|
||||
return
|
||||
|
||||
image_colors_dim = (Dimension(3), Dimension(1))
|
||||
num_dims = len(self._shape)
|
||||
if num_dims == 4:
|
||||
if self._shape[1] in image_colors_dim:
|
||||
self._layout = Layout("NCHW")
|
||||
elif self._shape[3] in image_colors_dim:
|
||||
self._layout = Layout("NHWC")
|
||||
elif num_dims == 3:
|
||||
if self._shape[0] in image_colors_dim:
|
||||
self._layout = Layout("CHW")
|
||||
elif self._shape[2] in image_colors_dim:
|
||||
self._layout = Layout("HWC")
|
||||
|
@ -9,6 +9,7 @@ from pathlib import Path
|
||||
import numpy as np
|
||||
import cv2 as cv
|
||||
|
||||
from openvino.runtime import Layout # pylint: disable=E0611,E0401
|
||||
from openvino.tools.pot.utils.logger import get_logger
|
||||
|
||||
logger = get_logger(__name__)
|
||||
@ -34,12 +35,11 @@ def crop(image, central_fraction):
|
||||
return image[start_height:start_height + dst_height, start_width:start_width + dst_width]
|
||||
|
||||
|
||||
def prepare_image(image, dst_shape, central_fraction=None):
|
||||
|
||||
def prepare_image(image, layout, dst_shape, central_fraction=None):
|
||||
if central_fraction:
|
||||
image = crop(image, central_fraction)
|
||||
|
||||
if image.shape[-1] in [3, 1]:
|
||||
if layout == Layout('NCHW') or layout == Layout('CHW'):
|
||||
image = cv.resize(image, dst_shape[::-1])
|
||||
return image.transpose(2, 0, 1)
|
||||
|
||||
|
@ -44,3 +44,29 @@ def test_check_image(tmp_path, models, model_name, model_framework):
|
||||
num_images_in_dir = len(os.listdir(path_image_data))
|
||||
|
||||
assert num_images_from_data_loader == num_images_in_dir
|
||||
|
||||
|
||||
TEST_MODELS_LAYOUT = [('mobilenet-v2-pytorch', 'pytorch', 'NCHW', (3, 224, 224)),
|
||||
('mobilenet-v2-pytorch', 'pytorch', 'NHWC', (224, 224, 3)),
|
||||
('mobilenet-v2-pytorch', 'pytorch', None, (3, 224, 224)),
|
||||
('mobilenet-v1-1.0-224-tf', 'tf', None, (224, 224, 3))]
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
'model_name, model_framework, layout, reference_shape', TEST_MODELS,
|
||||
ids=['{}_{}'.format(m[0], m[1]) for m in TEST_MODELS])
|
||||
def test_check_layout(tmp_path, models, model_name, model_framework, layout, reference_shape):
|
||||
test_dir = Path(__file__).parent
|
||||
path_image_data = os.path.join(test_dir, "data/image_data")
|
||||
|
||||
engine_config = Dict({"device": "CPU",
|
||||
"type": "simplified",
|
||||
"layout": layout,
|
||||
"data_source": path_image_data})
|
||||
model = models.get(model_name, model_framework, tmp_path)
|
||||
model = load_model(model.model_params)
|
||||
|
||||
data_loader = create_data_loader(engine_config, model)
|
||||
image = data_loader.item()
|
||||
|
||||
assert image.shape == reference_shape
|
||||
|
Loading…
Reference in New Issue
Block a user