[MO] ConvolutionWithGroupResolver update to enable TF DepthwiseConv2dNative with in_channels=1 (#5528)
* changed permutation attribute in conv extractor * changed conv get_group parameter * implemented a transformation * updated BOM * specified transformation for in_channels 1 * added unittest and comment string * updated convolution normalizer to convert depthwise convolution with group=1 to group convolution * renamed function * updated IR reader * conversations resolving * condition change
This commit is contained in:
parent
6a42b47c2f
commit
4d9fe14ec6
@ -8,10 +8,43 @@ from extensions.back.ReverseInputChannels import ApplyReverseChannels
|
|||||||
from mo.back.replacement import BackReplacementPattern
|
from mo.back.replacement import BackReplacementPattern
|
||||||
from mo.front.common.partial_infer.utils import int64_array
|
from mo.front.common.partial_infer.utils import int64_array
|
||||||
from mo.front.tf.graph_utils import create_op_node_with_second_input, create_op_with_const_inputs
|
from mo.front.tf.graph_utils import create_op_node_with_second_input, create_op_with_const_inputs
|
||||||
from mo.graph.graph import Graph
|
from mo.graph.graph import Graph, Node
|
||||||
from mo.ops.const import Const
|
from mo.ops.const import Const
|
||||||
from mo.ops.reshape import Reshape
|
from mo.ops.reshape import Reshape
|
||||||
from mo.ops.strided_slice import StridedSlice
|
from mo.ops.strided_slice import StridedSlice
|
||||||
|
from mo.utils.error import Error
|
||||||
|
|
||||||
|
|
||||||
|
def resolve_convolution_with_group(node: Node, group: int, ir_version: str):
|
||||||
|
input_shape = node.in_port(0).data.get_shape()
|
||||||
|
assert len(input_shape) in [3, 4, 5]
|
||||||
|
|
||||||
|
weights_shape = node.in_port(1).data.get_shape()
|
||||||
|
assert weights_shape is not None
|
||||||
|
assert len(weights_shape) in [3, 4, 5]
|
||||||
|
assert weights_shape[0] % group == 0
|
||||||
|
|
||||||
|
assert int64_array(node.output).ndim == 0
|
||||||
|
if ir_version == 'V7':
|
||||||
|
if weights_shape[0] == node.output:
|
||||||
|
# weights are already is in [G*O I X Y] format
|
||||||
|
return
|
||||||
|
new_shape = int64_array([node.output, -1, *weights_shape[2:]])
|
||||||
|
|
||||||
|
elif ir_version == 'V10':
|
||||||
|
I = input_shape[1]
|
||||||
|
new_shape = int64_array([group, node.output / group, I / group, *weights_shape[2:]])
|
||||||
|
assert np.prod(weights_shape) == np.prod(new_shape), \
|
||||||
|
'Initial weights shape {}, grouped weights shape {}'.format(weights_shape, new_shape)
|
||||||
|
del node['group']
|
||||||
|
node['type'] = 'GroupConvolution'
|
||||||
|
else:
|
||||||
|
raise Error("Unknown IR version: {}".format(ir_version))
|
||||||
|
|
||||||
|
reshape = create_op_node_with_second_input(node.graph, Reshape, int64_array(new_shape),
|
||||||
|
{'override_output_shape': True})
|
||||||
|
|
||||||
|
node.in_port(1).get_connection().insert_node(reshape)
|
||||||
|
|
||||||
|
|
||||||
class ConvolutionNormalizer(BackReplacementPattern):
|
class ConvolutionNormalizer(BackReplacementPattern):
|
||||||
@ -37,33 +70,12 @@ class V7ConvolutionWithGroupsResolver(BackReplacementPattern):
|
|||||||
"""
|
"""
|
||||||
enabled = False
|
enabled = False
|
||||||
|
|
||||||
@staticmethod
|
def find_and_replace_pattern(self, graph: Graph):
|
||||||
def pattern():
|
for node in graph.get_op_nodes(type='Convolution'):
|
||||||
return dict(
|
group = node.soft_get('group', None)
|
||||||
nodes=[
|
if group is not None:
|
||||||
('node', dict(type='Convolution', group=lambda g: g is not None and g != 1))
|
if group != 1 or node.soft_get('op') == 'DepthwiseConv2dNative':
|
||||||
],
|
resolve_convolution_with_group(node, group, ir_version='V7')
|
||||||
edges=[]
|
|
||||||
)
|
|
||||||
|
|
||||||
def replace_pattern(self, graph: Graph, match: dict):
|
|
||||||
node = match['node']
|
|
||||||
|
|
||||||
group = node.group
|
|
||||||
assert group > 1
|
|
||||||
|
|
||||||
weights_shape = node.in_port(1).data.get_shape()
|
|
||||||
assert weights_shape is not None
|
|
||||||
assert weights_shape[0] % group == 0
|
|
||||||
|
|
||||||
if weights_shape[0] == node.output:
|
|
||||||
# weights are already is in [G*O I X Y] format
|
|
||||||
return
|
|
||||||
|
|
||||||
new_shape = int64_array([node.output, -1, *weights_shape[2:]])
|
|
||||||
reshape = create_op_node_with_second_input(graph, Reshape, int64_array(new_shape),
|
|
||||||
{'override_output_shape': True})
|
|
||||||
node.in_port(1).get_connection().insert_node(reshape)
|
|
||||||
|
|
||||||
|
|
||||||
class V10ConvolutionWithGroupsResolver(BackReplacementPattern):
|
class V10ConvolutionWithGroupsResolver(BackReplacementPattern):
|
||||||
@ -73,38 +85,12 @@ class V10ConvolutionWithGroupsResolver(BackReplacementPattern):
|
|||||||
"""
|
"""
|
||||||
enabled = False
|
enabled = False
|
||||||
|
|
||||||
@staticmethod
|
def find_and_replace_pattern(self, graph: Graph):
|
||||||
def pattern():
|
for node in graph.get_op_nodes(type='Convolution'):
|
||||||
return dict(
|
group = node.soft_get('group', None)
|
||||||
nodes=[
|
if group is not None:
|
||||||
('node', dict(type='Convolution', group=lambda g: g is not None and g != 1))
|
if group != 1 or node.soft_get('op') == 'DepthwiseConv2dNative':
|
||||||
],
|
resolve_convolution_with_group(node, group, ir_version='V10')
|
||||||
edges=[]
|
|
||||||
)
|
|
||||||
|
|
||||||
def replace_pattern(self, graph: Graph, match: dict):
|
|
||||||
node = match['node']
|
|
||||||
|
|
||||||
group = node.group
|
|
||||||
assert group > 1
|
|
||||||
|
|
||||||
weights_shape = node.in_port(1).data.get_shape()
|
|
||||||
assert weights_shape is not None
|
|
||||||
assert weights_shape[0] % group == 0
|
|
||||||
I = node.in_port(0).data.get_shape()[1]
|
|
||||||
|
|
||||||
new_shape = int64_array([group, node.output / group, I / group, *weights_shape[2:]])
|
|
||||||
|
|
||||||
assert np.prod(weights_shape) == np.prod(new_shape), \
|
|
||||||
'Initial weights shape {}, grouped weights shape {}'.format(weights_shape, new_shape)
|
|
||||||
|
|
||||||
del node['group']
|
|
||||||
node['type'] = 'GroupConvolution'
|
|
||||||
|
|
||||||
reshape = create_op_node_with_second_input(graph, Reshape, int64_array(new_shape),
|
|
||||||
{'override_output_shape': True})
|
|
||||||
|
|
||||||
node.in_port(1).get_connection().insert_node(reshape)
|
|
||||||
|
|
||||||
|
|
||||||
class ConvolutionWithGroupsResolver(BackReplacementPattern):
|
class ConvolutionWithGroupsResolver(BackReplacementPattern):
|
||||||
|
@ -195,6 +195,13 @@ def groupconv_to_conv(op: Node):
|
|||||||
'Weight shape and calculated shape mismatch in GroupConv node {}.'.format(op.name)
|
'Weight shape and calculated shape mismatch in GroupConv node {}.'.format(op.name)
|
||||||
# we need to set this attrs for correct shape infer as convolution
|
# we need to set this attrs for correct shape infer as convolution
|
||||||
op['group'] = group
|
op['group'] = group
|
||||||
|
# The only way GroupConvolution with 'group' = 1 appears in IR is by converting from TF DepthwiseConv2dNative.
|
||||||
|
# In this case we need to specify 'op' parameter for the
|
||||||
|
# extensions.back.ConvolutionNormalizer.ConvolutionWithGroupsResolver to work properly.
|
||||||
|
# Otherwise there will be 'Convolution' instead 'GroupConvolution' in restored IR, since 'GroupConvolution' is
|
||||||
|
# extended as node with 'type' = 'Convolution' by IR reader
|
||||||
|
if group == 1:
|
||||||
|
op['op'] = 'DepthwiseConv2dNative'
|
||||||
op.type = 'Convolution'
|
op.type = 'Convolution'
|
||||||
|
|
||||||
|
|
||||||
|
@ -102,7 +102,7 @@ class TestPullReshapeThroughFQ(unittest.TestCase):
|
|||||||
class TestV7ConvolutionWithGroupsResolver(unittest.TestCase):
|
class TestV7ConvolutionWithGroupsResolver(unittest.TestCase):
|
||||||
def test_v7_group_convolution_resolver(self):
|
def test_v7_group_convolution_resolver(self):
|
||||||
nodes = {
|
nodes = {
|
||||||
**regular_op_with_shaped_data('input', None, {'type': 'Parameter'}),
|
**regular_op_with_shaped_data('input', [1, 3, 224, 224], {'type': 'Parameter'}),
|
||||||
|
|
||||||
**valued_const_with_data('weights', np.ones([3, 8, 7, 7])),
|
**valued_const_with_data('weights', np.ones([3, 8, 7, 7])),
|
||||||
|
|
||||||
@ -133,7 +133,7 @@ class TestV7ConvolutionWithGroupsResolver(unittest.TestCase):
|
|||||||
|
|
||||||
def test_v7_group_convolution_resolver_weight_are_in_the_right_layout(self):
|
def test_v7_group_convolution_resolver_weight_are_in_the_right_layout(self):
|
||||||
nodes = {
|
nodes = {
|
||||||
**regular_op_with_shaped_data('input', None, {'type': 'Parameter'}),
|
**regular_op_with_shaped_data('input', [1, 3, 224, 224], {'type': 'Parameter'}),
|
||||||
**valued_const_with_data('weights', np.ones([24, 1, 7, 7])),
|
**valued_const_with_data('weights', np.ones([24, 1, 7, 7])),
|
||||||
**regular_op_with_shaped_data('convolution', None, {'type': 'Convolution', 'group': 3, 'output': 24}),
|
**regular_op_with_shaped_data('convolution', None, {'type': 'Convolution', 'group': 3, 'output': 24}),
|
||||||
**result(),
|
**result(),
|
||||||
@ -149,6 +149,38 @@ class TestV7ConvolutionWithGroupsResolver(unittest.TestCase):
|
|||||||
(flag, resp) = compare_graphs(graph, graph_ref, last_node='output', check_op_attrs=True)
|
(flag, resp) = compare_graphs(graph, graph_ref, last_node='output', check_op_attrs=True)
|
||||||
self.assertTrue(flag, resp)
|
self.assertTrue(flag, resp)
|
||||||
|
|
||||||
|
def test_v7_group_convolution_resolver_depthwise_conv2d(self):
|
||||||
|
nodes = {
|
||||||
|
**regular_op_with_shaped_data('input', [1, 1, 224, 224], {'type': 'Parameter'}),
|
||||||
|
|
||||||
|
**valued_const_with_data('weights', np.ones([1, 8, 7, 7])),
|
||||||
|
|
||||||
|
**valued_const_with_data('dim', int64_array([8, -1, 7, 7])),
|
||||||
|
**regular_op_with_empty_data('reshape', {'type': 'Reshape'}),
|
||||||
|
|
||||||
|
**regular_op_with_shaped_data('convolution', None, {'type': 'Convolution', 'group': 1, 'output': 8,
|
||||||
|
'op': 'DepthwiseConv2dNative'}),
|
||||||
|
|
||||||
|
**result(),
|
||||||
|
}
|
||||||
|
graph = build_graph(nodes, [
|
||||||
|
*connect('input', '0:convolution'),
|
||||||
|
*connect('weights', '1:convolution'),
|
||||||
|
*connect('convolution', 'output'),
|
||||||
|
], nodes_with_edges_only=True)
|
||||||
|
|
||||||
|
V7ConvolutionWithGroupsResolver().find_and_replace_pattern(graph)
|
||||||
|
graph_ref = build_graph(nodes, [
|
||||||
|
*connect('input', '0:convolution'),
|
||||||
|
*connect('weights', '0:reshape'),
|
||||||
|
*connect('dim', '1:reshape'),
|
||||||
|
*connect('reshape', '1:convolution'),
|
||||||
|
*connect('convolution', 'output'),
|
||||||
|
], nodes_with_edges_only=True)
|
||||||
|
|
||||||
|
(flag, resp) = compare_graphs(graph, graph_ref, last_node='output', check_op_attrs=True)
|
||||||
|
self.assertTrue(flag, resp)
|
||||||
|
|
||||||
|
|
||||||
class TestV10ConvolutionWithGroupsResolver(unittest.TestCase):
|
class TestV10ConvolutionWithGroupsResolver(unittest.TestCase):
|
||||||
def test_v10_group_convolution_resolver(self):
|
def test_v10_group_convolution_resolver(self):
|
||||||
@ -185,3 +217,39 @@ class TestV10ConvolutionWithGroupsResolver(unittest.TestCase):
|
|||||||
|
|
||||||
(flag, resp) = compare_graphs(graph, graph_ref, last_node='output', check_op_attrs=True)
|
(flag, resp) = compare_graphs(graph, graph_ref, last_node='output', check_op_attrs=True)
|
||||||
self.assertTrue(flag, resp)
|
self.assertTrue(flag, resp)
|
||||||
|
|
||||||
|
def test_v10_group_convolution_resolver_depthwise_conv2d(self):
|
||||||
|
nodes = {
|
||||||
|
**regular_op_with_shaped_data('input', [1, 1, 224, 224], {'type': 'Parameter'}),
|
||||||
|
|
||||||
|
**valued_const_with_data('weights', np.ones([1, 8, 7, 7])),
|
||||||
|
|
||||||
|
**valued_const_with_data('dim', int64_array([1, 8, 1, 7, 7])),
|
||||||
|
**regular_op_with_empty_data('reshape', {'type': 'Reshape'}),
|
||||||
|
|
||||||
|
**regular_op_with_shaped_data('convolution', None, {'type': 'Convolution', 'group': 1, 'output': 8,
|
||||||
|
'op': 'DepthwiseConv2dNative'}),
|
||||||
|
|
||||||
|
**result(),
|
||||||
|
}
|
||||||
|
graph = build_graph(nodes, [
|
||||||
|
*connect('input', '0:convolution'),
|
||||||
|
*connect('weights', '1:convolution'),
|
||||||
|
*connect('convolution', 'output'),
|
||||||
|
], nodes_with_edges_only=True)
|
||||||
|
|
||||||
|
V10ConvolutionWithGroupsResolver().find_and_replace_pattern(graph)
|
||||||
|
|
||||||
|
nodes['convolution']['type'] = 'GroupConvolution'
|
||||||
|
del nodes['convolution']['group']
|
||||||
|
|
||||||
|
graph_ref = build_graph(nodes, [
|
||||||
|
*connect('input', '0:convolution'),
|
||||||
|
*connect('weights', '0:reshape'),
|
||||||
|
*connect('dim', '1:reshape'),
|
||||||
|
*connect('reshape', '1:convolution'),
|
||||||
|
*connect('convolution', 'output'),
|
||||||
|
], nodes_with_edges_only=True)
|
||||||
|
|
||||||
|
(flag, resp) = compare_graphs(graph, graph_ref, last_node='output', check_op_attrs=True)
|
||||||
|
self.assertTrue(flag, resp)
|
||||||
|
Loading…
Reference in New Issue
Block a user