Fixed ReduceL2Decomposition (#3452)

* Fixed ReduceL2Decomposition

* Added test
This commit is contained in:
Liubov Batanina 2020-12-04 13:56:01 +03:00 committed by GitHub
parent 256e047ad2
commit 29b8ffa40b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 8 additions and 4 deletions

View File

@ -28,7 +28,7 @@ ngraph::pass::ReduceL2Decomposition::ReduceL2Decomposition() {
auto square = std::make_shared<ngraph::opset4::Power>(reduce_l2_node->input_value(0), const_2); auto square = std::make_shared<ngraph::opset4::Power>(reduce_l2_node->input_value(0), const_2);
auto reduce_sum = register_new_node<ngraph::opset4::ReduceSum>(square, reduce_l2_node->input_value(1), reduce_l2_node->get_keep_dims()); auto reduce_sum = register_new_node<ngraph::opset4::ReduceSum>(square, reduce_l2_node->input_value(1), reduce_l2_node->get_keep_dims());
auto sqrt = std::make_shared<ngraph::opset4::Sqrt>(reduce_sum); auto sqrt = std::make_shared<ngraph::opset4::Sqrt>(reduce_sum);
reduce_sum->set_friendly_name(m.get_match_root()->get_friendly_name()); sqrt->set_friendly_name(m.get_match_root()->get_friendly_name());
ngraph::copy_runtime_info(reduce_l2_node, ngraph::copy_runtime_info(reduce_l2_node,
{sqrt, reduce_sum, square, const_2}); {sqrt, reduce_sum, square, const_2});
ngraph::replace_node(m.get_match_root(), sqrt); ngraph::replace_node(m.get_match_root(), sqrt);

View File

@ -23,10 +23,10 @@ TEST(TransformationTests, ReduceL2DecompositionTest) {
{ {
auto data = std::make_shared<ngraph::opset4::Parameter>(ngraph::element::f32, ngraph::PartialShape::dynamic(1)); auto data = std::make_shared<ngraph::opset4::Parameter>(ngraph::element::f32, ngraph::PartialShape::dynamic(1));
auto axes = std::make_shared<ngraph::opset4::Parameter>(ngraph::element::i32, ngraph::Shape{1}); auto axes = std::make_shared<ngraph::opset4::Parameter>(ngraph::element::i32, ngraph::Shape{1});
auto reduce_l1 = std::make_shared<ngraph::opset4::ReduceL2>(data, axes, true); auto reduce_l2 = std::make_shared<ngraph::opset4::ReduceL2>(data, axes, true);
reduce_l2->set_friendly_name("reduce_l2");
f = std::make_shared<ngraph::Function>(ngraph::NodeVector{reduce_l1}, ngraph::ParameterVector{data, axes});
f = std::make_shared<ngraph::Function>(ngraph::NodeVector{reduce_l2}, ngraph::ParameterVector{data, axes});
ngraph::pass::Manager manager; ngraph::pass::Manager manager;
manager.register_pass<ngraph::pass::InitNodeInfo>(); manager.register_pass<ngraph::pass::InitNodeInfo>();
manager.register_pass<ngraph::pass::ReduceL2Decomposition>(); manager.register_pass<ngraph::pass::ReduceL2Decomposition>();
@ -46,4 +46,8 @@ TEST(TransformationTests, ReduceL2DecompositionTest) {
auto res = compare_functions(f, f_ref); auto res = compare_functions(f, f_ref);
ASSERT_TRUE(res.first) << res.second; ASSERT_TRUE(res.first) << res.second;
auto result_node_of_converted_f = f->get_output_op(0);
auto output_node = result_node_of_converted_f->input(0).get_source_output().get_node_shared_ptr();
ASSERT_TRUE(output_node->get_friendly_name() == "reduce_l2") << "Transformation ReduceL2Decomposition should keep output names.\n";
} }