don't fold ShapeOf in TransposeSinking (#20742)
* add to DisableShapeOfConstantFolding flag; set it in ts_gather; add unit test * fix ts_general; use check_shape flag in EnableShapeOfConstantFolding * set EnableShapeOfConstantFolding constructor explicit --------- Co-authored-by: Ivan Tikhonov <ivan.tikhonov@intel.com>
This commit is contained in:
parent
7931938ea4
commit
ecd1972c24
@ -20,5 +20,5 @@ class TRANSFORMATIONS_API DisableShapeOfConstantFolding;
|
||||
class ov::pass::DisableShapeOfConstantFolding : public ov::pass::MatcherPass {
|
||||
public:
|
||||
OPENVINO_RTTI("DisableShapeOfConstantFolding", "0");
|
||||
DisableShapeOfConstantFolding();
|
||||
explicit DisableShapeOfConstantFolding(bool check_shape = true);
|
||||
};
|
||||
|
@ -18,7 +18,7 @@ namespace pass {
|
||||
class TRANSFORMATIONS_API EnableShapeOfConstantFolding : public MatcherPass {
|
||||
public:
|
||||
OPENVINO_RTTI("EnableShapeOfConstantFolding", "0");
|
||||
EnableShapeOfConstantFolding();
|
||||
explicit EnableShapeOfConstantFolding(bool check_shape = true);
|
||||
};
|
||||
|
||||
} // namespace pass
|
||||
|
@ -10,9 +10,11 @@
|
||||
#include "openvino/pass/pattern/op/wrap_type.hpp"
|
||||
#include "transformations/rt_info/disable_constant_folding.hpp"
|
||||
|
||||
ov::pass::DisableShapeOfConstantFolding::DisableShapeOfConstantFolding() {
|
||||
ov::pass::DisableShapeOfConstantFolding::DisableShapeOfConstantFolding(bool check_shape) {
|
||||
auto shape_of = pattern::wrap_type<ov::op::v0::ShapeOf, ov::op::v3::ShapeOf>([=](const Output<Node>& output) {
|
||||
const auto& shape = output.get_partial_shape();
|
||||
if (!check_shape)
|
||||
return true;
|
||||
return shape.is_dynamic() || shape_size(shape.get_shape()) != 1;
|
||||
});
|
||||
|
||||
|
@ -8,9 +8,11 @@
|
||||
#include "openvino/pass/pattern/op/wrap_type.hpp"
|
||||
#include "transformations/rt_info/disable_constant_folding.hpp"
|
||||
|
||||
ov::pass::EnableShapeOfConstantFolding::EnableShapeOfConstantFolding() {
|
||||
ov::pass::EnableShapeOfConstantFolding::EnableShapeOfConstantFolding(bool check_shape) {
|
||||
auto shape_of = pattern::wrap_type<op::util::ShapeOfBase>([=](const Output<Node>& output) {
|
||||
const auto& shape = output.get_partial_shape();
|
||||
if (!check_shape)
|
||||
return true;
|
||||
return shape.is_dynamic() || shape_size(shape.get_shape()) != 1;
|
||||
});
|
||||
|
||||
|
@ -73,7 +73,7 @@ bool TSGeneral::run_on_model(const std::shared_ptr<ov::Model>& f) {
|
||||
RUN_ON_FUNCTION_SCOPE(TSGeneral);
|
||||
{
|
||||
Manager manager(get_pass_config());
|
||||
manager.register_pass<DisableShapeOfConstantFolding>();
|
||||
manager.register_pass<DisableShapeOfConstantFolding>(/* check_shape */ false);
|
||||
manager.register_pass<TSGeneralForward>();
|
||||
manager.register_pass<ConstantFolding>();
|
||||
manager.run_passes(f);
|
||||
@ -81,11 +81,11 @@ bool TSGeneral::run_on_model(const std::shared_ptr<ov::Model>& f) {
|
||||
|
||||
{
|
||||
Manager manager(get_pass_config());
|
||||
manager.register_pass<DisableShapeOfConstantFolding>();
|
||||
manager.register_pass<DisableShapeOfConstantFolding>(/* check_shape */ false);
|
||||
manager.register_pass<TSGeneralBackward>();
|
||||
manager.register_pass<ConstantFolding>();
|
||||
manager.register_pass<TSResetNoSinkingAttribute>();
|
||||
manager.register_pass<EnableShapeOfConstantFolding>();
|
||||
manager.register_pass<EnableShapeOfConstantFolding>(/* check_shape */ false);
|
||||
manager.run_passes(f);
|
||||
}
|
||||
|
||||
|
@ -501,6 +501,38 @@ TEST(TransformationTests, TSGeneralBackwardCheckFriendlyAndTensorNamesForMultipl
|
||||
EXPECT_NE(actual_names.find(name), actual_names.end());
|
||||
}
|
||||
}
|
||||
|
||||
TEST_F(TransformationTestsF, TSGeneralDisableShapeOf) {
|
||||
using namespace transpose_sinking::testing::general;
|
||||
ov::Shape input_shape = {1};
|
||||
ov::element::Type input_type = ov::element::f32;
|
||||
|
||||
{
|
||||
auto X = std::make_shared<Parameter>(input_type, input_shape);
|
||||
|
||||
auto ng_order = std::make_shared<Constant>(ov::element::i64, ov::Shape{1}, ov::Shape{0});
|
||||
auto transpose = std::make_shared<Transpose>(X, ng_order);
|
||||
|
||||
auto shape_of = std::make_shared<ShapeOf>(transpose);
|
||||
|
||||
model = std::make_shared<ov::Model>(shape_of, ov::ParameterVector{X});
|
||||
}
|
||||
|
||||
{
|
||||
auto X = std::make_shared<Parameter>(input_type, input_shape);
|
||||
|
||||
auto shape_of = std::make_shared<ShapeOf>(X);
|
||||
|
||||
auto indices = std::make_shared<Constant>(ov::element::i64, ov::Shape{1}, ov::Shape{0});
|
||||
auto axes = std::make_shared<Constant>(ov::element::i32, ov::Shape{1}, ov::Shape{0});
|
||||
auto gather = std::make_shared<Gather>(shape_of, indices, axes);
|
||||
|
||||
model = std::make_shared<ov::Model>(gather, ov::ParameterVector{X});
|
||||
}
|
||||
|
||||
manager.register_pass<TSGeneral>();
|
||||
}
|
||||
|
||||
} // namespace general
|
||||
} // namespace testing
|
||||
} // namespace transpose_sinking
|
||||
} // namespace transpose_sinking
|
||||
|
Loading…
Reference in New Issue
Block a user