From 4d9fe14ec62641e7ab480d08cd70be0151aecd62 Mon Sep 17 00:00:00 2001 From: Yegor Kruglov Date: Mon, 7 Jun 2021 20:22:26 +0300 Subject: [PATCH] [MO] ConvolutionWithGroupResolver update to enable TF DepthwiseConv2dNative with in_channels=1 (#5528) * changed permutation attribute in conv extractor * changed conv get_group parameter * implemented a transformation * updated BOM * specified transformation for in_channels 1 * added unittest and comment string * updated convolution normalizer to convert depthwise convolution with group=1 to group convolution * renamed function * updated IR reader * conversations resolving * condition change --- .../extensions/back/ConvolutionNormalizer.py | 106 ++++++++---------- .../mo/utils/ir_reader/layer_to_class.py | 7 ++ .../back/ConvolutionNormalizer_test.py | 72 +++++++++++- 3 files changed, 123 insertions(+), 62 deletions(-) diff --git a/model-optimizer/extensions/back/ConvolutionNormalizer.py b/model-optimizer/extensions/back/ConvolutionNormalizer.py index 04da65631c5..0abfbec0e20 100644 --- a/model-optimizer/extensions/back/ConvolutionNormalizer.py +++ b/model-optimizer/extensions/back/ConvolutionNormalizer.py @@ -8,10 +8,43 @@ from extensions.back.ReverseInputChannels import ApplyReverseChannels from mo.back.replacement import BackReplacementPattern from mo.front.common.partial_infer.utils import int64_array from mo.front.tf.graph_utils import create_op_node_with_second_input, create_op_with_const_inputs -from mo.graph.graph import Graph +from mo.graph.graph import Graph, Node from mo.ops.const import Const from mo.ops.reshape import Reshape from mo.ops.strided_slice import StridedSlice +from mo.utils.error import Error + + +def resolve_convolution_with_group(node: Node, group: int, ir_version: str): + input_shape = node.in_port(0).data.get_shape() + assert len(input_shape) in [3, 4, 5] + + weights_shape = node.in_port(1).data.get_shape() + assert weights_shape is not None + assert len(weights_shape) in [3, 4, 5] + assert weights_shape[0] % group == 0 + + assert int64_array(node.output).ndim == 0 + if ir_version == 'V7': + if weights_shape[0] == node.output: + # weights are already is in [G*O I X Y] format + return + new_shape = int64_array([node.output, -1, *weights_shape[2:]]) + + elif ir_version == 'V10': + I = input_shape[1] + new_shape = int64_array([group, node.output / group, I / group, *weights_shape[2:]]) + assert np.prod(weights_shape) == np.prod(new_shape), \ + 'Initial weights shape {}, grouped weights shape {}'.format(weights_shape, new_shape) + del node['group'] + node['type'] = 'GroupConvolution' + else: + raise Error("Unknown IR version: {}".format(ir_version)) + + reshape = create_op_node_with_second_input(node.graph, Reshape, int64_array(new_shape), + {'override_output_shape': True}) + + node.in_port(1).get_connection().insert_node(reshape) class ConvolutionNormalizer(BackReplacementPattern): @@ -37,33 +70,12 @@ class V7ConvolutionWithGroupsResolver(BackReplacementPattern): """ enabled = False - @staticmethod - def pattern(): - return dict( - nodes=[ - ('node', dict(type='Convolution', group=lambda g: g is not None and g != 1)) - ], - edges=[] - ) - - def replace_pattern(self, graph: Graph, match: dict): - node = match['node'] - - group = node.group - assert group > 1 - - weights_shape = node.in_port(1).data.get_shape() - assert weights_shape is not None - assert weights_shape[0] % group == 0 - - if weights_shape[0] == node.output: - # weights are already is in [G*O I X Y] format - return - - new_shape = int64_array([node.output, -1, *weights_shape[2:]]) - reshape = create_op_node_with_second_input(graph, Reshape, int64_array(new_shape), - {'override_output_shape': True}) - node.in_port(1).get_connection().insert_node(reshape) + def find_and_replace_pattern(self, graph: Graph): + for node in graph.get_op_nodes(type='Convolution'): + group = node.soft_get('group', None) + if group is not None: + if group != 1 or node.soft_get('op') == 'DepthwiseConv2dNative': + resolve_convolution_with_group(node, group, ir_version='V7') class V10ConvolutionWithGroupsResolver(BackReplacementPattern): @@ -73,38 +85,12 @@ class V10ConvolutionWithGroupsResolver(BackReplacementPattern): """ enabled = False - @staticmethod - def pattern(): - return dict( - nodes=[ - ('node', dict(type='Convolution', group=lambda g: g is not None and g != 1)) - ], - edges=[] - ) - - def replace_pattern(self, graph: Graph, match: dict): - node = match['node'] - - group = node.group - assert group > 1 - - weights_shape = node.in_port(1).data.get_shape() - assert weights_shape is not None - assert weights_shape[0] % group == 0 - I = node.in_port(0).data.get_shape()[1] - - new_shape = int64_array([group, node.output / group, I / group, *weights_shape[2:]]) - - assert np.prod(weights_shape) == np.prod(new_shape), \ - 'Initial weights shape {}, grouped weights shape {}'.format(weights_shape, new_shape) - - del node['group'] - node['type'] = 'GroupConvolution' - - reshape = create_op_node_with_second_input(graph, Reshape, int64_array(new_shape), - {'override_output_shape': True}) - - node.in_port(1).get_connection().insert_node(reshape) + def find_and_replace_pattern(self, graph: Graph): + for node in graph.get_op_nodes(type='Convolution'): + group = node.soft_get('group', None) + if group is not None: + if group != 1 or node.soft_get('op') == 'DepthwiseConv2dNative': + resolve_convolution_with_group(node, group, ir_version='V10') class ConvolutionWithGroupsResolver(BackReplacementPattern): diff --git a/model-optimizer/mo/utils/ir_reader/layer_to_class.py b/model-optimizer/mo/utils/ir_reader/layer_to_class.py index d1afed6918d..f9f5fd9fb1d 100644 --- a/model-optimizer/mo/utils/ir_reader/layer_to_class.py +++ b/model-optimizer/mo/utils/ir_reader/layer_to_class.py @@ -195,6 +195,13 @@ def groupconv_to_conv(op: Node): 'Weight shape and calculated shape mismatch in GroupConv node {}.'.format(op.name) # we need to set this attrs for correct shape infer as convolution op['group'] = group + # The only way GroupConvolution with 'group' = 1 appears in IR is by converting from TF DepthwiseConv2dNative. + # In this case we need to specify 'op' parameter for the + # extensions.back.ConvolutionNormalizer.ConvolutionWithGroupsResolver to work properly. + # Otherwise there will be 'Convolution' instead 'GroupConvolution' in restored IR, since 'GroupConvolution' is + # extended as node with 'type' = 'Convolution' by IR reader + if group == 1: + op['op'] = 'DepthwiseConv2dNative' op.type = 'Convolution' diff --git a/model-optimizer/unit_tests/extensions/back/ConvolutionNormalizer_test.py b/model-optimizer/unit_tests/extensions/back/ConvolutionNormalizer_test.py index d7a2f4809fc..47a7fc9ac28 100644 --- a/model-optimizer/unit_tests/extensions/back/ConvolutionNormalizer_test.py +++ b/model-optimizer/unit_tests/extensions/back/ConvolutionNormalizer_test.py @@ -102,7 +102,7 @@ class TestPullReshapeThroughFQ(unittest.TestCase): class TestV7ConvolutionWithGroupsResolver(unittest.TestCase): def test_v7_group_convolution_resolver(self): nodes = { - **regular_op_with_shaped_data('input', None, {'type': 'Parameter'}), + **regular_op_with_shaped_data('input', [1, 3, 224, 224], {'type': 'Parameter'}), **valued_const_with_data('weights', np.ones([3, 8, 7, 7])), @@ -133,7 +133,7 @@ class TestV7ConvolutionWithGroupsResolver(unittest.TestCase): def test_v7_group_convolution_resolver_weight_are_in_the_right_layout(self): nodes = { - **regular_op_with_shaped_data('input', None, {'type': 'Parameter'}), + **regular_op_with_shaped_data('input', [1, 3, 224, 224], {'type': 'Parameter'}), **valued_const_with_data('weights', np.ones([24, 1, 7, 7])), **regular_op_with_shaped_data('convolution', None, {'type': 'Convolution', 'group': 3, 'output': 24}), **result(), @@ -149,6 +149,38 @@ class TestV7ConvolutionWithGroupsResolver(unittest.TestCase): (flag, resp) = compare_graphs(graph, graph_ref, last_node='output', check_op_attrs=True) self.assertTrue(flag, resp) + def test_v7_group_convolution_resolver_depthwise_conv2d(self): + nodes = { + **regular_op_with_shaped_data('input', [1, 1, 224, 224], {'type': 'Parameter'}), + + **valued_const_with_data('weights', np.ones([1, 8, 7, 7])), + + **valued_const_with_data('dim', int64_array([8, -1, 7, 7])), + **regular_op_with_empty_data('reshape', {'type': 'Reshape'}), + + **regular_op_with_shaped_data('convolution', None, {'type': 'Convolution', 'group': 1, 'output': 8, + 'op': 'DepthwiseConv2dNative'}), + + **result(), + } + graph = build_graph(nodes, [ + *connect('input', '0:convolution'), + *connect('weights', '1:convolution'), + *connect('convolution', 'output'), + ], nodes_with_edges_only=True) + + V7ConvolutionWithGroupsResolver().find_and_replace_pattern(graph) + graph_ref = build_graph(nodes, [ + *connect('input', '0:convolution'), + *connect('weights', '0:reshape'), + *connect('dim', '1:reshape'), + *connect('reshape', '1:convolution'), + *connect('convolution', 'output'), + ], nodes_with_edges_only=True) + + (flag, resp) = compare_graphs(graph, graph_ref, last_node='output', check_op_attrs=True) + self.assertTrue(flag, resp) + class TestV10ConvolutionWithGroupsResolver(unittest.TestCase): def test_v10_group_convolution_resolver(self): @@ -185,3 +217,39 @@ class TestV10ConvolutionWithGroupsResolver(unittest.TestCase): (flag, resp) = compare_graphs(graph, graph_ref, last_node='output', check_op_attrs=True) self.assertTrue(flag, resp) + + def test_v10_group_convolution_resolver_depthwise_conv2d(self): + nodes = { + **regular_op_with_shaped_data('input', [1, 1, 224, 224], {'type': 'Parameter'}), + + **valued_const_with_data('weights', np.ones([1, 8, 7, 7])), + + **valued_const_with_data('dim', int64_array([1, 8, 1, 7, 7])), + **regular_op_with_empty_data('reshape', {'type': 'Reshape'}), + + **regular_op_with_shaped_data('convolution', None, {'type': 'Convolution', 'group': 1, 'output': 8, + 'op': 'DepthwiseConv2dNative'}), + + **result(), + } + graph = build_graph(nodes, [ + *connect('input', '0:convolution'), + *connect('weights', '1:convolution'), + *connect('convolution', 'output'), + ], nodes_with_edges_only=True) + + V10ConvolutionWithGroupsResolver().find_and_replace_pattern(graph) + + nodes['convolution']['type'] = 'GroupConvolution' + del nodes['convolution']['group'] + + graph_ref = build_graph(nodes, [ + *connect('input', '0:convolution'), + *connect('weights', '0:reshape'), + *connect('dim', '1:reshape'), + *connect('reshape', '1:convolution'), + *connect('convolution', 'output'), + ], nodes_with_edges_only=True) + + (flag, resp) = compare_graphs(graph, graph_ref, last_node='output', check_op_attrs=True) + self.assertTrue(flag, resp)