[MO] Add support for dynamic cases in legacy Broadcast shape infer (#15546)

* [MO] Add support for dynamic cases in legacy Broadcast shape infer

* Update broadcast_test.py

* Update broadcast.py

* Update broadcast_test.py

* Update broadcast_test.py
This commit is contained in:
Maxim Vafin 2023-02-08 13:53:43 +01:00 committed by GitHub
parent 70177dbfb3
commit de4f9c8131
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 24 additions and 11 deletions

View File

@ -49,16 +49,21 @@ class Broadcast(Op):
target_shape_shape = node.in_port(1).data.get_shape()
target_shape = node.in_port(1).data.get_value()
assert node.has_and_set('mode'), 'Broadcasting mode is not defined for node "{}"'.format(node_name)
# Dynamic target shape is possible to infer only if shape of target shape is static and 1D
if target_shape is None and len(target_shape_shape) == 1 and (len(input_shape) <= 1 or node.mode == 'explicit'):
assert is_fully_defined(target_shape_shape)
new_shape = undefined_shape_of_rank(target_shape_shape.item(0))
node.out_port(0).data.set_shape(new_shape)
return
assert target_shape is not None, 'Output shape is not defined for node "{}"'.format(node_name)
PermuteInputs().set_input_permutation(node.in_node(1), node, 'output:0', 'shape')
# Dynamic target shape is possible to infer only if shape of target shape is static
if target_shape is None:
assert len(target_shape_shape) == 1, 'Shape of target_shape must be [1] for node "{}"'.format(node_name)
assert is_fully_defined(target_shape_shape), 'Output shape is not defined for node "{}"'.format(node_name)
new_shape = undefined_shape_of_rank(target_shape_shape.item(0))
node.out_port(0).data.set_shape(new_shape)
if node.mode == 'explicit':
assert node.is_in_port_connected(
2), 'Axes mapping must be specified for Broadcast(mode="explicit"). Node: `{}`'.format(node_name)
PermuteInputs().set_input_permutation(node.in_node(2), node, 'output:0', 'axis')
return
if input_value is not None and not node.has_and_set('stop_value_propagation') and \
is_fully_defined(target_shape):
if node.mode == 'numpy':

View File

@ -78,11 +78,15 @@ class BroadcastTest(unittest.TestCase):
self.assertTrue(np.array_equal(broadcast_node.out_node().shape, np.array(target_shape)))
@generate(*[
([1], [3], 'numpy', undefined_shape_of_rank(3)),
([1], [3], 'explicit', undefined_shape_of_rank(3)),
([1, 2], [3], 'numpy', None, True),
([1], [3], [0], 'explicit', undefined_shape_of_rank(3)),
([1], [3], None, 'numpy', undefined_shape_of_rank(3)),
([1], [3], None, 'bidirectional', undefined_shape_of_rank(3)),
([1, 7], [4], [1, 2], 'explicit', undefined_shape_of_rank(4)),
([1, 2], [3], None, 'numpy', undefined_shape_of_rank(3)),
([1, 1], [2], None, 'bidirectional', undefined_shape_of_rank(2)),
([1, 1], [2, 1], None, 'numpy', None, True),
])
def test_broadcast_dynamic(self, data, target_shape_shape, mode='numpy', ref_out_shape=None, test_raising=False):
def test_broadcast_dynamic(self, data, target_shape_shape, axes_mapping=None, mode='numpy', ref_out_shape=None, test_raising=False):
nodes = {
**shaped_data('data', int64_array(data)),
**shaped_data('target_shape', int64_array(target_shape_shape)),
@ -93,6 +97,10 @@ class BroadcastTest(unittest.TestCase):
('target_shape', 'broadcast'),
('broadcast', 'broadcast_d')]
if axes_mapping is not None:
nodes.update(**valued_const_with_data('axes_mapping', int64_array(axes_mapping)))
edges.append(('axes_mapping', 'axes_mapping_d'))
edges.append(('axes_mapping_d', 'broadcast'))
graph = build_graph(nodes, edges)
broadcast_node = Node(graph, 'broadcast')