Enable HSwish - ngraph op, fusion/decomposition and reference implementation (#1770)
* Add HSwish operator to nGraph * Add HSwishFusion transformation * Update check_constant function * Add reference implementation for HSwish * Enable reference implemenation in HSwish evaluate * Add op_eval test * HSwish fusion transformation test * Add HSwishFusionWithoutRelu transformation * Add more hswish fusion tests * Register HSwishFusion pass in common_optimizations * Update HSwish reference implementation * Add HSwishFusion with Relu and Multiply * Add HSwishDecomposition transformation pass * Add HSwishDecomposition test * Add HSwish op to ngraph python API * Update HSwish fusion transformations * Remove HSwishFusion from common optimizations * Update hswish python API * Add bf16 to evaluate hswish * Update hswish python API * Move hswish reference implementation * UnaryElementwiseArithmetic inheritance * Enable HSwish callback for clDNN * Register HSwishDecomposition pass in ConvertOpSet1ToLegacy * Enable HSwishFusion pass in common optimizations * Use NGRAPH_RTTI_DECLARATION * Moved python hswish test to the test_ops_unary
This commit is contained in:
@@ -91,6 +91,7 @@ InferenceEngine::ICNNNetwork::Ptr clDNNEngine::CloneAndTransformNetwork(const In
|
||||
std::dynamic_pointer_cast<const ::ngraph::opset3::ShuffleChannels>(node) ||
|
||||
std::dynamic_pointer_cast<const ::ngraph::opset2::BatchToSpace>(node) ||
|
||||
std::dynamic_pointer_cast<const ::ngraph::opset2::SpaceToBatch>(node) ||
|
||||
std::dynamic_pointer_cast<const ::ngraph::opset4::HSwish>(node) ||
|
||||
std::dynamic_pointer_cast<const ::ngraph::opset4::ReduceL1>(node) ||
|
||||
std::dynamic_pointer_cast<const ::ngraph::opset4::ReduceL2>(node);
|
||||
};
|
||||
|
||||
@@ -0,0 +1,25 @@
|
||||
// Copyright (C) 2018-2020 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <transformations_visibility.hpp>
|
||||
#include <ngraph/pass/graph_rewrite.hpp>
|
||||
|
||||
namespace ngraph {
|
||||
namespace pass {
|
||||
|
||||
class TRANSFORMATIONS_API HSwishDecomposition;
|
||||
|
||||
} // namespace pass
|
||||
} // namespace ngraph
|
||||
|
||||
/**
|
||||
* @ingroup ie_transformation_common_api
|
||||
* @brief HSwishDecomposition transformation into sub-graph x * (min(Relu(x + 3), 6) * const(1/6).
|
||||
*/
|
||||
class ngraph::pass::HSwishDecomposition: public ngraph::pass::MatcherPass {
|
||||
public:
|
||||
HSwishDecomposition();
|
||||
};
|
||||
@@ -0,0 +1,63 @@
|
||||
// Copyright (C) 2018-2020 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <memory>
|
||||
#include <utility>
|
||||
|
||||
#include <transformations_visibility.hpp>
|
||||
#include <ngraph/pass/graph_rewrite.hpp>
|
||||
|
||||
namespace ngraph {
|
||||
namespace pass {
|
||||
|
||||
class TRANSFORMATIONS_API HSwishFusion;
|
||||
class TRANSFORMATIONS_API HSwishFusionWithReluDiv;
|
||||
class TRANSFORMATIONS_API HSwishFusionWithReluMul;
|
||||
class TRANSFORMATIONS_API HSwishFusionWithoutRelu;
|
||||
|
||||
|
||||
} // namespace pass
|
||||
} // namespace ngraph
|
||||
|
||||
/**
|
||||
* @ingroup ie_transformation_common_api
|
||||
* @brief HSwishFusion transformation replaces various sub-graphs with a HSwish op.
|
||||
*/
|
||||
class ngraph::pass::HSwishFusion: public ngraph::pass::GraphRewrite {
|
||||
public:
|
||||
HSwishFusion() {
|
||||
add_matcher<ngraph::pass::HSwishFusionWithReluDiv>();
|
||||
add_matcher<ngraph::pass::HSwishFusionWithReluMul>();
|
||||
add_matcher<ngraph::pass::HSwishFusionWithoutRelu>();
|
||||
}
|
||||
};
|
||||
|
||||
/**
|
||||
* @ingroup ie_transformation_common_api
|
||||
* @brief HSwishFusion transformation replaces a sub-graph (x * (min(Relu(x + 3), 6))) / 6 with a HSwish op.
|
||||
*/
|
||||
class ngraph::pass::HSwishFusionWithReluDiv: public ngraph::pass::MatcherPass {
|
||||
public:
|
||||
HSwishFusionWithReluDiv();
|
||||
};
|
||||
|
||||
/**
|
||||
* @ingroup ie_transformation_common_api
|
||||
* @brief HSwishFusion transformation replaces a sub-graph (x * (min(Relu(x + 3), 6)) * const(1/6) with a HSwish op.
|
||||
*/
|
||||
class ngraph::pass::HSwishFusionWithReluMul: public ngraph::pass::MatcherPass {
|
||||
public:
|
||||
HSwishFusionWithReluMul();
|
||||
};
|
||||
|
||||
/**
|
||||
* @ingroup ie_transformation_common_api
|
||||
* @brief HSwishFusion transformation replaces a sub-graph x * (min(max(x + 3, 0), 6) / 6) with a HSwish op.
|
||||
*/
|
||||
class ngraph::pass::HSwishFusionWithoutRelu: public ngraph::pass::MatcherPass {
|
||||
public:
|
||||
HSwishFusionWithoutRelu();
|
||||
};
|
||||
@@ -14,6 +14,7 @@
|
||||
#include "transformations/itt.hpp"
|
||||
#include "transformations/mish_fusion.hpp"
|
||||
#include "transformations/swish_fusion.hpp"
|
||||
#include "transformations/hswish_fusion.hpp"
|
||||
|
||||
#include <ngraph/pass/manager.hpp>
|
||||
#include <ngraph/pass/nop_elimination.hpp>
|
||||
@@ -37,6 +38,7 @@ bool ngraph::pass::CommonOptimizations::run_on_function(std::shared_ptr<ngraph::
|
||||
manager.register_pass<ngraph::pass::DepthToSpaceFusion>();
|
||||
manager.register_pass<ngraph::pass::MishFusion>();
|
||||
manager.register_pass<ngraph::pass::SwishFusion>();
|
||||
manager.register_pass<ngraph::pass::HSwishFusion>();
|
||||
manager.register_pass<ngraph::pass::ConvertPadToGroupConvolution>();
|
||||
|
||||
manager.set_callback(m_transformation_callback);
|
||||
|
||||
@@ -47,6 +47,7 @@
|
||||
#include <transformations/convert_opset1_to_legacy/convert_hard_sigmoid_to_hard_sigmoid_ie.hpp>
|
||||
#include <transformations/lin_op_sequence_fusoin.hpp>
|
||||
#include <transformations/common_optimizations/conv_mul_fusion.hpp>
|
||||
#include <transformations/hswish_decomposition.hpp>
|
||||
#include <transformations/reduce_l1_decomposition.hpp>
|
||||
#include <transformations/reduce_l2_decomposition.hpp>
|
||||
|
||||
@@ -71,6 +72,10 @@ bool ngraph::pass::ConvertOpSet1ToLegacy::run_on_function(std::shared_ptr<ngraph
|
||||
manager.register_pass<ngraph::pass::ReduceL1Decomposition>();
|
||||
manager.register_pass<ngraph::pass::ReduceL2Decomposition>();
|
||||
|
||||
// HSwishDecomposition produce Minimum, Relu and Multiply operations
|
||||
// so it must be executed before
|
||||
manager.register_pass<ngraph::pass::HSwishDecomposition>();
|
||||
|
||||
// List if Decomposition and Conversion transformations that can be
|
||||
// applied simultaneously in a single graph traversal
|
||||
auto decomp = manager.register_pass<ngraph::pass::GraphRewrite>();
|
||||
|
||||
@@ -0,0 +1,44 @@
|
||||
// Copyright (C) 2018-2020 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#include "transformations/hswish_decomposition.hpp"
|
||||
|
||||
#include <memory>
|
||||
|
||||
#include <ngraph/opsets/opset4.hpp>
|
||||
#include <ngraph/rt_info.hpp>
|
||||
#include <ngraph/pattern/op/wrap_type.hpp>
|
||||
|
||||
ngraph::pass::HSwishDecomposition::HSwishDecomposition() {
|
||||
// Decomposes HSwish(x) op into sub-graph x * (min(Relu(x + 3), 6) * const(1/6)
|
||||
auto hswish = ngraph::pattern::wrap_type<opset4::HSwish>();
|
||||
|
||||
ngraph::matcher_pass_callback callback = [=](ngraph::pattern::Matcher &m) {
|
||||
auto &pattern_to_output = m.get_pattern_value_map();
|
||||
auto hswish_node = pattern_to_output.at(hswish).get_node_shared_ptr();
|
||||
|
||||
if (m_transformation_callback(hswish_node)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
auto input_type = hswish_node->input_value(0).get_element_type();
|
||||
auto add_constant = ngraph::opset4::Constant::create(input_type, ngraph::Shape{}, {3.0});
|
||||
auto add = std::make_shared<ngraph::opset4::Add>(hswish_node->input_value(0), add_constant);
|
||||
auto relu = std::make_shared<ngraph::opset4::Relu>(add);
|
||||
auto min_constant = ngraph::opset4::Constant::create(input_type, ngraph::Shape{}, {6.0});
|
||||
auto min = std::make_shared<ngraph::opset4::Minimum>(relu, min_constant);
|
||||
auto mul_first = std::make_shared<ngraph::opset4::Multiply>(hswish_node->input_value(0), min);
|
||||
auto mul_constant = ngraph::opset4::Constant::create(input_type, ngraph::Shape{}, {(1.0/6.0)}); // const(1/6)
|
||||
auto mul_second = std::make_shared<ngraph::opset4::Multiply>(mul_first, mul_constant);
|
||||
|
||||
mul_second->set_friendly_name(m.get_match_root()->get_friendly_name());
|
||||
ngraph::copy_runtime_info(hswish_node,
|
||||
{add_constant, add, relu, min_constant, min, mul_first, mul_constant, mul_second});
|
||||
ngraph::replace_node(m.get_match_root(), mul_second);
|
||||
return true;
|
||||
};
|
||||
|
||||
auto m = std::make_shared<ngraph::pattern::Matcher>(hswish, "HSwishDecomposition");
|
||||
register_matcher(m, callback);
|
||||
}
|
||||
@@ -0,0 +1,180 @@
|
||||
// Copyright (C) 2018-2020 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#include "transformations/hswish_fusion.hpp"
|
||||
|
||||
#include <memory>
|
||||
|
||||
#include <ngraph/opsets/opset4.hpp>
|
||||
#include <ngraph/rt_info.hpp>
|
||||
#include <ngraph/pattern/op/wrap_type.hpp>
|
||||
|
||||
bool check_constant_value(const std::shared_ptr<ngraph::opset4::Constant>& constant,
|
||||
const float value,
|
||||
float epsilon = std::numeric_limits<float>::epsilon()) {
|
||||
if (!constant) {
|
||||
return false;
|
||||
}
|
||||
if (constant->get_element_type() == ngraph::element::f32 || constant->get_element_type() == ngraph::element::f16) {
|
||||
auto data = constant->cast_vector<float>();
|
||||
if (data.size() != 1 || std::fabs(data[0] - value) > epsilon) {
|
||||
return false;
|
||||
}
|
||||
} else {
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
ngraph::pass::HSwishFusionWithReluDiv::HSwishFusionWithReluDiv() {
|
||||
// Replaces a sub-graph (x * (min(Relu(x + 3), 6)) / 6 with a HSwish op.
|
||||
auto input = ngraph::pattern::any_input();
|
||||
auto add_constant = ngraph::pattern::wrap_type<ngraph::opset4::Constant>();
|
||||
auto add = std::make_shared<ngraph::opset4::Add>(input, add_constant);
|
||||
auto relu = std::make_shared<ngraph::opset4::Relu>(add);
|
||||
auto min_constant = ngraph::pattern::wrap_type<ngraph::opset4::Constant>();
|
||||
auto min = std::make_shared<ngraph::opset4::Minimum>(relu, min_constant);
|
||||
auto mul = std::make_shared<ngraph::opset4::Multiply>(input, min);
|
||||
auto div_constant = ngraph::pattern::wrap_type<ngraph::opset4::Constant>();
|
||||
auto div = std::make_shared<ngraph::opset4::Divide>(mul, div_constant);
|
||||
|
||||
ngraph::matcher_pass_callback callback = [=](ngraph::pattern::Matcher &m) {
|
||||
auto &pattern_to_output = m.get_pattern_value_map();
|
||||
auto x_output = pattern_to_output.at(input);
|
||||
|
||||
auto add_const_value = std::dynamic_pointer_cast<ngraph::opset4::Constant>(pattern_to_output.at(add_constant).get_node_shared_ptr());
|
||||
auto min_const_value = std::dynamic_pointer_cast<ngraph::opset4::Constant>(pattern_to_output.at(min_constant).get_node_shared_ptr());
|
||||
auto div_const_value = std::dynamic_pointer_cast<ngraph::opset4::Constant>(pattern_to_output.at(div_constant).get_node_shared_ptr());
|
||||
|
||||
bool valid_constant_values = check_constant_value(add_const_value, 3.0)
|
||||
&& check_constant_value(min_const_value, 6.0)
|
||||
&& check_constant_value(div_const_value, 6.0);
|
||||
|
||||
if (!valid_constant_values) {
|
||||
return false;
|
||||
}
|
||||
|
||||
auto hswish = std::make_shared<ngraph::opset4::HSwish>(x_output);
|
||||
|
||||
hswish->set_friendly_name(m.get_match_root()->get_friendly_name());
|
||||
ngraph::copy_runtime_info({ pattern_to_output.at(add_constant).get_node_shared_ptr(),
|
||||
pattern_to_output.at(add).get_node_shared_ptr(),
|
||||
pattern_to_output.at(relu).get_node_shared_ptr(),
|
||||
pattern_to_output.at(min_constant).get_node_shared_ptr(),
|
||||
pattern_to_output.at(min).get_node_shared_ptr(),
|
||||
pattern_to_output.at(mul).get_node_shared_ptr(),
|
||||
pattern_to_output.at(div_constant).get_node_shared_ptr(),
|
||||
pattern_to_output.at(div).get_node_shared_ptr(),
|
||||
},
|
||||
hswish);
|
||||
ngraph::replace_node(m.get_match_root(), hswish);
|
||||
return true;
|
||||
};
|
||||
|
||||
auto m = std::make_shared<ngraph::pattern::Matcher>(div, "HSwishWithReluDivFusion");
|
||||
register_matcher(m, callback);
|
||||
}
|
||||
|
||||
ngraph::pass::HSwishFusionWithReluMul::HSwishFusionWithReluMul() {
|
||||
// Replaces a sub-graph (x * (min(Relu(x + 3), 6)) * const(1/6) with a HSwish op.
|
||||
auto input = ngraph::pattern::any_input();
|
||||
auto add_constant = ngraph::pattern::wrap_type<ngraph::opset4::Constant>();
|
||||
auto add = std::make_shared<ngraph::opset4::Add>(input, add_constant);
|
||||
auto relu = std::make_shared<ngraph::opset4::Relu>(add);
|
||||
auto min_constant = ngraph::pattern::wrap_type<ngraph::opset4::Constant>();
|
||||
auto min = std::make_shared<ngraph::opset4::Minimum>(relu, min_constant);
|
||||
auto mul_first = std::make_shared<ngraph::opset4::Multiply>(input, min);
|
||||
auto mul_constant = ngraph::pattern::wrap_type<ngraph::opset4::Constant>();
|
||||
auto mul_second = std::make_shared<ngraph::opset4::Multiply>(mul_first, mul_constant);
|
||||
|
||||
ngraph::matcher_pass_callback callback = [=](ngraph::pattern::Matcher &m) {
|
||||
auto &pattern_to_output = m.get_pattern_value_map();
|
||||
auto x_output = pattern_to_output.at(input);
|
||||
|
||||
auto add_const_value = std::dynamic_pointer_cast<ngraph::opset4::Constant>(pattern_to_output.at(add_constant).get_node_shared_ptr());
|
||||
auto min_const_value = std::dynamic_pointer_cast<ngraph::opset4::Constant>(pattern_to_output.at(min_constant).get_node_shared_ptr());
|
||||
auto mul_const_value = std::dynamic_pointer_cast<ngraph::opset4::Constant>(pattern_to_output.at(mul_constant).get_node_shared_ptr());
|
||||
|
||||
bool valid_constant_values = check_constant_value(add_const_value, 3.0)
|
||||
&& check_constant_value(min_const_value, 6.0)
|
||||
&& check_constant_value(mul_const_value, (1.0/6.0), 0.0001);
|
||||
|
||||
if (!valid_constant_values) {
|
||||
return false;
|
||||
}
|
||||
|
||||
auto hswish = std::make_shared<ngraph::opset4::HSwish>(x_output);
|
||||
|
||||
hswish->set_friendly_name(m.get_match_root()->get_friendly_name());
|
||||
ngraph::copy_runtime_info({ pattern_to_output.at(add_constant).get_node_shared_ptr(),
|
||||
pattern_to_output.at(add).get_node_shared_ptr(),
|
||||
pattern_to_output.at(relu).get_node_shared_ptr(),
|
||||
pattern_to_output.at(min_constant).get_node_shared_ptr(),
|
||||
pattern_to_output.at(min).get_node_shared_ptr(),
|
||||
pattern_to_output.at(mul_first).get_node_shared_ptr(),
|
||||
pattern_to_output.at(mul_constant).get_node_shared_ptr(),
|
||||
pattern_to_output.at(mul_second).get_node_shared_ptr()
|
||||
},
|
||||
hswish);
|
||||
ngraph::replace_node(m.get_match_root(), hswish);
|
||||
return true;
|
||||
};
|
||||
|
||||
auto m = std::make_shared<ngraph::pattern::Matcher>(mul_second, "HSwishWithReluMulFusion");
|
||||
register_matcher(m, callback);
|
||||
}
|
||||
|
||||
|
||||
ngraph::pass::HSwishFusionWithoutRelu::HSwishFusionWithoutRelu() {
|
||||
// Replaces a sub-graph x * (min(max(x + 3, 0), 6) / 6) with a HSwish op.
|
||||
auto input = ngraph::pattern::any_input();
|
||||
auto add_constant = ngraph::pattern::wrap_type<ngraph::opset4::Constant>();
|
||||
auto add = std::make_shared<ngraph::opset4::Add>(input, add_constant);
|
||||
auto max_constant = ngraph::pattern::wrap_type<ngraph::opset4::Constant>();
|
||||
auto max = std::make_shared<ngraph::opset4::Maximum>(add, max_constant);
|
||||
auto min_constant = ngraph::pattern::wrap_type<ngraph::opset4::Constant>();
|
||||
auto min = std::make_shared<ngraph::opset4::Minimum>(max, min_constant);
|
||||
auto div_constant = ngraph::pattern::wrap_type<ngraph::opset4::Constant>();
|
||||
auto div = std::make_shared<ngraph::opset4::Divide>(min, div_constant);
|
||||
auto mul = std::make_shared<ngraph::opset4::Multiply>(input, div);
|
||||
|
||||
ngraph::matcher_pass_callback callback = [=](ngraph::pattern::Matcher &m) {
|
||||
auto &pattern_to_output = m.get_pattern_value_map();
|
||||
auto x_output = pattern_to_output.at(input);
|
||||
|
||||
auto add_const_value = std::dynamic_pointer_cast<ngraph::opset4::Constant>(pattern_to_output.at(add_constant).get_node_shared_ptr());
|
||||
auto max_const_value = std::dynamic_pointer_cast<ngraph::opset4::Constant>(pattern_to_output.at(max_constant).get_node_shared_ptr());
|
||||
auto min_const_value = std::dynamic_pointer_cast<ngraph::opset4::Constant>(pattern_to_output.at(min_constant).get_node_shared_ptr());
|
||||
auto div_const_value = std::dynamic_pointer_cast<ngraph::opset4::Constant>(pattern_to_output.at(div_constant).get_node_shared_ptr());
|
||||
|
||||
bool valid_constant_values = check_constant_value(add_const_value, 3.0)
|
||||
&& check_constant_value(max_const_value, 0.0)
|
||||
&& check_constant_value(min_const_value, 6.0)
|
||||
&& check_constant_value(div_const_value, 6.0);
|
||||
|
||||
if (!valid_constant_values) {
|
||||
return false;
|
||||
}
|
||||
|
||||
auto hswish = std::make_shared<ngraph::opset4::HSwish>(x_output);
|
||||
|
||||
hswish->set_friendly_name(m.get_match_root()->get_friendly_name());
|
||||
ngraph::copy_runtime_info({ pattern_to_output.at(add_constant).get_node_shared_ptr(),
|
||||
pattern_to_output.at(add).get_node_shared_ptr(),
|
||||
pattern_to_output.at(max_constant).get_node_shared_ptr(),
|
||||
pattern_to_output.at(max).get_node_shared_ptr(),
|
||||
pattern_to_output.at(min_constant).get_node_shared_ptr(),
|
||||
pattern_to_output.at(min).get_node_shared_ptr(),
|
||||
pattern_to_output.at(div_constant).get_node_shared_ptr(),
|
||||
pattern_to_output.at(div).get_node_shared_ptr(),
|
||||
pattern_to_output.at(mul).get_node_shared_ptr()
|
||||
},
|
||||
hswish);
|
||||
ngraph::replace_node(m.get_match_root(), hswish);
|
||||
return true;
|
||||
};
|
||||
|
||||
auto m = std::make_shared<ngraph::pattern::Matcher>(mul, "HSwishWithoutReluFusion");
|
||||
register_matcher(m, callback);
|
||||
}
|
||||
@@ -0,0 +1,52 @@
|
||||
// Copyright (C) 2020 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#include <gtest/gtest.h>
|
||||
|
||||
#include <string>
|
||||
#include <memory>
|
||||
|
||||
#include <ngraph/function.hpp>
|
||||
#include <ngraph/opsets/opset4.hpp>
|
||||
#include <ngraph/pass/manager.hpp>
|
||||
#include <transformations/hswish_decomposition.hpp>
|
||||
#include <transformations/init_node_info.hpp>
|
||||
#include <transformations/utils/utils.hpp>
|
||||
|
||||
#include "common_test_utils/ngraph_test_utils.hpp"
|
||||
|
||||
using namespace testing;
|
||||
|
||||
TEST(TransformationTests, HSwishDecompositionTest) {
|
||||
std::shared_ptr<ngraph::Function> f(nullptr), f_ref(nullptr);
|
||||
{
|
||||
auto input = std::make_shared<ngraph::opset4::Parameter>(ngraph::element::f32, ngraph::PartialShape::dynamic(1));
|
||||
auto hswish = std::make_shared<ngraph::opset4::HSwish>(input);
|
||||
|
||||
f = std::make_shared<ngraph::Function>(ngraph::NodeVector{hswish}, ngraph::ParameterVector{input});
|
||||
|
||||
ngraph::pass::Manager manager;
|
||||
manager.register_pass<ngraph::pass::InitNodeInfo>();
|
||||
manager.register_pass<ngraph::pass::HSwishDecomposition>();
|
||||
manager.run_passes(f);
|
||||
ASSERT_NO_THROW(check_rt_info(f));
|
||||
}
|
||||
|
||||
{
|
||||
auto input = std::make_shared<ngraph::opset4::Parameter>(ngraph::element::f16, ngraph::PartialShape::dynamic(1));
|
||||
auto add_constant = ngraph::opset4::Constant::create(ngraph::element::f16, ngraph::Shape{}, {3.0});
|
||||
auto add = std::make_shared<ngraph::opset4::Add>(input, add_constant);
|
||||
auto relu = std::make_shared<ngraph::opset4::Relu>(add);
|
||||
auto min_constant = ngraph::opset4::Constant::create(ngraph::element::f16, ngraph::Shape{}, {6.0});
|
||||
auto min = std::make_shared<ngraph::opset4::Minimum>(relu, min_constant);
|
||||
auto mul_first = std::make_shared<ngraph::opset4::Multiply>(input, min);
|
||||
auto mul_constant = ngraph::opset4::Constant::create(ngraph::element::f16, ngraph::Shape{}, {0.1666666716});
|
||||
auto mul_second = std::make_shared<ngraph::opset4::Multiply>(mul_first, mul_constant);
|
||||
|
||||
f_ref = std::make_shared<ngraph::Function>(ngraph::NodeVector{mul_second}, ngraph::ParameterVector{input});
|
||||
}
|
||||
|
||||
auto res = compare_functions(f, f_ref);
|
||||
ASSERT_TRUE(res.first) << res.second;
|
||||
}
|
||||
@@ -0,0 +1,274 @@
|
||||
// Copyright (C) 2020 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#include <gtest/gtest.h>
|
||||
|
||||
#include <string>
|
||||
#include <memory>
|
||||
|
||||
#include <ngraph/function.hpp>
|
||||
#include <ngraph/opsets/opset4.hpp>
|
||||
#include <ngraph/pass/manager.hpp>
|
||||
#include <transformations/hswish_fusion.hpp>
|
||||
#include <transformations/init_node_info.hpp>
|
||||
#include <transformations/utils/utils.hpp>
|
||||
|
||||
#include "common_test_utils/ngraph_test_utils.hpp"
|
||||
|
||||
using namespace testing;
|
||||
|
||||
TEST(TransformationTests, HSwishFusionWithReluDivF16) {
|
||||
std::shared_ptr<ngraph::Function> f(nullptr), f_ref(nullptr);
|
||||
{
|
||||
auto input = std::make_shared<ngraph::opset4::Parameter>(ngraph::element::f16, ngraph::PartialShape::dynamic(1));
|
||||
auto add_constant = ngraph::opset4::Constant::create(ngraph::element::f16, ngraph::Shape{}, {3.0});
|
||||
auto add = std::make_shared<ngraph::opset4::Add>(input, add_constant);
|
||||
auto relu = std::make_shared<ngraph::opset4::Relu>(add);
|
||||
auto min_constant = ngraph::opset4::Constant::create(ngraph::element::f16, ngraph::Shape{}, {6.0});
|
||||
auto min = std::make_shared<ngraph::opset4::Minimum>(relu, min_constant);
|
||||
auto mul = std::make_shared<ngraph::opset4::Multiply>(input, min);
|
||||
auto div_constant = ngraph::opset4::Constant::create(ngraph::element::f16, ngraph::Shape{}, {6.0});
|
||||
auto div = std::make_shared<ngraph::opset4::Divide>(mul, div_constant);
|
||||
|
||||
f = std::make_shared<ngraph::Function>(ngraph::NodeVector{div}, ngraph::ParameterVector{input});
|
||||
|
||||
ngraph::pass::Manager manager;
|
||||
manager.register_pass<ngraph::pass::InitNodeInfo>();
|
||||
manager.register_pass<ngraph::pass::HSwishFusionWithReluDiv>();
|
||||
manager.run_passes(f);
|
||||
ASSERT_NO_THROW(check_rt_info(f));
|
||||
}
|
||||
|
||||
{
|
||||
auto input = std::make_shared<ngraph::opset4::Parameter>(ngraph::element::f16, ngraph::PartialShape::dynamic(1));
|
||||
auto hswish = std::make_shared<ngraph::opset4::HSwish>(input);
|
||||
|
||||
f_ref = std::make_shared<ngraph::Function>(ngraph::NodeVector{hswish}, ngraph::ParameterVector{input});
|
||||
}
|
||||
|
||||
auto res = compare_functions(f, f_ref);
|
||||
ASSERT_TRUE(res.first) << res.second;
|
||||
}
|
||||
|
||||
TEST(TransformationTests, HSwishFusionWithReluDivF32) {
|
||||
std::shared_ptr<ngraph::Function> f(nullptr), f_ref(nullptr);
|
||||
{
|
||||
auto input = std::make_shared<ngraph::opset4::Parameter>(ngraph::element::f32, ngraph::Shape{});
|
||||
auto add_constant = ngraph::opset4::Constant::create(ngraph::element::f32, ngraph::Shape{}, {3.0});
|
||||
auto add = std::make_shared<ngraph::opset4::Add>(input, add_constant);
|
||||
auto relu = std::make_shared<ngraph::opset4::Relu>(add);
|
||||
auto min_constant = ngraph::opset4::Constant::create(ngraph::element::f32, ngraph::Shape{}, {6.0});
|
||||
auto min = std::make_shared<ngraph::opset4::Minimum>(relu, min_constant);
|
||||
auto mul = std::make_shared<ngraph::opset4::Multiply>(input, min);
|
||||
auto div_constant = ngraph::opset4::Constant::create(ngraph::element::f32, ngraph::Shape{}, {6.0});
|
||||
auto div = std::make_shared<ngraph::opset4::Divide>(mul, div_constant);
|
||||
|
||||
f = std::make_shared<ngraph::Function>(ngraph::NodeVector{div}, ngraph::ParameterVector{input});
|
||||
|
||||
ngraph::pass::Manager manager;
|
||||
manager.register_pass<ngraph::pass::InitNodeInfo>();
|
||||
manager.register_pass<ngraph::pass::HSwishFusionWithReluDiv>();
|
||||
manager.run_passes(f);
|
||||
ASSERT_NO_THROW(check_rt_info(f));
|
||||
}
|
||||
|
||||
{
|
||||
auto input = std::make_shared<ngraph::opset4::Parameter>(ngraph::element::f32, ngraph::Shape{});
|
||||
auto hswish = std::make_shared<ngraph::opset4::HSwish>(input);
|
||||
|
||||
f_ref = std::make_shared<ngraph::Function>(ngraph::NodeVector{hswish}, ngraph::ParameterVector{input});
|
||||
}
|
||||
|
||||
auto res = compare_functions(f, f_ref);
|
||||
ASSERT_TRUE(res.first) << res.second;
|
||||
}
|
||||
|
||||
TEST(TransformationTests, HSwishFusionWithReluMul) {
|
||||
std::shared_ptr<ngraph::Function> f(nullptr), f_ref(nullptr);
|
||||
{
|
||||
auto input = std::make_shared<ngraph::opset4::Parameter>(ngraph::element::f16, ngraph::PartialShape::dynamic(1));
|
||||
auto add_constant = ngraph::opset4::Constant::create(ngraph::element::f16, ngraph::Shape{}, {3.0});
|
||||
auto add = std::make_shared<ngraph::opset4::Add>(input, add_constant);
|
||||
auto relu = std::make_shared<ngraph::opset4::Relu>(add);
|
||||
auto min_constant = ngraph::opset4::Constant::create(ngraph::element::f16, ngraph::Shape{}, {6.0});
|
||||
auto min = std::make_shared<ngraph::opset4::Minimum>(relu, min_constant);
|
||||
auto mul_first = std::make_shared<ngraph::opset4::Multiply>(input, min);
|
||||
auto mul_constant = ngraph::opset4::Constant::create(ngraph::element::f16, ngraph::Shape{}, {0.1666666716});
|
||||
auto mul_second = std::make_shared<ngraph::opset4::Multiply>(mul_first, mul_constant);
|
||||
|
||||
f = std::make_shared<ngraph::Function>(ngraph::NodeVector{mul_second}, ngraph::ParameterVector{input});
|
||||
|
||||
ngraph::pass::Manager manager;
|
||||
manager.register_pass<ngraph::pass::InitNodeInfo>();
|
||||
manager.register_pass<ngraph::pass::HSwishFusionWithReluMul>();
|
||||
manager.run_passes(f);
|
||||
ASSERT_NO_THROW(check_rt_info(f));
|
||||
}
|
||||
|
||||
{
|
||||
auto input = std::make_shared<ngraph::opset4::Parameter>(ngraph::element::f16, ngraph::PartialShape::dynamic(1));
|
||||
auto hswish = std::make_shared<ngraph::opset4::HSwish>(input);
|
||||
|
||||
f_ref = std::make_shared<ngraph::Function>(ngraph::NodeVector{hswish}, ngraph::ParameterVector{input});
|
||||
}
|
||||
|
||||
auto res = compare_functions(f, f_ref);
|
||||
ASSERT_TRUE(res.first) << res.second;
|
||||
}
|
||||
|
||||
TEST(TransformationTests, HSwishFusionWithoutRelu) {
|
||||
std::shared_ptr<ngraph::Function> f(nullptr), f_ref(nullptr);
|
||||
{
|
||||
auto input = std::make_shared<ngraph::opset4::Parameter>(ngraph::element::f16, ngraph::PartialShape::dynamic(1));
|
||||
auto add_constant = ngraph::opset4::Constant::create(ngraph::element::f16, ngraph::Shape{}, {3.0});
|
||||
auto add = std::make_shared<ngraph::opset4::Add>(input, add_constant);
|
||||
auto max_constant = ngraph::opset4::Constant::create(ngraph::element::f16, ngraph::Shape{}, {0.0});
|
||||
auto max = std::make_shared<ngraph::opset4::Maximum>(add, max_constant);
|
||||
auto min_constant = ngraph::opset4::Constant::create(ngraph::element::f16, ngraph::Shape{}, {6.0});
|
||||
auto min = std::make_shared<ngraph::opset4::Minimum>(max, min_constant);
|
||||
auto div_constant = ngraph::opset4::Constant::create(ngraph::element::f16, ngraph::Shape{}, {6.0});
|
||||
auto div = std::make_shared<ngraph::opset4::Divide>(min, div_constant);
|
||||
auto mul = std::make_shared<ngraph::opset4::Multiply>(input, div);
|
||||
|
||||
f = std::make_shared<ngraph::Function>(ngraph::NodeVector{mul}, ngraph::ParameterVector{input});
|
||||
|
||||
ngraph::pass::Manager manager;
|
||||
manager.register_pass<ngraph::pass::InitNodeInfo>();
|
||||
manager.register_pass<ngraph::pass::HSwishFusionWithoutRelu>();
|
||||
manager.run_passes(f);
|
||||
ASSERT_NO_THROW(check_rt_info(f));
|
||||
}
|
||||
|
||||
{
|
||||
auto input = std::make_shared<ngraph::opset4::Parameter>(ngraph::element::f16, ngraph::PartialShape::dynamic(1));
|
||||
auto hswish = std::make_shared<ngraph::opset4::HSwish>(input);
|
||||
|
||||
f_ref = std::make_shared<ngraph::Function>(ngraph::NodeVector{hswish}, ngraph::ParameterVector{input});
|
||||
}
|
||||
|
||||
auto res = compare_functions(f, f_ref);
|
||||
ASSERT_TRUE(res.first) << res.second;
|
||||
}
|
||||
|
||||
TEST(TransformationTests, HSwishFusionWithReluMulWrongConstValue) {
|
||||
std::shared_ptr<ngraph::Function> f(nullptr), f_ref(nullptr);
|
||||
{
|
||||
auto input = std::make_shared<ngraph::opset4::Parameter>(ngraph::element::f16, ngraph::PartialShape::dynamic(1));
|
||||
auto add_constant = ngraph::opset4::Constant::create(ngraph::element::f16, ngraph::Shape{}, {3.0});
|
||||
auto add = std::make_shared<ngraph::opset4::Add>(input, add_constant);
|
||||
auto relu = std::make_shared<ngraph::opset4::Relu>(add);
|
||||
auto min_constant = ngraph::opset4::Constant::create(ngraph::element::f16, ngraph::Shape{}, {6.0});
|
||||
auto min = std::make_shared<ngraph::opset4::Minimum>(relu, min_constant);
|
||||
auto mul_first = std::make_shared<ngraph::opset4::Multiply>(input, min);
|
||||
auto mul_constant = ngraph::opset4::Constant::create(ngraph::element::f16, ngraph::Shape{}, {0.167});
|
||||
auto mul_second = std::make_shared<ngraph::opset4::Multiply>(mul_first, mul_constant);
|
||||
|
||||
f = std::make_shared<ngraph::Function>(ngraph::NodeVector{mul_second}, ngraph::ParameterVector{input});
|
||||
|
||||
ngraph::pass::Manager manager;
|
||||
manager.register_pass<ngraph::pass::InitNodeInfo>();
|
||||
manager.register_pass<ngraph::pass::HSwishFusionWithReluMul>();
|
||||
manager.run_passes(f);
|
||||
ASSERT_NO_THROW(check_rt_info(f));
|
||||
}
|
||||
|
||||
{
|
||||
auto input = std::make_shared<ngraph::opset4::Parameter>(ngraph::element::f16, ngraph::PartialShape::dynamic(1));
|
||||
auto add_constant = ngraph::opset4::Constant::create(ngraph::element::f16, ngraph::Shape{}, {3.0});
|
||||
auto add = std::make_shared<ngraph::opset4::Add>(input, add_constant);
|
||||
auto relu = std::make_shared<ngraph::opset4::Relu>(add);
|
||||
auto min_constant = ngraph::opset4::Constant::create(ngraph::element::f16, ngraph::Shape{}, {6.0});
|
||||
auto min = std::make_shared<ngraph::opset4::Minimum>(relu, min_constant);
|
||||
auto mul_first = std::make_shared<ngraph::opset4::Multiply>(input, min);
|
||||
auto mul_constant = ngraph::opset4::Constant::create(ngraph::element::f16, ngraph::Shape{}, {0.167});
|
||||
auto mul_second = std::make_shared<ngraph::opset4::Multiply>(mul_first, mul_constant);
|
||||
|
||||
f_ref = std::make_shared<ngraph::Function>(ngraph::NodeVector{mul_second}, ngraph::ParameterVector{input});
|
||||
}
|
||||
|
||||
auto res = compare_functions(f, f_ref);
|
||||
ASSERT_TRUE(res.first) << res.second;
|
||||
}
|
||||
|
||||
TEST(TransformationTests, HSwishFusionWithReluDivWrongConstValue) {
|
||||
std::shared_ptr<ngraph::Function> f(nullptr), f_ref(nullptr);
|
||||
{
|
||||
auto input = std::make_shared<ngraph::opset4::Parameter>(ngraph::element::f16, ngraph::Shape{});
|
||||
auto add_constant = ngraph::opset4::Constant::create(ngraph::element::f16, ngraph::Shape{}, {3.01});
|
||||
auto add = std::make_shared<ngraph::opset4::Add>(input, add_constant);
|
||||
auto relu = std::make_shared<ngraph::opset4::Relu>(add);
|
||||
auto min_constant = ngraph::opset4::Constant::create(ngraph::element::f16, ngraph::Shape{}, {6.002});
|
||||
auto min = std::make_shared<ngraph::opset4::Minimum>(relu, min_constant);
|
||||
auto mul = std::make_shared<ngraph::opset4::Multiply>(input, min);
|
||||
auto div_constant = ngraph::opset4::Constant::create(ngraph::element::f16, ngraph::Shape{}, {0.0});
|
||||
auto div = std::make_shared<ngraph::opset4::Divide>(mul, div_constant);
|
||||
|
||||
f = std::make_shared<ngraph::Function>(ngraph::NodeVector{div}, ngraph::ParameterVector{input});
|
||||
|
||||
ngraph::pass::Manager manager;
|
||||
manager.register_pass<ngraph::pass::InitNodeInfo>();
|
||||
manager.register_pass<ngraph::pass::HSwishFusionWithReluDiv>();
|
||||
manager.run_passes(f);
|
||||
ASSERT_NO_THROW(check_rt_info(f));
|
||||
}
|
||||
|
||||
{
|
||||
auto input = std::make_shared<ngraph::opset4::Parameter>(ngraph::element::f16, ngraph::Shape{});
|
||||
auto add_constant = ngraph::opset4::Constant::create(ngraph::element::f16, ngraph::Shape{}, {3.01});
|
||||
auto add = std::make_shared<ngraph::opset4::Add>(input, add_constant);
|
||||
auto relu = std::make_shared<ngraph::opset4::Relu>(add);
|
||||
auto min_constant = ngraph::opset4::Constant::create(ngraph::element::f16, ngraph::Shape{}, {6.002});
|
||||
auto min = std::make_shared<ngraph::opset4::Minimum>(relu, min_constant);
|
||||
auto mul = std::make_shared<ngraph::opset4::Multiply>(input, min);
|
||||
auto div_constant = ngraph::opset4::Constant::create(ngraph::element::f16, ngraph::Shape{}, {0.0});
|
||||
auto div = std::make_shared<ngraph::opset4::Divide>(mul, div_constant);
|
||||
|
||||
f_ref = std::make_shared<ngraph::Function>(ngraph::NodeVector{div}, ngraph::ParameterVector{input});
|
||||
}
|
||||
|
||||
auto res = compare_functions(f, f_ref);
|
||||
ASSERT_TRUE(res.first) << res.second;
|
||||
}
|
||||
|
||||
TEST(TransformationTests, HSwishFusionWithoutReluWrongConstValue) {
|
||||
std::shared_ptr<ngraph::Function> f(nullptr), f_ref(nullptr);
|
||||
{
|
||||
auto input = std::make_shared<ngraph::opset4::Parameter>(ngraph::element::f16, ngraph::PartialShape::dynamic(1));
|
||||
auto add_constant = ngraph::opset4::Constant::create(ngraph::element::f16, ngraph::Shape{}, {3.11});
|
||||
auto add = std::make_shared<ngraph::opset4::Add>(input, add_constant);
|
||||
auto max_constant = ngraph::opset4::Constant::create(ngraph::element::f16, ngraph::Shape{}, {0.22});
|
||||
auto max = std::make_shared<ngraph::opset4::Maximum>(add, max_constant);
|
||||
auto min_constant = ngraph::opset4::Constant::create(ngraph::element::f16, ngraph::Shape{}, {6.01});
|
||||
auto min = std::make_shared<ngraph::opset4::Minimum>(max, min_constant);
|
||||
auto div_constant = ngraph::opset4::Constant::create(ngraph::element::f16, ngraph::Shape{}, {6.002});
|
||||
auto div = std::make_shared<ngraph::opset4::Divide>(min, div_constant);
|
||||
auto mul = std::make_shared<ngraph::opset4::Multiply>(input, div);
|
||||
|
||||
f = std::make_shared<ngraph::Function>(ngraph::NodeVector{mul}, ngraph::ParameterVector{input});
|
||||
|
||||
ngraph::pass::Manager manager;
|
||||
manager.register_pass<ngraph::pass::InitNodeInfo>();
|
||||
manager.register_pass<ngraph::pass::HSwishFusionWithoutRelu>();
|
||||
manager.run_passes(f);
|
||||
ASSERT_NO_THROW(check_rt_info(f));
|
||||
}
|
||||
|
||||
{
|
||||
auto input = std::make_shared<ngraph::opset4::Parameter>(ngraph::element::f16, ngraph::PartialShape::dynamic(1));
|
||||
auto add_constant = ngraph::opset4::Constant::create(ngraph::element::f16, ngraph::Shape{}, {3.11});
|
||||
auto add = std::make_shared<ngraph::opset4::Add>(input, add_constant);
|
||||
auto max_constant = ngraph::opset4::Constant::create(ngraph::element::f16, ngraph::Shape{}, {0.22});
|
||||
auto max = std::make_shared<ngraph::opset4::Maximum>(add, max_constant);
|
||||
auto min_constant = ngraph::opset4::Constant::create(ngraph::element::f16, ngraph::Shape{}, {6.01});
|
||||
auto min = std::make_shared<ngraph::opset4::Minimum>(max, min_constant);
|
||||
auto div_constant = ngraph::opset4::Constant::create(ngraph::element::f16, ngraph::Shape{}, {6.002});
|
||||
auto div = std::make_shared<ngraph::opset4::Divide>(min, div_constant);
|
||||
auto mul = std::make_shared<ngraph::opset4::Multiply>(input, div);
|
||||
|
||||
f_ref = std::make_shared<ngraph::Function>(ngraph::NodeVector{mul}, ngraph::ParameterVector{input});
|
||||
}
|
||||
|
||||
auto res = compare_functions(f, f_ref);
|
||||
ASSERT_TRUE(res.first) << res.second;
|
||||
}
|
||||
53
ngraph/core/include/ngraph/op/hswish.hpp
Normal file
53
ngraph/core/include/ngraph/op/hswish.hpp
Normal file
@@ -0,0 +1,53 @@
|
||||
//*****************************************************************************
|
||||
// Copyright 2017-2020 Intel Corporation
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
//*****************************************************************************
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ngraph/node.hpp"
|
||||
#include "ngraph/op/op.hpp"
|
||||
#include "ngraph/op/util/unary_elementwise_arithmetic.hpp"
|
||||
|
||||
namespace ngraph
|
||||
{
|
||||
namespace op
|
||||
{
|
||||
namespace v4
|
||||
{
|
||||
/// \brief A HSwish Activation Function
|
||||
/// f(x) = x * min(max(x + 3, 0), 6) / 6 or
|
||||
/// f(x) = x * min(ReLU(x + 3), 6) / 6
|
||||
///
|
||||
class NGRAPH_API HSwish : public ngraph::op::util::UnaryElementwiseArithmetic
|
||||
{
|
||||
public:
|
||||
NGRAPH_RTTI_DECLARATION;
|
||||
HSwish() = default;
|
||||
|
||||
/// \brief Constructs a HSwish (hard version of Swish) operation.
|
||||
///
|
||||
/// \param data Input tensor
|
||||
HSwish(const Output<Node>& arg);
|
||||
|
||||
bool visit_attributes(AttributeVisitor& visitor) override;
|
||||
|
||||
virtual std::shared_ptr<Node>
|
||||
clone_with_new_inputs(const OutputVector& new_args) const override;
|
||||
bool evaluate(const HostTensorVector& outputs,
|
||||
const HostTensorVector& inputs) const override;
|
||||
};
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -75,6 +75,7 @@
|
||||
#include "ngraph/op/group_conv.hpp"
|
||||
#include "ngraph/op/gru_cell.hpp"
|
||||
#include "ngraph/op/hard_sigmoid.hpp"
|
||||
#include "ngraph/op/hswish.hpp"
|
||||
#include "ngraph/op/interpolate.hpp"
|
||||
#include "ngraph/op/less.hpp"
|
||||
#include "ngraph/op/less_eq.hpp"
|
||||
|
||||
@@ -156,6 +156,7 @@ NGRAPH_OP(Acosh, ngraph::op::v3)
|
||||
NGRAPH_OP(Asinh, ngraph::op::v3)
|
||||
NGRAPH_OP(Atanh, ngraph::op::v3)
|
||||
NGRAPH_OP(CTCLoss, ngraph::op::v4)
|
||||
NGRAPH_OP(HSwish, ngraph::op::v4)
|
||||
NGRAPH_OP(NonMaxSuppression, ngraph::op::v4)
|
||||
NGRAPH_OP(Mish, ngraph::op::v4)
|
||||
NGRAPH_OP(ReduceL1, ngraph::op::v4)
|
||||
|
||||
@@ -0,0 +1,38 @@
|
||||
//*****************************************************************************
|
||||
// Copyright 2017-2020 Intel Corporation
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
//*****************************************************************************
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <cmath>
|
||||
#include <cstddef>
|
||||
|
||||
namespace ngraph
|
||||
{
|
||||
namespace runtime
|
||||
{
|
||||
namespace reference
|
||||
{
|
||||
template <typename T>
|
||||
void hswish(const T* arg, T* out, size_t count)
|
||||
{
|
||||
for (size_t i = 0; i < count; i++)
|
||||
{
|
||||
out[i] = arg[i] * std::min<T>(std::max<T>(arg[i] + 3.0f, 0.0f), 6.0f) / 6.0f;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
78
ngraph/core/src/op/hswish.cpp
Normal file
78
ngraph/core/src/op/hswish.cpp
Normal file
@@ -0,0 +1,78 @@
|
||||
//*****************************************************************************
|
||||
// Copyright 2017-2020 Intel Corporation
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
//*****************************************************************************
|
||||
|
||||
#include "ngraph/op/hswish.hpp"
|
||||
#include "ngraph/attribute_visitor.hpp"
|
||||
#include "ngraph/op/constant.hpp"
|
||||
|
||||
#include "ngraph/runtime/host_tensor.hpp"
|
||||
#include "ngraph/runtime/reference/hswish.hpp"
|
||||
|
||||
using namespace std;
|
||||
using namespace ngraph;
|
||||
|
||||
NGRAPH_RTTI_DEFINITION(op::v4::HSwish, "HSwish", 4);
|
||||
|
||||
op::v4::HSwish::HSwish(const Output<Node>& arg)
|
||||
: UnaryElementwiseArithmetic(arg)
|
||||
{
|
||||
constructor_validate_and_infer_types();
|
||||
}
|
||||
|
||||
bool op::v4::HSwish::visit_attributes(AttributeVisitor& visitor)
|
||||
{
|
||||
return true;
|
||||
}
|
||||
|
||||
shared_ptr<Node> op::v4::HSwish::clone_with_new_inputs(const OutputVector& new_args) const
|
||||
{
|
||||
return make_shared<op::v4::HSwish>(new_args.at(0));
|
||||
}
|
||||
|
||||
namespace
|
||||
{
|
||||
template <element::Type_t ET>
|
||||
inline bool evaluate(const HostTensorPtr& arg, const HostTensorPtr& out, const size_t count)
|
||||
{
|
||||
using T = typename element_type_traits<ET>::value_type;
|
||||
|
||||
runtime::reference::hswish<T>(arg->get_data_ptr<ET>(), out->get_data_ptr<ET>(), count);
|
||||
return true;
|
||||
}
|
||||
|
||||
bool evaluate_hswish(const HostTensorPtr& arg, const HostTensorPtr& out, const size_t count)
|
||||
{
|
||||
bool rc = true;
|
||||
out->set_unary(arg);
|
||||
|
||||
switch (arg->get_element_type())
|
||||
{
|
||||
TYPE_CASE(bf16)(arg, out, count);
|
||||
break;
|
||||
TYPE_CASE(f16)(arg, out, count);
|
||||
break;
|
||||
TYPE_CASE(f32)(arg, out, count);
|
||||
break;
|
||||
default: rc = false; break;
|
||||
}
|
||||
return rc;
|
||||
}
|
||||
}
|
||||
|
||||
bool op::v4::HSwish::evaluate(const HostTensorVector& outputs, const HostTensorVector& inputs) const
|
||||
{
|
||||
return evaluate_hswish(inputs[0], outputs[0], shape_size(get_output_shape(0)));
|
||||
}
|
||||
@@ -84,6 +84,7 @@ from ngraph.opset4 import group_convolution
|
||||
from ngraph.opset4 import group_convolution_backprop_data
|
||||
from ngraph.opset4 import gru_cell
|
||||
from ngraph.opset4 import hard_sigmoid
|
||||
from ngraph.opset4 import hswish
|
||||
from ngraph.opset4 import interpolate
|
||||
from ngraph.opset4 import less
|
||||
from ngraph.opset4 import less_equal
|
||||
|
||||
@@ -72,6 +72,7 @@ from ngraph.opset1.ops import group_convolution
|
||||
from ngraph.opset1.ops import group_convolution_backprop_data
|
||||
from ngraph.opset3.ops import gru_cell
|
||||
from ngraph.opset1.ops import hard_sigmoid
|
||||
from ngraph.opset4.ops import hswish
|
||||
from ngraph.opset1.ops import interpolate
|
||||
from ngraph.opset1.ops import less
|
||||
from ngraph.opset1.ops import less_equal
|
||||
|
||||
@@ -149,6 +149,16 @@ def mish(data: NodeInput, name: Optional[str] = None,) -> Node:
|
||||
return _get_node_factory_opset4().create("Mish", as_nodes(data), {})
|
||||
|
||||
|
||||
@nameable_op
|
||||
def hswish(data: NodeInput, name: Optional[str] = None,) -> Node:
|
||||
"""Return a node which performs HSwish (hard version of Swish).
|
||||
|
||||
:param data: Tensor with input data floating point type.
|
||||
:return: The new node which performs HSwish
|
||||
"""
|
||||
return _get_node_factory_opset4().create("HSwish", as_nodes(data), {})
|
||||
|
||||
|
||||
@nameable_op
|
||||
def swish(
|
||||
data: NodeInput,
|
||||
|
||||
@@ -17,6 +17,7 @@ import numpy as np
|
||||
import pytest
|
||||
|
||||
import ngraph as ng
|
||||
from ngraph.impl import Shape, Type
|
||||
from tests.test_ngraph.util import run_op_node, run_op_numeric_data
|
||||
from tests import xfail_issue_35929, xfail_issue_34323
|
||||
|
||||
@@ -148,3 +149,14 @@ def test_erf():
|
||||
|
||||
result = run_op_numeric_data(input_tensor, ng.erf)
|
||||
assert np.allclose(result, expected)
|
||||
|
||||
|
||||
def test_hswish():
|
||||
float_dtype = np.float32
|
||||
data = ng.parameter(Shape([3, 10]), dtype=float_dtype, name="data")
|
||||
|
||||
node = ng.hswish(data)
|
||||
assert node.get_type_name() == "HSwish"
|
||||
assert node.get_output_size() == 1
|
||||
assert list(node.get_output_shape(0)) == [3, 10]
|
||||
assert node.get_output_element_type(0) == Type.f32
|
||||
|
||||
@@ -70,6 +70,7 @@ set(SRC
|
||||
node_input_output.cpp
|
||||
nop_elimination.cpp
|
||||
op.cpp
|
||||
op_eval/hswish.cpp
|
||||
op_eval/matmul.cpp
|
||||
op_eval/mish.cpp
|
||||
op_eval/non_zero.cpp
|
||||
@@ -127,6 +128,7 @@ set(SRC
|
||||
type_prop/group_convolution_backprop_data.cpp
|
||||
type_prop/gru_cell.cpp
|
||||
type_prop/hard_sigmoid.cpp
|
||||
type_prop/hswish.cpp
|
||||
type_prop/lrn.cpp
|
||||
type_prop/lstm_cell.cpp
|
||||
type_prop/lstm_sequence.cpp
|
||||
|
||||
48
ngraph/test/op_eval/hswish.cpp
Normal file
48
ngraph/test/op_eval/hswish.cpp
Normal file
@@ -0,0 +1,48 @@
|
||||
//*****************************************************************************
|
||||
// Copyright 2017-2020 Intel Corporation
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
//*****************************************************************************
|
||||
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "gtest/gtest.h"
|
||||
|
||||
#include "ngraph/op/hswish.hpp"
|
||||
#include "ngraph/runtime/host_tensor.hpp"
|
||||
#include "ngraph/validation_util.hpp"
|
||||
#include "runtime/backend.hpp"
|
||||
#include "util/test_tools.hpp"
|
||||
|
||||
using namespace std;
|
||||
using namespace ngraph;
|
||||
|
||||
TEST(op_eval, hswish)
|
||||
{
|
||||
auto p = make_shared<op::Parameter>(element::f32, Shape{3});
|
||||
auto swish = make_shared<op::v4::HSwish>(p);
|
||||
auto fun = make_shared<Function>(OutputVector{swish}, ParameterVector{p});
|
||||
|
||||
std::vector<float> inputs{-0.5f, 0.0f, 0.5f};
|
||||
std::vector<float> expected_result{-0.208333f, 0.0f, 0.29166667f};
|
||||
|
||||
auto result = make_shared<HostTensor>();
|
||||
ASSERT_TRUE(
|
||||
fun->evaluate({result}, {make_host_tensor<element::Type_t::f32>(Shape{3}, inputs)}));
|
||||
EXPECT_EQ(result->get_element_type(), element::f32);
|
||||
EXPECT_EQ(result->get_shape(), Shape{3});
|
||||
auto result_data = read_vector<float>(result);
|
||||
for (auto i = 0; i < inputs.size(); i++)
|
||||
EXPECT_NEAR(result_data[i], expected_result[i], 0.000001);
|
||||
}
|
||||
54
ngraph/test/type_prop/hswish.cpp
Normal file
54
ngraph/test/type_prop/hswish.cpp
Normal file
@@ -0,0 +1,54 @@
|
||||
//*****************************************************************************
|
||||
// Copyright 2017-2020 Intel Corporation
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
//*****************************************************************************
|
||||
|
||||
#include "gtest/gtest.h"
|
||||
#include "ngraph/ngraph.hpp"
|
||||
#include "util/type_prop.hpp"
|
||||
|
||||
using namespace std;
|
||||
using namespace ngraph;
|
||||
|
||||
TEST(type_prop, hswish)
|
||||
{
|
||||
auto data = make_shared<op::Parameter>(element::f32, Shape{1, 3, 6});
|
||||
auto hswish_func = make_shared<op::v4::HSwish>(data);
|
||||
EXPECT_EQ(hswish_func->get_element_type(), element::f32);
|
||||
EXPECT_EQ(hswish_func->get_shape(), data->get_output_shape(0));
|
||||
}
|
||||
|
||||
TEST(type_prop, hswish_partial)
|
||||
{
|
||||
auto data = make_shared<op::Parameter>(element::f32, PartialShape{1, Dimension::dynamic(), 6});
|
||||
auto hswish_func = make_shared<op::v4::HSwish>(data);
|
||||
EXPECT_EQ(hswish_func->get_element_type(), element::f32);
|
||||
ASSERT_TRUE(
|
||||
hswish_func->get_output_partial_shape(0).same_scheme(data->get_output_partial_shape(0)));
|
||||
|
||||
// rank unknown
|
||||
auto hswish_partial = make_shared<op::v4::HSwish>(
|
||||
make_shared<op::Parameter>(element::f32, PartialShape::dynamic()));
|
||||
ASSERT_TRUE(hswish_partial->get_output_partial_shape(0).same_scheme(PartialShape::dynamic()));
|
||||
}
|
||||
|
||||
TEST(type_prop, hswish_partial_static_rank)
|
||||
{
|
||||
auto data = make_shared<op::Parameter>(element::f32, PartialShape{1, Dimension::dynamic(), 6});
|
||||
auto hswish_func = make_shared<op::v4::HSwish>(data);
|
||||
EXPECT_EQ(hswish_func->get_element_type(), element::f32);
|
||||
ASSERT_TRUE(
|
||||
hswish_func->get_output_partial_shape(0).same_scheme(data->get_output_partial_shape(0)));
|
||||
ASSERT_TRUE(hswish_func->get_output_partial_shape(0).rank().is_static());
|
||||
}
|
||||
Reference in New Issue
Block a user