diff --git a/ngraph/core/src/pass/constant_folding.cpp b/ngraph/core/src/pass/constant_folding.cpp index 6bbeb759704..a09ecffb03b 100644 --- a/ngraph/core/src/pass/constant_folding.cpp +++ b/ngraph/core/src/pass/constant_folding.cpp @@ -27,20 +27,10 @@ bool ngraph::pass::ConstantFolding::run_on_function(std::shared_ptrget_ordered_ops()) + for (const auto& node : f->get_ordered_ops()) { node->revalidate_and_infer_types(); - // recursively constant fold operators containing subgraphs (ie: TensorIterator) - if (auto sub_graph_node = std::dynamic_pointer_cast(node)) - { - if (auto sub_graph = sub_graph_node->get_function()) - { - rewritten |= run_on_function(sub_graph); - continue; - } - } - OutputVector replacements(node->get_output_size()); if (node->constant_fold(replacements, node->input_values())) { @@ -72,6 +62,17 @@ bool ngraph::pass::ConstantFolding::run_on_function(std::shared_ptr(node)) + { + if (const auto& sub_graph = sub_graph_node->get_function()) + { + rewritten |= run_on_function(sub_graph); + } + } + } } return rewritten; diff --git a/ngraph/test/constant_folding.cpp b/ngraph/test/constant_folding.cpp index ed315acac9e..3c039e372e1 100644 --- a/ngraph/test/constant_folding.cpp +++ b/ngraph/test/constant_folding.cpp @@ -17,6 +17,7 @@ #include "gtest/gtest.h" #include "ngraph/ngraph.hpp" +#include "ngraph/opsets/opset5.hpp" #include "ngraph/pass/constant_folding.hpp" #include "ngraph/pass/manager.hpp" #include "util/all_close_f.hpp" @@ -3107,3 +3108,65 @@ TEST(constant_folding, disable_constant_folding) ASSERT_EQ(count_ops_of_type(f), 1); ASSERT_EQ(count_ops_of_type(f), 1); } + +TEST(constant_folding, constant_loop) +{ + auto X = make_shared( + element::f32, Shape{2, 1, 3}, std::vector{0, 1, 2, 3, 4, 5}); + auto Y = + make_shared(element::f32, Shape{1, 1, 3}, std::vector{1, 2, 3}); + + // Body parameters + auto Xi = make_shared(element::f32, PartialShape::dynamic()); + auto Yi = make_shared(element::f32, PartialShape::dynamic()); + auto body_condition = std::make_shared( + ngraph::element::boolean, ngraph::Shape{1}, true); + + auto trip_count = + std::make_shared(ngraph::element::i64, ngraph::Shape{1}, 2); + auto exec_condition = std::make_shared( + ngraph::element::boolean, ngraph::Shape{1}, true); + // Body + auto sum = make_shared(Xi, Yi); + auto body = + make_shared(OutputVector{body_condition, sum}, ParameterVector{Xi, Yi}); + auto loop = make_shared(trip_count, exec_condition); + loop->set_function(body); + loop->set_special_body_ports(ngraph::opset5::Loop::SpecialBodyPorts{-1, 0}); + + loop->set_sliced_input(Xi, X, 0, 1, 1, -1, 0); + loop->set_invariant_input(Yi, Y); + + auto out0 = loop->get_iter_value(sum, -1); + auto out1 = loop->get_concatenated_slices(sum, 0, 1, 1, -1, 0); + + auto result0 = make_shared(out0); + auto result1 = make_shared(out1); + + auto results = ResultVector{result0, result1}; + auto f = make_shared(results, ParameterVector{}); + + pass::Manager pass_manager; + pass_manager.register_pass(); + pass_manager.run_passes(f); + + ASSERT_EQ(count_ops_of_type(f), 0); + ASSERT_EQ(count_ops_of_type(f), 2); + + auto result_node_0 = + as_type_ptr(f->get_results().at(0)->input_value(0).get_node_shared_ptr()); + auto result_node_1 = + as_type_ptr(f->get_results().at(1)->input_value(0).get_node_shared_ptr()); + ASSERT_TRUE(result_node_0); + ASSERT_TRUE(result_node_1); + + const ngraph::Shape shape_0{1, 1, 3}; + const ngraph::Shape shape_1{2, 1, 3}; + + ASSERT_EQ(shape_0, result_node_0->get_output_shape(0)); + ASSERT_EQ(shape_1, result_node_1->get_output_shape(0)); + std::vector expected_0{4, 6, 8}; + std::vector expected_1{1, 3, 5, 4, 6, 8}; + range_test_check(result_node_0->cast_vector(), expected_0); + range_test_check(result_node_1->cast_vector(), expected_1); +}