ConstantFolding as a FunctionPass instead of GraphRewrite (#3065)
* Redundant op::Max CF removal * Redundant op::Min CF removal * Redundant op::Sum & op::Product CF removal * CF Min and Max using evaluate() * Arithmetic reduction CF pass removal * Quantize op CF pass removal * Convert op CF pass removal * Logical reduction CF pass removal * Select op CF pass removal * OneHot CF pass removal * Code formatting * ScatterElements CF pass removal * Gather CF pass removal * Disable a Quantize op test that fails in CI * CF pass cleanup * Convert op cleanup and test adaptation to spec * Possible fix for failing VPU tests * Limit the types used in OneHot::evaluate * Quantize op evaluator removal * Refactor of Gather evaluator * CF pass cleanup and adaptation to FunctionPass interface * New CF pass implementation * Fix the Tile::evaluate method for dynamic shapes * Node shapes revalidation in CF pass * Obsolete code cleanup * Obsolete include removal * Recursively fold subgraph nodes * Obsolete include removal * Revalidate each node in CF * PR feedback * Missing RTTI symbol definition
This commit is contained in:
parent
4b22a99a69
commit
6899a95e1c
@ -16,28 +16,27 @@
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ngraph/pass/graph_rewrite.hpp"
|
||||
#include "ngraph/runtime/aligned_buffer.hpp"
|
||||
#include "ngraph/util.hpp"
|
||||
#include "ngraph/pass/pass.hpp"
|
||||
|
||||
namespace ngraph
|
||||
{
|
||||
namespace pass
|
||||
{
|
||||
class ConstantFolding;
|
||||
bool revalidate_and_ensure_static(std::shared_ptr<ngraph::Node> n);
|
||||
}
|
||||
}
|
||||
/**
|
||||
* @brief Constant folding iterates over the function and tries to evaluate nodes
|
||||
* with constant inputs. Such nodes are then replaced with new Constants containing
|
||||
* the result of a folded operation.
|
||||
*/
|
||||
class NGRAPH_API ConstantFolding : public FunctionPass
|
||||
{
|
||||
public:
|
||||
NGRAPH_RTTI_DECLARATION;
|
||||
bool run_on_function(std::shared_ptr<ngraph::Function> f) override;
|
||||
|
||||
class NGRAPH_API ngraph::pass::ConstantFolding : public ngraph::pass::GraphRewrite
|
||||
{
|
||||
public:
|
||||
NGRAPH_RTTI_DECLARATION;
|
||||
ConstantFolding(const ngraph::BuildNodeExecutorMap& cfmap = ngraph::BuildNodeExecutorMap());
|
||||
private:
|
||||
void copy_runtime_info_to_target_inputs(const std::shared_ptr<Node>& node,
|
||||
const Output<Node>& replacement);
|
||||
};
|
||||
|
||||
private:
|
||||
void copy_runtime_info_to_target_inputs(const std::shared_ptr<Node>& node,
|
||||
const Output<Node>& replacement);
|
||||
|
||||
ngraph::BuildNodeExecutorMap m_cfmap;
|
||||
};
|
||||
} // namespace pass
|
||||
} // namespace ngraph
|
||||
|
@ -230,15 +230,6 @@ namespace ngraph
|
||||
NGRAPH_API
|
||||
AxisVector get_default_order(const Shape& shape);
|
||||
|
||||
// NodeExecutors are used in compiler optimization passes like ConstantFolding to execute a node
|
||||
// using the supplied input and output memory locations.
|
||||
// A BuildNodeExecutor returns a backend-specific NodeExecutor for a given Node type
|
||||
using NodeExecutorTy =
|
||||
std::function<void(const std::vector<void*>& inputs, std::vector<void*>& outputs)>;
|
||||
using BuildNodeExecutor = std::function<NodeExecutorTy(const ngraph::Node*)>;
|
||||
|
||||
using BuildNodeExecutorMap = std::unordered_map<std::type_index, BuildNodeExecutor>;
|
||||
|
||||
//
|
||||
// EnumMask is intended to work with a scoped enum type. It's used to store
|
||||
// a combination of enum values and provides easy access and manipulation
|
||||
|
@ -100,7 +100,6 @@ namespace shape_of
|
||||
auto output_type = shape_of_node->get_output_element_type(0);
|
||||
if (partial_shape.is_static())
|
||||
{
|
||||
NGRAPH_CHECK(pass::revalidate_and_ensure_static(shape_of_node->shared_from_this()));
|
||||
auto arg_shape = shape_of_input.get_shape();
|
||||
auto result_tensor =
|
||||
make_shared<HostTensor>(output_type, shape_of_node->get_output_shape(0));
|
||||
|
@ -115,6 +115,12 @@ bool op::v0::Tile::evaluate(const HostTensorVector& outputs, const HostTensorVec
|
||||
{
|
||||
output_shape[i] = data_shape[i] * repeats_val[i];
|
||||
}
|
||||
|
||||
if (!output->get_is_allocated())
|
||||
{
|
||||
output->set_shape(output_shape);
|
||||
}
|
||||
|
||||
runtime::reference::tile(data->get_data_ptr<const char>(),
|
||||
output->get_data_ptr<char>(),
|
||||
data->get_shape(),
|
||||
|
@ -16,31 +16,38 @@
|
||||
|
||||
#include "constant_folding.hpp"
|
||||
#include <ngraph/rt_info.hpp>
|
||||
#include "ngraph/op/util/sub_graph_base.hpp"
|
||||
|
||||
using namespace std;
|
||||
using namespace ngraph;
|
||||
|
||||
NGRAPH_RTTI_DEFINITION(ngraph::pass::ConstantFolding, "ConstantFolding", 0);
|
||||
|
||||
ngraph::pass::ConstantFolding::ConstantFolding(const ngraph::BuildNodeExecutorMap& cfmap)
|
||||
: GraphRewrite()
|
||||
, m_cfmap{cfmap}
|
||||
bool ngraph::pass::ConstantFolding::run_on_function(std::shared_ptr<ngraph::Function> f)
|
||||
{
|
||||
m_enable_shape_inference = true;
|
||||
bool rewritten = false;
|
||||
|
||||
m_matchers.push_back(std::make_shared<MatcherPass>(
|
||||
"Constant folding defaults",
|
||||
nullptr,
|
||||
[=](const std::shared_ptr<Node>& node) -> bool {
|
||||
OutputVector replacements(node->get_output_size());
|
||||
if (!node->constant_fold(replacements, node->input_values()))
|
||||
for (auto&& node : f->get_ordered_ops())
|
||||
{
|
||||
node->revalidate_and_infer_types();
|
||||
|
||||
// recursively constant fold operators containing subgraphs (ie: TensorIterator)
|
||||
if (auto sub_graph_node = std::dynamic_pointer_cast<op::util::SubGraphOp>(node))
|
||||
{
|
||||
if (auto sub_graph = sub_graph_node->get_function())
|
||||
{
|
||||
return false;
|
||||
rewritten |= run_on_function(sub_graph);
|
||||
continue;
|
||||
}
|
||||
}
|
||||
|
||||
OutputVector replacements(node->get_output_size());
|
||||
if (node->constant_fold(replacements, node->input_values()))
|
||||
{
|
||||
NGRAPH_CHECK(replacements.size() == node->get_output_size(),
|
||||
"constant_fold_default returned incorrect number of replacements for ",
|
||||
node);
|
||||
bool result{false};
|
||||
|
||||
for (size_t i = 0; i < replacements.size(); ++i)
|
||||
{
|
||||
auto node_output = node->output(i);
|
||||
@ -60,25 +67,14 @@ ngraph::pass::ConstantFolding::ConstantFolding(const ngraph::BuildNodeExecutorMa
|
||||
node_output.replace(replacement);
|
||||
// Propagate runtime info attributes to replacement consumer nodes
|
||||
copy_runtime_info_to_target_inputs(node, replacement);
|
||||
result = true;
|
||||
|
||||
rewritten = true;
|
||||
}
|
||||
}
|
||||
return result;
|
||||
},
|
||||
PassProperty::CHANGE_DYNAMIC_STATE));
|
||||
}
|
||||
|
||||
bool ngraph::pass::revalidate_and_ensure_static(shared_ptr<Node> n)
|
||||
{
|
||||
n->revalidate_and_infer_types();
|
||||
for (auto& o : n->outputs())
|
||||
{
|
||||
if (o.get_partial_shape().is_dynamic() || o.get_element_type().is_dynamic())
|
||||
{
|
||||
return false;
|
||||
}
|
||||
}
|
||||
return true;
|
||||
|
||||
return rewritten;
|
||||
}
|
||||
|
||||
void ngraph::pass::ConstantFolding::copy_runtime_info_to_target_inputs(
|
||||
|
@ -1991,8 +1991,8 @@ TEST(constant_folding, constant_dyn_reshape_shape_not_originally_constant)
|
||||
auto constant_in = make_shared<op::Constant>(element::f32, shape_in, values_in);
|
||||
auto constant_shape_a = make_shared<op::Constant>(element::i64, shape_shape, values_shape_a);
|
||||
auto constant_shape_b = make_shared<op::Constant>(element::i64, shape_shape, values_shape_b);
|
||||
auto dyn_reshape =
|
||||
make_shared<op::v1::Reshape>(constant_in, constant_shape_a + constant_shape_b, false);
|
||||
auto dyn_reshape = make_shared<op::v1::Reshape>(
|
||||
constant_in, std::make_shared<op::v1::Add>(constant_shape_a, constant_shape_b), false);
|
||||
dyn_reshape->set_friendly_name("test");
|
||||
auto f = make_shared<Function>(dyn_reshape, ParameterVector{});
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user