Transpose sinking on passes for Concat, Split and binary eltwise operations (#13718)
* initial * clang cleanup fixes * remove TrasposeAxis function; cleanup namespaces * fix TransposeInputsInfo spell * one_input_transpose spell * cleanup speel * spell * decompose forward sinking * decompose backward sink * use NodeVector * clang cleanup * decomposite transformations into different files * decompose unit tests * clang cleanup * azure build fixes * code review fixes * clang cleanup fixes
This commit is contained in:
parent
738d7bb09f
commit
0846bdb67e
@ -0,0 +1,30 @@
|
||||
// Copyright (C) 2022 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "openvino/pass/graph_rewrite.hpp"
|
||||
#include "openvino/pass/pass.hpp"
|
||||
#include "transformations_visibility.hpp"
|
||||
|
||||
namespace ov {
|
||||
namespace pass {
|
||||
|
||||
class TRANSFORMATIONS_API TransposeSinkingBinaryElementwiseForward;
|
||||
class TRANSFORMATIONS_API TransposeSinkingBinaryElementwiseBackward;
|
||||
|
||||
} // namespace pass
|
||||
} // namespace ov
|
||||
|
||||
class ov::pass::TransposeSinkingBinaryElementwiseForward : public ov::pass::MatcherPass {
|
||||
public:
|
||||
OPENVINO_RTTI("ov::pass::TransposeSinkingBinaryElementwiseForward", "0");
|
||||
TransposeSinkingBinaryElementwiseForward();
|
||||
};
|
||||
|
||||
class ov::pass::TransposeSinkingBinaryElementwiseBackward : public ov::pass::MatcherPass {
|
||||
public:
|
||||
OPENVINO_RTTI("ov::pass::TransposeSinkingBinaryElementwiseBackward", "0");
|
||||
TransposeSinkingBinaryElementwiseBackward();
|
||||
};
|
@ -0,0 +1,30 @@
|
||||
// Copyright (C) 2022 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "openvino/pass/graph_rewrite.hpp"
|
||||
#include "openvino/pass/pass.hpp"
|
||||
#include "transformations_visibility.hpp"
|
||||
|
||||
namespace ov {
|
||||
namespace pass {
|
||||
|
||||
class TRANSFORMATIONS_API TransposeSinkingConcatForward;
|
||||
class TRANSFORMATIONS_API TransposeSinkingConcatBackward;
|
||||
|
||||
} // namespace pass
|
||||
} // namespace ov
|
||||
|
||||
class ov::pass::TransposeSinkingConcatForward : public ov::pass::MatcherPass {
|
||||
public:
|
||||
OPENVINO_RTTI("ov::pass::TransposeSinkingConcatForward", "0");
|
||||
TransposeSinkingConcatForward();
|
||||
};
|
||||
|
||||
class ov::pass::TransposeSinkingConcatBackward : public ov::pass::MatcherPass {
|
||||
public:
|
||||
OPENVINO_RTTI("ov::pass::TransposeSinkingConcatBackward", "0");
|
||||
TransposeSinkingConcatBackward();
|
||||
};
|
@ -0,0 +1,30 @@
|
||||
// Copyright (C) 2022 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "openvino/pass/graph_rewrite.hpp"
|
||||
#include "openvino/pass/pass.hpp"
|
||||
#include "transformations_visibility.hpp"
|
||||
|
||||
namespace ov {
|
||||
namespace pass {
|
||||
|
||||
class TRANSFORMATIONS_API TransposeSinkingSplitBackward;
|
||||
class TRANSFORMATIONS_API TransposeSinkingSplitForward;
|
||||
|
||||
} // namespace pass
|
||||
} // namespace ov
|
||||
|
||||
class ov::pass::TransposeSinkingSplitForward : public ov::pass::MatcherPass {
|
||||
public:
|
||||
OPENVINO_RTTI("ov::pass::TransposeSinkingSplitForward", "0");
|
||||
TransposeSinkingSplitForward();
|
||||
};
|
||||
|
||||
class ov::pass::TransposeSinkingSplitBackward : public ov::pass::MatcherPass {
|
||||
public:
|
||||
OPENVINO_RTTI("ov::pass::TransposeSinkingSplitBackward", "0");
|
||||
TransposeSinkingSplitBackward();
|
||||
};
|
@ -0,0 +1,50 @@
|
||||
// Copyright (C) 2022 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <openvino/pass/pattern/op/or.hpp>
|
||||
#include <transformations/utils/utils.hpp>
|
||||
#include <utility>
|
||||
|
||||
#include "itt.hpp"
|
||||
#include "openvino/op/util/op_types.hpp"
|
||||
#include "openvino/opsets/opset9.hpp"
|
||||
#include "openvino/pass/pattern/op/label.hpp"
|
||||
#include "openvino/pass/pattern/op/wrap_type.hpp"
|
||||
#include "openvino/util/common_util.hpp"
|
||||
#include "openvino/util/log.hpp"
|
||||
|
||||
namespace transpose_sinking {
|
||||
|
||||
struct TransposeInputsInfo {
|
||||
std::shared_ptr<ov::opset9::Transpose> transpose;
|
||||
std::shared_ptr<ov::opset9::Constant> transpose_const;
|
||||
size_t input_idx;
|
||||
|
||||
bool isEmpty() const {
|
||||
return !transpose || !transpose_const;
|
||||
}
|
||||
};
|
||||
|
||||
TransposeInputsInfo GetFirstTransposeInput(std::shared_ptr<ov::Node> node);
|
||||
bool IfNodeHasTransposeInputs(const ov::Output<ov::Node>& output);
|
||||
ov::AxisVector ReverseTransposeOrder(const ov::AxisVector& axis_order);
|
||||
void SwapOutputNames(ov::Output<ov::Node> output1, ov::Output<ov::Node> output2);
|
||||
void SwapFriendlyNames(std::shared_ptr<ov::Node> node1, std::shared_ptr<ov::Node> node2);
|
||||
void SwapNames(std::shared_ptr<ov::Node> node1, std::shared_ptr<ov::Node> node2);
|
||||
|
||||
namespace sink_forward {
|
||||
// insert input reversed transposes, remove first input tranpose
|
||||
void UpdateInputTransposes(std::shared_ptr<ov::Node> main_node, TransposeInputsInfo& transpose_input_info);
|
||||
void RemoveZeroInputNode(std::shared_ptr<ov::Node> main_node);
|
||||
ov::NodeVector InsertOutputTransposes(std::shared_ptr<ov::Node> main_node, TransposeInputsInfo& transpose_input_info);
|
||||
} // namespace sink_forward
|
||||
|
||||
namespace sink_backward {
|
||||
ov::NodeVector InsertTransposeBeforeNode(std::shared_ptr<ov::Node> main_node,
|
||||
std::shared_ptr<ov::opset9::Constant> transpose_const);
|
||||
} // namespace sink_backward
|
||||
|
||||
} // namespace transpose_sinking
|
@ -0,0 +1,74 @@
|
||||
#include "transformations/common_optimizations/transpose_sinking_binary.hpp"
|
||||
|
||||
#include <openvino/pass/pattern/op/or.hpp>
|
||||
#include <transformations/utils/utils.hpp>
|
||||
#include <utility>
|
||||
|
||||
#include "itt.hpp"
|
||||
#include "openvino/op/util/op_types.hpp"
|
||||
#include "openvino/opsets/opset9.hpp"
|
||||
#include "openvino/pass/pattern/op/label.hpp"
|
||||
#include "openvino/pass/pattern/op/wrap_type.hpp"
|
||||
#include "openvino/util/common_util.hpp"
|
||||
#include "openvino/util/log.hpp"
|
||||
#include "transformations/common_optimizations/transpose_sinking_utils.hpp"
|
||||
|
||||
using namespace ov::pass::pattern;
|
||||
using namespace ov;
|
||||
using namespace ov::opset9;
|
||||
using namespace transpose_sinking;
|
||||
|
||||
ov::pass::TransposeSinkingBinaryElementwiseForward::TransposeSinkingBinaryElementwiseForward() {
|
||||
MATCHER_SCOPE(TransposeSinkingBinaryElementwiseForward);
|
||||
|
||||
auto main_node_label = wrap_type<op::util::BinaryElementwiseArithmetic>(IfNodeHasTransposeInputs);
|
||||
|
||||
matcher_pass_callback matcher_pass_callback = [=](Matcher& m) {
|
||||
const auto& pattern_to_output = m.get_pattern_value_map();
|
||||
|
||||
auto& main_node_output = pattern_to_output.at(main_node_label);
|
||||
auto main_node = main_node_output.get_node_shared_ptr();
|
||||
|
||||
TransposeInputsInfo transpose_input_info = GetFirstTransposeInput(main_node);
|
||||
|
||||
sink_forward::UpdateInputTransposes(main_node, transpose_input_info);
|
||||
for (auto& new_node : sink_forward::InsertOutputTransposes(main_node, transpose_input_info)) {
|
||||
register_new_node(new_node);
|
||||
}
|
||||
|
||||
return true;
|
||||
};
|
||||
|
||||
auto m = std::make_shared<Matcher>(main_node_label, matcher_name);
|
||||
register_matcher(m, matcher_pass_callback);
|
||||
}
|
||||
|
||||
ov::pass::TransposeSinkingBinaryElementwiseBackward::TransposeSinkingBinaryElementwiseBackward() {
|
||||
MATCHER_SCOPE(TransposeSinkingBinaryElementwiseBackward);
|
||||
|
||||
auto main_node_label = wrap_type<op::util::BinaryElementwiseArithmetic>(consumers_count(1));
|
||||
|
||||
auto transpose_const_label = wrap_type<Constant>(consumers_count(1));
|
||||
auto transpose_label = wrap_type<Transpose>({main_node_label, transpose_const_label}, consumers_count(1));
|
||||
|
||||
matcher_pass_callback matcher_pass_callback = [=](Matcher& m) {
|
||||
const auto& pattern_to_output = m.get_pattern_value_map();
|
||||
auto transpose_const = as_type_ptr<Constant>(pattern_to_output.at(transpose_const_label).get_node_shared_ptr());
|
||||
auto transpose = pattern_to_output.at(transpose_label).get_node_shared_ptr();
|
||||
auto main_node = pattern_to_output.at(main_node_label).get_node_shared_ptr();
|
||||
|
||||
for (auto& new_node : sink_backward::InsertTransposeBeforeNode(main_node, transpose_const)) {
|
||||
register_new_node(new_node);
|
||||
}
|
||||
|
||||
// remove transpose after main node
|
||||
transpose->output(0).replace(main_node);
|
||||
|
||||
SwapNames(transpose, main_node);
|
||||
|
||||
return true;
|
||||
};
|
||||
|
||||
auto m = std::make_shared<Matcher>(transpose_label, matcher_name);
|
||||
register_matcher(m, matcher_pass_callback);
|
||||
}
|
@ -0,0 +1,86 @@
|
||||
#include "transformations/common_optimizations/transpose_sinking_concat.hpp"
|
||||
|
||||
#include <openvino/opsets/opset9.hpp>
|
||||
#include <openvino/pass/pattern/op/or.hpp>
|
||||
#include <transformations/utils/utils.hpp>
|
||||
#include <utility>
|
||||
|
||||
#include "itt.hpp"
|
||||
#include "openvino/op/util/op_types.hpp"
|
||||
#include "openvino/opsets/opset9.hpp"
|
||||
#include "openvino/pass/pattern/op/label.hpp"
|
||||
#include "openvino/pass/pattern/op/wrap_type.hpp"
|
||||
#include "openvino/util/common_util.hpp"
|
||||
#include "openvino/util/log.hpp"
|
||||
#include "transformations/common_optimizations/transpose_sinking_utils.hpp"
|
||||
|
||||
using namespace ov::pass::pattern;
|
||||
using namespace ov;
|
||||
using namespace ov::opset9;
|
||||
using namespace transpose_sinking;
|
||||
|
||||
ov::pass::TransposeSinkingConcatForward::TransposeSinkingConcatForward() {
|
||||
MATCHER_SCOPE(TransposeSinkingConcatForward);
|
||||
|
||||
auto main_node_label = wrap_type<Concat>(IfNodeHasTransposeInputs);
|
||||
|
||||
matcher_pass_callback matcher_pass_callback = [=](Matcher& m) {
|
||||
const auto& pattern_to_output = m.get_pattern_value_map();
|
||||
|
||||
auto& main_node_output = pattern_to_output.at(main_node_label);
|
||||
auto main_node = main_node_output.get_node_shared_ptr();
|
||||
|
||||
TransposeInputsInfo transpose_input_info = GetFirstTransposeInput(main_node);
|
||||
|
||||
sink_forward::UpdateInputTransposes(main_node, transpose_input_info);
|
||||
for (auto& new_node : sink_forward::InsertOutputTransposes(main_node, transpose_input_info)) {
|
||||
register_new_node(new_node);
|
||||
}
|
||||
|
||||
auto concat_node = as_type_ptr<Concat>(main_node);
|
||||
const auto transpose_axis_order = transpose_input_info.transpose_const->get_axis_vector_val();
|
||||
const int64_t transposed_concat_axis = transpose_axis_order[concat_node->get_axis()];
|
||||
concat_node->set_concatenation_axis(transposed_concat_axis);
|
||||
|
||||
return true;
|
||||
};
|
||||
|
||||
auto m = std::make_shared<Matcher>(main_node_label, matcher_name);
|
||||
register_matcher(m, matcher_pass_callback);
|
||||
}
|
||||
|
||||
ov::pass::TransposeSinkingConcatBackward::TransposeSinkingConcatBackward() {
|
||||
MATCHER_SCOPE(TransposeSinkingConcatBackward);
|
||||
|
||||
auto main_node_label = wrap_type<Concat>(consumers_count(1));
|
||||
|
||||
auto transpose_const_label = wrap_type<Constant>(consumers_count(1));
|
||||
auto transpose_label = wrap_type<Transpose>({main_node_label, transpose_const_label}, consumers_count(1));
|
||||
|
||||
matcher_pass_callback matcher_pass_callback = [=](Matcher& m) {
|
||||
const auto& pattern_to_output = m.get_pattern_value_map();
|
||||
auto transpose_const = as_type_ptr<Constant>(pattern_to_output.at(transpose_const_label).get_node_shared_ptr());
|
||||
auto transpose = pattern_to_output.at(transpose_label).get_node_shared_ptr();
|
||||
auto main_node = pattern_to_output.at(main_node_label).get_node_shared_ptr();
|
||||
|
||||
for (auto& new_node : sink_backward::InsertTransposeBeforeNode(main_node, transpose_const)) {
|
||||
register_new_node(new_node);
|
||||
}
|
||||
|
||||
// remove transpose after main node
|
||||
transpose->output(0).replace(main_node);
|
||||
|
||||
SwapNames(transpose, main_node);
|
||||
|
||||
auto concat_node = as_type_ptr<Concat>(main_node);
|
||||
const auto transpose_axis_order = transpose_const->get_axis_vector_val();
|
||||
const auto reversed_traspose_axis_order = ReverseTransposeOrder(transpose_axis_order);
|
||||
const int64_t transposed_concat_axis = reversed_traspose_axis_order[concat_node->get_axis()];
|
||||
concat_node->set_concatenation_axis(transposed_concat_axis);
|
||||
|
||||
return true;
|
||||
};
|
||||
|
||||
auto m = std::make_shared<Matcher>(transpose_label, matcher_name);
|
||||
register_matcher(m, matcher_pass_callback);
|
||||
}
|
@ -0,0 +1,200 @@
|
||||
#include "transformations/common_optimizations/transpose_sinking_split.hpp"
|
||||
|
||||
#include <openvino/opsets/opset9.hpp>
|
||||
#include <openvino/pass/pattern/op/or.hpp>
|
||||
#include <transformations/utils/utils.hpp>
|
||||
#include <utility>
|
||||
|
||||
#include "itt.hpp"
|
||||
#include "openvino/op/util/op_types.hpp"
|
||||
#include "openvino/opsets/opset9.hpp"
|
||||
#include "openvino/pass/pattern/op/label.hpp"
|
||||
#include "openvino/pass/pattern/op/wrap_type.hpp"
|
||||
#include "openvino/util/common_util.hpp"
|
||||
#include "openvino/util/log.hpp"
|
||||
#include "transformations/common_optimizations/transpose_sinking_utils.hpp"
|
||||
|
||||
using namespace ov::pass::pattern;
|
||||
using namespace ov;
|
||||
using namespace ov::opset9;
|
||||
using namespace transpose_sinking;
|
||||
|
||||
namespace {
|
||||
|
||||
using NodePtr = std::shared_ptr<Node>;
|
||||
|
||||
struct OutputTranspose {
|
||||
OutputTranspose() : transpose(nullptr), transpose_const(nullptr) {}
|
||||
Transpose* transpose;
|
||||
Constant* transpose_const;
|
||||
};
|
||||
|
||||
OutputTranspose GetOutputTransposes(NodePtr node) {
|
||||
for (size_t output_idx = 0; output_idx < node->get_output_size(); ++output_idx) {
|
||||
for (auto& input : node->get_output_target_inputs(output_idx)) {
|
||||
auto transpose_node = dynamic_cast<Transpose*>(input.get_node());
|
||||
if (!transpose_node)
|
||||
continue;
|
||||
auto constant_node = dynamic_cast<Constant*>(transpose_node->input_value(1).get_node());
|
||||
if (!constant_node)
|
||||
continue;
|
||||
{
|
||||
OutputTranspose output_transpose;
|
||||
output_transpose.transpose = transpose_node;
|
||||
output_transpose.transpose_const = constant_node;
|
||||
|
||||
return output_transpose;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return OutputTranspose();
|
||||
}
|
||||
|
||||
NodePtr FindSplitInput(Node* node) {
|
||||
for (size_t input_idx = 0; input_idx < node->get_input_size(); ++input_idx) {
|
||||
NodePtr input_node = node->get_input_node_shared_ptr(input_idx);
|
||||
auto split_node = as_type_ptr<Split>(input_node);
|
||||
if (split_node)
|
||||
return split_node;
|
||||
}
|
||||
return {};
|
||||
}
|
||||
|
||||
std::shared_ptr<Constant> GetTransposeConstant(Input<Node> input) {
|
||||
auto transpose_node = dynamic_cast<Transpose*>(input.get_node());
|
||||
if (!transpose_node)
|
||||
return {};
|
||||
|
||||
auto constant_node = as_type_ptr<Constant>(transpose_node->input_value(1).get_node_shared_ptr());
|
||||
if (!constant_node)
|
||||
return {};
|
||||
|
||||
return constant_node;
|
||||
}
|
||||
|
||||
bool HasInputSplitAndTransposeSiblings(const Output<Node>& output) {
|
||||
NodePtr split_node = FindSplitInput(output.get_node());
|
||||
if (!split_node) {
|
||||
return false;
|
||||
}
|
||||
|
||||
AxisVector first_transpose_axis_order;
|
||||
// get first transpose axis
|
||||
{
|
||||
auto constant_node = GetTransposeConstant(*(split_node->get_output_target_inputs(0).begin()));
|
||||
if (!constant_node)
|
||||
return false;
|
||||
first_transpose_axis_order = constant_node->get_axis_vector_val();
|
||||
}
|
||||
|
||||
for (size_t output_idx = 1; output_idx < split_node->get_output_size(); ++output_idx) {
|
||||
for (auto& input : split_node->get_output_target_inputs(output_idx)) {
|
||||
auto constant_node = GetTransposeConstant(input);
|
||||
if (!constant_node)
|
||||
return false;
|
||||
|
||||
AxisVector transpose_axis_order = constant_node->get_axis_vector_val();
|
||||
if (transpose_axis_order.size() != first_transpose_axis_order.size())
|
||||
return false;
|
||||
if (!std::equal(transpose_axis_order.begin(),
|
||||
transpose_axis_order.end(),
|
||||
first_transpose_axis_order.begin()))
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
} // namespace
|
||||
|
||||
ov::pass::TransposeSinkingSplitBackward::TransposeSinkingSplitBackward() {
|
||||
MATCHER_SCOPE(TransposeSinkingSplitBackward);
|
||||
|
||||
auto transpose_const_label = wrap_type<Constant>(consumers_count(1));
|
||||
auto transpose_label =
|
||||
wrap_type<Transpose>({any_input(), transpose_const_label}, HasInputSplitAndTransposeSiblings);
|
||||
|
||||
matcher_pass_callback matcher_pass_callback = [=](Matcher& m) {
|
||||
const auto& pattern_to_output = m.get_pattern_value_map();
|
||||
auto transpose_label_node = pattern_to_output.at(transpose_label).get_node();
|
||||
|
||||
NodePtr split = FindSplitInput(transpose_label_node);
|
||||
auto split_axis_constant = as_type_ptr<Constant>(split->input_value(1).get_node_shared_ptr());
|
||||
OutputTranspose output_transpose = GetOutputTransposes(split);
|
||||
|
||||
const auto transpose_axis_order = output_transpose.transpose_const->get_axis_vector_val();
|
||||
const auto transpose_element_type = output_transpose.transpose_const->get_element_type();
|
||||
|
||||
const auto reversed_traspose_axis_order = ReverseTransposeOrder(transpose_axis_order);
|
||||
|
||||
const size_t split_axis = split_axis_constant->get_axis_vector_val()[0];
|
||||
const size_t reversed_transposed_split_axis = reversed_traspose_axis_order[split_axis];
|
||||
|
||||
// insert transpose before split
|
||||
{
|
||||
auto input_node = split->input_value(0);
|
||||
auto new_transpose_const = std::make_shared<Constant>(transpose_element_type,
|
||||
Shape{transpose_axis_order.size()},
|
||||
transpose_axis_order);
|
||||
auto new_transpose = std::make_shared<Transpose>(input_node, new_transpose_const);
|
||||
|
||||
split->input(0).replace_source_output(new_transpose->output(0));
|
||||
|
||||
copy_runtime_info(input_node.get_node_shared_ptr(), {new_transpose, new_transpose_const});
|
||||
|
||||
register_new_node(new_transpose);
|
||||
}
|
||||
|
||||
// update split axis
|
||||
auto new_split_axis_const = std::make_shared<Constant>(split_axis_constant->get_element_type(),
|
||||
Shape{},
|
||||
reversed_transposed_split_axis);
|
||||
split->input(1).replace_source_output(new_split_axis_const);
|
||||
|
||||
// remove split output transposes
|
||||
for (size_t output_idx = 0; output_idx < split->get_output_size(); ++output_idx) {
|
||||
for (auto& input : split->get_output_target_inputs(output_idx)) {
|
||||
input.get_node()->output(0).replace(split->output(output_idx));
|
||||
}
|
||||
}
|
||||
return true;
|
||||
};
|
||||
|
||||
auto m = std::make_shared<Matcher>(transpose_label, matcher_name);
|
||||
register_matcher(m, matcher_pass_callback);
|
||||
}
|
||||
|
||||
ov::pass::TransposeSinkingSplitForward::TransposeSinkingSplitForward() {
|
||||
MATCHER_SCOPE(TransposeSinkingSplitForward);
|
||||
|
||||
auto main_node_label = wrap_type<Split>(IfNodeHasTransposeInputs);
|
||||
|
||||
matcher_pass_callback matcher_pass_callback = [=](Matcher& m) {
|
||||
const auto& pattern_to_output = m.get_pattern_value_map();
|
||||
|
||||
auto& main_node_output = pattern_to_output.at(main_node_label);
|
||||
auto main_node = main_node_output.get_node_shared_ptr();
|
||||
|
||||
TransposeInputsInfo transpose_input_info = GetFirstTransposeInput(main_node);
|
||||
|
||||
sink_forward::RemoveZeroInputNode(main_node);
|
||||
for (auto& new_node : sink_forward::InsertOutputTransposes(main_node, transpose_input_info)) {
|
||||
register_new_node(new_node);
|
||||
}
|
||||
|
||||
const auto transpose_axis_order = transpose_input_info.transpose_const->get_axis_vector_val();
|
||||
auto split_node = as_type_ptr<Split>(main_node);
|
||||
auto split_axis_constant = as_type_ptr<Constant>(split_node->input_value(1).get_node_shared_ptr());
|
||||
const size_t split_axis = split_axis_constant->get_axis_vector_val()[0];
|
||||
const size_t transposed_split_axis = transpose_axis_order[split_axis];
|
||||
auto new_split_axis_const =
|
||||
std::make_shared<Constant>(split_axis_constant->get_element_type(), Shape{}, transposed_split_axis);
|
||||
split_node->input(1).replace_source_output(new_split_axis_const);
|
||||
|
||||
return true;
|
||||
};
|
||||
|
||||
auto m = std::make_shared<Matcher>(main_node_label, matcher_name);
|
||||
register_matcher(m, matcher_pass_callback);
|
||||
}
|
@ -0,0 +1,172 @@
|
||||
#include "transformations/common_optimizations/transpose_sinking_utils.hpp"
|
||||
|
||||
#include <openvino/pass/pattern/op/or.hpp>
|
||||
#include <transformations/utils/utils.hpp>
|
||||
#include <utility>
|
||||
|
||||
#include "itt.hpp"
|
||||
#include "openvino/op/util/op_types.hpp"
|
||||
#include "openvino/opsets/opset9.hpp"
|
||||
#include "openvino/pass/pattern/op/label.hpp"
|
||||
#include "openvino/pass/pattern/op/wrap_type.hpp"
|
||||
#include "openvino/util/common_util.hpp"
|
||||
#include "openvino/util/log.hpp"
|
||||
|
||||
namespace transpose_sinking {
|
||||
|
||||
using namespace ov;
|
||||
using namespace ov::opset9;
|
||||
|
||||
using NodePtr = std::shared_ptr<Node>;
|
||||
|
||||
TransposeInputsInfo GetFirstTransposeInput(NodePtr node) {
|
||||
for (size_t input_idx = 0; input_idx < node->get_input_size(); ++input_idx) {
|
||||
NodePtr input_node = node->get_input_node_shared_ptr(input_idx);
|
||||
auto transpose_node = as_type_ptr<Transpose>(input_node);
|
||||
if (!transpose_node)
|
||||
continue;
|
||||
auto constant_node = as_type_ptr<Constant>(transpose_node->input_value(1).get_node_shared_ptr());
|
||||
if (!constant_node)
|
||||
continue;
|
||||
{
|
||||
TransposeInputsInfo input_info;
|
||||
input_info.transpose = transpose_node;
|
||||
input_info.transpose_const = constant_node;
|
||||
input_info.input_idx = input_idx;
|
||||
return input_info;
|
||||
}
|
||||
}
|
||||
|
||||
return TransposeInputsInfo();
|
||||
}
|
||||
|
||||
bool IfNodeHasTransposeInputs(const Output<Node>& output) {
|
||||
TransposeInputsInfo inputs_info = GetFirstTransposeInput(output.get_node_shared_ptr());
|
||||
return !inputs_info.isEmpty();
|
||||
}
|
||||
|
||||
AxisVector ReverseTransposeOrder(const AxisVector& axis_order) {
|
||||
AxisVector out(axis_order.size());
|
||||
for (size_t i = 0; i < axis_order.size(); i++) {
|
||||
out.at(axis_order[i]) = i;
|
||||
}
|
||||
return out;
|
||||
}
|
||||
|
||||
void SwapOutputNames(Output<Node> output1, Output<Node> output2) {
|
||||
const auto node2_output_names = output2.get_names();
|
||||
output2.set_names(output1.get_names());
|
||||
output1.set_names(node2_output_names);
|
||||
}
|
||||
|
||||
void SwapFriendlyNames(NodePtr node1, NodePtr node2) {
|
||||
const std::string node2_name = node2->get_friendly_name();
|
||||
node2->set_friendly_name(node1->get_friendly_name());
|
||||
node1->set_friendly_name(node2_name);
|
||||
}
|
||||
|
||||
void SwapNames(NodePtr node1, NodePtr node2) {
|
||||
SwapFriendlyNames(node1, node2);
|
||||
SwapOutputNames(node1->output(0), node2->output(0));
|
||||
}
|
||||
|
||||
namespace sink_forward {
|
||||
|
||||
// insert input reversed transposes, remove first input tranpose
|
||||
void UpdateInputTransposes(NodePtr main_node, TransposeInputsInfo& transpose_input_info) {
|
||||
if (transpose_input_info.isEmpty())
|
||||
return;
|
||||
const auto transpose_axis_order = transpose_input_info.transpose_const->get_axis_vector_val();
|
||||
const auto reversed_traspose_axis_order = ReverseTransposeOrder(transpose_axis_order);
|
||||
const size_t tranpose_input_index = transpose_input_info.input_idx;
|
||||
const auto transpose_element_type = transpose_input_info.transpose_const->get_element_type();
|
||||
|
||||
for (size_t i = 0; i < main_node->get_input_size(); ++i) {
|
||||
auto input_node = main_node->input_value(i);
|
||||
if (i == tranpose_input_index) {
|
||||
auto transpose_parent = input_node.get_node()->input_value(0);
|
||||
main_node->input(i).replace_source_output(transpose_parent);
|
||||
} else {
|
||||
auto new_transpose_const = std::make_shared<Constant>(transpose_element_type,
|
||||
Shape{reversed_traspose_axis_order.size()},
|
||||
reversed_traspose_axis_order);
|
||||
auto new_transpose = std::make_shared<Transpose>(input_node, new_transpose_const);
|
||||
|
||||
main_node->input(i).replace_source_output(new_transpose->output(0));
|
||||
|
||||
copy_runtime_info(input_node.get_node_shared_ptr(), {new_transpose, new_transpose_const});
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void RemoveZeroInputNode(NodePtr main_node) {
|
||||
auto input_node = main_node->input_value(0);
|
||||
if (input_node.get_node()->get_input_size() < 1)
|
||||
return;
|
||||
auto parent_node = input_node.get_node()->input_value(0);
|
||||
main_node->input(0).replace_source_output(parent_node);
|
||||
}
|
||||
|
||||
NodeVector InsertOutputTransposes(NodePtr main_node, TransposeInputsInfo& transpose_input_info) {
|
||||
if (transpose_input_info.isEmpty())
|
||||
return {};
|
||||
const auto transpose_axis_order = transpose_input_info.transpose_const->get_axis_vector_val();
|
||||
const auto reversed_traspose_axis_order = ReverseTransposeOrder(transpose_axis_order);
|
||||
const auto transpose_element_type = transpose_input_info.transpose_const->get_element_type();
|
||||
|
||||
NodeVector new_nodes;
|
||||
|
||||
for (size_t i = 0; i < main_node->get_output_size(); ++i) {
|
||||
auto main_node_consumers = main_node->output(i).get_target_inputs();
|
||||
|
||||
auto new_transpose_const = std::make_shared<Constant>(transpose_element_type,
|
||||
Shape{transpose_axis_order.size()},
|
||||
transpose_axis_order);
|
||||
auto new_transpose = std::make_shared<Transpose>(main_node->output(i), new_transpose_const);
|
||||
|
||||
for (auto& consumer : main_node_consumers) {
|
||||
consumer.replace_source_output(new_transpose);
|
||||
}
|
||||
|
||||
copy_runtime_info(main_node, {new_transpose, new_transpose_const});
|
||||
SwapOutputNames(main_node->output(i), new_transpose->output(0));
|
||||
|
||||
if (main_node->get_output_size() > 1)
|
||||
new_transpose->set_friendly_name(main_node->get_friendly_name() + "." + std::to_string(i));
|
||||
else
|
||||
SwapFriendlyNames(new_transpose, main_node);
|
||||
|
||||
new_nodes.push_back(new_transpose);
|
||||
}
|
||||
|
||||
return new_nodes;
|
||||
}
|
||||
|
||||
} // namespace sink_forward
|
||||
|
||||
namespace sink_backward {
|
||||
NodeVector InsertTransposeBeforeNode(NodePtr main_node, std::shared_ptr<Constant> transpose_const) {
|
||||
const auto transpose_axis_order = transpose_const->get_axis_vector_val();
|
||||
const auto transpose_element_type = transpose_const->get_element_type();
|
||||
|
||||
NodeVector new_nodes;
|
||||
|
||||
for (size_t i = 0; i < main_node->get_input_size(); ++i) {
|
||||
auto input_node = main_node->input_value(i);
|
||||
auto new_transpose_const = std::make_shared<Constant>(transpose_element_type,
|
||||
Shape{transpose_axis_order.size()},
|
||||
transpose_axis_order);
|
||||
auto new_transpose = std::make_shared<Transpose>(input_node, new_transpose_const);
|
||||
|
||||
main_node->input(i).replace_source_output(new_transpose->output(0));
|
||||
|
||||
copy_runtime_info(input_node.get_node_shared_ptr(), {new_transpose, new_transpose_const});
|
||||
|
||||
new_nodes.push_back(new_transpose);
|
||||
}
|
||||
|
||||
return new_nodes;
|
||||
}
|
||||
} // namespace sink_backward
|
||||
|
||||
} // namespace transpose_sinking
|
@ -0,0 +1,358 @@
|
||||
// Copyright (C) 2022 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#include <transformations/common_optimizations/transpose_sinking_binary.hpp>
|
||||
|
||||
#include <transformations/init_node_info.hpp>
|
||||
#include <openvino/frontend/manager.hpp>
|
||||
#include <openvino/opsets/opset9.hpp>
|
||||
#include <openvino/pass/manager.hpp>
|
||||
#include "common_test_utils/ngraph_test_utils.hpp"
|
||||
|
||||
#include <functional>
|
||||
|
||||
#include "gtest/gtest.h"
|
||||
|
||||
namespace {
|
||||
|
||||
using NodePtr = std::shared_ptr<ov::Node>;
|
||||
using ModelPtr = std::shared_ptr<ov::Model>;
|
||||
using Output = ov::Output<ov::Node>;
|
||||
|
||||
// ----------------------------------------------------------------------------
|
||||
|
||||
class IBinaryFactory {
|
||||
public:
|
||||
IBinaryFactory() = default;
|
||||
virtual ~IBinaryFactory() = default;
|
||||
virtual NodePtr create(NodePtr parent_left_node, NodePtr parent_right_node) const = 0;
|
||||
};
|
||||
|
||||
using BinaryFactoryPtr = std::shared_ptr<IBinaryFactory>;
|
||||
|
||||
template <typename BinaryT>
|
||||
class BinaryFactory : public IBinaryFactory {
|
||||
public:
|
||||
BinaryFactory() = default;
|
||||
NodePtr create(NodePtr parent_left_node, NodePtr parent_right_node) const override {
|
||||
return std::make_shared<BinaryT>(parent_left_node, parent_right_node);
|
||||
}
|
||||
};
|
||||
|
||||
template <typename BinaryT>
|
||||
BinaryFactoryPtr CreateBinaryFactory() {
|
||||
return std::make_shared<BinaryFactory<BinaryT>>();
|
||||
}
|
||||
|
||||
// ----------------------------------------------------------------------------
|
||||
|
||||
class IPassFactory {
|
||||
public:
|
||||
IPassFactory() = default;
|
||||
virtual ~IPassFactory() = default;
|
||||
virtual void registerPass(ov::pass::Manager& pass_manager) const = 0;
|
||||
};
|
||||
|
||||
using PassFactoryPtr = std::shared_ptr<IPassFactory>;
|
||||
|
||||
template <typename PassT>
|
||||
class PassFactory : public IPassFactory {
|
||||
public:
|
||||
void registerPass(ov::pass::Manager& pass_manager) const override {
|
||||
pass_manager.register_pass<PassT>();
|
||||
}
|
||||
};
|
||||
|
||||
template <typename PassT>
|
||||
PassFactoryPtr CreatePassFactory() {
|
||||
return std::make_shared<PassFactory<PassT>>();
|
||||
}
|
||||
|
||||
std::vector<BinaryFactoryPtr> binary_factories = {
|
||||
CreateBinaryFactory<ov::opset9::Add>(),
|
||||
CreateBinaryFactory<ov::opset9::Divide>(),
|
||||
CreateBinaryFactory<ov::opset9::Maximum>(),
|
||||
CreateBinaryFactory<ov::opset9::Minimum>(),
|
||||
CreateBinaryFactory<ov::opset9::Mod>(),
|
||||
CreateBinaryFactory<ov::opset9::Multiply>(),
|
||||
CreateBinaryFactory<ov::opset9::Power>(),
|
||||
CreateBinaryFactory<ov::opset9::SquaredDifference>(),
|
||||
CreateBinaryFactory<ov::opset9::Subtract>()
|
||||
};
|
||||
|
||||
std::vector<size_t> binary_operations_numbers = {1, 10};
|
||||
|
||||
std::vector<size_t> binary_transpose_input_indexes = {0, 1};
|
||||
|
||||
} // namespace
|
||||
|
||||
|
||||
namespace binary {
|
||||
namespace single_consumer {
|
||||
namespace forward {
|
||||
namespace one_input_transpose {
|
||||
|
||||
std::shared_ptr<ov::Model> CreateFunction(BinaryFactoryPtr binary_factory,
|
||||
size_t num_binary_ops,
|
||||
ov::element::Type input_type,
|
||||
size_t binary_transpose_input_idx) {
|
||||
const ov::Shape input_shape{1, 96, 55, 55};
|
||||
const ov::Shape const_shape{1, 55, 55, 96};
|
||||
|
||||
auto X = std::make_shared<ov::opset9::Parameter>(input_type, input_shape);
|
||||
|
||||
auto ng_order0 = std::make_shared<ov::opset9::Constant>(ov::element::u64, ov::Shape{4}, ov::Shape{0, 2, 3, 1});
|
||||
auto transpose0 = std::make_shared<ov::opset9::Transpose>(X, ng_order0);
|
||||
|
||||
NodePtr in_op = transpose0;
|
||||
for (size_t i = 0; i < num_binary_ops; ++i) {
|
||||
auto in_constant = std::make_shared<ov::opset9::Constant>(input_type, const_shape, ov::Shape{1});
|
||||
if (!binary_transpose_input_idx)
|
||||
in_op = binary_factory->create(in_op, in_constant);
|
||||
else
|
||||
in_op = binary_factory->create(in_constant, in_op);
|
||||
}
|
||||
|
||||
return std::make_shared<ov::Model>(ov::OutputVector{in_op}, ov::ParameterVector{X});
|
||||
}
|
||||
|
||||
std::shared_ptr<ov::Model> CreateReferenceFunction(BinaryFactoryPtr binary_factory,
|
||||
size_t num_binary_ops,
|
||||
ov::element::Type input_type,
|
||||
size_t binary_transpose_input_idx) {
|
||||
const ov::Shape input_shape{1, 96, 55, 55};
|
||||
const ov::Shape const_shape{1, 55, 55, 96};
|
||||
|
||||
auto X = std::make_shared<ov::opset9::Parameter>(input_type, input_shape);
|
||||
|
||||
NodePtr in_op = X;
|
||||
for (size_t i = 0; i < num_binary_ops; ++i) {
|
||||
auto in_constant = std::make_shared<ov::opset9::Constant>(input_type, const_shape, ov::Shape{1});
|
||||
|
||||
auto transpose_reversed_const =
|
||||
std::make_shared<ov::opset9::Constant>(ov::element::u64, ov::Shape{4}, ov::Shape{0, 3, 1, 2});
|
||||
auto transpose_reversed = std::make_shared<ov::opset9::Transpose>(in_constant, transpose_reversed_const);
|
||||
|
||||
if (!binary_transpose_input_idx)
|
||||
in_op = binary_factory->create(in_op, transpose_reversed);
|
||||
else
|
||||
in_op = binary_factory->create(transpose_reversed, in_op);
|
||||
}
|
||||
|
||||
auto ng_order0 = std::make_shared<ov::opset9::Constant>(ov::element::u64, ov::Shape{4}, ov::Shape{0, 2, 3, 1});
|
||||
auto transpose0 = std::make_shared<ov::opset9::Transpose>(in_op, ng_order0);
|
||||
|
||||
return std::make_shared<ov::Model>(ov::OutputVector{transpose0}, ov::ParameterVector{X});
|
||||
}
|
||||
|
||||
} // namespace one_input_transpose
|
||||
|
||||
namespace double_transpose {
|
||||
std::shared_ptr<ov::Model> CreateFunction(BinaryFactoryPtr binary_factory,
|
||||
size_t num_binary_ops,
|
||||
ov::element::Type input_type) {
|
||||
const ov::Shape input_shape{1, 96, 55, 55};
|
||||
|
||||
auto X = std::make_shared<ov::opset9::Parameter>(input_type, input_shape);
|
||||
|
||||
auto ng_order0 = std::make_shared<ov::opset9::Constant>(ov::element::u64, ov::Shape{4}, ov::Shape{0, 2, 3, 1});
|
||||
auto transpose0 = std::make_shared<ov::opset9::Transpose>(X, ng_order0);
|
||||
|
||||
NodePtr in_op = transpose0;
|
||||
for (size_t i = 0; i < num_binary_ops; ++i) {
|
||||
auto in_constant = std::make_shared<ov::opset9::Constant>(input_type, input_shape, ov::Shape{1});
|
||||
auto ng_order1 = std::make_shared<ov::opset9::Constant>(ov::element::u64, ov::Shape{4}, ov::Shape{0, 2, 3, 1});
|
||||
auto transpose1 = std::make_shared<ov::opset9::Transpose>(in_constant, ng_order1);
|
||||
|
||||
in_op = binary_factory->create(in_op, transpose1);
|
||||
}
|
||||
|
||||
return std::make_shared<ov::Model>(ov::OutputVector{in_op}, ov::ParameterVector{X});
|
||||
}
|
||||
|
||||
std::shared_ptr<ov::Model> CreateReferenceFunction(BinaryFactoryPtr binary_factory,
|
||||
size_t num_binary_ops,
|
||||
ov::element::Type input_type) {
|
||||
const ov::Shape input_shape{1, 96, 55, 55};
|
||||
|
||||
auto X = std::make_shared<ov::opset9::Parameter>(input_type, input_shape);
|
||||
|
||||
NodePtr in_op = X;
|
||||
for (size_t i = 0; i < num_binary_ops; ++i) {
|
||||
auto in_constant = std::make_shared<ov::opset9::Constant>(input_type, input_shape, ov::Shape{1});
|
||||
|
||||
auto ng_order1 = std::make_shared<ov::opset9::Constant>(ov::element::u64, ov::Shape{4}, ov::Shape{0, 2, 3, 1});
|
||||
auto transpose1 = std::make_shared<ov::opset9::Transpose>(in_constant, ng_order1);
|
||||
|
||||
auto transpose_reversed_const =
|
||||
std::make_shared<ov::opset9::Constant>(ov::element::u64, ov::Shape{4}, ov::Shape{0, 3, 1, 2});
|
||||
auto transpose_reversed = std::make_shared<ov::opset9::Transpose>(transpose1, transpose_reversed_const);
|
||||
|
||||
in_op = binary_factory->create(in_op, transpose_reversed);
|
||||
}
|
||||
|
||||
auto ng_order0 = std::make_shared<ov::opset9::Constant>(ov::element::u64, ov::Shape{4}, ov::Shape{0, 2, 3, 1});
|
||||
auto transpose0 = std::make_shared<ov::opset9::Transpose>(in_op, ng_order0);
|
||||
|
||||
return std::make_shared<ov::Model>(ov::OutputVector{transpose0}, ov::ParameterVector{X});
|
||||
}
|
||||
|
||||
} // namespace double_transpose
|
||||
} // namespace forward
|
||||
|
||||
namespace backward {
|
||||
namespace one_input_transpose {
|
||||
std::shared_ptr<ov::Model> CreateFunction(BinaryFactoryPtr binary_factory,
|
||||
size_t num_binary_ops,
|
||||
ov::element::Type input_type,
|
||||
size_t binary_transpose_input_idx) {
|
||||
const ov::Shape input_shape{1, 96, 55, 55};
|
||||
|
||||
auto X = std::make_shared<ov::opset9::Parameter>(input_type, input_shape);
|
||||
|
||||
NodePtr in_op = X;
|
||||
for (size_t i = 0; i < num_binary_ops; ++i) {
|
||||
auto in_constant = std::make_shared<ov::opset9::Constant>(input_type, input_shape, ov::Shape{1});
|
||||
if (!binary_transpose_input_idx)
|
||||
in_op = binary_factory->create(in_op, in_constant);
|
||||
else
|
||||
in_op = binary_factory->create(in_constant, in_op);
|
||||
}
|
||||
|
||||
auto ng_order0 = std::make_shared<ov::opset9::Constant>(ov::element::u64, ov::Shape{4}, ov::Shape{0, 2, 3, 1});
|
||||
auto transpose0 = std::make_shared<ov::opset9::Transpose>(in_op, ng_order0);
|
||||
|
||||
return std::make_shared<ov::Model>(ov::OutputVector{transpose0}, ov::ParameterVector{X});
|
||||
}
|
||||
|
||||
std::shared_ptr<ov::Model> CreateReferenceFunction(BinaryFactoryPtr binary_factory,
|
||||
size_t num_binary_ops,
|
||||
ov::element::Type input_type,
|
||||
size_t binary_transpose_input_idx) {
|
||||
const ov::Shape input_shape{1, 96, 55, 55};
|
||||
|
||||
auto X = std::make_shared<ov::opset9::Parameter>(input_type, input_shape);
|
||||
|
||||
auto ng_order0 = std::make_shared<ov::opset9::Constant>(ov::element::u64, ov::Shape{4}, ov::Shape{0, 2, 3, 1});
|
||||
auto transpose0 = std::make_shared<ov::opset9::Transpose>(X, ng_order0);
|
||||
|
||||
NodePtr in_op = transpose0;
|
||||
for (size_t i = 0; i < num_binary_ops; ++i) {
|
||||
auto in_constant = std::make_shared<ov::opset9::Constant>(input_type, input_shape, ov::Shape{1});
|
||||
|
||||
auto ng_order = std::make_shared<ov::opset9::Constant>(ov::element::u64, ov::Shape{4}, ov::Shape{0, 2, 3, 1});
|
||||
auto transpose = std::make_shared<ov::opset9::Transpose>(in_constant, ng_order);
|
||||
|
||||
if (!binary_transpose_input_idx)
|
||||
in_op = binary_factory->create(in_op, transpose);
|
||||
else
|
||||
in_op = binary_factory->create(transpose, in_op);
|
||||
}
|
||||
|
||||
return std::make_shared<ov::Model>(ov::OutputVector{in_op}, ov::ParameterVector{X});
|
||||
}
|
||||
} // namespace one_input_transpose
|
||||
} // namespace backward
|
||||
} // namespace single_consumer
|
||||
} // namespace binary
|
||||
|
||||
using CreateGraphBinaryF = std::function<std::shared_ptr<ov::Model>(BinaryFactoryPtr unary_factory,
|
||||
size_t num_binary_ops,
|
||||
ov::element::Type input_type,
|
||||
size_t binary_transpose_input_idx)>;
|
||||
|
||||
using TestBinaryParams = std::tuple<BinaryFactoryPtr,
|
||||
PassFactoryPtr,
|
||||
size_t, /* num_binary_ops */
|
||||
CreateGraphBinaryF, /* model_factory */
|
||||
CreateGraphBinaryF, /* reference_model_factory */
|
||||
ov::element::Type, /* input type */
|
||||
size_t>; /* binary_transpose_input_idx */
|
||||
|
||||
class TransposeSinkingBinaryTestFixture : public ::testing::WithParamInterface<TestBinaryParams>,
|
||||
public TransformationTestsF {};
|
||||
|
||||
|
||||
TEST_P(TransposeSinkingBinaryTestFixture, CompareFunctions) {
|
||||
BinaryFactoryPtr unary_factory;
|
||||
PassFactoryPtr pass_factory;
|
||||
size_t num_binary_ops;
|
||||
CreateGraphBinaryF model_factory;
|
||||
CreateGraphBinaryF reference_model_factory;
|
||||
ov::element::Type input_type;
|
||||
size_t binary_transpose_input_idx;
|
||||
std::tie(unary_factory,
|
||||
pass_factory,
|
||||
num_binary_ops,
|
||||
model_factory,
|
||||
reference_model_factory,
|
||||
input_type,
|
||||
binary_transpose_input_idx) = this->GetParam();
|
||||
|
||||
model = model_factory(unary_factory, num_binary_ops, input_type, binary_transpose_input_idx);
|
||||
model_ref = reference_model_factory(unary_factory, num_binary_ops, input_type, binary_transpose_input_idx);
|
||||
pass_factory->registerPass(manager);
|
||||
}
|
||||
|
||||
INSTANTIATE_TEST_SUITE_P(TransposeSinkingBinaryForwardTestSuite, TransposeSinkingBinaryTestFixture,
|
||||
::testing::Combine(::testing::ValuesIn(binary_factories),
|
||||
::testing::Values(CreatePassFactory<ov::pass::TransposeSinkingBinaryElementwiseForward>()),
|
||||
::testing::ValuesIn(binary_operations_numbers),
|
||||
::testing::Values(binary::single_consumer::forward::one_input_transpose::CreateFunction),
|
||||
::testing::Values(binary::single_consumer::forward::one_input_transpose::CreateReferenceFunction),
|
||||
::testing::Values(ov::element::f32),
|
||||
::testing::ValuesIn(binary_transpose_input_indexes)));
|
||||
|
||||
INSTANTIATE_TEST_SUITE_P(
|
||||
TransposeSinkingBinaryBackwardTestSuite,
|
||||
TransposeSinkingBinaryTestFixture,
|
||||
::testing::Combine(::testing::ValuesIn(binary_factories),
|
||||
::testing::Values(CreatePassFactory<ov::pass::TransposeSinkingBinaryElementwiseBackward>()),
|
||||
::testing::ValuesIn(binary_operations_numbers),
|
||||
::testing::Values(binary::single_consumer::backward::one_input_transpose::CreateFunction),
|
||||
::testing::Values(binary::single_consumer::backward::one_input_transpose::CreateReferenceFunction),
|
||||
::testing::Values(ov::element::f32),
|
||||
::testing::ValuesIn(binary_transpose_input_indexes)));
|
||||
|
||||
// --------------------------------------------------------------------------------------
|
||||
|
||||
using CreateGraphBinaryTwoTransposeInputsF = std::function<
|
||||
std::shared_ptr<ov::Model>(BinaryFactoryPtr unary_factory, size_t num_binary_ops, ov::element::Type input_type)>;
|
||||
|
||||
using TestBinaryTwoTransposeInputsParams = std::tuple<BinaryFactoryPtr,
|
||||
PassFactoryPtr,
|
||||
size_t, /* num_binary_ops */
|
||||
CreateGraphBinaryTwoTransposeInputsF, /* model_factory */
|
||||
CreateGraphBinaryTwoTransposeInputsF, /* reference_model_factory */
|
||||
ov::element::Type>; /* input type */
|
||||
|
||||
class TransposeSinkingBinaryTwoTransposeInputsTestFixture
|
||||
: public ::testing::WithParamInterface<TestBinaryTwoTransposeInputsParams>,
|
||||
public TransformationTestsF {};
|
||||
|
||||
TEST_P(TransposeSinkingBinaryTwoTransposeInputsTestFixture, CompareFunctions) {
|
||||
BinaryFactoryPtr unary_factory;
|
||||
PassFactoryPtr pass_factory;
|
||||
size_t num_binary_ops;
|
||||
CreateGraphBinaryTwoTransposeInputsF model_factory;
|
||||
CreateGraphBinaryTwoTransposeInputsF reference_model_factory;
|
||||
ov::element::Type input_type;
|
||||
|
||||
std::tie(unary_factory, pass_factory, num_binary_ops, model_factory, reference_model_factory, input_type) =
|
||||
this->GetParam();
|
||||
|
||||
model = model_factory(unary_factory, num_binary_ops, input_type);
|
||||
model_ref = reference_model_factory(unary_factory, num_binary_ops, input_type);
|
||||
pass_factory->registerPass(manager);
|
||||
}
|
||||
|
||||
INSTANTIATE_TEST_SUITE_P(
|
||||
TransposeSinkingBinaryTwoTransposeInputsForwardTestSuite,
|
||||
TransposeSinkingBinaryTwoTransposeInputsTestFixture,
|
||||
::testing::Combine(::testing::ValuesIn(binary_factories),
|
||||
::testing::Values(CreatePassFactory<ov::pass::TransposeSinkingBinaryElementwiseForward>()),
|
||||
::testing::ValuesIn(binary_operations_numbers),
|
||||
::testing::Values(binary::single_consumer::forward::double_transpose::CreateFunction),
|
||||
::testing::Values(binary::single_consumer::forward::double_transpose::CreateReferenceFunction),
|
||||
::testing::Values(ov::element::f32)));
|
@ -0,0 +1,400 @@
|
||||
// Copyright (C) 2022 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#include <transformations/common_optimizations/transpose_sinking_concat.hpp>
|
||||
|
||||
#include <transformations/init_node_info.hpp>
|
||||
#include <openvino/frontend/manager.hpp>
|
||||
#include <openvino/opsets/opset9.hpp>
|
||||
#include <openvino/pass/manager.hpp>
|
||||
#include "common_test_utils/ngraph_test_utils.hpp"
|
||||
|
||||
#include <functional>
|
||||
|
||||
#include "gtest/gtest.h"
|
||||
|
||||
namespace {
|
||||
|
||||
using NodePtr = std::shared_ptr<ov::Node>;
|
||||
using ModelPtr = std::shared_ptr<ov::Model>;
|
||||
using Output = ov::Output<ov::Node>;
|
||||
|
||||
// ----------------------------------------------------------------------------
|
||||
|
||||
class IBinaryFactory {
|
||||
public:
|
||||
IBinaryFactory() = default;
|
||||
virtual ~IBinaryFactory() = default;
|
||||
virtual NodePtr create(NodePtr parent_left_node, NodePtr parent_right_node) const = 0;
|
||||
};
|
||||
|
||||
using BinaryFactoryPtr = std::shared_ptr<IBinaryFactory>;
|
||||
|
||||
template <typename BinaryT>
|
||||
class BinaryFactory : public IBinaryFactory {
|
||||
public:
|
||||
BinaryFactory() = default;
|
||||
NodePtr create(NodePtr parent_left_node, NodePtr parent_right_node) const override {
|
||||
return std::make_shared<BinaryT>(parent_left_node, parent_right_node);
|
||||
}
|
||||
};
|
||||
|
||||
template <typename BinaryT>
|
||||
BinaryFactoryPtr CreateBinaryFactory() {
|
||||
return std::make_shared<BinaryFactory<BinaryT>>();
|
||||
}
|
||||
|
||||
// ----------------------------------------------------------------------------
|
||||
|
||||
class IPassFactory {
|
||||
public:
|
||||
IPassFactory() = default;
|
||||
virtual ~IPassFactory() = default;
|
||||
virtual void registerPass(ov::pass::Manager& pass_manager) const = 0;
|
||||
};
|
||||
|
||||
using PassFactoryPtr = std::shared_ptr<IPassFactory>;
|
||||
|
||||
template <typename PassT>
|
||||
class PassFactory : public IPassFactory {
|
||||
public:
|
||||
void registerPass(ov::pass::Manager& pass_manager) const override {
|
||||
pass_manager.register_pass<PassT>();
|
||||
}
|
||||
};
|
||||
|
||||
template <typename PassT>
|
||||
PassFactoryPtr CreatePassFactory() {
|
||||
return std::make_shared<PassFactory<PassT>>();
|
||||
}
|
||||
|
||||
std::vector<BinaryFactoryPtr> binary_factories = {
|
||||
CreateBinaryFactory<ov::opset9::Add>(),
|
||||
CreateBinaryFactory<ov::opset9::Divide>(),
|
||||
CreateBinaryFactory<ov::opset9::Maximum>(),
|
||||
CreateBinaryFactory<ov::opset9::Minimum>(),
|
||||
CreateBinaryFactory<ov::opset9::Mod>(),
|
||||
CreateBinaryFactory<ov::opset9::Multiply>(),
|
||||
CreateBinaryFactory<ov::opset9::Power>(),
|
||||
CreateBinaryFactory<ov::opset9::SquaredDifference>(),
|
||||
CreateBinaryFactory<ov::opset9::Subtract>()
|
||||
};
|
||||
|
||||
std::vector<size_t> binary_operations_numbers = {1, 10};
|
||||
|
||||
std::vector<size_t> binary_transpose_input_indexes = {0, 1};
|
||||
|
||||
} // namespace
|
||||
|
||||
|
||||
using CreateGraphConcatF = std::function< std::shared_ptr<ov::Model> (size_t num_concat_ops,
|
||||
ov::element::Type input_type,
|
||||
size_t concat_transpose_input_idx,
|
||||
size_t num_concat_inputs) >;
|
||||
|
||||
using TestConcatParams = std::tuple<PassFactoryPtr,
|
||||
size_t, /* num_concat_ops */
|
||||
CreateGraphConcatF, /* model_factory */
|
||||
CreateGraphConcatF, /* reference_model_factory */
|
||||
ov::element::Type, /* input type */
|
||||
size_t, /* concat_transpose_input_idx */
|
||||
size_t>; /* num_concat_inputs */
|
||||
|
||||
class TransposeSinkingConcatTestFixture: public ::testing::WithParamInterface<TestConcatParams>,
|
||||
public TransformationTestsF {};
|
||||
|
||||
namespace {
|
||||
|
||||
std::vector<size_t> concat_operations_numbers = {1, 10};
|
||||
|
||||
std::vector<size_t> concat_transpose_input_indexes = {0, 2};
|
||||
|
||||
} // namespace
|
||||
|
||||
namespace single_consumer {
|
||||
namespace forward {
|
||||
namespace one_input_transpose {
|
||||
|
||||
std::shared_ptr<ov::Model> CreateFunction(size_t num_concat_ops,
|
||||
ov::element::Type input_type,
|
||||
size_t concat_transpose_input_idx,
|
||||
size_t num_concat_inputs) {
|
||||
const ov::Shape input_shape{1, 96, 55, 55};
|
||||
const ov::Shape const_shape{1, 55, 55, 96};
|
||||
|
||||
auto X = std::make_shared<ov::opset9::Parameter>(input_type, input_shape);
|
||||
|
||||
auto ng_order0 = std::make_shared<ov::opset9::Constant>(ov::element::u64, ov::Shape{4}, ov::Shape{0, 2, 3, 1});
|
||||
auto transpose0 = std::make_shared<ov::opset9::Transpose>(X, ng_order0);
|
||||
|
||||
NodePtr in_op = transpose0;
|
||||
for (size_t i = 0; i < num_concat_ops; ++i) {
|
||||
ov::OutputVector concat_inputs;
|
||||
for (size_t j = 0; j < num_concat_inputs; ++j) {
|
||||
if (j == concat_transpose_input_idx)
|
||||
concat_inputs.push_back(in_op);
|
||||
else
|
||||
concat_inputs.push_back(std::make_shared<ov::opset9::Constant>(input_type, const_shape, ov::Shape{1}));
|
||||
}
|
||||
in_op = std::make_shared<ov::opset9::Concat>(concat_inputs, 1);
|
||||
}
|
||||
|
||||
return std::make_shared<ov::Model>(ov::OutputVector{in_op}, ov::ParameterVector{X});
|
||||
}
|
||||
|
||||
std::shared_ptr<ov::Model> CreateReferenceFunction(size_t num_concat_ops,
|
||||
ov::element::Type input_type,
|
||||
size_t concat_transpose_input_idx,
|
||||
size_t num_concat_inputs) {
|
||||
const ov::Shape input_shape{1, 96, 55, 55};
|
||||
const ov::Shape const_shape{1, 55, 55, 96};
|
||||
|
||||
auto X = std::make_shared<ov::opset9::Parameter>(input_type, input_shape);
|
||||
|
||||
NodePtr in_op = X;
|
||||
for (size_t i = 0; i < num_concat_ops; ++i) {
|
||||
ov::OutputVector concat_inputs;
|
||||
for (size_t j = 0; j < num_concat_inputs; ++j) {
|
||||
if (j == concat_transpose_input_idx) {
|
||||
concat_inputs.push_back(in_op);
|
||||
} else {
|
||||
auto in_constant = std::make_shared<ov::opset9::Constant>(input_type, const_shape, ov::Shape{1});
|
||||
|
||||
auto transpose_reversed_const =
|
||||
std::make_shared<ov::opset9::Constant>(ov::element::u64, ov::Shape{4}, ov::Shape{0, 3, 1, 2});
|
||||
auto transpose_reversed =
|
||||
std::make_shared<ov::opset9::Transpose>(in_constant, transpose_reversed_const);
|
||||
|
||||
concat_inputs.push_back(transpose_reversed);
|
||||
}
|
||||
}
|
||||
in_op = std::make_shared<ov::opset9::Concat>(concat_inputs, 2);
|
||||
}
|
||||
|
||||
auto ng_order0 = std::make_shared<ov::opset9::Constant>(ov::element::u64, ov::Shape{4}, ov::Shape{0, 2, 3, 1});
|
||||
auto transpose0 = std::make_shared<ov::opset9::Transpose>(in_op, ng_order0);
|
||||
|
||||
return std::make_shared<ov::Model>(ov::OutputVector{transpose0}, ov::ParameterVector{X});
|
||||
}
|
||||
|
||||
} // namespace one_input_transpose
|
||||
|
||||
namespace double_transpose {
|
||||
|
||||
std::shared_ptr<ov::Model> CreateFunction(size_t num_concat_ops,
|
||||
ov::element::Type input_type,
|
||||
size_t num_concat_inputs) {
|
||||
const ov::Shape input_shape{1, 96, 55, 55};
|
||||
|
||||
auto X = std::make_shared<ov::opset9::Parameter>(input_type, input_shape);
|
||||
|
||||
auto ng_order0 = std::make_shared<ov::opset9::Constant>(ov::element::u64, ov::Shape{4}, ov::Shape{0, 2, 3, 1});
|
||||
auto transpose0 = std::make_shared<ov::opset9::Transpose>(X, ng_order0);
|
||||
|
||||
NodePtr in_op = transpose0;
|
||||
for (size_t i = 0; i < num_concat_ops; ++i) {
|
||||
ov::OutputVector concat_inputs;
|
||||
concat_inputs.push_back(in_op);
|
||||
for (size_t j = 1; j < num_concat_inputs; ++j) {
|
||||
auto in_constant = std::make_shared<ov::opset9::Constant>(input_type, input_shape, ov::Shape{1});
|
||||
auto ng_order1 =
|
||||
std::make_shared<ov::opset9::Constant>(ov::element::u64, ov::Shape{4}, ov::Shape{0, 2, 3, 1});
|
||||
auto transpose1 = std::make_shared<ov::opset9::Transpose>(in_constant, ng_order1);
|
||||
concat_inputs.push_back(transpose1);
|
||||
}
|
||||
in_op = std::make_shared<ov::opset9::Concat>(concat_inputs, 1);
|
||||
}
|
||||
|
||||
return std::make_shared<ov::Model>(ov::OutputVector{in_op}, ov::ParameterVector{X});
|
||||
}
|
||||
|
||||
std::shared_ptr<ov::Model> CreateReferenceFunction(size_t num_concat_ops,
|
||||
ov::element::Type input_type,
|
||||
size_t num_concat_inputs) {
|
||||
const ov::Shape input_shape{1, 96, 55, 55};
|
||||
|
||||
auto X = std::make_shared<ov::opset9::Parameter>(input_type, input_shape);
|
||||
|
||||
NodePtr in_op = X;
|
||||
for (size_t i = 0; i < num_concat_ops; ++i) {
|
||||
ov::OutputVector concat_inputs;
|
||||
|
||||
concat_inputs.push_back(in_op);
|
||||
|
||||
for (size_t j = 1; j < num_concat_inputs; ++j) {
|
||||
auto in_constant = std::make_shared<ov::opset9::Constant>(input_type, input_shape, ov::Shape{1});
|
||||
|
||||
auto ng_order1 =
|
||||
std::make_shared<ov::opset9::Constant>(ov::element::u64, ov::Shape{4}, ov::Shape{0, 2, 3, 1});
|
||||
auto transpose1 = std::make_shared<ov::opset9::Transpose>(in_constant, ng_order1);
|
||||
|
||||
auto transpose_reversed_const =
|
||||
std::make_shared<ov::opset9::Constant>(ov::element::u64, ov::Shape{4}, ov::Shape{0, 3, 1, 2});
|
||||
auto transpose_reversed = std::make_shared<ov::opset9::Transpose>(transpose1, transpose_reversed_const);
|
||||
|
||||
concat_inputs.push_back(transpose_reversed);
|
||||
}
|
||||
in_op = std::make_shared<ov::opset9::Concat>(concat_inputs, 2);
|
||||
}
|
||||
|
||||
auto ng_order0 = std::make_shared<ov::opset9::Constant>(ov::element::u64, ov::Shape{4}, ov::Shape{0, 2, 3, 1});
|
||||
auto transpose0 = std::make_shared<ov::opset9::Transpose>(in_op, ng_order0);
|
||||
|
||||
return std::make_shared<ov::Model>(ov::OutputVector{transpose0}, ov::ParameterVector{X});
|
||||
}
|
||||
|
||||
} // namespace double_transpose
|
||||
|
||||
} // namespace forward
|
||||
|
||||
namespace backward {
|
||||
|
||||
std::shared_ptr<ov::Model> CreateFunction(size_t num_concat_ops,
|
||||
ov::element::Type input_type,
|
||||
size_t concat_transpose_input_idx,
|
||||
size_t num_concat_inputs) {
|
||||
const ov::Shape input_shape{1, 96, 55, 55};
|
||||
|
||||
auto X = std::make_shared<ov::opset9::Parameter>(input_type, input_shape);
|
||||
|
||||
NodePtr in_op = X;
|
||||
for (size_t i = 0; i < num_concat_ops; ++i) {
|
||||
ov::OutputVector concat_inputs;
|
||||
for (size_t j = 0; j < num_concat_inputs; ++j) {
|
||||
if (j == concat_transpose_input_idx)
|
||||
concat_inputs.push_back(in_op);
|
||||
else
|
||||
concat_inputs.push_back(std::make_shared<ov::opset9::Constant>(input_type, input_shape, ov::Shape{1}));
|
||||
}
|
||||
in_op = std::make_shared<ov::opset9::Concat>(concat_inputs, 1);
|
||||
}
|
||||
|
||||
auto ng_order0 = std::make_shared<ov::opset9::Constant>(ov::element::u64, ov::Shape{4}, ov::Shape{0, 2, 3, 1});
|
||||
auto transpose0 = std::make_shared<ov::opset9::Transpose>(in_op, ng_order0);
|
||||
|
||||
return std::make_shared<ov::Model>(ov::OutputVector{transpose0}, ov::ParameterVector{X});
|
||||
}
|
||||
|
||||
std::shared_ptr<ov::Model> CreateReferenceFunction(size_t num_concat_ops,
|
||||
ov::element::Type input_type,
|
||||
size_t concat_transpose_input_idx,
|
||||
size_t num_concat_inputs) {
|
||||
const ov::Shape input_shape{1, 96, 55, 55};
|
||||
|
||||
auto X = std::make_shared<ov::opset9::Parameter>(input_type, input_shape);
|
||||
|
||||
auto ng_order0 = std::make_shared<ov::opset9::Constant>(ov::element::u64, ov::Shape{4}, ov::Shape{0, 2, 3, 1});
|
||||
auto transpose0 = std::make_shared<ov::opset9::Transpose>(X, ng_order0);
|
||||
|
||||
NodePtr in_op = transpose0;
|
||||
for (size_t i = 0; i < num_concat_ops; ++i) {
|
||||
ov::OutputVector concat_inputs;
|
||||
for (size_t j = 0; j < num_concat_inputs; ++j) {
|
||||
if (j == concat_transpose_input_idx) {
|
||||
concat_inputs.push_back(in_op);
|
||||
} else {
|
||||
auto in_constant = std::make_shared<ov::opset9::Constant>(input_type, input_shape, ov::Shape{1});
|
||||
|
||||
auto transpose_reversed_const = std::make_shared<ov::opset9::Constant>(ov::element::u64, ov::Shape{4}, ov::Shape{0, 2, 3, 1});
|
||||
auto transpose_reversed = std::make_shared<ov::opset9::Transpose>(in_constant, transpose_reversed_const);
|
||||
|
||||
concat_inputs.push_back(transpose_reversed);
|
||||
}
|
||||
}
|
||||
in_op = std::make_shared<ov::opset9::Concat>(concat_inputs, 3);
|
||||
}
|
||||
|
||||
return std::make_shared<ov::Model>(ov::OutputVector{in_op}, ov::ParameterVector{X});
|
||||
}
|
||||
|
||||
} // namespace backward
|
||||
} // namespace single_consumer
|
||||
|
||||
|
||||
|
||||
TEST_P(TransposeSinkingConcatTestFixture, CompareFunctions) {
|
||||
PassFactoryPtr pass_factory;
|
||||
size_t num_concat_ops;
|
||||
CreateGraphConcatF model_factory;
|
||||
CreateGraphConcatF reference_model_factory;
|
||||
ov::element::Type input_type;
|
||||
size_t concat_transpose_input_idx;
|
||||
size_t num_concat_inputs;
|
||||
std::tie(pass_factory,
|
||||
num_concat_ops,
|
||||
model_factory,
|
||||
reference_model_factory,
|
||||
input_type,
|
||||
concat_transpose_input_idx,
|
||||
num_concat_inputs) = this->GetParam();
|
||||
|
||||
model = model_factory(num_concat_ops, input_type, concat_transpose_input_idx, num_concat_inputs);
|
||||
model_ref = reference_model_factory(num_concat_ops, input_type, concat_transpose_input_idx, num_concat_inputs);
|
||||
pass_factory->registerPass(manager);
|
||||
}
|
||||
|
||||
INSTANTIATE_TEST_SUITE_P(TransposeSinkingConcatForwardTestSuite, TransposeSinkingConcatTestFixture,
|
||||
::testing::Combine(::testing::Values(CreatePassFactory<ov::pass::TransposeSinkingConcatForward>()),
|
||||
::testing::ValuesIn(concat_operations_numbers),
|
||||
::testing::Values(single_consumer::forward::one_input_transpose::CreateFunction),
|
||||
::testing::Values(single_consumer::forward::one_input_transpose::CreateReferenceFunction),
|
||||
::testing::Values(ov::element::f32),
|
||||
::testing::ValuesIn(concat_transpose_input_indexes),
|
||||
::testing::Values(5)));
|
||||
|
||||
INSTANTIATE_TEST_SUITE_P(TransposeSinkingConcatBackwardTestSuite, TransposeSinkingConcatTestFixture,
|
||||
::testing::Combine(::testing::Values(CreatePassFactory<ov::pass::TransposeSinkingConcatBackward>()),
|
||||
::testing::ValuesIn(concat_operations_numbers),
|
||||
::testing::Values(single_consumer::backward::CreateFunction),
|
||||
::testing::Values(single_consumer::backward::CreateReferenceFunction),
|
||||
::testing::Values(ov::element::f32),
|
||||
::testing::ValuesIn(concat_transpose_input_indexes),
|
||||
::testing::Values(5)));
|
||||
|
||||
|
||||
// --------------------------------------------------------------------------------------
|
||||
|
||||
using CreateGraphConcatAllTransposesInputF = std::function<std::shared_ptr<ov::Model>(size_t num_concat_ops,
|
||||
ov::element::Type input_type,
|
||||
size_t num_concat_inputs)>;
|
||||
|
||||
using TestConcatAllTransposesInputParams = std::tuple<PassFactoryPtr,
|
||||
size_t, /* num_concat_ops */
|
||||
CreateGraphConcatAllTransposesInputF, /* model_factory */
|
||||
CreateGraphConcatAllTransposesInputF, /* reference_model_factory */
|
||||
ov::element::Type, /* input type */
|
||||
size_t>; /* num_concat_inputs */
|
||||
|
||||
class TransposeSinkingConcatAllTransposesInputTestFixture
|
||||
: public ::testing::WithParamInterface<TestConcatAllTransposesInputParams>,
|
||||
public TransformationTestsF {};
|
||||
|
||||
TEST_P(TransposeSinkingConcatAllTransposesInputTestFixture, CompareFunctions) {
|
||||
PassFactoryPtr pass_factory;
|
||||
size_t num_concat_ops;
|
||||
CreateGraphConcatAllTransposesInputF model_factory;
|
||||
CreateGraphConcatAllTransposesInputF reference_model_factory;
|
||||
ov::element::Type input_type;
|
||||
size_t num_concat_inputs;
|
||||
std::tie(pass_factory,
|
||||
num_concat_ops,
|
||||
model_factory,
|
||||
reference_model_factory,
|
||||
input_type,
|
||||
num_concat_inputs) = this->GetParam();
|
||||
|
||||
model = model_factory(num_concat_ops, input_type, num_concat_inputs);
|
||||
model_ref = reference_model_factory(num_concat_ops, input_type, num_concat_inputs);
|
||||
pass_factory->registerPass(manager);
|
||||
}
|
||||
|
||||
INSTANTIATE_TEST_SUITE_P(
|
||||
TransposeSinkingConcatForwardAllTransposesTestSuite,
|
||||
TransposeSinkingConcatAllTransposesInputTestFixture,
|
||||
::testing::Combine(::testing::Values(CreatePassFactory<ov::pass::TransposeSinkingConcatForward>()),
|
||||
::testing::ValuesIn(concat_operations_numbers),
|
||||
::testing::Values(single_consumer::forward::double_transpose::CreateFunction),
|
||||
::testing::Values(single_consumer::forward::double_transpose::CreateReferenceFunction),
|
||||
::testing::Values(ov::element::f32),
|
||||
::testing::Values(5)));
|
@ -0,0 +1,552 @@
|
||||
// Copyright (C) 2022 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#include <transformations/common_optimizations/transpose_sinking_split.hpp>
|
||||
|
||||
#include <transformations/init_node_info.hpp>
|
||||
#include <openvino/frontend/manager.hpp>
|
||||
#include <openvino/opsets/opset9.hpp>
|
||||
#include <openvino/pass/manager.hpp>
|
||||
#include "common_test_utils/ngraph_test_utils.hpp"
|
||||
|
||||
#include <functional>
|
||||
|
||||
#include "gtest/gtest.h"
|
||||
|
||||
namespace {
|
||||
|
||||
using NodePtr = std::shared_ptr<ov::Node>;
|
||||
using ModelPtr = std::shared_ptr<ov::Model>;
|
||||
using Output = ov::Output<ov::Node>;
|
||||
|
||||
// ----------------------------------------------------------------------------
|
||||
|
||||
class IBinaryFactory {
|
||||
public:
|
||||
IBinaryFactory() = default;
|
||||
virtual ~IBinaryFactory() = default;
|
||||
virtual NodePtr create(NodePtr parent_left_node, NodePtr parent_right_node) const = 0;
|
||||
};
|
||||
|
||||
using BinaryFactoryPtr = std::shared_ptr<IBinaryFactory>;
|
||||
|
||||
template <typename BinaryT>
|
||||
class BinaryFactory : public IBinaryFactory {
|
||||
public:
|
||||
BinaryFactory() = default;
|
||||
NodePtr create(NodePtr parent_left_node, NodePtr parent_right_node) const override {
|
||||
return std::make_shared<BinaryT>(parent_left_node, parent_right_node);
|
||||
}
|
||||
};
|
||||
|
||||
template <typename BinaryT>
|
||||
BinaryFactoryPtr CreateBinaryFactory() {
|
||||
return std::make_shared<BinaryFactory<BinaryT>>();
|
||||
}
|
||||
|
||||
// ----------------------------------------------------------------------------
|
||||
|
||||
class IPassFactory {
|
||||
public:
|
||||
IPassFactory() = default;
|
||||
virtual ~IPassFactory() = default;
|
||||
virtual void registerPass(ov::pass::Manager& pass_manager) const = 0;
|
||||
};
|
||||
|
||||
using PassFactoryPtr = std::shared_ptr<IPassFactory>;
|
||||
|
||||
template <typename PassT>
|
||||
class PassFactory : public IPassFactory {
|
||||
public:
|
||||
void registerPass(ov::pass::Manager& pass_manager) const override {
|
||||
pass_manager.register_pass<PassT>();
|
||||
}
|
||||
};
|
||||
|
||||
template <typename PassT>
|
||||
PassFactoryPtr CreatePassFactory() {
|
||||
return std::make_shared<PassFactory<PassT>>();
|
||||
}
|
||||
|
||||
std::vector<BinaryFactoryPtr> binary_factories = {
|
||||
CreateBinaryFactory<ov::opset9::Add>(),
|
||||
CreateBinaryFactory<ov::opset9::Divide>(),
|
||||
CreateBinaryFactory<ov::opset9::Maximum>(),
|
||||
CreateBinaryFactory<ov::opset9::Minimum>(),
|
||||
CreateBinaryFactory<ov::opset9::Mod>(),
|
||||
CreateBinaryFactory<ov::opset9::Multiply>(),
|
||||
CreateBinaryFactory<ov::opset9::Power>(),
|
||||
CreateBinaryFactory<ov::opset9::SquaredDifference>(),
|
||||
CreateBinaryFactory<ov::opset9::Subtract>()
|
||||
};
|
||||
|
||||
std::vector<size_t> binary_operations_numbers = {1, 10};
|
||||
|
||||
std::vector<size_t> binary_transpose_input_indexes = {0, 1};
|
||||
|
||||
} // namespace
|
||||
|
||||
// --------------------------------------------------------------------------------------
|
||||
|
||||
using CreateGraphSplitForwardF = std::function< std::shared_ptr<ov::Model> (size_t num_split_ops,
|
||||
size_t num_split_outputs,
|
||||
ov::element::Type input_type)>;
|
||||
|
||||
using TestSplitForwardParams = std::tuple<PassFactoryPtr,
|
||||
size_t, /* num_split_ops */
|
||||
size_t, /* num_split_outputs */
|
||||
CreateGraphSplitForwardF, /* model_factory */
|
||||
CreateGraphSplitForwardF, /* reference_model_factory */
|
||||
ov::element::Type> /* input type */;
|
||||
|
||||
class TransposeSinkingSplitForwardTestFixture: public ::testing::WithParamInterface<TestSplitForwardParams>,
|
||||
public TransformationTestsF {};
|
||||
|
||||
namespace {
|
||||
|
||||
std::vector<size_t> split_operations_numbers = {1, 10};
|
||||
|
||||
std::vector<size_t> split_outputs_numbers = {2, 3};
|
||||
|
||||
} // namespace
|
||||
|
||||
namespace split {
|
||||
namespace forward {
|
||||
std::shared_ptr<ov::Model> CreateFunction(size_t num_split_ops,
|
||||
size_t num_split_outputs,
|
||||
ov::element::Type input_type) {
|
||||
const ov::Shape input_shape{96, static_cast<size_t>(std::pow(num_split_outputs, num_split_ops + 1)), 55, 55};
|
||||
|
||||
auto X = std::make_shared<ov::opset9::Parameter>(input_type, input_shape);
|
||||
|
||||
auto ng_order0 = std::make_shared<ov::opset9::Constant>(ov::element::u64, ov::Shape{4}, ov::Shape{0, 3, 1, 2});
|
||||
auto transpose0 = std::make_shared<ov::opset9::Transpose>(X, ng_order0);
|
||||
|
||||
ov::OutputVector outputs;
|
||||
Output in_op = transpose0->output(0);
|
||||
for (size_t i = 0; i < num_split_ops; ++i) {
|
||||
auto split_axis_const = std::make_shared<ov::opset9::Constant>(ov::element::u64,
|
||||
ov::Shape{},
|
||||
2);
|
||||
auto split = std::make_shared<ov::opset9::Split>(in_op, split_axis_const, num_split_outputs);
|
||||
for (size_t num_output = 0; num_output < num_split_outputs - 1; ++num_output) {
|
||||
outputs.push_back(split->output(num_output));
|
||||
}
|
||||
in_op = split->output(num_split_outputs - 1);
|
||||
}
|
||||
outputs.push_back(in_op);
|
||||
|
||||
return std::make_shared<ov::Model>(outputs, ov::ParameterVector{X});
|
||||
}
|
||||
|
||||
std::shared_ptr<ov::Model> CreateReferenceFunction(size_t num_split_ops,
|
||||
size_t num_split_outputs,
|
||||
ov::element::Type input_type) {
|
||||
const ov::Shape input_shape{96, static_cast<size_t>(std::pow(num_split_outputs, num_split_ops + 1)), 55, 55};
|
||||
|
||||
auto X = std::make_shared<ov::opset9::Parameter>(input_type, input_shape);
|
||||
|
||||
ov::OutputVector outputs;
|
||||
Output in_op = X->output(0);
|
||||
for (size_t i = 0; i < num_split_ops; ++i) {
|
||||
auto split_axis_const = std::make_shared<ov::opset9::Constant>(ov::element::u64,
|
||||
ov::Shape{},
|
||||
1);
|
||||
auto split = std::make_shared<ov::opset9::Split>(in_op, split_axis_const, num_split_outputs);
|
||||
for (size_t num_output = 0; num_output < num_split_outputs - 1; ++num_output) {
|
||||
auto ng_order0 = std::make_shared<ov::opset9::Constant>(ov::element::u64, ov::Shape{4}, ov::Shape{0, 3, 1, 2});
|
||||
auto transpose0 = std::make_shared<ov::opset9::Transpose>(split->output(num_output), ng_order0);
|
||||
outputs.push_back(transpose0);
|
||||
}
|
||||
in_op = split->output(num_split_outputs - 1);
|
||||
}
|
||||
|
||||
auto ng_order0 = std::make_shared<ov::opset9::Constant>(ov::element::u64, ov::Shape{4}, ov::Shape{0, 3, 1, 2});
|
||||
auto transpose0 = std::make_shared<ov::opset9::Transpose>(in_op, ng_order0);
|
||||
outputs.push_back(transpose0);
|
||||
|
||||
return std::make_shared<ov::Model>(outputs, ov::ParameterVector{X});
|
||||
}
|
||||
|
||||
} // namespace forward
|
||||
} // namespace split
|
||||
|
||||
|
||||
TEST_P(TransposeSinkingSplitForwardTestFixture, CompareFunctions) {
|
||||
PassFactoryPtr pass_factory;
|
||||
size_t num_split_ops;
|
||||
size_t num_split_outputs;
|
||||
CreateGraphSplitForwardF model_factory;
|
||||
CreateGraphSplitForwardF reference_model_factory;
|
||||
ov::element::Type input_type;
|
||||
std::tie(pass_factory,
|
||||
num_split_ops,
|
||||
num_split_outputs,
|
||||
model_factory,
|
||||
reference_model_factory,
|
||||
input_type) = this->GetParam();
|
||||
|
||||
model = model_factory(num_split_ops, num_split_outputs, input_type);
|
||||
model_ref = reference_model_factory(num_split_ops, num_split_outputs, input_type);
|
||||
pass_factory->registerPass(manager);
|
||||
}
|
||||
|
||||
INSTANTIATE_TEST_SUITE_P(
|
||||
TransposeSinkingSplitForwardTestSuite,
|
||||
TransposeSinkingSplitForwardTestFixture,
|
||||
::testing::Combine(::testing::Values(CreatePassFactory<ov::pass::TransposeSinkingSplitForward>()),
|
||||
::testing::ValuesIn(split_operations_numbers),
|
||||
::testing::ValuesIn(split_outputs_numbers),
|
||||
::testing::Values(split::forward::CreateFunction),
|
||||
::testing::Values(split::forward::CreateReferenceFunction),
|
||||
::testing::Values(ov::element::f32)));
|
||||
|
||||
|
||||
// --------------------------------------------------------------------------------------
|
||||
|
||||
using CreateGraphSplitBackwardF = std::function< std::shared_ptr<ov::Model> (size_t split_tree_depth,
|
||||
size_t num_split_outputs,
|
||||
ov::element::Type input_type)>;
|
||||
|
||||
using TestSplitBackwardParams = std::tuple<PassFactoryPtr,
|
||||
size_t, /* split_tree_depth */
|
||||
size_t, /* num_split_outputs */
|
||||
CreateGraphSplitBackwardF, /* model_factory */
|
||||
CreateGraphSplitBackwardF, /* reference_model_factory */
|
||||
ov::element::Type> /* input type */;
|
||||
|
||||
class TransposeSinkingSplitBackwardTestFixture: public ::testing::WithParamInterface<TestSplitBackwardParams>,
|
||||
public TransformationTestsF {};
|
||||
|
||||
namespace {
|
||||
std::vector<size_t> split_tree_depth_nums = {1, 3};
|
||||
} // namespace
|
||||
|
||||
// --------------------------------------------------------------------------------------
|
||||
|
||||
namespace split {
|
||||
namespace backward {
|
||||
|
||||
class SplitFactory {
|
||||
public:
|
||||
SplitFactory(size_t axis, size_t n_outputs, ov::element::Type elem_type) :
|
||||
_axis(axis), _n_outputs(n_outputs), _elem_type(elem_type) {}
|
||||
NodePtr create(Output parent) const {
|
||||
auto split_axis_const = std::make_shared<ov::opset9::Constant>(_elem_type,
|
||||
ov::Shape{},
|
||||
_axis);
|
||||
return std::make_shared<ov::opset9::Split>(parent, split_axis_const, _n_outputs);
|
||||
}
|
||||
private:
|
||||
const size_t _axis;
|
||||
const size_t _n_outputs;
|
||||
const ov::element::Type _elem_type;
|
||||
};
|
||||
|
||||
void CreateSplitTree(size_t max_depth, size_t depth, Output parent, const SplitFactory & split_factory, ov::OutputVector & leaves) {
|
||||
if (depth == max_depth) {
|
||||
leaves.push_back(parent);
|
||||
return;
|
||||
}
|
||||
|
||||
auto split = split_factory.create(parent);
|
||||
|
||||
for (size_t output_idx = 0; output_idx < split->get_output_size(); ++output_idx) {
|
||||
CreateSplitTree(max_depth, depth + 1, split->output(output_idx), split_factory, leaves);
|
||||
}
|
||||
}
|
||||
|
||||
std::shared_ptr<ov::Model> CreateFunction(size_t split_tree_depth,
|
||||
size_t num_split_outputs,
|
||||
ov::element::Type input_type) {
|
||||
const size_t split_input_dim_value = static_cast<size_t>(std::pow(num_split_outputs, split_tree_depth + 1));
|
||||
const ov::Shape input_shape{96, split_input_dim_value, 55, 55};
|
||||
|
||||
auto X = std::make_shared<ov::opset9::Parameter>(input_type, input_shape);
|
||||
|
||||
ov::OutputVector split_tree_leaves;
|
||||
{
|
||||
SplitFactory split_factory(/* axis */ 1, num_split_outputs, /* elem_type */ ov::element::u64);
|
||||
CreateSplitTree(split_tree_depth, /* depth */ 0, X->output(0), split_factory, split_tree_leaves);
|
||||
}
|
||||
|
||||
ov::OutputVector outputs;
|
||||
for (auto& split_tree_leaf : split_tree_leaves) {
|
||||
auto ng_order = std::make_shared<ov::opset9::Constant>(ov::element::u64, ov::Shape{4}, ov::Shape{0, 3, 1, 2});
|
||||
auto transpose = std::make_shared<ov::opset9::Transpose>(split_tree_leaf, ng_order);
|
||||
|
||||
const size_t split_dim_current_value = static_cast<size_t>(split_input_dim_value / std::pow(num_split_outputs, split_tree_depth));
|
||||
auto reshape_const = std::make_shared<ov::opset9::Constant>(ov::element::u64, ov::Shape{3}, ov::Shape{96, 55, split_dim_current_value * 55});
|
||||
auto reshape = std::make_shared<ov::opset9::Reshape>(transpose, reshape_const, false);
|
||||
outputs.push_back(reshape);
|
||||
}
|
||||
|
||||
return std::make_shared<ov::Model>(outputs, ov::ParameterVector{X});
|
||||
}
|
||||
|
||||
std::shared_ptr<ov::Model> CreateReferenceFunction(size_t split_tree_depth,
|
||||
size_t num_split_outputs,
|
||||
ov::element::Type input_type) {
|
||||
const size_t split_input_dim_value = static_cast<size_t>(std::pow(num_split_outputs, split_tree_depth + 1));
|
||||
const ov::Shape input_shape{96, split_input_dim_value, 55, 55};
|
||||
|
||||
auto X = std::make_shared<ov::opset9::Parameter>(input_type, input_shape);
|
||||
|
||||
auto ng_order = std::make_shared<ov::opset9::Constant>(ov::element::u64, ov::Shape{4}, ov::Shape{0, 3, 1, 2});
|
||||
auto transpose = std::make_shared<ov::opset9::Transpose>(X, ng_order);
|
||||
|
||||
ov::OutputVector split_tree_leaves;
|
||||
{
|
||||
SplitFactory split_factory(/* axis */ 2, num_split_outputs, /* elem_type */ ov::element::u64);
|
||||
CreateSplitTree(split_tree_depth, /* depth */ 0, transpose->output(0), split_factory, split_tree_leaves);
|
||||
}
|
||||
|
||||
ov::OutputVector outputs;
|
||||
for (auto& split_tree_leaf : split_tree_leaves) {
|
||||
const size_t split_dim_current_value = static_cast<size_t>(split_input_dim_value / std::pow(num_split_outputs, split_tree_depth));
|
||||
auto reshape_const = std::make_shared<ov::opset9::Constant>(ov::element::u64, ov::Shape{3}, ov::Shape{96, 55, split_dim_current_value * 55});
|
||||
auto reshape = std::make_shared<ov::opset9::Reshape>(split_tree_leaf, reshape_const, false);
|
||||
outputs.push_back(reshape);
|
||||
}
|
||||
|
||||
return std::make_shared<ov::Model>(outputs, ov::ParameterVector{X});
|
||||
}
|
||||
|
||||
} // namespace backward
|
||||
} // namespace split
|
||||
|
||||
|
||||
|
||||
TEST_P(TransposeSinkingSplitBackwardTestFixture, CompareFunctions) {
|
||||
PassFactoryPtr pass_factory;
|
||||
size_t split_tree_depth;
|
||||
size_t num_split_outputs;
|
||||
CreateGraphSplitBackwardF model_factory;
|
||||
CreateGraphSplitBackwardF reference_model_factory;
|
||||
ov::element::Type input_type;
|
||||
std::tie(pass_factory,
|
||||
split_tree_depth,
|
||||
num_split_outputs,
|
||||
model_factory,
|
||||
reference_model_factory,
|
||||
input_type) = this->GetParam();
|
||||
|
||||
model = model_factory(split_tree_depth, num_split_outputs, input_type);
|
||||
model_ref = reference_model_factory(split_tree_depth, num_split_outputs, input_type);
|
||||
pass_factory->registerPass(manager);
|
||||
}
|
||||
|
||||
INSTANTIATE_TEST_SUITE_P(
|
||||
TransposeSinkingSplitBackwardTestSuite,
|
||||
TransposeSinkingSplitBackwardTestFixture,
|
||||
::testing::Combine(::testing::Values(CreatePassFactory<ov::pass::TransposeSinkingSplitBackward>()),
|
||||
::testing::ValuesIn(split_tree_depth_nums),
|
||||
::testing::ValuesIn(split_outputs_numbers),
|
||||
::testing::Values(split::backward::CreateFunction),
|
||||
::testing::Values(split::backward::CreateReferenceFunction),
|
||||
::testing::Values(ov::element::f32)));
|
||||
|
||||
|
||||
using TransposeInsertF = std::function< ov::OutputVector (const ov::OutputVector& split_tree_leaves)>;
|
||||
|
||||
using CreateGraphSplitBackwardRestrictF = std::function< std::shared_ptr<ov::Model> (size_t split_tree_depth,
|
||||
size_t num_split_outputs,
|
||||
ov::element::Type input_type,
|
||||
TransposeInsertF tranpose_insert_function)>;
|
||||
|
||||
using TestSplitBackwardRestrictParams = std::tuple<PassFactoryPtr,
|
||||
size_t, /* split_tree_depth */
|
||||
size_t, /* num_split_outputs */
|
||||
CreateGraphSplitBackwardRestrictF, /* model_factory */
|
||||
ov::element::Type, /* input type */
|
||||
TransposeInsertF>; /* insert transpose function */
|
||||
|
||||
class TransposeSinkingSplitBackwardRestrictTestFixture: public ::testing::WithParamInterface<TestSplitBackwardRestrictParams>,
|
||||
public TransformationTestsF {};
|
||||
|
||||
TEST_P(TransposeSinkingSplitBackwardRestrictTestFixture, CompareFunctions) {
|
||||
PassFactoryPtr pass_factory;
|
||||
size_t split_tree_depth;
|
||||
size_t num_split_outputs;
|
||||
CreateGraphSplitBackwardRestrictF model_factory;
|
||||
ov::element::Type input_type;
|
||||
TransposeInsertF tranpose_insert_function;
|
||||
std::tie(pass_factory,
|
||||
split_tree_depth,
|
||||
num_split_outputs,
|
||||
model_factory,
|
||||
input_type,
|
||||
tranpose_insert_function) = this->GetParam();
|
||||
|
||||
model = model_factory(split_tree_depth, num_split_outputs, input_type, tranpose_insert_function);
|
||||
model_ref = model->clone();
|
||||
pass_factory->registerPass(manager);
|
||||
}
|
||||
namespace split {
|
||||
namespace backward {
|
||||
namespace restrictions {
|
||||
|
||||
std::shared_ptr<ov::Model> CreateFunction(size_t split_tree_depth,
|
||||
size_t num_split_outputs,
|
||||
ov::element::Type input_type,
|
||||
TransposeInsertF transpose_insert_func) {
|
||||
const size_t split_input_dim_value = static_cast<size_t>(std::pow(num_split_outputs, split_tree_depth + 1));
|
||||
const ov::Shape input_shape{96, split_input_dim_value, 55, 55};
|
||||
|
||||
auto X = std::make_shared<ov::opset9::Parameter>(input_type, input_shape);
|
||||
|
||||
ov::OutputVector split_tree_leaves;
|
||||
{
|
||||
SplitFactory split_factory(/* axis */ 1, num_split_outputs, /* elem_type */ ov::element::u64);
|
||||
CreateSplitTree(split_tree_depth, /* depth */ 0, X->output(0), split_factory, split_tree_leaves);
|
||||
}
|
||||
|
||||
ov::OutputVector outputs;
|
||||
for (auto& split_tree_leaf : transpose_insert_func(split_tree_leaves)) {
|
||||
const size_t split_dim_current_value = static_cast<size_t>(split_input_dim_value / std::pow(num_split_outputs, split_tree_depth));
|
||||
auto reshape_const = std::make_shared<ov::opset9::Constant>(ov::element::u64, ov::Shape{3}, ov::Shape{96, 55, split_dim_current_value * 55});
|
||||
auto reshape = std::make_shared<ov::opset9::Reshape>(split_tree_leaf, reshape_const, false);
|
||||
outputs.push_back(reshape);
|
||||
}
|
||||
|
||||
return std::make_shared<ov::Model>(outputs, ov::ParameterVector{X});
|
||||
}
|
||||
|
||||
ov::OutputVector OnlyFirstTranspose(const ov::OutputVector& split_tree_leaves) {
|
||||
ov::OutputVector outputs;
|
||||
{
|
||||
auto& split_tree_leaf = split_tree_leaves.front();
|
||||
auto ng_order = std::make_shared<ov::opset9::Constant>(ov::element::u64, ov::Shape{4}, ov::Shape{0, 3, 1, 2});
|
||||
auto transpose = std::make_shared<ov::opset9::Transpose>(split_tree_leaf, ng_order);
|
||||
outputs.push_back(transpose);
|
||||
}
|
||||
|
||||
for (size_t leaf_idx = 1; leaf_idx < split_tree_leaves.size(); ++leaf_idx) {
|
||||
outputs.push_back(split_tree_leaves[leaf_idx]);
|
||||
}
|
||||
|
||||
return outputs;
|
||||
}
|
||||
|
||||
ov::OutputVector OnlyLastTranspose(const ov::OutputVector& split_tree_leaves) {
|
||||
ov::OutputVector outputs;
|
||||
{
|
||||
auto& split_tree_leaf = split_tree_leaves.back();
|
||||
auto ng_order = std::make_shared<ov::opset9::Constant>(ov::element::u64, ov::Shape{4}, ov::Shape{0, 3, 1, 2});
|
||||
auto transpose = std::make_shared<ov::opset9::Transpose>(split_tree_leaf, ng_order);
|
||||
outputs.push_back(transpose);
|
||||
}
|
||||
|
||||
for (size_t leaf_idx = 0; leaf_idx < split_tree_leaves.size() - 1; ++leaf_idx) {
|
||||
outputs.push_back(split_tree_leaves[leaf_idx]);
|
||||
}
|
||||
|
||||
return outputs;
|
||||
}
|
||||
|
||||
ov::OutputVector OnlyMiddleTranspose(const ov::OutputVector& split_tree_leaves) {
|
||||
ov::OutputVector outputs;
|
||||
size_t middle_idx = split_tree_leaves.size() / 2;
|
||||
if (split_tree_leaves.size() % 2)
|
||||
++middle_idx;
|
||||
for (size_t leaf_idx = 0; leaf_idx < split_tree_leaves.size() - 1; ++leaf_idx) {
|
||||
if (leaf_idx == middle_idx) {
|
||||
auto& split_tree_leaf = split_tree_leaves[leaf_idx];
|
||||
auto ng_order = std::make_shared<ov::opset9::Constant>(ov::element::u64, ov::Shape{4}, ov::Shape{0, 3, 1, 2});
|
||||
auto transpose = std::make_shared<ov::opset9::Transpose>(split_tree_leaf, ng_order);
|
||||
outputs.push_back(transpose);
|
||||
} else {
|
||||
outputs.push_back(split_tree_leaves[leaf_idx]);
|
||||
}
|
||||
}
|
||||
|
||||
return outputs;
|
||||
}
|
||||
|
||||
ov::OutputVector FirstAnotherTranspose(const ov::OutputVector& split_tree_leaves) {
|
||||
ov::OutputVector outputs;
|
||||
{
|
||||
auto& split_tree_leaf = split_tree_leaves.front();
|
||||
auto ng_order = std::make_shared<ov::opset9::Constant>(ov::element::u64, ov::Shape{4}, ov::Shape{0, 2, 3, 1});
|
||||
auto transpose = std::make_shared<ov::opset9::Transpose>(split_tree_leaf, ng_order);
|
||||
outputs.push_back(transpose);
|
||||
}
|
||||
|
||||
for (size_t leaf_idx = 1; leaf_idx < split_tree_leaves.size(); ++leaf_idx) {
|
||||
auto& split_tree_leaf = split_tree_leaves[leaf_idx];
|
||||
auto ng_order = std::make_shared<ov::opset9::Constant>(ov::element::u64, ov::Shape{4}, ov::Shape{0, 3, 1, 2});
|
||||
auto transpose = std::make_shared<ov::opset9::Transpose>(split_tree_leaf, ng_order);
|
||||
outputs.push_back(transpose);
|
||||
}
|
||||
|
||||
return outputs;
|
||||
}
|
||||
|
||||
ov::OutputVector LastAnotherTranspose(const ov::OutputVector& split_tree_leaves) {
|
||||
ov::OutputVector outputs;
|
||||
{
|
||||
auto& split_tree_leaf = split_tree_leaves.back();
|
||||
auto ng_order = std::make_shared<ov::opset9::Constant>(ov::element::u64, ov::Shape{4}, ov::Shape{0, 2, 3, 1});
|
||||
auto transpose = std::make_shared<ov::opset9::Transpose>(split_tree_leaf, ng_order);
|
||||
outputs.push_back(transpose);
|
||||
}
|
||||
|
||||
for (size_t leaf_idx = 0; leaf_idx < split_tree_leaves.size() - 1; ++leaf_idx) {
|
||||
auto& split_tree_leaf = split_tree_leaves[leaf_idx];
|
||||
auto ng_order = std::make_shared<ov::opset9::Constant>(ov::element::u64, ov::Shape{4}, ov::Shape{0, 3, 1, 2});
|
||||
auto transpose = std::make_shared<ov::opset9::Transpose>(split_tree_leaf, ng_order);
|
||||
outputs.push_back(transpose);
|
||||
}
|
||||
|
||||
return outputs;
|
||||
}
|
||||
|
||||
ov::OutputVector MiddleAnotherTranspose(const ov::OutputVector& split_tree_leaves) {
|
||||
ov::OutputVector outputs;
|
||||
size_t middle_idx = split_tree_leaves.size() / 2;
|
||||
if (split_tree_leaves.size() % 2)
|
||||
++middle_idx;
|
||||
for (size_t leaf_idx = 0; leaf_idx < split_tree_leaves.size(); ++leaf_idx) {
|
||||
auto& split_tree_leaf = split_tree_leaves[leaf_idx];
|
||||
if (leaf_idx == middle_idx) {
|
||||
auto ng_order = std::make_shared<ov::opset9::Constant>(ov::element::u64, ov::Shape{4}, ov::Shape{0, 2, 3, 1});
|
||||
auto transpose = std::make_shared<ov::opset9::Transpose>(split_tree_leaf, ng_order);
|
||||
outputs.push_back(transpose);
|
||||
} else {
|
||||
auto ng_order = std::make_shared<ov::opset9::Constant>(ov::element::u64, ov::Shape{4}, ov::Shape{0, 3, 1, 2});
|
||||
auto transpose = std::make_shared<ov::opset9::Transpose>(split_tree_leaf, ng_order);
|
||||
outputs.push_back(transpose);
|
||||
}
|
||||
}
|
||||
|
||||
return outputs;
|
||||
}
|
||||
|
||||
} // namespace restrictions
|
||||
} // namespace backward
|
||||
} // namespace split
|
||||
|
||||
namespace {
|
||||
|
||||
std::vector<TransposeInsertF> insertTransposeFactories = {
|
||||
split::backward::restrictions::OnlyFirstTranspose,
|
||||
split::backward::restrictions::OnlyLastTranspose,
|
||||
split::backward::restrictions::OnlyMiddleTranspose,
|
||||
split::backward::restrictions::FirstAnotherTranspose,
|
||||
split::backward::restrictions::LastAnotherTranspose,
|
||||
split::backward::restrictions::MiddleAnotherTranspose
|
||||
};
|
||||
|
||||
} // namespace
|
||||
|
||||
|
||||
INSTANTIATE_TEST_SUITE_P(
|
||||
TransposeSinkingSplitBackwardRestrictTestSuite,
|
||||
TransposeSinkingSplitBackwardRestrictTestFixture,
|
||||
::testing::Combine(::testing::Values(CreatePassFactory<ov::pass::TransposeSinkingSplitBackward>()),
|
||||
::testing::Values(1),
|
||||
::testing::Values(5),
|
||||
::testing::Values(split::backward::restrictions::CreateFunction),
|
||||
::testing::Values(ov::element::f32),
|
||||
::testing::ValuesIn(insertTransposeFactories)));
|
Loading…
Reference in New Issue
Block a user