Validate speedup (#6779)
* Add minor speedup changes. * inline clip * reduce clip calls * more Interval::size - move to header * terminate instead of throwing exception * back to throw exception when element type was not found * rename variable
This commit is contained in:
parent
7ab92b5845
commit
0861a5c910
@ -28,7 +28,7 @@ namespace ngraph
|
|||||||
});
|
});
|
||||||
return rc;
|
return rc;
|
||||||
};
|
};
|
||||||
for (auto p : get().m_string_enums)
|
for (const auto& p : get().m_string_enums)
|
||||||
{
|
{
|
||||||
if (to_lower(p.first) == to_lower(name))
|
if (to_lower(p.first) == to_lower(name))
|
||||||
{
|
{
|
||||||
@ -41,7 +41,7 @@ namespace ngraph
|
|||||||
/// Converts enum values to strings
|
/// Converts enum values to strings
|
||||||
static const std::string& as_string(EnumType e)
|
static const std::string& as_string(EnumType e)
|
||||||
{
|
{
|
||||||
for (auto& p : get().m_string_enums)
|
for (const auto& p : get().m_string_enums)
|
||||||
{
|
{
|
||||||
if (p.second == e)
|
if (p.second == e)
|
||||||
{
|
{
|
||||||
|
@ -41,9 +41,16 @@ namespace ngraph
|
|||||||
Interval& operator=(const Interval& interval) = default;
|
Interval& operator=(const Interval& interval) = default;
|
||||||
|
|
||||||
/// \brief The number of elements in the interval. Zero if max < min.
|
/// \brief The number of elements in the interval. Zero if max < min.
|
||||||
size_type size() const;
|
size_type size() const
|
||||||
|
{
|
||||||
|
if (m_max_val == s_max)
|
||||||
|
{
|
||||||
|
return m_min_val == s_max ? 0 : s_max;
|
||||||
|
}
|
||||||
|
return m_max_val - m_min_val + 1;
|
||||||
|
}
|
||||||
/// \brief Returns true if the interval has no elements
|
/// \brief Returns true if the interval has no elements
|
||||||
bool empty() const;
|
bool empty() const { return m_min_val == s_max; }
|
||||||
/// \brief the inclusive lower bound of the interval
|
/// \brief the inclusive lower bound of the interval
|
||||||
value_type get_min_val() const { return m_min_val; }
|
value_type get_min_val() const { return m_min_val; }
|
||||||
/// \brief Set the inclusive lower bound of the interval
|
/// \brief Set the inclusive lower bound of the interval
|
||||||
@ -84,7 +91,7 @@ namespace ngraph
|
|||||||
Interval& operator&=(const Interval& interval);
|
Interval& operator&=(const Interval& interval);
|
||||||
|
|
||||||
/// \brief True if this interval includes value
|
/// \brief True if this interval includes value
|
||||||
bool contains(value_type value) const;
|
bool contains(value_type value) const { return m_min_val <= value && value <= m_max_val; }
|
||||||
/// \brief True if this interval includes all the values in interval
|
/// \brief True if this interval includes all the values in interval
|
||||||
bool contains(const Interval& interval) const;
|
bool contains(const Interval& interval) const;
|
||||||
|
|
||||||
@ -93,10 +100,6 @@ namespace ngraph
|
|||||||
|
|
||||||
protected:
|
protected:
|
||||||
void canonicalize();
|
void canonicalize();
|
||||||
static value_type clip(value_type value);
|
|
||||||
static value_type clip_times(value_type a, value_type b);
|
|
||||||
static value_type clip_add(value_type a, value_type b);
|
|
||||||
static value_type clip_minus(value_type a, value_type b);
|
|
||||||
|
|
||||||
value_type m_min_val{0};
|
value_type m_min_val{0};
|
||||||
value_type m_max_val{s_max};
|
value_type m_max_val{s_max};
|
||||||
|
@ -54,7 +54,7 @@ namespace ngraph
|
|||||||
|
|
||||||
/// \brief Constructs a PartialShape with static rank from a vector of Dimension.
|
/// \brief Constructs a PartialShape with static rank from a vector of Dimension.
|
||||||
/// \param dimensions The Dimension values for the constructed shape.
|
/// \param dimensions The Dimension values for the constructed shape.
|
||||||
PartialShape(const std::vector<Dimension>& dimensions);
|
PartialShape(std::vector<Dimension> dimensions);
|
||||||
|
|
||||||
/// \brief Constructs a PartialShape with static rank from a vector of dimensions values.
|
/// \brief Constructs a PartialShape with static rank from a vector of dimensions values.
|
||||||
/// \param dimensions The Dimension values for the constructed shape.
|
/// \param dimensions The Dimension values for the constructed shape.
|
||||||
@ -269,7 +269,7 @@ namespace ngraph
|
|||||||
|
|
||||||
private:
|
private:
|
||||||
// Private constructor for PartialShape::dynamic().
|
// Private constructor for PartialShape::dynamic().
|
||||||
PartialShape(bool rank_is_static, const std::vector<Dimension>& dimensions);
|
PartialShape(bool rank_is_static, std::vector<Dimension> dimensions);
|
||||||
|
|
||||||
// True if the shape's rank is static.
|
// True if the shape's rank is static.
|
||||||
bool m_rank_is_static;
|
bool m_rank_is_static;
|
||||||
|
@ -99,12 +99,12 @@ bool Dimension::merge(Dimension& dst, const Dimension d1, const Dimension d2)
|
|||||||
|
|
||||||
bool Dimension::broadcast_merge(Dimension& dst, const Dimension d1, const Dimension d2)
|
bool Dimension::broadcast_merge(Dimension& dst, const Dimension d1, const Dimension d2)
|
||||||
{
|
{
|
||||||
if (d1.m_dimension.size() == 1 && d1.m_dimension.get_min_val() == 1)
|
if (d1.m_dimension.get_min_val() == 1 && d1.m_dimension.size() == 1)
|
||||||
{
|
{
|
||||||
dst = d2;
|
dst = d2;
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
if (d2.m_dimension.size() == 1 && d2.m_dimension.get_min_val() == 1)
|
if (d2.m_dimension.get_min_val() == 1 && d2.m_dimension.size() == 1)
|
||||||
{
|
{
|
||||||
dst = d1;
|
dst = d1;
|
||||||
return true;
|
return true;
|
||||||
|
@ -6,6 +6,46 @@
|
|||||||
|
|
||||||
using namespace ngraph;
|
using namespace ngraph;
|
||||||
|
|
||||||
|
namespace
|
||||||
|
{
|
||||||
|
Interval::value_type clip(Interval::value_type value)
|
||||||
|
{
|
||||||
|
return std::max(Interval::value_type(0), std::min(Interval::s_max, value));
|
||||||
|
}
|
||||||
|
|
||||||
|
Interval::value_type clip_times(Interval::value_type a, Interval::value_type b)
|
||||||
|
{
|
||||||
|
if (a == 0 || b == 0)
|
||||||
|
{
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
|
else if (a == Interval::s_max || b == Interval::s_max)
|
||||||
|
{
|
||||||
|
return Interval::s_max;
|
||||||
|
}
|
||||||
|
else
|
||||||
|
{
|
||||||
|
return a * b;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Interval::value_type clip_add(Interval::value_type a, Interval::value_type b)
|
||||||
|
{
|
||||||
|
return (a == Interval::s_max || b == Interval::s_max) ? Interval::s_max : a + b;
|
||||||
|
}
|
||||||
|
Interval::value_type clip_minus(Interval::value_type a, Interval::value_type b)
|
||||||
|
{
|
||||||
|
if (a <= b)
|
||||||
|
{
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
|
if (a == Interval::s_max)
|
||||||
|
{
|
||||||
|
return Interval::s_max;
|
||||||
|
}
|
||||||
|
return a - b;
|
||||||
|
}
|
||||||
|
} // namespace
|
||||||
|
|
||||||
void Interval::canonicalize()
|
void Interval::canonicalize()
|
||||||
{
|
{
|
||||||
if (m_max_val < m_min_val)
|
if (m_max_val < m_min_val)
|
||||||
@ -28,22 +68,9 @@ Interval::Interval(value_type min_val, value_type max_val)
|
|||||||
}
|
}
|
||||||
|
|
||||||
Interval::Interval(value_type val)
|
Interval::Interval(value_type val)
|
||||||
: Interval(val, val)
|
|
||||||
{
|
{
|
||||||
}
|
m_min_val = clip(val);
|
||||||
|
m_max_val = m_min_val;
|
||||||
Interval::size_type Interval::size() const
|
|
||||||
{
|
|
||||||
if (m_max_val == s_max)
|
|
||||||
{
|
|
||||||
return m_min_val == s_max ? 0 : s_max;
|
|
||||||
}
|
|
||||||
return m_max_val - m_min_val + 1;
|
|
||||||
}
|
|
||||||
|
|
||||||
bool Interval::empty() const
|
|
||||||
{
|
|
||||||
return m_min_val == s_max;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
bool Interval::operator==(const Interval& interval) const
|
bool Interval::operator==(const Interval& interval) const
|
||||||
@ -116,55 +143,11 @@ Interval& Interval::operator&=(const Interval& interval)
|
|||||||
return *this = *this & interval;
|
return *this = *this & interval;
|
||||||
}
|
}
|
||||||
|
|
||||||
bool Interval::contains(value_type value) const
|
|
||||||
{
|
|
||||||
return m_min_val <= value && value <= m_max_val;
|
|
||||||
}
|
|
||||||
|
|
||||||
bool Interval::contains(const Interval& interval) const
|
bool Interval::contains(const Interval& interval) const
|
||||||
{
|
{
|
||||||
return contains(interval.m_min_val) && contains(interval.m_max_val);
|
return contains(interval.m_min_val) && contains(interval.m_max_val);
|
||||||
}
|
}
|
||||||
|
|
||||||
Interval::value_type Interval::clip(value_type value)
|
|
||||||
{
|
|
||||||
return std::max(value_type(0), std::min(s_max, value));
|
|
||||||
}
|
|
||||||
|
|
||||||
Interval::value_type Interval::clip_add(value_type a, value_type b)
|
|
||||||
{
|
|
||||||
return (a == s_max || b == s_max) ? s_max : a + b;
|
|
||||||
}
|
|
||||||
|
|
||||||
Interval::value_type Interval::clip_minus(value_type a, value_type b)
|
|
||||||
{
|
|
||||||
if (a <= b)
|
|
||||||
{
|
|
||||||
return 0;
|
|
||||||
}
|
|
||||||
if (a == s_max)
|
|
||||||
{
|
|
||||||
return s_max;
|
|
||||||
}
|
|
||||||
return a - b;
|
|
||||||
}
|
|
||||||
|
|
||||||
Interval::value_type Interval::clip_times(value_type a, value_type b)
|
|
||||||
{
|
|
||||||
if (a == 0 || b == 0)
|
|
||||||
{
|
|
||||||
return 0;
|
|
||||||
}
|
|
||||||
else if (a == s_max || b == s_max)
|
|
||||||
{
|
|
||||||
return s_max;
|
|
||||||
}
|
|
||||||
else
|
|
||||||
{
|
|
||||||
return a * b;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
constexpr Interval::value_type Interval::s_max;
|
constexpr Interval::value_type Interval::s_max;
|
||||||
|
|
||||||
namespace ngraph
|
namespace ngraph
|
||||||
|
@ -34,15 +34,15 @@ PartialShape::PartialShape(const Shape& shape)
|
|||||||
{
|
{
|
||||||
}
|
}
|
||||||
|
|
||||||
PartialShape::PartialShape(bool rank_is_static, const std::vector<Dimension>& dimensions)
|
PartialShape::PartialShape(bool rank_is_static, std::vector<Dimension> dimensions)
|
||||||
: m_rank_is_static(rank_is_static)
|
: m_rank_is_static(rank_is_static)
|
||||||
, m_dimensions(dimensions)
|
, m_dimensions(std::move(dimensions))
|
||||||
{
|
{
|
||||||
}
|
}
|
||||||
|
|
||||||
PartialShape::PartialShape(const std::vector<Dimension>& dimensions)
|
PartialShape::PartialShape(std::vector<Dimension> dimensions)
|
||||||
: m_rank_is_static(true)
|
: m_rank_is_static(true)
|
||||||
, m_dimensions(dimensions)
|
, m_dimensions(std::move(dimensions))
|
||||||
{
|
{
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -387,7 +387,7 @@ bool PartialShape::broadcast_merge_into(PartialShape& dst,
|
|||||||
i < (new_rank - src_rank) ? Dimension(1) : src[i - (new_rank - src_rank)];
|
i < (new_rank - src_rank) ? Dimension(1) : src[i - (new_rank - src_rank)];
|
||||||
success &= Dimension::broadcast_merge(dims[i], dsti, srci);
|
success &= Dimension::broadcast_merge(dims[i], dsti, srci);
|
||||||
}
|
}
|
||||||
dst = PartialShape(dims);
|
dst = PartialShape(std::move(dims));
|
||||||
return success;
|
return success;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -12,45 +12,47 @@
|
|||||||
#include "ngraph/type/element_type_traits.hpp"
|
#include "ngraph/type/element_type_traits.hpp"
|
||||||
|
|
||||||
using namespace ngraph;
|
using namespace ngraph;
|
||||||
using namespace std;
|
|
||||||
|
|
||||||
constexpr DiscreteTypeInfo AttributeAdapter<element::Type>::type_info;
|
constexpr DiscreteTypeInfo AttributeAdapter<element::Type>::type_info;
|
||||||
|
namespace
|
||||||
class TypeInfo
|
|
||||||
{
|
{
|
||||||
public:
|
class TypeInfo
|
||||||
TypeInfo(size_t bitwidth,
|
|
||||||
bool is_real,
|
|
||||||
bool is_signed,
|
|
||||||
bool is_quantized,
|
|
||||||
const std::string& cname,
|
|
||||||
const std::string& type_name)
|
|
||||||
: m_bitwidth{bitwidth}
|
|
||||||
, m_is_real{is_real}
|
|
||||||
, m_is_signed{is_signed}
|
|
||||||
, m_is_quantized{is_quantized}
|
|
||||||
, m_cname{cname}
|
|
||||||
, m_type_name{type_name}
|
|
||||||
{
|
{
|
||||||
}
|
public:
|
||||||
size_t m_bitwidth;
|
TypeInfo(size_t bitwidth,
|
||||||
bool m_is_real;
|
bool is_real,
|
||||||
bool m_is_signed;
|
bool is_signed,
|
||||||
bool m_is_quantized;
|
bool is_quantized,
|
||||||
std::string m_cname;
|
const std::string& cname,
|
||||||
std::string m_type_name;
|
const std::string& type_name)
|
||||||
};
|
: m_bitwidth{bitwidth}
|
||||||
|
, m_is_real{is_real}
|
||||||
|
, m_is_signed{is_signed}
|
||||||
|
, m_is_quantized{is_quantized}
|
||||||
|
, m_cname{cname}
|
||||||
|
, m_type_name{type_name}
|
||||||
|
{
|
||||||
|
}
|
||||||
|
size_t m_bitwidth;
|
||||||
|
bool m_is_real;
|
||||||
|
bool m_is_signed;
|
||||||
|
bool m_is_quantized;
|
||||||
|
std::string m_cname;
|
||||||
|
std::string m_type_name;
|
||||||
|
};
|
||||||
|
|
||||||
struct element_type_hash
|
struct ElementTypes
|
||||||
{
|
{
|
||||||
size_t operator()(element::Type_t t) const { return static_cast<size_t>(t); }
|
struct TypeHash
|
||||||
};
|
{
|
||||||
|
size_t operator()(element::Type_t t) const { return static_cast<size_t>(t); }
|
||||||
|
};
|
||||||
|
|
||||||
typedef unordered_map<element::Type_t, const TypeInfo, element_type_hash> element_types_map_t;
|
using ElementsMap = std::unordered_map<element::Type_t, TypeInfo, TypeHash>;
|
||||||
|
static const ElementsMap elements_map;
|
||||||
|
};
|
||||||
|
|
||||||
static const element_types_map_t& get_type_info_map()
|
const ElementTypes::ElementsMap ElementTypes::elements_map{
|
||||||
{
|
|
||||||
static element_types_map_t s_type_info_map{
|
|
||||||
{element::Type_t::undefined,
|
{element::Type_t::undefined,
|
||||||
TypeInfo(
|
TypeInfo(
|
||||||
std::numeric_limits<size_t>::max(), false, false, false, "undefined", "undefined")},
|
std::numeric_limits<size_t>::max(), false, false, false, "undefined", "undefined")},
|
||||||
@ -72,8 +74,20 @@ static const element_types_map_t& get_type_info_map()
|
|||||||
{element::Type_t::u32, TypeInfo(32, false, false, false, "uint32_t", "u32")},
|
{element::Type_t::u32, TypeInfo(32, false, false, false, "uint32_t", "u32")},
|
||||||
{element::Type_t::u64, TypeInfo(64, false, false, false, "uint64_t", "u64")},
|
{element::Type_t::u64, TypeInfo(64, false, false, false, "uint64_t", "u64")},
|
||||||
};
|
};
|
||||||
return s_type_info_map;
|
|
||||||
};
|
const ElementTypes::ElementsMap& get_type_info_map() { return ElementTypes::elements_map; };
|
||||||
|
|
||||||
|
const TypeInfo& get_type_info(element::Type_t type)
|
||||||
|
{
|
||||||
|
const auto& tim = get_type_info_map();
|
||||||
|
const auto& found = tim.find(type);
|
||||||
|
if (found == tim.end())
|
||||||
|
{
|
||||||
|
throw std::out_of_range{"element::Type_t not supported"};
|
||||||
|
}
|
||||||
|
return found->second;
|
||||||
|
};
|
||||||
|
} // namespace
|
||||||
|
|
||||||
std::vector<const element::Type*> element::Type::get_known_types()
|
std::vector<const element::Type*> element::Type::get_known_types()
|
||||||
{
|
{
|
||||||
@ -103,7 +117,7 @@ element::Type::Type(size_t bitwidth,
|
|||||||
bool is_quantized,
|
bool is_quantized,
|
||||||
const std::string& /* cname */)
|
const std::string& /* cname */)
|
||||||
{
|
{
|
||||||
for (auto& t : get_type_info_map())
|
for (const auto& t : get_type_info_map())
|
||||||
{
|
{
|
||||||
const TypeInfo& info = t.second;
|
const TypeInfo& info = t.second;
|
||||||
if (bitwidth == info.m_bitwidth && is_real == info.m_is_real &&
|
if (bitwidth == info.m_bitwidth && is_real == info.m_is_real &&
|
||||||
@ -117,7 +131,7 @@ element::Type::Type(size_t bitwidth,
|
|||||||
|
|
||||||
const std::string& element::Type::c_type_string() const
|
const std::string& element::Type::c_type_string() const
|
||||||
{
|
{
|
||||||
return get_type_info_map().at(m_type).m_cname;
|
return get_type_info(m_type).m_cname;
|
||||||
}
|
}
|
||||||
|
|
||||||
size_t element::Type::size() const
|
size_t element::Type::size() const
|
||||||
@ -132,7 +146,7 @@ size_t element::Type::hash() const
|
|||||||
|
|
||||||
const std::string& element::Type::get_type_name() const
|
const std::string& element::Type::get_type_name() const
|
||||||
{
|
{
|
||||||
return get_type_info_map().at(m_type).m_type_name;
|
return get_type_info(m_type).m_type_name;
|
||||||
}
|
}
|
||||||
|
|
||||||
namespace ngraph
|
namespace ngraph
|
||||||
@ -247,12 +261,12 @@ bool element::Type::merge(element::Type& dst, const element::Type& t1, const ele
|
|||||||
|
|
||||||
bool element::Type::is_static() const
|
bool element::Type::is_static() const
|
||||||
{
|
{
|
||||||
return get_type_info_map().at(m_type).m_bitwidth != 0;
|
return get_type_info(m_type).m_bitwidth != 0;
|
||||||
}
|
}
|
||||||
|
|
||||||
bool element::Type::is_real() const
|
bool element::Type::is_real() const
|
||||||
{
|
{
|
||||||
return get_type_info_map().at(m_type).m_is_real;
|
return get_type_info(m_type).m_is_real;
|
||||||
}
|
}
|
||||||
|
|
||||||
bool element::Type::is_integral_number() const
|
bool element::Type::is_integral_number() const
|
||||||
@ -262,17 +276,17 @@ bool element::Type::is_integral_number() const
|
|||||||
|
|
||||||
bool element::Type::is_signed() const
|
bool element::Type::is_signed() const
|
||||||
{
|
{
|
||||||
return get_type_info_map().at(m_type).m_is_signed;
|
return get_type_info(m_type).m_is_signed;
|
||||||
}
|
}
|
||||||
|
|
||||||
bool element::Type::is_quantized() const
|
bool element::Type::is_quantized() const
|
||||||
{
|
{
|
||||||
return get_type_info_map().at(m_type).m_is_quantized;
|
return get_type_info(m_type).m_is_quantized;
|
||||||
}
|
}
|
||||||
|
|
||||||
size_t element::Type::bitwidth() const
|
size_t element::Type::bitwidth() const
|
||||||
{
|
{
|
||||||
return get_type_info_map().at(m_type).m_bitwidth;
|
return get_type_info(m_type).m_bitwidth;
|
||||||
}
|
}
|
||||||
|
|
||||||
size_t ngraph::compiler_byte_size(element::Type_t et)
|
size_t ngraph::compiler_byte_size(element::Type_t et)
|
||||||
|
Loading…
Reference in New Issue
Block a user