Move pass pattern to ov (#7255)
* Moved ngraph::Node to ov namespace * Fixed code style * Fixed VPU * Fixed GNA * Fixed tests * Added aliases for backward compatibility * Fix clDNN * Try to fix build * Fixed comment * Renamed RTTI macros * Add new headers * Fixed ngraph build * Fixed unit tests * Try to fix Serialize
This commit is contained in:
parent
07f7061f96
commit
9eca6ba9d5
@ -26,7 +26,7 @@ endif()
|
|||||||
# resolving dependencies for the project
|
# resolving dependencies for the project
|
||||||
message (STATUS "PROJECT ............................... " ${PROJECT_NAME})
|
message (STATUS "PROJECT ............................... " ${PROJECT_NAME})
|
||||||
message (STATUS "CMAKE_BINARY_DIR ...................... " ${CMAKE_BINARY_DIR})
|
message (STATUS "CMAKE_BINARY_DIR ...................... " ${CMAKE_BINARY_DIR})
|
||||||
message (STATUS "OpenVINO_SOURCE_DIR .... .......... " ${OpenVINO_SOURCE_DIR})
|
message (STATUS "OpenVINO_SOURCE_DIR ................... " ${OpenVINO_SOURCE_DIR})
|
||||||
message (STATUS "CMAKE_GENERATOR ....................... " ${CMAKE_GENERATOR})
|
message (STATUS "CMAKE_GENERATOR ....................... " ${CMAKE_GENERATOR})
|
||||||
message (STATUS "CMAKE_C_COMPILER_ID ................... " ${CMAKE_C_COMPILER_ID})
|
message (STATUS "CMAKE_C_COMPILER_ID ................... " ${CMAKE_C_COMPILER_ID})
|
||||||
message (STATUS "CMAKE_BUILD_TYPE ...................... " ${CMAKE_BUILD_TYPE})
|
message (STATUS "CMAKE_BUILD_TYPE ...................... " ${CMAKE_BUILD_TYPE})
|
||||||
|
@ -811,8 +811,34 @@ void ngfunction_2_irv10(pugi::xml_node& netXml,
|
|||||||
f.validate_nodes_and_infer_types();
|
f.validate_nodes_and_infer_types();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
std::string valid_xml_path(const std::string &path) {
|
||||||
|
NGRAPH_CHECK(path.length() > 4, "Path for xml file is to short: \"" + path + "\"");
|
||||||
|
|
||||||
|
const char *const extension = ".xml";
|
||||||
|
const bool has_xml_extension = path.rfind(extension) == path.size() - std::strlen(extension);
|
||||||
|
NGRAPH_CHECK(has_xml_extension,
|
||||||
|
"Path for xml file doesn't contains file name with 'xml' extension: \"" +
|
||||||
|
path + "\"");
|
||||||
|
return path;
|
||||||
|
}
|
||||||
|
|
||||||
|
std::string provide_bin_path(const std::string &xmlPath, const std::string &binPath) {
|
||||||
|
if (!binPath.empty()) {
|
||||||
|
return binPath;
|
||||||
|
}
|
||||||
|
assert(xmlPath.size() > 4); // should be check by valid_xml_path
|
||||||
|
std::string bestPath = xmlPath;
|
||||||
|
const char *const extension = "bin";
|
||||||
|
const auto ext_size = std::strlen(extension);
|
||||||
|
bestPath.replace(bestPath.size() - ext_size, ext_size, extension);
|
||||||
|
return bestPath;
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
|
namespace ngraph {
|
||||||
|
|
||||||
// ! [function_pass:serialize_cpp]
|
// ! [function_pass:serialize_cpp]
|
||||||
// serialize.cpp
|
// serialize.cpp
|
||||||
bool pass::Serialize::run_on_function(std::shared_ptr<ngraph::Function> f) {
|
bool pass::Serialize::run_on_function(std::shared_ptr<ngraph::Function> f) {
|
||||||
@ -868,33 +894,6 @@ bool pass::Serialize::run_on_function(std::shared_ptr<ngraph::Function> f) {
|
|||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
namespace {
|
|
||||||
|
|
||||||
std::string valid_xml_path(const std::string &path) {
|
|
||||||
NGRAPH_CHECK(path.length() > 4, "Path for xml file is to short: \"" + path + "\"");
|
|
||||||
|
|
||||||
const char *const extension = ".xml";
|
|
||||||
const bool has_xml_extension = path.rfind(extension) == path.size() - std::strlen(extension);
|
|
||||||
NGRAPH_CHECK(has_xml_extension,
|
|
||||||
"Path for xml file doesn't contains file name with 'xml' extension: \"" +
|
|
||||||
path + "\"");
|
|
||||||
return path;
|
|
||||||
}
|
|
||||||
|
|
||||||
std::string provide_bin_path(const std::string &xmlPath, const std::string &binPath) {
|
|
||||||
if (!binPath.empty()) {
|
|
||||||
return binPath;
|
|
||||||
}
|
|
||||||
assert(xmlPath.size() > 4); // should be check by valid_xml_path
|
|
||||||
std::string bestPath = xmlPath;
|
|
||||||
const char *const extension = "bin";
|
|
||||||
const auto ext_size = std::strlen(extension);
|
|
||||||
bestPath.replace(bestPath.size() - ext_size, ext_size, extension);
|
|
||||||
return bestPath;
|
|
||||||
}
|
|
||||||
|
|
||||||
} // namespace
|
|
||||||
|
|
||||||
pass::Serialize::Serialize(std::ostream& xmlFile,
|
pass::Serialize::Serialize(std::ostream& xmlFile,
|
||||||
std::ostream& binFile,
|
std::ostream& binFile,
|
||||||
pass::Serialize::Version version,
|
pass::Serialize::Version version,
|
||||||
@ -921,3 +920,4 @@ pass::Serialize::Serialize(const std::string& xmlPath,
|
|||||||
{
|
{
|
||||||
}
|
}
|
||||||
// ! [function_pass:serialize_cpp]
|
// ! [function_pass:serialize_cpp]
|
||||||
|
} // namespace ngraph
|
||||||
|
@ -5,24 +5,10 @@
|
|||||||
#pragma once
|
#pragma once
|
||||||
|
|
||||||
#include "ngraph/pass/pass.hpp"
|
#include "ngraph/pass/pass.hpp"
|
||||||
|
#include "openvino/pass/constant_folding.hpp"
|
||||||
|
|
||||||
namespace ngraph {
|
namespace ngraph {
|
||||||
namespace pass {
|
namespace pass {
|
||||||
/**
|
using ov::pass::ConstantFolding;
|
||||||
* @brief Constant folding iterates over the function and tries to evaluate nodes
|
|
||||||
* with constant inputs. Such nodes are then replaced with new Constants containing
|
|
||||||
* the result of a folded operation.
|
|
||||||
*/
|
|
||||||
class NGRAPH_API ConstantFolding : public FunctionPass {
|
|
||||||
public:
|
|
||||||
NGRAPH_RTTI_DECLARATION;
|
|
||||||
bool run_on_function(std::shared_ptr<ngraph::Function> f) override;
|
|
||||||
|
|
||||||
private:
|
|
||||||
void copy_runtime_info_to_target_inputs(const std::shared_ptr<Node>& node, const Output<Node>& replacement);
|
|
||||||
/// \brief Folds pre-calculated output tensor values to constants in case lower and
|
|
||||||
/// upper estimations are equal. Traverses graph backwards starting from the results.
|
|
||||||
bool pre_calculated_values_folding(const std::shared_ptr<ngraph::Function>& f);
|
|
||||||
};
|
|
||||||
} // namespace pass
|
} // namespace pass
|
||||||
} // namespace ngraph
|
} // namespace ngraph
|
||||||
|
@ -4,14 +4,11 @@
|
|||||||
|
|
||||||
#pragma once
|
#pragma once
|
||||||
|
|
||||||
#include <ngraph/pass/graph_rewrite.hpp>
|
#include "ngraph/pass/graph_rewrite.hpp"
|
||||||
|
#include "openvino/pass/convert_fp32_to_fp16.hpp"
|
||||||
|
|
||||||
namespace ngraph {
|
namespace ngraph {
|
||||||
namespace pass {
|
namespace pass {
|
||||||
class NGRAPH_API ConvertFP32ToFP16 : public ngraph::pass::FunctionPass {
|
using ov::pass::ConvertFP32ToFP16;
|
||||||
public:
|
|
||||||
NGRAPH_RTTI_DECLARATION;
|
|
||||||
bool run_on_function(std::shared_ptr<ngraph::Function>) override;
|
|
||||||
};
|
|
||||||
} // namespace pass
|
} // namespace pass
|
||||||
} // namespace ngraph
|
} // namespace ngraph
|
||||||
|
@ -10,240 +10,17 @@
|
|||||||
|
|
||||||
#include "ngraph/pass/pass.hpp"
|
#include "ngraph/pass/pass.hpp"
|
||||||
#include "ngraph/pattern/matcher.hpp"
|
#include "ngraph/pattern/matcher.hpp"
|
||||||
|
#include "openvino/pass/graph_rewrite.hpp"
|
||||||
|
|
||||||
namespace ngraph {
|
namespace ngraph {
|
||||||
using matcher_pass_callback = std::function<bool(ngraph::pattern::Matcher& m)>;
|
using ov::graph_rewrite_callback;
|
||||||
using graph_rewrite_callback = std::function<bool(ngraph::pattern::Matcher& m)>;
|
using ov::handler_callback;
|
||||||
using recurrent_graph_rewrite_callback = std::function<bool(ngraph::pattern::RecurrentMatcher& m)>;
|
using ov::matcher_pass_callback;
|
||||||
using handler_callback = std::function<bool(const std::shared_ptr<Node>& node)>;
|
using ov::recurrent_graph_rewrite_callback;
|
||||||
namespace pass {
|
namespace pass {
|
||||||
/// \brief MatcherPass is a basic block for pattern based transformations. It describes
|
using ov::pass::BackwardGraphRewrite;
|
||||||
/// pattern and
|
using ov::pass::GraphRewrite;
|
||||||
/// action that is applied if pattern is matched.
|
using ov::pass::MatcherPass;
|
||||||
///
|
using ov::pass::RecurrentGraphRewrite;
|
||||||
/// MatcherPass consists of Matcher and matcher_pass_callback that needs to be implemented
|
|
||||||
/// and
|
|
||||||
/// finally registered by using \sa register_matcher. MatcherPass can be executed on node
|
|
||||||
/// within
|
|
||||||
/// \sa apply method. To run matcher pass on Function use GraphRewrite.
|
|
||||||
/// In addition MatcherPass provides a way for adding new operations into GraphRewrite
|
|
||||||
/// execution
|
|
||||||
/// queue. That means that operations that were created inside transformation callback can
|
|
||||||
/// be added
|
|
||||||
/// for matching. To register node use \sa register_new_node method. GraphRewrite
|
|
||||||
/// automatically
|
|
||||||
/// takes registered nodes and put them to execution queue. If multiple nodes were register
|
|
||||||
/// make
|
|
||||||
/// sure that they were registered in topological order.
|
|
||||||
/// Note: when implementing pattern for Matcher make sure that root node is an operation
|
|
||||||
/// from opset
|
|
||||||
/// or has ngraph::pattern::op::WrapType. That will help GraphRewrite to execute matcher
|
|
||||||
/// passes more
|
|
||||||
/// efficient.
|
|
||||||
|
|
||||||
class NGRAPH_API MatcherPass : public ngraph::pass::PassBase {
|
|
||||||
public:
|
|
||||||
NGRAPH_RTTI_DECLARATION;
|
|
||||||
|
|
||||||
MatcherPass() = default;
|
|
||||||
|
|
||||||
MatcherPass(const MatcherPass&) = delete;
|
|
||||||
MatcherPass& operator=(const MatcherPass&) = delete;
|
|
||||||
|
|
||||||
explicit MatcherPass(const std::string& name,
|
|
||||||
const std::shared_ptr<pattern::Matcher>& m,
|
|
||||||
const handler_callback& handler,
|
|
||||||
const PassPropertyMask& property = PassProperty::CHANGE_DYNAMIC_STATE)
|
|
||||||
: PassBase(),
|
|
||||||
m_handler(handler),
|
|
||||||
m_matcher(m) {
|
|
||||||
set_name(name);
|
|
||||||
set_property(property, true);
|
|
||||||
}
|
|
||||||
|
|
||||||
bool apply(std::shared_ptr<ngraph::Node> node);
|
|
||||||
|
|
||||||
template <typename T, class... Args>
|
|
||||||
std::shared_ptr<T> register_new_node(Args&&... args) {
|
|
||||||
auto node = std::make_shared<T>(std::forward<Args>(args)...);
|
|
||||||
m_new_nodes.push_back(node);
|
|
||||||
return node;
|
|
||||||
}
|
|
||||||
|
|
||||||
template <typename T>
|
|
||||||
std::shared_ptr<T> register_new_node(const std::shared_ptr<T>& node) {
|
|
||||||
m_new_nodes.push_back(node);
|
|
||||||
return node;
|
|
||||||
}
|
|
||||||
|
|
||||||
const std::vector<std::shared_ptr<ngraph::Node>>& get_new_nodes() {
|
|
||||||
return m_new_nodes;
|
|
||||||
}
|
|
||||||
void clear_new_nodes() {
|
|
||||||
m_new_nodes.clear();
|
|
||||||
}
|
|
||||||
std::shared_ptr<pattern::Matcher> get_matcher() {
|
|
||||||
return m_matcher;
|
|
||||||
}
|
|
||||||
|
|
||||||
protected:
|
|
||||||
void register_matcher(const std::shared_ptr<pattern::Matcher>& m,
|
|
||||||
const ngraph::graph_rewrite_callback& callback,
|
|
||||||
const PassPropertyMask& property = PassProperty::CHANGE_DYNAMIC_STATE);
|
|
||||||
|
|
||||||
private:
|
|
||||||
handler_callback m_handler;
|
|
||||||
std::shared_ptr<pattern::Matcher> m_matcher;
|
|
||||||
std::vector<std::shared_ptr<ngraph::Node>> m_new_nodes;
|
|
||||||
};
|
|
||||||
|
|
||||||
/// \brief GraphRewrite is a container for MatcherPasses that allows to run them on Function
|
|
||||||
/// in
|
|
||||||
/// efficient way
|
|
||||||
///
|
|
||||||
/// Graph rewrite pass is used for matcher passes execution on Function.
|
|
||||||
/// To register MatcherPass use \sa add_matcher<T>(args) method where T is a MatcherPass
|
|
||||||
/// class.
|
|
||||||
/// As a default algorithm graph rewrite pass traverse Function in topological order and
|
|
||||||
/// applies
|
|
||||||
/// registered matcher passes for each node. But if all registered matcher passes have type
|
|
||||||
/// based
|
|
||||||
/// root node in Matcher pattern then efficient mechanism is used to execute them.
|
|
||||||
/// Matcher pattern root is type based if it's operation from opset or
|
|
||||||
/// pattern::op::WrapType.
|
|
||||||
/// Note: when implementing pattern for Matcher make sure that root node is an operation
|
|
||||||
/// from opset
|
|
||||||
/// or has ngraph::pattern::op::WrapType. That will help GraphRewrite to execute matcher
|
|
||||||
/// passes more
|
|
||||||
/// efficient.
|
|
||||||
|
|
||||||
class NGRAPH_API GraphRewrite : public ngraph::pass::FunctionPass {
|
|
||||||
public:
|
|
||||||
NGRAPH_RTTI_DECLARATION;
|
|
||||||
|
|
||||||
GraphRewrite() = default;
|
|
||||||
|
|
||||||
explicit GraphRewrite(const std::shared_ptr<MatcherPass>& pass) : FunctionPass() {
|
|
||||||
m_matchers.push_back(pass);
|
|
||||||
}
|
|
||||||
|
|
||||||
/// \brief Register given transformation class type to GraphRewrite execution list
|
|
||||||
/// All registered transformations will be executed in a single graph traversal.
|
|
||||||
/// Example below show the basic usage of pass::GraphRewrite
|
|
||||||
///
|
|
||||||
/// pass::Manager manager;
|
|
||||||
/// auto anchor = manager.register_pass<GraphRewrite>();
|
|
||||||
/// anchor->add_matcher<MatcherPassA>();
|
|
||||||
/// anchor->add_matcher<MatcherPassB>();
|
|
||||||
/// anchor->set_name("CommonMatchers");
|
|
||||||
/// manager.run_passes(f);
|
|
||||||
///
|
|
||||||
/// For some purposes transformation can be registered and disabled by default.
|
|
||||||
///
|
|
||||||
/// anchor->add_matcher<MatcherPassB, false>();
|
|
||||||
///
|
|
||||||
/// \return shared_ptr to the transformation instance
|
|
||||||
template <typename T,
|
|
||||||
bool Enabled = true,
|
|
||||||
class... Args,
|
|
||||||
typename std::enable_if<std::is_base_of<pass::MatcherPass, T>::value, bool>::type = true>
|
|
||||||
std::shared_ptr<T> add_matcher(Args&&... args) {
|
|
||||||
static_assert(std::is_base_of<pass::MatcherPass, T>::value, "pass not derived from MatcherPass");
|
|
||||||
auto pass = std::make_shared<T>(std::forward<Args>(args)...);
|
|
||||||
auto pass_config = get_pass_config();
|
|
||||||
pass->set_pass_config(pass_config);
|
|
||||||
if (!Enabled && !pass_config->is_enabled<T>()) {
|
|
||||||
pass_config->disable<T>();
|
|
||||||
}
|
|
||||||
m_matchers.push_back(pass);
|
|
||||||
return pass;
|
|
||||||
}
|
|
||||||
|
|
||||||
/// \brief Register passes from GraphRewrite class that contains sequence of matcher
|
|
||||||
/// passes registered in its ctor.
|
|
||||||
/// For example:
|
|
||||||
///
|
|
||||||
/// class ngraph::pass::LinFusions: public ngraph::pass::GraphRewrite {
|
|
||||||
/// public:
|
|
||||||
/// NGRAPH_RTTI_DECLARATION;
|
|
||||||
/// Fusions() {
|
|
||||||
/// add_matcher<ngraph::pass::AddFusion>();
|
|
||||||
/// add_matcher<ngraph::pass::MulFusion>();
|
|
||||||
/// }
|
|
||||||
/// };
|
|
||||||
///
|
|
||||||
/// pass::Manager manager;
|
|
||||||
/// auto anchor = manager.register_pass<GraphRewrite>();
|
|
||||||
/// anchor->add_matcher<LinFusions>();
|
|
||||||
/// anchor->add_matcher<OtherFusions>();
|
|
||||||
/// anchor->set_name("CommonFusions");
|
|
||||||
/// manager.run_passes(f);
|
|
||||||
///
|
|
||||||
/// In this case all matcher passes from LinFusions pass will be united with other
|
|
||||||
/// registered matchers.
|
|
||||||
template <typename T,
|
|
||||||
class... Args,
|
|
||||||
typename std::enable_if<std::is_base_of<pass::GraphRewrite, T>::value, bool>::type = true>
|
|
||||||
void add_matcher(Args&&... args) {
|
|
||||||
static_assert(std::is_base_of<pass::GraphRewrite, T>::value, "pass not derived from GraphRewrite");
|
|
||||||
auto pass = std::make_shared<T>(std::forward<Args>(args)...);
|
|
||||||
auto pass_config = get_pass_config();
|
|
||||||
|
|
||||||
for (auto& matcher : pass->m_matchers) {
|
|
||||||
pass->set_pass_config(pass_config);
|
|
||||||
m_matchers.push_back(matcher);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
NGRAPH_DEPRECATED("Use MatcherPass instead")
|
|
||||||
void add_matcher(const std::shared_ptr<pattern::Matcher>& m,
|
|
||||||
const ngraph::graph_rewrite_callback& callback,
|
|
||||||
const PassPropertyMask& property);
|
|
||||||
|
|
||||||
NGRAPH_DEPRECATED("Use MatcherPass instead")
|
|
||||||
void add_matcher(const std::shared_ptr<pattern::Matcher>& m, const ngraph::graph_rewrite_callback& callback);
|
|
||||||
|
|
||||||
bool run_on_function(std::shared_ptr<ngraph::Function> f) override;
|
|
||||||
|
|
||||||
void set_pass_config(const std::shared_ptr<PassConfig>& pass_config) override;
|
|
||||||
|
|
||||||
protected:
|
|
||||||
bool apply_matcher_passes(std::shared_ptr<Function> f, std::deque<std::weak_ptr<Node>> nodes_to_run);
|
|
||||||
|
|
||||||
bool m_enable_shape_inference = false;
|
|
||||||
|
|
||||||
std::vector<std::shared_ptr<ngraph::pass::MatcherPass>> m_matchers;
|
|
||||||
};
|
|
||||||
|
|
||||||
class NGRAPH_API BackwardGraphRewrite : public ngraph::pass::GraphRewrite {
|
|
||||||
public:
|
|
||||||
NGRAPH_RTTI_DECLARATION;
|
|
||||||
|
|
||||||
BackwardGraphRewrite() = default;
|
|
||||||
|
|
||||||
explicit BackwardGraphRewrite(const std::shared_ptr<MatcherPass>& pass) : GraphRewrite(pass) {}
|
|
||||||
|
|
||||||
bool run_on_function(std::shared_ptr<ngraph::Function> f) override;
|
|
||||||
};
|
|
||||||
|
|
||||||
class NGRAPH_API RecurrentGraphRewrite : public ngraph::pass::FunctionPass {
|
|
||||||
public:
|
|
||||||
RecurrentGraphRewrite(size_t num_iters = 10) : FunctionPass(), m_num_iters(num_iters) {}
|
|
||||||
|
|
||||||
void add_matcher(const std::shared_ptr<pattern::RecurrentMatcher>& m,
|
|
||||||
const ngraph::recurrent_graph_rewrite_callback& callback,
|
|
||||||
const PassPropertyMask& property);
|
|
||||||
|
|
||||||
// TODO: This interface may deprecate after all passes are refactored.
|
|
||||||
void add_matcher(const std::shared_ptr<pattern::RecurrentMatcher>& m,
|
|
||||||
const ngraph::recurrent_graph_rewrite_callback& callback);
|
|
||||||
|
|
||||||
bool run_on_function(std::shared_ptr<ngraph::Function> f) override;
|
|
||||||
|
|
||||||
private:
|
|
||||||
size_t m_num_iters;
|
|
||||||
|
|
||||||
std::vector<std::shared_ptr<ngraph::pass::MatcherPass>> m_matchers;
|
|
||||||
};
|
|
||||||
} // namespace pass
|
} // namespace pass
|
||||||
} // namespace ngraph
|
} // namespace ngraph
|
||||||
|
@ -5,10 +5,12 @@
|
|||||||
#pragma once
|
#pragma once
|
||||||
|
|
||||||
#include <memory>
|
#include <memory>
|
||||||
#include <ngraph/pass/graph_rewrite.hpp>
|
|
||||||
#include <ngraph/pass/pass.hpp>
|
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
|
#include "ngraph/pass/graph_rewrite.hpp"
|
||||||
|
#include "ngraph/pass/pass.hpp"
|
||||||
|
#include "openvino/pass/low_latency.hpp"
|
||||||
|
|
||||||
namespace ngraph {
|
namespace ngraph {
|
||||||
namespace pass {
|
namespace pass {
|
||||||
/**
|
/**
|
||||||
@ -46,38 +48,6 @@ public:
|
|||||||
LowLatency();
|
LowLatency();
|
||||||
};
|
};
|
||||||
|
|
||||||
/**
|
using ov::pass::LowLatency2;
|
||||||
* @brief The transformation finds all TensorIterator/Loop layers in the network,
|
|
||||||
* processes all back edges that describe a connection between Result and Parameter
|
|
||||||
* of the TensorIterator/Loop bodies,and inserts ReadValue and Assign layers at the
|
|
||||||
* input and output corresponding to this back edge.
|
|
||||||
* Supported platforms: CPU, GNA.
|
|
||||||
*
|
|
||||||
* The example below describes the changes made by the transformation
|
|
||||||
* [] - TensorIterator body
|
|
||||||
* () - new layer
|
|
||||||
* BE - back-edge
|
|
||||||
*
|
|
||||||
* before applying the transformation:
|
|
||||||
* -> input1[BE_1 -> Parameter -> Layers ... -> Result -> BE_1 ]output1->
|
|
||||||
*
|
|
||||||
* after applying the transformation:
|
|
||||||
* ->(ReadValue)-> input1[BE_1 ->Parameter->Layers ...->Result->BE_1]output1 ->(Assign)
|
|
||||||
* \
|
|
||||||
* ->...
|
|
||||||
* After applying the transformation, the resulting network can be inferred
|
|
||||||
* step by step, the states will store between inferences.
|
|
||||||
*/
|
|
||||||
class NGRAPH_API LowLatency2 : public ngraph::pass::FunctionPass {
|
|
||||||
public:
|
|
||||||
NGRAPH_RTTI_DECLARATION;
|
|
||||||
|
|
||||||
explicit LowLatency2(bool use_const_initializer = true) : m_use_const_initializer(use_const_initializer) {}
|
|
||||||
|
|
||||||
bool run_on_function(std::shared_ptr<ngraph::Function> f) override;
|
|
||||||
|
|
||||||
private:
|
|
||||||
bool m_use_const_initializer;
|
|
||||||
};
|
|
||||||
} // namespace pass
|
} // namespace pass
|
||||||
} // namespace ngraph
|
} // namespace ngraph
|
||||||
|
@ -11,106 +11,10 @@
|
|||||||
|
|
||||||
#include "ngraph/pass/pass.hpp"
|
#include "ngraph/pass/pass.hpp"
|
||||||
#include "ngraph/pass/validate.hpp"
|
#include "ngraph/pass/validate.hpp"
|
||||||
|
#include "openvino/pass/manager.hpp"
|
||||||
|
|
||||||
namespace ngraph {
|
namespace ngraph {
|
||||||
namespace pass {
|
namespace pass {
|
||||||
class NGRAPH_API Manager {
|
using ov::pass::Manager;
|
||||||
public:
|
|
||||||
Manager();
|
|
||||||
~Manager();
|
|
||||||
|
|
||||||
//// \brief Construct Manager with shared PassConfig instance
|
|
||||||
explicit Manager(std::shared_ptr<PassConfig> pass_config);
|
|
||||||
|
|
||||||
/// \brief Register given transformation class type to execution list
|
|
||||||
/// Example below show the basic usage of pass::Manager
|
|
||||||
///
|
|
||||||
/// pass::Manager manager;
|
|
||||||
/// manager.register_pass<MyTransformation>(/*transformation constructor ars*/);
|
|
||||||
/// manager.run_passes(f);
|
|
||||||
///
|
|
||||||
/// For some purposes transformation can be registered and disabled by default.
|
|
||||||
///
|
|
||||||
/// manager.register_pass<MyTransformation, false>();
|
|
||||||
///
|
|
||||||
/// \return shared_ptr to the transformation instance
|
|
||||||
template <typename T, bool Enable = true, class... Args>
|
|
||||||
std::shared_ptr<T> register_pass(Args&&... args) {
|
|
||||||
auto rc = push_pass<T>(std::forward<Args>(args)...);
|
|
||||||
rc->set_pass_config(m_pass_config);
|
|
||||||
if (m_per_pass_validation) {
|
|
||||||
push_pass<Validate>();
|
|
||||||
}
|
|
||||||
if (!Enable && !m_pass_config->is_enabled<T>()) {
|
|
||||||
m_pass_config->disable<T>();
|
|
||||||
}
|
|
||||||
return rc;
|
|
||||||
}
|
|
||||||
|
|
||||||
void run_passes(std::shared_ptr<Function>);
|
|
||||||
|
|
||||||
void set_pass_visualization(bool new_state) {
|
|
||||||
m_visualize = new_state;
|
|
||||||
}
|
|
||||||
/// \brief Set flag to enable/disable running Validate pass after executing
|
|
||||||
/// each registered pass
|
|
||||||
/// \param new_state Value "true" enables Validate pass run; "false", otherwise
|
|
||||||
void set_per_pass_validation(bool new_state) {
|
|
||||||
m_per_pass_validation = new_state;
|
|
||||||
}
|
|
||||||
/// \brief Callback is a lambda function that can be used by registered transformations.
|
|
||||||
/// The main purpose of this callback is to provide a way for plugins to disable/enable
|
|
||||||
/// transformations based on some conditions. In some cases plugins may want not to
|
|
||||||
/// execute some
|
|
||||||
/// transformations.
|
|
||||||
/// For example plugin can disable unpleasant decompositions because of performance
|
|
||||||
/// reasons for
|
|
||||||
/// some cases.
|
|
||||||
/// Callback example:
|
|
||||||
/// auto callback = [](const std::shared_ptr<const ngraph::Node> & node) -> bool {
|
|
||||||
/// return std::dynamic_pointer_cast<const ngraph::opset3::DepthToSpace>(node) !=
|
|
||||||
/// nullptr;
|
|
||||||
/// };
|
|
||||||
/// This callback returns true in case of DepthToSpace operation. So when execution
|
|
||||||
/// DepthToSpace
|
|
||||||
/// decomposition pass will check is this decomposition needed or plugin can execute
|
|
||||||
/// this
|
|
||||||
/// operation directly. And of course on transformation side we need to have a response
|
|
||||||
/// for this
|
|
||||||
/// callback.
|
|
||||||
/// if (transformation_callback(batch_to_space)) {
|
|
||||||
/// return false;
|
|
||||||
/// }
|
|
||||||
/// \param callback lamda function that returns true in case if node is supported by
|
|
||||||
/// plugin and
|
|
||||||
/// transformation is not needed
|
|
||||||
NGRAPH_DEPRECATED("Please use get_pass_config() to configure transformation pipeline")
|
|
||||||
void set_callback(const param_callback& callback) {
|
|
||||||
m_pass_config->set_callback(callback);
|
|
||||||
}
|
|
||||||
/// \return PassConfig shared object. This object is used for transformations pipeline
|
|
||||||
/// configuration.
|
|
||||||
/// This object allows to disable/enable transformations execution, set callback to
|
|
||||||
/// particular
|
|
||||||
/// transformation. For mo details see PassConfig class.
|
|
||||||
std::shared_ptr<PassConfig> get_pass_config() {
|
|
||||||
return m_pass_config;
|
|
||||||
}
|
|
||||||
|
|
||||||
protected:
|
|
||||||
template <typename T, class... Args>
|
|
||||||
std::shared_ptr<T> push_pass(Args&&... args) {
|
|
||||||
static_assert(std::is_base_of<pass::PassBase, T>::value, "pass not derived from pass base");
|
|
||||||
auto pass = std::make_shared<T>(std::forward<Args>(args)...);
|
|
||||||
auto pass_base = std::static_pointer_cast<PassBase>(pass);
|
|
||||||
m_pass_list.push_back(pass_base);
|
|
||||||
return pass;
|
|
||||||
}
|
|
||||||
|
|
||||||
std::shared_ptr<PassConfig> m_pass_config;
|
|
||||||
std::vector<std::shared_ptr<PassBase>> m_pass_list;
|
|
||||||
bool m_visualize = false;
|
|
||||||
bool m_per_pass_validation = true;
|
|
||||||
};
|
|
||||||
} // namespace pass
|
} // namespace pass
|
||||||
} // namespace ngraph
|
} // namespace ngraph
|
||||||
|
@ -13,105 +13,32 @@
|
|||||||
#include "ngraph/node.hpp"
|
#include "ngraph/node.hpp"
|
||||||
#include "ngraph/pass/pass_config.hpp"
|
#include "ngraph/pass/pass_config.hpp"
|
||||||
#include "ngraph/util.hpp"
|
#include "ngraph/util.hpp"
|
||||||
|
#include "openvino/pass/pass.hpp"
|
||||||
|
|
||||||
|
namespace ov {
|
||||||
|
namespace pass {
|
||||||
|
|
||||||
|
class Manager;
|
||||||
|
|
||||||
|
}
|
||||||
|
} // namespace ov
|
||||||
namespace ngraph {
|
namespace ngraph {
|
||||||
namespace pass {
|
namespace pass {
|
||||||
enum class PassProperty : uint32_t {
|
using ov::pass::FunctionPass;
|
||||||
// Pass requires node shapes to be static
|
using ov::pass::FusionType;
|
||||||
REQUIRE_STATIC_SHAPE = 0x1,
|
using ov::pass::FusionTypeMask;
|
||||||
// Pass transformation will change the function's dynamic state
|
using ov::pass::Manager;
|
||||||
CHANGE_DYNAMIC_STATE = 1 << 1,
|
using ov::pass::PassBase;
|
||||||
};
|
using ov::pass::PassProperty;
|
||||||
|
using ov::pass::PassPropertyMask;
|
||||||
typedef EnumMask<PassProperty> PassPropertyMask;
|
NGRAPH_DEPRECATED("This variable is deprecated and will be removed soon.")
|
||||||
const PassPropertyMask all_pass_property_off;
|
const PassPropertyMask all_pass_property_off;
|
||||||
|
|
||||||
class NGRAPH_API PassBase {
|
|
||||||
friend class Manager;
|
|
||||||
|
|
||||||
public:
|
|
||||||
PassBase();
|
|
||||||
virtual ~PassBase() {}
|
|
||||||
/// Check if this pass has all the pass properties.
|
|
||||||
bool get_property(const PassPropertyMask& prop_mask) const;
|
|
||||||
|
|
||||||
void set_name(const std::string& name) {
|
|
||||||
m_name = name;
|
|
||||||
}
|
|
||||||
std::string get_name() const;
|
|
||||||
|
|
||||||
/// \brief Set callback for particular transformation type.
|
|
||||||
/// This method set global callback. For more details see PassConfig class
|
|
||||||
/// documentation.
|
|
||||||
/// \param callback lambda function that takes node and returns bool
|
|
||||||
void set_callback(const param_callback& callback);
|
|
||||||
|
|
||||||
/// \brief Set PassConfig for particular transformation instance
|
|
||||||
/// \param pass_config is a PassConfig shared_ptr
|
|
||||||
virtual void set_pass_config(const std::shared_ptr<PassConfig>& pass_config) {
|
|
||||||
m_pass_config = pass_config;
|
|
||||||
}
|
|
||||||
|
|
||||||
/// \brief Allows to access PassConfig shared instance
|
|
||||||
/// \return Shared instance of PassConfig class
|
|
||||||
std::shared_ptr<PassConfig> get_pass_config() {
|
|
||||||
return m_pass_config;
|
|
||||||
}
|
|
||||||
/// \brief Applies callback for given node. By default callback returns false.
|
|
||||||
/// This method remains here only for backward compatibility and will be removed
|
|
||||||
/// after all transformations are moved to transformation_callback() method.
|
|
||||||
/// \return result of callback execution for given node
|
|
||||||
NGRAPH_DEPRECATED("Please use transformation_callback method instead")
|
|
||||||
bool m_transformation_callback(const std::shared_ptr<const Node>& node) {
|
|
||||||
return m_pass_config->get_callback(get_type_info())(node);
|
|
||||||
}
|
|
||||||
|
|
||||||
/// \brief Applies callback for given node. By default callback returns false.
|
|
||||||
/// \param node which will be used inside callback
|
|
||||||
/// \return result of callback execution for given node
|
|
||||||
bool transformation_callback(const std::shared_ptr<const Node>& node) {
|
|
||||||
return m_pass_config->get_callback(get_type_info())(node);
|
|
||||||
}
|
|
||||||
|
|
||||||
using type_info_t = DiscreteTypeInfo;
|
|
||||||
|
|
||||||
virtual const type_info_t& get_type_info() const = 0;
|
|
||||||
|
|
||||||
protected:
|
|
||||||
void set_property(const PassPropertyMask& prop, bool value);
|
|
||||||
|
|
||||||
private:
|
|
||||||
PassPropertyMask m_property;
|
|
||||||
|
|
||||||
std::string m_name;
|
|
||||||
std::shared_ptr<PassConfig> m_pass_config;
|
|
||||||
};
|
|
||||||
|
|
||||||
class NGRAPH_API FunctionPass : public PassBase {
|
|
||||||
public:
|
|
||||||
NGRAPH_RTTI_DECLARATION;
|
|
||||||
virtual ~FunctionPass();
|
|
||||||
virtual bool run_on_function(std::shared_ptr<ngraph::Function>) = 0;
|
|
||||||
};
|
|
||||||
|
|
||||||
class NGRAPH_DEPRECATED("Use MatcherPass or FunctionPass instead.") NGRAPH_API NodePass : public PassBase {
|
class NGRAPH_DEPRECATED("Use MatcherPass or FunctionPass instead.") NGRAPH_API NodePass : public PassBase {
|
||||||
public:
|
public:
|
||||||
NGRAPH_RTTI_DECLARATION;
|
NGRAPH_RTTI_DECLARATION;
|
||||||
virtual ~NodePass();
|
~NodePass() override;
|
||||||
virtual bool run_on_node(std::shared_ptr<ngraph::Node>) = 0;
|
virtual bool run_on_node(std::shared_ptr<ngraph::Node>) = 0;
|
||||||
};
|
};
|
||||||
|
|
||||||
class Manager;
|
|
||||||
enum class FusionType : uint32_t {
|
|
||||||
//`DIFFERENTIABLE_FUSIONS` produce ops that support autodiff
|
|
||||||
// i.e. implement `generate_adjoints`
|
|
||||||
DIFFERENTIABLE_FUSIONS = 0x1,
|
|
||||||
REGULAR_FUSIONS = 0x2,
|
|
||||||
//`FOP_FUSIONS` produce ops in the FusedOps category that might
|
|
||||||
// not be supported by all backends
|
|
||||||
FOP_FUSIONS = 0x4,
|
|
||||||
ALL_FUSIONS = 0xFFFFFFFF
|
|
||||||
};
|
|
||||||
typedef EnumMask<FusionType> FusionTypeMask;
|
|
||||||
} // namespace pass
|
} // namespace pass
|
||||||
} // namespace ngraph
|
} // namespace ngraph
|
||||||
|
@ -12,164 +12,12 @@
|
|||||||
#include "ngraph/function.hpp"
|
#include "ngraph/function.hpp"
|
||||||
#include "ngraph/node.hpp"
|
#include "ngraph/node.hpp"
|
||||||
#include "ngraph/util.hpp"
|
#include "ngraph/util.hpp"
|
||||||
|
#include "openvino/pass/pass_config.hpp"
|
||||||
|
|
||||||
namespace ngraph {
|
namespace ngraph {
|
||||||
namespace pass {
|
namespace pass {
|
||||||
using param_callback = std::function<bool(const std::shared_ptr<const ::ngraph::Node>)>;
|
using ov::pass::param_callback;
|
||||||
using param_callback_map = std::map<ngraph::DiscreteTypeInfo, param_callback>;
|
using ov::pass::param_callback_map;
|
||||||
|
using ov::pass::PassConfig;
|
||||||
/// \brief Class representing a transformations config that is used for disabling/enabling
|
|
||||||
/// transformations registered inside pass::Manager and also allows to set callback for all
|
|
||||||
/// transformations or for particular transformation.
|
|
||||||
///
|
|
||||||
/// When pass::Manager is created all passes registered inside this manager including nested
|
|
||||||
/// passes will share the same instance of PassConfig class.
|
|
||||||
/// To work with this class first you need to get shared instance of this class by calling
|
|
||||||
/// manager.get_pass_config() method. Then you will be able to disable/enable passes based
|
|
||||||
/// on transformations type_info. For example:
|
|
||||||
///
|
|
||||||
/// pass::Manager manager;
|
|
||||||
/// manager.register_pass<CommonOptimizations>();
|
|
||||||
/// auto pass_config = manager.get_pass_config();
|
|
||||||
/// pass_config->disable<ConvertGELU>(); // this will disable nested pass inside
|
|
||||||
/// // CommonOptimizations pipeline
|
|
||||||
/// manager.run_passes(f);
|
|
||||||
///
|
|
||||||
/// Sometimes it is needed to call transformation inside other transformation manually. And
|
|
||||||
/// for that case before running transformation you need manually check that this pass is
|
|
||||||
/// not disabled and then you need to set current PassConfig instance to this
|
|
||||||
/// transformation. For example:
|
|
||||||
///
|
|
||||||
/// // Inside MatcherPass callback or inside FunctionPass run_on_function() method
|
|
||||||
/// // you need to call get_pass_config() method to get shared instance of PassConfig
|
|
||||||
/// auto pass_config = get_pass_config();
|
|
||||||
///
|
|
||||||
/// // Before running nested transformation you need to check is it disabled or not
|
|
||||||
/// if (!pass_config->is_disabled<ConvertGELU>()) {
|
|
||||||
/// auto pass = ConvertGELU();
|
|
||||||
/// pass->set_pass_config(pass_config);
|
|
||||||
/// pass.apply(node);
|
|
||||||
/// }
|
|
||||||
///
|
|
||||||
/// Following this logic inside your transformations you will guaranty that transformations
|
|
||||||
/// will be executed in a right way.
|
|
||||||
class NGRAPH_API PassConfig {
|
|
||||||
public:
|
|
||||||
/// \brief Disable transformation by its type_info
|
|
||||||
/// \param type_info Transformation type_info
|
|
||||||
void disable(const DiscreteTypeInfo& type_info);
|
|
||||||
/// \brief Disable transformation by its class type (based on type_info)
|
|
||||||
template <typename T>
|
|
||||||
void disable() {
|
|
||||||
NGRAPH_SUPPRESS_DEPRECATED_START
|
|
||||||
disable(T::type_info);
|
|
||||||
NGRAPH_SUPPRESS_DEPRECATED_END
|
|
||||||
}
|
|
||||||
|
|
||||||
/// \brief Enable transformation by its type_info
|
|
||||||
/// \param type_info Transformation type_info
|
|
||||||
void enable(const DiscreteTypeInfo& type_info);
|
|
||||||
/// \brief Enable transformation by its class type (based on type_info)
|
|
||||||
template <typename T>
|
|
||||||
void enable() {
|
|
||||||
NGRAPH_SUPPRESS_DEPRECATED_START
|
|
||||||
enable(T::type_info);
|
|
||||||
NGRAPH_SUPPRESS_DEPRECATED_END
|
|
||||||
}
|
|
||||||
|
|
||||||
/// \brief Set callback for all kind of transformations
|
|
||||||
void set_callback(const param_callback& callback) {
|
|
||||||
m_callback = callback;
|
|
||||||
}
|
|
||||||
template <typename... Args>
|
|
||||||
typename std::enable_if<sizeof...(Args) == 0>::type set_callback(const param_callback& callback) {}
|
|
||||||
|
|
||||||
/// \brief Set callback for particular transformation class types
|
|
||||||
///
|
|
||||||
/// Example below show how to set callback for one or multiple passes using this method.
|
|
||||||
///
|
|
||||||
/// pass_config->set_callback<ngraph::pass::ConvertBatchToSpace,
|
|
||||||
/// ngraph::pass::ConvertSpaceToBatch>(
|
|
||||||
/// [](const_node_ptr &node) -> bool {
|
|
||||||
/// // Disable transformations for cases when input shape rank is not
|
|
||||||
/// equal to 4
|
|
||||||
/// const auto input_shape_rank =
|
|
||||||
/// node->get_output_partial_shape(0).rank().get_length();
|
|
||||||
/// if (input_shape_rank != 4) {
|
|
||||||
/// return false;
|
|
||||||
/// }
|
|
||||||
/// return true;
|
|
||||||
/// });
|
|
||||||
///
|
|
||||||
/// Note that inside transformations you must provide code that work with this callback.
|
|
||||||
/// See example below:
|
|
||||||
///
|
|
||||||
/// if (transformation_callback(node)) {
|
|
||||||
/// return false; // exit from transformation
|
|
||||||
/// }
|
|
||||||
///
|
|
||||||
template <typename T, class... Args>
|
|
||||||
void set_callback(const param_callback& callback) {
|
|
||||||
m_callback_map[T::type_info] = callback;
|
|
||||||
set_callback<Args...>(callback);
|
|
||||||
}
|
|
||||||
|
|
||||||
/// \brief Get callback for given transformation type_info
|
|
||||||
/// \param type_info Transformation type_info
|
|
||||||
///
|
|
||||||
/// In case if callback wasn't set for given transformation type then global callback
|
|
||||||
/// will be returned. But if even global callback wasn't set then default callback will
|
|
||||||
/// be returned.
|
|
||||||
param_callback get_callback(const DiscreteTypeInfo& type_info) const;
|
|
||||||
|
|
||||||
/// \brief Get callback for given transformation class type
|
|
||||||
/// \return callback lambda function
|
|
||||||
template <typename T>
|
|
||||||
param_callback get_callback() const {
|
|
||||||
NGRAPH_SUPPRESS_DEPRECATED_START
|
|
||||||
return get_callback(T::type_info);
|
|
||||||
NGRAPH_SUPPRESS_DEPRECATED_END
|
|
||||||
}
|
|
||||||
|
|
||||||
/// \brief Check either transformation type is disabled or not
|
|
||||||
/// \param type_info Transformation type_info
|
|
||||||
/// \return true if transformation type was disabled and false otherwise
|
|
||||||
bool is_disabled(const DiscreteTypeInfo& type_info) const {
|
|
||||||
return m_disabled.count(type_info);
|
|
||||||
}
|
|
||||||
|
|
||||||
/// \brief Check either transformation class type is disabled or not
|
|
||||||
/// \return true if transformation type was disabled and false otherwise
|
|
||||||
template <typename T>
|
|
||||||
bool is_disabled() const {
|
|
||||||
NGRAPH_SUPPRESS_DEPRECATED_START
|
|
||||||
return is_disabled(T::type_info);
|
|
||||||
NGRAPH_SUPPRESS_DEPRECATED_END
|
|
||||||
}
|
|
||||||
|
|
||||||
/// \brief Check either transformation type is force enabled or not
|
|
||||||
/// \param type_info Transformation type_info
|
|
||||||
/// \return true if transformation type was force enabled and false otherwise
|
|
||||||
bool is_enabled(const DiscreteTypeInfo& type_info) const {
|
|
||||||
return m_enabled.count(type_info);
|
|
||||||
}
|
|
||||||
|
|
||||||
/// \brief Check either transformation class type is force enabled or not
|
|
||||||
/// \return true if transformation type was force enabled and false otherwise
|
|
||||||
template <typename T>
|
|
||||||
bool is_enabled() const {
|
|
||||||
return is_enabled(T::type_info);
|
|
||||||
}
|
|
||||||
|
|
||||||
void add_disabled_passes(const PassConfig& rhs);
|
|
||||||
|
|
||||||
private:
|
|
||||||
param_callback m_callback = [](const std::shared_ptr<const ::ngraph::Node>&) {
|
|
||||||
return false;
|
|
||||||
};
|
|
||||||
param_callback_map m_callback_map;
|
|
||||||
std::unordered_set<DiscreteTypeInfo> m_disabled;
|
|
||||||
std::unordered_set<DiscreteTypeInfo> m_enabled;
|
|
||||||
};
|
|
||||||
} // namespace pass
|
} // namespace pass
|
||||||
} // namespace ngraph
|
} // namespace ngraph
|
||||||
|
@ -5,27 +5,10 @@
|
|||||||
#pragma once
|
#pragma once
|
||||||
|
|
||||||
#include "ngraph/pass/pass.hpp"
|
#include "ngraph/pass/pass.hpp"
|
||||||
|
#include "openvino/pass/validate.hpp"
|
||||||
|
|
||||||
namespace ngraph {
|
namespace ngraph {
|
||||||
namespace pass {
|
namespace pass {
|
||||||
/// \brief The Validate pass performs sanity checks on attributes and inputs, and
|
using ov::pass::Validate;
|
||||||
/// computes output shapes and element types for all computation nodes in a given
|
|
||||||
/// computation graph.
|
|
||||||
///
|
|
||||||
/// \details The verification and inference is done via invoking each node's specific
|
|
||||||
/// implementation of \link ngraph::Node::validate_and_infer_types() \endlink function.
|
|
||||||
///
|
|
||||||
/// By default, the \ref ngraph::pass::Manager runs this pass after executing every
|
|
||||||
/// optimization pass. This is to ensure that any update to the graph by an optimization
|
|
||||||
/// pass does not break the shape and data type requirement on a computation node.
|
|
||||||
/// This default validation run can be changed via calling the
|
|
||||||
/// \link ngraph::pass::Manager::set_per_pass_validation(bool) \endlink function.
|
|
||||||
class NGRAPH_API Validate : public FunctionPass {
|
|
||||||
public:
|
|
||||||
NGRAPH_RTTI_DECLARATION;
|
|
||||||
|
|
||||||
Validate() : FunctionPass() {}
|
|
||||||
bool run_on_function(std::shared_ptr<ngraph::Function> f) override;
|
|
||||||
};
|
|
||||||
} // namespace pass
|
} // namespace pass
|
||||||
} // namespace ngraph
|
} // namespace ngraph
|
||||||
|
@ -14,44 +14,10 @@
|
|||||||
#include <utility>
|
#include <utility>
|
||||||
|
|
||||||
#include "ngraph/pass/pass.hpp"
|
#include "ngraph/pass/pass.hpp"
|
||||||
|
#include "openvino/pass/visualize_tree.hpp"
|
||||||
class HeightMap;
|
|
||||||
|
|
||||||
using visualize_tree_ops_map_t =
|
|
||||||
std::unordered_map<ngraph::Node::type_info_t, std::function<void(const ngraph::Node&, std::ostream& ss)>>;
|
|
||||||
|
|
||||||
namespace ngraph {
|
namespace ngraph {
|
||||||
namespace pass {
|
namespace pass {
|
||||||
class NGRAPH_API VisualizeTree : public FunctionPass {
|
using ov::pass::VisualizeTree;
|
||||||
public:
|
|
||||||
NGRAPH_RTTI_DECLARATION;
|
|
||||||
|
|
||||||
using node_modifiers_t = std::function<void(const Node& node, std::vector<std::string>& attributes)>;
|
|
||||||
VisualizeTree(const std::string& file_name, node_modifiers_t nm = nullptr, bool dot_only = false);
|
|
||||||
bool run_on_function(std::shared_ptr<ngraph::Function>) override;
|
|
||||||
|
|
||||||
void set_ops_to_details(const visualize_tree_ops_map_t& ops_map) {
|
|
||||||
m_ops_to_details = ops_map;
|
|
||||||
}
|
|
||||||
|
|
||||||
protected:
|
|
||||||
void add_node_arguments(std::shared_ptr<Node> node,
|
|
||||||
std::unordered_map<Node*, HeightMap>& height_maps,
|
|
||||||
size_t& fake_node_ctr);
|
|
||||||
std::string add_attributes(std::shared_ptr<Node> node);
|
|
||||||
virtual std::string get_attributes(std::shared_ptr<Node> node);
|
|
||||||
virtual std::string get_node_name(std::shared_ptr<Node> node);
|
|
||||||
std::string get_constant_value(std::shared_ptr<Node> node, size_t max_elements = 7);
|
|
||||||
|
|
||||||
void render() const;
|
|
||||||
|
|
||||||
std::stringstream m_ss;
|
|
||||||
std::string m_name;
|
|
||||||
std::set<std::shared_ptr<Node>> m_nodes_with_attributes;
|
|
||||||
visualize_tree_ops_map_t m_ops_to_details;
|
|
||||||
node_modifiers_t m_node_modifiers = nullptr;
|
|
||||||
bool m_dot_only;
|
|
||||||
static const int max_jump_distance;
|
|
||||||
};
|
|
||||||
} // namespace pass
|
} // namespace pass
|
||||||
} // namespace ngraph
|
} // namespace ngraph
|
||||||
|
@ -16,255 +16,21 @@
|
|||||||
#include "ngraph/pattern/op/any_output.hpp"
|
#include "ngraph/pattern/op/any_output.hpp"
|
||||||
#include "ngraph/pattern/op/label.hpp"
|
#include "ngraph/pattern/op/label.hpp"
|
||||||
#include "ngraph/pattern/op/skip.hpp"
|
#include "ngraph/pattern/op/skip.hpp"
|
||||||
|
#include "openvino/pass/pattern/matcher.hpp"
|
||||||
|
|
||||||
namespace ngraph {
|
namespace ov {
|
||||||
namespace pass {
|
namespace pass {
|
||||||
class GraphRewrite;
|
class GraphRewrite;
|
||||||
}
|
}
|
||||||
|
} // namespace ov
|
||||||
|
namespace ngraph {
|
||||||
|
namespace pass {
|
||||||
|
using ov::pass::GraphRewrite;
|
||||||
|
}
|
||||||
|
|
||||||
namespace pattern {
|
namespace pattern {
|
||||||
class Matcher;
|
using ov::pass::pattern::Matcher;
|
||||||
|
using ov::pass::pattern::MatcherState;
|
||||||
class NGRAPH_API MatcherState {
|
using ov::pass::pattern::RecurrentMatcher;
|
||||||
public:
|
|
||||||
MatcherState(Matcher*);
|
|
||||||
bool finish(bool is_successful);
|
|
||||||
~MatcherState();
|
|
||||||
|
|
||||||
protected:
|
|
||||||
Matcher* m_matcher;
|
|
||||||
PatternValueMap m_pattern_value_map;
|
|
||||||
PatternValueMaps m_pattern_value_maps;
|
|
||||||
size_t m_watermark;
|
|
||||||
size_t m_capture_size;
|
|
||||||
bool m_restore{true};
|
|
||||||
};
|
|
||||||
|
|
||||||
/// Matcher looks for node patterns in a computation graph. The patterns are described by an
|
|
||||||
/// automaton that is described by an extended computation graph. The matcher executes
|
|
||||||
/// by attempting to match the start node of the pattern to a computation graph value
|
|
||||||
/// (output of a Node). In addition to determing if a match occurs, a pattern node may add
|
|
||||||
/// graph nodes to a list of matched nodes, associate nodes with graph values, and start
|
|
||||||
/// submatches. Submatches add match state changes to the enclosing match if the submatch
|
|
||||||
/// succeeds; otherwise the state is reverted.
|
|
||||||
///
|
|
||||||
/// The default match behavior of a pattern node with a graph nodes is that the computation
|
|
||||||
/// graph value is added to the end of the matched value list and the match succeeds if the
|
|
||||||
/// node/pattern types match and the input values match. In the case of a commutative node,
|
|
||||||
/// the inputs can match in any order. If the matcher is in strict mode, the graph value
|
|
||||||
/// element type and shape must also match.
|
|
||||||
///
|
|
||||||
/// Pattern nodes that have different match behavior are in ngraph::pattern::op and have
|
|
||||||
/// descriptions of their match behavior.
|
|
||||||
class NGRAPH_API Matcher {
|
|
||||||
public:
|
|
||||||
using PatternMap = ngraph::pattern::PatternMap;
|
|
||||||
|
|
||||||
// Avoid implicit string construction from nullptr.
|
|
||||||
Matcher(const std::shared_ptr<Node> pattern_node, std::nullptr_t name) = delete;
|
|
||||||
|
|
||||||
Matcher() {}
|
|
||||||
Matcher(Output<Node>& pattern_node) : m_pattern_node{pattern_node} {}
|
|
||||||
|
|
||||||
Matcher(Output<Node>& pattern_node, const std::string& name) : m_pattern_node(pattern_node), m_name{name} {}
|
|
||||||
|
|
||||||
/// \brief Constructs a Matcher object
|
|
||||||
///
|
|
||||||
/// \param pattern_node is a pattern sub graph that will be matched against input graphs
|
|
||||||
/// \param name is a string which is used for logging and disabling a matcher
|
|
||||||
/// \param strict_mode forces a matcher to consider shapes and ET of nodes
|
|
||||||
Matcher(const Output<Node>& pattern_node, const std::string& name, bool strict_mode)
|
|
||||||
: m_pattern_node(pattern_node),
|
|
||||||
m_name(name),
|
|
||||||
m_strict_mode(strict_mode) {}
|
|
||||||
|
|
||||||
// Some matches should start on a node rather than an output. These three constructors
|
|
||||||
// are transition until we work out the right way to do that.
|
|
||||||
Matcher(std::shared_ptr<Node> pattern_node);
|
|
||||||
Matcher(std::shared_ptr<Node> pattern_node, const std::string& name);
|
|
||||||
Matcher(std::shared_ptr<Node> pattern_node, const std::string& name, bool strict_mode);
|
|
||||||
|
|
||||||
virtual ~Matcher() {}
|
|
||||||
/// \brief Matches a pattern to \p graph_node
|
|
||||||
///
|
|
||||||
/// \param graph_value is an input graph to be matched against
|
|
||||||
bool match(const Output<Node>& graph_value);
|
|
||||||
|
|
||||||
bool match(std::shared_ptr<Node> graph_node);
|
|
||||||
|
|
||||||
/// \brief Matches a pattern to \p graph_node
|
|
||||||
///
|
|
||||||
/// \param graph_value is an input graph to be matched against
|
|
||||||
/// \param previous_matches contains previous mappings from labels to nodes to use
|
|
||||||
bool match(const Output<Node>& graph_value, const PatternMap& previous_matches);
|
|
||||||
bool match(const Output<Node>& graph_value, const PatternValueMap& previous_matches);
|
|
||||||
|
|
||||||
template <typename T>
|
|
||||||
static std::shared_ptr<T> unique_match(std::shared_ptr<Node> node) {
|
|
||||||
std::shared_ptr<T> matched;
|
|
||||||
for (auto arg : node->input_values()) {
|
|
||||||
if (auto t_casted = ov::as_type_ptr<T>(arg.get_node_shared_ptr())) {
|
|
||||||
if (matched) {
|
|
||||||
throw ngraph_error("There's more than two arguments of the same type");
|
|
||||||
} else {
|
|
||||||
matched = t_casted;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return matched;
|
|
||||||
}
|
|
||||||
|
|
||||||
bool is_contained_match(const NodeVector& exclusions = {}, bool ignore_unused = true);
|
|
||||||
const NodeVector get_matched_nodes() {
|
|
||||||
return as_node_vector(m_matched_list);
|
|
||||||
}
|
|
||||||
const OutputVector& get_matched_values() const {
|
|
||||||
return m_matched_list;
|
|
||||||
}
|
|
||||||
OutputVector& get_matched_values() {
|
|
||||||
return m_matched_list;
|
|
||||||
}
|
|
||||||
void reset() {}
|
|
||||||
const std::string& get_name() {
|
|
||||||
return m_name;
|
|
||||||
}
|
|
||||||
std::shared_ptr<Node> get_pattern() {
|
|
||||||
return m_pattern_node.get_node_shared_ptr();
|
|
||||||
}
|
|
||||||
Output<Node> get_pattern_value() {
|
|
||||||
return m_pattern_node;
|
|
||||||
}
|
|
||||||
std::shared_ptr<Node> get_match_root();
|
|
||||||
Output<Node> get_match_value();
|
|
||||||
PatternMap get_pattern_map() const;
|
|
||||||
PatternValueMap& get_pattern_value_map() {
|
|
||||||
return m_pattern_map;
|
|
||||||
}
|
|
||||||
PatternValueMaps& get_pattern_value_maps() {
|
|
||||||
return m_pattern_value_maps;
|
|
||||||
}
|
|
||||||
/// \brief Low-level helper to match recurring patterns
|
|
||||||
///
|
|
||||||
/// \param graph is a graph to be matched against
|
|
||||||
/// \param pattern is a recurring pattern
|
|
||||||
/// \param rpattern specifies a node to recur from next
|
|
||||||
/// \param patterns a map from labels to matches
|
|
||||||
|
|
||||||
size_t add_node(Output<Node> node);
|
|
||||||
|
|
||||||
bool virtual match_value(const ngraph::Output<Node>& pattern_value, const ngraph::Output<Node>& graph_value);
|
|
||||||
|
|
||||||
bool is_strict_mode() {
|
|
||||||
return m_strict_mode;
|
|
||||||
}
|
|
||||||
virtual bool match_arguments(Node* pattern_node, const std::shared_ptr<Node>& graph_node);
|
|
||||||
|
|
||||||
void capture(const std::set<Node*>& static_nodes);
|
|
||||||
|
|
||||||
void clear_state();
|
|
||||||
|
|
||||||
size_t get_number_of_recurrent_matches() const {
|
|
||||||
return m_pattern_value_maps.size();
|
|
||||||
}
|
|
||||||
NodeVector get_bound_nodes_for_pattern(const Output<Node>& pattern) const;
|
|
||||||
size_t get_number_of_bound_labels() const;
|
|
||||||
/// \brief Try a match
|
|
||||||
MatcherState start_match();
|
|
||||||
|
|
||||||
Output<Node> m_match_root;
|
|
||||||
Output<Node> m_pattern_node;
|
|
||||||
PatternValueMap m_pattern_map;
|
|
||||||
PatternValueMaps m_pattern_value_maps;
|
|
||||||
OutputVector m_matched_list;
|
|
||||||
|
|
||||||
protected:
|
|
||||||
bool match_permutation(const OutputVector& pattern_args, const OutputVector& args);
|
|
||||||
|
|
||||||
std::string m_name{"unnamed"};
|
|
||||||
bool m_strict_mode{false};
|
|
||||||
};
|
|
||||||
|
|
||||||
class NGRAPH_API RecurrentMatcher {
|
|
||||||
public:
|
|
||||||
/// \brief Constructs a RecurrentMatcher object. Reccurent Matchers are used to match
|
|
||||||
/// repeating patterns (e.g. RNN, LSTM, GRU cells)
|
|
||||||
///
|
|
||||||
/// \param initial_pattern is a pattern sub graph describing the initial cell
|
|
||||||
/// \param pattern is a pattern sub graph describing an individual cell
|
|
||||||
/// \param rpattern is a (recurring) label to denote which node the next match should
|
|
||||||
/// start at
|
|
||||||
/// \param correlated_patterns is a set of labels whose bound nodes must remain the same
|
|
||||||
/// across all cells
|
|
||||||
RecurrentMatcher(const Output<Node>& initial_pattern,
|
|
||||||
const Output<Node>& pattern,
|
|
||||||
const std::shared_ptr<Node>& rpattern,
|
|
||||||
const std::set<std::shared_ptr<Node>>& correlated_patterns)
|
|
||||||
: m_initial_pattern(initial_pattern),
|
|
||||||
m_pattern(pattern),
|
|
||||||
m_recurrent_pattern(rpattern),
|
|
||||||
m_correlated_patterns(correlated_patterns) {}
|
|
||||||
|
|
||||||
/// \brief Constructs a RecurrentMatcher object. Reccurent Matchers are used to match
|
|
||||||
/// repeating patterns (e.g. RNN, LSTM, GRU cells)
|
|
||||||
///
|
|
||||||
/// \param pattern is a pattern sub graph describing an individual cell
|
|
||||||
/// \param rpattern is a (recurring) label to denote which node the next match should
|
|
||||||
/// start at
|
|
||||||
/// \param correlated_patterns is a set of labels whose bound nodes must remain the same
|
|
||||||
/// across all cells
|
|
||||||
RecurrentMatcher(const Output<Node>& pattern,
|
|
||||||
const std::shared_ptr<Node>& rpattern,
|
|
||||||
const std::set<std::shared_ptr<Node>>& correlated_patterns)
|
|
||||||
: RecurrentMatcher(pattern, pattern, rpattern, correlated_patterns) {}
|
|
||||||
|
|
||||||
RecurrentMatcher(const Output<Node>& initial_pattern,
|
|
||||||
const Output<Node>& pattern,
|
|
||||||
const std::shared_ptr<Node>& rpattern,
|
|
||||||
const std::set<std::shared_ptr<op::Label>>& correlated_patterns);
|
|
||||||
|
|
||||||
RecurrentMatcher(const Output<Node>& pattern,
|
|
||||||
const std::shared_ptr<Node>& rpattern,
|
|
||||||
const std::set<std::shared_ptr<op::Label>>& correlated_patterns)
|
|
||||||
: RecurrentMatcher(pattern, pattern, rpattern, correlated_patterns) {}
|
|
||||||
|
|
||||||
/// \brief Returns a vector of bound nodes for a given label (used in a pattern
|
|
||||||
/// describing an individual cell
|
|
||||||
NodeVector get_bound_nodes_for_pattern(const std::shared_ptr<Node>& pattern) const {
|
|
||||||
if (m_matches.count(pattern) == 0) {
|
|
||||||
throw ngraph_error("No bound nodes for a given label");
|
|
||||||
}
|
|
||||||
|
|
||||||
return as_node_vector(m_matches.at(pattern));
|
|
||||||
}
|
|
||||||
|
|
||||||
size_t get_number_of_recurrent_matches() const {
|
|
||||||
if (m_matches.size() == 0) {
|
|
||||||
return 0;
|
|
||||||
}
|
|
||||||
|
|
||||||
return (*m_matches.begin()).second.size();
|
|
||||||
}
|
|
||||||
|
|
||||||
size_t get_number_of_bound_labels() const {
|
|
||||||
return m_matches.size();
|
|
||||||
}
|
|
||||||
/// \brief Tries to match a pattern for an individual cell to a given \p graph
|
|
||||||
bool match(Output<Node> graph);
|
|
||||||
|
|
||||||
std::shared_ptr<Node> get_match_root() {
|
|
||||||
return m_match_root.get_node_shared_ptr();
|
|
||||||
}
|
|
||||||
Output<Node> get_match_value() {
|
|
||||||
return m_match_root;
|
|
||||||
}
|
|
||||||
|
|
||||||
private:
|
|
||||||
Output<Node> m_initial_pattern;
|
|
||||||
Output<Node> m_pattern;
|
|
||||||
std::shared_ptr<Node> m_recurrent_pattern;
|
|
||||||
const std::set<std::shared_ptr<Node>> m_correlated_patterns;
|
|
||||||
RPatternValueMap m_matches;
|
|
||||||
Output<Node> m_match_root;
|
|
||||||
};
|
|
||||||
} // namespace pattern
|
} // namespace pattern
|
||||||
} // namespace ngraph
|
} // namespace ngraph
|
||||||
|
@ -6,38 +6,12 @@
|
|||||||
|
|
||||||
#include "ngraph/node.hpp"
|
#include "ngraph/node.hpp"
|
||||||
#include "ngraph/pattern/op/pattern.hpp"
|
#include "ngraph/pattern/op/pattern.hpp"
|
||||||
|
#include "openvino/pass/pattern/op/any.hpp"
|
||||||
|
|
||||||
namespace ngraph {
|
namespace ngraph {
|
||||||
namespace pattern {
|
namespace pattern {
|
||||||
namespace op {
|
namespace op {
|
||||||
/// The graph value is to the matched value list. If the predicate is true for the node
|
using ov::pass::pattern::op::Any;
|
||||||
/// and the arguments match, the match succeeds.
|
|
||||||
class NGRAPH_API Any : public Pattern {
|
|
||||||
public:
|
|
||||||
static constexpr NodeTypeInfo type_info{"patternAny", 0};
|
|
||||||
const NodeTypeInfo& get_type_info() const override;
|
|
||||||
/// \brief creates a Any node containing a sub-pattern described by \sa type and \sa
|
|
||||||
/// shape.
|
|
||||||
Any(const element::Type& type, const PartialShape& s, ValuePredicate pred, const OutputVector& wrapped_values)
|
|
||||||
: Pattern(wrapped_values, pred) {
|
|
||||||
set_output_type(0, type, s);
|
|
||||||
}
|
|
||||||
Any(const element::Type& type, const PartialShape& s, NodePredicate pred, const NodeVector& wrapped_values)
|
|
||||||
: Any(type, s, as_value_predicate(pred), as_output_vector(wrapped_values)) {}
|
|
||||||
/// \brief creates a Any node containing a sub-pattern described by the type and
|
|
||||||
/// shape of \sa node.
|
|
||||||
Any(const Output<Node>& node, ValuePredicate pred, const OutputVector& wrapped_values)
|
|
||||||
: Any(node.get_element_type(), node.get_partial_shape(), pred, wrapped_values) {}
|
|
||||||
Any(const Output<Node>& node, NodePredicate pred, const NodeVector& wrapped_values)
|
|
||||||
: Any(node.get_element_type(),
|
|
||||||
node.get_partial_shape(),
|
|
||||||
as_value_predicate(pred),
|
|
||||||
as_output_vector(wrapped_values)) {}
|
|
||||||
|
|
||||||
bool match_value(pattern::Matcher* matcher,
|
|
||||||
const Output<Node>& pattern_value,
|
|
||||||
const Output<Node>& graph_value) override;
|
|
||||||
};
|
|
||||||
} // namespace op
|
} // namespace op
|
||||||
} // namespace pattern
|
} // namespace pattern
|
||||||
} // namespace ngraph
|
} // namespace ngraph
|
||||||
|
@ -6,47 +6,12 @@
|
|||||||
|
|
||||||
#include "ngraph/node.hpp"
|
#include "ngraph/node.hpp"
|
||||||
#include "ngraph/pattern/op/pattern.hpp"
|
#include "ngraph/pattern/op/pattern.hpp"
|
||||||
|
#include "openvino/pass/pattern/op/any_of.hpp"
|
||||||
|
|
||||||
namespace ngraph {
|
namespace ngraph {
|
||||||
namespace pattern {
|
namespace pattern {
|
||||||
namespace op {
|
namespace op {
|
||||||
/// The graph value is added to the matched values list. If the predicate is true for
|
using ov::pass::pattern::op::AnyOf;
|
||||||
/// the
|
|
||||||
/// graph node, a submatch is performed on the input of AnyOf and each input of the
|
|
||||||
/// graph node. The first match that succeeds results in a successful match. Otherwise
|
|
||||||
/// the match fails.
|
|
||||||
///
|
|
||||||
/// AnyOf may be given a type and shape for use in strict mode.
|
|
||||||
class NGRAPH_API AnyOf : public Pattern {
|
|
||||||
public:
|
|
||||||
static constexpr NodeTypeInfo type_info{"patternAnyOf", 0};
|
|
||||||
const NodeTypeInfo& get_type_info() const override;
|
|
||||||
/// \brief creates a AnyOf node containing a sub-pattern described by \sa type and
|
|
||||||
/// \sa shape.
|
|
||||||
AnyOf(const element::Type& type, const PartialShape& s, ValuePredicate pred, const OutputVector& wrapped_values)
|
|
||||||
: Pattern(wrapped_values, pred) {
|
|
||||||
if (wrapped_values.size() != 1) {
|
|
||||||
throw ngraph_error("AnyOf expects exactly one argument");
|
|
||||||
}
|
|
||||||
set_output_type(0, type, s);
|
|
||||||
}
|
|
||||||
AnyOf(const element::Type& type, const PartialShape& s, NodePredicate pred, const NodeVector& wrapped_values)
|
|
||||||
: AnyOf(
|
|
||||||
type,
|
|
||||||
s,
|
|
||||||
[pred](const Output<Node>& value) {
|
|
||||||
return pred(value.get_node_shared_ptr());
|
|
||||||
},
|
|
||||||
as_output_vector(wrapped_values)) {}
|
|
||||||
|
|
||||||
/// \brief creates a AnyOf node containing a sub-pattern described by the type and
|
|
||||||
/// shape of \sa node.
|
|
||||||
AnyOf(const Output<Node>& node, ValuePredicate pred, const OutputVector& wrapped_values)
|
|
||||||
: AnyOf(node.get_element_type(), node.get_partial_shape(), pred, wrapped_values) {}
|
|
||||||
AnyOf(std::shared_ptr<Node> node, NodePredicate pred, const NodeVector& wrapped_values)
|
|
||||||
: AnyOf(node, as_value_predicate(pred), as_output_vector(wrapped_values)) {}
|
|
||||||
bool match_value(Matcher* matcher, const Output<Node>& pattern_value, const Output<Node>& graph_value) override;
|
|
||||||
};
|
|
||||||
} // namespace op
|
} // namespace op
|
||||||
} // namespace pattern
|
} // namespace pattern
|
||||||
} // namespace ngraph
|
} // namespace ngraph
|
||||||
|
@ -6,23 +6,12 @@
|
|||||||
|
|
||||||
#include "ngraph/node.hpp"
|
#include "ngraph/node.hpp"
|
||||||
#include "ngraph/pattern/op/pattern.hpp"
|
#include "ngraph/pattern/op/pattern.hpp"
|
||||||
|
#include "openvino/pass/pattern/op/any_output.hpp"
|
||||||
|
|
||||||
namespace ngraph {
|
namespace ngraph {
|
||||||
namespace pattern {
|
namespace pattern {
|
||||||
namespace op {
|
namespace op {
|
||||||
/// Matches any output of a node
|
using ov::pass::pattern::op::AnyOutput;
|
||||||
class NGRAPH_API AnyOutput : public Pattern {
|
|
||||||
public:
|
|
||||||
static constexpr NodeTypeInfo type_info{"patternAnyOutput", 0};
|
|
||||||
const NodeTypeInfo& get_type_info() const override;
|
|
||||||
/// \brief creates an AnyOutput node matching any output of a node
|
|
||||||
/// \param node The node to match
|
|
||||||
AnyOutput(const std::shared_ptr<Node>& pattern) : Pattern({pattern->output(0)}) {}
|
|
||||||
|
|
||||||
bool match_value(pattern::Matcher* matcher,
|
|
||||||
const Output<Node>& pattern_value,
|
|
||||||
const Output<Node>& graph_value) override;
|
|
||||||
};
|
|
||||||
} // namespace op
|
} // namespace op
|
||||||
} // namespace pattern
|
} // namespace pattern
|
||||||
} // namespace ngraph
|
} // namespace ngraph
|
||||||
|
@ -6,48 +6,12 @@
|
|||||||
|
|
||||||
#include "ngraph/node.hpp"
|
#include "ngraph/node.hpp"
|
||||||
#include "ngraph/pattern/op/pattern.hpp"
|
#include "ngraph/pattern/op/pattern.hpp"
|
||||||
|
#include "openvino/pass/pattern/op/branch.hpp"
|
||||||
|
|
||||||
namespace ngraph {
|
namespace ngraph {
|
||||||
namespace pattern {
|
namespace pattern {
|
||||||
namespace op {
|
namespace op {
|
||||||
/// A branch adds a loop to the pattern. The branch match is successful if the
|
using ov::pass::pattern::op::Branch;
|
||||||
/// destination node pattern matches the graph value. The destination node is a node in
|
|
||||||
/// the pattern graph that will not have been created some time after the Branch node is
|
|
||||||
/// created; use set_destination to add it.
|
|
||||||
///
|
|
||||||
/// The branch destination is not stored as a shared pointer to prevent reference
|
|
||||||
/// cycles. Thus the destination node must be referenced in some other way to prevent it
|
|
||||||
/// from being deleted.
|
|
||||||
class NGRAPH_API Branch : public Pattern {
|
|
||||||
public:
|
|
||||||
static constexpr NodeTypeInfo type_info{"patternBranch", 0};
|
|
||||||
const NodeTypeInfo& get_type_info() const override;
|
|
||||||
/// \brief Creates a Branch pattern
|
|
||||||
/// \param pattern the destinationing pattern
|
|
||||||
/// \param labels Labels where the destination may occur
|
|
||||||
Branch() : Pattern(OutputVector{}) {
|
|
||||||
set_output_type(0, element::f32, Shape{});
|
|
||||||
}
|
|
||||||
|
|
||||||
void set_destination(const Output<Node>& destination) {
|
|
||||||
m_destination_node = destination.get_node();
|
|
||||||
m_destination_index = destination.get_index();
|
|
||||||
}
|
|
||||||
|
|
||||||
Output<Node> get_destination() const {
|
|
||||||
return m_destination_node == nullptr
|
|
||||||
? Output<Node>()
|
|
||||||
: Output<Node>{m_destination_node->shared_from_this(), m_destination_index};
|
|
||||||
}
|
|
||||||
|
|
||||||
bool match_value(pattern::Matcher* matcher,
|
|
||||||
const Output<Node>& pattern_value,
|
|
||||||
const Output<Node>& graph_value) override;
|
|
||||||
|
|
||||||
protected:
|
|
||||||
Node* m_destination_node{nullptr};
|
|
||||||
size_t m_destination_index{0};
|
|
||||||
};
|
|
||||||
} // namespace op
|
} // namespace op
|
||||||
} // namespace pattern
|
} // namespace pattern
|
||||||
} // namespace ngraph
|
} // namespace ngraph
|
||||||
|
@ -6,37 +6,12 @@
|
|||||||
|
|
||||||
#include "ngraph/node.hpp"
|
#include "ngraph/node.hpp"
|
||||||
#include "ngraph/pattern/op/pattern.hpp"
|
#include "ngraph/pattern/op/pattern.hpp"
|
||||||
|
#include "openvino/pass/pattern/op/capture.hpp"
|
||||||
|
|
||||||
namespace ngraph {
|
namespace ngraph {
|
||||||
namespace pattern {
|
namespace pattern {
|
||||||
namespace op {
|
namespace op {
|
||||||
/// Experimental for support of recurrent matches.
|
using ov::pass::pattern::op::Capture;
|
||||||
///
|
|
||||||
/// Capture adds the pattern value map to a list of pattern value maps and resets
|
|
||||||
/// matches for pattern nodes not in the static node list. The match always succeeds.
|
|
||||||
class NGRAPH_API Capture : public Pattern {
|
|
||||||
public:
|
|
||||||
static constexpr NodeTypeInfo type_info{"patternCapture", 0};
|
|
||||||
const NodeTypeInfo& get_type_info() const override;
|
|
||||||
Capture(const Output<Node>& arg) : Pattern({arg}) {
|
|
||||||
set_output_type(0, arg.get_element_type(), arg.get_partial_shape());
|
|
||||||
}
|
|
||||||
|
|
||||||
/// \brief static nodes are retained after a capture. All other nodes are dropped
|
|
||||||
std::set<Node*> get_static_nodes() {
|
|
||||||
return m_static_nodes;
|
|
||||||
}
|
|
||||||
void set_static_nodes(const std::set<Node*>& static_nodes) {
|
|
||||||
m_static_nodes = static_nodes;
|
|
||||||
}
|
|
||||||
|
|
||||||
virtual bool match_value(pattern::Matcher* matcher,
|
|
||||||
const Output<Node>& pattern_value,
|
|
||||||
const Output<Node>& graph_value) override;
|
|
||||||
|
|
||||||
protected:
|
|
||||||
std::set<Node*> m_static_nodes;
|
|
||||||
};
|
|
||||||
} // namespace op
|
} // namespace op
|
||||||
} // namespace pattern
|
} // namespace pattern
|
||||||
} // namespace ngraph
|
} // namespace ngraph
|
||||||
|
@ -6,106 +6,14 @@
|
|||||||
|
|
||||||
#include "ngraph/node.hpp"
|
#include "ngraph/node.hpp"
|
||||||
#include "ngraph/pattern/op/pattern.hpp"
|
#include "ngraph/pattern/op/pattern.hpp"
|
||||||
|
#include "openvino/pass/pattern/op/label.hpp"
|
||||||
|
|
||||||
namespace ngraph {
|
namespace ngraph {
|
||||||
namespace pattern {
|
namespace pattern {
|
||||||
namespace op {
|
namespace op {
|
||||||
/// Fails if the predicate returns false on the graph value.
|
using ov::pass::pattern::op::Label;
|
||||||
///
|
|
||||||
/// The graph value is added to the matched values list. If the Label is already
|
|
||||||
/// associated with a value, the match succeeds if the value is the same as the graph
|
|
||||||
/// value. Otherwise, the label is associated with the graph value and the match
|
|
||||||
/// succeeds if the pattern input matches the graph value.
|
|
||||||
///
|
|
||||||
/// DEPRECATED: If no inputs are given to Label, a True node is serves as the input. If
|
|
||||||
/// more than one inputs are given, an Or pattern of the inputs serves as the input.
|
|
||||||
class NGRAPH_API Label : public Pattern {
|
|
||||||
public:
|
|
||||||
static constexpr NodeTypeInfo type_info{"patternLabel", 0};
|
|
||||||
const NodeTypeInfo& get_type_info() const override;
|
|
||||||
/// \brief creates a Label node containing a sub-pattern described by \sa type and
|
|
||||||
/// \sa shape.
|
|
||||||
///
|
|
||||||
/// this Label node can be bound only to the nodes in the input graph
|
|
||||||
/// that match the pattern specified by \sa wrapped_nodes
|
|
||||||
/// Example:
|
|
||||||
/// \code{.cpp}
|
|
||||||
/// auto add = a + b; // a and b are op::Parameter in this example
|
|
||||||
/// auto label = std::make_shared<pattern::op::Label>(element::f32,
|
|
||||||
/// Shape{2,2},
|
|
||||||
/// nullptr,
|
|
||||||
/// OutputVector{add});
|
|
||||||
/// \endcode
|
|
||||||
Label(const element::Type& type,
|
|
||||||
const PartialShape& s,
|
|
||||||
const ValuePredicate pred,
|
|
||||||
const OutputVector& wrapped_values)
|
|
||||||
: Pattern(OutputVector{wrap_values(wrapped_values)}, pred) {
|
|
||||||
set_output_type(0, type, s);
|
|
||||||
}
|
|
||||||
|
|
||||||
explicit Label(const element::Type& type = element::dynamic, const PartialShape& s = PartialShape::dynamic())
|
|
||||||
: Label(
|
|
||||||
type,
|
|
||||||
s,
|
|
||||||
[](const Output<Node>&) {
|
|
||||||
return true;
|
|
||||||
},
|
|
||||||
OutputVector()) {}
|
|
||||||
|
|
||||||
Label(const element::Type& type, const PartialShape& s, ValuePredicate pred)
|
|
||||||
: Label(type, s, pred, OutputVector{}) {}
|
|
||||||
|
|
||||||
Label(const element::Type& type, const PartialShape& s, NodePredicate pred)
|
|
||||||
: Label(type, s, as_value_predicate(pred), OutputVector{}) {}
|
|
||||||
|
|
||||||
Label(const element::Type& type, const PartialShape& s, const NodePredicate pred, const NodeVector& wrapped_values)
|
|
||||||
: Label(type, s, as_value_predicate(pred), as_output_vector(wrapped_values)) {}
|
|
||||||
|
|
||||||
/// \brief creates a Label node containing a sub-pattern described by the type and
|
|
||||||
/// shape of \sa node.
|
|
||||||
///
|
|
||||||
/// this Label node can be bound only to the nodes in the input graph
|
|
||||||
/// that match the pattern specified by \sa wrapped_values
|
|
||||||
/// Example:
|
|
||||||
/// \code{.cpp}
|
|
||||||
/// auto add = a + b; // a and b are op::Parameter in this example
|
|
||||||
/// auto label = std::make_shared<pattern::op::Label>(add,
|
|
||||||
/// nullptr,
|
|
||||||
/// OutputVector{add});
|
|
||||||
/// \endcode
|
|
||||||
Label(const Output<Node>& value, const ValuePredicate pred, const OutputVector& wrapped_values)
|
|
||||||
: Label(value.get_element_type(), value.get_partial_shape(), pred, wrapped_values) {}
|
|
||||||
Label(const Output<Node>& value, const ValuePredicate pred)
|
|
||||||
: Label(value.get_element_type(), value.get_partial_shape(), pred, OutputVector{}) {}
|
|
||||||
|
|
||||||
Label(const Output<Node>& value, const NodePredicate pred)
|
|
||||||
: Label(value.get_element_type(), value.get_partial_shape(), as_value_predicate(pred), OutputVector{}) {}
|
|
||||||
Label(const Output<Node>& value)
|
|
||||||
: Label(
|
|
||||||
value.get_element_type(),
|
|
||||||
value.get_partial_shape(),
|
|
||||||
[](const Output<Node>&) {
|
|
||||||
return true;
|
|
||||||
},
|
|
||||||
OutputVector{}) {}
|
|
||||||
Label(const Output<Node>& node, const NodePredicate pred, const NodeVector& wrapped_values)
|
|
||||||
: Label(node.get_element_type(),
|
|
||||||
node.get_partial_shape(),
|
|
||||||
as_value_predicate(pred),
|
|
||||||
as_output_vector(wrapped_values)) {}
|
|
||||||
|
|
||||||
bool match_value(Matcher* matcher, const Output<Node>& pattern_value, const Output<Node>& graph_value) override;
|
|
||||||
|
|
||||||
protected:
|
|
||||||
static Output<Node> wrap_values(const OutputVector& wrapped_values);
|
|
||||||
};
|
|
||||||
} // namespace op
|
} // namespace op
|
||||||
|
|
||||||
NGRAPH_API
|
using ov::pass::pattern::any_input;
|
||||||
std::shared_ptr<Node> any_input();
|
|
||||||
|
|
||||||
NGRAPH_API
|
|
||||||
std::shared_ptr<Node> any_input(const pattern::op::ValuePredicate& pred);
|
|
||||||
} // namespace pattern
|
} // namespace pattern
|
||||||
} // namespace ngraph
|
} // namespace ngraph
|
||||||
|
@ -6,25 +6,12 @@
|
|||||||
|
|
||||||
#include "ngraph/node.hpp"
|
#include "ngraph/node.hpp"
|
||||||
#include "ngraph/pattern/op/pattern.hpp"
|
#include "ngraph/pattern/op/pattern.hpp"
|
||||||
|
#include "openvino/pass/pattern/op/or.hpp"
|
||||||
|
|
||||||
namespace ngraph {
|
namespace ngraph {
|
||||||
namespace pattern {
|
namespace pattern {
|
||||||
namespace op {
|
namespace op {
|
||||||
/// A submatch on the graph value is performed on each input to the Or; the match
|
using ov::pass::pattern::op::Or;
|
||||||
/// succeeds on the first match. Otherwise the match fails.
|
|
||||||
class NGRAPH_API Or : public Pattern {
|
|
||||||
public:
|
|
||||||
static constexpr NodeTypeInfo type_info{"patternOr", 0};
|
|
||||||
const NodeTypeInfo& get_type_info() const override;
|
|
||||||
/// \brief creates an Or node matching one of several sub-patterns in order. Does
|
|
||||||
/// not add node to match list.
|
|
||||||
/// \param patterns The patterns to try for matching
|
|
||||||
Or(const OutputVector& patterns) : Pattern(patterns) {}
|
|
||||||
|
|
||||||
bool match_value(pattern::Matcher* matcher,
|
|
||||||
const Output<Node>& pattern_value,
|
|
||||||
const Output<Node>& graph_value) override;
|
|
||||||
};
|
|
||||||
} // namespace op
|
} // namespace op
|
||||||
} // namespace pattern
|
} // namespace pattern
|
||||||
} // namespace ngraph
|
} // namespace ngraph
|
||||||
|
@ -7,8 +7,10 @@
|
|||||||
#include <functional>
|
#include <functional>
|
||||||
|
|
||||||
#include "ngraph/node.hpp"
|
#include "ngraph/node.hpp"
|
||||||
|
#include "openvino/pass/pattern/op/pattern.hpp"
|
||||||
|
|
||||||
namespace ngraph {
|
namespace ov {
|
||||||
|
namespace pass {
|
||||||
namespace pattern {
|
namespace pattern {
|
||||||
namespace op {
|
namespace op {
|
||||||
class Label;
|
class Label;
|
||||||
@ -16,79 +18,42 @@ class Label;
|
|||||||
|
|
||||||
class Matcher;
|
class Matcher;
|
||||||
class MatchState;
|
class MatchState;
|
||||||
|
} // namespace pattern
|
||||||
using RPatternValueMap = std::map<std::shared_ptr<Node>, OutputVector>;
|
} // namespace pass
|
||||||
using PatternValueMap = std::map<std::shared_ptr<Node>, Output<Node>>;
|
} // namespace ov
|
||||||
using PatternValueMaps = std::vector<PatternValueMap>;
|
namespace ngraph {
|
||||||
|
namespace pattern {
|
||||||
using PatternMap = std::map<std::shared_ptr<Node>, std::shared_ptr<Node>>;
|
namespace op {
|
||||||
|
using ov::pass::pattern::op::Label;
|
||||||
PatternMap as_pattern_map(const PatternValueMap& pattern_value_map);
|
|
||||||
PatternValueMap as_pattern_value_map(const PatternMap& pattern_map);
|
|
||||||
|
|
||||||
template <typename T>
|
|
||||||
std::function<bool(std::shared_ptr<Node>)> has_class() {
|
|
||||||
auto pred = [](std::shared_ptr<Node> node) -> bool {
|
|
||||||
return ov::is_type<T>(node);
|
|
||||||
};
|
|
||||||
|
|
||||||
return pred;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
NGRAPH_API
|
using ov::pass::pattern::Matcher;
|
||||||
std::function<bool(Output<Node>)> consumers_count(size_t n);
|
using ov::pass::pattern::MatcherState;
|
||||||
|
|
||||||
NGRAPH_API
|
using ov::pass::pattern::PatternValueMap;
|
||||||
std::function<bool(Output<Node>)> has_static_dim(size_t pos);
|
using ov::pass::pattern::PatternValueMaps;
|
||||||
|
using ov::pass::pattern::RPatternValueMap;
|
||||||
|
|
||||||
NGRAPH_API
|
using ov::pass::pattern::PatternMap;
|
||||||
std::function<bool(Output<Node>)> has_static_dims(const std::vector<size_t>& dims);
|
|
||||||
|
|
||||||
NGRAPH_API
|
using ov::pass::pattern::as_pattern_map;
|
||||||
std::function<bool(Output<Node>)> has_static_shape();
|
using ov::pass::pattern::as_pattern_value_map;
|
||||||
|
using ov::pass::pattern::consumers_count;
|
||||||
NGRAPH_API
|
using ov::pass::pattern::has_class;
|
||||||
std::function<bool(Output<Node>)> has_static_rank();
|
using ov::pass::pattern::has_static_dim;
|
||||||
|
using ov::pass::pattern::has_static_dims;
|
||||||
NGRAPH_API
|
using ov::pass::pattern::has_static_rank;
|
||||||
std::function<bool(Output<Node>)> rank_equals(const Dimension& expected_rank);
|
using ov::pass::pattern::has_static_shape;
|
||||||
|
using ov::pass::pattern::rank_equals;
|
||||||
NGRAPH_API
|
using ov::pass::pattern::type_matches;
|
||||||
std::function<bool(Output<Node>)> type_matches(const element::Type& type);
|
using ov::pass::pattern::type_matches_any;
|
||||||
|
|
||||||
NGRAPH_API
|
|
||||||
std::function<bool(Output<Node>)> type_matches_any(const std::vector<element::Type>& types);
|
|
||||||
|
|
||||||
namespace op {
|
namespace op {
|
||||||
using NodePredicate = std::function<bool(std::shared_ptr<Node>)>;
|
using ov::pass::pattern::op::NodePredicate;
|
||||||
using ValuePredicate = std::function<bool(const Output<Node>& value)>;
|
using ov::pass::pattern::op::ValuePredicate;
|
||||||
|
|
||||||
NGRAPH_API
|
using ov::pass::pattern::op::as_value_predicate;
|
||||||
ValuePredicate as_value_predicate(NodePredicate pred);
|
using ov::pass::pattern::op::Pattern;
|
||||||
|
|
||||||
class NGRAPH_API Pattern : public Node {
|
|
||||||
public:
|
|
||||||
/// \brief \p a base class for \sa Skip and \sa Label
|
|
||||||
///
|
|
||||||
Pattern(const OutputVector& patterns, ValuePredicate pred) : Node(patterns), m_predicate(pred) {
|
|
||||||
if (!m_predicate) {
|
|
||||||
m_predicate = [](const Output<Node>&) {
|
|
||||||
return true;
|
|
||||||
};
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
Pattern(const OutputVector& patterns) : Pattern(patterns, nullptr) {}
|
|
||||||
|
|
||||||
virtual std::shared_ptr<Node> clone_with_new_inputs(const OutputVector& /* new_args */) const override {
|
|
||||||
throw ngraph_error("Uncopyable");
|
|
||||||
}
|
|
||||||
|
|
||||||
ValuePredicate get_predicate() const;
|
|
||||||
|
|
||||||
protected:
|
|
||||||
ValuePredicate m_predicate;
|
|
||||||
};
|
|
||||||
} // namespace op
|
} // namespace op
|
||||||
} // namespace pattern
|
} // namespace pattern
|
||||||
} // namespace ngraph
|
} // namespace ngraph
|
||||||
|
@ -6,37 +6,12 @@
|
|||||||
|
|
||||||
#include "ngraph/node.hpp"
|
#include "ngraph/node.hpp"
|
||||||
#include "ngraph/pattern/op/pattern.hpp"
|
#include "ngraph/pattern/op/pattern.hpp"
|
||||||
|
#include "openvino/pass/pattern/op/skip.hpp"
|
||||||
|
|
||||||
namespace ngraph {
|
namespace ngraph {
|
||||||
namespace pattern {
|
namespace pattern {
|
||||||
namespace op {
|
namespace op {
|
||||||
/// The graph value is added to the matched value list. If the predicate is true, the
|
using ov::pass::pattern::op::Skip;
|
||||||
/// match succeeds if the arguments match; if the predicate is false, the match succeeds
|
|
||||||
/// if the pattern input matches the graph value.
|
|
||||||
class NGRAPH_API Skip : public Pattern {
|
|
||||||
public:
|
|
||||||
static constexpr NodeTypeInfo type_info{"patternSkip", 0};
|
|
||||||
const NodeTypeInfo& get_type_info() const override;
|
|
||||||
Skip(const Output<Node>& arg, ValuePredicate pred) : Pattern({arg}, pred) {
|
|
||||||
set_output_type(0, arg.get_element_type(), arg.get_partial_shape());
|
|
||||||
}
|
|
||||||
|
|
||||||
Skip(const Output<Node>& arg, NodePredicate pred = nullptr) : Pattern({arg}, as_value_predicate(pred)) {
|
|
||||||
set_output_type(0, arg.get_element_type(), arg.get_partial_shape());
|
|
||||||
}
|
|
||||||
|
|
||||||
Skip(const OutputVector& args, ValuePredicate pred) : Pattern(args, pred) {
|
|
||||||
set_output_type(0, args.at(0).get_element_type(), args.at(0).get_partial_shape());
|
|
||||||
}
|
|
||||||
|
|
||||||
Skip(const OutputVector& args, NodePredicate pred = nullptr) : Pattern(args, as_value_predicate(pred)) {
|
|
||||||
set_output_type(0, args.at(0).get_element_type(), args.at(0).get_partial_shape());
|
|
||||||
}
|
|
||||||
|
|
||||||
virtual bool match_value(pattern::Matcher* matcher,
|
|
||||||
const Output<Node>& pattern_value,
|
|
||||||
const Output<Node>& graph_value) override;
|
|
||||||
};
|
|
||||||
} // namespace op
|
} // namespace op
|
||||||
} // namespace pattern
|
} // namespace pattern
|
||||||
} // namespace ngraph
|
} // namespace ngraph
|
||||||
|
@ -6,21 +6,12 @@
|
|||||||
|
|
||||||
#include "ngraph/node.hpp"
|
#include "ngraph/node.hpp"
|
||||||
#include "ngraph/pattern/op/pattern.hpp"
|
#include "ngraph/pattern/op/pattern.hpp"
|
||||||
|
#include "openvino/pass/pattern/op/true.hpp"
|
||||||
|
|
||||||
namespace ngraph {
|
namespace ngraph {
|
||||||
namespace pattern {
|
namespace pattern {
|
||||||
namespace op {
|
namespace op {
|
||||||
/// \brief The match always succeeds.
|
using ov::pass::pattern::op::True;
|
||||||
class NGRAPH_API True : public Pattern {
|
|
||||||
public:
|
|
||||||
static constexpr NodeTypeInfo type_info{"patternTrue", 0};
|
|
||||||
const NodeTypeInfo& get_type_info() const override;
|
|
||||||
/// \brief Always matches, does not add node to match list.
|
|
||||||
True() : Pattern(OutputVector{}) {}
|
|
||||||
bool match_value(pattern::Matcher* matcher,
|
|
||||||
const Output<Node>& pattern_value,
|
|
||||||
const Output<Node>& graph_value) override;
|
|
||||||
};
|
|
||||||
} // namespace op
|
} // namespace op
|
||||||
} // namespace pattern
|
} // namespace pattern
|
||||||
} // namespace ngraph
|
} // namespace ngraph
|
||||||
|
@ -6,68 +6,14 @@
|
|||||||
|
|
||||||
#include "ngraph/node.hpp"
|
#include "ngraph/node.hpp"
|
||||||
#include "ngraph/pattern/op/pattern.hpp"
|
#include "ngraph/pattern/op/pattern.hpp"
|
||||||
|
#include "openvino/pass/pattern/op/wrap_type.hpp"
|
||||||
|
|
||||||
namespace ngraph {
|
namespace ngraph {
|
||||||
namespace pattern {
|
namespace pattern {
|
||||||
namespace op {
|
namespace op {
|
||||||
class NGRAPH_API WrapType : public Pattern {
|
using ov::pass::pattern::op::WrapType;
|
||||||
public:
|
|
||||||
static constexpr NodeTypeInfo type_info{"patternAnyType", 0};
|
|
||||||
const NodeTypeInfo& get_type_info() const override;
|
|
||||||
|
|
||||||
explicit WrapType(
|
|
||||||
NodeTypeInfo wrapped_type,
|
|
||||||
const ValuePredicate& pred =
|
|
||||||
[](const Output<Node>& output) {
|
|
||||||
return true;
|
|
||||||
},
|
|
||||||
const OutputVector& input_values = {})
|
|
||||||
: Pattern(input_values, pred),
|
|
||||||
m_wrapped_types({wrapped_type}) {
|
|
||||||
set_output_type(0, element::Type_t::dynamic, PartialShape::dynamic());
|
|
||||||
}
|
|
||||||
|
|
||||||
explicit WrapType(
|
|
||||||
std::vector<NodeTypeInfo> wrapped_types,
|
|
||||||
const ValuePredicate& pred =
|
|
||||||
[](const Output<Node>& output) {
|
|
||||||
return true;
|
|
||||||
},
|
|
||||||
const OutputVector& input_values = {})
|
|
||||||
: Pattern(input_values, pred),
|
|
||||||
m_wrapped_types(std::move(wrapped_types)) {
|
|
||||||
set_output_type(0, element::Type_t::dynamic, PartialShape::dynamic());
|
|
||||||
}
|
|
||||||
|
|
||||||
bool match_value(pattern::Matcher* matcher,
|
|
||||||
const Output<Node>& pattern_value,
|
|
||||||
const Output<Node>& graph_value) override;
|
|
||||||
|
|
||||||
NodeTypeInfo get_wrapped_type() const;
|
|
||||||
|
|
||||||
const std::vector<NodeTypeInfo>& get_wrapped_types() const;
|
|
||||||
|
|
||||||
private:
|
|
||||||
std::vector<NodeTypeInfo> m_wrapped_types;
|
|
||||||
};
|
|
||||||
} // namespace op
|
} // namespace op
|
||||||
|
|
||||||
template <class... Args>
|
using ov::pass::pattern::wrap_type;
|
||||||
std::shared_ptr<Node> wrap_type(const OutputVector& inputs, const pattern::op::ValuePredicate& pred) {
|
|
||||||
std::vector<DiscreteTypeInfo> info{Args::type_info...};
|
|
||||||
return std::make_shared<op::WrapType>(info, pred, inputs);
|
|
||||||
}
|
|
||||||
|
|
||||||
template <class... Args>
|
|
||||||
std::shared_ptr<Node> wrap_type(const OutputVector& inputs = {}) {
|
|
||||||
return wrap_type<Args...>(inputs, [](const Output<Node>& output) {
|
|
||||||
return true;
|
|
||||||
});
|
|
||||||
}
|
|
||||||
|
|
||||||
template <class... Args>
|
|
||||||
std::shared_ptr<Node> wrap_type(const pattern::op::ValuePredicate& pred) {
|
|
||||||
return wrap_type<Args...>({}, pred);
|
|
||||||
}
|
|
||||||
} // namespace pattern
|
} // namespace pattern
|
||||||
} // namespace ngraph
|
} // namespace ngraph
|
||||||
|
@ -50,12 +50,14 @@ class Result;
|
|||||||
} // namespace v0
|
} // namespace v0
|
||||||
} // namespace op
|
} // namespace op
|
||||||
|
|
||||||
namespace pattern {
|
|
||||||
class Matcher;
|
|
||||||
} // namespace pattern
|
|
||||||
} // namespace ngraph
|
} // namespace ngraph
|
||||||
|
|
||||||
namespace ov {
|
namespace ov {
|
||||||
|
namespace pass {
|
||||||
|
namespace pattern {
|
||||||
|
class Matcher;
|
||||||
|
} // namespace pattern
|
||||||
|
} // namespace pass
|
||||||
using HostTensor = ngraph::runtime::HostTensor;
|
using HostTensor = ngraph::runtime::HostTensor;
|
||||||
using HostTensorPtr = std::shared_ptr<HostTensor>;
|
using HostTensorPtr = std::shared_ptr<HostTensor>;
|
||||||
using HostTensorVector = std::vector<HostTensorPtr>;
|
using HostTensorVector = std::vector<HostTensorPtr>;
|
||||||
@ -487,11 +489,11 @@ public:
|
|||||||
}
|
}
|
||||||
OPENVINO_SUPPRESS_DEPRECATED_END
|
OPENVINO_SUPPRESS_DEPRECATED_END
|
||||||
|
|
||||||
virtual bool match_value(ngraph::pattern::Matcher* matcher,
|
virtual bool match_value(ov::pass::pattern::Matcher* matcher,
|
||||||
const Output<Node>& pattern_value,
|
const Output<Node>& pattern_value,
|
||||||
const Output<Node>& graph_value);
|
const Output<Node>& graph_value);
|
||||||
|
|
||||||
virtual bool match_node(ngraph::pattern::Matcher* matcher, const Output<Node>& graph_value);
|
virtual bool match_node(ov::pass::pattern::Matcher* matcher, const Output<Node>& graph_value);
|
||||||
|
|
||||||
private:
|
private:
|
||||||
descriptor::Input& get_input_descriptor(size_t position);
|
descriptor::Input& get_input_descriptor(size_t position);
|
||||||
|
28
ngraph/core/include/openvino/pass/constant_folding.hpp
Normal file
28
ngraph/core/include/openvino/pass/constant_folding.hpp
Normal file
@ -0,0 +1,28 @@
|
|||||||
|
// Copyright (C) 2018-2021 Intel Corporation
|
||||||
|
// SPDX-License-Identifier: Apache-2.0
|
||||||
|
//
|
||||||
|
|
||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include "openvino/pass/pass.hpp"
|
||||||
|
|
||||||
|
namespace ov {
|
||||||
|
namespace pass {
|
||||||
|
/**
|
||||||
|
* @brief Constant folding iterates over the function and tries to evaluate nodes
|
||||||
|
* with constant inputs. Such nodes are then replaced with new Constants containing
|
||||||
|
* the result of a folded operation.
|
||||||
|
*/
|
||||||
|
class OPENVINO_API ConstantFolding : public FunctionPass {
|
||||||
|
public:
|
||||||
|
OPENVINO_RTTI_DECLARATION;
|
||||||
|
bool run_on_function(std::shared_ptr<ngraph::Function> f) override;
|
||||||
|
|
||||||
|
private:
|
||||||
|
void copy_runtime_info_to_target_inputs(const std::shared_ptr<Node>& node, const Output<Node>& replacement);
|
||||||
|
/// \brief Folds pre-calculated output tensor values to constants in case lower and
|
||||||
|
/// upper estimations are equal. Traverses graph backwards starting from the results.
|
||||||
|
bool pre_calculated_values_folding(const std::shared_ptr<ngraph::Function>& f);
|
||||||
|
};
|
||||||
|
} // namespace pass
|
||||||
|
} // namespace ov
|
17
ngraph/core/include/openvino/pass/convert_fp32_to_fp16.hpp
Normal file
17
ngraph/core/include/openvino/pass/convert_fp32_to_fp16.hpp
Normal file
@ -0,0 +1,17 @@
|
|||||||
|
// Copyright (C) 2018-2021 Intel Corporation
|
||||||
|
// SPDX-License-Identifier: Apache-2.0
|
||||||
|
//
|
||||||
|
|
||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include "openvino/pass/graph_rewrite.hpp"
|
||||||
|
|
||||||
|
namespace ov {
|
||||||
|
namespace pass {
|
||||||
|
class OPENVINO_API ConvertFP32ToFP16 : public FunctionPass {
|
||||||
|
public:
|
||||||
|
OPENVINO_RTTI_DECLARATION;
|
||||||
|
bool run_on_function(std::shared_ptr<ngraph::Function>) override;
|
||||||
|
};
|
||||||
|
} // namespace pass
|
||||||
|
} // namespace ov
|
249
ngraph/core/include/openvino/pass/graph_rewrite.hpp
Normal file
249
ngraph/core/include/openvino/pass/graph_rewrite.hpp
Normal file
@ -0,0 +1,249 @@
|
|||||||
|
// Copyright (C) 2018-2021 Intel Corporation
|
||||||
|
// SPDX-License-Identifier: Apache-2.0
|
||||||
|
//
|
||||||
|
|
||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include <functional>
|
||||||
|
#include <memory>
|
||||||
|
#include <set>
|
||||||
|
|
||||||
|
#include "openvino/pass/pass.hpp"
|
||||||
|
#include "openvino/pass/pattern/matcher.hpp"
|
||||||
|
|
||||||
|
namespace ov {
|
||||||
|
using matcher_pass_callback = std::function<bool(pass::pattern::Matcher& m)>;
|
||||||
|
using graph_rewrite_callback = std::function<bool(pass::pattern::Matcher& m)>;
|
||||||
|
using recurrent_graph_rewrite_callback = std::function<bool(pass::pattern::RecurrentMatcher& m)>;
|
||||||
|
using handler_callback = std::function<bool(const std::shared_ptr<Node>& node)>;
|
||||||
|
namespace pass {
|
||||||
|
/// \brief MatcherPass is a basic block for pattern based transformations. It describes
|
||||||
|
/// pattern and
|
||||||
|
/// action that is applied if pattern is matched.
|
||||||
|
///
|
||||||
|
/// MatcherPass consists of Matcher and matcher_pass_callback that needs to be implemented
|
||||||
|
/// and
|
||||||
|
/// finally registered by using \sa register_matcher. MatcherPass can be executed on node
|
||||||
|
/// within
|
||||||
|
/// \sa apply method. To run matcher pass on Function use GraphRewrite.
|
||||||
|
/// In addition MatcherPass provides a way for adding new operations into GraphRewrite
|
||||||
|
/// execution
|
||||||
|
/// queue. That means that operations that were created inside transformation callback can
|
||||||
|
/// be added
|
||||||
|
/// for matching. To register node use \sa register_new_node method. GraphRewrite
|
||||||
|
/// automatically
|
||||||
|
/// takes registered nodes and put them to execution queue. If multiple nodes were register
|
||||||
|
/// make
|
||||||
|
/// sure that they were registered in topological order.
|
||||||
|
/// Note: when implementing pattern for Matcher make sure that root node is an operation
|
||||||
|
/// from opset
|
||||||
|
/// or has ov::pass::pattern::op::WrapType. That will help GraphRewrite to execute matcher
|
||||||
|
/// passes more
|
||||||
|
/// efficient.
|
||||||
|
|
||||||
|
class OPENVINO_API MatcherPass : public PassBase {
|
||||||
|
public:
|
||||||
|
OPENVINO_RTTI_DECLARATION;
|
||||||
|
|
||||||
|
MatcherPass() = default;
|
||||||
|
|
||||||
|
MatcherPass(const MatcherPass&) = delete;
|
||||||
|
MatcherPass& operator=(const MatcherPass&) = delete;
|
||||||
|
|
||||||
|
explicit MatcherPass(const std::string& name,
|
||||||
|
const std::shared_ptr<pattern::Matcher>& m,
|
||||||
|
const handler_callback& handler,
|
||||||
|
const PassPropertyMask& property = PassProperty::CHANGE_DYNAMIC_STATE)
|
||||||
|
: PassBase(),
|
||||||
|
m_handler(handler),
|
||||||
|
m_matcher(m) {
|
||||||
|
set_name(name);
|
||||||
|
set_property(property, true);
|
||||||
|
}
|
||||||
|
|
||||||
|
bool apply(std::shared_ptr<ov::Node> node);
|
||||||
|
|
||||||
|
template <typename T, class... Args>
|
||||||
|
std::shared_ptr<T> register_new_node(Args&&... args) {
|
||||||
|
auto node = std::make_shared<T>(std::forward<Args>(args)...);
|
||||||
|
m_new_nodes.push_back(node);
|
||||||
|
return node;
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
std::shared_ptr<T> register_new_node(const std::shared_ptr<T>& node) {
|
||||||
|
m_new_nodes.push_back(node);
|
||||||
|
return node;
|
||||||
|
}
|
||||||
|
|
||||||
|
const std::vector<std::shared_ptr<ov::Node>>& get_new_nodes() {
|
||||||
|
return m_new_nodes;
|
||||||
|
}
|
||||||
|
void clear_new_nodes() {
|
||||||
|
m_new_nodes.clear();
|
||||||
|
}
|
||||||
|
std::shared_ptr<pattern::Matcher> get_matcher() {
|
||||||
|
return m_matcher;
|
||||||
|
}
|
||||||
|
|
||||||
|
protected:
|
||||||
|
void register_matcher(const std::shared_ptr<pattern::Matcher>& m,
|
||||||
|
const graph_rewrite_callback& callback,
|
||||||
|
const PassPropertyMask& property = PassProperty::CHANGE_DYNAMIC_STATE);
|
||||||
|
|
||||||
|
private:
|
||||||
|
handler_callback m_handler;
|
||||||
|
std::shared_ptr<pattern::Matcher> m_matcher;
|
||||||
|
std::vector<std::shared_ptr<ov::Node>> m_new_nodes;
|
||||||
|
};
|
||||||
|
|
||||||
|
/// \brief GraphRewrite is a container for MatcherPasses that allows to run them on Function
|
||||||
|
/// in
|
||||||
|
/// efficient way
|
||||||
|
///
|
||||||
|
/// Graph rewrite pass is used for matcher passes execution on Function.
|
||||||
|
/// To register MatcherPass use \sa add_matcher<T>(args) method where T is a MatcherPass
|
||||||
|
/// class.
|
||||||
|
/// As a default algorithm graph rewrite pass traverse Function in topological order and
|
||||||
|
/// applies
|
||||||
|
/// registered matcher passes for each node. But if all registered matcher passes have type
|
||||||
|
/// based
|
||||||
|
/// root node in Matcher pattern then efficient mechanism is used to execute them.
|
||||||
|
/// Matcher pattern root is type based if it's operation from opset or
|
||||||
|
/// pattern::op::WrapType.
|
||||||
|
/// Note: when implementing pattern for Matcher make sure that root node is an operation
|
||||||
|
/// from opset
|
||||||
|
/// or has ov::pattern::op::WrapType. That will help GraphRewrite to execute matcher
|
||||||
|
/// passes more
|
||||||
|
/// efficient.
|
||||||
|
|
||||||
|
class OPENVINO_API GraphRewrite : public FunctionPass {
|
||||||
|
public:
|
||||||
|
OPENVINO_RTTI_DECLARATION;
|
||||||
|
|
||||||
|
GraphRewrite() = default;
|
||||||
|
|
||||||
|
explicit GraphRewrite(const std::shared_ptr<MatcherPass>& pass) : FunctionPass() {
|
||||||
|
m_matchers.push_back(pass);
|
||||||
|
}
|
||||||
|
|
||||||
|
/// \brief Register given transformation class type to GraphRewrite execution list
|
||||||
|
/// All registered transformations will be executed in a single graph traversal.
|
||||||
|
/// Example below show the basic usage of pass::GraphRewrite
|
||||||
|
///
|
||||||
|
/// pass::Manager manager;
|
||||||
|
/// auto anchor = manager.register_pass<GraphRewrite>();
|
||||||
|
/// anchor->add_matcher<MatcherPassA>();
|
||||||
|
/// anchor->add_matcher<MatcherPassB>();
|
||||||
|
/// anchor->set_name("CommonMatchers");
|
||||||
|
/// manager.run_passes(f);
|
||||||
|
///
|
||||||
|
/// For some purposes transformation can be registered and disabled by default.
|
||||||
|
///
|
||||||
|
/// anchor->add_matcher<MatcherPassB, false>();
|
||||||
|
///
|
||||||
|
/// \return shared_ptr to the transformation instance
|
||||||
|
template <typename T,
|
||||||
|
bool Enabled = true,
|
||||||
|
class... Args,
|
||||||
|
typename std::enable_if<std::is_base_of<pass::MatcherPass, T>::value, bool>::type = true>
|
||||||
|
std::shared_ptr<T> add_matcher(Args&&... args) {
|
||||||
|
static_assert(std::is_base_of<pass::MatcherPass, T>::value, "pass not derived from MatcherPass");
|
||||||
|
auto pass = std::make_shared<T>(std::forward<Args>(args)...);
|
||||||
|
auto pass_config = get_pass_config();
|
||||||
|
pass->set_pass_config(pass_config);
|
||||||
|
if (!Enabled && !pass_config->is_enabled<T>()) {
|
||||||
|
pass_config->disable<T>();
|
||||||
|
}
|
||||||
|
m_matchers.push_back(pass);
|
||||||
|
return pass;
|
||||||
|
}
|
||||||
|
|
||||||
|
/// \brief Register passes from GraphRewrite class that contains sequence of matcher
|
||||||
|
/// passes registered in its ctor.
|
||||||
|
/// For example:
|
||||||
|
///
|
||||||
|
/// class ov::pass::LinFusions: public ov::pass::GraphRewrite {
|
||||||
|
/// public:
|
||||||
|
/// OPENVINO_RTTI_DECLARATION;
|
||||||
|
/// Fusions() {
|
||||||
|
/// add_matcher<ov::pass::AddFusion>();
|
||||||
|
/// add_matcher<ov::pass::MulFusion>();
|
||||||
|
/// }
|
||||||
|
/// };
|
||||||
|
///
|
||||||
|
/// pass::Manager manager;
|
||||||
|
/// auto anchor = manager.register_pass<GraphRewrite>();
|
||||||
|
/// anchor->add_matcher<LinFusions>();
|
||||||
|
/// anchor->add_matcher<OtherFusions>();
|
||||||
|
/// anchor->set_name("CommonFusions");
|
||||||
|
/// manager.run_passes(f);
|
||||||
|
///
|
||||||
|
/// In this case all matcher passes from LinFusions pass will be united with other
|
||||||
|
/// registered matchers.
|
||||||
|
template <typename T,
|
||||||
|
class... Args,
|
||||||
|
typename std::enable_if<std::is_base_of<pass::GraphRewrite, T>::value, bool>::type = true>
|
||||||
|
void add_matcher(Args&&... args) {
|
||||||
|
static_assert(std::is_base_of<pass::GraphRewrite, T>::value, "pass not derived from GraphRewrite");
|
||||||
|
auto pass = std::make_shared<T>(std::forward<Args>(args)...);
|
||||||
|
auto pass_config = get_pass_config();
|
||||||
|
|
||||||
|
for (auto& matcher : pass->m_matchers) {
|
||||||
|
pass->set_pass_config(pass_config);
|
||||||
|
m_matchers.push_back(matcher);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
OPENVINO_DEPRECATED("Use MatcherPass instead")
|
||||||
|
void add_matcher(const std::shared_ptr<pattern::Matcher>& m,
|
||||||
|
const graph_rewrite_callback& callback,
|
||||||
|
const PassPropertyMask& property);
|
||||||
|
|
||||||
|
OPENVINO_DEPRECATED("Use MatcherPass instead")
|
||||||
|
void add_matcher(const std::shared_ptr<pattern::Matcher>& m, const ov::graph_rewrite_callback& callback);
|
||||||
|
|
||||||
|
bool run_on_function(std::shared_ptr<ov::Function> f) override;
|
||||||
|
|
||||||
|
void set_pass_config(const std::shared_ptr<PassConfig>& pass_config) override;
|
||||||
|
|
||||||
|
protected:
|
||||||
|
bool apply_matcher_passes(std::shared_ptr<Function> f, std::deque<std::weak_ptr<Node>> nodes_to_run);
|
||||||
|
|
||||||
|
bool m_enable_shape_inference = false;
|
||||||
|
|
||||||
|
std::vector<std::shared_ptr<ov::pass::MatcherPass>> m_matchers;
|
||||||
|
};
|
||||||
|
|
||||||
|
class OPENVINO_API BackwardGraphRewrite : public GraphRewrite {
|
||||||
|
public:
|
||||||
|
OPENVINO_RTTI_DECLARATION;
|
||||||
|
|
||||||
|
BackwardGraphRewrite() = default;
|
||||||
|
|
||||||
|
explicit BackwardGraphRewrite(const std::shared_ptr<MatcherPass>& pass) : GraphRewrite(pass) {}
|
||||||
|
|
||||||
|
bool run_on_function(std::shared_ptr<ov::Function> f) override;
|
||||||
|
};
|
||||||
|
|
||||||
|
class OPENVINO_API RecurrentGraphRewrite : public FunctionPass {
|
||||||
|
public:
|
||||||
|
RecurrentGraphRewrite(size_t num_iters = 10) : FunctionPass(), m_num_iters(num_iters) {}
|
||||||
|
|
||||||
|
void add_matcher(const std::shared_ptr<pattern::RecurrentMatcher>& m,
|
||||||
|
const ov::recurrent_graph_rewrite_callback& callback,
|
||||||
|
const PassPropertyMask& property);
|
||||||
|
|
||||||
|
// TODO: This interface may deprecate after all passes are refactored.
|
||||||
|
void add_matcher(const std::shared_ptr<pattern::RecurrentMatcher>& m,
|
||||||
|
const ov::recurrent_graph_rewrite_callback& callback);
|
||||||
|
|
||||||
|
bool run_on_function(std::shared_ptr<ov::Function> f) override;
|
||||||
|
|
||||||
|
private:
|
||||||
|
size_t m_num_iters;
|
||||||
|
|
||||||
|
std::vector<std::shared_ptr<ov::pass::MatcherPass>> m_matchers;
|
||||||
|
};
|
||||||
|
} // namespace pass
|
||||||
|
} // namespace ov
|
48
ngraph/core/include/openvino/pass/low_latency.hpp
Normal file
48
ngraph/core/include/openvino/pass/low_latency.hpp
Normal file
@ -0,0 +1,48 @@
|
|||||||
|
// Copyright (C) 2018-2021 Intel Corporation
|
||||||
|
// SPDX-License-Identifier: Apache-2.0
|
||||||
|
//
|
||||||
|
|
||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include <memory>
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
#include "openvino/pass/pass.hpp"
|
||||||
|
|
||||||
|
namespace ov {
|
||||||
|
namespace pass {
|
||||||
|
/**
|
||||||
|
* @brief The transformation finds all TensorIterator/Loop layers in the network,
|
||||||
|
* processes all back edges that describe a connection between Result and Parameter
|
||||||
|
* of the TensorIterator/Loop bodies,and inserts ReadValue and Assign layers at the
|
||||||
|
* input and output corresponding to this back edge.
|
||||||
|
* Supported platforms: CPU, GNA.
|
||||||
|
*
|
||||||
|
* The example below describes the changes made by the transformation
|
||||||
|
* [] - TensorIterator body
|
||||||
|
* () - new layer
|
||||||
|
* BE - back-edge
|
||||||
|
*
|
||||||
|
* before applying the transformation:
|
||||||
|
* -> input1[BE_1 -> Parameter -> Layers ... -> Result -> BE_1 ]output1->
|
||||||
|
*
|
||||||
|
* after applying the transformation:
|
||||||
|
* ->(ReadValue)-> input1[BE_1 ->Parameter->Layers ...->Result->BE_1]output1 ->(Assign)
|
||||||
|
* \
|
||||||
|
* ->...
|
||||||
|
* After applying the transformation, the resulting network can be inferred
|
||||||
|
* step by step, the states will store between inferences.
|
||||||
|
*/
|
||||||
|
class OPENVINO_API LowLatency2 : public FunctionPass {
|
||||||
|
public:
|
||||||
|
OPENVINO_RTTI_DECLARATION;
|
||||||
|
|
||||||
|
explicit LowLatency2(bool use_const_initializer = true) : m_use_const_initializer(use_const_initializer) {}
|
||||||
|
|
||||||
|
bool run_on_function(std::shared_ptr<ngraph::Function> f) override;
|
||||||
|
|
||||||
|
private:
|
||||||
|
bool m_use_const_initializer;
|
||||||
|
};
|
||||||
|
} // namespace pass
|
||||||
|
} // namespace ov
|
116
ngraph/core/include/openvino/pass/manager.hpp
Normal file
116
ngraph/core/include/openvino/pass/manager.hpp
Normal file
@ -0,0 +1,116 @@
|
|||||||
|
// Copyright (C) 2018-2021 Intel Corporation
|
||||||
|
// SPDX-License-Identifier: Apache-2.0
|
||||||
|
//
|
||||||
|
|
||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include <list>
|
||||||
|
#include <memory>
|
||||||
|
#include <typeinfo>
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
#include "openvino/pass/pass.hpp"
|
||||||
|
#include "openvino/pass/validate.hpp"
|
||||||
|
|
||||||
|
namespace ov {
|
||||||
|
namespace pass {
|
||||||
|
class OPENVINO_API Manager {
|
||||||
|
public:
|
||||||
|
Manager();
|
||||||
|
~Manager();
|
||||||
|
|
||||||
|
//// \brief Construct Manager with shared PassConfig instance
|
||||||
|
explicit Manager(std::shared_ptr<PassConfig> pass_config);
|
||||||
|
|
||||||
|
/// \brief Register given transformation class type to execution list
|
||||||
|
/// Example below show the basic usage of pass::Manager
|
||||||
|
///
|
||||||
|
/// pass::Manager manager;
|
||||||
|
/// manager.register_pass<MyTransformation>(/*transformation constructor ars*/);
|
||||||
|
/// manager.run_passes(f);
|
||||||
|
///
|
||||||
|
/// For some purposes transformation can be registered and disabled by default.
|
||||||
|
///
|
||||||
|
/// manager.register_pass<MyTransformation, false>();
|
||||||
|
///
|
||||||
|
/// \return shared_ptr to the transformation instance
|
||||||
|
template <typename T, bool Enable = true, class... Args>
|
||||||
|
std::shared_ptr<T> register_pass(Args&&... args) {
|
||||||
|
auto rc = push_pass<T>(std::forward<Args>(args)...);
|
||||||
|
rc->set_pass_config(m_pass_config);
|
||||||
|
if (m_per_pass_validation) {
|
||||||
|
push_pass<Validate>();
|
||||||
|
}
|
||||||
|
if (!Enable && !m_pass_config->is_enabled<T>()) {
|
||||||
|
m_pass_config->disable<T>();
|
||||||
|
}
|
||||||
|
return rc;
|
||||||
|
}
|
||||||
|
|
||||||
|
void run_passes(std::shared_ptr<Function>);
|
||||||
|
|
||||||
|
void set_pass_visualization(bool new_state) {
|
||||||
|
m_visualize = new_state;
|
||||||
|
}
|
||||||
|
/// \brief Set flag to enable/disable running Validate pass after executing
|
||||||
|
/// each registered pass
|
||||||
|
/// \param new_state Value "true" enables Validate pass run; "false", otherwise
|
||||||
|
void set_per_pass_validation(bool new_state) {
|
||||||
|
m_per_pass_validation = new_state;
|
||||||
|
}
|
||||||
|
/// \brief Callback is a lambda function that can be used by registered transformations.
|
||||||
|
/// The main purpose of this callback is to provide a way for plugins to disable/enable
|
||||||
|
/// transformations based on some conditions. In some cases plugins may want not to
|
||||||
|
/// execute some
|
||||||
|
/// transformations.
|
||||||
|
/// For example plugin can disable unpleasant decompositions because of performance
|
||||||
|
/// reasons for
|
||||||
|
/// some cases.
|
||||||
|
/// Callback example:
|
||||||
|
/// auto callback = [](const std::shared_ptr<const ngraph::Node> & node) -> bool {
|
||||||
|
/// return std::dynamic_pointer_cast<const ngraph::opset3::DepthToSpace>(node) !=
|
||||||
|
/// nullptr;
|
||||||
|
/// };
|
||||||
|
/// This callback returns true in case of DepthToSpace operation. So when execution
|
||||||
|
/// DepthToSpace
|
||||||
|
/// decomposition pass will check is this decomposition needed or plugin can execute
|
||||||
|
/// this
|
||||||
|
/// operation directly. And of course on transformation side we need to have a response
|
||||||
|
/// for this
|
||||||
|
/// callback.
|
||||||
|
/// if (transformation_callback(batch_to_space)) {
|
||||||
|
/// return false;
|
||||||
|
/// }
|
||||||
|
/// \param callback lamda function that returns true in case if node is supported by
|
||||||
|
/// plugin and
|
||||||
|
/// transformation is not needed
|
||||||
|
OPENVINO_DEPRECATED("Please use get_pass_config() to configure transformation pipeline")
|
||||||
|
void set_callback(const param_callback& callback) {
|
||||||
|
m_pass_config->set_callback(callback);
|
||||||
|
}
|
||||||
|
/// \return PassConfig shared object. This object is used for transformations pipeline
|
||||||
|
/// configuration.
|
||||||
|
/// This object allows to disable/enable transformations execution, set callback to
|
||||||
|
/// particular
|
||||||
|
/// transformation. For mo details see PassConfig class.
|
||||||
|
std::shared_ptr<PassConfig> get_pass_config() {
|
||||||
|
return m_pass_config;
|
||||||
|
}
|
||||||
|
|
||||||
|
protected:
|
||||||
|
template <typename T, class... Args>
|
||||||
|
std::shared_ptr<T> push_pass(Args&&... args) {
|
||||||
|
static_assert(std::is_base_of<pass::PassBase, T>::value, "pass not derived from pass base");
|
||||||
|
auto pass = std::make_shared<T>(std::forward<Args>(args)...);
|
||||||
|
auto pass_base = std::static_pointer_cast<PassBase>(pass);
|
||||||
|
m_pass_list.push_back(pass_base);
|
||||||
|
return pass;
|
||||||
|
}
|
||||||
|
|
||||||
|
std::shared_ptr<PassConfig> m_pass_config;
|
||||||
|
std::vector<std::shared_ptr<PassBase>> m_pass_list;
|
||||||
|
bool m_visualize = false;
|
||||||
|
bool m_per_pass_validation = true;
|
||||||
|
};
|
||||||
|
} // namespace pass
|
||||||
|
} // namespace ov
|
110
ngraph/core/include/openvino/pass/pass.hpp
Normal file
110
ngraph/core/include/openvino/pass/pass.hpp
Normal file
@ -0,0 +1,110 @@
|
|||||||
|
// Copyright (C) 2018-2021 Intel Corporation
|
||||||
|
// SPDX-License-Identifier: Apache-2.0
|
||||||
|
//
|
||||||
|
|
||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include <list>
|
||||||
|
#include <memory>
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
#include "ngraph/util.hpp"
|
||||||
|
#include "openvino/core/core_visibility.hpp"
|
||||||
|
#include "openvino/core/deprecated.hpp"
|
||||||
|
#include "openvino/core/function.hpp"
|
||||||
|
#include "openvino/core/node.hpp"
|
||||||
|
#include "openvino/pass/pass_config.hpp"
|
||||||
|
|
||||||
|
namespace ov {
|
||||||
|
namespace pass {
|
||||||
|
enum class PassProperty : uint32_t {
|
||||||
|
// Pass requires node shapes to be static
|
||||||
|
REQUIRE_STATIC_SHAPE = 0x1,
|
||||||
|
// Pass transformation will change the function's dynamic state
|
||||||
|
CHANGE_DYNAMIC_STATE = 1 << 1,
|
||||||
|
};
|
||||||
|
|
||||||
|
using PassPropertyMask = ngraph::EnumMask<PassProperty>;
|
||||||
|
|
||||||
|
class OPENVINO_API PassBase {
|
||||||
|
friend class Manager;
|
||||||
|
|
||||||
|
public:
|
||||||
|
PassBase();
|
||||||
|
virtual ~PassBase() = default;
|
||||||
|
/// Check if this pass has all the pass properties.
|
||||||
|
bool get_property(const PassPropertyMask& prop_mask) const;
|
||||||
|
|
||||||
|
void set_name(const std::string& name) {
|
||||||
|
m_name = name;
|
||||||
|
}
|
||||||
|
std::string get_name() const;
|
||||||
|
|
||||||
|
/// \brief Set callback for particular transformation type.
|
||||||
|
/// This method set global callback. For more details see PassConfig class
|
||||||
|
/// documentation.
|
||||||
|
/// \param callback lambda function that takes node and returns bool
|
||||||
|
void set_callback(const param_callback& callback);
|
||||||
|
|
||||||
|
/// \brief Set PassConfig for particular transformation instance
|
||||||
|
/// \param pass_config is a PassConfig shared_ptr
|
||||||
|
virtual void set_pass_config(const std::shared_ptr<PassConfig>& pass_config) {
|
||||||
|
m_pass_config = pass_config;
|
||||||
|
}
|
||||||
|
|
||||||
|
/// \brief Allows to access PassConfig shared instance
|
||||||
|
/// \return Shared instance of PassConfig class
|
||||||
|
std::shared_ptr<PassConfig> get_pass_config() {
|
||||||
|
return m_pass_config;
|
||||||
|
}
|
||||||
|
/// \brief Applies callback for given node. By default callback returns false.
|
||||||
|
/// This method remains here only for backward compatibility and will be removed
|
||||||
|
/// after all transformations are moved to transformation_callback() method.
|
||||||
|
/// \return result of callback execution for given node
|
||||||
|
NGRAPH_DEPRECATED("Please use transformation_callback method instead")
|
||||||
|
bool m_transformation_callback(const std::shared_ptr<const Node>& node) {
|
||||||
|
return m_pass_config->get_callback(get_type_info())(node);
|
||||||
|
}
|
||||||
|
|
||||||
|
/// \brief Applies callback for given node. By default callback returns false.
|
||||||
|
/// \param node which will be used inside callback
|
||||||
|
/// \return result of callback execution for given node
|
||||||
|
bool transformation_callback(const std::shared_ptr<const Node>& node) {
|
||||||
|
return m_pass_config->get_callback(get_type_info())(node);
|
||||||
|
}
|
||||||
|
|
||||||
|
using type_info_t = DiscreteTypeInfo;
|
||||||
|
|
||||||
|
virtual const type_info_t& get_type_info() const = 0;
|
||||||
|
|
||||||
|
protected:
|
||||||
|
void set_property(const PassPropertyMask& prop, bool value);
|
||||||
|
|
||||||
|
private:
|
||||||
|
PassPropertyMask m_property;
|
||||||
|
|
||||||
|
std::string m_name;
|
||||||
|
std::shared_ptr<PassConfig> m_pass_config;
|
||||||
|
};
|
||||||
|
|
||||||
|
class OPENVINO_API FunctionPass : public PassBase {
|
||||||
|
public:
|
||||||
|
NGRAPH_RTTI_DECLARATION;
|
||||||
|
~FunctionPass() override;
|
||||||
|
virtual bool run_on_function(std::shared_ptr<ngraph::Function>) = 0;
|
||||||
|
};
|
||||||
|
|
||||||
|
class Manager;
|
||||||
|
enum class FusionType : uint32_t {
|
||||||
|
//`DIFFERENTIABLE_FUSIONS` produce ops that support autodiff
|
||||||
|
// i.e. implement `generate_adjoints`
|
||||||
|
DIFFERENTIABLE_FUSIONS = 0x1,
|
||||||
|
REGULAR_FUSIONS = 0x2,
|
||||||
|
//`FOP_FUSIONS` produce ops in the FusedOps category that might
|
||||||
|
// not be supported by all backends
|
||||||
|
FOP_FUSIONS = 0x4,
|
||||||
|
ALL_FUSIONS = 0xFFFFFFFF
|
||||||
|
};
|
||||||
|
using FusionTypeMask = ngraph::EnumMask<FusionType>;
|
||||||
|
} // namespace pass
|
||||||
|
} // namespace ov
|
176
ngraph/core/include/openvino/pass/pass_config.hpp
Normal file
176
ngraph/core/include/openvino/pass/pass_config.hpp
Normal file
@ -0,0 +1,176 @@
|
|||||||
|
// Copyright (C) 2018-2021 Intel Corporation
|
||||||
|
// SPDX-License-Identifier: Apache-2.0
|
||||||
|
//
|
||||||
|
|
||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include <list>
|
||||||
|
#include <memory>
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
#include "ngraph/util.hpp"
|
||||||
|
#include "openvino/core/core_visibility.hpp"
|
||||||
|
#include "openvino/core/deprecated.hpp"
|
||||||
|
#include "openvino/core/function.hpp"
|
||||||
|
#include "openvino/core/node.hpp"
|
||||||
|
|
||||||
|
namespace ov {
|
||||||
|
namespace pass {
|
||||||
|
using param_callback = std::function<bool(const std::shared_ptr<const ::ov::Node>)>;
|
||||||
|
using param_callback_map = std::map<ov::DiscreteTypeInfo, param_callback>;
|
||||||
|
|
||||||
|
/// \brief Class representing a transformations config that is used for disabling/enabling
|
||||||
|
/// transformations registered inside pass::Manager and also allows to set callback for all
|
||||||
|
/// transformations or for particular transformation.
|
||||||
|
///
|
||||||
|
/// When pass::Manager is created all passes registered inside this manager including nested
|
||||||
|
/// passes will share the same instance of PassConfig class.
|
||||||
|
/// To work with this class first you need to get shared instance of this class by calling
|
||||||
|
/// manager.get_pass_config() method. Then you will be able to disable/enable passes based
|
||||||
|
/// on transformations type_info. For example:
|
||||||
|
///
|
||||||
|
/// pass::Manager manager;
|
||||||
|
/// manager.register_pass<CommonOptimizations>();
|
||||||
|
/// auto pass_config = manager.get_pass_config();
|
||||||
|
/// pass_config->disable<ConvertGELU>(); // this will disable nested pass inside
|
||||||
|
/// // CommonOptimizations pipeline
|
||||||
|
/// manager.run_passes(f);
|
||||||
|
///
|
||||||
|
/// Sometimes it is needed to call transformation inside other transformation manually. And
|
||||||
|
/// for that case before running transformation you need manually check that this pass is
|
||||||
|
/// not disabled and then you need to set current PassConfig instance to this
|
||||||
|
/// transformation. For example:
|
||||||
|
///
|
||||||
|
/// // Inside MatcherPass callback or inside FunctionPass run_on_function() method
|
||||||
|
/// // you need to call get_pass_config() method to get shared instance of PassConfig
|
||||||
|
/// auto pass_config = get_pass_config();
|
||||||
|
///
|
||||||
|
/// // Before running nested transformation you need to check is it disabled or not
|
||||||
|
/// if (!pass_config->is_disabled<ConvertGELU>()) {
|
||||||
|
/// auto pass = ConvertGELU();
|
||||||
|
/// pass->set_pass_config(pass_config);
|
||||||
|
/// pass.apply(node);
|
||||||
|
/// }
|
||||||
|
///
|
||||||
|
/// Following this logic inside your transformations you will guaranty that transformations
|
||||||
|
/// will be executed in a right way.
|
||||||
|
class OPENVINO_API PassConfig {
|
||||||
|
public:
|
||||||
|
/// \brief Disable transformation by its type_info
|
||||||
|
/// \param type_info Transformation type_info
|
||||||
|
void disable(const DiscreteTypeInfo& type_info);
|
||||||
|
/// \brief Disable transformation by its class type (based on type_info)
|
||||||
|
template <typename T>
|
||||||
|
void disable() {
|
||||||
|
OPENVINO_SUPPRESS_DEPRECATED_START
|
||||||
|
disable(T::type_info);
|
||||||
|
OPENVINO_SUPPRESS_DEPRECATED_END
|
||||||
|
}
|
||||||
|
|
||||||
|
/// \brief Enable transformation by its type_info
|
||||||
|
/// \param type_info Transformation type_info
|
||||||
|
void enable(const DiscreteTypeInfo& type_info);
|
||||||
|
/// \brief Enable transformation by its class type (based on type_info)
|
||||||
|
template <typename T>
|
||||||
|
void enable() {
|
||||||
|
OPENVINO_SUPPRESS_DEPRECATED_START
|
||||||
|
enable(T::type_info);
|
||||||
|
OPENVINO_SUPPRESS_DEPRECATED_END
|
||||||
|
}
|
||||||
|
|
||||||
|
/// \brief Set callback for all kind of transformations
|
||||||
|
void set_callback(const param_callback& callback) {
|
||||||
|
m_callback = callback;
|
||||||
|
}
|
||||||
|
template <typename... Args>
|
||||||
|
typename std::enable_if<sizeof...(Args) == 0>::type set_callback(const param_callback& callback) {}
|
||||||
|
|
||||||
|
/// \brief Set callback for particular transformation class types
|
||||||
|
///
|
||||||
|
/// Example below show how to set callback for one or multiple passes using this method.
|
||||||
|
///
|
||||||
|
/// pass_config->set_callback<ngraph::pass::ConvertBatchToSpace,
|
||||||
|
/// ngraph::pass::ConvertSpaceToBatch>(
|
||||||
|
/// [](const_node_ptr &node) -> bool {
|
||||||
|
/// // Disable transformations for cases when input shape rank is not
|
||||||
|
/// equal to 4
|
||||||
|
/// const auto input_shape_rank =
|
||||||
|
/// node->get_output_partial_shape(0).rank().get_length();
|
||||||
|
/// if (input_shape_rank != 4) {
|
||||||
|
/// return false;
|
||||||
|
/// }
|
||||||
|
/// return true;
|
||||||
|
/// });
|
||||||
|
///
|
||||||
|
/// Note that inside transformations you must provide code that work with this callback.
|
||||||
|
/// See example below:
|
||||||
|
///
|
||||||
|
/// if (transformation_callback(node)) {
|
||||||
|
/// return false; // exit from transformation
|
||||||
|
/// }
|
||||||
|
///
|
||||||
|
template <typename T, class... Args>
|
||||||
|
void set_callback(const param_callback& callback) {
|
||||||
|
m_callback_map[T::type_info] = callback;
|
||||||
|
set_callback<Args...>(callback);
|
||||||
|
}
|
||||||
|
|
||||||
|
/// \brief Get callback for given transformation type_info
|
||||||
|
/// \param type_info Transformation type_info
|
||||||
|
///
|
||||||
|
/// In case if callback wasn't set for given transformation type then global callback
|
||||||
|
/// will be returned. But if even global callback wasn't set then default callback will
|
||||||
|
/// be returned.
|
||||||
|
param_callback get_callback(const DiscreteTypeInfo& type_info) const;
|
||||||
|
|
||||||
|
/// \brief Get callback for given transformation class type
|
||||||
|
/// \return callback lambda function
|
||||||
|
template <typename T>
|
||||||
|
param_callback get_callback() const {
|
||||||
|
NGRAPH_SUPPRESS_DEPRECATED_START
|
||||||
|
return get_callback(T::type_info);
|
||||||
|
NGRAPH_SUPPRESS_DEPRECATED_END
|
||||||
|
}
|
||||||
|
|
||||||
|
/// \brief Check either transformation type is disabled or not
|
||||||
|
/// \param type_info Transformation type_info
|
||||||
|
/// \return true if transformation type was disabled and false otherwise
|
||||||
|
bool is_disabled(const DiscreteTypeInfo& type_info) const {
|
||||||
|
return m_disabled.count(type_info);
|
||||||
|
}
|
||||||
|
|
||||||
|
/// \brief Check either transformation class type is disabled or not
|
||||||
|
/// \return true if transformation type was disabled and false otherwise
|
||||||
|
template <typename T>
|
||||||
|
bool is_disabled() const {
|
||||||
|
NGRAPH_SUPPRESS_DEPRECATED_START
|
||||||
|
return is_disabled(T::type_info);
|
||||||
|
NGRAPH_SUPPRESS_DEPRECATED_END
|
||||||
|
}
|
||||||
|
|
||||||
|
/// \brief Check either transformation type is force enabled or not
|
||||||
|
/// \param type_info Transformation type_info
|
||||||
|
/// \return true if transformation type was force enabled and false otherwise
|
||||||
|
bool is_enabled(const DiscreteTypeInfo& type_info) const {
|
||||||
|
return m_enabled.count(type_info);
|
||||||
|
}
|
||||||
|
|
||||||
|
/// \brief Check either transformation class type is force enabled or not
|
||||||
|
/// \return true if transformation type was force enabled and false otherwise
|
||||||
|
template <typename T>
|
||||||
|
bool is_enabled() const {
|
||||||
|
return is_enabled(T::type_info);
|
||||||
|
}
|
||||||
|
|
||||||
|
void add_disabled_passes(const PassConfig& rhs);
|
||||||
|
|
||||||
|
private:
|
||||||
|
param_callback m_callback = [](const std::shared_ptr<const ::ngraph::Node>&) {
|
||||||
|
return false;
|
||||||
|
};
|
||||||
|
param_callback_map m_callback_map;
|
||||||
|
std::unordered_set<DiscreteTypeInfo> m_disabled;
|
||||||
|
std::unordered_set<DiscreteTypeInfo> m_enabled;
|
||||||
|
};
|
||||||
|
} // namespace pass
|
||||||
|
} // namespace ov
|
271
ngraph/core/include/openvino/pass/pattern/matcher.hpp
Normal file
271
ngraph/core/include/openvino/pass/pattern/matcher.hpp
Normal file
@ -0,0 +1,271 @@
|
|||||||
|
// Copyright (C) 2018-2021 Intel Corporation
|
||||||
|
// SPDX-License-Identifier: Apache-2.0
|
||||||
|
//
|
||||||
|
|
||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include <memory.h>
|
||||||
|
|
||||||
|
#include <algorithm>
|
||||||
|
#include <functional>
|
||||||
|
|
||||||
|
#include "ngraph/op/constant.hpp"
|
||||||
|
#include "openvino/core/except.hpp"
|
||||||
|
#include "openvino/core/node.hpp"
|
||||||
|
#include "openvino/pass/pattern/op/any.hpp"
|
||||||
|
#include "openvino/pass/pattern/op/any_of.hpp"
|
||||||
|
#include "openvino/pass/pattern/op/any_output.hpp"
|
||||||
|
#include "openvino/pass/pattern/op/label.hpp"
|
||||||
|
#include "openvino/pass/pattern/op/skip.hpp"
|
||||||
|
|
||||||
|
namespace ov {
|
||||||
|
namespace pass {
|
||||||
|
class GraphRewrite;
|
||||||
|
|
||||||
|
namespace pattern {
|
||||||
|
class Matcher;
|
||||||
|
|
||||||
|
class OPENVINO_API MatcherState {
|
||||||
|
public:
|
||||||
|
MatcherState(Matcher*);
|
||||||
|
bool finish(bool is_successful);
|
||||||
|
~MatcherState();
|
||||||
|
|
||||||
|
protected:
|
||||||
|
Matcher* m_matcher;
|
||||||
|
PatternValueMap m_pattern_value_map;
|
||||||
|
PatternValueMaps m_pattern_value_maps;
|
||||||
|
size_t m_watermark;
|
||||||
|
size_t m_capture_size;
|
||||||
|
bool m_restore{true};
|
||||||
|
};
|
||||||
|
|
||||||
|
/// Matcher looks for node patterns in a computation graph. The patterns are described by an
|
||||||
|
/// automaton that is described by an extended computation graph. The matcher executes
|
||||||
|
/// by attempting to match the start node of the pattern to a computation graph value
|
||||||
|
/// (output of a Node). In addition to determing if a match occurs, a pattern node may add
|
||||||
|
/// graph nodes to a list of matched nodes, associate nodes with graph values, and start
|
||||||
|
/// submatches. Submatches add match state changes to the enclosing match if the submatch
|
||||||
|
/// succeeds; otherwise the state is reverted.
|
||||||
|
///
|
||||||
|
/// The default match behavior of a pattern node with a graph nodes is that the computation
|
||||||
|
/// graph value is added to the end of the matched value list and the match succeeds if the
|
||||||
|
/// node/pattern types match and the input values match. In the case of a commutative node,
|
||||||
|
/// the inputs can match in any order. If the matcher is in strict mode, the graph value
|
||||||
|
/// element type and shape must also match.
|
||||||
|
///
|
||||||
|
/// Pattern nodes that have different match behavior are in ov::pass::pattern::op and have
|
||||||
|
/// descriptions of their match behavior.
|
||||||
|
class OPENVINO_API Matcher {
|
||||||
|
public:
|
||||||
|
using PatternMap = ov::pass::pattern::PatternMap;
|
||||||
|
|
||||||
|
// Avoid implicit string construction from nullptr.
|
||||||
|
Matcher(const std::shared_ptr<Node> pattern_node, std::nullptr_t name) = delete;
|
||||||
|
|
||||||
|
Matcher() = default;
|
||||||
|
Matcher(Output<Node>& pattern_node) : m_pattern_node{pattern_node} {}
|
||||||
|
|
||||||
|
Matcher(Output<Node>& pattern_node, const std::string& name) : m_pattern_node(pattern_node), m_name{name} {}
|
||||||
|
|
||||||
|
/// \brief Constructs a Matcher object
|
||||||
|
///
|
||||||
|
/// \param pattern_node is a pattern sub graph that will be matched against input graphs
|
||||||
|
/// \param name is a string which is used for logging and disabling a matcher
|
||||||
|
/// \param strict_mode forces a matcher to consider shapes and ET of nodes
|
||||||
|
Matcher(const Output<Node>& pattern_node, const std::string& name, bool strict_mode)
|
||||||
|
: m_pattern_node(pattern_node),
|
||||||
|
m_name(name),
|
||||||
|
m_strict_mode(strict_mode) {}
|
||||||
|
|
||||||
|
// Some matches should start on a node rather than an output. These three constructors
|
||||||
|
// are transition until we work out the right way to do that.
|
||||||
|
Matcher(std::shared_ptr<Node> pattern_node);
|
||||||
|
Matcher(std::shared_ptr<Node> pattern_node, const std::string& name);
|
||||||
|
Matcher(std::shared_ptr<Node> pattern_node, const std::string& name, bool strict_mode);
|
||||||
|
|
||||||
|
virtual ~Matcher() = default;
|
||||||
|
/// \brief Matches a pattern to \p graph_node
|
||||||
|
///
|
||||||
|
/// \param graph_value is an input graph to be matched against
|
||||||
|
bool match(const Output<Node>& graph_value);
|
||||||
|
|
||||||
|
bool match(std::shared_ptr<Node> graph_node);
|
||||||
|
|
||||||
|
/// \brief Matches a pattern to \p graph_node
|
||||||
|
///
|
||||||
|
/// \param graph_value is an input graph to be matched against
|
||||||
|
/// \param previous_matches contains previous mappings from labels to nodes to use
|
||||||
|
bool match(const Output<Node>& graph_value, const PatternMap& previous_matches);
|
||||||
|
bool match(const Output<Node>& graph_value, const PatternValueMap& previous_matches);
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
static std::shared_ptr<T> unique_match(const std::shared_ptr<Node>& node) {
|
||||||
|
std::shared_ptr<T> matched;
|
||||||
|
for (const auto& arg : node->input_values()) {
|
||||||
|
if (auto t_casted = ov::as_type_ptr<T>(arg.get_node_shared_ptr())) {
|
||||||
|
if (matched) {
|
||||||
|
throw Exception("There's more than two arguments of the same type");
|
||||||
|
} else {
|
||||||
|
matched = t_casted;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return matched;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool is_contained_match(const NodeVector& exclusions = {}, bool ignore_unused = true);
|
||||||
|
const NodeVector get_matched_nodes() {
|
||||||
|
return as_node_vector(m_matched_list);
|
||||||
|
}
|
||||||
|
const OutputVector& get_matched_values() const {
|
||||||
|
return m_matched_list;
|
||||||
|
}
|
||||||
|
OutputVector& get_matched_values() {
|
||||||
|
return m_matched_list;
|
||||||
|
}
|
||||||
|
void reset() {}
|
||||||
|
const std::string& get_name() {
|
||||||
|
return m_name;
|
||||||
|
}
|
||||||
|
std::shared_ptr<Node> get_pattern() {
|
||||||
|
return m_pattern_node.get_node_shared_ptr();
|
||||||
|
}
|
||||||
|
Output<Node> get_pattern_value() {
|
||||||
|
return m_pattern_node;
|
||||||
|
}
|
||||||
|
std::shared_ptr<Node> get_match_root();
|
||||||
|
Output<Node> get_match_value();
|
||||||
|
PatternMap get_pattern_map() const;
|
||||||
|
PatternValueMap& get_pattern_value_map() {
|
||||||
|
return m_pattern_map;
|
||||||
|
}
|
||||||
|
PatternValueMaps& get_pattern_value_maps() {
|
||||||
|
return m_pattern_value_maps;
|
||||||
|
}
|
||||||
|
/// \brief Low-level helper to match recurring patterns
|
||||||
|
///
|
||||||
|
/// \param graph is a graph to be matched against
|
||||||
|
/// \param pattern is a recurring pattern
|
||||||
|
/// \param rpattern specifies a node to recur from next
|
||||||
|
/// \param patterns a map from labels to matches
|
||||||
|
|
||||||
|
size_t add_node(Output<Node> node);
|
||||||
|
|
||||||
|
bool virtual match_value(const ov::Output<Node>& pattern_value, const ov::Output<Node>& graph_value);
|
||||||
|
|
||||||
|
bool is_strict_mode() {
|
||||||
|
return m_strict_mode;
|
||||||
|
}
|
||||||
|
virtual bool match_arguments(Node* pattern_node, const std::shared_ptr<Node>& graph_node);
|
||||||
|
|
||||||
|
void capture(const std::set<Node*>& static_nodes);
|
||||||
|
|
||||||
|
void clear_state();
|
||||||
|
|
||||||
|
size_t get_number_of_recurrent_matches() const {
|
||||||
|
return m_pattern_value_maps.size();
|
||||||
|
}
|
||||||
|
NodeVector get_bound_nodes_for_pattern(const Output<Node>& pattern) const;
|
||||||
|
size_t get_number_of_bound_labels() const;
|
||||||
|
/// \brief Try a match
|
||||||
|
MatcherState start_match();
|
||||||
|
|
||||||
|
Output<Node> m_match_root;
|
||||||
|
Output<Node> m_pattern_node;
|
||||||
|
PatternValueMap m_pattern_map;
|
||||||
|
PatternValueMaps m_pattern_value_maps;
|
||||||
|
OutputVector m_matched_list;
|
||||||
|
|
||||||
|
protected:
|
||||||
|
bool match_permutation(const OutputVector& pattern_args, const OutputVector& args);
|
||||||
|
|
||||||
|
std::string m_name{"unnamed"};
|
||||||
|
bool m_strict_mode{false};
|
||||||
|
};
|
||||||
|
|
||||||
|
class OPENVINO_API RecurrentMatcher {
|
||||||
|
public:
|
||||||
|
/// \brief Constructs a RecurrentMatcher object. Reccurent Matchers are used to match
|
||||||
|
/// repeating patterns (e.g. RNN, LSTM, GRU cells)
|
||||||
|
///
|
||||||
|
/// \param initial_pattern is a pattern sub graph describing the initial cell
|
||||||
|
/// \param pattern is a pattern sub graph describing an individual cell
|
||||||
|
/// \param rpattern is a (recurring) label to denote which node the next match should
|
||||||
|
/// start at
|
||||||
|
/// \param correlated_patterns is a set of labels whose bound nodes must remain the same
|
||||||
|
/// across all cells
|
||||||
|
RecurrentMatcher(const Output<Node>& initial_pattern,
|
||||||
|
const Output<Node>& pattern,
|
||||||
|
const std::shared_ptr<Node>& rpattern,
|
||||||
|
const std::set<std::shared_ptr<Node>>& correlated_patterns)
|
||||||
|
: m_initial_pattern(initial_pattern),
|
||||||
|
m_pattern(pattern),
|
||||||
|
m_recurrent_pattern(rpattern),
|
||||||
|
m_correlated_patterns(correlated_patterns) {}
|
||||||
|
|
||||||
|
/// \brief Constructs a RecurrentMatcher object. Reccurent Matchers are used to match
|
||||||
|
/// repeating patterns (e.g. RNN, LSTM, GRU cells)
|
||||||
|
///
|
||||||
|
/// \param pattern is a pattern sub graph describing an individual cell
|
||||||
|
/// \param rpattern is a (recurring) label to denote which node the next match should
|
||||||
|
/// start at
|
||||||
|
/// \param correlated_patterns is a set of labels whose bound nodes must remain the same
|
||||||
|
/// across all cells
|
||||||
|
RecurrentMatcher(const Output<Node>& pattern,
|
||||||
|
const std::shared_ptr<Node>& rpattern,
|
||||||
|
const std::set<std::shared_ptr<Node>>& correlated_patterns)
|
||||||
|
: RecurrentMatcher(pattern, pattern, rpattern, correlated_patterns) {}
|
||||||
|
|
||||||
|
RecurrentMatcher(const Output<Node>& initial_pattern,
|
||||||
|
const Output<Node>& pattern,
|
||||||
|
const std::shared_ptr<Node>& rpattern,
|
||||||
|
const std::set<std::shared_ptr<op::Label>>& correlated_patterns);
|
||||||
|
|
||||||
|
RecurrentMatcher(const Output<Node>& pattern,
|
||||||
|
const std::shared_ptr<Node>& rpattern,
|
||||||
|
const std::set<std::shared_ptr<op::Label>>& correlated_patterns)
|
||||||
|
: RecurrentMatcher(pattern, pattern, rpattern, correlated_patterns) {}
|
||||||
|
|
||||||
|
/// \brief Returns a vector of bound nodes for a given label (used in a pattern
|
||||||
|
/// describing an individual cell
|
||||||
|
NodeVector get_bound_nodes_for_pattern(const std::shared_ptr<Node>& pattern) const {
|
||||||
|
if (m_matches.count(pattern) == 0) {
|
||||||
|
throw Exception("No bound nodes for a given label");
|
||||||
|
}
|
||||||
|
|
||||||
|
return as_node_vector(m_matches.at(pattern));
|
||||||
|
}
|
||||||
|
|
||||||
|
size_t get_number_of_recurrent_matches() const {
|
||||||
|
if (m_matches.size() == 0) {
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
return (*m_matches.begin()).second.size();
|
||||||
|
}
|
||||||
|
|
||||||
|
size_t get_number_of_bound_labels() const {
|
||||||
|
return m_matches.size();
|
||||||
|
}
|
||||||
|
/// \brief Tries to match a pattern for an individual cell to a given \p graph
|
||||||
|
bool match(Output<Node> graph);
|
||||||
|
|
||||||
|
std::shared_ptr<Node> get_match_root() {
|
||||||
|
return m_match_root.get_node_shared_ptr();
|
||||||
|
}
|
||||||
|
Output<Node> get_match_value() {
|
||||||
|
return m_match_root;
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
Output<Node> m_initial_pattern;
|
||||||
|
Output<Node> m_pattern;
|
||||||
|
std::shared_ptr<Node> m_recurrent_pattern;
|
||||||
|
const std::set<std::shared_ptr<Node>> m_correlated_patterns;
|
||||||
|
RPatternValueMap m_matches;
|
||||||
|
Output<Node> m_match_root;
|
||||||
|
};
|
||||||
|
} // namespace pattern
|
||||||
|
} // namespace pass
|
||||||
|
} // namespace ov
|
45
ngraph/core/include/openvino/pass/pattern/op/any.hpp
Normal file
45
ngraph/core/include/openvino/pass/pattern/op/any.hpp
Normal file
@ -0,0 +1,45 @@
|
|||||||
|
// Copyright (C) 2018-2021 Intel Corporation
|
||||||
|
// SPDX-License-Identifier: Apache-2.0
|
||||||
|
//
|
||||||
|
|
||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include "openvino/core/node.hpp"
|
||||||
|
#include "openvino/pass/pattern/op/pattern.hpp"
|
||||||
|
|
||||||
|
namespace ov {
|
||||||
|
namespace pass {
|
||||||
|
namespace pattern {
|
||||||
|
namespace op {
|
||||||
|
/// The graph value is to the matched value list. If the predicate is true for the node
|
||||||
|
/// and the arguments match, the match succeeds.
|
||||||
|
class OPENVINO_API Any : public Pattern {
|
||||||
|
public:
|
||||||
|
static constexpr NodeTypeInfo type_info{"patternAny", 0};
|
||||||
|
const NodeTypeInfo& get_type_info() const override;
|
||||||
|
/// \brief creates a Any node containing a sub-pattern described by \sa type and \sa
|
||||||
|
/// shape.
|
||||||
|
Any(const element::Type& type, const PartialShape& s, ValuePredicate pred, const OutputVector& wrapped_values)
|
||||||
|
: Pattern(wrapped_values, pred) {
|
||||||
|
set_output_type(0, type, s);
|
||||||
|
}
|
||||||
|
Any(const element::Type& type, const PartialShape& s, NodePredicate pred, const NodeVector& wrapped_values)
|
||||||
|
: Any(type, s, as_value_predicate(pred), as_output_vector(wrapped_values)) {}
|
||||||
|
/// \brief creates a Any node containing a sub-pattern described by the type and
|
||||||
|
/// shape of \sa node.
|
||||||
|
Any(const Output<Node>& node, ValuePredicate pred, const OutputVector& wrapped_values)
|
||||||
|
: Any(node.get_element_type(), node.get_partial_shape(), pred, wrapped_values) {}
|
||||||
|
Any(const Output<Node>& node, NodePredicate pred, const NodeVector& wrapped_values)
|
||||||
|
: Any(node.get_element_type(),
|
||||||
|
node.get_partial_shape(),
|
||||||
|
as_value_predicate(pred),
|
||||||
|
as_output_vector(wrapped_values)) {}
|
||||||
|
|
||||||
|
bool match_value(pattern::Matcher* matcher,
|
||||||
|
const Output<Node>& pattern_value,
|
||||||
|
const Output<Node>& graph_value) override;
|
||||||
|
};
|
||||||
|
} // namespace op
|
||||||
|
} // namespace pattern
|
||||||
|
} // namespace pass
|
||||||
|
} // namespace ov
|
54
ngraph/core/include/openvino/pass/pattern/op/any_of.hpp
Normal file
54
ngraph/core/include/openvino/pass/pattern/op/any_of.hpp
Normal file
@ -0,0 +1,54 @@
|
|||||||
|
// Copyright (C) 2018-2021 Intel Corporation
|
||||||
|
// SPDX-License-Identifier: Apache-2.0
|
||||||
|
//
|
||||||
|
|
||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include "openvino/core/node.hpp"
|
||||||
|
#include "openvino/pass/pattern/op/pattern.hpp"
|
||||||
|
|
||||||
|
namespace ov {
|
||||||
|
namespace pass {
|
||||||
|
namespace pattern {
|
||||||
|
namespace op {
|
||||||
|
/// The graph value is added to the matched values list. If the predicate is true for
|
||||||
|
/// the
|
||||||
|
/// graph node, a submatch is performed on the input of AnyOf and each input of the
|
||||||
|
/// graph node. The first match that succeeds results in a successful match. Otherwise
|
||||||
|
/// the match fails.
|
||||||
|
///
|
||||||
|
/// AnyOf may be given a type and shape for use in strict mode.
|
||||||
|
class OPENVINO_API AnyOf : public Pattern {
|
||||||
|
public:
|
||||||
|
static constexpr NodeTypeInfo type_info{"patternAnyOf", 0};
|
||||||
|
const NodeTypeInfo& get_type_info() const override;
|
||||||
|
/// \brief creates a AnyOf node containing a sub-pattern described by \sa type and
|
||||||
|
/// \sa shape.
|
||||||
|
AnyOf(const element::Type& type, const PartialShape& s, ValuePredicate pred, const OutputVector& wrapped_values)
|
||||||
|
: Pattern(wrapped_values, pred) {
|
||||||
|
if (wrapped_values.size() != 1) {
|
||||||
|
throw Exception("AnyOf expects exactly one argument");
|
||||||
|
}
|
||||||
|
set_output_type(0, type, s);
|
||||||
|
}
|
||||||
|
AnyOf(const element::Type& type, const PartialShape& s, NodePredicate pred, const NodeVector& wrapped_values)
|
||||||
|
: AnyOf(
|
||||||
|
type,
|
||||||
|
s,
|
||||||
|
[pred](const Output<Node>& value) {
|
||||||
|
return pred(value.get_node_shared_ptr());
|
||||||
|
},
|
||||||
|
as_output_vector(wrapped_values)) {}
|
||||||
|
|
||||||
|
/// \brief creates a AnyOf node containing a sub-pattern described by the type and
|
||||||
|
/// shape of \sa node.
|
||||||
|
AnyOf(const Output<Node>& node, ValuePredicate pred, const OutputVector& wrapped_values)
|
||||||
|
: AnyOf(node.get_element_type(), node.get_partial_shape(), pred, wrapped_values) {}
|
||||||
|
AnyOf(const std::shared_ptr<Node>& node, NodePredicate pred, const NodeVector& wrapped_values)
|
||||||
|
: AnyOf(node, as_value_predicate(pred), as_output_vector(wrapped_values)) {}
|
||||||
|
bool match_value(Matcher* matcher, const Output<Node>& pattern_value, const Output<Node>& graph_value) override;
|
||||||
|
};
|
||||||
|
} // namespace op
|
||||||
|
} // namespace pattern
|
||||||
|
} // namespace pass
|
||||||
|
} // namespace ov
|
30
ngraph/core/include/openvino/pass/pattern/op/any_output.hpp
Normal file
30
ngraph/core/include/openvino/pass/pattern/op/any_output.hpp
Normal file
@ -0,0 +1,30 @@
|
|||||||
|
// Copyright (C) 2018-2021 Intel Corporation
|
||||||
|
// SPDX-License-Identifier: Apache-2.0
|
||||||
|
//
|
||||||
|
|
||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include "openvino/core/node.hpp"
|
||||||
|
#include "openvino/pass/pattern/op/pattern.hpp"
|
||||||
|
|
||||||
|
namespace ov {
|
||||||
|
namespace pass {
|
||||||
|
namespace pattern {
|
||||||
|
namespace op {
|
||||||
|
/// Matches any output of a node
|
||||||
|
class OPENVINO_API AnyOutput : public Pattern {
|
||||||
|
public:
|
||||||
|
static constexpr NodeTypeInfo type_info{"patternAnyOutput", 0};
|
||||||
|
const NodeTypeInfo& get_type_info() const override;
|
||||||
|
/// \brief creates an AnyOutput node matching any output of a node
|
||||||
|
/// \param node The node to match
|
||||||
|
AnyOutput(const std::shared_ptr<Node>& pattern) : Pattern({pattern->output(0)}) {}
|
||||||
|
|
||||||
|
bool match_value(pattern::Matcher* matcher,
|
||||||
|
const Output<Node>& pattern_value,
|
||||||
|
const Output<Node>& graph_value) override;
|
||||||
|
};
|
||||||
|
} // namespace op
|
||||||
|
} // namespace pattern
|
||||||
|
} // namespace pass
|
||||||
|
} // namespace ov
|
55
ngraph/core/include/openvino/pass/pattern/op/branch.hpp
Normal file
55
ngraph/core/include/openvino/pass/pattern/op/branch.hpp
Normal file
@ -0,0 +1,55 @@
|
|||||||
|
// Copyright (C) 2018-2021 Intel Corporation
|
||||||
|
// SPDX-License-Identifier: Apache-2.0
|
||||||
|
//
|
||||||
|
|
||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include "openvino/core/node.hpp"
|
||||||
|
#include "openvino/pass/pattern/op/pattern.hpp"
|
||||||
|
|
||||||
|
namespace ov {
|
||||||
|
namespace pass {
|
||||||
|
namespace pattern {
|
||||||
|
namespace op {
|
||||||
|
/// A branch adds a loop to the pattern. The branch match is successful if the
|
||||||
|
/// destination node pattern matches the graph value. The destination node is a node in
|
||||||
|
/// the pattern graph that will not have been created some time after the Branch node is
|
||||||
|
/// created; use set_destination to add it.
|
||||||
|
///
|
||||||
|
/// The branch destination is not stored as a shared pointer to prevent reference
|
||||||
|
/// cycles. Thus the destination node must be referenced in some other way to prevent it
|
||||||
|
/// from being deleted.
|
||||||
|
class OPENVINO_API Branch : public Pattern {
|
||||||
|
public:
|
||||||
|
static constexpr NodeTypeInfo type_info{"patternBranch", 0};
|
||||||
|
const NodeTypeInfo& get_type_info() const override;
|
||||||
|
/// \brief Creates a Branch pattern
|
||||||
|
/// \param pattern the destinationing pattern
|
||||||
|
/// \param labels Labels where the destination may occur
|
||||||
|
Branch() : Pattern(OutputVector{}) {
|
||||||
|
set_output_type(0, element::f32, ngraph::Shape{});
|
||||||
|
}
|
||||||
|
|
||||||
|
void set_destination(const Output<Node>& destination) {
|
||||||
|
m_destination_node = destination.get_node();
|
||||||
|
m_destination_index = destination.get_index();
|
||||||
|
}
|
||||||
|
|
||||||
|
Output<Node> get_destination() const {
|
||||||
|
return m_destination_node == nullptr
|
||||||
|
? Output<Node>()
|
||||||
|
: Output<Node>{m_destination_node->shared_from_this(), m_destination_index};
|
||||||
|
}
|
||||||
|
|
||||||
|
bool match_value(pattern::Matcher* matcher,
|
||||||
|
const Output<Node>& pattern_value,
|
||||||
|
const Output<Node>& graph_value) override;
|
||||||
|
|
||||||
|
protected:
|
||||||
|
Node* m_destination_node{nullptr};
|
||||||
|
size_t m_destination_index{0};
|
||||||
|
};
|
||||||
|
} // namespace op
|
||||||
|
} // namespace pattern
|
||||||
|
} // namespace pass
|
||||||
|
} // namespace ov
|
44
ngraph/core/include/openvino/pass/pattern/op/capture.hpp
Normal file
44
ngraph/core/include/openvino/pass/pattern/op/capture.hpp
Normal file
@ -0,0 +1,44 @@
|
|||||||
|
// Copyright (C) 2018-2021 Intel Corporation
|
||||||
|
// SPDX-License-Identifier: Apache-2.0
|
||||||
|
//
|
||||||
|
|
||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include "openvino/core/node.hpp"
|
||||||
|
#include "openvino/pass/pattern/op/pattern.hpp"
|
||||||
|
|
||||||
|
namespace ov {
|
||||||
|
namespace pass {
|
||||||
|
namespace pattern {
|
||||||
|
namespace op {
|
||||||
|
/// Experimental for support of recurrent matches.
|
||||||
|
///
|
||||||
|
/// Capture adds the pattern value map to a list of pattern value maps and resets
|
||||||
|
/// matches for pattern nodes not in the static node list. The match always succeeds.
|
||||||
|
class OPENVINO_API Capture : public Pattern {
|
||||||
|
public:
|
||||||
|
static constexpr NodeTypeInfo type_info{"patternCapture", 0};
|
||||||
|
const NodeTypeInfo& get_type_info() const override;
|
||||||
|
Capture(const Output<Node>& arg) : Pattern({arg}) {
|
||||||
|
set_output_type(0, arg.get_element_type(), arg.get_partial_shape());
|
||||||
|
}
|
||||||
|
|
||||||
|
/// \brief static nodes are retained after a capture. All other nodes are dropped
|
||||||
|
std::set<Node*> get_static_nodes() {
|
||||||
|
return m_static_nodes;
|
||||||
|
}
|
||||||
|
void set_static_nodes(const std::set<Node*>& static_nodes) {
|
||||||
|
m_static_nodes = static_nodes;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool match_value(pattern::Matcher* matcher,
|
||||||
|
const Output<Node>& pattern_value,
|
||||||
|
const Output<Node>& graph_value) override;
|
||||||
|
|
||||||
|
protected:
|
||||||
|
std::set<Node*> m_static_nodes;
|
||||||
|
};
|
||||||
|
} // namespace op
|
||||||
|
} // namespace pattern
|
||||||
|
} // namespace pass
|
||||||
|
} // namespace ov
|
113
ngraph/core/include/openvino/pass/pattern/op/label.hpp
Normal file
113
ngraph/core/include/openvino/pass/pattern/op/label.hpp
Normal file
@ -0,0 +1,113 @@
|
|||||||
|
// Copyright (C) 2018-2021 Intel Corporation
|
||||||
|
// SPDX-License-Identifier: Apache-2.0
|
||||||
|
//
|
||||||
|
|
||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include "openvino/core/node.hpp"
|
||||||
|
#include "openvino/pass/pattern/op/pattern.hpp"
|
||||||
|
|
||||||
|
namespace ov {
|
||||||
|
namespace pass {
|
||||||
|
namespace pattern {
|
||||||
|
namespace op {
|
||||||
|
/// Fails if the predicate returns false on the graph value.
|
||||||
|
///
|
||||||
|
/// The graph value is added to the matched values list. If the Label is already
|
||||||
|
/// associated with a value, the match succeeds if the value is the same as the graph
|
||||||
|
/// value. Otherwise, the label is associated with the graph value and the match
|
||||||
|
/// succeeds if the pattern input matches the graph value.
|
||||||
|
///
|
||||||
|
/// DEPRECATED: If no inputs are given to Label, a True node is serves as the input. If
|
||||||
|
/// more than one inputs are given, an Or pattern of the inputs serves as the input.
|
||||||
|
class OPENVINO_API Label : public Pattern {
|
||||||
|
public:
|
||||||
|
static constexpr NodeTypeInfo type_info{"patternLabel", 0};
|
||||||
|
const NodeTypeInfo& get_type_info() const override;
|
||||||
|
/// \brief creates a Label node containing a sub-pattern described by \sa type and
|
||||||
|
/// \sa shape.
|
||||||
|
///
|
||||||
|
/// this Label node can be bound only to the nodes in the input graph
|
||||||
|
/// that match the pattern specified by \sa wrapped_nodes
|
||||||
|
/// Example:
|
||||||
|
/// \code{.cpp}
|
||||||
|
/// auto add = a + b; // a and b are op::Parameter in this example
|
||||||
|
/// auto label = std::make_shared<pattern::op::Label>(element::f32,
|
||||||
|
/// Shape{2,2},
|
||||||
|
/// nullptr,
|
||||||
|
/// OutputVector{add});
|
||||||
|
/// \endcode
|
||||||
|
Label(const element::Type& type,
|
||||||
|
const PartialShape& s,
|
||||||
|
const ValuePredicate pred,
|
||||||
|
const OutputVector& wrapped_values)
|
||||||
|
: Pattern(OutputVector{wrap_values(wrapped_values)}, pred) {
|
||||||
|
set_output_type(0, type, s);
|
||||||
|
}
|
||||||
|
|
||||||
|
explicit Label(const element::Type& type = element::dynamic, const PartialShape& s = PartialShape::dynamic())
|
||||||
|
: Label(
|
||||||
|
type,
|
||||||
|
s,
|
||||||
|
[](const Output<Node>&) {
|
||||||
|
return true;
|
||||||
|
},
|
||||||
|
OutputVector()) {}
|
||||||
|
|
||||||
|
Label(const element::Type& type, const PartialShape& s, ValuePredicate pred)
|
||||||
|
: Label(type, s, pred, OutputVector{}) {}
|
||||||
|
|
||||||
|
Label(const element::Type& type, const PartialShape& s, NodePredicate pred)
|
||||||
|
: Label(type, s, as_value_predicate(pred), OutputVector{}) {}
|
||||||
|
|
||||||
|
Label(const element::Type& type, const PartialShape& s, const NodePredicate pred, const NodeVector& wrapped_values)
|
||||||
|
: Label(type, s, as_value_predicate(pred), as_output_vector(wrapped_values)) {}
|
||||||
|
|
||||||
|
/// \brief creates a Label node containing a sub-pattern described by the type and
|
||||||
|
/// shape of \sa node.
|
||||||
|
///
|
||||||
|
/// this Label node can be bound only to the nodes in the input graph
|
||||||
|
/// that match the pattern specified by \sa wrapped_values
|
||||||
|
/// Example:
|
||||||
|
/// \code{.cpp}
|
||||||
|
/// auto add = a + b; // a and b are op::Parameter in this example
|
||||||
|
/// auto label = std::make_shared<pattern::op::Label>(add,
|
||||||
|
/// nullptr,
|
||||||
|
/// OutputVector{add});
|
||||||
|
/// \endcode
|
||||||
|
Label(const Output<Node>& value, const ValuePredicate pred, const OutputVector& wrapped_values)
|
||||||
|
: Label(value.get_element_type(), value.get_partial_shape(), pred, wrapped_values) {}
|
||||||
|
Label(const Output<Node>& value, const ValuePredicate pred)
|
||||||
|
: Label(value.get_element_type(), value.get_partial_shape(), pred, OutputVector{}) {}
|
||||||
|
|
||||||
|
Label(const Output<Node>& value, const NodePredicate pred)
|
||||||
|
: Label(value.get_element_type(), value.get_partial_shape(), as_value_predicate(pred), OutputVector{}) {}
|
||||||
|
Label(const Output<Node>& value)
|
||||||
|
: Label(
|
||||||
|
value.get_element_type(),
|
||||||
|
value.get_partial_shape(),
|
||||||
|
[](const Output<Node>&) {
|
||||||
|
return true;
|
||||||
|
},
|
||||||
|
OutputVector{}) {}
|
||||||
|
Label(const Output<Node>& node, const NodePredicate pred, const NodeVector& wrapped_values)
|
||||||
|
: Label(node.get_element_type(),
|
||||||
|
node.get_partial_shape(),
|
||||||
|
as_value_predicate(pred),
|
||||||
|
as_output_vector(wrapped_values)) {}
|
||||||
|
|
||||||
|
bool match_value(Matcher* matcher, const Output<Node>& pattern_value, const Output<Node>& graph_value) override;
|
||||||
|
|
||||||
|
protected:
|
||||||
|
static Output<Node> wrap_values(const OutputVector& wrapped_values);
|
||||||
|
};
|
||||||
|
} // namespace op
|
||||||
|
|
||||||
|
OPENVINO_API
|
||||||
|
std::shared_ptr<Node> any_input();
|
||||||
|
|
||||||
|
OPENVINO_API
|
||||||
|
std::shared_ptr<Node> any_input(const pattern::op::ValuePredicate& pred);
|
||||||
|
} // namespace pattern
|
||||||
|
} // namespace pass
|
||||||
|
} // namespace ov
|
32
ngraph/core/include/openvino/pass/pattern/op/or.hpp
Normal file
32
ngraph/core/include/openvino/pass/pattern/op/or.hpp
Normal file
@ -0,0 +1,32 @@
|
|||||||
|
// Copyright (C) 2018-2021 Intel Corporation
|
||||||
|
// SPDX-License-Identifier: Apache-2.0
|
||||||
|
//
|
||||||
|
|
||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include "openvino/core/node.hpp"
|
||||||
|
#include "openvino/pass/pattern/op/pattern.hpp"
|
||||||
|
|
||||||
|
namespace ov {
|
||||||
|
namespace pass {
|
||||||
|
namespace pattern {
|
||||||
|
namespace op {
|
||||||
|
/// A submatch on the graph value is performed on each input to the Or; the match
|
||||||
|
/// succeeds on the first match. Otherwise the match fails.
|
||||||
|
class OPENVINO_API Or : public Pattern {
|
||||||
|
public:
|
||||||
|
static constexpr NodeTypeInfo type_info{"patternOr", 0};
|
||||||
|
const NodeTypeInfo& get_type_info() const override;
|
||||||
|
/// \brief creates an Or node matching one of several sub-patterns in order. Does
|
||||||
|
/// not add node to match list.
|
||||||
|
/// \param patterns The patterns to try for matching
|
||||||
|
Or(const OutputVector& patterns) : Pattern(patterns) {}
|
||||||
|
|
||||||
|
bool match_value(pattern::Matcher* matcher,
|
||||||
|
const Output<Node>& pattern_value,
|
||||||
|
const Output<Node>& graph_value) override;
|
||||||
|
};
|
||||||
|
} // namespace op
|
||||||
|
} // namespace pattern
|
||||||
|
} // namespace pass
|
||||||
|
} // namespace ov
|
96
ngraph/core/include/openvino/pass/pattern/op/pattern.hpp
Normal file
96
ngraph/core/include/openvino/pass/pattern/op/pattern.hpp
Normal file
@ -0,0 +1,96 @@
|
|||||||
|
// Copyright (C) 2018-2021 Intel Corporation
|
||||||
|
// SPDX-License-Identifier: Apache-2.0
|
||||||
|
//
|
||||||
|
|
||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include <functional>
|
||||||
|
|
||||||
|
#include "openvino/core/node.hpp"
|
||||||
|
|
||||||
|
namespace ov {
|
||||||
|
namespace pass {
|
||||||
|
namespace pattern {
|
||||||
|
namespace op {
|
||||||
|
class Label;
|
||||||
|
}
|
||||||
|
|
||||||
|
class Matcher;
|
||||||
|
class MatcherState;
|
||||||
|
|
||||||
|
using RPatternValueMap = std::map<std::shared_ptr<Node>, OutputVector>;
|
||||||
|
using PatternValueMap = std::map<std::shared_ptr<Node>, Output<Node>>;
|
||||||
|
using PatternValueMaps = std::vector<PatternValueMap>;
|
||||||
|
|
||||||
|
using PatternMap = std::map<std::shared_ptr<Node>, std::shared_ptr<Node>>;
|
||||||
|
|
||||||
|
PatternMap as_pattern_map(const PatternValueMap& pattern_value_map);
|
||||||
|
PatternValueMap as_pattern_value_map(const PatternMap& pattern_map);
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
std::function<bool(std::shared_ptr<Node>)> has_class() {
|
||||||
|
auto pred = [](std::shared_ptr<Node> node) -> bool {
|
||||||
|
return ov::is_type<T>(node);
|
||||||
|
};
|
||||||
|
|
||||||
|
return pred;
|
||||||
|
}
|
||||||
|
|
||||||
|
OPENVINO_API
|
||||||
|
std::function<bool(Output<Node>)> consumers_count(size_t n);
|
||||||
|
|
||||||
|
OPENVINO_API
|
||||||
|
std::function<bool(Output<Node>)> has_static_dim(size_t pos);
|
||||||
|
|
||||||
|
OPENVINO_API
|
||||||
|
std::function<bool(Output<Node>)> has_static_dims(const std::vector<size_t>& dims);
|
||||||
|
|
||||||
|
OPENVINO_API
|
||||||
|
std::function<bool(Output<Node>)> has_static_shape();
|
||||||
|
|
||||||
|
OPENVINO_API
|
||||||
|
std::function<bool(Output<Node>)> has_static_rank();
|
||||||
|
|
||||||
|
OPENVINO_API
|
||||||
|
std::function<bool(Output<Node>)> rank_equals(const Dimension& expected_rank);
|
||||||
|
|
||||||
|
OPENVINO_API
|
||||||
|
std::function<bool(Output<Node>)> type_matches(const element::Type& type);
|
||||||
|
|
||||||
|
OPENVINO_API
|
||||||
|
std::function<bool(Output<Node>)> type_matches_any(const std::vector<element::Type>& types);
|
||||||
|
|
||||||
|
namespace op {
|
||||||
|
using NodePredicate = std::function<bool(std::shared_ptr<Node>)>;
|
||||||
|
using ValuePredicate = std::function<bool(const Output<Node>& value)>;
|
||||||
|
|
||||||
|
OPENVINO_API
|
||||||
|
ValuePredicate as_value_predicate(NodePredicate pred);
|
||||||
|
|
||||||
|
class OPENVINO_API Pattern : public Node {
|
||||||
|
public:
|
||||||
|
/// \brief \p a base class for \sa Skip and \sa Label
|
||||||
|
///
|
||||||
|
Pattern(const OutputVector& patterns, ValuePredicate pred) : Node(patterns), m_predicate(pred) {
|
||||||
|
if (!m_predicate) {
|
||||||
|
m_predicate = [](const Output<Node>&) {
|
||||||
|
return true;
|
||||||
|
};
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
Pattern(const OutputVector& patterns) : Pattern(patterns, nullptr) {}
|
||||||
|
|
||||||
|
std::shared_ptr<Node> clone_with_new_inputs(const OutputVector& /* new_args */) const override {
|
||||||
|
throw Exception("Uncopyable");
|
||||||
|
}
|
||||||
|
|
||||||
|
ValuePredicate get_predicate() const;
|
||||||
|
|
||||||
|
protected:
|
||||||
|
ValuePredicate m_predicate;
|
||||||
|
};
|
||||||
|
} // namespace op
|
||||||
|
} // namespace pattern
|
||||||
|
} // namespace pass
|
||||||
|
} // namespace ov
|
44
ngraph/core/include/openvino/pass/pattern/op/skip.hpp
Normal file
44
ngraph/core/include/openvino/pass/pattern/op/skip.hpp
Normal file
@ -0,0 +1,44 @@
|
|||||||
|
// Copyright (C) 2018-2021 Intel Corporation
|
||||||
|
// SPDX-License-Identifier: Apache-2.0
|
||||||
|
//
|
||||||
|
|
||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include "openvino/core/node.hpp"
|
||||||
|
#include "openvino/pass/pattern/op/pattern.hpp"
|
||||||
|
|
||||||
|
namespace ov {
|
||||||
|
namespace pass {
|
||||||
|
namespace pattern {
|
||||||
|
namespace op {
|
||||||
|
/// The graph value is added to the matched value list. If the predicate is true, the
|
||||||
|
/// match succeeds if the arguments match; if the predicate is false, the match succeeds
|
||||||
|
/// if the pattern input matches the graph value.
|
||||||
|
class OPENVINO_API Skip : public Pattern {
|
||||||
|
public:
|
||||||
|
static constexpr NodeTypeInfo type_info{"patternSkip", 0};
|
||||||
|
const NodeTypeInfo& get_type_info() const override;
|
||||||
|
Skip(const Output<Node>& arg, ValuePredicate pred) : Pattern({arg}, pred) {
|
||||||
|
set_output_type(0, arg.get_element_type(), arg.get_partial_shape());
|
||||||
|
}
|
||||||
|
|
||||||
|
Skip(const Output<Node>& arg, NodePredicate pred = nullptr) : Pattern({arg}, as_value_predicate(pred)) {
|
||||||
|
set_output_type(0, arg.get_element_type(), arg.get_partial_shape());
|
||||||
|
}
|
||||||
|
|
||||||
|
Skip(const OutputVector& args, ValuePredicate pred) : Pattern(args, pred) {
|
||||||
|
set_output_type(0, args.at(0).get_element_type(), args.at(0).get_partial_shape());
|
||||||
|
}
|
||||||
|
|
||||||
|
Skip(const OutputVector& args, NodePredicate pred = nullptr) : Pattern(args, as_value_predicate(pred)) {
|
||||||
|
set_output_type(0, args.at(0).get_element_type(), args.at(0).get_partial_shape());
|
||||||
|
}
|
||||||
|
|
||||||
|
bool match_value(pattern::Matcher* matcher,
|
||||||
|
const Output<Node>& pattern_value,
|
||||||
|
const Output<Node>& graph_value) override;
|
||||||
|
};
|
||||||
|
} // namespace op
|
||||||
|
} // namespace pattern
|
||||||
|
} // namespace pass
|
||||||
|
} // namespace ov
|
28
ngraph/core/include/openvino/pass/pattern/op/true.hpp
Normal file
28
ngraph/core/include/openvino/pass/pattern/op/true.hpp
Normal file
@ -0,0 +1,28 @@
|
|||||||
|
// Copyright (C) 2018-2021 Intel Corporation
|
||||||
|
// SPDX-License-Identifier: Apache-2.0
|
||||||
|
//
|
||||||
|
|
||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include "openvino/core/node.hpp"
|
||||||
|
#include "openvino/pass/pattern/op/pattern.hpp"
|
||||||
|
|
||||||
|
namespace ov {
|
||||||
|
namespace pass {
|
||||||
|
namespace pattern {
|
||||||
|
namespace op {
|
||||||
|
/// \brief The match always succeeds.
|
||||||
|
class OPENVINO_API True : public Pattern {
|
||||||
|
public:
|
||||||
|
static constexpr NodeTypeInfo type_info{"patternTrue", 0};
|
||||||
|
const NodeTypeInfo& get_type_info() const override;
|
||||||
|
/// \brief Always matches, does not add node to match list.
|
||||||
|
True() : Pattern(OutputVector{}) {}
|
||||||
|
bool match_value(pattern::Matcher* matcher,
|
||||||
|
const Output<Node>& pattern_value,
|
||||||
|
const Output<Node>& graph_value) override;
|
||||||
|
};
|
||||||
|
} // namespace op
|
||||||
|
} // namespace pattern
|
||||||
|
} // namespace pass
|
||||||
|
} // namespace ov
|
75
ngraph/core/include/openvino/pass/pattern/op/wrap_type.hpp
Normal file
75
ngraph/core/include/openvino/pass/pattern/op/wrap_type.hpp
Normal file
@ -0,0 +1,75 @@
|
|||||||
|
// Copyright (C) 2018-2021 Intel Corporation
|
||||||
|
// SPDX-License-Identifier: Apache-2.0
|
||||||
|
//
|
||||||
|
|
||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include "openvino/core/node.hpp"
|
||||||
|
#include "openvino/pass/pattern/op/pattern.hpp"
|
||||||
|
|
||||||
|
namespace ov {
|
||||||
|
namespace pass {
|
||||||
|
namespace pattern {
|
||||||
|
namespace op {
|
||||||
|
class OPENVINO_API WrapType : public Pattern {
|
||||||
|
public:
|
||||||
|
static constexpr NodeTypeInfo type_info{"patternAnyType", 0};
|
||||||
|
const NodeTypeInfo& get_type_info() const override;
|
||||||
|
|
||||||
|
explicit WrapType(
|
||||||
|
NodeTypeInfo wrapped_type,
|
||||||
|
const ValuePredicate& pred =
|
||||||
|
[](const Output<Node>& output) {
|
||||||
|
return true;
|
||||||
|
},
|
||||||
|
const OutputVector& input_values = {})
|
||||||
|
: Pattern(input_values, pred),
|
||||||
|
m_wrapped_types({wrapped_type}) {
|
||||||
|
set_output_type(0, element::Type_t::dynamic, PartialShape::dynamic());
|
||||||
|
}
|
||||||
|
|
||||||
|
explicit WrapType(
|
||||||
|
std::vector<NodeTypeInfo> wrapped_types,
|
||||||
|
const ValuePredicate& pred =
|
||||||
|
[](const Output<Node>& output) {
|
||||||
|
return true;
|
||||||
|
},
|
||||||
|
const OutputVector& input_values = {})
|
||||||
|
: Pattern(input_values, pred),
|
||||||
|
m_wrapped_types(std::move(wrapped_types)) {
|
||||||
|
set_output_type(0, element::Type_t::dynamic, PartialShape::dynamic());
|
||||||
|
}
|
||||||
|
|
||||||
|
bool match_value(pattern::Matcher* matcher,
|
||||||
|
const Output<Node>& pattern_value,
|
||||||
|
const Output<Node>& graph_value) override;
|
||||||
|
|
||||||
|
NodeTypeInfo get_wrapped_type() const;
|
||||||
|
|
||||||
|
const std::vector<NodeTypeInfo>& get_wrapped_types() const;
|
||||||
|
|
||||||
|
private:
|
||||||
|
std::vector<NodeTypeInfo> m_wrapped_types;
|
||||||
|
};
|
||||||
|
} // namespace op
|
||||||
|
|
||||||
|
template <class... Args>
|
||||||
|
std::shared_ptr<Node> wrap_type(const OutputVector& inputs, const pattern::op::ValuePredicate& pred) {
|
||||||
|
std::vector<DiscreteTypeInfo> info{Args::type_info...};
|
||||||
|
return std::make_shared<op::WrapType>(info, pred, inputs);
|
||||||
|
}
|
||||||
|
|
||||||
|
template <class... Args>
|
||||||
|
std::shared_ptr<Node> wrap_type(const OutputVector& inputs = {}) {
|
||||||
|
return wrap_type<Args...>(inputs, [](const Output<Node>& output) {
|
||||||
|
return true;
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
template <class... Args>
|
||||||
|
std::shared_ptr<Node> wrap_type(const pattern::op::ValuePredicate& pred) {
|
||||||
|
return wrap_type<Args...>({}, pred);
|
||||||
|
}
|
||||||
|
} // namespace pattern
|
||||||
|
} // namespace pass
|
||||||
|
} // namespace ov
|
32
ngraph/core/include/openvino/pass/validate.hpp
Normal file
32
ngraph/core/include/openvino/pass/validate.hpp
Normal file
@ -0,0 +1,32 @@
|
|||||||
|
// Copyright (C) 2018-2021 Intel Corporation
|
||||||
|
// SPDX-License-Identifier: Apache-2.0
|
||||||
|
//
|
||||||
|
|
||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include "openvino/core/core_visibility.hpp"
|
||||||
|
#include "openvino/pass/pass.hpp"
|
||||||
|
|
||||||
|
namespace ov {
|
||||||
|
namespace pass {
|
||||||
|
/// \brief The Validate pass performs sanity checks on attributes and inputs, and
|
||||||
|
/// computes output shapes and element types for all computation nodes in a given
|
||||||
|
/// computation graph.
|
||||||
|
///
|
||||||
|
/// \details The verification and inference is done via invoking each node's specific
|
||||||
|
/// implementation of \link ov::Node::validate_and_infer_types() \endlink function.
|
||||||
|
///
|
||||||
|
/// By default, the \ref ov::pass::Manager runs this pass after executing every
|
||||||
|
/// optimization pass. This is to ensure that any update to the graph by an optimization
|
||||||
|
/// pass does not break the shape and data type requirement on a computation node.
|
||||||
|
/// This default validation run can be changed via calling the
|
||||||
|
/// \link ov::pass::Manager::set_per_pass_validation(bool) \endlink function.
|
||||||
|
class OPENVINO_API Validate : public FunctionPass {
|
||||||
|
public:
|
||||||
|
OPENVINO_RTTI_DECLARATION;
|
||||||
|
|
||||||
|
Validate() : FunctionPass() {}
|
||||||
|
bool run_on_function(std::shared_ptr<ov::Function> f) override;
|
||||||
|
};
|
||||||
|
} // namespace pass
|
||||||
|
} // namespace ov
|
57
ngraph/core/include/openvino/pass/visualize_tree.hpp
Normal file
57
ngraph/core/include/openvino/pass/visualize_tree.hpp
Normal file
@ -0,0 +1,57 @@
|
|||||||
|
// Copyright (C) 2018-2021 Intel Corporation
|
||||||
|
// SPDX-License-Identifier: Apache-2.0
|
||||||
|
//
|
||||||
|
|
||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include <functional>
|
||||||
|
#include <set>
|
||||||
|
#include <sstream>
|
||||||
|
#include <string>
|
||||||
|
#include <typeindex>
|
||||||
|
#include <typeinfo>
|
||||||
|
#include <unordered_map>
|
||||||
|
#include <utility>
|
||||||
|
|
||||||
|
#include "openvino/pass/pass.hpp"
|
||||||
|
|
||||||
|
class HeightMap;
|
||||||
|
|
||||||
|
using visualize_tree_ops_map_t =
|
||||||
|
std::unordered_map<ov::Node::type_info_t, std::function<void(const ov::Node&, std::ostream& ss)>>;
|
||||||
|
|
||||||
|
namespace ov {
|
||||||
|
namespace pass {
|
||||||
|
class OPENVINO_API VisualizeTree : public FunctionPass {
|
||||||
|
public:
|
||||||
|
OPENVINO_RTTI_DECLARATION;
|
||||||
|
|
||||||
|
using node_modifiers_t = std::function<void(const Node& node, std::vector<std::string>& attributes)>;
|
||||||
|
VisualizeTree(const std::string& file_name, node_modifiers_t nm = nullptr, bool dot_only = false);
|
||||||
|
bool run_on_function(std::shared_ptr<ov::Function>) override;
|
||||||
|
|
||||||
|
void set_ops_to_details(const visualize_tree_ops_map_t& ops_map) {
|
||||||
|
m_ops_to_details = ops_map;
|
||||||
|
}
|
||||||
|
|
||||||
|
protected:
|
||||||
|
void add_node_arguments(std::shared_ptr<Node> node,
|
||||||
|
std::unordered_map<Node*, HeightMap>& height_maps,
|
||||||
|
size_t& fake_node_ctr);
|
||||||
|
std::string add_attributes(std::shared_ptr<Node> node);
|
||||||
|
virtual std::string get_attributes(std::shared_ptr<Node> node);
|
||||||
|
virtual std::string get_node_name(std::shared_ptr<Node> node);
|
||||||
|
std::string get_constant_value(std::shared_ptr<Node> node, size_t max_elements = 7);
|
||||||
|
|
||||||
|
void render() const;
|
||||||
|
|
||||||
|
std::stringstream m_ss;
|
||||||
|
std::string m_name;
|
||||||
|
std::set<std::shared_ptr<Node>> m_nodes_with_attributes;
|
||||||
|
visualize_tree_ops_map_t m_ops_to_details;
|
||||||
|
node_modifiers_t m_node_modifiers = nullptr;
|
||||||
|
bool m_dot_only;
|
||||||
|
static const int max_jump_distance;
|
||||||
|
};
|
||||||
|
} // namespace pass
|
||||||
|
} // namespace ov
|
@ -11,11 +11,10 @@
|
|||||||
#include "ngraph/validation_util.hpp"
|
#include "ngraph/validation_util.hpp"
|
||||||
|
|
||||||
using namespace std;
|
using namespace std;
|
||||||
using namespace ngraph;
|
|
||||||
|
|
||||||
NGRAPH_RTTI_DEFINITION(ngraph::pass::ConstantFolding, "ConstantFolding", 0);
|
OPENVINO_RTTI_DEFINITION(ov::pass::ConstantFolding, "ConstantFolding", 0);
|
||||||
|
|
||||||
bool ngraph::pass::ConstantFolding::run_on_function(std::shared_ptr<ngraph::Function> f) {
|
bool ov::pass::ConstantFolding::run_on_function(std::shared_ptr<ov::Function> f) {
|
||||||
bool rewritten = pre_calculated_values_folding(f);
|
bool rewritten = pre_calculated_values_folding(f);
|
||||||
|
|
||||||
for (const auto& node : f->get_ordered_ops()) {
|
for (const auto& node : f->get_ordered_ops()) {
|
||||||
@ -48,7 +47,7 @@ bool ngraph::pass::ConstantFolding::run_on_function(std::shared_ptr<ngraph::Func
|
|||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
// recursively constant fold operators containing subgraphs (ie: TensorIterator, Loop)
|
// recursively constant fold operators containing subgraphs (ie: TensorIterator, Loop)
|
||||||
if (auto sub_graph_node = std::dynamic_pointer_cast<op::util::SubGraphOp>(node)) {
|
if (auto sub_graph_node = std::dynamic_pointer_cast<ngraph::op::util::SubGraphOp>(node)) {
|
||||||
if (const auto& sub_graph = sub_graph_node->get_function()) {
|
if (const auto& sub_graph = sub_graph_node->get_function()) {
|
||||||
rewritten |= run_on_function(sub_graph);
|
rewritten |= run_on_function(sub_graph);
|
||||||
}
|
}
|
||||||
@ -79,14 +78,14 @@ bool ngraph::pass::ConstantFolding::pre_calculated_values_folding(const std::sha
|
|||||||
while (!nodes.empty()) {
|
while (!nodes.empty()) {
|
||||||
auto curr_node = nodes.front();
|
auto curr_node = nodes.front();
|
||||||
nodes.pop_front();
|
nodes.pop_front();
|
||||||
if (visited.count(curr_node) || ov::is_type<op::Constant>(curr_node))
|
if (visited.count(curr_node) || ov::is_type<ngraph::op::Constant>(curr_node))
|
||||||
continue;
|
continue;
|
||||||
visited.insert(curr_node);
|
visited.insert(curr_node);
|
||||||
|
|
||||||
for (auto& input_value : curr_node->input_values()) {
|
for (auto& input_value : curr_node->input_values()) {
|
||||||
// Check that ConstantFolding is not disabled on this path
|
// Check that ConstantFolding is not disabled on this path
|
||||||
std::vector<Node*> order;
|
std::vector<Node*> order;
|
||||||
auto status = could_propagate(input_value, order);
|
auto status = ngraph::could_propagate(input_value, order);
|
||||||
if (status) {
|
if (status) {
|
||||||
for (const auto& node : order) {
|
for (const auto& node : order) {
|
||||||
const auto& rt_info = node->get_rt_info();
|
const auto& rt_info = node->get_rt_info();
|
||||||
@ -99,8 +98,8 @@ bool ngraph::pass::ConstantFolding::pre_calculated_values_folding(const std::sha
|
|||||||
|
|
||||||
if (status && input_value.get_tensor().has_and_set_bound()) {
|
if (status && input_value.get_tensor().has_and_set_bound()) {
|
||||||
auto input_node = input_value.get_node_shared_ptr();
|
auto input_node = input_value.get_node_shared_ptr();
|
||||||
auto replacement = std::make_shared<op::Constant>(input_value.get_tensor().get_lower_value());
|
auto replacement = std::make_shared<ngraph::op::Constant>(input_value.get_tensor().get_lower_value());
|
||||||
if (replacement && !ov::is_type<op::Constant>(input_node)) {
|
if (replacement && !ov::is_type<ngraph::op::Constant>(input_node)) {
|
||||||
if (input_node->get_output_size() == 1) {
|
if (input_node->get_output_size() == 1) {
|
||||||
replacement->set_friendly_name(input_node->get_friendly_name());
|
replacement->set_friendly_name(input_node->get_friendly_name());
|
||||||
} else {
|
} else {
|
||||||
|
@ -9,12 +9,11 @@
|
|||||||
#include "transformations/convert_precision.hpp"
|
#include "transformations/convert_precision.hpp"
|
||||||
|
|
||||||
using namespace std;
|
using namespace std;
|
||||||
using namespace ngraph;
|
|
||||||
|
|
||||||
NGRAPH_RTTI_DEFINITION(ngraph::pass::ConvertFP32ToFP16, "ConvertFP32ToFP16", 0);
|
OPENVINO_RTTI_DEFINITION(ov::pass::ConvertFP32ToFP16, "ConvertFP32ToFP16", 0);
|
||||||
|
|
||||||
bool ngraph::pass::ConvertFP32ToFP16::run_on_function(std::shared_ptr<ngraph::Function> f) {
|
bool ov::pass::ConvertFP32ToFP16::run_on_function(std::shared_ptr<ov::Function> f) {
|
||||||
ngraph::pass::Manager m(get_pass_config());
|
ov::pass::Manager m(get_pass_config());
|
||||||
m.register_pass<ngraph::pass::ConvertPrecision>(precisions_array{{ngraph::element::f32, ngraph::element::f16}});
|
m.register_pass<ngraph::pass::ConvertPrecision>(precisions_array{{ngraph::element::f32, ngraph::element::f16}});
|
||||||
m.run_passes(f);
|
m.run_passes(f);
|
||||||
return false;
|
return false;
|
||||||
|
@ -18,9 +18,6 @@
|
|||||||
#include "ngraph/op/util/sub_graph_base.hpp"
|
#include "ngraph/op/util/sub_graph_base.hpp"
|
||||||
#include "perf_counters.hpp"
|
#include "perf_counters.hpp"
|
||||||
|
|
||||||
using namespace std;
|
|
||||||
using namespace ngraph;
|
|
||||||
|
|
||||||
/* GraphRewrite algorithm:
|
/* GraphRewrite algorithm:
|
||||||
* GraphRewrite processes an input graph in an topological order(i.e. args before users)
|
* GraphRewrite processes an input graph in an topological order(i.e. args before users)
|
||||||
* Given the following graph: Abs2
|
* Given the following graph: Abs2
|
||||||
@ -33,7 +30,7 @@ using namespace ngraph;
|
|||||||
* Note, `Abs2` comes before `Neg3` as `Abs2`'s id = 2 is *less* than `Neg3`'s one (id = 3)
|
* Note, `Abs2` comes before `Neg3` as `Abs2`'s id = 2 is *less* than `Neg3`'s one (id = 3)
|
||||||
* Next, GraphRewrite will invoke matchers passes registered in add_matcher order.
|
* Next, GraphRewrite will invoke matchers passes registered in add_matcher order.
|
||||||
* For example:
|
* For example:
|
||||||
* ngraph::pass::GraphRewrite pass;
|
* ov::pass::GraphRewrite pass;
|
||||||
* pass.add_matcher<m1>();
|
* pass.add_matcher<m1>();
|
||||||
* pass.add_matcher<m2>();
|
* pass.add_matcher<m2>();
|
||||||
* pass.add_matcher<m3>();
|
* pass.add_matcher<m3>();
|
||||||
@ -53,13 +50,13 @@ using namespace ngraph;
|
|||||||
* If MatcherPass register more than one node make sure that this nodes are registered in
|
* If MatcherPass register more than one node make sure that this nodes are registered in
|
||||||
* topological order. */
|
* topological order. */
|
||||||
|
|
||||||
NGRAPH_RTTI_DEFINITION(ngraph::pass::GraphRewrite, "ngraph::pass::GraphRewrite", 0);
|
NGRAPH_RTTI_DEFINITION(ov::pass::GraphRewrite, "ov::pass::GraphRewrite", 0);
|
||||||
|
|
||||||
NGRAPH_RTTI_DEFINITION(ngraph::pass::BackwardGraphRewrite, "ngraph::pass::BackwardGraphRewrite", 0);
|
NGRAPH_RTTI_DEFINITION(ov::pass::BackwardGraphRewrite, "ov::pass::BackwardGraphRewrite", 0);
|
||||||
|
|
||||||
NGRAPH_RTTI_DEFINITION(ngraph::pass::MatcherPass, "ngraph::pass::MatcherPass", 0);
|
NGRAPH_RTTI_DEFINITION(ov::pass::MatcherPass, "ov::pass::MatcherPass", 0);
|
||||||
|
|
||||||
namespace ngraph {
|
namespace ov {
|
||||||
namespace pass {
|
namespace pass {
|
||||||
namespace internal {
|
namespace internal {
|
||||||
PerfCounters& perf_counters_graph_rewrite() {
|
PerfCounters& perf_counters_graph_rewrite() {
|
||||||
@ -68,27 +65,28 @@ PerfCounters& perf_counters_graph_rewrite() {
|
|||||||
}
|
}
|
||||||
} // namespace internal
|
} // namespace internal
|
||||||
} // namespace pass
|
} // namespace pass
|
||||||
} // namespace ngraph
|
} // namespace ov
|
||||||
|
|
||||||
bool pass::BackwardGraphRewrite::run_on_function(std::shared_ptr<ngraph::Function> f) {
|
bool ov::pass::BackwardGraphRewrite::run_on_function(std::shared_ptr<ov::Function> f) {
|
||||||
// Initialize execution queue with nodes in topological order
|
// Initialize execution queue with nodes in topological order
|
||||||
deque<std::weak_ptr<Node>> nodes_to_run;
|
std::deque<std::weak_ptr<Node>> nodes_to_run;
|
||||||
for (auto& node : f->get_ordered_ops()) {
|
for (auto& node : f->get_ordered_ops()) {
|
||||||
nodes_to_run.emplace_front(node);
|
nodes_to_run.emplace_front(node);
|
||||||
}
|
}
|
||||||
return apply_matcher_passes(f, std::move(nodes_to_run));
|
return apply_matcher_passes(f, std::move(nodes_to_run));
|
||||||
}
|
}
|
||||||
|
|
||||||
bool pass::GraphRewrite::run_on_function(std::shared_ptr<ngraph::Function> f) {
|
bool ov::pass::GraphRewrite::run_on_function(std::shared_ptr<ov::Function> f) {
|
||||||
// Initialize execution queue with nodes in topological order
|
// Initialize execution queue with nodes in topological order
|
||||||
deque<std::weak_ptr<Node>> nodes_to_run;
|
std::deque<std::weak_ptr<Node>> nodes_to_run;
|
||||||
for (auto& node : f->get_ordered_ops()) {
|
for (auto& node : f->get_ordered_ops()) {
|
||||||
nodes_to_run.emplace_back(node);
|
nodes_to_run.emplace_back(node);
|
||||||
}
|
}
|
||||||
return apply_matcher_passes(f, std::move(nodes_to_run));
|
return apply_matcher_passes(f, std::move(nodes_to_run));
|
||||||
}
|
}
|
||||||
|
|
||||||
bool pass::GraphRewrite::apply_matcher_passes(shared_ptr<Function> f, deque<std::weak_ptr<Node>> nodes_to_run) {
|
bool ov::pass::GraphRewrite::apply_matcher_passes(std::shared_ptr<Function> f,
|
||||||
|
std::deque<std::weak_ptr<Node>> nodes_to_run) {
|
||||||
OV_ITT_SCOPED_TASK(ov::itt::domains::nGraph, "pass::GraphRewrite::run_on_function");
|
OV_ITT_SCOPED_TASK(ov::itt::domains::nGraph, "pass::GraphRewrite::run_on_function");
|
||||||
|
|
||||||
bool rewritten = false;
|
bool rewritten = false;
|
||||||
@ -111,7 +109,7 @@ bool pass::GraphRewrite::apply_matcher_passes(shared_ptr<Function> f, deque<std:
|
|||||||
auto root = matcher->get_pattern_value().get_node_shared_ptr();
|
auto root = matcher->get_pattern_value().get_node_shared_ptr();
|
||||||
// pattern::op::AnyOutput operation automatically appends for multi output operations inside
|
// pattern::op::AnyOutput operation automatically appends for multi output operations inside
|
||||||
// Matcher and to gen actual root node we need to take it's parent.
|
// Matcher and to gen actual root node we need to take it's parent.
|
||||||
if (auto any_type = dynamic_pointer_cast<pattern::op::AnyOutput>(root)) {
|
if (auto any_type = std::dynamic_pointer_cast<pattern::op::AnyOutput>(root)) {
|
||||||
root = any_type->input_value(0).get_node_shared_ptr();
|
root = any_type->input_value(0).get_node_shared_ptr();
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -119,8 +117,8 @@ bool pass::GraphRewrite::apply_matcher_passes(shared_ptr<Function> f, deque<std:
|
|||||||
// it's type
|
// it's type
|
||||||
// and use it in unordered_map as key for fast MatcherPass search. Otherwise type is unknown
|
// and use it in unordered_map as key for fast MatcherPass search. Otherwise type is unknown
|
||||||
// and default algorithm is used.
|
// and default algorithm is used.
|
||||||
if (auto p = dynamic_pointer_cast<pattern::op::Pattern>(root)) {
|
if (auto p = std::dynamic_pointer_cast<pattern::op::Pattern>(root)) {
|
||||||
if (auto any_type = dynamic_pointer_cast<pattern::op::WrapType>(p)) {
|
if (auto any_type = std::dynamic_pointer_cast<pattern::op::WrapType>(p)) {
|
||||||
for (const auto& root_type_info : any_type->get_wrapped_types()) {
|
for (const auto& root_type_info : any_type->get_wrapped_types()) {
|
||||||
type_to_matcher[root_type_info].push_back(matcher_index);
|
type_to_matcher[root_type_info].push_back(matcher_index);
|
||||||
}
|
}
|
||||||
@ -180,7 +178,7 @@ bool pass::GraphRewrite::apply_matcher_passes(shared_ptr<Function> f, deque<std:
|
|||||||
continue;
|
continue;
|
||||||
|
|
||||||
// Recursive apply Matchers for sub-graph based nodes
|
// Recursive apply Matchers for sub-graph based nodes
|
||||||
if (auto sub_graph_node = std::dynamic_pointer_cast<op::util::SubGraphOp>(node)) {
|
if (auto sub_graph_node = std::dynamic_pointer_cast<ngraph::op::util::SubGraphOp>(node)) {
|
||||||
if (auto sub_graph = sub_graph_node->get_function()) {
|
if (auto sub_graph = sub_graph_node->get_function()) {
|
||||||
run_on_function(sub_graph);
|
run_on_function(sub_graph);
|
||||||
}
|
}
|
||||||
@ -236,9 +234,9 @@ bool pass::GraphRewrite::apply_matcher_passes(shared_ptr<Function> f, deque<std:
|
|||||||
return rewritten;
|
return rewritten;
|
||||||
}
|
}
|
||||||
|
|
||||||
void pass::GraphRewrite::add_matcher(const shared_ptr<pattern::Matcher>& m,
|
void ov::pass::GraphRewrite::add_matcher(const std::shared_ptr<pattern::Matcher>& m,
|
||||||
const graph_rewrite_callback& callback,
|
const graph_rewrite_callback& callback,
|
||||||
const PassPropertyMask& property) {
|
const PassPropertyMask& property) {
|
||||||
m_matchers.push_back(std::make_shared<MatcherPass>(
|
m_matchers.push_back(std::make_shared<MatcherPass>(
|
||||||
m->get_name(),
|
m->get_name(),
|
||||||
m,
|
m,
|
||||||
@ -258,7 +256,8 @@ void pass::GraphRewrite::add_matcher(const shared_ptr<pattern::Matcher>& m,
|
|||||||
property));
|
property));
|
||||||
}
|
}
|
||||||
|
|
||||||
void pass::GraphRewrite::add_matcher(const shared_ptr<pattern::Matcher>& m, const graph_rewrite_callback& callback) {
|
void ov::pass::GraphRewrite::add_matcher(const std::shared_ptr<pattern::Matcher>& m,
|
||||||
|
const graph_rewrite_callback& callback) {
|
||||||
NGRAPH_SUPPRESS_DEPRECATED_START
|
NGRAPH_SUPPRESS_DEPRECATED_START
|
||||||
// TODO: before deprecate this function, by default expect the
|
// TODO: before deprecate this function, by default expect the
|
||||||
// callback require static shape.
|
// callback require static shape.
|
||||||
@ -266,7 +265,7 @@ void pass::GraphRewrite::add_matcher(const shared_ptr<pattern::Matcher>& m, cons
|
|||||||
NGRAPH_SUPPRESS_DEPRECATED_END
|
NGRAPH_SUPPRESS_DEPRECATED_END
|
||||||
}
|
}
|
||||||
|
|
||||||
void pass::GraphRewrite::set_pass_config(const std::shared_ptr<PassConfig>& rhs) {
|
void ov::pass::GraphRewrite::set_pass_config(const std::shared_ptr<PassConfig>& rhs) {
|
||||||
auto pass_config = get_pass_config();
|
auto pass_config = get_pass_config();
|
||||||
// We have to preserve disabled passes because in case when we register matchers inside
|
// We have to preserve disabled passes because in case when we register matchers inside
|
||||||
// GraphRewrite c-tor we work with local PassConfig instance.
|
// GraphRewrite c-tor we work with local PassConfig instance.
|
||||||
@ -293,9 +292,9 @@ void pass::GraphRewrite::set_pass_config(const std::shared_ptr<PassConfig>& rhs)
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
void pass::RecurrentGraphRewrite::add_matcher(const std::shared_ptr<pattern::RecurrentMatcher>& m,
|
void ov::pass::RecurrentGraphRewrite::add_matcher(const std::shared_ptr<pattern::RecurrentMatcher>& m,
|
||||||
const ngraph::recurrent_graph_rewrite_callback& callback,
|
const ov::recurrent_graph_rewrite_callback& callback,
|
||||||
const PassPropertyMask& property) {
|
const PassPropertyMask& property) {
|
||||||
m_matchers.push_back(std::make_shared<MatcherPass>(
|
m_matchers.push_back(std::make_shared<MatcherPass>(
|
||||||
"Recurrent matcher",
|
"Recurrent matcher",
|
||||||
nullptr,
|
nullptr,
|
||||||
@ -310,24 +309,24 @@ void pass::RecurrentGraphRewrite::add_matcher(const std::shared_ptr<pattern::Rec
|
|||||||
property));
|
property));
|
||||||
}
|
}
|
||||||
|
|
||||||
void pass::RecurrentGraphRewrite::add_matcher(const std::shared_ptr<pattern::RecurrentMatcher>& m,
|
void ov::pass::RecurrentGraphRewrite::add_matcher(const std::shared_ptr<pattern::RecurrentMatcher>& m,
|
||||||
const ngraph::recurrent_graph_rewrite_callback& callback) {
|
const ov::recurrent_graph_rewrite_callback& callback) {
|
||||||
// TODO: before deprecate this function, by default expect the
|
// TODO: before deprecate this function, by default expect the
|
||||||
// callback require static shape.
|
// callback require static shape.
|
||||||
add_matcher(m, callback, {PassProperty::REQUIRE_STATIC_SHAPE});
|
add_matcher(m, callback, {PassProperty::REQUIRE_STATIC_SHAPE});
|
||||||
}
|
}
|
||||||
|
|
||||||
bool pass::RecurrentGraphRewrite::run_on_function(shared_ptr<Function> f) {
|
bool ov::pass::RecurrentGraphRewrite::run_on_function(std::shared_ptr<Function> f) {
|
||||||
bool changed = false;
|
bool changed = false;
|
||||||
size_t i = 0;
|
size_t i = 0;
|
||||||
|
|
||||||
// This check is very expensive and is only needed for experimental features, so we will hide
|
// This check is very expensive and is only needed for experimental features, so we will hide
|
||||||
// it behind an environment variable for now. TODO: Find a less expensive way to handle this.
|
// it behind an environment variable for now. TODO: Find a less expensive way to handle this.
|
||||||
static bool s_rerun_dynamic_check = getenv_bool("NGRAPH_GRAPH_REWRITE_RERUN_DYNAMIC_CHECK");
|
static bool s_rerun_dynamic_check = ngraph::getenv_bool("NGRAPH_GRAPH_REWRITE_RERUN_DYNAMIC_CHECK");
|
||||||
|
|
||||||
auto run_matchers = [&]() -> bool {
|
auto run_matchers = [&]() -> bool {
|
||||||
bool is_dyn_func = s_rerun_dynamic_check && f->is_dynamic();
|
bool is_dyn_func = s_rerun_dynamic_check && f->is_dynamic();
|
||||||
for (auto node : f->get_ops()) {
|
for (const auto& node : f->get_ops()) {
|
||||||
for (auto& m_pass : m_matchers) {
|
for (auto& m_pass : m_matchers) {
|
||||||
if (is_dyn_func && m_pass->get_property(PassProperty::REQUIRE_STATIC_SHAPE)) {
|
if (is_dyn_func && m_pass->get_property(PassProperty::REQUIRE_STATIC_SHAPE)) {
|
||||||
NGRAPH_DEBUG << "matcher callback requires static shape but the "
|
NGRAPH_DEBUG << "matcher callback requires static shape but the "
|
||||||
@ -356,9 +355,9 @@ bool pass::RecurrentGraphRewrite::run_on_function(shared_ptr<Function> f) {
|
|||||||
return changed;
|
return changed;
|
||||||
}
|
}
|
||||||
|
|
||||||
void ngraph::pass::MatcherPass::register_matcher(const std::shared_ptr<ngraph::pattern::Matcher>& m,
|
void ov::pass::MatcherPass::register_matcher(const std::shared_ptr<ov::pass::pattern::Matcher>& m,
|
||||||
const ngraph::graph_rewrite_callback& callback,
|
const ov::graph_rewrite_callback& callback,
|
||||||
const PassPropertyMask& property) {
|
const PassPropertyMask& property) {
|
||||||
set_name(m->get_name());
|
set_name(m->get_name());
|
||||||
set_property(property, true);
|
set_property(property, true);
|
||||||
m_matcher = m;
|
m_matcher = m;
|
||||||
@ -376,7 +375,7 @@ void ngraph::pass::MatcherPass::register_matcher(const std::shared_ptr<ngraph::p
|
|||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
|
||||||
bool ngraph::pass::MatcherPass::apply(std::shared_ptr<ngraph::Node> node) {
|
bool ov::pass::MatcherPass::apply(std::shared_ptr<ov::Node> node) {
|
||||||
OV_ITT_SCOPED_TASK(ov::itt::domains::nGraph, pass::internal::perf_counters_graph_rewrite()[get_type_info()]);
|
OV_ITT_SCOPED_TASK(ov::itt::domains::nGraph, pass::internal::perf_counters_graph_rewrite()[get_type_info()]);
|
||||||
m_new_nodes.clear();
|
m_new_nodes.clear();
|
||||||
if (m_handler)
|
if (m_handler)
|
||||||
|
@ -12,13 +12,12 @@
|
|||||||
#include <ngraph/rt_info.hpp>
|
#include <ngraph/rt_info.hpp>
|
||||||
#include <ngraph/variant.hpp>
|
#include <ngraph/variant.hpp>
|
||||||
|
|
||||||
NGRAPH_RTTI_DEFINITION(ngraph::pass::LowLatency2, "LowLatency2", 0);
|
NGRAPH_RTTI_DEFINITION(ov::pass::LowLatency2, "LowLatency2", 0);
|
||||||
|
|
||||||
NGRAPH_SUPPRESS_DEPRECATED_START
|
NGRAPH_SUPPRESS_DEPRECATED_START
|
||||||
NGRAPH_RTTI_DEFINITION(ngraph::pass::LowLatency, "LowLatency", 0);
|
NGRAPH_RTTI_DEFINITION(ngraph::pass::LowLatency, "LowLatency", 0);
|
||||||
|
|
||||||
using namespace std;
|
using namespace std;
|
||||||
using namespace ngraph;
|
|
||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
string generate_variable_name(const string& op_name, const string& param_name, int variable_idx) {
|
string generate_variable_name(const string& op_name, const string& param_name, int variable_idx) {
|
||||||
@ -27,8 +26,8 @@ string generate_variable_name(const string& op_name, const string& param_name, i
|
|||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
ngraph::pass::LowLatency::LowLatency() {
|
ngraph::pass::LowLatency::LowLatency() {
|
||||||
auto tensor_iterator = ngraph::pattern::wrap_type<opset6::TensorIterator, opset6::Loop>();
|
auto tensor_iterator = ov::pass::pattern::wrap_type<opset6::TensorIterator, opset6::Loop>();
|
||||||
ngraph::matcher_pass_callback callback = [](ngraph::pattern::Matcher& m) {
|
ov::matcher_pass_callback callback = [](ov::pass::pattern::Matcher& m) {
|
||||||
const auto& sub_graph_op = std::dynamic_pointer_cast<ngraph::op::util::SubGraphOp>(m.get_match_root());
|
const auto& sub_graph_op = std::dynamic_pointer_cast<ngraph::op::util::SubGraphOp>(m.get_match_root());
|
||||||
if (!sub_graph_op) {
|
if (!sub_graph_op) {
|
||||||
return false;
|
return false;
|
||||||
@ -38,7 +37,7 @@ ngraph::pass::LowLatency::LowLatency() {
|
|||||||
const auto& trip_count = std::dynamic_pointer_cast<opset6::Constant>(loop->get_input_node_shared_ptr(0));
|
const auto& trip_count = std::dynamic_pointer_cast<opset6::Constant>(loop->get_input_node_shared_ptr(0));
|
||||||
const auto& num_iter = loop->get_num_iterations();
|
const auto& num_iter = loop->get_num_iterations();
|
||||||
if (trip_count && num_iter > 0 && trip_count->get_output_target_inputs(0).size() == 1) {
|
if (trip_count && num_iter > 0 && trip_count->get_output_target_inputs(0).size() == 1) {
|
||||||
auto single_iter = std::make_shared<opset6::Constant>(ngraph::element::i64, Shape{}, 1);
|
auto single_iter = std::make_shared<opset6::Constant>(ov::element::i64, Shape{}, 1);
|
||||||
replace_node(trip_count, single_iter);
|
replace_node(trip_count, single_iter);
|
||||||
} else {
|
} else {
|
||||||
// count of iterations is dynamic;
|
// count of iterations is dynamic;
|
||||||
@ -47,7 +46,7 @@ ngraph::pass::LowLatency::LowLatency() {
|
|||||||
}
|
}
|
||||||
// Mark the TI layer to be unrolled. Enable unconditional ti unrolling for all plugins.
|
// Mark the TI layer to be unrolled. Enable unconditional ti unrolling for all plugins.
|
||||||
auto& rt_info = sub_graph_op->get_rt_info();
|
auto& rt_info = sub_graph_op->get_rt_info();
|
||||||
rt_info["UNROLL_TI"] = std::make_shared<ngraph::VariantWrapper<int64_t>>(1);
|
rt_info["UNROLL_TI"] = std::make_shared<ov::VariantWrapper<int64_t>>(1);
|
||||||
|
|
||||||
int64_t variable_id = 0;
|
int64_t variable_id = 0;
|
||||||
std::vector<std::shared_ptr<ngraph::op::Sink>> assigns;
|
std::vector<std::shared_ptr<ngraph::op::Sink>> assigns;
|
||||||
@ -87,13 +86,14 @@ ngraph::pass::LowLatency::LowLatency() {
|
|||||||
return false;
|
return false;
|
||||||
};
|
};
|
||||||
|
|
||||||
auto m = std::make_shared<ngraph::pattern::Matcher>(tensor_iterator, "LowLatency");
|
auto m = std::make_shared<ov::pass::pattern::Matcher>(tensor_iterator, "LowLatency");
|
||||||
register_matcher(m, callback);
|
register_matcher(m, callback);
|
||||||
}
|
}
|
||||||
NGRAPH_SUPPRESS_DEPRECATED_END
|
NGRAPH_SUPPRESS_DEPRECATED_END
|
||||||
|
|
||||||
void UnrollSingleIteration(const shared_ptr<op::util::SubGraphOp>& sub_graph_op, const shared_ptr<Function>& outer_f) {
|
void UnrollSingleIteration(const shared_ptr<ngraph::op::util::SubGraphOp>& sub_graph_op,
|
||||||
using namespace opset7;
|
const shared_ptr<ov::Function>& outer_f) {
|
||||||
|
using namespace ngraph::opset7;
|
||||||
|
|
||||||
const auto& params = sub_graph_op->get_function()->get_parameters();
|
const auto& params = sub_graph_op->get_function()->get_parameters();
|
||||||
const auto& results = sub_graph_op->get_function()->get_results();
|
const auto& results = sub_graph_op->get_function()->get_results();
|
||||||
@ -109,7 +109,7 @@ void UnrollSingleIteration(const shared_ptr<op::util::SubGraphOp>& sub_graph_op,
|
|||||||
|
|
||||||
// before: TI [...-> Layer1 -> Result -> output] -> Layer2 -> ...
|
// before: TI [...-> Layer1 -> Result -> output] -> Layer2 -> ...
|
||||||
// after: ...-> Layer1 -> Layer2 -> ...
|
// after: ...-> Layer1 -> Layer2 -> ...
|
||||||
NodeVector new_ops;
|
ov::NodeVector new_ops;
|
||||||
for (const auto& out : sub_graph_op->get_output_descriptions()) {
|
for (const auto& out : sub_graph_op->get_output_descriptions()) {
|
||||||
const auto& connect_to = results.at(out->m_body_value_index)->get_input_source_output(0);
|
const auto& connect_to = results.at(out->m_body_value_index)->get_input_source_output(0);
|
||||||
for (auto& input_to : sub_graph_op->output(out->m_output_index).get_target_inputs()) {
|
for (auto& input_to : sub_graph_op->output(out->m_output_index).get_target_inputs()) {
|
||||||
@ -120,7 +120,7 @@ void UnrollSingleIteration(const shared_ptr<op::util::SubGraphOp>& sub_graph_op,
|
|||||||
|
|
||||||
// IECompatibility: insert identity (Unsqueeze + Squeeze) to store the TensorIterator
|
// IECompatibility: insert identity (Unsqueeze + Squeeze) to store the TensorIterator
|
||||||
// output names
|
// output names
|
||||||
auto axis_1 = Constant::create(ngraph::element::i64, ngraph::Shape{1}, {1});
|
auto axis_1 = Constant::create(ov::element::i64, ngraph::Shape{1}, {1});
|
||||||
auto identity_1 = std::make_shared<Unsqueeze>(connect_to, axis_1);
|
auto identity_1 = std::make_shared<Unsqueeze>(connect_to, axis_1);
|
||||||
auto identity_2 = std::make_shared<Squeeze>(identity_1, axis_1);
|
auto identity_2 = std::make_shared<Squeeze>(identity_1, axis_1);
|
||||||
identity_2->set_friendly_name(out_name);
|
identity_2->set_friendly_name(out_name);
|
||||||
@ -135,36 +135,38 @@ void UnrollSingleIteration(const shared_ptr<op::util::SubGraphOp>& sub_graph_op,
|
|||||||
ngraph::copy_runtime_info(sub_graph_op, new_ops);
|
ngraph::copy_runtime_info(sub_graph_op, new_ops);
|
||||||
}
|
}
|
||||||
|
|
||||||
Output<Node> create_init_subgraph(const shared_ptr<op::util::SubGraphOp>& sub_graph_op, const Output<Node>& in_node) {
|
ngraph::Output<ngraph::Node> create_init_subgraph(const shared_ptr<ngraph::op::util::SubGraphOp>& sub_graph_op,
|
||||||
using namespace opset7;
|
const ngraph::Output<ngraph::Node>& in_node) {
|
||||||
|
using namespace ngraph::opset7;
|
||||||
|
|
||||||
auto const_zero = make_shared<Constant>(in_node.get_element_type(), Shape{1}, 0);
|
auto const_zero = make_shared<Constant>(in_node.get_element_type(), ngraph::Shape{1}, 0);
|
||||||
auto shape_of = make_shared<ShapeOf>(in_node);
|
auto shape_of = make_shared<ShapeOf>(in_node);
|
||||||
auto broadcast = make_shared<Broadcast>(const_zero, shape_of);
|
auto broadcast = make_shared<Broadcast>(const_zero, shape_of);
|
||||||
copy_runtime_info(sub_graph_op, {const_zero, shape_of, broadcast});
|
copy_runtime_info(sub_graph_op, {const_zero, shape_of, broadcast});
|
||||||
return broadcast->output(0);
|
return broadcast->output(0);
|
||||||
}
|
}
|
||||||
|
|
||||||
bool pass::LowLatency2::run_on_function(shared_ptr<Function> f) {
|
bool ov::pass::LowLatency2::run_on_function(shared_ptr<Function> f) {
|
||||||
using namespace opset7;
|
using namespace ngraph::opset7;
|
||||||
|
|
||||||
SinkVector assigns;
|
ngraph::SinkVector assigns;
|
||||||
for (const auto& op : f->get_ordered_ops()) {
|
for (const auto& op : f->get_ordered_ops()) {
|
||||||
if (const auto& sub_graph_op = dynamic_pointer_cast<op::util::SubGraphOp>(op)) {
|
if (const auto& sub_graph_op = dynamic_pointer_cast<ngraph::op::util::SubGraphOp>(op)) {
|
||||||
int64_t variable_id = 0;
|
int64_t variable_id = 0;
|
||||||
const auto& func = sub_graph_op->get_function();
|
const auto& func = sub_graph_op->get_function();
|
||||||
const auto& params = func->get_parameters();
|
const auto& params = func->get_parameters();
|
||||||
for (const auto& in : sub_graph_op->get_input_descriptions()) {
|
for (const auto& in : sub_graph_op->get_input_descriptions()) {
|
||||||
// Process all back edges
|
// Process all back edges
|
||||||
if (const auto& merged_in = dynamic_pointer_cast<op::util::SubGraphOp::MergedInputDescription>(in)) {
|
if (const auto& merged_in =
|
||||||
|
dynamic_pointer_cast<ngraph::op::util::SubGraphOp::MergedInputDescription>(in)) {
|
||||||
// create new Variable
|
// create new Variable
|
||||||
const string& param_name = params.at(merged_in->m_body_parameter_index)->get_friendly_name();
|
const string& param_name = params.at(merged_in->m_body_parameter_index)->get_friendly_name();
|
||||||
const string& var_name =
|
const string& var_name =
|
||||||
generate_variable_name(sub_graph_op->get_friendly_name(), param_name, variable_id);
|
generate_variable_name(sub_graph_op->get_friendly_name(), param_name, variable_id);
|
||||||
|
|
||||||
const auto& input = sub_graph_op->input(merged_in->m_input_index);
|
const auto& input = sub_graph_op->input(merged_in->m_input_index);
|
||||||
if (std::dynamic_pointer_cast<op::ReadValueBase>(input.get_source_output().get_node_shared_ptr()) !=
|
if (std::dynamic_pointer_cast<ngraph::op::ReadValueBase>(
|
||||||
nullptr) {
|
input.get_source_output().get_node_shared_ptr()) != nullptr) {
|
||||||
NGRAPH_DEBUG << "LowLatency2 transformation cannot be applied because the "
|
NGRAPH_DEBUG << "LowLatency2 transformation cannot be applied because the "
|
||||||
<< "ReadValue node is already an input to the TensorIterator."
|
<< "ReadValue node is already an input to the TensorIterator."
|
||||||
<< "LowLatency2 transformation may have already been applied, please "
|
<< "LowLatency2 transformation may have already been applied, please "
|
||||||
@ -175,7 +177,7 @@ bool pass::LowLatency2::run_on_function(shared_ptr<Function> f) {
|
|||||||
const auto& param =
|
const auto& param =
|
||||||
sub_graph_op->get_function()->get_parameters().at(merged_in->m_body_parameter_index);
|
sub_graph_op->get_function()->get_parameters().at(merged_in->m_body_parameter_index);
|
||||||
for (const auto& in_to : param->output(0).get_target_inputs()) {
|
for (const auto& in_to : param->output(0).get_target_inputs()) {
|
||||||
if (dynamic_cast<op::ReadValueBase*>(in_to.get_node()) != nullptr) {
|
if (dynamic_cast<ngraph::op::ReadValueBase*>(in_to.get_node()) != nullptr) {
|
||||||
NGRAPH_DEBUG << "LowLatency2 transformation cannot be applied because the "
|
NGRAPH_DEBUG << "LowLatency2 transformation cannot be applied because the "
|
||||||
<< "ReadValue node is already inside the TensorIterator. "
|
<< "ReadValue node is already inside the TensorIterator. "
|
||||||
<< "LowLatency transformation may have been applied, please do "
|
<< "LowLatency transformation may have been applied, please do "
|
||||||
@ -184,8 +186,8 @@ bool pass::LowLatency2::run_on_function(shared_ptr<Function> f) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
VariableInfo var_info{PartialShape::dynamic(), element::dynamic, var_name};
|
ngraph::VariableInfo var_info{PartialShape::dynamic(), element::dynamic, var_name};
|
||||||
auto variable = make_shared<Variable>(var_info);
|
auto variable = make_shared<ngraph::Variable>(var_info);
|
||||||
|
|
||||||
// insert ReadValue
|
// insert ReadValue
|
||||||
// Layers -> [new op: ReadValue] -> Subgraph operation
|
// Layers -> [new op: ReadValue] -> Subgraph operation
|
||||||
@ -204,12 +206,12 @@ bool pass::LowLatency2::run_on_function(shared_ptr<Function> f) {
|
|||||||
// ---> Layers -> ...
|
// ---> Layers -> ...
|
||||||
*/
|
*/
|
||||||
const auto& out_desc = sub_graph_op->get_output_descriptions();
|
const auto& out_desc = sub_graph_op->get_output_descriptions();
|
||||||
bool is_output_exist =
|
bool is_output_exist = std::any_of(
|
||||||
std::any_of(out_desc.begin(),
|
out_desc.begin(),
|
||||||
out_desc.end(),
|
out_desc.end(),
|
||||||
[&merged_in](const std::shared_ptr<op::util::SubGraphOp::OutputDescription>& out) {
|
[&merged_in](const std::shared_ptr<ngraph::op::util::SubGraphOp::OutputDescription>& out) {
|
||||||
return out->m_body_value_index == merged_in->m_body_value_index;
|
return out->m_body_value_index == merged_in->m_body_value_index;
|
||||||
});
|
});
|
||||||
// Create new output if it doesn't exist.
|
// Create new output if it doesn't exist.
|
||||||
if (!is_output_exist) {
|
if (!is_output_exist) {
|
||||||
sub_graph_op->get_iter_value(func->get_results().at(merged_in->m_body_value_index));
|
sub_graph_op->get_iter_value(func->get_results().at(merged_in->m_body_value_index));
|
||||||
@ -217,7 +219,7 @@ bool pass::LowLatency2::run_on_function(shared_ptr<Function> f) {
|
|||||||
for (const auto& out : sub_graph_op->get_output_descriptions()) {
|
for (const auto& out : sub_graph_op->get_output_descriptions()) {
|
||||||
if (out->m_body_value_index == merged_in->m_body_value_index) {
|
if (out->m_body_value_index == merged_in->m_body_value_index) {
|
||||||
auto assign = make_shared<Assign>(sub_graph_op->output(out->m_output_index), variable);
|
auto assign = make_shared<Assign>(sub_graph_op->output(out->m_output_index), variable);
|
||||||
ngraph::copy_runtime_info(sub_graph_op, assign);
|
copy_runtime_info(sub_graph_op, assign);
|
||||||
// control dependency so that ReadValue is processed before Assign
|
// control dependency so that ReadValue is processed before Assign
|
||||||
assign->add_control_dependency(read_value);
|
assign->add_control_dependency(read_value);
|
||||||
assigns.emplace_back(assign);
|
assigns.emplace_back(assign);
|
||||||
|
@ -24,9 +24,8 @@
|
|||||||
#include "perf_counters.hpp"
|
#include "perf_counters.hpp"
|
||||||
|
|
||||||
using namespace std;
|
using namespace std;
|
||||||
using namespace ngraph;
|
|
||||||
|
|
||||||
namespace ngraph {
|
namespace ov {
|
||||||
namespace pass {
|
namespace pass {
|
||||||
namespace internal {
|
namespace internal {
|
||||||
PerfCounters& perf_counters() {
|
PerfCounters& perf_counters() {
|
||||||
@ -35,25 +34,25 @@ PerfCounters& perf_counters() {
|
|||||||
}
|
}
|
||||||
} // namespace internal
|
} // namespace internal
|
||||||
} // namespace pass
|
} // namespace pass
|
||||||
} // namespace ngraph
|
} // namespace ov
|
||||||
|
|
||||||
pass::Manager::Manager()
|
ov::pass::Manager::Manager()
|
||||||
: m_pass_config(std::make_shared<PassConfig>()),
|
: m_pass_config(std::make_shared<PassConfig>()),
|
||||||
m_visualize(getenv_bool("NGRAPH_ENABLE_VISUALIZE_TRACING")) {}
|
m_visualize(ngraph::getenv_bool("NGRAPH_ENABLE_VISUALIZE_TRACING")) {}
|
||||||
|
|
||||||
pass::Manager::~Manager() {}
|
ov::pass::Manager::~Manager() = default;
|
||||||
|
|
||||||
pass::Manager::Manager(std::shared_ptr<ngraph::pass::PassConfig> pass_config) : m_pass_config(std::move(pass_config)) {}
|
ov::pass::Manager::Manager(std::shared_ptr<ov::pass::PassConfig> pass_config) : m_pass_config(std::move(pass_config)) {}
|
||||||
|
|
||||||
void pass::Manager::run_passes(shared_ptr<Function> func) {
|
void ov::pass::Manager::run_passes(shared_ptr<ov::Function> func) {
|
||||||
NGRAPH_SUPPRESS_DEPRECATED_START
|
NGRAPH_SUPPRESS_DEPRECATED_START
|
||||||
OV_ITT_SCOPED_TASK(ov::itt::domains::nGraph, "pass::Manager::run_passes");
|
OV_ITT_SCOPED_TASK(ov::itt::domains::nGraph, "pass::Manager::run_passes");
|
||||||
|
|
||||||
static bool profile_enabled = getenv_bool("NGRAPH_PROFILE_PASS_ENABLE");
|
static bool profile_enabled = ngraph::getenv_bool("NGRAPH_PROFILE_PASS_ENABLE");
|
||||||
|
|
||||||
size_t index = 0;
|
size_t index = 0;
|
||||||
stopwatch pass_timer;
|
ngraph::stopwatch pass_timer;
|
||||||
stopwatch overall_timer;
|
ngraph::stopwatch overall_timer;
|
||||||
overall_timer.start();
|
overall_timer.start();
|
||||||
bool function_changed = false;
|
bool function_changed = false;
|
||||||
for (auto& pass : m_pass_list) {
|
for (auto& pass : m_pass_list) {
|
||||||
@ -96,13 +95,13 @@ void pass::Manager::run_passes(shared_ptr<Function> func) {
|
|||||||
} else {
|
} else {
|
||||||
function_changed = function_pass->run_on_function(func);
|
function_changed = function_pass->run_on_function(func);
|
||||||
}
|
}
|
||||||
} else if (auto node_pass = dynamic_pointer_cast<NodePass>(pass)) {
|
} else if (auto node_pass = dynamic_pointer_cast<ngraph::pass::NodePass>(pass)) {
|
||||||
if (node_pass->get_property(PassProperty::REQUIRE_STATIC_SHAPE) && func->is_dynamic()) {
|
if (node_pass->get_property(PassProperty::REQUIRE_STATIC_SHAPE) && func->is_dynamic()) {
|
||||||
NGRAPH_DEBUG << "Pass " << pass->get_name() << " requires static shape but the "
|
NGRAPH_DEBUG << "Pass " << pass->get_name() << " requires static shape but the "
|
||||||
<< "function is dynamic. Skipping this transformation";
|
<< "function is dynamic. Skipping this transformation";
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
for (shared_ptr<Node> n : func->get_ops()) {
|
for (const shared_ptr<Node>& n : func->get_ops()) {
|
||||||
function_changed |= node_pass->run_on_node(n);
|
function_changed |= node_pass->run_on_node(n);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -115,7 +114,7 @@ void pass::Manager::run_passes(shared_ptr<Function> func) {
|
|||||||
auto base_filename = func->get_name() + std::string("_") + index_str + std::string("_") + pass->get_name();
|
auto base_filename = func->get_name() + std::string("_") + index_str + std::string("_") + pass->get_name();
|
||||||
|
|
||||||
if (m_visualize) {
|
if (m_visualize) {
|
||||||
static const string format = getenv_string("NGRAPH_VISUALIZE_TRACING_FORMAT");
|
static const string format = ngraph::getenv_string("NGRAPH_VISUALIZE_TRACING_FORMAT");
|
||||||
auto file_ext = format.empty() ? "svg" : format;
|
auto file_ext = format.empty() ? "svg" : format;
|
||||||
pass::VisualizeTree vt(base_filename + std::string(".") + file_ext);
|
pass::VisualizeTree vt(base_filename + std::string(".") + file_ext);
|
||||||
vt.run_on_function(func);
|
vt.run_on_function(func);
|
||||||
|
@ -7,21 +7,20 @@
|
|||||||
# include <cxxabi.h>
|
# include <cxxabi.h>
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
#include "ngraph/pass/manager.hpp"
|
|
||||||
#include "ngraph/pass/pass.hpp"
|
#include "ngraph/pass/pass.hpp"
|
||||||
|
#include "openvino/pass/manager.hpp"
|
||||||
|
|
||||||
using namespace std;
|
using namespace std;
|
||||||
using namespace ngraph;
|
|
||||||
|
|
||||||
NGRAPH_RTTI_DEFINITION(ngraph::pass::FunctionPass, "ngraph::pass::FunctionPass", 0);
|
OPENVINO_RTTI_DEFINITION(ov::pass::FunctionPass, "ov::pass::FunctionPass", 0);
|
||||||
|
|
||||||
pass::PassBase::PassBase() : m_property{all_pass_property_off}, m_pass_config(std::make_shared<PassConfig>()) {}
|
ov::pass::PassBase::PassBase() : m_property(), m_pass_config(std::make_shared<PassConfig>()) {}
|
||||||
|
|
||||||
bool pass::PassBase::get_property(const PassPropertyMask& prop) const {
|
bool ov::pass::PassBase::get_property(const PassPropertyMask& prop) const {
|
||||||
return m_property.is_set(prop);
|
return m_property.is_set(prop);
|
||||||
}
|
}
|
||||||
|
|
||||||
void pass::PassBase::set_property(const PassPropertyMask& prop, bool value) {
|
void ov::pass::PassBase::set_property(const PassPropertyMask& prop, bool value) {
|
||||||
if (value) {
|
if (value) {
|
||||||
m_property.set(prop);
|
m_property.set(prop);
|
||||||
} else {
|
} else {
|
||||||
@ -29,7 +28,7 @@ void pass::PassBase::set_property(const PassPropertyMask& prop, bool value) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
std::string pass::PassBase::get_name() const {
|
std::string ov::pass::PassBase::get_name() const {
|
||||||
if (m_name.empty()) {
|
if (m_name.empty()) {
|
||||||
const PassBase* p = this;
|
const PassBase* p = this;
|
||||||
std::string pass_name = typeid(*p).name();
|
std::string pass_name = typeid(*p).name();
|
||||||
@ -43,16 +42,16 @@ std::string pass::PassBase::get_name() const {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
void pass::PassBase::set_callback(const param_callback& callback) {
|
void ov::pass::PassBase::set_callback(const param_callback& callback) {
|
||||||
m_pass_config->set_callback(callback);
|
m_pass_config->set_callback(callback);
|
||||||
}
|
}
|
||||||
|
|
||||||
// The symbols are requiered to be in cpp file to workaround RTTI issue on Android LLVM
|
// The symbols are requiered to be in cpp file to workaround RTTI issue on Android LLVM
|
||||||
|
|
||||||
pass::FunctionPass::~FunctionPass() {}
|
ov::pass::FunctionPass::~FunctionPass() = default;
|
||||||
|
|
||||||
NGRAPH_SUPPRESS_DEPRECATED_START
|
OPENVINO_SUPPRESS_DEPRECATED_START
|
||||||
|
|
||||||
NGRAPH_RTTI_DEFINITION(ngraph::pass::NodePass, "ngraph::pass::NodePass", 0);
|
OPENVINO_RTTI_DEFINITION(ngraph::pass::NodePass, "ngraph::pass::NodePass", 0);
|
||||||
|
|
||||||
pass::NodePass::~NodePass() {}
|
ngraph::pass::NodePass::~NodePass() = default;
|
||||||
|
@ -2,11 +2,9 @@
|
|||||||
// SPDX-License-Identifier: Apache-2.0
|
// SPDX-License-Identifier: Apache-2.0
|
||||||
//
|
//
|
||||||
|
|
||||||
#include "ngraph/pass/pass_config.hpp"
|
#include "openvino/pass/pass_config.hpp"
|
||||||
|
|
||||||
using namespace ngraph;
|
ov::pass::param_callback ov::pass::PassConfig::get_callback(const DiscreteTypeInfo& type_info) const {
|
||||||
|
|
||||||
pass::param_callback pass::PassConfig::get_callback(const DiscreteTypeInfo& type_info) const {
|
|
||||||
const auto& it = m_callback_map.find(type_info);
|
const auto& it = m_callback_map.find(type_info);
|
||||||
if (it != m_callback_map.end()) {
|
if (it != m_callback_map.end()) {
|
||||||
return it->second;
|
return it->second;
|
||||||
@ -15,17 +13,17 @@ pass::param_callback pass::PassConfig::get_callback(const DiscreteTypeInfo& type
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
void pass::PassConfig::enable(const ngraph::DiscreteTypeInfo& type_info) {
|
void ov::pass::PassConfig::enable(const ngraph::DiscreteTypeInfo& type_info) {
|
||||||
m_disabled.erase(type_info);
|
m_disabled.erase(type_info);
|
||||||
m_enabled.insert(type_info);
|
m_enabled.insert(type_info);
|
||||||
}
|
}
|
||||||
|
|
||||||
void pass::PassConfig::disable(const ngraph::DiscreteTypeInfo& type_info) {
|
void ov::pass::PassConfig::disable(const ngraph::DiscreteTypeInfo& type_info) {
|
||||||
m_enabled.erase(type_info);
|
m_enabled.erase(type_info);
|
||||||
m_disabled.insert(type_info);
|
m_disabled.insert(type_info);
|
||||||
}
|
}
|
||||||
|
|
||||||
void pass::PassConfig::add_disabled_passes(const PassConfig& rhs) {
|
void ov::pass::PassConfig::add_disabled_passes(const PassConfig& rhs) {
|
||||||
for (const auto& pass : rhs.m_disabled) {
|
for (const auto& pass : rhs.m_disabled) {
|
||||||
if (is_enabled(pass))
|
if (is_enabled(pass))
|
||||||
continue;
|
continue;
|
||||||
|
@ -3,7 +3,7 @@
|
|||||||
//
|
//
|
||||||
#include "perf_counters.hpp"
|
#include "perf_counters.hpp"
|
||||||
|
|
||||||
namespace ngraph {
|
namespace ov {
|
||||||
namespace pass {
|
namespace pass {
|
||||||
openvino::itt::handle_t PerfCounters::operator[](::ngraph::Node::type_info_t const& type_inf) {
|
openvino::itt::handle_t PerfCounters::operator[](::ngraph::Node::type_info_t const& type_inf) {
|
||||||
std::lock_guard<std::mutex> guard(m_mutex);
|
std::lock_guard<std::mutex> guard(m_mutex);
|
||||||
@ -13,4 +13,4 @@ openvino::itt::handle_t PerfCounters::operator[](::ngraph::Node::type_info_t con
|
|||||||
return m_counters[&type_inf] = openvino::itt::handle(type_inf.name);
|
return m_counters[&type_inf] = openvino::itt::handle(type_inf.name);
|
||||||
}
|
}
|
||||||
} // namespace pass
|
} // namespace pass
|
||||||
} // namespace ngraph
|
} // namespace ov
|
||||||
|
@ -7,7 +7,7 @@
|
|||||||
#include <ngraph/node.hpp>
|
#include <ngraph/node.hpp>
|
||||||
#include <unordered_map>
|
#include <unordered_map>
|
||||||
|
|
||||||
namespace ngraph {
|
namespace ov {
|
||||||
namespace pass {
|
namespace pass {
|
||||||
class PerfCounters {
|
class PerfCounters {
|
||||||
PerfCounters(PerfCounters const&) = delete;
|
PerfCounters(PerfCounters const&) = delete;
|
||||||
@ -27,4 +27,4 @@ private:
|
|||||||
counters_map m_counters;
|
counters_map m_counters;
|
||||||
};
|
};
|
||||||
} // namespace pass
|
} // namespace pass
|
||||||
} // namespace ngraph
|
} // namespace ov
|
||||||
|
@ -2,16 +2,16 @@
|
|||||||
// SPDX-License-Identifier: Apache-2.0
|
// SPDX-License-Identifier: Apache-2.0
|
||||||
//
|
//
|
||||||
|
|
||||||
#include "ngraph/pass/validate.hpp"
|
#include "openvino/pass/validate.hpp"
|
||||||
|
|
||||||
#include "itt.hpp"
|
#include "itt.hpp"
|
||||||
#include "ngraph/graph_util.hpp"
|
#include "ngraph/graph_util.hpp"
|
||||||
|
|
||||||
using namespace ngraph;
|
using namespace ngraph;
|
||||||
|
|
||||||
NGRAPH_RTTI_DEFINITION(ngraph::pass::Validate, "ngraph::pass::Validate", 0);
|
OPENVINO_RTTI_DEFINITION(ov::pass::Validate, "ov::pass::Validate", 0);
|
||||||
|
|
||||||
bool pass::Validate::run_on_function(std::shared_ptr<Function> f) {
|
bool ov::pass::Validate::run_on_function(std::shared_ptr<Function> f) {
|
||||||
f->validate_nodes_and_infer_types();
|
f->validate_nodes_and_infer_types();
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
@ -13,7 +13,8 @@
|
|||||||
#include "ngraph/op/parameter.hpp"
|
#include "ngraph/op/parameter.hpp"
|
||||||
#include "ngraph/op/util/op_types.hpp"
|
#include "ngraph/op/util/op_types.hpp"
|
||||||
|
|
||||||
namespace ngraph {
|
namespace ov {
|
||||||
|
namespace pass {
|
||||||
namespace pattern {
|
namespace pattern {
|
||||||
MatcherState::MatcherState(Matcher* matcher)
|
MatcherState::MatcherState(Matcher* matcher)
|
||||||
: m_matcher(matcher),
|
: m_matcher(matcher),
|
||||||
@ -88,7 +89,7 @@ bool Matcher::is_contained_match(const NodeVector& exclusions, bool ignore_unuse
|
|||||||
NGRAPH_SUPPRESS_DEPRECATED_START
|
NGRAPH_SUPPRESS_DEPRECATED_START
|
||||||
if (exclusions.empty()) {
|
if (exclusions.empty()) {
|
||||||
NodeVector label_exclusions;
|
NodeVector label_exclusions;
|
||||||
for (auto entry : m_pattern_map) {
|
for (const auto& entry : m_pattern_map) {
|
||||||
// leaf label
|
// leaf label
|
||||||
if (entry.first->get_input_size() == 0) {
|
if (entry.first->get_input_size() == 0) {
|
||||||
label_exclusions.push_back(entry.second.get_node_shared_ptr());
|
label_exclusions.push_back(entry.second.get_node_shared_ptr());
|
||||||
@ -108,7 +109,7 @@ bool Matcher::match_value(const ngraph::Output<Node>& pattern_value, const ngrap
|
|||||||
// This env var allows one to specify node name patterns to abort pattern matching
|
// This env var allows one to specify node name patterns to abort pattern matching
|
||||||
// at particular nodes. The upshot is that one can quickly zero in on an offending
|
// at particular nodes. The upshot is that one can quickly zero in on an offending
|
||||||
// fusion by disabling individual fusions or optimizations that use Matcher.
|
// fusion by disabling individual fusions or optimizations that use Matcher.
|
||||||
static const std::string node_skip_cregex = getenv_string("NGRAPH_FAIL_MATCH_AT");
|
static const std::string node_skip_cregex = ngraph::getenv_string("NGRAPH_FAIL_MATCH_AT");
|
||||||
if (!node_skip_cregex.empty()) {
|
if (!node_skip_cregex.empty()) {
|
||||||
static const std::regex node_skip_regex(node_skip_cregex);
|
static const std::regex node_skip_regex(node_skip_cregex);
|
||||||
if (std::regex_match(graph_node->get_name(), node_skip_regex)) {
|
if (std::regex_match(graph_node->get_name(), node_skip_regex)) {
|
||||||
@ -201,7 +202,7 @@ void Matcher::clear_state() {
|
|||||||
namespace {
|
namespace {
|
||||||
std::set<std::shared_ptr<Node>> as_node_set(const std::set<std::shared_ptr<op::Label>>& label_set) {
|
std::set<std::shared_ptr<Node>> as_node_set(const std::set<std::shared_ptr<op::Label>>& label_set) {
|
||||||
std::set<std::shared_ptr<Node>> result;
|
std::set<std::shared_ptr<Node>> result;
|
||||||
for (auto label : label_set) {
|
for (const auto& label : label_set) {
|
||||||
result.insert(label);
|
result.insert(label);
|
||||||
}
|
}
|
||||||
return result;
|
return result;
|
||||||
@ -230,7 +231,7 @@ bool RecurrentMatcher::match(Output<Node> graph) {
|
|||||||
graph = m.get_pattern_value_map()[m_recurrent_pattern];
|
graph = m.get_pattern_value_map()[m_recurrent_pattern];
|
||||||
|
|
||||||
// copy bound nodes for the current pattern graph into a global matches map
|
// copy bound nodes for the current pattern graph into a global matches map
|
||||||
for (auto cur_match : m.get_pattern_value_map()) {
|
for (const auto& cur_match : m.get_pattern_value_map()) {
|
||||||
m_matches[cur_match.first].push_back(cur_match.second);
|
m_matches[cur_match.first].push_back(cur_match.second);
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -238,7 +239,7 @@ bool RecurrentMatcher::match(Output<Node> graph) {
|
|||||||
// from the current match. Only bound nodes whose labels are in
|
// from the current match. Only bound nodes whose labels are in
|
||||||
// correlated_patterns are pre-populated. Skip other labels are
|
// correlated_patterns are pre-populated. Skip other labels are
|
||||||
// unbounded by default
|
// unbounded by default
|
||||||
for (auto cor_pat : m_correlated_patterns) {
|
for (const auto& cor_pat : m_correlated_patterns) {
|
||||||
previous_matches[cor_pat] = m.get_pattern_value_map()[cor_pat];
|
previous_matches[cor_pat] = m.get_pattern_value_map()[cor_pat];
|
||||||
}
|
}
|
||||||
m = m_repeat;
|
m = m_repeat;
|
||||||
@ -251,4 +252,5 @@ bool RecurrentMatcher::match(Output<Node> graph) {
|
|||||||
return matched;
|
return matched;
|
||||||
}
|
}
|
||||||
} // namespace pattern
|
} // namespace pattern
|
||||||
} // namespace ngraph
|
} // namespace pass
|
||||||
|
} // namespace ov
|
||||||
|
@ -7,17 +7,16 @@
|
|||||||
#include "ngraph/pattern/matcher.hpp"
|
#include "ngraph/pattern/matcher.hpp"
|
||||||
|
|
||||||
using namespace std;
|
using namespace std;
|
||||||
using namespace ngraph;
|
|
||||||
|
|
||||||
constexpr NodeTypeInfo pattern::op::Any::type_info;
|
constexpr ov::NodeTypeInfo ov::pass::pattern::op::Any::type_info;
|
||||||
|
|
||||||
const NodeTypeInfo& pattern::op::Any::get_type_info() const {
|
const ov::NodeTypeInfo& ov::pass::pattern::op::Any::get_type_info() const {
|
||||||
return type_info;
|
return type_info;
|
||||||
}
|
}
|
||||||
|
|
||||||
bool pattern::op::Any::match_value(Matcher* matcher,
|
bool ov::pass::pattern::op::Any::match_value(Matcher* matcher,
|
||||||
const Output<Node>& pattern_value,
|
const Output<Node>& pattern_value,
|
||||||
const Output<Node>& graph_value) {
|
const Output<Node>& graph_value) {
|
||||||
matcher->add_node(graph_value);
|
matcher->add_node(graph_value);
|
||||||
return m_predicate(graph_value) &&
|
return m_predicate(graph_value) &&
|
||||||
matcher->match_arguments(pattern_value.get_node(), graph_value.get_node_shared_ptr());
|
matcher->match_arguments(pattern_value.get_node(), graph_value.get_node_shared_ptr());
|
||||||
|
@ -7,20 +7,19 @@
|
|||||||
#include "ngraph/pattern/matcher.hpp"
|
#include "ngraph/pattern/matcher.hpp"
|
||||||
|
|
||||||
using namespace std;
|
using namespace std;
|
||||||
using namespace ngraph;
|
|
||||||
|
|
||||||
constexpr NodeTypeInfo pattern::op::AnyOf::type_info;
|
constexpr ov::NodeTypeInfo ov::pass::pattern::op::AnyOf::type_info;
|
||||||
|
|
||||||
const NodeTypeInfo& pattern::op::AnyOf::get_type_info() const {
|
const ov::NodeTypeInfo& ov::pass::pattern::op::AnyOf::get_type_info() const {
|
||||||
return type_info;
|
return type_info;
|
||||||
}
|
}
|
||||||
|
|
||||||
bool pattern::op::AnyOf::match_value(Matcher* matcher,
|
bool ov::pass::pattern::op::AnyOf::match_value(Matcher* matcher,
|
||||||
const Output<Node>& pattern_value,
|
const Output<Node>& pattern_value,
|
||||||
const Output<Node>& graph_value) {
|
const Output<Node>& graph_value) {
|
||||||
matcher->add_node(graph_value);
|
matcher->add_node(graph_value);
|
||||||
return m_predicate(graph_value) && ([&]() {
|
return m_predicate(graph_value) && ([&]() {
|
||||||
for (auto arg : graph_value.get_node_shared_ptr()->input_values()) {
|
for (const auto& arg : graph_value.get_node_shared_ptr()->input_values()) {
|
||||||
auto saved = matcher->start_match();
|
auto saved = matcher->start_match();
|
||||||
if (matcher->match_value(input_value(0), arg)) {
|
if (matcher->match_value(input_value(0), arg)) {
|
||||||
return saved.finish(true);
|
return saved.finish(true);
|
||||||
|
@ -7,16 +7,15 @@
|
|||||||
#include "ngraph/pattern/matcher.hpp"
|
#include "ngraph/pattern/matcher.hpp"
|
||||||
|
|
||||||
using namespace std;
|
using namespace std;
|
||||||
using namespace ngraph;
|
|
||||||
|
|
||||||
constexpr NodeTypeInfo pattern::op::AnyOutput::type_info;
|
constexpr ov::NodeTypeInfo ov::pass::pattern::op::AnyOutput::type_info;
|
||||||
|
|
||||||
const NodeTypeInfo& pattern::op::AnyOutput::get_type_info() const {
|
const ov::NodeTypeInfo& ov::pass::pattern::op::AnyOutput::get_type_info() const {
|
||||||
return type_info;
|
return type_info;
|
||||||
}
|
}
|
||||||
|
|
||||||
bool pattern::op::AnyOutput::match_value(Matcher* matcher,
|
bool ov::pass::pattern::op::AnyOutput::match_value(Matcher* matcher,
|
||||||
const Output<Node>& pattern_value,
|
const Output<Node>& pattern_value,
|
||||||
const Output<Node>& graph_value) {
|
const Output<Node>& graph_value) {
|
||||||
return input_value(0).get_node()->match_node(matcher, graph_value);
|
return input_value(0).get_node()->match_node(matcher, graph_value);
|
||||||
}
|
}
|
||||||
|
@ -9,15 +9,14 @@
|
|||||||
#include "ngraph/pattern/op/true.hpp"
|
#include "ngraph/pattern/op/true.hpp"
|
||||||
|
|
||||||
using namespace std;
|
using namespace std;
|
||||||
using namespace ngraph;
|
|
||||||
|
|
||||||
constexpr NodeTypeInfo pattern::op::Label::type_info;
|
constexpr ov::NodeTypeInfo ov::pass::pattern::op::Label::type_info;
|
||||||
|
|
||||||
const NodeTypeInfo& pattern::op::Label::get_type_info() const {
|
const ov::NodeTypeInfo& ov::pass::pattern::op::Label::get_type_info() const {
|
||||||
return type_info;
|
return type_info;
|
||||||
}
|
}
|
||||||
|
|
||||||
Output<Node> pattern::op::Label::wrap_values(const OutputVector& wrapped_values) {
|
ov::Output<ov::Node> ov::pass::pattern::op::Label::wrap_values(const ov::OutputVector& wrapped_values) {
|
||||||
switch (wrapped_values.size()) {
|
switch (wrapped_values.size()) {
|
||||||
case 0:
|
case 0:
|
||||||
return make_shared<pattern::op::True>()->output(0);
|
return make_shared<pattern::op::True>()->output(0);
|
||||||
@ -28,9 +27,9 @@ Output<Node> pattern::op::Label::wrap_values(const OutputVector& wrapped_values)
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
bool pattern::op::Label::match_value(Matcher* matcher,
|
bool ov::pass::pattern::op::Label::match_value(ov::pass::pattern::Matcher* matcher,
|
||||||
const Output<Node>& pattern_value,
|
const ov::Output<ov::Node>& pattern_value,
|
||||||
const Output<Node>& graph_value) {
|
const ov::Output<ov::Node>& graph_value) {
|
||||||
if (m_predicate(graph_value)) {
|
if (m_predicate(graph_value)) {
|
||||||
auto& pattern_map = matcher->get_pattern_value_map();
|
auto& pattern_map = matcher->get_pattern_value_map();
|
||||||
auto saved = matcher->start_match();
|
auto saved = matcher->start_match();
|
||||||
@ -45,10 +44,10 @@ bool pattern::op::Label::match_value(Matcher* matcher,
|
|||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
std::shared_ptr<Node> pattern::any_input() {
|
std::shared_ptr<ov::Node> ov::pass::pattern::any_input() {
|
||||||
return std::make_shared<pattern::op::Label>();
|
return std::make_shared<pattern::op::Label>();
|
||||||
}
|
}
|
||||||
|
|
||||||
std::shared_ptr<Node> pattern::any_input(const pattern::op::ValuePredicate& pred) {
|
std::shared_ptr<ov::Node> ov::pass::pattern::any_input(const ov::pass::pattern::op::ValuePredicate& pred) {
|
||||||
return std::make_shared<pattern::op::Label>(element::dynamic, PartialShape::dynamic(), pred);
|
return std::make_shared<pattern::op::Label>(element::dynamic, PartialShape::dynamic(), pred);
|
||||||
}
|
}
|
||||||
|
@ -7,7 +7,8 @@
|
|||||||
#include <algorithm>
|
#include <algorithm>
|
||||||
#include <regex>
|
#include <regex>
|
||||||
|
|
||||||
namespace ngraph {
|
namespace ov {
|
||||||
|
namespace pass {
|
||||||
namespace pattern {
|
namespace pattern {
|
||||||
namespace op {
|
namespace op {
|
||||||
// The symbols are required to be in cpp file to workaround RTTI issue on Android LLVM
|
// The symbols are required to be in cpp file to workaround RTTI issue on Android LLVM
|
||||||
@ -101,4 +102,5 @@ std::function<bool(Output<Node>)> type_matches_any(const std::vector<element::Ty
|
|||||||
};
|
};
|
||||||
}
|
}
|
||||||
} // namespace pattern
|
} // namespace pattern
|
||||||
} // namespace ngraph
|
} // namespace pass
|
||||||
|
} // namespace ov
|
||||||
|
@ -2,9 +2,10 @@
|
|||||||
// SPDX-License-Identifier: Apache-2.0
|
// SPDX-License-Identifier: Apache-2.0
|
||||||
//
|
//
|
||||||
|
|
||||||
|
#include "dyn_elimination.hpp"
|
||||||
|
|
||||||
#include <numeric>
|
#include <numeric>
|
||||||
|
|
||||||
#include "dyn_elimination.hpp"
|
|
||||||
#include "ngraph/builder/reshape.hpp"
|
#include "ngraph/builder/reshape.hpp"
|
||||||
#include "ngraph/op/broadcast.hpp"
|
#include "ngraph/op/broadcast.hpp"
|
||||||
#include "ngraph/op/range.hpp"
|
#include "ngraph/op/range.hpp"
|
||||||
@ -19,9 +20,7 @@ NGRAPH_SUPPRESS_DEPRECATED_START
|
|||||||
using namespace std;
|
using namespace std;
|
||||||
using namespace ngraph;
|
using namespace ngraph;
|
||||||
|
|
||||||
pass::DynElimination::DynElimination()
|
pass::DynElimination::DynElimination() : GraphRewrite() {
|
||||||
: GraphRewrite()
|
|
||||||
{
|
|
||||||
construct_range();
|
construct_range();
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -29,28 +28,22 @@ template <typename T>
|
|||||||
std::shared_ptr<op::Constant> make_range_replacement(const element::Type& et,
|
std::shared_ptr<op::Constant> make_range_replacement(const element::Type& et,
|
||||||
const Shape& shape,
|
const Shape& shape,
|
||||||
const std::shared_ptr<op::Constant>& start_arg,
|
const std::shared_ptr<op::Constant>& start_arg,
|
||||||
const std::shared_ptr<op::Constant>& step_arg)
|
const std::shared_ptr<op::Constant>& step_arg) {
|
||||||
{
|
|
||||||
std::vector<T> elements(shape_size(shape));
|
std::vector<T> elements(shape_size(shape));
|
||||||
std::vector<T> start_vec = start_arg->get_vector<T>();
|
std::vector<T> start_vec = start_arg->get_vector<T>();
|
||||||
std::vector<T> step_vec = step_arg->get_vector<T>();
|
std::vector<T> step_vec = step_arg->get_vector<T>();
|
||||||
|
|
||||||
NGRAPH_CHECK(start_vec.size() == 1 && step_vec.size() == 1);
|
NGRAPH_CHECK(start_vec.size() == 1 && step_vec.size() == 1);
|
||||||
|
|
||||||
runtime::reference::range<T>(
|
runtime::reference::range<T>(start_vec.data(), step_vec.data(), shape_size(shape), elements.data());
|
||||||
start_vec.data(), step_vec.data(), shape_size(shape), elements.data());
|
|
||||||
|
|
||||||
return make_shared<op::Constant>(et, shape, elements);
|
return make_shared<op::Constant>(et, shape, elements);
|
||||||
}
|
}
|
||||||
|
|
||||||
void pass::DynElimination::construct_range()
|
void pass::DynElimination::construct_range() {
|
||||||
{
|
auto start_arg_label = make_shared<pattern::op::Label>(element::f32, Shape{}, pattern::has_class<op::Constant>());
|
||||||
auto start_arg_label =
|
auto stop_arg_label = make_shared<pattern::op::Label>(element::f32, Shape{}, pattern::has_class<op::Constant>());
|
||||||
make_shared<pattern::op::Label>(element::f32, Shape{}, pattern::has_class<op::Constant>());
|
auto step_arg_label = make_shared<pattern::op::Label>(element::f32, Shape{}, pattern::has_class<op::Constant>());
|
||||||
auto stop_arg_label =
|
|
||||||
make_shared<pattern::op::Label>(element::f32, Shape{}, pattern::has_class<op::Constant>());
|
|
||||||
auto step_arg_label =
|
|
||||||
make_shared<pattern::op::Label>(element::f32, Shape{}, pattern::has_class<op::Constant>());
|
|
||||||
|
|
||||||
auto range_pat = make_shared<op::Range>(start_arg_label, stop_arg_label, step_arg_label);
|
auto range_pat = make_shared<op::Range>(start_arg_label, stop_arg_label, step_arg_label);
|
||||||
|
|
||||||
@ -70,12 +63,11 @@ void pass::DynElimination::construct_range()
|
|||||||
std::shared_ptr<op::Constant> replacement;
|
std::shared_ptr<op::Constant> replacement;
|
||||||
|
|
||||||
#if defined(__GNUC__) && !(__GNUC__ == 4 && __GNUC_MINOR__ == 8)
|
#if defined(__GNUC__) && !(__GNUC__ == 4 && __GNUC_MINOR__ == 8)
|
||||||
#pragma GCC diagnostic push
|
# pragma GCC diagnostic push
|
||||||
#pragma GCC diagnostic error "-Wswitch"
|
# pragma GCC diagnostic error "-Wswitch"
|
||||||
#pragma GCC diagnostic error "-Wswitch-enum"
|
# pragma GCC diagnostic error "-Wswitch-enum"
|
||||||
#endif
|
#endif
|
||||||
switch (et)
|
switch (et) {
|
||||||
{
|
|
||||||
case element::Type_t::bf16:
|
case element::Type_t::bf16:
|
||||||
replacement = make_range_replacement<bfloat16>(et, shape, start_arg, step_arg);
|
replacement = make_range_replacement<bfloat16>(et, shape, start_arg, step_arg);
|
||||||
break;
|
break;
|
||||||
@ -122,7 +114,7 @@ void pass::DynElimination::construct_range()
|
|||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
#if defined(__GNUC__) && !(__GNUC__ == 4 && __GNUC_MINOR__ == 8)
|
#if defined(__GNUC__) && !(__GNUC__ == 4 && __GNUC_MINOR__ == 8)
|
||||||
#pragma GCC diagnostic pop
|
# pragma GCC diagnostic pop
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
replace_node(range_node, replacement);
|
replace_node(range_node, replacement);
|
||||||
|
Loading…
Reference in New Issue
Block a user