Enable ConstantFolding inside MO Backend (#6916)
This commit is contained in:
parent
5d781afe73
commit
b11a2220b0
@ -0,0 +1,24 @@
|
||||
// Copyright (C) 2021 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <memory>
|
||||
|
||||
#include <ngraph/pass/graph_rewrite.hpp>
|
||||
|
||||
namespace ngraph {
|
||||
namespace pass {
|
||||
|
||||
class DisableShapeOfConstantFolding;
|
||||
|
||||
} // namespace pass
|
||||
} // namespace ngraph
|
||||
|
||||
|
||||
class ngraph::pass::DisableShapeOfConstantFolding: public ngraph::pass::MatcherPass {
|
||||
public:
|
||||
NGRAPH_RTTI_DECLARATION;
|
||||
DisableShapeOfConstantFolding();
|
||||
};
|
@ -0,0 +1,32 @@
|
||||
// Copyright (C) 2021 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#include <memory>
|
||||
#include <ostream>
|
||||
#include <fstream>
|
||||
|
||||
#include <ngraph/pattern/op/wrap_type.hpp>
|
||||
#include <ngraph/opsets/opset2.hpp>
|
||||
#include <ngraph/opsets/opset3.hpp>
|
||||
#include <ngraph/variant.hpp>
|
||||
#include <transformations/rt_info/disable_constant_folding.hpp>
|
||||
|
||||
#include "disable_shapeof_constant_folding.hpp"
|
||||
|
||||
NGRAPH_RTTI_DEFINITION(ngraph::pass::DisableShapeOfConstantFolding, "DisableShapeOfConstantFolding", 0);
|
||||
|
||||
ngraph::pass::DisableShapeOfConstantFolding::DisableShapeOfConstantFolding() {
|
||||
auto shape_of = pattern::wrap_type<opset2::ShapeOf, opset3::ShapeOf>([=](const Output<Node> & output) {
|
||||
const auto & shape = output.get_partial_shape();
|
||||
return shape.is_dynamic() || shape_size(shape.get_shape()) != 1;
|
||||
});
|
||||
|
||||
ngraph::matcher_pass_callback callback = [=](pattern::Matcher& m) {
|
||||
disable_constant_folding(m.get_match_root());
|
||||
return true;
|
||||
};
|
||||
|
||||
auto m = std::make_shared<ngraph::pattern::Matcher>(shape_of, "DisableShapeOfConstantFolding");
|
||||
this->register_matcher(m, callback);
|
||||
}
|
@ -5,8 +5,10 @@
|
||||
#include <memory>
|
||||
|
||||
#include "moc_transformations.hpp"
|
||||
#include "disable_shapeof_constant_folding.hpp"
|
||||
|
||||
#include <ngraph/pass/manager.hpp>
|
||||
#include <ngraph/pass/constant_folding.hpp>
|
||||
#include <transformations/init_node_info.hpp>
|
||||
#include <transformations/common_optimizations/gelu_fusion.hpp>
|
||||
#include <transformations/common_optimizations/softplus_fusion.hpp>
|
||||
@ -32,6 +34,7 @@
|
||||
#include <transformations/common_optimizations/lin_op_sequence_fusion.hpp>
|
||||
#include <transformations/common_optimizations/conv_mul_fusion.hpp>
|
||||
#include <transformations/common_optimizations/nop_elimination.hpp>
|
||||
#include <transformations/low_precision/disable_convert_constant_folding_on_const_path.hpp>
|
||||
|
||||
NGRAPH_RTTI_DEFINITION(ngraph::pass::MOCTransformations, "MOCTransformations", 0);
|
||||
|
||||
@ -48,6 +51,10 @@ bool ngraph::pass::MOCTransformations::run_on_function(std::shared_ptr<ngraph::F
|
||||
ngraph::pass::Manager manager(get_pass_config());
|
||||
|
||||
manager.register_pass<ngraph::pass::InitNodeInfo>();
|
||||
manager.register_pass<ngraph::pass::DisableConvertConstantFoldingOnConstPath>(
|
||||
element::TypeVector{ ngraph::element::i8, ngraph::element::u8, ngraph::element::i4, ngraph::element::u4 });
|
||||
manager.register_pass<ngraph::pass::DisableShapeOfConstantFolding>();
|
||||
manager.register_pass<ngraph::pass::ConstantFolding>();
|
||||
manager.register_pass<ngraph::pass::RemoveFilteringBoxesBySize>();
|
||||
manager.register_pass<ngraph::pass::ConvertQuantizeDequantize>();
|
||||
manager.register_pass<ngraph::pass::SimplifyShapeOfSubGraph>();
|
||||
|
@ -22,5 +22,5 @@ class ngraph::pass::DisableConvertConstantFoldingOnConstPath : public ngraph::pa
|
||||
public:
|
||||
NGRAPH_RTTI_DECLARATION;
|
||||
DisableConvertConstantFoldingOnConstPath(
|
||||
const std::vector<ngraph::element::Type>& inputPrecisions = {});
|
||||
const element::TypeVector & inputPrecisions = {});
|
||||
};
|
||||
|
@ -0,0 +1,44 @@
|
||||
// Copyright (C) 2021 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#include <assert.h>
|
||||
#include <functional>
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <set>
|
||||
|
||||
#include <ngraph/node.hpp>
|
||||
#include <ngraph/variant.hpp>
|
||||
#include <transformations_visibility.hpp>
|
||||
|
||||
|
||||
namespace ngraph {
|
||||
|
||||
/**
|
||||
* @ingroup ie_runtime_attr_api
|
||||
* @brief DisableConstantFolding disable ConstantFolding for given operation
|
||||
*/
|
||||
class TRANSFORMATIONS_API DisableConstantFolding {
|
||||
public:
|
||||
DisableConstantFolding() = default;
|
||||
};
|
||||
|
||||
extern template class TRANSFORMATIONS_API VariantImpl<DisableConstantFolding>;
|
||||
|
||||
template<>
|
||||
class TRANSFORMATIONS_API VariantWrapper<DisableConstantFolding> : public VariantImpl<DisableConstantFolding> {
|
||||
public:
|
||||
static constexpr VariantTypeInfo type_info{"DISABLED_CONSTANT_FOLDING", 0};
|
||||
|
||||
const VariantTypeInfo &get_type_info() const override {
|
||||
return type_info;
|
||||
}
|
||||
|
||||
VariantWrapper(const value_type &value) : VariantImpl<value_type>(value) {}
|
||||
|
||||
bool is_copyable() const override { return false; }
|
||||
};
|
||||
|
||||
TRANSFORMATIONS_API void disable_constant_folding(const std::shared_ptr<Node>& node);
|
||||
} // namespace ngraph
|
@ -20,7 +20,7 @@ using namespace ngraph;
|
||||
NGRAPH_RTTI_DEFINITION(ngraph::pass::DisableConvertConstantFoldingOnConstPath, "DisableConvertConstantFoldingOnConstPath", 0);
|
||||
|
||||
ngraph::pass::DisableConvertConstantFoldingOnConstPath::DisableConvertConstantFoldingOnConstPath(
|
||||
const std::vector<ngraph::element::Type>& inputPrecisions) {
|
||||
const element::TypeVector & inputPrecisions) {
|
||||
auto matcherData = ngraph::pattern::any_input();
|
||||
auto matcherConvert = ngraph::pattern::wrap_type<opset3::Convert>({ matcherData }, pattern::consumers_count(1));
|
||||
|
||||
|
@ -0,0 +1,14 @@
|
||||
// Copyright (C) 2021 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#include "transformations/rt_info/disable_constant_folding.hpp"
|
||||
|
||||
template class ngraph::VariantImpl<ngraph::DisableConstantFolding>;
|
||||
|
||||
constexpr ngraph::VariantTypeInfo ngraph::VariantWrapper<ngraph::DisableConstantFolding>::type_info;
|
||||
|
||||
void ngraph::disable_constant_folding(const std::shared_ptr<Node>& node) {
|
||||
auto & rt_info = node->get_rt_info();
|
||||
rt_info[VariantWrapper<DisableConstantFolding>::type_info.name] = make_variant<DisableConstantFolding>({});
|
||||
}
|
@ -0,0 +1,78 @@
|
||||
// Copyright (C) 2021 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#include <gtest/gtest.h>
|
||||
|
||||
#include <string>
|
||||
#include <memory>
|
||||
#include <queue>
|
||||
|
||||
#include <ngraph/function.hpp>
|
||||
#include <ngraph/opsets/opset6.hpp>
|
||||
#include <ngraph/pass/manager.hpp>
|
||||
#include <ngraph/pass/constant_folding.hpp>
|
||||
|
||||
#include <disable_shapeof_constant_folding.hpp>
|
||||
#include <transformations/serialize.hpp>
|
||||
|
||||
#include "common_test_utils/ngraph_test_utils.hpp"
|
||||
|
||||
|
||||
using namespace testing;
|
||||
using namespace ngraph;
|
||||
|
||||
TEST(TransformationTests, DisableShapeOfConstantFolding) {
|
||||
std::shared_ptr<Function> f, f_ref;
|
||||
{
|
||||
auto data = std::make_shared<opset6::Parameter>(element::f32, Shape{1, 4, 10, 10});
|
||||
auto shape_of = std::make_shared<opset6::ShapeOf>(data);
|
||||
auto abs = std::make_shared<opset6::Abs>(shape_of);
|
||||
auto reshape = std::make_shared<opset6::Reshape>(data, abs, false);
|
||||
f = std::make_shared<Function>(NodeVector{reshape}, ParameterVector{data});
|
||||
|
||||
pass::Manager m;
|
||||
m.register_pass<pass::DisableShapeOfConstantFolding>();
|
||||
m.register_pass<pass::ConstantFolding>();
|
||||
m.run_passes(f);
|
||||
}
|
||||
|
||||
{
|
||||
auto data = std::make_shared<opset6::Parameter>(element::f32, Shape{1, 4, 10, 10});
|
||||
auto shape_of = std::make_shared<opset6::ShapeOf>(data);
|
||||
auto abs = std::make_shared<opset6::Abs>(shape_of);
|
||||
auto reshape = std::make_shared<opset6::Reshape>(data, abs, false);
|
||||
f_ref = std::make_shared<Function>(NodeVector{reshape}, ParameterVector{data});
|
||||
}
|
||||
|
||||
auto res = compare_functions(f, f_ref);
|
||||
ASSERT_TRUE(res.first) << res.second;
|
||||
}
|
||||
|
||||
TEST(TransformationTests, ShapeOfShapeOfConstantFolding) {
|
||||
std::shared_ptr<Function> f, f_ref;
|
||||
{
|
||||
auto data = std::make_shared<opset6::Parameter>(element::i64, Shape{1, 4, 10, 10});
|
||||
auto shape_of = std::make_shared<opset6::ShapeOf>(data);
|
||||
auto reshape = std::make_shared<opset6::Reshape>(data, shape_of, false);
|
||||
auto rank = std::make_shared<opset6::ShapeOf>(shape_of);
|
||||
auto mul = std::make_shared<opset6::Multiply>(reshape, rank);
|
||||
f = std::make_shared<Function>(NodeVector{mul}, ParameterVector{data});
|
||||
|
||||
pass::Manager m;
|
||||
m.register_pass<pass::DisableShapeOfConstantFolding>();
|
||||
m.register_pass<pass::ConstantFolding>();
|
||||
m.run_passes(f);
|
||||
}
|
||||
|
||||
{
|
||||
auto data = std::make_shared<opset6::Parameter>(element::i64, Shape{1, 4, 10, 10});
|
||||
auto shape_of = std::make_shared<opset6::ShapeOf>(data);
|
||||
auto reshape = std::make_shared<opset6::Reshape>(data, shape_of, false);
|
||||
auto mul = std::make_shared<opset6::Multiply>(reshape, opset6::Constant::create(element::i64, Shape{1}, {4}));
|
||||
f_ref = std::make_shared<Function>(NodeVector{mul}, ParameterVector{data});
|
||||
}
|
||||
|
||||
auto res = compare_functions(f, f_ref);
|
||||
ASSERT_TRUE(res.first) << res.second;
|
||||
}
|
@ -329,6 +329,8 @@ namespace ngraph
|
||||
/// that all the HostTensorPtrs are not equal to nullptr
|
||||
NGRAPH_API bool validate_host_tensor_vector(const HostTensorVector& v, const size_t& size);
|
||||
|
||||
NGRAPH_API bool could_propagate(const Output<Node>& output, std::vector<Node*>& order);
|
||||
|
||||
namespace opset1
|
||||
{
|
||||
///
|
||||
|
@ -24,6 +24,7 @@ namespace ngraph
|
||||
|
||||
virtual std::shared_ptr<ngraph::Variant> init(const std::shared_ptr<ngraph::Node>& node);
|
||||
virtual std::shared_ptr<ngraph::Variant> merge(const ngraph::NodeVector& nodes);
|
||||
virtual bool is_copyable() const;
|
||||
virtual std::string to_string() { return ""; }
|
||||
};
|
||||
|
||||
|
@ -6,6 +6,7 @@
|
||||
#include <ngraph/op/constant.hpp>
|
||||
#include "ngraph/op/util/sub_graph_base.hpp"
|
||||
#include "ngraph/rt_info.hpp"
|
||||
#include "ngraph/validation_util.hpp"
|
||||
|
||||
using namespace std;
|
||||
using namespace ngraph;
|
||||
@ -101,7 +102,23 @@ bool ngraph::pass::ConstantFolding::pre_calculated_values_folding(
|
||||
|
||||
for (auto& input_value : curr_node->input_values())
|
||||
{
|
||||
if (input_value.get_tensor().has_and_set_bound())
|
||||
// Check that ConstantFolding is not disabled on this path
|
||||
std::vector<Node*> order;
|
||||
auto status = could_propagate(input_value, order);
|
||||
if (status)
|
||||
{
|
||||
for (const auto& node : order)
|
||||
{
|
||||
const auto& rt_info = node->get_rt_info();
|
||||
if (rt_info.count("DISABLED_CONSTANT_FOLDING"))
|
||||
{
|
||||
status = false;
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (status && input_value.get_tensor().has_and_set_bound())
|
||||
{
|
||||
auto input_node = input_value.get_node_shared_ptr();
|
||||
auto replacement =
|
||||
|
@ -8,46 +8,47 @@
|
||||
|
||||
ngraph::Node::RTMap mergeRuntimeInfo(const ngraph::NodeVector& nodes)
|
||||
{
|
||||
ngraph::Node::RTMap mergedInfo;
|
||||
for (auto& node : nodes)
|
||||
std::unordered_map<std::string, std::vector<std::shared_ptr<ngraph::Variant>>> attrs;
|
||||
for (const auto& node : nodes)
|
||||
{
|
||||
for (auto& item : node->get_rt_info())
|
||||
for (const auto& item : node->get_rt_info())
|
||||
{
|
||||
mergedInfo[item.first] = item.second;
|
||||
if (item.second->is_copyable())
|
||||
{
|
||||
attrs[item.first].push_back(item.second);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
ngraph::Node::RTMap newInfo;
|
||||
for (auto& item : mergedInfo)
|
||||
ngraph::Node::RTMap merged_attrs;
|
||||
for (auto& item : attrs)
|
||||
{
|
||||
size_t attributes_count = 0;
|
||||
for (auto& node : nodes)
|
||||
auto attr = *item.second.begin();
|
||||
if (item.second.size() == 1)
|
||||
{
|
||||
const auto& rt_info = node->get_rt_info();
|
||||
if (rt_info.count(item.first))
|
||||
merged_attrs[item.first] = attr;
|
||||
}
|
||||
else if (auto merge_attr = attr->merge(nodes))
|
||||
{
|
||||
attributes_count++;
|
||||
merged_attrs[item.first] = merge_attr;
|
||||
}
|
||||
}
|
||||
|
||||
if (attributes_count == 1)
|
||||
{
|
||||
newInfo[item.first] = item.second;
|
||||
}
|
||||
else if (auto merge_attr = item.second->merge(nodes))
|
||||
{
|
||||
newInfo[item.first] = merge_attr;
|
||||
}
|
||||
}
|
||||
|
||||
return newInfo;
|
||||
return merged_attrs;
|
||||
}
|
||||
|
||||
void ngraph::copy_runtime_info(std::shared_ptr<ngraph::Node> from, std::shared_ptr<ngraph::Node> to)
|
||||
{
|
||||
auto& rtInfoFrom = from->get_rt_info();
|
||||
auto& rtInfoTo = to->get_rt_info();
|
||||
rtInfoTo = rtInfoFrom;
|
||||
auto& attrs = to->get_rt_info();
|
||||
attrs.clear();
|
||||
|
||||
for (const auto& item : from->get_rt_info())
|
||||
{
|
||||
if (item.second->is_copyable())
|
||||
{
|
||||
attrs[item.first] = item.second;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void ngraph::copy_runtime_info(std::shared_ptr<ngraph::Node> from, ngraph::NodeVector to)
|
||||
|
@ -1306,7 +1306,7 @@ void ngraph::evaluate_nodes(std::map<RawNodeOutput, HostTensorPtr>& value_map,
|
||||
}
|
||||
}
|
||||
|
||||
bool could_propagate(const Output<Node>& output, std::vector<Node*>& order)
|
||||
bool ngraph::could_propagate(const Output<Node>& output, std::vector<Node*>& order)
|
||||
{
|
||||
bool status = true;
|
||||
|
||||
@ -1367,7 +1367,7 @@ void propagate_rt_info(Node* node, const Output<Node>& final_port)
|
||||
auto& rt_info = consumer->get_rt_info();
|
||||
for (const auto& it : orig_rt_info)
|
||||
{
|
||||
if (rt_info.find(it.first) == rt_info.end())
|
||||
if (rt_info.find(it.first) == rt_info.end() && it.second->is_copyable())
|
||||
rt_info[it.first] = it.second;
|
||||
}
|
||||
}
|
||||
|
@ -22,5 +22,10 @@ std::shared_ptr<ngraph::Variant> Variant::merge(const ngraph::NodeVector& nodes)
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
bool Variant::is_copyable() const
|
||||
{
|
||||
return true;
|
||||
}
|
||||
|
||||
template class ngraph::VariantImpl<std::string>;
|
||||
template class ngraph::VariantImpl<int64_t>;
|
||||
|
Loading…
Reference in New Issue
Block a user