Revise SoftPlus operation reference implementation 37559 (#5938)
* change threshold to 20 instead of max limit of data type * add invalid input type test case * add invalid input data check * add input type dynamic check * add backend test case * add more clarity on spec and align with real implementation * Revert "[CPU] Fix for CoreThreadingTestsWithIterations tests (#5892)" This reverts commit115aa143ef
. * Revert "Revert "[CPU] Fix for CoreThreadingTestsWithIterations tests (#5892)"" This reverts commit95afa50d94
. * Revert "change threshold to 20 instead of max limit of data type" This reverts commit91af825056
. * Revert "add more clarity on spec and align with real implementation" This reverts commita3b232a8fb
. * add visitor attribute test case * Revert "add visitor attribute test case" This reverts commit610728f1ab
. * add attribute test case * revise the attribute visitor test per parametrized visitor API PR: #6181
This commit is contained in:
parent
77912ca06e
commit
5ce5f9e0c8
@ -29,6 +29,13 @@ bool op::v4::SoftPlus::visit_attributes(AttributeVisitor& visitor)
|
|||||||
void op::v4::SoftPlus::validate_and_infer_types()
|
void op::v4::SoftPlus::validate_and_infer_types()
|
||||||
{
|
{
|
||||||
NGRAPH_OP_SCOPE(v4_SoftPlus_validate_and_infer_types);
|
NGRAPH_OP_SCOPE(v4_SoftPlus_validate_and_infer_types);
|
||||||
|
const element::Type& input_et = get_input_element_type(0);
|
||||||
|
|
||||||
|
NODE_VALIDATION_CHECK(this,
|
||||||
|
input_et.is_dynamic() || input_et.is_real(),
|
||||||
|
"Input element type must be float. Got: ",
|
||||||
|
input_et);
|
||||||
|
|
||||||
set_output_size(1);
|
set_output_size(1);
|
||||||
set_output_type(0, get_input_element_type(0), get_input_partial_shape(0));
|
set_output_type(0, get_input_element_type(0), get_input_partial_shape(0));
|
||||||
}
|
}
|
||||||
|
@ -287,6 +287,7 @@ set(SRC
|
|||||||
visitors/op/selu.cpp
|
visitors/op/selu.cpp
|
||||||
visitors/op/shuffle_channels.cpp
|
visitors/op/shuffle_channels.cpp
|
||||||
visitors/op/softmax.cpp
|
visitors/op/softmax.cpp
|
||||||
|
visitors/op/softplus.cpp
|
||||||
visitors/op/space_to_batch.cpp
|
visitors/op/space_to_batch.cpp
|
||||||
visitors/op/space_to_depth.cpp
|
visitors/op/space_to_depth.cpp
|
||||||
visitors/op/split.cpp
|
visitors/op/split.cpp
|
||||||
@ -463,6 +464,7 @@ set(MULTI_TEST_SRC
|
|||||||
backend/sin.in.cpp
|
backend/sin.in.cpp
|
||||||
backend/sinh.in.cpp
|
backend/sinh.in.cpp
|
||||||
backend/softmax.in.cpp
|
backend/softmax.in.cpp
|
||||||
|
backend/softplus.in.cpp
|
||||||
backend/space_to_batch.in.cpp
|
backend/space_to_batch.in.cpp
|
||||||
backend/split.in.cpp
|
backend/split.in.cpp
|
||||||
backend/sqrt.in.cpp
|
backend/sqrt.in.cpp
|
||||||
|
45
ngraph/test/backend/softplus.in.cpp
Normal file
45
ngraph/test/backend/softplus.in.cpp
Normal file
@ -0,0 +1,45 @@
|
|||||||
|
// Copyright (C) 2018-2021 Intel Corporation
|
||||||
|
// SPDX-License-Identifier: Apache-2.0
|
||||||
|
//
|
||||||
|
|
||||||
|
#include <algorithm>
|
||||||
|
#include <cinttypes>
|
||||||
|
#include <cmath>
|
||||||
|
#include <cstdlib>
|
||||||
|
#include <random>
|
||||||
|
#include <string>
|
||||||
|
|
||||||
|
// clang-format off
|
||||||
|
#ifdef ${BACKEND_NAME}_FLOAT_TOLERANCE_BITS
|
||||||
|
#define DEFAULT_FLOAT_TOLERANCE_BITS ${BACKEND_NAME}_FLOAT_TOLERANCE_BITS
|
||||||
|
#endif
|
||||||
|
|
||||||
|
#ifdef ${BACKEND_NAME}_DOUBLE_TOLERANCE_BITS
|
||||||
|
#define DEFAULT_DOUBLE_TOLERANCE_BITS ${BACKEND_NAME}_DOUBLE_TOLERANCE_BITS
|
||||||
|
#endif
|
||||||
|
// clang-format on
|
||||||
|
|
||||||
|
#include "gtest/gtest.h"
|
||||||
|
#include "ngraph/ngraph.hpp"
|
||||||
|
#include "util/engine/test_engines.hpp"
|
||||||
|
#include "util/test_case.hpp"
|
||||||
|
#include "util/test_control.hpp"
|
||||||
|
|
||||||
|
using namespace std;
|
||||||
|
using namespace ngraph;
|
||||||
|
|
||||||
|
static string s_manifest = "${MANIFEST}";
|
||||||
|
using TestEngine = test::ENGINE_CLASS_NAME(${BACKEND_NAME});
|
||||||
|
|
||||||
|
NGRAPH_TEST(${BACKEND_NAME}, softplus)
|
||||||
|
{
|
||||||
|
auto A = make_shared<op::Parameter>(element::f32, Shape{4});
|
||||||
|
auto softplus = make_shared<op::v4::SoftPlus>(A);
|
||||||
|
auto function = make_shared<Function>(NodeVector{softplus}, ParameterVector{A});
|
||||||
|
|
||||||
|
auto test_case = test::TestCase<TestEngine>(function);
|
||||||
|
test_case.add_input(vector<float>{-1.0, 0.0, 1.0, 20.0});
|
||||||
|
test_case.add_expected_output(
|
||||||
|
vector<float>{0.31326166, 0.69314718, 1.3132616, 20.0});
|
||||||
|
test_case.run();
|
||||||
|
}
|
@ -40,3 +40,23 @@ TEST(type_prop, softplus_partial_static_rank)
|
|||||||
(PartialShape{1, Dimension::dynamic(), 6})));
|
(PartialShape{1, Dimension::dynamic(), 6})));
|
||||||
ASSERT_TRUE(softplus_func->get_output_partial_shape(0).rank().is_static());
|
ASSERT_TRUE(softplus_func->get_output_partial_shape(0).rank().is_static());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST(type_prop, softplus_invalid_element_type)
|
||||||
|
{
|
||||||
|
auto data = make_shared<op::Parameter>(element::i32, Shape{2, 2});
|
||||||
|
|
||||||
|
try
|
||||||
|
{
|
||||||
|
auto softplus = make_shared<op::v4::SoftPlus>(data);
|
||||||
|
// Input element type is boolean
|
||||||
|
FAIL() << "Invalid int element type for input not detected";
|
||||||
|
}
|
||||||
|
catch (const NodeValidationFailure& error)
|
||||||
|
{
|
||||||
|
EXPECT_HAS_SUBSTRING(error.what(), "Input element type must be float");
|
||||||
|
}
|
||||||
|
catch (...)
|
||||||
|
{
|
||||||
|
FAIL() << "Numeric element type node validation check failed for unexpected reason";
|
||||||
|
}
|
||||||
|
}
|
||||||
|
12
ngraph/test/visitors/op/softplus.cpp
Normal file
12
ngraph/test/visitors/op/softplus.cpp
Normal file
@ -0,0 +1,12 @@
|
|||||||
|
// Copyright (C) 2021 Intel Corporation
|
||||||
|
// SPDX-License-Identifier: Apache-2.0
|
||||||
|
//
|
||||||
|
|
||||||
|
#include "unary_ops.hpp"
|
||||||
|
|
||||||
|
using Types = ::testing::Types<UnaryOperatorType<ngraph::op::v4::SoftPlus, element::f32>>;
|
||||||
|
|
||||||
|
INSTANTIATE_TYPED_TEST_CASE_P(visitor_without_atrribute,
|
||||||
|
UnaryOperatorVisitor,
|
||||||
|
Types,
|
||||||
|
UnaryOperatorTypeName);
|
Loading…
Reference in New Issue
Block a user