Ngraph improvements (#2058)

This commit is contained in:
Vladislav Volkov 2020-09-07 10:36:52 +03:00 committed by GitHub
parent 6730cab192
commit 50c6f02a2e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 64 additions and 33 deletions

View File

@ -55,28 +55,18 @@ namespace ngraph
/// PartialShape s{}; // rank=0 /// PartialShape s{}; // rank=0
/// PartialShape s{2,Dimension::dynamic(),3}; // rank=3, dimension 1 dynamic /// PartialShape s{2,Dimension::dynamic(),3}; // rank=3, dimension 1 dynamic
/// \endcode /// \endcode
PartialShape(std::initializer_list<Dimension> init) PartialShape(std::initializer_list<Dimension> init);
: PartialShape(true, init)
{
}
/// \brief Constructs a PartialShape with static rank from a vector of Dimension. /// \brief Constructs a PartialShape with static rank from a vector of Dimension.
/// \param dimensions The Dimension values for the constructed shape. /// \param dimensions The Dimension values for the constructed shape.
PartialShape(const std::vector<Dimension>& dimensions) PartialShape(const std::vector<Dimension>& dimensions);
: m_rank_is_static(true)
, m_dimensions(dimensions)
{
}
/// \brief Constructs a PartialShape with static rank from a vector of dimensions values. /// \brief Constructs a PartialShape with static rank from a vector of dimensions values.
/// \param dimensions The Dimension values for the constructed shape. /// \param dimensions The Dimension values for the constructed shape.
PartialShape(const std::vector<Dimension::value_type>& dimensions); PartialShape(const std::vector<Dimension::value_type>& dimensions);
/// \brief Constructs a static PartialShape with zero rank (the shape of a scalar). /// \brief Constructs a static PartialShape with zero rank (the shape of a scalar).
PartialShape() PartialShape();
: PartialShape(std::initializer_list<Dimension>{})
{
}
/// \brief Constructs a static PartialShape from a Shape. /// \brief Constructs a static PartialShape from a Shape.
/// \param shape The Shape to convert into PartialShape. /// \param shape The Shape to convert into PartialShape.
@ -235,15 +225,18 @@ namespace ngraph
private: private:
// Private constructor for PartialShape::dynamic(). // Private constructor for PartialShape::dynamic().
PartialShape(bool rank_is_static, std::vector<Dimension> dimensions) PartialShape(bool rank_is_static, const std::vector<Dimension>& dimensions);
: m_rank_is_static(rank_is_static)
, m_dimensions(dimensions)
{
}
// True if the shape's rank is static. // True if the shape's rank is static.
bool m_rank_is_static; bool m_rank_is_static;
// True if the shape is static.
mutable enum class ShapeType {
SHAPE_IS_UNKNOWN,
SHAPE_IS_STATIC,
SHAPE_IS_DYNAMIC
} m_shape_type{ShapeType::SHAPE_IS_UNKNOWN};
// Shape dimensions. This has no meaning if m_rank_is_static is false. // Shape dimensions. This has no meaning if m_rank_is_static is false.
std::vector<Dimension> m_dimensions; std::vector<Dimension> m_dimensions;
}; };

View File

@ -23,26 +23,53 @@
using namespace ngraph; using namespace ngraph;
PartialShape::PartialShape()
: PartialShape(std::initializer_list<Dimension>{})
{
}
PartialShape::PartialShape(std::initializer_list<Dimension> init)
: PartialShape(true, init)
{
}
PartialShape::PartialShape(const std::vector<Dimension::value_type>& dimensions) PartialShape::PartialShape(const std::vector<Dimension::value_type>& dimensions)
: m_rank_is_static(true) : m_rank_is_static(true)
, m_dimensions(dimensions.begin(), dimensions.end())
{ {
std::transform(dimensions.cbegin(),
dimensions.cend(),
std::back_inserter(m_dimensions),
[](const Dimension::value_type& dimension) { return dimension; });
} }
PartialShape::PartialShape(const Shape& shape) PartialShape::PartialShape(const Shape& shape)
: PartialShape(true, {}) : m_rank_is_static(true)
, m_shape_type(ShapeType::SHAPE_IS_STATIC)
, m_dimensions(shape.begin(), shape.end())
{
}
PartialShape::PartialShape(bool rank_is_static, const std::vector<Dimension>& dimensions)
: m_rank_is_static(rank_is_static)
, m_dimensions(dimensions)
{
}
PartialShape::PartialShape(const std::vector<Dimension>& dimensions)
: m_rank_is_static(true)
, m_dimensions(dimensions)
{ {
m_dimensions.assign(shape.begin(), shape.end());
} }
bool ngraph::PartialShape::is_static() const bool ngraph::PartialShape::is_static() const
{ {
return m_rank_is_static && std::all_of(m_dimensions.begin(), if (m_shape_type == ShapeType::SHAPE_IS_UNKNOWN)
m_dimensions.end(), {
[](const Dimension& d) { return d.is_static(); }); m_shape_type =
m_rank_is_static && std::all_of(m_dimensions.begin(),
m_dimensions.end(),
[](const Dimension& d) { return d.is_static(); })
? ShapeType::SHAPE_IS_STATIC
: ShapeType::SHAPE_IS_DYNAMIC;
}
return m_shape_type == ShapeType::SHAPE_IS_STATIC;
} }
bool ngraph::PartialShape::operator==(const PartialShape& partial_shape) const bool ngraph::PartialShape::operator==(const PartialShape& partial_shape) const
@ -282,6 +309,7 @@ bool PartialShape::merge_rank(Rank r)
{ {
m_rank_is_static = true; m_rank_is_static = true;
m_dimensions = std::vector<Dimension>(r.get_length(), Dimension::dynamic()); m_dimensions = std::vector<Dimension>(r.get_length(), Dimension::dynamic());
m_shape_type = ShapeType::SHAPE_IS_UNKNOWN;
return true; return true;
} }
else else
@ -297,13 +325,13 @@ Shape PartialShape::to_shape() const
throw std::invalid_argument("to_shape was called on a dynamic shape."); throw std::invalid_argument("to_shape was called on a dynamic shape.");
} }
std::vector<size_t> dimensions_to_shape(m_dimensions.size()); std::vector<size_t> shape_dimensions(m_dimensions.size());
std::transform(m_dimensions.begin(), std::transform(m_dimensions.begin(),
m_dimensions.end(), m_dimensions.end(),
dimensions_to_shape.begin(), shape_dimensions.begin(),
[](const Dimension& d) { return d.get_length(); }); [](const Dimension& d) { return d.get_length(); });
return Shape(dimensions_to_shape.begin(), dimensions_to_shape.end()); return shape_dimensions;
} }
bool PartialShape::merge_into(PartialShape& dst, const PartialShape& src) bool PartialShape::merge_into(PartialShape& dst, const PartialShape& src)
@ -444,6 +472,8 @@ Dimension& PartialShape::operator[](size_t i)
{ {
throw std::out_of_range("Accessing out-of-range dimension in Dimension[]"); throw std::out_of_range("Accessing out-of-range dimension in Dimension[]");
} }
m_shape_type =
ShapeType::SHAPE_IS_UNKNOWN; // We can't guarantee that the shape remains static or dynamic.
return m_dimensions[i]; return m_dimensions[i];
} }

View File

@ -15,8 +15,9 @@
//***************************************************************************** //*****************************************************************************
#include <cmath> #include <cmath>
#include <functional>
#include <iostream> #include <iostream>
#include <map> #include <unordered_map>
#include "ngraph/log.hpp" #include "ngraph/log.hpp"
#include "ngraph/type/element_type.hpp" #include "ngraph/type/element_type.hpp"
@ -69,9 +70,16 @@ public:
std::string m_type_name; std::string m_type_name;
}; };
static const map<element::Type_t, const TypeInfo>& get_type_info_map() struct element_type_hash
{ {
static map<element::Type_t, const TypeInfo> s_type_info_map{ size_t operator()(element::Type_t t) const { return static_cast<size_t>(t); }
};
typedef unordered_map<element::Type_t, const TypeInfo, element_type_hash> element_types_map_t;
static const element_types_map_t& get_type_info_map()
{
static element_types_map_t s_type_info_map{
{element::Type_t::undefined, {element::Type_t::undefined,
TypeInfo( TypeInfo(
std::numeric_limits<size_t>::max(), false, false, false, "undefined", "undefined")}, std::numeric_limits<size_t>::max(), false, false, false, "undefined", "undefined")},