[TF FE] Refactor ZerosLike and add layer test (#15648)
* [TF FE] Refactor ZerosLike and add layer test Signed-off-by: Kazantsev, Roman <roman.kazantsev@intel.com> * Fix test for Wide and Deep model --------- Signed-off-by: Kazantsev, Roman <roman.kazantsev@intel.com>
This commit is contained in:
@@ -91,8 +91,14 @@ ov::frontend::tensorflow::pass::EmbeddingSegmentSingleFeatureFusion::EmbeddingSe
|
||||
false);
|
||||
auto tile = make_shared<Tile>(reshape, pack);
|
||||
|
||||
auto zeros_like = make_shared<Broadcast>(make_shared<Constant>(ov::element::f32, Shape{1}, std::vector<int64_t>{0}),
|
||||
make_shared<ShapeOf>(sparse_segment_op));
|
||||
auto zero_int_const = make_shared<Constant>(element::i32, Shape{1}, 0);
|
||||
auto one_int_const = make_shared<Constant>(element::i32, Shape{1}, 1);
|
||||
Output<Node> shape_of = make_shared<ShapeOf>(sparse_segment_op, element::i32);
|
||||
shape_of = make_shared<Concat>(OutputVector{one_int_const, shape_of}, 0);
|
||||
|
||||
Output<Node> zeros_like =
|
||||
make_shared<Broadcast>(make_shared<Constant>(ov::element::f32, Shape{1}, std::vector<int64_t>{0}), shape_of);
|
||||
zeros_like = make_shared<Squeeze>(zeros_like, zero_int_const);
|
||||
|
||||
// compute number of dimensions to unsqueeze the condition
|
||||
auto cond_rank = compute_subgraph_scalar_rank(tile, element::i32);
|
||||
|
||||
@@ -14,12 +14,25 @@ namespace tensorflow {
|
||||
namespace op {
|
||||
|
||||
OutputVector translate_zeros_like_op(const NodeContext& node) {
|
||||
default_op_checks(node, 1, {"ZerosLike", "ZEROS_LIKE"});
|
||||
auto x = node.get_input(0);
|
||||
auto shape_of = make_shared<ShapeOf>(x);
|
||||
auto zero = make_shared<Constant>(x.get_element_type(), Shape{1}, 0);
|
||||
auto res = make_shared<Broadcast>(zero, shape_of);
|
||||
set_node_name(node.get_name(), res);
|
||||
return res->outputs();
|
||||
Output<Node> shape_of = make_shared<ShapeOf>(x, element::i32);
|
||||
auto zero_const = make_shared<Constant>(x.get_element_type(), Shape{1}, 0);
|
||||
|
||||
// in case of x to be scalar, we need handle it more specifically
|
||||
// since Broadcast supports only broadcasting to rank greater 0
|
||||
// we have to introduce extra dimension for input scalar case
|
||||
auto zero_int_const = make_shared<Constant>(element::i32, Shape{1}, 0);
|
||||
auto one_int_const = make_shared<Constant>(element::i32, Shape{1}, 1);
|
||||
shape_of = make_shared<Concat>(OutputVector{one_int_const, shape_of}, 0);
|
||||
|
||||
// create a tensor of zeros of shape with extra dimension
|
||||
Output<Node> zeros_like = make_shared<Broadcast>(zero_const, shape_of);
|
||||
// remove extra dimension by squeezing
|
||||
zeros_like = make_shared<Squeeze>(zeros_like, zero_int_const);
|
||||
|
||||
set_node_name(node.get_name(), zeros_like.get_node_shared_ptr());
|
||||
return {zeros_like};
|
||||
}
|
||||
|
||||
} // namespace op
|
||||
|
||||
35
tests/layer_tests/tensorflow_tests/test_tf_ZerosLike.py
Normal file
35
tests/layer_tests/tensorflow_tests/test_tf_ZerosLike.py
Normal file
@@ -0,0 +1,35 @@
|
||||
# Copyright (C) 2018-2023 Intel Corporation
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
import pytest
|
||||
import tensorflow as tf
|
||||
from common.tf_layer_test_class import CommonTFLayerTest
|
||||
|
||||
|
||||
class TestZerosLike(CommonTFLayerTest):
|
||||
def create_zeros_like_net(self, x_shape):
|
||||
tf.compat.v1.reset_default_graph()
|
||||
# Create the graph and model
|
||||
with tf.compat.v1.Session() as sess:
|
||||
x = tf.compat.v1.placeholder(tf.float32, x_shape, 'x')
|
||||
tf.raw_ops.ZerosLike(x=x)
|
||||
tf.compat.v1.global_variables_initializer()
|
||||
tf_net = sess.graph_def
|
||||
|
||||
return tf_net, None
|
||||
|
||||
test_data_basic = [
|
||||
dict(x_shape=[]),
|
||||
dict(x_shape=[3]),
|
||||
dict(x_shape=[2, 1, 4]),
|
||||
dict(x_shape=[2, 4, 3, 1]),
|
||||
]
|
||||
|
||||
@pytest.mark.parametrize("params", test_data_basic)
|
||||
@pytest.mark.precommit_tf_fe
|
||||
@pytest.mark.nightly
|
||||
def test_zeros_like_basic(self, params, ie_device, precision, ir_version, temp_dir,
|
||||
use_new_frontend, use_old_api):
|
||||
self._test(*self.create_zeros_like_net(**params),
|
||||
ie_device, precision, ir_version, temp_dir=temp_dir,
|
||||
use_new_frontend=use_new_frontend, use_old_api=use_old_api)
|
||||
Reference in New Issue
Block a user