[MO] ConvolutionWithGroupResolver update to enable TF DepthwiseConv2dNative with in_channels=1 (#5528)

* changed permutation attribute in conv extractor

* changed conv get_group parameter

* implemented a transformation

* updated BOM

* specified transformation for in_channels 1

* added unittest and comment string

* updated convolution normalizer to convert depthwise convolution with group=1 to group convolution

* renamed function

* updated IR reader

* conversations resolving

* condition change
This commit is contained in:
Yegor Kruglov 2021-06-07 20:22:26 +03:00 committed by GitHub
parent 6a42b47c2f
commit 4d9fe14ec6
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 123 additions and 62 deletions

View File

@ -8,10 +8,43 @@ from extensions.back.ReverseInputChannels import ApplyReverseChannels
from mo.back.replacement import BackReplacementPattern
from mo.front.common.partial_infer.utils import int64_array
from mo.front.tf.graph_utils import create_op_node_with_second_input, create_op_with_const_inputs
from mo.graph.graph import Graph
from mo.graph.graph import Graph, Node
from mo.ops.const import Const
from mo.ops.reshape import Reshape
from mo.ops.strided_slice import StridedSlice
from mo.utils.error import Error
def resolve_convolution_with_group(node: Node, group: int, ir_version: str):
input_shape = node.in_port(0).data.get_shape()
assert len(input_shape) in [3, 4, 5]
weights_shape = node.in_port(1).data.get_shape()
assert weights_shape is not None
assert len(weights_shape) in [3, 4, 5]
assert weights_shape[0] % group == 0
assert int64_array(node.output).ndim == 0
if ir_version == 'V7':
if weights_shape[0] == node.output:
# weights are already is in [G*O I X Y] format
return
new_shape = int64_array([node.output, -1, *weights_shape[2:]])
elif ir_version == 'V10':
I = input_shape[1]
new_shape = int64_array([group, node.output / group, I / group, *weights_shape[2:]])
assert np.prod(weights_shape) == np.prod(new_shape), \
'Initial weights shape {}, grouped weights shape {}'.format(weights_shape, new_shape)
del node['group']
node['type'] = 'GroupConvolution'
else:
raise Error("Unknown IR version: {}".format(ir_version))
reshape = create_op_node_with_second_input(node.graph, Reshape, int64_array(new_shape),
{'override_output_shape': True})
node.in_port(1).get_connection().insert_node(reshape)
class ConvolutionNormalizer(BackReplacementPattern):
@ -37,33 +70,12 @@ class V7ConvolutionWithGroupsResolver(BackReplacementPattern):
"""
enabled = False
@staticmethod
def pattern():
return dict(
nodes=[
('node', dict(type='Convolution', group=lambda g: g is not None and g != 1))
],
edges=[]
)
def replace_pattern(self, graph: Graph, match: dict):
node = match['node']
group = node.group
assert group > 1
weights_shape = node.in_port(1).data.get_shape()
assert weights_shape is not None
assert weights_shape[0] % group == 0
if weights_shape[0] == node.output:
# weights are already is in [G*O I X Y] format
return
new_shape = int64_array([node.output, -1, *weights_shape[2:]])
reshape = create_op_node_with_second_input(graph, Reshape, int64_array(new_shape),
{'override_output_shape': True})
node.in_port(1).get_connection().insert_node(reshape)
def find_and_replace_pattern(self, graph: Graph):
for node in graph.get_op_nodes(type='Convolution'):
group = node.soft_get('group', None)
if group is not None:
if group != 1 or node.soft_get('op') == 'DepthwiseConv2dNative':
resolve_convolution_with_group(node, group, ir_version='V7')
class V10ConvolutionWithGroupsResolver(BackReplacementPattern):
@ -73,38 +85,12 @@ class V10ConvolutionWithGroupsResolver(BackReplacementPattern):
"""
enabled = False
@staticmethod
def pattern():
return dict(
nodes=[
('node', dict(type='Convolution', group=lambda g: g is not None and g != 1))
],
edges=[]
)
def replace_pattern(self, graph: Graph, match: dict):
node = match['node']
group = node.group
assert group > 1
weights_shape = node.in_port(1).data.get_shape()
assert weights_shape is not None
assert weights_shape[0] % group == 0
I = node.in_port(0).data.get_shape()[1]
new_shape = int64_array([group, node.output / group, I / group, *weights_shape[2:]])
assert np.prod(weights_shape) == np.prod(new_shape), \
'Initial weights shape {}, grouped weights shape {}'.format(weights_shape, new_shape)
del node['group']
node['type'] = 'GroupConvolution'
reshape = create_op_node_with_second_input(graph, Reshape, int64_array(new_shape),
{'override_output_shape': True})
node.in_port(1).get_connection().insert_node(reshape)
def find_and_replace_pattern(self, graph: Graph):
for node in graph.get_op_nodes(type='Convolution'):
group = node.soft_get('group', None)
if group is not None:
if group != 1 or node.soft_get('op') == 'DepthwiseConv2dNative':
resolve_convolution_with_group(node, group, ir_version='V10')
class ConvolutionWithGroupsResolver(BackReplacementPattern):

View File

@ -195,6 +195,13 @@ def groupconv_to_conv(op: Node):
'Weight shape and calculated shape mismatch in GroupConv node {}.'.format(op.name)
# we need to set this attrs for correct shape infer as convolution
op['group'] = group
# The only way GroupConvolution with 'group' = 1 appears in IR is by converting from TF DepthwiseConv2dNative.
# In this case we need to specify 'op' parameter for the
# extensions.back.ConvolutionNormalizer.ConvolutionWithGroupsResolver to work properly.
# Otherwise there will be 'Convolution' instead 'GroupConvolution' in restored IR, since 'GroupConvolution' is
# extended as node with 'type' = 'Convolution' by IR reader
if group == 1:
op['op'] = 'DepthwiseConv2dNative'
op.type = 'Convolution'

View File

@ -102,7 +102,7 @@ class TestPullReshapeThroughFQ(unittest.TestCase):
class TestV7ConvolutionWithGroupsResolver(unittest.TestCase):
def test_v7_group_convolution_resolver(self):
nodes = {
**regular_op_with_shaped_data('input', None, {'type': 'Parameter'}),
**regular_op_with_shaped_data('input', [1, 3, 224, 224], {'type': 'Parameter'}),
**valued_const_with_data('weights', np.ones([3, 8, 7, 7])),
@ -133,7 +133,7 @@ class TestV7ConvolutionWithGroupsResolver(unittest.TestCase):
def test_v7_group_convolution_resolver_weight_are_in_the_right_layout(self):
nodes = {
**regular_op_with_shaped_data('input', None, {'type': 'Parameter'}),
**regular_op_with_shaped_data('input', [1, 3, 224, 224], {'type': 'Parameter'}),
**valued_const_with_data('weights', np.ones([24, 1, 7, 7])),
**regular_op_with_shaped_data('convolution', None, {'type': 'Convolution', 'group': 3, 'output': 24}),
**result(),
@ -149,6 +149,38 @@ class TestV7ConvolutionWithGroupsResolver(unittest.TestCase):
(flag, resp) = compare_graphs(graph, graph_ref, last_node='output', check_op_attrs=True)
self.assertTrue(flag, resp)
def test_v7_group_convolution_resolver_depthwise_conv2d(self):
nodes = {
**regular_op_with_shaped_data('input', [1, 1, 224, 224], {'type': 'Parameter'}),
**valued_const_with_data('weights', np.ones([1, 8, 7, 7])),
**valued_const_with_data('dim', int64_array([8, -1, 7, 7])),
**regular_op_with_empty_data('reshape', {'type': 'Reshape'}),
**regular_op_with_shaped_data('convolution', None, {'type': 'Convolution', 'group': 1, 'output': 8,
'op': 'DepthwiseConv2dNative'}),
**result(),
}
graph = build_graph(nodes, [
*connect('input', '0:convolution'),
*connect('weights', '1:convolution'),
*connect('convolution', 'output'),
], nodes_with_edges_only=True)
V7ConvolutionWithGroupsResolver().find_and_replace_pattern(graph)
graph_ref = build_graph(nodes, [
*connect('input', '0:convolution'),
*connect('weights', '0:reshape'),
*connect('dim', '1:reshape'),
*connect('reshape', '1:convolution'),
*connect('convolution', 'output'),
], nodes_with_edges_only=True)
(flag, resp) = compare_graphs(graph, graph_ref, last_node='output', check_op_attrs=True)
self.assertTrue(flag, resp)
class TestV10ConvolutionWithGroupsResolver(unittest.TestCase):
def test_v10_group_convolution_resolver(self):
@ -185,3 +217,39 @@ class TestV10ConvolutionWithGroupsResolver(unittest.TestCase):
(flag, resp) = compare_graphs(graph, graph_ref, last_node='output', check_op_attrs=True)
self.assertTrue(flag, resp)
def test_v10_group_convolution_resolver_depthwise_conv2d(self):
nodes = {
**regular_op_with_shaped_data('input', [1, 1, 224, 224], {'type': 'Parameter'}),
**valued_const_with_data('weights', np.ones([1, 8, 7, 7])),
**valued_const_with_data('dim', int64_array([1, 8, 1, 7, 7])),
**regular_op_with_empty_data('reshape', {'type': 'Reshape'}),
**regular_op_with_shaped_data('convolution', None, {'type': 'Convolution', 'group': 1, 'output': 8,
'op': 'DepthwiseConv2dNative'}),
**result(),
}
graph = build_graph(nodes, [
*connect('input', '0:convolution'),
*connect('weights', '1:convolution'),
*connect('convolution', 'output'),
], nodes_with_edges_only=True)
V10ConvolutionWithGroupsResolver().find_and_replace_pattern(graph)
nodes['convolution']['type'] = 'GroupConvolution'
del nodes['convolution']['group']
graph_ref = build_graph(nodes, [
*connect('input', '0:convolution'),
*connect('weights', '0:reshape'),
*connect('dim', '1:reshape'),
*connect('reshape', '1:convolution'),
*connect('convolution', 'output'),
], nodes_with_edges_only=True)
(flag, resp) = compare_graphs(graph, graph_ref, last_node='output', check_op_attrs=True)
self.assertTrue(flag, resp)