Fixed replace_output_update_name method to preserve rt info (#1983)
* Fixed replace_output_update_name method to preserve rt info * added test
This commit is contained in:
parent
25856f4cdc
commit
34b595d218
@ -93,15 +93,20 @@ static bool eliminate_reshape_v1(const std::shared_ptr<Node>& node) {
|
||||
return replace_output_update_name(node->output(0), input);
|
||||
}
|
||||
// eliminate redundant reshape, squeeze, or unsqueeze
|
||||
if (is_type<opset3::Squeeze>(input.get_node()) ||
|
||||
is_type<opset3::Unsqueeze>(input.get_node()) || is_type<opset3::Reshape>(input.get_node())) {
|
||||
auto input_node = input.get_node_shared_ptr();
|
||||
if (as_type_ptr<opset3::Squeeze>(input_node) ||
|
||||
as_type_ptr<opset3::Unsqueeze>(input_node) ||
|
||||
as_type_ptr<opset3::Reshape>(input_node)) {
|
||||
auto shape = node->get_output_shape(0);
|
||||
std::vector<int64_t> vi;
|
||||
vi.assign(shape.begin(), shape.end());
|
||||
auto pat = opset3::Constant::create<int64_t>(element::i64, Shape{vi.size()}, vi);
|
||||
auto new_reshape =
|
||||
make_shared<opset3::Reshape>(input.get_node()->input_value(0), pat, false);
|
||||
return replace_node_update_name(node, new_reshape);
|
||||
new_reshape->set_friendly_name(node->get_friendly_name());
|
||||
copy_runtime_info({input_node, node}, new_reshape);
|
||||
replace_node(node, new_reshape);
|
||||
return true;
|
||||
}
|
||||
|
||||
return false;
|
||||
|
@ -35,7 +35,7 @@ std::vector<std::string> FusedNames::getVectorNames() const {
|
||||
}
|
||||
|
||||
void FusedNames::fuseWith(const FusedNames &names) {
|
||||
for (auto name : names.fused_names) {
|
||||
for (const auto & name : names.fused_names) {
|
||||
fused_names.insert(name);
|
||||
}
|
||||
}
|
||||
|
@ -17,6 +17,7 @@
|
||||
#include <transformations/common_optimizations/nop_elimination.hpp>
|
||||
#include <transformations/utils/utils.hpp>
|
||||
#include <transformations/init_node_info.hpp>
|
||||
#include <transformations/rt_info/fused_names_attribute.hpp>
|
||||
|
||||
#include "common_test_utils/ngraph_test_utils.hpp"
|
||||
|
||||
@ -139,6 +140,45 @@ TEST(nop_elimination, reshape_elimination_v1) {
|
||||
ASSERT_TRUE(count_ops_of_type<op::v1::Reshape>(func_zero) == 1);
|
||||
}
|
||||
|
||||
TEST(nop_elimination, squeeze_reshape_elimination_check_info) {
|
||||
std::shared_ptr<Function> f;
|
||||
{
|
||||
auto arg = std::make_shared<opset4::Parameter>(element::f32, PartialShape{8, 16, 1, 3});
|
||||
|
||||
auto relu = std::make_shared<opset4::Relu>(arg);
|
||||
relu->set_friendly_name("relu");
|
||||
|
||||
auto squeeze_axes = opset4::Constant::create(element::i64, Shape{1}, {2});
|
||||
auto squeeze = std::make_shared<opset4::Squeeze>(relu, squeeze_axes);
|
||||
squeeze->set_friendly_name("squeeze");
|
||||
|
||||
auto reshape_shape = opset4::Constant::create(element::i64, Shape{4}, {8, 16, 1, 3});
|
||||
auto reshape = std::make_shared<opset4::Reshape>(squeeze, reshape_shape, false);
|
||||
reshape->set_friendly_name("reshape");
|
||||
|
||||
auto abs = std::make_shared<opset4::Abs>(reshape);
|
||||
|
||||
f = std::make_shared<Function>(NodeVector{abs}, ParameterVector{arg});
|
||||
}
|
||||
|
||||
pass::Manager pass_manager;
|
||||
pass_manager.register_pass<pass::InitNodeInfo>();
|
||||
pass_manager.register_pass<pass::NopElimination>();
|
||||
pass_manager.run_passes(f);
|
||||
|
||||
bool reshape_is_missing = true;
|
||||
for (auto node : f->get_ops()) {
|
||||
if (node->get_friendly_name() == "reshape") {
|
||||
reshape_is_missing = false;
|
||||
ASSERT_TRUE(std::dynamic_pointer_cast<opset4::Reshape>(node));
|
||||
auto original_names = getFusedNamesVector(node);
|
||||
sort(original_names.begin(), original_names.end());
|
||||
ASSERT_EQ(original_names, std::vector<std::string>({"reshape", "squeeze"}));
|
||||
}
|
||||
}
|
||||
ASSERT_FALSE(reshape_is_missing);
|
||||
}
|
||||
|
||||
TEST(nop_elimination, reshape_elimination_v1_dynamic) {
|
||||
auto arg = std::make_shared<op::Parameter>(element::i64, PartialShape::dynamic());
|
||||
auto pattern = make_shared<op::Parameter>(element::i64, PartialShape::dynamic(1));
|
||||
|
@ -902,10 +902,10 @@ bool ngraph::replace_output_update_name(Output<Node> output, const Output<Node>&
|
||||
if (has_result_output && !is_type<ngraph::op::Parameter>(replacement.get_node()))
|
||||
{
|
||||
replacement.get_node()->set_friendly_name(output.get_node()->get_friendly_name());
|
||||
copy_runtime_info({replacement.get_node_shared_ptr(), output.get_node_shared_ptr()},
|
||||
replacement.get_node_shared_ptr());
|
||||
}
|
||||
output.replace(replacement);
|
||||
copy_runtime_info({replacement.get_node_shared_ptr(), output.get_node_shared_ptr()},
|
||||
replacement.get_node_shared_ptr());
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
|
Loading…
Reference in New Issue
Block a user