[PyOV] Add torchvision to OpenVINO preprocessing converter (#17934)

* Some progress

* refactoring

* Refactor tests, run black

* Refactor, flake

* Minor change

* Add support for num_output_channels

* WIP

* Add dependency

* Almost done

* Fix flake

* Add MO convert

* Minor changes

* Add interpolation mode tests

* Add requirements for preprocessing

* Fix linter

* Update tests

* Fix type error

* Introduce typing

* Rename module

* Code review

* Update src/bindings/python/src/openvino/preprocess/torchvision/torchvision_preprocessing.py

Co-authored-by: Anastasia Kuporosova <anastasia.kuporosova@intel.com>

* Update src/bindings/python/src/openvino/preprocess/torchvision/preprocess_converter.py

Co-authored-by: Anastasia Kuporosova <anastasia.kuporosova@intel.com>

* CR changes

* Fix mypy

* Minor change

* Minor change

* Fix flake

* typing change

* Update types

* Add tools to pythonpath

* Add MO reqs

* Minor change

* Fix PT FE issue

* bugfix

* Use absolute path

* Debug

* Debug req path

* Change MO to OVC

* Enable PT FE building

* Debug

* Skip some tests on ARM

* ADd ticket numbers

* Some progress

* refactoring

* Refactor tests, run black

* Refactor, flake

* Minor change

* Add support for num_output_channels

* WIP

* Add dependency

* Almost done

* Fix flake

* Add MO convert

* Minor changes

* Add interpolation mode tests

* Add requirements for preprocessing

* Fix linter

* Update tests

* Fix type error

* Introduce typing

* Rename module

* Code review

* Update src/bindings/python/src/openvino/preprocess/torchvision/torchvision_preprocessing.py

Co-authored-by: Anastasia Kuporosova <anastasia.kuporosova@intel.com>

* Update src/bindings/python/src/openvino/preprocess/torchvision/preprocess_converter.py

Co-authored-by: Anastasia Kuporosova <anastasia.kuporosova@intel.com>

* CR changes

* Fix mypy

* Minor change

* Minor change

* Fix flake

* typing change

* Update types

* Add tools to pythonpath

* Add MO reqs

* Minor change

* Fix PT FE issue

* bugfix

* Use absolute path

* Debug

* Debug req path

* Change MO to OVC

* Enable PT FE building

* Debug

* Skip some tests on ARM

* ADd ticket numbers

* Update test val

* Change pytest version

* Modify GH workflow

* Update cmake

* Update cmake

* fix cmake

* Skip tests on ARM

* Fix after OVC MO changes

* Limit torch requirements

* Raise allowed diff in test

* Debug - remove new bindings

* Debug - remove interpolation modes

* Debug - change test

* Cleanup

* Minor change

* Disable on ARM and Py<38

* Update ARM marker

* limit torchvision for ARM

---------

Co-authored-by: Anastasia Kuporosova <anastasia.kuporosova@intel.com>
Co-authored-by: gklodkox <gracjanx.klodkowski@intel.com>
Co-authored-by: Michal Lukaszewski <michal.lukaszewski@intel.com>
This commit is contained in:
Przemyslaw Wysocki 2023-08-02 14:44:30 +02:00 committed by GitHub
parent 7be660e551
commit 62fa09a181
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
12 changed files with 660 additions and 5 deletions

View File

@ -178,6 +178,8 @@ jobs:
python3 -m pip install -r $(REPO_DIR)/src/frontends/onnx/tests/requirements.txt python3 -m pip install -r $(REPO_DIR)/src/frontends/onnx/tests/requirements.txt
# For running TensorFlow frontend unit tests # For running TensorFlow frontend unit tests
python3 -m pip install -r $(REPO_DIR)/src/frontends/tensorflow/tests/requirements.txt python3 -m pip install -r $(REPO_DIR)/src/frontends/tensorflow/tests/requirements.txt
# For running torchvision -> OpenVINO preprocess converter
python3 -m pip install -r $(REPO_DIR)/src/bindings/python/src/openvino/preprocess/torchvision/requirements.txt
# For MO unit tests # For MO unit tests
python3 -m pip install -r $(REPO_DIR)/tools/mo/requirements_mxnet.txt python3 -m pip install -r $(REPO_DIR)/tools/mo/requirements_mxnet.txt
python3 -m pip install -r $(REPO_DIR)/tools/mo/requirements_caffe.txt python3 -m pip install -r $(REPO_DIR)/tools/mo/requirements_caffe.txt

View File

@ -62,7 +62,7 @@ RUN cmake .. \
-DENABLE_PROFILING_ITT=OFF \ -DENABLE_PROFILING_ITT=OFF \
-DENABLE_SAMPLES=OFF \ -DENABLE_SAMPLES=OFF \
-DENABLE_OV_PADDLE_FRONTEND=OFF \ -DENABLE_OV_PADDLE_FRONTEND=OFF \
-DENABLE_OV_PYTORCH_FRONTEND=OFF \ -DENABLE_OV_PYTORCH_FRONTEND=ON \
-DENABLE_OV_TF_FRONTEND=OFF \ -DENABLE_OV_TF_FRONTEND=OFF \
-DENABLE_OPENVINO_DEBUG=OFF \ -DENABLE_OPENVINO_DEBUG=OFF \
-DCMAKE_INSTALL_PREFIX=/openvino/dist -DCMAKE_INSTALL_PREFIX=/openvino/dist
@ -72,5 +72,5 @@ RUN ninja install
WORKDIR /openvino/src/bindings/python WORKDIR /openvino/src/bindings/python
ENV OpenVINO_DIR=/openvino/dist/runtime/cmake ENV OpenVINO_DIR=/openvino/dist/runtime/cmake
ENV LD_LIBRARY_PATH=/openvino/dist/runtime/lib/intel64:/openvino/dist/runtime/3rdparty/tbb/lib ENV LD_LIBRARY_PATH=/openvino/dist/runtime/lib/intel64:/openvino/dist/runtime/3rdparty/tbb/lib
ENV PYTHONPATH=/openvino/bin/intel64/${BUILD_TYPE}/python:${PYTHONPATH} ENV PYTHONPATH=/openvino/bin/intel64/${BUILD_TYPE}/python:/openvino/tools/mo:${PYTHONPATH}
CMD tox CMD tox

View File

@ -531,6 +531,9 @@ jobs:
# For running Paddle frontend unit tests # For running Paddle frontend unit tests
python3 -m pip install -r ${{ env.OPENVINO_REPO }}/src/frontends/paddle/tests/requirements.txt python3 -m pip install -r ${{ env.OPENVINO_REPO }}/src/frontends/paddle/tests/requirements.txt
# For torchvision to OpenVINO preprocessing converter
python3 -m pip install -r ${{ env.OPENVINO_REPO }}/src/bindings/python/src/openvino/preprocess/torchvision/requirements.txt
- name: Install MO dependencies - name: Install MO dependencies
run: | run: |
python3 -m pip install -r ${{ env.OPENVINO_REPO }}/tools/mo/requirements_mxnet.txt python3 -m pip install -r ${{ env.OPENVINO_REPO }}/tools/mo/requirements_mxnet.txt

View File

@ -4,7 +4,7 @@ numpy>=1.16.6,<1.26 # Python bindings, frontends
# pytest # pytest
pytest>=5.0,<7.4 pytest>=5.0,<7.4
pytest-dependency==0.5.1 pytest-dependency==0.5.1
pytest-html==3.1.1 pytest-html==3.2.0
pytest-timeout==2.1.0 pytest-timeout==2.1.0
# Python bindings # Python bindings
@ -19,4 +19,8 @@ paddlepaddle==2.4.2
tensorflow>=1.15.5,<2.13.0 tensorflow>=1.15.5,<2.13.0
six~=1.16.0 six~=1.16.0
protobuf>=3.18.1,<4.0.0 protobuf>=3.18.1,<4.0.0
onnx==1.13.1 onnx==1.13.1
# torchvision > OpenVINO preprocessing converter
pillow>=9.0
torch>=1.13

View File

@ -40,3 +40,7 @@ tox
types-pkg_resources types-pkg_resources
wheel wheel
singledispatchmethod singledispatchmethod
torch
torchvision; platform_machine == 'arm64' and python_version >= '3.8'
torchvision; platform_machine != 'arm64'
pillow

View File

@ -7,6 +7,7 @@ skip_install=True
deps = deps =
-rrequirements.txt -rrequirements.txt
-rrequirements_test.txt -rrequirements_test.txt
-r /openvino/tools/mo/requirements.txt # for torchvision -> OV preprocess converter
-r /openvino/src/frontends/onnx/tests/requirements.txt -r /openvino/src/frontends/onnx/tests/requirements.txt
setenv = setenv =
OV_BACKEND = {env:OV_BACKEND:"CPU"} OV_BACKEND = {env:OV_BACKEND:"CPU"}
@ -43,6 +44,7 @@ deps = -rrequirements.txt
# D107 - Missing docstring in __init__ # D107 - Missing docstring in __init__
# D412 - No blank lines allowed between a section header and its content # D412 - No blank lines allowed between a section header and its content
# F401 - module imported but unused # F401 - module imported but unused
# N801 - class name '...' should use CapWords convention
# N803 - argument name '...' should be lowercase # N803 - argument name '...' should be lowercase
# T001 - print found # T001 - print found
# W503 - line break before binary operator (prefer line breaks before op, not after) # W503 - line break before binary operator (prefer line breaks before op, not after)
@ -66,6 +68,7 @@ per-file-ignores =
src/openvino/runtime/*/ops.py: VNE001,VNE003 src/openvino/runtime/*/ops.py: VNE001,VNE003
tests_compatibility/test_ngraph/*: C101,C812,C815,C816,C819,CCE001,D212,E800,ECE001,N400,N802,N806,P101,P103,PT001,PT005,PT006,PT011,PT019,PT023,RST201,S001,VNE002 tests_compatibility/test_ngraph/*: C101,C812,C815,C816,C819,CCE001,D212,E800,ECE001,N400,N802,N806,P101,P103,PT001,PT005,PT006,PT011,PT019,PT023,RST201,S001,VNE002
src/compatibility/ngraph/*: C101,C812,C819,CCE001,E800,N806,P101,RST201,RST202,RST203,RST206,VNE001,VNE003 src/compatibility/ngraph/*: C101,C812,C819,CCE001,E800,N806,P101,RST201,RST202,RST203,RST206,VNE001,VNE003
src/openvino/preprocess/torchvision/*: N801, VNE001
*__init__.py: F401 *__init__.py: F401
[pydocstyle] [pydocstyle]

View File

@ -0,0 +1,15 @@
# Copyright (C) 2018-2023 Intel Corporation
# SPDX-License-Identifier: Apache-2.0
"""
Package: openvino
Torchvision to OpenVINO preprocess converter.
"""
# flake8: noqa
from openvino._pyopenvino import get_version as _get_version
__version__ = _get_version()
from .preprocess_converter import PreprocessConverter

View File

@ -0,0 +1,47 @@
# -*- coding: utf-8 -*-
# Copyright (C) 2018-2023 Intel Corporation
# SPDX-License-Identifier: Apache-2.0
from typing import Callable, Any, Union
import logging
import openvino.runtime as ov
class PreprocessConverter():
def __init__(self, model: ov.Model):
self._model = model
@staticmethod
def from_torchvision(model: ov.Model, transform: Callable, input_example: Any,
input_name: Union[str, None] = None) -> ov.Model:
"""Embed torchvision preprocessing in an OpenVINO model.
Arguments:
model (ov.Model):
Result name
transform (Callable):
torchvision transform to convert
input_example (torch.Tensor or np.ndarray or PIL.Image):
Example of input data for transform to trace its structure.
Don't confuse with the model input.
input_name (str, optional):
Name of the current model's input node to connect with preprocessing.
Not needed if the model has one input.
Returns:
ov.Mode: OpenVINO Model object with embedded preprocessing
Example:
>>> model = PreprocessorConvertor.from_torchvision(model, "input", transform, input_example)
"""
try:
import PIL
import torch
from torchvision import transforms
from .torchvision_preprocessing import _from_torchvision
return _from_torchvision(model, transform, input_example, input_name)
except ImportError as e:
raise ImportError(f"Please install torch, torchvision and pillow packages:\n{e}")
except Exception as e:
logging.error(f"Unexpected error: {e}")
raise e

View File

@ -0,0 +1,5 @@
-c ../../../../constraints.txt
torch
torchvision; platform_machine == 'arm64' and python_version >= '3.8'
torchvision; platform_machine != 'arm64'
pillow

View File

@ -0,0 +1,325 @@
# -*- coding: utf-8 -*-
# Copyright (C) 2018-2023 Intel Corporation
# SPDX-License-Identifier: Apache-2.0
# mypy: disable-error-code="no-redef"
import numbers
import logging
import copy
import numpy as np
from typing import List, Dict
from abc import ABCMeta, abstractmethod
from typing import Callable, Any, Union, Tuple
from typing import Sequence as SequenceType
from collections.abc import Sequence
from PIL import Image
import torch
import torchvision.transforms as transforms
from torchvision.transforms import InterpolationMode
import openvino.runtime as ov
import openvino.runtime.opset11 as ops
from openvino.runtime import Layout, Type
from openvino.runtime.utils.decorators import custom_preprocess_function
from openvino.preprocess import PrePostProcessor, ResizeAlgorithm, ColorFormat
TORCHTYPE_TO_OVTYPE = {
float: ov.Type.f32,
int: ov.Type.i32,
bool: ov.Type.boolean,
torch.float16: ov.Type.f16,
torch.float32: ov.Type.f32,
torch.float64: ov.Type.f64,
torch.uint8: ov.Type.u8,
torch.int8: ov.Type.i8,
torch.int32: ov.Type.i32,
torch.int64: ov.Type.i64,
torch.bool: ov.Type.boolean,
torch.DoubleTensor: ov.Type.f64,
torch.FloatTensor: ov.Type.f32,
torch.IntTensor: ov.Type.i32,
torch.LongTensor: ov.Type.i64,
torch.BoolTensor: ov.Type.boolean,
}
def _setup_size(size: Any, error_msg: str) -> SequenceType[int]:
# TODO: refactor into @singledispatch once Python 3.7 support is dropped
if isinstance(size, numbers.Number):
return int(size), int(size) # type: ignore
if isinstance(size, Sequence):
if len(size) == 1:
return size[0], size[0]
elif len(size) == 2:
return size
raise ValueError(error_msg)
def _NHWC_to_NCHW(input_shape: List) -> List: # noqa N802
new_shape = copy.deepcopy(input_shape)
new_shape[1] = input_shape[3]
new_shape[2] = input_shape[1]
new_shape[3] = input_shape[2]
return new_shape
def _to_list(transform: Callable) -> List:
# TODO: refactor into @singledispatch once Python 3.7 support is dropped
if isinstance(transform, torch.nn.Sequential):
return list(transform)
elif isinstance(transform, transforms.Compose):
return transform.transforms
else:
raise TypeError(f"Unsupported transform type: {type(transform)}")
def _get_shape_layout_from_data(input_example: Union[torch.Tensor, np.ndarray, Image.Image]) -> Tuple[List, Layout]:
if isinstance(input_example, (torch.Tensor, np.ndarray, Image.Image)): # PyTorch, OpenCV, numpy, PILLOW
shape = list(np.array(input_example, copy=False).shape)
layout = Layout("NCHW") if isinstance(input_example, torch.Tensor) else Layout("NHWC")
else:
raise TypeError(f"Unsupported input type: {type(input_example)}")
if len(shape) == 3:
shape = [1] + shape
elif len(shape) != 4:
raise ValueError(f"Unsupported number of input dimensions: {len(shape)}")
return shape, layout
class TransformConverterBase(metaclass=ABCMeta):
def __init__(self, **kwargs: Any) -> None: # noqa B027
pass
@abstractmethod
def convert(self, input_idx: int, ppp: PrePostProcessor, transform: Callable, meta: Dict) -> None:
pass
class TransformConverterFactory:
registry: Dict[str, Callable] = {}
@classmethod
def register(cls: Callable, target_type: Union[Callable, None] = None) -> Callable:
def inner_wrapper(wrapped_class: TransformConverterBase) -> Callable:
registered_name = wrapped_class.__name__ if target_type is None else target_type.__name__
if registered_name in cls.registry:
logging.warning(f"Executor {registered_name} already exists. {wrapped_class.__name__} will replace it.")
cls.registry[registered_name] = wrapped_class
return wrapped_class # type: ignore
return inner_wrapper
@classmethod
def convert(cls: Callable, converter_type: Callable, *args: Any, **kwargs: Any) -> Callable:
transform_name = converter_type.__name__
if transform_name not in cls.registry:
raise ValueError(f"{transform_name} is not supported.")
converter = cls.registry[transform_name]()
return converter.convert(*args, **kwargs)
@TransformConverterFactory.register(transforms.Normalize)
class _(TransformConverterBase):
def convert(self, input_idx: int, ppp: PrePostProcessor, transform: Callable, meta: Dict) -> None:
if transform.inplace:
raise ValueError("Inplace Normaliziation is not supported.")
ppp.input(input_idx).preprocess().mean(transform.mean).scale(transform.std)
@TransformConverterFactory.register(transforms.ConvertImageDtype)
class _(TransformConverterBase):
def convert(self, input_idx: int, ppp: PrePostProcessor, transform: Callable, meta: Dict) -> None:
ppp.input(input_idx).preprocess().convert_element_type(TORCHTYPE_TO_OVTYPE[transform.dtype])
@TransformConverterFactory.register(transforms.Grayscale)
class _(TransformConverterBase):
def convert(self, input_idx: int, ppp: PrePostProcessor, transform: Callable, meta: Dict) -> None:
input_shape = meta["input_shape"]
layout = meta["layout"]
input_shape[layout.get_index_by_name("C")] = 1
ppp.input(input_idx).preprocess().convert_color(ColorFormat.GRAY)
if transform.num_output_channels != 1:
input_shape[layout.get_index_by_name("C")] = transform.num_output_channels
@custom_preprocess_function
def broadcast_node(output: ov.Output) -> Callable:
return ops.broadcast(
data=output,
target_shape=input_shape,
)
ppp.input(input_idx).preprocess().custom(broadcast_node)
meta["input_shape"] = input_shape
@TransformConverterFactory.register(transforms.Pad)
class _(TransformConverterBase):
def convert(self, input_idx: int, ppp: PrePostProcessor, transform: Callable, meta: Dict) -> None:
image_dimensions = list(meta["image_dimensions"])
layout = meta["layout"]
torch_padding = transform.padding
pad_mode = transform.padding_mode
if pad_mode == "constant":
if isinstance(transform.fill, tuple):
raise ValueError("Different fill values for R, G, B channels are not supported.")
pads_begin = [0 for _ in meta["input_shape"]]
pads_end = [0 for _ in meta["input_shape"]]
# padding equal on all sides
if isinstance(torch_padding, int):
image_dimensions[0] += 2 * torch_padding
image_dimensions[1] += 2 * torch_padding
pads_begin[layout.get_index_by_name("H")] = torch_padding
pads_begin[layout.get_index_by_name("W")] = torch_padding
pads_end[layout.get_index_by_name("H")] = torch_padding
pads_end[layout.get_index_by_name("W")] = torch_padding
# padding different in horizontal and vertical axis
elif len(torch_padding) == 2:
image_dimensions[0] += sum(torch_padding)
image_dimensions[1] += sum(torch_padding)
pads_begin[layout.get_index_by_name("H")] = torch_padding[1]
pads_begin[layout.get_index_by_name("W")] = torch_padding[0]
pads_end[layout.get_index_by_name("H")] = torch_padding[1]
pads_end[layout.get_index_by_name("W")] = torch_padding[0]
# padding different on top, bottom, left and right of image
else:
image_dimensions[0] += torch_padding[1] + torch_padding[3]
image_dimensions[1] += torch_padding[0] + torch_padding[2]
pads_begin[layout.get_index_by_name("H")] = torch_padding[1]
pads_begin[layout.get_index_by_name("W")] = torch_padding[0]
pads_end[layout.get_index_by_name("H")] = torch_padding[3]
pads_end[layout.get_index_by_name("W")] = torch_padding[2]
@custom_preprocess_function
def pad_node(output: ov.Output) -> Callable:
return ops.pad(
output,
pad_mode=pad_mode,
pads_begin=pads_begin,
pads_end=pads_end,
arg_pad_value=np.array(transform.fill, dtype=np.uint8) if pad_mode == "constant" else None,
)
ppp.input(input_idx).preprocess().custom(pad_node)
meta["image_dimensions"] = tuple(image_dimensions)
@TransformConverterFactory.register(transforms.ToTensor)
class _(TransformConverterBase):
def convert(self, input_idx: int, ppp: PrePostProcessor, transform: Callable, meta: Dict) -> None:
input_shape = meta["input_shape"]
layout = meta["layout"]
ppp.input(input_idx).tensor().set_element_type(Type.u8).set_layout(Layout("NHWC")).set_color_format(ColorFormat.RGB) # noqa ECE001
if layout == Layout("NHWC"):
input_shape = _NHWC_to_NCHW(input_shape)
layout = Layout("NCHW")
ppp.input(input_idx).preprocess().convert_layout(layout)
ppp.input(input_idx).preprocess().convert_element_type(Type.f32)
ppp.input(input_idx).preprocess().scale(255.0)
meta["input_shape"] = input_shape
meta["layout"] = layout
@TransformConverterFactory.register(transforms.CenterCrop)
class _(TransformConverterBase):
def convert(self, input_idx: int, ppp: PrePostProcessor, transform: Callable, meta: Dict) -> None:
input_shape = meta["input_shape"]
source_size = meta["image_dimensions"]
target_size = _setup_size(transform.size, "Incorrect size type for CenterCrop operation")
if target_size[0] > source_size[0] or target_size[1] > source_size[1]:
ValueError(f"CenterCrop size={target_size} is greater than source_size={source_size}")
bottom_left = []
bottom_left.append(int((source_size[0] - target_size[0]) / 2))
bottom_left.append(int((source_size[1] - target_size[1]) / 2))
top_right = []
top_right.append(min(bottom_left[0] + target_size[0], source_size[0] - 1))
top_right.append(min(bottom_left[1] + target_size[1], source_size[1] - 1))
bottom_left = [0] * len(input_shape[:-2]) + bottom_left if meta["layout"] == Layout("NCHW") else [0] + bottom_left + [0] # noqa ECE001
top_right = input_shape[:-2] + top_right if meta["layout"] == Layout("NCHW") else input_shape[:1] + top_right + input_shape[-1:]
ppp.input(input_idx).preprocess().crop(bottom_left, top_right)
meta["image_dimensions"] = (target_size[-2], target_size[-1])
@TransformConverterFactory.register(transforms.Resize)
class _(TransformConverterBase):
def convert(self, input_idx: int, ppp: PrePostProcessor, transform: Callable, meta: Dict) -> None:
resize_mode_map = {
InterpolationMode.NEAREST: ResizeAlgorithm.RESIZE_NEAREST,
}
if transform.max_size:
raise ValueError("Resize with max_size if not supported")
if transform.interpolation is not InterpolationMode.NEAREST:
raise ValueError("Only InterpolationMode.NEAREST is supported.")
h, w = _setup_size(transform.size, "Incorrect size type for Resize operation")
ppp.input(input_idx).tensor().set_layout(Layout("NCHW"))
input_shape = meta["input_shape"]
input_shape[meta["layout"].get_index_by_name("H")] = -1
input_shape[meta["layout"].get_index_by_name("W")] = -1
ppp.input(input_idx).tensor().set_shape(input_shape)
ppp.input(input_idx).preprocess().resize(resize_mode_map[transform.interpolation], h, w)
meta["input_shape"] = input_shape
meta["image_dimensions"] = (h, w)
def _from_torchvision(model: ov.Model, transform: Callable, input_example: Any, input_name: Union[str, None] = None) -> ov.Model:
if input_name is not None:
input_idx = next((i for i, p in enumerate(model.get_parameters()) if p.get_friendly_name() == input_name), None)
else:
if len(model.get_parameters()) == 1:
input_idx = 0
else:
raise ValueError("Model contains multiple inputs. Please specify the name of the input to which prepocessing is added.")
if input_idx is None:
raise ValueError(f"Input with name {input_name} is not found")
input_shape, layout = _get_shape_layout_from_data(input_example)
ppp = PrePostProcessor(model)
ppp.input(input_idx).tensor().set_layout(layout)
ppp.input(input_idx).tensor().set_shape(input_shape)
image_dimensions = [input_shape[layout.get_index_by_name("H")], input_shape[layout.get_index_by_name("W")]]
global_meta = {
"input_shape": input_shape,
"image_dimensions": image_dimensions,
"layout": layout,
}
for tm in _to_list(transform):
TransformConverterFactory.convert(type(tm), input_idx, ppp, tm, global_meta)
updated_model = ppp.build()
return updated_model

View File

@ -116,7 +116,8 @@ if(OpenVINO_SOURCE_DIR OR OpenVINODeveloperPackage_FOUND)
COMPONENT ${OV_CPACK_COMP_PYTHON_OPENVINO}_${pyversion} COMPONENT ${OV_CPACK_COMP_PYTHON_OPENVINO}_${pyversion}
${OV_CPACK_COMP_PYTHON_OPENVINO_EXCLUDE_ALL} ${OV_CPACK_COMP_PYTHON_OPENVINO_EXCLUDE_ALL}
USE_SOURCE_PERMISSIONS USE_SOURCE_PERMISSIONS
PATTERN "test_utils" EXCLUDE) PATTERN "test_utils" EXCLUDE
PATTERN "torchvision/requirements.txt" EXCLUDE)
install(TARGETS ${PROJECT_NAME} install(TARGETS ${PROJECT_NAME}
DESTINATION ${OV_CPACK_PYTHONDIR}/openvino DESTINATION ${OV_CPACK_PYTHONDIR}/openvino
@ -127,6 +128,11 @@ if(OpenVINO_SOURCE_DIR OR OpenVINODeveloperPackage_FOUND)
DESTINATION ${OV_CPACK_PYTHONDIR} DESTINATION ${OV_CPACK_PYTHONDIR}
COMPONENT ${OV_CPACK_COMP_OPENVINO_DEV_REQ_FILES} COMPONENT ${OV_CPACK_COMP_OPENVINO_DEV_REQ_FILES}
${OV_CPACK_COMP_OPENVINO_DEV_REQ_FILES_EXCLUDE_ALL}) ${OV_CPACK_COMP_OPENVINO_DEV_REQ_FILES_EXCLUDE_ALL})
install(FILES ${OpenVINOPython_SOURCE_DIR}/src/openvino/preprocess/torchvision/requirements.txt
DESTINATION ${OV_CPACK_PYTHONDIR}
COMPONENT ${OV_CPACK_COMP_OPENVINO_DEV_REQ_FILES}/src/openvino/preprocess/torchvision
${OV_CPACK_COMP_OPENVINO_DEV_REQ_FILES_EXCLUDE_ALL})
install(DIRECTORY ${OpenVINOPython_SOURCE_DIR}/tests install(DIRECTORY ${OpenVINOPython_SOURCE_DIR}/tests
DESTINATION tests/${PROJECT_NAME} DESTINATION tests/${PROJECT_NAME}

View File

@ -0,0 +1,241 @@
# -*- coding: utf-8 -*-
# Copyright (C) 2018-2023 Intel Corporation
# SPDX-License-Identifier: Apache-2.0
import numpy as np
import copy
import pytest
import platform
from PIL import Image
import torch
import torch.nn.functional as f
import torchvision.transforms as transforms
from openvino.runtime import Core, Tensor
from openvino.tools.mo import convert_model
from openvino.preprocess.torchvision import PreprocessConverter
class Convnet(torch.nn.Module):
def __init__(self, input_channels):
super(Convnet, self).__init__()
self.conv1 = torch.nn.Conv2d(input_channels, 6, 5)
self.conv2 = torch.nn.Conv2d(6, 16, 3)
def forward(self, data):
data = f.max_pool2d(f.relu(self.conv1(data)), 2)
data = f.max_pool2d(f.relu(self.conv2(data)), 2)
return data
def _infer_pipelines(test_input, preprocess_pipeline, input_channels=3):
torch_model = Convnet(input_channels)
example_input = Tensor(np.expand_dims(test_input, axis=0).astype(np.float32))
ov_model = convert_model(torch_model, example_input=example_input)
core = Core()
ov_model = PreprocessConverter.from_torchvision(
model=ov_model, transform=preprocess_pipeline, input_example=Image.fromarray(test_input.astype("uint8"), "RGB"),
)
ov_model = core.compile_model(ov_model, "CPU")
# Torch results
torch_input = copy.deepcopy(test_input)
test_image = Image.fromarray(torch_input.astype("uint8"), "RGB")
transformed_input = preprocess_pipeline(test_image)
transformed_input = torch.unsqueeze(transformed_input, dim=0)
with torch.no_grad():
torch_result = torch_model(transformed_input).numpy()
# OpenVINO results
ov_input = test_input
ov_input = np.expand_dims(ov_input, axis=0)
output = ov_model.output(0)
ov_result = ov_model(ov_input)[output]
return torch_result, ov_result
def test_normalize():
test_input = np.random.randint(255, size=(224, 224, 3), dtype=np.uint8)
preprocess_pipeline = transforms.Compose([transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])])
torch_result, ov_result = _infer_pipelines(test_input, preprocess_pipeline)
assert np.max(np.absolute(torch_result - ov_result)) < 4e-05
@pytest.mark.parametrize(
("interpolation", "tolerance"),
[
(transforms.InterpolationMode.NEAREST, 4e-05),
],
)
def test_resize(interpolation, tolerance):
if platform.machine() in ["arm", "armv7l", "aarch64", "arm64", "ARM64"]:
pytest.skip("Ticket: 114816")
test_input = np.random.randint(255, size=(220, 220, 3), dtype=np.uint8)
preprocess_pipeline = transforms.Compose(
[
transforms.Resize(224, interpolation=interpolation),
transforms.ToTensor(),
],
)
torch_result, ov_result = _infer_pipelines(test_input, preprocess_pipeline)
assert np.max(np.absolute(torch_result - ov_result)) < tolerance
def test_convertimagedtype():
test_input = np.random.randint(255, size=(224, 224, 3), dtype=np.uint8)
preprocess_pipeline = transforms.Compose(
[
transforms.ToTensor(),
transforms.ConvertImageDtype(torch.float16),
transforms.ConvertImageDtype(torch.float32),
],
)
torch_result, ov_result = _infer_pipelines(test_input, preprocess_pipeline)
assert np.max(np.absolute(torch_result - ov_result)) < 2e-04
@pytest.mark.parametrize(
("test_input", "preprocess_pipeline"),
[
(
np.random.randint(255, size=(220, 220, 3), dtype=np.uint8),
transforms.Compose(
[
transforms.Pad((2)),
transforms.ToTensor(),
],
),
),
(
np.random.randint(255, size=(218, 220, 3), dtype=np.uint8),
transforms.Compose(
[
transforms.Pad((2, 3)),
transforms.ToTensor(),
],
),
),
(
np.random.randint(255, size=(216, 218, 3), dtype=np.uint8),
transforms.Compose(
[
transforms.Pad((2, 3, 4, 5)),
transforms.ToTensor(),
],
),
),
(
np.random.randint(255, size=(216, 218, 3), dtype=np.uint8),
transforms.Compose(
[
transforms.Pad((2, 3, 4, 5), fill=3),
transforms.ToTensor(),
],
),
),
(
np.random.randint(255, size=(218, 220, 3), dtype=np.uint8),
transforms.Compose(
[
transforms.Pad((2, 3), padding_mode="edge"),
transforms.ToTensor(),
],
),
),
(
np.random.randint(255, size=(218, 220, 3), dtype=np.uint8),
transforms.Compose(
[
transforms.Pad((2, 3), padding_mode="reflect"),
transforms.ToTensor(),
],
),
),
(
np.random.randint(255, size=(218, 220, 3), dtype=np.uint8),
transforms.Compose(
[
transforms.Pad((2, 3), padding_mode="symmetric"),
transforms.ToTensor(),
],
),
),
],
)
def test_pad(test_input, preprocess_pipeline):
torch_result, ov_result = _infer_pipelines(test_input, preprocess_pipeline)
assert np.max(np.absolute(torch_result - ov_result)) < 4e-05
def test_centercrop():
test_input = np.random.randint(255, size=(260, 260, 3), dtype=np.uint8)
preprocess_pipeline = transforms.Compose(
[
transforms.CenterCrop((224)),
transforms.ToTensor(),
],
)
torch_result, ov_result = _infer_pipelines(test_input, preprocess_pipeline)
assert np.max(np.absolute(torch_result - ov_result)) < 4e-05
def test_grayscale():
test_input = np.random.randint(255, size=(224, 224, 3), dtype=np.uint8)
preprocess_pipeline = transforms.Compose([transforms.ToTensor(), transforms.Grayscale()])
torch_result, ov_result = _infer_pipelines(test_input, preprocess_pipeline, input_channels=1)
assert np.max(np.absolute(torch_result - ov_result)) < 2e-04
def test_grayscale_num_output_channels():
test_input = np.random.randint(255, size=(224, 224, 3), dtype=np.uint8)
preprocess_pipeline = transforms.Compose([transforms.ToTensor(), transforms.Grayscale(3)])
torch_result, ov_result = _infer_pipelines(test_input, preprocess_pipeline)
assert np.max(np.absolute(torch_result - ov_result)) < 2e-04
def test_pipeline_1():
test_input = np.random.randint(255, size=(260, 260, 3), dtype=np.uint8)
preprocess_pipeline = transforms.Compose(
[
transforms.Resize(256, interpolation=transforms.InterpolationMode.NEAREST),
transforms.CenterCrop((216, 218)),
transforms.Pad((2, 3, 4, 5), fill=3),
transforms.ToTensor(),
transforms.ConvertImageDtype(torch.float32),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
],
)
torch_result, ov_result = _infer_pipelines(test_input, preprocess_pipeline)
assert np.max(np.absolute(torch_result - ov_result)) < 4e-05
def test_pipeline_2():
if platform.machine() in ["arm", "armv7l", "aarch64", "arm64", "ARM64"]:
pytest.skip("Ticket: 114816")
test_input = np.random.randint(255, size=(224, 224, 3), dtype=np.uint8)
preprocess_pipeline = transforms.Compose(
[
transforms.Resize(256, interpolation=transforms.InterpolationMode.NEAREST),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize((0.481, 0.457, 0.408), (0.268, 0.261, 0.275)),
],
)
torch_result, ov_result = _infer_pipelines(test_input, preprocess_pipeline)
assert np.max(np.absolute(torch_result - ov_result)) < 5e-03
def test_pipeline_3():
test_input = np.random.randint(255, size=(260, 260, 3), dtype=np.uint8)
preprocess_pipeline = transforms.Compose(
[
transforms.ToTensor(),
transforms.ConvertImageDtype(torch.float),
],
)
torch_result, ov_result = _infer_pipelines(test_input, preprocess_pipeline)
assert np.max(np.absolute(torch_result - ov_result)) < 2e-03