Add AbsSubMul PReLu fusion (#16086)

This commit is contained in:
Tomasz Jankowski
2023-03-06 13:56:22 +01:00
committed by GitHub
parent 4ce35fd851
commit b8348cda2e
3 changed files with 95 additions and 0 deletions

View File

@@ -17,6 +17,7 @@ class TRANSFORMATIONS_API PReluFusionNegativeAdd;
class TRANSFORMATIONS_API PReluFusionNegativeSub;
class TRANSFORMATIONS_API PReluFusionMultiplyAdd;
class TRANSFORMATIONS_API PReluFusionMultiplySub;
class TRANSFORMATIONS_API PReluFusionAbsSubMulMulAdd;
} // namespace pass
} // namespace ov
@@ -99,6 +100,27 @@ public:
PReluFusionMultiplySub();
};
/**
* @ingroup ie_transformation_common_api
* @brief PReluFusionAbsSubMulMulAdd transformation replaces a sub-graph
* Op
* / | \
* Relu | Abs
* | \ |
* | Sub
* | |
* | Multiply
* | |
* | Multiply (0.5)
* \ /
* Add
*/
class ov::pass::PReluFusionAbsSubMulMulAdd : public ov::pass::MatcherPass {
public:
OPENVINO_RTTI("PReluFusionAbsSubMulMulAdd", "0");
PReluFusionAbsSubMulMulAdd();
};
/**
* @ingroup ie_transformation_common_api
* @brief PReluFusion transformation replaces various sub-graphs with a PRelu op.
@@ -111,5 +133,6 @@ public:
add_matcher<ov::pass::PReluFusionNegativeSub>();
add_matcher<ov::pass::PReluFusionMultiplyAdd>();
add_matcher<ov::pass::PReluFusionMultiplySub>();
add_matcher<ov::pass::PReluFusionAbsSubMulMulAdd>();
}
};

View File

@@ -10,6 +10,7 @@
#include <ngraph/pattern/op/or.hpp>
#include <ngraph/pattern/op/wrap_type.hpp>
#include <ngraph/rt_info.hpp>
#include <openvino/opsets/opset10.hpp>
#include <openvino/opsets/opset8.hpp>
#include "itt.hpp"
@@ -155,3 +156,48 @@ ov::pass::PReluFusionMultiplySub::PReluFusionMultiplySub() {
auto m = std::make_shared<ngraph::pattern::Matcher>(sub, matcher_name);
register_matcher(m, callback);
}
ov::pass::PReluFusionAbsSubMulMulAdd::PReluFusionAbsSubMulMulAdd() {
MATCHER_SCOPE(PReluFusionAbsSubMulMulAdd);
using namespace std;
using namespace ov;
using namespace ov::opset10;
const auto equals_half = [](const Output<Node>& node) {
float v;
const auto constant = dynamic_pointer_cast<Constant>(node.get_node_shared_ptr());
return constant && op::util::get_single_value(constant, v) && v == 0.5f;
};
const auto input = pass::pattern::any_input();
const auto relu = pattern::wrap_type<Relu>({input});
const auto abs = pattern::wrap_type<Abs>({input});
const auto sub = pattern::wrap_type<Subtract>({input, abs});
const auto mul_1_constant = pattern::wrap_type<Constant>();
const auto mul_1 = pattern::wrap_type<Multiply>({sub, mul_1_constant});
const auto mul_2_constant = pattern::wrap_type<Constant>(equals_half);
const auto mul_2 = pattern::wrap_type<Multiply>({mul_1, mul_2_constant});
const auto add = pattern::wrap_type<Add>({mul_2, relu});
matcher_pass_callback callback = [=](pattern::Matcher& m) {
const auto& pattern_to_output = m.get_pattern_value_map();
const auto input_output = pattern_to_output.at(input);
const auto add_node = pattern_to_output.at(add).get_node_shared_ptr();
const auto slope = pattern_to_output.at(mul_1_constant);
const auto prelu = make_shared<PRelu>(input_output, slope);
prelu->set_friendly_name(m.get_match_root()->get_friendly_name());
const OutputVector copy_from = {pattern_to_output.at(relu),
pattern_to_output.at(abs),
pattern_to_output.at(sub),
pattern_to_output.at(mul_1),
pattern_to_output.at(mul_2),
pattern_to_output.at(add)};
copy_runtime_info(as_node_vector(copy_from), prelu);
replace_node(add_node, prelu);
return true;
};
auto m = make_shared<pattern::Matcher>(add, matcher_name);
register_matcher(m, callback);
}

View File

@@ -11,6 +11,7 @@
#include <ngraph/function.hpp>
#include <ngraph/opsets/opset8.hpp>
#include <ngraph/pass/manager.hpp>
#include <openvino/opsets/opset10.hpp>
#include <transformations/common_optimizations/prelu_fusion.hpp>
#include "common_test_utils/ngraph_test_utils.hpp"
@@ -131,3 +132,28 @@ TEST_F(TransformationTestsF, PReluFusionFail) {
function_ref = ngraph::clone_function(*function);
}
TEST_F(TransformationTestsF, PReluFusionAbsSubMulMulAdd) {
using namespace std;
using namespace ov::opset10;
{
const auto data = make_shared<Parameter>(element::f32, Shape{1, 128});
const auto relu = make_shared<Relu>(data);
const auto abs = make_shared<Abs>(data);
const auto sub = make_shared<Subtract>(data, abs);
const auto mul_1_const = Constant::create(element::f32, Shape{1}, {0.022});
const auto mul_1 = make_shared<Multiply>(sub, mul_1_const);
const auto mul_2_const = Constant::create(element::f32, Shape{1}, {0.5});
const auto mul_2 = make_shared<Multiply>(mul_1, mul_2_const);
const auto add = make_shared<Add>(relu, mul_2);
function = make_shared<Function>(NodeVector{add}, ParameterVector{data});
manager.register_pass<ov::pass::PReluFusion>();
}
{
const auto data = make_shared<Parameter>(element::f32, Shape{1, 128});
const auto prelu_const = Constant::create(element::f32, Shape{1}, {0.022});
const auto prelu = make_shared<PRelu>(data, prelu_const);
function_ref = make_shared<Function>(NodeVector{prelu}, ParameterVector{data});
}
}