[TF FE] Fix Transpose Sinking for dynamic shapes (#15302)

* TransposeSinking transformation: add support for dynamic shapes

* fix e2e tests

Co-authored-by: Ivan <ivan.tikhonov@intel.com>
This commit is contained in:
Roman Kazantsev 2023-01-25 21:04:26 +04:00 committed by GitHub
parent bf98d31393
commit 7ee781dfbf
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 171 additions and 106 deletions

View File

@ -14,9 +14,7 @@ namespace pass {
class TransposeSinking : public ov::pass::ModelPass {
public:
OPENVINO_RTTI("ov::frontend::tensorflow::pass::TransposeSinking");
TransposeSinking() {
set_property(ov::pass::PassProperty::REQUIRE_STATIC_SHAPE, true);
}
TransposeSinking() = default;
bool run_on_model(const std::shared_ptr<ov::Model>& function) override;
};

View File

@ -42,6 +42,12 @@ static AxisVector get_default_order(size_t rank) {
return default_order;
}
int64_t get_static_rank(const Output<Node>& output) {
auto rank = output.get_partial_shape().rank();
OPENVINO_ASSERT(rank.is_static(), "Dynamic rank is not supported in TransposeSinking transformation.");
return rank.get_length();
}
template <typename T>
static string describe(shared_ptr<Node> node) {
// ensure that it's either a reshape or a transpose
@ -93,7 +99,7 @@ static shared_ptr<Transpose> read_transposemap(TransposeMap& reorders, const Out
}
static shared_ptr<Transpose> combine_transposes(const shared_ptr<Transpose>& t1, const shared_ptr<Transpose>& t2) {
auto default_order = get_default_order(t1->get_shape().size());
auto default_order = get_default_order(get_static_rank(t1));
auto t1_const = as_type_ptr<Constant>(t1->input_value(1).get_node_shared_ptr());
auto t2_const = as_type_ptr<Constant>(t2->input_value(1).get_node_shared_ptr());
@ -112,7 +118,9 @@ static shared_ptr<Transpose> combine_transposes(const shared_ptr<Transpose>& t1,
static void insert_transpose(const shared_ptr<Node>& target, const shared_ptr<Node>& transpose, size_t input_index) {
OPENVINO_DEBUG << "Inserting transpose at input " << target->get_name() << " input index " << input_index;
auto arg = target->input(input_index).get_source_output();
OPENVINO_DEBUG << "Arg shape: " << arg.get_shape();
if (arg.get_partial_shape().is_static()) {
OPENVINO_DEBUG << "Arg shape: " << arg.get_shape();
}
auto new_order = as_type_ptr<Constant>(transpose->input_value(1).get_node_shared_ptr());
auto new_transpose = make_transpose(arg.get_node_shared_ptr(), new_order->get_axis_vector_val());
OPENVINO_DEBUG << "Inserting transpose " << describe<Transpose>(new_transpose) << " at input " << target->get_name()
@ -142,7 +150,7 @@ static void mark_transpose_for_deletion(const shared_ptr<Node>& transpose,
}
static shared_ptr<Transpose> create_default_transpose(const Output<Node>& n) {
auto default_order = get_default_order(n.get_shape().size());
auto default_order = get_default_order(get_static_rank(n));
auto order = std::make_shared<Constant>(element::i64, Shape{default_order.size()}, default_order);
return make_shared<Transpose>(n, order);
}
@ -160,27 +168,32 @@ static void convert_binary_to_default_order(const shared_ptr<Node>& binary,
auto left = input.get_source_output();
auto right_t = read_transposemap(reorders, right);
auto right_const = as_type_ptr<Constant>(right_t->input_value(1).get_node_shared_ptr());
auto perm_to_def = permutation_to_default_order(right_const->get_axis_vector_val());
// if right input is being implicitly broadcasted, insert a reshape
// instead of a transpose
shared_ptr<Node> new_node;
auto left_shape = left.get_shape();
if (left_shape.size() < perm_to_def.size()) {
left_shape.insert(left_shape.begin(), perm_to_def.size() - left_shape.size(), 1);
auto left_rank = get_static_rank(left);
if (left_rank < perm_to_def.size() && left.get_partial_shape().is_static()) {
auto left_shape = left.get_shape();
left_shape.insert(left_shape.begin(), perm_to_def.size() - left_rank, 1);
auto new_shape = apply_permutation(left_shape, perm_to_def);
new_node = make_reshape(left, new_shape);
} else if (left_shape.size() == perm_to_def.size()) {
} else if (left_rank == perm_to_def.size()) {
new_node = make_transpose(left, perm_to_def);
} else {
throw runtime_error("case not supported when converting binary to default order");
}
input.replace_source_output(new_node->output(0));
OPENVINO_DEBUG << "right = " << ov::util::vector_to_string(right.get_shape()) << ", "
<< right.get_node_shared_ptr()->get_name();
if (right.get_partial_shape().is_static()) {
OPENVINO_DEBUG << "right = " << ov::util::vector_to_string(right.get_shape()) << ", "
<< right.get_node_shared_ptr()->get_name();
} else {
OPENVINO_DEBUG << "right = "
<< "dynamic shape, " << right.get_node_shared_ptr()->get_name();
}
// this should now insert transpose on right
mark_transpose_for_deletion(right_t, transposes_to_delete);
write_transposemap(reorders, binary, right_t);
@ -203,14 +216,15 @@ static void materialize_shapes(const shared_ptr<Node>& n,
<< arg.get_node_shared_ptr()->get_name();
mark_transpose_for_deletion(arg_transpose, transposes_to_delete);
auto arg_transpose_order = as_type_ptr<Constant>(arg_transpose->input_value(1).get_node_shared_ptr());
if (arg_transpose_order->get_axis_vector_val() != get_default_order(arg.get_shape().size())) {
if (arg_transpose_order &&
arg_transpose_order->get_axis_vector_val() != get_default_order(get_static_rank(arg))) {
// Insert if arg needs to be transposed.
insert_transpose(n, arg_transpose, i);
}
}
}
static void sink_transpose(const shared_ptr<Transpose>& transpose,
static bool sink_transpose(const shared_ptr<Transpose>& transpose,
TransposeMap& reorders,
set<shared_ptr<Node>>& transposes_to_delete) {
OPENVINO_DEBUG << "Sinking Transpose :" << describe<Transpose>(transpose);
@ -226,18 +240,25 @@ static void sink_transpose(const shared_ptr<Transpose>& transpose,
replace_node(transpose, new_transpose);
mark_transpose_for_deletion(new_transpose, transposes_to_delete);
write_transposemap(reorders, new_transpose, new_transpose);
} else {
// combine_transposes failed
// transpose remains in the graph
OPENVINO_DEBUG << "CombineTranspose has failed. Writing original transpose to the transpose map.";
return false;
}
return true;
}
static void sink_unary(const shared_ptr<Node>& n,
static bool sink_unary(const shared_ptr<Node>& n,
TransposeMap& reorders,
set<shared_ptr<Node>>& /* transposes_to_delete */) {
auto arg_transpose = read_transposemap(reorders, n->input_value(0));
OPENVINO_DEBUG << "Propagating " << describe<Transpose>(arg_transpose) << " for " << n->get_name();
write_transposemap(reorders, n, arg_transpose);
return true;
}
static void sink_binary(const shared_ptr<Node>& binary,
static bool sink_binary(const shared_ptr<Node>& binary,
TransposeMap& reorders,
set<shared_ptr<Node>>& transposes_to_delete) {
auto left = binary->input_value(0);
@ -246,18 +267,25 @@ static void sink_binary(const shared_ptr<Node>& binary,
auto right_t = read_transposemap(reorders, right);
auto left_const = as_type_ptr<Constant>(left_t->input_value(1).get_node_shared_ptr());
auto right_const = as_type_ptr<Constant>(right_t->input_value(1).get_node_shared_ptr());
if (!(left_const && right_const)) {
OPENVINO_DEBUG << "TransposeSinking failed for binary op " << binary->get_name()
<< "2nd inputs to Transposes must be constants.";
return false;
}
auto left_order = left_const->get_axis_vector_val();
auto right_order = right_const->get_axis_vector_val();
auto left_mismatch = left_order != get_default_order(left.get_shape().size());
auto right_mismatch = right_order != get_default_order(right.get_shape().size());
auto left_rank = get_static_rank(left);
auto right_rank = get_static_rank(right);
auto left_mismatch = left_order != get_default_order(left_rank);
auto right_mismatch = right_order != get_default_order(right_rank);
OPENVINO_DEBUG << "Sink binary " << binary->get_name()
<< " left transpose: " << ov::util::vector_to_string(left_order)
<< " left default: " << ov::util::vector_to_string(get_default_order(left.get_shape().size()))
<< " left default: " << ov::util::vector_to_string(get_default_order(left_rank))
<< " right transpose: " << ov::util::vector_to_string(right_order)
<< " right default: " << ov::util::vector_to_string(get_default_order(right.get_shape().size()));
<< " right default: " << ov::util::vector_to_string(get_default_order(right_rank));
if ((left_order.size() == right_order.size() && left_order == right_order) || (!left_mismatch && !right_mismatch)) {
// Propagate the reshape which matches the shape of the binary node
@ -277,95 +305,115 @@ static void sink_binary(const shared_ptr<Node>& binary,
}
}
} catch (const std::exception&) {
throw std::runtime_error("");
return false;
}
}
return true;
}
static void sink_pad(shared_ptr<Pad> n, TransposeMap& reorders, set<shared_ptr<Node>>& /* transposes_to_delete */) {
static bool sink_pad(shared_ptr<Pad> n, TransposeMap& reorders, set<shared_ptr<Node>>& /* transposes_to_delete */) {
auto n_in = n->input_value(0);
auto arg_transpose = read_transposemap(reorders, n_in);
describe<Transpose>(arg_transpose);
auto arg_transpose_order = as_type_ptr<Constant>(arg_transpose->input_value(1).get_node_shared_ptr());
auto order = arg_transpose_order->get_axis_vector_val();
// we need the correct input shape to produce the right output shape
// we are going to create a label of the right input shape,
// so a new pad will have the right shape
auto def_order = permutation_to_default_order(order);
if (arg_transpose->get_output_partial_shape(0).is_static()) {
auto arg_transpose_order = as_type_ptr<Constant>(arg_transpose->input_value(1).get_node_shared_ptr());
auto order = arg_transpose_order->get_axis_vector_val();
// we need the correct input shape to produce the right output shape
// we are going to create a label of the right input shape,
// so a new pad will have the right shape
auto def_order = permutation_to_default_order(order);
auto input_shape = apply_permutation(arg_transpose->get_shape(), def_order);
auto input_shape = apply_permutation(arg_transpose->get_shape(), def_order);
auto dummy_correct_shape =
make_shared<ov::pass::pattern::op::Label>(arg_transpose->get_element_type(), input_shape);
auto dummy_correct_shape =
make_shared<ov::pass::pattern::op::Label>(arg_transpose->get_element_type(), input_shape);
auto pad_begin = apply_permutation(n->get_pads_begin(), def_order);
auto pad_end = apply_permutation(n->get_pads_end(), def_order);
auto pad_begin = apply_permutation(n->get_pads_begin(), def_order);
auto pad_end = apply_permutation(n->get_pads_end(), def_order);
auto new_begin = make_shared<Constant>(element::i64, Shape{pad_begin.size()}, pad_begin);
auto new_end = make_shared<Constant>(element::i64, Shape{pad_end.size()}, pad_end);
auto new_pad = make_shared<Pad>(dummy_correct_shape, new_begin, new_end, n->input_value(3), n->get_pad_mode());
replace_node(dummy_correct_shape, n->input_value(0).get_node_shared_ptr());
OPENVINO_DEBUG << "Replacing " << n->get_name() << " with " << new_pad->get_name();
replace_node(n, new_pad);
auto new_transpose = make_transpose(new_pad, order);
OPENVINO_DEBUG << "Propagating " << describe<Transpose>(new_transpose) << " for " << n->get_name();
write_transposemap(reorders, new_pad, new_transpose);
auto new_begin = make_shared<Constant>(element::i64, Shape{pad_begin.size()}, pad_begin);
auto new_end = make_shared<Constant>(element::i64, Shape{pad_end.size()}, pad_end);
auto new_pad = make_shared<Pad>(dummy_correct_shape, new_begin, new_end, n->input_value(3), n->get_pad_mode());
replace_node(dummy_correct_shape, n->input_value(0).get_node_shared_ptr());
OPENVINO_DEBUG << "Replacing " << n->get_name() << " with " << new_pad->get_name();
replace_node(n, new_pad);
auto new_transpose = make_transpose(new_pad, order);
OPENVINO_DEBUG << "Propagating " << describe<Transpose>(new_transpose) << " for " << n->get_name();
write_transposemap(reorders, new_pad, new_transpose);
} else {
OPENVINO_DEBUG << "TransposeSinking failed for Pad op " << n->get_name()
<< " . Output shape of Transpose op must be static.";
return false;
}
return true;
}
static void sink_concat(const shared_ptr<Concat>& n,
static bool sink_concat(const shared_ptr<Concat>& n,
TransposeMap& reorders,
set<shared_ptr<Node>>& transposes_to_delete) {
auto n_in = n->input_value(0);
auto arg_transpose = read_transposemap(reorders, n_in);
auto arg_transpose_order = as_type_ptr<Constant>(arg_transpose->input_value(1).get_node_shared_ptr());
auto order = arg_transpose_order->get_axis_vector_val();
// we need the correct input shape to produce the right output shape
// we are going to create a label of the right input shape,
// so a new concat will have the right shape
auto def_order = permutation_to_default_order(order);
if (arg_transpose->get_output_partial_shape(0).is_static()) {
auto arg_transpose_order = as_type_ptr<Constant>(arg_transpose->input_value(1).get_node_shared_ptr());
auto order = arg_transpose_order->get_axis_vector_val();
// we need the correct input shape to produce the right output shape
// we are going to create a label of the right input shape,
// so a new concat will have the right shape
auto def_order = permutation_to_default_order(order);
auto input_shape = apply_permutation(arg_transpose->get_shape(), def_order);
auto input_shape = apply_permutation(arg_transpose->get_shape(), def_order);
auto dummy_correct_shape =
make_shared<ov::pass::pattern::op::Label>(arg_transpose->get_element_type(), input_shape);
auto dummy_correct_shape =
make_shared<ov::pass::pattern::op::Label>(arg_transpose->get_element_type(), input_shape);
NodeVector new_args;
new_args.push_back(dummy_correct_shape);
NodeVector new_args;
new_args.push_back(dummy_correct_shape);
for (size_t i = 1; i < n->get_input_size(); i++) {
auto iarg = n->input_value(i);
auto iarg_transpose = read_transposemap(reorders, iarg);
auto iarg_transpose_order = as_type_ptr<Constant>(iarg_transpose->input_value(1).get_node_shared_ptr());
auto iorder = iarg_transpose_order->get_axis_vector_val();
if (iorder != order) {
OPENVINO_DEBUG << " input order at " << i << "-th arg is different from first arg";
materialize_shapes(n, reorders, transposes_to_delete);
return;
for (size_t i = 1; i < n->get_input_size(); i++) {
auto iarg = n->input_value(i);
auto iarg_transpose = read_transposemap(reorders, iarg);
auto iarg_transpose_order = as_type_ptr<Constant>(iarg_transpose->input_value(1).get_node_shared_ptr());
auto iorder = iarg_transpose_order->get_axis_vector_val();
if (iorder != order) {
OPENVINO_DEBUG << " input order at " << i << "-th arg is different from first arg";
return false;
}
if (iarg_transpose->get_output_partial_shape(0).is_dynamic()) {
OPENVINO_DEBUG << "TransposeSinking failed for Concat op " << n->get_name()
<< " . Input Transpose ops"
" must have static shapes. ";
return false;
}
auto iinput_shape = apply_permutation(iarg_transpose->get_shape(), def_order);
auto idummy_correct_shape =
make_shared<ov::pass::pattern::op::Label>(iarg_transpose->get_element_type(), iinput_shape);
new_args.push_back(idummy_correct_shape);
}
auto iinput_shape = apply_permutation(iarg_transpose->get_shape(), def_order);
auto idummy_correct_shape =
make_shared<ov::pass::pattern::op::Label>(iarg_transpose->get_element_type(), iinput_shape);
new_args.push_back(idummy_correct_shape);
auto new_axis = order.at(n->get_concatenation_axis());
auto new_concat = make_shared<Concat>(new_args, new_axis);
// put back the original arguments
for (size_t i = 0; i < new_concat->get_input_size(); i++) {
OPENVINO_DEBUG << "Replacing " << new_concat->get_name() << " input " << i << " with " << n->get_name()
<< " input " << i;
new_concat->input(i).replace_source_output(n->input_value(i));
}
OPENVINO_DEBUG << "Replacing " << n->get_name() << " with " << new_concat->get_name();
replace_node(n, new_concat);
auto new_transpose = make_transpose(new_concat, order);
OPENVINO_DEBUG << "Propagating " << describe<Transpose>(new_transpose) << " for " << n->get_name();
write_transposemap(reorders, new_concat, new_transpose);
} else {
OPENVINO_DEBUG << "TransposeSinking failed for Concat op " << n->get_name()
<< " . Output shape of Transpose op must be static.";
return false;
}
auto new_axis = order.at(n->get_concatenation_axis());
auto new_concat = make_shared<Concat>(new_args, new_axis);
// put back the original arguments
for (size_t i = 0; i < new_concat->get_input_size(); i++) {
OPENVINO_DEBUG << "Replacing " << new_concat->get_name() << " input " << i << " with " << n->get_name()
<< " input " << i;
new_concat->input(i).replace_source_output(n->input_value(i));
}
OPENVINO_DEBUG << "Replacing " << n->get_name() << " with " << new_concat->get_name();
replace_node(n, new_concat);
auto new_transpose = make_transpose(new_concat, order);
OPENVINO_DEBUG << "Propagating " << describe<Transpose>(new_transpose) << " for " << n->get_name();
write_transposemap(reorders, new_concat, new_transpose);
return true;
}
static void sink_prelu(const shared_ptr<PRelu>& prelu,
static bool sink_prelu(const shared_ptr<PRelu>& prelu,
TransposeMap& reorders,
set<shared_ptr<Node>>& transposes_to_delete) {
FRONT_END_GENERAL_CHECK(prelu, "Null pointer is given to PRelu node.");
@ -377,9 +425,9 @@ static void sink_prelu(const shared_ptr<PRelu>& prelu,
OPENVINO_DEBUG << "Propagating " << describe<Transpose>(arg_transpose) << " for " << prelu->get_name();
write_transposemap(reorders, prelu, arg_transpose);
} else {
// TODO: handle other cases with non-scalar slope
materialize_shapes(prelu, reorders, transposes_to_delete);
return false;
}
return true;
}
void purge_transposes(const set<shared_ptr<Node>>& transposes_to_delete) {
@ -400,7 +448,7 @@ void purge_transposes(const set<shared_ptr<Node>>& transposes_to_delete) {
bool ov::frontend::tensorflow::pass::TransposeSinking::run_on_model(const shared_ptr<Model>& f) {
TransposeMap reorders;
set<shared_ptr<Node>> transposes_to_delete;
unordered_map<std::string, Shape> orig_result_out_shape;
unordered_map<std::string, PartialShape> orig_result_out_shape;
// STEP 1 : Sink or Swim transposes away for op clusters
try {
@ -408,24 +456,28 @@ bool ov::frontend::tensorflow::pass::TransposeSinking::run_on_model(const shared
OPENVINO_DEBUG << "Processing " << n->get_name();
// collect output shape of all Result nodes for a sanity check
if (ov::op::util::is_output(n)) {
orig_result_out_shape[n->get_name()] = n->get_output_shape(0);
orig_result_out_shape[n->get_name()] = n->get_output_partial_shape(0);
}
bool sink_res = false;
if (auto transpose = as_type_ptr<opset8::Transpose>(n)) {
sink_transpose(transpose, reorders, transposes_to_delete);
sink_res = sink_transpose(transpose, reorders, transposes_to_delete);
} else if (ov::op::util::is_unary_elementwise_arithmetic(n) || as_type_ptr<Clamp>(n) ||
as_type_ptr<Elu>(n) || as_type_ptr<SoftPlus>(n) || as_type_ptr<LogicalNot>(n)) {
// Some unary operations are inherrited from Op class
// so we need explicitly to check them
sink_unary(n, reorders, transposes_to_delete);
sink_res = sink_unary(n, reorders, transposes_to_delete);
} else if (ov::op::util::is_binary_elementwise_arithmetic(n)) {
sink_binary(n, reorders, transposes_to_delete);
sink_res = sink_binary(n, reorders, transposes_to_delete);
} else if (auto pad = as_type_ptr<Pad>(n)) {
sink_pad(pad, reorders, transposes_to_delete);
sink_res = sink_pad(pad, reorders, transposes_to_delete);
} else if (auto concat = as_type_ptr<Concat>(n)) {
sink_concat(concat, reorders, transposes_to_delete);
sink_res = sink_concat(concat, reorders, transposes_to_delete);
} else if (auto prelu = as_type_ptr<PRelu>(n)) {
sink_prelu(prelu, reorders, transposes_to_delete);
} else {
sink_res = sink_prelu(prelu, reorders, transposes_to_delete);
}
if (!sink_res) {
materialize_shapes(n, reorders, transposes_to_delete);
}
}
@ -448,16 +500,16 @@ bool ov::frontend::tensorflow::pass::TransposeSinking::run_on_model(const shared
const ResultVector& results = f->get_results();
for (const auto& r : results) {
// make sure shapes are always materialized before results
FRONT_END_GENERAL_CHECK(
r->get_shape() == r->get_input_shape(0) && r->get_element_type() == r->input_value(0).get_element_type(),
" op::Result = ",
*r,
", Arg = ",
r->input_value(0).get_node());
FRONT_END_GENERAL_CHECK(r->get_output_partial_shape(0) == r->get_input_partial_shape(0) &&
r->get_element_type() == r->input_value(0).get_element_type(),
" op::Result = ",
*r,
", Arg = ",
r->input_value(0).get_node());
// make sure that after TransposeSinking pass the output_shape for Result
// does not change from the expected output_shape before the pass
FRONT_END_GENERAL_CHECK(r->get_output_shape(0) == orig_result_out_shape[r->get_name()],
FRONT_END_GENERAL_CHECK(r->get_output_partial_shape(0) == orig_result_out_shape[r->get_name()],
" op::Result = ",
*r,
" expected output shape = ",

View File

@ -26,10 +26,25 @@ int64_t count_ops_of_type(const shared_ptr<Model>& f) {
return cnt;
}
TEST(TransposeSinkingTest, PassProperty) {
auto pass = std::make_shared<TransposeSinking>();
ASSERT_TRUE(pass->get_property(ov::pass::PassProperty::REQUIRE_STATIC_SHAPE));
ASSERT_FALSE(pass->get_property(ov::pass::PassProperty::CHANGE_DYNAMIC_STATE));
TEST(TransposeSinkingTest, DynamicShape) {
ov::PartialShape shape_nhwc(vector<Dimension>(4, Dimension::dynamic()));
auto a = make_shared<Parameter>(ngraph::element::i32, shape_nhwc);
auto ng_order = std::make_shared<Constant>(ngraph::element::u64, ngraph::Shape{4}, ngraph::Shape{0, 3, 1, 2});
auto transpose = make_shared<Transpose>(a, ng_order);
auto absn = make_shared<Abs>(transpose);
auto absn2 = make_shared<Abs>(absn);
absn2->output(0).set_names({"out_name"});
auto res = make_shared<Result>(absn2);
auto func = make_shared<ngraph::Function>(ngraph::OutputVector{res}, ngraph::ParameterVector{a});
ov::pass::Manager pass_manager;
pass_manager.register_pass<TransposeSinking>();
pass_manager.run_passes(func);
auto new_transpose =
ngraph::as_type_ptr<Transpose>(func->get_results().at(0)->input_value(0).get_node_shared_ptr());
ASSERT_TRUE(new_transpose);
EXPECT_EQ(new_transpose->output(0).get_names(), std::unordered_set<std::string>({"out_name"}));
}
TEST(TransposeSinkingTest, TensorNames) {