Transpose sinking general (#13943)
* initial * build fixes + couple of simple unit tests * remove old transpose_sinking_binary berfore merge with PR branch * initial * clang cleanup fixes * remove TrasposeAxis function; cleanup namespaces * fix TransposeInputsInfo spell * one_input_transpose spell * cleanup speel * spell * decompose forward sinking * decompose backward sink * use NodeVector * clang cleanup * decomposite transformations into different files * decompose unit tests * clang cleanup * fix transformation names in general transformation * fix ngraph::pass::TransposeFuse use element type the same as in fusing nodes * add checkout sinking ability function; check sinking ability for unary operations; add unit test on general transformation * sinking check for binary; unit tests; fixes * add check to concat transformation; unit test * add check to split tranformation * azure build fixes * add general test * cleanup tests using common class * clang cleanup * add transpose sinkig to moc * remove comment * fix after rebase * clang fixes * fix after rebase * code review fixes * fix after rebase * add RUN_ON_FUNCTION_SCOPE to general transformation * fixes after rebase * move tests to new directory * cleanup * use ov::RuntimeAttribute * move NoTransposeSinkingAttr to files * fix namespace * fix names
This commit is contained in:
parent
d980365680
commit
22d7bc70d9
@ -0,0 +1,36 @@
|
||||
// Copyright (C) 2022 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "openvino/pass/graph_rewrite.hpp"
|
||||
#include "transformations_visibility.hpp"
|
||||
|
||||
namespace ov {
|
||||
namespace pass {
|
||||
|
||||
class TRANSFORMATIONS_API TransposeSinkingGeneralForward;
|
||||
class TRANSFORMATIONS_API TransposeSinkingGeneralBackward;
|
||||
class TRANSFORMATIONS_API TransposeSinkingGeneral;
|
||||
|
||||
} // namespace pass
|
||||
} // namespace ov
|
||||
|
||||
class ov::pass::TransposeSinkingGeneralForward : public ov::pass::GraphRewrite {
|
||||
public:
|
||||
OPENVINO_RTTI("TransposeSinkingGeneralForward", "0");
|
||||
TransposeSinkingGeneralForward();
|
||||
};
|
||||
|
||||
class ov::pass::TransposeSinkingGeneralBackward : public ov::pass::GraphRewrite {
|
||||
public:
|
||||
OPENVINO_RTTI("TransposeSinkingGeneralBackward", "0");
|
||||
TransposeSinkingGeneralBackward();
|
||||
};
|
||||
|
||||
class ov::pass::TransposeSinkingGeneral : public ov::pass::ModelPass {
|
||||
public:
|
||||
OPENVINO_RTTI("TransposeSinkingGeneral", "0");
|
||||
bool run_on_model(const std::shared_ptr<ov::Model>& m) override;
|
||||
};
|
@ -47,4 +47,6 @@ ov::NodeVector InsertTransposeBeforeNode(std::shared_ptr<ov::Node> main_node,
|
||||
std::shared_ptr<ov::opset9::Constant> transpose_const);
|
||||
} // namespace sink_backward
|
||||
|
||||
void UpdateForwardSinkingAbility(std::shared_ptr<ov::Node>);
|
||||
|
||||
} // namespace transpose_sinking
|
||||
|
@ -0,0 +1,32 @@
|
||||
// Copyright (C) 2022 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "openvino/core/node.hpp"
|
||||
#include "openvino/core/runtime_attribute.hpp"
|
||||
#include "transformations_visibility.hpp"
|
||||
|
||||
namespace ov {
|
||||
|
||||
TRANSFORMATIONS_API void mark_as_no_sinking_node(const std::shared_ptr<Node>& node);
|
||||
|
||||
TRANSFORMATIONS_API bool is_sinking_node(const std::shared_ptr<Node>& node);
|
||||
TRANSFORMATIONS_API bool is_sinking_node(const Node* node);
|
||||
|
||||
/**
|
||||
* @ingroup ie_runtime_attr_api
|
||||
* @brief NoTransposeSinkingAttr class represents runtime info attribute that marks transpose
|
||||
* operation should not be moved be backward sinking propagation.
|
||||
*/
|
||||
class TRANSFORMATIONS_API NoTransposeSinkingAttr : public RuntimeAttribute {
|
||||
public:
|
||||
OPENVINO_RTTI("no_transpose_sinking", "0");
|
||||
|
||||
bool is_copyable() const override {
|
||||
return false;
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace ov
|
@ -66,6 +66,7 @@
|
||||
#include <transformations/common_optimizations/subtract_fusion.hpp>
|
||||
#include <transformations/common_optimizations/swish_fusion.hpp>
|
||||
#include <transformations/common_optimizations/transpose_sinking.hpp>
|
||||
#include <transformations/common_optimizations/transpose_sinking_general.hpp>
|
||||
#include <transformations/common_optimizations/transpose_to_reshape.hpp>
|
||||
#include <transformations/init_node_info.hpp>
|
||||
#include <transformations/low_precision/mark_dequantization_subgraph.hpp>
|
||||
|
@ -14,6 +14,7 @@
|
||||
#include <vector>
|
||||
|
||||
#include "itt.hpp"
|
||||
#include "transformations/common_optimizations/transpose_sinking_utils.hpp"
|
||||
#include "transformations/utils/utils.hpp"
|
||||
|
||||
using namespace ov;
|
||||
@ -294,15 +295,21 @@ ov::pass::TransposeFuse::TransposeFuse() {
|
||||
is_ordered = false;
|
||||
}
|
||||
|
||||
auto transpose_order_type = transpose1_order->get_element_type();
|
||||
if (transpose_order_type != transpose2_order->get_element_type())
|
||||
transpose_order_type = element::i64;
|
||||
|
||||
if (is_ordered) {
|
||||
return ngraph::replace_output_update_name(transpose2->output(0), input);
|
||||
} else {
|
||||
auto new_order = opset7::Constant::create(element::i64, {order2.size()}, order2);
|
||||
auto new_order = opset7::Constant::create(transpose_order_type, {order2.size()}, order2);
|
||||
auto new_transpose = register_new_node<opset7::Transpose>(input, new_order);
|
||||
|
||||
new_transpose->set_friendly_name(m.get_match_root()->get_friendly_name());
|
||||
ngraph::copy_runtime_info({transpose1, transpose2}, new_transpose);
|
||||
ngraph::replace_node(m.get_match_root(), new_transpose);
|
||||
|
||||
transpose_sinking::UpdateForwardSinkingAbility(new_transpose);
|
||||
}
|
||||
|
||||
return true;
|
||||
|
@ -1,5 +1,6 @@
|
||||
#include "transformations/common_optimizations/transpose_sinking_binary.hpp"
|
||||
|
||||
#include <openvino/opsets/opset9.hpp>
|
||||
#include <openvino/pass/pattern/op/or.hpp>
|
||||
#include <transformations/utils/utils.hpp>
|
||||
#include <utility>
|
||||
@ -12,6 +13,7 @@
|
||||
#include "openvino/util/common_util.hpp"
|
||||
#include "openvino/util/log.hpp"
|
||||
#include "transformations/common_optimizations/transpose_sinking_utils.hpp"
|
||||
#include "transformations/rt_info/transpose_sinking_attr.hpp"
|
||||
|
||||
using namespace ov::pass::pattern;
|
||||
using namespace ov;
|
||||
@ -34,6 +36,7 @@ ov::pass::TransposeSinkingBinaryElementwiseForward::TransposeSinkingBinaryElemen
|
||||
sink_forward::UpdateInputTransposes(main_node, transpose_input_info);
|
||||
for (auto& new_node : sink_forward::InsertOutputTransposes(main_node, transpose_input_info)) {
|
||||
register_new_node(new_node);
|
||||
transpose_sinking::UpdateForwardSinkingAbility(new_node);
|
||||
}
|
||||
|
||||
return true;
|
||||
@ -43,13 +46,19 @@ ov::pass::TransposeSinkingBinaryElementwiseForward::TransposeSinkingBinaryElemen
|
||||
register_matcher(m, matcher_pass_callback);
|
||||
}
|
||||
|
||||
ov::pass::TransposeSinkingBinaryElementwiseBackward::TransposeSinkingBinaryElementwiseBackward() {
|
||||
pass::TransposeSinkingBinaryElementwiseBackward::TransposeSinkingBinaryElementwiseBackward() {
|
||||
MATCHER_SCOPE(TransposeSinkingBinaryElementwiseBackward);
|
||||
|
||||
auto main_node_label = wrap_type<op::util::BinaryElementwiseArithmetic>(consumers_count(1));
|
||||
|
||||
auto transpose_const_label = wrap_type<Constant>(consumers_count(1));
|
||||
auto transpose_label = wrap_type<Transpose>({main_node_label, transpose_const_label}, consumers_count(1));
|
||||
|
||||
auto IfSinkingEnabled = [](const Output<Node>& output) -> bool {
|
||||
static auto consumers_check = consumers_count(1);
|
||||
return consumers_check(output) && is_sinking_node(output.get_node_shared_ptr());
|
||||
};
|
||||
|
||||
auto transpose_label = wrap_type<Transpose>({main_node_label, transpose_const_label}, IfSinkingEnabled);
|
||||
|
||||
matcher_pass_callback matcher_pass_callback = [=](Matcher& m) {
|
||||
const auto& pattern_to_output = m.get_pattern_value_map();
|
||||
|
@ -13,6 +13,7 @@
|
||||
#include "openvino/util/common_util.hpp"
|
||||
#include "openvino/util/log.hpp"
|
||||
#include "transformations/common_optimizations/transpose_sinking_utils.hpp"
|
||||
#include "transformations/rt_info/transpose_sinking_attr.hpp"
|
||||
|
||||
using namespace ov::pass::pattern;
|
||||
using namespace ov;
|
||||
@ -35,6 +36,7 @@ ov::pass::TransposeSinkingConcatForward::TransposeSinkingConcatForward() {
|
||||
sink_forward::UpdateInputTransposes(main_node, transpose_input_info);
|
||||
for (auto& new_node : sink_forward::InsertOutputTransposes(main_node, transpose_input_info)) {
|
||||
register_new_node(new_node);
|
||||
transpose_sinking::UpdateForwardSinkingAbility(new_node);
|
||||
}
|
||||
|
||||
auto concat_node = as_type_ptr<Concat>(main_node);
|
||||
@ -55,7 +57,13 @@ ov::pass::TransposeSinkingConcatBackward::TransposeSinkingConcatBackward() {
|
||||
auto main_node_label = wrap_type<Concat>(consumers_count(1));
|
||||
|
||||
auto transpose_const_label = wrap_type<Constant>(consumers_count(1));
|
||||
auto transpose_label = wrap_type<Transpose>({main_node_label, transpose_const_label}, consumers_count(1));
|
||||
|
||||
auto IfSinkingEnabled = [](const Output<Node>& output) -> bool {
|
||||
static auto consumers_check = consumers_count(1);
|
||||
return consumers_check(output) && is_sinking_node(output.get_node_shared_ptr());
|
||||
};
|
||||
|
||||
auto transpose_label = wrap_type<Transpose>({main_node_label, transpose_const_label}, IfSinkingEnabled);
|
||||
|
||||
matcher_pass_callback matcher_pass_callback = [=](Matcher& m) {
|
||||
const auto& pattern_to_output = m.get_pattern_value_map();
|
||||
|
@ -0,0 +1,56 @@
|
||||
// Copyright (C) 2022 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#include "transformations/common_optimizations/transpose_sinking_general.hpp"
|
||||
|
||||
#include <ngraph/pass/constant_folding.hpp>
|
||||
#include <ngraph/pass/graph_rewrite.hpp>
|
||||
#include <ngraph/pass/manager.hpp>
|
||||
#include <ngraph/pattern/op/wrap_type.hpp>
|
||||
#include <ngraph/rt_info.hpp>
|
||||
|
||||
#include "itt.hpp"
|
||||
#include "transformations/common_optimizations/transpose_sinking.hpp"
|
||||
#include "transformations/common_optimizations/transpose_sinking_binary.hpp"
|
||||
#include "transformations/common_optimizations/transpose_sinking_concat.hpp"
|
||||
#include "transformations/common_optimizations/transpose_sinking_split.hpp"
|
||||
#include "transformations/common_optimizations/transpose_sinking_unary.hpp"
|
||||
#include "transformations/utils/utils.hpp"
|
||||
|
||||
ov::pass::TransposeSinkingGeneralForward::TransposeSinkingGeneralForward() {
|
||||
MATCHER_SCOPE(TransposeSinkingGeneralForward);
|
||||
add_matcher<ov::pass::TransposeSinkingUnaryForward>();
|
||||
add_matcher<ov::pass::TransposeSinkingBinaryElementwiseForward>();
|
||||
add_matcher<ov::pass::TransposeSinkingConcatForward>();
|
||||
add_matcher<ov::pass::TransposeSinkingSplitForward>();
|
||||
add_matcher<ngraph::pass::TransposeFuse>();
|
||||
}
|
||||
|
||||
ov::pass::TransposeSinkingGeneralBackward::TransposeSinkingGeneralBackward() {
|
||||
MATCHER_SCOPE(TransposeSinkingGeneralBackward);
|
||||
add_matcher<ov::pass::TransposeSinkingUnaryBackward>();
|
||||
add_matcher<ov::pass::TransposeSinkingBinaryElementwiseBackward>();
|
||||
add_matcher<ov::pass::TransposeSinkingConcatBackward>();
|
||||
add_matcher<ov::pass::TransposeSinkingSplitBackward>();
|
||||
add_matcher<ngraph::pass::TransposeFuse>();
|
||||
}
|
||||
|
||||
bool ov::pass::TransposeSinkingGeneral::run_on_model(const std::shared_ptr<ov::Model>& f) {
|
||||
RUN_ON_FUNCTION_SCOPE(TransposeSinkingGeneral);
|
||||
{
|
||||
ngraph::pass::Manager manager(get_pass_config());
|
||||
manager.register_pass<ov::pass::TransposeSinkingGeneralForward>();
|
||||
manager.register_pass<ngraph::pass::ConstantFolding>();
|
||||
manager.run_passes(f);
|
||||
}
|
||||
|
||||
{
|
||||
ngraph::pass::Manager manager(get_pass_config());
|
||||
manager.register_pass<ov::pass::TransposeSinkingGeneralBackward>();
|
||||
manager.register_pass<ngraph::pass::ConstantFolding>();
|
||||
manager.run_passes(f);
|
||||
}
|
||||
|
||||
return false;
|
||||
}
|
@ -13,6 +13,7 @@
|
||||
#include "openvino/util/common_util.hpp"
|
||||
#include "openvino/util/log.hpp"
|
||||
#include "transformations/common_optimizations/transpose_sinking_utils.hpp"
|
||||
#include "transformations/rt_info/transpose_sinking_attr.hpp"
|
||||
|
||||
using namespace ov::pass::pattern;
|
||||
using namespace ov;
|
||||
@ -66,6 +67,9 @@ std::shared_ptr<Constant> GetTransposeConstant(Input<Node> input) {
|
||||
if (!transpose_node)
|
||||
return {};
|
||||
|
||||
if (!is_sinking_node(input.get_node()))
|
||||
return {};
|
||||
|
||||
auto constant_node = as_type_ptr<Constant>(transpose_node->input_value(1).get_node_shared_ptr());
|
||||
if (!constant_node)
|
||||
return {};
|
||||
@ -181,6 +185,7 @@ ov::pass::TransposeSinkingSplitForward::TransposeSinkingSplitForward() {
|
||||
sink_forward::RemoveZeroInputNode(main_node);
|
||||
for (auto& new_node : sink_forward::InsertOutputTransposes(main_node, transpose_input_info)) {
|
||||
register_new_node(new_node);
|
||||
transpose_sinking::UpdateForwardSinkingAbility(new_node);
|
||||
}
|
||||
|
||||
const auto transpose_axis_order = transpose_input_info.transpose_const->get_axis_vector_val();
|
||||
|
@ -6,6 +6,10 @@
|
||||
#include "itt.hpp"
|
||||
#include "openvino/opsets/opset9.hpp"
|
||||
#include "openvino/pass/pattern/op/wrap_type.hpp"
|
||||
#include "transformations/common_optimizations/transpose_sinking_utils.hpp"
|
||||
#include "transformations/rt_info/transpose_sinking_attr.hpp"
|
||||
|
||||
using namespace ov;
|
||||
|
||||
namespace {
|
||||
|
||||
@ -105,6 +109,8 @@ ov::pass::TransposeSinkingUnaryForward::TransposeSinkingUnaryForward() {
|
||||
register_new_node(new_nodes.first);
|
||||
register_new_node(new_nodes.second);
|
||||
|
||||
transpose_sinking::UpdateForwardSinkingAbility(new_nodes.second);
|
||||
|
||||
return true;
|
||||
};
|
||||
|
||||
@ -112,6 +118,12 @@ ov::pass::TransposeSinkingUnaryForward::TransposeSinkingUnaryForward() {
|
||||
register_matcher(m, matcher_pass_callback);
|
||||
}
|
||||
|
||||
namespace {
|
||||
bool IfSinkingEnabled(const Output<Node>& output) {
|
||||
return is_sinking_node(output.get_node_shared_ptr());
|
||||
}
|
||||
} // namespace
|
||||
|
||||
ov::pass::TransposeSinkingUnaryBackward::TransposeSinkingUnaryBackward() {
|
||||
MATCHER_SCOPE(TransposeSinkingUnaryBackward);
|
||||
|
||||
@ -123,7 +135,8 @@ ov::pass::TransposeSinkingUnaryBackward::TransposeSinkingUnaryBackward() {
|
||||
ov::opset9::Convert>({ov::pass::pattern::any_input()});
|
||||
|
||||
auto transpose_label =
|
||||
ov::pass::pattern::wrap_type<ov::opset9::Transpose>({unary_label, ov::pass::pattern::any_input()});
|
||||
ov::pass::pattern::wrap_type<ov::opset9::Transpose>({unary_label, ov::pass::pattern::any_input()},
|
||||
IfSinkingEnabled);
|
||||
|
||||
ov::matcher_pass_callback matcher_pass_callback = [=](ov::pass::pattern::Matcher& m) {
|
||||
const auto& pattern_to_output = m.get_pattern_value_map();
|
||||
|
@ -11,6 +11,7 @@
|
||||
#include "openvino/pass/pattern/op/wrap_type.hpp"
|
||||
#include "openvino/util/common_util.hpp"
|
||||
#include "openvino/util/log.hpp"
|
||||
#include "transformations/rt_info/transpose_sinking_attr.hpp"
|
||||
|
||||
namespace transpose_sinking {
|
||||
|
||||
@ -169,4 +170,44 @@ NodeVector InsertTransposeBeforeNode(NodePtr main_node, std::shared_ptr<Constant
|
||||
}
|
||||
} // namespace sink_backward
|
||||
|
||||
#define CHECK_TRANSPOSE_SINKING_SUPPORTED(TYPE, node) \
|
||||
if (dynamic_cast<TYPE*>(node)) { \
|
||||
return true; \
|
||||
}
|
||||
|
||||
namespace {
|
||||
|
||||
bool CanPropagateForwardThrough(Node* node) {
|
||||
CHECK_TRANSPOSE_SINKING_SUPPORTED(ov::op::util::UnaryElementwiseArithmetic, node);
|
||||
CHECK_TRANSPOSE_SINKING_SUPPORTED(Clamp, node);
|
||||
CHECK_TRANSPOSE_SINKING_SUPPORTED(Elu, node);
|
||||
CHECK_TRANSPOSE_SINKING_SUPPORTED(SoftPlus, node);
|
||||
CHECK_TRANSPOSE_SINKING_SUPPORTED(LogicalNot, node);
|
||||
CHECK_TRANSPOSE_SINKING_SUPPORTED(Convert, node);
|
||||
CHECK_TRANSPOSE_SINKING_SUPPORTED(ov::op::util::BinaryElementwiseArithmetic, node);
|
||||
CHECK_TRANSPOSE_SINKING_SUPPORTED(Concat, node);
|
||||
CHECK_TRANSPOSE_SINKING_SUPPORTED(Split, node);
|
||||
CHECK_TRANSPOSE_SINKING_SUPPORTED(Transpose, node);
|
||||
|
||||
return false;
|
||||
}
|
||||
|
||||
bool CanPropagateForward(NodePtr node) {
|
||||
for (size_t i = 0; i < node->get_output_size(); ++i) {
|
||||
for (auto& consumer_input : node->output(i).get_target_inputs()) {
|
||||
if (!CanPropagateForwardThrough(consumer_input.get_node()))
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
void UpdateForwardSinkingAbility(NodePtr node) {
|
||||
if (!CanPropagateForward(node))
|
||||
mark_as_no_sinking_node(node);
|
||||
}
|
||||
|
||||
} // namespace transpose_sinking
|
||||
|
@ -0,0 +1,28 @@
|
||||
// Copyright (C) 2022 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#include "transformations/rt_info/transpose_sinking_attr.hpp"
|
||||
|
||||
using namespace ov;
|
||||
|
||||
void ov::mark_as_no_sinking_node(const std::shared_ptr<Node>& node) {
|
||||
auto& rt_info = node->get_rt_info();
|
||||
rt_info[NoTransposeSinkingAttr::get_type_info_static()] = NoTransposeSinkingAttr();
|
||||
}
|
||||
|
||||
namespace {
|
||||
template <typename NodePtr>
|
||||
bool is_sinking_node_private(NodePtr node) {
|
||||
const auto& rt_info = node->get_rt_info();
|
||||
return rt_info.find(NoTransposeSinkingAttr::get_type_info_static()) == rt_info.end();
|
||||
}
|
||||
} // namespace
|
||||
|
||||
bool ov::is_sinking_node(const std::shared_ptr<Node>& node) {
|
||||
return is_sinking_node_private(node);
|
||||
}
|
||||
|
||||
bool ov::is_sinking_node(const Node* node) {
|
||||
return is_sinking_node_private(node);
|
||||
}
|
@ -0,0 +1,392 @@
|
||||
// Copyright (C) 2022 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#include <transformations/common_optimizations/transpose_sinking_general.hpp>
|
||||
|
||||
#include <transformations/init_node_info.hpp>
|
||||
#include <openvino/frontend/manager.hpp>
|
||||
|
||||
#include <openvino/opsets/opset9.hpp>
|
||||
|
||||
#include <openvino/pass/manager.hpp>
|
||||
#include "common_test_utils/ngraph_test_utils.hpp"
|
||||
|
||||
#include <functional>
|
||||
|
||||
#include "gtest/gtest.h"
|
||||
|
||||
using namespace testing;
|
||||
|
||||
using NodePtr = std::shared_ptr<ov::Node>;
|
||||
|
||||
TEST_F(TransformationTestsF, TransposeSinkingGeneralTestUnariesTransposesForward) {
|
||||
ov::Shape input_shape = {1, 96, 55, 55};
|
||||
ov::element::Type input_type = ov::element::f32;
|
||||
size_t num_unary_ops = 10;
|
||||
|
||||
{
|
||||
auto X = std::make_shared<ov::opset9::Parameter>(input_type, input_shape);
|
||||
|
||||
NodePtr in_op = X;
|
||||
for (size_t i = 0; i < num_unary_ops; ++i) {
|
||||
auto ng_order0 = std::make_shared<ov::opset9::Constant>(ov::element::u64, ov::Shape{4}, ov::Shape{0, 2, 3, 1});
|
||||
auto transpose0 = std::make_shared<ov::opset9::Transpose>(in_op, ng_order0);
|
||||
|
||||
auto unary = std::make_shared<ov::opset9::Tanh>(transpose0);
|
||||
|
||||
auto ng_order1 = std::make_shared<ov::opset9::Constant>(ov::element::u64, ov::Shape{4}, ov::Shape{0, 3, 1, 2});
|
||||
in_op = std::make_shared<ov::opset9::Transpose>(unary, ng_order1);
|
||||
}
|
||||
|
||||
function = std::make_shared<ov::Model>(in_op, ov::ParameterVector{X});
|
||||
}
|
||||
|
||||
{
|
||||
auto X = std::make_shared<ov::opset9::Parameter>(input_type, input_shape);
|
||||
|
||||
NodePtr in_op = X;
|
||||
for (size_t i = 0; i < num_unary_ops; ++i) {
|
||||
in_op = std::make_shared<ov::opset9::Tanh>(in_op);
|
||||
}
|
||||
|
||||
function_ref = std::make_shared<ov::Model>(in_op, ov::ParameterVector{X});
|
||||
}
|
||||
|
||||
manager.register_pass<ov::pass::TransposeSinkingGeneralForward>();
|
||||
}
|
||||
|
||||
TEST_F(TransformationTestsF, TransposeSinkingGeneralTestUnariesTransposesBackward) {
|
||||
ov::Shape input_shape = {1, 96, 55, 55};
|
||||
ov::element::Type input_type = ov::element::f32;
|
||||
size_t num_unary_ops = 10;
|
||||
|
||||
{
|
||||
auto X = std::make_shared<ov::opset9::Parameter>(input_type, input_shape);
|
||||
|
||||
NodePtr in_op = X;
|
||||
for (size_t i = 0; i < num_unary_ops; ++i) {
|
||||
auto ng_order0 = std::make_shared<ov::opset9::Constant>(ov::element::u64, ov::Shape{4}, ov::Shape{0, 2, 3, 1});
|
||||
auto transpose0 = std::make_shared<ov::opset9::Transpose>(in_op, ng_order0);
|
||||
|
||||
auto unary = std::make_shared<ov::opset9::Tanh>(transpose0);
|
||||
|
||||
auto ng_order1 = std::make_shared<ov::opset9::Constant>(ov::element::u64, ov::Shape{4}, ov::Shape{0, 3, 1, 2});
|
||||
in_op = std::make_shared<ov::opset9::Transpose>(unary, ng_order1);
|
||||
}
|
||||
|
||||
function = std::make_shared<ov::Model>(in_op, ov::ParameterVector{X});
|
||||
}
|
||||
|
||||
{
|
||||
auto X = std::make_shared<ov::opset9::Parameter>(input_type, input_shape);
|
||||
|
||||
NodePtr in_op = X;
|
||||
for (size_t i = 0; i < num_unary_ops; ++i) {
|
||||
in_op = std::make_shared<ov::opset9::Tanh>(in_op);
|
||||
}
|
||||
|
||||
function_ref = std::make_shared<ov::Model>(in_op, ov::ParameterVector{X});
|
||||
}
|
||||
|
||||
manager.register_pass<ov::pass::TransposeSinkingGeneralBackward>();
|
||||
}
|
||||
|
||||
TEST_F(TransformationTestsF, TransposeSinkingGeneralTestUnariesTransposesGeneral) {
|
||||
ov::Shape input_shape = {1, 96, 55, 55};
|
||||
ov::element::Type input_type = ov::element::f32;
|
||||
size_t num_unary_ops = 10;
|
||||
|
||||
{
|
||||
auto X = std::make_shared<ov::opset9::Parameter>(input_type, input_shape);
|
||||
|
||||
auto ng_order0 = std::make_shared<ov::opset9::Constant>(ov::element::u64, ov::Shape{4}, ov::Shape{0, 2, 3, 1});
|
||||
auto transpose0 = std::make_shared<ov::opset9::Transpose>(X, ng_order0);
|
||||
|
||||
NodePtr in_op = transpose0;
|
||||
for (size_t i = 0; i < num_unary_ops; ++i) {
|
||||
auto ng_order0 = std::make_shared<ov::opset9::Constant>(ov::element::u64, ov::Shape{4}, ov::Shape{0, 2, 3, 1});
|
||||
auto transpose0 = std::make_shared<ov::opset9::Transpose>(in_op, ng_order0);
|
||||
|
||||
auto unary = std::make_shared<ov::opset9::Tanh>(transpose0);
|
||||
|
||||
auto ng_order1 = std::make_shared<ov::opset9::Constant>(ov::element::u64, ov::Shape{4}, ov::Shape{0, 3, 1, 2});
|
||||
in_op = std::make_shared<ov::opset9::Transpose>(unary, ng_order1);
|
||||
}
|
||||
|
||||
function = std::make_shared<ov::Model>(in_op, ov::ParameterVector{X});
|
||||
}
|
||||
|
||||
{
|
||||
auto X = std::make_shared<ov::opset9::Parameter>(input_type, input_shape);
|
||||
|
||||
NodePtr in_op = X;
|
||||
for (size_t i = 0; i < num_unary_ops; ++i) {
|
||||
in_op = std::make_shared<ov::opset9::Tanh>(in_op);
|
||||
}
|
||||
|
||||
auto ng_order0 = std::make_shared<ov::opset9::Constant>(ov::element::u64, ov::Shape{4}, ov::Shape{0, 2, 3, 1});
|
||||
auto transpose0 = std::make_shared<ov::opset9::Transpose>(in_op, ng_order0);
|
||||
|
||||
function_ref = std::make_shared<ov::Model>(transpose0, ov::ParameterVector{X});
|
||||
}
|
||||
|
||||
manager.register_pass<ov::pass::TransposeSinkingGeneral>();
|
||||
}
|
||||
|
||||
TEST_F(TransformationTestsF, TransposeSinkingGeneralTestBinaryGeneral) {
|
||||
ov::Shape input_shape = {1, 96, 55, 55};
|
||||
ov::element::Type input_type = ov::element::f32;
|
||||
size_t num_binary_ops = 10;
|
||||
|
||||
{
|
||||
auto X = std::make_shared<ov::opset9::Parameter>(input_type, input_shape);
|
||||
|
||||
auto ng_order0 = std::make_shared<ov::opset9::Constant>(ov::element::u64, ov::Shape{4}, ov::Shape{0, 2, 3, 1});
|
||||
auto transpose0 = std::make_shared<ov::opset9::Transpose>(X, ng_order0);
|
||||
|
||||
NodePtr in_op = transpose0;
|
||||
for (size_t i = 0; i < num_binary_ops; ++i) {
|
||||
auto in_constant = std::make_shared<ov::opset9::Constant>(input_type, input_shape, ov::Shape{1});
|
||||
auto ng_order1 = std::make_shared<ov::opset9::Constant>(ov::element::u64, ov::Shape{4}, ov::Shape{0, 2, 3, 1});
|
||||
auto transpose1 = std::make_shared<ov::opset9::Transpose>(in_constant, ng_order1);
|
||||
|
||||
in_op = std::make_shared<ov::opset9::Add>(in_op, transpose1);
|
||||
}
|
||||
|
||||
function = std::make_shared<ov::Model>(in_op, ov::ParameterVector{X});
|
||||
}
|
||||
|
||||
{
|
||||
auto X = std::make_shared<ov::opset9::Parameter>(input_type, input_shape);
|
||||
|
||||
NodePtr in_op = X;
|
||||
for (size_t i = 0; i < num_binary_ops; ++i) {
|
||||
auto in_constant = std::make_shared<ov::opset9::Constant>(input_type, input_shape, ov::Shape{1});
|
||||
in_op = std::make_shared<ov::opset9::Add>(in_op, in_constant);
|
||||
}
|
||||
|
||||
auto ng_order0 = std::make_shared<ov::opset9::Constant>(ov::element::u64, ov::Shape{4}, ov::Shape{0, 2, 3, 1});
|
||||
auto transpose0 = std::make_shared<ov::opset9::Transpose>(in_op, ng_order0);
|
||||
|
||||
function_ref = std::make_shared<ov::Model>(transpose0, ov::ParameterVector{X});
|
||||
}
|
||||
|
||||
manager.register_pass<ov::pass::TransposeSinkingGeneral>();
|
||||
}
|
||||
|
||||
TEST_F(TransformationTestsF, TransposeSinkingGeneralTestConcatGeneral) {
|
||||
ov::Shape input_shape = {1, 96, 55, 55};
|
||||
ov::element::Type input_type = ov::element::f32;
|
||||
const size_t num_concat_ops = 3;
|
||||
const size_t num_concat_inputs = 2;
|
||||
|
||||
{
|
||||
auto X = std::make_shared<ov::opset9::Parameter>(input_type, input_shape);
|
||||
|
||||
auto ng_order0 = std::make_shared<ov::opset9::Constant>(ov::element::u64, ov::Shape{4}, ov::Shape{0, 2, 3, 1});
|
||||
auto transpose0 = std::make_shared<ov::opset9::Transpose>(X, ng_order0);
|
||||
|
||||
NodePtr in_op = transpose0;
|
||||
for (size_t i = 0; i < num_concat_ops; ++i) {
|
||||
ov::OutputVector concat_inputs;
|
||||
concat_inputs.push_back(in_op);
|
||||
for (size_t j = 1; j < num_concat_inputs; ++j) {
|
||||
auto in_constant = std::make_shared<ov::opset9::Constant>(input_type, input_shape, ov::Shape{1});
|
||||
auto ng_order1 =
|
||||
std::make_shared<ov::opset9::Constant>(ov::element::u64, ov::Shape{4}, ov::Shape{0, 2, 3, 1});
|
||||
auto transpose1 = std::make_shared<ov::opset9::Transpose>(in_constant, ng_order1);
|
||||
concat_inputs.push_back(transpose1);
|
||||
}
|
||||
in_op = std::make_shared<ov::opset9::Concat>(concat_inputs, 1);
|
||||
}
|
||||
|
||||
function = std::make_shared<ov::Model>(in_op, ov::ParameterVector{X});
|
||||
}
|
||||
|
||||
{
|
||||
auto X = std::make_shared<ov::opset9::Parameter>(input_type, input_shape);
|
||||
|
||||
NodePtr in_op = X;
|
||||
for (size_t i = 0; i < num_concat_ops; ++i) {
|
||||
ov::OutputVector concat_inputs;
|
||||
|
||||
concat_inputs.push_back(in_op);
|
||||
|
||||
for (size_t j = 1; j < num_concat_inputs; ++j) {
|
||||
auto in_constant = std::make_shared<ov::opset9::Constant>(input_type, input_shape, ov::Shape{1});
|
||||
concat_inputs.push_back(in_constant);
|
||||
}
|
||||
in_op = std::make_shared<ov::opset9::Concat>(concat_inputs, 2);
|
||||
}
|
||||
|
||||
auto ng_order0 = std::make_shared<ov::opset9::Constant>(ov::element::u64, ov::Shape{4}, ov::Shape{0, 2, 3, 1});
|
||||
auto transpose0 = std::make_shared<ov::opset9::Transpose>(in_op, ng_order0);
|
||||
|
||||
function_ref = std::make_shared<ov::Model>(transpose0, ov::ParameterVector{X});
|
||||
}
|
||||
|
||||
manager.register_pass<ov::pass::TransposeSinkingGeneral>();
|
||||
}
|
||||
|
||||
// ----------------------------------------------------------------------------------------------------------------------
|
||||
|
||||
class IFactory {
|
||||
public:
|
||||
virtual ~IFactory() = default;
|
||||
virtual NodePtr create(const ov::OutputVector & parent) = 0;
|
||||
|
||||
virtual size_t getNumInputs() const = 0;
|
||||
virtual size_t getNumOuputs() const = 0;
|
||||
};
|
||||
|
||||
using FactoryPtr = std::shared_ptr<IFactory>;
|
||||
|
||||
class UnaryFactory : public IFactory {
|
||||
public:
|
||||
NodePtr create(const ov::OutputVector & parent) override {
|
||||
return std::make_shared<ov::opset9::Sinh>(parent.front());
|
||||
}
|
||||
|
||||
static FactoryPtr createFactory() {
|
||||
return std::make_shared<UnaryFactory>();
|
||||
}
|
||||
|
||||
size_t getNumInputs() const override { return 1; }
|
||||
size_t getNumOuputs() const override { return 1; }
|
||||
};
|
||||
|
||||
class BinaryFactory : public IFactory {
|
||||
public:
|
||||
NodePtr create(const ov::OutputVector & parent) override {
|
||||
return std::make_shared<ov::opset9::Add>(parent[0], parent[1]);
|
||||
}
|
||||
|
||||
static FactoryPtr createFactory() {
|
||||
return std::make_shared<BinaryFactory>();
|
||||
}
|
||||
|
||||
size_t getNumInputs() const override { return 2; }
|
||||
size_t getNumOuputs() const override { return 1; }
|
||||
};
|
||||
|
||||
class SplitFactory : public IFactory {
|
||||
public:
|
||||
SplitFactory(size_t axis) : axis_(axis) {}
|
||||
NodePtr create(const ov::OutputVector & parent) override {
|
||||
auto split_axis_const = std::make_shared<ov::opset9::Constant>(ov::element::u64,
|
||||
ov::Shape{},
|
||||
axis_);
|
||||
return std::make_shared<ov::opset9::Split>(parent.front(), split_axis_const, 2);
|
||||
}
|
||||
|
||||
static FactoryPtr createFactory(size_t axis) {
|
||||
return std::make_shared<SplitFactory>(axis);
|
||||
}
|
||||
|
||||
size_t getNumInputs() const override { return 1; }
|
||||
size_t getNumOuputs() const override { return 2; }
|
||||
private:
|
||||
const size_t axis_;
|
||||
};
|
||||
|
||||
class ConcatFactory : public IFactory {
|
||||
public:
|
||||
ConcatFactory(size_t axis) : axis_(axis) {}
|
||||
NodePtr create(const ov::OutputVector & parent) override {
|
||||
return std::make_shared<ov::opset9::Concat>(parent, axis_);
|
||||
}
|
||||
|
||||
static FactoryPtr createFactory(size_t axis) {
|
||||
return std::make_shared<ConcatFactory>(axis);
|
||||
}
|
||||
|
||||
size_t getNumInputs() const override { return 2; }
|
||||
size_t getNumOuputs() const override { return 1; }
|
||||
private:
|
||||
const size_t axis_;
|
||||
};
|
||||
|
||||
/*
|
||||
Each node pair should be started with input size = 1 node and finished with node output size = 1
|
||||
Insert Split/Concat to fullfill that.
|
||||
*/
|
||||
NodePtr CreateNodePair(FactoryPtr factory_first, FactoryPtr factory_second, NodePtr parent, size_t split_axis, size_t concat_axis) {
|
||||
NodePtr input = parent;
|
||||
if (factory_first->getNumInputs() != 1) {
|
||||
input = SplitFactory(split_axis).create(input->outputs());
|
||||
}
|
||||
|
||||
input = factory_first->create(input->outputs());
|
||||
if (factory_first->getNumOuputs() < factory_second->getNumInputs()) {
|
||||
input = SplitFactory(split_axis).create(input->outputs());
|
||||
} else if (factory_first->getNumOuputs() > factory_second->getNumInputs()) {
|
||||
input = ConcatFactory(concat_axis).create(input->outputs());
|
||||
}
|
||||
|
||||
auto output = factory_second->create(input->outputs());
|
||||
if (output->get_output_size() > 1) {
|
||||
output = ConcatFactory(concat_axis).create(output->outputs());
|
||||
}
|
||||
|
||||
return output;
|
||||
}
|
||||
|
||||
NodePtr MakeAllNodesSubgraph(NodePtr parent, size_t split_axis, size_t concat_axis) {
|
||||
std::vector<FactoryPtr> factories = { UnaryFactory::createFactory(),
|
||||
SplitFactory::createFactory(split_axis),
|
||||
ConcatFactory::createFactory(concat_axis) };
|
||||
NodePtr in_op = parent;
|
||||
for (int i = 0; i < factories.size(); ++i) {
|
||||
for (int j = 0; j < factories.size(); ++j) {
|
||||
in_op = CreateNodePair(factories[i], factories[j], in_op, split_axis, concat_axis);
|
||||
}
|
||||
}
|
||||
|
||||
return in_op;
|
||||
}
|
||||
|
||||
TEST_F(TransformationTestsF, TransposeSinkingGeneralTestMultipleTypes) {
|
||||
ov::Shape input_shape = {1, 96, 40, 55};
|
||||
ov::element::Type input_type = ov::element::f32;
|
||||
|
||||
{
|
||||
auto X = std::make_shared<ov::opset9::Parameter>(input_type, input_shape);
|
||||
|
||||
auto node0 = MakeAllNodesSubgraph(X, 1, 1);
|
||||
|
||||
auto ng_order0 = std::make_shared<ov::opset9::Constant>(ov::element::u64, ov::Shape{4}, ov::Shape{0, 2, 3, 1});
|
||||
auto transpose0 = std::make_shared<ov::opset9::Transpose>(node0, ng_order0);
|
||||
|
||||
auto reshape_const = std::make_shared<ov::opset9::Constant>(ov::element::u64, ov::Shape{4}, ov::Shape{1, 40, 55, 96});
|
||||
auto reshape = std::make_shared<ov::opset9::Reshape>(transpose0, reshape_const, false);
|
||||
|
||||
auto ng_order1 = std::make_shared<ov::opset9::Constant>(ov::element::u64, ov::Shape{4}, ov::Shape{0, 3, 1, 2});
|
||||
auto transpose1 = std::make_shared<ov::opset9::Transpose>(reshape, ng_order1);
|
||||
|
||||
auto node1 = MakeAllNodesSubgraph(transpose1, 1, 1);
|
||||
|
||||
function = std::make_shared<ov::Model>(node1, ov::ParameterVector{X});
|
||||
}
|
||||
|
||||
{
|
||||
auto X = std::make_shared<ov::opset9::Parameter>(input_type, input_shape);
|
||||
|
||||
auto ng_order0 = std::make_shared<ov::opset9::Constant>(ov::element::u64, ov::Shape{4}, ov::Shape{0, 2, 3, 1});
|
||||
auto transpose0 = std::make_shared<ov::opset9::Transpose>(X, ng_order0);
|
||||
|
||||
auto node0 = MakeAllNodesSubgraph(transpose0, 3, 3);
|
||||
|
||||
auto reshape_const = std::make_shared<ov::opset9::Constant>(ov::element::u64, ov::Shape{4}, ov::Shape{1, 40, 55, 96});
|
||||
auto reshape = std::make_shared<ov::opset9::Reshape>(node0, reshape_const, false);
|
||||
|
||||
auto node1 = MakeAllNodesSubgraph(reshape, 3, 3);
|
||||
|
||||
auto ng_order1 = std::make_shared<ov::opset9::Constant>(ov::element::u64, ov::Shape{4}, ov::Shape{0, 3, 1, 2});
|
||||
auto transpose1 = std::make_shared<ov::opset9::Transpose>(node1, ng_order1);
|
||||
|
||||
function_ref = std::make_shared<ov::Model>(transpose1, ov::ParameterVector{X});
|
||||
}
|
||||
|
||||
manager.register_pass<ov::pass::TransposeSinkingGeneral>();
|
||||
}
|
Loading…
Reference in New Issue
Block a user