Validate speedup (#6779)

* Add minor speedup changes.

* inline clip

* reduce clip calls

* more Interval::size - move to header

* terminate instead of throwing exception

* back to throw exception when element type was not found

* rename variable
This commit is contained in:
Patryk Elszkowski 2021-07-30 07:59:36 +02:00 committed by GitHub
parent 7ab92b5845
commit 0861a5c910
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 119 additions and 119 deletions

View File

@ -28,7 +28,7 @@ namespace ngraph
});
return rc;
};
for (auto p : get().m_string_enums)
for (const auto& p : get().m_string_enums)
{
if (to_lower(p.first) == to_lower(name))
{
@ -41,7 +41,7 @@ namespace ngraph
/// Converts enum values to strings
static const std::string& as_string(EnumType e)
{
for (auto& p : get().m_string_enums)
for (const auto& p : get().m_string_enums)
{
if (p.second == e)
{

View File

@ -41,9 +41,16 @@ namespace ngraph
Interval& operator=(const Interval& interval) = default;
/// \brief The number of elements in the interval. Zero if max < min.
size_type size() const;
size_type size() const
{
if (m_max_val == s_max)
{
return m_min_val == s_max ? 0 : s_max;
}
return m_max_val - m_min_val + 1;
}
/// \brief Returns true if the interval has no elements
bool empty() const;
bool empty() const { return m_min_val == s_max; }
/// \brief the inclusive lower bound of the interval
value_type get_min_val() const { return m_min_val; }
/// \brief Set the inclusive lower bound of the interval
@ -84,7 +91,7 @@ namespace ngraph
Interval& operator&=(const Interval& interval);
/// \brief True if this interval includes value
bool contains(value_type value) const;
bool contains(value_type value) const { return m_min_val <= value && value <= m_max_val; }
/// \brief True if this interval includes all the values in interval
bool contains(const Interval& interval) const;
@ -93,10 +100,6 @@ namespace ngraph
protected:
void canonicalize();
static value_type clip(value_type value);
static value_type clip_times(value_type a, value_type b);
static value_type clip_add(value_type a, value_type b);
static value_type clip_minus(value_type a, value_type b);
value_type m_min_val{0};
value_type m_max_val{s_max};

View File

@ -54,7 +54,7 @@ namespace ngraph
/// \brief Constructs a PartialShape with static rank from a vector of Dimension.
/// \param dimensions The Dimension values for the constructed shape.
PartialShape(const std::vector<Dimension>& dimensions);
PartialShape(std::vector<Dimension> dimensions);
/// \brief Constructs a PartialShape with static rank from a vector of dimensions values.
/// \param dimensions The Dimension values for the constructed shape.
@ -269,7 +269,7 @@ namespace ngraph
private:
// Private constructor for PartialShape::dynamic().
PartialShape(bool rank_is_static, const std::vector<Dimension>& dimensions);
PartialShape(bool rank_is_static, std::vector<Dimension> dimensions);
// True if the shape's rank is static.
bool m_rank_is_static;

View File

@ -99,12 +99,12 @@ bool Dimension::merge(Dimension& dst, const Dimension d1, const Dimension d2)
bool Dimension::broadcast_merge(Dimension& dst, const Dimension d1, const Dimension d2)
{
if (d1.m_dimension.size() == 1 && d1.m_dimension.get_min_val() == 1)
if (d1.m_dimension.get_min_val() == 1 && d1.m_dimension.size() == 1)
{
dst = d2;
return true;
}
if (d2.m_dimension.size() == 1 && d2.m_dimension.get_min_val() == 1)
if (d2.m_dimension.get_min_val() == 1 && d2.m_dimension.size() == 1)
{
dst = d1;
return true;

View File

@ -6,6 +6,46 @@
using namespace ngraph;
namespace
{
Interval::value_type clip(Interval::value_type value)
{
return std::max(Interval::value_type(0), std::min(Interval::s_max, value));
}
Interval::value_type clip_times(Interval::value_type a, Interval::value_type b)
{
if (a == 0 || b == 0)
{
return 0;
}
else if (a == Interval::s_max || b == Interval::s_max)
{
return Interval::s_max;
}
else
{
return a * b;
}
}
Interval::value_type clip_add(Interval::value_type a, Interval::value_type b)
{
return (a == Interval::s_max || b == Interval::s_max) ? Interval::s_max : a + b;
}
Interval::value_type clip_minus(Interval::value_type a, Interval::value_type b)
{
if (a <= b)
{
return 0;
}
if (a == Interval::s_max)
{
return Interval::s_max;
}
return a - b;
}
} // namespace
void Interval::canonicalize()
{
if (m_max_val < m_min_val)
@ -28,22 +68,9 @@ Interval::Interval(value_type min_val, value_type max_val)
}
Interval::Interval(value_type val)
: Interval(val, val)
{
}
Interval::size_type Interval::size() const
{
if (m_max_val == s_max)
{
return m_min_val == s_max ? 0 : s_max;
}
return m_max_val - m_min_val + 1;
}
bool Interval::empty() const
{
return m_min_val == s_max;
m_min_val = clip(val);
m_max_val = m_min_val;
}
bool Interval::operator==(const Interval& interval) const
@ -116,55 +143,11 @@ Interval& Interval::operator&=(const Interval& interval)
return *this = *this & interval;
}
bool Interval::contains(value_type value) const
{
return m_min_val <= value && value <= m_max_val;
}
bool Interval::contains(const Interval& interval) const
{
return contains(interval.m_min_val) && contains(interval.m_max_val);
}
Interval::value_type Interval::clip(value_type value)
{
return std::max(value_type(0), std::min(s_max, value));
}
Interval::value_type Interval::clip_add(value_type a, value_type b)
{
return (a == s_max || b == s_max) ? s_max : a + b;
}
Interval::value_type Interval::clip_minus(value_type a, value_type b)
{
if (a <= b)
{
return 0;
}
if (a == s_max)
{
return s_max;
}
return a - b;
}
Interval::value_type Interval::clip_times(value_type a, value_type b)
{
if (a == 0 || b == 0)
{
return 0;
}
else if (a == s_max || b == s_max)
{
return s_max;
}
else
{
return a * b;
}
}
constexpr Interval::value_type Interval::s_max;
namespace ngraph

View File

@ -34,15 +34,15 @@ PartialShape::PartialShape(const Shape& shape)
{
}
PartialShape::PartialShape(bool rank_is_static, const std::vector<Dimension>& dimensions)
PartialShape::PartialShape(bool rank_is_static, std::vector<Dimension> dimensions)
: m_rank_is_static(rank_is_static)
, m_dimensions(dimensions)
, m_dimensions(std::move(dimensions))
{
}
PartialShape::PartialShape(const std::vector<Dimension>& dimensions)
PartialShape::PartialShape(std::vector<Dimension> dimensions)
: m_rank_is_static(true)
, m_dimensions(dimensions)
, m_dimensions(std::move(dimensions))
{
}
@ -387,7 +387,7 @@ bool PartialShape::broadcast_merge_into(PartialShape& dst,
i < (new_rank - src_rank) ? Dimension(1) : src[i - (new_rank - src_rank)];
success &= Dimension::broadcast_merge(dims[i], dsti, srci);
}
dst = PartialShape(dims);
dst = PartialShape(std::move(dims));
return success;
}
}

View File

@ -12,13 +12,13 @@
#include "ngraph/type/element_type_traits.hpp"
using namespace ngraph;
using namespace std;
constexpr DiscreteTypeInfo AttributeAdapter<element::Type>::type_info;
class TypeInfo
namespace
{
public:
class TypeInfo
{
public:
TypeInfo(size_t bitwidth,
bool is_real,
bool is_signed,
@ -39,18 +39,20 @@ public:
bool m_is_quantized;
std::string m_cname;
std::string m_type_name;
};
};
struct element_type_hash
{
struct ElementTypes
{
struct TypeHash
{
size_t operator()(element::Type_t t) const { return static_cast<size_t>(t); }
};
};
typedef unordered_map<element::Type_t, const TypeInfo, element_type_hash> element_types_map_t;
using ElementsMap = std::unordered_map<element::Type_t, TypeInfo, TypeHash>;
static const ElementsMap elements_map;
};
static const element_types_map_t& get_type_info_map()
{
static element_types_map_t s_type_info_map{
const ElementTypes::ElementsMap ElementTypes::elements_map{
{element::Type_t::undefined,
TypeInfo(
std::numeric_limits<size_t>::max(), false, false, false, "undefined", "undefined")},
@ -72,8 +74,20 @@ static const element_types_map_t& get_type_info_map()
{element::Type_t::u32, TypeInfo(32, false, false, false, "uint32_t", "u32")},
{element::Type_t::u64, TypeInfo(64, false, false, false, "uint64_t", "u64")},
};
return s_type_info_map;
};
const ElementTypes::ElementsMap& get_type_info_map() { return ElementTypes::elements_map; };
const TypeInfo& get_type_info(element::Type_t type)
{
const auto& tim = get_type_info_map();
const auto& found = tim.find(type);
if (found == tim.end())
{
throw std::out_of_range{"element::Type_t not supported"};
}
return found->second;
};
} // namespace
std::vector<const element::Type*> element::Type::get_known_types()
{
@ -103,7 +117,7 @@ element::Type::Type(size_t bitwidth,
bool is_quantized,
const std::string& /* cname */)
{
for (auto& t : get_type_info_map())
for (const auto& t : get_type_info_map())
{
const TypeInfo& info = t.second;
if (bitwidth == info.m_bitwidth && is_real == info.m_is_real &&
@ -117,7 +131,7 @@ element::Type::Type(size_t bitwidth,
const std::string& element::Type::c_type_string() const
{
return get_type_info_map().at(m_type).m_cname;
return get_type_info(m_type).m_cname;
}
size_t element::Type::size() const
@ -132,7 +146,7 @@ size_t element::Type::hash() const
const std::string& element::Type::get_type_name() const
{
return get_type_info_map().at(m_type).m_type_name;
return get_type_info(m_type).m_type_name;
}
namespace ngraph
@ -247,12 +261,12 @@ bool element::Type::merge(element::Type& dst, const element::Type& t1, const ele
bool element::Type::is_static() const
{
return get_type_info_map().at(m_type).m_bitwidth != 0;
return get_type_info(m_type).m_bitwidth != 0;
}
bool element::Type::is_real() const
{
return get_type_info_map().at(m_type).m_is_real;
return get_type_info(m_type).m_is_real;
}
bool element::Type::is_integral_number() const
@ -262,17 +276,17 @@ bool element::Type::is_integral_number() const
bool element::Type::is_signed() const
{
return get_type_info_map().at(m_type).m_is_signed;
return get_type_info(m_type).m_is_signed;
}
bool element::Type::is_quantized() const
{
return get_type_info_map().at(m_type).m_is_quantized;
return get_type_info(m_type).m_is_quantized;
}
size_t element::Type::bitwidth() const
{
return get_type_info_map().at(m_type).m_bitwidth;
return get_type_info(m_type).m_bitwidth;
}
size_t ngraph::compiler_byte_size(element::Type_t et)