Transpose sinking on passes for Concat, Split and binary eltwise operations (#13718)

* initial

* clang cleanup fixes

* remove TrasposeAxis function; cleanup namespaces

* fix TransposeInputsInfo spell

* one_input_transpose spell

* cleanup speel

* spell

* decompose forward sinking

* decompose backward sink

* use NodeVector

* clang cleanup

* decomposite transformations into different files

* decompose unit tests

* clang cleanup

* azure build fixes

* code review fixes

* clang cleanup fixes
This commit is contained in:
Evgeny Kotov 2022-11-21 13:24:26 +01:00 committed by GitHub
parent 738d7bb09f
commit 0846bdb67e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
11 changed files with 1982 additions and 0 deletions

View File

@ -0,0 +1,30 @@
// Copyright (C) 2022 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#pragma once
#include "openvino/pass/graph_rewrite.hpp"
#include "openvino/pass/pass.hpp"
#include "transformations_visibility.hpp"
namespace ov {
namespace pass {
class TRANSFORMATIONS_API TransposeSinkingBinaryElementwiseForward;
class TRANSFORMATIONS_API TransposeSinkingBinaryElementwiseBackward;
} // namespace pass
} // namespace ov
class ov::pass::TransposeSinkingBinaryElementwiseForward : public ov::pass::MatcherPass {
public:
OPENVINO_RTTI("ov::pass::TransposeSinkingBinaryElementwiseForward", "0");
TransposeSinkingBinaryElementwiseForward();
};
class ov::pass::TransposeSinkingBinaryElementwiseBackward : public ov::pass::MatcherPass {
public:
OPENVINO_RTTI("ov::pass::TransposeSinkingBinaryElementwiseBackward", "0");
TransposeSinkingBinaryElementwiseBackward();
};

View File

@ -0,0 +1,30 @@
// Copyright (C) 2022 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#pragma once
#include "openvino/pass/graph_rewrite.hpp"
#include "openvino/pass/pass.hpp"
#include "transformations_visibility.hpp"
namespace ov {
namespace pass {
class TRANSFORMATIONS_API TransposeSinkingConcatForward;
class TRANSFORMATIONS_API TransposeSinkingConcatBackward;
} // namespace pass
} // namespace ov
class ov::pass::TransposeSinkingConcatForward : public ov::pass::MatcherPass {
public:
OPENVINO_RTTI("ov::pass::TransposeSinkingConcatForward", "0");
TransposeSinkingConcatForward();
};
class ov::pass::TransposeSinkingConcatBackward : public ov::pass::MatcherPass {
public:
OPENVINO_RTTI("ov::pass::TransposeSinkingConcatBackward", "0");
TransposeSinkingConcatBackward();
};

View File

@ -0,0 +1,30 @@
// Copyright (C) 2022 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#pragma once
#include "openvino/pass/graph_rewrite.hpp"
#include "openvino/pass/pass.hpp"
#include "transformations_visibility.hpp"
namespace ov {
namespace pass {
class TRANSFORMATIONS_API TransposeSinkingSplitBackward;
class TRANSFORMATIONS_API TransposeSinkingSplitForward;
} // namespace pass
} // namespace ov
class ov::pass::TransposeSinkingSplitForward : public ov::pass::MatcherPass {
public:
OPENVINO_RTTI("ov::pass::TransposeSinkingSplitForward", "0");
TransposeSinkingSplitForward();
};
class ov::pass::TransposeSinkingSplitBackward : public ov::pass::MatcherPass {
public:
OPENVINO_RTTI("ov::pass::TransposeSinkingSplitBackward", "0");
TransposeSinkingSplitBackward();
};

View File

@ -0,0 +1,50 @@
// Copyright (C) 2022 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#pragma once
#include <openvino/pass/pattern/op/or.hpp>
#include <transformations/utils/utils.hpp>
#include <utility>
#include "itt.hpp"
#include "openvino/op/util/op_types.hpp"
#include "openvino/opsets/opset9.hpp"
#include "openvino/pass/pattern/op/label.hpp"
#include "openvino/pass/pattern/op/wrap_type.hpp"
#include "openvino/util/common_util.hpp"
#include "openvino/util/log.hpp"
namespace transpose_sinking {
struct TransposeInputsInfo {
std::shared_ptr<ov::opset9::Transpose> transpose;
std::shared_ptr<ov::opset9::Constant> transpose_const;
size_t input_idx;
bool isEmpty() const {
return !transpose || !transpose_const;
}
};
TransposeInputsInfo GetFirstTransposeInput(std::shared_ptr<ov::Node> node);
bool IfNodeHasTransposeInputs(const ov::Output<ov::Node>& output);
ov::AxisVector ReverseTransposeOrder(const ov::AxisVector& axis_order);
void SwapOutputNames(ov::Output<ov::Node> output1, ov::Output<ov::Node> output2);
void SwapFriendlyNames(std::shared_ptr<ov::Node> node1, std::shared_ptr<ov::Node> node2);
void SwapNames(std::shared_ptr<ov::Node> node1, std::shared_ptr<ov::Node> node2);
namespace sink_forward {
// insert input reversed transposes, remove first input tranpose
void UpdateInputTransposes(std::shared_ptr<ov::Node> main_node, TransposeInputsInfo& transpose_input_info);
void RemoveZeroInputNode(std::shared_ptr<ov::Node> main_node);
ov::NodeVector InsertOutputTransposes(std::shared_ptr<ov::Node> main_node, TransposeInputsInfo& transpose_input_info);
} // namespace sink_forward
namespace sink_backward {
ov::NodeVector InsertTransposeBeforeNode(std::shared_ptr<ov::Node> main_node,
std::shared_ptr<ov::opset9::Constant> transpose_const);
} // namespace sink_backward
} // namespace transpose_sinking

View File

@ -0,0 +1,74 @@
#include "transformations/common_optimizations/transpose_sinking_binary.hpp"
#include <openvino/pass/pattern/op/or.hpp>
#include <transformations/utils/utils.hpp>
#include <utility>
#include "itt.hpp"
#include "openvino/op/util/op_types.hpp"
#include "openvino/opsets/opset9.hpp"
#include "openvino/pass/pattern/op/label.hpp"
#include "openvino/pass/pattern/op/wrap_type.hpp"
#include "openvino/util/common_util.hpp"
#include "openvino/util/log.hpp"
#include "transformations/common_optimizations/transpose_sinking_utils.hpp"
using namespace ov::pass::pattern;
using namespace ov;
using namespace ov::opset9;
using namespace transpose_sinking;
ov::pass::TransposeSinkingBinaryElementwiseForward::TransposeSinkingBinaryElementwiseForward() {
MATCHER_SCOPE(TransposeSinkingBinaryElementwiseForward);
auto main_node_label = wrap_type<op::util::BinaryElementwiseArithmetic>(IfNodeHasTransposeInputs);
matcher_pass_callback matcher_pass_callback = [=](Matcher& m) {
const auto& pattern_to_output = m.get_pattern_value_map();
auto& main_node_output = pattern_to_output.at(main_node_label);
auto main_node = main_node_output.get_node_shared_ptr();
TransposeInputsInfo transpose_input_info = GetFirstTransposeInput(main_node);
sink_forward::UpdateInputTransposes(main_node, transpose_input_info);
for (auto& new_node : sink_forward::InsertOutputTransposes(main_node, transpose_input_info)) {
register_new_node(new_node);
}
return true;
};
auto m = std::make_shared<Matcher>(main_node_label, matcher_name);
register_matcher(m, matcher_pass_callback);
}
ov::pass::TransposeSinkingBinaryElementwiseBackward::TransposeSinkingBinaryElementwiseBackward() {
MATCHER_SCOPE(TransposeSinkingBinaryElementwiseBackward);
auto main_node_label = wrap_type<op::util::BinaryElementwiseArithmetic>(consumers_count(1));
auto transpose_const_label = wrap_type<Constant>(consumers_count(1));
auto transpose_label = wrap_type<Transpose>({main_node_label, transpose_const_label}, consumers_count(1));
matcher_pass_callback matcher_pass_callback = [=](Matcher& m) {
const auto& pattern_to_output = m.get_pattern_value_map();
auto transpose_const = as_type_ptr<Constant>(pattern_to_output.at(transpose_const_label).get_node_shared_ptr());
auto transpose = pattern_to_output.at(transpose_label).get_node_shared_ptr();
auto main_node = pattern_to_output.at(main_node_label).get_node_shared_ptr();
for (auto& new_node : sink_backward::InsertTransposeBeforeNode(main_node, transpose_const)) {
register_new_node(new_node);
}
// remove transpose after main node
transpose->output(0).replace(main_node);
SwapNames(transpose, main_node);
return true;
};
auto m = std::make_shared<Matcher>(transpose_label, matcher_name);
register_matcher(m, matcher_pass_callback);
}

View File

@ -0,0 +1,86 @@
#include "transformations/common_optimizations/transpose_sinking_concat.hpp"
#include <openvino/opsets/opset9.hpp>
#include <openvino/pass/pattern/op/or.hpp>
#include <transformations/utils/utils.hpp>
#include <utility>
#include "itt.hpp"
#include "openvino/op/util/op_types.hpp"
#include "openvino/opsets/opset9.hpp"
#include "openvino/pass/pattern/op/label.hpp"
#include "openvino/pass/pattern/op/wrap_type.hpp"
#include "openvino/util/common_util.hpp"
#include "openvino/util/log.hpp"
#include "transformations/common_optimizations/transpose_sinking_utils.hpp"
using namespace ov::pass::pattern;
using namespace ov;
using namespace ov::opset9;
using namespace transpose_sinking;
ov::pass::TransposeSinkingConcatForward::TransposeSinkingConcatForward() {
MATCHER_SCOPE(TransposeSinkingConcatForward);
auto main_node_label = wrap_type<Concat>(IfNodeHasTransposeInputs);
matcher_pass_callback matcher_pass_callback = [=](Matcher& m) {
const auto& pattern_to_output = m.get_pattern_value_map();
auto& main_node_output = pattern_to_output.at(main_node_label);
auto main_node = main_node_output.get_node_shared_ptr();
TransposeInputsInfo transpose_input_info = GetFirstTransposeInput(main_node);
sink_forward::UpdateInputTransposes(main_node, transpose_input_info);
for (auto& new_node : sink_forward::InsertOutputTransposes(main_node, transpose_input_info)) {
register_new_node(new_node);
}
auto concat_node = as_type_ptr<Concat>(main_node);
const auto transpose_axis_order = transpose_input_info.transpose_const->get_axis_vector_val();
const int64_t transposed_concat_axis = transpose_axis_order[concat_node->get_axis()];
concat_node->set_concatenation_axis(transposed_concat_axis);
return true;
};
auto m = std::make_shared<Matcher>(main_node_label, matcher_name);
register_matcher(m, matcher_pass_callback);
}
ov::pass::TransposeSinkingConcatBackward::TransposeSinkingConcatBackward() {
MATCHER_SCOPE(TransposeSinkingConcatBackward);
auto main_node_label = wrap_type<Concat>(consumers_count(1));
auto transpose_const_label = wrap_type<Constant>(consumers_count(1));
auto transpose_label = wrap_type<Transpose>({main_node_label, transpose_const_label}, consumers_count(1));
matcher_pass_callback matcher_pass_callback = [=](Matcher& m) {
const auto& pattern_to_output = m.get_pattern_value_map();
auto transpose_const = as_type_ptr<Constant>(pattern_to_output.at(transpose_const_label).get_node_shared_ptr());
auto transpose = pattern_to_output.at(transpose_label).get_node_shared_ptr();
auto main_node = pattern_to_output.at(main_node_label).get_node_shared_ptr();
for (auto& new_node : sink_backward::InsertTransposeBeforeNode(main_node, transpose_const)) {
register_new_node(new_node);
}
// remove transpose after main node
transpose->output(0).replace(main_node);
SwapNames(transpose, main_node);
auto concat_node = as_type_ptr<Concat>(main_node);
const auto transpose_axis_order = transpose_const->get_axis_vector_val();
const auto reversed_traspose_axis_order = ReverseTransposeOrder(transpose_axis_order);
const int64_t transposed_concat_axis = reversed_traspose_axis_order[concat_node->get_axis()];
concat_node->set_concatenation_axis(transposed_concat_axis);
return true;
};
auto m = std::make_shared<Matcher>(transpose_label, matcher_name);
register_matcher(m, matcher_pass_callback);
}

View File

@ -0,0 +1,200 @@
#include "transformations/common_optimizations/transpose_sinking_split.hpp"
#include <openvino/opsets/opset9.hpp>
#include <openvino/pass/pattern/op/or.hpp>
#include <transformations/utils/utils.hpp>
#include <utility>
#include "itt.hpp"
#include "openvino/op/util/op_types.hpp"
#include "openvino/opsets/opset9.hpp"
#include "openvino/pass/pattern/op/label.hpp"
#include "openvino/pass/pattern/op/wrap_type.hpp"
#include "openvino/util/common_util.hpp"
#include "openvino/util/log.hpp"
#include "transformations/common_optimizations/transpose_sinking_utils.hpp"
using namespace ov::pass::pattern;
using namespace ov;
using namespace ov::opset9;
using namespace transpose_sinking;
namespace {
using NodePtr = std::shared_ptr<Node>;
struct OutputTranspose {
OutputTranspose() : transpose(nullptr), transpose_const(nullptr) {}
Transpose* transpose;
Constant* transpose_const;
};
OutputTranspose GetOutputTransposes(NodePtr node) {
for (size_t output_idx = 0; output_idx < node->get_output_size(); ++output_idx) {
for (auto& input : node->get_output_target_inputs(output_idx)) {
auto transpose_node = dynamic_cast<Transpose*>(input.get_node());
if (!transpose_node)
continue;
auto constant_node = dynamic_cast<Constant*>(transpose_node->input_value(1).get_node());
if (!constant_node)
continue;
{
OutputTranspose output_transpose;
output_transpose.transpose = transpose_node;
output_transpose.transpose_const = constant_node;
return output_transpose;
}
}
}
return OutputTranspose();
}
NodePtr FindSplitInput(Node* node) {
for (size_t input_idx = 0; input_idx < node->get_input_size(); ++input_idx) {
NodePtr input_node = node->get_input_node_shared_ptr(input_idx);
auto split_node = as_type_ptr<Split>(input_node);
if (split_node)
return split_node;
}
return {};
}
std::shared_ptr<Constant> GetTransposeConstant(Input<Node> input) {
auto transpose_node = dynamic_cast<Transpose*>(input.get_node());
if (!transpose_node)
return {};
auto constant_node = as_type_ptr<Constant>(transpose_node->input_value(1).get_node_shared_ptr());
if (!constant_node)
return {};
return constant_node;
}
bool HasInputSplitAndTransposeSiblings(const Output<Node>& output) {
NodePtr split_node = FindSplitInput(output.get_node());
if (!split_node) {
return false;
}
AxisVector first_transpose_axis_order;
// get first transpose axis
{
auto constant_node = GetTransposeConstant(*(split_node->get_output_target_inputs(0).begin()));
if (!constant_node)
return false;
first_transpose_axis_order = constant_node->get_axis_vector_val();
}
for (size_t output_idx = 1; output_idx < split_node->get_output_size(); ++output_idx) {
for (auto& input : split_node->get_output_target_inputs(output_idx)) {
auto constant_node = GetTransposeConstant(input);
if (!constant_node)
return false;
AxisVector transpose_axis_order = constant_node->get_axis_vector_val();
if (transpose_axis_order.size() != first_transpose_axis_order.size())
return false;
if (!std::equal(transpose_axis_order.begin(),
transpose_axis_order.end(),
first_transpose_axis_order.begin()))
return false;
}
}
return true;
}
} // namespace
ov::pass::TransposeSinkingSplitBackward::TransposeSinkingSplitBackward() {
MATCHER_SCOPE(TransposeSinkingSplitBackward);
auto transpose_const_label = wrap_type<Constant>(consumers_count(1));
auto transpose_label =
wrap_type<Transpose>({any_input(), transpose_const_label}, HasInputSplitAndTransposeSiblings);
matcher_pass_callback matcher_pass_callback = [=](Matcher& m) {
const auto& pattern_to_output = m.get_pattern_value_map();
auto transpose_label_node = pattern_to_output.at(transpose_label).get_node();
NodePtr split = FindSplitInput(transpose_label_node);
auto split_axis_constant = as_type_ptr<Constant>(split->input_value(1).get_node_shared_ptr());
OutputTranspose output_transpose = GetOutputTransposes(split);
const auto transpose_axis_order = output_transpose.transpose_const->get_axis_vector_val();
const auto transpose_element_type = output_transpose.transpose_const->get_element_type();
const auto reversed_traspose_axis_order = ReverseTransposeOrder(transpose_axis_order);
const size_t split_axis = split_axis_constant->get_axis_vector_val()[0];
const size_t reversed_transposed_split_axis = reversed_traspose_axis_order[split_axis];
// insert transpose before split
{
auto input_node = split->input_value(0);
auto new_transpose_const = std::make_shared<Constant>(transpose_element_type,
Shape{transpose_axis_order.size()},
transpose_axis_order);
auto new_transpose = std::make_shared<Transpose>(input_node, new_transpose_const);
split->input(0).replace_source_output(new_transpose->output(0));
copy_runtime_info(input_node.get_node_shared_ptr(), {new_transpose, new_transpose_const});
register_new_node(new_transpose);
}
// update split axis
auto new_split_axis_const = std::make_shared<Constant>(split_axis_constant->get_element_type(),
Shape{},
reversed_transposed_split_axis);
split->input(1).replace_source_output(new_split_axis_const);
// remove split output transposes
for (size_t output_idx = 0; output_idx < split->get_output_size(); ++output_idx) {
for (auto& input : split->get_output_target_inputs(output_idx)) {
input.get_node()->output(0).replace(split->output(output_idx));
}
}
return true;
};
auto m = std::make_shared<Matcher>(transpose_label, matcher_name);
register_matcher(m, matcher_pass_callback);
}
ov::pass::TransposeSinkingSplitForward::TransposeSinkingSplitForward() {
MATCHER_SCOPE(TransposeSinkingSplitForward);
auto main_node_label = wrap_type<Split>(IfNodeHasTransposeInputs);
matcher_pass_callback matcher_pass_callback = [=](Matcher& m) {
const auto& pattern_to_output = m.get_pattern_value_map();
auto& main_node_output = pattern_to_output.at(main_node_label);
auto main_node = main_node_output.get_node_shared_ptr();
TransposeInputsInfo transpose_input_info = GetFirstTransposeInput(main_node);
sink_forward::RemoveZeroInputNode(main_node);
for (auto& new_node : sink_forward::InsertOutputTransposes(main_node, transpose_input_info)) {
register_new_node(new_node);
}
const auto transpose_axis_order = transpose_input_info.transpose_const->get_axis_vector_val();
auto split_node = as_type_ptr<Split>(main_node);
auto split_axis_constant = as_type_ptr<Constant>(split_node->input_value(1).get_node_shared_ptr());
const size_t split_axis = split_axis_constant->get_axis_vector_val()[0];
const size_t transposed_split_axis = transpose_axis_order[split_axis];
auto new_split_axis_const =
std::make_shared<Constant>(split_axis_constant->get_element_type(), Shape{}, transposed_split_axis);
split_node->input(1).replace_source_output(new_split_axis_const);
return true;
};
auto m = std::make_shared<Matcher>(main_node_label, matcher_name);
register_matcher(m, matcher_pass_callback);
}

View File

@ -0,0 +1,172 @@
#include "transformations/common_optimizations/transpose_sinking_utils.hpp"
#include <openvino/pass/pattern/op/or.hpp>
#include <transformations/utils/utils.hpp>
#include <utility>
#include "itt.hpp"
#include "openvino/op/util/op_types.hpp"
#include "openvino/opsets/opset9.hpp"
#include "openvino/pass/pattern/op/label.hpp"
#include "openvino/pass/pattern/op/wrap_type.hpp"
#include "openvino/util/common_util.hpp"
#include "openvino/util/log.hpp"
namespace transpose_sinking {
using namespace ov;
using namespace ov::opset9;
using NodePtr = std::shared_ptr<Node>;
TransposeInputsInfo GetFirstTransposeInput(NodePtr node) {
for (size_t input_idx = 0; input_idx < node->get_input_size(); ++input_idx) {
NodePtr input_node = node->get_input_node_shared_ptr(input_idx);
auto transpose_node = as_type_ptr<Transpose>(input_node);
if (!transpose_node)
continue;
auto constant_node = as_type_ptr<Constant>(transpose_node->input_value(1).get_node_shared_ptr());
if (!constant_node)
continue;
{
TransposeInputsInfo input_info;
input_info.transpose = transpose_node;
input_info.transpose_const = constant_node;
input_info.input_idx = input_idx;
return input_info;
}
}
return TransposeInputsInfo();
}
bool IfNodeHasTransposeInputs(const Output<Node>& output) {
TransposeInputsInfo inputs_info = GetFirstTransposeInput(output.get_node_shared_ptr());
return !inputs_info.isEmpty();
}
AxisVector ReverseTransposeOrder(const AxisVector& axis_order) {
AxisVector out(axis_order.size());
for (size_t i = 0; i < axis_order.size(); i++) {
out.at(axis_order[i]) = i;
}
return out;
}
void SwapOutputNames(Output<Node> output1, Output<Node> output2) {
const auto node2_output_names = output2.get_names();
output2.set_names(output1.get_names());
output1.set_names(node2_output_names);
}
void SwapFriendlyNames(NodePtr node1, NodePtr node2) {
const std::string node2_name = node2->get_friendly_name();
node2->set_friendly_name(node1->get_friendly_name());
node1->set_friendly_name(node2_name);
}
void SwapNames(NodePtr node1, NodePtr node2) {
SwapFriendlyNames(node1, node2);
SwapOutputNames(node1->output(0), node2->output(0));
}
namespace sink_forward {
// insert input reversed transposes, remove first input tranpose
void UpdateInputTransposes(NodePtr main_node, TransposeInputsInfo& transpose_input_info) {
if (transpose_input_info.isEmpty())
return;
const auto transpose_axis_order = transpose_input_info.transpose_const->get_axis_vector_val();
const auto reversed_traspose_axis_order = ReverseTransposeOrder(transpose_axis_order);
const size_t tranpose_input_index = transpose_input_info.input_idx;
const auto transpose_element_type = transpose_input_info.transpose_const->get_element_type();
for (size_t i = 0; i < main_node->get_input_size(); ++i) {
auto input_node = main_node->input_value(i);
if (i == tranpose_input_index) {
auto transpose_parent = input_node.get_node()->input_value(0);
main_node->input(i).replace_source_output(transpose_parent);
} else {
auto new_transpose_const = std::make_shared<Constant>(transpose_element_type,
Shape{reversed_traspose_axis_order.size()},
reversed_traspose_axis_order);
auto new_transpose = std::make_shared<Transpose>(input_node, new_transpose_const);
main_node->input(i).replace_source_output(new_transpose->output(0));
copy_runtime_info(input_node.get_node_shared_ptr(), {new_transpose, new_transpose_const});
}
}
}
void RemoveZeroInputNode(NodePtr main_node) {
auto input_node = main_node->input_value(0);
if (input_node.get_node()->get_input_size() < 1)
return;
auto parent_node = input_node.get_node()->input_value(0);
main_node->input(0).replace_source_output(parent_node);
}
NodeVector InsertOutputTransposes(NodePtr main_node, TransposeInputsInfo& transpose_input_info) {
if (transpose_input_info.isEmpty())
return {};
const auto transpose_axis_order = transpose_input_info.transpose_const->get_axis_vector_val();
const auto reversed_traspose_axis_order = ReverseTransposeOrder(transpose_axis_order);
const auto transpose_element_type = transpose_input_info.transpose_const->get_element_type();
NodeVector new_nodes;
for (size_t i = 0; i < main_node->get_output_size(); ++i) {
auto main_node_consumers = main_node->output(i).get_target_inputs();
auto new_transpose_const = std::make_shared<Constant>(transpose_element_type,
Shape{transpose_axis_order.size()},
transpose_axis_order);
auto new_transpose = std::make_shared<Transpose>(main_node->output(i), new_transpose_const);
for (auto& consumer : main_node_consumers) {
consumer.replace_source_output(new_transpose);
}
copy_runtime_info(main_node, {new_transpose, new_transpose_const});
SwapOutputNames(main_node->output(i), new_transpose->output(0));
if (main_node->get_output_size() > 1)
new_transpose->set_friendly_name(main_node->get_friendly_name() + "." + std::to_string(i));
else
SwapFriendlyNames(new_transpose, main_node);
new_nodes.push_back(new_transpose);
}
return new_nodes;
}
} // namespace sink_forward
namespace sink_backward {
NodeVector InsertTransposeBeforeNode(NodePtr main_node, std::shared_ptr<Constant> transpose_const) {
const auto transpose_axis_order = transpose_const->get_axis_vector_val();
const auto transpose_element_type = transpose_const->get_element_type();
NodeVector new_nodes;
for (size_t i = 0; i < main_node->get_input_size(); ++i) {
auto input_node = main_node->input_value(i);
auto new_transpose_const = std::make_shared<Constant>(transpose_element_type,
Shape{transpose_axis_order.size()},
transpose_axis_order);
auto new_transpose = std::make_shared<Transpose>(input_node, new_transpose_const);
main_node->input(i).replace_source_output(new_transpose->output(0));
copy_runtime_info(input_node.get_node_shared_ptr(), {new_transpose, new_transpose_const});
new_nodes.push_back(new_transpose);
}
return new_nodes;
}
} // namespace sink_backward
} // namespace transpose_sinking

View File

@ -0,0 +1,358 @@
// Copyright (C) 2022 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#include <transformations/common_optimizations/transpose_sinking_binary.hpp>
#include <transformations/init_node_info.hpp>
#include <openvino/frontend/manager.hpp>
#include <openvino/opsets/opset9.hpp>
#include <openvino/pass/manager.hpp>
#include "common_test_utils/ngraph_test_utils.hpp"
#include <functional>
#include "gtest/gtest.h"
namespace {
using NodePtr = std::shared_ptr<ov::Node>;
using ModelPtr = std::shared_ptr<ov::Model>;
using Output = ov::Output<ov::Node>;
// ----------------------------------------------------------------------------
class IBinaryFactory {
public:
IBinaryFactory() = default;
virtual ~IBinaryFactory() = default;
virtual NodePtr create(NodePtr parent_left_node, NodePtr parent_right_node) const = 0;
};
using BinaryFactoryPtr = std::shared_ptr<IBinaryFactory>;
template <typename BinaryT>
class BinaryFactory : public IBinaryFactory {
public:
BinaryFactory() = default;
NodePtr create(NodePtr parent_left_node, NodePtr parent_right_node) const override {
return std::make_shared<BinaryT>(parent_left_node, parent_right_node);
}
};
template <typename BinaryT>
BinaryFactoryPtr CreateBinaryFactory() {
return std::make_shared<BinaryFactory<BinaryT>>();
}
// ----------------------------------------------------------------------------
class IPassFactory {
public:
IPassFactory() = default;
virtual ~IPassFactory() = default;
virtual void registerPass(ov::pass::Manager& pass_manager) const = 0;
};
using PassFactoryPtr = std::shared_ptr<IPassFactory>;
template <typename PassT>
class PassFactory : public IPassFactory {
public:
void registerPass(ov::pass::Manager& pass_manager) const override {
pass_manager.register_pass<PassT>();
}
};
template <typename PassT>
PassFactoryPtr CreatePassFactory() {
return std::make_shared<PassFactory<PassT>>();
}
std::vector<BinaryFactoryPtr> binary_factories = {
CreateBinaryFactory<ov::opset9::Add>(),
CreateBinaryFactory<ov::opset9::Divide>(),
CreateBinaryFactory<ov::opset9::Maximum>(),
CreateBinaryFactory<ov::opset9::Minimum>(),
CreateBinaryFactory<ov::opset9::Mod>(),
CreateBinaryFactory<ov::opset9::Multiply>(),
CreateBinaryFactory<ov::opset9::Power>(),
CreateBinaryFactory<ov::opset9::SquaredDifference>(),
CreateBinaryFactory<ov::opset9::Subtract>()
};
std::vector<size_t> binary_operations_numbers = {1, 10};
std::vector<size_t> binary_transpose_input_indexes = {0, 1};
} // namespace
namespace binary {
namespace single_consumer {
namespace forward {
namespace one_input_transpose {
std::shared_ptr<ov::Model> CreateFunction(BinaryFactoryPtr binary_factory,
size_t num_binary_ops,
ov::element::Type input_type,
size_t binary_transpose_input_idx) {
const ov::Shape input_shape{1, 96, 55, 55};
const ov::Shape const_shape{1, 55, 55, 96};
auto X = std::make_shared<ov::opset9::Parameter>(input_type, input_shape);
auto ng_order0 = std::make_shared<ov::opset9::Constant>(ov::element::u64, ov::Shape{4}, ov::Shape{0, 2, 3, 1});
auto transpose0 = std::make_shared<ov::opset9::Transpose>(X, ng_order0);
NodePtr in_op = transpose0;
for (size_t i = 0; i < num_binary_ops; ++i) {
auto in_constant = std::make_shared<ov::opset9::Constant>(input_type, const_shape, ov::Shape{1});
if (!binary_transpose_input_idx)
in_op = binary_factory->create(in_op, in_constant);
else
in_op = binary_factory->create(in_constant, in_op);
}
return std::make_shared<ov::Model>(ov::OutputVector{in_op}, ov::ParameterVector{X});
}
std::shared_ptr<ov::Model> CreateReferenceFunction(BinaryFactoryPtr binary_factory,
size_t num_binary_ops,
ov::element::Type input_type,
size_t binary_transpose_input_idx) {
const ov::Shape input_shape{1, 96, 55, 55};
const ov::Shape const_shape{1, 55, 55, 96};
auto X = std::make_shared<ov::opset9::Parameter>(input_type, input_shape);
NodePtr in_op = X;
for (size_t i = 0; i < num_binary_ops; ++i) {
auto in_constant = std::make_shared<ov::opset9::Constant>(input_type, const_shape, ov::Shape{1});
auto transpose_reversed_const =
std::make_shared<ov::opset9::Constant>(ov::element::u64, ov::Shape{4}, ov::Shape{0, 3, 1, 2});
auto transpose_reversed = std::make_shared<ov::opset9::Transpose>(in_constant, transpose_reversed_const);
if (!binary_transpose_input_idx)
in_op = binary_factory->create(in_op, transpose_reversed);
else
in_op = binary_factory->create(transpose_reversed, in_op);
}
auto ng_order0 = std::make_shared<ov::opset9::Constant>(ov::element::u64, ov::Shape{4}, ov::Shape{0, 2, 3, 1});
auto transpose0 = std::make_shared<ov::opset9::Transpose>(in_op, ng_order0);
return std::make_shared<ov::Model>(ov::OutputVector{transpose0}, ov::ParameterVector{X});
}
} // namespace one_input_transpose
namespace double_transpose {
std::shared_ptr<ov::Model> CreateFunction(BinaryFactoryPtr binary_factory,
size_t num_binary_ops,
ov::element::Type input_type) {
const ov::Shape input_shape{1, 96, 55, 55};
auto X = std::make_shared<ov::opset9::Parameter>(input_type, input_shape);
auto ng_order0 = std::make_shared<ov::opset9::Constant>(ov::element::u64, ov::Shape{4}, ov::Shape{0, 2, 3, 1});
auto transpose0 = std::make_shared<ov::opset9::Transpose>(X, ng_order0);
NodePtr in_op = transpose0;
for (size_t i = 0; i < num_binary_ops; ++i) {
auto in_constant = std::make_shared<ov::opset9::Constant>(input_type, input_shape, ov::Shape{1});
auto ng_order1 = std::make_shared<ov::opset9::Constant>(ov::element::u64, ov::Shape{4}, ov::Shape{0, 2, 3, 1});
auto transpose1 = std::make_shared<ov::opset9::Transpose>(in_constant, ng_order1);
in_op = binary_factory->create(in_op, transpose1);
}
return std::make_shared<ov::Model>(ov::OutputVector{in_op}, ov::ParameterVector{X});
}
std::shared_ptr<ov::Model> CreateReferenceFunction(BinaryFactoryPtr binary_factory,
size_t num_binary_ops,
ov::element::Type input_type) {
const ov::Shape input_shape{1, 96, 55, 55};
auto X = std::make_shared<ov::opset9::Parameter>(input_type, input_shape);
NodePtr in_op = X;
for (size_t i = 0; i < num_binary_ops; ++i) {
auto in_constant = std::make_shared<ov::opset9::Constant>(input_type, input_shape, ov::Shape{1});
auto ng_order1 = std::make_shared<ov::opset9::Constant>(ov::element::u64, ov::Shape{4}, ov::Shape{0, 2, 3, 1});
auto transpose1 = std::make_shared<ov::opset9::Transpose>(in_constant, ng_order1);
auto transpose_reversed_const =
std::make_shared<ov::opset9::Constant>(ov::element::u64, ov::Shape{4}, ov::Shape{0, 3, 1, 2});
auto transpose_reversed = std::make_shared<ov::opset9::Transpose>(transpose1, transpose_reversed_const);
in_op = binary_factory->create(in_op, transpose_reversed);
}
auto ng_order0 = std::make_shared<ov::opset9::Constant>(ov::element::u64, ov::Shape{4}, ov::Shape{0, 2, 3, 1});
auto transpose0 = std::make_shared<ov::opset9::Transpose>(in_op, ng_order0);
return std::make_shared<ov::Model>(ov::OutputVector{transpose0}, ov::ParameterVector{X});
}
} // namespace double_transpose
} // namespace forward
namespace backward {
namespace one_input_transpose {
std::shared_ptr<ov::Model> CreateFunction(BinaryFactoryPtr binary_factory,
size_t num_binary_ops,
ov::element::Type input_type,
size_t binary_transpose_input_idx) {
const ov::Shape input_shape{1, 96, 55, 55};
auto X = std::make_shared<ov::opset9::Parameter>(input_type, input_shape);
NodePtr in_op = X;
for (size_t i = 0; i < num_binary_ops; ++i) {
auto in_constant = std::make_shared<ov::opset9::Constant>(input_type, input_shape, ov::Shape{1});
if (!binary_transpose_input_idx)
in_op = binary_factory->create(in_op, in_constant);
else
in_op = binary_factory->create(in_constant, in_op);
}
auto ng_order0 = std::make_shared<ov::opset9::Constant>(ov::element::u64, ov::Shape{4}, ov::Shape{0, 2, 3, 1});
auto transpose0 = std::make_shared<ov::opset9::Transpose>(in_op, ng_order0);
return std::make_shared<ov::Model>(ov::OutputVector{transpose0}, ov::ParameterVector{X});
}
std::shared_ptr<ov::Model> CreateReferenceFunction(BinaryFactoryPtr binary_factory,
size_t num_binary_ops,
ov::element::Type input_type,
size_t binary_transpose_input_idx) {
const ov::Shape input_shape{1, 96, 55, 55};
auto X = std::make_shared<ov::opset9::Parameter>(input_type, input_shape);
auto ng_order0 = std::make_shared<ov::opset9::Constant>(ov::element::u64, ov::Shape{4}, ov::Shape{0, 2, 3, 1});
auto transpose0 = std::make_shared<ov::opset9::Transpose>(X, ng_order0);
NodePtr in_op = transpose0;
for (size_t i = 0; i < num_binary_ops; ++i) {
auto in_constant = std::make_shared<ov::opset9::Constant>(input_type, input_shape, ov::Shape{1});
auto ng_order = std::make_shared<ov::opset9::Constant>(ov::element::u64, ov::Shape{4}, ov::Shape{0, 2, 3, 1});
auto transpose = std::make_shared<ov::opset9::Transpose>(in_constant, ng_order);
if (!binary_transpose_input_idx)
in_op = binary_factory->create(in_op, transpose);
else
in_op = binary_factory->create(transpose, in_op);
}
return std::make_shared<ov::Model>(ov::OutputVector{in_op}, ov::ParameterVector{X});
}
} // namespace one_input_transpose
} // namespace backward
} // namespace single_consumer
} // namespace binary
using CreateGraphBinaryF = std::function<std::shared_ptr<ov::Model>(BinaryFactoryPtr unary_factory,
size_t num_binary_ops,
ov::element::Type input_type,
size_t binary_transpose_input_idx)>;
using TestBinaryParams = std::tuple<BinaryFactoryPtr,
PassFactoryPtr,
size_t, /* num_binary_ops */
CreateGraphBinaryF, /* model_factory */
CreateGraphBinaryF, /* reference_model_factory */
ov::element::Type, /* input type */
size_t>; /* binary_transpose_input_idx */
class TransposeSinkingBinaryTestFixture : public ::testing::WithParamInterface<TestBinaryParams>,
public TransformationTestsF {};
TEST_P(TransposeSinkingBinaryTestFixture, CompareFunctions) {
BinaryFactoryPtr unary_factory;
PassFactoryPtr pass_factory;
size_t num_binary_ops;
CreateGraphBinaryF model_factory;
CreateGraphBinaryF reference_model_factory;
ov::element::Type input_type;
size_t binary_transpose_input_idx;
std::tie(unary_factory,
pass_factory,
num_binary_ops,
model_factory,
reference_model_factory,
input_type,
binary_transpose_input_idx) = this->GetParam();
model = model_factory(unary_factory, num_binary_ops, input_type, binary_transpose_input_idx);
model_ref = reference_model_factory(unary_factory, num_binary_ops, input_type, binary_transpose_input_idx);
pass_factory->registerPass(manager);
}
INSTANTIATE_TEST_SUITE_P(TransposeSinkingBinaryForwardTestSuite, TransposeSinkingBinaryTestFixture,
::testing::Combine(::testing::ValuesIn(binary_factories),
::testing::Values(CreatePassFactory<ov::pass::TransposeSinkingBinaryElementwiseForward>()),
::testing::ValuesIn(binary_operations_numbers),
::testing::Values(binary::single_consumer::forward::one_input_transpose::CreateFunction),
::testing::Values(binary::single_consumer::forward::one_input_transpose::CreateReferenceFunction),
::testing::Values(ov::element::f32),
::testing::ValuesIn(binary_transpose_input_indexes)));
INSTANTIATE_TEST_SUITE_P(
TransposeSinkingBinaryBackwardTestSuite,
TransposeSinkingBinaryTestFixture,
::testing::Combine(::testing::ValuesIn(binary_factories),
::testing::Values(CreatePassFactory<ov::pass::TransposeSinkingBinaryElementwiseBackward>()),
::testing::ValuesIn(binary_operations_numbers),
::testing::Values(binary::single_consumer::backward::one_input_transpose::CreateFunction),
::testing::Values(binary::single_consumer::backward::one_input_transpose::CreateReferenceFunction),
::testing::Values(ov::element::f32),
::testing::ValuesIn(binary_transpose_input_indexes)));
// --------------------------------------------------------------------------------------
using CreateGraphBinaryTwoTransposeInputsF = std::function<
std::shared_ptr<ov::Model>(BinaryFactoryPtr unary_factory, size_t num_binary_ops, ov::element::Type input_type)>;
using TestBinaryTwoTransposeInputsParams = std::tuple<BinaryFactoryPtr,
PassFactoryPtr,
size_t, /* num_binary_ops */
CreateGraphBinaryTwoTransposeInputsF, /* model_factory */
CreateGraphBinaryTwoTransposeInputsF, /* reference_model_factory */
ov::element::Type>; /* input type */
class TransposeSinkingBinaryTwoTransposeInputsTestFixture
: public ::testing::WithParamInterface<TestBinaryTwoTransposeInputsParams>,
public TransformationTestsF {};
TEST_P(TransposeSinkingBinaryTwoTransposeInputsTestFixture, CompareFunctions) {
BinaryFactoryPtr unary_factory;
PassFactoryPtr pass_factory;
size_t num_binary_ops;
CreateGraphBinaryTwoTransposeInputsF model_factory;
CreateGraphBinaryTwoTransposeInputsF reference_model_factory;
ov::element::Type input_type;
std::tie(unary_factory, pass_factory, num_binary_ops, model_factory, reference_model_factory, input_type) =
this->GetParam();
model = model_factory(unary_factory, num_binary_ops, input_type);
model_ref = reference_model_factory(unary_factory, num_binary_ops, input_type);
pass_factory->registerPass(manager);
}
INSTANTIATE_TEST_SUITE_P(
TransposeSinkingBinaryTwoTransposeInputsForwardTestSuite,
TransposeSinkingBinaryTwoTransposeInputsTestFixture,
::testing::Combine(::testing::ValuesIn(binary_factories),
::testing::Values(CreatePassFactory<ov::pass::TransposeSinkingBinaryElementwiseForward>()),
::testing::ValuesIn(binary_operations_numbers),
::testing::Values(binary::single_consumer::forward::double_transpose::CreateFunction),
::testing::Values(binary::single_consumer::forward::double_transpose::CreateReferenceFunction),
::testing::Values(ov::element::f32)));

View File

@ -0,0 +1,400 @@
// Copyright (C) 2022 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#include <transformations/common_optimizations/transpose_sinking_concat.hpp>
#include <transformations/init_node_info.hpp>
#include <openvino/frontend/manager.hpp>
#include <openvino/opsets/opset9.hpp>
#include <openvino/pass/manager.hpp>
#include "common_test_utils/ngraph_test_utils.hpp"
#include <functional>
#include "gtest/gtest.h"
namespace {
using NodePtr = std::shared_ptr<ov::Node>;
using ModelPtr = std::shared_ptr<ov::Model>;
using Output = ov::Output<ov::Node>;
// ----------------------------------------------------------------------------
class IBinaryFactory {
public:
IBinaryFactory() = default;
virtual ~IBinaryFactory() = default;
virtual NodePtr create(NodePtr parent_left_node, NodePtr parent_right_node) const = 0;
};
using BinaryFactoryPtr = std::shared_ptr<IBinaryFactory>;
template <typename BinaryT>
class BinaryFactory : public IBinaryFactory {
public:
BinaryFactory() = default;
NodePtr create(NodePtr parent_left_node, NodePtr parent_right_node) const override {
return std::make_shared<BinaryT>(parent_left_node, parent_right_node);
}
};
template <typename BinaryT>
BinaryFactoryPtr CreateBinaryFactory() {
return std::make_shared<BinaryFactory<BinaryT>>();
}
// ----------------------------------------------------------------------------
class IPassFactory {
public:
IPassFactory() = default;
virtual ~IPassFactory() = default;
virtual void registerPass(ov::pass::Manager& pass_manager) const = 0;
};
using PassFactoryPtr = std::shared_ptr<IPassFactory>;
template <typename PassT>
class PassFactory : public IPassFactory {
public:
void registerPass(ov::pass::Manager& pass_manager) const override {
pass_manager.register_pass<PassT>();
}
};
template <typename PassT>
PassFactoryPtr CreatePassFactory() {
return std::make_shared<PassFactory<PassT>>();
}
std::vector<BinaryFactoryPtr> binary_factories = {
CreateBinaryFactory<ov::opset9::Add>(),
CreateBinaryFactory<ov::opset9::Divide>(),
CreateBinaryFactory<ov::opset9::Maximum>(),
CreateBinaryFactory<ov::opset9::Minimum>(),
CreateBinaryFactory<ov::opset9::Mod>(),
CreateBinaryFactory<ov::opset9::Multiply>(),
CreateBinaryFactory<ov::opset9::Power>(),
CreateBinaryFactory<ov::opset9::SquaredDifference>(),
CreateBinaryFactory<ov::opset9::Subtract>()
};
std::vector<size_t> binary_operations_numbers = {1, 10};
std::vector<size_t> binary_transpose_input_indexes = {0, 1};
} // namespace
using CreateGraphConcatF = std::function< std::shared_ptr<ov::Model> (size_t num_concat_ops,
ov::element::Type input_type,
size_t concat_transpose_input_idx,
size_t num_concat_inputs) >;
using TestConcatParams = std::tuple<PassFactoryPtr,
size_t, /* num_concat_ops */
CreateGraphConcatF, /* model_factory */
CreateGraphConcatF, /* reference_model_factory */
ov::element::Type, /* input type */
size_t, /* concat_transpose_input_idx */
size_t>; /* num_concat_inputs */
class TransposeSinkingConcatTestFixture: public ::testing::WithParamInterface<TestConcatParams>,
public TransformationTestsF {};
namespace {
std::vector<size_t> concat_operations_numbers = {1, 10};
std::vector<size_t> concat_transpose_input_indexes = {0, 2};
} // namespace
namespace single_consumer {
namespace forward {
namespace one_input_transpose {
std::shared_ptr<ov::Model> CreateFunction(size_t num_concat_ops,
ov::element::Type input_type,
size_t concat_transpose_input_idx,
size_t num_concat_inputs) {
const ov::Shape input_shape{1, 96, 55, 55};
const ov::Shape const_shape{1, 55, 55, 96};
auto X = std::make_shared<ov::opset9::Parameter>(input_type, input_shape);
auto ng_order0 = std::make_shared<ov::opset9::Constant>(ov::element::u64, ov::Shape{4}, ov::Shape{0, 2, 3, 1});
auto transpose0 = std::make_shared<ov::opset9::Transpose>(X, ng_order0);
NodePtr in_op = transpose0;
for (size_t i = 0; i < num_concat_ops; ++i) {
ov::OutputVector concat_inputs;
for (size_t j = 0; j < num_concat_inputs; ++j) {
if (j == concat_transpose_input_idx)
concat_inputs.push_back(in_op);
else
concat_inputs.push_back(std::make_shared<ov::opset9::Constant>(input_type, const_shape, ov::Shape{1}));
}
in_op = std::make_shared<ov::opset9::Concat>(concat_inputs, 1);
}
return std::make_shared<ov::Model>(ov::OutputVector{in_op}, ov::ParameterVector{X});
}
std::shared_ptr<ov::Model> CreateReferenceFunction(size_t num_concat_ops,
ov::element::Type input_type,
size_t concat_transpose_input_idx,
size_t num_concat_inputs) {
const ov::Shape input_shape{1, 96, 55, 55};
const ov::Shape const_shape{1, 55, 55, 96};
auto X = std::make_shared<ov::opset9::Parameter>(input_type, input_shape);
NodePtr in_op = X;
for (size_t i = 0; i < num_concat_ops; ++i) {
ov::OutputVector concat_inputs;
for (size_t j = 0; j < num_concat_inputs; ++j) {
if (j == concat_transpose_input_idx) {
concat_inputs.push_back(in_op);
} else {
auto in_constant = std::make_shared<ov::opset9::Constant>(input_type, const_shape, ov::Shape{1});
auto transpose_reversed_const =
std::make_shared<ov::opset9::Constant>(ov::element::u64, ov::Shape{4}, ov::Shape{0, 3, 1, 2});
auto transpose_reversed =
std::make_shared<ov::opset9::Transpose>(in_constant, transpose_reversed_const);
concat_inputs.push_back(transpose_reversed);
}
}
in_op = std::make_shared<ov::opset9::Concat>(concat_inputs, 2);
}
auto ng_order0 = std::make_shared<ov::opset9::Constant>(ov::element::u64, ov::Shape{4}, ov::Shape{0, 2, 3, 1});
auto transpose0 = std::make_shared<ov::opset9::Transpose>(in_op, ng_order0);
return std::make_shared<ov::Model>(ov::OutputVector{transpose0}, ov::ParameterVector{X});
}
} // namespace one_input_transpose
namespace double_transpose {
std::shared_ptr<ov::Model> CreateFunction(size_t num_concat_ops,
ov::element::Type input_type,
size_t num_concat_inputs) {
const ov::Shape input_shape{1, 96, 55, 55};
auto X = std::make_shared<ov::opset9::Parameter>(input_type, input_shape);
auto ng_order0 = std::make_shared<ov::opset9::Constant>(ov::element::u64, ov::Shape{4}, ov::Shape{0, 2, 3, 1});
auto transpose0 = std::make_shared<ov::opset9::Transpose>(X, ng_order0);
NodePtr in_op = transpose0;
for (size_t i = 0; i < num_concat_ops; ++i) {
ov::OutputVector concat_inputs;
concat_inputs.push_back(in_op);
for (size_t j = 1; j < num_concat_inputs; ++j) {
auto in_constant = std::make_shared<ov::opset9::Constant>(input_type, input_shape, ov::Shape{1});
auto ng_order1 =
std::make_shared<ov::opset9::Constant>(ov::element::u64, ov::Shape{4}, ov::Shape{0, 2, 3, 1});
auto transpose1 = std::make_shared<ov::opset9::Transpose>(in_constant, ng_order1);
concat_inputs.push_back(transpose1);
}
in_op = std::make_shared<ov::opset9::Concat>(concat_inputs, 1);
}
return std::make_shared<ov::Model>(ov::OutputVector{in_op}, ov::ParameterVector{X});
}
std::shared_ptr<ov::Model> CreateReferenceFunction(size_t num_concat_ops,
ov::element::Type input_type,
size_t num_concat_inputs) {
const ov::Shape input_shape{1, 96, 55, 55};
auto X = std::make_shared<ov::opset9::Parameter>(input_type, input_shape);
NodePtr in_op = X;
for (size_t i = 0; i < num_concat_ops; ++i) {
ov::OutputVector concat_inputs;
concat_inputs.push_back(in_op);
for (size_t j = 1; j < num_concat_inputs; ++j) {
auto in_constant = std::make_shared<ov::opset9::Constant>(input_type, input_shape, ov::Shape{1});
auto ng_order1 =
std::make_shared<ov::opset9::Constant>(ov::element::u64, ov::Shape{4}, ov::Shape{0, 2, 3, 1});
auto transpose1 = std::make_shared<ov::opset9::Transpose>(in_constant, ng_order1);
auto transpose_reversed_const =
std::make_shared<ov::opset9::Constant>(ov::element::u64, ov::Shape{4}, ov::Shape{0, 3, 1, 2});
auto transpose_reversed = std::make_shared<ov::opset9::Transpose>(transpose1, transpose_reversed_const);
concat_inputs.push_back(transpose_reversed);
}
in_op = std::make_shared<ov::opset9::Concat>(concat_inputs, 2);
}
auto ng_order0 = std::make_shared<ov::opset9::Constant>(ov::element::u64, ov::Shape{4}, ov::Shape{0, 2, 3, 1});
auto transpose0 = std::make_shared<ov::opset9::Transpose>(in_op, ng_order0);
return std::make_shared<ov::Model>(ov::OutputVector{transpose0}, ov::ParameterVector{X});
}
} // namespace double_transpose
} // namespace forward
namespace backward {
std::shared_ptr<ov::Model> CreateFunction(size_t num_concat_ops,
ov::element::Type input_type,
size_t concat_transpose_input_idx,
size_t num_concat_inputs) {
const ov::Shape input_shape{1, 96, 55, 55};
auto X = std::make_shared<ov::opset9::Parameter>(input_type, input_shape);
NodePtr in_op = X;
for (size_t i = 0; i < num_concat_ops; ++i) {
ov::OutputVector concat_inputs;
for (size_t j = 0; j < num_concat_inputs; ++j) {
if (j == concat_transpose_input_idx)
concat_inputs.push_back(in_op);
else
concat_inputs.push_back(std::make_shared<ov::opset9::Constant>(input_type, input_shape, ov::Shape{1}));
}
in_op = std::make_shared<ov::opset9::Concat>(concat_inputs, 1);
}
auto ng_order0 = std::make_shared<ov::opset9::Constant>(ov::element::u64, ov::Shape{4}, ov::Shape{0, 2, 3, 1});
auto transpose0 = std::make_shared<ov::opset9::Transpose>(in_op, ng_order0);
return std::make_shared<ov::Model>(ov::OutputVector{transpose0}, ov::ParameterVector{X});
}
std::shared_ptr<ov::Model> CreateReferenceFunction(size_t num_concat_ops,
ov::element::Type input_type,
size_t concat_transpose_input_idx,
size_t num_concat_inputs) {
const ov::Shape input_shape{1, 96, 55, 55};
auto X = std::make_shared<ov::opset9::Parameter>(input_type, input_shape);
auto ng_order0 = std::make_shared<ov::opset9::Constant>(ov::element::u64, ov::Shape{4}, ov::Shape{0, 2, 3, 1});
auto transpose0 = std::make_shared<ov::opset9::Transpose>(X, ng_order0);
NodePtr in_op = transpose0;
for (size_t i = 0; i < num_concat_ops; ++i) {
ov::OutputVector concat_inputs;
for (size_t j = 0; j < num_concat_inputs; ++j) {
if (j == concat_transpose_input_idx) {
concat_inputs.push_back(in_op);
} else {
auto in_constant = std::make_shared<ov::opset9::Constant>(input_type, input_shape, ov::Shape{1});
auto transpose_reversed_const = std::make_shared<ov::opset9::Constant>(ov::element::u64, ov::Shape{4}, ov::Shape{0, 2, 3, 1});
auto transpose_reversed = std::make_shared<ov::opset9::Transpose>(in_constant, transpose_reversed_const);
concat_inputs.push_back(transpose_reversed);
}
}
in_op = std::make_shared<ov::opset9::Concat>(concat_inputs, 3);
}
return std::make_shared<ov::Model>(ov::OutputVector{in_op}, ov::ParameterVector{X});
}
} // namespace backward
} // namespace single_consumer
TEST_P(TransposeSinkingConcatTestFixture, CompareFunctions) {
PassFactoryPtr pass_factory;
size_t num_concat_ops;
CreateGraphConcatF model_factory;
CreateGraphConcatF reference_model_factory;
ov::element::Type input_type;
size_t concat_transpose_input_idx;
size_t num_concat_inputs;
std::tie(pass_factory,
num_concat_ops,
model_factory,
reference_model_factory,
input_type,
concat_transpose_input_idx,
num_concat_inputs) = this->GetParam();
model = model_factory(num_concat_ops, input_type, concat_transpose_input_idx, num_concat_inputs);
model_ref = reference_model_factory(num_concat_ops, input_type, concat_transpose_input_idx, num_concat_inputs);
pass_factory->registerPass(manager);
}
INSTANTIATE_TEST_SUITE_P(TransposeSinkingConcatForwardTestSuite, TransposeSinkingConcatTestFixture,
::testing::Combine(::testing::Values(CreatePassFactory<ov::pass::TransposeSinkingConcatForward>()),
::testing::ValuesIn(concat_operations_numbers),
::testing::Values(single_consumer::forward::one_input_transpose::CreateFunction),
::testing::Values(single_consumer::forward::one_input_transpose::CreateReferenceFunction),
::testing::Values(ov::element::f32),
::testing::ValuesIn(concat_transpose_input_indexes),
::testing::Values(5)));
INSTANTIATE_TEST_SUITE_P(TransposeSinkingConcatBackwardTestSuite, TransposeSinkingConcatTestFixture,
::testing::Combine(::testing::Values(CreatePassFactory<ov::pass::TransposeSinkingConcatBackward>()),
::testing::ValuesIn(concat_operations_numbers),
::testing::Values(single_consumer::backward::CreateFunction),
::testing::Values(single_consumer::backward::CreateReferenceFunction),
::testing::Values(ov::element::f32),
::testing::ValuesIn(concat_transpose_input_indexes),
::testing::Values(5)));
// --------------------------------------------------------------------------------------
using CreateGraphConcatAllTransposesInputF = std::function<std::shared_ptr<ov::Model>(size_t num_concat_ops,
ov::element::Type input_type,
size_t num_concat_inputs)>;
using TestConcatAllTransposesInputParams = std::tuple<PassFactoryPtr,
size_t, /* num_concat_ops */
CreateGraphConcatAllTransposesInputF, /* model_factory */
CreateGraphConcatAllTransposesInputF, /* reference_model_factory */
ov::element::Type, /* input type */
size_t>; /* num_concat_inputs */
class TransposeSinkingConcatAllTransposesInputTestFixture
: public ::testing::WithParamInterface<TestConcatAllTransposesInputParams>,
public TransformationTestsF {};
TEST_P(TransposeSinkingConcatAllTransposesInputTestFixture, CompareFunctions) {
PassFactoryPtr pass_factory;
size_t num_concat_ops;
CreateGraphConcatAllTransposesInputF model_factory;
CreateGraphConcatAllTransposesInputF reference_model_factory;
ov::element::Type input_type;
size_t num_concat_inputs;
std::tie(pass_factory,
num_concat_ops,
model_factory,
reference_model_factory,
input_type,
num_concat_inputs) = this->GetParam();
model = model_factory(num_concat_ops, input_type, num_concat_inputs);
model_ref = reference_model_factory(num_concat_ops, input_type, num_concat_inputs);
pass_factory->registerPass(manager);
}
INSTANTIATE_TEST_SUITE_P(
TransposeSinkingConcatForwardAllTransposesTestSuite,
TransposeSinkingConcatAllTransposesInputTestFixture,
::testing::Combine(::testing::Values(CreatePassFactory<ov::pass::TransposeSinkingConcatForward>()),
::testing::ValuesIn(concat_operations_numbers),
::testing::Values(single_consumer::forward::double_transpose::CreateFunction),
::testing::Values(single_consumer::forward::double_transpose::CreateReferenceFunction),
::testing::Values(ov::element::f32),
::testing::Values(5)));

View File

@ -0,0 +1,552 @@
// Copyright (C) 2022 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#include <transformations/common_optimizations/transpose_sinking_split.hpp>
#include <transformations/init_node_info.hpp>
#include <openvino/frontend/manager.hpp>
#include <openvino/opsets/opset9.hpp>
#include <openvino/pass/manager.hpp>
#include "common_test_utils/ngraph_test_utils.hpp"
#include <functional>
#include "gtest/gtest.h"
namespace {
using NodePtr = std::shared_ptr<ov::Node>;
using ModelPtr = std::shared_ptr<ov::Model>;
using Output = ov::Output<ov::Node>;
// ----------------------------------------------------------------------------
class IBinaryFactory {
public:
IBinaryFactory() = default;
virtual ~IBinaryFactory() = default;
virtual NodePtr create(NodePtr parent_left_node, NodePtr parent_right_node) const = 0;
};
using BinaryFactoryPtr = std::shared_ptr<IBinaryFactory>;
template <typename BinaryT>
class BinaryFactory : public IBinaryFactory {
public:
BinaryFactory() = default;
NodePtr create(NodePtr parent_left_node, NodePtr parent_right_node) const override {
return std::make_shared<BinaryT>(parent_left_node, parent_right_node);
}
};
template <typename BinaryT>
BinaryFactoryPtr CreateBinaryFactory() {
return std::make_shared<BinaryFactory<BinaryT>>();
}
// ----------------------------------------------------------------------------
class IPassFactory {
public:
IPassFactory() = default;
virtual ~IPassFactory() = default;
virtual void registerPass(ov::pass::Manager& pass_manager) const = 0;
};
using PassFactoryPtr = std::shared_ptr<IPassFactory>;
template <typename PassT>
class PassFactory : public IPassFactory {
public:
void registerPass(ov::pass::Manager& pass_manager) const override {
pass_manager.register_pass<PassT>();
}
};
template <typename PassT>
PassFactoryPtr CreatePassFactory() {
return std::make_shared<PassFactory<PassT>>();
}
std::vector<BinaryFactoryPtr> binary_factories = {
CreateBinaryFactory<ov::opset9::Add>(),
CreateBinaryFactory<ov::opset9::Divide>(),
CreateBinaryFactory<ov::opset9::Maximum>(),
CreateBinaryFactory<ov::opset9::Minimum>(),
CreateBinaryFactory<ov::opset9::Mod>(),
CreateBinaryFactory<ov::opset9::Multiply>(),
CreateBinaryFactory<ov::opset9::Power>(),
CreateBinaryFactory<ov::opset9::SquaredDifference>(),
CreateBinaryFactory<ov::opset9::Subtract>()
};
std::vector<size_t> binary_operations_numbers = {1, 10};
std::vector<size_t> binary_transpose_input_indexes = {0, 1};
} // namespace
// --------------------------------------------------------------------------------------
using CreateGraphSplitForwardF = std::function< std::shared_ptr<ov::Model> (size_t num_split_ops,
size_t num_split_outputs,
ov::element::Type input_type)>;
using TestSplitForwardParams = std::tuple<PassFactoryPtr,
size_t, /* num_split_ops */
size_t, /* num_split_outputs */
CreateGraphSplitForwardF, /* model_factory */
CreateGraphSplitForwardF, /* reference_model_factory */
ov::element::Type> /* input type */;
class TransposeSinkingSplitForwardTestFixture: public ::testing::WithParamInterface<TestSplitForwardParams>,
public TransformationTestsF {};
namespace {
std::vector<size_t> split_operations_numbers = {1, 10};
std::vector<size_t> split_outputs_numbers = {2, 3};
} // namespace
namespace split {
namespace forward {
std::shared_ptr<ov::Model> CreateFunction(size_t num_split_ops,
size_t num_split_outputs,
ov::element::Type input_type) {
const ov::Shape input_shape{96, static_cast<size_t>(std::pow(num_split_outputs, num_split_ops + 1)), 55, 55};
auto X = std::make_shared<ov::opset9::Parameter>(input_type, input_shape);
auto ng_order0 = std::make_shared<ov::opset9::Constant>(ov::element::u64, ov::Shape{4}, ov::Shape{0, 3, 1, 2});
auto transpose0 = std::make_shared<ov::opset9::Transpose>(X, ng_order0);
ov::OutputVector outputs;
Output in_op = transpose0->output(0);
for (size_t i = 0; i < num_split_ops; ++i) {
auto split_axis_const = std::make_shared<ov::opset9::Constant>(ov::element::u64,
ov::Shape{},
2);
auto split = std::make_shared<ov::opset9::Split>(in_op, split_axis_const, num_split_outputs);
for (size_t num_output = 0; num_output < num_split_outputs - 1; ++num_output) {
outputs.push_back(split->output(num_output));
}
in_op = split->output(num_split_outputs - 1);
}
outputs.push_back(in_op);
return std::make_shared<ov::Model>(outputs, ov::ParameterVector{X});
}
std::shared_ptr<ov::Model> CreateReferenceFunction(size_t num_split_ops,
size_t num_split_outputs,
ov::element::Type input_type) {
const ov::Shape input_shape{96, static_cast<size_t>(std::pow(num_split_outputs, num_split_ops + 1)), 55, 55};
auto X = std::make_shared<ov::opset9::Parameter>(input_type, input_shape);
ov::OutputVector outputs;
Output in_op = X->output(0);
for (size_t i = 0; i < num_split_ops; ++i) {
auto split_axis_const = std::make_shared<ov::opset9::Constant>(ov::element::u64,
ov::Shape{},
1);
auto split = std::make_shared<ov::opset9::Split>(in_op, split_axis_const, num_split_outputs);
for (size_t num_output = 0; num_output < num_split_outputs - 1; ++num_output) {
auto ng_order0 = std::make_shared<ov::opset9::Constant>(ov::element::u64, ov::Shape{4}, ov::Shape{0, 3, 1, 2});
auto transpose0 = std::make_shared<ov::opset9::Transpose>(split->output(num_output), ng_order0);
outputs.push_back(transpose0);
}
in_op = split->output(num_split_outputs - 1);
}
auto ng_order0 = std::make_shared<ov::opset9::Constant>(ov::element::u64, ov::Shape{4}, ov::Shape{0, 3, 1, 2});
auto transpose0 = std::make_shared<ov::opset9::Transpose>(in_op, ng_order0);
outputs.push_back(transpose0);
return std::make_shared<ov::Model>(outputs, ov::ParameterVector{X});
}
} // namespace forward
} // namespace split
TEST_P(TransposeSinkingSplitForwardTestFixture, CompareFunctions) {
PassFactoryPtr pass_factory;
size_t num_split_ops;
size_t num_split_outputs;
CreateGraphSplitForwardF model_factory;
CreateGraphSplitForwardF reference_model_factory;
ov::element::Type input_type;
std::tie(pass_factory,
num_split_ops,
num_split_outputs,
model_factory,
reference_model_factory,
input_type) = this->GetParam();
model = model_factory(num_split_ops, num_split_outputs, input_type);
model_ref = reference_model_factory(num_split_ops, num_split_outputs, input_type);
pass_factory->registerPass(manager);
}
INSTANTIATE_TEST_SUITE_P(
TransposeSinkingSplitForwardTestSuite,
TransposeSinkingSplitForwardTestFixture,
::testing::Combine(::testing::Values(CreatePassFactory<ov::pass::TransposeSinkingSplitForward>()),
::testing::ValuesIn(split_operations_numbers),
::testing::ValuesIn(split_outputs_numbers),
::testing::Values(split::forward::CreateFunction),
::testing::Values(split::forward::CreateReferenceFunction),
::testing::Values(ov::element::f32)));
// --------------------------------------------------------------------------------------
using CreateGraphSplitBackwardF = std::function< std::shared_ptr<ov::Model> (size_t split_tree_depth,
size_t num_split_outputs,
ov::element::Type input_type)>;
using TestSplitBackwardParams = std::tuple<PassFactoryPtr,
size_t, /* split_tree_depth */
size_t, /* num_split_outputs */
CreateGraphSplitBackwardF, /* model_factory */
CreateGraphSplitBackwardF, /* reference_model_factory */
ov::element::Type> /* input type */;
class TransposeSinkingSplitBackwardTestFixture: public ::testing::WithParamInterface<TestSplitBackwardParams>,
public TransformationTestsF {};
namespace {
std::vector<size_t> split_tree_depth_nums = {1, 3};
} // namespace
// --------------------------------------------------------------------------------------
namespace split {
namespace backward {
class SplitFactory {
public:
SplitFactory(size_t axis, size_t n_outputs, ov::element::Type elem_type) :
_axis(axis), _n_outputs(n_outputs), _elem_type(elem_type) {}
NodePtr create(Output parent) const {
auto split_axis_const = std::make_shared<ov::opset9::Constant>(_elem_type,
ov::Shape{},
_axis);
return std::make_shared<ov::opset9::Split>(parent, split_axis_const, _n_outputs);
}
private:
const size_t _axis;
const size_t _n_outputs;
const ov::element::Type _elem_type;
};
void CreateSplitTree(size_t max_depth, size_t depth, Output parent, const SplitFactory & split_factory, ov::OutputVector & leaves) {
if (depth == max_depth) {
leaves.push_back(parent);
return;
}
auto split = split_factory.create(parent);
for (size_t output_idx = 0; output_idx < split->get_output_size(); ++output_idx) {
CreateSplitTree(max_depth, depth + 1, split->output(output_idx), split_factory, leaves);
}
}
std::shared_ptr<ov::Model> CreateFunction(size_t split_tree_depth,
size_t num_split_outputs,
ov::element::Type input_type) {
const size_t split_input_dim_value = static_cast<size_t>(std::pow(num_split_outputs, split_tree_depth + 1));
const ov::Shape input_shape{96, split_input_dim_value, 55, 55};
auto X = std::make_shared<ov::opset9::Parameter>(input_type, input_shape);
ov::OutputVector split_tree_leaves;
{
SplitFactory split_factory(/* axis */ 1, num_split_outputs, /* elem_type */ ov::element::u64);
CreateSplitTree(split_tree_depth, /* depth */ 0, X->output(0), split_factory, split_tree_leaves);
}
ov::OutputVector outputs;
for (auto& split_tree_leaf : split_tree_leaves) {
auto ng_order = std::make_shared<ov::opset9::Constant>(ov::element::u64, ov::Shape{4}, ov::Shape{0, 3, 1, 2});
auto transpose = std::make_shared<ov::opset9::Transpose>(split_tree_leaf, ng_order);
const size_t split_dim_current_value = static_cast<size_t>(split_input_dim_value / std::pow(num_split_outputs, split_tree_depth));
auto reshape_const = std::make_shared<ov::opset9::Constant>(ov::element::u64, ov::Shape{3}, ov::Shape{96, 55, split_dim_current_value * 55});
auto reshape = std::make_shared<ov::opset9::Reshape>(transpose, reshape_const, false);
outputs.push_back(reshape);
}
return std::make_shared<ov::Model>(outputs, ov::ParameterVector{X});
}
std::shared_ptr<ov::Model> CreateReferenceFunction(size_t split_tree_depth,
size_t num_split_outputs,
ov::element::Type input_type) {
const size_t split_input_dim_value = static_cast<size_t>(std::pow(num_split_outputs, split_tree_depth + 1));
const ov::Shape input_shape{96, split_input_dim_value, 55, 55};
auto X = std::make_shared<ov::opset9::Parameter>(input_type, input_shape);
auto ng_order = std::make_shared<ov::opset9::Constant>(ov::element::u64, ov::Shape{4}, ov::Shape{0, 3, 1, 2});
auto transpose = std::make_shared<ov::opset9::Transpose>(X, ng_order);
ov::OutputVector split_tree_leaves;
{
SplitFactory split_factory(/* axis */ 2, num_split_outputs, /* elem_type */ ov::element::u64);
CreateSplitTree(split_tree_depth, /* depth */ 0, transpose->output(0), split_factory, split_tree_leaves);
}
ov::OutputVector outputs;
for (auto& split_tree_leaf : split_tree_leaves) {
const size_t split_dim_current_value = static_cast<size_t>(split_input_dim_value / std::pow(num_split_outputs, split_tree_depth));
auto reshape_const = std::make_shared<ov::opset9::Constant>(ov::element::u64, ov::Shape{3}, ov::Shape{96, 55, split_dim_current_value * 55});
auto reshape = std::make_shared<ov::opset9::Reshape>(split_tree_leaf, reshape_const, false);
outputs.push_back(reshape);
}
return std::make_shared<ov::Model>(outputs, ov::ParameterVector{X});
}
} // namespace backward
} // namespace split
TEST_P(TransposeSinkingSplitBackwardTestFixture, CompareFunctions) {
PassFactoryPtr pass_factory;
size_t split_tree_depth;
size_t num_split_outputs;
CreateGraphSplitBackwardF model_factory;
CreateGraphSplitBackwardF reference_model_factory;
ov::element::Type input_type;
std::tie(pass_factory,
split_tree_depth,
num_split_outputs,
model_factory,
reference_model_factory,
input_type) = this->GetParam();
model = model_factory(split_tree_depth, num_split_outputs, input_type);
model_ref = reference_model_factory(split_tree_depth, num_split_outputs, input_type);
pass_factory->registerPass(manager);
}
INSTANTIATE_TEST_SUITE_P(
TransposeSinkingSplitBackwardTestSuite,
TransposeSinkingSplitBackwardTestFixture,
::testing::Combine(::testing::Values(CreatePassFactory<ov::pass::TransposeSinkingSplitBackward>()),
::testing::ValuesIn(split_tree_depth_nums),
::testing::ValuesIn(split_outputs_numbers),
::testing::Values(split::backward::CreateFunction),
::testing::Values(split::backward::CreateReferenceFunction),
::testing::Values(ov::element::f32)));
using TransposeInsertF = std::function< ov::OutputVector (const ov::OutputVector& split_tree_leaves)>;
using CreateGraphSplitBackwardRestrictF = std::function< std::shared_ptr<ov::Model> (size_t split_tree_depth,
size_t num_split_outputs,
ov::element::Type input_type,
TransposeInsertF tranpose_insert_function)>;
using TestSplitBackwardRestrictParams = std::tuple<PassFactoryPtr,
size_t, /* split_tree_depth */
size_t, /* num_split_outputs */
CreateGraphSplitBackwardRestrictF, /* model_factory */
ov::element::Type, /* input type */
TransposeInsertF>; /* insert transpose function */
class TransposeSinkingSplitBackwardRestrictTestFixture: public ::testing::WithParamInterface<TestSplitBackwardRestrictParams>,
public TransformationTestsF {};
TEST_P(TransposeSinkingSplitBackwardRestrictTestFixture, CompareFunctions) {
PassFactoryPtr pass_factory;
size_t split_tree_depth;
size_t num_split_outputs;
CreateGraphSplitBackwardRestrictF model_factory;
ov::element::Type input_type;
TransposeInsertF tranpose_insert_function;
std::tie(pass_factory,
split_tree_depth,
num_split_outputs,
model_factory,
input_type,
tranpose_insert_function) = this->GetParam();
model = model_factory(split_tree_depth, num_split_outputs, input_type, tranpose_insert_function);
model_ref = model->clone();
pass_factory->registerPass(manager);
}
namespace split {
namespace backward {
namespace restrictions {
std::shared_ptr<ov::Model> CreateFunction(size_t split_tree_depth,
size_t num_split_outputs,
ov::element::Type input_type,
TransposeInsertF transpose_insert_func) {
const size_t split_input_dim_value = static_cast<size_t>(std::pow(num_split_outputs, split_tree_depth + 1));
const ov::Shape input_shape{96, split_input_dim_value, 55, 55};
auto X = std::make_shared<ov::opset9::Parameter>(input_type, input_shape);
ov::OutputVector split_tree_leaves;
{
SplitFactory split_factory(/* axis */ 1, num_split_outputs, /* elem_type */ ov::element::u64);
CreateSplitTree(split_tree_depth, /* depth */ 0, X->output(0), split_factory, split_tree_leaves);
}
ov::OutputVector outputs;
for (auto& split_tree_leaf : transpose_insert_func(split_tree_leaves)) {
const size_t split_dim_current_value = static_cast<size_t>(split_input_dim_value / std::pow(num_split_outputs, split_tree_depth));
auto reshape_const = std::make_shared<ov::opset9::Constant>(ov::element::u64, ov::Shape{3}, ov::Shape{96, 55, split_dim_current_value * 55});
auto reshape = std::make_shared<ov::opset9::Reshape>(split_tree_leaf, reshape_const, false);
outputs.push_back(reshape);
}
return std::make_shared<ov::Model>(outputs, ov::ParameterVector{X});
}
ov::OutputVector OnlyFirstTranspose(const ov::OutputVector& split_tree_leaves) {
ov::OutputVector outputs;
{
auto& split_tree_leaf = split_tree_leaves.front();
auto ng_order = std::make_shared<ov::opset9::Constant>(ov::element::u64, ov::Shape{4}, ov::Shape{0, 3, 1, 2});
auto transpose = std::make_shared<ov::opset9::Transpose>(split_tree_leaf, ng_order);
outputs.push_back(transpose);
}
for (size_t leaf_idx = 1; leaf_idx < split_tree_leaves.size(); ++leaf_idx) {
outputs.push_back(split_tree_leaves[leaf_idx]);
}
return outputs;
}
ov::OutputVector OnlyLastTranspose(const ov::OutputVector& split_tree_leaves) {
ov::OutputVector outputs;
{
auto& split_tree_leaf = split_tree_leaves.back();
auto ng_order = std::make_shared<ov::opset9::Constant>(ov::element::u64, ov::Shape{4}, ov::Shape{0, 3, 1, 2});
auto transpose = std::make_shared<ov::opset9::Transpose>(split_tree_leaf, ng_order);
outputs.push_back(transpose);
}
for (size_t leaf_idx = 0; leaf_idx < split_tree_leaves.size() - 1; ++leaf_idx) {
outputs.push_back(split_tree_leaves[leaf_idx]);
}
return outputs;
}
ov::OutputVector OnlyMiddleTranspose(const ov::OutputVector& split_tree_leaves) {
ov::OutputVector outputs;
size_t middle_idx = split_tree_leaves.size() / 2;
if (split_tree_leaves.size() % 2)
++middle_idx;
for (size_t leaf_idx = 0; leaf_idx < split_tree_leaves.size() - 1; ++leaf_idx) {
if (leaf_idx == middle_idx) {
auto& split_tree_leaf = split_tree_leaves[leaf_idx];
auto ng_order = std::make_shared<ov::opset9::Constant>(ov::element::u64, ov::Shape{4}, ov::Shape{0, 3, 1, 2});
auto transpose = std::make_shared<ov::opset9::Transpose>(split_tree_leaf, ng_order);
outputs.push_back(transpose);
} else {
outputs.push_back(split_tree_leaves[leaf_idx]);
}
}
return outputs;
}
ov::OutputVector FirstAnotherTranspose(const ov::OutputVector& split_tree_leaves) {
ov::OutputVector outputs;
{
auto& split_tree_leaf = split_tree_leaves.front();
auto ng_order = std::make_shared<ov::opset9::Constant>(ov::element::u64, ov::Shape{4}, ov::Shape{0, 2, 3, 1});
auto transpose = std::make_shared<ov::opset9::Transpose>(split_tree_leaf, ng_order);
outputs.push_back(transpose);
}
for (size_t leaf_idx = 1; leaf_idx < split_tree_leaves.size(); ++leaf_idx) {
auto& split_tree_leaf = split_tree_leaves[leaf_idx];
auto ng_order = std::make_shared<ov::opset9::Constant>(ov::element::u64, ov::Shape{4}, ov::Shape{0, 3, 1, 2});
auto transpose = std::make_shared<ov::opset9::Transpose>(split_tree_leaf, ng_order);
outputs.push_back(transpose);
}
return outputs;
}
ov::OutputVector LastAnotherTranspose(const ov::OutputVector& split_tree_leaves) {
ov::OutputVector outputs;
{
auto& split_tree_leaf = split_tree_leaves.back();
auto ng_order = std::make_shared<ov::opset9::Constant>(ov::element::u64, ov::Shape{4}, ov::Shape{0, 2, 3, 1});
auto transpose = std::make_shared<ov::opset9::Transpose>(split_tree_leaf, ng_order);
outputs.push_back(transpose);
}
for (size_t leaf_idx = 0; leaf_idx < split_tree_leaves.size() - 1; ++leaf_idx) {
auto& split_tree_leaf = split_tree_leaves[leaf_idx];
auto ng_order = std::make_shared<ov::opset9::Constant>(ov::element::u64, ov::Shape{4}, ov::Shape{0, 3, 1, 2});
auto transpose = std::make_shared<ov::opset9::Transpose>(split_tree_leaf, ng_order);
outputs.push_back(transpose);
}
return outputs;
}
ov::OutputVector MiddleAnotherTranspose(const ov::OutputVector& split_tree_leaves) {
ov::OutputVector outputs;
size_t middle_idx = split_tree_leaves.size() / 2;
if (split_tree_leaves.size() % 2)
++middle_idx;
for (size_t leaf_idx = 0; leaf_idx < split_tree_leaves.size(); ++leaf_idx) {
auto& split_tree_leaf = split_tree_leaves[leaf_idx];
if (leaf_idx == middle_idx) {
auto ng_order = std::make_shared<ov::opset9::Constant>(ov::element::u64, ov::Shape{4}, ov::Shape{0, 2, 3, 1});
auto transpose = std::make_shared<ov::opset9::Transpose>(split_tree_leaf, ng_order);
outputs.push_back(transpose);
} else {
auto ng_order = std::make_shared<ov::opset9::Constant>(ov::element::u64, ov::Shape{4}, ov::Shape{0, 3, 1, 2});
auto transpose = std::make_shared<ov::opset9::Transpose>(split_tree_leaf, ng_order);
outputs.push_back(transpose);
}
}
return outputs;
}
} // namespace restrictions
} // namespace backward
} // namespace split
namespace {
std::vector<TransposeInsertF> insertTransposeFactories = {
split::backward::restrictions::OnlyFirstTranspose,
split::backward::restrictions::OnlyLastTranspose,
split::backward::restrictions::OnlyMiddleTranspose,
split::backward::restrictions::FirstAnotherTranspose,
split::backward::restrictions::LastAnotherTranspose,
split::backward::restrictions::MiddleAnotherTranspose
};
} // namespace
INSTANTIATE_TEST_SUITE_P(
TransposeSinkingSplitBackwardRestrictTestSuite,
TransposeSinkingSplitBackwardRestrictTestFixture,
::testing::Combine(::testing::Values(CreatePassFactory<ov::pass::TransposeSinkingSplitBackward>()),
::testing::Values(1),
::testing::Values(5),
::testing::Values(split::backward::restrictions::CreateFunction),
::testing::Values(ov::element::f32),
::testing::ValuesIn(insertTransposeFactories)));