* Fix node name issue introduced by #5854 * Compare names in TransposeFuse tests
This commit is contained in:
parent
226cb952ae
commit
14bbf7f5e2
@ -212,8 +212,9 @@ ngraph::pass::TransposeFuse::TransposeFuse() {
|
||||
auto new_order = ngraph::opset7::Constant::create(element::i64, {order2.size()}, order2);
|
||||
auto new_transpose = register_new_node<ngraph::opset7::Transpose>(input, new_order);
|
||||
|
||||
new_transpose->set_friendly_name(m.get_match_root()->get_friendly_name());
|
||||
ngraph::copy_runtime_info({ transpose1, transpose2 }, new_transpose);
|
||||
ngraph::replace_node(transpose2, new_transpose);
|
||||
ngraph::replace_node(m.get_match_root(), new_transpose);
|
||||
}
|
||||
|
||||
return true;
|
||||
|
@ -239,10 +239,13 @@ TEST(TransformationTests, TransposeFuses) {
|
||||
auto input = std::make_shared<ngraph::opset6::Parameter>(ngraph::element::f32, ngraph::Shape{ 1, 2, 640, 20, 2, 2 });
|
||||
auto tr1_order = ngraph::opset6::Constant::create(ngraph::element::i64, ngraph::Shape{ 6 }, { 0, 5, 1, 2, 3, 4 });
|
||||
auto transpose1 = std::make_shared<ngraph::opset6::Transpose>(input, tr1_order);
|
||||
transpose1->set_friendly_name("transpose1");
|
||||
auto tr2_order = ngraph::opset6::Constant::create(ngraph::element::i64, ngraph::Shape{ 6 }, { 0, 1, 3, 4, 2, 5 });
|
||||
auto transpose2 = std::make_shared<ngraph::opset6::Transpose>(transpose1, tr2_order);
|
||||
transpose2->set_friendly_name("transpose2");
|
||||
auto add_const = ngraph::opset6::Constant::create(ngraph::element::f32, ngraph::Shape{ 1 }, { 1 });
|
||||
auto add = std::make_shared<ngraph::opset6::Add>(transpose2, add_const);
|
||||
add->set_friendly_name("add");
|
||||
|
||||
f = std::make_shared<ngraph::Function>(ngraph::NodeVector{ add }, ngraph::ParameterVector{ input });
|
||||
|
||||
@ -257,12 +260,15 @@ TEST(TransformationTests, TransposeFuses) {
|
||||
auto input = std::make_shared<ngraph::opset6::Parameter>(ngraph::element::f32, ngraph::Shape{ 1, 2, 640, 20, 2, 2 });
|
||||
auto tr_order = ngraph::opset6::Constant::create(ngraph::element::i64, ngraph::Shape{ 6 }, { 0, 5, 2, 3, 1, 4 });
|
||||
auto transpose = std::make_shared<ngraph::opset6::Transpose>(input, tr_order);
|
||||
transpose->set_friendly_name("transpose2");
|
||||
auto add_const = ngraph::opset6::Constant::create(ngraph::element::f32, ngraph::Shape{ 1 }, { 1 });
|
||||
auto add = std::make_shared<ngraph::opset6::Add>(transpose, add_const);
|
||||
add->set_friendly_name("add");
|
||||
|
||||
f_ref = std::make_shared<ngraph::Function>(ngraph::NodeVector{ add }, ngraph::ParameterVector{ input });
|
||||
}
|
||||
|
||||
auto res = compare_functions(f, f_ref);
|
||||
ASSERT_TRUE(res.first) << res.second;
|
||||
const FunctionsComparator func_comparator = FunctionsComparator::with_default().enable(FunctionsComparator::NAMES);
|
||||
const FunctionsComparator::Result res = func_comparator(f, f_ref);
|
||||
ASSERT_TRUE(res.valid) << res.message;
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user