Disable ConstantFolding for ShapeOf subgraph in TS transformation (#16765)
* Disable ConstantFolding for ShapeOf expressions in TS transformation * update ModelWithEmptyTensorListAndPushBack: add ShapeOf subgraph
This commit is contained in:
parent
8b7e6878e8
commit
72952bdc45
@ -9,6 +9,7 @@
|
||||
#include "openvino/pass/graph_rewrite.hpp"
|
||||
#include "openvino/pass/manager.hpp"
|
||||
#include "openvino/pass/pattern/op/wrap_type.hpp"
|
||||
#include "transformations/common_optimizations/disable_shapeof_constant_folding.hpp"
|
||||
#include "transformations/transpose_sinking/ts_binary.hpp"
|
||||
#include "transformations/transpose_sinking/ts_concat.hpp"
|
||||
#include "transformations/transpose_sinking/ts_data_movement.hpp"
|
||||
@ -58,6 +59,7 @@ bool TSGeneral::run_on_model(const std::shared_ptr<ov::Model>& f) {
|
||||
RUN_ON_FUNCTION_SCOPE(TSGeneral);
|
||||
{
|
||||
Manager manager(get_pass_config());
|
||||
manager.register_pass<DisableShapeOfConstantFolding>();
|
||||
manager.register_pass<TSGeneralForward>();
|
||||
manager.register_pass<ConstantFolding>();
|
||||
manager.run_passes(f);
|
||||
@ -65,6 +67,7 @@ bool TSGeneral::run_on_model(const std::shared_ptr<ov::Model>& f) {
|
||||
|
||||
{
|
||||
Manager manager(get_pass_config());
|
||||
manager.register_pass<DisableShapeOfConstantFolding>();
|
||||
manager.register_pass<TSGeneralBackward>();
|
||||
manager.register_pass<ConstantFolding>();
|
||||
manager.run_passes(f);
|
||||
|
@ -412,6 +412,29 @@ TEST_F(TransformationTestsF, TSGeneralTestMultipleTypes) {
|
||||
manager.register_pass<TSGeneral>();
|
||||
}
|
||||
|
||||
TEST_F(TransformationTestsF, TSGeneralCheckShapeOfConstFoldingDisabled) {
|
||||
using namespace transpose_sinking::testing::general;
|
||||
ov::Shape input_shape = {96, 40, 55};
|
||||
ov::Shape reshape_shape = {1, 96, 40, 55};
|
||||
ov::element::Type input_type = ov::element::f32;
|
||||
{
|
||||
auto X = std::make_shared<Parameter>(input_type, input_shape);
|
||||
auto Shape = std::make_shared<Parameter>(input_type, reshape_shape);
|
||||
|
||||
auto order = std::make_shared<Constant>(ov::element::u64, ov::Shape{3}, ov::Shape{0, 2, 1});
|
||||
auto transpose = std::make_shared<Transpose>(X, order);
|
||||
|
||||
auto shape_of = std::make_shared<ShapeOf>(Shape);
|
||||
auto reshape = std::make_shared<Reshape>(transpose, shape_of, false);
|
||||
|
||||
auto ng_order1 = std::make_shared<Constant>(ov::element::u64, ov::Shape{4}, ov::Shape{0, 3, 1, 2});
|
||||
auto transpose1 = std::make_shared<Transpose>(reshape, ng_order1);
|
||||
|
||||
function = std::make_shared<ov::Model>(transpose1, ov::ParameterVector{X, Shape});
|
||||
}
|
||||
manager.register_pass<TSGeneral>();
|
||||
}
|
||||
|
||||
} // namespace general
|
||||
} // namespace testing
|
||||
} // namespace transpose_sinking
|
@ -406,7 +406,13 @@ TEST_F(TransformationTestsF, ModelWithEmptyTensorListAndPushBack) {
|
||||
auto x_unsqueeze_flatten = make_shared<Unsqueeze>(x_flatten, zero_const);
|
||||
auto empty_const = make_shared<Constant>(f32, Shape{0, 30}, vector<float>{});
|
||||
auto list_push_back = make_shared<Concat>(OutputVector{empty_const, x_unsqueeze_flatten}, 0);
|
||||
auto recover_item_shape = make_shared<Constant>(i32, Shape{4}, vector<int32_t>{1, 2, 3, 5});
|
||||
auto list_push_back_shape = make_shared<ShapeOf>(list_push_back, element::i32);
|
||||
auto start = make_shared<Constant>(i32, Shape{1}, 0);
|
||||
auto stop = make_shared<Constant>(i32, Shape{1}, 1);
|
||||
auto step = make_shared<Constant>(i32, Shape{1}, 1);
|
||||
auto batch = make_shared<Slice>(list_push_back_shape, start, stop, step);
|
||||
auto shape_without_batch = make_shared<Constant>(i32, Shape{3}, vector<int32_t>{2, 3, 5});
|
||||
auto recover_item_shape = make_shared<Concat>(OutputVector{batch, shape_without_batch}, 0);
|
||||
auto recover_item = make_shared<Reshape>(list_push_back, recover_item_shape, false);
|
||||
model_ref = make_shared<Model>(OutputVector{recover_item}, ParameterVector{x});
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user