[TF FE] Port changes from 2022.2 branch - Add Transpose Sinking for Prelu operation (#12848)

* [TF FE] Add Transpose Sinking for Prelu operation

Now it covers a case with a scalar slope.

Signed-off-by: Kazantsev, Roman <roman.kazantsev@intel.com>

* Add unit-tests for Transpose sinking of Prelu

Signed-off-by: Kazantsev, Roman <roman.kazantsev@intel.com>

* Fix non-scalar slope case

Signed-off-by: Kazantsev, Roman <roman.kazantsev@intel.com>

Signed-off-by: Kazantsev, Roman <roman.kazantsev@intel.com>
This commit is contained in:
Roman Kazantsev 2022-09-01 09:29:17 +03:00 committed by GitHub
parent 4b104f0a13
commit 3856d69ae1
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 71 additions and 0 deletions

View File

@ -348,6 +348,23 @@ static void sink_concat(const shared_ptr<Concat>& n,
write_transposemap(reorders, new_concat, new_transpose);
}
static void sink_prelu(const shared_ptr<PRelu>& prelu,
TransposeMap& reorders,
set<shared_ptr<Node>>& transposes_to_delete) {
FRONT_END_GENERAL_CHECK(prelu, "Null pointer is given to PRelu node.");
FRONT_END_GENERAL_CHECK(prelu->get_input_size() > 1, "The PRelu node must contain at least two inputs.");
auto slope_shape = prelu->input_value(1).get_partial_shape();
if (slope_shape.is_static() && shape_size(slope_shape.to_shape()) == 1) {
// handle a case covering LeakyRelu decomposition
auto arg_transpose = read_transposemap(reorders, prelu->input_value(0));
OPENVINO_DEBUG << "Propagating " << describe<Transpose>(arg_transpose) << " for " << prelu->get_name();
write_transposemap(reorders, prelu, arg_transpose);
} else {
// TODO: handle other cases with non-scalar slope
materialize_shapes(prelu, reorders, transposes_to_delete);
}
}
// The goal of TransposeSinking is to remove
// round-trip transposes(i.e. nhwc->nchw(nchw-only-op)->nhwc)
// around nchw-only-op (e.g.Convolution, Batchnorm, Avg/MaxPool)
@ -383,6 +400,8 @@ bool ov::frontend::tensorflow::pass::TransposeSinking::run_on_model(const shared
sink_pad(pad, reorders, transposes_to_delete);
} else if (auto concat = as_type_ptr<Concat>(n)) {
sink_concat(concat, reorders, transposes_to_delete);
} else if (auto prelu = as_type_ptr<PRelu>(n)) {
sink_prelu(prelu, reorders, transposes_to_delete);
} else {
materialize_shapes(n, reorders, transposes_to_delete);
}

View File

@ -343,6 +343,58 @@ TEST(TransposeSinkingTest, SimpleUnary) {
EXPECT_EQ(after_count, 0);
}
TEST(TransposeSinkingTest, SinkingThroughPreLUWithScalarSlope) {
auto input = make_shared<Parameter>(ov::element::f32, ov::Shape{1, 105, 30, 30});
auto transpose_before =
make_shared<Transpose>(input,
make_shared<Constant>(ov::element::i64, ov::Shape{4}, std::vector<int64_t>{0, 2, 3, 1}));
auto prelu = make_shared<PRelu>(transpose_before,
make_shared<Constant>(ov::element::f32, ov::Shape{1}, std::vector<float>{0.8}));
auto transpose_after =
make_shared<Transpose>(prelu,
make_shared<Constant>(ov::element::i64, ov::Shape{4}, std::vector<int64_t>{0, 3, 1, 2}));
auto model = make_shared<ov::Model>(ov::OutputVector{transpose_after}, ov::ParameterVector{input});
size_t before_count = count_ops_of_type<Transpose>(model);
ov::pass::Manager pass_manager;
pass_manager.register_pass<TransposeSinking>();
pass_manager.run_passes(model);
size_t after_count = count_ops_of_type<Transpose>(model);
EXPECT_EQ(before_count, 2);
EXPECT_EQ(after_count, 0);
}
TEST(TransposeSinkingTest, SinkingThroughPreLUWithNonScalarSlope) {
auto input = make_shared<Parameter>(ov::element::f32, ov::Shape{1, 3, 3, 3});
auto transpose_before =
make_shared<Transpose>(input,
make_shared<Constant>(ov::element::i64, ov::Shape{4}, std::vector<int64_t>{0, 2, 3, 1}));
auto prelu =
make_shared<PRelu>(transpose_before,
make_shared<Constant>(ov::element::f32, ov::Shape{3}, std::vector<float>{0.8, 0.7, 0.1}));
auto transpose_after =
make_shared<Transpose>(prelu,
make_shared<Constant>(ov::element::i64, ov::Shape{4}, std::vector<int64_t>{0, 3, 1, 2}));
auto model = make_shared<ov::Model>(ov::OutputVector{transpose_after}, ov::ParameterVector{input});
size_t before_count = count_ops_of_type<Transpose>(model);
ov::pass::Manager pass_manager;
pass_manager.register_pass<TransposeSinking>();
pass_manager.run_passes(model);
size_t after_count = count_ops_of_type<Transpose>(model);
EXPECT_EQ(before_count, 2);
// Now Transpose Sinking is not applied to Prelu with non-scalar slope
EXPECT_EQ(after_count, 2);
}
// X (NCHW)
// |
// Transpose1