Move pass pattern to ov (#7255)

* Moved ngraph::Node to ov namespace

* Fixed code style

* Fixed VPU

* Fixed GNA

* Fixed tests

* Added aliases for backward compatibility

* Fix clDNN

* Try to fix build

* Fixed comment

* Renamed RTTI macros

* Add new headers

* Fixed ngraph build

* Fixed unit tests

* Try to fix Serialize
This commit is contained in:
Ilya Churaev 2021-09-02 10:03:04 +03:00 committed by GitHub
parent 07f7061f96
commit 9eca6ba9d5
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
62 changed files with 2028 additions and 1556 deletions

View File

@ -26,7 +26,7 @@ endif()
# resolving dependencies for the project # resolving dependencies for the project
message (STATUS "PROJECT ............................... " ${PROJECT_NAME}) message (STATUS "PROJECT ............................... " ${PROJECT_NAME})
message (STATUS "CMAKE_BINARY_DIR ...................... " ${CMAKE_BINARY_DIR}) message (STATUS "CMAKE_BINARY_DIR ...................... " ${CMAKE_BINARY_DIR})
message (STATUS "OpenVINO_SOURCE_DIR .... .......... " ${OpenVINO_SOURCE_DIR}) message (STATUS "OpenVINO_SOURCE_DIR ................... " ${OpenVINO_SOURCE_DIR})
message (STATUS "CMAKE_GENERATOR ....................... " ${CMAKE_GENERATOR}) message (STATUS "CMAKE_GENERATOR ....................... " ${CMAKE_GENERATOR})
message (STATUS "CMAKE_C_COMPILER_ID ................... " ${CMAKE_C_COMPILER_ID}) message (STATUS "CMAKE_C_COMPILER_ID ................... " ${CMAKE_C_COMPILER_ID})
message (STATUS "CMAKE_BUILD_TYPE ...................... " ${CMAKE_BUILD_TYPE}) message (STATUS "CMAKE_BUILD_TYPE ...................... " ${CMAKE_BUILD_TYPE})

View File

@ -811,8 +811,34 @@ void ngfunction_2_irv10(pugi::xml_node& netXml,
f.validate_nodes_and_infer_types(); f.validate_nodes_and_infer_types();
} }
} }
std::string valid_xml_path(const std::string &path) {
NGRAPH_CHECK(path.length() > 4, "Path for xml file is to short: \"" + path + "\"");
const char *const extension = ".xml";
const bool has_xml_extension = path.rfind(extension) == path.size() - std::strlen(extension);
NGRAPH_CHECK(has_xml_extension,
"Path for xml file doesn't contains file name with 'xml' extension: \"" +
path + "\"");
return path;
}
std::string provide_bin_path(const std::string &xmlPath, const std::string &binPath) {
if (!binPath.empty()) {
return binPath;
}
assert(xmlPath.size() > 4); // should be check by valid_xml_path
std::string bestPath = xmlPath;
const char *const extension = "bin";
const auto ext_size = std::strlen(extension);
bestPath.replace(bestPath.size() - ext_size, ext_size, extension);
return bestPath;
}
} // namespace } // namespace
namespace ngraph {
// ! [function_pass:serialize_cpp] // ! [function_pass:serialize_cpp]
// serialize.cpp // serialize.cpp
bool pass::Serialize::run_on_function(std::shared_ptr<ngraph::Function> f) { bool pass::Serialize::run_on_function(std::shared_ptr<ngraph::Function> f) {
@ -868,33 +894,6 @@ bool pass::Serialize::run_on_function(std::shared_ptr<ngraph::Function> f) {
return false; return false;
} }
namespace {
std::string valid_xml_path(const std::string &path) {
NGRAPH_CHECK(path.length() > 4, "Path for xml file is to short: \"" + path + "\"");
const char *const extension = ".xml";
const bool has_xml_extension = path.rfind(extension) == path.size() - std::strlen(extension);
NGRAPH_CHECK(has_xml_extension,
"Path for xml file doesn't contains file name with 'xml' extension: \"" +
path + "\"");
return path;
}
std::string provide_bin_path(const std::string &xmlPath, const std::string &binPath) {
if (!binPath.empty()) {
return binPath;
}
assert(xmlPath.size() > 4); // should be check by valid_xml_path
std::string bestPath = xmlPath;
const char *const extension = "bin";
const auto ext_size = std::strlen(extension);
bestPath.replace(bestPath.size() - ext_size, ext_size, extension);
return bestPath;
}
} // namespace
pass::Serialize::Serialize(std::ostream& xmlFile, pass::Serialize::Serialize(std::ostream& xmlFile,
std::ostream& binFile, std::ostream& binFile,
pass::Serialize::Version version, pass::Serialize::Version version,
@ -921,3 +920,4 @@ pass::Serialize::Serialize(const std::string& xmlPath,
{ {
} }
// ! [function_pass:serialize_cpp] // ! [function_pass:serialize_cpp]
} // namespace ngraph

View File

@ -5,24 +5,10 @@
#pragma once #pragma once
#include "ngraph/pass/pass.hpp" #include "ngraph/pass/pass.hpp"
#include "openvino/pass/constant_folding.hpp"
namespace ngraph { namespace ngraph {
namespace pass { namespace pass {
/** using ov::pass::ConstantFolding;
* @brief Constant folding iterates over the function and tries to evaluate nodes
* with constant inputs. Such nodes are then replaced with new Constants containing
* the result of a folded operation.
*/
class NGRAPH_API ConstantFolding : public FunctionPass {
public:
NGRAPH_RTTI_DECLARATION;
bool run_on_function(std::shared_ptr<ngraph::Function> f) override;
private:
void copy_runtime_info_to_target_inputs(const std::shared_ptr<Node>& node, const Output<Node>& replacement);
/// \brief Folds pre-calculated output tensor values to constants in case lower and
/// upper estimations are equal. Traverses graph backwards starting from the results.
bool pre_calculated_values_folding(const std::shared_ptr<ngraph::Function>& f);
};
} // namespace pass } // namespace pass
} // namespace ngraph } // namespace ngraph

View File

@ -4,14 +4,11 @@
#pragma once #pragma once
#include <ngraph/pass/graph_rewrite.hpp> #include "ngraph/pass/graph_rewrite.hpp"
#include "openvino/pass/convert_fp32_to_fp16.hpp"
namespace ngraph { namespace ngraph {
namespace pass { namespace pass {
class NGRAPH_API ConvertFP32ToFP16 : public ngraph::pass::FunctionPass { using ov::pass::ConvertFP32ToFP16;
public:
NGRAPH_RTTI_DECLARATION;
bool run_on_function(std::shared_ptr<ngraph::Function>) override;
};
} // namespace pass } // namespace pass
} // namespace ngraph } // namespace ngraph

View File

@ -10,240 +10,17 @@
#include "ngraph/pass/pass.hpp" #include "ngraph/pass/pass.hpp"
#include "ngraph/pattern/matcher.hpp" #include "ngraph/pattern/matcher.hpp"
#include "openvino/pass/graph_rewrite.hpp"
namespace ngraph { namespace ngraph {
using matcher_pass_callback = std::function<bool(ngraph::pattern::Matcher& m)>; using ov::graph_rewrite_callback;
using graph_rewrite_callback = std::function<bool(ngraph::pattern::Matcher& m)>; using ov::handler_callback;
using recurrent_graph_rewrite_callback = std::function<bool(ngraph::pattern::RecurrentMatcher& m)>; using ov::matcher_pass_callback;
using handler_callback = std::function<bool(const std::shared_ptr<Node>& node)>; using ov::recurrent_graph_rewrite_callback;
namespace pass { namespace pass {
/// \brief MatcherPass is a basic block for pattern based transformations. It describes using ov::pass::BackwardGraphRewrite;
/// pattern and using ov::pass::GraphRewrite;
/// action that is applied if pattern is matched. using ov::pass::MatcherPass;
/// using ov::pass::RecurrentGraphRewrite;
/// MatcherPass consists of Matcher and matcher_pass_callback that needs to be implemented
/// and
/// finally registered by using \sa register_matcher. MatcherPass can be executed on node
/// within
/// \sa apply method. To run matcher pass on Function use GraphRewrite.
/// In addition MatcherPass provides a way for adding new operations into GraphRewrite
/// execution
/// queue. That means that operations that were created inside transformation callback can
/// be added
/// for matching. To register node use \sa register_new_node method. GraphRewrite
/// automatically
/// takes registered nodes and put them to execution queue. If multiple nodes were register
/// make
/// sure that they were registered in topological order.
/// Note: when implementing pattern for Matcher make sure that root node is an operation
/// from opset
/// or has ngraph::pattern::op::WrapType. That will help GraphRewrite to execute matcher
/// passes more
/// efficient.
class NGRAPH_API MatcherPass : public ngraph::pass::PassBase {
public:
NGRAPH_RTTI_DECLARATION;
MatcherPass() = default;
MatcherPass(const MatcherPass&) = delete;
MatcherPass& operator=(const MatcherPass&) = delete;
explicit MatcherPass(const std::string& name,
const std::shared_ptr<pattern::Matcher>& m,
const handler_callback& handler,
const PassPropertyMask& property = PassProperty::CHANGE_DYNAMIC_STATE)
: PassBase(),
m_handler(handler),
m_matcher(m) {
set_name(name);
set_property(property, true);
}
bool apply(std::shared_ptr<ngraph::Node> node);
template <typename T, class... Args>
std::shared_ptr<T> register_new_node(Args&&... args) {
auto node = std::make_shared<T>(std::forward<Args>(args)...);
m_new_nodes.push_back(node);
return node;
}
template <typename T>
std::shared_ptr<T> register_new_node(const std::shared_ptr<T>& node) {
m_new_nodes.push_back(node);
return node;
}
const std::vector<std::shared_ptr<ngraph::Node>>& get_new_nodes() {
return m_new_nodes;
}
void clear_new_nodes() {
m_new_nodes.clear();
}
std::shared_ptr<pattern::Matcher> get_matcher() {
return m_matcher;
}
protected:
void register_matcher(const std::shared_ptr<pattern::Matcher>& m,
const ngraph::graph_rewrite_callback& callback,
const PassPropertyMask& property = PassProperty::CHANGE_DYNAMIC_STATE);
private:
handler_callback m_handler;
std::shared_ptr<pattern::Matcher> m_matcher;
std::vector<std::shared_ptr<ngraph::Node>> m_new_nodes;
};
/// \brief GraphRewrite is a container for MatcherPasses that allows to run them on Function
/// in
/// efficient way
///
/// Graph rewrite pass is used for matcher passes execution on Function.
/// To register MatcherPass use \sa add_matcher<T>(args) method where T is a MatcherPass
/// class.
/// As a default algorithm graph rewrite pass traverse Function in topological order and
/// applies
/// registered matcher passes for each node. But if all registered matcher passes have type
/// based
/// root node in Matcher pattern then efficient mechanism is used to execute them.
/// Matcher pattern root is type based if it's operation from opset or
/// pattern::op::WrapType.
/// Note: when implementing pattern for Matcher make sure that root node is an operation
/// from opset
/// or has ngraph::pattern::op::WrapType. That will help GraphRewrite to execute matcher
/// passes more
/// efficient.
class NGRAPH_API GraphRewrite : public ngraph::pass::FunctionPass {
public:
NGRAPH_RTTI_DECLARATION;
GraphRewrite() = default;
explicit GraphRewrite(const std::shared_ptr<MatcherPass>& pass) : FunctionPass() {
m_matchers.push_back(pass);
}
/// \brief Register given transformation class type to GraphRewrite execution list
/// All registered transformations will be executed in a single graph traversal.
/// Example below show the basic usage of pass::GraphRewrite
///
/// pass::Manager manager;
/// auto anchor = manager.register_pass<GraphRewrite>();
/// anchor->add_matcher<MatcherPassA>();
/// anchor->add_matcher<MatcherPassB>();
/// anchor->set_name("CommonMatchers");
/// manager.run_passes(f);
///
/// For some purposes transformation can be registered and disabled by default.
///
/// anchor->add_matcher<MatcherPassB, false>();
///
/// \return shared_ptr to the transformation instance
template <typename T,
bool Enabled = true,
class... Args,
typename std::enable_if<std::is_base_of<pass::MatcherPass, T>::value, bool>::type = true>
std::shared_ptr<T> add_matcher(Args&&... args) {
static_assert(std::is_base_of<pass::MatcherPass, T>::value, "pass not derived from MatcherPass");
auto pass = std::make_shared<T>(std::forward<Args>(args)...);
auto pass_config = get_pass_config();
pass->set_pass_config(pass_config);
if (!Enabled && !pass_config->is_enabled<T>()) {
pass_config->disable<T>();
}
m_matchers.push_back(pass);
return pass;
}
/// \brief Register passes from GraphRewrite class that contains sequence of matcher
/// passes registered in its ctor.
/// For example:
///
/// class ngraph::pass::LinFusions: public ngraph::pass::GraphRewrite {
/// public:
/// NGRAPH_RTTI_DECLARATION;
/// Fusions() {
/// add_matcher<ngraph::pass::AddFusion>();
/// add_matcher<ngraph::pass::MulFusion>();
/// }
/// };
///
/// pass::Manager manager;
/// auto anchor = manager.register_pass<GraphRewrite>();
/// anchor->add_matcher<LinFusions>();
/// anchor->add_matcher<OtherFusions>();
/// anchor->set_name("CommonFusions");
/// manager.run_passes(f);
///
/// In this case all matcher passes from LinFusions pass will be united with other
/// registered matchers.
template <typename T,
class... Args,
typename std::enable_if<std::is_base_of<pass::GraphRewrite, T>::value, bool>::type = true>
void add_matcher(Args&&... args) {
static_assert(std::is_base_of<pass::GraphRewrite, T>::value, "pass not derived from GraphRewrite");
auto pass = std::make_shared<T>(std::forward<Args>(args)...);
auto pass_config = get_pass_config();
for (auto& matcher : pass->m_matchers) {
pass->set_pass_config(pass_config);
m_matchers.push_back(matcher);
}
}
NGRAPH_DEPRECATED("Use MatcherPass instead")
void add_matcher(const std::shared_ptr<pattern::Matcher>& m,
const ngraph::graph_rewrite_callback& callback,
const PassPropertyMask& property);
NGRAPH_DEPRECATED("Use MatcherPass instead")
void add_matcher(const std::shared_ptr<pattern::Matcher>& m, const ngraph::graph_rewrite_callback& callback);
bool run_on_function(std::shared_ptr<ngraph::Function> f) override;
void set_pass_config(const std::shared_ptr<PassConfig>& pass_config) override;
protected:
bool apply_matcher_passes(std::shared_ptr<Function> f, std::deque<std::weak_ptr<Node>> nodes_to_run);
bool m_enable_shape_inference = false;
std::vector<std::shared_ptr<ngraph::pass::MatcherPass>> m_matchers;
};
class NGRAPH_API BackwardGraphRewrite : public ngraph::pass::GraphRewrite {
public:
NGRAPH_RTTI_DECLARATION;
BackwardGraphRewrite() = default;
explicit BackwardGraphRewrite(const std::shared_ptr<MatcherPass>& pass) : GraphRewrite(pass) {}
bool run_on_function(std::shared_ptr<ngraph::Function> f) override;
};
class NGRAPH_API RecurrentGraphRewrite : public ngraph::pass::FunctionPass {
public:
RecurrentGraphRewrite(size_t num_iters = 10) : FunctionPass(), m_num_iters(num_iters) {}
void add_matcher(const std::shared_ptr<pattern::RecurrentMatcher>& m,
const ngraph::recurrent_graph_rewrite_callback& callback,
const PassPropertyMask& property);
// TODO: This interface may deprecate after all passes are refactored.
void add_matcher(const std::shared_ptr<pattern::RecurrentMatcher>& m,
const ngraph::recurrent_graph_rewrite_callback& callback);
bool run_on_function(std::shared_ptr<ngraph::Function> f) override;
private:
size_t m_num_iters;
std::vector<std::shared_ptr<ngraph::pass::MatcherPass>> m_matchers;
};
} // namespace pass } // namespace pass
} // namespace ngraph } // namespace ngraph

View File

@ -5,10 +5,12 @@
#pragma once #pragma once
#include <memory> #include <memory>
#include <ngraph/pass/graph_rewrite.hpp>
#include <ngraph/pass/pass.hpp>
#include <vector> #include <vector>
#include "ngraph/pass/graph_rewrite.hpp"
#include "ngraph/pass/pass.hpp"
#include "openvino/pass/low_latency.hpp"
namespace ngraph { namespace ngraph {
namespace pass { namespace pass {
/** /**
@ -46,38 +48,6 @@ public:
LowLatency(); LowLatency();
}; };
/** using ov::pass::LowLatency2;
* @brief The transformation finds all TensorIterator/Loop layers in the network,
* processes all back edges that describe a connection between Result and Parameter
* of the TensorIterator/Loop bodies,and inserts ReadValue and Assign layers at the
* input and output corresponding to this back edge.
* Supported platforms: CPU, GNA.
*
* The example below describes the changes made by the transformation
* [] - TensorIterator body
* () - new layer
* BE - back-edge
*
* before applying the transformation:
* -> input1[BE_1 -> Parameter -> Layers ... -> Result -> BE_1 ]output1->
*
* after applying the transformation:
* ->(ReadValue)-> input1[BE_1 ->Parameter->Layers ...->Result->BE_1]output1 ->(Assign)
* \
* ->...
* After applying the transformation, the resulting network can be inferred
* step by step, the states will store between inferences.
*/
class NGRAPH_API LowLatency2 : public ngraph::pass::FunctionPass {
public:
NGRAPH_RTTI_DECLARATION;
explicit LowLatency2(bool use_const_initializer = true) : m_use_const_initializer(use_const_initializer) {}
bool run_on_function(std::shared_ptr<ngraph::Function> f) override;
private:
bool m_use_const_initializer;
};
} // namespace pass } // namespace pass
} // namespace ngraph } // namespace ngraph

View File

@ -11,106 +11,10 @@
#include "ngraph/pass/pass.hpp" #include "ngraph/pass/pass.hpp"
#include "ngraph/pass/validate.hpp" #include "ngraph/pass/validate.hpp"
#include "openvino/pass/manager.hpp"
namespace ngraph { namespace ngraph {
namespace pass { namespace pass {
class NGRAPH_API Manager { using ov::pass::Manager;
public:
Manager();
~Manager();
//// \brief Construct Manager with shared PassConfig instance
explicit Manager(std::shared_ptr<PassConfig> pass_config);
/// \brief Register given transformation class type to execution list
/// Example below show the basic usage of pass::Manager
///
/// pass::Manager manager;
/// manager.register_pass<MyTransformation>(/*transformation constructor ars*/);
/// manager.run_passes(f);
///
/// For some purposes transformation can be registered and disabled by default.
///
/// manager.register_pass<MyTransformation, false>();
///
/// \return shared_ptr to the transformation instance
template <typename T, bool Enable = true, class... Args>
std::shared_ptr<T> register_pass(Args&&... args) {
auto rc = push_pass<T>(std::forward<Args>(args)...);
rc->set_pass_config(m_pass_config);
if (m_per_pass_validation) {
push_pass<Validate>();
}
if (!Enable && !m_pass_config->is_enabled<T>()) {
m_pass_config->disable<T>();
}
return rc;
}
void run_passes(std::shared_ptr<Function>);
void set_pass_visualization(bool new_state) {
m_visualize = new_state;
}
/// \brief Set flag to enable/disable running Validate pass after executing
/// each registered pass
/// \param new_state Value "true" enables Validate pass run; "false", otherwise
void set_per_pass_validation(bool new_state) {
m_per_pass_validation = new_state;
}
/// \brief Callback is a lambda function that can be used by registered transformations.
/// The main purpose of this callback is to provide a way for plugins to disable/enable
/// transformations based on some conditions. In some cases plugins may want not to
/// execute some
/// transformations.
/// For example plugin can disable unpleasant decompositions because of performance
/// reasons for
/// some cases.
/// Callback example:
/// auto callback = [](const std::shared_ptr<const ngraph::Node> & node) -> bool {
/// return std::dynamic_pointer_cast<const ngraph::opset3::DepthToSpace>(node) !=
/// nullptr;
/// };
/// This callback returns true in case of DepthToSpace operation. So when execution
/// DepthToSpace
/// decomposition pass will check is this decomposition needed or plugin can execute
/// this
/// operation directly. And of course on transformation side we need to have a response
/// for this
/// callback.
/// if (transformation_callback(batch_to_space)) {
/// return false;
/// }
/// \param callback lamda function that returns true in case if node is supported by
/// plugin and
/// transformation is not needed
NGRAPH_DEPRECATED("Please use get_pass_config() to configure transformation pipeline")
void set_callback(const param_callback& callback) {
m_pass_config->set_callback(callback);
}
/// \return PassConfig shared object. This object is used for transformations pipeline
/// configuration.
/// This object allows to disable/enable transformations execution, set callback to
/// particular
/// transformation. For mo details see PassConfig class.
std::shared_ptr<PassConfig> get_pass_config() {
return m_pass_config;
}
protected:
template <typename T, class... Args>
std::shared_ptr<T> push_pass(Args&&... args) {
static_assert(std::is_base_of<pass::PassBase, T>::value, "pass not derived from pass base");
auto pass = std::make_shared<T>(std::forward<Args>(args)...);
auto pass_base = std::static_pointer_cast<PassBase>(pass);
m_pass_list.push_back(pass_base);
return pass;
}
std::shared_ptr<PassConfig> m_pass_config;
std::vector<std::shared_ptr<PassBase>> m_pass_list;
bool m_visualize = false;
bool m_per_pass_validation = true;
};
} // namespace pass } // namespace pass
} // namespace ngraph } // namespace ngraph

View File

@ -13,105 +13,32 @@
#include "ngraph/node.hpp" #include "ngraph/node.hpp"
#include "ngraph/pass/pass_config.hpp" #include "ngraph/pass/pass_config.hpp"
#include "ngraph/util.hpp" #include "ngraph/util.hpp"
#include "openvino/pass/pass.hpp"
namespace ov {
namespace pass {
class Manager;
}
} // namespace ov
namespace ngraph { namespace ngraph {
namespace pass { namespace pass {
enum class PassProperty : uint32_t { using ov::pass::FunctionPass;
// Pass requires node shapes to be static using ov::pass::FusionType;
REQUIRE_STATIC_SHAPE = 0x1, using ov::pass::FusionTypeMask;
// Pass transformation will change the function's dynamic state using ov::pass::Manager;
CHANGE_DYNAMIC_STATE = 1 << 1, using ov::pass::PassBase;
}; using ov::pass::PassProperty;
using ov::pass::PassPropertyMask;
typedef EnumMask<PassProperty> PassPropertyMask; NGRAPH_DEPRECATED("This variable is deprecated and will be removed soon.")
const PassPropertyMask all_pass_property_off; const PassPropertyMask all_pass_property_off;
class NGRAPH_API PassBase {
friend class Manager;
public:
PassBase();
virtual ~PassBase() {}
/// Check if this pass has all the pass properties.
bool get_property(const PassPropertyMask& prop_mask) const;
void set_name(const std::string& name) {
m_name = name;
}
std::string get_name() const;
/// \brief Set callback for particular transformation type.
/// This method set global callback. For more details see PassConfig class
/// documentation.
/// \param callback lambda function that takes node and returns bool
void set_callback(const param_callback& callback);
/// \brief Set PassConfig for particular transformation instance
/// \param pass_config is a PassConfig shared_ptr
virtual void set_pass_config(const std::shared_ptr<PassConfig>& pass_config) {
m_pass_config = pass_config;
}
/// \brief Allows to access PassConfig shared instance
/// \return Shared instance of PassConfig class
std::shared_ptr<PassConfig> get_pass_config() {
return m_pass_config;
}
/// \brief Applies callback for given node. By default callback returns false.
/// This method remains here only for backward compatibility and will be removed
/// after all transformations are moved to transformation_callback() method.
/// \return result of callback execution for given node
NGRAPH_DEPRECATED("Please use transformation_callback method instead")
bool m_transformation_callback(const std::shared_ptr<const Node>& node) {
return m_pass_config->get_callback(get_type_info())(node);
}
/// \brief Applies callback for given node. By default callback returns false.
/// \param node which will be used inside callback
/// \return result of callback execution for given node
bool transformation_callback(const std::shared_ptr<const Node>& node) {
return m_pass_config->get_callback(get_type_info())(node);
}
using type_info_t = DiscreteTypeInfo;
virtual const type_info_t& get_type_info() const = 0;
protected:
void set_property(const PassPropertyMask& prop, bool value);
private:
PassPropertyMask m_property;
std::string m_name;
std::shared_ptr<PassConfig> m_pass_config;
};
class NGRAPH_API FunctionPass : public PassBase {
public:
NGRAPH_RTTI_DECLARATION;
virtual ~FunctionPass();
virtual bool run_on_function(std::shared_ptr<ngraph::Function>) = 0;
};
class NGRAPH_DEPRECATED("Use MatcherPass or FunctionPass instead.") NGRAPH_API NodePass : public PassBase { class NGRAPH_DEPRECATED("Use MatcherPass or FunctionPass instead.") NGRAPH_API NodePass : public PassBase {
public: public:
NGRAPH_RTTI_DECLARATION; NGRAPH_RTTI_DECLARATION;
virtual ~NodePass(); ~NodePass() override;
virtual bool run_on_node(std::shared_ptr<ngraph::Node>) = 0; virtual bool run_on_node(std::shared_ptr<ngraph::Node>) = 0;
}; };
class Manager;
enum class FusionType : uint32_t {
//`DIFFERENTIABLE_FUSIONS` produce ops that support autodiff
// i.e. implement `generate_adjoints`
DIFFERENTIABLE_FUSIONS = 0x1,
REGULAR_FUSIONS = 0x2,
//`FOP_FUSIONS` produce ops in the FusedOps category that might
// not be supported by all backends
FOP_FUSIONS = 0x4,
ALL_FUSIONS = 0xFFFFFFFF
};
typedef EnumMask<FusionType> FusionTypeMask;
} // namespace pass } // namespace pass
} // namespace ngraph } // namespace ngraph

View File

@ -12,164 +12,12 @@
#include "ngraph/function.hpp" #include "ngraph/function.hpp"
#include "ngraph/node.hpp" #include "ngraph/node.hpp"
#include "ngraph/util.hpp" #include "ngraph/util.hpp"
#include "openvino/pass/pass_config.hpp"
namespace ngraph { namespace ngraph {
namespace pass { namespace pass {
using param_callback = std::function<bool(const std::shared_ptr<const ::ngraph::Node>)>; using ov::pass::param_callback;
using param_callback_map = std::map<ngraph::DiscreteTypeInfo, param_callback>; using ov::pass::param_callback_map;
using ov::pass::PassConfig;
/// \brief Class representing a transformations config that is used for disabling/enabling
/// transformations registered inside pass::Manager and also allows to set callback for all
/// transformations or for particular transformation.
///
/// When pass::Manager is created all passes registered inside this manager including nested
/// passes will share the same instance of PassConfig class.
/// To work with this class first you need to get shared instance of this class by calling
/// manager.get_pass_config() method. Then you will be able to disable/enable passes based
/// on transformations type_info. For example:
///
/// pass::Manager manager;
/// manager.register_pass<CommonOptimizations>();
/// auto pass_config = manager.get_pass_config();
/// pass_config->disable<ConvertGELU>(); // this will disable nested pass inside
/// // CommonOptimizations pipeline
/// manager.run_passes(f);
///
/// Sometimes it is needed to call transformation inside other transformation manually. And
/// for that case before running transformation you need manually check that this pass is
/// not disabled and then you need to set current PassConfig instance to this
/// transformation. For example:
///
/// // Inside MatcherPass callback or inside FunctionPass run_on_function() method
/// // you need to call get_pass_config() method to get shared instance of PassConfig
/// auto pass_config = get_pass_config();
///
/// // Before running nested transformation you need to check is it disabled or not
/// if (!pass_config->is_disabled<ConvertGELU>()) {
/// auto pass = ConvertGELU();
/// pass->set_pass_config(pass_config);
/// pass.apply(node);
/// }
///
/// Following this logic inside your transformations you will guaranty that transformations
/// will be executed in a right way.
class NGRAPH_API PassConfig {
public:
/// \brief Disable transformation by its type_info
/// \param type_info Transformation type_info
void disable(const DiscreteTypeInfo& type_info);
/// \brief Disable transformation by its class type (based on type_info)
template <typename T>
void disable() {
NGRAPH_SUPPRESS_DEPRECATED_START
disable(T::type_info);
NGRAPH_SUPPRESS_DEPRECATED_END
}
/// \brief Enable transformation by its type_info
/// \param type_info Transformation type_info
void enable(const DiscreteTypeInfo& type_info);
/// \brief Enable transformation by its class type (based on type_info)
template <typename T>
void enable() {
NGRAPH_SUPPRESS_DEPRECATED_START
enable(T::type_info);
NGRAPH_SUPPRESS_DEPRECATED_END
}
/// \brief Set callback for all kind of transformations
void set_callback(const param_callback& callback) {
m_callback = callback;
}
template <typename... Args>
typename std::enable_if<sizeof...(Args) == 0>::type set_callback(const param_callback& callback) {}
/// \brief Set callback for particular transformation class types
///
/// Example below show how to set callback for one or multiple passes using this method.
///
/// pass_config->set_callback<ngraph::pass::ConvertBatchToSpace,
/// ngraph::pass::ConvertSpaceToBatch>(
/// [](const_node_ptr &node) -> bool {
/// // Disable transformations for cases when input shape rank is not
/// equal to 4
/// const auto input_shape_rank =
/// node->get_output_partial_shape(0).rank().get_length();
/// if (input_shape_rank != 4) {
/// return false;
/// }
/// return true;
/// });
///
/// Note that inside transformations you must provide code that work with this callback.
/// See example below:
///
/// if (transformation_callback(node)) {
/// return false; // exit from transformation
/// }
///
template <typename T, class... Args>
void set_callback(const param_callback& callback) {
m_callback_map[T::type_info] = callback;
set_callback<Args...>(callback);
}
/// \brief Get callback for given transformation type_info
/// \param type_info Transformation type_info
///
/// In case if callback wasn't set for given transformation type then global callback
/// will be returned. But if even global callback wasn't set then default callback will
/// be returned.
param_callback get_callback(const DiscreteTypeInfo& type_info) const;
/// \brief Get callback for given transformation class type
/// \return callback lambda function
template <typename T>
param_callback get_callback() const {
NGRAPH_SUPPRESS_DEPRECATED_START
return get_callback(T::type_info);
NGRAPH_SUPPRESS_DEPRECATED_END
}
/// \brief Check either transformation type is disabled or not
/// \param type_info Transformation type_info
/// \return true if transformation type was disabled and false otherwise
bool is_disabled(const DiscreteTypeInfo& type_info) const {
return m_disabled.count(type_info);
}
/// \brief Check either transformation class type is disabled or not
/// \return true if transformation type was disabled and false otherwise
template <typename T>
bool is_disabled() const {
NGRAPH_SUPPRESS_DEPRECATED_START
return is_disabled(T::type_info);
NGRAPH_SUPPRESS_DEPRECATED_END
}
/// \brief Check either transformation type is force enabled or not
/// \param type_info Transformation type_info
/// \return true if transformation type was force enabled and false otherwise
bool is_enabled(const DiscreteTypeInfo& type_info) const {
return m_enabled.count(type_info);
}
/// \brief Check either transformation class type is force enabled or not
/// \return true if transformation type was force enabled and false otherwise
template <typename T>
bool is_enabled() const {
return is_enabled(T::type_info);
}
void add_disabled_passes(const PassConfig& rhs);
private:
param_callback m_callback = [](const std::shared_ptr<const ::ngraph::Node>&) {
return false;
};
param_callback_map m_callback_map;
std::unordered_set<DiscreteTypeInfo> m_disabled;
std::unordered_set<DiscreteTypeInfo> m_enabled;
};
} // namespace pass } // namespace pass
} // namespace ngraph } // namespace ngraph

View File

@ -5,27 +5,10 @@
#pragma once #pragma once
#include "ngraph/pass/pass.hpp" #include "ngraph/pass/pass.hpp"
#include "openvino/pass/validate.hpp"
namespace ngraph { namespace ngraph {
namespace pass { namespace pass {
/// \brief The Validate pass performs sanity checks on attributes and inputs, and using ov::pass::Validate;
/// computes output shapes and element types for all computation nodes in a given
/// computation graph.
///
/// \details The verification and inference is done via invoking each node's specific
/// implementation of \link ngraph::Node::validate_and_infer_types() \endlink function.
///
/// By default, the \ref ngraph::pass::Manager runs this pass after executing every
/// optimization pass. This is to ensure that any update to the graph by an optimization
/// pass does not break the shape and data type requirement on a computation node.
/// This default validation run can be changed via calling the
/// \link ngraph::pass::Manager::set_per_pass_validation(bool) \endlink function.
class NGRAPH_API Validate : public FunctionPass {
public:
NGRAPH_RTTI_DECLARATION;
Validate() : FunctionPass() {}
bool run_on_function(std::shared_ptr<ngraph::Function> f) override;
};
} // namespace pass } // namespace pass
} // namespace ngraph } // namespace ngraph

View File

@ -14,44 +14,10 @@
#include <utility> #include <utility>
#include "ngraph/pass/pass.hpp" #include "ngraph/pass/pass.hpp"
#include "openvino/pass/visualize_tree.hpp"
class HeightMap;
using visualize_tree_ops_map_t =
std::unordered_map<ngraph::Node::type_info_t, std::function<void(const ngraph::Node&, std::ostream& ss)>>;
namespace ngraph { namespace ngraph {
namespace pass { namespace pass {
class NGRAPH_API VisualizeTree : public FunctionPass { using ov::pass::VisualizeTree;
public:
NGRAPH_RTTI_DECLARATION;
using node_modifiers_t = std::function<void(const Node& node, std::vector<std::string>& attributes)>;
VisualizeTree(const std::string& file_name, node_modifiers_t nm = nullptr, bool dot_only = false);
bool run_on_function(std::shared_ptr<ngraph::Function>) override;
void set_ops_to_details(const visualize_tree_ops_map_t& ops_map) {
m_ops_to_details = ops_map;
}
protected:
void add_node_arguments(std::shared_ptr<Node> node,
std::unordered_map<Node*, HeightMap>& height_maps,
size_t& fake_node_ctr);
std::string add_attributes(std::shared_ptr<Node> node);
virtual std::string get_attributes(std::shared_ptr<Node> node);
virtual std::string get_node_name(std::shared_ptr<Node> node);
std::string get_constant_value(std::shared_ptr<Node> node, size_t max_elements = 7);
void render() const;
std::stringstream m_ss;
std::string m_name;
std::set<std::shared_ptr<Node>> m_nodes_with_attributes;
visualize_tree_ops_map_t m_ops_to_details;
node_modifiers_t m_node_modifiers = nullptr;
bool m_dot_only;
static const int max_jump_distance;
};
} // namespace pass } // namespace pass
} // namespace ngraph } // namespace ngraph

View File

@ -16,255 +16,21 @@
#include "ngraph/pattern/op/any_output.hpp" #include "ngraph/pattern/op/any_output.hpp"
#include "ngraph/pattern/op/label.hpp" #include "ngraph/pattern/op/label.hpp"
#include "ngraph/pattern/op/skip.hpp" #include "ngraph/pattern/op/skip.hpp"
#include "openvino/pass/pattern/matcher.hpp"
namespace ngraph { namespace ov {
namespace pass { namespace pass {
class GraphRewrite; class GraphRewrite;
} }
} // namespace ov
namespace ngraph {
namespace pass {
using ov::pass::GraphRewrite;
}
namespace pattern { namespace pattern {
class Matcher; using ov::pass::pattern::Matcher;
using ov::pass::pattern::MatcherState;
class NGRAPH_API MatcherState { using ov::pass::pattern::RecurrentMatcher;
public:
MatcherState(Matcher*);
bool finish(bool is_successful);
~MatcherState();
protected:
Matcher* m_matcher;
PatternValueMap m_pattern_value_map;
PatternValueMaps m_pattern_value_maps;
size_t m_watermark;
size_t m_capture_size;
bool m_restore{true};
};
/// Matcher looks for node patterns in a computation graph. The patterns are described by an
/// automaton that is described by an extended computation graph. The matcher executes
/// by attempting to match the start node of the pattern to a computation graph value
/// (output of a Node). In addition to determing if a match occurs, a pattern node may add
/// graph nodes to a list of matched nodes, associate nodes with graph values, and start
/// submatches. Submatches add match state changes to the enclosing match if the submatch
/// succeeds; otherwise the state is reverted.
///
/// The default match behavior of a pattern node with a graph nodes is that the computation
/// graph value is added to the end of the matched value list and the match succeeds if the
/// node/pattern types match and the input values match. In the case of a commutative node,
/// the inputs can match in any order. If the matcher is in strict mode, the graph value
/// element type and shape must also match.
///
/// Pattern nodes that have different match behavior are in ngraph::pattern::op and have
/// descriptions of their match behavior.
class NGRAPH_API Matcher {
public:
using PatternMap = ngraph::pattern::PatternMap;
// Avoid implicit string construction from nullptr.
Matcher(const std::shared_ptr<Node> pattern_node, std::nullptr_t name) = delete;
Matcher() {}
Matcher(Output<Node>& pattern_node) : m_pattern_node{pattern_node} {}
Matcher(Output<Node>& pattern_node, const std::string& name) : m_pattern_node(pattern_node), m_name{name} {}
/// \brief Constructs a Matcher object
///
/// \param pattern_node is a pattern sub graph that will be matched against input graphs
/// \param name is a string which is used for logging and disabling a matcher
/// \param strict_mode forces a matcher to consider shapes and ET of nodes
Matcher(const Output<Node>& pattern_node, const std::string& name, bool strict_mode)
: m_pattern_node(pattern_node),
m_name(name),
m_strict_mode(strict_mode) {}
// Some matches should start on a node rather than an output. These three constructors
// are transition until we work out the right way to do that.
Matcher(std::shared_ptr<Node> pattern_node);
Matcher(std::shared_ptr<Node> pattern_node, const std::string& name);
Matcher(std::shared_ptr<Node> pattern_node, const std::string& name, bool strict_mode);
virtual ~Matcher() {}
/// \brief Matches a pattern to \p graph_node
///
/// \param graph_value is an input graph to be matched against
bool match(const Output<Node>& graph_value);
bool match(std::shared_ptr<Node> graph_node);
/// \brief Matches a pattern to \p graph_node
///
/// \param graph_value is an input graph to be matched against
/// \param previous_matches contains previous mappings from labels to nodes to use
bool match(const Output<Node>& graph_value, const PatternMap& previous_matches);
bool match(const Output<Node>& graph_value, const PatternValueMap& previous_matches);
template <typename T>
static std::shared_ptr<T> unique_match(std::shared_ptr<Node> node) {
std::shared_ptr<T> matched;
for (auto arg : node->input_values()) {
if (auto t_casted = ov::as_type_ptr<T>(arg.get_node_shared_ptr())) {
if (matched) {
throw ngraph_error("There's more than two arguments of the same type");
} else {
matched = t_casted;
}
}
}
return matched;
}
bool is_contained_match(const NodeVector& exclusions = {}, bool ignore_unused = true);
const NodeVector get_matched_nodes() {
return as_node_vector(m_matched_list);
}
const OutputVector& get_matched_values() const {
return m_matched_list;
}
OutputVector& get_matched_values() {
return m_matched_list;
}
void reset() {}
const std::string& get_name() {
return m_name;
}
std::shared_ptr<Node> get_pattern() {
return m_pattern_node.get_node_shared_ptr();
}
Output<Node> get_pattern_value() {
return m_pattern_node;
}
std::shared_ptr<Node> get_match_root();
Output<Node> get_match_value();
PatternMap get_pattern_map() const;
PatternValueMap& get_pattern_value_map() {
return m_pattern_map;
}
PatternValueMaps& get_pattern_value_maps() {
return m_pattern_value_maps;
}
/// \brief Low-level helper to match recurring patterns
///
/// \param graph is a graph to be matched against
/// \param pattern is a recurring pattern
/// \param rpattern specifies a node to recur from next
/// \param patterns a map from labels to matches
size_t add_node(Output<Node> node);
bool virtual match_value(const ngraph::Output<Node>& pattern_value, const ngraph::Output<Node>& graph_value);
bool is_strict_mode() {
return m_strict_mode;
}
virtual bool match_arguments(Node* pattern_node, const std::shared_ptr<Node>& graph_node);
void capture(const std::set<Node*>& static_nodes);
void clear_state();
size_t get_number_of_recurrent_matches() const {
return m_pattern_value_maps.size();
}
NodeVector get_bound_nodes_for_pattern(const Output<Node>& pattern) const;
size_t get_number_of_bound_labels() const;
/// \brief Try a match
MatcherState start_match();
Output<Node> m_match_root;
Output<Node> m_pattern_node;
PatternValueMap m_pattern_map;
PatternValueMaps m_pattern_value_maps;
OutputVector m_matched_list;
protected:
bool match_permutation(const OutputVector& pattern_args, const OutputVector& args);
std::string m_name{"unnamed"};
bool m_strict_mode{false};
};
class NGRAPH_API RecurrentMatcher {
public:
/// \brief Constructs a RecurrentMatcher object. Reccurent Matchers are used to match
/// repeating patterns (e.g. RNN, LSTM, GRU cells)
///
/// \param initial_pattern is a pattern sub graph describing the initial cell
/// \param pattern is a pattern sub graph describing an individual cell
/// \param rpattern is a (recurring) label to denote which node the next match should
/// start at
/// \param correlated_patterns is a set of labels whose bound nodes must remain the same
/// across all cells
RecurrentMatcher(const Output<Node>& initial_pattern,
const Output<Node>& pattern,
const std::shared_ptr<Node>& rpattern,
const std::set<std::shared_ptr<Node>>& correlated_patterns)
: m_initial_pattern(initial_pattern),
m_pattern(pattern),
m_recurrent_pattern(rpattern),
m_correlated_patterns(correlated_patterns) {}
/// \brief Constructs a RecurrentMatcher object. Reccurent Matchers are used to match
/// repeating patterns (e.g. RNN, LSTM, GRU cells)
///
/// \param pattern is a pattern sub graph describing an individual cell
/// \param rpattern is a (recurring) label to denote which node the next match should
/// start at
/// \param correlated_patterns is a set of labels whose bound nodes must remain the same
/// across all cells
RecurrentMatcher(const Output<Node>& pattern,
const std::shared_ptr<Node>& rpattern,
const std::set<std::shared_ptr<Node>>& correlated_patterns)
: RecurrentMatcher(pattern, pattern, rpattern, correlated_patterns) {}
RecurrentMatcher(const Output<Node>& initial_pattern,
const Output<Node>& pattern,
const std::shared_ptr<Node>& rpattern,
const std::set<std::shared_ptr<op::Label>>& correlated_patterns);
RecurrentMatcher(const Output<Node>& pattern,
const std::shared_ptr<Node>& rpattern,
const std::set<std::shared_ptr<op::Label>>& correlated_patterns)
: RecurrentMatcher(pattern, pattern, rpattern, correlated_patterns) {}
/// \brief Returns a vector of bound nodes for a given label (used in a pattern
/// describing an individual cell
NodeVector get_bound_nodes_for_pattern(const std::shared_ptr<Node>& pattern) const {
if (m_matches.count(pattern) == 0) {
throw ngraph_error("No bound nodes for a given label");
}
return as_node_vector(m_matches.at(pattern));
}
size_t get_number_of_recurrent_matches() const {
if (m_matches.size() == 0) {
return 0;
}
return (*m_matches.begin()).second.size();
}
size_t get_number_of_bound_labels() const {
return m_matches.size();
}
/// \brief Tries to match a pattern for an individual cell to a given \p graph
bool match(Output<Node> graph);
std::shared_ptr<Node> get_match_root() {
return m_match_root.get_node_shared_ptr();
}
Output<Node> get_match_value() {
return m_match_root;
}
private:
Output<Node> m_initial_pattern;
Output<Node> m_pattern;
std::shared_ptr<Node> m_recurrent_pattern;
const std::set<std::shared_ptr<Node>> m_correlated_patterns;
RPatternValueMap m_matches;
Output<Node> m_match_root;
};
} // namespace pattern } // namespace pattern
} // namespace ngraph } // namespace ngraph

View File

@ -6,38 +6,12 @@
#include "ngraph/node.hpp" #include "ngraph/node.hpp"
#include "ngraph/pattern/op/pattern.hpp" #include "ngraph/pattern/op/pattern.hpp"
#include "openvino/pass/pattern/op/any.hpp"
namespace ngraph { namespace ngraph {
namespace pattern { namespace pattern {
namespace op { namespace op {
/// The graph value is to the matched value list. If the predicate is true for the node using ov::pass::pattern::op::Any;
/// and the arguments match, the match succeeds.
class NGRAPH_API Any : public Pattern {
public:
static constexpr NodeTypeInfo type_info{"patternAny", 0};
const NodeTypeInfo& get_type_info() const override;
/// \brief creates a Any node containing a sub-pattern described by \sa type and \sa
/// shape.
Any(const element::Type& type, const PartialShape& s, ValuePredicate pred, const OutputVector& wrapped_values)
: Pattern(wrapped_values, pred) {
set_output_type(0, type, s);
}
Any(const element::Type& type, const PartialShape& s, NodePredicate pred, const NodeVector& wrapped_values)
: Any(type, s, as_value_predicate(pred), as_output_vector(wrapped_values)) {}
/// \brief creates a Any node containing a sub-pattern described by the type and
/// shape of \sa node.
Any(const Output<Node>& node, ValuePredicate pred, const OutputVector& wrapped_values)
: Any(node.get_element_type(), node.get_partial_shape(), pred, wrapped_values) {}
Any(const Output<Node>& node, NodePredicate pred, const NodeVector& wrapped_values)
: Any(node.get_element_type(),
node.get_partial_shape(),
as_value_predicate(pred),
as_output_vector(wrapped_values)) {}
bool match_value(pattern::Matcher* matcher,
const Output<Node>& pattern_value,
const Output<Node>& graph_value) override;
};
} // namespace op } // namespace op
} // namespace pattern } // namespace pattern
} // namespace ngraph } // namespace ngraph

View File

@ -6,47 +6,12 @@
#include "ngraph/node.hpp" #include "ngraph/node.hpp"
#include "ngraph/pattern/op/pattern.hpp" #include "ngraph/pattern/op/pattern.hpp"
#include "openvino/pass/pattern/op/any_of.hpp"
namespace ngraph { namespace ngraph {
namespace pattern { namespace pattern {
namespace op { namespace op {
/// The graph value is added to the matched values list. If the predicate is true for using ov::pass::pattern::op::AnyOf;
/// the
/// graph node, a submatch is performed on the input of AnyOf and each input of the
/// graph node. The first match that succeeds results in a successful match. Otherwise
/// the match fails.
///
/// AnyOf may be given a type and shape for use in strict mode.
class NGRAPH_API AnyOf : public Pattern {
public:
static constexpr NodeTypeInfo type_info{"patternAnyOf", 0};
const NodeTypeInfo& get_type_info() const override;
/// \brief creates a AnyOf node containing a sub-pattern described by \sa type and
/// \sa shape.
AnyOf(const element::Type& type, const PartialShape& s, ValuePredicate pred, const OutputVector& wrapped_values)
: Pattern(wrapped_values, pred) {
if (wrapped_values.size() != 1) {
throw ngraph_error("AnyOf expects exactly one argument");
}
set_output_type(0, type, s);
}
AnyOf(const element::Type& type, const PartialShape& s, NodePredicate pred, const NodeVector& wrapped_values)
: AnyOf(
type,
s,
[pred](const Output<Node>& value) {
return pred(value.get_node_shared_ptr());
},
as_output_vector(wrapped_values)) {}
/// \brief creates a AnyOf node containing a sub-pattern described by the type and
/// shape of \sa node.
AnyOf(const Output<Node>& node, ValuePredicate pred, const OutputVector& wrapped_values)
: AnyOf(node.get_element_type(), node.get_partial_shape(), pred, wrapped_values) {}
AnyOf(std::shared_ptr<Node> node, NodePredicate pred, const NodeVector& wrapped_values)
: AnyOf(node, as_value_predicate(pred), as_output_vector(wrapped_values)) {}
bool match_value(Matcher* matcher, const Output<Node>& pattern_value, const Output<Node>& graph_value) override;
};
} // namespace op } // namespace op
} // namespace pattern } // namespace pattern
} // namespace ngraph } // namespace ngraph

View File

@ -6,23 +6,12 @@
#include "ngraph/node.hpp" #include "ngraph/node.hpp"
#include "ngraph/pattern/op/pattern.hpp" #include "ngraph/pattern/op/pattern.hpp"
#include "openvino/pass/pattern/op/any_output.hpp"
namespace ngraph { namespace ngraph {
namespace pattern { namespace pattern {
namespace op { namespace op {
/// Matches any output of a node using ov::pass::pattern::op::AnyOutput;
class NGRAPH_API AnyOutput : public Pattern {
public:
static constexpr NodeTypeInfo type_info{"patternAnyOutput", 0};
const NodeTypeInfo& get_type_info() const override;
/// \brief creates an AnyOutput node matching any output of a node
/// \param node The node to match
AnyOutput(const std::shared_ptr<Node>& pattern) : Pattern({pattern->output(0)}) {}
bool match_value(pattern::Matcher* matcher,
const Output<Node>& pattern_value,
const Output<Node>& graph_value) override;
};
} // namespace op } // namespace op
} // namespace pattern } // namespace pattern
} // namespace ngraph } // namespace ngraph

View File

@ -6,48 +6,12 @@
#include "ngraph/node.hpp" #include "ngraph/node.hpp"
#include "ngraph/pattern/op/pattern.hpp" #include "ngraph/pattern/op/pattern.hpp"
#include "openvino/pass/pattern/op/branch.hpp"
namespace ngraph { namespace ngraph {
namespace pattern { namespace pattern {
namespace op { namespace op {
/// A branch adds a loop to the pattern. The branch match is successful if the using ov::pass::pattern::op::Branch;
/// destination node pattern matches the graph value. The destination node is a node in
/// the pattern graph that will not have been created some time after the Branch node is
/// created; use set_destination to add it.
///
/// The branch destination is not stored as a shared pointer to prevent reference
/// cycles. Thus the destination node must be referenced in some other way to prevent it
/// from being deleted.
class NGRAPH_API Branch : public Pattern {
public:
static constexpr NodeTypeInfo type_info{"patternBranch", 0};
const NodeTypeInfo& get_type_info() const override;
/// \brief Creates a Branch pattern
/// \param pattern the destinationing pattern
/// \param labels Labels where the destination may occur
Branch() : Pattern(OutputVector{}) {
set_output_type(0, element::f32, Shape{});
}
void set_destination(const Output<Node>& destination) {
m_destination_node = destination.get_node();
m_destination_index = destination.get_index();
}
Output<Node> get_destination() const {
return m_destination_node == nullptr
? Output<Node>()
: Output<Node>{m_destination_node->shared_from_this(), m_destination_index};
}
bool match_value(pattern::Matcher* matcher,
const Output<Node>& pattern_value,
const Output<Node>& graph_value) override;
protected:
Node* m_destination_node{nullptr};
size_t m_destination_index{0};
};
} // namespace op } // namespace op
} // namespace pattern } // namespace pattern
} // namespace ngraph } // namespace ngraph

View File

@ -6,37 +6,12 @@
#include "ngraph/node.hpp" #include "ngraph/node.hpp"
#include "ngraph/pattern/op/pattern.hpp" #include "ngraph/pattern/op/pattern.hpp"
#include "openvino/pass/pattern/op/capture.hpp"
namespace ngraph { namespace ngraph {
namespace pattern { namespace pattern {
namespace op { namespace op {
/// Experimental for support of recurrent matches. using ov::pass::pattern::op::Capture;
///
/// Capture adds the pattern value map to a list of pattern value maps and resets
/// matches for pattern nodes not in the static node list. The match always succeeds.
class NGRAPH_API Capture : public Pattern {
public:
static constexpr NodeTypeInfo type_info{"patternCapture", 0};
const NodeTypeInfo& get_type_info() const override;
Capture(const Output<Node>& arg) : Pattern({arg}) {
set_output_type(0, arg.get_element_type(), arg.get_partial_shape());
}
/// \brief static nodes are retained after a capture. All other nodes are dropped
std::set<Node*> get_static_nodes() {
return m_static_nodes;
}
void set_static_nodes(const std::set<Node*>& static_nodes) {
m_static_nodes = static_nodes;
}
virtual bool match_value(pattern::Matcher* matcher,
const Output<Node>& pattern_value,
const Output<Node>& graph_value) override;
protected:
std::set<Node*> m_static_nodes;
};
} // namespace op } // namespace op
} // namespace pattern } // namespace pattern
} // namespace ngraph } // namespace ngraph

View File

@ -6,106 +6,14 @@
#include "ngraph/node.hpp" #include "ngraph/node.hpp"
#include "ngraph/pattern/op/pattern.hpp" #include "ngraph/pattern/op/pattern.hpp"
#include "openvino/pass/pattern/op/label.hpp"
namespace ngraph { namespace ngraph {
namespace pattern { namespace pattern {
namespace op { namespace op {
/// Fails if the predicate returns false on the graph value. using ov::pass::pattern::op::Label;
///
/// The graph value is added to the matched values list. If the Label is already
/// associated with a value, the match succeeds if the value is the same as the graph
/// value. Otherwise, the label is associated with the graph value and the match
/// succeeds if the pattern input matches the graph value.
///
/// DEPRECATED: If no inputs are given to Label, a True node is serves as the input. If
/// more than one inputs are given, an Or pattern of the inputs serves as the input.
class NGRAPH_API Label : public Pattern {
public:
static constexpr NodeTypeInfo type_info{"patternLabel", 0};
const NodeTypeInfo& get_type_info() const override;
/// \brief creates a Label node containing a sub-pattern described by \sa type and
/// \sa shape.
///
/// this Label node can be bound only to the nodes in the input graph
/// that match the pattern specified by \sa wrapped_nodes
/// Example:
/// \code{.cpp}
/// auto add = a + b; // a and b are op::Parameter in this example
/// auto label = std::make_shared<pattern::op::Label>(element::f32,
/// Shape{2,2},
/// nullptr,
/// OutputVector{add});
/// \endcode
Label(const element::Type& type,
const PartialShape& s,
const ValuePredicate pred,
const OutputVector& wrapped_values)
: Pattern(OutputVector{wrap_values(wrapped_values)}, pred) {
set_output_type(0, type, s);
}
explicit Label(const element::Type& type = element::dynamic, const PartialShape& s = PartialShape::dynamic())
: Label(
type,
s,
[](const Output<Node>&) {
return true;
},
OutputVector()) {}
Label(const element::Type& type, const PartialShape& s, ValuePredicate pred)
: Label(type, s, pred, OutputVector{}) {}
Label(const element::Type& type, const PartialShape& s, NodePredicate pred)
: Label(type, s, as_value_predicate(pred), OutputVector{}) {}
Label(const element::Type& type, const PartialShape& s, const NodePredicate pred, const NodeVector& wrapped_values)
: Label(type, s, as_value_predicate(pred), as_output_vector(wrapped_values)) {}
/// \brief creates a Label node containing a sub-pattern described by the type and
/// shape of \sa node.
///
/// this Label node can be bound only to the nodes in the input graph
/// that match the pattern specified by \sa wrapped_values
/// Example:
/// \code{.cpp}
/// auto add = a + b; // a and b are op::Parameter in this example
/// auto label = std::make_shared<pattern::op::Label>(add,
/// nullptr,
/// OutputVector{add});
/// \endcode
Label(const Output<Node>& value, const ValuePredicate pred, const OutputVector& wrapped_values)
: Label(value.get_element_type(), value.get_partial_shape(), pred, wrapped_values) {}
Label(const Output<Node>& value, const ValuePredicate pred)
: Label(value.get_element_type(), value.get_partial_shape(), pred, OutputVector{}) {}
Label(const Output<Node>& value, const NodePredicate pred)
: Label(value.get_element_type(), value.get_partial_shape(), as_value_predicate(pred), OutputVector{}) {}
Label(const Output<Node>& value)
: Label(
value.get_element_type(),
value.get_partial_shape(),
[](const Output<Node>&) {
return true;
},
OutputVector{}) {}
Label(const Output<Node>& node, const NodePredicate pred, const NodeVector& wrapped_values)
: Label(node.get_element_type(),
node.get_partial_shape(),
as_value_predicate(pred),
as_output_vector(wrapped_values)) {}
bool match_value(Matcher* matcher, const Output<Node>& pattern_value, const Output<Node>& graph_value) override;
protected:
static Output<Node> wrap_values(const OutputVector& wrapped_values);
};
} // namespace op } // namespace op
NGRAPH_API using ov::pass::pattern::any_input;
std::shared_ptr<Node> any_input();
NGRAPH_API
std::shared_ptr<Node> any_input(const pattern::op::ValuePredicate& pred);
} // namespace pattern } // namespace pattern
} // namespace ngraph } // namespace ngraph

View File

@ -6,25 +6,12 @@
#include "ngraph/node.hpp" #include "ngraph/node.hpp"
#include "ngraph/pattern/op/pattern.hpp" #include "ngraph/pattern/op/pattern.hpp"
#include "openvino/pass/pattern/op/or.hpp"
namespace ngraph { namespace ngraph {
namespace pattern { namespace pattern {
namespace op { namespace op {
/// A submatch on the graph value is performed on each input to the Or; the match using ov::pass::pattern::op::Or;
/// succeeds on the first match. Otherwise the match fails.
class NGRAPH_API Or : public Pattern {
public:
static constexpr NodeTypeInfo type_info{"patternOr", 0};
const NodeTypeInfo& get_type_info() const override;
/// \brief creates an Or node matching one of several sub-patterns in order. Does
/// not add node to match list.
/// \param patterns The patterns to try for matching
Or(const OutputVector& patterns) : Pattern(patterns) {}
bool match_value(pattern::Matcher* matcher,
const Output<Node>& pattern_value,
const Output<Node>& graph_value) override;
};
} // namespace op } // namespace op
} // namespace pattern } // namespace pattern
} // namespace ngraph } // namespace ngraph

View File

@ -7,8 +7,10 @@
#include <functional> #include <functional>
#include "ngraph/node.hpp" #include "ngraph/node.hpp"
#include "openvino/pass/pattern/op/pattern.hpp"
namespace ngraph { namespace ov {
namespace pass {
namespace pattern { namespace pattern {
namespace op { namespace op {
class Label; class Label;
@ -16,79 +18,42 @@ class Label;
class Matcher; class Matcher;
class MatchState; class MatchState;
} // namespace pattern
using RPatternValueMap = std::map<std::shared_ptr<Node>, OutputVector>; } // namespace pass
using PatternValueMap = std::map<std::shared_ptr<Node>, Output<Node>>; } // namespace ov
using PatternValueMaps = std::vector<PatternValueMap>; namespace ngraph {
namespace pattern {
using PatternMap = std::map<std::shared_ptr<Node>, std::shared_ptr<Node>>; namespace op {
using ov::pass::pattern::op::Label;
PatternMap as_pattern_map(const PatternValueMap& pattern_value_map);
PatternValueMap as_pattern_value_map(const PatternMap& pattern_map);
template <typename T>
std::function<bool(std::shared_ptr<Node>)> has_class() {
auto pred = [](std::shared_ptr<Node> node) -> bool {
return ov::is_type<T>(node);
};
return pred;
} }
NGRAPH_API using ov::pass::pattern::Matcher;
std::function<bool(Output<Node>)> consumers_count(size_t n); using ov::pass::pattern::MatcherState;
NGRAPH_API using ov::pass::pattern::PatternValueMap;
std::function<bool(Output<Node>)> has_static_dim(size_t pos); using ov::pass::pattern::PatternValueMaps;
using ov::pass::pattern::RPatternValueMap;
NGRAPH_API using ov::pass::pattern::PatternMap;
std::function<bool(Output<Node>)> has_static_dims(const std::vector<size_t>& dims);
NGRAPH_API using ov::pass::pattern::as_pattern_map;
std::function<bool(Output<Node>)> has_static_shape(); using ov::pass::pattern::as_pattern_value_map;
using ov::pass::pattern::consumers_count;
NGRAPH_API using ov::pass::pattern::has_class;
std::function<bool(Output<Node>)> has_static_rank(); using ov::pass::pattern::has_static_dim;
using ov::pass::pattern::has_static_dims;
NGRAPH_API using ov::pass::pattern::has_static_rank;
std::function<bool(Output<Node>)> rank_equals(const Dimension& expected_rank); using ov::pass::pattern::has_static_shape;
using ov::pass::pattern::rank_equals;
NGRAPH_API using ov::pass::pattern::type_matches;
std::function<bool(Output<Node>)> type_matches(const element::Type& type); using ov::pass::pattern::type_matches_any;
NGRAPH_API
std::function<bool(Output<Node>)> type_matches_any(const std::vector<element::Type>& types);
namespace op { namespace op {
using NodePredicate = std::function<bool(std::shared_ptr<Node>)>; using ov::pass::pattern::op::NodePredicate;
using ValuePredicate = std::function<bool(const Output<Node>& value)>; using ov::pass::pattern::op::ValuePredicate;
NGRAPH_API using ov::pass::pattern::op::as_value_predicate;
ValuePredicate as_value_predicate(NodePredicate pred); using ov::pass::pattern::op::Pattern;
class NGRAPH_API Pattern : public Node {
public:
/// \brief \p a base class for \sa Skip and \sa Label
///
Pattern(const OutputVector& patterns, ValuePredicate pred) : Node(patterns), m_predicate(pred) {
if (!m_predicate) {
m_predicate = [](const Output<Node>&) {
return true;
};
}
}
Pattern(const OutputVector& patterns) : Pattern(patterns, nullptr) {}
virtual std::shared_ptr<Node> clone_with_new_inputs(const OutputVector& /* new_args */) const override {
throw ngraph_error("Uncopyable");
}
ValuePredicate get_predicate() const;
protected:
ValuePredicate m_predicate;
};
} // namespace op } // namespace op
} // namespace pattern } // namespace pattern
} // namespace ngraph } // namespace ngraph

View File

@ -6,37 +6,12 @@
#include "ngraph/node.hpp" #include "ngraph/node.hpp"
#include "ngraph/pattern/op/pattern.hpp" #include "ngraph/pattern/op/pattern.hpp"
#include "openvino/pass/pattern/op/skip.hpp"
namespace ngraph { namespace ngraph {
namespace pattern { namespace pattern {
namespace op { namespace op {
/// The graph value is added to the matched value list. If the predicate is true, the using ov::pass::pattern::op::Skip;
/// match succeeds if the arguments match; if the predicate is false, the match succeeds
/// if the pattern input matches the graph value.
class NGRAPH_API Skip : public Pattern {
public:
static constexpr NodeTypeInfo type_info{"patternSkip", 0};
const NodeTypeInfo& get_type_info() const override;
Skip(const Output<Node>& arg, ValuePredicate pred) : Pattern({arg}, pred) {
set_output_type(0, arg.get_element_type(), arg.get_partial_shape());
}
Skip(const Output<Node>& arg, NodePredicate pred = nullptr) : Pattern({arg}, as_value_predicate(pred)) {
set_output_type(0, arg.get_element_type(), arg.get_partial_shape());
}
Skip(const OutputVector& args, ValuePredicate pred) : Pattern(args, pred) {
set_output_type(0, args.at(0).get_element_type(), args.at(0).get_partial_shape());
}
Skip(const OutputVector& args, NodePredicate pred = nullptr) : Pattern(args, as_value_predicate(pred)) {
set_output_type(0, args.at(0).get_element_type(), args.at(0).get_partial_shape());
}
virtual bool match_value(pattern::Matcher* matcher,
const Output<Node>& pattern_value,
const Output<Node>& graph_value) override;
};
} // namespace op } // namespace op
} // namespace pattern } // namespace pattern
} // namespace ngraph } // namespace ngraph

View File

@ -6,21 +6,12 @@
#include "ngraph/node.hpp" #include "ngraph/node.hpp"
#include "ngraph/pattern/op/pattern.hpp" #include "ngraph/pattern/op/pattern.hpp"
#include "openvino/pass/pattern/op/true.hpp"
namespace ngraph { namespace ngraph {
namespace pattern { namespace pattern {
namespace op { namespace op {
/// \brief The match always succeeds. using ov::pass::pattern::op::True;
class NGRAPH_API True : public Pattern {
public:
static constexpr NodeTypeInfo type_info{"patternTrue", 0};
const NodeTypeInfo& get_type_info() const override;
/// \brief Always matches, does not add node to match list.
True() : Pattern(OutputVector{}) {}
bool match_value(pattern::Matcher* matcher,
const Output<Node>& pattern_value,
const Output<Node>& graph_value) override;
};
} // namespace op } // namespace op
} // namespace pattern } // namespace pattern
} // namespace ngraph } // namespace ngraph

View File

@ -6,68 +6,14 @@
#include "ngraph/node.hpp" #include "ngraph/node.hpp"
#include "ngraph/pattern/op/pattern.hpp" #include "ngraph/pattern/op/pattern.hpp"
#include "openvino/pass/pattern/op/wrap_type.hpp"
namespace ngraph { namespace ngraph {
namespace pattern { namespace pattern {
namespace op { namespace op {
class NGRAPH_API WrapType : public Pattern { using ov::pass::pattern::op::WrapType;
public:
static constexpr NodeTypeInfo type_info{"patternAnyType", 0};
const NodeTypeInfo& get_type_info() const override;
explicit WrapType(
NodeTypeInfo wrapped_type,
const ValuePredicate& pred =
[](const Output<Node>& output) {
return true;
},
const OutputVector& input_values = {})
: Pattern(input_values, pred),
m_wrapped_types({wrapped_type}) {
set_output_type(0, element::Type_t::dynamic, PartialShape::dynamic());
}
explicit WrapType(
std::vector<NodeTypeInfo> wrapped_types,
const ValuePredicate& pred =
[](const Output<Node>& output) {
return true;
},
const OutputVector& input_values = {})
: Pattern(input_values, pred),
m_wrapped_types(std::move(wrapped_types)) {
set_output_type(0, element::Type_t::dynamic, PartialShape::dynamic());
}
bool match_value(pattern::Matcher* matcher,
const Output<Node>& pattern_value,
const Output<Node>& graph_value) override;
NodeTypeInfo get_wrapped_type() const;
const std::vector<NodeTypeInfo>& get_wrapped_types() const;
private:
std::vector<NodeTypeInfo> m_wrapped_types;
};
} // namespace op } // namespace op
template <class... Args> using ov::pass::pattern::wrap_type;
std::shared_ptr<Node> wrap_type(const OutputVector& inputs, const pattern::op::ValuePredicate& pred) {
std::vector<DiscreteTypeInfo> info{Args::type_info...};
return std::make_shared<op::WrapType>(info, pred, inputs);
}
template <class... Args>
std::shared_ptr<Node> wrap_type(const OutputVector& inputs = {}) {
return wrap_type<Args...>(inputs, [](const Output<Node>& output) {
return true;
});
}
template <class... Args>
std::shared_ptr<Node> wrap_type(const pattern::op::ValuePredicate& pred) {
return wrap_type<Args...>({}, pred);
}
} // namespace pattern } // namespace pattern
} // namespace ngraph } // namespace ngraph

View File

@ -50,12 +50,14 @@ class Result;
} // namespace v0 } // namespace v0
} // namespace op } // namespace op
namespace pattern {
class Matcher;
} // namespace pattern
} // namespace ngraph } // namespace ngraph
namespace ov { namespace ov {
namespace pass {
namespace pattern {
class Matcher;
} // namespace pattern
} // namespace pass
using HostTensor = ngraph::runtime::HostTensor; using HostTensor = ngraph::runtime::HostTensor;
using HostTensorPtr = std::shared_ptr<HostTensor>; using HostTensorPtr = std::shared_ptr<HostTensor>;
using HostTensorVector = std::vector<HostTensorPtr>; using HostTensorVector = std::vector<HostTensorPtr>;
@ -487,11 +489,11 @@ public:
} }
OPENVINO_SUPPRESS_DEPRECATED_END OPENVINO_SUPPRESS_DEPRECATED_END
virtual bool match_value(ngraph::pattern::Matcher* matcher, virtual bool match_value(ov::pass::pattern::Matcher* matcher,
const Output<Node>& pattern_value, const Output<Node>& pattern_value,
const Output<Node>& graph_value); const Output<Node>& graph_value);
virtual bool match_node(ngraph::pattern::Matcher* matcher, const Output<Node>& graph_value); virtual bool match_node(ov::pass::pattern::Matcher* matcher, const Output<Node>& graph_value);
private: private:
descriptor::Input& get_input_descriptor(size_t position); descriptor::Input& get_input_descriptor(size_t position);

View File

@ -0,0 +1,28 @@
// Copyright (C) 2018-2021 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#pragma once
#include "openvino/pass/pass.hpp"
namespace ov {
namespace pass {
/**
* @brief Constant folding iterates over the function and tries to evaluate nodes
* with constant inputs. Such nodes are then replaced with new Constants containing
* the result of a folded operation.
*/
class OPENVINO_API ConstantFolding : public FunctionPass {
public:
OPENVINO_RTTI_DECLARATION;
bool run_on_function(std::shared_ptr<ngraph::Function> f) override;
private:
void copy_runtime_info_to_target_inputs(const std::shared_ptr<Node>& node, const Output<Node>& replacement);
/// \brief Folds pre-calculated output tensor values to constants in case lower and
/// upper estimations are equal. Traverses graph backwards starting from the results.
bool pre_calculated_values_folding(const std::shared_ptr<ngraph::Function>& f);
};
} // namespace pass
} // namespace ov

View File

@ -0,0 +1,17 @@
// Copyright (C) 2018-2021 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#pragma once
#include "openvino/pass/graph_rewrite.hpp"
namespace ov {
namespace pass {
class OPENVINO_API ConvertFP32ToFP16 : public FunctionPass {
public:
OPENVINO_RTTI_DECLARATION;
bool run_on_function(std::shared_ptr<ngraph::Function>) override;
};
} // namespace pass
} // namespace ov

View File

@ -0,0 +1,249 @@
// Copyright (C) 2018-2021 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#pragma once
#include <functional>
#include <memory>
#include <set>
#include "openvino/pass/pass.hpp"
#include "openvino/pass/pattern/matcher.hpp"
namespace ov {
using matcher_pass_callback = std::function<bool(pass::pattern::Matcher& m)>;
using graph_rewrite_callback = std::function<bool(pass::pattern::Matcher& m)>;
using recurrent_graph_rewrite_callback = std::function<bool(pass::pattern::RecurrentMatcher& m)>;
using handler_callback = std::function<bool(const std::shared_ptr<Node>& node)>;
namespace pass {
/// \brief MatcherPass is a basic block for pattern based transformations. It describes
/// pattern and
/// action that is applied if pattern is matched.
///
/// MatcherPass consists of Matcher and matcher_pass_callback that needs to be implemented
/// and
/// finally registered by using \sa register_matcher. MatcherPass can be executed on node
/// within
/// \sa apply method. To run matcher pass on Function use GraphRewrite.
/// In addition MatcherPass provides a way for adding new operations into GraphRewrite
/// execution
/// queue. That means that operations that were created inside transformation callback can
/// be added
/// for matching. To register node use \sa register_new_node method. GraphRewrite
/// automatically
/// takes registered nodes and put them to execution queue. If multiple nodes were register
/// make
/// sure that they were registered in topological order.
/// Note: when implementing pattern for Matcher make sure that root node is an operation
/// from opset
/// or has ov::pass::pattern::op::WrapType. That will help GraphRewrite to execute matcher
/// passes more
/// efficient.
class OPENVINO_API MatcherPass : public PassBase {
public:
OPENVINO_RTTI_DECLARATION;
MatcherPass() = default;
MatcherPass(const MatcherPass&) = delete;
MatcherPass& operator=(const MatcherPass&) = delete;
explicit MatcherPass(const std::string& name,
const std::shared_ptr<pattern::Matcher>& m,
const handler_callback& handler,
const PassPropertyMask& property = PassProperty::CHANGE_DYNAMIC_STATE)
: PassBase(),
m_handler(handler),
m_matcher(m) {
set_name(name);
set_property(property, true);
}
bool apply(std::shared_ptr<ov::Node> node);
template <typename T, class... Args>
std::shared_ptr<T> register_new_node(Args&&... args) {
auto node = std::make_shared<T>(std::forward<Args>(args)...);
m_new_nodes.push_back(node);
return node;
}
template <typename T>
std::shared_ptr<T> register_new_node(const std::shared_ptr<T>& node) {
m_new_nodes.push_back(node);
return node;
}
const std::vector<std::shared_ptr<ov::Node>>& get_new_nodes() {
return m_new_nodes;
}
void clear_new_nodes() {
m_new_nodes.clear();
}
std::shared_ptr<pattern::Matcher> get_matcher() {
return m_matcher;
}
protected:
void register_matcher(const std::shared_ptr<pattern::Matcher>& m,
const graph_rewrite_callback& callback,
const PassPropertyMask& property = PassProperty::CHANGE_DYNAMIC_STATE);
private:
handler_callback m_handler;
std::shared_ptr<pattern::Matcher> m_matcher;
std::vector<std::shared_ptr<ov::Node>> m_new_nodes;
};
/// \brief GraphRewrite is a container for MatcherPasses that allows to run them on Function
/// in
/// efficient way
///
/// Graph rewrite pass is used for matcher passes execution on Function.
/// To register MatcherPass use \sa add_matcher<T>(args) method where T is a MatcherPass
/// class.
/// As a default algorithm graph rewrite pass traverse Function in topological order and
/// applies
/// registered matcher passes for each node. But if all registered matcher passes have type
/// based
/// root node in Matcher pattern then efficient mechanism is used to execute them.
/// Matcher pattern root is type based if it's operation from opset or
/// pattern::op::WrapType.
/// Note: when implementing pattern for Matcher make sure that root node is an operation
/// from opset
/// or has ov::pattern::op::WrapType. That will help GraphRewrite to execute matcher
/// passes more
/// efficient.
class OPENVINO_API GraphRewrite : public FunctionPass {
public:
OPENVINO_RTTI_DECLARATION;
GraphRewrite() = default;
explicit GraphRewrite(const std::shared_ptr<MatcherPass>& pass) : FunctionPass() {
m_matchers.push_back(pass);
}
/// \brief Register given transformation class type to GraphRewrite execution list
/// All registered transformations will be executed in a single graph traversal.
/// Example below show the basic usage of pass::GraphRewrite
///
/// pass::Manager manager;
/// auto anchor = manager.register_pass<GraphRewrite>();
/// anchor->add_matcher<MatcherPassA>();
/// anchor->add_matcher<MatcherPassB>();
/// anchor->set_name("CommonMatchers");
/// manager.run_passes(f);
///
/// For some purposes transformation can be registered and disabled by default.
///
/// anchor->add_matcher<MatcherPassB, false>();
///
/// \return shared_ptr to the transformation instance
template <typename T,
bool Enabled = true,
class... Args,
typename std::enable_if<std::is_base_of<pass::MatcherPass, T>::value, bool>::type = true>
std::shared_ptr<T> add_matcher(Args&&... args) {
static_assert(std::is_base_of<pass::MatcherPass, T>::value, "pass not derived from MatcherPass");
auto pass = std::make_shared<T>(std::forward<Args>(args)...);
auto pass_config = get_pass_config();
pass->set_pass_config(pass_config);
if (!Enabled && !pass_config->is_enabled<T>()) {
pass_config->disable<T>();
}
m_matchers.push_back(pass);
return pass;
}
/// \brief Register passes from GraphRewrite class that contains sequence of matcher
/// passes registered in its ctor.
/// For example:
///
/// class ov::pass::LinFusions: public ov::pass::GraphRewrite {
/// public:
/// OPENVINO_RTTI_DECLARATION;
/// Fusions() {
/// add_matcher<ov::pass::AddFusion>();
/// add_matcher<ov::pass::MulFusion>();
/// }
/// };
///
/// pass::Manager manager;
/// auto anchor = manager.register_pass<GraphRewrite>();
/// anchor->add_matcher<LinFusions>();
/// anchor->add_matcher<OtherFusions>();
/// anchor->set_name("CommonFusions");
/// manager.run_passes(f);
///
/// In this case all matcher passes from LinFusions pass will be united with other
/// registered matchers.
template <typename T,
class... Args,
typename std::enable_if<std::is_base_of<pass::GraphRewrite, T>::value, bool>::type = true>
void add_matcher(Args&&... args) {
static_assert(std::is_base_of<pass::GraphRewrite, T>::value, "pass not derived from GraphRewrite");
auto pass = std::make_shared<T>(std::forward<Args>(args)...);
auto pass_config = get_pass_config();
for (auto& matcher : pass->m_matchers) {
pass->set_pass_config(pass_config);
m_matchers.push_back(matcher);
}
}
OPENVINO_DEPRECATED("Use MatcherPass instead")
void add_matcher(const std::shared_ptr<pattern::Matcher>& m,
const graph_rewrite_callback& callback,
const PassPropertyMask& property);
OPENVINO_DEPRECATED("Use MatcherPass instead")
void add_matcher(const std::shared_ptr<pattern::Matcher>& m, const ov::graph_rewrite_callback& callback);
bool run_on_function(std::shared_ptr<ov::Function> f) override;
void set_pass_config(const std::shared_ptr<PassConfig>& pass_config) override;
protected:
bool apply_matcher_passes(std::shared_ptr<Function> f, std::deque<std::weak_ptr<Node>> nodes_to_run);
bool m_enable_shape_inference = false;
std::vector<std::shared_ptr<ov::pass::MatcherPass>> m_matchers;
};
class OPENVINO_API BackwardGraphRewrite : public GraphRewrite {
public:
OPENVINO_RTTI_DECLARATION;
BackwardGraphRewrite() = default;
explicit BackwardGraphRewrite(const std::shared_ptr<MatcherPass>& pass) : GraphRewrite(pass) {}
bool run_on_function(std::shared_ptr<ov::Function> f) override;
};
class OPENVINO_API RecurrentGraphRewrite : public FunctionPass {
public:
RecurrentGraphRewrite(size_t num_iters = 10) : FunctionPass(), m_num_iters(num_iters) {}
void add_matcher(const std::shared_ptr<pattern::RecurrentMatcher>& m,
const ov::recurrent_graph_rewrite_callback& callback,
const PassPropertyMask& property);
// TODO: This interface may deprecate after all passes are refactored.
void add_matcher(const std::shared_ptr<pattern::RecurrentMatcher>& m,
const ov::recurrent_graph_rewrite_callback& callback);
bool run_on_function(std::shared_ptr<ov::Function> f) override;
private:
size_t m_num_iters;
std::vector<std::shared_ptr<ov::pass::MatcherPass>> m_matchers;
};
} // namespace pass
} // namespace ov

View File

@ -0,0 +1,48 @@
// Copyright (C) 2018-2021 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#pragma once
#include <memory>
#include <vector>
#include "openvino/pass/pass.hpp"
namespace ov {
namespace pass {
/**
* @brief The transformation finds all TensorIterator/Loop layers in the network,
* processes all back edges that describe a connection between Result and Parameter
* of the TensorIterator/Loop bodies,and inserts ReadValue and Assign layers at the
* input and output corresponding to this back edge.
* Supported platforms: CPU, GNA.
*
* The example below describes the changes made by the transformation
* [] - TensorIterator body
* () - new layer
* BE - back-edge
*
* before applying the transformation:
* -> input1[BE_1 -> Parameter -> Layers ... -> Result -> BE_1 ]output1->
*
* after applying the transformation:
* ->(ReadValue)-> input1[BE_1 ->Parameter->Layers ...->Result->BE_1]output1 ->(Assign)
* \
* ->...
* After applying the transformation, the resulting network can be inferred
* step by step, the states will store between inferences.
*/
class OPENVINO_API LowLatency2 : public FunctionPass {
public:
OPENVINO_RTTI_DECLARATION;
explicit LowLatency2(bool use_const_initializer = true) : m_use_const_initializer(use_const_initializer) {}
bool run_on_function(std::shared_ptr<ngraph::Function> f) override;
private:
bool m_use_const_initializer;
};
} // namespace pass
} // namespace ov

View File

@ -0,0 +1,116 @@
// Copyright (C) 2018-2021 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#pragma once
#include <list>
#include <memory>
#include <typeinfo>
#include <vector>
#include "openvino/pass/pass.hpp"
#include "openvino/pass/validate.hpp"
namespace ov {
namespace pass {
class OPENVINO_API Manager {
public:
Manager();
~Manager();
//// \brief Construct Manager with shared PassConfig instance
explicit Manager(std::shared_ptr<PassConfig> pass_config);
/// \brief Register given transformation class type to execution list
/// Example below show the basic usage of pass::Manager
///
/// pass::Manager manager;
/// manager.register_pass<MyTransformation>(/*transformation constructor ars*/);
/// manager.run_passes(f);
///
/// For some purposes transformation can be registered and disabled by default.
///
/// manager.register_pass<MyTransformation, false>();
///
/// \return shared_ptr to the transformation instance
template <typename T, bool Enable = true, class... Args>
std::shared_ptr<T> register_pass(Args&&... args) {
auto rc = push_pass<T>(std::forward<Args>(args)...);
rc->set_pass_config(m_pass_config);
if (m_per_pass_validation) {
push_pass<Validate>();
}
if (!Enable && !m_pass_config->is_enabled<T>()) {
m_pass_config->disable<T>();
}
return rc;
}
void run_passes(std::shared_ptr<Function>);
void set_pass_visualization(bool new_state) {
m_visualize = new_state;
}
/// \brief Set flag to enable/disable running Validate pass after executing
/// each registered pass
/// \param new_state Value "true" enables Validate pass run; "false", otherwise
void set_per_pass_validation(bool new_state) {
m_per_pass_validation = new_state;
}
/// \brief Callback is a lambda function that can be used by registered transformations.
/// The main purpose of this callback is to provide a way for plugins to disable/enable
/// transformations based on some conditions. In some cases plugins may want not to
/// execute some
/// transformations.
/// For example plugin can disable unpleasant decompositions because of performance
/// reasons for
/// some cases.
/// Callback example:
/// auto callback = [](const std::shared_ptr<const ngraph::Node> & node) -> bool {
/// return std::dynamic_pointer_cast<const ngraph::opset3::DepthToSpace>(node) !=
/// nullptr;
/// };
/// This callback returns true in case of DepthToSpace operation. So when execution
/// DepthToSpace
/// decomposition pass will check is this decomposition needed or plugin can execute
/// this
/// operation directly. And of course on transformation side we need to have a response
/// for this
/// callback.
/// if (transformation_callback(batch_to_space)) {
/// return false;
/// }
/// \param callback lamda function that returns true in case if node is supported by
/// plugin and
/// transformation is not needed
OPENVINO_DEPRECATED("Please use get_pass_config() to configure transformation pipeline")
void set_callback(const param_callback& callback) {
m_pass_config->set_callback(callback);
}
/// \return PassConfig shared object. This object is used for transformations pipeline
/// configuration.
/// This object allows to disable/enable transformations execution, set callback to
/// particular
/// transformation. For mo details see PassConfig class.
std::shared_ptr<PassConfig> get_pass_config() {
return m_pass_config;
}
protected:
template <typename T, class... Args>
std::shared_ptr<T> push_pass(Args&&... args) {
static_assert(std::is_base_of<pass::PassBase, T>::value, "pass not derived from pass base");
auto pass = std::make_shared<T>(std::forward<Args>(args)...);
auto pass_base = std::static_pointer_cast<PassBase>(pass);
m_pass_list.push_back(pass_base);
return pass;
}
std::shared_ptr<PassConfig> m_pass_config;
std::vector<std::shared_ptr<PassBase>> m_pass_list;
bool m_visualize = false;
bool m_per_pass_validation = true;
};
} // namespace pass
} // namespace ov

View File

@ -0,0 +1,110 @@
// Copyright (C) 2018-2021 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#pragma once
#include <list>
#include <memory>
#include <vector>
#include "ngraph/util.hpp"
#include "openvino/core/core_visibility.hpp"
#include "openvino/core/deprecated.hpp"
#include "openvino/core/function.hpp"
#include "openvino/core/node.hpp"
#include "openvino/pass/pass_config.hpp"
namespace ov {
namespace pass {
enum class PassProperty : uint32_t {
// Pass requires node shapes to be static
REQUIRE_STATIC_SHAPE = 0x1,
// Pass transformation will change the function's dynamic state
CHANGE_DYNAMIC_STATE = 1 << 1,
};
using PassPropertyMask = ngraph::EnumMask<PassProperty>;
class OPENVINO_API PassBase {
friend class Manager;
public:
PassBase();
virtual ~PassBase() = default;
/// Check if this pass has all the pass properties.
bool get_property(const PassPropertyMask& prop_mask) const;
void set_name(const std::string& name) {
m_name = name;
}
std::string get_name() const;
/// \brief Set callback for particular transformation type.
/// This method set global callback. For more details see PassConfig class
/// documentation.
/// \param callback lambda function that takes node and returns bool
void set_callback(const param_callback& callback);
/// \brief Set PassConfig for particular transformation instance
/// \param pass_config is a PassConfig shared_ptr
virtual void set_pass_config(const std::shared_ptr<PassConfig>& pass_config) {
m_pass_config = pass_config;
}
/// \brief Allows to access PassConfig shared instance
/// \return Shared instance of PassConfig class
std::shared_ptr<PassConfig> get_pass_config() {
return m_pass_config;
}
/// \brief Applies callback for given node. By default callback returns false.
/// This method remains here only for backward compatibility and will be removed
/// after all transformations are moved to transformation_callback() method.
/// \return result of callback execution for given node
NGRAPH_DEPRECATED("Please use transformation_callback method instead")
bool m_transformation_callback(const std::shared_ptr<const Node>& node) {
return m_pass_config->get_callback(get_type_info())(node);
}
/// \brief Applies callback for given node. By default callback returns false.
/// \param node which will be used inside callback
/// \return result of callback execution for given node
bool transformation_callback(const std::shared_ptr<const Node>& node) {
return m_pass_config->get_callback(get_type_info())(node);
}
using type_info_t = DiscreteTypeInfo;
virtual const type_info_t& get_type_info() const = 0;
protected:
void set_property(const PassPropertyMask& prop, bool value);
private:
PassPropertyMask m_property;
std::string m_name;
std::shared_ptr<PassConfig> m_pass_config;
};
class OPENVINO_API FunctionPass : public PassBase {
public:
NGRAPH_RTTI_DECLARATION;
~FunctionPass() override;
virtual bool run_on_function(std::shared_ptr<ngraph::Function>) = 0;
};
class Manager;
enum class FusionType : uint32_t {
//`DIFFERENTIABLE_FUSIONS` produce ops that support autodiff
// i.e. implement `generate_adjoints`
DIFFERENTIABLE_FUSIONS = 0x1,
REGULAR_FUSIONS = 0x2,
//`FOP_FUSIONS` produce ops in the FusedOps category that might
// not be supported by all backends
FOP_FUSIONS = 0x4,
ALL_FUSIONS = 0xFFFFFFFF
};
using FusionTypeMask = ngraph::EnumMask<FusionType>;
} // namespace pass
} // namespace ov

View File

@ -0,0 +1,176 @@
// Copyright (C) 2018-2021 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#pragma once
#include <list>
#include <memory>
#include <vector>
#include "ngraph/util.hpp"
#include "openvino/core/core_visibility.hpp"
#include "openvino/core/deprecated.hpp"
#include "openvino/core/function.hpp"
#include "openvino/core/node.hpp"
namespace ov {
namespace pass {
using param_callback = std::function<bool(const std::shared_ptr<const ::ov::Node>)>;
using param_callback_map = std::map<ov::DiscreteTypeInfo, param_callback>;
/// \brief Class representing a transformations config that is used for disabling/enabling
/// transformations registered inside pass::Manager and also allows to set callback for all
/// transformations or for particular transformation.
///
/// When pass::Manager is created all passes registered inside this manager including nested
/// passes will share the same instance of PassConfig class.
/// To work with this class first you need to get shared instance of this class by calling
/// manager.get_pass_config() method. Then you will be able to disable/enable passes based
/// on transformations type_info. For example:
///
/// pass::Manager manager;
/// manager.register_pass<CommonOptimizations>();
/// auto pass_config = manager.get_pass_config();
/// pass_config->disable<ConvertGELU>(); // this will disable nested pass inside
/// // CommonOptimizations pipeline
/// manager.run_passes(f);
///
/// Sometimes it is needed to call transformation inside other transformation manually. And
/// for that case before running transformation you need manually check that this pass is
/// not disabled and then you need to set current PassConfig instance to this
/// transformation. For example:
///
/// // Inside MatcherPass callback or inside FunctionPass run_on_function() method
/// // you need to call get_pass_config() method to get shared instance of PassConfig
/// auto pass_config = get_pass_config();
///
/// // Before running nested transformation you need to check is it disabled or not
/// if (!pass_config->is_disabled<ConvertGELU>()) {
/// auto pass = ConvertGELU();
/// pass->set_pass_config(pass_config);
/// pass.apply(node);
/// }
///
/// Following this logic inside your transformations you will guaranty that transformations
/// will be executed in a right way.
class OPENVINO_API PassConfig {
public:
/// \brief Disable transformation by its type_info
/// \param type_info Transformation type_info
void disable(const DiscreteTypeInfo& type_info);
/// \brief Disable transformation by its class type (based on type_info)
template <typename T>
void disable() {
OPENVINO_SUPPRESS_DEPRECATED_START
disable(T::type_info);
OPENVINO_SUPPRESS_DEPRECATED_END
}
/// \brief Enable transformation by its type_info
/// \param type_info Transformation type_info
void enable(const DiscreteTypeInfo& type_info);
/// \brief Enable transformation by its class type (based on type_info)
template <typename T>
void enable() {
OPENVINO_SUPPRESS_DEPRECATED_START
enable(T::type_info);
OPENVINO_SUPPRESS_DEPRECATED_END
}
/// \brief Set callback for all kind of transformations
void set_callback(const param_callback& callback) {
m_callback = callback;
}
template <typename... Args>
typename std::enable_if<sizeof...(Args) == 0>::type set_callback(const param_callback& callback) {}
/// \brief Set callback for particular transformation class types
///
/// Example below show how to set callback for one or multiple passes using this method.
///
/// pass_config->set_callback<ngraph::pass::ConvertBatchToSpace,
/// ngraph::pass::ConvertSpaceToBatch>(
/// [](const_node_ptr &node) -> bool {
/// // Disable transformations for cases when input shape rank is not
/// equal to 4
/// const auto input_shape_rank =
/// node->get_output_partial_shape(0).rank().get_length();
/// if (input_shape_rank != 4) {
/// return false;
/// }
/// return true;
/// });
///
/// Note that inside transformations you must provide code that work with this callback.
/// See example below:
///
/// if (transformation_callback(node)) {
/// return false; // exit from transformation
/// }
///
template <typename T, class... Args>
void set_callback(const param_callback& callback) {
m_callback_map[T::type_info] = callback;
set_callback<Args...>(callback);
}
/// \brief Get callback for given transformation type_info
/// \param type_info Transformation type_info
///
/// In case if callback wasn't set for given transformation type then global callback
/// will be returned. But if even global callback wasn't set then default callback will
/// be returned.
param_callback get_callback(const DiscreteTypeInfo& type_info) const;
/// \brief Get callback for given transformation class type
/// \return callback lambda function
template <typename T>
param_callback get_callback() const {
NGRAPH_SUPPRESS_DEPRECATED_START
return get_callback(T::type_info);
NGRAPH_SUPPRESS_DEPRECATED_END
}
/// \brief Check either transformation type is disabled or not
/// \param type_info Transformation type_info
/// \return true if transformation type was disabled and false otherwise
bool is_disabled(const DiscreteTypeInfo& type_info) const {
return m_disabled.count(type_info);
}
/// \brief Check either transformation class type is disabled or not
/// \return true if transformation type was disabled and false otherwise
template <typename T>
bool is_disabled() const {
NGRAPH_SUPPRESS_DEPRECATED_START
return is_disabled(T::type_info);
NGRAPH_SUPPRESS_DEPRECATED_END
}
/// \brief Check either transformation type is force enabled or not
/// \param type_info Transformation type_info
/// \return true if transformation type was force enabled and false otherwise
bool is_enabled(const DiscreteTypeInfo& type_info) const {
return m_enabled.count(type_info);
}
/// \brief Check either transformation class type is force enabled or not
/// \return true if transformation type was force enabled and false otherwise
template <typename T>
bool is_enabled() const {
return is_enabled(T::type_info);
}
void add_disabled_passes(const PassConfig& rhs);
private:
param_callback m_callback = [](const std::shared_ptr<const ::ngraph::Node>&) {
return false;
};
param_callback_map m_callback_map;
std::unordered_set<DiscreteTypeInfo> m_disabled;
std::unordered_set<DiscreteTypeInfo> m_enabled;
};
} // namespace pass
} // namespace ov

View File

@ -0,0 +1,271 @@
// Copyright (C) 2018-2021 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#pragma once
#include <memory.h>
#include <algorithm>
#include <functional>
#include "ngraph/op/constant.hpp"
#include "openvino/core/except.hpp"
#include "openvino/core/node.hpp"
#include "openvino/pass/pattern/op/any.hpp"
#include "openvino/pass/pattern/op/any_of.hpp"
#include "openvino/pass/pattern/op/any_output.hpp"
#include "openvino/pass/pattern/op/label.hpp"
#include "openvino/pass/pattern/op/skip.hpp"
namespace ov {
namespace pass {
class GraphRewrite;
namespace pattern {
class Matcher;
class OPENVINO_API MatcherState {
public:
MatcherState(Matcher*);
bool finish(bool is_successful);
~MatcherState();
protected:
Matcher* m_matcher;
PatternValueMap m_pattern_value_map;
PatternValueMaps m_pattern_value_maps;
size_t m_watermark;
size_t m_capture_size;
bool m_restore{true};
};
/// Matcher looks for node patterns in a computation graph. The patterns are described by an
/// automaton that is described by an extended computation graph. The matcher executes
/// by attempting to match the start node of the pattern to a computation graph value
/// (output of a Node). In addition to determing if a match occurs, a pattern node may add
/// graph nodes to a list of matched nodes, associate nodes with graph values, and start
/// submatches. Submatches add match state changes to the enclosing match if the submatch
/// succeeds; otherwise the state is reverted.
///
/// The default match behavior of a pattern node with a graph nodes is that the computation
/// graph value is added to the end of the matched value list and the match succeeds if the
/// node/pattern types match and the input values match. In the case of a commutative node,
/// the inputs can match in any order. If the matcher is in strict mode, the graph value
/// element type and shape must also match.
///
/// Pattern nodes that have different match behavior are in ov::pass::pattern::op and have
/// descriptions of their match behavior.
class OPENVINO_API Matcher {
public:
using PatternMap = ov::pass::pattern::PatternMap;
// Avoid implicit string construction from nullptr.
Matcher(const std::shared_ptr<Node> pattern_node, std::nullptr_t name) = delete;
Matcher() = default;
Matcher(Output<Node>& pattern_node) : m_pattern_node{pattern_node} {}
Matcher(Output<Node>& pattern_node, const std::string& name) : m_pattern_node(pattern_node), m_name{name} {}
/// \brief Constructs a Matcher object
///
/// \param pattern_node is a pattern sub graph that will be matched against input graphs
/// \param name is a string which is used for logging and disabling a matcher
/// \param strict_mode forces a matcher to consider shapes and ET of nodes
Matcher(const Output<Node>& pattern_node, const std::string& name, bool strict_mode)
: m_pattern_node(pattern_node),
m_name(name),
m_strict_mode(strict_mode) {}
// Some matches should start on a node rather than an output. These three constructors
// are transition until we work out the right way to do that.
Matcher(std::shared_ptr<Node> pattern_node);
Matcher(std::shared_ptr<Node> pattern_node, const std::string& name);
Matcher(std::shared_ptr<Node> pattern_node, const std::string& name, bool strict_mode);
virtual ~Matcher() = default;
/// \brief Matches a pattern to \p graph_node
///
/// \param graph_value is an input graph to be matched against
bool match(const Output<Node>& graph_value);
bool match(std::shared_ptr<Node> graph_node);
/// \brief Matches a pattern to \p graph_node
///
/// \param graph_value is an input graph to be matched against
/// \param previous_matches contains previous mappings from labels to nodes to use
bool match(const Output<Node>& graph_value, const PatternMap& previous_matches);
bool match(const Output<Node>& graph_value, const PatternValueMap& previous_matches);
template <typename T>
static std::shared_ptr<T> unique_match(const std::shared_ptr<Node>& node) {
std::shared_ptr<T> matched;
for (const auto& arg : node->input_values()) {
if (auto t_casted = ov::as_type_ptr<T>(arg.get_node_shared_ptr())) {
if (matched) {
throw Exception("There's more than two arguments of the same type");
} else {
matched = t_casted;
}
}
}
return matched;
}
bool is_contained_match(const NodeVector& exclusions = {}, bool ignore_unused = true);
const NodeVector get_matched_nodes() {
return as_node_vector(m_matched_list);
}
const OutputVector& get_matched_values() const {
return m_matched_list;
}
OutputVector& get_matched_values() {
return m_matched_list;
}
void reset() {}
const std::string& get_name() {
return m_name;
}
std::shared_ptr<Node> get_pattern() {
return m_pattern_node.get_node_shared_ptr();
}
Output<Node> get_pattern_value() {
return m_pattern_node;
}
std::shared_ptr<Node> get_match_root();
Output<Node> get_match_value();
PatternMap get_pattern_map() const;
PatternValueMap& get_pattern_value_map() {
return m_pattern_map;
}
PatternValueMaps& get_pattern_value_maps() {
return m_pattern_value_maps;
}
/// \brief Low-level helper to match recurring patterns
///
/// \param graph is a graph to be matched against
/// \param pattern is a recurring pattern
/// \param rpattern specifies a node to recur from next
/// \param patterns a map from labels to matches
size_t add_node(Output<Node> node);
bool virtual match_value(const ov::Output<Node>& pattern_value, const ov::Output<Node>& graph_value);
bool is_strict_mode() {
return m_strict_mode;
}
virtual bool match_arguments(Node* pattern_node, const std::shared_ptr<Node>& graph_node);
void capture(const std::set<Node*>& static_nodes);
void clear_state();
size_t get_number_of_recurrent_matches() const {
return m_pattern_value_maps.size();
}
NodeVector get_bound_nodes_for_pattern(const Output<Node>& pattern) const;
size_t get_number_of_bound_labels() const;
/// \brief Try a match
MatcherState start_match();
Output<Node> m_match_root;
Output<Node> m_pattern_node;
PatternValueMap m_pattern_map;
PatternValueMaps m_pattern_value_maps;
OutputVector m_matched_list;
protected:
bool match_permutation(const OutputVector& pattern_args, const OutputVector& args);
std::string m_name{"unnamed"};
bool m_strict_mode{false};
};
class OPENVINO_API RecurrentMatcher {
public:
/// \brief Constructs a RecurrentMatcher object. Reccurent Matchers are used to match
/// repeating patterns (e.g. RNN, LSTM, GRU cells)
///
/// \param initial_pattern is a pattern sub graph describing the initial cell
/// \param pattern is a pattern sub graph describing an individual cell
/// \param rpattern is a (recurring) label to denote which node the next match should
/// start at
/// \param correlated_patterns is a set of labels whose bound nodes must remain the same
/// across all cells
RecurrentMatcher(const Output<Node>& initial_pattern,
const Output<Node>& pattern,
const std::shared_ptr<Node>& rpattern,
const std::set<std::shared_ptr<Node>>& correlated_patterns)
: m_initial_pattern(initial_pattern),
m_pattern(pattern),
m_recurrent_pattern(rpattern),
m_correlated_patterns(correlated_patterns) {}
/// \brief Constructs a RecurrentMatcher object. Reccurent Matchers are used to match
/// repeating patterns (e.g. RNN, LSTM, GRU cells)
///
/// \param pattern is a pattern sub graph describing an individual cell
/// \param rpattern is a (recurring) label to denote which node the next match should
/// start at
/// \param correlated_patterns is a set of labels whose bound nodes must remain the same
/// across all cells
RecurrentMatcher(const Output<Node>& pattern,
const std::shared_ptr<Node>& rpattern,
const std::set<std::shared_ptr<Node>>& correlated_patterns)
: RecurrentMatcher(pattern, pattern, rpattern, correlated_patterns) {}
RecurrentMatcher(const Output<Node>& initial_pattern,
const Output<Node>& pattern,
const std::shared_ptr<Node>& rpattern,
const std::set<std::shared_ptr<op::Label>>& correlated_patterns);
RecurrentMatcher(const Output<Node>& pattern,
const std::shared_ptr<Node>& rpattern,
const std::set<std::shared_ptr<op::Label>>& correlated_patterns)
: RecurrentMatcher(pattern, pattern, rpattern, correlated_patterns) {}
/// \brief Returns a vector of bound nodes for a given label (used in a pattern
/// describing an individual cell
NodeVector get_bound_nodes_for_pattern(const std::shared_ptr<Node>& pattern) const {
if (m_matches.count(pattern) == 0) {
throw Exception("No bound nodes for a given label");
}
return as_node_vector(m_matches.at(pattern));
}
size_t get_number_of_recurrent_matches() const {
if (m_matches.size() == 0) {
return 0;
}
return (*m_matches.begin()).second.size();
}
size_t get_number_of_bound_labels() const {
return m_matches.size();
}
/// \brief Tries to match a pattern for an individual cell to a given \p graph
bool match(Output<Node> graph);
std::shared_ptr<Node> get_match_root() {
return m_match_root.get_node_shared_ptr();
}
Output<Node> get_match_value() {
return m_match_root;
}
private:
Output<Node> m_initial_pattern;
Output<Node> m_pattern;
std::shared_ptr<Node> m_recurrent_pattern;
const std::set<std::shared_ptr<Node>> m_correlated_patterns;
RPatternValueMap m_matches;
Output<Node> m_match_root;
};
} // namespace pattern
} // namespace pass
} // namespace ov

View File

@ -0,0 +1,45 @@
// Copyright (C) 2018-2021 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#pragma once
#include "openvino/core/node.hpp"
#include "openvino/pass/pattern/op/pattern.hpp"
namespace ov {
namespace pass {
namespace pattern {
namespace op {
/// The graph value is to the matched value list. If the predicate is true for the node
/// and the arguments match, the match succeeds.
class OPENVINO_API Any : public Pattern {
public:
static constexpr NodeTypeInfo type_info{"patternAny", 0};
const NodeTypeInfo& get_type_info() const override;
/// \brief creates a Any node containing a sub-pattern described by \sa type and \sa
/// shape.
Any(const element::Type& type, const PartialShape& s, ValuePredicate pred, const OutputVector& wrapped_values)
: Pattern(wrapped_values, pred) {
set_output_type(0, type, s);
}
Any(const element::Type& type, const PartialShape& s, NodePredicate pred, const NodeVector& wrapped_values)
: Any(type, s, as_value_predicate(pred), as_output_vector(wrapped_values)) {}
/// \brief creates a Any node containing a sub-pattern described by the type and
/// shape of \sa node.
Any(const Output<Node>& node, ValuePredicate pred, const OutputVector& wrapped_values)
: Any(node.get_element_type(), node.get_partial_shape(), pred, wrapped_values) {}
Any(const Output<Node>& node, NodePredicate pred, const NodeVector& wrapped_values)
: Any(node.get_element_type(),
node.get_partial_shape(),
as_value_predicate(pred),
as_output_vector(wrapped_values)) {}
bool match_value(pattern::Matcher* matcher,
const Output<Node>& pattern_value,
const Output<Node>& graph_value) override;
};
} // namespace op
} // namespace pattern
} // namespace pass
} // namespace ov

View File

@ -0,0 +1,54 @@
// Copyright (C) 2018-2021 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#pragma once
#include "openvino/core/node.hpp"
#include "openvino/pass/pattern/op/pattern.hpp"
namespace ov {
namespace pass {
namespace pattern {
namespace op {
/// The graph value is added to the matched values list. If the predicate is true for
/// the
/// graph node, a submatch is performed on the input of AnyOf and each input of the
/// graph node. The first match that succeeds results in a successful match. Otherwise
/// the match fails.
///
/// AnyOf may be given a type and shape for use in strict mode.
class OPENVINO_API AnyOf : public Pattern {
public:
static constexpr NodeTypeInfo type_info{"patternAnyOf", 0};
const NodeTypeInfo& get_type_info() const override;
/// \brief creates a AnyOf node containing a sub-pattern described by \sa type and
/// \sa shape.
AnyOf(const element::Type& type, const PartialShape& s, ValuePredicate pred, const OutputVector& wrapped_values)
: Pattern(wrapped_values, pred) {
if (wrapped_values.size() != 1) {
throw Exception("AnyOf expects exactly one argument");
}
set_output_type(0, type, s);
}
AnyOf(const element::Type& type, const PartialShape& s, NodePredicate pred, const NodeVector& wrapped_values)
: AnyOf(
type,
s,
[pred](const Output<Node>& value) {
return pred(value.get_node_shared_ptr());
},
as_output_vector(wrapped_values)) {}
/// \brief creates a AnyOf node containing a sub-pattern described by the type and
/// shape of \sa node.
AnyOf(const Output<Node>& node, ValuePredicate pred, const OutputVector& wrapped_values)
: AnyOf(node.get_element_type(), node.get_partial_shape(), pred, wrapped_values) {}
AnyOf(const std::shared_ptr<Node>& node, NodePredicate pred, const NodeVector& wrapped_values)
: AnyOf(node, as_value_predicate(pred), as_output_vector(wrapped_values)) {}
bool match_value(Matcher* matcher, const Output<Node>& pattern_value, const Output<Node>& graph_value) override;
};
} // namespace op
} // namespace pattern
} // namespace pass
} // namespace ov

View File

@ -0,0 +1,30 @@
// Copyright (C) 2018-2021 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#pragma once
#include "openvino/core/node.hpp"
#include "openvino/pass/pattern/op/pattern.hpp"
namespace ov {
namespace pass {
namespace pattern {
namespace op {
/// Matches any output of a node
class OPENVINO_API AnyOutput : public Pattern {
public:
static constexpr NodeTypeInfo type_info{"patternAnyOutput", 0};
const NodeTypeInfo& get_type_info() const override;
/// \brief creates an AnyOutput node matching any output of a node
/// \param node The node to match
AnyOutput(const std::shared_ptr<Node>& pattern) : Pattern({pattern->output(0)}) {}
bool match_value(pattern::Matcher* matcher,
const Output<Node>& pattern_value,
const Output<Node>& graph_value) override;
};
} // namespace op
} // namespace pattern
} // namespace pass
} // namespace ov

View File

@ -0,0 +1,55 @@
// Copyright (C) 2018-2021 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#pragma once
#include "openvino/core/node.hpp"
#include "openvino/pass/pattern/op/pattern.hpp"
namespace ov {
namespace pass {
namespace pattern {
namespace op {
/// A branch adds a loop to the pattern. The branch match is successful if the
/// destination node pattern matches the graph value. The destination node is a node in
/// the pattern graph that will not have been created some time after the Branch node is
/// created; use set_destination to add it.
///
/// The branch destination is not stored as a shared pointer to prevent reference
/// cycles. Thus the destination node must be referenced in some other way to prevent it
/// from being deleted.
class OPENVINO_API Branch : public Pattern {
public:
static constexpr NodeTypeInfo type_info{"patternBranch", 0};
const NodeTypeInfo& get_type_info() const override;
/// \brief Creates a Branch pattern
/// \param pattern the destinationing pattern
/// \param labels Labels where the destination may occur
Branch() : Pattern(OutputVector{}) {
set_output_type(0, element::f32, ngraph::Shape{});
}
void set_destination(const Output<Node>& destination) {
m_destination_node = destination.get_node();
m_destination_index = destination.get_index();
}
Output<Node> get_destination() const {
return m_destination_node == nullptr
? Output<Node>()
: Output<Node>{m_destination_node->shared_from_this(), m_destination_index};
}
bool match_value(pattern::Matcher* matcher,
const Output<Node>& pattern_value,
const Output<Node>& graph_value) override;
protected:
Node* m_destination_node{nullptr};
size_t m_destination_index{0};
};
} // namespace op
} // namespace pattern
} // namespace pass
} // namespace ov

View File

@ -0,0 +1,44 @@
// Copyright (C) 2018-2021 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#pragma once
#include "openvino/core/node.hpp"
#include "openvino/pass/pattern/op/pattern.hpp"
namespace ov {
namespace pass {
namespace pattern {
namespace op {
/// Experimental for support of recurrent matches.
///
/// Capture adds the pattern value map to a list of pattern value maps and resets
/// matches for pattern nodes not in the static node list. The match always succeeds.
class OPENVINO_API Capture : public Pattern {
public:
static constexpr NodeTypeInfo type_info{"patternCapture", 0};
const NodeTypeInfo& get_type_info() const override;
Capture(const Output<Node>& arg) : Pattern({arg}) {
set_output_type(0, arg.get_element_type(), arg.get_partial_shape());
}
/// \brief static nodes are retained after a capture. All other nodes are dropped
std::set<Node*> get_static_nodes() {
return m_static_nodes;
}
void set_static_nodes(const std::set<Node*>& static_nodes) {
m_static_nodes = static_nodes;
}
bool match_value(pattern::Matcher* matcher,
const Output<Node>& pattern_value,
const Output<Node>& graph_value) override;
protected:
std::set<Node*> m_static_nodes;
};
} // namespace op
} // namespace pattern
} // namespace pass
} // namespace ov

View File

@ -0,0 +1,113 @@
// Copyright (C) 2018-2021 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#pragma once
#include "openvino/core/node.hpp"
#include "openvino/pass/pattern/op/pattern.hpp"
namespace ov {
namespace pass {
namespace pattern {
namespace op {
/// Fails if the predicate returns false on the graph value.
///
/// The graph value is added to the matched values list. If the Label is already
/// associated with a value, the match succeeds if the value is the same as the graph
/// value. Otherwise, the label is associated with the graph value and the match
/// succeeds if the pattern input matches the graph value.
///
/// DEPRECATED: If no inputs are given to Label, a True node is serves as the input. If
/// more than one inputs are given, an Or pattern of the inputs serves as the input.
class OPENVINO_API Label : public Pattern {
public:
static constexpr NodeTypeInfo type_info{"patternLabel", 0};
const NodeTypeInfo& get_type_info() const override;
/// \brief creates a Label node containing a sub-pattern described by \sa type and
/// \sa shape.
///
/// this Label node can be bound only to the nodes in the input graph
/// that match the pattern specified by \sa wrapped_nodes
/// Example:
/// \code{.cpp}
/// auto add = a + b; // a and b are op::Parameter in this example
/// auto label = std::make_shared<pattern::op::Label>(element::f32,
/// Shape{2,2},
/// nullptr,
/// OutputVector{add});
/// \endcode
Label(const element::Type& type,
const PartialShape& s,
const ValuePredicate pred,
const OutputVector& wrapped_values)
: Pattern(OutputVector{wrap_values(wrapped_values)}, pred) {
set_output_type(0, type, s);
}
explicit Label(const element::Type& type = element::dynamic, const PartialShape& s = PartialShape::dynamic())
: Label(
type,
s,
[](const Output<Node>&) {
return true;
},
OutputVector()) {}
Label(const element::Type& type, const PartialShape& s, ValuePredicate pred)
: Label(type, s, pred, OutputVector{}) {}
Label(const element::Type& type, const PartialShape& s, NodePredicate pred)
: Label(type, s, as_value_predicate(pred), OutputVector{}) {}
Label(const element::Type& type, const PartialShape& s, const NodePredicate pred, const NodeVector& wrapped_values)
: Label(type, s, as_value_predicate(pred), as_output_vector(wrapped_values)) {}
/// \brief creates a Label node containing a sub-pattern described by the type and
/// shape of \sa node.
///
/// this Label node can be bound only to the nodes in the input graph
/// that match the pattern specified by \sa wrapped_values
/// Example:
/// \code{.cpp}
/// auto add = a + b; // a and b are op::Parameter in this example
/// auto label = std::make_shared<pattern::op::Label>(add,
/// nullptr,
/// OutputVector{add});
/// \endcode
Label(const Output<Node>& value, const ValuePredicate pred, const OutputVector& wrapped_values)
: Label(value.get_element_type(), value.get_partial_shape(), pred, wrapped_values) {}
Label(const Output<Node>& value, const ValuePredicate pred)
: Label(value.get_element_type(), value.get_partial_shape(), pred, OutputVector{}) {}
Label(const Output<Node>& value, const NodePredicate pred)
: Label(value.get_element_type(), value.get_partial_shape(), as_value_predicate(pred), OutputVector{}) {}
Label(const Output<Node>& value)
: Label(
value.get_element_type(),
value.get_partial_shape(),
[](const Output<Node>&) {
return true;
},
OutputVector{}) {}
Label(const Output<Node>& node, const NodePredicate pred, const NodeVector& wrapped_values)
: Label(node.get_element_type(),
node.get_partial_shape(),
as_value_predicate(pred),
as_output_vector(wrapped_values)) {}
bool match_value(Matcher* matcher, const Output<Node>& pattern_value, const Output<Node>& graph_value) override;
protected:
static Output<Node> wrap_values(const OutputVector& wrapped_values);
};
} // namespace op
OPENVINO_API
std::shared_ptr<Node> any_input();
OPENVINO_API
std::shared_ptr<Node> any_input(const pattern::op::ValuePredicate& pred);
} // namespace pattern
} // namespace pass
} // namespace ov

View File

@ -0,0 +1,32 @@
// Copyright (C) 2018-2021 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#pragma once
#include "openvino/core/node.hpp"
#include "openvino/pass/pattern/op/pattern.hpp"
namespace ov {
namespace pass {
namespace pattern {
namespace op {
/// A submatch on the graph value is performed on each input to the Or; the match
/// succeeds on the first match. Otherwise the match fails.
class OPENVINO_API Or : public Pattern {
public:
static constexpr NodeTypeInfo type_info{"patternOr", 0};
const NodeTypeInfo& get_type_info() const override;
/// \brief creates an Or node matching one of several sub-patterns in order. Does
/// not add node to match list.
/// \param patterns The patterns to try for matching
Or(const OutputVector& patterns) : Pattern(patterns) {}
bool match_value(pattern::Matcher* matcher,
const Output<Node>& pattern_value,
const Output<Node>& graph_value) override;
};
} // namespace op
} // namespace pattern
} // namespace pass
} // namespace ov

View File

@ -0,0 +1,96 @@
// Copyright (C) 2018-2021 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#pragma once
#include <functional>
#include "openvino/core/node.hpp"
namespace ov {
namespace pass {
namespace pattern {
namespace op {
class Label;
}
class Matcher;
class MatcherState;
using RPatternValueMap = std::map<std::shared_ptr<Node>, OutputVector>;
using PatternValueMap = std::map<std::shared_ptr<Node>, Output<Node>>;
using PatternValueMaps = std::vector<PatternValueMap>;
using PatternMap = std::map<std::shared_ptr<Node>, std::shared_ptr<Node>>;
PatternMap as_pattern_map(const PatternValueMap& pattern_value_map);
PatternValueMap as_pattern_value_map(const PatternMap& pattern_map);
template <typename T>
std::function<bool(std::shared_ptr<Node>)> has_class() {
auto pred = [](std::shared_ptr<Node> node) -> bool {
return ov::is_type<T>(node);
};
return pred;
}
OPENVINO_API
std::function<bool(Output<Node>)> consumers_count(size_t n);
OPENVINO_API
std::function<bool(Output<Node>)> has_static_dim(size_t pos);
OPENVINO_API
std::function<bool(Output<Node>)> has_static_dims(const std::vector<size_t>& dims);
OPENVINO_API
std::function<bool(Output<Node>)> has_static_shape();
OPENVINO_API
std::function<bool(Output<Node>)> has_static_rank();
OPENVINO_API
std::function<bool(Output<Node>)> rank_equals(const Dimension& expected_rank);
OPENVINO_API
std::function<bool(Output<Node>)> type_matches(const element::Type& type);
OPENVINO_API
std::function<bool(Output<Node>)> type_matches_any(const std::vector<element::Type>& types);
namespace op {
using NodePredicate = std::function<bool(std::shared_ptr<Node>)>;
using ValuePredicate = std::function<bool(const Output<Node>& value)>;
OPENVINO_API
ValuePredicate as_value_predicate(NodePredicate pred);
class OPENVINO_API Pattern : public Node {
public:
/// \brief \p a base class for \sa Skip and \sa Label
///
Pattern(const OutputVector& patterns, ValuePredicate pred) : Node(patterns), m_predicate(pred) {
if (!m_predicate) {
m_predicate = [](const Output<Node>&) {
return true;
};
}
}
Pattern(const OutputVector& patterns) : Pattern(patterns, nullptr) {}
std::shared_ptr<Node> clone_with_new_inputs(const OutputVector& /* new_args */) const override {
throw Exception("Uncopyable");
}
ValuePredicate get_predicate() const;
protected:
ValuePredicate m_predicate;
};
} // namespace op
} // namespace pattern
} // namespace pass
} // namespace ov

View File

@ -0,0 +1,44 @@
// Copyright (C) 2018-2021 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#pragma once
#include "openvino/core/node.hpp"
#include "openvino/pass/pattern/op/pattern.hpp"
namespace ov {
namespace pass {
namespace pattern {
namespace op {
/// The graph value is added to the matched value list. If the predicate is true, the
/// match succeeds if the arguments match; if the predicate is false, the match succeeds
/// if the pattern input matches the graph value.
class OPENVINO_API Skip : public Pattern {
public:
static constexpr NodeTypeInfo type_info{"patternSkip", 0};
const NodeTypeInfo& get_type_info() const override;
Skip(const Output<Node>& arg, ValuePredicate pred) : Pattern({arg}, pred) {
set_output_type(0, arg.get_element_type(), arg.get_partial_shape());
}
Skip(const Output<Node>& arg, NodePredicate pred = nullptr) : Pattern({arg}, as_value_predicate(pred)) {
set_output_type(0, arg.get_element_type(), arg.get_partial_shape());
}
Skip(const OutputVector& args, ValuePredicate pred) : Pattern(args, pred) {
set_output_type(0, args.at(0).get_element_type(), args.at(0).get_partial_shape());
}
Skip(const OutputVector& args, NodePredicate pred = nullptr) : Pattern(args, as_value_predicate(pred)) {
set_output_type(0, args.at(0).get_element_type(), args.at(0).get_partial_shape());
}
bool match_value(pattern::Matcher* matcher,
const Output<Node>& pattern_value,
const Output<Node>& graph_value) override;
};
} // namespace op
} // namespace pattern
} // namespace pass
} // namespace ov

View File

@ -0,0 +1,28 @@
// Copyright (C) 2018-2021 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#pragma once
#include "openvino/core/node.hpp"
#include "openvino/pass/pattern/op/pattern.hpp"
namespace ov {
namespace pass {
namespace pattern {
namespace op {
/// \brief The match always succeeds.
class OPENVINO_API True : public Pattern {
public:
static constexpr NodeTypeInfo type_info{"patternTrue", 0};
const NodeTypeInfo& get_type_info() const override;
/// \brief Always matches, does not add node to match list.
True() : Pattern(OutputVector{}) {}
bool match_value(pattern::Matcher* matcher,
const Output<Node>& pattern_value,
const Output<Node>& graph_value) override;
};
} // namespace op
} // namespace pattern
} // namespace pass
} // namespace ov

View File

@ -0,0 +1,75 @@
// Copyright (C) 2018-2021 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#pragma once
#include "openvino/core/node.hpp"
#include "openvino/pass/pattern/op/pattern.hpp"
namespace ov {
namespace pass {
namespace pattern {
namespace op {
class OPENVINO_API WrapType : public Pattern {
public:
static constexpr NodeTypeInfo type_info{"patternAnyType", 0};
const NodeTypeInfo& get_type_info() const override;
explicit WrapType(
NodeTypeInfo wrapped_type,
const ValuePredicate& pred =
[](const Output<Node>& output) {
return true;
},
const OutputVector& input_values = {})
: Pattern(input_values, pred),
m_wrapped_types({wrapped_type}) {
set_output_type(0, element::Type_t::dynamic, PartialShape::dynamic());
}
explicit WrapType(
std::vector<NodeTypeInfo> wrapped_types,
const ValuePredicate& pred =
[](const Output<Node>& output) {
return true;
},
const OutputVector& input_values = {})
: Pattern(input_values, pred),
m_wrapped_types(std::move(wrapped_types)) {
set_output_type(0, element::Type_t::dynamic, PartialShape::dynamic());
}
bool match_value(pattern::Matcher* matcher,
const Output<Node>& pattern_value,
const Output<Node>& graph_value) override;
NodeTypeInfo get_wrapped_type() const;
const std::vector<NodeTypeInfo>& get_wrapped_types() const;
private:
std::vector<NodeTypeInfo> m_wrapped_types;
};
} // namespace op
template <class... Args>
std::shared_ptr<Node> wrap_type(const OutputVector& inputs, const pattern::op::ValuePredicate& pred) {
std::vector<DiscreteTypeInfo> info{Args::type_info...};
return std::make_shared<op::WrapType>(info, pred, inputs);
}
template <class... Args>
std::shared_ptr<Node> wrap_type(const OutputVector& inputs = {}) {
return wrap_type<Args...>(inputs, [](const Output<Node>& output) {
return true;
});
}
template <class... Args>
std::shared_ptr<Node> wrap_type(const pattern::op::ValuePredicate& pred) {
return wrap_type<Args...>({}, pred);
}
} // namespace pattern
} // namespace pass
} // namespace ov

View File

@ -0,0 +1,32 @@
// Copyright (C) 2018-2021 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#pragma once
#include "openvino/core/core_visibility.hpp"
#include "openvino/pass/pass.hpp"
namespace ov {
namespace pass {
/// \brief The Validate pass performs sanity checks on attributes and inputs, and
/// computes output shapes and element types for all computation nodes in a given
/// computation graph.
///
/// \details The verification and inference is done via invoking each node's specific
/// implementation of \link ov::Node::validate_and_infer_types() \endlink function.
///
/// By default, the \ref ov::pass::Manager runs this pass after executing every
/// optimization pass. This is to ensure that any update to the graph by an optimization
/// pass does not break the shape and data type requirement on a computation node.
/// This default validation run can be changed via calling the
/// \link ov::pass::Manager::set_per_pass_validation(bool) \endlink function.
class OPENVINO_API Validate : public FunctionPass {
public:
OPENVINO_RTTI_DECLARATION;
Validate() : FunctionPass() {}
bool run_on_function(std::shared_ptr<ov::Function> f) override;
};
} // namespace pass
} // namespace ov

View File

@ -0,0 +1,57 @@
// Copyright (C) 2018-2021 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#pragma once
#include <functional>
#include <set>
#include <sstream>
#include <string>
#include <typeindex>
#include <typeinfo>
#include <unordered_map>
#include <utility>
#include "openvino/pass/pass.hpp"
class HeightMap;
using visualize_tree_ops_map_t =
std::unordered_map<ov::Node::type_info_t, std::function<void(const ov::Node&, std::ostream& ss)>>;
namespace ov {
namespace pass {
class OPENVINO_API VisualizeTree : public FunctionPass {
public:
OPENVINO_RTTI_DECLARATION;
using node_modifiers_t = std::function<void(const Node& node, std::vector<std::string>& attributes)>;
VisualizeTree(const std::string& file_name, node_modifiers_t nm = nullptr, bool dot_only = false);
bool run_on_function(std::shared_ptr<ov::Function>) override;
void set_ops_to_details(const visualize_tree_ops_map_t& ops_map) {
m_ops_to_details = ops_map;
}
protected:
void add_node_arguments(std::shared_ptr<Node> node,
std::unordered_map<Node*, HeightMap>& height_maps,
size_t& fake_node_ctr);
std::string add_attributes(std::shared_ptr<Node> node);
virtual std::string get_attributes(std::shared_ptr<Node> node);
virtual std::string get_node_name(std::shared_ptr<Node> node);
std::string get_constant_value(std::shared_ptr<Node> node, size_t max_elements = 7);
void render() const;
std::stringstream m_ss;
std::string m_name;
std::set<std::shared_ptr<Node>> m_nodes_with_attributes;
visualize_tree_ops_map_t m_ops_to_details;
node_modifiers_t m_node_modifiers = nullptr;
bool m_dot_only;
static const int max_jump_distance;
};
} // namespace pass
} // namespace ov

View File

@ -11,11 +11,10 @@
#include "ngraph/validation_util.hpp" #include "ngraph/validation_util.hpp"
using namespace std; using namespace std;
using namespace ngraph;
NGRAPH_RTTI_DEFINITION(ngraph::pass::ConstantFolding, "ConstantFolding", 0); OPENVINO_RTTI_DEFINITION(ov::pass::ConstantFolding, "ConstantFolding", 0);
bool ngraph::pass::ConstantFolding::run_on_function(std::shared_ptr<ngraph::Function> f) { bool ov::pass::ConstantFolding::run_on_function(std::shared_ptr<ov::Function> f) {
bool rewritten = pre_calculated_values_folding(f); bool rewritten = pre_calculated_values_folding(f);
for (const auto& node : f->get_ordered_ops()) { for (const auto& node : f->get_ordered_ops()) {
@ -48,7 +47,7 @@ bool ngraph::pass::ConstantFolding::run_on_function(std::shared_ptr<ngraph::Func
} }
} else { } else {
// recursively constant fold operators containing subgraphs (ie: TensorIterator, Loop) // recursively constant fold operators containing subgraphs (ie: TensorIterator, Loop)
if (auto sub_graph_node = std::dynamic_pointer_cast<op::util::SubGraphOp>(node)) { if (auto sub_graph_node = std::dynamic_pointer_cast<ngraph::op::util::SubGraphOp>(node)) {
if (const auto& sub_graph = sub_graph_node->get_function()) { if (const auto& sub_graph = sub_graph_node->get_function()) {
rewritten |= run_on_function(sub_graph); rewritten |= run_on_function(sub_graph);
} }
@ -79,14 +78,14 @@ bool ngraph::pass::ConstantFolding::pre_calculated_values_folding(const std::sha
while (!nodes.empty()) { while (!nodes.empty()) {
auto curr_node = nodes.front(); auto curr_node = nodes.front();
nodes.pop_front(); nodes.pop_front();
if (visited.count(curr_node) || ov::is_type<op::Constant>(curr_node)) if (visited.count(curr_node) || ov::is_type<ngraph::op::Constant>(curr_node))
continue; continue;
visited.insert(curr_node); visited.insert(curr_node);
for (auto& input_value : curr_node->input_values()) { for (auto& input_value : curr_node->input_values()) {
// Check that ConstantFolding is not disabled on this path // Check that ConstantFolding is not disabled on this path
std::vector<Node*> order; std::vector<Node*> order;
auto status = could_propagate(input_value, order); auto status = ngraph::could_propagate(input_value, order);
if (status) { if (status) {
for (const auto& node : order) { for (const auto& node : order) {
const auto& rt_info = node->get_rt_info(); const auto& rt_info = node->get_rt_info();
@ -99,8 +98,8 @@ bool ngraph::pass::ConstantFolding::pre_calculated_values_folding(const std::sha
if (status && input_value.get_tensor().has_and_set_bound()) { if (status && input_value.get_tensor().has_and_set_bound()) {
auto input_node = input_value.get_node_shared_ptr(); auto input_node = input_value.get_node_shared_ptr();
auto replacement = std::make_shared<op::Constant>(input_value.get_tensor().get_lower_value()); auto replacement = std::make_shared<ngraph::op::Constant>(input_value.get_tensor().get_lower_value());
if (replacement && !ov::is_type<op::Constant>(input_node)) { if (replacement && !ov::is_type<ngraph::op::Constant>(input_node)) {
if (input_node->get_output_size() == 1) { if (input_node->get_output_size() == 1) {
replacement->set_friendly_name(input_node->get_friendly_name()); replacement->set_friendly_name(input_node->get_friendly_name());
} else { } else {

View File

@ -9,12 +9,11 @@
#include "transformations/convert_precision.hpp" #include "transformations/convert_precision.hpp"
using namespace std; using namespace std;
using namespace ngraph;
NGRAPH_RTTI_DEFINITION(ngraph::pass::ConvertFP32ToFP16, "ConvertFP32ToFP16", 0); OPENVINO_RTTI_DEFINITION(ov::pass::ConvertFP32ToFP16, "ConvertFP32ToFP16", 0);
bool ngraph::pass::ConvertFP32ToFP16::run_on_function(std::shared_ptr<ngraph::Function> f) { bool ov::pass::ConvertFP32ToFP16::run_on_function(std::shared_ptr<ov::Function> f) {
ngraph::pass::Manager m(get_pass_config()); ov::pass::Manager m(get_pass_config());
m.register_pass<ngraph::pass::ConvertPrecision>(precisions_array{{ngraph::element::f32, ngraph::element::f16}}); m.register_pass<ngraph::pass::ConvertPrecision>(precisions_array{{ngraph::element::f32, ngraph::element::f16}});
m.run_passes(f); m.run_passes(f);
return false; return false;

View File

@ -18,9 +18,6 @@
#include "ngraph/op/util/sub_graph_base.hpp" #include "ngraph/op/util/sub_graph_base.hpp"
#include "perf_counters.hpp" #include "perf_counters.hpp"
using namespace std;
using namespace ngraph;
/* GraphRewrite algorithm: /* GraphRewrite algorithm:
* GraphRewrite processes an input graph in an topological order(i.e. args before users) * GraphRewrite processes an input graph in an topological order(i.e. args before users)
* Given the following graph: Abs2 * Given the following graph: Abs2
@ -33,7 +30,7 @@ using namespace ngraph;
* Note, `Abs2` comes before `Neg3` as `Abs2`'s id = 2 is *less* than `Neg3`'s one (id = 3) * Note, `Abs2` comes before `Neg3` as `Abs2`'s id = 2 is *less* than `Neg3`'s one (id = 3)
* Next, GraphRewrite will invoke matchers passes registered in add_matcher order. * Next, GraphRewrite will invoke matchers passes registered in add_matcher order.
* For example: * For example:
* ngraph::pass::GraphRewrite pass; * ov::pass::GraphRewrite pass;
* pass.add_matcher<m1>(); * pass.add_matcher<m1>();
* pass.add_matcher<m2>(); * pass.add_matcher<m2>();
* pass.add_matcher<m3>(); * pass.add_matcher<m3>();
@ -53,13 +50,13 @@ using namespace ngraph;
* If MatcherPass register more than one node make sure that this nodes are registered in * If MatcherPass register more than one node make sure that this nodes are registered in
* topological order. */ * topological order. */
NGRAPH_RTTI_DEFINITION(ngraph::pass::GraphRewrite, "ngraph::pass::GraphRewrite", 0); NGRAPH_RTTI_DEFINITION(ov::pass::GraphRewrite, "ov::pass::GraphRewrite", 0);
NGRAPH_RTTI_DEFINITION(ngraph::pass::BackwardGraphRewrite, "ngraph::pass::BackwardGraphRewrite", 0); NGRAPH_RTTI_DEFINITION(ov::pass::BackwardGraphRewrite, "ov::pass::BackwardGraphRewrite", 0);
NGRAPH_RTTI_DEFINITION(ngraph::pass::MatcherPass, "ngraph::pass::MatcherPass", 0); NGRAPH_RTTI_DEFINITION(ov::pass::MatcherPass, "ov::pass::MatcherPass", 0);
namespace ngraph { namespace ov {
namespace pass { namespace pass {
namespace internal { namespace internal {
PerfCounters& perf_counters_graph_rewrite() { PerfCounters& perf_counters_graph_rewrite() {
@ -68,27 +65,28 @@ PerfCounters& perf_counters_graph_rewrite() {
} }
} // namespace internal } // namespace internal
} // namespace pass } // namespace pass
} // namespace ngraph } // namespace ov
bool pass::BackwardGraphRewrite::run_on_function(std::shared_ptr<ngraph::Function> f) { bool ov::pass::BackwardGraphRewrite::run_on_function(std::shared_ptr<ov::Function> f) {
// Initialize execution queue with nodes in topological order // Initialize execution queue with nodes in topological order
deque<std::weak_ptr<Node>> nodes_to_run; std::deque<std::weak_ptr<Node>> nodes_to_run;
for (auto& node : f->get_ordered_ops()) { for (auto& node : f->get_ordered_ops()) {
nodes_to_run.emplace_front(node); nodes_to_run.emplace_front(node);
} }
return apply_matcher_passes(f, std::move(nodes_to_run)); return apply_matcher_passes(f, std::move(nodes_to_run));
} }
bool pass::GraphRewrite::run_on_function(std::shared_ptr<ngraph::Function> f) { bool ov::pass::GraphRewrite::run_on_function(std::shared_ptr<ov::Function> f) {
// Initialize execution queue with nodes in topological order // Initialize execution queue with nodes in topological order
deque<std::weak_ptr<Node>> nodes_to_run; std::deque<std::weak_ptr<Node>> nodes_to_run;
for (auto& node : f->get_ordered_ops()) { for (auto& node : f->get_ordered_ops()) {
nodes_to_run.emplace_back(node); nodes_to_run.emplace_back(node);
} }
return apply_matcher_passes(f, std::move(nodes_to_run)); return apply_matcher_passes(f, std::move(nodes_to_run));
} }
bool pass::GraphRewrite::apply_matcher_passes(shared_ptr<Function> f, deque<std::weak_ptr<Node>> nodes_to_run) { bool ov::pass::GraphRewrite::apply_matcher_passes(std::shared_ptr<Function> f,
std::deque<std::weak_ptr<Node>> nodes_to_run) {
OV_ITT_SCOPED_TASK(ov::itt::domains::nGraph, "pass::GraphRewrite::run_on_function"); OV_ITT_SCOPED_TASK(ov::itt::domains::nGraph, "pass::GraphRewrite::run_on_function");
bool rewritten = false; bool rewritten = false;
@ -111,7 +109,7 @@ bool pass::GraphRewrite::apply_matcher_passes(shared_ptr<Function> f, deque<std:
auto root = matcher->get_pattern_value().get_node_shared_ptr(); auto root = matcher->get_pattern_value().get_node_shared_ptr();
// pattern::op::AnyOutput operation automatically appends for multi output operations inside // pattern::op::AnyOutput operation automatically appends for multi output operations inside
// Matcher and to gen actual root node we need to take it's parent. // Matcher and to gen actual root node we need to take it's parent.
if (auto any_type = dynamic_pointer_cast<pattern::op::AnyOutput>(root)) { if (auto any_type = std::dynamic_pointer_cast<pattern::op::AnyOutput>(root)) {
root = any_type->input_value(0).get_node_shared_ptr(); root = any_type->input_value(0).get_node_shared_ptr();
} }
@ -119,8 +117,8 @@ bool pass::GraphRewrite::apply_matcher_passes(shared_ptr<Function> f, deque<std:
// it's type // it's type
// and use it in unordered_map as key for fast MatcherPass search. Otherwise type is unknown // and use it in unordered_map as key for fast MatcherPass search. Otherwise type is unknown
// and default algorithm is used. // and default algorithm is used.
if (auto p = dynamic_pointer_cast<pattern::op::Pattern>(root)) { if (auto p = std::dynamic_pointer_cast<pattern::op::Pattern>(root)) {
if (auto any_type = dynamic_pointer_cast<pattern::op::WrapType>(p)) { if (auto any_type = std::dynamic_pointer_cast<pattern::op::WrapType>(p)) {
for (const auto& root_type_info : any_type->get_wrapped_types()) { for (const auto& root_type_info : any_type->get_wrapped_types()) {
type_to_matcher[root_type_info].push_back(matcher_index); type_to_matcher[root_type_info].push_back(matcher_index);
} }
@ -180,7 +178,7 @@ bool pass::GraphRewrite::apply_matcher_passes(shared_ptr<Function> f, deque<std:
continue; continue;
// Recursive apply Matchers for sub-graph based nodes // Recursive apply Matchers for sub-graph based nodes
if (auto sub_graph_node = std::dynamic_pointer_cast<op::util::SubGraphOp>(node)) { if (auto sub_graph_node = std::dynamic_pointer_cast<ngraph::op::util::SubGraphOp>(node)) {
if (auto sub_graph = sub_graph_node->get_function()) { if (auto sub_graph = sub_graph_node->get_function()) {
run_on_function(sub_graph); run_on_function(sub_graph);
} }
@ -236,7 +234,7 @@ bool pass::GraphRewrite::apply_matcher_passes(shared_ptr<Function> f, deque<std:
return rewritten; return rewritten;
} }
void pass::GraphRewrite::add_matcher(const shared_ptr<pattern::Matcher>& m, void ov::pass::GraphRewrite::add_matcher(const std::shared_ptr<pattern::Matcher>& m,
const graph_rewrite_callback& callback, const graph_rewrite_callback& callback,
const PassPropertyMask& property) { const PassPropertyMask& property) {
m_matchers.push_back(std::make_shared<MatcherPass>( m_matchers.push_back(std::make_shared<MatcherPass>(
@ -258,7 +256,8 @@ void pass::GraphRewrite::add_matcher(const shared_ptr<pattern::Matcher>& m,
property)); property));
} }
void pass::GraphRewrite::add_matcher(const shared_ptr<pattern::Matcher>& m, const graph_rewrite_callback& callback) { void ov::pass::GraphRewrite::add_matcher(const std::shared_ptr<pattern::Matcher>& m,
const graph_rewrite_callback& callback) {
NGRAPH_SUPPRESS_DEPRECATED_START NGRAPH_SUPPRESS_DEPRECATED_START
// TODO: before deprecate this function, by default expect the // TODO: before deprecate this function, by default expect the
// callback require static shape. // callback require static shape.
@ -266,7 +265,7 @@ void pass::GraphRewrite::add_matcher(const shared_ptr<pattern::Matcher>& m, cons
NGRAPH_SUPPRESS_DEPRECATED_END NGRAPH_SUPPRESS_DEPRECATED_END
} }
void pass::GraphRewrite::set_pass_config(const std::shared_ptr<PassConfig>& rhs) { void ov::pass::GraphRewrite::set_pass_config(const std::shared_ptr<PassConfig>& rhs) {
auto pass_config = get_pass_config(); auto pass_config = get_pass_config();
// We have to preserve disabled passes because in case when we register matchers inside // We have to preserve disabled passes because in case when we register matchers inside
// GraphRewrite c-tor we work with local PassConfig instance. // GraphRewrite c-tor we work with local PassConfig instance.
@ -293,8 +292,8 @@ void pass::GraphRewrite::set_pass_config(const std::shared_ptr<PassConfig>& rhs)
} }
} }
void pass::RecurrentGraphRewrite::add_matcher(const std::shared_ptr<pattern::RecurrentMatcher>& m, void ov::pass::RecurrentGraphRewrite::add_matcher(const std::shared_ptr<pattern::RecurrentMatcher>& m,
const ngraph::recurrent_graph_rewrite_callback& callback, const ov::recurrent_graph_rewrite_callback& callback,
const PassPropertyMask& property) { const PassPropertyMask& property) {
m_matchers.push_back(std::make_shared<MatcherPass>( m_matchers.push_back(std::make_shared<MatcherPass>(
"Recurrent matcher", "Recurrent matcher",
@ -310,24 +309,24 @@ void pass::RecurrentGraphRewrite::add_matcher(const std::shared_ptr<pattern::Rec
property)); property));
} }
void pass::RecurrentGraphRewrite::add_matcher(const std::shared_ptr<pattern::RecurrentMatcher>& m, void ov::pass::RecurrentGraphRewrite::add_matcher(const std::shared_ptr<pattern::RecurrentMatcher>& m,
const ngraph::recurrent_graph_rewrite_callback& callback) { const ov::recurrent_graph_rewrite_callback& callback) {
// TODO: before deprecate this function, by default expect the // TODO: before deprecate this function, by default expect the
// callback require static shape. // callback require static shape.
add_matcher(m, callback, {PassProperty::REQUIRE_STATIC_SHAPE}); add_matcher(m, callback, {PassProperty::REQUIRE_STATIC_SHAPE});
} }
bool pass::RecurrentGraphRewrite::run_on_function(shared_ptr<Function> f) { bool ov::pass::RecurrentGraphRewrite::run_on_function(std::shared_ptr<Function> f) {
bool changed = false; bool changed = false;
size_t i = 0; size_t i = 0;
// This check is very expensive and is only needed for experimental features, so we will hide // This check is very expensive and is only needed for experimental features, so we will hide
// it behind an environment variable for now. TODO: Find a less expensive way to handle this. // it behind an environment variable for now. TODO: Find a less expensive way to handle this.
static bool s_rerun_dynamic_check = getenv_bool("NGRAPH_GRAPH_REWRITE_RERUN_DYNAMIC_CHECK"); static bool s_rerun_dynamic_check = ngraph::getenv_bool("NGRAPH_GRAPH_REWRITE_RERUN_DYNAMIC_CHECK");
auto run_matchers = [&]() -> bool { auto run_matchers = [&]() -> bool {
bool is_dyn_func = s_rerun_dynamic_check && f->is_dynamic(); bool is_dyn_func = s_rerun_dynamic_check && f->is_dynamic();
for (auto node : f->get_ops()) { for (const auto& node : f->get_ops()) {
for (auto& m_pass : m_matchers) { for (auto& m_pass : m_matchers) {
if (is_dyn_func && m_pass->get_property(PassProperty::REQUIRE_STATIC_SHAPE)) { if (is_dyn_func && m_pass->get_property(PassProperty::REQUIRE_STATIC_SHAPE)) {
NGRAPH_DEBUG << "matcher callback requires static shape but the " NGRAPH_DEBUG << "matcher callback requires static shape but the "
@ -356,8 +355,8 @@ bool pass::RecurrentGraphRewrite::run_on_function(shared_ptr<Function> f) {
return changed; return changed;
} }
void ngraph::pass::MatcherPass::register_matcher(const std::shared_ptr<ngraph::pattern::Matcher>& m, void ov::pass::MatcherPass::register_matcher(const std::shared_ptr<ov::pass::pattern::Matcher>& m,
const ngraph::graph_rewrite_callback& callback, const ov::graph_rewrite_callback& callback,
const PassPropertyMask& property) { const PassPropertyMask& property) {
set_name(m->get_name()); set_name(m->get_name());
set_property(property, true); set_property(property, true);
@ -376,7 +375,7 @@ void ngraph::pass::MatcherPass::register_matcher(const std::shared_ptr<ngraph::p
}; };
} }
bool ngraph::pass::MatcherPass::apply(std::shared_ptr<ngraph::Node> node) { bool ov::pass::MatcherPass::apply(std::shared_ptr<ov::Node> node) {
OV_ITT_SCOPED_TASK(ov::itt::domains::nGraph, pass::internal::perf_counters_graph_rewrite()[get_type_info()]); OV_ITT_SCOPED_TASK(ov::itt::domains::nGraph, pass::internal::perf_counters_graph_rewrite()[get_type_info()]);
m_new_nodes.clear(); m_new_nodes.clear();
if (m_handler) if (m_handler)

View File

@ -12,13 +12,12 @@
#include <ngraph/rt_info.hpp> #include <ngraph/rt_info.hpp>
#include <ngraph/variant.hpp> #include <ngraph/variant.hpp>
NGRAPH_RTTI_DEFINITION(ngraph::pass::LowLatency2, "LowLatency2", 0); NGRAPH_RTTI_DEFINITION(ov::pass::LowLatency2, "LowLatency2", 0);
NGRAPH_SUPPRESS_DEPRECATED_START NGRAPH_SUPPRESS_DEPRECATED_START
NGRAPH_RTTI_DEFINITION(ngraph::pass::LowLatency, "LowLatency", 0); NGRAPH_RTTI_DEFINITION(ngraph::pass::LowLatency, "LowLatency", 0);
using namespace std; using namespace std;
using namespace ngraph;
namespace { namespace {
string generate_variable_name(const string& op_name, const string& param_name, int variable_idx) { string generate_variable_name(const string& op_name, const string& param_name, int variable_idx) {
@ -27,8 +26,8 @@ string generate_variable_name(const string& op_name, const string& param_name, i
} // namespace } // namespace
ngraph::pass::LowLatency::LowLatency() { ngraph::pass::LowLatency::LowLatency() {
auto tensor_iterator = ngraph::pattern::wrap_type<opset6::TensorIterator, opset6::Loop>(); auto tensor_iterator = ov::pass::pattern::wrap_type<opset6::TensorIterator, opset6::Loop>();
ngraph::matcher_pass_callback callback = [](ngraph::pattern::Matcher& m) { ov::matcher_pass_callback callback = [](ov::pass::pattern::Matcher& m) {
const auto& sub_graph_op = std::dynamic_pointer_cast<ngraph::op::util::SubGraphOp>(m.get_match_root()); const auto& sub_graph_op = std::dynamic_pointer_cast<ngraph::op::util::SubGraphOp>(m.get_match_root());
if (!sub_graph_op) { if (!sub_graph_op) {
return false; return false;
@ -38,7 +37,7 @@ ngraph::pass::LowLatency::LowLatency() {
const auto& trip_count = std::dynamic_pointer_cast<opset6::Constant>(loop->get_input_node_shared_ptr(0)); const auto& trip_count = std::dynamic_pointer_cast<opset6::Constant>(loop->get_input_node_shared_ptr(0));
const auto& num_iter = loop->get_num_iterations(); const auto& num_iter = loop->get_num_iterations();
if (trip_count && num_iter > 0 && trip_count->get_output_target_inputs(0).size() == 1) { if (trip_count && num_iter > 0 && trip_count->get_output_target_inputs(0).size() == 1) {
auto single_iter = std::make_shared<opset6::Constant>(ngraph::element::i64, Shape{}, 1); auto single_iter = std::make_shared<opset6::Constant>(ov::element::i64, Shape{}, 1);
replace_node(trip_count, single_iter); replace_node(trip_count, single_iter);
} else { } else {
// count of iterations is dynamic; // count of iterations is dynamic;
@ -47,7 +46,7 @@ ngraph::pass::LowLatency::LowLatency() {
} }
// Mark the TI layer to be unrolled. Enable unconditional ti unrolling for all plugins. // Mark the TI layer to be unrolled. Enable unconditional ti unrolling for all plugins.
auto& rt_info = sub_graph_op->get_rt_info(); auto& rt_info = sub_graph_op->get_rt_info();
rt_info["UNROLL_TI"] = std::make_shared<ngraph::VariantWrapper<int64_t>>(1); rt_info["UNROLL_TI"] = std::make_shared<ov::VariantWrapper<int64_t>>(1);
int64_t variable_id = 0; int64_t variable_id = 0;
std::vector<std::shared_ptr<ngraph::op::Sink>> assigns; std::vector<std::shared_ptr<ngraph::op::Sink>> assigns;
@ -87,13 +86,14 @@ ngraph::pass::LowLatency::LowLatency() {
return false; return false;
}; };
auto m = std::make_shared<ngraph::pattern::Matcher>(tensor_iterator, "LowLatency"); auto m = std::make_shared<ov::pass::pattern::Matcher>(tensor_iterator, "LowLatency");
register_matcher(m, callback); register_matcher(m, callback);
} }
NGRAPH_SUPPRESS_DEPRECATED_END NGRAPH_SUPPRESS_DEPRECATED_END
void UnrollSingleIteration(const shared_ptr<op::util::SubGraphOp>& sub_graph_op, const shared_ptr<Function>& outer_f) { void UnrollSingleIteration(const shared_ptr<ngraph::op::util::SubGraphOp>& sub_graph_op,
using namespace opset7; const shared_ptr<ov::Function>& outer_f) {
using namespace ngraph::opset7;
const auto& params = sub_graph_op->get_function()->get_parameters(); const auto& params = sub_graph_op->get_function()->get_parameters();
const auto& results = sub_graph_op->get_function()->get_results(); const auto& results = sub_graph_op->get_function()->get_results();
@ -109,7 +109,7 @@ void UnrollSingleIteration(const shared_ptr<op::util::SubGraphOp>& sub_graph_op,
// before: TI [...-> Layer1 -> Result -> output] -> Layer2 -> ... // before: TI [...-> Layer1 -> Result -> output] -> Layer2 -> ...
// after: ...-> Layer1 -> Layer2 -> ... // after: ...-> Layer1 -> Layer2 -> ...
NodeVector new_ops; ov::NodeVector new_ops;
for (const auto& out : sub_graph_op->get_output_descriptions()) { for (const auto& out : sub_graph_op->get_output_descriptions()) {
const auto& connect_to = results.at(out->m_body_value_index)->get_input_source_output(0); const auto& connect_to = results.at(out->m_body_value_index)->get_input_source_output(0);
for (auto& input_to : sub_graph_op->output(out->m_output_index).get_target_inputs()) { for (auto& input_to : sub_graph_op->output(out->m_output_index).get_target_inputs()) {
@ -120,7 +120,7 @@ void UnrollSingleIteration(const shared_ptr<op::util::SubGraphOp>& sub_graph_op,
// IECompatibility: insert identity (Unsqueeze + Squeeze) to store the TensorIterator // IECompatibility: insert identity (Unsqueeze + Squeeze) to store the TensorIterator
// output names // output names
auto axis_1 = Constant::create(ngraph::element::i64, ngraph::Shape{1}, {1}); auto axis_1 = Constant::create(ov::element::i64, ngraph::Shape{1}, {1});
auto identity_1 = std::make_shared<Unsqueeze>(connect_to, axis_1); auto identity_1 = std::make_shared<Unsqueeze>(connect_to, axis_1);
auto identity_2 = std::make_shared<Squeeze>(identity_1, axis_1); auto identity_2 = std::make_shared<Squeeze>(identity_1, axis_1);
identity_2->set_friendly_name(out_name); identity_2->set_friendly_name(out_name);
@ -135,36 +135,38 @@ void UnrollSingleIteration(const shared_ptr<op::util::SubGraphOp>& sub_graph_op,
ngraph::copy_runtime_info(sub_graph_op, new_ops); ngraph::copy_runtime_info(sub_graph_op, new_ops);
} }
Output<Node> create_init_subgraph(const shared_ptr<op::util::SubGraphOp>& sub_graph_op, const Output<Node>& in_node) { ngraph::Output<ngraph::Node> create_init_subgraph(const shared_ptr<ngraph::op::util::SubGraphOp>& sub_graph_op,
using namespace opset7; const ngraph::Output<ngraph::Node>& in_node) {
using namespace ngraph::opset7;
auto const_zero = make_shared<Constant>(in_node.get_element_type(), Shape{1}, 0); auto const_zero = make_shared<Constant>(in_node.get_element_type(), ngraph::Shape{1}, 0);
auto shape_of = make_shared<ShapeOf>(in_node); auto shape_of = make_shared<ShapeOf>(in_node);
auto broadcast = make_shared<Broadcast>(const_zero, shape_of); auto broadcast = make_shared<Broadcast>(const_zero, shape_of);
copy_runtime_info(sub_graph_op, {const_zero, shape_of, broadcast}); copy_runtime_info(sub_graph_op, {const_zero, shape_of, broadcast});
return broadcast->output(0); return broadcast->output(0);
} }
bool pass::LowLatency2::run_on_function(shared_ptr<Function> f) { bool ov::pass::LowLatency2::run_on_function(shared_ptr<Function> f) {
using namespace opset7; using namespace ngraph::opset7;
SinkVector assigns; ngraph::SinkVector assigns;
for (const auto& op : f->get_ordered_ops()) { for (const auto& op : f->get_ordered_ops()) {
if (const auto& sub_graph_op = dynamic_pointer_cast<op::util::SubGraphOp>(op)) { if (const auto& sub_graph_op = dynamic_pointer_cast<ngraph::op::util::SubGraphOp>(op)) {
int64_t variable_id = 0; int64_t variable_id = 0;
const auto& func = sub_graph_op->get_function(); const auto& func = sub_graph_op->get_function();
const auto& params = func->get_parameters(); const auto& params = func->get_parameters();
for (const auto& in : sub_graph_op->get_input_descriptions()) { for (const auto& in : sub_graph_op->get_input_descriptions()) {
// Process all back edges // Process all back edges
if (const auto& merged_in = dynamic_pointer_cast<op::util::SubGraphOp::MergedInputDescription>(in)) { if (const auto& merged_in =
dynamic_pointer_cast<ngraph::op::util::SubGraphOp::MergedInputDescription>(in)) {
// create new Variable // create new Variable
const string& param_name = params.at(merged_in->m_body_parameter_index)->get_friendly_name(); const string& param_name = params.at(merged_in->m_body_parameter_index)->get_friendly_name();
const string& var_name = const string& var_name =
generate_variable_name(sub_graph_op->get_friendly_name(), param_name, variable_id); generate_variable_name(sub_graph_op->get_friendly_name(), param_name, variable_id);
const auto& input = sub_graph_op->input(merged_in->m_input_index); const auto& input = sub_graph_op->input(merged_in->m_input_index);
if (std::dynamic_pointer_cast<op::ReadValueBase>(input.get_source_output().get_node_shared_ptr()) != if (std::dynamic_pointer_cast<ngraph::op::ReadValueBase>(
nullptr) { input.get_source_output().get_node_shared_ptr()) != nullptr) {
NGRAPH_DEBUG << "LowLatency2 transformation cannot be applied because the " NGRAPH_DEBUG << "LowLatency2 transformation cannot be applied because the "
<< "ReadValue node is already an input to the TensorIterator." << "ReadValue node is already an input to the TensorIterator."
<< "LowLatency2 transformation may have already been applied, please " << "LowLatency2 transformation may have already been applied, please "
@ -175,7 +177,7 @@ bool pass::LowLatency2::run_on_function(shared_ptr<Function> f) {
const auto& param = const auto& param =
sub_graph_op->get_function()->get_parameters().at(merged_in->m_body_parameter_index); sub_graph_op->get_function()->get_parameters().at(merged_in->m_body_parameter_index);
for (const auto& in_to : param->output(0).get_target_inputs()) { for (const auto& in_to : param->output(0).get_target_inputs()) {
if (dynamic_cast<op::ReadValueBase*>(in_to.get_node()) != nullptr) { if (dynamic_cast<ngraph::op::ReadValueBase*>(in_to.get_node()) != nullptr) {
NGRAPH_DEBUG << "LowLatency2 transformation cannot be applied because the " NGRAPH_DEBUG << "LowLatency2 transformation cannot be applied because the "
<< "ReadValue node is already inside the TensorIterator. " << "ReadValue node is already inside the TensorIterator. "
<< "LowLatency transformation may have been applied, please do " << "LowLatency transformation may have been applied, please do "
@ -184,8 +186,8 @@ bool pass::LowLatency2::run_on_function(shared_ptr<Function> f) {
} }
} }
VariableInfo var_info{PartialShape::dynamic(), element::dynamic, var_name}; ngraph::VariableInfo var_info{PartialShape::dynamic(), element::dynamic, var_name};
auto variable = make_shared<Variable>(var_info); auto variable = make_shared<ngraph::Variable>(var_info);
// insert ReadValue // insert ReadValue
// Layers -> [new op: ReadValue] -> Subgraph operation // Layers -> [new op: ReadValue] -> Subgraph operation
@ -204,10 +206,10 @@ bool pass::LowLatency2::run_on_function(shared_ptr<Function> f) {
// ---> Layers -> ... // ---> Layers -> ...
*/ */
const auto& out_desc = sub_graph_op->get_output_descriptions(); const auto& out_desc = sub_graph_op->get_output_descriptions();
bool is_output_exist = bool is_output_exist = std::any_of(
std::any_of(out_desc.begin(), out_desc.begin(),
out_desc.end(), out_desc.end(),
[&merged_in](const std::shared_ptr<op::util::SubGraphOp::OutputDescription>& out) { [&merged_in](const std::shared_ptr<ngraph::op::util::SubGraphOp::OutputDescription>& out) {
return out->m_body_value_index == merged_in->m_body_value_index; return out->m_body_value_index == merged_in->m_body_value_index;
}); });
// Create new output if it doesn't exist. // Create new output if it doesn't exist.
@ -217,7 +219,7 @@ bool pass::LowLatency2::run_on_function(shared_ptr<Function> f) {
for (const auto& out : sub_graph_op->get_output_descriptions()) { for (const auto& out : sub_graph_op->get_output_descriptions()) {
if (out->m_body_value_index == merged_in->m_body_value_index) { if (out->m_body_value_index == merged_in->m_body_value_index) {
auto assign = make_shared<Assign>(sub_graph_op->output(out->m_output_index), variable); auto assign = make_shared<Assign>(sub_graph_op->output(out->m_output_index), variable);
ngraph::copy_runtime_info(sub_graph_op, assign); copy_runtime_info(sub_graph_op, assign);
// control dependency so that ReadValue is processed before Assign // control dependency so that ReadValue is processed before Assign
assign->add_control_dependency(read_value); assign->add_control_dependency(read_value);
assigns.emplace_back(assign); assigns.emplace_back(assign);

View File

@ -24,9 +24,8 @@
#include "perf_counters.hpp" #include "perf_counters.hpp"
using namespace std; using namespace std;
using namespace ngraph;
namespace ngraph { namespace ov {
namespace pass { namespace pass {
namespace internal { namespace internal {
PerfCounters& perf_counters() { PerfCounters& perf_counters() {
@ -35,25 +34,25 @@ PerfCounters& perf_counters() {
} }
} // namespace internal } // namespace internal
} // namespace pass } // namespace pass
} // namespace ngraph } // namespace ov
pass::Manager::Manager() ov::pass::Manager::Manager()
: m_pass_config(std::make_shared<PassConfig>()), : m_pass_config(std::make_shared<PassConfig>()),
m_visualize(getenv_bool("NGRAPH_ENABLE_VISUALIZE_TRACING")) {} m_visualize(ngraph::getenv_bool("NGRAPH_ENABLE_VISUALIZE_TRACING")) {}
pass::Manager::~Manager() {} ov::pass::Manager::~Manager() = default;
pass::Manager::Manager(std::shared_ptr<ngraph::pass::PassConfig> pass_config) : m_pass_config(std::move(pass_config)) {} ov::pass::Manager::Manager(std::shared_ptr<ov::pass::PassConfig> pass_config) : m_pass_config(std::move(pass_config)) {}
void pass::Manager::run_passes(shared_ptr<Function> func) { void ov::pass::Manager::run_passes(shared_ptr<ov::Function> func) {
NGRAPH_SUPPRESS_DEPRECATED_START NGRAPH_SUPPRESS_DEPRECATED_START
OV_ITT_SCOPED_TASK(ov::itt::domains::nGraph, "pass::Manager::run_passes"); OV_ITT_SCOPED_TASK(ov::itt::domains::nGraph, "pass::Manager::run_passes");
static bool profile_enabled = getenv_bool("NGRAPH_PROFILE_PASS_ENABLE"); static bool profile_enabled = ngraph::getenv_bool("NGRAPH_PROFILE_PASS_ENABLE");
size_t index = 0; size_t index = 0;
stopwatch pass_timer; ngraph::stopwatch pass_timer;
stopwatch overall_timer; ngraph::stopwatch overall_timer;
overall_timer.start(); overall_timer.start();
bool function_changed = false; bool function_changed = false;
for (auto& pass : m_pass_list) { for (auto& pass : m_pass_list) {
@ -96,13 +95,13 @@ void pass::Manager::run_passes(shared_ptr<Function> func) {
} else { } else {
function_changed = function_pass->run_on_function(func); function_changed = function_pass->run_on_function(func);
} }
} else if (auto node_pass = dynamic_pointer_cast<NodePass>(pass)) { } else if (auto node_pass = dynamic_pointer_cast<ngraph::pass::NodePass>(pass)) {
if (node_pass->get_property(PassProperty::REQUIRE_STATIC_SHAPE) && func->is_dynamic()) { if (node_pass->get_property(PassProperty::REQUIRE_STATIC_SHAPE) && func->is_dynamic()) {
NGRAPH_DEBUG << "Pass " << pass->get_name() << " requires static shape but the " NGRAPH_DEBUG << "Pass " << pass->get_name() << " requires static shape but the "
<< "function is dynamic. Skipping this transformation"; << "function is dynamic. Skipping this transformation";
continue; continue;
} }
for (shared_ptr<Node> n : func->get_ops()) { for (const shared_ptr<Node>& n : func->get_ops()) {
function_changed |= node_pass->run_on_node(n); function_changed |= node_pass->run_on_node(n);
} }
} }
@ -115,7 +114,7 @@ void pass::Manager::run_passes(shared_ptr<Function> func) {
auto base_filename = func->get_name() + std::string("_") + index_str + std::string("_") + pass->get_name(); auto base_filename = func->get_name() + std::string("_") + index_str + std::string("_") + pass->get_name();
if (m_visualize) { if (m_visualize) {
static const string format = getenv_string("NGRAPH_VISUALIZE_TRACING_FORMAT"); static const string format = ngraph::getenv_string("NGRAPH_VISUALIZE_TRACING_FORMAT");
auto file_ext = format.empty() ? "svg" : format; auto file_ext = format.empty() ? "svg" : format;
pass::VisualizeTree vt(base_filename + std::string(".") + file_ext); pass::VisualizeTree vt(base_filename + std::string(".") + file_ext);
vt.run_on_function(func); vt.run_on_function(func);

View File

@ -7,21 +7,20 @@
# include <cxxabi.h> # include <cxxabi.h>
#endif #endif
#include "ngraph/pass/manager.hpp"
#include "ngraph/pass/pass.hpp" #include "ngraph/pass/pass.hpp"
#include "openvino/pass/manager.hpp"
using namespace std; using namespace std;
using namespace ngraph;
NGRAPH_RTTI_DEFINITION(ngraph::pass::FunctionPass, "ngraph::pass::FunctionPass", 0); OPENVINO_RTTI_DEFINITION(ov::pass::FunctionPass, "ov::pass::FunctionPass", 0);
pass::PassBase::PassBase() : m_property{all_pass_property_off}, m_pass_config(std::make_shared<PassConfig>()) {} ov::pass::PassBase::PassBase() : m_property(), m_pass_config(std::make_shared<PassConfig>()) {}
bool pass::PassBase::get_property(const PassPropertyMask& prop) const { bool ov::pass::PassBase::get_property(const PassPropertyMask& prop) const {
return m_property.is_set(prop); return m_property.is_set(prop);
} }
void pass::PassBase::set_property(const PassPropertyMask& prop, bool value) { void ov::pass::PassBase::set_property(const PassPropertyMask& prop, bool value) {
if (value) { if (value) {
m_property.set(prop); m_property.set(prop);
} else { } else {
@ -29,7 +28,7 @@ void pass::PassBase::set_property(const PassPropertyMask& prop, bool value) {
} }
} }
std::string pass::PassBase::get_name() const { std::string ov::pass::PassBase::get_name() const {
if (m_name.empty()) { if (m_name.empty()) {
const PassBase* p = this; const PassBase* p = this;
std::string pass_name = typeid(*p).name(); std::string pass_name = typeid(*p).name();
@ -43,16 +42,16 @@ std::string pass::PassBase::get_name() const {
} }
} }
void pass::PassBase::set_callback(const param_callback& callback) { void ov::pass::PassBase::set_callback(const param_callback& callback) {
m_pass_config->set_callback(callback); m_pass_config->set_callback(callback);
} }
// The symbols are requiered to be in cpp file to workaround RTTI issue on Android LLVM // The symbols are requiered to be in cpp file to workaround RTTI issue on Android LLVM
pass::FunctionPass::~FunctionPass() {} ov::pass::FunctionPass::~FunctionPass() = default;
NGRAPH_SUPPRESS_DEPRECATED_START OPENVINO_SUPPRESS_DEPRECATED_START
NGRAPH_RTTI_DEFINITION(ngraph::pass::NodePass, "ngraph::pass::NodePass", 0); OPENVINO_RTTI_DEFINITION(ngraph::pass::NodePass, "ngraph::pass::NodePass", 0);
pass::NodePass::~NodePass() {} ngraph::pass::NodePass::~NodePass() = default;

View File

@ -2,11 +2,9 @@
// SPDX-License-Identifier: Apache-2.0 // SPDX-License-Identifier: Apache-2.0
// //
#include "ngraph/pass/pass_config.hpp" #include "openvino/pass/pass_config.hpp"
using namespace ngraph; ov::pass::param_callback ov::pass::PassConfig::get_callback(const DiscreteTypeInfo& type_info) const {
pass::param_callback pass::PassConfig::get_callback(const DiscreteTypeInfo& type_info) const {
const auto& it = m_callback_map.find(type_info); const auto& it = m_callback_map.find(type_info);
if (it != m_callback_map.end()) { if (it != m_callback_map.end()) {
return it->second; return it->second;
@ -15,17 +13,17 @@ pass::param_callback pass::PassConfig::get_callback(const DiscreteTypeInfo& type
} }
} }
void pass::PassConfig::enable(const ngraph::DiscreteTypeInfo& type_info) { void ov::pass::PassConfig::enable(const ngraph::DiscreteTypeInfo& type_info) {
m_disabled.erase(type_info); m_disabled.erase(type_info);
m_enabled.insert(type_info); m_enabled.insert(type_info);
} }
void pass::PassConfig::disable(const ngraph::DiscreteTypeInfo& type_info) { void ov::pass::PassConfig::disable(const ngraph::DiscreteTypeInfo& type_info) {
m_enabled.erase(type_info); m_enabled.erase(type_info);
m_disabled.insert(type_info); m_disabled.insert(type_info);
} }
void pass::PassConfig::add_disabled_passes(const PassConfig& rhs) { void ov::pass::PassConfig::add_disabled_passes(const PassConfig& rhs) {
for (const auto& pass : rhs.m_disabled) { for (const auto& pass : rhs.m_disabled) {
if (is_enabled(pass)) if (is_enabled(pass))
continue; continue;

View File

@ -3,7 +3,7 @@
// //
#include "perf_counters.hpp" #include "perf_counters.hpp"
namespace ngraph { namespace ov {
namespace pass { namespace pass {
openvino::itt::handle_t PerfCounters::operator[](::ngraph::Node::type_info_t const& type_inf) { openvino::itt::handle_t PerfCounters::operator[](::ngraph::Node::type_info_t const& type_inf) {
std::lock_guard<std::mutex> guard(m_mutex); std::lock_guard<std::mutex> guard(m_mutex);
@ -13,4 +13,4 @@ openvino::itt::handle_t PerfCounters::operator[](::ngraph::Node::type_info_t con
return m_counters[&type_inf] = openvino::itt::handle(type_inf.name); return m_counters[&type_inf] = openvino::itt::handle(type_inf.name);
} }
} // namespace pass } // namespace pass
} // namespace ngraph } // namespace ov

View File

@ -7,7 +7,7 @@
#include <ngraph/node.hpp> #include <ngraph/node.hpp>
#include <unordered_map> #include <unordered_map>
namespace ngraph { namespace ov {
namespace pass { namespace pass {
class PerfCounters { class PerfCounters {
PerfCounters(PerfCounters const&) = delete; PerfCounters(PerfCounters const&) = delete;
@ -27,4 +27,4 @@ private:
counters_map m_counters; counters_map m_counters;
}; };
} // namespace pass } // namespace pass
} // namespace ngraph } // namespace ov

View File

@ -2,16 +2,16 @@
// SPDX-License-Identifier: Apache-2.0 // SPDX-License-Identifier: Apache-2.0
// //
#include "ngraph/pass/validate.hpp" #include "openvino/pass/validate.hpp"
#include "itt.hpp" #include "itt.hpp"
#include "ngraph/graph_util.hpp" #include "ngraph/graph_util.hpp"
using namespace ngraph; using namespace ngraph;
NGRAPH_RTTI_DEFINITION(ngraph::pass::Validate, "ngraph::pass::Validate", 0); OPENVINO_RTTI_DEFINITION(ov::pass::Validate, "ov::pass::Validate", 0);
bool pass::Validate::run_on_function(std::shared_ptr<Function> f) { bool ov::pass::Validate::run_on_function(std::shared_ptr<Function> f) {
f->validate_nodes_and_infer_types(); f->validate_nodes_and_infer_types();
return false; return false;
} }

View File

@ -13,7 +13,8 @@
#include "ngraph/op/parameter.hpp" #include "ngraph/op/parameter.hpp"
#include "ngraph/op/util/op_types.hpp" #include "ngraph/op/util/op_types.hpp"
namespace ngraph { namespace ov {
namespace pass {
namespace pattern { namespace pattern {
MatcherState::MatcherState(Matcher* matcher) MatcherState::MatcherState(Matcher* matcher)
: m_matcher(matcher), : m_matcher(matcher),
@ -88,7 +89,7 @@ bool Matcher::is_contained_match(const NodeVector& exclusions, bool ignore_unuse
NGRAPH_SUPPRESS_DEPRECATED_START NGRAPH_SUPPRESS_DEPRECATED_START
if (exclusions.empty()) { if (exclusions.empty()) {
NodeVector label_exclusions; NodeVector label_exclusions;
for (auto entry : m_pattern_map) { for (const auto& entry : m_pattern_map) {
// leaf label // leaf label
if (entry.first->get_input_size() == 0) { if (entry.first->get_input_size() == 0) {
label_exclusions.push_back(entry.second.get_node_shared_ptr()); label_exclusions.push_back(entry.second.get_node_shared_ptr());
@ -108,7 +109,7 @@ bool Matcher::match_value(const ngraph::Output<Node>& pattern_value, const ngrap
// This env var allows one to specify node name patterns to abort pattern matching // This env var allows one to specify node name patterns to abort pattern matching
// at particular nodes. The upshot is that one can quickly zero in on an offending // at particular nodes. The upshot is that one can quickly zero in on an offending
// fusion by disabling individual fusions or optimizations that use Matcher. // fusion by disabling individual fusions or optimizations that use Matcher.
static const std::string node_skip_cregex = getenv_string("NGRAPH_FAIL_MATCH_AT"); static const std::string node_skip_cregex = ngraph::getenv_string("NGRAPH_FAIL_MATCH_AT");
if (!node_skip_cregex.empty()) { if (!node_skip_cregex.empty()) {
static const std::regex node_skip_regex(node_skip_cregex); static const std::regex node_skip_regex(node_skip_cregex);
if (std::regex_match(graph_node->get_name(), node_skip_regex)) { if (std::regex_match(graph_node->get_name(), node_skip_regex)) {
@ -201,7 +202,7 @@ void Matcher::clear_state() {
namespace { namespace {
std::set<std::shared_ptr<Node>> as_node_set(const std::set<std::shared_ptr<op::Label>>& label_set) { std::set<std::shared_ptr<Node>> as_node_set(const std::set<std::shared_ptr<op::Label>>& label_set) {
std::set<std::shared_ptr<Node>> result; std::set<std::shared_ptr<Node>> result;
for (auto label : label_set) { for (const auto& label : label_set) {
result.insert(label); result.insert(label);
} }
return result; return result;
@ -230,7 +231,7 @@ bool RecurrentMatcher::match(Output<Node> graph) {
graph = m.get_pattern_value_map()[m_recurrent_pattern]; graph = m.get_pattern_value_map()[m_recurrent_pattern];
// copy bound nodes for the current pattern graph into a global matches map // copy bound nodes for the current pattern graph into a global matches map
for (auto cur_match : m.get_pattern_value_map()) { for (const auto& cur_match : m.get_pattern_value_map()) {
m_matches[cur_match.first].push_back(cur_match.second); m_matches[cur_match.first].push_back(cur_match.second);
} }
@ -238,7 +239,7 @@ bool RecurrentMatcher::match(Output<Node> graph) {
// from the current match. Only bound nodes whose labels are in // from the current match. Only bound nodes whose labels are in
// correlated_patterns are pre-populated. Skip other labels are // correlated_patterns are pre-populated. Skip other labels are
// unbounded by default // unbounded by default
for (auto cor_pat : m_correlated_patterns) { for (const auto& cor_pat : m_correlated_patterns) {
previous_matches[cor_pat] = m.get_pattern_value_map()[cor_pat]; previous_matches[cor_pat] = m.get_pattern_value_map()[cor_pat];
} }
m = m_repeat; m = m_repeat;
@ -251,4 +252,5 @@ bool RecurrentMatcher::match(Output<Node> graph) {
return matched; return matched;
} }
} // namespace pattern } // namespace pattern
} // namespace ngraph } // namespace pass
} // namespace ov

View File

@ -7,15 +7,14 @@
#include "ngraph/pattern/matcher.hpp" #include "ngraph/pattern/matcher.hpp"
using namespace std; using namespace std;
using namespace ngraph;
constexpr NodeTypeInfo pattern::op::Any::type_info; constexpr ov::NodeTypeInfo ov::pass::pattern::op::Any::type_info;
const NodeTypeInfo& pattern::op::Any::get_type_info() const { const ov::NodeTypeInfo& ov::pass::pattern::op::Any::get_type_info() const {
return type_info; return type_info;
} }
bool pattern::op::Any::match_value(Matcher* matcher, bool ov::pass::pattern::op::Any::match_value(Matcher* matcher,
const Output<Node>& pattern_value, const Output<Node>& pattern_value,
const Output<Node>& graph_value) { const Output<Node>& graph_value) {
matcher->add_node(graph_value); matcher->add_node(graph_value);

View File

@ -7,20 +7,19 @@
#include "ngraph/pattern/matcher.hpp" #include "ngraph/pattern/matcher.hpp"
using namespace std; using namespace std;
using namespace ngraph;
constexpr NodeTypeInfo pattern::op::AnyOf::type_info; constexpr ov::NodeTypeInfo ov::pass::pattern::op::AnyOf::type_info;
const NodeTypeInfo& pattern::op::AnyOf::get_type_info() const { const ov::NodeTypeInfo& ov::pass::pattern::op::AnyOf::get_type_info() const {
return type_info; return type_info;
} }
bool pattern::op::AnyOf::match_value(Matcher* matcher, bool ov::pass::pattern::op::AnyOf::match_value(Matcher* matcher,
const Output<Node>& pattern_value, const Output<Node>& pattern_value,
const Output<Node>& graph_value) { const Output<Node>& graph_value) {
matcher->add_node(graph_value); matcher->add_node(graph_value);
return m_predicate(graph_value) && ([&]() { return m_predicate(graph_value) && ([&]() {
for (auto arg : graph_value.get_node_shared_ptr()->input_values()) { for (const auto& arg : graph_value.get_node_shared_ptr()->input_values()) {
auto saved = matcher->start_match(); auto saved = matcher->start_match();
if (matcher->match_value(input_value(0), arg)) { if (matcher->match_value(input_value(0), arg)) {
return saved.finish(true); return saved.finish(true);

View File

@ -7,15 +7,14 @@
#include "ngraph/pattern/matcher.hpp" #include "ngraph/pattern/matcher.hpp"
using namespace std; using namespace std;
using namespace ngraph;
constexpr NodeTypeInfo pattern::op::AnyOutput::type_info; constexpr ov::NodeTypeInfo ov::pass::pattern::op::AnyOutput::type_info;
const NodeTypeInfo& pattern::op::AnyOutput::get_type_info() const { const ov::NodeTypeInfo& ov::pass::pattern::op::AnyOutput::get_type_info() const {
return type_info; return type_info;
} }
bool pattern::op::AnyOutput::match_value(Matcher* matcher, bool ov::pass::pattern::op::AnyOutput::match_value(Matcher* matcher,
const Output<Node>& pattern_value, const Output<Node>& pattern_value,
const Output<Node>& graph_value) { const Output<Node>& graph_value) {
return input_value(0).get_node()->match_node(matcher, graph_value); return input_value(0).get_node()->match_node(matcher, graph_value);

View File

@ -9,15 +9,14 @@
#include "ngraph/pattern/op/true.hpp" #include "ngraph/pattern/op/true.hpp"
using namespace std; using namespace std;
using namespace ngraph;
constexpr NodeTypeInfo pattern::op::Label::type_info; constexpr ov::NodeTypeInfo ov::pass::pattern::op::Label::type_info;
const NodeTypeInfo& pattern::op::Label::get_type_info() const { const ov::NodeTypeInfo& ov::pass::pattern::op::Label::get_type_info() const {
return type_info; return type_info;
} }
Output<Node> pattern::op::Label::wrap_values(const OutputVector& wrapped_values) { ov::Output<ov::Node> ov::pass::pattern::op::Label::wrap_values(const ov::OutputVector& wrapped_values) {
switch (wrapped_values.size()) { switch (wrapped_values.size()) {
case 0: case 0:
return make_shared<pattern::op::True>()->output(0); return make_shared<pattern::op::True>()->output(0);
@ -28,9 +27,9 @@ Output<Node> pattern::op::Label::wrap_values(const OutputVector& wrapped_values)
} }
} }
bool pattern::op::Label::match_value(Matcher* matcher, bool ov::pass::pattern::op::Label::match_value(ov::pass::pattern::Matcher* matcher,
const Output<Node>& pattern_value, const ov::Output<ov::Node>& pattern_value,
const Output<Node>& graph_value) { const ov::Output<ov::Node>& graph_value) {
if (m_predicate(graph_value)) { if (m_predicate(graph_value)) {
auto& pattern_map = matcher->get_pattern_value_map(); auto& pattern_map = matcher->get_pattern_value_map();
auto saved = matcher->start_match(); auto saved = matcher->start_match();
@ -45,10 +44,10 @@ bool pattern::op::Label::match_value(Matcher* matcher,
return false; return false;
} }
std::shared_ptr<Node> pattern::any_input() { std::shared_ptr<ov::Node> ov::pass::pattern::any_input() {
return std::make_shared<pattern::op::Label>(); return std::make_shared<pattern::op::Label>();
} }
std::shared_ptr<Node> pattern::any_input(const pattern::op::ValuePredicate& pred) { std::shared_ptr<ov::Node> ov::pass::pattern::any_input(const ov::pass::pattern::op::ValuePredicate& pred) {
return std::make_shared<pattern::op::Label>(element::dynamic, PartialShape::dynamic(), pred); return std::make_shared<pattern::op::Label>(element::dynamic, PartialShape::dynamic(), pred);
} }

View File

@ -7,7 +7,8 @@
#include <algorithm> #include <algorithm>
#include <regex> #include <regex>
namespace ngraph { namespace ov {
namespace pass {
namespace pattern { namespace pattern {
namespace op { namespace op {
// The symbols are required to be in cpp file to workaround RTTI issue on Android LLVM // The symbols are required to be in cpp file to workaround RTTI issue on Android LLVM
@ -101,4 +102,5 @@ std::function<bool(Output<Node>)> type_matches_any(const std::vector<element::Ty
}; };
} }
} // namespace pattern } // namespace pattern
} // namespace ngraph } // namespace pass
} // namespace ov

View File

@ -2,9 +2,10 @@
// SPDX-License-Identifier: Apache-2.0 // SPDX-License-Identifier: Apache-2.0
// //
#include "dyn_elimination.hpp"
#include <numeric> #include <numeric>
#include "dyn_elimination.hpp"
#include "ngraph/builder/reshape.hpp" #include "ngraph/builder/reshape.hpp"
#include "ngraph/op/broadcast.hpp" #include "ngraph/op/broadcast.hpp"
#include "ngraph/op/range.hpp" #include "ngraph/op/range.hpp"
@ -19,9 +20,7 @@ NGRAPH_SUPPRESS_DEPRECATED_START
using namespace std; using namespace std;
using namespace ngraph; using namespace ngraph;
pass::DynElimination::DynElimination() pass::DynElimination::DynElimination() : GraphRewrite() {
: GraphRewrite()
{
construct_range(); construct_range();
} }
@ -29,28 +28,22 @@ template <typename T>
std::shared_ptr<op::Constant> make_range_replacement(const element::Type& et, std::shared_ptr<op::Constant> make_range_replacement(const element::Type& et,
const Shape& shape, const Shape& shape,
const std::shared_ptr<op::Constant>& start_arg, const std::shared_ptr<op::Constant>& start_arg,
const std::shared_ptr<op::Constant>& step_arg) const std::shared_ptr<op::Constant>& step_arg) {
{
std::vector<T> elements(shape_size(shape)); std::vector<T> elements(shape_size(shape));
std::vector<T> start_vec = start_arg->get_vector<T>(); std::vector<T> start_vec = start_arg->get_vector<T>();
std::vector<T> step_vec = step_arg->get_vector<T>(); std::vector<T> step_vec = step_arg->get_vector<T>();
NGRAPH_CHECK(start_vec.size() == 1 && step_vec.size() == 1); NGRAPH_CHECK(start_vec.size() == 1 && step_vec.size() == 1);
runtime::reference::range<T>( runtime::reference::range<T>(start_vec.data(), step_vec.data(), shape_size(shape), elements.data());
start_vec.data(), step_vec.data(), shape_size(shape), elements.data());
return make_shared<op::Constant>(et, shape, elements); return make_shared<op::Constant>(et, shape, elements);
} }
void pass::DynElimination::construct_range() void pass::DynElimination::construct_range() {
{ auto start_arg_label = make_shared<pattern::op::Label>(element::f32, Shape{}, pattern::has_class<op::Constant>());
auto start_arg_label = auto stop_arg_label = make_shared<pattern::op::Label>(element::f32, Shape{}, pattern::has_class<op::Constant>());
make_shared<pattern::op::Label>(element::f32, Shape{}, pattern::has_class<op::Constant>()); auto step_arg_label = make_shared<pattern::op::Label>(element::f32, Shape{}, pattern::has_class<op::Constant>());
auto stop_arg_label =
make_shared<pattern::op::Label>(element::f32, Shape{}, pattern::has_class<op::Constant>());
auto step_arg_label =
make_shared<pattern::op::Label>(element::f32, Shape{}, pattern::has_class<op::Constant>());
auto range_pat = make_shared<op::Range>(start_arg_label, stop_arg_label, step_arg_label); auto range_pat = make_shared<op::Range>(start_arg_label, stop_arg_label, step_arg_label);
@ -74,8 +67,7 @@ void pass::DynElimination::construct_range()
# pragma GCC diagnostic error "-Wswitch" # pragma GCC diagnostic error "-Wswitch"
# pragma GCC diagnostic error "-Wswitch-enum" # pragma GCC diagnostic error "-Wswitch-enum"
#endif #endif
switch (et) switch (et) {
{
case element::Type_t::bf16: case element::Type_t::bf16:
replacement = make_range_replacement<bfloat16>(et, shape, start_arg, step_arg); replacement = make_range_replacement<bfloat16>(et, shape, start_arg, step_arg);
break; break;