fix constant folding for sub graph ops (#3534)

This commit is contained in:
Ivan Tikhonov 2020-12-10 11:56:16 +03:00 committed by GitHub
parent 6f512142b6
commit 1ac3caf472
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 75 additions and 11 deletions

View File

@ -27,20 +27,10 @@ bool ngraph::pass::ConstantFolding::run_on_function(std::shared_ptr<ngraph::Func
{
bool rewritten = false;
for (auto&& node : f->get_ordered_ops())
for (const auto& node : f->get_ordered_ops())
{
node->revalidate_and_infer_types();
// recursively constant fold operators containing subgraphs (ie: TensorIterator)
if (auto sub_graph_node = std::dynamic_pointer_cast<op::util::SubGraphOp>(node))
{
if (auto sub_graph = sub_graph_node->get_function())
{
rewritten |= run_on_function(sub_graph);
continue;
}
}
OutputVector replacements(node->get_output_size());
if (node->constant_fold(replacements, node->input_values()))
{
@ -72,6 +62,17 @@ bool ngraph::pass::ConstantFolding::run_on_function(std::shared_ptr<ngraph::Func
}
}
}
else
{
// recursively constant fold operators containing subgraphs (ie: TensorIterator, Loop)
if (auto sub_graph_node = std::dynamic_pointer_cast<op::util::SubGraphOp>(node))
{
if (const auto& sub_graph = sub_graph_node->get_function())
{
rewritten |= run_on_function(sub_graph);
}
}
}
}
return rewritten;

View File

@ -17,6 +17,7 @@
#include "gtest/gtest.h"
#include "ngraph/ngraph.hpp"
#include "ngraph/opsets/opset5.hpp"
#include "ngraph/pass/constant_folding.hpp"
#include "ngraph/pass/manager.hpp"
#include "util/all_close_f.hpp"
@ -3107,3 +3108,65 @@ TEST(constant_folding, disable_constant_folding)
ASSERT_EQ(count_ops_of_type<op::v1::Reshape>(f), 1);
ASSERT_EQ(count_ops_of_type<op::Constant>(f), 1);
}
TEST(constant_folding, constant_loop)
{
auto X = make_shared<opset5::Constant>(
element::f32, Shape{2, 1, 3}, std::vector<int64_t>{0, 1, 2, 3, 4, 5});
auto Y =
make_shared<opset5::Constant>(element::f32, Shape{1, 1, 3}, std::vector<int64_t>{1, 2, 3});
// Body parameters
auto Xi = make_shared<opset5::Parameter>(element::f32, PartialShape::dynamic());
auto Yi = make_shared<opset5::Parameter>(element::f32, PartialShape::dynamic());
auto body_condition = std::make_shared<ngraph::opset5::Constant>(
ngraph::element::boolean, ngraph::Shape{1}, true);
auto trip_count =
std::make_shared<ngraph::opset5::Constant>(ngraph::element::i64, ngraph::Shape{1}, 2);
auto exec_condition = std::make_shared<ngraph::opset5::Constant>(
ngraph::element::boolean, ngraph::Shape{1}, true);
// Body
auto sum = make_shared<ngraph::opset5::Add>(Xi, Yi);
auto body =
make_shared<ngraph::Function>(OutputVector{body_condition, sum}, ParameterVector{Xi, Yi});
auto loop = make_shared<opset5::Loop>(trip_count, exec_condition);
loop->set_function(body);
loop->set_special_body_ports(ngraph::opset5::Loop::SpecialBodyPorts{-1, 0});
loop->set_sliced_input(Xi, X, 0, 1, 1, -1, 0);
loop->set_invariant_input(Yi, Y);
auto out0 = loop->get_iter_value(sum, -1);
auto out1 = loop->get_concatenated_slices(sum, 0, 1, 1, -1, 0);
auto result0 = make_shared<opset5::Result>(out0);
auto result1 = make_shared<opset5::Result>(out1);
auto results = ResultVector{result0, result1};
auto f = make_shared<Function>(results, ParameterVector{});
pass::Manager pass_manager;
pass_manager.register_pass<pass::ConstantFolding>();
pass_manager.run_passes(f);
ASSERT_EQ(count_ops_of_type<ngraph::opset5::Loop>(f), 0);
ASSERT_EQ(count_ops_of_type<ngraph::opset5::Constant>(f), 2);
auto result_node_0 =
as_type_ptr<op::Constant>(f->get_results().at(0)->input_value(0).get_node_shared_ptr());
auto result_node_1 =
as_type_ptr<op::Constant>(f->get_results().at(1)->input_value(0).get_node_shared_ptr());
ASSERT_TRUE(result_node_0);
ASSERT_TRUE(result_node_1);
const ngraph::Shape shape_0{1, 1, 3};
const ngraph::Shape shape_1{2, 1, 3};
ASSERT_EQ(shape_0, result_node_0->get_output_shape(0));
ASSERT_EQ(shape_1, result_node_1->get_output_shape(0));
std::vector<float> expected_0{4, 6, 8};
std::vector<float> expected_1{1, 3, 5, 4, 6, 8};
range_test_check(result_node_0->cast_vector<float>(), expected_0);
range_test_check(result_node_1->cast_vector<float>(), expected_1);
}