Add PRelu fusion (#16617)

This commit is contained in:
Tomasz Jankowski 2023-03-29 13:32:57 +02:00 committed by GitHub
parent 591c3e61c5
commit f7e898893d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 83 additions and 2 deletions

View File

@ -18,6 +18,7 @@ class TRANSFORMATIONS_API PReluFusionNegativeSub;
class TRANSFORMATIONS_API PReluFusionMultiplyAdd; class TRANSFORMATIONS_API PReluFusionMultiplyAdd;
class TRANSFORMATIONS_API PReluFusionMultiplySub; class TRANSFORMATIONS_API PReluFusionMultiplySub;
class TRANSFORMATIONS_API PReluFusionAbsSubMulMulAdd; class TRANSFORMATIONS_API PReluFusionAbsSubMulMulAdd;
class TRANSFORMATIONS_API PReluFusionNegReluMulAdd;
} // namespace pass } // namespace pass
} // namespace ov } // namespace ov
@ -103,11 +104,11 @@ public:
/** /**
* @ingroup ie_transformation_common_api * @ingroup ie_transformation_common_api
* @brief PReluFusionAbsSubMulMulAdd transformation replaces a sub-graph * @brief PReluFusionAbsSubMulMulAdd transformation replaces a sub-graph
* Op * Op
* / | \ * / | \
* Relu | Abs * Relu | Abs
* | \ | * | \ |
* | Sub * | Subtract
* | | * | |
* | Multiply * | Multiply
* | | * | |
@ -121,6 +122,25 @@ public:
PReluFusionAbsSubMulMulAdd(); PReluFusionAbsSubMulMulAdd();
}; };
/**
* @ingroup ie_transformation_common_api
* @brief PReluFusionNegReluMulAdd transformation replaces a sub-graph
* Op
* / \
* Relu Negative
* | |
* | Relu
* | |
* | Multiply
* \ /
* Add
*/
class ov::pass::PReluFusionNegReluMulAdd : public ov::pass::MatcherPass {
public:
OPENVINO_RTTI("PReluFusionNegReluMulAdd", "0");
PReluFusionNegReluMulAdd();
};
/** /**
* @ingroup ie_transformation_common_api * @ingroup ie_transformation_common_api
* @brief PReluFusion transformation replaces various sub-graphs with a PRelu op. * @brief PReluFusion transformation replaces various sub-graphs with a PRelu op.
@ -134,5 +154,6 @@ public:
add_matcher<ov::pass::PReluFusionMultiplyAdd>(); add_matcher<ov::pass::PReluFusionMultiplyAdd>();
add_matcher<ov::pass::PReluFusionMultiplySub>(); add_matcher<ov::pass::PReluFusionMultiplySub>();
add_matcher<ov::pass::PReluFusionAbsSubMulMulAdd>(); add_matcher<ov::pass::PReluFusionAbsSubMulMulAdd>();
add_matcher<ov::pass::PReluFusionNegReluMulAdd>();
} }
}; };

View File

@ -201,3 +201,38 @@ ov::pass::PReluFusionAbsSubMulMulAdd::PReluFusionAbsSubMulMulAdd() {
auto m = make_shared<pattern::Matcher>(add, matcher_name); auto m = make_shared<pattern::Matcher>(add, matcher_name);
register_matcher(m, callback); register_matcher(m, callback);
} }
ov::pass::PReluFusionNegReluMulAdd::PReluFusionNegReluMulAdd() {
MATCHER_SCOPE(PReluFusionNegReluMulAdd);
using namespace std;
using namespace ov;
using namespace ov::opset10;
const auto input = pass::pattern::any_input();
const auto relu_pos = pattern::wrap_type<Relu>({input});
const auto neg1 = pattern::wrap_type<Negative>({input});
const auto relu_neg = pattern::wrap_type<Relu>({neg1});
const auto mul_constant = pattern::wrap_type<Constant>();
const auto mul = pattern::wrap_type<Multiply>({relu_neg, mul_constant});
const auto add = pattern::wrap_type<Add>({relu_pos, mul});
matcher_pass_callback callback = [=](pattern::Matcher& m) {
const auto& pattern_to_output = m.get_pattern_value_map();
const auto input_output = pattern_to_output.at(input);
const auto add_node = pattern_to_output.at(add).get_node_shared_ptr();
const auto slope = op::util::make_try_fold<Negative>(pattern_to_output.at(mul_constant));
const auto prelu = make_shared<PRelu>(input_output, slope);
prelu->set_friendly_name(m.get_match_root()->get_friendly_name());
NodeVector copy_from = {pattern_to_output.at(relu_pos).get_node_shared_ptr(),
pattern_to_output.at(neg1).get_node_shared_ptr(),
pattern_to_output.at(relu_neg).get_node_shared_ptr(),
pattern_to_output.at(mul).get_node_shared_ptr(),
pattern_to_output.at(add).get_node_shared_ptr()};
copy_runtime_info(copy_from, prelu);
replace_node(add_node, prelu);
return true;
};
auto matcher = make_shared<pattern::Matcher>(add, matcher_name);
register_matcher(matcher, callback);
}

View File

@ -156,4 +156,29 @@ TEST_F(TransformationTestsF, PReluFusionAbsSubMulMulAdd) {
const auto prelu = make_shared<PRelu>(data, prelu_const); const auto prelu = make_shared<PRelu>(data, prelu_const);
function_ref = make_shared<Function>(NodeVector{prelu}, ParameterVector{data}); function_ref = make_shared<Function>(NodeVector{prelu}, ParameterVector{data});
} }
comparator.enable(FunctionsComparator::CmpValues::CONST_VALUES);
}
TEST_F(TransformationTestsF, PReluFusionNegReluMulAdd) {
using namespace std;
using namespace ov::opset10;
{
const auto data = make_shared<Parameter>(element::f32, Shape{2, 12});
const auto relu_pos = make_shared<Relu>(data);
const auto neg = make_shared<Negative>(data);
const auto relu_neg = make_shared<Relu>(neg);
const auto mul_const = Constant::create(element::f32, Shape{1}, {0.235});
const auto mul = make_shared<Multiply>(relu_neg, mul_const);
const auto add = make_shared<Add>(relu_pos, mul);
function = make_shared<Function>(NodeVector{add}, ParameterVector{data});
manager.register_pass<ov::pass::PReluFusion>();
}
{
const auto data = make_shared<Parameter>(element::f32, Shape{2, 12});
const auto prelu_const = Constant::create(element::f32, Shape{1}, {-0.235});
const auto prelu = make_shared<PRelu>(data, prelu_const);
function_ref = make_shared<Function>(NodeVector{prelu}, ParameterVector{data});
}
comparator.enable(FunctionsComparator::CmpValues::CONST_VALUES);
} }