From de4f9c813138f292601a65e8c9021ef9fa03c562 Mon Sep 17 00:00:00 2001 From: Maxim Vafin Date: Wed, 8 Feb 2023 13:53:43 +0100 Subject: [PATCH] [MO] Add support for dynamic cases in legacy Broadcast shape infer (#15546) * [MO] Add support for dynamic cases in legacy Broadcast shape infer * Update broadcast_test.py * Update broadcast.py * Update broadcast_test.py * Update broadcast_test.py --- tools/mo/openvino/tools/mo/ops/broadcast.py | 19 ++++++++++++------- tools/mo/unit_tests/mo/ops/broadcast_test.py | 16 ++++++++++++---- 2 files changed, 24 insertions(+), 11 deletions(-) diff --git a/tools/mo/openvino/tools/mo/ops/broadcast.py b/tools/mo/openvino/tools/mo/ops/broadcast.py index 7207267edf5..12a2178d809 100644 --- a/tools/mo/openvino/tools/mo/ops/broadcast.py +++ b/tools/mo/openvino/tools/mo/ops/broadcast.py @@ -49,16 +49,21 @@ class Broadcast(Op): target_shape_shape = node.in_port(1).data.get_shape() target_shape = node.in_port(1).data.get_value() assert node.has_and_set('mode'), 'Broadcasting mode is not defined for node "{}"'.format(node_name) - # Dynamic target shape is possible to infer only if shape of target shape is static and 1D - if target_shape is None and len(target_shape_shape) == 1 and (len(input_shape) <= 1 or node.mode == 'explicit'): - assert is_fully_defined(target_shape_shape) - new_shape = undefined_shape_of_rank(target_shape_shape.item(0)) - node.out_port(0).data.set_shape(new_shape) - return - assert target_shape is not None, 'Output shape is not defined for node "{}"'.format(node_name) PermuteInputs().set_input_permutation(node.in_node(1), node, 'output:0', 'shape') + # Dynamic target shape is possible to infer only if shape of target shape is static + if target_shape is None: + assert len(target_shape_shape) == 1, 'Shape of target_shape must be [1] for node "{}"'.format(node_name) + assert is_fully_defined(target_shape_shape), 'Output shape is not defined for node "{}"'.format(node_name) + new_shape = undefined_shape_of_rank(target_shape_shape.item(0)) + node.out_port(0).data.set_shape(new_shape) + if node.mode == 'explicit': + assert node.is_in_port_connected( + 2), 'Axes mapping must be specified for Broadcast(mode="explicit"). Node: `{}`'.format(node_name) + PermuteInputs().set_input_permutation(node.in_node(2), node, 'output:0', 'axis') + return + if input_value is not None and not node.has_and_set('stop_value_propagation') and \ is_fully_defined(target_shape): if node.mode == 'numpy': diff --git a/tools/mo/unit_tests/mo/ops/broadcast_test.py b/tools/mo/unit_tests/mo/ops/broadcast_test.py index 2be08f79848..7da252317a7 100644 --- a/tools/mo/unit_tests/mo/ops/broadcast_test.py +++ b/tools/mo/unit_tests/mo/ops/broadcast_test.py @@ -78,11 +78,15 @@ class BroadcastTest(unittest.TestCase): self.assertTrue(np.array_equal(broadcast_node.out_node().shape, np.array(target_shape))) @generate(*[ - ([1], [3], 'numpy', undefined_shape_of_rank(3)), - ([1], [3], 'explicit', undefined_shape_of_rank(3)), - ([1, 2], [3], 'numpy', None, True), + ([1], [3], [0], 'explicit', undefined_shape_of_rank(3)), + ([1], [3], None, 'numpy', undefined_shape_of_rank(3)), + ([1], [3], None, 'bidirectional', undefined_shape_of_rank(3)), + ([1, 7], [4], [1, 2], 'explicit', undefined_shape_of_rank(4)), + ([1, 2], [3], None, 'numpy', undefined_shape_of_rank(3)), + ([1, 1], [2], None, 'bidirectional', undefined_shape_of_rank(2)), + ([1, 1], [2, 1], None, 'numpy', None, True), ]) - def test_broadcast_dynamic(self, data, target_shape_shape, mode='numpy', ref_out_shape=None, test_raising=False): + def test_broadcast_dynamic(self, data, target_shape_shape, axes_mapping=None, mode='numpy', ref_out_shape=None, test_raising=False): nodes = { **shaped_data('data', int64_array(data)), **shaped_data('target_shape', int64_array(target_shape_shape)), @@ -93,6 +97,10 @@ class BroadcastTest(unittest.TestCase): ('target_shape', 'broadcast'), ('broadcast', 'broadcast_d')] + if axes_mapping is not None: + nodes.update(**valued_const_with_data('axes_mapping', int64_array(axes_mapping))) + edges.append(('axes_mapping', 'axes_mapping_d')) + edges.append(('axes_mapping_d', 'broadcast')) graph = build_graph(nodes, edges) broadcast_node = Node(graph, 'broadcast')