[MO] ConvolutionWithGroupResolver update to enable TF DepthwiseConv2dNative with in_channels=1 (#5528)
* changed permutation attribute in conv extractor * changed conv get_group parameter * implemented a transformation * updated BOM * specified transformation for in_channels 1 * added unittest and comment string * updated convolution normalizer to convert depthwise convolution with group=1 to group convolution * renamed function * updated IR reader * conversations resolving * condition change
This commit is contained in:
parent
6a42b47c2f
commit
4d9fe14ec6
@ -8,10 +8,43 @@ from extensions.back.ReverseInputChannels import ApplyReverseChannels
|
||||
from mo.back.replacement import BackReplacementPattern
|
||||
from mo.front.common.partial_infer.utils import int64_array
|
||||
from mo.front.tf.graph_utils import create_op_node_with_second_input, create_op_with_const_inputs
|
||||
from mo.graph.graph import Graph
|
||||
from mo.graph.graph import Graph, Node
|
||||
from mo.ops.const import Const
|
||||
from mo.ops.reshape import Reshape
|
||||
from mo.ops.strided_slice import StridedSlice
|
||||
from mo.utils.error import Error
|
||||
|
||||
|
||||
def resolve_convolution_with_group(node: Node, group: int, ir_version: str):
|
||||
input_shape = node.in_port(0).data.get_shape()
|
||||
assert len(input_shape) in [3, 4, 5]
|
||||
|
||||
weights_shape = node.in_port(1).data.get_shape()
|
||||
assert weights_shape is not None
|
||||
assert len(weights_shape) in [3, 4, 5]
|
||||
assert weights_shape[0] % group == 0
|
||||
|
||||
assert int64_array(node.output).ndim == 0
|
||||
if ir_version == 'V7':
|
||||
if weights_shape[0] == node.output:
|
||||
# weights are already is in [G*O I X Y] format
|
||||
return
|
||||
new_shape = int64_array([node.output, -1, *weights_shape[2:]])
|
||||
|
||||
elif ir_version == 'V10':
|
||||
I = input_shape[1]
|
||||
new_shape = int64_array([group, node.output / group, I / group, *weights_shape[2:]])
|
||||
assert np.prod(weights_shape) == np.prod(new_shape), \
|
||||
'Initial weights shape {}, grouped weights shape {}'.format(weights_shape, new_shape)
|
||||
del node['group']
|
||||
node['type'] = 'GroupConvolution'
|
||||
else:
|
||||
raise Error("Unknown IR version: {}".format(ir_version))
|
||||
|
||||
reshape = create_op_node_with_second_input(node.graph, Reshape, int64_array(new_shape),
|
||||
{'override_output_shape': True})
|
||||
|
||||
node.in_port(1).get_connection().insert_node(reshape)
|
||||
|
||||
|
||||
class ConvolutionNormalizer(BackReplacementPattern):
|
||||
@ -37,33 +70,12 @@ class V7ConvolutionWithGroupsResolver(BackReplacementPattern):
|
||||
"""
|
||||
enabled = False
|
||||
|
||||
@staticmethod
|
||||
def pattern():
|
||||
return dict(
|
||||
nodes=[
|
||||
('node', dict(type='Convolution', group=lambda g: g is not None and g != 1))
|
||||
],
|
||||
edges=[]
|
||||
)
|
||||
|
||||
def replace_pattern(self, graph: Graph, match: dict):
|
||||
node = match['node']
|
||||
|
||||
group = node.group
|
||||
assert group > 1
|
||||
|
||||
weights_shape = node.in_port(1).data.get_shape()
|
||||
assert weights_shape is not None
|
||||
assert weights_shape[0] % group == 0
|
||||
|
||||
if weights_shape[0] == node.output:
|
||||
# weights are already is in [G*O I X Y] format
|
||||
return
|
||||
|
||||
new_shape = int64_array([node.output, -1, *weights_shape[2:]])
|
||||
reshape = create_op_node_with_second_input(graph, Reshape, int64_array(new_shape),
|
||||
{'override_output_shape': True})
|
||||
node.in_port(1).get_connection().insert_node(reshape)
|
||||
def find_and_replace_pattern(self, graph: Graph):
|
||||
for node in graph.get_op_nodes(type='Convolution'):
|
||||
group = node.soft_get('group', None)
|
||||
if group is not None:
|
||||
if group != 1 or node.soft_get('op') == 'DepthwiseConv2dNative':
|
||||
resolve_convolution_with_group(node, group, ir_version='V7')
|
||||
|
||||
|
||||
class V10ConvolutionWithGroupsResolver(BackReplacementPattern):
|
||||
@ -73,38 +85,12 @@ class V10ConvolutionWithGroupsResolver(BackReplacementPattern):
|
||||
"""
|
||||
enabled = False
|
||||
|
||||
@staticmethod
|
||||
def pattern():
|
||||
return dict(
|
||||
nodes=[
|
||||
('node', dict(type='Convolution', group=lambda g: g is not None and g != 1))
|
||||
],
|
||||
edges=[]
|
||||
)
|
||||
|
||||
def replace_pattern(self, graph: Graph, match: dict):
|
||||
node = match['node']
|
||||
|
||||
group = node.group
|
||||
assert group > 1
|
||||
|
||||
weights_shape = node.in_port(1).data.get_shape()
|
||||
assert weights_shape is not None
|
||||
assert weights_shape[0] % group == 0
|
||||
I = node.in_port(0).data.get_shape()[1]
|
||||
|
||||
new_shape = int64_array([group, node.output / group, I / group, *weights_shape[2:]])
|
||||
|
||||
assert np.prod(weights_shape) == np.prod(new_shape), \
|
||||
'Initial weights shape {}, grouped weights shape {}'.format(weights_shape, new_shape)
|
||||
|
||||
del node['group']
|
||||
node['type'] = 'GroupConvolution'
|
||||
|
||||
reshape = create_op_node_with_second_input(graph, Reshape, int64_array(new_shape),
|
||||
{'override_output_shape': True})
|
||||
|
||||
node.in_port(1).get_connection().insert_node(reshape)
|
||||
def find_and_replace_pattern(self, graph: Graph):
|
||||
for node in graph.get_op_nodes(type='Convolution'):
|
||||
group = node.soft_get('group', None)
|
||||
if group is not None:
|
||||
if group != 1 or node.soft_get('op') == 'DepthwiseConv2dNative':
|
||||
resolve_convolution_with_group(node, group, ir_version='V10')
|
||||
|
||||
|
||||
class ConvolutionWithGroupsResolver(BackReplacementPattern):
|
||||
|
@ -195,6 +195,13 @@ def groupconv_to_conv(op: Node):
|
||||
'Weight shape and calculated shape mismatch in GroupConv node {}.'.format(op.name)
|
||||
# we need to set this attrs for correct shape infer as convolution
|
||||
op['group'] = group
|
||||
# The only way GroupConvolution with 'group' = 1 appears in IR is by converting from TF DepthwiseConv2dNative.
|
||||
# In this case we need to specify 'op' parameter for the
|
||||
# extensions.back.ConvolutionNormalizer.ConvolutionWithGroupsResolver to work properly.
|
||||
# Otherwise there will be 'Convolution' instead 'GroupConvolution' in restored IR, since 'GroupConvolution' is
|
||||
# extended as node with 'type' = 'Convolution' by IR reader
|
||||
if group == 1:
|
||||
op['op'] = 'DepthwiseConv2dNative'
|
||||
op.type = 'Convolution'
|
||||
|
||||
|
||||
|
@ -102,7 +102,7 @@ class TestPullReshapeThroughFQ(unittest.TestCase):
|
||||
class TestV7ConvolutionWithGroupsResolver(unittest.TestCase):
|
||||
def test_v7_group_convolution_resolver(self):
|
||||
nodes = {
|
||||
**regular_op_with_shaped_data('input', None, {'type': 'Parameter'}),
|
||||
**regular_op_with_shaped_data('input', [1, 3, 224, 224], {'type': 'Parameter'}),
|
||||
|
||||
**valued_const_with_data('weights', np.ones([3, 8, 7, 7])),
|
||||
|
||||
@ -133,7 +133,7 @@ class TestV7ConvolutionWithGroupsResolver(unittest.TestCase):
|
||||
|
||||
def test_v7_group_convolution_resolver_weight_are_in_the_right_layout(self):
|
||||
nodes = {
|
||||
**regular_op_with_shaped_data('input', None, {'type': 'Parameter'}),
|
||||
**regular_op_with_shaped_data('input', [1, 3, 224, 224], {'type': 'Parameter'}),
|
||||
**valued_const_with_data('weights', np.ones([24, 1, 7, 7])),
|
||||
**regular_op_with_shaped_data('convolution', None, {'type': 'Convolution', 'group': 3, 'output': 24}),
|
||||
**result(),
|
||||
@ -149,6 +149,38 @@ class TestV7ConvolutionWithGroupsResolver(unittest.TestCase):
|
||||
(flag, resp) = compare_graphs(graph, graph_ref, last_node='output', check_op_attrs=True)
|
||||
self.assertTrue(flag, resp)
|
||||
|
||||
def test_v7_group_convolution_resolver_depthwise_conv2d(self):
|
||||
nodes = {
|
||||
**regular_op_with_shaped_data('input', [1, 1, 224, 224], {'type': 'Parameter'}),
|
||||
|
||||
**valued_const_with_data('weights', np.ones([1, 8, 7, 7])),
|
||||
|
||||
**valued_const_with_data('dim', int64_array([8, -1, 7, 7])),
|
||||
**regular_op_with_empty_data('reshape', {'type': 'Reshape'}),
|
||||
|
||||
**regular_op_with_shaped_data('convolution', None, {'type': 'Convolution', 'group': 1, 'output': 8,
|
||||
'op': 'DepthwiseConv2dNative'}),
|
||||
|
||||
**result(),
|
||||
}
|
||||
graph = build_graph(nodes, [
|
||||
*connect('input', '0:convolution'),
|
||||
*connect('weights', '1:convolution'),
|
||||
*connect('convolution', 'output'),
|
||||
], nodes_with_edges_only=True)
|
||||
|
||||
V7ConvolutionWithGroupsResolver().find_and_replace_pattern(graph)
|
||||
graph_ref = build_graph(nodes, [
|
||||
*connect('input', '0:convolution'),
|
||||
*connect('weights', '0:reshape'),
|
||||
*connect('dim', '1:reshape'),
|
||||
*connect('reshape', '1:convolution'),
|
||||
*connect('convolution', 'output'),
|
||||
], nodes_with_edges_only=True)
|
||||
|
||||
(flag, resp) = compare_graphs(graph, graph_ref, last_node='output', check_op_attrs=True)
|
||||
self.assertTrue(flag, resp)
|
||||
|
||||
|
||||
class TestV10ConvolutionWithGroupsResolver(unittest.TestCase):
|
||||
def test_v10_group_convolution_resolver(self):
|
||||
@ -185,3 +217,39 @@ class TestV10ConvolutionWithGroupsResolver(unittest.TestCase):
|
||||
|
||||
(flag, resp) = compare_graphs(graph, graph_ref, last_node='output', check_op_attrs=True)
|
||||
self.assertTrue(flag, resp)
|
||||
|
||||
def test_v10_group_convolution_resolver_depthwise_conv2d(self):
|
||||
nodes = {
|
||||
**regular_op_with_shaped_data('input', [1, 1, 224, 224], {'type': 'Parameter'}),
|
||||
|
||||
**valued_const_with_data('weights', np.ones([1, 8, 7, 7])),
|
||||
|
||||
**valued_const_with_data('dim', int64_array([1, 8, 1, 7, 7])),
|
||||
**regular_op_with_empty_data('reshape', {'type': 'Reshape'}),
|
||||
|
||||
**regular_op_with_shaped_data('convolution', None, {'type': 'Convolution', 'group': 1, 'output': 8,
|
||||
'op': 'DepthwiseConv2dNative'}),
|
||||
|
||||
**result(),
|
||||
}
|
||||
graph = build_graph(nodes, [
|
||||
*connect('input', '0:convolution'),
|
||||
*connect('weights', '1:convolution'),
|
||||
*connect('convolution', 'output'),
|
||||
], nodes_with_edges_only=True)
|
||||
|
||||
V10ConvolutionWithGroupsResolver().find_and_replace_pattern(graph)
|
||||
|
||||
nodes['convolution']['type'] = 'GroupConvolution'
|
||||
del nodes['convolution']['group']
|
||||
|
||||
graph_ref = build_graph(nodes, [
|
||||
*connect('input', '0:convolution'),
|
||||
*connect('weights', '0:reshape'),
|
||||
*connect('dim', '1:reshape'),
|
||||
*connect('reshape', '1:convolution'),
|
||||
*connect('convolution', 'output'),
|
||||
], nodes_with_edges_only=True)
|
||||
|
||||
(flag, resp) = compare_graphs(graph, graph_ref, last_node='output', check_op_attrs=True)
|
||||
self.assertTrue(flag, resp)
|
||||
|
Loading…
Reference in New Issue
Block a user