Fix StridedSlice constant folding when disabled attribute set. (#17679)
* Do not fold StridedSlice when: - On begin or end there is ShapeOf with disabled constant folding. - StridedSlice op has disabled constant folding. * Copy rt info to folded StridedSlice
This commit is contained in:
parent
fa428a12e6
commit
e09b1a9fa2
@ -18,8 +18,10 @@
|
|||||||
#include "ngraph/slice_plan.hpp"
|
#include "ngraph/slice_plan.hpp"
|
||||||
#include "ngraph/type/element_type_traits.hpp"
|
#include "ngraph/type/element_type_traits.hpp"
|
||||||
#include "ngraph/util.hpp"
|
#include "ngraph/util.hpp"
|
||||||
|
#include "openvino/core/rt_info.hpp"
|
||||||
#include "openvino/core/validation_util.hpp"
|
#include "openvino/core/validation_util.hpp"
|
||||||
#include "openvino/op/util/precision_sensitive_attribute.hpp"
|
#include "openvino/op/util/precision_sensitive_attribute.hpp"
|
||||||
|
#include "openvino/pass/constant_folding.hpp"
|
||||||
#include "strided_slice_shape_inference.hpp"
|
#include "strided_slice_shape_inference.hpp"
|
||||||
|
|
||||||
using namespace std;
|
using namespace std;
|
||||||
@ -296,7 +298,7 @@ bool op::v1::StridedSlice::evaluate_label(TensorLabelVector& output_labels) cons
|
|||||||
|
|
||||||
bool op::v1::StridedSlice::constant_fold(OutputVector& output_values, const OutputVector& inputs_values) {
|
bool op::v1::StridedSlice::constant_fold(OutputVector& output_values, const OutputVector& inputs_values) {
|
||||||
auto is_folded = Node::constant_fold(output_values, inputs_values);
|
auto is_folded = Node::constant_fold(output_values, inputs_values);
|
||||||
if (!is_folded) {
|
if (!is_const_fold_disabled() && !is_folded) {
|
||||||
// If all ignored mask are set for all begin or end then replace this input by dummy constant
|
// If all ignored mask are set for all begin or end then replace this input by dummy constant
|
||||||
// to avoid return false from `could_propagate` during bound evaluation (value of const will be ignored).
|
// to avoid return false from `could_propagate` during bound evaluation (value of const will be ignored).
|
||||||
auto get_indices_input = [&inputs_values](size_t port, const std::vector<int64_t>& mask) -> Output<Node> {
|
auto get_indices_input = [&inputs_values](size_t port, const std::vector<int64_t>& mask) -> Output<Node> {
|
||||||
@ -325,11 +327,21 @@ bool op::v1::StridedSlice::constant_fold(OutputVector& output_values, const Outp
|
|||||||
? clone_with_new_inputs(OutputVector{inputs_values[0], begin, end, inputs_values[3]})->output(0)
|
? clone_with_new_inputs(OutputVector{inputs_values[0], begin, end, inputs_values[3]})->output(0)
|
||||||
: this->output(0);
|
: this->output(0);
|
||||||
|
|
||||||
OPENVINO_SUPPRESS_DEPRECATED_START
|
std::vector<Node*> nodes;
|
||||||
if (const auto c = ov::get_constant_from_source(output)) {
|
// Check if bounds can be evaluated and none of output nodes have disabled constant folding.
|
||||||
OPENVINO_SUPPRESS_DEPRECATED_END
|
if (ov::could_propagate(output, nodes) && std::none_of(nodes.begin(), nodes.end(), [](const Node* n) {
|
||||||
output_values[0] = c;
|
return ov::pass::constant_folding_is_disabled(n);
|
||||||
is_folded = true;
|
})) {
|
||||||
|
OPENVINO_SUPPRESS_DEPRECATED_START
|
||||||
|
if (const auto c = ov::get_constant_from_source(output)) {
|
||||||
|
OPENVINO_SUPPRESS_DEPRECATED_END
|
||||||
|
output_values[0] = c;
|
||||||
|
auto output_ptr = output_values[0].get_node_shared_ptr();
|
||||||
|
for (const auto& n : nodes) {
|
||||||
|
copy_runtime_info(n->shared_from_this(), output_ptr);
|
||||||
|
}
|
||||||
|
is_folded = true;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return is_folded;
|
return is_folded;
|
||||||
|
@ -13,6 +13,7 @@
|
|||||||
#include "ngraph/opsets/opset5.hpp"
|
#include "ngraph/opsets/opset5.hpp"
|
||||||
#include "ngraph/pass/manager.hpp"
|
#include "ngraph/pass/manager.hpp"
|
||||||
#include "openvino/opsets/opset11.hpp"
|
#include "openvino/opsets/opset11.hpp"
|
||||||
|
#include "transformations/common_optimizations/disable_shapeof_constant_folding.hpp"
|
||||||
#include "util/all_close_f.hpp"
|
#include "util/all_close_f.hpp"
|
||||||
#include "util/test_tools.hpp"
|
#include "util/test_tools.hpp"
|
||||||
|
|
||||||
@ -2424,6 +2425,84 @@ TEST(constant_folding, strided_slice_not_ignored_dynamic_begin_from_shape_of) {
|
|||||||
ASSERT_EQ(count_ops_of_type<op::Constant>(model), 2);
|
ASSERT_EQ(count_ops_of_type<op::Constant>(model), 2);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST(constant_folding, strided_slice_can_be_folded_but_is_blocked_by_shape_of_which_got_folding_disabled) {
|
||||||
|
const auto constant =
|
||||||
|
make_shared<op::Constant>(element::i32,
|
||||||
|
Shape{1, 1, 2, 4, 2},
|
||||||
|
std::vector<int>{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16});
|
||||||
|
constant->set_friendly_name("constant");
|
||||||
|
|
||||||
|
const auto begin_shape = PartialShape{0, -1, 0, 0, 0};
|
||||||
|
const auto p_begin = std::make_shared<op::Parameter>(element::i64, begin_shape);
|
||||||
|
const auto shape_of_begin = std::make_shared<op::ShapeOf>(p_begin);
|
||||||
|
shape_of_begin->set_friendly_name("begin");
|
||||||
|
|
||||||
|
const auto end_shape = PartialShape{-1, 512, 2, 2, 16};
|
||||||
|
const auto p_end = std::make_shared<op::Parameter>(element::i64, end_shape);
|
||||||
|
const auto shape_of_end = std::make_shared<op::ShapeOf>(p_end);
|
||||||
|
shape_of_end->set_friendly_name("end");
|
||||||
|
|
||||||
|
const auto stride = op::Constant::create(element::i64, {5}, {1, 1, 1, 1, 1});
|
||||||
|
stride->set_friendly_name("stride");
|
||||||
|
|
||||||
|
const auto slice = make_shared<op::v1::StridedSlice>(constant,
|
||||||
|
shape_of_begin,
|
||||||
|
shape_of_end,
|
||||||
|
stride,
|
||||||
|
std::vector<int64_t>{0, 1, 0, 0, 0},
|
||||||
|
std::vector<int64_t>{1, 1, 0, 0, 1});
|
||||||
|
slice->set_friendly_name("test");
|
||||||
|
|
||||||
|
auto model = make_shared<ov::Model>(slice, ParameterVector{p_begin, p_end});
|
||||||
|
|
||||||
|
pass::Manager pass_manager;
|
||||||
|
pass_manager.register_pass<ov::pass::InitNodeInfo>();
|
||||||
|
pass_manager.register_pass<ov::pass::DisableShapeOfConstantFolding>();
|
||||||
|
pass_manager.register_pass<pass::ConstantFolding>();
|
||||||
|
pass_manager.run_passes(model);
|
||||||
|
|
||||||
|
ASSERT_EQ(count_ops_of_type<op::v1::StridedSlice>(model), 1);
|
||||||
|
ASSERT_EQ(count_ops_of_type<op::Constant>(model), 2);
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(constant_folding, strided_slice_is_foldable_but_got_set_disable_constant_fold) {
|
||||||
|
const auto constant =
|
||||||
|
make_shared<op::Constant>(element::i32,
|
||||||
|
Shape{1, 1, 2, 4, 2},
|
||||||
|
std::vector<int>{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16});
|
||||||
|
constant->set_friendly_name("constant");
|
||||||
|
|
||||||
|
const auto begin_shape = PartialShape{0, -1, 0, 0, 0};
|
||||||
|
const auto p_begin = std::make_shared<op::Parameter>(element::i64, begin_shape);
|
||||||
|
const auto shape_of_begin = std::make_shared<op::ShapeOf>(p_begin);
|
||||||
|
shape_of_begin->set_friendly_name("begin");
|
||||||
|
|
||||||
|
const auto end_shape = PartialShape{-1, 512, 2, 2, 16};
|
||||||
|
const auto p_end = std::make_shared<op::Parameter>(element::i64, end_shape);
|
||||||
|
const auto shape_of_end = std::make_shared<op::ShapeOf>(p_end);
|
||||||
|
shape_of_end->set_friendly_name("end");
|
||||||
|
|
||||||
|
const auto stride = op::Constant::create(element::i64, {5}, {1, 1, 1, 1, 1});
|
||||||
|
stride->set_friendly_name("stride");
|
||||||
|
|
||||||
|
const auto slice = make_shared<op::v1::StridedSlice>(constant,
|
||||||
|
shape_of_begin,
|
||||||
|
shape_of_end,
|
||||||
|
stride,
|
||||||
|
std::vector<int64_t>{0, 1, 0, 0, 0},
|
||||||
|
std::vector<int64_t>{1, 1, 0, 0, 1});
|
||||||
|
slice->set_friendly_name("test");
|
||||||
|
|
||||||
|
auto model = make_shared<ov::Model>(slice, ParameterVector{p_begin, p_end});
|
||||||
|
|
||||||
|
ov::disable_constant_folding(slice);
|
||||||
|
|
||||||
|
run_constant_folding(model);
|
||||||
|
|
||||||
|
ASSERT_EQ(count_ops_of_type<op::v1::StridedSlice>(model), 1);
|
||||||
|
ASSERT_EQ(count_ops_of_type<op::Constant>(model), 2);
|
||||||
|
}
|
||||||
|
|
||||||
TEST(constant_folding, constant_dyn_reshape) {
|
TEST(constant_folding, constant_dyn_reshape) {
|
||||||
Shape shape_in{2, 4};
|
Shape shape_in{2, 4};
|
||||||
vector<float> values_in{0, 1, 2, 3, 4, 5, 6, 7};
|
vector<float> values_in{0, 1, 2, 3, 4, 5, 6, 7};
|
||||||
|
Loading…
Reference in New Issue
Block a user