diff --git a/ngraph/core/include/ngraph/partial_shape.hpp b/ngraph/core/include/ngraph/partial_shape.hpp index e42b4f776dc..8c89a7d57dd 100644 --- a/ngraph/core/include/ngraph/partial_shape.hpp +++ b/ngraph/core/include/ngraph/partial_shape.hpp @@ -55,28 +55,18 @@ namespace ngraph /// PartialShape s{}; // rank=0 /// PartialShape s{2,Dimension::dynamic(),3}; // rank=3, dimension 1 dynamic /// \endcode - PartialShape(std::initializer_list init) - : PartialShape(true, init) - { - } + PartialShape(std::initializer_list init); /// \brief Constructs a PartialShape with static rank from a vector of Dimension. /// \param dimensions The Dimension values for the constructed shape. - PartialShape(const std::vector& dimensions) - : m_rank_is_static(true) - , m_dimensions(dimensions) - { - } + PartialShape(const std::vector& dimensions); /// \brief Constructs a PartialShape with static rank from a vector of dimensions values. /// \param dimensions The Dimension values for the constructed shape. PartialShape(const std::vector& dimensions); /// \brief Constructs a static PartialShape with zero rank (the shape of a scalar). - PartialShape() - : PartialShape(std::initializer_list{}) - { - } + PartialShape(); /// \brief Constructs a static PartialShape from a Shape. /// \param shape The Shape to convert into PartialShape. @@ -235,15 +225,18 @@ namespace ngraph private: // Private constructor for PartialShape::dynamic(). - PartialShape(bool rank_is_static, std::vector dimensions) - : m_rank_is_static(rank_is_static) - , m_dimensions(dimensions) - { - } + PartialShape(bool rank_is_static, const std::vector& dimensions); // True if the shape's rank is static. bool m_rank_is_static; + // True if the shape is static. + mutable enum class ShapeType { + SHAPE_IS_UNKNOWN, + SHAPE_IS_STATIC, + SHAPE_IS_DYNAMIC + } m_shape_type{ShapeType::SHAPE_IS_UNKNOWN}; + // Shape dimensions. This has no meaning if m_rank_is_static is false. std::vector m_dimensions; }; diff --git a/ngraph/core/src/partial_shape.cpp b/ngraph/core/src/partial_shape.cpp index e31afb0174f..3f0c38d1977 100644 --- a/ngraph/core/src/partial_shape.cpp +++ b/ngraph/core/src/partial_shape.cpp @@ -23,26 +23,53 @@ using namespace ngraph; +PartialShape::PartialShape() + : PartialShape(std::initializer_list{}) +{ +} + +PartialShape::PartialShape(std::initializer_list init) + : PartialShape(true, init) +{ +} + PartialShape::PartialShape(const std::vector& dimensions) : m_rank_is_static(true) + , m_dimensions(dimensions.begin(), dimensions.end()) { - std::transform(dimensions.cbegin(), - dimensions.cend(), - std::back_inserter(m_dimensions), - [](const Dimension::value_type& dimension) { return dimension; }); } PartialShape::PartialShape(const Shape& shape) - : PartialShape(true, {}) + : m_rank_is_static(true) + , m_shape_type(ShapeType::SHAPE_IS_STATIC) + , m_dimensions(shape.begin(), shape.end()) +{ +} + +PartialShape::PartialShape(bool rank_is_static, const std::vector& dimensions) + : m_rank_is_static(rank_is_static) + , m_dimensions(dimensions) +{ +} + +PartialShape::PartialShape(const std::vector& dimensions) + : m_rank_is_static(true) + , m_dimensions(dimensions) { - m_dimensions.assign(shape.begin(), shape.end()); } bool ngraph::PartialShape::is_static() const { - return m_rank_is_static && std::all_of(m_dimensions.begin(), - m_dimensions.end(), - [](const Dimension& d) { return d.is_static(); }); + if (m_shape_type == ShapeType::SHAPE_IS_UNKNOWN) + { + m_shape_type = + m_rank_is_static && std::all_of(m_dimensions.begin(), + m_dimensions.end(), + [](const Dimension& d) { return d.is_static(); }) + ? ShapeType::SHAPE_IS_STATIC + : ShapeType::SHAPE_IS_DYNAMIC; + } + return m_shape_type == ShapeType::SHAPE_IS_STATIC; } bool ngraph::PartialShape::operator==(const PartialShape& partial_shape) const @@ -282,6 +309,7 @@ bool PartialShape::merge_rank(Rank r) { m_rank_is_static = true; m_dimensions = std::vector(r.get_length(), Dimension::dynamic()); + m_shape_type = ShapeType::SHAPE_IS_UNKNOWN; return true; } else @@ -297,13 +325,13 @@ Shape PartialShape::to_shape() const throw std::invalid_argument("to_shape was called on a dynamic shape."); } - std::vector dimensions_to_shape(m_dimensions.size()); + std::vector shape_dimensions(m_dimensions.size()); std::transform(m_dimensions.begin(), m_dimensions.end(), - dimensions_to_shape.begin(), + shape_dimensions.begin(), [](const Dimension& d) { return d.get_length(); }); - return Shape(dimensions_to_shape.begin(), dimensions_to_shape.end()); + return shape_dimensions; } bool PartialShape::merge_into(PartialShape& dst, const PartialShape& src) @@ -444,6 +472,8 @@ Dimension& PartialShape::operator[](size_t i) { throw std::out_of_range("Accessing out-of-range dimension in Dimension[]"); } + m_shape_type = + ShapeType::SHAPE_IS_UNKNOWN; // We can't guarantee that the shape remains static or dynamic. return m_dimensions[i]; } diff --git a/ngraph/core/src/type/element_type.cpp b/ngraph/core/src/type/element_type.cpp index 588a14006ac..a807bb56a7c 100644 --- a/ngraph/core/src/type/element_type.cpp +++ b/ngraph/core/src/type/element_type.cpp @@ -15,8 +15,9 @@ //***************************************************************************** #include +#include #include -#include +#include #include "ngraph/log.hpp" #include "ngraph/type/element_type.hpp" @@ -69,9 +70,16 @@ public: std::string m_type_name; }; -static const map& get_type_info_map() +struct element_type_hash { - static map s_type_info_map{ + size_t operator()(element::Type_t t) const { return static_cast(t); } +}; + +typedef unordered_map element_types_map_t; + +static const element_types_map_t& get_type_info_map() +{ + static element_types_map_t s_type_info_map{ {element::Type_t::undefined, TypeInfo( std::numeric_limits::max(), false, false, false, "undefined", "undefined")},