Failed MO shape infer after SplitConcatPairToInterpolate transformation, when Concat has more than one producer (#4502)
* Fixes in the MO transformation SplitConcatPairToInterpolate. * Small fix. * Small fix. * Added test for the case when inputs of Concat are two Splits. * Added docstring to the function get_concat_after_split. * Some fixes. * Small fix.
This commit is contained in:
parent
4946d2e62c
commit
473c944e6e
@ -1,5 +1,5 @@
|
||||
"""
|
||||
Copyright (c) 2020 Intel Corporation
|
||||
Copyright (c) 2018-2021 Intel Corporation
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
@ -32,6 +32,19 @@ from mo.ops.strided_slice import StridedSlice
|
||||
|
||||
|
||||
def get_concat_after_split(split: Node) -> Optional[Node]:
|
||||
"""
|
||||
This function gets consumers of the 'split' node, checks that the following conditions are fulfilled:
|
||||
1) 'split' node has only one consumer;
|
||||
2) for any output port of 'split', number of corresponding input ports of the consumer is the same;
|
||||
3) for any output port 'i' of the 'split', corresponding input ports of the consumer are
|
||||
[i * m, ..., i * m + (m - 1)], where 'm' is the same for all 'i';
|
||||
4) the consumer operation is 'Concat';
|
||||
5) 'split' is a unique producer for this 'Concat';
|
||||
and, if all these conditions are fulfilled, returns the above mentioned 'Concat' node. Otherwise, if some of these
|
||||
conditions is false, this functions returns None.
|
||||
:param split: Split node
|
||||
:return: Concat node, if all conditions are fulfilled, or None otherwise
|
||||
"""
|
||||
# If number of output nodes of 'split' is not equal to 1, then the transformation is not applicable.
|
||||
split_outputs = [d.node for _, p in split.out_ports().items() for d in p.get_connection().get_destinations()]
|
||||
names_of_split_outputs = set([n.name for n in split_outputs])
|
||||
@ -53,7 +66,17 @@ def get_concat_after_split(split: Node) -> Optional[Node]:
|
||||
|
||||
dest = split.out_port(0).get_destinations()[0].node
|
||||
# The transformation is applicable, only if next node is Concat.
|
||||
return dest if dest.soft_get('type') == 'Concat' else None
|
||||
if dest.soft_get('type') != 'Concat':
|
||||
return
|
||||
|
||||
# The transformation is applicable, only if Split is a unique producer for Concat.
|
||||
dest_inputs = [p.get_source().node for p in dest.in_ports().values() if not p.disconnected()]
|
||||
names_of_concat_inputs = set([n.soft_get('name', n.id) for n in dest_inputs])
|
||||
expected_number_of_unique_inputs = 1 if dest.has_valid('axis') else 2
|
||||
if len(names_of_concat_inputs) != expected_number_of_unique_inputs:
|
||||
return
|
||||
|
||||
return dest
|
||||
|
||||
|
||||
def get_interpolate_pattern(split: Node) -> dict:
|
||||
|
@ -503,6 +503,81 @@ ref_graph_node_attrs_for_3d_spatial_case_2 = {
|
||||
}
|
||||
|
||||
|
||||
graph_node_attrs_when_there_are_two_splits_one_concat = {
|
||||
'placeholder1': {'type': 'Parameter', 'kind': 'op', 'op': 'Parameter'},
|
||||
'placeholder1_data': {
|
||||
'value': None,
|
||||
'shape': int64_array([1, 13, 13, 3, 2]),
|
||||
'kind': 'data',
|
||||
'data_type': None
|
||||
},
|
||||
'placeholder2': {'type': 'Parameter', 'kind': 'op', 'op': 'Parameter'},
|
||||
'placeholder2_data': {
|
||||
'value': None,
|
||||
'shape': int64_array([1, 13, 13, 3, 2]),
|
||||
'kind': 'data',
|
||||
'data_type': None
|
||||
},
|
||||
'split1': {'type': 'Split', 'kind': 'op', 'op': 'Split', 'num_splits': 2},
|
||||
'split1_axis_const': {
|
||||
'kind': 'op',
|
||||
'value': np.array(4, dtype=np.int64),
|
||||
'op': 'Const',
|
||||
'type': 'Const'
|
||||
},
|
||||
'split1_axis_const_data': {
|
||||
'value': np.array(4, dtype=np.int64),
|
||||
'shape': np.array(4, dtype=np.int64).shape,
|
||||
'kind': 'data'
|
||||
},
|
||||
'split2': {'type': 'Split', 'kind': 'op', 'op': 'Split', 'num_splits': 2},
|
||||
'split2_axis_const': {
|
||||
'kind': 'op',
|
||||
'value': np.array(4, dtype=np.int64),
|
||||
'op': 'Const',
|
||||
'type': 'Const'
|
||||
},
|
||||
'split2_axis_const_data': {
|
||||
'value': np.array(4, dtype=np.int64),
|
||||
'shape': np.array(4, dtype=np.int64).shape,
|
||||
'kind': 'data'
|
||||
},
|
||||
'split1_data_0': {'value': None, 'shape': int64_array([1, 13, 13, 3, 1]), 'kind': 'data'},
|
||||
'split1_data_1': {'value': None, 'shape': int64_array([1, 13, 13, 3, 1]), 'kind': 'data'},
|
||||
'split2_data_0': {'value': None, 'shape': int64_array([1, 13, 13, 3, 1]), 'kind': 'data'},
|
||||
'split2_data_1': {'value': None, 'shape': int64_array([1, 13, 13, 3, 1]), 'kind': 'data'},
|
||||
'concat': {'type': 'Concat', 'kind': 'op', 'axis': 4},
|
||||
'concat_data': {'value': None, 'shape': int64_array([1, 13, 13, 3, 4]), 'kind': 'data'},
|
||||
'abs': {'type': 'Abs', 'kind': 'op', 'op': 'Abs'},
|
||||
'abs_data': {'value': None, 'shape': int64_array([1, 13, 13, 3, 4]), 'kind': 'data'},
|
||||
'output': {'kind': 'op', 'op': 'Result'},
|
||||
}
|
||||
|
||||
|
||||
graph_edges_when_there_are_two_splits_one_concat = [
|
||||
('placeholder1', 'placeholder1_data'),
|
||||
('placeholder2', 'placeholder2_data'),
|
||||
('placeholder1_data', 'split1', {'in': 0}),
|
||||
('split1_axis_const', 'split1_axis_const_data'),
|
||||
('split1_axis_const_data', 'split1', {'in': 1}),
|
||||
('split1', 'split1_data_0', {'out': 0}),
|
||||
('split1', 'split1_data_1', {'out': 1}),
|
||||
('placeholder2_data', 'split2', {'in': 0}),
|
||||
('split2_axis_const', 'split2_axis_const_data'),
|
||||
('split2_axis_const_data', 'split2', {'in': 1}),
|
||||
('split2', 'split2_data_0', {'out': 0}),
|
||||
('split2', 'split2_data_1', {'out': 1}),
|
||||
('split1_data_0', 'concat', {'in': 0}),
|
||||
('split1_data_1', 'concat', {'in': 1}),
|
||||
('split2_data_0', 'concat', {'in': 2}),
|
||||
('split2_data_1', 'concat', {'in': 3}),
|
||||
('concat', 'concat_data'),
|
||||
('concat_data', 'abs'),
|
||||
('abs', 'abs_data'),
|
||||
('abs_data', 'output')
|
||||
]
|
||||
|
||||
|
||||
class SplitConcatPairToInterpolateTest(unittest.TestCase):
|
||||
def test_spatial_2d_split_concat_1(self):
|
||||
graph = build_graph(
|
||||
@ -601,3 +676,16 @@ class SplitConcatPairToInterpolateTest(unittest.TestCase):
|
||||
SplitConcatPairToInterpolate().find_and_replace_pattern(graph)
|
||||
(flag, resp) = compare_graphs(graph, ref_graph, 'output')
|
||||
self.assertTrue(flag, resp)
|
||||
|
||||
def test_two_splits_one_concat(self):
|
||||
graph = build_graph(
|
||||
nodes_attrs=graph_node_attrs_when_there_are_two_splits_one_concat,
|
||||
edges=graph_edges_when_there_are_two_splits_one_concat
|
||||
)
|
||||
ref_graph = build_graph(
|
||||
nodes_attrs=graph_node_attrs_when_there_are_two_splits_one_concat,
|
||||
edges=graph_edges_when_there_are_two_splits_one_concat
|
||||
)
|
||||
SplitConcatPairToInterpolate().find_and_replace_pattern(graph)
|
||||
(flag, resp) = compare_graphs(graph, ref_graph, 'output')
|
||||
self.assertTrue(flag, resp)
|
||||
|
Loading…
Reference in New Issue
Block a user