From 473c944e6ec3247583f2ea52bfe1997a35beabf7 Mon Sep 17 00:00:00 2001 From: Vladimir Gavrilov Date: Tue, 2 Mar 2021 14:09:42 +0300 Subject: [PATCH] Failed MO shape infer after SplitConcatPairToInterpolate transformation, when Concat has more than one producer (#4502) * Fixes in the MO transformation SplitConcatPairToInterpolate. * Small fix. * Small fix. * Added test for the case when inputs of Concat are two Splits. * Added docstring to the function get_concat_after_split. * Some fixes. * Small fix. --- .../middle/SplitConcatPairToInterpolate.py | 27 +++++- .../SplitConcatPairToInterpolate_test.py | 88 +++++++++++++++++++ 2 files changed, 113 insertions(+), 2 deletions(-) diff --git a/model-optimizer/extensions/middle/SplitConcatPairToInterpolate.py b/model-optimizer/extensions/middle/SplitConcatPairToInterpolate.py index 01d0353be0a..c037f8808be 100644 --- a/model-optimizer/extensions/middle/SplitConcatPairToInterpolate.py +++ b/model-optimizer/extensions/middle/SplitConcatPairToInterpolate.py @@ -1,5 +1,5 @@ """ - Copyright (c) 2020 Intel Corporation + Copyright (c) 2018-2021 Intel Corporation Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -32,6 +32,19 @@ from mo.ops.strided_slice import StridedSlice def get_concat_after_split(split: Node) -> Optional[Node]: + """ + This function gets consumers of the 'split' node, checks that the following conditions are fulfilled: + 1) 'split' node has only one consumer; + 2) for any output port of 'split', number of corresponding input ports of the consumer is the same; + 3) for any output port 'i' of the 'split', corresponding input ports of the consumer are + [i * m, ..., i * m + (m - 1)], where 'm' is the same for all 'i'; + 4) the consumer operation is 'Concat'; + 5) 'split' is a unique producer for this 'Concat'; + and, if all these conditions are fulfilled, returns the above mentioned 'Concat' node. Otherwise, if some of these + conditions is false, this functions returns None. + :param split: Split node + :return: Concat node, if all conditions are fulfilled, or None otherwise + """ # If number of output nodes of 'split' is not equal to 1, then the transformation is not applicable. split_outputs = [d.node for _, p in split.out_ports().items() for d in p.get_connection().get_destinations()] names_of_split_outputs = set([n.name for n in split_outputs]) @@ -53,7 +66,17 @@ def get_concat_after_split(split: Node) -> Optional[Node]: dest = split.out_port(0).get_destinations()[0].node # The transformation is applicable, only if next node is Concat. - return dest if dest.soft_get('type') == 'Concat' else None + if dest.soft_get('type') != 'Concat': + return + + # The transformation is applicable, only if Split is a unique producer for Concat. + dest_inputs = [p.get_source().node for p in dest.in_ports().values() if not p.disconnected()] + names_of_concat_inputs = set([n.soft_get('name', n.id) for n in dest_inputs]) + expected_number_of_unique_inputs = 1 if dest.has_valid('axis') else 2 + if len(names_of_concat_inputs) != expected_number_of_unique_inputs: + return + + return dest def get_interpolate_pattern(split: Node) -> dict: diff --git a/model-optimizer/extensions/middle/SplitConcatPairToInterpolate_test.py b/model-optimizer/extensions/middle/SplitConcatPairToInterpolate_test.py index 72281f1f552..1843fece189 100644 --- a/model-optimizer/extensions/middle/SplitConcatPairToInterpolate_test.py +++ b/model-optimizer/extensions/middle/SplitConcatPairToInterpolate_test.py @@ -503,6 +503,81 @@ ref_graph_node_attrs_for_3d_spatial_case_2 = { } +graph_node_attrs_when_there_are_two_splits_one_concat = { + 'placeholder1': {'type': 'Parameter', 'kind': 'op', 'op': 'Parameter'}, + 'placeholder1_data': { + 'value': None, + 'shape': int64_array([1, 13, 13, 3, 2]), + 'kind': 'data', + 'data_type': None + }, + 'placeholder2': {'type': 'Parameter', 'kind': 'op', 'op': 'Parameter'}, + 'placeholder2_data': { + 'value': None, + 'shape': int64_array([1, 13, 13, 3, 2]), + 'kind': 'data', + 'data_type': None + }, + 'split1': {'type': 'Split', 'kind': 'op', 'op': 'Split', 'num_splits': 2}, + 'split1_axis_const': { + 'kind': 'op', + 'value': np.array(4, dtype=np.int64), + 'op': 'Const', + 'type': 'Const' + }, + 'split1_axis_const_data': { + 'value': np.array(4, dtype=np.int64), + 'shape': np.array(4, dtype=np.int64).shape, + 'kind': 'data' + }, + 'split2': {'type': 'Split', 'kind': 'op', 'op': 'Split', 'num_splits': 2}, + 'split2_axis_const': { + 'kind': 'op', + 'value': np.array(4, dtype=np.int64), + 'op': 'Const', + 'type': 'Const' + }, + 'split2_axis_const_data': { + 'value': np.array(4, dtype=np.int64), + 'shape': np.array(4, dtype=np.int64).shape, + 'kind': 'data' + }, + 'split1_data_0': {'value': None, 'shape': int64_array([1, 13, 13, 3, 1]), 'kind': 'data'}, + 'split1_data_1': {'value': None, 'shape': int64_array([1, 13, 13, 3, 1]), 'kind': 'data'}, + 'split2_data_0': {'value': None, 'shape': int64_array([1, 13, 13, 3, 1]), 'kind': 'data'}, + 'split2_data_1': {'value': None, 'shape': int64_array([1, 13, 13, 3, 1]), 'kind': 'data'}, + 'concat': {'type': 'Concat', 'kind': 'op', 'axis': 4}, + 'concat_data': {'value': None, 'shape': int64_array([1, 13, 13, 3, 4]), 'kind': 'data'}, + 'abs': {'type': 'Abs', 'kind': 'op', 'op': 'Abs'}, + 'abs_data': {'value': None, 'shape': int64_array([1, 13, 13, 3, 4]), 'kind': 'data'}, + 'output': {'kind': 'op', 'op': 'Result'}, +} + + +graph_edges_when_there_are_two_splits_one_concat = [ + ('placeholder1', 'placeholder1_data'), + ('placeholder2', 'placeholder2_data'), + ('placeholder1_data', 'split1', {'in': 0}), + ('split1_axis_const', 'split1_axis_const_data'), + ('split1_axis_const_data', 'split1', {'in': 1}), + ('split1', 'split1_data_0', {'out': 0}), + ('split1', 'split1_data_1', {'out': 1}), + ('placeholder2_data', 'split2', {'in': 0}), + ('split2_axis_const', 'split2_axis_const_data'), + ('split2_axis_const_data', 'split2', {'in': 1}), + ('split2', 'split2_data_0', {'out': 0}), + ('split2', 'split2_data_1', {'out': 1}), + ('split1_data_0', 'concat', {'in': 0}), + ('split1_data_1', 'concat', {'in': 1}), + ('split2_data_0', 'concat', {'in': 2}), + ('split2_data_1', 'concat', {'in': 3}), + ('concat', 'concat_data'), + ('concat_data', 'abs'), + ('abs', 'abs_data'), + ('abs_data', 'output') +] + + class SplitConcatPairToInterpolateTest(unittest.TestCase): def test_spatial_2d_split_concat_1(self): graph = build_graph( @@ -601,3 +676,16 @@ class SplitConcatPairToInterpolateTest(unittest.TestCase): SplitConcatPairToInterpolate().find_and_replace_pattern(graph) (flag, resp) = compare_graphs(graph, ref_graph, 'output') self.assertTrue(flag, resp) + + def test_two_splits_one_concat(self): + graph = build_graph( + nodes_attrs=graph_node_attrs_when_there_are_two_splits_one_concat, + edges=graph_edges_when_there_are_two_splits_one_concat + ) + ref_graph = build_graph( + nodes_attrs=graph_node_attrs_when_there_are_two_splits_one_concat, + edges=graph_edges_when_there_are_two_splits_one_concat + ) + SplitConcatPairToInterpolate().find_and_replace_pattern(graph) + (flag, resp) = compare_graphs(graph, ref_graph, 'output') + self.assertTrue(flag, resp)