Fixed headers for doxygen genration (#3746)
This commit is contained in:
parent
e82257d021
commit
72cd81305c
@ -37,6 +37,5 @@ namespace ngraph
|
||||
void copy_runtime_info_to_target_inputs(const std::shared_ptr<Node>& node,
|
||||
const Output<Node>& replacement);
|
||||
};
|
||||
|
||||
} // namespace pass
|
||||
} // namespace ngraph
|
||||
|
@ -22,24 +22,21 @@ namespace ngraph
|
||||
{
|
||||
namespace pass
|
||||
{
|
||||
class ConvertFP32ToFP16;
|
||||
class NGRAPH_API ConvertFP32ToFP16 : public ngraph::pass::GraphRewrite
|
||||
{
|
||||
public:
|
||||
NGRAPH_RTTI_DECLARATION;
|
||||
ConvertFP32ToFP16()
|
||||
: GraphRewrite()
|
||||
{
|
||||
convert_constants_precision();
|
||||
convert_parameters_precision();
|
||||
}
|
||||
|
||||
private:
|
||||
void convert_constants_precision();
|
||||
|
||||
void convert_parameters_precision();
|
||||
};
|
||||
} // namespace pass
|
||||
} // namespace ngraph
|
||||
|
||||
class NGRAPH_API ngraph::pass::ConvertFP32ToFP16 : public ngraph::pass::GraphRewrite
|
||||
{
|
||||
public:
|
||||
NGRAPH_RTTI_DECLARATION;
|
||||
ConvertFP32ToFP16()
|
||||
: GraphRewrite()
|
||||
{
|
||||
convert_constants_precision();
|
||||
convert_parameters_precision();
|
||||
}
|
||||
|
||||
private:
|
||||
void convert_constants_precision();
|
||||
|
||||
void convert_parameters_precision();
|
||||
};
|
||||
|
@ -25,178 +25,194 @@
|
||||
|
||||
namespace ngraph
|
||||
{
|
||||
namespace pass
|
||||
{
|
||||
class GraphRewrite;
|
||||
class RecurrentGraphRewrite;
|
||||
class MatcherPass;
|
||||
}
|
||||
|
||||
using matcher_pass_callback = std::function<bool(ngraph::pattern::Matcher& m)>;
|
||||
using graph_rewrite_callback = std::function<bool(ngraph::pattern::Matcher& m)>;
|
||||
using recurrent_graph_rewrite_callback =
|
||||
std::function<bool(ngraph::pattern::RecurrentMatcher& m)>;
|
||||
using handler_callback = std::function<bool(const std::shared_ptr<Node>& node)>;
|
||||
}
|
||||
|
||||
/// \brief MatcherPass is a basic block for pattern based transformations. It describes pattern and
|
||||
/// action that is applied if pattern is matched.
|
||||
///
|
||||
/// MatcherPass consists of Matcher and matcher_pass_callback that needs to be implemented and
|
||||
/// finally registered by using \sa register_matcher. MatcherPass can be executed on node within
|
||||
/// \sa apply method. To run matcher pass on Function use GraphRewrite.
|
||||
/// In addition MatcherPass provides a way for adding new operations into GraphRewrite execution
|
||||
/// queue. That means that operations that were created inside transformation callback can be added
|
||||
/// for matching. To register node use \sa register_new_node method. GraphRewrite automatically
|
||||
/// takes registered nodes and put them to execution queue. If multiple nodes were register make
|
||||
/// sure that they were registered in topological order.
|
||||
/// Note: when implementing pattern for Matcher make sure that root node is an operation from opset
|
||||
/// or has ngraph::pattern::op::WrapType. That will help GraphRewrite to execute matcher passes more
|
||||
/// efficient.
|
||||
|
||||
class NGRAPH_API ngraph::pass::MatcherPass : public ngraph::pass::PassBase
|
||||
{
|
||||
public:
|
||||
NGRAPH_RTTI_DECLARATION;
|
||||
|
||||
MatcherPass() = default;
|
||||
|
||||
MatcherPass(const MatcherPass&) = delete;
|
||||
MatcherPass& operator=(const MatcherPass&) = delete;
|
||||
|
||||
explicit MatcherPass(const std::string& name,
|
||||
const std::shared_ptr<pattern::Matcher>& m,
|
||||
const handler_callback& handler,
|
||||
const PassPropertyMask& property = PassProperty::CHANGE_DYNAMIC_STATE)
|
||||
: PassBase()
|
||||
, m_handler(handler)
|
||||
, m_matcher(m)
|
||||
namespace pass
|
||||
{
|
||||
set_name(name);
|
||||
set_property(property, true);
|
||||
}
|
||||
/// \brief MatcherPass is a basic block for pattern based transformations. It describes
|
||||
/// pattern and
|
||||
/// action that is applied if pattern is matched.
|
||||
///
|
||||
/// MatcherPass consists of Matcher and matcher_pass_callback that needs to be implemented
|
||||
/// and
|
||||
/// finally registered by using \sa register_matcher. MatcherPass can be executed on node
|
||||
/// within
|
||||
/// \sa apply method. To run matcher pass on Function use GraphRewrite.
|
||||
/// In addition MatcherPass provides a way for adding new operations into GraphRewrite
|
||||
/// execution
|
||||
/// queue. That means that operations that were created inside transformation callback can
|
||||
/// be added
|
||||
/// for matching. To register node use \sa register_new_node method. GraphRewrite
|
||||
/// automatically
|
||||
/// takes registered nodes and put them to execution queue. If multiple nodes were register
|
||||
/// make
|
||||
/// sure that they were registered in topological order.
|
||||
/// Note: when implementing pattern for Matcher make sure that root node is an operation
|
||||
/// from opset
|
||||
/// or has ngraph::pattern::op::WrapType. That will help GraphRewrite to execute matcher
|
||||
/// passes more
|
||||
/// efficient.
|
||||
|
||||
bool apply(std::shared_ptr<ngraph::Node> node);
|
||||
|
||||
template <typename T, class... Args>
|
||||
std::shared_ptr<T> register_new_node(Args&&... args)
|
||||
{
|
||||
auto node = std::make_shared<T>(std::forward<Args>(args)...);
|
||||
m_new_nodes.push_back(node);
|
||||
return node;
|
||||
}
|
||||
|
||||
const std::vector<std::shared_ptr<ngraph::Node>>& get_new_nodes() { return m_new_nodes; }
|
||||
void clear_new_nodes() { m_new_nodes.clear(); }
|
||||
std::shared_ptr<pattern::Matcher> get_matcher() { return m_matcher; }
|
||||
protected:
|
||||
void register_matcher(const std::shared_ptr<pattern::Matcher>& m,
|
||||
const ngraph::graph_rewrite_callback& callback,
|
||||
const PassPropertyMask& property = PassProperty::CHANGE_DYNAMIC_STATE);
|
||||
|
||||
private:
|
||||
handler_callback m_handler;
|
||||
std::shared_ptr<pattern::Matcher> m_matcher;
|
||||
std::vector<std::shared_ptr<ngraph::Node>> m_new_nodes;
|
||||
};
|
||||
|
||||
/// \brief GraphRewrite is a container for MatcherPasses that allows to run them on Function in
|
||||
/// efficient way
|
||||
///
|
||||
/// Graph rewrite pass is used for matcher passes execution on Function.
|
||||
/// To register MatcherPass use \sa add_matcher<T>(args) method where T is a MatcherPass class.
|
||||
/// As a default algorithm graph rewrite pass traverse Function in topological order and applies
|
||||
/// registered matcher passes for each node. But if all registered matcher passes have type based
|
||||
/// root node in Matcher pattern then efficient mechanism is used to execute them.
|
||||
/// Matcher pattern root is type based if it's operation from opset or pattern::op::WrapType.
|
||||
/// Note: when implementing pattern for Matcher make sure that root node is an operation from opset
|
||||
/// or has ngraph::pattern::op::WrapType. That will help GraphRewrite to execute matcher passes more
|
||||
/// efficient.
|
||||
|
||||
class NGRAPH_API ngraph::pass::GraphRewrite : public ngraph::pass::FunctionPass
|
||||
{
|
||||
public:
|
||||
NGRAPH_RTTI_DECLARATION;
|
||||
|
||||
GraphRewrite() = default;
|
||||
|
||||
explicit GraphRewrite(const std::shared_ptr<MatcherPass>& pass)
|
||||
: FunctionPass()
|
||||
{
|
||||
m_matchers.push_back(pass);
|
||||
}
|
||||
|
||||
/// \brief Register given transformation class type to GraphRewrite execution list
|
||||
/// All registered transformations will be executed in a single graph traversal.
|
||||
/// Example below show the basic usage of pass::GraphRewrite
|
||||
///
|
||||
/// pass::Manager manager;
|
||||
/// auto anchor = manager.register_pass<GraphRewrite>();
|
||||
/// anchor->add_matcher<MatcherPassA>();
|
||||
/// anchor->add_matcher<MatcherPassB>();
|
||||
/// anchor->set_name("CommonMathcers");
|
||||
/// manager.run_passes(f);
|
||||
///
|
||||
/// For some purposes transformation can be registered and disabled by default.
|
||||
///
|
||||
/// anchor->add_matcher<MatcherPassB, false>();
|
||||
///
|
||||
/// \return shared_ptr to the transformation instance
|
||||
template <typename T, bool Enabled = true, class... Args>
|
||||
std::shared_ptr<T> add_matcher(Args&&... args)
|
||||
{
|
||||
static_assert(std::is_base_of<pass::MatcherPass, T>::value,
|
||||
"pass not derived from MatcherPass");
|
||||
auto pass = std::make_shared<T>(std::forward<Args>(args)...);
|
||||
auto pass_config = get_pass_config();
|
||||
pass->set_pass_config(pass_config);
|
||||
if (!Enabled && !pass_config->is_enabled<T>())
|
||||
class NGRAPH_API MatcherPass : public ngraph::pass::PassBase
|
||||
{
|
||||
pass_config->disable<T>();
|
||||
}
|
||||
m_matchers.push_back(pass);
|
||||
return pass;
|
||||
}
|
||||
NGRAPH_DEPRECATED("Use MatcherPass instead")
|
||||
void add_matcher(const std::shared_ptr<pattern::Matcher>& m,
|
||||
const ngraph::graph_rewrite_callback& callback,
|
||||
const PassPropertyMask& property);
|
||||
public:
|
||||
NGRAPH_RTTI_DECLARATION;
|
||||
|
||||
NGRAPH_DEPRECATED("Use MatcherPass instead")
|
||||
void add_matcher(const std::shared_ptr<pattern::Matcher>& m,
|
||||
const ngraph::graph_rewrite_callback& callback);
|
||||
MatcherPass() = default;
|
||||
|
||||
bool run_on_function(std::shared_ptr<ngraph::Function> f) override;
|
||||
MatcherPass(const MatcherPass&) = delete;
|
||||
MatcherPass& operator=(const MatcherPass&) = delete;
|
||||
|
||||
void set_pass_config(const std::shared_ptr<PassConfig>& pass_config) override;
|
||||
explicit MatcherPass(
|
||||
const std::string& name,
|
||||
const std::shared_ptr<pattern::Matcher>& m,
|
||||
const handler_callback& handler,
|
||||
const PassPropertyMask& property = PassProperty::CHANGE_DYNAMIC_STATE)
|
||||
: PassBase()
|
||||
, m_handler(handler)
|
||||
, m_matcher(m)
|
||||
{
|
||||
set_name(name);
|
||||
set_property(property, true);
|
||||
}
|
||||
|
||||
protected:
|
||||
bool m_enable_shape_inference = false;
|
||||
bool apply(std::shared_ptr<ngraph::Node> node);
|
||||
|
||||
std::vector<std::shared_ptr<ngraph::pass::MatcherPass>> m_matchers;
|
||||
};
|
||||
template <typename T, class... Args>
|
||||
std::shared_ptr<T> register_new_node(Args&&... args)
|
||||
{
|
||||
auto node = std::make_shared<T>(std::forward<Args>(args)...);
|
||||
m_new_nodes.push_back(node);
|
||||
return node;
|
||||
}
|
||||
|
||||
class NGRAPH_API ngraph::pass::RecurrentGraphRewrite : public ngraph::pass::FunctionPass
|
||||
{
|
||||
public:
|
||||
RecurrentGraphRewrite(size_t num_iters = 10)
|
||||
: FunctionPass()
|
||||
, m_num_iters(num_iters)
|
||||
{
|
||||
}
|
||||
const std::vector<std::shared_ptr<ngraph::Node>>& get_new_nodes()
|
||||
{
|
||||
return m_new_nodes;
|
||||
}
|
||||
void clear_new_nodes() { m_new_nodes.clear(); }
|
||||
std::shared_ptr<pattern::Matcher> get_matcher() { return m_matcher; }
|
||||
protected:
|
||||
void register_matcher(
|
||||
const std::shared_ptr<pattern::Matcher>& m,
|
||||
const ngraph::graph_rewrite_callback& callback,
|
||||
const PassPropertyMask& property = PassProperty::CHANGE_DYNAMIC_STATE);
|
||||
|
||||
void add_matcher(const std::shared_ptr<pattern::RecurrentMatcher>& m,
|
||||
const ngraph::recurrent_graph_rewrite_callback& callback,
|
||||
const PassPropertyMask& property);
|
||||
private:
|
||||
handler_callback m_handler;
|
||||
std::shared_ptr<pattern::Matcher> m_matcher;
|
||||
std::vector<std::shared_ptr<ngraph::Node>> m_new_nodes;
|
||||
};
|
||||
|
||||
// TODO: This interface may deprecate after all passes are refactored.
|
||||
void add_matcher(const std::shared_ptr<pattern::RecurrentMatcher>& m,
|
||||
const ngraph::recurrent_graph_rewrite_callback& callback);
|
||||
/// \brief GraphRewrite is a container for MatcherPasses that allows to run them on Function
|
||||
/// in
|
||||
/// efficient way
|
||||
///
|
||||
/// Graph rewrite pass is used for matcher passes execution on Function.
|
||||
/// To register MatcherPass use \sa add_matcher<T>(args) method where T is a MatcherPass
|
||||
/// class.
|
||||
/// As a default algorithm graph rewrite pass traverse Function in topological order and
|
||||
/// applies
|
||||
/// registered matcher passes for each node. But if all registered matcher passes have type
|
||||
/// based
|
||||
/// root node in Matcher pattern then efficient mechanism is used to execute them.
|
||||
/// Matcher pattern root is type based if it's operation from opset or
|
||||
/// pattern::op::WrapType.
|
||||
/// Note: when implementing pattern for Matcher make sure that root node is an operation
|
||||
/// from opset
|
||||
/// or has ngraph::pattern::op::WrapType. That will help GraphRewrite to execute matcher
|
||||
/// passes more
|
||||
/// efficient.
|
||||
|
||||
virtual bool run_on_function(std::shared_ptr<ngraph::Function> f);
|
||||
class NGRAPH_API GraphRewrite : public ngraph::pass::FunctionPass
|
||||
{
|
||||
public:
|
||||
NGRAPH_RTTI_DECLARATION;
|
||||
|
||||
private:
|
||||
size_t m_num_iters;
|
||||
GraphRewrite() = default;
|
||||
|
||||
std::vector<std::shared_ptr<ngraph::pass::MatcherPass>> m_matchers;
|
||||
};
|
||||
explicit GraphRewrite(const std::shared_ptr<MatcherPass>& pass)
|
||||
: FunctionPass()
|
||||
{
|
||||
m_matchers.push_back(pass);
|
||||
}
|
||||
|
||||
/// \brief Register given transformation class type to GraphRewrite execution list
|
||||
/// All registered transformations will be executed in a single graph traversal.
|
||||
/// Example below show the basic usage of pass::GraphRewrite
|
||||
///
|
||||
/// pass::Manager manager;
|
||||
/// auto anchor = manager.register_pass<GraphRewrite>();
|
||||
/// anchor->add_matcher<MatcherPassA>();
|
||||
/// anchor->add_matcher<MatcherPassB>();
|
||||
/// anchor->set_name("CommonMathcers");
|
||||
/// manager.run_passes(f);
|
||||
///
|
||||
/// For some purposes transformation can be registered and disabled by default.
|
||||
///
|
||||
/// anchor->add_matcher<MatcherPassB, false>();
|
||||
///
|
||||
/// \return shared_ptr to the transformation instance
|
||||
template <typename T, bool Enabled = true, class... Args>
|
||||
std::shared_ptr<T> add_matcher(Args&&... args)
|
||||
{
|
||||
static_assert(std::is_base_of<pass::MatcherPass, T>::value,
|
||||
"pass not derived from MatcherPass");
|
||||
auto pass = std::make_shared<T>(std::forward<Args>(args)...);
|
||||
auto pass_config = get_pass_config();
|
||||
pass->set_pass_config(pass_config);
|
||||
if (!Enabled && !pass_config->is_enabled<T>())
|
||||
{
|
||||
pass_config->disable<T>();
|
||||
}
|
||||
m_matchers.push_back(pass);
|
||||
return pass;
|
||||
}
|
||||
NGRAPH_DEPRECATED("Use MatcherPass instead")
|
||||
void add_matcher(const std::shared_ptr<pattern::Matcher>& m,
|
||||
const ngraph::graph_rewrite_callback& callback,
|
||||
const PassPropertyMask& property);
|
||||
|
||||
NGRAPH_DEPRECATED("Use MatcherPass instead")
|
||||
void add_matcher(const std::shared_ptr<pattern::Matcher>& m,
|
||||
const ngraph::graph_rewrite_callback& callback);
|
||||
|
||||
bool run_on_function(std::shared_ptr<ngraph::Function> f) override;
|
||||
|
||||
void set_pass_config(const std::shared_ptr<PassConfig>& pass_config) override;
|
||||
|
||||
protected:
|
||||
bool m_enable_shape_inference = false;
|
||||
|
||||
std::vector<std::shared_ptr<ngraph::pass::MatcherPass>> m_matchers;
|
||||
};
|
||||
|
||||
class NGRAPH_API RecurrentGraphRewrite : public ngraph::pass::FunctionPass
|
||||
{
|
||||
public:
|
||||
RecurrentGraphRewrite(size_t num_iters = 10)
|
||||
: FunctionPass()
|
||||
, m_num_iters(num_iters)
|
||||
{
|
||||
}
|
||||
|
||||
void add_matcher(const std::shared_ptr<pattern::RecurrentMatcher>& m,
|
||||
const ngraph::recurrent_graph_rewrite_callback& callback,
|
||||
const PassPropertyMask& property);
|
||||
|
||||
// TODO: This interface may deprecate after all passes are refactored.
|
||||
void add_matcher(const std::shared_ptr<pattern::RecurrentMatcher>& m,
|
||||
const ngraph::recurrent_graph_rewrite_callback& callback);
|
||||
|
||||
virtual bool run_on_function(std::shared_ptr<ngraph::Function> f);
|
||||
|
||||
private:
|
||||
size_t m_num_iters;
|
||||
|
||||
std::vector<std::shared_ptr<ngraph::pass::MatcherPass>> m_matchers;
|
||||
};
|
||||
} // namespace pass
|
||||
} // namespace ngraph
|
||||
|
@ -28,93 +28,103 @@ namespace ngraph
|
||||
{
|
||||
namespace pass
|
||||
{
|
||||
class Manager;
|
||||
class NGRAPH_API Manager
|
||||
{
|
||||
public:
|
||||
Manager();
|
||||
~Manager();
|
||||
|
||||
//// \brief Construct Manager with shared PassConfig instance
|
||||
explicit Manager(std::shared_ptr<PassConfig> pass_config);
|
||||
|
||||
/// \brief Register given transformation class type to execution list
|
||||
/// Example below show the basic usage of pass::Manager
|
||||
///
|
||||
/// pass::Manager manager;
|
||||
/// manager.register_pass<MyTransformation>(/*transformation constructor ars*/);
|
||||
/// manager.run_passes(f);
|
||||
///
|
||||
/// For some purposes transformation can be registered and disabled by default.
|
||||
///
|
||||
/// manager.register_pass<MyTransformation, false>();
|
||||
///
|
||||
/// \return shared_ptr to the transformation instance
|
||||
template <typename T, bool Enable = true, class... Args>
|
||||
std::shared_ptr<T> register_pass(Args&&... args)
|
||||
{
|
||||
auto rc = push_pass<T>(std::forward<Args>(args)...);
|
||||
rc->set_pass_config(m_pass_config);
|
||||
if (m_per_pass_validation)
|
||||
{
|
||||
push_pass<Validate>();
|
||||
}
|
||||
if (!Enable && !m_pass_config->is_enabled<T>())
|
||||
{
|
||||
m_pass_config->disable<T>();
|
||||
}
|
||||
return rc;
|
||||
}
|
||||
|
||||
void run_passes(std::shared_ptr<Function>);
|
||||
|
||||
void set_pass_visualization(bool new_state) { m_visualize = new_state; }
|
||||
/// \brief Set flag to enable/disable running Validate pass after executing
|
||||
/// each registered pass
|
||||
/// \param new_state Value "true" enables Validate pass run; "false", otherwise
|
||||
void set_per_pass_validation(bool new_state) { m_per_pass_validation = new_state; }
|
||||
/// \brief Callback is a lambda function that can be used by registered transformations.
|
||||
/// The main purpose of this callback is to provide a way for plugins to disable/enable
|
||||
/// transformations based on some conditions. In some cases plugins may want not to
|
||||
/// execute some
|
||||
/// transformations.
|
||||
/// For example plugin can disable unpleasant decompositions because of performance
|
||||
/// reasons for
|
||||
/// some cases.
|
||||
/// Callback example:
|
||||
/// auto callback = [](const std::shared_ptr<const ngraph::Node> & node) -> bool {
|
||||
/// return std::dynamic_pointer_cast<const ngraph::opset3::DepthToSpace>(node) !=
|
||||
/// nullptr;
|
||||
/// };
|
||||
/// This callback returns true in case of DepthToSpace operation. So when execution
|
||||
/// DepthToSpace
|
||||
/// decomposition pass will check is this decomposition needed or plugin can execute
|
||||
/// this
|
||||
/// operation directly. And of course on transformation side we need to have a response
|
||||
/// for this
|
||||
/// callback.
|
||||
/// if (transformation_callback(batch_to_space)) {
|
||||
/// return false;
|
||||
/// }
|
||||
/// \param callback lamda function that returns true in case if node is supported by
|
||||
/// plugin and
|
||||
/// transformation is not needed
|
||||
NGRAPH_DEPRECATED("Please use get_pass_config() to configure transformation pipeline")
|
||||
void set_callback(const param_callback& callback)
|
||||
{
|
||||
m_pass_config->set_callback(callback);
|
||||
}
|
||||
/// \return PassConfig shared object. This object is used for transformations pipeline
|
||||
/// configuration.
|
||||
/// This object allows to disable/enable transformations execution, set callback to
|
||||
/// particular
|
||||
/// transformation. For mo details see PassConfig class.
|
||||
std::shared_ptr<PassConfig> get_pass_config() { return m_pass_config; }
|
||||
protected:
|
||||
template <typename T, class... Args>
|
||||
std::shared_ptr<T> push_pass(Args&&... args)
|
||||
{
|
||||
static_assert(std::is_base_of<pass::PassBase, T>::value,
|
||||
"pass not derived from pass base");
|
||||
auto pass = std::make_shared<T>(std::forward<Args>(args)...);
|
||||
auto pass_base = std::static_pointer_cast<PassBase>(pass);
|
||||
m_pass_list.push_back(pass_base);
|
||||
return pass;
|
||||
}
|
||||
|
||||
std::shared_ptr<PassConfig> m_pass_config;
|
||||
std::vector<std::shared_ptr<PassBase>> m_pass_list;
|
||||
bool m_visualize = false;
|
||||
bool m_per_pass_validation = true;
|
||||
};
|
||||
}
|
||||
}
|
||||
|
||||
class NGRAPH_API ngraph::pass::Manager
|
||||
{
|
||||
public:
|
||||
Manager();
|
||||
~Manager();
|
||||
|
||||
//// \brief Construct Manager with shared PassConfig instance
|
||||
explicit Manager(std::shared_ptr<PassConfig> pass_config);
|
||||
|
||||
/// \brief Register given transformation class type to execution list
|
||||
/// Example below show the basic usage of pass::Manager
|
||||
///
|
||||
/// pass::Manager manager;
|
||||
/// manager.register_pass<MyTransformation>(/*transformation constructor ars*/);
|
||||
/// manager.run_passes(f);
|
||||
///
|
||||
/// For some purposes transformation can be registered and disabled by default.
|
||||
///
|
||||
/// manager.register_pass<MyTransformation, false>();
|
||||
///
|
||||
/// \return shared_ptr to the transformation instance
|
||||
template <typename T, bool Enable = true, class... Args>
|
||||
std::shared_ptr<T> register_pass(Args&&... args)
|
||||
{
|
||||
auto rc = push_pass<T>(std::forward<Args>(args)...);
|
||||
rc->set_pass_config(m_pass_config);
|
||||
if (m_per_pass_validation)
|
||||
{
|
||||
push_pass<Validate>();
|
||||
}
|
||||
if (!Enable && !m_pass_config->is_enabled<T>())
|
||||
{
|
||||
m_pass_config->disable<T>();
|
||||
}
|
||||
return rc;
|
||||
}
|
||||
|
||||
void run_passes(std::shared_ptr<Function>);
|
||||
|
||||
void set_pass_visualization(bool new_state) { m_visualize = new_state; }
|
||||
/// \brief Set flag to enable/disable running Validate pass after executing
|
||||
/// each registered pass
|
||||
/// \param new_state Value "true" enables Validate pass run; "false", otherwise
|
||||
void set_per_pass_validation(bool new_state) { m_per_pass_validation = new_state; }
|
||||
/// \brief Callback is a lambda function that can be used by registered transformations.
|
||||
/// The main purpose of this callback is to provide a way for plugins to disable/enable
|
||||
/// transformations based on some conditions. In some cases plugins may want not to execute some
|
||||
/// transformations.
|
||||
/// For example plugin can disable unpleasant decompositions because of performance reasons for
|
||||
/// some cases.
|
||||
/// Callback example:
|
||||
/// auto callback = [](const std::shared_ptr<const ngraph::Node> & node) -> bool {
|
||||
/// return std::dynamic_pointer_cast<const ngraph::opset3::DepthToSpace>(node) != nullptr;
|
||||
/// };
|
||||
/// This callback returns true in case of DepthToSpace operation. So when execution DepthToSpace
|
||||
/// decomposition pass will check is this decomposition needed or plugin can execute this
|
||||
/// operation directly. And of course on transformation side we need to have a response for this
|
||||
/// callback.
|
||||
/// if (transformation_callback(batch_to_space)) {
|
||||
/// return false;
|
||||
/// }
|
||||
/// \param callback lamda function that returns true in case if node is supported by plugin and
|
||||
/// transformation is not needed
|
||||
NGRAPH_DEPRECATED("Please use get_pass_config() to configure transformation pipeline")
|
||||
void set_callback(const param_callback& callback) { m_pass_config->set_callback(callback); }
|
||||
/// \return PassConfig shared object. This object is used for transformations pipeline
|
||||
/// configuration.
|
||||
/// This object allows to disable/enable transformations execution, set callback to particular
|
||||
/// transformation. For mo details see PassConfig class.
|
||||
std::shared_ptr<PassConfig> get_pass_config() { return m_pass_config; }
|
||||
protected:
|
||||
template <typename T, class... Args>
|
||||
std::shared_ptr<T> push_pass(Args&&... args)
|
||||
{
|
||||
static_assert(std::is_base_of<pass::PassBase, T>::value, "pass not derived from pass base");
|
||||
auto pass = std::make_shared<T>(std::forward<Args>(args)...);
|
||||
auto pass_base = std::static_pointer_cast<PassBase>(pass);
|
||||
m_pass_list.push_back(pass_base);
|
||||
return pass;
|
||||
}
|
||||
|
||||
std::shared_ptr<PassConfig> m_pass_config;
|
||||
std::vector<std::shared_ptr<PassBase>> m_pass_list;
|
||||
bool m_visualize = false;
|
||||
bool m_per_pass_validation = true;
|
||||
};
|
||||
|
@ -27,49 +27,51 @@
|
||||
|
||||
#include "ngraph/pass/pass.hpp"
|
||||
|
||||
namespace ngraph
|
||||
{
|
||||
namespace pass
|
||||
{
|
||||
class VisualizeTree;
|
||||
}
|
||||
}
|
||||
|
||||
class HeightMap;
|
||||
|
||||
using visualize_tree_ops_map_t =
|
||||
std::unordered_map<ngraph::Node::type_info_t,
|
||||
std::function<void(const ngraph::Node&, std::ostream& ss)>>;
|
||||
|
||||
class NGRAPH_API ngraph::pass::VisualizeTree : public FunctionPass
|
||||
namespace ngraph
|
||||
{
|
||||
public:
|
||||
NGRAPH_RTTI_DECLARATION;
|
||||
namespace pass
|
||||
{
|
||||
class NGRAPH_API VisualizeTree : public FunctionPass
|
||||
{
|
||||
public:
|
||||
NGRAPH_RTTI_DECLARATION;
|
||||
|
||||
using node_modifiers_t =
|
||||
std::function<void(const Node& node, std::vector<std::string>& attributes)>;
|
||||
VisualizeTree(const std::string& file_name,
|
||||
node_modifiers_t nm = nullptr,
|
||||
bool dot_only = false);
|
||||
bool run_on_function(std::shared_ptr<ngraph::Function>) override;
|
||||
using node_modifiers_t =
|
||||
std::function<void(const Node& node, std::vector<std::string>& attributes)>;
|
||||
VisualizeTree(const std::string& file_name,
|
||||
node_modifiers_t nm = nullptr,
|
||||
bool dot_only = false);
|
||||
bool run_on_function(std::shared_ptr<ngraph::Function>) override;
|
||||
|
||||
void set_ops_to_details(const visualize_tree_ops_map_t& ops_map) { m_ops_to_details = ops_map; }
|
||||
protected:
|
||||
void add_node_arguments(std::shared_ptr<Node> node,
|
||||
std::unordered_map<Node*, HeightMap>& height_maps,
|
||||
size_t& fake_node_ctr);
|
||||
std::string add_attributes(std::shared_ptr<Node> node);
|
||||
virtual std::string get_attributes(std::shared_ptr<Node> node);
|
||||
virtual std::string get_node_name(std::shared_ptr<Node> node);
|
||||
std::string get_constant_value(std::shared_ptr<Node> node, size_t max_elements = 7);
|
||||
void set_ops_to_details(const visualize_tree_ops_map_t& ops_map)
|
||||
{
|
||||
m_ops_to_details = ops_map;
|
||||
}
|
||||
|
||||
void render() const;
|
||||
protected:
|
||||
void add_node_arguments(std::shared_ptr<Node> node,
|
||||
std::unordered_map<Node*, HeightMap>& height_maps,
|
||||
size_t& fake_node_ctr);
|
||||
std::string add_attributes(std::shared_ptr<Node> node);
|
||||
virtual std::string get_attributes(std::shared_ptr<Node> node);
|
||||
virtual std::string get_node_name(std::shared_ptr<Node> node);
|
||||
std::string get_constant_value(std::shared_ptr<Node> node, size_t max_elements = 7);
|
||||
|
||||
std::stringstream m_ss;
|
||||
std::string m_name;
|
||||
std::set<std::shared_ptr<Node>> m_nodes_with_attributes;
|
||||
visualize_tree_ops_map_t m_ops_to_details;
|
||||
node_modifiers_t m_node_modifiers = nullptr;
|
||||
bool m_dot_only;
|
||||
static const int max_jump_distance;
|
||||
};
|
||||
void render() const;
|
||||
|
||||
std::stringstream m_ss;
|
||||
std::string m_name;
|
||||
std::set<std::shared_ptr<Node>> m_nodes_with_attributes;
|
||||
visualize_tree_ops_map_t m_ops_to_details;
|
||||
node_modifiers_t m_node_modifiers = nullptr;
|
||||
bool m_dot_only;
|
||||
static const int max_jump_distance;
|
||||
};
|
||||
}
|
||||
}
|
||||
|
@ -24,60 +24,55 @@ namespace ngraph
|
||||
{
|
||||
namespace runtime
|
||||
{
|
||||
class AlignedBuffer;
|
||||
/// \brief Allocates a block of memory on the specified alignment. The actual size of the
|
||||
/// allocated memory is larger than the requested size by the alignment, so allocating 1
|
||||
/// byte
|
||||
/// on 64 byte alignment will allocate 65 bytes.
|
||||
class NGRAPH_API AlignedBuffer
|
||||
{
|
||||
public:
|
||||
// Allocator objects and the allocation interfaces are owned by the
|
||||
// creators of AlignedBuffers. They need to ensure that the lifetime of
|
||||
// allocator exceeds the lifetime of this AlignedBuffer.
|
||||
AlignedBuffer(size_t byte_size, size_t alignment = 64);
|
||||
|
||||
AlignedBuffer();
|
||||
~AlignedBuffer();
|
||||
|
||||
AlignedBuffer(AlignedBuffer&& other);
|
||||
AlignedBuffer& operator=(AlignedBuffer&& other);
|
||||
|
||||
size_t size() const { return m_byte_size; }
|
||||
void* get_ptr(size_t offset) const { return m_aligned_buffer + offset; }
|
||||
void* get_ptr() { return m_aligned_buffer; }
|
||||
const void* get_ptr() const { return m_aligned_buffer; }
|
||||
template <typename T>
|
||||
T* get_ptr()
|
||||
{
|
||||
return reinterpret_cast<T*>(m_aligned_buffer);
|
||||
}
|
||||
template <typename T>
|
||||
const T* get_ptr() const
|
||||
{
|
||||
return reinterpret_cast<const T*>(m_aligned_buffer);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
explicit operator T*()
|
||||
{
|
||||
return get_ptr<T>();
|
||||
}
|
||||
|
||||
private:
|
||||
AlignedBuffer(const AlignedBuffer&) = delete;
|
||||
AlignedBuffer& operator=(const AlignedBuffer&) = delete;
|
||||
|
||||
protected:
|
||||
char* m_allocated_buffer;
|
||||
char* m_aligned_buffer;
|
||||
size_t m_byte_size;
|
||||
};
|
||||
}
|
||||
}
|
||||
|
||||
/// \brief Allocates a block of memory on the specified alignment. The actual size of the
|
||||
/// allocated memory is larger than the requested size by the alignment, so allocating 1 byte
|
||||
/// on 64 byte alignment will allocate 65 bytes.
|
||||
class NGRAPH_API ngraph::runtime::AlignedBuffer
|
||||
{
|
||||
public:
|
||||
// Allocator objects and the allocation interfaces are owned by the
|
||||
// creators of AlignedBuffers. They need to ensure that the lifetime of
|
||||
// allocator exceeds the lifetime of this AlignedBuffer.
|
||||
AlignedBuffer(size_t byte_size, size_t alignment = 64);
|
||||
|
||||
AlignedBuffer();
|
||||
~AlignedBuffer();
|
||||
|
||||
AlignedBuffer(AlignedBuffer&& other);
|
||||
AlignedBuffer& operator=(AlignedBuffer&& other);
|
||||
|
||||
size_t size() const { return m_byte_size; }
|
||||
void* get_ptr(size_t offset) const { return m_aligned_buffer + offset; }
|
||||
void* get_ptr() { return m_aligned_buffer; }
|
||||
const void* get_ptr() const { return m_aligned_buffer; }
|
||||
template <typename T>
|
||||
T* get_ptr()
|
||||
{
|
||||
return reinterpret_cast<T*>(m_aligned_buffer);
|
||||
}
|
||||
template <typename T>
|
||||
const T* get_ptr() const
|
||||
{
|
||||
return reinterpret_cast<const T*>(m_aligned_buffer);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
explicit operator T*()
|
||||
{
|
||||
return get_ptr<T>();
|
||||
}
|
||||
|
||||
private:
|
||||
AlignedBuffer(const AlignedBuffer&) = delete;
|
||||
AlignedBuffer& operator=(const AlignedBuffer&) = delete;
|
||||
|
||||
protected:
|
||||
char* m_allocated_buffer;
|
||||
char* m_aligned_buffer;
|
||||
size_t m_byte_size;
|
||||
};
|
||||
|
||||
namespace ngraph
|
||||
{
|
||||
template <>
|
||||
class NGRAPH_API AttributeAdapter<std::shared_ptr<runtime::AlignedBuffer>>
|
||||
: public ValueAccessor<void*>
|
||||
|
@ -25,10 +25,6 @@
|
||||
|
||||
namespace ngraph
|
||||
{
|
||||
namespace runtime
|
||||
{
|
||||
class HostTensor;
|
||||
}
|
||||
namespace op
|
||||
{
|
||||
namespace v0
|
||||
@ -36,100 +32,106 @@ namespace ngraph
|
||||
class Constant;
|
||||
}
|
||||
}
|
||||
namespace runtime
|
||||
{
|
||||
class NGRAPH_API HostTensor : public ngraph::runtime::Tensor
|
||||
{
|
||||
public:
|
||||
HostTensor(const element::Type& element_type,
|
||||
const Shape& shape,
|
||||
void* memory_pointer,
|
||||
const std::string& name = "");
|
||||
HostTensor(const element::Type& element_type,
|
||||
const Shape& shape,
|
||||
const std::string& name = "");
|
||||
HostTensor(const element::Type& element_type,
|
||||
const PartialShape& partial_shape,
|
||||
const std::string& name = "");
|
||||
HostTensor(const std::string& name = "");
|
||||
explicit HostTensor(const Output<Node>&);
|
||||
explicit HostTensor(const std::shared_ptr<op::v0::Constant>& constant);
|
||||
virtual ~HostTensor() override;
|
||||
|
||||
void initialize(const std::shared_ptr<op::v0::Constant>& constant);
|
||||
|
||||
void* get_data_ptr();
|
||||
const void* get_data_ptr() const;
|
||||
|
||||
template <typename T>
|
||||
T* get_data_ptr()
|
||||
{
|
||||
return static_cast<T*>(get_data_ptr());
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
const T* get_data_ptr() const
|
||||
{
|
||||
return static_cast<T*>(get_data_ptr());
|
||||
}
|
||||
|
||||
template <element::Type_t ET>
|
||||
typename element_type_traits<ET>::value_type* get_data_ptr()
|
||||
{
|
||||
NGRAPH_CHECK(ET == get_element_type(),
|
||||
"get_data_ptr() called for incorrect element type.");
|
||||
return static_cast<typename element_type_traits<ET>::value_type*>(get_data_ptr());
|
||||
}
|
||||
|
||||
template <element::Type_t ET>
|
||||
const typename element_type_traits<ET>::value_type* get_data_ptr() const
|
||||
{
|
||||
NGRAPH_CHECK(ET == get_element_type(),
|
||||
"get_data_ptr() called for incorrect element type.");
|
||||
return static_cast<typename element_type_traits<ET>::value_type>(get_data_ptr());
|
||||
}
|
||||
|
||||
/// \brief Write bytes directly into the tensor
|
||||
/// \param p Pointer to source of data
|
||||
/// \param n Number of bytes to write, must be integral number of elements.
|
||||
void write(const void* p, size_t n) override;
|
||||
|
||||
/// \brief Read bytes directly from the tensor
|
||||
/// \param p Pointer to destination for data
|
||||
/// \param n Number of bytes to read, must be integral number of elements.
|
||||
void read(void* p, size_t n) const override;
|
||||
|
||||
bool get_is_allocated() const;
|
||||
/// \brief Set the element type. Must be compatible with the current element type.
|
||||
/// \param element_type The element type
|
||||
void set_element_type(const element::Type& element_type);
|
||||
/// \brief Set the actual shape of the tensor compatibly with the partial shape.
|
||||
/// \param shape The shape being set
|
||||
void set_shape(const Shape& shape);
|
||||
/// \brief Set the shape of a node from an input
|
||||
/// \param arg The input argument
|
||||
void set_unary(const HostTensorPtr& arg);
|
||||
/// \brief Set the shape of the tensor using broadcast rules
|
||||
/// \param autob The broadcast mode
|
||||
/// \param arg0 The first argument
|
||||
/// \param arg1 The second argument
|
||||
void set_broadcast(const op::AutoBroadcastSpec& autob,
|
||||
const HostTensorPtr& arg0,
|
||||
const HostTensorPtr& arg1);
|
||||
/// \brief Set the shape of the tensor using broadcast rules
|
||||
/// \param autob The broadcast mode
|
||||
/// \param arg0 The first argument
|
||||
/// \param arg1 The second argument
|
||||
/// \param element_type The output element type
|
||||
void set_broadcast(const op::AutoBroadcastSpec& autob,
|
||||
const HostTensorPtr& arg0,
|
||||
const HostTensorPtr& arg1,
|
||||
const element::Type& element_type);
|
||||
|
||||
private:
|
||||
void allocate_buffer();
|
||||
HostTensor(const HostTensor&) = delete;
|
||||
HostTensor(HostTensor&&) = delete;
|
||||
HostTensor& operator=(const HostTensor&) = delete;
|
||||
|
||||
void* m_memory_pointer{nullptr};
|
||||
void* m_allocated_buffer_pool{nullptr};
|
||||
void* m_aligned_buffer_pool{nullptr};
|
||||
size_t m_buffer_size;
|
||||
};
|
||||
}
|
||||
}
|
||||
|
||||
class NGRAPH_API ngraph::runtime::HostTensor : public ngraph::runtime::Tensor
|
||||
{
|
||||
public:
|
||||
HostTensor(const element::Type& element_type,
|
||||
const Shape& shape,
|
||||
void* memory_pointer,
|
||||
const std::string& name = "");
|
||||
HostTensor(const element::Type& element_type, const Shape& shape, const std::string& name = "");
|
||||
HostTensor(const element::Type& element_type,
|
||||
const PartialShape& partial_shape,
|
||||
const std::string& name = "");
|
||||
HostTensor(const std::string& name = "");
|
||||
explicit HostTensor(const Output<Node>&);
|
||||
explicit HostTensor(const std::shared_ptr<op::v0::Constant>& constant);
|
||||
virtual ~HostTensor() override;
|
||||
|
||||
void initialize(const std::shared_ptr<op::v0::Constant>& constant);
|
||||
|
||||
void* get_data_ptr();
|
||||
const void* get_data_ptr() const;
|
||||
|
||||
template <typename T>
|
||||
T* get_data_ptr()
|
||||
{
|
||||
return static_cast<T*>(get_data_ptr());
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
const T* get_data_ptr() const
|
||||
{
|
||||
return static_cast<T*>(get_data_ptr());
|
||||
}
|
||||
|
||||
template <element::Type_t ET>
|
||||
typename element_type_traits<ET>::value_type* get_data_ptr()
|
||||
{
|
||||
NGRAPH_CHECK(ET == get_element_type(), "get_data_ptr() called for incorrect element type.");
|
||||
return static_cast<typename element_type_traits<ET>::value_type*>(get_data_ptr());
|
||||
}
|
||||
|
||||
template <element::Type_t ET>
|
||||
const typename element_type_traits<ET>::value_type* get_data_ptr() const
|
||||
{
|
||||
NGRAPH_CHECK(ET == get_element_type(), "get_data_ptr() called for incorrect element type.");
|
||||
return static_cast<typename element_type_traits<ET>::value_type>(get_data_ptr());
|
||||
}
|
||||
|
||||
/// \brief Write bytes directly into the tensor
|
||||
/// \param p Pointer to source of data
|
||||
/// \param n Number of bytes to write, must be integral number of elements.
|
||||
void write(const void* p, size_t n) override;
|
||||
|
||||
/// \brief Read bytes directly from the tensor
|
||||
/// \param p Pointer to destination for data
|
||||
/// \param n Number of bytes to read, must be integral number of elements.
|
||||
void read(void* p, size_t n) const override;
|
||||
|
||||
bool get_is_allocated() const;
|
||||
/// \brief Set the element type. Must be compatible with the current element type.
|
||||
/// \param element_type The element type
|
||||
void set_element_type(const element::Type& element_type);
|
||||
/// \brief Set the actual shape of the tensor compatibly with the partial shape.
|
||||
/// \param shape The shape being set
|
||||
void set_shape(const Shape& shape);
|
||||
/// \brief Set the shape of a node from an input
|
||||
/// \param arg The input argument
|
||||
void set_unary(const HostTensorPtr& arg);
|
||||
/// \brief Set the shape of the tensor using broadcast rules
|
||||
/// \param autob The broadcast mode
|
||||
/// \param arg0 The first argument
|
||||
/// \param arg1 The second argument
|
||||
void set_broadcast(const op::AutoBroadcastSpec& autob,
|
||||
const HostTensorPtr& arg0,
|
||||
const HostTensorPtr& arg1);
|
||||
/// \brief Set the shape of the tensor using broadcast rules
|
||||
/// \param autob The broadcast mode
|
||||
/// \param arg0 The first argument
|
||||
/// \param arg1 The second argument
|
||||
/// \param element_type The output element type
|
||||
void set_broadcast(const op::AutoBroadcastSpec& autob,
|
||||
const HostTensorPtr& arg0,
|
||||
const HostTensorPtr& arg1,
|
||||
const element::Type& element_type);
|
||||
|
||||
private:
|
||||
void allocate_buffer();
|
||||
HostTensor(const HostTensor&) = delete;
|
||||
HostTensor(HostTensor&&) = delete;
|
||||
HostTensor& operator=(const HostTensor&) = delete;
|
||||
|
||||
void* m_memory_pointer{nullptr};
|
||||
void* m_allocated_buffer_pool{nullptr};
|
||||
void* m_aligned_buffer_pool{nullptr};
|
||||
size_t m_buffer_size;
|
||||
};
|
||||
|
Loading…
Reference in New Issue
Block a user