Disable ConstantFolding for ShapeOf subgraph in TS transformation (#16765)

* Disable ConstantFolding for ShapeOf expressions in TS transformation

* update ModelWithEmptyTensorListAndPushBack: add ShapeOf subgraph
This commit is contained in:
Ivan Tikhonov 2023-04-07 14:50:59 +04:00 committed by GitHub
parent 8b7e6878e8
commit 72952bdc45
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 33 additions and 1 deletions

View File

@ -9,6 +9,7 @@
#include "openvino/pass/graph_rewrite.hpp"
#include "openvino/pass/manager.hpp"
#include "openvino/pass/pattern/op/wrap_type.hpp"
#include "transformations/common_optimizations/disable_shapeof_constant_folding.hpp"
#include "transformations/transpose_sinking/ts_binary.hpp"
#include "transformations/transpose_sinking/ts_concat.hpp"
#include "transformations/transpose_sinking/ts_data_movement.hpp"
@ -58,6 +59,7 @@ bool TSGeneral::run_on_model(const std::shared_ptr<ov::Model>& f) {
RUN_ON_FUNCTION_SCOPE(TSGeneral);
{
Manager manager(get_pass_config());
manager.register_pass<DisableShapeOfConstantFolding>();
manager.register_pass<TSGeneralForward>();
manager.register_pass<ConstantFolding>();
manager.run_passes(f);
@ -65,6 +67,7 @@ bool TSGeneral::run_on_model(const std::shared_ptr<ov::Model>& f) {
{
Manager manager(get_pass_config());
manager.register_pass<DisableShapeOfConstantFolding>();
manager.register_pass<TSGeneralBackward>();
manager.register_pass<ConstantFolding>();
manager.run_passes(f);

View File

@ -412,6 +412,29 @@ TEST_F(TransformationTestsF, TSGeneralTestMultipleTypes) {
manager.register_pass<TSGeneral>();
}
TEST_F(TransformationTestsF, TSGeneralCheckShapeOfConstFoldingDisabled) {
using namespace transpose_sinking::testing::general;
ov::Shape input_shape = {96, 40, 55};
ov::Shape reshape_shape = {1, 96, 40, 55};
ov::element::Type input_type = ov::element::f32;
{
auto X = std::make_shared<Parameter>(input_type, input_shape);
auto Shape = std::make_shared<Parameter>(input_type, reshape_shape);
auto order = std::make_shared<Constant>(ov::element::u64, ov::Shape{3}, ov::Shape{0, 2, 1});
auto transpose = std::make_shared<Transpose>(X, order);
auto shape_of = std::make_shared<ShapeOf>(Shape);
auto reshape = std::make_shared<Reshape>(transpose, shape_of, false);
auto ng_order1 = std::make_shared<Constant>(ov::element::u64, ov::Shape{4}, ov::Shape{0, 3, 1, 2});
auto transpose1 = std::make_shared<Transpose>(reshape, ng_order1);
function = std::make_shared<ov::Model>(transpose1, ov::ParameterVector{X, Shape});
}
manager.register_pass<TSGeneral>();
}
} // namespace general
} // namespace testing
} // namespace transpose_sinking

View File

@ -406,7 +406,13 @@ TEST_F(TransformationTestsF, ModelWithEmptyTensorListAndPushBack) {
auto x_unsqueeze_flatten = make_shared<Unsqueeze>(x_flatten, zero_const);
auto empty_const = make_shared<Constant>(f32, Shape{0, 30}, vector<float>{});
auto list_push_back = make_shared<Concat>(OutputVector{empty_const, x_unsqueeze_flatten}, 0);
auto recover_item_shape = make_shared<Constant>(i32, Shape{4}, vector<int32_t>{1, 2, 3, 5});
auto list_push_back_shape = make_shared<ShapeOf>(list_push_back, element::i32);
auto start = make_shared<Constant>(i32, Shape{1}, 0);
auto stop = make_shared<Constant>(i32, Shape{1}, 1);
auto step = make_shared<Constant>(i32, Shape{1}, 1);
auto batch = make_shared<Slice>(list_push_back_shape, start, stop, step);
auto shape_without_batch = make_shared<Constant>(i32, Shape{3}, vector<int32_t>{2, 3, 5});
auto recover_item_shape = make_shared<Concat>(OutputVector{batch, shape_without_batch}, 0);
auto recover_item = make_shared<Reshape>(list_push_back, recover_item_shape, false);
model_ref = make_shared<Model>(OutputVector{recover_item}, ParameterVector{x});
}