diff --git a/src/frontends/tensorflow_common/src/helper_transforms/embedding_segments_feature_fusing.cpp b/src/frontends/tensorflow_common/src/helper_transforms/embedding_segments_feature_fusing.cpp index f4cf1b0c02f..91fa1d4be4d 100644 --- a/src/frontends/tensorflow_common/src/helper_transforms/embedding_segments_feature_fusing.cpp +++ b/src/frontends/tensorflow_common/src/helper_transforms/embedding_segments_feature_fusing.cpp @@ -91,8 +91,14 @@ ov::frontend::tensorflow::pass::EmbeddingSegmentSingleFeatureFusion::EmbeddingSe false); auto tile = make_shared(reshape, pack); - auto zeros_like = make_shared(make_shared(ov::element::f32, Shape{1}, std::vector{0}), - make_shared(sparse_segment_op)); + auto zero_int_const = make_shared(element::i32, Shape{1}, 0); + auto one_int_const = make_shared(element::i32, Shape{1}, 1); + Output shape_of = make_shared(sparse_segment_op, element::i32); + shape_of = make_shared(OutputVector{one_int_const, shape_of}, 0); + + Output zeros_like = + make_shared(make_shared(ov::element::f32, Shape{1}, std::vector{0}), shape_of); + zeros_like = make_shared(zeros_like, zero_int_const); // compute number of dimensions to unsqueeze the condition auto cond_rank = compute_subgraph_scalar_rank(tile, element::i32); diff --git a/src/frontends/tensorflow_common/src/op/zeros_like.cpp b/src/frontends/tensorflow_common/src/op/zeros_like.cpp index 4baf3ec49f0..19f794e7b01 100644 --- a/src/frontends/tensorflow_common/src/op/zeros_like.cpp +++ b/src/frontends/tensorflow_common/src/op/zeros_like.cpp @@ -14,12 +14,25 @@ namespace tensorflow { namespace op { OutputVector translate_zeros_like_op(const NodeContext& node) { + default_op_checks(node, 1, {"ZerosLike", "ZEROS_LIKE"}); auto x = node.get_input(0); - auto shape_of = make_shared(x); - auto zero = make_shared(x.get_element_type(), Shape{1}, 0); - auto res = make_shared(zero, shape_of); - set_node_name(node.get_name(), res); - return res->outputs(); + Output shape_of = make_shared(x, element::i32); + auto zero_const = make_shared(x.get_element_type(), Shape{1}, 0); + + // in case of x to be scalar, we need handle it more specifically + // since Broadcast supports only broadcasting to rank greater 0 + // we have to introduce extra dimension for input scalar case + auto zero_int_const = make_shared(element::i32, Shape{1}, 0); + auto one_int_const = make_shared(element::i32, Shape{1}, 1); + shape_of = make_shared(OutputVector{one_int_const, shape_of}, 0); + + // create a tensor of zeros of shape with extra dimension + Output zeros_like = make_shared(zero_const, shape_of); + // remove extra dimension by squeezing + zeros_like = make_shared(zeros_like, zero_int_const); + + set_node_name(node.get_name(), zeros_like.get_node_shared_ptr()); + return {zeros_like}; } } // namespace op diff --git a/tests/layer_tests/tensorflow_tests/test_tf_ZerosLike.py b/tests/layer_tests/tensorflow_tests/test_tf_ZerosLike.py new file mode 100644 index 00000000000..ad8ba15a383 --- /dev/null +++ b/tests/layer_tests/tensorflow_tests/test_tf_ZerosLike.py @@ -0,0 +1,35 @@ +# Copyright (C) 2018-2023 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +import pytest +import tensorflow as tf +from common.tf_layer_test_class import CommonTFLayerTest + + +class TestZerosLike(CommonTFLayerTest): + def create_zeros_like_net(self, x_shape): + tf.compat.v1.reset_default_graph() + # Create the graph and model + with tf.compat.v1.Session() as sess: + x = tf.compat.v1.placeholder(tf.float32, x_shape, 'x') + tf.raw_ops.ZerosLike(x=x) + tf.compat.v1.global_variables_initializer() + tf_net = sess.graph_def + + return tf_net, None + + test_data_basic = [ + dict(x_shape=[]), + dict(x_shape=[3]), + dict(x_shape=[2, 1, 4]), + dict(x_shape=[2, 4, 3, 1]), + ] + + @pytest.mark.parametrize("params", test_data_basic) + @pytest.mark.precommit_tf_fe + @pytest.mark.nightly + def test_zeros_like_basic(self, params, ie_device, precision, ir_version, temp_dir, + use_new_frontend, use_old_api): + self._test(*self.create_zeros_like_net(**params), + ie_device, precision, ir_version, temp_dir=temp_dir, + use_new_frontend=use_new_frontend, use_old_api=use_old_api)