Enable SubractFusion and DivideFusion in MOC (#7949)
* Keep changes * Enabled DivideFusion and ConvertDivideWithConstant in MOC * Enable SubtractFusion in MOC; Remove eltwise fusion from MO * Temporary disable fusions * Temporary disable ConvertDivide folding * Update ConvertDivide * Update remove filtering boxes pass execution
This commit is contained in:
parent
869408075c
commit
abc554513f
@ -0,0 +1,28 @@
|
||||
// Copyright (C) 2021 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <memory>
|
||||
#include <ngraph/pass/graph_rewrite.hpp>
|
||||
#include <transformations_visibility.hpp>
|
||||
|
||||
namespace ngraph {
|
||||
namespace pass {
|
||||
|
||||
class TRANSFORMATIONS_API DivideFusion;
|
||||
|
||||
} // namespace pass
|
||||
} // namespace ngraph
|
||||
|
||||
/**
|
||||
* @ingroup ie_transformation_common_api
|
||||
* @brief DivideFusion transformation replaces a sub-graph
|
||||
* Pow(y, -1) * x or x * Pow(y, -1) with Divide(x,y)
|
||||
*/
|
||||
class ngraph::pass::DivideFusion : public ngraph::pass::MatcherPass {
|
||||
public:
|
||||
NGRAPH_RTTI_DECLARATION;
|
||||
DivideFusion();
|
||||
};
|
@ -14,11 +14,18 @@
|
||||
namespace ngraph {
|
||||
namespace pass {
|
||||
|
||||
class TRANSFORMATIONS_API FuseFilteringBoxesBySize;
|
||||
class TRANSFORMATIONS_API RemoveFilteringBoxesBySize;
|
||||
|
||||
} // namespace pass
|
||||
} // namespace ngraph
|
||||
|
||||
class ngraph::pass::FuseFilteringBoxesBySize: public ngraph::pass::GraphRewrite {
|
||||
public:
|
||||
NGRAPH_RTTI_DECLARATION;
|
||||
FuseFilteringBoxesBySize();
|
||||
};
|
||||
|
||||
class ngraph::pass::RemoveFilteringBoxesBySize: public ngraph::pass::MatcherPass {
|
||||
public:
|
||||
NGRAPH_RTTI_DECLARATION;
|
||||
|
@ -0,0 +1,28 @@
|
||||
// Copyright (C) 2021 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <memory>
|
||||
#include <ngraph/pass/graph_rewrite.hpp>
|
||||
#include <transformations_visibility.hpp>
|
||||
|
||||
namespace ngraph {
|
||||
namespace pass {
|
||||
|
||||
class TRANSFORMATIONS_API SubtractFusion;
|
||||
|
||||
} // namespace pass
|
||||
} // namespace ngraph
|
||||
|
||||
/**
|
||||
* @ingroup ie_transformation_common_api
|
||||
* @brief SubtractFusion transformation replaces a sub-graph
|
||||
* Mul(y, -1) + x or x + Mul(y, -1) with Subtract(x,y)
|
||||
*/
|
||||
class ngraph::pass::SubtractFusion : public ngraph::pass::MatcherPass {
|
||||
public:
|
||||
NGRAPH_RTTI_DECLARATION;
|
||||
SubtractFusion();
|
||||
};
|
@ -15,6 +15,7 @@ namespace ngraph {
|
||||
namespace pass {
|
||||
|
||||
class TRANSFORMATIONS_API ConvertDivide;
|
||||
class TRANSFORMATIONS_API ConvertDivideWithConstant;
|
||||
|
||||
} // namespace pass
|
||||
} // namespace ngraph
|
||||
@ -24,3 +25,9 @@ public:
|
||||
NGRAPH_RTTI_DECLARATION;
|
||||
ConvertDivide();
|
||||
};
|
||||
|
||||
class ngraph::pass::ConvertDivideWithConstant: public ngraph::pass::MatcherPass {
|
||||
public:
|
||||
NGRAPH_RTTI_DECLARATION;
|
||||
ConvertDivideWithConstant();
|
||||
};
|
||||
|
@ -0,0 +1,47 @@
|
||||
// Copyright (C) 2021 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#include "transformations/common_optimizations/divide_fusion.hpp"
|
||||
|
||||
#include <memory>
|
||||
#include <ngraph/opsets/opset8.hpp>
|
||||
#include <ngraph/pattern/op/wrap_type.hpp>
|
||||
#include <ngraph/rt_info.hpp>
|
||||
|
||||
#include "itt.hpp"
|
||||
#include "transformations/utils/utils.hpp"
|
||||
|
||||
NGRAPH_RTTI_DEFINITION(ngraph::pass::DivideFusion, "DivideFusion", 0);
|
||||
|
||||
ngraph::pass::DivideFusion::DivideFusion() {
|
||||
MATCHER_SCOPE(DivideFusion);
|
||||
auto p_pow_input = pattern::any_input();
|
||||
auto p_pow_const = pattern::wrap_type<opset8::Constant>();
|
||||
auto p_pow = pattern::wrap_type<opset8::Power>({p_pow_input, p_pow_const});
|
||||
auto p_mul_input = pattern::any_input();
|
||||
auto p_mul = ngraph::pattern::wrap_type<opset8::Multiply>({p_mul_input, p_pow});
|
||||
|
||||
matcher_pass_callback callback = [=](pattern::Matcher &m) {
|
||||
const auto & pattern_to_output = m.get_pattern_value_map();
|
||||
const auto & minuend_input = pattern_to_output.at(p_mul_input);
|
||||
const auto & subtrahend_input = pattern_to_output.at(p_pow_input);
|
||||
const auto & mul = pattern_to_output.at(p_mul).get_node_shared_ptr();
|
||||
const auto & pow = pattern_to_output.at(p_pow).get_node_shared_ptr();
|
||||
const auto & minus_one = pattern_to_output.at(p_pow_const).get_node_shared_ptr();
|
||||
|
||||
auto minus_one_const = std::dynamic_pointer_cast<opset8::Constant>(minus_one);
|
||||
if (!minus_one_const || !op::util::has_constant_value<float>(minus_one_const, -1.)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
auto div = register_new_node<opset8::Divide>(minuend_input, subtrahend_input);
|
||||
div->set_friendly_name(mul->get_friendly_name());
|
||||
copy_runtime_info({mul, pow}, div);
|
||||
replace_node(mul, div);
|
||||
return true;
|
||||
};
|
||||
|
||||
auto m = std::make_shared<ngraph::pattern::Matcher>(p_mul, matcher_name);
|
||||
register_matcher(m, callback);
|
||||
}
|
@ -48,7 +48,10 @@
|
||||
#include <transformations/common_optimizations/transpose_to_reshape.hpp>
|
||||
#include <transformations/common_optimizations/batch_to_space_fusion.hpp>
|
||||
#include <transformations/common_optimizations/mul_conv_fusion.hpp>
|
||||
#include "transformations/common_optimizations/split_concat_pair_to_interpolate_fusion.hpp"
|
||||
#include <transformations/common_optimizations/split_concat_pair_to_interpolate_fusion.hpp>
|
||||
#include <transformations/op_conversions/convert_divide.hpp>
|
||||
#include <transformations/common_optimizations/divide_fusion.hpp>
|
||||
#include <transformations/common_optimizations/subtract_fusion.hpp>
|
||||
|
||||
NGRAPH_RTTI_DEFINITION(ngraph::pass::MOCTransformations, "MOCTransformations", 0);
|
||||
|
||||
@ -76,7 +79,10 @@ bool ngraph::pass::MOCTransformations::run_on_function(std::shared_ptr<ngraph::F
|
||||
}
|
||||
manager.register_pass<ngraph::pass::DisableRandomUniformConstantFolding>();
|
||||
manager.register_pass<ngraph::pass::ConstantFolding>();
|
||||
manager.register_pass<ngraph::pass::RemoveFilteringBoxesBySize>();
|
||||
// FusedFilteringBoxesBySize transformation has the complex pattern
|
||||
// which can be affected by further transformations. So we have to
|
||||
// execute it at the beginning of the pipeline.
|
||||
manager.register_pass<ngraph::pass::FuseFilteringBoxesBySize>();
|
||||
manager.register_pass<ngraph::pass::ConvertQuantizeDequantize>();
|
||||
manager.register_pass<ngraph::pass::SimplifyShapeOfSubGraph>();
|
||||
if (!m_use_shapes) {
|
||||
@ -122,6 +128,8 @@ bool ngraph::pass::MOCTransformations::run_on_function(std::shared_ptr<ngraph::F
|
||||
common_fusions->add_matcher<ngraph::pass::LeakyReluFusion>();
|
||||
common_fusions->add_matcher<ngraph::pass::RandomUniformFusion>();
|
||||
common_fusions->add_matcher<ngraph::pass::SplitConcatPairToInterpolateFusion>(m_use_shapes);
|
||||
common_fusions->add_matcher<ngraph::pass::DivideFusion>();
|
||||
common_fusions->add_matcher<ngraph::pass::SubtractFusion>();
|
||||
common_fusions->set_name("ngraph::pass::CommonFusions");
|
||||
|
||||
manager.register_pass<ngraph::pass::BinarizeWeights>();
|
||||
@ -129,6 +137,7 @@ bool ngraph::pass::MOCTransformations::run_on_function(std::shared_ptr<ngraph::F
|
||||
|
||||
auto decomp = manager.register_pass<ngraph::pass::GraphRewrite>();
|
||||
decomp->add_matcher<ngraph::pass::BatchNormDecomposition>();
|
||||
decomp->add_matcher<ngraph::pass::ConvertDivideWithConstant>();
|
||||
|
||||
manager.register_pass<ngraph::pass::LinOpSequenceFusion>();
|
||||
|
||||
|
@ -8,11 +8,19 @@
|
||||
|
||||
#include <ngraph/opsets/opset3.hpp>
|
||||
#include <ngraph/rt_info.hpp>
|
||||
#include <ngraph/pass/manager.hpp>
|
||||
|
||||
#include "transformations/common_optimizations/remove_filtering_boxes_by_size.hpp"
|
||||
#include "transformations/common_optimizations/subtract_fusion.hpp"
|
||||
|
||||
NGRAPH_RTTI_DEFINITION(ngraph::pass::FuseFilteringBoxesBySize, "FuseFilteringBoxesBySize", 0);
|
||||
NGRAPH_RTTI_DEFINITION(ngraph::pass::RemoveFilteringBoxesBySize, "RemoveFilteringBoxesBySize", 0);
|
||||
|
||||
ngraph::pass::FuseFilteringBoxesBySize::FuseFilteringBoxesBySize() {
|
||||
add_matcher<SubtractFusion>();
|
||||
add_matcher<RemoveFilteringBoxesBySize>();
|
||||
}
|
||||
|
||||
ngraph::pass::RemoveFilteringBoxesBySize::RemoveFilteringBoxesBySize() {
|
||||
MATCHER_SCOPE(RemoveFilteringBoxesBySize);
|
||||
// variadic split
|
||||
@ -85,9 +93,9 @@ ngraph::pass::RemoveFilteringBoxesBySize::RemoveFilteringBoxesBySize() {
|
||||
auto start = opset3::Constant::create(element::i64, Shape{}, std::vector<int64_t >({0}));
|
||||
auto step = opset3::Constant::create(element::i64, Shape{}, std::vector<int64_t >({1}));
|
||||
|
||||
auto pattern_map = m.get_pattern_map();
|
||||
const auto & pattern_map = m.get_pattern_map();
|
||||
|
||||
auto input = pattern_map[data];
|
||||
auto input = pattern_map.at(data);
|
||||
auto output = m.get_match_root();
|
||||
|
||||
auto input_shape = std::make_shared<ngraph::opset3::ShapeOf>(input);
|
||||
|
@ -0,0 +1,60 @@
|
||||
// Copyright (C) 2021 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#include "transformations/common_optimizations/subtract_fusion.hpp"
|
||||
|
||||
#include <memory>
|
||||
#include <ngraph/opsets/opset8.hpp>
|
||||
#include <ngraph/pattern/op/or.hpp>
|
||||
#include <ngraph/pattern/op/wrap_type.hpp>
|
||||
#include <ngraph/rt_info.hpp>
|
||||
|
||||
#include "itt.hpp"
|
||||
#include "transformations/utils/utils.hpp"
|
||||
|
||||
NGRAPH_RTTI_DEFINITION(ngraph::pass::SubtractFusion, "SubtractFusion", 0);
|
||||
|
||||
ngraph::pass::SubtractFusion::SubtractFusion() {
|
||||
MATCHER_SCOPE(SubtractFusion);
|
||||
auto p_input = pattern::any_input();
|
||||
|
||||
auto p_mul_const = pattern::wrap_type<opset8::Constant>();
|
||||
auto p_mul = pattern::wrap_type<opset8::Multiply>({p_input, p_mul_const});
|
||||
|
||||
auto p_neg = pattern::wrap_type<opset8::Negative>({p_input});
|
||||
|
||||
auto p_mul_or_neg = std::make_shared<pattern::op::Or>(OutputVector({p_mul, p_neg}));
|
||||
|
||||
auto p_add_input = pattern::any_input();
|
||||
auto p_add = ngraph::pattern::wrap_type<opset8::Add>({p_add_input, p_mul_or_neg});
|
||||
|
||||
matcher_pass_callback callback = [=](pattern::Matcher &m) {
|
||||
const auto & pattern_to_output = m.get_pattern_value_map();
|
||||
const auto & minuend_input = pattern_to_output.at(p_add_input);
|
||||
const auto & subtrahend_input = pattern_to_output.at(p_input);
|
||||
|
||||
const auto & add = pattern_to_output.at(p_add).get_node_shared_ptr();
|
||||
|
||||
NodeVector nodes_to_replace{add};
|
||||
|
||||
if (pattern_to_output.count(p_mul_const)) {
|
||||
auto minus_one_const = std::dynamic_pointer_cast<opset8::Constant>(pattern_to_output.at(p_mul_const).get_node_shared_ptr());
|
||||
if (!op::util::has_constant_value<float>(minus_one_const, -1.)) {
|
||||
return false;
|
||||
}
|
||||
nodes_to_replace.emplace_back(pattern_to_output.at(p_mul).get_node_shared_ptr());
|
||||
} else {
|
||||
nodes_to_replace.emplace_back(pattern_to_output.at(p_neg).get_node_shared_ptr());
|
||||
}
|
||||
|
||||
auto sub = register_new_node<opset8::Subtract>(minuend_input, subtrahend_input);
|
||||
sub->set_friendly_name(add->get_friendly_name());
|
||||
copy_runtime_info(nodes_to_replace, sub);
|
||||
replace_node(add, sub);
|
||||
return true;
|
||||
};
|
||||
|
||||
auto m = std::make_shared<ngraph::pattern::Matcher>(p_add, matcher_name);
|
||||
register_matcher(m, callback);
|
||||
}
|
@ -11,31 +11,65 @@
|
||||
#include <ngraph/opsets/opset1.hpp>
|
||||
#include <ngraph/rt_info.hpp>
|
||||
#include <ngraph/pattern/op/wrap_type.hpp>
|
||||
#include <ngraph/validation_util.hpp>
|
||||
#include <ngraph/log.hpp>
|
||||
|
||||
NGRAPH_RTTI_DEFINITION(ngraph::pass::ConvertDivide, "ConvertDivide", 0);
|
||||
NGRAPH_RTTI_DEFINITION(ngraph::pass::ConvertDivideWithConstant, "ConvertDivideWithConstant", 0);
|
||||
|
||||
namespace {
|
||||
bool convert_divide(std::shared_ptr<ngraph::Node> node) {
|
||||
auto div = std::dynamic_pointer_cast<ngraph::opset1::Divide>(node);
|
||||
// We can not apply this transformation in case with integer input data type
|
||||
if (!div || div->get_input_element_type(0).is_integral()
|
||||
|| div->get_input_element_type(1).is_integral()) {
|
||||
return false;
|
||||
}
|
||||
|
||||
ngraph::Output<ngraph::Node> pow = std::make_shared<ngraph::opset1::Power>(div->input_value(1),
|
||||
ngraph::op::Constant::create(div->get_input_element_type(1), ngraph::Shape{}, {-1}));
|
||||
|
||||
if (std::dynamic_pointer_cast<ngraph::op::Constant>(div->get_input_node_shared_ptr(1))) {
|
||||
if (auto const_pow = ngraph::get_constant_from_source(pow)) {
|
||||
pow = const_pow;
|
||||
} else {
|
||||
NGRAPH_DEBUG << "ConvertDivide has failed due to unsupported evaluate type in " << pow.get_node();
|
||||
return false;
|
||||
}
|
||||
} else {
|
||||
ngraph::copy_runtime_info(div, pow.get_node_shared_ptr());
|
||||
}
|
||||
|
||||
auto mul = std::make_shared<ngraph::opset1::Multiply>(div->input(0).get_source_output(), pow);
|
||||
|
||||
mul->set_friendly_name(div->get_friendly_name());
|
||||
ngraph::copy_runtime_info(div, mul);
|
||||
ngraph::replace_node(div, mul);
|
||||
return true;
|
||||
}
|
||||
} // namespace
|
||||
|
||||
ngraph::pass::ConvertDivide::ConvertDivide() {
|
||||
MATCHER_SCOPE(ConvertDivide);
|
||||
auto div = ngraph::pattern::wrap_type<ngraph::opset1::Divide>();
|
||||
|
||||
ngraph::matcher_pass_callback callback = [](pattern::Matcher& m) {
|
||||
auto div = std::dynamic_pointer_cast<ngraph::opset1::Divide> (m.get_match_root());
|
||||
// We can not apply this transformation in case with integer input data type
|
||||
if (!div || div->input(0).get_element_type().is_integral()) {
|
||||
return false;
|
||||
}
|
||||
|
||||
auto pow = std::make_shared<ngraph::opset1::Power>(div->input(1).get_source_output(),
|
||||
op::Constant::create(div->get_input_element_type(1), Shape{}, {-1}));
|
||||
|
||||
auto mul = std::make_shared<ngraph::opset1::Multiply>(div->input(0).get_source_output(), pow);
|
||||
|
||||
mul->set_friendly_name(div->get_friendly_name());
|
||||
ngraph::copy_runtime_info(div, {pow, mul});
|
||||
ngraph::replace_node(div, mul);
|
||||
return true;
|
||||
return convert_divide(m.get_match_root());
|
||||
};
|
||||
|
||||
auto m = std::make_shared<ngraph::pattern::Matcher>(div, matcher_name);
|
||||
this->register_matcher(m, callback);
|
||||
}
|
||||
|
||||
ngraph::pass::ConvertDivideWithConstant::ConvertDivideWithConstant() {
|
||||
MATCHER_SCOPE(ConvertDivideWithConstant);
|
||||
auto div = ngraph::pattern::wrap_type<ngraph::opset1::Divide>(
|
||||
{pattern::any_input(), pattern::wrap_type<op::Constant>()});
|
||||
|
||||
ngraph::matcher_pass_callback callback = [](pattern::Matcher& m) {
|
||||
return convert_divide(m.get_match_root());
|
||||
};
|
||||
|
||||
auto m = std::make_shared<ngraph::pattern::Matcher>(div, matcher_name);
|
||||
this->register_matcher(m, callback);
|
||||
}
|
@ -32,10 +32,8 @@ TEST_F(TransformationTestsF, ConvertDivide) {
|
||||
|
||||
{
|
||||
auto data = std::make_shared<ngraph::opset1::Parameter>(ngraph::element::f32, ngraph::Shape{3, 1, 2});
|
||||
auto divide_constant = ngraph::opset1::Constant::create(ngraph::element::f32, ngraph::Shape{1}, {1.5});
|
||||
auto pow = std::make_shared<ngraph::opset1::Power>(divide_constant,
|
||||
ngraph::opset1::Constant::create(ngraph::element::f32, ngraph::Shape{}, {-1}));
|
||||
auto mul = std::make_shared<ngraph::opset1::Multiply>(data, pow);
|
||||
auto divide_constant = ngraph::opset1::Constant::create(ngraph::element::f32, ngraph::Shape{1}, {1. / 1.5});
|
||||
auto mul = std::make_shared<ngraph::opset1::Multiply>(data, divide_constant);
|
||||
|
||||
function_ref = std::make_shared<ngraph::Function>(ngraph::NodeVector{mul}, ngraph::ParameterVector{data});
|
||||
}
|
||||
@ -63,11 +61,11 @@ TEST_F(TransformationTestsF, ConvertDivideNegative) {
|
||||
|
||||
TEST_F(TransformationTestsF, ConvertDivideScalar) {
|
||||
{
|
||||
auto data = std::make_shared<ngraph::opset1::Parameter>(ngraph::element::f32, ngraph::Shape{});
|
||||
auto divide_constant = ngraph::opset1::Constant::create(ngraph::element::f32, ngraph::Shape{}, {1.5});
|
||||
auto divide = std::make_shared<ngraph::opset1::Divide>(data, divide_constant);
|
||||
auto data1 = std::make_shared<ngraph::opset1::Parameter>(ngraph::element::f32, ngraph::Shape{});
|
||||
auto data2 = std::make_shared<ngraph::opset1::Parameter>(ngraph::element::f32, ngraph::Shape{});
|
||||
auto divide = std::make_shared<ngraph::opset1::Divide>(data1, data2);
|
||||
|
||||
function = std::make_shared<ngraph::Function>(ngraph::NodeVector{divide}, ngraph::ParameterVector{data});
|
||||
function = std::make_shared<ngraph::Function>(ngraph::NodeVector{divide}, ngraph::ParameterVector{data1, data2});
|
||||
|
||||
NGRAPH_CHECK(divide->get_output_partial_shape(0).rank().get_length() == 0);
|
||||
|
||||
@ -76,13 +74,51 @@ TEST_F(TransformationTestsF, ConvertDivideScalar) {
|
||||
|
||||
{
|
||||
auto data = std::make_shared<ngraph::opset1::Parameter>(ngraph::element::f32, ngraph::Shape{});
|
||||
auto divide_constant = ngraph::opset1::Constant::create(ngraph::element::f32, ngraph::Shape{}, {1.5});
|
||||
auto pow = std::make_shared<ngraph::opset1::Power>(divide_constant,
|
||||
auto pow_input = std::make_shared<ngraph::opset1::Parameter>(ngraph::element::f32, ngraph::Shape{});
|
||||
auto pow = std::make_shared<ngraph::opset1::Power>(pow_input,
|
||||
ngraph::opset1::Constant::create(ngraph::element::f32, ngraph::Shape{}, {-1}));
|
||||
auto mul = std::make_shared<ngraph::opset1::Multiply>(data, pow);
|
||||
|
||||
function_ref = std::make_shared<ngraph::Function>(ngraph::NodeVector{mul}, ngraph::ParameterVector{data});
|
||||
function_ref = std::make_shared<ngraph::Function>(ngraph::NodeVector{mul}, ngraph::ParameterVector{data, pow_input});
|
||||
|
||||
NGRAPH_CHECK(mul->get_output_partial_shape(0).rank().get_length() == 0);
|
||||
}
|
||||
}
|
||||
|
||||
TEST_F(TransformationTestsF, ConvertDivideWithConstantPositive) {
|
||||
{
|
||||
auto data = std::make_shared<ngraph::opset1::Parameter>(ngraph::element::f32, ngraph::Shape{});
|
||||
auto divide_constant = ngraph::opset1::Constant::create(ngraph::element::f32, ngraph::Shape{}, {1.5});
|
||||
auto divide = std::make_shared<ngraph::opset1::Divide>(data, divide_constant);
|
||||
|
||||
function = std::make_shared<ngraph::Function>(ngraph::NodeVector{divide}, ngraph::ParameterVector{data});
|
||||
manager.register_pass<ngraph::pass::ConvertDivideWithConstant>();
|
||||
}
|
||||
|
||||
{
|
||||
auto data = std::make_shared<ngraph::opset1::Parameter>(ngraph::element::f32, ngraph::Shape{});
|
||||
auto divide_constant = ngraph::opset1::Constant::create(ngraph::element::f32, ngraph::Shape{}, {1. / 1.5});
|
||||
auto mul = std::make_shared<ngraph::opset1::Multiply>(data, divide_constant);
|
||||
|
||||
function_ref = std::make_shared<ngraph::Function>(ngraph::NodeVector{mul}, ngraph::ParameterVector{data});
|
||||
}
|
||||
}
|
||||
|
||||
TEST_F(TransformationTestsF, ConvertDivideWithConstantNegative) {
|
||||
{
|
||||
auto data1 = std::make_shared<ngraph::opset1::Parameter>(ngraph::element::f32, ngraph::Shape{});
|
||||
auto data2 = std::make_shared<ngraph::opset1::Parameter>(ngraph::element::f32, ngraph::Shape{});
|
||||
auto divide = std::make_shared<ngraph::opset1::Divide>(data1, data2);
|
||||
|
||||
function = std::make_shared<ngraph::Function>(ngraph::NodeVector{divide}, ngraph::ParameterVector{data1, data2});
|
||||
manager.register_pass<ngraph::pass::ConvertDivideWithConstant>();
|
||||
}
|
||||
|
||||
{
|
||||
auto data1 = std::make_shared<ngraph::opset1::Parameter>(ngraph::element::f32, ngraph::Shape{});
|
||||
auto data2 = std::make_shared<ngraph::opset1::Parameter>(ngraph::element::f32, ngraph::Shape{});
|
||||
auto divide = std::make_shared<ngraph::opset1::Divide>(data1, data2);
|
||||
|
||||
function_ref = std::make_shared<ngraph::Function>(ngraph::NodeVector{divide}, ngraph::ParameterVector{data1, data2});
|
||||
}
|
||||
}
|
||||
|
@ -0,0 +1,88 @@
|
||||
// Copyright (C) 2021 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#include <gtest/gtest.h>
|
||||
|
||||
#include <string>
|
||||
#include <memory>
|
||||
#include <queue>
|
||||
|
||||
#include <ngraph/function.hpp>
|
||||
#include <ngraph/opsets/opset1.hpp>
|
||||
#include <transformations/init_node_info.hpp>
|
||||
#include <transformations/utils/utils.hpp>
|
||||
#include <ngraph/pass/manager.hpp>
|
||||
#include <transformations/common_optimizations/divide_fusion.hpp>
|
||||
|
||||
#include "common_test_utils/ngraph_test_utils.hpp"
|
||||
|
||||
using namespace testing;
|
||||
|
||||
TEST(TransformationTests, DivideFusion) {
|
||||
std::shared_ptr<ngraph::Function> f, f_ref;
|
||||
{
|
||||
auto data1 = std::make_shared<ngraph::opset1::Parameter>(ngraph::element::f32, ngraph::Shape{3, 1, 2});
|
||||
auto data2 = std::make_shared<ngraph::opset1::Parameter>(ngraph::element::f32, ngraph::Shape{3, 1, 2});
|
||||
auto pow_constant = ngraph::opset1::Constant::create(ngraph::element::f32, ngraph::Shape{1}, {-1});
|
||||
auto pow = std::make_shared<ngraph::opset1::Power>(data2, pow_constant);
|
||||
auto mul = std::make_shared<ngraph::opset1::Multiply>(data1, pow);
|
||||
|
||||
f = std::make_shared<ngraph::Function>(ngraph::NodeVector{mul}, ngraph::ParameterVector{data1, data2});
|
||||
|
||||
ngraph::pass::Manager m;
|
||||
m.register_pass<ngraph::pass::InitNodeInfo>();
|
||||
m.register_pass<ngraph::pass::DivideFusion>();
|
||||
m.run_passes(f);
|
||||
ASSERT_NO_THROW(check_rt_info(f));
|
||||
}
|
||||
|
||||
{
|
||||
auto data1 = std::make_shared<ngraph::opset1::Parameter>(ngraph::element::f32, ngraph::Shape{3, 1, 2});
|
||||
auto data2 = std::make_shared<ngraph::opset1::Parameter>(ngraph::element::f32, ngraph::Shape{3, 1, 2});
|
||||
auto divide = std::make_shared<ngraph::opset1::Divide>(data1, data2);
|
||||
|
||||
f_ref = std::make_shared<ngraph::Function>(ngraph::NodeVector{divide}, ngraph::ParameterVector{data1, data2});
|
||||
}
|
||||
|
||||
const auto res = FunctionsComparator::with_default()
|
||||
.enable(FunctionsComparator::CONST_VALUES)
|
||||
.enable(FunctionsComparator::ATTRIBUTES)
|
||||
.compare(f, f_ref);
|
||||
ASSERT_TRUE(res.valid) << res.message;
|
||||
}
|
||||
|
||||
TEST(TransformationTests, DivideFusionNegative) {
|
||||
std::shared_ptr<ngraph::Function> f, f_ref;
|
||||
{
|
||||
auto data1 = std::make_shared<ngraph::opset1::Parameter>(ngraph::element::f32, ngraph::Shape{3, 1, 2});
|
||||
auto data2 = std::make_shared<ngraph::opset1::Parameter>(ngraph::element::f32, ngraph::Shape{3, 1, 2});
|
||||
auto pow_constant = ngraph::opset1::Constant::create(ngraph::element::f32, ngraph::Shape{1}, {-1.01});
|
||||
auto pow = std::make_shared<ngraph::opset1::Power>(data2, pow_constant);
|
||||
auto mul = std::make_shared<ngraph::opset1::Multiply>(data1, pow);
|
||||
|
||||
f = std::make_shared<ngraph::Function>(ngraph::NodeVector{mul}, ngraph::ParameterVector{data1, data2});
|
||||
|
||||
ngraph::pass::Manager m;
|
||||
m.register_pass<ngraph::pass::InitNodeInfo>();
|
||||
m.register_pass<ngraph::pass::DivideFusion>();
|
||||
m.run_passes(f);
|
||||
ASSERT_NO_THROW(check_rt_info(f));
|
||||
}
|
||||
|
||||
{
|
||||
auto data1 = std::make_shared<ngraph::opset1::Parameter>(ngraph::element::f32, ngraph::Shape{3, 1, 2});
|
||||
auto data2 = std::make_shared<ngraph::opset1::Parameter>(ngraph::element::f32, ngraph::Shape{3, 1, 2});
|
||||
auto pow_constant = ngraph::opset1::Constant::create(ngraph::element::f32, ngraph::Shape{1}, {-1.01});
|
||||
auto pow = std::make_shared<ngraph::opset1::Power>(data2, pow_constant);
|
||||
auto mul = std::make_shared<ngraph::opset1::Multiply>(data1, pow);
|
||||
|
||||
f_ref = std::make_shared<ngraph::Function>(ngraph::NodeVector{mul}, ngraph::ParameterVector{data1, data2});
|
||||
}
|
||||
|
||||
const auto res = FunctionsComparator::with_default()
|
||||
.enable(FunctionsComparator::CONST_VALUES)
|
||||
.enable(FunctionsComparator::ATTRIBUTES)
|
||||
.compare(f, f_ref);
|
||||
ASSERT_TRUE(res.valid) << res.message;
|
||||
}
|
@ -0,0 +1,121 @@
|
||||
// Copyright (C) 2021 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#include <gtest/gtest.h>
|
||||
|
||||
#include <string>
|
||||
#include <memory>
|
||||
#include <queue>
|
||||
|
||||
#include <ngraph/function.hpp>
|
||||
#include <ngraph/opsets/opset1.hpp>
|
||||
#include <transformations/init_node_info.hpp>
|
||||
#include <transformations/utils/utils.hpp>
|
||||
#include <ngraph/pass/manager.hpp>
|
||||
#include <transformations/common_optimizations/divide_fusion.hpp>
|
||||
#include <transformations/common_optimizations/subtract_fusion.hpp>
|
||||
|
||||
#include "common_test_utils/ngraph_test_utils.hpp"
|
||||
|
||||
using namespace testing;
|
||||
|
||||
TEST(TransformationTests, SubtractFusionMultiply) {
|
||||
std::shared_ptr<ngraph::Function> f, f_ref;
|
||||
{
|
||||
auto data1 = std::make_shared<ngraph::opset1::Parameter>(ngraph::element::f32, ngraph::Shape{3, 1, 2});
|
||||
auto data2 = std::make_shared<ngraph::opset1::Parameter>(ngraph::element::f32, ngraph::Shape{3, 1, 2});
|
||||
auto mul_constant = ngraph::opset1::Constant::create(ngraph::element::f32, ngraph::Shape{1}, {-1});
|
||||
auto mul = std::make_shared<ngraph::opset1::Multiply>(data2, mul_constant);
|
||||
auto add = std::make_shared<ngraph::opset1::Add>(data1, mul);
|
||||
|
||||
f = std::make_shared<ngraph::Function>(ngraph::NodeVector{add}, ngraph::ParameterVector{data1, data2});
|
||||
|
||||
ngraph::pass::Manager m;
|
||||
m.register_pass<ngraph::pass::InitNodeInfo>();
|
||||
m.register_pass<ngraph::pass::SubtractFusion>();
|
||||
m.run_passes(f);
|
||||
ASSERT_NO_THROW(check_rt_info(f));
|
||||
}
|
||||
|
||||
{
|
||||
auto data1 = std::make_shared<ngraph::opset1::Parameter>(ngraph::element::f32, ngraph::Shape{3, 1, 2});
|
||||
auto data2 = std::make_shared<ngraph::opset1::Parameter>(ngraph::element::f32, ngraph::Shape{3, 1, 2});
|
||||
auto divide = std::make_shared<ngraph::opset1::Subtract>(data1, data2);
|
||||
|
||||
f_ref = std::make_shared<ngraph::Function>(ngraph::NodeVector{divide}, ngraph::ParameterVector{data1, data2});
|
||||
}
|
||||
|
||||
const auto res = FunctionsComparator::with_default()
|
||||
.enable(FunctionsComparator::CONST_VALUES)
|
||||
.enable(FunctionsComparator::ATTRIBUTES)
|
||||
.compare(f, f_ref);
|
||||
ASSERT_TRUE(res.valid) << res.message;
|
||||
}
|
||||
|
||||
TEST(TransformationTests, SubtractFusionMultiplyNegative) {
|
||||
std::shared_ptr<ngraph::Function> f, f_ref;
|
||||
{
|
||||
auto data1 = std::make_shared<ngraph::opset1::Parameter>(ngraph::element::f32, ngraph::Shape{3, 1, 2});
|
||||
auto data2 = std::make_shared<ngraph::opset1::Parameter>(ngraph::element::f32, ngraph::Shape{3, 1, 2});
|
||||
auto mul_constant = ngraph::opset1::Constant::create(ngraph::element::f32, ngraph::Shape{1}, {-1.01});
|
||||
auto mul = std::make_shared<ngraph::opset1::Multiply>(data2, mul_constant);
|
||||
auto add = std::make_shared<ngraph::opset1::Add>(data1, mul);
|
||||
|
||||
f = std::make_shared<ngraph::Function>(ngraph::NodeVector{add}, ngraph::ParameterVector{data1, data2});
|
||||
|
||||
ngraph::pass::Manager m;
|
||||
m.register_pass<ngraph::pass::InitNodeInfo>();
|
||||
m.register_pass<ngraph::pass::SubtractFusion>();
|
||||
m.run_passes(f);
|
||||
ASSERT_NO_THROW(check_rt_info(f));
|
||||
}
|
||||
|
||||
{
|
||||
auto data1 = std::make_shared<ngraph::opset1::Parameter>(ngraph::element::f32, ngraph::Shape{3, 1, 2});
|
||||
auto data2 = std::make_shared<ngraph::opset1::Parameter>(ngraph::element::f32, ngraph::Shape{3, 1, 2});
|
||||
auto mul_constant = ngraph::opset1::Constant::create(ngraph::element::f32, ngraph::Shape{1}, {-1.01});
|
||||
auto mul = std::make_shared<ngraph::opset1::Multiply>(data2, mul_constant);
|
||||
auto add = std::make_shared<ngraph::opset1::Add>(data1, mul);
|
||||
|
||||
f_ref = std::make_shared<ngraph::Function>(ngraph::NodeVector{add}, ngraph::ParameterVector{data1, data2});
|
||||
}
|
||||
|
||||
const auto res = FunctionsComparator::with_default()
|
||||
.enable(FunctionsComparator::CONST_VALUES)
|
||||
.enable(FunctionsComparator::ATTRIBUTES)
|
||||
.compare(f, f_ref);
|
||||
ASSERT_TRUE(res.valid) << res.message;
|
||||
}
|
||||
|
||||
TEST(TransformationTests, SubtractFusionNeg) {
|
||||
std::shared_ptr<ngraph::Function> f, f_ref;
|
||||
{
|
||||
auto data1 = std::make_shared<ngraph::opset1::Parameter>(ngraph::element::f32, ngraph::Shape{3, 1, 2});
|
||||
auto data2 = std::make_shared<ngraph::opset1::Parameter>(ngraph::element::f32, ngraph::Shape{3, 1, 2});
|
||||
auto neg = std::make_shared<ngraph::opset1::Negative>(data2);
|
||||
auto add = std::make_shared<ngraph::opset1::Add>(neg, data1);
|
||||
|
||||
f = std::make_shared<ngraph::Function>(ngraph::NodeVector{add}, ngraph::ParameterVector{data1, data2});
|
||||
|
||||
ngraph::pass::Manager m;
|
||||
m.register_pass<ngraph::pass::InitNodeInfo>();
|
||||
m.register_pass<ngraph::pass::SubtractFusion>();
|
||||
m.run_passes(f);
|
||||
ASSERT_NO_THROW(check_rt_info(f));
|
||||
}
|
||||
|
||||
{
|
||||
auto data1 = std::make_shared<ngraph::opset1::Parameter>(ngraph::element::f32, ngraph::Shape{3, 1, 2});
|
||||
auto data2 = std::make_shared<ngraph::opset1::Parameter>(ngraph::element::f32, ngraph::Shape{3, 1, 2});
|
||||
auto divide = std::make_shared<ngraph::opset1::Subtract>(data1, data2);
|
||||
|
||||
f_ref = std::make_shared<ngraph::Function>(ngraph::NodeVector{divide}, ngraph::ParameterVector{data1, data2});
|
||||
}
|
||||
|
||||
const auto res = FunctionsComparator::with_default()
|
||||
.enable(FunctionsComparator::CONST_VALUES)
|
||||
.enable(FunctionsComparator::ATTRIBUTES)
|
||||
.compare(f, f_ref);
|
||||
ASSERT_TRUE(res.valid) << res.message;
|
||||
}
|
@ -22,7 +22,6 @@ extensions/back/CutMemory.py
|
||||
extensions/back/EnableConstantStridedSlice.py
|
||||
extensions/back/FakeOutputResolver.py
|
||||
extensions/back/ForceStrictPrecision.py
|
||||
extensions/back/fuse_sub_div_min.py
|
||||
extensions/back/FuseTransposesSequence.py
|
||||
extensions/back/GatherNormalizer.py
|
||||
extensions/back/insert_compatibility_l2normalization.py
|
||||
|
@ -1,122 +0,0 @@
|
||||
# Copyright (C) 2018-2021 Intel Corporation
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
import numpy as np
|
||||
|
||||
from extensions.ops.elementwise import Sub, Div, Negative
|
||||
from mo.back.replacement import BackReplacementPattern
|
||||
from mo.graph.graph import Node, Graph
|
||||
|
||||
|
||||
class Negate(BackReplacementPattern):
|
||||
enabled = True
|
||||
force_clean_up = True
|
||||
|
||||
@staticmethod
|
||||
def pattern():
|
||||
return dict(
|
||||
nodes=[
|
||||
('mul', {'type': 'Multiply'})
|
||||
],
|
||||
edges=[],
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def replace_pattern(graph: Graph, match: [str, Node]):
|
||||
mul = match['mul']
|
||||
name = mul.soft_get('name', mul.id)
|
||||
|
||||
mul_POS_port = None
|
||||
if mul.in_port(0).data.get_value() is not None and np.all(mul.in_port(0).data.get_value() == -1):
|
||||
mul_POS_port = mul.in_port(1)
|
||||
if mul.in_port(1).data.get_value() is not None and np.all(mul.in_port(1).data.get_value() == -1):
|
||||
mul_POS_port = mul.in_port(0)
|
||||
|
||||
if mul_POS_port is None:
|
||||
return
|
||||
|
||||
negative = Negative(graph, {'name': name + '/Negate'}).create_node()
|
||||
|
||||
mul.out_port(0).get_connection().set_source(negative.out_port(0))
|
||||
mul_POS_port.get_connection().set_destination(negative.in_port(0))
|
||||
|
||||
|
||||
class EltwisesToSubtract(BackReplacementPattern):
|
||||
enabled = True
|
||||
force_clean_up = True
|
||||
|
||||
def run_after(self):
|
||||
return [Negate]
|
||||
|
||||
@staticmethod
|
||||
def pattern():
|
||||
return dict(
|
||||
nodes=[
|
||||
('neg', {'type': 'Negative'}),
|
||||
('neg_d', {}),
|
||||
('add', {'type': 'Add'})
|
||||
],
|
||||
edges=[
|
||||
('neg', 'neg_d'),
|
||||
('neg_d', 'add'),
|
||||
],
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def replace_pattern(graph: Graph, match: [str, Node]):
|
||||
neg = match['neg']
|
||||
add = match['add']
|
||||
|
||||
name = add.soft_get('name', add.id)
|
||||
|
||||
minuend_port = add.in_port(0).get_source() \
|
||||
if add.in_port(1).get_source().node.id == neg.id else add.in_port(1).get_source()
|
||||
subtrahned_port = neg.in_port(0).get_source()
|
||||
|
||||
sub = Sub(graph, {'name': name + '/sub'}).create_node()
|
||||
|
||||
add.out_port(0).get_connection().set_source(sub.out_port(0))
|
||||
minuend_port.connect(sub.in_port(0))
|
||||
subtrahned_port.connect(sub.in_port(1))
|
||||
|
||||
|
||||
class EltwisesToDiv(BackReplacementPattern):
|
||||
enabled = True
|
||||
force_clean_up = True
|
||||
|
||||
@staticmethod
|
||||
def pattern():
|
||||
return dict(
|
||||
nodes=[
|
||||
('const', {'type': 'Const'}),
|
||||
('const_d', {'value': lambda val: val is not None and np.all(val == -1)}),
|
||||
('inv', {'type': 'Pow'}),
|
||||
('inv_d', {}),
|
||||
('mul', {'type': 'Multiply'})
|
||||
],
|
||||
edges=[
|
||||
('const', 'const_d'),
|
||||
('const_d', 'inv', {'in': 1}),
|
||||
('inv', 'inv_d'),
|
||||
('inv_d', 'mul'),
|
||||
],
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def replace_pattern(graph: Graph, match: [str, Node]):
|
||||
pow = match['inv']
|
||||
mul = match['mul']
|
||||
const = match['const']
|
||||
|
||||
name = mul.soft_get('name', mul.id)
|
||||
|
||||
devidend_port = mul.in_port(0).get_source() if mul.in_port(1).get_source().node.id == pow.id else mul.in_port(
|
||||
1).get_source()
|
||||
divider_port = pow.in_port(0).get_source() if pow.in_port(1).get_source().node.id == const.id else pow.in_port(
|
||||
1).get_source()
|
||||
|
||||
div = Div(graph, {'name': name + '/div'}).create_node()
|
||||
|
||||
mul.out_port(0).get_connection().set_source(div.out_port(0))
|
||||
devidend_port.connect(div.in_port(0))
|
||||
divider_port.connect(div.in_port(1))
|
Loading…
Reference in New Issue
Block a user