Enable swish (#1682)

* Draft version of the Swish nGraph operation and fusing transformations for different approaches to express the operation

* Swish fusing transformation refactoring

* Added Swish operation and extractor for TF. Removed unfolding transformation for the operation.

* Added SwishIE. Implemented transformation to convert Swish to SwishIE.

* Code style fixes

* Updated Swish reference implementation. Added tests for shape and value inference


* Fixed code style for Python API

* Fixed unit test

* Apply review comments

* Use matcher_pass_callback

* Make m_alpha attribute protected in the SwishIE operation

* Fixed Swish op PythonAPI test
This commit is contained in:
Evgeny Lazarev 2020-08-10 15:51:21 +03:00 committed by GitHub
parent 600ad8d180
commit 318d38770b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
29 changed files with 1179 additions and 101 deletions

View File

@ -496,6 +496,16 @@ InferenceEngine::details::CNNLayerCreator::CNNLayerCreator(const std::shared_ptr
});
addSpecificCreator({"SwishIE"}, [](const std::shared_ptr<::ngraph::Node>& node,
const std::map<std::string, std::string> params) -> CNNLayerPtr {
LayerParams attrs = {node->get_friendly_name(), "Swish",
details::convertPrecision(node->get_output_element_type(0))};
auto res = std::make_shared<InferenceEngine::CNNLayer>(attrs);
res->params = params;
return res;
});
addSpecificCreator({"PriorBox"}, [](const std::shared_ptr<::ngraph::Node>& node,
const std::map<std::string, std::string> params) -> CNNLayerPtr {
THROW_IE_EXCEPTION << "PriorBox operation has a form that is not supported." << node->get_friendly_name()

View File

@ -0,0 +1,32 @@
// Copyright (C) 2018-2020 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#pragma once
#include <memory>
#include <transformations_visibility.hpp>
#include "ngraph/op/op.hpp"
namespace ngraph {
namespace op {
class TRANSFORMATIONS_API SwishIE : public Op {
public:
static constexpr NodeTypeInfo type_info{"SwishIE", 1};
const NodeTypeInfo &get_type_info() const override { return type_info; }
explicit SwishIE(const Output<Node> &input, float alpha = 1.0);
void validate_and_infer_types() override;
bool visit_attributes(AttributeVisitor& visitor) override;
std::shared_ptr<Node> clone_with_new_inputs(const OutputVector &new_args) const override;
void set_alpha(float alpha);
float get_alpha() const;
protected:
float m_alpha;
};
} // namespace op
} // namespace ngraph

View File

@ -0,0 +1,24 @@
// Copyright (C) 2018-2020 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#pragma once
#include <vector>
#include <memory>
#include <string>
#include <transformations_visibility.hpp>
#include <ngraph/pass/graph_rewrite.hpp>
namespace ngraph {
namespace pass {
class TRANSFORMATIONS_API ConvertSwishToSwishIEMatcher;
} // namespace pass
} // namespace ngraph
class ngraph::pass::ConvertSwishToSwishIEMatcher: public ngraph::pass::MatcherPass {
public:
ConvertSwishToSwishIEMatcher();
};

View File

@ -0,0 +1,73 @@
// Copyright (C) 2018-2020 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#pragma once
#include <memory>
#include <utility>
#include <transformations_visibility.hpp>
#include <ngraph/pass/graph_rewrite.hpp>
namespace ngraph {
namespace pass {
class TRANSFORMATIONS_API SwishFusion;
class TRANSFORMATIONS_API SwishFusionWithSigmoid;
class TRANSFORMATIONS_API SwishFusionWithSigmoidWithBeta;
class TRANSFORMATIONS_API SwishFusionWithBeta;
class TRANSFORMATIONS_API SwishFusionWithoutBeta;
} // namespace pass
} // namespace ngraph
/**
* @ingroup ie_transformation_common_api
* @brief SwishFusion transformation replaces various sub-graphs with a Swish op.
*/
class ngraph::pass::SwishFusion: public ngraph::pass::GraphRewrite {
public:
SwishFusion() {
add_matcher<ngraph::pass::SwishFusionWithSigmoid>();
add_matcher<ngraph::pass::SwishFusionWithSigmoidWithBeta>();
add_matcher<ngraph::pass::SwishFusionWithBeta>();
add_matcher<ngraph::pass::SwishFusionWithoutBeta>();
}
};
/**
* @ingroup ie_transformation_common_api
* @brief SwishFusionWithSigmoid replaces a sub-graphs x * Sigmoid(x) with a Swish op.
*/
class ngraph::pass::SwishFusionWithSigmoid: public ngraph::pass::MatcherPass {
public:
SwishFusionWithSigmoid();
};
/**
* @ingroup ie_transformation_common_api
* @brief SwishFusionWithSigmoid replaces a sub-graphs x * Sigmoid(x * beta) with a Swish op.
*/
class ngraph::pass::SwishFusionWithSigmoidWithBeta: public ngraph::pass::MatcherPass {
public:
SwishFusionWithSigmoidWithBeta();
};
/**
* @ingroup ie_transformation_common_api
* @brief SwishFusionWithSigmoid replaces a sub-graphs x / (1.0 + exp(-x * beta)) with a Swish op.
*/
class ngraph::pass::SwishFusionWithBeta: public ngraph::pass::MatcherPass {
public:
SwishFusionWithBeta();
};
/**
* @ingroup ie_transformation_common_api
* @brief SwishFusionWithSigmoid replaces a sub-graphs x / (1.0 + exp(-x)) with a Swish op.
*/
class ngraph::pass::SwishFusionWithoutBeta: public ngraph::pass::MatcherPass {
public:
SwishFusionWithoutBeta();
};

View File

@ -0,0 +1,44 @@
// Copyright (C) 2018-2020 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#include "ngraph_ops/swish_ie.hpp"
#include <algorithm>
#include <memory>
#include "ngraph/util.hpp"
#include "ngraph/validation_util.hpp"
using namespace std;
using namespace ngraph;
constexpr NodeTypeInfo op::SwishIE::type_info;
op::SwishIE::SwishIE(const Output<Node> & input, const float alpha)
: Op({input}), m_alpha(alpha) {
constructor_validate_and_infer_types();
}
std::shared_ptr<Node> op::SwishIE::clone_with_new_inputs(const OutputVector& new_args) const {
check_new_args_count(this, new_args);
return make_shared<SwishIE>(new_args.at(0), m_alpha);
}
bool op::SwishIE::visit_attributes(AttributeVisitor& visitor) {
visitor.on_attribute("alpha", m_alpha);
return true;
}
void op::SwishIE::validate_and_infer_types() {
set_output_type(0, get_input_element_type(0), get_input_partial_shape(0));
}
void op::SwishIE::set_alpha(float alpha) {
m_alpha = alpha;
}
float op::SwishIE::get_alpha() const {
return m_alpha;
}

View File

@ -12,6 +12,7 @@
#include "transformations/init_node_info.hpp"
#include "transformations/itt.hpp"
#include "transformations/mish_fusion.hpp"
#include "transformations/swish_fusion.hpp"
#include <ngraph/pass/manager.hpp>
#include <ngraph/pass/nop_elimination.hpp>
@ -34,6 +35,7 @@ bool ngraph::pass::CommonOptimizations::run_on_function(std::shared_ptr<ngraph::
manager.register_pass<ngraph::pass::ConvertScatterElementsToScatter>(); // partially depends on CF
manager.register_pass<ngraph::pass::DepthToSpaceFusion>();
manager.register_pass<ngraph::pass::MishFusion>();
manager.register_pass<ngraph::pass::SwishFusion>();
manager.set_callback(m_transformation_callback);
manager.run_passes(f);

View File

@ -32,6 +32,7 @@
#include <transformations/convert_opset1_to_legacy/convert_strided_slice_to_crop.hpp>
#include <transformations/convert_subtract.hpp>
#include <transformations/convert_opset1_to_legacy/convert_selu_to_selu_ie.hpp>
#include <transformations/convert_opset1_to_legacy/convert_swish_to_swish_ie.hpp>
#include <transformations/convert_opset1_to_legacy/convert_tile_to_ie_tile.hpp>
#include <transformations/convert_opset1_to_legacy/convert_topk_to_topk_ie.hpp>
#include <transformations/convert_depth_to_space.hpp>
@ -129,6 +130,7 @@ bool ngraph::pass::ConvertOpSet1ToLegacy::run_on_function(std::shared_ptr<ngraph
anchor->add_matcher<ngraph::pass::ConvertPReLUToReLUIE>();
anchor->add_matcher<ngraph::pass::ConvertGatherToGatherIEMatcher>();
anchor->add_matcher<ngraph::pass::ConvertSeluToSeluIEMatcher>();
anchor->add_matcher<ngraph::pass::ConvertSwishToSwishIEMatcher>();
anchor->add_matcher<ngraph::pass::ConvertOneHotToOneHotIEMatcher>()->detect_output_type(f);
anchor->add_matcher<ngraph::pass::ConvertGatherTreeToGatherTreeIEMatcher>();
anchor->add_matcher<ngraph::pass::ConvertTopKToTopKIEMatcher>();

View File

@ -0,0 +1,46 @@
// Copyright (C) 2018-2020 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#include "transformations/convert_opset1_to_legacy/convert_swish_to_swish_ie.hpp"
#include <memory>
#include <ngraph/opsets/opset4.hpp>
#include <ngraph_ops/swish_ie.hpp>
#include <transformations/utils/utils.hpp>
#include <ngraph/rt_info.hpp>
#include <ngraph/pattern/op/wrap_type.hpp>
ngraph::pass::ConvertSwishToSwishIEMatcher::ConvertSwishToSwishIEMatcher() {
auto swish = ngraph::pattern::wrap_type<ngraph::opset4::Swish>();
ngraph::matcher_pass_callback callback = [](pattern::Matcher& m) {
auto swish = std::dynamic_pointer_cast<ngraph::opset4::Swish> (m.get_match_root());
if (!swish) {
return false;
}
float beta_value = 1.0;
if (swish->input_values().size() == 2) {
auto beta_node = swish->input_value(1).get_node_shared_ptr();
auto beta_const = std::dynamic_pointer_cast<ngraph::opset4::Constant>(beta_node);
if (!beta_const) {
return false;
}
if (!ngraph::op::util::get_single_value(beta_const, beta_value)) {
return false;
}
}
auto swish_ie = std::make_shared<ngraph::op::SwishIE>(swish->input(0).get_source_output(), beta_value);
swish_ie->set_friendly_name(swish->get_friendly_name());
ngraph::copy_runtime_info(swish, swish_ie);
ngraph::replace_node(swish, swish_ie);
return true;
};
auto m = std::make_shared<ngraph::pattern::Matcher>(swish, "ConvertSwishToSwishIE");
this->register_matcher(m, callback);
}

View File

@ -0,0 +1,183 @@
// Copyright (C) 2018-2020 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#include "transformations/swish_fusion.hpp"
#include <memory>
#include <ngraph/opsets/opset4.hpp>
#include <ngraph/rt_info.hpp>
#include <ngraph/pattern/op/wrap_type.hpp>
bool check_constant_value(const std::shared_ptr<ngraph::opset4::Constant>& constant) {
if (!constant) {
return false;
}
if (constant->get_element_type() == ngraph::element::f32 || constant->get_element_type() == ngraph::element::f16) {
auto data = constant->cast_vector<float>();
if (data.size() != 1 || data[0] != 1.0) {
return false;
}
} else {
return false;
}
return true;
}
bool check_beta_value(const std::shared_ptr<ngraph::opset4::Constant>& constant) {
// check that the constant for beta contains only one distinct element
if (!constant) {
return false;
}
if (constant->get_element_type() == ngraph::element::f32 || constant->get_element_type() == ngraph::element::f16) {
auto data = constant->cast_vector<float>();
if (!std::equal(data.begin() + 1, data.end(), data.begin())) {
return false;
}
} else {
return false;
}
return true;
}
ngraph::pass::SwishFusionWithSigmoid::SwishFusionWithSigmoid() {
// replaces a sub-graphs x * Sigmoid(x) with a Swish op.
auto input = ngraph::pattern::any_input();
auto sigmoid = std::make_shared<ngraph::opset4::Sigmoid>(input);
auto mul = std::make_shared<ngraph::opset4::Multiply>(input, sigmoid);
ngraph::matcher_pass_callback callback = [=](ngraph::pattern::Matcher &m) {
auto &pattern_to_output = m.get_pattern_value_map();
auto exp_input = pattern_to_output.at(input);
auto swish = std::make_shared<ngraph::opset4::Swish>(exp_input);
swish->set_friendly_name(m.get_match_root()->get_friendly_name());
ngraph::copy_runtime_info({pattern_to_output.at(sigmoid).get_node_shared_ptr(),
pattern_to_output.at(mul).get_node_shared_ptr()},
swish);
ngraph::replace_node(m.get_match_root(), swish);
return true;
};
auto m = std::make_shared<ngraph::pattern::Matcher>(mul, "SwishWithSigmoidFusion");
register_matcher(m, callback);
}
ngraph::pass::SwishFusionWithSigmoidWithBeta::SwishFusionWithSigmoidWithBeta() {
// replaces a sub-graphs x * Sigmoid(x * beta) with a Swish op.
auto input = ngraph::pattern::any_input();
auto beta = ngraph::pattern::any_input();
auto mul_beta = std::make_shared<ngraph::opset4::Multiply>(input, beta);
auto sigmoid = std::make_shared<ngraph::opset4::Sigmoid>(mul_beta);
auto mul = std::make_shared<ngraph::opset4::Multiply>(input, sigmoid);
ngraph::matcher_pass_callback callback = [=](ngraph::pattern::Matcher &m) {
auto &pattern_to_output = m.get_pattern_value_map();
auto exp_input = pattern_to_output.at(input);
auto beta_input = pattern_to_output.at(beta);
auto beta_constant = std::dynamic_pointer_cast<ngraph::opset4::Constant>(beta_input.get_node_shared_ptr());
Output<Node> new_beta;
if (beta_constant) {
if (check_beta_value(beta_constant)) {
new_beta = opset4::Constant::create(beta_input.get_element_type(), Shape{}, {beta_constant->cast_vector<float>()[0]});
} else {
return false;
}
} else {
// if the input is not constant and number of elements is not equal to 1 then we cannot perform fusing
if (beta_input.get_partial_shape().is_dynamic() || ngraph::shape_size(beta_input.get_shape()) != 1) {
return false;
}
new_beta = beta_input;
}
auto swish = std::make_shared<ngraph::opset4::Swish>(exp_input, new_beta);
swish->set_friendly_name(m.get_match_root()->get_friendly_name());
ngraph::copy_runtime_info({pattern_to_output.at(sigmoid).get_node_shared_ptr(),
pattern_to_output.at(mul).get_node_shared_ptr()},
swish);
ngraph::replace_node(m.get_match_root(), swish);
return true;
};
auto m = std::make_shared<ngraph::pattern::Matcher>(mul, "SwishWithSigmoidWithBetaFusion");
register_matcher(m, callback);
}
ngraph::pass::SwishFusionWithBeta::SwishFusionWithBeta() {
// replaces a sub-graphs x / (1.0 + exp(-x * beta)) with a Swish op.
auto input = ngraph::pattern::any_input();
auto beta = ngraph::pattern::any_input();
auto mul = std::make_shared<ngraph::opset4::Multiply>(input, beta);
auto neg = std::make_shared<ngraph::opset4::Negative>(mul);
auto exp = std::make_shared<ngraph::opset4::Exp>(neg);
auto add_constant = ngraph::pattern::wrap_type<ngraph::opset4::Constant>();
auto add = std::make_shared<ngraph::opset4::Add>(exp, add_constant);
auto div = std::make_shared<ngraph::opset4::Divide>(input, add);
ngraph::matcher_pass_callback callback = [=](ngraph::pattern::Matcher &m) {
auto &pattern_to_output = m.get_pattern_value_map();
auto exp_input = pattern_to_output.at(input);
auto constant = std::dynamic_pointer_cast<ngraph::opset4::Constant>(pattern_to_output.at(add_constant).get_node_shared_ptr());
if (!check_constant_value(constant)) {
return false;
}
auto swish = std::make_shared<ngraph::opset4::Swish>(exp_input, pattern_to_output.at(beta));
swish->set_friendly_name(m.get_match_root()->get_friendly_name());
ngraph::copy_runtime_info({pattern_to_output.at(beta).get_node_shared_ptr(),
pattern_to_output.at(mul).get_node_shared_ptr(),
pattern_to_output.at(neg).get_node_shared_ptr(),
pattern_to_output.at(exp).get_node_shared_ptr(),
pattern_to_output.at(add_constant).get_node_shared_ptr(),
pattern_to_output.at(add).get_node_shared_ptr(),
pattern_to_output.at(div).get_node_shared_ptr()},
swish);
ngraph::replace_node(m.get_match_root(), swish);
return true;
};
auto m = std::make_shared<ngraph::pattern::Matcher>(div, "SwishWithBetaFusion");
register_matcher(m, callback);
}
ngraph::pass::SwishFusionWithoutBeta::SwishFusionWithoutBeta() {
// replaces a sub-graphs x / (1.0 + exp(-x)) with a Swish op.
auto input = ngraph::pattern::any_input();
auto neg = std::make_shared<ngraph::opset4::Negative>(input);
auto exp = std::make_shared<ngraph::opset4::Exp>(neg);
auto add_constant = ngraph::pattern::wrap_type<ngraph::opset4::Constant>();
auto add = std::make_shared<ngraph::opset4::Add>(exp, add_constant);
auto div = std::make_shared<ngraph::opset4::Divide>(input, add);
ngraph::matcher_pass_callback callback = [=](ngraph::pattern::Matcher& m) {
auto & pattern_to_output = m.get_pattern_value_map();
auto exp_input = pattern_to_output.at(input);
auto constant = std::dynamic_pointer_cast<ngraph::opset4::Constant>(pattern_to_output.at(add_constant).get_node_shared_ptr());
if (!check_constant_value(constant)) {
return false;
}
auto swish = std::make_shared<ngraph::opset4::Swish>(exp_input);
swish->set_friendly_name(m.get_match_root()->get_friendly_name());
ngraph::copy_runtime_info({pattern_to_output.at(neg).get_node_shared_ptr(),
pattern_to_output.at(exp).get_node_shared_ptr(),
pattern_to_output.at(add_constant).get_node_shared_ptr(),
pattern_to_output.at(add).get_node_shared_ptr(),
pattern_to_output.at(div).get_node_shared_ptr()},
swish);
ngraph::replace_node(m.get_match_root(), swish);
return true;
};
auto m = std::make_shared<ngraph::pattern::Matcher>(div, "SwishWithoutBetaFusion");
register_matcher(m, callback);
}

View File

@ -6,12 +6,10 @@
#include <string>
#include <memory>
#include <queue>
#include <ngraph/function.hpp>
#include <ngraph/opsets/opset4.hpp>
#include <ngraph/pass/manager.hpp>
#include <ngraph/pass/visualize_tree.hpp>
#include <transformations/mish_fusion.hpp>
#include <transformations/init_node_info.hpp>
#include <transformations/utils/utils.hpp>

View File

@ -0,0 +1,206 @@
// Copyright (C) 2020 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#include <gtest/gtest.h>
#include <string>
#include <memory>
#include <ngraph/function.hpp>
#include <ngraph/opsets/opset4.hpp>
#include <ngraph/pass/manager.hpp>
#include <transformations/swish_fusion.hpp>
#include <transformations/init_node_info.hpp>
#include <transformations/utils/utils.hpp>
#include "common_test_utils/ngraph_test_utils.hpp"
using namespace testing;
TEST(TransformationTests, SwishFusionWithBeta) {
std::shared_ptr<ngraph::Function> f(nullptr), f_ref(nullptr);
{
auto input = std::make_shared<ngraph::opset4::Parameter>(ngraph::element::f32, ngraph::PartialShape::dynamic(1));
auto beta = std::make_shared<ngraph::opset4::Parameter>(ngraph::element::f32, ngraph::Shape{});
auto mul = std::make_shared<ngraph::opset4::Multiply>(input, beta);
auto neg = std::make_shared<ngraph::opset4::Negative>(mul);
auto exp = std::make_shared<ngraph::opset4::Exp>(neg);
auto constant = ngraph::opset4::Constant::create(ngraph::element::f32, ngraph::Shape{1}, {1.0});
auto add = std::make_shared<ngraph::opset4::Add>(exp, constant);
auto div = std::make_shared<ngraph::opset4::Divide>(input, add);
f = std::make_shared<ngraph::Function>(ngraph::NodeVector{div}, ngraph::ParameterVector{input, beta});
ngraph::pass::Manager manager;
manager.register_pass<ngraph::pass::InitNodeInfo>();
manager.register_pass<ngraph::pass::SwishFusion>();
manager.run_passes(f);
ASSERT_NO_THROW(check_rt_info(f));
}
{
auto input = std::make_shared<ngraph::opset4::Parameter>(ngraph::element::f32, ngraph::PartialShape::dynamic(1));
auto beta = std::make_shared<ngraph::opset4::Parameter>(ngraph::element::f32, ngraph::Shape{});
auto swish = std::make_shared<ngraph::opset4::Swish>(input, beta);
f_ref = std::make_shared<ngraph::Function>(ngraph::NodeVector{swish}, ngraph::ParameterVector{input, beta});
}
auto res = compare_functions(f, f_ref);
ASSERT_TRUE(res.first) << res.second;
}
TEST(TransformationTests, SwishFusionWithoutBeta) {
std::shared_ptr<ngraph::Function> f(nullptr), f_ref(nullptr);
{
auto input = std::make_shared<ngraph::opset4::Parameter>(ngraph::element::f16, ngraph::PartialShape::dynamic(1));
auto neg = std::make_shared<ngraph::opset4::Negative>(input);
auto exp = std::make_shared<ngraph::opset4::Exp>(neg);
auto constant = ngraph::opset4::Constant::create(ngraph::element::f16, ngraph::Shape{}, {1.0});
auto add = std::make_shared<ngraph::opset4::Add>(exp, constant);
auto div = std::make_shared<ngraph::opset4::Divide>(input, add);
f = std::make_shared<ngraph::Function>(ngraph::NodeVector{div}, ngraph::ParameterVector{input});
ngraph::pass::Manager manager;
manager.register_pass<ngraph::pass::InitNodeInfo>();
manager.register_pass<ngraph::pass::SwishFusion>();
manager.run_passes(f);
ASSERT_NO_THROW(check_rt_info(f));
}
{
auto input = std::make_shared<ngraph::opset4::Parameter>(ngraph::element::f16, ngraph::PartialShape::dynamic(1));
auto swish = std::make_shared<ngraph::opset4::Swish>(input);
f_ref = std::make_shared<ngraph::Function>(ngraph::NodeVector{swish}, ngraph::ParameterVector{input});
}
auto res = compare_functions(f, f_ref);
ASSERT_TRUE(res.first) << res.second;
}
TEST(TransformationTests, SwishFusionWithoutBetaNonOneAddConstant) {
std::shared_ptr<ngraph::Function> f(nullptr), f_ref(nullptr);
{
auto input = std::make_shared<ngraph::opset4::Parameter>(ngraph::element::f16, ngraph::PartialShape::dynamic(1));
auto neg = std::make_shared<ngraph::opset4::Negative>(input);
auto exp = std::make_shared<ngraph::opset4::Exp>(neg);
auto constant = ngraph::opset4::Constant::create(ngraph::element::f16, ngraph::Shape{}, {1.1});
auto add = std::make_shared<ngraph::opset4::Add>(exp, constant);
auto div = std::make_shared<ngraph::opset4::Divide>(input, add);
f = std::make_shared<ngraph::Function>(ngraph::NodeVector{div}, ngraph::ParameterVector{input});
ngraph::pass::Manager manager;
manager.register_pass<ngraph::pass::InitNodeInfo>();
manager.register_pass<ngraph::pass::SwishFusion>();
manager.run_passes(f);
ASSERT_NO_THROW(check_rt_info(f));
}
{
auto input = std::make_shared<ngraph::opset4::Parameter>(ngraph::element::f16, ngraph::PartialShape::dynamic(1));
auto neg = std::make_shared<ngraph::opset4::Negative>(input);
auto exp = std::make_shared<ngraph::opset4::Exp>(neg);
auto constant = ngraph::opset4::Constant::create(ngraph::element::f16, ngraph::Shape{}, {1.1});
auto add = std::make_shared<ngraph::opset4::Add>(exp, constant);
auto div = std::make_shared<ngraph::opset4::Divide>(input, add);
f = std::make_shared<ngraph::Function>(ngraph::NodeVector{div}, ngraph::ParameterVector{input});
f_ref = std::make_shared<ngraph::Function>(ngraph::NodeVector{div}, ngraph::ParameterVector{input});
}
auto res = compare_functions(f, f_ref);
ASSERT_TRUE(res.first) << res.second;
}
TEST(TransformationTests, SwishFusionWithSigmoid) {
std::shared_ptr<ngraph::Function> f(nullptr), f_ref(nullptr);
{
auto input = std::make_shared<ngraph::opset4::Parameter>(ngraph::element::f16, ngraph::PartialShape::dynamic(1));
auto sig = std::make_shared<ngraph::opset4::Sigmoid>(input);
auto mul = std::make_shared<ngraph::opset4::Multiply>(input, sig);
f = std::make_shared<ngraph::Function>(ngraph::NodeVector{mul}, ngraph::ParameterVector{input});
ngraph::pass::Manager manager;
manager.register_pass<ngraph::pass::InitNodeInfo>();
manager.register_pass<ngraph::pass::SwishFusion>();
manager.run_passes(f);
ASSERT_NO_THROW(check_rt_info(f));
}
{
auto input = std::make_shared<ngraph::opset4::Parameter>(ngraph::element::f16, ngraph::PartialShape::dynamic(1));
auto swish = std::make_shared<ngraph::opset4::Swish>(input);
f_ref = std::make_shared<ngraph::Function>(ngraph::NodeVector{swish}, ngraph::ParameterVector{input});
}
auto res = compare_functions(f, f_ref);
ASSERT_TRUE(res.first) << res.second;
}
TEST(TransformationTests, SwishFusionWithSigmoidWithBeta) {
std::shared_ptr<ngraph::Function> f(nullptr), f_ref(nullptr);
{
auto input = std::make_shared<ngraph::opset4::Parameter>(ngraph::element::f16, ngraph::PartialShape::dynamic(1));
auto beta = std::make_shared<ngraph::opset4::Parameter>(ngraph::element::f16, ngraph::Shape{});
auto mul_beta = std::make_shared<ngraph::opset4::Multiply>(input, beta);
auto sig = std::make_shared<ngraph::opset4::Sigmoid>(mul_beta);
auto mul = std::make_shared<ngraph::opset4::Multiply>(input, sig);
f = std::make_shared<ngraph::Function>(ngraph::NodeVector{mul}, ngraph::ParameterVector{input, beta});
ngraph::pass::Manager manager;
manager.register_pass<ngraph::pass::InitNodeInfo>();
manager.register_pass<ngraph::pass::SwishFusion>();
manager.run_passes(f);
ASSERT_NO_THROW(check_rt_info(f));
}
{
auto input = std::make_shared<ngraph::opset4::Parameter>(ngraph::element::f16, ngraph::PartialShape::dynamic(1));
auto beta = std::make_shared<ngraph::opset4::Parameter>(ngraph::element::f16, ngraph::Shape{});
auto swish = std::make_shared<ngraph::opset4::Swish>(input, beta);
f_ref = std::make_shared<ngraph::Function>(ngraph::NodeVector{swish}, ngraph::ParameterVector{input, beta});
}
auto res = compare_functions(f, f_ref);
ASSERT_TRUE(res.first) << res.second;
}
TEST(TransformationTests, SwishFusionWithSigmoidWithBetaConstant) {
// test where the beta constant has multiple but the same value
std::shared_ptr<ngraph::Function> f(nullptr), f_ref(nullptr);
{
auto input = std::make_shared<ngraph::opset4::Parameter>(ngraph::element::f16, ngraph::PartialShape::dynamic(1));
auto beta = ngraph::opset4::Constant::create(ngraph::element::f16, ngraph::Shape{3}, {2.0, 2.0, 2.0});
auto mul_beta = std::make_shared<ngraph::opset4::Multiply>(input, beta);
auto sig = std::make_shared<ngraph::opset4::Sigmoid>(mul_beta);
auto mul = std::make_shared<ngraph::opset4::Multiply>(input, sig);
f = std::make_shared<ngraph::Function>(ngraph::NodeVector{mul}, ngraph::ParameterVector{input});
ngraph::pass::Manager manager;
manager.register_pass<ngraph::pass::InitNodeInfo>();
manager.register_pass<ngraph::pass::SwishFusion>();
manager.run_passes(f);
ASSERT_NO_THROW(check_rt_info(f));
}
{
auto input = std::make_shared<ngraph::opset4::Parameter>(ngraph::element::f16, ngraph::PartialShape::dynamic(1));
auto beta = ngraph::opset4::Constant::create(ngraph::element::f16, ngraph::Shape{}, {2.0});
auto swish = std::make_shared<ngraph::opset4::Swish>(input, beta);
f_ref = std::make_shared<ngraph::Function>(ngraph::NodeVector{swish}, ngraph::ParameterVector{input});
}
auto res = compare_functions(f, f_ref);
ASSERT_TRUE(res.first) << res.second;
}

View File

@ -444,7 +444,7 @@ extensions/front/tf/ssd_toolbox_multihead_detection_output.json
extensions/front/tf/ssd_v2_support.json
extensions/front/tf/SSDToolboxDetectionOutput.py
extensions/front/tf/swap_deconv_inputs.py
extensions/front/tf/swish.py
extensions/front/tf/swish_ext.py
extensions/front/tf/SwitchMergeOptimization.py
extensions/front/tf/TensorArrayExtractors.py
extensions/front/tf/TensorArrayGatherV3.py

View File

@ -1,37 +0,0 @@
"""
Copyright (C) 2018-2020 Intel Corporation
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
from extensions.ops.activation_ops import Sigmoid
from extensions.ops.elementwise import Mul
from mo.front.common.replacement import FrontReplacementOp
from mo.graph.graph import Node, Graph
class Swish(FrontReplacementOp):
op = "swish_f32"
enabled = True
def replace_op(self, graph: Graph, node: Node):
mul_node = Mul(graph, {'name': node.name + '/mul_'}).create_node()
sigmoid_node = Sigmoid(graph, {'name': node.name + '/sigmoid_'}).create_node()
# Connect nodes
node.in_port(0).get_connection().get_source().connect(mul_node.in_port(0))
node.in_port(0).get_connection().get_source().connect(sigmoid_node.in_port(0))
sigmoid_node.out_port(0).connect(mul_node.in_port(1))
# The "explicit" version of the return value is: [(out_node.id, 0)])
return [mul_node.id]

View File

@ -0,0 +1,29 @@
"""
Copyright (C) 2018-2020 Intel Corporation
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
from extensions.ops.activation_ops import Swish
from mo.front.extractor import FrontExtractorOp
from mo.graph.graph import Node
class SwishExtractor(FrontExtractorOp):
op = 'swish_f32'
enabled = True
@classmethod
def extract(cls, node: Node):
Swish.update_node_stat(node, {})
return cls.enabled

View File

@ -1,57 +0,0 @@
"""
Copyright (C) 2018-2020 Intel Corporation
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
import unittest
import numpy as np
from extensions.front.tf.swish import Swish
from mo.utils.ir_engine.compare_graphs import compare_graphs
from mo.utils.unittest.graph import build_graph
nodes_attributes = {
'placeholder_1': {'shape': np.array([1, 227, 227, 3]), 'type': 'Parameter', 'kind': 'op', 'op': 'Parameter'},
'placeholder_2': {'shape': np.array([1, 227, 227, 3]), 'type': 'Parameter', 'kind': 'op', 'op': 'Parameter'},
# swish operation
'swish': {'kind': 'op', 'op': 'swish_f32'},
# Test operation
'last': {'type': None, 'value': None, 'kind': 'op', 'op': None},
# Add and Mul operations
'mul': {'type': 'Multiply', 'kind': 'op', 'op': 'Mul'},
'sigmoid': {'value': None, 'type': 'Sigmoid', 'kind': 'op', 'op': 'Sigmoid'},
}
class TestSwish(unittest.TestCase):
def test_swish_test_1(self):
# Test with two different inputs from two placeholders
graph = build_graph(nodes_attributes,
[('placeholder_1', 'swish'),
('swish', 'last')
], nodes_with_edges_only=True)
graph_ref = build_graph(nodes_attributes,
[('placeholder_1', 'sigmoid', {'out': 0}),
('placeholder_1', 'mul', {'in': 0, 'out': 0}),
('sigmoid', 'mul', {'in': 1}),
('mul', 'last'),
], nodes_with_edges_only=True)
graph.stage = 'front'
Swish().find_and_replace_pattern(graph)
(flag, resp) = compare_graphs(graph, graph_ref, 'last', check_op_attrs=True)
self.assertTrue(flag, resp)

View File

@ -198,9 +198,9 @@ class LeakyReLU(Op):
def __init__(self, graph: Graph, attrs: dict):
super().__init__(graph, {
'type': __class__.op,
'op': __class__.op,
'infer': __class__.infer,
'type': self.op,
'op': self.op,
'infer': self.infer,
'in_ports_count': 1,
'out_ports_count': 1,
}, attrs)
@ -265,3 +265,36 @@ class Mish(Activation):
sp_attrs = {'version': 'opset4'}
sp_attrs.update(attrs)
super().__init__(graph, sp_attrs)
class Swish(Op):
op = 'Swish'
def __init__(self, graph: Graph, attrs: dict):
mandatory_props = {
'op': self.op,
'type': self.op,
'version': 'opset4',
'infer': self.infer,
'in_ports_count': 2,
'out_ports_count': 1,
}
super().__init__(graph, mandatory_props, attrs)
@staticmethod
def infer(node: Node):
node_name = node.soft_get('name', node.id)
node.out_port(0).data.set_shape(node.in_port(0).data.get_shape())
beta = 1.0
if node.is_in_port_connected(1):
beta = node.in_port(1).data.get_value()
if beta is not None:
assert beta.ndim == 0, 'The "beta" value for node {} must be a scalar'.format(node_name)
beta = beta.item()
input_value = node.in_port(1).data.get_value()
if input_value is not None and beta is not None:
node.out_port(0).data.set_value(input_value / (1.0 + np.exp(-input_value * beta)))

View File

@ -150,6 +150,7 @@ from ngraph.opset4 import squared_difference
from ngraph.opset4 import squeeze
from ngraph.opset4 import strided_slice
from ngraph.opset4 import subtract
from ngraph.opset4 import swish
from ngraph.opset4 import tan
from ngraph.opset4 import tanh
from ngraph.opset4 import tensor_iterator

View File

@ -139,6 +139,7 @@ from ngraph.opset1.ops import squared_difference
from ngraph.opset1.ops import squeeze
from ngraph.opset1.ops import strided_slice
from ngraph.opset1.ops import subtract
from ngraph.opset4.ops import swish
from ngraph.opset1.ops import tan
from ngraph.opset1.ops import tanh
from ngraph.opset1.ops import tensor_iterator

View File

@ -147,3 +147,19 @@ def mish(data: NodeInput, name: Optional[str] = None,) -> Node:
:return: The new node which performs Mish
"""
return _get_node_factory_opset4().create("Mish", as_nodes(data), {})
@nameable_op
def swish(
data: NodeInput,
beta: Optional[NodeInput] = None,
name: Optional[str] = None,
) -> Node:
"""Return a node which performing Swish activation function Swish(x, beta=1.0) = x * sigmoid(x * beta)).
:param data: Tensor with input data floating point type.
:return: The new node which performs Swish
"""
if beta is None:
beta = make_constant_node(1.0, np.float32)
return _get_node_factory_opset4().create("Swish", as_nodes(data, beta), {})

View File

@ -0,0 +1,41 @@
# ******************************************************************************
# Copyright 2017-2020 Intel Corporation
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ******************************************************************************
import numpy as np
import ngraph as ng
from ngraph.impl import Shape, Type
def test_swish_props_with_beta():
float_dtype = np.float32
data = ng.parameter(Shape([3, 10]), dtype=float_dtype, name="data")
beta = ng.parameter(Shape([]), dtype=float_dtype, name="beta")
node = ng.swish(data, beta)
assert node.get_type_name() == "Swish"
assert node.get_output_size() == 1
assert list(node.get_output_shape(0)) == [3, 10]
assert node.get_output_element_type(0) == Type.f32
def test_swish_props_without_beta():
float_dtype = np.float32
data = ng.parameter(Shape([3, 10]), dtype=float_dtype, name="data")
node = ng.swish(data)
assert node.get_type_name() == "Swish"
assert node.get_output_size() == 1
assert list(node.get_output_shape(0)) == [3, 10]
assert node.get_output_element_type(0) == Type.f32

View File

@ -332,6 +332,8 @@ set (SRC
op/subtract.hpp
op/sum.cpp
op/sum.hpp
op/swish.cpp
op/swish.hpp
op/variadic_split.cpp
op/variadic_split.hpp
op/tan.cpp

View File

@ -0,0 +1,140 @@
//*****************************************************************************
// Copyright 2017-2020 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#include "ngraph/op/swish.hpp"
#include "ngraph/attribute_visitor.hpp"
#include "ngraph/op/constant.hpp"
#include "ngraph/runtime/host_tensor.hpp"
#include "ngraph/runtime/reference/swish.hpp"
using namespace std;
using namespace ngraph;
constexpr NodeTypeInfo op::v4::Swish::type_info;
op::v4::Swish::Swish(const Output<Node>& arg)
: Op({arg})
{
constructor_validate_and_infer_types();
}
op::v4::Swish::Swish(const Output<Node>& arg, const Output<Node>& beta)
: Op({arg, beta})
{
constructor_validate_and_infer_types();
}
bool op::v4::Swish::visit_attributes(AttributeVisitor& visitor)
{
return true;
}
void op::v4::Swish::validate_and_infer_types()
{
auto inputs_count = input_values().size();
NODE_VALIDATION_CHECK(this,
inputs_count == 1 || inputs_count == 2,
"Swish must have 1 or 2 inputs, but it has: ",
inputs_count);
if (inputs_count == 2)
{
NODE_VALIDATION_CHECK(this,
input_value(0).get_element_type() ==
input_value(1).get_element_type(),
"Swish inputs must have the same type but they are: ",
input_value(0).get_element_type(),
" and ",
input_value(1).get_element_type());
if (get_input_partial_shape(1).rank().is_static())
{
auto beta_rank = get_input_partial_shape(1).rank().get_length();
NODE_VALIDATION_CHECK(this,
beta_rank == 0,
"Swish input with beta must be scalar but it has rank: ",
beta_rank);
}
}
set_output_size(1);
set_output_type(0, get_input_element_type(0), get_input_partial_shape(0));
}
shared_ptr<Node> op::v4::Swish::clone_with_new_inputs(const OutputVector& new_args) const
{
if (new_args.size() == 1)
{
return make_shared<op::v4::Swish>(new_args.at(0));
}
else
{
return make_shared<op::v4::Swish>(new_args.at(0), new_args.at(1));
}
}
namespace
{
template <element::Type_t ET>
inline bool evaluate(const HostTensorPtr& arg0,
const HostTensorPtr& arg1,
const HostTensorPtr& out,
const size_t count)
{
using T = typename element_type_traits<ET>::value_type;
if (arg1 != nullptr)
{
runtime::reference::swish<T>(
arg0->get_data_ptr<ET>(), arg1->get_data_ptr<ET>(), out->get_data_ptr<ET>(), count);
}
else
{
runtime::reference::swish<T>(
arg0->get_data_ptr<ET>(), nullptr, out->get_data_ptr<ET>(), count);
}
return true;
}
bool evaluate_swish(const HostTensorPtr& arg0,
const HostTensorPtr& arg1,
const HostTensorPtr& out,
const size_t count)
{
bool rc = true;
out->set_unary(arg0);
switch (arg0->get_element_type())
{
TYPE_CASE(f16)(arg0, arg1, out, count);
break;
TYPE_CASE(f32)(arg0, arg1, out, count);
break;
default: rc = false; break;
}
return rc;
}
}
bool op::v4::Swish::evaluate(const HostTensorVector& outputs, const HostTensorVector& inputs) const
{
if (inputs.size() == 2)
{
return evaluate_swish(inputs[0], inputs[1], outputs[0], shape_size(get_output_shape(0)));
}
else
{
return evaluate_swish(inputs[0], nullptr, outputs[0], shape_size(get_output_shape(0)));
}
}

View File

@ -0,0 +1,57 @@
//*****************************************************************************
// Copyright 2017-2020 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#pragma once
#include "ngraph/node.hpp"
#include "ngraph/op/op.hpp"
namespace ngraph
{
namespace op
{
namespace v4
{
/// \brief A Swish Activation Function
/// f(x) = x / (1.0 + exp(-beta * x)) or
/// f(x) = x * sigmoid(beta * x)
///
class NGRAPH_API Swish : public ngraph::op::Op
{
public:
static constexpr NodeTypeInfo type_info{"Swish", 4};
const NodeTypeInfo& get_type_info() const override { return type_info; }
Swish() = default;
/// \brief Constructs an Swish operation.
///
/// \param data Input tensor
/// \param beta Scalar with beta value. If the argument is not specified then use
/// the default value 1.0
Swish(const Output<Node>& arg, const Output<Node>& beta);
explicit Swish(const Output<Node>& arg);
bool visit_attributes(AttributeVisitor& visitor) override;
void validate_and_infer_types() override;
virtual std::shared_ptr<Node>
clone_with_new_inputs(const OutputVector& new_args) const override;
bool evaluate(const HostTensorVector& outputs,
const HostTensorVector& inputs) const override;
};
}
}
}

View File

@ -157,6 +157,7 @@
#include "ngraph/op/strided_slice.hpp"
#include "ngraph/op/subtract.hpp"
#include "ngraph/op/sum.hpp"
#include "ngraph/op/swish.hpp"
#include "ngraph/op/tan.hpp"
#include "ngraph/op/tanh.hpp"
#include "ngraph/op/tensor_iterator.hpp"

View File

@ -155,6 +155,7 @@ NGRAPH_OP(TopK, ngraph::op::v3)
NGRAPH_OP(Acosh, ngraph::op::v3)
NGRAPH_OP(Asinh, ngraph::op::v3)
NGRAPH_OP(Atanh, ngraph::op::v3)
NGRAPH_OP(CTCLoss, ngraph::op::v4)
NGRAPH_OP(NonMaxSuppression, ngraph::op::v4)
NGRAPH_OP(Mish, ngraph::op::v4)
NGRAPH_OP(CTCLoss, ngraph::op::v4)
NGRAPH_OP(Swish, ngraph::op::v4)

View File

@ -0,0 +1,43 @@
//*****************************************************************************
// Copyright 2017-2020 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#pragma once
#include <cmath>
#include <cstddef>
namespace ngraph
{
namespace runtime
{
namespace reference
{
template <typename T>
void swish(const T* arg, const T* beta, T* out, size_t count)
{
T beta_value = static_cast<T>(1.0);
if (beta != nullptr)
{
beta_value = beta[0];
}
for (size_t i = 0; i < count; i++)
{
out[i] = arg[i] / (1.0 + std::exp(-arg[i] * beta_value));
}
}
}
}
}

View File

@ -77,6 +77,7 @@ set(SRC
op_eval/non_zero.cpp
op_eval/split.cpp
op_eval/strided_slice.cpp
op_eval/swish.cpp
op_is.cpp
opset1.cpp
partial_shape.cpp
@ -165,6 +166,7 @@ set(SRC
type_prop/squared_difference.cpp
type_prop/squeeze.cpp
type_prop/sum.cpp
type_prop/swish.cpp
type_prop/reduce_prod.cpp
type_prop/reduce_sum.cpp
type_prop/tile.cpp

View File

@ -0,0 +1,90 @@
//*****************************************************************************
// Copyright 2017-2020 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#include <string>
#include <vector>
#include "gtest/gtest.h"
#include "ngraph/op/swish.hpp"
#include "ngraph/runtime/host_tensor.hpp"
#include "ngraph/validation_util.hpp"
#include "runtime/backend.hpp"
#include "util/test_tools.hpp"
using namespace std;
using namespace ngraph;
TEST(op_eval, swish_with_beta1)
{
auto p = make_shared<op::Parameter>(element::f32, Shape{3});
auto beta = make_shared<op::Parameter>(element::f32, Shape{});
auto swish = make_shared<op::v4::Swish>(p, beta);
auto fun = make_shared<Function>(OutputVector{swish}, ParameterVector{p, beta});
std::vector<float> inputs{-0.5, 0.0, 0.5};
std::vector<float> expected_result{-0.18877034, 0.0, 0.31122968};
auto result = make_shared<HostTensor>();
ASSERT_TRUE(fun->evaluate({result},
{make_host_tensor<element::Type_t::f32>(Shape{3}, inputs),
make_host_tensor<element::Type_t::f32>(Shape{}, {1.0})}));
EXPECT_EQ(result->get_element_type(), element::f32);
EXPECT_EQ(result->get_shape(), Shape{3});
auto result_data = read_vector<float>(result);
for (auto i = 0; i < inputs.size(); i++)
EXPECT_NEAR(result_data[i], expected_result[i], 0.000001);
}
TEST(op_eval, swish_with_beta0_75)
{
auto p = make_shared<op::Parameter>(element::f32, Shape{3});
auto beta = make_shared<op::Parameter>(element::f32, Shape{});
auto swish = make_shared<op::v4::Swish>(p, beta);
auto fun = make_shared<Function>(OutputVector{swish}, ParameterVector{p, beta});
std::vector<float> inputs{-0.5, 0.0, 0.5};
std::vector<float> expected_result{-0.2036667, 0.0, 0.2963333};
auto result = make_shared<HostTensor>();
ASSERT_TRUE(fun->evaluate({result},
{make_host_tensor<element::Type_t::f32>(Shape{3}, inputs),
make_host_tensor<element::Type_t::f32>(Shape{}, {0.75})}));
EXPECT_EQ(result->get_element_type(), element::f32);
EXPECT_EQ(result->get_shape(), Shape{3});
auto result_data = read_vector<float>(result);
for (auto i = 0; i < inputs.size(); i++)
EXPECT_NEAR(result_data[i], expected_result[i], 0.000001);
}
TEST(op_eval, swish_without_beta)
{
auto p = make_shared<op::Parameter>(element::f32, Shape{3});
auto swish = make_shared<op::v4::Swish>(p);
auto fun = make_shared<Function>(OutputVector{swish}, ParameterVector{p});
std::vector<float> inputs{-0.5, 0.0, 0.5};
std::vector<float> expected_result{-0.18877034, 0.0, 0.31122968};
auto result = make_shared<HostTensor>();
ASSERT_TRUE(
fun->evaluate({result}, {make_host_tensor<element::Type_t::f32>(Shape{3}, inputs)}));
EXPECT_EQ(result->get_element_type(), element::f32);
EXPECT_EQ(result->get_shape(), Shape{3});
auto result_data = read_vector<float>(result);
for (auto i = 0; i < inputs.size(); i++)
EXPECT_NEAR(result_data[i], expected_result[i], 0.000001);
}

View File

@ -0,0 +1,95 @@
//*****************************************************************************
// Copyright 2017-2020 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#include "gtest/gtest.h"
#include "ngraph/ngraph.hpp"
#include "util/type_prop.hpp"
using namespace std;
using namespace ngraph;
TEST(type_prop, swish)
{
auto data = make_shared<op::Parameter>(element::f32, Shape{1, 3, 6});
auto swish_func = make_shared<op::v4::Swish>(data);
EXPECT_EQ(swish_func->get_element_type(), element::f32);
EXPECT_EQ(swish_func->get_shape(), data->get_output_shape(0));
}
TEST(type_prop, swish_partial)
{
auto data = make_shared<op::Parameter>(element::f32, PartialShape{1, Dimension::dynamic(), 6});
auto swish_func = make_shared<op::v4::Swish>(data);
EXPECT_EQ(swish_func->get_element_type(), element::f32);
ASSERT_TRUE(
swish_func->get_output_partial_shape(0).same_scheme(data->get_output_partial_shape(0)));
// rank unknown
auto swish_partial = make_shared<op::v4::Swish>(
make_shared<op::Parameter>(element::f32, PartialShape::dynamic()));
ASSERT_TRUE(swish_partial->get_output_partial_shape(0).same_scheme(PartialShape::dynamic()));
}
TEST(type_prop, swish_partial_static_rank)
{
auto data = make_shared<op::Parameter>(element::f32, PartialShape{1, Dimension::dynamic(), 6});
auto swish_func = make_shared<op::v4::Swish>(data);
EXPECT_EQ(swish_func->get_element_type(), element::f32);
ASSERT_TRUE(
swish_func->get_output_partial_shape(0).same_scheme(data->get_output_partial_shape(0)));
ASSERT_TRUE(swish_func->get_output_partial_shape(0).rank().is_static());
}
TEST(type_prop, swish_incompatible_types)
{
auto data = make_shared<op::Parameter>(element::f32, Shape{1, 3, 6});
auto beta = make_shared<op::Parameter>(element::f16, Shape{});
try
{
const auto swish_func = make_shared<op::v4::Swish>(data, beta);
FAIL() << "swish_func node was created with incompatible input data types.";
}
catch (const NodeValidationFailure& error)
{
EXPECT_HAS_SUBSTRING(error.what(), std::string("Swish inputs must have the same type"));
}
}
TEST(type_prop, swish_beta_not_scalar)
{
auto data = make_shared<op::Parameter>(element::f32, Shape{1, 3, 6});
auto beta = make_shared<op::Parameter>(element::f32, Shape{1});
try
{
const auto swish_func = make_shared<op::v4::Swish>(data, beta);
FAIL() << "swish_func node was created with scalar beta value.";
}
catch (const NodeValidationFailure& error)
{
EXPECT_HAS_SUBSTRING(error.what(), std::string("Swish input with beta must be scalar"));
}
}
TEST(type_prop, swish_2_inputs)
{
auto data = make_shared<op::Parameter>(element::f32, Shape{1, 3, 6});
auto beta = make_shared<op::Parameter>(element::f32, Shape{});
const auto swish_func = make_shared<op::v4::Swish>(data, beta);
EXPECT_EQ(swish_func->get_element_type(), element::f32);
ASSERT_TRUE(swish_func->get_output_partial_shape(0).same_scheme(data->get_output_shape(0)));
ASSERT_TRUE(swish_func->get_output_partial_shape(0).rank().is_static());
}