[MO] Support TensorFlow Grouped Conv2DBackpropInput (#11420)
* [MO] Support TensorFlow Grouped Conv2DBackpropInput Signed-off-by: Roman Kazantsev <roman.kazantsev@intel.com> * Correct computation of group number for ConvBackpropInput operation Signed-off-by: Roman Kazantsev <roman.kazantsev@intel.com> * Fix get_conv_backprop_groups function Signed-off-by: Roman Kazantsev <roman.kazantsev@intel.com> * Add unit-tests for Deconvolution shape inference Signed-off-by: Roman Kazantsev <roman.kazantsev@intel.com>
This commit is contained in:
parent
60521a92c9
commit
9dee25fa79
@ -1,9 +1,11 @@
|
||||
# Copyright (C) 2018-2022 Intel Corporation
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
from openvino.tools.mo.front.common.partial_infer.utils import convert_deconv_tf_padding_to_str, int64_array
|
||||
from openvino.tools.mo.front.common.partial_infer.utils import convert_deconv_tf_padding_to_str, int64_array, \
|
||||
dynamic_dimension
|
||||
from openvino.tools.mo.front.extractor import FrontExtractorOp
|
||||
from openvino.tools.mo.front.tf.extractors.utils import tf_data_format_spatial, tf_data_format_channel, tf_data_format_batch, \
|
||||
from openvino.tools.mo.front.tf.extractors.utils import tf_data_format_spatial, tf_data_format_channel, \
|
||||
tf_data_format_batch, \
|
||||
tf_int_list
|
||||
from openvino.tools.mo.ops.deconvolution import Deconvolution
|
||||
from openvino.tools.mo.ops.op import PermuteAttrs
|
||||
@ -17,6 +19,7 @@ class Conv2DBackpropInputFrontExtractor(FrontExtractorOp):
|
||||
def extract(cls, node):
|
||||
attrs = tf_create_attrs(node, 3, 2)
|
||||
attrs.update({'op': cls.op,
|
||||
'get_group': get_conv_backprop_groups,
|
||||
'get_weights_permute': PermuteAttrs.Permutation(perm=int64_array([3, 2, 0, 1]),
|
||||
inv=int64_array([2, 3, 1, 0])),
|
||||
'swap_0_and_2_inputs': True,
|
||||
@ -36,6 +39,7 @@ class Conv3DBackpropInputV2InputFrontExtractor(FrontExtractorOp):
|
||||
def extract(cls, node):
|
||||
attrs = tf_create_attrs(node, 4, 3)
|
||||
attrs.update({'op': cls.op,
|
||||
'get_group': get_conv_backprop_groups,
|
||||
'get_weights_permute': PermuteAttrs.Permutation(perm=int64_array([4, 3, 0, 1, 2]),
|
||||
inv=int64_array([2, 3, 4, 1, 0])),
|
||||
'swap_0_and_2_inputs': True,
|
||||
@ -69,3 +73,20 @@ def tf_create_attrs(node, input_feature_channel, output_feature_channel):
|
||||
'input_feature_channel': input_feature_channel,
|
||||
'output_feature_channel': output_feature_channel,
|
||||
}
|
||||
|
||||
|
||||
def get_conv_backprop_groups(node):
|
||||
# output shape is required input for TensorFlow ConvBackpropInput operation and contains output shape values
|
||||
# in the form [batch_size, output_height, output_width, output_channel], so that
|
||||
# groups number = output_channel // kernel_out_channels, where
|
||||
# kernel shape is given as [kernel_height, kernel_width, kernel_out_channels, in_channels]
|
||||
output_shape = node.in_port(2).data.get_value()
|
||||
kernel_shape = node.in_port(1).data.get_shape()
|
||||
if node.has_and_set('group'):
|
||||
return node.group
|
||||
elif output_shape is not None and kernel_shape is not None \
|
||||
and output_shape[node.channel_dims[0]] is not dynamic_dimension \
|
||||
and kernel_shape[node.output_feature_channel] is not dynamic_dimension:
|
||||
return output_shape[node.channel_dims] // kernel_shape[node.output_feature_channel]
|
||||
else:
|
||||
return 1
|
||||
|
@ -60,6 +60,9 @@ class Deconvolution(Op):
|
||||
if not node.has_valid('dilation'):
|
||||
node['dilation'] = np.full([len(output_shape)], 1, dtype=np.int64)
|
||||
|
||||
if node.has_valid('get_group'):
|
||||
node['group'] = node.get_group(node)
|
||||
|
||||
spatial_dims = node.spatial_dims
|
||||
output_spatial = shape_array(output_shape[spatial_dims])
|
||||
stride_spatial = shape_array(node.stride[spatial_dims])
|
||||
|
90
tools/mo/unit_tests/mo/ops/deconvolution_test.py
Normal file
90
tools/mo/unit_tests/mo/ops/deconvolution_test.py
Normal file
@ -0,0 +1,90 @@
|
||||
# Copyright (C) 2018-2022 Intel Corporation
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
import unittest
|
||||
|
||||
import numpy as np
|
||||
|
||||
from openvino.tools.mo.front.common.partial_infer.utils import int64_array
|
||||
from openvino.tools.mo.front.tf.deconv_ext import get_conv_backprop_groups
|
||||
from openvino.tools.mo.graph.graph import Node
|
||||
from openvino.tools.mo.ops.deconvolution import Deconvolution
|
||||
from unit_tests.utils.graph import build_graph
|
||||
|
||||
nodes_attributes = {'deconv_input': {'value': None, 'kind': 'data'},
|
||||
'deconv_weights': {'value': None, 'kind': 'data'},
|
||||
'deconv_output_shape': {'value': None, 'kind': 'data'},
|
||||
'deconv_node': {'type': 'Deconvolution', 'op': 'Deconvolution', 'kind': 'op'},
|
||||
'deconv_output': {'value': None, 'kind': 'data'},
|
||||
'op_output': {'kind': 'op', 'op': 'Result'}
|
||||
}
|
||||
|
||||
|
||||
def create_deconv_graph(input_shape: int64_array, weights_shape: int64_array, output_shape: int64_array):
|
||||
graph = build_graph(nodes_attributes,
|
||||
[('deconv_input', 'deconv_node'),
|
||||
('deconv_weights', 'deconv_node'),
|
||||
('deconv_output_shape', 'deconv_node'),
|
||||
('deconv_node', 'deconv_output'),
|
||||
('deconv_output', 'op_output')
|
||||
],
|
||||
{'deconv_input': {'shape': input_shape},
|
||||
'deconv_weights': {'shape': weights_shape,
|
||||
'dim_attrs': ['spatial_dims', 'channel_dims', 'batch_dims', 'axis']},
|
||||
'deconv_output_shape': {'value': output_shape},
|
||||
'deconv_node': {'channel_dims': int64_array([1]),
|
||||
'batch_dims': int64_array([0]),
|
||||
'spatial_dims': int64_array([2, 3]),
|
||||
'pad_spatial_shape': int64_array([[0, 0], [0, 0]]),
|
||||
'kernel_spatial': int64_array([4, 4]),
|
||||
'kernel_spatial_idx': int64_array([2, 3]),
|
||||
'input_feature_channel': 0,
|
||||
'output_feature_channel': 1,
|
||||
'auto_pad': 'same_lower',
|
||||
'output_padding': int64_array([0, 0, 1, 1]),
|
||||
'type': 'Deconvolution',
|
||||
'dilation': int64_array([1, 1, 1, 1]),
|
||||
'stride': int64_array([1, 1, 2, 2]),
|
||||
'pad': None,
|
||||
'output': None,
|
||||
'output_shape': None,
|
||||
'get_group': get_conv_backprop_groups},
|
||||
'deconv_output': {'shape': None},
|
||||
})
|
||||
return graph
|
||||
|
||||
|
||||
class TestConvolutionPartialInfer(unittest.TestCase):
|
||||
def test_deconv_infer_one_group(self):
|
||||
graph = create_deconv_graph(int64_array([1, 21, 18, 18]), int64_array([21, 50, 4, 4]),
|
||||
int64_array([1, 50, 35, 35]))
|
||||
|
||||
Deconvolution.infer(Node(graph, 'deconv_node'))
|
||||
res_shape = graph.node['deconv_output']['shape']
|
||||
exp_shape = np.array([1, 50, 35, 35])
|
||||
|
||||
res_group = graph.node['deconv_node']['group']
|
||||
exp_group = int64_array([1])
|
||||
|
||||
self.assertTrue(np.array_equal(exp_shape, res_shape),
|
||||
'values do not match expected: {} and computed: {}'.format(exp_shape, res_shape))
|
||||
|
||||
self.assertTrue(np.array_equal(exp_group, res_group),
|
||||
'group number values do not match expected: {} and computed: {}'.format(exp_group, res_group))
|
||||
|
||||
def test_deconv_infer_several_groups(self):
|
||||
graph = create_deconv_graph(int64_array([1, 21, 18, 18]), int64_array([21, 50, 4, 4]),
|
||||
int64_array([1, 350, 35, 35]))
|
||||
|
||||
Deconvolution.infer(Node(graph, 'deconv_node'))
|
||||
res_shape = graph.node['deconv_output']['shape']
|
||||
exp_shape = np.array([1, 350, 35, 35])
|
||||
|
||||
res_group = graph.node['deconv_node']['group']
|
||||
exp_group = int64_array([7])
|
||||
|
||||
self.assertTrue(np.array_equal(exp_shape, res_shape),
|
||||
'values do not match expected: {} and computed: {}'.format(exp_shape, res_shape))
|
||||
|
||||
self.assertTrue(np.array_equal(exp_group, res_group),
|
||||
'group number values do not match expected: {} and computed: {}'.format(exp_group, res_group))
|
Loading…
Reference in New Issue
Block a user