ConstantFolding as a FunctionPass instead of GraphRewrite (#3065)

* Redundant op::Max CF removal

* Redundant op::Min CF removal

* Redundant op::Sum & op::Product CF removal

* CF Min and Max using evaluate()

* Arithmetic reduction CF pass removal

* Quantize op CF pass removal

* Convert op CF pass removal

* Logical reduction CF pass removal

* Select op CF pass removal

* OneHot CF pass removal

* Code formatting

* ScatterElements CF pass removal

* Gather CF pass removal

* Disable a Quantize op test that fails in CI

* CF pass cleanup

* Convert op cleanup and test adaptation to spec

* Possible fix for failing VPU tests

* Limit the types used in OneHot::evaluate

* Quantize op evaluator removal

* Refactor of Gather evaluator

* CF pass cleanup and adaptation to FunctionPass interface

* New CF pass implementation

* Fix the Tile::evaluate method for dynamic shapes

* Node shapes revalidation in CF pass

* Obsolete code cleanup

* Obsolete include removal

* Recursively fold subgraph nodes

* Obsolete include removal

* Revalidate each node in CF

* PR feedback

* Missing RTTI symbol definition
This commit is contained in:
Tomasz Dołbniak 2020-11-20 18:36:05 +01:00 committed by GitHub
parent 4b22a99a69
commit 6899a95e1c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 48 additions and 57 deletions

View File

@ -16,28 +16,27 @@
#pragma once
#include "ngraph/pass/graph_rewrite.hpp"
#include "ngraph/runtime/aligned_buffer.hpp"
#include "ngraph/util.hpp"
#include "ngraph/pass/pass.hpp"
namespace ngraph
{
namespace pass
{
class ConstantFolding;
bool revalidate_and_ensure_static(std::shared_ptr<ngraph::Node> n);
}
}
/**
* @brief Constant folding iterates over the function and tries to evaluate nodes
* with constant inputs. Such nodes are then replaced with new Constants containing
* the result of a folded operation.
*/
class NGRAPH_API ConstantFolding : public FunctionPass
{
public:
NGRAPH_RTTI_DECLARATION;
bool run_on_function(std::shared_ptr<ngraph::Function> f) override;
class NGRAPH_API ngraph::pass::ConstantFolding : public ngraph::pass::GraphRewrite
{
public:
NGRAPH_RTTI_DECLARATION;
ConstantFolding(const ngraph::BuildNodeExecutorMap& cfmap = ngraph::BuildNodeExecutorMap());
private:
void copy_runtime_info_to_target_inputs(const std::shared_ptr<Node>& node,
const Output<Node>& replacement);
};
private:
void copy_runtime_info_to_target_inputs(const std::shared_ptr<Node>& node,
const Output<Node>& replacement);
ngraph::BuildNodeExecutorMap m_cfmap;
};
} // namespace pass
} // namespace ngraph

View File

@ -230,15 +230,6 @@ namespace ngraph
NGRAPH_API
AxisVector get_default_order(const Shape& shape);
// NodeExecutors are used in compiler optimization passes like ConstantFolding to execute a node
// using the supplied input and output memory locations.
// A BuildNodeExecutor returns a backend-specific NodeExecutor for a given Node type
using NodeExecutorTy =
std::function<void(const std::vector<void*>& inputs, std::vector<void*>& outputs)>;
using BuildNodeExecutor = std::function<NodeExecutorTy(const ngraph::Node*)>;
using BuildNodeExecutorMap = std::unordered_map<std::type_index, BuildNodeExecutor>;
//
// EnumMask is intended to work with a scoped enum type. It's used to store
// a combination of enum values and provides easy access and manipulation

View File

@ -100,7 +100,6 @@ namespace shape_of
auto output_type = shape_of_node->get_output_element_type(0);
if (partial_shape.is_static())
{
NGRAPH_CHECK(pass::revalidate_and_ensure_static(shape_of_node->shared_from_this()));
auto arg_shape = shape_of_input.get_shape();
auto result_tensor =
make_shared<HostTensor>(output_type, shape_of_node->get_output_shape(0));

View File

@ -115,6 +115,12 @@ bool op::v0::Tile::evaluate(const HostTensorVector& outputs, const HostTensorVec
{
output_shape[i] = data_shape[i] * repeats_val[i];
}
if (!output->get_is_allocated())
{
output->set_shape(output_shape);
}
runtime::reference::tile(data->get_data_ptr<const char>(),
output->get_data_ptr<char>(),
data->get_shape(),

View File

@ -16,31 +16,38 @@
#include "constant_folding.hpp"
#include <ngraph/rt_info.hpp>
#include "ngraph/op/util/sub_graph_base.hpp"
using namespace std;
using namespace ngraph;
NGRAPH_RTTI_DEFINITION(ngraph::pass::ConstantFolding, "ConstantFolding", 0);
ngraph::pass::ConstantFolding::ConstantFolding(const ngraph::BuildNodeExecutorMap& cfmap)
: GraphRewrite()
, m_cfmap{cfmap}
bool ngraph::pass::ConstantFolding::run_on_function(std::shared_ptr<ngraph::Function> f)
{
m_enable_shape_inference = true;
bool rewritten = false;
m_matchers.push_back(std::make_shared<MatcherPass>(
"Constant folding defaults",
nullptr,
[=](const std::shared_ptr<Node>& node) -> bool {
OutputVector replacements(node->get_output_size());
if (!node->constant_fold(replacements, node->input_values()))
for (auto&& node : f->get_ordered_ops())
{
node->revalidate_and_infer_types();
// recursively constant fold operators containing subgraphs (ie: TensorIterator)
if (auto sub_graph_node = std::dynamic_pointer_cast<op::util::SubGraphOp>(node))
{
if (auto sub_graph = sub_graph_node->get_function())
{
return false;
rewritten |= run_on_function(sub_graph);
continue;
}
}
OutputVector replacements(node->get_output_size());
if (node->constant_fold(replacements, node->input_values()))
{
NGRAPH_CHECK(replacements.size() == node->get_output_size(),
"constant_fold_default returned incorrect number of replacements for ",
node);
bool result{false};
for (size_t i = 0; i < replacements.size(); ++i)
{
auto node_output = node->output(i);
@ -60,25 +67,14 @@ ngraph::pass::ConstantFolding::ConstantFolding(const ngraph::BuildNodeExecutorMa
node_output.replace(replacement);
// Propagate runtime info attributes to replacement consumer nodes
copy_runtime_info_to_target_inputs(node, replacement);
result = true;
rewritten = true;
}
}
return result;
},
PassProperty::CHANGE_DYNAMIC_STATE));
}
bool ngraph::pass::revalidate_and_ensure_static(shared_ptr<Node> n)
{
n->revalidate_and_infer_types();
for (auto& o : n->outputs())
{
if (o.get_partial_shape().is_dynamic() || o.get_element_type().is_dynamic())
{
return false;
}
}
return true;
return rewritten;
}
void ngraph::pass::ConstantFolding::copy_runtime_info_to_target_inputs(

View File

@ -1991,8 +1991,8 @@ TEST(constant_folding, constant_dyn_reshape_shape_not_originally_constant)
auto constant_in = make_shared<op::Constant>(element::f32, shape_in, values_in);
auto constant_shape_a = make_shared<op::Constant>(element::i64, shape_shape, values_shape_a);
auto constant_shape_b = make_shared<op::Constant>(element::i64, shape_shape, values_shape_b);
auto dyn_reshape =
make_shared<op::v1::Reshape>(constant_in, constant_shape_a + constant_shape_b, false);
auto dyn_reshape = make_shared<op::v1::Reshape>(
constant_in, std::make_shared<op::v1::Add>(constant_shape_a, constant_shape_b), false);
dyn_reshape->set_friendly_name("test");
auto f = make_shared<Function>(dyn_reshape, ParameterVector{});