[MO] Fix: add PermuteInputs to shape broadcasting (#2419)

* fix: add PermuteInputs to shape broadcasting

* fix type declaration typo
This commit is contained in:
Pavel Esir 2020-10-05 12:48:57 +03:00 committed by GitHub
parent 0879938250
commit 9149d899b3
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 2 additions and 1 deletions

View File

@ -87,6 +87,7 @@ class Broadcast(Op):
axes_mapping = node.in_port(2).data.get_value()
assert axes_mapping is not None, 'Broadcast(mode="explicit") with dynamic axes_mapping input ' \
'is not supported. Node: `{}`'.format(node_name)
PermuteInputs().set_input_permutation(node.in_node(2), node, 'output:0', 'axis')
axes_mapping = node.in_port(2).data.get_value()
new_shape,_ = explicit_shape_broadcasting(input_shape, target_shape, axes_mapping)
node.out_port(0).data.set_shape(new_shape)

View File

@ -79,7 +79,7 @@ def bi_directional_shape_broadcasting(input_shape_1: np.array, input_shape_2: np
return np.maximum(shape_1, shape_2)
def explicit_shape_broadcasting(input_shape: np.array, target_shape: np.array, axes_mapping: np.array) -> np.array:
def explicit_shape_broadcasting(input_shape: np.array, target_shape: np.array, axes_mapping: np.array) -> [np.array, np.array]:
"""
Explicit shape broadcasting of input tensor. Function only asserts that values are correct and normalizes axes.
Resulting shape is equal to target_shape.