Transpose sinking binary/split/concat mult consumers case + PRelu (#14670)
* binary: move test classes into namespaces * add unsing namespacees ov and ov::opset9 * ov::Model -> Model * use naming for unit tests * fix unary -> binary * fix identation * add tests binary forward consumers > 1 * add binary output consumers > 1 unit tests * add binary input consumers > 1 unit tests * binary backward add unit test binary has multiple consumers + fix * move code to common function CloneNodeWithoutConsumers * add unit test on backward binary propagation multi consumers * reorganize binary unit tests * add backward_input_node_consumers test * add test output transpose mult consumers * cleanup * cleanup unit tests for split * add forward::mult_consumers::input_node_consumers tests * add forward::mult_consumers::input_transpose_consumers tests * add forward::mult_consumers::output_consumers tests * code indentation fix * fix unit test split backward input node consumers * added mult_output_consumers split unit test * add mult_split_consumers split unit test * cancat tests using namespaces * concat add unit tests * move repeated code into functions * clang format fixes * rename TransposeSinkingBinaryElementwise -> TransposeSinkingBinary * add PRelu * clang cleanup * fix prelu tests * fix using * clang fixes * fix NodeVector argument * fix function descriptions * fix const ref arg functions * clang fixes * prevent backward sinking if not all output nodes are transposes * fix RemoveSingleOutputConsumers checking output size * cleanup * clang cleanup * disable unary backward with not same transposes * clang fixes * remove unneeded functions CloneNodeWithoutConsumers GetNodeIds * remove unneeded GetTransposeConstant(Input<Node> input) as duplicate of GetTransposeConstant(Node* node) * cleanup * add output_transpose_mult_transposes test * add backward::output_transpose_mult_transposes test * add unit tests for unary backward multiple transposes; fix problems by adding new transformation * fix bug TransposeSinkingUnaryBackwardMultiConsumers consumers_more_than
This commit is contained in:
parent
9388560aec
commit
78995e9ac2
@ -11,20 +11,20 @@
|
|||||||
namespace ov {
|
namespace ov {
|
||||||
namespace pass {
|
namespace pass {
|
||||||
|
|
||||||
class TRANSFORMATIONS_API TransposeSinkingBinaryElementwiseForward;
|
class TRANSFORMATIONS_API TransposeSinkingBinaryForward;
|
||||||
class TRANSFORMATIONS_API TransposeSinkingBinaryElementwiseBackward;
|
class TRANSFORMATIONS_API TransposeSinkingBinaryBackward;
|
||||||
|
|
||||||
} // namespace pass
|
} // namespace pass
|
||||||
} // namespace ov
|
} // namespace ov
|
||||||
|
|
||||||
class ov::pass::TransposeSinkingBinaryElementwiseForward : public ov::pass::MatcherPass {
|
class ov::pass::TransposeSinkingBinaryForward : public ov::pass::MatcherPass {
|
||||||
public:
|
public:
|
||||||
OPENVINO_RTTI("ov::pass::TransposeSinkingBinaryElementwiseForward", "0");
|
OPENVINO_RTTI("ov::pass::TransposeSinkingBinaryForward", "0");
|
||||||
TransposeSinkingBinaryElementwiseForward();
|
TransposeSinkingBinaryForward();
|
||||||
};
|
};
|
||||||
|
|
||||||
class ov::pass::TransposeSinkingBinaryElementwiseBackward : public ov::pass::MatcherPass {
|
class ov::pass::TransposeSinkingBinaryBackward : public ov::pass::MatcherPass {
|
||||||
public:
|
public:
|
||||||
OPENVINO_RTTI("ov::pass::TransposeSinkingBinaryElementwiseBackward", "0");
|
OPENVINO_RTTI("ov::pass::TransposeSinkingBinaryBackward", "0");
|
||||||
TransposeSinkingBinaryElementwiseBackward();
|
TransposeSinkingBinaryBackward();
|
||||||
};
|
};
|
||||||
|
@ -11,6 +11,8 @@ namespace ov {
|
|||||||
namespace pass {
|
namespace pass {
|
||||||
|
|
||||||
class TRANSFORMATIONS_API TransposeSinkingUnaryForward;
|
class TRANSFORMATIONS_API TransposeSinkingUnaryForward;
|
||||||
|
class TRANSFORMATIONS_API TransposeSinkingUnaryBackwardSingleConsumer;
|
||||||
|
class TRANSFORMATIONS_API TransposeSinkingUnaryBackwardMultiConsumers;
|
||||||
class TRANSFORMATIONS_API TransposeSinkingUnaryBackward;
|
class TRANSFORMATIONS_API TransposeSinkingUnaryBackward;
|
||||||
|
|
||||||
} // namespace pass
|
} // namespace pass
|
||||||
@ -22,7 +24,19 @@ public:
|
|||||||
TransposeSinkingUnaryForward();
|
TransposeSinkingUnaryForward();
|
||||||
};
|
};
|
||||||
|
|
||||||
class ov::pass::TransposeSinkingUnaryBackward : public ov::pass::MatcherPass {
|
class ov::pass::TransposeSinkingUnaryBackwardSingleConsumer : public ov::pass::MatcherPass {
|
||||||
|
public:
|
||||||
|
OPENVINO_RTTI("TransposeSinkingUnaryBackwardSingleConsumer", "0");
|
||||||
|
TransposeSinkingUnaryBackwardSingleConsumer();
|
||||||
|
};
|
||||||
|
|
||||||
|
class ov::pass::TransposeSinkingUnaryBackwardMultiConsumers : public ov::pass::MatcherPass {
|
||||||
|
public:
|
||||||
|
OPENVINO_RTTI("TransposeSinkingUnaryBackwardMultiConsumers", "0");
|
||||||
|
TransposeSinkingUnaryBackwardMultiConsumers();
|
||||||
|
};
|
||||||
|
|
||||||
|
class ov::pass::TransposeSinkingUnaryBackward : public ov::pass::GraphRewrite {
|
||||||
public:
|
public:
|
||||||
OPENVINO_RTTI("TransposeSinkingUnaryBackward", "0");
|
OPENVINO_RTTI("TransposeSinkingUnaryBackward", "0");
|
||||||
TransposeSinkingUnaryBackward();
|
TransposeSinkingUnaryBackward();
|
||||||
|
@ -28,25 +28,77 @@ struct TransposeInputsInfo {
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
TransposeInputsInfo GetFirstTransposeInput(std::shared_ptr<ov::Node> node);
|
/**
|
||||||
bool IfNodeHasTransposeInputs(const ov::Output<ov::Node>& output);
|
* @brief Finds node first input that is a transpose operation and returns filled TransposeInputsInfo
|
||||||
ov::AxisVector ReverseTransposeOrder(const ov::AxisVector& axis_order);
|
* for it
|
||||||
void SwapOutputNames(ov::Output<ov::Node> output1, ov::Output<ov::Node> output2);
|
*/
|
||||||
void SwapFriendlyNames(std::shared_ptr<ov::Node> node1, std::shared_ptr<ov::Node> node2);
|
TransposeInputsInfo GetFirstTransposeInput(std::shared_ptr<ov::Node>);
|
||||||
void SwapNames(std::shared_ptr<ov::Node> node1, std::shared_ptr<ov::Node> node2);
|
|
||||||
|
/**
|
||||||
|
* @brief Checks if @arg has any input node that is a transpose operation
|
||||||
|
*/
|
||||||
|
bool IfNodeHasTransposeInputs(const ov::Output<ov::Node>&);
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @brief Reverses order of transpose operation. Do it in a such way that if we had couple following one after
|
||||||
|
* another transposes (one would be reversed version of another) we will have no transpose as a result of that
|
||||||
|
* couple of transposes.
|
||||||
|
*/
|
||||||
|
ov::AxisVector ReverseTransposeOrder(const ov::AxisVector&);
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @brief Swaps @args output tensor names
|
||||||
|
*/
|
||||||
|
void SwapOutputNames(ov::Output<ov::Node>, ov::Output<ov::Node>);
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @brief Swaps @args friendly names
|
||||||
|
*/
|
||||||
|
void SwapFriendlyNames(std::shared_ptr<ov::Node>, std::shared_ptr<ov::Node>);
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @brief Swaps @args output tensor names and friendly names
|
||||||
|
*/
|
||||||
|
void SwapNames(std::shared_ptr<ov::Node>, std::shared_ptr<ov::Node>);
|
||||||
|
|
||||||
namespace sink_forward {
|
namespace sink_forward {
|
||||||
// insert input reversed transposes, remove first input tranpose
|
/**
|
||||||
void UpdateInputTransposes(std::shared_ptr<ov::Node> main_node, TransposeInputsInfo& transpose_input_info);
|
* @brief Inserts reversed transposed on @args main_node inputs. Removes input transpose specified in @arg
|
||||||
void RemoveZeroInputNode(std::shared_ptr<ov::Node> main_node);
|
* transpose_input_info
|
||||||
ov::NodeVector InsertOutputTransposes(std::shared_ptr<ov::Node> main_node, TransposeInputsInfo& transpose_input_info);
|
*/
|
||||||
|
void UpdateInputTransposes(std::shared_ptr<ov::Node> main_node, const TransposeInputsInfo& transpose_input_info);
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @brief Removes @arg input node
|
||||||
|
*/
|
||||||
|
void RemoveInputNode(std::shared_ptr<ov::Node>, size_t input_idx);
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @brief Inserts transposes on each main_node output with the order specified in @arg transpose_input_info
|
||||||
|
*/
|
||||||
|
ov::NodeVector InsertOutputTransposes(std::shared_ptr<ov::Node> main_node,
|
||||||
|
const TransposeInputsInfo& transpose_input_info);
|
||||||
} // namespace sink_forward
|
} // namespace sink_forward
|
||||||
|
|
||||||
namespace sink_backward {
|
namespace sink_backward {
|
||||||
|
/**
|
||||||
|
* @brief Inserts transposes on each input of @arg main_node with the order specified in @arg transpose_const
|
||||||
|
*/
|
||||||
ov::NodeVector InsertTransposeBeforeNode(std::shared_ptr<ov::Node> main_node,
|
ov::NodeVector InsertTransposeBeforeNode(std::shared_ptr<ov::Node> main_node,
|
||||||
std::shared_ptr<ov::opset9::Constant> transpose_const);
|
std::shared_ptr<ov::opset9::Constant> transpose_const);
|
||||||
} // namespace sink_backward
|
} // namespace sink_backward
|
||||||
|
|
||||||
void UpdateForwardSinkingAbility(std::shared_ptr<ov::Node>);
|
void UpdateForwardSinkingAbility(std::shared_ptr<ov::Node>);
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @brief Checks if @arg has consumers that all are the same transpose operation. If no consumers at all
|
||||||
|
* returns false.
|
||||||
|
*/
|
||||||
|
bool HasSameOutputTransposeNodes(const ov::Output<ov::Node>&);
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Removes all direct node consumers that have one output
|
||||||
|
*/
|
||||||
|
void RemoveSingleOutputConsumers(std::shared_ptr<ov::Node>);
|
||||||
|
|
||||||
} // namespace transpose_sinking
|
} // namespace transpose_sinking
|
||||||
|
@ -20,10 +20,10 @@ using namespace ov;
|
|||||||
using namespace ov::opset9;
|
using namespace ov::opset9;
|
||||||
using namespace transpose_sinking;
|
using namespace transpose_sinking;
|
||||||
|
|
||||||
ov::pass::TransposeSinkingBinaryElementwiseForward::TransposeSinkingBinaryElementwiseForward() {
|
ov::pass::TransposeSinkingBinaryForward::TransposeSinkingBinaryForward() {
|
||||||
MATCHER_SCOPE(TransposeSinkingBinaryElementwiseForward);
|
MATCHER_SCOPE(TransposeSinkingBinaryForward);
|
||||||
|
|
||||||
auto main_node_label = wrap_type<op::util::BinaryElementwiseArithmetic>(IfNodeHasTransposeInputs);
|
auto main_node_label = wrap_type<op::util::BinaryElementwiseArithmetic, PRelu>(IfNodeHasTransposeInputs);
|
||||||
|
|
||||||
matcher_pass_callback matcher_pass_callback = [=](Matcher& m) {
|
matcher_pass_callback matcher_pass_callback = [=](Matcher& m) {
|
||||||
const auto& pattern_to_output = m.get_pattern_value_map();
|
const auto& pattern_to_output = m.get_pattern_value_map();
|
||||||
@ -46,18 +46,19 @@ ov::pass::TransposeSinkingBinaryElementwiseForward::TransposeSinkingBinaryElemen
|
|||||||
register_matcher(m, matcher_pass_callback);
|
register_matcher(m, matcher_pass_callback);
|
||||||
}
|
}
|
||||||
|
|
||||||
ov::pass::TransposeSinkingBinaryElementwiseBackward::TransposeSinkingBinaryElementwiseBackward() {
|
ov::pass::TransposeSinkingBinaryBackward::TransposeSinkingBinaryBackward() {
|
||||||
MATCHER_SCOPE(TransposeSinkingBinaryElementwiseBackward);
|
MATCHER_SCOPE(TransposeSinkingBinaryBackward);
|
||||||
|
|
||||||
auto main_node_label = wrap_type<op::util::BinaryElementwiseArithmetic>([](const Output<Node>& output) -> bool {
|
auto main_node_label =
|
||||||
return consumers_count(1)(output) && has_static_rank()(output);
|
wrap_type<op::util::BinaryElementwiseArithmetic, PRelu>([](const Output<Node>& output) -> bool {
|
||||||
|
return has_static_rank()(output) && HasSameOutputTransposeNodes(output);
|
||||||
});
|
});
|
||||||
|
|
||||||
auto transpose_const_label = wrap_type<Constant>(consumers_count(1));
|
auto transpose_const_label = wrap_type<Constant>();
|
||||||
|
|
||||||
auto transpose_label =
|
auto transpose_label =
|
||||||
wrap_type<Transpose>({main_node_label, transpose_const_label}, [](const Output<Node>& output) -> bool {
|
wrap_type<Transpose>({main_node_label, transpose_const_label}, [](const Output<Node>& output) -> bool {
|
||||||
return consumers_count(1)(output) && has_static_rank()(output) && is_sinking_node(output);
|
return has_static_rank()(output) && is_sinking_node(output);
|
||||||
});
|
});
|
||||||
|
|
||||||
matcher_pass_callback matcher_pass_callback = [=](Matcher& m) {
|
matcher_pass_callback matcher_pass_callback = [=](Matcher& m) {
|
||||||
@ -70,8 +71,8 @@ ov::pass::TransposeSinkingBinaryElementwiseBackward::TransposeSinkingBinaryEleme
|
|||||||
register_new_node(new_node);
|
register_new_node(new_node);
|
||||||
}
|
}
|
||||||
|
|
||||||
// remove transpose after main node
|
// remove output transposes
|
||||||
transpose->output(0).replace(main_node);
|
RemoveSingleOutputConsumers(main_node);
|
||||||
|
|
||||||
SwapNames(transpose, main_node);
|
SwapNames(transpose, main_node);
|
||||||
|
|
||||||
|
@ -55,14 +55,14 @@ ov::pass::TransposeSinkingConcatBackward::TransposeSinkingConcatBackward() {
|
|||||||
MATCHER_SCOPE(TransposeSinkingConcatBackward);
|
MATCHER_SCOPE(TransposeSinkingConcatBackward);
|
||||||
|
|
||||||
auto main_node_label = wrap_type<Concat>([](const Output<Node>& output) -> bool {
|
auto main_node_label = wrap_type<Concat>([](const Output<Node>& output) -> bool {
|
||||||
return consumers_count(1)(output) && has_static_rank()(output);
|
return has_static_rank()(output) && HasSameOutputTransposeNodes(output);
|
||||||
});
|
});
|
||||||
|
|
||||||
auto transpose_const_label = wrap_type<Constant>(consumers_count(1));
|
auto transpose_const_label = wrap_type<Constant>();
|
||||||
|
|
||||||
auto transpose_label =
|
auto transpose_label =
|
||||||
wrap_type<Transpose>({main_node_label, transpose_const_label}, [](const Output<Node>& output) -> bool {
|
wrap_type<Transpose>({main_node_label, transpose_const_label}, [](const Output<Node>& output) -> bool {
|
||||||
return consumers_count(1)(output) && has_static_rank()(output) && is_sinking_node(output);
|
return has_static_rank()(output) && is_sinking_node(output);
|
||||||
});
|
});
|
||||||
|
|
||||||
matcher_pass_callback matcher_pass_callback = [=](Matcher& m) {
|
matcher_pass_callback matcher_pass_callback = [=](Matcher& m) {
|
||||||
@ -75,8 +75,8 @@ ov::pass::TransposeSinkingConcatBackward::TransposeSinkingConcatBackward() {
|
|||||||
register_new_node(new_node);
|
register_new_node(new_node);
|
||||||
}
|
}
|
||||||
|
|
||||||
// remove transpose after main node
|
// remove output transposes
|
||||||
transpose->output(0).replace(main_node);
|
RemoveSingleOutputConsumers(main_node);
|
||||||
|
|
||||||
SwapNames(transpose, main_node);
|
SwapNames(transpose, main_node);
|
||||||
|
|
||||||
|
@ -21,7 +21,7 @@
|
|||||||
ov::pass::TransposeSinkingGeneralForward::TransposeSinkingGeneralForward() {
|
ov::pass::TransposeSinkingGeneralForward::TransposeSinkingGeneralForward() {
|
||||||
MATCHER_SCOPE(TransposeSinkingGeneralForward);
|
MATCHER_SCOPE(TransposeSinkingGeneralForward);
|
||||||
add_matcher<ov::pass::TransposeSinkingUnaryForward>();
|
add_matcher<ov::pass::TransposeSinkingUnaryForward>();
|
||||||
add_matcher<ov::pass::TransposeSinkingBinaryElementwiseForward>();
|
add_matcher<ov::pass::TransposeSinkingBinaryForward>();
|
||||||
add_matcher<ov::pass::TransposeSinkingConcatForward>();
|
add_matcher<ov::pass::TransposeSinkingConcatForward>();
|
||||||
add_matcher<ov::pass::TransposeSinkingSplitForward>();
|
add_matcher<ov::pass::TransposeSinkingSplitForward>();
|
||||||
add_matcher<ngraph::pass::TransposeFuse>();
|
add_matcher<ngraph::pass::TransposeFuse>();
|
||||||
@ -30,7 +30,7 @@ ov::pass::TransposeSinkingGeneralForward::TransposeSinkingGeneralForward() {
|
|||||||
ov::pass::TransposeSinkingGeneralBackward::TransposeSinkingGeneralBackward() {
|
ov::pass::TransposeSinkingGeneralBackward::TransposeSinkingGeneralBackward() {
|
||||||
MATCHER_SCOPE(TransposeSinkingGeneralBackward);
|
MATCHER_SCOPE(TransposeSinkingGeneralBackward);
|
||||||
add_matcher<ov::pass::TransposeSinkingUnaryBackward>();
|
add_matcher<ov::pass::TransposeSinkingUnaryBackward>();
|
||||||
add_matcher<ov::pass::TransposeSinkingBinaryElementwiseBackward>();
|
add_matcher<ov::pass::TransposeSinkingBinaryBackward>();
|
||||||
add_matcher<ov::pass::TransposeSinkingConcatBackward>();
|
add_matcher<ov::pass::TransposeSinkingConcatBackward>();
|
||||||
add_matcher<ov::pass::TransposeSinkingSplitBackward>();
|
add_matcher<ov::pass::TransposeSinkingSplitBackward>();
|
||||||
add_matcher<ngraph::pass::TransposeFuse>();
|
add_matcher<ngraph::pass::TransposeFuse>();
|
||||||
|
@ -52,78 +52,69 @@ OutputTranspose GetOutputTransposes(NodePtr node) {
|
|||||||
return OutputTranspose();
|
return OutputTranspose();
|
||||||
}
|
}
|
||||||
|
|
||||||
NodePtr FindSplitInput(Node* node) {
|
template <typename NodeT>
|
||||||
|
std::shared_ptr<ov::Node> FindInputNode(ov::Node* node) {
|
||||||
for (size_t input_idx = 0; input_idx < node->get_input_size(); ++input_idx) {
|
for (size_t input_idx = 0; input_idx < node->get_input_size(); ++input_idx) {
|
||||||
NodePtr input_node = node->get_input_node_shared_ptr(input_idx);
|
std::shared_ptr<ov::Node> input_node = node->get_input_node_shared_ptr(input_idx);
|
||||||
auto split_node = as_type_ptr<Split>(input_node);
|
auto target_node = ov::as_type_ptr<NodeT>(input_node);
|
||||||
if (split_node)
|
if (target_node)
|
||||||
return split_node;
|
return target_node;
|
||||||
}
|
}
|
||||||
return {};
|
return {};
|
||||||
}
|
}
|
||||||
|
|
||||||
std::shared_ptr<Constant> GetTransposeConstant(Input<Node> input) {
|
|
||||||
auto transpose_node = dynamic_cast<Transpose*>(input.get_node());
|
|
||||||
if (!transpose_node)
|
|
||||||
return {};
|
|
||||||
|
|
||||||
if (!is_sinking_node(input.get_node()))
|
|
||||||
return {};
|
|
||||||
|
|
||||||
auto constant_node = as_type_ptr<Constant>(transpose_node->input_value(1).get_node_shared_ptr());
|
|
||||||
if (!constant_node)
|
|
||||||
return {};
|
|
||||||
|
|
||||||
return constant_node;
|
|
||||||
}
|
|
||||||
|
|
||||||
bool HasInputSplitAndTransposeSiblings(const Output<Node>& output) {
|
bool HasInputSplitAndTransposeSiblings(const Output<Node>& output) {
|
||||||
NodePtr split_node = FindSplitInput(output.get_node());
|
NodePtr main_node = FindInputNode<Split>(output.get_node());
|
||||||
if (!split_node) {
|
if (!main_node) {
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
AxisVector first_transpose_axis_order;
|
return HasSameOutputTransposeNodes(main_node);
|
||||||
// get first transpose axis
|
|
||||||
{
|
|
||||||
auto constant_node = GetTransposeConstant(*(split_node->get_output_target_inputs(0).begin()));
|
|
||||||
if (!constant_node)
|
|
||||||
return false;
|
|
||||||
first_transpose_axis_order = constant_node->get_axis_vector_val();
|
|
||||||
}
|
|
||||||
|
|
||||||
for (size_t output_idx = 1; output_idx < split_node->get_output_size(); ++output_idx) {
|
|
||||||
for (auto& input : split_node->get_output_target_inputs(output_idx)) {
|
|
||||||
auto constant_node = GetTransposeConstant(input);
|
|
||||||
if (!constant_node)
|
|
||||||
return false;
|
|
||||||
|
|
||||||
AxisVector transpose_axis_order = constant_node->get_axis_vector_val();
|
|
||||||
if (transpose_axis_order.size() != first_transpose_axis_order.size())
|
|
||||||
return false;
|
|
||||||
if (!std::equal(transpose_axis_order.begin(),
|
|
||||||
transpose_axis_order.end(),
|
|
||||||
first_transpose_axis_order.begin()))
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return true;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
bool IsSplitSinked(const Output<Node>& output) {
|
||||||
|
return HasInputSplitAndTransposeSiblings(output) && is_sinking_node(output);
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
|
/*
|
||||||
|
* We follow Transpose operations rather than Split. We cannot create matcher pattern
|
||||||
|
* for Split with Transpose outputs since Split can have different number of outputs.
|
||||||
|
* We just can:
|
||||||
|
* - specify Split as searched node and check if it has transpose outputs
|
||||||
|
* - specify Transpose as searched node and check if it has Split input
|
||||||
|
* Transformations are called on each found node in sorted order from the start to end
|
||||||
|
* of the network. When we proceed Split backward sinking we move input transpose
|
||||||
|
* to the input of the Split operation.
|
||||||
|
* Consider case Split (1) -> Split (2) -> Transpose
|
||||||
|
* If specify Split as main searched node after first transformation work we will have
|
||||||
|
* Split (1) -> Transpose -> Split(2)
|
||||||
|
* Matcher pass will not call TransposeSinkingSplitBackward since
|
||||||
|
* - matcher pattern has no Transpose label
|
||||||
|
* - Split (1) has already been proceeded
|
||||||
|
* Adding Split(2) into the working queue as register_new_node(split)
|
||||||
|
* cannot help us. We just can try to find all input Split operations and add them with
|
||||||
|
* register_new_node(). Implemented way is simpler.
|
||||||
|
*
|
||||||
|
* We sink Transpose through Split operation in a backward way only if all the output
|
||||||
|
* nodes are the same Transpose. We can:
|
||||||
|
* - clone Split with all outputs except Transpose
|
||||||
|
* causes perfomance problems
|
||||||
|
* - add reversed Transpose operations on all outputs except sinking Transpose
|
||||||
|
* nothing to do with new added output Transposes
|
||||||
|
*/
|
||||||
ov::pass::TransposeSinkingSplitBackward::TransposeSinkingSplitBackward() {
|
ov::pass::TransposeSinkingSplitBackward::TransposeSinkingSplitBackward() {
|
||||||
MATCHER_SCOPE(TransposeSinkingSplitBackward);
|
MATCHER_SCOPE(TransposeSinkingSplitBackward);
|
||||||
|
|
||||||
auto transpose_const_label = wrap_type<Constant>(consumers_count(1));
|
auto transpose_const_label = wrap_type<Constant>();
|
||||||
auto transpose_label =
|
auto transpose_label = wrap_type<Transpose>({any_input(), transpose_const_label}, IsSplitSinked);
|
||||||
wrap_type<Transpose>({any_input(), transpose_const_label}, HasInputSplitAndTransposeSiblings);
|
|
||||||
|
|
||||||
matcher_pass_callback matcher_pass_callback = [=](Matcher& m) {
|
matcher_pass_callback matcher_pass_callback = [=](Matcher& m) {
|
||||||
const auto& pattern_to_output = m.get_pattern_value_map();
|
const auto& pattern_to_output = m.get_pattern_value_map();
|
||||||
auto transpose_label_node = pattern_to_output.at(transpose_label).get_node();
|
auto transpose_label_node = pattern_to_output.at(transpose_label).get_node();
|
||||||
|
|
||||||
NodePtr split = FindSplitInput(transpose_label_node);
|
NodePtr split = FindInputNode<Split>(transpose_label_node);
|
||||||
auto split_axis_constant = as_type_ptr<Constant>(split->input_value(1).get_node_shared_ptr());
|
auto split_axis_constant = as_type_ptr<Constant>(split->input_value(1).get_node_shared_ptr());
|
||||||
OutputTranspose output_transpose = GetOutputTransposes(split);
|
OutputTranspose output_transpose = GetOutputTransposes(split);
|
||||||
|
|
||||||
@ -161,11 +152,8 @@ ov::pass::TransposeSinkingSplitBackward::TransposeSinkingSplitBackward() {
|
|||||||
new_split_axis_const);
|
new_split_axis_const);
|
||||||
|
|
||||||
// remove split output transposes
|
// remove split output transposes
|
||||||
for (size_t output_idx = 0; output_idx < split->get_output_size(); ++output_idx) {
|
RemoveSingleOutputConsumers(split);
|
||||||
for (auto& input : split->get_output_target_inputs(output_idx)) {
|
|
||||||
input.get_node()->output(0).replace(split->output(output_idx));
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return true;
|
return true;
|
||||||
};
|
};
|
||||||
|
|
||||||
@ -186,7 +174,7 @@ ov::pass::TransposeSinkingSplitForward::TransposeSinkingSplitForward() {
|
|||||||
|
|
||||||
TransposeInputsInfo transpose_input_info = GetFirstTransposeInput(main_node);
|
TransposeInputsInfo transpose_input_info = GetFirstTransposeInput(main_node);
|
||||||
|
|
||||||
sink_forward::RemoveZeroInputNode(main_node);
|
sink_forward::RemoveInputNode(main_node, /* input_idx */ 0);
|
||||||
for (auto& new_node : sink_forward::InsertOutputTransposes(main_node, transpose_input_info)) {
|
for (auto& new_node : sink_forward::InsertOutputTransposes(main_node, transpose_input_info)) {
|
||||||
register_new_node(new_node);
|
register_new_node(new_node);
|
||||||
transpose_sinking::UpdateForwardSinkingAbility(new_node);
|
transpose_sinking::UpdateForwardSinkingAbility(new_node);
|
||||||
|
@ -10,6 +10,10 @@
|
|||||||
#include "transformations/rt_info/transpose_sinking_attr.hpp"
|
#include "transformations/rt_info/transpose_sinking_attr.hpp"
|
||||||
|
|
||||||
using namespace ov;
|
using namespace ov;
|
||||||
|
using namespace ov::opset9;
|
||||||
|
using namespace ov::pass::pattern;
|
||||||
|
using namespace ov::op::util;
|
||||||
|
using namespace transpose_sinking;
|
||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
|
|
||||||
@ -90,16 +94,11 @@ NodePair Swap(NodePtr first_node, NodePtr second_node) {
|
|||||||
ov::pass::TransposeSinkingUnaryForward::TransposeSinkingUnaryForward() {
|
ov::pass::TransposeSinkingUnaryForward::TransposeSinkingUnaryForward() {
|
||||||
MATCHER_SCOPE(TransposeSinkingUnaryForward);
|
MATCHER_SCOPE(TransposeSinkingUnaryForward);
|
||||||
|
|
||||||
auto transpose_label = ov::pass::pattern::wrap_type<ov::opset9::Transpose>(
|
auto transpose_label = wrap_type<Transpose>({any_input(), any_input()});
|
||||||
{ov::pass::pattern::any_input(), ov::pass::pattern::any_input()});
|
auto unary_label =
|
||||||
auto unary_label = ov::pass::pattern::wrap_type<ov::op::util::UnaryElementwiseArithmetic,
|
wrap_type<UnaryElementwiseArithmetic, Clamp, Elu, SoftPlus, LogicalNot, Convert>({transpose_label});
|
||||||
ov::opset9::Clamp,
|
|
||||||
ov::opset9::Elu,
|
|
||||||
ov::opset9::SoftPlus,
|
|
||||||
ov::opset9::LogicalNot,
|
|
||||||
ov::opset9::Convert>({transpose_label});
|
|
||||||
|
|
||||||
ov::matcher_pass_callback matcher_pass_callback = [=](ov::pass::pattern::Matcher& m) {
|
ov::matcher_pass_callback matcher_pass_callback = [=](Matcher& m) {
|
||||||
const auto& pattern_to_output = m.get_pattern_value_map();
|
const auto& pattern_to_output = m.get_pattern_value_map();
|
||||||
auto transpose = pattern_to_output.at(transpose_label).get_node_shared_ptr();
|
auto transpose = pattern_to_output.at(transpose_label).get_node_shared_ptr();
|
||||||
auto unary = pattern_to_output.at(unary_label).get_node_shared_ptr();
|
auto unary = pattern_to_output.at(unary_label).get_node_shared_ptr();
|
||||||
@ -109,12 +108,12 @@ ov::pass::TransposeSinkingUnaryForward::TransposeSinkingUnaryForward() {
|
|||||||
register_new_node(new_nodes.first);
|
register_new_node(new_nodes.first);
|
||||||
register_new_node(new_nodes.second);
|
register_new_node(new_nodes.second);
|
||||||
|
|
||||||
transpose_sinking::UpdateForwardSinkingAbility(new_nodes.second);
|
UpdateForwardSinkingAbility(new_nodes.second);
|
||||||
|
|
||||||
return true;
|
return true;
|
||||||
};
|
};
|
||||||
|
|
||||||
auto m = std::make_shared<ov::pass::pattern::Matcher>(unary_label, "ov::pass::TransposeSinkingUnaryForward");
|
auto m = std::make_shared<Matcher>(unary_label, "ov::pass::TransposeSinkingUnaryForward");
|
||||||
register_matcher(m, matcher_pass_callback);
|
register_matcher(m, matcher_pass_callback);
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -124,21 +123,16 @@ bool IfSinkingEnabled(const Output<Node>& output) {
|
|||||||
}
|
}
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
ov::pass::TransposeSinkingUnaryBackward::TransposeSinkingUnaryBackward() {
|
ov::pass::TransposeSinkingUnaryBackwardSingleConsumer::TransposeSinkingUnaryBackwardSingleConsumer() {
|
||||||
MATCHER_SCOPE(TransposeSinkingUnaryBackward);
|
MATCHER_SCOPE(TransposeSinkingUnaryBackwardSingleConsumer);
|
||||||
|
|
||||||
auto unary_label = ov::pass::pattern::wrap_type<ov::op::util::UnaryElementwiseArithmetic,
|
auto unary_label =
|
||||||
ov::opset9::Clamp,
|
wrap_type<UnaryElementwiseArithmetic, Clamp, Elu, SoftPlus, LogicalNot, Convert>({any_input()},
|
||||||
ov::opset9::Elu,
|
consumers_count(1));
|
||||||
ov::opset9::SoftPlus,
|
|
||||||
ov::opset9::LogicalNot,
|
|
||||||
ov::opset9::Convert>({ov::pass::pattern::any_input()});
|
|
||||||
|
|
||||||
auto transpose_label =
|
auto transpose_label = wrap_type<Transpose>({unary_label, any_input()}, IfSinkingEnabled);
|
||||||
ov::pass::pattern::wrap_type<ov::opset9::Transpose>({unary_label, ov::pass::pattern::any_input()},
|
|
||||||
IfSinkingEnabled);
|
|
||||||
|
|
||||||
ov::matcher_pass_callback matcher_pass_callback = [=](ov::pass::pattern::Matcher& m) {
|
ov::matcher_pass_callback matcher_pass_callback = [=](Matcher& m) {
|
||||||
const auto& pattern_to_output = m.get_pattern_value_map();
|
const auto& pattern_to_output = m.get_pattern_value_map();
|
||||||
auto transpose = pattern_to_output.at(transpose_label).get_node_shared_ptr();
|
auto transpose = pattern_to_output.at(transpose_label).get_node_shared_ptr();
|
||||||
auto unary = pattern_to_output.at(unary_label).get_node_shared_ptr();
|
auto unary = pattern_to_output.at(unary_label).get_node_shared_ptr();
|
||||||
@ -151,6 +145,55 @@ ov::pass::TransposeSinkingUnaryBackward::TransposeSinkingUnaryBackward() {
|
|||||||
return true;
|
return true;
|
||||||
};
|
};
|
||||||
|
|
||||||
auto m = std::make_shared<ov::pass::pattern::Matcher>(transpose_label, "ov::pass::TransposeSinkingUnaryBackward");
|
auto m = std::make_shared<Matcher>(transpose_label, "ov::pass::TransposeSinkingUnaryBackwardSingleConsumer");
|
||||||
register_matcher(m, matcher_pass_callback);
|
register_matcher(m, matcher_pass_callback);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
std::function<bool(Output<Node>)> consumers_more_than(size_t n) {
|
||||||
|
return [=](Output<Node> output) -> bool {
|
||||||
|
return output.get_target_inputs().size() > n;
|
||||||
|
};
|
||||||
|
}
|
||||||
|
} // namespace
|
||||||
|
|
||||||
|
ov::pass::TransposeSinkingUnaryBackwardMultiConsumers::TransposeSinkingUnaryBackwardMultiConsumers() {
|
||||||
|
MATCHER_SCOPE(TransposeSinkingUnaryBackwardMultiConsumers);
|
||||||
|
|
||||||
|
auto unary_restrictions = [](const Output<Node>& output) -> bool {
|
||||||
|
return consumers_more_than(1)(output) && HasSameOutputTransposeNodes(output);
|
||||||
|
};
|
||||||
|
|
||||||
|
auto unary_label =
|
||||||
|
wrap_type<UnaryElementwiseArithmetic, Clamp, Elu, SoftPlus, LogicalNot, Convert>({any_input()},
|
||||||
|
unary_restrictions);
|
||||||
|
|
||||||
|
auto transpose_const_label = wrap_type<Constant>();
|
||||||
|
|
||||||
|
auto transpose_label = wrap_type<Transpose>({unary_label, transpose_const_label}, IfSinkingEnabled);
|
||||||
|
|
||||||
|
ov::matcher_pass_callback matcher_pass_callback = [=](Matcher& m) {
|
||||||
|
const auto& pattern_to_output = m.get_pattern_value_map();
|
||||||
|
auto transpose_const = as_type_ptr<Constant>(pattern_to_output.at(transpose_const_label).get_node_shared_ptr());
|
||||||
|
auto transpose = pattern_to_output.at(transpose_label).get_node_shared_ptr();
|
||||||
|
auto unary = pattern_to_output.at(unary_label).get_node_shared_ptr();
|
||||||
|
|
||||||
|
for (auto& new_node : sink_backward::InsertTransposeBeforeNode(unary, transpose_const)) {
|
||||||
|
register_new_node(new_node);
|
||||||
|
}
|
||||||
|
|
||||||
|
// remove output transposes
|
||||||
|
RemoveSingleOutputConsumers(unary);
|
||||||
|
|
||||||
|
return true;
|
||||||
|
};
|
||||||
|
|
||||||
|
auto m = std::make_shared<Matcher>(transpose_label, "ov::pass::TransposeSinkingUnaryBackwardMultiConsumers");
|
||||||
|
register_matcher(m, matcher_pass_callback);
|
||||||
|
}
|
||||||
|
|
||||||
|
ov::pass::TransposeSinkingUnaryBackward::TransposeSinkingUnaryBackward() {
|
||||||
|
MATCHER_SCOPE(TransposeSinkingUnaryBackward);
|
||||||
|
add_matcher<ov::pass::TransposeSinkingUnaryBackwardSingleConsumer>();
|
||||||
|
add_matcher<ov::pass::TransposeSinkingUnaryBackwardMultiConsumers>();
|
||||||
|
}
|
||||||
|
@ -115,8 +115,7 @@ ov::Output<ov::Node> FixInputNodeRank(ov::Output<ov::Node> input_node, ov::Rank:
|
|||||||
|
|
||||||
namespace sink_forward {
|
namespace sink_forward {
|
||||||
|
|
||||||
// insert input reversed transposes, remove first input tranpose
|
void UpdateInputTransposes(NodePtr main_node, const TransposeInputsInfo& transpose_input_info) {
|
||||||
void UpdateInputTransposes(NodePtr main_node, TransposeInputsInfo& transpose_input_info) {
|
|
||||||
if (transpose_input_info.isEmpty() || HasDynamicRankInput(main_node))
|
if (transpose_input_info.isEmpty() || HasDynamicRankInput(main_node))
|
||||||
return;
|
return;
|
||||||
|
|
||||||
@ -149,15 +148,15 @@ void UpdateInputTransposes(NodePtr main_node, TransposeInputsInfo& transpose_inp
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
void RemoveZeroInputNode(NodePtr main_node) {
|
void RemoveInputNode(NodePtr main_node, size_t input_idx) {
|
||||||
auto input_node = main_node->input_value(0);
|
auto input_node = main_node->input_value(input_idx);
|
||||||
if (input_node.get_node()->get_input_size() < 1)
|
if (input_node.get_node()->get_input_size() < (input_idx + 1))
|
||||||
return;
|
return;
|
||||||
auto parent_node = input_node.get_node()->input_value(0);
|
auto parent_node = input_node.get_node()->input_value(input_idx);
|
||||||
main_node->input(0).replace_source_output(parent_node);
|
main_node->input(input_idx).replace_source_output(parent_node);
|
||||||
}
|
}
|
||||||
|
|
||||||
NodeVector InsertOutputTransposes(NodePtr main_node, TransposeInputsInfo& transpose_input_info) {
|
NodeVector InsertOutputTransposes(NodePtr main_node, const TransposeInputsInfo& transpose_input_info) {
|
||||||
if (transpose_input_info.isEmpty())
|
if (transpose_input_info.isEmpty())
|
||||||
return {};
|
return {};
|
||||||
const auto transpose_axis_order = transpose_input_info.transpose_const->get_axis_vector_val();
|
const auto transpose_axis_order = transpose_input_info.transpose_const->get_axis_vector_val();
|
||||||
@ -246,6 +245,7 @@ bool CanPropagateForwardThrough(Node* node) {
|
|||||||
CHECK_TRANSPOSE_SINKING_SUPPORTED(Concat, node);
|
CHECK_TRANSPOSE_SINKING_SUPPORTED(Concat, node);
|
||||||
CHECK_TRANSPOSE_SINKING_SUPPORTED(Split, node);
|
CHECK_TRANSPOSE_SINKING_SUPPORTED(Split, node);
|
||||||
CHECK_TRANSPOSE_SINKING_SUPPORTED(Transpose, node);
|
CHECK_TRANSPOSE_SINKING_SUPPORTED(Transpose, node);
|
||||||
|
CHECK_TRANSPOSE_SINKING_SUPPORTED(PRelu, node);
|
||||||
|
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
@ -268,4 +268,76 @@ void UpdateForwardSinkingAbility(NodePtr node) {
|
|||||||
mark_as_no_sinking_node(node);
|
mark_as_no_sinking_node(node);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
|
||||||
|
std::shared_ptr<Constant> GetTransposeConstant(Node* node) {
|
||||||
|
auto transpose_node = dynamic_cast<Transpose*>(node);
|
||||||
|
if (!transpose_node)
|
||||||
|
return {};
|
||||||
|
|
||||||
|
auto constant_node = as_type_ptr<Constant>(transpose_node->input_value(1).get_node_shared_ptr());
|
||||||
|
if (!constant_node)
|
||||||
|
return {};
|
||||||
|
|
||||||
|
return constant_node;
|
||||||
|
}
|
||||||
|
|
||||||
|
Node* FindFirstConsumer(NodePtr node) {
|
||||||
|
for (size_t output_idx = 0; output_idx < node->get_output_size(); ++output_idx) {
|
||||||
|
auto inputs = node->get_output_target_inputs(output_idx);
|
||||||
|
if (inputs.empty())
|
||||||
|
continue;
|
||||||
|
return inputs.begin()->get_node();
|
||||||
|
}
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool HasSameOutputTransposeNodes(NodePtr main_node) {
|
||||||
|
AxisVector first_transpose_axis_order;
|
||||||
|
{
|
||||||
|
Node* first_consumer = FindFirstConsumer(main_node);
|
||||||
|
if (!first_consumer)
|
||||||
|
return false;
|
||||||
|
auto constant_node = GetTransposeConstant(first_consumer);
|
||||||
|
if (!constant_node)
|
||||||
|
return false;
|
||||||
|
first_transpose_axis_order = constant_node->get_axis_vector_val();
|
||||||
|
}
|
||||||
|
|
||||||
|
for (size_t output_idx = 0; output_idx < main_node->get_output_size(); ++output_idx) {
|
||||||
|
for (auto& input : main_node->get_output_target_inputs(output_idx)) {
|
||||||
|
auto constant_node = GetTransposeConstant(input.get_node());
|
||||||
|
if (!constant_node)
|
||||||
|
return false;
|
||||||
|
|
||||||
|
AxisVector transpose_axis_order = constant_node->get_axis_vector_val();
|
||||||
|
if (transpose_axis_order.size() != first_transpose_axis_order.size())
|
||||||
|
return false;
|
||||||
|
if (!std::equal(transpose_axis_order.begin(),
|
||||||
|
transpose_axis_order.end(),
|
||||||
|
first_transpose_axis_order.begin()))
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace
|
||||||
|
|
||||||
|
bool HasSameOutputTransposeNodes(const Output<Node>& output) {
|
||||||
|
return HasSameOutputTransposeNodes(output.get_node_shared_ptr());
|
||||||
|
}
|
||||||
|
|
||||||
|
void RemoveSingleOutputConsumers(NodePtr node) {
|
||||||
|
for (size_t output_idx = 0; output_idx < node->get_output_size(); ++output_idx) {
|
||||||
|
for (auto& input : node->get_output_target_inputs(output_idx)) {
|
||||||
|
Node* consumer = input.get_node();
|
||||||
|
if (consumer->get_output_size() != 1)
|
||||||
|
continue;
|
||||||
|
consumer->output(0).replace(node->output(output_idx));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace transpose_sinking
|
} // namespace transpose_sinking
|
||||||
|
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
@ -12,13 +12,37 @@
|
|||||||
#include "common_test_utils/ngraph_test_utils.hpp"
|
#include "common_test_utils/ngraph_test_utils.hpp"
|
||||||
#include "gtest/gtest.h"
|
#include "gtest/gtest.h"
|
||||||
|
|
||||||
|
using namespace ov;
|
||||||
|
using namespace ov::opset9;
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
std::string to_string(const Shape& shape) {
|
||||||
|
std::ostringstream result;
|
||||||
|
result << "{";
|
||||||
|
for (size_t idx = 0; idx < shape.size(); ++idx) {
|
||||||
|
if (idx)
|
||||||
|
result << ",";
|
||||||
|
result << shape[idx];
|
||||||
|
}
|
||||||
|
result << "}";
|
||||||
|
return result.str();
|
||||||
|
}
|
||||||
|
} // namespace
|
||||||
|
|
||||||
using NodePtr = std::shared_ptr<ov::Node>;
|
using NodePtr = std::shared_ptr<ov::Node>;
|
||||||
|
|
||||||
class IUnaryFactory {
|
class IUnaryFactory {
|
||||||
public:
|
public:
|
||||||
IUnaryFactory() = default;
|
IUnaryFactory(const std::string& type_name) : type_name_(type_name) {}
|
||||||
virtual ~IUnaryFactory() = default;
|
virtual ~IUnaryFactory() = default;
|
||||||
virtual NodePtr create(NodePtr parent_node) const = 0;
|
virtual NodePtr create(NodePtr parent_node) const = 0;
|
||||||
|
|
||||||
|
const std::string& getTypeName() const {
|
||||||
|
return type_name_;
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
const std::string type_name_;
|
||||||
};
|
};
|
||||||
|
|
||||||
using UnaryFactoryPtr = std::shared_ptr<IUnaryFactory>;
|
using UnaryFactoryPtr = std::shared_ptr<IUnaryFactory>;
|
||||||
@ -26,39 +50,45 @@ using UnaryFactoryPtr = std::shared_ptr<IUnaryFactory>;
|
|||||||
template <typename UnaryT>
|
template <typename UnaryT>
|
||||||
class UnaryFactory : public IUnaryFactory {
|
class UnaryFactory : public IUnaryFactory {
|
||||||
public:
|
public:
|
||||||
UnaryFactory() = default;
|
UnaryFactory(const std::string& type_name) : IUnaryFactory(type_name) {}
|
||||||
NodePtr create(NodePtr parent_node) const override {
|
NodePtr create(NodePtr parent_node) const override {
|
||||||
return std::make_shared<UnaryT>(parent_node);
|
return std::make_shared<UnaryT>(parent_node);
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
template <>
|
template <>
|
||||||
NodePtr UnaryFactory<ov::opset9::Elu>::create(NodePtr parent_node) const {
|
NodePtr UnaryFactory<Elu>::create(NodePtr parent_node) const {
|
||||||
return std::make_shared<ov::opset9::Elu>(parent_node, 0.1);
|
return std::make_shared<Elu>(parent_node, 0.1);
|
||||||
}
|
}
|
||||||
|
|
||||||
template <>
|
template <>
|
||||||
NodePtr UnaryFactory<ov::opset9::Clamp>::create(NodePtr parent_node) const {
|
NodePtr UnaryFactory<Clamp>::create(NodePtr parent_node) const {
|
||||||
return std::make_shared<ov::opset9::Clamp>(parent_node, 0.1, 0.2);
|
return std::make_shared<Clamp>(parent_node, 0.1, 0.2);
|
||||||
}
|
}
|
||||||
|
|
||||||
template <>
|
template <>
|
||||||
NodePtr UnaryFactory<ov::opset9::Convert>::create(NodePtr parent_node) const {
|
NodePtr UnaryFactory<Convert>::create(NodePtr parent_node) const {
|
||||||
return std::make_shared<ov::opset9::Convert>(parent_node, ov::element::f64);
|
return std::make_shared<Convert>(parent_node, element::f64);
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename UnaryT>
|
template <typename UnaryT>
|
||||||
UnaryFactoryPtr CreateUnaryFactory() {
|
UnaryFactoryPtr CreateUnaryFactory(const std::string& type_name) {
|
||||||
return std::make_shared<UnaryFactory<UnaryT>>();
|
return std::make_shared<UnaryFactory<UnaryT>>(type_name);
|
||||||
}
|
}
|
||||||
|
|
||||||
// ----------------------------------------------------------------------------
|
// ----------------------------------------------------------------------------
|
||||||
|
|
||||||
class IPassFactory {
|
class IPassFactory {
|
||||||
public:
|
public:
|
||||||
IPassFactory() = default;
|
IPassFactory(const std::string& type_name) : type_name_(type_name) {}
|
||||||
virtual ~IPassFactory() = default;
|
virtual ~IPassFactory() = default;
|
||||||
virtual void registerPass(ov::pass::Manager& pass_manager) const = 0;
|
virtual void registerPass(ov::pass::Manager& pass_manager) const = 0;
|
||||||
|
const std::string& getTypeName() const {
|
||||||
|
return type_name_;
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
const std::string type_name_;
|
||||||
};
|
};
|
||||||
|
|
||||||
using PassFactoryPtr = std::shared_ptr<IPassFactory>;
|
using PassFactoryPtr = std::shared_ptr<IPassFactory>;
|
||||||
@ -66,15 +96,16 @@ using PassFactoryPtr = std::shared_ptr<IPassFactory>;
|
|||||||
template <typename PassT>
|
template <typename PassT>
|
||||||
class PassFactory : public IPassFactory {
|
class PassFactory : public IPassFactory {
|
||||||
public:
|
public:
|
||||||
|
PassFactory(const std::string& type_name) : IPassFactory(type_name) {}
|
||||||
void registerPass(ov::pass::Manager& pass_manager) const override {
|
void registerPass(ov::pass::Manager& pass_manager) const override {
|
||||||
pass_manager.register_pass<PassT>();
|
pass_manager.register_pass<PassT>();
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
template <typename PassT>
|
#define CREATE_PASS_FACTORY(pass_name) std::make_shared<PassFactory<ov::pass::pass_name>>(#pass_name)
|
||||||
PassFactoryPtr CreatePassFactory() {
|
|
||||||
return std::make_shared<PassFactory<PassT>>();
|
#undef CREATE_UNARY_FACTORY
|
||||||
}
|
#define CREATE_UNARY_FACTORY(type_name) CreateUnaryFactory<type_name>(#type_name)
|
||||||
|
|
||||||
// ----------------------------------------------------------------------------
|
// ----------------------------------------------------------------------------
|
||||||
|
|
||||||
@ -82,19 +113,46 @@ using FloatPtr = std::unique_ptr<float[]>;
|
|||||||
|
|
||||||
using CreateGraphF = std::function<std::shared_ptr<ov::Model>(UnaryFactoryPtr unary_factory,
|
using CreateGraphF = std::function<std::shared_ptr<ov::Model>(UnaryFactoryPtr unary_factory,
|
||||||
size_t num_unary_ops,
|
size_t num_unary_ops,
|
||||||
const ov::Shape& input_shape,
|
const Shape& input_shape,
|
||||||
ov::element::Type input_type)>;
|
element::Type input_type)>;
|
||||||
|
|
||||||
using TestParams = std::tuple<UnaryFactoryPtr,
|
using TestParams = std::tuple<UnaryFactoryPtr,
|
||||||
PassFactoryPtr,
|
PassFactoryPtr,
|
||||||
size_t, /* num_unary_ops */
|
size_t, /* num_unary_ops */
|
||||||
CreateGraphF, /* model_factory */
|
CreateGraphF, /* model_factory */
|
||||||
CreateGraphF, /* reference_model_factory */
|
CreateGraphF, /* reference_model_factory */
|
||||||
ov::Shape, /* input shape */
|
Shape, /* input shape */
|
||||||
ov::element::Type>; /* input type */
|
element::Type>; /* input type */
|
||||||
|
|
||||||
class TransposeSinkingUnaryTestFixture : public ::testing::WithParamInterface<TestParams>,
|
class TransposeSinkingUnaryTestFixture : public ::testing::WithParamInterface<TestParams>, public TransformationTestsF {
|
||||||
public TransformationTestsF {};
|
public:
|
||||||
|
static std::string get_test_name(const testing::TestParamInfo<TestParams>& obj) {
|
||||||
|
UnaryFactoryPtr unary_factory;
|
||||||
|
PassFactoryPtr pass_factory;
|
||||||
|
size_t num_unary_ops;
|
||||||
|
CreateGraphF model_factory;
|
||||||
|
CreateGraphF reference_model_factory;
|
||||||
|
Shape input_shape;
|
||||||
|
element::Type input_type;
|
||||||
|
std::tie(unary_factory,
|
||||||
|
pass_factory,
|
||||||
|
num_unary_ops,
|
||||||
|
model_factory,
|
||||||
|
reference_model_factory,
|
||||||
|
input_shape,
|
||||||
|
input_type) = obj.param;
|
||||||
|
|
||||||
|
std::ostringstream test_name;
|
||||||
|
test_name << "unaryFactory=" << unary_factory->getTypeName() << "/";
|
||||||
|
test_name << "numUnaryOps=" << num_unary_ops << "/";
|
||||||
|
test_name << "inputShape=" << to_string(input_shape) << "/";
|
||||||
|
test_name << "unaryFactory=" << unary_factory->getTypeName() << "/";
|
||||||
|
test_name << "passFactory=" << pass_factory->getTypeName() << "/";
|
||||||
|
test_name << "inputType=" << input_type;
|
||||||
|
|
||||||
|
return test_name.str();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
|
|
||||||
@ -105,12 +163,12 @@ std::string GetFinalNodeName(std::shared_ptr<ov::Model> model, int index = 0) {
|
|||||||
|
|
||||||
std::shared_ptr<ov::Model> CreateFunctionTransposeBefore(UnaryFactoryPtr unary_factory,
|
std::shared_ptr<ov::Model> CreateFunctionTransposeBefore(UnaryFactoryPtr unary_factory,
|
||||||
size_t num_unary_ops,
|
size_t num_unary_ops,
|
||||||
const ov::Shape& input_shape,
|
const Shape& input_shape,
|
||||||
ov::element::Type input_type) {
|
element::Type input_type) {
|
||||||
auto X = std::make_shared<ov::opset9::Parameter>(input_type, input_shape);
|
auto X = std::make_shared<Parameter>(input_type, input_shape);
|
||||||
|
|
||||||
auto ng_order0 = std::make_shared<ov::opset9::Constant>(ov::element::u64, ov::Shape{4}, ov::Shape{0, 2, 3, 1});
|
auto ng_order0 = std::make_shared<Constant>(element::u64, Shape{4}, Shape{0, 2, 3, 1});
|
||||||
auto transpose0 = std::make_shared<ov::opset9::Transpose>(X, ng_order0);
|
auto transpose0 = std::make_shared<Transpose>(X, ng_order0);
|
||||||
|
|
||||||
NodePtr in_op = transpose0;
|
NodePtr in_op = transpose0;
|
||||||
for (size_t i = 0; i < num_unary_ops; ++i) {
|
for (size_t i = 0; i < num_unary_ops; ++i) {
|
||||||
@ -122,25 +180,25 @@ std::shared_ptr<ov::Model> CreateFunctionTransposeBefore(UnaryFactoryPtr unary_f
|
|||||||
|
|
||||||
std::shared_ptr<ov::Model> CreateFunctionTransposeAfter(UnaryFactoryPtr unary_factory,
|
std::shared_ptr<ov::Model> CreateFunctionTransposeAfter(UnaryFactoryPtr unary_factory,
|
||||||
size_t num_unary_ops,
|
size_t num_unary_ops,
|
||||||
const ov::Shape& input_shape,
|
const Shape& input_shape,
|
||||||
ov::element::Type input_type) {
|
element::Type input_type) {
|
||||||
auto X = std::make_shared<ov::opset9::Parameter>(input_type, input_shape);
|
auto X = std::make_shared<Parameter>(input_type, input_shape);
|
||||||
|
|
||||||
NodePtr in_op = X;
|
NodePtr in_op = X;
|
||||||
for (size_t i = 0; i < num_unary_ops; ++i) {
|
for (size_t i = 0; i < num_unary_ops; ++i) {
|
||||||
in_op = unary_factory->create(in_op);
|
in_op = unary_factory->create(in_op);
|
||||||
}
|
}
|
||||||
|
|
||||||
auto ng_order0 = std::make_shared<ov::opset9::Constant>(ov::element::u64, ov::Shape{4}, ov::Shape{0, 2, 3, 1});
|
auto ng_order0 = std::make_shared<Constant>(element::u64, Shape{4}, Shape{0, 2, 3, 1});
|
||||||
auto transpose0 = std::make_shared<ov::opset9::Transpose>(in_op, ng_order0);
|
auto transpose0 = std::make_shared<Transpose>(in_op, ng_order0);
|
||||||
|
|
||||||
return std::make_shared<ov::Model>(transpose0, ov::ParameterVector{X});
|
return std::make_shared<ov::Model>(transpose0, ov::ParameterVector{X});
|
||||||
}
|
}
|
||||||
|
|
||||||
static NodePtr CreateReshape(NodePtr parent_node, const ov::Shape& input_shape) {
|
static NodePtr CreateReshape(NodePtr parent_node, const Shape& input_shape) {
|
||||||
const size_t mul = std::accumulate(input_shape.begin(), input_shape.end(), (size_t)1, std::multiplies<size_t>());
|
const size_t mul = std::accumulate(input_shape.begin(), input_shape.end(), (size_t)1, std::multiplies<size_t>());
|
||||||
auto reshape_const = std::make_shared<ov::opset9::Constant>(ov::element::u64, ov::Shape{1}, ov::Shape{mul});
|
auto reshape_const = std::make_shared<Constant>(element::u64, Shape{1}, Shape{mul});
|
||||||
return std::make_shared<ov::opset9::Reshape>(parent_node, reshape_const, false);
|
return std::make_shared<Reshape>(parent_node, reshape_const, false);
|
||||||
}
|
}
|
||||||
|
|
||||||
namespace mult_consumers_last_node {
|
namespace mult_consumers_last_node {
|
||||||
@ -148,17 +206,17 @@ namespace with_reshape {
|
|||||||
|
|
||||||
std::shared_ptr<ov::Model> CreateFunctionTransposeAfter(UnaryFactoryPtr unary_factory,
|
std::shared_ptr<ov::Model> CreateFunctionTransposeAfter(UnaryFactoryPtr unary_factory,
|
||||||
size_t num_unary_ops,
|
size_t num_unary_ops,
|
||||||
const ov::Shape& input_shape,
|
const Shape& input_shape,
|
||||||
ov::element::Type input_type) {
|
element::Type input_type) {
|
||||||
auto X = std::make_shared<ov::opset9::Parameter>(input_type, input_shape);
|
auto X = std::make_shared<Parameter>(input_type, input_shape);
|
||||||
|
|
||||||
NodePtr in_op = X;
|
NodePtr in_op = X;
|
||||||
for (size_t i = 0; i < num_unary_ops; ++i) {
|
for (size_t i = 0; i < num_unary_ops; ++i) {
|
||||||
in_op = unary_factory->create(in_op);
|
in_op = unary_factory->create(in_op);
|
||||||
}
|
}
|
||||||
|
|
||||||
auto ng_order0 = std::make_shared<ov::opset9::Constant>(ov::element::u64, ov::Shape{4}, ov::Shape{0, 2, 3, 1});
|
auto ng_order0 = std::make_shared<Constant>(element::u64, Shape{4}, Shape{0, 2, 3, 1});
|
||||||
auto transpose0 = std::make_shared<ov::opset9::Transpose>(in_op, ng_order0);
|
auto transpose0 = std::make_shared<Transpose>(in_op, ng_order0);
|
||||||
|
|
||||||
auto reshape1 = CreateReshape(transpose0, input_shape);
|
auto reshape1 = CreateReshape(transpose0, input_shape);
|
||||||
auto reshape2 = CreateReshape(transpose0, input_shape);
|
auto reshape2 = CreateReshape(transpose0, input_shape);
|
||||||
@ -168,12 +226,12 @@ std::shared_ptr<ov::Model> CreateFunctionTransposeAfter(UnaryFactoryPtr unary_fa
|
|||||||
|
|
||||||
std::shared_ptr<ov::Model> CreateFunctionTransposeBefore(UnaryFactoryPtr unary_factory,
|
std::shared_ptr<ov::Model> CreateFunctionTransposeBefore(UnaryFactoryPtr unary_factory,
|
||||||
size_t num_unary_ops,
|
size_t num_unary_ops,
|
||||||
const ov::Shape& input_shape,
|
const Shape& input_shape,
|
||||||
ov::element::Type input_type) {
|
element::Type input_type) {
|
||||||
auto X = std::make_shared<ov::opset9::Parameter>(input_type, input_shape);
|
auto X = std::make_shared<Parameter>(input_type, input_shape);
|
||||||
|
|
||||||
auto ng_order0 = std::make_shared<ov::opset9::Constant>(ov::element::u64, ov::Shape{4}, ov::Shape{0, 2, 3, 1});
|
auto ng_order0 = std::make_shared<Constant>(element::u64, Shape{4}, Shape{0, 2, 3, 1});
|
||||||
auto transpose0 = std::make_shared<ov::opset9::Transpose>(X, ng_order0);
|
auto transpose0 = std::make_shared<Transpose>(X, ng_order0);
|
||||||
|
|
||||||
NodePtr in_op = transpose0;
|
NodePtr in_op = transpose0;
|
||||||
for (size_t i = 0; i < num_unary_ops; ++i) {
|
for (size_t i = 0; i < num_unary_ops; ++i) {
|
||||||
@ -191,44 +249,44 @@ namespace with_eltwise {
|
|||||||
|
|
||||||
std::shared_ptr<ov::Model> CreateFunctionTransposeAfter(UnaryFactoryPtr unary_factory,
|
std::shared_ptr<ov::Model> CreateFunctionTransposeAfter(UnaryFactoryPtr unary_factory,
|
||||||
size_t num_unary_ops,
|
size_t num_unary_ops,
|
||||||
const ov::Shape& input_shape,
|
const Shape& input_shape,
|
||||||
ov::element::Type input_type) {
|
element::Type input_type) {
|
||||||
auto X = std::make_shared<ov::opset9::Parameter>(input_type, input_shape);
|
auto X = std::make_shared<Parameter>(input_type, input_shape);
|
||||||
|
|
||||||
NodePtr in_op = X;
|
NodePtr in_op = X;
|
||||||
for (size_t i = 0; i < num_unary_ops; ++i) {
|
for (size_t i = 0; i < num_unary_ops; ++i) {
|
||||||
in_op = unary_factory->create(in_op);
|
in_op = unary_factory->create(in_op);
|
||||||
}
|
}
|
||||||
|
|
||||||
auto sinh = std::make_shared<ov::opset9::Sinh>(in_op);
|
auto sinh = std::make_shared<Sinh>(in_op);
|
||||||
|
|
||||||
auto ng_order0 = std::make_shared<ov::opset9::Constant>(ov::element::u64, ov::Shape{4}, ov::Shape{0, 2, 3, 1});
|
auto ng_order0 = std::make_shared<Constant>(element::u64, Shape{4}, Shape{0, 2, 3, 1});
|
||||||
auto transpose0 = std::make_shared<ov::opset9::Transpose>(sinh, ng_order0);
|
auto transpose0 = std::make_shared<Transpose>(sinh, ng_order0);
|
||||||
|
|
||||||
auto cosh = std::make_shared<ov::opset9::Cosh>(in_op);
|
auto cosh = std::make_shared<Cosh>(in_op);
|
||||||
|
|
||||||
auto ng_order1 = std::make_shared<ov::opset9::Constant>(ov::element::u64, ov::Shape{4}, ov::Shape{0, 2, 3, 1});
|
auto ng_order1 = std::make_shared<Constant>(element::u64, Shape{4}, Shape{0, 2, 3, 1});
|
||||||
auto transpose1 = std::make_shared<ov::opset9::Transpose>(cosh, ng_order1);
|
auto transpose1 = std::make_shared<Transpose>(cosh, ng_order1);
|
||||||
|
|
||||||
return std::make_shared<ov::Model>(ov::OutputVector{transpose0, transpose1}, ov::ParameterVector{X});
|
return std::make_shared<ov::Model>(ov::OutputVector{transpose0, transpose1}, ov::ParameterVector{X});
|
||||||
}
|
}
|
||||||
|
|
||||||
std::shared_ptr<ov::Model> CreateFunctionTransposeBefore(UnaryFactoryPtr unary_factory,
|
std::shared_ptr<ov::Model> CreateFunctionTransposeBefore(UnaryFactoryPtr unary_factory,
|
||||||
size_t num_unary_ops,
|
size_t num_unary_ops,
|
||||||
const ov::Shape& input_shape,
|
const Shape& input_shape,
|
||||||
ov::element::Type input_type) {
|
element::Type input_type) {
|
||||||
auto X = std::make_shared<ov::opset9::Parameter>(input_type, input_shape);
|
auto X = std::make_shared<Parameter>(input_type, input_shape);
|
||||||
|
|
||||||
auto ng_order0 = std::make_shared<ov::opset9::Constant>(ov::element::u64, ov::Shape{4}, ov::Shape{0, 2, 3, 1});
|
auto ng_order0 = std::make_shared<Constant>(element::u64, Shape{4}, Shape{0, 2, 3, 1});
|
||||||
auto transpose0 = std::make_shared<ov::opset9::Transpose>(X, ng_order0);
|
auto transpose0 = std::make_shared<Transpose>(X, ng_order0);
|
||||||
|
|
||||||
NodePtr in_op = transpose0;
|
NodePtr in_op = transpose0;
|
||||||
for (size_t i = 0; i < num_unary_ops; ++i) {
|
for (size_t i = 0; i < num_unary_ops; ++i) {
|
||||||
in_op = unary_factory->create(in_op);
|
in_op = unary_factory->create(in_op);
|
||||||
}
|
}
|
||||||
|
|
||||||
auto sinh = std::make_shared<ov::opset9::Sinh>(in_op);
|
auto sinh = std::make_shared<Sinh>(in_op);
|
||||||
auto cosh = std::make_shared<ov::opset9::Cosh>(in_op);
|
auto cosh = std::make_shared<Cosh>(in_op);
|
||||||
|
|
||||||
return std::make_shared<ov::Model>(ov::OutputVector{sinh, cosh}, ov::ParameterVector{X});
|
return std::make_shared<ov::Model>(ov::OutputVector{sinh, cosh}, ov::ParameterVector{X});
|
||||||
}
|
}
|
||||||
@ -241,66 +299,87 @@ namespace backward {
|
|||||||
|
|
||||||
std::shared_ptr<ov::Model> CreateFunction(UnaryFactoryPtr unary_factory,
|
std::shared_ptr<ov::Model> CreateFunction(UnaryFactoryPtr unary_factory,
|
||||||
size_t num_unary_ops,
|
size_t num_unary_ops,
|
||||||
const ov::Shape& input_shape,
|
const Shape& input_shape,
|
||||||
ov::element::Type input_type) {
|
element::Type input_type) {
|
||||||
auto X = std::make_shared<ov::opset9::Parameter>(input_type, input_shape);
|
auto X = std::make_shared<Parameter>(input_type, input_shape);
|
||||||
ov::OutputVector outputs;
|
ov::OutputVector outputs;
|
||||||
|
|
||||||
NodePtr in_op = X;
|
NodePtr in_op = X;
|
||||||
for (size_t i = 0; i < num_unary_ops; ++i) {
|
for (size_t i = 0; i < num_unary_ops; ++i) {
|
||||||
in_op = unary_factory->create(in_op);
|
in_op = unary_factory->create(in_op);
|
||||||
auto cosh = std::make_shared<ov::opset9::Cosh>(in_op);
|
auto cosh = std::make_shared<Cosh>(in_op);
|
||||||
outputs.push_back(cosh);
|
outputs.push_back(cosh);
|
||||||
}
|
}
|
||||||
|
|
||||||
auto ng_order0 = std::make_shared<ov::opset9::Constant>(ov::element::u64, ov::Shape{4}, ov::Shape{0, 2, 3, 1});
|
auto ng_order0 = std::make_shared<Constant>(element::u64, Shape{4}, Shape{0, 2, 3, 1});
|
||||||
auto transpose0 = std::make_shared<ov::opset9::Transpose>(in_op, ng_order0);
|
auto transpose0 = std::make_shared<Transpose>(in_op, ng_order0);
|
||||||
outputs.push_back(transpose0);
|
outputs.push_back(transpose0);
|
||||||
|
|
||||||
return std::make_shared<ov::Model>(outputs, ov::ParameterVector{X});
|
return std::make_shared<ov::Model>(outputs, ov::ParameterVector{X});
|
||||||
}
|
}
|
||||||
|
|
||||||
std::shared_ptr<ov::Model> CreateReferenceFunction(UnaryFactoryPtr unary_factory,
|
|
||||||
size_t num_unary_ops,
|
|
||||||
const ov::Shape& input_shape,
|
|
||||||
ov::element::Type input_type) {
|
|
||||||
auto X = std::make_shared<ov::opset9::Parameter>(input_type, input_shape);
|
|
||||||
ov::OutputVector outputs;
|
|
||||||
|
|
||||||
NodePtr in_op = X;
|
|
||||||
for (size_t i = 0; i < num_unary_ops; ++i) {
|
|
||||||
in_op = unary_factory->create(in_op);
|
|
||||||
auto cosh = std::make_shared<ov::opset9::Cosh>(in_op);
|
|
||||||
outputs.push_back(cosh);
|
|
||||||
}
|
|
||||||
|
|
||||||
auto ng_order0 = std::make_shared<ov::opset9::Constant>(ov::element::u64, ov::Shape{4}, ov::Shape{0, 2, 3, 1});
|
|
||||||
auto transpose0 = std::make_shared<ov::opset9::Transpose>(X, ng_order0);
|
|
||||||
|
|
||||||
in_op = transpose0;
|
|
||||||
for (size_t i = 0; i < num_unary_ops; ++i) {
|
|
||||||
in_op = unary_factory->create(in_op);
|
|
||||||
}
|
|
||||||
|
|
||||||
outputs.push_back(in_op);
|
|
||||||
|
|
||||||
return std::make_shared<ov::Model>(outputs, ov::ParameterVector{X});
|
|
||||||
}
|
|
||||||
|
|
||||||
} // namespace backward
|
} // namespace backward
|
||||||
|
|
||||||
|
namespace backward_mult_transposes {
|
||||||
|
|
||||||
|
std::shared_ptr<ov::Model> CreateFunction(UnaryFactoryPtr unary_factory,
|
||||||
|
size_t num_unary_ops,
|
||||||
|
const Shape& input_shape,
|
||||||
|
element::Type input_type) {
|
||||||
|
auto X = std::make_shared<Parameter>(input_type, input_shape);
|
||||||
|
|
||||||
|
NodePtr in_op = X;
|
||||||
|
for (size_t i = 0; i < num_unary_ops; ++i) {
|
||||||
|
in_op = unary_factory->create(in_op);
|
||||||
|
}
|
||||||
|
|
||||||
|
auto ng_order0 = std::make_shared<Constant>(element::u64, Shape{4}, Shape{0, 2, 3, 1});
|
||||||
|
auto transpose0 = std::make_shared<Transpose>(in_op, ng_order0);
|
||||||
|
|
||||||
|
auto tanh0 = std::make_shared<Tanh>(transpose0);
|
||||||
|
|
||||||
|
auto ng_order1 = std::make_shared<Constant>(element::u64, Shape{4}, Shape{0, 2, 3, 1});
|
||||||
|
auto transpose1 = std::make_shared<Transpose>(in_op, ng_order1);
|
||||||
|
|
||||||
|
auto tanh1 = std::make_shared<Tanh>(transpose1);
|
||||||
|
|
||||||
|
return std::make_shared<ov::Model>(ov::OutputVector{tanh0, tanh1}, ov::ParameterVector{X});
|
||||||
|
}
|
||||||
|
|
||||||
|
std::shared_ptr<ov::Model> CreateReferenceFunction(UnaryFactoryPtr unary_factory,
|
||||||
|
size_t num_unary_ops,
|
||||||
|
const Shape& input_shape,
|
||||||
|
element::Type input_type) {
|
||||||
|
auto X = std::make_shared<Parameter>(input_type, input_shape);
|
||||||
|
|
||||||
|
auto ng_order0 = std::make_shared<Constant>(element::u64, Shape{4}, Shape{0, 2, 3, 1});
|
||||||
|
auto transpose0 = std::make_shared<Transpose>(X, ng_order0);
|
||||||
|
|
||||||
|
NodePtr in_op = transpose0;
|
||||||
|
for (size_t i = 0; i < num_unary_ops; ++i) {
|
||||||
|
in_op = unary_factory->create(in_op);
|
||||||
|
}
|
||||||
|
|
||||||
|
auto tanh0 = std::make_shared<Tanh>(in_op);
|
||||||
|
auto tanh1 = std::make_shared<Tanh>(in_op);
|
||||||
|
|
||||||
|
return std::make_shared<ov::Model>(ov::OutputVector{tanh0, tanh1}, ov::ParameterVector{X});
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace backward_mult_transposes
|
||||||
|
|
||||||
namespace forward {
|
namespace forward {
|
||||||
|
|
||||||
std::shared_ptr<ov::Model> CreateFunction(UnaryFactoryPtr unary_factory,
|
std::shared_ptr<ov::Model> CreateFunction(UnaryFactoryPtr unary_factory,
|
||||||
size_t num_unary_ops,
|
size_t num_unary_ops,
|
||||||
const ov::Shape& input_shape,
|
const Shape& input_shape,
|
||||||
ov::element::Type input_type) {
|
element::Type input_type) {
|
||||||
auto X = std::make_shared<ov::opset9::Parameter>(input_type, input_shape);
|
auto X = std::make_shared<Parameter>(input_type, input_shape);
|
||||||
|
|
||||||
auto sinh = std::make_shared<ov::opset9::Sinh>(X);
|
auto sinh = std::make_shared<Sinh>(X);
|
||||||
|
|
||||||
auto ng_order0 = std::make_shared<ov::opset9::Constant>(ov::element::u64, ov::Shape{4}, ov::Shape{0, 2, 3, 1});
|
auto ng_order0 = std::make_shared<Constant>(element::u64, Shape{4}, Shape{0, 2, 3, 1});
|
||||||
auto transpose0 = std::make_shared<ov::opset9::Transpose>(sinh, ng_order0);
|
auto transpose0 = std::make_shared<Transpose>(sinh, ng_order0);
|
||||||
|
|
||||||
auto reshape = CreateReshape(transpose0, input_shape);
|
auto reshape = CreateReshape(transpose0, input_shape);
|
||||||
|
|
||||||
@ -314,14 +393,14 @@ std::shared_ptr<ov::Model> CreateFunction(UnaryFactoryPtr unary_factory,
|
|||||||
|
|
||||||
std::shared_ptr<ov::Model> CreateReferenceFunction(UnaryFactoryPtr unary_factory,
|
std::shared_ptr<ov::Model> CreateReferenceFunction(UnaryFactoryPtr unary_factory,
|
||||||
size_t num_unary_ops,
|
size_t num_unary_ops,
|
||||||
const ov::Shape& input_shape,
|
const Shape& input_shape,
|
||||||
ov::element::Type input_type) {
|
element::Type input_type) {
|
||||||
auto X = std::make_shared<ov::opset9::Parameter>(input_type, input_shape);
|
auto X = std::make_shared<Parameter>(input_type, input_shape);
|
||||||
|
|
||||||
auto sinh = std::make_shared<ov::opset9::Sinh>(X);
|
auto sinh = std::make_shared<Sinh>(X);
|
||||||
|
|
||||||
auto ng_order0 = std::make_shared<ov::opset9::Constant>(ov::element::u64, ov::Shape{4}, ov::Shape{0, 2, 3, 1});
|
auto ng_order0 = std::make_shared<Constant>(element::u64, Shape{4}, Shape{0, 2, 3, 1});
|
||||||
auto transpose0 = std::make_shared<ov::opset9::Transpose>(sinh, ng_order0);
|
auto transpose0 = std::make_shared<Transpose>(sinh, ng_order0);
|
||||||
auto reshape = CreateReshape(transpose0, input_shape);
|
auto reshape = CreateReshape(transpose0, input_shape);
|
||||||
|
|
||||||
NodePtr in_op = sinh;
|
NodePtr in_op = sinh;
|
||||||
@ -329,8 +408,8 @@ std::shared_ptr<ov::Model> CreateReferenceFunction(UnaryFactoryPtr unary_factory
|
|||||||
in_op = unary_factory->create(in_op);
|
in_op = unary_factory->create(in_op);
|
||||||
}
|
}
|
||||||
|
|
||||||
auto ng_order1 = std::make_shared<ov::opset9::Constant>(ov::element::u64, ov::Shape{4}, ov::Shape{0, 2, 3, 1});
|
auto ng_order1 = std::make_shared<Constant>(element::u64, Shape{4}, Shape{0, 2, 3, 1});
|
||||||
auto transpose1 = std::make_shared<ov::opset9::Transpose>(in_op, ng_order1);
|
auto transpose1 = std::make_shared<Transpose>(in_op, ng_order1);
|
||||||
|
|
||||||
return std::make_shared<ov::Model>(ov::OutputVector{transpose1, reshape}, ov::ParameterVector{X});
|
return std::make_shared<ov::Model>(ov::OutputVector{transpose1, reshape}, ov::ParameterVector{X});
|
||||||
}
|
}
|
||||||
@ -339,21 +418,16 @@ std::shared_ptr<ov::Model> CreateReferenceFunction(UnaryFactoryPtr unary_factory
|
|||||||
} // namespace mult_consumers_first_node
|
} // namespace mult_consumers_first_node
|
||||||
|
|
||||||
std::vector<UnaryFactoryPtr> unary_factories = {
|
std::vector<UnaryFactoryPtr> unary_factories = {
|
||||||
CreateUnaryFactory<ov::opset9::Clamp>(), CreateUnaryFactory<ov::opset9::Elu>(),
|
CREATE_UNARY_FACTORY(Clamp), CREATE_UNARY_FACTORY(Elu), CREATE_UNARY_FACTORY(SoftPlus),
|
||||||
CreateUnaryFactory<ov::opset9::SoftPlus>(), CreateUnaryFactory<ov::opset9::LogicalNot>(),
|
CREATE_UNARY_FACTORY(LogicalNot), CREATE_UNARY_FACTORY(Convert), CREATE_UNARY_FACTORY(Abs),
|
||||||
CreateUnaryFactory<ov::opset9::Convert>(), CreateUnaryFactory<ov::opset9::Abs>(),
|
CREATE_UNARY_FACTORY(Acos), CREATE_UNARY_FACTORY(Asin), CREATE_UNARY_FACTORY(Asinh),
|
||||||
CreateUnaryFactory<ov::opset9::Acos>(), CreateUnaryFactory<ov::opset9::Asin>(),
|
CREATE_UNARY_FACTORY(Atan), CREATE_UNARY_FACTORY(Ceiling), CREATE_UNARY_FACTORY(Cos),
|
||||||
CreateUnaryFactory<ov::opset9::Asinh>(), CreateUnaryFactory<ov::opset9::Atan>(),
|
CREATE_UNARY_FACTORY(Cosh), CREATE_UNARY_FACTORY(Erf), CREATE_UNARY_FACTORY(Exp),
|
||||||
CreateUnaryFactory<ov::opset9::Ceiling>(), CreateUnaryFactory<ov::opset9::Cos>(),
|
CREATE_UNARY_FACTORY(Gelu), CREATE_UNARY_FACTORY(HSigmoid), CREATE_UNARY_FACTORY(HSwish),
|
||||||
CreateUnaryFactory<ov::opset9::Cosh>(), CreateUnaryFactory<ov::opset9::Erf>(),
|
CREATE_UNARY_FACTORY(Log), CREATE_UNARY_FACTORY(Negative), CREATE_UNARY_FACTORY(Relu),
|
||||||
CreateUnaryFactory<ov::opset9::Exp>(), CreateUnaryFactory<ov::opset9::Gelu>(),
|
CREATE_UNARY_FACTORY(Sigmoid), CREATE_UNARY_FACTORY(Sign), CREATE_UNARY_FACTORY(Sin),
|
||||||
CreateUnaryFactory<ov::opset9::HSigmoid>(), CreateUnaryFactory<ov::opset9::HSwish>(),
|
CREATE_UNARY_FACTORY(Sinh), CREATE_UNARY_FACTORY(SoftSign), CREATE_UNARY_FACTORY(Sqrt),
|
||||||
CreateUnaryFactory<ov::opset9::Log>(), CreateUnaryFactory<ov::opset9::Negative>(),
|
CREATE_UNARY_FACTORY(Tan), CREATE_UNARY_FACTORY(Tanh)};
|
||||||
CreateUnaryFactory<ov::opset9::Relu>(), CreateUnaryFactory<ov::opset9::Sigmoid>(),
|
|
||||||
CreateUnaryFactory<ov::opset9::Sign>(), CreateUnaryFactory<ov::opset9::Sin>(),
|
|
||||||
CreateUnaryFactory<ov::opset9::Sinh>(), CreateUnaryFactory<ov::opset9::SoftSign>(),
|
|
||||||
CreateUnaryFactory<ov::opset9::Sqrt>(), CreateUnaryFactory<ov::opset9::Tan>(),
|
|
||||||
CreateUnaryFactory<ov::opset9::Tanh>()};
|
|
||||||
|
|
||||||
std::vector<size_t> unary_operations_numbers = {1, 10};
|
std::vector<size_t> unary_operations_numbers = {1, 10};
|
||||||
|
|
||||||
@ -365,8 +439,8 @@ TEST_P(TransposeSinkingUnaryTestFixture, CompareFunctions) {
|
|||||||
size_t num_unary_ops;
|
size_t num_unary_ops;
|
||||||
CreateGraphF model_factory;
|
CreateGraphF model_factory;
|
||||||
CreateGraphF reference_model_factory;
|
CreateGraphF reference_model_factory;
|
||||||
ov::Shape input_shape;
|
Shape input_shape;
|
||||||
ov::element::Type input_type;
|
element::Type input_type;
|
||||||
std::tie(unary_factory,
|
std::tie(unary_factory,
|
||||||
pass_factory,
|
pass_factory,
|
||||||
num_unary_ops,
|
num_unary_ops,
|
||||||
@ -380,90 +454,95 @@ TEST_P(TransposeSinkingUnaryTestFixture, CompareFunctions) {
|
|||||||
pass_factory->registerPass(manager);
|
pass_factory->registerPass(manager);
|
||||||
}
|
}
|
||||||
|
|
||||||
INSTANTIATE_TEST_SUITE_P(
|
INSTANTIATE_TEST_SUITE_P(TransposeSinkingUnaryForwardTestSuite,
|
||||||
TransposeSinkingUnaryForwardTestSuite,
|
|
||||||
TransposeSinkingUnaryTestFixture,
|
TransposeSinkingUnaryTestFixture,
|
||||||
::testing::Combine(::testing::ValuesIn(unary_factories),
|
::testing::Combine(::testing::ValuesIn(unary_factories),
|
||||||
::testing::Values(CreatePassFactory<ov::pass::TransposeSinkingUnaryForward>()),
|
::testing::Values(CREATE_PASS_FACTORY(TransposeSinkingUnaryForward)),
|
||||||
::testing::ValuesIn(unary_operations_numbers),
|
::testing::ValuesIn(unary_operations_numbers),
|
||||||
::testing::Values(CreateFunctionTransposeBefore),
|
::testing::Values(CreateFunctionTransposeBefore),
|
||||||
::testing::Values(CreateFunctionTransposeAfter),
|
::testing::Values(CreateFunctionTransposeAfter),
|
||||||
::testing::Values(ov::Shape{1, 96, 55, 55}),
|
::testing::Values(Shape{1, 96, 55, 55}),
|
||||||
::testing::Values(ov::element::f32)));
|
::testing::Values(element::f32)),
|
||||||
|
TransposeSinkingUnaryTestFixture::get_test_name);
|
||||||
|
|
||||||
INSTANTIATE_TEST_SUITE_P(
|
INSTANTIATE_TEST_SUITE_P(TransposeSinkingUnaryBackwardTestSuite,
|
||||||
TransposeSinkingUnaryBackwardTestSuite,
|
|
||||||
TransposeSinkingUnaryTestFixture,
|
TransposeSinkingUnaryTestFixture,
|
||||||
::testing::Combine(::testing::ValuesIn(unary_factories),
|
::testing::Combine(::testing::ValuesIn(unary_factories),
|
||||||
::testing::Values(CreatePassFactory<ov::pass::TransposeSinkingUnaryBackward>()),
|
::testing::Values(CREATE_PASS_FACTORY(TransposeSinkingUnaryBackward)),
|
||||||
::testing::ValuesIn(unary_operations_numbers),
|
::testing::ValuesIn(unary_operations_numbers),
|
||||||
::testing::Values(CreateFunctionTransposeAfter),
|
::testing::Values(CreateFunctionTransposeAfter),
|
||||||
::testing::Values(CreateFunctionTransposeBefore),
|
::testing::Values(CreateFunctionTransposeBefore),
|
||||||
::testing::Values(ov::Shape{1, 96, 55, 55}),
|
::testing::Values(Shape{1, 96, 55, 55}),
|
||||||
::testing::Values(ov::element::f32)));
|
::testing::Values(element::f32)),
|
||||||
|
TransposeSinkingUnaryTestFixture::get_test_name);
|
||||||
|
|
||||||
INSTANTIATE_TEST_SUITE_P(
|
INSTANTIATE_TEST_SUITE_P(
|
||||||
TransposeSinkingUnaryForwardMultConsumersTestSuiteLastNodeReshape,
|
TransposeSinkingUnaryForwardMultConsumersTestSuiteLastNodeReshape,
|
||||||
TransposeSinkingUnaryTestFixture,
|
TransposeSinkingUnaryTestFixture,
|
||||||
::testing::Combine(::testing::ValuesIn(unary_factories),
|
::testing::Combine(::testing::ValuesIn(unary_factories),
|
||||||
::testing::Values(CreatePassFactory<ov::pass::TransposeSinkingUnaryForward>()),
|
::testing::Values(CREATE_PASS_FACTORY(TransposeSinkingUnaryForward)),
|
||||||
::testing::ValuesIn(unary_operations_numbers),
|
::testing::ValuesIn(unary_operations_numbers),
|
||||||
::testing::Values(mult_consumers_last_node::with_reshape::CreateFunctionTransposeBefore),
|
::testing::Values(mult_consumers_last_node::with_reshape::CreateFunctionTransposeBefore),
|
||||||
::testing::Values(mult_consumers_last_node::with_reshape::CreateFunctionTransposeAfter),
|
::testing::Values(mult_consumers_last_node::with_reshape::CreateFunctionTransposeAfter),
|
||||||
::testing::Values(ov::Shape{1, 96, 55, 55}),
|
::testing::Values(Shape{1, 96, 55, 55}),
|
||||||
::testing::Values(ov::element::f32)));
|
::testing::Values(element::f32)),
|
||||||
|
TransposeSinkingUnaryTestFixture::get_test_name);
|
||||||
|
|
||||||
INSTANTIATE_TEST_SUITE_P(
|
INSTANTIATE_TEST_SUITE_P(
|
||||||
TransposeSinkingUnaryBackwardMultConsumersTestSuiteLastNodeReshape,
|
TransposeSinkingUnaryBackwardMultConsumersTestSuiteLastNodeReshape,
|
||||||
TransposeSinkingUnaryTestFixture,
|
TransposeSinkingUnaryTestFixture,
|
||||||
::testing::Combine(::testing::ValuesIn(unary_factories),
|
::testing::Combine(::testing::ValuesIn(unary_factories),
|
||||||
::testing::Values(CreatePassFactory<ov::pass::TransposeSinkingUnaryBackward>()),
|
::testing::Values(CREATE_PASS_FACTORY(TransposeSinkingUnaryBackward)),
|
||||||
::testing::ValuesIn(unary_operations_numbers),
|
::testing::ValuesIn(unary_operations_numbers),
|
||||||
::testing::Values(mult_consumers_last_node::with_reshape::CreateFunctionTransposeAfter),
|
::testing::Values(mult_consumers_last_node::with_reshape::CreateFunctionTransposeAfter),
|
||||||
::testing::Values(mult_consumers_last_node::with_reshape::CreateFunctionTransposeBefore),
|
::testing::Values(mult_consumers_last_node::with_reshape::CreateFunctionTransposeBefore),
|
||||||
::testing::Values(ov::Shape{1, 96, 55, 55}),
|
::testing::Values(Shape{1, 96, 55, 55}),
|
||||||
::testing::Values(ov::element::f32)));
|
::testing::Values(element::f32)),
|
||||||
|
TransposeSinkingUnaryTestFixture::get_test_name);
|
||||||
|
|
||||||
INSTANTIATE_TEST_SUITE_P(
|
INSTANTIATE_TEST_SUITE_P(
|
||||||
TransposeSinkingUnaryForwardMultConsumersTestSuiteLastNodeEltwise,
|
TransposeSinkingUnaryForwardMultConsumersTestSuiteLastNodeEltwise,
|
||||||
TransposeSinkingUnaryTestFixture,
|
TransposeSinkingUnaryTestFixture,
|
||||||
::testing::Combine(::testing::ValuesIn(unary_factories),
|
::testing::Combine(::testing::ValuesIn(unary_factories),
|
||||||
::testing::Values(CreatePassFactory<ov::pass::TransposeSinkingUnaryForward>()),
|
::testing::Values(CREATE_PASS_FACTORY(TransposeSinkingUnaryForward)),
|
||||||
::testing::ValuesIn(unary_operations_numbers),
|
::testing::ValuesIn(unary_operations_numbers),
|
||||||
::testing::Values(mult_consumers_last_node::with_eltwise::CreateFunctionTransposeBefore),
|
::testing::Values(mult_consumers_last_node::with_eltwise::CreateFunctionTransposeBefore),
|
||||||
::testing::Values(mult_consumers_last_node::with_eltwise::CreateFunctionTransposeAfter),
|
::testing::Values(mult_consumers_last_node::with_eltwise::CreateFunctionTransposeAfter),
|
||||||
::testing::Values(ov::Shape{1, 96, 55, 55}),
|
::testing::Values(Shape{1, 96, 55, 55}),
|
||||||
::testing::Values(ov::element::f32)));
|
::testing::Values(element::f32)),
|
||||||
|
TransposeSinkingUnaryTestFixture::get_test_name);
|
||||||
INSTANTIATE_TEST_SUITE_P(
|
|
||||||
TransposeSinkingUnaryBackwardMultConsumersTestSuiteLastNodeEltwise,
|
|
||||||
TransposeSinkingUnaryTestFixture,
|
|
||||||
::testing::Combine(::testing::ValuesIn(unary_factories),
|
|
||||||
::testing::Values(CreatePassFactory<ov::pass::TransposeSinkingUnaryBackward>()),
|
|
||||||
::testing::ValuesIn(unary_operations_numbers),
|
|
||||||
::testing::Values(mult_consumers_last_node::with_eltwise::CreateFunctionTransposeAfter),
|
|
||||||
::testing::Values(mult_consumers_last_node::with_eltwise::CreateFunctionTransposeBefore),
|
|
||||||
::testing::Values(ov::Shape{1, 96, 55, 55}),
|
|
||||||
::testing::Values(ov::element::f32)));
|
|
||||||
|
|
||||||
INSTANTIATE_TEST_SUITE_P(
|
INSTANTIATE_TEST_SUITE_P(
|
||||||
TransposeSinkingUnaryForwardMultConsumersTestSuiteFirstNode,
|
TransposeSinkingUnaryForwardMultConsumersTestSuiteFirstNode,
|
||||||
TransposeSinkingUnaryTestFixture,
|
TransposeSinkingUnaryTestFixture,
|
||||||
::testing::Combine(::testing::ValuesIn(unary_factories),
|
::testing::Combine(::testing::ValuesIn(unary_factories),
|
||||||
::testing::Values(CreatePassFactory<ov::pass::TransposeSinkingUnaryForward>()),
|
::testing::Values(CREATE_PASS_FACTORY(TransposeSinkingUnaryForward)),
|
||||||
::testing::ValuesIn(unary_operations_numbers),
|
::testing::ValuesIn(unary_operations_numbers),
|
||||||
::testing::Values(mult_consumers_first_node::forward::CreateFunction),
|
::testing::Values(mult_consumers_first_node::forward::CreateFunction),
|
||||||
::testing::Values(mult_consumers_first_node::forward::CreateReferenceFunction),
|
::testing::Values(mult_consumers_first_node::forward::CreateReferenceFunction),
|
||||||
::testing::Values(ov::Shape{1, 96, 55, 55}),
|
::testing::Values(Shape{1, 96, 55, 55}),
|
||||||
::testing::Values(ov::element::f32)));
|
::testing::Values(element::f32)),
|
||||||
|
TransposeSinkingUnaryTestFixture::get_test_name);
|
||||||
|
|
||||||
INSTANTIATE_TEST_SUITE_P(
|
INSTANTIATE_TEST_SUITE_P(TransposeSinkingUnaryBackwardMultConsumersTestSuiteFirstNode,
|
||||||
TransposeSinkingUnaryBackwardMultConsumersTestSuiteFirstNode,
|
|
||||||
TransposeSinkingUnaryTestFixture,
|
TransposeSinkingUnaryTestFixture,
|
||||||
::testing::Combine(::testing::ValuesIn(unary_factories),
|
::testing::Combine(::testing::ValuesIn(unary_factories),
|
||||||
::testing::Values(CreatePassFactory<ov::pass::TransposeSinkingUnaryBackward>()),
|
::testing::Values(CREATE_PASS_FACTORY(TransposeSinkingUnaryBackward)),
|
||||||
::testing::ValuesIn(unary_operations_numbers),
|
::testing::ValuesIn(unary_operations_numbers),
|
||||||
::testing::Values(mult_consumers_first_node::backward::CreateFunction),
|
::testing::Values(mult_consumers_first_node::backward::CreateFunction),
|
||||||
::testing::Values(mult_consumers_first_node::backward::CreateReferenceFunction),
|
::testing::Values(mult_consumers_first_node::backward::CreateFunction),
|
||||||
::testing::Values(ov::Shape{1, 96, 55, 55}),
|
::testing::Values(Shape{1, 96, 55, 55}),
|
||||||
::testing::Values(ov::element::f32)));
|
::testing::Values(element::f32)),
|
||||||
|
TransposeSinkingUnaryTestFixture::get_test_name);
|
||||||
|
|
||||||
|
INSTANTIATE_TEST_SUITE_P(
|
||||||
|
TransposeSinkingUnaryBackwardMultTransposeConsumersTestSuiteFirstNode,
|
||||||
|
TransposeSinkingUnaryTestFixture,
|
||||||
|
::testing::Combine(::testing::ValuesIn(unary_factories),
|
||||||
|
::testing::Values(CREATE_PASS_FACTORY(TransposeSinkingUnaryBackward)),
|
||||||
|
::testing::ValuesIn(unary_operations_numbers),
|
||||||
|
::testing::Values(mult_consumers_first_node::backward_mult_transposes::CreateFunction),
|
||||||
|
::testing::Values(mult_consumers_first_node::backward_mult_transposes::CreateReferenceFunction),
|
||||||
|
::testing::Values(Shape{1, 96, 55, 55}),
|
||||||
|
::testing::Values(element::f32)),
|
||||||
|
TransposeSinkingUnaryTestFixture::get_test_name);
|
||||||
|
Loading…
Reference in New Issue
Block a user