Validate speedup (#6779)

* Add minor speedup changes.

* inline clip

* reduce clip calls

* more Interval::size - move to header

* terminate instead of throwing exception

* back to throw exception when element type was not found

* rename variable
This commit is contained in:
Patryk Elszkowski 2021-07-30 07:59:36 +02:00 committed by GitHub
parent 7ab92b5845
commit 0861a5c910
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 119 additions and 119 deletions

View File

@ -28,7 +28,7 @@ namespace ngraph
}); });
return rc; return rc;
}; };
for (auto p : get().m_string_enums) for (const auto& p : get().m_string_enums)
{ {
if (to_lower(p.first) == to_lower(name)) if (to_lower(p.first) == to_lower(name))
{ {
@ -41,7 +41,7 @@ namespace ngraph
/// Converts enum values to strings /// Converts enum values to strings
static const std::string& as_string(EnumType e) static const std::string& as_string(EnumType e)
{ {
for (auto& p : get().m_string_enums) for (const auto& p : get().m_string_enums)
{ {
if (p.second == e) if (p.second == e)
{ {

View File

@ -41,9 +41,16 @@ namespace ngraph
Interval& operator=(const Interval& interval) = default; Interval& operator=(const Interval& interval) = default;
/// \brief The number of elements in the interval. Zero if max < min. /// \brief The number of elements in the interval. Zero if max < min.
size_type size() const; size_type size() const
{
if (m_max_val == s_max)
{
return m_min_val == s_max ? 0 : s_max;
}
return m_max_val - m_min_val + 1;
}
/// \brief Returns true if the interval has no elements /// \brief Returns true if the interval has no elements
bool empty() const; bool empty() const { return m_min_val == s_max; }
/// \brief the inclusive lower bound of the interval /// \brief the inclusive lower bound of the interval
value_type get_min_val() const { return m_min_val; } value_type get_min_val() const { return m_min_val; }
/// \brief Set the inclusive lower bound of the interval /// \brief Set the inclusive lower bound of the interval
@ -84,7 +91,7 @@ namespace ngraph
Interval& operator&=(const Interval& interval); Interval& operator&=(const Interval& interval);
/// \brief True if this interval includes value /// \brief True if this interval includes value
bool contains(value_type value) const; bool contains(value_type value) const { return m_min_val <= value && value <= m_max_val; }
/// \brief True if this interval includes all the values in interval /// \brief True if this interval includes all the values in interval
bool contains(const Interval& interval) const; bool contains(const Interval& interval) const;
@ -93,10 +100,6 @@ namespace ngraph
protected: protected:
void canonicalize(); void canonicalize();
static value_type clip(value_type value);
static value_type clip_times(value_type a, value_type b);
static value_type clip_add(value_type a, value_type b);
static value_type clip_minus(value_type a, value_type b);
value_type m_min_val{0}; value_type m_min_val{0};
value_type m_max_val{s_max}; value_type m_max_val{s_max};

View File

@ -54,7 +54,7 @@ namespace ngraph
/// \brief Constructs a PartialShape with static rank from a vector of Dimension. /// \brief Constructs a PartialShape with static rank from a vector of Dimension.
/// \param dimensions The Dimension values for the constructed shape. /// \param dimensions The Dimension values for the constructed shape.
PartialShape(const std::vector<Dimension>& dimensions); PartialShape(std::vector<Dimension> dimensions);
/// \brief Constructs a PartialShape with static rank from a vector of dimensions values. /// \brief Constructs a PartialShape with static rank from a vector of dimensions values.
/// \param dimensions The Dimension values for the constructed shape. /// \param dimensions The Dimension values for the constructed shape.
@ -269,7 +269,7 @@ namespace ngraph
private: private:
// Private constructor for PartialShape::dynamic(). // Private constructor for PartialShape::dynamic().
PartialShape(bool rank_is_static, const std::vector<Dimension>& dimensions); PartialShape(bool rank_is_static, std::vector<Dimension> dimensions);
// True if the shape's rank is static. // True if the shape's rank is static.
bool m_rank_is_static; bool m_rank_is_static;

View File

@ -99,12 +99,12 @@ bool Dimension::merge(Dimension& dst, const Dimension d1, const Dimension d2)
bool Dimension::broadcast_merge(Dimension& dst, const Dimension d1, const Dimension d2) bool Dimension::broadcast_merge(Dimension& dst, const Dimension d1, const Dimension d2)
{ {
if (d1.m_dimension.size() == 1 && d1.m_dimension.get_min_val() == 1) if (d1.m_dimension.get_min_val() == 1 && d1.m_dimension.size() == 1)
{ {
dst = d2; dst = d2;
return true; return true;
} }
if (d2.m_dimension.size() == 1 && d2.m_dimension.get_min_val() == 1) if (d2.m_dimension.get_min_val() == 1 && d2.m_dimension.size() == 1)
{ {
dst = d1; dst = d1;
return true; return true;

View File

@ -6,6 +6,46 @@
using namespace ngraph; using namespace ngraph;
namespace
{
Interval::value_type clip(Interval::value_type value)
{
return std::max(Interval::value_type(0), std::min(Interval::s_max, value));
}
Interval::value_type clip_times(Interval::value_type a, Interval::value_type b)
{
if (a == 0 || b == 0)
{
return 0;
}
else if (a == Interval::s_max || b == Interval::s_max)
{
return Interval::s_max;
}
else
{
return a * b;
}
}
Interval::value_type clip_add(Interval::value_type a, Interval::value_type b)
{
return (a == Interval::s_max || b == Interval::s_max) ? Interval::s_max : a + b;
}
Interval::value_type clip_minus(Interval::value_type a, Interval::value_type b)
{
if (a <= b)
{
return 0;
}
if (a == Interval::s_max)
{
return Interval::s_max;
}
return a - b;
}
} // namespace
void Interval::canonicalize() void Interval::canonicalize()
{ {
if (m_max_val < m_min_val) if (m_max_val < m_min_val)
@ -28,22 +68,9 @@ Interval::Interval(value_type min_val, value_type max_val)
} }
Interval::Interval(value_type val) Interval::Interval(value_type val)
: Interval(val, val)
{ {
} m_min_val = clip(val);
m_max_val = m_min_val;
Interval::size_type Interval::size() const
{
if (m_max_val == s_max)
{
return m_min_val == s_max ? 0 : s_max;
}
return m_max_val - m_min_val + 1;
}
bool Interval::empty() const
{
return m_min_val == s_max;
} }
bool Interval::operator==(const Interval& interval) const bool Interval::operator==(const Interval& interval) const
@ -116,55 +143,11 @@ Interval& Interval::operator&=(const Interval& interval)
return *this = *this & interval; return *this = *this & interval;
} }
bool Interval::contains(value_type value) const
{
return m_min_val <= value && value <= m_max_val;
}
bool Interval::contains(const Interval& interval) const bool Interval::contains(const Interval& interval) const
{ {
return contains(interval.m_min_val) && contains(interval.m_max_val); return contains(interval.m_min_val) && contains(interval.m_max_val);
} }
Interval::value_type Interval::clip(value_type value)
{
return std::max(value_type(0), std::min(s_max, value));
}
Interval::value_type Interval::clip_add(value_type a, value_type b)
{
return (a == s_max || b == s_max) ? s_max : a + b;
}
Interval::value_type Interval::clip_minus(value_type a, value_type b)
{
if (a <= b)
{
return 0;
}
if (a == s_max)
{
return s_max;
}
return a - b;
}
Interval::value_type Interval::clip_times(value_type a, value_type b)
{
if (a == 0 || b == 0)
{
return 0;
}
else if (a == s_max || b == s_max)
{
return s_max;
}
else
{
return a * b;
}
}
constexpr Interval::value_type Interval::s_max; constexpr Interval::value_type Interval::s_max;
namespace ngraph namespace ngraph

View File

@ -34,15 +34,15 @@ PartialShape::PartialShape(const Shape& shape)
{ {
} }
PartialShape::PartialShape(bool rank_is_static, const std::vector<Dimension>& dimensions) PartialShape::PartialShape(bool rank_is_static, std::vector<Dimension> dimensions)
: m_rank_is_static(rank_is_static) : m_rank_is_static(rank_is_static)
, m_dimensions(dimensions) , m_dimensions(std::move(dimensions))
{ {
} }
PartialShape::PartialShape(const std::vector<Dimension>& dimensions) PartialShape::PartialShape(std::vector<Dimension> dimensions)
: m_rank_is_static(true) : m_rank_is_static(true)
, m_dimensions(dimensions) , m_dimensions(std::move(dimensions))
{ {
} }
@ -387,7 +387,7 @@ bool PartialShape::broadcast_merge_into(PartialShape& dst,
i < (new_rank - src_rank) ? Dimension(1) : src[i - (new_rank - src_rank)]; i < (new_rank - src_rank) ? Dimension(1) : src[i - (new_rank - src_rank)];
success &= Dimension::broadcast_merge(dims[i], dsti, srci); success &= Dimension::broadcast_merge(dims[i], dsti, srci);
} }
dst = PartialShape(dims); dst = PartialShape(std::move(dims));
return success; return success;
} }
} }

View File

@ -12,45 +12,47 @@
#include "ngraph/type/element_type_traits.hpp" #include "ngraph/type/element_type_traits.hpp"
using namespace ngraph; using namespace ngraph;
using namespace std;
constexpr DiscreteTypeInfo AttributeAdapter<element::Type>::type_info; constexpr DiscreteTypeInfo AttributeAdapter<element::Type>::type_info;
namespace
class TypeInfo
{ {
public: class TypeInfo
TypeInfo(size_t bitwidth,
bool is_real,
bool is_signed,
bool is_quantized,
const std::string& cname,
const std::string& type_name)
: m_bitwidth{bitwidth}
, m_is_real{is_real}
, m_is_signed{is_signed}
, m_is_quantized{is_quantized}
, m_cname{cname}
, m_type_name{type_name}
{ {
} public:
size_t m_bitwidth; TypeInfo(size_t bitwidth,
bool m_is_real; bool is_real,
bool m_is_signed; bool is_signed,
bool m_is_quantized; bool is_quantized,
std::string m_cname; const std::string& cname,
std::string m_type_name; const std::string& type_name)
}; : m_bitwidth{bitwidth}
, m_is_real{is_real}
, m_is_signed{is_signed}
, m_is_quantized{is_quantized}
, m_cname{cname}
, m_type_name{type_name}
{
}
size_t m_bitwidth;
bool m_is_real;
bool m_is_signed;
bool m_is_quantized;
std::string m_cname;
std::string m_type_name;
};
struct element_type_hash struct ElementTypes
{ {
size_t operator()(element::Type_t t) const { return static_cast<size_t>(t); } struct TypeHash
}; {
size_t operator()(element::Type_t t) const { return static_cast<size_t>(t); }
};
typedef unordered_map<element::Type_t, const TypeInfo, element_type_hash> element_types_map_t; using ElementsMap = std::unordered_map<element::Type_t, TypeInfo, TypeHash>;
static const ElementsMap elements_map;
};
static const element_types_map_t& get_type_info_map() const ElementTypes::ElementsMap ElementTypes::elements_map{
{
static element_types_map_t s_type_info_map{
{element::Type_t::undefined, {element::Type_t::undefined,
TypeInfo( TypeInfo(
std::numeric_limits<size_t>::max(), false, false, false, "undefined", "undefined")}, std::numeric_limits<size_t>::max(), false, false, false, "undefined", "undefined")},
@ -72,8 +74,20 @@ static const element_types_map_t& get_type_info_map()
{element::Type_t::u32, TypeInfo(32, false, false, false, "uint32_t", "u32")}, {element::Type_t::u32, TypeInfo(32, false, false, false, "uint32_t", "u32")},
{element::Type_t::u64, TypeInfo(64, false, false, false, "uint64_t", "u64")}, {element::Type_t::u64, TypeInfo(64, false, false, false, "uint64_t", "u64")},
}; };
return s_type_info_map;
}; const ElementTypes::ElementsMap& get_type_info_map() { return ElementTypes::elements_map; };
const TypeInfo& get_type_info(element::Type_t type)
{
const auto& tim = get_type_info_map();
const auto& found = tim.find(type);
if (found == tim.end())
{
throw std::out_of_range{"element::Type_t not supported"};
}
return found->second;
};
} // namespace
std::vector<const element::Type*> element::Type::get_known_types() std::vector<const element::Type*> element::Type::get_known_types()
{ {
@ -103,7 +117,7 @@ element::Type::Type(size_t bitwidth,
bool is_quantized, bool is_quantized,
const std::string& /* cname */) const std::string& /* cname */)
{ {
for (auto& t : get_type_info_map()) for (const auto& t : get_type_info_map())
{ {
const TypeInfo& info = t.second; const TypeInfo& info = t.second;
if (bitwidth == info.m_bitwidth && is_real == info.m_is_real && if (bitwidth == info.m_bitwidth && is_real == info.m_is_real &&
@ -117,7 +131,7 @@ element::Type::Type(size_t bitwidth,
const std::string& element::Type::c_type_string() const const std::string& element::Type::c_type_string() const
{ {
return get_type_info_map().at(m_type).m_cname; return get_type_info(m_type).m_cname;
} }
size_t element::Type::size() const size_t element::Type::size() const
@ -132,7 +146,7 @@ size_t element::Type::hash() const
const std::string& element::Type::get_type_name() const const std::string& element::Type::get_type_name() const
{ {
return get_type_info_map().at(m_type).m_type_name; return get_type_info(m_type).m_type_name;
} }
namespace ngraph namespace ngraph
@ -247,12 +261,12 @@ bool element::Type::merge(element::Type& dst, const element::Type& t1, const ele
bool element::Type::is_static() const bool element::Type::is_static() const
{ {
return get_type_info_map().at(m_type).m_bitwidth != 0; return get_type_info(m_type).m_bitwidth != 0;
} }
bool element::Type::is_real() const bool element::Type::is_real() const
{ {
return get_type_info_map().at(m_type).m_is_real; return get_type_info(m_type).m_is_real;
} }
bool element::Type::is_integral_number() const bool element::Type::is_integral_number() const
@ -262,17 +276,17 @@ bool element::Type::is_integral_number() const
bool element::Type::is_signed() const bool element::Type::is_signed() const
{ {
return get_type_info_map().at(m_type).m_is_signed; return get_type_info(m_type).m_is_signed;
} }
bool element::Type::is_quantized() const bool element::Type::is_quantized() const
{ {
return get_type_info_map().at(m_type).m_is_quantized; return get_type_info(m_type).m_is_quantized;
} }
size_t element::Type::bitwidth() const size_t element::Type::bitwidth() const
{ {
return get_type_info_map().at(m_type).m_bitwidth; return get_type_info(m_type).m_bitwidth;
} }
size_t ngraph::compiler_byte_size(element::Type_t et) size_t ngraph::compiler_byte_size(element::Type_t et)