[MO] Support TensorFlow Grouped Conv2DBackpropInput (#11420)

* [MO] Support TensorFlow Grouped Conv2DBackpropInput

Signed-off-by: Roman Kazantsev <roman.kazantsev@intel.com>

* Correct computation of group number for ConvBackpropInput operation

Signed-off-by: Roman Kazantsev <roman.kazantsev@intel.com>

* Fix get_conv_backprop_groups function

Signed-off-by: Roman Kazantsev <roman.kazantsev@intel.com>

* Add unit-tests for Deconvolution shape inference

Signed-off-by: Roman Kazantsev <roman.kazantsev@intel.com>
This commit is contained in:
Roman Kazantsev 2022-04-04 12:30:31 +03:00 committed by GitHub
parent 60521a92c9
commit 9dee25fa79
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 116 additions and 2 deletions

View File

@ -1,9 +1,11 @@
# Copyright (C) 2018-2022 Intel Corporation
# SPDX-License-Identifier: Apache-2.0
from openvino.tools.mo.front.common.partial_infer.utils import convert_deconv_tf_padding_to_str, int64_array
from openvino.tools.mo.front.common.partial_infer.utils import convert_deconv_tf_padding_to_str, int64_array, \
dynamic_dimension
from openvino.tools.mo.front.extractor import FrontExtractorOp
from openvino.tools.mo.front.tf.extractors.utils import tf_data_format_spatial, tf_data_format_channel, tf_data_format_batch, \
from openvino.tools.mo.front.tf.extractors.utils import tf_data_format_spatial, tf_data_format_channel, \
tf_data_format_batch, \
tf_int_list
from openvino.tools.mo.ops.deconvolution import Deconvolution
from openvino.tools.mo.ops.op import PermuteAttrs
@ -17,6 +19,7 @@ class Conv2DBackpropInputFrontExtractor(FrontExtractorOp):
def extract(cls, node):
attrs = tf_create_attrs(node, 3, 2)
attrs.update({'op': cls.op,
'get_group': get_conv_backprop_groups,
'get_weights_permute': PermuteAttrs.Permutation(perm=int64_array([3, 2, 0, 1]),
inv=int64_array([2, 3, 1, 0])),
'swap_0_and_2_inputs': True,
@ -36,6 +39,7 @@ class Conv3DBackpropInputV2InputFrontExtractor(FrontExtractorOp):
def extract(cls, node):
attrs = tf_create_attrs(node, 4, 3)
attrs.update({'op': cls.op,
'get_group': get_conv_backprop_groups,
'get_weights_permute': PermuteAttrs.Permutation(perm=int64_array([4, 3, 0, 1, 2]),
inv=int64_array([2, 3, 4, 1, 0])),
'swap_0_and_2_inputs': True,
@ -69,3 +73,20 @@ def tf_create_attrs(node, input_feature_channel, output_feature_channel):
'input_feature_channel': input_feature_channel,
'output_feature_channel': output_feature_channel,
}
def get_conv_backprop_groups(node):
# output shape is required input for TensorFlow ConvBackpropInput operation and contains output shape values
# in the form [batch_size, output_height, output_width, output_channel], so that
# groups number = output_channel // kernel_out_channels, where
# kernel shape is given as [kernel_height, kernel_width, kernel_out_channels, in_channels]
output_shape = node.in_port(2).data.get_value()
kernel_shape = node.in_port(1).data.get_shape()
if node.has_and_set('group'):
return node.group
elif output_shape is not None and kernel_shape is not None \
and output_shape[node.channel_dims[0]] is not dynamic_dimension \
and kernel_shape[node.output_feature_channel] is not dynamic_dimension:
return output_shape[node.channel_dims] // kernel_shape[node.output_feature_channel]
else:
return 1

View File

@ -60,6 +60,9 @@ class Deconvolution(Op):
if not node.has_valid('dilation'):
node['dilation'] = np.full([len(output_shape)], 1, dtype=np.int64)
if node.has_valid('get_group'):
node['group'] = node.get_group(node)
spatial_dims = node.spatial_dims
output_spatial = shape_array(output_shape[spatial_dims])
stride_spatial = shape_array(node.stride[spatial_dims])

View File

@ -0,0 +1,90 @@
# Copyright (C) 2018-2022 Intel Corporation
# SPDX-License-Identifier: Apache-2.0
import unittest
import numpy as np
from openvino.tools.mo.front.common.partial_infer.utils import int64_array
from openvino.tools.mo.front.tf.deconv_ext import get_conv_backprop_groups
from openvino.tools.mo.graph.graph import Node
from openvino.tools.mo.ops.deconvolution import Deconvolution
from unit_tests.utils.graph import build_graph
nodes_attributes = {'deconv_input': {'value': None, 'kind': 'data'},
'deconv_weights': {'value': None, 'kind': 'data'},
'deconv_output_shape': {'value': None, 'kind': 'data'},
'deconv_node': {'type': 'Deconvolution', 'op': 'Deconvolution', 'kind': 'op'},
'deconv_output': {'value': None, 'kind': 'data'},
'op_output': {'kind': 'op', 'op': 'Result'}
}
def create_deconv_graph(input_shape: int64_array, weights_shape: int64_array, output_shape: int64_array):
graph = build_graph(nodes_attributes,
[('deconv_input', 'deconv_node'),
('deconv_weights', 'deconv_node'),
('deconv_output_shape', 'deconv_node'),
('deconv_node', 'deconv_output'),
('deconv_output', 'op_output')
],
{'deconv_input': {'shape': input_shape},
'deconv_weights': {'shape': weights_shape,
'dim_attrs': ['spatial_dims', 'channel_dims', 'batch_dims', 'axis']},
'deconv_output_shape': {'value': output_shape},
'deconv_node': {'channel_dims': int64_array([1]),
'batch_dims': int64_array([0]),
'spatial_dims': int64_array([2, 3]),
'pad_spatial_shape': int64_array([[0, 0], [0, 0]]),
'kernel_spatial': int64_array([4, 4]),
'kernel_spatial_idx': int64_array([2, 3]),
'input_feature_channel': 0,
'output_feature_channel': 1,
'auto_pad': 'same_lower',
'output_padding': int64_array([0, 0, 1, 1]),
'type': 'Deconvolution',
'dilation': int64_array([1, 1, 1, 1]),
'stride': int64_array([1, 1, 2, 2]),
'pad': None,
'output': None,
'output_shape': None,
'get_group': get_conv_backprop_groups},
'deconv_output': {'shape': None},
})
return graph
class TestConvolutionPartialInfer(unittest.TestCase):
def test_deconv_infer_one_group(self):
graph = create_deconv_graph(int64_array([1, 21, 18, 18]), int64_array([21, 50, 4, 4]),
int64_array([1, 50, 35, 35]))
Deconvolution.infer(Node(graph, 'deconv_node'))
res_shape = graph.node['deconv_output']['shape']
exp_shape = np.array([1, 50, 35, 35])
res_group = graph.node['deconv_node']['group']
exp_group = int64_array([1])
self.assertTrue(np.array_equal(exp_shape, res_shape),
'values do not match expected: {} and computed: {}'.format(exp_shape, res_shape))
self.assertTrue(np.array_equal(exp_group, res_group),
'group number values do not match expected: {} and computed: {}'.format(exp_group, res_group))
def test_deconv_infer_several_groups(self):
graph = create_deconv_graph(int64_array([1, 21, 18, 18]), int64_array([21, 50, 4, 4]),
int64_array([1, 350, 35, 35]))
Deconvolution.infer(Node(graph, 'deconv_node'))
res_shape = graph.node['deconv_output']['shape']
exp_shape = np.array([1, 350, 35, 35])
res_group = graph.node['deconv_node']['group']
exp_group = int64_array([7])
self.assertTrue(np.array_equal(exp_shape, res_shape),
'values do not match expected: {} and computed: {}'.format(exp_shape, res_shape))
self.assertTrue(np.array_equal(exp_group, res_group),
'group number values do not match expected: {} and computed: {}'.format(exp_group, res_group))