[PyOV] Add torchvision to OpenVINO preprocessing converter (#17934)
* Some progress * refactoring * Refactor tests, run black * Refactor, flake * Minor change * Add support for num_output_channels * WIP * Add dependency * Almost done * Fix flake * Add MO convert * Minor changes * Add interpolation mode tests * Add requirements for preprocessing * Fix linter * Update tests * Fix type error * Introduce typing * Rename module * Code review * Update src/bindings/python/src/openvino/preprocess/torchvision/torchvision_preprocessing.py Co-authored-by: Anastasia Kuporosova <anastasia.kuporosova@intel.com> * Update src/bindings/python/src/openvino/preprocess/torchvision/preprocess_converter.py Co-authored-by: Anastasia Kuporosova <anastasia.kuporosova@intel.com> * CR changes * Fix mypy * Minor change * Minor change * Fix flake * typing change * Update types * Add tools to pythonpath * Add MO reqs * Minor change * Fix PT FE issue * bugfix * Use absolute path * Debug * Debug req path * Change MO to OVC * Enable PT FE building * Debug * Skip some tests on ARM * ADd ticket numbers * Some progress * refactoring * Refactor tests, run black * Refactor, flake * Minor change * Add support for num_output_channels * WIP * Add dependency * Almost done * Fix flake * Add MO convert * Minor changes * Add interpolation mode tests * Add requirements for preprocessing * Fix linter * Update tests * Fix type error * Introduce typing * Rename module * Code review * Update src/bindings/python/src/openvino/preprocess/torchvision/torchvision_preprocessing.py Co-authored-by: Anastasia Kuporosova <anastasia.kuporosova@intel.com> * Update src/bindings/python/src/openvino/preprocess/torchvision/preprocess_converter.py Co-authored-by: Anastasia Kuporosova <anastasia.kuporosova@intel.com> * CR changes * Fix mypy * Minor change * Minor change * Fix flake * typing change * Update types * Add tools to pythonpath * Add MO reqs * Minor change * Fix PT FE issue * bugfix * Use absolute path * Debug * Debug req path * Change MO to OVC * Enable PT FE building * Debug * Skip some tests on ARM * ADd ticket numbers * Update test val * Change pytest version * Modify GH workflow * Update cmake * Update cmake * fix cmake * Skip tests on ARM * Fix after OVC MO changes * Limit torch requirements * Raise allowed diff in test * Debug - remove new bindings * Debug - remove interpolation modes * Debug - change test * Cleanup * Minor change * Disable on ARM and Py<38 * Update ARM marker * limit torchvision for ARM --------- Co-authored-by: Anastasia Kuporosova <anastasia.kuporosova@intel.com> Co-authored-by: gklodkox <gracjanx.klodkowski@intel.com> Co-authored-by: Michal Lukaszewski <michal.lukaszewski@intel.com>
This commit is contained in:
parent
7be660e551
commit
62fa09a181
@ -178,6 +178,8 @@ jobs:
|
|||||||
python3 -m pip install -r $(REPO_DIR)/src/frontends/onnx/tests/requirements.txt
|
python3 -m pip install -r $(REPO_DIR)/src/frontends/onnx/tests/requirements.txt
|
||||||
# For running TensorFlow frontend unit tests
|
# For running TensorFlow frontend unit tests
|
||||||
python3 -m pip install -r $(REPO_DIR)/src/frontends/tensorflow/tests/requirements.txt
|
python3 -m pip install -r $(REPO_DIR)/src/frontends/tensorflow/tests/requirements.txt
|
||||||
|
# For running torchvision -> OpenVINO preprocess converter
|
||||||
|
python3 -m pip install -r $(REPO_DIR)/src/bindings/python/src/openvino/preprocess/torchvision/requirements.txt
|
||||||
# For MO unit tests
|
# For MO unit tests
|
||||||
python3 -m pip install -r $(REPO_DIR)/tools/mo/requirements_mxnet.txt
|
python3 -m pip install -r $(REPO_DIR)/tools/mo/requirements_mxnet.txt
|
||||||
python3 -m pip install -r $(REPO_DIR)/tools/mo/requirements_caffe.txt
|
python3 -m pip install -r $(REPO_DIR)/tools/mo/requirements_caffe.txt
|
||||||
|
@ -62,7 +62,7 @@ RUN cmake .. \
|
|||||||
-DENABLE_PROFILING_ITT=OFF \
|
-DENABLE_PROFILING_ITT=OFF \
|
||||||
-DENABLE_SAMPLES=OFF \
|
-DENABLE_SAMPLES=OFF \
|
||||||
-DENABLE_OV_PADDLE_FRONTEND=OFF \
|
-DENABLE_OV_PADDLE_FRONTEND=OFF \
|
||||||
-DENABLE_OV_PYTORCH_FRONTEND=OFF \
|
-DENABLE_OV_PYTORCH_FRONTEND=ON \
|
||||||
-DENABLE_OV_TF_FRONTEND=OFF \
|
-DENABLE_OV_TF_FRONTEND=OFF \
|
||||||
-DENABLE_OPENVINO_DEBUG=OFF \
|
-DENABLE_OPENVINO_DEBUG=OFF \
|
||||||
-DCMAKE_INSTALL_PREFIX=/openvino/dist
|
-DCMAKE_INSTALL_PREFIX=/openvino/dist
|
||||||
@ -72,5 +72,5 @@ RUN ninja install
|
|||||||
WORKDIR /openvino/src/bindings/python
|
WORKDIR /openvino/src/bindings/python
|
||||||
ENV OpenVINO_DIR=/openvino/dist/runtime/cmake
|
ENV OpenVINO_DIR=/openvino/dist/runtime/cmake
|
||||||
ENV LD_LIBRARY_PATH=/openvino/dist/runtime/lib/intel64:/openvino/dist/runtime/3rdparty/tbb/lib
|
ENV LD_LIBRARY_PATH=/openvino/dist/runtime/lib/intel64:/openvino/dist/runtime/3rdparty/tbb/lib
|
||||||
ENV PYTHONPATH=/openvino/bin/intel64/${BUILD_TYPE}/python:${PYTHONPATH}
|
ENV PYTHONPATH=/openvino/bin/intel64/${BUILD_TYPE}/python:/openvino/tools/mo:${PYTHONPATH}
|
||||||
CMD tox
|
CMD tox
|
||||||
|
3
.github/workflows/linux.yml
vendored
3
.github/workflows/linux.yml
vendored
@ -531,6 +531,9 @@ jobs:
|
|||||||
# For running Paddle frontend unit tests
|
# For running Paddle frontend unit tests
|
||||||
python3 -m pip install -r ${{ env.OPENVINO_REPO }}/src/frontends/paddle/tests/requirements.txt
|
python3 -m pip install -r ${{ env.OPENVINO_REPO }}/src/frontends/paddle/tests/requirements.txt
|
||||||
|
|
||||||
|
# For torchvision to OpenVINO preprocessing converter
|
||||||
|
python3 -m pip install -r ${{ env.OPENVINO_REPO }}/src/bindings/python/src/openvino/preprocess/torchvision/requirements.txt
|
||||||
|
|
||||||
- name: Install MO dependencies
|
- name: Install MO dependencies
|
||||||
run: |
|
run: |
|
||||||
python3 -m pip install -r ${{ env.OPENVINO_REPO }}/tools/mo/requirements_mxnet.txt
|
python3 -m pip install -r ${{ env.OPENVINO_REPO }}/tools/mo/requirements_mxnet.txt
|
||||||
|
@ -4,7 +4,7 @@ numpy>=1.16.6,<1.26 # Python bindings, frontends
|
|||||||
# pytest
|
# pytest
|
||||||
pytest>=5.0,<7.4
|
pytest>=5.0,<7.4
|
||||||
pytest-dependency==0.5.1
|
pytest-dependency==0.5.1
|
||||||
pytest-html==3.1.1
|
pytest-html==3.2.0
|
||||||
pytest-timeout==2.1.0
|
pytest-timeout==2.1.0
|
||||||
|
|
||||||
# Python bindings
|
# Python bindings
|
||||||
@ -19,4 +19,8 @@ paddlepaddle==2.4.2
|
|||||||
tensorflow>=1.15.5,<2.13.0
|
tensorflow>=1.15.5,<2.13.0
|
||||||
six~=1.16.0
|
six~=1.16.0
|
||||||
protobuf>=3.18.1,<4.0.0
|
protobuf>=3.18.1,<4.0.0
|
||||||
onnx==1.13.1
|
onnx==1.13.1
|
||||||
|
|
||||||
|
# torchvision > OpenVINO preprocessing converter
|
||||||
|
pillow>=9.0
|
||||||
|
torch>=1.13
|
||||||
|
@ -40,3 +40,7 @@ tox
|
|||||||
types-pkg_resources
|
types-pkg_resources
|
||||||
wheel
|
wheel
|
||||||
singledispatchmethod
|
singledispatchmethod
|
||||||
|
torch
|
||||||
|
torchvision; platform_machine == 'arm64' and python_version >= '3.8'
|
||||||
|
torchvision; platform_machine != 'arm64'
|
||||||
|
pillow
|
||||||
|
@ -7,6 +7,7 @@ skip_install=True
|
|||||||
deps =
|
deps =
|
||||||
-rrequirements.txt
|
-rrequirements.txt
|
||||||
-rrequirements_test.txt
|
-rrequirements_test.txt
|
||||||
|
-r /openvino/tools/mo/requirements.txt # for torchvision -> OV preprocess converter
|
||||||
-r /openvino/src/frontends/onnx/tests/requirements.txt
|
-r /openvino/src/frontends/onnx/tests/requirements.txt
|
||||||
setenv =
|
setenv =
|
||||||
OV_BACKEND = {env:OV_BACKEND:"CPU"}
|
OV_BACKEND = {env:OV_BACKEND:"CPU"}
|
||||||
@ -43,6 +44,7 @@ deps = -rrequirements.txt
|
|||||||
# D107 - Missing docstring in __init__
|
# D107 - Missing docstring in __init__
|
||||||
# D412 - No blank lines allowed between a section header and its content
|
# D412 - No blank lines allowed between a section header and its content
|
||||||
# F401 - module imported but unused
|
# F401 - module imported but unused
|
||||||
|
# N801 - class name '...' should use CapWords convention
|
||||||
# N803 - argument name '...' should be lowercase
|
# N803 - argument name '...' should be lowercase
|
||||||
# T001 - print found
|
# T001 - print found
|
||||||
# W503 - line break before binary operator (prefer line breaks before op, not after)
|
# W503 - line break before binary operator (prefer line breaks before op, not after)
|
||||||
@ -66,6 +68,7 @@ per-file-ignores =
|
|||||||
src/openvino/runtime/*/ops.py: VNE001,VNE003
|
src/openvino/runtime/*/ops.py: VNE001,VNE003
|
||||||
tests_compatibility/test_ngraph/*: C101,C812,C815,C816,C819,CCE001,D212,E800,ECE001,N400,N802,N806,P101,P103,PT001,PT005,PT006,PT011,PT019,PT023,RST201,S001,VNE002
|
tests_compatibility/test_ngraph/*: C101,C812,C815,C816,C819,CCE001,D212,E800,ECE001,N400,N802,N806,P101,P103,PT001,PT005,PT006,PT011,PT019,PT023,RST201,S001,VNE002
|
||||||
src/compatibility/ngraph/*: C101,C812,C819,CCE001,E800,N806,P101,RST201,RST202,RST203,RST206,VNE001,VNE003
|
src/compatibility/ngraph/*: C101,C812,C819,CCE001,E800,N806,P101,RST201,RST202,RST203,RST206,VNE001,VNE003
|
||||||
|
src/openvino/preprocess/torchvision/*: N801, VNE001
|
||||||
*__init__.py: F401
|
*__init__.py: F401
|
||||||
|
|
||||||
[pydocstyle]
|
[pydocstyle]
|
||||||
|
@ -0,0 +1,15 @@
|
|||||||
|
# Copyright (C) 2018-2023 Intel Corporation
|
||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
|
||||||
|
"""
|
||||||
|
Package: openvino
|
||||||
|
Torchvision to OpenVINO preprocess converter.
|
||||||
|
"""
|
||||||
|
|
||||||
|
# flake8: noqa
|
||||||
|
|
||||||
|
from openvino._pyopenvino import get_version as _get_version
|
||||||
|
|
||||||
|
__version__ = _get_version()
|
||||||
|
|
||||||
|
from .preprocess_converter import PreprocessConverter
|
@ -0,0 +1,47 @@
|
|||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
# Copyright (C) 2018-2023 Intel Corporation
|
||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
|
||||||
|
from typing import Callable, Any, Union
|
||||||
|
import logging
|
||||||
|
|
||||||
|
import openvino.runtime as ov
|
||||||
|
|
||||||
|
|
||||||
|
class PreprocessConverter():
|
||||||
|
def __init__(self, model: ov.Model):
|
||||||
|
self._model = model
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def from_torchvision(model: ov.Model, transform: Callable, input_example: Any,
|
||||||
|
input_name: Union[str, None] = None) -> ov.Model:
|
||||||
|
"""Embed torchvision preprocessing in an OpenVINO model.
|
||||||
|
|
||||||
|
Arguments:
|
||||||
|
model (ov.Model):
|
||||||
|
Result name
|
||||||
|
transform (Callable):
|
||||||
|
torchvision transform to convert
|
||||||
|
input_example (torch.Tensor or np.ndarray or PIL.Image):
|
||||||
|
Example of input data for transform to trace its structure.
|
||||||
|
Don't confuse with the model input.
|
||||||
|
input_name (str, optional):
|
||||||
|
Name of the current model's input node to connect with preprocessing.
|
||||||
|
Not needed if the model has one input.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
ov.Mode: OpenVINO Model object with embedded preprocessing
|
||||||
|
Example:
|
||||||
|
>>> model = PreprocessorConvertor.from_torchvision(model, "input", transform, input_example)
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
import PIL
|
||||||
|
import torch
|
||||||
|
from torchvision import transforms
|
||||||
|
from .torchvision_preprocessing import _from_torchvision
|
||||||
|
return _from_torchvision(model, transform, input_example, input_name)
|
||||||
|
except ImportError as e:
|
||||||
|
raise ImportError(f"Please install torch, torchvision and pillow packages:\n{e}")
|
||||||
|
except Exception as e:
|
||||||
|
logging.error(f"Unexpected error: {e}")
|
||||||
|
raise e
|
@ -0,0 +1,5 @@
|
|||||||
|
-c ../../../../constraints.txt
|
||||||
|
torch
|
||||||
|
torchvision; platform_machine == 'arm64' and python_version >= '3.8'
|
||||||
|
torchvision; platform_machine != 'arm64'
|
||||||
|
pillow
|
@ -0,0 +1,325 @@
|
|||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
# Copyright (C) 2018-2023 Intel Corporation
|
||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
|
||||||
|
# mypy: disable-error-code="no-redef"
|
||||||
|
|
||||||
|
import numbers
|
||||||
|
import logging
|
||||||
|
import copy
|
||||||
|
import numpy as np
|
||||||
|
from typing import List, Dict
|
||||||
|
from abc import ABCMeta, abstractmethod
|
||||||
|
from typing import Callable, Any, Union, Tuple
|
||||||
|
from typing import Sequence as SequenceType
|
||||||
|
from collections.abc import Sequence
|
||||||
|
from PIL import Image
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torchvision.transforms as transforms
|
||||||
|
from torchvision.transforms import InterpolationMode
|
||||||
|
|
||||||
|
import openvino.runtime as ov
|
||||||
|
import openvino.runtime.opset11 as ops
|
||||||
|
from openvino.runtime import Layout, Type
|
||||||
|
from openvino.runtime.utils.decorators import custom_preprocess_function
|
||||||
|
from openvino.preprocess import PrePostProcessor, ResizeAlgorithm, ColorFormat
|
||||||
|
|
||||||
|
|
||||||
|
TORCHTYPE_TO_OVTYPE = {
|
||||||
|
float: ov.Type.f32,
|
||||||
|
int: ov.Type.i32,
|
||||||
|
bool: ov.Type.boolean,
|
||||||
|
torch.float16: ov.Type.f16,
|
||||||
|
torch.float32: ov.Type.f32,
|
||||||
|
torch.float64: ov.Type.f64,
|
||||||
|
torch.uint8: ov.Type.u8,
|
||||||
|
torch.int8: ov.Type.i8,
|
||||||
|
torch.int32: ov.Type.i32,
|
||||||
|
torch.int64: ov.Type.i64,
|
||||||
|
torch.bool: ov.Type.boolean,
|
||||||
|
torch.DoubleTensor: ov.Type.f64,
|
||||||
|
torch.FloatTensor: ov.Type.f32,
|
||||||
|
torch.IntTensor: ov.Type.i32,
|
||||||
|
torch.LongTensor: ov.Type.i64,
|
||||||
|
torch.BoolTensor: ov.Type.boolean,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def _setup_size(size: Any, error_msg: str) -> SequenceType[int]:
|
||||||
|
# TODO: refactor into @singledispatch once Python 3.7 support is dropped
|
||||||
|
if isinstance(size, numbers.Number):
|
||||||
|
return int(size), int(size) # type: ignore
|
||||||
|
if isinstance(size, Sequence):
|
||||||
|
if len(size) == 1:
|
||||||
|
return size[0], size[0]
|
||||||
|
elif len(size) == 2:
|
||||||
|
return size
|
||||||
|
raise ValueError(error_msg)
|
||||||
|
|
||||||
|
|
||||||
|
def _NHWC_to_NCHW(input_shape: List) -> List: # noqa N802
|
||||||
|
new_shape = copy.deepcopy(input_shape)
|
||||||
|
new_shape[1] = input_shape[3]
|
||||||
|
new_shape[2] = input_shape[1]
|
||||||
|
new_shape[3] = input_shape[2]
|
||||||
|
return new_shape
|
||||||
|
|
||||||
|
|
||||||
|
def _to_list(transform: Callable) -> List:
|
||||||
|
# TODO: refactor into @singledispatch once Python 3.7 support is dropped
|
||||||
|
if isinstance(transform, torch.nn.Sequential):
|
||||||
|
return list(transform)
|
||||||
|
elif isinstance(transform, transforms.Compose):
|
||||||
|
return transform.transforms
|
||||||
|
else:
|
||||||
|
raise TypeError(f"Unsupported transform type: {type(transform)}")
|
||||||
|
|
||||||
|
|
||||||
|
def _get_shape_layout_from_data(input_example: Union[torch.Tensor, np.ndarray, Image.Image]) -> Tuple[List, Layout]:
|
||||||
|
if isinstance(input_example, (torch.Tensor, np.ndarray, Image.Image)): # PyTorch, OpenCV, numpy, PILLOW
|
||||||
|
shape = list(np.array(input_example, copy=False).shape)
|
||||||
|
layout = Layout("NCHW") if isinstance(input_example, torch.Tensor) else Layout("NHWC")
|
||||||
|
else:
|
||||||
|
raise TypeError(f"Unsupported input type: {type(input_example)}")
|
||||||
|
|
||||||
|
if len(shape) == 3:
|
||||||
|
shape = [1] + shape
|
||||||
|
elif len(shape) != 4:
|
||||||
|
raise ValueError(f"Unsupported number of input dimensions: {len(shape)}")
|
||||||
|
|
||||||
|
return shape, layout
|
||||||
|
|
||||||
|
|
||||||
|
class TransformConverterBase(metaclass=ABCMeta):
|
||||||
|
|
||||||
|
def __init__(self, **kwargs: Any) -> None: # noqa B027
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def convert(self, input_idx: int, ppp: PrePostProcessor, transform: Callable, meta: Dict) -> None:
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class TransformConverterFactory:
|
||||||
|
|
||||||
|
registry: Dict[str, Callable] = {}
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def register(cls: Callable, target_type: Union[Callable, None] = None) -> Callable:
|
||||||
|
def inner_wrapper(wrapped_class: TransformConverterBase) -> Callable:
|
||||||
|
registered_name = wrapped_class.__name__ if target_type is None else target_type.__name__
|
||||||
|
if registered_name in cls.registry:
|
||||||
|
logging.warning(f"Executor {registered_name} already exists. {wrapped_class.__name__} will replace it.")
|
||||||
|
cls.registry[registered_name] = wrapped_class
|
||||||
|
return wrapped_class # type: ignore
|
||||||
|
|
||||||
|
return inner_wrapper
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def convert(cls: Callable, converter_type: Callable, *args: Any, **kwargs: Any) -> Callable:
|
||||||
|
transform_name = converter_type.__name__
|
||||||
|
if transform_name not in cls.registry:
|
||||||
|
raise ValueError(f"{transform_name} is not supported.")
|
||||||
|
|
||||||
|
converter = cls.registry[transform_name]()
|
||||||
|
return converter.convert(*args, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
@TransformConverterFactory.register(transforms.Normalize)
|
||||||
|
class _(TransformConverterBase):
|
||||||
|
def convert(self, input_idx: int, ppp: PrePostProcessor, transform: Callable, meta: Dict) -> None:
|
||||||
|
if transform.inplace:
|
||||||
|
raise ValueError("Inplace Normaliziation is not supported.")
|
||||||
|
ppp.input(input_idx).preprocess().mean(transform.mean).scale(transform.std)
|
||||||
|
|
||||||
|
|
||||||
|
@TransformConverterFactory.register(transforms.ConvertImageDtype)
|
||||||
|
class _(TransformConverterBase):
|
||||||
|
def convert(self, input_idx: int, ppp: PrePostProcessor, transform: Callable, meta: Dict) -> None:
|
||||||
|
ppp.input(input_idx).preprocess().convert_element_type(TORCHTYPE_TO_OVTYPE[transform.dtype])
|
||||||
|
|
||||||
|
|
||||||
|
@TransformConverterFactory.register(transforms.Grayscale)
|
||||||
|
class _(TransformConverterBase):
|
||||||
|
def convert(self, input_idx: int, ppp: PrePostProcessor, transform: Callable, meta: Dict) -> None:
|
||||||
|
input_shape = meta["input_shape"]
|
||||||
|
layout = meta["layout"]
|
||||||
|
|
||||||
|
input_shape[layout.get_index_by_name("C")] = 1
|
||||||
|
|
||||||
|
ppp.input(input_idx).preprocess().convert_color(ColorFormat.GRAY)
|
||||||
|
if transform.num_output_channels != 1:
|
||||||
|
input_shape[layout.get_index_by_name("C")] = transform.num_output_channels
|
||||||
|
|
||||||
|
@custom_preprocess_function
|
||||||
|
def broadcast_node(output: ov.Output) -> Callable:
|
||||||
|
return ops.broadcast(
|
||||||
|
data=output,
|
||||||
|
target_shape=input_shape,
|
||||||
|
)
|
||||||
|
ppp.input(input_idx).preprocess().custom(broadcast_node)
|
||||||
|
|
||||||
|
meta["input_shape"] = input_shape
|
||||||
|
|
||||||
|
|
||||||
|
@TransformConverterFactory.register(transforms.Pad)
|
||||||
|
class _(TransformConverterBase):
|
||||||
|
def convert(self, input_idx: int, ppp: PrePostProcessor, transform: Callable, meta: Dict) -> None:
|
||||||
|
image_dimensions = list(meta["image_dimensions"])
|
||||||
|
layout = meta["layout"]
|
||||||
|
torch_padding = transform.padding
|
||||||
|
pad_mode = transform.padding_mode
|
||||||
|
|
||||||
|
if pad_mode == "constant":
|
||||||
|
if isinstance(transform.fill, tuple):
|
||||||
|
raise ValueError("Different fill values for R, G, B channels are not supported.")
|
||||||
|
|
||||||
|
pads_begin = [0 for _ in meta["input_shape"]]
|
||||||
|
pads_end = [0 for _ in meta["input_shape"]]
|
||||||
|
|
||||||
|
# padding equal on all sides
|
||||||
|
if isinstance(torch_padding, int):
|
||||||
|
image_dimensions[0] += 2 * torch_padding
|
||||||
|
image_dimensions[1] += 2 * torch_padding
|
||||||
|
|
||||||
|
pads_begin[layout.get_index_by_name("H")] = torch_padding
|
||||||
|
pads_begin[layout.get_index_by_name("W")] = torch_padding
|
||||||
|
pads_end[layout.get_index_by_name("H")] = torch_padding
|
||||||
|
pads_end[layout.get_index_by_name("W")] = torch_padding
|
||||||
|
|
||||||
|
# padding different in horizontal and vertical axis
|
||||||
|
elif len(torch_padding) == 2:
|
||||||
|
image_dimensions[0] += sum(torch_padding)
|
||||||
|
image_dimensions[1] += sum(torch_padding)
|
||||||
|
|
||||||
|
pads_begin[layout.get_index_by_name("H")] = torch_padding[1]
|
||||||
|
pads_begin[layout.get_index_by_name("W")] = torch_padding[0]
|
||||||
|
pads_end[layout.get_index_by_name("H")] = torch_padding[1]
|
||||||
|
pads_end[layout.get_index_by_name("W")] = torch_padding[0]
|
||||||
|
|
||||||
|
# padding different on top, bottom, left and right of image
|
||||||
|
else:
|
||||||
|
image_dimensions[0] += torch_padding[1] + torch_padding[3]
|
||||||
|
image_dimensions[1] += torch_padding[0] + torch_padding[2]
|
||||||
|
|
||||||
|
pads_begin[layout.get_index_by_name("H")] = torch_padding[1]
|
||||||
|
pads_begin[layout.get_index_by_name("W")] = torch_padding[0]
|
||||||
|
pads_end[layout.get_index_by_name("H")] = torch_padding[3]
|
||||||
|
pads_end[layout.get_index_by_name("W")] = torch_padding[2]
|
||||||
|
|
||||||
|
@custom_preprocess_function
|
||||||
|
def pad_node(output: ov.Output) -> Callable:
|
||||||
|
return ops.pad(
|
||||||
|
output,
|
||||||
|
pad_mode=pad_mode,
|
||||||
|
pads_begin=pads_begin,
|
||||||
|
pads_end=pads_end,
|
||||||
|
arg_pad_value=np.array(transform.fill, dtype=np.uint8) if pad_mode == "constant" else None,
|
||||||
|
)
|
||||||
|
|
||||||
|
ppp.input(input_idx).preprocess().custom(pad_node)
|
||||||
|
meta["image_dimensions"] = tuple(image_dimensions)
|
||||||
|
|
||||||
|
|
||||||
|
@TransformConverterFactory.register(transforms.ToTensor)
|
||||||
|
class _(TransformConverterBase):
|
||||||
|
def convert(self, input_idx: int, ppp: PrePostProcessor, transform: Callable, meta: Dict) -> None:
|
||||||
|
input_shape = meta["input_shape"]
|
||||||
|
layout = meta["layout"]
|
||||||
|
|
||||||
|
ppp.input(input_idx).tensor().set_element_type(Type.u8).set_layout(Layout("NHWC")).set_color_format(ColorFormat.RGB) # noqa ECE001
|
||||||
|
|
||||||
|
if layout == Layout("NHWC"):
|
||||||
|
input_shape = _NHWC_to_NCHW(input_shape)
|
||||||
|
layout = Layout("NCHW")
|
||||||
|
ppp.input(input_idx).preprocess().convert_layout(layout)
|
||||||
|
ppp.input(input_idx).preprocess().convert_element_type(Type.f32)
|
||||||
|
ppp.input(input_idx).preprocess().scale(255.0)
|
||||||
|
|
||||||
|
meta["input_shape"] = input_shape
|
||||||
|
meta["layout"] = layout
|
||||||
|
|
||||||
|
|
||||||
|
@TransformConverterFactory.register(transforms.CenterCrop)
|
||||||
|
class _(TransformConverterBase):
|
||||||
|
def convert(self, input_idx: int, ppp: PrePostProcessor, transform: Callable, meta: Dict) -> None:
|
||||||
|
input_shape = meta["input_shape"]
|
||||||
|
source_size = meta["image_dimensions"]
|
||||||
|
target_size = _setup_size(transform.size, "Incorrect size type for CenterCrop operation")
|
||||||
|
|
||||||
|
if target_size[0] > source_size[0] or target_size[1] > source_size[1]:
|
||||||
|
ValueError(f"CenterCrop size={target_size} is greater than source_size={source_size}")
|
||||||
|
|
||||||
|
bottom_left = []
|
||||||
|
bottom_left.append(int((source_size[0] - target_size[0]) / 2))
|
||||||
|
bottom_left.append(int((source_size[1] - target_size[1]) / 2))
|
||||||
|
|
||||||
|
top_right = []
|
||||||
|
top_right.append(min(bottom_left[0] + target_size[0], source_size[0] - 1))
|
||||||
|
top_right.append(min(bottom_left[1] + target_size[1], source_size[1] - 1))
|
||||||
|
|
||||||
|
bottom_left = [0] * len(input_shape[:-2]) + bottom_left if meta["layout"] == Layout("NCHW") else [0] + bottom_left + [0] # noqa ECE001
|
||||||
|
top_right = input_shape[:-2] + top_right if meta["layout"] == Layout("NCHW") else input_shape[:1] + top_right + input_shape[-1:]
|
||||||
|
|
||||||
|
ppp.input(input_idx).preprocess().crop(bottom_left, top_right)
|
||||||
|
meta["image_dimensions"] = (target_size[-2], target_size[-1])
|
||||||
|
|
||||||
|
|
||||||
|
@TransformConverterFactory.register(transforms.Resize)
|
||||||
|
class _(TransformConverterBase):
|
||||||
|
def convert(self, input_idx: int, ppp: PrePostProcessor, transform: Callable, meta: Dict) -> None:
|
||||||
|
resize_mode_map = {
|
||||||
|
InterpolationMode.NEAREST: ResizeAlgorithm.RESIZE_NEAREST,
|
||||||
|
}
|
||||||
|
if transform.max_size:
|
||||||
|
raise ValueError("Resize with max_size if not supported")
|
||||||
|
if transform.interpolation is not InterpolationMode.NEAREST:
|
||||||
|
raise ValueError("Only InterpolationMode.NEAREST is supported.")
|
||||||
|
|
||||||
|
h, w = _setup_size(transform.size, "Incorrect size type for Resize operation")
|
||||||
|
|
||||||
|
ppp.input(input_idx).tensor().set_layout(Layout("NCHW"))
|
||||||
|
|
||||||
|
input_shape = meta["input_shape"]
|
||||||
|
|
||||||
|
input_shape[meta["layout"].get_index_by_name("H")] = -1
|
||||||
|
input_shape[meta["layout"].get_index_by_name("W")] = -1
|
||||||
|
|
||||||
|
ppp.input(input_idx).tensor().set_shape(input_shape)
|
||||||
|
ppp.input(input_idx).preprocess().resize(resize_mode_map[transform.interpolation], h, w)
|
||||||
|
meta["input_shape"] = input_shape
|
||||||
|
meta["image_dimensions"] = (h, w)
|
||||||
|
|
||||||
|
|
||||||
|
def _from_torchvision(model: ov.Model, transform: Callable, input_example: Any, input_name: Union[str, None] = None) -> ov.Model:
|
||||||
|
|
||||||
|
if input_name is not None:
|
||||||
|
input_idx = next((i for i, p in enumerate(model.get_parameters()) if p.get_friendly_name() == input_name), None)
|
||||||
|
else:
|
||||||
|
if len(model.get_parameters()) == 1:
|
||||||
|
input_idx = 0
|
||||||
|
else:
|
||||||
|
raise ValueError("Model contains multiple inputs. Please specify the name of the input to which prepocessing is added.")
|
||||||
|
|
||||||
|
if input_idx is None:
|
||||||
|
raise ValueError(f"Input with name {input_name} is not found")
|
||||||
|
|
||||||
|
input_shape, layout = _get_shape_layout_from_data(input_example)
|
||||||
|
|
||||||
|
ppp = PrePostProcessor(model)
|
||||||
|
ppp.input(input_idx).tensor().set_layout(layout)
|
||||||
|
ppp.input(input_idx).tensor().set_shape(input_shape)
|
||||||
|
|
||||||
|
image_dimensions = [input_shape[layout.get_index_by_name("H")], input_shape[layout.get_index_by_name("W")]]
|
||||||
|
global_meta = {
|
||||||
|
"input_shape": input_shape,
|
||||||
|
"image_dimensions": image_dimensions,
|
||||||
|
"layout": layout,
|
||||||
|
}
|
||||||
|
|
||||||
|
for tm in _to_list(transform):
|
||||||
|
TransformConverterFactory.convert(type(tm), input_idx, ppp, tm, global_meta)
|
||||||
|
|
||||||
|
updated_model = ppp.build()
|
||||||
|
return updated_model
|
@ -116,7 +116,8 @@ if(OpenVINO_SOURCE_DIR OR OpenVINODeveloperPackage_FOUND)
|
|||||||
COMPONENT ${OV_CPACK_COMP_PYTHON_OPENVINO}_${pyversion}
|
COMPONENT ${OV_CPACK_COMP_PYTHON_OPENVINO}_${pyversion}
|
||||||
${OV_CPACK_COMP_PYTHON_OPENVINO_EXCLUDE_ALL}
|
${OV_CPACK_COMP_PYTHON_OPENVINO_EXCLUDE_ALL}
|
||||||
USE_SOURCE_PERMISSIONS
|
USE_SOURCE_PERMISSIONS
|
||||||
PATTERN "test_utils" EXCLUDE)
|
PATTERN "test_utils" EXCLUDE
|
||||||
|
PATTERN "torchvision/requirements.txt" EXCLUDE)
|
||||||
|
|
||||||
install(TARGETS ${PROJECT_NAME}
|
install(TARGETS ${PROJECT_NAME}
|
||||||
DESTINATION ${OV_CPACK_PYTHONDIR}/openvino
|
DESTINATION ${OV_CPACK_PYTHONDIR}/openvino
|
||||||
@ -127,6 +128,11 @@ if(OpenVINO_SOURCE_DIR OR OpenVINODeveloperPackage_FOUND)
|
|||||||
DESTINATION ${OV_CPACK_PYTHONDIR}
|
DESTINATION ${OV_CPACK_PYTHONDIR}
|
||||||
COMPONENT ${OV_CPACK_COMP_OPENVINO_DEV_REQ_FILES}
|
COMPONENT ${OV_CPACK_COMP_OPENVINO_DEV_REQ_FILES}
|
||||||
${OV_CPACK_COMP_OPENVINO_DEV_REQ_FILES_EXCLUDE_ALL})
|
${OV_CPACK_COMP_OPENVINO_DEV_REQ_FILES_EXCLUDE_ALL})
|
||||||
|
|
||||||
|
install(FILES ${OpenVINOPython_SOURCE_DIR}/src/openvino/preprocess/torchvision/requirements.txt
|
||||||
|
DESTINATION ${OV_CPACK_PYTHONDIR}
|
||||||
|
COMPONENT ${OV_CPACK_COMP_OPENVINO_DEV_REQ_FILES}/src/openvino/preprocess/torchvision
|
||||||
|
${OV_CPACK_COMP_OPENVINO_DEV_REQ_FILES_EXCLUDE_ALL})
|
||||||
|
|
||||||
install(DIRECTORY ${OpenVINOPython_SOURCE_DIR}/tests
|
install(DIRECTORY ${OpenVINOPython_SOURCE_DIR}/tests
|
||||||
DESTINATION tests/${PROJECT_NAME}
|
DESTINATION tests/${PROJECT_NAME}
|
||||||
|
@ -0,0 +1,241 @@
|
|||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
# Copyright (C) 2018-2023 Intel Corporation
|
||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import copy
|
||||||
|
import pytest
|
||||||
|
import platform
|
||||||
|
from PIL import Image
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn.functional as f
|
||||||
|
import torchvision.transforms as transforms
|
||||||
|
|
||||||
|
from openvino.runtime import Core, Tensor
|
||||||
|
from openvino.tools.mo import convert_model
|
||||||
|
|
||||||
|
from openvino.preprocess.torchvision import PreprocessConverter
|
||||||
|
|
||||||
|
|
||||||
|
class Convnet(torch.nn.Module):
|
||||||
|
def __init__(self, input_channels):
|
||||||
|
super(Convnet, self).__init__()
|
||||||
|
self.conv1 = torch.nn.Conv2d(input_channels, 6, 5)
|
||||||
|
self.conv2 = torch.nn.Conv2d(6, 16, 3)
|
||||||
|
|
||||||
|
def forward(self, data):
|
||||||
|
data = f.max_pool2d(f.relu(self.conv1(data)), 2)
|
||||||
|
data = f.max_pool2d(f.relu(self.conv2(data)), 2)
|
||||||
|
return data
|
||||||
|
|
||||||
|
|
||||||
|
def _infer_pipelines(test_input, preprocess_pipeline, input_channels=3):
|
||||||
|
torch_model = Convnet(input_channels)
|
||||||
|
example_input = Tensor(np.expand_dims(test_input, axis=0).astype(np.float32))
|
||||||
|
ov_model = convert_model(torch_model, example_input=example_input)
|
||||||
|
core = Core()
|
||||||
|
|
||||||
|
ov_model = PreprocessConverter.from_torchvision(
|
||||||
|
model=ov_model, transform=preprocess_pipeline, input_example=Image.fromarray(test_input.astype("uint8"), "RGB"),
|
||||||
|
)
|
||||||
|
ov_model = core.compile_model(ov_model, "CPU")
|
||||||
|
|
||||||
|
# Torch results
|
||||||
|
torch_input = copy.deepcopy(test_input)
|
||||||
|
test_image = Image.fromarray(torch_input.astype("uint8"), "RGB")
|
||||||
|
transformed_input = preprocess_pipeline(test_image)
|
||||||
|
transformed_input = torch.unsqueeze(transformed_input, dim=0)
|
||||||
|
with torch.no_grad():
|
||||||
|
torch_result = torch_model(transformed_input).numpy()
|
||||||
|
|
||||||
|
# OpenVINO results
|
||||||
|
ov_input = test_input
|
||||||
|
ov_input = np.expand_dims(ov_input, axis=0)
|
||||||
|
output = ov_model.output(0)
|
||||||
|
ov_result = ov_model(ov_input)[output]
|
||||||
|
|
||||||
|
return torch_result, ov_result
|
||||||
|
|
||||||
|
|
||||||
|
def test_normalize():
|
||||||
|
test_input = np.random.randint(255, size=(224, 224, 3), dtype=np.uint8)
|
||||||
|
preprocess_pipeline = transforms.Compose([transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])])
|
||||||
|
torch_result, ov_result = _infer_pipelines(test_input, preprocess_pipeline)
|
||||||
|
assert np.max(np.absolute(torch_result - ov_result)) < 4e-05
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
("interpolation", "tolerance"),
|
||||||
|
[
|
||||||
|
(transforms.InterpolationMode.NEAREST, 4e-05),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_resize(interpolation, tolerance):
|
||||||
|
if platform.machine() in ["arm", "armv7l", "aarch64", "arm64", "ARM64"]:
|
||||||
|
pytest.skip("Ticket: 114816")
|
||||||
|
test_input = np.random.randint(255, size=(220, 220, 3), dtype=np.uint8)
|
||||||
|
preprocess_pipeline = transforms.Compose(
|
||||||
|
[
|
||||||
|
transforms.Resize(224, interpolation=interpolation),
|
||||||
|
transforms.ToTensor(),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
torch_result, ov_result = _infer_pipelines(test_input, preprocess_pipeline)
|
||||||
|
assert np.max(np.absolute(torch_result - ov_result)) < tolerance
|
||||||
|
|
||||||
|
|
||||||
|
def test_convertimagedtype():
|
||||||
|
test_input = np.random.randint(255, size=(224, 224, 3), dtype=np.uint8)
|
||||||
|
preprocess_pipeline = transforms.Compose(
|
||||||
|
[
|
||||||
|
transforms.ToTensor(),
|
||||||
|
transforms.ConvertImageDtype(torch.float16),
|
||||||
|
transforms.ConvertImageDtype(torch.float32),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
torch_result, ov_result = _infer_pipelines(test_input, preprocess_pipeline)
|
||||||
|
assert np.max(np.absolute(torch_result - ov_result)) < 2e-04
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
("test_input", "preprocess_pipeline"),
|
||||||
|
[
|
||||||
|
(
|
||||||
|
np.random.randint(255, size=(220, 220, 3), dtype=np.uint8),
|
||||||
|
transforms.Compose(
|
||||||
|
[
|
||||||
|
transforms.Pad((2)),
|
||||||
|
transforms.ToTensor(),
|
||||||
|
],
|
||||||
|
),
|
||||||
|
),
|
||||||
|
(
|
||||||
|
np.random.randint(255, size=(218, 220, 3), dtype=np.uint8),
|
||||||
|
transforms.Compose(
|
||||||
|
[
|
||||||
|
transforms.Pad((2, 3)),
|
||||||
|
transforms.ToTensor(),
|
||||||
|
],
|
||||||
|
),
|
||||||
|
),
|
||||||
|
(
|
||||||
|
np.random.randint(255, size=(216, 218, 3), dtype=np.uint8),
|
||||||
|
transforms.Compose(
|
||||||
|
[
|
||||||
|
transforms.Pad((2, 3, 4, 5)),
|
||||||
|
transforms.ToTensor(),
|
||||||
|
],
|
||||||
|
),
|
||||||
|
),
|
||||||
|
(
|
||||||
|
np.random.randint(255, size=(216, 218, 3), dtype=np.uint8),
|
||||||
|
transforms.Compose(
|
||||||
|
[
|
||||||
|
transforms.Pad((2, 3, 4, 5), fill=3),
|
||||||
|
transforms.ToTensor(),
|
||||||
|
],
|
||||||
|
),
|
||||||
|
),
|
||||||
|
(
|
||||||
|
np.random.randint(255, size=(218, 220, 3), dtype=np.uint8),
|
||||||
|
transforms.Compose(
|
||||||
|
[
|
||||||
|
transforms.Pad((2, 3), padding_mode="edge"),
|
||||||
|
transforms.ToTensor(),
|
||||||
|
],
|
||||||
|
),
|
||||||
|
),
|
||||||
|
(
|
||||||
|
np.random.randint(255, size=(218, 220, 3), dtype=np.uint8),
|
||||||
|
transforms.Compose(
|
||||||
|
[
|
||||||
|
transforms.Pad((2, 3), padding_mode="reflect"),
|
||||||
|
transforms.ToTensor(),
|
||||||
|
],
|
||||||
|
),
|
||||||
|
),
|
||||||
|
(
|
||||||
|
np.random.randint(255, size=(218, 220, 3), dtype=np.uint8),
|
||||||
|
transforms.Compose(
|
||||||
|
[
|
||||||
|
transforms.Pad((2, 3), padding_mode="symmetric"),
|
||||||
|
transforms.ToTensor(),
|
||||||
|
],
|
||||||
|
),
|
||||||
|
),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_pad(test_input, preprocess_pipeline):
|
||||||
|
torch_result, ov_result = _infer_pipelines(test_input, preprocess_pipeline)
|
||||||
|
assert np.max(np.absolute(torch_result - ov_result)) < 4e-05
|
||||||
|
|
||||||
|
|
||||||
|
def test_centercrop():
|
||||||
|
test_input = np.random.randint(255, size=(260, 260, 3), dtype=np.uint8)
|
||||||
|
preprocess_pipeline = transforms.Compose(
|
||||||
|
[
|
||||||
|
transforms.CenterCrop((224)),
|
||||||
|
transforms.ToTensor(),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
torch_result, ov_result = _infer_pipelines(test_input, preprocess_pipeline)
|
||||||
|
assert np.max(np.absolute(torch_result - ov_result)) < 4e-05
|
||||||
|
|
||||||
|
|
||||||
|
def test_grayscale():
|
||||||
|
test_input = np.random.randint(255, size=(224, 224, 3), dtype=np.uint8)
|
||||||
|
preprocess_pipeline = transforms.Compose([transforms.ToTensor(), transforms.Grayscale()])
|
||||||
|
torch_result, ov_result = _infer_pipelines(test_input, preprocess_pipeline, input_channels=1)
|
||||||
|
assert np.max(np.absolute(torch_result - ov_result)) < 2e-04
|
||||||
|
|
||||||
|
|
||||||
|
def test_grayscale_num_output_channels():
|
||||||
|
test_input = np.random.randint(255, size=(224, 224, 3), dtype=np.uint8)
|
||||||
|
preprocess_pipeline = transforms.Compose([transforms.ToTensor(), transforms.Grayscale(3)])
|
||||||
|
torch_result, ov_result = _infer_pipelines(test_input, preprocess_pipeline)
|
||||||
|
assert np.max(np.absolute(torch_result - ov_result)) < 2e-04
|
||||||
|
|
||||||
|
|
||||||
|
def test_pipeline_1():
|
||||||
|
test_input = np.random.randint(255, size=(260, 260, 3), dtype=np.uint8)
|
||||||
|
preprocess_pipeline = transforms.Compose(
|
||||||
|
[
|
||||||
|
transforms.Resize(256, interpolation=transforms.InterpolationMode.NEAREST),
|
||||||
|
transforms.CenterCrop((216, 218)),
|
||||||
|
transforms.Pad((2, 3, 4, 5), fill=3),
|
||||||
|
transforms.ToTensor(),
|
||||||
|
transforms.ConvertImageDtype(torch.float32),
|
||||||
|
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
torch_result, ov_result = _infer_pipelines(test_input, preprocess_pipeline)
|
||||||
|
assert np.max(np.absolute(torch_result - ov_result)) < 4e-05
|
||||||
|
|
||||||
|
|
||||||
|
def test_pipeline_2():
|
||||||
|
if platform.machine() in ["arm", "armv7l", "aarch64", "arm64", "ARM64"]:
|
||||||
|
pytest.skip("Ticket: 114816")
|
||||||
|
test_input = np.random.randint(255, size=(224, 224, 3), dtype=np.uint8)
|
||||||
|
preprocess_pipeline = transforms.Compose(
|
||||||
|
[
|
||||||
|
transforms.Resize(256, interpolation=transforms.InterpolationMode.NEAREST),
|
||||||
|
transforms.CenterCrop(224),
|
||||||
|
transforms.ToTensor(),
|
||||||
|
transforms.Normalize((0.481, 0.457, 0.408), (0.268, 0.261, 0.275)),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
torch_result, ov_result = _infer_pipelines(test_input, preprocess_pipeline)
|
||||||
|
assert np.max(np.absolute(torch_result - ov_result)) < 5e-03
|
||||||
|
|
||||||
|
|
||||||
|
def test_pipeline_3():
|
||||||
|
test_input = np.random.randint(255, size=(260, 260, 3), dtype=np.uint8)
|
||||||
|
preprocess_pipeline = transforms.Compose(
|
||||||
|
[
|
||||||
|
transforms.ToTensor(),
|
||||||
|
transforms.ConvertImageDtype(torch.float),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
torch_result, ov_result = _infer_pipelines(test_input, preprocess_pipeline)
|
||||||
|
assert np.max(np.absolute(torch_result - ov_result)) < 2e-03
|
Loading…
Reference in New Issue
Block a user