set_output_type speedup (#6754)
* set_output_type speedup * style * Final optimization * Removed extra include, removed unnecessary lock_guard * Typo * Apply suggestions from code review Co-authored-by: Mikhail Nosov <mikhail.nosov@intel.com> * Update ngraph/core/include/ngraph/descriptor/tensor.hpp Co-authored-by: Mikhail Nosov <mikhail.nosov@intel.com> Co-authored-by: Mikhail Nosov <mikhail.nosov@intel.com>
This commit is contained in:
parent
feb1eaef05
commit
b907c3b84f
@ -4,7 +4,9 @@
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <atomic>
|
||||
#include <memory>
|
||||
#include <mutex>
|
||||
#include <string>
|
||||
#include <unordered_set>
|
||||
|
||||
@ -74,16 +76,24 @@ namespace ngraph
|
||||
protected:
|
||||
element::Type m_element_type;
|
||||
|
||||
// TODO(amprocte): For now we are maintaining both m_shape and m_partial_shape fields,
|
||||
// with m_shape possibly being invalid (get_shape will throw an exception if it
|
||||
// is). This is because get_shape() returns a const reference. I think ideally we
|
||||
// should refactor so that get_shape returns by value.
|
||||
Shape m_shape;
|
||||
PartialShape m_partial_shape;
|
||||
Node* m_node{nullptr};
|
||||
HostTensorPtr m_lower_value, m_upper_value;
|
||||
size_t m_node_output_number{0};
|
||||
// TODO: remove along with get_shape
|
||||
// Initially there was ngraph::Shape m_shape only available to keep shape information.
|
||||
// Support for dynamic shapes required transition to ngraph::PartialShape.
|
||||
// To smoothly transition to ngraph::PartialShape we introduced m_partial_shape
|
||||
// and kept m_shape in sync with m_partial_shape. Synchronization point was placed
|
||||
// in set_partial_shape which dramatically affected performance of ngraph::Function
|
||||
// validation. Since we have started the transition to ngraph::PartialShape and reduced
|
||||
// ngraph::Shape usage the only user of m_shape was get_shape method with signature:
|
||||
// const Shape& descriptor::Tensor::get_shape() const
|
||||
// It was decided to move m_shape and m_partial_shape synchronization point there and
|
||||
// to keep methods signature backward compatible.
|
||||
mutable std::mutex shape_mutex;
|
||||
mutable std::atomic_bool m_shape_changed;
|
||||
mutable Shape m_shape;
|
||||
// TODO: end
|
||||
|
||||
PartialShape m_partial_shape;
|
||||
HostTensorPtr m_lower_value, m_upper_value;
|
||||
std::string m_name;
|
||||
std::unordered_set<std::string> m_names;
|
||||
};
|
||||
|
@ -4,7 +4,6 @@
|
||||
|
||||
#include "ngraph/descriptor/tensor.hpp"
|
||||
#include "ngraph/node.hpp"
|
||||
#include "ngraph/runtime/host_tensor.hpp"
|
||||
|
||||
using namespace ngraph;
|
||||
using namespace std;
|
||||
@ -13,9 +12,9 @@ descriptor::Tensor::Tensor(const element::Type& element_type,
|
||||
const PartialShape& pshape,
|
||||
const std::string& name)
|
||||
: m_element_type(element_type)
|
||||
, m_shape(pshape.is_static() ? pshape.to_shape() : Shape{})
|
||||
, m_partial_shape(pshape)
|
||||
, m_name(name)
|
||||
, m_shape_changed(true)
|
||||
{
|
||||
}
|
||||
|
||||
@ -24,10 +23,8 @@ descriptor::Tensor::Tensor(const element::Type& element_type,
|
||||
Node* node,
|
||||
size_t node_output_number)
|
||||
: m_element_type(element_type)
|
||||
, m_shape(pshape.is_static() ? pshape.to_shape() : Shape{})
|
||||
, m_partial_shape(pshape)
|
||||
, m_node(node)
|
||||
, m_node_output_number(node_output_number)
|
||||
, m_shape_changed(true)
|
||||
{
|
||||
}
|
||||
|
||||
@ -46,14 +43,7 @@ void descriptor::Tensor::set_element_type(const element::Type& element_type)
|
||||
void descriptor::Tensor::set_partial_shape(const PartialShape& partial_shape)
|
||||
{
|
||||
m_partial_shape = partial_shape;
|
||||
if (m_partial_shape.is_static())
|
||||
{
|
||||
m_shape = m_partial_shape.to_shape();
|
||||
}
|
||||
else
|
||||
{
|
||||
m_shape = Shape{};
|
||||
}
|
||||
m_shape_changed = true;
|
||||
}
|
||||
|
||||
void descriptor::Tensor::invalidate_values()
|
||||
@ -82,6 +72,15 @@ const Shape& descriptor::Tensor::get_shape() const
|
||||
{
|
||||
if (m_partial_shape.is_static())
|
||||
{
|
||||
if (m_shape_changed.load(std::memory_order_relaxed))
|
||||
{
|
||||
std::lock_guard<std::mutex> guard(shape_mutex);
|
||||
if (m_shape_changed) // double check after mutex lock
|
||||
{
|
||||
m_shape = m_partial_shape.to_shape();
|
||||
m_shape_changed = false;
|
||||
}
|
||||
}
|
||||
return m_shape;
|
||||
}
|
||||
else
|
||||
|
@ -210,7 +210,7 @@ descriptor::Output& Node::get_output_descriptor(size_t position)
|
||||
make_shared<descriptor::Tensor>(element::dynamic, PartialShape::dynamic(), this, i);
|
||||
m_outputs.emplace_back(this, i, tensor_descriptor);
|
||||
}
|
||||
return m_outputs.at(position);
|
||||
return m_outputs[position];
|
||||
}
|
||||
|
||||
void Node::set_argument(size_t position, const Output<Node>& argument)
|
||||
|
Loading…
Reference in New Issue
Block a user