Add PRelu fusion (#16617)
This commit is contained in:
parent
591c3e61c5
commit
f7e898893d
@ -18,6 +18,7 @@ class TRANSFORMATIONS_API PReluFusionNegativeSub;
|
||||
class TRANSFORMATIONS_API PReluFusionMultiplyAdd;
|
||||
class TRANSFORMATIONS_API PReluFusionMultiplySub;
|
||||
class TRANSFORMATIONS_API PReluFusionAbsSubMulMulAdd;
|
||||
class TRANSFORMATIONS_API PReluFusionNegReluMulAdd;
|
||||
|
||||
} // namespace pass
|
||||
} // namespace ov
|
||||
@ -103,11 +104,11 @@ public:
|
||||
/**
|
||||
* @ingroup ie_transformation_common_api
|
||||
* @brief PReluFusionAbsSubMulMulAdd transformation replaces a sub-graph
|
||||
* Op
|
||||
* Op
|
||||
* / | \
|
||||
* Relu | Abs
|
||||
* | \ |
|
||||
* | Sub
|
||||
* | Subtract
|
||||
* | |
|
||||
* | Multiply
|
||||
* | |
|
||||
@ -121,6 +122,25 @@ public:
|
||||
PReluFusionAbsSubMulMulAdd();
|
||||
};
|
||||
|
||||
/**
|
||||
* @ingroup ie_transformation_common_api
|
||||
* @brief PReluFusionNegReluMulAdd transformation replaces a sub-graph
|
||||
* Op
|
||||
* / \
|
||||
* Relu Negative
|
||||
* | |
|
||||
* | Relu
|
||||
* | |
|
||||
* | Multiply
|
||||
* \ /
|
||||
* Add
|
||||
*/
|
||||
class ov::pass::PReluFusionNegReluMulAdd : public ov::pass::MatcherPass {
|
||||
public:
|
||||
OPENVINO_RTTI("PReluFusionNegReluMulAdd", "0");
|
||||
PReluFusionNegReluMulAdd();
|
||||
};
|
||||
|
||||
/**
|
||||
* @ingroup ie_transformation_common_api
|
||||
* @brief PReluFusion transformation replaces various sub-graphs with a PRelu op.
|
||||
@ -134,5 +154,6 @@ public:
|
||||
add_matcher<ov::pass::PReluFusionMultiplyAdd>();
|
||||
add_matcher<ov::pass::PReluFusionMultiplySub>();
|
||||
add_matcher<ov::pass::PReluFusionAbsSubMulMulAdd>();
|
||||
add_matcher<ov::pass::PReluFusionNegReluMulAdd>();
|
||||
}
|
||||
};
|
||||
|
@ -201,3 +201,38 @@ ov::pass::PReluFusionAbsSubMulMulAdd::PReluFusionAbsSubMulMulAdd() {
|
||||
auto m = make_shared<pattern::Matcher>(add, matcher_name);
|
||||
register_matcher(m, callback);
|
||||
}
|
||||
|
||||
ov::pass::PReluFusionNegReluMulAdd::PReluFusionNegReluMulAdd() {
|
||||
MATCHER_SCOPE(PReluFusionNegReluMulAdd);
|
||||
|
||||
using namespace std;
|
||||
using namespace ov;
|
||||
using namespace ov::opset10;
|
||||
|
||||
const auto input = pass::pattern::any_input();
|
||||
const auto relu_pos = pattern::wrap_type<Relu>({input});
|
||||
const auto neg1 = pattern::wrap_type<Negative>({input});
|
||||
const auto relu_neg = pattern::wrap_type<Relu>({neg1});
|
||||
const auto mul_constant = pattern::wrap_type<Constant>();
|
||||
const auto mul = pattern::wrap_type<Multiply>({relu_neg, mul_constant});
|
||||
const auto add = pattern::wrap_type<Add>({relu_pos, mul});
|
||||
|
||||
matcher_pass_callback callback = [=](pattern::Matcher& m) {
|
||||
const auto& pattern_to_output = m.get_pattern_value_map();
|
||||
const auto input_output = pattern_to_output.at(input);
|
||||
const auto add_node = pattern_to_output.at(add).get_node_shared_ptr();
|
||||
const auto slope = op::util::make_try_fold<Negative>(pattern_to_output.at(mul_constant));
|
||||
const auto prelu = make_shared<PRelu>(input_output, slope);
|
||||
prelu->set_friendly_name(m.get_match_root()->get_friendly_name());
|
||||
NodeVector copy_from = {pattern_to_output.at(relu_pos).get_node_shared_ptr(),
|
||||
pattern_to_output.at(neg1).get_node_shared_ptr(),
|
||||
pattern_to_output.at(relu_neg).get_node_shared_ptr(),
|
||||
pattern_to_output.at(mul).get_node_shared_ptr(),
|
||||
pattern_to_output.at(add).get_node_shared_ptr()};
|
||||
copy_runtime_info(copy_from, prelu);
|
||||
replace_node(add_node, prelu);
|
||||
return true;
|
||||
};
|
||||
auto matcher = make_shared<pattern::Matcher>(add, matcher_name);
|
||||
register_matcher(matcher, callback);
|
||||
}
|
||||
|
@ -156,4 +156,29 @@ TEST_F(TransformationTestsF, PReluFusionAbsSubMulMulAdd) {
|
||||
const auto prelu = make_shared<PRelu>(data, prelu_const);
|
||||
function_ref = make_shared<Function>(NodeVector{prelu}, ParameterVector{data});
|
||||
}
|
||||
comparator.enable(FunctionsComparator::CmpValues::CONST_VALUES);
|
||||
}
|
||||
|
||||
TEST_F(TransformationTestsF, PReluFusionNegReluMulAdd) {
|
||||
using namespace std;
|
||||
using namespace ov::opset10;
|
||||
{
|
||||
const auto data = make_shared<Parameter>(element::f32, Shape{2, 12});
|
||||
const auto relu_pos = make_shared<Relu>(data);
|
||||
const auto neg = make_shared<Negative>(data);
|
||||
const auto relu_neg = make_shared<Relu>(neg);
|
||||
const auto mul_const = Constant::create(element::f32, Shape{1}, {0.235});
|
||||
const auto mul = make_shared<Multiply>(relu_neg, mul_const);
|
||||
const auto add = make_shared<Add>(relu_pos, mul);
|
||||
function = make_shared<Function>(NodeVector{add}, ParameterVector{data});
|
||||
|
||||
manager.register_pass<ov::pass::PReluFusion>();
|
||||
}
|
||||
{
|
||||
const auto data = make_shared<Parameter>(element::f32, Shape{2, 12});
|
||||
const auto prelu_const = Constant::create(element::f32, Shape{1}, {-0.235});
|
||||
const auto prelu = make_shared<PRelu>(data, prelu_const);
|
||||
function_ref = make_shared<Function>(NodeVector{prelu}, ParameterVector{data});
|
||||
}
|
||||
comparator.enable(FunctionsComparator::CmpValues::CONST_VALUES);
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user