diff --git a/src/common/transformations/include/transformations/common_optimizations/transpose_sinking_pad.hpp b/src/common/transformations/include/transformations/common_optimizations/transpose_sinking_pad.hpp new file mode 100644 index 00000000000..a3c74aee2b3 --- /dev/null +++ b/src/common/transformations/include/transformations/common_optimizations/transpose_sinking_pad.hpp @@ -0,0 +1,30 @@ +// Copyright (C) 2022 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#pragma once + +#include "openvino/pass/graph_rewrite.hpp" +#include "openvino/pass/pass.hpp" +#include "transformations_visibility.hpp" + +namespace ov { +namespace pass { + +class TRANSFORMATIONS_API TransposeSinkingPadForward; +class TRANSFORMATIONS_API TransposeSinkingPadBackward; + +} // namespace pass +} // namespace ov + +class ov::pass::TransposeSinkingPadForward : public ov::pass::MatcherPass { +public: + OPENVINO_RTTI("ov::pass::TransposeSinkingPadForward", "0"); + TransposeSinkingPadForward(); +}; + +class ov::pass::TransposeSinkingPadBackward : public ov::pass::MatcherPass { +public: + OPENVINO_RTTI("ov::pass::TransposeSinkingPadBackward", "0"); + TransposeSinkingPadBackward(); +}; diff --git a/src/common/transformations/include/transformations/common_optimizations/transpose_sinking_utils.hpp b/src/common/transformations/include/transformations/common_optimizations/transpose_sinking_utils.hpp index 7e28e2816bf..2486ced2a63 100644 --- a/src/common/transformations/include/transformations/common_optimizations/transpose_sinking_utils.hpp +++ b/src/common/transformations/include/transformations/common_optimizations/transpose_sinking_utils.hpp @@ -8,7 +8,6 @@ #include #include -#include "itt.hpp" #include "openvino/op/util/op_types.hpp" #include "openvino/opsets/opset9.hpp" #include "openvino/pass/pattern/op/label.hpp" @@ -32,7 +31,7 @@ struct TransposeInputsInfo { * @brief Finds node first input that is a transpose operation and returns filled TransposeInputsInfo * for it */ -TransposeInputsInfo GetFirstTransposeInput(std::shared_ptr); +TransposeInputsInfo GetFirstTransposeInput(const std::shared_ptr&); /** * @brief Checks if @arg has any input node that is a transpose operation @@ -54,41 +53,44 @@ void SwapOutputNames(ov::Output, ov::Output); /** * @brief Swaps @args friendly names */ -void SwapFriendlyNames(std::shared_ptr, std::shared_ptr); +void SwapFriendlyNames(const std::shared_ptr&, const std::shared_ptr&); /** * @brief Swaps @args output tensor names and friendly names */ -void SwapNames(std::shared_ptr, std::shared_ptr); +void SwapNames(const std::shared_ptr&, const std::shared_ptr&); namespace sink_forward { /** * @brief Inserts reversed transposed on @args main_node inputs. Removes input transpose specified in @arg * transpose_input_info */ -void UpdateInputTransposes(std::shared_ptr main_node, const TransposeInputsInfo& transpose_input_info); +void UpdateInputTransposes(const std::shared_ptr& main_node, const TransposeInputsInfo& transpose_input_info); /** * @brief Removes @arg input node */ -void RemoveInputNode(std::shared_ptr, size_t input_idx); +void RemoveInputNode(const std::shared_ptr&, size_t input_idx); /** * @brief Inserts transposes on each main_node output with the order specified in @arg transpose_input_info */ -ov::NodeVector InsertOutputTransposes(std::shared_ptr main_node, +ov::NodeVector InsertOutputTransposes(const std::shared_ptr& main_node, const TransposeInputsInfo& transpose_input_info); } // namespace sink_forward namespace sink_backward { /** - * @brief Inserts transposes on each input of @arg main_node with the order specified in @arg transpose_const + * @brief Inserts transposes on inputs of @arg main_node specified by @arg input_indexes + * with the order specified in @arg transpose_const. If @arg input_indexes is empty, then it inserts + * transposes for all inputs. */ -ov::NodeVector InsertTransposeBeforeNode(std::shared_ptr main_node, - std::shared_ptr transpose_const); +ov::NodeVector InsertTransposeBeforeNode(const std::shared_ptr& main_node, + const std::shared_ptr& transpose_const, + std::vector input_indexes = {}); } // namespace sink_backward -void UpdateForwardSinkingAbility(std::shared_ptr); +void UpdateForwardSinkingAbility(const std::shared_ptr&); /** * @brief Checks if @arg has consumers that all are the same transpose operation. If no consumers at all @@ -99,6 +101,12 @@ bool HasSameOutputTransposeNodes(const ov::Output&); /** * Removes all direct node consumers that have one output */ -void RemoveSingleOutputConsumers(std::shared_ptr); +void RemoveSingleOutputConsumers(const std::shared_ptr&); +/** + * Changes the order of values in @arg input according to @arg transpose_axis_order along @arg axis + */ +ov::Output ChangeValuesOrder(const ov::Output& input, + const ov::AxisVector& transpose_axis_order, + const std::shared_ptr& axis); } // namespace transpose_sinking diff --git a/src/common/transformations/src/transformations/common_optimizations/transpose_sinking_general.cpp b/src/common/transformations/src/transformations/common_optimizations/transpose_sinking_general.cpp index 5de1a9d3be3..343857a0486 100644 --- a/src/common/transformations/src/transformations/common_optimizations/transpose_sinking_general.cpp +++ b/src/common/transformations/src/transformations/common_optimizations/transpose_sinking_general.cpp @@ -14,6 +14,7 @@ #include "transformations/common_optimizations/transpose_sinking.hpp" #include "transformations/common_optimizations/transpose_sinking_binary.hpp" #include "transformations/common_optimizations/transpose_sinking_concat.hpp" +#include "transformations/common_optimizations/transpose_sinking_pad.hpp" #include "transformations/common_optimizations/transpose_sinking_split.hpp" #include "transformations/common_optimizations/transpose_sinking_unary.hpp" #include "transformations/utils/utils.hpp" @@ -24,6 +25,7 @@ ov::pass::TransposeSinkingGeneralForward::TransposeSinkingGeneralForward() { add_matcher(); add_matcher(); add_matcher(); + add_matcher(); add_matcher(); } @@ -33,6 +35,7 @@ ov::pass::TransposeSinkingGeneralBackward::TransposeSinkingGeneralBackward() { add_matcher(); add_matcher(); add_matcher(); + add_matcher(); add_matcher(); } diff --git a/src/common/transformations/src/transformations/common_optimizations/transpose_sinking_pad.cpp b/src/common/transformations/src/transformations/common_optimizations/transpose_sinking_pad.cpp new file mode 100644 index 00000000000..1c0af9475c7 --- /dev/null +++ b/src/common/transformations/src/transformations/common_optimizations/transpose_sinking_pad.cpp @@ -0,0 +1,105 @@ +#include "transformations/common_optimizations/transpose_sinking_pad.hpp" + +#include + +#include "itt.hpp" +#include "openvino/op/util/op_types.hpp" +#include "openvino/opsets/opset10.hpp" +#include "openvino/pass/pattern/op/wrap_type.hpp" +#include "openvino/util/common_util.hpp" +#include "transformations/common_optimizations/transpose_sinking_utils.hpp" +#include "transformations/rt_info/transpose_sinking_attr.hpp" + +using namespace ov::pass::pattern; +using namespace ov; +using namespace ov::opset10; +using namespace transpose_sinking; + +ov::pass::TransposeSinkingPadForward::TransposeSinkingPadForward() { + MATCHER_SCOPE(TransposeSinkingPadForward); + auto const_label = wrap_type(); + auto transpose_label = wrap_type({any_input(), const_label}); + auto main_node_label = wrap_type({transpose_label, any_input(), any_input(), any_input()}); + + matcher_pass_callback matcher_pass_callback = [=](Matcher& m) { + const auto& pattern_to_node = m.get_pattern_map(); + + auto& main_node = pattern_to_node.at(main_node_label); + auto transpose = std::dynamic_pointer_cast(pattern_to_node.at(transpose_label)); + if (!transpose) { + return false; + } + + auto transpose_const = as_type_ptr(pattern_to_node.at(const_label)); + if (!transpose_const) { + return false; + } + + // remove Transpose on 1st input: + auto transpose_parent = main_node->input_value(0).get_node()->input_value(0); + main_node->input(0).replace_source_output(transpose_parent); + + // change the order of values for PadBegin and PadEng inputs + const auto transpose_axis_order = transpose_const->get_axis_vector_val(); + auto axis = std::make_shared(element::i32, Shape{}, std::vector{0}); + + main_node->input(1).replace_source_output( + ChangeValuesOrder(main_node->input_value(1), transpose_axis_order, axis)); + main_node->input(2).replace_source_output( + ChangeValuesOrder(main_node->input_value(2), transpose_axis_order, axis)); + + // insert Transpose for Pad output + TransposeInputsInfo transpose_input_info = {transpose, transpose_const, 0}; + for (auto& new_node : sink_forward::InsertOutputTransposes(main_node, transpose_input_info)) { + register_new_node(new_node); + transpose_sinking::UpdateForwardSinkingAbility(new_node); + } + return true; + }; + + auto m = std::make_shared(main_node_label, matcher_name); + register_matcher(m, matcher_pass_callback); +} + +ov::pass::TransposeSinkingPadBackward::TransposeSinkingPadBackward() { + MATCHER_SCOPE(TransposeSinkingPadBackward); + + auto main_node_label = wrap_type([](const Output& output) -> bool { + return has_static_rank()(output) && HasSameOutputTransposeNodes(output); + }); + + auto transpose_const_label = wrap_type(); + + auto transpose_label = + wrap_type({main_node_label, transpose_const_label}, [](const Output& output) -> bool { + return has_static_rank()(output) && is_sinking_node(output); + }); + + matcher_pass_callback matcher_pass_callback = [=](Matcher& m) { + const auto& pattern_to_output = m.get_pattern_value_map(); + auto transpose_const = as_type_ptr(pattern_to_output.at(transpose_const_label).get_node_shared_ptr()); + auto transpose = pattern_to_output.at(transpose_label).get_node_shared_ptr(); + auto main_node = pattern_to_output.at(main_node_label).get_node_shared_ptr(); + + for (auto& new_node : sink_backward::InsertTransposeBeforeNode(main_node, + transpose_const, + /* input_indexes= */ {0})) { + register_new_node(new_node); + } + + // remove output transposes + RemoveSingleOutputConsumers(main_node); + + const auto transpose_axis_order = transpose_const->get_axis_vector_val(); + auto axis = std::make_shared(element::i32, Shape{}, std::vector{0}); + + main_node->input(1).replace_source_output( + ChangeValuesOrder(main_node->input_value(1), transpose_axis_order, axis)); + main_node->input(2).replace_source_output( + ChangeValuesOrder(main_node->input_value(2), transpose_axis_order, axis)); + return true; + }; + + auto m = std::make_shared(transpose_label, matcher_name); + register_matcher(m, matcher_pass_callback); +} diff --git a/src/common/transformations/src/transformations/common_optimizations/transpose_sinking_utils.cpp b/src/common/transformations/src/transformations/common_optimizations/transpose_sinking_utils.cpp index 74974996824..eb59c443741 100644 --- a/src/common/transformations/src/transformations/common_optimizations/transpose_sinking_utils.cpp +++ b/src/common/transformations/src/transformations/common_optimizations/transpose_sinking_utils.cpp @@ -20,7 +20,22 @@ using namespace ov::opset9; using NodePtr = std::shared_ptr; -TransposeInputsInfo GetFirstTransposeInput(NodePtr node) { +Output ChangeValuesOrder(const Output& input, + const AxisVector& transpose_axis_order, + const std::shared_ptr& axis) { + auto rank = transpose_axis_order.size(); + auto split_pad = std::make_shared(input, axis, rank); + auto split_outputs = split_pad->outputs(); + OutputVector new_order(split_outputs.size()); + for (size_t i = 0; i < rank; ++i) { + new_order[i] = split_outputs[transpose_axis_order[i]]; + } + auto concat_pad = std::make_shared(new_order, 0); + copy_runtime_info(input.get_node_shared_ptr(), {split_pad, concat_pad}); + return concat_pad; +} + +TransposeInputsInfo GetFirstTransposeInput(const NodePtr& node) { for (size_t input_idx = 0; input_idx < node->get_input_size(); ++input_idx) { NodePtr input_node = node->get_input_node_shared_ptr(input_idx); auto transpose_node = as_type_ptr(input_node); @@ -38,7 +53,7 @@ TransposeInputsInfo GetFirstTransposeInput(NodePtr node) { } } - return TransposeInputsInfo(); + return {}; } bool IfNodeHasTransposeInputs(const Output& output) { @@ -60,20 +75,20 @@ void SwapOutputNames(Output output1, Output output2) { output1.set_names(node2_output_names); } -void SwapFriendlyNames(NodePtr node1, NodePtr node2) { +void SwapFriendlyNames(const NodePtr& node1, const NodePtr& node2) { const std::string node2_name = node2->get_friendly_name(); node2->set_friendly_name(node1->get_friendly_name()); node1->set_friendly_name(node2_name); } -void SwapNames(NodePtr node1, NodePtr node2) { +void SwapNames(const NodePtr& node1, const NodePtr& node2) { SwapFriendlyNames(node1, node2); SwapOutputNames(node1->output(0), node2->output(0)); } namespace { -bool HasDynamicRankInput(NodePtr node) { +bool HasDynamicRankInput(const NodePtr& node) { for (auto& input_node : node->input_values()) { const ov::Rank output_rank = input_node.get_partial_shape().rank(); if (output_rank.is_dynamic()) @@ -95,7 +110,7 @@ ov::Rank::value_type GetMaxInputRank(const NodePtr& node) { return max_input_rank; } -NodePtr InsertUnsqueeze(Output node, size_t n_dims) { +NodePtr InsertUnsqueeze(const Output& node, size_t n_dims) { std::vector dims(n_dims); std::iota(dims.begin(), dims.end(), 0); auto unsqueeze_const = std::make_shared(ov::element::i64, Shape{dims.size()}, dims); @@ -115,7 +130,7 @@ ov::Output FixInputNodeRank(ov::Output input_node, ov::Rank: namespace sink_forward { -void UpdateInputTransposes(NodePtr main_node, const TransposeInputsInfo& transpose_input_info) { +void UpdateInputTransposes(const NodePtr& main_node, const TransposeInputsInfo& transpose_input_info) { if (transpose_input_info.isEmpty() || HasDynamicRankInput(main_node)) return; @@ -148,7 +163,7 @@ void UpdateInputTransposes(NodePtr main_node, const TransposeInputsInfo& transpo } } -void RemoveInputNode(NodePtr main_node, size_t input_idx) { +void RemoveInputNode(const NodePtr& main_node, size_t input_idx) { auto input_node = main_node->input_value(input_idx); if (input_node.get_node()->get_input_size() < (input_idx + 1)) return; @@ -156,7 +171,7 @@ void RemoveInputNode(NodePtr main_node, size_t input_idx) { main_node->input(input_idx).replace_source_output(parent_node); } -NodeVector InsertOutputTransposes(NodePtr main_node, const TransposeInputsInfo& transpose_input_info) { +NodeVector InsertOutputTransposes(const NodePtr& main_node, const TransposeInputsInfo& transpose_input_info) { if (transpose_input_info.isEmpty()) return {}; const auto transpose_axis_order = transpose_input_info.transpose_const->get_axis_vector_val(); @@ -195,7 +210,13 @@ NodeVector InsertOutputTransposes(NodePtr main_node, const TransposeInputsInfo& namespace sink_backward { -NodeVector InsertTransposeBeforeNode(NodePtr main_node, std::shared_ptr transpose_const) { +NodeVector InsertTransposeBeforeNode(const NodePtr& main_node, + const std::shared_ptr& transpose_const, + std::vector input_indexes) { + if (input_indexes.empty()) { + input_indexes.resize(main_node->get_input_size()); + std::iota(input_indexes.begin(), input_indexes.end(), 0); + } const auto transpose_axis_order = transpose_const->get_axis_vector_val(); const auto transpose_element_type = transpose_const->get_element_type(); @@ -208,7 +229,7 @@ NodeVector InsertTransposeBeforeNode(NodePtr main_node, std::shared_ptrget_input_size(); ++i) { + for (const auto& i : input_indexes) { auto input_node = FixInputNodeRank(main_node->input_value(i), max_input_rank); auto new_transpose_const = std::make_shared(transpose_element_type, @@ -250,7 +271,7 @@ bool CanPropagateForwardThrough(Node* node) { return false; } -bool CanPropagateForward(NodePtr node) { +bool CanPropagateForward(const NodePtr& node) { for (size_t i = 0; i < node->get_output_size(); ++i) { for (auto& consumer_input : node->output(i).get_target_inputs()) { if (!CanPropagateForwardThrough(consumer_input.get_node())) @@ -263,7 +284,7 @@ bool CanPropagateForward(NodePtr node) { } // namespace -void UpdateForwardSinkingAbility(NodePtr node) { +void UpdateForwardSinkingAbility(const NodePtr& node) { if (!CanPropagateForward(node)) mark_as_no_sinking_node(node); } @@ -282,7 +303,7 @@ std::shared_ptr GetTransposeConstant(Node* node) { return constant_node; } -Node* FindFirstConsumer(NodePtr node) { +Node* FindFirstConsumer(const NodePtr& node) { for (size_t output_idx = 0; output_idx < node->get_output_size(); ++output_idx) { auto inputs = node->get_output_target_inputs(output_idx); if (inputs.empty()) @@ -292,7 +313,7 @@ Node* FindFirstConsumer(NodePtr node) { return nullptr; } -bool HasSameOutputTransposeNodes(NodePtr main_node) { +bool HasSameOutputTransposeNodes(const NodePtr& main_node) { AxisVector first_transpose_axis_order; { Node* first_consumer = FindFirstConsumer(main_node); @@ -329,7 +350,7 @@ bool HasSameOutputTransposeNodes(const Output& output) { return HasSameOutputTransposeNodes(output.get_node_shared_ptr()); } -void RemoveSingleOutputConsumers(NodePtr node) { +void RemoveSingleOutputConsumers(const NodePtr& node) { for (size_t output_idx = 0; output_idx < node->get_output_size(); ++output_idx) { for (auto& input : node->get_output_target_inputs(output_idx)) { Node* consumer = input.get_node(); diff --git a/src/common/transformations/tests/common_optimizations/transpose_sinking_pad_test.cpp b/src/common/transformations/tests/common_optimizations/transpose_sinking_pad_test.cpp new file mode 100644 index 00000000000..e66f9008c6f --- /dev/null +++ b/src/common/transformations/tests/common_optimizations/transpose_sinking_pad_test.cpp @@ -0,0 +1,313 @@ +// Copyright (C) 2022 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#include +#include +#include +#include +#include +#include +#include +#include + +#include "common_test_utils/ngraph_test_utils.hpp" +#include "gtest/gtest.h" + +using namespace std; +using namespace ov; +using namespace opset10; + +namespace transpose_sinking_pad { +namespace { + +using NodePtr = shared_ptr; +using ModelPtr = shared_ptr; + +class IPassFactory { +public: + IPassFactory(const string& type_name) : type_name_(type_name) {} + + virtual ~IPassFactory() = default; + + virtual void registerPass(pass::Manager& pass_manager) const = 0; + + const string& getTypeName() const { + return type_name_; + } + +private: + const string type_name_; +}; + +using PassFactoryPtr = shared_ptr; + +template +class PassFactory : public IPassFactory { +public: + PassFactory(const string& type_name) : IPassFactory(type_name) {} + + void registerPass(pass::Manager& pass_manager) const override { + pass_manager.register_pass(); + pass_manager.register_pass(); + } +}; + +#define CREATE_PASS_FACTORY(pass_name) make_shared>(#pass_name) + +vector TransposePadValues(const vector& pads, const vector& order) { + vector new_pads(pads.size()); + for (size_t i = 0; i < pads.size(); ++i) { + new_pads[i] = pads[order[i]]; + } + return new_pads; +}; +} // namespace + +namespace forward { +namespace single_consumer { + +shared_ptr CreateFunction(size_t num_pad_ops, element::Type input_type) { + const Shape input_shape{96, 32, 55, 55}; + + auto X = make_shared(input_type, input_shape); + + auto order = make_shared(element::i64, Shape{4}, Shape{0, 3, 1, 2}); + auto transpose = make_shared(X, order); + + OutputVector outputs; + Output in_op = transpose->output(0); + auto pad_value = make_shared(input_type, Shape{}, 0); + for (size_t i = 0; i < num_pad_ops; ++i) { + auto pad_begin_const = make_shared(element::i64, Shape{4}, vector{0, 1, 2, 3}); + auto pad_end_const = make_shared(element::i64, Shape{4}, vector{0, 1, 2, 3}); + auto pad = make_shared(in_op, pad_begin_const, pad_end_const, pad_value, ov::op::PadMode::CONSTANT); + outputs.push_back((pad->output(0))); + in_op = pad; + } + outputs.push_back(in_op); + + return make_shared(outputs, ParameterVector{X}); +} + +shared_ptr CreateReferenceFunction(size_t num_pad_ops, element::Type input_type) { + const Shape input_shape{96, 32, 55, 55}; + + auto X = make_shared(input_type, input_shape); + + OutputVector outputs; + Output in_op = X->output(0); + vector pads{0, 1, 2, 3}; + auto transpose_pad_values = [&](const vector& order) { + vector new_pads(pads.size()); + for (size_t i = 0; i < pads.size(); ++i) { + new_pads[i] = pads[order[i]]; + } + return new_pads; + }; + auto axis = make_shared(element::i64, Shape{}, 0); + auto pad_value = make_shared(input_type, Shape{}, 0); + for (size_t i = 0; i < num_pad_ops; ++i) { + vector order_val = {0, 3, 1, 2}; + auto pad_begin_const = make_shared(element::i64, Shape{4}, transpose_pad_values(order_val)); + auto pad_end_const = make_shared(element::i64, Shape{4}, transpose_pad_values(order_val)); + auto pad = make_shared(in_op, pad_begin_const, pad_end_const, pad_value, ov::op::PadMode::CONSTANT); + + auto order = make_shared(element::i64, Shape{4}, Shape{order_val}); + auto transpose = make_shared(pad->output(0), order); + outputs.push_back(transpose); + in_op = pad; + } + + auto order = make_shared(element::i64, Shape{4}, Shape{0, 3, 1, 2}); + auto transpose = make_shared(in_op, order); + outputs.push_back(transpose); + + auto ref = make_shared(outputs, ParameterVector{X}); + ov::pass::Manager ps_manager; + ps_manager.run_passes(ref); + return ref; +} + +} // namespace single_consumer +} // namespace forward + +namespace backward { +namespace single_consumer { + +shared_ptr CreateFunction(size_t num_pad_ops, element::Type input_type) { + const Shape input_shape{96, 32, 55, 55}; + + auto X = make_shared(input_type, input_shape); + + OutputVector outputs; + Output in_op = X->output(0); + auto pad_value = make_shared(input_type, Shape{}, 0); + for (size_t i = 0; i < num_pad_ops; ++i) { + auto pad_begin_const = make_shared(element::i64, Shape{4}, vector{0, 1, 2, 3}); + auto pad_end_const = make_shared(element::i64, Shape{4}, vector{0, 1, 2, 3}); + auto pad = make_shared(in_op, pad_begin_const, pad_end_const, pad_value, ov::op::PadMode::CONSTANT); + in_op = pad; + } + auto order = make_shared(element::i64, Shape{4}, vector{0, 3, 1, 2}); + auto transpose = make_shared(in_op, order); + auto relu = make_shared(transpose); + outputs.push_back(relu); + return make_shared(outputs, ParameterVector{X}); +} + +shared_ptr CreateReferenceFunction(size_t num_pad_ops, element::Type input_type) { + const Shape input_shape{96, 32, 55, 55}; + + auto X = make_shared(input_type, input_shape); + vector order_val = {0, 3, 1, 2}; + auto order = make_shared(element::i64, Shape{4}, order_val); + auto transpose = make_shared(X, order); + + OutputVector outputs; + Output in_op = transpose->output(0); + vector pads{0, 1, 2, 3}; + auto axis = make_shared(element::i64, Shape{}, 0); + auto pad_value = make_shared(input_type, Shape{}, 0); + for (size_t i = 0; i < num_pad_ops; ++i) { + auto pad_begin_const = make_shared(element::i64, Shape{4}, TransposePadValues(pads, order_val)); + auto pad_end_const = make_shared(element::i64, Shape{4}, TransposePadValues(pads, order_val)); + auto pad = make_shared(in_op, pad_begin_const, pad_end_const, pad_value, ov::op::PadMode::CONSTANT); + in_op = pad; + } + auto relu = make_shared(in_op); + outputs.push_back(relu); + auto ref = make_shared(outputs, ParameterVector{X}); + return ref; +} + +} // namespace single_consumer + +namespace output_transpose_mult_transposes { +shared_ptr CreateFunction(size_t num_pad_ops, element::Type input_type) { + const Shape input_shape{96, 32, 55, 55}; + + auto X = make_shared(input_type, input_shape); + + OutputVector outputs; + Output in_op = X->output(0); + auto pad_value = make_shared(input_type, Shape{}, 0); + for (size_t i = 0; i < num_pad_ops; ++i) { + auto pad_begin_const = make_shared(element::i64, Shape{4}, vector{0, 1, 2, 3}); + auto pad_end_const = make_shared(element::i64, Shape{4}, vector{0, 1, 2, 3}); + auto pad = make_shared(in_op, pad_begin_const, pad_end_const, pad_value, ov::op::PadMode::CONSTANT); + in_op = pad; + } + auto order = make_shared(element::i64, Shape{4}, vector{0, 3, 1, 2}); + auto transpose_1 = make_shared(in_op, order); + auto relu_1 = make_shared(transpose_1); + outputs.push_back(relu_1); + + auto transpose_2 = make_shared(in_op, order); + auto relu_2 = make_shared(transpose_2); + outputs.push_back(relu_2); + return make_shared(outputs, ParameterVector{X}); +} + +shared_ptr CreateReferenceFunction(size_t num_pad_ops, element::Type input_type) { + const Shape input_shape{96, 32, 55, 55}; + + auto X = make_shared(input_type, input_shape); + vector order_val = {0, 3, 1, 2}; + auto order = make_shared(element::i64, Shape{4}, order_val); + auto transpose = make_shared(X, order); + + OutputVector outputs; + Output in_op = transpose->output(0); + vector pads{0, 1, 2, 3}; + + auto axis = make_shared(element::i64, Shape{}, 0); + auto pad_value = make_shared(input_type, Shape{}, 0); + for (size_t i = 0; i < num_pad_ops; ++i) { + auto pad_begin_const = make_shared(element::i64, Shape{4}, TransposePadValues(pads, order_val)); + auto pad_end_const = make_shared(element::i64, Shape{4}, TransposePadValues(pads, order_val)); + auto pad = make_shared(in_op, pad_begin_const, pad_end_const, pad_value, ov::op::PadMode::CONSTANT); + in_op = pad; + } + auto relu_1 = make_shared(in_op); + auto relu_2 = make_shared(in_op); + outputs.push_back(relu_1); + outputs.push_back(relu_2); + auto ref = make_shared(outputs, ParameterVector{X}); + return ref; +} +} // namespace output_transpose_mult_transposes +} // namespace backward + +using CreateGraphPadF = function(size_t num_pad_ops, element::Type input_type)>; + +using TestPadParams = tuple /* input type */; + +class TransposeSinkingPadTestFixture : public ::testing::WithParamInterface, + public TransformationTestsF { +public: + static string get_test_name(const testing::TestParamInfo& obj) { + PassFactoryPtr pass_factory; + size_t num_pad_ops; + CreateGraphPadF model_factory; + CreateGraphPadF reference_model_factory; + element::Type input_type; + + tie(pass_factory, num_pad_ops, model_factory, reference_model_factory, input_type) = obj.param; + + ostringstream test_name; + test_name << "pass_factory=" << pass_factory->getTypeName() << "_"; + test_name << "num_pad_ops=" << num_pad_ops << "_"; + test_name << "input_type=" << input_type; + + return test_name.str(); + } +}; + +TEST_P(TransposeSinkingPadTestFixture, CompareFunctions) { + PassFactoryPtr pass_factory; + size_t num_pad_ops; + CreateGraphPadF model_factory; + CreateGraphPadF reference_model_factory; + element::Type input_type; + tie(pass_factory, num_pad_ops, model_factory, reference_model_factory, input_type) = this->GetParam(); + + model = model_factory(num_pad_ops, input_type); + model_ref = reference_model_factory(num_pad_ops, input_type); + pass_factory->registerPass(manager); +} + +std::vector pad_operations_numbers = {1, 10}; + +INSTANTIATE_TEST_SUITE_P(TransposeSinkingPadForwardSingleConsumerTestSuite, + TransposeSinkingPadTestFixture, + ::testing::Combine(::testing::Values(CREATE_PASS_FACTORY(TransposeSinkingPadForward)), + ::testing::ValuesIn(pad_operations_numbers), + ::testing::Values(forward::single_consumer::CreateFunction), + ::testing::Values(forward::single_consumer::CreateReferenceFunction), + ::testing::Values(element::f32)), + TransposeSinkingPadTestFixture::get_test_name); + +INSTANTIATE_TEST_SUITE_P(TransposeSinkingPadBackwardSingleConsumerTestSuite, + TransposeSinkingPadTestFixture, + ::testing::Combine(::testing::Values(CREATE_PASS_FACTORY(TransposeSinkingPadBackward)), + ::testing::ValuesIn(pad_operations_numbers), + ::testing::Values(backward::single_consumer::CreateFunction), + ::testing::Values(backward::single_consumer::CreateReferenceFunction), + ::testing::Values(element::f32)), + TransposeSinkingPadTestFixture::get_test_name); + +INSTANTIATE_TEST_SUITE_P( + TransposeSinkingPadBackwardSingleConsumerMultiTransposesTestSuite, + TransposeSinkingPadTestFixture, + ::testing::Combine(::testing::Values(CREATE_PASS_FACTORY(TransposeSinkingPadBackward)), + ::testing::ValuesIn(pad_operations_numbers), + ::testing::Values(backward::output_transpose_mult_transposes::CreateFunction), + ::testing::Values(backward::output_transpose_mult_transposes::CreateReferenceFunction), + ::testing::Values(element::f32)), + TransposeSinkingPadTestFixture::get_test_name); +} // namespace transpose_sinking_pad \ No newline at end of file