diff --git a/inference-engine/src/offline_transformations/include/disable_shapeof_constant_folding.hpp b/inference-engine/src/offline_transformations/include/disable_shapeof_constant_folding.hpp new file mode 100644 index 00000000000..678b41af0ef --- /dev/null +++ b/inference-engine/src/offline_transformations/include/disable_shapeof_constant_folding.hpp @@ -0,0 +1,24 @@ +// Copyright (C) 2021 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#pragma once + +#include + +#include + +namespace ngraph { +namespace pass { + +class DisableShapeOfConstantFolding; + +} // namespace pass +} // namespace ngraph + + +class ngraph::pass::DisableShapeOfConstantFolding: public ngraph::pass::MatcherPass { +public: + NGRAPH_RTTI_DECLARATION; + DisableShapeOfConstantFolding(); +}; diff --git a/inference-engine/src/offline_transformations/src/disable_shapeof_constant_folding.cpp b/inference-engine/src/offline_transformations/src/disable_shapeof_constant_folding.cpp new file mode 100644 index 00000000000..456ba721647 --- /dev/null +++ b/inference-engine/src/offline_transformations/src/disable_shapeof_constant_folding.cpp @@ -0,0 +1,32 @@ +// Copyright (C) 2021 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#include +#include +#include + +#include +#include +#include +#include +#include + +#include "disable_shapeof_constant_folding.hpp" + +NGRAPH_RTTI_DEFINITION(ngraph::pass::DisableShapeOfConstantFolding, "DisableShapeOfConstantFolding", 0); + +ngraph::pass::DisableShapeOfConstantFolding::DisableShapeOfConstantFolding() { + auto shape_of = pattern::wrap_type([=](const Output & output) { + const auto & shape = output.get_partial_shape(); + return shape.is_dynamic() || shape_size(shape.get_shape()) != 1; + }); + + ngraph::matcher_pass_callback callback = [=](pattern::Matcher& m) { + disable_constant_folding(m.get_match_root()); + return true; + }; + + auto m = std::make_shared(shape_of, "DisableShapeOfConstantFolding"); + this->register_matcher(m, callback); +} diff --git a/inference-engine/src/offline_transformations/src/moc_transformations.cpp b/inference-engine/src/offline_transformations/src/moc_transformations.cpp index 1e6c865353e..1a23f72e607 100644 --- a/inference-engine/src/offline_transformations/src/moc_transformations.cpp +++ b/inference-engine/src/offline_transformations/src/moc_transformations.cpp @@ -5,8 +5,10 @@ #include #include "moc_transformations.hpp" +#include "disable_shapeof_constant_folding.hpp" #include +#include #include #include #include @@ -32,6 +34,7 @@ #include #include #include +#include NGRAPH_RTTI_DEFINITION(ngraph::pass::MOCTransformations, "MOCTransformations", 0); @@ -48,6 +51,10 @@ bool ngraph::pass::MOCTransformations::run_on_function(std::shared_ptr(); + manager.register_pass( + element::TypeVector{ ngraph::element::i8, ngraph::element::u8, ngraph::element::i4, ngraph::element::u4 }); + manager.register_pass(); + manager.register_pass(); manager.register_pass(); manager.register_pass(); manager.register_pass(); diff --git a/inference-engine/src/transformations/include/transformations/low_precision/disable_convert_constant_folding_on_const_path.hpp b/inference-engine/src/transformations/include/transformations/low_precision/disable_convert_constant_folding_on_const_path.hpp index 79ad6e3e882..f5405daa6d6 100644 --- a/inference-engine/src/transformations/include/transformations/low_precision/disable_convert_constant_folding_on_const_path.hpp +++ b/inference-engine/src/transformations/include/transformations/low_precision/disable_convert_constant_folding_on_const_path.hpp @@ -22,5 +22,5 @@ class ngraph::pass::DisableConvertConstantFoldingOnConstPath : public ngraph::pa public: NGRAPH_RTTI_DECLARATION; DisableConvertConstantFoldingOnConstPath( - const std::vector& inputPrecisions = {}); + const element::TypeVector & inputPrecisions = {}); }; diff --git a/inference-engine/src/transformations/include/transformations/rt_info/disable_constant_folding.hpp b/inference-engine/src/transformations/include/transformations/rt_info/disable_constant_folding.hpp new file mode 100644 index 00000000000..1e04ce22dcc --- /dev/null +++ b/inference-engine/src/transformations/include/transformations/rt_info/disable_constant_folding.hpp @@ -0,0 +1,44 @@ +// Copyright (C) 2021 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#include +#include +#include +#include +#include + +#include +#include +#include + + +namespace ngraph { + +/** + * @ingroup ie_runtime_attr_api + * @brief DisableConstantFolding disable ConstantFolding for given operation + */ +class TRANSFORMATIONS_API DisableConstantFolding { +public: + DisableConstantFolding() = default; +}; + +extern template class TRANSFORMATIONS_API VariantImpl; + +template<> +class TRANSFORMATIONS_API VariantWrapper : public VariantImpl { +public: + static constexpr VariantTypeInfo type_info{"DISABLED_CONSTANT_FOLDING", 0}; + + const VariantTypeInfo &get_type_info() const override { + return type_info; + } + + VariantWrapper(const value_type &value) : VariantImpl(value) {} + + bool is_copyable() const override { return false; } +}; + +TRANSFORMATIONS_API void disable_constant_folding(const std::shared_ptr& node); +} // namespace ngraph diff --git a/inference-engine/src/transformations/src/transformations/low_precision/disable_convert_constant_folding_on_const_path.cpp b/inference-engine/src/transformations/src/transformations/low_precision/disable_convert_constant_folding_on_const_path.cpp index d5c30e73e4f..44d05860c4f 100644 --- a/inference-engine/src/transformations/src/transformations/low_precision/disable_convert_constant_folding_on_const_path.cpp +++ b/inference-engine/src/transformations/src/transformations/low_precision/disable_convert_constant_folding_on_const_path.cpp @@ -20,7 +20,7 @@ using namespace ngraph; NGRAPH_RTTI_DEFINITION(ngraph::pass::DisableConvertConstantFoldingOnConstPath, "DisableConvertConstantFoldingOnConstPath", 0); ngraph::pass::DisableConvertConstantFoldingOnConstPath::DisableConvertConstantFoldingOnConstPath( - const std::vector& inputPrecisions) { + const element::TypeVector & inputPrecisions) { auto matcherData = ngraph::pattern::any_input(); auto matcherConvert = ngraph::pattern::wrap_type({ matcherData }, pattern::consumers_count(1)); diff --git a/inference-engine/src/transformations/src/transformations/rt_info/disable_constant_folding.cpp b/inference-engine/src/transformations/src/transformations/rt_info/disable_constant_folding.cpp new file mode 100644 index 00000000000..791102ed1f4 --- /dev/null +++ b/inference-engine/src/transformations/src/transformations/rt_info/disable_constant_folding.cpp @@ -0,0 +1,14 @@ +// Copyright (C) 2021 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#include "transformations/rt_info/disable_constant_folding.hpp" + +template class ngraph::VariantImpl; + +constexpr ngraph::VariantTypeInfo ngraph::VariantWrapper::type_info; + +void ngraph::disable_constant_folding(const std::shared_ptr& node) { + auto & rt_info = node->get_rt_info(); + rt_info[VariantWrapper::type_info.name] = make_variant({}); +} \ No newline at end of file diff --git a/inference-engine/tests/functional/inference_engine/transformations/disable_shapeof_constant_folding_tests.cpp b/inference-engine/tests/functional/inference_engine/transformations/disable_shapeof_constant_folding_tests.cpp new file mode 100644 index 00000000000..2e526e4c72d --- /dev/null +++ b/inference-engine/tests/functional/inference_engine/transformations/disable_shapeof_constant_folding_tests.cpp @@ -0,0 +1,78 @@ +// Copyright (C) 2021 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#include + +#include +#include +#include + +#include +#include +#include +#include + +#include +#include + +#include "common_test_utils/ngraph_test_utils.hpp" + + +using namespace testing; +using namespace ngraph; + +TEST(TransformationTests, DisableShapeOfConstantFolding) { + std::shared_ptr f, f_ref; + { + auto data = std::make_shared(element::f32, Shape{1, 4, 10, 10}); + auto shape_of = std::make_shared(data); + auto abs = std::make_shared(shape_of); + auto reshape = std::make_shared(data, abs, false); + f = std::make_shared(NodeVector{reshape}, ParameterVector{data}); + + pass::Manager m; + m.register_pass(); + m.register_pass(); + m.run_passes(f); + } + + { + auto data = std::make_shared(element::f32, Shape{1, 4, 10, 10}); + auto shape_of = std::make_shared(data); + auto abs = std::make_shared(shape_of); + auto reshape = std::make_shared(data, abs, false); + f_ref = std::make_shared(NodeVector{reshape}, ParameterVector{data}); + } + + auto res = compare_functions(f, f_ref); + ASSERT_TRUE(res.first) << res.second; +} + +TEST(TransformationTests, ShapeOfShapeOfConstantFolding) { + std::shared_ptr f, f_ref; + { + auto data = std::make_shared(element::i64, Shape{1, 4, 10, 10}); + auto shape_of = std::make_shared(data); + auto reshape = std::make_shared(data, shape_of, false); + auto rank = std::make_shared(shape_of); + auto mul = std::make_shared(reshape, rank); + f = std::make_shared(NodeVector{mul}, ParameterVector{data}); + + pass::Manager m; + m.register_pass(); + m.register_pass(); + m.run_passes(f); + } + + { + auto data = std::make_shared(element::i64, Shape{1, 4, 10, 10}); + auto shape_of = std::make_shared(data); + auto reshape = std::make_shared(data, shape_of, false); + auto mul = std::make_shared(reshape, opset6::Constant::create(element::i64, Shape{1}, {4})); + f_ref = std::make_shared(NodeVector{mul}, ParameterVector{data}); + } + + auto res = compare_functions(f, f_ref); + ASSERT_TRUE(res.first) << res.second; +} \ No newline at end of file diff --git a/ngraph/core/include/ngraph/validation_util.hpp b/ngraph/core/include/ngraph/validation_util.hpp index 7bb64867d79..60a245890d6 100644 --- a/ngraph/core/include/ngraph/validation_util.hpp +++ b/ngraph/core/include/ngraph/validation_util.hpp @@ -329,6 +329,8 @@ namespace ngraph /// that all the HostTensorPtrs are not equal to nullptr NGRAPH_API bool validate_host_tensor_vector(const HostTensorVector& v, const size_t& size); + NGRAPH_API bool could_propagate(const Output& output, std::vector& order); + namespace opset1 { /// diff --git a/ngraph/core/include/ngraph/variant.hpp b/ngraph/core/include/ngraph/variant.hpp index 2ddad182d38..5fb7b99d875 100644 --- a/ngraph/core/include/ngraph/variant.hpp +++ b/ngraph/core/include/ngraph/variant.hpp @@ -24,6 +24,7 @@ namespace ngraph virtual std::shared_ptr init(const std::shared_ptr& node); virtual std::shared_ptr merge(const ngraph::NodeVector& nodes); + virtual bool is_copyable() const; virtual std::string to_string() { return ""; } }; diff --git a/ngraph/core/src/pass/constant_folding.cpp b/ngraph/core/src/pass/constant_folding.cpp index 824df2ac26b..edb7f7980f6 100644 --- a/ngraph/core/src/pass/constant_folding.cpp +++ b/ngraph/core/src/pass/constant_folding.cpp @@ -6,6 +6,7 @@ #include #include "ngraph/op/util/sub_graph_base.hpp" #include "ngraph/rt_info.hpp" +#include "ngraph/validation_util.hpp" using namespace std; using namespace ngraph; @@ -101,7 +102,23 @@ bool ngraph::pass::ConstantFolding::pre_calculated_values_folding( for (auto& input_value : curr_node->input_values()) { - if (input_value.get_tensor().has_and_set_bound()) + // Check that ConstantFolding is not disabled on this path + std::vector order; + auto status = could_propagate(input_value, order); + if (status) + { + for (const auto& node : order) + { + const auto& rt_info = node->get_rt_info(); + if (rt_info.count("DISABLED_CONSTANT_FOLDING")) + { + status = false; + break; + } + } + } + + if (status && input_value.get_tensor().has_and_set_bound()) { auto input_node = input_value.get_node_shared_ptr(); auto replacement = diff --git a/ngraph/core/src/rt_info.cpp b/ngraph/core/src/rt_info.cpp index c444be5d531..1409162d7e6 100644 --- a/ngraph/core/src/rt_info.cpp +++ b/ngraph/core/src/rt_info.cpp @@ -8,46 +8,47 @@ ngraph::Node::RTMap mergeRuntimeInfo(const ngraph::NodeVector& nodes) { - ngraph::Node::RTMap mergedInfo; - for (auto& node : nodes) + std::unordered_map>> attrs; + for (const auto& node : nodes) { - for (auto& item : node->get_rt_info()) + for (const auto& item : node->get_rt_info()) { - mergedInfo[item.first] = item.second; - } - } - - ngraph::Node::RTMap newInfo; - for (auto& item : mergedInfo) - { - size_t attributes_count = 0; - for (auto& node : nodes) - { - const auto& rt_info = node->get_rt_info(); - if (rt_info.count(item.first)) + if (item.second->is_copyable()) { - attributes_count++; + attrs[item.first].push_back(item.second); } } + } - if (attributes_count == 1) + ngraph::Node::RTMap merged_attrs; + for (auto& item : attrs) + { + auto attr = *item.second.begin(); + if (item.second.size() == 1) { - newInfo[item.first] = item.second; + merged_attrs[item.first] = attr; } - else if (auto merge_attr = item.second->merge(nodes)) + else if (auto merge_attr = attr->merge(nodes)) { - newInfo[item.first] = merge_attr; + merged_attrs[item.first] = merge_attr; } } - return newInfo; + return merged_attrs; } void ngraph::copy_runtime_info(std::shared_ptr from, std::shared_ptr to) { - auto& rtInfoFrom = from->get_rt_info(); - auto& rtInfoTo = to->get_rt_info(); - rtInfoTo = rtInfoFrom; + auto& attrs = to->get_rt_info(); + attrs.clear(); + + for (const auto& item : from->get_rt_info()) + { + if (item.second->is_copyable()) + { + attrs[item.first] = item.second; + } + } } void ngraph::copy_runtime_info(std::shared_ptr from, ngraph::NodeVector to) diff --git a/ngraph/core/src/validation_util.cpp b/ngraph/core/src/validation_util.cpp index 79b9b1588fa..ff7f3d75ccc 100644 --- a/ngraph/core/src/validation_util.cpp +++ b/ngraph/core/src/validation_util.cpp @@ -1306,7 +1306,7 @@ void ngraph::evaluate_nodes(std::map& value_map, } } -bool could_propagate(const Output& output, std::vector& order) +bool ngraph::could_propagate(const Output& output, std::vector& order) { bool status = true; @@ -1367,7 +1367,7 @@ void propagate_rt_info(Node* node, const Output& final_port) auto& rt_info = consumer->get_rt_info(); for (const auto& it : orig_rt_info) { - if (rt_info.find(it.first) == rt_info.end()) + if (rt_info.find(it.first) == rt_info.end() && it.second->is_copyable()) rt_info[it.first] = it.second; } } diff --git a/ngraph/core/src/variant.cpp b/ngraph/core/src/variant.cpp index a4b780c41e8..43500cb555a 100644 --- a/ngraph/core/src/variant.cpp +++ b/ngraph/core/src/variant.cpp @@ -22,5 +22,10 @@ std::shared_ptr Variant::merge(const ngraph::NodeVector& nodes) return nullptr; } +bool Variant::is_copyable() const +{ + return true; +} + template class ngraph::VariantImpl; template class ngraph::VariantImpl;