Transpose sinking on passes for Concat, Split and binary eltwise operations (#13718)
* initial * clang cleanup fixes * remove TrasposeAxis function; cleanup namespaces * fix TransposeInputsInfo spell * one_input_transpose spell * cleanup speel * spell * decompose forward sinking * decompose backward sink * use NodeVector * clang cleanup * decomposite transformations into different files * decompose unit tests * clang cleanup * azure build fixes * code review fixes * clang cleanup fixes
This commit is contained in:
parent
738d7bb09f
commit
0846bdb67e
@ -0,0 +1,30 @@
|
|||||||
|
// Copyright (C) 2022 Intel Corporation
|
||||||
|
// SPDX-License-Identifier: Apache-2.0
|
||||||
|
//
|
||||||
|
|
||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include "openvino/pass/graph_rewrite.hpp"
|
||||||
|
#include "openvino/pass/pass.hpp"
|
||||||
|
#include "transformations_visibility.hpp"
|
||||||
|
|
||||||
|
namespace ov {
|
||||||
|
namespace pass {
|
||||||
|
|
||||||
|
class TRANSFORMATIONS_API TransposeSinkingBinaryElementwiseForward;
|
||||||
|
class TRANSFORMATIONS_API TransposeSinkingBinaryElementwiseBackward;
|
||||||
|
|
||||||
|
} // namespace pass
|
||||||
|
} // namespace ov
|
||||||
|
|
||||||
|
class ov::pass::TransposeSinkingBinaryElementwiseForward : public ov::pass::MatcherPass {
|
||||||
|
public:
|
||||||
|
OPENVINO_RTTI("ov::pass::TransposeSinkingBinaryElementwiseForward", "0");
|
||||||
|
TransposeSinkingBinaryElementwiseForward();
|
||||||
|
};
|
||||||
|
|
||||||
|
class ov::pass::TransposeSinkingBinaryElementwiseBackward : public ov::pass::MatcherPass {
|
||||||
|
public:
|
||||||
|
OPENVINO_RTTI("ov::pass::TransposeSinkingBinaryElementwiseBackward", "0");
|
||||||
|
TransposeSinkingBinaryElementwiseBackward();
|
||||||
|
};
|
@ -0,0 +1,30 @@
|
|||||||
|
// Copyright (C) 2022 Intel Corporation
|
||||||
|
// SPDX-License-Identifier: Apache-2.0
|
||||||
|
//
|
||||||
|
|
||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include "openvino/pass/graph_rewrite.hpp"
|
||||||
|
#include "openvino/pass/pass.hpp"
|
||||||
|
#include "transformations_visibility.hpp"
|
||||||
|
|
||||||
|
namespace ov {
|
||||||
|
namespace pass {
|
||||||
|
|
||||||
|
class TRANSFORMATIONS_API TransposeSinkingConcatForward;
|
||||||
|
class TRANSFORMATIONS_API TransposeSinkingConcatBackward;
|
||||||
|
|
||||||
|
} // namespace pass
|
||||||
|
} // namespace ov
|
||||||
|
|
||||||
|
class ov::pass::TransposeSinkingConcatForward : public ov::pass::MatcherPass {
|
||||||
|
public:
|
||||||
|
OPENVINO_RTTI("ov::pass::TransposeSinkingConcatForward", "0");
|
||||||
|
TransposeSinkingConcatForward();
|
||||||
|
};
|
||||||
|
|
||||||
|
class ov::pass::TransposeSinkingConcatBackward : public ov::pass::MatcherPass {
|
||||||
|
public:
|
||||||
|
OPENVINO_RTTI("ov::pass::TransposeSinkingConcatBackward", "0");
|
||||||
|
TransposeSinkingConcatBackward();
|
||||||
|
};
|
@ -0,0 +1,30 @@
|
|||||||
|
// Copyright (C) 2022 Intel Corporation
|
||||||
|
// SPDX-License-Identifier: Apache-2.0
|
||||||
|
//
|
||||||
|
|
||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include "openvino/pass/graph_rewrite.hpp"
|
||||||
|
#include "openvino/pass/pass.hpp"
|
||||||
|
#include "transformations_visibility.hpp"
|
||||||
|
|
||||||
|
namespace ov {
|
||||||
|
namespace pass {
|
||||||
|
|
||||||
|
class TRANSFORMATIONS_API TransposeSinkingSplitBackward;
|
||||||
|
class TRANSFORMATIONS_API TransposeSinkingSplitForward;
|
||||||
|
|
||||||
|
} // namespace pass
|
||||||
|
} // namespace ov
|
||||||
|
|
||||||
|
class ov::pass::TransposeSinkingSplitForward : public ov::pass::MatcherPass {
|
||||||
|
public:
|
||||||
|
OPENVINO_RTTI("ov::pass::TransposeSinkingSplitForward", "0");
|
||||||
|
TransposeSinkingSplitForward();
|
||||||
|
};
|
||||||
|
|
||||||
|
class ov::pass::TransposeSinkingSplitBackward : public ov::pass::MatcherPass {
|
||||||
|
public:
|
||||||
|
OPENVINO_RTTI("ov::pass::TransposeSinkingSplitBackward", "0");
|
||||||
|
TransposeSinkingSplitBackward();
|
||||||
|
};
|
@ -0,0 +1,50 @@
|
|||||||
|
// Copyright (C) 2022 Intel Corporation
|
||||||
|
// SPDX-License-Identifier: Apache-2.0
|
||||||
|
//
|
||||||
|
|
||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include <openvino/pass/pattern/op/or.hpp>
|
||||||
|
#include <transformations/utils/utils.hpp>
|
||||||
|
#include <utility>
|
||||||
|
|
||||||
|
#include "itt.hpp"
|
||||||
|
#include "openvino/op/util/op_types.hpp"
|
||||||
|
#include "openvino/opsets/opset9.hpp"
|
||||||
|
#include "openvino/pass/pattern/op/label.hpp"
|
||||||
|
#include "openvino/pass/pattern/op/wrap_type.hpp"
|
||||||
|
#include "openvino/util/common_util.hpp"
|
||||||
|
#include "openvino/util/log.hpp"
|
||||||
|
|
||||||
|
namespace transpose_sinking {
|
||||||
|
|
||||||
|
struct TransposeInputsInfo {
|
||||||
|
std::shared_ptr<ov::opset9::Transpose> transpose;
|
||||||
|
std::shared_ptr<ov::opset9::Constant> transpose_const;
|
||||||
|
size_t input_idx;
|
||||||
|
|
||||||
|
bool isEmpty() const {
|
||||||
|
return !transpose || !transpose_const;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
TransposeInputsInfo GetFirstTransposeInput(std::shared_ptr<ov::Node> node);
|
||||||
|
bool IfNodeHasTransposeInputs(const ov::Output<ov::Node>& output);
|
||||||
|
ov::AxisVector ReverseTransposeOrder(const ov::AxisVector& axis_order);
|
||||||
|
void SwapOutputNames(ov::Output<ov::Node> output1, ov::Output<ov::Node> output2);
|
||||||
|
void SwapFriendlyNames(std::shared_ptr<ov::Node> node1, std::shared_ptr<ov::Node> node2);
|
||||||
|
void SwapNames(std::shared_ptr<ov::Node> node1, std::shared_ptr<ov::Node> node2);
|
||||||
|
|
||||||
|
namespace sink_forward {
|
||||||
|
// insert input reversed transposes, remove first input tranpose
|
||||||
|
void UpdateInputTransposes(std::shared_ptr<ov::Node> main_node, TransposeInputsInfo& transpose_input_info);
|
||||||
|
void RemoveZeroInputNode(std::shared_ptr<ov::Node> main_node);
|
||||||
|
ov::NodeVector InsertOutputTransposes(std::shared_ptr<ov::Node> main_node, TransposeInputsInfo& transpose_input_info);
|
||||||
|
} // namespace sink_forward
|
||||||
|
|
||||||
|
namespace sink_backward {
|
||||||
|
ov::NodeVector InsertTransposeBeforeNode(std::shared_ptr<ov::Node> main_node,
|
||||||
|
std::shared_ptr<ov::opset9::Constant> transpose_const);
|
||||||
|
} // namespace sink_backward
|
||||||
|
|
||||||
|
} // namespace transpose_sinking
|
@ -0,0 +1,74 @@
|
|||||||
|
#include "transformations/common_optimizations/transpose_sinking_binary.hpp"
|
||||||
|
|
||||||
|
#include <openvino/pass/pattern/op/or.hpp>
|
||||||
|
#include <transformations/utils/utils.hpp>
|
||||||
|
#include <utility>
|
||||||
|
|
||||||
|
#include "itt.hpp"
|
||||||
|
#include "openvino/op/util/op_types.hpp"
|
||||||
|
#include "openvino/opsets/opset9.hpp"
|
||||||
|
#include "openvino/pass/pattern/op/label.hpp"
|
||||||
|
#include "openvino/pass/pattern/op/wrap_type.hpp"
|
||||||
|
#include "openvino/util/common_util.hpp"
|
||||||
|
#include "openvino/util/log.hpp"
|
||||||
|
#include "transformations/common_optimizations/transpose_sinking_utils.hpp"
|
||||||
|
|
||||||
|
using namespace ov::pass::pattern;
|
||||||
|
using namespace ov;
|
||||||
|
using namespace ov::opset9;
|
||||||
|
using namespace transpose_sinking;
|
||||||
|
|
||||||
|
ov::pass::TransposeSinkingBinaryElementwiseForward::TransposeSinkingBinaryElementwiseForward() {
|
||||||
|
MATCHER_SCOPE(TransposeSinkingBinaryElementwiseForward);
|
||||||
|
|
||||||
|
auto main_node_label = wrap_type<op::util::BinaryElementwiseArithmetic>(IfNodeHasTransposeInputs);
|
||||||
|
|
||||||
|
matcher_pass_callback matcher_pass_callback = [=](Matcher& m) {
|
||||||
|
const auto& pattern_to_output = m.get_pattern_value_map();
|
||||||
|
|
||||||
|
auto& main_node_output = pattern_to_output.at(main_node_label);
|
||||||
|
auto main_node = main_node_output.get_node_shared_ptr();
|
||||||
|
|
||||||
|
TransposeInputsInfo transpose_input_info = GetFirstTransposeInput(main_node);
|
||||||
|
|
||||||
|
sink_forward::UpdateInputTransposes(main_node, transpose_input_info);
|
||||||
|
for (auto& new_node : sink_forward::InsertOutputTransposes(main_node, transpose_input_info)) {
|
||||||
|
register_new_node(new_node);
|
||||||
|
}
|
||||||
|
|
||||||
|
return true;
|
||||||
|
};
|
||||||
|
|
||||||
|
auto m = std::make_shared<Matcher>(main_node_label, matcher_name);
|
||||||
|
register_matcher(m, matcher_pass_callback);
|
||||||
|
}
|
||||||
|
|
||||||
|
ov::pass::TransposeSinkingBinaryElementwiseBackward::TransposeSinkingBinaryElementwiseBackward() {
|
||||||
|
MATCHER_SCOPE(TransposeSinkingBinaryElementwiseBackward);
|
||||||
|
|
||||||
|
auto main_node_label = wrap_type<op::util::BinaryElementwiseArithmetic>(consumers_count(1));
|
||||||
|
|
||||||
|
auto transpose_const_label = wrap_type<Constant>(consumers_count(1));
|
||||||
|
auto transpose_label = wrap_type<Transpose>({main_node_label, transpose_const_label}, consumers_count(1));
|
||||||
|
|
||||||
|
matcher_pass_callback matcher_pass_callback = [=](Matcher& m) {
|
||||||
|
const auto& pattern_to_output = m.get_pattern_value_map();
|
||||||
|
auto transpose_const = as_type_ptr<Constant>(pattern_to_output.at(transpose_const_label).get_node_shared_ptr());
|
||||||
|
auto transpose = pattern_to_output.at(transpose_label).get_node_shared_ptr();
|
||||||
|
auto main_node = pattern_to_output.at(main_node_label).get_node_shared_ptr();
|
||||||
|
|
||||||
|
for (auto& new_node : sink_backward::InsertTransposeBeforeNode(main_node, transpose_const)) {
|
||||||
|
register_new_node(new_node);
|
||||||
|
}
|
||||||
|
|
||||||
|
// remove transpose after main node
|
||||||
|
transpose->output(0).replace(main_node);
|
||||||
|
|
||||||
|
SwapNames(transpose, main_node);
|
||||||
|
|
||||||
|
return true;
|
||||||
|
};
|
||||||
|
|
||||||
|
auto m = std::make_shared<Matcher>(transpose_label, matcher_name);
|
||||||
|
register_matcher(m, matcher_pass_callback);
|
||||||
|
}
|
@ -0,0 +1,86 @@
|
|||||||
|
#include "transformations/common_optimizations/transpose_sinking_concat.hpp"
|
||||||
|
|
||||||
|
#include <openvino/opsets/opset9.hpp>
|
||||||
|
#include <openvino/pass/pattern/op/or.hpp>
|
||||||
|
#include <transformations/utils/utils.hpp>
|
||||||
|
#include <utility>
|
||||||
|
|
||||||
|
#include "itt.hpp"
|
||||||
|
#include "openvino/op/util/op_types.hpp"
|
||||||
|
#include "openvino/opsets/opset9.hpp"
|
||||||
|
#include "openvino/pass/pattern/op/label.hpp"
|
||||||
|
#include "openvino/pass/pattern/op/wrap_type.hpp"
|
||||||
|
#include "openvino/util/common_util.hpp"
|
||||||
|
#include "openvino/util/log.hpp"
|
||||||
|
#include "transformations/common_optimizations/transpose_sinking_utils.hpp"
|
||||||
|
|
||||||
|
using namespace ov::pass::pattern;
|
||||||
|
using namespace ov;
|
||||||
|
using namespace ov::opset9;
|
||||||
|
using namespace transpose_sinking;
|
||||||
|
|
||||||
|
ov::pass::TransposeSinkingConcatForward::TransposeSinkingConcatForward() {
|
||||||
|
MATCHER_SCOPE(TransposeSinkingConcatForward);
|
||||||
|
|
||||||
|
auto main_node_label = wrap_type<Concat>(IfNodeHasTransposeInputs);
|
||||||
|
|
||||||
|
matcher_pass_callback matcher_pass_callback = [=](Matcher& m) {
|
||||||
|
const auto& pattern_to_output = m.get_pattern_value_map();
|
||||||
|
|
||||||
|
auto& main_node_output = pattern_to_output.at(main_node_label);
|
||||||
|
auto main_node = main_node_output.get_node_shared_ptr();
|
||||||
|
|
||||||
|
TransposeInputsInfo transpose_input_info = GetFirstTransposeInput(main_node);
|
||||||
|
|
||||||
|
sink_forward::UpdateInputTransposes(main_node, transpose_input_info);
|
||||||
|
for (auto& new_node : sink_forward::InsertOutputTransposes(main_node, transpose_input_info)) {
|
||||||
|
register_new_node(new_node);
|
||||||
|
}
|
||||||
|
|
||||||
|
auto concat_node = as_type_ptr<Concat>(main_node);
|
||||||
|
const auto transpose_axis_order = transpose_input_info.transpose_const->get_axis_vector_val();
|
||||||
|
const int64_t transposed_concat_axis = transpose_axis_order[concat_node->get_axis()];
|
||||||
|
concat_node->set_concatenation_axis(transposed_concat_axis);
|
||||||
|
|
||||||
|
return true;
|
||||||
|
};
|
||||||
|
|
||||||
|
auto m = std::make_shared<Matcher>(main_node_label, matcher_name);
|
||||||
|
register_matcher(m, matcher_pass_callback);
|
||||||
|
}
|
||||||
|
|
||||||
|
ov::pass::TransposeSinkingConcatBackward::TransposeSinkingConcatBackward() {
|
||||||
|
MATCHER_SCOPE(TransposeSinkingConcatBackward);
|
||||||
|
|
||||||
|
auto main_node_label = wrap_type<Concat>(consumers_count(1));
|
||||||
|
|
||||||
|
auto transpose_const_label = wrap_type<Constant>(consumers_count(1));
|
||||||
|
auto transpose_label = wrap_type<Transpose>({main_node_label, transpose_const_label}, consumers_count(1));
|
||||||
|
|
||||||
|
matcher_pass_callback matcher_pass_callback = [=](Matcher& m) {
|
||||||
|
const auto& pattern_to_output = m.get_pattern_value_map();
|
||||||
|
auto transpose_const = as_type_ptr<Constant>(pattern_to_output.at(transpose_const_label).get_node_shared_ptr());
|
||||||
|
auto transpose = pattern_to_output.at(transpose_label).get_node_shared_ptr();
|
||||||
|
auto main_node = pattern_to_output.at(main_node_label).get_node_shared_ptr();
|
||||||
|
|
||||||
|
for (auto& new_node : sink_backward::InsertTransposeBeforeNode(main_node, transpose_const)) {
|
||||||
|
register_new_node(new_node);
|
||||||
|
}
|
||||||
|
|
||||||
|
// remove transpose after main node
|
||||||
|
transpose->output(0).replace(main_node);
|
||||||
|
|
||||||
|
SwapNames(transpose, main_node);
|
||||||
|
|
||||||
|
auto concat_node = as_type_ptr<Concat>(main_node);
|
||||||
|
const auto transpose_axis_order = transpose_const->get_axis_vector_val();
|
||||||
|
const auto reversed_traspose_axis_order = ReverseTransposeOrder(transpose_axis_order);
|
||||||
|
const int64_t transposed_concat_axis = reversed_traspose_axis_order[concat_node->get_axis()];
|
||||||
|
concat_node->set_concatenation_axis(transposed_concat_axis);
|
||||||
|
|
||||||
|
return true;
|
||||||
|
};
|
||||||
|
|
||||||
|
auto m = std::make_shared<Matcher>(transpose_label, matcher_name);
|
||||||
|
register_matcher(m, matcher_pass_callback);
|
||||||
|
}
|
@ -0,0 +1,200 @@
|
|||||||
|
#include "transformations/common_optimizations/transpose_sinking_split.hpp"
|
||||||
|
|
||||||
|
#include <openvino/opsets/opset9.hpp>
|
||||||
|
#include <openvino/pass/pattern/op/or.hpp>
|
||||||
|
#include <transformations/utils/utils.hpp>
|
||||||
|
#include <utility>
|
||||||
|
|
||||||
|
#include "itt.hpp"
|
||||||
|
#include "openvino/op/util/op_types.hpp"
|
||||||
|
#include "openvino/opsets/opset9.hpp"
|
||||||
|
#include "openvino/pass/pattern/op/label.hpp"
|
||||||
|
#include "openvino/pass/pattern/op/wrap_type.hpp"
|
||||||
|
#include "openvino/util/common_util.hpp"
|
||||||
|
#include "openvino/util/log.hpp"
|
||||||
|
#include "transformations/common_optimizations/transpose_sinking_utils.hpp"
|
||||||
|
|
||||||
|
using namespace ov::pass::pattern;
|
||||||
|
using namespace ov;
|
||||||
|
using namespace ov::opset9;
|
||||||
|
using namespace transpose_sinking;
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
|
||||||
|
using NodePtr = std::shared_ptr<Node>;
|
||||||
|
|
||||||
|
struct OutputTranspose {
|
||||||
|
OutputTranspose() : transpose(nullptr), transpose_const(nullptr) {}
|
||||||
|
Transpose* transpose;
|
||||||
|
Constant* transpose_const;
|
||||||
|
};
|
||||||
|
|
||||||
|
OutputTranspose GetOutputTransposes(NodePtr node) {
|
||||||
|
for (size_t output_idx = 0; output_idx < node->get_output_size(); ++output_idx) {
|
||||||
|
for (auto& input : node->get_output_target_inputs(output_idx)) {
|
||||||
|
auto transpose_node = dynamic_cast<Transpose*>(input.get_node());
|
||||||
|
if (!transpose_node)
|
||||||
|
continue;
|
||||||
|
auto constant_node = dynamic_cast<Constant*>(transpose_node->input_value(1).get_node());
|
||||||
|
if (!constant_node)
|
||||||
|
continue;
|
||||||
|
{
|
||||||
|
OutputTranspose output_transpose;
|
||||||
|
output_transpose.transpose = transpose_node;
|
||||||
|
output_transpose.transpose_const = constant_node;
|
||||||
|
|
||||||
|
return output_transpose;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return OutputTranspose();
|
||||||
|
}
|
||||||
|
|
||||||
|
NodePtr FindSplitInput(Node* node) {
|
||||||
|
for (size_t input_idx = 0; input_idx < node->get_input_size(); ++input_idx) {
|
||||||
|
NodePtr input_node = node->get_input_node_shared_ptr(input_idx);
|
||||||
|
auto split_node = as_type_ptr<Split>(input_node);
|
||||||
|
if (split_node)
|
||||||
|
return split_node;
|
||||||
|
}
|
||||||
|
return {};
|
||||||
|
}
|
||||||
|
|
||||||
|
std::shared_ptr<Constant> GetTransposeConstant(Input<Node> input) {
|
||||||
|
auto transpose_node = dynamic_cast<Transpose*>(input.get_node());
|
||||||
|
if (!transpose_node)
|
||||||
|
return {};
|
||||||
|
|
||||||
|
auto constant_node = as_type_ptr<Constant>(transpose_node->input_value(1).get_node_shared_ptr());
|
||||||
|
if (!constant_node)
|
||||||
|
return {};
|
||||||
|
|
||||||
|
return constant_node;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool HasInputSplitAndTransposeSiblings(const Output<Node>& output) {
|
||||||
|
NodePtr split_node = FindSplitInput(output.get_node());
|
||||||
|
if (!split_node) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
AxisVector first_transpose_axis_order;
|
||||||
|
// get first transpose axis
|
||||||
|
{
|
||||||
|
auto constant_node = GetTransposeConstant(*(split_node->get_output_target_inputs(0).begin()));
|
||||||
|
if (!constant_node)
|
||||||
|
return false;
|
||||||
|
first_transpose_axis_order = constant_node->get_axis_vector_val();
|
||||||
|
}
|
||||||
|
|
||||||
|
for (size_t output_idx = 1; output_idx < split_node->get_output_size(); ++output_idx) {
|
||||||
|
for (auto& input : split_node->get_output_target_inputs(output_idx)) {
|
||||||
|
auto constant_node = GetTransposeConstant(input);
|
||||||
|
if (!constant_node)
|
||||||
|
return false;
|
||||||
|
|
||||||
|
AxisVector transpose_axis_order = constant_node->get_axis_vector_val();
|
||||||
|
if (transpose_axis_order.size() != first_transpose_axis_order.size())
|
||||||
|
return false;
|
||||||
|
if (!std::equal(transpose_axis_order.begin(),
|
||||||
|
transpose_axis_order.end(),
|
||||||
|
first_transpose_axis_order.begin()))
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
} // namespace
|
||||||
|
|
||||||
|
ov::pass::TransposeSinkingSplitBackward::TransposeSinkingSplitBackward() {
|
||||||
|
MATCHER_SCOPE(TransposeSinkingSplitBackward);
|
||||||
|
|
||||||
|
auto transpose_const_label = wrap_type<Constant>(consumers_count(1));
|
||||||
|
auto transpose_label =
|
||||||
|
wrap_type<Transpose>({any_input(), transpose_const_label}, HasInputSplitAndTransposeSiblings);
|
||||||
|
|
||||||
|
matcher_pass_callback matcher_pass_callback = [=](Matcher& m) {
|
||||||
|
const auto& pattern_to_output = m.get_pattern_value_map();
|
||||||
|
auto transpose_label_node = pattern_to_output.at(transpose_label).get_node();
|
||||||
|
|
||||||
|
NodePtr split = FindSplitInput(transpose_label_node);
|
||||||
|
auto split_axis_constant = as_type_ptr<Constant>(split->input_value(1).get_node_shared_ptr());
|
||||||
|
OutputTranspose output_transpose = GetOutputTransposes(split);
|
||||||
|
|
||||||
|
const auto transpose_axis_order = output_transpose.transpose_const->get_axis_vector_val();
|
||||||
|
const auto transpose_element_type = output_transpose.transpose_const->get_element_type();
|
||||||
|
|
||||||
|
const auto reversed_traspose_axis_order = ReverseTransposeOrder(transpose_axis_order);
|
||||||
|
|
||||||
|
const size_t split_axis = split_axis_constant->get_axis_vector_val()[0];
|
||||||
|
const size_t reversed_transposed_split_axis = reversed_traspose_axis_order[split_axis];
|
||||||
|
|
||||||
|
// insert transpose before split
|
||||||
|
{
|
||||||
|
auto input_node = split->input_value(0);
|
||||||
|
auto new_transpose_const = std::make_shared<Constant>(transpose_element_type,
|
||||||
|
Shape{transpose_axis_order.size()},
|
||||||
|
transpose_axis_order);
|
||||||
|
auto new_transpose = std::make_shared<Transpose>(input_node, new_transpose_const);
|
||||||
|
|
||||||
|
split->input(0).replace_source_output(new_transpose->output(0));
|
||||||
|
|
||||||
|
copy_runtime_info(input_node.get_node_shared_ptr(), {new_transpose, new_transpose_const});
|
||||||
|
|
||||||
|
register_new_node(new_transpose);
|
||||||
|
}
|
||||||
|
|
||||||
|
// update split axis
|
||||||
|
auto new_split_axis_const = std::make_shared<Constant>(split_axis_constant->get_element_type(),
|
||||||
|
Shape{},
|
||||||
|
reversed_transposed_split_axis);
|
||||||
|
split->input(1).replace_source_output(new_split_axis_const);
|
||||||
|
|
||||||
|
// remove split output transposes
|
||||||
|
for (size_t output_idx = 0; output_idx < split->get_output_size(); ++output_idx) {
|
||||||
|
for (auto& input : split->get_output_target_inputs(output_idx)) {
|
||||||
|
input.get_node()->output(0).replace(split->output(output_idx));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return true;
|
||||||
|
};
|
||||||
|
|
||||||
|
auto m = std::make_shared<Matcher>(transpose_label, matcher_name);
|
||||||
|
register_matcher(m, matcher_pass_callback);
|
||||||
|
}
|
||||||
|
|
||||||
|
ov::pass::TransposeSinkingSplitForward::TransposeSinkingSplitForward() {
|
||||||
|
MATCHER_SCOPE(TransposeSinkingSplitForward);
|
||||||
|
|
||||||
|
auto main_node_label = wrap_type<Split>(IfNodeHasTransposeInputs);
|
||||||
|
|
||||||
|
matcher_pass_callback matcher_pass_callback = [=](Matcher& m) {
|
||||||
|
const auto& pattern_to_output = m.get_pattern_value_map();
|
||||||
|
|
||||||
|
auto& main_node_output = pattern_to_output.at(main_node_label);
|
||||||
|
auto main_node = main_node_output.get_node_shared_ptr();
|
||||||
|
|
||||||
|
TransposeInputsInfo transpose_input_info = GetFirstTransposeInput(main_node);
|
||||||
|
|
||||||
|
sink_forward::RemoveZeroInputNode(main_node);
|
||||||
|
for (auto& new_node : sink_forward::InsertOutputTransposes(main_node, transpose_input_info)) {
|
||||||
|
register_new_node(new_node);
|
||||||
|
}
|
||||||
|
|
||||||
|
const auto transpose_axis_order = transpose_input_info.transpose_const->get_axis_vector_val();
|
||||||
|
auto split_node = as_type_ptr<Split>(main_node);
|
||||||
|
auto split_axis_constant = as_type_ptr<Constant>(split_node->input_value(1).get_node_shared_ptr());
|
||||||
|
const size_t split_axis = split_axis_constant->get_axis_vector_val()[0];
|
||||||
|
const size_t transposed_split_axis = transpose_axis_order[split_axis];
|
||||||
|
auto new_split_axis_const =
|
||||||
|
std::make_shared<Constant>(split_axis_constant->get_element_type(), Shape{}, transposed_split_axis);
|
||||||
|
split_node->input(1).replace_source_output(new_split_axis_const);
|
||||||
|
|
||||||
|
return true;
|
||||||
|
};
|
||||||
|
|
||||||
|
auto m = std::make_shared<Matcher>(main_node_label, matcher_name);
|
||||||
|
register_matcher(m, matcher_pass_callback);
|
||||||
|
}
|
@ -0,0 +1,172 @@
|
|||||||
|
#include "transformations/common_optimizations/transpose_sinking_utils.hpp"
|
||||||
|
|
||||||
|
#include <openvino/pass/pattern/op/or.hpp>
|
||||||
|
#include <transformations/utils/utils.hpp>
|
||||||
|
#include <utility>
|
||||||
|
|
||||||
|
#include "itt.hpp"
|
||||||
|
#include "openvino/op/util/op_types.hpp"
|
||||||
|
#include "openvino/opsets/opset9.hpp"
|
||||||
|
#include "openvino/pass/pattern/op/label.hpp"
|
||||||
|
#include "openvino/pass/pattern/op/wrap_type.hpp"
|
||||||
|
#include "openvino/util/common_util.hpp"
|
||||||
|
#include "openvino/util/log.hpp"
|
||||||
|
|
||||||
|
namespace transpose_sinking {
|
||||||
|
|
||||||
|
using namespace ov;
|
||||||
|
using namespace ov::opset9;
|
||||||
|
|
||||||
|
using NodePtr = std::shared_ptr<Node>;
|
||||||
|
|
||||||
|
TransposeInputsInfo GetFirstTransposeInput(NodePtr node) {
|
||||||
|
for (size_t input_idx = 0; input_idx < node->get_input_size(); ++input_idx) {
|
||||||
|
NodePtr input_node = node->get_input_node_shared_ptr(input_idx);
|
||||||
|
auto transpose_node = as_type_ptr<Transpose>(input_node);
|
||||||
|
if (!transpose_node)
|
||||||
|
continue;
|
||||||
|
auto constant_node = as_type_ptr<Constant>(transpose_node->input_value(1).get_node_shared_ptr());
|
||||||
|
if (!constant_node)
|
||||||
|
continue;
|
||||||
|
{
|
||||||
|
TransposeInputsInfo input_info;
|
||||||
|
input_info.transpose = transpose_node;
|
||||||
|
input_info.transpose_const = constant_node;
|
||||||
|
input_info.input_idx = input_idx;
|
||||||
|
return input_info;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return TransposeInputsInfo();
|
||||||
|
}
|
||||||
|
|
||||||
|
bool IfNodeHasTransposeInputs(const Output<Node>& output) {
|
||||||
|
TransposeInputsInfo inputs_info = GetFirstTransposeInput(output.get_node_shared_ptr());
|
||||||
|
return !inputs_info.isEmpty();
|
||||||
|
}
|
||||||
|
|
||||||
|
AxisVector ReverseTransposeOrder(const AxisVector& axis_order) {
|
||||||
|
AxisVector out(axis_order.size());
|
||||||
|
for (size_t i = 0; i < axis_order.size(); i++) {
|
||||||
|
out.at(axis_order[i]) = i;
|
||||||
|
}
|
||||||
|
return out;
|
||||||
|
}
|
||||||
|
|
||||||
|
void SwapOutputNames(Output<Node> output1, Output<Node> output2) {
|
||||||
|
const auto node2_output_names = output2.get_names();
|
||||||
|
output2.set_names(output1.get_names());
|
||||||
|
output1.set_names(node2_output_names);
|
||||||
|
}
|
||||||
|
|
||||||
|
void SwapFriendlyNames(NodePtr node1, NodePtr node2) {
|
||||||
|
const std::string node2_name = node2->get_friendly_name();
|
||||||
|
node2->set_friendly_name(node1->get_friendly_name());
|
||||||
|
node1->set_friendly_name(node2_name);
|
||||||
|
}
|
||||||
|
|
||||||
|
void SwapNames(NodePtr node1, NodePtr node2) {
|
||||||
|
SwapFriendlyNames(node1, node2);
|
||||||
|
SwapOutputNames(node1->output(0), node2->output(0));
|
||||||
|
}
|
||||||
|
|
||||||
|
namespace sink_forward {
|
||||||
|
|
||||||
|
// insert input reversed transposes, remove first input tranpose
|
||||||
|
void UpdateInputTransposes(NodePtr main_node, TransposeInputsInfo& transpose_input_info) {
|
||||||
|
if (transpose_input_info.isEmpty())
|
||||||
|
return;
|
||||||
|
const auto transpose_axis_order = transpose_input_info.transpose_const->get_axis_vector_val();
|
||||||
|
const auto reversed_traspose_axis_order = ReverseTransposeOrder(transpose_axis_order);
|
||||||
|
const size_t tranpose_input_index = transpose_input_info.input_idx;
|
||||||
|
const auto transpose_element_type = transpose_input_info.transpose_const->get_element_type();
|
||||||
|
|
||||||
|
for (size_t i = 0; i < main_node->get_input_size(); ++i) {
|
||||||
|
auto input_node = main_node->input_value(i);
|
||||||
|
if (i == tranpose_input_index) {
|
||||||
|
auto transpose_parent = input_node.get_node()->input_value(0);
|
||||||
|
main_node->input(i).replace_source_output(transpose_parent);
|
||||||
|
} else {
|
||||||
|
auto new_transpose_const = std::make_shared<Constant>(transpose_element_type,
|
||||||
|
Shape{reversed_traspose_axis_order.size()},
|
||||||
|
reversed_traspose_axis_order);
|
||||||
|
auto new_transpose = std::make_shared<Transpose>(input_node, new_transpose_const);
|
||||||
|
|
||||||
|
main_node->input(i).replace_source_output(new_transpose->output(0));
|
||||||
|
|
||||||
|
copy_runtime_info(input_node.get_node_shared_ptr(), {new_transpose, new_transpose_const});
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void RemoveZeroInputNode(NodePtr main_node) {
|
||||||
|
auto input_node = main_node->input_value(0);
|
||||||
|
if (input_node.get_node()->get_input_size() < 1)
|
||||||
|
return;
|
||||||
|
auto parent_node = input_node.get_node()->input_value(0);
|
||||||
|
main_node->input(0).replace_source_output(parent_node);
|
||||||
|
}
|
||||||
|
|
||||||
|
NodeVector InsertOutputTransposes(NodePtr main_node, TransposeInputsInfo& transpose_input_info) {
|
||||||
|
if (transpose_input_info.isEmpty())
|
||||||
|
return {};
|
||||||
|
const auto transpose_axis_order = transpose_input_info.transpose_const->get_axis_vector_val();
|
||||||
|
const auto reversed_traspose_axis_order = ReverseTransposeOrder(transpose_axis_order);
|
||||||
|
const auto transpose_element_type = transpose_input_info.transpose_const->get_element_type();
|
||||||
|
|
||||||
|
NodeVector new_nodes;
|
||||||
|
|
||||||
|
for (size_t i = 0; i < main_node->get_output_size(); ++i) {
|
||||||
|
auto main_node_consumers = main_node->output(i).get_target_inputs();
|
||||||
|
|
||||||
|
auto new_transpose_const = std::make_shared<Constant>(transpose_element_type,
|
||||||
|
Shape{transpose_axis_order.size()},
|
||||||
|
transpose_axis_order);
|
||||||
|
auto new_transpose = std::make_shared<Transpose>(main_node->output(i), new_transpose_const);
|
||||||
|
|
||||||
|
for (auto& consumer : main_node_consumers) {
|
||||||
|
consumer.replace_source_output(new_transpose);
|
||||||
|
}
|
||||||
|
|
||||||
|
copy_runtime_info(main_node, {new_transpose, new_transpose_const});
|
||||||
|
SwapOutputNames(main_node->output(i), new_transpose->output(0));
|
||||||
|
|
||||||
|
if (main_node->get_output_size() > 1)
|
||||||
|
new_transpose->set_friendly_name(main_node->get_friendly_name() + "." + std::to_string(i));
|
||||||
|
else
|
||||||
|
SwapFriendlyNames(new_transpose, main_node);
|
||||||
|
|
||||||
|
new_nodes.push_back(new_transpose);
|
||||||
|
}
|
||||||
|
|
||||||
|
return new_nodes;
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace sink_forward
|
||||||
|
|
||||||
|
namespace sink_backward {
|
||||||
|
NodeVector InsertTransposeBeforeNode(NodePtr main_node, std::shared_ptr<Constant> transpose_const) {
|
||||||
|
const auto transpose_axis_order = transpose_const->get_axis_vector_val();
|
||||||
|
const auto transpose_element_type = transpose_const->get_element_type();
|
||||||
|
|
||||||
|
NodeVector new_nodes;
|
||||||
|
|
||||||
|
for (size_t i = 0; i < main_node->get_input_size(); ++i) {
|
||||||
|
auto input_node = main_node->input_value(i);
|
||||||
|
auto new_transpose_const = std::make_shared<Constant>(transpose_element_type,
|
||||||
|
Shape{transpose_axis_order.size()},
|
||||||
|
transpose_axis_order);
|
||||||
|
auto new_transpose = std::make_shared<Transpose>(input_node, new_transpose_const);
|
||||||
|
|
||||||
|
main_node->input(i).replace_source_output(new_transpose->output(0));
|
||||||
|
|
||||||
|
copy_runtime_info(input_node.get_node_shared_ptr(), {new_transpose, new_transpose_const});
|
||||||
|
|
||||||
|
new_nodes.push_back(new_transpose);
|
||||||
|
}
|
||||||
|
|
||||||
|
return new_nodes;
|
||||||
|
}
|
||||||
|
} // namespace sink_backward
|
||||||
|
|
||||||
|
} // namespace transpose_sinking
|
@ -0,0 +1,358 @@
|
|||||||
|
// Copyright (C) 2022 Intel Corporation
|
||||||
|
// SPDX-License-Identifier: Apache-2.0
|
||||||
|
//
|
||||||
|
|
||||||
|
#include <transformations/common_optimizations/transpose_sinking_binary.hpp>
|
||||||
|
|
||||||
|
#include <transformations/init_node_info.hpp>
|
||||||
|
#include <openvino/frontend/manager.hpp>
|
||||||
|
#include <openvino/opsets/opset9.hpp>
|
||||||
|
#include <openvino/pass/manager.hpp>
|
||||||
|
#include "common_test_utils/ngraph_test_utils.hpp"
|
||||||
|
|
||||||
|
#include <functional>
|
||||||
|
|
||||||
|
#include "gtest/gtest.h"
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
|
||||||
|
using NodePtr = std::shared_ptr<ov::Node>;
|
||||||
|
using ModelPtr = std::shared_ptr<ov::Model>;
|
||||||
|
using Output = ov::Output<ov::Node>;
|
||||||
|
|
||||||
|
// ----------------------------------------------------------------------------
|
||||||
|
|
||||||
|
class IBinaryFactory {
|
||||||
|
public:
|
||||||
|
IBinaryFactory() = default;
|
||||||
|
virtual ~IBinaryFactory() = default;
|
||||||
|
virtual NodePtr create(NodePtr parent_left_node, NodePtr parent_right_node) const = 0;
|
||||||
|
};
|
||||||
|
|
||||||
|
using BinaryFactoryPtr = std::shared_ptr<IBinaryFactory>;
|
||||||
|
|
||||||
|
template <typename BinaryT>
|
||||||
|
class BinaryFactory : public IBinaryFactory {
|
||||||
|
public:
|
||||||
|
BinaryFactory() = default;
|
||||||
|
NodePtr create(NodePtr parent_left_node, NodePtr parent_right_node) const override {
|
||||||
|
return std::make_shared<BinaryT>(parent_left_node, parent_right_node);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
template <typename BinaryT>
|
||||||
|
BinaryFactoryPtr CreateBinaryFactory() {
|
||||||
|
return std::make_shared<BinaryFactory<BinaryT>>();
|
||||||
|
}
|
||||||
|
|
||||||
|
// ----------------------------------------------------------------------------
|
||||||
|
|
||||||
|
class IPassFactory {
|
||||||
|
public:
|
||||||
|
IPassFactory() = default;
|
||||||
|
virtual ~IPassFactory() = default;
|
||||||
|
virtual void registerPass(ov::pass::Manager& pass_manager) const = 0;
|
||||||
|
};
|
||||||
|
|
||||||
|
using PassFactoryPtr = std::shared_ptr<IPassFactory>;
|
||||||
|
|
||||||
|
template <typename PassT>
|
||||||
|
class PassFactory : public IPassFactory {
|
||||||
|
public:
|
||||||
|
void registerPass(ov::pass::Manager& pass_manager) const override {
|
||||||
|
pass_manager.register_pass<PassT>();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
template <typename PassT>
|
||||||
|
PassFactoryPtr CreatePassFactory() {
|
||||||
|
return std::make_shared<PassFactory<PassT>>();
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<BinaryFactoryPtr> binary_factories = {
|
||||||
|
CreateBinaryFactory<ov::opset9::Add>(),
|
||||||
|
CreateBinaryFactory<ov::opset9::Divide>(),
|
||||||
|
CreateBinaryFactory<ov::opset9::Maximum>(),
|
||||||
|
CreateBinaryFactory<ov::opset9::Minimum>(),
|
||||||
|
CreateBinaryFactory<ov::opset9::Mod>(),
|
||||||
|
CreateBinaryFactory<ov::opset9::Multiply>(),
|
||||||
|
CreateBinaryFactory<ov::opset9::Power>(),
|
||||||
|
CreateBinaryFactory<ov::opset9::SquaredDifference>(),
|
||||||
|
CreateBinaryFactory<ov::opset9::Subtract>()
|
||||||
|
};
|
||||||
|
|
||||||
|
std::vector<size_t> binary_operations_numbers = {1, 10};
|
||||||
|
|
||||||
|
std::vector<size_t> binary_transpose_input_indexes = {0, 1};
|
||||||
|
|
||||||
|
} // namespace
|
||||||
|
|
||||||
|
|
||||||
|
namespace binary {
|
||||||
|
namespace single_consumer {
|
||||||
|
namespace forward {
|
||||||
|
namespace one_input_transpose {
|
||||||
|
|
||||||
|
std::shared_ptr<ov::Model> CreateFunction(BinaryFactoryPtr binary_factory,
|
||||||
|
size_t num_binary_ops,
|
||||||
|
ov::element::Type input_type,
|
||||||
|
size_t binary_transpose_input_idx) {
|
||||||
|
const ov::Shape input_shape{1, 96, 55, 55};
|
||||||
|
const ov::Shape const_shape{1, 55, 55, 96};
|
||||||
|
|
||||||
|
auto X = std::make_shared<ov::opset9::Parameter>(input_type, input_shape);
|
||||||
|
|
||||||
|
auto ng_order0 = std::make_shared<ov::opset9::Constant>(ov::element::u64, ov::Shape{4}, ov::Shape{0, 2, 3, 1});
|
||||||
|
auto transpose0 = std::make_shared<ov::opset9::Transpose>(X, ng_order0);
|
||||||
|
|
||||||
|
NodePtr in_op = transpose0;
|
||||||
|
for (size_t i = 0; i < num_binary_ops; ++i) {
|
||||||
|
auto in_constant = std::make_shared<ov::opset9::Constant>(input_type, const_shape, ov::Shape{1});
|
||||||
|
if (!binary_transpose_input_idx)
|
||||||
|
in_op = binary_factory->create(in_op, in_constant);
|
||||||
|
else
|
||||||
|
in_op = binary_factory->create(in_constant, in_op);
|
||||||
|
}
|
||||||
|
|
||||||
|
return std::make_shared<ov::Model>(ov::OutputVector{in_op}, ov::ParameterVector{X});
|
||||||
|
}
|
||||||
|
|
||||||
|
std::shared_ptr<ov::Model> CreateReferenceFunction(BinaryFactoryPtr binary_factory,
|
||||||
|
size_t num_binary_ops,
|
||||||
|
ov::element::Type input_type,
|
||||||
|
size_t binary_transpose_input_idx) {
|
||||||
|
const ov::Shape input_shape{1, 96, 55, 55};
|
||||||
|
const ov::Shape const_shape{1, 55, 55, 96};
|
||||||
|
|
||||||
|
auto X = std::make_shared<ov::opset9::Parameter>(input_type, input_shape);
|
||||||
|
|
||||||
|
NodePtr in_op = X;
|
||||||
|
for (size_t i = 0; i < num_binary_ops; ++i) {
|
||||||
|
auto in_constant = std::make_shared<ov::opset9::Constant>(input_type, const_shape, ov::Shape{1});
|
||||||
|
|
||||||
|
auto transpose_reversed_const =
|
||||||
|
std::make_shared<ov::opset9::Constant>(ov::element::u64, ov::Shape{4}, ov::Shape{0, 3, 1, 2});
|
||||||
|
auto transpose_reversed = std::make_shared<ov::opset9::Transpose>(in_constant, transpose_reversed_const);
|
||||||
|
|
||||||
|
if (!binary_transpose_input_idx)
|
||||||
|
in_op = binary_factory->create(in_op, transpose_reversed);
|
||||||
|
else
|
||||||
|
in_op = binary_factory->create(transpose_reversed, in_op);
|
||||||
|
}
|
||||||
|
|
||||||
|
auto ng_order0 = std::make_shared<ov::opset9::Constant>(ov::element::u64, ov::Shape{4}, ov::Shape{0, 2, 3, 1});
|
||||||
|
auto transpose0 = std::make_shared<ov::opset9::Transpose>(in_op, ng_order0);
|
||||||
|
|
||||||
|
return std::make_shared<ov::Model>(ov::OutputVector{transpose0}, ov::ParameterVector{X});
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace one_input_transpose
|
||||||
|
|
||||||
|
namespace double_transpose {
|
||||||
|
std::shared_ptr<ov::Model> CreateFunction(BinaryFactoryPtr binary_factory,
|
||||||
|
size_t num_binary_ops,
|
||||||
|
ov::element::Type input_type) {
|
||||||
|
const ov::Shape input_shape{1, 96, 55, 55};
|
||||||
|
|
||||||
|
auto X = std::make_shared<ov::opset9::Parameter>(input_type, input_shape);
|
||||||
|
|
||||||
|
auto ng_order0 = std::make_shared<ov::opset9::Constant>(ov::element::u64, ov::Shape{4}, ov::Shape{0, 2, 3, 1});
|
||||||
|
auto transpose0 = std::make_shared<ov::opset9::Transpose>(X, ng_order0);
|
||||||
|
|
||||||
|
NodePtr in_op = transpose0;
|
||||||
|
for (size_t i = 0; i < num_binary_ops; ++i) {
|
||||||
|
auto in_constant = std::make_shared<ov::opset9::Constant>(input_type, input_shape, ov::Shape{1});
|
||||||
|
auto ng_order1 = std::make_shared<ov::opset9::Constant>(ov::element::u64, ov::Shape{4}, ov::Shape{0, 2, 3, 1});
|
||||||
|
auto transpose1 = std::make_shared<ov::opset9::Transpose>(in_constant, ng_order1);
|
||||||
|
|
||||||
|
in_op = binary_factory->create(in_op, transpose1);
|
||||||
|
}
|
||||||
|
|
||||||
|
return std::make_shared<ov::Model>(ov::OutputVector{in_op}, ov::ParameterVector{X});
|
||||||
|
}
|
||||||
|
|
||||||
|
std::shared_ptr<ov::Model> CreateReferenceFunction(BinaryFactoryPtr binary_factory,
|
||||||
|
size_t num_binary_ops,
|
||||||
|
ov::element::Type input_type) {
|
||||||
|
const ov::Shape input_shape{1, 96, 55, 55};
|
||||||
|
|
||||||
|
auto X = std::make_shared<ov::opset9::Parameter>(input_type, input_shape);
|
||||||
|
|
||||||
|
NodePtr in_op = X;
|
||||||
|
for (size_t i = 0; i < num_binary_ops; ++i) {
|
||||||
|
auto in_constant = std::make_shared<ov::opset9::Constant>(input_type, input_shape, ov::Shape{1});
|
||||||
|
|
||||||
|
auto ng_order1 = std::make_shared<ov::opset9::Constant>(ov::element::u64, ov::Shape{4}, ov::Shape{0, 2, 3, 1});
|
||||||
|
auto transpose1 = std::make_shared<ov::opset9::Transpose>(in_constant, ng_order1);
|
||||||
|
|
||||||
|
auto transpose_reversed_const =
|
||||||
|
std::make_shared<ov::opset9::Constant>(ov::element::u64, ov::Shape{4}, ov::Shape{0, 3, 1, 2});
|
||||||
|
auto transpose_reversed = std::make_shared<ov::opset9::Transpose>(transpose1, transpose_reversed_const);
|
||||||
|
|
||||||
|
in_op = binary_factory->create(in_op, transpose_reversed);
|
||||||
|
}
|
||||||
|
|
||||||
|
auto ng_order0 = std::make_shared<ov::opset9::Constant>(ov::element::u64, ov::Shape{4}, ov::Shape{0, 2, 3, 1});
|
||||||
|
auto transpose0 = std::make_shared<ov::opset9::Transpose>(in_op, ng_order0);
|
||||||
|
|
||||||
|
return std::make_shared<ov::Model>(ov::OutputVector{transpose0}, ov::ParameterVector{X});
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace double_transpose
|
||||||
|
} // namespace forward
|
||||||
|
|
||||||
|
namespace backward {
|
||||||
|
namespace one_input_transpose {
|
||||||
|
std::shared_ptr<ov::Model> CreateFunction(BinaryFactoryPtr binary_factory,
|
||||||
|
size_t num_binary_ops,
|
||||||
|
ov::element::Type input_type,
|
||||||
|
size_t binary_transpose_input_idx) {
|
||||||
|
const ov::Shape input_shape{1, 96, 55, 55};
|
||||||
|
|
||||||
|
auto X = std::make_shared<ov::opset9::Parameter>(input_type, input_shape);
|
||||||
|
|
||||||
|
NodePtr in_op = X;
|
||||||
|
for (size_t i = 0; i < num_binary_ops; ++i) {
|
||||||
|
auto in_constant = std::make_shared<ov::opset9::Constant>(input_type, input_shape, ov::Shape{1});
|
||||||
|
if (!binary_transpose_input_idx)
|
||||||
|
in_op = binary_factory->create(in_op, in_constant);
|
||||||
|
else
|
||||||
|
in_op = binary_factory->create(in_constant, in_op);
|
||||||
|
}
|
||||||
|
|
||||||
|
auto ng_order0 = std::make_shared<ov::opset9::Constant>(ov::element::u64, ov::Shape{4}, ov::Shape{0, 2, 3, 1});
|
||||||
|
auto transpose0 = std::make_shared<ov::opset9::Transpose>(in_op, ng_order0);
|
||||||
|
|
||||||
|
return std::make_shared<ov::Model>(ov::OutputVector{transpose0}, ov::ParameterVector{X});
|
||||||
|
}
|
||||||
|
|
||||||
|
std::shared_ptr<ov::Model> CreateReferenceFunction(BinaryFactoryPtr binary_factory,
|
||||||
|
size_t num_binary_ops,
|
||||||
|
ov::element::Type input_type,
|
||||||
|
size_t binary_transpose_input_idx) {
|
||||||
|
const ov::Shape input_shape{1, 96, 55, 55};
|
||||||
|
|
||||||
|
auto X = std::make_shared<ov::opset9::Parameter>(input_type, input_shape);
|
||||||
|
|
||||||
|
auto ng_order0 = std::make_shared<ov::opset9::Constant>(ov::element::u64, ov::Shape{4}, ov::Shape{0, 2, 3, 1});
|
||||||
|
auto transpose0 = std::make_shared<ov::opset9::Transpose>(X, ng_order0);
|
||||||
|
|
||||||
|
NodePtr in_op = transpose0;
|
||||||
|
for (size_t i = 0; i < num_binary_ops; ++i) {
|
||||||
|
auto in_constant = std::make_shared<ov::opset9::Constant>(input_type, input_shape, ov::Shape{1});
|
||||||
|
|
||||||
|
auto ng_order = std::make_shared<ov::opset9::Constant>(ov::element::u64, ov::Shape{4}, ov::Shape{0, 2, 3, 1});
|
||||||
|
auto transpose = std::make_shared<ov::opset9::Transpose>(in_constant, ng_order);
|
||||||
|
|
||||||
|
if (!binary_transpose_input_idx)
|
||||||
|
in_op = binary_factory->create(in_op, transpose);
|
||||||
|
else
|
||||||
|
in_op = binary_factory->create(transpose, in_op);
|
||||||
|
}
|
||||||
|
|
||||||
|
return std::make_shared<ov::Model>(ov::OutputVector{in_op}, ov::ParameterVector{X});
|
||||||
|
}
|
||||||
|
} // namespace one_input_transpose
|
||||||
|
} // namespace backward
|
||||||
|
} // namespace single_consumer
|
||||||
|
} // namespace binary
|
||||||
|
|
||||||
|
using CreateGraphBinaryF = std::function<std::shared_ptr<ov::Model>(BinaryFactoryPtr unary_factory,
|
||||||
|
size_t num_binary_ops,
|
||||||
|
ov::element::Type input_type,
|
||||||
|
size_t binary_transpose_input_idx)>;
|
||||||
|
|
||||||
|
using TestBinaryParams = std::tuple<BinaryFactoryPtr,
|
||||||
|
PassFactoryPtr,
|
||||||
|
size_t, /* num_binary_ops */
|
||||||
|
CreateGraphBinaryF, /* model_factory */
|
||||||
|
CreateGraphBinaryF, /* reference_model_factory */
|
||||||
|
ov::element::Type, /* input type */
|
||||||
|
size_t>; /* binary_transpose_input_idx */
|
||||||
|
|
||||||
|
class TransposeSinkingBinaryTestFixture : public ::testing::WithParamInterface<TestBinaryParams>,
|
||||||
|
public TransformationTestsF {};
|
||||||
|
|
||||||
|
|
||||||
|
TEST_P(TransposeSinkingBinaryTestFixture, CompareFunctions) {
|
||||||
|
BinaryFactoryPtr unary_factory;
|
||||||
|
PassFactoryPtr pass_factory;
|
||||||
|
size_t num_binary_ops;
|
||||||
|
CreateGraphBinaryF model_factory;
|
||||||
|
CreateGraphBinaryF reference_model_factory;
|
||||||
|
ov::element::Type input_type;
|
||||||
|
size_t binary_transpose_input_idx;
|
||||||
|
std::tie(unary_factory,
|
||||||
|
pass_factory,
|
||||||
|
num_binary_ops,
|
||||||
|
model_factory,
|
||||||
|
reference_model_factory,
|
||||||
|
input_type,
|
||||||
|
binary_transpose_input_idx) = this->GetParam();
|
||||||
|
|
||||||
|
model = model_factory(unary_factory, num_binary_ops, input_type, binary_transpose_input_idx);
|
||||||
|
model_ref = reference_model_factory(unary_factory, num_binary_ops, input_type, binary_transpose_input_idx);
|
||||||
|
pass_factory->registerPass(manager);
|
||||||
|
}
|
||||||
|
|
||||||
|
INSTANTIATE_TEST_SUITE_P(TransposeSinkingBinaryForwardTestSuite, TransposeSinkingBinaryTestFixture,
|
||||||
|
::testing::Combine(::testing::ValuesIn(binary_factories),
|
||||||
|
::testing::Values(CreatePassFactory<ov::pass::TransposeSinkingBinaryElementwiseForward>()),
|
||||||
|
::testing::ValuesIn(binary_operations_numbers),
|
||||||
|
::testing::Values(binary::single_consumer::forward::one_input_transpose::CreateFunction),
|
||||||
|
::testing::Values(binary::single_consumer::forward::one_input_transpose::CreateReferenceFunction),
|
||||||
|
::testing::Values(ov::element::f32),
|
||||||
|
::testing::ValuesIn(binary_transpose_input_indexes)));
|
||||||
|
|
||||||
|
INSTANTIATE_TEST_SUITE_P(
|
||||||
|
TransposeSinkingBinaryBackwardTestSuite,
|
||||||
|
TransposeSinkingBinaryTestFixture,
|
||||||
|
::testing::Combine(::testing::ValuesIn(binary_factories),
|
||||||
|
::testing::Values(CreatePassFactory<ov::pass::TransposeSinkingBinaryElementwiseBackward>()),
|
||||||
|
::testing::ValuesIn(binary_operations_numbers),
|
||||||
|
::testing::Values(binary::single_consumer::backward::one_input_transpose::CreateFunction),
|
||||||
|
::testing::Values(binary::single_consumer::backward::one_input_transpose::CreateReferenceFunction),
|
||||||
|
::testing::Values(ov::element::f32),
|
||||||
|
::testing::ValuesIn(binary_transpose_input_indexes)));
|
||||||
|
|
||||||
|
// --------------------------------------------------------------------------------------
|
||||||
|
|
||||||
|
using CreateGraphBinaryTwoTransposeInputsF = std::function<
|
||||||
|
std::shared_ptr<ov::Model>(BinaryFactoryPtr unary_factory, size_t num_binary_ops, ov::element::Type input_type)>;
|
||||||
|
|
||||||
|
using TestBinaryTwoTransposeInputsParams = std::tuple<BinaryFactoryPtr,
|
||||||
|
PassFactoryPtr,
|
||||||
|
size_t, /* num_binary_ops */
|
||||||
|
CreateGraphBinaryTwoTransposeInputsF, /* model_factory */
|
||||||
|
CreateGraphBinaryTwoTransposeInputsF, /* reference_model_factory */
|
||||||
|
ov::element::Type>; /* input type */
|
||||||
|
|
||||||
|
class TransposeSinkingBinaryTwoTransposeInputsTestFixture
|
||||||
|
: public ::testing::WithParamInterface<TestBinaryTwoTransposeInputsParams>,
|
||||||
|
public TransformationTestsF {};
|
||||||
|
|
||||||
|
TEST_P(TransposeSinkingBinaryTwoTransposeInputsTestFixture, CompareFunctions) {
|
||||||
|
BinaryFactoryPtr unary_factory;
|
||||||
|
PassFactoryPtr pass_factory;
|
||||||
|
size_t num_binary_ops;
|
||||||
|
CreateGraphBinaryTwoTransposeInputsF model_factory;
|
||||||
|
CreateGraphBinaryTwoTransposeInputsF reference_model_factory;
|
||||||
|
ov::element::Type input_type;
|
||||||
|
|
||||||
|
std::tie(unary_factory, pass_factory, num_binary_ops, model_factory, reference_model_factory, input_type) =
|
||||||
|
this->GetParam();
|
||||||
|
|
||||||
|
model = model_factory(unary_factory, num_binary_ops, input_type);
|
||||||
|
model_ref = reference_model_factory(unary_factory, num_binary_ops, input_type);
|
||||||
|
pass_factory->registerPass(manager);
|
||||||
|
}
|
||||||
|
|
||||||
|
INSTANTIATE_TEST_SUITE_P(
|
||||||
|
TransposeSinkingBinaryTwoTransposeInputsForwardTestSuite,
|
||||||
|
TransposeSinkingBinaryTwoTransposeInputsTestFixture,
|
||||||
|
::testing::Combine(::testing::ValuesIn(binary_factories),
|
||||||
|
::testing::Values(CreatePassFactory<ov::pass::TransposeSinkingBinaryElementwiseForward>()),
|
||||||
|
::testing::ValuesIn(binary_operations_numbers),
|
||||||
|
::testing::Values(binary::single_consumer::forward::double_transpose::CreateFunction),
|
||||||
|
::testing::Values(binary::single_consumer::forward::double_transpose::CreateReferenceFunction),
|
||||||
|
::testing::Values(ov::element::f32)));
|
@ -0,0 +1,400 @@
|
|||||||
|
// Copyright (C) 2022 Intel Corporation
|
||||||
|
// SPDX-License-Identifier: Apache-2.0
|
||||||
|
//
|
||||||
|
|
||||||
|
#include <transformations/common_optimizations/transpose_sinking_concat.hpp>
|
||||||
|
|
||||||
|
#include <transformations/init_node_info.hpp>
|
||||||
|
#include <openvino/frontend/manager.hpp>
|
||||||
|
#include <openvino/opsets/opset9.hpp>
|
||||||
|
#include <openvino/pass/manager.hpp>
|
||||||
|
#include "common_test_utils/ngraph_test_utils.hpp"
|
||||||
|
|
||||||
|
#include <functional>
|
||||||
|
|
||||||
|
#include "gtest/gtest.h"
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
|
||||||
|
using NodePtr = std::shared_ptr<ov::Node>;
|
||||||
|
using ModelPtr = std::shared_ptr<ov::Model>;
|
||||||
|
using Output = ov::Output<ov::Node>;
|
||||||
|
|
||||||
|
// ----------------------------------------------------------------------------
|
||||||
|
|
||||||
|
class IBinaryFactory {
|
||||||
|
public:
|
||||||
|
IBinaryFactory() = default;
|
||||||
|
virtual ~IBinaryFactory() = default;
|
||||||
|
virtual NodePtr create(NodePtr parent_left_node, NodePtr parent_right_node) const = 0;
|
||||||
|
};
|
||||||
|
|
||||||
|
using BinaryFactoryPtr = std::shared_ptr<IBinaryFactory>;
|
||||||
|
|
||||||
|
template <typename BinaryT>
|
||||||
|
class BinaryFactory : public IBinaryFactory {
|
||||||
|
public:
|
||||||
|
BinaryFactory() = default;
|
||||||
|
NodePtr create(NodePtr parent_left_node, NodePtr parent_right_node) const override {
|
||||||
|
return std::make_shared<BinaryT>(parent_left_node, parent_right_node);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
template <typename BinaryT>
|
||||||
|
BinaryFactoryPtr CreateBinaryFactory() {
|
||||||
|
return std::make_shared<BinaryFactory<BinaryT>>();
|
||||||
|
}
|
||||||
|
|
||||||
|
// ----------------------------------------------------------------------------
|
||||||
|
|
||||||
|
class IPassFactory {
|
||||||
|
public:
|
||||||
|
IPassFactory() = default;
|
||||||
|
virtual ~IPassFactory() = default;
|
||||||
|
virtual void registerPass(ov::pass::Manager& pass_manager) const = 0;
|
||||||
|
};
|
||||||
|
|
||||||
|
using PassFactoryPtr = std::shared_ptr<IPassFactory>;
|
||||||
|
|
||||||
|
template <typename PassT>
|
||||||
|
class PassFactory : public IPassFactory {
|
||||||
|
public:
|
||||||
|
void registerPass(ov::pass::Manager& pass_manager) const override {
|
||||||
|
pass_manager.register_pass<PassT>();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
template <typename PassT>
|
||||||
|
PassFactoryPtr CreatePassFactory() {
|
||||||
|
return std::make_shared<PassFactory<PassT>>();
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<BinaryFactoryPtr> binary_factories = {
|
||||||
|
CreateBinaryFactory<ov::opset9::Add>(),
|
||||||
|
CreateBinaryFactory<ov::opset9::Divide>(),
|
||||||
|
CreateBinaryFactory<ov::opset9::Maximum>(),
|
||||||
|
CreateBinaryFactory<ov::opset9::Minimum>(),
|
||||||
|
CreateBinaryFactory<ov::opset9::Mod>(),
|
||||||
|
CreateBinaryFactory<ov::opset9::Multiply>(),
|
||||||
|
CreateBinaryFactory<ov::opset9::Power>(),
|
||||||
|
CreateBinaryFactory<ov::opset9::SquaredDifference>(),
|
||||||
|
CreateBinaryFactory<ov::opset9::Subtract>()
|
||||||
|
};
|
||||||
|
|
||||||
|
std::vector<size_t> binary_operations_numbers = {1, 10};
|
||||||
|
|
||||||
|
std::vector<size_t> binary_transpose_input_indexes = {0, 1};
|
||||||
|
|
||||||
|
} // namespace
|
||||||
|
|
||||||
|
|
||||||
|
using CreateGraphConcatF = std::function< std::shared_ptr<ov::Model> (size_t num_concat_ops,
|
||||||
|
ov::element::Type input_type,
|
||||||
|
size_t concat_transpose_input_idx,
|
||||||
|
size_t num_concat_inputs) >;
|
||||||
|
|
||||||
|
using TestConcatParams = std::tuple<PassFactoryPtr,
|
||||||
|
size_t, /* num_concat_ops */
|
||||||
|
CreateGraphConcatF, /* model_factory */
|
||||||
|
CreateGraphConcatF, /* reference_model_factory */
|
||||||
|
ov::element::Type, /* input type */
|
||||||
|
size_t, /* concat_transpose_input_idx */
|
||||||
|
size_t>; /* num_concat_inputs */
|
||||||
|
|
||||||
|
class TransposeSinkingConcatTestFixture: public ::testing::WithParamInterface<TestConcatParams>,
|
||||||
|
public TransformationTestsF {};
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
|
||||||
|
std::vector<size_t> concat_operations_numbers = {1, 10};
|
||||||
|
|
||||||
|
std::vector<size_t> concat_transpose_input_indexes = {0, 2};
|
||||||
|
|
||||||
|
} // namespace
|
||||||
|
|
||||||
|
namespace single_consumer {
|
||||||
|
namespace forward {
|
||||||
|
namespace one_input_transpose {
|
||||||
|
|
||||||
|
std::shared_ptr<ov::Model> CreateFunction(size_t num_concat_ops,
|
||||||
|
ov::element::Type input_type,
|
||||||
|
size_t concat_transpose_input_idx,
|
||||||
|
size_t num_concat_inputs) {
|
||||||
|
const ov::Shape input_shape{1, 96, 55, 55};
|
||||||
|
const ov::Shape const_shape{1, 55, 55, 96};
|
||||||
|
|
||||||
|
auto X = std::make_shared<ov::opset9::Parameter>(input_type, input_shape);
|
||||||
|
|
||||||
|
auto ng_order0 = std::make_shared<ov::opset9::Constant>(ov::element::u64, ov::Shape{4}, ov::Shape{0, 2, 3, 1});
|
||||||
|
auto transpose0 = std::make_shared<ov::opset9::Transpose>(X, ng_order0);
|
||||||
|
|
||||||
|
NodePtr in_op = transpose0;
|
||||||
|
for (size_t i = 0; i < num_concat_ops; ++i) {
|
||||||
|
ov::OutputVector concat_inputs;
|
||||||
|
for (size_t j = 0; j < num_concat_inputs; ++j) {
|
||||||
|
if (j == concat_transpose_input_idx)
|
||||||
|
concat_inputs.push_back(in_op);
|
||||||
|
else
|
||||||
|
concat_inputs.push_back(std::make_shared<ov::opset9::Constant>(input_type, const_shape, ov::Shape{1}));
|
||||||
|
}
|
||||||
|
in_op = std::make_shared<ov::opset9::Concat>(concat_inputs, 1);
|
||||||
|
}
|
||||||
|
|
||||||
|
return std::make_shared<ov::Model>(ov::OutputVector{in_op}, ov::ParameterVector{X});
|
||||||
|
}
|
||||||
|
|
||||||
|
std::shared_ptr<ov::Model> CreateReferenceFunction(size_t num_concat_ops,
|
||||||
|
ov::element::Type input_type,
|
||||||
|
size_t concat_transpose_input_idx,
|
||||||
|
size_t num_concat_inputs) {
|
||||||
|
const ov::Shape input_shape{1, 96, 55, 55};
|
||||||
|
const ov::Shape const_shape{1, 55, 55, 96};
|
||||||
|
|
||||||
|
auto X = std::make_shared<ov::opset9::Parameter>(input_type, input_shape);
|
||||||
|
|
||||||
|
NodePtr in_op = X;
|
||||||
|
for (size_t i = 0; i < num_concat_ops; ++i) {
|
||||||
|
ov::OutputVector concat_inputs;
|
||||||
|
for (size_t j = 0; j < num_concat_inputs; ++j) {
|
||||||
|
if (j == concat_transpose_input_idx) {
|
||||||
|
concat_inputs.push_back(in_op);
|
||||||
|
} else {
|
||||||
|
auto in_constant = std::make_shared<ov::opset9::Constant>(input_type, const_shape, ov::Shape{1});
|
||||||
|
|
||||||
|
auto transpose_reversed_const =
|
||||||
|
std::make_shared<ov::opset9::Constant>(ov::element::u64, ov::Shape{4}, ov::Shape{0, 3, 1, 2});
|
||||||
|
auto transpose_reversed =
|
||||||
|
std::make_shared<ov::opset9::Transpose>(in_constant, transpose_reversed_const);
|
||||||
|
|
||||||
|
concat_inputs.push_back(transpose_reversed);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
in_op = std::make_shared<ov::opset9::Concat>(concat_inputs, 2);
|
||||||
|
}
|
||||||
|
|
||||||
|
auto ng_order0 = std::make_shared<ov::opset9::Constant>(ov::element::u64, ov::Shape{4}, ov::Shape{0, 2, 3, 1});
|
||||||
|
auto transpose0 = std::make_shared<ov::opset9::Transpose>(in_op, ng_order0);
|
||||||
|
|
||||||
|
return std::make_shared<ov::Model>(ov::OutputVector{transpose0}, ov::ParameterVector{X});
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace one_input_transpose
|
||||||
|
|
||||||
|
namespace double_transpose {
|
||||||
|
|
||||||
|
std::shared_ptr<ov::Model> CreateFunction(size_t num_concat_ops,
|
||||||
|
ov::element::Type input_type,
|
||||||
|
size_t num_concat_inputs) {
|
||||||
|
const ov::Shape input_shape{1, 96, 55, 55};
|
||||||
|
|
||||||
|
auto X = std::make_shared<ov::opset9::Parameter>(input_type, input_shape);
|
||||||
|
|
||||||
|
auto ng_order0 = std::make_shared<ov::opset9::Constant>(ov::element::u64, ov::Shape{4}, ov::Shape{0, 2, 3, 1});
|
||||||
|
auto transpose0 = std::make_shared<ov::opset9::Transpose>(X, ng_order0);
|
||||||
|
|
||||||
|
NodePtr in_op = transpose0;
|
||||||
|
for (size_t i = 0; i < num_concat_ops; ++i) {
|
||||||
|
ov::OutputVector concat_inputs;
|
||||||
|
concat_inputs.push_back(in_op);
|
||||||
|
for (size_t j = 1; j < num_concat_inputs; ++j) {
|
||||||
|
auto in_constant = std::make_shared<ov::opset9::Constant>(input_type, input_shape, ov::Shape{1});
|
||||||
|
auto ng_order1 =
|
||||||
|
std::make_shared<ov::opset9::Constant>(ov::element::u64, ov::Shape{4}, ov::Shape{0, 2, 3, 1});
|
||||||
|
auto transpose1 = std::make_shared<ov::opset9::Transpose>(in_constant, ng_order1);
|
||||||
|
concat_inputs.push_back(transpose1);
|
||||||
|
}
|
||||||
|
in_op = std::make_shared<ov::opset9::Concat>(concat_inputs, 1);
|
||||||
|
}
|
||||||
|
|
||||||
|
return std::make_shared<ov::Model>(ov::OutputVector{in_op}, ov::ParameterVector{X});
|
||||||
|
}
|
||||||
|
|
||||||
|
std::shared_ptr<ov::Model> CreateReferenceFunction(size_t num_concat_ops,
|
||||||
|
ov::element::Type input_type,
|
||||||
|
size_t num_concat_inputs) {
|
||||||
|
const ov::Shape input_shape{1, 96, 55, 55};
|
||||||
|
|
||||||
|
auto X = std::make_shared<ov::opset9::Parameter>(input_type, input_shape);
|
||||||
|
|
||||||
|
NodePtr in_op = X;
|
||||||
|
for (size_t i = 0; i < num_concat_ops; ++i) {
|
||||||
|
ov::OutputVector concat_inputs;
|
||||||
|
|
||||||
|
concat_inputs.push_back(in_op);
|
||||||
|
|
||||||
|
for (size_t j = 1; j < num_concat_inputs; ++j) {
|
||||||
|
auto in_constant = std::make_shared<ov::opset9::Constant>(input_type, input_shape, ov::Shape{1});
|
||||||
|
|
||||||
|
auto ng_order1 =
|
||||||
|
std::make_shared<ov::opset9::Constant>(ov::element::u64, ov::Shape{4}, ov::Shape{0, 2, 3, 1});
|
||||||
|
auto transpose1 = std::make_shared<ov::opset9::Transpose>(in_constant, ng_order1);
|
||||||
|
|
||||||
|
auto transpose_reversed_const =
|
||||||
|
std::make_shared<ov::opset9::Constant>(ov::element::u64, ov::Shape{4}, ov::Shape{0, 3, 1, 2});
|
||||||
|
auto transpose_reversed = std::make_shared<ov::opset9::Transpose>(transpose1, transpose_reversed_const);
|
||||||
|
|
||||||
|
concat_inputs.push_back(transpose_reversed);
|
||||||
|
}
|
||||||
|
in_op = std::make_shared<ov::opset9::Concat>(concat_inputs, 2);
|
||||||
|
}
|
||||||
|
|
||||||
|
auto ng_order0 = std::make_shared<ov::opset9::Constant>(ov::element::u64, ov::Shape{4}, ov::Shape{0, 2, 3, 1});
|
||||||
|
auto transpose0 = std::make_shared<ov::opset9::Transpose>(in_op, ng_order0);
|
||||||
|
|
||||||
|
return std::make_shared<ov::Model>(ov::OutputVector{transpose0}, ov::ParameterVector{X});
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace double_transpose
|
||||||
|
|
||||||
|
} // namespace forward
|
||||||
|
|
||||||
|
namespace backward {
|
||||||
|
|
||||||
|
std::shared_ptr<ov::Model> CreateFunction(size_t num_concat_ops,
|
||||||
|
ov::element::Type input_type,
|
||||||
|
size_t concat_transpose_input_idx,
|
||||||
|
size_t num_concat_inputs) {
|
||||||
|
const ov::Shape input_shape{1, 96, 55, 55};
|
||||||
|
|
||||||
|
auto X = std::make_shared<ov::opset9::Parameter>(input_type, input_shape);
|
||||||
|
|
||||||
|
NodePtr in_op = X;
|
||||||
|
for (size_t i = 0; i < num_concat_ops; ++i) {
|
||||||
|
ov::OutputVector concat_inputs;
|
||||||
|
for (size_t j = 0; j < num_concat_inputs; ++j) {
|
||||||
|
if (j == concat_transpose_input_idx)
|
||||||
|
concat_inputs.push_back(in_op);
|
||||||
|
else
|
||||||
|
concat_inputs.push_back(std::make_shared<ov::opset9::Constant>(input_type, input_shape, ov::Shape{1}));
|
||||||
|
}
|
||||||
|
in_op = std::make_shared<ov::opset9::Concat>(concat_inputs, 1);
|
||||||
|
}
|
||||||
|
|
||||||
|
auto ng_order0 = std::make_shared<ov::opset9::Constant>(ov::element::u64, ov::Shape{4}, ov::Shape{0, 2, 3, 1});
|
||||||
|
auto transpose0 = std::make_shared<ov::opset9::Transpose>(in_op, ng_order0);
|
||||||
|
|
||||||
|
return std::make_shared<ov::Model>(ov::OutputVector{transpose0}, ov::ParameterVector{X});
|
||||||
|
}
|
||||||
|
|
||||||
|
std::shared_ptr<ov::Model> CreateReferenceFunction(size_t num_concat_ops,
|
||||||
|
ov::element::Type input_type,
|
||||||
|
size_t concat_transpose_input_idx,
|
||||||
|
size_t num_concat_inputs) {
|
||||||
|
const ov::Shape input_shape{1, 96, 55, 55};
|
||||||
|
|
||||||
|
auto X = std::make_shared<ov::opset9::Parameter>(input_type, input_shape);
|
||||||
|
|
||||||
|
auto ng_order0 = std::make_shared<ov::opset9::Constant>(ov::element::u64, ov::Shape{4}, ov::Shape{0, 2, 3, 1});
|
||||||
|
auto transpose0 = std::make_shared<ov::opset9::Transpose>(X, ng_order0);
|
||||||
|
|
||||||
|
NodePtr in_op = transpose0;
|
||||||
|
for (size_t i = 0; i < num_concat_ops; ++i) {
|
||||||
|
ov::OutputVector concat_inputs;
|
||||||
|
for (size_t j = 0; j < num_concat_inputs; ++j) {
|
||||||
|
if (j == concat_transpose_input_idx) {
|
||||||
|
concat_inputs.push_back(in_op);
|
||||||
|
} else {
|
||||||
|
auto in_constant = std::make_shared<ov::opset9::Constant>(input_type, input_shape, ov::Shape{1});
|
||||||
|
|
||||||
|
auto transpose_reversed_const = std::make_shared<ov::opset9::Constant>(ov::element::u64, ov::Shape{4}, ov::Shape{0, 2, 3, 1});
|
||||||
|
auto transpose_reversed = std::make_shared<ov::opset9::Transpose>(in_constant, transpose_reversed_const);
|
||||||
|
|
||||||
|
concat_inputs.push_back(transpose_reversed);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
in_op = std::make_shared<ov::opset9::Concat>(concat_inputs, 3);
|
||||||
|
}
|
||||||
|
|
||||||
|
return std::make_shared<ov::Model>(ov::OutputVector{in_op}, ov::ParameterVector{X});
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace backward
|
||||||
|
} // namespace single_consumer
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
TEST_P(TransposeSinkingConcatTestFixture, CompareFunctions) {
|
||||||
|
PassFactoryPtr pass_factory;
|
||||||
|
size_t num_concat_ops;
|
||||||
|
CreateGraphConcatF model_factory;
|
||||||
|
CreateGraphConcatF reference_model_factory;
|
||||||
|
ov::element::Type input_type;
|
||||||
|
size_t concat_transpose_input_idx;
|
||||||
|
size_t num_concat_inputs;
|
||||||
|
std::tie(pass_factory,
|
||||||
|
num_concat_ops,
|
||||||
|
model_factory,
|
||||||
|
reference_model_factory,
|
||||||
|
input_type,
|
||||||
|
concat_transpose_input_idx,
|
||||||
|
num_concat_inputs) = this->GetParam();
|
||||||
|
|
||||||
|
model = model_factory(num_concat_ops, input_type, concat_transpose_input_idx, num_concat_inputs);
|
||||||
|
model_ref = reference_model_factory(num_concat_ops, input_type, concat_transpose_input_idx, num_concat_inputs);
|
||||||
|
pass_factory->registerPass(manager);
|
||||||
|
}
|
||||||
|
|
||||||
|
INSTANTIATE_TEST_SUITE_P(TransposeSinkingConcatForwardTestSuite, TransposeSinkingConcatTestFixture,
|
||||||
|
::testing::Combine(::testing::Values(CreatePassFactory<ov::pass::TransposeSinkingConcatForward>()),
|
||||||
|
::testing::ValuesIn(concat_operations_numbers),
|
||||||
|
::testing::Values(single_consumer::forward::one_input_transpose::CreateFunction),
|
||||||
|
::testing::Values(single_consumer::forward::one_input_transpose::CreateReferenceFunction),
|
||||||
|
::testing::Values(ov::element::f32),
|
||||||
|
::testing::ValuesIn(concat_transpose_input_indexes),
|
||||||
|
::testing::Values(5)));
|
||||||
|
|
||||||
|
INSTANTIATE_TEST_SUITE_P(TransposeSinkingConcatBackwardTestSuite, TransposeSinkingConcatTestFixture,
|
||||||
|
::testing::Combine(::testing::Values(CreatePassFactory<ov::pass::TransposeSinkingConcatBackward>()),
|
||||||
|
::testing::ValuesIn(concat_operations_numbers),
|
||||||
|
::testing::Values(single_consumer::backward::CreateFunction),
|
||||||
|
::testing::Values(single_consumer::backward::CreateReferenceFunction),
|
||||||
|
::testing::Values(ov::element::f32),
|
||||||
|
::testing::ValuesIn(concat_transpose_input_indexes),
|
||||||
|
::testing::Values(5)));
|
||||||
|
|
||||||
|
|
||||||
|
// --------------------------------------------------------------------------------------
|
||||||
|
|
||||||
|
using CreateGraphConcatAllTransposesInputF = std::function<std::shared_ptr<ov::Model>(size_t num_concat_ops,
|
||||||
|
ov::element::Type input_type,
|
||||||
|
size_t num_concat_inputs)>;
|
||||||
|
|
||||||
|
using TestConcatAllTransposesInputParams = std::tuple<PassFactoryPtr,
|
||||||
|
size_t, /* num_concat_ops */
|
||||||
|
CreateGraphConcatAllTransposesInputF, /* model_factory */
|
||||||
|
CreateGraphConcatAllTransposesInputF, /* reference_model_factory */
|
||||||
|
ov::element::Type, /* input type */
|
||||||
|
size_t>; /* num_concat_inputs */
|
||||||
|
|
||||||
|
class TransposeSinkingConcatAllTransposesInputTestFixture
|
||||||
|
: public ::testing::WithParamInterface<TestConcatAllTransposesInputParams>,
|
||||||
|
public TransformationTestsF {};
|
||||||
|
|
||||||
|
TEST_P(TransposeSinkingConcatAllTransposesInputTestFixture, CompareFunctions) {
|
||||||
|
PassFactoryPtr pass_factory;
|
||||||
|
size_t num_concat_ops;
|
||||||
|
CreateGraphConcatAllTransposesInputF model_factory;
|
||||||
|
CreateGraphConcatAllTransposesInputF reference_model_factory;
|
||||||
|
ov::element::Type input_type;
|
||||||
|
size_t num_concat_inputs;
|
||||||
|
std::tie(pass_factory,
|
||||||
|
num_concat_ops,
|
||||||
|
model_factory,
|
||||||
|
reference_model_factory,
|
||||||
|
input_type,
|
||||||
|
num_concat_inputs) = this->GetParam();
|
||||||
|
|
||||||
|
model = model_factory(num_concat_ops, input_type, num_concat_inputs);
|
||||||
|
model_ref = reference_model_factory(num_concat_ops, input_type, num_concat_inputs);
|
||||||
|
pass_factory->registerPass(manager);
|
||||||
|
}
|
||||||
|
|
||||||
|
INSTANTIATE_TEST_SUITE_P(
|
||||||
|
TransposeSinkingConcatForwardAllTransposesTestSuite,
|
||||||
|
TransposeSinkingConcatAllTransposesInputTestFixture,
|
||||||
|
::testing::Combine(::testing::Values(CreatePassFactory<ov::pass::TransposeSinkingConcatForward>()),
|
||||||
|
::testing::ValuesIn(concat_operations_numbers),
|
||||||
|
::testing::Values(single_consumer::forward::double_transpose::CreateFunction),
|
||||||
|
::testing::Values(single_consumer::forward::double_transpose::CreateReferenceFunction),
|
||||||
|
::testing::Values(ov::element::f32),
|
||||||
|
::testing::Values(5)));
|
@ -0,0 +1,552 @@
|
|||||||
|
// Copyright (C) 2022 Intel Corporation
|
||||||
|
// SPDX-License-Identifier: Apache-2.0
|
||||||
|
//
|
||||||
|
|
||||||
|
#include <transformations/common_optimizations/transpose_sinking_split.hpp>
|
||||||
|
|
||||||
|
#include <transformations/init_node_info.hpp>
|
||||||
|
#include <openvino/frontend/manager.hpp>
|
||||||
|
#include <openvino/opsets/opset9.hpp>
|
||||||
|
#include <openvino/pass/manager.hpp>
|
||||||
|
#include "common_test_utils/ngraph_test_utils.hpp"
|
||||||
|
|
||||||
|
#include <functional>
|
||||||
|
|
||||||
|
#include "gtest/gtest.h"
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
|
||||||
|
using NodePtr = std::shared_ptr<ov::Node>;
|
||||||
|
using ModelPtr = std::shared_ptr<ov::Model>;
|
||||||
|
using Output = ov::Output<ov::Node>;
|
||||||
|
|
||||||
|
// ----------------------------------------------------------------------------
|
||||||
|
|
||||||
|
class IBinaryFactory {
|
||||||
|
public:
|
||||||
|
IBinaryFactory() = default;
|
||||||
|
virtual ~IBinaryFactory() = default;
|
||||||
|
virtual NodePtr create(NodePtr parent_left_node, NodePtr parent_right_node) const = 0;
|
||||||
|
};
|
||||||
|
|
||||||
|
using BinaryFactoryPtr = std::shared_ptr<IBinaryFactory>;
|
||||||
|
|
||||||
|
template <typename BinaryT>
|
||||||
|
class BinaryFactory : public IBinaryFactory {
|
||||||
|
public:
|
||||||
|
BinaryFactory() = default;
|
||||||
|
NodePtr create(NodePtr parent_left_node, NodePtr parent_right_node) const override {
|
||||||
|
return std::make_shared<BinaryT>(parent_left_node, parent_right_node);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
template <typename BinaryT>
|
||||||
|
BinaryFactoryPtr CreateBinaryFactory() {
|
||||||
|
return std::make_shared<BinaryFactory<BinaryT>>();
|
||||||
|
}
|
||||||
|
|
||||||
|
// ----------------------------------------------------------------------------
|
||||||
|
|
||||||
|
class IPassFactory {
|
||||||
|
public:
|
||||||
|
IPassFactory() = default;
|
||||||
|
virtual ~IPassFactory() = default;
|
||||||
|
virtual void registerPass(ov::pass::Manager& pass_manager) const = 0;
|
||||||
|
};
|
||||||
|
|
||||||
|
using PassFactoryPtr = std::shared_ptr<IPassFactory>;
|
||||||
|
|
||||||
|
template <typename PassT>
|
||||||
|
class PassFactory : public IPassFactory {
|
||||||
|
public:
|
||||||
|
void registerPass(ov::pass::Manager& pass_manager) const override {
|
||||||
|
pass_manager.register_pass<PassT>();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
template <typename PassT>
|
||||||
|
PassFactoryPtr CreatePassFactory() {
|
||||||
|
return std::make_shared<PassFactory<PassT>>();
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<BinaryFactoryPtr> binary_factories = {
|
||||||
|
CreateBinaryFactory<ov::opset9::Add>(),
|
||||||
|
CreateBinaryFactory<ov::opset9::Divide>(),
|
||||||
|
CreateBinaryFactory<ov::opset9::Maximum>(),
|
||||||
|
CreateBinaryFactory<ov::opset9::Minimum>(),
|
||||||
|
CreateBinaryFactory<ov::opset9::Mod>(),
|
||||||
|
CreateBinaryFactory<ov::opset9::Multiply>(),
|
||||||
|
CreateBinaryFactory<ov::opset9::Power>(),
|
||||||
|
CreateBinaryFactory<ov::opset9::SquaredDifference>(),
|
||||||
|
CreateBinaryFactory<ov::opset9::Subtract>()
|
||||||
|
};
|
||||||
|
|
||||||
|
std::vector<size_t> binary_operations_numbers = {1, 10};
|
||||||
|
|
||||||
|
std::vector<size_t> binary_transpose_input_indexes = {0, 1};
|
||||||
|
|
||||||
|
} // namespace
|
||||||
|
|
||||||
|
// --------------------------------------------------------------------------------------
|
||||||
|
|
||||||
|
using CreateGraphSplitForwardF = std::function< std::shared_ptr<ov::Model> (size_t num_split_ops,
|
||||||
|
size_t num_split_outputs,
|
||||||
|
ov::element::Type input_type)>;
|
||||||
|
|
||||||
|
using TestSplitForwardParams = std::tuple<PassFactoryPtr,
|
||||||
|
size_t, /* num_split_ops */
|
||||||
|
size_t, /* num_split_outputs */
|
||||||
|
CreateGraphSplitForwardF, /* model_factory */
|
||||||
|
CreateGraphSplitForwardF, /* reference_model_factory */
|
||||||
|
ov::element::Type> /* input type */;
|
||||||
|
|
||||||
|
class TransposeSinkingSplitForwardTestFixture: public ::testing::WithParamInterface<TestSplitForwardParams>,
|
||||||
|
public TransformationTestsF {};
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
|
||||||
|
std::vector<size_t> split_operations_numbers = {1, 10};
|
||||||
|
|
||||||
|
std::vector<size_t> split_outputs_numbers = {2, 3};
|
||||||
|
|
||||||
|
} // namespace
|
||||||
|
|
||||||
|
namespace split {
|
||||||
|
namespace forward {
|
||||||
|
std::shared_ptr<ov::Model> CreateFunction(size_t num_split_ops,
|
||||||
|
size_t num_split_outputs,
|
||||||
|
ov::element::Type input_type) {
|
||||||
|
const ov::Shape input_shape{96, static_cast<size_t>(std::pow(num_split_outputs, num_split_ops + 1)), 55, 55};
|
||||||
|
|
||||||
|
auto X = std::make_shared<ov::opset9::Parameter>(input_type, input_shape);
|
||||||
|
|
||||||
|
auto ng_order0 = std::make_shared<ov::opset9::Constant>(ov::element::u64, ov::Shape{4}, ov::Shape{0, 3, 1, 2});
|
||||||
|
auto transpose0 = std::make_shared<ov::opset9::Transpose>(X, ng_order0);
|
||||||
|
|
||||||
|
ov::OutputVector outputs;
|
||||||
|
Output in_op = transpose0->output(0);
|
||||||
|
for (size_t i = 0; i < num_split_ops; ++i) {
|
||||||
|
auto split_axis_const = std::make_shared<ov::opset9::Constant>(ov::element::u64,
|
||||||
|
ov::Shape{},
|
||||||
|
2);
|
||||||
|
auto split = std::make_shared<ov::opset9::Split>(in_op, split_axis_const, num_split_outputs);
|
||||||
|
for (size_t num_output = 0; num_output < num_split_outputs - 1; ++num_output) {
|
||||||
|
outputs.push_back(split->output(num_output));
|
||||||
|
}
|
||||||
|
in_op = split->output(num_split_outputs - 1);
|
||||||
|
}
|
||||||
|
outputs.push_back(in_op);
|
||||||
|
|
||||||
|
return std::make_shared<ov::Model>(outputs, ov::ParameterVector{X});
|
||||||
|
}
|
||||||
|
|
||||||
|
std::shared_ptr<ov::Model> CreateReferenceFunction(size_t num_split_ops,
|
||||||
|
size_t num_split_outputs,
|
||||||
|
ov::element::Type input_type) {
|
||||||
|
const ov::Shape input_shape{96, static_cast<size_t>(std::pow(num_split_outputs, num_split_ops + 1)), 55, 55};
|
||||||
|
|
||||||
|
auto X = std::make_shared<ov::opset9::Parameter>(input_type, input_shape);
|
||||||
|
|
||||||
|
ov::OutputVector outputs;
|
||||||
|
Output in_op = X->output(0);
|
||||||
|
for (size_t i = 0; i < num_split_ops; ++i) {
|
||||||
|
auto split_axis_const = std::make_shared<ov::opset9::Constant>(ov::element::u64,
|
||||||
|
ov::Shape{},
|
||||||
|
1);
|
||||||
|
auto split = std::make_shared<ov::opset9::Split>(in_op, split_axis_const, num_split_outputs);
|
||||||
|
for (size_t num_output = 0; num_output < num_split_outputs - 1; ++num_output) {
|
||||||
|
auto ng_order0 = std::make_shared<ov::opset9::Constant>(ov::element::u64, ov::Shape{4}, ov::Shape{0, 3, 1, 2});
|
||||||
|
auto transpose0 = std::make_shared<ov::opset9::Transpose>(split->output(num_output), ng_order0);
|
||||||
|
outputs.push_back(transpose0);
|
||||||
|
}
|
||||||
|
in_op = split->output(num_split_outputs - 1);
|
||||||
|
}
|
||||||
|
|
||||||
|
auto ng_order0 = std::make_shared<ov::opset9::Constant>(ov::element::u64, ov::Shape{4}, ov::Shape{0, 3, 1, 2});
|
||||||
|
auto transpose0 = std::make_shared<ov::opset9::Transpose>(in_op, ng_order0);
|
||||||
|
outputs.push_back(transpose0);
|
||||||
|
|
||||||
|
return std::make_shared<ov::Model>(outputs, ov::ParameterVector{X});
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace forward
|
||||||
|
} // namespace split
|
||||||
|
|
||||||
|
|
||||||
|
TEST_P(TransposeSinkingSplitForwardTestFixture, CompareFunctions) {
|
||||||
|
PassFactoryPtr pass_factory;
|
||||||
|
size_t num_split_ops;
|
||||||
|
size_t num_split_outputs;
|
||||||
|
CreateGraphSplitForwardF model_factory;
|
||||||
|
CreateGraphSplitForwardF reference_model_factory;
|
||||||
|
ov::element::Type input_type;
|
||||||
|
std::tie(pass_factory,
|
||||||
|
num_split_ops,
|
||||||
|
num_split_outputs,
|
||||||
|
model_factory,
|
||||||
|
reference_model_factory,
|
||||||
|
input_type) = this->GetParam();
|
||||||
|
|
||||||
|
model = model_factory(num_split_ops, num_split_outputs, input_type);
|
||||||
|
model_ref = reference_model_factory(num_split_ops, num_split_outputs, input_type);
|
||||||
|
pass_factory->registerPass(manager);
|
||||||
|
}
|
||||||
|
|
||||||
|
INSTANTIATE_TEST_SUITE_P(
|
||||||
|
TransposeSinkingSplitForwardTestSuite,
|
||||||
|
TransposeSinkingSplitForwardTestFixture,
|
||||||
|
::testing::Combine(::testing::Values(CreatePassFactory<ov::pass::TransposeSinkingSplitForward>()),
|
||||||
|
::testing::ValuesIn(split_operations_numbers),
|
||||||
|
::testing::ValuesIn(split_outputs_numbers),
|
||||||
|
::testing::Values(split::forward::CreateFunction),
|
||||||
|
::testing::Values(split::forward::CreateReferenceFunction),
|
||||||
|
::testing::Values(ov::element::f32)));
|
||||||
|
|
||||||
|
|
||||||
|
// --------------------------------------------------------------------------------------
|
||||||
|
|
||||||
|
using CreateGraphSplitBackwardF = std::function< std::shared_ptr<ov::Model> (size_t split_tree_depth,
|
||||||
|
size_t num_split_outputs,
|
||||||
|
ov::element::Type input_type)>;
|
||||||
|
|
||||||
|
using TestSplitBackwardParams = std::tuple<PassFactoryPtr,
|
||||||
|
size_t, /* split_tree_depth */
|
||||||
|
size_t, /* num_split_outputs */
|
||||||
|
CreateGraphSplitBackwardF, /* model_factory */
|
||||||
|
CreateGraphSplitBackwardF, /* reference_model_factory */
|
||||||
|
ov::element::Type> /* input type */;
|
||||||
|
|
||||||
|
class TransposeSinkingSplitBackwardTestFixture: public ::testing::WithParamInterface<TestSplitBackwardParams>,
|
||||||
|
public TransformationTestsF {};
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
std::vector<size_t> split_tree_depth_nums = {1, 3};
|
||||||
|
} // namespace
|
||||||
|
|
||||||
|
// --------------------------------------------------------------------------------------
|
||||||
|
|
||||||
|
namespace split {
|
||||||
|
namespace backward {
|
||||||
|
|
||||||
|
class SplitFactory {
|
||||||
|
public:
|
||||||
|
SplitFactory(size_t axis, size_t n_outputs, ov::element::Type elem_type) :
|
||||||
|
_axis(axis), _n_outputs(n_outputs), _elem_type(elem_type) {}
|
||||||
|
NodePtr create(Output parent) const {
|
||||||
|
auto split_axis_const = std::make_shared<ov::opset9::Constant>(_elem_type,
|
||||||
|
ov::Shape{},
|
||||||
|
_axis);
|
||||||
|
return std::make_shared<ov::opset9::Split>(parent, split_axis_const, _n_outputs);
|
||||||
|
}
|
||||||
|
private:
|
||||||
|
const size_t _axis;
|
||||||
|
const size_t _n_outputs;
|
||||||
|
const ov::element::Type _elem_type;
|
||||||
|
};
|
||||||
|
|
||||||
|
void CreateSplitTree(size_t max_depth, size_t depth, Output parent, const SplitFactory & split_factory, ov::OutputVector & leaves) {
|
||||||
|
if (depth == max_depth) {
|
||||||
|
leaves.push_back(parent);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
auto split = split_factory.create(parent);
|
||||||
|
|
||||||
|
for (size_t output_idx = 0; output_idx < split->get_output_size(); ++output_idx) {
|
||||||
|
CreateSplitTree(max_depth, depth + 1, split->output(output_idx), split_factory, leaves);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
std::shared_ptr<ov::Model> CreateFunction(size_t split_tree_depth,
|
||||||
|
size_t num_split_outputs,
|
||||||
|
ov::element::Type input_type) {
|
||||||
|
const size_t split_input_dim_value = static_cast<size_t>(std::pow(num_split_outputs, split_tree_depth + 1));
|
||||||
|
const ov::Shape input_shape{96, split_input_dim_value, 55, 55};
|
||||||
|
|
||||||
|
auto X = std::make_shared<ov::opset9::Parameter>(input_type, input_shape);
|
||||||
|
|
||||||
|
ov::OutputVector split_tree_leaves;
|
||||||
|
{
|
||||||
|
SplitFactory split_factory(/* axis */ 1, num_split_outputs, /* elem_type */ ov::element::u64);
|
||||||
|
CreateSplitTree(split_tree_depth, /* depth */ 0, X->output(0), split_factory, split_tree_leaves);
|
||||||
|
}
|
||||||
|
|
||||||
|
ov::OutputVector outputs;
|
||||||
|
for (auto& split_tree_leaf : split_tree_leaves) {
|
||||||
|
auto ng_order = std::make_shared<ov::opset9::Constant>(ov::element::u64, ov::Shape{4}, ov::Shape{0, 3, 1, 2});
|
||||||
|
auto transpose = std::make_shared<ov::opset9::Transpose>(split_tree_leaf, ng_order);
|
||||||
|
|
||||||
|
const size_t split_dim_current_value = static_cast<size_t>(split_input_dim_value / std::pow(num_split_outputs, split_tree_depth));
|
||||||
|
auto reshape_const = std::make_shared<ov::opset9::Constant>(ov::element::u64, ov::Shape{3}, ov::Shape{96, 55, split_dim_current_value * 55});
|
||||||
|
auto reshape = std::make_shared<ov::opset9::Reshape>(transpose, reshape_const, false);
|
||||||
|
outputs.push_back(reshape);
|
||||||
|
}
|
||||||
|
|
||||||
|
return std::make_shared<ov::Model>(outputs, ov::ParameterVector{X});
|
||||||
|
}
|
||||||
|
|
||||||
|
std::shared_ptr<ov::Model> CreateReferenceFunction(size_t split_tree_depth,
|
||||||
|
size_t num_split_outputs,
|
||||||
|
ov::element::Type input_type) {
|
||||||
|
const size_t split_input_dim_value = static_cast<size_t>(std::pow(num_split_outputs, split_tree_depth + 1));
|
||||||
|
const ov::Shape input_shape{96, split_input_dim_value, 55, 55};
|
||||||
|
|
||||||
|
auto X = std::make_shared<ov::opset9::Parameter>(input_type, input_shape);
|
||||||
|
|
||||||
|
auto ng_order = std::make_shared<ov::opset9::Constant>(ov::element::u64, ov::Shape{4}, ov::Shape{0, 3, 1, 2});
|
||||||
|
auto transpose = std::make_shared<ov::opset9::Transpose>(X, ng_order);
|
||||||
|
|
||||||
|
ov::OutputVector split_tree_leaves;
|
||||||
|
{
|
||||||
|
SplitFactory split_factory(/* axis */ 2, num_split_outputs, /* elem_type */ ov::element::u64);
|
||||||
|
CreateSplitTree(split_tree_depth, /* depth */ 0, transpose->output(0), split_factory, split_tree_leaves);
|
||||||
|
}
|
||||||
|
|
||||||
|
ov::OutputVector outputs;
|
||||||
|
for (auto& split_tree_leaf : split_tree_leaves) {
|
||||||
|
const size_t split_dim_current_value = static_cast<size_t>(split_input_dim_value / std::pow(num_split_outputs, split_tree_depth));
|
||||||
|
auto reshape_const = std::make_shared<ov::opset9::Constant>(ov::element::u64, ov::Shape{3}, ov::Shape{96, 55, split_dim_current_value * 55});
|
||||||
|
auto reshape = std::make_shared<ov::opset9::Reshape>(split_tree_leaf, reshape_const, false);
|
||||||
|
outputs.push_back(reshape);
|
||||||
|
}
|
||||||
|
|
||||||
|
return std::make_shared<ov::Model>(outputs, ov::ParameterVector{X});
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace backward
|
||||||
|
} // namespace split
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
TEST_P(TransposeSinkingSplitBackwardTestFixture, CompareFunctions) {
|
||||||
|
PassFactoryPtr pass_factory;
|
||||||
|
size_t split_tree_depth;
|
||||||
|
size_t num_split_outputs;
|
||||||
|
CreateGraphSplitBackwardF model_factory;
|
||||||
|
CreateGraphSplitBackwardF reference_model_factory;
|
||||||
|
ov::element::Type input_type;
|
||||||
|
std::tie(pass_factory,
|
||||||
|
split_tree_depth,
|
||||||
|
num_split_outputs,
|
||||||
|
model_factory,
|
||||||
|
reference_model_factory,
|
||||||
|
input_type) = this->GetParam();
|
||||||
|
|
||||||
|
model = model_factory(split_tree_depth, num_split_outputs, input_type);
|
||||||
|
model_ref = reference_model_factory(split_tree_depth, num_split_outputs, input_type);
|
||||||
|
pass_factory->registerPass(manager);
|
||||||
|
}
|
||||||
|
|
||||||
|
INSTANTIATE_TEST_SUITE_P(
|
||||||
|
TransposeSinkingSplitBackwardTestSuite,
|
||||||
|
TransposeSinkingSplitBackwardTestFixture,
|
||||||
|
::testing::Combine(::testing::Values(CreatePassFactory<ov::pass::TransposeSinkingSplitBackward>()),
|
||||||
|
::testing::ValuesIn(split_tree_depth_nums),
|
||||||
|
::testing::ValuesIn(split_outputs_numbers),
|
||||||
|
::testing::Values(split::backward::CreateFunction),
|
||||||
|
::testing::Values(split::backward::CreateReferenceFunction),
|
||||||
|
::testing::Values(ov::element::f32)));
|
||||||
|
|
||||||
|
|
||||||
|
using TransposeInsertF = std::function< ov::OutputVector (const ov::OutputVector& split_tree_leaves)>;
|
||||||
|
|
||||||
|
using CreateGraphSplitBackwardRestrictF = std::function< std::shared_ptr<ov::Model> (size_t split_tree_depth,
|
||||||
|
size_t num_split_outputs,
|
||||||
|
ov::element::Type input_type,
|
||||||
|
TransposeInsertF tranpose_insert_function)>;
|
||||||
|
|
||||||
|
using TestSplitBackwardRestrictParams = std::tuple<PassFactoryPtr,
|
||||||
|
size_t, /* split_tree_depth */
|
||||||
|
size_t, /* num_split_outputs */
|
||||||
|
CreateGraphSplitBackwardRestrictF, /* model_factory */
|
||||||
|
ov::element::Type, /* input type */
|
||||||
|
TransposeInsertF>; /* insert transpose function */
|
||||||
|
|
||||||
|
class TransposeSinkingSplitBackwardRestrictTestFixture: public ::testing::WithParamInterface<TestSplitBackwardRestrictParams>,
|
||||||
|
public TransformationTestsF {};
|
||||||
|
|
||||||
|
TEST_P(TransposeSinkingSplitBackwardRestrictTestFixture, CompareFunctions) {
|
||||||
|
PassFactoryPtr pass_factory;
|
||||||
|
size_t split_tree_depth;
|
||||||
|
size_t num_split_outputs;
|
||||||
|
CreateGraphSplitBackwardRestrictF model_factory;
|
||||||
|
ov::element::Type input_type;
|
||||||
|
TransposeInsertF tranpose_insert_function;
|
||||||
|
std::tie(pass_factory,
|
||||||
|
split_tree_depth,
|
||||||
|
num_split_outputs,
|
||||||
|
model_factory,
|
||||||
|
input_type,
|
||||||
|
tranpose_insert_function) = this->GetParam();
|
||||||
|
|
||||||
|
model = model_factory(split_tree_depth, num_split_outputs, input_type, tranpose_insert_function);
|
||||||
|
model_ref = model->clone();
|
||||||
|
pass_factory->registerPass(manager);
|
||||||
|
}
|
||||||
|
namespace split {
|
||||||
|
namespace backward {
|
||||||
|
namespace restrictions {
|
||||||
|
|
||||||
|
std::shared_ptr<ov::Model> CreateFunction(size_t split_tree_depth,
|
||||||
|
size_t num_split_outputs,
|
||||||
|
ov::element::Type input_type,
|
||||||
|
TransposeInsertF transpose_insert_func) {
|
||||||
|
const size_t split_input_dim_value = static_cast<size_t>(std::pow(num_split_outputs, split_tree_depth + 1));
|
||||||
|
const ov::Shape input_shape{96, split_input_dim_value, 55, 55};
|
||||||
|
|
||||||
|
auto X = std::make_shared<ov::opset9::Parameter>(input_type, input_shape);
|
||||||
|
|
||||||
|
ov::OutputVector split_tree_leaves;
|
||||||
|
{
|
||||||
|
SplitFactory split_factory(/* axis */ 1, num_split_outputs, /* elem_type */ ov::element::u64);
|
||||||
|
CreateSplitTree(split_tree_depth, /* depth */ 0, X->output(0), split_factory, split_tree_leaves);
|
||||||
|
}
|
||||||
|
|
||||||
|
ov::OutputVector outputs;
|
||||||
|
for (auto& split_tree_leaf : transpose_insert_func(split_tree_leaves)) {
|
||||||
|
const size_t split_dim_current_value = static_cast<size_t>(split_input_dim_value / std::pow(num_split_outputs, split_tree_depth));
|
||||||
|
auto reshape_const = std::make_shared<ov::opset9::Constant>(ov::element::u64, ov::Shape{3}, ov::Shape{96, 55, split_dim_current_value * 55});
|
||||||
|
auto reshape = std::make_shared<ov::opset9::Reshape>(split_tree_leaf, reshape_const, false);
|
||||||
|
outputs.push_back(reshape);
|
||||||
|
}
|
||||||
|
|
||||||
|
return std::make_shared<ov::Model>(outputs, ov::ParameterVector{X});
|
||||||
|
}
|
||||||
|
|
||||||
|
ov::OutputVector OnlyFirstTranspose(const ov::OutputVector& split_tree_leaves) {
|
||||||
|
ov::OutputVector outputs;
|
||||||
|
{
|
||||||
|
auto& split_tree_leaf = split_tree_leaves.front();
|
||||||
|
auto ng_order = std::make_shared<ov::opset9::Constant>(ov::element::u64, ov::Shape{4}, ov::Shape{0, 3, 1, 2});
|
||||||
|
auto transpose = std::make_shared<ov::opset9::Transpose>(split_tree_leaf, ng_order);
|
||||||
|
outputs.push_back(transpose);
|
||||||
|
}
|
||||||
|
|
||||||
|
for (size_t leaf_idx = 1; leaf_idx < split_tree_leaves.size(); ++leaf_idx) {
|
||||||
|
outputs.push_back(split_tree_leaves[leaf_idx]);
|
||||||
|
}
|
||||||
|
|
||||||
|
return outputs;
|
||||||
|
}
|
||||||
|
|
||||||
|
ov::OutputVector OnlyLastTranspose(const ov::OutputVector& split_tree_leaves) {
|
||||||
|
ov::OutputVector outputs;
|
||||||
|
{
|
||||||
|
auto& split_tree_leaf = split_tree_leaves.back();
|
||||||
|
auto ng_order = std::make_shared<ov::opset9::Constant>(ov::element::u64, ov::Shape{4}, ov::Shape{0, 3, 1, 2});
|
||||||
|
auto transpose = std::make_shared<ov::opset9::Transpose>(split_tree_leaf, ng_order);
|
||||||
|
outputs.push_back(transpose);
|
||||||
|
}
|
||||||
|
|
||||||
|
for (size_t leaf_idx = 0; leaf_idx < split_tree_leaves.size() - 1; ++leaf_idx) {
|
||||||
|
outputs.push_back(split_tree_leaves[leaf_idx]);
|
||||||
|
}
|
||||||
|
|
||||||
|
return outputs;
|
||||||
|
}
|
||||||
|
|
||||||
|
ov::OutputVector OnlyMiddleTranspose(const ov::OutputVector& split_tree_leaves) {
|
||||||
|
ov::OutputVector outputs;
|
||||||
|
size_t middle_idx = split_tree_leaves.size() / 2;
|
||||||
|
if (split_tree_leaves.size() % 2)
|
||||||
|
++middle_idx;
|
||||||
|
for (size_t leaf_idx = 0; leaf_idx < split_tree_leaves.size() - 1; ++leaf_idx) {
|
||||||
|
if (leaf_idx == middle_idx) {
|
||||||
|
auto& split_tree_leaf = split_tree_leaves[leaf_idx];
|
||||||
|
auto ng_order = std::make_shared<ov::opset9::Constant>(ov::element::u64, ov::Shape{4}, ov::Shape{0, 3, 1, 2});
|
||||||
|
auto transpose = std::make_shared<ov::opset9::Transpose>(split_tree_leaf, ng_order);
|
||||||
|
outputs.push_back(transpose);
|
||||||
|
} else {
|
||||||
|
outputs.push_back(split_tree_leaves[leaf_idx]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return outputs;
|
||||||
|
}
|
||||||
|
|
||||||
|
ov::OutputVector FirstAnotherTranspose(const ov::OutputVector& split_tree_leaves) {
|
||||||
|
ov::OutputVector outputs;
|
||||||
|
{
|
||||||
|
auto& split_tree_leaf = split_tree_leaves.front();
|
||||||
|
auto ng_order = std::make_shared<ov::opset9::Constant>(ov::element::u64, ov::Shape{4}, ov::Shape{0, 2, 3, 1});
|
||||||
|
auto transpose = std::make_shared<ov::opset9::Transpose>(split_tree_leaf, ng_order);
|
||||||
|
outputs.push_back(transpose);
|
||||||
|
}
|
||||||
|
|
||||||
|
for (size_t leaf_idx = 1; leaf_idx < split_tree_leaves.size(); ++leaf_idx) {
|
||||||
|
auto& split_tree_leaf = split_tree_leaves[leaf_idx];
|
||||||
|
auto ng_order = std::make_shared<ov::opset9::Constant>(ov::element::u64, ov::Shape{4}, ov::Shape{0, 3, 1, 2});
|
||||||
|
auto transpose = std::make_shared<ov::opset9::Transpose>(split_tree_leaf, ng_order);
|
||||||
|
outputs.push_back(transpose);
|
||||||
|
}
|
||||||
|
|
||||||
|
return outputs;
|
||||||
|
}
|
||||||
|
|
||||||
|
ov::OutputVector LastAnotherTranspose(const ov::OutputVector& split_tree_leaves) {
|
||||||
|
ov::OutputVector outputs;
|
||||||
|
{
|
||||||
|
auto& split_tree_leaf = split_tree_leaves.back();
|
||||||
|
auto ng_order = std::make_shared<ov::opset9::Constant>(ov::element::u64, ov::Shape{4}, ov::Shape{0, 2, 3, 1});
|
||||||
|
auto transpose = std::make_shared<ov::opset9::Transpose>(split_tree_leaf, ng_order);
|
||||||
|
outputs.push_back(transpose);
|
||||||
|
}
|
||||||
|
|
||||||
|
for (size_t leaf_idx = 0; leaf_idx < split_tree_leaves.size() - 1; ++leaf_idx) {
|
||||||
|
auto& split_tree_leaf = split_tree_leaves[leaf_idx];
|
||||||
|
auto ng_order = std::make_shared<ov::opset9::Constant>(ov::element::u64, ov::Shape{4}, ov::Shape{0, 3, 1, 2});
|
||||||
|
auto transpose = std::make_shared<ov::opset9::Transpose>(split_tree_leaf, ng_order);
|
||||||
|
outputs.push_back(transpose);
|
||||||
|
}
|
||||||
|
|
||||||
|
return outputs;
|
||||||
|
}
|
||||||
|
|
||||||
|
ov::OutputVector MiddleAnotherTranspose(const ov::OutputVector& split_tree_leaves) {
|
||||||
|
ov::OutputVector outputs;
|
||||||
|
size_t middle_idx = split_tree_leaves.size() / 2;
|
||||||
|
if (split_tree_leaves.size() % 2)
|
||||||
|
++middle_idx;
|
||||||
|
for (size_t leaf_idx = 0; leaf_idx < split_tree_leaves.size(); ++leaf_idx) {
|
||||||
|
auto& split_tree_leaf = split_tree_leaves[leaf_idx];
|
||||||
|
if (leaf_idx == middle_idx) {
|
||||||
|
auto ng_order = std::make_shared<ov::opset9::Constant>(ov::element::u64, ov::Shape{4}, ov::Shape{0, 2, 3, 1});
|
||||||
|
auto transpose = std::make_shared<ov::opset9::Transpose>(split_tree_leaf, ng_order);
|
||||||
|
outputs.push_back(transpose);
|
||||||
|
} else {
|
||||||
|
auto ng_order = std::make_shared<ov::opset9::Constant>(ov::element::u64, ov::Shape{4}, ov::Shape{0, 3, 1, 2});
|
||||||
|
auto transpose = std::make_shared<ov::opset9::Transpose>(split_tree_leaf, ng_order);
|
||||||
|
outputs.push_back(transpose);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return outputs;
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace restrictions
|
||||||
|
} // namespace backward
|
||||||
|
} // namespace split
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
|
||||||
|
std::vector<TransposeInsertF> insertTransposeFactories = {
|
||||||
|
split::backward::restrictions::OnlyFirstTranspose,
|
||||||
|
split::backward::restrictions::OnlyLastTranspose,
|
||||||
|
split::backward::restrictions::OnlyMiddleTranspose,
|
||||||
|
split::backward::restrictions::FirstAnotherTranspose,
|
||||||
|
split::backward::restrictions::LastAnotherTranspose,
|
||||||
|
split::backward::restrictions::MiddleAnotherTranspose
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace
|
||||||
|
|
||||||
|
|
||||||
|
INSTANTIATE_TEST_SUITE_P(
|
||||||
|
TransposeSinkingSplitBackwardRestrictTestSuite,
|
||||||
|
TransposeSinkingSplitBackwardRestrictTestFixture,
|
||||||
|
::testing::Combine(::testing::Values(CreatePassFactory<ov::pass::TransposeSinkingSplitBackward>()),
|
||||||
|
::testing::Values(1),
|
||||||
|
::testing::Values(5),
|
||||||
|
::testing::Values(split::backward::restrictions::CreateFunction),
|
||||||
|
::testing::Values(ov::element::f32),
|
||||||
|
::testing::ValuesIn(insertTransposeFactories)));
|
Loading…
Reference in New Issue
Block a user