[TF FE][TF Hub] Support Xlogy operation (#20467)
* [TF FE][TF Hub] Support Xlogy operation * fix * fix * fix * fix * Update tests/layer_tests/tensorflow_tests/test_tf_Xlogy.py * Update tests/layer_tests/tensorflow_tests/test_tf_Xlogy.py --------- Co-authored-by: Roman Kazantsev <roman.kazantsev@intel.com>
This commit is contained in:
@@ -286,6 +286,7 @@ const std::map<std::string, CreatorFunction> get_supported_ops() {
|
||||
{"While", CreatorFunction(translate_while_op)},
|
||||
{"Where", CreatorFunction(translate_where_op)},
|
||||
{"Xdivy", CreatorFunction(translate_x_div_y_op)},
|
||||
{"Xlogy", CreatorFunction(translate_xlogy_op)},
|
||||
{"ZerosLike", CreatorFunction(translate_zeros_like_op)},
|
||||
|
||||
// Translators for SavedModel and MetaGraph
|
||||
|
||||
@@ -149,6 +149,7 @@ OP_CONVERTER(translate_unravel_index_op);
|
||||
OP_CONVERTER(translate_unsorted_segment_sum_op);
|
||||
OP_CONVERTER(translate_where_op);
|
||||
OP_CONVERTER(translate_x_div_y_op);
|
||||
OP_CONVERTER(translate_xlogy_op);
|
||||
OP_CONVERTER(translate_zeros_like_op);
|
||||
|
||||
// Translators for internal operations
|
||||
|
||||
43
src/frontends/tensorflow_common/src/op/xlogy.cpp
Normal file
43
src/frontends/tensorflow_common/src/op/xlogy.cpp
Normal file
@@ -0,0 +1,43 @@
|
||||
// Copyright (C) 2018-2023 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#include "common_op_table.hpp"
|
||||
#include "openvino/op/constant.hpp"
|
||||
#include "openvino/op/convert_like.hpp"
|
||||
#include "openvino/op/equal.hpp"
|
||||
#include "openvino/op/log.hpp"
|
||||
#include "openvino/op/multiply.hpp"
|
||||
#include "openvino/op/select.hpp"
|
||||
|
||||
using namespace std;
|
||||
using namespace ov::opset10;
|
||||
|
||||
namespace ov {
|
||||
namespace frontend {
|
||||
namespace tensorflow {
|
||||
namespace op {
|
||||
OutputVector translate_xlogy_op(const NodeContext& node) {
|
||||
default_op_checks(node, 2, {"Xlogy"});
|
||||
auto x = node.get_input(0);
|
||||
auto y = node.get_input(1);
|
||||
|
||||
// prepare auxiliary zero constant of the same type as the input
|
||||
auto zero = create_same_type_const_scalar<int32_t>(x, 0);
|
||||
|
||||
// compute a mask to identify where x is equal to 0
|
||||
auto is_zero = make_shared<Equal>(x, zero);
|
||||
|
||||
// compute x * log(y) elementwise
|
||||
auto xlog_y = make_shared<Multiply>(x, make_shared<Log>(y));
|
||||
|
||||
// create the output tensor using Select to handle the x == 0 condition
|
||||
auto result = make_shared<Select>(is_zero, zero, xlog_y);
|
||||
|
||||
set_node_name(node.get_name(), result);
|
||||
return result->outputs();
|
||||
}
|
||||
} // namespace op
|
||||
} // namespace tensorflow
|
||||
} // namespace frontend
|
||||
} // namespace ov
|
||||
49
tests/layer_tests/tensorflow_tests/test_tf_Xlogy.py
Normal file
49
tests/layer_tests/tensorflow_tests/test_tf_Xlogy.py
Normal file
@@ -0,0 +1,49 @@
|
||||
# Copyright (C) 2018-2023 Intel Corporation
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
import tensorflow as tf
|
||||
from common.tf_layer_test_class import CommonTFLayerTest
|
||||
|
||||
|
||||
class TestXlogy(CommonTFLayerTest):
|
||||
def _prepare_input(self, inputs_info):
|
||||
assert 'x' in inputs_info
|
||||
assert 'y' in inputs_info
|
||||
x_shape = inputs_info['x']
|
||||
y_shape = inputs_info['y']
|
||||
inputs_data = {}
|
||||
# x = [-3 ,3] y = [1, 2]
|
||||
# generate x in way to have zeros
|
||||
inputs_data['x'] = (6 * np.random.random(size=x_shape).astype(np.float32) - 3) * \
|
||||
np.random.randint(2, size=x_shape).astype(np.float32)
|
||||
inputs_data['y'] = np.random.random(size=y_shape).astype(np.float32) + 1
|
||||
return inputs_data
|
||||
|
||||
def create_xlogy_net(self, input_shape, input_type):
|
||||
self.input_type = input_type
|
||||
tf.compat.v1.reset_default_graph()
|
||||
# Create the graph and model
|
||||
with tf.compat.v1.Session() as sess:
|
||||
x = tf.compat.v1.placeholder(input_type, input_shape, 'x')
|
||||
y = tf.compat.v1.placeholder(input_type, input_shape, 'y')
|
||||
tf.raw_ops.Xlogy(x=x, y=y)
|
||||
tf.compat.v1.global_variables_initializer()
|
||||
tf_net = sess.graph_def
|
||||
|
||||
return tf_net, None
|
||||
|
||||
test_data_basic = [
|
||||
dict(input_shape=[10, 20], input_type=np.float32),
|
||||
dict(input_shape=[2, 3, 4], input_type=np.float32),
|
||||
]
|
||||
|
||||
@pytest.mark.parametrize("params", test_data_basic)
|
||||
@pytest.mark.precommit_tf_fe
|
||||
@pytest.mark.nightly
|
||||
def test_xlogy_basic(self, params, ie_device, precision, ir_version, temp_dir,
|
||||
use_new_frontend, use_old_api):
|
||||
self._test(*self.create_xlogy_net(**params),
|
||||
ie_device, precision, ir_version, temp_dir=temp_dir,
|
||||
use_new_frontend=use_new_frontend, use_old_api=use_old_api)
|
||||
Reference in New Issue
Block a user