From b544308616224615a1c614ec89bf04e3f1e537b2 Mon Sep 17 00:00:00 2001 From: Roman Kazantsev Date: Mon, 6 Feb 2023 19:05:39 +0400 Subject: [PATCH] [TF FE] Refactor Unpack and add layer test (#15519) * [TF FE] Refactor Unpack and add layer test Signed-off-by: Kazantsev, Roman * Update tests/layer_tests/tensorflow_tests/test_tf_Unpack.py --------- Signed-off-by: Kazantsev, Roman --- .../tensorflow_common/src/op/unpack.cpp | 21 +++---- .../tensorflow_tests/test_tf_Unpack.py | 55 +++++++++++++++++++ 2 files changed, 64 insertions(+), 12 deletions(-) create mode 100644 tests/layer_tests/tensorflow_tests/test_tf_Unpack.py diff --git a/src/frontends/tensorflow_common/src/op/unpack.cpp b/src/frontends/tensorflow_common/src/op/unpack.cpp index 53d6581b7b8..1896083993c 100644 --- a/src/frontends/tensorflow_common/src/op/unpack.cpp +++ b/src/frontends/tensorflow_common/src/op/unpack.cpp @@ -14,23 +14,20 @@ namespace tensorflow { namespace op { OutputVector translate_unpack_op(const NodeContext& node) { - TENSORFLOW_OP_VALIDATION(node, node.get_input_size() > 0, "Unpack must have at least one input."); - auto input = node.get_input(0); + default_op_checks(node, 1, {"Unpack", "UNPACK"}); + auto value = node.get_input(0); auto axis = node.get_attribute("axis", 0); auto num = node.get_attribute("num"); auto axis_const = make_shared(element::i64, Shape{}, axis); - auto split = make_shared(input, axis_const, num); - OutputVector res; - int idx = 0; - for (auto out : split->outputs()) { - auto squeezed_res = make_shared(out, axis_const); - squeezed_res->set_friendly_name(node.get_name() + "/squeeze_" + to_string(idx)); - set_out_name(node.get_name() + ":" + std::to_string(idx), squeezed_res->output(0)); - ++idx; - res.push_back(squeezed_res); + auto split = make_shared(value, axis_const, num); + OutputVector unpack_outputs; + for (int output_ind = 0; output_ind < num; ++output_ind) { + auto unpack_output = make_shared(split->output(output_ind), axis_const); + set_out_name(node.get_name() + ":" + to_string(output_ind), unpack_output); + unpack_outputs.push_back(unpack_output); } - return res; + return unpack_outputs; } } // namespace op } // namespace tensorflow diff --git a/tests/layer_tests/tensorflow_tests/test_tf_Unpack.py b/tests/layer_tests/tensorflow_tests/test_tf_Unpack.py new file mode 100644 index 00000000000..04db4ad0786 --- /dev/null +++ b/tests/layer_tests/tensorflow_tests/test_tf_Unpack.py @@ -0,0 +1,55 @@ +# Copyright (C) 2018-2023 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +import numpy as np +import pytest +import tensorflow as tf +from common.tf_layer_test_class import CommonTFLayerTest + + +class TestUnpack(CommonTFLayerTest): + def _prepare_input(self, inputs_info): + assert 'x' in inputs_info, "Test error: inputs_info must contain `x`" + x_shape = inputs_info['x'] + inputs_data = {} + inputs_data['x'] = np.random.randint(-10, 10, x_shape).astype(self.input_type) + return inputs_data + + def create_unpack_net(self, input_shape, num, axis, input_type): + self.input_type = input_type + tf.compat.v1.reset_default_graph() + # Create the graph and model + with tf.compat.v1.Session() as sess: + type_map = { + np.float32: tf.float32, + np.int32: tf.int32, + } + assert input_type in type_map, "Test error: need to update type_map" + tf_type = type_map[input_type] + x = tf.compat.v1.placeholder(tf_type, input_shape, 'x') + if axis is not None: + unpack = tf.raw_ops.Unpack(value=x, num=num, axis=axis) + else: + unpack = tf.raw_ops.Unpack(value=x, num=num) + for ind in range(num): + tf.identity(unpack[ind], name="output_" + str(ind)) + tf.compat.v1.global_variables_initializer() + + tf_net = sess.graph_def + + return tf_net, None + + test_data_basic = [ + dict(input_shape=[2, 3, 4], num=3, axis=1, input_type=np.float32), + dict(input_shape=[3, 4], num=3, axis=None, input_type=np.int32), + dict(input_shape=[4, 2, 3], num=2, axis=-2, input_type=np.float32), + ] + + @pytest.mark.parametrize("params", test_data_basic) + @pytest.mark.precommit_tf_fe + @pytest.mark.nightly + def test_unpack_basic(self, params, ie_device, precision, ir_version, temp_dir, + use_new_frontend, use_old_api): + self._test(*self.create_unpack_net(**params), + ie_device, precision, ir_version, temp_dir=temp_dir, + use_new_frontend=use_new_frontend, use_old_api=use_old_api)