Add EnableShapeOfConstantFolding transformation (#19880)
* Add EnableShapeOfConstantFolding transformation Transpose sinking (that is used in TF frontend) disables ShapeOf constant folding which prevents some optimizations further in the pipeline. This patch introduces EnableShapeOfConstantFolding that removes DisableConstantFolding from ShapeOf nodes. Ticket: CVS-118890 * add description * review comments * headers
This commit is contained in:
parent
2bbfe7b44d
commit
5384fe43df
@ -0,0 +1,25 @@
|
|||||||
|
// Copyright (C) 2023 Intel Corporation
|
||||||
|
// SPDX-License-Identifier: Apache-2.0
|
||||||
|
//
|
||||||
|
|
||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include "openvino/pass/graph_rewrite.hpp"
|
||||||
|
#include "transformations_visibility.hpp"
|
||||||
|
|
||||||
|
namespace ov {
|
||||||
|
namespace pass {
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @ingroup ie_transformation_common_api
|
||||||
|
* @brief This transformation enables constfoldability for ShapeOf nodes that was
|
||||||
|
* disabled by DisableShapeOfConstantFolding.
|
||||||
|
*/
|
||||||
|
class TRANSFORMATIONS_API EnableShapeOfConstantFolding : public MatcherPass {
|
||||||
|
public:
|
||||||
|
OPENVINO_RTTI("EnableShapeOfConstantFolding", "0");
|
||||||
|
EnableShapeOfConstantFolding();
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace pass
|
||||||
|
} // namespace ov
|
@ -0,0 +1,24 @@
|
|||||||
|
// Copyright (C) 2023 Intel Corporation
|
||||||
|
// SPDX-License-Identifier: Apache-2.0
|
||||||
|
//
|
||||||
|
|
||||||
|
#include "transformations/common_optimizations/enable_shapeof_constant_folding.hpp"
|
||||||
|
|
||||||
|
#include "openvino/op/util/shape_of_base.hpp"
|
||||||
|
#include "openvino/pass/pattern/op/wrap_type.hpp"
|
||||||
|
#include "transformations/rt_info/disable_constant_folding.hpp"
|
||||||
|
|
||||||
|
ov::pass::EnableShapeOfConstantFolding::EnableShapeOfConstantFolding() {
|
||||||
|
auto shape_of = pattern::wrap_type<op::util::ShapeOfBase>([=](const Output<Node>& output) {
|
||||||
|
const auto& shape = output.get_partial_shape();
|
||||||
|
return shape.is_dynamic() || shape_size(shape.get_shape()) != 1;
|
||||||
|
});
|
||||||
|
|
||||||
|
matcher_pass_callback callback = [=](pattern::Matcher& m) {
|
||||||
|
enable_constant_folding(m.get_match_root());
|
||||||
|
return true;
|
||||||
|
};
|
||||||
|
|
||||||
|
auto m = std::make_shared<pattern::Matcher>(shape_of, "EnableShapeOfConstantFolding");
|
||||||
|
this->register_matcher(m, callback);
|
||||||
|
}
|
@ -10,6 +10,7 @@
|
|||||||
#include "openvino/pass/manager.hpp"
|
#include "openvino/pass/manager.hpp"
|
||||||
#include "openvino/pass/pattern/op/wrap_type.hpp"
|
#include "openvino/pass/pattern/op/wrap_type.hpp"
|
||||||
#include "transformations/common_optimizations/disable_shapeof_constant_folding.hpp"
|
#include "transformations/common_optimizations/disable_shapeof_constant_folding.hpp"
|
||||||
|
#include "transformations/common_optimizations/enable_shapeof_constant_folding.hpp"
|
||||||
#include "transformations/transpose_sinking/ts_binary.hpp"
|
#include "transformations/transpose_sinking/ts_binary.hpp"
|
||||||
#include "transformations/transpose_sinking/ts_concat.hpp"
|
#include "transformations/transpose_sinking/ts_concat.hpp"
|
||||||
#include "transformations/transpose_sinking/ts_data_movement.hpp"
|
#include "transformations/transpose_sinking/ts_data_movement.hpp"
|
||||||
@ -77,6 +78,7 @@ bool TSGeneral::run_on_model(const std::shared_ptr<ov::Model>& f) {
|
|||||||
manager.register_pass<TSGeneralBackward>();
|
manager.register_pass<TSGeneralBackward>();
|
||||||
manager.register_pass<ConstantFolding>();
|
manager.register_pass<ConstantFolding>();
|
||||||
manager.register_pass<TSResetNoSinkingAttribute>();
|
manager.register_pass<TSResetNoSinkingAttribute>();
|
||||||
|
manager.register_pass<EnableShapeOfConstantFolding>();
|
||||||
manager.run_passes(f);
|
manager.run_passes(f);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -0,0 +1,60 @@
|
|||||||
|
// Copyright (C) 2023 Intel Corporation
|
||||||
|
// SPDX-License-Identifier: Apache-2.0
|
||||||
|
//
|
||||||
|
|
||||||
|
#include "transformations/common_optimizations/enable_shapeof_constant_folding.hpp"
|
||||||
|
|
||||||
|
#include <gtest/gtest.h>
|
||||||
|
|
||||||
|
#include "common_test_utils/ov_test_utils.hpp"
|
||||||
|
#include "openvino/core/model.hpp"
|
||||||
|
#include "openvino/op/abs.hpp"
|
||||||
|
#include "openvino/op/reshape.hpp"
|
||||||
|
#include "openvino/op/shape_of.hpp"
|
||||||
|
#include "openvino/pass/constant_folding.hpp"
|
||||||
|
#include "openvino/pass/manager.hpp"
|
||||||
|
|
||||||
|
using namespace testing;
|
||||||
|
using namespace ov;
|
||||||
|
|
||||||
|
TEST_F(TransformationTestsF, EnableShapeOfV0ConstantFolding) {
|
||||||
|
{
|
||||||
|
auto data = std::make_shared<op::v0::Parameter>(element::f32, Shape{1, 4, 10, 10});
|
||||||
|
auto shape_of = std::make_shared<op::v0::ShapeOf>(data);
|
||||||
|
pass::disable_constant_folding(shape_of);
|
||||||
|
auto abs = std::make_shared<op::v0::Abs>(shape_of);
|
||||||
|
auto reshape = std::make_shared<op::v1::Reshape>(data, abs, false);
|
||||||
|
model = std::make_shared<Model>(reshape, ParameterVector{data});
|
||||||
|
|
||||||
|
manager.register_pass<pass::EnableShapeOfConstantFolding>();
|
||||||
|
manager.register_pass<pass::ConstantFolding>();
|
||||||
|
}
|
||||||
|
|
||||||
|
{
|
||||||
|
auto data = std::make_shared<op::v0::Parameter>(element::f32, Shape{1, 4, 10, 10});
|
||||||
|
auto shape = op::v0::Constant::create(element::i64, Shape{4}, {1, 4, 10, 10});
|
||||||
|
auto reshape = std::make_shared<op::v1::Reshape>(data, shape, false);
|
||||||
|
model_ref = std::make_shared<Model>(reshape, ParameterVector{data});
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(TransformationTestsF, EnableShapeOfV3ConstantFolding) {
|
||||||
|
{
|
||||||
|
auto data = std::make_shared<op::v0::Parameter>(element::f32, Shape{1, 4, 10, 10});
|
||||||
|
auto shape_of = std::make_shared<op::v3::ShapeOf>(data);
|
||||||
|
pass::disable_constant_folding(shape_of);
|
||||||
|
auto abs = std::make_shared<op::v0::Abs>(shape_of);
|
||||||
|
auto reshape = std::make_shared<op::v1::Reshape>(data, abs, false);
|
||||||
|
model = std::make_shared<Model>(reshape, ParameterVector{data});
|
||||||
|
|
||||||
|
manager.register_pass<pass::EnableShapeOfConstantFolding>();
|
||||||
|
manager.register_pass<pass::ConstantFolding>();
|
||||||
|
}
|
||||||
|
|
||||||
|
{
|
||||||
|
auto data = std::make_shared<op::v0::Parameter>(element::f32, Shape{1, 4, 10, 10});
|
||||||
|
auto shape = op::v0::Constant::create(element::i64, Shape{4}, {1, 4, 10, 10});
|
||||||
|
auto reshape = std::make_shared<op::v1::Reshape>(data, shape, false);
|
||||||
|
model_ref = std::make_shared<Model>(reshape, ParameterVector{data});
|
||||||
|
}
|
||||||
|
}
|
@ -46,3 +46,24 @@ TEST_F(CompileModelsTests, ModelWithSplitConvConcat) {
|
|||||||
}));
|
}));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST_F(CompileModelsTests, ModelWithShapeOf) {
|
||||||
|
auto model = convert_model("shapeof_slice_abs/shapeof_slice_abs.pbtxt");
|
||||||
|
ov::Core core;
|
||||||
|
ov::CompiledModel compiled_model = core.compile_model(model, "CPU");
|
||||||
|
const auto runtime_model = compiled_model.get_runtime_model();
|
||||||
|
auto get_layer_type = [](const std::shared_ptr<ov::Node>& node) {
|
||||||
|
return node->get_rt_info().at(ExecGraphInfoSerialization::LAYER_TYPE).as<std::string>();
|
||||||
|
};
|
||||||
|
const auto ops = runtime_model->get_ops();
|
||||||
|
// one Input, one Eltwise and one Output
|
||||||
|
EXPECT_EQ(3, ops.size());
|
||||||
|
// ShapeOf is folded
|
||||||
|
EXPECT_EQ(0, std::count_if(ops.begin(), ops.end(), [&](const std::shared_ptr<ov::Node>& node) {
|
||||||
|
return get_layer_type(node) == "ShapeOf";
|
||||||
|
}));
|
||||||
|
// Slice is eliminated
|
||||||
|
EXPECT_EQ(0, std::count_if(ops.begin(), ops.end(), [&](const std::shared_ptr<ov::Node>& node) {
|
||||||
|
return get_layer_type(node) == "StridedSlice";
|
||||||
|
}));
|
||||||
|
}
|
||||||
|
@ -0,0 +1,149 @@
|
|||||||
|
node {
|
||||||
|
name: "input"
|
||||||
|
op: "Placeholder"
|
||||||
|
attr {
|
||||||
|
key: "_output_shapes"
|
||||||
|
value {
|
||||||
|
list {
|
||||||
|
shape {
|
||||||
|
dim {
|
||||||
|
size: 10
|
||||||
|
}
|
||||||
|
dim {
|
||||||
|
size: 4
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
attr {
|
||||||
|
key: "dtype"
|
||||||
|
value {
|
||||||
|
type: DT_FLOAT
|
||||||
|
}
|
||||||
|
}
|
||||||
|
attr {
|
||||||
|
key: "shape"
|
||||||
|
value {
|
||||||
|
shape {
|
||||||
|
dim {
|
||||||
|
size: 10
|
||||||
|
}
|
||||||
|
dim {
|
||||||
|
size: 4
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
node {
|
||||||
|
name: "start"
|
||||||
|
op: "Const"
|
||||||
|
attr {
|
||||||
|
key: "dtype"
|
||||||
|
value {
|
||||||
|
type: DT_INT32
|
||||||
|
}
|
||||||
|
}
|
||||||
|
attr {
|
||||||
|
key: "value"
|
||||||
|
value {
|
||||||
|
tensor {
|
||||||
|
dtype: DT_INT32
|
||||||
|
tensor_shape {
|
||||||
|
dim {
|
||||||
|
size: 2
|
||||||
|
}
|
||||||
|
}
|
||||||
|
tensor_content: "\000\000\000\000\000\000\000\000"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
node {
|
||||||
|
name: "stop"
|
||||||
|
input: "input"
|
||||||
|
op: "Shape"
|
||||||
|
attr {
|
||||||
|
key: "dtype"
|
||||||
|
value {
|
||||||
|
type: DT_INT32
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
node {
|
||||||
|
name: "stride"
|
||||||
|
op: "Const"
|
||||||
|
attr {
|
||||||
|
key: "dtype"
|
||||||
|
value {
|
||||||
|
type: DT_INT32
|
||||||
|
}
|
||||||
|
}
|
||||||
|
attr {
|
||||||
|
key: "value"
|
||||||
|
value {
|
||||||
|
tensor {
|
||||||
|
dtype: DT_INT32
|
||||||
|
tensor_shape {
|
||||||
|
dim {
|
||||||
|
size: 2
|
||||||
|
}
|
||||||
|
}
|
||||||
|
tensor_content: "\001\000\000\000\001\000\000\000"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
node {
|
||||||
|
name: "axes"
|
||||||
|
op: "Const"
|
||||||
|
attr {
|
||||||
|
key: "dtype"
|
||||||
|
value {
|
||||||
|
type: DT_INT32
|
||||||
|
}
|
||||||
|
}
|
||||||
|
attr {
|
||||||
|
key: "value"
|
||||||
|
value {
|
||||||
|
tensor {
|
||||||
|
dtype: DT_INT32
|
||||||
|
tensor_shape {
|
||||||
|
dim {
|
||||||
|
size: 2
|
||||||
|
}
|
||||||
|
}
|
||||||
|
tensor_content: "\000\000\000\000\001\000\000\000"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
node {
|
||||||
|
name: "slice"
|
||||||
|
op: "Slice"
|
||||||
|
input: "input"
|
||||||
|
input: "start"
|
||||||
|
input: "stop"
|
||||||
|
input: "stride"
|
||||||
|
input: "axes"
|
||||||
|
attr {
|
||||||
|
key: "T"
|
||||||
|
value {
|
||||||
|
type: DT_FLOAT
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
node {
|
||||||
|
name: "abs"
|
||||||
|
op: "Abs"
|
||||||
|
input: "slice"
|
||||||
|
attr {
|
||||||
|
key: "T"
|
||||||
|
value {
|
||||||
|
type: DT_FLOAT
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
library {
|
||||||
|
}
|
Loading…
Reference in New Issue
Block a user