[nGraph] [MO] Add ngraph transformations for PRelu fusing (#10209)
* MO: add PRelu fusing pattern with Sub * PRelu fusing * Apply review comments * Code style
This commit is contained in:
parent
93722fe101
commit
3c721b0f03
@ -0,0 +1,115 @@
|
||||
// Copyright (C) 2022 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <memory>
|
||||
#include <ngraph/pass/graph_rewrite.hpp>
|
||||
#include <transformations_visibility.hpp>
|
||||
#include <utility>
|
||||
|
||||
namespace ngraph {
|
||||
namespace pass {
|
||||
|
||||
class TRANSFORMATIONS_API PReluFusion;
|
||||
class TRANSFORMATIONS_API PReluFusionNegativeAdd;
|
||||
class TRANSFORMATIONS_API PReluFusionNegativeSub;
|
||||
class TRANSFORMATIONS_API PReluFusionMultiplyAdd;
|
||||
class TRANSFORMATIONS_API PReluFusionMultiplySub;
|
||||
|
||||
} // namespace pass
|
||||
} // namespace ngraph
|
||||
|
||||
/**
|
||||
* @ingroup ie_transformation_common_api
|
||||
* @brief PReluFusionNegativeAdd transformation replaces a sub-graph
|
||||
* Op
|
||||
* / \
|
||||
* Relu Negative
|
||||
* | |
|
||||
* | Relu
|
||||
* | |
|
||||
* | Negative
|
||||
* | |
|
||||
* | Multiply
|
||||
* \ /
|
||||
* Add
|
||||
*/
|
||||
class ngraph::pass::PReluFusionNegativeAdd : public ngraph::pass::MatcherPass {
|
||||
public:
|
||||
NGRAPH_RTTI_DECLARATION;
|
||||
PReluFusionNegativeAdd();
|
||||
};
|
||||
|
||||
/**
|
||||
* @ingroup ie_transformation_common_api
|
||||
* @brief PReluFusionNegativeSub transformation replaces a sub-graph
|
||||
* Op
|
||||
* / \
|
||||
* Relu Negative
|
||||
* | |
|
||||
* | Relu
|
||||
* | |
|
||||
* | Multiply
|
||||
* \ /
|
||||
* Sub
|
||||
*/
|
||||
class ngraph::pass::PReluFusionNegativeSub : public ngraph::pass::MatcherPass {
|
||||
public:
|
||||
NGRAPH_RTTI_DECLARATION;
|
||||
PReluFusionNegativeSub();
|
||||
};
|
||||
|
||||
/**
|
||||
* @ingroup ie_transformation_common_api
|
||||
* @brief PReluFusionMultiplyAdd transformation replaces a sub-graph
|
||||
* Op
|
||||
* / \
|
||||
* Relu Multiply (-1)
|
||||
* | |
|
||||
* | Relu
|
||||
* | |
|
||||
* | Multiply
|
||||
* \ /
|
||||
* Add
|
||||
*/
|
||||
class ngraph::pass::PReluFusionMultiplyAdd : public ngraph::pass::MatcherPass {
|
||||
public:
|
||||
NGRAPH_RTTI_DECLARATION;
|
||||
PReluFusionMultiplyAdd();
|
||||
};
|
||||
|
||||
/**
|
||||
* @ingroup ie_transformation_common_api
|
||||
* @brief PReluFusionMultiplySub transformation replaces a sub-graph
|
||||
* Op
|
||||
* / \
|
||||
* Relu Multiply (-1)
|
||||
* | |
|
||||
* | Relu
|
||||
* | |
|
||||
* | Multiply
|
||||
* \ /
|
||||
* Sub
|
||||
*/
|
||||
class ngraph::pass::PReluFusionMultiplySub : public ngraph::pass::MatcherPass {
|
||||
public:
|
||||
NGRAPH_RTTI_DECLARATION;
|
||||
PReluFusionMultiplySub();
|
||||
};
|
||||
|
||||
/**
|
||||
* @ingroup ie_transformation_common_api
|
||||
* @brief PReluFusion transformation replaces various sub-graphs with a PRelu op.
|
||||
*/
|
||||
class ngraph::pass::PReluFusion : public ngraph::pass::GraphRewrite {
|
||||
public:
|
||||
NGRAPH_RTTI_DECLARATION;
|
||||
PReluFusion() {
|
||||
add_matcher<ngraph::pass::PReluFusionNegativeAdd>();
|
||||
add_matcher<ngraph::pass::PReluFusionNegativeSub>();
|
||||
add_matcher<ngraph::pass::PReluFusionMultiplyAdd>();
|
||||
add_matcher<ngraph::pass::PReluFusionMultiplySub>();
|
||||
}
|
||||
};
|
@ -35,6 +35,7 @@
|
||||
#include <transformations/common_optimizations/normalize_l2_fusion.hpp>
|
||||
#include <transformations/common_optimizations/optimize_strided_slice.hpp>
|
||||
#include <transformations/common_optimizations/pad_fusion.hpp>
|
||||
#include <transformations/common_optimizations/prelu_fusion.hpp>
|
||||
#include <transformations/common_optimizations/random_uniform_fusion.hpp>
|
||||
#include <transformations/common_optimizations/remove_concat_zero_dim_input.hpp>
|
||||
#include <transformations/common_optimizations/remove_filtering_boxes_by_size.hpp>
|
||||
@ -153,6 +154,7 @@ bool ngraph::pass::MOCTransformations::run_on_model(const std::shared_ptr<ngraph
|
||||
common_fusions->add_matcher<ngraph::pass::TransposeToReshape>();
|
||||
common_fusions->add_matcher<ngraph::pass::ReshapeSequenceFusion>(m_use_shapes);
|
||||
common_fusions->add_matcher<ngraph::pass::MatMulConstTransposesExtraction>();
|
||||
common_fusions->add_matcher<ngraph::pass::PReluFusion>();
|
||||
common_fusions->set_name("ngraph::pass::CommonFusions");
|
||||
|
||||
manager.register_pass<ngraph::pass::BinarizeWeights>();
|
||||
|
@ -0,0 +1,163 @@
|
||||
// Copyright (C) 2022 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#define _USE_MATH_DEFINES
|
||||
|
||||
#include "transformations/common_optimizations/prelu_fusion.hpp"
|
||||
|
||||
#include <memory>
|
||||
#include <ngraph/opsets/opset8.hpp>
|
||||
#include <ngraph/pattern/op/or.hpp>
|
||||
#include <ngraph/pattern/op/wrap_type.hpp>
|
||||
#include <ngraph/rt_info.hpp>
|
||||
|
||||
#include "itt.hpp"
|
||||
#include "transformations/utils/utils.hpp"
|
||||
|
||||
NGRAPH_RTTI_DEFINITION(ngraph::pass::PReluFusion, "PReluFusion", 0);
|
||||
NGRAPH_RTTI_DEFINITION(ngraph::pass::PReluFusionNegativeAdd, "PReluFusionNegativeAdd", 0);
|
||||
NGRAPH_RTTI_DEFINITION(ngraph::pass::PReluFusionNegativeSub, "PReluFusionNegativeSub", 0);
|
||||
NGRAPH_RTTI_DEFINITION(ngraph::pass::PReluFusionMultiplyAdd, "PReluFusionMultiplyAdd", 0);
|
||||
NGRAPH_RTTI_DEFINITION(ngraph::pass::PReluFusionMultiplySub, "PReluFusionMultiplySub", 0);
|
||||
|
||||
ngraph::pass::PReluFusionNegativeAdd::PReluFusionNegativeAdd() {
|
||||
MATCHER_SCOPE(PReluFusionNegativeAdd);
|
||||
auto input = ngraph::pattern::any_input();
|
||||
auto relu_pos = ngraph::pattern::wrap_type<ngraph::opset8::Relu>({input});
|
||||
auto neg1 = ngraph::pattern::wrap_type<ngraph::opset8::Negative>({input});
|
||||
auto relu_neg = ngraph::pattern::wrap_type<ngraph::opset8::Relu>({neg1});
|
||||
auto neg2 = ngraph::pattern::wrap_type<ngraph::opset8::Negative>({relu_neg});
|
||||
auto mul_constant = ngraph::pattern::wrap_type<ngraph::opset8::Constant>();
|
||||
auto mul = ngraph::pattern::wrap_type<ngraph::opset8::Multiply>({neg2, mul_constant});
|
||||
auto add = ngraph::pattern::wrap_type<ngraph::opset8::Add>({relu_pos, mul});
|
||||
|
||||
ngraph::matcher_pass_callback callback = [=](ngraph::pattern::Matcher& m) {
|
||||
const auto& pattern_to_output = m.get_pattern_value_map();
|
||||
auto input_output = pattern_to_output.at(input);
|
||||
auto slope_output = pattern_to_output.at(mul_constant);
|
||||
auto add_node = pattern_to_output.at(add).get_node_shared_ptr();
|
||||
auto prelu = std::make_shared<ngraph::opset8::PRelu>(input_output, slope_output);
|
||||
prelu->set_friendly_name(m.get_match_root()->get_friendly_name());
|
||||
ngraph::NodeVector copy_from = {pattern_to_output.at(relu_pos).get_node_shared_ptr(),
|
||||
pattern_to_output.at(neg1).get_node_shared_ptr(),
|
||||
pattern_to_output.at(relu_neg).get_node_shared_ptr(),
|
||||
pattern_to_output.at(neg2).get_node_shared_ptr(),
|
||||
pattern_to_output.at(mul).get_node_shared_ptr(),
|
||||
pattern_to_output.at(add).get_node_shared_ptr()};
|
||||
ngraph::copy_runtime_info(copy_from, prelu);
|
||||
ngraph::replace_node(add_node, prelu);
|
||||
return true;
|
||||
};
|
||||
auto m = std::make_shared<ngraph::pattern::Matcher>(add, matcher_name);
|
||||
register_matcher(m, callback);
|
||||
}
|
||||
|
||||
ngraph::pass::PReluFusionNegativeSub::PReluFusionNegativeSub() {
|
||||
MATCHER_SCOPE(PReluFusionNegativeSub);
|
||||
auto input = ngraph::pattern::any_input();
|
||||
auto relu_pos = ngraph::pattern::wrap_type<ngraph::opset8::Relu>({input});
|
||||
auto neg1 = ngraph::pattern::wrap_type<ngraph::opset8::Negative>({input});
|
||||
auto relu_neg = ngraph::pattern::wrap_type<ngraph::opset8::Relu>({neg1});
|
||||
auto mul_constant = ngraph::pattern::wrap_type<ngraph::opset8::Constant>();
|
||||
auto mul = ngraph::pattern::wrap_type<ngraph::opset8::Multiply>({relu_neg, mul_constant});
|
||||
auto sub = ngraph::pattern::wrap_type<ngraph::opset8::Subtract>({relu_pos, mul});
|
||||
|
||||
ngraph::matcher_pass_callback callback = [=](ngraph::pattern::Matcher& m) {
|
||||
const auto& pattern_to_output = m.get_pattern_value_map();
|
||||
auto input_output = pattern_to_output.at(input);
|
||||
auto slope_output = pattern_to_output.at(mul_constant);
|
||||
auto sub_node = pattern_to_output.at(sub).get_node_shared_ptr();
|
||||
auto prelu = std::make_shared<ngraph::opset8::PRelu>(input_output, slope_output);
|
||||
prelu->set_friendly_name(m.get_match_root()->get_friendly_name());
|
||||
ngraph::NodeVector copy_from = {pattern_to_output.at(relu_pos).get_node_shared_ptr(),
|
||||
pattern_to_output.at(neg1).get_node_shared_ptr(),
|
||||
pattern_to_output.at(relu_neg).get_node_shared_ptr(),
|
||||
pattern_to_output.at(mul).get_node_shared_ptr(),
|
||||
pattern_to_output.at(sub).get_node_shared_ptr()};
|
||||
ngraph::copy_runtime_info(copy_from, prelu);
|
||||
ngraph::replace_node(sub_node, prelu);
|
||||
return true;
|
||||
};
|
||||
auto m = std::make_shared<ngraph::pattern::Matcher>(sub, matcher_name);
|
||||
register_matcher(m, callback);
|
||||
}
|
||||
|
||||
static std::function<bool(ngraph::Output<ngraph::Node>)> constant_value(const float target_value) {
|
||||
return [=](const ngraph::Output<ngraph::Node>& output) -> bool {
|
||||
auto node = std::dynamic_pointer_cast<ngraph::opset8::Constant>(output.get_node_shared_ptr());
|
||||
if (!node) {
|
||||
return false;
|
||||
}
|
||||
float value;
|
||||
if (!ngraph::op::util::get_single_value(node, value)) {
|
||||
return false;
|
||||
}
|
||||
return value == target_value;
|
||||
};
|
||||
}
|
||||
|
||||
ngraph::pass::PReluFusionMultiplyAdd::PReluFusionMultiplyAdd() {
|
||||
MATCHER_SCOPE(PReluFusionMultiplyAdd);
|
||||
auto input = ngraph::pattern::any_input();
|
||||
auto relu_pos = ngraph::pattern::wrap_type<ngraph::opset8::Relu>({input});
|
||||
auto mul_neg_constant = ngraph::pattern::wrap_type<ngraph::opset8::Constant>(constant_value(-1.0));
|
||||
auto mul_neg = ngraph::pattern::wrap_type<ngraph::opset8::Multiply>({input, mul_neg_constant});
|
||||
auto relu_neg = ngraph::pattern::wrap_type<ngraph::opset8::Relu>({mul_neg});
|
||||
auto mul_constant = ngraph::pattern::wrap_type<ngraph::opset8::Constant>();
|
||||
auto mul = ngraph::pattern::wrap_type<ngraph::opset8::Multiply>({relu_neg, mul_constant});
|
||||
auto add = ngraph::pattern::wrap_type<ngraph::opset8::Add>({relu_pos, mul});
|
||||
|
||||
ngraph::matcher_pass_callback callback = [=](ngraph::pattern::Matcher& m) {
|
||||
const auto& pattern_to_output = m.get_pattern_value_map();
|
||||
auto input_output = pattern_to_output.at(input);
|
||||
auto slope_output = pattern_to_output.at(mul_constant);
|
||||
auto add_node = pattern_to_output.at(add).get_node_shared_ptr();
|
||||
auto negative = ngraph::op::util::make_try_fold<ngraph::opset8::Negative>(slope_output);
|
||||
auto prelu = std::make_shared<ngraph::opset8::PRelu>(input_output, negative);
|
||||
|
||||
prelu->set_friendly_name(m.get_match_root()->get_friendly_name());
|
||||
ngraph::NodeVector copy_from = {pattern_to_output.at(relu_pos).get_node_shared_ptr(),
|
||||
pattern_to_output.at(mul_neg).get_node_shared_ptr(),
|
||||
pattern_to_output.at(relu_neg).get_node_shared_ptr(),
|
||||
pattern_to_output.at(mul).get_node_shared_ptr(),
|
||||
pattern_to_output.at(add).get_node_shared_ptr()};
|
||||
ngraph::copy_runtime_info(copy_from, {prelu, negative});
|
||||
ngraph::replace_node(add_node, prelu);
|
||||
return true;
|
||||
};
|
||||
auto m = std::make_shared<ngraph::pattern::Matcher>(add, matcher_name);
|
||||
register_matcher(m, callback);
|
||||
}
|
||||
|
||||
ngraph::pass::PReluFusionMultiplySub::PReluFusionMultiplySub() {
|
||||
MATCHER_SCOPE(PReluFusionMultiplySub);
|
||||
auto input = ngraph::pattern::any_input();
|
||||
auto relu_pos = ngraph::pattern::wrap_type<ngraph::opset8::Relu>({input});
|
||||
auto mul_neg_constant = ngraph::pattern::wrap_type<ngraph::opset8::Constant>(constant_value(-1.0));
|
||||
auto mul_neg = ngraph::pattern::wrap_type<ngraph::opset8::Multiply>({input, mul_neg_constant});
|
||||
auto relu_neg = ngraph::pattern::wrap_type<ngraph::opset8::Relu>({mul_neg});
|
||||
auto mul_constant = ngraph::pattern::wrap_type<ngraph::opset8::Constant>();
|
||||
auto mul = ngraph::pattern::wrap_type<ngraph::opset8::Multiply>({relu_neg, mul_constant});
|
||||
auto sub = ngraph::pattern::wrap_type<ngraph::opset8::Subtract>({relu_pos, mul});
|
||||
|
||||
ngraph::matcher_pass_callback callback = [=](ngraph::pattern::Matcher& m) {
|
||||
const auto& pattern_to_output = m.get_pattern_value_map();
|
||||
auto input_output = pattern_to_output.at(input);
|
||||
auto slope_output = pattern_to_output.at(mul_constant);
|
||||
auto sub_node = pattern_to_output.at(sub).get_node_shared_ptr();
|
||||
auto prelu = std::make_shared<ngraph::opset8::PRelu>(input_output, slope_output);
|
||||
|
||||
prelu->set_friendly_name(m.get_match_root()->get_friendly_name());
|
||||
ngraph::NodeVector copy_from = {pattern_to_output.at(relu_pos).get_node_shared_ptr(),
|
||||
pattern_to_output.at(mul_neg).get_node_shared_ptr(),
|
||||
pattern_to_output.at(relu_neg).get_node_shared_ptr(),
|
||||
pattern_to_output.at(mul).get_node_shared_ptr(),
|
||||
pattern_to_output.at(sub).get_node_shared_ptr()};
|
||||
ngraph::copy_runtime_info(copy_from, prelu);
|
||||
ngraph::replace_node(sub_node, prelu);
|
||||
return true;
|
||||
};
|
||||
auto m = std::make_shared<ngraph::pattern::Matcher>(sub, matcher_name);
|
||||
register_matcher(m, callback);
|
||||
}
|
@ -0,0 +1,137 @@
|
||||
// Copyright (C) 2022 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#define _USE_MATH_DEFINES
|
||||
|
||||
#include <gtest/gtest.h>
|
||||
#include <math.h>
|
||||
|
||||
#include <memory>
|
||||
#include <ngraph/function.hpp>
|
||||
#include <ngraph/opsets/opset8.hpp>
|
||||
#include <ngraph/pass/manager.hpp>
|
||||
#include <transformations/common_optimizations/prelu_fusion.hpp>
|
||||
|
||||
#include "common_test_utils/ngraph_test_utils.hpp"
|
||||
|
||||
using namespace testing;
|
||||
using namespace ngraph;
|
||||
|
||||
TEST_F(TransformationTestsF, PReluFusionNegativeAdd) {
|
||||
{
|
||||
auto data = std::make_shared<opset8::Parameter>(element::f32, Shape{1, 128});
|
||||
auto relu_pos = std::make_shared<ngraph::opset8::Relu>(data);
|
||||
auto neg = std::make_shared<ngraph::opset8::Negative>(data);
|
||||
auto relu_neg = std::make_shared<ngraph::opset8::Relu>(neg);
|
||||
auto neg2 = std::make_shared<ngraph::opset8::Negative>(relu_neg);
|
||||
auto mul_const = opset8::Constant::create(element::f32, Shape{1}, {0.001});
|
||||
auto mul = std::make_shared<ngraph::opset8::Multiply>(neg2, mul_const);
|
||||
auto add = std::make_shared<ngraph::opset8::Add>(relu_pos, mul);
|
||||
|
||||
function = std::make_shared<Function>(NodeVector{add}, ParameterVector{data});
|
||||
|
||||
manager.register_pass<pass::PReluFusion>();
|
||||
}
|
||||
|
||||
{
|
||||
auto data = std::make_shared<opset8::Parameter>(element::f32, Shape{1, 128});
|
||||
auto prelu_const = opset8::Constant::create(element::f32, Shape{1}, {0.001});
|
||||
auto prelu = std::make_shared<opset8::PRelu>(data, prelu_const);
|
||||
function_ref =
|
||||
std::make_shared<Function>(NodeVector{prelu}, ParameterVector{data});
|
||||
}
|
||||
}
|
||||
|
||||
TEST_F(TransformationTestsF, PReluFusionNegativeSub) {
|
||||
{
|
||||
auto data = std::make_shared<opset8::Parameter>(element::f32, Shape{1, 128});
|
||||
auto relu_pos = std::make_shared<ngraph::opset8::Relu>(data);
|
||||
auto neg = std::make_shared<ngraph::opset8::Negative>(data);
|
||||
auto relu_neg = std::make_shared<ngraph::opset8::Relu>(neg);
|
||||
auto mul_const = opset8::Constant::create(element::f32, Shape{1}, {0.001});
|
||||
auto mul = std::make_shared<ngraph::opset8::Multiply>(relu_neg, mul_const);
|
||||
auto sub = std::make_shared<ngraph::opset8::Subtract>(relu_pos, mul);
|
||||
|
||||
function = std::make_shared<Function>(NodeVector{sub}, ParameterVector{data});
|
||||
|
||||
manager.register_pass<pass::PReluFusion>();
|
||||
}
|
||||
|
||||
{
|
||||
auto data = std::make_shared<opset8::Parameter>(element::f32, Shape{1, 128});
|
||||
auto prelu_const = opset8::Constant::create(element::f32, Shape{1}, {0.001});
|
||||
auto prelu = std::make_shared<opset8::PRelu>(data, prelu_const);
|
||||
function_ref =
|
||||
std::make_shared<Function>(NodeVector{prelu}, ParameterVector{data});
|
||||
}
|
||||
}
|
||||
|
||||
TEST_F(TransformationTestsF, PReluFusionMultiplyAdd) {
|
||||
{
|
||||
auto data = std::make_shared<opset8::Parameter>(element::f32, Shape{1, 128});
|
||||
auto relu_pos = std::make_shared<ngraph::opset8::Relu>(data);
|
||||
auto mul_neg_const = opset8::Constant::create(element::f32, Shape{1}, {-1.0});
|
||||
auto mul_neg = std::make_shared<ngraph::opset8::Multiply>(data, mul_neg_const);
|
||||
auto relu_neg = std::make_shared<ngraph::opset8::Relu>(mul_neg);
|
||||
auto mul_const = opset8::Constant::create(element::f32, Shape{1}, {-0.001});
|
||||
auto mul = std::make_shared<ngraph::opset8::Multiply>(relu_neg, mul_const);
|
||||
auto add = std::make_shared<ngraph::opset8::Add>(relu_pos, mul);
|
||||
|
||||
function = std::make_shared<Function>(NodeVector{add}, ParameterVector{data});
|
||||
|
||||
manager.register_pass<pass::PReluFusion>();
|
||||
}
|
||||
|
||||
{
|
||||
auto data = std::make_shared<opset8::Parameter>(element::f32, Shape{1, 128});
|
||||
auto prelu_const = opset8::Constant::create(element::f32, Shape{1}, {0.001});
|
||||
auto prelu = std::make_shared<opset8::PRelu>(data, prelu_const);
|
||||
function_ref =
|
||||
std::make_shared<Function>(NodeVector{prelu}, ParameterVector{data});
|
||||
}
|
||||
}
|
||||
|
||||
TEST_F(TransformationTestsF, PReluFusionMultiplySub) {
|
||||
{
|
||||
auto data = std::make_shared<opset8::Parameter>(element::f32, Shape{1, 128});
|
||||
auto relu_pos = std::make_shared<ngraph::opset8::Relu>(data);
|
||||
auto mul_neg_const = opset8::Constant::create(element::f32, Shape{1}, {-1.0});
|
||||
auto mul_neg = std::make_shared<ngraph::opset8::Multiply>(data, mul_neg_const);
|
||||
auto relu_neg = std::make_shared<ngraph::opset8::Relu>(mul_neg);
|
||||
auto mul_const = opset8::Constant::create(element::f32, Shape{1}, {0.001});
|
||||
auto mul = std::make_shared<ngraph::opset8::Multiply>(relu_neg, mul_const);
|
||||
auto sub = std::make_shared<ngraph::opset8::Subtract>(relu_pos, mul);
|
||||
|
||||
function = std::make_shared<Function>(NodeVector{sub}, ParameterVector{data});
|
||||
|
||||
manager.register_pass<pass::PReluFusion>();
|
||||
}
|
||||
|
||||
{
|
||||
auto data = std::make_shared<opset8::Parameter>(element::f32, Shape{1, 128});
|
||||
auto prelu_const = opset8::Constant::create(element::f32, Shape{1}, {0.001});
|
||||
auto prelu = std::make_shared<opset8::PRelu>(data, prelu_const);
|
||||
function_ref =
|
||||
std::make_shared<Function>(NodeVector{prelu}, ParameterVector{data});
|
||||
}
|
||||
}
|
||||
|
||||
TEST_F(TransformationTestsF, PReluFusionFail) {
|
||||
{
|
||||
auto data = std::make_shared<opset8::Parameter>(element::f32, Shape{1, 128});
|
||||
auto relu_pos = std::make_shared<ngraph::opset8::Relu>(data);
|
||||
auto mul_neg_const = opset8::Constant::create(element::f32, Shape{1}, {2.0});
|
||||
auto mul_neg = std::make_shared<ngraph::opset8::Multiply>(data, mul_neg_const);
|
||||
auto relu_neg = std::make_shared<ngraph::opset8::Relu>(mul_neg);
|
||||
auto mul_const = opset8::Constant::create(element::f32, Shape{1}, {0.001});
|
||||
auto mul = std::make_shared<ngraph::opset8::Multiply>(relu_neg, mul_const);
|
||||
auto sub = std::make_shared<ngraph::opset8::Subtract>(relu_pos, mul);
|
||||
|
||||
function = std::make_shared<Function>(NodeVector{sub}, ParameterVector{data});
|
||||
|
||||
manager.register_pass<pass::PReluFusion>();
|
||||
}
|
||||
|
||||
function_ref = ngraph::clone_function(*function);
|
||||
}
|
Loading…
Reference in New Issue
Block a user