[nGraph] [MO] Add ngraph transformations for PRelu fusing (#10209)

* MO: add PRelu fusing pattern with Sub

* PRelu fusing

* Apply review comments

* Code style
This commit is contained in:
Nadezhda Ageeva 2022-03-15 18:57:59 +03:00 committed by GitHub
parent 93722fe101
commit 3c721b0f03
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 417 additions and 0 deletions

View File

@ -0,0 +1,115 @@
// Copyright (C) 2022 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#pragma once
#include <memory>
#include <ngraph/pass/graph_rewrite.hpp>
#include <transformations_visibility.hpp>
#include <utility>
namespace ngraph {
namespace pass {
class TRANSFORMATIONS_API PReluFusion;
class TRANSFORMATIONS_API PReluFusionNegativeAdd;
class TRANSFORMATIONS_API PReluFusionNegativeSub;
class TRANSFORMATIONS_API PReluFusionMultiplyAdd;
class TRANSFORMATIONS_API PReluFusionMultiplySub;
} // namespace pass
} // namespace ngraph
/**
* @ingroup ie_transformation_common_api
* @brief PReluFusionNegativeAdd transformation replaces a sub-graph
* Op
* / \
* Relu Negative
* | |
* | Relu
* | |
* | Negative
* | |
* | Multiply
* \ /
* Add
*/
class ngraph::pass::PReluFusionNegativeAdd : public ngraph::pass::MatcherPass {
public:
NGRAPH_RTTI_DECLARATION;
PReluFusionNegativeAdd();
};
/**
* @ingroup ie_transformation_common_api
* @brief PReluFusionNegativeSub transformation replaces a sub-graph
* Op
* / \
* Relu Negative
* | |
* | Relu
* | |
* | Multiply
* \ /
* Sub
*/
class ngraph::pass::PReluFusionNegativeSub : public ngraph::pass::MatcherPass {
public:
NGRAPH_RTTI_DECLARATION;
PReluFusionNegativeSub();
};
/**
* @ingroup ie_transformation_common_api
* @brief PReluFusionMultiplyAdd transformation replaces a sub-graph
* Op
* / \
* Relu Multiply (-1)
* | |
* | Relu
* | |
* | Multiply
* \ /
* Add
*/
class ngraph::pass::PReluFusionMultiplyAdd : public ngraph::pass::MatcherPass {
public:
NGRAPH_RTTI_DECLARATION;
PReluFusionMultiplyAdd();
};
/**
* @ingroup ie_transformation_common_api
* @brief PReluFusionMultiplySub transformation replaces a sub-graph
* Op
* / \
* Relu Multiply (-1)
* | |
* | Relu
* | |
* | Multiply
* \ /
* Sub
*/
class ngraph::pass::PReluFusionMultiplySub : public ngraph::pass::MatcherPass {
public:
NGRAPH_RTTI_DECLARATION;
PReluFusionMultiplySub();
};
/**
* @ingroup ie_transformation_common_api
* @brief PReluFusion transformation replaces various sub-graphs with a PRelu op.
*/
class ngraph::pass::PReluFusion : public ngraph::pass::GraphRewrite {
public:
NGRAPH_RTTI_DECLARATION;
PReluFusion() {
add_matcher<ngraph::pass::PReluFusionNegativeAdd>();
add_matcher<ngraph::pass::PReluFusionNegativeSub>();
add_matcher<ngraph::pass::PReluFusionMultiplyAdd>();
add_matcher<ngraph::pass::PReluFusionMultiplySub>();
}
};

View File

@ -35,6 +35,7 @@
#include <transformations/common_optimizations/normalize_l2_fusion.hpp>
#include <transformations/common_optimizations/optimize_strided_slice.hpp>
#include <transformations/common_optimizations/pad_fusion.hpp>
#include <transformations/common_optimizations/prelu_fusion.hpp>
#include <transformations/common_optimizations/random_uniform_fusion.hpp>
#include <transformations/common_optimizations/remove_concat_zero_dim_input.hpp>
#include <transformations/common_optimizations/remove_filtering_boxes_by_size.hpp>
@ -153,6 +154,7 @@ bool ngraph::pass::MOCTransformations::run_on_model(const std::shared_ptr<ngraph
common_fusions->add_matcher<ngraph::pass::TransposeToReshape>();
common_fusions->add_matcher<ngraph::pass::ReshapeSequenceFusion>(m_use_shapes);
common_fusions->add_matcher<ngraph::pass::MatMulConstTransposesExtraction>();
common_fusions->add_matcher<ngraph::pass::PReluFusion>();
common_fusions->set_name("ngraph::pass::CommonFusions");
manager.register_pass<ngraph::pass::BinarizeWeights>();

View File

@ -0,0 +1,163 @@
// Copyright (C) 2022 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#define _USE_MATH_DEFINES
#include "transformations/common_optimizations/prelu_fusion.hpp"
#include <memory>
#include <ngraph/opsets/opset8.hpp>
#include <ngraph/pattern/op/or.hpp>
#include <ngraph/pattern/op/wrap_type.hpp>
#include <ngraph/rt_info.hpp>
#include "itt.hpp"
#include "transformations/utils/utils.hpp"
NGRAPH_RTTI_DEFINITION(ngraph::pass::PReluFusion, "PReluFusion", 0);
NGRAPH_RTTI_DEFINITION(ngraph::pass::PReluFusionNegativeAdd, "PReluFusionNegativeAdd", 0);
NGRAPH_RTTI_DEFINITION(ngraph::pass::PReluFusionNegativeSub, "PReluFusionNegativeSub", 0);
NGRAPH_RTTI_DEFINITION(ngraph::pass::PReluFusionMultiplyAdd, "PReluFusionMultiplyAdd", 0);
NGRAPH_RTTI_DEFINITION(ngraph::pass::PReluFusionMultiplySub, "PReluFusionMultiplySub", 0);
ngraph::pass::PReluFusionNegativeAdd::PReluFusionNegativeAdd() {
MATCHER_SCOPE(PReluFusionNegativeAdd);
auto input = ngraph::pattern::any_input();
auto relu_pos = ngraph::pattern::wrap_type<ngraph::opset8::Relu>({input});
auto neg1 = ngraph::pattern::wrap_type<ngraph::opset8::Negative>({input});
auto relu_neg = ngraph::pattern::wrap_type<ngraph::opset8::Relu>({neg1});
auto neg2 = ngraph::pattern::wrap_type<ngraph::opset8::Negative>({relu_neg});
auto mul_constant = ngraph::pattern::wrap_type<ngraph::opset8::Constant>();
auto mul = ngraph::pattern::wrap_type<ngraph::opset8::Multiply>({neg2, mul_constant});
auto add = ngraph::pattern::wrap_type<ngraph::opset8::Add>({relu_pos, mul});
ngraph::matcher_pass_callback callback = [=](ngraph::pattern::Matcher& m) {
const auto& pattern_to_output = m.get_pattern_value_map();
auto input_output = pattern_to_output.at(input);
auto slope_output = pattern_to_output.at(mul_constant);
auto add_node = pattern_to_output.at(add).get_node_shared_ptr();
auto prelu = std::make_shared<ngraph::opset8::PRelu>(input_output, slope_output);
prelu->set_friendly_name(m.get_match_root()->get_friendly_name());
ngraph::NodeVector copy_from = {pattern_to_output.at(relu_pos).get_node_shared_ptr(),
pattern_to_output.at(neg1).get_node_shared_ptr(),
pattern_to_output.at(relu_neg).get_node_shared_ptr(),
pattern_to_output.at(neg2).get_node_shared_ptr(),
pattern_to_output.at(mul).get_node_shared_ptr(),
pattern_to_output.at(add).get_node_shared_ptr()};
ngraph::copy_runtime_info(copy_from, prelu);
ngraph::replace_node(add_node, prelu);
return true;
};
auto m = std::make_shared<ngraph::pattern::Matcher>(add, matcher_name);
register_matcher(m, callback);
}
ngraph::pass::PReluFusionNegativeSub::PReluFusionNegativeSub() {
MATCHER_SCOPE(PReluFusionNegativeSub);
auto input = ngraph::pattern::any_input();
auto relu_pos = ngraph::pattern::wrap_type<ngraph::opset8::Relu>({input});
auto neg1 = ngraph::pattern::wrap_type<ngraph::opset8::Negative>({input});
auto relu_neg = ngraph::pattern::wrap_type<ngraph::opset8::Relu>({neg1});
auto mul_constant = ngraph::pattern::wrap_type<ngraph::opset8::Constant>();
auto mul = ngraph::pattern::wrap_type<ngraph::opset8::Multiply>({relu_neg, mul_constant});
auto sub = ngraph::pattern::wrap_type<ngraph::opset8::Subtract>({relu_pos, mul});
ngraph::matcher_pass_callback callback = [=](ngraph::pattern::Matcher& m) {
const auto& pattern_to_output = m.get_pattern_value_map();
auto input_output = pattern_to_output.at(input);
auto slope_output = pattern_to_output.at(mul_constant);
auto sub_node = pattern_to_output.at(sub).get_node_shared_ptr();
auto prelu = std::make_shared<ngraph::opset8::PRelu>(input_output, slope_output);
prelu->set_friendly_name(m.get_match_root()->get_friendly_name());
ngraph::NodeVector copy_from = {pattern_to_output.at(relu_pos).get_node_shared_ptr(),
pattern_to_output.at(neg1).get_node_shared_ptr(),
pattern_to_output.at(relu_neg).get_node_shared_ptr(),
pattern_to_output.at(mul).get_node_shared_ptr(),
pattern_to_output.at(sub).get_node_shared_ptr()};
ngraph::copy_runtime_info(copy_from, prelu);
ngraph::replace_node(sub_node, prelu);
return true;
};
auto m = std::make_shared<ngraph::pattern::Matcher>(sub, matcher_name);
register_matcher(m, callback);
}
static std::function<bool(ngraph::Output<ngraph::Node>)> constant_value(const float target_value) {
return [=](const ngraph::Output<ngraph::Node>& output) -> bool {
auto node = std::dynamic_pointer_cast<ngraph::opset8::Constant>(output.get_node_shared_ptr());
if (!node) {
return false;
}
float value;
if (!ngraph::op::util::get_single_value(node, value)) {
return false;
}
return value == target_value;
};
}
ngraph::pass::PReluFusionMultiplyAdd::PReluFusionMultiplyAdd() {
MATCHER_SCOPE(PReluFusionMultiplyAdd);
auto input = ngraph::pattern::any_input();
auto relu_pos = ngraph::pattern::wrap_type<ngraph::opset8::Relu>({input});
auto mul_neg_constant = ngraph::pattern::wrap_type<ngraph::opset8::Constant>(constant_value(-1.0));
auto mul_neg = ngraph::pattern::wrap_type<ngraph::opset8::Multiply>({input, mul_neg_constant});
auto relu_neg = ngraph::pattern::wrap_type<ngraph::opset8::Relu>({mul_neg});
auto mul_constant = ngraph::pattern::wrap_type<ngraph::opset8::Constant>();
auto mul = ngraph::pattern::wrap_type<ngraph::opset8::Multiply>({relu_neg, mul_constant});
auto add = ngraph::pattern::wrap_type<ngraph::opset8::Add>({relu_pos, mul});
ngraph::matcher_pass_callback callback = [=](ngraph::pattern::Matcher& m) {
const auto& pattern_to_output = m.get_pattern_value_map();
auto input_output = pattern_to_output.at(input);
auto slope_output = pattern_to_output.at(mul_constant);
auto add_node = pattern_to_output.at(add).get_node_shared_ptr();
auto negative = ngraph::op::util::make_try_fold<ngraph::opset8::Negative>(slope_output);
auto prelu = std::make_shared<ngraph::opset8::PRelu>(input_output, negative);
prelu->set_friendly_name(m.get_match_root()->get_friendly_name());
ngraph::NodeVector copy_from = {pattern_to_output.at(relu_pos).get_node_shared_ptr(),
pattern_to_output.at(mul_neg).get_node_shared_ptr(),
pattern_to_output.at(relu_neg).get_node_shared_ptr(),
pattern_to_output.at(mul).get_node_shared_ptr(),
pattern_to_output.at(add).get_node_shared_ptr()};
ngraph::copy_runtime_info(copy_from, {prelu, negative});
ngraph::replace_node(add_node, prelu);
return true;
};
auto m = std::make_shared<ngraph::pattern::Matcher>(add, matcher_name);
register_matcher(m, callback);
}
ngraph::pass::PReluFusionMultiplySub::PReluFusionMultiplySub() {
MATCHER_SCOPE(PReluFusionMultiplySub);
auto input = ngraph::pattern::any_input();
auto relu_pos = ngraph::pattern::wrap_type<ngraph::opset8::Relu>({input});
auto mul_neg_constant = ngraph::pattern::wrap_type<ngraph::opset8::Constant>(constant_value(-1.0));
auto mul_neg = ngraph::pattern::wrap_type<ngraph::opset8::Multiply>({input, mul_neg_constant});
auto relu_neg = ngraph::pattern::wrap_type<ngraph::opset8::Relu>({mul_neg});
auto mul_constant = ngraph::pattern::wrap_type<ngraph::opset8::Constant>();
auto mul = ngraph::pattern::wrap_type<ngraph::opset8::Multiply>({relu_neg, mul_constant});
auto sub = ngraph::pattern::wrap_type<ngraph::opset8::Subtract>({relu_pos, mul});
ngraph::matcher_pass_callback callback = [=](ngraph::pattern::Matcher& m) {
const auto& pattern_to_output = m.get_pattern_value_map();
auto input_output = pattern_to_output.at(input);
auto slope_output = pattern_to_output.at(mul_constant);
auto sub_node = pattern_to_output.at(sub).get_node_shared_ptr();
auto prelu = std::make_shared<ngraph::opset8::PRelu>(input_output, slope_output);
prelu->set_friendly_name(m.get_match_root()->get_friendly_name());
ngraph::NodeVector copy_from = {pattern_to_output.at(relu_pos).get_node_shared_ptr(),
pattern_to_output.at(mul_neg).get_node_shared_ptr(),
pattern_to_output.at(relu_neg).get_node_shared_ptr(),
pattern_to_output.at(mul).get_node_shared_ptr(),
pattern_to_output.at(sub).get_node_shared_ptr()};
ngraph::copy_runtime_info(copy_from, prelu);
ngraph::replace_node(sub_node, prelu);
return true;
};
auto m = std::make_shared<ngraph::pattern::Matcher>(sub, matcher_name);
register_matcher(m, callback);
}

View File

@ -0,0 +1,137 @@
// Copyright (C) 2022 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#define _USE_MATH_DEFINES
#include <gtest/gtest.h>
#include <math.h>
#include <memory>
#include <ngraph/function.hpp>
#include <ngraph/opsets/opset8.hpp>
#include <ngraph/pass/manager.hpp>
#include <transformations/common_optimizations/prelu_fusion.hpp>
#include "common_test_utils/ngraph_test_utils.hpp"
using namespace testing;
using namespace ngraph;
TEST_F(TransformationTestsF, PReluFusionNegativeAdd) {
{
auto data = std::make_shared<opset8::Parameter>(element::f32, Shape{1, 128});
auto relu_pos = std::make_shared<ngraph::opset8::Relu>(data);
auto neg = std::make_shared<ngraph::opset8::Negative>(data);
auto relu_neg = std::make_shared<ngraph::opset8::Relu>(neg);
auto neg2 = std::make_shared<ngraph::opset8::Negative>(relu_neg);
auto mul_const = opset8::Constant::create(element::f32, Shape{1}, {0.001});
auto mul = std::make_shared<ngraph::opset8::Multiply>(neg2, mul_const);
auto add = std::make_shared<ngraph::opset8::Add>(relu_pos, mul);
function = std::make_shared<Function>(NodeVector{add}, ParameterVector{data});
manager.register_pass<pass::PReluFusion>();
}
{
auto data = std::make_shared<opset8::Parameter>(element::f32, Shape{1, 128});
auto prelu_const = opset8::Constant::create(element::f32, Shape{1}, {0.001});
auto prelu = std::make_shared<opset8::PRelu>(data, prelu_const);
function_ref =
std::make_shared<Function>(NodeVector{prelu}, ParameterVector{data});
}
}
TEST_F(TransformationTestsF, PReluFusionNegativeSub) {
{
auto data = std::make_shared<opset8::Parameter>(element::f32, Shape{1, 128});
auto relu_pos = std::make_shared<ngraph::opset8::Relu>(data);
auto neg = std::make_shared<ngraph::opset8::Negative>(data);
auto relu_neg = std::make_shared<ngraph::opset8::Relu>(neg);
auto mul_const = opset8::Constant::create(element::f32, Shape{1}, {0.001});
auto mul = std::make_shared<ngraph::opset8::Multiply>(relu_neg, mul_const);
auto sub = std::make_shared<ngraph::opset8::Subtract>(relu_pos, mul);
function = std::make_shared<Function>(NodeVector{sub}, ParameterVector{data});
manager.register_pass<pass::PReluFusion>();
}
{
auto data = std::make_shared<opset8::Parameter>(element::f32, Shape{1, 128});
auto prelu_const = opset8::Constant::create(element::f32, Shape{1}, {0.001});
auto prelu = std::make_shared<opset8::PRelu>(data, prelu_const);
function_ref =
std::make_shared<Function>(NodeVector{prelu}, ParameterVector{data});
}
}
TEST_F(TransformationTestsF, PReluFusionMultiplyAdd) {
{
auto data = std::make_shared<opset8::Parameter>(element::f32, Shape{1, 128});
auto relu_pos = std::make_shared<ngraph::opset8::Relu>(data);
auto mul_neg_const = opset8::Constant::create(element::f32, Shape{1}, {-1.0});
auto mul_neg = std::make_shared<ngraph::opset8::Multiply>(data, mul_neg_const);
auto relu_neg = std::make_shared<ngraph::opset8::Relu>(mul_neg);
auto mul_const = opset8::Constant::create(element::f32, Shape{1}, {-0.001});
auto mul = std::make_shared<ngraph::opset8::Multiply>(relu_neg, mul_const);
auto add = std::make_shared<ngraph::opset8::Add>(relu_pos, mul);
function = std::make_shared<Function>(NodeVector{add}, ParameterVector{data});
manager.register_pass<pass::PReluFusion>();
}
{
auto data = std::make_shared<opset8::Parameter>(element::f32, Shape{1, 128});
auto prelu_const = opset8::Constant::create(element::f32, Shape{1}, {0.001});
auto prelu = std::make_shared<opset8::PRelu>(data, prelu_const);
function_ref =
std::make_shared<Function>(NodeVector{prelu}, ParameterVector{data});
}
}
TEST_F(TransformationTestsF, PReluFusionMultiplySub) {
{
auto data = std::make_shared<opset8::Parameter>(element::f32, Shape{1, 128});
auto relu_pos = std::make_shared<ngraph::opset8::Relu>(data);
auto mul_neg_const = opset8::Constant::create(element::f32, Shape{1}, {-1.0});
auto mul_neg = std::make_shared<ngraph::opset8::Multiply>(data, mul_neg_const);
auto relu_neg = std::make_shared<ngraph::opset8::Relu>(mul_neg);
auto mul_const = opset8::Constant::create(element::f32, Shape{1}, {0.001});
auto mul = std::make_shared<ngraph::opset8::Multiply>(relu_neg, mul_const);
auto sub = std::make_shared<ngraph::opset8::Subtract>(relu_pos, mul);
function = std::make_shared<Function>(NodeVector{sub}, ParameterVector{data});
manager.register_pass<pass::PReluFusion>();
}
{
auto data = std::make_shared<opset8::Parameter>(element::f32, Shape{1, 128});
auto prelu_const = opset8::Constant::create(element::f32, Shape{1}, {0.001});
auto prelu = std::make_shared<opset8::PRelu>(data, prelu_const);
function_ref =
std::make_shared<Function>(NodeVector{prelu}, ParameterVector{data});
}
}
TEST_F(TransformationTestsF, PReluFusionFail) {
{
auto data = std::make_shared<opset8::Parameter>(element::f32, Shape{1, 128});
auto relu_pos = std::make_shared<ngraph::opset8::Relu>(data);
auto mul_neg_const = opset8::Constant::create(element::f32, Shape{1}, {2.0});
auto mul_neg = std::make_shared<ngraph::opset8::Multiply>(data, mul_neg_const);
auto relu_neg = std::make_shared<ngraph::opset8::Relu>(mul_neg);
auto mul_const = opset8::Constant::create(element::f32, Shape{1}, {0.001});
auto mul = std::make_shared<ngraph::opset8::Multiply>(relu_neg, mul_const);
auto sub = std::make_shared<ngraph::opset8::Subtract>(relu_pos, mul);
function = std::make_shared<Function>(NodeVector{sub}, ParameterVector{data});
manager.register_pass<pass::PReluFusion>();
}
function_ref = ngraph::clone_function(*function);
}