Added MVN fusion for case with constants inside (#4961)
This commit is contained in:
parent
61e9d020d4
commit
621e36ee79
@ -16,7 +16,9 @@
|
||||
namespace ngraph {
|
||||
namespace pass {
|
||||
|
||||
class TRANSFORMATIONS_API MVNFusion;
|
||||
class TRANSFORMATIONS_API MVNFusion;
|
||||
class TRANSFORMATIONS_API MVNFusionWithoutConstants;
|
||||
class TRANSFORMATIONS_API MVNFusionWithConstantsInside;
|
||||
|
||||
} // namespace pass
|
||||
} // namespace ngraph
|
||||
@ -26,8 +28,32 @@ namespace pass {
|
||||
* @brief MVNFusion transformation replaces group of
|
||||
* operations: (x - ReduceMean(x, axes)) / (Sqrt(ReduceMean((x - ReduceMean(x, axes)) ^ 2)) + eps) to MVN op.
|
||||
*/
|
||||
class ngraph::pass::MVNFusion : public ngraph::pass::MatcherPass {
|
||||
class ngraph::pass::MVNFusionWithoutConstants : public ngraph::pass::MatcherPass {
|
||||
public:
|
||||
NGRAPH_RTTI_DECLARATION;
|
||||
MVNFusion();
|
||||
MVNFusionWithoutConstants();
|
||||
};
|
||||
|
||||
/**
|
||||
* @ingroup ie_transformation_common_api
|
||||
* @brief MVNFusion transformation replaces group of
|
||||
* operations: gamma * (x - ReduceMean(x, axes)) / (Sqrt(ReduceMean((x - ReduceMean(x, axes)) ^ 2)) + eps) - beta to MVN op.
|
||||
*/
|
||||
class ngraph::pass::MVNFusionWithConstantsInside : public ngraph::pass::MatcherPass {
|
||||
public:
|
||||
NGRAPH_RTTI_DECLARATION;
|
||||
MVNFusionWithConstantsInside();
|
||||
};
|
||||
|
||||
/**
|
||||
* @ingroup ie_transformation_common_api
|
||||
* @brief MVNFusion transformation replaces various sub-graphs with a MVN op.
|
||||
*/
|
||||
class ngraph::pass::MVNFusion: public ngraph::pass::GraphRewrite {
|
||||
public:
|
||||
NGRAPH_RTTI_DECLARATION;
|
||||
MVNFusion() {
|
||||
add_matcher<ngraph::pass::MVNFusionWithoutConstants>();
|
||||
add_matcher<ngraph::pass::MVNFusionWithConstantsInside>();
|
||||
}
|
||||
};
|
||||
|
@ -27,8 +27,10 @@ std::function<bool(ngraph::Output<ngraph::Node>)> value_is_equal_to(const std::v
|
||||
};
|
||||
}
|
||||
|
||||
ngraph::pass::MVNFusion::MVNFusion() {
|
||||
MATCHER_SCOPE(MVNFusion);
|
||||
NGRAPH_RTTI_DEFINITION(ngraph::pass::MVNFusionWithoutConstants, "MVNFusionWithoutConstants", 0);
|
||||
|
||||
ngraph::pass::MVNFusionWithoutConstants::MVNFusionWithoutConstants() {
|
||||
MATCHER_SCOPE(MVNFusionWithoutConstants);
|
||||
// Detect MVN decomposition pattern:
|
||||
// (x - ReduceMean(x, axes)) / (Sqrt(ReduceMean((x - ReduceMean(x, axes)) ^ 2)) + eps)
|
||||
auto x = pattern::any_input();
|
||||
@ -188,3 +190,113 @@ ngraph::pass::MVNFusion::MVNFusion() {
|
||||
auto m = std::make_shared<ngraph::pattern::Matcher>(powerMulOrDiv, matcher_name);
|
||||
register_matcher(m, matcher_pass_callback);
|
||||
}
|
||||
|
||||
NGRAPH_RTTI_DEFINITION(ngraph::pass::MVNFusionWithConstantsInside, "MVNFusionWithConstantsInside", 0);
|
||||
|
||||
ngraph::pass::MVNFusionWithConstantsInside::MVNFusionWithConstantsInside() {
|
||||
MATCHER_SCOPE(MVNFusionWithConstantsInside);
|
||||
// Detect MVN decomposition pattern:
|
||||
// (x - ReduceMean(x, axes)) * gamma / (Sqrt(ReduceMean((x - ReduceMean(x, axes)) ^ 2)) + eps) - beta
|
||||
auto x = pattern::any_input();
|
||||
|
||||
// (x - ReduceMean(x, axes))^2
|
||||
// `------mean1-------'
|
||||
auto mean1_axes = pattern::wrap_type<opset6::Constant>();
|
||||
auto mean1 = pattern::wrap_type<opset6::ReduceMean>({ x, mean1_axes });
|
||||
|
||||
// (x - ReduceMean(x, axes))^2
|
||||
// `-squared_difference------'
|
||||
auto squared_difference = pattern::wrap_type<opset6::SquaredDifference>({ x, mean1 });
|
||||
|
||||
// 1 / Sqrt(ReduceMean((x - ReduceMean(x, axes)) ^ 2) + eps)
|
||||
// `---mean2--------------------------------'
|
||||
auto mean2_axes = pattern::wrap_type<opset6::Constant>();
|
||||
auto mean2 = pattern::wrap_type<opset6::ReduceMean>({ squared_difference, mean2_axes });
|
||||
|
||||
// 1 / Sqrt(ReduceMean((x - ReduceMean(x, axes)) ^ 2) + eps)
|
||||
// `------------------------------------------add--'
|
||||
auto eps = pattern::wrap_type<opset6::Constant>();
|
||||
auto add_eps = pattern::wrap_type<opset6::Add>({ mean2, eps });
|
||||
|
||||
// 1 / Sqrt(ReduceMean((x - ReduceMean(x, axes)) ^ 2) + eps)
|
||||
// `-power-------------------------------------------------'
|
||||
auto const_0_5 = pattern::wrap_type<opset6::Constant>(value_is_equal_to<float>({-0.5}));
|
||||
auto power = pattern::wrap_type<opset6::Power>({ add_eps, const_0_5 });
|
||||
|
||||
// gamma / Sqrt(ReduceMean((x - ReduceMean(x, axes)) ^ 2) + eps)
|
||||
// `---mul1----------------------------------------------------'
|
||||
auto gamma = pattern::wrap_type<opset6::Constant>();
|
||||
auto mul1 = pattern::wrap_type<opset6::Multiply>({ power, gamma });
|
||||
|
||||
// x * gamma / Sqrt(ReduceMean((x - ReduceMean(x, axes)) ^ 2) + eps)
|
||||
// `---mul2--------------------------------------------------------'
|
||||
auto mul2 = pattern::wrap_type<opset6::Multiply>({ x, mul1 });
|
||||
|
||||
// ReduceMean(x, axes) * gamma / Sqrt(ReduceMean((x - ReduceMean(x, axes)) ^ 2) + eps) - beta
|
||||
// `-------------------mul3----------------------------------------------------------'
|
||||
auto mul3 = pattern::wrap_type<opset6::Multiply>({ mul1, mean1 });
|
||||
|
||||
// beta - ReduceMean(x, axes) * gamma / Sqrt(ReduceMean((x - ReduceMean(x, axes)) ^ 2) + eps)
|
||||
// `---sub-----------------------------------------------------------------------------------'
|
||||
auto beta = pattern::wrap_type<opset6::Constant>();
|
||||
auto sub = pattern::wrap_type<opset6::Subtract>({ beta, mul3 });
|
||||
|
||||
// Final Add
|
||||
// x * gamma / Sqrt(ReduceMean((x - ReduceMean(x, axes)) ^ 2) + eps) +
|
||||
// beta - ReduceMean(x, axes) * gamma / Sqrt(ReduceMean((x - ReduceMean(x, axes)) ^ 2) + eps) =
|
||||
// gamma * (x - ReduceMean(x, axes)) / Sqrt(ReduceMean((x - ReduceMean(x, axes)) ^ 2) + eps) - beta
|
||||
auto add = pattern::wrap_type<opset6::Add>({ mul2, sub });
|
||||
|
||||
ngraph::matcher_pass_callback matcher_pass_callback = [=](ngraph::pattern::Matcher& m) {
|
||||
auto& pattern_to_output = m.get_pattern_value_map();
|
||||
auto x_output = pattern_to_output.at(x);
|
||||
|
||||
auto const_0_5_node = std::dynamic_pointer_cast<ngraph::opset6::Constant>(pattern_to_output.at(const_0_5).get_node_shared_ptr());
|
||||
if (!const_0_5_node) {
|
||||
return false;
|
||||
}
|
||||
|
||||
auto const_gamma_node = std::dynamic_pointer_cast<ngraph::opset6::Constant>(pattern_to_output.at(gamma).get_node_shared_ptr());
|
||||
auto const_beta_node = std::dynamic_pointer_cast<ngraph::opset6::Constant>(pattern_to_output.at(beta).get_node_shared_ptr());
|
||||
auto const_eps_node = std::dynamic_pointer_cast<ngraph::opset6::Constant>(pattern_to_output.at(eps).get_node_shared_ptr());
|
||||
float eps_value;
|
||||
|
||||
bool valid_constant_values = op::util::has_constant_value<float>(const_0_5_node, -0.5) && op::util::get_single_value(const_eps_node, eps_value);
|
||||
if (!valid_constant_values) {
|
||||
return false;
|
||||
}
|
||||
|
||||
auto axes_1_node = std::dynamic_pointer_cast<ngraph::opset6::Constant>(pattern_to_output.at(mean1_axes).get_node_shared_ptr());
|
||||
auto axes_2_node = std::dynamic_pointer_cast<ngraph::opset6::Constant>(pattern_to_output.at(mean2_axes).get_node_shared_ptr());
|
||||
if (!axes_1_node || !axes_2_node) {
|
||||
return false;
|
||||
}
|
||||
|
||||
auto axes_1_value = axes_1_node->cast_vector<int64_t>();
|
||||
auto axes_2_value = axes_2_node->cast_vector<int64_t>();
|
||||
if (axes_1_value != axes_2_value) {
|
||||
return false;
|
||||
}
|
||||
|
||||
auto mvn = std::make_shared<ngraph::opset6::MVN>(x_output, axes_1_node, true, eps_value, op::MVNEpsMode::INSIDE_SQRT);
|
||||
auto mul_gamma = std::make_shared<ngraph::opset6::Multiply>(mvn, const_gamma_node);
|
||||
auto add_beta = std::make_shared<ngraph::opset6::Add>(mul_gamma, const_beta_node);
|
||||
|
||||
ngraph::copy_runtime_info({ pattern_to_output.at(mean1).get_node_shared_ptr(),
|
||||
pattern_to_output.at(squared_difference).get_node_shared_ptr(),
|
||||
pattern_to_output.at(add_eps).get_node_shared_ptr(),
|
||||
pattern_to_output.at(power).get_node_shared_ptr(),
|
||||
pattern_to_output.at(mul1).get_node_shared_ptr(),
|
||||
pattern_to_output.at(mul2).get_node_shared_ptr(),
|
||||
pattern_to_output.at(mul3).get_node_shared_ptr(),
|
||||
pattern_to_output.at(sub).get_node_shared_ptr(),
|
||||
pattern_to_output.at(add).get_node_shared_ptr() },
|
||||
{ mvn, const_gamma_node, mul_gamma, const_beta_node, add_beta });
|
||||
add_beta->set_friendly_name(m.get_match_root()->get_friendly_name());
|
||||
ngraph::replace_node(m.get_match_root(), add_beta);
|
||||
return true;
|
||||
};
|
||||
|
||||
auto m = std::make_shared<ngraph::pattern::Matcher>(add, matcher_name);
|
||||
register_matcher(m, matcher_pass_callback);
|
||||
}
|
||||
|
@ -419,3 +419,49 @@ TEST(TransformationTests, MVNFusionTestAltDivInsideSqrt) {
|
||||
auto res = compare_functions(f, f_ref);
|
||||
ASSERT_TRUE(res.first) << res.second;
|
||||
}
|
||||
|
||||
TEST(TransformationTests, MVNFusionTestWithParametersInside) {
|
||||
std::shared_ptr<ngraph::Function> f(nullptr), f_ref(nullptr);
|
||||
{
|
||||
auto input = std::make_shared<ngraph::opset6::Parameter>(ngraph::element::f32, ngraph::Shape{ 1, 3, 224 });
|
||||
auto mean1_axes = ngraph::opset6::Constant::create(ngraph::element::i32, ngraph::Shape{ 1 }, { 2 });
|
||||
auto mean1 = std::make_shared<ngraph::opset6::ReduceMean>(input, mean1_axes, true);
|
||||
auto squared_difference = std::make_shared<ngraph::opset6::SquaredDifference>(input, mean1);
|
||||
auto mean2_axes = ngraph::opset6::Constant::create(ngraph::element::i32, ngraph::Shape{ 1 }, { 2 });
|
||||
auto mean2 = std::make_shared<ngraph::opset6::ReduceMean>(squared_difference, mean2_axes, true);
|
||||
auto eps = ngraph::opset6::Constant::create(ngraph::element::f32, ngraph::Shape{}, { 1e-9 });
|
||||
auto add_eps = std::make_shared<ngraph::opset6::Add>(mean2, eps);
|
||||
auto const_0_5 = ngraph::opset6::Constant::create(ngraph::element::f32, ngraph::Shape{}, { -0.5 });
|
||||
auto power_sqrt = std::make_shared<ngraph::opset6::Power>(add_eps, const_0_5);
|
||||
auto gamma = ngraph::opset6::Constant::create(ngraph::element::f32, ngraph::Shape{}, { 1 });
|
||||
auto mul_gamma = std::make_shared<ngraph::opset6::Multiply>(power_sqrt, gamma);
|
||||
auto mul1 = std::make_shared<ngraph::opset6::Multiply>(input, mul_gamma);
|
||||
auto mul2 = std::make_shared<ngraph::opset6::Multiply>(mul_gamma, mean1);
|
||||
auto beta = ngraph::opset6::Constant::create(ngraph::element::f32, ngraph::Shape{}, { -1 });
|
||||
auto sub = std::make_shared<ngraph::opset6::Subtract>(beta, mul2);
|
||||
auto add = std::make_shared<ngraph::opset6::Add>(mul1, sub);
|
||||
|
||||
f = std::make_shared<ngraph::Function>(ngraph::NodeVector{ add }, ngraph::ParameterVector{ input });
|
||||
|
||||
ngraph::pass::Manager manager;
|
||||
manager.register_pass<ngraph::pass::InitNodeInfo>();
|
||||
manager.register_pass<ngraph::pass::MVNFusion>();
|
||||
manager.run_passes(f);
|
||||
ASSERT_NO_THROW(check_rt_info(f));
|
||||
}
|
||||
|
||||
{
|
||||
auto input = std::make_shared<ngraph::opset6::Parameter>(ngraph::element::f32, ngraph::Shape{ 1, 3, 224 });
|
||||
auto axes = ngraph::opset6::Constant::create(ngraph::element::i32, ngraph::Shape{ 1 }, { 2 });
|
||||
auto mvn = std::make_shared<ngraph::opset6::MVN>(input, axes, true, 1e-9, ngraph::op::MVNEpsMode::INSIDE_SQRT);
|
||||
auto gamma = ngraph::opset6::Constant::create(ngraph::element::f32, ngraph::Shape{}, { 1 });
|
||||
auto mul_gamma = std::make_shared<ngraph::opset6::Multiply>(mvn, gamma);
|
||||
auto beta = ngraph::opset6::Constant::create(ngraph::element::f32, ngraph::Shape{}, { -1 });
|
||||
auto add = std::make_shared<ngraph::opset6::Add>(mul_gamma, beta);
|
||||
|
||||
f_ref = std::make_shared<ngraph::Function>(ngraph::NodeVector{ add }, ngraph::ParameterVector{ input });
|
||||
}
|
||||
|
||||
auto res = compare_functions(f, f_ref);
|
||||
ASSERT_TRUE(res.first) << res.second;
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user