Transformations for hsigmoid op (#2531)

* Add hsigmoid op

* Add tests for hsigmoid

* Add fusion hsigmoid

* Add unit tests for fuse hsigmoid

* Add python api for hsigmoid. Update opset 5

* Update opset5 file

* Add hsigmoid decomposition transformation

* fix

* Move transformations for hsigmoid

* Hot fix

* Fix unit tests

* fix unit tests

* Fix unit test

* Fix code style

* Reverse changes

* Add includes for hsigmoid transformations

* Enable in cldnn

* Refactoring hsigmoid fusion

* Move hsigmoid transforms patterns to cpp file

* Reverse hsigmoid fusion refactoring

* Fix according to code review

* Refactoring transformation

* Hot fix
This commit is contained in:
iliya mironov 2020-10-23 12:35:56 +03:00 committed by GitHub
parent 85b06835aa
commit 0a59be6f1e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
15 changed files with 810 additions and 1 deletions

View File

@ -24,6 +24,7 @@
#include <ngraph/opsets/opset2.hpp> #include <ngraph/opsets/opset2.hpp>
#include <ngraph/opsets/opset3.hpp> #include <ngraph/opsets/opset3.hpp>
#include <ngraph/opsets/opset4.hpp> #include <ngraph/opsets/opset4.hpp>
#include <ngraph/opsets/opset5.hpp>
#include <ngraph/pass/manager.hpp> #include <ngraph/pass/manager.hpp>
#include <generic_ie.hpp> #include <generic_ie.hpp>
#include <transformations/control_flow/unroll_tensor_iterator.hpp> #include <transformations/control_flow/unroll_tensor_iterator.hpp>
@ -117,6 +118,7 @@ InferenceEngine::ICNNNetwork::Ptr clDNNEngine::CloneAndTransformNetwork(const In
std::dynamic_pointer_cast<const ::ngraph::opset2::BatchToSpace>(node) || std::dynamic_pointer_cast<const ::ngraph::opset2::BatchToSpace>(node) ||
std::dynamic_pointer_cast<const ::ngraph::opset2::SpaceToBatch>(node) || std::dynamic_pointer_cast<const ::ngraph::opset2::SpaceToBatch>(node) ||
std::dynamic_pointer_cast<const ::ngraph::opset3::ExtractImagePatches>(node) || std::dynamic_pointer_cast<const ::ngraph::opset3::ExtractImagePatches>(node) ||
std::dynamic_pointer_cast<const ::ngraph::opset5::HSigmoid>(node) ||
std::dynamic_pointer_cast<const ::ngraph::opset4::HSwish>(node) || std::dynamic_pointer_cast<const ::ngraph::opset4::HSwish>(node) ||
std::dynamic_pointer_cast<const ::ngraph::opset4::ReduceL1>(node) || std::dynamic_pointer_cast<const ::ngraph::opset4::ReduceL1>(node) ||
std::dynamic_pointer_cast<const ::ngraph::opset4::ReduceL2>(node) || std::dynamic_pointer_cast<const ::ngraph::opset4::ReduceL2>(node) ||

View File

@ -0,0 +1,79 @@
// Copyright (C) 2018-2020 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#pragma once
#include <memory>
#include <utility>
#include <transformations_visibility.hpp>
#include <ngraph/pass/graph_rewrite.hpp>
namespace ngraph {
namespace pass {
class TRANSFORMATIONS_API HSigmoidFusion;
class TRANSFORMATIONS_API HSigmoidFusionWithReluDiv;
class TRANSFORMATIONS_API HSigmoidFusionWithReluMul;
class TRANSFORMATIONS_API HSigmoidFusionWithoutRelu;
class TRANSFORMATIONS_API HSigmoidFusionWithClamp;
} // namespace pass
} // namespace ngraph
/**
* @ingroup ie_transformation_common_api
* @brief HSigmoidFusion transformation replaces various sub-graphs with a HSigmoid op.
*/
class ngraph::pass::HSigmoidFusion: public ngraph::pass::GraphRewrite {
public:
NGRAPH_RTTI_DECLARATION;
HSigmoidFusion() {
add_matcher<ngraph::pass::HSigmoidFusionWithReluDiv>();
add_matcher<ngraph::pass::HSigmoidFusionWithReluMul>();
add_matcher<ngraph::pass::HSigmoidFusionWithoutRelu>();
add_matcher<ngraph::pass::HSigmoidFusionWithClamp>();
}
};
/**
* @ingroup ie_transformation_common_api
* @brief HSigmoidFusion transformation replaces a sub-graph (x * (min(Relu(x + 3), 6))) / 6 with a HSigmoid op.
*/
class ngraph::pass::HSigmoidFusionWithReluDiv: public ngraph::pass::MatcherPass {
public:
NGRAPH_RTTI_DECLARATION;
HSigmoidFusionWithReluDiv();
};
/**
* @ingroup ie_transformation_common_api
* @brief HSigmoidFusion transformation replaces a sub-graph (x * (min(Relu(x + 3), 6)) * const(1/6) with a HSigmoid op.
*/
class ngraph::pass::HSigmoidFusionWithReluMul: public ngraph::pass::MatcherPass {
public:
NGRAPH_RTTI_DECLARATION;
HSigmoidFusionWithReluMul();
};
/**
* @ingroup ie_transformation_common_api
* @brief HSigmoidFusion transformation replaces a sub-graph x * (min(max(x + 3, 0), 6) / 6) with a HSigmoid op.
*/
class ngraph::pass::HSigmoidFusionWithoutRelu: public ngraph::pass::MatcherPass {
public:
NGRAPH_RTTI_DECLARATION;
HSigmoidFusionWithoutRelu();
};
/**
* @ingroup ie_transformation_common_api
* @brief HSigmoidFusion transformation replaces a sub-graph x * (Clamp(x + 3, 0, 6) * const(1/6)) with a HSigmoid op.
*/
class ngraph::pass::HSigmoidFusionWithClamp: public ngraph::pass::MatcherPass {
public:
NGRAPH_RTTI_DECLARATION;
HSigmoidFusionWithClamp();
};

View File

@ -0,0 +1,26 @@
// Copyright (C) 2018-2020 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#pragma once
#include <transformations_visibility.hpp>
#include <ngraph/pass/graph_rewrite.hpp>
namespace ngraph {
namespace pass {
class TRANSFORMATIONS_API HSigmoidDecomposition;
} // namespace pass
} // namespace ngraph
/**
* @ingroup ie_transformation_common_api
* @brief HSigmoidDecomposition transformation into sub-graph (min(Relu(x + 3), 6) * const(1/6).
*/
class ngraph::pass::HSigmoidDecomposition: public ngraph::pass::MatcherPass {
public:
NGRAPH_RTTI_DECLARATION;
HSigmoidDecomposition();
};

View File

@ -22,6 +22,7 @@
#include "transformations/common_optimizations/pull_transpose_through_fq.hpp" #include "transformations/common_optimizations/pull_transpose_through_fq.hpp"
#include "transformations/common_optimizations/lin_op_sequence_fusion.hpp" #include "transformations/common_optimizations/lin_op_sequence_fusion.hpp"
#include "transformations/common_optimizations/remove_filtering_boxes_by_size.hpp" #include "transformations/common_optimizations/remove_filtering_boxes_by_size.hpp"
#include "transformations/common_optimizations/hsigmoid_fusion.hpp"
#include "transformations/common_optimizations/hswish_fusion.hpp" #include "transformations/common_optimizations/hswish_fusion.hpp"
#include "transformations/common_optimizations/convert_quantize_dequantize.hpp" #include "transformations/common_optimizations/convert_quantize_dequantize.hpp"
#include "transformations/op_conversions/bidirectional_sequences_decomposition.hpp" #include "transformations/op_conversions/bidirectional_sequences_decomposition.hpp"
@ -41,6 +42,7 @@
#include "transformations/op_conversions/reduce_l1_decomposition.hpp" #include "transformations/op_conversions/reduce_l1_decomposition.hpp"
#include "transformations/op_conversions/reduce_l2_decomposition.hpp" #include "transformations/op_conversions/reduce_l2_decomposition.hpp"
#include "transformations/op_conversions/hswish_decomposition.hpp" #include "transformations/op_conversions/hswish_decomposition.hpp"
#include "transformations/op_conversions/hsigmoid_decomposition.hpp"
#include "transformations/op_conversions/log_softmax_decomposition.hpp" #include "transformations/op_conversions/log_softmax_decomposition.hpp"
#include <ngraph/pass/manager.hpp> #include <ngraph/pass/manager.hpp>
@ -68,6 +70,7 @@ bool ngraph::pass::CommonOptimizations::run_on_function(std::shared_ptr<ngraph::
manager.register_pass<ngraph::pass::SoftPlusFusion>(); manager.register_pass<ngraph::pass::SoftPlusFusion>();
manager.register_pass<ngraph::pass::SoftPlusToMishFusion>(); manager.register_pass<ngraph::pass::SoftPlusToMishFusion>();
manager.register_pass<ngraph::pass::SwishFusion>(); manager.register_pass<ngraph::pass::SwishFusion>();
manager.register_pass<ngraph::pass::HSigmoidFusion>();
manager.register_pass<ngraph::pass::HSwishFusion>(); manager.register_pass<ngraph::pass::HSwishFusion>();
manager.register_pass<ngraph::pass::ConvertPadToGroupConvolution, false>(); manager.register_pass<ngraph::pass::ConvertPadToGroupConvolution, false>();
manager.register_pass<ngraph::pass::NormalizeL2Fusion>(); manager.register_pass<ngraph::pass::NormalizeL2Fusion>();
@ -78,6 +81,7 @@ bool ngraph::pass::CommonOptimizations::run_on_function(std::shared_ptr<ngraph::
auto decomp = manager.register_pass<ngraph::pass::GraphRewrite>(); auto decomp = manager.register_pass<ngraph::pass::GraphRewrite>();
decomp->add_matcher<ngraph::pass::ReduceL1Decomposition>(); decomp->add_matcher<ngraph::pass::ReduceL1Decomposition>();
decomp->add_matcher<ngraph::pass::ReduceL2Decomposition>(); decomp->add_matcher<ngraph::pass::ReduceL2Decomposition>();
decomp->add_matcher<ngraph::pass::HSigmoidDecomposition>();
decomp->add_matcher<ngraph::pass::HSwishDecomposition>(); decomp->add_matcher<ngraph::pass::HSwishDecomposition>();
decomp->add_matcher<ngraph::pass::LogSoftmaxDecomposition>(); decomp->add_matcher<ngraph::pass::LogSoftmaxDecomposition>();
decomp->add_matcher<ngraph::pass::ConvertReduceMeanToPooling>(); decomp->add_matcher<ngraph::pass::ConvertReduceMeanToPooling>();

View File

@ -0,0 +1,198 @@
// Copyright (C) 2018-2020 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#include "transformations/common_optimizations/hsigmoid_fusion.hpp"
#include "transformations/utils/utils.hpp"
#include <memory>
#include <ngraph/opsets/opset5.hpp>
#include <ngraph/rt_info.hpp>
#include <ngraph/pattern/op/wrap_type.hpp>
NGRAPH_RTTI_DEFINITION(ngraph::pass::HSigmoidFusion, "HSigmoidFusion", 0);
NGRAPH_RTTI_DEFINITION(ngraph::pass::HSigmoidFusionWithReluDiv, "HSigmoidFusionWithReluDiv", 0);
ngraph::pass::HSigmoidFusionWithReluDiv::HSigmoidFusionWithReluDiv() {
// Replaces a sub-graph ((min(Relu(x + 3), 6)) / 6 with a HSigmoid op.
auto input = ngraph::pattern::any_input();
auto add_constant = ngraph::pattern::wrap_type<ngraph::opset4::Constant>();
auto add = std::make_shared<ngraph::opset4::Add>(input, add_constant);
auto relu = std::make_shared<ngraph::opset4::Relu>(add);
auto min_constant = ngraph::pattern::wrap_type<ngraph::opset4::Constant>();
auto min = std::make_shared<ngraph::opset4::Minimum>(relu, min_constant);
auto div_constant = ngraph::pattern::wrap_type<ngraph::opset4::Constant>();
auto div = std::make_shared<ngraph::opset4::Divide>(min, div_constant);
ngraph::matcher_pass_callback callback = [=](ngraph::pattern::Matcher &m) {
auto &pattern_to_output = m.get_pattern_value_map();
auto x_output = pattern_to_output.at(input);
auto add_const_value = std::dynamic_pointer_cast<ngraph::opset4::Constant>(pattern_to_output.at(add_constant).get_node_shared_ptr());
auto min_const_value = std::dynamic_pointer_cast<ngraph::opset4::Constant>(pattern_to_output.at(min_constant).get_node_shared_ptr());
auto div_const_value = std::dynamic_pointer_cast<ngraph::opset4::Constant>(pattern_to_output.at(div_constant).get_node_shared_ptr());
bool valid_constant_values = op::util::has_constant_value<float>(add_const_value, 3.0)
&& op::util::has_constant_value<float>(min_const_value, 6.0)
&& op::util::has_constant_value<float>(div_const_value, 6.0);
if (!valid_constant_values) {
return false;
}
auto hsigmoid = std::make_shared<ngraph::opset5::HSigmoid>(x_output);
hsigmoid->set_friendly_name(m.get_match_root()->get_friendly_name());
ngraph::copy_runtime_info({ pattern_to_output.at(add).get_node_shared_ptr(),
pattern_to_output.at(relu).get_node_shared_ptr(),
pattern_to_output.at(min).get_node_shared_ptr(),
pattern_to_output.at(div).get_node_shared_ptr(),
},
hsigmoid);
ngraph::replace_node(m.get_match_root(), hsigmoid);
return true;
};
auto m = std::make_shared<ngraph::pattern::Matcher>(div, "HSigmoidWithReluDivFusion");
register_matcher(m, callback);
}
NGRAPH_RTTI_DEFINITION(ngraph::pass::HSigmoidFusionWithReluMul, "HSigmoidFusionWithReluMul", 0);
ngraph::pass::HSigmoidFusionWithReluMul::HSigmoidFusionWithReluMul() {
// Replaces a sub-graph ((min(Relu(x + 3), 6)) * const(1/6) with a HSigmoid op.
auto input = ngraph::pattern::any_input();
auto add_constant = ngraph::pattern::wrap_type<ngraph::opset4::Constant>();
auto add = std::make_shared<ngraph::opset4::Add>(input, add_constant);
auto relu = std::make_shared<ngraph::opset4::Relu>(add);
auto min_constant = ngraph::pattern::wrap_type<ngraph::opset4::Constant>();
auto min = std::make_shared<ngraph::opset4::Minimum>(relu, min_constant);
//auto mul_first = std::make_shared<ngraph::opset4::Multiply>(input, min);
auto mul_constant = ngraph::pattern::wrap_type<ngraph::opset4::Constant>();
auto mul_second = std::make_shared<ngraph::opset4::Multiply>(min, mul_constant);
ngraph::matcher_pass_callback callback = [=](ngraph::pattern::Matcher &m) {
auto &pattern_to_output = m.get_pattern_value_map();
auto x_output = pattern_to_output.at(input);
auto add_const_value = std::dynamic_pointer_cast<ngraph::opset4::Constant>(pattern_to_output.at(add_constant).get_node_shared_ptr());
auto min_const_value = std::dynamic_pointer_cast<ngraph::opset4::Constant>(pattern_to_output.at(min_constant).get_node_shared_ptr());
auto mul_const_value = std::dynamic_pointer_cast<ngraph::opset4::Constant>(pattern_to_output.at(mul_constant).get_node_shared_ptr());
bool valid_constant_values = op::util::has_constant_value<float>(add_const_value, 3.0f)
&& op::util::has_constant_value<float>(min_const_value, 6.0f)
&& op::util::has_constant_value<float>(mul_const_value, (1.0f/6.0f), 0.0001f);
if (!valid_constant_values) {
return false;
}
auto hsigmoid = std::make_shared<ngraph::opset5::HSigmoid>(x_output);
hsigmoid->set_friendly_name(m.get_match_root()->get_friendly_name());
ngraph::copy_runtime_info({ pattern_to_output.at(add).get_node_shared_ptr(),
pattern_to_output.at(relu).get_node_shared_ptr(),
pattern_to_output.at(min).get_node_shared_ptr(),
pattern_to_output.at(mul_second).get_node_shared_ptr()
},
hsigmoid);
ngraph::replace_node(m.get_match_root(), hsigmoid);
return true;
};
auto m = std::make_shared<ngraph::pattern::Matcher>(mul_second, "HSigmoidWithReluMulFusion");
register_matcher(m, callback);
}
NGRAPH_RTTI_DEFINITION(ngraph::pass::HSigmoidFusionWithoutRelu, "HSigmoidFusionWithoutRelu", 0);
ngraph::pass::HSigmoidFusionWithoutRelu::HSigmoidFusionWithoutRelu() {
// Replaces a sub-graph (min(max(x + 3, 0), 6) / 6) with a HSigmoid op.
auto input = ngraph::pattern::any_input();
auto add_constant = ngraph::pattern::wrap_type<ngraph::opset4::Constant>();
auto add = std::make_shared<ngraph::opset4::Add>(input, add_constant);
auto max_constant = ngraph::pattern::wrap_type<ngraph::opset4::Constant>();
auto max = std::make_shared<ngraph::opset4::Maximum>(add, max_constant);
auto min_constant = ngraph::pattern::wrap_type<ngraph::opset4::Constant>();
auto min = std::make_shared<ngraph::opset4::Minimum>(max, min_constant);
auto div_constant = ngraph::pattern::wrap_type<ngraph::opset4::Constant>();
auto div = std::make_shared<ngraph::opset4::Divide>(min, div_constant);
auto mul = std::make_shared<ngraph::opset4::Multiply>(input, div);
ngraph::matcher_pass_callback callback = [=](ngraph::pattern::Matcher &m) {
auto &pattern_to_output = m.get_pattern_value_map();
auto x_output = pattern_to_output.at(input);
auto add_const_value = std::dynamic_pointer_cast<ngraph::opset4::Constant>(pattern_to_output.at(add_constant).get_node_shared_ptr());
auto max_const_value = std::dynamic_pointer_cast<ngraph::opset4::Constant>(pattern_to_output.at(max_constant).get_node_shared_ptr());
auto min_const_value = std::dynamic_pointer_cast<ngraph::opset4::Constant>(pattern_to_output.at(min_constant).get_node_shared_ptr());
auto div_const_value = std::dynamic_pointer_cast<ngraph::opset4::Constant>(pattern_to_output.at(div_constant).get_node_shared_ptr());
bool valid_constant_values = op::util::has_constant_value<float>(add_const_value, 3.0f)
&& op::util::has_constant_value<float>(max_const_value, 0.0f)
&& op::util::has_constant_value<float>(min_const_value, 6.0f)
&& op::util::has_constant_value<float>(div_const_value, 6.0f);
if (!valid_constant_values) {
return false;
}
auto hsigmoid = std::make_shared<ngraph::opset5::HSigmoid>(x_output);
hsigmoid->set_friendly_name(m.get_match_root()->get_friendly_name());
ngraph::copy_runtime_info({ pattern_to_output.at(add).get_node_shared_ptr(),
pattern_to_output.at(max).get_node_shared_ptr(),
pattern_to_output.at(min).get_node_shared_ptr(),
pattern_to_output.at(div).get_node_shared_ptr()
},
hsigmoid);
ngraph::replace_node(m.get_match_root(), hsigmoid);
return true;
};
auto m = std::make_shared<ngraph::pattern::Matcher>(div, "HSigmoidWithoutReluFusion");
register_matcher(m, callback);
}
NGRAPH_RTTI_DEFINITION(ngraph::pass::HSigmoidFusionWithClamp, "HSigmoidFusionWithClamp", 0);
ngraph::pass::HSigmoidFusionWithClamp::HSigmoidFusionWithClamp() {
// Replaces a sub-graph (Clamp(x + 3, 0, 6) * const(1/6)) with a HSigmoid op.
auto input = ngraph::pattern::any_input();
auto add_constant = ngraph::pattern::wrap_type<ngraph::opset4::Constant>();
auto add = std::make_shared<ngraph::opset4::Add>(input, add_constant);
auto clamp = std::make_shared<ngraph::op::v0::Clamp>(add, 0.0f, 6.0f);
auto mul_constant = ngraph::pattern::wrap_type<ngraph::opset4::Constant>();
auto mul_first = std::make_shared<ngraph::opset4::Multiply>(clamp, mul_constant);
ngraph::matcher_pass_callback callback = [=](ngraph::pattern::Matcher &m) {
auto &pattern_to_output = m.get_pattern_value_map();
auto x_output = pattern_to_output.at(input);
auto add_const_value = std::dynamic_pointer_cast<ngraph::opset4::Constant>(pattern_to_output.at(add_constant).get_node_shared_ptr());
auto mul_const_value = std::dynamic_pointer_cast<ngraph::opset4::Constant>(pattern_to_output.at(mul_constant).get_node_shared_ptr());
bool valid_constant_values = op::util::has_constant_value(add_const_value, 3.0)
&& op::util::has_constant_value(mul_const_value, (1.0/6.0), 0.0001);
if (!valid_constant_values) {
return false;
}
auto hsigmoid = std::make_shared<ngraph::opset5::HSigmoid>(x_output);
hsigmoid->set_friendly_name(m.get_match_root()->get_friendly_name());
ngraph::copy_runtime_info({ pattern_to_output.at(add).get_node_shared_ptr(),
pattern_to_output.at(clamp).get_node_shared_ptr(),
pattern_to_output.at(mul_first).get_node_shared_ptr()
},
hsigmoid);
ngraph::replace_node(m.get_match_root(), hsigmoid);
return true;
};
auto m = std::make_shared<ngraph::pattern::Matcher>(mul_first, "HSigmoidWithClampFusion");
register_matcher(m, callback);
}

View File

@ -0,0 +1,44 @@
// Copyright (C) 2018-2020 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#include "transformations/op_conversions/hsigmoid_decomposition.hpp"
#include <memory>
#include <ngraph/opsets/opset5.hpp>
#include <ngraph/rt_info.hpp>
#include <ngraph/pattern/op/wrap_type.hpp>
NGRAPH_RTTI_DEFINITION(ngraph::pass::HSigmoidDecomposition, "HSigmoidDecomposition", 0);
ngraph::pass::HSigmoidDecomposition::HSigmoidDecomposition() {
// Decomposes HSigmoid(x) op into sub-graph (min(Relu(x + 3), 6) * const(1/6)
auto hsigmoid = ngraph::pattern::wrap_type<opset5::HSigmoid>();
ngraph::matcher_pass_callback callback = [=](ngraph::pattern::Matcher &m) {
auto &pattern_to_output = m.get_pattern_value_map();
auto hsigmoid_node = pattern_to_output.at(hsigmoid).get_node_shared_ptr();
if (m_transformation_callback(hsigmoid_node)) {
return false;
}
auto input_type = hsigmoid_node->input_value(0).get_element_type();
auto add_constant = ngraph::opset5::Constant::create(input_type, ngraph::Shape{}, {3.0});
auto add = std::make_shared<ngraph::opset5::Add>(hsigmoid_node->input_value(0), add_constant);
auto relu = std::make_shared<ngraph::opset5::Relu>(add);
auto min_constant = ngraph::opset5::Constant::create(input_type, ngraph::Shape{}, {6.0});
auto min = register_new_node<ngraph::opset5::Minimum>(relu, min_constant);
auto mul = std::make_shared<ngraph::opset5::Multiply>(hsigmoid_node->input_value(0), min);
mul->set_friendly_name(m.get_match_root()->get_friendly_name());
ngraph::copy_runtime_info(hsigmoid_node,
{add_constant, add, relu, min_constant, min, mul});
ngraph::replace_node(m.get_match_root(), mul);
return true;
};
auto m = std::make_shared<ngraph::pattern::Matcher>(hsigmoid, "HSigmoidDecomposition");
register_matcher(m, callback);
}

View File

@ -0,0 +1,50 @@
// Copyright (C) 2020 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#include <gtest/gtest.h>
#include <string>
#include <memory>
#include <ngraph/function.hpp>
#include <ngraph/opsets/opset5.hpp>
#include <ngraph/pass/manager.hpp>
#include <transformations/op_conversions/hsigmoid_decomposition.hpp>
#include <transformations/init_node_info.hpp>
#include <transformations/utils/utils.hpp>
#include "common_test_utils/ngraph_test_utils.hpp"
using namespace testing;
TEST(TransformationTests, HSigmoidDecompositionTest) {
std::shared_ptr<ngraph::Function> f(nullptr), f_ref(nullptr);
{
auto input = std::make_shared<ngraph::opset5::Parameter>(ngraph::element::f32, ngraph::PartialShape::dynamic(1));
auto hsigmoid = std::make_shared<ngraph::opset5::HSigmoid>(input);
f = std::make_shared<ngraph::Function>(ngraph::NodeVector{hsigmoid}, ngraph::ParameterVector{input});
ngraph::pass::Manager manager;
manager.register_pass<ngraph::pass::InitNodeInfo>();
manager.register_pass<ngraph::pass::HSigmoidDecomposition>();
manager.run_passes(f);
ASSERT_NO_THROW(check_rt_info(f));
}
{
auto input = std::make_shared<ngraph::opset5::Parameter>(ngraph::element::f32, ngraph::PartialShape::dynamic(1));
auto add_constant = ngraph::opset5::Constant::create(ngraph::element::f32, ngraph::Shape{}, {3.0});
auto add = std::make_shared<ngraph::opset5::Add>(input, add_constant);
auto relu = std::make_shared<ngraph::opset5::Relu>(add);
auto min_constant = ngraph::opset5::Constant::create(ngraph::element::f32, ngraph::Shape{}, {6.0});
auto min = std::make_shared<ngraph::opset5::Minimum>(relu, min_constant);
auto mul = std::make_shared<ngraph::opset5::Multiply>(input, min);
f_ref = std::make_shared<ngraph::Function>(ngraph::NodeVector{mul}, ngraph::ParameterVector{input});
}
auto res = compare_functions(f, f_ref);
ASSERT_TRUE(res.first) << res.second;
}

View File

@ -0,0 +1,328 @@
// Copyright (C) 2020 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#include <gtest/gtest.h>
#include <string>
#include <memory>
#include <ngraph/function.hpp>
#include <ngraph/opsets/opset5.hpp>
#include <ngraph/pass/manager.hpp>
#include <transformations/common_optimizations/hsigmoid_fusion.hpp>
#include <transformations/init_node_info.hpp>
#include <transformations/utils/utils.hpp>
#include "common_test_utils/ngraph_test_utils.hpp"
using namespace testing;
TEST(TransformationTests, HSigmoidFusionWithReluDivF16) {
std::shared_ptr<ngraph::Function> f(nullptr), f_ref(nullptr);
{
auto input = std::make_shared<ngraph::opset4::Parameter>(ngraph::element::f16, ngraph::PartialShape::dynamic(1));
auto add_constant = ngraph::opset4::Constant::create(ngraph::element::f16, ngraph::Shape{}, {3.0});
auto add = std::make_shared<ngraph::opset4::Add>(input, add_constant);
auto relu = std::make_shared<ngraph::opset4::Relu>(add);
auto min_constant = ngraph::opset4::Constant::create(ngraph::element::f16, ngraph::Shape{}, {6.0});
auto min = std::make_shared<ngraph::opset4::Minimum>(relu, min_constant);
auto div_constant = ngraph::opset4::Constant::create(ngraph::element::f16, ngraph::Shape{}, {6.0});
auto div = std::make_shared<ngraph::opset4::Divide>(min, div_constant);
f = std::make_shared<ngraph::Function>(ngraph::NodeVector{div}, ngraph::ParameterVector{input});
ngraph::pass::Manager manager;
manager.register_pass<ngraph::pass::InitNodeInfo>();
manager.register_pass<ngraph::pass::HSigmoidFusion>();
manager.run_passes(f);
ASSERT_NO_THROW(check_rt_info(f));
}
{
auto input = std::make_shared<ngraph::opset4::Parameter>(ngraph::element::f16, ngraph::PartialShape::dynamic(1));
auto hsigmoid = std::make_shared<ngraph::opset5::HSigmoid>(input);
f_ref = std::make_shared<ngraph::Function>(ngraph::NodeVector{hsigmoid}, ngraph::ParameterVector{input});
}
auto res = compare_functions(f, f_ref);
ASSERT_TRUE(res.first) << res.second;
}
TEST(TransformationTests, HSigmoidFusionWithReluDivF32) {
std::shared_ptr<ngraph::Function> f(nullptr), f_ref(nullptr);
{
auto input = std::make_shared<ngraph::opset4::Parameter>(ngraph::element::f32, ngraph::Shape{});
auto add_constant = ngraph::opset4::Constant::create(ngraph::element::f32, ngraph::Shape{}, {3.0});
auto add = std::make_shared<ngraph::opset4::Add>(input, add_constant);
auto relu = std::make_shared<ngraph::opset4::Relu>(add);
auto min_constant = ngraph::opset4::Constant::create(ngraph::element::f32, ngraph::Shape{}, {6.0});
auto min = std::make_shared<ngraph::opset4::Minimum>(relu, min_constant);
auto div_constant = ngraph::opset4::Constant::create(ngraph::element::f32, ngraph::Shape{}, {6.0});
auto div = std::make_shared<ngraph::opset4::Divide>(min, div_constant);
f = std::make_shared<ngraph::Function>(ngraph::NodeVector{div}, ngraph::ParameterVector{input});
ngraph::pass::Manager manager;
manager.register_pass<ngraph::pass::InitNodeInfo>();
manager.register_pass<ngraph::pass::HSigmoidFusion>();
manager.run_passes(f);
ASSERT_NO_THROW(check_rt_info(f));
}
{
auto input = std::make_shared<ngraph::opset4::Parameter>(ngraph::element::f32, ngraph::Shape{});
auto hsigmoid = std::make_shared<ngraph::opset5::HSigmoid>(input);
f_ref = std::make_shared<ngraph::Function>(ngraph::NodeVector{hsigmoid}, ngraph::ParameterVector{input});
}
auto res = compare_functions(f, f_ref);
ASSERT_TRUE(res.first) << res.second;
}
TEST(TransformationTests, HSigmoidFusionWithReluMul) {
std::shared_ptr<ngraph::Function> f(nullptr), f_ref(nullptr);
{
auto input = std::make_shared<ngraph::opset4::Parameter>(ngraph::element::f16, ngraph::PartialShape::dynamic(1));
auto add_constant = ngraph::opset4::Constant::create(ngraph::element::f16, ngraph::Shape{}, {3.0});
auto add = std::make_shared<ngraph::opset4::Add>(input, add_constant);
auto relu = std::make_shared<ngraph::opset4::Relu>(add);
auto min_constant = ngraph::opset4::Constant::create(ngraph::element::f16, ngraph::Shape{}, {6.0});
auto min = std::make_shared<ngraph::opset4::Minimum>(relu, min_constant);
auto mul_constant = ngraph::opset4::Constant::create(ngraph::element::f16, ngraph::Shape{}, {0.1666666716});
auto mul_second = std::make_shared<ngraph::opset4::Multiply>(min, mul_constant);
f = std::make_shared<ngraph::Function>(ngraph::NodeVector{mul_second}, ngraph::ParameterVector{input});
ngraph::pass::Manager manager;
manager.register_pass<ngraph::pass::InitNodeInfo>();
manager.register_pass<ngraph::pass::HSigmoidFusion>();
manager.run_passes(f);
ASSERT_NO_THROW(check_rt_info(f));
}
{
auto input = std::make_shared<ngraph::opset4::Parameter>(ngraph::element::f16, ngraph::PartialShape::dynamic(1));
auto hsigmoid = std::make_shared<ngraph::opset5::HSigmoid>(input);
f_ref = std::make_shared<ngraph::Function>(ngraph::NodeVector{hsigmoid}, ngraph::ParameterVector{input});
}
auto res = compare_functions(f, f_ref);
ASSERT_TRUE(res.first) << res.second;
}
TEST(TransformationTests, HSigmoidFusionWithoutRelu) {
std::shared_ptr<ngraph::Function> f(nullptr), f_ref(nullptr);
{
auto input = std::make_shared<ngraph::opset4::Parameter>(ngraph::element::f16, ngraph::PartialShape::dynamic(1));
auto add_constant = ngraph::opset4::Constant::create(ngraph::element::f16, ngraph::Shape{}, {3.0});
auto add = std::make_shared<ngraph::opset4::Add>(input, add_constant);
auto max_constant = ngraph::opset4::Constant::create(ngraph::element::f16, ngraph::Shape{}, {0.0});
auto max = std::make_shared<ngraph::opset4::Maximum>(add, max_constant);
auto min_constant = ngraph::opset4::Constant::create(ngraph::element::f16, ngraph::Shape{}, {6.0});
auto min = std::make_shared<ngraph::opset4::Minimum>(max, min_constant);
auto div_constant = ngraph::opset4::Constant::create(ngraph::element::f16, ngraph::Shape{}, {6.0});
auto div = std::make_shared<ngraph::opset4::Divide>(min, div_constant);
f = std::make_shared<ngraph::Function>(ngraph::NodeVector{div}, ngraph::ParameterVector{input});
ngraph::pass::Manager manager;
manager.register_pass<ngraph::pass::InitNodeInfo>();
manager.register_pass<ngraph::pass::HSigmoidFusion>();
manager.run_passes(f);
ASSERT_NO_THROW(check_rt_info(f));
}
{
auto input = std::make_shared<ngraph::opset4::Parameter>(ngraph::element::f16, ngraph::PartialShape::dynamic(1));
auto hsigmoid = std::make_shared<ngraph::opset5::HSigmoid>(input);
f_ref = std::make_shared<ngraph::Function>(ngraph::NodeVector{hsigmoid}, ngraph::ParameterVector{input});
}
auto res = compare_functions(f, f_ref);
ASSERT_TRUE(res.first) << res.second;
}
TEST(TransformationTests, HSigmoidFusionWithClamp) {
std::shared_ptr<ngraph::Function> f(nullptr), f_ref(nullptr);
{
auto input = std::make_shared<ngraph::opset4::Parameter>(ngraph::element::f16, ngraph::PartialShape::dynamic(1));
auto add_constant = ngraph::opset4::Constant::create(ngraph::element::f16, ngraph::Shape{}, {3.0});
auto add = std::make_shared<ngraph::opset4::Add>(input, add_constant);
auto clamp = std::make_shared<ngraph::op::v0::Clamp>(add, 0.0f, 6.0f);
auto mul_constant = ngraph::opset4::Constant::create(ngraph::element::f16, ngraph::Shape{}, {1.0 / 6.0});
auto mul_first = std::make_shared<ngraph::opset4::Multiply>(clamp, mul_constant);
f = std::make_shared<ngraph::Function>(ngraph::NodeVector{mul_first}, ngraph::ParameterVector{input});
ngraph::pass::Manager manager;
manager.register_pass<ngraph::pass::InitNodeInfo>();
manager.register_pass<ngraph::pass::HSigmoidFusion>();
manager.run_passes(f);
ASSERT_NO_THROW(check_rt_info(f));
}
{
auto input = std::make_shared<ngraph::opset4::Parameter>(ngraph::element::f16, ngraph::PartialShape::dynamic(1));
auto hsigmoid = std::make_shared<ngraph::opset5::HSigmoid>(input);
f_ref = std::make_shared<ngraph::Function>(ngraph::NodeVector{hsigmoid}, ngraph::ParameterVector{input});
}
auto res = compare_functions(f, f_ref);
ASSERT_TRUE(res.first) << res.second;
}
TEST(TransformationTests, HSigmoidFusionWithReluMulWrongConstValue) {
std::shared_ptr<ngraph::Function> f(nullptr), f_ref(nullptr);
{
auto input = std::make_shared<ngraph::opset4::Parameter>(ngraph::element::f16, ngraph::PartialShape::dynamic(1));
auto add_constant = ngraph::opset4::Constant::create(ngraph::element::f16, ngraph::Shape{}, {3.0});
auto add = std::make_shared<ngraph::opset4::Add>(input, add_constant);
auto relu = std::make_shared<ngraph::opset4::Relu>(add);
auto min_constant = ngraph::opset4::Constant::create(ngraph::element::f16, ngraph::Shape{}, {6.0});
auto min = std::make_shared<ngraph::opset4::Minimum>(relu, min_constant);
auto mul_constant = ngraph::opset4::Constant::create(ngraph::element::f16, ngraph::Shape{}, {0.167});
auto mul_second = std::make_shared<ngraph::opset4::Multiply>(min, mul_constant);
f = std::make_shared<ngraph::Function>(ngraph::NodeVector{mul_second}, ngraph::ParameterVector{input});
ngraph::pass::Manager manager;
manager.register_pass<ngraph::pass::InitNodeInfo>();
manager.register_pass<ngraph::pass::HSigmoidFusion>();
manager.run_passes(f);
ASSERT_NO_THROW(check_rt_info(f));
}
{
auto input = std::make_shared<ngraph::opset4::Parameter>(ngraph::element::f16, ngraph::PartialShape::dynamic(1));
auto add_constant = ngraph::opset4::Constant::create(ngraph::element::f16, ngraph::Shape{}, {3.0});
auto add = std::make_shared<ngraph::opset4::Add>(input, add_constant);
auto relu = std::make_shared<ngraph::opset4::Relu>(add);
auto min_constant = ngraph::opset4::Constant::create(ngraph::element::f16, ngraph::Shape{}, {6.0});
auto min = std::make_shared<ngraph::opset4::Minimum>(relu, min_constant);
auto mul_constant = ngraph::opset4::Constant::create(ngraph::element::f16, ngraph::Shape{}, {0.167});
auto mul_second = std::make_shared<ngraph::opset4::Multiply>(min, mul_constant);
f_ref = std::make_shared<ngraph::Function>(ngraph::NodeVector{mul_second}, ngraph::ParameterVector{input});
}
auto res = compare_functions(f, f_ref);
ASSERT_TRUE(res.first) << res.second;
}
TEST(TransformationTests, HSigmoidFusionWithReluDivWrongConstValue) {
std::shared_ptr<ngraph::Function> f(nullptr), f_ref(nullptr);
{
auto input = std::make_shared<ngraph::opset4::Parameter>(ngraph::element::f16, ngraph::Shape{});
auto add_constant = ngraph::opset4::Constant::create(ngraph::element::f16, ngraph::Shape{}, {3.01});
auto add = std::make_shared<ngraph::opset4::Add>(input, add_constant);
auto relu = std::make_shared<ngraph::opset4::Relu>(add);
auto min_constant = ngraph::opset4::Constant::create(ngraph::element::f16, ngraph::Shape{}, {6.002});
auto min = std::make_shared<ngraph::opset4::Minimum>(relu, min_constant);
auto div_constant = ngraph::opset4::Constant::create(ngraph::element::f16, ngraph::Shape{}, {0.0});
auto div = std::make_shared<ngraph::opset4::Divide>(min, div_constant);
f = std::make_shared<ngraph::Function>(ngraph::NodeVector{div}, ngraph::ParameterVector{input});
ngraph::pass::Manager manager;
manager.register_pass<ngraph::pass::InitNodeInfo>();
manager.register_pass<ngraph::pass::HSigmoidFusion>();
manager.run_passes(f);
ASSERT_NO_THROW(check_rt_info(f));
}
{
auto input = std::make_shared<ngraph::opset4::Parameter>(ngraph::element::f16, ngraph::Shape{});
auto add_constant = ngraph::opset4::Constant::create(ngraph::element::f16, ngraph::Shape{}, {3.01});
auto add = std::make_shared<ngraph::opset4::Add>(input, add_constant);
auto relu = std::make_shared<ngraph::opset4::Relu>(add);
auto min_constant = ngraph::opset4::Constant::create(ngraph::element::f16, ngraph::Shape{}, {6.002});
auto min = std::make_shared<ngraph::opset4::Minimum>(relu, min_constant);
auto div_constant = ngraph::opset4::Constant::create(ngraph::element::f16, ngraph::Shape{}, {0.0});
auto div = std::make_shared<ngraph::opset4::Divide>(min, div_constant);
f_ref = std::make_shared<ngraph::Function>(ngraph::NodeVector{div}, ngraph::ParameterVector{input});
}
auto res = compare_functions(f, f_ref);
ASSERT_TRUE(res.first) << res.second;
}
TEST(TransformationTests, HSigmoidFusionWithoutReluWrongConstValue) {
std::shared_ptr<ngraph::Function> f(nullptr), f_ref(nullptr);
{
auto input = std::make_shared<ngraph::opset4::Parameter>(ngraph::element::f16, ngraph::PartialShape::dynamic(1));
auto add_constant = ngraph::opset4::Constant::create(ngraph::element::f16, ngraph::Shape{}, {3.11});
auto add = std::make_shared<ngraph::opset4::Add>(input, add_constant);
auto max_constant = ngraph::opset4::Constant::create(ngraph::element::f16, ngraph::Shape{}, {0.22});
auto max = std::make_shared<ngraph::opset4::Maximum>(add, max_constant);
auto min_constant = ngraph::opset4::Constant::create(ngraph::element::f16, ngraph::Shape{}, {6.01});
auto min = std::make_shared<ngraph::opset4::Minimum>(max, min_constant);
auto div_constant = ngraph::opset4::Constant::create(ngraph::element::f16, ngraph::Shape{}, {6.002});
auto div = std::make_shared<ngraph::opset4::Divide>(min, div_constant);
f = std::make_shared<ngraph::Function>(ngraph::NodeVector{div}, ngraph::ParameterVector{input});
ngraph::pass::Manager manager;
manager.register_pass<ngraph::pass::InitNodeInfo>();
manager.register_pass<ngraph::pass::HSigmoidFusion>();
manager.run_passes(f);
ASSERT_NO_THROW(check_rt_info(f));
}
{
auto input = std::make_shared<ngraph::opset4::Parameter>(ngraph::element::f16, ngraph::PartialShape::dynamic(1));
auto add_constant = ngraph::opset4::Constant::create(ngraph::element::f16, ngraph::Shape{}, {3.11});
auto add = std::make_shared<ngraph::opset4::Add>(input, add_constant);
auto max_constant = ngraph::opset4::Constant::create(ngraph::element::f16, ngraph::Shape{}, {0.22});
auto max = std::make_shared<ngraph::opset4::Maximum>(add, max_constant);
auto min_constant = ngraph::opset4::Constant::create(ngraph::element::f16, ngraph::Shape{}, {6.01});
auto min = std::make_shared<ngraph::opset4::Minimum>(max, min_constant);
auto div_constant = ngraph::opset4::Constant::create(ngraph::element::f16, ngraph::Shape{}, {6.002});
auto div = std::make_shared<ngraph::opset4::Divide>(min, div_constant);
f_ref = std::make_shared<ngraph::Function>(ngraph::NodeVector{div}, ngraph::ParameterVector{input});
}
auto res = compare_functions(f, f_ref);
ASSERT_TRUE(res.first) << res.second;
}
TEST(TransformationTests, HSigmoidFusionWithClampWrongConstValue) {
std::shared_ptr<ngraph::Function> f(nullptr), f_ref(nullptr);
{
auto input = std::make_shared<ngraph::opset4::Parameter>(ngraph::element::f16, ngraph::PartialShape::dynamic(1));
auto add_constant = ngraph::opset4::Constant::create(ngraph::element::f16, ngraph::Shape{}, {3.11});
auto add = std::make_shared<ngraph::opset4::Add>(input, add_constant);
auto clamp = std::make_shared<ngraph::op::v0::Clamp>(add, 0.11f, 6.02f);
auto mul_constant = ngraph::opset4::Constant::create(ngraph::element::f16, ngraph::Shape{}, {0.98 / 6.15});
auto mul_first = std::make_shared<ngraph::opset4::Multiply>(clamp, mul_constant);
f = std::make_shared<ngraph::Function>(ngraph::NodeVector{mul_first}, ngraph::ParameterVector{input});
ngraph::pass::Manager manager;
manager.register_pass<ngraph::pass::InitNodeInfo>();
manager.register_pass<ngraph::pass::HSigmoidFusion>();
manager.run_passes(f);
ASSERT_NO_THROW(check_rt_info(f));
}
{
auto input = std::make_shared<ngraph::opset4::Parameter>(ngraph::element::f16, ngraph::PartialShape::dynamic(1));
auto add_constant = ngraph::opset4::Constant::create(ngraph::element::f16, ngraph::Shape{}, {3.11});
auto add = std::make_shared<ngraph::opset4::Add>(input, add_constant);
auto clamp = std::make_shared<ngraph::op::v0::Clamp>(add, 0.11f, 6.02f);
auto mul_constant = ngraph::opset4::Constant::create(ngraph::element::f16, ngraph::Shape{}, {0.98 / 6.15});
auto mul_first = std::make_shared<ngraph::opset4::Multiply>(clamp, mul_constant);
f_ref = std::make_shared<ngraph::Function>(ngraph::NodeVector{mul_first}, ngraph::ParameterVector{input});
}
auto res = compare_functions(f, f_ref);
ASSERT_TRUE(res.first) << res.second;
}

View File

@ -36,7 +36,7 @@ namespace ngraph
NGRAPH_RTTI_DECLARATION; NGRAPH_RTTI_DECLARATION;
HSigmoid() = default; HSigmoid() = default;
/// \brief Constructs a HSigmoid (hard version of Swish) operation. /// \brief Constructs a HSigmoid operation.
/// ///
/// \param data Input tensor /// \param data Input tensor
HSigmoid(const Output<Node>& arg); HSigmoid(const Output<Node>& arg);

View File

@ -86,6 +86,7 @@ from ngraph.opset5 import group_convolution
from ngraph.opset5 import group_convolution_backprop_data from ngraph.opset5 import group_convolution_backprop_data
from ngraph.opset5 import gru_cell from ngraph.opset5 import gru_cell
from ngraph.opset5 import hard_sigmoid from ngraph.opset5 import hard_sigmoid
from ngraph.opset5 import hsigmoid
from ngraph.opset5 import hswish from ngraph.opset5 import hswish
from ngraph.opset5 import interpolate from ngraph.opset5 import interpolate
from ngraph.opset5 import less from ngraph.opset5 import less

View File

@ -73,6 +73,7 @@ from ngraph.opset1.ops import group_convolution
from ngraph.opset1.ops import group_convolution_backprop_data from ngraph.opset1.ops import group_convolution_backprop_data
from ngraph.opset3.ops import gru_cell from ngraph.opset3.ops import gru_cell
from ngraph.opset1.ops import hard_sigmoid from ngraph.opset1.ops import hard_sigmoid
from ngraph.opset5.ops import hsigmoid
from ngraph.opset4.ops import hswish from ngraph.opset4.ops import hswish
from ngraph.opset1.ops import interpolate from ngraph.opset1.ops import interpolate
from ngraph.opset1.ops import less from ngraph.opset1.ops import less

View File

@ -130,3 +130,13 @@ def round(data: NodeInput, mode: str = "half_to_even", name: Optional[str] = Non
:return: The new node with Round operation applied on each element. :return: The new node with Round operation applied on each element.
""" """
return _get_node_factory_opset5().create("Round", as_nodes(data), {"mode": mode.upper()}) return _get_node_factory_opset5().create("Round", as_nodes(data), {"mode": mode.upper()})
@nameable_op
def hsigmoid(data: NodeInput, name: Optional[str] = None,) -> Node:
"""Return a node which performs HSigmoid.
:param data: Tensor with input data floating point type.
:return: The new node which performs HSigmoid
"""
return _get_node_factory_opset5().create("HSigmoid", as_nodes(data), {})

View File

@ -179,3 +179,14 @@ def test_round_away():
# result = run_op_node([input_tensor], ng.round, "HALF_AWAY_FROM_ZERO") # result = run_op_node([input_tensor], ng.round, "HALF_AWAY_FROM_ZERO")
# assert np.allclose(result, expected) # assert np.allclose(result, expected)
def test_hsigmoid():
float_dtype = np.float32
data = ng.parameter(Shape([3, 10]), dtype=float_dtype, name="data")
node = ng.hsigmoid(data)
assert node.get_type_name() == "HSigmoid"
assert node.get_output_size() == 1
assert list(node.get_output_shape(0)) == [3, 10]
assert node.get_output_element_type(0) == Type.f32

View File

@ -132,6 +132,7 @@ set(SRC
type_prop/gru_cell.cpp type_prop/gru_cell.cpp
type_prop/gru_sequence.cpp type_prop/gru_sequence.cpp
type_prop/hard_sigmoid.cpp type_prop/hard_sigmoid.cpp
type_prop/hsigmoid.cpp
type_prop/hswish.cpp type_prop/hswish.cpp
type_prop/interpolate.cpp type_prop/interpolate.cpp
type_prop/lrn.cpp type_prop/lrn.cpp

View File

@ -0,0 +1,54 @@
//*****************************************************************************
// Copyright 2017-2020 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#include "gtest/gtest.h"
#include "ngraph/ngraph.hpp"
#include "util/type_prop.hpp"
using namespace std;
using namespace ngraph;
TEST(type_prop, hsigmoid)
{
auto data = make_shared<op::Parameter>(element::f32, Shape{1, 3, 6});
auto hsigmoid_func = make_shared<op::v5::HSigmoid>(data);
EXPECT_EQ(hsigmoid_func->get_element_type(), element::f32);
EXPECT_EQ(hsigmoid_func->get_shape(), data->get_output_shape(0));
}
TEST(type_prop, hsigmoid_partial)
{
auto data = make_shared<op::Parameter>(element::f32, PartialShape{1, Dimension::dynamic(), 6});
auto hsigmoid_func = make_shared<op::v5::HSigmoid>(data);
EXPECT_EQ(hsigmoid_func->get_element_type(), element::f32);
ASSERT_TRUE(
hsigmoid_func->get_output_partial_shape(0).same_scheme(data->get_output_partial_shape(0)));
// rank unknown
auto hsigmoid_partial = make_shared<op::v5::HSigmoid>(
make_shared<op::Parameter>(element::f32, PartialShape::dynamic()));
ASSERT_TRUE(hsigmoid_partial->get_output_partial_shape(0).same_scheme(PartialShape::dynamic()));
}
TEST(type_prop, hsigmoid_partial_static_rank)
{
auto data = make_shared<op::Parameter>(element::f32, PartialShape{1, Dimension::dynamic(), 6});
auto hsigmoid_func = make_shared<op::v5::HSigmoid>(data);
EXPECT_EQ(hsigmoid_func->get_element_type(), element::f32);
ASSERT_TRUE(
hsigmoid_func->get_output_partial_shape(0).same_scheme(data->get_output_partial_shape(0)));
ASSERT_TRUE(hsigmoid_func->get_output_partial_shape(0).rank().is_static());
}