From e09b1a9fa2dfaab677031211c3a7af773d3c75db Mon Sep 17 00:00:00 2001 From: Pawel Raasz Date: Wed, 24 May 2023 09:20:17 +0200 Subject: [PATCH] Fix StridedSlice constant folding when disabled attribute set. (#17679) * Do not fold StridedSlice when: - On begin or end there is ShapeOf with disabled constant folding. - StridedSlice op has disabled constant folding. * Copy rt info to folded StridedSlice --- src/core/src/op/strided_slice.cpp | 24 ++++++--- src/core/tests/constant_folding.cpp | 79 +++++++++++++++++++++++++++++ 2 files changed, 97 insertions(+), 6 deletions(-) diff --git a/src/core/src/op/strided_slice.cpp b/src/core/src/op/strided_slice.cpp index ef6d3b6dabe..cba7b37a10e 100644 --- a/src/core/src/op/strided_slice.cpp +++ b/src/core/src/op/strided_slice.cpp @@ -18,8 +18,10 @@ #include "ngraph/slice_plan.hpp" #include "ngraph/type/element_type_traits.hpp" #include "ngraph/util.hpp" +#include "openvino/core/rt_info.hpp" #include "openvino/core/validation_util.hpp" #include "openvino/op/util/precision_sensitive_attribute.hpp" +#include "openvino/pass/constant_folding.hpp" #include "strided_slice_shape_inference.hpp" using namespace std; @@ -296,7 +298,7 @@ bool op::v1::StridedSlice::evaluate_label(TensorLabelVector& output_labels) cons bool op::v1::StridedSlice::constant_fold(OutputVector& output_values, const OutputVector& inputs_values) { auto is_folded = Node::constant_fold(output_values, inputs_values); - if (!is_folded) { + if (!is_const_fold_disabled() && !is_folded) { // If all ignored mask are set for all begin or end then replace this input by dummy constant // to avoid return false from `could_propagate` during bound evaluation (value of const will be ignored). auto get_indices_input = [&inputs_values](size_t port, const std::vector& mask) -> Output { @@ -325,11 +327,21 @@ bool op::v1::StridedSlice::constant_fold(OutputVector& output_values, const Outp ? clone_with_new_inputs(OutputVector{inputs_values[0], begin, end, inputs_values[3]})->output(0) : this->output(0); - OPENVINO_SUPPRESS_DEPRECATED_START - if (const auto c = ov::get_constant_from_source(output)) { - OPENVINO_SUPPRESS_DEPRECATED_END - output_values[0] = c; - is_folded = true; + std::vector nodes; + // Check if bounds can be evaluated and none of output nodes have disabled constant folding. + if (ov::could_propagate(output, nodes) && std::none_of(nodes.begin(), nodes.end(), [](const Node* n) { + return ov::pass::constant_folding_is_disabled(n); + })) { + OPENVINO_SUPPRESS_DEPRECATED_START + if (const auto c = ov::get_constant_from_source(output)) { + OPENVINO_SUPPRESS_DEPRECATED_END + output_values[0] = c; + auto output_ptr = output_values[0].get_node_shared_ptr(); + for (const auto& n : nodes) { + copy_runtime_info(n->shared_from_this(), output_ptr); + } + is_folded = true; + } } } return is_folded; diff --git a/src/core/tests/constant_folding.cpp b/src/core/tests/constant_folding.cpp index 71c1fc52ba9..a5bf60f615b 100644 --- a/src/core/tests/constant_folding.cpp +++ b/src/core/tests/constant_folding.cpp @@ -13,6 +13,7 @@ #include "ngraph/opsets/opset5.hpp" #include "ngraph/pass/manager.hpp" #include "openvino/opsets/opset11.hpp" +#include "transformations/common_optimizations/disable_shapeof_constant_folding.hpp" #include "util/all_close_f.hpp" #include "util/test_tools.hpp" @@ -2424,6 +2425,84 @@ TEST(constant_folding, strided_slice_not_ignored_dynamic_begin_from_shape_of) { ASSERT_EQ(count_ops_of_type(model), 2); } +TEST(constant_folding, strided_slice_can_be_folded_but_is_blocked_by_shape_of_which_got_folding_disabled) { + const auto constant = + make_shared(element::i32, + Shape{1, 1, 2, 4, 2}, + std::vector{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}); + constant->set_friendly_name("constant"); + + const auto begin_shape = PartialShape{0, -1, 0, 0, 0}; + const auto p_begin = std::make_shared(element::i64, begin_shape); + const auto shape_of_begin = std::make_shared(p_begin); + shape_of_begin->set_friendly_name("begin"); + + const auto end_shape = PartialShape{-1, 512, 2, 2, 16}; + const auto p_end = std::make_shared(element::i64, end_shape); + const auto shape_of_end = std::make_shared(p_end); + shape_of_end->set_friendly_name("end"); + + const auto stride = op::Constant::create(element::i64, {5}, {1, 1, 1, 1, 1}); + stride->set_friendly_name("stride"); + + const auto slice = make_shared(constant, + shape_of_begin, + shape_of_end, + stride, + std::vector{0, 1, 0, 0, 0}, + std::vector{1, 1, 0, 0, 1}); + slice->set_friendly_name("test"); + + auto model = make_shared(slice, ParameterVector{p_begin, p_end}); + + pass::Manager pass_manager; + pass_manager.register_pass(); + pass_manager.register_pass(); + pass_manager.register_pass(); + pass_manager.run_passes(model); + + ASSERT_EQ(count_ops_of_type(model), 1); + ASSERT_EQ(count_ops_of_type(model), 2); +} + +TEST(constant_folding, strided_slice_is_foldable_but_got_set_disable_constant_fold) { + const auto constant = + make_shared(element::i32, + Shape{1, 1, 2, 4, 2}, + std::vector{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}); + constant->set_friendly_name("constant"); + + const auto begin_shape = PartialShape{0, -1, 0, 0, 0}; + const auto p_begin = std::make_shared(element::i64, begin_shape); + const auto shape_of_begin = std::make_shared(p_begin); + shape_of_begin->set_friendly_name("begin"); + + const auto end_shape = PartialShape{-1, 512, 2, 2, 16}; + const auto p_end = std::make_shared(element::i64, end_shape); + const auto shape_of_end = std::make_shared(p_end); + shape_of_end->set_friendly_name("end"); + + const auto stride = op::Constant::create(element::i64, {5}, {1, 1, 1, 1, 1}); + stride->set_friendly_name("stride"); + + const auto slice = make_shared(constant, + shape_of_begin, + shape_of_end, + stride, + std::vector{0, 1, 0, 0, 0}, + std::vector{1, 1, 0, 0, 1}); + slice->set_friendly_name("test"); + + auto model = make_shared(slice, ParameterVector{p_begin, p_end}); + + ov::disable_constant_folding(slice); + + run_constant_folding(model); + + ASSERT_EQ(count_ops_of_type(model), 1); + ASSERT_EQ(count_ops_of_type(model), 2); +} + TEST(constant_folding, constant_dyn_reshape) { Shape shape_in{2, 4}; vector values_in{0, 1, 2, 3, 4, 5, 6, 7};