Check if input of Unsqueeze is parameter during NopElimination (#1622)
This commit is contained in:
parent
ae48d9deb8
commit
e88c7b5ed7
@ -254,7 +254,6 @@ static bool eliminate_unsqueeze(const std::shared_ptr<Node>& node)
|
||||
|
||||
auto unsqueeze = as_type_ptr<opset3::Unsqueeze>(node);
|
||||
auto input = unsqueeze->input_value(0).get_node_shared_ptr();
|
||||
auto data_shape = input->input_value(0).get_partial_shape();
|
||||
auto squeeze = as_type_ptr<opset3::Squeeze>(input);
|
||||
auto replace_unsqueeze_only = [&](const vector<int64_t>& axes) {
|
||||
auto axes_const = opset3::Constant::create<int64_t>(element::i64, Shape{axes.size()}, axes);
|
||||
@ -269,6 +268,7 @@ static bool eliminate_unsqueeze(const std::shared_ptr<Node>& node)
|
||||
// eliminate redundant squeeze->unsqueeze
|
||||
if (squeeze)
|
||||
{
|
||||
const auto& data_shape = squeeze->input_value(0).get_partial_shape();
|
||||
if (ngraph::compare_constants(squeeze->input_value(1).get_node_shared_ptr(),
|
||||
unsqueeze->input_value(1).get_node_shared_ptr()))
|
||||
{
|
||||
@ -310,6 +310,7 @@ static bool eliminate_unsqueeze(const std::shared_ptr<Node>& node)
|
||||
auto unsqueeze_i = as_type_ptr<opset3::Unsqueeze>(input);
|
||||
if (unsqueeze_i)
|
||||
{
|
||||
const auto& data_shape = unsqueeze_i->input_value(0).get_partial_shape();
|
||||
if (data_shape.rank().is_dynamic() || out_shape.rank().is_dynamic())
|
||||
{
|
||||
return false;
|
||||
|
@ -225,6 +225,22 @@ TEST(nop_elimination, concat_elimination_single_input_dynamic)
|
||||
ASSERT_EQ(count_ops_of_type<op::v0::Concat>(f), 0);
|
||||
}
|
||||
|
||||
TEST(nop_elimination, unsqueeze_elimination)
|
||||
{
|
||||
const auto axis = op::Constant::create<int64_t>(element::i64, {}, {0});
|
||||
const auto A = make_shared<op::Parameter>(
|
||||
element::f32, PartialShape{3, Dimension::dynamic(), Dimension::dynamic()});
|
||||
const auto unsqueeze = make_shared<op::v0::Unsqueeze>(A, axis);
|
||||
auto f = make_shared<Function>(unsqueeze, ParameterVector{A});
|
||||
|
||||
pass::Manager pass_manager;
|
||||
pass_manager.register_pass<pass::Validate>();
|
||||
pass_manager.register_pass<pass::NopElimination>();
|
||||
pass_manager.run_passes(f);
|
||||
|
||||
ASSERT_EQ(count_ops_of_type<op::v0::Unsqueeze>(f), 1);
|
||||
}
|
||||
|
||||
TEST(nop_elimination, squeeze_unsqueeze_overlap_elimination)
|
||||
{
|
||||
auto check_usecase = [](const PartialShape& shape,
|
||||
|
Loading…
Reference in New Issue
Block a user