Fixed replace_output_update_name method to preserve rt info (#1983)

* Fixed replace_output_update_name method to preserve rt info

* added test
This commit is contained in:
Gleb Kazantaev 2020-09-01 17:02:46 +03:00 committed by GitHub
parent 25856f4cdc
commit 34b595d218
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 51 additions and 6 deletions

View File

@ -93,15 +93,20 @@ static bool eliminate_reshape_v1(const std::shared_ptr<Node>& node) {
return replace_output_update_name(node->output(0), input);
}
// eliminate redundant reshape, squeeze, or unsqueeze
if (is_type<opset3::Squeeze>(input.get_node()) ||
is_type<opset3::Unsqueeze>(input.get_node()) || is_type<opset3::Reshape>(input.get_node())) {
auto input_node = input.get_node_shared_ptr();
if (as_type_ptr<opset3::Squeeze>(input_node) ||
as_type_ptr<opset3::Unsqueeze>(input_node) ||
as_type_ptr<opset3::Reshape>(input_node)) {
auto shape = node->get_output_shape(0);
std::vector<int64_t> vi;
vi.assign(shape.begin(), shape.end());
auto pat = opset3::Constant::create<int64_t>(element::i64, Shape{vi.size()}, vi);
auto new_reshape =
make_shared<opset3::Reshape>(input.get_node()->input_value(0), pat, false);
return replace_node_update_name(node, new_reshape);
new_reshape->set_friendly_name(node->get_friendly_name());
copy_runtime_info({input_node, node}, new_reshape);
replace_node(node, new_reshape);
return true;
}
return false;

View File

@ -35,7 +35,7 @@ std::vector<std::string> FusedNames::getVectorNames() const {
}
void FusedNames::fuseWith(const FusedNames &names) {
for (auto name : names.fused_names) {
for (const auto & name : names.fused_names) {
fused_names.insert(name);
}
}

View File

@ -17,6 +17,7 @@
#include <transformations/common_optimizations/nop_elimination.hpp>
#include <transformations/utils/utils.hpp>
#include <transformations/init_node_info.hpp>
#include <transformations/rt_info/fused_names_attribute.hpp>
#include "common_test_utils/ngraph_test_utils.hpp"
@ -139,6 +140,45 @@ TEST(nop_elimination, reshape_elimination_v1) {
ASSERT_TRUE(count_ops_of_type<op::v1::Reshape>(func_zero) == 1);
}
TEST(nop_elimination, squeeze_reshape_elimination_check_info) {
std::shared_ptr<Function> f;
{
auto arg = std::make_shared<opset4::Parameter>(element::f32, PartialShape{8, 16, 1, 3});
auto relu = std::make_shared<opset4::Relu>(arg);
relu->set_friendly_name("relu");
auto squeeze_axes = opset4::Constant::create(element::i64, Shape{1}, {2});
auto squeeze = std::make_shared<opset4::Squeeze>(relu, squeeze_axes);
squeeze->set_friendly_name("squeeze");
auto reshape_shape = opset4::Constant::create(element::i64, Shape{4}, {8, 16, 1, 3});
auto reshape = std::make_shared<opset4::Reshape>(squeeze, reshape_shape, false);
reshape->set_friendly_name("reshape");
auto abs = std::make_shared<opset4::Abs>(reshape);
f = std::make_shared<Function>(NodeVector{abs}, ParameterVector{arg});
}
pass::Manager pass_manager;
pass_manager.register_pass<pass::InitNodeInfo>();
pass_manager.register_pass<pass::NopElimination>();
pass_manager.run_passes(f);
bool reshape_is_missing = true;
for (auto node : f->get_ops()) {
if (node->get_friendly_name() == "reshape") {
reshape_is_missing = false;
ASSERT_TRUE(std::dynamic_pointer_cast<opset4::Reshape>(node));
auto original_names = getFusedNamesVector(node);
sort(original_names.begin(), original_names.end());
ASSERT_EQ(original_names, std::vector<std::string>({"reshape", "squeeze"}));
}
}
ASSERT_FALSE(reshape_is_missing);
}
TEST(nop_elimination, reshape_elimination_v1_dynamic) {
auto arg = std::make_shared<op::Parameter>(element::i64, PartialShape::dynamic());
auto pattern = make_shared<op::Parameter>(element::i64, PartialShape::dynamic(1));

View File

@ -902,10 +902,10 @@ bool ngraph::replace_output_update_name(Output<Node> output, const Output<Node>&
if (has_result_output && !is_type<ngraph::op::Parameter>(replacement.get_node()))
{
replacement.get_node()->set_friendly_name(output.get_node()->get_friendly_name());
copy_runtime_info({replacement.get_node_shared_ptr(), output.get_node_shared_ptr()},
replacement.get_node_shared_ptr());
}
output.replace(replacement);
copy_runtime_info({replacement.get_node_shared_ptr(), output.get_node_shared_ptr()},
replacement.get_node_shared_ptr());
return true;
}
return false;