From 0846bdb67e47ee453af24996764798cc238be49e Mon Sep 17 00:00:00 2001 From: Evgeny Kotov Date: Mon, 21 Nov 2022 13:24:26 +0100 Subject: [PATCH] Transpose sinking on passes for Concat, Split and binary eltwise operations (#13718) * initial * clang cleanup fixes * remove TrasposeAxis function; cleanup namespaces * fix TransposeInputsInfo spell * one_input_transpose spell * cleanup speel * spell * decompose forward sinking * decompose backward sink * use NodeVector * clang cleanup * decomposite transformations into different files * decompose unit tests * clang cleanup * azure build fixes * code review fixes * clang cleanup fixes --- .../transpose_sinking_binary.hpp | 30 + .../transpose_sinking_concat.hpp | 30 + .../transpose_sinking_split.hpp | 30 + .../transpose_sinking_utils.hpp | 50 ++ .../transpose_sinking_binary.cpp | 74 +++ .../transpose_sinking_concat.cpp | 86 +++ .../transpose_sinking_split.cpp | 200 +++++++ .../transpose_sinking_utils.cpp | 172 ++++++ .../transpose_sinking_binary_test.cpp | 358 ++++++++++++ .../transpose_sinking_concat_test.cpp | 400 +++++++++++++ .../transpose_sinking_split_test.cpp | 552 ++++++++++++++++++ 11 files changed, 1982 insertions(+) create mode 100644 src/common/transformations/include/transformations/common_optimizations/transpose_sinking_binary.hpp create mode 100644 src/common/transformations/include/transformations/common_optimizations/transpose_sinking_concat.hpp create mode 100644 src/common/transformations/include/transformations/common_optimizations/transpose_sinking_split.hpp create mode 100644 src/common/transformations/include/transformations/common_optimizations/transpose_sinking_utils.hpp create mode 100644 src/common/transformations/src/transformations/common_optimizations/transpose_sinking_binary.cpp create mode 100644 src/common/transformations/src/transformations/common_optimizations/transpose_sinking_concat.cpp create mode 100644 src/common/transformations/src/transformations/common_optimizations/transpose_sinking_split.cpp create mode 100644 src/common/transformations/src/transformations/common_optimizations/transpose_sinking_utils.cpp create mode 100644 src/tests/functional/inference_engine/transformations/common_optimizations/transpose_sinking_binary_test.cpp create mode 100644 src/tests/functional/inference_engine/transformations/common_optimizations/transpose_sinking_concat_test.cpp create mode 100644 src/tests/functional/inference_engine/transformations/common_optimizations/transpose_sinking_split_test.cpp diff --git a/src/common/transformations/include/transformations/common_optimizations/transpose_sinking_binary.hpp b/src/common/transformations/include/transformations/common_optimizations/transpose_sinking_binary.hpp new file mode 100644 index 00000000000..663139d8068 --- /dev/null +++ b/src/common/transformations/include/transformations/common_optimizations/transpose_sinking_binary.hpp @@ -0,0 +1,30 @@ +// Copyright (C) 2022 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#pragma once + +#include "openvino/pass/graph_rewrite.hpp" +#include "openvino/pass/pass.hpp" +#include "transformations_visibility.hpp" + +namespace ov { +namespace pass { + +class TRANSFORMATIONS_API TransposeSinkingBinaryElementwiseForward; +class TRANSFORMATIONS_API TransposeSinkingBinaryElementwiseBackward; + +} // namespace pass +} // namespace ov + +class ov::pass::TransposeSinkingBinaryElementwiseForward : public ov::pass::MatcherPass { +public: + OPENVINO_RTTI("ov::pass::TransposeSinkingBinaryElementwiseForward", "0"); + TransposeSinkingBinaryElementwiseForward(); +}; + +class ov::pass::TransposeSinkingBinaryElementwiseBackward : public ov::pass::MatcherPass { +public: + OPENVINO_RTTI("ov::pass::TransposeSinkingBinaryElementwiseBackward", "0"); + TransposeSinkingBinaryElementwiseBackward(); +}; diff --git a/src/common/transformations/include/transformations/common_optimizations/transpose_sinking_concat.hpp b/src/common/transformations/include/transformations/common_optimizations/transpose_sinking_concat.hpp new file mode 100644 index 00000000000..709c2b7b72c --- /dev/null +++ b/src/common/transformations/include/transformations/common_optimizations/transpose_sinking_concat.hpp @@ -0,0 +1,30 @@ +// Copyright (C) 2022 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#pragma once + +#include "openvino/pass/graph_rewrite.hpp" +#include "openvino/pass/pass.hpp" +#include "transformations_visibility.hpp" + +namespace ov { +namespace pass { + +class TRANSFORMATIONS_API TransposeSinkingConcatForward; +class TRANSFORMATIONS_API TransposeSinkingConcatBackward; + +} // namespace pass +} // namespace ov + +class ov::pass::TransposeSinkingConcatForward : public ov::pass::MatcherPass { +public: + OPENVINO_RTTI("ov::pass::TransposeSinkingConcatForward", "0"); + TransposeSinkingConcatForward(); +}; + +class ov::pass::TransposeSinkingConcatBackward : public ov::pass::MatcherPass { +public: + OPENVINO_RTTI("ov::pass::TransposeSinkingConcatBackward", "0"); + TransposeSinkingConcatBackward(); +}; diff --git a/src/common/transformations/include/transformations/common_optimizations/transpose_sinking_split.hpp b/src/common/transformations/include/transformations/common_optimizations/transpose_sinking_split.hpp new file mode 100644 index 00000000000..8ae3d6de607 --- /dev/null +++ b/src/common/transformations/include/transformations/common_optimizations/transpose_sinking_split.hpp @@ -0,0 +1,30 @@ +// Copyright (C) 2022 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#pragma once + +#include "openvino/pass/graph_rewrite.hpp" +#include "openvino/pass/pass.hpp" +#include "transformations_visibility.hpp" + +namespace ov { +namespace pass { + +class TRANSFORMATIONS_API TransposeSinkingSplitBackward; +class TRANSFORMATIONS_API TransposeSinkingSplitForward; + +} // namespace pass +} // namespace ov + +class ov::pass::TransposeSinkingSplitForward : public ov::pass::MatcherPass { +public: + OPENVINO_RTTI("ov::pass::TransposeSinkingSplitForward", "0"); + TransposeSinkingSplitForward(); +}; + +class ov::pass::TransposeSinkingSplitBackward : public ov::pass::MatcherPass { +public: + OPENVINO_RTTI("ov::pass::TransposeSinkingSplitBackward", "0"); + TransposeSinkingSplitBackward(); +}; diff --git a/src/common/transformations/include/transformations/common_optimizations/transpose_sinking_utils.hpp b/src/common/transformations/include/transformations/common_optimizations/transpose_sinking_utils.hpp new file mode 100644 index 00000000000..6ced0ceac88 --- /dev/null +++ b/src/common/transformations/include/transformations/common_optimizations/transpose_sinking_utils.hpp @@ -0,0 +1,50 @@ +// Copyright (C) 2022 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#pragma once + +#include +#include +#include + +#include "itt.hpp" +#include "openvino/op/util/op_types.hpp" +#include "openvino/opsets/opset9.hpp" +#include "openvino/pass/pattern/op/label.hpp" +#include "openvino/pass/pattern/op/wrap_type.hpp" +#include "openvino/util/common_util.hpp" +#include "openvino/util/log.hpp" + +namespace transpose_sinking { + +struct TransposeInputsInfo { + std::shared_ptr transpose; + std::shared_ptr transpose_const; + size_t input_idx; + + bool isEmpty() const { + return !transpose || !transpose_const; + } +}; + +TransposeInputsInfo GetFirstTransposeInput(std::shared_ptr node); +bool IfNodeHasTransposeInputs(const ov::Output& output); +ov::AxisVector ReverseTransposeOrder(const ov::AxisVector& axis_order); +void SwapOutputNames(ov::Output output1, ov::Output output2); +void SwapFriendlyNames(std::shared_ptr node1, std::shared_ptr node2); +void SwapNames(std::shared_ptr node1, std::shared_ptr node2); + +namespace sink_forward { +// insert input reversed transposes, remove first input tranpose +void UpdateInputTransposes(std::shared_ptr main_node, TransposeInputsInfo& transpose_input_info); +void RemoveZeroInputNode(std::shared_ptr main_node); +ov::NodeVector InsertOutputTransposes(std::shared_ptr main_node, TransposeInputsInfo& transpose_input_info); +} // namespace sink_forward + +namespace sink_backward { +ov::NodeVector InsertTransposeBeforeNode(std::shared_ptr main_node, + std::shared_ptr transpose_const); +} // namespace sink_backward + +} // namespace transpose_sinking diff --git a/src/common/transformations/src/transformations/common_optimizations/transpose_sinking_binary.cpp b/src/common/transformations/src/transformations/common_optimizations/transpose_sinking_binary.cpp new file mode 100644 index 00000000000..6326f5fb9f4 --- /dev/null +++ b/src/common/transformations/src/transformations/common_optimizations/transpose_sinking_binary.cpp @@ -0,0 +1,74 @@ +#include "transformations/common_optimizations/transpose_sinking_binary.hpp" + +#include +#include +#include + +#include "itt.hpp" +#include "openvino/op/util/op_types.hpp" +#include "openvino/opsets/opset9.hpp" +#include "openvino/pass/pattern/op/label.hpp" +#include "openvino/pass/pattern/op/wrap_type.hpp" +#include "openvino/util/common_util.hpp" +#include "openvino/util/log.hpp" +#include "transformations/common_optimizations/transpose_sinking_utils.hpp" + +using namespace ov::pass::pattern; +using namespace ov; +using namespace ov::opset9; +using namespace transpose_sinking; + +ov::pass::TransposeSinkingBinaryElementwiseForward::TransposeSinkingBinaryElementwiseForward() { + MATCHER_SCOPE(TransposeSinkingBinaryElementwiseForward); + + auto main_node_label = wrap_type(IfNodeHasTransposeInputs); + + matcher_pass_callback matcher_pass_callback = [=](Matcher& m) { + const auto& pattern_to_output = m.get_pattern_value_map(); + + auto& main_node_output = pattern_to_output.at(main_node_label); + auto main_node = main_node_output.get_node_shared_ptr(); + + TransposeInputsInfo transpose_input_info = GetFirstTransposeInput(main_node); + + sink_forward::UpdateInputTransposes(main_node, transpose_input_info); + for (auto& new_node : sink_forward::InsertOutputTransposes(main_node, transpose_input_info)) { + register_new_node(new_node); + } + + return true; + }; + + auto m = std::make_shared(main_node_label, matcher_name); + register_matcher(m, matcher_pass_callback); +} + +ov::pass::TransposeSinkingBinaryElementwiseBackward::TransposeSinkingBinaryElementwiseBackward() { + MATCHER_SCOPE(TransposeSinkingBinaryElementwiseBackward); + + auto main_node_label = wrap_type(consumers_count(1)); + + auto transpose_const_label = wrap_type(consumers_count(1)); + auto transpose_label = wrap_type({main_node_label, transpose_const_label}, consumers_count(1)); + + matcher_pass_callback matcher_pass_callback = [=](Matcher& m) { + const auto& pattern_to_output = m.get_pattern_value_map(); + auto transpose_const = as_type_ptr(pattern_to_output.at(transpose_const_label).get_node_shared_ptr()); + auto transpose = pattern_to_output.at(transpose_label).get_node_shared_ptr(); + auto main_node = pattern_to_output.at(main_node_label).get_node_shared_ptr(); + + for (auto& new_node : sink_backward::InsertTransposeBeforeNode(main_node, transpose_const)) { + register_new_node(new_node); + } + + // remove transpose after main node + transpose->output(0).replace(main_node); + + SwapNames(transpose, main_node); + + return true; + }; + + auto m = std::make_shared(transpose_label, matcher_name); + register_matcher(m, matcher_pass_callback); +} diff --git a/src/common/transformations/src/transformations/common_optimizations/transpose_sinking_concat.cpp b/src/common/transformations/src/transformations/common_optimizations/transpose_sinking_concat.cpp new file mode 100644 index 00000000000..39d255af05f --- /dev/null +++ b/src/common/transformations/src/transformations/common_optimizations/transpose_sinking_concat.cpp @@ -0,0 +1,86 @@ +#include "transformations/common_optimizations/transpose_sinking_concat.hpp" + +#include +#include +#include +#include + +#include "itt.hpp" +#include "openvino/op/util/op_types.hpp" +#include "openvino/opsets/opset9.hpp" +#include "openvino/pass/pattern/op/label.hpp" +#include "openvino/pass/pattern/op/wrap_type.hpp" +#include "openvino/util/common_util.hpp" +#include "openvino/util/log.hpp" +#include "transformations/common_optimizations/transpose_sinking_utils.hpp" + +using namespace ov::pass::pattern; +using namespace ov; +using namespace ov::opset9; +using namespace transpose_sinking; + +ov::pass::TransposeSinkingConcatForward::TransposeSinkingConcatForward() { + MATCHER_SCOPE(TransposeSinkingConcatForward); + + auto main_node_label = wrap_type(IfNodeHasTransposeInputs); + + matcher_pass_callback matcher_pass_callback = [=](Matcher& m) { + const auto& pattern_to_output = m.get_pattern_value_map(); + + auto& main_node_output = pattern_to_output.at(main_node_label); + auto main_node = main_node_output.get_node_shared_ptr(); + + TransposeInputsInfo transpose_input_info = GetFirstTransposeInput(main_node); + + sink_forward::UpdateInputTransposes(main_node, transpose_input_info); + for (auto& new_node : sink_forward::InsertOutputTransposes(main_node, transpose_input_info)) { + register_new_node(new_node); + } + + auto concat_node = as_type_ptr(main_node); + const auto transpose_axis_order = transpose_input_info.transpose_const->get_axis_vector_val(); + const int64_t transposed_concat_axis = transpose_axis_order[concat_node->get_axis()]; + concat_node->set_concatenation_axis(transposed_concat_axis); + + return true; + }; + + auto m = std::make_shared(main_node_label, matcher_name); + register_matcher(m, matcher_pass_callback); +} + +ov::pass::TransposeSinkingConcatBackward::TransposeSinkingConcatBackward() { + MATCHER_SCOPE(TransposeSinkingConcatBackward); + + auto main_node_label = wrap_type(consumers_count(1)); + + auto transpose_const_label = wrap_type(consumers_count(1)); + auto transpose_label = wrap_type({main_node_label, transpose_const_label}, consumers_count(1)); + + matcher_pass_callback matcher_pass_callback = [=](Matcher& m) { + const auto& pattern_to_output = m.get_pattern_value_map(); + auto transpose_const = as_type_ptr(pattern_to_output.at(transpose_const_label).get_node_shared_ptr()); + auto transpose = pattern_to_output.at(transpose_label).get_node_shared_ptr(); + auto main_node = pattern_to_output.at(main_node_label).get_node_shared_ptr(); + + for (auto& new_node : sink_backward::InsertTransposeBeforeNode(main_node, transpose_const)) { + register_new_node(new_node); + } + + // remove transpose after main node + transpose->output(0).replace(main_node); + + SwapNames(transpose, main_node); + + auto concat_node = as_type_ptr(main_node); + const auto transpose_axis_order = transpose_const->get_axis_vector_val(); + const auto reversed_traspose_axis_order = ReverseTransposeOrder(transpose_axis_order); + const int64_t transposed_concat_axis = reversed_traspose_axis_order[concat_node->get_axis()]; + concat_node->set_concatenation_axis(transposed_concat_axis); + + return true; + }; + + auto m = std::make_shared(transpose_label, matcher_name); + register_matcher(m, matcher_pass_callback); +} diff --git a/src/common/transformations/src/transformations/common_optimizations/transpose_sinking_split.cpp b/src/common/transformations/src/transformations/common_optimizations/transpose_sinking_split.cpp new file mode 100644 index 00000000000..8da19612a59 --- /dev/null +++ b/src/common/transformations/src/transformations/common_optimizations/transpose_sinking_split.cpp @@ -0,0 +1,200 @@ +#include "transformations/common_optimizations/transpose_sinking_split.hpp" + +#include +#include +#include +#include + +#include "itt.hpp" +#include "openvino/op/util/op_types.hpp" +#include "openvino/opsets/opset9.hpp" +#include "openvino/pass/pattern/op/label.hpp" +#include "openvino/pass/pattern/op/wrap_type.hpp" +#include "openvino/util/common_util.hpp" +#include "openvino/util/log.hpp" +#include "transformations/common_optimizations/transpose_sinking_utils.hpp" + +using namespace ov::pass::pattern; +using namespace ov; +using namespace ov::opset9; +using namespace transpose_sinking; + +namespace { + +using NodePtr = std::shared_ptr; + +struct OutputTranspose { + OutputTranspose() : transpose(nullptr), transpose_const(nullptr) {} + Transpose* transpose; + Constant* transpose_const; +}; + +OutputTranspose GetOutputTransposes(NodePtr node) { + for (size_t output_idx = 0; output_idx < node->get_output_size(); ++output_idx) { + for (auto& input : node->get_output_target_inputs(output_idx)) { + auto transpose_node = dynamic_cast(input.get_node()); + if (!transpose_node) + continue; + auto constant_node = dynamic_cast(transpose_node->input_value(1).get_node()); + if (!constant_node) + continue; + { + OutputTranspose output_transpose; + output_transpose.transpose = transpose_node; + output_transpose.transpose_const = constant_node; + + return output_transpose; + } + } + } + + return OutputTranspose(); +} + +NodePtr FindSplitInput(Node* node) { + for (size_t input_idx = 0; input_idx < node->get_input_size(); ++input_idx) { + NodePtr input_node = node->get_input_node_shared_ptr(input_idx); + auto split_node = as_type_ptr(input_node); + if (split_node) + return split_node; + } + return {}; +} + +std::shared_ptr GetTransposeConstant(Input input) { + auto transpose_node = dynamic_cast(input.get_node()); + if (!transpose_node) + return {}; + + auto constant_node = as_type_ptr(transpose_node->input_value(1).get_node_shared_ptr()); + if (!constant_node) + return {}; + + return constant_node; +} + +bool HasInputSplitAndTransposeSiblings(const Output& output) { + NodePtr split_node = FindSplitInput(output.get_node()); + if (!split_node) { + return false; + } + + AxisVector first_transpose_axis_order; + // get first transpose axis + { + auto constant_node = GetTransposeConstant(*(split_node->get_output_target_inputs(0).begin())); + if (!constant_node) + return false; + first_transpose_axis_order = constant_node->get_axis_vector_val(); + } + + for (size_t output_idx = 1; output_idx < split_node->get_output_size(); ++output_idx) { + for (auto& input : split_node->get_output_target_inputs(output_idx)) { + auto constant_node = GetTransposeConstant(input); + if (!constant_node) + return false; + + AxisVector transpose_axis_order = constant_node->get_axis_vector_val(); + if (transpose_axis_order.size() != first_transpose_axis_order.size()) + return false; + if (!std::equal(transpose_axis_order.begin(), + transpose_axis_order.end(), + first_transpose_axis_order.begin())) + return false; + } + } + + return true; +} +} // namespace + +ov::pass::TransposeSinkingSplitBackward::TransposeSinkingSplitBackward() { + MATCHER_SCOPE(TransposeSinkingSplitBackward); + + auto transpose_const_label = wrap_type(consumers_count(1)); + auto transpose_label = + wrap_type({any_input(), transpose_const_label}, HasInputSplitAndTransposeSiblings); + + matcher_pass_callback matcher_pass_callback = [=](Matcher& m) { + const auto& pattern_to_output = m.get_pattern_value_map(); + auto transpose_label_node = pattern_to_output.at(transpose_label).get_node(); + + NodePtr split = FindSplitInput(transpose_label_node); + auto split_axis_constant = as_type_ptr(split->input_value(1).get_node_shared_ptr()); + OutputTranspose output_transpose = GetOutputTransposes(split); + + const auto transpose_axis_order = output_transpose.transpose_const->get_axis_vector_val(); + const auto transpose_element_type = output_transpose.transpose_const->get_element_type(); + + const auto reversed_traspose_axis_order = ReverseTransposeOrder(transpose_axis_order); + + const size_t split_axis = split_axis_constant->get_axis_vector_val()[0]; + const size_t reversed_transposed_split_axis = reversed_traspose_axis_order[split_axis]; + + // insert transpose before split + { + auto input_node = split->input_value(0); + auto new_transpose_const = std::make_shared(transpose_element_type, + Shape{transpose_axis_order.size()}, + transpose_axis_order); + auto new_transpose = std::make_shared(input_node, new_transpose_const); + + split->input(0).replace_source_output(new_transpose->output(0)); + + copy_runtime_info(input_node.get_node_shared_ptr(), {new_transpose, new_transpose_const}); + + register_new_node(new_transpose); + } + + // update split axis + auto new_split_axis_const = std::make_shared(split_axis_constant->get_element_type(), + Shape{}, + reversed_transposed_split_axis); + split->input(1).replace_source_output(new_split_axis_const); + + // remove split output transposes + for (size_t output_idx = 0; output_idx < split->get_output_size(); ++output_idx) { + for (auto& input : split->get_output_target_inputs(output_idx)) { + input.get_node()->output(0).replace(split->output(output_idx)); + } + } + return true; + }; + + auto m = std::make_shared(transpose_label, matcher_name); + register_matcher(m, matcher_pass_callback); +} + +ov::pass::TransposeSinkingSplitForward::TransposeSinkingSplitForward() { + MATCHER_SCOPE(TransposeSinkingSplitForward); + + auto main_node_label = wrap_type(IfNodeHasTransposeInputs); + + matcher_pass_callback matcher_pass_callback = [=](Matcher& m) { + const auto& pattern_to_output = m.get_pattern_value_map(); + + auto& main_node_output = pattern_to_output.at(main_node_label); + auto main_node = main_node_output.get_node_shared_ptr(); + + TransposeInputsInfo transpose_input_info = GetFirstTransposeInput(main_node); + + sink_forward::RemoveZeroInputNode(main_node); + for (auto& new_node : sink_forward::InsertOutputTransposes(main_node, transpose_input_info)) { + register_new_node(new_node); + } + + const auto transpose_axis_order = transpose_input_info.transpose_const->get_axis_vector_val(); + auto split_node = as_type_ptr(main_node); + auto split_axis_constant = as_type_ptr(split_node->input_value(1).get_node_shared_ptr()); + const size_t split_axis = split_axis_constant->get_axis_vector_val()[0]; + const size_t transposed_split_axis = transpose_axis_order[split_axis]; + auto new_split_axis_const = + std::make_shared(split_axis_constant->get_element_type(), Shape{}, transposed_split_axis); + split_node->input(1).replace_source_output(new_split_axis_const); + + return true; + }; + + auto m = std::make_shared(main_node_label, matcher_name); + register_matcher(m, matcher_pass_callback); +} diff --git a/src/common/transformations/src/transformations/common_optimizations/transpose_sinking_utils.cpp b/src/common/transformations/src/transformations/common_optimizations/transpose_sinking_utils.cpp new file mode 100644 index 00000000000..4d7579ce3d6 --- /dev/null +++ b/src/common/transformations/src/transformations/common_optimizations/transpose_sinking_utils.cpp @@ -0,0 +1,172 @@ +#include "transformations/common_optimizations/transpose_sinking_utils.hpp" + +#include +#include +#include + +#include "itt.hpp" +#include "openvino/op/util/op_types.hpp" +#include "openvino/opsets/opset9.hpp" +#include "openvino/pass/pattern/op/label.hpp" +#include "openvino/pass/pattern/op/wrap_type.hpp" +#include "openvino/util/common_util.hpp" +#include "openvino/util/log.hpp" + +namespace transpose_sinking { + +using namespace ov; +using namespace ov::opset9; + +using NodePtr = std::shared_ptr; + +TransposeInputsInfo GetFirstTransposeInput(NodePtr node) { + for (size_t input_idx = 0; input_idx < node->get_input_size(); ++input_idx) { + NodePtr input_node = node->get_input_node_shared_ptr(input_idx); + auto transpose_node = as_type_ptr(input_node); + if (!transpose_node) + continue; + auto constant_node = as_type_ptr(transpose_node->input_value(1).get_node_shared_ptr()); + if (!constant_node) + continue; + { + TransposeInputsInfo input_info; + input_info.transpose = transpose_node; + input_info.transpose_const = constant_node; + input_info.input_idx = input_idx; + return input_info; + } + } + + return TransposeInputsInfo(); +} + +bool IfNodeHasTransposeInputs(const Output& output) { + TransposeInputsInfo inputs_info = GetFirstTransposeInput(output.get_node_shared_ptr()); + return !inputs_info.isEmpty(); +} + +AxisVector ReverseTransposeOrder(const AxisVector& axis_order) { + AxisVector out(axis_order.size()); + for (size_t i = 0; i < axis_order.size(); i++) { + out.at(axis_order[i]) = i; + } + return out; +} + +void SwapOutputNames(Output output1, Output output2) { + const auto node2_output_names = output2.get_names(); + output2.set_names(output1.get_names()); + output1.set_names(node2_output_names); +} + +void SwapFriendlyNames(NodePtr node1, NodePtr node2) { + const std::string node2_name = node2->get_friendly_name(); + node2->set_friendly_name(node1->get_friendly_name()); + node1->set_friendly_name(node2_name); +} + +void SwapNames(NodePtr node1, NodePtr node2) { + SwapFriendlyNames(node1, node2); + SwapOutputNames(node1->output(0), node2->output(0)); +} + +namespace sink_forward { + +// insert input reversed transposes, remove first input tranpose +void UpdateInputTransposes(NodePtr main_node, TransposeInputsInfo& transpose_input_info) { + if (transpose_input_info.isEmpty()) + return; + const auto transpose_axis_order = transpose_input_info.transpose_const->get_axis_vector_val(); + const auto reversed_traspose_axis_order = ReverseTransposeOrder(transpose_axis_order); + const size_t tranpose_input_index = transpose_input_info.input_idx; + const auto transpose_element_type = transpose_input_info.transpose_const->get_element_type(); + + for (size_t i = 0; i < main_node->get_input_size(); ++i) { + auto input_node = main_node->input_value(i); + if (i == tranpose_input_index) { + auto transpose_parent = input_node.get_node()->input_value(0); + main_node->input(i).replace_source_output(transpose_parent); + } else { + auto new_transpose_const = std::make_shared(transpose_element_type, + Shape{reversed_traspose_axis_order.size()}, + reversed_traspose_axis_order); + auto new_transpose = std::make_shared(input_node, new_transpose_const); + + main_node->input(i).replace_source_output(new_transpose->output(0)); + + copy_runtime_info(input_node.get_node_shared_ptr(), {new_transpose, new_transpose_const}); + } + } +} + +void RemoveZeroInputNode(NodePtr main_node) { + auto input_node = main_node->input_value(0); + if (input_node.get_node()->get_input_size() < 1) + return; + auto parent_node = input_node.get_node()->input_value(0); + main_node->input(0).replace_source_output(parent_node); +} + +NodeVector InsertOutputTransposes(NodePtr main_node, TransposeInputsInfo& transpose_input_info) { + if (transpose_input_info.isEmpty()) + return {}; + const auto transpose_axis_order = transpose_input_info.transpose_const->get_axis_vector_val(); + const auto reversed_traspose_axis_order = ReverseTransposeOrder(transpose_axis_order); + const auto transpose_element_type = transpose_input_info.transpose_const->get_element_type(); + + NodeVector new_nodes; + + for (size_t i = 0; i < main_node->get_output_size(); ++i) { + auto main_node_consumers = main_node->output(i).get_target_inputs(); + + auto new_transpose_const = std::make_shared(transpose_element_type, + Shape{transpose_axis_order.size()}, + transpose_axis_order); + auto new_transpose = std::make_shared(main_node->output(i), new_transpose_const); + + for (auto& consumer : main_node_consumers) { + consumer.replace_source_output(new_transpose); + } + + copy_runtime_info(main_node, {new_transpose, new_transpose_const}); + SwapOutputNames(main_node->output(i), new_transpose->output(0)); + + if (main_node->get_output_size() > 1) + new_transpose->set_friendly_name(main_node->get_friendly_name() + "." + std::to_string(i)); + else + SwapFriendlyNames(new_transpose, main_node); + + new_nodes.push_back(new_transpose); + } + + return new_nodes; +} + +} // namespace sink_forward + +namespace sink_backward { +NodeVector InsertTransposeBeforeNode(NodePtr main_node, std::shared_ptr transpose_const) { + const auto transpose_axis_order = transpose_const->get_axis_vector_val(); + const auto transpose_element_type = transpose_const->get_element_type(); + + NodeVector new_nodes; + + for (size_t i = 0; i < main_node->get_input_size(); ++i) { + auto input_node = main_node->input_value(i); + auto new_transpose_const = std::make_shared(transpose_element_type, + Shape{transpose_axis_order.size()}, + transpose_axis_order); + auto new_transpose = std::make_shared(input_node, new_transpose_const); + + main_node->input(i).replace_source_output(new_transpose->output(0)); + + copy_runtime_info(input_node.get_node_shared_ptr(), {new_transpose, new_transpose_const}); + + new_nodes.push_back(new_transpose); + } + + return new_nodes; +} +} // namespace sink_backward + +} // namespace transpose_sinking diff --git a/src/tests/functional/inference_engine/transformations/common_optimizations/transpose_sinking_binary_test.cpp b/src/tests/functional/inference_engine/transformations/common_optimizations/transpose_sinking_binary_test.cpp new file mode 100644 index 00000000000..37041f60d0f --- /dev/null +++ b/src/tests/functional/inference_engine/transformations/common_optimizations/transpose_sinking_binary_test.cpp @@ -0,0 +1,358 @@ +// Copyright (C) 2022 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#include + +#include +#include +#include +#include +#include "common_test_utils/ngraph_test_utils.hpp" + +#include + +#include "gtest/gtest.h" + +namespace { + +using NodePtr = std::shared_ptr; +using ModelPtr = std::shared_ptr; +using Output = ov::Output; + +// ---------------------------------------------------------------------------- + +class IBinaryFactory { +public: + IBinaryFactory() = default; + virtual ~IBinaryFactory() = default; + virtual NodePtr create(NodePtr parent_left_node, NodePtr parent_right_node) const = 0; +}; + +using BinaryFactoryPtr = std::shared_ptr; + +template +class BinaryFactory : public IBinaryFactory { +public: + BinaryFactory() = default; + NodePtr create(NodePtr parent_left_node, NodePtr parent_right_node) const override { + return std::make_shared(parent_left_node, parent_right_node); + } +}; + +template +BinaryFactoryPtr CreateBinaryFactory() { + return std::make_shared>(); +} + +// ---------------------------------------------------------------------------- + +class IPassFactory { +public: + IPassFactory() = default; + virtual ~IPassFactory() = default; + virtual void registerPass(ov::pass::Manager& pass_manager) const = 0; +}; + +using PassFactoryPtr = std::shared_ptr; + +template +class PassFactory : public IPassFactory { +public: + void registerPass(ov::pass::Manager& pass_manager) const override { + pass_manager.register_pass(); + } +}; + +template +PassFactoryPtr CreatePassFactory() { + return std::make_shared>(); +} + +std::vector binary_factories = { + CreateBinaryFactory(), + CreateBinaryFactory(), + CreateBinaryFactory(), + CreateBinaryFactory(), + CreateBinaryFactory(), + CreateBinaryFactory(), + CreateBinaryFactory(), + CreateBinaryFactory(), + CreateBinaryFactory() +}; + +std::vector binary_operations_numbers = {1, 10}; + +std::vector binary_transpose_input_indexes = {0, 1}; + +} // namespace + + +namespace binary { +namespace single_consumer { +namespace forward { +namespace one_input_transpose { + +std::shared_ptr CreateFunction(BinaryFactoryPtr binary_factory, + size_t num_binary_ops, + ov::element::Type input_type, + size_t binary_transpose_input_idx) { + const ov::Shape input_shape{1, 96, 55, 55}; + const ov::Shape const_shape{1, 55, 55, 96}; + + auto X = std::make_shared(input_type, input_shape); + + auto ng_order0 = std::make_shared(ov::element::u64, ov::Shape{4}, ov::Shape{0, 2, 3, 1}); + auto transpose0 = std::make_shared(X, ng_order0); + + NodePtr in_op = transpose0; + for (size_t i = 0; i < num_binary_ops; ++i) { + auto in_constant = std::make_shared(input_type, const_shape, ov::Shape{1}); + if (!binary_transpose_input_idx) + in_op = binary_factory->create(in_op, in_constant); + else + in_op = binary_factory->create(in_constant, in_op); + } + + return std::make_shared(ov::OutputVector{in_op}, ov::ParameterVector{X}); +} + +std::shared_ptr CreateReferenceFunction(BinaryFactoryPtr binary_factory, + size_t num_binary_ops, + ov::element::Type input_type, + size_t binary_transpose_input_idx) { + const ov::Shape input_shape{1, 96, 55, 55}; + const ov::Shape const_shape{1, 55, 55, 96}; + + auto X = std::make_shared(input_type, input_shape); + + NodePtr in_op = X; + for (size_t i = 0; i < num_binary_ops; ++i) { + auto in_constant = std::make_shared(input_type, const_shape, ov::Shape{1}); + + auto transpose_reversed_const = + std::make_shared(ov::element::u64, ov::Shape{4}, ov::Shape{0, 3, 1, 2}); + auto transpose_reversed = std::make_shared(in_constant, transpose_reversed_const); + + if (!binary_transpose_input_idx) + in_op = binary_factory->create(in_op, transpose_reversed); + else + in_op = binary_factory->create(transpose_reversed, in_op); + } + + auto ng_order0 = std::make_shared(ov::element::u64, ov::Shape{4}, ov::Shape{0, 2, 3, 1}); + auto transpose0 = std::make_shared(in_op, ng_order0); + + return std::make_shared(ov::OutputVector{transpose0}, ov::ParameterVector{X}); +} + +} // namespace one_input_transpose + +namespace double_transpose { +std::shared_ptr CreateFunction(BinaryFactoryPtr binary_factory, + size_t num_binary_ops, + ov::element::Type input_type) { + const ov::Shape input_shape{1, 96, 55, 55}; + + auto X = std::make_shared(input_type, input_shape); + + auto ng_order0 = std::make_shared(ov::element::u64, ov::Shape{4}, ov::Shape{0, 2, 3, 1}); + auto transpose0 = std::make_shared(X, ng_order0); + + NodePtr in_op = transpose0; + for (size_t i = 0; i < num_binary_ops; ++i) { + auto in_constant = std::make_shared(input_type, input_shape, ov::Shape{1}); + auto ng_order1 = std::make_shared(ov::element::u64, ov::Shape{4}, ov::Shape{0, 2, 3, 1}); + auto transpose1 = std::make_shared(in_constant, ng_order1); + + in_op = binary_factory->create(in_op, transpose1); + } + + return std::make_shared(ov::OutputVector{in_op}, ov::ParameterVector{X}); +} + +std::shared_ptr CreateReferenceFunction(BinaryFactoryPtr binary_factory, + size_t num_binary_ops, + ov::element::Type input_type) { + const ov::Shape input_shape{1, 96, 55, 55}; + + auto X = std::make_shared(input_type, input_shape); + + NodePtr in_op = X; + for (size_t i = 0; i < num_binary_ops; ++i) { + auto in_constant = std::make_shared(input_type, input_shape, ov::Shape{1}); + + auto ng_order1 = std::make_shared(ov::element::u64, ov::Shape{4}, ov::Shape{0, 2, 3, 1}); + auto transpose1 = std::make_shared(in_constant, ng_order1); + + auto transpose_reversed_const = + std::make_shared(ov::element::u64, ov::Shape{4}, ov::Shape{0, 3, 1, 2}); + auto transpose_reversed = std::make_shared(transpose1, transpose_reversed_const); + + in_op = binary_factory->create(in_op, transpose_reversed); + } + + auto ng_order0 = std::make_shared(ov::element::u64, ov::Shape{4}, ov::Shape{0, 2, 3, 1}); + auto transpose0 = std::make_shared(in_op, ng_order0); + + return std::make_shared(ov::OutputVector{transpose0}, ov::ParameterVector{X}); +} + +} // namespace double_transpose +} // namespace forward + +namespace backward { +namespace one_input_transpose { +std::shared_ptr CreateFunction(BinaryFactoryPtr binary_factory, + size_t num_binary_ops, + ov::element::Type input_type, + size_t binary_transpose_input_idx) { + const ov::Shape input_shape{1, 96, 55, 55}; + + auto X = std::make_shared(input_type, input_shape); + + NodePtr in_op = X; + for (size_t i = 0; i < num_binary_ops; ++i) { + auto in_constant = std::make_shared(input_type, input_shape, ov::Shape{1}); + if (!binary_transpose_input_idx) + in_op = binary_factory->create(in_op, in_constant); + else + in_op = binary_factory->create(in_constant, in_op); + } + + auto ng_order0 = std::make_shared(ov::element::u64, ov::Shape{4}, ov::Shape{0, 2, 3, 1}); + auto transpose0 = std::make_shared(in_op, ng_order0); + + return std::make_shared(ov::OutputVector{transpose0}, ov::ParameterVector{X}); +} + +std::shared_ptr CreateReferenceFunction(BinaryFactoryPtr binary_factory, + size_t num_binary_ops, + ov::element::Type input_type, + size_t binary_transpose_input_idx) { + const ov::Shape input_shape{1, 96, 55, 55}; + + auto X = std::make_shared(input_type, input_shape); + + auto ng_order0 = std::make_shared(ov::element::u64, ov::Shape{4}, ov::Shape{0, 2, 3, 1}); + auto transpose0 = std::make_shared(X, ng_order0); + + NodePtr in_op = transpose0; + for (size_t i = 0; i < num_binary_ops; ++i) { + auto in_constant = std::make_shared(input_type, input_shape, ov::Shape{1}); + + auto ng_order = std::make_shared(ov::element::u64, ov::Shape{4}, ov::Shape{0, 2, 3, 1}); + auto transpose = std::make_shared(in_constant, ng_order); + + if (!binary_transpose_input_idx) + in_op = binary_factory->create(in_op, transpose); + else + in_op = binary_factory->create(transpose, in_op); + } + + return std::make_shared(ov::OutputVector{in_op}, ov::ParameterVector{X}); +} +} // namespace one_input_transpose +} // namespace backward +} // namespace single_consumer +} // namespace binary + +using CreateGraphBinaryF = std::function(BinaryFactoryPtr unary_factory, + size_t num_binary_ops, + ov::element::Type input_type, + size_t binary_transpose_input_idx)>; + +using TestBinaryParams = std::tuple; /* binary_transpose_input_idx */ + +class TransposeSinkingBinaryTestFixture : public ::testing::WithParamInterface, + public TransformationTestsF {}; + + +TEST_P(TransposeSinkingBinaryTestFixture, CompareFunctions) { + BinaryFactoryPtr unary_factory; + PassFactoryPtr pass_factory; + size_t num_binary_ops; + CreateGraphBinaryF model_factory; + CreateGraphBinaryF reference_model_factory; + ov::element::Type input_type; + size_t binary_transpose_input_idx; + std::tie(unary_factory, + pass_factory, + num_binary_ops, + model_factory, + reference_model_factory, + input_type, + binary_transpose_input_idx) = this->GetParam(); + + model = model_factory(unary_factory, num_binary_ops, input_type, binary_transpose_input_idx); + model_ref = reference_model_factory(unary_factory, num_binary_ops, input_type, binary_transpose_input_idx); + pass_factory->registerPass(manager); +} + +INSTANTIATE_TEST_SUITE_P(TransposeSinkingBinaryForwardTestSuite, TransposeSinkingBinaryTestFixture, + ::testing::Combine(::testing::ValuesIn(binary_factories), + ::testing::Values(CreatePassFactory()), + ::testing::ValuesIn(binary_operations_numbers), + ::testing::Values(binary::single_consumer::forward::one_input_transpose::CreateFunction), + ::testing::Values(binary::single_consumer::forward::one_input_transpose::CreateReferenceFunction), + ::testing::Values(ov::element::f32), + ::testing::ValuesIn(binary_transpose_input_indexes))); + +INSTANTIATE_TEST_SUITE_P( + TransposeSinkingBinaryBackwardTestSuite, + TransposeSinkingBinaryTestFixture, + ::testing::Combine(::testing::ValuesIn(binary_factories), + ::testing::Values(CreatePassFactory()), + ::testing::ValuesIn(binary_operations_numbers), + ::testing::Values(binary::single_consumer::backward::one_input_transpose::CreateFunction), + ::testing::Values(binary::single_consumer::backward::one_input_transpose::CreateReferenceFunction), + ::testing::Values(ov::element::f32), + ::testing::ValuesIn(binary_transpose_input_indexes))); + +// -------------------------------------------------------------------------------------- + +using CreateGraphBinaryTwoTransposeInputsF = std::function< + std::shared_ptr(BinaryFactoryPtr unary_factory, size_t num_binary_ops, ov::element::Type input_type)>; + +using TestBinaryTwoTransposeInputsParams = std::tuple; /* input type */ + +class TransposeSinkingBinaryTwoTransposeInputsTestFixture + : public ::testing::WithParamInterface, + public TransformationTestsF {}; + +TEST_P(TransposeSinkingBinaryTwoTransposeInputsTestFixture, CompareFunctions) { + BinaryFactoryPtr unary_factory; + PassFactoryPtr pass_factory; + size_t num_binary_ops; + CreateGraphBinaryTwoTransposeInputsF model_factory; + CreateGraphBinaryTwoTransposeInputsF reference_model_factory; + ov::element::Type input_type; + + std::tie(unary_factory, pass_factory, num_binary_ops, model_factory, reference_model_factory, input_type) = + this->GetParam(); + + model = model_factory(unary_factory, num_binary_ops, input_type); + model_ref = reference_model_factory(unary_factory, num_binary_ops, input_type); + pass_factory->registerPass(manager); +} + +INSTANTIATE_TEST_SUITE_P( + TransposeSinkingBinaryTwoTransposeInputsForwardTestSuite, + TransposeSinkingBinaryTwoTransposeInputsTestFixture, + ::testing::Combine(::testing::ValuesIn(binary_factories), + ::testing::Values(CreatePassFactory()), + ::testing::ValuesIn(binary_operations_numbers), + ::testing::Values(binary::single_consumer::forward::double_transpose::CreateFunction), + ::testing::Values(binary::single_consumer::forward::double_transpose::CreateReferenceFunction), + ::testing::Values(ov::element::f32))); diff --git a/src/tests/functional/inference_engine/transformations/common_optimizations/transpose_sinking_concat_test.cpp b/src/tests/functional/inference_engine/transformations/common_optimizations/transpose_sinking_concat_test.cpp new file mode 100644 index 00000000000..f4ca73e5c85 --- /dev/null +++ b/src/tests/functional/inference_engine/transformations/common_optimizations/transpose_sinking_concat_test.cpp @@ -0,0 +1,400 @@ +// Copyright (C) 2022 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#include + +#include +#include +#include +#include +#include "common_test_utils/ngraph_test_utils.hpp" + +#include + +#include "gtest/gtest.h" + +namespace { + +using NodePtr = std::shared_ptr; +using ModelPtr = std::shared_ptr; +using Output = ov::Output; + +// ---------------------------------------------------------------------------- + +class IBinaryFactory { +public: + IBinaryFactory() = default; + virtual ~IBinaryFactory() = default; + virtual NodePtr create(NodePtr parent_left_node, NodePtr parent_right_node) const = 0; +}; + +using BinaryFactoryPtr = std::shared_ptr; + +template +class BinaryFactory : public IBinaryFactory { +public: + BinaryFactory() = default; + NodePtr create(NodePtr parent_left_node, NodePtr parent_right_node) const override { + return std::make_shared(parent_left_node, parent_right_node); + } +}; + +template +BinaryFactoryPtr CreateBinaryFactory() { + return std::make_shared>(); +} + +// ---------------------------------------------------------------------------- + +class IPassFactory { +public: + IPassFactory() = default; + virtual ~IPassFactory() = default; + virtual void registerPass(ov::pass::Manager& pass_manager) const = 0; +}; + +using PassFactoryPtr = std::shared_ptr; + +template +class PassFactory : public IPassFactory { +public: + void registerPass(ov::pass::Manager& pass_manager) const override { + pass_manager.register_pass(); + } +}; + +template +PassFactoryPtr CreatePassFactory() { + return std::make_shared>(); +} + +std::vector binary_factories = { + CreateBinaryFactory(), + CreateBinaryFactory(), + CreateBinaryFactory(), + CreateBinaryFactory(), + CreateBinaryFactory(), + CreateBinaryFactory(), + CreateBinaryFactory(), + CreateBinaryFactory(), + CreateBinaryFactory() +}; + +std::vector binary_operations_numbers = {1, 10}; + +std::vector binary_transpose_input_indexes = {0, 1}; + +} // namespace + + +using CreateGraphConcatF = std::function< std::shared_ptr (size_t num_concat_ops, + ov::element::Type input_type, + size_t concat_transpose_input_idx, + size_t num_concat_inputs) >; + +using TestConcatParams = std::tuple; /* num_concat_inputs */ + +class TransposeSinkingConcatTestFixture: public ::testing::WithParamInterface, + public TransformationTestsF {}; + +namespace { + +std::vector concat_operations_numbers = {1, 10}; + +std::vector concat_transpose_input_indexes = {0, 2}; + +} // namespace + +namespace single_consumer { +namespace forward { +namespace one_input_transpose { + +std::shared_ptr CreateFunction(size_t num_concat_ops, + ov::element::Type input_type, + size_t concat_transpose_input_idx, + size_t num_concat_inputs) { + const ov::Shape input_shape{1, 96, 55, 55}; + const ov::Shape const_shape{1, 55, 55, 96}; + + auto X = std::make_shared(input_type, input_shape); + + auto ng_order0 = std::make_shared(ov::element::u64, ov::Shape{4}, ov::Shape{0, 2, 3, 1}); + auto transpose0 = std::make_shared(X, ng_order0); + + NodePtr in_op = transpose0; + for (size_t i = 0; i < num_concat_ops; ++i) { + ov::OutputVector concat_inputs; + for (size_t j = 0; j < num_concat_inputs; ++j) { + if (j == concat_transpose_input_idx) + concat_inputs.push_back(in_op); + else + concat_inputs.push_back(std::make_shared(input_type, const_shape, ov::Shape{1})); + } + in_op = std::make_shared(concat_inputs, 1); + } + + return std::make_shared(ov::OutputVector{in_op}, ov::ParameterVector{X}); +} + +std::shared_ptr CreateReferenceFunction(size_t num_concat_ops, + ov::element::Type input_type, + size_t concat_transpose_input_idx, + size_t num_concat_inputs) { + const ov::Shape input_shape{1, 96, 55, 55}; + const ov::Shape const_shape{1, 55, 55, 96}; + + auto X = std::make_shared(input_type, input_shape); + + NodePtr in_op = X; + for (size_t i = 0; i < num_concat_ops; ++i) { + ov::OutputVector concat_inputs; + for (size_t j = 0; j < num_concat_inputs; ++j) { + if (j == concat_transpose_input_idx) { + concat_inputs.push_back(in_op); + } else { + auto in_constant = std::make_shared(input_type, const_shape, ov::Shape{1}); + + auto transpose_reversed_const = + std::make_shared(ov::element::u64, ov::Shape{4}, ov::Shape{0, 3, 1, 2}); + auto transpose_reversed = + std::make_shared(in_constant, transpose_reversed_const); + + concat_inputs.push_back(transpose_reversed); + } + } + in_op = std::make_shared(concat_inputs, 2); + } + + auto ng_order0 = std::make_shared(ov::element::u64, ov::Shape{4}, ov::Shape{0, 2, 3, 1}); + auto transpose0 = std::make_shared(in_op, ng_order0); + + return std::make_shared(ov::OutputVector{transpose0}, ov::ParameterVector{X}); +} + +} // namespace one_input_transpose + +namespace double_transpose { + +std::shared_ptr CreateFunction(size_t num_concat_ops, + ov::element::Type input_type, + size_t num_concat_inputs) { + const ov::Shape input_shape{1, 96, 55, 55}; + + auto X = std::make_shared(input_type, input_shape); + + auto ng_order0 = std::make_shared(ov::element::u64, ov::Shape{4}, ov::Shape{0, 2, 3, 1}); + auto transpose0 = std::make_shared(X, ng_order0); + + NodePtr in_op = transpose0; + for (size_t i = 0; i < num_concat_ops; ++i) { + ov::OutputVector concat_inputs; + concat_inputs.push_back(in_op); + for (size_t j = 1; j < num_concat_inputs; ++j) { + auto in_constant = std::make_shared(input_type, input_shape, ov::Shape{1}); + auto ng_order1 = + std::make_shared(ov::element::u64, ov::Shape{4}, ov::Shape{0, 2, 3, 1}); + auto transpose1 = std::make_shared(in_constant, ng_order1); + concat_inputs.push_back(transpose1); + } + in_op = std::make_shared(concat_inputs, 1); + } + + return std::make_shared(ov::OutputVector{in_op}, ov::ParameterVector{X}); +} + +std::shared_ptr CreateReferenceFunction(size_t num_concat_ops, + ov::element::Type input_type, + size_t num_concat_inputs) { + const ov::Shape input_shape{1, 96, 55, 55}; + + auto X = std::make_shared(input_type, input_shape); + + NodePtr in_op = X; + for (size_t i = 0; i < num_concat_ops; ++i) { + ov::OutputVector concat_inputs; + + concat_inputs.push_back(in_op); + + for (size_t j = 1; j < num_concat_inputs; ++j) { + auto in_constant = std::make_shared(input_type, input_shape, ov::Shape{1}); + + auto ng_order1 = + std::make_shared(ov::element::u64, ov::Shape{4}, ov::Shape{0, 2, 3, 1}); + auto transpose1 = std::make_shared(in_constant, ng_order1); + + auto transpose_reversed_const = + std::make_shared(ov::element::u64, ov::Shape{4}, ov::Shape{0, 3, 1, 2}); + auto transpose_reversed = std::make_shared(transpose1, transpose_reversed_const); + + concat_inputs.push_back(transpose_reversed); + } + in_op = std::make_shared(concat_inputs, 2); + } + + auto ng_order0 = std::make_shared(ov::element::u64, ov::Shape{4}, ov::Shape{0, 2, 3, 1}); + auto transpose0 = std::make_shared(in_op, ng_order0); + + return std::make_shared(ov::OutputVector{transpose0}, ov::ParameterVector{X}); +} + +} // namespace double_transpose + +} // namespace forward + +namespace backward { + +std::shared_ptr CreateFunction(size_t num_concat_ops, + ov::element::Type input_type, + size_t concat_transpose_input_idx, + size_t num_concat_inputs) { + const ov::Shape input_shape{1, 96, 55, 55}; + + auto X = std::make_shared(input_type, input_shape); + + NodePtr in_op = X; + for (size_t i = 0; i < num_concat_ops; ++i) { + ov::OutputVector concat_inputs; + for (size_t j = 0; j < num_concat_inputs; ++j) { + if (j == concat_transpose_input_idx) + concat_inputs.push_back(in_op); + else + concat_inputs.push_back(std::make_shared(input_type, input_shape, ov::Shape{1})); + } + in_op = std::make_shared(concat_inputs, 1); + } + + auto ng_order0 = std::make_shared(ov::element::u64, ov::Shape{4}, ov::Shape{0, 2, 3, 1}); + auto transpose0 = std::make_shared(in_op, ng_order0); + + return std::make_shared(ov::OutputVector{transpose0}, ov::ParameterVector{X}); +} + +std::shared_ptr CreateReferenceFunction(size_t num_concat_ops, + ov::element::Type input_type, + size_t concat_transpose_input_idx, + size_t num_concat_inputs) { + const ov::Shape input_shape{1, 96, 55, 55}; + + auto X = std::make_shared(input_type, input_shape); + + auto ng_order0 = std::make_shared(ov::element::u64, ov::Shape{4}, ov::Shape{0, 2, 3, 1}); + auto transpose0 = std::make_shared(X, ng_order0); + + NodePtr in_op = transpose0; + for (size_t i = 0; i < num_concat_ops; ++i) { + ov::OutputVector concat_inputs; + for (size_t j = 0; j < num_concat_inputs; ++j) { + if (j == concat_transpose_input_idx) { + concat_inputs.push_back(in_op); + } else { + auto in_constant = std::make_shared(input_type, input_shape, ov::Shape{1}); + + auto transpose_reversed_const = std::make_shared(ov::element::u64, ov::Shape{4}, ov::Shape{0, 2, 3, 1}); + auto transpose_reversed = std::make_shared(in_constant, transpose_reversed_const); + + concat_inputs.push_back(transpose_reversed); + } + } + in_op = std::make_shared(concat_inputs, 3); + } + + return std::make_shared(ov::OutputVector{in_op}, ov::ParameterVector{X}); +} + +} // namespace backward +} // namespace single_consumer + + + +TEST_P(TransposeSinkingConcatTestFixture, CompareFunctions) { + PassFactoryPtr pass_factory; + size_t num_concat_ops; + CreateGraphConcatF model_factory; + CreateGraphConcatF reference_model_factory; + ov::element::Type input_type; + size_t concat_transpose_input_idx; + size_t num_concat_inputs; + std::tie(pass_factory, + num_concat_ops, + model_factory, + reference_model_factory, + input_type, + concat_transpose_input_idx, + num_concat_inputs) = this->GetParam(); + + model = model_factory(num_concat_ops, input_type, concat_transpose_input_idx, num_concat_inputs); + model_ref = reference_model_factory(num_concat_ops, input_type, concat_transpose_input_idx, num_concat_inputs); + pass_factory->registerPass(manager); +} + +INSTANTIATE_TEST_SUITE_P(TransposeSinkingConcatForwardTestSuite, TransposeSinkingConcatTestFixture, + ::testing::Combine(::testing::Values(CreatePassFactory()), + ::testing::ValuesIn(concat_operations_numbers), + ::testing::Values(single_consumer::forward::one_input_transpose::CreateFunction), + ::testing::Values(single_consumer::forward::one_input_transpose::CreateReferenceFunction), + ::testing::Values(ov::element::f32), + ::testing::ValuesIn(concat_transpose_input_indexes), + ::testing::Values(5))); + +INSTANTIATE_TEST_SUITE_P(TransposeSinkingConcatBackwardTestSuite, TransposeSinkingConcatTestFixture, + ::testing::Combine(::testing::Values(CreatePassFactory()), + ::testing::ValuesIn(concat_operations_numbers), + ::testing::Values(single_consumer::backward::CreateFunction), + ::testing::Values(single_consumer::backward::CreateReferenceFunction), + ::testing::Values(ov::element::f32), + ::testing::ValuesIn(concat_transpose_input_indexes), + ::testing::Values(5))); + + +// -------------------------------------------------------------------------------------- + +using CreateGraphConcatAllTransposesInputF = std::function(size_t num_concat_ops, + ov::element::Type input_type, + size_t num_concat_inputs)>; + +using TestConcatAllTransposesInputParams = std::tuple; /* num_concat_inputs */ + +class TransposeSinkingConcatAllTransposesInputTestFixture + : public ::testing::WithParamInterface, + public TransformationTestsF {}; + +TEST_P(TransposeSinkingConcatAllTransposesInputTestFixture, CompareFunctions) { + PassFactoryPtr pass_factory; + size_t num_concat_ops; + CreateGraphConcatAllTransposesInputF model_factory; + CreateGraphConcatAllTransposesInputF reference_model_factory; + ov::element::Type input_type; + size_t num_concat_inputs; + std::tie(pass_factory, + num_concat_ops, + model_factory, + reference_model_factory, + input_type, + num_concat_inputs) = this->GetParam(); + + model = model_factory(num_concat_ops, input_type, num_concat_inputs); + model_ref = reference_model_factory(num_concat_ops, input_type, num_concat_inputs); + pass_factory->registerPass(manager); +} + +INSTANTIATE_TEST_SUITE_P( + TransposeSinkingConcatForwardAllTransposesTestSuite, + TransposeSinkingConcatAllTransposesInputTestFixture, + ::testing::Combine(::testing::Values(CreatePassFactory()), + ::testing::ValuesIn(concat_operations_numbers), + ::testing::Values(single_consumer::forward::double_transpose::CreateFunction), + ::testing::Values(single_consumer::forward::double_transpose::CreateReferenceFunction), + ::testing::Values(ov::element::f32), + ::testing::Values(5))); diff --git a/src/tests/functional/inference_engine/transformations/common_optimizations/transpose_sinking_split_test.cpp b/src/tests/functional/inference_engine/transformations/common_optimizations/transpose_sinking_split_test.cpp new file mode 100644 index 00000000000..1e35c3a6a99 --- /dev/null +++ b/src/tests/functional/inference_engine/transformations/common_optimizations/transpose_sinking_split_test.cpp @@ -0,0 +1,552 @@ +// Copyright (C) 2022 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#include + +#include +#include +#include +#include +#include "common_test_utils/ngraph_test_utils.hpp" + +#include + +#include "gtest/gtest.h" + +namespace { + +using NodePtr = std::shared_ptr; +using ModelPtr = std::shared_ptr; +using Output = ov::Output; + +// ---------------------------------------------------------------------------- + +class IBinaryFactory { +public: + IBinaryFactory() = default; + virtual ~IBinaryFactory() = default; + virtual NodePtr create(NodePtr parent_left_node, NodePtr parent_right_node) const = 0; +}; + +using BinaryFactoryPtr = std::shared_ptr; + +template +class BinaryFactory : public IBinaryFactory { +public: + BinaryFactory() = default; + NodePtr create(NodePtr parent_left_node, NodePtr parent_right_node) const override { + return std::make_shared(parent_left_node, parent_right_node); + } +}; + +template +BinaryFactoryPtr CreateBinaryFactory() { + return std::make_shared>(); +} + +// ---------------------------------------------------------------------------- + +class IPassFactory { +public: + IPassFactory() = default; + virtual ~IPassFactory() = default; + virtual void registerPass(ov::pass::Manager& pass_manager) const = 0; +}; + +using PassFactoryPtr = std::shared_ptr; + +template +class PassFactory : public IPassFactory { +public: + void registerPass(ov::pass::Manager& pass_manager) const override { + pass_manager.register_pass(); + } +}; + +template +PassFactoryPtr CreatePassFactory() { + return std::make_shared>(); +} + +std::vector binary_factories = { + CreateBinaryFactory(), + CreateBinaryFactory(), + CreateBinaryFactory(), + CreateBinaryFactory(), + CreateBinaryFactory(), + CreateBinaryFactory(), + CreateBinaryFactory(), + CreateBinaryFactory(), + CreateBinaryFactory() +}; + +std::vector binary_operations_numbers = {1, 10}; + +std::vector binary_transpose_input_indexes = {0, 1}; + +} // namespace + +// -------------------------------------------------------------------------------------- + +using CreateGraphSplitForwardF = std::function< std::shared_ptr (size_t num_split_ops, + size_t num_split_outputs, + ov::element::Type input_type)>; + +using TestSplitForwardParams = std::tuple /* input type */; + +class TransposeSinkingSplitForwardTestFixture: public ::testing::WithParamInterface, + public TransformationTestsF {}; + +namespace { + +std::vector split_operations_numbers = {1, 10}; + +std::vector split_outputs_numbers = {2, 3}; + +} // namespace + +namespace split { +namespace forward { +std::shared_ptr CreateFunction(size_t num_split_ops, + size_t num_split_outputs, + ov::element::Type input_type) { + const ov::Shape input_shape{96, static_cast(std::pow(num_split_outputs, num_split_ops + 1)), 55, 55}; + + auto X = std::make_shared(input_type, input_shape); + + auto ng_order0 = std::make_shared(ov::element::u64, ov::Shape{4}, ov::Shape{0, 3, 1, 2}); + auto transpose0 = std::make_shared(X, ng_order0); + + ov::OutputVector outputs; + Output in_op = transpose0->output(0); + for (size_t i = 0; i < num_split_ops; ++i) { + auto split_axis_const = std::make_shared(ov::element::u64, + ov::Shape{}, + 2); + auto split = std::make_shared(in_op, split_axis_const, num_split_outputs); + for (size_t num_output = 0; num_output < num_split_outputs - 1; ++num_output) { + outputs.push_back(split->output(num_output)); + } + in_op = split->output(num_split_outputs - 1); + } + outputs.push_back(in_op); + + return std::make_shared(outputs, ov::ParameterVector{X}); +} + +std::shared_ptr CreateReferenceFunction(size_t num_split_ops, + size_t num_split_outputs, + ov::element::Type input_type) { + const ov::Shape input_shape{96, static_cast(std::pow(num_split_outputs, num_split_ops + 1)), 55, 55}; + + auto X = std::make_shared(input_type, input_shape); + + ov::OutputVector outputs; + Output in_op = X->output(0); + for (size_t i = 0; i < num_split_ops; ++i) { + auto split_axis_const = std::make_shared(ov::element::u64, + ov::Shape{}, + 1); + auto split = std::make_shared(in_op, split_axis_const, num_split_outputs); + for (size_t num_output = 0; num_output < num_split_outputs - 1; ++num_output) { + auto ng_order0 = std::make_shared(ov::element::u64, ov::Shape{4}, ov::Shape{0, 3, 1, 2}); + auto transpose0 = std::make_shared(split->output(num_output), ng_order0); + outputs.push_back(transpose0); + } + in_op = split->output(num_split_outputs - 1); + } + + auto ng_order0 = std::make_shared(ov::element::u64, ov::Shape{4}, ov::Shape{0, 3, 1, 2}); + auto transpose0 = std::make_shared(in_op, ng_order0); + outputs.push_back(transpose0); + + return std::make_shared(outputs, ov::ParameterVector{X}); +} + +} // namespace forward +} // namespace split + + +TEST_P(TransposeSinkingSplitForwardTestFixture, CompareFunctions) { + PassFactoryPtr pass_factory; + size_t num_split_ops; + size_t num_split_outputs; + CreateGraphSplitForwardF model_factory; + CreateGraphSplitForwardF reference_model_factory; + ov::element::Type input_type; + std::tie(pass_factory, + num_split_ops, + num_split_outputs, + model_factory, + reference_model_factory, + input_type) = this->GetParam(); + + model = model_factory(num_split_ops, num_split_outputs, input_type); + model_ref = reference_model_factory(num_split_ops, num_split_outputs, input_type); + pass_factory->registerPass(manager); +} + +INSTANTIATE_TEST_SUITE_P( + TransposeSinkingSplitForwardTestSuite, + TransposeSinkingSplitForwardTestFixture, + ::testing::Combine(::testing::Values(CreatePassFactory()), + ::testing::ValuesIn(split_operations_numbers), + ::testing::ValuesIn(split_outputs_numbers), + ::testing::Values(split::forward::CreateFunction), + ::testing::Values(split::forward::CreateReferenceFunction), + ::testing::Values(ov::element::f32))); + + +// -------------------------------------------------------------------------------------- + +using CreateGraphSplitBackwardF = std::function< std::shared_ptr (size_t split_tree_depth, + size_t num_split_outputs, + ov::element::Type input_type)>; + +using TestSplitBackwardParams = std::tuple /* input type */; + +class TransposeSinkingSplitBackwardTestFixture: public ::testing::WithParamInterface, + public TransformationTestsF {}; + +namespace { +std::vector split_tree_depth_nums = {1, 3}; +} // namespace + +// -------------------------------------------------------------------------------------- + +namespace split { +namespace backward { + +class SplitFactory { +public: + SplitFactory(size_t axis, size_t n_outputs, ov::element::Type elem_type) : + _axis(axis), _n_outputs(n_outputs), _elem_type(elem_type) {} + NodePtr create(Output parent) const { + auto split_axis_const = std::make_shared(_elem_type, + ov::Shape{}, + _axis); + return std::make_shared(parent, split_axis_const, _n_outputs); + } +private: + const size_t _axis; + const size_t _n_outputs; + const ov::element::Type _elem_type; +}; + +void CreateSplitTree(size_t max_depth, size_t depth, Output parent, const SplitFactory & split_factory, ov::OutputVector & leaves) { + if (depth == max_depth) { + leaves.push_back(parent); + return; + } + + auto split = split_factory.create(parent); + + for (size_t output_idx = 0; output_idx < split->get_output_size(); ++output_idx) { + CreateSplitTree(max_depth, depth + 1, split->output(output_idx), split_factory, leaves); + } +} + +std::shared_ptr CreateFunction(size_t split_tree_depth, + size_t num_split_outputs, + ov::element::Type input_type) { + const size_t split_input_dim_value = static_cast(std::pow(num_split_outputs, split_tree_depth + 1)); + const ov::Shape input_shape{96, split_input_dim_value, 55, 55}; + + auto X = std::make_shared(input_type, input_shape); + + ov::OutputVector split_tree_leaves; + { + SplitFactory split_factory(/* axis */ 1, num_split_outputs, /* elem_type */ ov::element::u64); + CreateSplitTree(split_tree_depth, /* depth */ 0, X->output(0), split_factory, split_tree_leaves); + } + + ov::OutputVector outputs; + for (auto& split_tree_leaf : split_tree_leaves) { + auto ng_order = std::make_shared(ov::element::u64, ov::Shape{4}, ov::Shape{0, 3, 1, 2}); + auto transpose = std::make_shared(split_tree_leaf, ng_order); + + const size_t split_dim_current_value = static_cast(split_input_dim_value / std::pow(num_split_outputs, split_tree_depth)); + auto reshape_const = std::make_shared(ov::element::u64, ov::Shape{3}, ov::Shape{96, 55, split_dim_current_value * 55}); + auto reshape = std::make_shared(transpose, reshape_const, false); + outputs.push_back(reshape); + } + + return std::make_shared(outputs, ov::ParameterVector{X}); +} + +std::shared_ptr CreateReferenceFunction(size_t split_tree_depth, + size_t num_split_outputs, + ov::element::Type input_type) { + const size_t split_input_dim_value = static_cast(std::pow(num_split_outputs, split_tree_depth + 1)); + const ov::Shape input_shape{96, split_input_dim_value, 55, 55}; + + auto X = std::make_shared(input_type, input_shape); + + auto ng_order = std::make_shared(ov::element::u64, ov::Shape{4}, ov::Shape{0, 3, 1, 2}); + auto transpose = std::make_shared(X, ng_order); + + ov::OutputVector split_tree_leaves; + { + SplitFactory split_factory(/* axis */ 2, num_split_outputs, /* elem_type */ ov::element::u64); + CreateSplitTree(split_tree_depth, /* depth */ 0, transpose->output(0), split_factory, split_tree_leaves); + } + + ov::OutputVector outputs; + for (auto& split_tree_leaf : split_tree_leaves) { + const size_t split_dim_current_value = static_cast(split_input_dim_value / std::pow(num_split_outputs, split_tree_depth)); + auto reshape_const = std::make_shared(ov::element::u64, ov::Shape{3}, ov::Shape{96, 55, split_dim_current_value * 55}); + auto reshape = std::make_shared(split_tree_leaf, reshape_const, false); + outputs.push_back(reshape); + } + + return std::make_shared(outputs, ov::ParameterVector{X}); +} + +} // namespace backward +} // namespace split + + + +TEST_P(TransposeSinkingSplitBackwardTestFixture, CompareFunctions) { + PassFactoryPtr pass_factory; + size_t split_tree_depth; + size_t num_split_outputs; + CreateGraphSplitBackwardF model_factory; + CreateGraphSplitBackwardF reference_model_factory; + ov::element::Type input_type; + std::tie(pass_factory, + split_tree_depth, + num_split_outputs, + model_factory, + reference_model_factory, + input_type) = this->GetParam(); + + model = model_factory(split_tree_depth, num_split_outputs, input_type); + model_ref = reference_model_factory(split_tree_depth, num_split_outputs, input_type); + pass_factory->registerPass(manager); +} + +INSTANTIATE_TEST_SUITE_P( + TransposeSinkingSplitBackwardTestSuite, + TransposeSinkingSplitBackwardTestFixture, + ::testing::Combine(::testing::Values(CreatePassFactory()), + ::testing::ValuesIn(split_tree_depth_nums), + ::testing::ValuesIn(split_outputs_numbers), + ::testing::Values(split::backward::CreateFunction), + ::testing::Values(split::backward::CreateReferenceFunction), + ::testing::Values(ov::element::f32))); + + +using TransposeInsertF = std::function< ov::OutputVector (const ov::OutputVector& split_tree_leaves)>; + +using CreateGraphSplitBackwardRestrictF = std::function< std::shared_ptr (size_t split_tree_depth, + size_t num_split_outputs, + ov::element::Type input_type, + TransposeInsertF tranpose_insert_function)>; + +using TestSplitBackwardRestrictParams = std::tuple; /* insert transpose function */ + +class TransposeSinkingSplitBackwardRestrictTestFixture: public ::testing::WithParamInterface, + public TransformationTestsF {}; + +TEST_P(TransposeSinkingSplitBackwardRestrictTestFixture, CompareFunctions) { + PassFactoryPtr pass_factory; + size_t split_tree_depth; + size_t num_split_outputs; + CreateGraphSplitBackwardRestrictF model_factory; + ov::element::Type input_type; + TransposeInsertF tranpose_insert_function; + std::tie(pass_factory, + split_tree_depth, + num_split_outputs, + model_factory, + input_type, + tranpose_insert_function) = this->GetParam(); + + model = model_factory(split_tree_depth, num_split_outputs, input_type, tranpose_insert_function); + model_ref = model->clone(); + pass_factory->registerPass(manager); +} +namespace split { +namespace backward { +namespace restrictions { + +std::shared_ptr CreateFunction(size_t split_tree_depth, + size_t num_split_outputs, + ov::element::Type input_type, + TransposeInsertF transpose_insert_func) { + const size_t split_input_dim_value = static_cast(std::pow(num_split_outputs, split_tree_depth + 1)); + const ov::Shape input_shape{96, split_input_dim_value, 55, 55}; + + auto X = std::make_shared(input_type, input_shape); + + ov::OutputVector split_tree_leaves; + { + SplitFactory split_factory(/* axis */ 1, num_split_outputs, /* elem_type */ ov::element::u64); + CreateSplitTree(split_tree_depth, /* depth */ 0, X->output(0), split_factory, split_tree_leaves); + } + + ov::OutputVector outputs; + for (auto& split_tree_leaf : transpose_insert_func(split_tree_leaves)) { + const size_t split_dim_current_value = static_cast(split_input_dim_value / std::pow(num_split_outputs, split_tree_depth)); + auto reshape_const = std::make_shared(ov::element::u64, ov::Shape{3}, ov::Shape{96, 55, split_dim_current_value * 55}); + auto reshape = std::make_shared(split_tree_leaf, reshape_const, false); + outputs.push_back(reshape); + } + + return std::make_shared(outputs, ov::ParameterVector{X}); +} + +ov::OutputVector OnlyFirstTranspose(const ov::OutputVector& split_tree_leaves) { + ov::OutputVector outputs; + { + auto& split_tree_leaf = split_tree_leaves.front(); + auto ng_order = std::make_shared(ov::element::u64, ov::Shape{4}, ov::Shape{0, 3, 1, 2}); + auto transpose = std::make_shared(split_tree_leaf, ng_order); + outputs.push_back(transpose); + } + + for (size_t leaf_idx = 1; leaf_idx < split_tree_leaves.size(); ++leaf_idx) { + outputs.push_back(split_tree_leaves[leaf_idx]); + } + + return outputs; +} + +ov::OutputVector OnlyLastTranspose(const ov::OutputVector& split_tree_leaves) { + ov::OutputVector outputs; + { + auto& split_tree_leaf = split_tree_leaves.back(); + auto ng_order = std::make_shared(ov::element::u64, ov::Shape{4}, ov::Shape{0, 3, 1, 2}); + auto transpose = std::make_shared(split_tree_leaf, ng_order); + outputs.push_back(transpose); + } + + for (size_t leaf_idx = 0; leaf_idx < split_tree_leaves.size() - 1; ++leaf_idx) { + outputs.push_back(split_tree_leaves[leaf_idx]); + } + + return outputs; +} + +ov::OutputVector OnlyMiddleTranspose(const ov::OutputVector& split_tree_leaves) { + ov::OutputVector outputs; + size_t middle_idx = split_tree_leaves.size() / 2; + if (split_tree_leaves.size() % 2) + ++middle_idx; + for (size_t leaf_idx = 0; leaf_idx < split_tree_leaves.size() - 1; ++leaf_idx) { + if (leaf_idx == middle_idx) { + auto& split_tree_leaf = split_tree_leaves[leaf_idx]; + auto ng_order = std::make_shared(ov::element::u64, ov::Shape{4}, ov::Shape{0, 3, 1, 2}); + auto transpose = std::make_shared(split_tree_leaf, ng_order); + outputs.push_back(transpose); + } else { + outputs.push_back(split_tree_leaves[leaf_idx]); + } + } + + return outputs; +} + +ov::OutputVector FirstAnotherTranspose(const ov::OutputVector& split_tree_leaves) { + ov::OutputVector outputs; + { + auto& split_tree_leaf = split_tree_leaves.front(); + auto ng_order = std::make_shared(ov::element::u64, ov::Shape{4}, ov::Shape{0, 2, 3, 1}); + auto transpose = std::make_shared(split_tree_leaf, ng_order); + outputs.push_back(transpose); + } + + for (size_t leaf_idx = 1; leaf_idx < split_tree_leaves.size(); ++leaf_idx) { + auto& split_tree_leaf = split_tree_leaves[leaf_idx]; + auto ng_order = std::make_shared(ov::element::u64, ov::Shape{4}, ov::Shape{0, 3, 1, 2}); + auto transpose = std::make_shared(split_tree_leaf, ng_order); + outputs.push_back(transpose); + } + + return outputs; +} + +ov::OutputVector LastAnotherTranspose(const ov::OutputVector& split_tree_leaves) { + ov::OutputVector outputs; + { + auto& split_tree_leaf = split_tree_leaves.back(); + auto ng_order = std::make_shared(ov::element::u64, ov::Shape{4}, ov::Shape{0, 2, 3, 1}); + auto transpose = std::make_shared(split_tree_leaf, ng_order); + outputs.push_back(transpose); + } + + for (size_t leaf_idx = 0; leaf_idx < split_tree_leaves.size() - 1; ++leaf_idx) { + auto& split_tree_leaf = split_tree_leaves[leaf_idx]; + auto ng_order = std::make_shared(ov::element::u64, ov::Shape{4}, ov::Shape{0, 3, 1, 2}); + auto transpose = std::make_shared(split_tree_leaf, ng_order); + outputs.push_back(transpose); + } + + return outputs; +} + +ov::OutputVector MiddleAnotherTranspose(const ov::OutputVector& split_tree_leaves) { + ov::OutputVector outputs; + size_t middle_idx = split_tree_leaves.size() / 2; + if (split_tree_leaves.size() % 2) + ++middle_idx; + for (size_t leaf_idx = 0; leaf_idx < split_tree_leaves.size(); ++leaf_idx) { + auto& split_tree_leaf = split_tree_leaves[leaf_idx]; + if (leaf_idx == middle_idx) { + auto ng_order = std::make_shared(ov::element::u64, ov::Shape{4}, ov::Shape{0, 2, 3, 1}); + auto transpose = std::make_shared(split_tree_leaf, ng_order); + outputs.push_back(transpose); + } else { + auto ng_order = std::make_shared(ov::element::u64, ov::Shape{4}, ov::Shape{0, 3, 1, 2}); + auto transpose = std::make_shared(split_tree_leaf, ng_order); + outputs.push_back(transpose); + } + } + + return outputs; +} + +} // namespace restrictions +} // namespace backward +} // namespace split + +namespace { + +std::vector insertTransposeFactories = { + split::backward::restrictions::OnlyFirstTranspose, + split::backward::restrictions::OnlyLastTranspose, + split::backward::restrictions::OnlyMiddleTranspose, + split::backward::restrictions::FirstAnotherTranspose, + split::backward::restrictions::LastAnotherTranspose, + split::backward::restrictions::MiddleAnotherTranspose +}; + +} // namespace + + +INSTANTIATE_TEST_SUITE_P( + TransposeSinkingSplitBackwardRestrictTestSuite, + TransposeSinkingSplitBackwardRestrictTestFixture, + ::testing::Combine(::testing::Values(CreatePassFactory()), + ::testing::Values(1), + ::testing::Values(5), + ::testing::Values(split::backward::restrictions::CreateFunction), + ::testing::Values(ov::element::f32), + ::testing::ValuesIn(insertTransposeFactories)));