diff --git a/src/frontends/tensorflow_common/src/op/split.cpp b/src/frontends/tensorflow_common/src/op/split.cpp index fd1dafd580a..f3795e35837 100644 --- a/src/frontends/tensorflow_common/src/op/split.cpp +++ b/src/frontends/tensorflow_common/src/op/split.cpp @@ -14,25 +14,25 @@ namespace tensorflow { namespace op { OutputVector translate_split_op(const NodeContext& node) { - TENSORFLOW_OP_VALIDATION(node, node.get_input_size() > 1, "Split must have at least two inputs."); + default_op_checks(node, 2, {"Split", "SPLIT"}); auto axis = node.get_input(0); - auto input = node.get_input(1); + auto value = node.get_input(1); auto num_split = node.get_attribute("num_split"); - auto res = make_shared(input, axis, num_split); - set_node_name(node.get_name(), res); - return res->outputs(); + auto split = make_shared(value, axis, num_split); + set_node_name(node.get_name(), split); + return split->outputs(); } OutputVector translate_split_v_op(const NodeContext& node) { - TENSORFLOW_OP_VALIDATION(node, node.get_input_size() > 2, "Split must have at least three inputs."); - auto input = node.get_input(0); - auto split_lengths = node.get_input(1); + default_op_checks(node, 3, {"SplitV", "SPLIT_V"}); + auto value = node.get_input(0); + auto size_splits = node.get_input(1); auto axis = node.get_input(2); - auto res = make_shared(input, axis, split_lengths); - set_node_name(node.get_name(), res); - return res->outputs(); + auto splitv = make_shared(value, axis, size_splits); + set_node_name(node.get_name(), splitv); + return splitv->outputs(); } } // namespace op } // namespace tensorflow diff --git a/tests/layer_tests/tensorflow_tests/test_tf_Split.py b/tests/layer_tests/tensorflow_tests/test_tf_Split.py new file mode 100644 index 00000000000..c221bb896c3 --- /dev/null +++ b/tests/layer_tests/tensorflow_tests/test_tf_Split.py @@ -0,0 +1,37 @@ +# Copyright (C) 2018-2023 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +import pytest +import tensorflow as tf +from common.tf_layer_test_class import CommonTFLayerTest + + +class TestSplit(CommonTFLayerTest): + def create_split_net(self, value_shape, axis_value, num_split): + tf.compat.v1.reset_default_graph() + # Create the graph and model + with tf.compat.v1.Session() as sess: + axis = tf.constant(axis_value, dtype=tf.int32) + value = tf.compat.v1.placeholder(tf.float32, value_shape, 'value') + split = tf.raw_ops.Split(axis=axis, value=value, num_split=num_split) + for output_ind in range(num_split): + tf.identity(split[output_ind], name="split_" + str(output_ind)) + tf.compat.v1.global_variables_initializer() + tf_net = sess.graph_def + + return tf_net, None + + test_data_basic = [ + dict(value_shape=[6], axis_value=0, num_split=2), + dict(value_shape=[2, 1, 6], axis_value=2, num_split=3), + dict(value_shape=[4, 3, 2, 7], axis_value=-4, num_split=4), + ] + + @pytest.mark.parametrize("params", test_data_basic) + @pytest.mark.precommit_tf_fe + @pytest.mark.nightly + def test_split_basic(self, params, ie_device, precision, ir_version, temp_dir, + use_new_frontend, use_old_api): + self._test(*self.create_split_net(**params), + ie_device, precision, ir_version, temp_dir=temp_dir, + use_new_frontend=use_new_frontend, use_old_api=use_old_api) diff --git a/tests/layer_tests/tensorflow_tests/test_tf_SplitV.py b/tests/layer_tests/tensorflow_tests/test_tf_SplitV.py new file mode 100644 index 00000000000..8d3328fe839 --- /dev/null +++ b/tests/layer_tests/tensorflow_tests/test_tf_SplitV.py @@ -0,0 +1,40 @@ +# Copyright (C) 2018-2023 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +import pytest +import tensorflow as tf +from common.tf_layer_test_class import CommonTFLayerTest + + +class TestSplitV(CommonTFLayerTest): + def create_splitv_net(self, value_shape, size_splits_values, axis_value): + tf.compat.v1.reset_default_graph() + # Create the graph and model + with tf.compat.v1.Session() as sess: + axis = tf.constant(axis_value, dtype=tf.int32) + size_splits = tf.constant(size_splits_values, dtype=tf.int32) + value = tf.compat.v1.placeholder(tf.float32, value_shape, 'value') + num_split = len(size_splits_values) + splitv = tf.raw_ops.SplitV(value=value, size_splits=size_splits, axis=axis, num_split=num_split) + for output_ind in range(num_split): + if size_splits_values[output_ind] != 0: + tf.identity(splitv[output_ind], name="split_" + str(output_ind)) + tf.compat.v1.global_variables_initializer() + tf_net = sess.graph_def + + return tf_net, None + + test_data_basic = [ + dict(value_shape=[3], size_splits_values=[0, 2, 0, 1], axis_value=0), + dict(value_shape=[2, 3, 9], size_splits_values=[1, 2, 3, -1, 1], axis_value=2), + dict(value_shape=[3, 9, 5, 4], size_splits_values=[1, 2, 0, -1, 2, 0], axis_value=-3), + ] + + @pytest.mark.parametrize("params", test_data_basic) + @pytest.mark.precommit_tf_fe + @pytest.mark.nightly + def test_split_basic(self, params, ie_device, precision, ir_version, temp_dir, + use_new_frontend, use_old_api): + self._test(*self.create_splitv_net(**params), + ie_device, precision, ir_version, temp_dir=temp_dir, + use_new_frontend=use_new_frontend, use_old_api=use_old_api)