don't fold ShapeOf in TransposeSinking (#20742)

* add to DisableShapeOfConstantFolding flag; set it in ts_gather; add unit test

* fix ts_general; use check_shape flag in EnableShapeOfConstantFolding

* set EnableShapeOfConstantFolding constructor explicit

---------

Co-authored-by: Ivan Tikhonov <ivan.tikhonov@intel.com>
This commit is contained in:
Evgeny Kotov 2023-11-20 11:57:19 +01:00 committed by GitHub
parent 7931938ea4
commit ecd1972c24
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 44 additions and 8 deletions

View File

@ -20,5 +20,5 @@ class TRANSFORMATIONS_API DisableShapeOfConstantFolding;
class ov::pass::DisableShapeOfConstantFolding : public ov::pass::MatcherPass {
public:
OPENVINO_RTTI("DisableShapeOfConstantFolding", "0");
DisableShapeOfConstantFolding();
explicit DisableShapeOfConstantFolding(bool check_shape = true);
};

View File

@ -18,7 +18,7 @@ namespace pass {
class TRANSFORMATIONS_API EnableShapeOfConstantFolding : public MatcherPass {
public:
OPENVINO_RTTI("EnableShapeOfConstantFolding", "0");
EnableShapeOfConstantFolding();
explicit EnableShapeOfConstantFolding(bool check_shape = true);
};
} // namespace pass

View File

@ -10,9 +10,11 @@
#include "openvino/pass/pattern/op/wrap_type.hpp"
#include "transformations/rt_info/disable_constant_folding.hpp"
ov::pass::DisableShapeOfConstantFolding::DisableShapeOfConstantFolding() {
ov::pass::DisableShapeOfConstantFolding::DisableShapeOfConstantFolding(bool check_shape) {
auto shape_of = pattern::wrap_type<ov::op::v0::ShapeOf, ov::op::v3::ShapeOf>([=](const Output<Node>& output) {
const auto& shape = output.get_partial_shape();
if (!check_shape)
return true;
return shape.is_dynamic() || shape_size(shape.get_shape()) != 1;
});

View File

@ -8,9 +8,11 @@
#include "openvino/pass/pattern/op/wrap_type.hpp"
#include "transformations/rt_info/disable_constant_folding.hpp"
ov::pass::EnableShapeOfConstantFolding::EnableShapeOfConstantFolding() {
ov::pass::EnableShapeOfConstantFolding::EnableShapeOfConstantFolding(bool check_shape) {
auto shape_of = pattern::wrap_type<op::util::ShapeOfBase>([=](const Output<Node>& output) {
const auto& shape = output.get_partial_shape();
if (!check_shape)
return true;
return shape.is_dynamic() || shape_size(shape.get_shape()) != 1;
});

View File

@ -73,7 +73,7 @@ bool TSGeneral::run_on_model(const std::shared_ptr<ov::Model>& f) {
RUN_ON_FUNCTION_SCOPE(TSGeneral);
{
Manager manager(get_pass_config());
manager.register_pass<DisableShapeOfConstantFolding>();
manager.register_pass<DisableShapeOfConstantFolding>(/* check_shape */ false);
manager.register_pass<TSGeneralForward>();
manager.register_pass<ConstantFolding>();
manager.run_passes(f);
@ -81,11 +81,11 @@ bool TSGeneral::run_on_model(const std::shared_ptr<ov::Model>& f) {
{
Manager manager(get_pass_config());
manager.register_pass<DisableShapeOfConstantFolding>();
manager.register_pass<DisableShapeOfConstantFolding>(/* check_shape */ false);
manager.register_pass<TSGeneralBackward>();
manager.register_pass<ConstantFolding>();
manager.register_pass<TSResetNoSinkingAttribute>();
manager.register_pass<EnableShapeOfConstantFolding>();
manager.register_pass<EnableShapeOfConstantFolding>(/* check_shape */ false);
manager.run_passes(f);
}

View File

@ -501,6 +501,38 @@ TEST(TransformationTests, TSGeneralBackwardCheckFriendlyAndTensorNamesForMultipl
EXPECT_NE(actual_names.find(name), actual_names.end());
}
}
TEST_F(TransformationTestsF, TSGeneralDisableShapeOf) {
using namespace transpose_sinking::testing::general;
ov::Shape input_shape = {1};
ov::element::Type input_type = ov::element::f32;
{
auto X = std::make_shared<Parameter>(input_type, input_shape);
auto ng_order = std::make_shared<Constant>(ov::element::i64, ov::Shape{1}, ov::Shape{0});
auto transpose = std::make_shared<Transpose>(X, ng_order);
auto shape_of = std::make_shared<ShapeOf>(transpose);
model = std::make_shared<ov::Model>(shape_of, ov::ParameterVector{X});
}
{
auto X = std::make_shared<Parameter>(input_type, input_shape);
auto shape_of = std::make_shared<ShapeOf>(X);
auto indices = std::make_shared<Constant>(ov::element::i64, ov::Shape{1}, ov::Shape{0});
auto axes = std::make_shared<Constant>(ov::element::i32, ov::Shape{1}, ov::Shape{0});
auto gather = std::make_shared<Gather>(shape_of, indices, axes);
model = std::make_shared<ov::Model>(gather, ov::ParameterVector{X});
}
manager.register_pass<TSGeneral>();
}
} // namespace general
} // namespace testing
} // namespace transpose_sinking
} // namespace transpose_sinking