Enable TransposeSinking transformation in Tensorflow FrontEnd (#15410)
* Enable TransposeSinking in MOC * replace TransposeSinking in TF Frontend * fix TS for concat op * Fix TS for Binary/Concat ops: broadcast transposed input * Fix transpose sinking for Pad op * fix pad tests * fix dynamic case for Concat/Binary ops * codestyle * fix TransposeSinking for Split/Concat ops * fix split * Resolve review comments
This commit is contained in:
parent
9bac41b466
commit
a0a73d443e
@ -65,7 +65,7 @@ namespace sink_forward {
|
||||
* @brief Inserts reversed transposed on @args main_node inputs. Removes input transpose specified in @arg
|
||||
* transpose_input_info
|
||||
*/
|
||||
void UpdateInputTransposes(const std::shared_ptr<ov::Node>& main_node, const TransposeInputsInfo& transpose_input_info);
|
||||
bool UpdateInputTransposes(const std::shared_ptr<ov::Node>& main_node, const TransposeInputsInfo& transpose_input_info);
|
||||
|
||||
/**
|
||||
* @brief Removes @arg input node
|
||||
|
@ -67,7 +67,6 @@
|
||||
#include <transformations/common_optimizations/subtract_fusion.hpp>
|
||||
#include <transformations/common_optimizations/swish_fusion.hpp>
|
||||
#include <transformations/common_optimizations/transpose_sinking.hpp>
|
||||
#include <transformations/common_optimizations/transpose_sinking_general.hpp>
|
||||
#include <transformations/common_optimizations/transpose_to_reshape.hpp>
|
||||
#include <transformations/init_node_info.hpp>
|
||||
#include <transformations/low_precision/mark_dequantization_subgraph.hpp>
|
||||
|
@ -33,7 +33,11 @@ ov::pass::TransposeSinkingBinaryForward::TransposeSinkingBinaryForward() {
|
||||
|
||||
TransposeInputsInfo transpose_input_info = GetFirstTransposeInput(main_node);
|
||||
|
||||
sink_forward::UpdateInputTransposes(main_node, transpose_input_info);
|
||||
// todo: support dynamic rank case
|
||||
bool updated = sink_forward::UpdateInputTransposes(main_node, transpose_input_info);
|
||||
if (!updated) {
|
||||
return false;
|
||||
}
|
||||
for (auto& new_node : sink_forward::InsertOutputTransposes(main_node, transpose_input_info)) {
|
||||
register_new_node(new_node);
|
||||
transpose_sinking::UpdateForwardSinkingAbility(new_node);
|
||||
|
@ -32,18 +32,26 @@ ov::pass::TransposeSinkingConcatForward::TransposeSinkingConcatForward() {
|
||||
auto main_node = main_node_output.get_node_shared_ptr();
|
||||
|
||||
TransposeInputsInfo transpose_input_info = GetFirstTransposeInput(main_node);
|
||||
auto concat_node = as_type_ptr<Concat>(main_node);
|
||||
auto concat_axis = concat_node->get_concatenation_axis();
|
||||
if (concat_axis < 0) {
|
||||
return false;
|
||||
}
|
||||
// todo: support dyn rank case
|
||||
bool updated = sink_forward::UpdateInputTransposes(main_node, transpose_input_info);
|
||||
if (!updated) {
|
||||
return false;
|
||||
}
|
||||
|
||||
sink_forward::UpdateInputTransposes(main_node, transpose_input_info);
|
||||
for (auto& new_node : sink_forward::InsertOutputTransposes(main_node, transpose_input_info)) {
|
||||
register_new_node(new_node);
|
||||
transpose_sinking::UpdateForwardSinkingAbility(new_node);
|
||||
}
|
||||
|
||||
auto concat_node = as_type_ptr<Concat>(main_node);
|
||||
const auto transpose_axis_order = transpose_input_info.transpose_const->get_axis_vector_val();
|
||||
const int64_t transposed_concat_axis = transpose_axis_order[concat_node->get_axis()];
|
||||
concat_node->set_concatenation_axis(transposed_concat_axis);
|
||||
|
||||
const int64_t transposed_concat_axis = transpose_axis_order[concat_axis];
|
||||
concat_node->set_axis(transposed_concat_axis);
|
||||
concat_node->set_concatenation_axis(-1);
|
||||
return true;
|
||||
};
|
||||
|
||||
@ -70,7 +78,11 @@ ov::pass::TransposeSinkingConcatBackward::TransposeSinkingConcatBackward() {
|
||||
auto transpose_const = as_type_ptr<Constant>(pattern_to_output.at(transpose_const_label).get_node_shared_ptr());
|
||||
auto transpose = pattern_to_output.at(transpose_label).get_node_shared_ptr();
|
||||
auto main_node = pattern_to_output.at(main_node_label).get_node_shared_ptr();
|
||||
|
||||
auto concat_node = as_type_ptr<Concat>(main_node);
|
||||
auto concat_axis = concat_node->get_concatenation_axis();
|
||||
if (concat_axis < 0) {
|
||||
return false;
|
||||
}
|
||||
for (auto& new_node : sink_backward::InsertTransposeBeforeNode(main_node, transpose_const)) {
|
||||
register_new_node(new_node);
|
||||
}
|
||||
@ -79,13 +91,11 @@ ov::pass::TransposeSinkingConcatBackward::TransposeSinkingConcatBackward() {
|
||||
RemoveSingleOutputConsumers(main_node);
|
||||
|
||||
SwapNames(transpose, main_node);
|
||||
|
||||
auto concat_node = as_type_ptr<Concat>(main_node);
|
||||
const auto transpose_axis_order = transpose_const->get_axis_vector_val();
|
||||
const auto reversed_traspose_axis_order = ReverseTransposeOrder(transpose_axis_order);
|
||||
const int64_t transposed_concat_axis = reversed_traspose_axis_order[concat_node->get_axis()];
|
||||
concat_node->set_concatenation_axis(transposed_concat_axis);
|
||||
|
||||
const int64_t transposed_concat_axis = reversed_traspose_axis_order[concat_axis];
|
||||
concat_node->set_axis(transposed_concat_axis);
|
||||
concat_node->set_concatenation_axis(-1);
|
||||
return true;
|
||||
};
|
||||
|
||||
|
@ -41,12 +41,13 @@ ov::pass::TransposeSinkingPadForward::TransposeSinkingPadForward() {
|
||||
|
||||
// change the order of values for PadBegin and PadEng inputs
|
||||
const auto transpose_axis_order = transpose_const->get_axis_vector_val();
|
||||
const auto reversed_transpose_order = ReverseTransposeOrder(transpose_axis_order);
|
||||
auto axis = std::make_shared<Constant>(element::i32, Shape{}, std::vector<int32_t>{0});
|
||||
|
||||
main_node->input(1).replace_source_output(
|
||||
ChangeValuesOrder(main_node->input_value(1), transpose_axis_order, axis));
|
||||
ChangeValuesOrder(main_node->input_value(1), reversed_transpose_order, axis));
|
||||
main_node->input(2).replace_source_output(
|
||||
ChangeValuesOrder(main_node->input_value(2), transpose_axis_order, axis));
|
||||
ChangeValuesOrder(main_node->input_value(2), reversed_transpose_order, axis));
|
||||
|
||||
// insert Transpose for Pad output
|
||||
TransposeInputsInfo transpose_input_info = {transpose, transpose_const, 0};
|
||||
|
@ -49,7 +49,7 @@ OutputTranspose GetOutputTransposes(NodePtr node) {
|
||||
}
|
||||
}
|
||||
|
||||
return OutputTranspose();
|
||||
return {};
|
||||
}
|
||||
|
||||
template <typename NodeT>
|
||||
@ -76,6 +76,22 @@ bool IsSplitSinked(const Output<Node>& output) {
|
||||
return HasInputSplitAndTransposeSiblings(output) && is_sinking_node(output);
|
||||
}
|
||||
|
||||
bool GetSplitAxis(const std::shared_ptr<Constant>& split_axis, const ov::Rank& rank, int64_t& axis) {
|
||||
auto split_axis_val = split_axis->cast_vector<int64_t>();
|
||||
if (split_axis_val.empty()) {
|
||||
return false;
|
||||
}
|
||||
axis = split_axis_val[0];
|
||||
if (axis < 0) {
|
||||
if (rank.is_static()) {
|
||||
const auto rank_val = rank.get_length();
|
||||
axis += rank_val;
|
||||
} else {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
} // namespace
|
||||
|
||||
/*
|
||||
@ -116,6 +132,14 @@ ov::pass::TransposeSinkingSplitBackward::TransposeSinkingSplitBackward() {
|
||||
|
||||
NodePtr split = FindInputNode<Split>(transpose_label_node);
|
||||
auto split_axis_constant = as_type_ptr<Constant>(split->input_value(1).get_node_shared_ptr());
|
||||
if (!split_axis_constant) {
|
||||
return false;
|
||||
}
|
||||
|
||||
int64_t split_axis;
|
||||
if (!GetSplitAxis(split_axis_constant, split->input_value(0).get_partial_shape().rank(), split_axis)) {
|
||||
return false;
|
||||
}
|
||||
OutputTranspose output_transpose = GetOutputTransposes(split);
|
||||
|
||||
const auto transpose_axis_order = output_transpose.transpose_const->get_axis_vector_val();
|
||||
@ -123,7 +147,6 @@ ov::pass::TransposeSinkingSplitBackward::TransposeSinkingSplitBackward() {
|
||||
|
||||
const auto reversed_traspose_axis_order = ReverseTransposeOrder(transpose_axis_order);
|
||||
|
||||
const size_t split_axis = split_axis_constant->get_axis_vector_val()[0];
|
||||
const size_t reversed_transposed_split_axis = reversed_traspose_axis_order[split_axis];
|
||||
|
||||
// insert transpose before split
|
||||
@ -171,7 +194,16 @@ ov::pass::TransposeSinkingSplitForward::TransposeSinkingSplitForward() {
|
||||
|
||||
auto& main_node_output = pattern_to_output.at(main_node_label);
|
||||
auto main_node = main_node_output.get_node_shared_ptr();
|
||||
auto split = as_type_ptr<Split>(main_node);
|
||||
auto split_axis_constant = as_type_ptr<Constant>(split->input_value(1).get_node_shared_ptr());
|
||||
if (!split_axis_constant) {
|
||||
return false;
|
||||
}
|
||||
|
||||
int64_t split_axis;
|
||||
if (!GetSplitAxis(split_axis_constant, split->input_value(0).get_partial_shape().rank(), split_axis)) {
|
||||
return false;
|
||||
}
|
||||
TransposeInputsInfo transpose_input_info = GetFirstTransposeInput(main_node);
|
||||
|
||||
sink_forward::RemoveInputNode(main_node, /* input_idx */ 0);
|
||||
@ -181,13 +213,10 @@ ov::pass::TransposeSinkingSplitForward::TransposeSinkingSplitForward() {
|
||||
}
|
||||
|
||||
const auto transpose_axis_order = transpose_input_info.transpose_const->get_axis_vector_val();
|
||||
auto split_node = as_type_ptr<Split>(main_node);
|
||||
auto split_axis_constant = as_type_ptr<Constant>(split_node->input_value(1).get_node_shared_ptr());
|
||||
const size_t split_axis = split_axis_constant->get_axis_vector_val()[0];
|
||||
const size_t transposed_split_axis = transpose_axis_order[split_axis];
|
||||
auto new_split_axis_const =
|
||||
std::make_shared<Constant>(split_axis_constant->get_element_type(), Shape{}, transposed_split_axis);
|
||||
split_node->input(1).replace_source_output(new_split_axis_const);
|
||||
split->input(1).replace_source_output(new_split_axis_const);
|
||||
copy_runtime_info({split_axis_constant, transpose_input_info.transpose, transpose_input_info.transpose_const},
|
||||
new_split_axis_const);
|
||||
|
||||
|
@ -120,7 +120,11 @@ NodePtr InsertUnsqueeze(const Output<Node>& node, size_t n_dims) {
|
||||
}
|
||||
|
||||
ov::Output<ov::Node> FixInputNodeRank(ov::Output<ov::Node> input_node, ov::Rank::value_type required_rank) {
|
||||
const ov::Rank::value_type output_rank = input_node.get_partial_shape().rank().get_length();
|
||||
auto rank = input_node.get_partial_shape().rank();
|
||||
if (rank.is_dynamic()) {
|
||||
return input_node;
|
||||
}
|
||||
const auto output_rank = rank.get_length();
|
||||
if (output_rank >= required_rank)
|
||||
return input_node;
|
||||
return InsertUnsqueeze(input_node, required_rank - output_rank)->output(0);
|
||||
@ -129,31 +133,54 @@ ov::Output<ov::Node> FixInputNodeRank(ov::Output<ov::Node> input_node, ov::Rank:
|
||||
} // namespace
|
||||
|
||||
namespace sink_forward {
|
||||
AxisVector AlignTransposeOrder(const Output<Node>& output, const TransposeInputsInfo& transpose_input_info) {
|
||||
if (transpose_input_info.isEmpty()) {
|
||||
return {};
|
||||
}
|
||||
auto num_of_val = static_cast<int64_t>(shape_size(transpose_input_info.transpose_const->get_shape()));
|
||||
const auto rank = output.get_partial_shape().rank();
|
||||
const auto rank_val = rank.get_length();
|
||||
AxisVector new_transpose_order;
|
||||
if (rank_val > num_of_val) {
|
||||
const auto diff = rank_val - num_of_val;
|
||||
new_transpose_order.resize(rank_val);
|
||||
std::iota(new_transpose_order.begin(), new_transpose_order.end(), 0);
|
||||
auto transpose_axis_order = transpose_input_info.transpose_const->get_axis_vector_val();
|
||||
for (int64_t i = diff; i < rank_val; ++i) {
|
||||
new_transpose_order[i] = transpose_axis_order[i - diff] + diff;
|
||||
}
|
||||
} else {
|
||||
new_transpose_order = transpose_input_info.transpose_const->get_axis_vector_val();
|
||||
}
|
||||
return new_transpose_order;
|
||||
}
|
||||
|
||||
void UpdateInputTransposes(const NodePtr& main_node, const TransposeInputsInfo& transpose_input_info) {
|
||||
bool UpdateInputTransposes(const NodePtr& main_node, const TransposeInputsInfo& transpose_input_info) {
|
||||
if (transpose_input_info.isEmpty() || HasDynamicRankInput(main_node))
|
||||
return;
|
||||
return false;
|
||||
|
||||
const auto max_input_rank = GetMaxInputRank(main_node);
|
||||
if (max_input_rank < 0)
|
||||
return;
|
||||
return false;
|
||||
|
||||
const auto transpose_axis_order = transpose_input_info.transpose_const->get_axis_vector_val();
|
||||
const auto reversed_traspose_axis_order = ReverseTransposeOrder(transpose_axis_order);
|
||||
const size_t tranpose_input_index = transpose_input_info.input_idx;
|
||||
const size_t transpose_input_index = transpose_input_info.input_idx;
|
||||
const auto transpose_element_type = transpose_input_info.transpose_const->get_element_type();
|
||||
|
||||
for (size_t i = 0; i < main_node->get_input_size(); ++i) {
|
||||
auto input_node = main_node->input_value(i);
|
||||
if (i == tranpose_input_index) {
|
||||
if (i == transpose_input_index) {
|
||||
auto transpose_parent = input_node.get_node()->input_value(0);
|
||||
main_node->input(i).replace_source_output(transpose_parent);
|
||||
} else {
|
||||
input_node = FixInputNodeRank(input_node, max_input_rank);
|
||||
|
||||
auto transpose_order = AlignTransposeOrder(input_node, transpose_input_info);
|
||||
if (transpose_order.empty()) {
|
||||
return false;
|
||||
}
|
||||
const auto reversed_transpose_axis_order = ReverseTransposeOrder(transpose_order);
|
||||
auto new_transpose_const = std::make_shared<Constant>(transpose_element_type,
|
||||
Shape{reversed_traspose_axis_order.size()},
|
||||
reversed_traspose_axis_order);
|
||||
Shape{reversed_transpose_axis_order.size()},
|
||||
reversed_transpose_axis_order);
|
||||
auto new_transpose = std::make_shared<Transpose>(input_node, new_transpose_const);
|
||||
|
||||
main_node->input(i).replace_source_output(new_transpose->output(0));
|
||||
@ -161,6 +188,7 @@ void UpdateInputTransposes(const NodePtr& main_node, const TransposeInputsInfo&
|
||||
copy_runtime_info(input_node.get_node_shared_ptr(), {new_transpose, new_transpose_const});
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
void RemoveInputNode(const NodePtr& main_node, size_t input_idx) {
|
||||
@ -174,20 +202,17 @@ void RemoveInputNode(const NodePtr& main_node, size_t input_idx) {
|
||||
NodeVector InsertOutputTransposes(const NodePtr& main_node, const TransposeInputsInfo& transpose_input_info) {
|
||||
if (transpose_input_info.isEmpty())
|
||||
return {};
|
||||
const auto transpose_axis_order = transpose_input_info.transpose_const->get_axis_vector_val();
|
||||
const auto reversed_traspose_axis_order = ReverseTransposeOrder(transpose_axis_order);
|
||||
|
||||
auto new_transpose_order = AlignTransposeOrder(main_node->output(0), transpose_input_info);
|
||||
const auto transpose_element_type = transpose_input_info.transpose_const->get_element_type();
|
||||
|
||||
NodeVector new_nodes;
|
||||
|
||||
for (size_t i = 0; i < main_node->get_output_size(); ++i) {
|
||||
auto new_transpose_const =
|
||||
std::make_shared<Constant>(transpose_element_type, Shape{new_transpose_order.size()}, new_transpose_order);
|
||||
auto main_node_consumers = main_node->output(i).get_target_inputs();
|
||||
|
||||
auto new_transpose_const = std::make_shared<Constant>(transpose_element_type,
|
||||
Shape{transpose_axis_order.size()},
|
||||
transpose_axis_order);
|
||||
auto new_transpose = std::make_shared<Transpose>(main_node->output(i), new_transpose_const);
|
||||
|
||||
for (auto& consumer : main_node_consumers) {
|
||||
consumer.replace_source_output(new_transpose);
|
||||
}
|
||||
|
@ -73,15 +73,15 @@ shared_ptr<Model> CreateFunction(size_t num_pad_ops, element::Type input_type) {
|
||||
auto X = make_shared<Parameter>(input_type, input_shape);
|
||||
|
||||
auto order = make_shared<Constant>(element::i64, Shape{4}, Shape{0, 3, 1, 2});
|
||||
auto transpose = make_shared<Transpose>(X, order);
|
||||
auto transpose = make_shared<Transpose>(X, order); // 96 55 32 55
|
||||
|
||||
OutputVector outputs;
|
||||
Output<Node> in_op = transpose->output(0);
|
||||
auto pad_value = make_shared<Constant>(input_type, Shape{}, 0);
|
||||
for (size_t i = 0; i < num_pad_ops; ++i) {
|
||||
auto pad_begin_const = make_shared<Constant>(element::i64, Shape{4}, vector<int64_t>{0, 1, 2, 3});
|
||||
auto pad_end_const = make_shared<Constant>(element::i64, Shape{4}, vector<int64_t>{0, 1, 2, 3});
|
||||
auto pad = make_shared<Pad>(in_op, pad_begin_const, pad_end_const, pad_value, ov::op::PadMode::CONSTANT);
|
||||
auto pad_begin_const = make_shared<Constant>(element::i64, Shape{4}, vector<int64_t>{95, 54, 31, 53});
|
||||
auto pad_end_const = make_shared<Constant>(element::i64, Shape{4}, vector<int64_t>{95, 54, 31, 53});
|
||||
auto pad = make_shared<Pad>(in_op, pad_begin_const, pad_end_const, pad_value, ov::op::PadMode::REFLECT);
|
||||
outputs.push_back((pad->output(0)));
|
||||
in_op = pad;
|
||||
}
|
||||
@ -97,7 +97,7 @@ shared_ptr<Model> CreateReferenceFunction(size_t num_pad_ops, element::Type inpu
|
||||
|
||||
OutputVector outputs;
|
||||
Output<Node> in_op = X->output(0);
|
||||
vector<int64_t> pads{0, 1, 2, 3};
|
||||
vector<int64_t> pads{95, 54, 31, 53};
|
||||
auto transpose_pad_values = [&](const vector<size_t>& order) {
|
||||
vector<int64_t> new_pads(pads.size());
|
||||
for (size_t i = 0; i < pads.size(); ++i) {
|
||||
@ -107,10 +107,10 @@ shared_ptr<Model> CreateReferenceFunction(size_t num_pad_ops, element::Type inpu
|
||||
};
|
||||
auto axis = make_shared<Constant>(element::i64, Shape{}, 0);
|
||||
auto pad_value = make_shared<Constant>(input_type, Shape{}, 0);
|
||||
vector<size_t> order_val = {0, 3, 1, 2};
|
||||
for (size_t i = 0; i < num_pad_ops; ++i) {
|
||||
vector<size_t> order_val = {0, 3, 1, 2};
|
||||
auto pad_begin_const = make_shared<Constant>(element::i64, Shape{4}, transpose_pad_values(order_val));
|
||||
auto pad_end_const = make_shared<Constant>(element::i64, Shape{4}, transpose_pad_values(order_val));
|
||||
auto pad_begin_const = make_shared<Constant>(element::i64, Shape{4}, transpose_pad_values({0, 2, 3, 1}));
|
||||
auto pad_end_const = make_shared<Constant>(element::i64, Shape{4}, transpose_pad_values({0, 2, 3, 1}));
|
||||
auto pad = make_shared<Pad>(in_op, pad_begin_const, pad_end_const, pad_value, ov::op::PadMode::CONSTANT);
|
||||
|
||||
auto order = make_shared<Constant>(element::i64, Shape{4}, Shape{order_val});
|
||||
@ -119,7 +119,7 @@ shared_ptr<Model> CreateReferenceFunction(size_t num_pad_ops, element::Type inpu
|
||||
in_op = pad;
|
||||
}
|
||||
|
||||
auto order = make_shared<Constant>(element::i64, Shape{4}, Shape{0, 3, 1, 2});
|
||||
auto order = make_shared<Constant>(element::i64, Shape{4}, order_val);
|
||||
auto transpose = make_shared<Transpose>(in_op, order);
|
||||
outputs.push_back(transpose);
|
||||
|
||||
|
@ -19,6 +19,7 @@
|
||||
#include "so_extension.hpp"
|
||||
#include "tf_framework_node.hpp"
|
||||
#include "transformations/common_optimizations/reverse_shape_and_type_infer.hpp"
|
||||
#include "transformations/common_optimizations/transpose_sinking_general.hpp"
|
||||
#include "translate_session.hpp"
|
||||
#include "utils.hpp"
|
||||
|
||||
@ -251,8 +252,7 @@ void FrontEnd::normalize(const std::shared_ptr<ov::Model>& function) const {
|
||||
manager.register_pass<pass::BlockLSTMReplacer>();
|
||||
manager.register_pass<pass::GRUBlockCellReplacer>();
|
||||
|
||||
// TODO: reimplement TransposeSinking that does not corrupt filters for Convolution
|
||||
manager.register_pass<ov::frontend::tensorflow::pass::TransposeSinking>();
|
||||
manager.register_pass<ov::pass::TransposeSinkingGeneral>();
|
||||
manager.register_pass<ov::pass::ReverseShapeAndTypeInfer>();
|
||||
manager.run_passes(function);
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user