[ MO ] Fixed layout interpretation for 4/5D tensors calculated from ShapeOfs (#1634)

This commit is contained in:
Evgenya Stepyreva
2020-08-11 09:34:04 +03:00
committed by GitHub
parent 2b474c8a47
commit 2d2a6dbfd8
2 changed files with 38 additions and 17 deletions

View File

@@ -62,13 +62,6 @@ class Broadcast(Op):
assert target_shape is not None, 'Output shape is not defined for node "{}"'.format(node_name)
assert node.has_and_set('mode'), 'Broadcasting mode is not defined for node "{}"'.format(node_name)
if node.mode == 'numpy':
node.out_port(0).data.set_shape(uni_directional_shape_broadcasting(input_shape, target_shape))
elif node.mode == 'bidirectional':
node.out_port(0).data.set_shape(bi_directional_shape_broadcasting(input_shape, target_shape))
else:
raise Error('The node "{}" has unsupported mode "{}"'.format(node_name, node.mode))
PermuteInputs().set_input_permutation(node.in_node(1), node, 'output:0', 'shape')
if input_value is not None and not node.has_and_set('stop_value_propagation'):
@@ -76,3 +69,12 @@ class Broadcast(Op):
node.out_port(0).data.set_value(uni_directional_broadcasting(input_value, target_shape))
elif node.mode == 'bidirectional':
node.out_port(0).data.set_value(bi_directional_broadcasting(input_value, target_shape))
else:
raise Error('The node "{}" has unsupported mode "{}"'.format(node_name, node.mode))
else:
if node.mode == 'numpy':
node.out_port(0).data.set_shape(uni_directional_shape_broadcasting(input_shape, target_shape))
elif node.mode == 'bidirectional':
node.out_port(0).data.set_shape(bi_directional_shape_broadcasting(input_shape, target_shape))
else:
raise Error('The node "{}" has unsupported mode "{}"'.format(node_name, node.mode))