Enable swish (#1682)
* Draft version of the Swish nGraph operation and fusing transformations for different approaches to express the operation * Swish fusing transformation refactoring * Added Swish operation and extractor for TF. Removed unfolding transformation for the operation. * Added SwishIE. Implemented transformation to convert Swish to SwishIE. * Code style fixes * Updated Swish reference implementation. Added tests for shape and value inference * Fixed code style for Python API * Fixed unit test * Apply review comments * Use matcher_pass_callback * Make m_alpha attribute protected in the SwishIE operation * Fixed Swish op PythonAPI test
This commit is contained in:
parent
600ad8d180
commit
318d38770b
@ -496,6 +496,16 @@ InferenceEngine::details::CNNLayerCreator::CNNLayerCreator(const std::shared_ptr
|
||||
|
||||
});
|
||||
|
||||
addSpecificCreator({"SwishIE"}, [](const std::shared_ptr<::ngraph::Node>& node,
|
||||
const std::map<std::string, std::string> params) -> CNNLayerPtr {
|
||||
LayerParams attrs = {node->get_friendly_name(), "Swish",
|
||||
details::convertPrecision(node->get_output_element_type(0))};
|
||||
auto res = std::make_shared<InferenceEngine::CNNLayer>(attrs);
|
||||
res->params = params;
|
||||
return res;
|
||||
|
||||
});
|
||||
|
||||
addSpecificCreator({"PriorBox"}, [](const std::shared_ptr<::ngraph::Node>& node,
|
||||
const std::map<std::string, std::string> params) -> CNNLayerPtr {
|
||||
THROW_IE_EXCEPTION << "PriorBox operation has a form that is not supported." << node->get_friendly_name()
|
||||
|
@ -0,0 +1,32 @@
|
||||
// Copyright (C) 2018-2020 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <memory>
|
||||
|
||||
#include <transformations_visibility.hpp>
|
||||
|
||||
#include "ngraph/op/op.hpp"
|
||||
|
||||
namespace ngraph {
|
||||
namespace op {
|
||||
class TRANSFORMATIONS_API SwishIE : public Op {
|
||||
public:
|
||||
static constexpr NodeTypeInfo type_info{"SwishIE", 1};
|
||||
const NodeTypeInfo &get_type_info() const override { return type_info; }
|
||||
|
||||
explicit SwishIE(const Output<Node> &input, float alpha = 1.0);
|
||||
|
||||
void validate_and_infer_types() override;
|
||||
bool visit_attributes(AttributeVisitor& visitor) override;
|
||||
std::shared_ptr<Node> clone_with_new_inputs(const OutputVector &new_args) const override;
|
||||
|
||||
void set_alpha(float alpha);
|
||||
float get_alpha() const;
|
||||
protected:
|
||||
float m_alpha;
|
||||
};
|
||||
} // namespace op
|
||||
} // namespace ngraph
|
@ -0,0 +1,24 @@
|
||||
// Copyright (C) 2018-2020 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <vector>
|
||||
#include <memory>
|
||||
#include <string>
|
||||
|
||||
#include <transformations_visibility.hpp>
|
||||
|
||||
#include <ngraph/pass/graph_rewrite.hpp>
|
||||
|
||||
namespace ngraph {
|
||||
namespace pass {
|
||||
class TRANSFORMATIONS_API ConvertSwishToSwishIEMatcher;
|
||||
} // namespace pass
|
||||
} // namespace ngraph
|
||||
|
||||
class ngraph::pass::ConvertSwishToSwishIEMatcher: public ngraph::pass::MatcherPass {
|
||||
public:
|
||||
ConvertSwishToSwishIEMatcher();
|
||||
};
|
@ -0,0 +1,73 @@
|
||||
// Copyright (C) 2018-2020 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <memory>
|
||||
#include <utility>
|
||||
|
||||
#include <transformations_visibility.hpp>
|
||||
#include <ngraph/pass/graph_rewrite.hpp>
|
||||
|
||||
namespace ngraph {
|
||||
namespace pass {
|
||||
|
||||
class TRANSFORMATIONS_API SwishFusion;
|
||||
class TRANSFORMATIONS_API SwishFusionWithSigmoid;
|
||||
class TRANSFORMATIONS_API SwishFusionWithSigmoidWithBeta;
|
||||
class TRANSFORMATIONS_API SwishFusionWithBeta;
|
||||
class TRANSFORMATIONS_API SwishFusionWithoutBeta;
|
||||
|
||||
} // namespace pass
|
||||
} // namespace ngraph
|
||||
|
||||
/**
|
||||
* @ingroup ie_transformation_common_api
|
||||
* @brief SwishFusion transformation replaces various sub-graphs with a Swish op.
|
||||
*/
|
||||
class ngraph::pass::SwishFusion: public ngraph::pass::GraphRewrite {
|
||||
public:
|
||||
SwishFusion() {
|
||||
add_matcher<ngraph::pass::SwishFusionWithSigmoid>();
|
||||
add_matcher<ngraph::pass::SwishFusionWithSigmoidWithBeta>();
|
||||
add_matcher<ngraph::pass::SwishFusionWithBeta>();
|
||||
add_matcher<ngraph::pass::SwishFusionWithoutBeta>();
|
||||
}
|
||||
};
|
||||
|
||||
/**
|
||||
* @ingroup ie_transformation_common_api
|
||||
* @brief SwishFusionWithSigmoid replaces a sub-graphs x * Sigmoid(x) with a Swish op.
|
||||
*/
|
||||
class ngraph::pass::SwishFusionWithSigmoid: public ngraph::pass::MatcherPass {
|
||||
public:
|
||||
SwishFusionWithSigmoid();
|
||||
};
|
||||
|
||||
/**
|
||||
* @ingroup ie_transformation_common_api
|
||||
* @brief SwishFusionWithSigmoid replaces a sub-graphs x * Sigmoid(x * beta) with a Swish op.
|
||||
*/
|
||||
class ngraph::pass::SwishFusionWithSigmoidWithBeta: public ngraph::pass::MatcherPass {
|
||||
public:
|
||||
SwishFusionWithSigmoidWithBeta();
|
||||
};
|
||||
|
||||
/**
|
||||
* @ingroup ie_transformation_common_api
|
||||
* @brief SwishFusionWithSigmoid replaces a sub-graphs x / (1.0 + exp(-x * beta)) with a Swish op.
|
||||
*/
|
||||
class ngraph::pass::SwishFusionWithBeta: public ngraph::pass::MatcherPass {
|
||||
public:
|
||||
SwishFusionWithBeta();
|
||||
};
|
||||
|
||||
/**
|
||||
* @ingroup ie_transformation_common_api
|
||||
* @brief SwishFusionWithSigmoid replaces a sub-graphs x / (1.0 + exp(-x)) with a Swish op.
|
||||
*/
|
||||
class ngraph::pass::SwishFusionWithoutBeta: public ngraph::pass::MatcherPass {
|
||||
public:
|
||||
SwishFusionWithoutBeta();
|
||||
};
|
@ -0,0 +1,44 @@
|
||||
// Copyright (C) 2018-2020 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#include "ngraph_ops/swish_ie.hpp"
|
||||
|
||||
#include <algorithm>
|
||||
#include <memory>
|
||||
|
||||
#include "ngraph/util.hpp"
|
||||
#include "ngraph/validation_util.hpp"
|
||||
|
||||
using namespace std;
|
||||
using namespace ngraph;
|
||||
|
||||
constexpr NodeTypeInfo op::SwishIE::type_info;
|
||||
|
||||
op::SwishIE::SwishIE(const Output<Node> & input, const float alpha)
|
||||
: Op({input}), m_alpha(alpha) {
|
||||
constructor_validate_and_infer_types();
|
||||
}
|
||||
|
||||
std::shared_ptr<Node> op::SwishIE::clone_with_new_inputs(const OutputVector& new_args) const {
|
||||
check_new_args_count(this, new_args);
|
||||
return make_shared<SwishIE>(new_args.at(0), m_alpha);
|
||||
}
|
||||
|
||||
bool op::SwishIE::visit_attributes(AttributeVisitor& visitor) {
|
||||
visitor.on_attribute("alpha", m_alpha);
|
||||
return true;
|
||||
}
|
||||
|
||||
void op::SwishIE::validate_and_infer_types() {
|
||||
set_output_type(0, get_input_element_type(0), get_input_partial_shape(0));
|
||||
}
|
||||
|
||||
void op::SwishIE::set_alpha(float alpha) {
|
||||
m_alpha = alpha;
|
||||
}
|
||||
|
||||
float op::SwishIE::get_alpha() const {
|
||||
return m_alpha;
|
||||
}
|
||||
|
@ -12,6 +12,7 @@
|
||||
#include "transformations/init_node_info.hpp"
|
||||
#include "transformations/itt.hpp"
|
||||
#include "transformations/mish_fusion.hpp"
|
||||
#include "transformations/swish_fusion.hpp"
|
||||
|
||||
#include <ngraph/pass/manager.hpp>
|
||||
#include <ngraph/pass/nop_elimination.hpp>
|
||||
@ -34,6 +35,7 @@ bool ngraph::pass::CommonOptimizations::run_on_function(std::shared_ptr<ngraph::
|
||||
manager.register_pass<ngraph::pass::ConvertScatterElementsToScatter>(); // partially depends on CF
|
||||
manager.register_pass<ngraph::pass::DepthToSpaceFusion>();
|
||||
manager.register_pass<ngraph::pass::MishFusion>();
|
||||
manager.register_pass<ngraph::pass::SwishFusion>();
|
||||
|
||||
manager.set_callback(m_transformation_callback);
|
||||
manager.run_passes(f);
|
||||
|
@ -32,6 +32,7 @@
|
||||
#include <transformations/convert_opset1_to_legacy/convert_strided_slice_to_crop.hpp>
|
||||
#include <transformations/convert_subtract.hpp>
|
||||
#include <transformations/convert_opset1_to_legacy/convert_selu_to_selu_ie.hpp>
|
||||
#include <transformations/convert_opset1_to_legacy/convert_swish_to_swish_ie.hpp>
|
||||
#include <transformations/convert_opset1_to_legacy/convert_tile_to_ie_tile.hpp>
|
||||
#include <transformations/convert_opset1_to_legacy/convert_topk_to_topk_ie.hpp>
|
||||
#include <transformations/convert_depth_to_space.hpp>
|
||||
@ -129,6 +130,7 @@ bool ngraph::pass::ConvertOpSet1ToLegacy::run_on_function(std::shared_ptr<ngraph
|
||||
anchor->add_matcher<ngraph::pass::ConvertPReLUToReLUIE>();
|
||||
anchor->add_matcher<ngraph::pass::ConvertGatherToGatherIEMatcher>();
|
||||
anchor->add_matcher<ngraph::pass::ConvertSeluToSeluIEMatcher>();
|
||||
anchor->add_matcher<ngraph::pass::ConvertSwishToSwishIEMatcher>();
|
||||
anchor->add_matcher<ngraph::pass::ConvertOneHotToOneHotIEMatcher>()->detect_output_type(f);
|
||||
anchor->add_matcher<ngraph::pass::ConvertGatherTreeToGatherTreeIEMatcher>();
|
||||
anchor->add_matcher<ngraph::pass::ConvertTopKToTopKIEMatcher>();
|
||||
|
@ -0,0 +1,46 @@
|
||||
// Copyright (C) 2018-2020 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#include "transformations/convert_opset1_to_legacy/convert_swish_to_swish_ie.hpp"
|
||||
|
||||
#include <memory>
|
||||
|
||||
#include <ngraph/opsets/opset4.hpp>
|
||||
|
||||
#include <ngraph_ops/swish_ie.hpp>
|
||||
#include <transformations/utils/utils.hpp>
|
||||
#include <ngraph/rt_info.hpp>
|
||||
#include <ngraph/pattern/op/wrap_type.hpp>
|
||||
|
||||
ngraph::pass::ConvertSwishToSwishIEMatcher::ConvertSwishToSwishIEMatcher() {
|
||||
auto swish = ngraph::pattern::wrap_type<ngraph::opset4::Swish>();
|
||||
|
||||
ngraph::matcher_pass_callback callback = [](pattern::Matcher& m) {
|
||||
auto swish = std::dynamic_pointer_cast<ngraph::opset4::Swish> (m.get_match_root());
|
||||
if (!swish) {
|
||||
return false;
|
||||
}
|
||||
float beta_value = 1.0;
|
||||
if (swish->input_values().size() == 2) {
|
||||
auto beta_node = swish->input_value(1).get_node_shared_ptr();
|
||||
auto beta_const = std::dynamic_pointer_cast<ngraph::opset4::Constant>(beta_node);
|
||||
|
||||
if (!beta_const) {
|
||||
return false;
|
||||
}
|
||||
if (!ngraph::op::util::get_single_value(beta_const, beta_value)) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
auto swish_ie = std::make_shared<ngraph::op::SwishIE>(swish->input(0).get_source_output(), beta_value);
|
||||
swish_ie->set_friendly_name(swish->get_friendly_name());
|
||||
ngraph::copy_runtime_info(swish, swish_ie);
|
||||
ngraph::replace_node(swish, swish_ie);
|
||||
return true;
|
||||
};
|
||||
|
||||
auto m = std::make_shared<ngraph::pattern::Matcher>(swish, "ConvertSwishToSwishIE");
|
||||
this->register_matcher(m, callback);
|
||||
}
|
@ -0,0 +1,183 @@
|
||||
// Copyright (C) 2018-2020 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#include "transformations/swish_fusion.hpp"
|
||||
|
||||
#include <memory>
|
||||
|
||||
#include <ngraph/opsets/opset4.hpp>
|
||||
#include <ngraph/rt_info.hpp>
|
||||
#include <ngraph/pattern/op/wrap_type.hpp>
|
||||
|
||||
bool check_constant_value(const std::shared_ptr<ngraph::opset4::Constant>& constant) {
|
||||
if (!constant) {
|
||||
return false;
|
||||
}
|
||||
if (constant->get_element_type() == ngraph::element::f32 || constant->get_element_type() == ngraph::element::f16) {
|
||||
auto data = constant->cast_vector<float>();
|
||||
if (data.size() != 1 || data[0] != 1.0) {
|
||||
return false;
|
||||
}
|
||||
} else {
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
bool check_beta_value(const std::shared_ptr<ngraph::opset4::Constant>& constant) {
|
||||
// check that the constant for beta contains only one distinct element
|
||||
if (!constant) {
|
||||
return false;
|
||||
}
|
||||
if (constant->get_element_type() == ngraph::element::f32 || constant->get_element_type() == ngraph::element::f16) {
|
||||
auto data = constant->cast_vector<float>();
|
||||
if (!std::equal(data.begin() + 1, data.end(), data.begin())) {
|
||||
return false;
|
||||
}
|
||||
} else {
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
ngraph::pass::SwishFusionWithSigmoid::SwishFusionWithSigmoid() {
|
||||
// replaces a sub-graphs x * Sigmoid(x) with a Swish op.
|
||||
auto input = ngraph::pattern::any_input();
|
||||
auto sigmoid = std::make_shared<ngraph::opset4::Sigmoid>(input);
|
||||
auto mul = std::make_shared<ngraph::opset4::Multiply>(input, sigmoid);
|
||||
|
||||
ngraph::matcher_pass_callback callback = [=](ngraph::pattern::Matcher &m) {
|
||||
auto &pattern_to_output = m.get_pattern_value_map();
|
||||
auto exp_input = pattern_to_output.at(input);
|
||||
|
||||
auto swish = std::make_shared<ngraph::opset4::Swish>(exp_input);
|
||||
|
||||
swish->set_friendly_name(m.get_match_root()->get_friendly_name());
|
||||
ngraph::copy_runtime_info({pattern_to_output.at(sigmoid).get_node_shared_ptr(),
|
||||
pattern_to_output.at(mul).get_node_shared_ptr()},
|
||||
swish);
|
||||
ngraph::replace_node(m.get_match_root(), swish);
|
||||
return true;
|
||||
};
|
||||
|
||||
auto m = std::make_shared<ngraph::pattern::Matcher>(mul, "SwishWithSigmoidFusion");
|
||||
register_matcher(m, callback);
|
||||
}
|
||||
|
||||
ngraph::pass::SwishFusionWithSigmoidWithBeta::SwishFusionWithSigmoidWithBeta() {
|
||||
// replaces a sub-graphs x * Sigmoid(x * beta) with a Swish op.
|
||||
auto input = ngraph::pattern::any_input();
|
||||
auto beta = ngraph::pattern::any_input();
|
||||
auto mul_beta = std::make_shared<ngraph::opset4::Multiply>(input, beta);
|
||||
auto sigmoid = std::make_shared<ngraph::opset4::Sigmoid>(mul_beta);
|
||||
auto mul = std::make_shared<ngraph::opset4::Multiply>(input, sigmoid);
|
||||
|
||||
ngraph::matcher_pass_callback callback = [=](ngraph::pattern::Matcher &m) {
|
||||
auto &pattern_to_output = m.get_pattern_value_map();
|
||||
auto exp_input = pattern_to_output.at(input);
|
||||
auto beta_input = pattern_to_output.at(beta);
|
||||
|
||||
auto beta_constant = std::dynamic_pointer_cast<ngraph::opset4::Constant>(beta_input.get_node_shared_ptr());
|
||||
Output<Node> new_beta;
|
||||
if (beta_constant) {
|
||||
if (check_beta_value(beta_constant)) {
|
||||
new_beta = opset4::Constant::create(beta_input.get_element_type(), Shape{}, {beta_constant->cast_vector<float>()[0]});
|
||||
} else {
|
||||
return false;
|
||||
}
|
||||
} else {
|
||||
// if the input is not constant and number of elements is not equal to 1 then we cannot perform fusing
|
||||
if (beta_input.get_partial_shape().is_dynamic() || ngraph::shape_size(beta_input.get_shape()) != 1) {
|
||||
return false;
|
||||
}
|
||||
new_beta = beta_input;
|
||||
}
|
||||
|
||||
auto swish = std::make_shared<ngraph::opset4::Swish>(exp_input, new_beta);
|
||||
|
||||
swish->set_friendly_name(m.get_match_root()->get_friendly_name());
|
||||
ngraph::copy_runtime_info({pattern_to_output.at(sigmoid).get_node_shared_ptr(),
|
||||
pattern_to_output.at(mul).get_node_shared_ptr()},
|
||||
swish);
|
||||
ngraph::replace_node(m.get_match_root(), swish);
|
||||
return true;
|
||||
};
|
||||
|
||||
auto m = std::make_shared<ngraph::pattern::Matcher>(mul, "SwishWithSigmoidWithBetaFusion");
|
||||
register_matcher(m, callback);
|
||||
}
|
||||
|
||||
ngraph::pass::SwishFusionWithBeta::SwishFusionWithBeta() {
|
||||
// replaces a sub-graphs x / (1.0 + exp(-x * beta)) with a Swish op.
|
||||
auto input = ngraph::pattern::any_input();
|
||||
auto beta = ngraph::pattern::any_input();
|
||||
auto mul = std::make_shared<ngraph::opset4::Multiply>(input, beta);
|
||||
auto neg = std::make_shared<ngraph::opset4::Negative>(mul);
|
||||
auto exp = std::make_shared<ngraph::opset4::Exp>(neg);
|
||||
auto add_constant = ngraph::pattern::wrap_type<ngraph::opset4::Constant>();
|
||||
auto add = std::make_shared<ngraph::opset4::Add>(exp, add_constant);
|
||||
auto div = std::make_shared<ngraph::opset4::Divide>(input, add);
|
||||
|
||||
ngraph::matcher_pass_callback callback = [=](ngraph::pattern::Matcher &m) {
|
||||
auto &pattern_to_output = m.get_pattern_value_map();
|
||||
auto exp_input = pattern_to_output.at(input);
|
||||
|
||||
auto constant = std::dynamic_pointer_cast<ngraph::opset4::Constant>(pattern_to_output.at(add_constant).get_node_shared_ptr());
|
||||
if (!check_constant_value(constant)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
auto swish = std::make_shared<ngraph::opset4::Swish>(exp_input, pattern_to_output.at(beta));
|
||||
|
||||
swish->set_friendly_name(m.get_match_root()->get_friendly_name());
|
||||
ngraph::copy_runtime_info({pattern_to_output.at(beta).get_node_shared_ptr(),
|
||||
pattern_to_output.at(mul).get_node_shared_ptr(),
|
||||
pattern_to_output.at(neg).get_node_shared_ptr(),
|
||||
pattern_to_output.at(exp).get_node_shared_ptr(),
|
||||
pattern_to_output.at(add_constant).get_node_shared_ptr(),
|
||||
pattern_to_output.at(add).get_node_shared_ptr(),
|
||||
pattern_to_output.at(div).get_node_shared_ptr()},
|
||||
swish);
|
||||
ngraph::replace_node(m.get_match_root(), swish);
|
||||
return true;
|
||||
};
|
||||
|
||||
auto m = std::make_shared<ngraph::pattern::Matcher>(div, "SwishWithBetaFusion");
|
||||
register_matcher(m, callback);
|
||||
}
|
||||
|
||||
ngraph::pass::SwishFusionWithoutBeta::SwishFusionWithoutBeta() {
|
||||
// replaces a sub-graphs x / (1.0 + exp(-x)) with a Swish op.
|
||||
auto input = ngraph::pattern::any_input();
|
||||
auto neg = std::make_shared<ngraph::opset4::Negative>(input);
|
||||
auto exp = std::make_shared<ngraph::opset4::Exp>(neg);
|
||||
auto add_constant = ngraph::pattern::wrap_type<ngraph::opset4::Constant>();
|
||||
auto add = std::make_shared<ngraph::opset4::Add>(exp, add_constant);
|
||||
auto div = std::make_shared<ngraph::opset4::Divide>(input, add);
|
||||
|
||||
ngraph::matcher_pass_callback callback = [=](ngraph::pattern::Matcher& m) {
|
||||
auto & pattern_to_output = m.get_pattern_value_map();
|
||||
auto exp_input = pattern_to_output.at(input);
|
||||
|
||||
auto constant = std::dynamic_pointer_cast<ngraph::opset4::Constant>(pattern_to_output.at(add_constant).get_node_shared_ptr());
|
||||
if (!check_constant_value(constant)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
auto swish = std::make_shared<ngraph::opset4::Swish>(exp_input);
|
||||
|
||||
swish->set_friendly_name(m.get_match_root()->get_friendly_name());
|
||||
ngraph::copy_runtime_info({pattern_to_output.at(neg).get_node_shared_ptr(),
|
||||
pattern_to_output.at(exp).get_node_shared_ptr(),
|
||||
pattern_to_output.at(add_constant).get_node_shared_ptr(),
|
||||
pattern_to_output.at(add).get_node_shared_ptr(),
|
||||
pattern_to_output.at(div).get_node_shared_ptr()},
|
||||
swish);
|
||||
ngraph::replace_node(m.get_match_root(), swish);
|
||||
return true;
|
||||
};
|
||||
|
||||
auto m = std::make_shared<ngraph::pattern::Matcher>(div, "SwishWithoutBetaFusion");
|
||||
register_matcher(m, callback);
|
||||
}
|
@ -6,12 +6,10 @@
|
||||
|
||||
#include <string>
|
||||
#include <memory>
|
||||
#include <queue>
|
||||
|
||||
#include <ngraph/function.hpp>
|
||||
#include <ngraph/opsets/opset4.hpp>
|
||||
#include <ngraph/pass/manager.hpp>
|
||||
#include <ngraph/pass/visualize_tree.hpp>
|
||||
#include <transformations/mish_fusion.hpp>
|
||||
#include <transformations/init_node_info.hpp>
|
||||
#include <transformations/utils/utils.hpp>
|
||||
|
@ -0,0 +1,206 @@
|
||||
// Copyright (C) 2020 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#include <gtest/gtest.h>
|
||||
|
||||
#include <string>
|
||||
#include <memory>
|
||||
|
||||
#include <ngraph/function.hpp>
|
||||
#include <ngraph/opsets/opset4.hpp>
|
||||
#include <ngraph/pass/manager.hpp>
|
||||
#include <transformations/swish_fusion.hpp>
|
||||
#include <transformations/init_node_info.hpp>
|
||||
#include <transformations/utils/utils.hpp>
|
||||
|
||||
#include "common_test_utils/ngraph_test_utils.hpp"
|
||||
|
||||
using namespace testing;
|
||||
|
||||
TEST(TransformationTests, SwishFusionWithBeta) {
|
||||
std::shared_ptr<ngraph::Function> f(nullptr), f_ref(nullptr);
|
||||
{
|
||||
auto input = std::make_shared<ngraph::opset4::Parameter>(ngraph::element::f32, ngraph::PartialShape::dynamic(1));
|
||||
auto beta = std::make_shared<ngraph::opset4::Parameter>(ngraph::element::f32, ngraph::Shape{});
|
||||
auto mul = std::make_shared<ngraph::opset4::Multiply>(input, beta);
|
||||
auto neg = std::make_shared<ngraph::opset4::Negative>(mul);
|
||||
auto exp = std::make_shared<ngraph::opset4::Exp>(neg);
|
||||
auto constant = ngraph::opset4::Constant::create(ngraph::element::f32, ngraph::Shape{1}, {1.0});
|
||||
auto add = std::make_shared<ngraph::opset4::Add>(exp, constant);
|
||||
auto div = std::make_shared<ngraph::opset4::Divide>(input, add);
|
||||
|
||||
f = std::make_shared<ngraph::Function>(ngraph::NodeVector{div}, ngraph::ParameterVector{input, beta});
|
||||
|
||||
ngraph::pass::Manager manager;
|
||||
manager.register_pass<ngraph::pass::InitNodeInfo>();
|
||||
manager.register_pass<ngraph::pass::SwishFusion>();
|
||||
manager.run_passes(f);
|
||||
ASSERT_NO_THROW(check_rt_info(f));
|
||||
}
|
||||
|
||||
{
|
||||
auto input = std::make_shared<ngraph::opset4::Parameter>(ngraph::element::f32, ngraph::PartialShape::dynamic(1));
|
||||
auto beta = std::make_shared<ngraph::opset4::Parameter>(ngraph::element::f32, ngraph::Shape{});
|
||||
auto swish = std::make_shared<ngraph::opset4::Swish>(input, beta);
|
||||
|
||||
f_ref = std::make_shared<ngraph::Function>(ngraph::NodeVector{swish}, ngraph::ParameterVector{input, beta});
|
||||
}
|
||||
|
||||
auto res = compare_functions(f, f_ref);
|
||||
ASSERT_TRUE(res.first) << res.second;
|
||||
}
|
||||
|
||||
TEST(TransformationTests, SwishFusionWithoutBeta) {
|
||||
std::shared_ptr<ngraph::Function> f(nullptr), f_ref(nullptr);
|
||||
{
|
||||
auto input = std::make_shared<ngraph::opset4::Parameter>(ngraph::element::f16, ngraph::PartialShape::dynamic(1));
|
||||
auto neg = std::make_shared<ngraph::opset4::Negative>(input);
|
||||
auto exp = std::make_shared<ngraph::opset4::Exp>(neg);
|
||||
auto constant = ngraph::opset4::Constant::create(ngraph::element::f16, ngraph::Shape{}, {1.0});
|
||||
auto add = std::make_shared<ngraph::opset4::Add>(exp, constant);
|
||||
auto div = std::make_shared<ngraph::opset4::Divide>(input, add);
|
||||
|
||||
f = std::make_shared<ngraph::Function>(ngraph::NodeVector{div}, ngraph::ParameterVector{input});
|
||||
|
||||
ngraph::pass::Manager manager;
|
||||
manager.register_pass<ngraph::pass::InitNodeInfo>();
|
||||
manager.register_pass<ngraph::pass::SwishFusion>();
|
||||
manager.run_passes(f);
|
||||
ASSERT_NO_THROW(check_rt_info(f));
|
||||
}
|
||||
|
||||
{
|
||||
auto input = std::make_shared<ngraph::opset4::Parameter>(ngraph::element::f16, ngraph::PartialShape::dynamic(1));
|
||||
auto swish = std::make_shared<ngraph::opset4::Swish>(input);
|
||||
|
||||
f_ref = std::make_shared<ngraph::Function>(ngraph::NodeVector{swish}, ngraph::ParameterVector{input});
|
||||
}
|
||||
|
||||
auto res = compare_functions(f, f_ref);
|
||||
ASSERT_TRUE(res.first) << res.second;
|
||||
}
|
||||
|
||||
TEST(TransformationTests, SwishFusionWithoutBetaNonOneAddConstant) {
|
||||
std::shared_ptr<ngraph::Function> f(nullptr), f_ref(nullptr);
|
||||
{
|
||||
auto input = std::make_shared<ngraph::opset4::Parameter>(ngraph::element::f16, ngraph::PartialShape::dynamic(1));
|
||||
auto neg = std::make_shared<ngraph::opset4::Negative>(input);
|
||||
auto exp = std::make_shared<ngraph::opset4::Exp>(neg);
|
||||
auto constant = ngraph::opset4::Constant::create(ngraph::element::f16, ngraph::Shape{}, {1.1});
|
||||
auto add = std::make_shared<ngraph::opset4::Add>(exp, constant);
|
||||
auto div = std::make_shared<ngraph::opset4::Divide>(input, add);
|
||||
|
||||
f = std::make_shared<ngraph::Function>(ngraph::NodeVector{div}, ngraph::ParameterVector{input});
|
||||
|
||||
ngraph::pass::Manager manager;
|
||||
manager.register_pass<ngraph::pass::InitNodeInfo>();
|
||||
manager.register_pass<ngraph::pass::SwishFusion>();
|
||||
manager.run_passes(f);
|
||||
ASSERT_NO_THROW(check_rt_info(f));
|
||||
}
|
||||
|
||||
{
|
||||
auto input = std::make_shared<ngraph::opset4::Parameter>(ngraph::element::f16, ngraph::PartialShape::dynamic(1));
|
||||
auto neg = std::make_shared<ngraph::opset4::Negative>(input);
|
||||
auto exp = std::make_shared<ngraph::opset4::Exp>(neg);
|
||||
auto constant = ngraph::opset4::Constant::create(ngraph::element::f16, ngraph::Shape{}, {1.1});
|
||||
auto add = std::make_shared<ngraph::opset4::Add>(exp, constant);
|
||||
auto div = std::make_shared<ngraph::opset4::Divide>(input, add);
|
||||
|
||||
f = std::make_shared<ngraph::Function>(ngraph::NodeVector{div}, ngraph::ParameterVector{input});
|
||||
|
||||
f_ref = std::make_shared<ngraph::Function>(ngraph::NodeVector{div}, ngraph::ParameterVector{input});
|
||||
}
|
||||
|
||||
auto res = compare_functions(f, f_ref);
|
||||
ASSERT_TRUE(res.first) << res.second;
|
||||
}
|
||||
|
||||
TEST(TransformationTests, SwishFusionWithSigmoid) {
|
||||
std::shared_ptr<ngraph::Function> f(nullptr), f_ref(nullptr);
|
||||
{
|
||||
auto input = std::make_shared<ngraph::opset4::Parameter>(ngraph::element::f16, ngraph::PartialShape::dynamic(1));
|
||||
auto sig = std::make_shared<ngraph::opset4::Sigmoid>(input);
|
||||
auto mul = std::make_shared<ngraph::opset4::Multiply>(input, sig);
|
||||
|
||||
f = std::make_shared<ngraph::Function>(ngraph::NodeVector{mul}, ngraph::ParameterVector{input});
|
||||
|
||||
ngraph::pass::Manager manager;
|
||||
manager.register_pass<ngraph::pass::InitNodeInfo>();
|
||||
manager.register_pass<ngraph::pass::SwishFusion>();
|
||||
manager.run_passes(f);
|
||||
ASSERT_NO_THROW(check_rt_info(f));
|
||||
}
|
||||
|
||||
{
|
||||
auto input = std::make_shared<ngraph::opset4::Parameter>(ngraph::element::f16, ngraph::PartialShape::dynamic(1));
|
||||
auto swish = std::make_shared<ngraph::opset4::Swish>(input);
|
||||
|
||||
f_ref = std::make_shared<ngraph::Function>(ngraph::NodeVector{swish}, ngraph::ParameterVector{input});
|
||||
}
|
||||
|
||||
auto res = compare_functions(f, f_ref);
|
||||
ASSERT_TRUE(res.first) << res.second;
|
||||
}
|
||||
|
||||
TEST(TransformationTests, SwishFusionWithSigmoidWithBeta) {
|
||||
std::shared_ptr<ngraph::Function> f(nullptr), f_ref(nullptr);
|
||||
{
|
||||
auto input = std::make_shared<ngraph::opset4::Parameter>(ngraph::element::f16, ngraph::PartialShape::dynamic(1));
|
||||
auto beta = std::make_shared<ngraph::opset4::Parameter>(ngraph::element::f16, ngraph::Shape{});
|
||||
auto mul_beta = std::make_shared<ngraph::opset4::Multiply>(input, beta);
|
||||
auto sig = std::make_shared<ngraph::opset4::Sigmoid>(mul_beta);
|
||||
auto mul = std::make_shared<ngraph::opset4::Multiply>(input, sig);
|
||||
|
||||
f = std::make_shared<ngraph::Function>(ngraph::NodeVector{mul}, ngraph::ParameterVector{input, beta});
|
||||
|
||||
ngraph::pass::Manager manager;
|
||||
manager.register_pass<ngraph::pass::InitNodeInfo>();
|
||||
manager.register_pass<ngraph::pass::SwishFusion>();
|
||||
manager.run_passes(f);
|
||||
ASSERT_NO_THROW(check_rt_info(f));
|
||||
}
|
||||
|
||||
{
|
||||
auto input = std::make_shared<ngraph::opset4::Parameter>(ngraph::element::f16, ngraph::PartialShape::dynamic(1));
|
||||
auto beta = std::make_shared<ngraph::opset4::Parameter>(ngraph::element::f16, ngraph::Shape{});
|
||||
auto swish = std::make_shared<ngraph::opset4::Swish>(input, beta);
|
||||
|
||||
f_ref = std::make_shared<ngraph::Function>(ngraph::NodeVector{swish}, ngraph::ParameterVector{input, beta});
|
||||
}
|
||||
|
||||
auto res = compare_functions(f, f_ref);
|
||||
ASSERT_TRUE(res.first) << res.second;
|
||||
}
|
||||
|
||||
TEST(TransformationTests, SwishFusionWithSigmoidWithBetaConstant) {
|
||||
// test where the beta constant has multiple but the same value
|
||||
std::shared_ptr<ngraph::Function> f(nullptr), f_ref(nullptr);
|
||||
{
|
||||
auto input = std::make_shared<ngraph::opset4::Parameter>(ngraph::element::f16, ngraph::PartialShape::dynamic(1));
|
||||
auto beta = ngraph::opset4::Constant::create(ngraph::element::f16, ngraph::Shape{3}, {2.0, 2.0, 2.0});
|
||||
auto mul_beta = std::make_shared<ngraph::opset4::Multiply>(input, beta);
|
||||
auto sig = std::make_shared<ngraph::opset4::Sigmoid>(mul_beta);
|
||||
auto mul = std::make_shared<ngraph::opset4::Multiply>(input, sig);
|
||||
|
||||
f = std::make_shared<ngraph::Function>(ngraph::NodeVector{mul}, ngraph::ParameterVector{input});
|
||||
|
||||
ngraph::pass::Manager manager;
|
||||
manager.register_pass<ngraph::pass::InitNodeInfo>();
|
||||
manager.register_pass<ngraph::pass::SwishFusion>();
|
||||
manager.run_passes(f);
|
||||
ASSERT_NO_THROW(check_rt_info(f));
|
||||
}
|
||||
|
||||
{
|
||||
auto input = std::make_shared<ngraph::opset4::Parameter>(ngraph::element::f16, ngraph::PartialShape::dynamic(1));
|
||||
auto beta = ngraph::opset4::Constant::create(ngraph::element::f16, ngraph::Shape{}, {2.0});
|
||||
auto swish = std::make_shared<ngraph::opset4::Swish>(input, beta);
|
||||
|
||||
f_ref = std::make_shared<ngraph::Function>(ngraph::NodeVector{swish}, ngraph::ParameterVector{input});
|
||||
}
|
||||
|
||||
auto res = compare_functions(f, f_ref);
|
||||
ASSERT_TRUE(res.first) << res.second;
|
||||
}
|
@ -444,7 +444,7 @@ extensions/front/tf/ssd_toolbox_multihead_detection_output.json
|
||||
extensions/front/tf/ssd_v2_support.json
|
||||
extensions/front/tf/SSDToolboxDetectionOutput.py
|
||||
extensions/front/tf/swap_deconv_inputs.py
|
||||
extensions/front/tf/swish.py
|
||||
extensions/front/tf/swish_ext.py
|
||||
extensions/front/tf/SwitchMergeOptimization.py
|
||||
extensions/front/tf/TensorArrayExtractors.py
|
||||
extensions/front/tf/TensorArrayGatherV3.py
|
||||
|
@ -1,37 +0,0 @@
|
||||
"""
|
||||
Copyright (C) 2018-2020 Intel Corporation
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
"""
|
||||
|
||||
from extensions.ops.activation_ops import Sigmoid
|
||||
from extensions.ops.elementwise import Mul
|
||||
from mo.front.common.replacement import FrontReplacementOp
|
||||
from mo.graph.graph import Node, Graph
|
||||
|
||||
|
||||
class Swish(FrontReplacementOp):
|
||||
op = "swish_f32"
|
||||
enabled = True
|
||||
|
||||
def replace_op(self, graph: Graph, node: Node):
|
||||
mul_node = Mul(graph, {'name': node.name + '/mul_'}).create_node()
|
||||
sigmoid_node = Sigmoid(graph, {'name': node.name + '/sigmoid_'}).create_node()
|
||||
|
||||
# Connect nodes
|
||||
node.in_port(0).get_connection().get_source().connect(mul_node.in_port(0))
|
||||
node.in_port(0).get_connection().get_source().connect(sigmoid_node.in_port(0))
|
||||
sigmoid_node.out_port(0).connect(mul_node.in_port(1))
|
||||
|
||||
# The "explicit" version of the return value is: [(out_node.id, 0)])
|
||||
return [mul_node.id]
|
29
model-optimizer/extensions/front/tf/swish_ext.py
Normal file
29
model-optimizer/extensions/front/tf/swish_ext.py
Normal file
@ -0,0 +1,29 @@
|
||||
"""
|
||||
Copyright (C) 2018-2020 Intel Corporation
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
"""
|
||||
from extensions.ops.activation_ops import Swish
|
||||
from mo.front.extractor import FrontExtractorOp
|
||||
from mo.graph.graph import Node
|
||||
|
||||
|
||||
class SwishExtractor(FrontExtractorOp):
|
||||
op = 'swish_f32'
|
||||
enabled = True
|
||||
|
||||
@classmethod
|
||||
def extract(cls, node: Node):
|
||||
Swish.update_node_stat(node, {})
|
||||
return cls.enabled
|
||||
|
@ -1,57 +0,0 @@
|
||||
"""
|
||||
Copyright (C) 2018-2020 Intel Corporation
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
"""
|
||||
|
||||
import unittest
|
||||
|
||||
import numpy as np
|
||||
|
||||
from extensions.front.tf.swish import Swish
|
||||
from mo.utils.ir_engine.compare_graphs import compare_graphs
|
||||
from mo.utils.unittest.graph import build_graph
|
||||
|
||||
nodes_attributes = {
|
||||
'placeholder_1': {'shape': np.array([1, 227, 227, 3]), 'type': 'Parameter', 'kind': 'op', 'op': 'Parameter'},
|
||||
'placeholder_2': {'shape': np.array([1, 227, 227, 3]), 'type': 'Parameter', 'kind': 'op', 'op': 'Parameter'},
|
||||
# swish operation
|
||||
'swish': {'kind': 'op', 'op': 'swish_f32'},
|
||||
# Test operation
|
||||
'last': {'type': None, 'value': None, 'kind': 'op', 'op': None},
|
||||
# Add and Mul operations
|
||||
'mul': {'type': 'Multiply', 'kind': 'op', 'op': 'Mul'},
|
||||
'sigmoid': {'value': None, 'type': 'Sigmoid', 'kind': 'op', 'op': 'Sigmoid'},
|
||||
}
|
||||
|
||||
|
||||
class TestSwish(unittest.TestCase):
|
||||
def test_swish_test_1(self):
|
||||
# Test with two different inputs from two placeholders
|
||||
graph = build_graph(nodes_attributes,
|
||||
[('placeholder_1', 'swish'),
|
||||
('swish', 'last')
|
||||
], nodes_with_edges_only=True)
|
||||
|
||||
graph_ref = build_graph(nodes_attributes,
|
||||
[('placeholder_1', 'sigmoid', {'out': 0}),
|
||||
('placeholder_1', 'mul', {'in': 0, 'out': 0}),
|
||||
('sigmoid', 'mul', {'in': 1}),
|
||||
('mul', 'last'),
|
||||
], nodes_with_edges_only=True)
|
||||
|
||||
graph.stage = 'front'
|
||||
Swish().find_and_replace_pattern(graph)
|
||||
|
||||
(flag, resp) = compare_graphs(graph, graph_ref, 'last', check_op_attrs=True)
|
||||
self.assertTrue(flag, resp)
|
@ -198,9 +198,9 @@ class LeakyReLU(Op):
|
||||
|
||||
def __init__(self, graph: Graph, attrs: dict):
|
||||
super().__init__(graph, {
|
||||
'type': __class__.op,
|
||||
'op': __class__.op,
|
||||
'infer': __class__.infer,
|
||||
'type': self.op,
|
||||
'op': self.op,
|
||||
'infer': self.infer,
|
||||
'in_ports_count': 1,
|
||||
'out_ports_count': 1,
|
||||
}, attrs)
|
||||
@ -265,3 +265,36 @@ class Mish(Activation):
|
||||
sp_attrs = {'version': 'opset4'}
|
||||
sp_attrs.update(attrs)
|
||||
super().__init__(graph, sp_attrs)
|
||||
|
||||
|
||||
class Swish(Op):
|
||||
op = 'Swish'
|
||||
|
||||
def __init__(self, graph: Graph, attrs: dict):
|
||||
mandatory_props = {
|
||||
'op': self.op,
|
||||
'type': self.op,
|
||||
'version': 'opset4',
|
||||
|
||||
'infer': self.infer,
|
||||
|
||||
'in_ports_count': 2,
|
||||
'out_ports_count': 1,
|
||||
}
|
||||
super().__init__(graph, mandatory_props, attrs)
|
||||
|
||||
@staticmethod
|
||||
def infer(node: Node):
|
||||
node_name = node.soft_get('name', node.id)
|
||||
node.out_port(0).data.set_shape(node.in_port(0).data.get_shape())
|
||||
|
||||
beta = 1.0
|
||||
if node.is_in_port_connected(1):
|
||||
beta = node.in_port(1).data.get_value()
|
||||
if beta is not None:
|
||||
assert beta.ndim == 0, 'The "beta" value for node {} must be a scalar'.format(node_name)
|
||||
beta = beta.item()
|
||||
|
||||
input_value = node.in_port(1).data.get_value()
|
||||
if input_value is not None and beta is not None:
|
||||
node.out_port(0).data.set_value(input_value / (1.0 + np.exp(-input_value * beta)))
|
||||
|
@ -150,6 +150,7 @@ from ngraph.opset4 import squared_difference
|
||||
from ngraph.opset4 import squeeze
|
||||
from ngraph.opset4 import strided_slice
|
||||
from ngraph.opset4 import subtract
|
||||
from ngraph.opset4 import swish
|
||||
from ngraph.opset4 import tan
|
||||
from ngraph.opset4 import tanh
|
||||
from ngraph.opset4 import tensor_iterator
|
||||
|
@ -139,6 +139,7 @@ from ngraph.opset1.ops import squared_difference
|
||||
from ngraph.opset1.ops import squeeze
|
||||
from ngraph.opset1.ops import strided_slice
|
||||
from ngraph.opset1.ops import subtract
|
||||
from ngraph.opset4.ops import swish
|
||||
from ngraph.opset1.ops import tan
|
||||
from ngraph.opset1.ops import tanh
|
||||
from ngraph.opset1.ops import tensor_iterator
|
||||
|
@ -147,3 +147,19 @@ def mish(data: NodeInput, name: Optional[str] = None,) -> Node:
|
||||
:return: The new node which performs Mish
|
||||
"""
|
||||
return _get_node_factory_opset4().create("Mish", as_nodes(data), {})
|
||||
|
||||
|
||||
@nameable_op
|
||||
def swish(
|
||||
data: NodeInput,
|
||||
beta: Optional[NodeInput] = None,
|
||||
name: Optional[str] = None,
|
||||
) -> Node:
|
||||
"""Return a node which performing Swish activation function Swish(x, beta=1.0) = x * sigmoid(x * beta)).
|
||||
|
||||
:param data: Tensor with input data floating point type.
|
||||
:return: The new node which performs Swish
|
||||
"""
|
||||
if beta is None:
|
||||
beta = make_constant_node(1.0, np.float32)
|
||||
return _get_node_factory_opset4().create("Swish", as_nodes(data, beta), {})
|
||||
|
41
ngraph/python/tests/test_ngraph/test_swish.py
Normal file
41
ngraph/python/tests/test_ngraph/test_swish.py
Normal file
@ -0,0 +1,41 @@
|
||||
# ******************************************************************************
|
||||
# Copyright 2017-2020 Intel Corporation
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ******************************************************************************
|
||||
import numpy as np
|
||||
import ngraph as ng
|
||||
from ngraph.impl import Shape, Type
|
||||
|
||||
|
||||
def test_swish_props_with_beta():
|
||||
float_dtype = np.float32
|
||||
data = ng.parameter(Shape([3, 10]), dtype=float_dtype, name="data")
|
||||
beta = ng.parameter(Shape([]), dtype=float_dtype, name="beta")
|
||||
|
||||
node = ng.swish(data, beta)
|
||||
assert node.get_type_name() == "Swish"
|
||||
assert node.get_output_size() == 1
|
||||
assert list(node.get_output_shape(0)) == [3, 10]
|
||||
assert node.get_output_element_type(0) == Type.f32
|
||||
|
||||
|
||||
def test_swish_props_without_beta():
|
||||
float_dtype = np.float32
|
||||
data = ng.parameter(Shape([3, 10]), dtype=float_dtype, name="data")
|
||||
|
||||
node = ng.swish(data)
|
||||
assert node.get_type_name() == "Swish"
|
||||
assert node.get_output_size() == 1
|
||||
assert list(node.get_output_shape(0)) == [3, 10]
|
||||
assert node.get_output_element_type(0) == Type.f32
|
@ -332,6 +332,8 @@ set (SRC
|
||||
op/subtract.hpp
|
||||
op/sum.cpp
|
||||
op/sum.hpp
|
||||
op/swish.cpp
|
||||
op/swish.hpp
|
||||
op/variadic_split.cpp
|
||||
op/variadic_split.hpp
|
||||
op/tan.cpp
|
||||
|
140
ngraph/src/ngraph/op/swish.cpp
Normal file
140
ngraph/src/ngraph/op/swish.cpp
Normal file
@ -0,0 +1,140 @@
|
||||
//*****************************************************************************
|
||||
// Copyright 2017-2020 Intel Corporation
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
//*****************************************************************************
|
||||
|
||||
#include "ngraph/op/swish.hpp"
|
||||
#include "ngraph/attribute_visitor.hpp"
|
||||
#include "ngraph/op/constant.hpp"
|
||||
|
||||
#include "ngraph/runtime/host_tensor.hpp"
|
||||
#include "ngraph/runtime/reference/swish.hpp"
|
||||
|
||||
using namespace std;
|
||||
using namespace ngraph;
|
||||
|
||||
constexpr NodeTypeInfo op::v4::Swish::type_info;
|
||||
|
||||
op::v4::Swish::Swish(const Output<Node>& arg)
|
||||
: Op({arg})
|
||||
{
|
||||
constructor_validate_and_infer_types();
|
||||
}
|
||||
|
||||
op::v4::Swish::Swish(const Output<Node>& arg, const Output<Node>& beta)
|
||||
: Op({arg, beta})
|
||||
{
|
||||
constructor_validate_and_infer_types();
|
||||
}
|
||||
|
||||
bool op::v4::Swish::visit_attributes(AttributeVisitor& visitor)
|
||||
{
|
||||
return true;
|
||||
}
|
||||
|
||||
void op::v4::Swish::validate_and_infer_types()
|
||||
{
|
||||
auto inputs_count = input_values().size();
|
||||
NODE_VALIDATION_CHECK(this,
|
||||
inputs_count == 1 || inputs_count == 2,
|
||||
"Swish must have 1 or 2 inputs, but it has: ",
|
||||
inputs_count);
|
||||
|
||||
if (inputs_count == 2)
|
||||
{
|
||||
NODE_VALIDATION_CHECK(this,
|
||||
input_value(0).get_element_type() ==
|
||||
input_value(1).get_element_type(),
|
||||
"Swish inputs must have the same type but they are: ",
|
||||
input_value(0).get_element_type(),
|
||||
" and ",
|
||||
input_value(1).get_element_type());
|
||||
if (get_input_partial_shape(1).rank().is_static())
|
||||
{
|
||||
auto beta_rank = get_input_partial_shape(1).rank().get_length();
|
||||
NODE_VALIDATION_CHECK(this,
|
||||
beta_rank == 0,
|
||||
"Swish input with beta must be scalar but it has rank: ",
|
||||
beta_rank);
|
||||
}
|
||||
}
|
||||
set_output_size(1);
|
||||
set_output_type(0, get_input_element_type(0), get_input_partial_shape(0));
|
||||
}
|
||||
|
||||
shared_ptr<Node> op::v4::Swish::clone_with_new_inputs(const OutputVector& new_args) const
|
||||
{
|
||||
if (new_args.size() == 1)
|
||||
{
|
||||
return make_shared<op::v4::Swish>(new_args.at(0));
|
||||
}
|
||||
else
|
||||
{
|
||||
return make_shared<op::v4::Swish>(new_args.at(0), new_args.at(1));
|
||||
}
|
||||
}
|
||||
|
||||
namespace
|
||||
{
|
||||
template <element::Type_t ET>
|
||||
inline bool evaluate(const HostTensorPtr& arg0,
|
||||
const HostTensorPtr& arg1,
|
||||
const HostTensorPtr& out,
|
||||
const size_t count)
|
||||
{
|
||||
using T = typename element_type_traits<ET>::value_type;
|
||||
if (arg1 != nullptr)
|
||||
{
|
||||
runtime::reference::swish<T>(
|
||||
arg0->get_data_ptr<ET>(), arg1->get_data_ptr<ET>(), out->get_data_ptr<ET>(), count);
|
||||
}
|
||||
else
|
||||
{
|
||||
runtime::reference::swish<T>(
|
||||
arg0->get_data_ptr<ET>(), nullptr, out->get_data_ptr<ET>(), count);
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
bool evaluate_swish(const HostTensorPtr& arg0,
|
||||
const HostTensorPtr& arg1,
|
||||
const HostTensorPtr& out,
|
||||
const size_t count)
|
||||
{
|
||||
bool rc = true;
|
||||
out->set_unary(arg0);
|
||||
|
||||
switch (arg0->get_element_type())
|
||||
{
|
||||
TYPE_CASE(f16)(arg0, arg1, out, count);
|
||||
break;
|
||||
TYPE_CASE(f32)(arg0, arg1, out, count);
|
||||
break;
|
||||
default: rc = false; break;
|
||||
}
|
||||
return rc;
|
||||
}
|
||||
}
|
||||
|
||||
bool op::v4::Swish::evaluate(const HostTensorVector& outputs, const HostTensorVector& inputs) const
|
||||
{
|
||||
if (inputs.size() == 2)
|
||||
{
|
||||
return evaluate_swish(inputs[0], inputs[1], outputs[0], shape_size(get_output_shape(0)));
|
||||
}
|
||||
else
|
||||
{
|
||||
return evaluate_swish(inputs[0], nullptr, outputs[0], shape_size(get_output_shape(0)));
|
||||
}
|
||||
}
|
57
ngraph/src/ngraph/op/swish.hpp
Normal file
57
ngraph/src/ngraph/op/swish.hpp
Normal file
@ -0,0 +1,57 @@
|
||||
//*****************************************************************************
|
||||
// Copyright 2017-2020 Intel Corporation
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
//*****************************************************************************
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ngraph/node.hpp"
|
||||
#include "ngraph/op/op.hpp"
|
||||
|
||||
namespace ngraph
|
||||
{
|
||||
namespace op
|
||||
{
|
||||
namespace v4
|
||||
{
|
||||
/// \brief A Swish Activation Function
|
||||
/// f(x) = x / (1.0 + exp(-beta * x)) or
|
||||
/// f(x) = x * sigmoid(beta * x)
|
||||
///
|
||||
class NGRAPH_API Swish : public ngraph::op::Op
|
||||
{
|
||||
public:
|
||||
static constexpr NodeTypeInfo type_info{"Swish", 4};
|
||||
const NodeTypeInfo& get_type_info() const override { return type_info; }
|
||||
Swish() = default;
|
||||
|
||||
/// \brief Constructs an Swish operation.
|
||||
///
|
||||
/// \param data Input tensor
|
||||
/// \param beta Scalar with beta value. If the argument is not specified then use
|
||||
/// the default value 1.0
|
||||
Swish(const Output<Node>& arg, const Output<Node>& beta);
|
||||
explicit Swish(const Output<Node>& arg);
|
||||
|
||||
bool visit_attributes(AttributeVisitor& visitor) override;
|
||||
void validate_and_infer_types() override;
|
||||
|
||||
virtual std::shared_ptr<Node>
|
||||
clone_with_new_inputs(const OutputVector& new_args) const override;
|
||||
bool evaluate(const HostTensorVector& outputs,
|
||||
const HostTensorVector& inputs) const override;
|
||||
};
|
||||
}
|
||||
}
|
||||
}
|
@ -157,6 +157,7 @@
|
||||
#include "ngraph/op/strided_slice.hpp"
|
||||
#include "ngraph/op/subtract.hpp"
|
||||
#include "ngraph/op/sum.hpp"
|
||||
#include "ngraph/op/swish.hpp"
|
||||
#include "ngraph/op/tan.hpp"
|
||||
#include "ngraph/op/tanh.hpp"
|
||||
#include "ngraph/op/tensor_iterator.hpp"
|
||||
|
@ -155,6 +155,7 @@ NGRAPH_OP(TopK, ngraph::op::v3)
|
||||
NGRAPH_OP(Acosh, ngraph::op::v3)
|
||||
NGRAPH_OP(Asinh, ngraph::op::v3)
|
||||
NGRAPH_OP(Atanh, ngraph::op::v3)
|
||||
NGRAPH_OP(CTCLoss, ngraph::op::v4)
|
||||
NGRAPH_OP(NonMaxSuppression, ngraph::op::v4)
|
||||
NGRAPH_OP(Mish, ngraph::op::v4)
|
||||
NGRAPH_OP(CTCLoss, ngraph::op::v4)
|
||||
NGRAPH_OP(Swish, ngraph::op::v4)
|
||||
|
43
ngraph/src/ngraph/runtime/reference/swish.hpp
Normal file
43
ngraph/src/ngraph/runtime/reference/swish.hpp
Normal file
@ -0,0 +1,43 @@
|
||||
//*****************************************************************************
|
||||
// Copyright 2017-2020 Intel Corporation
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
//*****************************************************************************
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <cmath>
|
||||
#include <cstddef>
|
||||
|
||||
namespace ngraph
|
||||
{
|
||||
namespace runtime
|
||||
{
|
||||
namespace reference
|
||||
{
|
||||
template <typename T>
|
||||
void swish(const T* arg, const T* beta, T* out, size_t count)
|
||||
{
|
||||
T beta_value = static_cast<T>(1.0);
|
||||
if (beta != nullptr)
|
||||
{
|
||||
beta_value = beta[0];
|
||||
}
|
||||
for (size_t i = 0; i < count; i++)
|
||||
{
|
||||
out[i] = arg[i] / (1.0 + std::exp(-arg[i] * beta_value));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
@ -77,6 +77,7 @@ set(SRC
|
||||
op_eval/non_zero.cpp
|
||||
op_eval/split.cpp
|
||||
op_eval/strided_slice.cpp
|
||||
op_eval/swish.cpp
|
||||
op_is.cpp
|
||||
opset1.cpp
|
||||
partial_shape.cpp
|
||||
@ -165,6 +166,7 @@ set(SRC
|
||||
type_prop/squared_difference.cpp
|
||||
type_prop/squeeze.cpp
|
||||
type_prop/sum.cpp
|
||||
type_prop/swish.cpp
|
||||
type_prop/reduce_prod.cpp
|
||||
type_prop/reduce_sum.cpp
|
||||
type_prop/tile.cpp
|
||||
|
90
ngraph/test/op_eval/swish.cpp
Normal file
90
ngraph/test/op_eval/swish.cpp
Normal file
@ -0,0 +1,90 @@
|
||||
//*****************************************************************************
|
||||
// Copyright 2017-2020 Intel Corporation
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
//*****************************************************************************
|
||||
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "gtest/gtest.h"
|
||||
|
||||
#include "ngraph/op/swish.hpp"
|
||||
#include "ngraph/runtime/host_tensor.hpp"
|
||||
#include "ngraph/validation_util.hpp"
|
||||
#include "runtime/backend.hpp"
|
||||
#include "util/test_tools.hpp"
|
||||
|
||||
using namespace std;
|
||||
using namespace ngraph;
|
||||
|
||||
TEST(op_eval, swish_with_beta1)
|
||||
{
|
||||
auto p = make_shared<op::Parameter>(element::f32, Shape{3});
|
||||
auto beta = make_shared<op::Parameter>(element::f32, Shape{});
|
||||
auto swish = make_shared<op::v4::Swish>(p, beta);
|
||||
auto fun = make_shared<Function>(OutputVector{swish}, ParameterVector{p, beta});
|
||||
|
||||
std::vector<float> inputs{-0.5, 0.0, 0.5};
|
||||
std::vector<float> expected_result{-0.18877034, 0.0, 0.31122968};
|
||||
|
||||
auto result = make_shared<HostTensor>();
|
||||
ASSERT_TRUE(fun->evaluate({result},
|
||||
{make_host_tensor<element::Type_t::f32>(Shape{3}, inputs),
|
||||
make_host_tensor<element::Type_t::f32>(Shape{}, {1.0})}));
|
||||
EXPECT_EQ(result->get_element_type(), element::f32);
|
||||
EXPECT_EQ(result->get_shape(), Shape{3});
|
||||
auto result_data = read_vector<float>(result);
|
||||
for (auto i = 0; i < inputs.size(); i++)
|
||||
EXPECT_NEAR(result_data[i], expected_result[i], 0.000001);
|
||||
}
|
||||
|
||||
TEST(op_eval, swish_with_beta0_75)
|
||||
{
|
||||
auto p = make_shared<op::Parameter>(element::f32, Shape{3});
|
||||
auto beta = make_shared<op::Parameter>(element::f32, Shape{});
|
||||
auto swish = make_shared<op::v4::Swish>(p, beta);
|
||||
auto fun = make_shared<Function>(OutputVector{swish}, ParameterVector{p, beta});
|
||||
|
||||
std::vector<float> inputs{-0.5, 0.0, 0.5};
|
||||
std::vector<float> expected_result{-0.2036667, 0.0, 0.2963333};
|
||||
|
||||
auto result = make_shared<HostTensor>();
|
||||
ASSERT_TRUE(fun->evaluate({result},
|
||||
{make_host_tensor<element::Type_t::f32>(Shape{3}, inputs),
|
||||
make_host_tensor<element::Type_t::f32>(Shape{}, {0.75})}));
|
||||
EXPECT_EQ(result->get_element_type(), element::f32);
|
||||
EXPECT_EQ(result->get_shape(), Shape{3});
|
||||
auto result_data = read_vector<float>(result);
|
||||
for (auto i = 0; i < inputs.size(); i++)
|
||||
EXPECT_NEAR(result_data[i], expected_result[i], 0.000001);
|
||||
}
|
||||
|
||||
TEST(op_eval, swish_without_beta)
|
||||
{
|
||||
auto p = make_shared<op::Parameter>(element::f32, Shape{3});
|
||||
auto swish = make_shared<op::v4::Swish>(p);
|
||||
auto fun = make_shared<Function>(OutputVector{swish}, ParameterVector{p});
|
||||
|
||||
std::vector<float> inputs{-0.5, 0.0, 0.5};
|
||||
std::vector<float> expected_result{-0.18877034, 0.0, 0.31122968};
|
||||
|
||||
auto result = make_shared<HostTensor>();
|
||||
ASSERT_TRUE(
|
||||
fun->evaluate({result}, {make_host_tensor<element::Type_t::f32>(Shape{3}, inputs)}));
|
||||
EXPECT_EQ(result->get_element_type(), element::f32);
|
||||
EXPECT_EQ(result->get_shape(), Shape{3});
|
||||
auto result_data = read_vector<float>(result);
|
||||
for (auto i = 0; i < inputs.size(); i++)
|
||||
EXPECT_NEAR(result_data[i], expected_result[i], 0.000001);
|
||||
}
|
95
ngraph/test/type_prop/swish.cpp
Normal file
95
ngraph/test/type_prop/swish.cpp
Normal file
@ -0,0 +1,95 @@
|
||||
//*****************************************************************************
|
||||
// Copyright 2017-2020 Intel Corporation
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
//*****************************************************************************
|
||||
|
||||
#include "gtest/gtest.h"
|
||||
#include "ngraph/ngraph.hpp"
|
||||
#include "util/type_prop.hpp"
|
||||
|
||||
using namespace std;
|
||||
using namespace ngraph;
|
||||
|
||||
TEST(type_prop, swish)
|
||||
{
|
||||
auto data = make_shared<op::Parameter>(element::f32, Shape{1, 3, 6});
|
||||
auto swish_func = make_shared<op::v4::Swish>(data);
|
||||
EXPECT_EQ(swish_func->get_element_type(), element::f32);
|
||||
EXPECT_EQ(swish_func->get_shape(), data->get_output_shape(0));
|
||||
}
|
||||
|
||||
TEST(type_prop, swish_partial)
|
||||
{
|
||||
auto data = make_shared<op::Parameter>(element::f32, PartialShape{1, Dimension::dynamic(), 6});
|
||||
auto swish_func = make_shared<op::v4::Swish>(data);
|
||||
EXPECT_EQ(swish_func->get_element_type(), element::f32);
|
||||
ASSERT_TRUE(
|
||||
swish_func->get_output_partial_shape(0).same_scheme(data->get_output_partial_shape(0)));
|
||||
|
||||
// rank unknown
|
||||
auto swish_partial = make_shared<op::v4::Swish>(
|
||||
make_shared<op::Parameter>(element::f32, PartialShape::dynamic()));
|
||||
ASSERT_TRUE(swish_partial->get_output_partial_shape(0).same_scheme(PartialShape::dynamic()));
|
||||
}
|
||||
|
||||
TEST(type_prop, swish_partial_static_rank)
|
||||
{
|
||||
auto data = make_shared<op::Parameter>(element::f32, PartialShape{1, Dimension::dynamic(), 6});
|
||||
auto swish_func = make_shared<op::v4::Swish>(data);
|
||||
EXPECT_EQ(swish_func->get_element_type(), element::f32);
|
||||
ASSERT_TRUE(
|
||||
swish_func->get_output_partial_shape(0).same_scheme(data->get_output_partial_shape(0)));
|
||||
ASSERT_TRUE(swish_func->get_output_partial_shape(0).rank().is_static());
|
||||
}
|
||||
|
||||
TEST(type_prop, swish_incompatible_types)
|
||||
{
|
||||
auto data = make_shared<op::Parameter>(element::f32, Shape{1, 3, 6});
|
||||
auto beta = make_shared<op::Parameter>(element::f16, Shape{});
|
||||
try
|
||||
{
|
||||
const auto swish_func = make_shared<op::v4::Swish>(data, beta);
|
||||
FAIL() << "swish_func node was created with incompatible input data types.";
|
||||
}
|
||||
catch (const NodeValidationFailure& error)
|
||||
{
|
||||
EXPECT_HAS_SUBSTRING(error.what(), std::string("Swish inputs must have the same type"));
|
||||
}
|
||||
}
|
||||
|
||||
TEST(type_prop, swish_beta_not_scalar)
|
||||
{
|
||||
auto data = make_shared<op::Parameter>(element::f32, Shape{1, 3, 6});
|
||||
auto beta = make_shared<op::Parameter>(element::f32, Shape{1});
|
||||
try
|
||||
{
|
||||
const auto swish_func = make_shared<op::v4::Swish>(data, beta);
|
||||
FAIL() << "swish_func node was created with scalar beta value.";
|
||||
}
|
||||
catch (const NodeValidationFailure& error)
|
||||
{
|
||||
EXPECT_HAS_SUBSTRING(error.what(), std::string("Swish input with beta must be scalar"));
|
||||
}
|
||||
}
|
||||
|
||||
TEST(type_prop, swish_2_inputs)
|
||||
{
|
||||
auto data = make_shared<op::Parameter>(element::f32, Shape{1, 3, 6});
|
||||
auto beta = make_shared<op::Parameter>(element::f32, Shape{});
|
||||
const auto swish_func = make_shared<op::v4::Swish>(data, beta);
|
||||
|
||||
EXPECT_EQ(swish_func->get_element_type(), element::f32);
|
||||
ASSERT_TRUE(swish_func->get_output_partial_shape(0).same_scheme(data->get_output_shape(0)));
|
||||
ASSERT_TRUE(swish_func->get_output_partial_shape(0).rank().is_static());
|
||||
}
|
Loading…
Reference in New Issue
Block a user