Add AbsSubMul PReLu fusion (#16086)
This commit is contained in:
@@ -17,6 +17,7 @@ class TRANSFORMATIONS_API PReluFusionNegativeAdd;
|
||||
class TRANSFORMATIONS_API PReluFusionNegativeSub;
|
||||
class TRANSFORMATIONS_API PReluFusionMultiplyAdd;
|
||||
class TRANSFORMATIONS_API PReluFusionMultiplySub;
|
||||
class TRANSFORMATIONS_API PReluFusionAbsSubMulMulAdd;
|
||||
|
||||
} // namespace pass
|
||||
} // namespace ov
|
||||
@@ -99,6 +100,27 @@ public:
|
||||
PReluFusionMultiplySub();
|
||||
};
|
||||
|
||||
/**
|
||||
* @ingroup ie_transformation_common_api
|
||||
* @brief PReluFusionAbsSubMulMulAdd transformation replaces a sub-graph
|
||||
* Op
|
||||
* / | \
|
||||
* Relu | Abs
|
||||
* | \ |
|
||||
* | Sub
|
||||
* | |
|
||||
* | Multiply
|
||||
* | |
|
||||
* | Multiply (0.5)
|
||||
* \ /
|
||||
* Add
|
||||
*/
|
||||
class ov::pass::PReluFusionAbsSubMulMulAdd : public ov::pass::MatcherPass {
|
||||
public:
|
||||
OPENVINO_RTTI("PReluFusionAbsSubMulMulAdd", "0");
|
||||
PReluFusionAbsSubMulMulAdd();
|
||||
};
|
||||
|
||||
/**
|
||||
* @ingroup ie_transformation_common_api
|
||||
* @brief PReluFusion transformation replaces various sub-graphs with a PRelu op.
|
||||
@@ -111,5 +133,6 @@ public:
|
||||
add_matcher<ov::pass::PReluFusionNegativeSub>();
|
||||
add_matcher<ov::pass::PReluFusionMultiplyAdd>();
|
||||
add_matcher<ov::pass::PReluFusionMultiplySub>();
|
||||
add_matcher<ov::pass::PReluFusionAbsSubMulMulAdd>();
|
||||
}
|
||||
};
|
||||
|
||||
@@ -10,6 +10,7 @@
|
||||
#include <ngraph/pattern/op/or.hpp>
|
||||
#include <ngraph/pattern/op/wrap_type.hpp>
|
||||
#include <ngraph/rt_info.hpp>
|
||||
#include <openvino/opsets/opset10.hpp>
|
||||
#include <openvino/opsets/opset8.hpp>
|
||||
|
||||
#include "itt.hpp"
|
||||
@@ -155,3 +156,48 @@ ov::pass::PReluFusionMultiplySub::PReluFusionMultiplySub() {
|
||||
auto m = std::make_shared<ngraph::pattern::Matcher>(sub, matcher_name);
|
||||
register_matcher(m, callback);
|
||||
}
|
||||
|
||||
ov::pass::PReluFusionAbsSubMulMulAdd::PReluFusionAbsSubMulMulAdd() {
|
||||
MATCHER_SCOPE(PReluFusionAbsSubMulMulAdd);
|
||||
|
||||
using namespace std;
|
||||
using namespace ov;
|
||||
using namespace ov::opset10;
|
||||
|
||||
const auto equals_half = [](const Output<Node>& node) {
|
||||
float v;
|
||||
const auto constant = dynamic_pointer_cast<Constant>(node.get_node_shared_ptr());
|
||||
return constant && op::util::get_single_value(constant, v) && v == 0.5f;
|
||||
};
|
||||
|
||||
const auto input = pass::pattern::any_input();
|
||||
const auto relu = pattern::wrap_type<Relu>({input});
|
||||
const auto abs = pattern::wrap_type<Abs>({input});
|
||||
const auto sub = pattern::wrap_type<Subtract>({input, abs});
|
||||
const auto mul_1_constant = pattern::wrap_type<Constant>();
|
||||
const auto mul_1 = pattern::wrap_type<Multiply>({sub, mul_1_constant});
|
||||
const auto mul_2_constant = pattern::wrap_type<Constant>(equals_half);
|
||||
const auto mul_2 = pattern::wrap_type<Multiply>({mul_1, mul_2_constant});
|
||||
const auto add = pattern::wrap_type<Add>({mul_2, relu});
|
||||
|
||||
matcher_pass_callback callback = [=](pattern::Matcher& m) {
|
||||
const auto& pattern_to_output = m.get_pattern_value_map();
|
||||
const auto input_output = pattern_to_output.at(input);
|
||||
const auto add_node = pattern_to_output.at(add).get_node_shared_ptr();
|
||||
const auto slope = pattern_to_output.at(mul_1_constant);
|
||||
const auto prelu = make_shared<PRelu>(input_output, slope);
|
||||
|
||||
prelu->set_friendly_name(m.get_match_root()->get_friendly_name());
|
||||
const OutputVector copy_from = {pattern_to_output.at(relu),
|
||||
pattern_to_output.at(abs),
|
||||
pattern_to_output.at(sub),
|
||||
pattern_to_output.at(mul_1),
|
||||
pattern_to_output.at(mul_2),
|
||||
pattern_to_output.at(add)};
|
||||
copy_runtime_info(as_node_vector(copy_from), prelu);
|
||||
replace_node(add_node, prelu);
|
||||
return true;
|
||||
};
|
||||
auto m = make_shared<pattern::Matcher>(add, matcher_name);
|
||||
register_matcher(m, callback);
|
||||
}
|
||||
|
||||
@@ -11,6 +11,7 @@
|
||||
#include <ngraph/function.hpp>
|
||||
#include <ngraph/opsets/opset8.hpp>
|
||||
#include <ngraph/pass/manager.hpp>
|
||||
#include <openvino/opsets/opset10.hpp>
|
||||
#include <transformations/common_optimizations/prelu_fusion.hpp>
|
||||
|
||||
#include "common_test_utils/ngraph_test_utils.hpp"
|
||||
@@ -131,3 +132,28 @@ TEST_F(TransformationTestsF, PReluFusionFail) {
|
||||
|
||||
function_ref = ngraph::clone_function(*function);
|
||||
}
|
||||
|
||||
TEST_F(TransformationTestsF, PReluFusionAbsSubMulMulAdd) {
|
||||
using namespace std;
|
||||
using namespace ov::opset10;
|
||||
{
|
||||
const auto data = make_shared<Parameter>(element::f32, Shape{1, 128});
|
||||
const auto relu = make_shared<Relu>(data);
|
||||
const auto abs = make_shared<Abs>(data);
|
||||
const auto sub = make_shared<Subtract>(data, abs);
|
||||
const auto mul_1_const = Constant::create(element::f32, Shape{1}, {0.022});
|
||||
const auto mul_1 = make_shared<Multiply>(sub, mul_1_const);
|
||||
const auto mul_2_const = Constant::create(element::f32, Shape{1}, {0.5});
|
||||
const auto mul_2 = make_shared<Multiply>(mul_1, mul_2_const);
|
||||
const auto add = make_shared<Add>(relu, mul_2);
|
||||
function = make_shared<Function>(NodeVector{add}, ParameterVector{data});
|
||||
|
||||
manager.register_pass<ov::pass::PReluFusion>();
|
||||
}
|
||||
{
|
||||
const auto data = make_shared<Parameter>(element::f32, Shape{1, 128});
|
||||
const auto prelu_const = Constant::create(element::f32, Shape{1}, {0.022});
|
||||
const auto prelu = make_shared<PRelu>(data, prelu_const);
|
||||
function_ref = make_shared<Function>(NodeVector{prelu}, ParameterVector{data});
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user