Enable ConstantFolding inside MO Backend (#6916)
This commit is contained in:
parent
5d781afe73
commit
b11a2220b0
@ -0,0 +1,24 @@
|
|||||||
|
// Copyright (C) 2021 Intel Corporation
|
||||||
|
// SPDX-License-Identifier: Apache-2.0
|
||||||
|
//
|
||||||
|
|
||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include <memory>
|
||||||
|
|
||||||
|
#include <ngraph/pass/graph_rewrite.hpp>
|
||||||
|
|
||||||
|
namespace ngraph {
|
||||||
|
namespace pass {
|
||||||
|
|
||||||
|
class DisableShapeOfConstantFolding;
|
||||||
|
|
||||||
|
} // namespace pass
|
||||||
|
} // namespace ngraph
|
||||||
|
|
||||||
|
|
||||||
|
class ngraph::pass::DisableShapeOfConstantFolding: public ngraph::pass::MatcherPass {
|
||||||
|
public:
|
||||||
|
NGRAPH_RTTI_DECLARATION;
|
||||||
|
DisableShapeOfConstantFolding();
|
||||||
|
};
|
@ -0,0 +1,32 @@
|
|||||||
|
// Copyright (C) 2021 Intel Corporation
|
||||||
|
// SPDX-License-Identifier: Apache-2.0
|
||||||
|
//
|
||||||
|
|
||||||
|
#include <memory>
|
||||||
|
#include <ostream>
|
||||||
|
#include <fstream>
|
||||||
|
|
||||||
|
#include <ngraph/pattern/op/wrap_type.hpp>
|
||||||
|
#include <ngraph/opsets/opset2.hpp>
|
||||||
|
#include <ngraph/opsets/opset3.hpp>
|
||||||
|
#include <ngraph/variant.hpp>
|
||||||
|
#include <transformations/rt_info/disable_constant_folding.hpp>
|
||||||
|
|
||||||
|
#include "disable_shapeof_constant_folding.hpp"
|
||||||
|
|
||||||
|
NGRAPH_RTTI_DEFINITION(ngraph::pass::DisableShapeOfConstantFolding, "DisableShapeOfConstantFolding", 0);
|
||||||
|
|
||||||
|
ngraph::pass::DisableShapeOfConstantFolding::DisableShapeOfConstantFolding() {
|
||||||
|
auto shape_of = pattern::wrap_type<opset2::ShapeOf, opset3::ShapeOf>([=](const Output<Node> & output) {
|
||||||
|
const auto & shape = output.get_partial_shape();
|
||||||
|
return shape.is_dynamic() || shape_size(shape.get_shape()) != 1;
|
||||||
|
});
|
||||||
|
|
||||||
|
ngraph::matcher_pass_callback callback = [=](pattern::Matcher& m) {
|
||||||
|
disable_constant_folding(m.get_match_root());
|
||||||
|
return true;
|
||||||
|
};
|
||||||
|
|
||||||
|
auto m = std::make_shared<ngraph::pattern::Matcher>(shape_of, "DisableShapeOfConstantFolding");
|
||||||
|
this->register_matcher(m, callback);
|
||||||
|
}
|
@ -5,8 +5,10 @@
|
|||||||
#include <memory>
|
#include <memory>
|
||||||
|
|
||||||
#include "moc_transformations.hpp"
|
#include "moc_transformations.hpp"
|
||||||
|
#include "disable_shapeof_constant_folding.hpp"
|
||||||
|
|
||||||
#include <ngraph/pass/manager.hpp>
|
#include <ngraph/pass/manager.hpp>
|
||||||
|
#include <ngraph/pass/constant_folding.hpp>
|
||||||
#include <transformations/init_node_info.hpp>
|
#include <transformations/init_node_info.hpp>
|
||||||
#include <transformations/common_optimizations/gelu_fusion.hpp>
|
#include <transformations/common_optimizations/gelu_fusion.hpp>
|
||||||
#include <transformations/common_optimizations/softplus_fusion.hpp>
|
#include <transformations/common_optimizations/softplus_fusion.hpp>
|
||||||
@ -32,6 +34,7 @@
|
|||||||
#include <transformations/common_optimizations/lin_op_sequence_fusion.hpp>
|
#include <transformations/common_optimizations/lin_op_sequence_fusion.hpp>
|
||||||
#include <transformations/common_optimizations/conv_mul_fusion.hpp>
|
#include <transformations/common_optimizations/conv_mul_fusion.hpp>
|
||||||
#include <transformations/common_optimizations/nop_elimination.hpp>
|
#include <transformations/common_optimizations/nop_elimination.hpp>
|
||||||
|
#include <transformations/low_precision/disable_convert_constant_folding_on_const_path.hpp>
|
||||||
|
|
||||||
NGRAPH_RTTI_DEFINITION(ngraph::pass::MOCTransformations, "MOCTransformations", 0);
|
NGRAPH_RTTI_DEFINITION(ngraph::pass::MOCTransformations, "MOCTransformations", 0);
|
||||||
|
|
||||||
@ -48,6 +51,10 @@ bool ngraph::pass::MOCTransformations::run_on_function(std::shared_ptr<ngraph::F
|
|||||||
ngraph::pass::Manager manager(get_pass_config());
|
ngraph::pass::Manager manager(get_pass_config());
|
||||||
|
|
||||||
manager.register_pass<ngraph::pass::InitNodeInfo>();
|
manager.register_pass<ngraph::pass::InitNodeInfo>();
|
||||||
|
manager.register_pass<ngraph::pass::DisableConvertConstantFoldingOnConstPath>(
|
||||||
|
element::TypeVector{ ngraph::element::i8, ngraph::element::u8, ngraph::element::i4, ngraph::element::u4 });
|
||||||
|
manager.register_pass<ngraph::pass::DisableShapeOfConstantFolding>();
|
||||||
|
manager.register_pass<ngraph::pass::ConstantFolding>();
|
||||||
manager.register_pass<ngraph::pass::RemoveFilteringBoxesBySize>();
|
manager.register_pass<ngraph::pass::RemoveFilteringBoxesBySize>();
|
||||||
manager.register_pass<ngraph::pass::ConvertQuantizeDequantize>();
|
manager.register_pass<ngraph::pass::ConvertQuantizeDequantize>();
|
||||||
manager.register_pass<ngraph::pass::SimplifyShapeOfSubGraph>();
|
manager.register_pass<ngraph::pass::SimplifyShapeOfSubGraph>();
|
||||||
|
@ -22,5 +22,5 @@ class ngraph::pass::DisableConvertConstantFoldingOnConstPath : public ngraph::pa
|
|||||||
public:
|
public:
|
||||||
NGRAPH_RTTI_DECLARATION;
|
NGRAPH_RTTI_DECLARATION;
|
||||||
DisableConvertConstantFoldingOnConstPath(
|
DisableConvertConstantFoldingOnConstPath(
|
||||||
const std::vector<ngraph::element::Type>& inputPrecisions = {});
|
const element::TypeVector & inputPrecisions = {});
|
||||||
};
|
};
|
||||||
|
@ -0,0 +1,44 @@
|
|||||||
|
// Copyright (C) 2021 Intel Corporation
|
||||||
|
// SPDX-License-Identifier: Apache-2.0
|
||||||
|
//
|
||||||
|
|
||||||
|
#include <assert.h>
|
||||||
|
#include <functional>
|
||||||
|
#include <memory>
|
||||||
|
#include <string>
|
||||||
|
#include <set>
|
||||||
|
|
||||||
|
#include <ngraph/node.hpp>
|
||||||
|
#include <ngraph/variant.hpp>
|
||||||
|
#include <transformations_visibility.hpp>
|
||||||
|
|
||||||
|
|
||||||
|
namespace ngraph {
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @ingroup ie_runtime_attr_api
|
||||||
|
* @brief DisableConstantFolding disable ConstantFolding for given operation
|
||||||
|
*/
|
||||||
|
class TRANSFORMATIONS_API DisableConstantFolding {
|
||||||
|
public:
|
||||||
|
DisableConstantFolding() = default;
|
||||||
|
};
|
||||||
|
|
||||||
|
extern template class TRANSFORMATIONS_API VariantImpl<DisableConstantFolding>;
|
||||||
|
|
||||||
|
template<>
|
||||||
|
class TRANSFORMATIONS_API VariantWrapper<DisableConstantFolding> : public VariantImpl<DisableConstantFolding> {
|
||||||
|
public:
|
||||||
|
static constexpr VariantTypeInfo type_info{"DISABLED_CONSTANT_FOLDING", 0};
|
||||||
|
|
||||||
|
const VariantTypeInfo &get_type_info() const override {
|
||||||
|
return type_info;
|
||||||
|
}
|
||||||
|
|
||||||
|
VariantWrapper(const value_type &value) : VariantImpl<value_type>(value) {}
|
||||||
|
|
||||||
|
bool is_copyable() const override { return false; }
|
||||||
|
};
|
||||||
|
|
||||||
|
TRANSFORMATIONS_API void disable_constant_folding(const std::shared_ptr<Node>& node);
|
||||||
|
} // namespace ngraph
|
@ -20,7 +20,7 @@ using namespace ngraph;
|
|||||||
NGRAPH_RTTI_DEFINITION(ngraph::pass::DisableConvertConstantFoldingOnConstPath, "DisableConvertConstantFoldingOnConstPath", 0);
|
NGRAPH_RTTI_DEFINITION(ngraph::pass::DisableConvertConstantFoldingOnConstPath, "DisableConvertConstantFoldingOnConstPath", 0);
|
||||||
|
|
||||||
ngraph::pass::DisableConvertConstantFoldingOnConstPath::DisableConvertConstantFoldingOnConstPath(
|
ngraph::pass::DisableConvertConstantFoldingOnConstPath::DisableConvertConstantFoldingOnConstPath(
|
||||||
const std::vector<ngraph::element::Type>& inputPrecisions) {
|
const element::TypeVector & inputPrecisions) {
|
||||||
auto matcherData = ngraph::pattern::any_input();
|
auto matcherData = ngraph::pattern::any_input();
|
||||||
auto matcherConvert = ngraph::pattern::wrap_type<opset3::Convert>({ matcherData }, pattern::consumers_count(1));
|
auto matcherConvert = ngraph::pattern::wrap_type<opset3::Convert>({ matcherData }, pattern::consumers_count(1));
|
||||||
|
|
||||||
|
@ -0,0 +1,14 @@
|
|||||||
|
// Copyright (C) 2021 Intel Corporation
|
||||||
|
// SPDX-License-Identifier: Apache-2.0
|
||||||
|
//
|
||||||
|
|
||||||
|
#include "transformations/rt_info/disable_constant_folding.hpp"
|
||||||
|
|
||||||
|
template class ngraph::VariantImpl<ngraph::DisableConstantFolding>;
|
||||||
|
|
||||||
|
constexpr ngraph::VariantTypeInfo ngraph::VariantWrapper<ngraph::DisableConstantFolding>::type_info;
|
||||||
|
|
||||||
|
void ngraph::disable_constant_folding(const std::shared_ptr<Node>& node) {
|
||||||
|
auto & rt_info = node->get_rt_info();
|
||||||
|
rt_info[VariantWrapper<DisableConstantFolding>::type_info.name] = make_variant<DisableConstantFolding>({});
|
||||||
|
}
|
@ -0,0 +1,78 @@
|
|||||||
|
// Copyright (C) 2021 Intel Corporation
|
||||||
|
// SPDX-License-Identifier: Apache-2.0
|
||||||
|
//
|
||||||
|
|
||||||
|
#include <gtest/gtest.h>
|
||||||
|
|
||||||
|
#include <string>
|
||||||
|
#include <memory>
|
||||||
|
#include <queue>
|
||||||
|
|
||||||
|
#include <ngraph/function.hpp>
|
||||||
|
#include <ngraph/opsets/opset6.hpp>
|
||||||
|
#include <ngraph/pass/manager.hpp>
|
||||||
|
#include <ngraph/pass/constant_folding.hpp>
|
||||||
|
|
||||||
|
#include <disable_shapeof_constant_folding.hpp>
|
||||||
|
#include <transformations/serialize.hpp>
|
||||||
|
|
||||||
|
#include "common_test_utils/ngraph_test_utils.hpp"
|
||||||
|
|
||||||
|
|
||||||
|
using namespace testing;
|
||||||
|
using namespace ngraph;
|
||||||
|
|
||||||
|
TEST(TransformationTests, DisableShapeOfConstantFolding) {
|
||||||
|
std::shared_ptr<Function> f, f_ref;
|
||||||
|
{
|
||||||
|
auto data = std::make_shared<opset6::Parameter>(element::f32, Shape{1, 4, 10, 10});
|
||||||
|
auto shape_of = std::make_shared<opset6::ShapeOf>(data);
|
||||||
|
auto abs = std::make_shared<opset6::Abs>(shape_of);
|
||||||
|
auto reshape = std::make_shared<opset6::Reshape>(data, abs, false);
|
||||||
|
f = std::make_shared<Function>(NodeVector{reshape}, ParameterVector{data});
|
||||||
|
|
||||||
|
pass::Manager m;
|
||||||
|
m.register_pass<pass::DisableShapeOfConstantFolding>();
|
||||||
|
m.register_pass<pass::ConstantFolding>();
|
||||||
|
m.run_passes(f);
|
||||||
|
}
|
||||||
|
|
||||||
|
{
|
||||||
|
auto data = std::make_shared<opset6::Parameter>(element::f32, Shape{1, 4, 10, 10});
|
||||||
|
auto shape_of = std::make_shared<opset6::ShapeOf>(data);
|
||||||
|
auto abs = std::make_shared<opset6::Abs>(shape_of);
|
||||||
|
auto reshape = std::make_shared<opset6::Reshape>(data, abs, false);
|
||||||
|
f_ref = std::make_shared<Function>(NodeVector{reshape}, ParameterVector{data});
|
||||||
|
}
|
||||||
|
|
||||||
|
auto res = compare_functions(f, f_ref);
|
||||||
|
ASSERT_TRUE(res.first) << res.second;
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(TransformationTests, ShapeOfShapeOfConstantFolding) {
|
||||||
|
std::shared_ptr<Function> f, f_ref;
|
||||||
|
{
|
||||||
|
auto data = std::make_shared<opset6::Parameter>(element::i64, Shape{1, 4, 10, 10});
|
||||||
|
auto shape_of = std::make_shared<opset6::ShapeOf>(data);
|
||||||
|
auto reshape = std::make_shared<opset6::Reshape>(data, shape_of, false);
|
||||||
|
auto rank = std::make_shared<opset6::ShapeOf>(shape_of);
|
||||||
|
auto mul = std::make_shared<opset6::Multiply>(reshape, rank);
|
||||||
|
f = std::make_shared<Function>(NodeVector{mul}, ParameterVector{data});
|
||||||
|
|
||||||
|
pass::Manager m;
|
||||||
|
m.register_pass<pass::DisableShapeOfConstantFolding>();
|
||||||
|
m.register_pass<pass::ConstantFolding>();
|
||||||
|
m.run_passes(f);
|
||||||
|
}
|
||||||
|
|
||||||
|
{
|
||||||
|
auto data = std::make_shared<opset6::Parameter>(element::i64, Shape{1, 4, 10, 10});
|
||||||
|
auto shape_of = std::make_shared<opset6::ShapeOf>(data);
|
||||||
|
auto reshape = std::make_shared<opset6::Reshape>(data, shape_of, false);
|
||||||
|
auto mul = std::make_shared<opset6::Multiply>(reshape, opset6::Constant::create(element::i64, Shape{1}, {4}));
|
||||||
|
f_ref = std::make_shared<Function>(NodeVector{mul}, ParameterVector{data});
|
||||||
|
}
|
||||||
|
|
||||||
|
auto res = compare_functions(f, f_ref);
|
||||||
|
ASSERT_TRUE(res.first) << res.second;
|
||||||
|
}
|
@ -329,6 +329,8 @@ namespace ngraph
|
|||||||
/// that all the HostTensorPtrs are not equal to nullptr
|
/// that all the HostTensorPtrs are not equal to nullptr
|
||||||
NGRAPH_API bool validate_host_tensor_vector(const HostTensorVector& v, const size_t& size);
|
NGRAPH_API bool validate_host_tensor_vector(const HostTensorVector& v, const size_t& size);
|
||||||
|
|
||||||
|
NGRAPH_API bool could_propagate(const Output<Node>& output, std::vector<Node*>& order);
|
||||||
|
|
||||||
namespace opset1
|
namespace opset1
|
||||||
{
|
{
|
||||||
///
|
///
|
||||||
|
@ -24,6 +24,7 @@ namespace ngraph
|
|||||||
|
|
||||||
virtual std::shared_ptr<ngraph::Variant> init(const std::shared_ptr<ngraph::Node>& node);
|
virtual std::shared_ptr<ngraph::Variant> init(const std::shared_ptr<ngraph::Node>& node);
|
||||||
virtual std::shared_ptr<ngraph::Variant> merge(const ngraph::NodeVector& nodes);
|
virtual std::shared_ptr<ngraph::Variant> merge(const ngraph::NodeVector& nodes);
|
||||||
|
virtual bool is_copyable() const;
|
||||||
virtual std::string to_string() { return ""; }
|
virtual std::string to_string() { return ""; }
|
||||||
};
|
};
|
||||||
|
|
||||||
|
@ -6,6 +6,7 @@
|
|||||||
#include <ngraph/op/constant.hpp>
|
#include <ngraph/op/constant.hpp>
|
||||||
#include "ngraph/op/util/sub_graph_base.hpp"
|
#include "ngraph/op/util/sub_graph_base.hpp"
|
||||||
#include "ngraph/rt_info.hpp"
|
#include "ngraph/rt_info.hpp"
|
||||||
|
#include "ngraph/validation_util.hpp"
|
||||||
|
|
||||||
using namespace std;
|
using namespace std;
|
||||||
using namespace ngraph;
|
using namespace ngraph;
|
||||||
@ -101,7 +102,23 @@ bool ngraph::pass::ConstantFolding::pre_calculated_values_folding(
|
|||||||
|
|
||||||
for (auto& input_value : curr_node->input_values())
|
for (auto& input_value : curr_node->input_values())
|
||||||
{
|
{
|
||||||
if (input_value.get_tensor().has_and_set_bound())
|
// Check that ConstantFolding is not disabled on this path
|
||||||
|
std::vector<Node*> order;
|
||||||
|
auto status = could_propagate(input_value, order);
|
||||||
|
if (status)
|
||||||
|
{
|
||||||
|
for (const auto& node : order)
|
||||||
|
{
|
||||||
|
const auto& rt_info = node->get_rt_info();
|
||||||
|
if (rt_info.count("DISABLED_CONSTANT_FOLDING"))
|
||||||
|
{
|
||||||
|
status = false;
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if (status && input_value.get_tensor().has_and_set_bound())
|
||||||
{
|
{
|
||||||
auto input_node = input_value.get_node_shared_ptr();
|
auto input_node = input_value.get_node_shared_ptr();
|
||||||
auto replacement =
|
auto replacement =
|
||||||
|
@ -8,46 +8,47 @@
|
|||||||
|
|
||||||
ngraph::Node::RTMap mergeRuntimeInfo(const ngraph::NodeVector& nodes)
|
ngraph::Node::RTMap mergeRuntimeInfo(const ngraph::NodeVector& nodes)
|
||||||
{
|
{
|
||||||
ngraph::Node::RTMap mergedInfo;
|
std::unordered_map<std::string, std::vector<std::shared_ptr<ngraph::Variant>>> attrs;
|
||||||
for (auto& node : nodes)
|
for (const auto& node : nodes)
|
||||||
{
|
{
|
||||||
for (auto& item : node->get_rt_info())
|
for (const auto& item : node->get_rt_info())
|
||||||
{
|
{
|
||||||
mergedInfo[item.first] = item.second;
|
if (item.second->is_copyable())
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
ngraph::Node::RTMap newInfo;
|
|
||||||
for (auto& item : mergedInfo)
|
|
||||||
{
|
|
||||||
size_t attributes_count = 0;
|
|
||||||
for (auto& node : nodes)
|
|
||||||
{
|
|
||||||
const auto& rt_info = node->get_rt_info();
|
|
||||||
if (rt_info.count(item.first))
|
|
||||||
{
|
{
|
||||||
attributes_count++;
|
attrs[item.first].push_back(item.second);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
if (attributes_count == 1)
|
ngraph::Node::RTMap merged_attrs;
|
||||||
|
for (auto& item : attrs)
|
||||||
|
{
|
||||||
|
auto attr = *item.second.begin();
|
||||||
|
if (item.second.size() == 1)
|
||||||
{
|
{
|
||||||
newInfo[item.first] = item.second;
|
merged_attrs[item.first] = attr;
|
||||||
}
|
}
|
||||||
else if (auto merge_attr = item.second->merge(nodes))
|
else if (auto merge_attr = attr->merge(nodes))
|
||||||
{
|
{
|
||||||
newInfo[item.first] = merge_attr;
|
merged_attrs[item.first] = merge_attr;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return newInfo;
|
return merged_attrs;
|
||||||
}
|
}
|
||||||
|
|
||||||
void ngraph::copy_runtime_info(std::shared_ptr<ngraph::Node> from, std::shared_ptr<ngraph::Node> to)
|
void ngraph::copy_runtime_info(std::shared_ptr<ngraph::Node> from, std::shared_ptr<ngraph::Node> to)
|
||||||
{
|
{
|
||||||
auto& rtInfoFrom = from->get_rt_info();
|
auto& attrs = to->get_rt_info();
|
||||||
auto& rtInfoTo = to->get_rt_info();
|
attrs.clear();
|
||||||
rtInfoTo = rtInfoFrom;
|
|
||||||
|
for (const auto& item : from->get_rt_info())
|
||||||
|
{
|
||||||
|
if (item.second->is_copyable())
|
||||||
|
{
|
||||||
|
attrs[item.first] = item.second;
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
void ngraph::copy_runtime_info(std::shared_ptr<ngraph::Node> from, ngraph::NodeVector to)
|
void ngraph::copy_runtime_info(std::shared_ptr<ngraph::Node> from, ngraph::NodeVector to)
|
||||||
|
@ -1306,7 +1306,7 @@ void ngraph::evaluate_nodes(std::map<RawNodeOutput, HostTensorPtr>& value_map,
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
bool could_propagate(const Output<Node>& output, std::vector<Node*>& order)
|
bool ngraph::could_propagate(const Output<Node>& output, std::vector<Node*>& order)
|
||||||
{
|
{
|
||||||
bool status = true;
|
bool status = true;
|
||||||
|
|
||||||
@ -1367,7 +1367,7 @@ void propagate_rt_info(Node* node, const Output<Node>& final_port)
|
|||||||
auto& rt_info = consumer->get_rt_info();
|
auto& rt_info = consumer->get_rt_info();
|
||||||
for (const auto& it : orig_rt_info)
|
for (const auto& it : orig_rt_info)
|
||||||
{
|
{
|
||||||
if (rt_info.find(it.first) == rt_info.end())
|
if (rt_info.find(it.first) == rt_info.end() && it.second->is_copyable())
|
||||||
rt_info[it.first] = it.second;
|
rt_info[it.first] = it.second;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -22,5 +22,10 @@ std::shared_ptr<ngraph::Variant> Variant::merge(const ngraph::NodeVector& nodes)
|
|||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
bool Variant::is_copyable() const
|
||||||
|
{
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
template class ngraph::VariantImpl<std::string>;
|
template class ngraph::VariantImpl<std::string>;
|
||||||
template class ngraph::VariantImpl<int64_t>;
|
template class ngraph::VariantImpl<int64_t>;
|
||||||
|
Loading…
Reference in New Issue
Block a user