Failed MO shape infer after SplitConcatPairToInterpolate transformation, when Concat has more than one producer (#4502)

* Fixes in the MO transformation SplitConcatPairToInterpolate.

* Small fix.

* Small fix.

* Added test for the case when inputs of Concat are two Splits.

* Added docstring to the function get_concat_after_split.

* Some fixes.

* Small fix.
This commit is contained in:
Vladimir Gavrilov 2021-03-02 14:09:42 +03:00 committed by GitHub
parent 4946d2e62c
commit 473c944e6e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 113 additions and 2 deletions

View File

@ -1,5 +1,5 @@
""" """
Copyright (c) 2020 Intel Corporation Copyright (c) 2018-2021 Intel Corporation
Licensed under the Apache License, Version 2.0 (the "License"); Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License. you may not use this file except in compliance with the License.
@ -32,6 +32,19 @@ from mo.ops.strided_slice import StridedSlice
def get_concat_after_split(split: Node) -> Optional[Node]: def get_concat_after_split(split: Node) -> Optional[Node]:
"""
This function gets consumers of the 'split' node, checks that the following conditions are fulfilled:
1) 'split' node has only one consumer;
2) for any output port of 'split', number of corresponding input ports of the consumer is the same;
3) for any output port 'i' of the 'split', corresponding input ports of the consumer are
[i * m, ..., i * m + (m - 1)], where 'm' is the same for all 'i';
4) the consumer operation is 'Concat';
5) 'split' is a unique producer for this 'Concat';
and, if all these conditions are fulfilled, returns the above mentioned 'Concat' node. Otherwise, if some of these
conditions is false, this functions returns None.
:param split: Split node
:return: Concat node, if all conditions are fulfilled, or None otherwise
"""
# If number of output nodes of 'split' is not equal to 1, then the transformation is not applicable. # If number of output nodes of 'split' is not equal to 1, then the transformation is not applicable.
split_outputs = [d.node for _, p in split.out_ports().items() for d in p.get_connection().get_destinations()] split_outputs = [d.node for _, p in split.out_ports().items() for d in p.get_connection().get_destinations()]
names_of_split_outputs = set([n.name for n in split_outputs]) names_of_split_outputs = set([n.name for n in split_outputs])
@ -53,7 +66,17 @@ def get_concat_after_split(split: Node) -> Optional[Node]:
dest = split.out_port(0).get_destinations()[0].node dest = split.out_port(0).get_destinations()[0].node
# The transformation is applicable, only if next node is Concat. # The transformation is applicable, only if next node is Concat.
return dest if dest.soft_get('type') == 'Concat' else None if dest.soft_get('type') != 'Concat':
return
# The transformation is applicable, only if Split is a unique producer for Concat.
dest_inputs = [p.get_source().node for p in dest.in_ports().values() if not p.disconnected()]
names_of_concat_inputs = set([n.soft_get('name', n.id) for n in dest_inputs])
expected_number_of_unique_inputs = 1 if dest.has_valid('axis') else 2
if len(names_of_concat_inputs) != expected_number_of_unique_inputs:
return
return dest
def get_interpolate_pattern(split: Node) -> dict: def get_interpolate_pattern(split: Node) -> dict:

View File

@ -503,6 +503,81 @@ ref_graph_node_attrs_for_3d_spatial_case_2 = {
} }
graph_node_attrs_when_there_are_two_splits_one_concat = {
'placeholder1': {'type': 'Parameter', 'kind': 'op', 'op': 'Parameter'},
'placeholder1_data': {
'value': None,
'shape': int64_array([1, 13, 13, 3, 2]),
'kind': 'data',
'data_type': None
},
'placeholder2': {'type': 'Parameter', 'kind': 'op', 'op': 'Parameter'},
'placeholder2_data': {
'value': None,
'shape': int64_array([1, 13, 13, 3, 2]),
'kind': 'data',
'data_type': None
},
'split1': {'type': 'Split', 'kind': 'op', 'op': 'Split', 'num_splits': 2},
'split1_axis_const': {
'kind': 'op',
'value': np.array(4, dtype=np.int64),
'op': 'Const',
'type': 'Const'
},
'split1_axis_const_data': {
'value': np.array(4, dtype=np.int64),
'shape': np.array(4, dtype=np.int64).shape,
'kind': 'data'
},
'split2': {'type': 'Split', 'kind': 'op', 'op': 'Split', 'num_splits': 2},
'split2_axis_const': {
'kind': 'op',
'value': np.array(4, dtype=np.int64),
'op': 'Const',
'type': 'Const'
},
'split2_axis_const_data': {
'value': np.array(4, dtype=np.int64),
'shape': np.array(4, dtype=np.int64).shape,
'kind': 'data'
},
'split1_data_0': {'value': None, 'shape': int64_array([1, 13, 13, 3, 1]), 'kind': 'data'},
'split1_data_1': {'value': None, 'shape': int64_array([1, 13, 13, 3, 1]), 'kind': 'data'},
'split2_data_0': {'value': None, 'shape': int64_array([1, 13, 13, 3, 1]), 'kind': 'data'},
'split2_data_1': {'value': None, 'shape': int64_array([1, 13, 13, 3, 1]), 'kind': 'data'},
'concat': {'type': 'Concat', 'kind': 'op', 'axis': 4},
'concat_data': {'value': None, 'shape': int64_array([1, 13, 13, 3, 4]), 'kind': 'data'},
'abs': {'type': 'Abs', 'kind': 'op', 'op': 'Abs'},
'abs_data': {'value': None, 'shape': int64_array([1, 13, 13, 3, 4]), 'kind': 'data'},
'output': {'kind': 'op', 'op': 'Result'},
}
graph_edges_when_there_are_two_splits_one_concat = [
('placeholder1', 'placeholder1_data'),
('placeholder2', 'placeholder2_data'),
('placeholder1_data', 'split1', {'in': 0}),
('split1_axis_const', 'split1_axis_const_data'),
('split1_axis_const_data', 'split1', {'in': 1}),
('split1', 'split1_data_0', {'out': 0}),
('split1', 'split1_data_1', {'out': 1}),
('placeholder2_data', 'split2', {'in': 0}),
('split2_axis_const', 'split2_axis_const_data'),
('split2_axis_const_data', 'split2', {'in': 1}),
('split2', 'split2_data_0', {'out': 0}),
('split2', 'split2_data_1', {'out': 1}),
('split1_data_0', 'concat', {'in': 0}),
('split1_data_1', 'concat', {'in': 1}),
('split2_data_0', 'concat', {'in': 2}),
('split2_data_1', 'concat', {'in': 3}),
('concat', 'concat_data'),
('concat_data', 'abs'),
('abs', 'abs_data'),
('abs_data', 'output')
]
class SplitConcatPairToInterpolateTest(unittest.TestCase): class SplitConcatPairToInterpolateTest(unittest.TestCase):
def test_spatial_2d_split_concat_1(self): def test_spatial_2d_split_concat_1(self):
graph = build_graph( graph = build_graph(
@ -601,3 +676,16 @@ class SplitConcatPairToInterpolateTest(unittest.TestCase):
SplitConcatPairToInterpolate().find_and_replace_pattern(graph) SplitConcatPairToInterpolate().find_and_replace_pattern(graph)
(flag, resp) = compare_graphs(graph, ref_graph, 'output') (flag, resp) = compare_graphs(graph, ref_graph, 'output')
self.assertTrue(flag, resp) self.assertTrue(flag, resp)
def test_two_splits_one_concat(self):
graph = build_graph(
nodes_attrs=graph_node_attrs_when_there_are_two_splits_one_concat,
edges=graph_edges_when_there_are_two_splits_one_concat
)
ref_graph = build_graph(
nodes_attrs=graph_node_attrs_when_there_are_two_splits_one_concat,
edges=graph_edges_when_there_are_two_splits_one_concat
)
SplitConcatPairToInterpolate().find_and_replace_pattern(graph)
(flag, resp) = compare_graphs(graph, ref_graph, 'output')
self.assertTrue(flag, resp)