Transpose sinking binary/split/concat mult consumers case + PRelu (#14670)

* binary: move test classes into namespaces

* add unsing namespacees ov and ov::opset9

* ov::Model -> Model

* use naming for unit tests

* fix unary -> binary

* fix identation

* add tests binary forward consumers > 1

* add binary output consumers > 1 unit tests

* add binary input consumers > 1 unit tests

* binary backward add unit test binary has multiple consumers + fix

* move code to common function CloneNodeWithoutConsumers

* add unit test on backward binary propagation multi consumers

* reorganize binary unit tests

* add backward_input_node_consumers test

* add test output transpose mult consumers

* cleanup

* cleanup unit tests for split

* add forward::mult_consumers::input_node_consumers tests

* add forward::mult_consumers::input_transpose_consumers tests

* add forward::mult_consumers::output_consumers tests

* code indentation fix

* fix unit test split backward input node consumers

* added mult_output_consumers split unit test

* add mult_split_consumers split unit test

* cancat tests using namespaces

* concat add unit tests

* move repeated code into functions

* clang format fixes

* rename TransposeSinkingBinaryElementwise -> TransposeSinkingBinary

* add PRelu

* clang cleanup

* fix prelu tests

* fix using

* clang fixes

* fix NodeVector argument

* fix function descriptions

* fix const ref arg functions

* clang fixes

* prevent backward sinking if not all output nodes are transposes

* fix RemoveSingleOutputConsumers checking output size

* cleanup

* clang cleanup

* disable unary backward with not same transposes

* clang fixes

* remove unneeded functions CloneNodeWithoutConsumers GetNodeIds

* remove unneeded GetTransposeConstant(Input<Node> input) as duplicate of GetTransposeConstant(Node* node)

* cleanup

* add output_transpose_mult_transposes test

* add backward::output_transpose_mult_transposes test

* add unit tests for unary backward multiple transposes; fix problems by adding new transformation

* fix bug TransposeSinkingUnaryBackwardMultiConsumers consumers_more_than
This commit is contained in:
Evgeny Kotov 2023-01-12 15:42:46 +01:00 committed by GitHub
parent 9388560aec
commit 78995e9ac2
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
13 changed files with 2953 additions and 1008 deletions

View File

@ -11,20 +11,20 @@
namespace ov { namespace ov {
namespace pass { namespace pass {
class TRANSFORMATIONS_API TransposeSinkingBinaryElementwiseForward; class TRANSFORMATIONS_API TransposeSinkingBinaryForward;
class TRANSFORMATIONS_API TransposeSinkingBinaryElementwiseBackward; class TRANSFORMATIONS_API TransposeSinkingBinaryBackward;
} // namespace pass } // namespace pass
} // namespace ov } // namespace ov
class ov::pass::TransposeSinkingBinaryElementwiseForward : public ov::pass::MatcherPass { class ov::pass::TransposeSinkingBinaryForward : public ov::pass::MatcherPass {
public: public:
OPENVINO_RTTI("ov::pass::TransposeSinkingBinaryElementwiseForward", "0"); OPENVINO_RTTI("ov::pass::TransposeSinkingBinaryForward", "0");
TransposeSinkingBinaryElementwiseForward(); TransposeSinkingBinaryForward();
}; };
class ov::pass::TransposeSinkingBinaryElementwiseBackward : public ov::pass::MatcherPass { class ov::pass::TransposeSinkingBinaryBackward : public ov::pass::MatcherPass {
public: public:
OPENVINO_RTTI("ov::pass::TransposeSinkingBinaryElementwiseBackward", "0"); OPENVINO_RTTI("ov::pass::TransposeSinkingBinaryBackward", "0");
TransposeSinkingBinaryElementwiseBackward(); TransposeSinkingBinaryBackward();
}; };

View File

@ -11,6 +11,8 @@ namespace ov {
namespace pass { namespace pass {
class TRANSFORMATIONS_API TransposeSinkingUnaryForward; class TRANSFORMATIONS_API TransposeSinkingUnaryForward;
class TRANSFORMATIONS_API TransposeSinkingUnaryBackwardSingleConsumer;
class TRANSFORMATIONS_API TransposeSinkingUnaryBackwardMultiConsumers;
class TRANSFORMATIONS_API TransposeSinkingUnaryBackward; class TRANSFORMATIONS_API TransposeSinkingUnaryBackward;
} // namespace pass } // namespace pass
@ -22,7 +24,19 @@ public:
TransposeSinkingUnaryForward(); TransposeSinkingUnaryForward();
}; };
class ov::pass::TransposeSinkingUnaryBackward : public ov::pass::MatcherPass { class ov::pass::TransposeSinkingUnaryBackwardSingleConsumer : public ov::pass::MatcherPass {
public:
OPENVINO_RTTI("TransposeSinkingUnaryBackwardSingleConsumer", "0");
TransposeSinkingUnaryBackwardSingleConsumer();
};
class ov::pass::TransposeSinkingUnaryBackwardMultiConsumers : public ov::pass::MatcherPass {
public:
OPENVINO_RTTI("TransposeSinkingUnaryBackwardMultiConsumers", "0");
TransposeSinkingUnaryBackwardMultiConsumers();
};
class ov::pass::TransposeSinkingUnaryBackward : public ov::pass::GraphRewrite {
public: public:
OPENVINO_RTTI("TransposeSinkingUnaryBackward", "0"); OPENVINO_RTTI("TransposeSinkingUnaryBackward", "0");
TransposeSinkingUnaryBackward(); TransposeSinkingUnaryBackward();

View File

@ -28,25 +28,77 @@ struct TransposeInputsInfo {
} }
}; };
TransposeInputsInfo GetFirstTransposeInput(std::shared_ptr<ov::Node> node); /**
bool IfNodeHasTransposeInputs(const ov::Output<ov::Node>& output); * @brief Finds node first input that is a transpose operation and returns filled TransposeInputsInfo
ov::AxisVector ReverseTransposeOrder(const ov::AxisVector& axis_order); * for it
void SwapOutputNames(ov::Output<ov::Node> output1, ov::Output<ov::Node> output2); */
void SwapFriendlyNames(std::shared_ptr<ov::Node> node1, std::shared_ptr<ov::Node> node2); TransposeInputsInfo GetFirstTransposeInput(std::shared_ptr<ov::Node>);
void SwapNames(std::shared_ptr<ov::Node> node1, std::shared_ptr<ov::Node> node2);
/**
* @brief Checks if @arg has any input node that is a transpose operation
*/
bool IfNodeHasTransposeInputs(const ov::Output<ov::Node>&);
/**
* @brief Reverses order of transpose operation. Do it in a such way that if we had couple following one after
* another transposes (one would be reversed version of another) we will have no transpose as a result of that
* couple of transposes.
*/
ov::AxisVector ReverseTransposeOrder(const ov::AxisVector&);
/**
* @brief Swaps @args output tensor names
*/
void SwapOutputNames(ov::Output<ov::Node>, ov::Output<ov::Node>);
/**
* @brief Swaps @args friendly names
*/
void SwapFriendlyNames(std::shared_ptr<ov::Node>, std::shared_ptr<ov::Node>);
/**
* @brief Swaps @args output tensor names and friendly names
*/
void SwapNames(std::shared_ptr<ov::Node>, std::shared_ptr<ov::Node>);
namespace sink_forward { namespace sink_forward {
// insert input reversed transposes, remove first input tranpose /**
void UpdateInputTransposes(std::shared_ptr<ov::Node> main_node, TransposeInputsInfo& transpose_input_info); * @brief Inserts reversed transposed on @args main_node inputs. Removes input transpose specified in @arg
void RemoveZeroInputNode(std::shared_ptr<ov::Node> main_node); * transpose_input_info
ov::NodeVector InsertOutputTransposes(std::shared_ptr<ov::Node> main_node, TransposeInputsInfo& transpose_input_info); */
void UpdateInputTransposes(std::shared_ptr<ov::Node> main_node, const TransposeInputsInfo& transpose_input_info);
/**
* @brief Removes @arg input node
*/
void RemoveInputNode(std::shared_ptr<ov::Node>, size_t input_idx);
/**
* @brief Inserts transposes on each main_node output with the order specified in @arg transpose_input_info
*/
ov::NodeVector InsertOutputTransposes(std::shared_ptr<ov::Node> main_node,
const TransposeInputsInfo& transpose_input_info);
} // namespace sink_forward } // namespace sink_forward
namespace sink_backward { namespace sink_backward {
/**
* @brief Inserts transposes on each input of @arg main_node with the order specified in @arg transpose_const
*/
ov::NodeVector InsertTransposeBeforeNode(std::shared_ptr<ov::Node> main_node, ov::NodeVector InsertTransposeBeforeNode(std::shared_ptr<ov::Node> main_node,
std::shared_ptr<ov::opset9::Constant> transpose_const); std::shared_ptr<ov::opset9::Constant> transpose_const);
} // namespace sink_backward } // namespace sink_backward
void UpdateForwardSinkingAbility(std::shared_ptr<ov::Node>); void UpdateForwardSinkingAbility(std::shared_ptr<ov::Node>);
/**
* @brief Checks if @arg has consumers that all are the same transpose operation. If no consumers at all
* returns false.
*/
bool HasSameOutputTransposeNodes(const ov::Output<ov::Node>&);
/**
* Removes all direct node consumers that have one output
*/
void RemoveSingleOutputConsumers(std::shared_ptr<ov::Node>);
} // namespace transpose_sinking } // namespace transpose_sinking

View File

@ -20,10 +20,10 @@ using namespace ov;
using namespace ov::opset9; using namespace ov::opset9;
using namespace transpose_sinking; using namespace transpose_sinking;
ov::pass::TransposeSinkingBinaryElementwiseForward::TransposeSinkingBinaryElementwiseForward() { ov::pass::TransposeSinkingBinaryForward::TransposeSinkingBinaryForward() {
MATCHER_SCOPE(TransposeSinkingBinaryElementwiseForward); MATCHER_SCOPE(TransposeSinkingBinaryForward);
auto main_node_label = wrap_type<op::util::BinaryElementwiseArithmetic>(IfNodeHasTransposeInputs); auto main_node_label = wrap_type<op::util::BinaryElementwiseArithmetic, PRelu>(IfNodeHasTransposeInputs);
matcher_pass_callback matcher_pass_callback = [=](Matcher& m) { matcher_pass_callback matcher_pass_callback = [=](Matcher& m) {
const auto& pattern_to_output = m.get_pattern_value_map(); const auto& pattern_to_output = m.get_pattern_value_map();
@ -46,18 +46,19 @@ ov::pass::TransposeSinkingBinaryElementwiseForward::TransposeSinkingBinaryElemen
register_matcher(m, matcher_pass_callback); register_matcher(m, matcher_pass_callback);
} }
ov::pass::TransposeSinkingBinaryElementwiseBackward::TransposeSinkingBinaryElementwiseBackward() { ov::pass::TransposeSinkingBinaryBackward::TransposeSinkingBinaryBackward() {
MATCHER_SCOPE(TransposeSinkingBinaryElementwiseBackward); MATCHER_SCOPE(TransposeSinkingBinaryBackward);
auto main_node_label = wrap_type<op::util::BinaryElementwiseArithmetic>([](const Output<Node>& output) -> bool { auto main_node_label =
return consumers_count(1)(output) && has_static_rank()(output); wrap_type<op::util::BinaryElementwiseArithmetic, PRelu>([](const Output<Node>& output) -> bool {
return has_static_rank()(output) && HasSameOutputTransposeNodes(output);
}); });
auto transpose_const_label = wrap_type<Constant>(consumers_count(1)); auto transpose_const_label = wrap_type<Constant>();
auto transpose_label = auto transpose_label =
wrap_type<Transpose>({main_node_label, transpose_const_label}, [](const Output<Node>& output) -> bool { wrap_type<Transpose>({main_node_label, transpose_const_label}, [](const Output<Node>& output) -> bool {
return consumers_count(1)(output) && has_static_rank()(output) && is_sinking_node(output); return has_static_rank()(output) && is_sinking_node(output);
}); });
matcher_pass_callback matcher_pass_callback = [=](Matcher& m) { matcher_pass_callback matcher_pass_callback = [=](Matcher& m) {
@ -70,8 +71,8 @@ ov::pass::TransposeSinkingBinaryElementwiseBackward::TransposeSinkingBinaryEleme
register_new_node(new_node); register_new_node(new_node);
} }
// remove transpose after main node // remove output transposes
transpose->output(0).replace(main_node); RemoveSingleOutputConsumers(main_node);
SwapNames(transpose, main_node); SwapNames(transpose, main_node);

View File

@ -55,14 +55,14 @@ ov::pass::TransposeSinkingConcatBackward::TransposeSinkingConcatBackward() {
MATCHER_SCOPE(TransposeSinkingConcatBackward); MATCHER_SCOPE(TransposeSinkingConcatBackward);
auto main_node_label = wrap_type<Concat>([](const Output<Node>& output) -> bool { auto main_node_label = wrap_type<Concat>([](const Output<Node>& output) -> bool {
return consumers_count(1)(output) && has_static_rank()(output); return has_static_rank()(output) && HasSameOutputTransposeNodes(output);
}); });
auto transpose_const_label = wrap_type<Constant>(consumers_count(1)); auto transpose_const_label = wrap_type<Constant>();
auto transpose_label = auto transpose_label =
wrap_type<Transpose>({main_node_label, transpose_const_label}, [](const Output<Node>& output) -> bool { wrap_type<Transpose>({main_node_label, transpose_const_label}, [](const Output<Node>& output) -> bool {
return consumers_count(1)(output) && has_static_rank()(output) && is_sinking_node(output); return has_static_rank()(output) && is_sinking_node(output);
}); });
matcher_pass_callback matcher_pass_callback = [=](Matcher& m) { matcher_pass_callback matcher_pass_callback = [=](Matcher& m) {
@ -75,8 +75,8 @@ ov::pass::TransposeSinkingConcatBackward::TransposeSinkingConcatBackward() {
register_new_node(new_node); register_new_node(new_node);
} }
// remove transpose after main node // remove output transposes
transpose->output(0).replace(main_node); RemoveSingleOutputConsumers(main_node);
SwapNames(transpose, main_node); SwapNames(transpose, main_node);

View File

@ -21,7 +21,7 @@
ov::pass::TransposeSinkingGeneralForward::TransposeSinkingGeneralForward() { ov::pass::TransposeSinkingGeneralForward::TransposeSinkingGeneralForward() {
MATCHER_SCOPE(TransposeSinkingGeneralForward); MATCHER_SCOPE(TransposeSinkingGeneralForward);
add_matcher<ov::pass::TransposeSinkingUnaryForward>(); add_matcher<ov::pass::TransposeSinkingUnaryForward>();
add_matcher<ov::pass::TransposeSinkingBinaryElementwiseForward>(); add_matcher<ov::pass::TransposeSinkingBinaryForward>();
add_matcher<ov::pass::TransposeSinkingConcatForward>(); add_matcher<ov::pass::TransposeSinkingConcatForward>();
add_matcher<ov::pass::TransposeSinkingSplitForward>(); add_matcher<ov::pass::TransposeSinkingSplitForward>();
add_matcher<ngraph::pass::TransposeFuse>(); add_matcher<ngraph::pass::TransposeFuse>();
@ -30,7 +30,7 @@ ov::pass::TransposeSinkingGeneralForward::TransposeSinkingGeneralForward() {
ov::pass::TransposeSinkingGeneralBackward::TransposeSinkingGeneralBackward() { ov::pass::TransposeSinkingGeneralBackward::TransposeSinkingGeneralBackward() {
MATCHER_SCOPE(TransposeSinkingGeneralBackward); MATCHER_SCOPE(TransposeSinkingGeneralBackward);
add_matcher<ov::pass::TransposeSinkingUnaryBackward>(); add_matcher<ov::pass::TransposeSinkingUnaryBackward>();
add_matcher<ov::pass::TransposeSinkingBinaryElementwiseBackward>(); add_matcher<ov::pass::TransposeSinkingBinaryBackward>();
add_matcher<ov::pass::TransposeSinkingConcatBackward>(); add_matcher<ov::pass::TransposeSinkingConcatBackward>();
add_matcher<ov::pass::TransposeSinkingSplitBackward>(); add_matcher<ov::pass::TransposeSinkingSplitBackward>();
add_matcher<ngraph::pass::TransposeFuse>(); add_matcher<ngraph::pass::TransposeFuse>();

View File

@ -52,78 +52,69 @@ OutputTranspose GetOutputTransposes(NodePtr node) {
return OutputTranspose(); return OutputTranspose();
} }
NodePtr FindSplitInput(Node* node) { template <typename NodeT>
std::shared_ptr<ov::Node> FindInputNode(ov::Node* node) {
for (size_t input_idx = 0; input_idx < node->get_input_size(); ++input_idx) { for (size_t input_idx = 0; input_idx < node->get_input_size(); ++input_idx) {
NodePtr input_node = node->get_input_node_shared_ptr(input_idx); std::shared_ptr<ov::Node> input_node = node->get_input_node_shared_ptr(input_idx);
auto split_node = as_type_ptr<Split>(input_node); auto target_node = ov::as_type_ptr<NodeT>(input_node);
if (split_node) if (target_node)
return split_node; return target_node;
} }
return {}; return {};
} }
std::shared_ptr<Constant> GetTransposeConstant(Input<Node> input) {
auto transpose_node = dynamic_cast<Transpose*>(input.get_node());
if (!transpose_node)
return {};
if (!is_sinking_node(input.get_node()))
return {};
auto constant_node = as_type_ptr<Constant>(transpose_node->input_value(1).get_node_shared_ptr());
if (!constant_node)
return {};
return constant_node;
}
bool HasInputSplitAndTransposeSiblings(const Output<Node>& output) { bool HasInputSplitAndTransposeSiblings(const Output<Node>& output) {
NodePtr split_node = FindSplitInput(output.get_node()); NodePtr main_node = FindInputNode<Split>(output.get_node());
if (!split_node) { if (!main_node) {
return false; return false;
} }
AxisVector first_transpose_axis_order; return HasSameOutputTransposeNodes(main_node);
// get first transpose axis
{
auto constant_node = GetTransposeConstant(*(split_node->get_output_target_inputs(0).begin()));
if (!constant_node)
return false;
first_transpose_axis_order = constant_node->get_axis_vector_val();
}
for (size_t output_idx = 1; output_idx < split_node->get_output_size(); ++output_idx) {
for (auto& input : split_node->get_output_target_inputs(output_idx)) {
auto constant_node = GetTransposeConstant(input);
if (!constant_node)
return false;
AxisVector transpose_axis_order = constant_node->get_axis_vector_val();
if (transpose_axis_order.size() != first_transpose_axis_order.size())
return false;
if (!std::equal(transpose_axis_order.begin(),
transpose_axis_order.end(),
first_transpose_axis_order.begin()))
return false;
}
}
return true;
} }
bool IsSplitSinked(const Output<Node>& output) {
return HasInputSplitAndTransposeSiblings(output) && is_sinking_node(output);
}
} // namespace } // namespace
/*
* We follow Transpose operations rather than Split. We cannot create matcher pattern
* for Split with Transpose outputs since Split can have different number of outputs.
* We just can:
* - specify Split as searched node and check if it has transpose outputs
* - specify Transpose as searched node and check if it has Split input
* Transformations are called on each found node in sorted order from the start to end
* of the network. When we proceed Split backward sinking we move input transpose
* to the input of the Split operation.
* Consider case Split (1) -> Split (2) -> Transpose
* If specify Split as main searched node after first transformation work we will have
* Split (1) -> Transpose -> Split(2)
* Matcher pass will not call TransposeSinkingSplitBackward since
* - matcher pattern has no Transpose label
* - Split (1) has already been proceeded
* Adding Split(2) into the working queue as register_new_node(split)
* cannot help us. We just can try to find all input Split operations and add them with
* register_new_node(). Implemented way is simpler.
*
* We sink Transpose through Split operation in a backward way only if all the output
* nodes are the same Transpose. We can:
* - clone Split with all outputs except Transpose
* causes perfomance problems
* - add reversed Transpose operations on all outputs except sinking Transpose
* nothing to do with new added output Transposes
*/
ov::pass::TransposeSinkingSplitBackward::TransposeSinkingSplitBackward() { ov::pass::TransposeSinkingSplitBackward::TransposeSinkingSplitBackward() {
MATCHER_SCOPE(TransposeSinkingSplitBackward); MATCHER_SCOPE(TransposeSinkingSplitBackward);
auto transpose_const_label = wrap_type<Constant>(consumers_count(1)); auto transpose_const_label = wrap_type<Constant>();
auto transpose_label = auto transpose_label = wrap_type<Transpose>({any_input(), transpose_const_label}, IsSplitSinked);
wrap_type<Transpose>({any_input(), transpose_const_label}, HasInputSplitAndTransposeSiblings);
matcher_pass_callback matcher_pass_callback = [=](Matcher& m) { matcher_pass_callback matcher_pass_callback = [=](Matcher& m) {
const auto& pattern_to_output = m.get_pattern_value_map(); const auto& pattern_to_output = m.get_pattern_value_map();
auto transpose_label_node = pattern_to_output.at(transpose_label).get_node(); auto transpose_label_node = pattern_to_output.at(transpose_label).get_node();
NodePtr split = FindSplitInput(transpose_label_node); NodePtr split = FindInputNode<Split>(transpose_label_node);
auto split_axis_constant = as_type_ptr<Constant>(split->input_value(1).get_node_shared_ptr()); auto split_axis_constant = as_type_ptr<Constant>(split->input_value(1).get_node_shared_ptr());
OutputTranspose output_transpose = GetOutputTransposes(split); OutputTranspose output_transpose = GetOutputTransposes(split);
@ -161,11 +152,8 @@ ov::pass::TransposeSinkingSplitBackward::TransposeSinkingSplitBackward() {
new_split_axis_const); new_split_axis_const);
// remove split output transposes // remove split output transposes
for (size_t output_idx = 0; output_idx < split->get_output_size(); ++output_idx) { RemoveSingleOutputConsumers(split);
for (auto& input : split->get_output_target_inputs(output_idx)) {
input.get_node()->output(0).replace(split->output(output_idx));
}
}
return true; return true;
}; };
@ -186,7 +174,7 @@ ov::pass::TransposeSinkingSplitForward::TransposeSinkingSplitForward() {
TransposeInputsInfo transpose_input_info = GetFirstTransposeInput(main_node); TransposeInputsInfo transpose_input_info = GetFirstTransposeInput(main_node);
sink_forward::RemoveZeroInputNode(main_node); sink_forward::RemoveInputNode(main_node, /* input_idx */ 0);
for (auto& new_node : sink_forward::InsertOutputTransposes(main_node, transpose_input_info)) { for (auto& new_node : sink_forward::InsertOutputTransposes(main_node, transpose_input_info)) {
register_new_node(new_node); register_new_node(new_node);
transpose_sinking::UpdateForwardSinkingAbility(new_node); transpose_sinking::UpdateForwardSinkingAbility(new_node);

View File

@ -10,6 +10,10 @@
#include "transformations/rt_info/transpose_sinking_attr.hpp" #include "transformations/rt_info/transpose_sinking_attr.hpp"
using namespace ov; using namespace ov;
using namespace ov::opset9;
using namespace ov::pass::pattern;
using namespace ov::op::util;
using namespace transpose_sinking;
namespace { namespace {
@ -90,16 +94,11 @@ NodePair Swap(NodePtr first_node, NodePtr second_node) {
ov::pass::TransposeSinkingUnaryForward::TransposeSinkingUnaryForward() { ov::pass::TransposeSinkingUnaryForward::TransposeSinkingUnaryForward() {
MATCHER_SCOPE(TransposeSinkingUnaryForward); MATCHER_SCOPE(TransposeSinkingUnaryForward);
auto transpose_label = ov::pass::pattern::wrap_type<ov::opset9::Transpose>( auto transpose_label = wrap_type<Transpose>({any_input(), any_input()});
{ov::pass::pattern::any_input(), ov::pass::pattern::any_input()}); auto unary_label =
auto unary_label = ov::pass::pattern::wrap_type<ov::op::util::UnaryElementwiseArithmetic, wrap_type<UnaryElementwiseArithmetic, Clamp, Elu, SoftPlus, LogicalNot, Convert>({transpose_label});
ov::opset9::Clamp,
ov::opset9::Elu,
ov::opset9::SoftPlus,
ov::opset9::LogicalNot,
ov::opset9::Convert>({transpose_label});
ov::matcher_pass_callback matcher_pass_callback = [=](ov::pass::pattern::Matcher& m) { ov::matcher_pass_callback matcher_pass_callback = [=](Matcher& m) {
const auto& pattern_to_output = m.get_pattern_value_map(); const auto& pattern_to_output = m.get_pattern_value_map();
auto transpose = pattern_to_output.at(transpose_label).get_node_shared_ptr(); auto transpose = pattern_to_output.at(transpose_label).get_node_shared_ptr();
auto unary = pattern_to_output.at(unary_label).get_node_shared_ptr(); auto unary = pattern_to_output.at(unary_label).get_node_shared_ptr();
@ -109,12 +108,12 @@ ov::pass::TransposeSinkingUnaryForward::TransposeSinkingUnaryForward() {
register_new_node(new_nodes.first); register_new_node(new_nodes.first);
register_new_node(new_nodes.second); register_new_node(new_nodes.second);
transpose_sinking::UpdateForwardSinkingAbility(new_nodes.second); UpdateForwardSinkingAbility(new_nodes.second);
return true; return true;
}; };
auto m = std::make_shared<ov::pass::pattern::Matcher>(unary_label, "ov::pass::TransposeSinkingUnaryForward"); auto m = std::make_shared<Matcher>(unary_label, "ov::pass::TransposeSinkingUnaryForward");
register_matcher(m, matcher_pass_callback); register_matcher(m, matcher_pass_callback);
} }
@ -124,21 +123,16 @@ bool IfSinkingEnabled(const Output<Node>& output) {
} }
} // namespace } // namespace
ov::pass::TransposeSinkingUnaryBackward::TransposeSinkingUnaryBackward() { ov::pass::TransposeSinkingUnaryBackwardSingleConsumer::TransposeSinkingUnaryBackwardSingleConsumer() {
MATCHER_SCOPE(TransposeSinkingUnaryBackward); MATCHER_SCOPE(TransposeSinkingUnaryBackwardSingleConsumer);
auto unary_label = ov::pass::pattern::wrap_type<ov::op::util::UnaryElementwiseArithmetic, auto unary_label =
ov::opset9::Clamp, wrap_type<UnaryElementwiseArithmetic, Clamp, Elu, SoftPlus, LogicalNot, Convert>({any_input()},
ov::opset9::Elu, consumers_count(1));
ov::opset9::SoftPlus,
ov::opset9::LogicalNot,
ov::opset9::Convert>({ov::pass::pattern::any_input()});
auto transpose_label = auto transpose_label = wrap_type<Transpose>({unary_label, any_input()}, IfSinkingEnabled);
ov::pass::pattern::wrap_type<ov::opset9::Transpose>({unary_label, ov::pass::pattern::any_input()},
IfSinkingEnabled);
ov::matcher_pass_callback matcher_pass_callback = [=](ov::pass::pattern::Matcher& m) { ov::matcher_pass_callback matcher_pass_callback = [=](Matcher& m) {
const auto& pattern_to_output = m.get_pattern_value_map(); const auto& pattern_to_output = m.get_pattern_value_map();
auto transpose = pattern_to_output.at(transpose_label).get_node_shared_ptr(); auto transpose = pattern_to_output.at(transpose_label).get_node_shared_ptr();
auto unary = pattern_to_output.at(unary_label).get_node_shared_ptr(); auto unary = pattern_to_output.at(unary_label).get_node_shared_ptr();
@ -151,6 +145,55 @@ ov::pass::TransposeSinkingUnaryBackward::TransposeSinkingUnaryBackward() {
return true; return true;
}; };
auto m = std::make_shared<ov::pass::pattern::Matcher>(transpose_label, "ov::pass::TransposeSinkingUnaryBackward"); auto m = std::make_shared<Matcher>(transpose_label, "ov::pass::TransposeSinkingUnaryBackwardSingleConsumer");
register_matcher(m, matcher_pass_callback); register_matcher(m, matcher_pass_callback);
} }
namespace {
std::function<bool(Output<Node>)> consumers_more_than(size_t n) {
return [=](Output<Node> output) -> bool {
return output.get_target_inputs().size() > n;
};
}
} // namespace
ov::pass::TransposeSinkingUnaryBackwardMultiConsumers::TransposeSinkingUnaryBackwardMultiConsumers() {
MATCHER_SCOPE(TransposeSinkingUnaryBackwardMultiConsumers);
auto unary_restrictions = [](const Output<Node>& output) -> bool {
return consumers_more_than(1)(output) && HasSameOutputTransposeNodes(output);
};
auto unary_label =
wrap_type<UnaryElementwiseArithmetic, Clamp, Elu, SoftPlus, LogicalNot, Convert>({any_input()},
unary_restrictions);
auto transpose_const_label = wrap_type<Constant>();
auto transpose_label = wrap_type<Transpose>({unary_label, transpose_const_label}, IfSinkingEnabled);
ov::matcher_pass_callback matcher_pass_callback = [=](Matcher& m) {
const auto& pattern_to_output = m.get_pattern_value_map();
auto transpose_const = as_type_ptr<Constant>(pattern_to_output.at(transpose_const_label).get_node_shared_ptr());
auto transpose = pattern_to_output.at(transpose_label).get_node_shared_ptr();
auto unary = pattern_to_output.at(unary_label).get_node_shared_ptr();
for (auto& new_node : sink_backward::InsertTransposeBeforeNode(unary, transpose_const)) {
register_new_node(new_node);
}
// remove output transposes
RemoveSingleOutputConsumers(unary);
return true;
};
auto m = std::make_shared<Matcher>(transpose_label, "ov::pass::TransposeSinkingUnaryBackwardMultiConsumers");
register_matcher(m, matcher_pass_callback);
}
ov::pass::TransposeSinkingUnaryBackward::TransposeSinkingUnaryBackward() {
MATCHER_SCOPE(TransposeSinkingUnaryBackward);
add_matcher<ov::pass::TransposeSinkingUnaryBackwardSingleConsumer>();
add_matcher<ov::pass::TransposeSinkingUnaryBackwardMultiConsumers>();
}

View File

@ -115,8 +115,7 @@ ov::Output<ov::Node> FixInputNodeRank(ov::Output<ov::Node> input_node, ov::Rank:
namespace sink_forward { namespace sink_forward {
// insert input reversed transposes, remove first input tranpose void UpdateInputTransposes(NodePtr main_node, const TransposeInputsInfo& transpose_input_info) {
void UpdateInputTransposes(NodePtr main_node, TransposeInputsInfo& transpose_input_info) {
if (transpose_input_info.isEmpty() || HasDynamicRankInput(main_node)) if (transpose_input_info.isEmpty() || HasDynamicRankInput(main_node))
return; return;
@ -149,15 +148,15 @@ void UpdateInputTransposes(NodePtr main_node, TransposeInputsInfo& transpose_inp
} }
} }
void RemoveZeroInputNode(NodePtr main_node) { void RemoveInputNode(NodePtr main_node, size_t input_idx) {
auto input_node = main_node->input_value(0); auto input_node = main_node->input_value(input_idx);
if (input_node.get_node()->get_input_size() < 1) if (input_node.get_node()->get_input_size() < (input_idx + 1))
return; return;
auto parent_node = input_node.get_node()->input_value(0); auto parent_node = input_node.get_node()->input_value(input_idx);
main_node->input(0).replace_source_output(parent_node); main_node->input(input_idx).replace_source_output(parent_node);
} }
NodeVector InsertOutputTransposes(NodePtr main_node, TransposeInputsInfo& transpose_input_info) { NodeVector InsertOutputTransposes(NodePtr main_node, const TransposeInputsInfo& transpose_input_info) {
if (transpose_input_info.isEmpty()) if (transpose_input_info.isEmpty())
return {}; return {};
const auto transpose_axis_order = transpose_input_info.transpose_const->get_axis_vector_val(); const auto transpose_axis_order = transpose_input_info.transpose_const->get_axis_vector_val();
@ -246,6 +245,7 @@ bool CanPropagateForwardThrough(Node* node) {
CHECK_TRANSPOSE_SINKING_SUPPORTED(Concat, node); CHECK_TRANSPOSE_SINKING_SUPPORTED(Concat, node);
CHECK_TRANSPOSE_SINKING_SUPPORTED(Split, node); CHECK_TRANSPOSE_SINKING_SUPPORTED(Split, node);
CHECK_TRANSPOSE_SINKING_SUPPORTED(Transpose, node); CHECK_TRANSPOSE_SINKING_SUPPORTED(Transpose, node);
CHECK_TRANSPOSE_SINKING_SUPPORTED(PRelu, node);
return false; return false;
} }
@ -268,4 +268,76 @@ void UpdateForwardSinkingAbility(NodePtr node) {
mark_as_no_sinking_node(node); mark_as_no_sinking_node(node);
} }
namespace {
std::shared_ptr<Constant> GetTransposeConstant(Node* node) {
auto transpose_node = dynamic_cast<Transpose*>(node);
if (!transpose_node)
return {};
auto constant_node = as_type_ptr<Constant>(transpose_node->input_value(1).get_node_shared_ptr());
if (!constant_node)
return {};
return constant_node;
}
Node* FindFirstConsumer(NodePtr node) {
for (size_t output_idx = 0; output_idx < node->get_output_size(); ++output_idx) {
auto inputs = node->get_output_target_inputs(output_idx);
if (inputs.empty())
continue;
return inputs.begin()->get_node();
}
return nullptr;
}
bool HasSameOutputTransposeNodes(NodePtr main_node) {
AxisVector first_transpose_axis_order;
{
Node* first_consumer = FindFirstConsumer(main_node);
if (!first_consumer)
return false;
auto constant_node = GetTransposeConstant(first_consumer);
if (!constant_node)
return false;
first_transpose_axis_order = constant_node->get_axis_vector_val();
}
for (size_t output_idx = 0; output_idx < main_node->get_output_size(); ++output_idx) {
for (auto& input : main_node->get_output_target_inputs(output_idx)) {
auto constant_node = GetTransposeConstant(input.get_node());
if (!constant_node)
return false;
AxisVector transpose_axis_order = constant_node->get_axis_vector_val();
if (transpose_axis_order.size() != first_transpose_axis_order.size())
return false;
if (!std::equal(transpose_axis_order.begin(),
transpose_axis_order.end(),
first_transpose_axis_order.begin()))
return false;
}
}
return true;
}
} // namespace
bool HasSameOutputTransposeNodes(const Output<Node>& output) {
return HasSameOutputTransposeNodes(output.get_node_shared_ptr());
}
void RemoveSingleOutputConsumers(NodePtr node) {
for (size_t output_idx = 0; output_idx < node->get_output_size(); ++output_idx) {
for (auto& input : node->get_output_target_inputs(output_idx)) {
Node* consumer = input.get_node();
if (consumer->get_output_size() != 1)
continue;
consumer->output(0).replace(node->output(output_idx));
}
}
}
} // namespace transpose_sinking } // namespace transpose_sinking

View File

@ -12,13 +12,37 @@
#include "common_test_utils/ngraph_test_utils.hpp" #include "common_test_utils/ngraph_test_utils.hpp"
#include "gtest/gtest.h" #include "gtest/gtest.h"
using namespace ov;
using namespace ov::opset9;
namespace {
std::string to_string(const Shape& shape) {
std::ostringstream result;
result << "{";
for (size_t idx = 0; idx < shape.size(); ++idx) {
if (idx)
result << ",";
result << shape[idx];
}
result << "}";
return result.str();
}
} // namespace
using NodePtr = std::shared_ptr<ov::Node>; using NodePtr = std::shared_ptr<ov::Node>;
class IUnaryFactory { class IUnaryFactory {
public: public:
IUnaryFactory() = default; IUnaryFactory(const std::string& type_name) : type_name_(type_name) {}
virtual ~IUnaryFactory() = default; virtual ~IUnaryFactory() = default;
virtual NodePtr create(NodePtr parent_node) const = 0; virtual NodePtr create(NodePtr parent_node) const = 0;
const std::string& getTypeName() const {
return type_name_;
}
private:
const std::string type_name_;
}; };
using UnaryFactoryPtr = std::shared_ptr<IUnaryFactory>; using UnaryFactoryPtr = std::shared_ptr<IUnaryFactory>;
@ -26,39 +50,45 @@ using UnaryFactoryPtr = std::shared_ptr<IUnaryFactory>;
template <typename UnaryT> template <typename UnaryT>
class UnaryFactory : public IUnaryFactory { class UnaryFactory : public IUnaryFactory {
public: public:
UnaryFactory() = default; UnaryFactory(const std::string& type_name) : IUnaryFactory(type_name) {}
NodePtr create(NodePtr parent_node) const override { NodePtr create(NodePtr parent_node) const override {
return std::make_shared<UnaryT>(parent_node); return std::make_shared<UnaryT>(parent_node);
} }
}; };
template <> template <>
NodePtr UnaryFactory<ov::opset9::Elu>::create(NodePtr parent_node) const { NodePtr UnaryFactory<Elu>::create(NodePtr parent_node) const {
return std::make_shared<ov::opset9::Elu>(parent_node, 0.1); return std::make_shared<Elu>(parent_node, 0.1);
} }
template <> template <>
NodePtr UnaryFactory<ov::opset9::Clamp>::create(NodePtr parent_node) const { NodePtr UnaryFactory<Clamp>::create(NodePtr parent_node) const {
return std::make_shared<ov::opset9::Clamp>(parent_node, 0.1, 0.2); return std::make_shared<Clamp>(parent_node, 0.1, 0.2);
} }
template <> template <>
NodePtr UnaryFactory<ov::opset9::Convert>::create(NodePtr parent_node) const { NodePtr UnaryFactory<Convert>::create(NodePtr parent_node) const {
return std::make_shared<ov::opset9::Convert>(parent_node, ov::element::f64); return std::make_shared<Convert>(parent_node, element::f64);
} }
template <typename UnaryT> template <typename UnaryT>
UnaryFactoryPtr CreateUnaryFactory() { UnaryFactoryPtr CreateUnaryFactory(const std::string& type_name) {
return std::make_shared<UnaryFactory<UnaryT>>(); return std::make_shared<UnaryFactory<UnaryT>>(type_name);
} }
// ---------------------------------------------------------------------------- // ----------------------------------------------------------------------------
class IPassFactory { class IPassFactory {
public: public:
IPassFactory() = default; IPassFactory(const std::string& type_name) : type_name_(type_name) {}
virtual ~IPassFactory() = default; virtual ~IPassFactory() = default;
virtual void registerPass(ov::pass::Manager& pass_manager) const = 0; virtual void registerPass(ov::pass::Manager& pass_manager) const = 0;
const std::string& getTypeName() const {
return type_name_;
}
private:
const std::string type_name_;
}; };
using PassFactoryPtr = std::shared_ptr<IPassFactory>; using PassFactoryPtr = std::shared_ptr<IPassFactory>;
@ -66,15 +96,16 @@ using PassFactoryPtr = std::shared_ptr<IPassFactory>;
template <typename PassT> template <typename PassT>
class PassFactory : public IPassFactory { class PassFactory : public IPassFactory {
public: public:
PassFactory(const std::string& type_name) : IPassFactory(type_name) {}
void registerPass(ov::pass::Manager& pass_manager) const override { void registerPass(ov::pass::Manager& pass_manager) const override {
pass_manager.register_pass<PassT>(); pass_manager.register_pass<PassT>();
} }
}; };
template <typename PassT> #define CREATE_PASS_FACTORY(pass_name) std::make_shared<PassFactory<ov::pass::pass_name>>(#pass_name)
PassFactoryPtr CreatePassFactory() {
return std::make_shared<PassFactory<PassT>>(); #undef CREATE_UNARY_FACTORY
} #define CREATE_UNARY_FACTORY(type_name) CreateUnaryFactory<type_name>(#type_name)
// ---------------------------------------------------------------------------- // ----------------------------------------------------------------------------
@ -82,19 +113,46 @@ using FloatPtr = std::unique_ptr<float[]>;
using CreateGraphF = std::function<std::shared_ptr<ov::Model>(UnaryFactoryPtr unary_factory, using CreateGraphF = std::function<std::shared_ptr<ov::Model>(UnaryFactoryPtr unary_factory,
size_t num_unary_ops, size_t num_unary_ops,
const ov::Shape& input_shape, const Shape& input_shape,
ov::element::Type input_type)>; element::Type input_type)>;
using TestParams = std::tuple<UnaryFactoryPtr, using TestParams = std::tuple<UnaryFactoryPtr,
PassFactoryPtr, PassFactoryPtr,
size_t, /* num_unary_ops */ size_t, /* num_unary_ops */
CreateGraphF, /* model_factory */ CreateGraphF, /* model_factory */
CreateGraphF, /* reference_model_factory */ CreateGraphF, /* reference_model_factory */
ov::Shape, /* input shape */ Shape, /* input shape */
ov::element::Type>; /* input type */ element::Type>; /* input type */
class TransposeSinkingUnaryTestFixture : public ::testing::WithParamInterface<TestParams>, class TransposeSinkingUnaryTestFixture : public ::testing::WithParamInterface<TestParams>, public TransformationTestsF {
public TransformationTestsF {}; public:
static std::string get_test_name(const testing::TestParamInfo<TestParams>& obj) {
UnaryFactoryPtr unary_factory;
PassFactoryPtr pass_factory;
size_t num_unary_ops;
CreateGraphF model_factory;
CreateGraphF reference_model_factory;
Shape input_shape;
element::Type input_type;
std::tie(unary_factory,
pass_factory,
num_unary_ops,
model_factory,
reference_model_factory,
input_shape,
input_type) = obj.param;
std::ostringstream test_name;
test_name << "unaryFactory=" << unary_factory->getTypeName() << "/";
test_name << "numUnaryOps=" << num_unary_ops << "/";
test_name << "inputShape=" << to_string(input_shape) << "/";
test_name << "unaryFactory=" << unary_factory->getTypeName() << "/";
test_name << "passFactory=" << pass_factory->getTypeName() << "/";
test_name << "inputType=" << input_type;
return test_name.str();
}
};
namespace { namespace {
@ -105,12 +163,12 @@ std::string GetFinalNodeName(std::shared_ptr<ov::Model> model, int index = 0) {
std::shared_ptr<ov::Model> CreateFunctionTransposeBefore(UnaryFactoryPtr unary_factory, std::shared_ptr<ov::Model> CreateFunctionTransposeBefore(UnaryFactoryPtr unary_factory,
size_t num_unary_ops, size_t num_unary_ops,
const ov::Shape& input_shape, const Shape& input_shape,
ov::element::Type input_type) { element::Type input_type) {
auto X = std::make_shared<ov::opset9::Parameter>(input_type, input_shape); auto X = std::make_shared<Parameter>(input_type, input_shape);
auto ng_order0 = std::make_shared<ov::opset9::Constant>(ov::element::u64, ov::Shape{4}, ov::Shape{0, 2, 3, 1}); auto ng_order0 = std::make_shared<Constant>(element::u64, Shape{4}, Shape{0, 2, 3, 1});
auto transpose0 = std::make_shared<ov::opset9::Transpose>(X, ng_order0); auto transpose0 = std::make_shared<Transpose>(X, ng_order0);
NodePtr in_op = transpose0; NodePtr in_op = transpose0;
for (size_t i = 0; i < num_unary_ops; ++i) { for (size_t i = 0; i < num_unary_ops; ++i) {
@ -122,25 +180,25 @@ std::shared_ptr<ov::Model> CreateFunctionTransposeBefore(UnaryFactoryPtr unary_f
std::shared_ptr<ov::Model> CreateFunctionTransposeAfter(UnaryFactoryPtr unary_factory, std::shared_ptr<ov::Model> CreateFunctionTransposeAfter(UnaryFactoryPtr unary_factory,
size_t num_unary_ops, size_t num_unary_ops,
const ov::Shape& input_shape, const Shape& input_shape,
ov::element::Type input_type) { element::Type input_type) {
auto X = std::make_shared<ov::opset9::Parameter>(input_type, input_shape); auto X = std::make_shared<Parameter>(input_type, input_shape);
NodePtr in_op = X; NodePtr in_op = X;
for (size_t i = 0; i < num_unary_ops; ++i) { for (size_t i = 0; i < num_unary_ops; ++i) {
in_op = unary_factory->create(in_op); in_op = unary_factory->create(in_op);
} }
auto ng_order0 = std::make_shared<ov::opset9::Constant>(ov::element::u64, ov::Shape{4}, ov::Shape{0, 2, 3, 1}); auto ng_order0 = std::make_shared<Constant>(element::u64, Shape{4}, Shape{0, 2, 3, 1});
auto transpose0 = std::make_shared<ov::opset9::Transpose>(in_op, ng_order0); auto transpose0 = std::make_shared<Transpose>(in_op, ng_order0);
return std::make_shared<ov::Model>(transpose0, ov::ParameterVector{X}); return std::make_shared<ov::Model>(transpose0, ov::ParameterVector{X});
} }
static NodePtr CreateReshape(NodePtr parent_node, const ov::Shape& input_shape) { static NodePtr CreateReshape(NodePtr parent_node, const Shape& input_shape) {
const size_t mul = std::accumulate(input_shape.begin(), input_shape.end(), (size_t)1, std::multiplies<size_t>()); const size_t mul = std::accumulate(input_shape.begin(), input_shape.end(), (size_t)1, std::multiplies<size_t>());
auto reshape_const = std::make_shared<ov::opset9::Constant>(ov::element::u64, ov::Shape{1}, ov::Shape{mul}); auto reshape_const = std::make_shared<Constant>(element::u64, Shape{1}, Shape{mul});
return std::make_shared<ov::opset9::Reshape>(parent_node, reshape_const, false); return std::make_shared<Reshape>(parent_node, reshape_const, false);
} }
namespace mult_consumers_last_node { namespace mult_consumers_last_node {
@ -148,17 +206,17 @@ namespace with_reshape {
std::shared_ptr<ov::Model> CreateFunctionTransposeAfter(UnaryFactoryPtr unary_factory, std::shared_ptr<ov::Model> CreateFunctionTransposeAfter(UnaryFactoryPtr unary_factory,
size_t num_unary_ops, size_t num_unary_ops,
const ov::Shape& input_shape, const Shape& input_shape,
ov::element::Type input_type) { element::Type input_type) {
auto X = std::make_shared<ov::opset9::Parameter>(input_type, input_shape); auto X = std::make_shared<Parameter>(input_type, input_shape);
NodePtr in_op = X; NodePtr in_op = X;
for (size_t i = 0; i < num_unary_ops; ++i) { for (size_t i = 0; i < num_unary_ops; ++i) {
in_op = unary_factory->create(in_op); in_op = unary_factory->create(in_op);
} }
auto ng_order0 = std::make_shared<ov::opset9::Constant>(ov::element::u64, ov::Shape{4}, ov::Shape{0, 2, 3, 1}); auto ng_order0 = std::make_shared<Constant>(element::u64, Shape{4}, Shape{0, 2, 3, 1});
auto transpose0 = std::make_shared<ov::opset9::Transpose>(in_op, ng_order0); auto transpose0 = std::make_shared<Transpose>(in_op, ng_order0);
auto reshape1 = CreateReshape(transpose0, input_shape); auto reshape1 = CreateReshape(transpose0, input_shape);
auto reshape2 = CreateReshape(transpose0, input_shape); auto reshape2 = CreateReshape(transpose0, input_shape);
@ -168,12 +226,12 @@ std::shared_ptr<ov::Model> CreateFunctionTransposeAfter(UnaryFactoryPtr unary_fa
std::shared_ptr<ov::Model> CreateFunctionTransposeBefore(UnaryFactoryPtr unary_factory, std::shared_ptr<ov::Model> CreateFunctionTransposeBefore(UnaryFactoryPtr unary_factory,
size_t num_unary_ops, size_t num_unary_ops,
const ov::Shape& input_shape, const Shape& input_shape,
ov::element::Type input_type) { element::Type input_type) {
auto X = std::make_shared<ov::opset9::Parameter>(input_type, input_shape); auto X = std::make_shared<Parameter>(input_type, input_shape);
auto ng_order0 = std::make_shared<ov::opset9::Constant>(ov::element::u64, ov::Shape{4}, ov::Shape{0, 2, 3, 1}); auto ng_order0 = std::make_shared<Constant>(element::u64, Shape{4}, Shape{0, 2, 3, 1});
auto transpose0 = std::make_shared<ov::opset9::Transpose>(X, ng_order0); auto transpose0 = std::make_shared<Transpose>(X, ng_order0);
NodePtr in_op = transpose0; NodePtr in_op = transpose0;
for (size_t i = 0; i < num_unary_ops; ++i) { for (size_t i = 0; i < num_unary_ops; ++i) {
@ -191,44 +249,44 @@ namespace with_eltwise {
std::shared_ptr<ov::Model> CreateFunctionTransposeAfter(UnaryFactoryPtr unary_factory, std::shared_ptr<ov::Model> CreateFunctionTransposeAfter(UnaryFactoryPtr unary_factory,
size_t num_unary_ops, size_t num_unary_ops,
const ov::Shape& input_shape, const Shape& input_shape,
ov::element::Type input_type) { element::Type input_type) {
auto X = std::make_shared<ov::opset9::Parameter>(input_type, input_shape); auto X = std::make_shared<Parameter>(input_type, input_shape);
NodePtr in_op = X; NodePtr in_op = X;
for (size_t i = 0; i < num_unary_ops; ++i) { for (size_t i = 0; i < num_unary_ops; ++i) {
in_op = unary_factory->create(in_op); in_op = unary_factory->create(in_op);
} }
auto sinh = std::make_shared<ov::opset9::Sinh>(in_op); auto sinh = std::make_shared<Sinh>(in_op);
auto ng_order0 = std::make_shared<ov::opset9::Constant>(ov::element::u64, ov::Shape{4}, ov::Shape{0, 2, 3, 1}); auto ng_order0 = std::make_shared<Constant>(element::u64, Shape{4}, Shape{0, 2, 3, 1});
auto transpose0 = std::make_shared<ov::opset9::Transpose>(sinh, ng_order0); auto transpose0 = std::make_shared<Transpose>(sinh, ng_order0);
auto cosh = std::make_shared<ov::opset9::Cosh>(in_op); auto cosh = std::make_shared<Cosh>(in_op);
auto ng_order1 = std::make_shared<ov::opset9::Constant>(ov::element::u64, ov::Shape{4}, ov::Shape{0, 2, 3, 1}); auto ng_order1 = std::make_shared<Constant>(element::u64, Shape{4}, Shape{0, 2, 3, 1});
auto transpose1 = std::make_shared<ov::opset9::Transpose>(cosh, ng_order1); auto transpose1 = std::make_shared<Transpose>(cosh, ng_order1);
return std::make_shared<ov::Model>(ov::OutputVector{transpose0, transpose1}, ov::ParameterVector{X}); return std::make_shared<ov::Model>(ov::OutputVector{transpose0, transpose1}, ov::ParameterVector{X});
} }
std::shared_ptr<ov::Model> CreateFunctionTransposeBefore(UnaryFactoryPtr unary_factory, std::shared_ptr<ov::Model> CreateFunctionTransposeBefore(UnaryFactoryPtr unary_factory,
size_t num_unary_ops, size_t num_unary_ops,
const ov::Shape& input_shape, const Shape& input_shape,
ov::element::Type input_type) { element::Type input_type) {
auto X = std::make_shared<ov::opset9::Parameter>(input_type, input_shape); auto X = std::make_shared<Parameter>(input_type, input_shape);
auto ng_order0 = std::make_shared<ov::opset9::Constant>(ov::element::u64, ov::Shape{4}, ov::Shape{0, 2, 3, 1}); auto ng_order0 = std::make_shared<Constant>(element::u64, Shape{4}, Shape{0, 2, 3, 1});
auto transpose0 = std::make_shared<ov::opset9::Transpose>(X, ng_order0); auto transpose0 = std::make_shared<Transpose>(X, ng_order0);
NodePtr in_op = transpose0; NodePtr in_op = transpose0;
for (size_t i = 0; i < num_unary_ops; ++i) { for (size_t i = 0; i < num_unary_ops; ++i) {
in_op = unary_factory->create(in_op); in_op = unary_factory->create(in_op);
} }
auto sinh = std::make_shared<ov::opset9::Sinh>(in_op); auto sinh = std::make_shared<Sinh>(in_op);
auto cosh = std::make_shared<ov::opset9::Cosh>(in_op); auto cosh = std::make_shared<Cosh>(in_op);
return std::make_shared<ov::Model>(ov::OutputVector{sinh, cosh}, ov::ParameterVector{X}); return std::make_shared<ov::Model>(ov::OutputVector{sinh, cosh}, ov::ParameterVector{X});
} }
@ -241,66 +299,87 @@ namespace backward {
std::shared_ptr<ov::Model> CreateFunction(UnaryFactoryPtr unary_factory, std::shared_ptr<ov::Model> CreateFunction(UnaryFactoryPtr unary_factory,
size_t num_unary_ops, size_t num_unary_ops,
const ov::Shape& input_shape, const Shape& input_shape,
ov::element::Type input_type) { element::Type input_type) {
auto X = std::make_shared<ov::opset9::Parameter>(input_type, input_shape); auto X = std::make_shared<Parameter>(input_type, input_shape);
ov::OutputVector outputs; ov::OutputVector outputs;
NodePtr in_op = X; NodePtr in_op = X;
for (size_t i = 0; i < num_unary_ops; ++i) { for (size_t i = 0; i < num_unary_ops; ++i) {
in_op = unary_factory->create(in_op); in_op = unary_factory->create(in_op);
auto cosh = std::make_shared<ov::opset9::Cosh>(in_op); auto cosh = std::make_shared<Cosh>(in_op);
outputs.push_back(cosh); outputs.push_back(cosh);
} }
auto ng_order0 = std::make_shared<ov::opset9::Constant>(ov::element::u64, ov::Shape{4}, ov::Shape{0, 2, 3, 1}); auto ng_order0 = std::make_shared<Constant>(element::u64, Shape{4}, Shape{0, 2, 3, 1});
auto transpose0 = std::make_shared<ov::opset9::Transpose>(in_op, ng_order0); auto transpose0 = std::make_shared<Transpose>(in_op, ng_order0);
outputs.push_back(transpose0); outputs.push_back(transpose0);
return std::make_shared<ov::Model>(outputs, ov::ParameterVector{X}); return std::make_shared<ov::Model>(outputs, ov::ParameterVector{X});
} }
std::shared_ptr<ov::Model> CreateReferenceFunction(UnaryFactoryPtr unary_factory,
size_t num_unary_ops,
const ov::Shape& input_shape,
ov::element::Type input_type) {
auto X = std::make_shared<ov::opset9::Parameter>(input_type, input_shape);
ov::OutputVector outputs;
NodePtr in_op = X;
for (size_t i = 0; i < num_unary_ops; ++i) {
in_op = unary_factory->create(in_op);
auto cosh = std::make_shared<ov::opset9::Cosh>(in_op);
outputs.push_back(cosh);
}
auto ng_order0 = std::make_shared<ov::opset9::Constant>(ov::element::u64, ov::Shape{4}, ov::Shape{0, 2, 3, 1});
auto transpose0 = std::make_shared<ov::opset9::Transpose>(X, ng_order0);
in_op = transpose0;
for (size_t i = 0; i < num_unary_ops; ++i) {
in_op = unary_factory->create(in_op);
}
outputs.push_back(in_op);
return std::make_shared<ov::Model>(outputs, ov::ParameterVector{X});
}
} // namespace backward } // namespace backward
namespace backward_mult_transposes {
std::shared_ptr<ov::Model> CreateFunction(UnaryFactoryPtr unary_factory,
size_t num_unary_ops,
const Shape& input_shape,
element::Type input_type) {
auto X = std::make_shared<Parameter>(input_type, input_shape);
NodePtr in_op = X;
for (size_t i = 0; i < num_unary_ops; ++i) {
in_op = unary_factory->create(in_op);
}
auto ng_order0 = std::make_shared<Constant>(element::u64, Shape{4}, Shape{0, 2, 3, 1});
auto transpose0 = std::make_shared<Transpose>(in_op, ng_order0);
auto tanh0 = std::make_shared<Tanh>(transpose0);
auto ng_order1 = std::make_shared<Constant>(element::u64, Shape{4}, Shape{0, 2, 3, 1});
auto transpose1 = std::make_shared<Transpose>(in_op, ng_order1);
auto tanh1 = std::make_shared<Tanh>(transpose1);
return std::make_shared<ov::Model>(ov::OutputVector{tanh0, tanh1}, ov::ParameterVector{X});
}
std::shared_ptr<ov::Model> CreateReferenceFunction(UnaryFactoryPtr unary_factory,
size_t num_unary_ops,
const Shape& input_shape,
element::Type input_type) {
auto X = std::make_shared<Parameter>(input_type, input_shape);
auto ng_order0 = std::make_shared<Constant>(element::u64, Shape{4}, Shape{0, 2, 3, 1});
auto transpose0 = std::make_shared<Transpose>(X, ng_order0);
NodePtr in_op = transpose0;
for (size_t i = 0; i < num_unary_ops; ++i) {
in_op = unary_factory->create(in_op);
}
auto tanh0 = std::make_shared<Tanh>(in_op);
auto tanh1 = std::make_shared<Tanh>(in_op);
return std::make_shared<ov::Model>(ov::OutputVector{tanh0, tanh1}, ov::ParameterVector{X});
}
} // namespace backward_mult_transposes
namespace forward { namespace forward {
std::shared_ptr<ov::Model> CreateFunction(UnaryFactoryPtr unary_factory, std::shared_ptr<ov::Model> CreateFunction(UnaryFactoryPtr unary_factory,
size_t num_unary_ops, size_t num_unary_ops,
const ov::Shape& input_shape, const Shape& input_shape,
ov::element::Type input_type) { element::Type input_type) {
auto X = std::make_shared<ov::opset9::Parameter>(input_type, input_shape); auto X = std::make_shared<Parameter>(input_type, input_shape);
auto sinh = std::make_shared<ov::opset9::Sinh>(X); auto sinh = std::make_shared<Sinh>(X);
auto ng_order0 = std::make_shared<ov::opset9::Constant>(ov::element::u64, ov::Shape{4}, ov::Shape{0, 2, 3, 1}); auto ng_order0 = std::make_shared<Constant>(element::u64, Shape{4}, Shape{0, 2, 3, 1});
auto transpose0 = std::make_shared<ov::opset9::Transpose>(sinh, ng_order0); auto transpose0 = std::make_shared<Transpose>(sinh, ng_order0);
auto reshape = CreateReshape(transpose0, input_shape); auto reshape = CreateReshape(transpose0, input_shape);
@ -314,14 +393,14 @@ std::shared_ptr<ov::Model> CreateFunction(UnaryFactoryPtr unary_factory,
std::shared_ptr<ov::Model> CreateReferenceFunction(UnaryFactoryPtr unary_factory, std::shared_ptr<ov::Model> CreateReferenceFunction(UnaryFactoryPtr unary_factory,
size_t num_unary_ops, size_t num_unary_ops,
const ov::Shape& input_shape, const Shape& input_shape,
ov::element::Type input_type) { element::Type input_type) {
auto X = std::make_shared<ov::opset9::Parameter>(input_type, input_shape); auto X = std::make_shared<Parameter>(input_type, input_shape);
auto sinh = std::make_shared<ov::opset9::Sinh>(X); auto sinh = std::make_shared<Sinh>(X);
auto ng_order0 = std::make_shared<ov::opset9::Constant>(ov::element::u64, ov::Shape{4}, ov::Shape{0, 2, 3, 1}); auto ng_order0 = std::make_shared<Constant>(element::u64, Shape{4}, Shape{0, 2, 3, 1});
auto transpose0 = std::make_shared<ov::opset9::Transpose>(sinh, ng_order0); auto transpose0 = std::make_shared<Transpose>(sinh, ng_order0);
auto reshape = CreateReshape(transpose0, input_shape); auto reshape = CreateReshape(transpose0, input_shape);
NodePtr in_op = sinh; NodePtr in_op = sinh;
@ -329,8 +408,8 @@ std::shared_ptr<ov::Model> CreateReferenceFunction(UnaryFactoryPtr unary_factory
in_op = unary_factory->create(in_op); in_op = unary_factory->create(in_op);
} }
auto ng_order1 = std::make_shared<ov::opset9::Constant>(ov::element::u64, ov::Shape{4}, ov::Shape{0, 2, 3, 1}); auto ng_order1 = std::make_shared<Constant>(element::u64, Shape{4}, Shape{0, 2, 3, 1});
auto transpose1 = std::make_shared<ov::opset9::Transpose>(in_op, ng_order1); auto transpose1 = std::make_shared<Transpose>(in_op, ng_order1);
return std::make_shared<ov::Model>(ov::OutputVector{transpose1, reshape}, ov::ParameterVector{X}); return std::make_shared<ov::Model>(ov::OutputVector{transpose1, reshape}, ov::ParameterVector{X});
} }
@ -339,21 +418,16 @@ std::shared_ptr<ov::Model> CreateReferenceFunction(UnaryFactoryPtr unary_factory
} // namespace mult_consumers_first_node } // namespace mult_consumers_first_node
std::vector<UnaryFactoryPtr> unary_factories = { std::vector<UnaryFactoryPtr> unary_factories = {
CreateUnaryFactory<ov::opset9::Clamp>(), CreateUnaryFactory<ov::opset9::Elu>(), CREATE_UNARY_FACTORY(Clamp), CREATE_UNARY_FACTORY(Elu), CREATE_UNARY_FACTORY(SoftPlus),
CreateUnaryFactory<ov::opset9::SoftPlus>(), CreateUnaryFactory<ov::opset9::LogicalNot>(), CREATE_UNARY_FACTORY(LogicalNot), CREATE_UNARY_FACTORY(Convert), CREATE_UNARY_FACTORY(Abs),
CreateUnaryFactory<ov::opset9::Convert>(), CreateUnaryFactory<ov::opset9::Abs>(), CREATE_UNARY_FACTORY(Acos), CREATE_UNARY_FACTORY(Asin), CREATE_UNARY_FACTORY(Asinh),
CreateUnaryFactory<ov::opset9::Acos>(), CreateUnaryFactory<ov::opset9::Asin>(), CREATE_UNARY_FACTORY(Atan), CREATE_UNARY_FACTORY(Ceiling), CREATE_UNARY_FACTORY(Cos),
CreateUnaryFactory<ov::opset9::Asinh>(), CreateUnaryFactory<ov::opset9::Atan>(), CREATE_UNARY_FACTORY(Cosh), CREATE_UNARY_FACTORY(Erf), CREATE_UNARY_FACTORY(Exp),
CreateUnaryFactory<ov::opset9::Ceiling>(), CreateUnaryFactory<ov::opset9::Cos>(), CREATE_UNARY_FACTORY(Gelu), CREATE_UNARY_FACTORY(HSigmoid), CREATE_UNARY_FACTORY(HSwish),
CreateUnaryFactory<ov::opset9::Cosh>(), CreateUnaryFactory<ov::opset9::Erf>(), CREATE_UNARY_FACTORY(Log), CREATE_UNARY_FACTORY(Negative), CREATE_UNARY_FACTORY(Relu),
CreateUnaryFactory<ov::opset9::Exp>(), CreateUnaryFactory<ov::opset9::Gelu>(), CREATE_UNARY_FACTORY(Sigmoid), CREATE_UNARY_FACTORY(Sign), CREATE_UNARY_FACTORY(Sin),
CreateUnaryFactory<ov::opset9::HSigmoid>(), CreateUnaryFactory<ov::opset9::HSwish>(), CREATE_UNARY_FACTORY(Sinh), CREATE_UNARY_FACTORY(SoftSign), CREATE_UNARY_FACTORY(Sqrt),
CreateUnaryFactory<ov::opset9::Log>(), CreateUnaryFactory<ov::opset9::Negative>(), CREATE_UNARY_FACTORY(Tan), CREATE_UNARY_FACTORY(Tanh)};
CreateUnaryFactory<ov::opset9::Relu>(), CreateUnaryFactory<ov::opset9::Sigmoid>(),
CreateUnaryFactory<ov::opset9::Sign>(), CreateUnaryFactory<ov::opset9::Sin>(),
CreateUnaryFactory<ov::opset9::Sinh>(), CreateUnaryFactory<ov::opset9::SoftSign>(),
CreateUnaryFactory<ov::opset9::Sqrt>(), CreateUnaryFactory<ov::opset9::Tan>(),
CreateUnaryFactory<ov::opset9::Tanh>()};
std::vector<size_t> unary_operations_numbers = {1, 10}; std::vector<size_t> unary_operations_numbers = {1, 10};
@ -365,8 +439,8 @@ TEST_P(TransposeSinkingUnaryTestFixture, CompareFunctions) {
size_t num_unary_ops; size_t num_unary_ops;
CreateGraphF model_factory; CreateGraphF model_factory;
CreateGraphF reference_model_factory; CreateGraphF reference_model_factory;
ov::Shape input_shape; Shape input_shape;
ov::element::Type input_type; element::Type input_type;
std::tie(unary_factory, std::tie(unary_factory,
pass_factory, pass_factory,
num_unary_ops, num_unary_ops,
@ -380,90 +454,95 @@ TEST_P(TransposeSinkingUnaryTestFixture, CompareFunctions) {
pass_factory->registerPass(manager); pass_factory->registerPass(manager);
} }
INSTANTIATE_TEST_SUITE_P( INSTANTIATE_TEST_SUITE_P(TransposeSinkingUnaryForwardTestSuite,
TransposeSinkingUnaryForwardTestSuite,
TransposeSinkingUnaryTestFixture, TransposeSinkingUnaryTestFixture,
::testing::Combine(::testing::ValuesIn(unary_factories), ::testing::Combine(::testing::ValuesIn(unary_factories),
::testing::Values(CreatePassFactory<ov::pass::TransposeSinkingUnaryForward>()), ::testing::Values(CREATE_PASS_FACTORY(TransposeSinkingUnaryForward)),
::testing::ValuesIn(unary_operations_numbers), ::testing::ValuesIn(unary_operations_numbers),
::testing::Values(CreateFunctionTransposeBefore), ::testing::Values(CreateFunctionTransposeBefore),
::testing::Values(CreateFunctionTransposeAfter), ::testing::Values(CreateFunctionTransposeAfter),
::testing::Values(ov::Shape{1, 96, 55, 55}), ::testing::Values(Shape{1, 96, 55, 55}),
::testing::Values(ov::element::f32))); ::testing::Values(element::f32)),
TransposeSinkingUnaryTestFixture::get_test_name);
INSTANTIATE_TEST_SUITE_P( INSTANTIATE_TEST_SUITE_P(TransposeSinkingUnaryBackwardTestSuite,
TransposeSinkingUnaryBackwardTestSuite,
TransposeSinkingUnaryTestFixture, TransposeSinkingUnaryTestFixture,
::testing::Combine(::testing::ValuesIn(unary_factories), ::testing::Combine(::testing::ValuesIn(unary_factories),
::testing::Values(CreatePassFactory<ov::pass::TransposeSinkingUnaryBackward>()), ::testing::Values(CREATE_PASS_FACTORY(TransposeSinkingUnaryBackward)),
::testing::ValuesIn(unary_operations_numbers), ::testing::ValuesIn(unary_operations_numbers),
::testing::Values(CreateFunctionTransposeAfter), ::testing::Values(CreateFunctionTransposeAfter),
::testing::Values(CreateFunctionTransposeBefore), ::testing::Values(CreateFunctionTransposeBefore),
::testing::Values(ov::Shape{1, 96, 55, 55}), ::testing::Values(Shape{1, 96, 55, 55}),
::testing::Values(ov::element::f32))); ::testing::Values(element::f32)),
TransposeSinkingUnaryTestFixture::get_test_name);
INSTANTIATE_TEST_SUITE_P( INSTANTIATE_TEST_SUITE_P(
TransposeSinkingUnaryForwardMultConsumersTestSuiteLastNodeReshape, TransposeSinkingUnaryForwardMultConsumersTestSuiteLastNodeReshape,
TransposeSinkingUnaryTestFixture, TransposeSinkingUnaryTestFixture,
::testing::Combine(::testing::ValuesIn(unary_factories), ::testing::Combine(::testing::ValuesIn(unary_factories),
::testing::Values(CreatePassFactory<ov::pass::TransposeSinkingUnaryForward>()), ::testing::Values(CREATE_PASS_FACTORY(TransposeSinkingUnaryForward)),
::testing::ValuesIn(unary_operations_numbers), ::testing::ValuesIn(unary_operations_numbers),
::testing::Values(mult_consumers_last_node::with_reshape::CreateFunctionTransposeBefore), ::testing::Values(mult_consumers_last_node::with_reshape::CreateFunctionTransposeBefore),
::testing::Values(mult_consumers_last_node::with_reshape::CreateFunctionTransposeAfter), ::testing::Values(mult_consumers_last_node::with_reshape::CreateFunctionTransposeAfter),
::testing::Values(ov::Shape{1, 96, 55, 55}), ::testing::Values(Shape{1, 96, 55, 55}),
::testing::Values(ov::element::f32))); ::testing::Values(element::f32)),
TransposeSinkingUnaryTestFixture::get_test_name);
INSTANTIATE_TEST_SUITE_P( INSTANTIATE_TEST_SUITE_P(
TransposeSinkingUnaryBackwardMultConsumersTestSuiteLastNodeReshape, TransposeSinkingUnaryBackwardMultConsumersTestSuiteLastNodeReshape,
TransposeSinkingUnaryTestFixture, TransposeSinkingUnaryTestFixture,
::testing::Combine(::testing::ValuesIn(unary_factories), ::testing::Combine(::testing::ValuesIn(unary_factories),
::testing::Values(CreatePassFactory<ov::pass::TransposeSinkingUnaryBackward>()), ::testing::Values(CREATE_PASS_FACTORY(TransposeSinkingUnaryBackward)),
::testing::ValuesIn(unary_operations_numbers), ::testing::ValuesIn(unary_operations_numbers),
::testing::Values(mult_consumers_last_node::with_reshape::CreateFunctionTransposeAfter), ::testing::Values(mult_consumers_last_node::with_reshape::CreateFunctionTransposeAfter),
::testing::Values(mult_consumers_last_node::with_reshape::CreateFunctionTransposeBefore), ::testing::Values(mult_consumers_last_node::with_reshape::CreateFunctionTransposeBefore),
::testing::Values(ov::Shape{1, 96, 55, 55}), ::testing::Values(Shape{1, 96, 55, 55}),
::testing::Values(ov::element::f32))); ::testing::Values(element::f32)),
TransposeSinkingUnaryTestFixture::get_test_name);
INSTANTIATE_TEST_SUITE_P( INSTANTIATE_TEST_SUITE_P(
TransposeSinkingUnaryForwardMultConsumersTestSuiteLastNodeEltwise, TransposeSinkingUnaryForwardMultConsumersTestSuiteLastNodeEltwise,
TransposeSinkingUnaryTestFixture, TransposeSinkingUnaryTestFixture,
::testing::Combine(::testing::ValuesIn(unary_factories), ::testing::Combine(::testing::ValuesIn(unary_factories),
::testing::Values(CreatePassFactory<ov::pass::TransposeSinkingUnaryForward>()), ::testing::Values(CREATE_PASS_FACTORY(TransposeSinkingUnaryForward)),
::testing::ValuesIn(unary_operations_numbers), ::testing::ValuesIn(unary_operations_numbers),
::testing::Values(mult_consumers_last_node::with_eltwise::CreateFunctionTransposeBefore), ::testing::Values(mult_consumers_last_node::with_eltwise::CreateFunctionTransposeBefore),
::testing::Values(mult_consumers_last_node::with_eltwise::CreateFunctionTransposeAfter), ::testing::Values(mult_consumers_last_node::with_eltwise::CreateFunctionTransposeAfter),
::testing::Values(ov::Shape{1, 96, 55, 55}), ::testing::Values(Shape{1, 96, 55, 55}),
::testing::Values(ov::element::f32))); ::testing::Values(element::f32)),
TransposeSinkingUnaryTestFixture::get_test_name);
INSTANTIATE_TEST_SUITE_P(
TransposeSinkingUnaryBackwardMultConsumersTestSuiteLastNodeEltwise,
TransposeSinkingUnaryTestFixture,
::testing::Combine(::testing::ValuesIn(unary_factories),
::testing::Values(CreatePassFactory<ov::pass::TransposeSinkingUnaryBackward>()),
::testing::ValuesIn(unary_operations_numbers),
::testing::Values(mult_consumers_last_node::with_eltwise::CreateFunctionTransposeAfter),
::testing::Values(mult_consumers_last_node::with_eltwise::CreateFunctionTransposeBefore),
::testing::Values(ov::Shape{1, 96, 55, 55}),
::testing::Values(ov::element::f32)));
INSTANTIATE_TEST_SUITE_P( INSTANTIATE_TEST_SUITE_P(
TransposeSinkingUnaryForwardMultConsumersTestSuiteFirstNode, TransposeSinkingUnaryForwardMultConsumersTestSuiteFirstNode,
TransposeSinkingUnaryTestFixture, TransposeSinkingUnaryTestFixture,
::testing::Combine(::testing::ValuesIn(unary_factories), ::testing::Combine(::testing::ValuesIn(unary_factories),
::testing::Values(CreatePassFactory<ov::pass::TransposeSinkingUnaryForward>()), ::testing::Values(CREATE_PASS_FACTORY(TransposeSinkingUnaryForward)),
::testing::ValuesIn(unary_operations_numbers), ::testing::ValuesIn(unary_operations_numbers),
::testing::Values(mult_consumers_first_node::forward::CreateFunction), ::testing::Values(mult_consumers_first_node::forward::CreateFunction),
::testing::Values(mult_consumers_first_node::forward::CreateReferenceFunction), ::testing::Values(mult_consumers_first_node::forward::CreateReferenceFunction),
::testing::Values(ov::Shape{1, 96, 55, 55}), ::testing::Values(Shape{1, 96, 55, 55}),
::testing::Values(ov::element::f32))); ::testing::Values(element::f32)),
TransposeSinkingUnaryTestFixture::get_test_name);
INSTANTIATE_TEST_SUITE_P( INSTANTIATE_TEST_SUITE_P(TransposeSinkingUnaryBackwardMultConsumersTestSuiteFirstNode,
TransposeSinkingUnaryBackwardMultConsumersTestSuiteFirstNode,
TransposeSinkingUnaryTestFixture, TransposeSinkingUnaryTestFixture,
::testing::Combine(::testing::ValuesIn(unary_factories), ::testing::Combine(::testing::ValuesIn(unary_factories),
::testing::Values(CreatePassFactory<ov::pass::TransposeSinkingUnaryBackward>()), ::testing::Values(CREATE_PASS_FACTORY(TransposeSinkingUnaryBackward)),
::testing::ValuesIn(unary_operations_numbers), ::testing::ValuesIn(unary_operations_numbers),
::testing::Values(mult_consumers_first_node::backward::CreateFunction), ::testing::Values(mult_consumers_first_node::backward::CreateFunction),
::testing::Values(mult_consumers_first_node::backward::CreateReferenceFunction), ::testing::Values(mult_consumers_first_node::backward::CreateFunction),
::testing::Values(ov::Shape{1, 96, 55, 55}), ::testing::Values(Shape{1, 96, 55, 55}),
::testing::Values(ov::element::f32))); ::testing::Values(element::f32)),
TransposeSinkingUnaryTestFixture::get_test_name);
INSTANTIATE_TEST_SUITE_P(
TransposeSinkingUnaryBackwardMultTransposeConsumersTestSuiteFirstNode,
TransposeSinkingUnaryTestFixture,
::testing::Combine(::testing::ValuesIn(unary_factories),
::testing::Values(CREATE_PASS_FACTORY(TransposeSinkingUnaryBackward)),
::testing::ValuesIn(unary_operations_numbers),
::testing::Values(mult_consumers_first_node::backward_mult_transposes::CreateFunction),
::testing::Values(mult_consumers_first_node::backward_mult_transposes::CreateReferenceFunction),
::testing::Values(Shape{1, 96, 55, 55}),
::testing::Values(element::f32)),
TransposeSinkingUnaryTestFixture::get_test_name);