[TF FE] Support Group Convolutions (#15130)

* [TF FE] Support Group Convolutions

Signed-off-by: Kazantsev, Roman <roman.kazantsev@intel.com>

* Split cases of regular Convolution and GroupConvolution operations

Signed-off-by: Kazantsev, Roman <roman.kazantsev@intel.com>
This commit is contained in:
Roman Kazantsev 2023-01-17 16:15:19 +04:00 committed by GitHub
parent 1d5fa360d4
commit 5043797b1c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 90 additions and 25 deletions

View File

@ -4,6 +4,8 @@
#include "utils.hpp"
#include <limits>
#include "openvino/opsets/opset10.hpp"
#include "openvino/opsets/opset8.hpp"
#include "openvino_conversions.hpp"
@ -182,8 +184,56 @@ ov::OutputVector ov::frontend::tensorflow::translate_convolution_op(const ov::fr
ov::AxisVector permutation_3d = {4, 3, 0, 1, 2};
filter = ov::frontend::tensorflow::make_transpose(filter, spatial_dims_num == 2 ? permutation_2d : permutation_3d);
ov::Output<ov::Node> conv =
std::make_shared<Convolution>(input, filter, strides, pads_begin, pads_end, dilations, auto_pad);
bool input_channels_static = false;
int64_t num_groups = 1;
auto input_shape = input.get_partial_shape();
auto filter_shape = filter.get_partial_shape();
if (input_shape.rank().is_static() && filter_shape.rank().is_static()) {
auto input_rank = static_cast<size_t>(input_shape.rank().get_length());
auto filter_rank = static_cast<size_t>(filter_shape.rank().get_length());
TENSORFLOW_OP_VALIDATION(node, input_rank == (spatial_dims_num + 2), "Internal error: incorrect input rank.");
TENSORFLOW_OP_VALIDATION(node, filter_rank == input_rank, "Internal error: incorrect filter rank.");
auto input_channels_size = input_shape[1];
auto filter_channels_size = filter_shape[1];
if (input_channels_size.is_static() && filter_channels_size.is_static()) {
// we assume that input channel size will not be changed if they are already static
// this will simplify us to differentiate Convolution and GroupConvolution cases
num_groups = input_channels_size.get_length() / filter_channels_size.get_length();
TENSORFLOW_OP_VALIDATION(node,
num_groups >= 1,
"Internal error: number of groups for Convolutional operation is not positive.");
input_channels_static = true;
}
}
ov::Output<ov::Node> conv;
if (input_channels_static && num_groups == 1) {
// regular convolutional operation
// we assume that input channel size will not be changed if they are already static
conv = std::make_shared<Convolution>(input, filter, strides, pads_begin, pads_end, dilations, auto_pad);
} else {
// grouped convolutional operation
// compute input channels given from the input and the filter
// and number of groups required to split the filter
auto input_shape = make_shared<ShapeOf>(input, element::i32);
auto filter_shape = make_shared<ShapeOf>(filter, element::i32);
auto zero_const = make_shared<Constant>(element::i32, Shape{1}, 0);
auto one_const = make_shared<Constant>(element::i32, Shape{1}, 1);
auto two_const = make_shared<Constant>(element::i32, Shape{1}, 2);
auto input_cin = make_shared<Slice>(input_shape, one_const, two_const, one_const);
auto filter_cin = make_shared<Slice>(filter_shape, one_const, two_const, one_const);
auto num_groups = make_shared<Divide>(input_cin, filter_cin);
// reshape the filter based on the number of groups information
auto int_max_const = make_shared<Constant>(element::i32, Shape{1}, std::numeric_limits<int>::max());
auto filter_cout = make_shared<Slice>(filter_shape, zero_const, one_const, one_const);
auto filter_new_cout = make_shared<Divide>(filter_cout, num_groups);
auto shape_cin_xy = make_shared<Slice>(filter_shape, one_const, int_max_const, one_const);
auto filter_new_shape = make_shared<Concat>(OutputVector{num_groups, filter_new_cout, shape_cin_xy}, 0);
auto new_filter = make_shared<Reshape>(filter, filter_new_shape, false);
conv =
std::make_shared<GroupConvolution>(input, new_filter, strides, pads_begin, pads_end, dilations, auto_pad);
}
ov::frontend::tensorflow::convert_nchw_to_nhwc(is_nhwc, conv, ov::Rank(spatial_dims_num + 2));
ov::frontend::tensorflow::set_node_name(node.get_name(), conv.get_node_shared_ptr());

View File

@ -1,12 +1,10 @@
# Copyright (C) 2023 Intel Corporation
# SPDX-License-Identifier: Apache-2.0
import numpy as np
import pytest
from common.tf_layer_test_class import CommonTFLayerTest
import logging
# Testing operation Conv2D
# Documentation: https://www.tensorflow.org/api_docs/python/tf/raw_ops/Conv2D
@ -18,7 +16,8 @@ class TestConv2D(CommonTFLayerTest):
# input_padding - should be a string, defines padding algorithm
# ir_version - common parameter
# use_new_frontend - common parameter
def create_conv2d_placeholder_const_net(self, input_shape, input_filter, input_strides, input_padding, dilations, ir_version, use_new_frontend):
def create_conv2d_placeholder_const_net(self, input_shape, input_filter, input_strides, input_padding, dilations,
ir_version, use_new_frontend):
"""
Tensorflow net IR net
@ -31,10 +30,10 @@ class TestConv2D(CommonTFLayerTest):
import tensorflow as tf
if dilations is None:
dilations = [1, 1, 1, 1] #default value regarding Documentation
dilations = [1, 1, 1, 1] # default value regarding Documentation
# Batch Height Width Channel
expl_paddings = [0, 0, 1, 1, 1, 1, 0, 0]
expl_paddings = [0, 0, 1, 1, 1, 1, 0, 0]
if input_padding == 'EXPLICIT' and use_new_frontend == False:
pytest.xfail(reason="100300")
@ -47,9 +46,11 @@ class TestConv2D(CommonTFLayerTest):
tf_filter = tf.compat.v1.placeholder(tf.float32, input_filter, "InputFilter")
if input_padding != 'EXPLICIT':
tf.raw_ops.Conv2D(input = tf_input, filter = tf_filter, strides = input_strides, padding = input_padding, dilations = dilations)
tf.raw_ops.Conv2D(input=tf_input, filter=tf_filter, strides=input_strides, padding=input_padding,
dilations=dilations)
else:
tf.raw_ops.Conv2D(input = tf_input, filter = tf_filter, strides = input_strides, padding = input_padding, explicit_paddings=expl_paddings, dilations = dilations)
tf.raw_ops.Conv2D(input=tf_input, filter=tf_filter, strides=input_strides, padding=input_padding,
explicit_paddings=expl_paddings, dilations=dilations)
tf.compat.v1.global_variables_initializer()
tf_net = sess.graph_def
@ -67,8 +68,14 @@ class TestConv2D(CommonTFLayerTest):
dict(input_shape=[1, 10, 10, 4], input_filter=[2, 2, 4, 2], input_strides=[1, 1, 1, 1], dilations=None),
dict(input_shape=[1, 16, 16, 3], input_filter=[2, 2, 3, 3], input_strides=[1, 2, 2, 1], dilations=[1, 2, 2, 1]),
pytest.param(
dict(input_shape=[1, 224, 224, 3], input_filter=[4, 4, 3, 2], input_strides=[1, 2, 2, 1], dilations=[1, 2, 2, 1]),
dict(input_shape=[1, 224, 224, 3], input_filter=[4, 4, 3, 2], input_strides=[1, 2, 2, 1],
dilations=[1, 2, 2, 1]),
marks=pytest.mark.precommit_tf_fe),
# with four groups
pytest.param(
dict(input_shape=[2, 224, 224, 4], input_filter=[4, 4, 1, 12], input_strides=[1, 2, 2, 1],
dilations=[1, 2, 2, 1]),
marks=pytest.mark.precommit_tf_fe)
]
@pytest.mark.parametrize("params", test_data)
@ -77,6 +84,6 @@ class TestConv2D(CommonTFLayerTest):
def test_conv2d_placeholder_const(self, params, padding, ie_device, precision, ir_version, temp_dir,
use_new_frontend, use_old_api):
self._test(*self.create_conv2d_placeholder_const_net(**params, input_padding=padding, ir_version=ir_version,
use_new_frontend=use_new_frontend),
use_new_frontend=use_new_frontend),
ie_device, precision, ir_version, input_padding=padding, temp_dir=temp_dir,
use_new_frontend=use_new_frontend, use_old_api=use_old_api)

View File

@ -1,12 +1,10 @@
# Copyright (C) 2023 Intel Corporation
# SPDX-License-Identifier: Apache-2.0
import numpy as np
import pytest
from common.tf_layer_test_class import CommonTFLayerTest
import logging
# Testing operation Conv3D
# Documentation: https://www.tensorflow.org/api_docs/python/tf/raw_ops/Conv3D
@ -18,7 +16,8 @@ class TestConv3D(CommonTFLayerTest):
# input_padding - should be a string, defines padding algorithm
# ir_version - common parameter
# use_new_frontend - common parameter
def create_conv3d_placeholder_const_net(self, input_shape, input_filter, input_strides, input_padding, dilations, ir_version, use_new_frontend):
def create_conv3d_placeholder_const_net(self, input_shape, input_filter, input_strides, input_padding, dilations,
ir_version, use_new_frontend):
"""
Tensorflow net IR net
@ -31,7 +30,7 @@ class TestConv3D(CommonTFLayerTest):
import tensorflow as tf
if dilations is None:
dilations = [1, 1, 1, 1, 1] #default value regarding Documentation
dilations = [1, 1, 1, 1, 1] # default value regarding Documentation
else:
pytest.skip('Dilations != 1 isn\' supported on CPU by Tensorflow')
@ -42,7 +41,8 @@ class TestConv3D(CommonTFLayerTest):
tf_input = tf.compat.v1.placeholder(tf.float32, input_shape, "InputShape")
tf_filter = tf.compat.v1.placeholder(tf.float32, input_filter, "InputFilter")
tf.raw_ops.Conv3D(input = tf_input, filter = tf_filter, strides = input_strides, padding = input_padding, dilations = dilations)
tf.raw_ops.Conv3D(input=tf_input, filter=tf_filter, strides=input_strides, padding=input_padding,
dilations=dilations)
tf.compat.v1.global_variables_initializer()
tf_net = sess.graph_def
@ -52,15 +52,23 @@ class TestConv3D(CommonTFLayerTest):
return tf_net, ref_net
test_data = [
dict(input_shape=[1, 10, 10, 10, 1], input_filter=[3, 3, 3, 1, 1], input_strides=[1, 1, 1, 1, 1], dilations=None),
dict(input_shape=[1, 10, 10, 10, 2], input_filter=[3, 3, 3, 2, 2], input_strides=[1, 1, 1, 1, 1], dilations=None),
dict(input_shape=[1, 10, 10, 10, 2], input_filter=[3, 3, 3, 2, 1], input_strides=[1, 1, 1, 1, 1], dilations=None),
dict(input_shape=[1, 16, 16, 16, 3], input_filter=[4, 2, 4, 3, 3], input_strides=[1, 2, 2, 2, 1], dilations=None),
dict(input_shape=[1, 10, 10, 20, 3], input_filter=[2, 4, 2, 3, 3], input_strides=[1, 2, 2, 2, 1], dilations=None),
dict(input_shape=[1, 10, 10, 10, 4], input_filter=[3, 3, 3, 4, 2], input_strides=[1, 1, 1, 1, 1], dilations=None),
dict(input_shape=[1, 10, 10, 20, 3], input_filter=[2, 4, 2, 3, 3], input_strides=[1, 2, 2, 2, 1], dilations=[1, 2, 2, 2, 1]),
dict(input_shape=[1, 10, 10, 10, 1], input_filter=[3, 3, 3, 1, 1], input_strides=[1, 1, 1, 1, 1],
dilations=None),
dict(input_shape=[1, 10, 10, 10, 2], input_filter=[3, 3, 3, 2, 2], input_strides=[1, 1, 1, 1, 1],
dilations=None),
dict(input_shape=[1, 10, 10, 10, 2], input_filter=[3, 3, 3, 2, 1], input_strides=[1, 1, 1, 1, 1],
dilations=None),
dict(input_shape=[1, 16, 16, 16, 3], input_filter=[4, 2, 4, 3, 3], input_strides=[1, 2, 2, 2, 1],
dilations=None),
dict(input_shape=[1, 10, 10, 20, 3], input_filter=[2, 4, 2, 3, 3], input_strides=[1, 2, 2, 2, 1],
dilations=None),
dict(input_shape=[1, 10, 10, 10, 4], input_filter=[3, 3, 3, 4, 2], input_strides=[1, 1, 1, 1, 1],
dilations=None),
dict(input_shape=[1, 10, 10, 20, 3], input_filter=[2, 4, 2, 3, 3], input_strides=[1, 2, 2, 2, 1],
dilations=[1, 2, 2, 2, 1]),
pytest.param(
dict(input_shape=[1, 224, 224, 224, 3], input_filter=[1, 2, 3, 3, 2], input_strides=[1, 2, 2, 2, 1], dilations=None),
dict(input_shape=[1, 224, 224, 224, 3], input_filter=[1, 2, 3, 3, 2], input_strides=[1, 2, 2, 2, 1],
dilations=None),
marks=pytest.mark.precommit_tf_fe),
]
@ -70,6 +78,6 @@ class TestConv3D(CommonTFLayerTest):
def test_conv3d_placeholder_const(self, params, padding, ie_device, precision, ir_version, temp_dir,
use_new_frontend, use_old_api):
self._test(*self.create_conv3d_placeholder_const_net(**params, input_padding=padding, ir_version=ir_version,
use_new_frontend=use_new_frontend),
use_new_frontend=use_new_frontend),
ie_device, precision, ir_version, input_padding=padding, temp_dir=temp_dir,
use_new_frontend=use_new_frontend, use_old_api=use_old_api)