Enable TransposeSinking transformation in Tensorflow FrontEnd (#15410)

* Enable TransposeSinking in MOC

* replace TransposeSinking in TF Frontend

* fix TS for concat op

* Fix TS for Binary/Concat ops: broadcast transposed input

* Fix transpose sinking for Pad op

* fix pad tests

* fix dynamic case for Concat/Binary ops

* codestyle

* fix TransposeSinking for Split/Concat ops

* fix split

* Resolve review comments
This commit is contained in:
Ivan Tikhonov 2023-02-07 23:21:34 +04:00 committed by GitHub
parent 9bac41b466
commit a0a73d443e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
9 changed files with 119 additions and 51 deletions

View File

@ -65,7 +65,7 @@ namespace sink_forward {
* @brief Inserts reversed transposed on @args main_node inputs. Removes input transpose specified in @arg
* transpose_input_info
*/
void UpdateInputTransposes(const std::shared_ptr<ov::Node>& main_node, const TransposeInputsInfo& transpose_input_info);
bool UpdateInputTransposes(const std::shared_ptr<ov::Node>& main_node, const TransposeInputsInfo& transpose_input_info);
/**
* @brief Removes @arg input node

View File

@ -67,7 +67,6 @@
#include <transformations/common_optimizations/subtract_fusion.hpp>
#include <transformations/common_optimizations/swish_fusion.hpp>
#include <transformations/common_optimizations/transpose_sinking.hpp>
#include <transformations/common_optimizations/transpose_sinking_general.hpp>
#include <transformations/common_optimizations/transpose_to_reshape.hpp>
#include <transformations/init_node_info.hpp>
#include <transformations/low_precision/mark_dequantization_subgraph.hpp>

View File

@ -33,7 +33,11 @@ ov::pass::TransposeSinkingBinaryForward::TransposeSinkingBinaryForward() {
TransposeInputsInfo transpose_input_info = GetFirstTransposeInput(main_node);
sink_forward::UpdateInputTransposes(main_node, transpose_input_info);
// todo: support dynamic rank case
bool updated = sink_forward::UpdateInputTransposes(main_node, transpose_input_info);
if (!updated) {
return false;
}
for (auto& new_node : sink_forward::InsertOutputTransposes(main_node, transpose_input_info)) {
register_new_node(new_node);
transpose_sinking::UpdateForwardSinkingAbility(new_node);

View File

@ -32,18 +32,26 @@ ov::pass::TransposeSinkingConcatForward::TransposeSinkingConcatForward() {
auto main_node = main_node_output.get_node_shared_ptr();
TransposeInputsInfo transpose_input_info = GetFirstTransposeInput(main_node);
auto concat_node = as_type_ptr<Concat>(main_node);
auto concat_axis = concat_node->get_concatenation_axis();
if (concat_axis < 0) {
return false;
}
// todo: support dyn rank case
bool updated = sink_forward::UpdateInputTransposes(main_node, transpose_input_info);
if (!updated) {
return false;
}
sink_forward::UpdateInputTransposes(main_node, transpose_input_info);
for (auto& new_node : sink_forward::InsertOutputTransposes(main_node, transpose_input_info)) {
register_new_node(new_node);
transpose_sinking::UpdateForwardSinkingAbility(new_node);
}
auto concat_node = as_type_ptr<Concat>(main_node);
const auto transpose_axis_order = transpose_input_info.transpose_const->get_axis_vector_val();
const int64_t transposed_concat_axis = transpose_axis_order[concat_node->get_axis()];
concat_node->set_concatenation_axis(transposed_concat_axis);
const int64_t transposed_concat_axis = transpose_axis_order[concat_axis];
concat_node->set_axis(transposed_concat_axis);
concat_node->set_concatenation_axis(-1);
return true;
};
@ -70,7 +78,11 @@ ov::pass::TransposeSinkingConcatBackward::TransposeSinkingConcatBackward() {
auto transpose_const = as_type_ptr<Constant>(pattern_to_output.at(transpose_const_label).get_node_shared_ptr());
auto transpose = pattern_to_output.at(transpose_label).get_node_shared_ptr();
auto main_node = pattern_to_output.at(main_node_label).get_node_shared_ptr();
auto concat_node = as_type_ptr<Concat>(main_node);
auto concat_axis = concat_node->get_concatenation_axis();
if (concat_axis < 0) {
return false;
}
for (auto& new_node : sink_backward::InsertTransposeBeforeNode(main_node, transpose_const)) {
register_new_node(new_node);
}
@ -79,13 +91,11 @@ ov::pass::TransposeSinkingConcatBackward::TransposeSinkingConcatBackward() {
RemoveSingleOutputConsumers(main_node);
SwapNames(transpose, main_node);
auto concat_node = as_type_ptr<Concat>(main_node);
const auto transpose_axis_order = transpose_const->get_axis_vector_val();
const auto reversed_traspose_axis_order = ReverseTransposeOrder(transpose_axis_order);
const int64_t transposed_concat_axis = reversed_traspose_axis_order[concat_node->get_axis()];
concat_node->set_concatenation_axis(transposed_concat_axis);
const int64_t transposed_concat_axis = reversed_traspose_axis_order[concat_axis];
concat_node->set_axis(transposed_concat_axis);
concat_node->set_concatenation_axis(-1);
return true;
};

View File

@ -41,12 +41,13 @@ ov::pass::TransposeSinkingPadForward::TransposeSinkingPadForward() {
// change the order of values for PadBegin and PadEng inputs
const auto transpose_axis_order = transpose_const->get_axis_vector_val();
const auto reversed_transpose_order = ReverseTransposeOrder(transpose_axis_order);
auto axis = std::make_shared<Constant>(element::i32, Shape{}, std::vector<int32_t>{0});
main_node->input(1).replace_source_output(
ChangeValuesOrder(main_node->input_value(1), transpose_axis_order, axis));
ChangeValuesOrder(main_node->input_value(1), reversed_transpose_order, axis));
main_node->input(2).replace_source_output(
ChangeValuesOrder(main_node->input_value(2), transpose_axis_order, axis));
ChangeValuesOrder(main_node->input_value(2), reversed_transpose_order, axis));
// insert Transpose for Pad output
TransposeInputsInfo transpose_input_info = {transpose, transpose_const, 0};

View File

@ -49,7 +49,7 @@ OutputTranspose GetOutputTransposes(NodePtr node) {
}
}
return OutputTranspose();
return {};
}
template <typename NodeT>
@ -76,6 +76,22 @@ bool IsSplitSinked(const Output<Node>& output) {
return HasInputSplitAndTransposeSiblings(output) && is_sinking_node(output);
}
bool GetSplitAxis(const std::shared_ptr<Constant>& split_axis, const ov::Rank& rank, int64_t& axis) {
auto split_axis_val = split_axis->cast_vector<int64_t>();
if (split_axis_val.empty()) {
return false;
}
axis = split_axis_val[0];
if (axis < 0) {
if (rank.is_static()) {
const auto rank_val = rank.get_length();
axis += rank_val;
} else {
return false;
}
}
return true;
}
} // namespace
/*
@ -116,6 +132,14 @@ ov::pass::TransposeSinkingSplitBackward::TransposeSinkingSplitBackward() {
NodePtr split = FindInputNode<Split>(transpose_label_node);
auto split_axis_constant = as_type_ptr<Constant>(split->input_value(1).get_node_shared_ptr());
if (!split_axis_constant) {
return false;
}
int64_t split_axis;
if (!GetSplitAxis(split_axis_constant, split->input_value(0).get_partial_shape().rank(), split_axis)) {
return false;
}
OutputTranspose output_transpose = GetOutputTransposes(split);
const auto transpose_axis_order = output_transpose.transpose_const->get_axis_vector_val();
@ -123,7 +147,6 @@ ov::pass::TransposeSinkingSplitBackward::TransposeSinkingSplitBackward() {
const auto reversed_traspose_axis_order = ReverseTransposeOrder(transpose_axis_order);
const size_t split_axis = split_axis_constant->get_axis_vector_val()[0];
const size_t reversed_transposed_split_axis = reversed_traspose_axis_order[split_axis];
// insert transpose before split
@ -171,7 +194,16 @@ ov::pass::TransposeSinkingSplitForward::TransposeSinkingSplitForward() {
auto& main_node_output = pattern_to_output.at(main_node_label);
auto main_node = main_node_output.get_node_shared_ptr();
auto split = as_type_ptr<Split>(main_node);
auto split_axis_constant = as_type_ptr<Constant>(split->input_value(1).get_node_shared_ptr());
if (!split_axis_constant) {
return false;
}
int64_t split_axis;
if (!GetSplitAxis(split_axis_constant, split->input_value(0).get_partial_shape().rank(), split_axis)) {
return false;
}
TransposeInputsInfo transpose_input_info = GetFirstTransposeInput(main_node);
sink_forward::RemoveInputNode(main_node, /* input_idx */ 0);
@ -181,13 +213,10 @@ ov::pass::TransposeSinkingSplitForward::TransposeSinkingSplitForward() {
}
const auto transpose_axis_order = transpose_input_info.transpose_const->get_axis_vector_val();
auto split_node = as_type_ptr<Split>(main_node);
auto split_axis_constant = as_type_ptr<Constant>(split_node->input_value(1).get_node_shared_ptr());
const size_t split_axis = split_axis_constant->get_axis_vector_val()[0];
const size_t transposed_split_axis = transpose_axis_order[split_axis];
auto new_split_axis_const =
std::make_shared<Constant>(split_axis_constant->get_element_type(), Shape{}, transposed_split_axis);
split_node->input(1).replace_source_output(new_split_axis_const);
split->input(1).replace_source_output(new_split_axis_const);
copy_runtime_info({split_axis_constant, transpose_input_info.transpose, transpose_input_info.transpose_const},
new_split_axis_const);

View File

@ -120,7 +120,11 @@ NodePtr InsertUnsqueeze(const Output<Node>& node, size_t n_dims) {
}
ov::Output<ov::Node> FixInputNodeRank(ov::Output<ov::Node> input_node, ov::Rank::value_type required_rank) {
const ov::Rank::value_type output_rank = input_node.get_partial_shape().rank().get_length();
auto rank = input_node.get_partial_shape().rank();
if (rank.is_dynamic()) {
return input_node;
}
const auto output_rank = rank.get_length();
if (output_rank >= required_rank)
return input_node;
return InsertUnsqueeze(input_node, required_rank - output_rank)->output(0);
@ -129,31 +133,54 @@ ov::Output<ov::Node> FixInputNodeRank(ov::Output<ov::Node> input_node, ov::Rank:
} // namespace
namespace sink_forward {
AxisVector AlignTransposeOrder(const Output<Node>& output, const TransposeInputsInfo& transpose_input_info) {
if (transpose_input_info.isEmpty()) {
return {};
}
auto num_of_val = static_cast<int64_t>(shape_size(transpose_input_info.transpose_const->get_shape()));
const auto rank = output.get_partial_shape().rank();
const auto rank_val = rank.get_length();
AxisVector new_transpose_order;
if (rank_val > num_of_val) {
const auto diff = rank_val - num_of_val;
new_transpose_order.resize(rank_val);
std::iota(new_transpose_order.begin(), new_transpose_order.end(), 0);
auto transpose_axis_order = transpose_input_info.transpose_const->get_axis_vector_val();
for (int64_t i = diff; i < rank_val; ++i) {
new_transpose_order[i] = transpose_axis_order[i - diff] + diff;
}
} else {
new_transpose_order = transpose_input_info.transpose_const->get_axis_vector_val();
}
return new_transpose_order;
}
void UpdateInputTransposes(const NodePtr& main_node, const TransposeInputsInfo& transpose_input_info) {
bool UpdateInputTransposes(const NodePtr& main_node, const TransposeInputsInfo& transpose_input_info) {
if (transpose_input_info.isEmpty() || HasDynamicRankInput(main_node))
return;
return false;
const auto max_input_rank = GetMaxInputRank(main_node);
if (max_input_rank < 0)
return;
return false;
const auto transpose_axis_order = transpose_input_info.transpose_const->get_axis_vector_val();
const auto reversed_traspose_axis_order = ReverseTransposeOrder(transpose_axis_order);
const size_t tranpose_input_index = transpose_input_info.input_idx;
const size_t transpose_input_index = transpose_input_info.input_idx;
const auto transpose_element_type = transpose_input_info.transpose_const->get_element_type();
for (size_t i = 0; i < main_node->get_input_size(); ++i) {
auto input_node = main_node->input_value(i);
if (i == tranpose_input_index) {
if (i == transpose_input_index) {
auto transpose_parent = input_node.get_node()->input_value(0);
main_node->input(i).replace_source_output(transpose_parent);
} else {
input_node = FixInputNodeRank(input_node, max_input_rank);
auto transpose_order = AlignTransposeOrder(input_node, transpose_input_info);
if (transpose_order.empty()) {
return false;
}
const auto reversed_transpose_axis_order = ReverseTransposeOrder(transpose_order);
auto new_transpose_const = std::make_shared<Constant>(transpose_element_type,
Shape{reversed_traspose_axis_order.size()},
reversed_traspose_axis_order);
Shape{reversed_transpose_axis_order.size()},
reversed_transpose_axis_order);
auto new_transpose = std::make_shared<Transpose>(input_node, new_transpose_const);
main_node->input(i).replace_source_output(new_transpose->output(0));
@ -161,6 +188,7 @@ void UpdateInputTransposes(const NodePtr& main_node, const TransposeInputsInfo&
copy_runtime_info(input_node.get_node_shared_ptr(), {new_transpose, new_transpose_const});
}
}
return true;
}
void RemoveInputNode(const NodePtr& main_node, size_t input_idx) {
@ -174,20 +202,17 @@ void RemoveInputNode(const NodePtr& main_node, size_t input_idx) {
NodeVector InsertOutputTransposes(const NodePtr& main_node, const TransposeInputsInfo& transpose_input_info) {
if (transpose_input_info.isEmpty())
return {};
const auto transpose_axis_order = transpose_input_info.transpose_const->get_axis_vector_val();
const auto reversed_traspose_axis_order = ReverseTransposeOrder(transpose_axis_order);
auto new_transpose_order = AlignTransposeOrder(main_node->output(0), transpose_input_info);
const auto transpose_element_type = transpose_input_info.transpose_const->get_element_type();
NodeVector new_nodes;
for (size_t i = 0; i < main_node->get_output_size(); ++i) {
auto new_transpose_const =
std::make_shared<Constant>(transpose_element_type, Shape{new_transpose_order.size()}, new_transpose_order);
auto main_node_consumers = main_node->output(i).get_target_inputs();
auto new_transpose_const = std::make_shared<Constant>(transpose_element_type,
Shape{transpose_axis_order.size()},
transpose_axis_order);
auto new_transpose = std::make_shared<Transpose>(main_node->output(i), new_transpose_const);
for (auto& consumer : main_node_consumers) {
consumer.replace_source_output(new_transpose);
}

View File

@ -73,15 +73,15 @@ shared_ptr<Model> CreateFunction(size_t num_pad_ops, element::Type input_type) {
auto X = make_shared<Parameter>(input_type, input_shape);
auto order = make_shared<Constant>(element::i64, Shape{4}, Shape{0, 3, 1, 2});
auto transpose = make_shared<Transpose>(X, order);
auto transpose = make_shared<Transpose>(X, order); // 96 55 32 55
OutputVector outputs;
Output<Node> in_op = transpose->output(0);
auto pad_value = make_shared<Constant>(input_type, Shape{}, 0);
for (size_t i = 0; i < num_pad_ops; ++i) {
auto pad_begin_const = make_shared<Constant>(element::i64, Shape{4}, vector<int64_t>{0, 1, 2, 3});
auto pad_end_const = make_shared<Constant>(element::i64, Shape{4}, vector<int64_t>{0, 1, 2, 3});
auto pad = make_shared<Pad>(in_op, pad_begin_const, pad_end_const, pad_value, ov::op::PadMode::CONSTANT);
auto pad_begin_const = make_shared<Constant>(element::i64, Shape{4}, vector<int64_t>{95, 54, 31, 53});
auto pad_end_const = make_shared<Constant>(element::i64, Shape{4}, vector<int64_t>{95, 54, 31, 53});
auto pad = make_shared<Pad>(in_op, pad_begin_const, pad_end_const, pad_value, ov::op::PadMode::REFLECT);
outputs.push_back((pad->output(0)));
in_op = pad;
}
@ -97,7 +97,7 @@ shared_ptr<Model> CreateReferenceFunction(size_t num_pad_ops, element::Type inpu
OutputVector outputs;
Output<Node> in_op = X->output(0);
vector<int64_t> pads{0, 1, 2, 3};
vector<int64_t> pads{95, 54, 31, 53};
auto transpose_pad_values = [&](const vector<size_t>& order) {
vector<int64_t> new_pads(pads.size());
for (size_t i = 0; i < pads.size(); ++i) {
@ -107,10 +107,10 @@ shared_ptr<Model> CreateReferenceFunction(size_t num_pad_ops, element::Type inpu
};
auto axis = make_shared<Constant>(element::i64, Shape{}, 0);
auto pad_value = make_shared<Constant>(input_type, Shape{}, 0);
vector<size_t> order_val = {0, 3, 1, 2};
for (size_t i = 0; i < num_pad_ops; ++i) {
vector<size_t> order_val = {0, 3, 1, 2};
auto pad_begin_const = make_shared<Constant>(element::i64, Shape{4}, transpose_pad_values(order_val));
auto pad_end_const = make_shared<Constant>(element::i64, Shape{4}, transpose_pad_values(order_val));
auto pad_begin_const = make_shared<Constant>(element::i64, Shape{4}, transpose_pad_values({0, 2, 3, 1}));
auto pad_end_const = make_shared<Constant>(element::i64, Shape{4}, transpose_pad_values({0, 2, 3, 1}));
auto pad = make_shared<Pad>(in_op, pad_begin_const, pad_end_const, pad_value, ov::op::PadMode::CONSTANT);
auto order = make_shared<Constant>(element::i64, Shape{4}, Shape{order_val});
@ -119,7 +119,7 @@ shared_ptr<Model> CreateReferenceFunction(size_t num_pad_ops, element::Type inpu
in_op = pad;
}
auto order = make_shared<Constant>(element::i64, Shape{4}, Shape{0, 3, 1, 2});
auto order = make_shared<Constant>(element::i64, Shape{4}, order_val);
auto transpose = make_shared<Transpose>(in_op, order);
outputs.push_back(transpose);

View File

@ -19,6 +19,7 @@
#include "so_extension.hpp"
#include "tf_framework_node.hpp"
#include "transformations/common_optimizations/reverse_shape_and_type_infer.hpp"
#include "transformations/common_optimizations/transpose_sinking_general.hpp"
#include "translate_session.hpp"
#include "utils.hpp"
@ -251,8 +252,7 @@ void FrontEnd::normalize(const std::shared_ptr<ov::Model>& function) const {
manager.register_pass<pass::BlockLSTMReplacer>();
manager.register_pass<pass::GRUBlockCellReplacer>();
// TODO: reimplement TransposeSinking that does not corrupt filters for Convolution
manager.register_pass<ov::frontend::tensorflow::pass::TransposeSinking>();
manager.register_pass<ov::pass::TransposeSinkingGeneral>();
manager.register_pass<ov::pass::ReverseShapeAndTypeInfer>();
manager.run_passes(function);
}