[MO] Fix: add PermuteInputs to shape broadcasting (#2419)
* fix: add PermuteInputs to shape broadcasting * fix type declaration typo
This commit is contained in:
parent
0879938250
commit
9149d899b3
@ -87,6 +87,7 @@ class Broadcast(Op):
|
||||
axes_mapping = node.in_port(2).data.get_value()
|
||||
assert axes_mapping is not None, 'Broadcast(mode="explicit") with dynamic axes_mapping input ' \
|
||||
'is not supported. Node: `{}`'.format(node_name)
|
||||
PermuteInputs().set_input_permutation(node.in_node(2), node, 'output:0', 'axis')
|
||||
axes_mapping = node.in_port(2).data.get_value()
|
||||
new_shape,_ = explicit_shape_broadcasting(input_shape, target_shape, axes_mapping)
|
||||
node.out_port(0).data.set_shape(new_shape)
|
||||
|
@ -79,7 +79,7 @@ def bi_directional_shape_broadcasting(input_shape_1: np.array, input_shape_2: np
|
||||
return np.maximum(shape_1, shape_2)
|
||||
|
||||
|
||||
def explicit_shape_broadcasting(input_shape: np.array, target_shape: np.array, axes_mapping: np.array) -> np.array:
|
||||
def explicit_shape_broadcasting(input_shape: np.array, target_shape: np.array, axes_mapping: np.array) -> [np.array, np.array]:
|
||||
"""
|
||||
Explicit shape broadcasting of input tensor. Function only asserts that values are correct and normalizes axes.
|
||||
Resulting shape is equal to target_shape.
|
||||
|
Loading…
Reference in New Issue
Block a user