From 72cd81305c5756d6b588b60fe585684c1f63538f Mon Sep 17 00:00:00 2001 From: Ilya Churaev Date: Mon, 28 Dec 2020 13:33:26 +0300 Subject: [PATCH] Fixed headers for doxygen genration (#3746) --- .../include/ngraph/pass/constant_folding.hpp | 1 - .../ngraph/pass/convert_fp32_to_fp16.hpp | 33 +- .../include/ngraph/pass/graph_rewrite.hpp | 328 +++++++++--------- ngraph/core/include/ngraph/pass/manager.hpp | 186 +++++----- .../include/ngraph/pass/visualize_tree.hpp | 72 ++-- .../include/ngraph/runtime/aligned_buffer.hpp | 101 +++--- .../include/ngraph/runtime/host_tensor.hpp | 202 +++++------ 7 files changed, 472 insertions(+), 451 deletions(-) diff --git a/ngraph/core/include/ngraph/pass/constant_folding.hpp b/ngraph/core/include/ngraph/pass/constant_folding.hpp index b48c41dc65f..fd7cfb1eaf3 100644 --- a/ngraph/core/include/ngraph/pass/constant_folding.hpp +++ b/ngraph/core/include/ngraph/pass/constant_folding.hpp @@ -37,6 +37,5 @@ namespace ngraph void copy_runtime_info_to_target_inputs(const std::shared_ptr& node, const Output& replacement); }; - } // namespace pass } // namespace ngraph diff --git a/ngraph/core/include/ngraph/pass/convert_fp32_to_fp16.hpp b/ngraph/core/include/ngraph/pass/convert_fp32_to_fp16.hpp index 80b483463a8..0326f222add 100644 --- a/ngraph/core/include/ngraph/pass/convert_fp32_to_fp16.hpp +++ b/ngraph/core/include/ngraph/pass/convert_fp32_to_fp16.hpp @@ -22,24 +22,21 @@ namespace ngraph { namespace pass { - class ConvertFP32ToFP16; + class NGRAPH_API ConvertFP32ToFP16 : public ngraph::pass::GraphRewrite + { + public: + NGRAPH_RTTI_DECLARATION; + ConvertFP32ToFP16() + : GraphRewrite() + { + convert_constants_precision(); + convert_parameters_precision(); + } + private: + void convert_constants_precision(); + + void convert_parameters_precision(); + }; } // namespace pass } // namespace ngraph - -class NGRAPH_API ngraph::pass::ConvertFP32ToFP16 : public ngraph::pass::GraphRewrite -{ -public: - NGRAPH_RTTI_DECLARATION; - ConvertFP32ToFP16() - : GraphRewrite() - { - convert_constants_precision(); - convert_parameters_precision(); - } - -private: - void convert_constants_precision(); - - void convert_parameters_precision(); -}; diff --git a/ngraph/core/include/ngraph/pass/graph_rewrite.hpp b/ngraph/core/include/ngraph/pass/graph_rewrite.hpp index 9cb572d1c8e..6ce5bb06a08 100644 --- a/ngraph/core/include/ngraph/pass/graph_rewrite.hpp +++ b/ngraph/core/include/ngraph/pass/graph_rewrite.hpp @@ -25,178 +25,194 @@ namespace ngraph { - namespace pass - { - class GraphRewrite; - class RecurrentGraphRewrite; - class MatcherPass; - } - using matcher_pass_callback = std::function; using graph_rewrite_callback = std::function; using recurrent_graph_rewrite_callback = std::function; using handler_callback = std::function& node)>; -} - -/// \brief MatcherPass is a basic block for pattern based transformations. It describes pattern and -/// action that is applied if pattern is matched. -/// -/// MatcherPass consists of Matcher and matcher_pass_callback that needs to be implemented and -/// finally registered by using \sa register_matcher. MatcherPass can be executed on node within -/// \sa apply method. To run matcher pass on Function use GraphRewrite. -/// In addition MatcherPass provides a way for adding new operations into GraphRewrite execution -/// queue. That means that operations that were created inside transformation callback can be added -/// for matching. To register node use \sa register_new_node method. GraphRewrite automatically -/// takes registered nodes and put them to execution queue. If multiple nodes were register make -/// sure that they were registered in topological order. -/// Note: when implementing pattern for Matcher make sure that root node is an operation from opset -/// or has ngraph::pattern::op::WrapType. That will help GraphRewrite to execute matcher passes more -/// efficient. - -class NGRAPH_API ngraph::pass::MatcherPass : public ngraph::pass::PassBase -{ -public: - NGRAPH_RTTI_DECLARATION; - - MatcherPass() = default; - - MatcherPass(const MatcherPass&) = delete; - MatcherPass& operator=(const MatcherPass&) = delete; - - explicit MatcherPass(const std::string& name, - const std::shared_ptr& m, - const handler_callback& handler, - const PassPropertyMask& property = PassProperty::CHANGE_DYNAMIC_STATE) - : PassBase() - , m_handler(handler) - , m_matcher(m) + namespace pass { - set_name(name); - set_property(property, true); - } + /// \brief MatcherPass is a basic block for pattern based transformations. It describes + /// pattern and + /// action that is applied if pattern is matched. + /// + /// MatcherPass consists of Matcher and matcher_pass_callback that needs to be implemented + /// and + /// finally registered by using \sa register_matcher. MatcherPass can be executed on node + /// within + /// \sa apply method. To run matcher pass on Function use GraphRewrite. + /// In addition MatcherPass provides a way for adding new operations into GraphRewrite + /// execution + /// queue. That means that operations that were created inside transformation callback can + /// be added + /// for matching. To register node use \sa register_new_node method. GraphRewrite + /// automatically + /// takes registered nodes and put them to execution queue. If multiple nodes were register + /// make + /// sure that they were registered in topological order. + /// Note: when implementing pattern for Matcher make sure that root node is an operation + /// from opset + /// or has ngraph::pattern::op::WrapType. That will help GraphRewrite to execute matcher + /// passes more + /// efficient. - bool apply(std::shared_ptr node); - - template - std::shared_ptr register_new_node(Args&&... args) - { - auto node = std::make_shared(std::forward(args)...); - m_new_nodes.push_back(node); - return node; - } - - const std::vector>& get_new_nodes() { return m_new_nodes; } - void clear_new_nodes() { m_new_nodes.clear(); } - std::shared_ptr get_matcher() { return m_matcher; } -protected: - void register_matcher(const std::shared_ptr& m, - const ngraph::graph_rewrite_callback& callback, - const PassPropertyMask& property = PassProperty::CHANGE_DYNAMIC_STATE); - -private: - handler_callback m_handler; - std::shared_ptr m_matcher; - std::vector> m_new_nodes; -}; - -/// \brief GraphRewrite is a container for MatcherPasses that allows to run them on Function in -/// efficient way -/// -/// Graph rewrite pass is used for matcher passes execution on Function. -/// To register MatcherPass use \sa add_matcher(args) method where T is a MatcherPass class. -/// As a default algorithm graph rewrite pass traverse Function in topological order and applies -/// registered matcher passes for each node. But if all registered matcher passes have type based -/// root node in Matcher pattern then efficient mechanism is used to execute them. -/// Matcher pattern root is type based if it's operation from opset or pattern::op::WrapType. -/// Note: when implementing pattern for Matcher make sure that root node is an operation from opset -/// or has ngraph::pattern::op::WrapType. That will help GraphRewrite to execute matcher passes more -/// efficient. - -class NGRAPH_API ngraph::pass::GraphRewrite : public ngraph::pass::FunctionPass -{ -public: - NGRAPH_RTTI_DECLARATION; - - GraphRewrite() = default; - - explicit GraphRewrite(const std::shared_ptr& pass) - : FunctionPass() - { - m_matchers.push_back(pass); - } - - /// \brief Register given transformation class type to GraphRewrite execution list - /// All registered transformations will be executed in a single graph traversal. - /// Example below show the basic usage of pass::GraphRewrite - /// - /// pass::Manager manager; - /// auto anchor = manager.register_pass(); - /// anchor->add_matcher(); - /// anchor->add_matcher(); - /// anchor->set_name("CommonMathcers"); - /// manager.run_passes(f); - /// - /// For some purposes transformation can be registered and disabled by default. - /// - /// anchor->add_matcher(); - /// - /// \return shared_ptr to the transformation instance - template - std::shared_ptr add_matcher(Args&&... args) - { - static_assert(std::is_base_of::value, - "pass not derived from MatcherPass"); - auto pass = std::make_shared(std::forward(args)...); - auto pass_config = get_pass_config(); - pass->set_pass_config(pass_config); - if (!Enabled && !pass_config->is_enabled()) + class NGRAPH_API MatcherPass : public ngraph::pass::PassBase { - pass_config->disable(); - } - m_matchers.push_back(pass); - return pass; - } - NGRAPH_DEPRECATED("Use MatcherPass instead") - void add_matcher(const std::shared_ptr& m, - const ngraph::graph_rewrite_callback& callback, - const PassPropertyMask& property); + public: + NGRAPH_RTTI_DECLARATION; - NGRAPH_DEPRECATED("Use MatcherPass instead") - void add_matcher(const std::shared_ptr& m, - const ngraph::graph_rewrite_callback& callback); + MatcherPass() = default; - bool run_on_function(std::shared_ptr f) override; + MatcherPass(const MatcherPass&) = delete; + MatcherPass& operator=(const MatcherPass&) = delete; - void set_pass_config(const std::shared_ptr& pass_config) override; + explicit MatcherPass( + const std::string& name, + const std::shared_ptr& m, + const handler_callback& handler, + const PassPropertyMask& property = PassProperty::CHANGE_DYNAMIC_STATE) + : PassBase() + , m_handler(handler) + , m_matcher(m) + { + set_name(name); + set_property(property, true); + } -protected: - bool m_enable_shape_inference = false; + bool apply(std::shared_ptr node); - std::vector> m_matchers; -}; + template + std::shared_ptr register_new_node(Args&&... args) + { + auto node = std::make_shared(std::forward(args)...); + m_new_nodes.push_back(node); + return node; + } -class NGRAPH_API ngraph::pass::RecurrentGraphRewrite : public ngraph::pass::FunctionPass -{ -public: - RecurrentGraphRewrite(size_t num_iters = 10) - : FunctionPass() - , m_num_iters(num_iters) - { - } + const std::vector>& get_new_nodes() + { + return m_new_nodes; + } + void clear_new_nodes() { m_new_nodes.clear(); } + std::shared_ptr get_matcher() { return m_matcher; } + protected: + void register_matcher( + const std::shared_ptr& m, + const ngraph::graph_rewrite_callback& callback, + const PassPropertyMask& property = PassProperty::CHANGE_DYNAMIC_STATE); - void add_matcher(const std::shared_ptr& m, - const ngraph::recurrent_graph_rewrite_callback& callback, - const PassPropertyMask& property); + private: + handler_callback m_handler; + std::shared_ptr m_matcher; + std::vector> m_new_nodes; + }; - // TODO: This interface may deprecate after all passes are refactored. - void add_matcher(const std::shared_ptr& m, - const ngraph::recurrent_graph_rewrite_callback& callback); + /// \brief GraphRewrite is a container for MatcherPasses that allows to run them on Function + /// in + /// efficient way + /// + /// Graph rewrite pass is used for matcher passes execution on Function. + /// To register MatcherPass use \sa add_matcher(args) method where T is a MatcherPass + /// class. + /// As a default algorithm graph rewrite pass traverse Function in topological order and + /// applies + /// registered matcher passes for each node. But if all registered matcher passes have type + /// based + /// root node in Matcher pattern then efficient mechanism is used to execute them. + /// Matcher pattern root is type based if it's operation from opset or + /// pattern::op::WrapType. + /// Note: when implementing pattern for Matcher make sure that root node is an operation + /// from opset + /// or has ngraph::pattern::op::WrapType. That will help GraphRewrite to execute matcher + /// passes more + /// efficient. - virtual bool run_on_function(std::shared_ptr f); + class NGRAPH_API GraphRewrite : public ngraph::pass::FunctionPass + { + public: + NGRAPH_RTTI_DECLARATION; -private: - size_t m_num_iters; + GraphRewrite() = default; - std::vector> m_matchers; -}; + explicit GraphRewrite(const std::shared_ptr& pass) + : FunctionPass() + { + m_matchers.push_back(pass); + } + + /// \brief Register given transformation class type to GraphRewrite execution list + /// All registered transformations will be executed in a single graph traversal. + /// Example below show the basic usage of pass::GraphRewrite + /// + /// pass::Manager manager; + /// auto anchor = manager.register_pass(); + /// anchor->add_matcher(); + /// anchor->add_matcher(); + /// anchor->set_name("CommonMathcers"); + /// manager.run_passes(f); + /// + /// For some purposes transformation can be registered and disabled by default. + /// + /// anchor->add_matcher(); + /// + /// \return shared_ptr to the transformation instance + template + std::shared_ptr add_matcher(Args&&... args) + { + static_assert(std::is_base_of::value, + "pass not derived from MatcherPass"); + auto pass = std::make_shared(std::forward(args)...); + auto pass_config = get_pass_config(); + pass->set_pass_config(pass_config); + if (!Enabled && !pass_config->is_enabled()) + { + pass_config->disable(); + } + m_matchers.push_back(pass); + return pass; + } + NGRAPH_DEPRECATED("Use MatcherPass instead") + void add_matcher(const std::shared_ptr& m, + const ngraph::graph_rewrite_callback& callback, + const PassPropertyMask& property); + + NGRAPH_DEPRECATED("Use MatcherPass instead") + void add_matcher(const std::shared_ptr& m, + const ngraph::graph_rewrite_callback& callback); + + bool run_on_function(std::shared_ptr f) override; + + void set_pass_config(const std::shared_ptr& pass_config) override; + + protected: + bool m_enable_shape_inference = false; + + std::vector> m_matchers; + }; + + class NGRAPH_API RecurrentGraphRewrite : public ngraph::pass::FunctionPass + { + public: + RecurrentGraphRewrite(size_t num_iters = 10) + : FunctionPass() + , m_num_iters(num_iters) + { + } + + void add_matcher(const std::shared_ptr& m, + const ngraph::recurrent_graph_rewrite_callback& callback, + const PassPropertyMask& property); + + // TODO: This interface may deprecate after all passes are refactored. + void add_matcher(const std::shared_ptr& m, + const ngraph::recurrent_graph_rewrite_callback& callback); + + virtual bool run_on_function(std::shared_ptr f); + + private: + size_t m_num_iters; + + std::vector> m_matchers; + }; + } // namespace pass +} // namespace ngraph diff --git a/ngraph/core/include/ngraph/pass/manager.hpp b/ngraph/core/include/ngraph/pass/manager.hpp index 508c6ea93af..0a672b65e99 100644 --- a/ngraph/core/include/ngraph/pass/manager.hpp +++ b/ngraph/core/include/ngraph/pass/manager.hpp @@ -28,93 +28,103 @@ namespace ngraph { namespace pass { - class Manager; + class NGRAPH_API Manager + { + public: + Manager(); + ~Manager(); + + //// \brief Construct Manager with shared PassConfig instance + explicit Manager(std::shared_ptr pass_config); + + /// \brief Register given transformation class type to execution list + /// Example below show the basic usage of pass::Manager + /// + /// pass::Manager manager; + /// manager.register_pass(/*transformation constructor ars*/); + /// manager.run_passes(f); + /// + /// For some purposes transformation can be registered and disabled by default. + /// + /// manager.register_pass(); + /// + /// \return shared_ptr to the transformation instance + template + std::shared_ptr register_pass(Args&&... args) + { + auto rc = push_pass(std::forward(args)...); + rc->set_pass_config(m_pass_config); + if (m_per_pass_validation) + { + push_pass(); + } + if (!Enable && !m_pass_config->is_enabled()) + { + m_pass_config->disable(); + } + return rc; + } + + void run_passes(std::shared_ptr); + + void set_pass_visualization(bool new_state) { m_visualize = new_state; } + /// \brief Set flag to enable/disable running Validate pass after executing + /// each registered pass + /// \param new_state Value "true" enables Validate pass run; "false", otherwise + void set_per_pass_validation(bool new_state) { m_per_pass_validation = new_state; } + /// \brief Callback is a lambda function that can be used by registered transformations. + /// The main purpose of this callback is to provide a way for plugins to disable/enable + /// transformations based on some conditions. In some cases plugins may want not to + /// execute some + /// transformations. + /// For example plugin can disable unpleasant decompositions because of performance + /// reasons for + /// some cases. + /// Callback example: + /// auto callback = [](const std::shared_ptr & node) -> bool { + /// return std::dynamic_pointer_cast(node) != + /// nullptr; + /// }; + /// This callback returns true in case of DepthToSpace operation. So when execution + /// DepthToSpace + /// decomposition pass will check is this decomposition needed or plugin can execute + /// this + /// operation directly. And of course on transformation side we need to have a response + /// for this + /// callback. + /// if (transformation_callback(batch_to_space)) { + /// return false; + /// } + /// \param callback lamda function that returns true in case if node is supported by + /// plugin and + /// transformation is not needed + NGRAPH_DEPRECATED("Please use get_pass_config() to configure transformation pipeline") + void set_callback(const param_callback& callback) + { + m_pass_config->set_callback(callback); + } + /// \return PassConfig shared object. This object is used for transformations pipeline + /// configuration. + /// This object allows to disable/enable transformations execution, set callback to + /// particular + /// transformation. For mo details see PassConfig class. + std::shared_ptr get_pass_config() { return m_pass_config; } + protected: + template + std::shared_ptr push_pass(Args&&... args) + { + static_assert(std::is_base_of::value, + "pass not derived from pass base"); + auto pass = std::make_shared(std::forward(args)...); + auto pass_base = std::static_pointer_cast(pass); + m_pass_list.push_back(pass_base); + return pass; + } + + std::shared_ptr m_pass_config; + std::vector> m_pass_list; + bool m_visualize = false; + bool m_per_pass_validation = true; + }; } } - -class NGRAPH_API ngraph::pass::Manager -{ -public: - Manager(); - ~Manager(); - - //// \brief Construct Manager with shared PassConfig instance - explicit Manager(std::shared_ptr pass_config); - - /// \brief Register given transformation class type to execution list - /// Example below show the basic usage of pass::Manager - /// - /// pass::Manager manager; - /// manager.register_pass(/*transformation constructor ars*/); - /// manager.run_passes(f); - /// - /// For some purposes transformation can be registered and disabled by default. - /// - /// manager.register_pass(); - /// - /// \return shared_ptr to the transformation instance - template - std::shared_ptr register_pass(Args&&... args) - { - auto rc = push_pass(std::forward(args)...); - rc->set_pass_config(m_pass_config); - if (m_per_pass_validation) - { - push_pass(); - } - if (!Enable && !m_pass_config->is_enabled()) - { - m_pass_config->disable(); - } - return rc; - } - - void run_passes(std::shared_ptr); - - void set_pass_visualization(bool new_state) { m_visualize = new_state; } - /// \brief Set flag to enable/disable running Validate pass after executing - /// each registered pass - /// \param new_state Value "true" enables Validate pass run; "false", otherwise - void set_per_pass_validation(bool new_state) { m_per_pass_validation = new_state; } - /// \brief Callback is a lambda function that can be used by registered transformations. - /// The main purpose of this callback is to provide a way for plugins to disable/enable - /// transformations based on some conditions. In some cases plugins may want not to execute some - /// transformations. - /// For example plugin can disable unpleasant decompositions because of performance reasons for - /// some cases. - /// Callback example: - /// auto callback = [](const std::shared_ptr & node) -> bool { - /// return std::dynamic_pointer_cast(node) != nullptr; - /// }; - /// This callback returns true in case of DepthToSpace operation. So when execution DepthToSpace - /// decomposition pass will check is this decomposition needed or plugin can execute this - /// operation directly. And of course on transformation side we need to have a response for this - /// callback. - /// if (transformation_callback(batch_to_space)) { - /// return false; - /// } - /// \param callback lamda function that returns true in case if node is supported by plugin and - /// transformation is not needed - NGRAPH_DEPRECATED("Please use get_pass_config() to configure transformation pipeline") - void set_callback(const param_callback& callback) { m_pass_config->set_callback(callback); } - /// \return PassConfig shared object. This object is used for transformations pipeline - /// configuration. - /// This object allows to disable/enable transformations execution, set callback to particular - /// transformation. For mo details see PassConfig class. - std::shared_ptr get_pass_config() { return m_pass_config; } -protected: - template - std::shared_ptr push_pass(Args&&... args) - { - static_assert(std::is_base_of::value, "pass not derived from pass base"); - auto pass = std::make_shared(std::forward(args)...); - auto pass_base = std::static_pointer_cast(pass); - m_pass_list.push_back(pass_base); - return pass; - } - - std::shared_ptr m_pass_config; - std::vector> m_pass_list; - bool m_visualize = false; - bool m_per_pass_validation = true; -}; diff --git a/ngraph/core/include/ngraph/pass/visualize_tree.hpp b/ngraph/core/include/ngraph/pass/visualize_tree.hpp index b0fdbe2e09d..0d882703d83 100644 --- a/ngraph/core/include/ngraph/pass/visualize_tree.hpp +++ b/ngraph/core/include/ngraph/pass/visualize_tree.hpp @@ -27,49 +27,51 @@ #include "ngraph/pass/pass.hpp" -namespace ngraph -{ - namespace pass - { - class VisualizeTree; - } -} - class HeightMap; using visualize_tree_ops_map_t = std::unordered_map>; -class NGRAPH_API ngraph::pass::VisualizeTree : public FunctionPass +namespace ngraph { -public: - NGRAPH_RTTI_DECLARATION; + namespace pass + { + class NGRAPH_API VisualizeTree : public FunctionPass + { + public: + NGRAPH_RTTI_DECLARATION; - using node_modifiers_t = - std::function& attributes)>; - VisualizeTree(const std::string& file_name, - node_modifiers_t nm = nullptr, - bool dot_only = false); - bool run_on_function(std::shared_ptr) override; + using node_modifiers_t = + std::function& attributes)>; + VisualizeTree(const std::string& file_name, + node_modifiers_t nm = nullptr, + bool dot_only = false); + bool run_on_function(std::shared_ptr) override; - void set_ops_to_details(const visualize_tree_ops_map_t& ops_map) { m_ops_to_details = ops_map; } -protected: - void add_node_arguments(std::shared_ptr node, - std::unordered_map& height_maps, - size_t& fake_node_ctr); - std::string add_attributes(std::shared_ptr node); - virtual std::string get_attributes(std::shared_ptr node); - virtual std::string get_node_name(std::shared_ptr node); - std::string get_constant_value(std::shared_ptr node, size_t max_elements = 7); + void set_ops_to_details(const visualize_tree_ops_map_t& ops_map) + { + m_ops_to_details = ops_map; + } - void render() const; + protected: + void add_node_arguments(std::shared_ptr node, + std::unordered_map& height_maps, + size_t& fake_node_ctr); + std::string add_attributes(std::shared_ptr node); + virtual std::string get_attributes(std::shared_ptr node); + virtual std::string get_node_name(std::shared_ptr node); + std::string get_constant_value(std::shared_ptr node, size_t max_elements = 7); - std::stringstream m_ss; - std::string m_name; - std::set> m_nodes_with_attributes; - visualize_tree_ops_map_t m_ops_to_details; - node_modifiers_t m_node_modifiers = nullptr; - bool m_dot_only; - static const int max_jump_distance; -}; + void render() const; + + std::stringstream m_ss; + std::string m_name; + std::set> m_nodes_with_attributes; + visualize_tree_ops_map_t m_ops_to_details; + node_modifiers_t m_node_modifiers = nullptr; + bool m_dot_only; + static const int max_jump_distance; + }; + } +} diff --git a/ngraph/core/include/ngraph/runtime/aligned_buffer.hpp b/ngraph/core/include/ngraph/runtime/aligned_buffer.hpp index 1e6852ba929..80ba5727b72 100644 --- a/ngraph/core/include/ngraph/runtime/aligned_buffer.hpp +++ b/ngraph/core/include/ngraph/runtime/aligned_buffer.hpp @@ -24,60 +24,55 @@ namespace ngraph { namespace runtime { - class AlignedBuffer; + /// \brief Allocates a block of memory on the specified alignment. The actual size of the + /// allocated memory is larger than the requested size by the alignment, so allocating 1 + /// byte + /// on 64 byte alignment will allocate 65 bytes. + class NGRAPH_API AlignedBuffer + { + public: + // Allocator objects and the allocation interfaces are owned by the + // creators of AlignedBuffers. They need to ensure that the lifetime of + // allocator exceeds the lifetime of this AlignedBuffer. + AlignedBuffer(size_t byte_size, size_t alignment = 64); + + AlignedBuffer(); + ~AlignedBuffer(); + + AlignedBuffer(AlignedBuffer&& other); + AlignedBuffer& operator=(AlignedBuffer&& other); + + size_t size() const { return m_byte_size; } + void* get_ptr(size_t offset) const { return m_aligned_buffer + offset; } + void* get_ptr() { return m_aligned_buffer; } + const void* get_ptr() const { return m_aligned_buffer; } + template + T* get_ptr() + { + return reinterpret_cast(m_aligned_buffer); + } + template + const T* get_ptr() const + { + return reinterpret_cast(m_aligned_buffer); + } + + template + explicit operator T*() + { + return get_ptr(); + } + + private: + AlignedBuffer(const AlignedBuffer&) = delete; + AlignedBuffer& operator=(const AlignedBuffer&) = delete; + + protected: + char* m_allocated_buffer; + char* m_aligned_buffer; + size_t m_byte_size; + }; } -} - -/// \brief Allocates a block of memory on the specified alignment. The actual size of the -/// allocated memory is larger than the requested size by the alignment, so allocating 1 byte -/// on 64 byte alignment will allocate 65 bytes. -class NGRAPH_API ngraph::runtime::AlignedBuffer -{ -public: - // Allocator objects and the allocation interfaces are owned by the - // creators of AlignedBuffers. They need to ensure that the lifetime of - // allocator exceeds the lifetime of this AlignedBuffer. - AlignedBuffer(size_t byte_size, size_t alignment = 64); - - AlignedBuffer(); - ~AlignedBuffer(); - - AlignedBuffer(AlignedBuffer&& other); - AlignedBuffer& operator=(AlignedBuffer&& other); - - size_t size() const { return m_byte_size; } - void* get_ptr(size_t offset) const { return m_aligned_buffer + offset; } - void* get_ptr() { return m_aligned_buffer; } - const void* get_ptr() const { return m_aligned_buffer; } - template - T* get_ptr() - { - return reinterpret_cast(m_aligned_buffer); - } - template - const T* get_ptr() const - { - return reinterpret_cast(m_aligned_buffer); - } - - template - explicit operator T*() - { - return get_ptr(); - } - -private: - AlignedBuffer(const AlignedBuffer&) = delete; - AlignedBuffer& operator=(const AlignedBuffer&) = delete; - -protected: - char* m_allocated_buffer; - char* m_aligned_buffer; - size_t m_byte_size; -}; - -namespace ngraph -{ template <> class NGRAPH_API AttributeAdapter> : public ValueAccessor diff --git a/ngraph/core/include/ngraph/runtime/host_tensor.hpp b/ngraph/core/include/ngraph/runtime/host_tensor.hpp index 5b337c87376..ba4723f5c3d 100644 --- a/ngraph/core/include/ngraph/runtime/host_tensor.hpp +++ b/ngraph/core/include/ngraph/runtime/host_tensor.hpp @@ -25,10 +25,6 @@ namespace ngraph { - namespace runtime - { - class HostTensor; - } namespace op { namespace v0 @@ -36,100 +32,106 @@ namespace ngraph class Constant; } } + namespace runtime + { + class NGRAPH_API HostTensor : public ngraph::runtime::Tensor + { + public: + HostTensor(const element::Type& element_type, + const Shape& shape, + void* memory_pointer, + const std::string& name = ""); + HostTensor(const element::Type& element_type, + const Shape& shape, + const std::string& name = ""); + HostTensor(const element::Type& element_type, + const PartialShape& partial_shape, + const std::string& name = ""); + HostTensor(const std::string& name = ""); + explicit HostTensor(const Output&); + explicit HostTensor(const std::shared_ptr& constant); + virtual ~HostTensor() override; + + void initialize(const std::shared_ptr& constant); + + void* get_data_ptr(); + const void* get_data_ptr() const; + + template + T* get_data_ptr() + { + return static_cast(get_data_ptr()); + } + + template + const T* get_data_ptr() const + { + return static_cast(get_data_ptr()); + } + + template + typename element_type_traits::value_type* get_data_ptr() + { + NGRAPH_CHECK(ET == get_element_type(), + "get_data_ptr() called for incorrect element type."); + return static_cast::value_type*>(get_data_ptr()); + } + + template + const typename element_type_traits::value_type* get_data_ptr() const + { + NGRAPH_CHECK(ET == get_element_type(), + "get_data_ptr() called for incorrect element type."); + return static_cast::value_type>(get_data_ptr()); + } + + /// \brief Write bytes directly into the tensor + /// \param p Pointer to source of data + /// \param n Number of bytes to write, must be integral number of elements. + void write(const void* p, size_t n) override; + + /// \brief Read bytes directly from the tensor + /// \param p Pointer to destination for data + /// \param n Number of bytes to read, must be integral number of elements. + void read(void* p, size_t n) const override; + + bool get_is_allocated() const; + /// \brief Set the element type. Must be compatible with the current element type. + /// \param element_type The element type + void set_element_type(const element::Type& element_type); + /// \brief Set the actual shape of the tensor compatibly with the partial shape. + /// \param shape The shape being set + void set_shape(const Shape& shape); + /// \brief Set the shape of a node from an input + /// \param arg The input argument + void set_unary(const HostTensorPtr& arg); + /// \brief Set the shape of the tensor using broadcast rules + /// \param autob The broadcast mode + /// \param arg0 The first argument + /// \param arg1 The second argument + void set_broadcast(const op::AutoBroadcastSpec& autob, + const HostTensorPtr& arg0, + const HostTensorPtr& arg1); + /// \brief Set the shape of the tensor using broadcast rules + /// \param autob The broadcast mode + /// \param arg0 The first argument + /// \param arg1 The second argument + /// \param element_type The output element type + void set_broadcast(const op::AutoBroadcastSpec& autob, + const HostTensorPtr& arg0, + const HostTensorPtr& arg1, + const element::Type& element_type); + + private: + void allocate_buffer(); + HostTensor(const HostTensor&) = delete; + HostTensor(HostTensor&&) = delete; + HostTensor& operator=(const HostTensor&) = delete; + + void* m_memory_pointer{nullptr}; + void* m_allocated_buffer_pool{nullptr}; + void* m_aligned_buffer_pool{nullptr}; + size_t m_buffer_size; + }; + } } - -class NGRAPH_API ngraph::runtime::HostTensor : public ngraph::runtime::Tensor -{ -public: - HostTensor(const element::Type& element_type, - const Shape& shape, - void* memory_pointer, - const std::string& name = ""); - HostTensor(const element::Type& element_type, const Shape& shape, const std::string& name = ""); - HostTensor(const element::Type& element_type, - const PartialShape& partial_shape, - const std::string& name = ""); - HostTensor(const std::string& name = ""); - explicit HostTensor(const Output&); - explicit HostTensor(const std::shared_ptr& constant); - virtual ~HostTensor() override; - - void initialize(const std::shared_ptr& constant); - - void* get_data_ptr(); - const void* get_data_ptr() const; - - template - T* get_data_ptr() - { - return static_cast(get_data_ptr()); - } - - template - const T* get_data_ptr() const - { - return static_cast(get_data_ptr()); - } - - template - typename element_type_traits::value_type* get_data_ptr() - { - NGRAPH_CHECK(ET == get_element_type(), "get_data_ptr() called for incorrect element type."); - return static_cast::value_type*>(get_data_ptr()); - } - - template - const typename element_type_traits::value_type* get_data_ptr() const - { - NGRAPH_CHECK(ET == get_element_type(), "get_data_ptr() called for incorrect element type."); - return static_cast::value_type>(get_data_ptr()); - } - - /// \brief Write bytes directly into the tensor - /// \param p Pointer to source of data - /// \param n Number of bytes to write, must be integral number of elements. - void write(const void* p, size_t n) override; - - /// \brief Read bytes directly from the tensor - /// \param p Pointer to destination for data - /// \param n Number of bytes to read, must be integral number of elements. - void read(void* p, size_t n) const override; - - bool get_is_allocated() const; - /// \brief Set the element type. Must be compatible with the current element type. - /// \param element_type The element type - void set_element_type(const element::Type& element_type); - /// \brief Set the actual shape of the tensor compatibly with the partial shape. - /// \param shape The shape being set - void set_shape(const Shape& shape); - /// \brief Set the shape of a node from an input - /// \param arg The input argument - void set_unary(const HostTensorPtr& arg); - /// \brief Set the shape of the tensor using broadcast rules - /// \param autob The broadcast mode - /// \param arg0 The first argument - /// \param arg1 The second argument - void set_broadcast(const op::AutoBroadcastSpec& autob, - const HostTensorPtr& arg0, - const HostTensorPtr& arg1); - /// \brief Set the shape of the tensor using broadcast rules - /// \param autob The broadcast mode - /// \param arg0 The first argument - /// \param arg1 The second argument - /// \param element_type The output element type - void set_broadcast(const op::AutoBroadcastSpec& autob, - const HostTensorPtr& arg0, - const HostTensorPtr& arg1, - const element::Type& element_type); - -private: - void allocate_buffer(); - HostTensor(const HostTensor&) = delete; - HostTensor(HostTensor&&) = delete; - HostTensor& operator=(const HostTensor&) = delete; - - void* m_memory_pointer{nullptr}; - void* m_allocated_buffer_pool{nullptr}; - void* m_aligned_buffer_pool{nullptr}; - size_t m_buffer_size; -};