Move pass pattern to ov (#7255)
* Moved ngraph::Node to ov namespace * Fixed code style * Fixed VPU * Fixed GNA * Fixed tests * Added aliases for backward compatibility * Fix clDNN * Try to fix build * Fixed comment * Renamed RTTI macros * Add new headers * Fixed ngraph build * Fixed unit tests * Try to fix Serialize
This commit is contained in:
parent
07f7061f96
commit
9eca6ba9d5
@ -26,7 +26,7 @@ endif()
|
||||
# resolving dependencies for the project
|
||||
message (STATUS "PROJECT ............................... " ${PROJECT_NAME})
|
||||
message (STATUS "CMAKE_BINARY_DIR ...................... " ${CMAKE_BINARY_DIR})
|
||||
message (STATUS "OpenVINO_SOURCE_DIR .... .......... " ${OpenVINO_SOURCE_DIR})
|
||||
message (STATUS "OpenVINO_SOURCE_DIR ................... " ${OpenVINO_SOURCE_DIR})
|
||||
message (STATUS "CMAKE_GENERATOR ....................... " ${CMAKE_GENERATOR})
|
||||
message (STATUS "CMAKE_C_COMPILER_ID ................... " ${CMAKE_C_COMPILER_ID})
|
||||
message (STATUS "CMAKE_BUILD_TYPE ...................... " ${CMAKE_BUILD_TYPE})
|
||||
|
@ -811,8 +811,34 @@ void ngfunction_2_irv10(pugi::xml_node& netXml,
|
||||
f.validate_nodes_and_infer_types();
|
||||
}
|
||||
}
|
||||
|
||||
std::string valid_xml_path(const std::string &path) {
|
||||
NGRAPH_CHECK(path.length() > 4, "Path for xml file is to short: \"" + path + "\"");
|
||||
|
||||
const char *const extension = ".xml";
|
||||
const bool has_xml_extension = path.rfind(extension) == path.size() - std::strlen(extension);
|
||||
NGRAPH_CHECK(has_xml_extension,
|
||||
"Path for xml file doesn't contains file name with 'xml' extension: \"" +
|
||||
path + "\"");
|
||||
return path;
|
||||
}
|
||||
|
||||
std::string provide_bin_path(const std::string &xmlPath, const std::string &binPath) {
|
||||
if (!binPath.empty()) {
|
||||
return binPath;
|
||||
}
|
||||
assert(xmlPath.size() > 4); // should be check by valid_xml_path
|
||||
std::string bestPath = xmlPath;
|
||||
const char *const extension = "bin";
|
||||
const auto ext_size = std::strlen(extension);
|
||||
bestPath.replace(bestPath.size() - ext_size, ext_size, extension);
|
||||
return bestPath;
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
namespace ngraph {
|
||||
|
||||
// ! [function_pass:serialize_cpp]
|
||||
// serialize.cpp
|
||||
bool pass::Serialize::run_on_function(std::shared_ptr<ngraph::Function> f) {
|
||||
@ -868,33 +894,6 @@ bool pass::Serialize::run_on_function(std::shared_ptr<ngraph::Function> f) {
|
||||
return false;
|
||||
}
|
||||
|
||||
namespace {
|
||||
|
||||
std::string valid_xml_path(const std::string &path) {
|
||||
NGRAPH_CHECK(path.length() > 4, "Path for xml file is to short: \"" + path + "\"");
|
||||
|
||||
const char *const extension = ".xml";
|
||||
const bool has_xml_extension = path.rfind(extension) == path.size() - std::strlen(extension);
|
||||
NGRAPH_CHECK(has_xml_extension,
|
||||
"Path for xml file doesn't contains file name with 'xml' extension: \"" +
|
||||
path + "\"");
|
||||
return path;
|
||||
}
|
||||
|
||||
std::string provide_bin_path(const std::string &xmlPath, const std::string &binPath) {
|
||||
if (!binPath.empty()) {
|
||||
return binPath;
|
||||
}
|
||||
assert(xmlPath.size() > 4); // should be check by valid_xml_path
|
||||
std::string bestPath = xmlPath;
|
||||
const char *const extension = "bin";
|
||||
const auto ext_size = std::strlen(extension);
|
||||
bestPath.replace(bestPath.size() - ext_size, ext_size, extension);
|
||||
return bestPath;
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
pass::Serialize::Serialize(std::ostream& xmlFile,
|
||||
std::ostream& binFile,
|
||||
pass::Serialize::Version version,
|
||||
@ -921,3 +920,4 @@ pass::Serialize::Serialize(const std::string& xmlPath,
|
||||
{
|
||||
}
|
||||
// ! [function_pass:serialize_cpp]
|
||||
} // namespace ngraph
|
||||
|
@ -5,24 +5,10 @@
|
||||
#pragma once
|
||||
|
||||
#include "ngraph/pass/pass.hpp"
|
||||
#include "openvino/pass/constant_folding.hpp"
|
||||
|
||||
namespace ngraph {
|
||||
namespace pass {
|
||||
/**
|
||||
* @brief Constant folding iterates over the function and tries to evaluate nodes
|
||||
* with constant inputs. Such nodes are then replaced with new Constants containing
|
||||
* the result of a folded operation.
|
||||
*/
|
||||
class NGRAPH_API ConstantFolding : public FunctionPass {
|
||||
public:
|
||||
NGRAPH_RTTI_DECLARATION;
|
||||
bool run_on_function(std::shared_ptr<ngraph::Function> f) override;
|
||||
|
||||
private:
|
||||
void copy_runtime_info_to_target_inputs(const std::shared_ptr<Node>& node, const Output<Node>& replacement);
|
||||
/// \brief Folds pre-calculated output tensor values to constants in case lower and
|
||||
/// upper estimations are equal. Traverses graph backwards starting from the results.
|
||||
bool pre_calculated_values_folding(const std::shared_ptr<ngraph::Function>& f);
|
||||
};
|
||||
using ov::pass::ConstantFolding;
|
||||
} // namespace pass
|
||||
} // namespace ngraph
|
||||
|
@ -4,14 +4,11 @@
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <ngraph/pass/graph_rewrite.hpp>
|
||||
#include "ngraph/pass/graph_rewrite.hpp"
|
||||
#include "openvino/pass/convert_fp32_to_fp16.hpp"
|
||||
|
||||
namespace ngraph {
|
||||
namespace pass {
|
||||
class NGRAPH_API ConvertFP32ToFP16 : public ngraph::pass::FunctionPass {
|
||||
public:
|
||||
NGRAPH_RTTI_DECLARATION;
|
||||
bool run_on_function(std::shared_ptr<ngraph::Function>) override;
|
||||
};
|
||||
using ov::pass::ConvertFP32ToFP16;
|
||||
} // namespace pass
|
||||
} // namespace ngraph
|
||||
|
@ -10,240 +10,17 @@
|
||||
|
||||
#include "ngraph/pass/pass.hpp"
|
||||
#include "ngraph/pattern/matcher.hpp"
|
||||
#include "openvino/pass/graph_rewrite.hpp"
|
||||
|
||||
namespace ngraph {
|
||||
using matcher_pass_callback = std::function<bool(ngraph::pattern::Matcher& m)>;
|
||||
using graph_rewrite_callback = std::function<bool(ngraph::pattern::Matcher& m)>;
|
||||
using recurrent_graph_rewrite_callback = std::function<bool(ngraph::pattern::RecurrentMatcher& m)>;
|
||||
using handler_callback = std::function<bool(const std::shared_ptr<Node>& node)>;
|
||||
using ov::graph_rewrite_callback;
|
||||
using ov::handler_callback;
|
||||
using ov::matcher_pass_callback;
|
||||
using ov::recurrent_graph_rewrite_callback;
|
||||
namespace pass {
|
||||
/// \brief MatcherPass is a basic block for pattern based transformations. It describes
|
||||
/// pattern and
|
||||
/// action that is applied if pattern is matched.
|
||||
///
|
||||
/// MatcherPass consists of Matcher and matcher_pass_callback that needs to be implemented
|
||||
/// and
|
||||
/// finally registered by using \sa register_matcher. MatcherPass can be executed on node
|
||||
/// within
|
||||
/// \sa apply method. To run matcher pass on Function use GraphRewrite.
|
||||
/// In addition MatcherPass provides a way for adding new operations into GraphRewrite
|
||||
/// execution
|
||||
/// queue. That means that operations that were created inside transformation callback can
|
||||
/// be added
|
||||
/// for matching. To register node use \sa register_new_node method. GraphRewrite
|
||||
/// automatically
|
||||
/// takes registered nodes and put them to execution queue. If multiple nodes were register
|
||||
/// make
|
||||
/// sure that they were registered in topological order.
|
||||
/// Note: when implementing pattern for Matcher make sure that root node is an operation
|
||||
/// from opset
|
||||
/// or has ngraph::pattern::op::WrapType. That will help GraphRewrite to execute matcher
|
||||
/// passes more
|
||||
/// efficient.
|
||||
|
||||
class NGRAPH_API MatcherPass : public ngraph::pass::PassBase {
|
||||
public:
|
||||
NGRAPH_RTTI_DECLARATION;
|
||||
|
||||
MatcherPass() = default;
|
||||
|
||||
MatcherPass(const MatcherPass&) = delete;
|
||||
MatcherPass& operator=(const MatcherPass&) = delete;
|
||||
|
||||
explicit MatcherPass(const std::string& name,
|
||||
const std::shared_ptr<pattern::Matcher>& m,
|
||||
const handler_callback& handler,
|
||||
const PassPropertyMask& property = PassProperty::CHANGE_DYNAMIC_STATE)
|
||||
: PassBase(),
|
||||
m_handler(handler),
|
||||
m_matcher(m) {
|
||||
set_name(name);
|
||||
set_property(property, true);
|
||||
}
|
||||
|
||||
bool apply(std::shared_ptr<ngraph::Node> node);
|
||||
|
||||
template <typename T, class... Args>
|
||||
std::shared_ptr<T> register_new_node(Args&&... args) {
|
||||
auto node = std::make_shared<T>(std::forward<Args>(args)...);
|
||||
m_new_nodes.push_back(node);
|
||||
return node;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
std::shared_ptr<T> register_new_node(const std::shared_ptr<T>& node) {
|
||||
m_new_nodes.push_back(node);
|
||||
return node;
|
||||
}
|
||||
|
||||
const std::vector<std::shared_ptr<ngraph::Node>>& get_new_nodes() {
|
||||
return m_new_nodes;
|
||||
}
|
||||
void clear_new_nodes() {
|
||||
m_new_nodes.clear();
|
||||
}
|
||||
std::shared_ptr<pattern::Matcher> get_matcher() {
|
||||
return m_matcher;
|
||||
}
|
||||
|
||||
protected:
|
||||
void register_matcher(const std::shared_ptr<pattern::Matcher>& m,
|
||||
const ngraph::graph_rewrite_callback& callback,
|
||||
const PassPropertyMask& property = PassProperty::CHANGE_DYNAMIC_STATE);
|
||||
|
||||
private:
|
||||
handler_callback m_handler;
|
||||
std::shared_ptr<pattern::Matcher> m_matcher;
|
||||
std::vector<std::shared_ptr<ngraph::Node>> m_new_nodes;
|
||||
};
|
||||
|
||||
/// \brief GraphRewrite is a container for MatcherPasses that allows to run them on Function
|
||||
/// in
|
||||
/// efficient way
|
||||
///
|
||||
/// Graph rewrite pass is used for matcher passes execution on Function.
|
||||
/// To register MatcherPass use \sa add_matcher<T>(args) method where T is a MatcherPass
|
||||
/// class.
|
||||
/// As a default algorithm graph rewrite pass traverse Function in topological order and
|
||||
/// applies
|
||||
/// registered matcher passes for each node. But if all registered matcher passes have type
|
||||
/// based
|
||||
/// root node in Matcher pattern then efficient mechanism is used to execute them.
|
||||
/// Matcher pattern root is type based if it's operation from opset or
|
||||
/// pattern::op::WrapType.
|
||||
/// Note: when implementing pattern for Matcher make sure that root node is an operation
|
||||
/// from opset
|
||||
/// or has ngraph::pattern::op::WrapType. That will help GraphRewrite to execute matcher
|
||||
/// passes more
|
||||
/// efficient.
|
||||
|
||||
class NGRAPH_API GraphRewrite : public ngraph::pass::FunctionPass {
|
||||
public:
|
||||
NGRAPH_RTTI_DECLARATION;
|
||||
|
||||
GraphRewrite() = default;
|
||||
|
||||
explicit GraphRewrite(const std::shared_ptr<MatcherPass>& pass) : FunctionPass() {
|
||||
m_matchers.push_back(pass);
|
||||
}
|
||||
|
||||
/// \brief Register given transformation class type to GraphRewrite execution list
|
||||
/// All registered transformations will be executed in a single graph traversal.
|
||||
/// Example below show the basic usage of pass::GraphRewrite
|
||||
///
|
||||
/// pass::Manager manager;
|
||||
/// auto anchor = manager.register_pass<GraphRewrite>();
|
||||
/// anchor->add_matcher<MatcherPassA>();
|
||||
/// anchor->add_matcher<MatcherPassB>();
|
||||
/// anchor->set_name("CommonMatchers");
|
||||
/// manager.run_passes(f);
|
||||
///
|
||||
/// For some purposes transformation can be registered and disabled by default.
|
||||
///
|
||||
/// anchor->add_matcher<MatcherPassB, false>();
|
||||
///
|
||||
/// \return shared_ptr to the transformation instance
|
||||
template <typename T,
|
||||
bool Enabled = true,
|
||||
class... Args,
|
||||
typename std::enable_if<std::is_base_of<pass::MatcherPass, T>::value, bool>::type = true>
|
||||
std::shared_ptr<T> add_matcher(Args&&... args) {
|
||||
static_assert(std::is_base_of<pass::MatcherPass, T>::value, "pass not derived from MatcherPass");
|
||||
auto pass = std::make_shared<T>(std::forward<Args>(args)...);
|
||||
auto pass_config = get_pass_config();
|
||||
pass->set_pass_config(pass_config);
|
||||
if (!Enabled && !pass_config->is_enabled<T>()) {
|
||||
pass_config->disable<T>();
|
||||
}
|
||||
m_matchers.push_back(pass);
|
||||
return pass;
|
||||
}
|
||||
|
||||
/// \brief Register passes from GraphRewrite class that contains sequence of matcher
|
||||
/// passes registered in its ctor.
|
||||
/// For example:
|
||||
///
|
||||
/// class ngraph::pass::LinFusions: public ngraph::pass::GraphRewrite {
|
||||
/// public:
|
||||
/// NGRAPH_RTTI_DECLARATION;
|
||||
/// Fusions() {
|
||||
/// add_matcher<ngraph::pass::AddFusion>();
|
||||
/// add_matcher<ngraph::pass::MulFusion>();
|
||||
/// }
|
||||
/// };
|
||||
///
|
||||
/// pass::Manager manager;
|
||||
/// auto anchor = manager.register_pass<GraphRewrite>();
|
||||
/// anchor->add_matcher<LinFusions>();
|
||||
/// anchor->add_matcher<OtherFusions>();
|
||||
/// anchor->set_name("CommonFusions");
|
||||
/// manager.run_passes(f);
|
||||
///
|
||||
/// In this case all matcher passes from LinFusions pass will be united with other
|
||||
/// registered matchers.
|
||||
template <typename T,
|
||||
class... Args,
|
||||
typename std::enable_if<std::is_base_of<pass::GraphRewrite, T>::value, bool>::type = true>
|
||||
void add_matcher(Args&&... args) {
|
||||
static_assert(std::is_base_of<pass::GraphRewrite, T>::value, "pass not derived from GraphRewrite");
|
||||
auto pass = std::make_shared<T>(std::forward<Args>(args)...);
|
||||
auto pass_config = get_pass_config();
|
||||
|
||||
for (auto& matcher : pass->m_matchers) {
|
||||
pass->set_pass_config(pass_config);
|
||||
m_matchers.push_back(matcher);
|
||||
}
|
||||
}
|
||||
|
||||
NGRAPH_DEPRECATED("Use MatcherPass instead")
|
||||
void add_matcher(const std::shared_ptr<pattern::Matcher>& m,
|
||||
const ngraph::graph_rewrite_callback& callback,
|
||||
const PassPropertyMask& property);
|
||||
|
||||
NGRAPH_DEPRECATED("Use MatcherPass instead")
|
||||
void add_matcher(const std::shared_ptr<pattern::Matcher>& m, const ngraph::graph_rewrite_callback& callback);
|
||||
|
||||
bool run_on_function(std::shared_ptr<ngraph::Function> f) override;
|
||||
|
||||
void set_pass_config(const std::shared_ptr<PassConfig>& pass_config) override;
|
||||
|
||||
protected:
|
||||
bool apply_matcher_passes(std::shared_ptr<Function> f, std::deque<std::weak_ptr<Node>> nodes_to_run);
|
||||
|
||||
bool m_enable_shape_inference = false;
|
||||
|
||||
std::vector<std::shared_ptr<ngraph::pass::MatcherPass>> m_matchers;
|
||||
};
|
||||
|
||||
class NGRAPH_API BackwardGraphRewrite : public ngraph::pass::GraphRewrite {
|
||||
public:
|
||||
NGRAPH_RTTI_DECLARATION;
|
||||
|
||||
BackwardGraphRewrite() = default;
|
||||
|
||||
explicit BackwardGraphRewrite(const std::shared_ptr<MatcherPass>& pass) : GraphRewrite(pass) {}
|
||||
|
||||
bool run_on_function(std::shared_ptr<ngraph::Function> f) override;
|
||||
};
|
||||
|
||||
class NGRAPH_API RecurrentGraphRewrite : public ngraph::pass::FunctionPass {
|
||||
public:
|
||||
RecurrentGraphRewrite(size_t num_iters = 10) : FunctionPass(), m_num_iters(num_iters) {}
|
||||
|
||||
void add_matcher(const std::shared_ptr<pattern::RecurrentMatcher>& m,
|
||||
const ngraph::recurrent_graph_rewrite_callback& callback,
|
||||
const PassPropertyMask& property);
|
||||
|
||||
// TODO: This interface may deprecate after all passes are refactored.
|
||||
void add_matcher(const std::shared_ptr<pattern::RecurrentMatcher>& m,
|
||||
const ngraph::recurrent_graph_rewrite_callback& callback);
|
||||
|
||||
bool run_on_function(std::shared_ptr<ngraph::Function> f) override;
|
||||
|
||||
private:
|
||||
size_t m_num_iters;
|
||||
|
||||
std::vector<std::shared_ptr<ngraph::pass::MatcherPass>> m_matchers;
|
||||
};
|
||||
using ov::pass::BackwardGraphRewrite;
|
||||
using ov::pass::GraphRewrite;
|
||||
using ov::pass::MatcherPass;
|
||||
using ov::pass::RecurrentGraphRewrite;
|
||||
} // namespace pass
|
||||
} // namespace ngraph
|
||||
|
@ -5,10 +5,12 @@
|
||||
#pragma once
|
||||
|
||||
#include <memory>
|
||||
#include <ngraph/pass/graph_rewrite.hpp>
|
||||
#include <ngraph/pass/pass.hpp>
|
||||
#include <vector>
|
||||
|
||||
#include "ngraph/pass/graph_rewrite.hpp"
|
||||
#include "ngraph/pass/pass.hpp"
|
||||
#include "openvino/pass/low_latency.hpp"
|
||||
|
||||
namespace ngraph {
|
||||
namespace pass {
|
||||
/**
|
||||
@ -46,38 +48,6 @@ public:
|
||||
LowLatency();
|
||||
};
|
||||
|
||||
/**
|
||||
* @brief The transformation finds all TensorIterator/Loop layers in the network,
|
||||
* processes all back edges that describe a connection between Result and Parameter
|
||||
* of the TensorIterator/Loop bodies,and inserts ReadValue and Assign layers at the
|
||||
* input and output corresponding to this back edge.
|
||||
* Supported platforms: CPU, GNA.
|
||||
*
|
||||
* The example below describes the changes made by the transformation
|
||||
* [] - TensorIterator body
|
||||
* () - new layer
|
||||
* BE - back-edge
|
||||
*
|
||||
* before applying the transformation:
|
||||
* -> input1[BE_1 -> Parameter -> Layers ... -> Result -> BE_1 ]output1->
|
||||
*
|
||||
* after applying the transformation:
|
||||
* ->(ReadValue)-> input1[BE_1 ->Parameter->Layers ...->Result->BE_1]output1 ->(Assign)
|
||||
* \
|
||||
* ->...
|
||||
* After applying the transformation, the resulting network can be inferred
|
||||
* step by step, the states will store between inferences.
|
||||
*/
|
||||
class NGRAPH_API LowLatency2 : public ngraph::pass::FunctionPass {
|
||||
public:
|
||||
NGRAPH_RTTI_DECLARATION;
|
||||
|
||||
explicit LowLatency2(bool use_const_initializer = true) : m_use_const_initializer(use_const_initializer) {}
|
||||
|
||||
bool run_on_function(std::shared_ptr<ngraph::Function> f) override;
|
||||
|
||||
private:
|
||||
bool m_use_const_initializer;
|
||||
};
|
||||
using ov::pass::LowLatency2;
|
||||
} // namespace pass
|
||||
} // namespace ngraph
|
||||
|
@ -11,106 +11,10 @@
|
||||
|
||||
#include "ngraph/pass/pass.hpp"
|
||||
#include "ngraph/pass/validate.hpp"
|
||||
#include "openvino/pass/manager.hpp"
|
||||
|
||||
namespace ngraph {
|
||||
namespace pass {
|
||||
class NGRAPH_API Manager {
|
||||
public:
|
||||
Manager();
|
||||
~Manager();
|
||||
|
||||
//// \brief Construct Manager with shared PassConfig instance
|
||||
explicit Manager(std::shared_ptr<PassConfig> pass_config);
|
||||
|
||||
/// \brief Register given transformation class type to execution list
|
||||
/// Example below show the basic usage of pass::Manager
|
||||
///
|
||||
/// pass::Manager manager;
|
||||
/// manager.register_pass<MyTransformation>(/*transformation constructor ars*/);
|
||||
/// manager.run_passes(f);
|
||||
///
|
||||
/// For some purposes transformation can be registered and disabled by default.
|
||||
///
|
||||
/// manager.register_pass<MyTransformation, false>();
|
||||
///
|
||||
/// \return shared_ptr to the transformation instance
|
||||
template <typename T, bool Enable = true, class... Args>
|
||||
std::shared_ptr<T> register_pass(Args&&... args) {
|
||||
auto rc = push_pass<T>(std::forward<Args>(args)...);
|
||||
rc->set_pass_config(m_pass_config);
|
||||
if (m_per_pass_validation) {
|
||||
push_pass<Validate>();
|
||||
}
|
||||
if (!Enable && !m_pass_config->is_enabled<T>()) {
|
||||
m_pass_config->disable<T>();
|
||||
}
|
||||
return rc;
|
||||
}
|
||||
|
||||
void run_passes(std::shared_ptr<Function>);
|
||||
|
||||
void set_pass_visualization(bool new_state) {
|
||||
m_visualize = new_state;
|
||||
}
|
||||
/// \brief Set flag to enable/disable running Validate pass after executing
|
||||
/// each registered pass
|
||||
/// \param new_state Value "true" enables Validate pass run; "false", otherwise
|
||||
void set_per_pass_validation(bool new_state) {
|
||||
m_per_pass_validation = new_state;
|
||||
}
|
||||
/// \brief Callback is a lambda function that can be used by registered transformations.
|
||||
/// The main purpose of this callback is to provide a way for plugins to disable/enable
|
||||
/// transformations based on some conditions. In some cases plugins may want not to
|
||||
/// execute some
|
||||
/// transformations.
|
||||
/// For example plugin can disable unpleasant decompositions because of performance
|
||||
/// reasons for
|
||||
/// some cases.
|
||||
/// Callback example:
|
||||
/// auto callback = [](const std::shared_ptr<const ngraph::Node> & node) -> bool {
|
||||
/// return std::dynamic_pointer_cast<const ngraph::opset3::DepthToSpace>(node) !=
|
||||
/// nullptr;
|
||||
/// };
|
||||
/// This callback returns true in case of DepthToSpace operation. So when execution
|
||||
/// DepthToSpace
|
||||
/// decomposition pass will check is this decomposition needed or plugin can execute
|
||||
/// this
|
||||
/// operation directly. And of course on transformation side we need to have a response
|
||||
/// for this
|
||||
/// callback.
|
||||
/// if (transformation_callback(batch_to_space)) {
|
||||
/// return false;
|
||||
/// }
|
||||
/// \param callback lamda function that returns true in case if node is supported by
|
||||
/// plugin and
|
||||
/// transformation is not needed
|
||||
NGRAPH_DEPRECATED("Please use get_pass_config() to configure transformation pipeline")
|
||||
void set_callback(const param_callback& callback) {
|
||||
m_pass_config->set_callback(callback);
|
||||
}
|
||||
/// \return PassConfig shared object. This object is used for transformations pipeline
|
||||
/// configuration.
|
||||
/// This object allows to disable/enable transformations execution, set callback to
|
||||
/// particular
|
||||
/// transformation. For mo details see PassConfig class.
|
||||
std::shared_ptr<PassConfig> get_pass_config() {
|
||||
return m_pass_config;
|
||||
}
|
||||
|
||||
protected:
|
||||
template <typename T, class... Args>
|
||||
std::shared_ptr<T> push_pass(Args&&... args) {
|
||||
static_assert(std::is_base_of<pass::PassBase, T>::value, "pass not derived from pass base");
|
||||
auto pass = std::make_shared<T>(std::forward<Args>(args)...);
|
||||
auto pass_base = std::static_pointer_cast<PassBase>(pass);
|
||||
m_pass_list.push_back(pass_base);
|
||||
return pass;
|
||||
}
|
||||
|
||||
std::shared_ptr<PassConfig> m_pass_config;
|
||||
std::vector<std::shared_ptr<PassBase>> m_pass_list;
|
||||
bool m_visualize = false;
|
||||
bool m_per_pass_validation = true;
|
||||
};
|
||||
using ov::pass::Manager;
|
||||
} // namespace pass
|
||||
} // namespace ngraph
|
||||
|
@ -13,105 +13,32 @@
|
||||
#include "ngraph/node.hpp"
|
||||
#include "ngraph/pass/pass_config.hpp"
|
||||
#include "ngraph/util.hpp"
|
||||
#include "openvino/pass/pass.hpp"
|
||||
|
||||
namespace ov {
|
||||
namespace pass {
|
||||
|
||||
class Manager;
|
||||
|
||||
}
|
||||
} // namespace ov
|
||||
namespace ngraph {
|
||||
namespace pass {
|
||||
enum class PassProperty : uint32_t {
|
||||
// Pass requires node shapes to be static
|
||||
REQUIRE_STATIC_SHAPE = 0x1,
|
||||
// Pass transformation will change the function's dynamic state
|
||||
CHANGE_DYNAMIC_STATE = 1 << 1,
|
||||
};
|
||||
|
||||
typedef EnumMask<PassProperty> PassPropertyMask;
|
||||
using ov::pass::FunctionPass;
|
||||
using ov::pass::FusionType;
|
||||
using ov::pass::FusionTypeMask;
|
||||
using ov::pass::Manager;
|
||||
using ov::pass::PassBase;
|
||||
using ov::pass::PassProperty;
|
||||
using ov::pass::PassPropertyMask;
|
||||
NGRAPH_DEPRECATED("This variable is deprecated and will be removed soon.")
|
||||
const PassPropertyMask all_pass_property_off;
|
||||
|
||||
class NGRAPH_API PassBase {
|
||||
friend class Manager;
|
||||
|
||||
public:
|
||||
PassBase();
|
||||
virtual ~PassBase() {}
|
||||
/// Check if this pass has all the pass properties.
|
||||
bool get_property(const PassPropertyMask& prop_mask) const;
|
||||
|
||||
void set_name(const std::string& name) {
|
||||
m_name = name;
|
||||
}
|
||||
std::string get_name() const;
|
||||
|
||||
/// \brief Set callback for particular transformation type.
|
||||
/// This method set global callback. For more details see PassConfig class
|
||||
/// documentation.
|
||||
/// \param callback lambda function that takes node and returns bool
|
||||
void set_callback(const param_callback& callback);
|
||||
|
||||
/// \brief Set PassConfig for particular transformation instance
|
||||
/// \param pass_config is a PassConfig shared_ptr
|
||||
virtual void set_pass_config(const std::shared_ptr<PassConfig>& pass_config) {
|
||||
m_pass_config = pass_config;
|
||||
}
|
||||
|
||||
/// \brief Allows to access PassConfig shared instance
|
||||
/// \return Shared instance of PassConfig class
|
||||
std::shared_ptr<PassConfig> get_pass_config() {
|
||||
return m_pass_config;
|
||||
}
|
||||
/// \brief Applies callback for given node. By default callback returns false.
|
||||
/// This method remains here only for backward compatibility and will be removed
|
||||
/// after all transformations are moved to transformation_callback() method.
|
||||
/// \return result of callback execution for given node
|
||||
NGRAPH_DEPRECATED("Please use transformation_callback method instead")
|
||||
bool m_transformation_callback(const std::shared_ptr<const Node>& node) {
|
||||
return m_pass_config->get_callback(get_type_info())(node);
|
||||
}
|
||||
|
||||
/// \brief Applies callback for given node. By default callback returns false.
|
||||
/// \param node which will be used inside callback
|
||||
/// \return result of callback execution for given node
|
||||
bool transformation_callback(const std::shared_ptr<const Node>& node) {
|
||||
return m_pass_config->get_callback(get_type_info())(node);
|
||||
}
|
||||
|
||||
using type_info_t = DiscreteTypeInfo;
|
||||
|
||||
virtual const type_info_t& get_type_info() const = 0;
|
||||
|
||||
protected:
|
||||
void set_property(const PassPropertyMask& prop, bool value);
|
||||
|
||||
private:
|
||||
PassPropertyMask m_property;
|
||||
|
||||
std::string m_name;
|
||||
std::shared_ptr<PassConfig> m_pass_config;
|
||||
};
|
||||
|
||||
class NGRAPH_API FunctionPass : public PassBase {
|
||||
public:
|
||||
NGRAPH_RTTI_DECLARATION;
|
||||
virtual ~FunctionPass();
|
||||
virtual bool run_on_function(std::shared_ptr<ngraph::Function>) = 0;
|
||||
};
|
||||
|
||||
class NGRAPH_DEPRECATED("Use MatcherPass or FunctionPass instead.") NGRAPH_API NodePass : public PassBase {
|
||||
public:
|
||||
NGRAPH_RTTI_DECLARATION;
|
||||
virtual ~NodePass();
|
||||
~NodePass() override;
|
||||
virtual bool run_on_node(std::shared_ptr<ngraph::Node>) = 0;
|
||||
};
|
||||
|
||||
class Manager;
|
||||
enum class FusionType : uint32_t {
|
||||
//`DIFFERENTIABLE_FUSIONS` produce ops that support autodiff
|
||||
// i.e. implement `generate_adjoints`
|
||||
DIFFERENTIABLE_FUSIONS = 0x1,
|
||||
REGULAR_FUSIONS = 0x2,
|
||||
//`FOP_FUSIONS` produce ops in the FusedOps category that might
|
||||
// not be supported by all backends
|
||||
FOP_FUSIONS = 0x4,
|
||||
ALL_FUSIONS = 0xFFFFFFFF
|
||||
};
|
||||
typedef EnumMask<FusionType> FusionTypeMask;
|
||||
} // namespace pass
|
||||
} // namespace ngraph
|
||||
|
@ -12,164 +12,12 @@
|
||||
#include "ngraph/function.hpp"
|
||||
#include "ngraph/node.hpp"
|
||||
#include "ngraph/util.hpp"
|
||||
#include "openvino/pass/pass_config.hpp"
|
||||
|
||||
namespace ngraph {
|
||||
namespace pass {
|
||||
using param_callback = std::function<bool(const std::shared_ptr<const ::ngraph::Node>)>;
|
||||
using param_callback_map = std::map<ngraph::DiscreteTypeInfo, param_callback>;
|
||||
|
||||
/// \brief Class representing a transformations config that is used for disabling/enabling
|
||||
/// transformations registered inside pass::Manager and also allows to set callback for all
|
||||
/// transformations or for particular transformation.
|
||||
///
|
||||
/// When pass::Manager is created all passes registered inside this manager including nested
|
||||
/// passes will share the same instance of PassConfig class.
|
||||
/// To work with this class first you need to get shared instance of this class by calling
|
||||
/// manager.get_pass_config() method. Then you will be able to disable/enable passes based
|
||||
/// on transformations type_info. For example:
|
||||
///
|
||||
/// pass::Manager manager;
|
||||
/// manager.register_pass<CommonOptimizations>();
|
||||
/// auto pass_config = manager.get_pass_config();
|
||||
/// pass_config->disable<ConvertGELU>(); // this will disable nested pass inside
|
||||
/// // CommonOptimizations pipeline
|
||||
/// manager.run_passes(f);
|
||||
///
|
||||
/// Sometimes it is needed to call transformation inside other transformation manually. And
|
||||
/// for that case before running transformation you need manually check that this pass is
|
||||
/// not disabled and then you need to set current PassConfig instance to this
|
||||
/// transformation. For example:
|
||||
///
|
||||
/// // Inside MatcherPass callback or inside FunctionPass run_on_function() method
|
||||
/// // you need to call get_pass_config() method to get shared instance of PassConfig
|
||||
/// auto pass_config = get_pass_config();
|
||||
///
|
||||
/// // Before running nested transformation you need to check is it disabled or not
|
||||
/// if (!pass_config->is_disabled<ConvertGELU>()) {
|
||||
/// auto pass = ConvertGELU();
|
||||
/// pass->set_pass_config(pass_config);
|
||||
/// pass.apply(node);
|
||||
/// }
|
||||
///
|
||||
/// Following this logic inside your transformations you will guaranty that transformations
|
||||
/// will be executed in a right way.
|
||||
class NGRAPH_API PassConfig {
|
||||
public:
|
||||
/// \brief Disable transformation by its type_info
|
||||
/// \param type_info Transformation type_info
|
||||
void disable(const DiscreteTypeInfo& type_info);
|
||||
/// \brief Disable transformation by its class type (based on type_info)
|
||||
template <typename T>
|
||||
void disable() {
|
||||
NGRAPH_SUPPRESS_DEPRECATED_START
|
||||
disable(T::type_info);
|
||||
NGRAPH_SUPPRESS_DEPRECATED_END
|
||||
}
|
||||
|
||||
/// \brief Enable transformation by its type_info
|
||||
/// \param type_info Transformation type_info
|
||||
void enable(const DiscreteTypeInfo& type_info);
|
||||
/// \brief Enable transformation by its class type (based on type_info)
|
||||
template <typename T>
|
||||
void enable() {
|
||||
NGRAPH_SUPPRESS_DEPRECATED_START
|
||||
enable(T::type_info);
|
||||
NGRAPH_SUPPRESS_DEPRECATED_END
|
||||
}
|
||||
|
||||
/// \brief Set callback for all kind of transformations
|
||||
void set_callback(const param_callback& callback) {
|
||||
m_callback = callback;
|
||||
}
|
||||
template <typename... Args>
|
||||
typename std::enable_if<sizeof...(Args) == 0>::type set_callback(const param_callback& callback) {}
|
||||
|
||||
/// \brief Set callback for particular transformation class types
|
||||
///
|
||||
/// Example below show how to set callback for one or multiple passes using this method.
|
||||
///
|
||||
/// pass_config->set_callback<ngraph::pass::ConvertBatchToSpace,
|
||||
/// ngraph::pass::ConvertSpaceToBatch>(
|
||||
/// [](const_node_ptr &node) -> bool {
|
||||
/// // Disable transformations for cases when input shape rank is not
|
||||
/// equal to 4
|
||||
/// const auto input_shape_rank =
|
||||
/// node->get_output_partial_shape(0).rank().get_length();
|
||||
/// if (input_shape_rank != 4) {
|
||||
/// return false;
|
||||
/// }
|
||||
/// return true;
|
||||
/// });
|
||||
///
|
||||
/// Note that inside transformations you must provide code that work with this callback.
|
||||
/// See example below:
|
||||
///
|
||||
/// if (transformation_callback(node)) {
|
||||
/// return false; // exit from transformation
|
||||
/// }
|
||||
///
|
||||
template <typename T, class... Args>
|
||||
void set_callback(const param_callback& callback) {
|
||||
m_callback_map[T::type_info] = callback;
|
||||
set_callback<Args...>(callback);
|
||||
}
|
||||
|
||||
/// \brief Get callback for given transformation type_info
|
||||
/// \param type_info Transformation type_info
|
||||
///
|
||||
/// In case if callback wasn't set for given transformation type then global callback
|
||||
/// will be returned. But if even global callback wasn't set then default callback will
|
||||
/// be returned.
|
||||
param_callback get_callback(const DiscreteTypeInfo& type_info) const;
|
||||
|
||||
/// \brief Get callback for given transformation class type
|
||||
/// \return callback lambda function
|
||||
template <typename T>
|
||||
param_callback get_callback() const {
|
||||
NGRAPH_SUPPRESS_DEPRECATED_START
|
||||
return get_callback(T::type_info);
|
||||
NGRAPH_SUPPRESS_DEPRECATED_END
|
||||
}
|
||||
|
||||
/// \brief Check either transformation type is disabled or not
|
||||
/// \param type_info Transformation type_info
|
||||
/// \return true if transformation type was disabled and false otherwise
|
||||
bool is_disabled(const DiscreteTypeInfo& type_info) const {
|
||||
return m_disabled.count(type_info);
|
||||
}
|
||||
|
||||
/// \brief Check either transformation class type is disabled or not
|
||||
/// \return true if transformation type was disabled and false otherwise
|
||||
template <typename T>
|
||||
bool is_disabled() const {
|
||||
NGRAPH_SUPPRESS_DEPRECATED_START
|
||||
return is_disabled(T::type_info);
|
||||
NGRAPH_SUPPRESS_DEPRECATED_END
|
||||
}
|
||||
|
||||
/// \brief Check either transformation type is force enabled or not
|
||||
/// \param type_info Transformation type_info
|
||||
/// \return true if transformation type was force enabled and false otherwise
|
||||
bool is_enabled(const DiscreteTypeInfo& type_info) const {
|
||||
return m_enabled.count(type_info);
|
||||
}
|
||||
|
||||
/// \brief Check either transformation class type is force enabled or not
|
||||
/// \return true if transformation type was force enabled and false otherwise
|
||||
template <typename T>
|
||||
bool is_enabled() const {
|
||||
return is_enabled(T::type_info);
|
||||
}
|
||||
|
||||
void add_disabled_passes(const PassConfig& rhs);
|
||||
|
||||
private:
|
||||
param_callback m_callback = [](const std::shared_ptr<const ::ngraph::Node>&) {
|
||||
return false;
|
||||
};
|
||||
param_callback_map m_callback_map;
|
||||
std::unordered_set<DiscreteTypeInfo> m_disabled;
|
||||
std::unordered_set<DiscreteTypeInfo> m_enabled;
|
||||
};
|
||||
using ov::pass::param_callback;
|
||||
using ov::pass::param_callback_map;
|
||||
using ov::pass::PassConfig;
|
||||
} // namespace pass
|
||||
} // namespace ngraph
|
||||
|
@ -5,27 +5,10 @@
|
||||
#pragma once
|
||||
|
||||
#include "ngraph/pass/pass.hpp"
|
||||
#include "openvino/pass/validate.hpp"
|
||||
|
||||
namespace ngraph {
|
||||
namespace pass {
|
||||
/// \brief The Validate pass performs sanity checks on attributes and inputs, and
|
||||
/// computes output shapes and element types for all computation nodes in a given
|
||||
/// computation graph.
|
||||
///
|
||||
/// \details The verification and inference is done via invoking each node's specific
|
||||
/// implementation of \link ngraph::Node::validate_and_infer_types() \endlink function.
|
||||
///
|
||||
/// By default, the \ref ngraph::pass::Manager runs this pass after executing every
|
||||
/// optimization pass. This is to ensure that any update to the graph by an optimization
|
||||
/// pass does not break the shape and data type requirement on a computation node.
|
||||
/// This default validation run can be changed via calling the
|
||||
/// \link ngraph::pass::Manager::set_per_pass_validation(bool) \endlink function.
|
||||
class NGRAPH_API Validate : public FunctionPass {
|
||||
public:
|
||||
NGRAPH_RTTI_DECLARATION;
|
||||
|
||||
Validate() : FunctionPass() {}
|
||||
bool run_on_function(std::shared_ptr<ngraph::Function> f) override;
|
||||
};
|
||||
using ov::pass::Validate;
|
||||
} // namespace pass
|
||||
} // namespace ngraph
|
||||
|
@ -14,44 +14,10 @@
|
||||
#include <utility>
|
||||
|
||||
#include "ngraph/pass/pass.hpp"
|
||||
|
||||
class HeightMap;
|
||||
|
||||
using visualize_tree_ops_map_t =
|
||||
std::unordered_map<ngraph::Node::type_info_t, std::function<void(const ngraph::Node&, std::ostream& ss)>>;
|
||||
#include "openvino/pass/visualize_tree.hpp"
|
||||
|
||||
namespace ngraph {
|
||||
namespace pass {
|
||||
class NGRAPH_API VisualizeTree : public FunctionPass {
|
||||
public:
|
||||
NGRAPH_RTTI_DECLARATION;
|
||||
|
||||
using node_modifiers_t = std::function<void(const Node& node, std::vector<std::string>& attributes)>;
|
||||
VisualizeTree(const std::string& file_name, node_modifiers_t nm = nullptr, bool dot_only = false);
|
||||
bool run_on_function(std::shared_ptr<ngraph::Function>) override;
|
||||
|
||||
void set_ops_to_details(const visualize_tree_ops_map_t& ops_map) {
|
||||
m_ops_to_details = ops_map;
|
||||
}
|
||||
|
||||
protected:
|
||||
void add_node_arguments(std::shared_ptr<Node> node,
|
||||
std::unordered_map<Node*, HeightMap>& height_maps,
|
||||
size_t& fake_node_ctr);
|
||||
std::string add_attributes(std::shared_ptr<Node> node);
|
||||
virtual std::string get_attributes(std::shared_ptr<Node> node);
|
||||
virtual std::string get_node_name(std::shared_ptr<Node> node);
|
||||
std::string get_constant_value(std::shared_ptr<Node> node, size_t max_elements = 7);
|
||||
|
||||
void render() const;
|
||||
|
||||
std::stringstream m_ss;
|
||||
std::string m_name;
|
||||
std::set<std::shared_ptr<Node>> m_nodes_with_attributes;
|
||||
visualize_tree_ops_map_t m_ops_to_details;
|
||||
node_modifiers_t m_node_modifiers = nullptr;
|
||||
bool m_dot_only;
|
||||
static const int max_jump_distance;
|
||||
};
|
||||
using ov::pass::VisualizeTree;
|
||||
} // namespace pass
|
||||
} // namespace ngraph
|
||||
|
@ -16,255 +16,21 @@
|
||||
#include "ngraph/pattern/op/any_output.hpp"
|
||||
#include "ngraph/pattern/op/label.hpp"
|
||||
#include "ngraph/pattern/op/skip.hpp"
|
||||
#include "openvino/pass/pattern/matcher.hpp"
|
||||
|
||||
namespace ngraph {
|
||||
namespace ov {
|
||||
namespace pass {
|
||||
class GraphRewrite;
|
||||
}
|
||||
} // namespace ov
|
||||
namespace ngraph {
|
||||
namespace pass {
|
||||
using ov::pass::GraphRewrite;
|
||||
}
|
||||
|
||||
namespace pattern {
|
||||
class Matcher;
|
||||
|
||||
class NGRAPH_API MatcherState {
|
||||
public:
|
||||
MatcherState(Matcher*);
|
||||
bool finish(bool is_successful);
|
||||
~MatcherState();
|
||||
|
||||
protected:
|
||||
Matcher* m_matcher;
|
||||
PatternValueMap m_pattern_value_map;
|
||||
PatternValueMaps m_pattern_value_maps;
|
||||
size_t m_watermark;
|
||||
size_t m_capture_size;
|
||||
bool m_restore{true};
|
||||
};
|
||||
|
||||
/// Matcher looks for node patterns in a computation graph. The patterns are described by an
|
||||
/// automaton that is described by an extended computation graph. The matcher executes
|
||||
/// by attempting to match the start node of the pattern to a computation graph value
|
||||
/// (output of a Node). In addition to determing if a match occurs, a pattern node may add
|
||||
/// graph nodes to a list of matched nodes, associate nodes with graph values, and start
|
||||
/// submatches. Submatches add match state changes to the enclosing match if the submatch
|
||||
/// succeeds; otherwise the state is reverted.
|
||||
///
|
||||
/// The default match behavior of a pattern node with a graph nodes is that the computation
|
||||
/// graph value is added to the end of the matched value list and the match succeeds if the
|
||||
/// node/pattern types match and the input values match. In the case of a commutative node,
|
||||
/// the inputs can match in any order. If the matcher is in strict mode, the graph value
|
||||
/// element type and shape must also match.
|
||||
///
|
||||
/// Pattern nodes that have different match behavior are in ngraph::pattern::op and have
|
||||
/// descriptions of their match behavior.
|
||||
class NGRAPH_API Matcher {
|
||||
public:
|
||||
using PatternMap = ngraph::pattern::PatternMap;
|
||||
|
||||
// Avoid implicit string construction from nullptr.
|
||||
Matcher(const std::shared_ptr<Node> pattern_node, std::nullptr_t name) = delete;
|
||||
|
||||
Matcher() {}
|
||||
Matcher(Output<Node>& pattern_node) : m_pattern_node{pattern_node} {}
|
||||
|
||||
Matcher(Output<Node>& pattern_node, const std::string& name) : m_pattern_node(pattern_node), m_name{name} {}
|
||||
|
||||
/// \brief Constructs a Matcher object
|
||||
///
|
||||
/// \param pattern_node is a pattern sub graph that will be matched against input graphs
|
||||
/// \param name is a string which is used for logging and disabling a matcher
|
||||
/// \param strict_mode forces a matcher to consider shapes and ET of nodes
|
||||
Matcher(const Output<Node>& pattern_node, const std::string& name, bool strict_mode)
|
||||
: m_pattern_node(pattern_node),
|
||||
m_name(name),
|
||||
m_strict_mode(strict_mode) {}
|
||||
|
||||
// Some matches should start on a node rather than an output. These three constructors
|
||||
// are transition until we work out the right way to do that.
|
||||
Matcher(std::shared_ptr<Node> pattern_node);
|
||||
Matcher(std::shared_ptr<Node> pattern_node, const std::string& name);
|
||||
Matcher(std::shared_ptr<Node> pattern_node, const std::string& name, bool strict_mode);
|
||||
|
||||
virtual ~Matcher() {}
|
||||
/// \brief Matches a pattern to \p graph_node
|
||||
///
|
||||
/// \param graph_value is an input graph to be matched against
|
||||
bool match(const Output<Node>& graph_value);
|
||||
|
||||
bool match(std::shared_ptr<Node> graph_node);
|
||||
|
||||
/// \brief Matches a pattern to \p graph_node
|
||||
///
|
||||
/// \param graph_value is an input graph to be matched against
|
||||
/// \param previous_matches contains previous mappings from labels to nodes to use
|
||||
bool match(const Output<Node>& graph_value, const PatternMap& previous_matches);
|
||||
bool match(const Output<Node>& graph_value, const PatternValueMap& previous_matches);
|
||||
|
||||
template <typename T>
|
||||
static std::shared_ptr<T> unique_match(std::shared_ptr<Node> node) {
|
||||
std::shared_ptr<T> matched;
|
||||
for (auto arg : node->input_values()) {
|
||||
if (auto t_casted = ov::as_type_ptr<T>(arg.get_node_shared_ptr())) {
|
||||
if (matched) {
|
||||
throw ngraph_error("There's more than two arguments of the same type");
|
||||
} else {
|
||||
matched = t_casted;
|
||||
}
|
||||
}
|
||||
}
|
||||
return matched;
|
||||
}
|
||||
|
||||
bool is_contained_match(const NodeVector& exclusions = {}, bool ignore_unused = true);
|
||||
const NodeVector get_matched_nodes() {
|
||||
return as_node_vector(m_matched_list);
|
||||
}
|
||||
const OutputVector& get_matched_values() const {
|
||||
return m_matched_list;
|
||||
}
|
||||
OutputVector& get_matched_values() {
|
||||
return m_matched_list;
|
||||
}
|
||||
void reset() {}
|
||||
const std::string& get_name() {
|
||||
return m_name;
|
||||
}
|
||||
std::shared_ptr<Node> get_pattern() {
|
||||
return m_pattern_node.get_node_shared_ptr();
|
||||
}
|
||||
Output<Node> get_pattern_value() {
|
||||
return m_pattern_node;
|
||||
}
|
||||
std::shared_ptr<Node> get_match_root();
|
||||
Output<Node> get_match_value();
|
||||
PatternMap get_pattern_map() const;
|
||||
PatternValueMap& get_pattern_value_map() {
|
||||
return m_pattern_map;
|
||||
}
|
||||
PatternValueMaps& get_pattern_value_maps() {
|
||||
return m_pattern_value_maps;
|
||||
}
|
||||
/// \brief Low-level helper to match recurring patterns
|
||||
///
|
||||
/// \param graph is a graph to be matched against
|
||||
/// \param pattern is a recurring pattern
|
||||
/// \param rpattern specifies a node to recur from next
|
||||
/// \param patterns a map from labels to matches
|
||||
|
||||
size_t add_node(Output<Node> node);
|
||||
|
||||
bool virtual match_value(const ngraph::Output<Node>& pattern_value, const ngraph::Output<Node>& graph_value);
|
||||
|
||||
bool is_strict_mode() {
|
||||
return m_strict_mode;
|
||||
}
|
||||
virtual bool match_arguments(Node* pattern_node, const std::shared_ptr<Node>& graph_node);
|
||||
|
||||
void capture(const std::set<Node*>& static_nodes);
|
||||
|
||||
void clear_state();
|
||||
|
||||
size_t get_number_of_recurrent_matches() const {
|
||||
return m_pattern_value_maps.size();
|
||||
}
|
||||
NodeVector get_bound_nodes_for_pattern(const Output<Node>& pattern) const;
|
||||
size_t get_number_of_bound_labels() const;
|
||||
/// \brief Try a match
|
||||
MatcherState start_match();
|
||||
|
||||
Output<Node> m_match_root;
|
||||
Output<Node> m_pattern_node;
|
||||
PatternValueMap m_pattern_map;
|
||||
PatternValueMaps m_pattern_value_maps;
|
||||
OutputVector m_matched_list;
|
||||
|
||||
protected:
|
||||
bool match_permutation(const OutputVector& pattern_args, const OutputVector& args);
|
||||
|
||||
std::string m_name{"unnamed"};
|
||||
bool m_strict_mode{false};
|
||||
};
|
||||
|
||||
class NGRAPH_API RecurrentMatcher {
|
||||
public:
|
||||
/// \brief Constructs a RecurrentMatcher object. Reccurent Matchers are used to match
|
||||
/// repeating patterns (e.g. RNN, LSTM, GRU cells)
|
||||
///
|
||||
/// \param initial_pattern is a pattern sub graph describing the initial cell
|
||||
/// \param pattern is a pattern sub graph describing an individual cell
|
||||
/// \param rpattern is a (recurring) label to denote which node the next match should
|
||||
/// start at
|
||||
/// \param correlated_patterns is a set of labels whose bound nodes must remain the same
|
||||
/// across all cells
|
||||
RecurrentMatcher(const Output<Node>& initial_pattern,
|
||||
const Output<Node>& pattern,
|
||||
const std::shared_ptr<Node>& rpattern,
|
||||
const std::set<std::shared_ptr<Node>>& correlated_patterns)
|
||||
: m_initial_pattern(initial_pattern),
|
||||
m_pattern(pattern),
|
||||
m_recurrent_pattern(rpattern),
|
||||
m_correlated_patterns(correlated_patterns) {}
|
||||
|
||||
/// \brief Constructs a RecurrentMatcher object. Reccurent Matchers are used to match
|
||||
/// repeating patterns (e.g. RNN, LSTM, GRU cells)
|
||||
///
|
||||
/// \param pattern is a pattern sub graph describing an individual cell
|
||||
/// \param rpattern is a (recurring) label to denote which node the next match should
|
||||
/// start at
|
||||
/// \param correlated_patterns is a set of labels whose bound nodes must remain the same
|
||||
/// across all cells
|
||||
RecurrentMatcher(const Output<Node>& pattern,
|
||||
const std::shared_ptr<Node>& rpattern,
|
||||
const std::set<std::shared_ptr<Node>>& correlated_patterns)
|
||||
: RecurrentMatcher(pattern, pattern, rpattern, correlated_patterns) {}
|
||||
|
||||
RecurrentMatcher(const Output<Node>& initial_pattern,
|
||||
const Output<Node>& pattern,
|
||||
const std::shared_ptr<Node>& rpattern,
|
||||
const std::set<std::shared_ptr<op::Label>>& correlated_patterns);
|
||||
|
||||
RecurrentMatcher(const Output<Node>& pattern,
|
||||
const std::shared_ptr<Node>& rpattern,
|
||||
const std::set<std::shared_ptr<op::Label>>& correlated_patterns)
|
||||
: RecurrentMatcher(pattern, pattern, rpattern, correlated_patterns) {}
|
||||
|
||||
/// \brief Returns a vector of bound nodes for a given label (used in a pattern
|
||||
/// describing an individual cell
|
||||
NodeVector get_bound_nodes_for_pattern(const std::shared_ptr<Node>& pattern) const {
|
||||
if (m_matches.count(pattern) == 0) {
|
||||
throw ngraph_error("No bound nodes for a given label");
|
||||
}
|
||||
|
||||
return as_node_vector(m_matches.at(pattern));
|
||||
}
|
||||
|
||||
size_t get_number_of_recurrent_matches() const {
|
||||
if (m_matches.size() == 0) {
|
||||
return 0;
|
||||
}
|
||||
|
||||
return (*m_matches.begin()).second.size();
|
||||
}
|
||||
|
||||
size_t get_number_of_bound_labels() const {
|
||||
return m_matches.size();
|
||||
}
|
||||
/// \brief Tries to match a pattern for an individual cell to a given \p graph
|
||||
bool match(Output<Node> graph);
|
||||
|
||||
std::shared_ptr<Node> get_match_root() {
|
||||
return m_match_root.get_node_shared_ptr();
|
||||
}
|
||||
Output<Node> get_match_value() {
|
||||
return m_match_root;
|
||||
}
|
||||
|
||||
private:
|
||||
Output<Node> m_initial_pattern;
|
||||
Output<Node> m_pattern;
|
||||
std::shared_ptr<Node> m_recurrent_pattern;
|
||||
const std::set<std::shared_ptr<Node>> m_correlated_patterns;
|
||||
RPatternValueMap m_matches;
|
||||
Output<Node> m_match_root;
|
||||
};
|
||||
using ov::pass::pattern::Matcher;
|
||||
using ov::pass::pattern::MatcherState;
|
||||
using ov::pass::pattern::RecurrentMatcher;
|
||||
} // namespace pattern
|
||||
} // namespace ngraph
|
||||
|
@ -6,38 +6,12 @@
|
||||
|
||||
#include "ngraph/node.hpp"
|
||||
#include "ngraph/pattern/op/pattern.hpp"
|
||||
#include "openvino/pass/pattern/op/any.hpp"
|
||||
|
||||
namespace ngraph {
|
||||
namespace pattern {
|
||||
namespace op {
|
||||
/// The graph value is to the matched value list. If the predicate is true for the node
|
||||
/// and the arguments match, the match succeeds.
|
||||
class NGRAPH_API Any : public Pattern {
|
||||
public:
|
||||
static constexpr NodeTypeInfo type_info{"patternAny", 0};
|
||||
const NodeTypeInfo& get_type_info() const override;
|
||||
/// \brief creates a Any node containing a sub-pattern described by \sa type and \sa
|
||||
/// shape.
|
||||
Any(const element::Type& type, const PartialShape& s, ValuePredicate pred, const OutputVector& wrapped_values)
|
||||
: Pattern(wrapped_values, pred) {
|
||||
set_output_type(0, type, s);
|
||||
}
|
||||
Any(const element::Type& type, const PartialShape& s, NodePredicate pred, const NodeVector& wrapped_values)
|
||||
: Any(type, s, as_value_predicate(pred), as_output_vector(wrapped_values)) {}
|
||||
/// \brief creates a Any node containing a sub-pattern described by the type and
|
||||
/// shape of \sa node.
|
||||
Any(const Output<Node>& node, ValuePredicate pred, const OutputVector& wrapped_values)
|
||||
: Any(node.get_element_type(), node.get_partial_shape(), pred, wrapped_values) {}
|
||||
Any(const Output<Node>& node, NodePredicate pred, const NodeVector& wrapped_values)
|
||||
: Any(node.get_element_type(),
|
||||
node.get_partial_shape(),
|
||||
as_value_predicate(pred),
|
||||
as_output_vector(wrapped_values)) {}
|
||||
|
||||
bool match_value(pattern::Matcher* matcher,
|
||||
const Output<Node>& pattern_value,
|
||||
const Output<Node>& graph_value) override;
|
||||
};
|
||||
using ov::pass::pattern::op::Any;
|
||||
} // namespace op
|
||||
} // namespace pattern
|
||||
} // namespace ngraph
|
||||
|
@ -6,47 +6,12 @@
|
||||
|
||||
#include "ngraph/node.hpp"
|
||||
#include "ngraph/pattern/op/pattern.hpp"
|
||||
#include "openvino/pass/pattern/op/any_of.hpp"
|
||||
|
||||
namespace ngraph {
|
||||
namespace pattern {
|
||||
namespace op {
|
||||
/// The graph value is added to the matched values list. If the predicate is true for
|
||||
/// the
|
||||
/// graph node, a submatch is performed on the input of AnyOf and each input of the
|
||||
/// graph node. The first match that succeeds results in a successful match. Otherwise
|
||||
/// the match fails.
|
||||
///
|
||||
/// AnyOf may be given a type and shape for use in strict mode.
|
||||
class NGRAPH_API AnyOf : public Pattern {
|
||||
public:
|
||||
static constexpr NodeTypeInfo type_info{"patternAnyOf", 0};
|
||||
const NodeTypeInfo& get_type_info() const override;
|
||||
/// \brief creates a AnyOf node containing a sub-pattern described by \sa type and
|
||||
/// \sa shape.
|
||||
AnyOf(const element::Type& type, const PartialShape& s, ValuePredicate pred, const OutputVector& wrapped_values)
|
||||
: Pattern(wrapped_values, pred) {
|
||||
if (wrapped_values.size() != 1) {
|
||||
throw ngraph_error("AnyOf expects exactly one argument");
|
||||
}
|
||||
set_output_type(0, type, s);
|
||||
}
|
||||
AnyOf(const element::Type& type, const PartialShape& s, NodePredicate pred, const NodeVector& wrapped_values)
|
||||
: AnyOf(
|
||||
type,
|
||||
s,
|
||||
[pred](const Output<Node>& value) {
|
||||
return pred(value.get_node_shared_ptr());
|
||||
},
|
||||
as_output_vector(wrapped_values)) {}
|
||||
|
||||
/// \brief creates a AnyOf node containing a sub-pattern described by the type and
|
||||
/// shape of \sa node.
|
||||
AnyOf(const Output<Node>& node, ValuePredicate pred, const OutputVector& wrapped_values)
|
||||
: AnyOf(node.get_element_type(), node.get_partial_shape(), pred, wrapped_values) {}
|
||||
AnyOf(std::shared_ptr<Node> node, NodePredicate pred, const NodeVector& wrapped_values)
|
||||
: AnyOf(node, as_value_predicate(pred), as_output_vector(wrapped_values)) {}
|
||||
bool match_value(Matcher* matcher, const Output<Node>& pattern_value, const Output<Node>& graph_value) override;
|
||||
};
|
||||
using ov::pass::pattern::op::AnyOf;
|
||||
} // namespace op
|
||||
} // namespace pattern
|
||||
} // namespace ngraph
|
||||
|
@ -6,23 +6,12 @@
|
||||
|
||||
#include "ngraph/node.hpp"
|
||||
#include "ngraph/pattern/op/pattern.hpp"
|
||||
#include "openvino/pass/pattern/op/any_output.hpp"
|
||||
|
||||
namespace ngraph {
|
||||
namespace pattern {
|
||||
namespace op {
|
||||
/// Matches any output of a node
|
||||
class NGRAPH_API AnyOutput : public Pattern {
|
||||
public:
|
||||
static constexpr NodeTypeInfo type_info{"patternAnyOutput", 0};
|
||||
const NodeTypeInfo& get_type_info() const override;
|
||||
/// \brief creates an AnyOutput node matching any output of a node
|
||||
/// \param node The node to match
|
||||
AnyOutput(const std::shared_ptr<Node>& pattern) : Pattern({pattern->output(0)}) {}
|
||||
|
||||
bool match_value(pattern::Matcher* matcher,
|
||||
const Output<Node>& pattern_value,
|
||||
const Output<Node>& graph_value) override;
|
||||
};
|
||||
using ov::pass::pattern::op::AnyOutput;
|
||||
} // namespace op
|
||||
} // namespace pattern
|
||||
} // namespace ngraph
|
||||
|
@ -6,48 +6,12 @@
|
||||
|
||||
#include "ngraph/node.hpp"
|
||||
#include "ngraph/pattern/op/pattern.hpp"
|
||||
#include "openvino/pass/pattern/op/branch.hpp"
|
||||
|
||||
namespace ngraph {
|
||||
namespace pattern {
|
||||
namespace op {
|
||||
/// A branch adds a loop to the pattern. The branch match is successful if the
|
||||
/// destination node pattern matches the graph value. The destination node is a node in
|
||||
/// the pattern graph that will not have been created some time after the Branch node is
|
||||
/// created; use set_destination to add it.
|
||||
///
|
||||
/// The branch destination is not stored as a shared pointer to prevent reference
|
||||
/// cycles. Thus the destination node must be referenced in some other way to prevent it
|
||||
/// from being deleted.
|
||||
class NGRAPH_API Branch : public Pattern {
|
||||
public:
|
||||
static constexpr NodeTypeInfo type_info{"patternBranch", 0};
|
||||
const NodeTypeInfo& get_type_info() const override;
|
||||
/// \brief Creates a Branch pattern
|
||||
/// \param pattern the destinationing pattern
|
||||
/// \param labels Labels where the destination may occur
|
||||
Branch() : Pattern(OutputVector{}) {
|
||||
set_output_type(0, element::f32, Shape{});
|
||||
}
|
||||
|
||||
void set_destination(const Output<Node>& destination) {
|
||||
m_destination_node = destination.get_node();
|
||||
m_destination_index = destination.get_index();
|
||||
}
|
||||
|
||||
Output<Node> get_destination() const {
|
||||
return m_destination_node == nullptr
|
||||
? Output<Node>()
|
||||
: Output<Node>{m_destination_node->shared_from_this(), m_destination_index};
|
||||
}
|
||||
|
||||
bool match_value(pattern::Matcher* matcher,
|
||||
const Output<Node>& pattern_value,
|
||||
const Output<Node>& graph_value) override;
|
||||
|
||||
protected:
|
||||
Node* m_destination_node{nullptr};
|
||||
size_t m_destination_index{0};
|
||||
};
|
||||
using ov::pass::pattern::op::Branch;
|
||||
} // namespace op
|
||||
} // namespace pattern
|
||||
} // namespace ngraph
|
||||
|
@ -6,37 +6,12 @@
|
||||
|
||||
#include "ngraph/node.hpp"
|
||||
#include "ngraph/pattern/op/pattern.hpp"
|
||||
#include "openvino/pass/pattern/op/capture.hpp"
|
||||
|
||||
namespace ngraph {
|
||||
namespace pattern {
|
||||
namespace op {
|
||||
/// Experimental for support of recurrent matches.
|
||||
///
|
||||
/// Capture adds the pattern value map to a list of pattern value maps and resets
|
||||
/// matches for pattern nodes not in the static node list. The match always succeeds.
|
||||
class NGRAPH_API Capture : public Pattern {
|
||||
public:
|
||||
static constexpr NodeTypeInfo type_info{"patternCapture", 0};
|
||||
const NodeTypeInfo& get_type_info() const override;
|
||||
Capture(const Output<Node>& arg) : Pattern({arg}) {
|
||||
set_output_type(0, arg.get_element_type(), arg.get_partial_shape());
|
||||
}
|
||||
|
||||
/// \brief static nodes are retained after a capture. All other nodes are dropped
|
||||
std::set<Node*> get_static_nodes() {
|
||||
return m_static_nodes;
|
||||
}
|
||||
void set_static_nodes(const std::set<Node*>& static_nodes) {
|
||||
m_static_nodes = static_nodes;
|
||||
}
|
||||
|
||||
virtual bool match_value(pattern::Matcher* matcher,
|
||||
const Output<Node>& pattern_value,
|
||||
const Output<Node>& graph_value) override;
|
||||
|
||||
protected:
|
||||
std::set<Node*> m_static_nodes;
|
||||
};
|
||||
using ov::pass::pattern::op::Capture;
|
||||
} // namespace op
|
||||
} // namespace pattern
|
||||
} // namespace ngraph
|
||||
|
@ -6,106 +6,14 @@
|
||||
|
||||
#include "ngraph/node.hpp"
|
||||
#include "ngraph/pattern/op/pattern.hpp"
|
||||
#include "openvino/pass/pattern/op/label.hpp"
|
||||
|
||||
namespace ngraph {
|
||||
namespace pattern {
|
||||
namespace op {
|
||||
/// Fails if the predicate returns false on the graph value.
|
||||
///
|
||||
/// The graph value is added to the matched values list. If the Label is already
|
||||
/// associated with a value, the match succeeds if the value is the same as the graph
|
||||
/// value. Otherwise, the label is associated with the graph value and the match
|
||||
/// succeeds if the pattern input matches the graph value.
|
||||
///
|
||||
/// DEPRECATED: If no inputs are given to Label, a True node is serves as the input. If
|
||||
/// more than one inputs are given, an Or pattern of the inputs serves as the input.
|
||||
class NGRAPH_API Label : public Pattern {
|
||||
public:
|
||||
static constexpr NodeTypeInfo type_info{"patternLabel", 0};
|
||||
const NodeTypeInfo& get_type_info() const override;
|
||||
/// \brief creates a Label node containing a sub-pattern described by \sa type and
|
||||
/// \sa shape.
|
||||
///
|
||||
/// this Label node can be bound only to the nodes in the input graph
|
||||
/// that match the pattern specified by \sa wrapped_nodes
|
||||
/// Example:
|
||||
/// \code{.cpp}
|
||||
/// auto add = a + b; // a and b are op::Parameter in this example
|
||||
/// auto label = std::make_shared<pattern::op::Label>(element::f32,
|
||||
/// Shape{2,2},
|
||||
/// nullptr,
|
||||
/// OutputVector{add});
|
||||
/// \endcode
|
||||
Label(const element::Type& type,
|
||||
const PartialShape& s,
|
||||
const ValuePredicate pred,
|
||||
const OutputVector& wrapped_values)
|
||||
: Pattern(OutputVector{wrap_values(wrapped_values)}, pred) {
|
||||
set_output_type(0, type, s);
|
||||
}
|
||||
|
||||
explicit Label(const element::Type& type = element::dynamic, const PartialShape& s = PartialShape::dynamic())
|
||||
: Label(
|
||||
type,
|
||||
s,
|
||||
[](const Output<Node>&) {
|
||||
return true;
|
||||
},
|
||||
OutputVector()) {}
|
||||
|
||||
Label(const element::Type& type, const PartialShape& s, ValuePredicate pred)
|
||||
: Label(type, s, pred, OutputVector{}) {}
|
||||
|
||||
Label(const element::Type& type, const PartialShape& s, NodePredicate pred)
|
||||
: Label(type, s, as_value_predicate(pred), OutputVector{}) {}
|
||||
|
||||
Label(const element::Type& type, const PartialShape& s, const NodePredicate pred, const NodeVector& wrapped_values)
|
||||
: Label(type, s, as_value_predicate(pred), as_output_vector(wrapped_values)) {}
|
||||
|
||||
/// \brief creates a Label node containing a sub-pattern described by the type and
|
||||
/// shape of \sa node.
|
||||
///
|
||||
/// this Label node can be bound only to the nodes in the input graph
|
||||
/// that match the pattern specified by \sa wrapped_values
|
||||
/// Example:
|
||||
/// \code{.cpp}
|
||||
/// auto add = a + b; // a and b are op::Parameter in this example
|
||||
/// auto label = std::make_shared<pattern::op::Label>(add,
|
||||
/// nullptr,
|
||||
/// OutputVector{add});
|
||||
/// \endcode
|
||||
Label(const Output<Node>& value, const ValuePredicate pred, const OutputVector& wrapped_values)
|
||||
: Label(value.get_element_type(), value.get_partial_shape(), pred, wrapped_values) {}
|
||||
Label(const Output<Node>& value, const ValuePredicate pred)
|
||||
: Label(value.get_element_type(), value.get_partial_shape(), pred, OutputVector{}) {}
|
||||
|
||||
Label(const Output<Node>& value, const NodePredicate pred)
|
||||
: Label(value.get_element_type(), value.get_partial_shape(), as_value_predicate(pred), OutputVector{}) {}
|
||||
Label(const Output<Node>& value)
|
||||
: Label(
|
||||
value.get_element_type(),
|
||||
value.get_partial_shape(),
|
||||
[](const Output<Node>&) {
|
||||
return true;
|
||||
},
|
||||
OutputVector{}) {}
|
||||
Label(const Output<Node>& node, const NodePredicate pred, const NodeVector& wrapped_values)
|
||||
: Label(node.get_element_type(),
|
||||
node.get_partial_shape(),
|
||||
as_value_predicate(pred),
|
||||
as_output_vector(wrapped_values)) {}
|
||||
|
||||
bool match_value(Matcher* matcher, const Output<Node>& pattern_value, const Output<Node>& graph_value) override;
|
||||
|
||||
protected:
|
||||
static Output<Node> wrap_values(const OutputVector& wrapped_values);
|
||||
};
|
||||
using ov::pass::pattern::op::Label;
|
||||
} // namespace op
|
||||
|
||||
NGRAPH_API
|
||||
std::shared_ptr<Node> any_input();
|
||||
|
||||
NGRAPH_API
|
||||
std::shared_ptr<Node> any_input(const pattern::op::ValuePredicate& pred);
|
||||
using ov::pass::pattern::any_input;
|
||||
} // namespace pattern
|
||||
} // namespace ngraph
|
||||
|
@ -6,25 +6,12 @@
|
||||
|
||||
#include "ngraph/node.hpp"
|
||||
#include "ngraph/pattern/op/pattern.hpp"
|
||||
#include "openvino/pass/pattern/op/or.hpp"
|
||||
|
||||
namespace ngraph {
|
||||
namespace pattern {
|
||||
namespace op {
|
||||
/// A submatch on the graph value is performed on each input to the Or; the match
|
||||
/// succeeds on the first match. Otherwise the match fails.
|
||||
class NGRAPH_API Or : public Pattern {
|
||||
public:
|
||||
static constexpr NodeTypeInfo type_info{"patternOr", 0};
|
||||
const NodeTypeInfo& get_type_info() const override;
|
||||
/// \brief creates an Or node matching one of several sub-patterns in order. Does
|
||||
/// not add node to match list.
|
||||
/// \param patterns The patterns to try for matching
|
||||
Or(const OutputVector& patterns) : Pattern(patterns) {}
|
||||
|
||||
bool match_value(pattern::Matcher* matcher,
|
||||
const Output<Node>& pattern_value,
|
||||
const Output<Node>& graph_value) override;
|
||||
};
|
||||
using ov::pass::pattern::op::Or;
|
||||
} // namespace op
|
||||
} // namespace pattern
|
||||
} // namespace ngraph
|
||||
|
@ -7,8 +7,10 @@
|
||||
#include <functional>
|
||||
|
||||
#include "ngraph/node.hpp"
|
||||
#include "openvino/pass/pattern/op/pattern.hpp"
|
||||
|
||||
namespace ngraph {
|
||||
namespace ov {
|
||||
namespace pass {
|
||||
namespace pattern {
|
||||
namespace op {
|
||||
class Label;
|
||||
@ -16,79 +18,42 @@ class Label;
|
||||
|
||||
class Matcher;
|
||||
class MatchState;
|
||||
|
||||
using RPatternValueMap = std::map<std::shared_ptr<Node>, OutputVector>;
|
||||
using PatternValueMap = std::map<std::shared_ptr<Node>, Output<Node>>;
|
||||
using PatternValueMaps = std::vector<PatternValueMap>;
|
||||
|
||||
using PatternMap = std::map<std::shared_ptr<Node>, std::shared_ptr<Node>>;
|
||||
|
||||
PatternMap as_pattern_map(const PatternValueMap& pattern_value_map);
|
||||
PatternValueMap as_pattern_value_map(const PatternMap& pattern_map);
|
||||
|
||||
template <typename T>
|
||||
std::function<bool(std::shared_ptr<Node>)> has_class() {
|
||||
auto pred = [](std::shared_ptr<Node> node) -> bool {
|
||||
return ov::is_type<T>(node);
|
||||
};
|
||||
|
||||
return pred;
|
||||
} // namespace pattern
|
||||
} // namespace pass
|
||||
} // namespace ov
|
||||
namespace ngraph {
|
||||
namespace pattern {
|
||||
namespace op {
|
||||
using ov::pass::pattern::op::Label;
|
||||
}
|
||||
|
||||
NGRAPH_API
|
||||
std::function<bool(Output<Node>)> consumers_count(size_t n);
|
||||
using ov::pass::pattern::Matcher;
|
||||
using ov::pass::pattern::MatcherState;
|
||||
|
||||
NGRAPH_API
|
||||
std::function<bool(Output<Node>)> has_static_dim(size_t pos);
|
||||
using ov::pass::pattern::PatternValueMap;
|
||||
using ov::pass::pattern::PatternValueMaps;
|
||||
using ov::pass::pattern::RPatternValueMap;
|
||||
|
||||
NGRAPH_API
|
||||
std::function<bool(Output<Node>)> has_static_dims(const std::vector<size_t>& dims);
|
||||
using ov::pass::pattern::PatternMap;
|
||||
|
||||
NGRAPH_API
|
||||
std::function<bool(Output<Node>)> has_static_shape();
|
||||
|
||||
NGRAPH_API
|
||||
std::function<bool(Output<Node>)> has_static_rank();
|
||||
|
||||
NGRAPH_API
|
||||
std::function<bool(Output<Node>)> rank_equals(const Dimension& expected_rank);
|
||||
|
||||
NGRAPH_API
|
||||
std::function<bool(Output<Node>)> type_matches(const element::Type& type);
|
||||
|
||||
NGRAPH_API
|
||||
std::function<bool(Output<Node>)> type_matches_any(const std::vector<element::Type>& types);
|
||||
using ov::pass::pattern::as_pattern_map;
|
||||
using ov::pass::pattern::as_pattern_value_map;
|
||||
using ov::pass::pattern::consumers_count;
|
||||
using ov::pass::pattern::has_class;
|
||||
using ov::pass::pattern::has_static_dim;
|
||||
using ov::pass::pattern::has_static_dims;
|
||||
using ov::pass::pattern::has_static_rank;
|
||||
using ov::pass::pattern::has_static_shape;
|
||||
using ov::pass::pattern::rank_equals;
|
||||
using ov::pass::pattern::type_matches;
|
||||
using ov::pass::pattern::type_matches_any;
|
||||
|
||||
namespace op {
|
||||
using NodePredicate = std::function<bool(std::shared_ptr<Node>)>;
|
||||
using ValuePredicate = std::function<bool(const Output<Node>& value)>;
|
||||
using ov::pass::pattern::op::NodePredicate;
|
||||
using ov::pass::pattern::op::ValuePredicate;
|
||||
|
||||
NGRAPH_API
|
||||
ValuePredicate as_value_predicate(NodePredicate pred);
|
||||
|
||||
class NGRAPH_API Pattern : public Node {
|
||||
public:
|
||||
/// \brief \p a base class for \sa Skip and \sa Label
|
||||
///
|
||||
Pattern(const OutputVector& patterns, ValuePredicate pred) : Node(patterns), m_predicate(pred) {
|
||||
if (!m_predicate) {
|
||||
m_predicate = [](const Output<Node>&) {
|
||||
return true;
|
||||
};
|
||||
}
|
||||
}
|
||||
|
||||
Pattern(const OutputVector& patterns) : Pattern(patterns, nullptr) {}
|
||||
|
||||
virtual std::shared_ptr<Node> clone_with_new_inputs(const OutputVector& /* new_args */) const override {
|
||||
throw ngraph_error("Uncopyable");
|
||||
}
|
||||
|
||||
ValuePredicate get_predicate() const;
|
||||
|
||||
protected:
|
||||
ValuePredicate m_predicate;
|
||||
};
|
||||
using ov::pass::pattern::op::as_value_predicate;
|
||||
using ov::pass::pattern::op::Pattern;
|
||||
} // namespace op
|
||||
} // namespace pattern
|
||||
} // namespace ngraph
|
||||
|
@ -6,37 +6,12 @@
|
||||
|
||||
#include "ngraph/node.hpp"
|
||||
#include "ngraph/pattern/op/pattern.hpp"
|
||||
#include "openvino/pass/pattern/op/skip.hpp"
|
||||
|
||||
namespace ngraph {
|
||||
namespace pattern {
|
||||
namespace op {
|
||||
/// The graph value is added to the matched value list. If the predicate is true, the
|
||||
/// match succeeds if the arguments match; if the predicate is false, the match succeeds
|
||||
/// if the pattern input matches the graph value.
|
||||
class NGRAPH_API Skip : public Pattern {
|
||||
public:
|
||||
static constexpr NodeTypeInfo type_info{"patternSkip", 0};
|
||||
const NodeTypeInfo& get_type_info() const override;
|
||||
Skip(const Output<Node>& arg, ValuePredicate pred) : Pattern({arg}, pred) {
|
||||
set_output_type(0, arg.get_element_type(), arg.get_partial_shape());
|
||||
}
|
||||
|
||||
Skip(const Output<Node>& arg, NodePredicate pred = nullptr) : Pattern({arg}, as_value_predicate(pred)) {
|
||||
set_output_type(0, arg.get_element_type(), arg.get_partial_shape());
|
||||
}
|
||||
|
||||
Skip(const OutputVector& args, ValuePredicate pred) : Pattern(args, pred) {
|
||||
set_output_type(0, args.at(0).get_element_type(), args.at(0).get_partial_shape());
|
||||
}
|
||||
|
||||
Skip(const OutputVector& args, NodePredicate pred = nullptr) : Pattern(args, as_value_predicate(pred)) {
|
||||
set_output_type(0, args.at(0).get_element_type(), args.at(0).get_partial_shape());
|
||||
}
|
||||
|
||||
virtual bool match_value(pattern::Matcher* matcher,
|
||||
const Output<Node>& pattern_value,
|
||||
const Output<Node>& graph_value) override;
|
||||
};
|
||||
using ov::pass::pattern::op::Skip;
|
||||
} // namespace op
|
||||
} // namespace pattern
|
||||
} // namespace ngraph
|
||||
|
@ -6,21 +6,12 @@
|
||||
|
||||
#include "ngraph/node.hpp"
|
||||
#include "ngraph/pattern/op/pattern.hpp"
|
||||
#include "openvino/pass/pattern/op/true.hpp"
|
||||
|
||||
namespace ngraph {
|
||||
namespace pattern {
|
||||
namespace op {
|
||||
/// \brief The match always succeeds.
|
||||
class NGRAPH_API True : public Pattern {
|
||||
public:
|
||||
static constexpr NodeTypeInfo type_info{"patternTrue", 0};
|
||||
const NodeTypeInfo& get_type_info() const override;
|
||||
/// \brief Always matches, does not add node to match list.
|
||||
True() : Pattern(OutputVector{}) {}
|
||||
bool match_value(pattern::Matcher* matcher,
|
||||
const Output<Node>& pattern_value,
|
||||
const Output<Node>& graph_value) override;
|
||||
};
|
||||
using ov::pass::pattern::op::True;
|
||||
} // namespace op
|
||||
} // namespace pattern
|
||||
} // namespace ngraph
|
||||
|
@ -6,68 +6,14 @@
|
||||
|
||||
#include "ngraph/node.hpp"
|
||||
#include "ngraph/pattern/op/pattern.hpp"
|
||||
#include "openvino/pass/pattern/op/wrap_type.hpp"
|
||||
|
||||
namespace ngraph {
|
||||
namespace pattern {
|
||||
namespace op {
|
||||
class NGRAPH_API WrapType : public Pattern {
|
||||
public:
|
||||
static constexpr NodeTypeInfo type_info{"patternAnyType", 0};
|
||||
const NodeTypeInfo& get_type_info() const override;
|
||||
|
||||
explicit WrapType(
|
||||
NodeTypeInfo wrapped_type,
|
||||
const ValuePredicate& pred =
|
||||
[](const Output<Node>& output) {
|
||||
return true;
|
||||
},
|
||||
const OutputVector& input_values = {})
|
||||
: Pattern(input_values, pred),
|
||||
m_wrapped_types({wrapped_type}) {
|
||||
set_output_type(0, element::Type_t::dynamic, PartialShape::dynamic());
|
||||
}
|
||||
|
||||
explicit WrapType(
|
||||
std::vector<NodeTypeInfo> wrapped_types,
|
||||
const ValuePredicate& pred =
|
||||
[](const Output<Node>& output) {
|
||||
return true;
|
||||
},
|
||||
const OutputVector& input_values = {})
|
||||
: Pattern(input_values, pred),
|
||||
m_wrapped_types(std::move(wrapped_types)) {
|
||||
set_output_type(0, element::Type_t::dynamic, PartialShape::dynamic());
|
||||
}
|
||||
|
||||
bool match_value(pattern::Matcher* matcher,
|
||||
const Output<Node>& pattern_value,
|
||||
const Output<Node>& graph_value) override;
|
||||
|
||||
NodeTypeInfo get_wrapped_type() const;
|
||||
|
||||
const std::vector<NodeTypeInfo>& get_wrapped_types() const;
|
||||
|
||||
private:
|
||||
std::vector<NodeTypeInfo> m_wrapped_types;
|
||||
};
|
||||
using ov::pass::pattern::op::WrapType;
|
||||
} // namespace op
|
||||
|
||||
template <class... Args>
|
||||
std::shared_ptr<Node> wrap_type(const OutputVector& inputs, const pattern::op::ValuePredicate& pred) {
|
||||
std::vector<DiscreteTypeInfo> info{Args::type_info...};
|
||||
return std::make_shared<op::WrapType>(info, pred, inputs);
|
||||
}
|
||||
|
||||
template <class... Args>
|
||||
std::shared_ptr<Node> wrap_type(const OutputVector& inputs = {}) {
|
||||
return wrap_type<Args...>(inputs, [](const Output<Node>& output) {
|
||||
return true;
|
||||
});
|
||||
}
|
||||
|
||||
template <class... Args>
|
||||
std::shared_ptr<Node> wrap_type(const pattern::op::ValuePredicate& pred) {
|
||||
return wrap_type<Args...>({}, pred);
|
||||
}
|
||||
using ov::pass::pattern::wrap_type;
|
||||
} // namespace pattern
|
||||
} // namespace ngraph
|
||||
|
@ -50,12 +50,14 @@ class Result;
|
||||
} // namespace v0
|
||||
} // namespace op
|
||||
|
||||
namespace pattern {
|
||||
class Matcher;
|
||||
} // namespace pattern
|
||||
} // namespace ngraph
|
||||
|
||||
namespace ov {
|
||||
namespace pass {
|
||||
namespace pattern {
|
||||
class Matcher;
|
||||
} // namespace pattern
|
||||
} // namespace pass
|
||||
using HostTensor = ngraph::runtime::HostTensor;
|
||||
using HostTensorPtr = std::shared_ptr<HostTensor>;
|
||||
using HostTensorVector = std::vector<HostTensorPtr>;
|
||||
@ -487,11 +489,11 @@ public:
|
||||
}
|
||||
OPENVINO_SUPPRESS_DEPRECATED_END
|
||||
|
||||
virtual bool match_value(ngraph::pattern::Matcher* matcher,
|
||||
virtual bool match_value(ov::pass::pattern::Matcher* matcher,
|
||||
const Output<Node>& pattern_value,
|
||||
const Output<Node>& graph_value);
|
||||
|
||||
virtual bool match_node(ngraph::pattern::Matcher* matcher, const Output<Node>& graph_value);
|
||||
virtual bool match_node(ov::pass::pattern::Matcher* matcher, const Output<Node>& graph_value);
|
||||
|
||||
private:
|
||||
descriptor::Input& get_input_descriptor(size_t position);
|
||||
|
28
ngraph/core/include/openvino/pass/constant_folding.hpp
Normal file
28
ngraph/core/include/openvino/pass/constant_folding.hpp
Normal file
@ -0,0 +1,28 @@
|
||||
// Copyright (C) 2018-2021 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "openvino/pass/pass.hpp"
|
||||
|
||||
namespace ov {
|
||||
namespace pass {
|
||||
/**
|
||||
* @brief Constant folding iterates over the function and tries to evaluate nodes
|
||||
* with constant inputs. Such nodes are then replaced with new Constants containing
|
||||
* the result of a folded operation.
|
||||
*/
|
||||
class OPENVINO_API ConstantFolding : public FunctionPass {
|
||||
public:
|
||||
OPENVINO_RTTI_DECLARATION;
|
||||
bool run_on_function(std::shared_ptr<ngraph::Function> f) override;
|
||||
|
||||
private:
|
||||
void copy_runtime_info_to_target_inputs(const std::shared_ptr<Node>& node, const Output<Node>& replacement);
|
||||
/// \brief Folds pre-calculated output tensor values to constants in case lower and
|
||||
/// upper estimations are equal. Traverses graph backwards starting from the results.
|
||||
bool pre_calculated_values_folding(const std::shared_ptr<ngraph::Function>& f);
|
||||
};
|
||||
} // namespace pass
|
||||
} // namespace ov
|
17
ngraph/core/include/openvino/pass/convert_fp32_to_fp16.hpp
Normal file
17
ngraph/core/include/openvino/pass/convert_fp32_to_fp16.hpp
Normal file
@ -0,0 +1,17 @@
|
||||
// Copyright (C) 2018-2021 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "openvino/pass/graph_rewrite.hpp"
|
||||
|
||||
namespace ov {
|
||||
namespace pass {
|
||||
class OPENVINO_API ConvertFP32ToFP16 : public FunctionPass {
|
||||
public:
|
||||
OPENVINO_RTTI_DECLARATION;
|
||||
bool run_on_function(std::shared_ptr<ngraph::Function>) override;
|
||||
};
|
||||
} // namespace pass
|
||||
} // namespace ov
|
249
ngraph/core/include/openvino/pass/graph_rewrite.hpp
Normal file
249
ngraph/core/include/openvino/pass/graph_rewrite.hpp
Normal file
@ -0,0 +1,249 @@
|
||||
// Copyright (C) 2018-2021 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <functional>
|
||||
#include <memory>
|
||||
#include <set>
|
||||
|
||||
#include "openvino/pass/pass.hpp"
|
||||
#include "openvino/pass/pattern/matcher.hpp"
|
||||
|
||||
namespace ov {
|
||||
using matcher_pass_callback = std::function<bool(pass::pattern::Matcher& m)>;
|
||||
using graph_rewrite_callback = std::function<bool(pass::pattern::Matcher& m)>;
|
||||
using recurrent_graph_rewrite_callback = std::function<bool(pass::pattern::RecurrentMatcher& m)>;
|
||||
using handler_callback = std::function<bool(const std::shared_ptr<Node>& node)>;
|
||||
namespace pass {
|
||||
/// \brief MatcherPass is a basic block for pattern based transformations. It describes
|
||||
/// pattern and
|
||||
/// action that is applied if pattern is matched.
|
||||
///
|
||||
/// MatcherPass consists of Matcher and matcher_pass_callback that needs to be implemented
|
||||
/// and
|
||||
/// finally registered by using \sa register_matcher. MatcherPass can be executed on node
|
||||
/// within
|
||||
/// \sa apply method. To run matcher pass on Function use GraphRewrite.
|
||||
/// In addition MatcherPass provides a way for adding new operations into GraphRewrite
|
||||
/// execution
|
||||
/// queue. That means that operations that were created inside transformation callback can
|
||||
/// be added
|
||||
/// for matching. To register node use \sa register_new_node method. GraphRewrite
|
||||
/// automatically
|
||||
/// takes registered nodes and put them to execution queue. If multiple nodes were register
|
||||
/// make
|
||||
/// sure that they were registered in topological order.
|
||||
/// Note: when implementing pattern for Matcher make sure that root node is an operation
|
||||
/// from opset
|
||||
/// or has ov::pass::pattern::op::WrapType. That will help GraphRewrite to execute matcher
|
||||
/// passes more
|
||||
/// efficient.
|
||||
|
||||
class OPENVINO_API MatcherPass : public PassBase {
|
||||
public:
|
||||
OPENVINO_RTTI_DECLARATION;
|
||||
|
||||
MatcherPass() = default;
|
||||
|
||||
MatcherPass(const MatcherPass&) = delete;
|
||||
MatcherPass& operator=(const MatcherPass&) = delete;
|
||||
|
||||
explicit MatcherPass(const std::string& name,
|
||||
const std::shared_ptr<pattern::Matcher>& m,
|
||||
const handler_callback& handler,
|
||||
const PassPropertyMask& property = PassProperty::CHANGE_DYNAMIC_STATE)
|
||||
: PassBase(),
|
||||
m_handler(handler),
|
||||
m_matcher(m) {
|
||||
set_name(name);
|
||||
set_property(property, true);
|
||||
}
|
||||
|
||||
bool apply(std::shared_ptr<ov::Node> node);
|
||||
|
||||
template <typename T, class... Args>
|
||||
std::shared_ptr<T> register_new_node(Args&&... args) {
|
||||
auto node = std::make_shared<T>(std::forward<Args>(args)...);
|
||||
m_new_nodes.push_back(node);
|
||||
return node;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
std::shared_ptr<T> register_new_node(const std::shared_ptr<T>& node) {
|
||||
m_new_nodes.push_back(node);
|
||||
return node;
|
||||
}
|
||||
|
||||
const std::vector<std::shared_ptr<ov::Node>>& get_new_nodes() {
|
||||
return m_new_nodes;
|
||||
}
|
||||
void clear_new_nodes() {
|
||||
m_new_nodes.clear();
|
||||
}
|
||||
std::shared_ptr<pattern::Matcher> get_matcher() {
|
||||
return m_matcher;
|
||||
}
|
||||
|
||||
protected:
|
||||
void register_matcher(const std::shared_ptr<pattern::Matcher>& m,
|
||||
const graph_rewrite_callback& callback,
|
||||
const PassPropertyMask& property = PassProperty::CHANGE_DYNAMIC_STATE);
|
||||
|
||||
private:
|
||||
handler_callback m_handler;
|
||||
std::shared_ptr<pattern::Matcher> m_matcher;
|
||||
std::vector<std::shared_ptr<ov::Node>> m_new_nodes;
|
||||
};
|
||||
|
||||
/// \brief GraphRewrite is a container for MatcherPasses that allows to run them on Function
|
||||
/// in
|
||||
/// efficient way
|
||||
///
|
||||
/// Graph rewrite pass is used for matcher passes execution on Function.
|
||||
/// To register MatcherPass use \sa add_matcher<T>(args) method where T is a MatcherPass
|
||||
/// class.
|
||||
/// As a default algorithm graph rewrite pass traverse Function in topological order and
|
||||
/// applies
|
||||
/// registered matcher passes for each node. But if all registered matcher passes have type
|
||||
/// based
|
||||
/// root node in Matcher pattern then efficient mechanism is used to execute them.
|
||||
/// Matcher pattern root is type based if it's operation from opset or
|
||||
/// pattern::op::WrapType.
|
||||
/// Note: when implementing pattern for Matcher make sure that root node is an operation
|
||||
/// from opset
|
||||
/// or has ov::pattern::op::WrapType. That will help GraphRewrite to execute matcher
|
||||
/// passes more
|
||||
/// efficient.
|
||||
|
||||
class OPENVINO_API GraphRewrite : public FunctionPass {
|
||||
public:
|
||||
OPENVINO_RTTI_DECLARATION;
|
||||
|
||||
GraphRewrite() = default;
|
||||
|
||||
explicit GraphRewrite(const std::shared_ptr<MatcherPass>& pass) : FunctionPass() {
|
||||
m_matchers.push_back(pass);
|
||||
}
|
||||
|
||||
/// \brief Register given transformation class type to GraphRewrite execution list
|
||||
/// All registered transformations will be executed in a single graph traversal.
|
||||
/// Example below show the basic usage of pass::GraphRewrite
|
||||
///
|
||||
/// pass::Manager manager;
|
||||
/// auto anchor = manager.register_pass<GraphRewrite>();
|
||||
/// anchor->add_matcher<MatcherPassA>();
|
||||
/// anchor->add_matcher<MatcherPassB>();
|
||||
/// anchor->set_name("CommonMatchers");
|
||||
/// manager.run_passes(f);
|
||||
///
|
||||
/// For some purposes transformation can be registered and disabled by default.
|
||||
///
|
||||
/// anchor->add_matcher<MatcherPassB, false>();
|
||||
///
|
||||
/// \return shared_ptr to the transformation instance
|
||||
template <typename T,
|
||||
bool Enabled = true,
|
||||
class... Args,
|
||||
typename std::enable_if<std::is_base_of<pass::MatcherPass, T>::value, bool>::type = true>
|
||||
std::shared_ptr<T> add_matcher(Args&&... args) {
|
||||
static_assert(std::is_base_of<pass::MatcherPass, T>::value, "pass not derived from MatcherPass");
|
||||
auto pass = std::make_shared<T>(std::forward<Args>(args)...);
|
||||
auto pass_config = get_pass_config();
|
||||
pass->set_pass_config(pass_config);
|
||||
if (!Enabled && !pass_config->is_enabled<T>()) {
|
||||
pass_config->disable<T>();
|
||||
}
|
||||
m_matchers.push_back(pass);
|
||||
return pass;
|
||||
}
|
||||
|
||||
/// \brief Register passes from GraphRewrite class that contains sequence of matcher
|
||||
/// passes registered in its ctor.
|
||||
/// For example:
|
||||
///
|
||||
/// class ov::pass::LinFusions: public ov::pass::GraphRewrite {
|
||||
/// public:
|
||||
/// OPENVINO_RTTI_DECLARATION;
|
||||
/// Fusions() {
|
||||
/// add_matcher<ov::pass::AddFusion>();
|
||||
/// add_matcher<ov::pass::MulFusion>();
|
||||
/// }
|
||||
/// };
|
||||
///
|
||||
/// pass::Manager manager;
|
||||
/// auto anchor = manager.register_pass<GraphRewrite>();
|
||||
/// anchor->add_matcher<LinFusions>();
|
||||
/// anchor->add_matcher<OtherFusions>();
|
||||
/// anchor->set_name("CommonFusions");
|
||||
/// manager.run_passes(f);
|
||||
///
|
||||
/// In this case all matcher passes from LinFusions pass will be united with other
|
||||
/// registered matchers.
|
||||
template <typename T,
|
||||
class... Args,
|
||||
typename std::enable_if<std::is_base_of<pass::GraphRewrite, T>::value, bool>::type = true>
|
||||
void add_matcher(Args&&... args) {
|
||||
static_assert(std::is_base_of<pass::GraphRewrite, T>::value, "pass not derived from GraphRewrite");
|
||||
auto pass = std::make_shared<T>(std::forward<Args>(args)...);
|
||||
auto pass_config = get_pass_config();
|
||||
|
||||
for (auto& matcher : pass->m_matchers) {
|
||||
pass->set_pass_config(pass_config);
|
||||
m_matchers.push_back(matcher);
|
||||
}
|
||||
}
|
||||
|
||||
OPENVINO_DEPRECATED("Use MatcherPass instead")
|
||||
void add_matcher(const std::shared_ptr<pattern::Matcher>& m,
|
||||
const graph_rewrite_callback& callback,
|
||||
const PassPropertyMask& property);
|
||||
|
||||
OPENVINO_DEPRECATED("Use MatcherPass instead")
|
||||
void add_matcher(const std::shared_ptr<pattern::Matcher>& m, const ov::graph_rewrite_callback& callback);
|
||||
|
||||
bool run_on_function(std::shared_ptr<ov::Function> f) override;
|
||||
|
||||
void set_pass_config(const std::shared_ptr<PassConfig>& pass_config) override;
|
||||
|
||||
protected:
|
||||
bool apply_matcher_passes(std::shared_ptr<Function> f, std::deque<std::weak_ptr<Node>> nodes_to_run);
|
||||
|
||||
bool m_enable_shape_inference = false;
|
||||
|
||||
std::vector<std::shared_ptr<ov::pass::MatcherPass>> m_matchers;
|
||||
};
|
||||
|
||||
class OPENVINO_API BackwardGraphRewrite : public GraphRewrite {
|
||||
public:
|
||||
OPENVINO_RTTI_DECLARATION;
|
||||
|
||||
BackwardGraphRewrite() = default;
|
||||
|
||||
explicit BackwardGraphRewrite(const std::shared_ptr<MatcherPass>& pass) : GraphRewrite(pass) {}
|
||||
|
||||
bool run_on_function(std::shared_ptr<ov::Function> f) override;
|
||||
};
|
||||
|
||||
class OPENVINO_API RecurrentGraphRewrite : public FunctionPass {
|
||||
public:
|
||||
RecurrentGraphRewrite(size_t num_iters = 10) : FunctionPass(), m_num_iters(num_iters) {}
|
||||
|
||||
void add_matcher(const std::shared_ptr<pattern::RecurrentMatcher>& m,
|
||||
const ov::recurrent_graph_rewrite_callback& callback,
|
||||
const PassPropertyMask& property);
|
||||
|
||||
// TODO: This interface may deprecate after all passes are refactored.
|
||||
void add_matcher(const std::shared_ptr<pattern::RecurrentMatcher>& m,
|
||||
const ov::recurrent_graph_rewrite_callback& callback);
|
||||
|
||||
bool run_on_function(std::shared_ptr<ov::Function> f) override;
|
||||
|
||||
private:
|
||||
size_t m_num_iters;
|
||||
|
||||
std::vector<std::shared_ptr<ov::pass::MatcherPass>> m_matchers;
|
||||
};
|
||||
} // namespace pass
|
||||
} // namespace ov
|
48
ngraph/core/include/openvino/pass/low_latency.hpp
Normal file
48
ngraph/core/include/openvino/pass/low_latency.hpp
Normal file
@ -0,0 +1,48 @@
|
||||
// Copyright (C) 2018-2021 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <memory>
|
||||
#include <vector>
|
||||
|
||||
#include "openvino/pass/pass.hpp"
|
||||
|
||||
namespace ov {
|
||||
namespace pass {
|
||||
/**
|
||||
* @brief The transformation finds all TensorIterator/Loop layers in the network,
|
||||
* processes all back edges that describe a connection between Result and Parameter
|
||||
* of the TensorIterator/Loop bodies,and inserts ReadValue and Assign layers at the
|
||||
* input and output corresponding to this back edge.
|
||||
* Supported platforms: CPU, GNA.
|
||||
*
|
||||
* The example below describes the changes made by the transformation
|
||||
* [] - TensorIterator body
|
||||
* () - new layer
|
||||
* BE - back-edge
|
||||
*
|
||||
* before applying the transformation:
|
||||
* -> input1[BE_1 -> Parameter -> Layers ... -> Result -> BE_1 ]output1->
|
||||
*
|
||||
* after applying the transformation:
|
||||
* ->(ReadValue)-> input1[BE_1 ->Parameter->Layers ...->Result->BE_1]output1 ->(Assign)
|
||||
* \
|
||||
* ->...
|
||||
* After applying the transformation, the resulting network can be inferred
|
||||
* step by step, the states will store between inferences.
|
||||
*/
|
||||
class OPENVINO_API LowLatency2 : public FunctionPass {
|
||||
public:
|
||||
OPENVINO_RTTI_DECLARATION;
|
||||
|
||||
explicit LowLatency2(bool use_const_initializer = true) : m_use_const_initializer(use_const_initializer) {}
|
||||
|
||||
bool run_on_function(std::shared_ptr<ngraph::Function> f) override;
|
||||
|
||||
private:
|
||||
bool m_use_const_initializer;
|
||||
};
|
||||
} // namespace pass
|
||||
} // namespace ov
|
116
ngraph/core/include/openvino/pass/manager.hpp
Normal file
116
ngraph/core/include/openvino/pass/manager.hpp
Normal file
@ -0,0 +1,116 @@
|
||||
// Copyright (C) 2018-2021 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <list>
|
||||
#include <memory>
|
||||
#include <typeinfo>
|
||||
#include <vector>
|
||||
|
||||
#include "openvino/pass/pass.hpp"
|
||||
#include "openvino/pass/validate.hpp"
|
||||
|
||||
namespace ov {
|
||||
namespace pass {
|
||||
class OPENVINO_API Manager {
|
||||
public:
|
||||
Manager();
|
||||
~Manager();
|
||||
|
||||
//// \brief Construct Manager with shared PassConfig instance
|
||||
explicit Manager(std::shared_ptr<PassConfig> pass_config);
|
||||
|
||||
/// \brief Register given transformation class type to execution list
|
||||
/// Example below show the basic usage of pass::Manager
|
||||
///
|
||||
/// pass::Manager manager;
|
||||
/// manager.register_pass<MyTransformation>(/*transformation constructor ars*/);
|
||||
/// manager.run_passes(f);
|
||||
///
|
||||
/// For some purposes transformation can be registered and disabled by default.
|
||||
///
|
||||
/// manager.register_pass<MyTransformation, false>();
|
||||
///
|
||||
/// \return shared_ptr to the transformation instance
|
||||
template <typename T, bool Enable = true, class... Args>
|
||||
std::shared_ptr<T> register_pass(Args&&... args) {
|
||||
auto rc = push_pass<T>(std::forward<Args>(args)...);
|
||||
rc->set_pass_config(m_pass_config);
|
||||
if (m_per_pass_validation) {
|
||||
push_pass<Validate>();
|
||||
}
|
||||
if (!Enable && !m_pass_config->is_enabled<T>()) {
|
||||
m_pass_config->disable<T>();
|
||||
}
|
||||
return rc;
|
||||
}
|
||||
|
||||
void run_passes(std::shared_ptr<Function>);
|
||||
|
||||
void set_pass_visualization(bool new_state) {
|
||||
m_visualize = new_state;
|
||||
}
|
||||
/// \brief Set flag to enable/disable running Validate pass after executing
|
||||
/// each registered pass
|
||||
/// \param new_state Value "true" enables Validate pass run; "false", otherwise
|
||||
void set_per_pass_validation(bool new_state) {
|
||||
m_per_pass_validation = new_state;
|
||||
}
|
||||
/// \brief Callback is a lambda function that can be used by registered transformations.
|
||||
/// The main purpose of this callback is to provide a way for plugins to disable/enable
|
||||
/// transformations based on some conditions. In some cases plugins may want not to
|
||||
/// execute some
|
||||
/// transformations.
|
||||
/// For example plugin can disable unpleasant decompositions because of performance
|
||||
/// reasons for
|
||||
/// some cases.
|
||||
/// Callback example:
|
||||
/// auto callback = [](const std::shared_ptr<const ngraph::Node> & node) -> bool {
|
||||
/// return std::dynamic_pointer_cast<const ngraph::opset3::DepthToSpace>(node) !=
|
||||
/// nullptr;
|
||||
/// };
|
||||
/// This callback returns true in case of DepthToSpace operation. So when execution
|
||||
/// DepthToSpace
|
||||
/// decomposition pass will check is this decomposition needed or plugin can execute
|
||||
/// this
|
||||
/// operation directly. And of course on transformation side we need to have a response
|
||||
/// for this
|
||||
/// callback.
|
||||
/// if (transformation_callback(batch_to_space)) {
|
||||
/// return false;
|
||||
/// }
|
||||
/// \param callback lamda function that returns true in case if node is supported by
|
||||
/// plugin and
|
||||
/// transformation is not needed
|
||||
OPENVINO_DEPRECATED("Please use get_pass_config() to configure transformation pipeline")
|
||||
void set_callback(const param_callback& callback) {
|
||||
m_pass_config->set_callback(callback);
|
||||
}
|
||||
/// \return PassConfig shared object. This object is used for transformations pipeline
|
||||
/// configuration.
|
||||
/// This object allows to disable/enable transformations execution, set callback to
|
||||
/// particular
|
||||
/// transformation. For mo details see PassConfig class.
|
||||
std::shared_ptr<PassConfig> get_pass_config() {
|
||||
return m_pass_config;
|
||||
}
|
||||
|
||||
protected:
|
||||
template <typename T, class... Args>
|
||||
std::shared_ptr<T> push_pass(Args&&... args) {
|
||||
static_assert(std::is_base_of<pass::PassBase, T>::value, "pass not derived from pass base");
|
||||
auto pass = std::make_shared<T>(std::forward<Args>(args)...);
|
||||
auto pass_base = std::static_pointer_cast<PassBase>(pass);
|
||||
m_pass_list.push_back(pass_base);
|
||||
return pass;
|
||||
}
|
||||
|
||||
std::shared_ptr<PassConfig> m_pass_config;
|
||||
std::vector<std::shared_ptr<PassBase>> m_pass_list;
|
||||
bool m_visualize = false;
|
||||
bool m_per_pass_validation = true;
|
||||
};
|
||||
} // namespace pass
|
||||
} // namespace ov
|
110
ngraph/core/include/openvino/pass/pass.hpp
Normal file
110
ngraph/core/include/openvino/pass/pass.hpp
Normal file
@ -0,0 +1,110 @@
|
||||
// Copyright (C) 2018-2021 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <list>
|
||||
#include <memory>
|
||||
#include <vector>
|
||||
|
||||
#include "ngraph/util.hpp"
|
||||
#include "openvino/core/core_visibility.hpp"
|
||||
#include "openvino/core/deprecated.hpp"
|
||||
#include "openvino/core/function.hpp"
|
||||
#include "openvino/core/node.hpp"
|
||||
#include "openvino/pass/pass_config.hpp"
|
||||
|
||||
namespace ov {
|
||||
namespace pass {
|
||||
enum class PassProperty : uint32_t {
|
||||
// Pass requires node shapes to be static
|
||||
REQUIRE_STATIC_SHAPE = 0x1,
|
||||
// Pass transformation will change the function's dynamic state
|
||||
CHANGE_DYNAMIC_STATE = 1 << 1,
|
||||
};
|
||||
|
||||
using PassPropertyMask = ngraph::EnumMask<PassProperty>;
|
||||
|
||||
class OPENVINO_API PassBase {
|
||||
friend class Manager;
|
||||
|
||||
public:
|
||||
PassBase();
|
||||
virtual ~PassBase() = default;
|
||||
/// Check if this pass has all the pass properties.
|
||||
bool get_property(const PassPropertyMask& prop_mask) const;
|
||||
|
||||
void set_name(const std::string& name) {
|
||||
m_name = name;
|
||||
}
|
||||
std::string get_name() const;
|
||||
|
||||
/// \brief Set callback for particular transformation type.
|
||||
/// This method set global callback. For more details see PassConfig class
|
||||
/// documentation.
|
||||
/// \param callback lambda function that takes node and returns bool
|
||||
void set_callback(const param_callback& callback);
|
||||
|
||||
/// \brief Set PassConfig for particular transformation instance
|
||||
/// \param pass_config is a PassConfig shared_ptr
|
||||
virtual void set_pass_config(const std::shared_ptr<PassConfig>& pass_config) {
|
||||
m_pass_config = pass_config;
|
||||
}
|
||||
|
||||
/// \brief Allows to access PassConfig shared instance
|
||||
/// \return Shared instance of PassConfig class
|
||||
std::shared_ptr<PassConfig> get_pass_config() {
|
||||
return m_pass_config;
|
||||
}
|
||||
/// \brief Applies callback for given node. By default callback returns false.
|
||||
/// This method remains here only for backward compatibility and will be removed
|
||||
/// after all transformations are moved to transformation_callback() method.
|
||||
/// \return result of callback execution for given node
|
||||
NGRAPH_DEPRECATED("Please use transformation_callback method instead")
|
||||
bool m_transformation_callback(const std::shared_ptr<const Node>& node) {
|
||||
return m_pass_config->get_callback(get_type_info())(node);
|
||||
}
|
||||
|
||||
/// \brief Applies callback for given node. By default callback returns false.
|
||||
/// \param node which will be used inside callback
|
||||
/// \return result of callback execution for given node
|
||||
bool transformation_callback(const std::shared_ptr<const Node>& node) {
|
||||
return m_pass_config->get_callback(get_type_info())(node);
|
||||
}
|
||||
|
||||
using type_info_t = DiscreteTypeInfo;
|
||||
|
||||
virtual const type_info_t& get_type_info() const = 0;
|
||||
|
||||
protected:
|
||||
void set_property(const PassPropertyMask& prop, bool value);
|
||||
|
||||
private:
|
||||
PassPropertyMask m_property;
|
||||
|
||||
std::string m_name;
|
||||
std::shared_ptr<PassConfig> m_pass_config;
|
||||
};
|
||||
|
||||
class OPENVINO_API FunctionPass : public PassBase {
|
||||
public:
|
||||
NGRAPH_RTTI_DECLARATION;
|
||||
~FunctionPass() override;
|
||||
virtual bool run_on_function(std::shared_ptr<ngraph::Function>) = 0;
|
||||
};
|
||||
|
||||
class Manager;
|
||||
enum class FusionType : uint32_t {
|
||||
//`DIFFERENTIABLE_FUSIONS` produce ops that support autodiff
|
||||
// i.e. implement `generate_adjoints`
|
||||
DIFFERENTIABLE_FUSIONS = 0x1,
|
||||
REGULAR_FUSIONS = 0x2,
|
||||
//`FOP_FUSIONS` produce ops in the FusedOps category that might
|
||||
// not be supported by all backends
|
||||
FOP_FUSIONS = 0x4,
|
||||
ALL_FUSIONS = 0xFFFFFFFF
|
||||
};
|
||||
using FusionTypeMask = ngraph::EnumMask<FusionType>;
|
||||
} // namespace pass
|
||||
} // namespace ov
|
176
ngraph/core/include/openvino/pass/pass_config.hpp
Normal file
176
ngraph/core/include/openvino/pass/pass_config.hpp
Normal file
@ -0,0 +1,176 @@
|
||||
// Copyright (C) 2018-2021 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <list>
|
||||
#include <memory>
|
||||
#include <vector>
|
||||
|
||||
#include "ngraph/util.hpp"
|
||||
#include "openvino/core/core_visibility.hpp"
|
||||
#include "openvino/core/deprecated.hpp"
|
||||
#include "openvino/core/function.hpp"
|
||||
#include "openvino/core/node.hpp"
|
||||
|
||||
namespace ov {
|
||||
namespace pass {
|
||||
using param_callback = std::function<bool(const std::shared_ptr<const ::ov::Node>)>;
|
||||
using param_callback_map = std::map<ov::DiscreteTypeInfo, param_callback>;
|
||||
|
||||
/// \brief Class representing a transformations config that is used for disabling/enabling
|
||||
/// transformations registered inside pass::Manager and also allows to set callback for all
|
||||
/// transformations or for particular transformation.
|
||||
///
|
||||
/// When pass::Manager is created all passes registered inside this manager including nested
|
||||
/// passes will share the same instance of PassConfig class.
|
||||
/// To work with this class first you need to get shared instance of this class by calling
|
||||
/// manager.get_pass_config() method. Then you will be able to disable/enable passes based
|
||||
/// on transformations type_info. For example:
|
||||
///
|
||||
/// pass::Manager manager;
|
||||
/// manager.register_pass<CommonOptimizations>();
|
||||
/// auto pass_config = manager.get_pass_config();
|
||||
/// pass_config->disable<ConvertGELU>(); // this will disable nested pass inside
|
||||
/// // CommonOptimizations pipeline
|
||||
/// manager.run_passes(f);
|
||||
///
|
||||
/// Sometimes it is needed to call transformation inside other transformation manually. And
|
||||
/// for that case before running transformation you need manually check that this pass is
|
||||
/// not disabled and then you need to set current PassConfig instance to this
|
||||
/// transformation. For example:
|
||||
///
|
||||
/// // Inside MatcherPass callback or inside FunctionPass run_on_function() method
|
||||
/// // you need to call get_pass_config() method to get shared instance of PassConfig
|
||||
/// auto pass_config = get_pass_config();
|
||||
///
|
||||
/// // Before running nested transformation you need to check is it disabled or not
|
||||
/// if (!pass_config->is_disabled<ConvertGELU>()) {
|
||||
/// auto pass = ConvertGELU();
|
||||
/// pass->set_pass_config(pass_config);
|
||||
/// pass.apply(node);
|
||||
/// }
|
||||
///
|
||||
/// Following this logic inside your transformations you will guaranty that transformations
|
||||
/// will be executed in a right way.
|
||||
class OPENVINO_API PassConfig {
|
||||
public:
|
||||
/// \brief Disable transformation by its type_info
|
||||
/// \param type_info Transformation type_info
|
||||
void disable(const DiscreteTypeInfo& type_info);
|
||||
/// \brief Disable transformation by its class type (based on type_info)
|
||||
template <typename T>
|
||||
void disable() {
|
||||
OPENVINO_SUPPRESS_DEPRECATED_START
|
||||
disable(T::type_info);
|
||||
OPENVINO_SUPPRESS_DEPRECATED_END
|
||||
}
|
||||
|
||||
/// \brief Enable transformation by its type_info
|
||||
/// \param type_info Transformation type_info
|
||||
void enable(const DiscreteTypeInfo& type_info);
|
||||
/// \brief Enable transformation by its class type (based on type_info)
|
||||
template <typename T>
|
||||
void enable() {
|
||||
OPENVINO_SUPPRESS_DEPRECATED_START
|
||||
enable(T::type_info);
|
||||
OPENVINO_SUPPRESS_DEPRECATED_END
|
||||
}
|
||||
|
||||
/// \brief Set callback for all kind of transformations
|
||||
void set_callback(const param_callback& callback) {
|
||||
m_callback = callback;
|
||||
}
|
||||
template <typename... Args>
|
||||
typename std::enable_if<sizeof...(Args) == 0>::type set_callback(const param_callback& callback) {}
|
||||
|
||||
/// \brief Set callback for particular transformation class types
|
||||
///
|
||||
/// Example below show how to set callback for one or multiple passes using this method.
|
||||
///
|
||||
/// pass_config->set_callback<ngraph::pass::ConvertBatchToSpace,
|
||||
/// ngraph::pass::ConvertSpaceToBatch>(
|
||||
/// [](const_node_ptr &node) -> bool {
|
||||
/// // Disable transformations for cases when input shape rank is not
|
||||
/// equal to 4
|
||||
/// const auto input_shape_rank =
|
||||
/// node->get_output_partial_shape(0).rank().get_length();
|
||||
/// if (input_shape_rank != 4) {
|
||||
/// return false;
|
||||
/// }
|
||||
/// return true;
|
||||
/// });
|
||||
///
|
||||
/// Note that inside transformations you must provide code that work with this callback.
|
||||
/// See example below:
|
||||
///
|
||||
/// if (transformation_callback(node)) {
|
||||
/// return false; // exit from transformation
|
||||
/// }
|
||||
///
|
||||
template <typename T, class... Args>
|
||||
void set_callback(const param_callback& callback) {
|
||||
m_callback_map[T::type_info] = callback;
|
||||
set_callback<Args...>(callback);
|
||||
}
|
||||
|
||||
/// \brief Get callback for given transformation type_info
|
||||
/// \param type_info Transformation type_info
|
||||
///
|
||||
/// In case if callback wasn't set for given transformation type then global callback
|
||||
/// will be returned. But if even global callback wasn't set then default callback will
|
||||
/// be returned.
|
||||
param_callback get_callback(const DiscreteTypeInfo& type_info) const;
|
||||
|
||||
/// \brief Get callback for given transformation class type
|
||||
/// \return callback lambda function
|
||||
template <typename T>
|
||||
param_callback get_callback() const {
|
||||
NGRAPH_SUPPRESS_DEPRECATED_START
|
||||
return get_callback(T::type_info);
|
||||
NGRAPH_SUPPRESS_DEPRECATED_END
|
||||
}
|
||||
|
||||
/// \brief Check either transformation type is disabled or not
|
||||
/// \param type_info Transformation type_info
|
||||
/// \return true if transformation type was disabled and false otherwise
|
||||
bool is_disabled(const DiscreteTypeInfo& type_info) const {
|
||||
return m_disabled.count(type_info);
|
||||
}
|
||||
|
||||
/// \brief Check either transformation class type is disabled or not
|
||||
/// \return true if transformation type was disabled and false otherwise
|
||||
template <typename T>
|
||||
bool is_disabled() const {
|
||||
NGRAPH_SUPPRESS_DEPRECATED_START
|
||||
return is_disabled(T::type_info);
|
||||
NGRAPH_SUPPRESS_DEPRECATED_END
|
||||
}
|
||||
|
||||
/// \brief Check either transformation type is force enabled or not
|
||||
/// \param type_info Transformation type_info
|
||||
/// \return true if transformation type was force enabled and false otherwise
|
||||
bool is_enabled(const DiscreteTypeInfo& type_info) const {
|
||||
return m_enabled.count(type_info);
|
||||
}
|
||||
|
||||
/// \brief Check either transformation class type is force enabled or not
|
||||
/// \return true if transformation type was force enabled and false otherwise
|
||||
template <typename T>
|
||||
bool is_enabled() const {
|
||||
return is_enabled(T::type_info);
|
||||
}
|
||||
|
||||
void add_disabled_passes(const PassConfig& rhs);
|
||||
|
||||
private:
|
||||
param_callback m_callback = [](const std::shared_ptr<const ::ngraph::Node>&) {
|
||||
return false;
|
||||
};
|
||||
param_callback_map m_callback_map;
|
||||
std::unordered_set<DiscreteTypeInfo> m_disabled;
|
||||
std::unordered_set<DiscreteTypeInfo> m_enabled;
|
||||
};
|
||||
} // namespace pass
|
||||
} // namespace ov
|
271
ngraph/core/include/openvino/pass/pattern/matcher.hpp
Normal file
271
ngraph/core/include/openvino/pass/pattern/matcher.hpp
Normal file
@ -0,0 +1,271 @@
|
||||
// Copyright (C) 2018-2021 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <memory.h>
|
||||
|
||||
#include <algorithm>
|
||||
#include <functional>
|
||||
|
||||
#include "ngraph/op/constant.hpp"
|
||||
#include "openvino/core/except.hpp"
|
||||
#include "openvino/core/node.hpp"
|
||||
#include "openvino/pass/pattern/op/any.hpp"
|
||||
#include "openvino/pass/pattern/op/any_of.hpp"
|
||||
#include "openvino/pass/pattern/op/any_output.hpp"
|
||||
#include "openvino/pass/pattern/op/label.hpp"
|
||||
#include "openvino/pass/pattern/op/skip.hpp"
|
||||
|
||||
namespace ov {
|
||||
namespace pass {
|
||||
class GraphRewrite;
|
||||
|
||||
namespace pattern {
|
||||
class Matcher;
|
||||
|
||||
class OPENVINO_API MatcherState {
|
||||
public:
|
||||
MatcherState(Matcher*);
|
||||
bool finish(bool is_successful);
|
||||
~MatcherState();
|
||||
|
||||
protected:
|
||||
Matcher* m_matcher;
|
||||
PatternValueMap m_pattern_value_map;
|
||||
PatternValueMaps m_pattern_value_maps;
|
||||
size_t m_watermark;
|
||||
size_t m_capture_size;
|
||||
bool m_restore{true};
|
||||
};
|
||||
|
||||
/// Matcher looks for node patterns in a computation graph. The patterns are described by an
|
||||
/// automaton that is described by an extended computation graph. The matcher executes
|
||||
/// by attempting to match the start node of the pattern to a computation graph value
|
||||
/// (output of a Node). In addition to determing if a match occurs, a pattern node may add
|
||||
/// graph nodes to a list of matched nodes, associate nodes with graph values, and start
|
||||
/// submatches. Submatches add match state changes to the enclosing match if the submatch
|
||||
/// succeeds; otherwise the state is reverted.
|
||||
///
|
||||
/// The default match behavior of a pattern node with a graph nodes is that the computation
|
||||
/// graph value is added to the end of the matched value list and the match succeeds if the
|
||||
/// node/pattern types match and the input values match. In the case of a commutative node,
|
||||
/// the inputs can match in any order. If the matcher is in strict mode, the graph value
|
||||
/// element type and shape must also match.
|
||||
///
|
||||
/// Pattern nodes that have different match behavior are in ov::pass::pattern::op and have
|
||||
/// descriptions of their match behavior.
|
||||
class OPENVINO_API Matcher {
|
||||
public:
|
||||
using PatternMap = ov::pass::pattern::PatternMap;
|
||||
|
||||
// Avoid implicit string construction from nullptr.
|
||||
Matcher(const std::shared_ptr<Node> pattern_node, std::nullptr_t name) = delete;
|
||||
|
||||
Matcher() = default;
|
||||
Matcher(Output<Node>& pattern_node) : m_pattern_node{pattern_node} {}
|
||||
|
||||
Matcher(Output<Node>& pattern_node, const std::string& name) : m_pattern_node(pattern_node), m_name{name} {}
|
||||
|
||||
/// \brief Constructs a Matcher object
|
||||
///
|
||||
/// \param pattern_node is a pattern sub graph that will be matched against input graphs
|
||||
/// \param name is a string which is used for logging and disabling a matcher
|
||||
/// \param strict_mode forces a matcher to consider shapes and ET of nodes
|
||||
Matcher(const Output<Node>& pattern_node, const std::string& name, bool strict_mode)
|
||||
: m_pattern_node(pattern_node),
|
||||
m_name(name),
|
||||
m_strict_mode(strict_mode) {}
|
||||
|
||||
// Some matches should start on a node rather than an output. These three constructors
|
||||
// are transition until we work out the right way to do that.
|
||||
Matcher(std::shared_ptr<Node> pattern_node);
|
||||
Matcher(std::shared_ptr<Node> pattern_node, const std::string& name);
|
||||
Matcher(std::shared_ptr<Node> pattern_node, const std::string& name, bool strict_mode);
|
||||
|
||||
virtual ~Matcher() = default;
|
||||
/// \brief Matches a pattern to \p graph_node
|
||||
///
|
||||
/// \param graph_value is an input graph to be matched against
|
||||
bool match(const Output<Node>& graph_value);
|
||||
|
||||
bool match(std::shared_ptr<Node> graph_node);
|
||||
|
||||
/// \brief Matches a pattern to \p graph_node
|
||||
///
|
||||
/// \param graph_value is an input graph to be matched against
|
||||
/// \param previous_matches contains previous mappings from labels to nodes to use
|
||||
bool match(const Output<Node>& graph_value, const PatternMap& previous_matches);
|
||||
bool match(const Output<Node>& graph_value, const PatternValueMap& previous_matches);
|
||||
|
||||
template <typename T>
|
||||
static std::shared_ptr<T> unique_match(const std::shared_ptr<Node>& node) {
|
||||
std::shared_ptr<T> matched;
|
||||
for (const auto& arg : node->input_values()) {
|
||||
if (auto t_casted = ov::as_type_ptr<T>(arg.get_node_shared_ptr())) {
|
||||
if (matched) {
|
||||
throw Exception("There's more than two arguments of the same type");
|
||||
} else {
|
||||
matched = t_casted;
|
||||
}
|
||||
}
|
||||
}
|
||||
return matched;
|
||||
}
|
||||
|
||||
bool is_contained_match(const NodeVector& exclusions = {}, bool ignore_unused = true);
|
||||
const NodeVector get_matched_nodes() {
|
||||
return as_node_vector(m_matched_list);
|
||||
}
|
||||
const OutputVector& get_matched_values() const {
|
||||
return m_matched_list;
|
||||
}
|
||||
OutputVector& get_matched_values() {
|
||||
return m_matched_list;
|
||||
}
|
||||
void reset() {}
|
||||
const std::string& get_name() {
|
||||
return m_name;
|
||||
}
|
||||
std::shared_ptr<Node> get_pattern() {
|
||||
return m_pattern_node.get_node_shared_ptr();
|
||||
}
|
||||
Output<Node> get_pattern_value() {
|
||||
return m_pattern_node;
|
||||
}
|
||||
std::shared_ptr<Node> get_match_root();
|
||||
Output<Node> get_match_value();
|
||||
PatternMap get_pattern_map() const;
|
||||
PatternValueMap& get_pattern_value_map() {
|
||||
return m_pattern_map;
|
||||
}
|
||||
PatternValueMaps& get_pattern_value_maps() {
|
||||
return m_pattern_value_maps;
|
||||
}
|
||||
/// \brief Low-level helper to match recurring patterns
|
||||
///
|
||||
/// \param graph is a graph to be matched against
|
||||
/// \param pattern is a recurring pattern
|
||||
/// \param rpattern specifies a node to recur from next
|
||||
/// \param patterns a map from labels to matches
|
||||
|
||||
size_t add_node(Output<Node> node);
|
||||
|
||||
bool virtual match_value(const ov::Output<Node>& pattern_value, const ov::Output<Node>& graph_value);
|
||||
|
||||
bool is_strict_mode() {
|
||||
return m_strict_mode;
|
||||
}
|
||||
virtual bool match_arguments(Node* pattern_node, const std::shared_ptr<Node>& graph_node);
|
||||
|
||||
void capture(const std::set<Node*>& static_nodes);
|
||||
|
||||
void clear_state();
|
||||
|
||||
size_t get_number_of_recurrent_matches() const {
|
||||
return m_pattern_value_maps.size();
|
||||
}
|
||||
NodeVector get_bound_nodes_for_pattern(const Output<Node>& pattern) const;
|
||||
size_t get_number_of_bound_labels() const;
|
||||
/// \brief Try a match
|
||||
MatcherState start_match();
|
||||
|
||||
Output<Node> m_match_root;
|
||||
Output<Node> m_pattern_node;
|
||||
PatternValueMap m_pattern_map;
|
||||
PatternValueMaps m_pattern_value_maps;
|
||||
OutputVector m_matched_list;
|
||||
|
||||
protected:
|
||||
bool match_permutation(const OutputVector& pattern_args, const OutputVector& args);
|
||||
|
||||
std::string m_name{"unnamed"};
|
||||
bool m_strict_mode{false};
|
||||
};
|
||||
|
||||
class OPENVINO_API RecurrentMatcher {
|
||||
public:
|
||||
/// \brief Constructs a RecurrentMatcher object. Reccurent Matchers are used to match
|
||||
/// repeating patterns (e.g. RNN, LSTM, GRU cells)
|
||||
///
|
||||
/// \param initial_pattern is a pattern sub graph describing the initial cell
|
||||
/// \param pattern is a pattern sub graph describing an individual cell
|
||||
/// \param rpattern is a (recurring) label to denote which node the next match should
|
||||
/// start at
|
||||
/// \param correlated_patterns is a set of labels whose bound nodes must remain the same
|
||||
/// across all cells
|
||||
RecurrentMatcher(const Output<Node>& initial_pattern,
|
||||
const Output<Node>& pattern,
|
||||
const std::shared_ptr<Node>& rpattern,
|
||||
const std::set<std::shared_ptr<Node>>& correlated_patterns)
|
||||
: m_initial_pattern(initial_pattern),
|
||||
m_pattern(pattern),
|
||||
m_recurrent_pattern(rpattern),
|
||||
m_correlated_patterns(correlated_patterns) {}
|
||||
|
||||
/// \brief Constructs a RecurrentMatcher object. Reccurent Matchers are used to match
|
||||
/// repeating patterns (e.g. RNN, LSTM, GRU cells)
|
||||
///
|
||||
/// \param pattern is a pattern sub graph describing an individual cell
|
||||
/// \param rpattern is a (recurring) label to denote which node the next match should
|
||||
/// start at
|
||||
/// \param correlated_patterns is a set of labels whose bound nodes must remain the same
|
||||
/// across all cells
|
||||
RecurrentMatcher(const Output<Node>& pattern,
|
||||
const std::shared_ptr<Node>& rpattern,
|
||||
const std::set<std::shared_ptr<Node>>& correlated_patterns)
|
||||
: RecurrentMatcher(pattern, pattern, rpattern, correlated_patterns) {}
|
||||
|
||||
RecurrentMatcher(const Output<Node>& initial_pattern,
|
||||
const Output<Node>& pattern,
|
||||
const std::shared_ptr<Node>& rpattern,
|
||||
const std::set<std::shared_ptr<op::Label>>& correlated_patterns);
|
||||
|
||||
RecurrentMatcher(const Output<Node>& pattern,
|
||||
const std::shared_ptr<Node>& rpattern,
|
||||
const std::set<std::shared_ptr<op::Label>>& correlated_patterns)
|
||||
: RecurrentMatcher(pattern, pattern, rpattern, correlated_patterns) {}
|
||||
|
||||
/// \brief Returns a vector of bound nodes for a given label (used in a pattern
|
||||
/// describing an individual cell
|
||||
NodeVector get_bound_nodes_for_pattern(const std::shared_ptr<Node>& pattern) const {
|
||||
if (m_matches.count(pattern) == 0) {
|
||||
throw Exception("No bound nodes for a given label");
|
||||
}
|
||||
|
||||
return as_node_vector(m_matches.at(pattern));
|
||||
}
|
||||
|
||||
size_t get_number_of_recurrent_matches() const {
|
||||
if (m_matches.size() == 0) {
|
||||
return 0;
|
||||
}
|
||||
|
||||
return (*m_matches.begin()).second.size();
|
||||
}
|
||||
|
||||
size_t get_number_of_bound_labels() const {
|
||||
return m_matches.size();
|
||||
}
|
||||
/// \brief Tries to match a pattern for an individual cell to a given \p graph
|
||||
bool match(Output<Node> graph);
|
||||
|
||||
std::shared_ptr<Node> get_match_root() {
|
||||
return m_match_root.get_node_shared_ptr();
|
||||
}
|
||||
Output<Node> get_match_value() {
|
||||
return m_match_root;
|
||||
}
|
||||
|
||||
private:
|
||||
Output<Node> m_initial_pattern;
|
||||
Output<Node> m_pattern;
|
||||
std::shared_ptr<Node> m_recurrent_pattern;
|
||||
const std::set<std::shared_ptr<Node>> m_correlated_patterns;
|
||||
RPatternValueMap m_matches;
|
||||
Output<Node> m_match_root;
|
||||
};
|
||||
} // namespace pattern
|
||||
} // namespace pass
|
||||
} // namespace ov
|
45
ngraph/core/include/openvino/pass/pattern/op/any.hpp
Normal file
45
ngraph/core/include/openvino/pass/pattern/op/any.hpp
Normal file
@ -0,0 +1,45 @@
|
||||
// Copyright (C) 2018-2021 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "openvino/core/node.hpp"
|
||||
#include "openvino/pass/pattern/op/pattern.hpp"
|
||||
|
||||
namespace ov {
|
||||
namespace pass {
|
||||
namespace pattern {
|
||||
namespace op {
|
||||
/// The graph value is to the matched value list. If the predicate is true for the node
|
||||
/// and the arguments match, the match succeeds.
|
||||
class OPENVINO_API Any : public Pattern {
|
||||
public:
|
||||
static constexpr NodeTypeInfo type_info{"patternAny", 0};
|
||||
const NodeTypeInfo& get_type_info() const override;
|
||||
/// \brief creates a Any node containing a sub-pattern described by \sa type and \sa
|
||||
/// shape.
|
||||
Any(const element::Type& type, const PartialShape& s, ValuePredicate pred, const OutputVector& wrapped_values)
|
||||
: Pattern(wrapped_values, pred) {
|
||||
set_output_type(0, type, s);
|
||||
}
|
||||
Any(const element::Type& type, const PartialShape& s, NodePredicate pred, const NodeVector& wrapped_values)
|
||||
: Any(type, s, as_value_predicate(pred), as_output_vector(wrapped_values)) {}
|
||||
/// \brief creates a Any node containing a sub-pattern described by the type and
|
||||
/// shape of \sa node.
|
||||
Any(const Output<Node>& node, ValuePredicate pred, const OutputVector& wrapped_values)
|
||||
: Any(node.get_element_type(), node.get_partial_shape(), pred, wrapped_values) {}
|
||||
Any(const Output<Node>& node, NodePredicate pred, const NodeVector& wrapped_values)
|
||||
: Any(node.get_element_type(),
|
||||
node.get_partial_shape(),
|
||||
as_value_predicate(pred),
|
||||
as_output_vector(wrapped_values)) {}
|
||||
|
||||
bool match_value(pattern::Matcher* matcher,
|
||||
const Output<Node>& pattern_value,
|
||||
const Output<Node>& graph_value) override;
|
||||
};
|
||||
} // namespace op
|
||||
} // namespace pattern
|
||||
} // namespace pass
|
||||
} // namespace ov
|
54
ngraph/core/include/openvino/pass/pattern/op/any_of.hpp
Normal file
54
ngraph/core/include/openvino/pass/pattern/op/any_of.hpp
Normal file
@ -0,0 +1,54 @@
|
||||
// Copyright (C) 2018-2021 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "openvino/core/node.hpp"
|
||||
#include "openvino/pass/pattern/op/pattern.hpp"
|
||||
|
||||
namespace ov {
|
||||
namespace pass {
|
||||
namespace pattern {
|
||||
namespace op {
|
||||
/// The graph value is added to the matched values list. If the predicate is true for
|
||||
/// the
|
||||
/// graph node, a submatch is performed on the input of AnyOf and each input of the
|
||||
/// graph node. The first match that succeeds results in a successful match. Otherwise
|
||||
/// the match fails.
|
||||
///
|
||||
/// AnyOf may be given a type and shape for use in strict mode.
|
||||
class OPENVINO_API AnyOf : public Pattern {
|
||||
public:
|
||||
static constexpr NodeTypeInfo type_info{"patternAnyOf", 0};
|
||||
const NodeTypeInfo& get_type_info() const override;
|
||||
/// \brief creates a AnyOf node containing a sub-pattern described by \sa type and
|
||||
/// \sa shape.
|
||||
AnyOf(const element::Type& type, const PartialShape& s, ValuePredicate pred, const OutputVector& wrapped_values)
|
||||
: Pattern(wrapped_values, pred) {
|
||||
if (wrapped_values.size() != 1) {
|
||||
throw Exception("AnyOf expects exactly one argument");
|
||||
}
|
||||
set_output_type(0, type, s);
|
||||
}
|
||||
AnyOf(const element::Type& type, const PartialShape& s, NodePredicate pred, const NodeVector& wrapped_values)
|
||||
: AnyOf(
|
||||
type,
|
||||
s,
|
||||
[pred](const Output<Node>& value) {
|
||||
return pred(value.get_node_shared_ptr());
|
||||
},
|
||||
as_output_vector(wrapped_values)) {}
|
||||
|
||||
/// \brief creates a AnyOf node containing a sub-pattern described by the type and
|
||||
/// shape of \sa node.
|
||||
AnyOf(const Output<Node>& node, ValuePredicate pred, const OutputVector& wrapped_values)
|
||||
: AnyOf(node.get_element_type(), node.get_partial_shape(), pred, wrapped_values) {}
|
||||
AnyOf(const std::shared_ptr<Node>& node, NodePredicate pred, const NodeVector& wrapped_values)
|
||||
: AnyOf(node, as_value_predicate(pred), as_output_vector(wrapped_values)) {}
|
||||
bool match_value(Matcher* matcher, const Output<Node>& pattern_value, const Output<Node>& graph_value) override;
|
||||
};
|
||||
} // namespace op
|
||||
} // namespace pattern
|
||||
} // namespace pass
|
||||
} // namespace ov
|
30
ngraph/core/include/openvino/pass/pattern/op/any_output.hpp
Normal file
30
ngraph/core/include/openvino/pass/pattern/op/any_output.hpp
Normal file
@ -0,0 +1,30 @@
|
||||
// Copyright (C) 2018-2021 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "openvino/core/node.hpp"
|
||||
#include "openvino/pass/pattern/op/pattern.hpp"
|
||||
|
||||
namespace ov {
|
||||
namespace pass {
|
||||
namespace pattern {
|
||||
namespace op {
|
||||
/// Matches any output of a node
|
||||
class OPENVINO_API AnyOutput : public Pattern {
|
||||
public:
|
||||
static constexpr NodeTypeInfo type_info{"patternAnyOutput", 0};
|
||||
const NodeTypeInfo& get_type_info() const override;
|
||||
/// \brief creates an AnyOutput node matching any output of a node
|
||||
/// \param node The node to match
|
||||
AnyOutput(const std::shared_ptr<Node>& pattern) : Pattern({pattern->output(0)}) {}
|
||||
|
||||
bool match_value(pattern::Matcher* matcher,
|
||||
const Output<Node>& pattern_value,
|
||||
const Output<Node>& graph_value) override;
|
||||
};
|
||||
} // namespace op
|
||||
} // namespace pattern
|
||||
} // namespace pass
|
||||
} // namespace ov
|
55
ngraph/core/include/openvino/pass/pattern/op/branch.hpp
Normal file
55
ngraph/core/include/openvino/pass/pattern/op/branch.hpp
Normal file
@ -0,0 +1,55 @@
|
||||
// Copyright (C) 2018-2021 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "openvino/core/node.hpp"
|
||||
#include "openvino/pass/pattern/op/pattern.hpp"
|
||||
|
||||
namespace ov {
|
||||
namespace pass {
|
||||
namespace pattern {
|
||||
namespace op {
|
||||
/// A branch adds a loop to the pattern. The branch match is successful if the
|
||||
/// destination node pattern matches the graph value. The destination node is a node in
|
||||
/// the pattern graph that will not have been created some time after the Branch node is
|
||||
/// created; use set_destination to add it.
|
||||
///
|
||||
/// The branch destination is not stored as a shared pointer to prevent reference
|
||||
/// cycles. Thus the destination node must be referenced in some other way to prevent it
|
||||
/// from being deleted.
|
||||
class OPENVINO_API Branch : public Pattern {
|
||||
public:
|
||||
static constexpr NodeTypeInfo type_info{"patternBranch", 0};
|
||||
const NodeTypeInfo& get_type_info() const override;
|
||||
/// \brief Creates a Branch pattern
|
||||
/// \param pattern the destinationing pattern
|
||||
/// \param labels Labels where the destination may occur
|
||||
Branch() : Pattern(OutputVector{}) {
|
||||
set_output_type(0, element::f32, ngraph::Shape{});
|
||||
}
|
||||
|
||||
void set_destination(const Output<Node>& destination) {
|
||||
m_destination_node = destination.get_node();
|
||||
m_destination_index = destination.get_index();
|
||||
}
|
||||
|
||||
Output<Node> get_destination() const {
|
||||
return m_destination_node == nullptr
|
||||
? Output<Node>()
|
||||
: Output<Node>{m_destination_node->shared_from_this(), m_destination_index};
|
||||
}
|
||||
|
||||
bool match_value(pattern::Matcher* matcher,
|
||||
const Output<Node>& pattern_value,
|
||||
const Output<Node>& graph_value) override;
|
||||
|
||||
protected:
|
||||
Node* m_destination_node{nullptr};
|
||||
size_t m_destination_index{0};
|
||||
};
|
||||
} // namespace op
|
||||
} // namespace pattern
|
||||
} // namespace pass
|
||||
} // namespace ov
|
44
ngraph/core/include/openvino/pass/pattern/op/capture.hpp
Normal file
44
ngraph/core/include/openvino/pass/pattern/op/capture.hpp
Normal file
@ -0,0 +1,44 @@
|
||||
// Copyright (C) 2018-2021 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "openvino/core/node.hpp"
|
||||
#include "openvino/pass/pattern/op/pattern.hpp"
|
||||
|
||||
namespace ov {
|
||||
namespace pass {
|
||||
namespace pattern {
|
||||
namespace op {
|
||||
/// Experimental for support of recurrent matches.
|
||||
///
|
||||
/// Capture adds the pattern value map to a list of pattern value maps and resets
|
||||
/// matches for pattern nodes not in the static node list. The match always succeeds.
|
||||
class OPENVINO_API Capture : public Pattern {
|
||||
public:
|
||||
static constexpr NodeTypeInfo type_info{"patternCapture", 0};
|
||||
const NodeTypeInfo& get_type_info() const override;
|
||||
Capture(const Output<Node>& arg) : Pattern({arg}) {
|
||||
set_output_type(0, arg.get_element_type(), arg.get_partial_shape());
|
||||
}
|
||||
|
||||
/// \brief static nodes are retained after a capture. All other nodes are dropped
|
||||
std::set<Node*> get_static_nodes() {
|
||||
return m_static_nodes;
|
||||
}
|
||||
void set_static_nodes(const std::set<Node*>& static_nodes) {
|
||||
m_static_nodes = static_nodes;
|
||||
}
|
||||
|
||||
bool match_value(pattern::Matcher* matcher,
|
||||
const Output<Node>& pattern_value,
|
||||
const Output<Node>& graph_value) override;
|
||||
|
||||
protected:
|
||||
std::set<Node*> m_static_nodes;
|
||||
};
|
||||
} // namespace op
|
||||
} // namespace pattern
|
||||
} // namespace pass
|
||||
} // namespace ov
|
113
ngraph/core/include/openvino/pass/pattern/op/label.hpp
Normal file
113
ngraph/core/include/openvino/pass/pattern/op/label.hpp
Normal file
@ -0,0 +1,113 @@
|
||||
// Copyright (C) 2018-2021 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "openvino/core/node.hpp"
|
||||
#include "openvino/pass/pattern/op/pattern.hpp"
|
||||
|
||||
namespace ov {
|
||||
namespace pass {
|
||||
namespace pattern {
|
||||
namespace op {
|
||||
/// Fails if the predicate returns false on the graph value.
|
||||
///
|
||||
/// The graph value is added to the matched values list. If the Label is already
|
||||
/// associated with a value, the match succeeds if the value is the same as the graph
|
||||
/// value. Otherwise, the label is associated with the graph value and the match
|
||||
/// succeeds if the pattern input matches the graph value.
|
||||
///
|
||||
/// DEPRECATED: If no inputs are given to Label, a True node is serves as the input. If
|
||||
/// more than one inputs are given, an Or pattern of the inputs serves as the input.
|
||||
class OPENVINO_API Label : public Pattern {
|
||||
public:
|
||||
static constexpr NodeTypeInfo type_info{"patternLabel", 0};
|
||||
const NodeTypeInfo& get_type_info() const override;
|
||||
/// \brief creates a Label node containing a sub-pattern described by \sa type and
|
||||
/// \sa shape.
|
||||
///
|
||||
/// this Label node can be bound only to the nodes in the input graph
|
||||
/// that match the pattern specified by \sa wrapped_nodes
|
||||
/// Example:
|
||||
/// \code{.cpp}
|
||||
/// auto add = a + b; // a and b are op::Parameter in this example
|
||||
/// auto label = std::make_shared<pattern::op::Label>(element::f32,
|
||||
/// Shape{2,2},
|
||||
/// nullptr,
|
||||
/// OutputVector{add});
|
||||
/// \endcode
|
||||
Label(const element::Type& type,
|
||||
const PartialShape& s,
|
||||
const ValuePredicate pred,
|
||||
const OutputVector& wrapped_values)
|
||||
: Pattern(OutputVector{wrap_values(wrapped_values)}, pred) {
|
||||
set_output_type(0, type, s);
|
||||
}
|
||||
|
||||
explicit Label(const element::Type& type = element::dynamic, const PartialShape& s = PartialShape::dynamic())
|
||||
: Label(
|
||||
type,
|
||||
s,
|
||||
[](const Output<Node>&) {
|
||||
return true;
|
||||
},
|
||||
OutputVector()) {}
|
||||
|
||||
Label(const element::Type& type, const PartialShape& s, ValuePredicate pred)
|
||||
: Label(type, s, pred, OutputVector{}) {}
|
||||
|
||||
Label(const element::Type& type, const PartialShape& s, NodePredicate pred)
|
||||
: Label(type, s, as_value_predicate(pred), OutputVector{}) {}
|
||||
|
||||
Label(const element::Type& type, const PartialShape& s, const NodePredicate pred, const NodeVector& wrapped_values)
|
||||
: Label(type, s, as_value_predicate(pred), as_output_vector(wrapped_values)) {}
|
||||
|
||||
/// \brief creates a Label node containing a sub-pattern described by the type and
|
||||
/// shape of \sa node.
|
||||
///
|
||||
/// this Label node can be bound only to the nodes in the input graph
|
||||
/// that match the pattern specified by \sa wrapped_values
|
||||
/// Example:
|
||||
/// \code{.cpp}
|
||||
/// auto add = a + b; // a and b are op::Parameter in this example
|
||||
/// auto label = std::make_shared<pattern::op::Label>(add,
|
||||
/// nullptr,
|
||||
/// OutputVector{add});
|
||||
/// \endcode
|
||||
Label(const Output<Node>& value, const ValuePredicate pred, const OutputVector& wrapped_values)
|
||||
: Label(value.get_element_type(), value.get_partial_shape(), pred, wrapped_values) {}
|
||||
Label(const Output<Node>& value, const ValuePredicate pred)
|
||||
: Label(value.get_element_type(), value.get_partial_shape(), pred, OutputVector{}) {}
|
||||
|
||||
Label(const Output<Node>& value, const NodePredicate pred)
|
||||
: Label(value.get_element_type(), value.get_partial_shape(), as_value_predicate(pred), OutputVector{}) {}
|
||||
Label(const Output<Node>& value)
|
||||
: Label(
|
||||
value.get_element_type(),
|
||||
value.get_partial_shape(),
|
||||
[](const Output<Node>&) {
|
||||
return true;
|
||||
},
|
||||
OutputVector{}) {}
|
||||
Label(const Output<Node>& node, const NodePredicate pred, const NodeVector& wrapped_values)
|
||||
: Label(node.get_element_type(),
|
||||
node.get_partial_shape(),
|
||||
as_value_predicate(pred),
|
||||
as_output_vector(wrapped_values)) {}
|
||||
|
||||
bool match_value(Matcher* matcher, const Output<Node>& pattern_value, const Output<Node>& graph_value) override;
|
||||
|
||||
protected:
|
||||
static Output<Node> wrap_values(const OutputVector& wrapped_values);
|
||||
};
|
||||
} // namespace op
|
||||
|
||||
OPENVINO_API
|
||||
std::shared_ptr<Node> any_input();
|
||||
|
||||
OPENVINO_API
|
||||
std::shared_ptr<Node> any_input(const pattern::op::ValuePredicate& pred);
|
||||
} // namespace pattern
|
||||
} // namespace pass
|
||||
} // namespace ov
|
32
ngraph/core/include/openvino/pass/pattern/op/or.hpp
Normal file
32
ngraph/core/include/openvino/pass/pattern/op/or.hpp
Normal file
@ -0,0 +1,32 @@
|
||||
// Copyright (C) 2018-2021 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "openvino/core/node.hpp"
|
||||
#include "openvino/pass/pattern/op/pattern.hpp"
|
||||
|
||||
namespace ov {
|
||||
namespace pass {
|
||||
namespace pattern {
|
||||
namespace op {
|
||||
/// A submatch on the graph value is performed on each input to the Or; the match
|
||||
/// succeeds on the first match. Otherwise the match fails.
|
||||
class OPENVINO_API Or : public Pattern {
|
||||
public:
|
||||
static constexpr NodeTypeInfo type_info{"patternOr", 0};
|
||||
const NodeTypeInfo& get_type_info() const override;
|
||||
/// \brief creates an Or node matching one of several sub-patterns in order. Does
|
||||
/// not add node to match list.
|
||||
/// \param patterns The patterns to try for matching
|
||||
Or(const OutputVector& patterns) : Pattern(patterns) {}
|
||||
|
||||
bool match_value(pattern::Matcher* matcher,
|
||||
const Output<Node>& pattern_value,
|
||||
const Output<Node>& graph_value) override;
|
||||
};
|
||||
} // namespace op
|
||||
} // namespace pattern
|
||||
} // namespace pass
|
||||
} // namespace ov
|
96
ngraph/core/include/openvino/pass/pattern/op/pattern.hpp
Normal file
96
ngraph/core/include/openvino/pass/pattern/op/pattern.hpp
Normal file
@ -0,0 +1,96 @@
|
||||
// Copyright (C) 2018-2021 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <functional>
|
||||
|
||||
#include "openvino/core/node.hpp"
|
||||
|
||||
namespace ov {
|
||||
namespace pass {
|
||||
namespace pattern {
|
||||
namespace op {
|
||||
class Label;
|
||||
}
|
||||
|
||||
class Matcher;
|
||||
class MatcherState;
|
||||
|
||||
using RPatternValueMap = std::map<std::shared_ptr<Node>, OutputVector>;
|
||||
using PatternValueMap = std::map<std::shared_ptr<Node>, Output<Node>>;
|
||||
using PatternValueMaps = std::vector<PatternValueMap>;
|
||||
|
||||
using PatternMap = std::map<std::shared_ptr<Node>, std::shared_ptr<Node>>;
|
||||
|
||||
PatternMap as_pattern_map(const PatternValueMap& pattern_value_map);
|
||||
PatternValueMap as_pattern_value_map(const PatternMap& pattern_map);
|
||||
|
||||
template <typename T>
|
||||
std::function<bool(std::shared_ptr<Node>)> has_class() {
|
||||
auto pred = [](std::shared_ptr<Node> node) -> bool {
|
||||
return ov::is_type<T>(node);
|
||||
};
|
||||
|
||||
return pred;
|
||||
}
|
||||
|
||||
OPENVINO_API
|
||||
std::function<bool(Output<Node>)> consumers_count(size_t n);
|
||||
|
||||
OPENVINO_API
|
||||
std::function<bool(Output<Node>)> has_static_dim(size_t pos);
|
||||
|
||||
OPENVINO_API
|
||||
std::function<bool(Output<Node>)> has_static_dims(const std::vector<size_t>& dims);
|
||||
|
||||
OPENVINO_API
|
||||
std::function<bool(Output<Node>)> has_static_shape();
|
||||
|
||||
OPENVINO_API
|
||||
std::function<bool(Output<Node>)> has_static_rank();
|
||||
|
||||
OPENVINO_API
|
||||
std::function<bool(Output<Node>)> rank_equals(const Dimension& expected_rank);
|
||||
|
||||
OPENVINO_API
|
||||
std::function<bool(Output<Node>)> type_matches(const element::Type& type);
|
||||
|
||||
OPENVINO_API
|
||||
std::function<bool(Output<Node>)> type_matches_any(const std::vector<element::Type>& types);
|
||||
|
||||
namespace op {
|
||||
using NodePredicate = std::function<bool(std::shared_ptr<Node>)>;
|
||||
using ValuePredicate = std::function<bool(const Output<Node>& value)>;
|
||||
|
||||
OPENVINO_API
|
||||
ValuePredicate as_value_predicate(NodePredicate pred);
|
||||
|
||||
class OPENVINO_API Pattern : public Node {
|
||||
public:
|
||||
/// \brief \p a base class for \sa Skip and \sa Label
|
||||
///
|
||||
Pattern(const OutputVector& patterns, ValuePredicate pred) : Node(patterns), m_predicate(pred) {
|
||||
if (!m_predicate) {
|
||||
m_predicate = [](const Output<Node>&) {
|
||||
return true;
|
||||
};
|
||||
}
|
||||
}
|
||||
|
||||
Pattern(const OutputVector& patterns) : Pattern(patterns, nullptr) {}
|
||||
|
||||
std::shared_ptr<Node> clone_with_new_inputs(const OutputVector& /* new_args */) const override {
|
||||
throw Exception("Uncopyable");
|
||||
}
|
||||
|
||||
ValuePredicate get_predicate() const;
|
||||
|
||||
protected:
|
||||
ValuePredicate m_predicate;
|
||||
};
|
||||
} // namespace op
|
||||
} // namespace pattern
|
||||
} // namespace pass
|
||||
} // namespace ov
|
44
ngraph/core/include/openvino/pass/pattern/op/skip.hpp
Normal file
44
ngraph/core/include/openvino/pass/pattern/op/skip.hpp
Normal file
@ -0,0 +1,44 @@
|
||||
// Copyright (C) 2018-2021 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "openvino/core/node.hpp"
|
||||
#include "openvino/pass/pattern/op/pattern.hpp"
|
||||
|
||||
namespace ov {
|
||||
namespace pass {
|
||||
namespace pattern {
|
||||
namespace op {
|
||||
/// The graph value is added to the matched value list. If the predicate is true, the
|
||||
/// match succeeds if the arguments match; if the predicate is false, the match succeeds
|
||||
/// if the pattern input matches the graph value.
|
||||
class OPENVINO_API Skip : public Pattern {
|
||||
public:
|
||||
static constexpr NodeTypeInfo type_info{"patternSkip", 0};
|
||||
const NodeTypeInfo& get_type_info() const override;
|
||||
Skip(const Output<Node>& arg, ValuePredicate pred) : Pattern({arg}, pred) {
|
||||
set_output_type(0, arg.get_element_type(), arg.get_partial_shape());
|
||||
}
|
||||
|
||||
Skip(const Output<Node>& arg, NodePredicate pred = nullptr) : Pattern({arg}, as_value_predicate(pred)) {
|
||||
set_output_type(0, arg.get_element_type(), arg.get_partial_shape());
|
||||
}
|
||||
|
||||
Skip(const OutputVector& args, ValuePredicate pred) : Pattern(args, pred) {
|
||||
set_output_type(0, args.at(0).get_element_type(), args.at(0).get_partial_shape());
|
||||
}
|
||||
|
||||
Skip(const OutputVector& args, NodePredicate pred = nullptr) : Pattern(args, as_value_predicate(pred)) {
|
||||
set_output_type(0, args.at(0).get_element_type(), args.at(0).get_partial_shape());
|
||||
}
|
||||
|
||||
bool match_value(pattern::Matcher* matcher,
|
||||
const Output<Node>& pattern_value,
|
||||
const Output<Node>& graph_value) override;
|
||||
};
|
||||
} // namespace op
|
||||
} // namespace pattern
|
||||
} // namespace pass
|
||||
} // namespace ov
|
28
ngraph/core/include/openvino/pass/pattern/op/true.hpp
Normal file
28
ngraph/core/include/openvino/pass/pattern/op/true.hpp
Normal file
@ -0,0 +1,28 @@
|
||||
// Copyright (C) 2018-2021 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "openvino/core/node.hpp"
|
||||
#include "openvino/pass/pattern/op/pattern.hpp"
|
||||
|
||||
namespace ov {
|
||||
namespace pass {
|
||||
namespace pattern {
|
||||
namespace op {
|
||||
/// \brief The match always succeeds.
|
||||
class OPENVINO_API True : public Pattern {
|
||||
public:
|
||||
static constexpr NodeTypeInfo type_info{"patternTrue", 0};
|
||||
const NodeTypeInfo& get_type_info() const override;
|
||||
/// \brief Always matches, does not add node to match list.
|
||||
True() : Pattern(OutputVector{}) {}
|
||||
bool match_value(pattern::Matcher* matcher,
|
||||
const Output<Node>& pattern_value,
|
||||
const Output<Node>& graph_value) override;
|
||||
};
|
||||
} // namespace op
|
||||
} // namespace pattern
|
||||
} // namespace pass
|
||||
} // namespace ov
|
75
ngraph/core/include/openvino/pass/pattern/op/wrap_type.hpp
Normal file
75
ngraph/core/include/openvino/pass/pattern/op/wrap_type.hpp
Normal file
@ -0,0 +1,75 @@
|
||||
// Copyright (C) 2018-2021 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "openvino/core/node.hpp"
|
||||
#include "openvino/pass/pattern/op/pattern.hpp"
|
||||
|
||||
namespace ov {
|
||||
namespace pass {
|
||||
namespace pattern {
|
||||
namespace op {
|
||||
class OPENVINO_API WrapType : public Pattern {
|
||||
public:
|
||||
static constexpr NodeTypeInfo type_info{"patternAnyType", 0};
|
||||
const NodeTypeInfo& get_type_info() const override;
|
||||
|
||||
explicit WrapType(
|
||||
NodeTypeInfo wrapped_type,
|
||||
const ValuePredicate& pred =
|
||||
[](const Output<Node>& output) {
|
||||
return true;
|
||||
},
|
||||
const OutputVector& input_values = {})
|
||||
: Pattern(input_values, pred),
|
||||
m_wrapped_types({wrapped_type}) {
|
||||
set_output_type(0, element::Type_t::dynamic, PartialShape::dynamic());
|
||||
}
|
||||
|
||||
explicit WrapType(
|
||||
std::vector<NodeTypeInfo> wrapped_types,
|
||||
const ValuePredicate& pred =
|
||||
[](const Output<Node>& output) {
|
||||
return true;
|
||||
},
|
||||
const OutputVector& input_values = {})
|
||||
: Pattern(input_values, pred),
|
||||
m_wrapped_types(std::move(wrapped_types)) {
|
||||
set_output_type(0, element::Type_t::dynamic, PartialShape::dynamic());
|
||||
}
|
||||
|
||||
bool match_value(pattern::Matcher* matcher,
|
||||
const Output<Node>& pattern_value,
|
||||
const Output<Node>& graph_value) override;
|
||||
|
||||
NodeTypeInfo get_wrapped_type() const;
|
||||
|
||||
const std::vector<NodeTypeInfo>& get_wrapped_types() const;
|
||||
|
||||
private:
|
||||
std::vector<NodeTypeInfo> m_wrapped_types;
|
||||
};
|
||||
} // namespace op
|
||||
|
||||
template <class... Args>
|
||||
std::shared_ptr<Node> wrap_type(const OutputVector& inputs, const pattern::op::ValuePredicate& pred) {
|
||||
std::vector<DiscreteTypeInfo> info{Args::type_info...};
|
||||
return std::make_shared<op::WrapType>(info, pred, inputs);
|
||||
}
|
||||
|
||||
template <class... Args>
|
||||
std::shared_ptr<Node> wrap_type(const OutputVector& inputs = {}) {
|
||||
return wrap_type<Args...>(inputs, [](const Output<Node>& output) {
|
||||
return true;
|
||||
});
|
||||
}
|
||||
|
||||
template <class... Args>
|
||||
std::shared_ptr<Node> wrap_type(const pattern::op::ValuePredicate& pred) {
|
||||
return wrap_type<Args...>({}, pred);
|
||||
}
|
||||
} // namespace pattern
|
||||
} // namespace pass
|
||||
} // namespace ov
|
32
ngraph/core/include/openvino/pass/validate.hpp
Normal file
32
ngraph/core/include/openvino/pass/validate.hpp
Normal file
@ -0,0 +1,32 @@
|
||||
// Copyright (C) 2018-2021 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "openvino/core/core_visibility.hpp"
|
||||
#include "openvino/pass/pass.hpp"
|
||||
|
||||
namespace ov {
|
||||
namespace pass {
|
||||
/// \brief The Validate pass performs sanity checks on attributes and inputs, and
|
||||
/// computes output shapes and element types for all computation nodes in a given
|
||||
/// computation graph.
|
||||
///
|
||||
/// \details The verification and inference is done via invoking each node's specific
|
||||
/// implementation of \link ov::Node::validate_and_infer_types() \endlink function.
|
||||
///
|
||||
/// By default, the \ref ov::pass::Manager runs this pass after executing every
|
||||
/// optimization pass. This is to ensure that any update to the graph by an optimization
|
||||
/// pass does not break the shape and data type requirement on a computation node.
|
||||
/// This default validation run can be changed via calling the
|
||||
/// \link ov::pass::Manager::set_per_pass_validation(bool) \endlink function.
|
||||
class OPENVINO_API Validate : public FunctionPass {
|
||||
public:
|
||||
OPENVINO_RTTI_DECLARATION;
|
||||
|
||||
Validate() : FunctionPass() {}
|
||||
bool run_on_function(std::shared_ptr<ov::Function> f) override;
|
||||
};
|
||||
} // namespace pass
|
||||
} // namespace ov
|
57
ngraph/core/include/openvino/pass/visualize_tree.hpp
Normal file
57
ngraph/core/include/openvino/pass/visualize_tree.hpp
Normal file
@ -0,0 +1,57 @@
|
||||
// Copyright (C) 2018-2021 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <functional>
|
||||
#include <set>
|
||||
#include <sstream>
|
||||
#include <string>
|
||||
#include <typeindex>
|
||||
#include <typeinfo>
|
||||
#include <unordered_map>
|
||||
#include <utility>
|
||||
|
||||
#include "openvino/pass/pass.hpp"
|
||||
|
||||
class HeightMap;
|
||||
|
||||
using visualize_tree_ops_map_t =
|
||||
std::unordered_map<ov::Node::type_info_t, std::function<void(const ov::Node&, std::ostream& ss)>>;
|
||||
|
||||
namespace ov {
|
||||
namespace pass {
|
||||
class OPENVINO_API VisualizeTree : public FunctionPass {
|
||||
public:
|
||||
OPENVINO_RTTI_DECLARATION;
|
||||
|
||||
using node_modifiers_t = std::function<void(const Node& node, std::vector<std::string>& attributes)>;
|
||||
VisualizeTree(const std::string& file_name, node_modifiers_t nm = nullptr, bool dot_only = false);
|
||||
bool run_on_function(std::shared_ptr<ov::Function>) override;
|
||||
|
||||
void set_ops_to_details(const visualize_tree_ops_map_t& ops_map) {
|
||||
m_ops_to_details = ops_map;
|
||||
}
|
||||
|
||||
protected:
|
||||
void add_node_arguments(std::shared_ptr<Node> node,
|
||||
std::unordered_map<Node*, HeightMap>& height_maps,
|
||||
size_t& fake_node_ctr);
|
||||
std::string add_attributes(std::shared_ptr<Node> node);
|
||||
virtual std::string get_attributes(std::shared_ptr<Node> node);
|
||||
virtual std::string get_node_name(std::shared_ptr<Node> node);
|
||||
std::string get_constant_value(std::shared_ptr<Node> node, size_t max_elements = 7);
|
||||
|
||||
void render() const;
|
||||
|
||||
std::stringstream m_ss;
|
||||
std::string m_name;
|
||||
std::set<std::shared_ptr<Node>> m_nodes_with_attributes;
|
||||
visualize_tree_ops_map_t m_ops_to_details;
|
||||
node_modifiers_t m_node_modifiers = nullptr;
|
||||
bool m_dot_only;
|
||||
static const int max_jump_distance;
|
||||
};
|
||||
} // namespace pass
|
||||
} // namespace ov
|
@ -11,11 +11,10 @@
|
||||
#include "ngraph/validation_util.hpp"
|
||||
|
||||
using namespace std;
|
||||
using namespace ngraph;
|
||||
|
||||
NGRAPH_RTTI_DEFINITION(ngraph::pass::ConstantFolding, "ConstantFolding", 0);
|
||||
OPENVINO_RTTI_DEFINITION(ov::pass::ConstantFolding, "ConstantFolding", 0);
|
||||
|
||||
bool ngraph::pass::ConstantFolding::run_on_function(std::shared_ptr<ngraph::Function> f) {
|
||||
bool ov::pass::ConstantFolding::run_on_function(std::shared_ptr<ov::Function> f) {
|
||||
bool rewritten = pre_calculated_values_folding(f);
|
||||
|
||||
for (const auto& node : f->get_ordered_ops()) {
|
||||
@ -48,7 +47,7 @@ bool ngraph::pass::ConstantFolding::run_on_function(std::shared_ptr<ngraph::Func
|
||||
}
|
||||
} else {
|
||||
// recursively constant fold operators containing subgraphs (ie: TensorIterator, Loop)
|
||||
if (auto sub_graph_node = std::dynamic_pointer_cast<op::util::SubGraphOp>(node)) {
|
||||
if (auto sub_graph_node = std::dynamic_pointer_cast<ngraph::op::util::SubGraphOp>(node)) {
|
||||
if (const auto& sub_graph = sub_graph_node->get_function()) {
|
||||
rewritten |= run_on_function(sub_graph);
|
||||
}
|
||||
@ -79,14 +78,14 @@ bool ngraph::pass::ConstantFolding::pre_calculated_values_folding(const std::sha
|
||||
while (!nodes.empty()) {
|
||||
auto curr_node = nodes.front();
|
||||
nodes.pop_front();
|
||||
if (visited.count(curr_node) || ov::is_type<op::Constant>(curr_node))
|
||||
if (visited.count(curr_node) || ov::is_type<ngraph::op::Constant>(curr_node))
|
||||
continue;
|
||||
visited.insert(curr_node);
|
||||
|
||||
for (auto& input_value : curr_node->input_values()) {
|
||||
// Check that ConstantFolding is not disabled on this path
|
||||
std::vector<Node*> order;
|
||||
auto status = could_propagate(input_value, order);
|
||||
auto status = ngraph::could_propagate(input_value, order);
|
||||
if (status) {
|
||||
for (const auto& node : order) {
|
||||
const auto& rt_info = node->get_rt_info();
|
||||
@ -99,8 +98,8 @@ bool ngraph::pass::ConstantFolding::pre_calculated_values_folding(const std::sha
|
||||
|
||||
if (status && input_value.get_tensor().has_and_set_bound()) {
|
||||
auto input_node = input_value.get_node_shared_ptr();
|
||||
auto replacement = std::make_shared<op::Constant>(input_value.get_tensor().get_lower_value());
|
||||
if (replacement && !ov::is_type<op::Constant>(input_node)) {
|
||||
auto replacement = std::make_shared<ngraph::op::Constant>(input_value.get_tensor().get_lower_value());
|
||||
if (replacement && !ov::is_type<ngraph::op::Constant>(input_node)) {
|
||||
if (input_node->get_output_size() == 1) {
|
||||
replacement->set_friendly_name(input_node->get_friendly_name());
|
||||
} else {
|
||||
|
@ -9,12 +9,11 @@
|
||||
#include "transformations/convert_precision.hpp"
|
||||
|
||||
using namespace std;
|
||||
using namespace ngraph;
|
||||
|
||||
NGRAPH_RTTI_DEFINITION(ngraph::pass::ConvertFP32ToFP16, "ConvertFP32ToFP16", 0);
|
||||
OPENVINO_RTTI_DEFINITION(ov::pass::ConvertFP32ToFP16, "ConvertFP32ToFP16", 0);
|
||||
|
||||
bool ngraph::pass::ConvertFP32ToFP16::run_on_function(std::shared_ptr<ngraph::Function> f) {
|
||||
ngraph::pass::Manager m(get_pass_config());
|
||||
bool ov::pass::ConvertFP32ToFP16::run_on_function(std::shared_ptr<ov::Function> f) {
|
||||
ov::pass::Manager m(get_pass_config());
|
||||
m.register_pass<ngraph::pass::ConvertPrecision>(precisions_array{{ngraph::element::f32, ngraph::element::f16}});
|
||||
m.run_passes(f);
|
||||
return false;
|
||||
|
@ -18,9 +18,6 @@
|
||||
#include "ngraph/op/util/sub_graph_base.hpp"
|
||||
#include "perf_counters.hpp"
|
||||
|
||||
using namespace std;
|
||||
using namespace ngraph;
|
||||
|
||||
/* GraphRewrite algorithm:
|
||||
* GraphRewrite processes an input graph in an topological order(i.e. args before users)
|
||||
* Given the following graph: Abs2
|
||||
@ -33,7 +30,7 @@ using namespace ngraph;
|
||||
* Note, `Abs2` comes before `Neg3` as `Abs2`'s id = 2 is *less* than `Neg3`'s one (id = 3)
|
||||
* Next, GraphRewrite will invoke matchers passes registered in add_matcher order.
|
||||
* For example:
|
||||
* ngraph::pass::GraphRewrite pass;
|
||||
* ov::pass::GraphRewrite pass;
|
||||
* pass.add_matcher<m1>();
|
||||
* pass.add_matcher<m2>();
|
||||
* pass.add_matcher<m3>();
|
||||
@ -53,13 +50,13 @@ using namespace ngraph;
|
||||
* If MatcherPass register more than one node make sure that this nodes are registered in
|
||||
* topological order. */
|
||||
|
||||
NGRAPH_RTTI_DEFINITION(ngraph::pass::GraphRewrite, "ngraph::pass::GraphRewrite", 0);
|
||||
NGRAPH_RTTI_DEFINITION(ov::pass::GraphRewrite, "ov::pass::GraphRewrite", 0);
|
||||
|
||||
NGRAPH_RTTI_DEFINITION(ngraph::pass::BackwardGraphRewrite, "ngraph::pass::BackwardGraphRewrite", 0);
|
||||
NGRAPH_RTTI_DEFINITION(ov::pass::BackwardGraphRewrite, "ov::pass::BackwardGraphRewrite", 0);
|
||||
|
||||
NGRAPH_RTTI_DEFINITION(ngraph::pass::MatcherPass, "ngraph::pass::MatcherPass", 0);
|
||||
NGRAPH_RTTI_DEFINITION(ov::pass::MatcherPass, "ov::pass::MatcherPass", 0);
|
||||
|
||||
namespace ngraph {
|
||||
namespace ov {
|
||||
namespace pass {
|
||||
namespace internal {
|
||||
PerfCounters& perf_counters_graph_rewrite() {
|
||||
@ -68,27 +65,28 @@ PerfCounters& perf_counters_graph_rewrite() {
|
||||
}
|
||||
} // namespace internal
|
||||
} // namespace pass
|
||||
} // namespace ngraph
|
||||
} // namespace ov
|
||||
|
||||
bool pass::BackwardGraphRewrite::run_on_function(std::shared_ptr<ngraph::Function> f) {
|
||||
bool ov::pass::BackwardGraphRewrite::run_on_function(std::shared_ptr<ov::Function> f) {
|
||||
// Initialize execution queue with nodes in topological order
|
||||
deque<std::weak_ptr<Node>> nodes_to_run;
|
||||
std::deque<std::weak_ptr<Node>> nodes_to_run;
|
||||
for (auto& node : f->get_ordered_ops()) {
|
||||
nodes_to_run.emplace_front(node);
|
||||
}
|
||||
return apply_matcher_passes(f, std::move(nodes_to_run));
|
||||
}
|
||||
|
||||
bool pass::GraphRewrite::run_on_function(std::shared_ptr<ngraph::Function> f) {
|
||||
bool ov::pass::GraphRewrite::run_on_function(std::shared_ptr<ov::Function> f) {
|
||||
// Initialize execution queue with nodes in topological order
|
||||
deque<std::weak_ptr<Node>> nodes_to_run;
|
||||
std::deque<std::weak_ptr<Node>> nodes_to_run;
|
||||
for (auto& node : f->get_ordered_ops()) {
|
||||
nodes_to_run.emplace_back(node);
|
||||
}
|
||||
return apply_matcher_passes(f, std::move(nodes_to_run));
|
||||
}
|
||||
|
||||
bool pass::GraphRewrite::apply_matcher_passes(shared_ptr<Function> f, deque<std::weak_ptr<Node>> nodes_to_run) {
|
||||
bool ov::pass::GraphRewrite::apply_matcher_passes(std::shared_ptr<Function> f,
|
||||
std::deque<std::weak_ptr<Node>> nodes_to_run) {
|
||||
OV_ITT_SCOPED_TASK(ov::itt::domains::nGraph, "pass::GraphRewrite::run_on_function");
|
||||
|
||||
bool rewritten = false;
|
||||
@ -111,7 +109,7 @@ bool pass::GraphRewrite::apply_matcher_passes(shared_ptr<Function> f, deque<std:
|
||||
auto root = matcher->get_pattern_value().get_node_shared_ptr();
|
||||
// pattern::op::AnyOutput operation automatically appends for multi output operations inside
|
||||
// Matcher and to gen actual root node we need to take it's parent.
|
||||
if (auto any_type = dynamic_pointer_cast<pattern::op::AnyOutput>(root)) {
|
||||
if (auto any_type = std::dynamic_pointer_cast<pattern::op::AnyOutput>(root)) {
|
||||
root = any_type->input_value(0).get_node_shared_ptr();
|
||||
}
|
||||
|
||||
@ -119,8 +117,8 @@ bool pass::GraphRewrite::apply_matcher_passes(shared_ptr<Function> f, deque<std:
|
||||
// it's type
|
||||
// and use it in unordered_map as key for fast MatcherPass search. Otherwise type is unknown
|
||||
// and default algorithm is used.
|
||||
if (auto p = dynamic_pointer_cast<pattern::op::Pattern>(root)) {
|
||||
if (auto any_type = dynamic_pointer_cast<pattern::op::WrapType>(p)) {
|
||||
if (auto p = std::dynamic_pointer_cast<pattern::op::Pattern>(root)) {
|
||||
if (auto any_type = std::dynamic_pointer_cast<pattern::op::WrapType>(p)) {
|
||||
for (const auto& root_type_info : any_type->get_wrapped_types()) {
|
||||
type_to_matcher[root_type_info].push_back(matcher_index);
|
||||
}
|
||||
@ -180,7 +178,7 @@ bool pass::GraphRewrite::apply_matcher_passes(shared_ptr<Function> f, deque<std:
|
||||
continue;
|
||||
|
||||
// Recursive apply Matchers for sub-graph based nodes
|
||||
if (auto sub_graph_node = std::dynamic_pointer_cast<op::util::SubGraphOp>(node)) {
|
||||
if (auto sub_graph_node = std::dynamic_pointer_cast<ngraph::op::util::SubGraphOp>(node)) {
|
||||
if (auto sub_graph = sub_graph_node->get_function()) {
|
||||
run_on_function(sub_graph);
|
||||
}
|
||||
@ -236,9 +234,9 @@ bool pass::GraphRewrite::apply_matcher_passes(shared_ptr<Function> f, deque<std:
|
||||
return rewritten;
|
||||
}
|
||||
|
||||
void pass::GraphRewrite::add_matcher(const shared_ptr<pattern::Matcher>& m,
|
||||
const graph_rewrite_callback& callback,
|
||||
const PassPropertyMask& property) {
|
||||
void ov::pass::GraphRewrite::add_matcher(const std::shared_ptr<pattern::Matcher>& m,
|
||||
const graph_rewrite_callback& callback,
|
||||
const PassPropertyMask& property) {
|
||||
m_matchers.push_back(std::make_shared<MatcherPass>(
|
||||
m->get_name(),
|
||||
m,
|
||||
@ -258,7 +256,8 @@ void pass::GraphRewrite::add_matcher(const shared_ptr<pattern::Matcher>& m,
|
||||
property));
|
||||
}
|
||||
|
||||
void pass::GraphRewrite::add_matcher(const shared_ptr<pattern::Matcher>& m, const graph_rewrite_callback& callback) {
|
||||
void ov::pass::GraphRewrite::add_matcher(const std::shared_ptr<pattern::Matcher>& m,
|
||||
const graph_rewrite_callback& callback) {
|
||||
NGRAPH_SUPPRESS_DEPRECATED_START
|
||||
// TODO: before deprecate this function, by default expect the
|
||||
// callback require static shape.
|
||||
@ -266,7 +265,7 @@ void pass::GraphRewrite::add_matcher(const shared_ptr<pattern::Matcher>& m, cons
|
||||
NGRAPH_SUPPRESS_DEPRECATED_END
|
||||
}
|
||||
|
||||
void pass::GraphRewrite::set_pass_config(const std::shared_ptr<PassConfig>& rhs) {
|
||||
void ov::pass::GraphRewrite::set_pass_config(const std::shared_ptr<PassConfig>& rhs) {
|
||||
auto pass_config = get_pass_config();
|
||||
// We have to preserve disabled passes because in case when we register matchers inside
|
||||
// GraphRewrite c-tor we work with local PassConfig instance.
|
||||
@ -293,9 +292,9 @@ void pass::GraphRewrite::set_pass_config(const std::shared_ptr<PassConfig>& rhs)
|
||||
}
|
||||
}
|
||||
|
||||
void pass::RecurrentGraphRewrite::add_matcher(const std::shared_ptr<pattern::RecurrentMatcher>& m,
|
||||
const ngraph::recurrent_graph_rewrite_callback& callback,
|
||||
const PassPropertyMask& property) {
|
||||
void ov::pass::RecurrentGraphRewrite::add_matcher(const std::shared_ptr<pattern::RecurrentMatcher>& m,
|
||||
const ov::recurrent_graph_rewrite_callback& callback,
|
||||
const PassPropertyMask& property) {
|
||||
m_matchers.push_back(std::make_shared<MatcherPass>(
|
||||
"Recurrent matcher",
|
||||
nullptr,
|
||||
@ -310,24 +309,24 @@ void pass::RecurrentGraphRewrite::add_matcher(const std::shared_ptr<pattern::Rec
|
||||
property));
|
||||
}
|
||||
|
||||
void pass::RecurrentGraphRewrite::add_matcher(const std::shared_ptr<pattern::RecurrentMatcher>& m,
|
||||
const ngraph::recurrent_graph_rewrite_callback& callback) {
|
||||
void ov::pass::RecurrentGraphRewrite::add_matcher(const std::shared_ptr<pattern::RecurrentMatcher>& m,
|
||||
const ov::recurrent_graph_rewrite_callback& callback) {
|
||||
// TODO: before deprecate this function, by default expect the
|
||||
// callback require static shape.
|
||||
add_matcher(m, callback, {PassProperty::REQUIRE_STATIC_SHAPE});
|
||||
}
|
||||
|
||||
bool pass::RecurrentGraphRewrite::run_on_function(shared_ptr<Function> f) {
|
||||
bool ov::pass::RecurrentGraphRewrite::run_on_function(std::shared_ptr<Function> f) {
|
||||
bool changed = false;
|
||||
size_t i = 0;
|
||||
|
||||
// This check is very expensive and is only needed for experimental features, so we will hide
|
||||
// it behind an environment variable for now. TODO: Find a less expensive way to handle this.
|
||||
static bool s_rerun_dynamic_check = getenv_bool("NGRAPH_GRAPH_REWRITE_RERUN_DYNAMIC_CHECK");
|
||||
static bool s_rerun_dynamic_check = ngraph::getenv_bool("NGRAPH_GRAPH_REWRITE_RERUN_DYNAMIC_CHECK");
|
||||
|
||||
auto run_matchers = [&]() -> bool {
|
||||
bool is_dyn_func = s_rerun_dynamic_check && f->is_dynamic();
|
||||
for (auto node : f->get_ops()) {
|
||||
for (const auto& node : f->get_ops()) {
|
||||
for (auto& m_pass : m_matchers) {
|
||||
if (is_dyn_func && m_pass->get_property(PassProperty::REQUIRE_STATIC_SHAPE)) {
|
||||
NGRAPH_DEBUG << "matcher callback requires static shape but the "
|
||||
@ -356,9 +355,9 @@ bool pass::RecurrentGraphRewrite::run_on_function(shared_ptr<Function> f) {
|
||||
return changed;
|
||||
}
|
||||
|
||||
void ngraph::pass::MatcherPass::register_matcher(const std::shared_ptr<ngraph::pattern::Matcher>& m,
|
||||
const ngraph::graph_rewrite_callback& callback,
|
||||
const PassPropertyMask& property) {
|
||||
void ov::pass::MatcherPass::register_matcher(const std::shared_ptr<ov::pass::pattern::Matcher>& m,
|
||||
const ov::graph_rewrite_callback& callback,
|
||||
const PassPropertyMask& property) {
|
||||
set_name(m->get_name());
|
||||
set_property(property, true);
|
||||
m_matcher = m;
|
||||
@ -376,7 +375,7 @@ void ngraph::pass::MatcherPass::register_matcher(const std::shared_ptr<ngraph::p
|
||||
};
|
||||
}
|
||||
|
||||
bool ngraph::pass::MatcherPass::apply(std::shared_ptr<ngraph::Node> node) {
|
||||
bool ov::pass::MatcherPass::apply(std::shared_ptr<ov::Node> node) {
|
||||
OV_ITT_SCOPED_TASK(ov::itt::domains::nGraph, pass::internal::perf_counters_graph_rewrite()[get_type_info()]);
|
||||
m_new_nodes.clear();
|
||||
if (m_handler)
|
||||
|
@ -12,13 +12,12 @@
|
||||
#include <ngraph/rt_info.hpp>
|
||||
#include <ngraph/variant.hpp>
|
||||
|
||||
NGRAPH_RTTI_DEFINITION(ngraph::pass::LowLatency2, "LowLatency2", 0);
|
||||
NGRAPH_RTTI_DEFINITION(ov::pass::LowLatency2, "LowLatency2", 0);
|
||||
|
||||
NGRAPH_SUPPRESS_DEPRECATED_START
|
||||
NGRAPH_RTTI_DEFINITION(ngraph::pass::LowLatency, "LowLatency", 0);
|
||||
|
||||
using namespace std;
|
||||
using namespace ngraph;
|
||||
|
||||
namespace {
|
||||
string generate_variable_name(const string& op_name, const string& param_name, int variable_idx) {
|
||||
@ -27,8 +26,8 @@ string generate_variable_name(const string& op_name, const string& param_name, i
|
||||
|
||||
} // namespace
|
||||
ngraph::pass::LowLatency::LowLatency() {
|
||||
auto tensor_iterator = ngraph::pattern::wrap_type<opset6::TensorIterator, opset6::Loop>();
|
||||
ngraph::matcher_pass_callback callback = [](ngraph::pattern::Matcher& m) {
|
||||
auto tensor_iterator = ov::pass::pattern::wrap_type<opset6::TensorIterator, opset6::Loop>();
|
||||
ov::matcher_pass_callback callback = [](ov::pass::pattern::Matcher& m) {
|
||||
const auto& sub_graph_op = std::dynamic_pointer_cast<ngraph::op::util::SubGraphOp>(m.get_match_root());
|
||||
if (!sub_graph_op) {
|
||||
return false;
|
||||
@ -38,7 +37,7 @@ ngraph::pass::LowLatency::LowLatency() {
|
||||
const auto& trip_count = std::dynamic_pointer_cast<opset6::Constant>(loop->get_input_node_shared_ptr(0));
|
||||
const auto& num_iter = loop->get_num_iterations();
|
||||
if (trip_count && num_iter > 0 && trip_count->get_output_target_inputs(0).size() == 1) {
|
||||
auto single_iter = std::make_shared<opset6::Constant>(ngraph::element::i64, Shape{}, 1);
|
||||
auto single_iter = std::make_shared<opset6::Constant>(ov::element::i64, Shape{}, 1);
|
||||
replace_node(trip_count, single_iter);
|
||||
} else {
|
||||
// count of iterations is dynamic;
|
||||
@ -47,7 +46,7 @@ ngraph::pass::LowLatency::LowLatency() {
|
||||
}
|
||||
// Mark the TI layer to be unrolled. Enable unconditional ti unrolling for all plugins.
|
||||
auto& rt_info = sub_graph_op->get_rt_info();
|
||||
rt_info["UNROLL_TI"] = std::make_shared<ngraph::VariantWrapper<int64_t>>(1);
|
||||
rt_info["UNROLL_TI"] = std::make_shared<ov::VariantWrapper<int64_t>>(1);
|
||||
|
||||
int64_t variable_id = 0;
|
||||
std::vector<std::shared_ptr<ngraph::op::Sink>> assigns;
|
||||
@ -87,13 +86,14 @@ ngraph::pass::LowLatency::LowLatency() {
|
||||
return false;
|
||||
};
|
||||
|
||||
auto m = std::make_shared<ngraph::pattern::Matcher>(tensor_iterator, "LowLatency");
|
||||
auto m = std::make_shared<ov::pass::pattern::Matcher>(tensor_iterator, "LowLatency");
|
||||
register_matcher(m, callback);
|
||||
}
|
||||
NGRAPH_SUPPRESS_DEPRECATED_END
|
||||
|
||||
void UnrollSingleIteration(const shared_ptr<op::util::SubGraphOp>& sub_graph_op, const shared_ptr<Function>& outer_f) {
|
||||
using namespace opset7;
|
||||
void UnrollSingleIteration(const shared_ptr<ngraph::op::util::SubGraphOp>& sub_graph_op,
|
||||
const shared_ptr<ov::Function>& outer_f) {
|
||||
using namespace ngraph::opset7;
|
||||
|
||||
const auto& params = sub_graph_op->get_function()->get_parameters();
|
||||
const auto& results = sub_graph_op->get_function()->get_results();
|
||||
@ -109,7 +109,7 @@ void UnrollSingleIteration(const shared_ptr<op::util::SubGraphOp>& sub_graph_op,
|
||||
|
||||
// before: TI [...-> Layer1 -> Result -> output] -> Layer2 -> ...
|
||||
// after: ...-> Layer1 -> Layer2 -> ...
|
||||
NodeVector new_ops;
|
||||
ov::NodeVector new_ops;
|
||||
for (const auto& out : sub_graph_op->get_output_descriptions()) {
|
||||
const auto& connect_to = results.at(out->m_body_value_index)->get_input_source_output(0);
|
||||
for (auto& input_to : sub_graph_op->output(out->m_output_index).get_target_inputs()) {
|
||||
@ -120,7 +120,7 @@ void UnrollSingleIteration(const shared_ptr<op::util::SubGraphOp>& sub_graph_op,
|
||||
|
||||
// IECompatibility: insert identity (Unsqueeze + Squeeze) to store the TensorIterator
|
||||
// output names
|
||||
auto axis_1 = Constant::create(ngraph::element::i64, ngraph::Shape{1}, {1});
|
||||
auto axis_1 = Constant::create(ov::element::i64, ngraph::Shape{1}, {1});
|
||||
auto identity_1 = std::make_shared<Unsqueeze>(connect_to, axis_1);
|
||||
auto identity_2 = std::make_shared<Squeeze>(identity_1, axis_1);
|
||||
identity_2->set_friendly_name(out_name);
|
||||
@ -135,36 +135,38 @@ void UnrollSingleIteration(const shared_ptr<op::util::SubGraphOp>& sub_graph_op,
|
||||
ngraph::copy_runtime_info(sub_graph_op, new_ops);
|
||||
}
|
||||
|
||||
Output<Node> create_init_subgraph(const shared_ptr<op::util::SubGraphOp>& sub_graph_op, const Output<Node>& in_node) {
|
||||
using namespace opset7;
|
||||
ngraph::Output<ngraph::Node> create_init_subgraph(const shared_ptr<ngraph::op::util::SubGraphOp>& sub_graph_op,
|
||||
const ngraph::Output<ngraph::Node>& in_node) {
|
||||
using namespace ngraph::opset7;
|
||||
|
||||
auto const_zero = make_shared<Constant>(in_node.get_element_type(), Shape{1}, 0);
|
||||
auto const_zero = make_shared<Constant>(in_node.get_element_type(), ngraph::Shape{1}, 0);
|
||||
auto shape_of = make_shared<ShapeOf>(in_node);
|
||||
auto broadcast = make_shared<Broadcast>(const_zero, shape_of);
|
||||
copy_runtime_info(sub_graph_op, {const_zero, shape_of, broadcast});
|
||||
return broadcast->output(0);
|
||||
}
|
||||
|
||||
bool pass::LowLatency2::run_on_function(shared_ptr<Function> f) {
|
||||
using namespace opset7;
|
||||
bool ov::pass::LowLatency2::run_on_function(shared_ptr<Function> f) {
|
||||
using namespace ngraph::opset7;
|
||||
|
||||
SinkVector assigns;
|
||||
ngraph::SinkVector assigns;
|
||||
for (const auto& op : f->get_ordered_ops()) {
|
||||
if (const auto& sub_graph_op = dynamic_pointer_cast<op::util::SubGraphOp>(op)) {
|
||||
if (const auto& sub_graph_op = dynamic_pointer_cast<ngraph::op::util::SubGraphOp>(op)) {
|
||||
int64_t variable_id = 0;
|
||||
const auto& func = sub_graph_op->get_function();
|
||||
const auto& params = func->get_parameters();
|
||||
for (const auto& in : sub_graph_op->get_input_descriptions()) {
|
||||
// Process all back edges
|
||||
if (const auto& merged_in = dynamic_pointer_cast<op::util::SubGraphOp::MergedInputDescription>(in)) {
|
||||
if (const auto& merged_in =
|
||||
dynamic_pointer_cast<ngraph::op::util::SubGraphOp::MergedInputDescription>(in)) {
|
||||
// create new Variable
|
||||
const string& param_name = params.at(merged_in->m_body_parameter_index)->get_friendly_name();
|
||||
const string& var_name =
|
||||
generate_variable_name(sub_graph_op->get_friendly_name(), param_name, variable_id);
|
||||
|
||||
const auto& input = sub_graph_op->input(merged_in->m_input_index);
|
||||
if (std::dynamic_pointer_cast<op::ReadValueBase>(input.get_source_output().get_node_shared_ptr()) !=
|
||||
nullptr) {
|
||||
if (std::dynamic_pointer_cast<ngraph::op::ReadValueBase>(
|
||||
input.get_source_output().get_node_shared_ptr()) != nullptr) {
|
||||
NGRAPH_DEBUG << "LowLatency2 transformation cannot be applied because the "
|
||||
<< "ReadValue node is already an input to the TensorIterator."
|
||||
<< "LowLatency2 transformation may have already been applied, please "
|
||||
@ -175,7 +177,7 @@ bool pass::LowLatency2::run_on_function(shared_ptr<Function> f) {
|
||||
const auto& param =
|
||||
sub_graph_op->get_function()->get_parameters().at(merged_in->m_body_parameter_index);
|
||||
for (const auto& in_to : param->output(0).get_target_inputs()) {
|
||||
if (dynamic_cast<op::ReadValueBase*>(in_to.get_node()) != nullptr) {
|
||||
if (dynamic_cast<ngraph::op::ReadValueBase*>(in_to.get_node()) != nullptr) {
|
||||
NGRAPH_DEBUG << "LowLatency2 transformation cannot be applied because the "
|
||||
<< "ReadValue node is already inside the TensorIterator. "
|
||||
<< "LowLatency transformation may have been applied, please do "
|
||||
@ -184,8 +186,8 @@ bool pass::LowLatency2::run_on_function(shared_ptr<Function> f) {
|
||||
}
|
||||
}
|
||||
|
||||
VariableInfo var_info{PartialShape::dynamic(), element::dynamic, var_name};
|
||||
auto variable = make_shared<Variable>(var_info);
|
||||
ngraph::VariableInfo var_info{PartialShape::dynamic(), element::dynamic, var_name};
|
||||
auto variable = make_shared<ngraph::Variable>(var_info);
|
||||
|
||||
// insert ReadValue
|
||||
// Layers -> [new op: ReadValue] -> Subgraph operation
|
||||
@ -204,12 +206,12 @@ bool pass::LowLatency2::run_on_function(shared_ptr<Function> f) {
|
||||
// ---> Layers -> ...
|
||||
*/
|
||||
const auto& out_desc = sub_graph_op->get_output_descriptions();
|
||||
bool is_output_exist =
|
||||
std::any_of(out_desc.begin(),
|
||||
out_desc.end(),
|
||||
[&merged_in](const std::shared_ptr<op::util::SubGraphOp::OutputDescription>& out) {
|
||||
return out->m_body_value_index == merged_in->m_body_value_index;
|
||||
});
|
||||
bool is_output_exist = std::any_of(
|
||||
out_desc.begin(),
|
||||
out_desc.end(),
|
||||
[&merged_in](const std::shared_ptr<ngraph::op::util::SubGraphOp::OutputDescription>& out) {
|
||||
return out->m_body_value_index == merged_in->m_body_value_index;
|
||||
});
|
||||
// Create new output if it doesn't exist.
|
||||
if (!is_output_exist) {
|
||||
sub_graph_op->get_iter_value(func->get_results().at(merged_in->m_body_value_index));
|
||||
@ -217,7 +219,7 @@ bool pass::LowLatency2::run_on_function(shared_ptr<Function> f) {
|
||||
for (const auto& out : sub_graph_op->get_output_descriptions()) {
|
||||
if (out->m_body_value_index == merged_in->m_body_value_index) {
|
||||
auto assign = make_shared<Assign>(sub_graph_op->output(out->m_output_index), variable);
|
||||
ngraph::copy_runtime_info(sub_graph_op, assign);
|
||||
copy_runtime_info(sub_graph_op, assign);
|
||||
// control dependency so that ReadValue is processed before Assign
|
||||
assign->add_control_dependency(read_value);
|
||||
assigns.emplace_back(assign);
|
||||
|
@ -24,9 +24,8 @@
|
||||
#include "perf_counters.hpp"
|
||||
|
||||
using namespace std;
|
||||
using namespace ngraph;
|
||||
|
||||
namespace ngraph {
|
||||
namespace ov {
|
||||
namespace pass {
|
||||
namespace internal {
|
||||
PerfCounters& perf_counters() {
|
||||
@ -35,25 +34,25 @@ PerfCounters& perf_counters() {
|
||||
}
|
||||
} // namespace internal
|
||||
} // namespace pass
|
||||
} // namespace ngraph
|
||||
} // namespace ov
|
||||
|
||||
pass::Manager::Manager()
|
||||
ov::pass::Manager::Manager()
|
||||
: m_pass_config(std::make_shared<PassConfig>()),
|
||||
m_visualize(getenv_bool("NGRAPH_ENABLE_VISUALIZE_TRACING")) {}
|
||||
m_visualize(ngraph::getenv_bool("NGRAPH_ENABLE_VISUALIZE_TRACING")) {}
|
||||
|
||||
pass::Manager::~Manager() {}
|
||||
ov::pass::Manager::~Manager() = default;
|
||||
|
||||
pass::Manager::Manager(std::shared_ptr<ngraph::pass::PassConfig> pass_config) : m_pass_config(std::move(pass_config)) {}
|
||||
ov::pass::Manager::Manager(std::shared_ptr<ov::pass::PassConfig> pass_config) : m_pass_config(std::move(pass_config)) {}
|
||||
|
||||
void pass::Manager::run_passes(shared_ptr<Function> func) {
|
||||
void ov::pass::Manager::run_passes(shared_ptr<ov::Function> func) {
|
||||
NGRAPH_SUPPRESS_DEPRECATED_START
|
||||
OV_ITT_SCOPED_TASK(ov::itt::domains::nGraph, "pass::Manager::run_passes");
|
||||
|
||||
static bool profile_enabled = getenv_bool("NGRAPH_PROFILE_PASS_ENABLE");
|
||||
static bool profile_enabled = ngraph::getenv_bool("NGRAPH_PROFILE_PASS_ENABLE");
|
||||
|
||||
size_t index = 0;
|
||||
stopwatch pass_timer;
|
||||
stopwatch overall_timer;
|
||||
ngraph::stopwatch pass_timer;
|
||||
ngraph::stopwatch overall_timer;
|
||||
overall_timer.start();
|
||||
bool function_changed = false;
|
||||
for (auto& pass : m_pass_list) {
|
||||
@ -96,13 +95,13 @@ void pass::Manager::run_passes(shared_ptr<Function> func) {
|
||||
} else {
|
||||
function_changed = function_pass->run_on_function(func);
|
||||
}
|
||||
} else if (auto node_pass = dynamic_pointer_cast<NodePass>(pass)) {
|
||||
} else if (auto node_pass = dynamic_pointer_cast<ngraph::pass::NodePass>(pass)) {
|
||||
if (node_pass->get_property(PassProperty::REQUIRE_STATIC_SHAPE) && func->is_dynamic()) {
|
||||
NGRAPH_DEBUG << "Pass " << pass->get_name() << " requires static shape but the "
|
||||
<< "function is dynamic. Skipping this transformation";
|
||||
continue;
|
||||
}
|
||||
for (shared_ptr<Node> n : func->get_ops()) {
|
||||
for (const shared_ptr<Node>& n : func->get_ops()) {
|
||||
function_changed |= node_pass->run_on_node(n);
|
||||
}
|
||||
}
|
||||
@ -115,7 +114,7 @@ void pass::Manager::run_passes(shared_ptr<Function> func) {
|
||||
auto base_filename = func->get_name() + std::string("_") + index_str + std::string("_") + pass->get_name();
|
||||
|
||||
if (m_visualize) {
|
||||
static const string format = getenv_string("NGRAPH_VISUALIZE_TRACING_FORMAT");
|
||||
static const string format = ngraph::getenv_string("NGRAPH_VISUALIZE_TRACING_FORMAT");
|
||||
auto file_ext = format.empty() ? "svg" : format;
|
||||
pass::VisualizeTree vt(base_filename + std::string(".") + file_ext);
|
||||
vt.run_on_function(func);
|
||||
|
@ -7,21 +7,20 @@
|
||||
# include <cxxabi.h>
|
||||
#endif
|
||||
|
||||
#include "ngraph/pass/manager.hpp"
|
||||
#include "ngraph/pass/pass.hpp"
|
||||
#include "openvino/pass/manager.hpp"
|
||||
|
||||
using namespace std;
|
||||
using namespace ngraph;
|
||||
|
||||
NGRAPH_RTTI_DEFINITION(ngraph::pass::FunctionPass, "ngraph::pass::FunctionPass", 0);
|
||||
OPENVINO_RTTI_DEFINITION(ov::pass::FunctionPass, "ov::pass::FunctionPass", 0);
|
||||
|
||||
pass::PassBase::PassBase() : m_property{all_pass_property_off}, m_pass_config(std::make_shared<PassConfig>()) {}
|
||||
ov::pass::PassBase::PassBase() : m_property(), m_pass_config(std::make_shared<PassConfig>()) {}
|
||||
|
||||
bool pass::PassBase::get_property(const PassPropertyMask& prop) const {
|
||||
bool ov::pass::PassBase::get_property(const PassPropertyMask& prop) const {
|
||||
return m_property.is_set(prop);
|
||||
}
|
||||
|
||||
void pass::PassBase::set_property(const PassPropertyMask& prop, bool value) {
|
||||
void ov::pass::PassBase::set_property(const PassPropertyMask& prop, bool value) {
|
||||
if (value) {
|
||||
m_property.set(prop);
|
||||
} else {
|
||||
@ -29,7 +28,7 @@ void pass::PassBase::set_property(const PassPropertyMask& prop, bool value) {
|
||||
}
|
||||
}
|
||||
|
||||
std::string pass::PassBase::get_name() const {
|
||||
std::string ov::pass::PassBase::get_name() const {
|
||||
if (m_name.empty()) {
|
||||
const PassBase* p = this;
|
||||
std::string pass_name = typeid(*p).name();
|
||||
@ -43,16 +42,16 @@ std::string pass::PassBase::get_name() const {
|
||||
}
|
||||
}
|
||||
|
||||
void pass::PassBase::set_callback(const param_callback& callback) {
|
||||
void ov::pass::PassBase::set_callback(const param_callback& callback) {
|
||||
m_pass_config->set_callback(callback);
|
||||
}
|
||||
|
||||
// The symbols are requiered to be in cpp file to workaround RTTI issue on Android LLVM
|
||||
|
||||
pass::FunctionPass::~FunctionPass() {}
|
||||
ov::pass::FunctionPass::~FunctionPass() = default;
|
||||
|
||||
NGRAPH_SUPPRESS_DEPRECATED_START
|
||||
OPENVINO_SUPPRESS_DEPRECATED_START
|
||||
|
||||
NGRAPH_RTTI_DEFINITION(ngraph::pass::NodePass, "ngraph::pass::NodePass", 0);
|
||||
OPENVINO_RTTI_DEFINITION(ngraph::pass::NodePass, "ngraph::pass::NodePass", 0);
|
||||
|
||||
pass::NodePass::~NodePass() {}
|
||||
ngraph::pass::NodePass::~NodePass() = default;
|
||||
|
@ -2,11 +2,9 @@
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#include "ngraph/pass/pass_config.hpp"
|
||||
#include "openvino/pass/pass_config.hpp"
|
||||
|
||||
using namespace ngraph;
|
||||
|
||||
pass::param_callback pass::PassConfig::get_callback(const DiscreteTypeInfo& type_info) const {
|
||||
ov::pass::param_callback ov::pass::PassConfig::get_callback(const DiscreteTypeInfo& type_info) const {
|
||||
const auto& it = m_callback_map.find(type_info);
|
||||
if (it != m_callback_map.end()) {
|
||||
return it->second;
|
||||
@ -15,17 +13,17 @@ pass::param_callback pass::PassConfig::get_callback(const DiscreteTypeInfo& type
|
||||
}
|
||||
}
|
||||
|
||||
void pass::PassConfig::enable(const ngraph::DiscreteTypeInfo& type_info) {
|
||||
void ov::pass::PassConfig::enable(const ngraph::DiscreteTypeInfo& type_info) {
|
||||
m_disabled.erase(type_info);
|
||||
m_enabled.insert(type_info);
|
||||
}
|
||||
|
||||
void pass::PassConfig::disable(const ngraph::DiscreteTypeInfo& type_info) {
|
||||
void ov::pass::PassConfig::disable(const ngraph::DiscreteTypeInfo& type_info) {
|
||||
m_enabled.erase(type_info);
|
||||
m_disabled.insert(type_info);
|
||||
}
|
||||
|
||||
void pass::PassConfig::add_disabled_passes(const PassConfig& rhs) {
|
||||
void ov::pass::PassConfig::add_disabled_passes(const PassConfig& rhs) {
|
||||
for (const auto& pass : rhs.m_disabled) {
|
||||
if (is_enabled(pass))
|
||||
continue;
|
||||
|
@ -3,7 +3,7 @@
|
||||
//
|
||||
#include "perf_counters.hpp"
|
||||
|
||||
namespace ngraph {
|
||||
namespace ov {
|
||||
namespace pass {
|
||||
openvino::itt::handle_t PerfCounters::operator[](::ngraph::Node::type_info_t const& type_inf) {
|
||||
std::lock_guard<std::mutex> guard(m_mutex);
|
||||
@ -13,4 +13,4 @@ openvino::itt::handle_t PerfCounters::operator[](::ngraph::Node::type_info_t con
|
||||
return m_counters[&type_inf] = openvino::itt::handle(type_inf.name);
|
||||
}
|
||||
} // namespace pass
|
||||
} // namespace ngraph
|
||||
} // namespace ov
|
||||
|
@ -7,7 +7,7 @@
|
||||
#include <ngraph/node.hpp>
|
||||
#include <unordered_map>
|
||||
|
||||
namespace ngraph {
|
||||
namespace ov {
|
||||
namespace pass {
|
||||
class PerfCounters {
|
||||
PerfCounters(PerfCounters const&) = delete;
|
||||
@ -27,4 +27,4 @@ private:
|
||||
counters_map m_counters;
|
||||
};
|
||||
} // namespace pass
|
||||
} // namespace ngraph
|
||||
} // namespace ov
|
||||
|
@ -2,16 +2,16 @@
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#include "ngraph/pass/validate.hpp"
|
||||
#include "openvino/pass/validate.hpp"
|
||||
|
||||
#include "itt.hpp"
|
||||
#include "ngraph/graph_util.hpp"
|
||||
|
||||
using namespace ngraph;
|
||||
|
||||
NGRAPH_RTTI_DEFINITION(ngraph::pass::Validate, "ngraph::pass::Validate", 0);
|
||||
OPENVINO_RTTI_DEFINITION(ov::pass::Validate, "ov::pass::Validate", 0);
|
||||
|
||||
bool pass::Validate::run_on_function(std::shared_ptr<Function> f) {
|
||||
bool ov::pass::Validate::run_on_function(std::shared_ptr<Function> f) {
|
||||
f->validate_nodes_and_infer_types();
|
||||
return false;
|
||||
}
|
||||
|
@ -13,7 +13,8 @@
|
||||
#include "ngraph/op/parameter.hpp"
|
||||
#include "ngraph/op/util/op_types.hpp"
|
||||
|
||||
namespace ngraph {
|
||||
namespace ov {
|
||||
namespace pass {
|
||||
namespace pattern {
|
||||
MatcherState::MatcherState(Matcher* matcher)
|
||||
: m_matcher(matcher),
|
||||
@ -88,7 +89,7 @@ bool Matcher::is_contained_match(const NodeVector& exclusions, bool ignore_unuse
|
||||
NGRAPH_SUPPRESS_DEPRECATED_START
|
||||
if (exclusions.empty()) {
|
||||
NodeVector label_exclusions;
|
||||
for (auto entry : m_pattern_map) {
|
||||
for (const auto& entry : m_pattern_map) {
|
||||
// leaf label
|
||||
if (entry.first->get_input_size() == 0) {
|
||||
label_exclusions.push_back(entry.second.get_node_shared_ptr());
|
||||
@ -108,7 +109,7 @@ bool Matcher::match_value(const ngraph::Output<Node>& pattern_value, const ngrap
|
||||
// This env var allows one to specify node name patterns to abort pattern matching
|
||||
// at particular nodes. The upshot is that one can quickly zero in on an offending
|
||||
// fusion by disabling individual fusions or optimizations that use Matcher.
|
||||
static const std::string node_skip_cregex = getenv_string("NGRAPH_FAIL_MATCH_AT");
|
||||
static const std::string node_skip_cregex = ngraph::getenv_string("NGRAPH_FAIL_MATCH_AT");
|
||||
if (!node_skip_cregex.empty()) {
|
||||
static const std::regex node_skip_regex(node_skip_cregex);
|
||||
if (std::regex_match(graph_node->get_name(), node_skip_regex)) {
|
||||
@ -201,7 +202,7 @@ void Matcher::clear_state() {
|
||||
namespace {
|
||||
std::set<std::shared_ptr<Node>> as_node_set(const std::set<std::shared_ptr<op::Label>>& label_set) {
|
||||
std::set<std::shared_ptr<Node>> result;
|
||||
for (auto label : label_set) {
|
||||
for (const auto& label : label_set) {
|
||||
result.insert(label);
|
||||
}
|
||||
return result;
|
||||
@ -230,7 +231,7 @@ bool RecurrentMatcher::match(Output<Node> graph) {
|
||||
graph = m.get_pattern_value_map()[m_recurrent_pattern];
|
||||
|
||||
// copy bound nodes for the current pattern graph into a global matches map
|
||||
for (auto cur_match : m.get_pattern_value_map()) {
|
||||
for (const auto& cur_match : m.get_pattern_value_map()) {
|
||||
m_matches[cur_match.first].push_back(cur_match.second);
|
||||
}
|
||||
|
||||
@ -238,7 +239,7 @@ bool RecurrentMatcher::match(Output<Node> graph) {
|
||||
// from the current match. Only bound nodes whose labels are in
|
||||
// correlated_patterns are pre-populated. Skip other labels are
|
||||
// unbounded by default
|
||||
for (auto cor_pat : m_correlated_patterns) {
|
||||
for (const auto& cor_pat : m_correlated_patterns) {
|
||||
previous_matches[cor_pat] = m.get_pattern_value_map()[cor_pat];
|
||||
}
|
||||
m = m_repeat;
|
||||
@ -251,4 +252,5 @@ bool RecurrentMatcher::match(Output<Node> graph) {
|
||||
return matched;
|
||||
}
|
||||
} // namespace pattern
|
||||
} // namespace ngraph
|
||||
} // namespace pass
|
||||
} // namespace ov
|
||||
|
@ -7,17 +7,16 @@
|
||||
#include "ngraph/pattern/matcher.hpp"
|
||||
|
||||
using namespace std;
|
||||
using namespace ngraph;
|
||||
|
||||
constexpr NodeTypeInfo pattern::op::Any::type_info;
|
||||
constexpr ov::NodeTypeInfo ov::pass::pattern::op::Any::type_info;
|
||||
|
||||
const NodeTypeInfo& pattern::op::Any::get_type_info() const {
|
||||
const ov::NodeTypeInfo& ov::pass::pattern::op::Any::get_type_info() const {
|
||||
return type_info;
|
||||
}
|
||||
|
||||
bool pattern::op::Any::match_value(Matcher* matcher,
|
||||
const Output<Node>& pattern_value,
|
||||
const Output<Node>& graph_value) {
|
||||
bool ov::pass::pattern::op::Any::match_value(Matcher* matcher,
|
||||
const Output<Node>& pattern_value,
|
||||
const Output<Node>& graph_value) {
|
||||
matcher->add_node(graph_value);
|
||||
return m_predicate(graph_value) &&
|
||||
matcher->match_arguments(pattern_value.get_node(), graph_value.get_node_shared_ptr());
|
||||
|
@ -7,20 +7,19 @@
|
||||
#include "ngraph/pattern/matcher.hpp"
|
||||
|
||||
using namespace std;
|
||||
using namespace ngraph;
|
||||
|
||||
constexpr NodeTypeInfo pattern::op::AnyOf::type_info;
|
||||
constexpr ov::NodeTypeInfo ov::pass::pattern::op::AnyOf::type_info;
|
||||
|
||||
const NodeTypeInfo& pattern::op::AnyOf::get_type_info() const {
|
||||
const ov::NodeTypeInfo& ov::pass::pattern::op::AnyOf::get_type_info() const {
|
||||
return type_info;
|
||||
}
|
||||
|
||||
bool pattern::op::AnyOf::match_value(Matcher* matcher,
|
||||
const Output<Node>& pattern_value,
|
||||
const Output<Node>& graph_value) {
|
||||
bool ov::pass::pattern::op::AnyOf::match_value(Matcher* matcher,
|
||||
const Output<Node>& pattern_value,
|
||||
const Output<Node>& graph_value) {
|
||||
matcher->add_node(graph_value);
|
||||
return m_predicate(graph_value) && ([&]() {
|
||||
for (auto arg : graph_value.get_node_shared_ptr()->input_values()) {
|
||||
for (const auto& arg : graph_value.get_node_shared_ptr()->input_values()) {
|
||||
auto saved = matcher->start_match();
|
||||
if (matcher->match_value(input_value(0), arg)) {
|
||||
return saved.finish(true);
|
||||
|
@ -7,16 +7,15 @@
|
||||
#include "ngraph/pattern/matcher.hpp"
|
||||
|
||||
using namespace std;
|
||||
using namespace ngraph;
|
||||
|
||||
constexpr NodeTypeInfo pattern::op::AnyOutput::type_info;
|
||||
constexpr ov::NodeTypeInfo ov::pass::pattern::op::AnyOutput::type_info;
|
||||
|
||||
const NodeTypeInfo& pattern::op::AnyOutput::get_type_info() const {
|
||||
const ov::NodeTypeInfo& ov::pass::pattern::op::AnyOutput::get_type_info() const {
|
||||
return type_info;
|
||||
}
|
||||
|
||||
bool pattern::op::AnyOutput::match_value(Matcher* matcher,
|
||||
const Output<Node>& pattern_value,
|
||||
const Output<Node>& graph_value) {
|
||||
bool ov::pass::pattern::op::AnyOutput::match_value(Matcher* matcher,
|
||||
const Output<Node>& pattern_value,
|
||||
const Output<Node>& graph_value) {
|
||||
return input_value(0).get_node()->match_node(matcher, graph_value);
|
||||
}
|
||||
|
@ -9,15 +9,14 @@
|
||||
#include "ngraph/pattern/op/true.hpp"
|
||||
|
||||
using namespace std;
|
||||
using namespace ngraph;
|
||||
|
||||
constexpr NodeTypeInfo pattern::op::Label::type_info;
|
||||
constexpr ov::NodeTypeInfo ov::pass::pattern::op::Label::type_info;
|
||||
|
||||
const NodeTypeInfo& pattern::op::Label::get_type_info() const {
|
||||
const ov::NodeTypeInfo& ov::pass::pattern::op::Label::get_type_info() const {
|
||||
return type_info;
|
||||
}
|
||||
|
||||
Output<Node> pattern::op::Label::wrap_values(const OutputVector& wrapped_values) {
|
||||
ov::Output<ov::Node> ov::pass::pattern::op::Label::wrap_values(const ov::OutputVector& wrapped_values) {
|
||||
switch (wrapped_values.size()) {
|
||||
case 0:
|
||||
return make_shared<pattern::op::True>()->output(0);
|
||||
@ -28,9 +27,9 @@ Output<Node> pattern::op::Label::wrap_values(const OutputVector& wrapped_values)
|
||||
}
|
||||
}
|
||||
|
||||
bool pattern::op::Label::match_value(Matcher* matcher,
|
||||
const Output<Node>& pattern_value,
|
||||
const Output<Node>& graph_value) {
|
||||
bool ov::pass::pattern::op::Label::match_value(ov::pass::pattern::Matcher* matcher,
|
||||
const ov::Output<ov::Node>& pattern_value,
|
||||
const ov::Output<ov::Node>& graph_value) {
|
||||
if (m_predicate(graph_value)) {
|
||||
auto& pattern_map = matcher->get_pattern_value_map();
|
||||
auto saved = matcher->start_match();
|
||||
@ -45,10 +44,10 @@ bool pattern::op::Label::match_value(Matcher* matcher,
|
||||
return false;
|
||||
}
|
||||
|
||||
std::shared_ptr<Node> pattern::any_input() {
|
||||
std::shared_ptr<ov::Node> ov::pass::pattern::any_input() {
|
||||
return std::make_shared<pattern::op::Label>();
|
||||
}
|
||||
|
||||
std::shared_ptr<Node> pattern::any_input(const pattern::op::ValuePredicate& pred) {
|
||||
std::shared_ptr<ov::Node> ov::pass::pattern::any_input(const ov::pass::pattern::op::ValuePredicate& pred) {
|
||||
return std::make_shared<pattern::op::Label>(element::dynamic, PartialShape::dynamic(), pred);
|
||||
}
|
||||
|
@ -7,7 +7,8 @@
|
||||
#include <algorithm>
|
||||
#include <regex>
|
||||
|
||||
namespace ngraph {
|
||||
namespace ov {
|
||||
namespace pass {
|
||||
namespace pattern {
|
||||
namespace op {
|
||||
// The symbols are required to be in cpp file to workaround RTTI issue on Android LLVM
|
||||
@ -101,4 +102,5 @@ std::function<bool(Output<Node>)> type_matches_any(const std::vector<element::Ty
|
||||
};
|
||||
}
|
||||
} // namespace pattern
|
||||
} // namespace ngraph
|
||||
} // namespace pass
|
||||
} // namespace ov
|
||||
|
@ -2,9 +2,10 @@
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#include "dyn_elimination.hpp"
|
||||
|
||||
#include <numeric>
|
||||
|
||||
#include "dyn_elimination.hpp"
|
||||
#include "ngraph/builder/reshape.hpp"
|
||||
#include "ngraph/op/broadcast.hpp"
|
||||
#include "ngraph/op/range.hpp"
|
||||
@ -19,9 +20,7 @@ NGRAPH_SUPPRESS_DEPRECATED_START
|
||||
using namespace std;
|
||||
using namespace ngraph;
|
||||
|
||||
pass::DynElimination::DynElimination()
|
||||
: GraphRewrite()
|
||||
{
|
||||
pass::DynElimination::DynElimination() : GraphRewrite() {
|
||||
construct_range();
|
||||
}
|
||||
|
||||
@ -29,28 +28,22 @@ template <typename T>
|
||||
std::shared_ptr<op::Constant> make_range_replacement(const element::Type& et,
|
||||
const Shape& shape,
|
||||
const std::shared_ptr<op::Constant>& start_arg,
|
||||
const std::shared_ptr<op::Constant>& step_arg)
|
||||
{
|
||||
const std::shared_ptr<op::Constant>& step_arg) {
|
||||
std::vector<T> elements(shape_size(shape));
|
||||
std::vector<T> start_vec = start_arg->get_vector<T>();
|
||||
std::vector<T> step_vec = step_arg->get_vector<T>();
|
||||
|
||||
NGRAPH_CHECK(start_vec.size() == 1 && step_vec.size() == 1);
|
||||
|
||||
runtime::reference::range<T>(
|
||||
start_vec.data(), step_vec.data(), shape_size(shape), elements.data());
|
||||
runtime::reference::range<T>(start_vec.data(), step_vec.data(), shape_size(shape), elements.data());
|
||||
|
||||
return make_shared<op::Constant>(et, shape, elements);
|
||||
}
|
||||
|
||||
void pass::DynElimination::construct_range()
|
||||
{
|
||||
auto start_arg_label =
|
||||
make_shared<pattern::op::Label>(element::f32, Shape{}, pattern::has_class<op::Constant>());
|
||||
auto stop_arg_label =
|
||||
make_shared<pattern::op::Label>(element::f32, Shape{}, pattern::has_class<op::Constant>());
|
||||
auto step_arg_label =
|
||||
make_shared<pattern::op::Label>(element::f32, Shape{}, pattern::has_class<op::Constant>());
|
||||
void pass::DynElimination::construct_range() {
|
||||
auto start_arg_label = make_shared<pattern::op::Label>(element::f32, Shape{}, pattern::has_class<op::Constant>());
|
||||
auto stop_arg_label = make_shared<pattern::op::Label>(element::f32, Shape{}, pattern::has_class<op::Constant>());
|
||||
auto step_arg_label = make_shared<pattern::op::Label>(element::f32, Shape{}, pattern::has_class<op::Constant>());
|
||||
|
||||
auto range_pat = make_shared<op::Range>(start_arg_label, stop_arg_label, step_arg_label);
|
||||
|
||||
@ -70,12 +63,11 @@ void pass::DynElimination::construct_range()
|
||||
std::shared_ptr<op::Constant> replacement;
|
||||
|
||||
#if defined(__GNUC__) && !(__GNUC__ == 4 && __GNUC_MINOR__ == 8)
|
||||
#pragma GCC diagnostic push
|
||||
#pragma GCC diagnostic error "-Wswitch"
|
||||
#pragma GCC diagnostic error "-Wswitch-enum"
|
||||
# pragma GCC diagnostic push
|
||||
# pragma GCC diagnostic error "-Wswitch"
|
||||
# pragma GCC diagnostic error "-Wswitch-enum"
|
||||
#endif
|
||||
switch (et)
|
||||
{
|
||||
switch (et) {
|
||||
case element::Type_t::bf16:
|
||||
replacement = make_range_replacement<bfloat16>(et, shape, start_arg, step_arg);
|
||||
break;
|
||||
@ -122,7 +114,7 @@ void pass::DynElimination::construct_range()
|
||||
break;
|
||||
}
|
||||
#if defined(__GNUC__) && !(__GNUC__ == 4 && __GNUC_MINOR__ == 8)
|
||||
#pragma GCC diagnostic pop
|
||||
# pragma GCC diagnostic pop
|
||||
#endif
|
||||
|
||||
replace_node(range_node, replacement);
|
||||
|
Loading…
Reference in New Issue
Block a user