Ngraph improvements (#2058)
This commit is contained in:
parent
6730cab192
commit
50c6f02a2e
@ -55,28 +55,18 @@ namespace ngraph
|
|||||||
/// PartialShape s{}; // rank=0
|
/// PartialShape s{}; // rank=0
|
||||||
/// PartialShape s{2,Dimension::dynamic(),3}; // rank=3, dimension 1 dynamic
|
/// PartialShape s{2,Dimension::dynamic(),3}; // rank=3, dimension 1 dynamic
|
||||||
/// \endcode
|
/// \endcode
|
||||||
PartialShape(std::initializer_list<Dimension> init)
|
PartialShape(std::initializer_list<Dimension> init);
|
||||||
: PartialShape(true, init)
|
|
||||||
{
|
|
||||||
}
|
|
||||||
|
|
||||||
/// \brief Constructs a PartialShape with static rank from a vector of Dimension.
|
/// \brief Constructs a PartialShape with static rank from a vector of Dimension.
|
||||||
/// \param dimensions The Dimension values for the constructed shape.
|
/// \param dimensions The Dimension values for the constructed shape.
|
||||||
PartialShape(const std::vector<Dimension>& dimensions)
|
PartialShape(const std::vector<Dimension>& dimensions);
|
||||||
: m_rank_is_static(true)
|
|
||||||
, m_dimensions(dimensions)
|
|
||||||
{
|
|
||||||
}
|
|
||||||
|
|
||||||
/// \brief Constructs a PartialShape with static rank from a vector of dimensions values.
|
/// \brief Constructs a PartialShape with static rank from a vector of dimensions values.
|
||||||
/// \param dimensions The Dimension values for the constructed shape.
|
/// \param dimensions The Dimension values for the constructed shape.
|
||||||
PartialShape(const std::vector<Dimension::value_type>& dimensions);
|
PartialShape(const std::vector<Dimension::value_type>& dimensions);
|
||||||
|
|
||||||
/// \brief Constructs a static PartialShape with zero rank (the shape of a scalar).
|
/// \brief Constructs a static PartialShape with zero rank (the shape of a scalar).
|
||||||
PartialShape()
|
PartialShape();
|
||||||
: PartialShape(std::initializer_list<Dimension>{})
|
|
||||||
{
|
|
||||||
}
|
|
||||||
|
|
||||||
/// \brief Constructs a static PartialShape from a Shape.
|
/// \brief Constructs a static PartialShape from a Shape.
|
||||||
/// \param shape The Shape to convert into PartialShape.
|
/// \param shape The Shape to convert into PartialShape.
|
||||||
@ -235,15 +225,18 @@ namespace ngraph
|
|||||||
|
|
||||||
private:
|
private:
|
||||||
// Private constructor for PartialShape::dynamic().
|
// Private constructor for PartialShape::dynamic().
|
||||||
PartialShape(bool rank_is_static, std::vector<Dimension> dimensions)
|
PartialShape(bool rank_is_static, const std::vector<Dimension>& dimensions);
|
||||||
: m_rank_is_static(rank_is_static)
|
|
||||||
, m_dimensions(dimensions)
|
|
||||||
{
|
|
||||||
}
|
|
||||||
|
|
||||||
// True if the shape's rank is static.
|
// True if the shape's rank is static.
|
||||||
bool m_rank_is_static;
|
bool m_rank_is_static;
|
||||||
|
|
||||||
|
// True if the shape is static.
|
||||||
|
mutable enum class ShapeType {
|
||||||
|
SHAPE_IS_UNKNOWN,
|
||||||
|
SHAPE_IS_STATIC,
|
||||||
|
SHAPE_IS_DYNAMIC
|
||||||
|
} m_shape_type{ShapeType::SHAPE_IS_UNKNOWN};
|
||||||
|
|
||||||
// Shape dimensions. This has no meaning if m_rank_is_static is false.
|
// Shape dimensions. This has no meaning if m_rank_is_static is false.
|
||||||
std::vector<Dimension> m_dimensions;
|
std::vector<Dimension> m_dimensions;
|
||||||
};
|
};
|
||||||
|
@ -23,26 +23,53 @@
|
|||||||
|
|
||||||
using namespace ngraph;
|
using namespace ngraph;
|
||||||
|
|
||||||
|
PartialShape::PartialShape()
|
||||||
|
: PartialShape(std::initializer_list<Dimension>{})
|
||||||
|
{
|
||||||
|
}
|
||||||
|
|
||||||
|
PartialShape::PartialShape(std::initializer_list<Dimension> init)
|
||||||
|
: PartialShape(true, init)
|
||||||
|
{
|
||||||
|
}
|
||||||
|
|
||||||
PartialShape::PartialShape(const std::vector<Dimension::value_type>& dimensions)
|
PartialShape::PartialShape(const std::vector<Dimension::value_type>& dimensions)
|
||||||
: m_rank_is_static(true)
|
: m_rank_is_static(true)
|
||||||
|
, m_dimensions(dimensions.begin(), dimensions.end())
|
||||||
{
|
{
|
||||||
std::transform(dimensions.cbegin(),
|
|
||||||
dimensions.cend(),
|
|
||||||
std::back_inserter(m_dimensions),
|
|
||||||
[](const Dimension::value_type& dimension) { return dimension; });
|
|
||||||
}
|
}
|
||||||
|
|
||||||
PartialShape::PartialShape(const Shape& shape)
|
PartialShape::PartialShape(const Shape& shape)
|
||||||
: PartialShape(true, {})
|
: m_rank_is_static(true)
|
||||||
|
, m_shape_type(ShapeType::SHAPE_IS_STATIC)
|
||||||
|
, m_dimensions(shape.begin(), shape.end())
|
||||||
|
{
|
||||||
|
}
|
||||||
|
|
||||||
|
PartialShape::PartialShape(bool rank_is_static, const std::vector<Dimension>& dimensions)
|
||||||
|
: m_rank_is_static(rank_is_static)
|
||||||
|
, m_dimensions(dimensions)
|
||||||
|
{
|
||||||
|
}
|
||||||
|
|
||||||
|
PartialShape::PartialShape(const std::vector<Dimension>& dimensions)
|
||||||
|
: m_rank_is_static(true)
|
||||||
|
, m_dimensions(dimensions)
|
||||||
{
|
{
|
||||||
m_dimensions.assign(shape.begin(), shape.end());
|
|
||||||
}
|
}
|
||||||
|
|
||||||
bool ngraph::PartialShape::is_static() const
|
bool ngraph::PartialShape::is_static() const
|
||||||
{
|
{
|
||||||
return m_rank_is_static && std::all_of(m_dimensions.begin(),
|
if (m_shape_type == ShapeType::SHAPE_IS_UNKNOWN)
|
||||||
m_dimensions.end(),
|
{
|
||||||
[](const Dimension& d) { return d.is_static(); });
|
m_shape_type =
|
||||||
|
m_rank_is_static && std::all_of(m_dimensions.begin(),
|
||||||
|
m_dimensions.end(),
|
||||||
|
[](const Dimension& d) { return d.is_static(); })
|
||||||
|
? ShapeType::SHAPE_IS_STATIC
|
||||||
|
: ShapeType::SHAPE_IS_DYNAMIC;
|
||||||
|
}
|
||||||
|
return m_shape_type == ShapeType::SHAPE_IS_STATIC;
|
||||||
}
|
}
|
||||||
|
|
||||||
bool ngraph::PartialShape::operator==(const PartialShape& partial_shape) const
|
bool ngraph::PartialShape::operator==(const PartialShape& partial_shape) const
|
||||||
@ -282,6 +309,7 @@ bool PartialShape::merge_rank(Rank r)
|
|||||||
{
|
{
|
||||||
m_rank_is_static = true;
|
m_rank_is_static = true;
|
||||||
m_dimensions = std::vector<Dimension>(r.get_length(), Dimension::dynamic());
|
m_dimensions = std::vector<Dimension>(r.get_length(), Dimension::dynamic());
|
||||||
|
m_shape_type = ShapeType::SHAPE_IS_UNKNOWN;
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
else
|
else
|
||||||
@ -297,13 +325,13 @@ Shape PartialShape::to_shape() const
|
|||||||
throw std::invalid_argument("to_shape was called on a dynamic shape.");
|
throw std::invalid_argument("to_shape was called on a dynamic shape.");
|
||||||
}
|
}
|
||||||
|
|
||||||
std::vector<size_t> dimensions_to_shape(m_dimensions.size());
|
std::vector<size_t> shape_dimensions(m_dimensions.size());
|
||||||
std::transform(m_dimensions.begin(),
|
std::transform(m_dimensions.begin(),
|
||||||
m_dimensions.end(),
|
m_dimensions.end(),
|
||||||
dimensions_to_shape.begin(),
|
shape_dimensions.begin(),
|
||||||
[](const Dimension& d) { return d.get_length(); });
|
[](const Dimension& d) { return d.get_length(); });
|
||||||
|
|
||||||
return Shape(dimensions_to_shape.begin(), dimensions_to_shape.end());
|
return shape_dimensions;
|
||||||
}
|
}
|
||||||
|
|
||||||
bool PartialShape::merge_into(PartialShape& dst, const PartialShape& src)
|
bool PartialShape::merge_into(PartialShape& dst, const PartialShape& src)
|
||||||
@ -444,6 +472,8 @@ Dimension& PartialShape::operator[](size_t i)
|
|||||||
{
|
{
|
||||||
throw std::out_of_range("Accessing out-of-range dimension in Dimension[]");
|
throw std::out_of_range("Accessing out-of-range dimension in Dimension[]");
|
||||||
}
|
}
|
||||||
|
m_shape_type =
|
||||||
|
ShapeType::SHAPE_IS_UNKNOWN; // We can't guarantee that the shape remains static or dynamic.
|
||||||
return m_dimensions[i];
|
return m_dimensions[i];
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -15,8 +15,9 @@
|
|||||||
//*****************************************************************************
|
//*****************************************************************************
|
||||||
|
|
||||||
#include <cmath>
|
#include <cmath>
|
||||||
|
#include <functional>
|
||||||
#include <iostream>
|
#include <iostream>
|
||||||
#include <map>
|
#include <unordered_map>
|
||||||
|
|
||||||
#include "ngraph/log.hpp"
|
#include "ngraph/log.hpp"
|
||||||
#include "ngraph/type/element_type.hpp"
|
#include "ngraph/type/element_type.hpp"
|
||||||
@ -69,9 +70,16 @@ public:
|
|||||||
std::string m_type_name;
|
std::string m_type_name;
|
||||||
};
|
};
|
||||||
|
|
||||||
static const map<element::Type_t, const TypeInfo>& get_type_info_map()
|
struct element_type_hash
|
||||||
{
|
{
|
||||||
static map<element::Type_t, const TypeInfo> s_type_info_map{
|
size_t operator()(element::Type_t t) const { return static_cast<size_t>(t); }
|
||||||
|
};
|
||||||
|
|
||||||
|
typedef unordered_map<element::Type_t, const TypeInfo, element_type_hash> element_types_map_t;
|
||||||
|
|
||||||
|
static const element_types_map_t& get_type_info_map()
|
||||||
|
{
|
||||||
|
static element_types_map_t s_type_info_map{
|
||||||
{element::Type_t::undefined,
|
{element::Type_t::undefined,
|
||||||
TypeInfo(
|
TypeInfo(
|
||||||
std::numeric_limits<size_t>::max(), false, false, false, "undefined", "undefined")},
|
std::numeric_limits<size_t>::max(), false, false, false, "undefined", "undefined")},
|
||||||
|
Loading…
Reference in New Issue
Block a user