fix constant folding for sub graph ops (#3534)
This commit is contained in:
parent
6f512142b6
commit
1ac3caf472
@ -27,20 +27,10 @@ bool ngraph::pass::ConstantFolding::run_on_function(std::shared_ptr<ngraph::Func
|
|||||||
{
|
{
|
||||||
bool rewritten = false;
|
bool rewritten = false;
|
||||||
|
|
||||||
for (auto&& node : f->get_ordered_ops())
|
for (const auto& node : f->get_ordered_ops())
|
||||||
{
|
{
|
||||||
node->revalidate_and_infer_types();
|
node->revalidate_and_infer_types();
|
||||||
|
|
||||||
// recursively constant fold operators containing subgraphs (ie: TensorIterator)
|
|
||||||
if (auto sub_graph_node = std::dynamic_pointer_cast<op::util::SubGraphOp>(node))
|
|
||||||
{
|
|
||||||
if (auto sub_graph = sub_graph_node->get_function())
|
|
||||||
{
|
|
||||||
rewritten |= run_on_function(sub_graph);
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
OutputVector replacements(node->get_output_size());
|
OutputVector replacements(node->get_output_size());
|
||||||
if (node->constant_fold(replacements, node->input_values()))
|
if (node->constant_fold(replacements, node->input_values()))
|
||||||
{
|
{
|
||||||
@ -72,6 +62,17 @@ bool ngraph::pass::ConstantFolding::run_on_function(std::shared_ptr<ngraph::Func
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
else
|
||||||
|
{
|
||||||
|
// recursively constant fold operators containing subgraphs (ie: TensorIterator, Loop)
|
||||||
|
if (auto sub_graph_node = std::dynamic_pointer_cast<op::util::SubGraphOp>(node))
|
||||||
|
{
|
||||||
|
if (const auto& sub_graph = sub_graph_node->get_function())
|
||||||
|
{
|
||||||
|
rewritten |= run_on_function(sub_graph);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return rewritten;
|
return rewritten;
|
||||||
|
@ -17,6 +17,7 @@
|
|||||||
#include "gtest/gtest.h"
|
#include "gtest/gtest.h"
|
||||||
|
|
||||||
#include "ngraph/ngraph.hpp"
|
#include "ngraph/ngraph.hpp"
|
||||||
|
#include "ngraph/opsets/opset5.hpp"
|
||||||
#include "ngraph/pass/constant_folding.hpp"
|
#include "ngraph/pass/constant_folding.hpp"
|
||||||
#include "ngraph/pass/manager.hpp"
|
#include "ngraph/pass/manager.hpp"
|
||||||
#include "util/all_close_f.hpp"
|
#include "util/all_close_f.hpp"
|
||||||
@ -3107,3 +3108,65 @@ TEST(constant_folding, disable_constant_folding)
|
|||||||
ASSERT_EQ(count_ops_of_type<op::v1::Reshape>(f), 1);
|
ASSERT_EQ(count_ops_of_type<op::v1::Reshape>(f), 1);
|
||||||
ASSERT_EQ(count_ops_of_type<op::Constant>(f), 1);
|
ASSERT_EQ(count_ops_of_type<op::Constant>(f), 1);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST(constant_folding, constant_loop)
|
||||||
|
{
|
||||||
|
auto X = make_shared<opset5::Constant>(
|
||||||
|
element::f32, Shape{2, 1, 3}, std::vector<int64_t>{0, 1, 2, 3, 4, 5});
|
||||||
|
auto Y =
|
||||||
|
make_shared<opset5::Constant>(element::f32, Shape{1, 1, 3}, std::vector<int64_t>{1, 2, 3});
|
||||||
|
|
||||||
|
// Body parameters
|
||||||
|
auto Xi = make_shared<opset5::Parameter>(element::f32, PartialShape::dynamic());
|
||||||
|
auto Yi = make_shared<opset5::Parameter>(element::f32, PartialShape::dynamic());
|
||||||
|
auto body_condition = std::make_shared<ngraph::opset5::Constant>(
|
||||||
|
ngraph::element::boolean, ngraph::Shape{1}, true);
|
||||||
|
|
||||||
|
auto trip_count =
|
||||||
|
std::make_shared<ngraph::opset5::Constant>(ngraph::element::i64, ngraph::Shape{1}, 2);
|
||||||
|
auto exec_condition = std::make_shared<ngraph::opset5::Constant>(
|
||||||
|
ngraph::element::boolean, ngraph::Shape{1}, true);
|
||||||
|
// Body
|
||||||
|
auto sum = make_shared<ngraph::opset5::Add>(Xi, Yi);
|
||||||
|
auto body =
|
||||||
|
make_shared<ngraph::Function>(OutputVector{body_condition, sum}, ParameterVector{Xi, Yi});
|
||||||
|
auto loop = make_shared<opset5::Loop>(trip_count, exec_condition);
|
||||||
|
loop->set_function(body);
|
||||||
|
loop->set_special_body_ports(ngraph::opset5::Loop::SpecialBodyPorts{-1, 0});
|
||||||
|
|
||||||
|
loop->set_sliced_input(Xi, X, 0, 1, 1, -1, 0);
|
||||||
|
loop->set_invariant_input(Yi, Y);
|
||||||
|
|
||||||
|
auto out0 = loop->get_iter_value(sum, -1);
|
||||||
|
auto out1 = loop->get_concatenated_slices(sum, 0, 1, 1, -1, 0);
|
||||||
|
|
||||||
|
auto result0 = make_shared<opset5::Result>(out0);
|
||||||
|
auto result1 = make_shared<opset5::Result>(out1);
|
||||||
|
|
||||||
|
auto results = ResultVector{result0, result1};
|
||||||
|
auto f = make_shared<Function>(results, ParameterVector{});
|
||||||
|
|
||||||
|
pass::Manager pass_manager;
|
||||||
|
pass_manager.register_pass<pass::ConstantFolding>();
|
||||||
|
pass_manager.run_passes(f);
|
||||||
|
|
||||||
|
ASSERT_EQ(count_ops_of_type<ngraph::opset5::Loop>(f), 0);
|
||||||
|
ASSERT_EQ(count_ops_of_type<ngraph::opset5::Constant>(f), 2);
|
||||||
|
|
||||||
|
auto result_node_0 =
|
||||||
|
as_type_ptr<op::Constant>(f->get_results().at(0)->input_value(0).get_node_shared_ptr());
|
||||||
|
auto result_node_1 =
|
||||||
|
as_type_ptr<op::Constant>(f->get_results().at(1)->input_value(0).get_node_shared_ptr());
|
||||||
|
ASSERT_TRUE(result_node_0);
|
||||||
|
ASSERT_TRUE(result_node_1);
|
||||||
|
|
||||||
|
const ngraph::Shape shape_0{1, 1, 3};
|
||||||
|
const ngraph::Shape shape_1{2, 1, 3};
|
||||||
|
|
||||||
|
ASSERT_EQ(shape_0, result_node_0->get_output_shape(0));
|
||||||
|
ASSERT_EQ(shape_1, result_node_1->get_output_shape(0));
|
||||||
|
std::vector<float> expected_0{4, 6, 8};
|
||||||
|
std::vector<float> expected_1{1, 3, 5, 4, 6, 8};
|
||||||
|
range_test_check(result_node_0->cast_vector<float>(), expected_0);
|
||||||
|
range_test_check(result_node_1->cast_vector<float>(), expected_1);
|
||||||
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user