[MO] Add support for dynamic cases in legacy Broadcast shape infer (#15546)
* [MO] Add support for dynamic cases in legacy Broadcast shape infer * Update broadcast_test.py * Update broadcast.py * Update broadcast_test.py * Update broadcast_test.py
This commit is contained in:
parent
70177dbfb3
commit
de4f9c8131
@ -49,16 +49,21 @@ class Broadcast(Op):
|
||||
target_shape_shape = node.in_port(1).data.get_shape()
|
||||
target_shape = node.in_port(1).data.get_value()
|
||||
assert node.has_and_set('mode'), 'Broadcasting mode is not defined for node "{}"'.format(node_name)
|
||||
# Dynamic target shape is possible to infer only if shape of target shape is static and 1D
|
||||
if target_shape is None and len(target_shape_shape) == 1 and (len(input_shape) <= 1 or node.mode == 'explicit'):
|
||||
assert is_fully_defined(target_shape_shape)
|
||||
new_shape = undefined_shape_of_rank(target_shape_shape.item(0))
|
||||
node.out_port(0).data.set_shape(new_shape)
|
||||
return
|
||||
assert target_shape is not None, 'Output shape is not defined for node "{}"'.format(node_name)
|
||||
|
||||
PermuteInputs().set_input_permutation(node.in_node(1), node, 'output:0', 'shape')
|
||||
|
||||
# Dynamic target shape is possible to infer only if shape of target shape is static
|
||||
if target_shape is None:
|
||||
assert len(target_shape_shape) == 1, 'Shape of target_shape must be [1] for node "{}"'.format(node_name)
|
||||
assert is_fully_defined(target_shape_shape), 'Output shape is not defined for node "{}"'.format(node_name)
|
||||
new_shape = undefined_shape_of_rank(target_shape_shape.item(0))
|
||||
node.out_port(0).data.set_shape(new_shape)
|
||||
if node.mode == 'explicit':
|
||||
assert node.is_in_port_connected(
|
||||
2), 'Axes mapping must be specified for Broadcast(mode="explicit"). Node: `{}`'.format(node_name)
|
||||
PermuteInputs().set_input_permutation(node.in_node(2), node, 'output:0', 'axis')
|
||||
return
|
||||
|
||||
if input_value is not None and not node.has_and_set('stop_value_propagation') and \
|
||||
is_fully_defined(target_shape):
|
||||
if node.mode == 'numpy':
|
||||
|
@ -78,11 +78,15 @@ class BroadcastTest(unittest.TestCase):
|
||||
self.assertTrue(np.array_equal(broadcast_node.out_node().shape, np.array(target_shape)))
|
||||
|
||||
@generate(*[
|
||||
([1], [3], 'numpy', undefined_shape_of_rank(3)),
|
||||
([1], [3], 'explicit', undefined_shape_of_rank(3)),
|
||||
([1, 2], [3], 'numpy', None, True),
|
||||
([1], [3], [0], 'explicit', undefined_shape_of_rank(3)),
|
||||
([1], [3], None, 'numpy', undefined_shape_of_rank(3)),
|
||||
([1], [3], None, 'bidirectional', undefined_shape_of_rank(3)),
|
||||
([1, 7], [4], [1, 2], 'explicit', undefined_shape_of_rank(4)),
|
||||
([1, 2], [3], None, 'numpy', undefined_shape_of_rank(3)),
|
||||
([1, 1], [2], None, 'bidirectional', undefined_shape_of_rank(2)),
|
||||
([1, 1], [2, 1], None, 'numpy', None, True),
|
||||
])
|
||||
def test_broadcast_dynamic(self, data, target_shape_shape, mode='numpy', ref_out_shape=None, test_raising=False):
|
||||
def test_broadcast_dynamic(self, data, target_shape_shape, axes_mapping=None, mode='numpy', ref_out_shape=None, test_raising=False):
|
||||
nodes = {
|
||||
**shaped_data('data', int64_array(data)),
|
||||
**shaped_data('target_shape', int64_array(target_shape_shape)),
|
||||
@ -93,6 +97,10 @@ class BroadcastTest(unittest.TestCase):
|
||||
('target_shape', 'broadcast'),
|
||||
('broadcast', 'broadcast_d')]
|
||||
|
||||
if axes_mapping is not None:
|
||||
nodes.update(**valued_const_with_data('axes_mapping', int64_array(axes_mapping)))
|
||||
edges.append(('axes_mapping', 'axes_mapping_d'))
|
||||
edges.append(('axes_mapping_d', 'broadcast'))
|
||||
graph = build_graph(nodes, edges)
|
||||
|
||||
broadcast_node = Node(graph, 'broadcast')
|
||||
|
Loading…
Reference in New Issue
Block a user