From 22f3bee0f9d003c570d6f005fffe7073f49dc933 Mon Sep 17 00:00:00 2001 From: Roman Kazantsev Date: Thu, 20 Oct 2022 17:01:31 +0300 Subject: [PATCH] [TF FE] Refactor translation for Concat and ConcatV2 operations and add tests (#13542) * [TF FE] Refactor translation for Concat and ConcatV2 operations and add tests Signed-off-by: Kazantsev, Roman * [TF FE] Correct the layer tests for passing * Apply code-review feedback: remove redundant function * Fix a name of function to create Concat model in the layer test Signed-off-by: Kazantsev, Roman --- src/frontends/tensorflow/src/op/concat.cpp | 64 ++++++++------ .../tensorflow_tests/test_tf_Concat.py | 85 +++++++------------ 2 files changed, 71 insertions(+), 78 deletions(-) diff --git a/src/frontends/tensorflow/src/op/concat.cpp b/src/frontends/tensorflow/src/op/concat.cpp index 3fb8d9b8147..82f450f3086 100644 --- a/src/frontends/tensorflow/src/op/concat.cpp +++ b/src/frontends/tensorflow/src/op/concat.cpp @@ -15,36 +15,50 @@ namespace ov { namespace frontend { namespace tensorflow { namespace op { - OutputVector translate_concat_op(const NodeContext& node) { - size_t axis_idx, concat_idx_start, concat_idx_stop; - if (node.get_op_type() == "ConcatV2") { - axis_idx = node.get_input_size() - 1; - concat_idx_start = 0; - concat_idx_stop = node.get_input_size() - 1; - } else if (node.get_op_type() == "Concat") { - axis_idx = 0; - concat_idx_start = 1; - concat_idx_stop = node.get_input_size(); + // The difference between Concat and ConcatV2 is that + // axis is the first input for Concat + // and is the last input to ConcatV2 + default_op_checks(node, 2, {"Concat", "ConcatV2"}); + auto input_size = node.get_input_size(); + + int64_t axis; + OutputVector inputs; + + if (node.get_op_type() == "Concat") { + std::vector axis_vector; + get_const_input(node, 0, &axis_vector); + TENSORFLOW_OP_VALIDATION( + node, + axis_vector.size() == 1, + "Input model is incorrect: axis input for Concat operation must have exactly one element."); + axis = axis_vector[0]; + for (size_t input_idx = 1; input_idx < input_size; ++input_idx) { + inputs.push_back(node.get_input(input_idx)); + } + } else if (node.get_op_type() == "ConcatV2") { + std::vector axis_vector; + get_const_input(node, input_size - 1, &axis_vector); + TENSORFLOW_OP_VALIDATION( + node, + axis_vector.size() == 1, + "Input model is incorrect: axis input for Concat operation must have exactly one element."); + axis = axis_vector[0]; + for (size_t input_idx = 0; input_idx < input_size - 1; ++input_idx) { + inputs.push_back(node.get_input(input_idx)); + } } else { - TENSORFLOW_OP_VALIDATION(node, false, "Incorrect operation type."); + TENSORFLOW_OP_VALIDATION(node, + false, + "Internal TensorFlow Frontend error: incorrect operation type is passed to " + "translate_concat_op function."); } - std::vector tf_concat_axis_vec; - get_const_input(node, axis_idx, &tf_concat_axis_vec); - int64_t concat_axis = tf_concat_axis_vec[0]; - - OutputVector ng_args; - for (size_t i = concat_idx_start; i < concat_idx_stop; i++) { - Output ng_arg = node.get_input(static_cast(i)); - ng_args.push_back(ng_arg); - } - - auto res = make_shared(ng_args, size_t(concat_axis)); - set_node_name(node.get_name(), res); - return res->outputs(); + auto concat = make_shared(inputs, axis); + set_node_name(node.get_name(), concat); + return {concat}; } } // namespace op } // namespace tensorflow } // namespace frontend -} // namespace ov \ No newline at end of file +} // namespace ov diff --git a/tests/layer_tests/tensorflow_tests/test_tf_Concat.py b/tests/layer_tests/tensorflow_tests/test_tf_Concat.py index 2cdbb50a96a..dc537b43269 100644 --- a/tests/layer_tests/tensorflow_tests/test_tf_Concat.py +++ b/tests/layer_tests/tensorflow_tests/test_tf_Concat.py @@ -4,94 +4,72 @@ import pytest from common.tf_layer_test_class import CommonTFLayerTest -from common.utils.tf_utils import permute_nchw_to_nhwc class TestConcat(CommonTFLayerTest): - def create_concat_net(self, shape, axis, ir_version, use_new_frontend): - """ - Tensorflow net IR net - - Input->Concat => Input->Concat - - """ - + def create_concat_net(self, input_shapes, axis, is_v2, ir_version, use_new_frontend): + # tf.concat is equivalent to tf.raw_ops.ConcatV2 + # only tf.concat accepts one input import tensorflow as tf - tf.compat.v1.reset_default_graph() - - # Create the graph and model with tf.compat.v1.Session() as sess: - ax = axis - - tf_x_shape = shape.copy() - - tf_x_shape = permute_nchw_to_nhwc(tf_x_shape, use_new_frontend) - - # TODO: add concat with const inputs to check fusing (as in ONNX) - - x = tf.compat.v1.placeholder(tf.float32, tf_x_shape, 'Input') - y = tf.compat.v1.placeholder(tf.float32, tf_x_shape, 'Input') # Input_1 in graph_def - - concat = tf.concat([x, y], axis=ax, name='Operation') - concat_shape = concat.shape.as_list() - + placeholders = [] + for ind, input_shape in enumerate(input_shapes): + placeholders.append(tf.compat.v1.placeholder(tf.float32, input_shape, 'input_{}'.format(ind))) + if len(input_shapes) == 1: + tf.concat(values=placeholders, axis=axis, name='concat') + elif is_v2: + tf.raw_ops.ConcatV2(values=placeholders, axis=axis, name='concat') + else: + tf.raw_ops.Concat(values=placeholders, concat_dim=axis, name='concat') tf.compat.v1.global_variables_initializer() tf_net = sess.graph_def - # - # Create reference IR net - # Please, specify 'type': 'Input' for input node - # Moreover, do not forget to validate ALL layer attributes!!! - # - ref_net = None - return tf_net, ref_net - # TODO: create tests for concat with 1 input and multiple inputs - - test_data_1D = [dict(shape=[1], axis=0), - dict(shape=[1], axis=-1)] + test_data_1D = [dict(input_shapes=[[1], [2]], axis=0, is_v2=False), + dict(input_shapes=[[1], [3]], axis=-1, is_v2=True)] @pytest.mark.parametrize("params", test_data_1D) @pytest.mark.nightly def test_concat_1D(self, params, ie_device, precision, ir_version, temp_dir, use_new_frontend, use_old_api): self._test(*self.create_concat_net(**params, ir_version=ir_version, - use_new_frontend=use_new_frontend), + use_new_frontend=use_new_frontend), ie_device, precision, ir_version, temp_dir=temp_dir, use_new_frontend=use_new_frontend, use_old_api=use_old_api) - test_data_2D = [dict(shape=[1, 224], axis=0), - dict(shape=[1, 224], axis=-1)] + test_data_2D = [dict(input_shapes=[[3, 4]], axis=0, is_v2=True), + dict(input_shapes=[[1, 4], [1, 2]], axis=-1, is_v2=True)] @pytest.mark.parametrize("params", test_data_2D) @pytest.mark.nightly def test_concat_2D(self, params, ie_device, precision, ir_version, temp_dir, use_new_frontend, use_old_api): self._test(*self.create_concat_net(**params, ir_version=ir_version, - use_new_frontend=use_new_frontend), + use_new_frontend=use_new_frontend), ie_device, precision, ir_version, temp_dir=temp_dir, use_new_frontend=use_new_frontend, use_old_api=use_old_api) test_data_3D = [ - dict(shape=[1, 3, 224], axis=0), - pytest.param(dict(shape=[1, 3, 224], axis=-1), marks=pytest.mark.precommit_tf_fe), - dict(shape=[1, 3, 224], axis=2)] + dict(input_shapes=[[1, 3, 5], [2, 3, 5]], axis=0, is_v2=False), + dict(input_shapes=[[1, 3, 2], [1, 3, 5]], axis=-1, is_v2=True), + dict(input_shapes=[[1, 3, 1], [1, 3, 4], [1, 3, 3]], axis=2, is_v2=True)] @pytest.mark.parametrize("params", test_data_3D) @pytest.mark.nightly + @pytest.mark.precommit_tf_fe def test_concat_3D(self, params, ie_device, precision, ir_version, temp_dir, use_new_frontend, use_old_api): self._test(*self.create_concat_net(**params, ir_version=ir_version, - use_new_frontend=use_new_frontend), + use_new_frontend=use_new_frontend), ie_device, precision, ir_version, temp_dir=temp_dir, use_new_frontend=use_new_frontend, use_old_api=use_old_api) - test_data_4D = [dict(shape=[1, 3, 100, 224], axis=0), - dict(shape=[1, 3, 100, 224], axis=-1), - dict(shape=[1, 3, 100, 224], axis=2)] + test_data_4D = [dict(input_shapes=[[1, 3, 5, 7], [3, 3, 5, 7], [2, 3, 5, 7]], axis=0, is_v2=False), + dict(input_shapes=[[1, 3, 5, 5], [1, 3, 5, 7]], axis=-1, is_v2=True), + dict(input_shapes=[[1, 3, 5, 7], [1, 3, 3, 7]], axis=2, is_v2=False)] @pytest.mark.parametrize("params", test_data_4D) @pytest.mark.nightly @@ -99,19 +77,20 @@ class TestConcat(CommonTFLayerTest): def test_concat_4D(self, params, ie_device, precision, ir_version, temp_dir, use_new_frontend, use_old_api): self._test(*self.create_concat_net(**params, ir_version=ir_version, - use_new_frontend=use_new_frontend), + use_new_frontend=use_new_frontend), ie_device, precision, ir_version, temp_dir=temp_dir, use_new_frontend=use_new_frontend, use_old_api=use_old_api) - test_data_5D = [dict(shape=[1, 3, 50, 100, 224], axis=0), - dict(shape=[1, 3, 50, 100, 224], axis=-1), - dict(shape=[1, 3, 50, 100, 224], axis=2)] + test_data_5D = [dict(input_shapes=[[1, 3, 5, 7, 8], [2, 3, 5, 7, 8]], axis=0, is_v2=True), + dict(input_shapes=[[1, 3, 5, 7, 2], [1, 3, 5, 7, 3], [1, 3, 5, 7, 2], [1, 3, 5, 7, 4]], + axis=-1, is_v2=True), + dict(input_shapes=[[1, 3, 5, 7, 8], [1, 3, 2, 7, 8]], axis=2, is_v2=False)] @pytest.mark.parametrize("params", test_data_5D) @pytest.mark.nightly def test_concat_5D(self, params, ie_device, precision, ir_version, temp_dir, use_new_frontend, use_old_api): self._test(*self.create_concat_net(**params, ir_version=ir_version, - use_new_frontend=use_new_frontend), + use_new_frontend=use_new_frontend), ie_device, precision, ir_version, temp_dir=temp_dir, use_new_frontend=use_new_frontend, use_old_api=use_old_api)