Enable SubractFusion and DivideFusion in MOC (#7949)

* Keep changes

* Enabled DivideFusion and ConvertDivideWithConstant in MOC

* Enable SubtractFusion in MOC; Remove eltwise fusion from MO

* Temporary disable fusions

* Temporary disable ConvertDivide folding

* Update ConvertDivide

* Update remove filtering boxes pass execution
This commit is contained in:
Gleb Kazantaev 2021-11-11 12:20:45 +03:00 committed by GitHub
parent 869408075c
commit abc554513f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
14 changed files with 503 additions and 153 deletions

View File

@ -0,0 +1,28 @@
// Copyright (C) 2021 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#pragma once
#include <memory>
#include <ngraph/pass/graph_rewrite.hpp>
#include <transformations_visibility.hpp>
namespace ngraph {
namespace pass {
class TRANSFORMATIONS_API DivideFusion;
} // namespace pass
} // namespace ngraph
/**
* @ingroup ie_transformation_common_api
* @brief DivideFusion transformation replaces a sub-graph
* Pow(y, -1) * x or x * Pow(y, -1) with Divide(x,y)
*/
class ngraph::pass::DivideFusion : public ngraph::pass::MatcherPass {
public:
NGRAPH_RTTI_DECLARATION;
DivideFusion();
};

View File

@ -14,11 +14,18 @@
namespace ngraph {
namespace pass {
class TRANSFORMATIONS_API FuseFilteringBoxesBySize;
class TRANSFORMATIONS_API RemoveFilteringBoxesBySize;
} // namespace pass
} // namespace ngraph
class ngraph::pass::FuseFilteringBoxesBySize: public ngraph::pass::GraphRewrite {
public:
NGRAPH_RTTI_DECLARATION;
FuseFilteringBoxesBySize();
};
class ngraph::pass::RemoveFilteringBoxesBySize: public ngraph::pass::MatcherPass {
public:
NGRAPH_RTTI_DECLARATION;

View File

@ -0,0 +1,28 @@
// Copyright (C) 2021 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#pragma once
#include <memory>
#include <ngraph/pass/graph_rewrite.hpp>
#include <transformations_visibility.hpp>
namespace ngraph {
namespace pass {
class TRANSFORMATIONS_API SubtractFusion;
} // namespace pass
} // namespace ngraph
/**
* @ingroup ie_transformation_common_api
* @brief SubtractFusion transformation replaces a sub-graph
* Mul(y, -1) + x or x + Mul(y, -1) with Subtract(x,y)
*/
class ngraph::pass::SubtractFusion : public ngraph::pass::MatcherPass {
public:
NGRAPH_RTTI_DECLARATION;
SubtractFusion();
};

View File

@ -15,6 +15,7 @@ namespace ngraph {
namespace pass {
class TRANSFORMATIONS_API ConvertDivide;
class TRANSFORMATIONS_API ConvertDivideWithConstant;
} // namespace pass
} // namespace ngraph
@ -24,3 +25,9 @@ public:
NGRAPH_RTTI_DECLARATION;
ConvertDivide();
};
class ngraph::pass::ConvertDivideWithConstant: public ngraph::pass::MatcherPass {
public:
NGRAPH_RTTI_DECLARATION;
ConvertDivideWithConstant();
};

View File

@ -0,0 +1,47 @@
// Copyright (C) 2021 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#include "transformations/common_optimizations/divide_fusion.hpp"
#include <memory>
#include <ngraph/opsets/opset8.hpp>
#include <ngraph/pattern/op/wrap_type.hpp>
#include <ngraph/rt_info.hpp>
#include "itt.hpp"
#include "transformations/utils/utils.hpp"
NGRAPH_RTTI_DEFINITION(ngraph::pass::DivideFusion, "DivideFusion", 0);
ngraph::pass::DivideFusion::DivideFusion() {
MATCHER_SCOPE(DivideFusion);
auto p_pow_input = pattern::any_input();
auto p_pow_const = pattern::wrap_type<opset8::Constant>();
auto p_pow = pattern::wrap_type<opset8::Power>({p_pow_input, p_pow_const});
auto p_mul_input = pattern::any_input();
auto p_mul = ngraph::pattern::wrap_type<opset8::Multiply>({p_mul_input, p_pow});
matcher_pass_callback callback = [=](pattern::Matcher &m) {
const auto & pattern_to_output = m.get_pattern_value_map();
const auto & minuend_input = pattern_to_output.at(p_mul_input);
const auto & subtrahend_input = pattern_to_output.at(p_pow_input);
const auto & mul = pattern_to_output.at(p_mul).get_node_shared_ptr();
const auto & pow = pattern_to_output.at(p_pow).get_node_shared_ptr();
const auto & minus_one = pattern_to_output.at(p_pow_const).get_node_shared_ptr();
auto minus_one_const = std::dynamic_pointer_cast<opset8::Constant>(minus_one);
if (!minus_one_const || !op::util::has_constant_value<float>(minus_one_const, -1.)) {
return false;
}
auto div = register_new_node<opset8::Divide>(minuend_input, subtrahend_input);
div->set_friendly_name(mul->get_friendly_name());
copy_runtime_info({mul, pow}, div);
replace_node(mul, div);
return true;
};
auto m = std::make_shared<ngraph::pattern::Matcher>(p_mul, matcher_name);
register_matcher(m, callback);
}

View File

@ -48,7 +48,10 @@
#include <transformations/common_optimizations/transpose_to_reshape.hpp>
#include <transformations/common_optimizations/batch_to_space_fusion.hpp>
#include <transformations/common_optimizations/mul_conv_fusion.hpp>
#include "transformations/common_optimizations/split_concat_pair_to_interpolate_fusion.hpp"
#include <transformations/common_optimizations/split_concat_pair_to_interpolate_fusion.hpp>
#include <transformations/op_conversions/convert_divide.hpp>
#include <transformations/common_optimizations/divide_fusion.hpp>
#include <transformations/common_optimizations/subtract_fusion.hpp>
NGRAPH_RTTI_DEFINITION(ngraph::pass::MOCTransformations, "MOCTransformations", 0);
@ -76,7 +79,10 @@ bool ngraph::pass::MOCTransformations::run_on_function(std::shared_ptr<ngraph::F
}
manager.register_pass<ngraph::pass::DisableRandomUniformConstantFolding>();
manager.register_pass<ngraph::pass::ConstantFolding>();
manager.register_pass<ngraph::pass::RemoveFilteringBoxesBySize>();
// FusedFilteringBoxesBySize transformation has the complex pattern
// which can be affected by further transformations. So we have to
// execute it at the beginning of the pipeline.
manager.register_pass<ngraph::pass::FuseFilteringBoxesBySize>();
manager.register_pass<ngraph::pass::ConvertQuantizeDequantize>();
manager.register_pass<ngraph::pass::SimplifyShapeOfSubGraph>();
if (!m_use_shapes) {
@ -122,6 +128,8 @@ bool ngraph::pass::MOCTransformations::run_on_function(std::shared_ptr<ngraph::F
common_fusions->add_matcher<ngraph::pass::LeakyReluFusion>();
common_fusions->add_matcher<ngraph::pass::RandomUniformFusion>();
common_fusions->add_matcher<ngraph::pass::SplitConcatPairToInterpolateFusion>(m_use_shapes);
common_fusions->add_matcher<ngraph::pass::DivideFusion>();
common_fusions->add_matcher<ngraph::pass::SubtractFusion>();
common_fusions->set_name("ngraph::pass::CommonFusions");
manager.register_pass<ngraph::pass::BinarizeWeights>();
@ -129,6 +137,7 @@ bool ngraph::pass::MOCTransformations::run_on_function(std::shared_ptr<ngraph::F
auto decomp = manager.register_pass<ngraph::pass::GraphRewrite>();
decomp->add_matcher<ngraph::pass::BatchNormDecomposition>();
decomp->add_matcher<ngraph::pass::ConvertDivideWithConstant>();
manager.register_pass<ngraph::pass::LinOpSequenceFusion>();

View File

@ -8,11 +8,19 @@
#include <ngraph/opsets/opset3.hpp>
#include <ngraph/rt_info.hpp>
#include <ngraph/pass/manager.hpp>
#include "transformations/common_optimizations/remove_filtering_boxes_by_size.hpp"
#include "transformations/common_optimizations/subtract_fusion.hpp"
NGRAPH_RTTI_DEFINITION(ngraph::pass::FuseFilteringBoxesBySize, "FuseFilteringBoxesBySize", 0);
NGRAPH_RTTI_DEFINITION(ngraph::pass::RemoveFilteringBoxesBySize, "RemoveFilteringBoxesBySize", 0);
ngraph::pass::FuseFilteringBoxesBySize::FuseFilteringBoxesBySize() {
add_matcher<SubtractFusion>();
add_matcher<RemoveFilteringBoxesBySize>();
}
ngraph::pass::RemoveFilteringBoxesBySize::RemoveFilteringBoxesBySize() {
MATCHER_SCOPE(RemoveFilteringBoxesBySize);
// variadic split
@ -85,9 +93,9 @@ ngraph::pass::RemoveFilteringBoxesBySize::RemoveFilteringBoxesBySize() {
auto start = opset3::Constant::create(element::i64, Shape{}, std::vector<int64_t >({0}));
auto step = opset3::Constant::create(element::i64, Shape{}, std::vector<int64_t >({1}));
auto pattern_map = m.get_pattern_map();
const auto & pattern_map = m.get_pattern_map();
auto input = pattern_map[data];
auto input = pattern_map.at(data);
auto output = m.get_match_root();
auto input_shape = std::make_shared<ngraph::opset3::ShapeOf>(input);

View File

@ -0,0 +1,60 @@
// Copyright (C) 2021 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#include "transformations/common_optimizations/subtract_fusion.hpp"
#include <memory>
#include <ngraph/opsets/opset8.hpp>
#include <ngraph/pattern/op/or.hpp>
#include <ngraph/pattern/op/wrap_type.hpp>
#include <ngraph/rt_info.hpp>
#include "itt.hpp"
#include "transformations/utils/utils.hpp"
NGRAPH_RTTI_DEFINITION(ngraph::pass::SubtractFusion, "SubtractFusion", 0);
ngraph::pass::SubtractFusion::SubtractFusion() {
MATCHER_SCOPE(SubtractFusion);
auto p_input = pattern::any_input();
auto p_mul_const = pattern::wrap_type<opset8::Constant>();
auto p_mul = pattern::wrap_type<opset8::Multiply>({p_input, p_mul_const});
auto p_neg = pattern::wrap_type<opset8::Negative>({p_input});
auto p_mul_or_neg = std::make_shared<pattern::op::Or>(OutputVector({p_mul, p_neg}));
auto p_add_input = pattern::any_input();
auto p_add = ngraph::pattern::wrap_type<opset8::Add>({p_add_input, p_mul_or_neg});
matcher_pass_callback callback = [=](pattern::Matcher &m) {
const auto & pattern_to_output = m.get_pattern_value_map();
const auto & minuend_input = pattern_to_output.at(p_add_input);
const auto & subtrahend_input = pattern_to_output.at(p_input);
const auto & add = pattern_to_output.at(p_add).get_node_shared_ptr();
NodeVector nodes_to_replace{add};
if (pattern_to_output.count(p_mul_const)) {
auto minus_one_const = std::dynamic_pointer_cast<opset8::Constant>(pattern_to_output.at(p_mul_const).get_node_shared_ptr());
if (!op::util::has_constant_value<float>(minus_one_const, -1.)) {
return false;
}
nodes_to_replace.emplace_back(pattern_to_output.at(p_mul).get_node_shared_ptr());
} else {
nodes_to_replace.emplace_back(pattern_to_output.at(p_neg).get_node_shared_ptr());
}
auto sub = register_new_node<opset8::Subtract>(minuend_input, subtrahend_input);
sub->set_friendly_name(add->get_friendly_name());
copy_runtime_info(nodes_to_replace, sub);
replace_node(add, sub);
return true;
};
auto m = std::make_shared<ngraph::pattern::Matcher>(p_add, matcher_name);
register_matcher(m, callback);
}

View File

@ -11,31 +11,65 @@
#include <ngraph/opsets/opset1.hpp>
#include <ngraph/rt_info.hpp>
#include <ngraph/pattern/op/wrap_type.hpp>
#include <ngraph/validation_util.hpp>
#include <ngraph/log.hpp>
NGRAPH_RTTI_DEFINITION(ngraph::pass::ConvertDivide, "ConvertDivide", 0);
NGRAPH_RTTI_DEFINITION(ngraph::pass::ConvertDivideWithConstant, "ConvertDivideWithConstant", 0);
namespace {
bool convert_divide(std::shared_ptr<ngraph::Node> node) {
auto div = std::dynamic_pointer_cast<ngraph::opset1::Divide>(node);
// We can not apply this transformation in case with integer input data type
if (!div || div->get_input_element_type(0).is_integral()
|| div->get_input_element_type(1).is_integral()) {
return false;
}
ngraph::Output<ngraph::Node> pow = std::make_shared<ngraph::opset1::Power>(div->input_value(1),
ngraph::op::Constant::create(div->get_input_element_type(1), ngraph::Shape{}, {-1}));
if (std::dynamic_pointer_cast<ngraph::op::Constant>(div->get_input_node_shared_ptr(1))) {
if (auto const_pow = ngraph::get_constant_from_source(pow)) {
pow = const_pow;
} else {
NGRAPH_DEBUG << "ConvertDivide has failed due to unsupported evaluate type in " << pow.get_node();
return false;
}
} else {
ngraph::copy_runtime_info(div, pow.get_node_shared_ptr());
}
auto mul = std::make_shared<ngraph::opset1::Multiply>(div->input(0).get_source_output(), pow);
mul->set_friendly_name(div->get_friendly_name());
ngraph::copy_runtime_info(div, mul);
ngraph::replace_node(div, mul);
return true;
}
} // namespace
ngraph::pass::ConvertDivide::ConvertDivide() {
MATCHER_SCOPE(ConvertDivide);
auto div = ngraph::pattern::wrap_type<ngraph::opset1::Divide>();
ngraph::matcher_pass_callback callback = [](pattern::Matcher& m) {
auto div = std::dynamic_pointer_cast<ngraph::opset1::Divide> (m.get_match_root());
// We can not apply this transformation in case with integer input data type
if (!div || div->input(0).get_element_type().is_integral()) {
return false;
}
auto pow = std::make_shared<ngraph::opset1::Power>(div->input(1).get_source_output(),
op::Constant::create(div->get_input_element_type(1), Shape{}, {-1}));
auto mul = std::make_shared<ngraph::opset1::Multiply>(div->input(0).get_source_output(), pow);
mul->set_friendly_name(div->get_friendly_name());
ngraph::copy_runtime_info(div, {pow, mul});
ngraph::replace_node(div, mul);
return true;
return convert_divide(m.get_match_root());
};
auto m = std::make_shared<ngraph::pattern::Matcher>(div, matcher_name);
this->register_matcher(m, callback);
}
ngraph::pass::ConvertDivideWithConstant::ConvertDivideWithConstant() {
MATCHER_SCOPE(ConvertDivideWithConstant);
auto div = ngraph::pattern::wrap_type<ngraph::opset1::Divide>(
{pattern::any_input(), pattern::wrap_type<op::Constant>()});
ngraph::matcher_pass_callback callback = [](pattern::Matcher& m) {
return convert_divide(m.get_match_root());
};
auto m = std::make_shared<ngraph::pattern::Matcher>(div, matcher_name);
this->register_matcher(m, callback);
}

View File

@ -32,10 +32,8 @@ TEST_F(TransformationTestsF, ConvertDivide) {
{
auto data = std::make_shared<ngraph::opset1::Parameter>(ngraph::element::f32, ngraph::Shape{3, 1, 2});
auto divide_constant = ngraph::opset1::Constant::create(ngraph::element::f32, ngraph::Shape{1}, {1.5});
auto pow = std::make_shared<ngraph::opset1::Power>(divide_constant,
ngraph::opset1::Constant::create(ngraph::element::f32, ngraph::Shape{}, {-1}));
auto mul = std::make_shared<ngraph::opset1::Multiply>(data, pow);
auto divide_constant = ngraph::opset1::Constant::create(ngraph::element::f32, ngraph::Shape{1}, {1. / 1.5});
auto mul = std::make_shared<ngraph::opset1::Multiply>(data, divide_constant);
function_ref = std::make_shared<ngraph::Function>(ngraph::NodeVector{mul}, ngraph::ParameterVector{data});
}
@ -63,11 +61,11 @@ TEST_F(TransformationTestsF, ConvertDivideNegative) {
TEST_F(TransformationTestsF, ConvertDivideScalar) {
{
auto data = std::make_shared<ngraph::opset1::Parameter>(ngraph::element::f32, ngraph::Shape{});
auto divide_constant = ngraph::opset1::Constant::create(ngraph::element::f32, ngraph::Shape{}, {1.5});
auto divide = std::make_shared<ngraph::opset1::Divide>(data, divide_constant);
auto data1 = std::make_shared<ngraph::opset1::Parameter>(ngraph::element::f32, ngraph::Shape{});
auto data2 = std::make_shared<ngraph::opset1::Parameter>(ngraph::element::f32, ngraph::Shape{});
auto divide = std::make_shared<ngraph::opset1::Divide>(data1, data2);
function = std::make_shared<ngraph::Function>(ngraph::NodeVector{divide}, ngraph::ParameterVector{data});
function = std::make_shared<ngraph::Function>(ngraph::NodeVector{divide}, ngraph::ParameterVector{data1, data2});
NGRAPH_CHECK(divide->get_output_partial_shape(0).rank().get_length() == 0);
@ -76,13 +74,51 @@ TEST_F(TransformationTestsF, ConvertDivideScalar) {
{
auto data = std::make_shared<ngraph::opset1::Parameter>(ngraph::element::f32, ngraph::Shape{});
auto divide_constant = ngraph::opset1::Constant::create(ngraph::element::f32, ngraph::Shape{}, {1.5});
auto pow = std::make_shared<ngraph::opset1::Power>(divide_constant,
auto pow_input = std::make_shared<ngraph::opset1::Parameter>(ngraph::element::f32, ngraph::Shape{});
auto pow = std::make_shared<ngraph::opset1::Power>(pow_input,
ngraph::opset1::Constant::create(ngraph::element::f32, ngraph::Shape{}, {-1}));
auto mul = std::make_shared<ngraph::opset1::Multiply>(data, pow);
function_ref = std::make_shared<ngraph::Function>(ngraph::NodeVector{mul}, ngraph::ParameterVector{data});
function_ref = std::make_shared<ngraph::Function>(ngraph::NodeVector{mul}, ngraph::ParameterVector{data, pow_input});
NGRAPH_CHECK(mul->get_output_partial_shape(0).rank().get_length() == 0);
}
}
TEST_F(TransformationTestsF, ConvertDivideWithConstantPositive) {
{
auto data = std::make_shared<ngraph::opset1::Parameter>(ngraph::element::f32, ngraph::Shape{});
auto divide_constant = ngraph::opset1::Constant::create(ngraph::element::f32, ngraph::Shape{}, {1.5});
auto divide = std::make_shared<ngraph::opset1::Divide>(data, divide_constant);
function = std::make_shared<ngraph::Function>(ngraph::NodeVector{divide}, ngraph::ParameterVector{data});
manager.register_pass<ngraph::pass::ConvertDivideWithConstant>();
}
{
auto data = std::make_shared<ngraph::opset1::Parameter>(ngraph::element::f32, ngraph::Shape{});
auto divide_constant = ngraph::opset1::Constant::create(ngraph::element::f32, ngraph::Shape{}, {1. / 1.5});
auto mul = std::make_shared<ngraph::opset1::Multiply>(data, divide_constant);
function_ref = std::make_shared<ngraph::Function>(ngraph::NodeVector{mul}, ngraph::ParameterVector{data});
}
}
TEST_F(TransformationTestsF, ConvertDivideWithConstantNegative) {
{
auto data1 = std::make_shared<ngraph::opset1::Parameter>(ngraph::element::f32, ngraph::Shape{});
auto data2 = std::make_shared<ngraph::opset1::Parameter>(ngraph::element::f32, ngraph::Shape{});
auto divide = std::make_shared<ngraph::opset1::Divide>(data1, data2);
function = std::make_shared<ngraph::Function>(ngraph::NodeVector{divide}, ngraph::ParameterVector{data1, data2});
manager.register_pass<ngraph::pass::ConvertDivideWithConstant>();
}
{
auto data1 = std::make_shared<ngraph::opset1::Parameter>(ngraph::element::f32, ngraph::Shape{});
auto data2 = std::make_shared<ngraph::opset1::Parameter>(ngraph::element::f32, ngraph::Shape{});
auto divide = std::make_shared<ngraph::opset1::Divide>(data1, data2);
function_ref = std::make_shared<ngraph::Function>(ngraph::NodeVector{divide}, ngraph::ParameterVector{data1, data2});
}
}

View File

@ -0,0 +1,88 @@
// Copyright (C) 2021 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#include <gtest/gtest.h>
#include <string>
#include <memory>
#include <queue>
#include <ngraph/function.hpp>
#include <ngraph/opsets/opset1.hpp>
#include <transformations/init_node_info.hpp>
#include <transformations/utils/utils.hpp>
#include <ngraph/pass/manager.hpp>
#include <transformations/common_optimizations/divide_fusion.hpp>
#include "common_test_utils/ngraph_test_utils.hpp"
using namespace testing;
TEST(TransformationTests, DivideFusion) {
std::shared_ptr<ngraph::Function> f, f_ref;
{
auto data1 = std::make_shared<ngraph::opset1::Parameter>(ngraph::element::f32, ngraph::Shape{3, 1, 2});
auto data2 = std::make_shared<ngraph::opset1::Parameter>(ngraph::element::f32, ngraph::Shape{3, 1, 2});
auto pow_constant = ngraph::opset1::Constant::create(ngraph::element::f32, ngraph::Shape{1}, {-1});
auto pow = std::make_shared<ngraph::opset1::Power>(data2, pow_constant);
auto mul = std::make_shared<ngraph::opset1::Multiply>(data1, pow);
f = std::make_shared<ngraph::Function>(ngraph::NodeVector{mul}, ngraph::ParameterVector{data1, data2});
ngraph::pass::Manager m;
m.register_pass<ngraph::pass::InitNodeInfo>();
m.register_pass<ngraph::pass::DivideFusion>();
m.run_passes(f);
ASSERT_NO_THROW(check_rt_info(f));
}
{
auto data1 = std::make_shared<ngraph::opset1::Parameter>(ngraph::element::f32, ngraph::Shape{3, 1, 2});
auto data2 = std::make_shared<ngraph::opset1::Parameter>(ngraph::element::f32, ngraph::Shape{3, 1, 2});
auto divide = std::make_shared<ngraph::opset1::Divide>(data1, data2);
f_ref = std::make_shared<ngraph::Function>(ngraph::NodeVector{divide}, ngraph::ParameterVector{data1, data2});
}
const auto res = FunctionsComparator::with_default()
.enable(FunctionsComparator::CONST_VALUES)
.enable(FunctionsComparator::ATTRIBUTES)
.compare(f, f_ref);
ASSERT_TRUE(res.valid) << res.message;
}
TEST(TransformationTests, DivideFusionNegative) {
std::shared_ptr<ngraph::Function> f, f_ref;
{
auto data1 = std::make_shared<ngraph::opset1::Parameter>(ngraph::element::f32, ngraph::Shape{3, 1, 2});
auto data2 = std::make_shared<ngraph::opset1::Parameter>(ngraph::element::f32, ngraph::Shape{3, 1, 2});
auto pow_constant = ngraph::opset1::Constant::create(ngraph::element::f32, ngraph::Shape{1}, {-1.01});
auto pow = std::make_shared<ngraph::opset1::Power>(data2, pow_constant);
auto mul = std::make_shared<ngraph::opset1::Multiply>(data1, pow);
f = std::make_shared<ngraph::Function>(ngraph::NodeVector{mul}, ngraph::ParameterVector{data1, data2});
ngraph::pass::Manager m;
m.register_pass<ngraph::pass::InitNodeInfo>();
m.register_pass<ngraph::pass::DivideFusion>();
m.run_passes(f);
ASSERT_NO_THROW(check_rt_info(f));
}
{
auto data1 = std::make_shared<ngraph::opset1::Parameter>(ngraph::element::f32, ngraph::Shape{3, 1, 2});
auto data2 = std::make_shared<ngraph::opset1::Parameter>(ngraph::element::f32, ngraph::Shape{3, 1, 2});
auto pow_constant = ngraph::opset1::Constant::create(ngraph::element::f32, ngraph::Shape{1}, {-1.01});
auto pow = std::make_shared<ngraph::opset1::Power>(data2, pow_constant);
auto mul = std::make_shared<ngraph::opset1::Multiply>(data1, pow);
f_ref = std::make_shared<ngraph::Function>(ngraph::NodeVector{mul}, ngraph::ParameterVector{data1, data2});
}
const auto res = FunctionsComparator::with_default()
.enable(FunctionsComparator::CONST_VALUES)
.enable(FunctionsComparator::ATTRIBUTES)
.compare(f, f_ref);
ASSERT_TRUE(res.valid) << res.message;
}

View File

@ -0,0 +1,121 @@
// Copyright (C) 2021 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#include <gtest/gtest.h>
#include <string>
#include <memory>
#include <queue>
#include <ngraph/function.hpp>
#include <ngraph/opsets/opset1.hpp>
#include <transformations/init_node_info.hpp>
#include <transformations/utils/utils.hpp>
#include <ngraph/pass/manager.hpp>
#include <transformations/common_optimizations/divide_fusion.hpp>
#include <transformations/common_optimizations/subtract_fusion.hpp>
#include "common_test_utils/ngraph_test_utils.hpp"
using namespace testing;
TEST(TransformationTests, SubtractFusionMultiply) {
std::shared_ptr<ngraph::Function> f, f_ref;
{
auto data1 = std::make_shared<ngraph::opset1::Parameter>(ngraph::element::f32, ngraph::Shape{3, 1, 2});
auto data2 = std::make_shared<ngraph::opset1::Parameter>(ngraph::element::f32, ngraph::Shape{3, 1, 2});
auto mul_constant = ngraph::opset1::Constant::create(ngraph::element::f32, ngraph::Shape{1}, {-1});
auto mul = std::make_shared<ngraph::opset1::Multiply>(data2, mul_constant);
auto add = std::make_shared<ngraph::opset1::Add>(data1, mul);
f = std::make_shared<ngraph::Function>(ngraph::NodeVector{add}, ngraph::ParameterVector{data1, data2});
ngraph::pass::Manager m;
m.register_pass<ngraph::pass::InitNodeInfo>();
m.register_pass<ngraph::pass::SubtractFusion>();
m.run_passes(f);
ASSERT_NO_THROW(check_rt_info(f));
}
{
auto data1 = std::make_shared<ngraph::opset1::Parameter>(ngraph::element::f32, ngraph::Shape{3, 1, 2});
auto data2 = std::make_shared<ngraph::opset1::Parameter>(ngraph::element::f32, ngraph::Shape{3, 1, 2});
auto divide = std::make_shared<ngraph::opset1::Subtract>(data1, data2);
f_ref = std::make_shared<ngraph::Function>(ngraph::NodeVector{divide}, ngraph::ParameterVector{data1, data2});
}
const auto res = FunctionsComparator::with_default()
.enable(FunctionsComparator::CONST_VALUES)
.enable(FunctionsComparator::ATTRIBUTES)
.compare(f, f_ref);
ASSERT_TRUE(res.valid) << res.message;
}
TEST(TransformationTests, SubtractFusionMultiplyNegative) {
std::shared_ptr<ngraph::Function> f, f_ref;
{
auto data1 = std::make_shared<ngraph::opset1::Parameter>(ngraph::element::f32, ngraph::Shape{3, 1, 2});
auto data2 = std::make_shared<ngraph::opset1::Parameter>(ngraph::element::f32, ngraph::Shape{3, 1, 2});
auto mul_constant = ngraph::opset1::Constant::create(ngraph::element::f32, ngraph::Shape{1}, {-1.01});
auto mul = std::make_shared<ngraph::opset1::Multiply>(data2, mul_constant);
auto add = std::make_shared<ngraph::opset1::Add>(data1, mul);
f = std::make_shared<ngraph::Function>(ngraph::NodeVector{add}, ngraph::ParameterVector{data1, data2});
ngraph::pass::Manager m;
m.register_pass<ngraph::pass::InitNodeInfo>();
m.register_pass<ngraph::pass::SubtractFusion>();
m.run_passes(f);
ASSERT_NO_THROW(check_rt_info(f));
}
{
auto data1 = std::make_shared<ngraph::opset1::Parameter>(ngraph::element::f32, ngraph::Shape{3, 1, 2});
auto data2 = std::make_shared<ngraph::opset1::Parameter>(ngraph::element::f32, ngraph::Shape{3, 1, 2});
auto mul_constant = ngraph::opset1::Constant::create(ngraph::element::f32, ngraph::Shape{1}, {-1.01});
auto mul = std::make_shared<ngraph::opset1::Multiply>(data2, mul_constant);
auto add = std::make_shared<ngraph::opset1::Add>(data1, mul);
f_ref = std::make_shared<ngraph::Function>(ngraph::NodeVector{add}, ngraph::ParameterVector{data1, data2});
}
const auto res = FunctionsComparator::with_default()
.enable(FunctionsComparator::CONST_VALUES)
.enable(FunctionsComparator::ATTRIBUTES)
.compare(f, f_ref);
ASSERT_TRUE(res.valid) << res.message;
}
TEST(TransformationTests, SubtractFusionNeg) {
std::shared_ptr<ngraph::Function> f, f_ref;
{
auto data1 = std::make_shared<ngraph::opset1::Parameter>(ngraph::element::f32, ngraph::Shape{3, 1, 2});
auto data2 = std::make_shared<ngraph::opset1::Parameter>(ngraph::element::f32, ngraph::Shape{3, 1, 2});
auto neg = std::make_shared<ngraph::opset1::Negative>(data2);
auto add = std::make_shared<ngraph::opset1::Add>(neg, data1);
f = std::make_shared<ngraph::Function>(ngraph::NodeVector{add}, ngraph::ParameterVector{data1, data2});
ngraph::pass::Manager m;
m.register_pass<ngraph::pass::InitNodeInfo>();
m.register_pass<ngraph::pass::SubtractFusion>();
m.run_passes(f);
ASSERT_NO_THROW(check_rt_info(f));
}
{
auto data1 = std::make_shared<ngraph::opset1::Parameter>(ngraph::element::f32, ngraph::Shape{3, 1, 2});
auto data2 = std::make_shared<ngraph::opset1::Parameter>(ngraph::element::f32, ngraph::Shape{3, 1, 2});
auto divide = std::make_shared<ngraph::opset1::Subtract>(data1, data2);
f_ref = std::make_shared<ngraph::Function>(ngraph::NodeVector{divide}, ngraph::ParameterVector{data1, data2});
}
const auto res = FunctionsComparator::with_default()
.enable(FunctionsComparator::CONST_VALUES)
.enable(FunctionsComparator::ATTRIBUTES)
.compare(f, f_ref);
ASSERT_TRUE(res.valid) << res.message;
}

View File

@ -22,7 +22,6 @@ extensions/back/CutMemory.py
extensions/back/EnableConstantStridedSlice.py
extensions/back/FakeOutputResolver.py
extensions/back/ForceStrictPrecision.py
extensions/back/fuse_sub_div_min.py
extensions/back/FuseTransposesSequence.py
extensions/back/GatherNormalizer.py
extensions/back/insert_compatibility_l2normalization.py

View File

@ -1,122 +0,0 @@
# Copyright (C) 2018-2021 Intel Corporation
# SPDX-License-Identifier: Apache-2.0
import numpy as np
from extensions.ops.elementwise import Sub, Div, Negative
from mo.back.replacement import BackReplacementPattern
from mo.graph.graph import Node, Graph
class Negate(BackReplacementPattern):
enabled = True
force_clean_up = True
@staticmethod
def pattern():
return dict(
nodes=[
('mul', {'type': 'Multiply'})
],
edges=[],
)
@staticmethod
def replace_pattern(graph: Graph, match: [str, Node]):
mul = match['mul']
name = mul.soft_get('name', mul.id)
mul_POS_port = None
if mul.in_port(0).data.get_value() is not None and np.all(mul.in_port(0).data.get_value() == -1):
mul_POS_port = mul.in_port(1)
if mul.in_port(1).data.get_value() is not None and np.all(mul.in_port(1).data.get_value() == -1):
mul_POS_port = mul.in_port(0)
if mul_POS_port is None:
return
negative = Negative(graph, {'name': name + '/Negate'}).create_node()
mul.out_port(0).get_connection().set_source(negative.out_port(0))
mul_POS_port.get_connection().set_destination(negative.in_port(0))
class EltwisesToSubtract(BackReplacementPattern):
enabled = True
force_clean_up = True
def run_after(self):
return [Negate]
@staticmethod
def pattern():
return dict(
nodes=[
('neg', {'type': 'Negative'}),
('neg_d', {}),
('add', {'type': 'Add'})
],
edges=[
('neg', 'neg_d'),
('neg_d', 'add'),
],
)
@staticmethod
def replace_pattern(graph: Graph, match: [str, Node]):
neg = match['neg']
add = match['add']
name = add.soft_get('name', add.id)
minuend_port = add.in_port(0).get_source() \
if add.in_port(1).get_source().node.id == neg.id else add.in_port(1).get_source()
subtrahned_port = neg.in_port(0).get_source()
sub = Sub(graph, {'name': name + '/sub'}).create_node()
add.out_port(0).get_connection().set_source(sub.out_port(0))
minuend_port.connect(sub.in_port(0))
subtrahned_port.connect(sub.in_port(1))
class EltwisesToDiv(BackReplacementPattern):
enabled = True
force_clean_up = True
@staticmethod
def pattern():
return dict(
nodes=[
('const', {'type': 'Const'}),
('const_d', {'value': lambda val: val is not None and np.all(val == -1)}),
('inv', {'type': 'Pow'}),
('inv_d', {}),
('mul', {'type': 'Multiply'})
],
edges=[
('const', 'const_d'),
('const_d', 'inv', {'in': 1}),
('inv', 'inv_d'),
('inv_d', 'mul'),
],
)
@staticmethod
def replace_pattern(graph: Graph, match: [str, Node]):
pow = match['inv']
mul = match['mul']
const = match['const']
name = mul.soft_get('name', mul.id)
devidend_port = mul.in_port(0).get_source() if mul.in_port(1).get_source().node.id == pow.id else mul.in_port(
1).get_source()
divider_port = pow.in_port(0).get_source() if pow.in_port(1).get_source().node.id == const.id else pow.in_port(
1).get_source()
div = Div(graph, {'name': name + '/div'}).create_node()
mul.out_port(0).get_connection().set_source(div.out_port(0))
devidend_port.connect(div.in_port(0))
divider_port.connect(div.in_port(1))