[TF FE] Refactor translation for Concat and ConcatV2 operations and add tests (#13542)

* [TF FE] Refactor translation for Concat and ConcatV2 operations and add tests

Signed-off-by: Kazantsev, Roman <roman.kazantsev@intel.com>

* [TF FE] Correct the layer tests for passing

* Apply code-review feedback: remove redundant function

* Fix a name of function to create Concat model in the layer test

Signed-off-by: Kazantsev, Roman <roman.kazantsev@intel.com>
This commit is contained in:
Roman Kazantsev 2022-10-20 17:01:31 +03:00 committed by GitHub
parent 88ffc23e3a
commit 22f3bee0f9
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 71 additions and 78 deletions

View File

@ -15,36 +15,50 @@ namespace ov {
namespace frontend { namespace frontend {
namespace tensorflow { namespace tensorflow {
namespace op { namespace op {
OutputVector translate_concat_op(const NodeContext& node) { OutputVector translate_concat_op(const NodeContext& node) {
size_t axis_idx, concat_idx_start, concat_idx_stop; // The difference between Concat and ConcatV2 is that
if (node.get_op_type() == "ConcatV2") { // axis is the first input for Concat
axis_idx = node.get_input_size() - 1; // and is the last input to ConcatV2
concat_idx_start = 0; default_op_checks(node, 2, {"Concat", "ConcatV2"});
concat_idx_stop = node.get_input_size() - 1; auto input_size = node.get_input_size();
} else if (node.get_op_type() == "Concat") {
axis_idx = 0; int64_t axis;
concat_idx_start = 1; OutputVector inputs;
concat_idx_stop = node.get_input_size();
if (node.get_op_type() == "Concat") {
std::vector<int64_t> axis_vector;
get_const_input(node, 0, &axis_vector);
TENSORFLOW_OP_VALIDATION(
node,
axis_vector.size() == 1,
"Input model is incorrect: axis input for Concat operation must have exactly one element.");
axis = axis_vector[0];
for (size_t input_idx = 1; input_idx < input_size; ++input_idx) {
inputs.push_back(node.get_input(input_idx));
}
} else if (node.get_op_type() == "ConcatV2") {
std::vector<int64_t> axis_vector;
get_const_input(node, input_size - 1, &axis_vector);
TENSORFLOW_OP_VALIDATION(
node,
axis_vector.size() == 1,
"Input model is incorrect: axis input for Concat operation must have exactly one element.");
axis = axis_vector[0];
for (size_t input_idx = 0; input_idx < input_size - 1; ++input_idx) {
inputs.push_back(node.get_input(input_idx));
}
} else { } else {
TENSORFLOW_OP_VALIDATION(node, false, "Incorrect operation type."); TENSORFLOW_OP_VALIDATION(node,
false,
"Internal TensorFlow Frontend error: incorrect operation type is passed to "
"translate_concat_op function.");
} }
std::vector<int64_t> tf_concat_axis_vec; auto concat = make_shared<Concat>(inputs, axis);
get_const_input(node, axis_idx, &tf_concat_axis_vec); set_node_name(node.get_name(), concat);
int64_t concat_axis = tf_concat_axis_vec[0]; return {concat};
OutputVector ng_args;
for (size_t i = concat_idx_start; i < concat_idx_stop; i++) {
Output<Node> ng_arg = node.get_input(static_cast<int>(i));
ng_args.push_back(ng_arg);
}
auto res = make_shared<Concat>(ng_args, size_t(concat_axis));
set_node_name(node.get_name(), res);
return res->outputs();
} }
} // namespace op } // namespace op
} // namespace tensorflow } // namespace tensorflow
} // namespace frontend } // namespace frontend
} // namespace ov } // namespace ov

View File

@ -4,94 +4,72 @@
import pytest import pytest
from common.tf_layer_test_class import CommonTFLayerTest from common.tf_layer_test_class import CommonTFLayerTest
from common.utils.tf_utils import permute_nchw_to_nhwc
class TestConcat(CommonTFLayerTest): class TestConcat(CommonTFLayerTest):
def create_concat_net(self, shape, axis, ir_version, use_new_frontend): def create_concat_net(self, input_shapes, axis, is_v2, ir_version, use_new_frontend):
""" # tf.concat is equivalent to tf.raw_ops.ConcatV2
Tensorflow net IR net # only tf.concat accepts one input
Input->Concat => Input->Concat
"""
import tensorflow as tf import tensorflow as tf
tf.compat.v1.reset_default_graph() tf.compat.v1.reset_default_graph()
# Create the graph and model
with tf.compat.v1.Session() as sess: with tf.compat.v1.Session() as sess:
ax = axis placeholders = []
for ind, input_shape in enumerate(input_shapes):
tf_x_shape = shape.copy() placeholders.append(tf.compat.v1.placeholder(tf.float32, input_shape, 'input_{}'.format(ind)))
if len(input_shapes) == 1:
tf_x_shape = permute_nchw_to_nhwc(tf_x_shape, use_new_frontend) tf.concat(values=placeholders, axis=axis, name='concat')
elif is_v2:
# TODO: add concat with const inputs to check fusing (as in ONNX) tf.raw_ops.ConcatV2(values=placeholders, axis=axis, name='concat')
else:
x = tf.compat.v1.placeholder(tf.float32, tf_x_shape, 'Input') tf.raw_ops.Concat(values=placeholders, concat_dim=axis, name='concat')
y = tf.compat.v1.placeholder(tf.float32, tf_x_shape, 'Input') # Input_1 in graph_def
concat = tf.concat([x, y], axis=ax, name='Operation')
concat_shape = concat.shape.as_list()
tf.compat.v1.global_variables_initializer() tf.compat.v1.global_variables_initializer()
tf_net = sess.graph_def tf_net = sess.graph_def
#
# Create reference IR net
# Please, specify 'type': 'Input' for input node
# Moreover, do not forget to validate ALL layer attributes!!!
#
ref_net = None ref_net = None
return tf_net, ref_net return tf_net, ref_net
# TODO: create tests for concat with 1 input and multiple inputs test_data_1D = [dict(input_shapes=[[1], [2]], axis=0, is_v2=False),
dict(input_shapes=[[1], [3]], axis=-1, is_v2=True)]
test_data_1D = [dict(shape=[1], axis=0),
dict(shape=[1], axis=-1)]
@pytest.mark.parametrize("params", test_data_1D) @pytest.mark.parametrize("params", test_data_1D)
@pytest.mark.nightly @pytest.mark.nightly
def test_concat_1D(self, params, ie_device, precision, ir_version, temp_dir, use_new_frontend, def test_concat_1D(self, params, ie_device, precision, ir_version, temp_dir, use_new_frontend,
use_old_api): use_old_api):
self._test(*self.create_concat_net(**params, ir_version=ir_version, self._test(*self.create_concat_net(**params, ir_version=ir_version,
use_new_frontend=use_new_frontend), use_new_frontend=use_new_frontend),
ie_device, precision, ir_version, temp_dir=temp_dir, ie_device, precision, ir_version, temp_dir=temp_dir,
use_new_frontend=use_new_frontend, use_old_api=use_old_api) use_new_frontend=use_new_frontend, use_old_api=use_old_api)
test_data_2D = [dict(shape=[1, 224], axis=0), test_data_2D = [dict(input_shapes=[[3, 4]], axis=0, is_v2=True),
dict(shape=[1, 224], axis=-1)] dict(input_shapes=[[1, 4], [1, 2]], axis=-1, is_v2=True)]
@pytest.mark.parametrize("params", test_data_2D) @pytest.mark.parametrize("params", test_data_2D)
@pytest.mark.nightly @pytest.mark.nightly
def test_concat_2D(self, params, ie_device, precision, ir_version, temp_dir, use_new_frontend, def test_concat_2D(self, params, ie_device, precision, ir_version, temp_dir, use_new_frontend,
use_old_api): use_old_api):
self._test(*self.create_concat_net(**params, ir_version=ir_version, self._test(*self.create_concat_net(**params, ir_version=ir_version,
use_new_frontend=use_new_frontend), use_new_frontend=use_new_frontend),
ie_device, precision, ir_version, temp_dir=temp_dir, ie_device, precision, ir_version, temp_dir=temp_dir,
use_new_frontend=use_new_frontend, use_old_api=use_old_api) use_new_frontend=use_new_frontend, use_old_api=use_old_api)
test_data_3D = [ test_data_3D = [
dict(shape=[1, 3, 224], axis=0), dict(input_shapes=[[1, 3, 5], [2, 3, 5]], axis=0, is_v2=False),
pytest.param(dict(shape=[1, 3, 224], axis=-1), marks=pytest.mark.precommit_tf_fe), dict(input_shapes=[[1, 3, 2], [1, 3, 5]], axis=-1, is_v2=True),
dict(shape=[1, 3, 224], axis=2)] dict(input_shapes=[[1, 3, 1], [1, 3, 4], [1, 3, 3]], axis=2, is_v2=True)]
@pytest.mark.parametrize("params", test_data_3D) @pytest.mark.parametrize("params", test_data_3D)
@pytest.mark.nightly @pytest.mark.nightly
@pytest.mark.precommit_tf_fe
def test_concat_3D(self, params, ie_device, precision, ir_version, temp_dir, use_new_frontend, def test_concat_3D(self, params, ie_device, precision, ir_version, temp_dir, use_new_frontend,
use_old_api): use_old_api):
self._test(*self.create_concat_net(**params, ir_version=ir_version, self._test(*self.create_concat_net(**params, ir_version=ir_version,
use_new_frontend=use_new_frontend), use_new_frontend=use_new_frontend),
ie_device, precision, ir_version, temp_dir=temp_dir, ie_device, precision, ir_version, temp_dir=temp_dir,
use_new_frontend=use_new_frontend, use_old_api=use_old_api) use_new_frontend=use_new_frontend, use_old_api=use_old_api)
test_data_4D = [dict(shape=[1, 3, 100, 224], axis=0), test_data_4D = [dict(input_shapes=[[1, 3, 5, 7], [3, 3, 5, 7], [2, 3, 5, 7]], axis=0, is_v2=False),
dict(shape=[1, 3, 100, 224], axis=-1), dict(input_shapes=[[1, 3, 5, 5], [1, 3, 5, 7]], axis=-1, is_v2=True),
dict(shape=[1, 3, 100, 224], axis=2)] dict(input_shapes=[[1, 3, 5, 7], [1, 3, 3, 7]], axis=2, is_v2=False)]
@pytest.mark.parametrize("params", test_data_4D) @pytest.mark.parametrize("params", test_data_4D)
@pytest.mark.nightly @pytest.mark.nightly
@ -99,19 +77,20 @@ class TestConcat(CommonTFLayerTest):
def test_concat_4D(self, params, ie_device, precision, ir_version, temp_dir, use_new_frontend, def test_concat_4D(self, params, ie_device, precision, ir_version, temp_dir, use_new_frontend,
use_old_api): use_old_api):
self._test(*self.create_concat_net(**params, ir_version=ir_version, self._test(*self.create_concat_net(**params, ir_version=ir_version,
use_new_frontend=use_new_frontend), use_new_frontend=use_new_frontend),
ie_device, precision, ir_version, temp_dir=temp_dir, ie_device, precision, ir_version, temp_dir=temp_dir,
use_new_frontend=use_new_frontend, use_old_api=use_old_api) use_new_frontend=use_new_frontend, use_old_api=use_old_api)
test_data_5D = [dict(shape=[1, 3, 50, 100, 224], axis=0), test_data_5D = [dict(input_shapes=[[1, 3, 5, 7, 8], [2, 3, 5, 7, 8]], axis=0, is_v2=True),
dict(shape=[1, 3, 50, 100, 224], axis=-1), dict(input_shapes=[[1, 3, 5, 7, 2], [1, 3, 5, 7, 3], [1, 3, 5, 7, 2], [1, 3, 5, 7, 4]],
dict(shape=[1, 3, 50, 100, 224], axis=2)] axis=-1, is_v2=True),
dict(input_shapes=[[1, 3, 5, 7, 8], [1, 3, 2, 7, 8]], axis=2, is_v2=False)]
@pytest.mark.parametrize("params", test_data_5D) @pytest.mark.parametrize("params", test_data_5D)
@pytest.mark.nightly @pytest.mark.nightly
def test_concat_5D(self, params, ie_device, precision, ir_version, temp_dir, use_new_frontend, def test_concat_5D(self, params, ie_device, precision, ir_version, temp_dir, use_new_frontend,
use_old_api): use_old_api):
self._test(*self.create_concat_net(**params, ir_version=ir_version, self._test(*self.create_concat_net(**params, ir_version=ir_version,
use_new_frontend=use_new_frontend), use_new_frontend=use_new_frontend),
ie_device, precision, ir_version, temp_dir=temp_dir, ie_device, precision, ir_version, temp_dir=temp_dir,
use_new_frontend=use_new_frontend, use_old_api=use_old_api) use_new_frontend=use_new_frontend, use_old_api=use_old_api)