Refactored with single dispatch generic function implementation (#19958)
* Refactored with single dispatch generic function implementation * Resolved mypy linting warnings * Update src/bindings/python/src/openvino/preprocess/torchvision/torchvision_preprocessing.py * Update src/bindings/python/src/openvino/preprocess/torchvision/torchvision_preprocessing.py --------- Co-authored-by: Przemyslaw Wysocki <przemyslaw.wysocki@intel.com>
This commit is contained in:
parent
0dd54c8a0e
commit
39f6cbf259
@ -10,6 +10,7 @@ import copy
|
||||
import numpy as np
|
||||
from typing import List, Dict
|
||||
from abc import ABCMeta, abstractmethod
|
||||
from functools import singledispatch
|
||||
from typing import Callable, Any, Union, Tuple
|
||||
from typing import Sequence as SequenceType
|
||||
from collections.abc import Sequence
|
||||
@ -46,15 +47,22 @@ TORCHTYPE_TO_OVTYPE = {
|
||||
}
|
||||
|
||||
|
||||
@singledispatch
|
||||
def _setup_size(size: Any, error_msg: str) -> SequenceType[int]:
|
||||
# TODO: refactor into @singledispatch once Python 3.7 support is dropped
|
||||
if isinstance(size, numbers.Number):
|
||||
return int(size), int(size) # type: ignore
|
||||
if isinstance(size, Sequence):
|
||||
if len(size) == 1:
|
||||
return size[0], size[0]
|
||||
elif len(size) == 2:
|
||||
return size
|
||||
raise ValueError(error_msg)
|
||||
|
||||
|
||||
@_setup_size.register
|
||||
def _setup_size_number(size: numbers.Number, error_msg: str) -> SequenceType[int]:
|
||||
return int(size), int(size) # type: ignore
|
||||
|
||||
|
||||
@_setup_size.register
|
||||
def _setup_size_sequence(size: Sequence, error_msg: str) -> SequenceType[int]:
|
||||
if len(size) == 1:
|
||||
return size[0], size[0]
|
||||
elif len(size) == 2:
|
||||
return size
|
||||
raise ValueError(error_msg)
|
||||
|
||||
|
||||
@ -66,14 +74,19 @@ def _NHWC_to_NCHW(input_shape: List) -> List: # noqa N802
|
||||
return new_shape
|
||||
|
||||
|
||||
@singledispatch
|
||||
def _to_list(transform: Callable) -> List:
|
||||
# TODO: refactor into @singledispatch once Python 3.7 support is dropped
|
||||
if isinstance(transform, torch.nn.Sequential):
|
||||
return list(transform)
|
||||
elif isinstance(transform, transforms.Compose):
|
||||
return transform.transforms
|
||||
else:
|
||||
raise TypeError(f"Unsupported transform type: {type(transform)}")
|
||||
raise TypeError(f"Unsupported transform type: {type(transform)}")
|
||||
|
||||
|
||||
@_to_list.register
|
||||
def _to_list_torch_sequential(transform: torch.nn.Sequential) -> List:
|
||||
return list(transform)
|
||||
|
||||
|
||||
@_to_list.register
|
||||
def _to_list_transforms_compose(transform: transforms.Compose) -> List:
|
||||
return transform.transforms
|
||||
|
||||
|
||||
def _get_shape_layout_from_data(input_example: Union[torch.Tensor, np.ndarray, Image.Image]) -> Tuple[List, Layout]:
|
||||
|
Loading…
Reference in New Issue
Block a user