Fixed headers for doxygen genration (#3746)

This commit is contained in:
Ilya Churaev 2020-12-28 13:33:26 +03:00 committed by GitHub
parent e82257d021
commit 72cd81305c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 472 additions and 451 deletions

View File

@ -37,6 +37,5 @@ namespace ngraph
void copy_runtime_info_to_target_inputs(const std::shared_ptr<Node>& node,
const Output<Node>& replacement);
};
} // namespace pass
} // namespace ngraph

View File

@ -22,24 +22,21 @@ namespace ngraph
{
namespace pass
{
class ConvertFP32ToFP16;
class NGRAPH_API ConvertFP32ToFP16 : public ngraph::pass::GraphRewrite
{
public:
NGRAPH_RTTI_DECLARATION;
ConvertFP32ToFP16()
: GraphRewrite()
{
convert_constants_precision();
convert_parameters_precision();
}
private:
void convert_constants_precision();
void convert_parameters_precision();
};
} // namespace pass
} // namespace ngraph
class NGRAPH_API ngraph::pass::ConvertFP32ToFP16 : public ngraph::pass::GraphRewrite
{
public:
NGRAPH_RTTI_DECLARATION;
ConvertFP32ToFP16()
: GraphRewrite()
{
convert_constants_precision();
convert_parameters_precision();
}
private:
void convert_constants_precision();
void convert_parameters_precision();
};

View File

@ -25,178 +25,194 @@
namespace ngraph
{
namespace pass
{
class GraphRewrite;
class RecurrentGraphRewrite;
class MatcherPass;
}
using matcher_pass_callback = std::function<bool(ngraph::pattern::Matcher& m)>;
using graph_rewrite_callback = std::function<bool(ngraph::pattern::Matcher& m)>;
using recurrent_graph_rewrite_callback =
std::function<bool(ngraph::pattern::RecurrentMatcher& m)>;
using handler_callback = std::function<bool(const std::shared_ptr<Node>& node)>;
}
/// \brief MatcherPass is a basic block for pattern based transformations. It describes pattern and
/// action that is applied if pattern is matched.
///
/// MatcherPass consists of Matcher and matcher_pass_callback that needs to be implemented and
/// finally registered by using \sa register_matcher. MatcherPass can be executed on node within
/// \sa apply method. To run matcher pass on Function use GraphRewrite.
/// In addition MatcherPass provides a way for adding new operations into GraphRewrite execution
/// queue. That means that operations that were created inside transformation callback can be added
/// for matching. To register node use \sa register_new_node method. GraphRewrite automatically
/// takes registered nodes and put them to execution queue. If multiple nodes were register make
/// sure that they were registered in topological order.
/// Note: when implementing pattern for Matcher make sure that root node is an operation from opset
/// or has ngraph::pattern::op::WrapType. That will help GraphRewrite to execute matcher passes more
/// efficient.
class NGRAPH_API ngraph::pass::MatcherPass : public ngraph::pass::PassBase
{
public:
NGRAPH_RTTI_DECLARATION;
MatcherPass() = default;
MatcherPass(const MatcherPass&) = delete;
MatcherPass& operator=(const MatcherPass&) = delete;
explicit MatcherPass(const std::string& name,
const std::shared_ptr<pattern::Matcher>& m,
const handler_callback& handler,
const PassPropertyMask& property = PassProperty::CHANGE_DYNAMIC_STATE)
: PassBase()
, m_handler(handler)
, m_matcher(m)
namespace pass
{
set_name(name);
set_property(property, true);
}
/// \brief MatcherPass is a basic block for pattern based transformations. It describes
/// pattern and
/// action that is applied if pattern is matched.
///
/// MatcherPass consists of Matcher and matcher_pass_callback that needs to be implemented
/// and
/// finally registered by using \sa register_matcher. MatcherPass can be executed on node
/// within
/// \sa apply method. To run matcher pass on Function use GraphRewrite.
/// In addition MatcherPass provides a way for adding new operations into GraphRewrite
/// execution
/// queue. That means that operations that were created inside transformation callback can
/// be added
/// for matching. To register node use \sa register_new_node method. GraphRewrite
/// automatically
/// takes registered nodes and put them to execution queue. If multiple nodes were register
/// make
/// sure that they were registered in topological order.
/// Note: when implementing pattern for Matcher make sure that root node is an operation
/// from opset
/// or has ngraph::pattern::op::WrapType. That will help GraphRewrite to execute matcher
/// passes more
/// efficient.
bool apply(std::shared_ptr<ngraph::Node> node);
template <typename T, class... Args>
std::shared_ptr<T> register_new_node(Args&&... args)
{
auto node = std::make_shared<T>(std::forward<Args>(args)...);
m_new_nodes.push_back(node);
return node;
}
const std::vector<std::shared_ptr<ngraph::Node>>& get_new_nodes() { return m_new_nodes; }
void clear_new_nodes() { m_new_nodes.clear(); }
std::shared_ptr<pattern::Matcher> get_matcher() { return m_matcher; }
protected:
void register_matcher(const std::shared_ptr<pattern::Matcher>& m,
const ngraph::graph_rewrite_callback& callback,
const PassPropertyMask& property = PassProperty::CHANGE_DYNAMIC_STATE);
private:
handler_callback m_handler;
std::shared_ptr<pattern::Matcher> m_matcher;
std::vector<std::shared_ptr<ngraph::Node>> m_new_nodes;
};
/// \brief GraphRewrite is a container for MatcherPasses that allows to run them on Function in
/// efficient way
///
/// Graph rewrite pass is used for matcher passes execution on Function.
/// To register MatcherPass use \sa add_matcher<T>(args) method where T is a MatcherPass class.
/// As a default algorithm graph rewrite pass traverse Function in topological order and applies
/// registered matcher passes for each node. But if all registered matcher passes have type based
/// root node in Matcher pattern then efficient mechanism is used to execute them.
/// Matcher pattern root is type based if it's operation from opset or pattern::op::WrapType.
/// Note: when implementing pattern for Matcher make sure that root node is an operation from opset
/// or has ngraph::pattern::op::WrapType. That will help GraphRewrite to execute matcher passes more
/// efficient.
class NGRAPH_API ngraph::pass::GraphRewrite : public ngraph::pass::FunctionPass
{
public:
NGRAPH_RTTI_DECLARATION;
GraphRewrite() = default;
explicit GraphRewrite(const std::shared_ptr<MatcherPass>& pass)
: FunctionPass()
{
m_matchers.push_back(pass);
}
/// \brief Register given transformation class type to GraphRewrite execution list
/// All registered transformations will be executed in a single graph traversal.
/// Example below show the basic usage of pass::GraphRewrite
///
/// pass::Manager manager;
/// auto anchor = manager.register_pass<GraphRewrite>();
/// anchor->add_matcher<MatcherPassA>();
/// anchor->add_matcher<MatcherPassB>();
/// anchor->set_name("CommonMathcers");
/// manager.run_passes(f);
///
/// For some purposes transformation can be registered and disabled by default.
///
/// anchor->add_matcher<MatcherPassB, false>();
///
/// \return shared_ptr to the transformation instance
template <typename T, bool Enabled = true, class... Args>
std::shared_ptr<T> add_matcher(Args&&... args)
{
static_assert(std::is_base_of<pass::MatcherPass, T>::value,
"pass not derived from MatcherPass");
auto pass = std::make_shared<T>(std::forward<Args>(args)...);
auto pass_config = get_pass_config();
pass->set_pass_config(pass_config);
if (!Enabled && !pass_config->is_enabled<T>())
class NGRAPH_API MatcherPass : public ngraph::pass::PassBase
{
pass_config->disable<T>();
}
m_matchers.push_back(pass);
return pass;
}
NGRAPH_DEPRECATED("Use MatcherPass instead")
void add_matcher(const std::shared_ptr<pattern::Matcher>& m,
const ngraph::graph_rewrite_callback& callback,
const PassPropertyMask& property);
public:
NGRAPH_RTTI_DECLARATION;
NGRAPH_DEPRECATED("Use MatcherPass instead")
void add_matcher(const std::shared_ptr<pattern::Matcher>& m,
const ngraph::graph_rewrite_callback& callback);
MatcherPass() = default;
bool run_on_function(std::shared_ptr<ngraph::Function> f) override;
MatcherPass(const MatcherPass&) = delete;
MatcherPass& operator=(const MatcherPass&) = delete;
void set_pass_config(const std::shared_ptr<PassConfig>& pass_config) override;
explicit MatcherPass(
const std::string& name,
const std::shared_ptr<pattern::Matcher>& m,
const handler_callback& handler,
const PassPropertyMask& property = PassProperty::CHANGE_DYNAMIC_STATE)
: PassBase()
, m_handler(handler)
, m_matcher(m)
{
set_name(name);
set_property(property, true);
}
protected:
bool m_enable_shape_inference = false;
bool apply(std::shared_ptr<ngraph::Node> node);
std::vector<std::shared_ptr<ngraph::pass::MatcherPass>> m_matchers;
};
template <typename T, class... Args>
std::shared_ptr<T> register_new_node(Args&&... args)
{
auto node = std::make_shared<T>(std::forward<Args>(args)...);
m_new_nodes.push_back(node);
return node;
}
class NGRAPH_API ngraph::pass::RecurrentGraphRewrite : public ngraph::pass::FunctionPass
{
public:
RecurrentGraphRewrite(size_t num_iters = 10)
: FunctionPass()
, m_num_iters(num_iters)
{
}
const std::vector<std::shared_ptr<ngraph::Node>>& get_new_nodes()
{
return m_new_nodes;
}
void clear_new_nodes() { m_new_nodes.clear(); }
std::shared_ptr<pattern::Matcher> get_matcher() { return m_matcher; }
protected:
void register_matcher(
const std::shared_ptr<pattern::Matcher>& m,
const ngraph::graph_rewrite_callback& callback,
const PassPropertyMask& property = PassProperty::CHANGE_DYNAMIC_STATE);
void add_matcher(const std::shared_ptr<pattern::RecurrentMatcher>& m,
const ngraph::recurrent_graph_rewrite_callback& callback,
const PassPropertyMask& property);
private:
handler_callback m_handler;
std::shared_ptr<pattern::Matcher> m_matcher;
std::vector<std::shared_ptr<ngraph::Node>> m_new_nodes;
};
// TODO: This interface may deprecate after all passes are refactored.
void add_matcher(const std::shared_ptr<pattern::RecurrentMatcher>& m,
const ngraph::recurrent_graph_rewrite_callback& callback);
/// \brief GraphRewrite is a container for MatcherPasses that allows to run them on Function
/// in
/// efficient way
///
/// Graph rewrite pass is used for matcher passes execution on Function.
/// To register MatcherPass use \sa add_matcher<T>(args) method where T is a MatcherPass
/// class.
/// As a default algorithm graph rewrite pass traverse Function in topological order and
/// applies
/// registered matcher passes for each node. But if all registered matcher passes have type
/// based
/// root node in Matcher pattern then efficient mechanism is used to execute them.
/// Matcher pattern root is type based if it's operation from opset or
/// pattern::op::WrapType.
/// Note: when implementing pattern for Matcher make sure that root node is an operation
/// from opset
/// or has ngraph::pattern::op::WrapType. That will help GraphRewrite to execute matcher
/// passes more
/// efficient.
virtual bool run_on_function(std::shared_ptr<ngraph::Function> f);
class NGRAPH_API GraphRewrite : public ngraph::pass::FunctionPass
{
public:
NGRAPH_RTTI_DECLARATION;
private:
size_t m_num_iters;
GraphRewrite() = default;
std::vector<std::shared_ptr<ngraph::pass::MatcherPass>> m_matchers;
};
explicit GraphRewrite(const std::shared_ptr<MatcherPass>& pass)
: FunctionPass()
{
m_matchers.push_back(pass);
}
/// \brief Register given transformation class type to GraphRewrite execution list
/// All registered transformations will be executed in a single graph traversal.
/// Example below show the basic usage of pass::GraphRewrite
///
/// pass::Manager manager;
/// auto anchor = manager.register_pass<GraphRewrite>();
/// anchor->add_matcher<MatcherPassA>();
/// anchor->add_matcher<MatcherPassB>();
/// anchor->set_name("CommonMathcers");
/// manager.run_passes(f);
///
/// For some purposes transformation can be registered and disabled by default.
///
/// anchor->add_matcher<MatcherPassB, false>();
///
/// \return shared_ptr to the transformation instance
template <typename T, bool Enabled = true, class... Args>
std::shared_ptr<T> add_matcher(Args&&... args)
{
static_assert(std::is_base_of<pass::MatcherPass, T>::value,
"pass not derived from MatcherPass");
auto pass = std::make_shared<T>(std::forward<Args>(args)...);
auto pass_config = get_pass_config();
pass->set_pass_config(pass_config);
if (!Enabled && !pass_config->is_enabled<T>())
{
pass_config->disable<T>();
}
m_matchers.push_back(pass);
return pass;
}
NGRAPH_DEPRECATED("Use MatcherPass instead")
void add_matcher(const std::shared_ptr<pattern::Matcher>& m,
const ngraph::graph_rewrite_callback& callback,
const PassPropertyMask& property);
NGRAPH_DEPRECATED("Use MatcherPass instead")
void add_matcher(const std::shared_ptr<pattern::Matcher>& m,
const ngraph::graph_rewrite_callback& callback);
bool run_on_function(std::shared_ptr<ngraph::Function> f) override;
void set_pass_config(const std::shared_ptr<PassConfig>& pass_config) override;
protected:
bool m_enable_shape_inference = false;
std::vector<std::shared_ptr<ngraph::pass::MatcherPass>> m_matchers;
};
class NGRAPH_API RecurrentGraphRewrite : public ngraph::pass::FunctionPass
{
public:
RecurrentGraphRewrite(size_t num_iters = 10)
: FunctionPass()
, m_num_iters(num_iters)
{
}
void add_matcher(const std::shared_ptr<pattern::RecurrentMatcher>& m,
const ngraph::recurrent_graph_rewrite_callback& callback,
const PassPropertyMask& property);
// TODO: This interface may deprecate after all passes are refactored.
void add_matcher(const std::shared_ptr<pattern::RecurrentMatcher>& m,
const ngraph::recurrent_graph_rewrite_callback& callback);
virtual bool run_on_function(std::shared_ptr<ngraph::Function> f);
private:
size_t m_num_iters;
std::vector<std::shared_ptr<ngraph::pass::MatcherPass>> m_matchers;
};
} // namespace pass
} // namespace ngraph

View File

@ -28,93 +28,103 @@ namespace ngraph
{
namespace pass
{
class Manager;
class NGRAPH_API Manager
{
public:
Manager();
~Manager();
//// \brief Construct Manager with shared PassConfig instance
explicit Manager(std::shared_ptr<PassConfig> pass_config);
/// \brief Register given transformation class type to execution list
/// Example below show the basic usage of pass::Manager
///
/// pass::Manager manager;
/// manager.register_pass<MyTransformation>(/*transformation constructor ars*/);
/// manager.run_passes(f);
///
/// For some purposes transformation can be registered and disabled by default.
///
/// manager.register_pass<MyTransformation, false>();
///
/// \return shared_ptr to the transformation instance
template <typename T, bool Enable = true, class... Args>
std::shared_ptr<T> register_pass(Args&&... args)
{
auto rc = push_pass<T>(std::forward<Args>(args)...);
rc->set_pass_config(m_pass_config);
if (m_per_pass_validation)
{
push_pass<Validate>();
}
if (!Enable && !m_pass_config->is_enabled<T>())
{
m_pass_config->disable<T>();
}
return rc;
}
void run_passes(std::shared_ptr<Function>);
void set_pass_visualization(bool new_state) { m_visualize = new_state; }
/// \brief Set flag to enable/disable running Validate pass after executing
/// each registered pass
/// \param new_state Value "true" enables Validate pass run; "false", otherwise
void set_per_pass_validation(bool new_state) { m_per_pass_validation = new_state; }
/// \brief Callback is a lambda function that can be used by registered transformations.
/// The main purpose of this callback is to provide a way for plugins to disable/enable
/// transformations based on some conditions. In some cases plugins may want not to
/// execute some
/// transformations.
/// For example plugin can disable unpleasant decompositions because of performance
/// reasons for
/// some cases.
/// Callback example:
/// auto callback = [](const std::shared_ptr<const ngraph::Node> & node) -> bool {
/// return std::dynamic_pointer_cast<const ngraph::opset3::DepthToSpace>(node) !=
/// nullptr;
/// };
/// This callback returns true in case of DepthToSpace operation. So when execution
/// DepthToSpace
/// decomposition pass will check is this decomposition needed or plugin can execute
/// this
/// operation directly. And of course on transformation side we need to have a response
/// for this
/// callback.
/// if (transformation_callback(batch_to_space)) {
/// return false;
/// }
/// \param callback lamda function that returns true in case if node is supported by
/// plugin and
/// transformation is not needed
NGRAPH_DEPRECATED("Please use get_pass_config() to configure transformation pipeline")
void set_callback(const param_callback& callback)
{
m_pass_config->set_callback(callback);
}
/// \return PassConfig shared object. This object is used for transformations pipeline
/// configuration.
/// This object allows to disable/enable transformations execution, set callback to
/// particular
/// transformation. For mo details see PassConfig class.
std::shared_ptr<PassConfig> get_pass_config() { return m_pass_config; }
protected:
template <typename T, class... Args>
std::shared_ptr<T> push_pass(Args&&... args)
{
static_assert(std::is_base_of<pass::PassBase, T>::value,
"pass not derived from pass base");
auto pass = std::make_shared<T>(std::forward<Args>(args)...);
auto pass_base = std::static_pointer_cast<PassBase>(pass);
m_pass_list.push_back(pass_base);
return pass;
}
std::shared_ptr<PassConfig> m_pass_config;
std::vector<std::shared_ptr<PassBase>> m_pass_list;
bool m_visualize = false;
bool m_per_pass_validation = true;
};
}
}
class NGRAPH_API ngraph::pass::Manager
{
public:
Manager();
~Manager();
//// \brief Construct Manager with shared PassConfig instance
explicit Manager(std::shared_ptr<PassConfig> pass_config);
/// \brief Register given transformation class type to execution list
/// Example below show the basic usage of pass::Manager
///
/// pass::Manager manager;
/// manager.register_pass<MyTransformation>(/*transformation constructor ars*/);
/// manager.run_passes(f);
///
/// For some purposes transformation can be registered and disabled by default.
///
/// manager.register_pass<MyTransformation, false>();
///
/// \return shared_ptr to the transformation instance
template <typename T, bool Enable = true, class... Args>
std::shared_ptr<T> register_pass(Args&&... args)
{
auto rc = push_pass<T>(std::forward<Args>(args)...);
rc->set_pass_config(m_pass_config);
if (m_per_pass_validation)
{
push_pass<Validate>();
}
if (!Enable && !m_pass_config->is_enabled<T>())
{
m_pass_config->disable<T>();
}
return rc;
}
void run_passes(std::shared_ptr<Function>);
void set_pass_visualization(bool new_state) { m_visualize = new_state; }
/// \brief Set flag to enable/disable running Validate pass after executing
/// each registered pass
/// \param new_state Value "true" enables Validate pass run; "false", otherwise
void set_per_pass_validation(bool new_state) { m_per_pass_validation = new_state; }
/// \brief Callback is a lambda function that can be used by registered transformations.
/// The main purpose of this callback is to provide a way for plugins to disable/enable
/// transformations based on some conditions. In some cases plugins may want not to execute some
/// transformations.
/// For example plugin can disable unpleasant decompositions because of performance reasons for
/// some cases.
/// Callback example:
/// auto callback = [](const std::shared_ptr<const ngraph::Node> & node) -> bool {
/// return std::dynamic_pointer_cast<const ngraph::opset3::DepthToSpace>(node) != nullptr;
/// };
/// This callback returns true in case of DepthToSpace operation. So when execution DepthToSpace
/// decomposition pass will check is this decomposition needed or plugin can execute this
/// operation directly. And of course on transformation side we need to have a response for this
/// callback.
/// if (transformation_callback(batch_to_space)) {
/// return false;
/// }
/// \param callback lamda function that returns true in case if node is supported by plugin and
/// transformation is not needed
NGRAPH_DEPRECATED("Please use get_pass_config() to configure transformation pipeline")
void set_callback(const param_callback& callback) { m_pass_config->set_callback(callback); }
/// \return PassConfig shared object. This object is used for transformations pipeline
/// configuration.
/// This object allows to disable/enable transformations execution, set callback to particular
/// transformation. For mo details see PassConfig class.
std::shared_ptr<PassConfig> get_pass_config() { return m_pass_config; }
protected:
template <typename T, class... Args>
std::shared_ptr<T> push_pass(Args&&... args)
{
static_assert(std::is_base_of<pass::PassBase, T>::value, "pass not derived from pass base");
auto pass = std::make_shared<T>(std::forward<Args>(args)...);
auto pass_base = std::static_pointer_cast<PassBase>(pass);
m_pass_list.push_back(pass_base);
return pass;
}
std::shared_ptr<PassConfig> m_pass_config;
std::vector<std::shared_ptr<PassBase>> m_pass_list;
bool m_visualize = false;
bool m_per_pass_validation = true;
};

View File

@ -27,49 +27,51 @@
#include "ngraph/pass/pass.hpp"
namespace ngraph
{
namespace pass
{
class VisualizeTree;
}
}
class HeightMap;
using visualize_tree_ops_map_t =
std::unordered_map<ngraph::Node::type_info_t,
std::function<void(const ngraph::Node&, std::ostream& ss)>>;
class NGRAPH_API ngraph::pass::VisualizeTree : public FunctionPass
namespace ngraph
{
public:
NGRAPH_RTTI_DECLARATION;
namespace pass
{
class NGRAPH_API VisualizeTree : public FunctionPass
{
public:
NGRAPH_RTTI_DECLARATION;
using node_modifiers_t =
std::function<void(const Node& node, std::vector<std::string>& attributes)>;
VisualizeTree(const std::string& file_name,
node_modifiers_t nm = nullptr,
bool dot_only = false);
bool run_on_function(std::shared_ptr<ngraph::Function>) override;
using node_modifiers_t =
std::function<void(const Node& node, std::vector<std::string>& attributes)>;
VisualizeTree(const std::string& file_name,
node_modifiers_t nm = nullptr,
bool dot_only = false);
bool run_on_function(std::shared_ptr<ngraph::Function>) override;
void set_ops_to_details(const visualize_tree_ops_map_t& ops_map) { m_ops_to_details = ops_map; }
protected:
void add_node_arguments(std::shared_ptr<Node> node,
std::unordered_map<Node*, HeightMap>& height_maps,
size_t& fake_node_ctr);
std::string add_attributes(std::shared_ptr<Node> node);
virtual std::string get_attributes(std::shared_ptr<Node> node);
virtual std::string get_node_name(std::shared_ptr<Node> node);
std::string get_constant_value(std::shared_ptr<Node> node, size_t max_elements = 7);
void set_ops_to_details(const visualize_tree_ops_map_t& ops_map)
{
m_ops_to_details = ops_map;
}
void render() const;
protected:
void add_node_arguments(std::shared_ptr<Node> node,
std::unordered_map<Node*, HeightMap>& height_maps,
size_t& fake_node_ctr);
std::string add_attributes(std::shared_ptr<Node> node);
virtual std::string get_attributes(std::shared_ptr<Node> node);
virtual std::string get_node_name(std::shared_ptr<Node> node);
std::string get_constant_value(std::shared_ptr<Node> node, size_t max_elements = 7);
std::stringstream m_ss;
std::string m_name;
std::set<std::shared_ptr<Node>> m_nodes_with_attributes;
visualize_tree_ops_map_t m_ops_to_details;
node_modifiers_t m_node_modifiers = nullptr;
bool m_dot_only;
static const int max_jump_distance;
};
void render() const;
std::stringstream m_ss;
std::string m_name;
std::set<std::shared_ptr<Node>> m_nodes_with_attributes;
visualize_tree_ops_map_t m_ops_to_details;
node_modifiers_t m_node_modifiers = nullptr;
bool m_dot_only;
static const int max_jump_distance;
};
}
}

View File

@ -24,60 +24,55 @@ namespace ngraph
{
namespace runtime
{
class AlignedBuffer;
/// \brief Allocates a block of memory on the specified alignment. The actual size of the
/// allocated memory is larger than the requested size by the alignment, so allocating 1
/// byte
/// on 64 byte alignment will allocate 65 bytes.
class NGRAPH_API AlignedBuffer
{
public:
// Allocator objects and the allocation interfaces are owned by the
// creators of AlignedBuffers. They need to ensure that the lifetime of
// allocator exceeds the lifetime of this AlignedBuffer.
AlignedBuffer(size_t byte_size, size_t alignment = 64);
AlignedBuffer();
~AlignedBuffer();
AlignedBuffer(AlignedBuffer&& other);
AlignedBuffer& operator=(AlignedBuffer&& other);
size_t size() const { return m_byte_size; }
void* get_ptr(size_t offset) const { return m_aligned_buffer + offset; }
void* get_ptr() { return m_aligned_buffer; }
const void* get_ptr() const { return m_aligned_buffer; }
template <typename T>
T* get_ptr()
{
return reinterpret_cast<T*>(m_aligned_buffer);
}
template <typename T>
const T* get_ptr() const
{
return reinterpret_cast<const T*>(m_aligned_buffer);
}
template <typename T>
explicit operator T*()
{
return get_ptr<T>();
}
private:
AlignedBuffer(const AlignedBuffer&) = delete;
AlignedBuffer& operator=(const AlignedBuffer&) = delete;
protected:
char* m_allocated_buffer;
char* m_aligned_buffer;
size_t m_byte_size;
};
}
}
/// \brief Allocates a block of memory on the specified alignment. The actual size of the
/// allocated memory is larger than the requested size by the alignment, so allocating 1 byte
/// on 64 byte alignment will allocate 65 bytes.
class NGRAPH_API ngraph::runtime::AlignedBuffer
{
public:
// Allocator objects and the allocation interfaces are owned by the
// creators of AlignedBuffers. They need to ensure that the lifetime of
// allocator exceeds the lifetime of this AlignedBuffer.
AlignedBuffer(size_t byte_size, size_t alignment = 64);
AlignedBuffer();
~AlignedBuffer();
AlignedBuffer(AlignedBuffer&& other);
AlignedBuffer& operator=(AlignedBuffer&& other);
size_t size() const { return m_byte_size; }
void* get_ptr(size_t offset) const { return m_aligned_buffer + offset; }
void* get_ptr() { return m_aligned_buffer; }
const void* get_ptr() const { return m_aligned_buffer; }
template <typename T>
T* get_ptr()
{
return reinterpret_cast<T*>(m_aligned_buffer);
}
template <typename T>
const T* get_ptr() const
{
return reinterpret_cast<const T*>(m_aligned_buffer);
}
template <typename T>
explicit operator T*()
{
return get_ptr<T>();
}
private:
AlignedBuffer(const AlignedBuffer&) = delete;
AlignedBuffer& operator=(const AlignedBuffer&) = delete;
protected:
char* m_allocated_buffer;
char* m_aligned_buffer;
size_t m_byte_size;
};
namespace ngraph
{
template <>
class NGRAPH_API AttributeAdapter<std::shared_ptr<runtime::AlignedBuffer>>
: public ValueAccessor<void*>

View File

@ -25,10 +25,6 @@
namespace ngraph
{
namespace runtime
{
class HostTensor;
}
namespace op
{
namespace v0
@ -36,100 +32,106 @@ namespace ngraph
class Constant;
}
}
namespace runtime
{
class NGRAPH_API HostTensor : public ngraph::runtime::Tensor
{
public:
HostTensor(const element::Type& element_type,
const Shape& shape,
void* memory_pointer,
const std::string& name = "");
HostTensor(const element::Type& element_type,
const Shape& shape,
const std::string& name = "");
HostTensor(const element::Type& element_type,
const PartialShape& partial_shape,
const std::string& name = "");
HostTensor(const std::string& name = "");
explicit HostTensor(const Output<Node>&);
explicit HostTensor(const std::shared_ptr<op::v0::Constant>& constant);
virtual ~HostTensor() override;
void initialize(const std::shared_ptr<op::v0::Constant>& constant);
void* get_data_ptr();
const void* get_data_ptr() const;
template <typename T>
T* get_data_ptr()
{
return static_cast<T*>(get_data_ptr());
}
template <typename T>
const T* get_data_ptr() const
{
return static_cast<T*>(get_data_ptr());
}
template <element::Type_t ET>
typename element_type_traits<ET>::value_type* get_data_ptr()
{
NGRAPH_CHECK(ET == get_element_type(),
"get_data_ptr() called for incorrect element type.");
return static_cast<typename element_type_traits<ET>::value_type*>(get_data_ptr());
}
template <element::Type_t ET>
const typename element_type_traits<ET>::value_type* get_data_ptr() const
{
NGRAPH_CHECK(ET == get_element_type(),
"get_data_ptr() called for incorrect element type.");
return static_cast<typename element_type_traits<ET>::value_type>(get_data_ptr());
}
/// \brief Write bytes directly into the tensor
/// \param p Pointer to source of data
/// \param n Number of bytes to write, must be integral number of elements.
void write(const void* p, size_t n) override;
/// \brief Read bytes directly from the tensor
/// \param p Pointer to destination for data
/// \param n Number of bytes to read, must be integral number of elements.
void read(void* p, size_t n) const override;
bool get_is_allocated() const;
/// \brief Set the element type. Must be compatible with the current element type.
/// \param element_type The element type
void set_element_type(const element::Type& element_type);
/// \brief Set the actual shape of the tensor compatibly with the partial shape.
/// \param shape The shape being set
void set_shape(const Shape& shape);
/// \brief Set the shape of a node from an input
/// \param arg The input argument
void set_unary(const HostTensorPtr& arg);
/// \brief Set the shape of the tensor using broadcast rules
/// \param autob The broadcast mode
/// \param arg0 The first argument
/// \param arg1 The second argument
void set_broadcast(const op::AutoBroadcastSpec& autob,
const HostTensorPtr& arg0,
const HostTensorPtr& arg1);
/// \brief Set the shape of the tensor using broadcast rules
/// \param autob The broadcast mode
/// \param arg0 The first argument
/// \param arg1 The second argument
/// \param element_type The output element type
void set_broadcast(const op::AutoBroadcastSpec& autob,
const HostTensorPtr& arg0,
const HostTensorPtr& arg1,
const element::Type& element_type);
private:
void allocate_buffer();
HostTensor(const HostTensor&) = delete;
HostTensor(HostTensor&&) = delete;
HostTensor& operator=(const HostTensor&) = delete;
void* m_memory_pointer{nullptr};
void* m_allocated_buffer_pool{nullptr};
void* m_aligned_buffer_pool{nullptr};
size_t m_buffer_size;
};
}
}
class NGRAPH_API ngraph::runtime::HostTensor : public ngraph::runtime::Tensor
{
public:
HostTensor(const element::Type& element_type,
const Shape& shape,
void* memory_pointer,
const std::string& name = "");
HostTensor(const element::Type& element_type, const Shape& shape, const std::string& name = "");
HostTensor(const element::Type& element_type,
const PartialShape& partial_shape,
const std::string& name = "");
HostTensor(const std::string& name = "");
explicit HostTensor(const Output<Node>&);
explicit HostTensor(const std::shared_ptr<op::v0::Constant>& constant);
virtual ~HostTensor() override;
void initialize(const std::shared_ptr<op::v0::Constant>& constant);
void* get_data_ptr();
const void* get_data_ptr() const;
template <typename T>
T* get_data_ptr()
{
return static_cast<T*>(get_data_ptr());
}
template <typename T>
const T* get_data_ptr() const
{
return static_cast<T*>(get_data_ptr());
}
template <element::Type_t ET>
typename element_type_traits<ET>::value_type* get_data_ptr()
{
NGRAPH_CHECK(ET == get_element_type(), "get_data_ptr() called for incorrect element type.");
return static_cast<typename element_type_traits<ET>::value_type*>(get_data_ptr());
}
template <element::Type_t ET>
const typename element_type_traits<ET>::value_type* get_data_ptr() const
{
NGRAPH_CHECK(ET == get_element_type(), "get_data_ptr() called for incorrect element type.");
return static_cast<typename element_type_traits<ET>::value_type>(get_data_ptr());
}
/// \brief Write bytes directly into the tensor
/// \param p Pointer to source of data
/// \param n Number of bytes to write, must be integral number of elements.
void write(const void* p, size_t n) override;
/// \brief Read bytes directly from the tensor
/// \param p Pointer to destination for data
/// \param n Number of bytes to read, must be integral number of elements.
void read(void* p, size_t n) const override;
bool get_is_allocated() const;
/// \brief Set the element type. Must be compatible with the current element type.
/// \param element_type The element type
void set_element_type(const element::Type& element_type);
/// \brief Set the actual shape of the tensor compatibly with the partial shape.
/// \param shape The shape being set
void set_shape(const Shape& shape);
/// \brief Set the shape of a node from an input
/// \param arg The input argument
void set_unary(const HostTensorPtr& arg);
/// \brief Set the shape of the tensor using broadcast rules
/// \param autob The broadcast mode
/// \param arg0 The first argument
/// \param arg1 The second argument
void set_broadcast(const op::AutoBroadcastSpec& autob,
const HostTensorPtr& arg0,
const HostTensorPtr& arg1);
/// \brief Set the shape of the tensor using broadcast rules
/// \param autob The broadcast mode
/// \param arg0 The first argument
/// \param arg1 The second argument
/// \param element_type The output element type
void set_broadcast(const op::AutoBroadcastSpec& autob,
const HostTensorPtr& arg0,
const HostTensorPtr& arg1,
const element::Type& element_type);
private:
void allocate_buffer();
HostTensor(const HostTensor&) = delete;
HostTensor(HostTensor&&) = delete;
HostTensor& operator=(const HostTensor&) = delete;
void* m_memory_pointer{nullptr};
void* m_allocated_buffer_pool{nullptr};
void* m_aligned_buffer_pool{nullptr};
size_t m_buffer_size;
};