From 9eca6ba9d5b4b4422f726f417e25c826ef72e0c9 Mon Sep 17 00:00:00 2001 From: Ilya Churaev Date: Thu, 2 Sep 2021 10:03:04 +0300 Subject: [PATCH] Move pass pattern to ov (#7255) * Moved ngraph::Node to ov namespace * Fixed code style * Fixed VPU * Fixed GNA * Fixed tests * Added aliases for backward compatibility * Fix clDNN * Try to fix build * Fixed comment * Renamed RTTI macros * Add new headers * Fixed ngraph build * Fixed unit tests * Try to fix Serialize --- CMakeLists.txt | 2 +- .../src/transformations/serialize.cpp | 54 ++-- .../include/ngraph/pass/constant_folding.hpp | 18 +- .../ngraph/pass/convert_fp32_to_fp16.hpp | 9 +- .../include/ngraph/pass/graph_rewrite.hpp | 241 +--------------- .../core/include/ngraph/pass/low_latency.hpp | 40 +-- ngraph/core/include/ngraph/pass/manager.hpp | 100 +------ ngraph/core/include/ngraph/pass/pass.hpp | 107 ++----- .../core/include/ngraph/pass/pass_config.hpp | 160 +---------- ngraph/core/include/ngraph/pass/validate.hpp | 21 +- .../include/ngraph/pass/visualize_tree.hpp | 38 +-- .../core/include/ngraph/pattern/matcher.hpp | 254 +--------------- ngraph/core/include/ngraph/pattern/op/any.hpp | 30 +- .../core/include/ngraph/pattern/op/any_of.hpp | 39 +-- .../include/ngraph/pattern/op/any_output.hpp | 15 +- .../core/include/ngraph/pattern/op/branch.hpp | 40 +-- .../include/ngraph/pattern/op/capture.hpp | 29 +- .../core/include/ngraph/pattern/op/label.hpp | 98 +------ ngraph/core/include/ngraph/pattern/op/or.hpp | 17 +- .../include/ngraph/pattern/op/pattern.hpp | 97 ++----- .../core/include/ngraph/pattern/op/skip.hpp | 29 +- .../core/include/ngraph/pattern/op/true.hpp | 13 +- .../include/ngraph/pattern/op/wrap_type.hpp | 60 +--- ngraph/core/include/openvino/core/node.hpp | 12 +- .../openvino/pass/constant_folding.hpp | 28 ++ .../openvino/pass/convert_fp32_to_fp16.hpp | 17 ++ .../include/openvino/pass/graph_rewrite.hpp | 249 ++++++++++++++++ .../include/openvino/pass/low_latency.hpp | 48 ++++ ngraph/core/include/openvino/pass/manager.hpp | 116 ++++++++ ngraph/core/include/openvino/pass/pass.hpp | 110 +++++++ .../include/openvino/pass/pass_config.hpp | 176 ++++++++++++ .../include/openvino/pass/pattern/matcher.hpp | 271 ++++++++++++++++++ .../include/openvino/pass/pattern/op/any.hpp | 45 +++ .../openvino/pass/pattern/op/any_of.hpp | 54 ++++ .../openvino/pass/pattern/op/any_output.hpp | 30 ++ .../openvino/pass/pattern/op/branch.hpp | 55 ++++ .../openvino/pass/pattern/op/capture.hpp | 44 +++ .../openvino/pass/pattern/op/label.hpp | 113 ++++++++ .../include/openvino/pass/pattern/op/or.hpp | 32 +++ .../openvino/pass/pattern/op/pattern.hpp | 96 +++++++ .../include/openvino/pass/pattern/op/skip.hpp | 44 +++ .../include/openvino/pass/pattern/op/true.hpp | 28 ++ .../openvino/pass/pattern/op/wrap_type.hpp | 75 +++++ .../core/include/openvino/pass/validate.hpp | 32 +++ .../include/openvino/pass/visualize_tree.hpp | 57 ++++ ngraph/core/src/pass/constant_folding.cpp | 15 +- ngraph/core/src/pass/convert_fp32_to_fp16.cpp | 7 +- ngraph/core/src/pass/graph_rewrite.cpp | 69 +++-- ngraph/core/src/pass/low_latency.cpp | 64 +++-- ngraph/core/src/pass/manager.cpp | 27 +- ngraph/core/src/pass/pass.cpp | 23 +- ngraph/core/src/pass/pass_config.cpp | 12 +- ngraph/core/src/pass/perf_counters.cpp | 4 +- ngraph/core/src/pass/perf_counters.hpp | 4 +- ngraph/core/src/pass/validate.cpp | 6 +- ngraph/core/src/pattern/matcher.cpp | 16 +- ngraph/core/src/pattern/op/any.cpp | 11 +- ngraph/core/src/pattern/op/any_of.cpp | 13 +- ngraph/core/src/pattern/op/any_output.cpp | 11 +- ngraph/core/src/pattern/op/label.cpp | 17 +- ngraph/core/src/pattern/op/pattern.cpp | 6 +- ngraph/test/runtime/pass/dyn_elimination.cpp | 36 +-- 62 files changed, 2028 insertions(+), 1556 deletions(-) create mode 100644 ngraph/core/include/openvino/pass/constant_folding.hpp create mode 100644 ngraph/core/include/openvino/pass/convert_fp32_to_fp16.hpp create mode 100644 ngraph/core/include/openvino/pass/graph_rewrite.hpp create mode 100644 ngraph/core/include/openvino/pass/low_latency.hpp create mode 100644 ngraph/core/include/openvino/pass/manager.hpp create mode 100644 ngraph/core/include/openvino/pass/pass.hpp create mode 100644 ngraph/core/include/openvino/pass/pass_config.hpp create mode 100644 ngraph/core/include/openvino/pass/pattern/matcher.hpp create mode 100644 ngraph/core/include/openvino/pass/pattern/op/any.hpp create mode 100644 ngraph/core/include/openvino/pass/pattern/op/any_of.hpp create mode 100644 ngraph/core/include/openvino/pass/pattern/op/any_output.hpp create mode 100644 ngraph/core/include/openvino/pass/pattern/op/branch.hpp create mode 100644 ngraph/core/include/openvino/pass/pattern/op/capture.hpp create mode 100644 ngraph/core/include/openvino/pass/pattern/op/label.hpp create mode 100644 ngraph/core/include/openvino/pass/pattern/op/or.hpp create mode 100644 ngraph/core/include/openvino/pass/pattern/op/pattern.hpp create mode 100644 ngraph/core/include/openvino/pass/pattern/op/skip.hpp create mode 100644 ngraph/core/include/openvino/pass/pattern/op/true.hpp create mode 100644 ngraph/core/include/openvino/pass/pattern/op/wrap_type.hpp create mode 100644 ngraph/core/include/openvino/pass/validate.hpp create mode 100644 ngraph/core/include/openvino/pass/visualize_tree.hpp diff --git a/CMakeLists.txt b/CMakeLists.txt index 61a96ae9f4c..2cec8d2d5e8 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -26,7 +26,7 @@ endif() # resolving dependencies for the project message (STATUS "PROJECT ............................... " ${PROJECT_NAME}) message (STATUS "CMAKE_BINARY_DIR ...................... " ${CMAKE_BINARY_DIR}) -message (STATUS "OpenVINO_SOURCE_DIR .... .......... " ${OpenVINO_SOURCE_DIR}) +message (STATUS "OpenVINO_SOURCE_DIR ................... " ${OpenVINO_SOURCE_DIR}) message (STATUS "CMAKE_GENERATOR ....................... " ${CMAKE_GENERATOR}) message (STATUS "CMAKE_C_COMPILER_ID ................... " ${CMAKE_C_COMPILER_ID}) message (STATUS "CMAKE_BUILD_TYPE ...................... " ${CMAKE_BUILD_TYPE}) diff --git a/inference-engine/src/transformations/src/transformations/serialize.cpp b/inference-engine/src/transformations/src/transformations/serialize.cpp index 3fbb1463f08..b10c5da24fa 100644 --- a/inference-engine/src/transformations/src/transformations/serialize.cpp +++ b/inference-engine/src/transformations/src/transformations/serialize.cpp @@ -811,8 +811,34 @@ void ngfunction_2_irv10(pugi::xml_node& netXml, f.validate_nodes_and_infer_types(); } } + +std::string valid_xml_path(const std::string &path) { + NGRAPH_CHECK(path.length() > 4, "Path for xml file is to short: \"" + path + "\""); + + const char *const extension = ".xml"; + const bool has_xml_extension = path.rfind(extension) == path.size() - std::strlen(extension); + NGRAPH_CHECK(has_xml_extension, + "Path for xml file doesn't contains file name with 'xml' extension: \"" + + path + "\""); + return path; +} + +std::string provide_bin_path(const std::string &xmlPath, const std::string &binPath) { + if (!binPath.empty()) { + return binPath; + } + assert(xmlPath.size() > 4); // should be check by valid_xml_path + std::string bestPath = xmlPath; + const char *const extension = "bin"; + const auto ext_size = std::strlen(extension); + bestPath.replace(bestPath.size() - ext_size, ext_size, extension); + return bestPath; +} + } // namespace +namespace ngraph { + // ! [function_pass:serialize_cpp] // serialize.cpp bool pass::Serialize::run_on_function(std::shared_ptr f) { @@ -868,33 +894,6 @@ bool pass::Serialize::run_on_function(std::shared_ptr f) { return false; } -namespace { - -std::string valid_xml_path(const std::string &path) { - NGRAPH_CHECK(path.length() > 4, "Path for xml file is to short: \"" + path + "\""); - - const char *const extension = ".xml"; - const bool has_xml_extension = path.rfind(extension) == path.size() - std::strlen(extension); - NGRAPH_CHECK(has_xml_extension, - "Path for xml file doesn't contains file name with 'xml' extension: \"" + - path + "\""); - return path; -} - -std::string provide_bin_path(const std::string &xmlPath, const std::string &binPath) { - if (!binPath.empty()) { - return binPath; - } - assert(xmlPath.size() > 4); // should be check by valid_xml_path - std::string bestPath = xmlPath; - const char *const extension = "bin"; - const auto ext_size = std::strlen(extension); - bestPath.replace(bestPath.size() - ext_size, ext_size, extension); - return bestPath; -} - -} // namespace - pass::Serialize::Serialize(std::ostream& xmlFile, std::ostream& binFile, pass::Serialize::Version version, @@ -921,3 +920,4 @@ pass::Serialize::Serialize(const std::string& xmlPath, { } // ! [function_pass:serialize_cpp] +} // namespace ngraph diff --git a/ngraph/core/include/ngraph/pass/constant_folding.hpp b/ngraph/core/include/ngraph/pass/constant_folding.hpp index 479704dc699..884bfa4551c 100644 --- a/ngraph/core/include/ngraph/pass/constant_folding.hpp +++ b/ngraph/core/include/ngraph/pass/constant_folding.hpp @@ -5,24 +5,10 @@ #pragma once #include "ngraph/pass/pass.hpp" +#include "openvino/pass/constant_folding.hpp" namespace ngraph { namespace pass { -/** - * @brief Constant folding iterates over the function and tries to evaluate nodes - * with constant inputs. Such nodes are then replaced with new Constants containing - * the result of a folded operation. - */ -class NGRAPH_API ConstantFolding : public FunctionPass { -public: - NGRAPH_RTTI_DECLARATION; - bool run_on_function(std::shared_ptr f) override; - -private: - void copy_runtime_info_to_target_inputs(const std::shared_ptr& node, const Output& replacement); - /// \brief Folds pre-calculated output tensor values to constants in case lower and - /// upper estimations are equal. Traverses graph backwards starting from the results. - bool pre_calculated_values_folding(const std::shared_ptr& f); -}; +using ov::pass::ConstantFolding; } // namespace pass } // namespace ngraph diff --git a/ngraph/core/include/ngraph/pass/convert_fp32_to_fp16.hpp b/ngraph/core/include/ngraph/pass/convert_fp32_to_fp16.hpp index 5753c0fd7ea..302eb89677a 100644 --- a/ngraph/core/include/ngraph/pass/convert_fp32_to_fp16.hpp +++ b/ngraph/core/include/ngraph/pass/convert_fp32_to_fp16.hpp @@ -4,14 +4,11 @@ #pragma once -#include +#include "ngraph/pass/graph_rewrite.hpp" +#include "openvino/pass/convert_fp32_to_fp16.hpp" namespace ngraph { namespace pass { -class NGRAPH_API ConvertFP32ToFP16 : public ngraph::pass::FunctionPass { -public: - NGRAPH_RTTI_DECLARATION; - bool run_on_function(std::shared_ptr) override; -}; +using ov::pass::ConvertFP32ToFP16; } // namespace pass } // namespace ngraph diff --git a/ngraph/core/include/ngraph/pass/graph_rewrite.hpp b/ngraph/core/include/ngraph/pass/graph_rewrite.hpp index 17fd2b732f6..66daa15fb3b 100644 --- a/ngraph/core/include/ngraph/pass/graph_rewrite.hpp +++ b/ngraph/core/include/ngraph/pass/graph_rewrite.hpp @@ -10,240 +10,17 @@ #include "ngraph/pass/pass.hpp" #include "ngraph/pattern/matcher.hpp" +#include "openvino/pass/graph_rewrite.hpp" namespace ngraph { -using matcher_pass_callback = std::function; -using graph_rewrite_callback = std::function; -using recurrent_graph_rewrite_callback = std::function; -using handler_callback = std::function& node)>; +using ov::graph_rewrite_callback; +using ov::handler_callback; +using ov::matcher_pass_callback; +using ov::recurrent_graph_rewrite_callback; namespace pass { -/// \brief MatcherPass is a basic block for pattern based transformations. It describes -/// pattern and -/// action that is applied if pattern is matched. -/// -/// MatcherPass consists of Matcher and matcher_pass_callback that needs to be implemented -/// and -/// finally registered by using \sa register_matcher. MatcherPass can be executed on node -/// within -/// \sa apply method. To run matcher pass on Function use GraphRewrite. -/// In addition MatcherPass provides a way for adding new operations into GraphRewrite -/// execution -/// queue. That means that operations that were created inside transformation callback can -/// be added -/// for matching. To register node use \sa register_new_node method. GraphRewrite -/// automatically -/// takes registered nodes and put them to execution queue. If multiple nodes were register -/// make -/// sure that they were registered in topological order. -/// Note: when implementing pattern for Matcher make sure that root node is an operation -/// from opset -/// or has ngraph::pattern::op::WrapType. That will help GraphRewrite to execute matcher -/// passes more -/// efficient. - -class NGRAPH_API MatcherPass : public ngraph::pass::PassBase { -public: - NGRAPH_RTTI_DECLARATION; - - MatcherPass() = default; - - MatcherPass(const MatcherPass&) = delete; - MatcherPass& operator=(const MatcherPass&) = delete; - - explicit MatcherPass(const std::string& name, - const std::shared_ptr& m, - const handler_callback& handler, - const PassPropertyMask& property = PassProperty::CHANGE_DYNAMIC_STATE) - : PassBase(), - m_handler(handler), - m_matcher(m) { - set_name(name); - set_property(property, true); - } - - bool apply(std::shared_ptr node); - - template - std::shared_ptr register_new_node(Args&&... args) { - auto node = std::make_shared(std::forward(args)...); - m_new_nodes.push_back(node); - return node; - } - - template - std::shared_ptr register_new_node(const std::shared_ptr& node) { - m_new_nodes.push_back(node); - return node; - } - - const std::vector>& get_new_nodes() { - return m_new_nodes; - } - void clear_new_nodes() { - m_new_nodes.clear(); - } - std::shared_ptr get_matcher() { - return m_matcher; - } - -protected: - void register_matcher(const std::shared_ptr& m, - const ngraph::graph_rewrite_callback& callback, - const PassPropertyMask& property = PassProperty::CHANGE_DYNAMIC_STATE); - -private: - handler_callback m_handler; - std::shared_ptr m_matcher; - std::vector> m_new_nodes; -}; - -/// \brief GraphRewrite is a container for MatcherPasses that allows to run them on Function -/// in -/// efficient way -/// -/// Graph rewrite pass is used for matcher passes execution on Function. -/// To register MatcherPass use \sa add_matcher(args) method where T is a MatcherPass -/// class. -/// As a default algorithm graph rewrite pass traverse Function in topological order and -/// applies -/// registered matcher passes for each node. But if all registered matcher passes have type -/// based -/// root node in Matcher pattern then efficient mechanism is used to execute them. -/// Matcher pattern root is type based if it's operation from opset or -/// pattern::op::WrapType. -/// Note: when implementing pattern for Matcher make sure that root node is an operation -/// from opset -/// or has ngraph::pattern::op::WrapType. That will help GraphRewrite to execute matcher -/// passes more -/// efficient. - -class NGRAPH_API GraphRewrite : public ngraph::pass::FunctionPass { -public: - NGRAPH_RTTI_DECLARATION; - - GraphRewrite() = default; - - explicit GraphRewrite(const std::shared_ptr& pass) : FunctionPass() { - m_matchers.push_back(pass); - } - - /// \brief Register given transformation class type to GraphRewrite execution list - /// All registered transformations will be executed in a single graph traversal. - /// Example below show the basic usage of pass::GraphRewrite - /// - /// pass::Manager manager; - /// auto anchor = manager.register_pass(); - /// anchor->add_matcher(); - /// anchor->add_matcher(); - /// anchor->set_name("CommonMatchers"); - /// manager.run_passes(f); - /// - /// For some purposes transformation can be registered and disabled by default. - /// - /// anchor->add_matcher(); - /// - /// \return shared_ptr to the transformation instance - template ::value, bool>::type = true> - std::shared_ptr add_matcher(Args&&... args) { - static_assert(std::is_base_of::value, "pass not derived from MatcherPass"); - auto pass = std::make_shared(std::forward(args)...); - auto pass_config = get_pass_config(); - pass->set_pass_config(pass_config); - if (!Enabled && !pass_config->is_enabled()) { - pass_config->disable(); - } - m_matchers.push_back(pass); - return pass; - } - - /// \brief Register passes from GraphRewrite class that contains sequence of matcher - /// passes registered in its ctor. - /// For example: - /// - /// class ngraph::pass::LinFusions: public ngraph::pass::GraphRewrite { - /// public: - /// NGRAPH_RTTI_DECLARATION; - /// Fusions() { - /// add_matcher(); - /// add_matcher(); - /// } - /// }; - /// - /// pass::Manager manager; - /// auto anchor = manager.register_pass(); - /// anchor->add_matcher(); - /// anchor->add_matcher(); - /// anchor->set_name("CommonFusions"); - /// manager.run_passes(f); - /// - /// In this case all matcher passes from LinFusions pass will be united with other - /// registered matchers. - template ::value, bool>::type = true> - void add_matcher(Args&&... args) { - static_assert(std::is_base_of::value, "pass not derived from GraphRewrite"); - auto pass = std::make_shared(std::forward(args)...); - auto pass_config = get_pass_config(); - - for (auto& matcher : pass->m_matchers) { - pass->set_pass_config(pass_config); - m_matchers.push_back(matcher); - } - } - - NGRAPH_DEPRECATED("Use MatcherPass instead") - void add_matcher(const std::shared_ptr& m, - const ngraph::graph_rewrite_callback& callback, - const PassPropertyMask& property); - - NGRAPH_DEPRECATED("Use MatcherPass instead") - void add_matcher(const std::shared_ptr& m, const ngraph::graph_rewrite_callback& callback); - - bool run_on_function(std::shared_ptr f) override; - - void set_pass_config(const std::shared_ptr& pass_config) override; - -protected: - bool apply_matcher_passes(std::shared_ptr f, std::deque> nodes_to_run); - - bool m_enable_shape_inference = false; - - std::vector> m_matchers; -}; - -class NGRAPH_API BackwardGraphRewrite : public ngraph::pass::GraphRewrite { -public: - NGRAPH_RTTI_DECLARATION; - - BackwardGraphRewrite() = default; - - explicit BackwardGraphRewrite(const std::shared_ptr& pass) : GraphRewrite(pass) {} - - bool run_on_function(std::shared_ptr f) override; -}; - -class NGRAPH_API RecurrentGraphRewrite : public ngraph::pass::FunctionPass { -public: - RecurrentGraphRewrite(size_t num_iters = 10) : FunctionPass(), m_num_iters(num_iters) {} - - void add_matcher(const std::shared_ptr& m, - const ngraph::recurrent_graph_rewrite_callback& callback, - const PassPropertyMask& property); - - // TODO: This interface may deprecate after all passes are refactored. - void add_matcher(const std::shared_ptr& m, - const ngraph::recurrent_graph_rewrite_callback& callback); - - bool run_on_function(std::shared_ptr f) override; - -private: - size_t m_num_iters; - - std::vector> m_matchers; -}; +using ov::pass::BackwardGraphRewrite; +using ov::pass::GraphRewrite; +using ov::pass::MatcherPass; +using ov::pass::RecurrentGraphRewrite; } // namespace pass } // namespace ngraph diff --git a/ngraph/core/include/ngraph/pass/low_latency.hpp b/ngraph/core/include/ngraph/pass/low_latency.hpp index 49607893b2b..f003bba095a 100644 --- a/ngraph/core/include/ngraph/pass/low_latency.hpp +++ b/ngraph/core/include/ngraph/pass/low_latency.hpp @@ -5,10 +5,12 @@ #pragma once #include -#include -#include #include +#include "ngraph/pass/graph_rewrite.hpp" +#include "ngraph/pass/pass.hpp" +#include "openvino/pass/low_latency.hpp" + namespace ngraph { namespace pass { /** @@ -46,38 +48,6 @@ public: LowLatency(); }; -/** - * @brief The transformation finds all TensorIterator/Loop layers in the network, - * processes all back edges that describe a connection between Result and Parameter - * of the TensorIterator/Loop bodies,and inserts ReadValue and Assign layers at the - * input and output corresponding to this back edge. - * Supported platforms: CPU, GNA. - * - * The example below describes the changes made by the transformation - * [] - TensorIterator body - * () - new layer - * BE - back-edge - * - * before applying the transformation: - * -> input1[BE_1 -> Parameter -> Layers ... -> Result -> BE_1 ]output1-> - * - * after applying the transformation: - * ->(ReadValue)-> input1[BE_1 ->Parameter->Layers ...->Result->BE_1]output1 ->(Assign) - * \ - * ->... - * After applying the transformation, the resulting network can be inferred - * step by step, the states will store between inferences. - */ -class NGRAPH_API LowLatency2 : public ngraph::pass::FunctionPass { -public: - NGRAPH_RTTI_DECLARATION; - - explicit LowLatency2(bool use_const_initializer = true) : m_use_const_initializer(use_const_initializer) {} - - bool run_on_function(std::shared_ptr f) override; - -private: - bool m_use_const_initializer; -}; +using ov::pass::LowLatency2; } // namespace pass } // namespace ngraph diff --git a/ngraph/core/include/ngraph/pass/manager.hpp b/ngraph/core/include/ngraph/pass/manager.hpp index 573877046ef..3ecef3dcaae 100644 --- a/ngraph/core/include/ngraph/pass/manager.hpp +++ b/ngraph/core/include/ngraph/pass/manager.hpp @@ -11,106 +11,10 @@ #include "ngraph/pass/pass.hpp" #include "ngraph/pass/validate.hpp" +#include "openvino/pass/manager.hpp" namespace ngraph { namespace pass { -class NGRAPH_API Manager { -public: - Manager(); - ~Manager(); - - //// \brief Construct Manager with shared PassConfig instance - explicit Manager(std::shared_ptr pass_config); - - /// \brief Register given transformation class type to execution list - /// Example below show the basic usage of pass::Manager - /// - /// pass::Manager manager; - /// manager.register_pass(/*transformation constructor ars*/); - /// manager.run_passes(f); - /// - /// For some purposes transformation can be registered and disabled by default. - /// - /// manager.register_pass(); - /// - /// \return shared_ptr to the transformation instance - template - std::shared_ptr register_pass(Args&&... args) { - auto rc = push_pass(std::forward(args)...); - rc->set_pass_config(m_pass_config); - if (m_per_pass_validation) { - push_pass(); - } - if (!Enable && !m_pass_config->is_enabled()) { - m_pass_config->disable(); - } - return rc; - } - - void run_passes(std::shared_ptr); - - void set_pass_visualization(bool new_state) { - m_visualize = new_state; - } - /// \brief Set flag to enable/disable running Validate pass after executing - /// each registered pass - /// \param new_state Value "true" enables Validate pass run; "false", otherwise - void set_per_pass_validation(bool new_state) { - m_per_pass_validation = new_state; - } - /// \brief Callback is a lambda function that can be used by registered transformations. - /// The main purpose of this callback is to provide a way for plugins to disable/enable - /// transformations based on some conditions. In some cases plugins may want not to - /// execute some - /// transformations. - /// For example plugin can disable unpleasant decompositions because of performance - /// reasons for - /// some cases. - /// Callback example: - /// auto callback = [](const std::shared_ptr & node) -> bool { - /// return std::dynamic_pointer_cast(node) != - /// nullptr; - /// }; - /// This callback returns true in case of DepthToSpace operation. So when execution - /// DepthToSpace - /// decomposition pass will check is this decomposition needed or plugin can execute - /// this - /// operation directly. And of course on transformation side we need to have a response - /// for this - /// callback. - /// if (transformation_callback(batch_to_space)) { - /// return false; - /// } - /// \param callback lamda function that returns true in case if node is supported by - /// plugin and - /// transformation is not needed - NGRAPH_DEPRECATED("Please use get_pass_config() to configure transformation pipeline") - void set_callback(const param_callback& callback) { - m_pass_config->set_callback(callback); - } - /// \return PassConfig shared object. This object is used for transformations pipeline - /// configuration. - /// This object allows to disable/enable transformations execution, set callback to - /// particular - /// transformation. For mo details see PassConfig class. - std::shared_ptr get_pass_config() { - return m_pass_config; - } - -protected: - template - std::shared_ptr push_pass(Args&&... args) { - static_assert(std::is_base_of::value, "pass not derived from pass base"); - auto pass = std::make_shared(std::forward(args)...); - auto pass_base = std::static_pointer_cast(pass); - m_pass_list.push_back(pass_base); - return pass; - } - - std::shared_ptr m_pass_config; - std::vector> m_pass_list; - bool m_visualize = false; - bool m_per_pass_validation = true; -}; +using ov::pass::Manager; } // namespace pass } // namespace ngraph diff --git a/ngraph/core/include/ngraph/pass/pass.hpp b/ngraph/core/include/ngraph/pass/pass.hpp index 349439d36cd..db0932b1227 100644 --- a/ngraph/core/include/ngraph/pass/pass.hpp +++ b/ngraph/core/include/ngraph/pass/pass.hpp @@ -13,105 +13,32 @@ #include "ngraph/node.hpp" #include "ngraph/pass/pass_config.hpp" #include "ngraph/util.hpp" +#include "openvino/pass/pass.hpp" +namespace ov { +namespace pass { + +class Manager; + +} +} // namespace ov namespace ngraph { namespace pass { -enum class PassProperty : uint32_t { - // Pass requires node shapes to be static - REQUIRE_STATIC_SHAPE = 0x1, - // Pass transformation will change the function's dynamic state - CHANGE_DYNAMIC_STATE = 1 << 1, -}; - -typedef EnumMask PassPropertyMask; +using ov::pass::FunctionPass; +using ov::pass::FusionType; +using ov::pass::FusionTypeMask; +using ov::pass::Manager; +using ov::pass::PassBase; +using ov::pass::PassProperty; +using ov::pass::PassPropertyMask; +NGRAPH_DEPRECATED("This variable is deprecated and will be removed soon.") const PassPropertyMask all_pass_property_off; -class NGRAPH_API PassBase { - friend class Manager; - -public: - PassBase(); - virtual ~PassBase() {} - /// Check if this pass has all the pass properties. - bool get_property(const PassPropertyMask& prop_mask) const; - - void set_name(const std::string& name) { - m_name = name; - } - std::string get_name() const; - - /// \brief Set callback for particular transformation type. - /// This method set global callback. For more details see PassConfig class - /// documentation. - /// \param callback lambda function that takes node and returns bool - void set_callback(const param_callback& callback); - - /// \brief Set PassConfig for particular transformation instance - /// \param pass_config is a PassConfig shared_ptr - virtual void set_pass_config(const std::shared_ptr& pass_config) { - m_pass_config = pass_config; - } - - /// \brief Allows to access PassConfig shared instance - /// \return Shared instance of PassConfig class - std::shared_ptr get_pass_config() { - return m_pass_config; - } - /// \brief Applies callback for given node. By default callback returns false. - /// This method remains here only for backward compatibility and will be removed - /// after all transformations are moved to transformation_callback() method. - /// \return result of callback execution for given node - NGRAPH_DEPRECATED("Please use transformation_callback method instead") - bool m_transformation_callback(const std::shared_ptr& node) { - return m_pass_config->get_callback(get_type_info())(node); - } - - /// \brief Applies callback for given node. By default callback returns false. - /// \param node which will be used inside callback - /// \return result of callback execution for given node - bool transformation_callback(const std::shared_ptr& node) { - return m_pass_config->get_callback(get_type_info())(node); - } - - using type_info_t = DiscreteTypeInfo; - - virtual const type_info_t& get_type_info() const = 0; - -protected: - void set_property(const PassPropertyMask& prop, bool value); - -private: - PassPropertyMask m_property; - - std::string m_name; - std::shared_ptr m_pass_config; -}; - -class NGRAPH_API FunctionPass : public PassBase { -public: - NGRAPH_RTTI_DECLARATION; - virtual ~FunctionPass(); - virtual bool run_on_function(std::shared_ptr) = 0; -}; - class NGRAPH_DEPRECATED("Use MatcherPass or FunctionPass instead.") NGRAPH_API NodePass : public PassBase { public: NGRAPH_RTTI_DECLARATION; - virtual ~NodePass(); + ~NodePass() override; virtual bool run_on_node(std::shared_ptr) = 0; }; - -class Manager; -enum class FusionType : uint32_t { - //`DIFFERENTIABLE_FUSIONS` produce ops that support autodiff - // i.e. implement `generate_adjoints` - DIFFERENTIABLE_FUSIONS = 0x1, - REGULAR_FUSIONS = 0x2, - //`FOP_FUSIONS` produce ops in the FusedOps category that might - // not be supported by all backends - FOP_FUSIONS = 0x4, - ALL_FUSIONS = 0xFFFFFFFF -}; -typedef EnumMask FusionTypeMask; } // namespace pass } // namespace ngraph diff --git a/ngraph/core/include/ngraph/pass/pass_config.hpp b/ngraph/core/include/ngraph/pass/pass_config.hpp index 6f9bd8bd755..69d89d94a4d 100644 --- a/ngraph/core/include/ngraph/pass/pass_config.hpp +++ b/ngraph/core/include/ngraph/pass/pass_config.hpp @@ -12,164 +12,12 @@ #include "ngraph/function.hpp" #include "ngraph/node.hpp" #include "ngraph/util.hpp" +#include "openvino/pass/pass_config.hpp" namespace ngraph { namespace pass { -using param_callback = std::function)>; -using param_callback_map = std::map; - -/// \brief Class representing a transformations config that is used for disabling/enabling -/// transformations registered inside pass::Manager and also allows to set callback for all -/// transformations or for particular transformation. -/// -/// When pass::Manager is created all passes registered inside this manager including nested -/// passes will share the same instance of PassConfig class. -/// To work with this class first you need to get shared instance of this class by calling -/// manager.get_pass_config() method. Then you will be able to disable/enable passes based -/// on transformations type_info. For example: -/// -/// pass::Manager manager; -/// manager.register_pass(); -/// auto pass_config = manager.get_pass_config(); -/// pass_config->disable(); // this will disable nested pass inside -/// // CommonOptimizations pipeline -/// manager.run_passes(f); -/// -/// Sometimes it is needed to call transformation inside other transformation manually. And -/// for that case before running transformation you need manually check that this pass is -/// not disabled and then you need to set current PassConfig instance to this -/// transformation. For example: -/// -/// // Inside MatcherPass callback or inside FunctionPass run_on_function() method -/// // you need to call get_pass_config() method to get shared instance of PassConfig -/// auto pass_config = get_pass_config(); -/// -/// // Before running nested transformation you need to check is it disabled or not -/// if (!pass_config->is_disabled()) { -/// auto pass = ConvertGELU(); -/// pass->set_pass_config(pass_config); -/// pass.apply(node); -/// } -/// -/// Following this logic inside your transformations you will guaranty that transformations -/// will be executed in a right way. -class NGRAPH_API PassConfig { -public: - /// \brief Disable transformation by its type_info - /// \param type_info Transformation type_info - void disable(const DiscreteTypeInfo& type_info); - /// \brief Disable transformation by its class type (based on type_info) - template - void disable() { - NGRAPH_SUPPRESS_DEPRECATED_START - disable(T::type_info); - NGRAPH_SUPPRESS_DEPRECATED_END - } - - /// \brief Enable transformation by its type_info - /// \param type_info Transformation type_info - void enable(const DiscreteTypeInfo& type_info); - /// \brief Enable transformation by its class type (based on type_info) - template - void enable() { - NGRAPH_SUPPRESS_DEPRECATED_START - enable(T::type_info); - NGRAPH_SUPPRESS_DEPRECATED_END - } - - /// \brief Set callback for all kind of transformations - void set_callback(const param_callback& callback) { - m_callback = callback; - } - template - typename std::enable_if::type set_callback(const param_callback& callback) {} - - /// \brief Set callback for particular transformation class types - /// - /// Example below show how to set callback for one or multiple passes using this method. - /// - /// pass_config->set_callback( - /// [](const_node_ptr &node) -> bool { - /// // Disable transformations for cases when input shape rank is not - /// equal to 4 - /// const auto input_shape_rank = - /// node->get_output_partial_shape(0).rank().get_length(); - /// if (input_shape_rank != 4) { - /// return false; - /// } - /// return true; - /// }); - /// - /// Note that inside transformations you must provide code that work with this callback. - /// See example below: - /// - /// if (transformation_callback(node)) { - /// return false; // exit from transformation - /// } - /// - template - void set_callback(const param_callback& callback) { - m_callback_map[T::type_info] = callback; - set_callback(callback); - } - - /// \brief Get callback for given transformation type_info - /// \param type_info Transformation type_info - /// - /// In case if callback wasn't set for given transformation type then global callback - /// will be returned. But if even global callback wasn't set then default callback will - /// be returned. - param_callback get_callback(const DiscreteTypeInfo& type_info) const; - - /// \brief Get callback for given transformation class type - /// \return callback lambda function - template - param_callback get_callback() const { - NGRAPH_SUPPRESS_DEPRECATED_START - return get_callback(T::type_info); - NGRAPH_SUPPRESS_DEPRECATED_END - } - - /// \brief Check either transformation type is disabled or not - /// \param type_info Transformation type_info - /// \return true if transformation type was disabled and false otherwise - bool is_disabled(const DiscreteTypeInfo& type_info) const { - return m_disabled.count(type_info); - } - - /// \brief Check either transformation class type is disabled or not - /// \return true if transformation type was disabled and false otherwise - template - bool is_disabled() const { - NGRAPH_SUPPRESS_DEPRECATED_START - return is_disabled(T::type_info); - NGRAPH_SUPPRESS_DEPRECATED_END - } - - /// \brief Check either transformation type is force enabled or not - /// \param type_info Transformation type_info - /// \return true if transformation type was force enabled and false otherwise - bool is_enabled(const DiscreteTypeInfo& type_info) const { - return m_enabled.count(type_info); - } - - /// \brief Check either transformation class type is force enabled or not - /// \return true if transformation type was force enabled and false otherwise - template - bool is_enabled() const { - return is_enabled(T::type_info); - } - - void add_disabled_passes(const PassConfig& rhs); - -private: - param_callback m_callback = [](const std::shared_ptr&) { - return false; - }; - param_callback_map m_callback_map; - std::unordered_set m_disabled; - std::unordered_set m_enabled; -}; +using ov::pass::param_callback; +using ov::pass::param_callback_map; +using ov::pass::PassConfig; } // namespace pass } // namespace ngraph diff --git a/ngraph/core/include/ngraph/pass/validate.hpp b/ngraph/core/include/ngraph/pass/validate.hpp index 6daa6ca09d3..3f8b42d14f2 100644 --- a/ngraph/core/include/ngraph/pass/validate.hpp +++ b/ngraph/core/include/ngraph/pass/validate.hpp @@ -5,27 +5,10 @@ #pragma once #include "ngraph/pass/pass.hpp" +#include "openvino/pass/validate.hpp" namespace ngraph { namespace pass { -/// \brief The Validate pass performs sanity checks on attributes and inputs, and -/// computes output shapes and element types for all computation nodes in a given -/// computation graph. -/// -/// \details The verification and inference is done via invoking each node's specific -/// implementation of \link ngraph::Node::validate_and_infer_types() \endlink function. -/// -/// By default, the \ref ngraph::pass::Manager runs this pass after executing every -/// optimization pass. This is to ensure that any update to the graph by an optimization -/// pass does not break the shape and data type requirement on a computation node. -/// This default validation run can be changed via calling the -/// \link ngraph::pass::Manager::set_per_pass_validation(bool) \endlink function. -class NGRAPH_API Validate : public FunctionPass { -public: - NGRAPH_RTTI_DECLARATION; - - Validate() : FunctionPass() {} - bool run_on_function(std::shared_ptr f) override; -}; +using ov::pass::Validate; } // namespace pass } // namespace ngraph diff --git a/ngraph/core/include/ngraph/pass/visualize_tree.hpp b/ngraph/core/include/ngraph/pass/visualize_tree.hpp index 6c4b0863329..b37c0ec502f 100644 --- a/ngraph/core/include/ngraph/pass/visualize_tree.hpp +++ b/ngraph/core/include/ngraph/pass/visualize_tree.hpp @@ -14,44 +14,10 @@ #include #include "ngraph/pass/pass.hpp" - -class HeightMap; - -using visualize_tree_ops_map_t = - std::unordered_map>; +#include "openvino/pass/visualize_tree.hpp" namespace ngraph { namespace pass { -class NGRAPH_API VisualizeTree : public FunctionPass { -public: - NGRAPH_RTTI_DECLARATION; - - using node_modifiers_t = std::function& attributes)>; - VisualizeTree(const std::string& file_name, node_modifiers_t nm = nullptr, bool dot_only = false); - bool run_on_function(std::shared_ptr) override; - - void set_ops_to_details(const visualize_tree_ops_map_t& ops_map) { - m_ops_to_details = ops_map; - } - -protected: - void add_node_arguments(std::shared_ptr node, - std::unordered_map& height_maps, - size_t& fake_node_ctr); - std::string add_attributes(std::shared_ptr node); - virtual std::string get_attributes(std::shared_ptr node); - virtual std::string get_node_name(std::shared_ptr node); - std::string get_constant_value(std::shared_ptr node, size_t max_elements = 7); - - void render() const; - - std::stringstream m_ss; - std::string m_name; - std::set> m_nodes_with_attributes; - visualize_tree_ops_map_t m_ops_to_details; - node_modifiers_t m_node_modifiers = nullptr; - bool m_dot_only; - static const int max_jump_distance; -}; +using ov::pass::VisualizeTree; } // namespace pass } // namespace ngraph diff --git a/ngraph/core/include/ngraph/pattern/matcher.hpp b/ngraph/core/include/ngraph/pattern/matcher.hpp index f361365bdcb..add3789429d 100644 --- a/ngraph/core/include/ngraph/pattern/matcher.hpp +++ b/ngraph/core/include/ngraph/pattern/matcher.hpp @@ -16,255 +16,21 @@ #include "ngraph/pattern/op/any_output.hpp" #include "ngraph/pattern/op/label.hpp" #include "ngraph/pattern/op/skip.hpp" +#include "openvino/pass/pattern/matcher.hpp" -namespace ngraph { +namespace ov { namespace pass { class GraphRewrite; } +} // namespace ov +namespace ngraph { +namespace pass { +using ov::pass::GraphRewrite; +} namespace pattern { -class Matcher; - -class NGRAPH_API MatcherState { -public: - MatcherState(Matcher*); - bool finish(bool is_successful); - ~MatcherState(); - -protected: - Matcher* m_matcher; - PatternValueMap m_pattern_value_map; - PatternValueMaps m_pattern_value_maps; - size_t m_watermark; - size_t m_capture_size; - bool m_restore{true}; -}; - -/// Matcher looks for node patterns in a computation graph. The patterns are described by an -/// automaton that is described by an extended computation graph. The matcher executes -/// by attempting to match the start node of the pattern to a computation graph value -/// (output of a Node). In addition to determing if a match occurs, a pattern node may add -/// graph nodes to a list of matched nodes, associate nodes with graph values, and start -/// submatches. Submatches add match state changes to the enclosing match if the submatch -/// succeeds; otherwise the state is reverted. -/// -/// The default match behavior of a pattern node with a graph nodes is that the computation -/// graph value is added to the end of the matched value list and the match succeeds if the -/// node/pattern types match and the input values match. In the case of a commutative node, -/// the inputs can match in any order. If the matcher is in strict mode, the graph value -/// element type and shape must also match. -/// -/// Pattern nodes that have different match behavior are in ngraph::pattern::op and have -/// descriptions of their match behavior. -class NGRAPH_API Matcher { -public: - using PatternMap = ngraph::pattern::PatternMap; - - // Avoid implicit string construction from nullptr. - Matcher(const std::shared_ptr pattern_node, std::nullptr_t name) = delete; - - Matcher() {} - Matcher(Output& pattern_node) : m_pattern_node{pattern_node} {} - - Matcher(Output& pattern_node, const std::string& name) : m_pattern_node(pattern_node), m_name{name} {} - - /// \brief Constructs a Matcher object - /// - /// \param pattern_node is a pattern sub graph that will be matched against input graphs - /// \param name is a string which is used for logging and disabling a matcher - /// \param strict_mode forces a matcher to consider shapes and ET of nodes - Matcher(const Output& pattern_node, const std::string& name, bool strict_mode) - : m_pattern_node(pattern_node), - m_name(name), - m_strict_mode(strict_mode) {} - - // Some matches should start on a node rather than an output. These three constructors - // are transition until we work out the right way to do that. - Matcher(std::shared_ptr pattern_node); - Matcher(std::shared_ptr pattern_node, const std::string& name); - Matcher(std::shared_ptr pattern_node, const std::string& name, bool strict_mode); - - virtual ~Matcher() {} - /// \brief Matches a pattern to \p graph_node - /// - /// \param graph_value is an input graph to be matched against - bool match(const Output& graph_value); - - bool match(std::shared_ptr graph_node); - - /// \brief Matches a pattern to \p graph_node - /// - /// \param graph_value is an input graph to be matched against - /// \param previous_matches contains previous mappings from labels to nodes to use - bool match(const Output& graph_value, const PatternMap& previous_matches); - bool match(const Output& graph_value, const PatternValueMap& previous_matches); - - template - static std::shared_ptr unique_match(std::shared_ptr node) { - std::shared_ptr matched; - for (auto arg : node->input_values()) { - if (auto t_casted = ov::as_type_ptr(arg.get_node_shared_ptr())) { - if (matched) { - throw ngraph_error("There's more than two arguments of the same type"); - } else { - matched = t_casted; - } - } - } - return matched; - } - - bool is_contained_match(const NodeVector& exclusions = {}, bool ignore_unused = true); - const NodeVector get_matched_nodes() { - return as_node_vector(m_matched_list); - } - const OutputVector& get_matched_values() const { - return m_matched_list; - } - OutputVector& get_matched_values() { - return m_matched_list; - } - void reset() {} - const std::string& get_name() { - return m_name; - } - std::shared_ptr get_pattern() { - return m_pattern_node.get_node_shared_ptr(); - } - Output get_pattern_value() { - return m_pattern_node; - } - std::shared_ptr get_match_root(); - Output get_match_value(); - PatternMap get_pattern_map() const; - PatternValueMap& get_pattern_value_map() { - return m_pattern_map; - } - PatternValueMaps& get_pattern_value_maps() { - return m_pattern_value_maps; - } - /// \brief Low-level helper to match recurring patterns - /// - /// \param graph is a graph to be matched against - /// \param pattern is a recurring pattern - /// \param rpattern specifies a node to recur from next - /// \param patterns a map from labels to matches - - size_t add_node(Output node); - - bool virtual match_value(const ngraph::Output& pattern_value, const ngraph::Output& graph_value); - - bool is_strict_mode() { - return m_strict_mode; - } - virtual bool match_arguments(Node* pattern_node, const std::shared_ptr& graph_node); - - void capture(const std::set& static_nodes); - - void clear_state(); - - size_t get_number_of_recurrent_matches() const { - return m_pattern_value_maps.size(); - } - NodeVector get_bound_nodes_for_pattern(const Output& pattern) const; - size_t get_number_of_bound_labels() const; - /// \brief Try a match - MatcherState start_match(); - - Output m_match_root; - Output m_pattern_node; - PatternValueMap m_pattern_map; - PatternValueMaps m_pattern_value_maps; - OutputVector m_matched_list; - -protected: - bool match_permutation(const OutputVector& pattern_args, const OutputVector& args); - - std::string m_name{"unnamed"}; - bool m_strict_mode{false}; -}; - -class NGRAPH_API RecurrentMatcher { -public: - /// \brief Constructs a RecurrentMatcher object. Reccurent Matchers are used to match - /// repeating patterns (e.g. RNN, LSTM, GRU cells) - /// - /// \param initial_pattern is a pattern sub graph describing the initial cell - /// \param pattern is a pattern sub graph describing an individual cell - /// \param rpattern is a (recurring) label to denote which node the next match should - /// start at - /// \param correlated_patterns is a set of labels whose bound nodes must remain the same - /// across all cells - RecurrentMatcher(const Output& initial_pattern, - const Output& pattern, - const std::shared_ptr& rpattern, - const std::set>& correlated_patterns) - : m_initial_pattern(initial_pattern), - m_pattern(pattern), - m_recurrent_pattern(rpattern), - m_correlated_patterns(correlated_patterns) {} - - /// \brief Constructs a RecurrentMatcher object. Reccurent Matchers are used to match - /// repeating patterns (e.g. RNN, LSTM, GRU cells) - /// - /// \param pattern is a pattern sub graph describing an individual cell - /// \param rpattern is a (recurring) label to denote which node the next match should - /// start at - /// \param correlated_patterns is a set of labels whose bound nodes must remain the same - /// across all cells - RecurrentMatcher(const Output& pattern, - const std::shared_ptr& rpattern, - const std::set>& correlated_patterns) - : RecurrentMatcher(pattern, pattern, rpattern, correlated_patterns) {} - - RecurrentMatcher(const Output& initial_pattern, - const Output& pattern, - const std::shared_ptr& rpattern, - const std::set>& correlated_patterns); - - RecurrentMatcher(const Output& pattern, - const std::shared_ptr& rpattern, - const std::set>& correlated_patterns) - : RecurrentMatcher(pattern, pattern, rpattern, correlated_patterns) {} - - /// \brief Returns a vector of bound nodes for a given label (used in a pattern - /// describing an individual cell - NodeVector get_bound_nodes_for_pattern(const std::shared_ptr& pattern) const { - if (m_matches.count(pattern) == 0) { - throw ngraph_error("No bound nodes for a given label"); - } - - return as_node_vector(m_matches.at(pattern)); - } - - size_t get_number_of_recurrent_matches() const { - if (m_matches.size() == 0) { - return 0; - } - - return (*m_matches.begin()).second.size(); - } - - size_t get_number_of_bound_labels() const { - return m_matches.size(); - } - /// \brief Tries to match a pattern for an individual cell to a given \p graph - bool match(Output graph); - - std::shared_ptr get_match_root() { - return m_match_root.get_node_shared_ptr(); - } - Output get_match_value() { - return m_match_root; - } - -private: - Output m_initial_pattern; - Output m_pattern; - std::shared_ptr m_recurrent_pattern; - const std::set> m_correlated_patterns; - RPatternValueMap m_matches; - Output m_match_root; -}; +using ov::pass::pattern::Matcher; +using ov::pass::pattern::MatcherState; +using ov::pass::pattern::RecurrentMatcher; } // namespace pattern } // namespace ngraph diff --git a/ngraph/core/include/ngraph/pattern/op/any.hpp b/ngraph/core/include/ngraph/pattern/op/any.hpp index d9d2a85cd8b..bc2ac780159 100644 --- a/ngraph/core/include/ngraph/pattern/op/any.hpp +++ b/ngraph/core/include/ngraph/pattern/op/any.hpp @@ -6,38 +6,12 @@ #include "ngraph/node.hpp" #include "ngraph/pattern/op/pattern.hpp" +#include "openvino/pass/pattern/op/any.hpp" namespace ngraph { namespace pattern { namespace op { -/// The graph value is to the matched value list. If the predicate is true for the node -/// and the arguments match, the match succeeds. -class NGRAPH_API Any : public Pattern { -public: - static constexpr NodeTypeInfo type_info{"patternAny", 0}; - const NodeTypeInfo& get_type_info() const override; - /// \brief creates a Any node containing a sub-pattern described by \sa type and \sa - /// shape. - Any(const element::Type& type, const PartialShape& s, ValuePredicate pred, const OutputVector& wrapped_values) - : Pattern(wrapped_values, pred) { - set_output_type(0, type, s); - } - Any(const element::Type& type, const PartialShape& s, NodePredicate pred, const NodeVector& wrapped_values) - : Any(type, s, as_value_predicate(pred), as_output_vector(wrapped_values)) {} - /// \brief creates a Any node containing a sub-pattern described by the type and - /// shape of \sa node. - Any(const Output& node, ValuePredicate pred, const OutputVector& wrapped_values) - : Any(node.get_element_type(), node.get_partial_shape(), pred, wrapped_values) {} - Any(const Output& node, NodePredicate pred, const NodeVector& wrapped_values) - : Any(node.get_element_type(), - node.get_partial_shape(), - as_value_predicate(pred), - as_output_vector(wrapped_values)) {} - - bool match_value(pattern::Matcher* matcher, - const Output& pattern_value, - const Output& graph_value) override; -}; +using ov::pass::pattern::op::Any; } // namespace op } // namespace pattern } // namespace ngraph diff --git a/ngraph/core/include/ngraph/pattern/op/any_of.hpp b/ngraph/core/include/ngraph/pattern/op/any_of.hpp index 4970626e136..166bd763271 100644 --- a/ngraph/core/include/ngraph/pattern/op/any_of.hpp +++ b/ngraph/core/include/ngraph/pattern/op/any_of.hpp @@ -6,47 +6,12 @@ #include "ngraph/node.hpp" #include "ngraph/pattern/op/pattern.hpp" +#include "openvino/pass/pattern/op/any_of.hpp" namespace ngraph { namespace pattern { namespace op { -/// The graph value is added to the matched values list. If the predicate is true for -/// the -/// graph node, a submatch is performed on the input of AnyOf and each input of the -/// graph node. The first match that succeeds results in a successful match. Otherwise -/// the match fails. -/// -/// AnyOf may be given a type and shape for use in strict mode. -class NGRAPH_API AnyOf : public Pattern { -public: - static constexpr NodeTypeInfo type_info{"patternAnyOf", 0}; - const NodeTypeInfo& get_type_info() const override; - /// \brief creates a AnyOf node containing a sub-pattern described by \sa type and - /// \sa shape. - AnyOf(const element::Type& type, const PartialShape& s, ValuePredicate pred, const OutputVector& wrapped_values) - : Pattern(wrapped_values, pred) { - if (wrapped_values.size() != 1) { - throw ngraph_error("AnyOf expects exactly one argument"); - } - set_output_type(0, type, s); - } - AnyOf(const element::Type& type, const PartialShape& s, NodePredicate pred, const NodeVector& wrapped_values) - : AnyOf( - type, - s, - [pred](const Output& value) { - return pred(value.get_node_shared_ptr()); - }, - as_output_vector(wrapped_values)) {} - - /// \brief creates a AnyOf node containing a sub-pattern described by the type and - /// shape of \sa node. - AnyOf(const Output& node, ValuePredicate pred, const OutputVector& wrapped_values) - : AnyOf(node.get_element_type(), node.get_partial_shape(), pred, wrapped_values) {} - AnyOf(std::shared_ptr node, NodePredicate pred, const NodeVector& wrapped_values) - : AnyOf(node, as_value_predicate(pred), as_output_vector(wrapped_values)) {} - bool match_value(Matcher* matcher, const Output& pattern_value, const Output& graph_value) override; -}; +using ov::pass::pattern::op::AnyOf; } // namespace op } // namespace pattern } // namespace ngraph diff --git a/ngraph/core/include/ngraph/pattern/op/any_output.hpp b/ngraph/core/include/ngraph/pattern/op/any_output.hpp index 46cf734f57b..58fd0c7a044 100644 --- a/ngraph/core/include/ngraph/pattern/op/any_output.hpp +++ b/ngraph/core/include/ngraph/pattern/op/any_output.hpp @@ -6,23 +6,12 @@ #include "ngraph/node.hpp" #include "ngraph/pattern/op/pattern.hpp" +#include "openvino/pass/pattern/op/any_output.hpp" namespace ngraph { namespace pattern { namespace op { -/// Matches any output of a node -class NGRAPH_API AnyOutput : public Pattern { -public: - static constexpr NodeTypeInfo type_info{"patternAnyOutput", 0}; - const NodeTypeInfo& get_type_info() const override; - /// \brief creates an AnyOutput node matching any output of a node - /// \param node The node to match - AnyOutput(const std::shared_ptr& pattern) : Pattern({pattern->output(0)}) {} - - bool match_value(pattern::Matcher* matcher, - const Output& pattern_value, - const Output& graph_value) override; -}; +using ov::pass::pattern::op::AnyOutput; } // namespace op } // namespace pattern } // namespace ngraph diff --git a/ngraph/core/include/ngraph/pattern/op/branch.hpp b/ngraph/core/include/ngraph/pattern/op/branch.hpp index 13a4e93db57..c6bbf9a45e5 100644 --- a/ngraph/core/include/ngraph/pattern/op/branch.hpp +++ b/ngraph/core/include/ngraph/pattern/op/branch.hpp @@ -6,48 +6,12 @@ #include "ngraph/node.hpp" #include "ngraph/pattern/op/pattern.hpp" +#include "openvino/pass/pattern/op/branch.hpp" namespace ngraph { namespace pattern { namespace op { -/// A branch adds a loop to the pattern. The branch match is successful if the -/// destination node pattern matches the graph value. The destination node is a node in -/// the pattern graph that will not have been created some time after the Branch node is -/// created; use set_destination to add it. -/// -/// The branch destination is not stored as a shared pointer to prevent reference -/// cycles. Thus the destination node must be referenced in some other way to prevent it -/// from being deleted. -class NGRAPH_API Branch : public Pattern { -public: - static constexpr NodeTypeInfo type_info{"patternBranch", 0}; - const NodeTypeInfo& get_type_info() const override; - /// \brief Creates a Branch pattern - /// \param pattern the destinationing pattern - /// \param labels Labels where the destination may occur - Branch() : Pattern(OutputVector{}) { - set_output_type(0, element::f32, Shape{}); - } - - void set_destination(const Output& destination) { - m_destination_node = destination.get_node(); - m_destination_index = destination.get_index(); - } - - Output get_destination() const { - return m_destination_node == nullptr - ? Output() - : Output{m_destination_node->shared_from_this(), m_destination_index}; - } - - bool match_value(pattern::Matcher* matcher, - const Output& pattern_value, - const Output& graph_value) override; - -protected: - Node* m_destination_node{nullptr}; - size_t m_destination_index{0}; -}; +using ov::pass::pattern::op::Branch; } // namespace op } // namespace pattern } // namespace ngraph diff --git a/ngraph/core/include/ngraph/pattern/op/capture.hpp b/ngraph/core/include/ngraph/pattern/op/capture.hpp index d5f816588fa..586e0f697c8 100644 --- a/ngraph/core/include/ngraph/pattern/op/capture.hpp +++ b/ngraph/core/include/ngraph/pattern/op/capture.hpp @@ -6,37 +6,12 @@ #include "ngraph/node.hpp" #include "ngraph/pattern/op/pattern.hpp" +#include "openvino/pass/pattern/op/capture.hpp" namespace ngraph { namespace pattern { namespace op { -/// Experimental for support of recurrent matches. -/// -/// Capture adds the pattern value map to a list of pattern value maps and resets -/// matches for pattern nodes not in the static node list. The match always succeeds. -class NGRAPH_API Capture : public Pattern { -public: - static constexpr NodeTypeInfo type_info{"patternCapture", 0}; - const NodeTypeInfo& get_type_info() const override; - Capture(const Output& arg) : Pattern({arg}) { - set_output_type(0, arg.get_element_type(), arg.get_partial_shape()); - } - - /// \brief static nodes are retained after a capture. All other nodes are dropped - std::set get_static_nodes() { - return m_static_nodes; - } - void set_static_nodes(const std::set& static_nodes) { - m_static_nodes = static_nodes; - } - - virtual bool match_value(pattern::Matcher* matcher, - const Output& pattern_value, - const Output& graph_value) override; - -protected: - std::set m_static_nodes; -}; +using ov::pass::pattern::op::Capture; } // namespace op } // namespace pattern } // namespace ngraph diff --git a/ngraph/core/include/ngraph/pattern/op/label.hpp b/ngraph/core/include/ngraph/pattern/op/label.hpp index 50da1e63b2b..d1893f375ef 100644 --- a/ngraph/core/include/ngraph/pattern/op/label.hpp +++ b/ngraph/core/include/ngraph/pattern/op/label.hpp @@ -6,106 +6,14 @@ #include "ngraph/node.hpp" #include "ngraph/pattern/op/pattern.hpp" +#include "openvino/pass/pattern/op/label.hpp" namespace ngraph { namespace pattern { namespace op { -/// Fails if the predicate returns false on the graph value. -/// -/// The graph value is added to the matched values list. If the Label is already -/// associated with a value, the match succeeds if the value is the same as the graph -/// value. Otherwise, the label is associated with the graph value and the match -/// succeeds if the pattern input matches the graph value. -/// -/// DEPRECATED: If no inputs are given to Label, a True node is serves as the input. If -/// more than one inputs are given, an Or pattern of the inputs serves as the input. -class NGRAPH_API Label : public Pattern { -public: - static constexpr NodeTypeInfo type_info{"patternLabel", 0}; - const NodeTypeInfo& get_type_info() const override; - /// \brief creates a Label node containing a sub-pattern described by \sa type and - /// \sa shape. - /// - /// this Label node can be bound only to the nodes in the input graph - /// that match the pattern specified by \sa wrapped_nodes - /// Example: - /// \code{.cpp} - /// auto add = a + b; // a and b are op::Parameter in this example - /// auto label = std::make_shared(element::f32, - /// Shape{2,2}, - /// nullptr, - /// OutputVector{add}); - /// \endcode - Label(const element::Type& type, - const PartialShape& s, - const ValuePredicate pred, - const OutputVector& wrapped_values) - : Pattern(OutputVector{wrap_values(wrapped_values)}, pred) { - set_output_type(0, type, s); - } - - explicit Label(const element::Type& type = element::dynamic, const PartialShape& s = PartialShape::dynamic()) - : Label( - type, - s, - [](const Output&) { - return true; - }, - OutputVector()) {} - - Label(const element::Type& type, const PartialShape& s, ValuePredicate pred) - : Label(type, s, pred, OutputVector{}) {} - - Label(const element::Type& type, const PartialShape& s, NodePredicate pred) - : Label(type, s, as_value_predicate(pred), OutputVector{}) {} - - Label(const element::Type& type, const PartialShape& s, const NodePredicate pred, const NodeVector& wrapped_values) - : Label(type, s, as_value_predicate(pred), as_output_vector(wrapped_values)) {} - - /// \brief creates a Label node containing a sub-pattern described by the type and - /// shape of \sa node. - /// - /// this Label node can be bound only to the nodes in the input graph - /// that match the pattern specified by \sa wrapped_values - /// Example: - /// \code{.cpp} - /// auto add = a + b; // a and b are op::Parameter in this example - /// auto label = std::make_shared(add, - /// nullptr, - /// OutputVector{add}); - /// \endcode - Label(const Output& value, const ValuePredicate pred, const OutputVector& wrapped_values) - : Label(value.get_element_type(), value.get_partial_shape(), pred, wrapped_values) {} - Label(const Output& value, const ValuePredicate pred) - : Label(value.get_element_type(), value.get_partial_shape(), pred, OutputVector{}) {} - - Label(const Output& value, const NodePredicate pred) - : Label(value.get_element_type(), value.get_partial_shape(), as_value_predicate(pred), OutputVector{}) {} - Label(const Output& value) - : Label( - value.get_element_type(), - value.get_partial_shape(), - [](const Output&) { - return true; - }, - OutputVector{}) {} - Label(const Output& node, const NodePredicate pred, const NodeVector& wrapped_values) - : Label(node.get_element_type(), - node.get_partial_shape(), - as_value_predicate(pred), - as_output_vector(wrapped_values)) {} - - bool match_value(Matcher* matcher, const Output& pattern_value, const Output& graph_value) override; - -protected: - static Output wrap_values(const OutputVector& wrapped_values); -}; +using ov::pass::pattern::op::Label; } // namespace op -NGRAPH_API -std::shared_ptr any_input(); - -NGRAPH_API -std::shared_ptr any_input(const pattern::op::ValuePredicate& pred); +using ov::pass::pattern::any_input; } // namespace pattern } // namespace ngraph diff --git a/ngraph/core/include/ngraph/pattern/op/or.hpp b/ngraph/core/include/ngraph/pattern/op/or.hpp index 53368949712..19f02c6c5c9 100644 --- a/ngraph/core/include/ngraph/pattern/op/or.hpp +++ b/ngraph/core/include/ngraph/pattern/op/or.hpp @@ -6,25 +6,12 @@ #include "ngraph/node.hpp" #include "ngraph/pattern/op/pattern.hpp" +#include "openvino/pass/pattern/op/or.hpp" namespace ngraph { namespace pattern { namespace op { -/// A submatch on the graph value is performed on each input to the Or; the match -/// succeeds on the first match. Otherwise the match fails. -class NGRAPH_API Or : public Pattern { -public: - static constexpr NodeTypeInfo type_info{"patternOr", 0}; - const NodeTypeInfo& get_type_info() const override; - /// \brief creates an Or node matching one of several sub-patterns in order. Does - /// not add node to match list. - /// \param patterns The patterns to try for matching - Or(const OutputVector& patterns) : Pattern(patterns) {} - - bool match_value(pattern::Matcher* matcher, - const Output& pattern_value, - const Output& graph_value) override; -}; +using ov::pass::pattern::op::Or; } // namespace op } // namespace pattern } // namespace ngraph diff --git a/ngraph/core/include/ngraph/pattern/op/pattern.hpp b/ngraph/core/include/ngraph/pattern/op/pattern.hpp index 931c3d9e4e0..99e4edfc570 100644 --- a/ngraph/core/include/ngraph/pattern/op/pattern.hpp +++ b/ngraph/core/include/ngraph/pattern/op/pattern.hpp @@ -7,8 +7,10 @@ #include #include "ngraph/node.hpp" +#include "openvino/pass/pattern/op/pattern.hpp" -namespace ngraph { +namespace ov { +namespace pass { namespace pattern { namespace op { class Label; @@ -16,79 +18,42 @@ class Label; class Matcher; class MatchState; - -using RPatternValueMap = std::map, OutputVector>; -using PatternValueMap = std::map, Output>; -using PatternValueMaps = std::vector; - -using PatternMap = std::map, std::shared_ptr>; - -PatternMap as_pattern_map(const PatternValueMap& pattern_value_map); -PatternValueMap as_pattern_value_map(const PatternMap& pattern_map); - -template -std::function)> has_class() { - auto pred = [](std::shared_ptr node) -> bool { - return ov::is_type(node); - }; - - return pred; +} // namespace pattern +} // namespace pass +} // namespace ov +namespace ngraph { +namespace pattern { +namespace op { +using ov::pass::pattern::op::Label; } -NGRAPH_API -std::function)> consumers_count(size_t n); +using ov::pass::pattern::Matcher; +using ov::pass::pattern::MatcherState; -NGRAPH_API -std::function)> has_static_dim(size_t pos); +using ov::pass::pattern::PatternValueMap; +using ov::pass::pattern::PatternValueMaps; +using ov::pass::pattern::RPatternValueMap; -NGRAPH_API -std::function)> has_static_dims(const std::vector& dims); +using ov::pass::pattern::PatternMap; -NGRAPH_API -std::function)> has_static_shape(); - -NGRAPH_API -std::function)> has_static_rank(); - -NGRAPH_API -std::function)> rank_equals(const Dimension& expected_rank); - -NGRAPH_API -std::function)> type_matches(const element::Type& type); - -NGRAPH_API -std::function)> type_matches_any(const std::vector& types); +using ov::pass::pattern::as_pattern_map; +using ov::pass::pattern::as_pattern_value_map; +using ov::pass::pattern::consumers_count; +using ov::pass::pattern::has_class; +using ov::pass::pattern::has_static_dim; +using ov::pass::pattern::has_static_dims; +using ov::pass::pattern::has_static_rank; +using ov::pass::pattern::has_static_shape; +using ov::pass::pattern::rank_equals; +using ov::pass::pattern::type_matches; +using ov::pass::pattern::type_matches_any; namespace op { -using NodePredicate = std::function)>; -using ValuePredicate = std::function& value)>; +using ov::pass::pattern::op::NodePredicate; +using ov::pass::pattern::op::ValuePredicate; -NGRAPH_API -ValuePredicate as_value_predicate(NodePredicate pred); - -class NGRAPH_API Pattern : public Node { -public: - /// \brief \p a base class for \sa Skip and \sa Label - /// - Pattern(const OutputVector& patterns, ValuePredicate pred) : Node(patterns), m_predicate(pred) { - if (!m_predicate) { - m_predicate = [](const Output&) { - return true; - }; - } - } - - Pattern(const OutputVector& patterns) : Pattern(patterns, nullptr) {} - - virtual std::shared_ptr clone_with_new_inputs(const OutputVector& /* new_args */) const override { - throw ngraph_error("Uncopyable"); - } - - ValuePredicate get_predicate() const; - -protected: - ValuePredicate m_predicate; -}; +using ov::pass::pattern::op::as_value_predicate; +using ov::pass::pattern::op::Pattern; } // namespace op } // namespace pattern } // namespace ngraph diff --git a/ngraph/core/include/ngraph/pattern/op/skip.hpp b/ngraph/core/include/ngraph/pattern/op/skip.hpp index f16bb667069..edf08ef69da 100644 --- a/ngraph/core/include/ngraph/pattern/op/skip.hpp +++ b/ngraph/core/include/ngraph/pattern/op/skip.hpp @@ -6,37 +6,12 @@ #include "ngraph/node.hpp" #include "ngraph/pattern/op/pattern.hpp" +#include "openvino/pass/pattern/op/skip.hpp" namespace ngraph { namespace pattern { namespace op { -/// The graph value is added to the matched value list. If the predicate is true, the -/// match succeeds if the arguments match; if the predicate is false, the match succeeds -/// if the pattern input matches the graph value. -class NGRAPH_API Skip : public Pattern { -public: - static constexpr NodeTypeInfo type_info{"patternSkip", 0}; - const NodeTypeInfo& get_type_info() const override; - Skip(const Output& arg, ValuePredicate pred) : Pattern({arg}, pred) { - set_output_type(0, arg.get_element_type(), arg.get_partial_shape()); - } - - Skip(const Output& arg, NodePredicate pred = nullptr) : Pattern({arg}, as_value_predicate(pred)) { - set_output_type(0, arg.get_element_type(), arg.get_partial_shape()); - } - - Skip(const OutputVector& args, ValuePredicate pred) : Pattern(args, pred) { - set_output_type(0, args.at(0).get_element_type(), args.at(0).get_partial_shape()); - } - - Skip(const OutputVector& args, NodePredicate pred = nullptr) : Pattern(args, as_value_predicate(pred)) { - set_output_type(0, args.at(0).get_element_type(), args.at(0).get_partial_shape()); - } - - virtual bool match_value(pattern::Matcher* matcher, - const Output& pattern_value, - const Output& graph_value) override; -}; +using ov::pass::pattern::op::Skip; } // namespace op } // namespace pattern } // namespace ngraph diff --git a/ngraph/core/include/ngraph/pattern/op/true.hpp b/ngraph/core/include/ngraph/pattern/op/true.hpp index ba08d05acff..1ccfa9e57fa 100644 --- a/ngraph/core/include/ngraph/pattern/op/true.hpp +++ b/ngraph/core/include/ngraph/pattern/op/true.hpp @@ -6,21 +6,12 @@ #include "ngraph/node.hpp" #include "ngraph/pattern/op/pattern.hpp" +#include "openvino/pass/pattern/op/true.hpp" namespace ngraph { namespace pattern { namespace op { -/// \brief The match always succeeds. -class NGRAPH_API True : public Pattern { -public: - static constexpr NodeTypeInfo type_info{"patternTrue", 0}; - const NodeTypeInfo& get_type_info() const override; - /// \brief Always matches, does not add node to match list. - True() : Pattern(OutputVector{}) {} - bool match_value(pattern::Matcher* matcher, - const Output& pattern_value, - const Output& graph_value) override; -}; +using ov::pass::pattern::op::True; } // namespace op } // namespace pattern } // namespace ngraph diff --git a/ngraph/core/include/ngraph/pattern/op/wrap_type.hpp b/ngraph/core/include/ngraph/pattern/op/wrap_type.hpp index 5d08553138e..875e890bd85 100644 --- a/ngraph/core/include/ngraph/pattern/op/wrap_type.hpp +++ b/ngraph/core/include/ngraph/pattern/op/wrap_type.hpp @@ -6,68 +6,14 @@ #include "ngraph/node.hpp" #include "ngraph/pattern/op/pattern.hpp" +#include "openvino/pass/pattern/op/wrap_type.hpp" namespace ngraph { namespace pattern { namespace op { -class NGRAPH_API WrapType : public Pattern { -public: - static constexpr NodeTypeInfo type_info{"patternAnyType", 0}; - const NodeTypeInfo& get_type_info() const override; - - explicit WrapType( - NodeTypeInfo wrapped_type, - const ValuePredicate& pred = - [](const Output& output) { - return true; - }, - const OutputVector& input_values = {}) - : Pattern(input_values, pred), - m_wrapped_types({wrapped_type}) { - set_output_type(0, element::Type_t::dynamic, PartialShape::dynamic()); - } - - explicit WrapType( - std::vector wrapped_types, - const ValuePredicate& pred = - [](const Output& output) { - return true; - }, - const OutputVector& input_values = {}) - : Pattern(input_values, pred), - m_wrapped_types(std::move(wrapped_types)) { - set_output_type(0, element::Type_t::dynamic, PartialShape::dynamic()); - } - - bool match_value(pattern::Matcher* matcher, - const Output& pattern_value, - const Output& graph_value) override; - - NodeTypeInfo get_wrapped_type() const; - - const std::vector& get_wrapped_types() const; - -private: - std::vector m_wrapped_types; -}; +using ov::pass::pattern::op::WrapType; } // namespace op -template -std::shared_ptr wrap_type(const OutputVector& inputs, const pattern::op::ValuePredicate& pred) { - std::vector info{Args::type_info...}; - return std::make_shared(info, pred, inputs); -} - -template -std::shared_ptr wrap_type(const OutputVector& inputs = {}) { - return wrap_type(inputs, [](const Output& output) { - return true; - }); -} - -template -std::shared_ptr wrap_type(const pattern::op::ValuePredicate& pred) { - return wrap_type({}, pred); -} +using ov::pass::pattern::wrap_type; } // namespace pattern } // namespace ngraph diff --git a/ngraph/core/include/openvino/core/node.hpp b/ngraph/core/include/openvino/core/node.hpp index 3ba65a5cd78..6ea71d68853 100644 --- a/ngraph/core/include/openvino/core/node.hpp +++ b/ngraph/core/include/openvino/core/node.hpp @@ -50,12 +50,14 @@ class Result; } // namespace v0 } // namespace op -namespace pattern { -class Matcher; -} // namespace pattern } // namespace ngraph namespace ov { +namespace pass { +namespace pattern { +class Matcher; +} // namespace pattern +} // namespace pass using HostTensor = ngraph::runtime::HostTensor; using HostTensorPtr = std::shared_ptr; using HostTensorVector = std::vector; @@ -487,11 +489,11 @@ public: } OPENVINO_SUPPRESS_DEPRECATED_END - virtual bool match_value(ngraph::pattern::Matcher* matcher, + virtual bool match_value(ov::pass::pattern::Matcher* matcher, const Output& pattern_value, const Output& graph_value); - virtual bool match_node(ngraph::pattern::Matcher* matcher, const Output& graph_value); + virtual bool match_node(ov::pass::pattern::Matcher* matcher, const Output& graph_value); private: descriptor::Input& get_input_descriptor(size_t position); diff --git a/ngraph/core/include/openvino/pass/constant_folding.hpp b/ngraph/core/include/openvino/pass/constant_folding.hpp new file mode 100644 index 00000000000..de2bb3c0f83 --- /dev/null +++ b/ngraph/core/include/openvino/pass/constant_folding.hpp @@ -0,0 +1,28 @@ +// Copyright (C) 2018-2021 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#pragma once + +#include "openvino/pass/pass.hpp" + +namespace ov { +namespace pass { +/** + * @brief Constant folding iterates over the function and tries to evaluate nodes + * with constant inputs. Such nodes are then replaced with new Constants containing + * the result of a folded operation. + */ +class OPENVINO_API ConstantFolding : public FunctionPass { +public: + OPENVINO_RTTI_DECLARATION; + bool run_on_function(std::shared_ptr f) override; + +private: + void copy_runtime_info_to_target_inputs(const std::shared_ptr& node, const Output& replacement); + /// \brief Folds pre-calculated output tensor values to constants in case lower and + /// upper estimations are equal. Traverses graph backwards starting from the results. + bool pre_calculated_values_folding(const std::shared_ptr& f); +}; +} // namespace pass +} // namespace ov diff --git a/ngraph/core/include/openvino/pass/convert_fp32_to_fp16.hpp b/ngraph/core/include/openvino/pass/convert_fp32_to_fp16.hpp new file mode 100644 index 00000000000..3d20e16bba3 --- /dev/null +++ b/ngraph/core/include/openvino/pass/convert_fp32_to_fp16.hpp @@ -0,0 +1,17 @@ +// Copyright (C) 2018-2021 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#pragma once + +#include "openvino/pass/graph_rewrite.hpp" + +namespace ov { +namespace pass { +class OPENVINO_API ConvertFP32ToFP16 : public FunctionPass { +public: + OPENVINO_RTTI_DECLARATION; + bool run_on_function(std::shared_ptr) override; +}; +} // namespace pass +} // namespace ov diff --git a/ngraph/core/include/openvino/pass/graph_rewrite.hpp b/ngraph/core/include/openvino/pass/graph_rewrite.hpp new file mode 100644 index 00000000000..154b10f9039 --- /dev/null +++ b/ngraph/core/include/openvino/pass/graph_rewrite.hpp @@ -0,0 +1,249 @@ +// Copyright (C) 2018-2021 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#pragma once + +#include +#include +#include + +#include "openvino/pass/pass.hpp" +#include "openvino/pass/pattern/matcher.hpp" + +namespace ov { +using matcher_pass_callback = std::function; +using graph_rewrite_callback = std::function; +using recurrent_graph_rewrite_callback = std::function; +using handler_callback = std::function& node)>; +namespace pass { +/// \brief MatcherPass is a basic block for pattern based transformations. It describes +/// pattern and +/// action that is applied if pattern is matched. +/// +/// MatcherPass consists of Matcher and matcher_pass_callback that needs to be implemented +/// and +/// finally registered by using \sa register_matcher. MatcherPass can be executed on node +/// within +/// \sa apply method. To run matcher pass on Function use GraphRewrite. +/// In addition MatcherPass provides a way for adding new operations into GraphRewrite +/// execution +/// queue. That means that operations that were created inside transformation callback can +/// be added +/// for matching. To register node use \sa register_new_node method. GraphRewrite +/// automatically +/// takes registered nodes and put them to execution queue. If multiple nodes were register +/// make +/// sure that they were registered in topological order. +/// Note: when implementing pattern for Matcher make sure that root node is an operation +/// from opset +/// or has ov::pass::pattern::op::WrapType. That will help GraphRewrite to execute matcher +/// passes more +/// efficient. + +class OPENVINO_API MatcherPass : public PassBase { +public: + OPENVINO_RTTI_DECLARATION; + + MatcherPass() = default; + + MatcherPass(const MatcherPass&) = delete; + MatcherPass& operator=(const MatcherPass&) = delete; + + explicit MatcherPass(const std::string& name, + const std::shared_ptr& m, + const handler_callback& handler, + const PassPropertyMask& property = PassProperty::CHANGE_DYNAMIC_STATE) + : PassBase(), + m_handler(handler), + m_matcher(m) { + set_name(name); + set_property(property, true); + } + + bool apply(std::shared_ptr node); + + template + std::shared_ptr register_new_node(Args&&... args) { + auto node = std::make_shared(std::forward(args)...); + m_new_nodes.push_back(node); + return node; + } + + template + std::shared_ptr register_new_node(const std::shared_ptr& node) { + m_new_nodes.push_back(node); + return node; + } + + const std::vector>& get_new_nodes() { + return m_new_nodes; + } + void clear_new_nodes() { + m_new_nodes.clear(); + } + std::shared_ptr get_matcher() { + return m_matcher; + } + +protected: + void register_matcher(const std::shared_ptr& m, + const graph_rewrite_callback& callback, + const PassPropertyMask& property = PassProperty::CHANGE_DYNAMIC_STATE); + +private: + handler_callback m_handler; + std::shared_ptr m_matcher; + std::vector> m_new_nodes; +}; + +/// \brief GraphRewrite is a container for MatcherPasses that allows to run them on Function +/// in +/// efficient way +/// +/// Graph rewrite pass is used for matcher passes execution on Function. +/// To register MatcherPass use \sa add_matcher(args) method where T is a MatcherPass +/// class. +/// As a default algorithm graph rewrite pass traverse Function in topological order and +/// applies +/// registered matcher passes for each node. But if all registered matcher passes have type +/// based +/// root node in Matcher pattern then efficient mechanism is used to execute them. +/// Matcher pattern root is type based if it's operation from opset or +/// pattern::op::WrapType. +/// Note: when implementing pattern for Matcher make sure that root node is an operation +/// from opset +/// or has ov::pattern::op::WrapType. That will help GraphRewrite to execute matcher +/// passes more +/// efficient. + +class OPENVINO_API GraphRewrite : public FunctionPass { +public: + OPENVINO_RTTI_DECLARATION; + + GraphRewrite() = default; + + explicit GraphRewrite(const std::shared_ptr& pass) : FunctionPass() { + m_matchers.push_back(pass); + } + + /// \brief Register given transformation class type to GraphRewrite execution list + /// All registered transformations will be executed in a single graph traversal. + /// Example below show the basic usage of pass::GraphRewrite + /// + /// pass::Manager manager; + /// auto anchor = manager.register_pass(); + /// anchor->add_matcher(); + /// anchor->add_matcher(); + /// anchor->set_name("CommonMatchers"); + /// manager.run_passes(f); + /// + /// For some purposes transformation can be registered and disabled by default. + /// + /// anchor->add_matcher(); + /// + /// \return shared_ptr to the transformation instance + template ::value, bool>::type = true> + std::shared_ptr add_matcher(Args&&... args) { + static_assert(std::is_base_of::value, "pass not derived from MatcherPass"); + auto pass = std::make_shared(std::forward(args)...); + auto pass_config = get_pass_config(); + pass->set_pass_config(pass_config); + if (!Enabled && !pass_config->is_enabled()) { + pass_config->disable(); + } + m_matchers.push_back(pass); + return pass; + } + + /// \brief Register passes from GraphRewrite class that contains sequence of matcher + /// passes registered in its ctor. + /// For example: + /// + /// class ov::pass::LinFusions: public ov::pass::GraphRewrite { + /// public: + /// OPENVINO_RTTI_DECLARATION; + /// Fusions() { + /// add_matcher(); + /// add_matcher(); + /// } + /// }; + /// + /// pass::Manager manager; + /// auto anchor = manager.register_pass(); + /// anchor->add_matcher(); + /// anchor->add_matcher(); + /// anchor->set_name("CommonFusions"); + /// manager.run_passes(f); + /// + /// In this case all matcher passes from LinFusions pass will be united with other + /// registered matchers. + template ::value, bool>::type = true> + void add_matcher(Args&&... args) { + static_assert(std::is_base_of::value, "pass not derived from GraphRewrite"); + auto pass = std::make_shared(std::forward(args)...); + auto pass_config = get_pass_config(); + + for (auto& matcher : pass->m_matchers) { + pass->set_pass_config(pass_config); + m_matchers.push_back(matcher); + } + } + + OPENVINO_DEPRECATED("Use MatcherPass instead") + void add_matcher(const std::shared_ptr& m, + const graph_rewrite_callback& callback, + const PassPropertyMask& property); + + OPENVINO_DEPRECATED("Use MatcherPass instead") + void add_matcher(const std::shared_ptr& m, const ov::graph_rewrite_callback& callback); + + bool run_on_function(std::shared_ptr f) override; + + void set_pass_config(const std::shared_ptr& pass_config) override; + +protected: + bool apply_matcher_passes(std::shared_ptr f, std::deque> nodes_to_run); + + bool m_enable_shape_inference = false; + + std::vector> m_matchers; +}; + +class OPENVINO_API BackwardGraphRewrite : public GraphRewrite { +public: + OPENVINO_RTTI_DECLARATION; + + BackwardGraphRewrite() = default; + + explicit BackwardGraphRewrite(const std::shared_ptr& pass) : GraphRewrite(pass) {} + + bool run_on_function(std::shared_ptr f) override; +}; + +class OPENVINO_API RecurrentGraphRewrite : public FunctionPass { +public: + RecurrentGraphRewrite(size_t num_iters = 10) : FunctionPass(), m_num_iters(num_iters) {} + + void add_matcher(const std::shared_ptr& m, + const ov::recurrent_graph_rewrite_callback& callback, + const PassPropertyMask& property); + + // TODO: This interface may deprecate after all passes are refactored. + void add_matcher(const std::shared_ptr& m, + const ov::recurrent_graph_rewrite_callback& callback); + + bool run_on_function(std::shared_ptr f) override; + +private: + size_t m_num_iters; + + std::vector> m_matchers; +}; +} // namespace pass +} // namespace ov diff --git a/ngraph/core/include/openvino/pass/low_latency.hpp b/ngraph/core/include/openvino/pass/low_latency.hpp new file mode 100644 index 00000000000..214b5545370 --- /dev/null +++ b/ngraph/core/include/openvino/pass/low_latency.hpp @@ -0,0 +1,48 @@ +// Copyright (C) 2018-2021 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#pragma once + +#include +#include + +#include "openvino/pass/pass.hpp" + +namespace ov { +namespace pass { +/** + * @brief The transformation finds all TensorIterator/Loop layers in the network, + * processes all back edges that describe a connection between Result and Parameter + * of the TensorIterator/Loop bodies,and inserts ReadValue and Assign layers at the + * input and output corresponding to this back edge. + * Supported platforms: CPU, GNA. + * + * The example below describes the changes made by the transformation + * [] - TensorIterator body + * () - new layer + * BE - back-edge + * + * before applying the transformation: + * -> input1[BE_1 -> Parameter -> Layers ... -> Result -> BE_1 ]output1-> + * + * after applying the transformation: + * ->(ReadValue)-> input1[BE_1 ->Parameter->Layers ...->Result->BE_1]output1 ->(Assign) + * \ + * ->... + * After applying the transformation, the resulting network can be inferred + * step by step, the states will store between inferences. + */ +class OPENVINO_API LowLatency2 : public FunctionPass { +public: + OPENVINO_RTTI_DECLARATION; + + explicit LowLatency2(bool use_const_initializer = true) : m_use_const_initializer(use_const_initializer) {} + + bool run_on_function(std::shared_ptr f) override; + +private: + bool m_use_const_initializer; +}; +} // namespace pass +} // namespace ov diff --git a/ngraph/core/include/openvino/pass/manager.hpp b/ngraph/core/include/openvino/pass/manager.hpp new file mode 100644 index 00000000000..6f5926b41b8 --- /dev/null +++ b/ngraph/core/include/openvino/pass/manager.hpp @@ -0,0 +1,116 @@ +// Copyright (C) 2018-2021 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#pragma once + +#include +#include +#include +#include + +#include "openvino/pass/pass.hpp" +#include "openvino/pass/validate.hpp" + +namespace ov { +namespace pass { +class OPENVINO_API Manager { +public: + Manager(); + ~Manager(); + + //// \brief Construct Manager with shared PassConfig instance + explicit Manager(std::shared_ptr pass_config); + + /// \brief Register given transformation class type to execution list + /// Example below show the basic usage of pass::Manager + /// + /// pass::Manager manager; + /// manager.register_pass(/*transformation constructor ars*/); + /// manager.run_passes(f); + /// + /// For some purposes transformation can be registered and disabled by default. + /// + /// manager.register_pass(); + /// + /// \return shared_ptr to the transformation instance + template + std::shared_ptr register_pass(Args&&... args) { + auto rc = push_pass(std::forward(args)...); + rc->set_pass_config(m_pass_config); + if (m_per_pass_validation) { + push_pass(); + } + if (!Enable && !m_pass_config->is_enabled()) { + m_pass_config->disable(); + } + return rc; + } + + void run_passes(std::shared_ptr); + + void set_pass_visualization(bool new_state) { + m_visualize = new_state; + } + /// \brief Set flag to enable/disable running Validate pass after executing + /// each registered pass + /// \param new_state Value "true" enables Validate pass run; "false", otherwise + void set_per_pass_validation(bool new_state) { + m_per_pass_validation = new_state; + } + /// \brief Callback is a lambda function that can be used by registered transformations. + /// The main purpose of this callback is to provide a way for plugins to disable/enable + /// transformations based on some conditions. In some cases plugins may want not to + /// execute some + /// transformations. + /// For example plugin can disable unpleasant decompositions because of performance + /// reasons for + /// some cases. + /// Callback example: + /// auto callback = [](const std::shared_ptr & node) -> bool { + /// return std::dynamic_pointer_cast(node) != + /// nullptr; + /// }; + /// This callback returns true in case of DepthToSpace operation. So when execution + /// DepthToSpace + /// decomposition pass will check is this decomposition needed or plugin can execute + /// this + /// operation directly. And of course on transformation side we need to have a response + /// for this + /// callback. + /// if (transformation_callback(batch_to_space)) { + /// return false; + /// } + /// \param callback lamda function that returns true in case if node is supported by + /// plugin and + /// transformation is not needed + OPENVINO_DEPRECATED("Please use get_pass_config() to configure transformation pipeline") + void set_callback(const param_callback& callback) { + m_pass_config->set_callback(callback); + } + /// \return PassConfig shared object. This object is used for transformations pipeline + /// configuration. + /// This object allows to disable/enable transformations execution, set callback to + /// particular + /// transformation. For mo details see PassConfig class. + std::shared_ptr get_pass_config() { + return m_pass_config; + } + +protected: + template + std::shared_ptr push_pass(Args&&... args) { + static_assert(std::is_base_of::value, "pass not derived from pass base"); + auto pass = std::make_shared(std::forward(args)...); + auto pass_base = std::static_pointer_cast(pass); + m_pass_list.push_back(pass_base); + return pass; + } + + std::shared_ptr m_pass_config; + std::vector> m_pass_list; + bool m_visualize = false; + bool m_per_pass_validation = true; +}; +} // namespace pass +} // namespace ov diff --git a/ngraph/core/include/openvino/pass/pass.hpp b/ngraph/core/include/openvino/pass/pass.hpp new file mode 100644 index 00000000000..c7056390a4f --- /dev/null +++ b/ngraph/core/include/openvino/pass/pass.hpp @@ -0,0 +1,110 @@ +// Copyright (C) 2018-2021 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#pragma once + +#include +#include +#include + +#include "ngraph/util.hpp" +#include "openvino/core/core_visibility.hpp" +#include "openvino/core/deprecated.hpp" +#include "openvino/core/function.hpp" +#include "openvino/core/node.hpp" +#include "openvino/pass/pass_config.hpp" + +namespace ov { +namespace pass { +enum class PassProperty : uint32_t { + // Pass requires node shapes to be static + REQUIRE_STATIC_SHAPE = 0x1, + // Pass transformation will change the function's dynamic state + CHANGE_DYNAMIC_STATE = 1 << 1, +}; + +using PassPropertyMask = ngraph::EnumMask; + +class OPENVINO_API PassBase { + friend class Manager; + +public: + PassBase(); + virtual ~PassBase() = default; + /// Check if this pass has all the pass properties. + bool get_property(const PassPropertyMask& prop_mask) const; + + void set_name(const std::string& name) { + m_name = name; + } + std::string get_name() const; + + /// \brief Set callback for particular transformation type. + /// This method set global callback. For more details see PassConfig class + /// documentation. + /// \param callback lambda function that takes node and returns bool + void set_callback(const param_callback& callback); + + /// \brief Set PassConfig for particular transformation instance + /// \param pass_config is a PassConfig shared_ptr + virtual void set_pass_config(const std::shared_ptr& pass_config) { + m_pass_config = pass_config; + } + + /// \brief Allows to access PassConfig shared instance + /// \return Shared instance of PassConfig class + std::shared_ptr get_pass_config() { + return m_pass_config; + } + /// \brief Applies callback for given node. By default callback returns false. + /// This method remains here only for backward compatibility and will be removed + /// after all transformations are moved to transformation_callback() method. + /// \return result of callback execution for given node + NGRAPH_DEPRECATED("Please use transformation_callback method instead") + bool m_transformation_callback(const std::shared_ptr& node) { + return m_pass_config->get_callback(get_type_info())(node); + } + + /// \brief Applies callback for given node. By default callback returns false. + /// \param node which will be used inside callback + /// \return result of callback execution for given node + bool transformation_callback(const std::shared_ptr& node) { + return m_pass_config->get_callback(get_type_info())(node); + } + + using type_info_t = DiscreteTypeInfo; + + virtual const type_info_t& get_type_info() const = 0; + +protected: + void set_property(const PassPropertyMask& prop, bool value); + +private: + PassPropertyMask m_property; + + std::string m_name; + std::shared_ptr m_pass_config; +}; + +class OPENVINO_API FunctionPass : public PassBase { +public: + NGRAPH_RTTI_DECLARATION; + ~FunctionPass() override; + virtual bool run_on_function(std::shared_ptr) = 0; +}; + +class Manager; +enum class FusionType : uint32_t { + //`DIFFERENTIABLE_FUSIONS` produce ops that support autodiff + // i.e. implement `generate_adjoints` + DIFFERENTIABLE_FUSIONS = 0x1, + REGULAR_FUSIONS = 0x2, + //`FOP_FUSIONS` produce ops in the FusedOps category that might + // not be supported by all backends + FOP_FUSIONS = 0x4, + ALL_FUSIONS = 0xFFFFFFFF +}; +using FusionTypeMask = ngraph::EnumMask; +} // namespace pass +} // namespace ov diff --git a/ngraph/core/include/openvino/pass/pass_config.hpp b/ngraph/core/include/openvino/pass/pass_config.hpp new file mode 100644 index 00000000000..5dcd72f3e89 --- /dev/null +++ b/ngraph/core/include/openvino/pass/pass_config.hpp @@ -0,0 +1,176 @@ +// Copyright (C) 2018-2021 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#pragma once + +#include +#include +#include + +#include "ngraph/util.hpp" +#include "openvino/core/core_visibility.hpp" +#include "openvino/core/deprecated.hpp" +#include "openvino/core/function.hpp" +#include "openvino/core/node.hpp" + +namespace ov { +namespace pass { +using param_callback = std::function)>; +using param_callback_map = std::map; + +/// \brief Class representing a transformations config that is used for disabling/enabling +/// transformations registered inside pass::Manager and also allows to set callback for all +/// transformations or for particular transformation. +/// +/// When pass::Manager is created all passes registered inside this manager including nested +/// passes will share the same instance of PassConfig class. +/// To work with this class first you need to get shared instance of this class by calling +/// manager.get_pass_config() method. Then you will be able to disable/enable passes based +/// on transformations type_info. For example: +/// +/// pass::Manager manager; +/// manager.register_pass(); +/// auto pass_config = manager.get_pass_config(); +/// pass_config->disable(); // this will disable nested pass inside +/// // CommonOptimizations pipeline +/// manager.run_passes(f); +/// +/// Sometimes it is needed to call transformation inside other transformation manually. And +/// for that case before running transformation you need manually check that this pass is +/// not disabled and then you need to set current PassConfig instance to this +/// transformation. For example: +/// +/// // Inside MatcherPass callback or inside FunctionPass run_on_function() method +/// // you need to call get_pass_config() method to get shared instance of PassConfig +/// auto pass_config = get_pass_config(); +/// +/// // Before running nested transformation you need to check is it disabled or not +/// if (!pass_config->is_disabled()) { +/// auto pass = ConvertGELU(); +/// pass->set_pass_config(pass_config); +/// pass.apply(node); +/// } +/// +/// Following this logic inside your transformations you will guaranty that transformations +/// will be executed in a right way. +class OPENVINO_API PassConfig { +public: + /// \brief Disable transformation by its type_info + /// \param type_info Transformation type_info + void disable(const DiscreteTypeInfo& type_info); + /// \brief Disable transformation by its class type (based on type_info) + template + void disable() { + OPENVINO_SUPPRESS_DEPRECATED_START + disable(T::type_info); + OPENVINO_SUPPRESS_DEPRECATED_END + } + + /// \brief Enable transformation by its type_info + /// \param type_info Transformation type_info + void enable(const DiscreteTypeInfo& type_info); + /// \brief Enable transformation by its class type (based on type_info) + template + void enable() { + OPENVINO_SUPPRESS_DEPRECATED_START + enable(T::type_info); + OPENVINO_SUPPRESS_DEPRECATED_END + } + + /// \brief Set callback for all kind of transformations + void set_callback(const param_callback& callback) { + m_callback = callback; + } + template + typename std::enable_if::type set_callback(const param_callback& callback) {} + + /// \brief Set callback for particular transformation class types + /// + /// Example below show how to set callback for one or multiple passes using this method. + /// + /// pass_config->set_callback( + /// [](const_node_ptr &node) -> bool { + /// // Disable transformations for cases when input shape rank is not + /// equal to 4 + /// const auto input_shape_rank = + /// node->get_output_partial_shape(0).rank().get_length(); + /// if (input_shape_rank != 4) { + /// return false; + /// } + /// return true; + /// }); + /// + /// Note that inside transformations you must provide code that work with this callback. + /// See example below: + /// + /// if (transformation_callback(node)) { + /// return false; // exit from transformation + /// } + /// + template + void set_callback(const param_callback& callback) { + m_callback_map[T::type_info] = callback; + set_callback(callback); + } + + /// \brief Get callback for given transformation type_info + /// \param type_info Transformation type_info + /// + /// In case if callback wasn't set for given transformation type then global callback + /// will be returned. But if even global callback wasn't set then default callback will + /// be returned. + param_callback get_callback(const DiscreteTypeInfo& type_info) const; + + /// \brief Get callback for given transformation class type + /// \return callback lambda function + template + param_callback get_callback() const { + NGRAPH_SUPPRESS_DEPRECATED_START + return get_callback(T::type_info); + NGRAPH_SUPPRESS_DEPRECATED_END + } + + /// \brief Check either transformation type is disabled or not + /// \param type_info Transformation type_info + /// \return true if transformation type was disabled and false otherwise + bool is_disabled(const DiscreteTypeInfo& type_info) const { + return m_disabled.count(type_info); + } + + /// \brief Check either transformation class type is disabled or not + /// \return true if transformation type was disabled and false otherwise + template + bool is_disabled() const { + NGRAPH_SUPPRESS_DEPRECATED_START + return is_disabled(T::type_info); + NGRAPH_SUPPRESS_DEPRECATED_END + } + + /// \brief Check either transformation type is force enabled or not + /// \param type_info Transformation type_info + /// \return true if transformation type was force enabled and false otherwise + bool is_enabled(const DiscreteTypeInfo& type_info) const { + return m_enabled.count(type_info); + } + + /// \brief Check either transformation class type is force enabled or not + /// \return true if transformation type was force enabled and false otherwise + template + bool is_enabled() const { + return is_enabled(T::type_info); + } + + void add_disabled_passes(const PassConfig& rhs); + +private: + param_callback m_callback = [](const std::shared_ptr&) { + return false; + }; + param_callback_map m_callback_map; + std::unordered_set m_disabled; + std::unordered_set m_enabled; +}; +} // namespace pass +} // namespace ov diff --git a/ngraph/core/include/openvino/pass/pattern/matcher.hpp b/ngraph/core/include/openvino/pass/pattern/matcher.hpp new file mode 100644 index 00000000000..261be2f86bc --- /dev/null +++ b/ngraph/core/include/openvino/pass/pattern/matcher.hpp @@ -0,0 +1,271 @@ +// Copyright (C) 2018-2021 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#pragma once + +#include + +#include +#include + +#include "ngraph/op/constant.hpp" +#include "openvino/core/except.hpp" +#include "openvino/core/node.hpp" +#include "openvino/pass/pattern/op/any.hpp" +#include "openvino/pass/pattern/op/any_of.hpp" +#include "openvino/pass/pattern/op/any_output.hpp" +#include "openvino/pass/pattern/op/label.hpp" +#include "openvino/pass/pattern/op/skip.hpp" + +namespace ov { +namespace pass { +class GraphRewrite; + +namespace pattern { +class Matcher; + +class OPENVINO_API MatcherState { +public: + MatcherState(Matcher*); + bool finish(bool is_successful); + ~MatcherState(); + +protected: + Matcher* m_matcher; + PatternValueMap m_pattern_value_map; + PatternValueMaps m_pattern_value_maps; + size_t m_watermark; + size_t m_capture_size; + bool m_restore{true}; +}; + +/// Matcher looks for node patterns in a computation graph. The patterns are described by an +/// automaton that is described by an extended computation graph. The matcher executes +/// by attempting to match the start node of the pattern to a computation graph value +/// (output of a Node). In addition to determing if a match occurs, a pattern node may add +/// graph nodes to a list of matched nodes, associate nodes with graph values, and start +/// submatches. Submatches add match state changes to the enclosing match if the submatch +/// succeeds; otherwise the state is reverted. +/// +/// The default match behavior of a pattern node with a graph nodes is that the computation +/// graph value is added to the end of the matched value list and the match succeeds if the +/// node/pattern types match and the input values match. In the case of a commutative node, +/// the inputs can match in any order. If the matcher is in strict mode, the graph value +/// element type and shape must also match. +/// +/// Pattern nodes that have different match behavior are in ov::pass::pattern::op and have +/// descriptions of their match behavior. +class OPENVINO_API Matcher { +public: + using PatternMap = ov::pass::pattern::PatternMap; + + // Avoid implicit string construction from nullptr. + Matcher(const std::shared_ptr pattern_node, std::nullptr_t name) = delete; + + Matcher() = default; + Matcher(Output& pattern_node) : m_pattern_node{pattern_node} {} + + Matcher(Output& pattern_node, const std::string& name) : m_pattern_node(pattern_node), m_name{name} {} + + /// \brief Constructs a Matcher object + /// + /// \param pattern_node is a pattern sub graph that will be matched against input graphs + /// \param name is a string which is used for logging and disabling a matcher + /// \param strict_mode forces a matcher to consider shapes and ET of nodes + Matcher(const Output& pattern_node, const std::string& name, bool strict_mode) + : m_pattern_node(pattern_node), + m_name(name), + m_strict_mode(strict_mode) {} + + // Some matches should start on a node rather than an output. These three constructors + // are transition until we work out the right way to do that. + Matcher(std::shared_ptr pattern_node); + Matcher(std::shared_ptr pattern_node, const std::string& name); + Matcher(std::shared_ptr pattern_node, const std::string& name, bool strict_mode); + + virtual ~Matcher() = default; + /// \brief Matches a pattern to \p graph_node + /// + /// \param graph_value is an input graph to be matched against + bool match(const Output& graph_value); + + bool match(std::shared_ptr graph_node); + + /// \brief Matches a pattern to \p graph_node + /// + /// \param graph_value is an input graph to be matched against + /// \param previous_matches contains previous mappings from labels to nodes to use + bool match(const Output& graph_value, const PatternMap& previous_matches); + bool match(const Output& graph_value, const PatternValueMap& previous_matches); + + template + static std::shared_ptr unique_match(const std::shared_ptr& node) { + std::shared_ptr matched; + for (const auto& arg : node->input_values()) { + if (auto t_casted = ov::as_type_ptr(arg.get_node_shared_ptr())) { + if (matched) { + throw Exception("There's more than two arguments of the same type"); + } else { + matched = t_casted; + } + } + } + return matched; + } + + bool is_contained_match(const NodeVector& exclusions = {}, bool ignore_unused = true); + const NodeVector get_matched_nodes() { + return as_node_vector(m_matched_list); + } + const OutputVector& get_matched_values() const { + return m_matched_list; + } + OutputVector& get_matched_values() { + return m_matched_list; + } + void reset() {} + const std::string& get_name() { + return m_name; + } + std::shared_ptr get_pattern() { + return m_pattern_node.get_node_shared_ptr(); + } + Output get_pattern_value() { + return m_pattern_node; + } + std::shared_ptr get_match_root(); + Output get_match_value(); + PatternMap get_pattern_map() const; + PatternValueMap& get_pattern_value_map() { + return m_pattern_map; + } + PatternValueMaps& get_pattern_value_maps() { + return m_pattern_value_maps; + } + /// \brief Low-level helper to match recurring patterns + /// + /// \param graph is a graph to be matched against + /// \param pattern is a recurring pattern + /// \param rpattern specifies a node to recur from next + /// \param patterns a map from labels to matches + + size_t add_node(Output node); + + bool virtual match_value(const ov::Output& pattern_value, const ov::Output& graph_value); + + bool is_strict_mode() { + return m_strict_mode; + } + virtual bool match_arguments(Node* pattern_node, const std::shared_ptr& graph_node); + + void capture(const std::set& static_nodes); + + void clear_state(); + + size_t get_number_of_recurrent_matches() const { + return m_pattern_value_maps.size(); + } + NodeVector get_bound_nodes_for_pattern(const Output& pattern) const; + size_t get_number_of_bound_labels() const; + /// \brief Try a match + MatcherState start_match(); + + Output m_match_root; + Output m_pattern_node; + PatternValueMap m_pattern_map; + PatternValueMaps m_pattern_value_maps; + OutputVector m_matched_list; + +protected: + bool match_permutation(const OutputVector& pattern_args, const OutputVector& args); + + std::string m_name{"unnamed"}; + bool m_strict_mode{false}; +}; + +class OPENVINO_API RecurrentMatcher { +public: + /// \brief Constructs a RecurrentMatcher object. Reccurent Matchers are used to match + /// repeating patterns (e.g. RNN, LSTM, GRU cells) + /// + /// \param initial_pattern is a pattern sub graph describing the initial cell + /// \param pattern is a pattern sub graph describing an individual cell + /// \param rpattern is a (recurring) label to denote which node the next match should + /// start at + /// \param correlated_patterns is a set of labels whose bound nodes must remain the same + /// across all cells + RecurrentMatcher(const Output& initial_pattern, + const Output& pattern, + const std::shared_ptr& rpattern, + const std::set>& correlated_patterns) + : m_initial_pattern(initial_pattern), + m_pattern(pattern), + m_recurrent_pattern(rpattern), + m_correlated_patterns(correlated_patterns) {} + + /// \brief Constructs a RecurrentMatcher object. Reccurent Matchers are used to match + /// repeating patterns (e.g. RNN, LSTM, GRU cells) + /// + /// \param pattern is a pattern sub graph describing an individual cell + /// \param rpattern is a (recurring) label to denote which node the next match should + /// start at + /// \param correlated_patterns is a set of labels whose bound nodes must remain the same + /// across all cells + RecurrentMatcher(const Output& pattern, + const std::shared_ptr& rpattern, + const std::set>& correlated_patterns) + : RecurrentMatcher(pattern, pattern, rpattern, correlated_patterns) {} + + RecurrentMatcher(const Output& initial_pattern, + const Output& pattern, + const std::shared_ptr& rpattern, + const std::set>& correlated_patterns); + + RecurrentMatcher(const Output& pattern, + const std::shared_ptr& rpattern, + const std::set>& correlated_patterns) + : RecurrentMatcher(pattern, pattern, rpattern, correlated_patterns) {} + + /// \brief Returns a vector of bound nodes for a given label (used in a pattern + /// describing an individual cell + NodeVector get_bound_nodes_for_pattern(const std::shared_ptr& pattern) const { + if (m_matches.count(pattern) == 0) { + throw Exception("No bound nodes for a given label"); + } + + return as_node_vector(m_matches.at(pattern)); + } + + size_t get_number_of_recurrent_matches() const { + if (m_matches.size() == 0) { + return 0; + } + + return (*m_matches.begin()).second.size(); + } + + size_t get_number_of_bound_labels() const { + return m_matches.size(); + } + /// \brief Tries to match a pattern for an individual cell to a given \p graph + bool match(Output graph); + + std::shared_ptr get_match_root() { + return m_match_root.get_node_shared_ptr(); + } + Output get_match_value() { + return m_match_root; + } + +private: + Output m_initial_pattern; + Output m_pattern; + std::shared_ptr m_recurrent_pattern; + const std::set> m_correlated_patterns; + RPatternValueMap m_matches; + Output m_match_root; +}; +} // namespace pattern +} // namespace pass +} // namespace ov diff --git a/ngraph/core/include/openvino/pass/pattern/op/any.hpp b/ngraph/core/include/openvino/pass/pattern/op/any.hpp new file mode 100644 index 00000000000..3552e25ebc0 --- /dev/null +++ b/ngraph/core/include/openvino/pass/pattern/op/any.hpp @@ -0,0 +1,45 @@ +// Copyright (C) 2018-2021 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#pragma once + +#include "openvino/core/node.hpp" +#include "openvino/pass/pattern/op/pattern.hpp" + +namespace ov { +namespace pass { +namespace pattern { +namespace op { +/// The graph value is to the matched value list. If the predicate is true for the node +/// and the arguments match, the match succeeds. +class OPENVINO_API Any : public Pattern { +public: + static constexpr NodeTypeInfo type_info{"patternAny", 0}; + const NodeTypeInfo& get_type_info() const override; + /// \brief creates a Any node containing a sub-pattern described by \sa type and \sa + /// shape. + Any(const element::Type& type, const PartialShape& s, ValuePredicate pred, const OutputVector& wrapped_values) + : Pattern(wrapped_values, pred) { + set_output_type(0, type, s); + } + Any(const element::Type& type, const PartialShape& s, NodePredicate pred, const NodeVector& wrapped_values) + : Any(type, s, as_value_predicate(pred), as_output_vector(wrapped_values)) {} + /// \brief creates a Any node containing a sub-pattern described by the type and + /// shape of \sa node. + Any(const Output& node, ValuePredicate pred, const OutputVector& wrapped_values) + : Any(node.get_element_type(), node.get_partial_shape(), pred, wrapped_values) {} + Any(const Output& node, NodePredicate pred, const NodeVector& wrapped_values) + : Any(node.get_element_type(), + node.get_partial_shape(), + as_value_predicate(pred), + as_output_vector(wrapped_values)) {} + + bool match_value(pattern::Matcher* matcher, + const Output& pattern_value, + const Output& graph_value) override; +}; +} // namespace op +} // namespace pattern +} // namespace pass +} // namespace ov diff --git a/ngraph/core/include/openvino/pass/pattern/op/any_of.hpp b/ngraph/core/include/openvino/pass/pattern/op/any_of.hpp new file mode 100644 index 00000000000..ce07522173c --- /dev/null +++ b/ngraph/core/include/openvino/pass/pattern/op/any_of.hpp @@ -0,0 +1,54 @@ +// Copyright (C) 2018-2021 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#pragma once + +#include "openvino/core/node.hpp" +#include "openvino/pass/pattern/op/pattern.hpp" + +namespace ov { +namespace pass { +namespace pattern { +namespace op { +/// The graph value is added to the matched values list. If the predicate is true for +/// the +/// graph node, a submatch is performed on the input of AnyOf and each input of the +/// graph node. The first match that succeeds results in a successful match. Otherwise +/// the match fails. +/// +/// AnyOf may be given a type and shape for use in strict mode. +class OPENVINO_API AnyOf : public Pattern { +public: + static constexpr NodeTypeInfo type_info{"patternAnyOf", 0}; + const NodeTypeInfo& get_type_info() const override; + /// \brief creates a AnyOf node containing a sub-pattern described by \sa type and + /// \sa shape. + AnyOf(const element::Type& type, const PartialShape& s, ValuePredicate pred, const OutputVector& wrapped_values) + : Pattern(wrapped_values, pred) { + if (wrapped_values.size() != 1) { + throw Exception("AnyOf expects exactly one argument"); + } + set_output_type(0, type, s); + } + AnyOf(const element::Type& type, const PartialShape& s, NodePredicate pred, const NodeVector& wrapped_values) + : AnyOf( + type, + s, + [pred](const Output& value) { + return pred(value.get_node_shared_ptr()); + }, + as_output_vector(wrapped_values)) {} + + /// \brief creates a AnyOf node containing a sub-pattern described by the type and + /// shape of \sa node. + AnyOf(const Output& node, ValuePredicate pred, const OutputVector& wrapped_values) + : AnyOf(node.get_element_type(), node.get_partial_shape(), pred, wrapped_values) {} + AnyOf(const std::shared_ptr& node, NodePredicate pred, const NodeVector& wrapped_values) + : AnyOf(node, as_value_predicate(pred), as_output_vector(wrapped_values)) {} + bool match_value(Matcher* matcher, const Output& pattern_value, const Output& graph_value) override; +}; +} // namespace op +} // namespace pattern +} // namespace pass +} // namespace ov diff --git a/ngraph/core/include/openvino/pass/pattern/op/any_output.hpp b/ngraph/core/include/openvino/pass/pattern/op/any_output.hpp new file mode 100644 index 00000000000..cbd7865ecda --- /dev/null +++ b/ngraph/core/include/openvino/pass/pattern/op/any_output.hpp @@ -0,0 +1,30 @@ +// Copyright (C) 2018-2021 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#pragma once + +#include "openvino/core/node.hpp" +#include "openvino/pass/pattern/op/pattern.hpp" + +namespace ov { +namespace pass { +namespace pattern { +namespace op { +/// Matches any output of a node +class OPENVINO_API AnyOutput : public Pattern { +public: + static constexpr NodeTypeInfo type_info{"patternAnyOutput", 0}; + const NodeTypeInfo& get_type_info() const override; + /// \brief creates an AnyOutput node matching any output of a node + /// \param node The node to match + AnyOutput(const std::shared_ptr& pattern) : Pattern({pattern->output(0)}) {} + + bool match_value(pattern::Matcher* matcher, + const Output& pattern_value, + const Output& graph_value) override; +}; +} // namespace op +} // namespace pattern +} // namespace pass +} // namespace ov diff --git a/ngraph/core/include/openvino/pass/pattern/op/branch.hpp b/ngraph/core/include/openvino/pass/pattern/op/branch.hpp new file mode 100644 index 00000000000..902934f64ab --- /dev/null +++ b/ngraph/core/include/openvino/pass/pattern/op/branch.hpp @@ -0,0 +1,55 @@ +// Copyright (C) 2018-2021 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#pragma once + +#include "openvino/core/node.hpp" +#include "openvino/pass/pattern/op/pattern.hpp" + +namespace ov { +namespace pass { +namespace pattern { +namespace op { +/// A branch adds a loop to the pattern. The branch match is successful if the +/// destination node pattern matches the graph value. The destination node is a node in +/// the pattern graph that will not have been created some time after the Branch node is +/// created; use set_destination to add it. +/// +/// The branch destination is not stored as a shared pointer to prevent reference +/// cycles. Thus the destination node must be referenced in some other way to prevent it +/// from being deleted. +class OPENVINO_API Branch : public Pattern { +public: + static constexpr NodeTypeInfo type_info{"patternBranch", 0}; + const NodeTypeInfo& get_type_info() const override; + /// \brief Creates a Branch pattern + /// \param pattern the destinationing pattern + /// \param labels Labels where the destination may occur + Branch() : Pattern(OutputVector{}) { + set_output_type(0, element::f32, ngraph::Shape{}); + } + + void set_destination(const Output& destination) { + m_destination_node = destination.get_node(); + m_destination_index = destination.get_index(); + } + + Output get_destination() const { + return m_destination_node == nullptr + ? Output() + : Output{m_destination_node->shared_from_this(), m_destination_index}; + } + + bool match_value(pattern::Matcher* matcher, + const Output& pattern_value, + const Output& graph_value) override; + +protected: + Node* m_destination_node{nullptr}; + size_t m_destination_index{0}; +}; +} // namespace op +} // namespace pattern +} // namespace pass +} // namespace ov diff --git a/ngraph/core/include/openvino/pass/pattern/op/capture.hpp b/ngraph/core/include/openvino/pass/pattern/op/capture.hpp new file mode 100644 index 00000000000..0b2a5eca940 --- /dev/null +++ b/ngraph/core/include/openvino/pass/pattern/op/capture.hpp @@ -0,0 +1,44 @@ +// Copyright (C) 2018-2021 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#pragma once + +#include "openvino/core/node.hpp" +#include "openvino/pass/pattern/op/pattern.hpp" + +namespace ov { +namespace pass { +namespace pattern { +namespace op { +/// Experimental for support of recurrent matches. +/// +/// Capture adds the pattern value map to a list of pattern value maps and resets +/// matches for pattern nodes not in the static node list. The match always succeeds. +class OPENVINO_API Capture : public Pattern { +public: + static constexpr NodeTypeInfo type_info{"patternCapture", 0}; + const NodeTypeInfo& get_type_info() const override; + Capture(const Output& arg) : Pattern({arg}) { + set_output_type(0, arg.get_element_type(), arg.get_partial_shape()); + } + + /// \brief static nodes are retained after a capture. All other nodes are dropped + std::set get_static_nodes() { + return m_static_nodes; + } + void set_static_nodes(const std::set& static_nodes) { + m_static_nodes = static_nodes; + } + + bool match_value(pattern::Matcher* matcher, + const Output& pattern_value, + const Output& graph_value) override; + +protected: + std::set m_static_nodes; +}; +} // namespace op +} // namespace pattern +} // namespace pass +} // namespace ov diff --git a/ngraph/core/include/openvino/pass/pattern/op/label.hpp b/ngraph/core/include/openvino/pass/pattern/op/label.hpp new file mode 100644 index 00000000000..507a8036ce4 --- /dev/null +++ b/ngraph/core/include/openvino/pass/pattern/op/label.hpp @@ -0,0 +1,113 @@ +// Copyright (C) 2018-2021 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#pragma once + +#include "openvino/core/node.hpp" +#include "openvino/pass/pattern/op/pattern.hpp" + +namespace ov { +namespace pass { +namespace pattern { +namespace op { +/// Fails if the predicate returns false on the graph value. +/// +/// The graph value is added to the matched values list. If the Label is already +/// associated with a value, the match succeeds if the value is the same as the graph +/// value. Otherwise, the label is associated with the graph value and the match +/// succeeds if the pattern input matches the graph value. +/// +/// DEPRECATED: If no inputs are given to Label, a True node is serves as the input. If +/// more than one inputs are given, an Or pattern of the inputs serves as the input. +class OPENVINO_API Label : public Pattern { +public: + static constexpr NodeTypeInfo type_info{"patternLabel", 0}; + const NodeTypeInfo& get_type_info() const override; + /// \brief creates a Label node containing a sub-pattern described by \sa type and + /// \sa shape. + /// + /// this Label node can be bound only to the nodes in the input graph + /// that match the pattern specified by \sa wrapped_nodes + /// Example: + /// \code{.cpp} + /// auto add = a + b; // a and b are op::Parameter in this example + /// auto label = std::make_shared(element::f32, + /// Shape{2,2}, + /// nullptr, + /// OutputVector{add}); + /// \endcode + Label(const element::Type& type, + const PartialShape& s, + const ValuePredicate pred, + const OutputVector& wrapped_values) + : Pattern(OutputVector{wrap_values(wrapped_values)}, pred) { + set_output_type(0, type, s); + } + + explicit Label(const element::Type& type = element::dynamic, const PartialShape& s = PartialShape::dynamic()) + : Label( + type, + s, + [](const Output&) { + return true; + }, + OutputVector()) {} + + Label(const element::Type& type, const PartialShape& s, ValuePredicate pred) + : Label(type, s, pred, OutputVector{}) {} + + Label(const element::Type& type, const PartialShape& s, NodePredicate pred) + : Label(type, s, as_value_predicate(pred), OutputVector{}) {} + + Label(const element::Type& type, const PartialShape& s, const NodePredicate pred, const NodeVector& wrapped_values) + : Label(type, s, as_value_predicate(pred), as_output_vector(wrapped_values)) {} + + /// \brief creates a Label node containing a sub-pattern described by the type and + /// shape of \sa node. + /// + /// this Label node can be bound only to the nodes in the input graph + /// that match the pattern specified by \sa wrapped_values + /// Example: + /// \code{.cpp} + /// auto add = a + b; // a and b are op::Parameter in this example + /// auto label = std::make_shared(add, + /// nullptr, + /// OutputVector{add}); + /// \endcode + Label(const Output& value, const ValuePredicate pred, const OutputVector& wrapped_values) + : Label(value.get_element_type(), value.get_partial_shape(), pred, wrapped_values) {} + Label(const Output& value, const ValuePredicate pred) + : Label(value.get_element_type(), value.get_partial_shape(), pred, OutputVector{}) {} + + Label(const Output& value, const NodePredicate pred) + : Label(value.get_element_type(), value.get_partial_shape(), as_value_predicate(pred), OutputVector{}) {} + Label(const Output& value) + : Label( + value.get_element_type(), + value.get_partial_shape(), + [](const Output&) { + return true; + }, + OutputVector{}) {} + Label(const Output& node, const NodePredicate pred, const NodeVector& wrapped_values) + : Label(node.get_element_type(), + node.get_partial_shape(), + as_value_predicate(pred), + as_output_vector(wrapped_values)) {} + + bool match_value(Matcher* matcher, const Output& pattern_value, const Output& graph_value) override; + +protected: + static Output wrap_values(const OutputVector& wrapped_values); +}; +} // namespace op + +OPENVINO_API +std::shared_ptr any_input(); + +OPENVINO_API +std::shared_ptr any_input(const pattern::op::ValuePredicate& pred); +} // namespace pattern +} // namespace pass +} // namespace ov diff --git a/ngraph/core/include/openvino/pass/pattern/op/or.hpp b/ngraph/core/include/openvino/pass/pattern/op/or.hpp new file mode 100644 index 00000000000..1f173bdd418 --- /dev/null +++ b/ngraph/core/include/openvino/pass/pattern/op/or.hpp @@ -0,0 +1,32 @@ +// Copyright (C) 2018-2021 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#pragma once + +#include "openvino/core/node.hpp" +#include "openvino/pass/pattern/op/pattern.hpp" + +namespace ov { +namespace pass { +namespace pattern { +namespace op { +/// A submatch on the graph value is performed on each input to the Or; the match +/// succeeds on the first match. Otherwise the match fails. +class OPENVINO_API Or : public Pattern { +public: + static constexpr NodeTypeInfo type_info{"patternOr", 0}; + const NodeTypeInfo& get_type_info() const override; + /// \brief creates an Or node matching one of several sub-patterns in order. Does + /// not add node to match list. + /// \param patterns The patterns to try for matching + Or(const OutputVector& patterns) : Pattern(patterns) {} + + bool match_value(pattern::Matcher* matcher, + const Output& pattern_value, + const Output& graph_value) override; +}; +} // namespace op +} // namespace pattern +} // namespace pass +} // namespace ov diff --git a/ngraph/core/include/openvino/pass/pattern/op/pattern.hpp b/ngraph/core/include/openvino/pass/pattern/op/pattern.hpp new file mode 100644 index 00000000000..9fe455e0079 --- /dev/null +++ b/ngraph/core/include/openvino/pass/pattern/op/pattern.hpp @@ -0,0 +1,96 @@ +// Copyright (C) 2018-2021 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#pragma once + +#include + +#include "openvino/core/node.hpp" + +namespace ov { +namespace pass { +namespace pattern { +namespace op { +class Label; +} + +class Matcher; +class MatcherState; + +using RPatternValueMap = std::map, OutputVector>; +using PatternValueMap = std::map, Output>; +using PatternValueMaps = std::vector; + +using PatternMap = std::map, std::shared_ptr>; + +PatternMap as_pattern_map(const PatternValueMap& pattern_value_map); +PatternValueMap as_pattern_value_map(const PatternMap& pattern_map); + +template +std::function)> has_class() { + auto pred = [](std::shared_ptr node) -> bool { + return ov::is_type(node); + }; + + return pred; +} + +OPENVINO_API +std::function)> consumers_count(size_t n); + +OPENVINO_API +std::function)> has_static_dim(size_t pos); + +OPENVINO_API +std::function)> has_static_dims(const std::vector& dims); + +OPENVINO_API +std::function)> has_static_shape(); + +OPENVINO_API +std::function)> has_static_rank(); + +OPENVINO_API +std::function)> rank_equals(const Dimension& expected_rank); + +OPENVINO_API +std::function)> type_matches(const element::Type& type); + +OPENVINO_API +std::function)> type_matches_any(const std::vector& types); + +namespace op { +using NodePredicate = std::function)>; +using ValuePredicate = std::function& value)>; + +OPENVINO_API +ValuePredicate as_value_predicate(NodePredicate pred); + +class OPENVINO_API Pattern : public Node { +public: + /// \brief \p a base class for \sa Skip and \sa Label + /// + Pattern(const OutputVector& patterns, ValuePredicate pred) : Node(patterns), m_predicate(pred) { + if (!m_predicate) { + m_predicate = [](const Output&) { + return true; + }; + } + } + + Pattern(const OutputVector& patterns) : Pattern(patterns, nullptr) {} + + std::shared_ptr clone_with_new_inputs(const OutputVector& /* new_args */) const override { + throw Exception("Uncopyable"); + } + + ValuePredicate get_predicate() const; + +protected: + ValuePredicate m_predicate; +}; +} // namespace op +} // namespace pattern +} // namespace pass +} // namespace ov diff --git a/ngraph/core/include/openvino/pass/pattern/op/skip.hpp b/ngraph/core/include/openvino/pass/pattern/op/skip.hpp new file mode 100644 index 00000000000..02888a75022 --- /dev/null +++ b/ngraph/core/include/openvino/pass/pattern/op/skip.hpp @@ -0,0 +1,44 @@ +// Copyright (C) 2018-2021 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#pragma once + +#include "openvino/core/node.hpp" +#include "openvino/pass/pattern/op/pattern.hpp" + +namespace ov { +namespace pass { +namespace pattern { +namespace op { +/// The graph value is added to the matched value list. If the predicate is true, the +/// match succeeds if the arguments match; if the predicate is false, the match succeeds +/// if the pattern input matches the graph value. +class OPENVINO_API Skip : public Pattern { +public: + static constexpr NodeTypeInfo type_info{"patternSkip", 0}; + const NodeTypeInfo& get_type_info() const override; + Skip(const Output& arg, ValuePredicate pred) : Pattern({arg}, pred) { + set_output_type(0, arg.get_element_type(), arg.get_partial_shape()); + } + + Skip(const Output& arg, NodePredicate pred = nullptr) : Pattern({arg}, as_value_predicate(pred)) { + set_output_type(0, arg.get_element_type(), arg.get_partial_shape()); + } + + Skip(const OutputVector& args, ValuePredicate pred) : Pattern(args, pred) { + set_output_type(0, args.at(0).get_element_type(), args.at(0).get_partial_shape()); + } + + Skip(const OutputVector& args, NodePredicate pred = nullptr) : Pattern(args, as_value_predicate(pred)) { + set_output_type(0, args.at(0).get_element_type(), args.at(0).get_partial_shape()); + } + + bool match_value(pattern::Matcher* matcher, + const Output& pattern_value, + const Output& graph_value) override; +}; +} // namespace op +} // namespace pattern +} // namespace pass +} // namespace ov diff --git a/ngraph/core/include/openvino/pass/pattern/op/true.hpp b/ngraph/core/include/openvino/pass/pattern/op/true.hpp new file mode 100644 index 00000000000..b99170ce799 --- /dev/null +++ b/ngraph/core/include/openvino/pass/pattern/op/true.hpp @@ -0,0 +1,28 @@ +// Copyright (C) 2018-2021 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#pragma once + +#include "openvino/core/node.hpp" +#include "openvino/pass/pattern/op/pattern.hpp" + +namespace ov { +namespace pass { +namespace pattern { +namespace op { +/// \brief The match always succeeds. +class OPENVINO_API True : public Pattern { +public: + static constexpr NodeTypeInfo type_info{"patternTrue", 0}; + const NodeTypeInfo& get_type_info() const override; + /// \brief Always matches, does not add node to match list. + True() : Pattern(OutputVector{}) {} + bool match_value(pattern::Matcher* matcher, + const Output& pattern_value, + const Output& graph_value) override; +}; +} // namespace op +} // namespace pattern +} // namespace pass +} // namespace ov diff --git a/ngraph/core/include/openvino/pass/pattern/op/wrap_type.hpp b/ngraph/core/include/openvino/pass/pattern/op/wrap_type.hpp new file mode 100644 index 00000000000..01498b81b5f --- /dev/null +++ b/ngraph/core/include/openvino/pass/pattern/op/wrap_type.hpp @@ -0,0 +1,75 @@ +// Copyright (C) 2018-2021 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#pragma once + +#include "openvino/core/node.hpp" +#include "openvino/pass/pattern/op/pattern.hpp" + +namespace ov { +namespace pass { +namespace pattern { +namespace op { +class OPENVINO_API WrapType : public Pattern { +public: + static constexpr NodeTypeInfo type_info{"patternAnyType", 0}; + const NodeTypeInfo& get_type_info() const override; + + explicit WrapType( + NodeTypeInfo wrapped_type, + const ValuePredicate& pred = + [](const Output& output) { + return true; + }, + const OutputVector& input_values = {}) + : Pattern(input_values, pred), + m_wrapped_types({wrapped_type}) { + set_output_type(0, element::Type_t::dynamic, PartialShape::dynamic()); + } + + explicit WrapType( + std::vector wrapped_types, + const ValuePredicate& pred = + [](const Output& output) { + return true; + }, + const OutputVector& input_values = {}) + : Pattern(input_values, pred), + m_wrapped_types(std::move(wrapped_types)) { + set_output_type(0, element::Type_t::dynamic, PartialShape::dynamic()); + } + + bool match_value(pattern::Matcher* matcher, + const Output& pattern_value, + const Output& graph_value) override; + + NodeTypeInfo get_wrapped_type() const; + + const std::vector& get_wrapped_types() const; + +private: + std::vector m_wrapped_types; +}; +} // namespace op + +template +std::shared_ptr wrap_type(const OutputVector& inputs, const pattern::op::ValuePredicate& pred) { + std::vector info{Args::type_info...}; + return std::make_shared(info, pred, inputs); +} + +template +std::shared_ptr wrap_type(const OutputVector& inputs = {}) { + return wrap_type(inputs, [](const Output& output) { + return true; + }); +} + +template +std::shared_ptr wrap_type(const pattern::op::ValuePredicate& pred) { + return wrap_type({}, pred); +} +} // namespace pattern +} // namespace pass +} // namespace ov diff --git a/ngraph/core/include/openvino/pass/validate.hpp b/ngraph/core/include/openvino/pass/validate.hpp new file mode 100644 index 00000000000..63b502a1a28 --- /dev/null +++ b/ngraph/core/include/openvino/pass/validate.hpp @@ -0,0 +1,32 @@ +// Copyright (C) 2018-2021 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#pragma once + +#include "openvino/core/core_visibility.hpp" +#include "openvino/pass/pass.hpp" + +namespace ov { +namespace pass { +/// \brief The Validate pass performs sanity checks on attributes and inputs, and +/// computes output shapes and element types for all computation nodes in a given +/// computation graph. +/// +/// \details The verification and inference is done via invoking each node's specific +/// implementation of \link ov::Node::validate_and_infer_types() \endlink function. +/// +/// By default, the \ref ov::pass::Manager runs this pass after executing every +/// optimization pass. This is to ensure that any update to the graph by an optimization +/// pass does not break the shape and data type requirement on a computation node. +/// This default validation run can be changed via calling the +/// \link ov::pass::Manager::set_per_pass_validation(bool) \endlink function. +class OPENVINO_API Validate : public FunctionPass { +public: + OPENVINO_RTTI_DECLARATION; + + Validate() : FunctionPass() {} + bool run_on_function(std::shared_ptr f) override; +}; +} // namespace pass +} // namespace ov diff --git a/ngraph/core/include/openvino/pass/visualize_tree.hpp b/ngraph/core/include/openvino/pass/visualize_tree.hpp new file mode 100644 index 00000000000..7eab8b8f2c9 --- /dev/null +++ b/ngraph/core/include/openvino/pass/visualize_tree.hpp @@ -0,0 +1,57 @@ +// Copyright (C) 2018-2021 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#pragma once + +#include +#include +#include +#include +#include +#include +#include +#include + +#include "openvino/pass/pass.hpp" + +class HeightMap; + +using visualize_tree_ops_map_t = + std::unordered_map>; + +namespace ov { +namespace pass { +class OPENVINO_API VisualizeTree : public FunctionPass { +public: + OPENVINO_RTTI_DECLARATION; + + using node_modifiers_t = std::function& attributes)>; + VisualizeTree(const std::string& file_name, node_modifiers_t nm = nullptr, bool dot_only = false); + bool run_on_function(std::shared_ptr) override; + + void set_ops_to_details(const visualize_tree_ops_map_t& ops_map) { + m_ops_to_details = ops_map; + } + +protected: + void add_node_arguments(std::shared_ptr node, + std::unordered_map& height_maps, + size_t& fake_node_ctr); + std::string add_attributes(std::shared_ptr node); + virtual std::string get_attributes(std::shared_ptr node); + virtual std::string get_node_name(std::shared_ptr node); + std::string get_constant_value(std::shared_ptr node, size_t max_elements = 7); + + void render() const; + + std::stringstream m_ss; + std::string m_name; + std::set> m_nodes_with_attributes; + visualize_tree_ops_map_t m_ops_to_details; + node_modifiers_t m_node_modifiers = nullptr; + bool m_dot_only; + static const int max_jump_distance; +}; +} // namespace pass +} // namespace ov diff --git a/ngraph/core/src/pass/constant_folding.cpp b/ngraph/core/src/pass/constant_folding.cpp index f9818321d1f..56be7361736 100644 --- a/ngraph/core/src/pass/constant_folding.cpp +++ b/ngraph/core/src/pass/constant_folding.cpp @@ -11,11 +11,10 @@ #include "ngraph/validation_util.hpp" using namespace std; -using namespace ngraph; -NGRAPH_RTTI_DEFINITION(ngraph::pass::ConstantFolding, "ConstantFolding", 0); +OPENVINO_RTTI_DEFINITION(ov::pass::ConstantFolding, "ConstantFolding", 0); -bool ngraph::pass::ConstantFolding::run_on_function(std::shared_ptr f) { +bool ov::pass::ConstantFolding::run_on_function(std::shared_ptr f) { bool rewritten = pre_calculated_values_folding(f); for (const auto& node : f->get_ordered_ops()) { @@ -48,7 +47,7 @@ bool ngraph::pass::ConstantFolding::run_on_function(std::shared_ptr(node)) { + if (auto sub_graph_node = std::dynamic_pointer_cast(node)) { if (const auto& sub_graph = sub_graph_node->get_function()) { rewritten |= run_on_function(sub_graph); } @@ -79,14 +78,14 @@ bool ngraph::pass::ConstantFolding::pre_calculated_values_folding(const std::sha while (!nodes.empty()) { auto curr_node = nodes.front(); nodes.pop_front(); - if (visited.count(curr_node) || ov::is_type(curr_node)) + if (visited.count(curr_node) || ov::is_type(curr_node)) continue; visited.insert(curr_node); for (auto& input_value : curr_node->input_values()) { // Check that ConstantFolding is not disabled on this path std::vector order; - auto status = could_propagate(input_value, order); + auto status = ngraph::could_propagate(input_value, order); if (status) { for (const auto& node : order) { const auto& rt_info = node->get_rt_info(); @@ -99,8 +98,8 @@ bool ngraph::pass::ConstantFolding::pre_calculated_values_folding(const std::sha if (status && input_value.get_tensor().has_and_set_bound()) { auto input_node = input_value.get_node_shared_ptr(); - auto replacement = std::make_shared(input_value.get_tensor().get_lower_value()); - if (replacement && !ov::is_type(input_node)) { + auto replacement = std::make_shared(input_value.get_tensor().get_lower_value()); + if (replacement && !ov::is_type(input_node)) { if (input_node->get_output_size() == 1) { replacement->set_friendly_name(input_node->get_friendly_name()); } else { diff --git a/ngraph/core/src/pass/convert_fp32_to_fp16.cpp b/ngraph/core/src/pass/convert_fp32_to_fp16.cpp index 74386bb9da6..8d98e5f83fc 100644 --- a/ngraph/core/src/pass/convert_fp32_to_fp16.cpp +++ b/ngraph/core/src/pass/convert_fp32_to_fp16.cpp @@ -9,12 +9,11 @@ #include "transformations/convert_precision.hpp" using namespace std; -using namespace ngraph; -NGRAPH_RTTI_DEFINITION(ngraph::pass::ConvertFP32ToFP16, "ConvertFP32ToFP16", 0); +OPENVINO_RTTI_DEFINITION(ov::pass::ConvertFP32ToFP16, "ConvertFP32ToFP16", 0); -bool ngraph::pass::ConvertFP32ToFP16::run_on_function(std::shared_ptr f) { - ngraph::pass::Manager m(get_pass_config()); +bool ov::pass::ConvertFP32ToFP16::run_on_function(std::shared_ptr f) { + ov::pass::Manager m(get_pass_config()); m.register_pass(precisions_array{{ngraph::element::f32, ngraph::element::f16}}); m.run_passes(f); return false; diff --git a/ngraph/core/src/pass/graph_rewrite.cpp b/ngraph/core/src/pass/graph_rewrite.cpp index 07e8472ce47..b9342a88562 100644 --- a/ngraph/core/src/pass/graph_rewrite.cpp +++ b/ngraph/core/src/pass/graph_rewrite.cpp @@ -18,9 +18,6 @@ #include "ngraph/op/util/sub_graph_base.hpp" #include "perf_counters.hpp" -using namespace std; -using namespace ngraph; - /* GraphRewrite algorithm: * GraphRewrite processes an input graph in an topological order(i.e. args before users) * Given the following graph: Abs2 @@ -33,7 +30,7 @@ using namespace ngraph; * Note, `Abs2` comes before `Neg3` as `Abs2`'s id = 2 is *less* than `Neg3`'s one (id = 3) * Next, GraphRewrite will invoke matchers passes registered in add_matcher order. * For example: - * ngraph::pass::GraphRewrite pass; + * ov::pass::GraphRewrite pass; * pass.add_matcher(); * pass.add_matcher(); * pass.add_matcher(); @@ -53,13 +50,13 @@ using namespace ngraph; * If MatcherPass register more than one node make sure that this nodes are registered in * topological order. */ -NGRAPH_RTTI_DEFINITION(ngraph::pass::GraphRewrite, "ngraph::pass::GraphRewrite", 0); +NGRAPH_RTTI_DEFINITION(ov::pass::GraphRewrite, "ov::pass::GraphRewrite", 0); -NGRAPH_RTTI_DEFINITION(ngraph::pass::BackwardGraphRewrite, "ngraph::pass::BackwardGraphRewrite", 0); +NGRAPH_RTTI_DEFINITION(ov::pass::BackwardGraphRewrite, "ov::pass::BackwardGraphRewrite", 0); -NGRAPH_RTTI_DEFINITION(ngraph::pass::MatcherPass, "ngraph::pass::MatcherPass", 0); +NGRAPH_RTTI_DEFINITION(ov::pass::MatcherPass, "ov::pass::MatcherPass", 0); -namespace ngraph { +namespace ov { namespace pass { namespace internal { PerfCounters& perf_counters_graph_rewrite() { @@ -68,27 +65,28 @@ PerfCounters& perf_counters_graph_rewrite() { } } // namespace internal } // namespace pass -} // namespace ngraph +} // namespace ov -bool pass::BackwardGraphRewrite::run_on_function(std::shared_ptr f) { +bool ov::pass::BackwardGraphRewrite::run_on_function(std::shared_ptr f) { // Initialize execution queue with nodes in topological order - deque> nodes_to_run; + std::deque> nodes_to_run; for (auto& node : f->get_ordered_ops()) { nodes_to_run.emplace_front(node); } return apply_matcher_passes(f, std::move(nodes_to_run)); } -bool pass::GraphRewrite::run_on_function(std::shared_ptr f) { +bool ov::pass::GraphRewrite::run_on_function(std::shared_ptr f) { // Initialize execution queue with nodes in topological order - deque> nodes_to_run; + std::deque> nodes_to_run; for (auto& node : f->get_ordered_ops()) { nodes_to_run.emplace_back(node); } return apply_matcher_passes(f, std::move(nodes_to_run)); } -bool pass::GraphRewrite::apply_matcher_passes(shared_ptr f, deque> nodes_to_run) { +bool ov::pass::GraphRewrite::apply_matcher_passes(std::shared_ptr f, + std::deque> nodes_to_run) { OV_ITT_SCOPED_TASK(ov::itt::domains::nGraph, "pass::GraphRewrite::run_on_function"); bool rewritten = false; @@ -111,7 +109,7 @@ bool pass::GraphRewrite::apply_matcher_passes(shared_ptr f, dequeget_pattern_value().get_node_shared_ptr(); // pattern::op::AnyOutput operation automatically appends for multi output operations inside // Matcher and to gen actual root node we need to take it's parent. - if (auto any_type = dynamic_pointer_cast(root)) { + if (auto any_type = std::dynamic_pointer_cast(root)) { root = any_type->input_value(0).get_node_shared_ptr(); } @@ -119,8 +117,8 @@ bool pass::GraphRewrite::apply_matcher_passes(shared_ptr f, deque(root)) { - if (auto any_type = dynamic_pointer_cast(p)) { + if (auto p = std::dynamic_pointer_cast(root)) { + if (auto any_type = std::dynamic_pointer_cast(p)) { for (const auto& root_type_info : any_type->get_wrapped_types()) { type_to_matcher[root_type_info].push_back(matcher_index); } @@ -180,7 +178,7 @@ bool pass::GraphRewrite::apply_matcher_passes(shared_ptr f, deque(node)) { + if (auto sub_graph_node = std::dynamic_pointer_cast(node)) { if (auto sub_graph = sub_graph_node->get_function()) { run_on_function(sub_graph); } @@ -236,9 +234,9 @@ bool pass::GraphRewrite::apply_matcher_passes(shared_ptr f, deque& m, - const graph_rewrite_callback& callback, - const PassPropertyMask& property) { +void ov::pass::GraphRewrite::add_matcher(const std::shared_ptr& m, + const graph_rewrite_callback& callback, + const PassPropertyMask& property) { m_matchers.push_back(std::make_shared( m->get_name(), m, @@ -258,7 +256,8 @@ void pass::GraphRewrite::add_matcher(const shared_ptr& m, property)); } -void pass::GraphRewrite::add_matcher(const shared_ptr& m, const graph_rewrite_callback& callback) { +void ov::pass::GraphRewrite::add_matcher(const std::shared_ptr& m, + const graph_rewrite_callback& callback) { NGRAPH_SUPPRESS_DEPRECATED_START // TODO: before deprecate this function, by default expect the // callback require static shape. @@ -266,7 +265,7 @@ void pass::GraphRewrite::add_matcher(const shared_ptr& m, cons NGRAPH_SUPPRESS_DEPRECATED_END } -void pass::GraphRewrite::set_pass_config(const std::shared_ptr& rhs) { +void ov::pass::GraphRewrite::set_pass_config(const std::shared_ptr& rhs) { auto pass_config = get_pass_config(); // We have to preserve disabled passes because in case when we register matchers inside // GraphRewrite c-tor we work with local PassConfig instance. @@ -293,9 +292,9 @@ void pass::GraphRewrite::set_pass_config(const std::shared_ptr& rhs) } } -void pass::RecurrentGraphRewrite::add_matcher(const std::shared_ptr& m, - const ngraph::recurrent_graph_rewrite_callback& callback, - const PassPropertyMask& property) { +void ov::pass::RecurrentGraphRewrite::add_matcher(const std::shared_ptr& m, + const ov::recurrent_graph_rewrite_callback& callback, + const PassPropertyMask& property) { m_matchers.push_back(std::make_shared( "Recurrent matcher", nullptr, @@ -310,24 +309,24 @@ void pass::RecurrentGraphRewrite::add_matcher(const std::shared_ptr& m, - const ngraph::recurrent_graph_rewrite_callback& callback) { +void ov::pass::RecurrentGraphRewrite::add_matcher(const std::shared_ptr& m, + const ov::recurrent_graph_rewrite_callback& callback) { // TODO: before deprecate this function, by default expect the // callback require static shape. add_matcher(m, callback, {PassProperty::REQUIRE_STATIC_SHAPE}); } -bool pass::RecurrentGraphRewrite::run_on_function(shared_ptr f) { +bool ov::pass::RecurrentGraphRewrite::run_on_function(std::shared_ptr f) { bool changed = false; size_t i = 0; // This check is very expensive and is only needed for experimental features, so we will hide // it behind an environment variable for now. TODO: Find a less expensive way to handle this. - static bool s_rerun_dynamic_check = getenv_bool("NGRAPH_GRAPH_REWRITE_RERUN_DYNAMIC_CHECK"); + static bool s_rerun_dynamic_check = ngraph::getenv_bool("NGRAPH_GRAPH_REWRITE_RERUN_DYNAMIC_CHECK"); auto run_matchers = [&]() -> bool { bool is_dyn_func = s_rerun_dynamic_check && f->is_dynamic(); - for (auto node : f->get_ops()) { + for (const auto& node : f->get_ops()) { for (auto& m_pass : m_matchers) { if (is_dyn_func && m_pass->get_property(PassProperty::REQUIRE_STATIC_SHAPE)) { NGRAPH_DEBUG << "matcher callback requires static shape but the " @@ -356,9 +355,9 @@ bool pass::RecurrentGraphRewrite::run_on_function(shared_ptr f) { return changed; } -void ngraph::pass::MatcherPass::register_matcher(const std::shared_ptr& m, - const ngraph::graph_rewrite_callback& callback, - const PassPropertyMask& property) { +void ov::pass::MatcherPass::register_matcher(const std::shared_ptr& m, + const ov::graph_rewrite_callback& callback, + const PassPropertyMask& property) { set_name(m->get_name()); set_property(property, true); m_matcher = m; @@ -376,7 +375,7 @@ void ngraph::pass::MatcherPass::register_matcher(const std::shared_ptr node) { +bool ov::pass::MatcherPass::apply(std::shared_ptr node) { OV_ITT_SCOPED_TASK(ov::itt::domains::nGraph, pass::internal::perf_counters_graph_rewrite()[get_type_info()]); m_new_nodes.clear(); if (m_handler) diff --git a/ngraph/core/src/pass/low_latency.cpp b/ngraph/core/src/pass/low_latency.cpp index 87946243230..d4e0abf2143 100644 --- a/ngraph/core/src/pass/low_latency.cpp +++ b/ngraph/core/src/pass/low_latency.cpp @@ -12,13 +12,12 @@ #include #include -NGRAPH_RTTI_DEFINITION(ngraph::pass::LowLatency2, "LowLatency2", 0); +NGRAPH_RTTI_DEFINITION(ov::pass::LowLatency2, "LowLatency2", 0); NGRAPH_SUPPRESS_DEPRECATED_START NGRAPH_RTTI_DEFINITION(ngraph::pass::LowLatency, "LowLatency", 0); using namespace std; -using namespace ngraph; namespace { string generate_variable_name(const string& op_name, const string& param_name, int variable_idx) { @@ -27,8 +26,8 @@ string generate_variable_name(const string& op_name, const string& param_name, i } // namespace ngraph::pass::LowLatency::LowLatency() { - auto tensor_iterator = ngraph::pattern::wrap_type(); - ngraph::matcher_pass_callback callback = [](ngraph::pattern::Matcher& m) { + auto tensor_iterator = ov::pass::pattern::wrap_type(); + ov::matcher_pass_callback callback = [](ov::pass::pattern::Matcher& m) { const auto& sub_graph_op = std::dynamic_pointer_cast(m.get_match_root()); if (!sub_graph_op) { return false; @@ -38,7 +37,7 @@ ngraph::pass::LowLatency::LowLatency() { const auto& trip_count = std::dynamic_pointer_cast(loop->get_input_node_shared_ptr(0)); const auto& num_iter = loop->get_num_iterations(); if (trip_count && num_iter > 0 && trip_count->get_output_target_inputs(0).size() == 1) { - auto single_iter = std::make_shared(ngraph::element::i64, Shape{}, 1); + auto single_iter = std::make_shared(ov::element::i64, Shape{}, 1); replace_node(trip_count, single_iter); } else { // count of iterations is dynamic; @@ -47,7 +46,7 @@ ngraph::pass::LowLatency::LowLatency() { } // Mark the TI layer to be unrolled. Enable unconditional ti unrolling for all plugins. auto& rt_info = sub_graph_op->get_rt_info(); - rt_info["UNROLL_TI"] = std::make_shared>(1); + rt_info["UNROLL_TI"] = std::make_shared>(1); int64_t variable_id = 0; std::vector> assigns; @@ -87,13 +86,14 @@ ngraph::pass::LowLatency::LowLatency() { return false; }; - auto m = std::make_shared(tensor_iterator, "LowLatency"); + auto m = std::make_shared(tensor_iterator, "LowLatency"); register_matcher(m, callback); } NGRAPH_SUPPRESS_DEPRECATED_END -void UnrollSingleIteration(const shared_ptr& sub_graph_op, const shared_ptr& outer_f) { - using namespace opset7; +void UnrollSingleIteration(const shared_ptr& sub_graph_op, + const shared_ptr& outer_f) { + using namespace ngraph::opset7; const auto& params = sub_graph_op->get_function()->get_parameters(); const auto& results = sub_graph_op->get_function()->get_results(); @@ -109,7 +109,7 @@ void UnrollSingleIteration(const shared_ptr& sub_graph_op, // before: TI [...-> Layer1 -> Result -> output] -> Layer2 -> ... // after: ...-> Layer1 -> Layer2 -> ... - NodeVector new_ops; + ov::NodeVector new_ops; for (const auto& out : sub_graph_op->get_output_descriptions()) { const auto& connect_to = results.at(out->m_body_value_index)->get_input_source_output(0); for (auto& input_to : sub_graph_op->output(out->m_output_index).get_target_inputs()) { @@ -120,7 +120,7 @@ void UnrollSingleIteration(const shared_ptr& sub_graph_op, // IECompatibility: insert identity (Unsqueeze + Squeeze) to store the TensorIterator // output names - auto axis_1 = Constant::create(ngraph::element::i64, ngraph::Shape{1}, {1}); + auto axis_1 = Constant::create(ov::element::i64, ngraph::Shape{1}, {1}); auto identity_1 = std::make_shared(connect_to, axis_1); auto identity_2 = std::make_shared(identity_1, axis_1); identity_2->set_friendly_name(out_name); @@ -135,36 +135,38 @@ void UnrollSingleIteration(const shared_ptr& sub_graph_op, ngraph::copy_runtime_info(sub_graph_op, new_ops); } -Output create_init_subgraph(const shared_ptr& sub_graph_op, const Output& in_node) { - using namespace opset7; +ngraph::Output create_init_subgraph(const shared_ptr& sub_graph_op, + const ngraph::Output& in_node) { + using namespace ngraph::opset7; - auto const_zero = make_shared(in_node.get_element_type(), Shape{1}, 0); + auto const_zero = make_shared(in_node.get_element_type(), ngraph::Shape{1}, 0); auto shape_of = make_shared(in_node); auto broadcast = make_shared(const_zero, shape_of); copy_runtime_info(sub_graph_op, {const_zero, shape_of, broadcast}); return broadcast->output(0); } -bool pass::LowLatency2::run_on_function(shared_ptr f) { - using namespace opset7; +bool ov::pass::LowLatency2::run_on_function(shared_ptr f) { + using namespace ngraph::opset7; - SinkVector assigns; + ngraph::SinkVector assigns; for (const auto& op : f->get_ordered_ops()) { - if (const auto& sub_graph_op = dynamic_pointer_cast(op)) { + if (const auto& sub_graph_op = dynamic_pointer_cast(op)) { int64_t variable_id = 0; const auto& func = sub_graph_op->get_function(); const auto& params = func->get_parameters(); for (const auto& in : sub_graph_op->get_input_descriptions()) { // Process all back edges - if (const auto& merged_in = dynamic_pointer_cast(in)) { + if (const auto& merged_in = + dynamic_pointer_cast(in)) { // create new Variable const string& param_name = params.at(merged_in->m_body_parameter_index)->get_friendly_name(); const string& var_name = generate_variable_name(sub_graph_op->get_friendly_name(), param_name, variable_id); const auto& input = sub_graph_op->input(merged_in->m_input_index); - if (std::dynamic_pointer_cast(input.get_source_output().get_node_shared_ptr()) != - nullptr) { + if (std::dynamic_pointer_cast( + input.get_source_output().get_node_shared_ptr()) != nullptr) { NGRAPH_DEBUG << "LowLatency2 transformation cannot be applied because the " << "ReadValue node is already an input to the TensorIterator." << "LowLatency2 transformation may have already been applied, please " @@ -175,7 +177,7 @@ bool pass::LowLatency2::run_on_function(shared_ptr f) { const auto& param = sub_graph_op->get_function()->get_parameters().at(merged_in->m_body_parameter_index); for (const auto& in_to : param->output(0).get_target_inputs()) { - if (dynamic_cast(in_to.get_node()) != nullptr) { + if (dynamic_cast(in_to.get_node()) != nullptr) { NGRAPH_DEBUG << "LowLatency2 transformation cannot be applied because the " << "ReadValue node is already inside the TensorIterator. " << "LowLatency transformation may have been applied, please do " @@ -184,8 +186,8 @@ bool pass::LowLatency2::run_on_function(shared_ptr f) { } } - VariableInfo var_info{PartialShape::dynamic(), element::dynamic, var_name}; - auto variable = make_shared(var_info); + ngraph::VariableInfo var_info{PartialShape::dynamic(), element::dynamic, var_name}; + auto variable = make_shared(var_info); // insert ReadValue // Layers -> [new op: ReadValue] -> Subgraph operation @@ -204,12 +206,12 @@ bool pass::LowLatency2::run_on_function(shared_ptr f) { // ---> Layers -> ... */ const auto& out_desc = sub_graph_op->get_output_descriptions(); - bool is_output_exist = - std::any_of(out_desc.begin(), - out_desc.end(), - [&merged_in](const std::shared_ptr& out) { - return out->m_body_value_index == merged_in->m_body_value_index; - }); + bool is_output_exist = std::any_of( + out_desc.begin(), + out_desc.end(), + [&merged_in](const std::shared_ptr& out) { + return out->m_body_value_index == merged_in->m_body_value_index; + }); // Create new output if it doesn't exist. if (!is_output_exist) { sub_graph_op->get_iter_value(func->get_results().at(merged_in->m_body_value_index)); @@ -217,7 +219,7 @@ bool pass::LowLatency2::run_on_function(shared_ptr f) { for (const auto& out : sub_graph_op->get_output_descriptions()) { if (out->m_body_value_index == merged_in->m_body_value_index) { auto assign = make_shared(sub_graph_op->output(out->m_output_index), variable); - ngraph::copy_runtime_info(sub_graph_op, assign); + copy_runtime_info(sub_graph_op, assign); // control dependency so that ReadValue is processed before Assign assign->add_control_dependency(read_value); assigns.emplace_back(assign); diff --git a/ngraph/core/src/pass/manager.cpp b/ngraph/core/src/pass/manager.cpp index 613ec73bff3..288003dd08a 100644 --- a/ngraph/core/src/pass/manager.cpp +++ b/ngraph/core/src/pass/manager.cpp @@ -24,9 +24,8 @@ #include "perf_counters.hpp" using namespace std; -using namespace ngraph; -namespace ngraph { +namespace ov { namespace pass { namespace internal { PerfCounters& perf_counters() { @@ -35,25 +34,25 @@ PerfCounters& perf_counters() { } } // namespace internal } // namespace pass -} // namespace ngraph +} // namespace ov -pass::Manager::Manager() +ov::pass::Manager::Manager() : m_pass_config(std::make_shared()), - m_visualize(getenv_bool("NGRAPH_ENABLE_VISUALIZE_TRACING")) {} + m_visualize(ngraph::getenv_bool("NGRAPH_ENABLE_VISUALIZE_TRACING")) {} -pass::Manager::~Manager() {} +ov::pass::Manager::~Manager() = default; -pass::Manager::Manager(std::shared_ptr pass_config) : m_pass_config(std::move(pass_config)) {} +ov::pass::Manager::Manager(std::shared_ptr pass_config) : m_pass_config(std::move(pass_config)) {} -void pass::Manager::run_passes(shared_ptr func) { +void ov::pass::Manager::run_passes(shared_ptr func) { NGRAPH_SUPPRESS_DEPRECATED_START OV_ITT_SCOPED_TASK(ov::itt::domains::nGraph, "pass::Manager::run_passes"); - static bool profile_enabled = getenv_bool("NGRAPH_PROFILE_PASS_ENABLE"); + static bool profile_enabled = ngraph::getenv_bool("NGRAPH_PROFILE_PASS_ENABLE"); size_t index = 0; - stopwatch pass_timer; - stopwatch overall_timer; + ngraph::stopwatch pass_timer; + ngraph::stopwatch overall_timer; overall_timer.start(); bool function_changed = false; for (auto& pass : m_pass_list) { @@ -96,13 +95,13 @@ void pass::Manager::run_passes(shared_ptr func) { } else { function_changed = function_pass->run_on_function(func); } - } else if (auto node_pass = dynamic_pointer_cast(pass)) { + } else if (auto node_pass = dynamic_pointer_cast(pass)) { if (node_pass->get_property(PassProperty::REQUIRE_STATIC_SHAPE) && func->is_dynamic()) { NGRAPH_DEBUG << "Pass " << pass->get_name() << " requires static shape but the " << "function is dynamic. Skipping this transformation"; continue; } - for (shared_ptr n : func->get_ops()) { + for (const shared_ptr& n : func->get_ops()) { function_changed |= node_pass->run_on_node(n); } } @@ -115,7 +114,7 @@ void pass::Manager::run_passes(shared_ptr func) { auto base_filename = func->get_name() + std::string("_") + index_str + std::string("_") + pass->get_name(); if (m_visualize) { - static const string format = getenv_string("NGRAPH_VISUALIZE_TRACING_FORMAT"); + static const string format = ngraph::getenv_string("NGRAPH_VISUALIZE_TRACING_FORMAT"); auto file_ext = format.empty() ? "svg" : format; pass::VisualizeTree vt(base_filename + std::string(".") + file_ext); vt.run_on_function(func); diff --git a/ngraph/core/src/pass/pass.cpp b/ngraph/core/src/pass/pass.cpp index 7ab35f96c8d..4d8e3308582 100644 --- a/ngraph/core/src/pass/pass.cpp +++ b/ngraph/core/src/pass/pass.cpp @@ -7,21 +7,20 @@ # include #endif -#include "ngraph/pass/manager.hpp" #include "ngraph/pass/pass.hpp" +#include "openvino/pass/manager.hpp" using namespace std; -using namespace ngraph; -NGRAPH_RTTI_DEFINITION(ngraph::pass::FunctionPass, "ngraph::pass::FunctionPass", 0); +OPENVINO_RTTI_DEFINITION(ov::pass::FunctionPass, "ov::pass::FunctionPass", 0); -pass::PassBase::PassBase() : m_property{all_pass_property_off}, m_pass_config(std::make_shared()) {} +ov::pass::PassBase::PassBase() : m_property(), m_pass_config(std::make_shared()) {} -bool pass::PassBase::get_property(const PassPropertyMask& prop) const { +bool ov::pass::PassBase::get_property(const PassPropertyMask& prop) const { return m_property.is_set(prop); } -void pass::PassBase::set_property(const PassPropertyMask& prop, bool value) { +void ov::pass::PassBase::set_property(const PassPropertyMask& prop, bool value) { if (value) { m_property.set(prop); } else { @@ -29,7 +28,7 @@ void pass::PassBase::set_property(const PassPropertyMask& prop, bool value) { } } -std::string pass::PassBase::get_name() const { +std::string ov::pass::PassBase::get_name() const { if (m_name.empty()) { const PassBase* p = this; std::string pass_name = typeid(*p).name(); @@ -43,16 +42,16 @@ std::string pass::PassBase::get_name() const { } } -void pass::PassBase::set_callback(const param_callback& callback) { +void ov::pass::PassBase::set_callback(const param_callback& callback) { m_pass_config->set_callback(callback); } // The symbols are requiered to be in cpp file to workaround RTTI issue on Android LLVM -pass::FunctionPass::~FunctionPass() {} +ov::pass::FunctionPass::~FunctionPass() = default; -NGRAPH_SUPPRESS_DEPRECATED_START +OPENVINO_SUPPRESS_DEPRECATED_START -NGRAPH_RTTI_DEFINITION(ngraph::pass::NodePass, "ngraph::pass::NodePass", 0); +OPENVINO_RTTI_DEFINITION(ngraph::pass::NodePass, "ngraph::pass::NodePass", 0); -pass::NodePass::~NodePass() {} +ngraph::pass::NodePass::~NodePass() = default; diff --git a/ngraph/core/src/pass/pass_config.cpp b/ngraph/core/src/pass/pass_config.cpp index af7aa863801..58e33dba8e8 100644 --- a/ngraph/core/src/pass/pass_config.cpp +++ b/ngraph/core/src/pass/pass_config.cpp @@ -2,11 +2,9 @@ // SPDX-License-Identifier: Apache-2.0 // -#include "ngraph/pass/pass_config.hpp" +#include "openvino/pass/pass_config.hpp" -using namespace ngraph; - -pass::param_callback pass::PassConfig::get_callback(const DiscreteTypeInfo& type_info) const { +ov::pass::param_callback ov::pass::PassConfig::get_callback(const DiscreteTypeInfo& type_info) const { const auto& it = m_callback_map.find(type_info); if (it != m_callback_map.end()) { return it->second; @@ -15,17 +13,17 @@ pass::param_callback pass::PassConfig::get_callback(const DiscreteTypeInfo& type } } -void pass::PassConfig::enable(const ngraph::DiscreteTypeInfo& type_info) { +void ov::pass::PassConfig::enable(const ngraph::DiscreteTypeInfo& type_info) { m_disabled.erase(type_info); m_enabled.insert(type_info); } -void pass::PassConfig::disable(const ngraph::DiscreteTypeInfo& type_info) { +void ov::pass::PassConfig::disable(const ngraph::DiscreteTypeInfo& type_info) { m_enabled.erase(type_info); m_disabled.insert(type_info); } -void pass::PassConfig::add_disabled_passes(const PassConfig& rhs) { +void ov::pass::PassConfig::add_disabled_passes(const PassConfig& rhs) { for (const auto& pass : rhs.m_disabled) { if (is_enabled(pass)) continue; diff --git a/ngraph/core/src/pass/perf_counters.cpp b/ngraph/core/src/pass/perf_counters.cpp index b59e703ac28..87f6ba59d18 100644 --- a/ngraph/core/src/pass/perf_counters.cpp +++ b/ngraph/core/src/pass/perf_counters.cpp @@ -3,7 +3,7 @@ // #include "perf_counters.hpp" -namespace ngraph { +namespace ov { namespace pass { openvino::itt::handle_t PerfCounters::operator[](::ngraph::Node::type_info_t const& type_inf) { std::lock_guard guard(m_mutex); @@ -13,4 +13,4 @@ openvino::itt::handle_t PerfCounters::operator[](::ngraph::Node::type_info_t con return m_counters[&type_inf] = openvino::itt::handle(type_inf.name); } } // namespace pass -} // namespace ngraph +} // namespace ov diff --git a/ngraph/core/src/pass/perf_counters.hpp b/ngraph/core/src/pass/perf_counters.hpp index a359c284a77..b5dad50e235 100644 --- a/ngraph/core/src/pass/perf_counters.hpp +++ b/ngraph/core/src/pass/perf_counters.hpp @@ -7,7 +7,7 @@ #include #include -namespace ngraph { +namespace ov { namespace pass { class PerfCounters { PerfCounters(PerfCounters const&) = delete; @@ -27,4 +27,4 @@ private: counters_map m_counters; }; } // namespace pass -} // namespace ngraph +} // namespace ov diff --git a/ngraph/core/src/pass/validate.cpp b/ngraph/core/src/pass/validate.cpp index 832354bb3f9..0bfb642765a 100644 --- a/ngraph/core/src/pass/validate.cpp +++ b/ngraph/core/src/pass/validate.cpp @@ -2,16 +2,16 @@ // SPDX-License-Identifier: Apache-2.0 // -#include "ngraph/pass/validate.hpp" +#include "openvino/pass/validate.hpp" #include "itt.hpp" #include "ngraph/graph_util.hpp" using namespace ngraph; -NGRAPH_RTTI_DEFINITION(ngraph::pass::Validate, "ngraph::pass::Validate", 0); +OPENVINO_RTTI_DEFINITION(ov::pass::Validate, "ov::pass::Validate", 0); -bool pass::Validate::run_on_function(std::shared_ptr f) { +bool ov::pass::Validate::run_on_function(std::shared_ptr f) { f->validate_nodes_and_infer_types(); return false; } diff --git a/ngraph/core/src/pattern/matcher.cpp b/ngraph/core/src/pattern/matcher.cpp index 9edc0df3309..dc24e44ae54 100644 --- a/ngraph/core/src/pattern/matcher.cpp +++ b/ngraph/core/src/pattern/matcher.cpp @@ -13,7 +13,8 @@ #include "ngraph/op/parameter.hpp" #include "ngraph/op/util/op_types.hpp" -namespace ngraph { +namespace ov { +namespace pass { namespace pattern { MatcherState::MatcherState(Matcher* matcher) : m_matcher(matcher), @@ -88,7 +89,7 @@ bool Matcher::is_contained_match(const NodeVector& exclusions, bool ignore_unuse NGRAPH_SUPPRESS_DEPRECATED_START if (exclusions.empty()) { NodeVector label_exclusions; - for (auto entry : m_pattern_map) { + for (const auto& entry : m_pattern_map) { // leaf label if (entry.first->get_input_size() == 0) { label_exclusions.push_back(entry.second.get_node_shared_ptr()); @@ -108,7 +109,7 @@ bool Matcher::match_value(const ngraph::Output& pattern_value, const ngrap // This env var allows one to specify node name patterns to abort pattern matching // at particular nodes. The upshot is that one can quickly zero in on an offending // fusion by disabling individual fusions or optimizations that use Matcher. - static const std::string node_skip_cregex = getenv_string("NGRAPH_FAIL_MATCH_AT"); + static const std::string node_skip_cregex = ngraph::getenv_string("NGRAPH_FAIL_MATCH_AT"); if (!node_skip_cregex.empty()) { static const std::regex node_skip_regex(node_skip_cregex); if (std::regex_match(graph_node->get_name(), node_skip_regex)) { @@ -201,7 +202,7 @@ void Matcher::clear_state() { namespace { std::set> as_node_set(const std::set>& label_set) { std::set> result; - for (auto label : label_set) { + for (const auto& label : label_set) { result.insert(label); } return result; @@ -230,7 +231,7 @@ bool RecurrentMatcher::match(Output graph) { graph = m.get_pattern_value_map()[m_recurrent_pattern]; // copy bound nodes for the current pattern graph into a global matches map - for (auto cur_match : m.get_pattern_value_map()) { + for (const auto& cur_match : m.get_pattern_value_map()) { m_matches[cur_match.first].push_back(cur_match.second); } @@ -238,7 +239,7 @@ bool RecurrentMatcher::match(Output graph) { // from the current match. Only bound nodes whose labels are in // correlated_patterns are pre-populated. Skip other labels are // unbounded by default - for (auto cor_pat : m_correlated_patterns) { + for (const auto& cor_pat : m_correlated_patterns) { previous_matches[cor_pat] = m.get_pattern_value_map()[cor_pat]; } m = m_repeat; @@ -251,4 +252,5 @@ bool RecurrentMatcher::match(Output graph) { return matched; } } // namespace pattern -} // namespace ngraph +} // namespace pass +} // namespace ov diff --git a/ngraph/core/src/pattern/op/any.cpp b/ngraph/core/src/pattern/op/any.cpp index 58ca9aed446..c2cf96c9efe 100644 --- a/ngraph/core/src/pattern/op/any.cpp +++ b/ngraph/core/src/pattern/op/any.cpp @@ -7,17 +7,16 @@ #include "ngraph/pattern/matcher.hpp" using namespace std; -using namespace ngraph; -constexpr NodeTypeInfo pattern::op::Any::type_info; +constexpr ov::NodeTypeInfo ov::pass::pattern::op::Any::type_info; -const NodeTypeInfo& pattern::op::Any::get_type_info() const { +const ov::NodeTypeInfo& ov::pass::pattern::op::Any::get_type_info() const { return type_info; } -bool pattern::op::Any::match_value(Matcher* matcher, - const Output& pattern_value, - const Output& graph_value) { +bool ov::pass::pattern::op::Any::match_value(Matcher* matcher, + const Output& pattern_value, + const Output& graph_value) { matcher->add_node(graph_value); return m_predicate(graph_value) && matcher->match_arguments(pattern_value.get_node(), graph_value.get_node_shared_ptr()); diff --git a/ngraph/core/src/pattern/op/any_of.cpp b/ngraph/core/src/pattern/op/any_of.cpp index affc3c48c52..6827ca8a69b 100644 --- a/ngraph/core/src/pattern/op/any_of.cpp +++ b/ngraph/core/src/pattern/op/any_of.cpp @@ -7,20 +7,19 @@ #include "ngraph/pattern/matcher.hpp" using namespace std; -using namespace ngraph; -constexpr NodeTypeInfo pattern::op::AnyOf::type_info; +constexpr ov::NodeTypeInfo ov::pass::pattern::op::AnyOf::type_info; -const NodeTypeInfo& pattern::op::AnyOf::get_type_info() const { +const ov::NodeTypeInfo& ov::pass::pattern::op::AnyOf::get_type_info() const { return type_info; } -bool pattern::op::AnyOf::match_value(Matcher* matcher, - const Output& pattern_value, - const Output& graph_value) { +bool ov::pass::pattern::op::AnyOf::match_value(Matcher* matcher, + const Output& pattern_value, + const Output& graph_value) { matcher->add_node(graph_value); return m_predicate(graph_value) && ([&]() { - for (auto arg : graph_value.get_node_shared_ptr()->input_values()) { + for (const auto& arg : graph_value.get_node_shared_ptr()->input_values()) { auto saved = matcher->start_match(); if (matcher->match_value(input_value(0), arg)) { return saved.finish(true); diff --git a/ngraph/core/src/pattern/op/any_output.cpp b/ngraph/core/src/pattern/op/any_output.cpp index 940eb7865fb..d9e36cf9cc0 100644 --- a/ngraph/core/src/pattern/op/any_output.cpp +++ b/ngraph/core/src/pattern/op/any_output.cpp @@ -7,16 +7,15 @@ #include "ngraph/pattern/matcher.hpp" using namespace std; -using namespace ngraph; -constexpr NodeTypeInfo pattern::op::AnyOutput::type_info; +constexpr ov::NodeTypeInfo ov::pass::pattern::op::AnyOutput::type_info; -const NodeTypeInfo& pattern::op::AnyOutput::get_type_info() const { +const ov::NodeTypeInfo& ov::pass::pattern::op::AnyOutput::get_type_info() const { return type_info; } -bool pattern::op::AnyOutput::match_value(Matcher* matcher, - const Output& pattern_value, - const Output& graph_value) { +bool ov::pass::pattern::op::AnyOutput::match_value(Matcher* matcher, + const Output& pattern_value, + const Output& graph_value) { return input_value(0).get_node()->match_node(matcher, graph_value); } diff --git a/ngraph/core/src/pattern/op/label.cpp b/ngraph/core/src/pattern/op/label.cpp index 621e0d4ebe9..025ac805f29 100644 --- a/ngraph/core/src/pattern/op/label.cpp +++ b/ngraph/core/src/pattern/op/label.cpp @@ -9,15 +9,14 @@ #include "ngraph/pattern/op/true.hpp" using namespace std; -using namespace ngraph; -constexpr NodeTypeInfo pattern::op::Label::type_info; +constexpr ov::NodeTypeInfo ov::pass::pattern::op::Label::type_info; -const NodeTypeInfo& pattern::op::Label::get_type_info() const { +const ov::NodeTypeInfo& ov::pass::pattern::op::Label::get_type_info() const { return type_info; } -Output pattern::op::Label::wrap_values(const OutputVector& wrapped_values) { +ov::Output ov::pass::pattern::op::Label::wrap_values(const ov::OutputVector& wrapped_values) { switch (wrapped_values.size()) { case 0: return make_shared()->output(0); @@ -28,9 +27,9 @@ Output pattern::op::Label::wrap_values(const OutputVector& wrapped_values) } } -bool pattern::op::Label::match_value(Matcher* matcher, - const Output& pattern_value, - const Output& graph_value) { +bool ov::pass::pattern::op::Label::match_value(ov::pass::pattern::Matcher* matcher, + const ov::Output& pattern_value, + const ov::Output& graph_value) { if (m_predicate(graph_value)) { auto& pattern_map = matcher->get_pattern_value_map(); auto saved = matcher->start_match(); @@ -45,10 +44,10 @@ bool pattern::op::Label::match_value(Matcher* matcher, return false; } -std::shared_ptr pattern::any_input() { +std::shared_ptr ov::pass::pattern::any_input() { return std::make_shared(); } -std::shared_ptr pattern::any_input(const pattern::op::ValuePredicate& pred) { +std::shared_ptr ov::pass::pattern::any_input(const ov::pass::pattern::op::ValuePredicate& pred) { return std::make_shared(element::dynamic, PartialShape::dynamic(), pred); } diff --git a/ngraph/core/src/pattern/op/pattern.cpp b/ngraph/core/src/pattern/op/pattern.cpp index 0c7c34a4745..f3fed5add9c 100644 --- a/ngraph/core/src/pattern/op/pattern.cpp +++ b/ngraph/core/src/pattern/op/pattern.cpp @@ -7,7 +7,8 @@ #include #include -namespace ngraph { +namespace ov { +namespace pass { namespace pattern { namespace op { // The symbols are required to be in cpp file to workaround RTTI issue on Android LLVM @@ -101,4 +102,5 @@ std::function)> type_matches_any(const std::vector -#include "dyn_elimination.hpp" #include "ngraph/builder/reshape.hpp" #include "ngraph/op/broadcast.hpp" #include "ngraph/op/range.hpp" @@ -19,9 +20,7 @@ NGRAPH_SUPPRESS_DEPRECATED_START using namespace std; using namespace ngraph; -pass::DynElimination::DynElimination() - : GraphRewrite() -{ +pass::DynElimination::DynElimination() : GraphRewrite() { construct_range(); } @@ -29,28 +28,22 @@ template std::shared_ptr make_range_replacement(const element::Type& et, const Shape& shape, const std::shared_ptr& start_arg, - const std::shared_ptr& step_arg) -{ + const std::shared_ptr& step_arg) { std::vector elements(shape_size(shape)); std::vector start_vec = start_arg->get_vector(); std::vector step_vec = step_arg->get_vector(); NGRAPH_CHECK(start_vec.size() == 1 && step_vec.size() == 1); - runtime::reference::range( - start_vec.data(), step_vec.data(), shape_size(shape), elements.data()); + runtime::reference::range(start_vec.data(), step_vec.data(), shape_size(shape), elements.data()); return make_shared(et, shape, elements); } -void pass::DynElimination::construct_range() -{ - auto start_arg_label = - make_shared(element::f32, Shape{}, pattern::has_class()); - auto stop_arg_label = - make_shared(element::f32, Shape{}, pattern::has_class()); - auto step_arg_label = - make_shared(element::f32, Shape{}, pattern::has_class()); +void pass::DynElimination::construct_range() { + auto start_arg_label = make_shared(element::f32, Shape{}, pattern::has_class()); + auto stop_arg_label = make_shared(element::f32, Shape{}, pattern::has_class()); + auto step_arg_label = make_shared(element::f32, Shape{}, pattern::has_class()); auto range_pat = make_shared(start_arg_label, stop_arg_label, step_arg_label); @@ -70,12 +63,11 @@ void pass::DynElimination::construct_range() std::shared_ptr replacement; #if defined(__GNUC__) && !(__GNUC__ == 4 && __GNUC_MINOR__ == 8) -#pragma GCC diagnostic push -#pragma GCC diagnostic error "-Wswitch" -#pragma GCC diagnostic error "-Wswitch-enum" +# pragma GCC diagnostic push +# pragma GCC diagnostic error "-Wswitch" +# pragma GCC diagnostic error "-Wswitch-enum" #endif - switch (et) - { + switch (et) { case element::Type_t::bf16: replacement = make_range_replacement(et, shape, start_arg, step_arg); break; @@ -122,7 +114,7 @@ void pass::DynElimination::construct_range() break; } #if defined(__GNUC__) && !(__GNUC__ == 4 && __GNUC_MINOR__ == 8) -#pragma GCC diagnostic pop +# pragma GCC diagnostic pop #endif replace_node(range_node, replacement);