[TF FE][GitHub issue] Support Selu operation and add test (#19528)
* [TF FE] Support Selu operation and add test Signed-off-by: Kazantsev, Roman <roman.kazantsev@intel.com> * Fix layer test --------- Signed-off-by: Kazantsev, Roman <roman.kazantsev@intel.com>
This commit is contained in:
parent
9715ccd992
commit
2cf8f2bc1f
@ -73,6 +73,7 @@ const std::map<std::string, CreatorFunction> get_supported_ops() {
|
||||
{"Mish", CreatorFunction(translate_unary_op<opset8::Mish>)},
|
||||
{"Neg", CreatorFunction(translate_unary_op<opset8::Negative>)},
|
||||
{"Relu", CreatorFunction(translate_unary_op<opset8::Relu>)},
|
||||
{"Selu", CreatorFunction(translate_selu_op)},
|
||||
{"Sigmoid", CreatorFunction(translate_unary_op<opset8::Sigmoid>)},
|
||||
{"Sin", CreatorFunction(translate_unary_op<opset8::Sin>)},
|
||||
{"Sinh", CreatorFunction(translate_unary_op<opset8::Sinh>)},
|
||||
|
@ -28,6 +28,7 @@ namespace op {
|
||||
OutputVector op(const ov::frontend::NodeContext& node)
|
||||
|
||||
OP_T_CONVERTER(translate_unary_op);
|
||||
OP_CONVERTER(translate_selu_op);
|
||||
OP_T_CONVERTER(translate_binary_op);
|
||||
OP_T_CONVERTER(translate_direct_reduce_op);
|
||||
|
||||
|
@ -3,12 +3,44 @@
|
||||
//
|
||||
|
||||
#include "common_op_table.hpp"
|
||||
#include "openvino/opsets/opset10.hpp"
|
||||
#include "openvino/opsets/opset8.hpp"
|
||||
#include "openvino/opsets/opset9.hpp"
|
||||
#include "openvino/op/abs.hpp"
|
||||
#include "openvino/op/acos.hpp"
|
||||
#include "openvino/op/acosh.hpp"
|
||||
#include "openvino/op/asin.hpp"
|
||||
#include "openvino/op/asinh.hpp"
|
||||
#include "openvino/op/atan.hpp"
|
||||
#include "openvino/op/atanh.hpp"
|
||||
#include "openvino/op/ceiling.hpp"
|
||||
#include "openvino/op/constant.hpp"
|
||||
#include "openvino/op/cos.hpp"
|
||||
#include "openvino/op/cosh.hpp"
|
||||
#include "openvino/op/erf.hpp"
|
||||
#include "openvino/op/exp.hpp"
|
||||
#include "openvino/op/floor.hpp"
|
||||
#include "openvino/op/hswish.hpp"
|
||||
#include "openvino/op/is_finite.hpp"
|
||||
#include "openvino/op/is_inf.hpp"
|
||||
#include "openvino/op/is_nan.hpp"
|
||||
#include "openvino/op/log.hpp"
|
||||
#include "openvino/op/logical_not.hpp"
|
||||
#include "openvino/op/mish.hpp"
|
||||
#include "openvino/op/negative.hpp"
|
||||
#include "openvino/op/relu.hpp"
|
||||
#include "openvino/op/selu.hpp"
|
||||
#include "openvino/op/sigmoid.hpp"
|
||||
#include "openvino/op/sign.hpp"
|
||||
#include "openvino/op/sin.hpp"
|
||||
#include "openvino/op/sinh.hpp"
|
||||
#include "openvino/op/softplus.hpp"
|
||||
#include "openvino/op/softsign.hpp"
|
||||
#include "openvino/op/swish.hpp"
|
||||
#include "openvino/op/tan.hpp"
|
||||
#include "openvino/op/tanh.hpp"
|
||||
#include "utils.hpp"
|
||||
|
||||
using namespace std;
|
||||
using namespace ov::opset8;
|
||||
using namespace ov::frontend::tensorflow;
|
||||
using namespace ov::op;
|
||||
|
||||
namespace ov {
|
||||
namespace frontend {
|
||||
@ -16,9 +48,9 @@ namespace tensorflow {
|
||||
namespace op {
|
||||
|
||||
OutputVector translate_unary_op(const NodeContext& op,
|
||||
const std::function<shared_ptr<Node>(Output<Node>)>& create_unary_op) {
|
||||
auto ng_input = op.get_input(0);
|
||||
auto res = create_unary_op(ng_input);
|
||||
const function<shared_ptr<Node>(Output<Node>)>& create_unary_op) {
|
||||
auto input = op.get_input(0);
|
||||
auto res = create_unary_op(input);
|
||||
set_node_name(op.get_name(), res);
|
||||
return {res};
|
||||
}
|
||||
@ -30,37 +62,49 @@ OutputVector translate_unary_op(const NodeContext& node) {
|
||||
});
|
||||
}
|
||||
|
||||
template OutputVector translate_unary_op<Abs>(const NodeContext& node);
|
||||
template OutputVector translate_unary_op<Acos>(const NodeContext& node);
|
||||
template OutputVector translate_unary_op<Acosh>(const NodeContext& node);
|
||||
template OutputVector translate_unary_op<Asin>(const NodeContext& node);
|
||||
template OutputVector translate_unary_op<Asinh>(const NodeContext& node);
|
||||
template OutputVector translate_unary_op<Atan>(const NodeContext& node);
|
||||
template OutputVector translate_unary_op<Atanh>(const NodeContext& node);
|
||||
template OutputVector translate_unary_op<Ceiling>(const NodeContext& node);
|
||||
template OutputVector translate_unary_op<Cos>(const NodeContext& node);
|
||||
template OutputVector translate_unary_op<Cosh>(const NodeContext& node);
|
||||
template OutputVector translate_unary_op<Erf>(const NodeContext& node);
|
||||
template OutputVector translate_unary_op<Exp>(const NodeContext& node);
|
||||
template OutputVector translate_unary_op<Floor>(const NodeContext& node);
|
||||
template OutputVector translate_unary_op<HSwish>(const NodeContext& node);
|
||||
template OutputVector translate_unary_op<opset10::IsFinite>(const NodeContext& node);
|
||||
template OutputVector translate_unary_op<opset10::IsInf>(const NodeContext& node);
|
||||
template OutputVector translate_unary_op<opset10::IsNaN>(const NodeContext& node);
|
||||
template OutputVector translate_unary_op<Log>(const NodeContext& node);
|
||||
template OutputVector translate_unary_op<LogicalNot>(const NodeContext& node);
|
||||
template OutputVector translate_unary_op<Mish>(const NodeContext& node);
|
||||
template OutputVector translate_unary_op<Negative>(const NodeContext& node);
|
||||
template OutputVector translate_unary_op<Relu>(const NodeContext& node);
|
||||
template OutputVector translate_unary_op<Sigmoid>(const NodeContext& node);
|
||||
template OutputVector translate_unary_op<Sin>(const NodeContext& node);
|
||||
template OutputVector translate_unary_op<Sinh>(const NodeContext& node);
|
||||
template OutputVector translate_unary_op<Sign>(const NodeContext& node);
|
||||
template OutputVector translate_unary_op<SoftPlus>(const NodeContext& node);
|
||||
template OutputVector translate_unary_op<Tan>(const NodeContext& node);
|
||||
template OutputVector translate_unary_op<Tanh>(const NodeContext& node);
|
||||
template OutputVector translate_unary_op<opset9::SoftSign>(const NodeContext& node);
|
||||
template OutputVector translate_unary_op<Swish>(const NodeContext& node);
|
||||
template OutputVector translate_unary_op<v0::Abs>(const NodeContext& node);
|
||||
template OutputVector translate_unary_op<v0::Acos>(const NodeContext& node);
|
||||
template OutputVector translate_unary_op<v3::Acosh>(const NodeContext& node);
|
||||
template OutputVector translate_unary_op<v0::Asin>(const NodeContext& node);
|
||||
template OutputVector translate_unary_op<v3::Asinh>(const NodeContext& node);
|
||||
template OutputVector translate_unary_op<v0::Atan>(const NodeContext& node);
|
||||
template OutputVector translate_unary_op<v3::Atanh>(const NodeContext& node);
|
||||
template OutputVector translate_unary_op<v0::Ceiling>(const NodeContext& node);
|
||||
template OutputVector translate_unary_op<v0::Cos>(const NodeContext& node);
|
||||
template OutputVector translate_unary_op<v0::Cosh>(const NodeContext& node);
|
||||
template OutputVector translate_unary_op<v0::Erf>(const NodeContext& node);
|
||||
template OutputVector translate_unary_op<v0::Exp>(const NodeContext& node);
|
||||
template OutputVector translate_unary_op<v0::Floor>(const NodeContext& node);
|
||||
template OutputVector translate_unary_op<v4::HSwish>(const NodeContext& node);
|
||||
template OutputVector translate_unary_op<v10::IsFinite>(const NodeContext& node);
|
||||
template OutputVector translate_unary_op<v10::IsInf>(const NodeContext& node);
|
||||
template OutputVector translate_unary_op<v10::IsNaN>(const NodeContext& node);
|
||||
template OutputVector translate_unary_op<v0::Log>(const NodeContext& node);
|
||||
template OutputVector translate_unary_op<v1::LogicalNot>(const NodeContext& node);
|
||||
template OutputVector translate_unary_op<v4::Mish>(const NodeContext& node);
|
||||
template OutputVector translate_unary_op<v0::Negative>(const NodeContext& node);
|
||||
template OutputVector translate_unary_op<v0::Relu>(const NodeContext& node);
|
||||
template OutputVector translate_unary_op<v0::Sigmoid>(const NodeContext& node);
|
||||
template OutputVector translate_unary_op<v0::Sin>(const NodeContext& node);
|
||||
template OutputVector translate_unary_op<v0::Sinh>(const NodeContext& node);
|
||||
template OutputVector translate_unary_op<v0::Sign>(const NodeContext& node);
|
||||
template OutputVector translate_unary_op<v4::SoftPlus>(const NodeContext& node);
|
||||
template OutputVector translate_unary_op<v0::Tan>(const NodeContext& node);
|
||||
template OutputVector translate_unary_op<v0::Tanh>(const NodeContext& node);
|
||||
template OutputVector translate_unary_op<v9::SoftSign>(const NodeContext& node);
|
||||
template OutputVector translate_unary_op<v4::Swish>(const NodeContext& node);
|
||||
|
||||
OutputVector translate_selu_op(const NodeContext& node) {
|
||||
default_op_checks(node, 1, {"Selu"});
|
||||
auto features = node.get_input(0);
|
||||
|
||||
// create pre-defined constants
|
||||
auto alpha = create_same_type_const<float>(features, {1.67326324f}, Shape{1});
|
||||
auto scale = create_same_type_const<float>(features, {1.05070098f}, Shape{1});
|
||||
auto selu = make_shared<v0::Selu>(features, alpha, scale);
|
||||
set_node_name(node.get_name(), selu);
|
||||
return {selu};
|
||||
}
|
||||
|
||||
} // namespace op
|
||||
} // namespace tensorflow
|
||||
|
@ -1,22 +1,20 @@
|
||||
# Copyright (C) 2018-2023 Intel Corporation
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
import sys
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
import sys
|
||||
from common.layer_test_class import check_ir_version
|
||||
from common.tf_layer_test_class import CommonTFLayerTest
|
||||
from common.utils.tf_utils import permute_nchw_to_nhwc
|
||||
|
||||
from unit_tests.utils.graph import build_graph
|
||||
|
||||
|
||||
class TestUnaryOps(CommonTFLayerTest):
|
||||
current_op_type = None
|
||||
|
||||
def _prepare_input(self, inputs_dict):
|
||||
non_negative = ['Sqrt', 'Log']
|
||||
narrow_borders = ["Sinh", "Cosh", "Tanh", "Exp"]
|
||||
narrow_borders = ["Sinh", "Cosh", "Tanh", "Exp", "Selu"]
|
||||
within_one = ['Asin', 'Acos', 'Atanh']
|
||||
from_one = ['Acosh']
|
||||
|
||||
@ -48,13 +46,6 @@ class TestUnaryOps(CommonTFLayerTest):
|
||||
return inputs_dict
|
||||
|
||||
def create_net_with_mish(self, shape, ir_version, use_new_frontend):
|
||||
"""
|
||||
TODO: Move functionality to `create_net_with_unary_op()` once tensorflow_addons
|
||||
supports Python 3.11
|
||||
Tensorflow net IR net
|
||||
|
||||
Input->mish => Input->mish
|
||||
"""
|
||||
import tensorflow as tf
|
||||
import tensorflow_addons as tfa
|
||||
|
||||
@ -70,41 +61,11 @@ class TestUnaryOps(CommonTFLayerTest):
|
||||
tf.compat.v1.global_variables_initializer()
|
||||
tf_net = sess.graph_def
|
||||
|
||||
#
|
||||
# Create reference IR net
|
||||
# Please, specify 'type': 'Input' for input node
|
||||
# Moreover, do not forget to validate ALL layer attributes!!!
|
||||
#
|
||||
|
||||
ref_net = None
|
||||
|
||||
if check_ir_version(10, None, ir_version) and not use_new_frontend:
|
||||
nodes_attributes = {
|
||||
'input': {'kind': 'op', 'type': 'Parameter'},
|
||||
'input_data': {'shape': shape, 'kind': 'data'},
|
||||
'testing_op': {'kind': 'op', 'type': 'Mish'},
|
||||
'testing_data': {'shape': shape, 'kind': 'data'},
|
||||
'result': {'kind': 'op', 'type': 'Result'}
|
||||
}
|
||||
|
||||
ref_net = build_graph(nodes_attributes,
|
||||
[('input', 'input_data'),
|
||||
('input_data', 'testing_op'),
|
||||
('testing_op', 'testing_data'),
|
||||
('testing_data', 'result')
|
||||
])
|
||||
|
||||
return tf_net, ref_net
|
||||
|
||||
def create_net_with_unary_op(self, shape, ir_version, op_type, use_new_frontend):
|
||||
"""
|
||||
TODO: Move functionality of `create_net_with_mish()` here once tensorflow_addons
|
||||
supports Python 3.11
|
||||
Tensorflow net IR net
|
||||
|
||||
Input->UnaryOp => Input->UnaryOp
|
||||
|
||||
"""
|
||||
import tensorflow as tf
|
||||
|
||||
self.current_op_type = op_type
|
||||
@ -125,8 +86,9 @@ class TestUnaryOps(CommonTFLayerTest):
|
||||
'Floor': tf.math.floor,
|
||||
'Log': tf.math.log,
|
||||
'LogicalNot': tf.math.logical_not,
|
||||
#'Mish': tfa.activations.mish, # temporarily moved to `create_net_with_mish()`
|
||||
# 'Mish': tfa.activations.mish, # temporarily moved to `create_net_with_mish()`
|
||||
'Negative': tf.math.negative,
|
||||
'Selu': tf.nn.selu,
|
||||
'Sigmoid': tf.nn.sigmoid,
|
||||
'Sign': tf.math.sign,
|
||||
'Sin': tf.math.sin,
|
||||
@ -158,30 +120,7 @@ class TestUnaryOps(CommonTFLayerTest):
|
||||
tf.compat.v1.global_variables_initializer()
|
||||
tf_net = sess.graph_def
|
||||
|
||||
#
|
||||
# Create reference IR net
|
||||
# Please, specify 'type': 'Input' for input node
|
||||
# Moreover, do not forget to validate ALL layer attributes!!!
|
||||
#
|
||||
|
||||
ref_net = None
|
||||
|
||||
if check_ir_version(10, None, ir_version) and not use_new_frontend:
|
||||
nodes_attributes = {
|
||||
'input': {'kind': 'op', 'type': 'Parameter'},
|
||||
'input_data': {'shape': shape, 'kind': 'data'},
|
||||
'testing_op': {'kind': 'op', 'type': self.current_op_type},
|
||||
'testing_data': {'shape': shape, 'kind': 'data'},
|
||||
'result': {'kind': 'op', 'type': 'Result'}
|
||||
}
|
||||
|
||||
ref_net = build_graph(nodes_attributes,
|
||||
[('input', 'input_data'),
|
||||
('input_data', 'testing_op'),
|
||||
('testing_op', 'testing_data'),
|
||||
('testing_data', 'result')
|
||||
])
|
||||
|
||||
return tf_net, ref_net
|
||||
|
||||
test_data_precommit = [dict(shape=[4, 6, 8, 10, 12])]
|
||||
@ -224,7 +163,8 @@ class TestUnaryOps(CommonTFLayerTest):
|
||||
ie_device, precision, ir_version, temp_dir=temp_dir,
|
||||
use_new_frontend=use_new_frontend, use_old_api=use_old_api)
|
||||
|
||||
@pytest.mark.xfail(sys.version_info > (3, 10), reason="tensorflow_addons package is not available for Python 3.11 and higher")
|
||||
@pytest.mark.xfail(sys.version_info > (3, 10),
|
||||
reason="tensorflow_addons package is not available for Python 3.11 and higher")
|
||||
@pytest.mark.parametrize("params", test_data_precommit)
|
||||
@pytest.mark.precommit
|
||||
def test_unary_op_mish_precommit(self, params, ie_device, precision, ir_version, temp_dir,
|
||||
@ -271,6 +211,7 @@ class TestUnaryOps(CommonTFLayerTest):
|
||||
'Asinh',
|
||||
'Square',
|
||||
'Erf',
|
||||
'Selu'
|
||||
])
|
||||
@pytest.mark.nightly
|
||||
def test_unary_op(self, params, ie_device, precision, ir_version, temp_dir, op_type,
|
||||
@ -282,7 +223,8 @@ class TestUnaryOps(CommonTFLayerTest):
|
||||
ie_device, precision, ir_version, temp_dir=temp_dir,
|
||||
use_new_frontend=use_new_frontend, use_old_api=use_old_api)
|
||||
|
||||
@pytest.mark.xfail(sys.version_info > (3, 10), reason="tensorflow_addons package is not available for Python 3.11 and higher")
|
||||
@pytest.mark.xfail(sys.version_info > (3, 10),
|
||||
reason="tensorflow_addons package is not available for Python 3.11 and higher")
|
||||
@pytest.mark.parametrize("params", test_data)
|
||||
@pytest.mark.nightly
|
||||
def test_unary_op_mish(self, params, ie_device, precision, ir_version, temp_dir, op_type,
|
||||
|
Loading…
Reference in New Issue
Block a user