Revise selu (#5513)
* remove FusedOp from selu operator * add type_prop tests for selu op * add vistors test for selu op * update CMakeLists and fix wrong path to visitors test for selu * fix style * add backend tests for selu op * refactor validate_and_infer_types function for selu op * refactor type_prop tests for selu to catch particular errors Co-authored-by: jdanieck <jozef.daniecki@intel.com>
This commit is contained in:
parent
089f76cc7f
commit
458435ad75
@ -6,10 +6,6 @@
|
||||
|
||||
#include "ngraph/node.hpp"
|
||||
#include "ngraph/op/op.hpp"
|
||||
#include "ngraph/op/util/fused_op.hpp"
|
||||
|
||||
NGRAPH_SUPPRESS_DEPRECATED_START
|
||||
|
||||
namespace ngraph
|
||||
{
|
||||
namespace op
|
||||
@ -17,12 +13,12 @@ namespace ngraph
|
||||
namespace v0
|
||||
{
|
||||
/// \brief Performs a SELU activation function on all elements of the input node
|
||||
class NGRAPH_API Selu : public ngraph::op::util::FusedOp
|
||||
class NGRAPH_API Selu : public ngraph::op::Op
|
||||
{
|
||||
public:
|
||||
static constexpr NodeTypeInfo type_info{"Selu", 0};
|
||||
const NodeTypeInfo& get_type_info() const override { return type_info; }
|
||||
Selu();
|
||||
NGRAPH_RTTI_DECLARATION;
|
||||
|
||||
Selu() = default;
|
||||
/// \brief Constructs a Selu node.
|
||||
///
|
||||
/// \param data - Node producing the input tensor
|
||||
@ -31,9 +27,10 @@ namespace ngraph
|
||||
Selu(const Output<Node>& data,
|
||||
const Output<Node>& alpha,
|
||||
const Output<Node>& lambda);
|
||||
virtual void pre_validate_and_infer_types() override;
|
||||
|
||||
void validate_and_infer_types() override;
|
||||
|
||||
bool visit_attributes(AttributeVisitor& visitor) override;
|
||||
virtual OutputVector decompose_op() const override;
|
||||
|
||||
virtual std::shared_ptr<Node>
|
||||
clone_with_new_inputs(const OutputVector& new_args) const override;
|
||||
@ -42,5 +39,3 @@ namespace ngraph
|
||||
using v0::Selu;
|
||||
} // namespace op
|
||||
} // namespace ngraph
|
||||
|
||||
NGRAPH_SUPPRESS_DEPRECATED_END
|
||||
|
@ -5,65 +5,53 @@
|
||||
#include "ngraph/op/selu.hpp"
|
||||
#include "itt.hpp"
|
||||
|
||||
#include "ngraph/op/add.hpp"
|
||||
#include "ngraph/op/constant.hpp"
|
||||
#include "ngraph/op/exp.hpp"
|
||||
#include "ngraph/op/maximum.hpp"
|
||||
#include "ngraph/op/minimum.hpp"
|
||||
#include "ngraph/op/multiply.hpp"
|
||||
#include "ngraph/op/subtract.hpp"
|
||||
|
||||
using namespace std;
|
||||
using namespace ngraph;
|
||||
|
||||
NGRAPH_SUPPRESS_DEPRECATED_START
|
||||
|
||||
constexpr NodeTypeInfo op::v0::Selu::type_info;
|
||||
|
||||
op::v0::Selu::Selu()
|
||||
: FusedOp()
|
||||
{
|
||||
}
|
||||
NGRAPH_RTTI_DEFINITION(op::v0::Selu, "Selu", 0);
|
||||
|
||||
op::v0::Selu::Selu(const Output<Node>& data, const Output<Node>& alpha, const Output<Node>& lambda)
|
||||
: FusedOp({data, alpha, lambda})
|
||||
: Op({data, alpha, lambda})
|
||||
{
|
||||
constructor_validate_and_infer_types();
|
||||
}
|
||||
|
||||
void ngraph::op::v0::Selu::pre_validate_and_infer_types()
|
||||
void op::v0::Selu::validate_and_infer_types()
|
||||
{
|
||||
set_output_type(0, get_input_element_type(0), get_input_partial_shape(0));
|
||||
NGRAPH_OP_SCOPE(v0_Selu_validate_and_infer_types);
|
||||
auto data_et = get_input_element_type(0);
|
||||
auto alpha_et = get_input_element_type(1);
|
||||
auto lambda_et = get_input_element_type(2);
|
||||
auto result_et = element::dynamic;
|
||||
|
||||
NODE_VALIDATION_CHECK(this,
|
||||
element::Type::merge(result_et, result_et, data_et) &&
|
||||
element::Type::merge(result_et, result_et, alpha_et) &&
|
||||
element::Type::merge(result_et, result_et, lambda_et),
|
||||
"Input element types do not match : ",
|
||||
data_et,
|
||||
" and ",
|
||||
alpha_et,
|
||||
" and ",
|
||||
lambda_et);
|
||||
|
||||
NODE_VALIDATION_CHECK(this,
|
||||
result_et.is_dynamic() || result_et.is_real(),
|
||||
"Input element types must be floating-point. Got: ",
|
||||
result_et);
|
||||
|
||||
set_output_type(0, result_et, get_input_partial_shape(0));
|
||||
}
|
||||
|
||||
bool ngraph::op::v0::Selu::visit_attributes(AttributeVisitor& visitor)
|
||||
bool op::v0::Selu::visit_attributes(AttributeVisitor& visitor)
|
||||
{
|
||||
NGRAPH_OP_SCOPE(v0_Selu_visit_attributes);
|
||||
return true;
|
||||
}
|
||||
|
||||
OutputVector op::v0::Selu::decompose_op() const
|
||||
{
|
||||
const auto data = input_value(0);
|
||||
const auto alpha = input_value(1);
|
||||
const auto lambda = input_value(2);
|
||||
const auto zero_node = op::Constant::create(data.get_element_type(), Shape{1}, {0});
|
||||
|
||||
// lambda * ((max(data, 0) + (alpha * exp(min(data, 0)) - alpha))
|
||||
return {std::make_shared<op::v1::Multiply>(
|
||||
lambda,
|
||||
std::make_shared<op::v1::Add>(
|
||||
std::make_shared<op::v1::Maximum>(data, zero_node),
|
||||
std::make_shared<op::v1::Subtract>(
|
||||
std::make_shared<op::v1::Multiply>(
|
||||
alpha,
|
||||
std::make_shared<op::Exp>(std::make_shared<op::v1::Minimum>(data, zero_node))),
|
||||
alpha)))};
|
||||
}
|
||||
|
||||
shared_ptr<Node> op::v0::Selu::clone_with_new_inputs(const OutputVector& new_args) const
|
||||
{
|
||||
NGRAPH_OP_SCOPE(v0_Selu_clone_with_new_inputs);
|
||||
check_new_args_count(this, new_args);
|
||||
return make_shared<v0::Selu>(new_args.at(0), new_args.at(1), new_args.at(2));
|
||||
return make_shared<op::v0::Selu>(new_args.at(0), new_args.at(1), new_args.at(2));
|
||||
}
|
||||
|
@ -196,6 +196,7 @@ set(SRC
|
||||
type_prop/scatter_nd_update.cpp
|
||||
type_prop/scatter_update.cpp
|
||||
type_prop/select.cpp
|
||||
type_prop/selu.cpp
|
||||
type_prop/shape_of.cpp
|
||||
type_prop/shuffle_channels.cpp
|
||||
type_prop/softmax.cpp
|
||||
@ -270,6 +271,7 @@ set(SRC
|
||||
visitors/op/reverse_sequence.cpp
|
||||
visitors/op/rnn_cell.cpp
|
||||
visitors/op/roi_pooling.cpp
|
||||
visitors/op/selu.cpp
|
||||
visitors/op/shuffle_channels.cpp
|
||||
visitors/op/softmax.cpp
|
||||
visitors/op/space_to_depth.cpp
|
||||
@ -435,6 +437,7 @@ set(MULTI_TEST_SRC
|
||||
backend/round.in.cpp
|
||||
backend/scatter_nd_update.in.cpp
|
||||
backend/select.in.cpp
|
||||
backend/selu.in.cpp
|
||||
backend/shape_of.in.cpp
|
||||
backend/sigmoid.in.cpp
|
||||
backend/sign.in.cpp
|
||||
|
99
ngraph/test/backend/selu.in.cpp
Normal file
99
ngraph/test/backend/selu.in.cpp
Normal file
@ -0,0 +1,99 @@
|
||||
// Copyright (C) 2021 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#include "gtest/gtest.h"
|
||||
#include "ngraph/ngraph.hpp"
|
||||
#include "util/engine/test_engines.hpp"
|
||||
#include "util/test_case.hpp"
|
||||
#include "util/test_control.hpp"
|
||||
|
||||
using namespace std;
|
||||
using namespace ngraph;
|
||||
|
||||
static string s_manifest = "${MANIFEST}";
|
||||
using TestEngine = test::ENGINE_CLASS_NAME(${BACKEND_NAME});
|
||||
|
||||
NGRAPH_TEST(${BACKEND_NAME}, selu_2Dfprop)
|
||||
{
|
||||
Shape rt_shape{2};
|
||||
Shape c_shape{1};
|
||||
element::Type et = element::f32;
|
||||
|
||||
auto input = make_shared<op::Parameter>(et, rt_shape);
|
||||
auto alpha = op::Constant::create(et, c_shape, {1.67326324});
|
||||
auto lambda = op::Constant::create(et, c_shape, {1.05070098});
|
||||
auto selu = make_shared<op::v0::Selu>(input, alpha, lambda);
|
||||
auto f = make_shared<Function>(selu, ParameterVector{input});
|
||||
|
||||
vector<float> input_data{-1, 3};
|
||||
vector<float> expected_out{-1.1113307, 3.152103};
|
||||
|
||||
auto test_case = test::TestCase<TestEngine>(f);
|
||||
test_case.add_input<float>(rt_shape, input_data);
|
||||
test_case.add_expected_output(rt_shape, expected_out);
|
||||
test_case.run();
|
||||
}
|
||||
|
||||
NGRAPH_TEST(${BACKEND_NAME}, selu_4Dfprop)
|
||||
{
|
||||
Shape in_shape{4};
|
||||
Shape c_shape{1};
|
||||
element::Type et = element::f32;
|
||||
|
||||
auto input = make_shared<op::Parameter>(et, in_shape);
|
||||
auto alpha = op::Constant::create(et, c_shape, {1.67326324});
|
||||
auto lambda = op::Constant::create(et, c_shape, {1.05070098});
|
||||
auto selu = make_shared<op::v0::Selu>(input, alpha, lambda);
|
||||
auto f = make_shared<Function>(selu, ParameterVector{input});
|
||||
|
||||
vector<float> in_vec{-1.0, 0.0, 1.0, 2.0};
|
||||
vector<float> out_vec{-1.1113307, 0., 1.050701, 2.101402};
|
||||
|
||||
auto test_case = test::TestCase<TestEngine>(f);
|
||||
test_case.add_input<float>(in_shape, in_vec);
|
||||
test_case.add_expected_output<float>(in_shape, out_vec);
|
||||
test_case.run_with_tolerance_as_fp(1e-4f);
|
||||
}
|
||||
|
||||
NGRAPH_TEST(${BACKEND_NAME}, selu_1Dfprop)
|
||||
{
|
||||
Shape in_shape{1};
|
||||
Shape c_shape{1};
|
||||
element::Type et = element::f32;
|
||||
|
||||
auto input = make_shared<op::Parameter>(et, in_shape);
|
||||
auto alpha = op::Constant::create(et, c_shape, {1.67326324});
|
||||
auto lambda = op::Constant::create(et, c_shape, {1.05070098});
|
||||
auto selu = make_shared<op::v0::Selu>(input, alpha, lambda);
|
||||
auto f = make_shared<Function>(selu, ParameterVector{input});
|
||||
|
||||
vector<float> in_vec{112.0};
|
||||
vector<float> out_vec{117.67851};
|
||||
|
||||
auto test_case = test::TestCase<TestEngine>(f);
|
||||
test_case.add_input<float>(in_shape, in_vec);
|
||||
test_case.add_expected_output<float>(in_shape, out_vec);
|
||||
test_case.run_with_tolerance_as_fp(1e-4f);
|
||||
}
|
||||
|
||||
NGRAPH_TEST(${BACKEND_NAME}, selu_3Dfprop_negative)
|
||||
{
|
||||
Shape in_shape{3};
|
||||
Shape c_shape{1};
|
||||
element::Type et = element::f32;
|
||||
|
||||
auto input = make_shared<op::Parameter>(et, in_shape);
|
||||
auto alpha = op::Constant::create(et, c_shape, {1.67326324});
|
||||
auto lambda = op::Constant::create(et, c_shape, {1.05070098});
|
||||
auto selu = make_shared<op::v0::Selu>(input, alpha, lambda);
|
||||
auto f = make_shared<Function>(selu, ParameterVector{input});
|
||||
|
||||
vector<float> in_vec{-3.0, -12.5, -7.0};
|
||||
vector<float> out_vec{-1.6705687, -1.7580928, -1.7564961};
|
||||
|
||||
auto test_case = test::TestCase<TestEngine>(f);
|
||||
test_case.add_input<float>(in_shape, in_vec);
|
||||
test_case.add_expected_output<float>(in_shape, out_vec);
|
||||
test_case.run_with_tolerance_as_fp(1e-4f);
|
||||
}
|
171
ngraph/test/type_prop/selu.cpp
Normal file
171
ngraph/test/type_prop/selu.cpp
Normal file
@ -0,0 +1,171 @@
|
||||
// Copyright (C) 2021 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#include "gtest/gtest.h"
|
||||
#include "ngraph/ngraph.hpp"
|
||||
#include "util/type_prop.hpp"
|
||||
|
||||
using namespace std;
|
||||
using namespace ngraph;
|
||||
|
||||
TEST(type_prop, selu_basic_inference_f32_3D)
|
||||
{
|
||||
const auto param = make_shared<op::Parameter>(element::f32, Shape{1, 32, 32});
|
||||
const auto alpha = make_shared<op::Parameter>(element::f32, Shape{1});
|
||||
const auto lambda = make_shared<op::Parameter>(element::f32, Shape{1});
|
||||
const auto selu = make_shared<op::Selu>(param, alpha, lambda);
|
||||
|
||||
ASSERT_EQ(selu->get_element_type(), element::f32);
|
||||
ASSERT_EQ(selu->get_shape(), (Shape{1, 32, 32}));
|
||||
}
|
||||
|
||||
TEST(type_prop, selu_basic_inference_f16_3D)
|
||||
{
|
||||
const auto param = make_shared<op::Parameter>(element::f16, Shape{1, 32, 32});
|
||||
const auto alpha = make_shared<op::Parameter>(element::f16, Shape{1});
|
||||
const auto lambda = make_shared<op::Parameter>(element::f16, Shape{1});
|
||||
const auto selu = make_shared<op::Selu>(param, alpha, lambda);
|
||||
|
||||
ASSERT_EQ(selu->get_element_type(), element::f16);
|
||||
ASSERT_EQ(selu->get_shape(), (Shape{1, 32, 32}));
|
||||
}
|
||||
|
||||
TEST(type_prop, selu_basic_inference_f32_5D)
|
||||
{
|
||||
const auto param = make_shared<op::Parameter>(element::f32, Shape{12, 135, 221, 31, 15});
|
||||
const auto alpha = make_shared<op::Parameter>(element::f32, Shape{1});
|
||||
const auto lambda = make_shared<op::Parameter>(element::f32, Shape{1});
|
||||
const auto selu = make_shared<op::Selu>(param, alpha, lambda);
|
||||
|
||||
ASSERT_EQ(selu->get_element_type(), element::f32);
|
||||
ASSERT_EQ(selu->get_shape(), (Shape{12, 135, 221, 31, 15}));
|
||||
}
|
||||
|
||||
TEST(type_prop, selu_basic_inference_f16_5D)
|
||||
{
|
||||
const auto param = make_shared<op::Parameter>(element::f16, Shape{12, 135, 221, 31, 15});
|
||||
const auto alpha = make_shared<op::Parameter>(element::f16, Shape{1});
|
||||
const auto lambda = make_shared<op::Parameter>(element::f16, Shape{1});
|
||||
const auto selu = make_shared<op::Selu>(param, alpha, lambda);
|
||||
|
||||
ASSERT_EQ(selu->get_element_type(), element::f16);
|
||||
ASSERT_EQ(selu->get_shape(), (Shape{12, 135, 221, 31, 15}));
|
||||
}
|
||||
|
||||
TEST(type_prop, selu_incompatible_input_type_boolean)
|
||||
{
|
||||
// Invalid data input element type
|
||||
try
|
||||
{
|
||||
auto data = make_shared<op::Parameter>(element::boolean, Shape{1, 2, 3, 4});
|
||||
const auto alpha = make_shared<op::Parameter>(element::boolean, Shape{1});
|
||||
const auto lambda = make_shared<op::Parameter>(element::boolean, Shape{1});
|
||||
auto selu = make_shared<op::Selu>(data, alpha, lambda);
|
||||
// Data input expected to be of numeric type
|
||||
FAIL() << "Invalid input type not detected";
|
||||
}
|
||||
catch (const NodeValidationFailure& error)
|
||||
{
|
||||
EXPECT_HAS_SUBSTRING(error.what(), std::string("Input element types must be floating-point"));
|
||||
}
|
||||
catch (...)
|
||||
{
|
||||
FAIL() << "Input type check failed for unexpected reason";
|
||||
}
|
||||
}
|
||||
|
||||
TEST(type_prop, selu_incompatible_input_type_i32)
|
||||
{
|
||||
// Invalid data input element type
|
||||
try
|
||||
{
|
||||
auto data = make_shared<op::Parameter>(element::i32, Shape{1, 2, 3, 4});
|
||||
const auto alpha = make_shared<op::Parameter>(element::i32, Shape{1});
|
||||
const auto lambda = make_shared<op::Parameter>(element::i32, Shape{1});
|
||||
auto selu = make_shared<op::Selu>(data, alpha, lambda);
|
||||
// Data input expected to be of numeric type
|
||||
FAIL() << "Invalid input type not detected";
|
||||
}
|
||||
catch (const NodeValidationFailure& error)
|
||||
{
|
||||
EXPECT_HAS_SUBSTRING(error.what(), std::string("Input element types must be floating-point"));
|
||||
}
|
||||
catch (...)
|
||||
{
|
||||
FAIL() << "Input type check failed for unexpected reason";
|
||||
}
|
||||
}
|
||||
|
||||
TEST(type_prop, selu_incompatible_input_type_u16)
|
||||
{
|
||||
// Invalid data input element type
|
||||
try
|
||||
{
|
||||
auto data = make_shared<op::Parameter>(element::u16, Shape{1, 2, 3, 4});
|
||||
const auto alpha = make_shared<op::Parameter>(element::u16, Shape{1});
|
||||
const auto lambda = make_shared<op::Parameter>(element::u16, Shape{1});
|
||||
auto selu = make_shared<op::Selu>(data, alpha, lambda);
|
||||
// Data input expected to be of numeric type
|
||||
FAIL() << "Invalid input type not detected";
|
||||
}
|
||||
catch (const NodeValidationFailure& error)
|
||||
{
|
||||
EXPECT_HAS_SUBSTRING(error.what(), std::string("Input element types must be floating-point"));
|
||||
}
|
||||
catch (...)
|
||||
{
|
||||
FAIL() << "Input type check failed for unexpected reason";
|
||||
}
|
||||
}
|
||||
|
||||
TEST(type_prop, selu_incompatible_input_types)
|
||||
{
|
||||
// Invalid data input element type
|
||||
try
|
||||
{
|
||||
auto data = make_shared<op::Parameter>(element::f32, Shape{1, 2, 3, 4});
|
||||
const auto alpha = make_shared<op::Parameter>(element::f32, Shape{1});
|
||||
const auto lambda = make_shared<op::Parameter>(element::u16, Shape{1});
|
||||
auto selu = make_shared<op::Selu>(data, alpha, lambda);
|
||||
// Data input expected to be of numeric type
|
||||
FAIL() << "Inavlid input types not detected";
|
||||
}
|
||||
catch (const NodeValidationFailure& error)
|
||||
{
|
||||
EXPECT_HAS_SUBSTRING(error.what(), std::string("Input element types do not match"));
|
||||
}
|
||||
catch (...)
|
||||
{
|
||||
FAIL() << "Input type check failed for unexpected reason";
|
||||
}
|
||||
}
|
||||
|
||||
TEST(type_prop, selu_dynamic_rank_input_shape_2D)
|
||||
{
|
||||
const PartialShape param_shape{Dimension::dynamic(), 10};
|
||||
const auto param = std::make_shared<op::Parameter>(element::f32, param_shape);
|
||||
const auto alpha = make_shared<op::Parameter>(element::f32, Shape{2, 1});
|
||||
const auto lambda = make_shared<op::Parameter>(element::f32, Shape{1});
|
||||
const auto op = std::make_shared<op::Selu>(param, alpha, lambda);
|
||||
ASSERT_TRUE(op->get_output_partial_shape(0).same_scheme(PartialShape{Dimension(), 10}));
|
||||
}
|
||||
|
||||
TEST(type_prop, selu_dynamic_rank_input_shape_3D)
|
||||
{
|
||||
const PartialShape param_shape{100, Dimension::dynamic(), 58};
|
||||
const auto param = std::make_shared<op::Parameter>(element::f32, param_shape);
|
||||
const auto alpha = make_shared<op::Parameter>(element::f32, Shape{1});
|
||||
const auto lambda = make_shared<op::Parameter>(element::f32, Shape{1});
|
||||
const auto op = std::make_shared<op::Selu>(param, alpha, lambda);
|
||||
ASSERT_TRUE(op->get_output_partial_shape(0).same_scheme(PartialShape{100, Dimension(), 58}));
|
||||
}
|
||||
|
||||
TEST(type_prop, selu_dynamic_rank_input_shape_full)
|
||||
{
|
||||
const auto param = std::make_shared<op::Parameter>(element::f32, PartialShape::dynamic());
|
||||
const auto alpha = make_shared<op::Parameter>(element::f32, Shape{1});
|
||||
const auto lambda = make_shared<op::Parameter>(element::f32, Shape{1});
|
||||
const auto op = std::make_shared<op::Selu>(param, alpha, lambda);
|
||||
ASSERT_TRUE(op->get_output_partial_shape(0).same_scheme(PartialShape::dynamic()));
|
||||
}
|
30
ngraph/test/visitors/op/selu.cpp
Normal file
30
ngraph/test/visitors/op/selu.cpp
Normal file
@ -0,0 +1,30 @@
|
||||
// Copyright (C) 2021 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#include "gtest/gtest.h"
|
||||
|
||||
#include "ngraph/ngraph.hpp"
|
||||
#include "ngraph/op/util/attr_types.hpp"
|
||||
#include "ngraph/opsets/opset1.hpp"
|
||||
|
||||
#include "util/visitor.hpp"
|
||||
|
||||
using namespace std;
|
||||
using namespace ngraph;
|
||||
using ngraph::test::NodeBuilder;
|
||||
|
||||
TEST(attributes, selu_op)
|
||||
{
|
||||
NodeBuilder::get_ops().register_factory<opset1::Selu>();
|
||||
const auto data_input = make_shared<op::Parameter>(element::f32, Shape{1, 2, 3});
|
||||
const auto alpha = make_shared<op::Parameter>(element::f32, Shape{1});
|
||||
const auto lambda = make_shared<op::Parameter>(element::f32, Shape{1});
|
||||
|
||||
const auto op = make_shared<opset1::Selu>(data_input, alpha, lambda);
|
||||
|
||||
NodeBuilder builder(op);
|
||||
const auto expected_attr_count = 0;
|
||||
|
||||
EXPECT_EQ(builder.get_value_map_size(), expected_attr_count);
|
||||
}
|
Loading…
Reference in New Issue
Block a user