[ MO ] Fixed layout interpretation for 4/5D tensors calculated from ShapeOfs (#1634)
This commit is contained in:
committed by
GitHub
parent
2b474c8a47
commit
2d2a6dbfd8
@@ -62,13 +62,6 @@ class Broadcast(Op):
|
||||
assert target_shape is not None, 'Output shape is not defined for node "{}"'.format(node_name)
|
||||
assert node.has_and_set('mode'), 'Broadcasting mode is not defined for node "{}"'.format(node_name)
|
||||
|
||||
if node.mode == 'numpy':
|
||||
node.out_port(0).data.set_shape(uni_directional_shape_broadcasting(input_shape, target_shape))
|
||||
elif node.mode == 'bidirectional':
|
||||
node.out_port(0).data.set_shape(bi_directional_shape_broadcasting(input_shape, target_shape))
|
||||
else:
|
||||
raise Error('The node "{}" has unsupported mode "{}"'.format(node_name, node.mode))
|
||||
|
||||
PermuteInputs().set_input_permutation(node.in_node(1), node, 'output:0', 'shape')
|
||||
|
||||
if input_value is not None and not node.has_and_set('stop_value_propagation'):
|
||||
@@ -76,3 +69,12 @@ class Broadcast(Op):
|
||||
node.out_port(0).data.set_value(uni_directional_broadcasting(input_value, target_shape))
|
||||
elif node.mode == 'bidirectional':
|
||||
node.out_port(0).data.set_value(bi_directional_broadcasting(input_value, target_shape))
|
||||
else:
|
||||
raise Error('The node "{}" has unsupported mode "{}"'.format(node_name, node.mode))
|
||||
else:
|
||||
if node.mode == 'numpy':
|
||||
node.out_port(0).data.set_shape(uni_directional_shape_broadcasting(input_shape, target_shape))
|
||||
elif node.mode == 'bidirectional':
|
||||
node.out_port(0).data.set_shape(bi_directional_shape_broadcasting(input_shape, target_shape))
|
||||
else:
|
||||
raise Error('The node "{}" has unsupported mode "{}"'.format(node_name, node.mode))
|
||||
|
||||
Reference in New Issue
Block a user