Check if input of Unsqueeze is parameter during NopElimination (#1622)

This commit is contained in:
Mateusz Bencer 2020-08-10 13:45:58 +02:00 committed by GitHub
parent ae48d9deb8
commit e88c7b5ed7
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 18 additions and 1 deletions

View File

@ -254,7 +254,6 @@ static bool eliminate_unsqueeze(const std::shared_ptr<Node>& node)
auto unsqueeze = as_type_ptr<opset3::Unsqueeze>(node);
auto input = unsqueeze->input_value(0).get_node_shared_ptr();
auto data_shape = input->input_value(0).get_partial_shape();
auto squeeze = as_type_ptr<opset3::Squeeze>(input);
auto replace_unsqueeze_only = [&](const vector<int64_t>& axes) {
auto axes_const = opset3::Constant::create<int64_t>(element::i64, Shape{axes.size()}, axes);
@ -269,6 +268,7 @@ static bool eliminate_unsqueeze(const std::shared_ptr<Node>& node)
// eliminate redundant squeeze->unsqueeze
if (squeeze)
{
const auto& data_shape = squeeze->input_value(0).get_partial_shape();
if (ngraph::compare_constants(squeeze->input_value(1).get_node_shared_ptr(),
unsqueeze->input_value(1).get_node_shared_ptr()))
{
@ -310,6 +310,7 @@ static bool eliminate_unsqueeze(const std::shared_ptr<Node>& node)
auto unsqueeze_i = as_type_ptr<opset3::Unsqueeze>(input);
if (unsqueeze_i)
{
const auto& data_shape = unsqueeze_i->input_value(0).get_partial_shape();
if (data_shape.rank().is_dynamic() || out_shape.rank().is_dynamic())
{
return false;

View File

@ -225,6 +225,22 @@ TEST(nop_elimination, concat_elimination_single_input_dynamic)
ASSERT_EQ(count_ops_of_type<op::v0::Concat>(f), 0);
}
TEST(nop_elimination, unsqueeze_elimination)
{
const auto axis = op::Constant::create<int64_t>(element::i64, {}, {0});
const auto A = make_shared<op::Parameter>(
element::f32, PartialShape{3, Dimension::dynamic(), Dimension::dynamic()});
const auto unsqueeze = make_shared<op::v0::Unsqueeze>(A, axis);
auto f = make_shared<Function>(unsqueeze, ParameterVector{A});
pass::Manager pass_manager;
pass_manager.register_pass<pass::Validate>();
pass_manager.register_pass<pass::NopElimination>();
pass_manager.run_passes(f);
ASSERT_EQ(count_ops_of_type<op::v0::Unsqueeze>(f), 1);
}
TEST(nop_elimination, squeeze_unsqueeze_overlap_elimination)
{
auto check_usecase = [](const PartialShape& shape,