From 22d7bc70d9c7de443c0a52eef1238d42877b766b Mon Sep 17 00:00:00 2001 From: Evgeny Kotov Date: Tue, 13 Dec 2022 17:08:06 +0100 Subject: [PATCH] Transpose sinking general (#13943) * initial * build fixes + couple of simple unit tests * remove old transpose_sinking_binary berfore merge with PR branch * initial * clang cleanup fixes * remove TrasposeAxis function; cleanup namespaces * fix TransposeInputsInfo spell * one_input_transpose spell * cleanup speel * spell * decompose forward sinking * decompose backward sink * use NodeVector * clang cleanup * decomposite transformations into different files * decompose unit tests * clang cleanup * fix transformation names in general transformation * fix ngraph::pass::TransposeFuse use element type the same as in fusing nodes * add checkout sinking ability function; check sinking ability for unary operations; add unit test on general transformation * sinking check for binary; unit tests; fixes * add check to concat transformation; unit test * add check to split tranformation * azure build fixes * add general test * cleanup tests using common class * clang cleanup * add transpose sinkig to moc * remove comment * fix after rebase * clang fixes * fix after rebase * code review fixes * fix after rebase * add RUN_ON_FUNCTION_SCOPE to general transformation * fixes after rebase * move tests to new directory * cleanup * use ov::RuntimeAttribute * move NoTransposeSinkingAttr to files * fix namespace * fix names --- .../transpose_sinking_general.hpp | 36 ++ .../transpose_sinking_utils.hpp | 2 + .../rt_info/transpose_sinking_attr.hpp | 32 ++ .../moc_transformations.cpp | 1 + .../transpose_sinking.cpp | 9 +- .../transpose_sinking_binary.cpp | 13 +- .../transpose_sinking_concat.cpp | 10 +- .../transpose_sinking_general.cpp | 56 +++ .../transpose_sinking_split.cpp | 5 + .../transpose_sinking_unary.cpp | 15 +- .../transpose_sinking_utils.cpp | 41 ++ .../rt_info/transpose_sinking_attr.cpp | 28 ++ .../transpose_sinking_general_test.cpp | 392 ++++++++++++++++++ 13 files changed, 635 insertions(+), 5 deletions(-) create mode 100644 src/common/transformations/include/transformations/common_optimizations/transpose_sinking_general.hpp create mode 100644 src/common/transformations/include/transformations/rt_info/transpose_sinking_attr.hpp create mode 100644 src/common/transformations/src/transformations/common_optimizations/transpose_sinking_general.cpp create mode 100644 src/common/transformations/src/transformations/rt_info/transpose_sinking_attr.cpp create mode 100644 src/common/transformations/tests/common_optimizations/transpose_sinking_general_test.cpp diff --git a/src/common/transformations/include/transformations/common_optimizations/transpose_sinking_general.hpp b/src/common/transformations/include/transformations/common_optimizations/transpose_sinking_general.hpp new file mode 100644 index 00000000000..dbd783a8aa2 --- /dev/null +++ b/src/common/transformations/include/transformations/common_optimizations/transpose_sinking_general.hpp @@ -0,0 +1,36 @@ +// Copyright (C) 2022 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#pragma once + +#include "openvino/pass/graph_rewrite.hpp" +#include "transformations_visibility.hpp" + +namespace ov { +namespace pass { + +class TRANSFORMATIONS_API TransposeSinkingGeneralForward; +class TRANSFORMATIONS_API TransposeSinkingGeneralBackward; +class TRANSFORMATIONS_API TransposeSinkingGeneral; + +} // namespace pass +} // namespace ov + +class ov::pass::TransposeSinkingGeneralForward : public ov::pass::GraphRewrite { +public: + OPENVINO_RTTI("TransposeSinkingGeneralForward", "0"); + TransposeSinkingGeneralForward(); +}; + +class ov::pass::TransposeSinkingGeneralBackward : public ov::pass::GraphRewrite { +public: + OPENVINO_RTTI("TransposeSinkingGeneralBackward", "0"); + TransposeSinkingGeneralBackward(); +}; + +class ov::pass::TransposeSinkingGeneral : public ov::pass::ModelPass { +public: + OPENVINO_RTTI("TransposeSinkingGeneral", "0"); + bool run_on_model(const std::shared_ptr& m) override; +}; diff --git a/src/common/transformations/include/transformations/common_optimizations/transpose_sinking_utils.hpp b/src/common/transformations/include/transformations/common_optimizations/transpose_sinking_utils.hpp index 6ced0ceac88..7c4cc038dbf 100644 --- a/src/common/transformations/include/transformations/common_optimizations/transpose_sinking_utils.hpp +++ b/src/common/transformations/include/transformations/common_optimizations/transpose_sinking_utils.hpp @@ -47,4 +47,6 @@ ov::NodeVector InsertTransposeBeforeNode(std::shared_ptr main_node, std::shared_ptr transpose_const); } // namespace sink_backward +void UpdateForwardSinkingAbility(std::shared_ptr); + } // namespace transpose_sinking diff --git a/src/common/transformations/include/transformations/rt_info/transpose_sinking_attr.hpp b/src/common/transformations/include/transformations/rt_info/transpose_sinking_attr.hpp new file mode 100644 index 00000000000..7e79b59e2a7 --- /dev/null +++ b/src/common/transformations/include/transformations/rt_info/transpose_sinking_attr.hpp @@ -0,0 +1,32 @@ +// Copyright (C) 2022 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#pragma once + +#include "openvino/core/node.hpp" +#include "openvino/core/runtime_attribute.hpp" +#include "transformations_visibility.hpp" + +namespace ov { + +TRANSFORMATIONS_API void mark_as_no_sinking_node(const std::shared_ptr& node); + +TRANSFORMATIONS_API bool is_sinking_node(const std::shared_ptr& node); +TRANSFORMATIONS_API bool is_sinking_node(const Node* node); + +/** + * @ingroup ie_runtime_attr_api + * @brief NoTransposeSinkingAttr class represents runtime info attribute that marks transpose + * operation should not be moved be backward sinking propagation. + */ +class TRANSFORMATIONS_API NoTransposeSinkingAttr : public RuntimeAttribute { +public: + OPENVINO_RTTI("no_transpose_sinking", "0"); + + bool is_copyable() const override { + return false; + } +}; + +} // namespace ov diff --git a/src/common/transformations/src/transformations/common_optimizations/moc_transformations.cpp b/src/common/transformations/src/transformations/common_optimizations/moc_transformations.cpp index 77c23cfeb44..92b1859599a 100644 --- a/src/common/transformations/src/transformations/common_optimizations/moc_transformations.cpp +++ b/src/common/transformations/src/transformations/common_optimizations/moc_transformations.cpp @@ -66,6 +66,7 @@ #include #include #include +#include #include #include #include diff --git a/src/common/transformations/src/transformations/common_optimizations/transpose_sinking.cpp b/src/common/transformations/src/transformations/common_optimizations/transpose_sinking.cpp index 5f9f057f57e..03479cbad03 100644 --- a/src/common/transformations/src/transformations/common_optimizations/transpose_sinking.cpp +++ b/src/common/transformations/src/transformations/common_optimizations/transpose_sinking.cpp @@ -14,6 +14,7 @@ #include #include "itt.hpp" +#include "transformations/common_optimizations/transpose_sinking_utils.hpp" #include "transformations/utils/utils.hpp" using namespace ov; @@ -294,15 +295,21 @@ ov::pass::TransposeFuse::TransposeFuse() { is_ordered = false; } + auto transpose_order_type = transpose1_order->get_element_type(); + if (transpose_order_type != transpose2_order->get_element_type()) + transpose_order_type = element::i64; + if (is_ordered) { return ngraph::replace_output_update_name(transpose2->output(0), input); } else { - auto new_order = opset7::Constant::create(element::i64, {order2.size()}, order2); + auto new_order = opset7::Constant::create(transpose_order_type, {order2.size()}, order2); auto new_transpose = register_new_node(input, new_order); new_transpose->set_friendly_name(m.get_match_root()->get_friendly_name()); ngraph::copy_runtime_info({transpose1, transpose2}, new_transpose); ngraph::replace_node(m.get_match_root(), new_transpose); + + transpose_sinking::UpdateForwardSinkingAbility(new_transpose); } return true; diff --git a/src/common/transformations/src/transformations/common_optimizations/transpose_sinking_binary.cpp b/src/common/transformations/src/transformations/common_optimizations/transpose_sinking_binary.cpp index 6326f5fb9f4..6441e1adb53 100644 --- a/src/common/transformations/src/transformations/common_optimizations/transpose_sinking_binary.cpp +++ b/src/common/transformations/src/transformations/common_optimizations/transpose_sinking_binary.cpp @@ -1,5 +1,6 @@ #include "transformations/common_optimizations/transpose_sinking_binary.hpp" +#include #include #include #include @@ -12,6 +13,7 @@ #include "openvino/util/common_util.hpp" #include "openvino/util/log.hpp" #include "transformations/common_optimizations/transpose_sinking_utils.hpp" +#include "transformations/rt_info/transpose_sinking_attr.hpp" using namespace ov::pass::pattern; using namespace ov; @@ -34,6 +36,7 @@ ov::pass::TransposeSinkingBinaryElementwiseForward::TransposeSinkingBinaryElemen sink_forward::UpdateInputTransposes(main_node, transpose_input_info); for (auto& new_node : sink_forward::InsertOutputTransposes(main_node, transpose_input_info)) { register_new_node(new_node); + transpose_sinking::UpdateForwardSinkingAbility(new_node); } return true; @@ -43,13 +46,19 @@ ov::pass::TransposeSinkingBinaryElementwiseForward::TransposeSinkingBinaryElemen register_matcher(m, matcher_pass_callback); } -ov::pass::TransposeSinkingBinaryElementwiseBackward::TransposeSinkingBinaryElementwiseBackward() { +pass::TransposeSinkingBinaryElementwiseBackward::TransposeSinkingBinaryElementwiseBackward() { MATCHER_SCOPE(TransposeSinkingBinaryElementwiseBackward); auto main_node_label = wrap_type(consumers_count(1)); auto transpose_const_label = wrap_type(consumers_count(1)); - auto transpose_label = wrap_type({main_node_label, transpose_const_label}, consumers_count(1)); + + auto IfSinkingEnabled = [](const Output& output) -> bool { + static auto consumers_check = consumers_count(1); + return consumers_check(output) && is_sinking_node(output.get_node_shared_ptr()); + }; + + auto transpose_label = wrap_type({main_node_label, transpose_const_label}, IfSinkingEnabled); matcher_pass_callback matcher_pass_callback = [=](Matcher& m) { const auto& pattern_to_output = m.get_pattern_value_map(); diff --git a/src/common/transformations/src/transformations/common_optimizations/transpose_sinking_concat.cpp b/src/common/transformations/src/transformations/common_optimizations/transpose_sinking_concat.cpp index 39d255af05f..c6abd1d9c31 100644 --- a/src/common/transformations/src/transformations/common_optimizations/transpose_sinking_concat.cpp +++ b/src/common/transformations/src/transformations/common_optimizations/transpose_sinking_concat.cpp @@ -13,6 +13,7 @@ #include "openvino/util/common_util.hpp" #include "openvino/util/log.hpp" #include "transformations/common_optimizations/transpose_sinking_utils.hpp" +#include "transformations/rt_info/transpose_sinking_attr.hpp" using namespace ov::pass::pattern; using namespace ov; @@ -35,6 +36,7 @@ ov::pass::TransposeSinkingConcatForward::TransposeSinkingConcatForward() { sink_forward::UpdateInputTransposes(main_node, transpose_input_info); for (auto& new_node : sink_forward::InsertOutputTransposes(main_node, transpose_input_info)) { register_new_node(new_node); + transpose_sinking::UpdateForwardSinkingAbility(new_node); } auto concat_node = as_type_ptr(main_node); @@ -55,7 +57,13 @@ ov::pass::TransposeSinkingConcatBackward::TransposeSinkingConcatBackward() { auto main_node_label = wrap_type(consumers_count(1)); auto transpose_const_label = wrap_type(consumers_count(1)); - auto transpose_label = wrap_type({main_node_label, transpose_const_label}, consumers_count(1)); + + auto IfSinkingEnabled = [](const Output& output) -> bool { + static auto consumers_check = consumers_count(1); + return consumers_check(output) && is_sinking_node(output.get_node_shared_ptr()); + }; + + auto transpose_label = wrap_type({main_node_label, transpose_const_label}, IfSinkingEnabled); matcher_pass_callback matcher_pass_callback = [=](Matcher& m) { const auto& pattern_to_output = m.get_pattern_value_map(); diff --git a/src/common/transformations/src/transformations/common_optimizations/transpose_sinking_general.cpp b/src/common/transformations/src/transformations/common_optimizations/transpose_sinking_general.cpp new file mode 100644 index 00000000000..bab7e211f6e --- /dev/null +++ b/src/common/transformations/src/transformations/common_optimizations/transpose_sinking_general.cpp @@ -0,0 +1,56 @@ +// Copyright (C) 2022 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#include "transformations/common_optimizations/transpose_sinking_general.hpp" + +#include +#include +#include +#include +#include + +#include "itt.hpp" +#include "transformations/common_optimizations/transpose_sinking.hpp" +#include "transformations/common_optimizations/transpose_sinking_binary.hpp" +#include "transformations/common_optimizations/transpose_sinking_concat.hpp" +#include "transformations/common_optimizations/transpose_sinking_split.hpp" +#include "transformations/common_optimizations/transpose_sinking_unary.hpp" +#include "transformations/utils/utils.hpp" + +ov::pass::TransposeSinkingGeneralForward::TransposeSinkingGeneralForward() { + MATCHER_SCOPE(TransposeSinkingGeneralForward); + add_matcher(); + add_matcher(); + add_matcher(); + add_matcher(); + add_matcher(); +} + +ov::pass::TransposeSinkingGeneralBackward::TransposeSinkingGeneralBackward() { + MATCHER_SCOPE(TransposeSinkingGeneralBackward); + add_matcher(); + add_matcher(); + add_matcher(); + add_matcher(); + add_matcher(); +} + +bool ov::pass::TransposeSinkingGeneral::run_on_model(const std::shared_ptr& f) { + RUN_ON_FUNCTION_SCOPE(TransposeSinkingGeneral); + { + ngraph::pass::Manager manager(get_pass_config()); + manager.register_pass(); + manager.register_pass(); + manager.run_passes(f); + } + + { + ngraph::pass::Manager manager(get_pass_config()); + manager.register_pass(); + manager.register_pass(); + manager.run_passes(f); + } + + return false; +} diff --git a/src/common/transformations/src/transformations/common_optimizations/transpose_sinking_split.cpp b/src/common/transformations/src/transformations/common_optimizations/transpose_sinking_split.cpp index 8da19612a59..1171bd0aeb4 100644 --- a/src/common/transformations/src/transformations/common_optimizations/transpose_sinking_split.cpp +++ b/src/common/transformations/src/transformations/common_optimizations/transpose_sinking_split.cpp @@ -13,6 +13,7 @@ #include "openvino/util/common_util.hpp" #include "openvino/util/log.hpp" #include "transformations/common_optimizations/transpose_sinking_utils.hpp" +#include "transformations/rt_info/transpose_sinking_attr.hpp" using namespace ov::pass::pattern; using namespace ov; @@ -66,6 +67,9 @@ std::shared_ptr GetTransposeConstant(Input input) { if (!transpose_node) return {}; + if (!is_sinking_node(input.get_node())) + return {}; + auto constant_node = as_type_ptr(transpose_node->input_value(1).get_node_shared_ptr()); if (!constant_node) return {}; @@ -181,6 +185,7 @@ ov::pass::TransposeSinkingSplitForward::TransposeSinkingSplitForward() { sink_forward::RemoveZeroInputNode(main_node); for (auto& new_node : sink_forward::InsertOutputTransposes(main_node, transpose_input_info)) { register_new_node(new_node); + transpose_sinking::UpdateForwardSinkingAbility(new_node); } const auto transpose_axis_order = transpose_input_info.transpose_const->get_axis_vector_val(); diff --git a/src/common/transformations/src/transformations/common_optimizations/transpose_sinking_unary.cpp b/src/common/transformations/src/transformations/common_optimizations/transpose_sinking_unary.cpp index 85b3deb22ea..0eebe9be309 100644 --- a/src/common/transformations/src/transformations/common_optimizations/transpose_sinking_unary.cpp +++ b/src/common/transformations/src/transformations/common_optimizations/transpose_sinking_unary.cpp @@ -6,6 +6,10 @@ #include "itt.hpp" #include "openvino/opsets/opset9.hpp" #include "openvino/pass/pattern/op/wrap_type.hpp" +#include "transformations/common_optimizations/transpose_sinking_utils.hpp" +#include "transformations/rt_info/transpose_sinking_attr.hpp" + +using namespace ov; namespace { @@ -105,6 +109,8 @@ ov::pass::TransposeSinkingUnaryForward::TransposeSinkingUnaryForward() { register_new_node(new_nodes.first); register_new_node(new_nodes.second); + transpose_sinking::UpdateForwardSinkingAbility(new_nodes.second); + return true; }; @@ -112,6 +118,12 @@ ov::pass::TransposeSinkingUnaryForward::TransposeSinkingUnaryForward() { register_matcher(m, matcher_pass_callback); } +namespace { +bool IfSinkingEnabled(const Output& output) { + return is_sinking_node(output.get_node_shared_ptr()); +} +} // namespace + ov::pass::TransposeSinkingUnaryBackward::TransposeSinkingUnaryBackward() { MATCHER_SCOPE(TransposeSinkingUnaryBackward); @@ -123,7 +135,8 @@ ov::pass::TransposeSinkingUnaryBackward::TransposeSinkingUnaryBackward() { ov::opset9::Convert>({ov::pass::pattern::any_input()}); auto transpose_label = - ov::pass::pattern::wrap_type({unary_label, ov::pass::pattern::any_input()}); + ov::pass::pattern::wrap_type({unary_label, ov::pass::pattern::any_input()}, + IfSinkingEnabled); ov::matcher_pass_callback matcher_pass_callback = [=](ov::pass::pattern::Matcher& m) { const auto& pattern_to_output = m.get_pattern_value_map(); diff --git a/src/common/transformations/src/transformations/common_optimizations/transpose_sinking_utils.cpp b/src/common/transformations/src/transformations/common_optimizations/transpose_sinking_utils.cpp index 4d7579ce3d6..69e3b9cecd8 100644 --- a/src/common/transformations/src/transformations/common_optimizations/transpose_sinking_utils.cpp +++ b/src/common/transformations/src/transformations/common_optimizations/transpose_sinking_utils.cpp @@ -11,6 +11,7 @@ #include "openvino/pass/pattern/op/wrap_type.hpp" #include "openvino/util/common_util.hpp" #include "openvino/util/log.hpp" +#include "transformations/rt_info/transpose_sinking_attr.hpp" namespace transpose_sinking { @@ -169,4 +170,44 @@ NodeVector InsertTransposeBeforeNode(NodePtr main_node, std::shared_ptr(node)) { \ + return true; \ + } + +namespace { + +bool CanPropagateForwardThrough(Node* node) { + CHECK_TRANSPOSE_SINKING_SUPPORTED(ov::op::util::UnaryElementwiseArithmetic, node); + CHECK_TRANSPOSE_SINKING_SUPPORTED(Clamp, node); + CHECK_TRANSPOSE_SINKING_SUPPORTED(Elu, node); + CHECK_TRANSPOSE_SINKING_SUPPORTED(SoftPlus, node); + CHECK_TRANSPOSE_SINKING_SUPPORTED(LogicalNot, node); + CHECK_TRANSPOSE_SINKING_SUPPORTED(Convert, node); + CHECK_TRANSPOSE_SINKING_SUPPORTED(ov::op::util::BinaryElementwiseArithmetic, node); + CHECK_TRANSPOSE_SINKING_SUPPORTED(Concat, node); + CHECK_TRANSPOSE_SINKING_SUPPORTED(Split, node); + CHECK_TRANSPOSE_SINKING_SUPPORTED(Transpose, node); + + return false; +} + +bool CanPropagateForward(NodePtr node) { + for (size_t i = 0; i < node->get_output_size(); ++i) { + for (auto& consumer_input : node->output(i).get_target_inputs()) { + if (!CanPropagateForwardThrough(consumer_input.get_node())) + return false; + } + } + + return true; +} + +} // namespace + +void UpdateForwardSinkingAbility(NodePtr node) { + if (!CanPropagateForward(node)) + mark_as_no_sinking_node(node); +} + } // namespace transpose_sinking diff --git a/src/common/transformations/src/transformations/rt_info/transpose_sinking_attr.cpp b/src/common/transformations/src/transformations/rt_info/transpose_sinking_attr.cpp new file mode 100644 index 00000000000..587b2676e93 --- /dev/null +++ b/src/common/transformations/src/transformations/rt_info/transpose_sinking_attr.cpp @@ -0,0 +1,28 @@ +// Copyright (C) 2022 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#include "transformations/rt_info/transpose_sinking_attr.hpp" + +using namespace ov; + +void ov::mark_as_no_sinking_node(const std::shared_ptr& node) { + auto& rt_info = node->get_rt_info(); + rt_info[NoTransposeSinkingAttr::get_type_info_static()] = NoTransposeSinkingAttr(); +} + +namespace { +template +bool is_sinking_node_private(NodePtr node) { + const auto& rt_info = node->get_rt_info(); + return rt_info.find(NoTransposeSinkingAttr::get_type_info_static()) == rt_info.end(); +} +} // namespace + +bool ov::is_sinking_node(const std::shared_ptr& node) { + return is_sinking_node_private(node); +} + +bool ov::is_sinking_node(const Node* node) { + return is_sinking_node_private(node); +} diff --git a/src/common/transformations/tests/common_optimizations/transpose_sinking_general_test.cpp b/src/common/transformations/tests/common_optimizations/transpose_sinking_general_test.cpp new file mode 100644 index 00000000000..1fd26afd508 --- /dev/null +++ b/src/common/transformations/tests/common_optimizations/transpose_sinking_general_test.cpp @@ -0,0 +1,392 @@ +// Copyright (C) 2022 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#include + +#include +#include + +#include + +#include +#include "common_test_utils/ngraph_test_utils.hpp" + +#include + +#include "gtest/gtest.h" + +using namespace testing; + +using NodePtr = std::shared_ptr; + +TEST_F(TransformationTestsF, TransposeSinkingGeneralTestUnariesTransposesForward) { + ov::Shape input_shape = {1, 96, 55, 55}; + ov::element::Type input_type = ov::element::f32; + size_t num_unary_ops = 10; + + { + auto X = std::make_shared(input_type, input_shape); + + NodePtr in_op = X; + for (size_t i = 0; i < num_unary_ops; ++i) { + auto ng_order0 = std::make_shared(ov::element::u64, ov::Shape{4}, ov::Shape{0, 2, 3, 1}); + auto transpose0 = std::make_shared(in_op, ng_order0); + + auto unary = std::make_shared(transpose0); + + auto ng_order1 = std::make_shared(ov::element::u64, ov::Shape{4}, ov::Shape{0, 3, 1, 2}); + in_op = std::make_shared(unary, ng_order1); + } + + function = std::make_shared(in_op, ov::ParameterVector{X}); + } + + { + auto X = std::make_shared(input_type, input_shape); + + NodePtr in_op = X; + for (size_t i = 0; i < num_unary_ops; ++i) { + in_op = std::make_shared(in_op); + } + + function_ref = std::make_shared(in_op, ov::ParameterVector{X}); + } + + manager.register_pass(); +} + +TEST_F(TransformationTestsF, TransposeSinkingGeneralTestUnariesTransposesBackward) { + ov::Shape input_shape = {1, 96, 55, 55}; + ov::element::Type input_type = ov::element::f32; + size_t num_unary_ops = 10; + + { + auto X = std::make_shared(input_type, input_shape); + + NodePtr in_op = X; + for (size_t i = 0; i < num_unary_ops; ++i) { + auto ng_order0 = std::make_shared(ov::element::u64, ov::Shape{4}, ov::Shape{0, 2, 3, 1}); + auto transpose0 = std::make_shared(in_op, ng_order0); + + auto unary = std::make_shared(transpose0); + + auto ng_order1 = std::make_shared(ov::element::u64, ov::Shape{4}, ov::Shape{0, 3, 1, 2}); + in_op = std::make_shared(unary, ng_order1); + } + + function = std::make_shared(in_op, ov::ParameterVector{X}); + } + + { + auto X = std::make_shared(input_type, input_shape); + + NodePtr in_op = X; + for (size_t i = 0; i < num_unary_ops; ++i) { + in_op = std::make_shared(in_op); + } + + function_ref = std::make_shared(in_op, ov::ParameterVector{X}); + } + + manager.register_pass(); +} + +TEST_F(TransformationTestsF, TransposeSinkingGeneralTestUnariesTransposesGeneral) { + ov::Shape input_shape = {1, 96, 55, 55}; + ov::element::Type input_type = ov::element::f32; + size_t num_unary_ops = 10; + + { + auto X = std::make_shared(input_type, input_shape); + + auto ng_order0 = std::make_shared(ov::element::u64, ov::Shape{4}, ov::Shape{0, 2, 3, 1}); + auto transpose0 = std::make_shared(X, ng_order0); + + NodePtr in_op = transpose0; + for (size_t i = 0; i < num_unary_ops; ++i) { + auto ng_order0 = std::make_shared(ov::element::u64, ov::Shape{4}, ov::Shape{0, 2, 3, 1}); + auto transpose0 = std::make_shared(in_op, ng_order0); + + auto unary = std::make_shared(transpose0); + + auto ng_order1 = std::make_shared(ov::element::u64, ov::Shape{4}, ov::Shape{0, 3, 1, 2}); + in_op = std::make_shared(unary, ng_order1); + } + + function = std::make_shared(in_op, ov::ParameterVector{X}); + } + + { + auto X = std::make_shared(input_type, input_shape); + + NodePtr in_op = X; + for (size_t i = 0; i < num_unary_ops; ++i) { + in_op = std::make_shared(in_op); + } + + auto ng_order0 = std::make_shared(ov::element::u64, ov::Shape{4}, ov::Shape{0, 2, 3, 1}); + auto transpose0 = std::make_shared(in_op, ng_order0); + + function_ref = std::make_shared(transpose0, ov::ParameterVector{X}); + } + + manager.register_pass(); +} + +TEST_F(TransformationTestsF, TransposeSinkingGeneralTestBinaryGeneral) { + ov::Shape input_shape = {1, 96, 55, 55}; + ov::element::Type input_type = ov::element::f32; + size_t num_binary_ops = 10; + + { + auto X = std::make_shared(input_type, input_shape); + + auto ng_order0 = std::make_shared(ov::element::u64, ov::Shape{4}, ov::Shape{0, 2, 3, 1}); + auto transpose0 = std::make_shared(X, ng_order0); + + NodePtr in_op = transpose0; + for (size_t i = 0; i < num_binary_ops; ++i) { + auto in_constant = std::make_shared(input_type, input_shape, ov::Shape{1}); + auto ng_order1 = std::make_shared(ov::element::u64, ov::Shape{4}, ov::Shape{0, 2, 3, 1}); + auto transpose1 = std::make_shared(in_constant, ng_order1); + + in_op = std::make_shared(in_op, transpose1); + } + + function = std::make_shared(in_op, ov::ParameterVector{X}); + } + + { + auto X = std::make_shared(input_type, input_shape); + + NodePtr in_op = X; + for (size_t i = 0; i < num_binary_ops; ++i) { + auto in_constant = std::make_shared(input_type, input_shape, ov::Shape{1}); + in_op = std::make_shared(in_op, in_constant); + } + + auto ng_order0 = std::make_shared(ov::element::u64, ov::Shape{4}, ov::Shape{0, 2, 3, 1}); + auto transpose0 = std::make_shared(in_op, ng_order0); + + function_ref = std::make_shared(transpose0, ov::ParameterVector{X}); + } + + manager.register_pass(); +} + +TEST_F(TransformationTestsF, TransposeSinkingGeneralTestConcatGeneral) { + ov::Shape input_shape = {1, 96, 55, 55}; + ov::element::Type input_type = ov::element::f32; + const size_t num_concat_ops = 3; + const size_t num_concat_inputs = 2; + + { + auto X = std::make_shared(input_type, input_shape); + + auto ng_order0 = std::make_shared(ov::element::u64, ov::Shape{4}, ov::Shape{0, 2, 3, 1}); + auto transpose0 = std::make_shared(X, ng_order0); + + NodePtr in_op = transpose0; + for (size_t i = 0; i < num_concat_ops; ++i) { + ov::OutputVector concat_inputs; + concat_inputs.push_back(in_op); + for (size_t j = 1; j < num_concat_inputs; ++j) { + auto in_constant = std::make_shared(input_type, input_shape, ov::Shape{1}); + auto ng_order1 = + std::make_shared(ov::element::u64, ov::Shape{4}, ov::Shape{0, 2, 3, 1}); + auto transpose1 = std::make_shared(in_constant, ng_order1); + concat_inputs.push_back(transpose1); + } + in_op = std::make_shared(concat_inputs, 1); + } + + function = std::make_shared(in_op, ov::ParameterVector{X}); + } + + { + auto X = std::make_shared(input_type, input_shape); + + NodePtr in_op = X; + for (size_t i = 0; i < num_concat_ops; ++i) { + ov::OutputVector concat_inputs; + + concat_inputs.push_back(in_op); + + for (size_t j = 1; j < num_concat_inputs; ++j) { + auto in_constant = std::make_shared(input_type, input_shape, ov::Shape{1}); + concat_inputs.push_back(in_constant); + } + in_op = std::make_shared(concat_inputs, 2); + } + + auto ng_order0 = std::make_shared(ov::element::u64, ov::Shape{4}, ov::Shape{0, 2, 3, 1}); + auto transpose0 = std::make_shared(in_op, ng_order0); + + function_ref = std::make_shared(transpose0, ov::ParameterVector{X}); + } + + manager.register_pass(); +} + +// ---------------------------------------------------------------------------------------------------------------------- + +class IFactory { +public: + virtual ~IFactory() = default; + virtual NodePtr create(const ov::OutputVector & parent) = 0; + + virtual size_t getNumInputs() const = 0; + virtual size_t getNumOuputs() const = 0; +}; + +using FactoryPtr = std::shared_ptr; + +class UnaryFactory : public IFactory { +public: + NodePtr create(const ov::OutputVector & parent) override { + return std::make_shared(parent.front()); + } + + static FactoryPtr createFactory() { + return std::make_shared(); + } + + size_t getNumInputs() const override { return 1; } + size_t getNumOuputs() const override { return 1; } +}; + +class BinaryFactory : public IFactory { +public: + NodePtr create(const ov::OutputVector & parent) override { + return std::make_shared(parent[0], parent[1]); + } + + static FactoryPtr createFactory() { + return std::make_shared(); + } + + size_t getNumInputs() const override { return 2; } + size_t getNumOuputs() const override { return 1; } +}; + +class SplitFactory : public IFactory { +public: + SplitFactory(size_t axis) : axis_(axis) {} + NodePtr create(const ov::OutputVector & parent) override { + auto split_axis_const = std::make_shared(ov::element::u64, + ov::Shape{}, + axis_); + return std::make_shared(parent.front(), split_axis_const, 2); + } + + static FactoryPtr createFactory(size_t axis) { + return std::make_shared(axis); + } + + size_t getNumInputs() const override { return 1; } + size_t getNumOuputs() const override { return 2; } +private: + const size_t axis_; +}; + +class ConcatFactory : public IFactory { +public: + ConcatFactory(size_t axis) : axis_(axis) {} + NodePtr create(const ov::OutputVector & parent) override { + return std::make_shared(parent, axis_); + } + + static FactoryPtr createFactory(size_t axis) { + return std::make_shared(axis); + } + + size_t getNumInputs() const override { return 2; } + size_t getNumOuputs() const override { return 1; } +private: + const size_t axis_; +}; + +/* + Each node pair should be started with input size = 1 node and finished with node output size = 1 + Insert Split/Concat to fullfill that. +*/ +NodePtr CreateNodePair(FactoryPtr factory_first, FactoryPtr factory_second, NodePtr parent, size_t split_axis, size_t concat_axis) { + NodePtr input = parent; + if (factory_first->getNumInputs() != 1) { + input = SplitFactory(split_axis).create(input->outputs()); + } + + input = factory_first->create(input->outputs()); + if (factory_first->getNumOuputs() < factory_second->getNumInputs()) { + input = SplitFactory(split_axis).create(input->outputs()); + } else if (factory_first->getNumOuputs() > factory_second->getNumInputs()) { + input = ConcatFactory(concat_axis).create(input->outputs()); + } + + auto output = factory_second->create(input->outputs()); + if (output->get_output_size() > 1) { + output = ConcatFactory(concat_axis).create(output->outputs()); + } + + return output; +} + +NodePtr MakeAllNodesSubgraph(NodePtr parent, size_t split_axis, size_t concat_axis) { + std::vector factories = { UnaryFactory::createFactory(), + SplitFactory::createFactory(split_axis), + ConcatFactory::createFactory(concat_axis) }; + NodePtr in_op = parent; + for (int i = 0; i < factories.size(); ++i) { + for (int j = 0; j < factories.size(); ++j) { + in_op = CreateNodePair(factories[i], factories[j], in_op, split_axis, concat_axis); + } + } + + return in_op; +} + +TEST_F(TransformationTestsF, TransposeSinkingGeneralTestMultipleTypes) { + ov::Shape input_shape = {1, 96, 40, 55}; + ov::element::Type input_type = ov::element::f32; + + { + auto X = std::make_shared(input_type, input_shape); + + auto node0 = MakeAllNodesSubgraph(X, 1, 1); + + auto ng_order0 = std::make_shared(ov::element::u64, ov::Shape{4}, ov::Shape{0, 2, 3, 1}); + auto transpose0 = std::make_shared(node0, ng_order0); + + auto reshape_const = std::make_shared(ov::element::u64, ov::Shape{4}, ov::Shape{1, 40, 55, 96}); + auto reshape = std::make_shared(transpose0, reshape_const, false); + + auto ng_order1 = std::make_shared(ov::element::u64, ov::Shape{4}, ov::Shape{0, 3, 1, 2}); + auto transpose1 = std::make_shared(reshape, ng_order1); + + auto node1 = MakeAllNodesSubgraph(transpose1, 1, 1); + + function = std::make_shared(node1, ov::ParameterVector{X}); + } + + { + auto X = std::make_shared(input_type, input_shape); + + auto ng_order0 = std::make_shared(ov::element::u64, ov::Shape{4}, ov::Shape{0, 2, 3, 1}); + auto transpose0 = std::make_shared(X, ng_order0); + + auto node0 = MakeAllNodesSubgraph(transpose0, 3, 3); + + auto reshape_const = std::make_shared(ov::element::u64, ov::Shape{4}, ov::Shape{1, 40, 55, 96}); + auto reshape = std::make_shared(node0, reshape_const, false); + + auto node1 = MakeAllNodesSubgraph(reshape, 3, 3); + + auto ng_order1 = std::make_shared(ov::element::u64, ov::Shape{4}, ov::Shape{0, 3, 1, 2}); + auto transpose1 = std::make_shared(node1, ng_order1); + + function_ref = std::make_shared(transpose1, ov::ParameterVector{X}); + } + + manager.register_pass(); +}