diff --git a/inference-engine/src/transformations/src/transformations/common_optimizations/nop_elimination.cpp b/inference-engine/src/transformations/src/transformations/common_optimizations/nop_elimination.cpp index 21a5d0dbe76..9bd439e42ce 100644 --- a/inference-engine/src/transformations/src/transformations/common_optimizations/nop_elimination.cpp +++ b/inference-engine/src/transformations/src/transformations/common_optimizations/nop_elimination.cpp @@ -93,15 +93,20 @@ static bool eliminate_reshape_v1(const std::shared_ptr& node) { return replace_output_update_name(node->output(0), input); } // eliminate redundant reshape, squeeze, or unsqueeze - if (is_type(input.get_node()) || - is_type(input.get_node()) || is_type(input.get_node())) { + auto input_node = input.get_node_shared_ptr(); + if (as_type_ptr(input_node) || + as_type_ptr(input_node) || + as_type_ptr(input_node)) { auto shape = node->get_output_shape(0); std::vector vi; vi.assign(shape.begin(), shape.end()); auto pat = opset3::Constant::create(element::i64, Shape{vi.size()}, vi); auto new_reshape = make_shared(input.get_node()->input_value(0), pat, false); - return replace_node_update_name(node, new_reshape); + new_reshape->set_friendly_name(node->get_friendly_name()); + copy_runtime_info({input_node, node}, new_reshape); + replace_node(node, new_reshape); + return true; } return false; diff --git a/inference-engine/src/transformations/src/transformations/rt_info/fused_names_attribute.cpp b/inference-engine/src/transformations/src/transformations/rt_info/fused_names_attribute.cpp index eef8b4b1271..cd3857b0d58 100644 --- a/inference-engine/src/transformations/src/transformations/rt_info/fused_names_attribute.cpp +++ b/inference-engine/src/transformations/src/transformations/rt_info/fused_names_attribute.cpp @@ -35,7 +35,7 @@ std::vector FusedNames::getVectorNames() const { } void FusedNames::fuseWith(const FusedNames &names) { - for (auto name : names.fused_names) { + for (const auto & name : names.fused_names) { fused_names.insert(name); } } diff --git a/inference-engine/tests/functional/inference_engine/transformations/nop_elimination.cpp b/inference-engine/tests/functional/inference_engine/transformations/nop_elimination.cpp index 71ff51c2936..3b2befc66d2 100644 --- a/inference-engine/tests/functional/inference_engine/transformations/nop_elimination.cpp +++ b/inference-engine/tests/functional/inference_engine/transformations/nop_elimination.cpp @@ -17,6 +17,7 @@ #include #include #include +#include #include "common_test_utils/ngraph_test_utils.hpp" @@ -139,6 +140,45 @@ TEST(nop_elimination, reshape_elimination_v1) { ASSERT_TRUE(count_ops_of_type(func_zero) == 1); } +TEST(nop_elimination, squeeze_reshape_elimination_check_info) { + std::shared_ptr f; + { + auto arg = std::make_shared(element::f32, PartialShape{8, 16, 1, 3}); + + auto relu = std::make_shared(arg); + relu->set_friendly_name("relu"); + + auto squeeze_axes = opset4::Constant::create(element::i64, Shape{1}, {2}); + auto squeeze = std::make_shared(relu, squeeze_axes); + squeeze->set_friendly_name("squeeze"); + + auto reshape_shape = opset4::Constant::create(element::i64, Shape{4}, {8, 16, 1, 3}); + auto reshape = std::make_shared(squeeze, reshape_shape, false); + reshape->set_friendly_name("reshape"); + + auto abs = std::make_shared(reshape); + + f = std::make_shared(NodeVector{abs}, ParameterVector{arg}); + } + + pass::Manager pass_manager; + pass_manager.register_pass(); + pass_manager.register_pass(); + pass_manager.run_passes(f); + + bool reshape_is_missing = true; + for (auto node : f->get_ops()) { + if (node->get_friendly_name() == "reshape") { + reshape_is_missing = false; + ASSERT_TRUE(std::dynamic_pointer_cast(node)); + auto original_names = getFusedNamesVector(node); + sort(original_names.begin(), original_names.end()); + ASSERT_EQ(original_names, std::vector({"reshape", "squeeze"})); + } + } + ASSERT_FALSE(reshape_is_missing); +} + TEST(nop_elimination, reshape_elimination_v1_dynamic) { auto arg = std::make_shared(element::i64, PartialShape::dynamic()); auto pattern = make_shared(element::i64, PartialShape::dynamic(1)); diff --git a/ngraph/core/src/graph_util.cpp b/ngraph/core/src/graph_util.cpp index 604b902fb3f..d9c0f44d6b4 100644 --- a/ngraph/core/src/graph_util.cpp +++ b/ngraph/core/src/graph_util.cpp @@ -902,10 +902,10 @@ bool ngraph::replace_output_update_name(Output output, const Output& if (has_result_output && !is_type(replacement.get_node())) { replacement.get_node()->set_friendly_name(output.get_node()->get_friendly_name()); - copy_runtime_info({replacement.get_node_shared_ptr(), output.get_node_shared_ptr()}, - replacement.get_node_shared_ptr()); } output.replace(replacement); + copy_runtime_info({replacement.get_node_shared_ptr(), output.get_node_shared_ptr()}, + replacement.get_node_shared_ptr()); return true; } return false;