diff --git a/ngraph/core/include/ngraph/pass/constant_folding.hpp b/ngraph/core/include/ngraph/pass/constant_folding.hpp index d7bc91512c5..b48c41dc65f 100644 --- a/ngraph/core/include/ngraph/pass/constant_folding.hpp +++ b/ngraph/core/include/ngraph/pass/constant_folding.hpp @@ -16,28 +16,27 @@ #pragma once -#include "ngraph/pass/graph_rewrite.hpp" -#include "ngraph/runtime/aligned_buffer.hpp" -#include "ngraph/util.hpp" +#include "ngraph/pass/pass.hpp" namespace ngraph { namespace pass { - class ConstantFolding; - bool revalidate_and_ensure_static(std::shared_ptr n); - } -} + /** + * @brief Constant folding iterates over the function and tries to evaluate nodes + * with constant inputs. Such nodes are then replaced with new Constants containing + * the result of a folded operation. + */ + class NGRAPH_API ConstantFolding : public FunctionPass + { + public: + NGRAPH_RTTI_DECLARATION; + bool run_on_function(std::shared_ptr f) override; -class NGRAPH_API ngraph::pass::ConstantFolding : public ngraph::pass::GraphRewrite -{ -public: - NGRAPH_RTTI_DECLARATION; - ConstantFolding(const ngraph::BuildNodeExecutorMap& cfmap = ngraph::BuildNodeExecutorMap()); + private: + void copy_runtime_info_to_target_inputs(const std::shared_ptr& node, + const Output& replacement); + }; -private: - void copy_runtime_info_to_target_inputs(const std::shared_ptr& node, - const Output& replacement); - - ngraph::BuildNodeExecutorMap m_cfmap; -}; + } // namespace pass +} // namespace ngraph diff --git a/ngraph/core/include/ngraph/util.hpp b/ngraph/core/include/ngraph/util.hpp index 1b6b5a2fbef..654e3302265 100644 --- a/ngraph/core/include/ngraph/util.hpp +++ b/ngraph/core/include/ngraph/util.hpp @@ -230,15 +230,6 @@ namespace ngraph NGRAPH_API AxisVector get_default_order(const Shape& shape); - // NodeExecutors are used in compiler optimization passes like ConstantFolding to execute a node - // using the supplied input and output memory locations. - // A BuildNodeExecutor returns a backend-specific NodeExecutor for a given Node type - using NodeExecutorTy = - std::function& inputs, std::vector& outputs)>; - using BuildNodeExecutor = std::function; - - using BuildNodeExecutorMap = std::unordered_map; - // // EnumMask is intended to work with a scoped enum type. It's used to store // a combination of enum values and provides easy access and manipulation diff --git a/ngraph/core/src/op/shape_of.cpp b/ngraph/core/src/op/shape_of.cpp index 84f9b9c93cc..78923352831 100644 --- a/ngraph/core/src/op/shape_of.cpp +++ b/ngraph/core/src/op/shape_of.cpp @@ -100,7 +100,6 @@ namespace shape_of auto output_type = shape_of_node->get_output_element_type(0); if (partial_shape.is_static()) { - NGRAPH_CHECK(pass::revalidate_and_ensure_static(shape_of_node->shared_from_this())); auto arg_shape = shape_of_input.get_shape(); auto result_tensor = make_shared(output_type, shape_of_node->get_output_shape(0)); diff --git a/ngraph/core/src/op/tile.cpp b/ngraph/core/src/op/tile.cpp index ee7f415914d..7b8a67b7af5 100644 --- a/ngraph/core/src/op/tile.cpp +++ b/ngraph/core/src/op/tile.cpp @@ -115,6 +115,12 @@ bool op::v0::Tile::evaluate(const HostTensorVector& outputs, const HostTensorVec { output_shape[i] = data_shape[i] * repeats_val[i]; } + + if (!output->get_is_allocated()) + { + output->set_shape(output_shape); + } + runtime::reference::tile(data->get_data_ptr(), output->get_data_ptr(), data->get_shape(), diff --git a/ngraph/core/src/pass/constant_folding.cpp b/ngraph/core/src/pass/constant_folding.cpp index e24164c0a30..6bbeb759704 100644 --- a/ngraph/core/src/pass/constant_folding.cpp +++ b/ngraph/core/src/pass/constant_folding.cpp @@ -16,31 +16,38 @@ #include "constant_folding.hpp" #include +#include "ngraph/op/util/sub_graph_base.hpp" using namespace std; using namespace ngraph; NGRAPH_RTTI_DEFINITION(ngraph::pass::ConstantFolding, "ConstantFolding", 0); -ngraph::pass::ConstantFolding::ConstantFolding(const ngraph::BuildNodeExecutorMap& cfmap) - : GraphRewrite() - , m_cfmap{cfmap} +bool ngraph::pass::ConstantFolding::run_on_function(std::shared_ptr f) { - m_enable_shape_inference = true; + bool rewritten = false; - m_matchers.push_back(std::make_shared( - "Constant folding defaults", - nullptr, - [=](const std::shared_ptr& node) -> bool { - OutputVector replacements(node->get_output_size()); - if (!node->constant_fold(replacements, node->input_values())) + for (auto&& node : f->get_ordered_ops()) + { + node->revalidate_and_infer_types(); + + // recursively constant fold operators containing subgraphs (ie: TensorIterator) + if (auto sub_graph_node = std::dynamic_pointer_cast(node)) + { + if (auto sub_graph = sub_graph_node->get_function()) { - return false; + rewritten |= run_on_function(sub_graph); + continue; } + } + + OutputVector replacements(node->get_output_size()); + if (node->constant_fold(replacements, node->input_values())) + { NGRAPH_CHECK(replacements.size() == node->get_output_size(), "constant_fold_default returned incorrect number of replacements for ", node); - bool result{false}; + for (size_t i = 0; i < replacements.size(); ++i) { auto node_output = node->output(i); @@ -60,25 +67,14 @@ ngraph::pass::ConstantFolding::ConstantFolding(const ngraph::BuildNodeExecutorMa node_output.replace(replacement); // Propagate runtime info attributes to replacement consumer nodes copy_runtime_info_to_target_inputs(node, replacement); - result = true; + + rewritten = true; } } - return result; - }, - PassProperty::CHANGE_DYNAMIC_STATE)); -} - -bool ngraph::pass::revalidate_and_ensure_static(shared_ptr n) -{ - n->revalidate_and_infer_types(); - for (auto& o : n->outputs()) - { - if (o.get_partial_shape().is_dynamic() || o.get_element_type().is_dynamic()) - { - return false; } } - return true; + + return rewritten; } void ngraph::pass::ConstantFolding::copy_runtime_info_to_target_inputs( diff --git a/ngraph/test/constant_folding.cpp b/ngraph/test/constant_folding.cpp index d563b73737c..be87409a08a 100644 --- a/ngraph/test/constant_folding.cpp +++ b/ngraph/test/constant_folding.cpp @@ -1991,8 +1991,8 @@ TEST(constant_folding, constant_dyn_reshape_shape_not_originally_constant) auto constant_in = make_shared(element::f32, shape_in, values_in); auto constant_shape_a = make_shared(element::i64, shape_shape, values_shape_a); auto constant_shape_b = make_shared(element::i64, shape_shape, values_shape_b); - auto dyn_reshape = - make_shared(constant_in, constant_shape_a + constant_shape_b, false); + auto dyn_reshape = make_shared( + constant_in, std::make_shared(constant_shape_a, constant_shape_b), false); dyn_reshape->set_friendly_name("test"); auto f = make_shared(dyn_reshape, ParameterVector{});