Fixed replace_output_update_name method to preserve rt info (#1983)
* Fixed replace_output_update_name method to preserve rt info * added test
This commit is contained in:
parent
25856f4cdc
commit
34b595d218
@ -93,15 +93,20 @@ static bool eliminate_reshape_v1(const std::shared_ptr<Node>& node) {
|
|||||||
return replace_output_update_name(node->output(0), input);
|
return replace_output_update_name(node->output(0), input);
|
||||||
}
|
}
|
||||||
// eliminate redundant reshape, squeeze, or unsqueeze
|
// eliminate redundant reshape, squeeze, or unsqueeze
|
||||||
if (is_type<opset3::Squeeze>(input.get_node()) ||
|
auto input_node = input.get_node_shared_ptr();
|
||||||
is_type<opset3::Unsqueeze>(input.get_node()) || is_type<opset3::Reshape>(input.get_node())) {
|
if (as_type_ptr<opset3::Squeeze>(input_node) ||
|
||||||
|
as_type_ptr<opset3::Unsqueeze>(input_node) ||
|
||||||
|
as_type_ptr<opset3::Reshape>(input_node)) {
|
||||||
auto shape = node->get_output_shape(0);
|
auto shape = node->get_output_shape(0);
|
||||||
std::vector<int64_t> vi;
|
std::vector<int64_t> vi;
|
||||||
vi.assign(shape.begin(), shape.end());
|
vi.assign(shape.begin(), shape.end());
|
||||||
auto pat = opset3::Constant::create<int64_t>(element::i64, Shape{vi.size()}, vi);
|
auto pat = opset3::Constant::create<int64_t>(element::i64, Shape{vi.size()}, vi);
|
||||||
auto new_reshape =
|
auto new_reshape =
|
||||||
make_shared<opset3::Reshape>(input.get_node()->input_value(0), pat, false);
|
make_shared<opset3::Reshape>(input.get_node()->input_value(0), pat, false);
|
||||||
return replace_node_update_name(node, new_reshape);
|
new_reshape->set_friendly_name(node->get_friendly_name());
|
||||||
|
copy_runtime_info({input_node, node}, new_reshape);
|
||||||
|
replace_node(node, new_reshape);
|
||||||
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
return false;
|
return false;
|
||||||
|
@ -35,7 +35,7 @@ std::vector<std::string> FusedNames::getVectorNames() const {
|
|||||||
}
|
}
|
||||||
|
|
||||||
void FusedNames::fuseWith(const FusedNames &names) {
|
void FusedNames::fuseWith(const FusedNames &names) {
|
||||||
for (auto name : names.fused_names) {
|
for (const auto & name : names.fused_names) {
|
||||||
fused_names.insert(name);
|
fused_names.insert(name);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -17,6 +17,7 @@
|
|||||||
#include <transformations/common_optimizations/nop_elimination.hpp>
|
#include <transformations/common_optimizations/nop_elimination.hpp>
|
||||||
#include <transformations/utils/utils.hpp>
|
#include <transformations/utils/utils.hpp>
|
||||||
#include <transformations/init_node_info.hpp>
|
#include <transformations/init_node_info.hpp>
|
||||||
|
#include <transformations/rt_info/fused_names_attribute.hpp>
|
||||||
|
|
||||||
#include "common_test_utils/ngraph_test_utils.hpp"
|
#include "common_test_utils/ngraph_test_utils.hpp"
|
||||||
|
|
||||||
@ -139,6 +140,45 @@ TEST(nop_elimination, reshape_elimination_v1) {
|
|||||||
ASSERT_TRUE(count_ops_of_type<op::v1::Reshape>(func_zero) == 1);
|
ASSERT_TRUE(count_ops_of_type<op::v1::Reshape>(func_zero) == 1);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST(nop_elimination, squeeze_reshape_elimination_check_info) {
|
||||||
|
std::shared_ptr<Function> f;
|
||||||
|
{
|
||||||
|
auto arg = std::make_shared<opset4::Parameter>(element::f32, PartialShape{8, 16, 1, 3});
|
||||||
|
|
||||||
|
auto relu = std::make_shared<opset4::Relu>(arg);
|
||||||
|
relu->set_friendly_name("relu");
|
||||||
|
|
||||||
|
auto squeeze_axes = opset4::Constant::create(element::i64, Shape{1}, {2});
|
||||||
|
auto squeeze = std::make_shared<opset4::Squeeze>(relu, squeeze_axes);
|
||||||
|
squeeze->set_friendly_name("squeeze");
|
||||||
|
|
||||||
|
auto reshape_shape = opset4::Constant::create(element::i64, Shape{4}, {8, 16, 1, 3});
|
||||||
|
auto reshape = std::make_shared<opset4::Reshape>(squeeze, reshape_shape, false);
|
||||||
|
reshape->set_friendly_name("reshape");
|
||||||
|
|
||||||
|
auto abs = std::make_shared<opset4::Abs>(reshape);
|
||||||
|
|
||||||
|
f = std::make_shared<Function>(NodeVector{abs}, ParameterVector{arg});
|
||||||
|
}
|
||||||
|
|
||||||
|
pass::Manager pass_manager;
|
||||||
|
pass_manager.register_pass<pass::InitNodeInfo>();
|
||||||
|
pass_manager.register_pass<pass::NopElimination>();
|
||||||
|
pass_manager.run_passes(f);
|
||||||
|
|
||||||
|
bool reshape_is_missing = true;
|
||||||
|
for (auto node : f->get_ops()) {
|
||||||
|
if (node->get_friendly_name() == "reshape") {
|
||||||
|
reshape_is_missing = false;
|
||||||
|
ASSERT_TRUE(std::dynamic_pointer_cast<opset4::Reshape>(node));
|
||||||
|
auto original_names = getFusedNamesVector(node);
|
||||||
|
sort(original_names.begin(), original_names.end());
|
||||||
|
ASSERT_EQ(original_names, std::vector<std::string>({"reshape", "squeeze"}));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
ASSERT_FALSE(reshape_is_missing);
|
||||||
|
}
|
||||||
|
|
||||||
TEST(nop_elimination, reshape_elimination_v1_dynamic) {
|
TEST(nop_elimination, reshape_elimination_v1_dynamic) {
|
||||||
auto arg = std::make_shared<op::Parameter>(element::i64, PartialShape::dynamic());
|
auto arg = std::make_shared<op::Parameter>(element::i64, PartialShape::dynamic());
|
||||||
auto pattern = make_shared<op::Parameter>(element::i64, PartialShape::dynamic(1));
|
auto pattern = make_shared<op::Parameter>(element::i64, PartialShape::dynamic(1));
|
||||||
|
@ -902,10 +902,10 @@ bool ngraph::replace_output_update_name(Output<Node> output, const Output<Node>&
|
|||||||
if (has_result_output && !is_type<ngraph::op::Parameter>(replacement.get_node()))
|
if (has_result_output && !is_type<ngraph::op::Parameter>(replacement.get_node()))
|
||||||
{
|
{
|
||||||
replacement.get_node()->set_friendly_name(output.get_node()->get_friendly_name());
|
replacement.get_node()->set_friendly_name(output.get_node()->get_friendly_name());
|
||||||
copy_runtime_info({replacement.get_node_shared_ptr(), output.get_node_shared_ptr()},
|
|
||||||
replacement.get_node_shared_ptr());
|
|
||||||
}
|
}
|
||||||
output.replace(replacement);
|
output.replace(replacement);
|
||||||
|
copy_runtime_info({replacement.get_node_shared_ptr(), output.get_node_shared_ptr()},
|
||||||
|
replacement.get_node_shared_ptr());
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
return false;
|
return false;
|
||||||
|
Loading…
Reference in New Issue
Block a user