diff --git a/ngraph/core/src/op/softplus.cpp b/ngraph/core/src/op/softplus.cpp index f37008220a4..23f356ae74f 100644 --- a/ngraph/core/src/op/softplus.cpp +++ b/ngraph/core/src/op/softplus.cpp @@ -29,6 +29,13 @@ bool op::v4::SoftPlus::visit_attributes(AttributeVisitor& visitor) void op::v4::SoftPlus::validate_and_infer_types() { NGRAPH_OP_SCOPE(v4_SoftPlus_validate_and_infer_types); + const element::Type& input_et = get_input_element_type(0); + + NODE_VALIDATION_CHECK(this, + input_et.is_dynamic() || input_et.is_real(), + "Input element type must be float. Got: ", + input_et); + set_output_size(1); set_output_type(0, get_input_element_type(0), get_input_partial_shape(0)); } diff --git a/ngraph/test/CMakeLists.txt b/ngraph/test/CMakeLists.txt index 2b7757f0527..6f4e2bb0db4 100644 --- a/ngraph/test/CMakeLists.txt +++ b/ngraph/test/CMakeLists.txt @@ -287,6 +287,7 @@ set(SRC visitors/op/selu.cpp visitors/op/shuffle_channels.cpp visitors/op/softmax.cpp + visitors/op/softplus.cpp visitors/op/space_to_batch.cpp visitors/op/space_to_depth.cpp visitors/op/split.cpp @@ -463,6 +464,7 @@ set(MULTI_TEST_SRC backend/sin.in.cpp backend/sinh.in.cpp backend/softmax.in.cpp + backend/softplus.in.cpp backend/space_to_batch.in.cpp backend/split.in.cpp backend/sqrt.in.cpp diff --git a/ngraph/test/backend/softplus.in.cpp b/ngraph/test/backend/softplus.in.cpp new file mode 100644 index 00000000000..c1e45a10c23 --- /dev/null +++ b/ngraph/test/backend/softplus.in.cpp @@ -0,0 +1,45 @@ +// Copyright (C) 2018-2021 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#include +#include +#include +#include +#include +#include + +// clang-format off +#ifdef ${BACKEND_NAME}_FLOAT_TOLERANCE_BITS +#define DEFAULT_FLOAT_TOLERANCE_BITS ${BACKEND_NAME}_FLOAT_TOLERANCE_BITS +#endif + +#ifdef ${BACKEND_NAME}_DOUBLE_TOLERANCE_BITS +#define DEFAULT_DOUBLE_TOLERANCE_BITS ${BACKEND_NAME}_DOUBLE_TOLERANCE_BITS +#endif +// clang-format on + +#include "gtest/gtest.h" +#include "ngraph/ngraph.hpp" +#include "util/engine/test_engines.hpp" +#include "util/test_case.hpp" +#include "util/test_control.hpp" + +using namespace std; +using namespace ngraph; + +static string s_manifest = "${MANIFEST}"; +using TestEngine = test::ENGINE_CLASS_NAME(${BACKEND_NAME}); + +NGRAPH_TEST(${BACKEND_NAME}, softplus) +{ + auto A = make_shared(element::f32, Shape{4}); + auto softplus = make_shared(A); + auto function = make_shared(NodeVector{softplus}, ParameterVector{A}); + + auto test_case = test::TestCase(function); + test_case.add_input(vector{-1.0, 0.0, 1.0, 20.0}); + test_case.add_expected_output( + vector{0.31326166, 0.69314718, 1.3132616, 20.0}); + test_case.run(); +} diff --git a/ngraph/test/type_prop/softplus.cpp b/ngraph/test/type_prop/softplus.cpp index 7d6ddb49144..d0cec9fdd27 100644 --- a/ngraph/test/type_prop/softplus.cpp +++ b/ngraph/test/type_prop/softplus.cpp @@ -40,3 +40,23 @@ TEST(type_prop, softplus_partial_static_rank) (PartialShape{1, Dimension::dynamic(), 6}))); ASSERT_TRUE(softplus_func->get_output_partial_shape(0).rank().is_static()); } + +TEST(type_prop, softplus_invalid_element_type) +{ + auto data = make_shared(element::i32, Shape{2, 2}); + + try + { + auto softplus = make_shared(data); + // Input element type is boolean + FAIL() << "Invalid int element type for input not detected"; + } + catch (const NodeValidationFailure& error) + { + EXPECT_HAS_SUBSTRING(error.what(), "Input element type must be float"); + } + catch (...) + { + FAIL() << "Numeric element type node validation check failed for unexpected reason"; + } +} diff --git a/ngraph/test/visitors/op/softplus.cpp b/ngraph/test/visitors/op/softplus.cpp new file mode 100644 index 00000000000..e5ff77e3a4e --- /dev/null +++ b/ngraph/test/visitors/op/softplus.cpp @@ -0,0 +1,12 @@ +// Copyright (C) 2021 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#include "unary_ops.hpp" + +using Types = ::testing::Types>; + +INSTANTIATE_TYPED_TEST_CASE_P(visitor_without_atrribute, + UnaryOperatorVisitor, + Types, + UnaryOperatorTypeName);