From 2e49b4e4d85304512750fd9f636c148f676136f3 Mon Sep 17 00:00:00 2001 From: Gleb Kazantaev Date: Fri, 9 Oct 2020 15:33:19 +0300 Subject: [PATCH] Fine-Grain Transformation pipeline tuning (#2547) * Initial version of transformation callback refactoring * Improved fine-grain tuning for transformation pipeline * Check disabled matchers in GraphRewrite * Avoid deprecated classes inside PassConfig * Enabled DepthToSpace fusion by default * Removed doulbe search in map * Moved back pass_config.hpp; Added doxygen documentation for new class and methods * Added doxygen comment for Manager and GraphRewrite new mthods --- .../src/mkldnn_plugin/mkldnn_plugin.cpp | 74 +++++--- .../common_optimizations.cpp | 5 +- .../convert_depth_to_space.cpp | 2 +- .../convert_opset1_to_legacy.cpp | 2 +- .../convert_opset2_to_opset1.cpp | 2 +- .../convert_opset3_to_opset2.cpp | 2 +- .../convert_pad_to_group_conv.cpp | 2 +- .../transformations/depth_to_space_fusion.cpp | 5 - .../include/ngraph/pass/graph_rewrite.hpp | 24 ++- ngraph/core/include/ngraph/pass/manager.hpp | 47 +++-- ngraph/core/include/ngraph/pass/pass.hpp | 40 ++++- .../core/include/ngraph/pass/pass_config.hpp | 168 +++++++++++++++--- ngraph/core/src/pass/graph_rewrite.cpp | 14 +- ngraph/core/src/pass/manager.cpp | 9 +- ngraph/core/src/pass/pass.cpp | 4 +- ngraph/core/src/pass/pass_config.cpp | 92 +--------- ngraph/test/graph_rewrite.cpp | 144 ++++++++++++++- ngraph/test/runtime/CMakeLists.txt | 2 + ngraph/test/runtime/backend.cpp | 8 - ngraph/test/runtime/backend.hpp | 9 - 20 files changed, 462 insertions(+), 193 deletions(-) diff --git a/inference-engine/src/mkldnn_plugin/mkldnn_plugin.cpp b/inference-engine/src/mkldnn_plugin/mkldnn_plugin.cpp index 57d23d3c6b1..25a2140027c 100644 --- a/inference-engine/src/mkldnn_plugin/mkldnn_plugin.cpp +++ b/inference-engine/src/mkldnn_plugin/mkldnn_plugin.cpp @@ -27,8 +27,21 @@ #include #include #include +#include +#include #include #include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include #include #include #include @@ -63,31 +76,6 @@ Engine::~Engine() { static void Transformation(ICNNNetwork::Ptr& clonedNetwork) { OV_ITT_SCOPED_TASK(MKLDNNPlugin::itt::domains::MKLDNNPlugin, "Transformation"); - const auto transformations_callback = [](const std::shared_ptr &node) -> bool { - // DepthToSpace node implementation supports only equal input/output tensors with rank <= 5 - if (auto dtsOp = std::dynamic_pointer_cast(node)) { - return dtsOp->input_value(0).get_shape().size() <= 5lu && dtsOp->input_value(0).get_shape().size() == dtsOp->get_output_shape(0).size(); - } - - // SpaceToDepth node implementation supports only equal input/output tensors with rank <= 5 - if (auto stdOp = std::dynamic_pointer_cast(node)) { - return stdOp->input_value(0).get_shape().size() <= 5lu && stdOp->input_value(0).get_shape().size() == stdOp->get_output_shape(0).size(); - } - - if (auto fc_op = std::dynamic_pointer_cast(node)) { - return fc_op->input_value(0).get_shape().size() == 3ul; - } - - return std::dynamic_pointer_cast(node) || - std::dynamic_pointer_cast(node) || - std::dynamic_pointer_cast(node) || - std::dynamic_pointer_cast(node) || - std::dynamic_pointer_cast(node) || - std::dynamic_pointer_cast(node) || - std::dynamic_pointer_cast(node) || - std::dynamic_pointer_cast(node) || - std::dynamic_pointer_cast(node); - }; auto nGraphFunc = clonedNetwork->getFunction(); // Disable shape inference (WA for generic operations) ngraph::op::GenericIE::DisableReshape noReshape(nGraphFunc); @@ -116,7 +104,41 @@ static void Transformation(ICNNNetwork::Ptr& clonedNetwork) { manager.register_pass(); manager.register_pass(ngraph::element::i64, ngraph::element::i32); - manager.set_callback(transformations_callback); + auto pass_config = manager.get_pass_config(); + + using const_node_ptr = const std::shared_ptr; + + // SpaceToDepth/ DepthToSpace node implementation supports only equal input/output tensors with rank <= 5 + pass_config->set_callback( + [](const_node_ptr &node) -> bool { + return node->input_value(0).get_shape().size() <= 5lu && + node->input_value(0).get_shape().size() == node->get_output_shape(0).size(); + }); + + // Disable FC reshaping for 3D case + pass_config->set_callback( + [](const_node_ptr &node) -> bool { + return node->input_value(0).get_shape().size() == 3ul; + }); + + pass_config->set_callback( + [](const_node_ptr &node) -> bool { + const auto & rank = node->input(0).get_partial_shape().rank().get_length(); + return rank == 4lu || rank == 5lu; + }); + + // List of enabled/disabled transformations + pass_config->disable(); + pass_config->disable(); + pass_config->disable(); + pass_config->disable(); + pass_config->disable(); + pass_config->disable(); + + pass_config->enable(); + manager.run_passes(nGraphFunc); clonedNetwork = InferenceEngine::details::convertFunctionToICNNNetwork(nGraphFunc, *clonedNetwork); diff --git a/inference-engine/src/transformations/src/transformations/common_optimizations/common_optimizations.cpp b/inference-engine/src/transformations/src/transformations/common_optimizations/common_optimizations.cpp index fae8de5cfcb..49894093854 100644 --- a/inference-engine/src/transformations/src/transformations/common_optimizations/common_optimizations.cpp +++ b/inference-engine/src/transformations/src/transformations/common_optimizations/common_optimizations.cpp @@ -68,7 +68,7 @@ bool ngraph::pass::CommonOptimizations::run_on_function(std::shared_ptr(); manager.register_pass(); manager.register_pass(); - manager.register_pass(); + manager.register_pass(); manager.register_pass(); manager.register_pass(); manager.register_pass(); @@ -111,7 +111,8 @@ bool ngraph::pass::CommonOptimizations::run_on_function(std::shared_ptradd_matcher(); fq_fusions->set_name("ngraph::pass::FakeQuantizeFusions"); - manager.set_callback(m_transformation_callback); + // Propagate local PassConfig to internal pass::Manager + manager.set_pass_config(get_pass_config()); manager.run_passes(f); return true; } diff --git a/inference-engine/src/transformations/src/transformations/convert_depth_to_space.cpp b/inference-engine/src/transformations/src/transformations/convert_depth_to_space.cpp index 854301986ec..e2d9de35f92 100644 --- a/inference-engine/src/transformations/src/transformations/convert_depth_to_space.cpp +++ b/inference-engine/src/transformations/src/transformations/convert_depth_to_space.cpp @@ -18,7 +18,7 @@ ngraph::pass::ConvertDepthToSpace::ConvertDepthToSpace() { ngraph::graph_rewrite_callback callback = [this](pattern::Matcher& m) { auto dts_node = std::dynamic_pointer_cast (m.get_match_root()); - if (!dts_node || m_transformation_callback(dts_node)) { + if (!dts_node || transformation_callback(dts_node)) { return false; } diff --git a/inference-engine/src/transformations/src/transformations/convert_opset1_to_legacy/convert_opset1_to_legacy.cpp b/inference-engine/src/transformations/src/transformations/convert_opset1_to_legacy/convert_opset1_to_legacy.cpp index b6784268879..77eec042231 100644 --- a/inference-engine/src/transformations/src/transformations/convert_opset1_to_legacy/convert_opset1_to_legacy.cpp +++ b/inference-engine/src/transformations/src/transformations/convert_opset1_to_legacy/convert_opset1_to_legacy.cpp @@ -154,7 +154,7 @@ bool ngraph::pass::ConvertOpSet1ToLegacy::run_on_function(std::shared_ptr(); - manager.set_callback(m_transformation_callback); + manager.set_pass_config(get_pass_config()); manager.run_passes(f); return true; } diff --git a/inference-engine/src/transformations/src/transformations/convert_opset2_to_opset1/convert_opset2_to_opset1.cpp b/inference-engine/src/transformations/src/transformations/convert_opset2_to_opset1/convert_opset2_to_opset1.cpp index 80b66b3ac0d..ded57834ab6 100644 --- a/inference-engine/src/transformations/src/transformations/convert_opset2_to_opset1/convert_opset2_to_opset1.cpp +++ b/inference-engine/src/transformations/src/transformations/convert_opset2_to_opset1/convert_opset2_to_opset1.cpp @@ -24,7 +24,7 @@ bool ngraph::pass::ConvertOpSet2ToOpSet1::run_on_function(std::shared_ptr(); manager.register_pass(); - manager.set_callback(m_transformation_callback); + manager.set_pass_config(get_pass_config()); manager.run_passes(f); return true; } diff --git a/inference-engine/src/transformations/src/transformations/convert_opset3_to_opset2/convert_opset3_to_opset2.cpp b/inference-engine/src/transformations/src/transformations/convert_opset3_to_opset2/convert_opset3_to_opset2.cpp index 72db0fbc9b5..968925691e2 100644 --- a/inference-engine/src/transformations/src/transformations/convert_opset3_to_opset2/convert_opset3_to_opset2.cpp +++ b/inference-engine/src/transformations/src/transformations/convert_opset3_to_opset2/convert_opset3_to_opset2.cpp @@ -33,7 +33,7 @@ bool ngraph::pass::ConvertOpSet3ToOpSet2::run_on_function(std::shared_ptr(); manager.register_pass(); - manager.set_callback(m_transformation_callback); + manager.set_pass_config(get_pass_config()); manager.run_passes(f); return true; } diff --git a/inference-engine/src/transformations/src/transformations/convert_pad_to_group_conv.cpp b/inference-engine/src/transformations/src/transformations/convert_pad_to_group_conv.cpp index 9576aa7bcf5..b987e9622c6 100644 --- a/inference-engine/src/transformations/src/transformations/convert_pad_to_group_conv.cpp +++ b/inference-engine/src/transformations/src/transformations/convert_pad_to_group_conv.cpp @@ -19,7 +19,7 @@ ngraph::pass::ConvertPadToGroupConvolution::ConvertPadToGroupConvolution() { ngraph::matcher_pass_callback callback = [this](pattern::Matcher& m) { auto pad = std::dynamic_pointer_cast (m.get_match_root()); - if (!pad || !m_transformation_callback(pad) /* disabled by default */) { + if (!pad) { return false; } diff --git a/inference-engine/src/transformations/src/transformations/depth_to_space_fusion.cpp b/inference-engine/src/transformations/src/transformations/depth_to_space_fusion.cpp index 75f38a1dfb1..411bda63fff 100644 --- a/inference-engine/src/transformations/src/transformations/depth_to_space_fusion.cpp +++ b/inference-engine/src/transformations/src/transformations/depth_to_space_fusion.cpp @@ -155,11 +155,6 @@ void ngraph::pass::DepthToSpaceFusion::depth_to_space_fusion() { std::make_shared(reshape_before->input_value(0), mode, block_size); depth_to_space->set_friendly_name(reshape_after->get_friendly_name()); ngraph::copy_runtime_info({reshape_before, permute, reshape_after}, depth_to_space); - - if (!m_transformation_callback(depth_to_space)) { - return false; - } - ngraph::replace_node(reshape_after, depth_to_space); return true; }; diff --git a/ngraph/core/include/ngraph/pass/graph_rewrite.hpp b/ngraph/core/include/ngraph/pass/graph_rewrite.hpp index dbfb47097d9..4b19bb5894c 100644 --- a/ngraph/core/include/ngraph/pass/graph_rewrite.hpp +++ b/ngraph/core/include/ngraph/pass/graph_rewrite.hpp @@ -126,12 +126,34 @@ public: m_matchers.push_back(pass); } - template + /// \brief Register given transformation class type to GraphRewrite execution list + /// All registered transformations will be executed in a single graph traversal. + /// Example below show the basic usage of pass::GraphRewrite + /// + /// pass::Manager manager; + /// auto anchor = manager.register_pass(); + /// anchor->add_matcher(); + /// anchor->add_matcher(); + /// anchor->set_name("CommonMathcers"); + /// manager.run_passes(f); + /// + /// For some purposes transformation can be registered and disabled by default. + /// + /// anchor->add_matcher(); + /// + /// \return shared_ptr to the transformation instance + template std::shared_ptr add_matcher(Args&&... args) { static_assert(std::is_base_of::value, "pass not derived from MatcherPass"); auto pass = std::make_shared(std::forward(args)...); + auto pass_config = get_pass_config(); + pass->set_pass_config(pass_config); + if (!Enabled) + { + pass_config->disable(); + } m_matchers.push_back(pass); return pass; } diff --git a/ngraph/core/include/ngraph/pass/manager.hpp b/ngraph/core/include/ngraph/pass/manager.hpp index 5438a6c29f0..6a0060f406b 100644 --- a/ngraph/core/include/ngraph/pass/manager.hpp +++ b/ngraph/core/include/ngraph/pass/manager.hpp @@ -22,7 +22,6 @@ #include #include "ngraph/pass/pass.hpp" -#include "ngraph/pass/pass_config.hpp" #include "ngraph/pass/validate.hpp" namespace ngraph @@ -39,14 +38,31 @@ public: Manager(); ~Manager(); - template + /// \brief Register given transformation class type to execution list + /// Example below show the basic usage of pass::Manager + /// + /// pass::Manager manager; + /// manager.register_pass(/*transformation constructor ars*/); + /// manager.run_passes(f); + /// + /// For some purposes transformation can be registered and disabled by default. + /// + /// manager.register_pass(); + /// + /// \return shared_ptr to the transformation instance + template std::shared_ptr register_pass(Args&&... args) { auto rc = push_pass(std::forward(args)...); + rc->set_pass_config(m_pass_config); if (m_per_pass_validation) { push_pass(); } + if (!Enable) + { + m_pass_config->disable(); + } return rc; } @@ -59,8 +75,10 @@ public: void set_per_pass_validation(bool new_state) { m_per_pass_validation = new_state; } /// \brief Callback is a lambda function that can be used by registered transformations. /// The main purpose of this callback is to provide a way for plugins to disable/enable - /// transformations. In some cases plugins may want not to execute some transformations. - /// For example plugin can disable unpleasant decompositions because of performance reasons. + /// transformations based on some conditions. In some cases plugins may want not to execute some + /// transformations. + /// For example plugin can disable unpleasant decompositions because of performance reasons for + /// some cases. /// Callback example: /// auto callback = [](const std::shared_ptr & node) -> bool { /// return std::dynamic_pointer_cast(node) != nullptr; @@ -69,15 +87,22 @@ public: /// decomposition pass will check is this decomposition needed or plugin can execute this /// operation directly. And of course on transformation side we need to have a response for this /// callback. - /// if (m_transformation_callback(batch_to_space)) { + /// if (transformation_callback(batch_to_space)) { /// return false; /// } /// \param callback lamda function that returns true in case if node is supported by plugin and /// transformation is not needed - void set_callback(param_callback callback) + NGRAPH_DEPRECATED("Please use get_pass_config() to configure transformation pipeline") + void set_callback(const param_callback& callback) { m_pass_config->set_callback(callback); } + /// \return PassConfig shared object. This object is used for transformations pipeline + /// configuration. + /// This object allows to disable/enable transformations execution, set callback to particular + /// transformation. For mo details see PassConfig class. + std::shared_ptr get_pass_config() { return m_pass_config; } + /// \brief Set external PassConfig object. + void set_pass_config(const std::shared_ptr& pass_config) { - m_transformation_callback = callback; - m_has_default_callback = false; + *m_pass_config = *pass_config; } protected: @@ -91,11 +116,7 @@ protected: return pass; } - param_callback m_transformation_callback = [](const std::shared_ptr&) -> bool { - return false; - }; - bool m_has_default_callback = true; - + std::shared_ptr m_pass_config; std::vector> m_pass_list; bool m_visualize = false; bool m_per_pass_validation = true; diff --git a/ngraph/core/include/ngraph/pass/pass.hpp b/ngraph/core/include/ngraph/pass/pass.hpp index ae5039417d2..100d45438e1 100644 --- a/ngraph/core/include/ngraph/pass/pass.hpp +++ b/ngraph/core/include/ngraph/pass/pass.hpp @@ -23,6 +23,7 @@ #include "ngraph/deprecated.hpp" #include "ngraph/function.hpp" #include "ngraph/node.hpp" +#include "ngraph/pass/pass_config.hpp" #include "ngraph/util.hpp" namespace ngraph @@ -39,7 +40,6 @@ namespace ngraph typedef EnumMask PassPropertyMask; const PassPropertyMask all_pass_property_off; - using param_callback = std::function)>; class NGRAPH_API PassBase { @@ -54,8 +54,40 @@ namespace ngraph void set_name(const std::string& name) { m_name = name; } std::string get_name() const; + /// \brief Set callback for particular transformation type. + /// This method set global callback. For more details see PassConfig class + /// documentation. + /// \param callback lambda function that takes node and returns bool void set_callback(const param_callback& callback); + /// \brief Set PassConfig for particular transformation instance + /// \param pass_config is a PassConfig shared_ptr + void set_pass_config(const std::shared_ptr& pass_config) + { + m_pass_config = pass_config; + } + + /// \brief Allows to access PassConfig shared instance + /// \return Shared instance of PassConfig class + std::shared_ptr get_pass_config() { return m_pass_config; } + /// \brief Applies callback for given node. By default callback returns false. + /// This method remains here only for backward compatibility and will be removed + /// after all transformations are moved to transformation_callback() method. + /// \return result of callback execution for given node + NGRAPH_DEPRECATED("Please use transformation_callback method instead") + bool m_transformation_callback(const std::shared_ptr& node) + { + return m_pass_config->get_callback(get_type_info())(node); + } + + /// \brief Applies callback for given node. By default callback returns false. + /// \param node which will be used inside callback + /// \return result of callback execution for given node + bool transformation_callback(const std::shared_ptr& node) + { + return m_pass_config->get_callback(get_type_info())(node); + } + using type_info_t = DiscreteTypeInfo; virtual const type_info_t& get_type_info() const = 0; @@ -63,13 +95,11 @@ namespace ngraph protected: void set_property(const PassPropertyMask& prop, bool value); - param_callback m_transformation_callback = - [](const std::shared_ptr&) -> bool { return false; }; - bool m_has_default_callback = true; - private: PassPropertyMask m_property; + std::string m_name; + std::shared_ptr m_pass_config; }; class NGRAPH_API FunctionPass : public PassBase diff --git a/ngraph/core/include/ngraph/pass/pass_config.hpp b/ngraph/core/include/ngraph/pass/pass_config.hpp index a592d879f07..7d85d1a26f3 100644 --- a/ngraph/core/include/ngraph/pass/pass_config.hpp +++ b/ngraph/core/include/ngraph/pass/pass_config.hpp @@ -16,31 +16,157 @@ #pragma once -#include -#include +#include +#include +#include -#include +#include "ngraph/deprecated.hpp" +#include "ngraph/function.hpp" +#include "ngraph/node.hpp" +#include "ngraph/util.hpp" namespace ngraph { namespace pass { - class PassConfig; + using param_callback = std::function)>; + using param_callback_map = std::map; + + /// \brief Class representing a transformations config that is used for disabling/enabling + /// transformations registered inside pass::Manager and also allows to set callback for all + /// transformations or for particular transformation. + /// + /// When pass::Manager is created all passes registered inside this manager including nested + /// passes will share the same instance of PassConfig class. + /// To work with this class first you need to get shared instance of this class by calling + /// manager.get_pass_config() method. Then you will be able to disable/enable passes based + /// on transformations type_info. For example: + /// + /// pass::Manager manager; + /// manager.register_pass(); + /// auto pass_config = manager.get_pass_config(); + /// pass_config->disable(); // this will disable nested pass inside + /// // CommonOptimizations pipeline + /// manager.run_passes(f); + /// + /// Sometimes it is needed to call transformation inside other transformation manually. And + /// for that case before running transformation you need manually check that this pass is + /// not disabled and then you need to set current PassConfig instance to this + /// transformation. For example: + /// + /// // Inside MatcherPass callback or inside FunctionPass run_on_function() method + /// // you need to call get_pass_config() method to get shared instance of PassConfig + /// auto pass_config = get_pass_config(); + /// + /// // Before running nested transformation you need to check is it disabled or not + /// if (!pass_config->is_disabled()) { + /// auto pass = ConvertGELU(); + /// pass->set_pass_config(pass_config); + /// pass.apply(node); + /// } + /// + /// Following this logic inside your transformations you will guaranty that transformations + /// will be executed in a right way. + class NGRAPH_API PassConfig + { + public: + /// \brief Disable transformation by its type_info + /// \param type_info Transformation type_info + void disable(const DiscreteTypeInfo& type_info) { m_disabled.insert(type_info); } + /// \brief Disable transformation by its class type (based on type_info) + template + void disable() + { + disable(T::type_info); + } + + /// \brief Enable transformation by its type_info + /// \param type_info Transformation type_info + void enable(const DiscreteTypeInfo& type_info) { m_disabled.erase(type_info); } + /// \brief Enable transformation by its class type (based on type_info) + template + void enable() + { + enable(T::type_info); + } + + /// \brief Set callback for all kind of transformations + void set_callback(const param_callback& callback) { m_callback = callback; } + template + typename std::enable_if::type + set_callback(const param_callback& callback) + { + } + + /// \brief Set callback for particular transformation class types + /// + /// Example below show how to set callback for one or multiple passes using this method. + /// + /// pass_config->set_callback( + /// [](const_node_ptr &node) -> bool { + /// // Disable transformations for cases when input shape rank is not + /// equal to 4 + /// const auto input_shape_rank = + /// node->get_output_partial_shape(0).rank().get_length(); + /// if (input_shape_rank != 4) { + /// return false; + /// } + /// return true; + /// }); + /// + /// Note that inside transformations you must provide code that work with this callback. + /// See example below: + /// + /// if (transformation_callback(node)) { + /// return false; // exit from transformation + /// } + /// + template + void set_callback(const param_callback& callback) + { + m_callback_map[T::type_info] = callback; + set_callback(callback); + } + + /// \brief Get callback for given transformation type_info + /// \param type_info Transformation type_info + /// + /// In case if callback wasn't set for given transformation type then global callback + /// will be returned. But if even global callback wasn't set then default callback will + /// be returned. + param_callback get_callback(const DiscreteTypeInfo& type_info) const; + + /// \brief Get callback for given transformation class type + /// \return callback lambda function + template + param_callback get_callback() const + { + return get_callback(T::type_info); + } + + /// \brief Check either transformation type is disabled or not + /// \param type_info Transformation type_info + /// \return true if transformation type was disabled and false otherwise + bool is_disabled(const DiscreteTypeInfo& type_info) const + { + return m_disabled.count(type_info); + } + + /// \brief Check either transformation class type is disabled or not + /// \return true if transformation type was disabled and false otherwise + template + bool is_disabled() const + { + return is_disabled(T::type_info); + } + + private: + param_callback m_callback = [](const std::shared_ptr&) { + return false; + }; + param_callback_map m_callback_map; + std::unordered_set m_disabled; + }; } -} - -class NGRAPH_API ngraph::pass::PassConfig -{ -public: - PassConfig(); - const std::map& get_enables() const { return m_pass_enables; } - void set_pass_enable(const std::string& name, bool enable); - bool get_pass_enable(const std::string& name) const; - const std::map& get_pass_attributes() const { return m_pass_attributes; } - void set_pass_attribute(const std::string& name, bool enable); - bool get_pass_attribute(const std::string& name) const; - -private: - std::map m_pass_enables; - std::map m_pass_attributes; -}; +} \ No newline at end of file diff --git a/ngraph/core/src/pass/graph_rewrite.cpp b/ngraph/core/src/pass/graph_rewrite.cpp index ef65eea37c1..66993ebd05a 100644 --- a/ngraph/core/src/pass/graph_rewrite.cpp +++ b/ngraph/core/src/pass/graph_rewrite.cpp @@ -72,6 +72,7 @@ bool pass::GraphRewrite::run_on_function(shared_ptr f) OV_ITT_SCOPED_TASK(itt::domains::nGraph, "pass::GraphRewrite::run_on_function"); bool rewritten = false; + const auto& pass_config = get_pass_config(); // Initialize execution queue with nodes in topological order deque> nodes_to_run; @@ -85,6 +86,10 @@ bool pass::GraphRewrite::run_on_function(shared_ptr f) std::unordered_map> type_to_matcher; for (size_t matcher_index = 0; matcher_index < m_matchers.size(); ++matcher_index) { + // Skip passes that are disabled + if (pass_config->is_disabled(m_matchers[matcher_index]->get_type_info())) + continue; + auto matcher = m_matchers[matcher_index]->get_matcher(); if (!matcher) { @@ -139,11 +144,6 @@ bool pass::GraphRewrite::run_on_function(shared_ptr f) return false; } - if (!m_has_default_callback) - { - m_pass->set_callback(m_transformation_callback); - } - // Apply MatcherPass. In case if it returns true no other MatcherPasses will apply // to this node bool status = m_pass->apply(node); @@ -224,6 +224,10 @@ bool pass::GraphRewrite::run_on_function(shared_ptr f) { for (auto& m_pass : m_matchers) { + // Skip passes that are disabled + if (pass_config->is_disabled(m_pass->get_type_info())) + continue; + if (run_matcher_pass(m_pass, node)) { rewritten = true; diff --git a/ngraph/core/src/pass/manager.cpp b/ngraph/core/src/pass/manager.cpp index 67045a7543d..e4c7044897f 100644 --- a/ngraph/core/src/pass/manager.cpp +++ b/ngraph/core/src/pass/manager.cpp @@ -36,6 +36,7 @@ using namespace ngraph; pass::Manager::Manager() : m_visualize(getenv_bool("NGRAPH_ENABLE_VISUALIZE_TRACING")) + , m_pass_config(std::make_shared()) { } @@ -56,12 +57,14 @@ void pass::Manager::run_passes(shared_ptr func) bool function_changed = false; for (auto& pass : m_pass_list) { - pass_timer.start(); - if (!m_has_default_callback) + if (m_pass_config->is_disabled(pass->get_type_info())) { - pass->set_callback(m_transformation_callback); + NGRAPH_DEBUG << "Pass " << pass->get_name() << " is disabled"; + continue; } + pass_timer.start(); + NGRAPH_SUPPRESS_DEPRECATED_START if (auto matcher_pass = dynamic_pointer_cast(pass)) { diff --git a/ngraph/core/src/pass/pass.cpp b/ngraph/core/src/pass/pass.cpp index bd4d3b03b16..4229e5a1fc1 100644 --- a/ngraph/core/src/pass/pass.cpp +++ b/ngraph/core/src/pass/pass.cpp @@ -33,6 +33,7 @@ NGRAPH_RTTI_DEFINITION(ngraph::pass::NodePass, "ngraph::pass::NodePass", 0); pass::PassBase::PassBase() : m_property{all_pass_property_off} + , m_pass_config(std::make_shared()) { } @@ -73,8 +74,7 @@ std::string pass::PassBase::get_name() const void pass::PassBase::set_callback(const param_callback& callback) { - m_transformation_callback = callback; - m_has_default_callback = false; + m_pass_config->set_callback(callback); } // The symbols are requiered to be in cpp file to workaround RTTI issue on Android LLVM diff --git a/ngraph/core/src/pass/pass_config.cpp b/ngraph/core/src/pass/pass_config.cpp index 1d2435a62e5..c123d4b18dd 100644 --- a/ngraph/core/src/pass/pass_config.cpp +++ b/ngraph/core/src/pass/pass_config.cpp @@ -15,100 +15,18 @@ //***************************************************************************** #include "ngraph/pass/pass_config.hpp" -#include "ngraph/env_util.hpp" -#include "ngraph/except.hpp" -#include "ngraph/log.hpp" -#include "ngraph/util.hpp" -using namespace std; using namespace ngraph; -// TODO: Add file-based configuration support -pass::PassConfig::PassConfig() +pass::param_callback pass::PassConfig::get_callback(const DiscreteTypeInfo& type_info) const { - // - // Parses the semi-colon separated environment string passed through NGRAPH_PASS_ENABLES - // and returns the pass names and whether they should be enabled or disabled in the - // provided unordered_map. Implementation of pass selection is up to the backend - // E.g., NGRAPH_PASS_ENABLES="CoreFusion:0;LikeReplacement:1;CPUCollapseDims" would - // set disables on CoreFusion and enables on LikeReplacement and CPUCollapseDims - // - string pass_enables = getenv_string("NGRAPH_PASS_ENABLES"); - if (!pass_enables.empty()) - { - stringstream ss; - ss << pass_enables; - while (ss.good()) - { - string substr; - getline(ss, substr, ';'); - auto split_str = split(substr, ':', false); - switch (split_str.size()) - { - case 1: m_pass_enables.emplace(split_str[0], true); break; - case 2: m_pass_enables.emplace(split_str[0], parse_string(split_str[1])); break; - default: throw ngraph_error("Unexpected string in NGRAPH_PASS_ENABLES: " + substr); - } - } - } - // - // Parses the semi-colon separated environment string passed through NGRAPH_PASS_ATTRIBUTES - // and returns the pass attributes and whether they should be enabled or disabled in the - // provided unordered_map. Naming of pass attributes is up to the backends. - // - // For example: - // NGRAPH_PASS_ATTRIBUTES="OptimizeForMemory=0;MemoryAssignment::ReuseMemory=1;UseDefaultLayouts" - // would set false on "OptimizeForMemory", true on "MemoryAssignment::ReuseMemory" and true on - // "UseDefaultLayouts" - // - static const string pass_attributes = getenv_string("NGRAPH_PASS_ATTRIBUTES"); - if (!pass_attributes.empty()) - { - stringstream ss; - ss << pass_attributes; - while (ss.good()) - { - string substr; - getline(ss, substr, ';'); - auto split_str = split(substr, '=', false); - switch (split_str.size()) - { - case 1: m_pass_attributes.emplace(split_str[0], true); break; - case 2: - m_pass_attributes.emplace(split_str[0], parse_string(split_str[1])); - break; - default: throw ngraph_error("Unexpected string in NGRAPH_PASS_ATTRIBUTES: " + substr); - } - } - } -} - -void pass::PassConfig::set_pass_enable(const string& name, bool enable) -{ - m_pass_enables[name] = enable; -} - -bool pass::PassConfig::get_pass_enable(const string& name) const -{ - auto it = m_pass_enables.find(name); - if (it != m_pass_enables.end()) + const auto& it = m_callback_map.find(type_info); + if (it != m_callback_map.end()) { return it->second; } - return false; -} - -void pass::PassConfig::set_pass_attribute(const string& name, bool enable) -{ - m_pass_attributes[name] = enable; -} - -bool pass::PassConfig::get_pass_attribute(const string& name) const -{ - auto it = m_pass_attributes.find(name); - if (it != m_pass_attributes.end()) + else { - return it->second; + return m_callback; } - return false; } diff --git a/ngraph/test/graph_rewrite.cpp b/ngraph/test/graph_rewrite.cpp index 5f3b13c53b3..e5c435d7e2a 100644 --- a/ngraph/test/graph_rewrite.cpp +++ b/ngraph/test/graph_rewrite.cpp @@ -15,6 +15,7 @@ using namespace ngraph; class TestPass : public ngraph::pass::MatcherPass { public: + NGRAPH_RTTI_DECLARATION; TestPass() : MatcherPass() { @@ -39,12 +40,16 @@ public: class Anchor : public ngraph::pass::GraphRewrite { public: + NGRAPH_RTTI_DECLARATION; Anchor() : GraphRewrite() { } }; +NGRAPH_RTTI_DEFINITION(TestPass, "TestPass", 0); +NGRAPH_RTTI_DEFINITION(Anchor, "Anchor", 0); + std::shared_ptr get_function() { auto data = @@ -93,7 +98,7 @@ TEST(GraphRewriteTest, GraphRewriteCallback) ASSERT_EQ(count_ops_of_type(f), 1); } -TEST(GraphRewriteTest, ManagerCallback) +TEST(GraphRewriteTest, ManagerCallbackDeprecated) { auto f = get_function(); @@ -106,6 +111,20 @@ TEST(GraphRewriteTest, ManagerCallback) ASSERT_EQ(count_ops_of_type(f), 1); } +TEST(GraphRewriteTest, ManagerCallback) +{ + auto f = get_function(); + + pass::Manager manager; + auto anchor = manager.register_pass(); + anchor->add_matcher(); + auto pass_config = manager.get_pass_config(); + pass_config->set_callback(get_callback()); + manager.run_passes(f); + + ASSERT_EQ(count_ops_of_type(f), 1); +} + TEST(GraphRewriteTest, ManagerCallback2) { auto f = get_function(); @@ -244,4 +263,127 @@ TEST(GraphRewriteTest, TypeBasedMatcherPassOrder2) anchor.run_on_function(f); ASSERT_EQ(count_ops_of_type(f), 1); +} + +TEST(PassConfigTest, Test1) +{ + { + auto f = get_function(); + + pass::Manager manager; + manager.register_pass(); + + auto pass_config = manager.get_pass_config(); + pass_config->set_callback(get_callback()); + + manager.run_passes(f); + + ASSERT_EQ(count_ops_of_type(f), 1); + } + + { + auto f = get_function(); + + pass::Manager manager; + manager.register_pass(); + + auto pass_config = manager.get_pass_config(); + pass_config->set_callback(get_callback()); + + manager.run_passes(f); + + ASSERT_EQ(count_ops_of_type(f), 1); + } + + { + auto f = get_function(); + + pass::Manager manager; + manager.register_pass(); + + auto pass_config = std::make_shared(); + pass_config->set_callback(get_callback()); + + manager.set_pass_config(pass_config); + manager.run_passes(f); + + ASSERT_EQ(count_ops_of_type(f), 1); + } + + { + auto f = get_function(); + + pass::Manager manager; + auto anchor = manager.register_pass(); + anchor->add_matcher(); + + auto pass_config = anchor->get_pass_config(); + pass_config->set_callback(get_callback()); + + manager.run_passes(f); + + ASSERT_EQ(count_ops_of_type(f), 1); + } + + { + auto f = get_function(); + + pass::Manager manager; + auto anchor = manager.register_pass(); + anchor->add_matcher(); + + auto pass_config = anchor->get_pass_config(); + pass_config->set_callback(get_callback()); + + manager.run_passes(f); + + ASSERT_EQ(count_ops_of_type(f), 1); + } + + { + auto pass_config = std::make_shared(); + + pass::Manager manager1; + pass::Manager manager2; + manager1.set_pass_config(pass_config); + manager2.set_pass_config(pass_config); + ASSERT_EQ(pass_config.use_count(), 1); + } + + { + auto f = get_function(); + + pass::Manager manager; + manager.register_pass(); + + auto pass_config = manager.get_pass_config(); + pass_config->set_callback(get_callback()); + + pass_config->disable(); + manager.run_passes(f); + ASSERT_EQ(count_ops_of_type(f), 0); + + pass_config->enable(); + manager.run_passes(f); + ASSERT_EQ(count_ops_of_type(f), 1); + } + + { + auto f = get_function(); + + pass::Manager manager; + auto anchor = manager.register_pass(); + anchor->add_matcher(); + + auto pass_config = manager.get_pass_config(); + pass_config->set_callback(get_callback()); + + pass_config->disable(); + manager.run_passes(f); + ASSERT_EQ(count_ops_of_type(f), 0); + + pass_config->enable(); + manager.run_passes(f); + ASSERT_EQ(count_ops_of_type(f), 1); + } } \ No newline at end of file diff --git a/ngraph/test/runtime/CMakeLists.txt b/ngraph/test/runtime/CMakeLists.txt index e37aba8b7bd..27cf6b0e61f 100644 --- a/ngraph/test/runtime/CMakeLists.txt +++ b/ngraph/test/runtime/CMakeLists.txt @@ -51,6 +51,8 @@ set (SRC pass/shape_relevance.hpp ) +disable_deprecated_warnings() + add_library(ngraph_backend SHARED ${SRC}) target_compile_definitions(ngraph_backend PRIVATE diff --git a/ngraph/test/runtime/backend.cpp b/ngraph/test/runtime/backend.cpp index da5a7bab7a7..2a2444a4208 100644 --- a/ngraph/test/runtime/backend.cpp +++ b/ngraph/test/runtime/backend.cpp @@ -102,14 +102,6 @@ std::shared_ptr throw std::invalid_argument("This backend does not support dynamic tensors"); } -std::shared_ptr - runtime::Backend::compile(std::shared_ptr func, - ngraph::pass::PassConfig& /* pass_config */, - bool enable_performance_data) -{ - return compile(func, enable_performance_data); -} - bool runtime::Backend::is_supported(const Node& /* node */) const { // The default behavior is that a backend does not support any ops. If this is not the case diff --git a/ngraph/test/runtime/backend.hpp b/ngraph/test/runtime/backend.hpp index b5dd2577504..b8757817377 100644 --- a/ngraph/test/runtime/backend.hpp +++ b/ngraph/test/runtime/backend.hpp @@ -22,7 +22,6 @@ #include "backend_visibility.hpp" #include "executable.hpp" #include "ngraph/function.hpp" -#include "ngraph/pass/pass_config.hpp" #include "ngraph/shape.hpp" #include "ngraph/type/element_type.hpp" #include "ngraph/util.hpp" @@ -111,14 +110,6 @@ public: virtual std::shared_ptr compile(std::shared_ptr func, bool enable_performance_data = false) = 0; - /// \brief Compiles a Function. - /// \param func The function to compile - /// \param pass_config Configuration object for defining compilation options - /// \returns compiled function or nullptr on failure - virtual std::shared_ptr compile(std::shared_ptr func, - ngraph::pass::PassConfig& pass_config, - bool enable_performance_data = false); - /// \brief Loads a previously saved Executable object from a stream. /// \param input_stream the opened input stream containing the saved Executable /// \returns A compiled function or throws an exception on error