Enable ConstantFolding inside MO Backend (#6916)

This commit is contained in:
Gleb Kazantaev 2021-08-07 12:40:46 +03:00 committed by GitHub
parent 5d781afe73
commit b11a2220b0
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
14 changed files with 254 additions and 29 deletions

View File

@ -0,0 +1,24 @@
// Copyright (C) 2021 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#pragma once
#include <memory>
#include <ngraph/pass/graph_rewrite.hpp>
namespace ngraph {
namespace pass {
class DisableShapeOfConstantFolding;
} // namespace pass
} // namespace ngraph
class ngraph::pass::DisableShapeOfConstantFolding: public ngraph::pass::MatcherPass {
public:
NGRAPH_RTTI_DECLARATION;
DisableShapeOfConstantFolding();
};

View File

@ -0,0 +1,32 @@
// Copyright (C) 2021 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#include <memory>
#include <ostream>
#include <fstream>
#include <ngraph/pattern/op/wrap_type.hpp>
#include <ngraph/opsets/opset2.hpp>
#include <ngraph/opsets/opset3.hpp>
#include <ngraph/variant.hpp>
#include <transformations/rt_info/disable_constant_folding.hpp>
#include "disable_shapeof_constant_folding.hpp"
NGRAPH_RTTI_DEFINITION(ngraph::pass::DisableShapeOfConstantFolding, "DisableShapeOfConstantFolding", 0);
ngraph::pass::DisableShapeOfConstantFolding::DisableShapeOfConstantFolding() {
auto shape_of = pattern::wrap_type<opset2::ShapeOf, opset3::ShapeOf>([=](const Output<Node> & output) {
const auto & shape = output.get_partial_shape();
return shape.is_dynamic() || shape_size(shape.get_shape()) != 1;
});
ngraph::matcher_pass_callback callback = [=](pattern::Matcher& m) {
disable_constant_folding(m.get_match_root());
return true;
};
auto m = std::make_shared<ngraph::pattern::Matcher>(shape_of, "DisableShapeOfConstantFolding");
this->register_matcher(m, callback);
}

View File

@ -5,8 +5,10 @@
#include <memory> #include <memory>
#include "moc_transformations.hpp" #include "moc_transformations.hpp"
#include "disable_shapeof_constant_folding.hpp"
#include <ngraph/pass/manager.hpp> #include <ngraph/pass/manager.hpp>
#include <ngraph/pass/constant_folding.hpp>
#include <transformations/init_node_info.hpp> #include <transformations/init_node_info.hpp>
#include <transformations/common_optimizations/gelu_fusion.hpp> #include <transformations/common_optimizations/gelu_fusion.hpp>
#include <transformations/common_optimizations/softplus_fusion.hpp> #include <transformations/common_optimizations/softplus_fusion.hpp>
@ -32,6 +34,7 @@
#include <transformations/common_optimizations/lin_op_sequence_fusion.hpp> #include <transformations/common_optimizations/lin_op_sequence_fusion.hpp>
#include <transformations/common_optimizations/conv_mul_fusion.hpp> #include <transformations/common_optimizations/conv_mul_fusion.hpp>
#include <transformations/common_optimizations/nop_elimination.hpp> #include <transformations/common_optimizations/nop_elimination.hpp>
#include <transformations/low_precision/disable_convert_constant_folding_on_const_path.hpp>
NGRAPH_RTTI_DEFINITION(ngraph::pass::MOCTransformations, "MOCTransformations", 0); NGRAPH_RTTI_DEFINITION(ngraph::pass::MOCTransformations, "MOCTransformations", 0);
@ -48,6 +51,10 @@ bool ngraph::pass::MOCTransformations::run_on_function(std::shared_ptr<ngraph::F
ngraph::pass::Manager manager(get_pass_config()); ngraph::pass::Manager manager(get_pass_config());
manager.register_pass<ngraph::pass::InitNodeInfo>(); manager.register_pass<ngraph::pass::InitNodeInfo>();
manager.register_pass<ngraph::pass::DisableConvertConstantFoldingOnConstPath>(
element::TypeVector{ ngraph::element::i8, ngraph::element::u8, ngraph::element::i4, ngraph::element::u4 });
manager.register_pass<ngraph::pass::DisableShapeOfConstantFolding>();
manager.register_pass<ngraph::pass::ConstantFolding>();
manager.register_pass<ngraph::pass::RemoveFilteringBoxesBySize>(); manager.register_pass<ngraph::pass::RemoveFilteringBoxesBySize>();
manager.register_pass<ngraph::pass::ConvertQuantizeDequantize>(); manager.register_pass<ngraph::pass::ConvertQuantizeDequantize>();
manager.register_pass<ngraph::pass::SimplifyShapeOfSubGraph>(); manager.register_pass<ngraph::pass::SimplifyShapeOfSubGraph>();

View File

@ -22,5 +22,5 @@ class ngraph::pass::DisableConvertConstantFoldingOnConstPath : public ngraph::pa
public: public:
NGRAPH_RTTI_DECLARATION; NGRAPH_RTTI_DECLARATION;
DisableConvertConstantFoldingOnConstPath( DisableConvertConstantFoldingOnConstPath(
const std::vector<ngraph::element::Type>& inputPrecisions = {}); const element::TypeVector & inputPrecisions = {});
}; };

View File

@ -0,0 +1,44 @@
// Copyright (C) 2021 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#include <assert.h>
#include <functional>
#include <memory>
#include <string>
#include <set>
#include <ngraph/node.hpp>
#include <ngraph/variant.hpp>
#include <transformations_visibility.hpp>
namespace ngraph {
/**
* @ingroup ie_runtime_attr_api
* @brief DisableConstantFolding disable ConstantFolding for given operation
*/
class TRANSFORMATIONS_API DisableConstantFolding {
public:
DisableConstantFolding() = default;
};
extern template class TRANSFORMATIONS_API VariantImpl<DisableConstantFolding>;
template<>
class TRANSFORMATIONS_API VariantWrapper<DisableConstantFolding> : public VariantImpl<DisableConstantFolding> {
public:
static constexpr VariantTypeInfo type_info{"DISABLED_CONSTANT_FOLDING", 0};
const VariantTypeInfo &get_type_info() const override {
return type_info;
}
VariantWrapper(const value_type &value) : VariantImpl<value_type>(value) {}
bool is_copyable() const override { return false; }
};
TRANSFORMATIONS_API void disable_constant_folding(const std::shared_ptr<Node>& node);
} // namespace ngraph

View File

@ -20,7 +20,7 @@ using namespace ngraph;
NGRAPH_RTTI_DEFINITION(ngraph::pass::DisableConvertConstantFoldingOnConstPath, "DisableConvertConstantFoldingOnConstPath", 0); NGRAPH_RTTI_DEFINITION(ngraph::pass::DisableConvertConstantFoldingOnConstPath, "DisableConvertConstantFoldingOnConstPath", 0);
ngraph::pass::DisableConvertConstantFoldingOnConstPath::DisableConvertConstantFoldingOnConstPath( ngraph::pass::DisableConvertConstantFoldingOnConstPath::DisableConvertConstantFoldingOnConstPath(
const std::vector<ngraph::element::Type>& inputPrecisions) { const element::TypeVector & inputPrecisions) {
auto matcherData = ngraph::pattern::any_input(); auto matcherData = ngraph::pattern::any_input();
auto matcherConvert = ngraph::pattern::wrap_type<opset3::Convert>({ matcherData }, pattern::consumers_count(1)); auto matcherConvert = ngraph::pattern::wrap_type<opset3::Convert>({ matcherData }, pattern::consumers_count(1));

View File

@ -0,0 +1,14 @@
// Copyright (C) 2021 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#include "transformations/rt_info/disable_constant_folding.hpp"
template class ngraph::VariantImpl<ngraph::DisableConstantFolding>;
constexpr ngraph::VariantTypeInfo ngraph::VariantWrapper<ngraph::DisableConstantFolding>::type_info;
void ngraph::disable_constant_folding(const std::shared_ptr<Node>& node) {
auto & rt_info = node->get_rt_info();
rt_info[VariantWrapper<DisableConstantFolding>::type_info.name] = make_variant<DisableConstantFolding>({});
}

View File

@ -0,0 +1,78 @@
// Copyright (C) 2021 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#include <gtest/gtest.h>
#include <string>
#include <memory>
#include <queue>
#include <ngraph/function.hpp>
#include <ngraph/opsets/opset6.hpp>
#include <ngraph/pass/manager.hpp>
#include <ngraph/pass/constant_folding.hpp>
#include <disable_shapeof_constant_folding.hpp>
#include <transformations/serialize.hpp>
#include "common_test_utils/ngraph_test_utils.hpp"
using namespace testing;
using namespace ngraph;
TEST(TransformationTests, DisableShapeOfConstantFolding) {
std::shared_ptr<Function> f, f_ref;
{
auto data = std::make_shared<opset6::Parameter>(element::f32, Shape{1, 4, 10, 10});
auto shape_of = std::make_shared<opset6::ShapeOf>(data);
auto abs = std::make_shared<opset6::Abs>(shape_of);
auto reshape = std::make_shared<opset6::Reshape>(data, abs, false);
f = std::make_shared<Function>(NodeVector{reshape}, ParameterVector{data});
pass::Manager m;
m.register_pass<pass::DisableShapeOfConstantFolding>();
m.register_pass<pass::ConstantFolding>();
m.run_passes(f);
}
{
auto data = std::make_shared<opset6::Parameter>(element::f32, Shape{1, 4, 10, 10});
auto shape_of = std::make_shared<opset6::ShapeOf>(data);
auto abs = std::make_shared<opset6::Abs>(shape_of);
auto reshape = std::make_shared<opset6::Reshape>(data, abs, false);
f_ref = std::make_shared<Function>(NodeVector{reshape}, ParameterVector{data});
}
auto res = compare_functions(f, f_ref);
ASSERT_TRUE(res.first) << res.second;
}
TEST(TransformationTests, ShapeOfShapeOfConstantFolding) {
std::shared_ptr<Function> f, f_ref;
{
auto data = std::make_shared<opset6::Parameter>(element::i64, Shape{1, 4, 10, 10});
auto shape_of = std::make_shared<opset6::ShapeOf>(data);
auto reshape = std::make_shared<opset6::Reshape>(data, shape_of, false);
auto rank = std::make_shared<opset6::ShapeOf>(shape_of);
auto mul = std::make_shared<opset6::Multiply>(reshape, rank);
f = std::make_shared<Function>(NodeVector{mul}, ParameterVector{data});
pass::Manager m;
m.register_pass<pass::DisableShapeOfConstantFolding>();
m.register_pass<pass::ConstantFolding>();
m.run_passes(f);
}
{
auto data = std::make_shared<opset6::Parameter>(element::i64, Shape{1, 4, 10, 10});
auto shape_of = std::make_shared<opset6::ShapeOf>(data);
auto reshape = std::make_shared<opset6::Reshape>(data, shape_of, false);
auto mul = std::make_shared<opset6::Multiply>(reshape, opset6::Constant::create(element::i64, Shape{1}, {4}));
f_ref = std::make_shared<Function>(NodeVector{mul}, ParameterVector{data});
}
auto res = compare_functions(f, f_ref);
ASSERT_TRUE(res.first) << res.second;
}

View File

@ -329,6 +329,8 @@ namespace ngraph
/// that all the HostTensorPtrs are not equal to nullptr /// that all the HostTensorPtrs are not equal to nullptr
NGRAPH_API bool validate_host_tensor_vector(const HostTensorVector& v, const size_t& size); NGRAPH_API bool validate_host_tensor_vector(const HostTensorVector& v, const size_t& size);
NGRAPH_API bool could_propagate(const Output<Node>& output, std::vector<Node*>& order);
namespace opset1 namespace opset1
{ {
/// ///

View File

@ -24,6 +24,7 @@ namespace ngraph
virtual std::shared_ptr<ngraph::Variant> init(const std::shared_ptr<ngraph::Node>& node); virtual std::shared_ptr<ngraph::Variant> init(const std::shared_ptr<ngraph::Node>& node);
virtual std::shared_ptr<ngraph::Variant> merge(const ngraph::NodeVector& nodes); virtual std::shared_ptr<ngraph::Variant> merge(const ngraph::NodeVector& nodes);
virtual bool is_copyable() const;
virtual std::string to_string() { return ""; } virtual std::string to_string() { return ""; }
}; };

View File

@ -6,6 +6,7 @@
#include <ngraph/op/constant.hpp> #include <ngraph/op/constant.hpp>
#include "ngraph/op/util/sub_graph_base.hpp" #include "ngraph/op/util/sub_graph_base.hpp"
#include "ngraph/rt_info.hpp" #include "ngraph/rt_info.hpp"
#include "ngraph/validation_util.hpp"
using namespace std; using namespace std;
using namespace ngraph; using namespace ngraph;
@ -101,7 +102,23 @@ bool ngraph::pass::ConstantFolding::pre_calculated_values_folding(
for (auto& input_value : curr_node->input_values()) for (auto& input_value : curr_node->input_values())
{ {
if (input_value.get_tensor().has_and_set_bound()) // Check that ConstantFolding is not disabled on this path
std::vector<Node*> order;
auto status = could_propagate(input_value, order);
if (status)
{
for (const auto& node : order)
{
const auto& rt_info = node->get_rt_info();
if (rt_info.count("DISABLED_CONSTANT_FOLDING"))
{
status = false;
break;
}
}
}
if (status && input_value.get_tensor().has_and_set_bound())
{ {
auto input_node = input_value.get_node_shared_ptr(); auto input_node = input_value.get_node_shared_ptr();
auto replacement = auto replacement =

View File

@ -8,46 +8,47 @@
ngraph::Node::RTMap mergeRuntimeInfo(const ngraph::NodeVector& nodes) ngraph::Node::RTMap mergeRuntimeInfo(const ngraph::NodeVector& nodes)
{ {
ngraph::Node::RTMap mergedInfo; std::unordered_map<std::string, std::vector<std::shared_ptr<ngraph::Variant>>> attrs;
for (auto& node : nodes) for (const auto& node : nodes)
{ {
for (auto& item : node->get_rt_info()) for (const auto& item : node->get_rt_info())
{ {
mergedInfo[item.first] = item.second; if (item.second->is_copyable())
{
attrs[item.first].push_back(item.second);
}
} }
} }
ngraph::Node::RTMap newInfo; ngraph::Node::RTMap merged_attrs;
for (auto& item : mergedInfo) for (auto& item : attrs)
{ {
size_t attributes_count = 0; auto attr = *item.second.begin();
for (auto& node : nodes) if (item.second.size() == 1)
{ {
const auto& rt_info = node->get_rt_info(); merged_attrs[item.first] = attr;
if (rt_info.count(item.first)) }
else if (auto merge_attr = attr->merge(nodes))
{ {
attributes_count++; merged_attrs[item.first] = merge_attr;
} }
} }
if (attributes_count == 1) return merged_attrs;
{
newInfo[item.first] = item.second;
}
else if (auto merge_attr = item.second->merge(nodes))
{
newInfo[item.first] = merge_attr;
}
}
return newInfo;
} }
void ngraph::copy_runtime_info(std::shared_ptr<ngraph::Node> from, std::shared_ptr<ngraph::Node> to) void ngraph::copy_runtime_info(std::shared_ptr<ngraph::Node> from, std::shared_ptr<ngraph::Node> to)
{ {
auto& rtInfoFrom = from->get_rt_info(); auto& attrs = to->get_rt_info();
auto& rtInfoTo = to->get_rt_info(); attrs.clear();
rtInfoTo = rtInfoFrom;
for (const auto& item : from->get_rt_info())
{
if (item.second->is_copyable())
{
attrs[item.first] = item.second;
}
}
} }
void ngraph::copy_runtime_info(std::shared_ptr<ngraph::Node> from, ngraph::NodeVector to) void ngraph::copy_runtime_info(std::shared_ptr<ngraph::Node> from, ngraph::NodeVector to)

View File

@ -1306,7 +1306,7 @@ void ngraph::evaluate_nodes(std::map<RawNodeOutput, HostTensorPtr>& value_map,
} }
} }
bool could_propagate(const Output<Node>& output, std::vector<Node*>& order) bool ngraph::could_propagate(const Output<Node>& output, std::vector<Node*>& order)
{ {
bool status = true; bool status = true;
@ -1367,7 +1367,7 @@ void propagate_rt_info(Node* node, const Output<Node>& final_port)
auto& rt_info = consumer->get_rt_info(); auto& rt_info = consumer->get_rt_info();
for (const auto& it : orig_rt_info) for (const auto& it : orig_rt_info)
{ {
if (rt_info.find(it.first) == rt_info.end()) if (rt_info.find(it.first) == rt_info.end() && it.second->is_copyable())
rt_info[it.first] = it.second; rt_info[it.first] = it.second;
} }
} }

View File

@ -22,5 +22,10 @@ std::shared_ptr<ngraph::Variant> Variant::merge(const ngraph::NodeVector& nodes)
return nullptr; return nullptr;
} }
bool Variant::is_copyable() const
{
return true;
}
template class ngraph::VariantImpl<std::string>; template class ngraph::VariantImpl<std::string>;
template class ngraph::VariantImpl<int64_t>; template class ngraph::VariantImpl<int64_t>;