Validate speedup (#6779)
* Add minor speedup changes. * inline clip * reduce clip calls * more Interval::size - move to header * terminate instead of throwing exception * back to throw exception when element type was not found * rename variable
This commit is contained in:
parent
7ab92b5845
commit
0861a5c910
@ -28,7 +28,7 @@ namespace ngraph
|
||||
});
|
||||
return rc;
|
||||
};
|
||||
for (auto p : get().m_string_enums)
|
||||
for (const auto& p : get().m_string_enums)
|
||||
{
|
||||
if (to_lower(p.first) == to_lower(name))
|
||||
{
|
||||
@ -41,7 +41,7 @@ namespace ngraph
|
||||
/// Converts enum values to strings
|
||||
static const std::string& as_string(EnumType e)
|
||||
{
|
||||
for (auto& p : get().m_string_enums)
|
||||
for (const auto& p : get().m_string_enums)
|
||||
{
|
||||
if (p.second == e)
|
||||
{
|
||||
|
@ -41,9 +41,16 @@ namespace ngraph
|
||||
Interval& operator=(const Interval& interval) = default;
|
||||
|
||||
/// \brief The number of elements in the interval. Zero if max < min.
|
||||
size_type size() const;
|
||||
size_type size() const
|
||||
{
|
||||
if (m_max_val == s_max)
|
||||
{
|
||||
return m_min_val == s_max ? 0 : s_max;
|
||||
}
|
||||
return m_max_val - m_min_val + 1;
|
||||
}
|
||||
/// \brief Returns true if the interval has no elements
|
||||
bool empty() const;
|
||||
bool empty() const { return m_min_val == s_max; }
|
||||
/// \brief the inclusive lower bound of the interval
|
||||
value_type get_min_val() const { return m_min_val; }
|
||||
/// \brief Set the inclusive lower bound of the interval
|
||||
@ -84,7 +91,7 @@ namespace ngraph
|
||||
Interval& operator&=(const Interval& interval);
|
||||
|
||||
/// \brief True if this interval includes value
|
||||
bool contains(value_type value) const;
|
||||
bool contains(value_type value) const { return m_min_val <= value && value <= m_max_val; }
|
||||
/// \brief True if this interval includes all the values in interval
|
||||
bool contains(const Interval& interval) const;
|
||||
|
||||
@ -93,10 +100,6 @@ namespace ngraph
|
||||
|
||||
protected:
|
||||
void canonicalize();
|
||||
static value_type clip(value_type value);
|
||||
static value_type clip_times(value_type a, value_type b);
|
||||
static value_type clip_add(value_type a, value_type b);
|
||||
static value_type clip_minus(value_type a, value_type b);
|
||||
|
||||
value_type m_min_val{0};
|
||||
value_type m_max_val{s_max};
|
||||
|
@ -54,7 +54,7 @@ namespace ngraph
|
||||
|
||||
/// \brief Constructs a PartialShape with static rank from a vector of Dimension.
|
||||
/// \param dimensions The Dimension values for the constructed shape.
|
||||
PartialShape(const std::vector<Dimension>& dimensions);
|
||||
PartialShape(std::vector<Dimension> dimensions);
|
||||
|
||||
/// \brief Constructs a PartialShape with static rank from a vector of dimensions values.
|
||||
/// \param dimensions The Dimension values for the constructed shape.
|
||||
@ -269,7 +269,7 @@ namespace ngraph
|
||||
|
||||
private:
|
||||
// Private constructor for PartialShape::dynamic().
|
||||
PartialShape(bool rank_is_static, const std::vector<Dimension>& dimensions);
|
||||
PartialShape(bool rank_is_static, std::vector<Dimension> dimensions);
|
||||
|
||||
// True if the shape's rank is static.
|
||||
bool m_rank_is_static;
|
||||
|
@ -99,12 +99,12 @@ bool Dimension::merge(Dimension& dst, const Dimension d1, const Dimension d2)
|
||||
|
||||
bool Dimension::broadcast_merge(Dimension& dst, const Dimension d1, const Dimension d2)
|
||||
{
|
||||
if (d1.m_dimension.size() == 1 && d1.m_dimension.get_min_val() == 1)
|
||||
if (d1.m_dimension.get_min_val() == 1 && d1.m_dimension.size() == 1)
|
||||
{
|
||||
dst = d2;
|
||||
return true;
|
||||
}
|
||||
if (d2.m_dimension.size() == 1 && d2.m_dimension.get_min_val() == 1)
|
||||
if (d2.m_dimension.get_min_val() == 1 && d2.m_dimension.size() == 1)
|
||||
{
|
||||
dst = d1;
|
||||
return true;
|
||||
|
@ -6,6 +6,46 @@
|
||||
|
||||
using namespace ngraph;
|
||||
|
||||
namespace
|
||||
{
|
||||
Interval::value_type clip(Interval::value_type value)
|
||||
{
|
||||
return std::max(Interval::value_type(0), std::min(Interval::s_max, value));
|
||||
}
|
||||
|
||||
Interval::value_type clip_times(Interval::value_type a, Interval::value_type b)
|
||||
{
|
||||
if (a == 0 || b == 0)
|
||||
{
|
||||
return 0;
|
||||
}
|
||||
else if (a == Interval::s_max || b == Interval::s_max)
|
||||
{
|
||||
return Interval::s_max;
|
||||
}
|
||||
else
|
||||
{
|
||||
return a * b;
|
||||
}
|
||||
}
|
||||
Interval::value_type clip_add(Interval::value_type a, Interval::value_type b)
|
||||
{
|
||||
return (a == Interval::s_max || b == Interval::s_max) ? Interval::s_max : a + b;
|
||||
}
|
||||
Interval::value_type clip_minus(Interval::value_type a, Interval::value_type b)
|
||||
{
|
||||
if (a <= b)
|
||||
{
|
||||
return 0;
|
||||
}
|
||||
if (a == Interval::s_max)
|
||||
{
|
||||
return Interval::s_max;
|
||||
}
|
||||
return a - b;
|
||||
}
|
||||
} // namespace
|
||||
|
||||
void Interval::canonicalize()
|
||||
{
|
||||
if (m_max_val < m_min_val)
|
||||
@ -28,22 +68,9 @@ Interval::Interval(value_type min_val, value_type max_val)
|
||||
}
|
||||
|
||||
Interval::Interval(value_type val)
|
||||
: Interval(val, val)
|
||||
{
|
||||
}
|
||||
|
||||
Interval::size_type Interval::size() const
|
||||
{
|
||||
if (m_max_val == s_max)
|
||||
{
|
||||
return m_min_val == s_max ? 0 : s_max;
|
||||
}
|
||||
return m_max_val - m_min_val + 1;
|
||||
}
|
||||
|
||||
bool Interval::empty() const
|
||||
{
|
||||
return m_min_val == s_max;
|
||||
m_min_val = clip(val);
|
||||
m_max_val = m_min_val;
|
||||
}
|
||||
|
||||
bool Interval::operator==(const Interval& interval) const
|
||||
@ -116,55 +143,11 @@ Interval& Interval::operator&=(const Interval& interval)
|
||||
return *this = *this & interval;
|
||||
}
|
||||
|
||||
bool Interval::contains(value_type value) const
|
||||
{
|
||||
return m_min_val <= value && value <= m_max_val;
|
||||
}
|
||||
|
||||
bool Interval::contains(const Interval& interval) const
|
||||
{
|
||||
return contains(interval.m_min_val) && contains(interval.m_max_val);
|
||||
}
|
||||
|
||||
Interval::value_type Interval::clip(value_type value)
|
||||
{
|
||||
return std::max(value_type(0), std::min(s_max, value));
|
||||
}
|
||||
|
||||
Interval::value_type Interval::clip_add(value_type a, value_type b)
|
||||
{
|
||||
return (a == s_max || b == s_max) ? s_max : a + b;
|
||||
}
|
||||
|
||||
Interval::value_type Interval::clip_minus(value_type a, value_type b)
|
||||
{
|
||||
if (a <= b)
|
||||
{
|
||||
return 0;
|
||||
}
|
||||
if (a == s_max)
|
||||
{
|
||||
return s_max;
|
||||
}
|
||||
return a - b;
|
||||
}
|
||||
|
||||
Interval::value_type Interval::clip_times(value_type a, value_type b)
|
||||
{
|
||||
if (a == 0 || b == 0)
|
||||
{
|
||||
return 0;
|
||||
}
|
||||
else if (a == s_max || b == s_max)
|
||||
{
|
||||
return s_max;
|
||||
}
|
||||
else
|
||||
{
|
||||
return a * b;
|
||||
}
|
||||
}
|
||||
|
||||
constexpr Interval::value_type Interval::s_max;
|
||||
|
||||
namespace ngraph
|
||||
|
@ -34,15 +34,15 @@ PartialShape::PartialShape(const Shape& shape)
|
||||
{
|
||||
}
|
||||
|
||||
PartialShape::PartialShape(bool rank_is_static, const std::vector<Dimension>& dimensions)
|
||||
PartialShape::PartialShape(bool rank_is_static, std::vector<Dimension> dimensions)
|
||||
: m_rank_is_static(rank_is_static)
|
||||
, m_dimensions(dimensions)
|
||||
, m_dimensions(std::move(dimensions))
|
||||
{
|
||||
}
|
||||
|
||||
PartialShape::PartialShape(const std::vector<Dimension>& dimensions)
|
||||
PartialShape::PartialShape(std::vector<Dimension> dimensions)
|
||||
: m_rank_is_static(true)
|
||||
, m_dimensions(dimensions)
|
||||
, m_dimensions(std::move(dimensions))
|
||||
{
|
||||
}
|
||||
|
||||
@ -387,7 +387,7 @@ bool PartialShape::broadcast_merge_into(PartialShape& dst,
|
||||
i < (new_rank - src_rank) ? Dimension(1) : src[i - (new_rank - src_rank)];
|
||||
success &= Dimension::broadcast_merge(dims[i], dsti, srci);
|
||||
}
|
||||
dst = PartialShape(dims);
|
||||
dst = PartialShape(std::move(dims));
|
||||
return success;
|
||||
}
|
||||
}
|
||||
|
@ -12,13 +12,13 @@
|
||||
#include "ngraph/type/element_type_traits.hpp"
|
||||
|
||||
using namespace ngraph;
|
||||
using namespace std;
|
||||
|
||||
constexpr DiscreteTypeInfo AttributeAdapter<element::Type>::type_info;
|
||||
|
||||
class TypeInfo
|
||||
namespace
|
||||
{
|
||||
public:
|
||||
class TypeInfo
|
||||
{
|
||||
public:
|
||||
TypeInfo(size_t bitwidth,
|
||||
bool is_real,
|
||||
bool is_signed,
|
||||
@ -39,18 +39,20 @@ public:
|
||||
bool m_is_quantized;
|
||||
std::string m_cname;
|
||||
std::string m_type_name;
|
||||
};
|
||||
};
|
||||
|
||||
struct element_type_hash
|
||||
{
|
||||
struct ElementTypes
|
||||
{
|
||||
struct TypeHash
|
||||
{
|
||||
size_t operator()(element::Type_t t) const { return static_cast<size_t>(t); }
|
||||
};
|
||||
};
|
||||
|
||||
typedef unordered_map<element::Type_t, const TypeInfo, element_type_hash> element_types_map_t;
|
||||
using ElementsMap = std::unordered_map<element::Type_t, TypeInfo, TypeHash>;
|
||||
static const ElementsMap elements_map;
|
||||
};
|
||||
|
||||
static const element_types_map_t& get_type_info_map()
|
||||
{
|
||||
static element_types_map_t s_type_info_map{
|
||||
const ElementTypes::ElementsMap ElementTypes::elements_map{
|
||||
{element::Type_t::undefined,
|
||||
TypeInfo(
|
||||
std::numeric_limits<size_t>::max(), false, false, false, "undefined", "undefined")},
|
||||
@ -72,8 +74,20 @@ static const element_types_map_t& get_type_info_map()
|
||||
{element::Type_t::u32, TypeInfo(32, false, false, false, "uint32_t", "u32")},
|
||||
{element::Type_t::u64, TypeInfo(64, false, false, false, "uint64_t", "u64")},
|
||||
};
|
||||
return s_type_info_map;
|
||||
};
|
||||
|
||||
const ElementTypes::ElementsMap& get_type_info_map() { return ElementTypes::elements_map; };
|
||||
|
||||
const TypeInfo& get_type_info(element::Type_t type)
|
||||
{
|
||||
const auto& tim = get_type_info_map();
|
||||
const auto& found = tim.find(type);
|
||||
if (found == tim.end())
|
||||
{
|
||||
throw std::out_of_range{"element::Type_t not supported"};
|
||||
}
|
||||
return found->second;
|
||||
};
|
||||
} // namespace
|
||||
|
||||
std::vector<const element::Type*> element::Type::get_known_types()
|
||||
{
|
||||
@ -103,7 +117,7 @@ element::Type::Type(size_t bitwidth,
|
||||
bool is_quantized,
|
||||
const std::string& /* cname */)
|
||||
{
|
||||
for (auto& t : get_type_info_map())
|
||||
for (const auto& t : get_type_info_map())
|
||||
{
|
||||
const TypeInfo& info = t.second;
|
||||
if (bitwidth == info.m_bitwidth && is_real == info.m_is_real &&
|
||||
@ -117,7 +131,7 @@ element::Type::Type(size_t bitwidth,
|
||||
|
||||
const std::string& element::Type::c_type_string() const
|
||||
{
|
||||
return get_type_info_map().at(m_type).m_cname;
|
||||
return get_type_info(m_type).m_cname;
|
||||
}
|
||||
|
||||
size_t element::Type::size() const
|
||||
@ -132,7 +146,7 @@ size_t element::Type::hash() const
|
||||
|
||||
const std::string& element::Type::get_type_name() const
|
||||
{
|
||||
return get_type_info_map().at(m_type).m_type_name;
|
||||
return get_type_info(m_type).m_type_name;
|
||||
}
|
||||
|
||||
namespace ngraph
|
||||
@ -247,12 +261,12 @@ bool element::Type::merge(element::Type& dst, const element::Type& t1, const ele
|
||||
|
||||
bool element::Type::is_static() const
|
||||
{
|
||||
return get_type_info_map().at(m_type).m_bitwidth != 0;
|
||||
return get_type_info(m_type).m_bitwidth != 0;
|
||||
}
|
||||
|
||||
bool element::Type::is_real() const
|
||||
{
|
||||
return get_type_info_map().at(m_type).m_is_real;
|
||||
return get_type_info(m_type).m_is_real;
|
||||
}
|
||||
|
||||
bool element::Type::is_integral_number() const
|
||||
@ -262,17 +276,17 @@ bool element::Type::is_integral_number() const
|
||||
|
||||
bool element::Type::is_signed() const
|
||||
{
|
||||
return get_type_info_map().at(m_type).m_is_signed;
|
||||
return get_type_info(m_type).m_is_signed;
|
||||
}
|
||||
|
||||
bool element::Type::is_quantized() const
|
||||
{
|
||||
return get_type_info_map().at(m_type).m_is_quantized;
|
||||
return get_type_info(m_type).m_is_quantized;
|
||||
}
|
||||
|
||||
size_t element::Type::bitwidth() const
|
||||
{
|
||||
return get_type_info_map().at(m_type).m_bitwidth;
|
||||
return get_type_info(m_type).m_bitwidth;
|
||||
}
|
||||
|
||||
size_t ngraph::compiler_byte_size(element::Type_t et)
|
||||
|
Loading…
Reference in New Issue
Block a user