From 62fa09a18148aefff104f7f4c1536045ad03300f Mon Sep 17 00:00:00 2001 From: Przemyslaw Wysocki Date: Wed, 2 Aug 2023 14:44:30 +0200 Subject: [PATCH] [PyOV] Add torchvision to OpenVINO preprocessing converter (#17934) * Some progress * refactoring * Refactor tests, run black * Refactor, flake * Minor change * Add support for num_output_channels * WIP * Add dependency * Almost done * Fix flake * Add MO convert * Minor changes * Add interpolation mode tests * Add requirements for preprocessing * Fix linter * Update tests * Fix type error * Introduce typing * Rename module * Code review * Update src/bindings/python/src/openvino/preprocess/torchvision/torchvision_preprocessing.py Co-authored-by: Anastasia Kuporosova * Update src/bindings/python/src/openvino/preprocess/torchvision/preprocess_converter.py Co-authored-by: Anastasia Kuporosova * CR changes * Fix mypy * Minor change * Minor change * Fix flake * typing change * Update types * Add tools to pythonpath * Add MO reqs * Minor change * Fix PT FE issue * bugfix * Use absolute path * Debug * Debug req path * Change MO to OVC * Enable PT FE building * Debug * Skip some tests on ARM * ADd ticket numbers * Some progress * refactoring * Refactor tests, run black * Refactor, flake * Minor change * Add support for num_output_channels * WIP * Add dependency * Almost done * Fix flake * Add MO convert * Minor changes * Add interpolation mode tests * Add requirements for preprocessing * Fix linter * Update tests * Fix type error * Introduce typing * Rename module * Code review * Update src/bindings/python/src/openvino/preprocess/torchvision/torchvision_preprocessing.py Co-authored-by: Anastasia Kuporosova * Update src/bindings/python/src/openvino/preprocess/torchvision/preprocess_converter.py Co-authored-by: Anastasia Kuporosova * CR changes * Fix mypy * Minor change * Minor change * Fix flake * typing change * Update types * Add tools to pythonpath * Add MO reqs * Minor change * Fix PT FE issue * bugfix * Use absolute path * Debug * Debug req path * Change MO to OVC * Enable PT FE building * Debug * Skip some tests on ARM * ADd ticket numbers * Update test val * Change pytest version * Modify GH workflow * Update cmake * Update cmake * fix cmake * Skip tests on ARM * Fix after OVC MO changes * Limit torch requirements * Raise allowed diff in test * Debug - remove new bindings * Debug - remove interpolation modes * Debug - change test * Cleanup * Minor change * Disable on ARM and Py<38 * Update ARM marker * limit torchvision for ARM --------- Co-authored-by: Anastasia Kuporosova Co-authored-by: gklodkox Co-authored-by: Michal Lukaszewski --- .ci/azure/linux.yml | 2 + .ci/openvino-onnx/Dockerfile | 4 +- .github/workflows/linux.yml | 3 + src/bindings/python/constraints.txt | 8 +- src/bindings/python/requirements_test.txt | 4 + src/bindings/python/setup.cfg | 3 + .../preprocess/torchvision/__init__.py | 15 + .../torchvision/preprocess_converter.py | 47 +++ .../preprocess/torchvision/requirements.txt | 5 + .../torchvision/torchvision_preprocessing.py | 325 ++++++++++++++++++ .../python/src/pyopenvino/CMakeLists.txt | 8 +- .../test_preprocessor.py | 241 +++++++++++++ 12 files changed, 660 insertions(+), 5 deletions(-) create mode 100644 src/bindings/python/src/openvino/preprocess/torchvision/__init__.py create mode 100644 src/bindings/python/src/openvino/preprocess/torchvision/preprocess_converter.py create mode 100644 src/bindings/python/src/openvino/preprocess/torchvision/requirements.txt create mode 100644 src/bindings/python/src/openvino/preprocess/torchvision/torchvision_preprocessing.py create mode 100644 src/bindings/python/tests/test_torchvision_to_ov/test_preprocessor.py diff --git a/.ci/azure/linux.yml b/.ci/azure/linux.yml index dbb8a30a421..f3f2d724fdc 100644 --- a/.ci/azure/linux.yml +++ b/.ci/azure/linux.yml @@ -178,6 +178,8 @@ jobs: python3 -m pip install -r $(REPO_DIR)/src/frontends/onnx/tests/requirements.txt # For running TensorFlow frontend unit tests python3 -m pip install -r $(REPO_DIR)/src/frontends/tensorflow/tests/requirements.txt + # For running torchvision -> OpenVINO preprocess converter + python3 -m pip install -r $(REPO_DIR)/src/bindings/python/src/openvino/preprocess/torchvision/requirements.txt # For MO unit tests python3 -m pip install -r $(REPO_DIR)/tools/mo/requirements_mxnet.txt python3 -m pip install -r $(REPO_DIR)/tools/mo/requirements_caffe.txt diff --git a/.ci/openvino-onnx/Dockerfile b/.ci/openvino-onnx/Dockerfile index 22197a79f38..d5439a3c5e1 100644 --- a/.ci/openvino-onnx/Dockerfile +++ b/.ci/openvino-onnx/Dockerfile @@ -62,7 +62,7 @@ RUN cmake .. \ -DENABLE_PROFILING_ITT=OFF \ -DENABLE_SAMPLES=OFF \ -DENABLE_OV_PADDLE_FRONTEND=OFF \ - -DENABLE_OV_PYTORCH_FRONTEND=OFF \ + -DENABLE_OV_PYTORCH_FRONTEND=ON \ -DENABLE_OV_TF_FRONTEND=OFF \ -DENABLE_OPENVINO_DEBUG=OFF \ -DCMAKE_INSTALL_PREFIX=/openvino/dist @@ -72,5 +72,5 @@ RUN ninja install WORKDIR /openvino/src/bindings/python ENV OpenVINO_DIR=/openvino/dist/runtime/cmake ENV LD_LIBRARY_PATH=/openvino/dist/runtime/lib/intel64:/openvino/dist/runtime/3rdparty/tbb/lib -ENV PYTHONPATH=/openvino/bin/intel64/${BUILD_TYPE}/python:${PYTHONPATH} +ENV PYTHONPATH=/openvino/bin/intel64/${BUILD_TYPE}/python:/openvino/tools/mo:${PYTHONPATH} CMD tox diff --git a/.github/workflows/linux.yml b/.github/workflows/linux.yml index c56da45d333..cc16c105950 100644 --- a/.github/workflows/linux.yml +++ b/.github/workflows/linux.yml @@ -531,6 +531,9 @@ jobs: # For running Paddle frontend unit tests python3 -m pip install -r ${{ env.OPENVINO_REPO }}/src/frontends/paddle/tests/requirements.txt + # For torchvision to OpenVINO preprocessing converter + python3 -m pip install -r ${{ env.OPENVINO_REPO }}/src/bindings/python/src/openvino/preprocess/torchvision/requirements.txt + - name: Install MO dependencies run: | python3 -m pip install -r ${{ env.OPENVINO_REPO }}/tools/mo/requirements_mxnet.txt diff --git a/src/bindings/python/constraints.txt b/src/bindings/python/constraints.txt index 133a18789cc..57a9308981f 100644 --- a/src/bindings/python/constraints.txt +++ b/src/bindings/python/constraints.txt @@ -4,7 +4,7 @@ numpy>=1.16.6,<1.26 # Python bindings, frontends # pytest pytest>=5.0,<7.4 pytest-dependency==0.5.1 -pytest-html==3.1.1 +pytest-html==3.2.0 pytest-timeout==2.1.0 # Python bindings @@ -19,4 +19,8 @@ paddlepaddle==2.4.2 tensorflow>=1.15.5,<2.13.0 six~=1.16.0 protobuf>=3.18.1,<4.0.0 -onnx==1.13.1 \ No newline at end of file +onnx==1.13.1 + +# torchvision > OpenVINO preprocessing converter +pillow>=9.0 +torch>=1.13 diff --git a/src/bindings/python/requirements_test.txt b/src/bindings/python/requirements_test.txt index 3594e48d5f8..d2d92b04dcb 100644 --- a/src/bindings/python/requirements_test.txt +++ b/src/bindings/python/requirements_test.txt @@ -40,3 +40,7 @@ tox types-pkg_resources wheel singledispatchmethod +torch +torchvision; platform_machine == 'arm64' and python_version >= '3.8' +torchvision; platform_machine != 'arm64' +pillow diff --git a/src/bindings/python/setup.cfg b/src/bindings/python/setup.cfg index 2289fb5f335..083c8e1de85 100644 --- a/src/bindings/python/setup.cfg +++ b/src/bindings/python/setup.cfg @@ -7,6 +7,7 @@ skip_install=True deps = -rrequirements.txt -rrequirements_test.txt + -r /openvino/tools/mo/requirements.txt # for torchvision -> OV preprocess converter -r /openvino/src/frontends/onnx/tests/requirements.txt setenv = OV_BACKEND = {env:OV_BACKEND:"CPU"} @@ -43,6 +44,7 @@ deps = -rrequirements.txt # D107 - Missing docstring in __init__ # D412 - No blank lines allowed between a section header and its content # F401 - module imported but unused +# N801 - class name '...' should use CapWords convention # N803 - argument name '...' should be lowercase # T001 - print found # W503 - line break before binary operator (prefer line breaks before op, not after) @@ -66,6 +68,7 @@ per-file-ignores = src/openvino/runtime/*/ops.py: VNE001,VNE003 tests_compatibility/test_ngraph/*: C101,C812,C815,C816,C819,CCE001,D212,E800,ECE001,N400,N802,N806,P101,P103,PT001,PT005,PT006,PT011,PT019,PT023,RST201,S001,VNE002 src/compatibility/ngraph/*: C101,C812,C819,CCE001,E800,N806,P101,RST201,RST202,RST203,RST206,VNE001,VNE003 + src/openvino/preprocess/torchvision/*: N801, VNE001 *__init__.py: F401 [pydocstyle] diff --git a/src/bindings/python/src/openvino/preprocess/torchvision/__init__.py b/src/bindings/python/src/openvino/preprocess/torchvision/__init__.py new file mode 100644 index 00000000000..e66ed1fb94f --- /dev/null +++ b/src/bindings/python/src/openvino/preprocess/torchvision/__init__.py @@ -0,0 +1,15 @@ +# Copyright (C) 2018-2023 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +""" +Package: openvino +Torchvision to OpenVINO preprocess converter. +""" + +# flake8: noqa + +from openvino._pyopenvino import get_version as _get_version + +__version__ = _get_version() + +from .preprocess_converter import PreprocessConverter diff --git a/src/bindings/python/src/openvino/preprocess/torchvision/preprocess_converter.py b/src/bindings/python/src/openvino/preprocess/torchvision/preprocess_converter.py new file mode 100644 index 00000000000..ff0f89fe366 --- /dev/null +++ b/src/bindings/python/src/openvino/preprocess/torchvision/preprocess_converter.py @@ -0,0 +1,47 @@ +# -*- coding: utf-8 -*- +# Copyright (C) 2018-2023 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +from typing import Callable, Any, Union +import logging + +import openvino.runtime as ov + + +class PreprocessConverter(): + def __init__(self, model: ov.Model): + self._model = model + + @staticmethod + def from_torchvision(model: ov.Model, transform: Callable, input_example: Any, + input_name: Union[str, None] = None) -> ov.Model: + """Embed torchvision preprocessing in an OpenVINO model. + + Arguments: + model (ov.Model): + Result name + transform (Callable): + torchvision transform to convert + input_example (torch.Tensor or np.ndarray or PIL.Image): + Example of input data for transform to trace its structure. + Don't confuse with the model input. + input_name (str, optional): + Name of the current model's input node to connect with preprocessing. + Not needed if the model has one input. + + Returns: + ov.Mode: OpenVINO Model object with embedded preprocessing + Example: + >>> model = PreprocessorConvertor.from_torchvision(model, "input", transform, input_example) + """ + try: + import PIL + import torch + from torchvision import transforms + from .torchvision_preprocessing import _from_torchvision + return _from_torchvision(model, transform, input_example, input_name) + except ImportError as e: + raise ImportError(f"Please install torch, torchvision and pillow packages:\n{e}") + except Exception as e: + logging.error(f"Unexpected error: {e}") + raise e diff --git a/src/bindings/python/src/openvino/preprocess/torchvision/requirements.txt b/src/bindings/python/src/openvino/preprocess/torchvision/requirements.txt new file mode 100644 index 00000000000..feda592b307 --- /dev/null +++ b/src/bindings/python/src/openvino/preprocess/torchvision/requirements.txt @@ -0,0 +1,5 @@ +-c ../../../../constraints.txt +torch +torchvision; platform_machine == 'arm64' and python_version >= '3.8' +torchvision; platform_machine != 'arm64' +pillow \ No newline at end of file diff --git a/src/bindings/python/src/openvino/preprocess/torchvision/torchvision_preprocessing.py b/src/bindings/python/src/openvino/preprocess/torchvision/torchvision_preprocessing.py new file mode 100644 index 00000000000..60fef4d22b9 --- /dev/null +++ b/src/bindings/python/src/openvino/preprocess/torchvision/torchvision_preprocessing.py @@ -0,0 +1,325 @@ +# -*- coding: utf-8 -*- +# Copyright (C) 2018-2023 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +# mypy: disable-error-code="no-redef" + +import numbers +import logging +import copy +import numpy as np +from typing import List, Dict +from abc import ABCMeta, abstractmethod +from typing import Callable, Any, Union, Tuple +from typing import Sequence as SequenceType +from collections.abc import Sequence +from PIL import Image + +import torch +import torchvision.transforms as transforms +from torchvision.transforms import InterpolationMode + +import openvino.runtime as ov +import openvino.runtime.opset11 as ops +from openvino.runtime import Layout, Type +from openvino.runtime.utils.decorators import custom_preprocess_function +from openvino.preprocess import PrePostProcessor, ResizeAlgorithm, ColorFormat + + +TORCHTYPE_TO_OVTYPE = { + float: ov.Type.f32, + int: ov.Type.i32, + bool: ov.Type.boolean, + torch.float16: ov.Type.f16, + torch.float32: ov.Type.f32, + torch.float64: ov.Type.f64, + torch.uint8: ov.Type.u8, + torch.int8: ov.Type.i8, + torch.int32: ov.Type.i32, + torch.int64: ov.Type.i64, + torch.bool: ov.Type.boolean, + torch.DoubleTensor: ov.Type.f64, + torch.FloatTensor: ov.Type.f32, + torch.IntTensor: ov.Type.i32, + torch.LongTensor: ov.Type.i64, + torch.BoolTensor: ov.Type.boolean, +} + + +def _setup_size(size: Any, error_msg: str) -> SequenceType[int]: + # TODO: refactor into @singledispatch once Python 3.7 support is dropped + if isinstance(size, numbers.Number): + return int(size), int(size) # type: ignore + if isinstance(size, Sequence): + if len(size) == 1: + return size[0], size[0] + elif len(size) == 2: + return size + raise ValueError(error_msg) + + +def _NHWC_to_NCHW(input_shape: List) -> List: # noqa N802 + new_shape = copy.deepcopy(input_shape) + new_shape[1] = input_shape[3] + new_shape[2] = input_shape[1] + new_shape[3] = input_shape[2] + return new_shape + + +def _to_list(transform: Callable) -> List: + # TODO: refactor into @singledispatch once Python 3.7 support is dropped + if isinstance(transform, torch.nn.Sequential): + return list(transform) + elif isinstance(transform, transforms.Compose): + return transform.transforms + else: + raise TypeError(f"Unsupported transform type: {type(transform)}") + + +def _get_shape_layout_from_data(input_example: Union[torch.Tensor, np.ndarray, Image.Image]) -> Tuple[List, Layout]: + if isinstance(input_example, (torch.Tensor, np.ndarray, Image.Image)): # PyTorch, OpenCV, numpy, PILLOW + shape = list(np.array(input_example, copy=False).shape) + layout = Layout("NCHW") if isinstance(input_example, torch.Tensor) else Layout("NHWC") + else: + raise TypeError(f"Unsupported input type: {type(input_example)}") + + if len(shape) == 3: + shape = [1] + shape + elif len(shape) != 4: + raise ValueError(f"Unsupported number of input dimensions: {len(shape)}") + + return shape, layout + + +class TransformConverterBase(metaclass=ABCMeta): + + def __init__(self, **kwargs: Any) -> None: # noqa B027 + pass + + @abstractmethod + def convert(self, input_idx: int, ppp: PrePostProcessor, transform: Callable, meta: Dict) -> None: + pass + + +class TransformConverterFactory: + + registry: Dict[str, Callable] = {} + + @classmethod + def register(cls: Callable, target_type: Union[Callable, None] = None) -> Callable: + def inner_wrapper(wrapped_class: TransformConverterBase) -> Callable: + registered_name = wrapped_class.__name__ if target_type is None else target_type.__name__ + if registered_name in cls.registry: + logging.warning(f"Executor {registered_name} already exists. {wrapped_class.__name__} will replace it.") + cls.registry[registered_name] = wrapped_class + return wrapped_class # type: ignore + + return inner_wrapper + + @classmethod + def convert(cls: Callable, converter_type: Callable, *args: Any, **kwargs: Any) -> Callable: + transform_name = converter_type.__name__ + if transform_name not in cls.registry: + raise ValueError(f"{transform_name} is not supported.") + + converter = cls.registry[transform_name]() + return converter.convert(*args, **kwargs) + + +@TransformConverterFactory.register(transforms.Normalize) +class _(TransformConverterBase): + def convert(self, input_idx: int, ppp: PrePostProcessor, transform: Callable, meta: Dict) -> None: + if transform.inplace: + raise ValueError("Inplace Normaliziation is not supported.") + ppp.input(input_idx).preprocess().mean(transform.mean).scale(transform.std) + + +@TransformConverterFactory.register(transforms.ConvertImageDtype) +class _(TransformConverterBase): + def convert(self, input_idx: int, ppp: PrePostProcessor, transform: Callable, meta: Dict) -> None: + ppp.input(input_idx).preprocess().convert_element_type(TORCHTYPE_TO_OVTYPE[transform.dtype]) + + +@TransformConverterFactory.register(transforms.Grayscale) +class _(TransformConverterBase): + def convert(self, input_idx: int, ppp: PrePostProcessor, transform: Callable, meta: Dict) -> None: + input_shape = meta["input_shape"] + layout = meta["layout"] + + input_shape[layout.get_index_by_name("C")] = 1 + + ppp.input(input_idx).preprocess().convert_color(ColorFormat.GRAY) + if transform.num_output_channels != 1: + input_shape[layout.get_index_by_name("C")] = transform.num_output_channels + + @custom_preprocess_function + def broadcast_node(output: ov.Output) -> Callable: + return ops.broadcast( + data=output, + target_shape=input_shape, + ) + ppp.input(input_idx).preprocess().custom(broadcast_node) + + meta["input_shape"] = input_shape + + +@TransformConverterFactory.register(transforms.Pad) +class _(TransformConverterBase): + def convert(self, input_idx: int, ppp: PrePostProcessor, transform: Callable, meta: Dict) -> None: + image_dimensions = list(meta["image_dimensions"]) + layout = meta["layout"] + torch_padding = transform.padding + pad_mode = transform.padding_mode + + if pad_mode == "constant": + if isinstance(transform.fill, tuple): + raise ValueError("Different fill values for R, G, B channels are not supported.") + + pads_begin = [0 for _ in meta["input_shape"]] + pads_end = [0 for _ in meta["input_shape"]] + + # padding equal on all sides + if isinstance(torch_padding, int): + image_dimensions[0] += 2 * torch_padding + image_dimensions[1] += 2 * torch_padding + + pads_begin[layout.get_index_by_name("H")] = torch_padding + pads_begin[layout.get_index_by_name("W")] = torch_padding + pads_end[layout.get_index_by_name("H")] = torch_padding + pads_end[layout.get_index_by_name("W")] = torch_padding + + # padding different in horizontal and vertical axis + elif len(torch_padding) == 2: + image_dimensions[0] += sum(torch_padding) + image_dimensions[1] += sum(torch_padding) + + pads_begin[layout.get_index_by_name("H")] = torch_padding[1] + pads_begin[layout.get_index_by_name("W")] = torch_padding[0] + pads_end[layout.get_index_by_name("H")] = torch_padding[1] + pads_end[layout.get_index_by_name("W")] = torch_padding[0] + + # padding different on top, bottom, left and right of image + else: + image_dimensions[0] += torch_padding[1] + torch_padding[3] + image_dimensions[1] += torch_padding[0] + torch_padding[2] + + pads_begin[layout.get_index_by_name("H")] = torch_padding[1] + pads_begin[layout.get_index_by_name("W")] = torch_padding[0] + pads_end[layout.get_index_by_name("H")] = torch_padding[3] + pads_end[layout.get_index_by_name("W")] = torch_padding[2] + + @custom_preprocess_function + def pad_node(output: ov.Output) -> Callable: + return ops.pad( + output, + pad_mode=pad_mode, + pads_begin=pads_begin, + pads_end=pads_end, + arg_pad_value=np.array(transform.fill, dtype=np.uint8) if pad_mode == "constant" else None, + ) + + ppp.input(input_idx).preprocess().custom(pad_node) + meta["image_dimensions"] = tuple(image_dimensions) + + +@TransformConverterFactory.register(transforms.ToTensor) +class _(TransformConverterBase): + def convert(self, input_idx: int, ppp: PrePostProcessor, transform: Callable, meta: Dict) -> None: + input_shape = meta["input_shape"] + layout = meta["layout"] + + ppp.input(input_idx).tensor().set_element_type(Type.u8).set_layout(Layout("NHWC")).set_color_format(ColorFormat.RGB) # noqa ECE001 + + if layout == Layout("NHWC"): + input_shape = _NHWC_to_NCHW(input_shape) + layout = Layout("NCHW") + ppp.input(input_idx).preprocess().convert_layout(layout) + ppp.input(input_idx).preprocess().convert_element_type(Type.f32) + ppp.input(input_idx).preprocess().scale(255.0) + + meta["input_shape"] = input_shape + meta["layout"] = layout + + +@TransformConverterFactory.register(transforms.CenterCrop) +class _(TransformConverterBase): + def convert(self, input_idx: int, ppp: PrePostProcessor, transform: Callable, meta: Dict) -> None: + input_shape = meta["input_shape"] + source_size = meta["image_dimensions"] + target_size = _setup_size(transform.size, "Incorrect size type for CenterCrop operation") + + if target_size[0] > source_size[0] or target_size[1] > source_size[1]: + ValueError(f"CenterCrop size={target_size} is greater than source_size={source_size}") + + bottom_left = [] + bottom_left.append(int((source_size[0] - target_size[0]) / 2)) + bottom_left.append(int((source_size[1] - target_size[1]) / 2)) + + top_right = [] + top_right.append(min(bottom_left[0] + target_size[0], source_size[0] - 1)) + top_right.append(min(bottom_left[1] + target_size[1], source_size[1] - 1)) + + bottom_left = [0] * len(input_shape[:-2]) + bottom_left if meta["layout"] == Layout("NCHW") else [0] + bottom_left + [0] # noqa ECE001 + top_right = input_shape[:-2] + top_right if meta["layout"] == Layout("NCHW") else input_shape[:1] + top_right + input_shape[-1:] + + ppp.input(input_idx).preprocess().crop(bottom_left, top_right) + meta["image_dimensions"] = (target_size[-2], target_size[-1]) + + +@TransformConverterFactory.register(transforms.Resize) +class _(TransformConverterBase): + def convert(self, input_idx: int, ppp: PrePostProcessor, transform: Callable, meta: Dict) -> None: + resize_mode_map = { + InterpolationMode.NEAREST: ResizeAlgorithm.RESIZE_NEAREST, + } + if transform.max_size: + raise ValueError("Resize with max_size if not supported") + if transform.interpolation is not InterpolationMode.NEAREST: + raise ValueError("Only InterpolationMode.NEAREST is supported.") + + h, w = _setup_size(transform.size, "Incorrect size type for Resize operation") + + ppp.input(input_idx).tensor().set_layout(Layout("NCHW")) + + input_shape = meta["input_shape"] + + input_shape[meta["layout"].get_index_by_name("H")] = -1 + input_shape[meta["layout"].get_index_by_name("W")] = -1 + + ppp.input(input_idx).tensor().set_shape(input_shape) + ppp.input(input_idx).preprocess().resize(resize_mode_map[transform.interpolation], h, w) + meta["input_shape"] = input_shape + meta["image_dimensions"] = (h, w) + + +def _from_torchvision(model: ov.Model, transform: Callable, input_example: Any, input_name: Union[str, None] = None) -> ov.Model: + + if input_name is not None: + input_idx = next((i for i, p in enumerate(model.get_parameters()) if p.get_friendly_name() == input_name), None) + else: + if len(model.get_parameters()) == 1: + input_idx = 0 + else: + raise ValueError("Model contains multiple inputs. Please specify the name of the input to which prepocessing is added.") + + if input_idx is None: + raise ValueError(f"Input with name {input_name} is not found") + + input_shape, layout = _get_shape_layout_from_data(input_example) + + ppp = PrePostProcessor(model) + ppp.input(input_idx).tensor().set_layout(layout) + ppp.input(input_idx).tensor().set_shape(input_shape) + + image_dimensions = [input_shape[layout.get_index_by_name("H")], input_shape[layout.get_index_by_name("W")]] + global_meta = { + "input_shape": input_shape, + "image_dimensions": image_dimensions, + "layout": layout, + } + + for tm in _to_list(transform): + TransformConverterFactory.convert(type(tm), input_idx, ppp, tm, global_meta) + + updated_model = ppp.build() + return updated_model diff --git a/src/bindings/python/src/pyopenvino/CMakeLists.txt b/src/bindings/python/src/pyopenvino/CMakeLists.txt index cf1ba99d190..1ce708f8427 100644 --- a/src/bindings/python/src/pyopenvino/CMakeLists.txt +++ b/src/bindings/python/src/pyopenvino/CMakeLists.txt @@ -116,7 +116,8 @@ if(OpenVINO_SOURCE_DIR OR OpenVINODeveloperPackage_FOUND) COMPONENT ${OV_CPACK_COMP_PYTHON_OPENVINO}_${pyversion} ${OV_CPACK_COMP_PYTHON_OPENVINO_EXCLUDE_ALL} USE_SOURCE_PERMISSIONS - PATTERN "test_utils" EXCLUDE) + PATTERN "test_utils" EXCLUDE + PATTERN "torchvision/requirements.txt" EXCLUDE) install(TARGETS ${PROJECT_NAME} DESTINATION ${OV_CPACK_PYTHONDIR}/openvino @@ -127,6 +128,11 @@ if(OpenVINO_SOURCE_DIR OR OpenVINODeveloperPackage_FOUND) DESTINATION ${OV_CPACK_PYTHONDIR} COMPONENT ${OV_CPACK_COMP_OPENVINO_DEV_REQ_FILES} ${OV_CPACK_COMP_OPENVINO_DEV_REQ_FILES_EXCLUDE_ALL}) + + install(FILES ${OpenVINOPython_SOURCE_DIR}/src/openvino/preprocess/torchvision/requirements.txt + DESTINATION ${OV_CPACK_PYTHONDIR} + COMPONENT ${OV_CPACK_COMP_OPENVINO_DEV_REQ_FILES}/src/openvino/preprocess/torchvision + ${OV_CPACK_COMP_OPENVINO_DEV_REQ_FILES_EXCLUDE_ALL}) install(DIRECTORY ${OpenVINOPython_SOURCE_DIR}/tests DESTINATION tests/${PROJECT_NAME} diff --git a/src/bindings/python/tests/test_torchvision_to_ov/test_preprocessor.py b/src/bindings/python/tests/test_torchvision_to_ov/test_preprocessor.py new file mode 100644 index 00000000000..f72c1f284b3 --- /dev/null +++ b/src/bindings/python/tests/test_torchvision_to_ov/test_preprocessor.py @@ -0,0 +1,241 @@ +# -*- coding: utf-8 -*- +# Copyright (C) 2018-2023 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +import numpy as np +import copy +import pytest +import platform +from PIL import Image + +import torch +import torch.nn.functional as f +import torchvision.transforms as transforms + +from openvino.runtime import Core, Tensor +from openvino.tools.mo import convert_model + +from openvino.preprocess.torchvision import PreprocessConverter + + +class Convnet(torch.nn.Module): + def __init__(self, input_channels): + super(Convnet, self).__init__() + self.conv1 = torch.nn.Conv2d(input_channels, 6, 5) + self.conv2 = torch.nn.Conv2d(6, 16, 3) + + def forward(self, data): + data = f.max_pool2d(f.relu(self.conv1(data)), 2) + data = f.max_pool2d(f.relu(self.conv2(data)), 2) + return data + + +def _infer_pipelines(test_input, preprocess_pipeline, input_channels=3): + torch_model = Convnet(input_channels) + example_input = Tensor(np.expand_dims(test_input, axis=0).astype(np.float32)) + ov_model = convert_model(torch_model, example_input=example_input) + core = Core() + + ov_model = PreprocessConverter.from_torchvision( + model=ov_model, transform=preprocess_pipeline, input_example=Image.fromarray(test_input.astype("uint8"), "RGB"), + ) + ov_model = core.compile_model(ov_model, "CPU") + + # Torch results + torch_input = copy.deepcopy(test_input) + test_image = Image.fromarray(torch_input.astype("uint8"), "RGB") + transformed_input = preprocess_pipeline(test_image) + transformed_input = torch.unsqueeze(transformed_input, dim=0) + with torch.no_grad(): + torch_result = torch_model(transformed_input).numpy() + + # OpenVINO results + ov_input = test_input + ov_input = np.expand_dims(ov_input, axis=0) + output = ov_model.output(0) + ov_result = ov_model(ov_input)[output] + + return torch_result, ov_result + + +def test_normalize(): + test_input = np.random.randint(255, size=(224, 224, 3), dtype=np.uint8) + preprocess_pipeline = transforms.Compose([transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])]) + torch_result, ov_result = _infer_pipelines(test_input, preprocess_pipeline) + assert np.max(np.absolute(torch_result - ov_result)) < 4e-05 + + +@pytest.mark.parametrize( + ("interpolation", "tolerance"), + [ + (transforms.InterpolationMode.NEAREST, 4e-05), + ], +) +def test_resize(interpolation, tolerance): + if platform.machine() in ["arm", "armv7l", "aarch64", "arm64", "ARM64"]: + pytest.skip("Ticket: 114816") + test_input = np.random.randint(255, size=(220, 220, 3), dtype=np.uint8) + preprocess_pipeline = transforms.Compose( + [ + transforms.Resize(224, interpolation=interpolation), + transforms.ToTensor(), + ], + ) + torch_result, ov_result = _infer_pipelines(test_input, preprocess_pipeline) + assert np.max(np.absolute(torch_result - ov_result)) < tolerance + + +def test_convertimagedtype(): + test_input = np.random.randint(255, size=(224, 224, 3), dtype=np.uint8) + preprocess_pipeline = transforms.Compose( + [ + transforms.ToTensor(), + transforms.ConvertImageDtype(torch.float16), + transforms.ConvertImageDtype(torch.float32), + ], + ) + torch_result, ov_result = _infer_pipelines(test_input, preprocess_pipeline) + assert np.max(np.absolute(torch_result - ov_result)) < 2e-04 + + +@pytest.mark.parametrize( + ("test_input", "preprocess_pipeline"), + [ + ( + np.random.randint(255, size=(220, 220, 3), dtype=np.uint8), + transforms.Compose( + [ + transforms.Pad((2)), + transforms.ToTensor(), + ], + ), + ), + ( + np.random.randint(255, size=(218, 220, 3), dtype=np.uint8), + transforms.Compose( + [ + transforms.Pad((2, 3)), + transforms.ToTensor(), + ], + ), + ), + ( + np.random.randint(255, size=(216, 218, 3), dtype=np.uint8), + transforms.Compose( + [ + transforms.Pad((2, 3, 4, 5)), + transforms.ToTensor(), + ], + ), + ), + ( + np.random.randint(255, size=(216, 218, 3), dtype=np.uint8), + transforms.Compose( + [ + transforms.Pad((2, 3, 4, 5), fill=3), + transforms.ToTensor(), + ], + ), + ), + ( + np.random.randint(255, size=(218, 220, 3), dtype=np.uint8), + transforms.Compose( + [ + transforms.Pad((2, 3), padding_mode="edge"), + transforms.ToTensor(), + ], + ), + ), + ( + np.random.randint(255, size=(218, 220, 3), dtype=np.uint8), + transforms.Compose( + [ + transforms.Pad((2, 3), padding_mode="reflect"), + transforms.ToTensor(), + ], + ), + ), + ( + np.random.randint(255, size=(218, 220, 3), dtype=np.uint8), + transforms.Compose( + [ + transforms.Pad((2, 3), padding_mode="symmetric"), + transforms.ToTensor(), + ], + ), + ), + ], +) +def test_pad(test_input, preprocess_pipeline): + torch_result, ov_result = _infer_pipelines(test_input, preprocess_pipeline) + assert np.max(np.absolute(torch_result - ov_result)) < 4e-05 + + +def test_centercrop(): + test_input = np.random.randint(255, size=(260, 260, 3), dtype=np.uint8) + preprocess_pipeline = transforms.Compose( + [ + transforms.CenterCrop((224)), + transforms.ToTensor(), + ], + ) + torch_result, ov_result = _infer_pipelines(test_input, preprocess_pipeline) + assert np.max(np.absolute(torch_result - ov_result)) < 4e-05 + + +def test_grayscale(): + test_input = np.random.randint(255, size=(224, 224, 3), dtype=np.uint8) + preprocess_pipeline = transforms.Compose([transforms.ToTensor(), transforms.Grayscale()]) + torch_result, ov_result = _infer_pipelines(test_input, preprocess_pipeline, input_channels=1) + assert np.max(np.absolute(torch_result - ov_result)) < 2e-04 + + +def test_grayscale_num_output_channels(): + test_input = np.random.randint(255, size=(224, 224, 3), dtype=np.uint8) + preprocess_pipeline = transforms.Compose([transforms.ToTensor(), transforms.Grayscale(3)]) + torch_result, ov_result = _infer_pipelines(test_input, preprocess_pipeline) + assert np.max(np.absolute(torch_result - ov_result)) < 2e-04 + + +def test_pipeline_1(): + test_input = np.random.randint(255, size=(260, 260, 3), dtype=np.uint8) + preprocess_pipeline = transforms.Compose( + [ + transforms.Resize(256, interpolation=transforms.InterpolationMode.NEAREST), + transforms.CenterCrop((216, 218)), + transforms.Pad((2, 3, 4, 5), fill=3), + transforms.ToTensor(), + transforms.ConvertImageDtype(torch.float32), + transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), + ], + ) + torch_result, ov_result = _infer_pipelines(test_input, preprocess_pipeline) + assert np.max(np.absolute(torch_result - ov_result)) < 4e-05 + + +def test_pipeline_2(): + if platform.machine() in ["arm", "armv7l", "aarch64", "arm64", "ARM64"]: + pytest.skip("Ticket: 114816") + test_input = np.random.randint(255, size=(224, 224, 3), dtype=np.uint8) + preprocess_pipeline = transforms.Compose( + [ + transforms.Resize(256, interpolation=transforms.InterpolationMode.NEAREST), + transforms.CenterCrop(224), + transforms.ToTensor(), + transforms.Normalize((0.481, 0.457, 0.408), (0.268, 0.261, 0.275)), + ], + ) + torch_result, ov_result = _infer_pipelines(test_input, preprocess_pipeline) + assert np.max(np.absolute(torch_result - ov_result)) < 5e-03 + + +def test_pipeline_3(): + test_input = np.random.randint(255, size=(260, 260, 3), dtype=np.uint8) + preprocess_pipeline = transforms.Compose( + [ + transforms.ToTensor(), + transforms.ConvertImageDtype(torch.float), + ], + ) + torch_result, ov_result = _infer_pipelines(test_input, preprocess_pipeline) + assert np.max(np.absolute(torch_result - ov_result)) < 2e-03