diff --git a/src/common/transformations/include/transformations/common_optimizations/enable_shapeof_constant_folding.hpp b/src/common/transformations/include/transformations/common_optimizations/enable_shapeof_constant_folding.hpp new file mode 100644 index 00000000000..1cfb90a77bd --- /dev/null +++ b/src/common/transformations/include/transformations/common_optimizations/enable_shapeof_constant_folding.hpp @@ -0,0 +1,25 @@ +// Copyright (C) 2023 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#pragma once + +#include "openvino/pass/graph_rewrite.hpp" +#include "transformations_visibility.hpp" + +namespace ov { +namespace pass { + +/** + * @ingroup ie_transformation_common_api + * @brief This transformation enables constfoldability for ShapeOf nodes that was + * disabled by DisableShapeOfConstantFolding. + */ +class TRANSFORMATIONS_API EnableShapeOfConstantFolding : public MatcherPass { +public: + OPENVINO_RTTI("EnableShapeOfConstantFolding", "0"); + EnableShapeOfConstantFolding(); +}; + +} // namespace pass +} // namespace ov diff --git a/src/common/transformations/src/transformations/common_optimizations/enable_shapeof_constant_folding.cpp b/src/common/transformations/src/transformations/common_optimizations/enable_shapeof_constant_folding.cpp new file mode 100644 index 00000000000..3db1d28f7fc --- /dev/null +++ b/src/common/transformations/src/transformations/common_optimizations/enable_shapeof_constant_folding.cpp @@ -0,0 +1,24 @@ +// Copyright (C) 2023 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#include "transformations/common_optimizations/enable_shapeof_constant_folding.hpp" + +#include "openvino/op/util/shape_of_base.hpp" +#include "openvino/pass/pattern/op/wrap_type.hpp" +#include "transformations/rt_info/disable_constant_folding.hpp" + +ov::pass::EnableShapeOfConstantFolding::EnableShapeOfConstantFolding() { + auto shape_of = pattern::wrap_type([=](const Output& output) { + const auto& shape = output.get_partial_shape(); + return shape.is_dynamic() || shape_size(shape.get_shape()) != 1; + }); + + matcher_pass_callback callback = [=](pattern::Matcher& m) { + enable_constant_folding(m.get_match_root()); + return true; + }; + + auto m = std::make_shared(shape_of, "EnableShapeOfConstantFolding"); + this->register_matcher(m, callback); +} diff --git a/src/common/transformations/src/transformations/transpose_sinking/ts_general.cpp b/src/common/transformations/src/transformations/transpose_sinking/ts_general.cpp index ccbef0501f0..4b4a0835a9d 100644 --- a/src/common/transformations/src/transformations/transpose_sinking/ts_general.cpp +++ b/src/common/transformations/src/transformations/transpose_sinking/ts_general.cpp @@ -10,6 +10,7 @@ #include "openvino/pass/manager.hpp" #include "openvino/pass/pattern/op/wrap_type.hpp" #include "transformations/common_optimizations/disable_shapeof_constant_folding.hpp" +#include "transformations/common_optimizations/enable_shapeof_constant_folding.hpp" #include "transformations/transpose_sinking/ts_binary.hpp" #include "transformations/transpose_sinking/ts_concat.hpp" #include "transformations/transpose_sinking/ts_data_movement.hpp" @@ -77,6 +78,7 @@ bool TSGeneral::run_on_model(const std::shared_ptr& f) { manager.register_pass(); manager.register_pass(); manager.register_pass(); + manager.register_pass(); manager.run_passes(f); } diff --git a/src/common/transformations/tests/common_optimizations/enable_shapeof_constant_folding.cpp b/src/common/transformations/tests/common_optimizations/enable_shapeof_constant_folding.cpp new file mode 100644 index 00000000000..a771ca7332c --- /dev/null +++ b/src/common/transformations/tests/common_optimizations/enable_shapeof_constant_folding.cpp @@ -0,0 +1,60 @@ +// Copyright (C) 2023 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#include "transformations/common_optimizations/enable_shapeof_constant_folding.hpp" + +#include + +#include "common_test_utils/ov_test_utils.hpp" +#include "openvino/core/model.hpp" +#include "openvino/op/abs.hpp" +#include "openvino/op/reshape.hpp" +#include "openvino/op/shape_of.hpp" +#include "openvino/pass/constant_folding.hpp" +#include "openvino/pass/manager.hpp" + +using namespace testing; +using namespace ov; + +TEST_F(TransformationTestsF, EnableShapeOfV0ConstantFolding) { + { + auto data = std::make_shared(element::f32, Shape{1, 4, 10, 10}); + auto shape_of = std::make_shared(data); + pass::disable_constant_folding(shape_of); + auto abs = std::make_shared(shape_of); + auto reshape = std::make_shared(data, abs, false); + model = std::make_shared(reshape, ParameterVector{data}); + + manager.register_pass(); + manager.register_pass(); + } + + { + auto data = std::make_shared(element::f32, Shape{1, 4, 10, 10}); + auto shape = op::v0::Constant::create(element::i64, Shape{4}, {1, 4, 10, 10}); + auto reshape = std::make_shared(data, shape, false); + model_ref = std::make_shared(reshape, ParameterVector{data}); + } +} + +TEST_F(TransformationTestsF, EnableShapeOfV3ConstantFolding) { + { + auto data = std::make_shared(element::f32, Shape{1, 4, 10, 10}); + auto shape_of = std::make_shared(data); + pass::disable_constant_folding(shape_of); + auto abs = std::make_shared(shape_of); + auto reshape = std::make_shared(data, abs, false); + model = std::make_shared(reshape, ParameterVector{data}); + + manager.register_pass(); + manager.register_pass(); + } + + { + auto data = std::make_shared(element::f32, Shape{1, 4, 10, 10}); + auto shape = op::v0::Constant::create(element::i64, Shape{4}, {1, 4, 10, 10}); + auto reshape = std::make_shared(data, shape, false); + model_ref = std::make_shared(reshape, ParameterVector{data}); + } +} diff --git a/src/frontends/tensorflow/tests/compilation.cpp b/src/frontends/tensorflow/tests/compilation.cpp index 29bb77c7805..fbb6ad94430 100644 --- a/src/frontends/tensorflow/tests/compilation.cpp +++ b/src/frontends/tensorflow/tests/compilation.cpp @@ -46,3 +46,24 @@ TEST_F(CompileModelsTests, ModelWithSplitConvConcat) { })); } } + +TEST_F(CompileModelsTests, ModelWithShapeOf) { + auto model = convert_model("shapeof_slice_abs/shapeof_slice_abs.pbtxt"); + ov::Core core; + ov::CompiledModel compiled_model = core.compile_model(model, "CPU"); + const auto runtime_model = compiled_model.get_runtime_model(); + auto get_layer_type = [](const std::shared_ptr& node) { + return node->get_rt_info().at(ExecGraphInfoSerialization::LAYER_TYPE).as(); + }; + const auto ops = runtime_model->get_ops(); + // one Input, one Eltwise and one Output + EXPECT_EQ(3, ops.size()); + // ShapeOf is folded + EXPECT_EQ(0, std::count_if(ops.begin(), ops.end(), [&](const std::shared_ptr& node) { + return get_layer_type(node) == "ShapeOf"; + })); + // Slice is eliminated + EXPECT_EQ(0, std::count_if(ops.begin(), ops.end(), [&](const std::shared_ptr& node) { + return get_layer_type(node) == "StridedSlice"; + })); +} diff --git a/src/frontends/tensorflow/tests/test_models/models_pbtxt/shapeof_slice_abs.pbtxt b/src/frontends/tensorflow/tests/test_models/models_pbtxt/shapeof_slice_abs.pbtxt new file mode 100644 index 00000000000..0811447825e --- /dev/null +++ b/src/frontends/tensorflow/tests/test_models/models_pbtxt/shapeof_slice_abs.pbtxt @@ -0,0 +1,149 @@ +node { + name: "input" + op: "Placeholder" + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 10 + } + dim { + size: 4 + } + } + } + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "shape" + value { + shape { + dim { + size: 10 + } + dim { + size: 4 + } + } + } + } +} +node { + name: "start" + op: "Const" + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + dim { + size: 2 + } + } + tensor_content: "\000\000\000\000\000\000\000\000" + } + } + } +} +node { + name: "stop" + input: "input" + op: "Shape" + attr { + key: "dtype" + value { + type: DT_INT32 + } + } +} +node { + name: "stride" + op: "Const" + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + dim { + size: 2 + } + } + tensor_content: "\001\000\000\000\001\000\000\000" + } + } + } +} +node { + name: "axes" + op: "Const" + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + dim { + size: 2 + } + } + tensor_content: "\000\000\000\000\001\000\000\000" + } + } + } +} +node { + name: "slice" + op: "Slice" + input: "input" + input: "start" + input: "stop" + input: "stride" + input: "axes" + attr { + key: "T" + value { + type: DT_FLOAT + } + } +} +node { + name: "abs" + op: "Abs" + input: "slice" + attr { + key: "T" + value { + type: DT_FLOAT + } + } +} +library { +}