[TF FE] Support Group Convolutions (#15130)
* [TF FE] Support Group Convolutions Signed-off-by: Kazantsev, Roman <roman.kazantsev@intel.com> * Split cases of regular Convolution and GroupConvolution operations Signed-off-by: Kazantsev, Roman <roman.kazantsev@intel.com>
This commit is contained in:
parent
1d5fa360d4
commit
5043797b1c
@ -4,6 +4,8 @@
|
||||
|
||||
#include "utils.hpp"
|
||||
|
||||
#include <limits>
|
||||
|
||||
#include "openvino/opsets/opset10.hpp"
|
||||
#include "openvino/opsets/opset8.hpp"
|
||||
#include "openvino_conversions.hpp"
|
||||
@ -182,8 +184,56 @@ ov::OutputVector ov::frontend::tensorflow::translate_convolution_op(const ov::fr
|
||||
ov::AxisVector permutation_3d = {4, 3, 0, 1, 2};
|
||||
filter = ov::frontend::tensorflow::make_transpose(filter, spatial_dims_num == 2 ? permutation_2d : permutation_3d);
|
||||
|
||||
ov::Output<ov::Node> conv =
|
||||
std::make_shared<Convolution>(input, filter, strides, pads_begin, pads_end, dilations, auto_pad);
|
||||
bool input_channels_static = false;
|
||||
int64_t num_groups = 1;
|
||||
auto input_shape = input.get_partial_shape();
|
||||
auto filter_shape = filter.get_partial_shape();
|
||||
if (input_shape.rank().is_static() && filter_shape.rank().is_static()) {
|
||||
auto input_rank = static_cast<size_t>(input_shape.rank().get_length());
|
||||
auto filter_rank = static_cast<size_t>(filter_shape.rank().get_length());
|
||||
TENSORFLOW_OP_VALIDATION(node, input_rank == (spatial_dims_num + 2), "Internal error: incorrect input rank.");
|
||||
TENSORFLOW_OP_VALIDATION(node, filter_rank == input_rank, "Internal error: incorrect filter rank.");
|
||||
auto input_channels_size = input_shape[1];
|
||||
auto filter_channels_size = filter_shape[1];
|
||||
if (input_channels_size.is_static() && filter_channels_size.is_static()) {
|
||||
// we assume that input channel size will not be changed if they are already static
|
||||
// this will simplify us to differentiate Convolution and GroupConvolution cases
|
||||
num_groups = input_channels_size.get_length() / filter_channels_size.get_length();
|
||||
TENSORFLOW_OP_VALIDATION(node,
|
||||
num_groups >= 1,
|
||||
"Internal error: number of groups for Convolutional operation is not positive.");
|
||||
input_channels_static = true;
|
||||
}
|
||||
}
|
||||
|
||||
ov::Output<ov::Node> conv;
|
||||
if (input_channels_static && num_groups == 1) {
|
||||
// regular convolutional operation
|
||||
// we assume that input channel size will not be changed if they are already static
|
||||
conv = std::make_shared<Convolution>(input, filter, strides, pads_begin, pads_end, dilations, auto_pad);
|
||||
} else {
|
||||
// grouped convolutional operation
|
||||
// compute input channels given from the input and the filter
|
||||
// and number of groups required to split the filter
|
||||
auto input_shape = make_shared<ShapeOf>(input, element::i32);
|
||||
auto filter_shape = make_shared<ShapeOf>(filter, element::i32);
|
||||
auto zero_const = make_shared<Constant>(element::i32, Shape{1}, 0);
|
||||
auto one_const = make_shared<Constant>(element::i32, Shape{1}, 1);
|
||||
auto two_const = make_shared<Constant>(element::i32, Shape{1}, 2);
|
||||
auto input_cin = make_shared<Slice>(input_shape, one_const, two_const, one_const);
|
||||
auto filter_cin = make_shared<Slice>(filter_shape, one_const, two_const, one_const);
|
||||
auto num_groups = make_shared<Divide>(input_cin, filter_cin);
|
||||
|
||||
// reshape the filter based on the number of groups information
|
||||
auto int_max_const = make_shared<Constant>(element::i32, Shape{1}, std::numeric_limits<int>::max());
|
||||
auto filter_cout = make_shared<Slice>(filter_shape, zero_const, one_const, one_const);
|
||||
auto filter_new_cout = make_shared<Divide>(filter_cout, num_groups);
|
||||
auto shape_cin_xy = make_shared<Slice>(filter_shape, one_const, int_max_const, one_const);
|
||||
auto filter_new_shape = make_shared<Concat>(OutputVector{num_groups, filter_new_cout, shape_cin_xy}, 0);
|
||||
auto new_filter = make_shared<Reshape>(filter, filter_new_shape, false);
|
||||
conv =
|
||||
std::make_shared<GroupConvolution>(input, new_filter, strides, pads_begin, pads_end, dilations, auto_pad);
|
||||
}
|
||||
|
||||
ov::frontend::tensorflow::convert_nchw_to_nhwc(is_nhwc, conv, ov::Rank(spatial_dims_num + 2));
|
||||
ov::frontend::tensorflow::set_node_name(node.get_name(), conv.get_node_shared_ptr());
|
||||
|
@ -1,12 +1,10 @@
|
||||
# Copyright (C) 2023 Intel Corporation
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
from common.tf_layer_test_class import CommonTFLayerTest
|
||||
|
||||
import logging
|
||||
|
||||
# Testing operation Conv2D
|
||||
# Documentation: https://www.tensorflow.org/api_docs/python/tf/raw_ops/Conv2D
|
||||
@ -18,7 +16,8 @@ class TestConv2D(CommonTFLayerTest):
|
||||
# input_padding - should be a string, defines padding algorithm
|
||||
# ir_version - common parameter
|
||||
# use_new_frontend - common parameter
|
||||
def create_conv2d_placeholder_const_net(self, input_shape, input_filter, input_strides, input_padding, dilations, ir_version, use_new_frontend):
|
||||
def create_conv2d_placeholder_const_net(self, input_shape, input_filter, input_strides, input_padding, dilations,
|
||||
ir_version, use_new_frontend):
|
||||
"""
|
||||
Tensorflow net IR net
|
||||
|
||||
@ -31,10 +30,10 @@ class TestConv2D(CommonTFLayerTest):
|
||||
import tensorflow as tf
|
||||
|
||||
if dilations is None:
|
||||
dilations = [1, 1, 1, 1] #default value regarding Documentation
|
||||
dilations = [1, 1, 1, 1] # default value regarding Documentation
|
||||
|
||||
# Batch Height Width Channel
|
||||
expl_paddings = [0, 0, 1, 1, 1, 1, 0, 0]
|
||||
expl_paddings = [0, 0, 1, 1, 1, 1, 0, 0]
|
||||
|
||||
if input_padding == 'EXPLICIT' and use_new_frontend == False:
|
||||
pytest.xfail(reason="100300")
|
||||
@ -47,9 +46,11 @@ class TestConv2D(CommonTFLayerTest):
|
||||
tf_filter = tf.compat.v1.placeholder(tf.float32, input_filter, "InputFilter")
|
||||
|
||||
if input_padding != 'EXPLICIT':
|
||||
tf.raw_ops.Conv2D(input = tf_input, filter = tf_filter, strides = input_strides, padding = input_padding, dilations = dilations)
|
||||
tf.raw_ops.Conv2D(input=tf_input, filter=tf_filter, strides=input_strides, padding=input_padding,
|
||||
dilations=dilations)
|
||||
else:
|
||||
tf.raw_ops.Conv2D(input = tf_input, filter = tf_filter, strides = input_strides, padding = input_padding, explicit_paddings=expl_paddings, dilations = dilations)
|
||||
tf.raw_ops.Conv2D(input=tf_input, filter=tf_filter, strides=input_strides, padding=input_padding,
|
||||
explicit_paddings=expl_paddings, dilations=dilations)
|
||||
|
||||
tf.compat.v1.global_variables_initializer()
|
||||
tf_net = sess.graph_def
|
||||
@ -67,8 +68,14 @@ class TestConv2D(CommonTFLayerTest):
|
||||
dict(input_shape=[1, 10, 10, 4], input_filter=[2, 2, 4, 2], input_strides=[1, 1, 1, 1], dilations=None),
|
||||
dict(input_shape=[1, 16, 16, 3], input_filter=[2, 2, 3, 3], input_strides=[1, 2, 2, 1], dilations=[1, 2, 2, 1]),
|
||||
pytest.param(
|
||||
dict(input_shape=[1, 224, 224, 3], input_filter=[4, 4, 3, 2], input_strides=[1, 2, 2, 1], dilations=[1, 2, 2, 1]),
|
||||
dict(input_shape=[1, 224, 224, 3], input_filter=[4, 4, 3, 2], input_strides=[1, 2, 2, 1],
|
||||
dilations=[1, 2, 2, 1]),
|
||||
marks=pytest.mark.precommit_tf_fe),
|
||||
# with four groups
|
||||
pytest.param(
|
||||
dict(input_shape=[2, 224, 224, 4], input_filter=[4, 4, 1, 12], input_strides=[1, 2, 2, 1],
|
||||
dilations=[1, 2, 2, 1]),
|
||||
marks=pytest.mark.precommit_tf_fe)
|
||||
]
|
||||
|
||||
@pytest.mark.parametrize("params", test_data)
|
||||
@ -77,6 +84,6 @@ class TestConv2D(CommonTFLayerTest):
|
||||
def test_conv2d_placeholder_const(self, params, padding, ie_device, precision, ir_version, temp_dir,
|
||||
use_new_frontend, use_old_api):
|
||||
self._test(*self.create_conv2d_placeholder_const_net(**params, input_padding=padding, ir_version=ir_version,
|
||||
use_new_frontend=use_new_frontend),
|
||||
use_new_frontend=use_new_frontend),
|
||||
ie_device, precision, ir_version, input_padding=padding, temp_dir=temp_dir,
|
||||
use_new_frontend=use_new_frontend, use_old_api=use_old_api)
|
||||
|
@ -1,12 +1,10 @@
|
||||
# Copyright (C) 2023 Intel Corporation
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
from common.tf_layer_test_class import CommonTFLayerTest
|
||||
|
||||
import logging
|
||||
|
||||
# Testing operation Conv3D
|
||||
# Documentation: https://www.tensorflow.org/api_docs/python/tf/raw_ops/Conv3D
|
||||
@ -18,7 +16,8 @@ class TestConv3D(CommonTFLayerTest):
|
||||
# input_padding - should be a string, defines padding algorithm
|
||||
# ir_version - common parameter
|
||||
# use_new_frontend - common parameter
|
||||
def create_conv3d_placeholder_const_net(self, input_shape, input_filter, input_strides, input_padding, dilations, ir_version, use_new_frontend):
|
||||
def create_conv3d_placeholder_const_net(self, input_shape, input_filter, input_strides, input_padding, dilations,
|
||||
ir_version, use_new_frontend):
|
||||
"""
|
||||
Tensorflow net IR net
|
||||
|
||||
@ -31,7 +30,7 @@ class TestConv3D(CommonTFLayerTest):
|
||||
import tensorflow as tf
|
||||
|
||||
if dilations is None:
|
||||
dilations = [1, 1, 1, 1, 1] #default value regarding Documentation
|
||||
dilations = [1, 1, 1, 1, 1] # default value regarding Documentation
|
||||
else:
|
||||
pytest.skip('Dilations != 1 isn\' supported on CPU by Tensorflow')
|
||||
|
||||
@ -42,7 +41,8 @@ class TestConv3D(CommonTFLayerTest):
|
||||
tf_input = tf.compat.v1.placeholder(tf.float32, input_shape, "InputShape")
|
||||
tf_filter = tf.compat.v1.placeholder(tf.float32, input_filter, "InputFilter")
|
||||
|
||||
tf.raw_ops.Conv3D(input = tf_input, filter = tf_filter, strides = input_strides, padding = input_padding, dilations = dilations)
|
||||
tf.raw_ops.Conv3D(input=tf_input, filter=tf_filter, strides=input_strides, padding=input_padding,
|
||||
dilations=dilations)
|
||||
|
||||
tf.compat.v1.global_variables_initializer()
|
||||
tf_net = sess.graph_def
|
||||
@ -52,15 +52,23 @@ class TestConv3D(CommonTFLayerTest):
|
||||
return tf_net, ref_net
|
||||
|
||||
test_data = [
|
||||
dict(input_shape=[1, 10, 10, 10, 1], input_filter=[3, 3, 3, 1, 1], input_strides=[1, 1, 1, 1, 1], dilations=None),
|
||||
dict(input_shape=[1, 10, 10, 10, 2], input_filter=[3, 3, 3, 2, 2], input_strides=[1, 1, 1, 1, 1], dilations=None),
|
||||
dict(input_shape=[1, 10, 10, 10, 2], input_filter=[3, 3, 3, 2, 1], input_strides=[1, 1, 1, 1, 1], dilations=None),
|
||||
dict(input_shape=[1, 16, 16, 16, 3], input_filter=[4, 2, 4, 3, 3], input_strides=[1, 2, 2, 2, 1], dilations=None),
|
||||
dict(input_shape=[1, 10, 10, 20, 3], input_filter=[2, 4, 2, 3, 3], input_strides=[1, 2, 2, 2, 1], dilations=None),
|
||||
dict(input_shape=[1, 10, 10, 10, 4], input_filter=[3, 3, 3, 4, 2], input_strides=[1, 1, 1, 1, 1], dilations=None),
|
||||
dict(input_shape=[1, 10, 10, 20, 3], input_filter=[2, 4, 2, 3, 3], input_strides=[1, 2, 2, 2, 1], dilations=[1, 2, 2, 2, 1]),
|
||||
dict(input_shape=[1, 10, 10, 10, 1], input_filter=[3, 3, 3, 1, 1], input_strides=[1, 1, 1, 1, 1],
|
||||
dilations=None),
|
||||
dict(input_shape=[1, 10, 10, 10, 2], input_filter=[3, 3, 3, 2, 2], input_strides=[1, 1, 1, 1, 1],
|
||||
dilations=None),
|
||||
dict(input_shape=[1, 10, 10, 10, 2], input_filter=[3, 3, 3, 2, 1], input_strides=[1, 1, 1, 1, 1],
|
||||
dilations=None),
|
||||
dict(input_shape=[1, 16, 16, 16, 3], input_filter=[4, 2, 4, 3, 3], input_strides=[1, 2, 2, 2, 1],
|
||||
dilations=None),
|
||||
dict(input_shape=[1, 10, 10, 20, 3], input_filter=[2, 4, 2, 3, 3], input_strides=[1, 2, 2, 2, 1],
|
||||
dilations=None),
|
||||
dict(input_shape=[1, 10, 10, 10, 4], input_filter=[3, 3, 3, 4, 2], input_strides=[1, 1, 1, 1, 1],
|
||||
dilations=None),
|
||||
dict(input_shape=[1, 10, 10, 20, 3], input_filter=[2, 4, 2, 3, 3], input_strides=[1, 2, 2, 2, 1],
|
||||
dilations=[1, 2, 2, 2, 1]),
|
||||
pytest.param(
|
||||
dict(input_shape=[1, 224, 224, 224, 3], input_filter=[1, 2, 3, 3, 2], input_strides=[1, 2, 2, 2, 1], dilations=None),
|
||||
dict(input_shape=[1, 224, 224, 224, 3], input_filter=[1, 2, 3, 3, 2], input_strides=[1, 2, 2, 2, 1],
|
||||
dilations=None),
|
||||
marks=pytest.mark.precommit_tf_fe),
|
||||
]
|
||||
|
||||
@ -70,6 +78,6 @@ class TestConv3D(CommonTFLayerTest):
|
||||
def test_conv3d_placeholder_const(self, params, padding, ie_device, precision, ir_version, temp_dir,
|
||||
use_new_frontend, use_old_api):
|
||||
self._test(*self.create_conv3d_placeholder_const_net(**params, input_padding=padding, ir_version=ir_version,
|
||||
use_new_frontend=use_new_frontend),
|
||||
use_new_frontend=use_new_frontend),
|
||||
ie_device, precision, ir_version, input_padding=padding, temp_dir=temp_dir,
|
||||
use_new_frontend=use_new_frontend, use_old_api=use_old_api)
|
||||
|
Loading…
Reference in New Issue
Block a user