Add EnableShapeOfConstantFolding transformation (#19880)

* Add EnableShapeOfConstantFolding transformation

Transpose sinking (that is used in TF frontend) disables ShapeOf constant folding
which prevents some optimizations further in the pipeline.
This patch introduces EnableShapeOfConstantFolding that removes DisableConstantFolding
from ShapeOf nodes.

Ticket: CVS-118890

* add description

* review comments

* headers
This commit is contained in:
Mateusz Tabaka 2023-09-26 15:53:23 +02:00 committed by GitHub
parent 2bbfe7b44d
commit 5384fe43df
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 281 additions and 0 deletions

View File

@ -0,0 +1,25 @@
// Copyright (C) 2023 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#pragma once
#include "openvino/pass/graph_rewrite.hpp"
#include "transformations_visibility.hpp"
namespace ov {
namespace pass {
/**
* @ingroup ie_transformation_common_api
* @brief This transformation enables constfoldability for ShapeOf nodes that was
* disabled by DisableShapeOfConstantFolding.
*/
class TRANSFORMATIONS_API EnableShapeOfConstantFolding : public MatcherPass {
public:
OPENVINO_RTTI("EnableShapeOfConstantFolding", "0");
EnableShapeOfConstantFolding();
};
} // namespace pass
} // namespace ov

View File

@ -0,0 +1,24 @@
// Copyright (C) 2023 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#include "transformations/common_optimizations/enable_shapeof_constant_folding.hpp"
#include "openvino/op/util/shape_of_base.hpp"
#include "openvino/pass/pattern/op/wrap_type.hpp"
#include "transformations/rt_info/disable_constant_folding.hpp"
ov::pass::EnableShapeOfConstantFolding::EnableShapeOfConstantFolding() {
auto shape_of = pattern::wrap_type<op::util::ShapeOfBase>([=](const Output<Node>& output) {
const auto& shape = output.get_partial_shape();
return shape.is_dynamic() || shape_size(shape.get_shape()) != 1;
});
matcher_pass_callback callback = [=](pattern::Matcher& m) {
enable_constant_folding(m.get_match_root());
return true;
};
auto m = std::make_shared<pattern::Matcher>(shape_of, "EnableShapeOfConstantFolding");
this->register_matcher(m, callback);
}

View File

@ -10,6 +10,7 @@
#include "openvino/pass/manager.hpp" #include "openvino/pass/manager.hpp"
#include "openvino/pass/pattern/op/wrap_type.hpp" #include "openvino/pass/pattern/op/wrap_type.hpp"
#include "transformations/common_optimizations/disable_shapeof_constant_folding.hpp" #include "transformations/common_optimizations/disable_shapeof_constant_folding.hpp"
#include "transformations/common_optimizations/enable_shapeof_constant_folding.hpp"
#include "transformations/transpose_sinking/ts_binary.hpp" #include "transformations/transpose_sinking/ts_binary.hpp"
#include "transformations/transpose_sinking/ts_concat.hpp" #include "transformations/transpose_sinking/ts_concat.hpp"
#include "transformations/transpose_sinking/ts_data_movement.hpp" #include "transformations/transpose_sinking/ts_data_movement.hpp"
@ -77,6 +78,7 @@ bool TSGeneral::run_on_model(const std::shared_ptr<ov::Model>& f) {
manager.register_pass<TSGeneralBackward>(); manager.register_pass<TSGeneralBackward>();
manager.register_pass<ConstantFolding>(); manager.register_pass<ConstantFolding>();
manager.register_pass<TSResetNoSinkingAttribute>(); manager.register_pass<TSResetNoSinkingAttribute>();
manager.register_pass<EnableShapeOfConstantFolding>();
manager.run_passes(f); manager.run_passes(f);
} }

View File

@ -0,0 +1,60 @@
// Copyright (C) 2023 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#include "transformations/common_optimizations/enable_shapeof_constant_folding.hpp"
#include <gtest/gtest.h>
#include "common_test_utils/ov_test_utils.hpp"
#include "openvino/core/model.hpp"
#include "openvino/op/abs.hpp"
#include "openvino/op/reshape.hpp"
#include "openvino/op/shape_of.hpp"
#include "openvino/pass/constant_folding.hpp"
#include "openvino/pass/manager.hpp"
using namespace testing;
using namespace ov;
TEST_F(TransformationTestsF, EnableShapeOfV0ConstantFolding) {
{
auto data = std::make_shared<op::v0::Parameter>(element::f32, Shape{1, 4, 10, 10});
auto shape_of = std::make_shared<op::v0::ShapeOf>(data);
pass::disable_constant_folding(shape_of);
auto abs = std::make_shared<op::v0::Abs>(shape_of);
auto reshape = std::make_shared<op::v1::Reshape>(data, abs, false);
model = std::make_shared<Model>(reshape, ParameterVector{data});
manager.register_pass<pass::EnableShapeOfConstantFolding>();
manager.register_pass<pass::ConstantFolding>();
}
{
auto data = std::make_shared<op::v0::Parameter>(element::f32, Shape{1, 4, 10, 10});
auto shape = op::v0::Constant::create(element::i64, Shape{4}, {1, 4, 10, 10});
auto reshape = std::make_shared<op::v1::Reshape>(data, shape, false);
model_ref = std::make_shared<Model>(reshape, ParameterVector{data});
}
}
TEST_F(TransformationTestsF, EnableShapeOfV3ConstantFolding) {
{
auto data = std::make_shared<op::v0::Parameter>(element::f32, Shape{1, 4, 10, 10});
auto shape_of = std::make_shared<op::v3::ShapeOf>(data);
pass::disable_constant_folding(shape_of);
auto abs = std::make_shared<op::v0::Abs>(shape_of);
auto reshape = std::make_shared<op::v1::Reshape>(data, abs, false);
model = std::make_shared<Model>(reshape, ParameterVector{data});
manager.register_pass<pass::EnableShapeOfConstantFolding>();
manager.register_pass<pass::ConstantFolding>();
}
{
auto data = std::make_shared<op::v0::Parameter>(element::f32, Shape{1, 4, 10, 10});
auto shape = op::v0::Constant::create(element::i64, Shape{4}, {1, 4, 10, 10});
auto reshape = std::make_shared<op::v1::Reshape>(data, shape, false);
model_ref = std::make_shared<Model>(reshape, ParameterVector{data});
}
}

View File

@ -46,3 +46,24 @@ TEST_F(CompileModelsTests, ModelWithSplitConvConcat) {
})); }));
} }
} }
TEST_F(CompileModelsTests, ModelWithShapeOf) {
auto model = convert_model("shapeof_slice_abs/shapeof_slice_abs.pbtxt");
ov::Core core;
ov::CompiledModel compiled_model = core.compile_model(model, "CPU");
const auto runtime_model = compiled_model.get_runtime_model();
auto get_layer_type = [](const std::shared_ptr<ov::Node>& node) {
return node->get_rt_info().at(ExecGraphInfoSerialization::LAYER_TYPE).as<std::string>();
};
const auto ops = runtime_model->get_ops();
// one Input, one Eltwise and one Output
EXPECT_EQ(3, ops.size());
// ShapeOf is folded
EXPECT_EQ(0, std::count_if(ops.begin(), ops.end(), [&](const std::shared_ptr<ov::Node>& node) {
return get_layer_type(node) == "ShapeOf";
}));
// Slice is eliminated
EXPECT_EQ(0, std::count_if(ops.begin(), ops.end(), [&](const std::shared_ptr<ov::Node>& node) {
return get_layer_type(node) == "StridedSlice";
}));
}

View File

@ -0,0 +1,149 @@
node {
name: "input"
op: "Placeholder"
attr {
key: "_output_shapes"
value {
list {
shape {
dim {
size: 10
}
dim {
size: 4
}
}
}
}
}
attr {
key: "dtype"
value {
type: DT_FLOAT
}
}
attr {
key: "shape"
value {
shape {
dim {
size: 10
}
dim {
size: 4
}
}
}
}
}
node {
name: "start"
op: "Const"
attr {
key: "dtype"
value {
type: DT_INT32
}
}
attr {
key: "value"
value {
tensor {
dtype: DT_INT32
tensor_shape {
dim {
size: 2
}
}
tensor_content: "\000\000\000\000\000\000\000\000"
}
}
}
}
node {
name: "stop"
input: "input"
op: "Shape"
attr {
key: "dtype"
value {
type: DT_INT32
}
}
}
node {
name: "stride"
op: "Const"
attr {
key: "dtype"
value {
type: DT_INT32
}
}
attr {
key: "value"
value {
tensor {
dtype: DT_INT32
tensor_shape {
dim {
size: 2
}
}
tensor_content: "\001\000\000\000\001\000\000\000"
}
}
}
}
node {
name: "axes"
op: "Const"
attr {
key: "dtype"
value {
type: DT_INT32
}
}
attr {
key: "value"
value {
tensor {
dtype: DT_INT32
tensor_shape {
dim {
size: 2
}
}
tensor_content: "\000\000\000\000\001\000\000\000"
}
}
}
}
node {
name: "slice"
op: "Slice"
input: "input"
input: "start"
input: "stop"
input: "stride"
input: "axes"
attr {
key: "T"
value {
type: DT_FLOAT
}
}
}
node {
name: "abs"
op: "Abs"
input: "slice"
attr {
key: "T"
value {
type: DT_FLOAT
}
}
}
library {
}