Ngraph improvements (#2058)

This commit is contained in:
Vladislav Volkov 2020-09-07 10:36:52 +03:00 committed by GitHub
parent 6730cab192
commit 50c6f02a2e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 64 additions and 33 deletions

View File

@ -55,28 +55,18 @@ namespace ngraph
/// PartialShape s{}; // rank=0
/// PartialShape s{2,Dimension::dynamic(),3}; // rank=3, dimension 1 dynamic
/// \endcode
PartialShape(std::initializer_list<Dimension> init)
: PartialShape(true, init)
{
}
PartialShape(std::initializer_list<Dimension> init);
/// \brief Constructs a PartialShape with static rank from a vector of Dimension.
/// \param dimensions The Dimension values for the constructed shape.
PartialShape(const std::vector<Dimension>& dimensions)
: m_rank_is_static(true)
, m_dimensions(dimensions)
{
}
PartialShape(const std::vector<Dimension>& dimensions);
/// \brief Constructs a PartialShape with static rank from a vector of dimensions values.
/// \param dimensions The Dimension values for the constructed shape.
PartialShape(const std::vector<Dimension::value_type>& dimensions);
/// \brief Constructs a static PartialShape with zero rank (the shape of a scalar).
PartialShape()
: PartialShape(std::initializer_list<Dimension>{})
{
}
PartialShape();
/// \brief Constructs a static PartialShape from a Shape.
/// \param shape The Shape to convert into PartialShape.
@ -235,15 +225,18 @@ namespace ngraph
private:
// Private constructor for PartialShape::dynamic().
PartialShape(bool rank_is_static, std::vector<Dimension> dimensions)
: m_rank_is_static(rank_is_static)
, m_dimensions(dimensions)
{
}
PartialShape(bool rank_is_static, const std::vector<Dimension>& dimensions);
// True if the shape's rank is static.
bool m_rank_is_static;
// True if the shape is static.
mutable enum class ShapeType {
SHAPE_IS_UNKNOWN,
SHAPE_IS_STATIC,
SHAPE_IS_DYNAMIC
} m_shape_type{ShapeType::SHAPE_IS_UNKNOWN};
// Shape dimensions. This has no meaning if m_rank_is_static is false.
std::vector<Dimension> m_dimensions;
};

View File

@ -23,26 +23,53 @@
using namespace ngraph;
PartialShape::PartialShape()
: PartialShape(std::initializer_list<Dimension>{})
{
}
PartialShape::PartialShape(std::initializer_list<Dimension> init)
: PartialShape(true, init)
{
}
PartialShape::PartialShape(const std::vector<Dimension::value_type>& dimensions)
: m_rank_is_static(true)
, m_dimensions(dimensions.begin(), dimensions.end())
{
std::transform(dimensions.cbegin(),
dimensions.cend(),
std::back_inserter(m_dimensions),
[](const Dimension::value_type& dimension) { return dimension; });
}
PartialShape::PartialShape(const Shape& shape)
: PartialShape(true, {})
: m_rank_is_static(true)
, m_shape_type(ShapeType::SHAPE_IS_STATIC)
, m_dimensions(shape.begin(), shape.end())
{
}
PartialShape::PartialShape(bool rank_is_static, const std::vector<Dimension>& dimensions)
: m_rank_is_static(rank_is_static)
, m_dimensions(dimensions)
{
}
PartialShape::PartialShape(const std::vector<Dimension>& dimensions)
: m_rank_is_static(true)
, m_dimensions(dimensions)
{
m_dimensions.assign(shape.begin(), shape.end());
}
bool ngraph::PartialShape::is_static() const
{
return m_rank_is_static && std::all_of(m_dimensions.begin(),
m_dimensions.end(),
[](const Dimension& d) { return d.is_static(); });
if (m_shape_type == ShapeType::SHAPE_IS_UNKNOWN)
{
m_shape_type =
m_rank_is_static && std::all_of(m_dimensions.begin(),
m_dimensions.end(),
[](const Dimension& d) { return d.is_static(); })
? ShapeType::SHAPE_IS_STATIC
: ShapeType::SHAPE_IS_DYNAMIC;
}
return m_shape_type == ShapeType::SHAPE_IS_STATIC;
}
bool ngraph::PartialShape::operator==(const PartialShape& partial_shape) const
@ -282,6 +309,7 @@ bool PartialShape::merge_rank(Rank r)
{
m_rank_is_static = true;
m_dimensions = std::vector<Dimension>(r.get_length(), Dimension::dynamic());
m_shape_type = ShapeType::SHAPE_IS_UNKNOWN;
return true;
}
else
@ -297,13 +325,13 @@ Shape PartialShape::to_shape() const
throw std::invalid_argument("to_shape was called on a dynamic shape.");
}
std::vector<size_t> dimensions_to_shape(m_dimensions.size());
std::vector<size_t> shape_dimensions(m_dimensions.size());
std::transform(m_dimensions.begin(),
m_dimensions.end(),
dimensions_to_shape.begin(),
shape_dimensions.begin(),
[](const Dimension& d) { return d.get_length(); });
return Shape(dimensions_to_shape.begin(), dimensions_to_shape.end());
return shape_dimensions;
}
bool PartialShape::merge_into(PartialShape& dst, const PartialShape& src)
@ -444,6 +472,8 @@ Dimension& PartialShape::operator[](size_t i)
{
throw std::out_of_range("Accessing out-of-range dimension in Dimension[]");
}
m_shape_type =
ShapeType::SHAPE_IS_UNKNOWN; // We can't guarantee that the shape remains static or dynamic.
return m_dimensions[i];
}

View File

@ -15,8 +15,9 @@
//*****************************************************************************
#include <cmath>
#include <functional>
#include <iostream>
#include <map>
#include <unordered_map>
#include "ngraph/log.hpp"
#include "ngraph/type/element_type.hpp"
@ -69,9 +70,16 @@ public:
std::string m_type_name;
};
static const map<element::Type_t, const TypeInfo>& get_type_info_map()
struct element_type_hash
{
static map<element::Type_t, const TypeInfo> s_type_info_map{
size_t operator()(element::Type_t t) const { return static_cast<size_t>(t); }
};
typedef unordered_map<element::Type_t, const TypeInfo, element_type_hash> element_types_map_t;
static const element_types_map_t& get_type_info_map()
{
static element_types_map_t s_type_info_map{
{element::Type_t::undefined,
TypeInfo(
std::numeric_limits<size_t>::max(), false, false, false, "undefined", "undefined")},