[TF FE][GitHub issue] Support Selu operation and add test (#19528)
* [TF FE] Support Selu operation and add test Signed-off-by: Kazantsev, Roman <roman.kazantsev@intel.com> * Fix layer test --------- Signed-off-by: Kazantsev, Roman <roman.kazantsev@intel.com>
This commit is contained in:
parent
9715ccd992
commit
2cf8f2bc1f
@ -73,6 +73,7 @@ const std::map<std::string, CreatorFunction> get_supported_ops() {
|
|||||||
{"Mish", CreatorFunction(translate_unary_op<opset8::Mish>)},
|
{"Mish", CreatorFunction(translate_unary_op<opset8::Mish>)},
|
||||||
{"Neg", CreatorFunction(translate_unary_op<opset8::Negative>)},
|
{"Neg", CreatorFunction(translate_unary_op<opset8::Negative>)},
|
||||||
{"Relu", CreatorFunction(translate_unary_op<opset8::Relu>)},
|
{"Relu", CreatorFunction(translate_unary_op<opset8::Relu>)},
|
||||||
|
{"Selu", CreatorFunction(translate_selu_op)},
|
||||||
{"Sigmoid", CreatorFunction(translate_unary_op<opset8::Sigmoid>)},
|
{"Sigmoid", CreatorFunction(translate_unary_op<opset8::Sigmoid>)},
|
||||||
{"Sin", CreatorFunction(translate_unary_op<opset8::Sin>)},
|
{"Sin", CreatorFunction(translate_unary_op<opset8::Sin>)},
|
||||||
{"Sinh", CreatorFunction(translate_unary_op<opset8::Sinh>)},
|
{"Sinh", CreatorFunction(translate_unary_op<opset8::Sinh>)},
|
||||||
|
@ -28,6 +28,7 @@ namespace op {
|
|||||||
OutputVector op(const ov::frontend::NodeContext& node)
|
OutputVector op(const ov::frontend::NodeContext& node)
|
||||||
|
|
||||||
OP_T_CONVERTER(translate_unary_op);
|
OP_T_CONVERTER(translate_unary_op);
|
||||||
|
OP_CONVERTER(translate_selu_op);
|
||||||
OP_T_CONVERTER(translate_binary_op);
|
OP_T_CONVERTER(translate_binary_op);
|
||||||
OP_T_CONVERTER(translate_direct_reduce_op);
|
OP_T_CONVERTER(translate_direct_reduce_op);
|
||||||
|
|
||||||
|
@ -3,12 +3,44 @@
|
|||||||
//
|
//
|
||||||
|
|
||||||
#include "common_op_table.hpp"
|
#include "common_op_table.hpp"
|
||||||
#include "openvino/opsets/opset10.hpp"
|
#include "openvino/op/abs.hpp"
|
||||||
#include "openvino/opsets/opset8.hpp"
|
#include "openvino/op/acos.hpp"
|
||||||
#include "openvino/opsets/opset9.hpp"
|
#include "openvino/op/acosh.hpp"
|
||||||
|
#include "openvino/op/asin.hpp"
|
||||||
|
#include "openvino/op/asinh.hpp"
|
||||||
|
#include "openvino/op/atan.hpp"
|
||||||
|
#include "openvino/op/atanh.hpp"
|
||||||
|
#include "openvino/op/ceiling.hpp"
|
||||||
|
#include "openvino/op/constant.hpp"
|
||||||
|
#include "openvino/op/cos.hpp"
|
||||||
|
#include "openvino/op/cosh.hpp"
|
||||||
|
#include "openvino/op/erf.hpp"
|
||||||
|
#include "openvino/op/exp.hpp"
|
||||||
|
#include "openvino/op/floor.hpp"
|
||||||
|
#include "openvino/op/hswish.hpp"
|
||||||
|
#include "openvino/op/is_finite.hpp"
|
||||||
|
#include "openvino/op/is_inf.hpp"
|
||||||
|
#include "openvino/op/is_nan.hpp"
|
||||||
|
#include "openvino/op/log.hpp"
|
||||||
|
#include "openvino/op/logical_not.hpp"
|
||||||
|
#include "openvino/op/mish.hpp"
|
||||||
|
#include "openvino/op/negative.hpp"
|
||||||
|
#include "openvino/op/relu.hpp"
|
||||||
|
#include "openvino/op/selu.hpp"
|
||||||
|
#include "openvino/op/sigmoid.hpp"
|
||||||
|
#include "openvino/op/sign.hpp"
|
||||||
|
#include "openvino/op/sin.hpp"
|
||||||
|
#include "openvino/op/sinh.hpp"
|
||||||
|
#include "openvino/op/softplus.hpp"
|
||||||
|
#include "openvino/op/softsign.hpp"
|
||||||
|
#include "openvino/op/swish.hpp"
|
||||||
|
#include "openvino/op/tan.hpp"
|
||||||
|
#include "openvino/op/tanh.hpp"
|
||||||
|
#include "utils.hpp"
|
||||||
|
|
||||||
using namespace std;
|
using namespace std;
|
||||||
using namespace ov::opset8;
|
using namespace ov::frontend::tensorflow;
|
||||||
|
using namespace ov::op;
|
||||||
|
|
||||||
namespace ov {
|
namespace ov {
|
||||||
namespace frontend {
|
namespace frontend {
|
||||||
@ -16,9 +48,9 @@ namespace tensorflow {
|
|||||||
namespace op {
|
namespace op {
|
||||||
|
|
||||||
OutputVector translate_unary_op(const NodeContext& op,
|
OutputVector translate_unary_op(const NodeContext& op,
|
||||||
const std::function<shared_ptr<Node>(Output<Node>)>& create_unary_op) {
|
const function<shared_ptr<Node>(Output<Node>)>& create_unary_op) {
|
||||||
auto ng_input = op.get_input(0);
|
auto input = op.get_input(0);
|
||||||
auto res = create_unary_op(ng_input);
|
auto res = create_unary_op(input);
|
||||||
set_node_name(op.get_name(), res);
|
set_node_name(op.get_name(), res);
|
||||||
return {res};
|
return {res};
|
||||||
}
|
}
|
||||||
@ -30,37 +62,49 @@ OutputVector translate_unary_op(const NodeContext& node) {
|
|||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
template OutputVector translate_unary_op<Abs>(const NodeContext& node);
|
template OutputVector translate_unary_op<v0::Abs>(const NodeContext& node);
|
||||||
template OutputVector translate_unary_op<Acos>(const NodeContext& node);
|
template OutputVector translate_unary_op<v0::Acos>(const NodeContext& node);
|
||||||
template OutputVector translate_unary_op<Acosh>(const NodeContext& node);
|
template OutputVector translate_unary_op<v3::Acosh>(const NodeContext& node);
|
||||||
template OutputVector translate_unary_op<Asin>(const NodeContext& node);
|
template OutputVector translate_unary_op<v0::Asin>(const NodeContext& node);
|
||||||
template OutputVector translate_unary_op<Asinh>(const NodeContext& node);
|
template OutputVector translate_unary_op<v3::Asinh>(const NodeContext& node);
|
||||||
template OutputVector translate_unary_op<Atan>(const NodeContext& node);
|
template OutputVector translate_unary_op<v0::Atan>(const NodeContext& node);
|
||||||
template OutputVector translate_unary_op<Atanh>(const NodeContext& node);
|
template OutputVector translate_unary_op<v3::Atanh>(const NodeContext& node);
|
||||||
template OutputVector translate_unary_op<Ceiling>(const NodeContext& node);
|
template OutputVector translate_unary_op<v0::Ceiling>(const NodeContext& node);
|
||||||
template OutputVector translate_unary_op<Cos>(const NodeContext& node);
|
template OutputVector translate_unary_op<v0::Cos>(const NodeContext& node);
|
||||||
template OutputVector translate_unary_op<Cosh>(const NodeContext& node);
|
template OutputVector translate_unary_op<v0::Cosh>(const NodeContext& node);
|
||||||
template OutputVector translate_unary_op<Erf>(const NodeContext& node);
|
template OutputVector translate_unary_op<v0::Erf>(const NodeContext& node);
|
||||||
template OutputVector translate_unary_op<Exp>(const NodeContext& node);
|
template OutputVector translate_unary_op<v0::Exp>(const NodeContext& node);
|
||||||
template OutputVector translate_unary_op<Floor>(const NodeContext& node);
|
template OutputVector translate_unary_op<v0::Floor>(const NodeContext& node);
|
||||||
template OutputVector translate_unary_op<HSwish>(const NodeContext& node);
|
template OutputVector translate_unary_op<v4::HSwish>(const NodeContext& node);
|
||||||
template OutputVector translate_unary_op<opset10::IsFinite>(const NodeContext& node);
|
template OutputVector translate_unary_op<v10::IsFinite>(const NodeContext& node);
|
||||||
template OutputVector translate_unary_op<opset10::IsInf>(const NodeContext& node);
|
template OutputVector translate_unary_op<v10::IsInf>(const NodeContext& node);
|
||||||
template OutputVector translate_unary_op<opset10::IsNaN>(const NodeContext& node);
|
template OutputVector translate_unary_op<v10::IsNaN>(const NodeContext& node);
|
||||||
template OutputVector translate_unary_op<Log>(const NodeContext& node);
|
template OutputVector translate_unary_op<v0::Log>(const NodeContext& node);
|
||||||
template OutputVector translate_unary_op<LogicalNot>(const NodeContext& node);
|
template OutputVector translate_unary_op<v1::LogicalNot>(const NodeContext& node);
|
||||||
template OutputVector translate_unary_op<Mish>(const NodeContext& node);
|
template OutputVector translate_unary_op<v4::Mish>(const NodeContext& node);
|
||||||
template OutputVector translate_unary_op<Negative>(const NodeContext& node);
|
template OutputVector translate_unary_op<v0::Negative>(const NodeContext& node);
|
||||||
template OutputVector translate_unary_op<Relu>(const NodeContext& node);
|
template OutputVector translate_unary_op<v0::Relu>(const NodeContext& node);
|
||||||
template OutputVector translate_unary_op<Sigmoid>(const NodeContext& node);
|
template OutputVector translate_unary_op<v0::Sigmoid>(const NodeContext& node);
|
||||||
template OutputVector translate_unary_op<Sin>(const NodeContext& node);
|
template OutputVector translate_unary_op<v0::Sin>(const NodeContext& node);
|
||||||
template OutputVector translate_unary_op<Sinh>(const NodeContext& node);
|
template OutputVector translate_unary_op<v0::Sinh>(const NodeContext& node);
|
||||||
template OutputVector translate_unary_op<Sign>(const NodeContext& node);
|
template OutputVector translate_unary_op<v0::Sign>(const NodeContext& node);
|
||||||
template OutputVector translate_unary_op<SoftPlus>(const NodeContext& node);
|
template OutputVector translate_unary_op<v4::SoftPlus>(const NodeContext& node);
|
||||||
template OutputVector translate_unary_op<Tan>(const NodeContext& node);
|
template OutputVector translate_unary_op<v0::Tan>(const NodeContext& node);
|
||||||
template OutputVector translate_unary_op<Tanh>(const NodeContext& node);
|
template OutputVector translate_unary_op<v0::Tanh>(const NodeContext& node);
|
||||||
template OutputVector translate_unary_op<opset9::SoftSign>(const NodeContext& node);
|
template OutputVector translate_unary_op<v9::SoftSign>(const NodeContext& node);
|
||||||
template OutputVector translate_unary_op<Swish>(const NodeContext& node);
|
template OutputVector translate_unary_op<v4::Swish>(const NodeContext& node);
|
||||||
|
|
||||||
|
OutputVector translate_selu_op(const NodeContext& node) {
|
||||||
|
default_op_checks(node, 1, {"Selu"});
|
||||||
|
auto features = node.get_input(0);
|
||||||
|
|
||||||
|
// create pre-defined constants
|
||||||
|
auto alpha = create_same_type_const<float>(features, {1.67326324f}, Shape{1});
|
||||||
|
auto scale = create_same_type_const<float>(features, {1.05070098f}, Shape{1});
|
||||||
|
auto selu = make_shared<v0::Selu>(features, alpha, scale);
|
||||||
|
set_node_name(node.get_name(), selu);
|
||||||
|
return {selu};
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace op
|
} // namespace op
|
||||||
} // namespace tensorflow
|
} // namespace tensorflow
|
||||||
|
@ -1,22 +1,20 @@
|
|||||||
# Copyright (C) 2018-2023 Intel Corporation
|
# Copyright (C) 2018-2023 Intel Corporation
|
||||||
# SPDX-License-Identifier: Apache-2.0
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
|
||||||
|
import sys
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import pytest
|
import pytest
|
||||||
import sys
|
|
||||||
from common.layer_test_class import check_ir_version
|
|
||||||
from common.tf_layer_test_class import CommonTFLayerTest
|
from common.tf_layer_test_class import CommonTFLayerTest
|
||||||
from common.utils.tf_utils import permute_nchw_to_nhwc
|
from common.utils.tf_utils import permute_nchw_to_nhwc
|
||||||
|
|
||||||
from unit_tests.utils.graph import build_graph
|
|
||||||
|
|
||||||
|
|
||||||
class TestUnaryOps(CommonTFLayerTest):
|
class TestUnaryOps(CommonTFLayerTest):
|
||||||
current_op_type = None
|
current_op_type = None
|
||||||
|
|
||||||
def _prepare_input(self, inputs_dict):
|
def _prepare_input(self, inputs_dict):
|
||||||
non_negative = ['Sqrt', 'Log']
|
non_negative = ['Sqrt', 'Log']
|
||||||
narrow_borders = ["Sinh", "Cosh", "Tanh", "Exp"]
|
narrow_borders = ["Sinh", "Cosh", "Tanh", "Exp", "Selu"]
|
||||||
within_one = ['Asin', 'Acos', 'Atanh']
|
within_one = ['Asin', 'Acos', 'Atanh']
|
||||||
from_one = ['Acosh']
|
from_one = ['Acosh']
|
||||||
|
|
||||||
@ -48,13 +46,6 @@ class TestUnaryOps(CommonTFLayerTest):
|
|||||||
return inputs_dict
|
return inputs_dict
|
||||||
|
|
||||||
def create_net_with_mish(self, shape, ir_version, use_new_frontend):
|
def create_net_with_mish(self, shape, ir_version, use_new_frontend):
|
||||||
"""
|
|
||||||
TODO: Move functionality to `create_net_with_unary_op()` once tensorflow_addons
|
|
||||||
supports Python 3.11
|
|
||||||
Tensorflow net IR net
|
|
||||||
|
|
||||||
Input->mish => Input->mish
|
|
||||||
"""
|
|
||||||
import tensorflow as tf
|
import tensorflow as tf
|
||||||
import tensorflow_addons as tfa
|
import tensorflow_addons as tfa
|
||||||
|
|
||||||
@ -70,41 +61,11 @@ class TestUnaryOps(CommonTFLayerTest):
|
|||||||
tf.compat.v1.global_variables_initializer()
|
tf.compat.v1.global_variables_initializer()
|
||||||
tf_net = sess.graph_def
|
tf_net = sess.graph_def
|
||||||
|
|
||||||
#
|
|
||||||
# Create reference IR net
|
|
||||||
# Please, specify 'type': 'Input' for input node
|
|
||||||
# Moreover, do not forget to validate ALL layer attributes!!!
|
|
||||||
#
|
|
||||||
|
|
||||||
ref_net = None
|
ref_net = None
|
||||||
|
|
||||||
if check_ir_version(10, None, ir_version) and not use_new_frontend:
|
|
||||||
nodes_attributes = {
|
|
||||||
'input': {'kind': 'op', 'type': 'Parameter'},
|
|
||||||
'input_data': {'shape': shape, 'kind': 'data'},
|
|
||||||
'testing_op': {'kind': 'op', 'type': 'Mish'},
|
|
||||||
'testing_data': {'shape': shape, 'kind': 'data'},
|
|
||||||
'result': {'kind': 'op', 'type': 'Result'}
|
|
||||||
}
|
|
||||||
|
|
||||||
ref_net = build_graph(nodes_attributes,
|
|
||||||
[('input', 'input_data'),
|
|
||||||
('input_data', 'testing_op'),
|
|
||||||
('testing_op', 'testing_data'),
|
|
||||||
('testing_data', 'result')
|
|
||||||
])
|
|
||||||
|
|
||||||
return tf_net, ref_net
|
return tf_net, ref_net
|
||||||
|
|
||||||
def create_net_with_unary_op(self, shape, ir_version, op_type, use_new_frontend):
|
def create_net_with_unary_op(self, shape, ir_version, op_type, use_new_frontend):
|
||||||
"""
|
|
||||||
TODO: Move functionality of `create_net_with_mish()` here once tensorflow_addons
|
|
||||||
supports Python 3.11
|
|
||||||
Tensorflow net IR net
|
|
||||||
|
|
||||||
Input->UnaryOp => Input->UnaryOp
|
|
||||||
|
|
||||||
"""
|
|
||||||
import tensorflow as tf
|
import tensorflow as tf
|
||||||
|
|
||||||
self.current_op_type = op_type
|
self.current_op_type = op_type
|
||||||
@ -125,8 +86,9 @@ class TestUnaryOps(CommonTFLayerTest):
|
|||||||
'Floor': tf.math.floor,
|
'Floor': tf.math.floor,
|
||||||
'Log': tf.math.log,
|
'Log': tf.math.log,
|
||||||
'LogicalNot': tf.math.logical_not,
|
'LogicalNot': tf.math.logical_not,
|
||||||
#'Mish': tfa.activations.mish, # temporarily moved to `create_net_with_mish()`
|
# 'Mish': tfa.activations.mish, # temporarily moved to `create_net_with_mish()`
|
||||||
'Negative': tf.math.negative,
|
'Negative': tf.math.negative,
|
||||||
|
'Selu': tf.nn.selu,
|
||||||
'Sigmoid': tf.nn.sigmoid,
|
'Sigmoid': tf.nn.sigmoid,
|
||||||
'Sign': tf.math.sign,
|
'Sign': tf.math.sign,
|
||||||
'Sin': tf.math.sin,
|
'Sin': tf.math.sin,
|
||||||
@ -158,30 +120,7 @@ class TestUnaryOps(CommonTFLayerTest):
|
|||||||
tf.compat.v1.global_variables_initializer()
|
tf.compat.v1.global_variables_initializer()
|
||||||
tf_net = sess.graph_def
|
tf_net = sess.graph_def
|
||||||
|
|
||||||
#
|
|
||||||
# Create reference IR net
|
|
||||||
# Please, specify 'type': 'Input' for input node
|
|
||||||
# Moreover, do not forget to validate ALL layer attributes!!!
|
|
||||||
#
|
|
||||||
|
|
||||||
ref_net = None
|
ref_net = None
|
||||||
|
|
||||||
if check_ir_version(10, None, ir_version) and not use_new_frontend:
|
|
||||||
nodes_attributes = {
|
|
||||||
'input': {'kind': 'op', 'type': 'Parameter'},
|
|
||||||
'input_data': {'shape': shape, 'kind': 'data'},
|
|
||||||
'testing_op': {'kind': 'op', 'type': self.current_op_type},
|
|
||||||
'testing_data': {'shape': shape, 'kind': 'data'},
|
|
||||||
'result': {'kind': 'op', 'type': 'Result'}
|
|
||||||
}
|
|
||||||
|
|
||||||
ref_net = build_graph(nodes_attributes,
|
|
||||||
[('input', 'input_data'),
|
|
||||||
('input_data', 'testing_op'),
|
|
||||||
('testing_op', 'testing_data'),
|
|
||||||
('testing_data', 'result')
|
|
||||||
])
|
|
||||||
|
|
||||||
return tf_net, ref_net
|
return tf_net, ref_net
|
||||||
|
|
||||||
test_data_precommit = [dict(shape=[4, 6, 8, 10, 12])]
|
test_data_precommit = [dict(shape=[4, 6, 8, 10, 12])]
|
||||||
@ -223,12 +162,13 @@ class TestUnaryOps(CommonTFLayerTest):
|
|||||||
use_new_frontend=use_new_frontend),
|
use_new_frontend=use_new_frontend),
|
||||||
ie_device, precision, ir_version, temp_dir=temp_dir,
|
ie_device, precision, ir_version, temp_dir=temp_dir,
|
||||||
use_new_frontend=use_new_frontend, use_old_api=use_old_api)
|
use_new_frontend=use_new_frontend, use_old_api=use_old_api)
|
||||||
|
|
||||||
@pytest.mark.xfail(sys.version_info > (3, 10), reason="tensorflow_addons package is not available for Python 3.11 and higher")
|
@pytest.mark.xfail(sys.version_info > (3, 10),
|
||||||
|
reason="tensorflow_addons package is not available for Python 3.11 and higher")
|
||||||
@pytest.mark.parametrize("params", test_data_precommit)
|
@pytest.mark.parametrize("params", test_data_precommit)
|
||||||
@pytest.mark.precommit
|
@pytest.mark.precommit
|
||||||
def test_unary_op_mish_precommit(self, params, ie_device, precision, ir_version, temp_dir,
|
def test_unary_op_mish_precommit(self, params, ie_device, precision, ir_version, temp_dir,
|
||||||
use_new_frontend, use_old_api):
|
use_new_frontend, use_old_api):
|
||||||
"""
|
"""
|
||||||
TODO: Move to `test_unary_op_precommit()` once tensorflow_addons package is available for Python 3.11
|
TODO: Move to `test_unary_op_precommit()` once tensorflow_addons package is available for Python 3.11
|
||||||
"""
|
"""
|
||||||
@ -271,6 +211,7 @@ class TestUnaryOps(CommonTFLayerTest):
|
|||||||
'Asinh',
|
'Asinh',
|
||||||
'Square',
|
'Square',
|
||||||
'Erf',
|
'Erf',
|
||||||
|
'Selu'
|
||||||
])
|
])
|
||||||
@pytest.mark.nightly
|
@pytest.mark.nightly
|
||||||
def test_unary_op(self, params, ie_device, precision, ir_version, temp_dir, op_type,
|
def test_unary_op(self, params, ie_device, precision, ir_version, temp_dir, op_type,
|
||||||
@ -281,8 +222,9 @@ class TestUnaryOps(CommonTFLayerTest):
|
|||||||
use_new_frontend=use_new_frontend),
|
use_new_frontend=use_new_frontend),
|
||||||
ie_device, precision, ir_version, temp_dir=temp_dir,
|
ie_device, precision, ir_version, temp_dir=temp_dir,
|
||||||
use_new_frontend=use_new_frontend, use_old_api=use_old_api)
|
use_new_frontend=use_new_frontend, use_old_api=use_old_api)
|
||||||
|
|
||||||
@pytest.mark.xfail(sys.version_info > (3, 10), reason="tensorflow_addons package is not available for Python 3.11 and higher")
|
@pytest.mark.xfail(sys.version_info > (3, 10),
|
||||||
|
reason="tensorflow_addons package is not available for Python 3.11 and higher")
|
||||||
@pytest.mark.parametrize("params", test_data)
|
@pytest.mark.parametrize("params", test_data)
|
||||||
@pytest.mark.nightly
|
@pytest.mark.nightly
|
||||||
def test_unary_op_mish(self, params, ie_device, precision, ir_version, temp_dir, op_type,
|
def test_unary_op_mish(self, params, ie_device, precision, ir_version, temp_dir, op_type,
|
||||||
|
Loading…
Reference in New Issue
Block a user