diff --git a/src/common/transformations/include/transformations/common_optimizations/transpose_sinking_binary.hpp b/src/common/transformations/include/transformations/common_optimizations/transpose_sinking_binary.hpp index 663139d8068..d7759f8c567 100644 --- a/src/common/transformations/include/transformations/common_optimizations/transpose_sinking_binary.hpp +++ b/src/common/transformations/include/transformations/common_optimizations/transpose_sinking_binary.hpp @@ -11,20 +11,20 @@ namespace ov { namespace pass { -class TRANSFORMATIONS_API TransposeSinkingBinaryElementwiseForward; -class TRANSFORMATIONS_API TransposeSinkingBinaryElementwiseBackward; +class TRANSFORMATIONS_API TransposeSinkingBinaryForward; +class TRANSFORMATIONS_API TransposeSinkingBinaryBackward; } // namespace pass } // namespace ov -class ov::pass::TransposeSinkingBinaryElementwiseForward : public ov::pass::MatcherPass { +class ov::pass::TransposeSinkingBinaryForward : public ov::pass::MatcherPass { public: - OPENVINO_RTTI("ov::pass::TransposeSinkingBinaryElementwiseForward", "0"); - TransposeSinkingBinaryElementwiseForward(); + OPENVINO_RTTI("ov::pass::TransposeSinkingBinaryForward", "0"); + TransposeSinkingBinaryForward(); }; -class ov::pass::TransposeSinkingBinaryElementwiseBackward : public ov::pass::MatcherPass { +class ov::pass::TransposeSinkingBinaryBackward : public ov::pass::MatcherPass { public: - OPENVINO_RTTI("ov::pass::TransposeSinkingBinaryElementwiseBackward", "0"); - TransposeSinkingBinaryElementwiseBackward(); + OPENVINO_RTTI("ov::pass::TransposeSinkingBinaryBackward", "0"); + TransposeSinkingBinaryBackward(); }; diff --git a/src/common/transformations/include/transformations/common_optimizations/transpose_sinking_unary.hpp b/src/common/transformations/include/transformations/common_optimizations/transpose_sinking_unary.hpp index 7ca37216e75..4f3ee0a701f 100644 --- a/src/common/transformations/include/transformations/common_optimizations/transpose_sinking_unary.hpp +++ b/src/common/transformations/include/transformations/common_optimizations/transpose_sinking_unary.hpp @@ -11,6 +11,8 @@ namespace ov { namespace pass { class TRANSFORMATIONS_API TransposeSinkingUnaryForward; +class TRANSFORMATIONS_API TransposeSinkingUnaryBackwardSingleConsumer; +class TRANSFORMATIONS_API TransposeSinkingUnaryBackwardMultiConsumers; class TRANSFORMATIONS_API TransposeSinkingUnaryBackward; } // namespace pass @@ -22,7 +24,19 @@ public: TransposeSinkingUnaryForward(); }; -class ov::pass::TransposeSinkingUnaryBackward : public ov::pass::MatcherPass { +class ov::pass::TransposeSinkingUnaryBackwardSingleConsumer : public ov::pass::MatcherPass { +public: + OPENVINO_RTTI("TransposeSinkingUnaryBackwardSingleConsumer", "0"); + TransposeSinkingUnaryBackwardSingleConsumer(); +}; + +class ov::pass::TransposeSinkingUnaryBackwardMultiConsumers : public ov::pass::MatcherPass { +public: + OPENVINO_RTTI("TransposeSinkingUnaryBackwardMultiConsumers", "0"); + TransposeSinkingUnaryBackwardMultiConsumers(); +}; + +class ov::pass::TransposeSinkingUnaryBackward : public ov::pass::GraphRewrite { public: OPENVINO_RTTI("TransposeSinkingUnaryBackward", "0"); TransposeSinkingUnaryBackward(); diff --git a/src/common/transformations/include/transformations/common_optimizations/transpose_sinking_utils.hpp b/src/common/transformations/include/transformations/common_optimizations/transpose_sinking_utils.hpp index 7c4cc038dbf..7e28e2816bf 100644 --- a/src/common/transformations/include/transformations/common_optimizations/transpose_sinking_utils.hpp +++ b/src/common/transformations/include/transformations/common_optimizations/transpose_sinking_utils.hpp @@ -28,25 +28,77 @@ struct TransposeInputsInfo { } }; -TransposeInputsInfo GetFirstTransposeInput(std::shared_ptr node); -bool IfNodeHasTransposeInputs(const ov::Output& output); -ov::AxisVector ReverseTransposeOrder(const ov::AxisVector& axis_order); -void SwapOutputNames(ov::Output output1, ov::Output output2); -void SwapFriendlyNames(std::shared_ptr node1, std::shared_ptr node2); -void SwapNames(std::shared_ptr node1, std::shared_ptr node2); +/** + * @brief Finds node first input that is a transpose operation and returns filled TransposeInputsInfo + * for it + */ +TransposeInputsInfo GetFirstTransposeInput(std::shared_ptr); + +/** + * @brief Checks if @arg has any input node that is a transpose operation + */ +bool IfNodeHasTransposeInputs(const ov::Output&); + +/** + * @brief Reverses order of transpose operation. Do it in a such way that if we had couple following one after + * another transposes (one would be reversed version of another) we will have no transpose as a result of that + * couple of transposes. + */ +ov::AxisVector ReverseTransposeOrder(const ov::AxisVector&); + +/** + * @brief Swaps @args output tensor names + */ +void SwapOutputNames(ov::Output, ov::Output); + +/** + * @brief Swaps @args friendly names + */ +void SwapFriendlyNames(std::shared_ptr, std::shared_ptr); + +/** + * @brief Swaps @args output tensor names and friendly names + */ +void SwapNames(std::shared_ptr, std::shared_ptr); namespace sink_forward { -// insert input reversed transposes, remove first input tranpose -void UpdateInputTransposes(std::shared_ptr main_node, TransposeInputsInfo& transpose_input_info); -void RemoveZeroInputNode(std::shared_ptr main_node); -ov::NodeVector InsertOutputTransposes(std::shared_ptr main_node, TransposeInputsInfo& transpose_input_info); +/** + * @brief Inserts reversed transposed on @args main_node inputs. Removes input transpose specified in @arg + * transpose_input_info + */ +void UpdateInputTransposes(std::shared_ptr main_node, const TransposeInputsInfo& transpose_input_info); + +/** + * @brief Removes @arg input node + */ +void RemoveInputNode(std::shared_ptr, size_t input_idx); + +/** + * @brief Inserts transposes on each main_node output with the order specified in @arg transpose_input_info + */ +ov::NodeVector InsertOutputTransposes(std::shared_ptr main_node, + const TransposeInputsInfo& transpose_input_info); } // namespace sink_forward namespace sink_backward { +/** + * @brief Inserts transposes on each input of @arg main_node with the order specified in @arg transpose_const + */ ov::NodeVector InsertTransposeBeforeNode(std::shared_ptr main_node, std::shared_ptr transpose_const); } // namespace sink_backward void UpdateForwardSinkingAbility(std::shared_ptr); +/** + * @brief Checks if @arg has consumers that all are the same transpose operation. If no consumers at all + * returns false. + */ +bool HasSameOutputTransposeNodes(const ov::Output&); + +/** + * Removes all direct node consumers that have one output + */ +void RemoveSingleOutputConsumers(std::shared_ptr); + } // namespace transpose_sinking diff --git a/src/common/transformations/src/transformations/common_optimizations/transpose_sinking_binary.cpp b/src/common/transformations/src/transformations/common_optimizations/transpose_sinking_binary.cpp index 4d9eb9c93a5..b034a4ad80e 100644 --- a/src/common/transformations/src/transformations/common_optimizations/transpose_sinking_binary.cpp +++ b/src/common/transformations/src/transformations/common_optimizations/transpose_sinking_binary.cpp @@ -20,10 +20,10 @@ using namespace ov; using namespace ov::opset9; using namespace transpose_sinking; -ov::pass::TransposeSinkingBinaryElementwiseForward::TransposeSinkingBinaryElementwiseForward() { - MATCHER_SCOPE(TransposeSinkingBinaryElementwiseForward); +ov::pass::TransposeSinkingBinaryForward::TransposeSinkingBinaryForward() { + MATCHER_SCOPE(TransposeSinkingBinaryForward); - auto main_node_label = wrap_type(IfNodeHasTransposeInputs); + auto main_node_label = wrap_type(IfNodeHasTransposeInputs); matcher_pass_callback matcher_pass_callback = [=](Matcher& m) { const auto& pattern_to_output = m.get_pattern_value_map(); @@ -46,18 +46,19 @@ ov::pass::TransposeSinkingBinaryElementwiseForward::TransposeSinkingBinaryElemen register_matcher(m, matcher_pass_callback); } -ov::pass::TransposeSinkingBinaryElementwiseBackward::TransposeSinkingBinaryElementwiseBackward() { - MATCHER_SCOPE(TransposeSinkingBinaryElementwiseBackward); +ov::pass::TransposeSinkingBinaryBackward::TransposeSinkingBinaryBackward() { + MATCHER_SCOPE(TransposeSinkingBinaryBackward); - auto main_node_label = wrap_type([](const Output& output) -> bool { - return consumers_count(1)(output) && has_static_rank()(output); - }); + auto main_node_label = + wrap_type([](const Output& output) -> bool { + return has_static_rank()(output) && HasSameOutputTransposeNodes(output); + }); - auto transpose_const_label = wrap_type(consumers_count(1)); + auto transpose_const_label = wrap_type(); auto transpose_label = wrap_type({main_node_label, transpose_const_label}, [](const Output& output) -> bool { - return consumers_count(1)(output) && has_static_rank()(output) && is_sinking_node(output); + return has_static_rank()(output) && is_sinking_node(output); }); matcher_pass_callback matcher_pass_callback = [=](Matcher& m) { @@ -70,8 +71,8 @@ ov::pass::TransposeSinkingBinaryElementwiseBackward::TransposeSinkingBinaryEleme register_new_node(new_node); } - // remove transpose after main node - transpose->output(0).replace(main_node); + // remove output transposes + RemoveSingleOutputConsumers(main_node); SwapNames(transpose, main_node); diff --git a/src/common/transformations/src/transformations/common_optimizations/transpose_sinking_concat.cpp b/src/common/transformations/src/transformations/common_optimizations/transpose_sinking_concat.cpp index 49200b8dc78..46c599d5c22 100644 --- a/src/common/transformations/src/transformations/common_optimizations/transpose_sinking_concat.cpp +++ b/src/common/transformations/src/transformations/common_optimizations/transpose_sinking_concat.cpp @@ -55,14 +55,14 @@ ov::pass::TransposeSinkingConcatBackward::TransposeSinkingConcatBackward() { MATCHER_SCOPE(TransposeSinkingConcatBackward); auto main_node_label = wrap_type([](const Output& output) -> bool { - return consumers_count(1)(output) && has_static_rank()(output); + return has_static_rank()(output) && HasSameOutputTransposeNodes(output); }); - auto transpose_const_label = wrap_type(consumers_count(1)); + auto transpose_const_label = wrap_type(); auto transpose_label = wrap_type({main_node_label, transpose_const_label}, [](const Output& output) -> bool { - return consumers_count(1)(output) && has_static_rank()(output) && is_sinking_node(output); + return has_static_rank()(output) && is_sinking_node(output); }); matcher_pass_callback matcher_pass_callback = [=](Matcher& m) { @@ -75,8 +75,8 @@ ov::pass::TransposeSinkingConcatBackward::TransposeSinkingConcatBackward() { register_new_node(new_node); } - // remove transpose after main node - transpose->output(0).replace(main_node); + // remove output transposes + RemoveSingleOutputConsumers(main_node); SwapNames(transpose, main_node); diff --git a/src/common/transformations/src/transformations/common_optimizations/transpose_sinking_general.cpp b/src/common/transformations/src/transformations/common_optimizations/transpose_sinking_general.cpp index bab7e211f6e..5de1a9d3be3 100644 --- a/src/common/transformations/src/transformations/common_optimizations/transpose_sinking_general.cpp +++ b/src/common/transformations/src/transformations/common_optimizations/transpose_sinking_general.cpp @@ -21,7 +21,7 @@ ov::pass::TransposeSinkingGeneralForward::TransposeSinkingGeneralForward() { MATCHER_SCOPE(TransposeSinkingGeneralForward); add_matcher(); - add_matcher(); + add_matcher(); add_matcher(); add_matcher(); add_matcher(); @@ -30,7 +30,7 @@ ov::pass::TransposeSinkingGeneralForward::TransposeSinkingGeneralForward() { ov::pass::TransposeSinkingGeneralBackward::TransposeSinkingGeneralBackward() { MATCHER_SCOPE(TransposeSinkingGeneralBackward); add_matcher(); - add_matcher(); + add_matcher(); add_matcher(); add_matcher(); add_matcher(); diff --git a/src/common/transformations/src/transformations/common_optimizations/transpose_sinking_split.cpp b/src/common/transformations/src/transformations/common_optimizations/transpose_sinking_split.cpp index b623cf6b830..2fac58b7565 100644 --- a/src/common/transformations/src/transformations/common_optimizations/transpose_sinking_split.cpp +++ b/src/common/transformations/src/transformations/common_optimizations/transpose_sinking_split.cpp @@ -52,78 +52,69 @@ OutputTranspose GetOutputTransposes(NodePtr node) { return OutputTranspose(); } -NodePtr FindSplitInput(Node* node) { +template +std::shared_ptr FindInputNode(ov::Node* node) { for (size_t input_idx = 0; input_idx < node->get_input_size(); ++input_idx) { - NodePtr input_node = node->get_input_node_shared_ptr(input_idx); - auto split_node = as_type_ptr(input_node); - if (split_node) - return split_node; + std::shared_ptr input_node = node->get_input_node_shared_ptr(input_idx); + auto target_node = ov::as_type_ptr(input_node); + if (target_node) + return target_node; } return {}; } -std::shared_ptr GetTransposeConstant(Input input) { - auto transpose_node = dynamic_cast(input.get_node()); - if (!transpose_node) - return {}; - - if (!is_sinking_node(input.get_node())) - return {}; - - auto constant_node = as_type_ptr(transpose_node->input_value(1).get_node_shared_ptr()); - if (!constant_node) - return {}; - - return constant_node; -} - bool HasInputSplitAndTransposeSiblings(const Output& output) { - NodePtr split_node = FindSplitInput(output.get_node()); - if (!split_node) { + NodePtr main_node = FindInputNode(output.get_node()); + if (!main_node) { return false; } - AxisVector first_transpose_axis_order; - // get first transpose axis - { - auto constant_node = GetTransposeConstant(*(split_node->get_output_target_inputs(0).begin())); - if (!constant_node) - return false; - first_transpose_axis_order = constant_node->get_axis_vector_val(); - } - - for (size_t output_idx = 1; output_idx < split_node->get_output_size(); ++output_idx) { - for (auto& input : split_node->get_output_target_inputs(output_idx)) { - auto constant_node = GetTransposeConstant(input); - if (!constant_node) - return false; - - AxisVector transpose_axis_order = constant_node->get_axis_vector_val(); - if (transpose_axis_order.size() != first_transpose_axis_order.size()) - return false; - if (!std::equal(transpose_axis_order.begin(), - transpose_axis_order.end(), - first_transpose_axis_order.begin())) - return false; - } - } - - return true; + return HasSameOutputTransposeNodes(main_node); } + +bool IsSplitSinked(const Output& output) { + return HasInputSplitAndTransposeSiblings(output) && is_sinking_node(output); +} + } // namespace +/* + * We follow Transpose operations rather than Split. We cannot create matcher pattern + * for Split with Transpose outputs since Split can have different number of outputs. + * We just can: + * - specify Split as searched node and check if it has transpose outputs + * - specify Transpose as searched node and check if it has Split input + * Transformations are called on each found node in sorted order from the start to end + * of the network. When we proceed Split backward sinking we move input transpose + * to the input of the Split operation. + * Consider case Split (1) -> Split (2) -> Transpose + * If specify Split as main searched node after first transformation work we will have + * Split (1) -> Transpose -> Split(2) + * Matcher pass will not call TransposeSinkingSplitBackward since + * - matcher pattern has no Transpose label + * - Split (1) has already been proceeded + * Adding Split(2) into the working queue as register_new_node(split) + * cannot help us. We just can try to find all input Split operations and add them with + * register_new_node(). Implemented way is simpler. + * + * We sink Transpose through Split operation in a backward way only if all the output + * nodes are the same Transpose. We can: + * - clone Split with all outputs except Transpose + * causes perfomance problems + * - add reversed Transpose operations on all outputs except sinking Transpose + * nothing to do with new added output Transposes + */ ov::pass::TransposeSinkingSplitBackward::TransposeSinkingSplitBackward() { MATCHER_SCOPE(TransposeSinkingSplitBackward); - auto transpose_const_label = wrap_type(consumers_count(1)); - auto transpose_label = - wrap_type({any_input(), transpose_const_label}, HasInputSplitAndTransposeSiblings); + auto transpose_const_label = wrap_type(); + auto transpose_label = wrap_type({any_input(), transpose_const_label}, IsSplitSinked); matcher_pass_callback matcher_pass_callback = [=](Matcher& m) { const auto& pattern_to_output = m.get_pattern_value_map(); auto transpose_label_node = pattern_to_output.at(transpose_label).get_node(); - NodePtr split = FindSplitInput(transpose_label_node); + NodePtr split = FindInputNode(transpose_label_node); auto split_axis_constant = as_type_ptr(split->input_value(1).get_node_shared_ptr()); OutputTranspose output_transpose = GetOutputTransposes(split); @@ -161,11 +152,8 @@ ov::pass::TransposeSinkingSplitBackward::TransposeSinkingSplitBackward() { new_split_axis_const); // remove split output transposes - for (size_t output_idx = 0; output_idx < split->get_output_size(); ++output_idx) { - for (auto& input : split->get_output_target_inputs(output_idx)) { - input.get_node()->output(0).replace(split->output(output_idx)); - } - } + RemoveSingleOutputConsumers(split); + return true; }; @@ -186,7 +174,7 @@ ov::pass::TransposeSinkingSplitForward::TransposeSinkingSplitForward() { TransposeInputsInfo transpose_input_info = GetFirstTransposeInput(main_node); - sink_forward::RemoveZeroInputNode(main_node); + sink_forward::RemoveInputNode(main_node, /* input_idx */ 0); for (auto& new_node : sink_forward::InsertOutputTransposes(main_node, transpose_input_info)) { register_new_node(new_node); transpose_sinking::UpdateForwardSinkingAbility(new_node); diff --git a/src/common/transformations/src/transformations/common_optimizations/transpose_sinking_unary.cpp b/src/common/transformations/src/transformations/common_optimizations/transpose_sinking_unary.cpp index 0eebe9be309..9a164c3642c 100644 --- a/src/common/transformations/src/transformations/common_optimizations/transpose_sinking_unary.cpp +++ b/src/common/transformations/src/transformations/common_optimizations/transpose_sinking_unary.cpp @@ -10,6 +10,10 @@ #include "transformations/rt_info/transpose_sinking_attr.hpp" using namespace ov; +using namespace ov::opset9; +using namespace ov::pass::pattern; +using namespace ov::op::util; +using namespace transpose_sinking; namespace { @@ -90,16 +94,11 @@ NodePair Swap(NodePtr first_node, NodePtr second_node) { ov::pass::TransposeSinkingUnaryForward::TransposeSinkingUnaryForward() { MATCHER_SCOPE(TransposeSinkingUnaryForward); - auto transpose_label = ov::pass::pattern::wrap_type( - {ov::pass::pattern::any_input(), ov::pass::pattern::any_input()}); - auto unary_label = ov::pass::pattern::wrap_type({transpose_label}); + auto transpose_label = wrap_type({any_input(), any_input()}); + auto unary_label = + wrap_type({transpose_label}); - ov::matcher_pass_callback matcher_pass_callback = [=](ov::pass::pattern::Matcher& m) { + ov::matcher_pass_callback matcher_pass_callback = [=](Matcher& m) { const auto& pattern_to_output = m.get_pattern_value_map(); auto transpose = pattern_to_output.at(transpose_label).get_node_shared_ptr(); auto unary = pattern_to_output.at(unary_label).get_node_shared_ptr(); @@ -109,12 +108,12 @@ ov::pass::TransposeSinkingUnaryForward::TransposeSinkingUnaryForward() { register_new_node(new_nodes.first); register_new_node(new_nodes.second); - transpose_sinking::UpdateForwardSinkingAbility(new_nodes.second); + UpdateForwardSinkingAbility(new_nodes.second); return true; }; - auto m = std::make_shared(unary_label, "ov::pass::TransposeSinkingUnaryForward"); + auto m = std::make_shared(unary_label, "ov::pass::TransposeSinkingUnaryForward"); register_matcher(m, matcher_pass_callback); } @@ -124,21 +123,16 @@ bool IfSinkingEnabled(const Output& output) { } } // namespace -ov::pass::TransposeSinkingUnaryBackward::TransposeSinkingUnaryBackward() { - MATCHER_SCOPE(TransposeSinkingUnaryBackward); +ov::pass::TransposeSinkingUnaryBackwardSingleConsumer::TransposeSinkingUnaryBackwardSingleConsumer() { + MATCHER_SCOPE(TransposeSinkingUnaryBackwardSingleConsumer); - auto unary_label = ov::pass::pattern::wrap_type({ov::pass::pattern::any_input()}); + auto unary_label = + wrap_type({any_input()}, + consumers_count(1)); - auto transpose_label = - ov::pass::pattern::wrap_type({unary_label, ov::pass::pattern::any_input()}, - IfSinkingEnabled); + auto transpose_label = wrap_type({unary_label, any_input()}, IfSinkingEnabled); - ov::matcher_pass_callback matcher_pass_callback = [=](ov::pass::pattern::Matcher& m) { + ov::matcher_pass_callback matcher_pass_callback = [=](Matcher& m) { const auto& pattern_to_output = m.get_pattern_value_map(); auto transpose = pattern_to_output.at(transpose_label).get_node_shared_ptr(); auto unary = pattern_to_output.at(unary_label).get_node_shared_ptr(); @@ -151,6 +145,55 @@ ov::pass::TransposeSinkingUnaryBackward::TransposeSinkingUnaryBackward() { return true; }; - auto m = std::make_shared(transpose_label, "ov::pass::TransposeSinkingUnaryBackward"); + auto m = std::make_shared(transpose_label, "ov::pass::TransposeSinkingUnaryBackwardSingleConsumer"); register_matcher(m, matcher_pass_callback); } + +namespace { +std::function)> consumers_more_than(size_t n) { + return [=](Output output) -> bool { + return output.get_target_inputs().size() > n; + }; +} +} // namespace + +ov::pass::TransposeSinkingUnaryBackwardMultiConsumers::TransposeSinkingUnaryBackwardMultiConsumers() { + MATCHER_SCOPE(TransposeSinkingUnaryBackwardMultiConsumers); + + auto unary_restrictions = [](const Output& output) -> bool { + return consumers_more_than(1)(output) && HasSameOutputTransposeNodes(output); + }; + + auto unary_label = + wrap_type({any_input()}, + unary_restrictions); + + auto transpose_const_label = wrap_type(); + + auto transpose_label = wrap_type({unary_label, transpose_const_label}, IfSinkingEnabled); + + ov::matcher_pass_callback matcher_pass_callback = [=](Matcher& m) { + const auto& pattern_to_output = m.get_pattern_value_map(); + auto transpose_const = as_type_ptr(pattern_to_output.at(transpose_const_label).get_node_shared_ptr()); + auto transpose = pattern_to_output.at(transpose_label).get_node_shared_ptr(); + auto unary = pattern_to_output.at(unary_label).get_node_shared_ptr(); + + for (auto& new_node : sink_backward::InsertTransposeBeforeNode(unary, transpose_const)) { + register_new_node(new_node); + } + + // remove output transposes + RemoveSingleOutputConsumers(unary); + + return true; + }; + + auto m = std::make_shared(transpose_label, "ov::pass::TransposeSinkingUnaryBackwardMultiConsumers"); + register_matcher(m, matcher_pass_callback); +} + +ov::pass::TransposeSinkingUnaryBackward::TransposeSinkingUnaryBackward() { + MATCHER_SCOPE(TransposeSinkingUnaryBackward); + add_matcher(); + add_matcher(); +} diff --git a/src/common/transformations/src/transformations/common_optimizations/transpose_sinking_utils.cpp b/src/common/transformations/src/transformations/common_optimizations/transpose_sinking_utils.cpp index 334ba387066..74974996824 100644 --- a/src/common/transformations/src/transformations/common_optimizations/transpose_sinking_utils.cpp +++ b/src/common/transformations/src/transformations/common_optimizations/transpose_sinking_utils.cpp @@ -115,8 +115,7 @@ ov::Output FixInputNodeRank(ov::Output input_node, ov::Rank: namespace sink_forward { -// insert input reversed transposes, remove first input tranpose -void UpdateInputTransposes(NodePtr main_node, TransposeInputsInfo& transpose_input_info) { +void UpdateInputTransposes(NodePtr main_node, const TransposeInputsInfo& transpose_input_info) { if (transpose_input_info.isEmpty() || HasDynamicRankInput(main_node)) return; @@ -149,15 +148,15 @@ void UpdateInputTransposes(NodePtr main_node, TransposeInputsInfo& transpose_inp } } -void RemoveZeroInputNode(NodePtr main_node) { - auto input_node = main_node->input_value(0); - if (input_node.get_node()->get_input_size() < 1) +void RemoveInputNode(NodePtr main_node, size_t input_idx) { + auto input_node = main_node->input_value(input_idx); + if (input_node.get_node()->get_input_size() < (input_idx + 1)) return; - auto parent_node = input_node.get_node()->input_value(0); - main_node->input(0).replace_source_output(parent_node); + auto parent_node = input_node.get_node()->input_value(input_idx); + main_node->input(input_idx).replace_source_output(parent_node); } -NodeVector InsertOutputTransposes(NodePtr main_node, TransposeInputsInfo& transpose_input_info) { +NodeVector InsertOutputTransposes(NodePtr main_node, const TransposeInputsInfo& transpose_input_info) { if (transpose_input_info.isEmpty()) return {}; const auto transpose_axis_order = transpose_input_info.transpose_const->get_axis_vector_val(); @@ -246,6 +245,7 @@ bool CanPropagateForwardThrough(Node* node) { CHECK_TRANSPOSE_SINKING_SUPPORTED(Concat, node); CHECK_TRANSPOSE_SINKING_SUPPORTED(Split, node); CHECK_TRANSPOSE_SINKING_SUPPORTED(Transpose, node); + CHECK_TRANSPOSE_SINKING_SUPPORTED(PRelu, node); return false; } @@ -268,4 +268,76 @@ void UpdateForwardSinkingAbility(NodePtr node) { mark_as_no_sinking_node(node); } +namespace { + +std::shared_ptr GetTransposeConstant(Node* node) { + auto transpose_node = dynamic_cast(node); + if (!transpose_node) + return {}; + + auto constant_node = as_type_ptr(transpose_node->input_value(1).get_node_shared_ptr()); + if (!constant_node) + return {}; + + return constant_node; +} + +Node* FindFirstConsumer(NodePtr node) { + for (size_t output_idx = 0; output_idx < node->get_output_size(); ++output_idx) { + auto inputs = node->get_output_target_inputs(output_idx); + if (inputs.empty()) + continue; + return inputs.begin()->get_node(); + } + return nullptr; +} + +bool HasSameOutputTransposeNodes(NodePtr main_node) { + AxisVector first_transpose_axis_order; + { + Node* first_consumer = FindFirstConsumer(main_node); + if (!first_consumer) + return false; + auto constant_node = GetTransposeConstant(first_consumer); + if (!constant_node) + return false; + first_transpose_axis_order = constant_node->get_axis_vector_val(); + } + + for (size_t output_idx = 0; output_idx < main_node->get_output_size(); ++output_idx) { + for (auto& input : main_node->get_output_target_inputs(output_idx)) { + auto constant_node = GetTransposeConstant(input.get_node()); + if (!constant_node) + return false; + + AxisVector transpose_axis_order = constant_node->get_axis_vector_val(); + if (transpose_axis_order.size() != first_transpose_axis_order.size()) + return false; + if (!std::equal(transpose_axis_order.begin(), + transpose_axis_order.end(), + first_transpose_axis_order.begin())) + return false; + } + } + + return true; +} + +} // namespace + +bool HasSameOutputTransposeNodes(const Output& output) { + return HasSameOutputTransposeNodes(output.get_node_shared_ptr()); +} + +void RemoveSingleOutputConsumers(NodePtr node) { + for (size_t output_idx = 0; output_idx < node->get_output_size(); ++output_idx) { + for (auto& input : node->get_output_target_inputs(output_idx)) { + Node* consumer = input.get_node(); + if (consumer->get_output_size() != 1) + continue; + consumer->output(0).replace(node->output(output_idx)); + } + } +} + } // namespace transpose_sinking diff --git a/src/common/transformations/tests/common_optimizations/transpose_sinking_binary_test.cpp b/src/common/transformations/tests/common_optimizations/transpose_sinking_binary_test.cpp index a04c81062da..e65e892655e 100644 --- a/src/common/transformations/tests/common_optimizations/transpose_sinking_binary_test.cpp +++ b/src/common/transformations/tests/common_optimizations/transpose_sinking_binary_test.cpp @@ -12,14 +12,19 @@ #include "common_test_utils/ngraph_test_utils.hpp" #include "gtest/gtest.h" +using namespace ov; +using namespace ov::opset9; + +namespace transpose_sinking_binary_eltwise { + namespace { using NodePtr = std::shared_ptr; -using ModelPtr = std::shared_ptr; +using ModelPtr = std::shared_ptr; using Output = ov::Output; namespace { -std::string to_string(const ov::Shape& shape) { +std::string to_string(const Shape& shape) { std::ostringstream result; result << "{"; for (size_t idx = 0; idx < shape.size(); ++idx) { @@ -92,7 +97,23 @@ public: #define CREATE_PASS_FACTORY(pass_name) std::make_shared>(#pass_name) #undef CREATE_BINARY_FACTORY -#define CREATE_BINARY_FACTORY(type_name) CreateBinaryFactory(#type_name) +#define CREATE_BINARY_FACTORY(type_name) CreateBinaryFactory(#type_name) + +/* + * binary operations without PRelu + * PRelu input(1) is special constant input that is important for some tests. Specially for the + * Unsqueeze insertion + */ +std::vector binary_elementwise_factories = {CREATE_BINARY_FACTORY(Add), + CREATE_BINARY_FACTORY(Divide), + CREATE_BINARY_FACTORY(Maximum), + CREATE_BINARY_FACTORY(Minimum), + CREATE_BINARY_FACTORY(Mod), + CREATE_BINARY_FACTORY(Multiply), + CREATE_BINARY_FACTORY(Power), + CREATE_BINARY_FACTORY(SquaredDifference), + CREATE_BINARY_FACTORY(Subtract)}; + std::vector binary_factories = {CREATE_BINARY_FACTORY(Add), CREATE_BINARY_FACTORY(Divide), CREATE_BINARY_FACTORY(Maximum), @@ -101,8 +122,8 @@ std::vector binary_factories = {CREATE_BINARY_FACTORY(Add), CREATE_BINARY_FACTORY(Multiply), CREATE_BINARY_FACTORY(Power), CREATE_BINARY_FACTORY(SquaredDifference), - CREATE_BINARY_FACTORY(Subtract)}; -#undef CREATE_BINARY_FACTORY + CREATE_BINARY_FACTORY(Subtract), + CREATE_BINARY_FACTORY(PRelu)}; std::vector binary_operations_numbers = {1, 10}; @@ -110,51 +131,49 @@ std::vector binary_transpose_input_indexes = {0, 1}; } // namespace -namespace binary { namespace single_consumer { namespace forward { namespace one_input_transpose { -std::shared_ptr CreateFunction(BinaryFactoryPtr binary_factory, - size_t num_binary_ops, - ov::element::Type input_type, - size_t binary_transpose_input_idx) { - const ov::Shape input_shape{1, 96, 55, 55}; - const ov::Shape const_shape{1, 55, 55, 96}; +std::shared_ptr CreateFunction(BinaryFactoryPtr binary_factory, + size_t num_binary_ops, + element::Type input_type, + size_t binary_transpose_input_idx) { + const Shape input_shape{1, 96, 55, 55}; + const Shape const_shape{1, 55, 55, 96}; - auto X = std::make_shared(input_type, input_shape); + auto X = std::make_shared(input_type, input_shape); - auto ng_order0 = std::make_shared(ov::element::u64, ov::Shape{4}, ov::Shape{0, 2, 3, 1}); - auto transpose0 = std::make_shared(X, ng_order0); + auto ng_order0 = std::make_shared(element::u64, Shape{4}, Shape{0, 2, 3, 1}); + auto transpose0 = std::make_shared(X, ng_order0); NodePtr in_op = transpose0; for (size_t i = 0; i < num_binary_ops; ++i) { - auto in_constant = std::make_shared(input_type, const_shape, ov::Shape{1}); + auto in_constant = std::make_shared(input_type, const_shape, Shape{1}); if (!binary_transpose_input_idx) in_op = binary_factory->create(in_op, in_constant); else in_op = binary_factory->create(in_constant, in_op); } - return std::make_shared(ov::OutputVector{in_op}, ov::ParameterVector{X}); + return std::make_shared(ov::OutputVector{in_op}, ov::ParameterVector{X}); } -std::shared_ptr CreateReferenceFunction(BinaryFactoryPtr binary_factory, - size_t num_binary_ops, - ov::element::Type input_type, - size_t binary_transpose_input_idx) { - const ov::Shape input_shape{1, 96, 55, 55}; - const ov::Shape const_shape{1, 55, 55, 96}; +std::shared_ptr CreateReferenceFunction(BinaryFactoryPtr binary_factory, + size_t num_binary_ops, + element::Type input_type, + size_t binary_transpose_input_idx) { + const Shape input_shape{1, 96, 55, 55}; + const Shape const_shape{1, 55, 55, 96}; - auto X = std::make_shared(input_type, input_shape); + auto X = std::make_shared(input_type, input_shape); NodePtr in_op = X; for (size_t i = 0; i < num_binary_ops; ++i) { - auto in_constant = std::make_shared(input_type, const_shape, ov::Shape{1}); + auto in_constant = std::make_shared(input_type, const_shape, Shape{1}); - auto transpose_reversed_const = - std::make_shared(ov::element::u64, ov::Shape{4}, ov::Shape{0, 3, 1, 2}); - auto transpose_reversed = std::make_shared(in_constant, transpose_reversed_const); + auto transpose_reversed_const = std::make_shared(element::u64, Shape{4}, Shape{0, 3, 1, 2}); + auto transpose_reversed = std::make_shared(in_constant, transpose_reversed_const); if (!binary_transpose_input_idx) in_op = binary_factory->create(in_op, transpose_reversed); @@ -162,109 +181,172 @@ std::shared_ptr CreateReferenceFunction(BinaryFactoryPtr binary_facto in_op = binary_factory->create(transpose_reversed, in_op); } - auto ng_order0 = std::make_shared(ov::element::u64, ov::Shape{4}, ov::Shape{0, 2, 3, 1}); - auto transpose0 = std::make_shared(in_op, ng_order0); + auto ng_order0 = std::make_shared(element::u64, Shape{4}, Shape{0, 2, 3, 1}); + auto transpose0 = std::make_shared(in_op, ng_order0); - return std::make_shared(ov::OutputVector{transpose0}, ov::ParameterVector{X}); + return std::make_shared(ov::OutputVector{transpose0}, ov::ParameterVector{X}); } } // namespace one_input_transpose namespace double_transpose { -std::shared_ptr CreateFunction(BinaryFactoryPtr binary_factory, - size_t num_binary_ops, - ov::element::Type input_type) { - const ov::Shape input_shape{1, 96, 55, 55}; +std::shared_ptr CreateFunction(BinaryFactoryPtr binary_factory, + size_t num_binary_ops, + element::Type input_type) { + const Shape input_shape{1, 96, 55, 55}; - auto X = std::make_shared(input_type, input_shape); + auto X = std::make_shared(input_type, input_shape); - auto ng_order0 = std::make_shared(ov::element::u64, ov::Shape{4}, ov::Shape{0, 2, 3, 1}); - auto transpose0 = std::make_shared(X, ng_order0); + auto ng_order0 = std::make_shared(element::u64, Shape{4}, Shape{0, 2, 3, 1}); + auto transpose0 = std::make_shared(X, ng_order0); NodePtr in_op = transpose0; for (size_t i = 0; i < num_binary_ops; ++i) { - auto in_constant = std::make_shared(input_type, input_shape, ov::Shape{1}); - auto ng_order1 = std::make_shared(ov::element::u64, ov::Shape{4}, ov::Shape{0, 2, 3, 1}); - auto transpose1 = std::make_shared(in_constant, ng_order1); + auto in_constant = std::make_shared(input_type, input_shape, Shape{1}); + auto ng_order1 = std::make_shared(element::u64, Shape{4}, Shape{0, 2, 3, 1}); + auto transpose1 = std::make_shared(in_constant, ng_order1); in_op = binary_factory->create(in_op, transpose1); } - return std::make_shared(ov::OutputVector{in_op}, ov::ParameterVector{X}); + return std::make_shared(ov::OutputVector{in_op}, ov::ParameterVector{X}); } -std::shared_ptr CreateReferenceFunction(BinaryFactoryPtr binary_factory, - size_t num_binary_ops, - ov::element::Type input_type) { - const ov::Shape input_shape{1, 96, 55, 55}; +std::shared_ptr CreateReferenceFunction(BinaryFactoryPtr binary_factory, + size_t num_binary_ops, + element::Type input_type) { + const Shape input_shape{1, 96, 55, 55}; - auto X = std::make_shared(input_type, input_shape); + auto X = std::make_shared(input_type, input_shape); NodePtr in_op = X; for (size_t i = 0; i < num_binary_ops; ++i) { - auto in_constant = std::make_shared(input_type, input_shape, ov::Shape{1}); + auto in_constant = std::make_shared(input_type, input_shape, Shape{1}); - auto ng_order1 = std::make_shared(ov::element::u64, ov::Shape{4}, ov::Shape{0, 2, 3, 1}); - auto transpose1 = std::make_shared(in_constant, ng_order1); + auto ng_order1 = std::make_shared(element::u64, Shape{4}, Shape{0, 2, 3, 1}); + auto transpose1 = std::make_shared(in_constant, ng_order1); - auto transpose_reversed_const = - std::make_shared(ov::element::u64, ov::Shape{4}, ov::Shape{0, 3, 1, 2}); - auto transpose_reversed = std::make_shared(transpose1, transpose_reversed_const); + auto transpose_reversed_const = std::make_shared(element::u64, Shape{4}, Shape{0, 3, 1, 2}); + auto transpose_reversed = std::make_shared(transpose1, transpose_reversed_const); in_op = binary_factory->create(in_op, transpose_reversed); } - auto ng_order0 = std::make_shared(ov::element::u64, ov::Shape{4}, ov::Shape{0, 2, 3, 1}); - auto transpose0 = std::make_shared(in_op, ng_order0); + auto ng_order0 = std::make_shared(element::u64, Shape{4}, Shape{0, 2, 3, 1}); + auto transpose0 = std::make_shared(in_op, ng_order0); - return std::make_shared(ov::OutputVector{transpose0}, ov::ParameterVector{X}); + return std::make_shared(ov::OutputVector{transpose0}, ov::ParameterVector{X}); } +using CreateGraphBinaryTwoTransposeInputsF = std::function< + std::shared_ptr(BinaryFactoryPtr binary_factory, size_t num_binary_ops, element::Type input_type)>; + +using TestBinaryTwoTransposeInputsParams = + std::tuple; /* input type */ + +class TransposeSinkingBinaryTwoTransposeInputsTestFixture + : public ::testing::WithParamInterface, + public TransformationTestsF { +public: + static std::string get_test_name(const testing::TestParamInfo& obj) { + BinaryFactoryPtr binary_factory; + PassFactoryPtr pass_factory; + size_t num_binary_ops; + CreateGraphBinaryTwoTransposeInputsF model_factory; + CreateGraphBinaryTwoTransposeInputsF reference_model_factory; + element::Type input_type; + + std::tie(binary_factory, pass_factory, num_binary_ops, model_factory, reference_model_factory, input_type) = + obj.param; + + std::ostringstream test_name; + test_name << "binaryFactory=" << binary_factory->getTypeName() << "/"; + test_name << "passFactory=" << pass_factory->getTypeName() << "/"; + test_name << "numBinaryOps=" << num_binary_ops << "/"; + test_name << "inputType=" << input_type; + + return test_name.str(); + } +}; + +TEST_P(TransposeSinkingBinaryTwoTransposeInputsTestFixture, CompareFunctions) { + BinaryFactoryPtr binary_factory; + PassFactoryPtr pass_factory; + size_t num_binary_ops; + CreateGraphBinaryTwoTransposeInputsF model_factory; + CreateGraphBinaryTwoTransposeInputsF reference_model_factory; + element::Type input_type; + + std::tie(binary_factory, pass_factory, num_binary_ops, model_factory, reference_model_factory, input_type) = + this->GetParam(); + + model = model_factory(binary_factory, num_binary_ops, input_type); + model_ref = reference_model_factory(binary_factory, num_binary_ops, input_type); + pass_factory->registerPass(manager); +} + +INSTANTIATE_TEST_SUITE_P(TransposeSinkingBinaryTwoTransposeInputsForwardTestSuite, + TransposeSinkingBinaryTwoTransposeInputsTestFixture, + ::testing::Combine(::testing::ValuesIn(binary_factories), + ::testing::Values(CREATE_PASS_FACTORY(TransposeSinkingBinaryForward)), + ::testing::ValuesIn(binary_operations_numbers), + ::testing::Values(CreateFunction), + ::testing::Values(CreateReferenceFunction), + ::testing::Values(element::f32)), + TransposeSinkingBinaryTwoTransposeInputsTestFixture::get_test_name); + } // namespace double_transpose } // namespace forward namespace backward { namespace one_input_transpose { -std::shared_ptr CreateFunction(BinaryFactoryPtr binary_factory, - size_t num_binary_ops, - ov::element::Type input_type, - size_t binary_transpose_input_idx) { - const ov::Shape input_shape{1, 96, 55, 55}; +std::shared_ptr CreateFunction(BinaryFactoryPtr binary_factory, + size_t num_binary_ops, + element::Type input_type, + size_t binary_transpose_input_idx) { + const Shape input_shape{1, 96, 55, 55}; - auto X = std::make_shared(input_type, input_shape); + auto X = std::make_shared(input_type, input_shape); NodePtr in_op = X; for (size_t i = 0; i < num_binary_ops; ++i) { - auto in_constant = std::make_shared(input_type, input_shape, ov::Shape{1}); + auto in_constant = std::make_shared(input_type, input_shape, Shape{1}); if (!binary_transpose_input_idx) in_op = binary_factory->create(in_op, in_constant); else in_op = binary_factory->create(in_constant, in_op); } - auto ng_order0 = std::make_shared(ov::element::u64, ov::Shape{4}, ov::Shape{0, 2, 3, 1}); - auto transpose0 = std::make_shared(in_op, ng_order0); + auto ng_order0 = std::make_shared(element::u64, Shape{4}, Shape{0, 2, 3, 1}); + auto transpose0 = std::make_shared(in_op, ng_order0); - return std::make_shared(ov::OutputVector{transpose0}, ov::ParameterVector{X}); + return std::make_shared(ov::OutputVector{transpose0}, ov::ParameterVector{X}); } -std::shared_ptr CreateReferenceFunction(BinaryFactoryPtr binary_factory, - size_t num_binary_ops, - ov::element::Type input_type, - size_t binary_transpose_input_idx) { - const ov::Shape input_shape{1, 96, 55, 55}; +std::shared_ptr CreateReferenceFunction(BinaryFactoryPtr binary_factory, + size_t num_binary_ops, + element::Type input_type, + size_t binary_transpose_input_idx) { + const Shape input_shape{1, 96, 55, 55}; - auto X = std::make_shared(input_type, input_shape); + auto X = std::make_shared(input_type, input_shape); - auto ng_order0 = std::make_shared(ov::element::u64, ov::Shape{4}, ov::Shape{0, 2, 3, 1}); - auto transpose0 = std::make_shared(X, ng_order0); + auto tanh = std::make_shared(X); + + auto ng_order0 = std::make_shared(element::u64, Shape{4}, Shape{0, 2, 3, 1}); + auto transpose0 = std::make_shared(X, ng_order0); NodePtr in_op = transpose0; for (size_t i = 0; i < num_binary_ops; ++i) { - auto in_constant = std::make_shared(input_type, input_shape, ov::Shape{1}); + auto in_constant = std::make_shared(input_type, input_shape, Shape{1}); - auto ng_order = std::make_shared(ov::element::u64, ov::Shape{4}, ov::Shape{0, 2, 3, 1}); - auto transpose = std::make_shared(in_constant, ng_order); + auto ng_order = std::make_shared(element::u64, Shape{4}, Shape{0, 2, 3, 1}); + auto transpose = std::make_shared(in_constant, ng_order); if (!binary_transpose_input_idx) in_op = binary_factory->create(in_op, transpose); @@ -272,24 +354,20 @@ std::shared_ptr CreateReferenceFunction(BinaryFactoryPtr binary_facto in_op = binary_factory->create(transpose, in_op); } - return std::make_shared(ov::OutputVector{in_op}, ov::ParameterVector{X}); + return std::make_shared(ov::OutputVector{in_op}, ov::ParameterVector{X}); } -} // namespace one_input_transpose -} // namespace backward -} // namespace single_consumer -} // namespace binary -using CreateGraphBinaryF = std::function(BinaryFactoryPtr unary_factory, - size_t num_binary_ops, - ov::element::Type input_type, - size_t binary_transpose_input_idx)>; +using CreateGraphBinaryF = std::function(BinaryFactoryPtr binary_factory, + size_t num_binary_ops, + element::Type input_type, + size_t binary_transpose_input_idx)>; using TestBinaryParams = std::tuple; /* binary_transpose_input_idx */ class TransposeSinkingBinaryTestFixture : public ::testing::WithParamInterface, @@ -301,8 +379,9 @@ public: size_t num_binary_ops; CreateGraphBinaryF model_factory; CreateGraphBinaryF reference_model_factory; - ov::element::Type input_type; + element::Type input_type; size_t binary_transpose_input_idx; + std::tie(binary_factory, pass_factory, num_binary_ops, @@ -312,11 +391,11 @@ public: binary_transpose_input_idx) = obj.param; std::ostringstream test_name; - test_name << "binary_factory=" << binary_factory->getTypeName() << "_"; - test_name << "pass_factory=" << pass_factory->getTypeName() << "_"; - test_name << "num_binary_ops=" << num_binary_ops << "_"; - test_name << "input_type=" << input_type << "_"; - test_name << "binary_transpose_input_idx=" << binary_transpose_input_idx; + test_name << "binaryFactory=" << binary_factory->getTypeName() << "/"; + test_name << "passFactory=" << pass_factory->getTypeName() << "/"; + test_name << "numBinaryOps=" << num_binary_ops << "/"; + test_name << "inputType=" << input_type << "/"; + test_name << "binaryTransposeInputIdx=" << binary_transpose_input_idx; return test_name.str(); } @@ -328,7 +407,7 @@ TEST_P(TransposeSinkingBinaryTestFixture, CompareFunctions) { size_t num_binary_ops; CreateGraphBinaryF model_factory; CreateGraphBinaryF reference_model_factory; - ov::element::Type input_type; + element::Type input_type; size_t binary_transpose_input_idx; std::tie(binary_factory, pass_factory, @@ -346,44 +425,42 @@ TEST_P(TransposeSinkingBinaryTestFixture, CompareFunctions) { INSTANTIATE_TEST_SUITE_P( TransposeSinkingBinaryForwardTestSuite, TransposeSinkingBinaryTestFixture, - ::testing::Combine( - ::testing::ValuesIn(binary_factories), - ::testing::Values(CREATE_PASS_FACTORY(TransposeSinkingBinaryElementwiseForward)), - ::testing::ValuesIn(binary_operations_numbers), - ::testing::Values(binary::single_consumer::forward::one_input_transpose::CreateFunction), - ::testing::Values(binary::single_consumer::forward::one_input_transpose::CreateReferenceFunction), - ::testing::Values(ov::element::f32), - ::testing::ValuesIn(binary_transpose_input_indexes)), + ::testing::Combine(::testing::ValuesIn(binary_factories), + ::testing::Values(CREATE_PASS_FACTORY(TransposeSinkingBinaryForward)), + ::testing::ValuesIn(binary_operations_numbers), + ::testing::Values(single_consumer::forward::one_input_transpose::CreateFunction), + ::testing::Values(single_consumer::forward::one_input_transpose::CreateReferenceFunction), + ::testing::Values(element::f32), + ::testing::ValuesIn(binary_transpose_input_indexes)), TransposeSinkingBinaryTestFixture::get_test_name); INSTANTIATE_TEST_SUITE_P( TransposeSinkingBinaryBackwardTestSuite, TransposeSinkingBinaryTestFixture, - ::testing::Combine( - ::testing::ValuesIn(binary_factories), - ::testing::Values(CREATE_PASS_FACTORY(TransposeSinkingBinaryElementwiseBackward)), - ::testing::ValuesIn(binary_operations_numbers), - ::testing::Values(binary::single_consumer::backward::one_input_transpose::CreateFunction), - ::testing::Values(binary::single_consumer::backward::one_input_transpose::CreateReferenceFunction), - ::testing::Values(ov::element::f32), - ::testing::ValuesIn(binary_transpose_input_indexes)), + ::testing::Combine(::testing::ValuesIn(binary_factories), + ::testing::Values(CREATE_PASS_FACTORY(TransposeSinkingBinaryBackward)), + ::testing::ValuesIn(binary_operations_numbers), + ::testing::Values(single_consumer::backward::one_input_transpose::CreateFunction), + ::testing::Values(single_consumer::backward::one_input_transpose::CreateReferenceFunction), + ::testing::Values(element::f32), + ::testing::ValuesIn(binary_transpose_input_indexes)), TransposeSinkingBinaryTestFixture::get_test_name); // -------------------------------------------------------------------------------------- -using CreateGraphBinaryIncompatShapesF = std::function(BinaryFactoryPtr unary_factory, - ov::element::Type input_type, - ov::Shape input_shape, - ov::Shape constant_shape, - size_t binary_transpose_input_idx)>; +using CreateGraphBinaryIncompatShapesF = std::function(BinaryFactoryPtr unary_factory, + element::Type input_type, + Shape input_shape, + Shape constant_shape, + size_t binary_transpose_input_idx)>; using TestBinaryIncompatShapesParams = std::tuple; /* binary_transpose_input_idx */ class TransposeSinkingBinaryIncompatShapesTestFixture @@ -393,11 +470,11 @@ public: static std::string get_test_name(const testing::TestParamInfo& obj) { BinaryFactoryPtr binary_factory; PassFactoryPtr pass_factory; - ov::Shape input_shape; - ov::Shape constant_shape; + Shape input_shape; + Shape constant_shape; CreateGraphBinaryIncompatShapesF model_factory; CreateGraphBinaryIncompatShapesF reference_model_factory; - ov::element::Type input_type; + element::Type input_type; size_t binary_transpose_input_idx; std::tie(binary_factory, pass_factory, @@ -409,12 +486,12 @@ public: binary_transpose_input_idx) = obj.param; std::ostringstream test_name; - test_name << "binary_factory=" << binary_factory->getTypeName() << "_"; - test_name << "pass_factory=" << pass_factory->getTypeName() << "_"; - test_name << "input_shape=" << to_string(input_shape) << "_"; - test_name << "constant_shape=" << to_string(constant_shape) << "_"; - test_name << "input_type=" << input_type << "_"; - test_name << "binary_transpose_input_idx=" << binary_transpose_input_idx; + test_name << "binaryFactory=" << binary_factory->getTypeName() << "/"; + test_name << "passFactory=" << pass_factory->getTypeName() << "/"; + test_name << "inputShape=" << to_string(input_shape) << "/"; + test_name << "constantShape=" << to_string(constant_shape) << "/"; + test_name << "inputType=" << input_type << "/"; + test_name << "binaryTransposeInputIdx=" << binary_transpose_input_idx; return test_name.str(); } @@ -423,11 +500,11 @@ public: TEST_P(TransposeSinkingBinaryIncompatShapesTestFixture, CompareFunctions) { BinaryFactoryPtr binary_factory; PassFactoryPtr pass_factory; - ov::Shape input_shape; - ov::Shape constant_shape; + Shape input_shape; + Shape constant_shape; CreateGraphBinaryIncompatShapesF model_factory; CreateGraphBinaryIncompatShapesF reference_model_factory; - ov::element::Type input_type; + element::Type input_type; size_t binary_transpose_input_idx; std::tie(binary_factory, pass_factory, @@ -449,14 +526,14 @@ namespace single_consumer { namespace backward { namespace incompat_shapes { -std::shared_ptr CreateFunction(BinaryFactoryPtr binary_factory, - ov::element::Type input_type, - ov::Shape input_shape, - ov::Shape constant_shape, - size_t binary_transpose_input_idx) { - auto X = std::make_shared(input_type, input_shape); +std::shared_ptr CreateFunction(BinaryFactoryPtr binary_factory, + element::Type input_type, + Shape input_shape, + Shape constant_shape, + size_t binary_transpose_input_idx) { + auto X = std::make_shared(input_type, input_shape); - auto in_constant = std::make_shared(input_type, constant_shape, ov::Shape{1}); + auto in_constant = std::make_shared(input_type, constant_shape, Shape{1}); NodePtr binary_op; if (!binary_transpose_input_idx) @@ -464,31 +541,31 @@ std::shared_ptr CreateFunction(BinaryFactoryPtr binary_factory, else binary_op = binary_factory->create(in_constant, X); - auto ng_order0 = std::make_shared(ov::element::u64, ov::Shape{4}, ov::Shape{0, 2, 3, 1}); - auto transpose0 = std::make_shared(binary_op, ng_order0); + auto ng_order0 = std::make_shared(element::u64, Shape{4}, Shape{0, 2, 3, 1}); + auto transpose0 = std::make_shared(binary_op, ng_order0); - return std::make_shared(ov::OutputVector{transpose0}, ov::ParameterVector{X}); + return std::make_shared(ov::OutputVector{transpose0}, ov::ParameterVector{X}); } -std::shared_ptr CreateReferenceFunction(BinaryFactoryPtr binary_factory, - ov::element::Type input_type, - ov::Shape input_shape, - ov::Shape constant_shape, - size_t binary_transpose_input_idx) { - auto X = std::make_shared(input_type, input_shape); +std::shared_ptr CreateReferenceFunction(BinaryFactoryPtr binary_factory, + element::Type input_type, + Shape input_shape, + Shape constant_shape, + size_t binary_transpose_input_idx) { + auto X = std::make_shared(input_type, input_shape); - auto ng_order0 = std::make_shared(ov::element::u64, ov::Shape{4}, ov::Shape{0, 2, 3, 1}); - auto transpose0 = std::make_shared(X, ng_order0); + auto ng_order0 = std::make_shared(element::u64, Shape{4}, Shape{0, 2, 3, 1}); + auto transpose0 = std::make_shared(X, ng_order0); - auto in_constant = std::make_shared(input_type, constant_shape, ov::Shape{1}); + auto in_constant = std::make_shared(input_type, constant_shape, Shape{1}); std::vector dims(input_shape.size() - constant_shape.size()); std::iota(dims.begin(), dims.end(), 0); - auto unsqueeze_const = std::make_shared(ov::element::i64, ov::Shape{dims.size()}, dims); - auto unsqeeze = std::make_shared(in_constant, unsqueeze_const); + auto unsqueeze_const = std::make_shared(element::i64, Shape{dims.size()}, dims); + auto unsqeeze = std::make_shared(in_constant, unsqueeze_const); - auto ng_order1 = std::make_shared(ov::element::u64, ov::Shape{4}, ov::Shape{0, 2, 3, 1}); - auto transpose1 = std::make_shared(unsqeeze, ng_order1); + auto ng_order1 = std::make_shared(element::u64, Shape{4}, Shape{0, 2, 3, 1}); + auto transpose1 = std::make_shared(unsqeeze, ng_order1); NodePtr binary_op; if (!binary_transpose_input_idx) @@ -496,10 +573,10 @@ std::shared_ptr CreateReferenceFunction(BinaryFactoryPtr binary_facto else binary_op = binary_factory->create(transpose1, transpose0); - return std::make_shared(ov::OutputVector{binary_op}, ov::ParameterVector{X}); + return std::make_shared(ov::OutputVector{binary_op}, ov::ParameterVector{X}); } -std::vector constant_shapes = {ov::Shape{96, 55, 55}, ov::Shape{1}}; +std::vector constant_shapes = {Shape{96, 55, 55}, Shape{1}}; } // namespace incompat_shapes } // namespace backward @@ -507,17 +584,17 @@ std::vector constant_shapes = {ov::Shape{96, 55, 55}, ov::Shape{1}}; namespace forward { namespace incompat_shapes { -std::shared_ptr CreateFunction(BinaryFactoryPtr binary_factory, - ov::element::Type input_type, - ov::Shape input_shape, - ov::Shape constant_shape, - size_t binary_transpose_input_idx) { - auto X = std::make_shared(input_type, input_shape); +std::shared_ptr CreateFunction(BinaryFactoryPtr binary_factory, + element::Type input_type, + Shape input_shape, + Shape constant_shape, + size_t binary_transpose_input_idx) { + auto X = std::make_shared(input_type, input_shape); - auto in_constant = std::make_shared(input_type, constant_shape, ov::Shape{1}); + auto in_constant = std::make_shared(input_type, constant_shape, Shape{1}); - auto ng_order0 = std::make_shared(ov::element::u64, ov::Shape{4}, ov::Shape{0, 2, 3, 1}); - auto transpose0 = std::make_shared(X, ng_order0); + auto ng_order0 = std::make_shared(element::u64, Shape{4}, Shape{0, 2, 3, 1}); + auto transpose0 = std::make_shared(X, ng_order0); NodePtr binary_op; if (!binary_transpose_input_idx) @@ -525,25 +602,25 @@ std::shared_ptr CreateFunction(BinaryFactoryPtr binary_factory, else binary_op = binary_factory->create(in_constant, transpose0); - return std::make_shared(ov::OutputVector{binary_op}, ov::ParameterVector{X}); + return std::make_shared(ov::OutputVector{binary_op}, ov::ParameterVector{X}); } -std::shared_ptr CreateReferenceFunction(BinaryFactoryPtr binary_factory, - ov::element::Type input_type, - ov::Shape input_shape, - ov::Shape constant_shape, - size_t binary_transpose_input_idx) { - auto X = std::make_shared(input_type, input_shape); +std::shared_ptr CreateReferenceFunction(BinaryFactoryPtr binary_factory, + element::Type input_type, + Shape input_shape, + Shape constant_shape, + size_t binary_transpose_input_idx) { + auto X = std::make_shared(input_type, input_shape); - auto in_constant = std::make_shared(input_type, constant_shape, ov::Shape{1}); + auto in_constant = std::make_shared(input_type, constant_shape, Shape{1}); std::vector dims(input_shape.size() - constant_shape.size()); std::iota(dims.begin(), dims.end(), 0); - auto unsqueeze_const = std::make_shared(ov::element::i64, ov::Shape{dims.size()}, dims); - auto unsqeeze = std::make_shared(in_constant, unsqueeze_const); + auto unsqueeze_const = std::make_shared(element::i64, Shape{dims.size()}, dims); + auto unsqeeze = std::make_shared(in_constant, unsqueeze_const); - auto ng_order1 = std::make_shared(ov::element::u64, ov::Shape{4}, ov::Shape{0, 3, 1, 2}); - auto transpose1 = std::make_shared(unsqeeze, ng_order1); + auto ng_order1 = std::make_shared(element::u64, Shape{4}, Shape{0, 3, 1, 2}); + auto transpose1 = std::make_shared(unsqeeze, ng_order1); NodePtr binary_op; if (!binary_transpose_input_idx) @@ -551,13 +628,13 @@ std::shared_ptr CreateReferenceFunction(BinaryFactoryPtr binary_facto else binary_op = binary_factory->create(transpose1, X); - auto ng_order0 = std::make_shared(ov::element::u64, ov::Shape{4}, ov::Shape{0, 2, 3, 1}); - auto transpose0 = std::make_shared(binary_op, ng_order0); + auto ng_order0 = std::make_shared(element::u64, Shape{4}, Shape{0, 2, 3, 1}); + auto transpose0 = std::make_shared(binary_op, ng_order0); - return std::make_shared(ov::OutputVector{transpose0}, ov::ParameterVector{X}); + return std::make_shared(ov::OutputVector{transpose0}, ov::ParameterVector{X}); } -std::vector constant_shapes = {ov::Shape{55, 55, 96}, ov::Shape{1}}; +std::vector constant_shapes = {Shape{55, 55, 96}, Shape{1}}; } // namespace incompat_shapes } // namespace forward @@ -568,25 +645,638 @@ std::vector constant_shapes = {ov::Shape{55, 55, 96}, ov::Shape{1}}; INSTANTIATE_TEST_SUITE_P( TransposeSinkingBinaryIncompatShapesBackwardTestSuite, TransposeSinkingBinaryIncompatShapesTestFixture, - ::testing::Combine(::testing::ValuesIn(binary_factories), - ::testing::Values(CREATE_PASS_FACTORY(TransposeSinkingBinaryElementwiseBackward)), - ::testing::Values(ov::Shape{1, 96, 55, 55}), + ::testing::Combine(::testing::ValuesIn(binary_elementwise_factories), + ::testing::Values(CREATE_PASS_FACTORY(TransposeSinkingBinaryBackward)), + ::testing::Values(Shape{1, 96, 55, 55}), ::testing::ValuesIn(binary::single_consumer::backward::incompat_shapes::constant_shapes), ::testing::Values(binary::single_consumer::backward::incompat_shapes::CreateFunction), ::testing::Values(binary::single_consumer::backward::incompat_shapes::CreateReferenceFunction), - ::testing::Values(ov::element::f32), + ::testing::Values(element::f32), ::testing::ValuesIn(binary_transpose_input_indexes)), TransposeSinkingBinaryIncompatShapesTestFixture::get_test_name); INSTANTIATE_TEST_SUITE_P( TransposeSinkingBinaryIncompatShapesForwardTestSuite, TransposeSinkingBinaryIncompatShapesTestFixture, - ::testing::Combine(::testing::ValuesIn(binary_factories), - ::testing::Values(CREATE_PASS_FACTORY(TransposeSinkingBinaryElementwiseForward)), - ::testing::Values(ov::Shape{1, 96, 55, 55}), + ::testing::Combine(::testing::ValuesIn(binary_elementwise_factories), + ::testing::Values(CREATE_PASS_FACTORY(TransposeSinkingBinaryForward)), + ::testing::Values(Shape{1, 96, 55, 55}), ::testing::ValuesIn(binary::single_consumer::forward::incompat_shapes::constant_shapes), ::testing::Values(binary::single_consumer::forward::incompat_shapes::CreateFunction), ::testing::Values(binary::single_consumer::forward::incompat_shapes::CreateReferenceFunction), - ::testing::Values(ov::element::f32), + ::testing::Values(element::f32), ::testing::ValuesIn(binary_transpose_input_indexes)), TransposeSinkingBinaryIncompatShapesTestFixture::get_test_name); + +INSTANTIATE_TEST_SUITE_P( + TransposeSinkingPReluIncompatShapesBackwardTestSuite, + TransposeSinkingBinaryIncompatShapesTestFixture, + ::testing::Combine(::testing::Values(CREATE_BINARY_FACTORY(PRelu)), + ::testing::Values(CREATE_PASS_FACTORY(TransposeSinkingBinaryBackward)), + ::testing::Values(Shape{1, 3, 16, 16}), + ::testing::ValuesIn(std::vector{Shape{3}}), + ::testing::Values(binary::single_consumer::backward::incompat_shapes::CreateFunction), + ::testing::Values(binary::single_consumer::backward::incompat_shapes::CreateReferenceFunction), + ::testing::Values(element::f32), + ::testing::Values(0)), + TransposeSinkingBinaryIncompatShapesTestFixture::get_test_name); + +INSTANTIATE_TEST_SUITE_P( + TransposeSinkingPReluIncompatShapesForwardTestSuite, + TransposeSinkingBinaryIncompatShapesTestFixture, + ::testing::Combine(::testing::Values(CREATE_BINARY_FACTORY(PRelu)), + ::testing::Values(CREATE_PASS_FACTORY(TransposeSinkingBinaryForward)), + ::testing::Values(Shape{1, 3, 16, 16}), + ::testing::ValuesIn(std::vector{Shape{3}}), + ::testing::Values(binary::single_consumer::forward::incompat_shapes::CreateFunction), + ::testing::Values(binary::single_consumer::forward::incompat_shapes::CreateReferenceFunction), + ::testing::Values(element::f32), + ::testing::Values(0)), + TransposeSinkingBinaryIncompatShapesTestFixture::get_test_name); + +} // namespace one_input_transpose +} // namespace backward +} // namespace single_consumer + +namespace mult_consumers { +namespace forward { +namespace input_transpose_consumers { + +std::shared_ptr CreateFunction(BinaryFactoryPtr binary_factory, + element::Type input_type, + size_t binary_transpose_input_idx) { + const Shape input_shape{1, 96, 55, 55}; + const Shape const_shape{1, 55, 55, 96}; + + auto X = std::make_shared(input_type, input_shape); + + auto ng_order0 = std::make_shared(element::u64, Shape{4}, Shape{0, 2, 3, 1}); + auto transpose0 = std::make_shared(X, ng_order0); + + auto tanh = std::make_shared(transpose0); + + NodePtr binary; + auto in_constant = std::make_shared(input_type, const_shape, Shape{1}); + if (!binary_transpose_input_idx) + binary = binary_factory->create(transpose0, in_constant); + else + binary = binary_factory->create(in_constant, transpose0); + + return std::make_shared(ov::OutputVector{binary, tanh}, ov::ParameterVector{X}); +} + +std::shared_ptr CreateReferenceFunction(BinaryFactoryPtr binary_factory, + element::Type input_type, + size_t binary_transpose_input_idx) { + const Shape input_shape{1, 96, 55, 55}; + const Shape const_shape{1, 55, 55, 96}; + + auto X = std::make_shared(input_type, input_shape); + + auto ng_order0 = std::make_shared(element::u64, Shape{4}, Shape{0, 2, 3, 1}); + auto transpose0 = std::make_shared(X, ng_order0); + + auto tanh = std::make_shared(transpose0); + + NodePtr binary; + auto in_constant = std::make_shared(input_type, const_shape, Shape{1}); + + auto transpose_reversed_const = std::make_shared(element::u64, Shape{4}, Shape{0, 3, 1, 2}); + auto transpose_reversed = std::make_shared(in_constant, transpose_reversed_const); + + if (!binary_transpose_input_idx) + binary = binary_factory->create(X, transpose_reversed); + else + binary = binary_factory->create(transpose_reversed, X); + + auto ng_order1 = std::make_shared(element::u64, Shape{4}, Shape{0, 2, 3, 1}); + auto transpose1 = std::make_shared(binary, ng_order1); + + return std::make_shared(ov::OutputVector{transpose1, tanh}, ov::ParameterVector{X}); +} + +} // namespace input_transpose_consumers + +namespace output_consumers { + +namespace one_binary { + +std::shared_ptr CreateFunction(BinaryFactoryPtr binary_factory, + element::Type input_type, + size_t binary_transpose_input_idx) { + const Shape input_shape{1, 96, 55, 55}; + const Shape const_shape{1, 55, 55, 96}; + + auto X = std::make_shared(input_type, input_shape); + + auto ng_order0 = std::make_shared(element::u64, Shape{4}, Shape{0, 2, 3, 1}); + auto transpose0 = std::make_shared(X, ng_order0); + + NodePtr binary; + auto in_constant = std::make_shared(input_type, const_shape, Shape{1}); + if (!binary_transpose_input_idx) + binary = binary_factory->create(transpose0, in_constant); + else + binary = binary_factory->create(in_constant, transpose0); + + auto tanh1 = std::make_shared(binary); + auto tanh2 = std::make_shared(binary); + + return std::make_shared(ov::OutputVector{tanh1, tanh2}, ov::ParameterVector{X}); +} + +std::shared_ptr CreateReferenceFunction(BinaryFactoryPtr binary_factory, + element::Type input_type, + size_t binary_transpose_input_idx) { + const Shape input_shape{1, 96, 55, 55}; + const Shape const_shape{1, 55, 55, 96}; + + auto X = std::make_shared(input_type, input_shape); + + NodePtr binary; + auto in_constant = std::make_shared(input_type, const_shape, Shape{1}); + + auto transpose_reversed_const = std::make_shared(element::u64, Shape{4}, Shape{0, 3, 1, 2}); + auto transpose_reversed = std::make_shared(in_constant, transpose_reversed_const); + + if (!binary_transpose_input_idx) + binary = binary_factory->create(X, transpose_reversed); + else + binary = binary_factory->create(transpose_reversed, X); + + auto ng_order0 = std::make_shared(element::u64, Shape{4}, Shape{0, 2, 3, 1}); + auto transpose0 = std::make_shared(binary, ng_order0); + + auto tanh1 = std::make_shared(transpose0); + auto tanh2 = std::make_shared(transpose0); + + return std::make_shared(ov::OutputVector{tanh1, tanh2}, ov::ParameterVector{X}); +} + +} // namespace one_binary + +} // namespace output_consumers + +namespace input_node_consumers { + +std::shared_ptr CreateFunction(BinaryFactoryPtr binary_factory, + element::Type input_type, + size_t binary_transpose_input_idx) { + const Shape input_shape{1, 96, 55, 55}; + const Shape const_shape{1, 55, 55, 96}; + + auto X = std::make_shared(input_type, input_shape); + + auto ng_order0 = std::make_shared(element::u64, Shape{4}, Shape{0, 2, 3, 1}); + auto transpose0 = std::make_shared(X, ng_order0); + + NodePtr binary; + auto in_constant = std::make_shared(input_type, const_shape, Shape{1}); + if (!binary_transpose_input_idx) + binary = binary_factory->create(transpose0, in_constant); + else + binary = binary_factory->create(in_constant, transpose0); + + auto tanh = std::make_shared(X); + + return std::make_shared(ov::OutputVector{binary, tanh}, ov::ParameterVector{X}); +} + +std::shared_ptr CreateReferenceFunction(BinaryFactoryPtr binary_factory, + element::Type input_type, + size_t binary_transpose_input_idx) { + const Shape input_shape{1, 96, 55, 55}; + const Shape const_shape{1, 55, 55, 96}; + + auto X = std::make_shared(input_type, input_shape); + + auto tanh = std::make_shared(X); + + auto ng_order0 = std::make_shared(element::u64, Shape{4}, Shape{0, 2, 3, 1}); + auto transpose0 = std::make_shared(X, ng_order0); + + NodePtr binary; + auto in_constant = std::make_shared(input_type, const_shape, Shape{1}); + + auto transpose_reversed_const = std::make_shared(element::u64, Shape{4}, Shape{0, 3, 1, 2}); + auto transpose_reversed = std::make_shared(in_constant, transpose_reversed_const); + + if (!binary_transpose_input_idx) + binary = binary_factory->create(X, transpose_reversed); + else + binary = binary_factory->create(transpose_reversed, X); + + auto ng_order1 = std::make_shared(element::u64, Shape{4}, Shape{0, 2, 3, 1}); + auto transpose1 = std::make_shared(binary, ng_order1); + + return std::make_shared(ov::OutputVector{transpose1, tanh}, ov::ParameterVector{X}); +} + +} // namespace input_node_consumers + +} // namespace forward + +namespace backward { + +namespace output_consumers { + +namespace one_binary { + +std::shared_ptr CreateFunction(BinaryFactoryPtr binary_factory, + element::Type input_type, + size_t binary_transpose_input_idx) { + const Shape input_shape{1, 96, 55, 55}; + + auto X = std::make_shared(input_type, input_shape); + + auto tanh0 = std::make_shared(X); + + NodePtr binary; + auto in_constant = std::make_shared(input_type, input_shape, Shape{1}); + if (!binary_transpose_input_idx) + binary = binary_factory->create(tanh0, in_constant); + else + binary = binary_factory->create(in_constant, tanh0); + + auto tanh = std::make_shared(binary); + + auto ng_order0 = std::make_shared(element::u64, Shape{4}, Shape{0, 2, 3, 1}); + auto transpose0 = std::make_shared(binary, ng_order0); + + return std::make_shared(ov::OutputVector{transpose0, tanh}, ov::ParameterVector{X}); +} + +} // namespace one_binary + +namespace multiple_binaries { + +std::shared_ptr CreateFunction(BinaryFactoryPtr binary_factory, + element::Type input_type, + size_t binary_transpose_input_idx) { + const Shape input_shape{1, 96, 55, 55}; + const size_t n_binaries = 10; + + auto X = std::make_shared(input_type, input_shape); + + auto tanh0 = std::make_shared(X); + + NodePtr in_op = tanh0; + for (size_t i = 0; i < n_binaries; ++i) { + auto in_constant = std::make_shared(input_type, input_shape, Shape{1}); + if (!binary_transpose_input_idx) + in_op = binary_factory->create(in_op, in_constant); + else + in_op = binary_factory->create(in_constant, in_op); + } + + auto tanh = std::make_shared(in_op); + + auto ng_order0 = std::make_shared(element::u64, Shape{4}, Shape{0, 2, 3, 1}); + auto transpose0 = std::make_shared(in_op, ng_order0); + + return std::make_shared(ov::OutputVector{transpose0, tanh}, ov::ParameterVector{X}); +} + +} // namespace multiple_binaries + +} // namespace output_consumers + +namespace input_node_consumers { + +std::shared_ptr CreateFunction(BinaryFactoryPtr binary_factory, + element::Type input_type, + size_t binary_transpose_input_idx) { + const Shape input_shape{1, 96, 55, 55}; + + auto X = std::make_shared(input_type, input_shape); + + auto tanh0 = std::make_shared(X); + + NodePtr binary; + auto in_constant = std::make_shared(input_type, input_shape, Shape{1}); + if (!binary_transpose_input_idx) + binary = binary_factory->create(tanh0, in_constant); + else + binary = binary_factory->create(in_constant, tanh0); + + auto ng_order0 = std::make_shared(element::u64, Shape{4}, Shape{0, 2, 3, 1}); + auto transpose0 = std::make_shared(binary, ng_order0); + + auto tanh1 = std::make_shared(tanh0); + + return std::make_shared(ov::OutputVector{transpose0, tanh1}, ov::ParameterVector{X}); +} + +std::shared_ptr CreateReferenceFunction(BinaryFactoryPtr binary_factory, + element::Type input_type, + size_t binary_transpose_input_idx) { + const Shape input_shape{1, 96, 55, 55}; + + auto X = std::make_shared(input_type, input_shape); + + auto tanh0 = std::make_shared(X); + + auto ng_order0 = std::make_shared(element::u64, Shape{4}, Shape{0, 2, 3, 1}); + auto transpose0 = std::make_shared(tanh0, ng_order0); + + auto in_constant = std::make_shared(input_type, input_shape, Shape{1}); + + auto ng_order = std::make_shared(element::u64, Shape{4}, Shape{0, 2, 3, 1}); + auto transpose = std::make_shared(in_constant, ng_order); + + NodePtr binary; + if (!binary_transpose_input_idx) + binary = binary_factory->create(transpose0, transpose); + else + binary = binary_factory->create(transpose, transpose0); + + auto tanh1 = std::make_shared(tanh0); + + return std::make_shared(ov::OutputVector{binary, tanh1}, ov::ParameterVector{X}); +} + +} // namespace input_node_consumers + +namespace output_transpose_mult_consumers { + +std::shared_ptr CreateFunction(BinaryFactoryPtr binary_factory, + element::Type input_type, + size_t binary_transpose_input_idx) { + const Shape input_shape{1, 96, 55, 55}; + + auto X = std::make_shared(input_type, input_shape); + + NodePtr binary; + auto in_constant = std::make_shared(input_type, input_shape, Shape{1}); + if (!binary_transpose_input_idx) + binary = binary_factory->create(X, in_constant); + else + binary = binary_factory->create(in_constant, X); + + auto ng_order0 = std::make_shared(element::u64, Shape{4}, Shape{0, 2, 3, 1}); + auto transpose0 = std::make_shared(binary, ng_order0); + + auto tanh0 = std::make_shared(transpose0); + auto tanh1 = std::make_shared(transpose0); + + return std::make_shared(ov::OutputVector{tanh0, tanh1}, ov::ParameterVector{X}); +} + +std::shared_ptr CreateReferenceFunction(BinaryFactoryPtr binary_factory, + element::Type input_type, + size_t binary_transpose_input_idx) { + const Shape input_shape{1, 96, 55, 55}; + + auto X = std::make_shared(input_type, input_shape); + + auto ng_order0 = std::make_shared(element::u64, Shape{4}, Shape{0, 2, 3, 1}); + auto transpose0 = std::make_shared(X, ng_order0); + + auto in_constant = std::make_shared(input_type, input_shape, Shape{1}); + + auto ng_order = std::make_shared(element::u64, Shape{4}, Shape{0, 2, 3, 1}); + auto transpose = std::make_shared(in_constant, ng_order); + + NodePtr binary; + if (!binary_transpose_input_idx) + binary = binary_factory->create(transpose0, transpose); + else + binary = binary_factory->create(transpose, transpose0); + + auto tanh0 = std::make_shared(binary); + auto tanh1 = std::make_shared(binary); + + return std::make_shared(ov::OutputVector{tanh0, tanh1}, ov::ParameterVector{X}); +} + +} // namespace output_transpose_mult_consumers + +namespace output_transpose_mult_transposes { + +std::shared_ptr CreateFunction(BinaryFactoryPtr binary_factory, + element::Type input_type, + size_t binary_transpose_input_idx) { + const Shape input_shape{1, 96, 55, 55}; + + auto X = std::make_shared(input_type, input_shape); + + NodePtr binary; + auto in_constant = std::make_shared(input_type, input_shape, Shape{1}); + if (!binary_transpose_input_idx) + binary = binary_factory->create(X, in_constant); + else + binary = binary_factory->create(in_constant, X); + + auto ng_order0 = std::make_shared(element::u64, Shape{4}, Shape{0, 2, 3, 1}); + auto transpose0 = std::make_shared(binary, ng_order0); + + auto tanh0 = std::make_shared(transpose0); + + auto ng_order1 = std::make_shared(element::u64, Shape{4}, Shape{0, 2, 3, 1}); + auto transpose1 = std::make_shared(binary, ng_order1); + + auto tanh1 = std::make_shared(transpose1); + + return std::make_shared(ov::OutputVector{tanh0, tanh1}, ov::ParameterVector{X}); +} + +std::shared_ptr CreateReferenceFunction(BinaryFactoryPtr binary_factory, + element::Type input_type, + size_t binary_transpose_input_idx) { + const Shape input_shape{1, 96, 55, 55}; + + auto X = std::make_shared(input_type, input_shape); + + auto ng_order0 = std::make_shared(element::u64, Shape{4}, Shape{0, 2, 3, 1}); + auto transpose0 = std::make_shared(X, ng_order0); + + auto in_constant = std::make_shared(input_type, input_shape, Shape{1}); + + auto ng_order = std::make_shared(element::u64, Shape{4}, Shape{0, 2, 3, 1}); + auto transpose = std::make_shared(in_constant, ng_order); + + NodePtr binary; + if (!binary_transpose_input_idx) + binary = binary_factory->create(transpose0, transpose); + else + binary = binary_factory->create(transpose, transpose0); + + auto tanh0 = std::make_shared(binary); + auto tanh1 = std::make_shared(binary); + + return std::make_shared(ov::OutputVector{tanh0, tanh1}, ov::ParameterVector{X}); +} + +} // namespace output_transpose_mult_transposes + +} // namespace backward + +using CreateGraphF = std::function(BinaryFactoryPtr binary_factory, + element::Type input_type, + size_t binary_transpose_input_idx)>; + +struct CreateGraphFunctionDesc { + CreateGraphFunctionDesc() = default; + CreateGraphFunctionDesc(CreateGraphF a_model_factory, + CreateGraphF a_reference_model_factory, + std::string a_subtest_name) + : model_factory(a_model_factory), + reference_model_factory(a_reference_model_factory), + subtest_name(a_subtest_name) {} + CreateGraphF model_factory; + CreateGraphF reference_model_factory; + std::string subtest_name; +}; + +using TestBinaryParams = std::tuple; /*binary_transpose_input_idx*/ + +class TransposeBinaryMultiSinkingFixture : public ::testing::WithParamInterface, + public TransformationTestsF { +public: + static std::string get_test_name(const testing::TestParamInfo& obj) { + BinaryFactoryPtr binary_factory; + PassFactoryPtr pass_factory; + CreateGraphFunctionDesc function_desc; + element::Type input_type; + size_t binary_transpose_input_idx; + + std::tie(binary_factory, pass_factory, function_desc, input_type, binary_transpose_input_idx) = obj.param; + + std::ostringstream test_name; + test_name << "binaryFactory=" << binary_factory->getTypeName() << "/"; + test_name << "passFactory=" << pass_factory->getTypeName() << "/"; + test_name << function_desc.subtest_name << "/"; + test_name << "inputType=" << input_type << "/"; + test_name << "binaryTransposeInputIdx=" << binary_transpose_input_idx; + + return test_name.str(); + } +}; + +TEST_P(TransposeBinaryMultiSinkingFixture, CompareFunctions) { + BinaryFactoryPtr binary_factory; + PassFactoryPtr pass_factory; + CreateGraphFunctionDesc function_desc; + element::Type input_type; + size_t binary_transpose_input_idx; + + std::tie(binary_factory, pass_factory, function_desc, input_type, binary_transpose_input_idx) = this->GetParam(); + + model = function_desc.model_factory(binary_factory, input_type, binary_transpose_input_idx); + model_ref = function_desc.reference_model_factory(binary_factory, input_type, binary_transpose_input_idx); + pass_factory->registerPass(manager); +} + +#define SUBTEST(nmspace, subtest_name) \ + CreateGraphFunctionDesc(nmspace::CreateFunction, nmspace::CreateReferenceFunction, subtest_name) + +std::vector forward_subtests = { + SUBTEST(forward::input_transpose_consumers, "forwardInputTransposeConsumers"), + SUBTEST(forward::output_consumers::one_binary, "forwardOutputConsumers"), + SUBTEST(forward::input_node_consumers, "forwardInputNodeConsumers")}; + +std::vector backward_subtests = { + SUBTEST(backward::input_node_consumers, "backwardInputNodeConsumers"), + SUBTEST(backward::output_transpose_mult_consumers, "backwardOutputTransposeMultConsumers"), + SUBTEST(backward::output_transpose_mult_transposes, "outputTransposeMultTransposes")}; + +#undef SUBTEST + +INSTANTIATE_TEST_SUITE_P(TransposeSinkingBinaryForwardMultiConsumersTestSuite, + TransposeBinaryMultiSinkingFixture, + ::testing::Combine(::testing::ValuesIn(binary_factories), + ::testing::Values(CREATE_PASS_FACTORY(TransposeSinkingBinaryForward)), + ::testing::ValuesIn(forward_subtests), + ::testing::Values(element::f32), + ::testing::ValuesIn(binary_transpose_input_indexes)), + TransposeBinaryMultiSinkingFixture::get_test_name); + +INSTANTIATE_TEST_SUITE_P(TransposeSinkingBinaryBackwardMultiConsumersTestSuite, + TransposeBinaryMultiSinkingFixture, + ::testing::Combine(::testing::ValuesIn(binary_factories), + ::testing::Values(CREATE_PASS_FACTORY(TransposeSinkingBinaryBackward)), + ::testing::ValuesIn(backward_subtests), + ::testing::Values(element::f32), + ::testing::ValuesIn(binary_transpose_input_indexes)), + TransposeBinaryMultiSinkingFixture::get_test_name); + +namespace no_sinking { + +struct CreateGraphFunctionDesc { + CreateGraphFunctionDesc() = default; + CreateGraphFunctionDesc(CreateGraphF a_model_factory, std::string a_subtest_name) + : model_factory(a_model_factory), + subtest_name(a_subtest_name) {} + CreateGraphF model_factory; + std::string subtest_name; +}; + +using TestBinaryParams = std::tuple; /*binary_transpose_input_idx*/ + +class TransposeBinaryMultiSinkingBinaryMultiConsumersFixture : public ::testing::WithParamInterface, + public TransformationTestsF { +public: + static std::string get_test_name(const testing::TestParamInfo& obj) { + BinaryFactoryPtr binary_factory; + PassFactoryPtr pass_factory; + CreateGraphFunctionDesc function_desc; + element::Type input_type; + size_t binary_transpose_input_idx; + + std::tie(binary_factory, pass_factory, function_desc, input_type, binary_transpose_input_idx) = obj.param; + + std::ostringstream test_name; + test_name << "binaryFactory=" << binary_factory->getTypeName() << "/"; + test_name << "passFactory=" << pass_factory->getTypeName() << "/"; + test_name << function_desc.subtest_name << "/"; + test_name << "inputType=" << input_type << "/"; + test_name << "binaryTransposeInputIdx=" << binary_transpose_input_idx; + + return test_name.str(); + } +}; + +TEST_P(TransposeBinaryMultiSinkingBinaryMultiConsumersFixture, CompareFunctions) { + BinaryFactoryPtr binary_factory; + PassFactoryPtr pass_factory; + CreateGraphFunctionDesc function_desc; + element::Type input_type; + size_t binary_transpose_input_idx; + + std::tie(binary_factory, pass_factory, function_desc, input_type, binary_transpose_input_idx) = this->GetParam(); + + model = function_desc.model_factory(binary_factory, input_type, binary_transpose_input_idx); + model_ref = model->clone(); + pass_factory->registerPass(manager); +} + +#define SUBTEST(nmspace, subtest_name) CreateGraphFunctionDesc(nmspace::CreateFunction, subtest_name) + +std::vector backward_subtests_binary_consumers = { + SUBTEST(backward::output_consumers::one_binary, "backwardOutputConsumersOneBinary"), + SUBTEST(backward::output_consumers::multiple_binaries, "backwardOutputConsumersMultipleBinaries"), +}; +#undef SUBTEST + +INSTANTIATE_TEST_SUITE_P(TransposeSinkingBinaryBackwardBinaryMultiConsumersTestSuite, + TransposeBinaryMultiSinkingBinaryMultiConsumersFixture, + ::testing::Combine(::testing::ValuesIn(binary_factories), + ::testing::Values(CREATE_PASS_FACTORY(TransposeSinkingBinaryBackward)), + ::testing::ValuesIn(backward_subtests_binary_consumers), + ::testing::Values(element::f32), + ::testing::ValuesIn(binary_transpose_input_indexes)), + TransposeBinaryMultiSinkingBinaryMultiConsumersFixture::get_test_name); + +} // namespace no_sinking + +} // namespace mult_consumers + +} // namespace transpose_sinking_binary_eltwise diff --git a/src/common/transformations/tests/common_optimizations/transpose_sinking_concat_test.cpp b/src/common/transformations/tests/common_optimizations/transpose_sinking_concat_test.cpp index 940215fd62e..9feb14011a9 100644 --- a/src/common/transformations/tests/common_optimizations/transpose_sinking_concat_test.cpp +++ b/src/common/transformations/tests/common_optimizations/transpose_sinking_concat_test.cpp @@ -12,44 +12,25 @@ #include "common_test_utils/ngraph_test_utils.hpp" #include "gtest/gtest.h" +using namespace ov; +using namespace ov::opset9; + namespace { using NodePtr = std::shared_ptr; -using ModelPtr = std::shared_ptr; -using Output = ov::Output; - -// ---------------------------------------------------------------------------- - -class IBinaryFactory { -public: - IBinaryFactory() = default; - virtual ~IBinaryFactory() = default; - virtual NodePtr create(NodePtr parent_left_node, NodePtr parent_right_node) const = 0; -}; - -using BinaryFactoryPtr = std::shared_ptr; - -template -class BinaryFactory : public IBinaryFactory { -public: - BinaryFactory() = default; - NodePtr create(NodePtr parent_left_node, NodePtr parent_right_node) const override { - return std::make_shared(parent_left_node, parent_right_node); - } -}; - -template -BinaryFactoryPtr CreateBinaryFactory() { - return std::make_shared>(); -} - -// ---------------------------------------------------------------------------- +using ModelPtr = std::shared_ptr; class IPassFactory { public: - IPassFactory() = default; + IPassFactory(const std::string& type_name) : type_name_(type_name) {} virtual ~IPassFactory() = default; virtual void registerPass(ov::pass::Manager& pass_manager) const = 0; + const std::string& getTypeName() const { + return type_name_; + } + +private: + const std::string type_name_; }; using PassFactoryPtr = std::shared_ptr; @@ -57,186 +38,207 @@ using PassFactoryPtr = std::shared_ptr; template class PassFactory : public IPassFactory { public: + PassFactory(const std::string& type_name) : IPassFactory(type_name) {} void registerPass(ov::pass::Manager& pass_manager) const override { pass_manager.register_pass(); } }; -template -PassFactoryPtr CreatePassFactory() { - return std::make_shared>(); -} - -std::vector binary_factories = {CreateBinaryFactory(), - CreateBinaryFactory(), - CreateBinaryFactory(), - CreateBinaryFactory(), - CreateBinaryFactory(), - CreateBinaryFactory(), - CreateBinaryFactory(), - CreateBinaryFactory(), - CreateBinaryFactory()}; - -std::vector binary_operations_numbers = {1, 10}; - -std::vector binary_transpose_input_indexes = {0, 1}; +#define CREATE_PASS_FACTORY(pass_name) std::make_shared>(#pass_name) } // namespace -using CreateGraphConcatF = std::function(size_t num_concat_ops, - ov::element::Type input_type, - size_t concat_transpose_input_idx, - size_t num_concat_inputs)>; - -using TestConcatParams = std::tuple; /* num_concat_inputs */ - -class TransposeSinkingConcatTestFixture : public ::testing::WithParamInterface, - public TransformationTestsF {}; - namespace { std::vector concat_operations_numbers = {1, 10}; std::vector concat_transpose_input_indexes = {0, 2}; +NodePtr CreateConcatChain(NodePtr input_node, + size_t num_concat_ops, + element::Type input_type, + size_t concat_transpose_input_idx, + size_t num_concat_inputs, + const Shape& const_shape, + int64_t axis) { + NodePtr in_op = input_node; + for (size_t i = 0; i < num_concat_ops; ++i) { + OutputVector concat_inputs; + for (size_t j = 0; j < num_concat_inputs; ++j) { + if (j == concat_transpose_input_idx) + concat_inputs.push_back(in_op); + else + concat_inputs.push_back(std::make_shared(input_type, const_shape, Shape{1})); + } + in_op = std::make_shared(concat_inputs, axis); + } + + return in_op; +} + +NodePtr CreateConcatTransposedChain(NodePtr input_node, + size_t num_concat_ops, + element::Type input_type, + size_t concat_transpose_input_idx, + size_t num_concat_inputs, + const Shape& const_shape, + int64_t axis, + const Shape& transpose_order) { + NodePtr in_op = input_node; + for (size_t i = 0; i < num_concat_ops; ++i) { + OutputVector concat_inputs; + for (size_t j = 0; j < num_concat_inputs; ++j) { + if (j == concat_transpose_input_idx) { + concat_inputs.push_back(in_op); + } else { + auto in_constant = std::make_shared(input_type, const_shape, Shape{1}); + + auto transpose_const = + std::make_shared(element::u64, Shape{transpose_order.size()}, transpose_order); + auto transpose = std::make_shared(in_constant, transpose_const); + + concat_inputs.push_back(transpose); + } + } + in_op = std::make_shared(concat_inputs, axis); + } + + return in_op; +} + +NodePtr CreateConcatDoubleTransposedChain(NodePtr input_node, + size_t num_concat_ops, + element::Type input_type, + size_t concat_transpose_input_idx, + size_t num_concat_inputs, + const Shape& const_shape, + int64_t axis, + const Shape& transpose1_order, + const Shape& transpose2_order) { + NodePtr in_op = input_node; + for (size_t i = 0; i < num_concat_ops; ++i) { + OutputVector concat_inputs; + for (size_t j = 0; j < num_concat_inputs; ++j) { + if (j == concat_transpose_input_idx) { + concat_inputs.push_back(in_op); + } else { + auto in_constant = std::make_shared(input_type, const_shape, Shape{1}); + + auto ng_order1 = + std::make_shared(element::u64, Shape{transpose1_order.size()}, transpose1_order); + auto transpose1 = std::make_shared(in_constant, ng_order1); + + auto ng_order2 = + std::make_shared(element::u64, Shape{transpose2_order.size()}, transpose2_order); + auto transpose2 = std::make_shared(transpose1, ng_order2); + + concat_inputs.push_back(transpose2); + } + } + in_op = std::make_shared(concat_inputs, axis); + } + + return in_op; +} + } // namespace namespace single_consumer { namespace forward { namespace one_input_transpose { -std::shared_ptr CreateFunction(size_t num_concat_ops, - ov::element::Type input_type, - size_t concat_transpose_input_idx, - size_t num_concat_inputs) { - const ov::Shape input_shape{1, 96, 55, 55}; - const ov::Shape const_shape{1, 55, 55, 96}; +std::shared_ptr CreateFunction(size_t num_concat_ops, + element::Type input_type, + size_t concat_transpose_input_idx, + size_t num_concat_inputs) { + const Shape input_shape{1, 96, 55, 55}; + const Shape const_shape{1, 55, 55, 96}; - auto X = std::make_shared(input_type, input_shape); + auto X = std::make_shared(input_type, input_shape); - auto ng_order0 = std::make_shared(ov::element::u64, ov::Shape{4}, ov::Shape{0, 2, 3, 1}); - auto transpose0 = std::make_shared(X, ng_order0); + auto ng_order0 = std::make_shared(element::u64, Shape{4}, Shape{0, 2, 3, 1}); + auto transpose0 = std::make_shared(X, ng_order0); - NodePtr in_op = transpose0; - for (size_t i = 0; i < num_concat_ops; ++i) { - ov::OutputVector concat_inputs; - for (size_t j = 0; j < num_concat_inputs; ++j) { - if (j == concat_transpose_input_idx) - concat_inputs.push_back(in_op); - else - concat_inputs.push_back(std::make_shared(input_type, const_shape, ov::Shape{1})); - } - in_op = std::make_shared(concat_inputs, 1); - } + auto concat = CreateConcatChain(transpose0, + num_concat_ops, + input_type, + concat_transpose_input_idx, + num_concat_inputs, + const_shape, + /* axis */ 1); - return std::make_shared(ov::OutputVector{in_op}, ov::ParameterVector{X}); + return std::make_shared(OutputVector{concat}, ParameterVector{X}); } -std::shared_ptr CreateReferenceFunction(size_t num_concat_ops, - ov::element::Type input_type, - size_t concat_transpose_input_idx, - size_t num_concat_inputs) { - const ov::Shape input_shape{1, 96, 55, 55}; - const ov::Shape const_shape{1, 55, 55, 96}; +std::shared_ptr CreateReferenceFunction(size_t num_concat_ops, + element::Type input_type, + size_t concat_transpose_input_idx, + size_t num_concat_inputs) { + const Shape input_shape{1, 96, 55, 55}; + const Shape const_shape{1, 55, 55, 96}; - auto X = std::make_shared(input_type, input_shape); + auto X = std::make_shared(input_type, input_shape); - NodePtr in_op = X; - for (size_t i = 0; i < num_concat_ops; ++i) { - ov::OutputVector concat_inputs; - for (size_t j = 0; j < num_concat_inputs; ++j) { - if (j == concat_transpose_input_idx) { - concat_inputs.push_back(in_op); - } else { - auto in_constant = std::make_shared(input_type, const_shape, ov::Shape{1}); + auto concat = CreateConcatTransposedChain(X, + num_concat_ops, + input_type, + concat_transpose_input_idx, + num_concat_inputs, + const_shape, + /* axis */ 2, + /* transpose order */ Shape{0, 3, 1, 2}); - auto transpose_reversed_const = - std::make_shared(ov::element::u64, ov::Shape{4}, ov::Shape{0, 3, 1, 2}); - auto transpose_reversed = - std::make_shared(in_constant, transpose_reversed_const); + auto ng_order0 = std::make_shared(element::u64, Shape{4}, Shape{0, 2, 3, 1}); + auto transpose0 = std::make_shared(concat, ng_order0); - concat_inputs.push_back(transpose_reversed); - } - } - in_op = std::make_shared(concat_inputs, 2); - } - - auto ng_order0 = std::make_shared(ov::element::u64, ov::Shape{4}, ov::Shape{0, 2, 3, 1}); - auto transpose0 = std::make_shared(in_op, ng_order0); - - return std::make_shared(ov::OutputVector{transpose0}, ov::ParameterVector{X}); + return std::make_shared(OutputVector{transpose0}, ParameterVector{X}); } } // namespace one_input_transpose namespace double_transpose { -std::shared_ptr CreateFunction(size_t num_concat_ops, - ov::element::Type input_type, - size_t num_concat_inputs) { - const ov::Shape input_shape{1, 96, 55, 55}; +std::shared_ptr CreateFunction(size_t num_concat_ops, element::Type input_type, size_t num_concat_inputs) { + const Shape input_shape{1, 96, 55, 55}; - auto X = std::make_shared(input_type, input_shape); + auto X = std::make_shared(input_type, input_shape); - auto ng_order0 = std::make_shared(ov::element::u64, ov::Shape{4}, ov::Shape{0, 2, 3, 1}); - auto transpose0 = std::make_shared(X, ng_order0); + auto ng_order0 = std::make_shared(element::u64, Shape{4}, Shape{0, 2, 3, 1}); + auto transpose0 = std::make_shared(X, ng_order0); - NodePtr in_op = transpose0; - for (size_t i = 0; i < num_concat_ops; ++i) { - ov::OutputVector concat_inputs; - concat_inputs.push_back(in_op); - for (size_t j = 1; j < num_concat_inputs; ++j) { - auto in_constant = std::make_shared(input_type, input_shape, ov::Shape{1}); - auto ng_order1 = - std::make_shared(ov::element::u64, ov::Shape{4}, ov::Shape{0, 2, 3, 1}); - auto transpose1 = std::make_shared(in_constant, ng_order1); - concat_inputs.push_back(transpose1); - } - in_op = std::make_shared(concat_inputs, 1); - } + auto concat = CreateConcatTransposedChain(transpose0, + num_concat_ops, + input_type, + /* concat_transpose_input_idx */ 0, + num_concat_inputs, + input_shape, + /* axis */ 1, + /* transpose order */ Shape{0, 2, 3, 1}); - return std::make_shared(ov::OutputVector{in_op}, ov::ParameterVector{X}); + return std::make_shared(OutputVector{concat}, ParameterVector{X}); } -std::shared_ptr CreateReferenceFunction(size_t num_concat_ops, - ov::element::Type input_type, - size_t num_concat_inputs) { - const ov::Shape input_shape{1, 96, 55, 55}; +std::shared_ptr CreateReferenceFunction(size_t num_concat_ops, + element::Type input_type, + size_t num_concat_inputs) { + const Shape input_shape{1, 96, 55, 55}; - auto X = std::make_shared(input_type, input_shape); + auto X = std::make_shared(input_type, input_shape); - NodePtr in_op = X; - for (size_t i = 0; i < num_concat_ops; ++i) { - ov::OutputVector concat_inputs; + auto concat = CreateConcatDoubleTransposedChain(X, + num_concat_ops, + input_type, + /* concat_transpose_input_idx */ 0, + num_concat_inputs, + input_shape, + /* axis */ 2, + /* transpose1 order */ Shape{0, 2, 3, 1}, + /* transpose2 order */ Shape{0, 3, 1, 2}); - concat_inputs.push_back(in_op); + auto ng_order0 = std::make_shared(element::u64, Shape{4}, Shape{0, 2, 3, 1}); + auto transpose0 = std::make_shared(concat, ng_order0); - for (size_t j = 1; j < num_concat_inputs; ++j) { - auto in_constant = std::make_shared(input_type, input_shape, ov::Shape{1}); - - auto ng_order1 = - std::make_shared(ov::element::u64, ov::Shape{4}, ov::Shape{0, 2, 3, 1}); - auto transpose1 = std::make_shared(in_constant, ng_order1); - - auto transpose_reversed_const = - std::make_shared(ov::element::u64, ov::Shape{4}, ov::Shape{0, 3, 1, 2}); - auto transpose_reversed = std::make_shared(transpose1, transpose_reversed_const); - - concat_inputs.push_back(transpose_reversed); - } - in_op = std::make_shared(concat_inputs, 2); - } - - auto ng_order0 = std::make_shared(ov::element::u64, ov::Shape{4}, ov::Shape{0, 2, 3, 1}); - auto transpose0 = std::make_shared(in_op, ng_order0); - - return std::make_shared(ov::OutputVector{transpose0}, ov::ParameterVector{X}); + return std::make_shared(OutputVector{transpose0}, ParameterVector{X}); } } // namespace double_transpose @@ -245,75 +247,104 @@ std::shared_ptr CreateReferenceFunction(size_t num_concat_ops, namespace backward { -std::shared_ptr CreateFunction(size_t num_concat_ops, - ov::element::Type input_type, - size_t concat_transpose_input_idx, - size_t num_concat_inputs) { - const ov::Shape input_shape{1, 96, 55, 55}; +std::shared_ptr CreateFunction(size_t num_concat_ops, + element::Type input_type, + size_t concat_transpose_input_idx, + size_t num_concat_inputs) { + const Shape input_shape{1, 96, 55, 55}; - auto X = std::make_shared(input_type, input_shape); + auto X = std::make_shared(input_type, input_shape); - NodePtr in_op = X; - for (size_t i = 0; i < num_concat_ops; ++i) { - ov::OutputVector concat_inputs; - for (size_t j = 0; j < num_concat_inputs; ++j) { - if (j == concat_transpose_input_idx) - concat_inputs.push_back(in_op); - else - concat_inputs.push_back(std::make_shared(input_type, input_shape, ov::Shape{1})); - } - in_op = std::make_shared(concat_inputs, 1); - } + auto concat = CreateConcatChain(X, + num_concat_ops, + input_type, + concat_transpose_input_idx, + num_concat_inputs, + input_shape, + /* axis */ 1); - auto ng_order0 = std::make_shared(ov::element::u64, ov::Shape{4}, ov::Shape{0, 2, 3, 1}); - auto transpose0 = std::make_shared(in_op, ng_order0); + auto ng_order0 = std::make_shared(element::u64, Shape{4}, Shape{0, 2, 3, 1}); + auto transpose0 = std::make_shared(concat, ng_order0); - return std::make_shared(ov::OutputVector{transpose0}, ov::ParameterVector{X}); + return std::make_shared(OutputVector{transpose0}, ParameterVector{X}); } -std::shared_ptr CreateReferenceFunction(size_t num_concat_ops, - ov::element::Type input_type, - size_t concat_transpose_input_idx, - size_t num_concat_inputs) { - const ov::Shape input_shape{1, 96, 55, 55}; +std::shared_ptr CreateReferenceFunction(size_t num_concat_ops, + element::Type input_type, + size_t concat_transpose_input_idx, + size_t num_concat_inputs) { + const Shape input_shape{1, 96, 55, 55}; - auto X = std::make_shared(input_type, input_shape); + auto X = std::make_shared(input_type, input_shape); - auto ng_order0 = std::make_shared(ov::element::u64, ov::Shape{4}, ov::Shape{0, 2, 3, 1}); - auto transpose0 = std::make_shared(X, ng_order0); + auto ng_order0 = std::make_shared(element::u64, Shape{4}, Shape{0, 2, 3, 1}); + auto transpose0 = std::make_shared(X, ng_order0); - NodePtr in_op = transpose0; - for (size_t i = 0; i < num_concat_ops; ++i) { - ov::OutputVector concat_inputs; - for (size_t j = 0; j < num_concat_inputs; ++j) { - if (j == concat_transpose_input_idx) { - concat_inputs.push_back(in_op); - } else { - auto in_constant = std::make_shared(input_type, input_shape, ov::Shape{1}); + auto concat = CreateConcatTransposedChain(transpose0, + num_concat_ops, + input_type, + concat_transpose_input_idx, + num_concat_inputs, + input_shape, + /* axis */ 3, + /* transpose order */ Shape{0, 2, 3, 1}); - auto transpose_reversed_const = - std::make_shared(ov::element::u64, ov::Shape{4}, ov::Shape{0, 2, 3, 1}); - auto transpose_reversed = - std::make_shared(in_constant, transpose_reversed_const); - - concat_inputs.push_back(transpose_reversed); - } - } - in_op = std::make_shared(concat_inputs, 3); - } - - return std::make_shared(ov::OutputVector{in_op}, ov::ParameterVector{X}); + return std::make_shared(OutputVector{concat}, ParameterVector{X}); } } // namespace backward } // namespace single_consumer +using CreateGraphConcatF = std::function(size_t num_concat_ops, + element::Type input_type, + size_t concat_transpose_input_idx, + size_t num_concat_inputs)>; + +using TestConcatParams = std::tuple; /* num_concat_inputs */ + +class TransposeSinkingConcatTestFixture : public ::testing::WithParamInterface, + public TransformationTestsF { +public: + static std::string get_test_name(const testing::TestParamInfo& obj) { + PassFactoryPtr pass_factory; + size_t num_concat_ops; + CreateGraphConcatF model_factory; + CreateGraphConcatF reference_model_factory; + element::Type input_type; + size_t concat_transpose_input_idx; + size_t num_concat_inputs; + + std::tie(pass_factory, + num_concat_ops, + model_factory, + reference_model_factory, + input_type, + concat_transpose_input_idx, + num_concat_inputs) = obj.param; + + std::ostringstream test_name; + test_name << "passFactory=" << pass_factory->getTypeName() << "/"; + test_name << "numConcatOps=" << num_concat_ops << "/"; + test_name << "concatTransposeInputIdx=" << concat_transpose_input_idx << "/"; + test_name << "numConcatInputs=" << num_concat_inputs << "/"; + test_name << "inputType=" << input_type; + + return test_name.str(); + } +}; + TEST_P(TransposeSinkingConcatTestFixture, CompareFunctions) { PassFactoryPtr pass_factory; size_t num_concat_ops; CreateGraphConcatF model_factory; CreateGraphConcatF reference_model_factory; - ov::element::Type input_type; + element::Type input_type; size_t concat_transpose_input_idx; size_t num_concat_inputs; std::tie(pass_factory, @@ -332,49 +363,72 @@ TEST_P(TransposeSinkingConcatTestFixture, CompareFunctions) { INSTANTIATE_TEST_SUITE_P( TransposeSinkingConcatForwardTestSuite, TransposeSinkingConcatTestFixture, - ::testing::Combine(::testing::Values(CreatePassFactory()), + ::testing::Combine(::testing::Values(CREATE_PASS_FACTORY(TransposeSinkingConcatForward)), ::testing::ValuesIn(concat_operations_numbers), ::testing::Values(single_consumer::forward::one_input_transpose::CreateFunction), ::testing::Values(single_consumer::forward::one_input_transpose::CreateReferenceFunction), - ::testing::Values(ov::element::f32), + ::testing::Values(element::f32), ::testing::ValuesIn(concat_transpose_input_indexes), - ::testing::Values(5))); + ::testing::Values(5)), + TransposeSinkingConcatTestFixture::get_test_name); -INSTANTIATE_TEST_SUITE_P( - TransposeSinkingConcatBackwardTestSuite, - TransposeSinkingConcatTestFixture, - ::testing::Combine(::testing::Values(CreatePassFactory()), - ::testing::ValuesIn(concat_operations_numbers), - ::testing::Values(single_consumer::backward::CreateFunction), - ::testing::Values(single_consumer::backward::CreateReferenceFunction), - ::testing::Values(ov::element::f32), - ::testing::ValuesIn(concat_transpose_input_indexes), - ::testing::Values(5))); +INSTANTIATE_TEST_SUITE_P(TransposeSinkingConcatBackwardTestSuite, + TransposeSinkingConcatTestFixture, + ::testing::Combine(::testing::Values(CREATE_PASS_FACTORY(TransposeSinkingConcatBackward)), + ::testing::ValuesIn(concat_operations_numbers), + ::testing::Values(single_consumer::backward::CreateFunction), + ::testing::Values(single_consumer::backward::CreateReferenceFunction), + ::testing::Values(element::f32), + ::testing::ValuesIn(concat_transpose_input_indexes), + ::testing::Values(5)), + TransposeSinkingConcatTestFixture::get_test_name); // -------------------------------------------------------------------------------------- -using CreateGraphConcatAllTransposesInputF = std::function< - std::shared_ptr(size_t num_concat_ops, ov::element::Type input_type, size_t num_concat_inputs)>; +using CreateGraphConcatAllTransposesInputF = + std::function(size_t num_concat_ops, element::Type input_type, size_t num_concat_inputs)>; using TestConcatAllTransposesInputParams = std::tuple; /* num_concat_inputs */ class TransposeSinkingConcatAllTransposesInputTestFixture : public ::testing::WithParamInterface, - public TransformationTestsF {}; + public TransformationTestsF { +public: + static std::string get_test_name(const testing::TestParamInfo& obj) { + PassFactoryPtr pass_factory; + size_t num_concat_ops; + CreateGraphConcatAllTransposesInputF model_factory; + CreateGraphConcatAllTransposesInputF reference_model_factory; + element::Type input_type; + size_t num_concat_inputs; + + std::tie(pass_factory, num_concat_ops, model_factory, reference_model_factory, input_type, num_concat_inputs) = + obj.param; + + std::ostringstream test_name; + test_name << "passFactory=" << pass_factory->getTypeName() << "/"; + test_name << "numConcatOps=" << num_concat_ops << "/"; + test_name << "numConcatInputs=" << num_concat_inputs << "/"; + test_name << "inputType=" << input_type; + + return test_name.str(); + } +}; TEST_P(TransposeSinkingConcatAllTransposesInputTestFixture, CompareFunctions) { PassFactoryPtr pass_factory; size_t num_concat_ops; CreateGraphConcatAllTransposesInputF model_factory; CreateGraphConcatAllTransposesInputF reference_model_factory; - ov::element::Type input_type; + element::Type input_type; size_t num_concat_inputs; + std::tie(pass_factory, num_concat_ops, model_factory, reference_model_factory, input_type, num_concat_inputs) = this->GetParam(); @@ -386,9 +440,635 @@ TEST_P(TransposeSinkingConcatAllTransposesInputTestFixture, CompareFunctions) { INSTANTIATE_TEST_SUITE_P( TransposeSinkingConcatForwardAllTransposesTestSuite, TransposeSinkingConcatAllTransposesInputTestFixture, - ::testing::Combine(::testing::Values(CreatePassFactory()), + ::testing::Combine(::testing::Values(CREATE_PASS_FACTORY(TransposeSinkingConcatForward)), ::testing::ValuesIn(concat_operations_numbers), ::testing::Values(single_consumer::forward::double_transpose::CreateFunction), ::testing::Values(single_consumer::forward::double_transpose::CreateReferenceFunction), - ::testing::Values(ov::element::f32), - ::testing::Values(5))); + ::testing::Values(element::f32), + ::testing::Values(5)), + TransposeSinkingConcatAllTransposesInputTestFixture::get_test_name); + +// -------------------------------------------------------------------------------------- + +namespace mult_consumers { +namespace forward { +namespace input_transpose_consumers { + +std::shared_ptr CreateFunction(size_t num_concat_ops, + element::Type input_type, + size_t concat_transpose_input_idx, + size_t num_concat_inputs) { + const Shape input_shape{1, 96, 55, 55}; + const Shape const_shape{1, 55, 55, 96}; + + auto X = std::make_shared(input_type, input_shape); + + auto ng_order0 = std::make_shared(element::u64, Shape{4}, Shape{0, 2, 3, 1}); + auto transpose0 = std::make_shared(X, ng_order0); + + auto tanh = std::make_shared(transpose0); + + auto concat = CreateConcatChain(transpose0, + num_concat_ops, + input_type, + concat_transpose_input_idx, + num_concat_inputs, + const_shape, + /* axis */ 1); + + return std::make_shared(OutputVector{concat, tanh}, ov::ParameterVector{X}); +} + +std::shared_ptr CreateReferenceFunction(size_t num_concat_ops, + element::Type input_type, + size_t concat_transpose_input_idx, + size_t num_concat_inputs) { + const Shape input_shape{1, 96, 55, 55}; + const Shape const_shape{1, 55, 55, 96}; + + auto X = std::make_shared(input_type, input_shape); + + auto ng_order0 = std::make_shared(element::u64, Shape{4}, Shape{0, 2, 3, 1}); + auto transpose0 = std::make_shared(X, ng_order0); + + auto tanh = std::make_shared(transpose0); + + NodePtr binary; + auto in_constant = std::make_shared(input_type, const_shape, Shape{1}); + + auto transpose_reversed_const = std::make_shared(element::u64, Shape{4}, Shape{0, 3, 1, 2}); + auto transpose_reversed = std::make_shared(in_constant, transpose_reversed_const); + + auto concat = CreateConcatTransposedChain(X, + num_concat_ops, + input_type, + concat_transpose_input_idx, + num_concat_inputs, + const_shape, + /* axis */ 2, + /* transpose order */ Shape{0, 3, 1, 2}); + + auto ng_order1 = std::make_shared(element::u64, Shape{4}, Shape{0, 2, 3, 1}); + auto transpose1 = std::make_shared(concat, ng_order1); + + return std::make_shared(ov::OutputVector{transpose1, tanh}, ov::ParameterVector{X}); +} + +} // namespace input_transpose_consumers + +namespace output_consumers { + +std::shared_ptr CreateFunction(size_t num_concat_ops, + element::Type input_type, + size_t concat_transpose_input_idx, + size_t num_concat_inputs) { + const Shape input_shape{1, 96, 55, 55}; + const Shape const_shape{1, 55, 55, 96}; + + auto X = std::make_shared(input_type, input_shape); + + auto ng_order0 = std::make_shared(element::u64, Shape{4}, Shape{0, 2, 3, 1}); + auto transpose0 = std::make_shared(X, ng_order0); + + auto concat = CreateConcatChain(transpose0, + num_concat_ops, + input_type, + concat_transpose_input_idx, + num_concat_inputs, + const_shape, + /* axis */ 1); + + auto tanh1 = std::make_shared(concat); + auto tanh2 = std::make_shared(concat); + + return std::make_shared(ov::OutputVector{tanh1, tanh2}, ov::ParameterVector{X}); +} + +std::shared_ptr CreateReferenceFunction(size_t num_concat_ops, + element::Type input_type, + size_t concat_transpose_input_idx, + size_t num_concat_inputs) { + const Shape input_shape{1, 96, 55, 55}; + const Shape const_shape{1, 55, 55, 96}; + + auto X = std::make_shared(input_type, input_shape); + + NodePtr binary; + auto in_constant = std::make_shared(input_type, const_shape, Shape{1}); + + auto transpose_reversed_const = std::make_shared(element::u64, Shape{4}, Shape{0, 3, 1, 2}); + auto transpose_reversed = std::make_shared(in_constant, transpose_reversed_const); + + auto concat = CreateConcatTransposedChain(X, + num_concat_ops, + input_type, + concat_transpose_input_idx, + num_concat_inputs, + const_shape, + /* axis */ 2, + /* transpose order */ Shape{0, 3, 1, 2}); + + auto ng_order0 = std::make_shared(element::u64, Shape{4}, Shape{0, 2, 3, 1}); + auto transpose0 = std::make_shared(concat, ng_order0); + + auto tanh1 = std::make_shared(transpose0); + auto tanh2 = std::make_shared(transpose0); + + return std::make_shared(ov::OutputVector{tanh1, tanh2}, ov::ParameterVector{X}); +} + +} // namespace output_consumers + +namespace input_node_consumers { + +std::shared_ptr CreateFunction(size_t num_concat_ops, + element::Type input_type, + size_t concat_transpose_input_idx, + size_t num_concat_inputs) { + const Shape input_shape{1, 96, 55, 55}; + const Shape const_shape{1, 55, 55, 96}; + + auto X = std::make_shared(input_type, input_shape); + + auto tanh = std::make_shared(X); + + auto ng_order0 = std::make_shared(element::u64, Shape{4}, Shape{0, 2, 3, 1}); + auto transpose0 = std::make_shared(X, ng_order0); + + auto concat = CreateConcatChain(transpose0, + num_concat_ops, + input_type, + concat_transpose_input_idx, + num_concat_inputs, + const_shape, + /* axis */ 1); + + return std::make_shared(ov::OutputVector{concat, tanh}, ov::ParameterVector{X}); +} + +std::shared_ptr CreateReferenceFunction(size_t num_concat_ops, + element::Type input_type, + size_t concat_transpose_input_idx, + size_t num_concat_inputs) { + const Shape input_shape{1, 96, 55, 55}; + const Shape const_shape{1, 55, 55, 96}; + + auto X = std::make_shared(input_type, input_shape); + + auto tanh = std::make_shared(X); + + auto ng_order0 = std::make_shared(element::u64, Shape{4}, Shape{0, 2, 3, 1}); + auto transpose0 = std::make_shared(X, ng_order0); + + NodePtr binary; + auto in_constant = std::make_shared(input_type, const_shape, Shape{1}); + + auto concat = CreateConcatTransposedChain(X, + num_concat_ops, + input_type, + concat_transpose_input_idx, + num_concat_inputs, + const_shape, + /* axis */ 2, + /* transpose order */ Shape{0, 3, 1, 2}); + + auto ng_order1 = std::make_shared(element::u64, Shape{4}, Shape{0, 2, 3, 1}); + auto transpose1 = std::make_shared(concat, ng_order1); + + return std::make_shared(ov::OutputVector{transpose1, tanh}, ov::ParameterVector{X}); +} + +} // namespace input_node_consumers + +} // namespace forward + +namespace backward { + +namespace output_consumers { + +namespace one_binary { +std::shared_ptr CreateFunction(size_t num_concat_ops, + element::Type input_type, + size_t concat_transpose_input_idx, + size_t num_concat_inputs) { + const Shape input_shape{1, 96, 55, 55}; + + auto X = std::make_shared(input_type, input_shape); + + auto tanh0 = std::make_shared(X); + + auto concat = CreateConcatChain(tanh0, + num_concat_ops, + input_type, + concat_transpose_input_idx, + num_concat_inputs, + input_shape, + /* axis */ 1); + + auto tanh = std::make_shared(concat); + + auto ng_order0 = std::make_shared(element::u64, Shape{4}, Shape{0, 2, 3, 1}); + auto transpose0 = std::make_shared(concat, ng_order0); + + return std::make_shared(ov::OutputVector{transpose0, tanh}, ov::ParameterVector{X}); +} + +} // namespace one_binary + +namespace multiple_binaries { + +std::shared_ptr CreateFunction(size_t num_concat_ops, + element::Type input_type, + size_t concat_transpose_input_idx, + size_t num_concat_inputs) { + const Shape input_shape{1, 96, 55, 55}; + + auto X = std::make_shared(input_type, input_shape); + + auto tanh0 = std::make_shared(X); + + auto concat = CreateConcatChain(tanh0, + num_concat_ops, + input_type, + concat_transpose_input_idx, + num_concat_inputs, + input_shape, + /* axis */ 1); + + auto tanh = std::make_shared(concat); + + auto ng_order0 = std::make_shared(element::u64, Shape{4}, Shape{0, 2, 3, 1}); + auto transpose0 = std::make_shared(concat, ng_order0); + + return std::make_shared(ov::OutputVector{transpose0, tanh}, ov::ParameterVector{X}); +} + +} // namespace multiple_binaries + +} // namespace output_consumers + +namespace input_node_consumers { + +std::shared_ptr CreateFunction(size_t num_concat_ops, + element::Type input_type, + size_t concat_transpose_input_idx, + size_t num_concat_inputs) { + const Shape input_shape{1, 96, 55, 55}; + + auto X = std::make_shared(input_type, input_shape); + + auto tanh0 = std::make_shared(X); + + auto concat = CreateConcatChain(tanh0, + num_concat_ops, + input_type, + concat_transpose_input_idx, + num_concat_inputs, + input_shape, + /* axis */ 1); + + auto ng_order0 = std::make_shared(element::u64, Shape{4}, Shape{0, 2, 3, 1}); + auto transpose0 = std::make_shared(concat, ng_order0); + + auto tanh1 = std::make_shared(tanh0); + + return std::make_shared(ov::OutputVector{transpose0, tanh1}, ov::ParameterVector{X}); +} + +std::shared_ptr CreateReferenceFunction(size_t num_concat_ops, + element::Type input_type, + size_t concat_transpose_input_idx, + size_t num_concat_inputs) { + const Shape input_shape{1, 96, 55, 55}; + + auto X = std::make_shared(input_type, input_shape); + + auto tanh0 = std::make_shared(X); + + auto ng_order0 = std::make_shared(element::u64, Shape{4}, Shape{0, 2, 3, 1}); + auto transpose0 = std::make_shared(tanh0, ng_order0); + + auto concat = CreateConcatTransposedChain(transpose0, + num_concat_ops, + input_type, + concat_transpose_input_idx, + num_concat_inputs, + input_shape, + /* axis */ 3, + /* transpose order */ Shape{0, 2, 3, 1}); + + auto tanh1 = std::make_shared(tanh0); + + return std::make_shared(ov::OutputVector{concat, tanh1}, ov::ParameterVector{X}); +} + +} // namespace input_node_consumers + +namespace output_transpose_mult_consumers { + +std::shared_ptr CreateFunction(size_t num_concat_ops, + element::Type input_type, + size_t concat_transpose_input_idx, + size_t num_concat_inputs) { + const Shape input_shape{1, 96, 55, 55}; + + auto X = std::make_shared(input_type, input_shape); + + auto concat = CreateConcatChain(X, + num_concat_ops, + input_type, + concat_transpose_input_idx, + num_concat_inputs, + input_shape, + /* axis */ 1); + + auto ng_order0 = std::make_shared(element::u64, Shape{4}, Shape{0, 2, 3, 1}); + auto transpose0 = std::make_shared(concat, ng_order0); + + auto tanh0 = std::make_shared(transpose0); + auto tanh1 = std::make_shared(transpose0); + + return std::make_shared(ov::OutputVector{tanh0, tanh1}, ov::ParameterVector{X}); +} + +std::shared_ptr CreateReferenceFunction(size_t num_concat_ops, + element::Type input_type, + size_t concat_transpose_input_idx, + size_t num_concat_inputs) { + const Shape input_shape{1, 96, 55, 55}; + + auto X = std::make_shared(input_type, input_shape); + + auto ng_order0 = std::make_shared(element::u64, Shape{4}, Shape{0, 2, 3, 1}); + auto transpose0 = std::make_shared(X, ng_order0); + + auto concat = CreateConcatTransposedChain(transpose0, + num_concat_ops, + input_type, + concat_transpose_input_idx, + num_concat_inputs, + input_shape, + /* axis */ 3, + /* transpose order */ Shape{0, 2, 3, 1}); + + auto tanh0 = std::make_shared(concat); + auto tanh1 = std::make_shared(concat); + + return std::make_shared(ov::OutputVector{tanh0, tanh1}, ov::ParameterVector{X}); +} + +} // namespace output_transpose_mult_consumers + +namespace output_transpose_mult_transposes { + +std::shared_ptr CreateFunction(size_t num_concat_ops, + element::Type input_type, + size_t concat_transpose_input_idx, + size_t num_concat_inputs) { + const Shape input_shape{1, 96, 55, 55}; + + auto X = std::make_shared(input_type, input_shape); + + auto concat = CreateConcatChain(X, + num_concat_ops, + input_type, + concat_transpose_input_idx, + num_concat_inputs, + input_shape, + /* axis */ 1); + + auto ng_order0 = std::make_shared(element::u64, Shape{4}, Shape{0, 2, 3, 1}); + auto transpose0 = std::make_shared(concat, ng_order0); + + auto tanh0 = std::make_shared(transpose0); + + auto ng_order1 = std::make_shared(element::u64, Shape{4}, Shape{0, 2, 3, 1}); + auto transpose1 = std::make_shared(concat, ng_order1); + + auto tanh1 = std::make_shared(transpose1); + + return std::make_shared(ov::OutputVector{tanh0, tanh1}, ov::ParameterVector{X}); +} + +std::shared_ptr CreateReferenceFunction(size_t num_concat_ops, + element::Type input_type, + size_t concat_transpose_input_idx, + size_t num_concat_inputs) { + const Shape input_shape{1, 96, 55, 55}; + + auto X = std::make_shared(input_type, input_shape); + + auto ng_order0 = std::make_shared(element::u64, Shape{4}, Shape{0, 2, 3, 1}); + auto transpose0 = std::make_shared(X, ng_order0); + + auto concat = CreateConcatTransposedChain(transpose0, + num_concat_ops, + input_type, + concat_transpose_input_idx, + num_concat_inputs, + input_shape, + /* axis */ 3, + /* transpose order */ Shape{0, 2, 3, 1}); + + auto tanh0 = std::make_shared(concat); + auto tanh1 = std::make_shared(concat); + + return std::make_shared(ov::OutputVector{tanh0, tanh1}, ov::ParameterVector{X}); +} + +} // namespace output_transpose_mult_transposes + +} // namespace backward + +using CreateGraphF = std::function(size_t num_concat_ops, + element::Type input_type, + size_t concat_transpose_input_idx, + size_t num_concat_inputs)>; + +struct CreateGraphFunctionDesc { + CreateGraphFunctionDesc() = default; + CreateGraphFunctionDesc(CreateGraphF a_model_factory, CreateGraphF a_ref_model_factory, std::string a_subtest_name) + : model_factory(a_model_factory), + reference_model_factory(a_ref_model_factory), + subtest_name(a_subtest_name) {} + CreateGraphF model_factory; + CreateGraphF reference_model_factory; + std::string subtest_name; +}; + +using TestConcatParams = std::tuple; /* num_concat_inputs */ + +class TransposeConcatMultiSinkingFixture : public ::testing::WithParamInterface, + public TransformationTestsF { +public: + static std::string get_test_name(const testing::TestParamInfo& obj) { + PassFactoryPtr pass_factory; + size_t num_concat_ops; + CreateGraphFunctionDesc function_desc; + element::Type input_type; + size_t concat_transpose_input_idx; + size_t num_concat_inputs; + + std::tie(pass_factory, + num_concat_ops, + function_desc, + input_type, + concat_transpose_input_idx, + num_concat_inputs) = obj.param; + + std::ostringstream test_name; + test_name << "passFactory=" << pass_factory->getTypeName() << "/"; + test_name << "functionDesc=" << function_desc.subtest_name << "/"; + test_name << "numConcatOps=" << num_concat_ops << "/"; + test_name << "concatTransposeInputIdx=" << concat_transpose_input_idx << "/"; + test_name << "numConcatInputs=" << num_concat_inputs << "/"; + test_name << "inputType=" << input_type; + + return test_name.str(); + } +}; + +TEST_P(TransposeConcatMultiSinkingFixture, CompareFunctions) { + PassFactoryPtr pass_factory; + size_t num_concat_ops; + CreateGraphFunctionDesc function_desc; + element::Type input_type; + size_t concat_transpose_input_idx; + size_t num_concat_inputs; + + std::tie(pass_factory, num_concat_ops, function_desc, input_type, concat_transpose_input_idx, num_concat_inputs) = + this->GetParam(); + + model = function_desc.model_factory(num_concat_ops, input_type, concat_transpose_input_idx, num_concat_inputs); + model_ref = function_desc.reference_model_factory(num_concat_ops, + input_type, + concat_transpose_input_idx, + num_concat_inputs); + pass_factory->registerPass(manager); +} + +#define SUBTEST(nmspace, subtest_name) \ + CreateGraphFunctionDesc(nmspace::CreateFunction, nmspace::CreateReferenceFunction, subtest_name) + +std::vector forward_subtests = { + SUBTEST(forward::input_transpose_consumers, "forwardInputTransposeConsumers"), + SUBTEST(forward::output_consumers, "forwardOutputConsumers"), + SUBTEST(forward::input_node_consumers, "forwardInputNodeConsumers")}; + +std::vector backward_subtests = { + SUBTEST(backward::input_node_consumers, "backwardInputNodeConsumers"), + SUBTEST(backward::output_transpose_mult_consumers, "backwardOutputTransposeMultConsumers"), + SUBTEST(backward::output_transpose_mult_transposes, "outputTransposeMultTransposes")}; + +#undef SUBTEST + +INSTANTIATE_TEST_SUITE_P(TransposeSinkingConcatForwardMultiConsumersTestSuite, + TransposeConcatMultiSinkingFixture, + ::testing::Combine(::testing::Values(CREATE_PASS_FACTORY(TransposeSinkingConcatForward)), + ::testing::ValuesIn(concat_operations_numbers), + ::testing::ValuesIn(forward_subtests), + ::testing::Values(element::f32), + ::testing::ValuesIn(concat_transpose_input_indexes), + ::testing::Values(5)), + TransposeConcatMultiSinkingFixture::get_test_name); + +INSTANTIATE_TEST_SUITE_P(TransposeSinkingConcatBackwardMultiConsumersTestSuite, + TransposeConcatMultiSinkingFixture, + ::testing::Combine(::testing::Values(CREATE_PASS_FACTORY(TransposeSinkingConcatBackward)), + ::testing::ValuesIn(concat_operations_numbers), + ::testing::ValuesIn(backward_subtests), + ::testing::Values(element::f32), + ::testing::ValuesIn(concat_transpose_input_indexes), + ::testing::Values(5)), + TransposeConcatMultiSinkingFixture::get_test_name); + +namespace no_sinking { + +struct CreateGraphFunctionNoSinkingDesc { + CreateGraphFunctionNoSinkingDesc() = default; + CreateGraphFunctionNoSinkingDesc(CreateGraphF a_model_factory, std::string a_subtest_name) + : model_factory(a_model_factory), + subtest_name(a_subtest_name) {} + CreateGraphF model_factory; + std::string subtest_name; +}; + +using TestConcatParams = std::tuple; /* num_concat_inputs */ + +class TransposeConcatMultiSinkingConcatConsumersFixture : public ::testing::WithParamInterface, + public TransformationTestsF { +public: + static std::string get_test_name(const testing::TestParamInfo& obj) { + PassFactoryPtr pass_factory; + size_t num_concat_ops; + CreateGraphFunctionNoSinkingDesc function_desc; + element::Type input_type; + size_t concat_transpose_input_idx; + size_t num_concat_inputs; + + std::tie(pass_factory, + num_concat_ops, + function_desc, + input_type, + concat_transpose_input_idx, + num_concat_inputs) = obj.param; + + std::ostringstream test_name; + test_name << "passFactory=" << pass_factory->getTypeName() << "/"; + test_name << "functionDesc=" << function_desc.subtest_name << "/"; + test_name << "numConcatOps=" << num_concat_ops << "/"; + test_name << "concatTransposeInputIdx=" << concat_transpose_input_idx << "/"; + test_name << "numConcatInputs=" << num_concat_inputs << "/"; + test_name << "inputType=" << input_type; + + return test_name.str(); + } +}; + +TEST_P(TransposeConcatMultiSinkingConcatConsumersFixture, CompareFunctions) { + PassFactoryPtr pass_factory; + size_t num_concat_ops; + CreateGraphFunctionNoSinkingDesc function_desc; + element::Type input_type; + size_t concat_transpose_input_idx; + size_t num_concat_inputs; + + std::tie(pass_factory, num_concat_ops, function_desc, input_type, concat_transpose_input_idx, num_concat_inputs) = + this->GetParam(); + + model = function_desc.model_factory(num_concat_ops, input_type, concat_transpose_input_idx, num_concat_inputs); + model_ref = model->clone(); + pass_factory->registerPass(manager); +} + +#define SUBTEST(nmspace, subtest_name) CreateGraphFunctionNoSinkingDesc(nmspace::CreateFunction, subtest_name) + +std::vector backward_subtests_no_sinking = { + SUBTEST(backward::output_consumers::one_binary, "backwardOutputConsumersOneBinary"), + SUBTEST(backward::output_consumers::multiple_binaries, "backwardOutputConsumersMultipleBinaries")}; + +#undef SUBTEST + +INSTANTIATE_TEST_SUITE_P(TransposeSinkingConcatBackwardMultiConsumersTestSuite, + TransposeConcatMultiSinkingConcatConsumersFixture, + ::testing::Combine(::testing::Values(CREATE_PASS_FACTORY(TransposeSinkingConcatBackward)), + ::testing::ValuesIn(concat_operations_numbers), + ::testing::ValuesIn(backward_subtests_no_sinking), + ::testing::Values(element::f32), + ::testing::ValuesIn(concat_transpose_input_indexes), + ::testing::Values(5)), + TransposeConcatMultiSinkingConcatConsumersFixture::get_test_name); + +} // namespace no_sinking + +} // namespace mult_consumers diff --git a/src/common/transformations/tests/common_optimizations/transpose_sinking_split_test.cpp b/src/common/transformations/tests/common_optimizations/transpose_sinking_split_test.cpp index 2ba1af4724e..27fa8d0a6fa 100644 --- a/src/common/transformations/tests/common_optimizations/transpose_sinking_split_test.cpp +++ b/src/common/transformations/tests/common_optimizations/transpose_sinking_split_test.cpp @@ -12,44 +12,27 @@ #include "common_test_utils/ngraph_test_utils.hpp" #include "gtest/gtest.h" +using namespace ov; +using namespace ov::opset9; + +namespace transpose_sinking_split { + namespace { using NodePtr = std::shared_ptr; -using ModelPtr = std::shared_ptr; -using Output = ov::Output; - -// ---------------------------------------------------------------------------- - -class IBinaryFactory { -public: - IBinaryFactory() = default; - virtual ~IBinaryFactory() = default; - virtual NodePtr create(NodePtr parent_left_node, NodePtr parent_right_node) const = 0; -}; - -using BinaryFactoryPtr = std::shared_ptr; - -template -class BinaryFactory : public IBinaryFactory { -public: - BinaryFactory() = default; - NodePtr create(NodePtr parent_left_node, NodePtr parent_right_node) const override { - return std::make_shared(parent_left_node, parent_right_node); - } -}; - -template -BinaryFactoryPtr CreateBinaryFactory() { - return std::make_shared>(); -} - -// ---------------------------------------------------------------------------- +using ModelPtr = std::shared_ptr; class IPassFactory { public: - IPassFactory() = default; + IPassFactory(const std::string& type_name) : type_name_(type_name) {} virtual ~IPassFactory() = default; virtual void registerPass(ov::pass::Manager& pass_manager) const = 0; + const std::string& getTypeName() const { + return type_name_; + } + +private: + const std::string type_name_; }; using PassFactoryPtr = std::shared_ptr; @@ -57,72 +40,37 @@ using PassFactoryPtr = std::shared_ptr; template class PassFactory : public IPassFactory { public: + PassFactory(const std::string& type_name) : IPassFactory(type_name) {} void registerPass(ov::pass::Manager& pass_manager) const override { pass_manager.register_pass(); } }; -template -PassFactoryPtr CreatePassFactory() { - return std::make_shared>(); -} - -std::vector binary_factories = {CreateBinaryFactory(), - CreateBinaryFactory(), - CreateBinaryFactory(), - CreateBinaryFactory(), - CreateBinaryFactory(), - CreateBinaryFactory(), - CreateBinaryFactory(), - CreateBinaryFactory(), - CreateBinaryFactory()}; - -std::vector binary_operations_numbers = {1, 10}; - -std::vector binary_transpose_input_indexes = {0, 1}; +#define CREATE_PASS_FACTORY(pass_name) std::make_shared>(#pass_name) } // namespace -// -------------------------------------------------------------------------------------- - -using CreateGraphSplitForwardF = std::function< - std::shared_ptr(size_t num_split_ops, size_t num_split_outputs, ov::element::Type input_type)>; - -using TestSplitForwardParams = std::tuple /* input type */; - -class TransposeSinkingSplitForwardTestFixture : public ::testing::WithParamInterface, - public TransformationTestsF {}; - -namespace { - +std::vector split_tree_depth_nums = {1, 3}; std::vector split_operations_numbers = {1, 10}; - std::vector split_outputs_numbers = {2, 3}; -} // namespace - -namespace split { namespace forward { -std::shared_ptr CreateFunction(size_t num_split_ops, - size_t num_split_outputs, - ov::element::Type input_type) { - const ov::Shape input_shape{96, static_cast(std::pow(num_split_outputs, num_split_ops + 1)), 55, 55}; - auto X = std::make_shared(input_type, input_shape); +namespace single_consumer { - auto ng_order0 = std::make_shared(ov::element::u64, ov::Shape{4}, ov::Shape{0, 3, 1, 2}); - auto transpose0 = std::make_shared(X, ng_order0); +std::shared_ptr CreateFunction(size_t num_split_ops, size_t num_split_outputs, element::Type input_type) { + const Shape input_shape{96, static_cast(std::pow(num_split_outputs, num_split_ops + 1)), 55, 55}; + + auto X = std::make_shared(input_type, input_shape); + + auto ng_order0 = std::make_shared(ov::element::u64, ov::Shape{4}, ov::Shape{0, 3, 1, 2}); + auto transpose0 = std::make_shared(X, ng_order0); ov::OutputVector outputs; - Output in_op = transpose0->output(0); + ov::Output in_op = transpose0->output(0); for (size_t i = 0; i < num_split_ops; ++i) { - auto split_axis_const = std::make_shared(ov::element::u64, ov::Shape{}, 2); - auto split = std::make_shared(in_op, split_axis_const, num_split_outputs); + auto split_axis_const = std::make_shared(element::u64, Shape{}, 2); + auto split = std::make_shared(in_op, split_axis_const, num_split_outputs); for (size_t num_output = 0; num_output < num_split_outputs - 1; ++num_output) { outputs.push_back(split->output(num_output)); } @@ -133,106 +81,254 @@ std::shared_ptr CreateFunction(size_t num_split_ops, return std::make_shared(outputs, ov::ParameterVector{X}); } -std::shared_ptr CreateReferenceFunction(size_t num_split_ops, - size_t num_split_outputs, - ov::element::Type input_type) { - const ov::Shape input_shape{96, static_cast(std::pow(num_split_outputs, num_split_ops + 1)), 55, 55}; +std::shared_ptr CreateReferenceFunction(size_t num_split_ops, + size_t num_split_outputs, + element::Type input_type) { + const Shape input_shape{96, static_cast(std::pow(num_split_outputs, num_split_ops + 1)), 55, 55}; - auto X = std::make_shared(input_type, input_shape); + auto X = std::make_shared(input_type, input_shape); ov::OutputVector outputs; - Output in_op = X->output(0); + ov::Output in_op = X->output(0); for (size_t i = 0; i < num_split_ops; ++i) { - auto split_axis_const = std::make_shared(ov::element::u64, ov::Shape{}, 1); - auto split = std::make_shared(in_op, split_axis_const, num_split_outputs); + auto split_axis_const = std::make_shared(element::u64, Shape{}, 1); + auto split = std::make_shared(in_op, split_axis_const, num_split_outputs); for (size_t num_output = 0; num_output < num_split_outputs - 1; ++num_output) { - auto ng_order0 = - std::make_shared(ov::element::u64, ov::Shape{4}, ov::Shape{0, 3, 1, 2}); - auto transpose0 = std::make_shared(split->output(num_output), ng_order0); + auto ng_order0 = std::make_shared(ov::element::u64, ov::Shape{4}, ov::Shape{0, 3, 1, 2}); + auto transpose0 = std::make_shared(split->output(num_output), ng_order0); outputs.push_back(transpose0); } in_op = split->output(num_split_outputs - 1); } - auto ng_order0 = std::make_shared(ov::element::u64, ov::Shape{4}, ov::Shape{0, 3, 1, 2}); - auto transpose0 = std::make_shared(in_op, ng_order0); + auto ng_order0 = std::make_shared(ov::element::u64, ov::Shape{4}, ov::Shape{0, 3, 1, 2}); + auto transpose0 = std::make_shared(in_op, ng_order0); outputs.push_back(transpose0); return std::make_shared(outputs, ov::ParameterVector{X}); } -} // namespace forward -} // namespace split +} // namespace single_consumer -TEST_P(TransposeSinkingSplitForwardTestFixture, CompareFunctions) { - PassFactoryPtr pass_factory; - size_t num_split_ops; - size_t num_split_outputs; - CreateGraphSplitForwardF model_factory; - CreateGraphSplitForwardF reference_model_factory; - ov::element::Type input_type; - std::tie(pass_factory, num_split_ops, num_split_outputs, model_factory, reference_model_factory, input_type) = - this->GetParam(); +namespace mult_consumers { - model = model_factory(num_split_ops, num_split_outputs, input_type); - model_ref = reference_model_factory(num_split_ops, num_split_outputs, input_type); - pass_factory->registerPass(manager); +namespace input_node_consumers { + +std::shared_ptr CreateFunction(size_t num_split_ops, size_t num_split_outputs, element::Type input_type) { + const Shape input_shape{96, static_cast(std::pow(num_split_outputs, num_split_ops + 1)), 55, 55}; + + auto X = std::make_shared(input_type, input_shape); + + auto tanh = std::make_shared(X); + + auto ng_order0 = std::make_shared(element::u64, Shape{4}, Shape{0, 3, 1, 2}); + auto transpose0 = std::make_shared(tanh, ng_order0); + + ov::OutputVector outputs; + auto in_op = transpose0->output(0); + for (size_t i = 0; i < num_split_ops; ++i) { + auto split_axis_const = std::make_shared(element::u64, Shape{}, 2); + auto split = std::make_shared(in_op, split_axis_const, num_split_outputs); + for (size_t num_output = 0; num_output < num_split_outputs - 1; ++num_output) { + outputs.push_back(split->output(num_output)); + } + in_op = split->output(num_split_outputs - 1); + } + outputs.push_back(in_op); + + auto tanh1 = std::make_shared(tanh); + outputs.push_back(tanh1); + + return std::make_shared(outputs, ov::ParameterVector{X}); } -INSTANTIATE_TEST_SUITE_P( - TransposeSinkingSplitForwardTestSuite, - TransposeSinkingSplitForwardTestFixture, - ::testing::Combine(::testing::Values(CreatePassFactory()), - ::testing::ValuesIn(split_operations_numbers), - ::testing::ValuesIn(split_outputs_numbers), - ::testing::Values(split::forward::CreateFunction), - ::testing::Values(split::forward::CreateReferenceFunction), - ::testing::Values(ov::element::f32))); +std::shared_ptr CreateReferenceFunction(size_t num_split_ops, + size_t num_split_outputs, + element::Type input_type) { + const Shape input_shape{96, static_cast(std::pow(num_split_outputs, num_split_ops + 1)), 55, 55}; -// -------------------------------------------------------------------------------------- + auto X = std::make_shared(input_type, input_shape); -using CreateGraphSplitBackwardF = std::function< - std::shared_ptr(size_t split_tree_depth, size_t num_split_outputs, ov::element::Type input_type)>; + auto tanh = std::make_shared(X); -using TestSplitBackwardParams = std::tuple /* input type */; + ov::OutputVector outputs; + auto in_op = tanh->output(0); + for (size_t i = 0; i < num_split_ops; ++i) { + auto split_axis_const = std::make_shared(element::u64, Shape{}, 1); + auto split = std::make_shared(in_op, split_axis_const, num_split_outputs); + for (size_t num_output = 0; num_output < num_split_outputs - 1; ++num_output) { + auto ng_order0 = std::make_shared(element::u64, Shape{4}, Shape{0, 3, 1, 2}); + auto transpose0 = std::make_shared(split->output(num_output), ng_order0); + outputs.push_back(transpose0); + } + in_op = split->output(num_split_outputs - 1); + } -class TransposeSinkingSplitBackwardTestFixture : public ::testing::WithParamInterface, - public TransformationTestsF {}; + auto ng_order0 = std::make_shared(element::u64, Shape{4}, Shape{0, 3, 1, 2}); + auto transpose0 = std::make_shared(in_op, ng_order0); + outputs.push_back(transpose0); -namespace { -std::vector split_tree_depth_nums = {1, 3}; -} // namespace + auto tanh1 = std::make_shared(tanh); + outputs.push_back(tanh1); -// -------------------------------------------------------------------------------------- + return std::make_shared(outputs, ov::ParameterVector{X}); +} + +} // namespace input_node_consumers + +namespace input_transpose_consumers { + +std::shared_ptr CreateFunction(size_t num_split_ops, size_t num_split_outputs, element::Type input_type) { + const Shape input_shape{96, static_cast(std::pow(num_split_outputs, num_split_ops + 1)), 55, 55}; + + auto X = std::make_shared(input_type, input_shape); + + auto ng_order0 = std::make_shared(element::u64, Shape{4}, Shape{0, 3, 1, 2}); + auto transpose0 = std::make_shared(X, ng_order0); + + ov::OutputVector outputs; + auto in_op = transpose0->output(0); + for (size_t i = 0; i < num_split_ops; ++i) { + auto split_axis_const = std::make_shared(element::u64, Shape{}, 2); + auto split = std::make_shared(in_op, split_axis_const, num_split_outputs); + for (size_t num_output = 0; num_output < num_split_outputs - 1; ++num_output) { + outputs.push_back(split->output(num_output)); + } + in_op = split->output(num_split_outputs - 1); + } + outputs.push_back(in_op); + + auto tanh = std::make_shared(transpose0); + outputs.push_back(tanh); + + return std::make_shared(outputs, ov::ParameterVector{X}); +} + +std::shared_ptr CreateReferenceFunction(size_t num_split_ops, + size_t num_split_outputs, + element::Type input_type) { + const Shape input_shape{96, static_cast(std::pow(num_split_outputs, num_split_ops + 1)), 55, 55}; + + auto X = std::make_shared(input_type, input_shape); + + ov::OutputVector outputs; + auto in_op = X->output(0); + for (size_t i = 0; i < num_split_ops; ++i) { + auto split_axis_const = std::make_shared(element::u64, Shape{}, 1); + auto split = std::make_shared(in_op, split_axis_const, num_split_outputs); + for (size_t num_output = 0; num_output < num_split_outputs - 1; ++num_output) { + auto ng_order0 = std::make_shared(element::u64, Shape{4}, Shape{0, 3, 1, 2}); + auto transpose0 = std::make_shared(split->output(num_output), ng_order0); + outputs.push_back(transpose0); + } + in_op = split->output(num_split_outputs - 1); + } + + auto ng_order0 = std::make_shared(element::u64, Shape{4}, Shape{0, 3, 1, 2}); + auto transpose0 = std::make_shared(in_op, ng_order0); + outputs.push_back(transpose0); + + auto ng_order1 = std::make_shared(element::u64, Shape{4}, Shape{0, 3, 1, 2}); + auto transpose1 = std::make_shared(X, ng_order1); + + auto tanh = std::make_shared(transpose1); + outputs.push_back(tanh); + + return std::make_shared(outputs, ov::ParameterVector{X}); +} + +} // namespace input_transpose_consumers + +namespace output_consumers { + +std::shared_ptr CreateFunction(size_t num_split_ops, size_t num_split_outputs, element::Type input_type) { + const Shape input_shape{96, static_cast(std::pow(num_split_outputs, num_split_ops + 1)), 55, 55}; + + auto X = std::make_shared(input_type, input_shape); + + auto ng_order0 = std::make_shared(element::u64, Shape{4}, Shape{0, 3, 1, 2}); + auto transpose0 = std::make_shared(X, ng_order0); + + ov::OutputVector outputs; + auto in_op = transpose0->output(0); + for (size_t i = 0; i < num_split_ops; ++i) { + auto split_axis_const = std::make_shared(element::u64, Shape{}, 2); + auto split = std::make_shared(in_op, split_axis_const, num_split_outputs); + for (size_t num_output = 0; num_output < num_split_outputs - 1; ++num_output) { + outputs.push_back(split->output(num_output)); + } + in_op = split->output(num_split_outputs - 1); + } + outputs.push_back(in_op); + + auto tanh = std::make_shared(in_op); + auto tanh1 = std::make_shared(in_op); + outputs.push_back(tanh); + outputs.push_back(tanh1); + + return std::make_shared(outputs, ov::ParameterVector{X}); +} + +std::shared_ptr CreateReferenceFunction(size_t num_split_ops, + size_t num_split_outputs, + element::Type input_type) { + const Shape input_shape{96, static_cast(std::pow(num_split_outputs, num_split_ops + 1)), 55, 55}; + + auto X = std::make_shared(input_type, input_shape); + + ov::OutputVector outputs; + auto in_op = X->output(0); + for (size_t i = 0; i < num_split_ops; ++i) { + auto split_axis_const = std::make_shared(element::u64, Shape{}, 1); + auto split = std::make_shared(in_op, split_axis_const, num_split_outputs); + for (size_t num_output = 0; num_output < num_split_outputs - 1; ++num_output) { + auto ng_order0 = std::make_shared(element::u64, Shape{4}, Shape{0, 3, 1, 2}); + auto transpose0 = std::make_shared(split->output(num_output), ng_order0); + outputs.push_back(transpose0); + } + in_op = split->output(num_split_outputs - 1); + } + + auto ng_order0 = std::make_shared(element::u64, Shape{4}, Shape{0, 3, 1, 2}); + auto transpose0 = std::make_shared(in_op, ng_order0); + outputs.push_back(transpose0); + + auto tanh = std::make_shared(transpose0); + auto tanh1 = std::make_shared(transpose0); + outputs.push_back(tanh); + outputs.push_back(tanh1); + + return std::make_shared(outputs, ov::ParameterVector{X}); +} + +} // namespace output_consumers + +} // namespace mult_consumers + +} // namespace forward -namespace split { namespace backward { class SplitFactory { public: - SplitFactory(size_t axis, size_t n_outputs, ov::element::Type elem_type) + SplitFactory(size_t axis, size_t n_outputs, element::Type elem_type) : _axis(axis), _n_outputs(n_outputs), _elem_type(elem_type) {} - NodePtr create(Output parent) const { - auto split_axis_const = std::make_shared(_elem_type, ov::Shape{}, _axis); - return std::make_shared(parent, split_axis_const, _n_outputs); + NodePtr create(ov::Output parent) const { + auto split_axis_const = std::make_shared(_elem_type, Shape{}, _axis); + return std::make_shared(parent, split_axis_const, _n_outputs); } private: const size_t _axis; const size_t _n_outputs; - const ov::element::Type _elem_type; + const element::Type _elem_type; }; void CreateSplitTree(size_t max_depth, size_t depth, - Output parent, + ov::Output parent, const SplitFactory& split_factory, ov::OutputVector& leaves) { if (depth == max_depth) { @@ -247,51 +343,57 @@ void CreateSplitTree(size_t max_depth, } } -std::shared_ptr CreateFunction(size_t split_tree_depth, - size_t num_split_outputs, - ov::element::Type input_type) { - const size_t split_input_dim_value = static_cast(std::pow(num_split_outputs, split_tree_depth + 1)); - const ov::Shape input_shape{96, split_input_dim_value, 55, 55}; +namespace single_consumer { - auto X = std::make_shared(input_type, input_shape); +std::shared_ptr CreateFunction(size_t split_tree_depth, size_t num_split_outputs, element::Type input_type) { + const size_t split_input_dim_value = static_cast(std::pow(num_split_outputs, split_tree_depth + 1)); + const Shape input_shape{96, split_input_dim_value, 55, 55}; + + auto X = std::make_shared(input_type, input_shape); + + auto tanh = std::make_shared(X); ov::OutputVector split_tree_leaves; { - SplitFactory split_factory(/* axis */ 1, num_split_outputs, /* elem_type */ ov::element::u64); - CreateSplitTree(split_tree_depth, /* depth */ 0, X->output(0), split_factory, split_tree_leaves); + SplitFactory split_factory(/* axis */ 1, num_split_outputs, /* elem_type */ element::u64); + CreateSplitTree(split_tree_depth, /* depth */ 0, tanh->output(0), split_factory, split_tree_leaves); } ov::OutputVector outputs; for (auto& split_tree_leaf : split_tree_leaves) { - auto ng_order = std::make_shared(ov::element::u64, ov::Shape{4}, ov::Shape{0, 3, 1, 2}); - auto transpose = std::make_shared(split_tree_leaf, ng_order); + auto ng_order = std::make_shared(element::u64, Shape{4}, Shape{0, 3, 1, 2}); + auto transpose = std::make_shared(split_tree_leaf, ng_order); const size_t split_dim_current_value = static_cast(split_input_dim_value / std::pow(num_split_outputs, split_tree_depth)); - auto reshape_const = std::make_shared(ov::element::u64, - ov::Shape{3}, - ov::Shape{96, 55, split_dim_current_value * 55}); - auto reshape = std::make_shared(transpose, reshape_const, false); + auto reshape_const = + std::make_shared(element::u64, Shape{3}, Shape{96, 55, split_dim_current_value * 55}); + auto reshape = std::make_shared(transpose, reshape_const, false); outputs.push_back(reshape); } - return std::make_shared(outputs, ov::ParameterVector{X}); + auto tanh1 = std::make_shared(tanh); + outputs.push_back(tanh1); + + return std::make_shared(outputs, ov::ParameterVector{X}); } -std::shared_ptr CreateReferenceFunction(size_t split_tree_depth, - size_t num_split_outputs, - ov::element::Type input_type) { +std::shared_ptr CreateReferenceFunction(size_t split_tree_depth, + size_t num_split_outputs, + element::Type input_type) { const size_t split_input_dim_value = static_cast(std::pow(num_split_outputs, split_tree_depth + 1)); - const ov::Shape input_shape{96, split_input_dim_value, 55, 55}; + const Shape input_shape{96, split_input_dim_value, 55, 55}; - auto X = std::make_shared(input_type, input_shape); + auto X = std::make_shared(input_type, input_shape); - auto ng_order = std::make_shared(ov::element::u64, ov::Shape{4}, ov::Shape{0, 3, 1, 2}); - auto transpose = std::make_shared(X, ng_order); + auto tanh = std::make_shared(X); + + auto ng_order = std::make_shared(element::u64, Shape{4}, Shape{0, 3, 1, 2}); + auto transpose = std::make_shared(tanh, ng_order); ov::OutputVector split_tree_leaves; { - SplitFactory split_factory(/* axis */ 2, num_split_outputs, /* elem_type */ ov::element::u64); + SplitFactory split_factory(/* axis */ 2, num_split_outputs, /* elem_type */ element::u64); CreateSplitTree(split_tree_depth, /* depth */ 0, transpose->output(0), split_factory, split_tree_leaves); } @@ -299,93 +401,252 @@ std::shared_ptr CreateReferenceFunction(size_t split_tree_depth, for (auto& split_tree_leaf : split_tree_leaves) { const size_t split_dim_current_value = static_cast(split_input_dim_value / std::pow(num_split_outputs, split_tree_depth)); - auto reshape_const = std::make_shared(ov::element::u64, - ov::Shape{3}, - ov::Shape{96, 55, split_dim_current_value * 55}); - auto reshape = std::make_shared(split_tree_leaf, reshape_const, false); + auto reshape_const = + std::make_shared(element::u64, Shape{3}, Shape{96, 55, split_dim_current_value * 55}); + auto reshape = std::make_shared(split_tree_leaf, reshape_const, false); outputs.push_back(reshape); } - return std::make_shared(outputs, ov::ParameterVector{X}); + auto tanh1 = std::make_shared(tanh); + outputs.push_back(tanh1); + + return std::make_shared(outputs, ov::ParameterVector{X}); } -} // namespace backward -} // namespace split +} // namespace single_consumer -TEST_P(TransposeSinkingSplitBackwardTestFixture, CompareFunctions) { - PassFactoryPtr pass_factory; - size_t split_tree_depth; - size_t num_split_outputs; - CreateGraphSplitBackwardF model_factory; - CreateGraphSplitBackwardF reference_model_factory; - ov::element::Type input_type; - std::tie(pass_factory, split_tree_depth, num_split_outputs, model_factory, reference_model_factory, input_type) = - this->GetParam(); +namespace mult_output_consumers { - model = model_factory(split_tree_depth, num_split_outputs, input_type); - model_ref = reference_model_factory(split_tree_depth, num_split_outputs, input_type); - pass_factory->registerPass(manager); -} - -INSTANTIATE_TEST_SUITE_P( - TransposeSinkingSplitBackwardTestSuite, - TransposeSinkingSplitBackwardTestFixture, - ::testing::Combine(::testing::Values(CreatePassFactory()), - ::testing::ValuesIn(split_tree_depth_nums), - ::testing::ValuesIn(split_outputs_numbers), - ::testing::Values(split::backward::CreateFunction), - ::testing::Values(split::backward::CreateReferenceFunction), - ::testing::Values(ov::element::f32))); - -using TransposeInsertF = std::function; - -using CreateGraphSplitBackwardRestrictF = - std::function(size_t split_tree_depth, - size_t num_split_outputs, - ov::element::Type input_type, - TransposeInsertF tranpose_insert_function)>; - -using TestSplitBackwardRestrictParams = std::tuple; /* insert transpose function */ - -class TransposeSinkingSplitBackwardRestrictTestFixture - : public ::testing::WithParamInterface, - public TransformationTestsF {}; - -TEST_P(TransposeSinkingSplitBackwardRestrictTestFixture, CompareFunctions) { - PassFactoryPtr pass_factory; - size_t split_tree_depth; - size_t num_split_outputs; - CreateGraphSplitBackwardRestrictF model_factory; - ov::element::Type input_type; - TransposeInsertF tranpose_insert_function; - std::tie(pass_factory, split_tree_depth, num_split_outputs, model_factory, input_type, tranpose_insert_function) = - this->GetParam(); - - model = model_factory(split_tree_depth, num_split_outputs, input_type, tranpose_insert_function); - model_ref = model->clone(); - pass_factory->registerPass(manager); -} -namespace split { -namespace backward { -namespace restrictions { - -std::shared_ptr CreateFunction(size_t split_tree_depth, - size_t num_split_outputs, - ov::element::Type input_type, - TransposeInsertF transpose_insert_func) { +std::shared_ptr CreateFunction(size_t split_tree_depth, size_t num_split_outputs, element::Type input_type) { const size_t split_input_dim_value = static_cast(std::pow(num_split_outputs, split_tree_depth + 1)); - const ov::Shape input_shape{96, split_input_dim_value, 55, 55}; + const Shape input_shape{96, split_input_dim_value, 55, 55}; - auto X = std::make_shared(input_type, input_shape); + auto X = std::make_shared(input_type, input_shape); ov::OutputVector split_tree_leaves; { - SplitFactory split_factory(/* axis */ 1, num_split_outputs, /* elem_type */ ov::element::u64); + SplitFactory split_factory(/* axis */ 1, num_split_outputs, /* elem_type */ element::u64); + CreateSplitTree(split_tree_depth, /* depth */ 0, X->output(0), split_factory, split_tree_leaves); + } + + ov::OutputVector outputs; + for (auto& split_tree_leaf : split_tree_leaves) { + auto ng_order = std::make_shared(element::u64, Shape{4}, Shape{0, 3, 1, 2}); + auto transpose = std::make_shared(split_tree_leaf, ng_order); + + auto tanh0 = std::make_shared(transpose); + auto tanh1 = std::make_shared(transpose); + + outputs.push_back(tanh0); + outputs.push_back(tanh1); + } + + return std::make_shared(outputs, ov::ParameterVector{X}); +} + +std::shared_ptr CreateReferenceFunction(size_t split_tree_depth, + size_t num_split_outputs, + element::Type input_type) { + const size_t split_input_dim_value = static_cast(std::pow(num_split_outputs, split_tree_depth + 1)); + const Shape input_shape{96, split_input_dim_value, 55, 55}; + + auto X = std::make_shared(input_type, input_shape); + + auto ng_order = std::make_shared(element::u64, Shape{4}, Shape{0, 3, 1, 2}); + auto transpose = std::make_shared(X, ng_order); + + ov::OutputVector split_tree_leaves; + { + SplitFactory split_factory(/* axis */ 2, num_split_outputs, /* elem_type */ element::u64); + CreateSplitTree(split_tree_depth, /* depth */ 0, transpose->output(0), split_factory, split_tree_leaves); + } + + ov::OutputVector outputs; + for (auto& split_tree_leaf : split_tree_leaves) { + auto tanh0 = std::make_shared(split_tree_leaf); + auto tanh1 = std::make_shared(split_tree_leaf); + + outputs.push_back(tanh0); + outputs.push_back(tanh1); + } + + return std::make_shared(outputs, ov::ParameterVector{X}); +} + +} // namespace mult_output_consumers + +namespace mult_split_consumers { + +std::shared_ptr CreateFunction(size_t split_tree_depth, size_t num_split_outputs, element::Type input_type) { + const size_t split_input_dim_value = static_cast(std::pow(num_split_outputs, split_tree_depth + 1)); + const Shape input_shape{96, split_input_dim_value, 55, 55}; + + auto X = std::make_shared(input_type, input_shape); + + ov::OutputVector split_tree_leaves; + { + SplitFactory split_factory(/* axis */ 1, num_split_outputs, /* elem_type */ element::u64); + CreateSplitTree(split_tree_depth, /* depth */ 0, X->output(0), split_factory, split_tree_leaves); + } + + ov::OutputVector outputs; + for (auto& split_tree_leaf : split_tree_leaves) { + auto ng_order = std::make_shared(element::u64, Shape{4}, Shape{0, 3, 1, 2}); + auto transpose = std::make_shared(split_tree_leaf, ng_order); + + auto tanh0 = std::make_shared(split_tree_leaf); + auto tanh1 = std::make_shared(split_tree_leaf); + + outputs.push_back(tanh0); + outputs.push_back(tanh1); + } + + return std::make_shared(outputs, ov::ParameterVector{X}); +} + +} // namespace mult_split_consumers + +} // namespace backward + +using CreateGraphSplitF = + std::function(size_t num_split_ops, size_t num_split_outputs, element::Type input_type)>; + +using TestSplitParams = std::tuple /* input type */; + +class TransposeSinkingSplitTestFixture : public ::testing::WithParamInterface, + public TransformationTestsF { +public: + static std::string get_test_name(const testing::TestParamInfo& obj) { + PassFactoryPtr pass_factory; + size_t num_split_ops; + size_t num_split_outputs; + CreateGraphSplitF model_factory; + CreateGraphSplitF reference_model_factory; + element::Type input_type; + + std::tie(pass_factory, num_split_ops, num_split_outputs, model_factory, reference_model_factory, input_type) = + obj.param; + + std::ostringstream test_name; + test_name << "pass_factory=" << pass_factory->getTypeName() << "_"; + test_name << "num_split_ops=" << num_split_ops << "_"; + test_name << "num_split_outputs=" << num_split_outputs << "_"; + test_name << "input_type=" << input_type; + + return test_name.str(); + } +}; + +TEST_P(TransposeSinkingSplitTestFixture, CompareFunctions) { + PassFactoryPtr pass_factory; + size_t num_split_ops; + size_t num_split_outputs; + CreateGraphSplitF model_factory; + CreateGraphSplitF reference_model_factory; + element::Type input_type; + std::tie(pass_factory, num_split_ops, num_split_outputs, model_factory, reference_model_factory, input_type) = + this->GetParam(); + + model = model_factory(num_split_ops, num_split_outputs, input_type); + model_ref = reference_model_factory(num_split_ops, num_split_outputs, input_type); + pass_factory->registerPass(manager); +} + +INSTANTIATE_TEST_SUITE_P(TransposeSinkingSplitForwardSingleConsumerTestSuite, + TransposeSinkingSplitTestFixture, + ::testing::Combine(::testing::Values(CREATE_PASS_FACTORY(TransposeSinkingSplitForward)), + ::testing::ValuesIn(split_operations_numbers), + ::testing::ValuesIn(split_outputs_numbers), + ::testing::Values(forward::single_consumer::CreateFunction), + ::testing::Values(forward::single_consumer::CreateReferenceFunction), + ::testing::Values(element::f32)), + TransposeSinkingSplitTestFixture::get_test_name); + +INSTANTIATE_TEST_SUITE_P( + TransposeSinkingSplitForwardMultInputNodeConsumersTestSuite, + TransposeSinkingSplitTestFixture, + ::testing::Combine(::testing::Values(CREATE_PASS_FACTORY(TransposeSinkingSplitForward)), + ::testing::ValuesIn(split_operations_numbers), + ::testing::ValuesIn(split_outputs_numbers), + ::testing::Values(forward::mult_consumers::input_node_consumers::CreateFunction), + ::testing::Values(forward::mult_consumers::input_node_consumers::CreateReferenceFunction), + ::testing::Values(element::f32)), + TransposeSinkingSplitTestFixture::get_test_name); + +INSTANTIATE_TEST_SUITE_P( + TransposeSinkingSplitForwardMultInputTransposeConsumersTestSuite, + TransposeSinkingSplitTestFixture, + ::testing::Combine(::testing::Values(CREATE_PASS_FACTORY(TransposeSinkingSplitForward)), + ::testing::ValuesIn(split_operations_numbers), + ::testing::ValuesIn(split_outputs_numbers), + ::testing::Values(forward::mult_consumers::input_transpose_consumers::CreateFunction), + ::testing::Values(forward::mult_consumers::input_transpose_consumers::CreateReferenceFunction), + ::testing::Values(element::f32)), + TransposeSinkingSplitTestFixture::get_test_name); + +INSTANTIATE_TEST_SUITE_P( + TransposeSinkingSplitForwardMultOutputConsumersTestSuite, + TransposeSinkingSplitTestFixture, + ::testing::Combine(::testing::Values(CREATE_PASS_FACTORY(TransposeSinkingSplitForward)), + ::testing::ValuesIn(split_operations_numbers), + ::testing::ValuesIn(split_outputs_numbers), + ::testing::Values(forward::mult_consumers::output_consumers::CreateFunction), + ::testing::Values(forward::mult_consumers::output_consumers::CreateReferenceFunction), + ::testing::Values(element::f32)), + TransposeSinkingSplitTestFixture::get_test_name); + +INSTANTIATE_TEST_SUITE_P(TransposeSinkingSplitBackwardTestSuite, + TransposeSinkingSplitTestFixture, + ::testing::Combine(::testing::Values(CREATE_PASS_FACTORY(TransposeSinkingSplitBackward)), + ::testing::ValuesIn(split_tree_depth_nums), + ::testing::ValuesIn(split_outputs_numbers), + ::testing::Values(backward::single_consumer::CreateFunction), + ::testing::Values(backward::single_consumer::CreateReferenceFunction), + ::testing::Values(element::f32)), + TransposeSinkingSplitTestFixture::get_test_name); + +INSTANTIATE_TEST_SUITE_P(TransposeSinkingSplitBackwardMultOutputConsumersTestSuite, + TransposeSinkingSplitTestFixture, + ::testing::Combine(::testing::Values(CREATE_PASS_FACTORY(TransposeSinkingSplitBackward)), + ::testing::ValuesIn(split_tree_depth_nums), + ::testing::ValuesIn(split_outputs_numbers), + ::testing::Values(backward::mult_output_consumers::CreateFunction), + ::testing::Values(backward::mult_output_consumers::CreateReferenceFunction), + ::testing::Values(element::f32)), + TransposeSinkingSplitTestFixture::get_test_name); + +INSTANTIATE_TEST_SUITE_P(TransposeSinkingSplitBackwardMultSplitConsumersTestSuite, + TransposeSinkingSplitTestFixture, + ::testing::Combine(::testing::Values(CREATE_PASS_FACTORY(TransposeSinkingSplitBackward)), + ::testing::ValuesIn(split_tree_depth_nums), + ::testing::ValuesIn(split_outputs_numbers), + ::testing::Values(backward::mult_split_consumers::CreateFunction), + ::testing::Values(backward::mult_split_consumers::CreateFunction), + ::testing::Values(element::f32)), + TransposeSinkingSplitTestFixture::get_test_name); + +namespace backward { +namespace restrictions { + +using TransposeInsertF = std::function; + +std::shared_ptr CreateFunction(size_t split_tree_depth, + size_t num_split_outputs, + element::Type input_type, + TransposeInsertF transpose_insert_func) { + const size_t split_input_dim_value = static_cast(std::pow(num_split_outputs, split_tree_depth + 1)); + const Shape input_shape{96, split_input_dim_value, 55, 55}; + + auto X = std::make_shared(input_type, input_shape); + + ov::OutputVector split_tree_leaves; + { + SplitFactory split_factory(/* axis */ 1, num_split_outputs, /* elem_type */ element::u64); CreateSplitTree(split_tree_depth, /* depth */ 0, X->output(0), split_factory, split_tree_leaves); } @@ -393,22 +654,21 @@ std::shared_ptr CreateFunction(size_t split_tree_depth, for (auto& split_tree_leaf : transpose_insert_func(split_tree_leaves)) { const size_t split_dim_current_value = static_cast(split_input_dim_value / std::pow(num_split_outputs, split_tree_depth)); - auto reshape_const = std::make_shared(ov::element::u64, - ov::Shape{3}, - ov::Shape{96, 55, split_dim_current_value * 55}); - auto reshape = std::make_shared(split_tree_leaf, reshape_const, false); + auto reshape_const = + std::make_shared(element::u64, Shape{3}, Shape{96, 55, split_dim_current_value * 55}); + auto reshape = std::make_shared(split_tree_leaf, reshape_const, false); outputs.push_back(reshape); } - return std::make_shared(outputs, ov::ParameterVector{X}); + return std::make_shared(outputs, ov::ParameterVector{X}); } ov::OutputVector OnlyFirstTranspose(const ov::OutputVector& split_tree_leaves) { ov::OutputVector outputs; { auto& split_tree_leaf = split_tree_leaves.front(); - auto ng_order = std::make_shared(ov::element::u64, ov::Shape{4}, ov::Shape{0, 3, 1, 2}); - auto transpose = std::make_shared(split_tree_leaf, ng_order); + auto ng_order = std::make_shared(element::u64, Shape{4}, Shape{0, 3, 1, 2}); + auto transpose = std::make_shared(split_tree_leaf, ng_order); outputs.push_back(transpose); } @@ -423,8 +683,8 @@ ov::OutputVector OnlyLastTranspose(const ov::OutputVector& split_tree_leaves) { ov::OutputVector outputs; { auto& split_tree_leaf = split_tree_leaves.back(); - auto ng_order = std::make_shared(ov::element::u64, ov::Shape{4}, ov::Shape{0, 3, 1, 2}); - auto transpose = std::make_shared(split_tree_leaf, ng_order); + auto ng_order = std::make_shared(element::u64, Shape{4}, Shape{0, 3, 1, 2}); + auto transpose = std::make_shared(split_tree_leaf, ng_order); outputs.push_back(transpose); } @@ -443,9 +703,8 @@ ov::OutputVector OnlyMiddleTranspose(const ov::OutputVector& split_tree_leaves) for (size_t leaf_idx = 0; leaf_idx < split_tree_leaves.size() - 1; ++leaf_idx) { if (leaf_idx == middle_idx) { auto& split_tree_leaf = split_tree_leaves[leaf_idx]; - auto ng_order = - std::make_shared(ov::element::u64, ov::Shape{4}, ov::Shape{0, 3, 1, 2}); - auto transpose = std::make_shared(split_tree_leaf, ng_order); + auto ng_order = std::make_shared(element::u64, Shape{4}, Shape{0, 3, 1, 2}); + auto transpose = std::make_shared(split_tree_leaf, ng_order); outputs.push_back(transpose); } else { outputs.push_back(split_tree_leaves[leaf_idx]); @@ -459,15 +718,15 @@ ov::OutputVector FirstAnotherTranspose(const ov::OutputVector& split_tree_leaves ov::OutputVector outputs; { auto& split_tree_leaf = split_tree_leaves.front(); - auto ng_order = std::make_shared(ov::element::u64, ov::Shape{4}, ov::Shape{0, 2, 3, 1}); - auto transpose = std::make_shared(split_tree_leaf, ng_order); + auto ng_order = std::make_shared(element::u64, Shape{4}, Shape{0, 2, 3, 1}); + auto transpose = std::make_shared(split_tree_leaf, ng_order); outputs.push_back(transpose); } for (size_t leaf_idx = 1; leaf_idx < split_tree_leaves.size(); ++leaf_idx) { auto& split_tree_leaf = split_tree_leaves[leaf_idx]; - auto ng_order = std::make_shared(ov::element::u64, ov::Shape{4}, ov::Shape{0, 3, 1, 2}); - auto transpose = std::make_shared(split_tree_leaf, ng_order); + auto ng_order = std::make_shared(element::u64, Shape{4}, Shape{0, 3, 1, 2}); + auto transpose = std::make_shared(split_tree_leaf, ng_order); outputs.push_back(transpose); } @@ -478,15 +737,15 @@ ov::OutputVector LastAnotherTranspose(const ov::OutputVector& split_tree_leaves) ov::OutputVector outputs; { auto& split_tree_leaf = split_tree_leaves.back(); - auto ng_order = std::make_shared(ov::element::u64, ov::Shape{4}, ov::Shape{0, 2, 3, 1}); - auto transpose = std::make_shared(split_tree_leaf, ng_order); + auto ng_order = std::make_shared(element::u64, Shape{4}, Shape{0, 2, 3, 1}); + auto transpose = std::make_shared(split_tree_leaf, ng_order); outputs.push_back(transpose); } for (size_t leaf_idx = 0; leaf_idx < split_tree_leaves.size() - 1; ++leaf_idx) { auto& split_tree_leaf = split_tree_leaves[leaf_idx]; - auto ng_order = std::make_shared(ov::element::u64, ov::Shape{4}, ov::Shape{0, 3, 1, 2}); - auto transpose = std::make_shared(split_tree_leaf, ng_order); + auto ng_order = std::make_shared(element::u64, Shape{4}, Shape{0, 3, 1, 2}); + auto transpose = std::make_shared(split_tree_leaf, ng_order); outputs.push_back(transpose); } @@ -501,14 +760,12 @@ ov::OutputVector MiddleAnotherTranspose(const ov::OutputVector& split_tree_leave for (size_t leaf_idx = 0; leaf_idx < split_tree_leaves.size(); ++leaf_idx) { auto& split_tree_leaf = split_tree_leaves[leaf_idx]; if (leaf_idx == middle_idx) { - auto ng_order = - std::make_shared(ov::element::u64, ov::Shape{4}, ov::Shape{0, 2, 3, 1}); - auto transpose = std::make_shared(split_tree_leaf, ng_order); + auto ng_order = std::make_shared(element::u64, Shape{4}, Shape{0, 2, 3, 1}); + auto transpose = std::make_shared(split_tree_leaf, ng_order); outputs.push_back(transpose); } else { - auto ng_order = - std::make_shared(ov::element::u64, ov::Shape{4}, ov::Shape{0, 3, 1, 2}); - auto transpose = std::make_shared(split_tree_leaf, ng_order); + auto ng_order = std::make_shared(element::u64, Shape{4}, Shape{0, 3, 1, 2}); + auto transpose = std::make_shared(split_tree_leaf, ng_order); outputs.push_back(transpose); } } @@ -516,27 +773,96 @@ ov::OutputVector MiddleAnotherTranspose(const ov::OutputVector& split_tree_leave return outputs; } +struct TransposeInsertFuncDesc { + TransposeInsertFuncDesc() = default; + TransposeInsertFuncDesc(TransposeInsertF a_function, std::string a_name) : function(a_function), name(a_name) {} + + TransposeInsertF function; + std::string name; +}; + +using CreateGraphSplitBackwardRestrictF = + std::function(size_t split_tree_depth, + size_t num_split_outputs, + element::Type input_type, + TransposeInsertF tranpose_insert_function)>; + +using TestSplitBackwardRestrictParams = std::tuple; /* insert transpose function */ + +class TransposeSinkingSplitBackwardRestrictTestFixture + : public ::testing::WithParamInterface, + public TransformationTestsF { +public: + static std::string get_test_name(const testing::TestParamInfo& obj) { + PassFactoryPtr pass_factory; + size_t split_tree_depth; + size_t num_split_outputs; + CreateGraphSplitBackwardRestrictF model_factory; + element::Type input_type; + TransposeInsertFuncDesc tranpose_insert_function; + + std::tie(pass_factory, + split_tree_depth, + num_split_outputs, + model_factory, + input_type, + tranpose_insert_function) = obj.param; + + std::ostringstream test_name; + test_name << "pass_factory=" << pass_factory->getTypeName() << "_"; + test_name << "split_tree_depth=" << split_tree_depth << "_"; + test_name << "num_split_outputs=" << num_split_outputs << "_"; + test_name << "tranpose_insert_function=" << tranpose_insert_function.name << "_"; + test_name << "input_type=" << input_type; + + return test_name.str(); + } +}; + +TEST_P(TransposeSinkingSplitBackwardRestrictTestFixture, CompareFunctions) { + PassFactoryPtr pass_factory; + size_t split_tree_depth; + size_t num_split_outputs; + CreateGraphSplitBackwardRestrictF model_factory; + element::Type input_type; + TransposeInsertFuncDesc tranpose_insert_function; + + std::tie(pass_factory, split_tree_depth, num_split_outputs, model_factory, input_type, tranpose_insert_function) = + this->GetParam(); + + model = model_factory(split_tree_depth, num_split_outputs, input_type, tranpose_insert_function.function); + model_ref = model->clone(); + pass_factory->registerPass(manager); +} + +#define FUNC(name) TransposeInsertFuncDesc(backward::restrictions::name, #name) + +std::vector insertTransposeFactories = {FUNC(OnlyFirstTranspose), + FUNC(OnlyLastTranspose), + FUNC(OnlyMiddleTranspose), + FUNC(FirstAnotherTranspose), + FUNC(LastAnotherTranspose), + FUNC(MiddleAnotherTranspose)}; + +#undef FUNC + +INSTANTIATE_TEST_SUITE_P(TransposeSinkingSplitBackwardRestrictTestSuite, + TransposeSinkingSplitBackwardRestrictTestFixture, + ::testing::Combine(::testing::Values(CREATE_PASS_FACTORY(TransposeSinkingSplitBackward)), + ::testing::Values(1), + ::testing::Values(5), + ::testing::Values(backward::restrictions::CreateFunction), + ::testing::Values(element::f32), + ::testing::ValuesIn(insertTransposeFactories)), + TransposeSinkingSplitBackwardRestrictTestFixture::get_test_name); + } // namespace restrictions + } // namespace backward -} // namespace split -namespace { - -std::vector insertTransposeFactories = {split::backward::restrictions::OnlyFirstTranspose, - split::backward::restrictions::OnlyLastTranspose, - split::backward::restrictions::OnlyMiddleTranspose, - split::backward::restrictions::FirstAnotherTranspose, - split::backward::restrictions::LastAnotherTranspose, - split::backward::restrictions::MiddleAnotherTranspose}; - -} // namespace - -INSTANTIATE_TEST_SUITE_P( - TransposeSinkingSplitBackwardRestrictTestSuite, - TransposeSinkingSplitBackwardRestrictTestFixture, - ::testing::Combine(::testing::Values(CreatePassFactory()), - ::testing::Values(1), - ::testing::Values(5), - ::testing::Values(split::backward::restrictions::CreateFunction), - ::testing::Values(ov::element::f32), - ::testing::ValuesIn(insertTransposeFactories))); +} // namespace transpose_sinking_split diff --git a/src/common/transformations/tests/common_optimizations/transpose_sinking_unary_test.cpp b/src/common/transformations/tests/common_optimizations/transpose_sinking_unary_test.cpp index 1c88c7b0d5d..fb16d71f436 100644 --- a/src/common/transformations/tests/common_optimizations/transpose_sinking_unary_test.cpp +++ b/src/common/transformations/tests/common_optimizations/transpose_sinking_unary_test.cpp @@ -12,13 +12,37 @@ #include "common_test_utils/ngraph_test_utils.hpp" #include "gtest/gtest.h" +using namespace ov; +using namespace ov::opset9; + +namespace { +std::string to_string(const Shape& shape) { + std::ostringstream result; + result << "{"; + for (size_t idx = 0; idx < shape.size(); ++idx) { + if (idx) + result << ","; + result << shape[idx]; + } + result << "}"; + return result.str(); +} +} // namespace + using NodePtr = std::shared_ptr; class IUnaryFactory { public: - IUnaryFactory() = default; + IUnaryFactory(const std::string& type_name) : type_name_(type_name) {} virtual ~IUnaryFactory() = default; virtual NodePtr create(NodePtr parent_node) const = 0; + + const std::string& getTypeName() const { + return type_name_; + } + +private: + const std::string type_name_; }; using UnaryFactoryPtr = std::shared_ptr; @@ -26,39 +50,45 @@ using UnaryFactoryPtr = std::shared_ptr; template class UnaryFactory : public IUnaryFactory { public: - UnaryFactory() = default; + UnaryFactory(const std::string& type_name) : IUnaryFactory(type_name) {} NodePtr create(NodePtr parent_node) const override { return std::make_shared(parent_node); } }; template <> -NodePtr UnaryFactory::create(NodePtr parent_node) const { - return std::make_shared(parent_node, 0.1); +NodePtr UnaryFactory::create(NodePtr parent_node) const { + return std::make_shared(parent_node, 0.1); } template <> -NodePtr UnaryFactory::create(NodePtr parent_node) const { - return std::make_shared(parent_node, 0.1, 0.2); +NodePtr UnaryFactory::create(NodePtr parent_node) const { + return std::make_shared(parent_node, 0.1, 0.2); } template <> -NodePtr UnaryFactory::create(NodePtr parent_node) const { - return std::make_shared(parent_node, ov::element::f64); +NodePtr UnaryFactory::create(NodePtr parent_node) const { + return std::make_shared(parent_node, element::f64); } template -UnaryFactoryPtr CreateUnaryFactory() { - return std::make_shared>(); +UnaryFactoryPtr CreateUnaryFactory(const std::string& type_name) { + return std::make_shared>(type_name); } // ---------------------------------------------------------------------------- class IPassFactory { public: - IPassFactory() = default; + IPassFactory(const std::string& type_name) : type_name_(type_name) {} virtual ~IPassFactory() = default; virtual void registerPass(ov::pass::Manager& pass_manager) const = 0; + const std::string& getTypeName() const { + return type_name_; + } + +private: + const std::string type_name_; }; using PassFactoryPtr = std::shared_ptr; @@ -66,15 +96,16 @@ using PassFactoryPtr = std::shared_ptr; template class PassFactory : public IPassFactory { public: + PassFactory(const std::string& type_name) : IPassFactory(type_name) {} void registerPass(ov::pass::Manager& pass_manager) const override { pass_manager.register_pass(); } }; -template -PassFactoryPtr CreatePassFactory() { - return std::make_shared>(); -} +#define CREATE_PASS_FACTORY(pass_name) std::make_shared>(#pass_name) + +#undef CREATE_UNARY_FACTORY +#define CREATE_UNARY_FACTORY(type_name) CreateUnaryFactory(#type_name) // ---------------------------------------------------------------------------- @@ -82,19 +113,46 @@ using FloatPtr = std::unique_ptr; using CreateGraphF = std::function(UnaryFactoryPtr unary_factory, size_t num_unary_ops, - const ov::Shape& input_shape, - ov::element::Type input_type)>; + const Shape& input_shape, + element::Type input_type)>; using TestParams = std::tuple; /* input type */ + size_t, /* num_unary_ops */ + CreateGraphF, /* model_factory */ + CreateGraphF, /* reference_model_factory */ + Shape, /* input shape */ + element::Type>; /* input type */ -class TransposeSinkingUnaryTestFixture : public ::testing::WithParamInterface, - public TransformationTestsF {}; +class TransposeSinkingUnaryTestFixture : public ::testing::WithParamInterface, public TransformationTestsF { +public: + static std::string get_test_name(const testing::TestParamInfo& obj) { + UnaryFactoryPtr unary_factory; + PassFactoryPtr pass_factory; + size_t num_unary_ops; + CreateGraphF model_factory; + CreateGraphF reference_model_factory; + Shape input_shape; + element::Type input_type; + std::tie(unary_factory, + pass_factory, + num_unary_ops, + model_factory, + reference_model_factory, + input_shape, + input_type) = obj.param; + + std::ostringstream test_name; + test_name << "unaryFactory=" << unary_factory->getTypeName() << "/"; + test_name << "numUnaryOps=" << num_unary_ops << "/"; + test_name << "inputShape=" << to_string(input_shape) << "/"; + test_name << "unaryFactory=" << unary_factory->getTypeName() << "/"; + test_name << "passFactory=" << pass_factory->getTypeName() << "/"; + test_name << "inputType=" << input_type; + + return test_name.str(); + } +}; namespace { @@ -105,12 +163,12 @@ std::string GetFinalNodeName(std::shared_ptr model, int index = 0) { std::shared_ptr CreateFunctionTransposeBefore(UnaryFactoryPtr unary_factory, size_t num_unary_ops, - const ov::Shape& input_shape, - ov::element::Type input_type) { - auto X = std::make_shared(input_type, input_shape); + const Shape& input_shape, + element::Type input_type) { + auto X = std::make_shared(input_type, input_shape); - auto ng_order0 = std::make_shared(ov::element::u64, ov::Shape{4}, ov::Shape{0, 2, 3, 1}); - auto transpose0 = std::make_shared(X, ng_order0); + auto ng_order0 = std::make_shared(element::u64, Shape{4}, Shape{0, 2, 3, 1}); + auto transpose0 = std::make_shared(X, ng_order0); NodePtr in_op = transpose0; for (size_t i = 0; i < num_unary_ops; ++i) { @@ -122,25 +180,25 @@ std::shared_ptr CreateFunctionTransposeBefore(UnaryFactoryPtr unary_f std::shared_ptr CreateFunctionTransposeAfter(UnaryFactoryPtr unary_factory, size_t num_unary_ops, - const ov::Shape& input_shape, - ov::element::Type input_type) { - auto X = std::make_shared(input_type, input_shape); + const Shape& input_shape, + element::Type input_type) { + auto X = std::make_shared(input_type, input_shape); NodePtr in_op = X; for (size_t i = 0; i < num_unary_ops; ++i) { in_op = unary_factory->create(in_op); } - auto ng_order0 = std::make_shared(ov::element::u64, ov::Shape{4}, ov::Shape{0, 2, 3, 1}); - auto transpose0 = std::make_shared(in_op, ng_order0); + auto ng_order0 = std::make_shared(element::u64, Shape{4}, Shape{0, 2, 3, 1}); + auto transpose0 = std::make_shared(in_op, ng_order0); return std::make_shared(transpose0, ov::ParameterVector{X}); } -static NodePtr CreateReshape(NodePtr parent_node, const ov::Shape& input_shape) { +static NodePtr CreateReshape(NodePtr parent_node, const Shape& input_shape) { const size_t mul = std::accumulate(input_shape.begin(), input_shape.end(), (size_t)1, std::multiplies()); - auto reshape_const = std::make_shared(ov::element::u64, ov::Shape{1}, ov::Shape{mul}); - return std::make_shared(parent_node, reshape_const, false); + auto reshape_const = std::make_shared(element::u64, Shape{1}, Shape{mul}); + return std::make_shared(parent_node, reshape_const, false); } namespace mult_consumers_last_node { @@ -148,17 +206,17 @@ namespace with_reshape { std::shared_ptr CreateFunctionTransposeAfter(UnaryFactoryPtr unary_factory, size_t num_unary_ops, - const ov::Shape& input_shape, - ov::element::Type input_type) { - auto X = std::make_shared(input_type, input_shape); + const Shape& input_shape, + element::Type input_type) { + auto X = std::make_shared(input_type, input_shape); NodePtr in_op = X; for (size_t i = 0; i < num_unary_ops; ++i) { in_op = unary_factory->create(in_op); } - auto ng_order0 = std::make_shared(ov::element::u64, ov::Shape{4}, ov::Shape{0, 2, 3, 1}); - auto transpose0 = std::make_shared(in_op, ng_order0); + auto ng_order0 = std::make_shared(element::u64, Shape{4}, Shape{0, 2, 3, 1}); + auto transpose0 = std::make_shared(in_op, ng_order0); auto reshape1 = CreateReshape(transpose0, input_shape); auto reshape2 = CreateReshape(transpose0, input_shape); @@ -168,12 +226,12 @@ std::shared_ptr CreateFunctionTransposeAfter(UnaryFactoryPtr unary_fa std::shared_ptr CreateFunctionTransposeBefore(UnaryFactoryPtr unary_factory, size_t num_unary_ops, - const ov::Shape& input_shape, - ov::element::Type input_type) { - auto X = std::make_shared(input_type, input_shape); + const Shape& input_shape, + element::Type input_type) { + auto X = std::make_shared(input_type, input_shape); - auto ng_order0 = std::make_shared(ov::element::u64, ov::Shape{4}, ov::Shape{0, 2, 3, 1}); - auto transpose0 = std::make_shared(X, ng_order0); + auto ng_order0 = std::make_shared(element::u64, Shape{4}, Shape{0, 2, 3, 1}); + auto transpose0 = std::make_shared(X, ng_order0); NodePtr in_op = transpose0; for (size_t i = 0; i < num_unary_ops; ++i) { @@ -191,44 +249,44 @@ namespace with_eltwise { std::shared_ptr CreateFunctionTransposeAfter(UnaryFactoryPtr unary_factory, size_t num_unary_ops, - const ov::Shape& input_shape, - ov::element::Type input_type) { - auto X = std::make_shared(input_type, input_shape); + const Shape& input_shape, + element::Type input_type) { + auto X = std::make_shared(input_type, input_shape); NodePtr in_op = X; for (size_t i = 0; i < num_unary_ops; ++i) { in_op = unary_factory->create(in_op); } - auto sinh = std::make_shared(in_op); + auto sinh = std::make_shared(in_op); - auto ng_order0 = std::make_shared(ov::element::u64, ov::Shape{4}, ov::Shape{0, 2, 3, 1}); - auto transpose0 = std::make_shared(sinh, ng_order0); + auto ng_order0 = std::make_shared(element::u64, Shape{4}, Shape{0, 2, 3, 1}); + auto transpose0 = std::make_shared(sinh, ng_order0); - auto cosh = std::make_shared(in_op); + auto cosh = std::make_shared(in_op); - auto ng_order1 = std::make_shared(ov::element::u64, ov::Shape{4}, ov::Shape{0, 2, 3, 1}); - auto transpose1 = std::make_shared(cosh, ng_order1); + auto ng_order1 = std::make_shared(element::u64, Shape{4}, Shape{0, 2, 3, 1}); + auto transpose1 = std::make_shared(cosh, ng_order1); return std::make_shared(ov::OutputVector{transpose0, transpose1}, ov::ParameterVector{X}); } std::shared_ptr CreateFunctionTransposeBefore(UnaryFactoryPtr unary_factory, size_t num_unary_ops, - const ov::Shape& input_shape, - ov::element::Type input_type) { - auto X = std::make_shared(input_type, input_shape); + const Shape& input_shape, + element::Type input_type) { + auto X = std::make_shared(input_type, input_shape); - auto ng_order0 = std::make_shared(ov::element::u64, ov::Shape{4}, ov::Shape{0, 2, 3, 1}); - auto transpose0 = std::make_shared(X, ng_order0); + auto ng_order0 = std::make_shared(element::u64, Shape{4}, Shape{0, 2, 3, 1}); + auto transpose0 = std::make_shared(X, ng_order0); NodePtr in_op = transpose0; for (size_t i = 0; i < num_unary_ops; ++i) { in_op = unary_factory->create(in_op); } - auto sinh = std::make_shared(in_op); - auto cosh = std::make_shared(in_op); + auto sinh = std::make_shared(in_op); + auto cosh = std::make_shared(in_op); return std::make_shared(ov::OutputVector{sinh, cosh}, ov::ParameterVector{X}); } @@ -241,66 +299,87 @@ namespace backward { std::shared_ptr CreateFunction(UnaryFactoryPtr unary_factory, size_t num_unary_ops, - const ov::Shape& input_shape, - ov::element::Type input_type) { - auto X = std::make_shared(input_type, input_shape); + const Shape& input_shape, + element::Type input_type) { + auto X = std::make_shared(input_type, input_shape); ov::OutputVector outputs; NodePtr in_op = X; for (size_t i = 0; i < num_unary_ops; ++i) { in_op = unary_factory->create(in_op); - auto cosh = std::make_shared(in_op); + auto cosh = std::make_shared(in_op); outputs.push_back(cosh); } - auto ng_order0 = std::make_shared(ov::element::u64, ov::Shape{4}, ov::Shape{0, 2, 3, 1}); - auto transpose0 = std::make_shared(in_op, ng_order0); + auto ng_order0 = std::make_shared(element::u64, Shape{4}, Shape{0, 2, 3, 1}); + auto transpose0 = std::make_shared(in_op, ng_order0); outputs.push_back(transpose0); return std::make_shared(outputs, ov::ParameterVector{X}); } -std::shared_ptr CreateReferenceFunction(UnaryFactoryPtr unary_factory, - size_t num_unary_ops, - const ov::Shape& input_shape, - ov::element::Type input_type) { - auto X = std::make_shared(input_type, input_shape); - ov::OutputVector outputs; - - NodePtr in_op = X; - for (size_t i = 0; i < num_unary_ops; ++i) { - in_op = unary_factory->create(in_op); - auto cosh = std::make_shared(in_op); - outputs.push_back(cosh); - } - - auto ng_order0 = std::make_shared(ov::element::u64, ov::Shape{4}, ov::Shape{0, 2, 3, 1}); - auto transpose0 = std::make_shared(X, ng_order0); - - in_op = transpose0; - for (size_t i = 0; i < num_unary_ops; ++i) { - in_op = unary_factory->create(in_op); - } - - outputs.push_back(in_op); - - return std::make_shared(outputs, ov::ParameterVector{X}); -} - } // namespace backward +namespace backward_mult_transposes { + +std::shared_ptr CreateFunction(UnaryFactoryPtr unary_factory, + size_t num_unary_ops, + const Shape& input_shape, + element::Type input_type) { + auto X = std::make_shared(input_type, input_shape); + + NodePtr in_op = X; + for (size_t i = 0; i < num_unary_ops; ++i) { + in_op = unary_factory->create(in_op); + } + + auto ng_order0 = std::make_shared(element::u64, Shape{4}, Shape{0, 2, 3, 1}); + auto transpose0 = std::make_shared(in_op, ng_order0); + + auto tanh0 = std::make_shared(transpose0); + + auto ng_order1 = std::make_shared(element::u64, Shape{4}, Shape{0, 2, 3, 1}); + auto transpose1 = std::make_shared(in_op, ng_order1); + + auto tanh1 = std::make_shared(transpose1); + + return std::make_shared(ov::OutputVector{tanh0, tanh1}, ov::ParameterVector{X}); +} + +std::shared_ptr CreateReferenceFunction(UnaryFactoryPtr unary_factory, + size_t num_unary_ops, + const Shape& input_shape, + element::Type input_type) { + auto X = std::make_shared(input_type, input_shape); + + auto ng_order0 = std::make_shared(element::u64, Shape{4}, Shape{0, 2, 3, 1}); + auto transpose0 = std::make_shared(X, ng_order0); + + NodePtr in_op = transpose0; + for (size_t i = 0; i < num_unary_ops; ++i) { + in_op = unary_factory->create(in_op); + } + + auto tanh0 = std::make_shared(in_op); + auto tanh1 = std::make_shared(in_op); + + return std::make_shared(ov::OutputVector{tanh0, tanh1}, ov::ParameterVector{X}); +} + +} // namespace backward_mult_transposes + namespace forward { std::shared_ptr CreateFunction(UnaryFactoryPtr unary_factory, size_t num_unary_ops, - const ov::Shape& input_shape, - ov::element::Type input_type) { - auto X = std::make_shared(input_type, input_shape); + const Shape& input_shape, + element::Type input_type) { + auto X = std::make_shared(input_type, input_shape); - auto sinh = std::make_shared(X); + auto sinh = std::make_shared(X); - auto ng_order0 = std::make_shared(ov::element::u64, ov::Shape{4}, ov::Shape{0, 2, 3, 1}); - auto transpose0 = std::make_shared(sinh, ng_order0); + auto ng_order0 = std::make_shared(element::u64, Shape{4}, Shape{0, 2, 3, 1}); + auto transpose0 = std::make_shared(sinh, ng_order0); auto reshape = CreateReshape(transpose0, input_shape); @@ -314,14 +393,14 @@ std::shared_ptr CreateFunction(UnaryFactoryPtr unary_factory, std::shared_ptr CreateReferenceFunction(UnaryFactoryPtr unary_factory, size_t num_unary_ops, - const ov::Shape& input_shape, - ov::element::Type input_type) { - auto X = std::make_shared(input_type, input_shape); + const Shape& input_shape, + element::Type input_type) { + auto X = std::make_shared(input_type, input_shape); - auto sinh = std::make_shared(X); + auto sinh = std::make_shared(X); - auto ng_order0 = std::make_shared(ov::element::u64, ov::Shape{4}, ov::Shape{0, 2, 3, 1}); - auto transpose0 = std::make_shared(sinh, ng_order0); + auto ng_order0 = std::make_shared(element::u64, Shape{4}, Shape{0, 2, 3, 1}); + auto transpose0 = std::make_shared(sinh, ng_order0); auto reshape = CreateReshape(transpose0, input_shape); NodePtr in_op = sinh; @@ -329,8 +408,8 @@ std::shared_ptr CreateReferenceFunction(UnaryFactoryPtr unary_factory in_op = unary_factory->create(in_op); } - auto ng_order1 = std::make_shared(ov::element::u64, ov::Shape{4}, ov::Shape{0, 2, 3, 1}); - auto transpose1 = std::make_shared(in_op, ng_order1); + auto ng_order1 = std::make_shared(element::u64, Shape{4}, Shape{0, 2, 3, 1}); + auto transpose1 = std::make_shared(in_op, ng_order1); return std::make_shared(ov::OutputVector{transpose1, reshape}, ov::ParameterVector{X}); } @@ -339,21 +418,16 @@ std::shared_ptr CreateReferenceFunction(UnaryFactoryPtr unary_factory } // namespace mult_consumers_first_node std::vector unary_factories = { - CreateUnaryFactory(), CreateUnaryFactory(), - CreateUnaryFactory(), CreateUnaryFactory(), - CreateUnaryFactory(), CreateUnaryFactory(), - CreateUnaryFactory(), CreateUnaryFactory(), - CreateUnaryFactory(), CreateUnaryFactory(), - CreateUnaryFactory(), CreateUnaryFactory(), - CreateUnaryFactory(), CreateUnaryFactory(), - CreateUnaryFactory(), CreateUnaryFactory(), - CreateUnaryFactory(), CreateUnaryFactory(), - CreateUnaryFactory(), CreateUnaryFactory(), - CreateUnaryFactory(), CreateUnaryFactory(), - CreateUnaryFactory(), CreateUnaryFactory(), - CreateUnaryFactory(), CreateUnaryFactory(), - CreateUnaryFactory(), CreateUnaryFactory(), - CreateUnaryFactory()}; + CREATE_UNARY_FACTORY(Clamp), CREATE_UNARY_FACTORY(Elu), CREATE_UNARY_FACTORY(SoftPlus), + CREATE_UNARY_FACTORY(LogicalNot), CREATE_UNARY_FACTORY(Convert), CREATE_UNARY_FACTORY(Abs), + CREATE_UNARY_FACTORY(Acos), CREATE_UNARY_FACTORY(Asin), CREATE_UNARY_FACTORY(Asinh), + CREATE_UNARY_FACTORY(Atan), CREATE_UNARY_FACTORY(Ceiling), CREATE_UNARY_FACTORY(Cos), + CREATE_UNARY_FACTORY(Cosh), CREATE_UNARY_FACTORY(Erf), CREATE_UNARY_FACTORY(Exp), + CREATE_UNARY_FACTORY(Gelu), CREATE_UNARY_FACTORY(HSigmoid), CREATE_UNARY_FACTORY(HSwish), + CREATE_UNARY_FACTORY(Log), CREATE_UNARY_FACTORY(Negative), CREATE_UNARY_FACTORY(Relu), + CREATE_UNARY_FACTORY(Sigmoid), CREATE_UNARY_FACTORY(Sign), CREATE_UNARY_FACTORY(Sin), + CREATE_UNARY_FACTORY(Sinh), CREATE_UNARY_FACTORY(SoftSign), CREATE_UNARY_FACTORY(Sqrt), + CREATE_UNARY_FACTORY(Tan), CREATE_UNARY_FACTORY(Tanh)}; std::vector unary_operations_numbers = {1, 10}; @@ -365,8 +439,8 @@ TEST_P(TransposeSinkingUnaryTestFixture, CompareFunctions) { size_t num_unary_ops; CreateGraphF model_factory; CreateGraphF reference_model_factory; - ov::Shape input_shape; - ov::element::Type input_type; + Shape input_shape; + element::Type input_type; std::tie(unary_factory, pass_factory, num_unary_ops, @@ -380,90 +454,95 @@ TEST_P(TransposeSinkingUnaryTestFixture, CompareFunctions) { pass_factory->registerPass(manager); } -INSTANTIATE_TEST_SUITE_P( - TransposeSinkingUnaryForwardTestSuite, - TransposeSinkingUnaryTestFixture, - ::testing::Combine(::testing::ValuesIn(unary_factories), - ::testing::Values(CreatePassFactory()), - ::testing::ValuesIn(unary_operations_numbers), - ::testing::Values(CreateFunctionTransposeBefore), - ::testing::Values(CreateFunctionTransposeAfter), - ::testing::Values(ov::Shape{1, 96, 55, 55}), - ::testing::Values(ov::element::f32))); +INSTANTIATE_TEST_SUITE_P(TransposeSinkingUnaryForwardTestSuite, + TransposeSinkingUnaryTestFixture, + ::testing::Combine(::testing::ValuesIn(unary_factories), + ::testing::Values(CREATE_PASS_FACTORY(TransposeSinkingUnaryForward)), + ::testing::ValuesIn(unary_operations_numbers), + ::testing::Values(CreateFunctionTransposeBefore), + ::testing::Values(CreateFunctionTransposeAfter), + ::testing::Values(Shape{1, 96, 55, 55}), + ::testing::Values(element::f32)), + TransposeSinkingUnaryTestFixture::get_test_name); -INSTANTIATE_TEST_SUITE_P( - TransposeSinkingUnaryBackwardTestSuite, - TransposeSinkingUnaryTestFixture, - ::testing::Combine(::testing::ValuesIn(unary_factories), - ::testing::Values(CreatePassFactory()), - ::testing::ValuesIn(unary_operations_numbers), - ::testing::Values(CreateFunctionTransposeAfter), - ::testing::Values(CreateFunctionTransposeBefore), - ::testing::Values(ov::Shape{1, 96, 55, 55}), - ::testing::Values(ov::element::f32))); +INSTANTIATE_TEST_SUITE_P(TransposeSinkingUnaryBackwardTestSuite, + TransposeSinkingUnaryTestFixture, + ::testing::Combine(::testing::ValuesIn(unary_factories), + ::testing::Values(CREATE_PASS_FACTORY(TransposeSinkingUnaryBackward)), + ::testing::ValuesIn(unary_operations_numbers), + ::testing::Values(CreateFunctionTransposeAfter), + ::testing::Values(CreateFunctionTransposeBefore), + ::testing::Values(Shape{1, 96, 55, 55}), + ::testing::Values(element::f32)), + TransposeSinkingUnaryTestFixture::get_test_name); INSTANTIATE_TEST_SUITE_P( TransposeSinkingUnaryForwardMultConsumersTestSuiteLastNodeReshape, TransposeSinkingUnaryTestFixture, ::testing::Combine(::testing::ValuesIn(unary_factories), - ::testing::Values(CreatePassFactory()), + ::testing::Values(CREATE_PASS_FACTORY(TransposeSinkingUnaryForward)), ::testing::ValuesIn(unary_operations_numbers), ::testing::Values(mult_consumers_last_node::with_reshape::CreateFunctionTransposeBefore), ::testing::Values(mult_consumers_last_node::with_reshape::CreateFunctionTransposeAfter), - ::testing::Values(ov::Shape{1, 96, 55, 55}), - ::testing::Values(ov::element::f32))); + ::testing::Values(Shape{1, 96, 55, 55}), + ::testing::Values(element::f32)), + TransposeSinkingUnaryTestFixture::get_test_name); INSTANTIATE_TEST_SUITE_P( TransposeSinkingUnaryBackwardMultConsumersTestSuiteLastNodeReshape, TransposeSinkingUnaryTestFixture, ::testing::Combine(::testing::ValuesIn(unary_factories), - ::testing::Values(CreatePassFactory()), + ::testing::Values(CREATE_PASS_FACTORY(TransposeSinkingUnaryBackward)), ::testing::ValuesIn(unary_operations_numbers), ::testing::Values(mult_consumers_last_node::with_reshape::CreateFunctionTransposeAfter), ::testing::Values(mult_consumers_last_node::with_reshape::CreateFunctionTransposeBefore), - ::testing::Values(ov::Shape{1, 96, 55, 55}), - ::testing::Values(ov::element::f32))); + ::testing::Values(Shape{1, 96, 55, 55}), + ::testing::Values(element::f32)), + TransposeSinkingUnaryTestFixture::get_test_name); INSTANTIATE_TEST_SUITE_P( TransposeSinkingUnaryForwardMultConsumersTestSuiteLastNodeEltwise, TransposeSinkingUnaryTestFixture, ::testing::Combine(::testing::ValuesIn(unary_factories), - ::testing::Values(CreatePassFactory()), + ::testing::Values(CREATE_PASS_FACTORY(TransposeSinkingUnaryForward)), ::testing::ValuesIn(unary_operations_numbers), ::testing::Values(mult_consumers_last_node::with_eltwise::CreateFunctionTransposeBefore), ::testing::Values(mult_consumers_last_node::with_eltwise::CreateFunctionTransposeAfter), - ::testing::Values(ov::Shape{1, 96, 55, 55}), - ::testing::Values(ov::element::f32))); - -INSTANTIATE_TEST_SUITE_P( - TransposeSinkingUnaryBackwardMultConsumersTestSuiteLastNodeEltwise, - TransposeSinkingUnaryTestFixture, - ::testing::Combine(::testing::ValuesIn(unary_factories), - ::testing::Values(CreatePassFactory()), - ::testing::ValuesIn(unary_operations_numbers), - ::testing::Values(mult_consumers_last_node::with_eltwise::CreateFunctionTransposeAfter), - ::testing::Values(mult_consumers_last_node::with_eltwise::CreateFunctionTransposeBefore), - ::testing::Values(ov::Shape{1, 96, 55, 55}), - ::testing::Values(ov::element::f32))); + ::testing::Values(Shape{1, 96, 55, 55}), + ::testing::Values(element::f32)), + TransposeSinkingUnaryTestFixture::get_test_name); INSTANTIATE_TEST_SUITE_P( TransposeSinkingUnaryForwardMultConsumersTestSuiteFirstNode, TransposeSinkingUnaryTestFixture, ::testing::Combine(::testing::ValuesIn(unary_factories), - ::testing::Values(CreatePassFactory()), + ::testing::Values(CREATE_PASS_FACTORY(TransposeSinkingUnaryForward)), ::testing::ValuesIn(unary_operations_numbers), ::testing::Values(mult_consumers_first_node::forward::CreateFunction), ::testing::Values(mult_consumers_first_node::forward::CreateReferenceFunction), - ::testing::Values(ov::Shape{1, 96, 55, 55}), - ::testing::Values(ov::element::f32))); + ::testing::Values(Shape{1, 96, 55, 55}), + ::testing::Values(element::f32)), + TransposeSinkingUnaryTestFixture::get_test_name); + +INSTANTIATE_TEST_SUITE_P(TransposeSinkingUnaryBackwardMultConsumersTestSuiteFirstNode, + TransposeSinkingUnaryTestFixture, + ::testing::Combine(::testing::ValuesIn(unary_factories), + ::testing::Values(CREATE_PASS_FACTORY(TransposeSinkingUnaryBackward)), + ::testing::ValuesIn(unary_operations_numbers), + ::testing::Values(mult_consumers_first_node::backward::CreateFunction), + ::testing::Values(mult_consumers_first_node::backward::CreateFunction), + ::testing::Values(Shape{1, 96, 55, 55}), + ::testing::Values(element::f32)), + TransposeSinkingUnaryTestFixture::get_test_name); INSTANTIATE_TEST_SUITE_P( - TransposeSinkingUnaryBackwardMultConsumersTestSuiteFirstNode, + TransposeSinkingUnaryBackwardMultTransposeConsumersTestSuiteFirstNode, TransposeSinkingUnaryTestFixture, ::testing::Combine(::testing::ValuesIn(unary_factories), - ::testing::Values(CreatePassFactory()), + ::testing::Values(CREATE_PASS_FACTORY(TransposeSinkingUnaryBackward)), ::testing::ValuesIn(unary_operations_numbers), - ::testing::Values(mult_consumers_first_node::backward::CreateFunction), - ::testing::Values(mult_consumers_first_node::backward::CreateReferenceFunction), - ::testing::Values(ov::Shape{1, 96, 55, 55}), - ::testing::Values(ov::element::f32))); + ::testing::Values(mult_consumers_first_node::backward_mult_transposes::CreateFunction), + ::testing::Values(mult_consumers_first_node::backward_mult_transposes::CreateReferenceFunction), + ::testing::Values(Shape{1, 96, 55, 55}), + ::testing::Values(element::f32)), + TransposeSinkingUnaryTestFixture::get_test_name);