Files
openvino/model-optimizer/mo/utils/broadcasting.py
Pavel Esir 9149d899b3 [MO] Fix: add PermuteInputs to shape broadcasting (#2419)
* fix: add PermuteInputs to shape broadcasting

* fix type declaration typo
2020-10-05 12:48:57 +03:00

141 lines
6.2 KiB
Python

"""
Copyright (C) 2018-2020 Intel Corporation
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
import logging as log
import numpy as np
from mo.front.common.partial_infer.utils import int64_array
def make_equal_rank(shape_1: np.array, shape_2: np.array):
"""
Prepend shape with smaller length with 1. Return updates shapes
:param shape_1: first shape
:param shape_2: second shape
:return: tuple with updated shapes
"""
while len(shape_1) < len(shape_2):
shape_1 = np.insert(shape_1, 0, 1)
while len(shape_2) < len(shape_1):
shape_2 = np.insert(shape_2, 0, 1)
return shape_1, shape_2
def uni_directional_shape_broadcasting(input_shape: np.array, target_shape: np.array):
"""
Uni-directional broadcasting of two shapes following the numpy semantic
:param input_shape: input shape to broadcast
:param target_shape: target shape
:return: broadcasted shape or None if broadcasting cannot be performed
"""
input = input_shape.copy()
# in one-directional broadcasting the target shape rank can be higher or equal than input shape
if len(input_shape) > len(target_shape):
log.debug('The shape "{}" cannot be broadcasted to "{}"'.format(input_shape, target_shape))
return None
# prepend input shape with 1s
input, target_shape = make_equal_rank(input, target_shape)
for left, right in zip(input, target_shape):
if left != right and left != 1:
log.debug('The shape "{}" cannot be broadcasted to "{}"'.format(input_shape, target_shape))
return None
return target_shape
def bi_directional_shape_broadcasting(input_shape_1: np.array, input_shape_2: np.array):
"""
Bi-directional broadcasting of two shapes following numpy semantic
:param input_shape_1: first shape to broadcast
:param input_shape_2: second shape to broadcast
:return: broadcasted shape or None if broadcasting cannot be performed
"""
shape_1 = input_shape_1.copy()
shape_2 = input_shape_2.copy()
shape_1, shape_2 = make_equal_rank(shape_1, shape_2)
for left, right in zip(shape_1, shape_2):
if left != right and left != 1 and right != 1:
log.debug('The shape "{}" cannot be broadcasted to "{}"'.format(input_shape_1, input_shape_2))
return None
return np.maximum(shape_1, shape_2)
def explicit_shape_broadcasting(input_shape: np.array, target_shape: np.array, axes_mapping: np.array) -> [np.array, np.array]:
"""
Explicit shape broadcasting of input tensor. Function only asserts that values are correct and normalizes axes.
Resulting shape is equal to target_shape.
:param input_value: input value to broadcast
:param target_shape: target shape
:param axes_mapping: a list of axis indices, each index maps an axis from the input_value to axis in the output
:return: broadcasted shape and normalized axes
"""
assert np.all(np.diff(axes_mapping) >= 0), "axes_mapping is not sorted"
assert len(axes_mapping) == len(input_shape), "size of axes_mapping does not match to rank of input"
axes_mapping = np.array(list(map(lambda axis: axis + len(target_shape) if axis < 0 else axis, axes_mapping)))
res = target_shape.copy()
for i, axis in enumerate(axes_mapping):
assert 0 <= axis < len(res), "axis value from axes_mapping exceeds rank of target_shape"
assert res[axis] == input_shape[i], "specified mapping axis in target_shape differs from axis in input_shape"
return res, axes_mapping
def uni_directional_broadcasting(input_value: np.array, target_shape: np.array):
"""
Uni-directional broadcasting of input tensor to target shape following the numpy semantic
:param input_value: input value to broadcast
:param target_shape: target shape
:return: broadcasted value
"""
assert uni_directional_shape_broadcasting(int64_array(input_value.shape), target_shape) is not None, \
'The tensor of shape "{}" cannot be uni-directionally broadcasted to shape "{}"'.format(input_value.shape,
target_shape)
return np.broadcast_to(input_value, target_shape)
def bi_directional_broadcasting(input_value: np.array, second_shape: np.array):
"""
Bi-directional broadcasting of input tensor to target shape following the numpy semantic
:param input_value: input value to broadcast
:param second_shape: second tensor shape
:return: broadcasted value
"""
assert bi_directional_shape_broadcasting(int64_array(input_value.shape), second_shape) is not None, \
'The tensor of shape "{}" cannot be bi-directionally broadcasted to shape "{}"'.format(input_value.shape,
second_shape)
return np.array(input_value * np.ones(second_shape), dtype=input_value.dtype)
def explicit_broadcasting(input_value: np.array, target_shape: np.array, axes_mapping: np.array) -> np.array:
"""
Explicit broadcasting of input tensor. Resulting shape is equal to target_shape except for axes specified in axes_mapping
:param input_value: input value to broadcast
:param target_shape: target shape
:param axes_mapping: a list of axis indices, each index maps an axis from the input_value to axis in the output
:return: broadcasted value
"""
res_shape, normalized_axes_mapping = explicit_shape_broadcasting(input_value.shape, target_shape, axes_mapping)
expand_dim_axis = set(np.arange(len(target_shape))) - set(normalized_axes_mapping)
input_expanded = np.expand_dims(input_value.copy(), axis=list(expand_dim_axis))
return np.broadcast_to(input_expanded, res_shape)