Enable HSwish - ngraph op, fusion/decomposition and reference implementation (#1770)

* Add HSwish operator to nGraph

* Add HSwishFusion transformation

* Update check_constant function

* Add reference implementation for HSwish

* Enable reference implemenation in HSwish evaluate

* Add op_eval test

* HSwish fusion transformation test

* Add HSwishFusionWithoutRelu transformation

* Add more hswish fusion tests

* Register HSwishFusion pass in common_optimizations

* Update HSwish reference implementation

* Add HSwishFusion with Relu and Multiply

* Add HSwishDecomposition transformation pass

* Add HSwishDecomposition test

* Add HSwish op to ngraph python API

* Update HSwish fusion transformations

* Remove HSwishFusion from common optimizations

* Update hswish python API

* Add bf16 to evaluate hswish

* Update hswish python API

* Move hswish reference implementation

* UnaryElementwiseArithmetic inheritance

* Enable HSwish callback for clDNN

* Register HSwishDecomposition pass in ConvertOpSet1ToLegacy

* Enable HSwishFusion pass in common optimizations

* Use NGRAPH_RTTI_DECLARATION

* Moved python hswish test to the test_ops_unary
This commit is contained in:
Katarzyna Mitrus
2020-08-19 07:04:00 +02:00
committed by GitHub
parent 6b04eca3c2
commit ceb8a25c94
21 changed files with 945 additions and 0 deletions

View File

@@ -91,6 +91,7 @@ InferenceEngine::ICNNNetwork::Ptr clDNNEngine::CloneAndTransformNetwork(const In
std::dynamic_pointer_cast<const ::ngraph::opset3::ShuffleChannels>(node) ||
std::dynamic_pointer_cast<const ::ngraph::opset2::BatchToSpace>(node) ||
std::dynamic_pointer_cast<const ::ngraph::opset2::SpaceToBatch>(node) ||
std::dynamic_pointer_cast<const ::ngraph::opset4::HSwish>(node) ||
std::dynamic_pointer_cast<const ::ngraph::opset4::ReduceL1>(node) ||
std::dynamic_pointer_cast<const ::ngraph::opset4::ReduceL2>(node);
};

View File

@@ -0,0 +1,25 @@
// Copyright (C) 2018-2020 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#pragma once
#include <transformations_visibility.hpp>
#include <ngraph/pass/graph_rewrite.hpp>
namespace ngraph {
namespace pass {
class TRANSFORMATIONS_API HSwishDecomposition;
} // namespace pass
} // namespace ngraph
/**
* @ingroup ie_transformation_common_api
* @brief HSwishDecomposition transformation into sub-graph x * (min(Relu(x + 3), 6) * const(1/6).
*/
class ngraph::pass::HSwishDecomposition: public ngraph::pass::MatcherPass {
public:
HSwishDecomposition();
};

View File

@@ -0,0 +1,63 @@
// Copyright (C) 2018-2020 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#pragma once
#include <memory>
#include <utility>
#include <transformations_visibility.hpp>
#include <ngraph/pass/graph_rewrite.hpp>
namespace ngraph {
namespace pass {
class TRANSFORMATIONS_API HSwishFusion;
class TRANSFORMATIONS_API HSwishFusionWithReluDiv;
class TRANSFORMATIONS_API HSwishFusionWithReluMul;
class TRANSFORMATIONS_API HSwishFusionWithoutRelu;
} // namespace pass
} // namespace ngraph
/**
* @ingroup ie_transformation_common_api
* @brief HSwishFusion transformation replaces various sub-graphs with a HSwish op.
*/
class ngraph::pass::HSwishFusion: public ngraph::pass::GraphRewrite {
public:
HSwishFusion() {
add_matcher<ngraph::pass::HSwishFusionWithReluDiv>();
add_matcher<ngraph::pass::HSwishFusionWithReluMul>();
add_matcher<ngraph::pass::HSwishFusionWithoutRelu>();
}
};
/**
* @ingroup ie_transformation_common_api
* @brief HSwishFusion transformation replaces a sub-graph (x * (min(Relu(x + 3), 6))) / 6 with a HSwish op.
*/
class ngraph::pass::HSwishFusionWithReluDiv: public ngraph::pass::MatcherPass {
public:
HSwishFusionWithReluDiv();
};
/**
* @ingroup ie_transformation_common_api
* @brief HSwishFusion transformation replaces a sub-graph (x * (min(Relu(x + 3), 6)) * const(1/6) with a HSwish op.
*/
class ngraph::pass::HSwishFusionWithReluMul: public ngraph::pass::MatcherPass {
public:
HSwishFusionWithReluMul();
};
/**
* @ingroup ie_transformation_common_api
* @brief HSwishFusion transformation replaces a sub-graph x * (min(max(x + 3, 0), 6) / 6) with a HSwish op.
*/
class ngraph::pass::HSwishFusionWithoutRelu: public ngraph::pass::MatcherPass {
public:
HSwishFusionWithoutRelu();
};

View File

@@ -14,6 +14,7 @@
#include "transformations/itt.hpp"
#include "transformations/mish_fusion.hpp"
#include "transformations/swish_fusion.hpp"
#include "transformations/hswish_fusion.hpp"
#include <ngraph/pass/manager.hpp>
#include <ngraph/pass/nop_elimination.hpp>
@@ -37,6 +38,7 @@ bool ngraph::pass::CommonOptimizations::run_on_function(std::shared_ptr<ngraph::
manager.register_pass<ngraph::pass::DepthToSpaceFusion>();
manager.register_pass<ngraph::pass::MishFusion>();
manager.register_pass<ngraph::pass::SwishFusion>();
manager.register_pass<ngraph::pass::HSwishFusion>();
manager.register_pass<ngraph::pass::ConvertPadToGroupConvolution>();
manager.set_callback(m_transformation_callback);

View File

@@ -47,6 +47,7 @@
#include <transformations/convert_opset1_to_legacy/convert_hard_sigmoid_to_hard_sigmoid_ie.hpp>
#include <transformations/lin_op_sequence_fusoin.hpp>
#include <transformations/common_optimizations/conv_mul_fusion.hpp>
#include <transformations/hswish_decomposition.hpp>
#include <transformations/reduce_l1_decomposition.hpp>
#include <transformations/reduce_l2_decomposition.hpp>
@@ -71,6 +72,10 @@ bool ngraph::pass::ConvertOpSet1ToLegacy::run_on_function(std::shared_ptr<ngraph
manager.register_pass<ngraph::pass::ReduceL1Decomposition>();
manager.register_pass<ngraph::pass::ReduceL2Decomposition>();
// HSwishDecomposition produce Minimum, Relu and Multiply operations
// so it must be executed before
manager.register_pass<ngraph::pass::HSwishDecomposition>();
// List if Decomposition and Conversion transformations that can be
// applied simultaneously in a single graph traversal
auto decomp = manager.register_pass<ngraph::pass::GraphRewrite>();

View File

@@ -0,0 +1,44 @@
// Copyright (C) 2018-2020 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#include "transformations/hswish_decomposition.hpp"
#include <memory>
#include <ngraph/opsets/opset4.hpp>
#include <ngraph/rt_info.hpp>
#include <ngraph/pattern/op/wrap_type.hpp>
ngraph::pass::HSwishDecomposition::HSwishDecomposition() {
// Decomposes HSwish(x) op into sub-graph x * (min(Relu(x + 3), 6) * const(1/6)
auto hswish = ngraph::pattern::wrap_type<opset4::HSwish>();
ngraph::matcher_pass_callback callback = [=](ngraph::pattern::Matcher &m) {
auto &pattern_to_output = m.get_pattern_value_map();
auto hswish_node = pattern_to_output.at(hswish).get_node_shared_ptr();
if (m_transformation_callback(hswish_node)) {
return false;
}
auto input_type = hswish_node->input_value(0).get_element_type();
auto add_constant = ngraph::opset4::Constant::create(input_type, ngraph::Shape{}, {3.0});
auto add = std::make_shared<ngraph::opset4::Add>(hswish_node->input_value(0), add_constant);
auto relu = std::make_shared<ngraph::opset4::Relu>(add);
auto min_constant = ngraph::opset4::Constant::create(input_type, ngraph::Shape{}, {6.0});
auto min = std::make_shared<ngraph::opset4::Minimum>(relu, min_constant);
auto mul_first = std::make_shared<ngraph::opset4::Multiply>(hswish_node->input_value(0), min);
auto mul_constant = ngraph::opset4::Constant::create(input_type, ngraph::Shape{}, {(1.0/6.0)}); // const(1/6)
auto mul_second = std::make_shared<ngraph::opset4::Multiply>(mul_first, mul_constant);
mul_second->set_friendly_name(m.get_match_root()->get_friendly_name());
ngraph::copy_runtime_info(hswish_node,
{add_constant, add, relu, min_constant, min, mul_first, mul_constant, mul_second});
ngraph::replace_node(m.get_match_root(), mul_second);
return true;
};
auto m = std::make_shared<ngraph::pattern::Matcher>(hswish, "HSwishDecomposition");
register_matcher(m, callback);
}

View File

@@ -0,0 +1,180 @@
// Copyright (C) 2018-2020 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#include "transformations/hswish_fusion.hpp"
#include <memory>
#include <ngraph/opsets/opset4.hpp>
#include <ngraph/rt_info.hpp>
#include <ngraph/pattern/op/wrap_type.hpp>
bool check_constant_value(const std::shared_ptr<ngraph::opset4::Constant>& constant,
const float value,
float epsilon = std::numeric_limits<float>::epsilon()) {
if (!constant) {
return false;
}
if (constant->get_element_type() == ngraph::element::f32 || constant->get_element_type() == ngraph::element::f16) {
auto data = constant->cast_vector<float>();
if (data.size() != 1 || std::fabs(data[0] - value) > epsilon) {
return false;
}
} else {
return false;
}
return true;
}
ngraph::pass::HSwishFusionWithReluDiv::HSwishFusionWithReluDiv() {
// Replaces a sub-graph (x * (min(Relu(x + 3), 6)) / 6 with a HSwish op.
auto input = ngraph::pattern::any_input();
auto add_constant = ngraph::pattern::wrap_type<ngraph::opset4::Constant>();
auto add = std::make_shared<ngraph::opset4::Add>(input, add_constant);
auto relu = std::make_shared<ngraph::opset4::Relu>(add);
auto min_constant = ngraph::pattern::wrap_type<ngraph::opset4::Constant>();
auto min = std::make_shared<ngraph::opset4::Minimum>(relu, min_constant);
auto mul = std::make_shared<ngraph::opset4::Multiply>(input, min);
auto div_constant = ngraph::pattern::wrap_type<ngraph::opset4::Constant>();
auto div = std::make_shared<ngraph::opset4::Divide>(mul, div_constant);
ngraph::matcher_pass_callback callback = [=](ngraph::pattern::Matcher &m) {
auto &pattern_to_output = m.get_pattern_value_map();
auto x_output = pattern_to_output.at(input);
auto add_const_value = std::dynamic_pointer_cast<ngraph::opset4::Constant>(pattern_to_output.at(add_constant).get_node_shared_ptr());
auto min_const_value = std::dynamic_pointer_cast<ngraph::opset4::Constant>(pattern_to_output.at(min_constant).get_node_shared_ptr());
auto div_const_value = std::dynamic_pointer_cast<ngraph::opset4::Constant>(pattern_to_output.at(div_constant).get_node_shared_ptr());
bool valid_constant_values = check_constant_value(add_const_value, 3.0)
&& check_constant_value(min_const_value, 6.0)
&& check_constant_value(div_const_value, 6.0);
if (!valid_constant_values) {
return false;
}
auto hswish = std::make_shared<ngraph::opset4::HSwish>(x_output);
hswish->set_friendly_name(m.get_match_root()->get_friendly_name());
ngraph::copy_runtime_info({ pattern_to_output.at(add_constant).get_node_shared_ptr(),
pattern_to_output.at(add).get_node_shared_ptr(),
pattern_to_output.at(relu).get_node_shared_ptr(),
pattern_to_output.at(min_constant).get_node_shared_ptr(),
pattern_to_output.at(min).get_node_shared_ptr(),
pattern_to_output.at(mul).get_node_shared_ptr(),
pattern_to_output.at(div_constant).get_node_shared_ptr(),
pattern_to_output.at(div).get_node_shared_ptr(),
},
hswish);
ngraph::replace_node(m.get_match_root(), hswish);
return true;
};
auto m = std::make_shared<ngraph::pattern::Matcher>(div, "HSwishWithReluDivFusion");
register_matcher(m, callback);
}
ngraph::pass::HSwishFusionWithReluMul::HSwishFusionWithReluMul() {
// Replaces a sub-graph (x * (min(Relu(x + 3), 6)) * const(1/6) with a HSwish op.
auto input = ngraph::pattern::any_input();
auto add_constant = ngraph::pattern::wrap_type<ngraph::opset4::Constant>();
auto add = std::make_shared<ngraph::opset4::Add>(input, add_constant);
auto relu = std::make_shared<ngraph::opset4::Relu>(add);
auto min_constant = ngraph::pattern::wrap_type<ngraph::opset4::Constant>();
auto min = std::make_shared<ngraph::opset4::Minimum>(relu, min_constant);
auto mul_first = std::make_shared<ngraph::opset4::Multiply>(input, min);
auto mul_constant = ngraph::pattern::wrap_type<ngraph::opset4::Constant>();
auto mul_second = std::make_shared<ngraph::opset4::Multiply>(mul_first, mul_constant);
ngraph::matcher_pass_callback callback = [=](ngraph::pattern::Matcher &m) {
auto &pattern_to_output = m.get_pattern_value_map();
auto x_output = pattern_to_output.at(input);
auto add_const_value = std::dynamic_pointer_cast<ngraph::opset4::Constant>(pattern_to_output.at(add_constant).get_node_shared_ptr());
auto min_const_value = std::dynamic_pointer_cast<ngraph::opset4::Constant>(pattern_to_output.at(min_constant).get_node_shared_ptr());
auto mul_const_value = std::dynamic_pointer_cast<ngraph::opset4::Constant>(pattern_to_output.at(mul_constant).get_node_shared_ptr());
bool valid_constant_values = check_constant_value(add_const_value, 3.0)
&& check_constant_value(min_const_value, 6.0)
&& check_constant_value(mul_const_value, (1.0/6.0), 0.0001);
if (!valid_constant_values) {
return false;
}
auto hswish = std::make_shared<ngraph::opset4::HSwish>(x_output);
hswish->set_friendly_name(m.get_match_root()->get_friendly_name());
ngraph::copy_runtime_info({ pattern_to_output.at(add_constant).get_node_shared_ptr(),
pattern_to_output.at(add).get_node_shared_ptr(),
pattern_to_output.at(relu).get_node_shared_ptr(),
pattern_to_output.at(min_constant).get_node_shared_ptr(),
pattern_to_output.at(min).get_node_shared_ptr(),
pattern_to_output.at(mul_first).get_node_shared_ptr(),
pattern_to_output.at(mul_constant).get_node_shared_ptr(),
pattern_to_output.at(mul_second).get_node_shared_ptr()
},
hswish);
ngraph::replace_node(m.get_match_root(), hswish);
return true;
};
auto m = std::make_shared<ngraph::pattern::Matcher>(mul_second, "HSwishWithReluMulFusion");
register_matcher(m, callback);
}
ngraph::pass::HSwishFusionWithoutRelu::HSwishFusionWithoutRelu() {
// Replaces a sub-graph x * (min(max(x + 3, 0), 6) / 6) with a HSwish op.
auto input = ngraph::pattern::any_input();
auto add_constant = ngraph::pattern::wrap_type<ngraph::opset4::Constant>();
auto add = std::make_shared<ngraph::opset4::Add>(input, add_constant);
auto max_constant = ngraph::pattern::wrap_type<ngraph::opset4::Constant>();
auto max = std::make_shared<ngraph::opset4::Maximum>(add, max_constant);
auto min_constant = ngraph::pattern::wrap_type<ngraph::opset4::Constant>();
auto min = std::make_shared<ngraph::opset4::Minimum>(max, min_constant);
auto div_constant = ngraph::pattern::wrap_type<ngraph::opset4::Constant>();
auto div = std::make_shared<ngraph::opset4::Divide>(min, div_constant);
auto mul = std::make_shared<ngraph::opset4::Multiply>(input, div);
ngraph::matcher_pass_callback callback = [=](ngraph::pattern::Matcher &m) {
auto &pattern_to_output = m.get_pattern_value_map();
auto x_output = pattern_to_output.at(input);
auto add_const_value = std::dynamic_pointer_cast<ngraph::opset4::Constant>(pattern_to_output.at(add_constant).get_node_shared_ptr());
auto max_const_value = std::dynamic_pointer_cast<ngraph::opset4::Constant>(pattern_to_output.at(max_constant).get_node_shared_ptr());
auto min_const_value = std::dynamic_pointer_cast<ngraph::opset4::Constant>(pattern_to_output.at(min_constant).get_node_shared_ptr());
auto div_const_value = std::dynamic_pointer_cast<ngraph::opset4::Constant>(pattern_to_output.at(div_constant).get_node_shared_ptr());
bool valid_constant_values = check_constant_value(add_const_value, 3.0)
&& check_constant_value(max_const_value, 0.0)
&& check_constant_value(min_const_value, 6.0)
&& check_constant_value(div_const_value, 6.0);
if (!valid_constant_values) {
return false;
}
auto hswish = std::make_shared<ngraph::opset4::HSwish>(x_output);
hswish->set_friendly_name(m.get_match_root()->get_friendly_name());
ngraph::copy_runtime_info({ pattern_to_output.at(add_constant).get_node_shared_ptr(),
pattern_to_output.at(add).get_node_shared_ptr(),
pattern_to_output.at(max_constant).get_node_shared_ptr(),
pattern_to_output.at(max).get_node_shared_ptr(),
pattern_to_output.at(min_constant).get_node_shared_ptr(),
pattern_to_output.at(min).get_node_shared_ptr(),
pattern_to_output.at(div_constant).get_node_shared_ptr(),
pattern_to_output.at(div).get_node_shared_ptr(),
pattern_to_output.at(mul).get_node_shared_ptr()
},
hswish);
ngraph::replace_node(m.get_match_root(), hswish);
return true;
};
auto m = std::make_shared<ngraph::pattern::Matcher>(mul, "HSwishWithoutReluFusion");
register_matcher(m, callback);
}

View File

@@ -0,0 +1,52 @@
// Copyright (C) 2020 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#include <gtest/gtest.h>
#include <string>
#include <memory>
#include <ngraph/function.hpp>
#include <ngraph/opsets/opset4.hpp>
#include <ngraph/pass/manager.hpp>
#include <transformations/hswish_decomposition.hpp>
#include <transformations/init_node_info.hpp>
#include <transformations/utils/utils.hpp>
#include "common_test_utils/ngraph_test_utils.hpp"
using namespace testing;
TEST(TransformationTests, HSwishDecompositionTest) {
std::shared_ptr<ngraph::Function> f(nullptr), f_ref(nullptr);
{
auto input = std::make_shared<ngraph::opset4::Parameter>(ngraph::element::f32, ngraph::PartialShape::dynamic(1));
auto hswish = std::make_shared<ngraph::opset4::HSwish>(input);
f = std::make_shared<ngraph::Function>(ngraph::NodeVector{hswish}, ngraph::ParameterVector{input});
ngraph::pass::Manager manager;
manager.register_pass<ngraph::pass::InitNodeInfo>();
manager.register_pass<ngraph::pass::HSwishDecomposition>();
manager.run_passes(f);
ASSERT_NO_THROW(check_rt_info(f));
}
{
auto input = std::make_shared<ngraph::opset4::Parameter>(ngraph::element::f16, ngraph::PartialShape::dynamic(1));
auto add_constant = ngraph::opset4::Constant::create(ngraph::element::f16, ngraph::Shape{}, {3.0});
auto add = std::make_shared<ngraph::opset4::Add>(input, add_constant);
auto relu = std::make_shared<ngraph::opset4::Relu>(add);
auto min_constant = ngraph::opset4::Constant::create(ngraph::element::f16, ngraph::Shape{}, {6.0});
auto min = std::make_shared<ngraph::opset4::Minimum>(relu, min_constant);
auto mul_first = std::make_shared<ngraph::opset4::Multiply>(input, min);
auto mul_constant = ngraph::opset4::Constant::create(ngraph::element::f16, ngraph::Shape{}, {0.1666666716});
auto mul_second = std::make_shared<ngraph::opset4::Multiply>(mul_first, mul_constant);
f_ref = std::make_shared<ngraph::Function>(ngraph::NodeVector{mul_second}, ngraph::ParameterVector{input});
}
auto res = compare_functions(f, f_ref);
ASSERT_TRUE(res.first) << res.second;
}

View File

@@ -0,0 +1,274 @@
// Copyright (C) 2020 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#include <gtest/gtest.h>
#include <string>
#include <memory>
#include <ngraph/function.hpp>
#include <ngraph/opsets/opset4.hpp>
#include <ngraph/pass/manager.hpp>
#include <transformations/hswish_fusion.hpp>
#include <transformations/init_node_info.hpp>
#include <transformations/utils/utils.hpp>
#include "common_test_utils/ngraph_test_utils.hpp"
using namespace testing;
TEST(TransformationTests, HSwishFusionWithReluDivF16) {
std::shared_ptr<ngraph::Function> f(nullptr), f_ref(nullptr);
{
auto input = std::make_shared<ngraph::opset4::Parameter>(ngraph::element::f16, ngraph::PartialShape::dynamic(1));
auto add_constant = ngraph::opset4::Constant::create(ngraph::element::f16, ngraph::Shape{}, {3.0});
auto add = std::make_shared<ngraph::opset4::Add>(input, add_constant);
auto relu = std::make_shared<ngraph::opset4::Relu>(add);
auto min_constant = ngraph::opset4::Constant::create(ngraph::element::f16, ngraph::Shape{}, {6.0});
auto min = std::make_shared<ngraph::opset4::Minimum>(relu, min_constant);
auto mul = std::make_shared<ngraph::opset4::Multiply>(input, min);
auto div_constant = ngraph::opset4::Constant::create(ngraph::element::f16, ngraph::Shape{}, {6.0});
auto div = std::make_shared<ngraph::opset4::Divide>(mul, div_constant);
f = std::make_shared<ngraph::Function>(ngraph::NodeVector{div}, ngraph::ParameterVector{input});
ngraph::pass::Manager manager;
manager.register_pass<ngraph::pass::InitNodeInfo>();
manager.register_pass<ngraph::pass::HSwishFusionWithReluDiv>();
manager.run_passes(f);
ASSERT_NO_THROW(check_rt_info(f));
}
{
auto input = std::make_shared<ngraph::opset4::Parameter>(ngraph::element::f16, ngraph::PartialShape::dynamic(1));
auto hswish = std::make_shared<ngraph::opset4::HSwish>(input);
f_ref = std::make_shared<ngraph::Function>(ngraph::NodeVector{hswish}, ngraph::ParameterVector{input});
}
auto res = compare_functions(f, f_ref);
ASSERT_TRUE(res.first) << res.second;
}
TEST(TransformationTests, HSwishFusionWithReluDivF32) {
std::shared_ptr<ngraph::Function> f(nullptr), f_ref(nullptr);
{
auto input = std::make_shared<ngraph::opset4::Parameter>(ngraph::element::f32, ngraph::Shape{});
auto add_constant = ngraph::opset4::Constant::create(ngraph::element::f32, ngraph::Shape{}, {3.0});
auto add = std::make_shared<ngraph::opset4::Add>(input, add_constant);
auto relu = std::make_shared<ngraph::opset4::Relu>(add);
auto min_constant = ngraph::opset4::Constant::create(ngraph::element::f32, ngraph::Shape{}, {6.0});
auto min = std::make_shared<ngraph::opset4::Minimum>(relu, min_constant);
auto mul = std::make_shared<ngraph::opset4::Multiply>(input, min);
auto div_constant = ngraph::opset4::Constant::create(ngraph::element::f32, ngraph::Shape{}, {6.0});
auto div = std::make_shared<ngraph::opset4::Divide>(mul, div_constant);
f = std::make_shared<ngraph::Function>(ngraph::NodeVector{div}, ngraph::ParameterVector{input});
ngraph::pass::Manager manager;
manager.register_pass<ngraph::pass::InitNodeInfo>();
manager.register_pass<ngraph::pass::HSwishFusionWithReluDiv>();
manager.run_passes(f);
ASSERT_NO_THROW(check_rt_info(f));
}
{
auto input = std::make_shared<ngraph::opset4::Parameter>(ngraph::element::f32, ngraph::Shape{});
auto hswish = std::make_shared<ngraph::opset4::HSwish>(input);
f_ref = std::make_shared<ngraph::Function>(ngraph::NodeVector{hswish}, ngraph::ParameterVector{input});
}
auto res = compare_functions(f, f_ref);
ASSERT_TRUE(res.first) << res.second;
}
TEST(TransformationTests, HSwishFusionWithReluMul) {
std::shared_ptr<ngraph::Function> f(nullptr), f_ref(nullptr);
{
auto input = std::make_shared<ngraph::opset4::Parameter>(ngraph::element::f16, ngraph::PartialShape::dynamic(1));
auto add_constant = ngraph::opset4::Constant::create(ngraph::element::f16, ngraph::Shape{}, {3.0});
auto add = std::make_shared<ngraph::opset4::Add>(input, add_constant);
auto relu = std::make_shared<ngraph::opset4::Relu>(add);
auto min_constant = ngraph::opset4::Constant::create(ngraph::element::f16, ngraph::Shape{}, {6.0});
auto min = std::make_shared<ngraph::opset4::Minimum>(relu, min_constant);
auto mul_first = std::make_shared<ngraph::opset4::Multiply>(input, min);
auto mul_constant = ngraph::opset4::Constant::create(ngraph::element::f16, ngraph::Shape{}, {0.1666666716});
auto mul_second = std::make_shared<ngraph::opset4::Multiply>(mul_first, mul_constant);
f = std::make_shared<ngraph::Function>(ngraph::NodeVector{mul_second}, ngraph::ParameterVector{input});
ngraph::pass::Manager manager;
manager.register_pass<ngraph::pass::InitNodeInfo>();
manager.register_pass<ngraph::pass::HSwishFusionWithReluMul>();
manager.run_passes(f);
ASSERT_NO_THROW(check_rt_info(f));
}
{
auto input = std::make_shared<ngraph::opset4::Parameter>(ngraph::element::f16, ngraph::PartialShape::dynamic(1));
auto hswish = std::make_shared<ngraph::opset4::HSwish>(input);
f_ref = std::make_shared<ngraph::Function>(ngraph::NodeVector{hswish}, ngraph::ParameterVector{input});
}
auto res = compare_functions(f, f_ref);
ASSERT_TRUE(res.first) << res.second;
}
TEST(TransformationTests, HSwishFusionWithoutRelu) {
std::shared_ptr<ngraph::Function> f(nullptr), f_ref(nullptr);
{
auto input = std::make_shared<ngraph::opset4::Parameter>(ngraph::element::f16, ngraph::PartialShape::dynamic(1));
auto add_constant = ngraph::opset4::Constant::create(ngraph::element::f16, ngraph::Shape{}, {3.0});
auto add = std::make_shared<ngraph::opset4::Add>(input, add_constant);
auto max_constant = ngraph::opset4::Constant::create(ngraph::element::f16, ngraph::Shape{}, {0.0});
auto max = std::make_shared<ngraph::opset4::Maximum>(add, max_constant);
auto min_constant = ngraph::opset4::Constant::create(ngraph::element::f16, ngraph::Shape{}, {6.0});
auto min = std::make_shared<ngraph::opset4::Minimum>(max, min_constant);
auto div_constant = ngraph::opset4::Constant::create(ngraph::element::f16, ngraph::Shape{}, {6.0});
auto div = std::make_shared<ngraph::opset4::Divide>(min, div_constant);
auto mul = std::make_shared<ngraph::opset4::Multiply>(input, div);
f = std::make_shared<ngraph::Function>(ngraph::NodeVector{mul}, ngraph::ParameterVector{input});
ngraph::pass::Manager manager;
manager.register_pass<ngraph::pass::InitNodeInfo>();
manager.register_pass<ngraph::pass::HSwishFusionWithoutRelu>();
manager.run_passes(f);
ASSERT_NO_THROW(check_rt_info(f));
}
{
auto input = std::make_shared<ngraph::opset4::Parameter>(ngraph::element::f16, ngraph::PartialShape::dynamic(1));
auto hswish = std::make_shared<ngraph::opset4::HSwish>(input);
f_ref = std::make_shared<ngraph::Function>(ngraph::NodeVector{hswish}, ngraph::ParameterVector{input});
}
auto res = compare_functions(f, f_ref);
ASSERT_TRUE(res.first) << res.second;
}
TEST(TransformationTests, HSwishFusionWithReluMulWrongConstValue) {
std::shared_ptr<ngraph::Function> f(nullptr), f_ref(nullptr);
{
auto input = std::make_shared<ngraph::opset4::Parameter>(ngraph::element::f16, ngraph::PartialShape::dynamic(1));
auto add_constant = ngraph::opset4::Constant::create(ngraph::element::f16, ngraph::Shape{}, {3.0});
auto add = std::make_shared<ngraph::opset4::Add>(input, add_constant);
auto relu = std::make_shared<ngraph::opset4::Relu>(add);
auto min_constant = ngraph::opset4::Constant::create(ngraph::element::f16, ngraph::Shape{}, {6.0});
auto min = std::make_shared<ngraph::opset4::Minimum>(relu, min_constant);
auto mul_first = std::make_shared<ngraph::opset4::Multiply>(input, min);
auto mul_constant = ngraph::opset4::Constant::create(ngraph::element::f16, ngraph::Shape{}, {0.167});
auto mul_second = std::make_shared<ngraph::opset4::Multiply>(mul_first, mul_constant);
f = std::make_shared<ngraph::Function>(ngraph::NodeVector{mul_second}, ngraph::ParameterVector{input});
ngraph::pass::Manager manager;
manager.register_pass<ngraph::pass::InitNodeInfo>();
manager.register_pass<ngraph::pass::HSwishFusionWithReluMul>();
manager.run_passes(f);
ASSERT_NO_THROW(check_rt_info(f));
}
{
auto input = std::make_shared<ngraph::opset4::Parameter>(ngraph::element::f16, ngraph::PartialShape::dynamic(1));
auto add_constant = ngraph::opset4::Constant::create(ngraph::element::f16, ngraph::Shape{}, {3.0});
auto add = std::make_shared<ngraph::opset4::Add>(input, add_constant);
auto relu = std::make_shared<ngraph::opset4::Relu>(add);
auto min_constant = ngraph::opset4::Constant::create(ngraph::element::f16, ngraph::Shape{}, {6.0});
auto min = std::make_shared<ngraph::opset4::Minimum>(relu, min_constant);
auto mul_first = std::make_shared<ngraph::opset4::Multiply>(input, min);
auto mul_constant = ngraph::opset4::Constant::create(ngraph::element::f16, ngraph::Shape{}, {0.167});
auto mul_second = std::make_shared<ngraph::opset4::Multiply>(mul_first, mul_constant);
f_ref = std::make_shared<ngraph::Function>(ngraph::NodeVector{mul_second}, ngraph::ParameterVector{input});
}
auto res = compare_functions(f, f_ref);
ASSERT_TRUE(res.first) << res.second;
}
TEST(TransformationTests, HSwishFusionWithReluDivWrongConstValue) {
std::shared_ptr<ngraph::Function> f(nullptr), f_ref(nullptr);
{
auto input = std::make_shared<ngraph::opset4::Parameter>(ngraph::element::f16, ngraph::Shape{});
auto add_constant = ngraph::opset4::Constant::create(ngraph::element::f16, ngraph::Shape{}, {3.01});
auto add = std::make_shared<ngraph::opset4::Add>(input, add_constant);
auto relu = std::make_shared<ngraph::opset4::Relu>(add);
auto min_constant = ngraph::opset4::Constant::create(ngraph::element::f16, ngraph::Shape{}, {6.002});
auto min = std::make_shared<ngraph::opset4::Minimum>(relu, min_constant);
auto mul = std::make_shared<ngraph::opset4::Multiply>(input, min);
auto div_constant = ngraph::opset4::Constant::create(ngraph::element::f16, ngraph::Shape{}, {0.0});
auto div = std::make_shared<ngraph::opset4::Divide>(mul, div_constant);
f = std::make_shared<ngraph::Function>(ngraph::NodeVector{div}, ngraph::ParameterVector{input});
ngraph::pass::Manager manager;
manager.register_pass<ngraph::pass::InitNodeInfo>();
manager.register_pass<ngraph::pass::HSwishFusionWithReluDiv>();
manager.run_passes(f);
ASSERT_NO_THROW(check_rt_info(f));
}
{
auto input = std::make_shared<ngraph::opset4::Parameter>(ngraph::element::f16, ngraph::Shape{});
auto add_constant = ngraph::opset4::Constant::create(ngraph::element::f16, ngraph::Shape{}, {3.01});
auto add = std::make_shared<ngraph::opset4::Add>(input, add_constant);
auto relu = std::make_shared<ngraph::opset4::Relu>(add);
auto min_constant = ngraph::opset4::Constant::create(ngraph::element::f16, ngraph::Shape{}, {6.002});
auto min = std::make_shared<ngraph::opset4::Minimum>(relu, min_constant);
auto mul = std::make_shared<ngraph::opset4::Multiply>(input, min);
auto div_constant = ngraph::opset4::Constant::create(ngraph::element::f16, ngraph::Shape{}, {0.0});
auto div = std::make_shared<ngraph::opset4::Divide>(mul, div_constant);
f_ref = std::make_shared<ngraph::Function>(ngraph::NodeVector{div}, ngraph::ParameterVector{input});
}
auto res = compare_functions(f, f_ref);
ASSERT_TRUE(res.first) << res.second;
}
TEST(TransformationTests, HSwishFusionWithoutReluWrongConstValue) {
std::shared_ptr<ngraph::Function> f(nullptr), f_ref(nullptr);
{
auto input = std::make_shared<ngraph::opset4::Parameter>(ngraph::element::f16, ngraph::PartialShape::dynamic(1));
auto add_constant = ngraph::opset4::Constant::create(ngraph::element::f16, ngraph::Shape{}, {3.11});
auto add = std::make_shared<ngraph::opset4::Add>(input, add_constant);
auto max_constant = ngraph::opset4::Constant::create(ngraph::element::f16, ngraph::Shape{}, {0.22});
auto max = std::make_shared<ngraph::opset4::Maximum>(add, max_constant);
auto min_constant = ngraph::opset4::Constant::create(ngraph::element::f16, ngraph::Shape{}, {6.01});
auto min = std::make_shared<ngraph::opset4::Minimum>(max, min_constant);
auto div_constant = ngraph::opset4::Constant::create(ngraph::element::f16, ngraph::Shape{}, {6.002});
auto div = std::make_shared<ngraph::opset4::Divide>(min, div_constant);
auto mul = std::make_shared<ngraph::opset4::Multiply>(input, div);
f = std::make_shared<ngraph::Function>(ngraph::NodeVector{mul}, ngraph::ParameterVector{input});
ngraph::pass::Manager manager;
manager.register_pass<ngraph::pass::InitNodeInfo>();
manager.register_pass<ngraph::pass::HSwishFusionWithoutRelu>();
manager.run_passes(f);
ASSERT_NO_THROW(check_rt_info(f));
}
{
auto input = std::make_shared<ngraph::opset4::Parameter>(ngraph::element::f16, ngraph::PartialShape::dynamic(1));
auto add_constant = ngraph::opset4::Constant::create(ngraph::element::f16, ngraph::Shape{}, {3.11});
auto add = std::make_shared<ngraph::opset4::Add>(input, add_constant);
auto max_constant = ngraph::opset4::Constant::create(ngraph::element::f16, ngraph::Shape{}, {0.22});
auto max = std::make_shared<ngraph::opset4::Maximum>(add, max_constant);
auto min_constant = ngraph::opset4::Constant::create(ngraph::element::f16, ngraph::Shape{}, {6.01});
auto min = std::make_shared<ngraph::opset4::Minimum>(max, min_constant);
auto div_constant = ngraph::opset4::Constant::create(ngraph::element::f16, ngraph::Shape{}, {6.002});
auto div = std::make_shared<ngraph::opset4::Divide>(min, div_constant);
auto mul = std::make_shared<ngraph::opset4::Multiply>(input, div);
f_ref = std::make_shared<ngraph::Function>(ngraph::NodeVector{mul}, ngraph::ParameterVector{input});
}
auto res = compare_functions(f, f_ref);
ASSERT_TRUE(res.first) << res.second;
}

View File

@@ -0,0 +1,53 @@
//*****************************************************************************
// Copyright 2017-2020 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#pragma once
#include "ngraph/node.hpp"
#include "ngraph/op/op.hpp"
#include "ngraph/op/util/unary_elementwise_arithmetic.hpp"
namespace ngraph
{
namespace op
{
namespace v4
{
/// \brief A HSwish Activation Function
/// f(x) = x * min(max(x + 3, 0), 6) / 6 or
/// f(x) = x * min(ReLU(x + 3), 6) / 6
///
class NGRAPH_API HSwish : public ngraph::op::util::UnaryElementwiseArithmetic
{
public:
NGRAPH_RTTI_DECLARATION;
HSwish() = default;
/// \brief Constructs a HSwish (hard version of Swish) operation.
///
/// \param data Input tensor
HSwish(const Output<Node>& arg);
bool visit_attributes(AttributeVisitor& visitor) override;
virtual std::shared_ptr<Node>
clone_with_new_inputs(const OutputVector& new_args) const override;
bool evaluate(const HostTensorVector& outputs,
const HostTensorVector& inputs) const override;
};
}
}
}

View File

@@ -75,6 +75,7 @@
#include "ngraph/op/group_conv.hpp"
#include "ngraph/op/gru_cell.hpp"
#include "ngraph/op/hard_sigmoid.hpp"
#include "ngraph/op/hswish.hpp"
#include "ngraph/op/interpolate.hpp"
#include "ngraph/op/less.hpp"
#include "ngraph/op/less_eq.hpp"

View File

@@ -156,6 +156,7 @@ NGRAPH_OP(Acosh, ngraph::op::v3)
NGRAPH_OP(Asinh, ngraph::op::v3)
NGRAPH_OP(Atanh, ngraph::op::v3)
NGRAPH_OP(CTCLoss, ngraph::op::v4)
NGRAPH_OP(HSwish, ngraph::op::v4)
NGRAPH_OP(NonMaxSuppression, ngraph::op::v4)
NGRAPH_OP(Mish, ngraph::op::v4)
NGRAPH_OP(ReduceL1, ngraph::op::v4)

View File

@@ -0,0 +1,38 @@
//*****************************************************************************
// Copyright 2017-2020 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#pragma once
#include <cmath>
#include <cstddef>
namespace ngraph
{
namespace runtime
{
namespace reference
{
template <typename T>
void hswish(const T* arg, T* out, size_t count)
{
for (size_t i = 0; i < count; i++)
{
out[i] = arg[i] * std::min<T>(std::max<T>(arg[i] + 3.0f, 0.0f), 6.0f) / 6.0f;
}
}
}
}
}

View File

@@ -0,0 +1,78 @@
//*****************************************************************************
// Copyright 2017-2020 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#include "ngraph/op/hswish.hpp"
#include "ngraph/attribute_visitor.hpp"
#include "ngraph/op/constant.hpp"
#include "ngraph/runtime/host_tensor.hpp"
#include "ngraph/runtime/reference/hswish.hpp"
using namespace std;
using namespace ngraph;
NGRAPH_RTTI_DEFINITION(op::v4::HSwish, "HSwish", 4);
op::v4::HSwish::HSwish(const Output<Node>& arg)
: UnaryElementwiseArithmetic(arg)
{
constructor_validate_and_infer_types();
}
bool op::v4::HSwish::visit_attributes(AttributeVisitor& visitor)
{
return true;
}
shared_ptr<Node> op::v4::HSwish::clone_with_new_inputs(const OutputVector& new_args) const
{
return make_shared<op::v4::HSwish>(new_args.at(0));
}
namespace
{
template <element::Type_t ET>
inline bool evaluate(const HostTensorPtr& arg, const HostTensorPtr& out, const size_t count)
{
using T = typename element_type_traits<ET>::value_type;
runtime::reference::hswish<T>(arg->get_data_ptr<ET>(), out->get_data_ptr<ET>(), count);
return true;
}
bool evaluate_hswish(const HostTensorPtr& arg, const HostTensorPtr& out, const size_t count)
{
bool rc = true;
out->set_unary(arg);
switch (arg->get_element_type())
{
TYPE_CASE(bf16)(arg, out, count);
break;
TYPE_CASE(f16)(arg, out, count);
break;
TYPE_CASE(f32)(arg, out, count);
break;
default: rc = false; break;
}
return rc;
}
}
bool op::v4::HSwish::evaluate(const HostTensorVector& outputs, const HostTensorVector& inputs) const
{
return evaluate_hswish(inputs[0], outputs[0], shape_size(get_output_shape(0)));
}

View File

@@ -84,6 +84,7 @@ from ngraph.opset4 import group_convolution
from ngraph.opset4 import group_convolution_backprop_data
from ngraph.opset4 import gru_cell
from ngraph.opset4 import hard_sigmoid
from ngraph.opset4 import hswish
from ngraph.opset4 import interpolate
from ngraph.opset4 import less
from ngraph.opset4 import less_equal

View File

@@ -72,6 +72,7 @@ from ngraph.opset1.ops import group_convolution
from ngraph.opset1.ops import group_convolution_backprop_data
from ngraph.opset3.ops import gru_cell
from ngraph.opset1.ops import hard_sigmoid
from ngraph.opset4.ops import hswish
from ngraph.opset1.ops import interpolate
from ngraph.opset1.ops import less
from ngraph.opset1.ops import less_equal

View File

@@ -149,6 +149,16 @@ def mish(data: NodeInput, name: Optional[str] = None,) -> Node:
return _get_node_factory_opset4().create("Mish", as_nodes(data), {})
@nameable_op
def hswish(data: NodeInput, name: Optional[str] = None,) -> Node:
"""Return a node which performs HSwish (hard version of Swish).
:param data: Tensor with input data floating point type.
:return: The new node which performs HSwish
"""
return _get_node_factory_opset4().create("HSwish", as_nodes(data), {})
@nameable_op
def swish(
data: NodeInput,

View File

@@ -17,6 +17,7 @@ import numpy as np
import pytest
import ngraph as ng
from ngraph.impl import Shape, Type
from tests.test_ngraph.util import run_op_node, run_op_numeric_data
from tests import xfail_issue_35929, xfail_issue_34323
@@ -148,3 +149,14 @@ def test_erf():
result = run_op_numeric_data(input_tensor, ng.erf)
assert np.allclose(result, expected)
def test_hswish():
float_dtype = np.float32
data = ng.parameter(Shape([3, 10]), dtype=float_dtype, name="data")
node = ng.hswish(data)
assert node.get_type_name() == "HSwish"
assert node.get_output_size() == 1
assert list(node.get_output_shape(0)) == [3, 10]
assert node.get_output_element_type(0) == Type.f32

View File

@@ -70,6 +70,7 @@ set(SRC
node_input_output.cpp
nop_elimination.cpp
op.cpp
op_eval/hswish.cpp
op_eval/matmul.cpp
op_eval/mish.cpp
op_eval/non_zero.cpp
@@ -127,6 +128,7 @@ set(SRC
type_prop/group_convolution_backprop_data.cpp
type_prop/gru_cell.cpp
type_prop/hard_sigmoid.cpp
type_prop/hswish.cpp
type_prop/lrn.cpp
type_prop/lstm_cell.cpp
type_prop/lstm_sequence.cpp

View File

@@ -0,0 +1,48 @@
//*****************************************************************************
// Copyright 2017-2020 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#include <string>
#include <vector>
#include "gtest/gtest.h"
#include "ngraph/op/hswish.hpp"
#include "ngraph/runtime/host_tensor.hpp"
#include "ngraph/validation_util.hpp"
#include "runtime/backend.hpp"
#include "util/test_tools.hpp"
using namespace std;
using namespace ngraph;
TEST(op_eval, hswish)
{
auto p = make_shared<op::Parameter>(element::f32, Shape{3});
auto swish = make_shared<op::v4::HSwish>(p);
auto fun = make_shared<Function>(OutputVector{swish}, ParameterVector{p});
std::vector<float> inputs{-0.5f, 0.0f, 0.5f};
std::vector<float> expected_result{-0.208333f, 0.0f, 0.29166667f};
auto result = make_shared<HostTensor>();
ASSERT_TRUE(
fun->evaluate({result}, {make_host_tensor<element::Type_t::f32>(Shape{3}, inputs)}));
EXPECT_EQ(result->get_element_type(), element::f32);
EXPECT_EQ(result->get_shape(), Shape{3});
auto result_data = read_vector<float>(result);
for (auto i = 0; i < inputs.size(); i++)
EXPECT_NEAR(result_data[i], expected_result[i], 0.000001);
}

View File

@@ -0,0 +1,54 @@
//*****************************************************************************
// Copyright 2017-2020 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#include "gtest/gtest.h"
#include "ngraph/ngraph.hpp"
#include "util/type_prop.hpp"
using namespace std;
using namespace ngraph;
TEST(type_prop, hswish)
{
auto data = make_shared<op::Parameter>(element::f32, Shape{1, 3, 6});
auto hswish_func = make_shared<op::v4::HSwish>(data);
EXPECT_EQ(hswish_func->get_element_type(), element::f32);
EXPECT_EQ(hswish_func->get_shape(), data->get_output_shape(0));
}
TEST(type_prop, hswish_partial)
{
auto data = make_shared<op::Parameter>(element::f32, PartialShape{1, Dimension::dynamic(), 6});
auto hswish_func = make_shared<op::v4::HSwish>(data);
EXPECT_EQ(hswish_func->get_element_type(), element::f32);
ASSERT_TRUE(
hswish_func->get_output_partial_shape(0).same_scheme(data->get_output_partial_shape(0)));
// rank unknown
auto hswish_partial = make_shared<op::v4::HSwish>(
make_shared<op::Parameter>(element::f32, PartialShape::dynamic()));
ASSERT_TRUE(hswish_partial->get_output_partial_shape(0).same_scheme(PartialShape::dynamic()));
}
TEST(type_prop, hswish_partial_static_rank)
{
auto data = make_shared<op::Parameter>(element::f32, PartialShape{1, Dimension::dynamic(), 6});
auto hswish_func = make_shared<op::v4::HSwish>(data);
EXPECT_EQ(hswish_func->get_element_type(), element::f32);
ASSERT_TRUE(
hswish_func->get_output_partial_shape(0).same_scheme(data->get_output_partial_shape(0)));
ASSERT_TRUE(hswish_func->get_output_partial_shape(0).rank().is_static());
}