Revise selu (#5513)

* remove FusedOp from selu operator

* add type_prop tests for selu op

* add vistors test for selu op

* update CMakeLists and fix wrong path to visitors test for selu

* fix style

* add backend tests for selu op

* refactor validate_and_infer_types function for selu op

* refactor type_prop tests for selu to catch particular errors

Co-authored-by: jdanieck <jozef.daniecki@intel.com>
This commit is contained in:
Bartek Szmelczynski 2021-06-11 17:01:52 +02:00 committed by GitHub
parent 089f76cc7f
commit 458435ad75
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 338 additions and 52 deletions

View File

@ -6,10 +6,6 @@
#include "ngraph/node.hpp" #include "ngraph/node.hpp"
#include "ngraph/op/op.hpp" #include "ngraph/op/op.hpp"
#include "ngraph/op/util/fused_op.hpp"
NGRAPH_SUPPRESS_DEPRECATED_START
namespace ngraph namespace ngraph
{ {
namespace op namespace op
@ -17,12 +13,12 @@ namespace ngraph
namespace v0 namespace v0
{ {
/// \brief Performs a SELU activation function on all elements of the input node /// \brief Performs a SELU activation function on all elements of the input node
class NGRAPH_API Selu : public ngraph::op::util::FusedOp class NGRAPH_API Selu : public ngraph::op::Op
{ {
public: public:
static constexpr NodeTypeInfo type_info{"Selu", 0}; NGRAPH_RTTI_DECLARATION;
const NodeTypeInfo& get_type_info() const override { return type_info; }
Selu(); Selu() = default;
/// \brief Constructs a Selu node. /// \brief Constructs a Selu node.
/// ///
/// \param data - Node producing the input tensor /// \param data - Node producing the input tensor
@ -31,9 +27,10 @@ namespace ngraph
Selu(const Output<Node>& data, Selu(const Output<Node>& data,
const Output<Node>& alpha, const Output<Node>& alpha,
const Output<Node>& lambda); const Output<Node>& lambda);
virtual void pre_validate_and_infer_types() override;
void validate_and_infer_types() override;
bool visit_attributes(AttributeVisitor& visitor) override; bool visit_attributes(AttributeVisitor& visitor) override;
virtual OutputVector decompose_op() const override;
virtual std::shared_ptr<Node> virtual std::shared_ptr<Node>
clone_with_new_inputs(const OutputVector& new_args) const override; clone_with_new_inputs(const OutputVector& new_args) const override;
@ -42,5 +39,3 @@ namespace ngraph
using v0::Selu; using v0::Selu;
} // namespace op } // namespace op
} // namespace ngraph } // namespace ngraph
NGRAPH_SUPPRESS_DEPRECATED_END

View File

@ -5,65 +5,53 @@
#include "ngraph/op/selu.hpp" #include "ngraph/op/selu.hpp"
#include "itt.hpp" #include "itt.hpp"
#include "ngraph/op/add.hpp"
#include "ngraph/op/constant.hpp"
#include "ngraph/op/exp.hpp"
#include "ngraph/op/maximum.hpp"
#include "ngraph/op/minimum.hpp"
#include "ngraph/op/multiply.hpp"
#include "ngraph/op/subtract.hpp"
using namespace std; using namespace std;
using namespace ngraph; using namespace ngraph;
NGRAPH_SUPPRESS_DEPRECATED_START NGRAPH_RTTI_DEFINITION(op::v0::Selu, "Selu", 0);
constexpr NodeTypeInfo op::v0::Selu::type_info;
op::v0::Selu::Selu()
: FusedOp()
{
}
op::v0::Selu::Selu(const Output<Node>& data, const Output<Node>& alpha, const Output<Node>& lambda) op::v0::Selu::Selu(const Output<Node>& data, const Output<Node>& alpha, const Output<Node>& lambda)
: FusedOp({data, alpha, lambda}) : Op({data, alpha, lambda})
{ {
constructor_validate_and_infer_types(); constructor_validate_and_infer_types();
} }
void ngraph::op::v0::Selu::pre_validate_and_infer_types() void op::v0::Selu::validate_and_infer_types()
{ {
set_output_type(0, get_input_element_type(0), get_input_partial_shape(0)); NGRAPH_OP_SCOPE(v0_Selu_validate_and_infer_types);
auto data_et = get_input_element_type(0);
auto alpha_et = get_input_element_type(1);
auto lambda_et = get_input_element_type(2);
auto result_et = element::dynamic;
NODE_VALIDATION_CHECK(this,
element::Type::merge(result_et, result_et, data_et) &&
element::Type::merge(result_et, result_et, alpha_et) &&
element::Type::merge(result_et, result_et, lambda_et),
"Input element types do not match : ",
data_et,
" and ",
alpha_et,
" and ",
lambda_et);
NODE_VALIDATION_CHECK(this,
result_et.is_dynamic() || result_et.is_real(),
"Input element types must be floating-point. Got: ",
result_et);
set_output_type(0, result_et, get_input_partial_shape(0));
} }
bool ngraph::op::v0::Selu::visit_attributes(AttributeVisitor& visitor) bool op::v0::Selu::visit_attributes(AttributeVisitor& visitor)
{ {
NGRAPH_OP_SCOPE(v0_Selu_visit_attributes); NGRAPH_OP_SCOPE(v0_Selu_visit_attributes);
return true; return true;
} }
OutputVector op::v0::Selu::decompose_op() const
{
const auto data = input_value(0);
const auto alpha = input_value(1);
const auto lambda = input_value(2);
const auto zero_node = op::Constant::create(data.get_element_type(), Shape{1}, {0});
// lambda * ((max(data, 0) + (alpha * exp(min(data, 0)) - alpha))
return {std::make_shared<op::v1::Multiply>(
lambda,
std::make_shared<op::v1::Add>(
std::make_shared<op::v1::Maximum>(data, zero_node),
std::make_shared<op::v1::Subtract>(
std::make_shared<op::v1::Multiply>(
alpha,
std::make_shared<op::Exp>(std::make_shared<op::v1::Minimum>(data, zero_node))),
alpha)))};
}
shared_ptr<Node> op::v0::Selu::clone_with_new_inputs(const OutputVector& new_args) const shared_ptr<Node> op::v0::Selu::clone_with_new_inputs(const OutputVector& new_args) const
{ {
NGRAPH_OP_SCOPE(v0_Selu_clone_with_new_inputs); NGRAPH_OP_SCOPE(v0_Selu_clone_with_new_inputs);
check_new_args_count(this, new_args); check_new_args_count(this, new_args);
return make_shared<v0::Selu>(new_args.at(0), new_args.at(1), new_args.at(2)); return make_shared<op::v0::Selu>(new_args.at(0), new_args.at(1), new_args.at(2));
} }

View File

@ -196,6 +196,7 @@ set(SRC
type_prop/scatter_nd_update.cpp type_prop/scatter_nd_update.cpp
type_prop/scatter_update.cpp type_prop/scatter_update.cpp
type_prop/select.cpp type_prop/select.cpp
type_prop/selu.cpp
type_prop/shape_of.cpp type_prop/shape_of.cpp
type_prop/shuffle_channels.cpp type_prop/shuffle_channels.cpp
type_prop/softmax.cpp type_prop/softmax.cpp
@ -270,6 +271,7 @@ set(SRC
visitors/op/reverse_sequence.cpp visitors/op/reverse_sequence.cpp
visitors/op/rnn_cell.cpp visitors/op/rnn_cell.cpp
visitors/op/roi_pooling.cpp visitors/op/roi_pooling.cpp
visitors/op/selu.cpp
visitors/op/shuffle_channels.cpp visitors/op/shuffle_channels.cpp
visitors/op/softmax.cpp visitors/op/softmax.cpp
visitors/op/space_to_depth.cpp visitors/op/space_to_depth.cpp
@ -435,6 +437,7 @@ set(MULTI_TEST_SRC
backend/round.in.cpp backend/round.in.cpp
backend/scatter_nd_update.in.cpp backend/scatter_nd_update.in.cpp
backend/select.in.cpp backend/select.in.cpp
backend/selu.in.cpp
backend/shape_of.in.cpp backend/shape_of.in.cpp
backend/sigmoid.in.cpp backend/sigmoid.in.cpp
backend/sign.in.cpp backend/sign.in.cpp

View File

@ -0,0 +1,99 @@
// Copyright (C) 2021 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#include "gtest/gtest.h"
#include "ngraph/ngraph.hpp"
#include "util/engine/test_engines.hpp"
#include "util/test_case.hpp"
#include "util/test_control.hpp"
using namespace std;
using namespace ngraph;
static string s_manifest = "${MANIFEST}";
using TestEngine = test::ENGINE_CLASS_NAME(${BACKEND_NAME});
NGRAPH_TEST(${BACKEND_NAME}, selu_2Dfprop)
{
Shape rt_shape{2};
Shape c_shape{1};
element::Type et = element::f32;
auto input = make_shared<op::Parameter>(et, rt_shape);
auto alpha = op::Constant::create(et, c_shape, {1.67326324});
auto lambda = op::Constant::create(et, c_shape, {1.05070098});
auto selu = make_shared<op::v0::Selu>(input, alpha, lambda);
auto f = make_shared<Function>(selu, ParameterVector{input});
vector<float> input_data{-1, 3};
vector<float> expected_out{-1.1113307, 3.152103};
auto test_case = test::TestCase<TestEngine>(f);
test_case.add_input<float>(rt_shape, input_data);
test_case.add_expected_output(rt_shape, expected_out);
test_case.run();
}
NGRAPH_TEST(${BACKEND_NAME}, selu_4Dfprop)
{
Shape in_shape{4};
Shape c_shape{1};
element::Type et = element::f32;
auto input = make_shared<op::Parameter>(et, in_shape);
auto alpha = op::Constant::create(et, c_shape, {1.67326324});
auto lambda = op::Constant::create(et, c_shape, {1.05070098});
auto selu = make_shared<op::v0::Selu>(input, alpha, lambda);
auto f = make_shared<Function>(selu, ParameterVector{input});
vector<float> in_vec{-1.0, 0.0, 1.0, 2.0};
vector<float> out_vec{-1.1113307, 0., 1.050701, 2.101402};
auto test_case = test::TestCase<TestEngine>(f);
test_case.add_input<float>(in_shape, in_vec);
test_case.add_expected_output<float>(in_shape, out_vec);
test_case.run_with_tolerance_as_fp(1e-4f);
}
NGRAPH_TEST(${BACKEND_NAME}, selu_1Dfprop)
{
Shape in_shape{1};
Shape c_shape{1};
element::Type et = element::f32;
auto input = make_shared<op::Parameter>(et, in_shape);
auto alpha = op::Constant::create(et, c_shape, {1.67326324});
auto lambda = op::Constant::create(et, c_shape, {1.05070098});
auto selu = make_shared<op::v0::Selu>(input, alpha, lambda);
auto f = make_shared<Function>(selu, ParameterVector{input});
vector<float> in_vec{112.0};
vector<float> out_vec{117.67851};
auto test_case = test::TestCase<TestEngine>(f);
test_case.add_input<float>(in_shape, in_vec);
test_case.add_expected_output<float>(in_shape, out_vec);
test_case.run_with_tolerance_as_fp(1e-4f);
}
NGRAPH_TEST(${BACKEND_NAME}, selu_3Dfprop_negative)
{
Shape in_shape{3};
Shape c_shape{1};
element::Type et = element::f32;
auto input = make_shared<op::Parameter>(et, in_shape);
auto alpha = op::Constant::create(et, c_shape, {1.67326324});
auto lambda = op::Constant::create(et, c_shape, {1.05070098});
auto selu = make_shared<op::v0::Selu>(input, alpha, lambda);
auto f = make_shared<Function>(selu, ParameterVector{input});
vector<float> in_vec{-3.0, -12.5, -7.0};
vector<float> out_vec{-1.6705687, -1.7580928, -1.7564961};
auto test_case = test::TestCase<TestEngine>(f);
test_case.add_input<float>(in_shape, in_vec);
test_case.add_expected_output<float>(in_shape, out_vec);
test_case.run_with_tolerance_as_fp(1e-4f);
}

View File

@ -0,0 +1,171 @@
// Copyright (C) 2021 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#include "gtest/gtest.h"
#include "ngraph/ngraph.hpp"
#include "util/type_prop.hpp"
using namespace std;
using namespace ngraph;
TEST(type_prop, selu_basic_inference_f32_3D)
{
const auto param = make_shared<op::Parameter>(element::f32, Shape{1, 32, 32});
const auto alpha = make_shared<op::Parameter>(element::f32, Shape{1});
const auto lambda = make_shared<op::Parameter>(element::f32, Shape{1});
const auto selu = make_shared<op::Selu>(param, alpha, lambda);
ASSERT_EQ(selu->get_element_type(), element::f32);
ASSERT_EQ(selu->get_shape(), (Shape{1, 32, 32}));
}
TEST(type_prop, selu_basic_inference_f16_3D)
{
const auto param = make_shared<op::Parameter>(element::f16, Shape{1, 32, 32});
const auto alpha = make_shared<op::Parameter>(element::f16, Shape{1});
const auto lambda = make_shared<op::Parameter>(element::f16, Shape{1});
const auto selu = make_shared<op::Selu>(param, alpha, lambda);
ASSERT_EQ(selu->get_element_type(), element::f16);
ASSERT_EQ(selu->get_shape(), (Shape{1, 32, 32}));
}
TEST(type_prop, selu_basic_inference_f32_5D)
{
const auto param = make_shared<op::Parameter>(element::f32, Shape{12, 135, 221, 31, 15});
const auto alpha = make_shared<op::Parameter>(element::f32, Shape{1});
const auto lambda = make_shared<op::Parameter>(element::f32, Shape{1});
const auto selu = make_shared<op::Selu>(param, alpha, lambda);
ASSERT_EQ(selu->get_element_type(), element::f32);
ASSERT_EQ(selu->get_shape(), (Shape{12, 135, 221, 31, 15}));
}
TEST(type_prop, selu_basic_inference_f16_5D)
{
const auto param = make_shared<op::Parameter>(element::f16, Shape{12, 135, 221, 31, 15});
const auto alpha = make_shared<op::Parameter>(element::f16, Shape{1});
const auto lambda = make_shared<op::Parameter>(element::f16, Shape{1});
const auto selu = make_shared<op::Selu>(param, alpha, lambda);
ASSERT_EQ(selu->get_element_type(), element::f16);
ASSERT_EQ(selu->get_shape(), (Shape{12, 135, 221, 31, 15}));
}
TEST(type_prop, selu_incompatible_input_type_boolean)
{
// Invalid data input element type
try
{
auto data = make_shared<op::Parameter>(element::boolean, Shape{1, 2, 3, 4});
const auto alpha = make_shared<op::Parameter>(element::boolean, Shape{1});
const auto lambda = make_shared<op::Parameter>(element::boolean, Shape{1});
auto selu = make_shared<op::Selu>(data, alpha, lambda);
// Data input expected to be of numeric type
FAIL() << "Invalid input type not detected";
}
catch (const NodeValidationFailure& error)
{
EXPECT_HAS_SUBSTRING(error.what(), std::string("Input element types must be floating-point"));
}
catch (...)
{
FAIL() << "Input type check failed for unexpected reason";
}
}
TEST(type_prop, selu_incompatible_input_type_i32)
{
// Invalid data input element type
try
{
auto data = make_shared<op::Parameter>(element::i32, Shape{1, 2, 3, 4});
const auto alpha = make_shared<op::Parameter>(element::i32, Shape{1});
const auto lambda = make_shared<op::Parameter>(element::i32, Shape{1});
auto selu = make_shared<op::Selu>(data, alpha, lambda);
// Data input expected to be of numeric type
FAIL() << "Invalid input type not detected";
}
catch (const NodeValidationFailure& error)
{
EXPECT_HAS_SUBSTRING(error.what(), std::string("Input element types must be floating-point"));
}
catch (...)
{
FAIL() << "Input type check failed for unexpected reason";
}
}
TEST(type_prop, selu_incompatible_input_type_u16)
{
// Invalid data input element type
try
{
auto data = make_shared<op::Parameter>(element::u16, Shape{1, 2, 3, 4});
const auto alpha = make_shared<op::Parameter>(element::u16, Shape{1});
const auto lambda = make_shared<op::Parameter>(element::u16, Shape{1});
auto selu = make_shared<op::Selu>(data, alpha, lambda);
// Data input expected to be of numeric type
FAIL() << "Invalid input type not detected";
}
catch (const NodeValidationFailure& error)
{
EXPECT_HAS_SUBSTRING(error.what(), std::string("Input element types must be floating-point"));
}
catch (...)
{
FAIL() << "Input type check failed for unexpected reason";
}
}
TEST(type_prop, selu_incompatible_input_types)
{
// Invalid data input element type
try
{
auto data = make_shared<op::Parameter>(element::f32, Shape{1, 2, 3, 4});
const auto alpha = make_shared<op::Parameter>(element::f32, Shape{1});
const auto lambda = make_shared<op::Parameter>(element::u16, Shape{1});
auto selu = make_shared<op::Selu>(data, alpha, lambda);
// Data input expected to be of numeric type
FAIL() << "Inavlid input types not detected";
}
catch (const NodeValidationFailure& error)
{
EXPECT_HAS_SUBSTRING(error.what(), std::string("Input element types do not match"));
}
catch (...)
{
FAIL() << "Input type check failed for unexpected reason";
}
}
TEST(type_prop, selu_dynamic_rank_input_shape_2D)
{
const PartialShape param_shape{Dimension::dynamic(), 10};
const auto param = std::make_shared<op::Parameter>(element::f32, param_shape);
const auto alpha = make_shared<op::Parameter>(element::f32, Shape{2, 1});
const auto lambda = make_shared<op::Parameter>(element::f32, Shape{1});
const auto op = std::make_shared<op::Selu>(param, alpha, lambda);
ASSERT_TRUE(op->get_output_partial_shape(0).same_scheme(PartialShape{Dimension(), 10}));
}
TEST(type_prop, selu_dynamic_rank_input_shape_3D)
{
const PartialShape param_shape{100, Dimension::dynamic(), 58};
const auto param = std::make_shared<op::Parameter>(element::f32, param_shape);
const auto alpha = make_shared<op::Parameter>(element::f32, Shape{1});
const auto lambda = make_shared<op::Parameter>(element::f32, Shape{1});
const auto op = std::make_shared<op::Selu>(param, alpha, lambda);
ASSERT_TRUE(op->get_output_partial_shape(0).same_scheme(PartialShape{100, Dimension(), 58}));
}
TEST(type_prop, selu_dynamic_rank_input_shape_full)
{
const auto param = std::make_shared<op::Parameter>(element::f32, PartialShape::dynamic());
const auto alpha = make_shared<op::Parameter>(element::f32, Shape{1});
const auto lambda = make_shared<op::Parameter>(element::f32, Shape{1});
const auto op = std::make_shared<op::Selu>(param, alpha, lambda);
ASSERT_TRUE(op->get_output_partial_shape(0).same_scheme(PartialShape::dynamic()));
}

View File

@ -0,0 +1,30 @@
// Copyright (C) 2021 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#include "gtest/gtest.h"
#include "ngraph/ngraph.hpp"
#include "ngraph/op/util/attr_types.hpp"
#include "ngraph/opsets/opset1.hpp"
#include "util/visitor.hpp"
using namespace std;
using namespace ngraph;
using ngraph::test::NodeBuilder;
TEST(attributes, selu_op)
{
NodeBuilder::get_ops().register_factory<opset1::Selu>();
const auto data_input = make_shared<op::Parameter>(element::f32, Shape{1, 2, 3});
const auto alpha = make_shared<op::Parameter>(element::f32, Shape{1});
const auto lambda = make_shared<op::Parameter>(element::f32, Shape{1});
const auto op = make_shared<opset1::Selu>(data_input, alpha, lambda);
NodeBuilder builder(op);
const auto expected_attr_count = 0;
EXPECT_EQ(builder.get_value_map_size(), expected_attr_count);
}