Transpose sinking binary/split/concat mult consumers case + PRelu (#14670)

* binary: move test classes into namespaces

* add unsing namespacees ov and ov::opset9

* ov::Model -> Model

* use naming for unit tests

* fix unary -> binary

* fix identation

* add tests binary forward consumers > 1

* add binary output consumers > 1 unit tests

* add binary input consumers > 1 unit tests

* binary backward add unit test binary has multiple consumers + fix

* move code to common function CloneNodeWithoutConsumers

* add unit test on backward binary propagation multi consumers

* reorganize binary unit tests

* add backward_input_node_consumers test

* add test output transpose mult consumers

* cleanup

* cleanup unit tests for split

* add forward::mult_consumers::input_node_consumers tests

* add forward::mult_consumers::input_transpose_consumers tests

* add forward::mult_consumers::output_consumers tests

* code indentation fix

* fix unit test split backward input node consumers

* added mult_output_consumers split unit test

* add mult_split_consumers split unit test

* cancat tests using namespaces

* concat add unit tests

* move repeated code into functions

* clang format fixes

* rename TransposeSinkingBinaryElementwise -> TransposeSinkingBinary

* add PRelu

* clang cleanup

* fix prelu tests

* fix using

* clang fixes

* fix NodeVector argument

* fix function descriptions

* fix const ref arg functions

* clang fixes

* prevent backward sinking if not all output nodes are transposes

* fix RemoveSingleOutputConsumers checking output size

* cleanup

* clang cleanup

* disable unary backward with not same transposes

* clang fixes

* remove unneeded functions CloneNodeWithoutConsumers GetNodeIds

* remove unneeded GetTransposeConstant(Input<Node> input) as duplicate of GetTransposeConstant(Node* node)

* cleanup

* add output_transpose_mult_transposes test

* add backward::output_transpose_mult_transposes test

* add unit tests for unary backward multiple transposes; fix problems by adding new transformation

* fix bug TransposeSinkingUnaryBackwardMultiConsumers consumers_more_than
This commit is contained in:
Evgeny Kotov 2023-01-12 15:42:46 +01:00 committed by GitHub
parent 9388560aec
commit 78995e9ac2
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
13 changed files with 2953 additions and 1008 deletions

View File

@ -11,20 +11,20 @@
namespace ov {
namespace pass {
class TRANSFORMATIONS_API TransposeSinkingBinaryElementwiseForward;
class TRANSFORMATIONS_API TransposeSinkingBinaryElementwiseBackward;
class TRANSFORMATIONS_API TransposeSinkingBinaryForward;
class TRANSFORMATIONS_API TransposeSinkingBinaryBackward;
} // namespace pass
} // namespace ov
class ov::pass::TransposeSinkingBinaryElementwiseForward : public ov::pass::MatcherPass {
class ov::pass::TransposeSinkingBinaryForward : public ov::pass::MatcherPass {
public:
OPENVINO_RTTI("ov::pass::TransposeSinkingBinaryElementwiseForward", "0");
TransposeSinkingBinaryElementwiseForward();
OPENVINO_RTTI("ov::pass::TransposeSinkingBinaryForward", "0");
TransposeSinkingBinaryForward();
};
class ov::pass::TransposeSinkingBinaryElementwiseBackward : public ov::pass::MatcherPass {
class ov::pass::TransposeSinkingBinaryBackward : public ov::pass::MatcherPass {
public:
OPENVINO_RTTI("ov::pass::TransposeSinkingBinaryElementwiseBackward", "0");
TransposeSinkingBinaryElementwiseBackward();
OPENVINO_RTTI("ov::pass::TransposeSinkingBinaryBackward", "0");
TransposeSinkingBinaryBackward();
};

View File

@ -11,6 +11,8 @@ namespace ov {
namespace pass {
class TRANSFORMATIONS_API TransposeSinkingUnaryForward;
class TRANSFORMATIONS_API TransposeSinkingUnaryBackwardSingleConsumer;
class TRANSFORMATIONS_API TransposeSinkingUnaryBackwardMultiConsumers;
class TRANSFORMATIONS_API TransposeSinkingUnaryBackward;
} // namespace pass
@ -22,7 +24,19 @@ public:
TransposeSinkingUnaryForward();
};
class ov::pass::TransposeSinkingUnaryBackward : public ov::pass::MatcherPass {
class ov::pass::TransposeSinkingUnaryBackwardSingleConsumer : public ov::pass::MatcherPass {
public:
OPENVINO_RTTI("TransposeSinkingUnaryBackwardSingleConsumer", "0");
TransposeSinkingUnaryBackwardSingleConsumer();
};
class ov::pass::TransposeSinkingUnaryBackwardMultiConsumers : public ov::pass::MatcherPass {
public:
OPENVINO_RTTI("TransposeSinkingUnaryBackwardMultiConsumers", "0");
TransposeSinkingUnaryBackwardMultiConsumers();
};
class ov::pass::TransposeSinkingUnaryBackward : public ov::pass::GraphRewrite {
public:
OPENVINO_RTTI("TransposeSinkingUnaryBackward", "0");
TransposeSinkingUnaryBackward();

View File

@ -28,25 +28,77 @@ struct TransposeInputsInfo {
}
};
TransposeInputsInfo GetFirstTransposeInput(std::shared_ptr<ov::Node> node);
bool IfNodeHasTransposeInputs(const ov::Output<ov::Node>& output);
ov::AxisVector ReverseTransposeOrder(const ov::AxisVector& axis_order);
void SwapOutputNames(ov::Output<ov::Node> output1, ov::Output<ov::Node> output2);
void SwapFriendlyNames(std::shared_ptr<ov::Node> node1, std::shared_ptr<ov::Node> node2);
void SwapNames(std::shared_ptr<ov::Node> node1, std::shared_ptr<ov::Node> node2);
/**
* @brief Finds node first input that is a transpose operation and returns filled TransposeInputsInfo
* for it
*/
TransposeInputsInfo GetFirstTransposeInput(std::shared_ptr<ov::Node>);
/**
* @brief Checks if @arg has any input node that is a transpose operation
*/
bool IfNodeHasTransposeInputs(const ov::Output<ov::Node>&);
/**
* @brief Reverses order of transpose operation. Do it in a such way that if we had couple following one after
* another transposes (one would be reversed version of another) we will have no transpose as a result of that
* couple of transposes.
*/
ov::AxisVector ReverseTransposeOrder(const ov::AxisVector&);
/**
* @brief Swaps @args output tensor names
*/
void SwapOutputNames(ov::Output<ov::Node>, ov::Output<ov::Node>);
/**
* @brief Swaps @args friendly names
*/
void SwapFriendlyNames(std::shared_ptr<ov::Node>, std::shared_ptr<ov::Node>);
/**
* @brief Swaps @args output tensor names and friendly names
*/
void SwapNames(std::shared_ptr<ov::Node>, std::shared_ptr<ov::Node>);
namespace sink_forward {
// insert input reversed transposes, remove first input tranpose
void UpdateInputTransposes(std::shared_ptr<ov::Node> main_node, TransposeInputsInfo& transpose_input_info);
void RemoveZeroInputNode(std::shared_ptr<ov::Node> main_node);
ov::NodeVector InsertOutputTransposes(std::shared_ptr<ov::Node> main_node, TransposeInputsInfo& transpose_input_info);
/**
* @brief Inserts reversed transposed on @args main_node inputs. Removes input transpose specified in @arg
* transpose_input_info
*/
void UpdateInputTransposes(std::shared_ptr<ov::Node> main_node, const TransposeInputsInfo& transpose_input_info);
/**
* @brief Removes @arg input node
*/
void RemoveInputNode(std::shared_ptr<ov::Node>, size_t input_idx);
/**
* @brief Inserts transposes on each main_node output with the order specified in @arg transpose_input_info
*/
ov::NodeVector InsertOutputTransposes(std::shared_ptr<ov::Node> main_node,
const TransposeInputsInfo& transpose_input_info);
} // namespace sink_forward
namespace sink_backward {
/**
* @brief Inserts transposes on each input of @arg main_node with the order specified in @arg transpose_const
*/
ov::NodeVector InsertTransposeBeforeNode(std::shared_ptr<ov::Node> main_node,
std::shared_ptr<ov::opset9::Constant> transpose_const);
} // namespace sink_backward
void UpdateForwardSinkingAbility(std::shared_ptr<ov::Node>);
/**
* @brief Checks if @arg has consumers that all are the same transpose operation. If no consumers at all
* returns false.
*/
bool HasSameOutputTransposeNodes(const ov::Output<ov::Node>&);
/**
* Removes all direct node consumers that have one output
*/
void RemoveSingleOutputConsumers(std::shared_ptr<ov::Node>);
} // namespace transpose_sinking

View File

@ -20,10 +20,10 @@ using namespace ov;
using namespace ov::opset9;
using namespace transpose_sinking;
ov::pass::TransposeSinkingBinaryElementwiseForward::TransposeSinkingBinaryElementwiseForward() {
MATCHER_SCOPE(TransposeSinkingBinaryElementwiseForward);
ov::pass::TransposeSinkingBinaryForward::TransposeSinkingBinaryForward() {
MATCHER_SCOPE(TransposeSinkingBinaryForward);
auto main_node_label = wrap_type<op::util::BinaryElementwiseArithmetic>(IfNodeHasTransposeInputs);
auto main_node_label = wrap_type<op::util::BinaryElementwiseArithmetic, PRelu>(IfNodeHasTransposeInputs);
matcher_pass_callback matcher_pass_callback = [=](Matcher& m) {
const auto& pattern_to_output = m.get_pattern_value_map();
@ -46,18 +46,19 @@ ov::pass::TransposeSinkingBinaryElementwiseForward::TransposeSinkingBinaryElemen
register_matcher(m, matcher_pass_callback);
}
ov::pass::TransposeSinkingBinaryElementwiseBackward::TransposeSinkingBinaryElementwiseBackward() {
MATCHER_SCOPE(TransposeSinkingBinaryElementwiseBackward);
ov::pass::TransposeSinkingBinaryBackward::TransposeSinkingBinaryBackward() {
MATCHER_SCOPE(TransposeSinkingBinaryBackward);
auto main_node_label = wrap_type<op::util::BinaryElementwiseArithmetic>([](const Output<Node>& output) -> bool {
return consumers_count(1)(output) && has_static_rank()(output);
});
auto main_node_label =
wrap_type<op::util::BinaryElementwiseArithmetic, PRelu>([](const Output<Node>& output) -> bool {
return has_static_rank()(output) && HasSameOutputTransposeNodes(output);
});
auto transpose_const_label = wrap_type<Constant>(consumers_count(1));
auto transpose_const_label = wrap_type<Constant>();
auto transpose_label =
wrap_type<Transpose>({main_node_label, transpose_const_label}, [](const Output<Node>& output) -> bool {
return consumers_count(1)(output) && has_static_rank()(output) && is_sinking_node(output);
return has_static_rank()(output) && is_sinking_node(output);
});
matcher_pass_callback matcher_pass_callback = [=](Matcher& m) {
@ -70,8 +71,8 @@ ov::pass::TransposeSinkingBinaryElementwiseBackward::TransposeSinkingBinaryEleme
register_new_node(new_node);
}
// remove transpose after main node
transpose->output(0).replace(main_node);
// remove output transposes
RemoveSingleOutputConsumers(main_node);
SwapNames(transpose, main_node);

View File

@ -55,14 +55,14 @@ ov::pass::TransposeSinkingConcatBackward::TransposeSinkingConcatBackward() {
MATCHER_SCOPE(TransposeSinkingConcatBackward);
auto main_node_label = wrap_type<Concat>([](const Output<Node>& output) -> bool {
return consumers_count(1)(output) && has_static_rank()(output);
return has_static_rank()(output) && HasSameOutputTransposeNodes(output);
});
auto transpose_const_label = wrap_type<Constant>(consumers_count(1));
auto transpose_const_label = wrap_type<Constant>();
auto transpose_label =
wrap_type<Transpose>({main_node_label, transpose_const_label}, [](const Output<Node>& output) -> bool {
return consumers_count(1)(output) && has_static_rank()(output) && is_sinking_node(output);
return has_static_rank()(output) && is_sinking_node(output);
});
matcher_pass_callback matcher_pass_callback = [=](Matcher& m) {
@ -75,8 +75,8 @@ ov::pass::TransposeSinkingConcatBackward::TransposeSinkingConcatBackward() {
register_new_node(new_node);
}
// remove transpose after main node
transpose->output(0).replace(main_node);
// remove output transposes
RemoveSingleOutputConsumers(main_node);
SwapNames(transpose, main_node);

View File

@ -21,7 +21,7 @@
ov::pass::TransposeSinkingGeneralForward::TransposeSinkingGeneralForward() {
MATCHER_SCOPE(TransposeSinkingGeneralForward);
add_matcher<ov::pass::TransposeSinkingUnaryForward>();
add_matcher<ov::pass::TransposeSinkingBinaryElementwiseForward>();
add_matcher<ov::pass::TransposeSinkingBinaryForward>();
add_matcher<ov::pass::TransposeSinkingConcatForward>();
add_matcher<ov::pass::TransposeSinkingSplitForward>();
add_matcher<ngraph::pass::TransposeFuse>();
@ -30,7 +30,7 @@ ov::pass::TransposeSinkingGeneralForward::TransposeSinkingGeneralForward() {
ov::pass::TransposeSinkingGeneralBackward::TransposeSinkingGeneralBackward() {
MATCHER_SCOPE(TransposeSinkingGeneralBackward);
add_matcher<ov::pass::TransposeSinkingUnaryBackward>();
add_matcher<ov::pass::TransposeSinkingBinaryElementwiseBackward>();
add_matcher<ov::pass::TransposeSinkingBinaryBackward>();
add_matcher<ov::pass::TransposeSinkingConcatBackward>();
add_matcher<ov::pass::TransposeSinkingSplitBackward>();
add_matcher<ngraph::pass::TransposeFuse>();

View File

@ -52,78 +52,69 @@ OutputTranspose GetOutputTransposes(NodePtr node) {
return OutputTranspose();
}
NodePtr FindSplitInput(Node* node) {
template <typename NodeT>
std::shared_ptr<ov::Node> FindInputNode(ov::Node* node) {
for (size_t input_idx = 0; input_idx < node->get_input_size(); ++input_idx) {
NodePtr input_node = node->get_input_node_shared_ptr(input_idx);
auto split_node = as_type_ptr<Split>(input_node);
if (split_node)
return split_node;
std::shared_ptr<ov::Node> input_node = node->get_input_node_shared_ptr(input_idx);
auto target_node = ov::as_type_ptr<NodeT>(input_node);
if (target_node)
return target_node;
}
return {};
}
std::shared_ptr<Constant> GetTransposeConstant(Input<Node> input) {
auto transpose_node = dynamic_cast<Transpose*>(input.get_node());
if (!transpose_node)
return {};
if (!is_sinking_node(input.get_node()))
return {};
auto constant_node = as_type_ptr<Constant>(transpose_node->input_value(1).get_node_shared_ptr());
if (!constant_node)
return {};
return constant_node;
}
bool HasInputSplitAndTransposeSiblings(const Output<Node>& output) {
NodePtr split_node = FindSplitInput(output.get_node());
if (!split_node) {
NodePtr main_node = FindInputNode<Split>(output.get_node());
if (!main_node) {
return false;
}
AxisVector first_transpose_axis_order;
// get first transpose axis
{
auto constant_node = GetTransposeConstant(*(split_node->get_output_target_inputs(0).begin()));
if (!constant_node)
return false;
first_transpose_axis_order = constant_node->get_axis_vector_val();
}
for (size_t output_idx = 1; output_idx < split_node->get_output_size(); ++output_idx) {
for (auto& input : split_node->get_output_target_inputs(output_idx)) {
auto constant_node = GetTransposeConstant(input);
if (!constant_node)
return false;
AxisVector transpose_axis_order = constant_node->get_axis_vector_val();
if (transpose_axis_order.size() != first_transpose_axis_order.size())
return false;
if (!std::equal(transpose_axis_order.begin(),
transpose_axis_order.end(),
first_transpose_axis_order.begin()))
return false;
}
}
return true;
return HasSameOutputTransposeNodes(main_node);
}
bool IsSplitSinked(const Output<Node>& output) {
return HasInputSplitAndTransposeSiblings(output) && is_sinking_node(output);
}
} // namespace
/*
* We follow Transpose operations rather than Split. We cannot create matcher pattern
* for Split with Transpose outputs since Split can have different number of outputs.
* We just can:
* - specify Split as searched node and check if it has transpose outputs
* - specify Transpose as searched node and check if it has Split input
* Transformations are called on each found node in sorted order from the start to end
* of the network. When we proceed Split backward sinking we move input transpose
* to the input of the Split operation.
* Consider case Split (1) -> Split (2) -> Transpose
* If specify Split as main searched node after first transformation work we will have
* Split (1) -> Transpose -> Split(2)
* Matcher pass will not call TransposeSinkingSplitBackward since
* - matcher pattern has no Transpose label
* - Split (1) has already been proceeded
* Adding Split(2) into the working queue as register_new_node(split)
* cannot help us. We just can try to find all input Split operations and add them with
* register_new_node(). Implemented way is simpler.
*
* We sink Transpose through Split operation in a backward way only if all the output
* nodes are the same Transpose. We can:
* - clone Split with all outputs except Transpose
* causes perfomance problems
* - add reversed Transpose operations on all outputs except sinking Transpose
* nothing to do with new added output Transposes
*/
ov::pass::TransposeSinkingSplitBackward::TransposeSinkingSplitBackward() {
MATCHER_SCOPE(TransposeSinkingSplitBackward);
auto transpose_const_label = wrap_type<Constant>(consumers_count(1));
auto transpose_label =
wrap_type<Transpose>({any_input(), transpose_const_label}, HasInputSplitAndTransposeSiblings);
auto transpose_const_label = wrap_type<Constant>();
auto transpose_label = wrap_type<Transpose>({any_input(), transpose_const_label}, IsSplitSinked);
matcher_pass_callback matcher_pass_callback = [=](Matcher& m) {
const auto& pattern_to_output = m.get_pattern_value_map();
auto transpose_label_node = pattern_to_output.at(transpose_label).get_node();
NodePtr split = FindSplitInput(transpose_label_node);
NodePtr split = FindInputNode<Split>(transpose_label_node);
auto split_axis_constant = as_type_ptr<Constant>(split->input_value(1).get_node_shared_ptr());
OutputTranspose output_transpose = GetOutputTransposes(split);
@ -161,11 +152,8 @@ ov::pass::TransposeSinkingSplitBackward::TransposeSinkingSplitBackward() {
new_split_axis_const);
// remove split output transposes
for (size_t output_idx = 0; output_idx < split->get_output_size(); ++output_idx) {
for (auto& input : split->get_output_target_inputs(output_idx)) {
input.get_node()->output(0).replace(split->output(output_idx));
}
}
RemoveSingleOutputConsumers(split);
return true;
};
@ -186,7 +174,7 @@ ov::pass::TransposeSinkingSplitForward::TransposeSinkingSplitForward() {
TransposeInputsInfo transpose_input_info = GetFirstTransposeInput(main_node);
sink_forward::RemoveZeroInputNode(main_node);
sink_forward::RemoveInputNode(main_node, /* input_idx */ 0);
for (auto& new_node : sink_forward::InsertOutputTransposes(main_node, transpose_input_info)) {
register_new_node(new_node);
transpose_sinking::UpdateForwardSinkingAbility(new_node);

View File

@ -10,6 +10,10 @@
#include "transformations/rt_info/transpose_sinking_attr.hpp"
using namespace ov;
using namespace ov::opset9;
using namespace ov::pass::pattern;
using namespace ov::op::util;
using namespace transpose_sinking;
namespace {
@ -90,16 +94,11 @@ NodePair Swap(NodePtr first_node, NodePtr second_node) {
ov::pass::TransposeSinkingUnaryForward::TransposeSinkingUnaryForward() {
MATCHER_SCOPE(TransposeSinkingUnaryForward);
auto transpose_label = ov::pass::pattern::wrap_type<ov::opset9::Transpose>(
{ov::pass::pattern::any_input(), ov::pass::pattern::any_input()});
auto unary_label = ov::pass::pattern::wrap_type<ov::op::util::UnaryElementwiseArithmetic,
ov::opset9::Clamp,
ov::opset9::Elu,
ov::opset9::SoftPlus,
ov::opset9::LogicalNot,
ov::opset9::Convert>({transpose_label});
auto transpose_label = wrap_type<Transpose>({any_input(), any_input()});
auto unary_label =
wrap_type<UnaryElementwiseArithmetic, Clamp, Elu, SoftPlus, LogicalNot, Convert>({transpose_label});
ov::matcher_pass_callback matcher_pass_callback = [=](ov::pass::pattern::Matcher& m) {
ov::matcher_pass_callback matcher_pass_callback = [=](Matcher& m) {
const auto& pattern_to_output = m.get_pattern_value_map();
auto transpose = pattern_to_output.at(transpose_label).get_node_shared_ptr();
auto unary = pattern_to_output.at(unary_label).get_node_shared_ptr();
@ -109,12 +108,12 @@ ov::pass::TransposeSinkingUnaryForward::TransposeSinkingUnaryForward() {
register_new_node(new_nodes.first);
register_new_node(new_nodes.second);
transpose_sinking::UpdateForwardSinkingAbility(new_nodes.second);
UpdateForwardSinkingAbility(new_nodes.second);
return true;
};
auto m = std::make_shared<ov::pass::pattern::Matcher>(unary_label, "ov::pass::TransposeSinkingUnaryForward");
auto m = std::make_shared<Matcher>(unary_label, "ov::pass::TransposeSinkingUnaryForward");
register_matcher(m, matcher_pass_callback);
}
@ -124,21 +123,16 @@ bool IfSinkingEnabled(const Output<Node>& output) {
}
} // namespace
ov::pass::TransposeSinkingUnaryBackward::TransposeSinkingUnaryBackward() {
MATCHER_SCOPE(TransposeSinkingUnaryBackward);
ov::pass::TransposeSinkingUnaryBackwardSingleConsumer::TransposeSinkingUnaryBackwardSingleConsumer() {
MATCHER_SCOPE(TransposeSinkingUnaryBackwardSingleConsumer);
auto unary_label = ov::pass::pattern::wrap_type<ov::op::util::UnaryElementwiseArithmetic,
ov::opset9::Clamp,
ov::opset9::Elu,
ov::opset9::SoftPlus,
ov::opset9::LogicalNot,
ov::opset9::Convert>({ov::pass::pattern::any_input()});
auto unary_label =
wrap_type<UnaryElementwiseArithmetic, Clamp, Elu, SoftPlus, LogicalNot, Convert>({any_input()},
consumers_count(1));
auto transpose_label =
ov::pass::pattern::wrap_type<ov::opset9::Transpose>({unary_label, ov::pass::pattern::any_input()},
IfSinkingEnabled);
auto transpose_label = wrap_type<Transpose>({unary_label, any_input()}, IfSinkingEnabled);
ov::matcher_pass_callback matcher_pass_callback = [=](ov::pass::pattern::Matcher& m) {
ov::matcher_pass_callback matcher_pass_callback = [=](Matcher& m) {
const auto& pattern_to_output = m.get_pattern_value_map();
auto transpose = pattern_to_output.at(transpose_label).get_node_shared_ptr();
auto unary = pattern_to_output.at(unary_label).get_node_shared_ptr();
@ -151,6 +145,55 @@ ov::pass::TransposeSinkingUnaryBackward::TransposeSinkingUnaryBackward() {
return true;
};
auto m = std::make_shared<ov::pass::pattern::Matcher>(transpose_label, "ov::pass::TransposeSinkingUnaryBackward");
auto m = std::make_shared<Matcher>(transpose_label, "ov::pass::TransposeSinkingUnaryBackwardSingleConsumer");
register_matcher(m, matcher_pass_callback);
}
namespace {
std::function<bool(Output<Node>)> consumers_more_than(size_t n) {
return [=](Output<Node> output) -> bool {
return output.get_target_inputs().size() > n;
};
}
} // namespace
ov::pass::TransposeSinkingUnaryBackwardMultiConsumers::TransposeSinkingUnaryBackwardMultiConsumers() {
MATCHER_SCOPE(TransposeSinkingUnaryBackwardMultiConsumers);
auto unary_restrictions = [](const Output<Node>& output) -> bool {
return consumers_more_than(1)(output) && HasSameOutputTransposeNodes(output);
};
auto unary_label =
wrap_type<UnaryElementwiseArithmetic, Clamp, Elu, SoftPlus, LogicalNot, Convert>({any_input()},
unary_restrictions);
auto transpose_const_label = wrap_type<Constant>();
auto transpose_label = wrap_type<Transpose>({unary_label, transpose_const_label}, IfSinkingEnabled);
ov::matcher_pass_callback matcher_pass_callback = [=](Matcher& m) {
const auto& pattern_to_output = m.get_pattern_value_map();
auto transpose_const = as_type_ptr<Constant>(pattern_to_output.at(transpose_const_label).get_node_shared_ptr());
auto transpose = pattern_to_output.at(transpose_label).get_node_shared_ptr();
auto unary = pattern_to_output.at(unary_label).get_node_shared_ptr();
for (auto& new_node : sink_backward::InsertTransposeBeforeNode(unary, transpose_const)) {
register_new_node(new_node);
}
// remove output transposes
RemoveSingleOutputConsumers(unary);
return true;
};
auto m = std::make_shared<Matcher>(transpose_label, "ov::pass::TransposeSinkingUnaryBackwardMultiConsumers");
register_matcher(m, matcher_pass_callback);
}
ov::pass::TransposeSinkingUnaryBackward::TransposeSinkingUnaryBackward() {
MATCHER_SCOPE(TransposeSinkingUnaryBackward);
add_matcher<ov::pass::TransposeSinkingUnaryBackwardSingleConsumer>();
add_matcher<ov::pass::TransposeSinkingUnaryBackwardMultiConsumers>();
}

View File

@ -115,8 +115,7 @@ ov::Output<ov::Node> FixInputNodeRank(ov::Output<ov::Node> input_node, ov::Rank:
namespace sink_forward {
// insert input reversed transposes, remove first input tranpose
void UpdateInputTransposes(NodePtr main_node, TransposeInputsInfo& transpose_input_info) {
void UpdateInputTransposes(NodePtr main_node, const TransposeInputsInfo& transpose_input_info) {
if (transpose_input_info.isEmpty() || HasDynamicRankInput(main_node))
return;
@ -149,15 +148,15 @@ void UpdateInputTransposes(NodePtr main_node, TransposeInputsInfo& transpose_inp
}
}
void RemoveZeroInputNode(NodePtr main_node) {
auto input_node = main_node->input_value(0);
if (input_node.get_node()->get_input_size() < 1)
void RemoveInputNode(NodePtr main_node, size_t input_idx) {
auto input_node = main_node->input_value(input_idx);
if (input_node.get_node()->get_input_size() < (input_idx + 1))
return;
auto parent_node = input_node.get_node()->input_value(0);
main_node->input(0).replace_source_output(parent_node);
auto parent_node = input_node.get_node()->input_value(input_idx);
main_node->input(input_idx).replace_source_output(parent_node);
}
NodeVector InsertOutputTransposes(NodePtr main_node, TransposeInputsInfo& transpose_input_info) {
NodeVector InsertOutputTransposes(NodePtr main_node, const TransposeInputsInfo& transpose_input_info) {
if (transpose_input_info.isEmpty())
return {};
const auto transpose_axis_order = transpose_input_info.transpose_const->get_axis_vector_val();
@ -246,6 +245,7 @@ bool CanPropagateForwardThrough(Node* node) {
CHECK_TRANSPOSE_SINKING_SUPPORTED(Concat, node);
CHECK_TRANSPOSE_SINKING_SUPPORTED(Split, node);
CHECK_TRANSPOSE_SINKING_SUPPORTED(Transpose, node);
CHECK_TRANSPOSE_SINKING_SUPPORTED(PRelu, node);
return false;
}
@ -268,4 +268,76 @@ void UpdateForwardSinkingAbility(NodePtr node) {
mark_as_no_sinking_node(node);
}
namespace {
std::shared_ptr<Constant> GetTransposeConstant(Node* node) {
auto transpose_node = dynamic_cast<Transpose*>(node);
if (!transpose_node)
return {};
auto constant_node = as_type_ptr<Constant>(transpose_node->input_value(1).get_node_shared_ptr());
if (!constant_node)
return {};
return constant_node;
}
Node* FindFirstConsumer(NodePtr node) {
for (size_t output_idx = 0; output_idx < node->get_output_size(); ++output_idx) {
auto inputs = node->get_output_target_inputs(output_idx);
if (inputs.empty())
continue;
return inputs.begin()->get_node();
}
return nullptr;
}
bool HasSameOutputTransposeNodes(NodePtr main_node) {
AxisVector first_transpose_axis_order;
{
Node* first_consumer = FindFirstConsumer(main_node);
if (!first_consumer)
return false;
auto constant_node = GetTransposeConstant(first_consumer);
if (!constant_node)
return false;
first_transpose_axis_order = constant_node->get_axis_vector_val();
}
for (size_t output_idx = 0; output_idx < main_node->get_output_size(); ++output_idx) {
for (auto& input : main_node->get_output_target_inputs(output_idx)) {
auto constant_node = GetTransposeConstant(input.get_node());
if (!constant_node)
return false;
AxisVector transpose_axis_order = constant_node->get_axis_vector_val();
if (transpose_axis_order.size() != first_transpose_axis_order.size())
return false;
if (!std::equal(transpose_axis_order.begin(),
transpose_axis_order.end(),
first_transpose_axis_order.begin()))
return false;
}
}
return true;
}
} // namespace
bool HasSameOutputTransposeNodes(const Output<Node>& output) {
return HasSameOutputTransposeNodes(output.get_node_shared_ptr());
}
void RemoveSingleOutputConsumers(NodePtr node) {
for (size_t output_idx = 0; output_idx < node->get_output_size(); ++output_idx) {
for (auto& input : node->get_output_target_inputs(output_idx)) {
Node* consumer = input.get_node();
if (consumer->get_output_size() != 1)
continue;
consumer->output(0).replace(node->output(output_idx));
}
}
}
} // namespace transpose_sinking

View File

@ -12,13 +12,37 @@
#include "common_test_utils/ngraph_test_utils.hpp"
#include "gtest/gtest.h"
using namespace ov;
using namespace ov::opset9;
namespace {
std::string to_string(const Shape& shape) {
std::ostringstream result;
result << "{";
for (size_t idx = 0; idx < shape.size(); ++idx) {
if (idx)
result << ",";
result << shape[idx];
}
result << "}";
return result.str();
}
} // namespace
using NodePtr = std::shared_ptr<ov::Node>;
class IUnaryFactory {
public:
IUnaryFactory() = default;
IUnaryFactory(const std::string& type_name) : type_name_(type_name) {}
virtual ~IUnaryFactory() = default;
virtual NodePtr create(NodePtr parent_node) const = 0;
const std::string& getTypeName() const {
return type_name_;
}
private:
const std::string type_name_;
};
using UnaryFactoryPtr = std::shared_ptr<IUnaryFactory>;
@ -26,39 +50,45 @@ using UnaryFactoryPtr = std::shared_ptr<IUnaryFactory>;
template <typename UnaryT>
class UnaryFactory : public IUnaryFactory {
public:
UnaryFactory() = default;
UnaryFactory(const std::string& type_name) : IUnaryFactory(type_name) {}
NodePtr create(NodePtr parent_node) const override {
return std::make_shared<UnaryT>(parent_node);
}
};
template <>
NodePtr UnaryFactory<ov::opset9::Elu>::create(NodePtr parent_node) const {
return std::make_shared<ov::opset9::Elu>(parent_node, 0.1);
NodePtr UnaryFactory<Elu>::create(NodePtr parent_node) const {
return std::make_shared<Elu>(parent_node, 0.1);
}
template <>
NodePtr UnaryFactory<ov::opset9::Clamp>::create(NodePtr parent_node) const {
return std::make_shared<ov::opset9::Clamp>(parent_node, 0.1, 0.2);
NodePtr UnaryFactory<Clamp>::create(NodePtr parent_node) const {
return std::make_shared<Clamp>(parent_node, 0.1, 0.2);
}
template <>
NodePtr UnaryFactory<ov::opset9::Convert>::create(NodePtr parent_node) const {
return std::make_shared<ov::opset9::Convert>(parent_node, ov::element::f64);
NodePtr UnaryFactory<Convert>::create(NodePtr parent_node) const {
return std::make_shared<Convert>(parent_node, element::f64);
}
template <typename UnaryT>
UnaryFactoryPtr CreateUnaryFactory() {
return std::make_shared<UnaryFactory<UnaryT>>();
UnaryFactoryPtr CreateUnaryFactory(const std::string& type_name) {
return std::make_shared<UnaryFactory<UnaryT>>(type_name);
}
// ----------------------------------------------------------------------------
class IPassFactory {
public:
IPassFactory() = default;
IPassFactory(const std::string& type_name) : type_name_(type_name) {}
virtual ~IPassFactory() = default;
virtual void registerPass(ov::pass::Manager& pass_manager) const = 0;
const std::string& getTypeName() const {
return type_name_;
}
private:
const std::string type_name_;
};
using PassFactoryPtr = std::shared_ptr<IPassFactory>;
@ -66,15 +96,16 @@ using PassFactoryPtr = std::shared_ptr<IPassFactory>;
template <typename PassT>
class PassFactory : public IPassFactory {
public:
PassFactory(const std::string& type_name) : IPassFactory(type_name) {}
void registerPass(ov::pass::Manager& pass_manager) const override {
pass_manager.register_pass<PassT>();
}
};
template <typename PassT>
PassFactoryPtr CreatePassFactory() {
return std::make_shared<PassFactory<PassT>>();
}
#define CREATE_PASS_FACTORY(pass_name) std::make_shared<PassFactory<ov::pass::pass_name>>(#pass_name)
#undef CREATE_UNARY_FACTORY
#define CREATE_UNARY_FACTORY(type_name) CreateUnaryFactory<type_name>(#type_name)
// ----------------------------------------------------------------------------
@ -82,19 +113,46 @@ using FloatPtr = std::unique_ptr<float[]>;
using CreateGraphF = std::function<std::shared_ptr<ov::Model>(UnaryFactoryPtr unary_factory,
size_t num_unary_ops,
const ov::Shape& input_shape,
ov::element::Type input_type)>;
const Shape& input_shape,
element::Type input_type)>;
using TestParams = std::tuple<UnaryFactoryPtr,
PassFactoryPtr,
size_t, /* num_unary_ops */
CreateGraphF, /* model_factory */
CreateGraphF, /* reference_model_factory */
ov::Shape, /* input shape */
ov::element::Type>; /* input type */
size_t, /* num_unary_ops */
CreateGraphF, /* model_factory */
CreateGraphF, /* reference_model_factory */
Shape, /* input shape */
element::Type>; /* input type */
class TransposeSinkingUnaryTestFixture : public ::testing::WithParamInterface<TestParams>,
public TransformationTestsF {};
class TransposeSinkingUnaryTestFixture : public ::testing::WithParamInterface<TestParams>, public TransformationTestsF {
public:
static std::string get_test_name(const testing::TestParamInfo<TestParams>& obj) {
UnaryFactoryPtr unary_factory;
PassFactoryPtr pass_factory;
size_t num_unary_ops;
CreateGraphF model_factory;
CreateGraphF reference_model_factory;
Shape input_shape;
element::Type input_type;
std::tie(unary_factory,
pass_factory,
num_unary_ops,
model_factory,
reference_model_factory,
input_shape,
input_type) = obj.param;
std::ostringstream test_name;
test_name << "unaryFactory=" << unary_factory->getTypeName() << "/";
test_name << "numUnaryOps=" << num_unary_ops << "/";
test_name << "inputShape=" << to_string(input_shape) << "/";
test_name << "unaryFactory=" << unary_factory->getTypeName() << "/";
test_name << "passFactory=" << pass_factory->getTypeName() << "/";
test_name << "inputType=" << input_type;
return test_name.str();
}
};
namespace {
@ -105,12 +163,12 @@ std::string GetFinalNodeName(std::shared_ptr<ov::Model> model, int index = 0) {
std::shared_ptr<ov::Model> CreateFunctionTransposeBefore(UnaryFactoryPtr unary_factory,
size_t num_unary_ops,
const ov::Shape& input_shape,
ov::element::Type input_type) {
auto X = std::make_shared<ov::opset9::Parameter>(input_type, input_shape);
const Shape& input_shape,
element::Type input_type) {
auto X = std::make_shared<Parameter>(input_type, input_shape);
auto ng_order0 = std::make_shared<ov::opset9::Constant>(ov::element::u64, ov::Shape{4}, ov::Shape{0, 2, 3, 1});
auto transpose0 = std::make_shared<ov::opset9::Transpose>(X, ng_order0);
auto ng_order0 = std::make_shared<Constant>(element::u64, Shape{4}, Shape{0, 2, 3, 1});
auto transpose0 = std::make_shared<Transpose>(X, ng_order0);
NodePtr in_op = transpose0;
for (size_t i = 0; i < num_unary_ops; ++i) {
@ -122,25 +180,25 @@ std::shared_ptr<ov::Model> CreateFunctionTransposeBefore(UnaryFactoryPtr unary_f
std::shared_ptr<ov::Model> CreateFunctionTransposeAfter(UnaryFactoryPtr unary_factory,
size_t num_unary_ops,
const ov::Shape& input_shape,
ov::element::Type input_type) {
auto X = std::make_shared<ov::opset9::Parameter>(input_type, input_shape);
const Shape& input_shape,
element::Type input_type) {
auto X = std::make_shared<Parameter>(input_type, input_shape);
NodePtr in_op = X;
for (size_t i = 0; i < num_unary_ops; ++i) {
in_op = unary_factory->create(in_op);
}
auto ng_order0 = std::make_shared<ov::opset9::Constant>(ov::element::u64, ov::Shape{4}, ov::Shape{0, 2, 3, 1});
auto transpose0 = std::make_shared<ov::opset9::Transpose>(in_op, ng_order0);
auto ng_order0 = std::make_shared<Constant>(element::u64, Shape{4}, Shape{0, 2, 3, 1});
auto transpose0 = std::make_shared<Transpose>(in_op, ng_order0);
return std::make_shared<ov::Model>(transpose0, ov::ParameterVector{X});
}
static NodePtr CreateReshape(NodePtr parent_node, const ov::Shape& input_shape) {
static NodePtr CreateReshape(NodePtr parent_node, const Shape& input_shape) {
const size_t mul = std::accumulate(input_shape.begin(), input_shape.end(), (size_t)1, std::multiplies<size_t>());
auto reshape_const = std::make_shared<ov::opset9::Constant>(ov::element::u64, ov::Shape{1}, ov::Shape{mul});
return std::make_shared<ov::opset9::Reshape>(parent_node, reshape_const, false);
auto reshape_const = std::make_shared<Constant>(element::u64, Shape{1}, Shape{mul});
return std::make_shared<Reshape>(parent_node, reshape_const, false);
}
namespace mult_consumers_last_node {
@ -148,17 +206,17 @@ namespace with_reshape {
std::shared_ptr<ov::Model> CreateFunctionTransposeAfter(UnaryFactoryPtr unary_factory,
size_t num_unary_ops,
const ov::Shape& input_shape,
ov::element::Type input_type) {
auto X = std::make_shared<ov::opset9::Parameter>(input_type, input_shape);
const Shape& input_shape,
element::Type input_type) {
auto X = std::make_shared<Parameter>(input_type, input_shape);
NodePtr in_op = X;
for (size_t i = 0; i < num_unary_ops; ++i) {
in_op = unary_factory->create(in_op);
}
auto ng_order0 = std::make_shared<ov::opset9::Constant>(ov::element::u64, ov::Shape{4}, ov::Shape{0, 2, 3, 1});
auto transpose0 = std::make_shared<ov::opset9::Transpose>(in_op, ng_order0);
auto ng_order0 = std::make_shared<Constant>(element::u64, Shape{4}, Shape{0, 2, 3, 1});
auto transpose0 = std::make_shared<Transpose>(in_op, ng_order0);
auto reshape1 = CreateReshape(transpose0, input_shape);
auto reshape2 = CreateReshape(transpose0, input_shape);
@ -168,12 +226,12 @@ std::shared_ptr<ov::Model> CreateFunctionTransposeAfter(UnaryFactoryPtr unary_fa
std::shared_ptr<ov::Model> CreateFunctionTransposeBefore(UnaryFactoryPtr unary_factory,
size_t num_unary_ops,
const ov::Shape& input_shape,
ov::element::Type input_type) {
auto X = std::make_shared<ov::opset9::Parameter>(input_type, input_shape);
const Shape& input_shape,
element::Type input_type) {
auto X = std::make_shared<Parameter>(input_type, input_shape);
auto ng_order0 = std::make_shared<ov::opset9::Constant>(ov::element::u64, ov::Shape{4}, ov::Shape{0, 2, 3, 1});
auto transpose0 = std::make_shared<ov::opset9::Transpose>(X, ng_order0);
auto ng_order0 = std::make_shared<Constant>(element::u64, Shape{4}, Shape{0, 2, 3, 1});
auto transpose0 = std::make_shared<Transpose>(X, ng_order0);
NodePtr in_op = transpose0;
for (size_t i = 0; i < num_unary_ops; ++i) {
@ -191,44 +249,44 @@ namespace with_eltwise {
std::shared_ptr<ov::Model> CreateFunctionTransposeAfter(UnaryFactoryPtr unary_factory,
size_t num_unary_ops,
const ov::Shape& input_shape,
ov::element::Type input_type) {
auto X = std::make_shared<ov::opset9::Parameter>(input_type, input_shape);
const Shape& input_shape,
element::Type input_type) {
auto X = std::make_shared<Parameter>(input_type, input_shape);
NodePtr in_op = X;
for (size_t i = 0; i < num_unary_ops; ++i) {
in_op = unary_factory->create(in_op);
}
auto sinh = std::make_shared<ov::opset9::Sinh>(in_op);
auto sinh = std::make_shared<Sinh>(in_op);
auto ng_order0 = std::make_shared<ov::opset9::Constant>(ov::element::u64, ov::Shape{4}, ov::Shape{0, 2, 3, 1});
auto transpose0 = std::make_shared<ov::opset9::Transpose>(sinh, ng_order0);
auto ng_order0 = std::make_shared<Constant>(element::u64, Shape{4}, Shape{0, 2, 3, 1});
auto transpose0 = std::make_shared<Transpose>(sinh, ng_order0);
auto cosh = std::make_shared<ov::opset9::Cosh>(in_op);
auto cosh = std::make_shared<Cosh>(in_op);
auto ng_order1 = std::make_shared<ov::opset9::Constant>(ov::element::u64, ov::Shape{4}, ov::Shape{0, 2, 3, 1});
auto transpose1 = std::make_shared<ov::opset9::Transpose>(cosh, ng_order1);
auto ng_order1 = std::make_shared<Constant>(element::u64, Shape{4}, Shape{0, 2, 3, 1});
auto transpose1 = std::make_shared<Transpose>(cosh, ng_order1);
return std::make_shared<ov::Model>(ov::OutputVector{transpose0, transpose1}, ov::ParameterVector{X});
}
std::shared_ptr<ov::Model> CreateFunctionTransposeBefore(UnaryFactoryPtr unary_factory,
size_t num_unary_ops,
const ov::Shape& input_shape,
ov::element::Type input_type) {
auto X = std::make_shared<ov::opset9::Parameter>(input_type, input_shape);
const Shape& input_shape,
element::Type input_type) {
auto X = std::make_shared<Parameter>(input_type, input_shape);
auto ng_order0 = std::make_shared<ov::opset9::Constant>(ov::element::u64, ov::Shape{4}, ov::Shape{0, 2, 3, 1});
auto transpose0 = std::make_shared<ov::opset9::Transpose>(X, ng_order0);
auto ng_order0 = std::make_shared<Constant>(element::u64, Shape{4}, Shape{0, 2, 3, 1});
auto transpose0 = std::make_shared<Transpose>(X, ng_order0);
NodePtr in_op = transpose0;
for (size_t i = 0; i < num_unary_ops; ++i) {
in_op = unary_factory->create(in_op);
}
auto sinh = std::make_shared<ov::opset9::Sinh>(in_op);
auto cosh = std::make_shared<ov::opset9::Cosh>(in_op);
auto sinh = std::make_shared<Sinh>(in_op);
auto cosh = std::make_shared<Cosh>(in_op);
return std::make_shared<ov::Model>(ov::OutputVector{sinh, cosh}, ov::ParameterVector{X});
}
@ -241,66 +299,87 @@ namespace backward {
std::shared_ptr<ov::Model> CreateFunction(UnaryFactoryPtr unary_factory,
size_t num_unary_ops,
const ov::Shape& input_shape,
ov::element::Type input_type) {
auto X = std::make_shared<ov::opset9::Parameter>(input_type, input_shape);
const Shape& input_shape,
element::Type input_type) {
auto X = std::make_shared<Parameter>(input_type, input_shape);
ov::OutputVector outputs;
NodePtr in_op = X;
for (size_t i = 0; i < num_unary_ops; ++i) {
in_op = unary_factory->create(in_op);
auto cosh = std::make_shared<ov::opset9::Cosh>(in_op);
auto cosh = std::make_shared<Cosh>(in_op);
outputs.push_back(cosh);
}
auto ng_order0 = std::make_shared<ov::opset9::Constant>(ov::element::u64, ov::Shape{4}, ov::Shape{0, 2, 3, 1});
auto transpose0 = std::make_shared<ov::opset9::Transpose>(in_op, ng_order0);
auto ng_order0 = std::make_shared<Constant>(element::u64, Shape{4}, Shape{0, 2, 3, 1});
auto transpose0 = std::make_shared<Transpose>(in_op, ng_order0);
outputs.push_back(transpose0);
return std::make_shared<ov::Model>(outputs, ov::ParameterVector{X});
}
std::shared_ptr<ov::Model> CreateReferenceFunction(UnaryFactoryPtr unary_factory,
size_t num_unary_ops,
const ov::Shape& input_shape,
ov::element::Type input_type) {
auto X = std::make_shared<ov::opset9::Parameter>(input_type, input_shape);
ov::OutputVector outputs;
NodePtr in_op = X;
for (size_t i = 0; i < num_unary_ops; ++i) {
in_op = unary_factory->create(in_op);
auto cosh = std::make_shared<ov::opset9::Cosh>(in_op);
outputs.push_back(cosh);
}
auto ng_order0 = std::make_shared<ov::opset9::Constant>(ov::element::u64, ov::Shape{4}, ov::Shape{0, 2, 3, 1});
auto transpose0 = std::make_shared<ov::opset9::Transpose>(X, ng_order0);
in_op = transpose0;
for (size_t i = 0; i < num_unary_ops; ++i) {
in_op = unary_factory->create(in_op);
}
outputs.push_back(in_op);
return std::make_shared<ov::Model>(outputs, ov::ParameterVector{X});
}
} // namespace backward
namespace backward_mult_transposes {
std::shared_ptr<ov::Model> CreateFunction(UnaryFactoryPtr unary_factory,
size_t num_unary_ops,
const Shape& input_shape,
element::Type input_type) {
auto X = std::make_shared<Parameter>(input_type, input_shape);
NodePtr in_op = X;
for (size_t i = 0; i < num_unary_ops; ++i) {
in_op = unary_factory->create(in_op);
}
auto ng_order0 = std::make_shared<Constant>(element::u64, Shape{4}, Shape{0, 2, 3, 1});
auto transpose0 = std::make_shared<Transpose>(in_op, ng_order0);
auto tanh0 = std::make_shared<Tanh>(transpose0);
auto ng_order1 = std::make_shared<Constant>(element::u64, Shape{4}, Shape{0, 2, 3, 1});
auto transpose1 = std::make_shared<Transpose>(in_op, ng_order1);
auto tanh1 = std::make_shared<Tanh>(transpose1);
return std::make_shared<ov::Model>(ov::OutputVector{tanh0, tanh1}, ov::ParameterVector{X});
}
std::shared_ptr<ov::Model> CreateReferenceFunction(UnaryFactoryPtr unary_factory,
size_t num_unary_ops,
const Shape& input_shape,
element::Type input_type) {
auto X = std::make_shared<Parameter>(input_type, input_shape);
auto ng_order0 = std::make_shared<Constant>(element::u64, Shape{4}, Shape{0, 2, 3, 1});
auto transpose0 = std::make_shared<Transpose>(X, ng_order0);
NodePtr in_op = transpose0;
for (size_t i = 0; i < num_unary_ops; ++i) {
in_op = unary_factory->create(in_op);
}
auto tanh0 = std::make_shared<Tanh>(in_op);
auto tanh1 = std::make_shared<Tanh>(in_op);
return std::make_shared<ov::Model>(ov::OutputVector{tanh0, tanh1}, ov::ParameterVector{X});
}
} // namespace backward_mult_transposes
namespace forward {
std::shared_ptr<ov::Model> CreateFunction(UnaryFactoryPtr unary_factory,
size_t num_unary_ops,
const ov::Shape& input_shape,
ov::element::Type input_type) {
auto X = std::make_shared<ov::opset9::Parameter>(input_type, input_shape);
const Shape& input_shape,
element::Type input_type) {
auto X = std::make_shared<Parameter>(input_type, input_shape);
auto sinh = std::make_shared<ov::opset9::Sinh>(X);
auto sinh = std::make_shared<Sinh>(X);
auto ng_order0 = std::make_shared<ov::opset9::Constant>(ov::element::u64, ov::Shape{4}, ov::Shape{0, 2, 3, 1});
auto transpose0 = std::make_shared<ov::opset9::Transpose>(sinh, ng_order0);
auto ng_order0 = std::make_shared<Constant>(element::u64, Shape{4}, Shape{0, 2, 3, 1});
auto transpose0 = std::make_shared<Transpose>(sinh, ng_order0);
auto reshape = CreateReshape(transpose0, input_shape);
@ -314,14 +393,14 @@ std::shared_ptr<ov::Model> CreateFunction(UnaryFactoryPtr unary_factory,
std::shared_ptr<ov::Model> CreateReferenceFunction(UnaryFactoryPtr unary_factory,
size_t num_unary_ops,
const ov::Shape& input_shape,
ov::element::Type input_type) {
auto X = std::make_shared<ov::opset9::Parameter>(input_type, input_shape);
const Shape& input_shape,
element::Type input_type) {
auto X = std::make_shared<Parameter>(input_type, input_shape);
auto sinh = std::make_shared<ov::opset9::Sinh>(X);
auto sinh = std::make_shared<Sinh>(X);
auto ng_order0 = std::make_shared<ov::opset9::Constant>(ov::element::u64, ov::Shape{4}, ov::Shape{0, 2, 3, 1});
auto transpose0 = std::make_shared<ov::opset9::Transpose>(sinh, ng_order0);
auto ng_order0 = std::make_shared<Constant>(element::u64, Shape{4}, Shape{0, 2, 3, 1});
auto transpose0 = std::make_shared<Transpose>(sinh, ng_order0);
auto reshape = CreateReshape(transpose0, input_shape);
NodePtr in_op = sinh;
@ -329,8 +408,8 @@ std::shared_ptr<ov::Model> CreateReferenceFunction(UnaryFactoryPtr unary_factory
in_op = unary_factory->create(in_op);
}
auto ng_order1 = std::make_shared<ov::opset9::Constant>(ov::element::u64, ov::Shape{4}, ov::Shape{0, 2, 3, 1});
auto transpose1 = std::make_shared<ov::opset9::Transpose>(in_op, ng_order1);
auto ng_order1 = std::make_shared<Constant>(element::u64, Shape{4}, Shape{0, 2, 3, 1});
auto transpose1 = std::make_shared<Transpose>(in_op, ng_order1);
return std::make_shared<ov::Model>(ov::OutputVector{transpose1, reshape}, ov::ParameterVector{X});
}
@ -339,21 +418,16 @@ std::shared_ptr<ov::Model> CreateReferenceFunction(UnaryFactoryPtr unary_factory
} // namespace mult_consumers_first_node
std::vector<UnaryFactoryPtr> unary_factories = {
CreateUnaryFactory<ov::opset9::Clamp>(), CreateUnaryFactory<ov::opset9::Elu>(),
CreateUnaryFactory<ov::opset9::SoftPlus>(), CreateUnaryFactory<ov::opset9::LogicalNot>(),
CreateUnaryFactory<ov::opset9::Convert>(), CreateUnaryFactory<ov::opset9::Abs>(),
CreateUnaryFactory<ov::opset9::Acos>(), CreateUnaryFactory<ov::opset9::Asin>(),
CreateUnaryFactory<ov::opset9::Asinh>(), CreateUnaryFactory<ov::opset9::Atan>(),
CreateUnaryFactory<ov::opset9::Ceiling>(), CreateUnaryFactory<ov::opset9::Cos>(),
CreateUnaryFactory<ov::opset9::Cosh>(), CreateUnaryFactory<ov::opset9::Erf>(),
CreateUnaryFactory<ov::opset9::Exp>(), CreateUnaryFactory<ov::opset9::Gelu>(),
CreateUnaryFactory<ov::opset9::HSigmoid>(), CreateUnaryFactory<ov::opset9::HSwish>(),
CreateUnaryFactory<ov::opset9::Log>(), CreateUnaryFactory<ov::opset9::Negative>(),
CreateUnaryFactory<ov::opset9::Relu>(), CreateUnaryFactory<ov::opset9::Sigmoid>(),
CreateUnaryFactory<ov::opset9::Sign>(), CreateUnaryFactory<ov::opset9::Sin>(),
CreateUnaryFactory<ov::opset9::Sinh>(), CreateUnaryFactory<ov::opset9::SoftSign>(),
CreateUnaryFactory<ov::opset9::Sqrt>(), CreateUnaryFactory<ov::opset9::Tan>(),
CreateUnaryFactory<ov::opset9::Tanh>()};
CREATE_UNARY_FACTORY(Clamp), CREATE_UNARY_FACTORY(Elu), CREATE_UNARY_FACTORY(SoftPlus),
CREATE_UNARY_FACTORY(LogicalNot), CREATE_UNARY_FACTORY(Convert), CREATE_UNARY_FACTORY(Abs),
CREATE_UNARY_FACTORY(Acos), CREATE_UNARY_FACTORY(Asin), CREATE_UNARY_FACTORY(Asinh),
CREATE_UNARY_FACTORY(Atan), CREATE_UNARY_FACTORY(Ceiling), CREATE_UNARY_FACTORY(Cos),
CREATE_UNARY_FACTORY(Cosh), CREATE_UNARY_FACTORY(Erf), CREATE_UNARY_FACTORY(Exp),
CREATE_UNARY_FACTORY(Gelu), CREATE_UNARY_FACTORY(HSigmoid), CREATE_UNARY_FACTORY(HSwish),
CREATE_UNARY_FACTORY(Log), CREATE_UNARY_FACTORY(Negative), CREATE_UNARY_FACTORY(Relu),
CREATE_UNARY_FACTORY(Sigmoid), CREATE_UNARY_FACTORY(Sign), CREATE_UNARY_FACTORY(Sin),
CREATE_UNARY_FACTORY(Sinh), CREATE_UNARY_FACTORY(SoftSign), CREATE_UNARY_FACTORY(Sqrt),
CREATE_UNARY_FACTORY(Tan), CREATE_UNARY_FACTORY(Tanh)};
std::vector<size_t> unary_operations_numbers = {1, 10};
@ -365,8 +439,8 @@ TEST_P(TransposeSinkingUnaryTestFixture, CompareFunctions) {
size_t num_unary_ops;
CreateGraphF model_factory;
CreateGraphF reference_model_factory;
ov::Shape input_shape;
ov::element::Type input_type;
Shape input_shape;
element::Type input_type;
std::tie(unary_factory,
pass_factory,
num_unary_ops,
@ -380,90 +454,95 @@ TEST_P(TransposeSinkingUnaryTestFixture, CompareFunctions) {
pass_factory->registerPass(manager);
}
INSTANTIATE_TEST_SUITE_P(
TransposeSinkingUnaryForwardTestSuite,
TransposeSinkingUnaryTestFixture,
::testing::Combine(::testing::ValuesIn(unary_factories),
::testing::Values(CreatePassFactory<ov::pass::TransposeSinkingUnaryForward>()),
::testing::ValuesIn(unary_operations_numbers),
::testing::Values(CreateFunctionTransposeBefore),
::testing::Values(CreateFunctionTransposeAfter),
::testing::Values(ov::Shape{1, 96, 55, 55}),
::testing::Values(ov::element::f32)));
INSTANTIATE_TEST_SUITE_P(TransposeSinkingUnaryForwardTestSuite,
TransposeSinkingUnaryTestFixture,
::testing::Combine(::testing::ValuesIn(unary_factories),
::testing::Values(CREATE_PASS_FACTORY(TransposeSinkingUnaryForward)),
::testing::ValuesIn(unary_operations_numbers),
::testing::Values(CreateFunctionTransposeBefore),
::testing::Values(CreateFunctionTransposeAfter),
::testing::Values(Shape{1, 96, 55, 55}),
::testing::Values(element::f32)),
TransposeSinkingUnaryTestFixture::get_test_name);
INSTANTIATE_TEST_SUITE_P(
TransposeSinkingUnaryBackwardTestSuite,
TransposeSinkingUnaryTestFixture,
::testing::Combine(::testing::ValuesIn(unary_factories),
::testing::Values(CreatePassFactory<ov::pass::TransposeSinkingUnaryBackward>()),
::testing::ValuesIn(unary_operations_numbers),
::testing::Values(CreateFunctionTransposeAfter),
::testing::Values(CreateFunctionTransposeBefore),
::testing::Values(ov::Shape{1, 96, 55, 55}),
::testing::Values(ov::element::f32)));
INSTANTIATE_TEST_SUITE_P(TransposeSinkingUnaryBackwardTestSuite,
TransposeSinkingUnaryTestFixture,
::testing::Combine(::testing::ValuesIn(unary_factories),
::testing::Values(CREATE_PASS_FACTORY(TransposeSinkingUnaryBackward)),
::testing::ValuesIn(unary_operations_numbers),
::testing::Values(CreateFunctionTransposeAfter),
::testing::Values(CreateFunctionTransposeBefore),
::testing::Values(Shape{1, 96, 55, 55}),
::testing::Values(element::f32)),
TransposeSinkingUnaryTestFixture::get_test_name);
INSTANTIATE_TEST_SUITE_P(
TransposeSinkingUnaryForwardMultConsumersTestSuiteLastNodeReshape,
TransposeSinkingUnaryTestFixture,
::testing::Combine(::testing::ValuesIn(unary_factories),
::testing::Values(CreatePassFactory<ov::pass::TransposeSinkingUnaryForward>()),
::testing::Values(CREATE_PASS_FACTORY(TransposeSinkingUnaryForward)),
::testing::ValuesIn(unary_operations_numbers),
::testing::Values(mult_consumers_last_node::with_reshape::CreateFunctionTransposeBefore),
::testing::Values(mult_consumers_last_node::with_reshape::CreateFunctionTransposeAfter),
::testing::Values(ov::Shape{1, 96, 55, 55}),
::testing::Values(ov::element::f32)));
::testing::Values(Shape{1, 96, 55, 55}),
::testing::Values(element::f32)),
TransposeSinkingUnaryTestFixture::get_test_name);
INSTANTIATE_TEST_SUITE_P(
TransposeSinkingUnaryBackwardMultConsumersTestSuiteLastNodeReshape,
TransposeSinkingUnaryTestFixture,
::testing::Combine(::testing::ValuesIn(unary_factories),
::testing::Values(CreatePassFactory<ov::pass::TransposeSinkingUnaryBackward>()),
::testing::Values(CREATE_PASS_FACTORY(TransposeSinkingUnaryBackward)),
::testing::ValuesIn(unary_operations_numbers),
::testing::Values(mult_consumers_last_node::with_reshape::CreateFunctionTransposeAfter),
::testing::Values(mult_consumers_last_node::with_reshape::CreateFunctionTransposeBefore),
::testing::Values(ov::Shape{1, 96, 55, 55}),
::testing::Values(ov::element::f32)));
::testing::Values(Shape{1, 96, 55, 55}),
::testing::Values(element::f32)),
TransposeSinkingUnaryTestFixture::get_test_name);
INSTANTIATE_TEST_SUITE_P(
TransposeSinkingUnaryForwardMultConsumersTestSuiteLastNodeEltwise,
TransposeSinkingUnaryTestFixture,
::testing::Combine(::testing::ValuesIn(unary_factories),
::testing::Values(CreatePassFactory<ov::pass::TransposeSinkingUnaryForward>()),
::testing::Values(CREATE_PASS_FACTORY(TransposeSinkingUnaryForward)),
::testing::ValuesIn(unary_operations_numbers),
::testing::Values(mult_consumers_last_node::with_eltwise::CreateFunctionTransposeBefore),
::testing::Values(mult_consumers_last_node::with_eltwise::CreateFunctionTransposeAfter),
::testing::Values(ov::Shape{1, 96, 55, 55}),
::testing::Values(ov::element::f32)));
INSTANTIATE_TEST_SUITE_P(
TransposeSinkingUnaryBackwardMultConsumersTestSuiteLastNodeEltwise,
TransposeSinkingUnaryTestFixture,
::testing::Combine(::testing::ValuesIn(unary_factories),
::testing::Values(CreatePassFactory<ov::pass::TransposeSinkingUnaryBackward>()),
::testing::ValuesIn(unary_operations_numbers),
::testing::Values(mult_consumers_last_node::with_eltwise::CreateFunctionTransposeAfter),
::testing::Values(mult_consumers_last_node::with_eltwise::CreateFunctionTransposeBefore),
::testing::Values(ov::Shape{1, 96, 55, 55}),
::testing::Values(ov::element::f32)));
::testing::Values(Shape{1, 96, 55, 55}),
::testing::Values(element::f32)),
TransposeSinkingUnaryTestFixture::get_test_name);
INSTANTIATE_TEST_SUITE_P(
TransposeSinkingUnaryForwardMultConsumersTestSuiteFirstNode,
TransposeSinkingUnaryTestFixture,
::testing::Combine(::testing::ValuesIn(unary_factories),
::testing::Values(CreatePassFactory<ov::pass::TransposeSinkingUnaryForward>()),
::testing::Values(CREATE_PASS_FACTORY(TransposeSinkingUnaryForward)),
::testing::ValuesIn(unary_operations_numbers),
::testing::Values(mult_consumers_first_node::forward::CreateFunction),
::testing::Values(mult_consumers_first_node::forward::CreateReferenceFunction),
::testing::Values(ov::Shape{1, 96, 55, 55}),
::testing::Values(ov::element::f32)));
::testing::Values(Shape{1, 96, 55, 55}),
::testing::Values(element::f32)),
TransposeSinkingUnaryTestFixture::get_test_name);
INSTANTIATE_TEST_SUITE_P(TransposeSinkingUnaryBackwardMultConsumersTestSuiteFirstNode,
TransposeSinkingUnaryTestFixture,
::testing::Combine(::testing::ValuesIn(unary_factories),
::testing::Values(CREATE_PASS_FACTORY(TransposeSinkingUnaryBackward)),
::testing::ValuesIn(unary_operations_numbers),
::testing::Values(mult_consumers_first_node::backward::CreateFunction),
::testing::Values(mult_consumers_first_node::backward::CreateFunction),
::testing::Values(Shape{1, 96, 55, 55}),
::testing::Values(element::f32)),
TransposeSinkingUnaryTestFixture::get_test_name);
INSTANTIATE_TEST_SUITE_P(
TransposeSinkingUnaryBackwardMultConsumersTestSuiteFirstNode,
TransposeSinkingUnaryBackwardMultTransposeConsumersTestSuiteFirstNode,
TransposeSinkingUnaryTestFixture,
::testing::Combine(::testing::ValuesIn(unary_factories),
::testing::Values(CreatePassFactory<ov::pass::TransposeSinkingUnaryBackward>()),
::testing::Values(CREATE_PASS_FACTORY(TransposeSinkingUnaryBackward)),
::testing::ValuesIn(unary_operations_numbers),
::testing::Values(mult_consumers_first_node::backward::CreateFunction),
::testing::Values(mult_consumers_first_node::backward::CreateReferenceFunction),
::testing::Values(ov::Shape{1, 96, 55, 55}),
::testing::Values(ov::element::f32)));
::testing::Values(mult_consumers_first_node::backward_mult_transposes::CreateFunction),
::testing::Values(mult_consumers_first_node::backward_mult_transposes::CreateReferenceFunction),
::testing::Values(Shape{1, 96, 55, 55}),
::testing::Values(element::f32)),
TransposeSinkingUnaryTestFixture::get_test_name);