Failed MO shape infer after SplitConcatPairToInterpolate transformation, when Concat has more than one producer (#4502)
* Fixes in the MO transformation SplitConcatPairToInterpolate. * Small fix. * Small fix. * Added test for the case when inputs of Concat are two Splits. * Added docstring to the function get_concat_after_split. * Some fixes. * Small fix.
This commit is contained in:
parent
4946d2e62c
commit
473c944e6e
@ -1,5 +1,5 @@
|
|||||||
"""
|
"""
|
||||||
Copyright (c) 2020 Intel Corporation
|
Copyright (c) 2018-2021 Intel Corporation
|
||||||
|
|
||||||
Licensed under the Apache License, Version 2.0 (the "License");
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
you may not use this file except in compliance with the License.
|
you may not use this file except in compliance with the License.
|
||||||
@ -32,6 +32,19 @@ from mo.ops.strided_slice import StridedSlice
|
|||||||
|
|
||||||
|
|
||||||
def get_concat_after_split(split: Node) -> Optional[Node]:
|
def get_concat_after_split(split: Node) -> Optional[Node]:
|
||||||
|
"""
|
||||||
|
This function gets consumers of the 'split' node, checks that the following conditions are fulfilled:
|
||||||
|
1) 'split' node has only one consumer;
|
||||||
|
2) for any output port of 'split', number of corresponding input ports of the consumer is the same;
|
||||||
|
3) for any output port 'i' of the 'split', corresponding input ports of the consumer are
|
||||||
|
[i * m, ..., i * m + (m - 1)], where 'm' is the same for all 'i';
|
||||||
|
4) the consumer operation is 'Concat';
|
||||||
|
5) 'split' is a unique producer for this 'Concat';
|
||||||
|
and, if all these conditions are fulfilled, returns the above mentioned 'Concat' node. Otherwise, if some of these
|
||||||
|
conditions is false, this functions returns None.
|
||||||
|
:param split: Split node
|
||||||
|
:return: Concat node, if all conditions are fulfilled, or None otherwise
|
||||||
|
"""
|
||||||
# If number of output nodes of 'split' is not equal to 1, then the transformation is not applicable.
|
# If number of output nodes of 'split' is not equal to 1, then the transformation is not applicable.
|
||||||
split_outputs = [d.node for _, p in split.out_ports().items() for d in p.get_connection().get_destinations()]
|
split_outputs = [d.node for _, p in split.out_ports().items() for d in p.get_connection().get_destinations()]
|
||||||
names_of_split_outputs = set([n.name for n in split_outputs])
|
names_of_split_outputs = set([n.name for n in split_outputs])
|
||||||
@ -53,7 +66,17 @@ def get_concat_after_split(split: Node) -> Optional[Node]:
|
|||||||
|
|
||||||
dest = split.out_port(0).get_destinations()[0].node
|
dest = split.out_port(0).get_destinations()[0].node
|
||||||
# The transformation is applicable, only if next node is Concat.
|
# The transformation is applicable, only if next node is Concat.
|
||||||
return dest if dest.soft_get('type') == 'Concat' else None
|
if dest.soft_get('type') != 'Concat':
|
||||||
|
return
|
||||||
|
|
||||||
|
# The transformation is applicable, only if Split is a unique producer for Concat.
|
||||||
|
dest_inputs = [p.get_source().node for p in dest.in_ports().values() if not p.disconnected()]
|
||||||
|
names_of_concat_inputs = set([n.soft_get('name', n.id) for n in dest_inputs])
|
||||||
|
expected_number_of_unique_inputs = 1 if dest.has_valid('axis') else 2
|
||||||
|
if len(names_of_concat_inputs) != expected_number_of_unique_inputs:
|
||||||
|
return
|
||||||
|
|
||||||
|
return dest
|
||||||
|
|
||||||
|
|
||||||
def get_interpolate_pattern(split: Node) -> dict:
|
def get_interpolate_pattern(split: Node) -> dict:
|
||||||
|
@ -503,6 +503,81 @@ ref_graph_node_attrs_for_3d_spatial_case_2 = {
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
graph_node_attrs_when_there_are_two_splits_one_concat = {
|
||||||
|
'placeholder1': {'type': 'Parameter', 'kind': 'op', 'op': 'Parameter'},
|
||||||
|
'placeholder1_data': {
|
||||||
|
'value': None,
|
||||||
|
'shape': int64_array([1, 13, 13, 3, 2]),
|
||||||
|
'kind': 'data',
|
||||||
|
'data_type': None
|
||||||
|
},
|
||||||
|
'placeholder2': {'type': 'Parameter', 'kind': 'op', 'op': 'Parameter'},
|
||||||
|
'placeholder2_data': {
|
||||||
|
'value': None,
|
||||||
|
'shape': int64_array([1, 13, 13, 3, 2]),
|
||||||
|
'kind': 'data',
|
||||||
|
'data_type': None
|
||||||
|
},
|
||||||
|
'split1': {'type': 'Split', 'kind': 'op', 'op': 'Split', 'num_splits': 2},
|
||||||
|
'split1_axis_const': {
|
||||||
|
'kind': 'op',
|
||||||
|
'value': np.array(4, dtype=np.int64),
|
||||||
|
'op': 'Const',
|
||||||
|
'type': 'Const'
|
||||||
|
},
|
||||||
|
'split1_axis_const_data': {
|
||||||
|
'value': np.array(4, dtype=np.int64),
|
||||||
|
'shape': np.array(4, dtype=np.int64).shape,
|
||||||
|
'kind': 'data'
|
||||||
|
},
|
||||||
|
'split2': {'type': 'Split', 'kind': 'op', 'op': 'Split', 'num_splits': 2},
|
||||||
|
'split2_axis_const': {
|
||||||
|
'kind': 'op',
|
||||||
|
'value': np.array(4, dtype=np.int64),
|
||||||
|
'op': 'Const',
|
||||||
|
'type': 'Const'
|
||||||
|
},
|
||||||
|
'split2_axis_const_data': {
|
||||||
|
'value': np.array(4, dtype=np.int64),
|
||||||
|
'shape': np.array(4, dtype=np.int64).shape,
|
||||||
|
'kind': 'data'
|
||||||
|
},
|
||||||
|
'split1_data_0': {'value': None, 'shape': int64_array([1, 13, 13, 3, 1]), 'kind': 'data'},
|
||||||
|
'split1_data_1': {'value': None, 'shape': int64_array([1, 13, 13, 3, 1]), 'kind': 'data'},
|
||||||
|
'split2_data_0': {'value': None, 'shape': int64_array([1, 13, 13, 3, 1]), 'kind': 'data'},
|
||||||
|
'split2_data_1': {'value': None, 'shape': int64_array([1, 13, 13, 3, 1]), 'kind': 'data'},
|
||||||
|
'concat': {'type': 'Concat', 'kind': 'op', 'axis': 4},
|
||||||
|
'concat_data': {'value': None, 'shape': int64_array([1, 13, 13, 3, 4]), 'kind': 'data'},
|
||||||
|
'abs': {'type': 'Abs', 'kind': 'op', 'op': 'Abs'},
|
||||||
|
'abs_data': {'value': None, 'shape': int64_array([1, 13, 13, 3, 4]), 'kind': 'data'},
|
||||||
|
'output': {'kind': 'op', 'op': 'Result'},
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
graph_edges_when_there_are_two_splits_one_concat = [
|
||||||
|
('placeholder1', 'placeholder1_data'),
|
||||||
|
('placeholder2', 'placeholder2_data'),
|
||||||
|
('placeholder1_data', 'split1', {'in': 0}),
|
||||||
|
('split1_axis_const', 'split1_axis_const_data'),
|
||||||
|
('split1_axis_const_data', 'split1', {'in': 1}),
|
||||||
|
('split1', 'split1_data_0', {'out': 0}),
|
||||||
|
('split1', 'split1_data_1', {'out': 1}),
|
||||||
|
('placeholder2_data', 'split2', {'in': 0}),
|
||||||
|
('split2_axis_const', 'split2_axis_const_data'),
|
||||||
|
('split2_axis_const_data', 'split2', {'in': 1}),
|
||||||
|
('split2', 'split2_data_0', {'out': 0}),
|
||||||
|
('split2', 'split2_data_1', {'out': 1}),
|
||||||
|
('split1_data_0', 'concat', {'in': 0}),
|
||||||
|
('split1_data_1', 'concat', {'in': 1}),
|
||||||
|
('split2_data_0', 'concat', {'in': 2}),
|
||||||
|
('split2_data_1', 'concat', {'in': 3}),
|
||||||
|
('concat', 'concat_data'),
|
||||||
|
('concat_data', 'abs'),
|
||||||
|
('abs', 'abs_data'),
|
||||||
|
('abs_data', 'output')
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
class SplitConcatPairToInterpolateTest(unittest.TestCase):
|
class SplitConcatPairToInterpolateTest(unittest.TestCase):
|
||||||
def test_spatial_2d_split_concat_1(self):
|
def test_spatial_2d_split_concat_1(self):
|
||||||
graph = build_graph(
|
graph = build_graph(
|
||||||
@ -601,3 +676,16 @@ class SplitConcatPairToInterpolateTest(unittest.TestCase):
|
|||||||
SplitConcatPairToInterpolate().find_and_replace_pattern(graph)
|
SplitConcatPairToInterpolate().find_and_replace_pattern(graph)
|
||||||
(flag, resp) = compare_graphs(graph, ref_graph, 'output')
|
(flag, resp) = compare_graphs(graph, ref_graph, 'output')
|
||||||
self.assertTrue(flag, resp)
|
self.assertTrue(flag, resp)
|
||||||
|
|
||||||
|
def test_two_splits_one_concat(self):
|
||||||
|
graph = build_graph(
|
||||||
|
nodes_attrs=graph_node_attrs_when_there_are_two_splits_one_concat,
|
||||||
|
edges=graph_edges_when_there_are_two_splits_one_concat
|
||||||
|
)
|
||||||
|
ref_graph = build_graph(
|
||||||
|
nodes_attrs=graph_node_attrs_when_there_are_two_splits_one_concat,
|
||||||
|
edges=graph_edges_when_there_are_two_splits_one_concat
|
||||||
|
)
|
||||||
|
SplitConcatPairToInterpolate().find_and_replace_pattern(graph)
|
||||||
|
(flag, resp) = compare_graphs(graph, ref_graph, 'output')
|
||||||
|
self.assertTrue(flag, resp)
|
||||||
|
Loading…
Reference in New Issue
Block a user