TransposeSinking: support Pad operation (#15137)
* Add Transpose sinking for Pad op, tests, refactoring * Update GeneralTransposeSinking transformation * resolve review comments * resolve review comment
This commit is contained in:
parent
0ba9f14e60
commit
0ade00488e
@ -0,0 +1,30 @@
|
||||
// Copyright (C) 2022 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "openvino/pass/graph_rewrite.hpp"
|
||||
#include "openvino/pass/pass.hpp"
|
||||
#include "transformations_visibility.hpp"
|
||||
|
||||
namespace ov {
|
||||
namespace pass {
|
||||
|
||||
class TRANSFORMATIONS_API TransposeSinkingPadForward;
|
||||
class TRANSFORMATIONS_API TransposeSinkingPadBackward;
|
||||
|
||||
} // namespace pass
|
||||
} // namespace ov
|
||||
|
||||
class ov::pass::TransposeSinkingPadForward : public ov::pass::MatcherPass {
|
||||
public:
|
||||
OPENVINO_RTTI("ov::pass::TransposeSinkingPadForward", "0");
|
||||
TransposeSinkingPadForward();
|
||||
};
|
||||
|
||||
class ov::pass::TransposeSinkingPadBackward : public ov::pass::MatcherPass {
|
||||
public:
|
||||
OPENVINO_RTTI("ov::pass::TransposeSinkingPadBackward", "0");
|
||||
TransposeSinkingPadBackward();
|
||||
};
|
@ -8,7 +8,6 @@
|
||||
#include <transformations/utils/utils.hpp>
|
||||
#include <utility>
|
||||
|
||||
#include "itt.hpp"
|
||||
#include "openvino/op/util/op_types.hpp"
|
||||
#include "openvino/opsets/opset9.hpp"
|
||||
#include "openvino/pass/pattern/op/label.hpp"
|
||||
@ -32,7 +31,7 @@ struct TransposeInputsInfo {
|
||||
* @brief Finds node first input that is a transpose operation and returns filled TransposeInputsInfo
|
||||
* for it
|
||||
*/
|
||||
TransposeInputsInfo GetFirstTransposeInput(std::shared_ptr<ov::Node>);
|
||||
TransposeInputsInfo GetFirstTransposeInput(const std::shared_ptr<ov::Node>&);
|
||||
|
||||
/**
|
||||
* @brief Checks if @arg has any input node that is a transpose operation
|
||||
@ -54,41 +53,44 @@ void SwapOutputNames(ov::Output<ov::Node>, ov::Output<ov::Node>);
|
||||
/**
|
||||
* @brief Swaps @args friendly names
|
||||
*/
|
||||
void SwapFriendlyNames(std::shared_ptr<ov::Node>, std::shared_ptr<ov::Node>);
|
||||
void SwapFriendlyNames(const std::shared_ptr<ov::Node>&, const std::shared_ptr<ov::Node>&);
|
||||
|
||||
/**
|
||||
* @brief Swaps @args output tensor names and friendly names
|
||||
*/
|
||||
void SwapNames(std::shared_ptr<ov::Node>, std::shared_ptr<ov::Node>);
|
||||
void SwapNames(const std::shared_ptr<ov::Node>&, const std::shared_ptr<ov::Node>&);
|
||||
|
||||
namespace sink_forward {
|
||||
/**
|
||||
* @brief Inserts reversed transposed on @args main_node inputs. Removes input transpose specified in @arg
|
||||
* transpose_input_info
|
||||
*/
|
||||
void UpdateInputTransposes(std::shared_ptr<ov::Node> main_node, const TransposeInputsInfo& transpose_input_info);
|
||||
void UpdateInputTransposes(const std::shared_ptr<ov::Node>& main_node, const TransposeInputsInfo& transpose_input_info);
|
||||
|
||||
/**
|
||||
* @brief Removes @arg input node
|
||||
*/
|
||||
void RemoveInputNode(std::shared_ptr<ov::Node>, size_t input_idx);
|
||||
void RemoveInputNode(const std::shared_ptr<ov::Node>&, size_t input_idx);
|
||||
|
||||
/**
|
||||
* @brief Inserts transposes on each main_node output with the order specified in @arg transpose_input_info
|
||||
*/
|
||||
ov::NodeVector InsertOutputTransposes(std::shared_ptr<ov::Node> main_node,
|
||||
ov::NodeVector InsertOutputTransposes(const std::shared_ptr<ov::Node>& main_node,
|
||||
const TransposeInputsInfo& transpose_input_info);
|
||||
} // namespace sink_forward
|
||||
|
||||
namespace sink_backward {
|
||||
/**
|
||||
* @brief Inserts transposes on each input of @arg main_node with the order specified in @arg transpose_const
|
||||
* @brief Inserts transposes on inputs of @arg main_node specified by @arg input_indexes
|
||||
* with the order specified in @arg transpose_const. If @arg input_indexes is empty, then it inserts
|
||||
* transposes for all inputs.
|
||||
*/
|
||||
ov::NodeVector InsertTransposeBeforeNode(std::shared_ptr<ov::Node> main_node,
|
||||
std::shared_ptr<ov::opset9::Constant> transpose_const);
|
||||
ov::NodeVector InsertTransposeBeforeNode(const std::shared_ptr<ov::Node>& main_node,
|
||||
const std::shared_ptr<ov::opset9::Constant>& transpose_const,
|
||||
std::vector<int> input_indexes = {});
|
||||
} // namespace sink_backward
|
||||
|
||||
void UpdateForwardSinkingAbility(std::shared_ptr<ov::Node>);
|
||||
void UpdateForwardSinkingAbility(const std::shared_ptr<ov::Node>&);
|
||||
|
||||
/**
|
||||
* @brief Checks if @arg has consumers that all are the same transpose operation. If no consumers at all
|
||||
@ -99,6 +101,12 @@ bool HasSameOutputTransposeNodes(const ov::Output<ov::Node>&);
|
||||
/**
|
||||
* Removes all direct node consumers that have one output
|
||||
*/
|
||||
void RemoveSingleOutputConsumers(std::shared_ptr<ov::Node>);
|
||||
void RemoveSingleOutputConsumers(const std::shared_ptr<ov::Node>&);
|
||||
|
||||
/**
|
||||
* Changes the order of values in @arg input according to @arg transpose_axis_order along @arg axis
|
||||
*/
|
||||
ov::Output<ov::Node> ChangeValuesOrder(const ov::Output<ov::Node>& input,
|
||||
const ov::AxisVector& transpose_axis_order,
|
||||
const std::shared_ptr<ov::opset9::Constant>& axis);
|
||||
} // namespace transpose_sinking
|
||||
|
@ -14,6 +14,7 @@
|
||||
#include "transformations/common_optimizations/transpose_sinking.hpp"
|
||||
#include "transformations/common_optimizations/transpose_sinking_binary.hpp"
|
||||
#include "transformations/common_optimizations/transpose_sinking_concat.hpp"
|
||||
#include "transformations/common_optimizations/transpose_sinking_pad.hpp"
|
||||
#include "transformations/common_optimizations/transpose_sinking_split.hpp"
|
||||
#include "transformations/common_optimizations/transpose_sinking_unary.hpp"
|
||||
#include "transformations/utils/utils.hpp"
|
||||
@ -24,6 +25,7 @@ ov::pass::TransposeSinkingGeneralForward::TransposeSinkingGeneralForward() {
|
||||
add_matcher<ov::pass::TransposeSinkingBinaryForward>();
|
||||
add_matcher<ov::pass::TransposeSinkingConcatForward>();
|
||||
add_matcher<ov::pass::TransposeSinkingSplitForward>();
|
||||
add_matcher<ov::pass::TransposeSinkingPadForward>();
|
||||
add_matcher<ngraph::pass::TransposeFuse>();
|
||||
}
|
||||
|
||||
@ -33,6 +35,7 @@ ov::pass::TransposeSinkingGeneralBackward::TransposeSinkingGeneralBackward() {
|
||||
add_matcher<ov::pass::TransposeSinkingBinaryBackward>();
|
||||
add_matcher<ov::pass::TransposeSinkingConcatBackward>();
|
||||
add_matcher<ov::pass::TransposeSinkingSplitBackward>();
|
||||
add_matcher<ov::pass::TransposeSinkingPadBackward>();
|
||||
add_matcher<ngraph::pass::TransposeFuse>();
|
||||
}
|
||||
|
||||
|
@ -0,0 +1,105 @@
|
||||
#include "transformations/common_optimizations/transpose_sinking_pad.hpp"
|
||||
|
||||
#include <openvino/pass/pattern/op/or.hpp>
|
||||
|
||||
#include "itt.hpp"
|
||||
#include "openvino/op/util/op_types.hpp"
|
||||
#include "openvino/opsets/opset10.hpp"
|
||||
#include "openvino/pass/pattern/op/wrap_type.hpp"
|
||||
#include "openvino/util/common_util.hpp"
|
||||
#include "transformations/common_optimizations/transpose_sinking_utils.hpp"
|
||||
#include "transformations/rt_info/transpose_sinking_attr.hpp"
|
||||
|
||||
using namespace ov::pass::pattern;
|
||||
using namespace ov;
|
||||
using namespace ov::opset10;
|
||||
using namespace transpose_sinking;
|
||||
|
||||
ov::pass::TransposeSinkingPadForward::TransposeSinkingPadForward() {
|
||||
MATCHER_SCOPE(TransposeSinkingPadForward);
|
||||
auto const_label = wrap_type<Constant>();
|
||||
auto transpose_label = wrap_type<Transpose>({any_input(), const_label});
|
||||
auto main_node_label = wrap_type<Pad>({transpose_label, any_input(), any_input(), any_input()});
|
||||
|
||||
matcher_pass_callback matcher_pass_callback = [=](Matcher& m) {
|
||||
const auto& pattern_to_node = m.get_pattern_map();
|
||||
|
||||
auto& main_node = pattern_to_node.at(main_node_label);
|
||||
auto transpose = std::dynamic_pointer_cast<Transpose>(pattern_to_node.at(transpose_label));
|
||||
if (!transpose) {
|
||||
return false;
|
||||
}
|
||||
|
||||
auto transpose_const = as_type_ptr<Constant>(pattern_to_node.at(const_label));
|
||||
if (!transpose_const) {
|
||||
return false;
|
||||
}
|
||||
|
||||
// remove Transpose on 1st input:
|
||||
auto transpose_parent = main_node->input_value(0).get_node()->input_value(0);
|
||||
main_node->input(0).replace_source_output(transpose_parent);
|
||||
|
||||
// change the order of values for PadBegin and PadEng inputs
|
||||
const auto transpose_axis_order = transpose_const->get_axis_vector_val();
|
||||
auto axis = std::make_shared<Constant>(element::i32, Shape{}, std::vector<int32_t>{0});
|
||||
|
||||
main_node->input(1).replace_source_output(
|
||||
ChangeValuesOrder(main_node->input_value(1), transpose_axis_order, axis));
|
||||
main_node->input(2).replace_source_output(
|
||||
ChangeValuesOrder(main_node->input_value(2), transpose_axis_order, axis));
|
||||
|
||||
// insert Transpose for Pad output
|
||||
TransposeInputsInfo transpose_input_info = {transpose, transpose_const, 0};
|
||||
for (auto& new_node : sink_forward::InsertOutputTransposes(main_node, transpose_input_info)) {
|
||||
register_new_node(new_node);
|
||||
transpose_sinking::UpdateForwardSinkingAbility(new_node);
|
||||
}
|
||||
return true;
|
||||
};
|
||||
|
||||
auto m = std::make_shared<Matcher>(main_node_label, matcher_name);
|
||||
register_matcher(m, matcher_pass_callback);
|
||||
}
|
||||
|
||||
ov::pass::TransposeSinkingPadBackward::TransposeSinkingPadBackward() {
|
||||
MATCHER_SCOPE(TransposeSinkingPadBackward);
|
||||
|
||||
auto main_node_label = wrap_type<Pad>([](const Output<Node>& output) -> bool {
|
||||
return has_static_rank()(output) && HasSameOutputTransposeNodes(output);
|
||||
});
|
||||
|
||||
auto transpose_const_label = wrap_type<Constant>();
|
||||
|
||||
auto transpose_label =
|
||||
wrap_type<Transpose>({main_node_label, transpose_const_label}, [](const Output<Node>& output) -> bool {
|
||||
return has_static_rank()(output) && is_sinking_node(output);
|
||||
});
|
||||
|
||||
matcher_pass_callback matcher_pass_callback = [=](Matcher& m) {
|
||||
const auto& pattern_to_output = m.get_pattern_value_map();
|
||||
auto transpose_const = as_type_ptr<Constant>(pattern_to_output.at(transpose_const_label).get_node_shared_ptr());
|
||||
auto transpose = pattern_to_output.at(transpose_label).get_node_shared_ptr();
|
||||
auto main_node = pattern_to_output.at(main_node_label).get_node_shared_ptr();
|
||||
|
||||
for (auto& new_node : sink_backward::InsertTransposeBeforeNode(main_node,
|
||||
transpose_const,
|
||||
/* input_indexes= */ {0})) {
|
||||
register_new_node(new_node);
|
||||
}
|
||||
|
||||
// remove output transposes
|
||||
RemoveSingleOutputConsumers(main_node);
|
||||
|
||||
const auto transpose_axis_order = transpose_const->get_axis_vector_val();
|
||||
auto axis = std::make_shared<Constant>(element::i32, Shape{}, std::vector<int32_t>{0});
|
||||
|
||||
main_node->input(1).replace_source_output(
|
||||
ChangeValuesOrder(main_node->input_value(1), transpose_axis_order, axis));
|
||||
main_node->input(2).replace_source_output(
|
||||
ChangeValuesOrder(main_node->input_value(2), transpose_axis_order, axis));
|
||||
return true;
|
||||
};
|
||||
|
||||
auto m = std::make_shared<Matcher>(transpose_label, matcher_name);
|
||||
register_matcher(m, matcher_pass_callback);
|
||||
}
|
@ -20,7 +20,22 @@ using namespace ov::opset9;
|
||||
|
||||
using NodePtr = std::shared_ptr<Node>;
|
||||
|
||||
TransposeInputsInfo GetFirstTransposeInput(NodePtr node) {
|
||||
Output<Node> ChangeValuesOrder(const Output<Node>& input,
|
||||
const AxisVector& transpose_axis_order,
|
||||
const std::shared_ptr<Constant>& axis) {
|
||||
auto rank = transpose_axis_order.size();
|
||||
auto split_pad = std::make_shared<Split>(input, axis, rank);
|
||||
auto split_outputs = split_pad->outputs();
|
||||
OutputVector new_order(split_outputs.size());
|
||||
for (size_t i = 0; i < rank; ++i) {
|
||||
new_order[i] = split_outputs[transpose_axis_order[i]];
|
||||
}
|
||||
auto concat_pad = std::make_shared<Concat>(new_order, 0);
|
||||
copy_runtime_info(input.get_node_shared_ptr(), {split_pad, concat_pad});
|
||||
return concat_pad;
|
||||
}
|
||||
|
||||
TransposeInputsInfo GetFirstTransposeInput(const NodePtr& node) {
|
||||
for (size_t input_idx = 0; input_idx < node->get_input_size(); ++input_idx) {
|
||||
NodePtr input_node = node->get_input_node_shared_ptr(input_idx);
|
||||
auto transpose_node = as_type_ptr<Transpose>(input_node);
|
||||
@ -38,7 +53,7 @@ TransposeInputsInfo GetFirstTransposeInput(NodePtr node) {
|
||||
}
|
||||
}
|
||||
|
||||
return TransposeInputsInfo();
|
||||
return {};
|
||||
}
|
||||
|
||||
bool IfNodeHasTransposeInputs(const Output<Node>& output) {
|
||||
@ -60,20 +75,20 @@ void SwapOutputNames(Output<Node> output1, Output<Node> output2) {
|
||||
output1.set_names(node2_output_names);
|
||||
}
|
||||
|
||||
void SwapFriendlyNames(NodePtr node1, NodePtr node2) {
|
||||
void SwapFriendlyNames(const NodePtr& node1, const NodePtr& node2) {
|
||||
const std::string node2_name = node2->get_friendly_name();
|
||||
node2->set_friendly_name(node1->get_friendly_name());
|
||||
node1->set_friendly_name(node2_name);
|
||||
}
|
||||
|
||||
void SwapNames(NodePtr node1, NodePtr node2) {
|
||||
void SwapNames(const NodePtr& node1, const NodePtr& node2) {
|
||||
SwapFriendlyNames(node1, node2);
|
||||
SwapOutputNames(node1->output(0), node2->output(0));
|
||||
}
|
||||
|
||||
namespace {
|
||||
|
||||
bool HasDynamicRankInput(NodePtr node) {
|
||||
bool HasDynamicRankInput(const NodePtr& node) {
|
||||
for (auto& input_node : node->input_values()) {
|
||||
const ov::Rank output_rank = input_node.get_partial_shape().rank();
|
||||
if (output_rank.is_dynamic())
|
||||
@ -95,7 +110,7 @@ ov::Rank::value_type GetMaxInputRank(const NodePtr& node) {
|
||||
return max_input_rank;
|
||||
}
|
||||
|
||||
NodePtr InsertUnsqueeze(Output<Node> node, size_t n_dims) {
|
||||
NodePtr InsertUnsqueeze(const Output<Node>& node, size_t n_dims) {
|
||||
std::vector<size_t> dims(n_dims);
|
||||
std::iota(dims.begin(), dims.end(), 0);
|
||||
auto unsqueeze_const = std::make_shared<Constant>(ov::element::i64, Shape{dims.size()}, dims);
|
||||
@ -115,7 +130,7 @@ ov::Output<ov::Node> FixInputNodeRank(ov::Output<ov::Node> input_node, ov::Rank:
|
||||
|
||||
namespace sink_forward {
|
||||
|
||||
void UpdateInputTransposes(NodePtr main_node, const TransposeInputsInfo& transpose_input_info) {
|
||||
void UpdateInputTransposes(const NodePtr& main_node, const TransposeInputsInfo& transpose_input_info) {
|
||||
if (transpose_input_info.isEmpty() || HasDynamicRankInput(main_node))
|
||||
return;
|
||||
|
||||
@ -148,7 +163,7 @@ void UpdateInputTransposes(NodePtr main_node, const TransposeInputsInfo& transpo
|
||||
}
|
||||
}
|
||||
|
||||
void RemoveInputNode(NodePtr main_node, size_t input_idx) {
|
||||
void RemoveInputNode(const NodePtr& main_node, size_t input_idx) {
|
||||
auto input_node = main_node->input_value(input_idx);
|
||||
if (input_node.get_node()->get_input_size() < (input_idx + 1))
|
||||
return;
|
||||
@ -156,7 +171,7 @@ void RemoveInputNode(NodePtr main_node, size_t input_idx) {
|
||||
main_node->input(input_idx).replace_source_output(parent_node);
|
||||
}
|
||||
|
||||
NodeVector InsertOutputTransposes(NodePtr main_node, const TransposeInputsInfo& transpose_input_info) {
|
||||
NodeVector InsertOutputTransposes(const NodePtr& main_node, const TransposeInputsInfo& transpose_input_info) {
|
||||
if (transpose_input_info.isEmpty())
|
||||
return {};
|
||||
const auto transpose_axis_order = transpose_input_info.transpose_const->get_axis_vector_val();
|
||||
@ -195,7 +210,13 @@ NodeVector InsertOutputTransposes(NodePtr main_node, const TransposeInputsInfo&
|
||||
|
||||
namespace sink_backward {
|
||||
|
||||
NodeVector InsertTransposeBeforeNode(NodePtr main_node, std::shared_ptr<Constant> transpose_const) {
|
||||
NodeVector InsertTransposeBeforeNode(const NodePtr& main_node,
|
||||
const std::shared_ptr<Constant>& transpose_const,
|
||||
std::vector<int> input_indexes) {
|
||||
if (input_indexes.empty()) {
|
||||
input_indexes.resize(main_node->get_input_size());
|
||||
std::iota(input_indexes.begin(), input_indexes.end(), 0);
|
||||
}
|
||||
const auto transpose_axis_order = transpose_const->get_axis_vector_val();
|
||||
const auto transpose_element_type = transpose_const->get_element_type();
|
||||
|
||||
@ -208,7 +229,7 @@ NodeVector InsertTransposeBeforeNode(NodePtr main_node, std::shared_ptr<Constant
|
||||
if (max_input_rank < 0)
|
||||
return {};
|
||||
|
||||
for (size_t i = 0; i < main_node->get_input_size(); ++i) {
|
||||
for (const auto& i : input_indexes) {
|
||||
auto input_node = FixInputNodeRank(main_node->input_value(i), max_input_rank);
|
||||
|
||||
auto new_transpose_const = std::make_shared<Constant>(transpose_element_type,
|
||||
@ -250,7 +271,7 @@ bool CanPropagateForwardThrough(Node* node) {
|
||||
return false;
|
||||
}
|
||||
|
||||
bool CanPropagateForward(NodePtr node) {
|
||||
bool CanPropagateForward(const NodePtr& node) {
|
||||
for (size_t i = 0; i < node->get_output_size(); ++i) {
|
||||
for (auto& consumer_input : node->output(i).get_target_inputs()) {
|
||||
if (!CanPropagateForwardThrough(consumer_input.get_node()))
|
||||
@ -263,7 +284,7 @@ bool CanPropagateForward(NodePtr node) {
|
||||
|
||||
} // namespace
|
||||
|
||||
void UpdateForwardSinkingAbility(NodePtr node) {
|
||||
void UpdateForwardSinkingAbility(const NodePtr& node) {
|
||||
if (!CanPropagateForward(node))
|
||||
mark_as_no_sinking_node(node);
|
||||
}
|
||||
@ -282,7 +303,7 @@ std::shared_ptr<Constant> GetTransposeConstant(Node* node) {
|
||||
return constant_node;
|
||||
}
|
||||
|
||||
Node* FindFirstConsumer(NodePtr node) {
|
||||
Node* FindFirstConsumer(const NodePtr& node) {
|
||||
for (size_t output_idx = 0; output_idx < node->get_output_size(); ++output_idx) {
|
||||
auto inputs = node->get_output_target_inputs(output_idx);
|
||||
if (inputs.empty())
|
||||
@ -292,7 +313,7 @@ Node* FindFirstConsumer(NodePtr node) {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
bool HasSameOutputTransposeNodes(NodePtr main_node) {
|
||||
bool HasSameOutputTransposeNodes(const NodePtr& main_node) {
|
||||
AxisVector first_transpose_axis_order;
|
||||
{
|
||||
Node* first_consumer = FindFirstConsumer(main_node);
|
||||
@ -329,7 +350,7 @@ bool HasSameOutputTransposeNodes(const Output<Node>& output) {
|
||||
return HasSameOutputTransposeNodes(output.get_node_shared_ptr());
|
||||
}
|
||||
|
||||
void RemoveSingleOutputConsumers(NodePtr node) {
|
||||
void RemoveSingleOutputConsumers(const NodePtr& node) {
|
||||
for (size_t output_idx = 0; output_idx < node->get_output_size(); ++output_idx) {
|
||||
for (auto& input : node->get_output_target_inputs(output_idx)) {
|
||||
Node* consumer = input.get_node();
|
||||
|
@ -0,0 +1,313 @@
|
||||
// Copyright (C) 2022 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#include <functional>
|
||||
#include <openvino/frontend/manager.hpp>
|
||||
#include <openvino/opsets/opset10.hpp>
|
||||
#include <openvino/pass/constant_folding.hpp>
|
||||
#include <openvino/pass/manager.hpp>
|
||||
#include <transformations/common_optimizations/transpose_sinking_pad.hpp>
|
||||
#include <transformations/common_optimizations/transpose_sinking_utils.hpp>
|
||||
#include <transformations/init_node_info.hpp>
|
||||
|
||||
#include "common_test_utils/ngraph_test_utils.hpp"
|
||||
#include "gtest/gtest.h"
|
||||
|
||||
using namespace std;
|
||||
using namespace ov;
|
||||
using namespace opset10;
|
||||
|
||||
namespace transpose_sinking_pad {
|
||||
namespace {
|
||||
|
||||
using NodePtr = shared_ptr<Node>;
|
||||
using ModelPtr = shared_ptr<Model>;
|
||||
|
||||
class IPassFactory {
|
||||
public:
|
||||
IPassFactory(const string& type_name) : type_name_(type_name) {}
|
||||
|
||||
virtual ~IPassFactory() = default;
|
||||
|
||||
virtual void registerPass(pass::Manager& pass_manager) const = 0;
|
||||
|
||||
const string& getTypeName() const {
|
||||
return type_name_;
|
||||
}
|
||||
|
||||
private:
|
||||
const string type_name_;
|
||||
};
|
||||
|
||||
using PassFactoryPtr = shared_ptr<IPassFactory>;
|
||||
|
||||
template <typename PassT>
|
||||
class PassFactory : public IPassFactory {
|
||||
public:
|
||||
PassFactory(const string& type_name) : IPassFactory(type_name) {}
|
||||
|
||||
void registerPass(pass::Manager& pass_manager) const override {
|
||||
pass_manager.register_pass<PassT>();
|
||||
pass_manager.register_pass<ov::pass::ConstantFolding>();
|
||||
}
|
||||
};
|
||||
|
||||
#define CREATE_PASS_FACTORY(pass_name) make_shared<PassFactory<pass::pass_name>>(#pass_name)
|
||||
|
||||
vector<int64_t> TransposePadValues(const vector<int64_t>& pads, const vector<size_t>& order) {
|
||||
vector<int64_t> new_pads(pads.size());
|
||||
for (size_t i = 0; i < pads.size(); ++i) {
|
||||
new_pads[i] = pads[order[i]];
|
||||
}
|
||||
return new_pads;
|
||||
};
|
||||
} // namespace
|
||||
|
||||
namespace forward {
|
||||
namespace single_consumer {
|
||||
|
||||
shared_ptr<Model> CreateFunction(size_t num_pad_ops, element::Type input_type) {
|
||||
const Shape input_shape{96, 32, 55, 55};
|
||||
|
||||
auto X = make_shared<Parameter>(input_type, input_shape);
|
||||
|
||||
auto order = make_shared<Constant>(element::i64, Shape{4}, Shape{0, 3, 1, 2});
|
||||
auto transpose = make_shared<Transpose>(X, order);
|
||||
|
||||
OutputVector outputs;
|
||||
Output<Node> in_op = transpose->output(0);
|
||||
auto pad_value = make_shared<Constant>(input_type, Shape{}, 0);
|
||||
for (size_t i = 0; i < num_pad_ops; ++i) {
|
||||
auto pad_begin_const = make_shared<Constant>(element::i64, Shape{4}, vector<int64_t>{0, 1, 2, 3});
|
||||
auto pad_end_const = make_shared<Constant>(element::i64, Shape{4}, vector<int64_t>{0, 1, 2, 3});
|
||||
auto pad = make_shared<Pad>(in_op, pad_begin_const, pad_end_const, pad_value, ov::op::PadMode::CONSTANT);
|
||||
outputs.push_back((pad->output(0)));
|
||||
in_op = pad;
|
||||
}
|
||||
outputs.push_back(in_op);
|
||||
|
||||
return make_shared<Model>(outputs, ParameterVector{X});
|
||||
}
|
||||
|
||||
shared_ptr<Model> CreateReferenceFunction(size_t num_pad_ops, element::Type input_type) {
|
||||
const Shape input_shape{96, 32, 55, 55};
|
||||
|
||||
auto X = make_shared<Parameter>(input_type, input_shape);
|
||||
|
||||
OutputVector outputs;
|
||||
Output<Node> in_op = X->output(0);
|
||||
vector<int64_t> pads{0, 1, 2, 3};
|
||||
auto transpose_pad_values = [&](const vector<size_t>& order) {
|
||||
vector<int64_t> new_pads(pads.size());
|
||||
for (size_t i = 0; i < pads.size(); ++i) {
|
||||
new_pads[i] = pads[order[i]];
|
||||
}
|
||||
return new_pads;
|
||||
};
|
||||
auto axis = make_shared<Constant>(element::i64, Shape{}, 0);
|
||||
auto pad_value = make_shared<Constant>(input_type, Shape{}, 0);
|
||||
for (size_t i = 0; i < num_pad_ops; ++i) {
|
||||
vector<size_t> order_val = {0, 3, 1, 2};
|
||||
auto pad_begin_const = make_shared<Constant>(element::i64, Shape{4}, transpose_pad_values(order_val));
|
||||
auto pad_end_const = make_shared<Constant>(element::i64, Shape{4}, transpose_pad_values(order_val));
|
||||
auto pad = make_shared<Pad>(in_op, pad_begin_const, pad_end_const, pad_value, ov::op::PadMode::CONSTANT);
|
||||
|
||||
auto order = make_shared<Constant>(element::i64, Shape{4}, Shape{order_val});
|
||||
auto transpose = make_shared<Transpose>(pad->output(0), order);
|
||||
outputs.push_back(transpose);
|
||||
in_op = pad;
|
||||
}
|
||||
|
||||
auto order = make_shared<Constant>(element::i64, Shape{4}, Shape{0, 3, 1, 2});
|
||||
auto transpose = make_shared<Transpose>(in_op, order);
|
||||
outputs.push_back(transpose);
|
||||
|
||||
auto ref = make_shared<Model>(outputs, ParameterVector{X});
|
||||
ov::pass::Manager ps_manager;
|
||||
ps_manager.run_passes(ref);
|
||||
return ref;
|
||||
}
|
||||
|
||||
} // namespace single_consumer
|
||||
} // namespace forward
|
||||
|
||||
namespace backward {
|
||||
namespace single_consumer {
|
||||
|
||||
shared_ptr<Model> CreateFunction(size_t num_pad_ops, element::Type input_type) {
|
||||
const Shape input_shape{96, 32, 55, 55};
|
||||
|
||||
auto X = make_shared<Parameter>(input_type, input_shape);
|
||||
|
||||
OutputVector outputs;
|
||||
Output<Node> in_op = X->output(0);
|
||||
auto pad_value = make_shared<Constant>(input_type, Shape{}, 0);
|
||||
for (size_t i = 0; i < num_pad_ops; ++i) {
|
||||
auto pad_begin_const = make_shared<Constant>(element::i64, Shape{4}, vector<int64_t>{0, 1, 2, 3});
|
||||
auto pad_end_const = make_shared<Constant>(element::i64, Shape{4}, vector<int64_t>{0, 1, 2, 3});
|
||||
auto pad = make_shared<Pad>(in_op, pad_begin_const, pad_end_const, pad_value, ov::op::PadMode::CONSTANT);
|
||||
in_op = pad;
|
||||
}
|
||||
auto order = make_shared<Constant>(element::i64, Shape{4}, vector<int64_t>{0, 3, 1, 2});
|
||||
auto transpose = make_shared<Transpose>(in_op, order);
|
||||
auto relu = make_shared<Relu>(transpose);
|
||||
outputs.push_back(relu);
|
||||
return make_shared<Model>(outputs, ParameterVector{X});
|
||||
}
|
||||
|
||||
shared_ptr<Model> CreateReferenceFunction(size_t num_pad_ops, element::Type input_type) {
|
||||
const Shape input_shape{96, 32, 55, 55};
|
||||
|
||||
auto X = make_shared<Parameter>(input_type, input_shape);
|
||||
vector<size_t> order_val = {0, 3, 1, 2};
|
||||
auto order = make_shared<Constant>(element::i64, Shape{4}, order_val);
|
||||
auto transpose = make_shared<Transpose>(X, order);
|
||||
|
||||
OutputVector outputs;
|
||||
Output<Node> in_op = transpose->output(0);
|
||||
vector<int64_t> pads{0, 1, 2, 3};
|
||||
auto axis = make_shared<Constant>(element::i64, Shape{}, 0);
|
||||
auto pad_value = make_shared<Constant>(input_type, Shape{}, 0);
|
||||
for (size_t i = 0; i < num_pad_ops; ++i) {
|
||||
auto pad_begin_const = make_shared<Constant>(element::i64, Shape{4}, TransposePadValues(pads, order_val));
|
||||
auto pad_end_const = make_shared<Constant>(element::i64, Shape{4}, TransposePadValues(pads, order_val));
|
||||
auto pad = make_shared<Pad>(in_op, pad_begin_const, pad_end_const, pad_value, ov::op::PadMode::CONSTANT);
|
||||
in_op = pad;
|
||||
}
|
||||
auto relu = make_shared<Relu>(in_op);
|
||||
outputs.push_back(relu);
|
||||
auto ref = make_shared<Model>(outputs, ParameterVector{X});
|
||||
return ref;
|
||||
}
|
||||
|
||||
} // namespace single_consumer
|
||||
|
||||
namespace output_transpose_mult_transposes {
|
||||
shared_ptr<Model> CreateFunction(size_t num_pad_ops, element::Type input_type) {
|
||||
const Shape input_shape{96, 32, 55, 55};
|
||||
|
||||
auto X = make_shared<Parameter>(input_type, input_shape);
|
||||
|
||||
OutputVector outputs;
|
||||
Output<Node> in_op = X->output(0);
|
||||
auto pad_value = make_shared<Constant>(input_type, Shape{}, 0);
|
||||
for (size_t i = 0; i < num_pad_ops; ++i) {
|
||||
auto pad_begin_const = make_shared<Constant>(element::i64, Shape{4}, vector<int64_t>{0, 1, 2, 3});
|
||||
auto pad_end_const = make_shared<Constant>(element::i64, Shape{4}, vector<int64_t>{0, 1, 2, 3});
|
||||
auto pad = make_shared<Pad>(in_op, pad_begin_const, pad_end_const, pad_value, ov::op::PadMode::CONSTANT);
|
||||
in_op = pad;
|
||||
}
|
||||
auto order = make_shared<Constant>(element::i64, Shape{4}, vector<int64_t>{0, 3, 1, 2});
|
||||
auto transpose_1 = make_shared<Transpose>(in_op, order);
|
||||
auto relu_1 = make_shared<Relu>(transpose_1);
|
||||
outputs.push_back(relu_1);
|
||||
|
||||
auto transpose_2 = make_shared<Transpose>(in_op, order);
|
||||
auto relu_2 = make_shared<Relu>(transpose_2);
|
||||
outputs.push_back(relu_2);
|
||||
return make_shared<Model>(outputs, ParameterVector{X});
|
||||
}
|
||||
|
||||
shared_ptr<Model> CreateReferenceFunction(size_t num_pad_ops, element::Type input_type) {
|
||||
const Shape input_shape{96, 32, 55, 55};
|
||||
|
||||
auto X = make_shared<Parameter>(input_type, input_shape);
|
||||
vector<size_t> order_val = {0, 3, 1, 2};
|
||||
auto order = make_shared<Constant>(element::i64, Shape{4}, order_val);
|
||||
auto transpose = make_shared<Transpose>(X, order);
|
||||
|
||||
OutputVector outputs;
|
||||
Output<Node> in_op = transpose->output(0);
|
||||
vector<int64_t> pads{0, 1, 2, 3};
|
||||
|
||||
auto axis = make_shared<Constant>(element::i64, Shape{}, 0);
|
||||
auto pad_value = make_shared<Constant>(input_type, Shape{}, 0);
|
||||
for (size_t i = 0; i < num_pad_ops; ++i) {
|
||||
auto pad_begin_const = make_shared<Constant>(element::i64, Shape{4}, TransposePadValues(pads, order_val));
|
||||
auto pad_end_const = make_shared<Constant>(element::i64, Shape{4}, TransposePadValues(pads, order_val));
|
||||
auto pad = make_shared<Pad>(in_op, pad_begin_const, pad_end_const, pad_value, ov::op::PadMode::CONSTANT);
|
||||
in_op = pad;
|
||||
}
|
||||
auto relu_1 = make_shared<Relu>(in_op);
|
||||
auto relu_2 = make_shared<Relu>(in_op);
|
||||
outputs.push_back(relu_1);
|
||||
outputs.push_back(relu_2);
|
||||
auto ref = make_shared<Model>(outputs, ParameterVector{X});
|
||||
return ref;
|
||||
}
|
||||
} // namespace output_transpose_mult_transposes
|
||||
} // namespace backward
|
||||
|
||||
using CreateGraphPadF = function<shared_ptr<Model>(size_t num_pad_ops, element::Type input_type)>;
|
||||
|
||||
using TestPadParams = tuple<PassFactoryPtr,
|
||||
size_t, /* num_pad_ops */
|
||||
CreateGraphPadF, /* model_factory */
|
||||
CreateGraphPadF, /* reference_model_factory */
|
||||
element::Type> /* input type */;
|
||||
|
||||
class TransposeSinkingPadTestFixture : public ::testing::WithParamInterface<TestPadParams>,
|
||||
public TransformationTestsF {
|
||||
public:
|
||||
static string get_test_name(const testing::TestParamInfo<TestPadParams>& obj) {
|
||||
PassFactoryPtr pass_factory;
|
||||
size_t num_pad_ops;
|
||||
CreateGraphPadF model_factory;
|
||||
CreateGraphPadF reference_model_factory;
|
||||
element::Type input_type;
|
||||
|
||||
tie(pass_factory, num_pad_ops, model_factory, reference_model_factory, input_type) = obj.param;
|
||||
|
||||
ostringstream test_name;
|
||||
test_name << "pass_factory=" << pass_factory->getTypeName() << "_";
|
||||
test_name << "num_pad_ops=" << num_pad_ops << "_";
|
||||
test_name << "input_type=" << input_type;
|
||||
|
||||
return test_name.str();
|
||||
}
|
||||
};
|
||||
|
||||
TEST_P(TransposeSinkingPadTestFixture, CompareFunctions) {
|
||||
PassFactoryPtr pass_factory;
|
||||
size_t num_pad_ops;
|
||||
CreateGraphPadF model_factory;
|
||||
CreateGraphPadF reference_model_factory;
|
||||
element::Type input_type;
|
||||
tie(pass_factory, num_pad_ops, model_factory, reference_model_factory, input_type) = this->GetParam();
|
||||
|
||||
model = model_factory(num_pad_ops, input_type);
|
||||
model_ref = reference_model_factory(num_pad_ops, input_type);
|
||||
pass_factory->registerPass(manager);
|
||||
}
|
||||
|
||||
std::vector<size_t> pad_operations_numbers = {1, 10};
|
||||
|
||||
INSTANTIATE_TEST_SUITE_P(TransposeSinkingPadForwardSingleConsumerTestSuite,
|
||||
TransposeSinkingPadTestFixture,
|
||||
::testing::Combine(::testing::Values(CREATE_PASS_FACTORY(TransposeSinkingPadForward)),
|
||||
::testing::ValuesIn(pad_operations_numbers),
|
||||
::testing::Values(forward::single_consumer::CreateFunction),
|
||||
::testing::Values(forward::single_consumer::CreateReferenceFunction),
|
||||
::testing::Values(element::f32)),
|
||||
TransposeSinkingPadTestFixture::get_test_name);
|
||||
|
||||
INSTANTIATE_TEST_SUITE_P(TransposeSinkingPadBackwardSingleConsumerTestSuite,
|
||||
TransposeSinkingPadTestFixture,
|
||||
::testing::Combine(::testing::Values(CREATE_PASS_FACTORY(TransposeSinkingPadBackward)),
|
||||
::testing::ValuesIn(pad_operations_numbers),
|
||||
::testing::Values(backward::single_consumer::CreateFunction),
|
||||
::testing::Values(backward::single_consumer::CreateReferenceFunction),
|
||||
::testing::Values(element::f32)),
|
||||
TransposeSinkingPadTestFixture::get_test_name);
|
||||
|
||||
INSTANTIATE_TEST_SUITE_P(
|
||||
TransposeSinkingPadBackwardSingleConsumerMultiTransposesTestSuite,
|
||||
TransposeSinkingPadTestFixture,
|
||||
::testing::Combine(::testing::Values(CREATE_PASS_FACTORY(TransposeSinkingPadBackward)),
|
||||
::testing::ValuesIn(pad_operations_numbers),
|
||||
::testing::Values(backward::output_transpose_mult_transposes::CreateFunction),
|
||||
::testing::Values(backward::output_transpose_mult_transposes::CreateReferenceFunction),
|
||||
::testing::Values(element::f32)),
|
||||
TransposeSinkingPadTestFixture::get_test_name);
|
||||
} // namespace transpose_sinking_pad
|
Loading…
Reference in New Issue
Block a user