Transpose sinking general (#13943)

* initial

* build fixes + couple of simple unit tests

* remove old transpose_sinking_binary berfore merge with PR branch

* initial

* clang cleanup fixes

* remove TrasposeAxis function; cleanup namespaces

* fix TransposeInputsInfo spell

* one_input_transpose spell

* cleanup speel

* spell

* decompose forward sinking

* decompose backward sink

* use NodeVector

* clang cleanup

* decomposite transformations into different files

* decompose unit tests

* clang cleanup

* fix transformation names in general transformation

* fix ngraph::pass::TransposeFuse use element type the same as in fusing nodes

* add checkout sinking ability function; check sinking ability for unary operations; add unit test on general transformation

* sinking check for binary; unit tests; fixes

* add check to concat transformation; unit test

* add check to split tranformation

* azure build fixes

* add general test

* cleanup tests using common class

* clang cleanup

* add transpose sinkig to moc

* remove comment

* fix after rebase

* clang fixes

* fix after rebase

* code review fixes

* fix after rebase

* add RUN_ON_FUNCTION_SCOPE to general transformation

* fixes after rebase

* move tests to new directory

* cleanup

* use ov::RuntimeAttribute

* move NoTransposeSinkingAttr to files

* fix namespace

* fix names
This commit is contained in:
Evgeny Kotov 2022-12-13 17:08:06 +01:00 committed by GitHub
parent d980365680
commit 22d7bc70d9
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
13 changed files with 635 additions and 5 deletions

View File

@ -0,0 +1,36 @@
// Copyright (C) 2022 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#pragma once
#include "openvino/pass/graph_rewrite.hpp"
#include "transformations_visibility.hpp"
namespace ov {
namespace pass {
class TRANSFORMATIONS_API TransposeSinkingGeneralForward;
class TRANSFORMATIONS_API TransposeSinkingGeneralBackward;
class TRANSFORMATIONS_API TransposeSinkingGeneral;
} // namespace pass
} // namespace ov
class ov::pass::TransposeSinkingGeneralForward : public ov::pass::GraphRewrite {
public:
OPENVINO_RTTI("TransposeSinkingGeneralForward", "0");
TransposeSinkingGeneralForward();
};
class ov::pass::TransposeSinkingGeneralBackward : public ov::pass::GraphRewrite {
public:
OPENVINO_RTTI("TransposeSinkingGeneralBackward", "0");
TransposeSinkingGeneralBackward();
};
class ov::pass::TransposeSinkingGeneral : public ov::pass::ModelPass {
public:
OPENVINO_RTTI("TransposeSinkingGeneral", "0");
bool run_on_model(const std::shared_ptr<ov::Model>& m) override;
};

View File

@ -47,4 +47,6 @@ ov::NodeVector InsertTransposeBeforeNode(std::shared_ptr<ov::Node> main_node,
std::shared_ptr<ov::opset9::Constant> transpose_const); std::shared_ptr<ov::opset9::Constant> transpose_const);
} // namespace sink_backward } // namespace sink_backward
void UpdateForwardSinkingAbility(std::shared_ptr<ov::Node>);
} // namespace transpose_sinking } // namespace transpose_sinking

View File

@ -0,0 +1,32 @@
// Copyright (C) 2022 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#pragma once
#include "openvino/core/node.hpp"
#include "openvino/core/runtime_attribute.hpp"
#include "transformations_visibility.hpp"
namespace ov {
TRANSFORMATIONS_API void mark_as_no_sinking_node(const std::shared_ptr<Node>& node);
TRANSFORMATIONS_API bool is_sinking_node(const std::shared_ptr<Node>& node);
TRANSFORMATIONS_API bool is_sinking_node(const Node* node);
/**
* @ingroup ie_runtime_attr_api
* @brief NoTransposeSinkingAttr class represents runtime info attribute that marks transpose
* operation should not be moved be backward sinking propagation.
*/
class TRANSFORMATIONS_API NoTransposeSinkingAttr : public RuntimeAttribute {
public:
OPENVINO_RTTI("no_transpose_sinking", "0");
bool is_copyable() const override {
return false;
}
};
} // namespace ov

View File

@ -66,6 +66,7 @@
#include <transformations/common_optimizations/subtract_fusion.hpp> #include <transformations/common_optimizations/subtract_fusion.hpp>
#include <transformations/common_optimizations/swish_fusion.hpp> #include <transformations/common_optimizations/swish_fusion.hpp>
#include <transformations/common_optimizations/transpose_sinking.hpp> #include <transformations/common_optimizations/transpose_sinking.hpp>
#include <transformations/common_optimizations/transpose_sinking_general.hpp>
#include <transformations/common_optimizations/transpose_to_reshape.hpp> #include <transformations/common_optimizations/transpose_to_reshape.hpp>
#include <transformations/init_node_info.hpp> #include <transformations/init_node_info.hpp>
#include <transformations/low_precision/mark_dequantization_subgraph.hpp> #include <transformations/low_precision/mark_dequantization_subgraph.hpp>

View File

@ -14,6 +14,7 @@
#include <vector> #include <vector>
#include "itt.hpp" #include "itt.hpp"
#include "transformations/common_optimizations/transpose_sinking_utils.hpp"
#include "transformations/utils/utils.hpp" #include "transformations/utils/utils.hpp"
using namespace ov; using namespace ov;
@ -294,15 +295,21 @@ ov::pass::TransposeFuse::TransposeFuse() {
is_ordered = false; is_ordered = false;
} }
auto transpose_order_type = transpose1_order->get_element_type();
if (transpose_order_type != transpose2_order->get_element_type())
transpose_order_type = element::i64;
if (is_ordered) { if (is_ordered) {
return ngraph::replace_output_update_name(transpose2->output(0), input); return ngraph::replace_output_update_name(transpose2->output(0), input);
} else { } else {
auto new_order = opset7::Constant::create(element::i64, {order2.size()}, order2); auto new_order = opset7::Constant::create(transpose_order_type, {order2.size()}, order2);
auto new_transpose = register_new_node<opset7::Transpose>(input, new_order); auto new_transpose = register_new_node<opset7::Transpose>(input, new_order);
new_transpose->set_friendly_name(m.get_match_root()->get_friendly_name()); new_transpose->set_friendly_name(m.get_match_root()->get_friendly_name());
ngraph::copy_runtime_info({transpose1, transpose2}, new_transpose); ngraph::copy_runtime_info({transpose1, transpose2}, new_transpose);
ngraph::replace_node(m.get_match_root(), new_transpose); ngraph::replace_node(m.get_match_root(), new_transpose);
transpose_sinking::UpdateForwardSinkingAbility(new_transpose);
} }
return true; return true;

View File

@ -1,5 +1,6 @@
#include "transformations/common_optimizations/transpose_sinking_binary.hpp" #include "transformations/common_optimizations/transpose_sinking_binary.hpp"
#include <openvino/opsets/opset9.hpp>
#include <openvino/pass/pattern/op/or.hpp> #include <openvino/pass/pattern/op/or.hpp>
#include <transformations/utils/utils.hpp> #include <transformations/utils/utils.hpp>
#include <utility> #include <utility>
@ -12,6 +13,7 @@
#include "openvino/util/common_util.hpp" #include "openvino/util/common_util.hpp"
#include "openvino/util/log.hpp" #include "openvino/util/log.hpp"
#include "transformations/common_optimizations/transpose_sinking_utils.hpp" #include "transformations/common_optimizations/transpose_sinking_utils.hpp"
#include "transformations/rt_info/transpose_sinking_attr.hpp"
using namespace ov::pass::pattern; using namespace ov::pass::pattern;
using namespace ov; using namespace ov;
@ -34,6 +36,7 @@ ov::pass::TransposeSinkingBinaryElementwiseForward::TransposeSinkingBinaryElemen
sink_forward::UpdateInputTransposes(main_node, transpose_input_info); sink_forward::UpdateInputTransposes(main_node, transpose_input_info);
for (auto& new_node : sink_forward::InsertOutputTransposes(main_node, transpose_input_info)) { for (auto& new_node : sink_forward::InsertOutputTransposes(main_node, transpose_input_info)) {
register_new_node(new_node); register_new_node(new_node);
transpose_sinking::UpdateForwardSinkingAbility(new_node);
} }
return true; return true;
@ -43,13 +46,19 @@ ov::pass::TransposeSinkingBinaryElementwiseForward::TransposeSinkingBinaryElemen
register_matcher(m, matcher_pass_callback); register_matcher(m, matcher_pass_callback);
} }
ov::pass::TransposeSinkingBinaryElementwiseBackward::TransposeSinkingBinaryElementwiseBackward() { pass::TransposeSinkingBinaryElementwiseBackward::TransposeSinkingBinaryElementwiseBackward() {
MATCHER_SCOPE(TransposeSinkingBinaryElementwiseBackward); MATCHER_SCOPE(TransposeSinkingBinaryElementwiseBackward);
auto main_node_label = wrap_type<op::util::BinaryElementwiseArithmetic>(consumers_count(1)); auto main_node_label = wrap_type<op::util::BinaryElementwiseArithmetic>(consumers_count(1));
auto transpose_const_label = wrap_type<Constant>(consumers_count(1)); auto transpose_const_label = wrap_type<Constant>(consumers_count(1));
auto transpose_label = wrap_type<Transpose>({main_node_label, transpose_const_label}, consumers_count(1));
auto IfSinkingEnabled = [](const Output<Node>& output) -> bool {
static auto consumers_check = consumers_count(1);
return consumers_check(output) && is_sinking_node(output.get_node_shared_ptr());
};
auto transpose_label = wrap_type<Transpose>({main_node_label, transpose_const_label}, IfSinkingEnabled);
matcher_pass_callback matcher_pass_callback = [=](Matcher& m) { matcher_pass_callback matcher_pass_callback = [=](Matcher& m) {
const auto& pattern_to_output = m.get_pattern_value_map(); const auto& pattern_to_output = m.get_pattern_value_map();

View File

@ -13,6 +13,7 @@
#include "openvino/util/common_util.hpp" #include "openvino/util/common_util.hpp"
#include "openvino/util/log.hpp" #include "openvino/util/log.hpp"
#include "transformations/common_optimizations/transpose_sinking_utils.hpp" #include "transformations/common_optimizations/transpose_sinking_utils.hpp"
#include "transformations/rt_info/transpose_sinking_attr.hpp"
using namespace ov::pass::pattern; using namespace ov::pass::pattern;
using namespace ov; using namespace ov;
@ -35,6 +36,7 @@ ov::pass::TransposeSinkingConcatForward::TransposeSinkingConcatForward() {
sink_forward::UpdateInputTransposes(main_node, transpose_input_info); sink_forward::UpdateInputTransposes(main_node, transpose_input_info);
for (auto& new_node : sink_forward::InsertOutputTransposes(main_node, transpose_input_info)) { for (auto& new_node : sink_forward::InsertOutputTransposes(main_node, transpose_input_info)) {
register_new_node(new_node); register_new_node(new_node);
transpose_sinking::UpdateForwardSinkingAbility(new_node);
} }
auto concat_node = as_type_ptr<Concat>(main_node); auto concat_node = as_type_ptr<Concat>(main_node);
@ -55,7 +57,13 @@ ov::pass::TransposeSinkingConcatBackward::TransposeSinkingConcatBackward() {
auto main_node_label = wrap_type<Concat>(consumers_count(1)); auto main_node_label = wrap_type<Concat>(consumers_count(1));
auto transpose_const_label = wrap_type<Constant>(consumers_count(1)); auto transpose_const_label = wrap_type<Constant>(consumers_count(1));
auto transpose_label = wrap_type<Transpose>({main_node_label, transpose_const_label}, consumers_count(1));
auto IfSinkingEnabled = [](const Output<Node>& output) -> bool {
static auto consumers_check = consumers_count(1);
return consumers_check(output) && is_sinking_node(output.get_node_shared_ptr());
};
auto transpose_label = wrap_type<Transpose>({main_node_label, transpose_const_label}, IfSinkingEnabled);
matcher_pass_callback matcher_pass_callback = [=](Matcher& m) { matcher_pass_callback matcher_pass_callback = [=](Matcher& m) {
const auto& pattern_to_output = m.get_pattern_value_map(); const auto& pattern_to_output = m.get_pattern_value_map();

View File

@ -0,0 +1,56 @@
// Copyright (C) 2022 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#include "transformations/common_optimizations/transpose_sinking_general.hpp"
#include <ngraph/pass/constant_folding.hpp>
#include <ngraph/pass/graph_rewrite.hpp>
#include <ngraph/pass/manager.hpp>
#include <ngraph/pattern/op/wrap_type.hpp>
#include <ngraph/rt_info.hpp>
#include "itt.hpp"
#include "transformations/common_optimizations/transpose_sinking.hpp"
#include "transformations/common_optimizations/transpose_sinking_binary.hpp"
#include "transformations/common_optimizations/transpose_sinking_concat.hpp"
#include "transformations/common_optimizations/transpose_sinking_split.hpp"
#include "transformations/common_optimizations/transpose_sinking_unary.hpp"
#include "transformations/utils/utils.hpp"
ov::pass::TransposeSinkingGeneralForward::TransposeSinkingGeneralForward() {
MATCHER_SCOPE(TransposeSinkingGeneralForward);
add_matcher<ov::pass::TransposeSinkingUnaryForward>();
add_matcher<ov::pass::TransposeSinkingBinaryElementwiseForward>();
add_matcher<ov::pass::TransposeSinkingConcatForward>();
add_matcher<ov::pass::TransposeSinkingSplitForward>();
add_matcher<ngraph::pass::TransposeFuse>();
}
ov::pass::TransposeSinkingGeneralBackward::TransposeSinkingGeneralBackward() {
MATCHER_SCOPE(TransposeSinkingGeneralBackward);
add_matcher<ov::pass::TransposeSinkingUnaryBackward>();
add_matcher<ov::pass::TransposeSinkingBinaryElementwiseBackward>();
add_matcher<ov::pass::TransposeSinkingConcatBackward>();
add_matcher<ov::pass::TransposeSinkingSplitBackward>();
add_matcher<ngraph::pass::TransposeFuse>();
}
bool ov::pass::TransposeSinkingGeneral::run_on_model(const std::shared_ptr<ov::Model>& f) {
RUN_ON_FUNCTION_SCOPE(TransposeSinkingGeneral);
{
ngraph::pass::Manager manager(get_pass_config());
manager.register_pass<ov::pass::TransposeSinkingGeneralForward>();
manager.register_pass<ngraph::pass::ConstantFolding>();
manager.run_passes(f);
}
{
ngraph::pass::Manager manager(get_pass_config());
manager.register_pass<ov::pass::TransposeSinkingGeneralBackward>();
manager.register_pass<ngraph::pass::ConstantFolding>();
manager.run_passes(f);
}
return false;
}

View File

@ -13,6 +13,7 @@
#include "openvino/util/common_util.hpp" #include "openvino/util/common_util.hpp"
#include "openvino/util/log.hpp" #include "openvino/util/log.hpp"
#include "transformations/common_optimizations/transpose_sinking_utils.hpp" #include "transformations/common_optimizations/transpose_sinking_utils.hpp"
#include "transformations/rt_info/transpose_sinking_attr.hpp"
using namespace ov::pass::pattern; using namespace ov::pass::pattern;
using namespace ov; using namespace ov;
@ -66,6 +67,9 @@ std::shared_ptr<Constant> GetTransposeConstant(Input<Node> input) {
if (!transpose_node) if (!transpose_node)
return {}; return {};
if (!is_sinking_node(input.get_node()))
return {};
auto constant_node = as_type_ptr<Constant>(transpose_node->input_value(1).get_node_shared_ptr()); auto constant_node = as_type_ptr<Constant>(transpose_node->input_value(1).get_node_shared_ptr());
if (!constant_node) if (!constant_node)
return {}; return {};
@ -181,6 +185,7 @@ ov::pass::TransposeSinkingSplitForward::TransposeSinkingSplitForward() {
sink_forward::RemoveZeroInputNode(main_node); sink_forward::RemoveZeroInputNode(main_node);
for (auto& new_node : sink_forward::InsertOutputTransposes(main_node, transpose_input_info)) { for (auto& new_node : sink_forward::InsertOutputTransposes(main_node, transpose_input_info)) {
register_new_node(new_node); register_new_node(new_node);
transpose_sinking::UpdateForwardSinkingAbility(new_node);
} }
const auto transpose_axis_order = transpose_input_info.transpose_const->get_axis_vector_val(); const auto transpose_axis_order = transpose_input_info.transpose_const->get_axis_vector_val();

View File

@ -6,6 +6,10 @@
#include "itt.hpp" #include "itt.hpp"
#include "openvino/opsets/opset9.hpp" #include "openvino/opsets/opset9.hpp"
#include "openvino/pass/pattern/op/wrap_type.hpp" #include "openvino/pass/pattern/op/wrap_type.hpp"
#include "transformations/common_optimizations/transpose_sinking_utils.hpp"
#include "transformations/rt_info/transpose_sinking_attr.hpp"
using namespace ov;
namespace { namespace {
@ -105,6 +109,8 @@ ov::pass::TransposeSinkingUnaryForward::TransposeSinkingUnaryForward() {
register_new_node(new_nodes.first); register_new_node(new_nodes.first);
register_new_node(new_nodes.second); register_new_node(new_nodes.second);
transpose_sinking::UpdateForwardSinkingAbility(new_nodes.second);
return true; return true;
}; };
@ -112,6 +118,12 @@ ov::pass::TransposeSinkingUnaryForward::TransposeSinkingUnaryForward() {
register_matcher(m, matcher_pass_callback); register_matcher(m, matcher_pass_callback);
} }
namespace {
bool IfSinkingEnabled(const Output<Node>& output) {
return is_sinking_node(output.get_node_shared_ptr());
}
} // namespace
ov::pass::TransposeSinkingUnaryBackward::TransposeSinkingUnaryBackward() { ov::pass::TransposeSinkingUnaryBackward::TransposeSinkingUnaryBackward() {
MATCHER_SCOPE(TransposeSinkingUnaryBackward); MATCHER_SCOPE(TransposeSinkingUnaryBackward);
@ -123,7 +135,8 @@ ov::pass::TransposeSinkingUnaryBackward::TransposeSinkingUnaryBackward() {
ov::opset9::Convert>({ov::pass::pattern::any_input()}); ov::opset9::Convert>({ov::pass::pattern::any_input()});
auto transpose_label = auto transpose_label =
ov::pass::pattern::wrap_type<ov::opset9::Transpose>({unary_label, ov::pass::pattern::any_input()}); ov::pass::pattern::wrap_type<ov::opset9::Transpose>({unary_label, ov::pass::pattern::any_input()},
IfSinkingEnabled);
ov::matcher_pass_callback matcher_pass_callback = [=](ov::pass::pattern::Matcher& m) { ov::matcher_pass_callback matcher_pass_callback = [=](ov::pass::pattern::Matcher& m) {
const auto& pattern_to_output = m.get_pattern_value_map(); const auto& pattern_to_output = m.get_pattern_value_map();

View File

@ -11,6 +11,7 @@
#include "openvino/pass/pattern/op/wrap_type.hpp" #include "openvino/pass/pattern/op/wrap_type.hpp"
#include "openvino/util/common_util.hpp" #include "openvino/util/common_util.hpp"
#include "openvino/util/log.hpp" #include "openvino/util/log.hpp"
#include "transformations/rt_info/transpose_sinking_attr.hpp"
namespace transpose_sinking { namespace transpose_sinking {
@ -169,4 +170,44 @@ NodeVector InsertTransposeBeforeNode(NodePtr main_node, std::shared_ptr<Constant
} }
} // namespace sink_backward } // namespace sink_backward
#define CHECK_TRANSPOSE_SINKING_SUPPORTED(TYPE, node) \
if (dynamic_cast<TYPE*>(node)) { \
return true; \
}
namespace {
bool CanPropagateForwardThrough(Node* node) {
CHECK_TRANSPOSE_SINKING_SUPPORTED(ov::op::util::UnaryElementwiseArithmetic, node);
CHECK_TRANSPOSE_SINKING_SUPPORTED(Clamp, node);
CHECK_TRANSPOSE_SINKING_SUPPORTED(Elu, node);
CHECK_TRANSPOSE_SINKING_SUPPORTED(SoftPlus, node);
CHECK_TRANSPOSE_SINKING_SUPPORTED(LogicalNot, node);
CHECK_TRANSPOSE_SINKING_SUPPORTED(Convert, node);
CHECK_TRANSPOSE_SINKING_SUPPORTED(ov::op::util::BinaryElementwiseArithmetic, node);
CHECK_TRANSPOSE_SINKING_SUPPORTED(Concat, node);
CHECK_TRANSPOSE_SINKING_SUPPORTED(Split, node);
CHECK_TRANSPOSE_SINKING_SUPPORTED(Transpose, node);
return false;
}
bool CanPropagateForward(NodePtr node) {
for (size_t i = 0; i < node->get_output_size(); ++i) {
for (auto& consumer_input : node->output(i).get_target_inputs()) {
if (!CanPropagateForwardThrough(consumer_input.get_node()))
return false;
}
}
return true;
}
} // namespace
void UpdateForwardSinkingAbility(NodePtr node) {
if (!CanPropagateForward(node))
mark_as_no_sinking_node(node);
}
} // namespace transpose_sinking } // namespace transpose_sinking

View File

@ -0,0 +1,28 @@
// Copyright (C) 2022 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#include "transformations/rt_info/transpose_sinking_attr.hpp"
using namespace ov;
void ov::mark_as_no_sinking_node(const std::shared_ptr<Node>& node) {
auto& rt_info = node->get_rt_info();
rt_info[NoTransposeSinkingAttr::get_type_info_static()] = NoTransposeSinkingAttr();
}
namespace {
template <typename NodePtr>
bool is_sinking_node_private(NodePtr node) {
const auto& rt_info = node->get_rt_info();
return rt_info.find(NoTransposeSinkingAttr::get_type_info_static()) == rt_info.end();
}
} // namespace
bool ov::is_sinking_node(const std::shared_ptr<Node>& node) {
return is_sinking_node_private(node);
}
bool ov::is_sinking_node(const Node* node) {
return is_sinking_node_private(node);
}

View File

@ -0,0 +1,392 @@
// Copyright (C) 2022 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#include <transformations/common_optimizations/transpose_sinking_general.hpp>
#include <transformations/init_node_info.hpp>
#include <openvino/frontend/manager.hpp>
#include <openvino/opsets/opset9.hpp>
#include <openvino/pass/manager.hpp>
#include "common_test_utils/ngraph_test_utils.hpp"
#include <functional>
#include "gtest/gtest.h"
using namespace testing;
using NodePtr = std::shared_ptr<ov::Node>;
TEST_F(TransformationTestsF, TransposeSinkingGeneralTestUnariesTransposesForward) {
ov::Shape input_shape = {1, 96, 55, 55};
ov::element::Type input_type = ov::element::f32;
size_t num_unary_ops = 10;
{
auto X = std::make_shared<ov::opset9::Parameter>(input_type, input_shape);
NodePtr in_op = X;
for (size_t i = 0; i < num_unary_ops; ++i) {
auto ng_order0 = std::make_shared<ov::opset9::Constant>(ov::element::u64, ov::Shape{4}, ov::Shape{0, 2, 3, 1});
auto transpose0 = std::make_shared<ov::opset9::Transpose>(in_op, ng_order0);
auto unary = std::make_shared<ov::opset9::Tanh>(transpose0);
auto ng_order1 = std::make_shared<ov::opset9::Constant>(ov::element::u64, ov::Shape{4}, ov::Shape{0, 3, 1, 2});
in_op = std::make_shared<ov::opset9::Transpose>(unary, ng_order1);
}
function = std::make_shared<ov::Model>(in_op, ov::ParameterVector{X});
}
{
auto X = std::make_shared<ov::opset9::Parameter>(input_type, input_shape);
NodePtr in_op = X;
for (size_t i = 0; i < num_unary_ops; ++i) {
in_op = std::make_shared<ov::opset9::Tanh>(in_op);
}
function_ref = std::make_shared<ov::Model>(in_op, ov::ParameterVector{X});
}
manager.register_pass<ov::pass::TransposeSinkingGeneralForward>();
}
TEST_F(TransformationTestsF, TransposeSinkingGeneralTestUnariesTransposesBackward) {
ov::Shape input_shape = {1, 96, 55, 55};
ov::element::Type input_type = ov::element::f32;
size_t num_unary_ops = 10;
{
auto X = std::make_shared<ov::opset9::Parameter>(input_type, input_shape);
NodePtr in_op = X;
for (size_t i = 0; i < num_unary_ops; ++i) {
auto ng_order0 = std::make_shared<ov::opset9::Constant>(ov::element::u64, ov::Shape{4}, ov::Shape{0, 2, 3, 1});
auto transpose0 = std::make_shared<ov::opset9::Transpose>(in_op, ng_order0);
auto unary = std::make_shared<ov::opset9::Tanh>(transpose0);
auto ng_order1 = std::make_shared<ov::opset9::Constant>(ov::element::u64, ov::Shape{4}, ov::Shape{0, 3, 1, 2});
in_op = std::make_shared<ov::opset9::Transpose>(unary, ng_order1);
}
function = std::make_shared<ov::Model>(in_op, ov::ParameterVector{X});
}
{
auto X = std::make_shared<ov::opset9::Parameter>(input_type, input_shape);
NodePtr in_op = X;
for (size_t i = 0; i < num_unary_ops; ++i) {
in_op = std::make_shared<ov::opset9::Tanh>(in_op);
}
function_ref = std::make_shared<ov::Model>(in_op, ov::ParameterVector{X});
}
manager.register_pass<ov::pass::TransposeSinkingGeneralBackward>();
}
TEST_F(TransformationTestsF, TransposeSinkingGeneralTestUnariesTransposesGeneral) {
ov::Shape input_shape = {1, 96, 55, 55};
ov::element::Type input_type = ov::element::f32;
size_t num_unary_ops = 10;
{
auto X = std::make_shared<ov::opset9::Parameter>(input_type, input_shape);
auto ng_order0 = std::make_shared<ov::opset9::Constant>(ov::element::u64, ov::Shape{4}, ov::Shape{0, 2, 3, 1});
auto transpose0 = std::make_shared<ov::opset9::Transpose>(X, ng_order0);
NodePtr in_op = transpose0;
for (size_t i = 0; i < num_unary_ops; ++i) {
auto ng_order0 = std::make_shared<ov::opset9::Constant>(ov::element::u64, ov::Shape{4}, ov::Shape{0, 2, 3, 1});
auto transpose0 = std::make_shared<ov::opset9::Transpose>(in_op, ng_order0);
auto unary = std::make_shared<ov::opset9::Tanh>(transpose0);
auto ng_order1 = std::make_shared<ov::opset9::Constant>(ov::element::u64, ov::Shape{4}, ov::Shape{0, 3, 1, 2});
in_op = std::make_shared<ov::opset9::Transpose>(unary, ng_order1);
}
function = std::make_shared<ov::Model>(in_op, ov::ParameterVector{X});
}
{
auto X = std::make_shared<ov::opset9::Parameter>(input_type, input_shape);
NodePtr in_op = X;
for (size_t i = 0; i < num_unary_ops; ++i) {
in_op = std::make_shared<ov::opset9::Tanh>(in_op);
}
auto ng_order0 = std::make_shared<ov::opset9::Constant>(ov::element::u64, ov::Shape{4}, ov::Shape{0, 2, 3, 1});
auto transpose0 = std::make_shared<ov::opset9::Transpose>(in_op, ng_order0);
function_ref = std::make_shared<ov::Model>(transpose0, ov::ParameterVector{X});
}
manager.register_pass<ov::pass::TransposeSinkingGeneral>();
}
TEST_F(TransformationTestsF, TransposeSinkingGeneralTestBinaryGeneral) {
ov::Shape input_shape = {1, 96, 55, 55};
ov::element::Type input_type = ov::element::f32;
size_t num_binary_ops = 10;
{
auto X = std::make_shared<ov::opset9::Parameter>(input_type, input_shape);
auto ng_order0 = std::make_shared<ov::opset9::Constant>(ov::element::u64, ov::Shape{4}, ov::Shape{0, 2, 3, 1});
auto transpose0 = std::make_shared<ov::opset9::Transpose>(X, ng_order0);
NodePtr in_op = transpose0;
for (size_t i = 0; i < num_binary_ops; ++i) {
auto in_constant = std::make_shared<ov::opset9::Constant>(input_type, input_shape, ov::Shape{1});
auto ng_order1 = std::make_shared<ov::opset9::Constant>(ov::element::u64, ov::Shape{4}, ov::Shape{0, 2, 3, 1});
auto transpose1 = std::make_shared<ov::opset9::Transpose>(in_constant, ng_order1);
in_op = std::make_shared<ov::opset9::Add>(in_op, transpose1);
}
function = std::make_shared<ov::Model>(in_op, ov::ParameterVector{X});
}
{
auto X = std::make_shared<ov::opset9::Parameter>(input_type, input_shape);
NodePtr in_op = X;
for (size_t i = 0; i < num_binary_ops; ++i) {
auto in_constant = std::make_shared<ov::opset9::Constant>(input_type, input_shape, ov::Shape{1});
in_op = std::make_shared<ov::opset9::Add>(in_op, in_constant);
}
auto ng_order0 = std::make_shared<ov::opset9::Constant>(ov::element::u64, ov::Shape{4}, ov::Shape{0, 2, 3, 1});
auto transpose0 = std::make_shared<ov::opset9::Transpose>(in_op, ng_order0);
function_ref = std::make_shared<ov::Model>(transpose0, ov::ParameterVector{X});
}
manager.register_pass<ov::pass::TransposeSinkingGeneral>();
}
TEST_F(TransformationTestsF, TransposeSinkingGeneralTestConcatGeneral) {
ov::Shape input_shape = {1, 96, 55, 55};
ov::element::Type input_type = ov::element::f32;
const size_t num_concat_ops = 3;
const size_t num_concat_inputs = 2;
{
auto X = std::make_shared<ov::opset9::Parameter>(input_type, input_shape);
auto ng_order0 = std::make_shared<ov::opset9::Constant>(ov::element::u64, ov::Shape{4}, ov::Shape{0, 2, 3, 1});
auto transpose0 = std::make_shared<ov::opset9::Transpose>(X, ng_order0);
NodePtr in_op = transpose0;
for (size_t i = 0; i < num_concat_ops; ++i) {
ov::OutputVector concat_inputs;
concat_inputs.push_back(in_op);
for (size_t j = 1; j < num_concat_inputs; ++j) {
auto in_constant = std::make_shared<ov::opset9::Constant>(input_type, input_shape, ov::Shape{1});
auto ng_order1 =
std::make_shared<ov::opset9::Constant>(ov::element::u64, ov::Shape{4}, ov::Shape{0, 2, 3, 1});
auto transpose1 = std::make_shared<ov::opset9::Transpose>(in_constant, ng_order1);
concat_inputs.push_back(transpose1);
}
in_op = std::make_shared<ov::opset9::Concat>(concat_inputs, 1);
}
function = std::make_shared<ov::Model>(in_op, ov::ParameterVector{X});
}
{
auto X = std::make_shared<ov::opset9::Parameter>(input_type, input_shape);
NodePtr in_op = X;
for (size_t i = 0; i < num_concat_ops; ++i) {
ov::OutputVector concat_inputs;
concat_inputs.push_back(in_op);
for (size_t j = 1; j < num_concat_inputs; ++j) {
auto in_constant = std::make_shared<ov::opset9::Constant>(input_type, input_shape, ov::Shape{1});
concat_inputs.push_back(in_constant);
}
in_op = std::make_shared<ov::opset9::Concat>(concat_inputs, 2);
}
auto ng_order0 = std::make_shared<ov::opset9::Constant>(ov::element::u64, ov::Shape{4}, ov::Shape{0, 2, 3, 1});
auto transpose0 = std::make_shared<ov::opset9::Transpose>(in_op, ng_order0);
function_ref = std::make_shared<ov::Model>(transpose0, ov::ParameterVector{X});
}
manager.register_pass<ov::pass::TransposeSinkingGeneral>();
}
// ----------------------------------------------------------------------------------------------------------------------
class IFactory {
public:
virtual ~IFactory() = default;
virtual NodePtr create(const ov::OutputVector & parent) = 0;
virtual size_t getNumInputs() const = 0;
virtual size_t getNumOuputs() const = 0;
};
using FactoryPtr = std::shared_ptr<IFactory>;
class UnaryFactory : public IFactory {
public:
NodePtr create(const ov::OutputVector & parent) override {
return std::make_shared<ov::opset9::Sinh>(parent.front());
}
static FactoryPtr createFactory() {
return std::make_shared<UnaryFactory>();
}
size_t getNumInputs() const override { return 1; }
size_t getNumOuputs() const override { return 1; }
};
class BinaryFactory : public IFactory {
public:
NodePtr create(const ov::OutputVector & parent) override {
return std::make_shared<ov::opset9::Add>(parent[0], parent[1]);
}
static FactoryPtr createFactory() {
return std::make_shared<BinaryFactory>();
}
size_t getNumInputs() const override { return 2; }
size_t getNumOuputs() const override { return 1; }
};
class SplitFactory : public IFactory {
public:
SplitFactory(size_t axis) : axis_(axis) {}
NodePtr create(const ov::OutputVector & parent) override {
auto split_axis_const = std::make_shared<ov::opset9::Constant>(ov::element::u64,
ov::Shape{},
axis_);
return std::make_shared<ov::opset9::Split>(parent.front(), split_axis_const, 2);
}
static FactoryPtr createFactory(size_t axis) {
return std::make_shared<SplitFactory>(axis);
}
size_t getNumInputs() const override { return 1; }
size_t getNumOuputs() const override { return 2; }
private:
const size_t axis_;
};
class ConcatFactory : public IFactory {
public:
ConcatFactory(size_t axis) : axis_(axis) {}
NodePtr create(const ov::OutputVector & parent) override {
return std::make_shared<ov::opset9::Concat>(parent, axis_);
}
static FactoryPtr createFactory(size_t axis) {
return std::make_shared<ConcatFactory>(axis);
}
size_t getNumInputs() const override { return 2; }
size_t getNumOuputs() const override { return 1; }
private:
const size_t axis_;
};
/*
Each node pair should be started with input size = 1 node and finished with node output size = 1
Insert Split/Concat to fullfill that.
*/
NodePtr CreateNodePair(FactoryPtr factory_first, FactoryPtr factory_second, NodePtr parent, size_t split_axis, size_t concat_axis) {
NodePtr input = parent;
if (factory_first->getNumInputs() != 1) {
input = SplitFactory(split_axis).create(input->outputs());
}
input = factory_first->create(input->outputs());
if (factory_first->getNumOuputs() < factory_second->getNumInputs()) {
input = SplitFactory(split_axis).create(input->outputs());
} else if (factory_first->getNumOuputs() > factory_second->getNumInputs()) {
input = ConcatFactory(concat_axis).create(input->outputs());
}
auto output = factory_second->create(input->outputs());
if (output->get_output_size() > 1) {
output = ConcatFactory(concat_axis).create(output->outputs());
}
return output;
}
NodePtr MakeAllNodesSubgraph(NodePtr parent, size_t split_axis, size_t concat_axis) {
std::vector<FactoryPtr> factories = { UnaryFactory::createFactory(),
SplitFactory::createFactory(split_axis),
ConcatFactory::createFactory(concat_axis) };
NodePtr in_op = parent;
for (int i = 0; i < factories.size(); ++i) {
for (int j = 0; j < factories.size(); ++j) {
in_op = CreateNodePair(factories[i], factories[j], in_op, split_axis, concat_axis);
}
}
return in_op;
}
TEST_F(TransformationTestsF, TransposeSinkingGeneralTestMultipleTypes) {
ov::Shape input_shape = {1, 96, 40, 55};
ov::element::Type input_type = ov::element::f32;
{
auto X = std::make_shared<ov::opset9::Parameter>(input_type, input_shape);
auto node0 = MakeAllNodesSubgraph(X, 1, 1);
auto ng_order0 = std::make_shared<ov::opset9::Constant>(ov::element::u64, ov::Shape{4}, ov::Shape{0, 2, 3, 1});
auto transpose0 = std::make_shared<ov::opset9::Transpose>(node0, ng_order0);
auto reshape_const = std::make_shared<ov::opset9::Constant>(ov::element::u64, ov::Shape{4}, ov::Shape{1, 40, 55, 96});
auto reshape = std::make_shared<ov::opset9::Reshape>(transpose0, reshape_const, false);
auto ng_order1 = std::make_shared<ov::opset9::Constant>(ov::element::u64, ov::Shape{4}, ov::Shape{0, 3, 1, 2});
auto transpose1 = std::make_shared<ov::opset9::Transpose>(reshape, ng_order1);
auto node1 = MakeAllNodesSubgraph(transpose1, 1, 1);
function = std::make_shared<ov::Model>(node1, ov::ParameterVector{X});
}
{
auto X = std::make_shared<ov::opset9::Parameter>(input_type, input_shape);
auto ng_order0 = std::make_shared<ov::opset9::Constant>(ov::element::u64, ov::Shape{4}, ov::Shape{0, 2, 3, 1});
auto transpose0 = std::make_shared<ov::opset9::Transpose>(X, ng_order0);
auto node0 = MakeAllNodesSubgraph(transpose0, 3, 3);
auto reshape_const = std::make_shared<ov::opset9::Constant>(ov::element::u64, ov::Shape{4}, ov::Shape{1, 40, 55, 96});
auto reshape = std::make_shared<ov::opset9::Reshape>(node0, reshape_const, false);
auto node1 = MakeAllNodesSubgraph(reshape, 3, 3);
auto ng_order1 = std::make_shared<ov::opset9::Constant>(ov::element::u64, ov::Shape{4}, ov::Shape{0, 3, 1, 2});
auto transpose1 = std::make_shared<ov::opset9::Transpose>(node1, ng_order1);
function_ref = std::make_shared<ov::Model>(transpose1, ov::ParameterVector{X});
}
manager.register_pass<ov::pass::TransposeSinkingGeneral>();
}