fix constant folding for sub graph ops (#3534)
This commit is contained in:
parent
6f512142b6
commit
1ac3caf472
@ -27,20 +27,10 @@ bool ngraph::pass::ConstantFolding::run_on_function(std::shared_ptr<ngraph::Func
|
||||
{
|
||||
bool rewritten = false;
|
||||
|
||||
for (auto&& node : f->get_ordered_ops())
|
||||
for (const auto& node : f->get_ordered_ops())
|
||||
{
|
||||
node->revalidate_and_infer_types();
|
||||
|
||||
// recursively constant fold operators containing subgraphs (ie: TensorIterator)
|
||||
if (auto sub_graph_node = std::dynamic_pointer_cast<op::util::SubGraphOp>(node))
|
||||
{
|
||||
if (auto sub_graph = sub_graph_node->get_function())
|
||||
{
|
||||
rewritten |= run_on_function(sub_graph);
|
||||
continue;
|
||||
}
|
||||
}
|
||||
|
||||
OutputVector replacements(node->get_output_size());
|
||||
if (node->constant_fold(replacements, node->input_values()))
|
||||
{
|
||||
@ -72,6 +62,17 @@ bool ngraph::pass::ConstantFolding::run_on_function(std::shared_ptr<ngraph::Func
|
||||
}
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
// recursively constant fold operators containing subgraphs (ie: TensorIterator, Loop)
|
||||
if (auto sub_graph_node = std::dynamic_pointer_cast<op::util::SubGraphOp>(node))
|
||||
{
|
||||
if (const auto& sub_graph = sub_graph_node->get_function())
|
||||
{
|
||||
rewritten |= run_on_function(sub_graph);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return rewritten;
|
||||
|
@ -17,6 +17,7 @@
|
||||
#include "gtest/gtest.h"
|
||||
|
||||
#include "ngraph/ngraph.hpp"
|
||||
#include "ngraph/opsets/opset5.hpp"
|
||||
#include "ngraph/pass/constant_folding.hpp"
|
||||
#include "ngraph/pass/manager.hpp"
|
||||
#include "util/all_close_f.hpp"
|
||||
@ -3107,3 +3108,65 @@ TEST(constant_folding, disable_constant_folding)
|
||||
ASSERT_EQ(count_ops_of_type<op::v1::Reshape>(f), 1);
|
||||
ASSERT_EQ(count_ops_of_type<op::Constant>(f), 1);
|
||||
}
|
||||
|
||||
TEST(constant_folding, constant_loop)
|
||||
{
|
||||
auto X = make_shared<opset5::Constant>(
|
||||
element::f32, Shape{2, 1, 3}, std::vector<int64_t>{0, 1, 2, 3, 4, 5});
|
||||
auto Y =
|
||||
make_shared<opset5::Constant>(element::f32, Shape{1, 1, 3}, std::vector<int64_t>{1, 2, 3});
|
||||
|
||||
// Body parameters
|
||||
auto Xi = make_shared<opset5::Parameter>(element::f32, PartialShape::dynamic());
|
||||
auto Yi = make_shared<opset5::Parameter>(element::f32, PartialShape::dynamic());
|
||||
auto body_condition = std::make_shared<ngraph::opset5::Constant>(
|
||||
ngraph::element::boolean, ngraph::Shape{1}, true);
|
||||
|
||||
auto trip_count =
|
||||
std::make_shared<ngraph::opset5::Constant>(ngraph::element::i64, ngraph::Shape{1}, 2);
|
||||
auto exec_condition = std::make_shared<ngraph::opset5::Constant>(
|
||||
ngraph::element::boolean, ngraph::Shape{1}, true);
|
||||
// Body
|
||||
auto sum = make_shared<ngraph::opset5::Add>(Xi, Yi);
|
||||
auto body =
|
||||
make_shared<ngraph::Function>(OutputVector{body_condition, sum}, ParameterVector{Xi, Yi});
|
||||
auto loop = make_shared<opset5::Loop>(trip_count, exec_condition);
|
||||
loop->set_function(body);
|
||||
loop->set_special_body_ports(ngraph::opset5::Loop::SpecialBodyPorts{-1, 0});
|
||||
|
||||
loop->set_sliced_input(Xi, X, 0, 1, 1, -1, 0);
|
||||
loop->set_invariant_input(Yi, Y);
|
||||
|
||||
auto out0 = loop->get_iter_value(sum, -1);
|
||||
auto out1 = loop->get_concatenated_slices(sum, 0, 1, 1, -1, 0);
|
||||
|
||||
auto result0 = make_shared<opset5::Result>(out0);
|
||||
auto result1 = make_shared<opset5::Result>(out1);
|
||||
|
||||
auto results = ResultVector{result0, result1};
|
||||
auto f = make_shared<Function>(results, ParameterVector{});
|
||||
|
||||
pass::Manager pass_manager;
|
||||
pass_manager.register_pass<pass::ConstantFolding>();
|
||||
pass_manager.run_passes(f);
|
||||
|
||||
ASSERT_EQ(count_ops_of_type<ngraph::opset5::Loop>(f), 0);
|
||||
ASSERT_EQ(count_ops_of_type<ngraph::opset5::Constant>(f), 2);
|
||||
|
||||
auto result_node_0 =
|
||||
as_type_ptr<op::Constant>(f->get_results().at(0)->input_value(0).get_node_shared_ptr());
|
||||
auto result_node_1 =
|
||||
as_type_ptr<op::Constant>(f->get_results().at(1)->input_value(0).get_node_shared_ptr());
|
||||
ASSERT_TRUE(result_node_0);
|
||||
ASSERT_TRUE(result_node_1);
|
||||
|
||||
const ngraph::Shape shape_0{1, 1, 3};
|
||||
const ngraph::Shape shape_1{2, 1, 3};
|
||||
|
||||
ASSERT_EQ(shape_0, result_node_0->get_output_shape(0));
|
||||
ASSERT_EQ(shape_1, result_node_1->get_output_shape(0));
|
||||
std::vector<float> expected_0{4, 6, 8};
|
||||
std::vector<float> expected_1{1, 3, 5, 4, 6, 8};
|
||||
range_test_check(result_node_0->cast_vector<float>(), expected_0);
|
||||
range_test_check(result_node_1->cast_vector<float>(), expected_1);
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user