[TF FE] Refactor translation for Concat and ConcatV2 operations and add tests (#13542)

* [TF FE] Refactor translation for Concat and ConcatV2 operations and add tests

Signed-off-by: Kazantsev, Roman <roman.kazantsev@intel.com>

* [TF FE] Correct the layer tests for passing

* Apply code-review feedback: remove redundant function

* Fix a name of function to create Concat model in the layer test

Signed-off-by: Kazantsev, Roman <roman.kazantsev@intel.com>
This commit is contained in:
Roman Kazantsev 2022-10-20 17:01:31 +03:00 committed by GitHub
parent 88ffc23e3a
commit 22f3bee0f9
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 71 additions and 78 deletions

View File

@ -15,34 +15,48 @@ namespace ov {
namespace frontend {
namespace tensorflow {
namespace op {
OutputVector translate_concat_op(const NodeContext& node) {
size_t axis_idx, concat_idx_start, concat_idx_stop;
if (node.get_op_type() == "ConcatV2") {
axis_idx = node.get_input_size() - 1;
concat_idx_start = 0;
concat_idx_stop = node.get_input_size() - 1;
} else if (node.get_op_type() == "Concat") {
axis_idx = 0;
concat_idx_start = 1;
concat_idx_stop = node.get_input_size();
// The difference between Concat and ConcatV2 is that
// axis is the first input for Concat
// and is the last input to ConcatV2
default_op_checks(node, 2, {"Concat", "ConcatV2"});
auto input_size = node.get_input_size();
int64_t axis;
OutputVector inputs;
if (node.get_op_type() == "Concat") {
std::vector<int64_t> axis_vector;
get_const_input(node, 0, &axis_vector);
TENSORFLOW_OP_VALIDATION(
node,
axis_vector.size() == 1,
"Input model is incorrect: axis input for Concat operation must have exactly one element.");
axis = axis_vector[0];
for (size_t input_idx = 1; input_idx < input_size; ++input_idx) {
inputs.push_back(node.get_input(input_idx));
}
} else if (node.get_op_type() == "ConcatV2") {
std::vector<int64_t> axis_vector;
get_const_input(node, input_size - 1, &axis_vector);
TENSORFLOW_OP_VALIDATION(
node,
axis_vector.size() == 1,
"Input model is incorrect: axis input for Concat operation must have exactly one element.");
axis = axis_vector[0];
for (size_t input_idx = 0; input_idx < input_size - 1; ++input_idx) {
inputs.push_back(node.get_input(input_idx));
}
} else {
TENSORFLOW_OP_VALIDATION(node, false, "Incorrect operation type.");
TENSORFLOW_OP_VALIDATION(node,
false,
"Internal TensorFlow Frontend error: incorrect operation type is passed to "
"translate_concat_op function.");
}
std::vector<int64_t> tf_concat_axis_vec;
get_const_input(node, axis_idx, &tf_concat_axis_vec);
int64_t concat_axis = tf_concat_axis_vec[0];
OutputVector ng_args;
for (size_t i = concat_idx_start; i < concat_idx_stop; i++) {
Output<Node> ng_arg = node.get_input(static_cast<int>(i));
ng_args.push_back(ng_arg);
}
auto res = make_shared<Concat>(ng_args, size_t(concat_axis));
set_node_name(node.get_name(), res);
return res->outputs();
auto concat = make_shared<Concat>(inputs, axis);
set_node_name(node.get_name(), concat);
return {concat};
}
} // namespace op
} // namespace tensorflow

View File

@ -4,55 +4,32 @@
import pytest
from common.tf_layer_test_class import CommonTFLayerTest
from common.utils.tf_utils import permute_nchw_to_nhwc
class TestConcat(CommonTFLayerTest):
def create_concat_net(self, shape, axis, ir_version, use_new_frontend):
"""
Tensorflow net IR net
Input->Concat => Input->Concat
"""
def create_concat_net(self, input_shapes, axis, is_v2, ir_version, use_new_frontend):
# tf.concat is equivalent to tf.raw_ops.ConcatV2
# only tf.concat accepts one input
import tensorflow as tf
tf.compat.v1.reset_default_graph()
# Create the graph and model
with tf.compat.v1.Session() as sess:
ax = axis
tf_x_shape = shape.copy()
tf_x_shape = permute_nchw_to_nhwc(tf_x_shape, use_new_frontend)
# TODO: add concat with const inputs to check fusing (as in ONNX)
x = tf.compat.v1.placeholder(tf.float32, tf_x_shape, 'Input')
y = tf.compat.v1.placeholder(tf.float32, tf_x_shape, 'Input') # Input_1 in graph_def
concat = tf.concat([x, y], axis=ax, name='Operation')
concat_shape = concat.shape.as_list()
placeholders = []
for ind, input_shape in enumerate(input_shapes):
placeholders.append(tf.compat.v1.placeholder(tf.float32, input_shape, 'input_{}'.format(ind)))
if len(input_shapes) == 1:
tf.concat(values=placeholders, axis=axis, name='concat')
elif is_v2:
tf.raw_ops.ConcatV2(values=placeholders, axis=axis, name='concat')
else:
tf.raw_ops.Concat(values=placeholders, concat_dim=axis, name='concat')
tf.compat.v1.global_variables_initializer()
tf_net = sess.graph_def
#
# Create reference IR net
# Please, specify 'type': 'Input' for input node
# Moreover, do not forget to validate ALL layer attributes!!!
#
ref_net = None
return tf_net, ref_net
# TODO: create tests for concat with 1 input and multiple inputs
test_data_1D = [dict(shape=[1], axis=0),
dict(shape=[1], axis=-1)]
test_data_1D = [dict(input_shapes=[[1], [2]], axis=0, is_v2=False),
dict(input_shapes=[[1], [3]], axis=-1, is_v2=True)]
@pytest.mark.parametrize("params", test_data_1D)
@pytest.mark.nightly
@ -63,8 +40,8 @@ class TestConcat(CommonTFLayerTest):
ie_device, precision, ir_version, temp_dir=temp_dir,
use_new_frontend=use_new_frontend, use_old_api=use_old_api)
test_data_2D = [dict(shape=[1, 224], axis=0),
dict(shape=[1, 224], axis=-1)]
test_data_2D = [dict(input_shapes=[[3, 4]], axis=0, is_v2=True),
dict(input_shapes=[[1, 4], [1, 2]], axis=-1, is_v2=True)]
@pytest.mark.parametrize("params", test_data_2D)
@pytest.mark.nightly
@ -76,12 +53,13 @@ class TestConcat(CommonTFLayerTest):
use_new_frontend=use_new_frontend, use_old_api=use_old_api)
test_data_3D = [
dict(shape=[1, 3, 224], axis=0),
pytest.param(dict(shape=[1, 3, 224], axis=-1), marks=pytest.mark.precommit_tf_fe),
dict(shape=[1, 3, 224], axis=2)]
dict(input_shapes=[[1, 3, 5], [2, 3, 5]], axis=0, is_v2=False),
dict(input_shapes=[[1, 3, 2], [1, 3, 5]], axis=-1, is_v2=True),
dict(input_shapes=[[1, 3, 1], [1, 3, 4], [1, 3, 3]], axis=2, is_v2=True)]
@pytest.mark.parametrize("params", test_data_3D)
@pytest.mark.nightly
@pytest.mark.precommit_tf_fe
def test_concat_3D(self, params, ie_device, precision, ir_version, temp_dir, use_new_frontend,
use_old_api):
self._test(*self.create_concat_net(**params, ir_version=ir_version,
@ -89,9 +67,9 @@ class TestConcat(CommonTFLayerTest):
ie_device, precision, ir_version, temp_dir=temp_dir,
use_new_frontend=use_new_frontend, use_old_api=use_old_api)
test_data_4D = [dict(shape=[1, 3, 100, 224], axis=0),
dict(shape=[1, 3, 100, 224], axis=-1),
dict(shape=[1, 3, 100, 224], axis=2)]
test_data_4D = [dict(input_shapes=[[1, 3, 5, 7], [3, 3, 5, 7], [2, 3, 5, 7]], axis=0, is_v2=False),
dict(input_shapes=[[1, 3, 5, 5], [1, 3, 5, 7]], axis=-1, is_v2=True),
dict(input_shapes=[[1, 3, 5, 7], [1, 3, 3, 7]], axis=2, is_v2=False)]
@pytest.mark.parametrize("params", test_data_4D)
@pytest.mark.nightly
@ -103,9 +81,10 @@ class TestConcat(CommonTFLayerTest):
ie_device, precision, ir_version, temp_dir=temp_dir,
use_new_frontend=use_new_frontend, use_old_api=use_old_api)
test_data_5D = [dict(shape=[1, 3, 50, 100, 224], axis=0),
dict(shape=[1, 3, 50, 100, 224], axis=-1),
dict(shape=[1, 3, 50, 100, 224], axis=2)]
test_data_5D = [dict(input_shapes=[[1, 3, 5, 7, 8], [2, 3, 5, 7, 8]], axis=0, is_v2=True),
dict(input_shapes=[[1, 3, 5, 7, 2], [1, 3, 5, 7, 3], [1, 3, 5, 7, 2], [1, 3, 5, 7, 4]],
axis=-1, is_v2=True),
dict(input_shapes=[[1, 3, 5, 7, 8], [1, 3, 2, 7, 8]], axis=2, is_v2=False)]
@pytest.mark.parametrize("params", test_data_5D)
@pytest.mark.nightly