Ngraph improvements (#2058)
This commit is contained in:
parent
6730cab192
commit
50c6f02a2e
@ -55,28 +55,18 @@ namespace ngraph
|
||||
/// PartialShape s{}; // rank=0
|
||||
/// PartialShape s{2,Dimension::dynamic(),3}; // rank=3, dimension 1 dynamic
|
||||
/// \endcode
|
||||
PartialShape(std::initializer_list<Dimension> init)
|
||||
: PartialShape(true, init)
|
||||
{
|
||||
}
|
||||
PartialShape(std::initializer_list<Dimension> init);
|
||||
|
||||
/// \brief Constructs a PartialShape with static rank from a vector of Dimension.
|
||||
/// \param dimensions The Dimension values for the constructed shape.
|
||||
PartialShape(const std::vector<Dimension>& dimensions)
|
||||
: m_rank_is_static(true)
|
||||
, m_dimensions(dimensions)
|
||||
{
|
||||
}
|
||||
PartialShape(const std::vector<Dimension>& dimensions);
|
||||
|
||||
/// \brief Constructs a PartialShape with static rank from a vector of dimensions values.
|
||||
/// \param dimensions The Dimension values for the constructed shape.
|
||||
PartialShape(const std::vector<Dimension::value_type>& dimensions);
|
||||
|
||||
/// \brief Constructs a static PartialShape with zero rank (the shape of a scalar).
|
||||
PartialShape()
|
||||
: PartialShape(std::initializer_list<Dimension>{})
|
||||
{
|
||||
}
|
||||
PartialShape();
|
||||
|
||||
/// \brief Constructs a static PartialShape from a Shape.
|
||||
/// \param shape The Shape to convert into PartialShape.
|
||||
@ -235,15 +225,18 @@ namespace ngraph
|
||||
|
||||
private:
|
||||
// Private constructor for PartialShape::dynamic().
|
||||
PartialShape(bool rank_is_static, std::vector<Dimension> dimensions)
|
||||
: m_rank_is_static(rank_is_static)
|
||||
, m_dimensions(dimensions)
|
||||
{
|
||||
}
|
||||
PartialShape(bool rank_is_static, const std::vector<Dimension>& dimensions);
|
||||
|
||||
// True if the shape's rank is static.
|
||||
bool m_rank_is_static;
|
||||
|
||||
// True if the shape is static.
|
||||
mutable enum class ShapeType {
|
||||
SHAPE_IS_UNKNOWN,
|
||||
SHAPE_IS_STATIC,
|
||||
SHAPE_IS_DYNAMIC
|
||||
} m_shape_type{ShapeType::SHAPE_IS_UNKNOWN};
|
||||
|
||||
// Shape dimensions. This has no meaning if m_rank_is_static is false.
|
||||
std::vector<Dimension> m_dimensions;
|
||||
};
|
||||
|
@ -23,26 +23,53 @@
|
||||
|
||||
using namespace ngraph;
|
||||
|
||||
PartialShape::PartialShape()
|
||||
: PartialShape(std::initializer_list<Dimension>{})
|
||||
{
|
||||
}
|
||||
|
||||
PartialShape::PartialShape(std::initializer_list<Dimension> init)
|
||||
: PartialShape(true, init)
|
||||
{
|
||||
}
|
||||
|
||||
PartialShape::PartialShape(const std::vector<Dimension::value_type>& dimensions)
|
||||
: m_rank_is_static(true)
|
||||
, m_dimensions(dimensions.begin(), dimensions.end())
|
||||
{
|
||||
std::transform(dimensions.cbegin(),
|
||||
dimensions.cend(),
|
||||
std::back_inserter(m_dimensions),
|
||||
[](const Dimension::value_type& dimension) { return dimension; });
|
||||
}
|
||||
|
||||
PartialShape::PartialShape(const Shape& shape)
|
||||
: PartialShape(true, {})
|
||||
: m_rank_is_static(true)
|
||||
, m_shape_type(ShapeType::SHAPE_IS_STATIC)
|
||||
, m_dimensions(shape.begin(), shape.end())
|
||||
{
|
||||
}
|
||||
|
||||
PartialShape::PartialShape(bool rank_is_static, const std::vector<Dimension>& dimensions)
|
||||
: m_rank_is_static(rank_is_static)
|
||||
, m_dimensions(dimensions)
|
||||
{
|
||||
}
|
||||
|
||||
PartialShape::PartialShape(const std::vector<Dimension>& dimensions)
|
||||
: m_rank_is_static(true)
|
||||
, m_dimensions(dimensions)
|
||||
{
|
||||
m_dimensions.assign(shape.begin(), shape.end());
|
||||
}
|
||||
|
||||
bool ngraph::PartialShape::is_static() const
|
||||
{
|
||||
return m_rank_is_static && std::all_of(m_dimensions.begin(),
|
||||
m_dimensions.end(),
|
||||
[](const Dimension& d) { return d.is_static(); });
|
||||
if (m_shape_type == ShapeType::SHAPE_IS_UNKNOWN)
|
||||
{
|
||||
m_shape_type =
|
||||
m_rank_is_static && std::all_of(m_dimensions.begin(),
|
||||
m_dimensions.end(),
|
||||
[](const Dimension& d) { return d.is_static(); })
|
||||
? ShapeType::SHAPE_IS_STATIC
|
||||
: ShapeType::SHAPE_IS_DYNAMIC;
|
||||
}
|
||||
return m_shape_type == ShapeType::SHAPE_IS_STATIC;
|
||||
}
|
||||
|
||||
bool ngraph::PartialShape::operator==(const PartialShape& partial_shape) const
|
||||
@ -282,6 +309,7 @@ bool PartialShape::merge_rank(Rank r)
|
||||
{
|
||||
m_rank_is_static = true;
|
||||
m_dimensions = std::vector<Dimension>(r.get_length(), Dimension::dynamic());
|
||||
m_shape_type = ShapeType::SHAPE_IS_UNKNOWN;
|
||||
return true;
|
||||
}
|
||||
else
|
||||
@ -297,13 +325,13 @@ Shape PartialShape::to_shape() const
|
||||
throw std::invalid_argument("to_shape was called on a dynamic shape.");
|
||||
}
|
||||
|
||||
std::vector<size_t> dimensions_to_shape(m_dimensions.size());
|
||||
std::vector<size_t> shape_dimensions(m_dimensions.size());
|
||||
std::transform(m_dimensions.begin(),
|
||||
m_dimensions.end(),
|
||||
dimensions_to_shape.begin(),
|
||||
shape_dimensions.begin(),
|
||||
[](const Dimension& d) { return d.get_length(); });
|
||||
|
||||
return Shape(dimensions_to_shape.begin(), dimensions_to_shape.end());
|
||||
return shape_dimensions;
|
||||
}
|
||||
|
||||
bool PartialShape::merge_into(PartialShape& dst, const PartialShape& src)
|
||||
@ -444,6 +472,8 @@ Dimension& PartialShape::operator[](size_t i)
|
||||
{
|
||||
throw std::out_of_range("Accessing out-of-range dimension in Dimension[]");
|
||||
}
|
||||
m_shape_type =
|
||||
ShapeType::SHAPE_IS_UNKNOWN; // We can't guarantee that the shape remains static or dynamic.
|
||||
return m_dimensions[i];
|
||||
}
|
||||
|
||||
|
@ -15,8 +15,9 @@
|
||||
//*****************************************************************************
|
||||
|
||||
#include <cmath>
|
||||
#include <functional>
|
||||
#include <iostream>
|
||||
#include <map>
|
||||
#include <unordered_map>
|
||||
|
||||
#include "ngraph/log.hpp"
|
||||
#include "ngraph/type/element_type.hpp"
|
||||
@ -69,9 +70,16 @@ public:
|
||||
std::string m_type_name;
|
||||
};
|
||||
|
||||
static const map<element::Type_t, const TypeInfo>& get_type_info_map()
|
||||
struct element_type_hash
|
||||
{
|
||||
static map<element::Type_t, const TypeInfo> s_type_info_map{
|
||||
size_t operator()(element::Type_t t) const { return static_cast<size_t>(t); }
|
||||
};
|
||||
|
||||
typedef unordered_map<element::Type_t, const TypeInfo, element_type_hash> element_types_map_t;
|
||||
|
||||
static const element_types_map_t& get_type_info_map()
|
||||
{
|
||||
static element_types_map_t s_type_info_map{
|
||||
{element::Type_t::undefined,
|
||||
TypeInfo(
|
||||
std::numeric_limits<size_t>::max(), false, false, false, "undefined", "undefined")},
|
||||
|
Loading…
Reference in New Issue
Block a user