Fixed ReduceL2Decomposition (#3452)
* Fixed ReduceL2Decomposition * Added test
This commit is contained in:
parent
256e047ad2
commit
29b8ffa40b
@ -28,7 +28,7 @@ ngraph::pass::ReduceL2Decomposition::ReduceL2Decomposition() {
|
||||
auto square = std::make_shared<ngraph::opset4::Power>(reduce_l2_node->input_value(0), const_2);
|
||||
auto reduce_sum = register_new_node<ngraph::opset4::ReduceSum>(square, reduce_l2_node->input_value(1), reduce_l2_node->get_keep_dims());
|
||||
auto sqrt = std::make_shared<ngraph::opset4::Sqrt>(reduce_sum);
|
||||
reduce_sum->set_friendly_name(m.get_match_root()->get_friendly_name());
|
||||
sqrt->set_friendly_name(m.get_match_root()->get_friendly_name());
|
||||
ngraph::copy_runtime_info(reduce_l2_node,
|
||||
{sqrt, reduce_sum, square, const_2});
|
||||
ngraph::replace_node(m.get_match_root(), sqrt);
|
||||
|
@ -23,10 +23,10 @@ TEST(TransformationTests, ReduceL2DecompositionTest) {
|
||||
{
|
||||
auto data = std::make_shared<ngraph::opset4::Parameter>(ngraph::element::f32, ngraph::PartialShape::dynamic(1));
|
||||
auto axes = std::make_shared<ngraph::opset4::Parameter>(ngraph::element::i32, ngraph::Shape{1});
|
||||
auto reduce_l1 = std::make_shared<ngraph::opset4::ReduceL2>(data, axes, true);
|
||||
|
||||
f = std::make_shared<ngraph::Function>(ngraph::NodeVector{reduce_l1}, ngraph::ParameterVector{data, axes});
|
||||
auto reduce_l2 = std::make_shared<ngraph::opset4::ReduceL2>(data, axes, true);
|
||||
reduce_l2->set_friendly_name("reduce_l2");
|
||||
|
||||
f = std::make_shared<ngraph::Function>(ngraph::NodeVector{reduce_l2}, ngraph::ParameterVector{data, axes});
|
||||
ngraph::pass::Manager manager;
|
||||
manager.register_pass<ngraph::pass::InitNodeInfo>();
|
||||
manager.register_pass<ngraph::pass::ReduceL2Decomposition>();
|
||||
@ -46,4 +46,8 @@ TEST(TransformationTests, ReduceL2DecompositionTest) {
|
||||
|
||||
auto res = compare_functions(f, f_ref);
|
||||
ASSERT_TRUE(res.first) << res.second;
|
||||
|
||||
auto result_node_of_converted_f = f->get_output_op(0);
|
||||
auto output_node = result_node_of_converted_f->input(0).get_source_output().get_node_shared_ptr();
|
||||
ASSERT_TRUE(output_node->get_friendly_name() == "reduce_l2") << "Transformation ReduceL2Decomposition should keep output names.\n";
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user