[TF FE][GitHub issue] Support Selu operation and add test (#19528)

* [TF FE] Support Selu operation and add test

Signed-off-by: Kazantsev, Roman <roman.kazantsev@intel.com>

* Fix layer test

---------

Signed-off-by: Kazantsev, Roman <roman.kazantsev@intel.com>
This commit is contained in:
Roman Kazantsev 2023-09-01 13:09:58 +04:00 committed by GitHub
parent 9715ccd992
commit 2cf8f2bc1f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 97 additions and 109 deletions

View File

@ -73,6 +73,7 @@ const std::map<std::string, CreatorFunction> get_supported_ops() {
{"Mish", CreatorFunction(translate_unary_op<opset8::Mish>)},
{"Neg", CreatorFunction(translate_unary_op<opset8::Negative>)},
{"Relu", CreatorFunction(translate_unary_op<opset8::Relu>)},
{"Selu", CreatorFunction(translate_selu_op)},
{"Sigmoid", CreatorFunction(translate_unary_op<opset8::Sigmoid>)},
{"Sin", CreatorFunction(translate_unary_op<opset8::Sin>)},
{"Sinh", CreatorFunction(translate_unary_op<opset8::Sinh>)},

View File

@ -28,6 +28,7 @@ namespace op {
OutputVector op(const ov::frontend::NodeContext& node)
OP_T_CONVERTER(translate_unary_op);
OP_CONVERTER(translate_selu_op);
OP_T_CONVERTER(translate_binary_op);
OP_T_CONVERTER(translate_direct_reduce_op);

View File

@ -3,12 +3,44 @@
//
#include "common_op_table.hpp"
#include "openvino/opsets/opset10.hpp"
#include "openvino/opsets/opset8.hpp"
#include "openvino/opsets/opset9.hpp"
#include "openvino/op/abs.hpp"
#include "openvino/op/acos.hpp"
#include "openvino/op/acosh.hpp"
#include "openvino/op/asin.hpp"
#include "openvino/op/asinh.hpp"
#include "openvino/op/atan.hpp"
#include "openvino/op/atanh.hpp"
#include "openvino/op/ceiling.hpp"
#include "openvino/op/constant.hpp"
#include "openvino/op/cos.hpp"
#include "openvino/op/cosh.hpp"
#include "openvino/op/erf.hpp"
#include "openvino/op/exp.hpp"
#include "openvino/op/floor.hpp"
#include "openvino/op/hswish.hpp"
#include "openvino/op/is_finite.hpp"
#include "openvino/op/is_inf.hpp"
#include "openvino/op/is_nan.hpp"
#include "openvino/op/log.hpp"
#include "openvino/op/logical_not.hpp"
#include "openvino/op/mish.hpp"
#include "openvino/op/negative.hpp"
#include "openvino/op/relu.hpp"
#include "openvino/op/selu.hpp"
#include "openvino/op/sigmoid.hpp"
#include "openvino/op/sign.hpp"
#include "openvino/op/sin.hpp"
#include "openvino/op/sinh.hpp"
#include "openvino/op/softplus.hpp"
#include "openvino/op/softsign.hpp"
#include "openvino/op/swish.hpp"
#include "openvino/op/tan.hpp"
#include "openvino/op/tanh.hpp"
#include "utils.hpp"
using namespace std;
using namespace ov::opset8;
using namespace ov::frontend::tensorflow;
using namespace ov::op;
namespace ov {
namespace frontend {
@ -16,9 +48,9 @@ namespace tensorflow {
namespace op {
OutputVector translate_unary_op(const NodeContext& op,
const std::function<shared_ptr<Node>(Output<Node>)>& create_unary_op) {
auto ng_input = op.get_input(0);
auto res = create_unary_op(ng_input);
const function<shared_ptr<Node>(Output<Node>)>& create_unary_op) {
auto input = op.get_input(0);
auto res = create_unary_op(input);
set_node_name(op.get_name(), res);
return {res};
}
@ -30,37 +62,49 @@ OutputVector translate_unary_op(const NodeContext& node) {
});
}
template OutputVector translate_unary_op<Abs>(const NodeContext& node);
template OutputVector translate_unary_op<Acos>(const NodeContext& node);
template OutputVector translate_unary_op<Acosh>(const NodeContext& node);
template OutputVector translate_unary_op<Asin>(const NodeContext& node);
template OutputVector translate_unary_op<Asinh>(const NodeContext& node);
template OutputVector translate_unary_op<Atan>(const NodeContext& node);
template OutputVector translate_unary_op<Atanh>(const NodeContext& node);
template OutputVector translate_unary_op<Ceiling>(const NodeContext& node);
template OutputVector translate_unary_op<Cos>(const NodeContext& node);
template OutputVector translate_unary_op<Cosh>(const NodeContext& node);
template OutputVector translate_unary_op<Erf>(const NodeContext& node);
template OutputVector translate_unary_op<Exp>(const NodeContext& node);
template OutputVector translate_unary_op<Floor>(const NodeContext& node);
template OutputVector translate_unary_op<HSwish>(const NodeContext& node);
template OutputVector translate_unary_op<opset10::IsFinite>(const NodeContext& node);
template OutputVector translate_unary_op<opset10::IsInf>(const NodeContext& node);
template OutputVector translate_unary_op<opset10::IsNaN>(const NodeContext& node);
template OutputVector translate_unary_op<Log>(const NodeContext& node);
template OutputVector translate_unary_op<LogicalNot>(const NodeContext& node);
template OutputVector translate_unary_op<Mish>(const NodeContext& node);
template OutputVector translate_unary_op<Negative>(const NodeContext& node);
template OutputVector translate_unary_op<Relu>(const NodeContext& node);
template OutputVector translate_unary_op<Sigmoid>(const NodeContext& node);
template OutputVector translate_unary_op<Sin>(const NodeContext& node);
template OutputVector translate_unary_op<Sinh>(const NodeContext& node);
template OutputVector translate_unary_op<Sign>(const NodeContext& node);
template OutputVector translate_unary_op<SoftPlus>(const NodeContext& node);
template OutputVector translate_unary_op<Tan>(const NodeContext& node);
template OutputVector translate_unary_op<Tanh>(const NodeContext& node);
template OutputVector translate_unary_op<opset9::SoftSign>(const NodeContext& node);
template OutputVector translate_unary_op<Swish>(const NodeContext& node);
template OutputVector translate_unary_op<v0::Abs>(const NodeContext& node);
template OutputVector translate_unary_op<v0::Acos>(const NodeContext& node);
template OutputVector translate_unary_op<v3::Acosh>(const NodeContext& node);
template OutputVector translate_unary_op<v0::Asin>(const NodeContext& node);
template OutputVector translate_unary_op<v3::Asinh>(const NodeContext& node);
template OutputVector translate_unary_op<v0::Atan>(const NodeContext& node);
template OutputVector translate_unary_op<v3::Atanh>(const NodeContext& node);
template OutputVector translate_unary_op<v0::Ceiling>(const NodeContext& node);
template OutputVector translate_unary_op<v0::Cos>(const NodeContext& node);
template OutputVector translate_unary_op<v0::Cosh>(const NodeContext& node);
template OutputVector translate_unary_op<v0::Erf>(const NodeContext& node);
template OutputVector translate_unary_op<v0::Exp>(const NodeContext& node);
template OutputVector translate_unary_op<v0::Floor>(const NodeContext& node);
template OutputVector translate_unary_op<v4::HSwish>(const NodeContext& node);
template OutputVector translate_unary_op<v10::IsFinite>(const NodeContext& node);
template OutputVector translate_unary_op<v10::IsInf>(const NodeContext& node);
template OutputVector translate_unary_op<v10::IsNaN>(const NodeContext& node);
template OutputVector translate_unary_op<v0::Log>(const NodeContext& node);
template OutputVector translate_unary_op<v1::LogicalNot>(const NodeContext& node);
template OutputVector translate_unary_op<v4::Mish>(const NodeContext& node);
template OutputVector translate_unary_op<v0::Negative>(const NodeContext& node);
template OutputVector translate_unary_op<v0::Relu>(const NodeContext& node);
template OutputVector translate_unary_op<v0::Sigmoid>(const NodeContext& node);
template OutputVector translate_unary_op<v0::Sin>(const NodeContext& node);
template OutputVector translate_unary_op<v0::Sinh>(const NodeContext& node);
template OutputVector translate_unary_op<v0::Sign>(const NodeContext& node);
template OutputVector translate_unary_op<v4::SoftPlus>(const NodeContext& node);
template OutputVector translate_unary_op<v0::Tan>(const NodeContext& node);
template OutputVector translate_unary_op<v0::Tanh>(const NodeContext& node);
template OutputVector translate_unary_op<v9::SoftSign>(const NodeContext& node);
template OutputVector translate_unary_op<v4::Swish>(const NodeContext& node);
OutputVector translate_selu_op(const NodeContext& node) {
default_op_checks(node, 1, {"Selu"});
auto features = node.get_input(0);
// create pre-defined constants
auto alpha = create_same_type_const<float>(features, {1.67326324f}, Shape{1});
auto scale = create_same_type_const<float>(features, {1.05070098f}, Shape{1});
auto selu = make_shared<v0::Selu>(features, alpha, scale);
set_node_name(node.get_name(), selu);
return {selu};
}
} // namespace op
} // namespace tensorflow

View File

@ -1,22 +1,20 @@
# Copyright (C) 2018-2023 Intel Corporation
# SPDX-License-Identifier: Apache-2.0
import sys
import numpy as np
import pytest
import sys
from common.layer_test_class import check_ir_version
from common.tf_layer_test_class import CommonTFLayerTest
from common.utils.tf_utils import permute_nchw_to_nhwc
from unit_tests.utils.graph import build_graph
class TestUnaryOps(CommonTFLayerTest):
current_op_type = None
def _prepare_input(self, inputs_dict):
non_negative = ['Sqrt', 'Log']
narrow_borders = ["Sinh", "Cosh", "Tanh", "Exp"]
narrow_borders = ["Sinh", "Cosh", "Tanh", "Exp", "Selu"]
within_one = ['Asin', 'Acos', 'Atanh']
from_one = ['Acosh']
@ -48,13 +46,6 @@ class TestUnaryOps(CommonTFLayerTest):
return inputs_dict
def create_net_with_mish(self, shape, ir_version, use_new_frontend):
"""
TODO: Move functionality to `create_net_with_unary_op()` once tensorflow_addons
supports Python 3.11
Tensorflow net IR net
Input->mish => Input->mish
"""
import tensorflow as tf
import tensorflow_addons as tfa
@ -70,41 +61,11 @@ class TestUnaryOps(CommonTFLayerTest):
tf.compat.v1.global_variables_initializer()
tf_net = sess.graph_def
#
# Create reference IR net
# Please, specify 'type': 'Input' for input node
# Moreover, do not forget to validate ALL layer attributes!!!
#
ref_net = None
if check_ir_version(10, None, ir_version) and not use_new_frontend:
nodes_attributes = {
'input': {'kind': 'op', 'type': 'Parameter'},
'input_data': {'shape': shape, 'kind': 'data'},
'testing_op': {'kind': 'op', 'type': 'Mish'},
'testing_data': {'shape': shape, 'kind': 'data'},
'result': {'kind': 'op', 'type': 'Result'}
}
ref_net = build_graph(nodes_attributes,
[('input', 'input_data'),
('input_data', 'testing_op'),
('testing_op', 'testing_data'),
('testing_data', 'result')
])
return tf_net, ref_net
def create_net_with_unary_op(self, shape, ir_version, op_type, use_new_frontend):
"""
TODO: Move functionality of `create_net_with_mish()` here once tensorflow_addons
supports Python 3.11
Tensorflow net IR net
Input->UnaryOp => Input->UnaryOp
"""
import tensorflow as tf
self.current_op_type = op_type
@ -127,6 +88,7 @@ class TestUnaryOps(CommonTFLayerTest):
'LogicalNot': tf.math.logical_not,
# 'Mish': tfa.activations.mish, # temporarily moved to `create_net_with_mish()`
'Negative': tf.math.negative,
'Selu': tf.nn.selu,
'Sigmoid': tf.nn.sigmoid,
'Sign': tf.math.sign,
'Sin': tf.math.sin,
@ -158,30 +120,7 @@ class TestUnaryOps(CommonTFLayerTest):
tf.compat.v1.global_variables_initializer()
tf_net = sess.graph_def
#
# Create reference IR net
# Please, specify 'type': 'Input' for input node
# Moreover, do not forget to validate ALL layer attributes!!!
#
ref_net = None
if check_ir_version(10, None, ir_version) and not use_new_frontend:
nodes_attributes = {
'input': {'kind': 'op', 'type': 'Parameter'},
'input_data': {'shape': shape, 'kind': 'data'},
'testing_op': {'kind': 'op', 'type': self.current_op_type},
'testing_data': {'shape': shape, 'kind': 'data'},
'result': {'kind': 'op', 'type': 'Result'}
}
ref_net = build_graph(nodes_attributes,
[('input', 'input_data'),
('input_data', 'testing_op'),
('testing_op', 'testing_data'),
('testing_data', 'result')
])
return tf_net, ref_net
test_data_precommit = [dict(shape=[4, 6, 8, 10, 12])]
@ -224,7 +163,8 @@ class TestUnaryOps(CommonTFLayerTest):
ie_device, precision, ir_version, temp_dir=temp_dir,
use_new_frontend=use_new_frontend, use_old_api=use_old_api)
@pytest.mark.xfail(sys.version_info > (3, 10), reason="tensorflow_addons package is not available for Python 3.11 and higher")
@pytest.mark.xfail(sys.version_info > (3, 10),
reason="tensorflow_addons package is not available for Python 3.11 and higher")
@pytest.mark.parametrize("params", test_data_precommit)
@pytest.mark.precommit
def test_unary_op_mish_precommit(self, params, ie_device, precision, ir_version, temp_dir,
@ -271,6 +211,7 @@ class TestUnaryOps(CommonTFLayerTest):
'Asinh',
'Square',
'Erf',
'Selu'
])
@pytest.mark.nightly
def test_unary_op(self, params, ie_device, precision, ir_version, temp_dir, op_type,
@ -282,7 +223,8 @@ class TestUnaryOps(CommonTFLayerTest):
ie_device, precision, ir_version, temp_dir=temp_dir,
use_new_frontend=use_new_frontend, use_old_api=use_old_api)
@pytest.mark.xfail(sys.version_info > (3, 10), reason="tensorflow_addons package is not available for Python 3.11 and higher")
@pytest.mark.xfail(sys.version_info > (3, 10),
reason="tensorflow_addons package is not available for Python 3.11 and higher")
@pytest.mark.parametrize("params", test_data)
@pytest.mark.nightly
def test_unary_op_mish(self, params, ie_device, precision, ir_version, temp_dir, op_type,