Files
openvino/inference-engine/include/ie_precision.hpp

421 lines
13 KiB
C++
Raw Normal View History

2020-02-11 22:48:49 +03:00
// Copyright (C) 2018-2020 Intel Corporation
2018-10-16 13:45:03 +03:00
// SPDX-License-Identifier: Apache-2.0
//
/**
* @brief A header file that provides class for describing precision of data
2020-02-11 22:48:49 +03:00
*
2018-10-16 13:45:03 +03:00
* @file ie_precision.hpp
*/
#pragma once
#include <string>
2020-02-11 22:48:49 +03:00
#include <unordered_map>
#include <vector>
2018-10-16 13:45:03 +03:00
#include "details/ie_exception.hpp"
namespace InferenceEngine {
/**
* @brief This class holds precision value and provides precision related operations
*/
class Precision {
public:
/** Enum to specify of different */
enum ePrecision : uint8_t {
UNSPECIFIED = 255, /**< Unspecified value. Used by default */
2020-02-11 22:48:49 +03:00
MIXED = 0, /**< Mixed value. Can be received from network. No applicable for tensors */
FP32 = 10, /**< 32bit floating point value */
FP16 = 11, /**< 16bit floating point value, 5 bit for exponent, 10 bit for mantisa */
BF16 = 12, /**< 16bit floating point value, 8 bit for exponent, 7 bit for mantisa*/
2020-02-11 22:48:49 +03:00
Q78 = 20, /**< 16bit specific signed fixed point precision */
I16 = 30, /**< 16bit signed integer value */
U8 = 40, /**< 8bit unsigned integer value */
I8 = 50, /**< 8bit signed integer value */
U16 = 60, /**< 16bit unsigned integer value */
I32 = 70, /**< 32bit signed integer value */
I64 = 72, /**< 64bit signed integer value */
2020-04-13 21:17:23 +03:00
U64 = 73, /**< 64bit unsigned integer value */
2020-02-11 22:48:49 +03:00
BIN = 71, /**< 1bit integer value */
BOOL = 41, /**< 8bit bool type */
CUSTOM = 80 /**< custom precision has it's own name and size of elements */
2018-10-16 13:45:03 +03:00
};
private:
struct PrecisionInfo {
/** @brief Size of underlined element */
size_t bitsSize = 0;
/** @brief Null terminated string with precision name */
2020-02-11 22:48:49 +03:00
const char* name = "UNSPECIFIED";
2018-10-16 13:45:03 +03:00
2020-02-11 22:48:49 +03:00
bool isFloat = false;
2018-10-16 13:45:03 +03:00
ePrecision value = Precision::UNSPECIFIED;
};
PrecisionInfo precisionInfo;
public:
/** @brief Default constructor */
2020-02-11 22:48:49 +03:00
Precision() = default;
2018-10-16 13:45:03 +03:00
/** @brief Constructor with specified precision */
2020-02-11 22:48:49 +03:00
Precision(const Precision::ePrecision value) { // NOLINT
2018-10-16 13:45:03 +03:00
precisionInfo = getPrecisionInfo(value);
}
/**
* @brief Custom precision constructor
2019-10-04 19:26:43 +03:00
*
* @param bitsSize size of elements
2018-10-16 13:45:03 +03:00
* @param name optional name string, used in serialisation
*/
2020-02-11 22:48:49 +03:00
explicit Precision(size_t bitsSize, const char* name = nullptr) {
2018-10-16 13:45:03 +03:00
if (bitsSize == 0) {
THROW_IE_EXCEPTION << "Precision with 0 elements size not supported";
}
precisionInfo.bitsSize = bitsSize;
if (name == nullptr) {
precisionInfo.name = "CUSTOM";
} else {
precisionInfo.name = name;
}
precisionInfo.value = CUSTOM;
}
/** @brief Creates custom precision with specific underlined type */
template <class T>
2020-02-11 22:48:49 +03:00
static Precision fromType(const char* typeName = nullptr) {
2018-10-16 13:45:03 +03:00
return Precision(8 * sizeof(T), typeName == nullptr ? typeid(T).name() : typeName);
}
2019-04-12 18:25:53 +03:00
/** @brief checks whether given storage class T can be used to store objects of current precision */
template <class T>
2020-02-11 22:48:49 +03:00
bool hasStorageType(const char* typeName = nullptr) const noexcept {
try {
if (precisionInfo.value != BIN) {
if (sizeof(T) != size()) {
return false;
}
2019-04-12 18:25:53 +03:00
}
2020-02-11 22:48:49 +03:00
#define CASE(x, y) \
case x: \
return std::is_same<T, y>()
#define CASE2(x, y1, y2) \
case x: \
return std::is_same<T, y1>() || std::is_same<T, y2>()
switch (precisionInfo.value) {
CASE(FP32, float);
CASE2(FP16, int16_t, uint16_t);
CASE2(BF16, int16_t, uint16_t);
CASE(I16, int16_t);
CASE(I32, int32_t);
2019-08-09 19:02:42 +03:00
CASE(I64, int64_t);
2020-04-13 21:17:23 +03:00
CASE(U64, uint64_t);
CASE(U16, uint16_t);
CASE(U8, uint8_t);
CASE(I8, int8_t);
2020-02-11 22:48:49 +03:00
CASE(BOOL, uint8_t);
CASE2(Q78, int16_t, uint16_t);
CASE2(BIN, int8_t, uint8_t);
2020-02-11 22:48:49 +03:00
default:
return areSameStrings(name(), typeName == nullptr ? typeid(T).name() : typeName);
#undef CASE
#undef CASE2
}
} catch (...) {
return false;
}
}
2018-10-16 13:45:03 +03:00
/** @brief Equality operator with Precision object */
2020-02-11 22:48:49 +03:00
bool operator==(const Precision& p) const noexcept {
return precisionInfo.value == p && precisionInfo.bitsSize == p.precisionInfo.bitsSize &&
areSameStrings(precisionInfo.name, p.precisionInfo.name);
2018-10-16 13:45:03 +03:00
}
/** @brief Equality operator with ePrecision enum value */
2020-02-11 22:48:49 +03:00
bool operator==(const ePrecision p) const noexcept {
2018-10-16 13:45:03 +03:00
return precisionInfo.value == p;
}
/** @brief Inequality operator with ePrecision enum value */
2020-02-11 22:48:49 +03:00
bool operator!=(const ePrecision p) const noexcept {
2018-10-16 13:45:03 +03:00
return precisionInfo.value != p;
}
/** @brief Assignment operator with ePrecision enum value */
2020-02-11 22:48:49 +03:00
Precision& operator=(const ePrecision p) noexcept {
2018-10-16 13:45:03 +03:00
precisionInfo = getPrecisionInfo(p);
return *this;
}
/** @brief Cast operator to a bool */
2018-10-16 13:45:03 +03:00
explicit operator bool() const noexcept {
return precisionInfo.value != UNSPECIFIED;
}
/** @brief Logical negation operator */
2020-02-11 22:48:49 +03:00
bool operator!() const noexcept {
2018-10-16 13:45:03 +03:00
return precisionInfo.value == UNSPECIFIED;
}
/** @brief Cast operator to a ePrecision */
2020-02-11 22:48:49 +03:00
operator Precision::ePrecision() const noexcept {
return precisionInfo.value;
}
2020-04-13 21:17:23 +03:00
/**
* @brief Gets the precision value of type ePrecision.
* @return The preccision value.
*/
2020-02-11 22:48:49 +03:00
constexpr uint8_t getPrecVal() const noexcept {
2018-10-16 13:45:03 +03:00
return precisionInfo.value;
}
/** @brief Getter of precision name */
2020-02-11 22:48:49 +03:00
const char* name() const noexcept {
2018-10-16 13:45:03 +03:00
return precisionInfo.name;
}
/** @brief Creates from string with precision name */
2020-02-11 22:48:49 +03:00
static Precision FromStr(const std::string& str) {
static std::unordered_map<std::string, ePrecision> names = {
#define PRECISION_NAME(s) {#s, s}
2020-04-13 21:17:23 +03:00
PRECISION_NAME(Q78), PRECISION_NAME(U8), PRECISION_NAME(I8), PRECISION_NAME(I16),
PRECISION_NAME(I32), PRECISION_NAME(I64), PRECISION_NAME(U64), PRECISION_NAME(U16),
2020-04-13 21:17:23 +03:00
PRECISION_NAME(FP32), PRECISION_NAME(FP16), PRECISION_NAME(MIXED), PRECISION_NAME(BIN),
PRECISION_NAME(BOOL), PRECISION_NAME(BF16),
2020-02-11 22:48:49 +03:00
#undef PRECISION_NAME
2018-10-16 13:45:03 +03:00
};
auto i = names.find(str);
return i == names.end() ? Precision() : Precision(i->second);
}
/**
2019-10-04 19:26:43 +03:00
* @brief Returns size of single element of that precision in bits
2020-02-11 22:48:49 +03:00
* @returns Number of bytes per element
2018-10-16 13:45:03 +03:00
*/
size_t size() const {
if (precisionInfo.bitsSize == 0) {
THROW_IE_EXCEPTION << " cannot estimate element if precision is " << precisionInfo.name;
}
return precisionInfo.bitsSize >> 3;
}
2020-04-13 21:17:23 +03:00
/**
* @brief Checks if it is a floating point value
* @return True if precision is float point, `false` otherwise
*/
bool is_float() const noexcept {
2018-10-16 13:45:03 +03:00
return precisionInfo.isFloat;
}
2020-04-13 21:17:23 +03:00
/**
* @brief Checks if it is a signed value
* @return True if precision is signed, `false` otherwise
*/
2020-02-11 22:48:49 +03:00
bool isSigned() const noexcept {
return (precisionInfo.value == Precision::UNSPECIFIED) || (precisionInfo.value == Precision::MIXED) ||
(precisionInfo.value == Precision::FP32) || (precisionInfo.value == Precision::FP16) ||
(precisionInfo.value == Precision::Q78) || (precisionInfo.value == Precision::I16) ||
(precisionInfo.value == Precision::I8) || (precisionInfo.value == Precision::I32) ||
(precisionInfo.value == Precision::I64) || (precisionInfo.value == Precision::BIN) ||
(precisionInfo.value == Precision::CUSTOM);
}
protected:
2019-10-04 19:26:43 +03:00
/**
2020-04-13 21:17:23 +03:00
* @brief Creates PrecisionInfo by @p precision with a specified name
* @tparam precision A precision to create PrecisionInfo for
2019-10-04 19:26:43 +03:00
* @param name Name of precision
2020-04-13 21:17:23 +03:00
* @return A PrecisionInfo object
2019-10-04 19:26:43 +03:00
*/
2020-02-11 22:48:49 +03:00
template <Precision::ePrecision precision>
static PrecisionInfo makePrecisionInfo(const char* name);
2018-10-16 13:45:03 +03:00
2019-10-04 19:26:43 +03:00
/**
* @brief Compare two c-strings
*
* @param l Const pointer to first string
* @param r Const pointer to another string
* @returns True if strings are the same
*/
2020-02-11 22:48:49 +03:00
static bool areSameStrings(const char* l, const char* r) noexcept {
if (l == r) return true;
2018-10-16 13:45:03 +03:00
2020-02-11 22:48:49 +03:00
if (l == nullptr || r == nullptr) return false;
2018-10-16 13:45:03 +03:00
for (; *l && *r; l++, r++) {
if (*l != *r) return false;
}
return *l == *r;
}
2019-10-04 19:26:43 +03:00
/**
* @brief Return PrecisionInfo
*/
2018-10-16 13:45:03 +03:00
static PrecisionInfo getPrecisionInfo(ePrecision v) {
2020-02-11 22:48:49 +03:00
#define CASE(x) \
case x: \
return makePrecisionInfo<x>(#x);
2018-10-16 13:45:03 +03:00
switch (v) {
CASE(FP32);
CASE(FP16);
CASE(BF16);
2018-10-16 13:45:03 +03:00
CASE(I16);
CASE(I32);
2019-08-09 19:02:42 +03:00
CASE(I64);
2020-04-13 21:17:23 +03:00
CASE(U64);
2018-10-16 13:45:03 +03:00
CASE(U16);
CASE(U8);
CASE(I8);
CASE(Q78);
CASE(MIXED);
2019-04-12 18:25:53 +03:00
CASE(BIN);
2020-02-11 22:48:49 +03:00
CASE(BOOL);
default:
return makePrecisionInfo<UNSPECIFIED>("UNSPECIFIED");
2018-10-16 13:45:03 +03:00
#undef CASE
}
}
};
/**
* @brief Particular precision traits
*/
2020-02-11 22:48:49 +03:00
template <Precision::ePrecision p>
struct PrecisionTrait {};
2018-10-16 13:45:03 +03:00
/** @cond INTERNAL */
2020-02-11 22:48:49 +03:00
template <>
2018-10-16 13:45:03 +03:00
struct PrecisionTrait<Precision::FP32> {
using value_type = float;
};
2020-02-11 22:48:49 +03:00
template <>
2018-10-16 13:45:03 +03:00
struct PrecisionTrait<Precision::FP16> {
using value_type = int16_t;
2018-10-16 13:45:03 +03:00
};
2020-02-11 22:48:49 +03:00
template <>
struct PrecisionTrait<Precision::BF16> {
using value_type = int16_t;
};
template<>
2018-10-16 13:45:03 +03:00
struct PrecisionTrait<Precision::Q78> {
using value_type = uint16_t;
};
2020-02-11 22:48:49 +03:00
template <>
2018-10-16 13:45:03 +03:00
struct PrecisionTrait<Precision::I16> {
using value_type = int16_t;
};
2020-02-11 22:48:49 +03:00
template <>
2018-10-16 13:45:03 +03:00
struct PrecisionTrait<Precision::U16> {
using value_type = uint16_t;
};
2020-02-11 22:48:49 +03:00
template <>
2018-10-16 13:45:03 +03:00
struct PrecisionTrait<Precision::U8> {
using value_type = uint8_t;
};
2020-02-11 22:48:49 +03:00
template <>
2018-10-16 13:45:03 +03:00
struct PrecisionTrait<Precision::I8> {
using value_type = int8_t;
};
2020-02-11 22:48:49 +03:00
template <>
struct PrecisionTrait<Precision::BOOL> {
using value_type = uint8_t;
};
template <>
2018-10-16 13:45:03 +03:00
struct PrecisionTrait<Precision::I32> {
using value_type = int32_t;
};
2020-02-11 22:48:49 +03:00
template <>
2019-08-09 19:02:42 +03:00
struct PrecisionTrait<Precision::I64> {
using value_type = int64_t;
};
2020-02-11 22:48:49 +03:00
template <>
2020-04-13 21:17:23 +03:00
struct PrecisionTrait<Precision::U64> {
using value_type = uint64_t;
};
template <>
2019-04-12 18:25:53 +03:00
struct PrecisionTrait<Precision::BIN> {
using value_type = int8_t;
};
2018-10-16 13:45:03 +03:00
2020-02-11 22:48:49 +03:00
template <class T>
2018-10-16 13:45:03 +03:00
inline uint8_t type_size_or_zero() {
return sizeof(T);
}
2020-02-11 22:48:49 +03:00
template <>
2018-10-16 13:45:03 +03:00
struct PrecisionTrait<Precision::UNSPECIFIED> {
using value_type = void;
};
2020-02-11 22:48:49 +03:00
template <>
struct PrecisionTrait<Precision::MIXED> : PrecisionTrait<Precision::UNSPECIFIED> {};
2018-10-16 13:45:03 +03:00
2020-02-11 22:48:49 +03:00
template <>
2018-10-16 13:45:03 +03:00
inline uint8_t type_size_or_zero<void>() {
return 0;
}
2020-02-11 22:48:49 +03:00
template <Precision::ePrecision T>
inline typename std::enable_if<std::is_same<std::integral_constant<Precision::ePrecision, Precision::FP16>,
std::integral_constant<Precision::ePrecision, T>>::value,
bool>::type
is_floating() {
2018-10-16 13:45:03 +03:00
return true;
}
2020-02-11 22:48:49 +03:00
template <Precision::ePrecision T>
inline typename std::enable_if<!std::is_same<std::integral_constant<Precision::ePrecision, Precision::FP16>,
std::integral_constant<Precision::ePrecision, T>>::value,
bool>::type
is_floating() {
2018-10-16 13:45:03 +03:00
return std::is_floating_point<typename PrecisionTrait<T>::value_type>::value;
}
2020-02-11 22:48:49 +03:00
template <Precision::ePrecision precision>
inline Precision::PrecisionInfo Precision::makePrecisionInfo(const char* name) {
2018-10-16 13:45:03 +03:00
Precision::PrecisionInfo info;
info.name = name;
2019-04-12 18:25:53 +03:00
size_t nBits = precision == BIN ? 1 : 8;
2019-04-12 18:25:53 +03:00
info.bitsSize = nBits * type_size_or_zero<typename PrecisionTrait<precision>::value_type>();
2018-10-16 13:45:03 +03:00
info.isFloat = is_floating<precision>();
info.value = precision;
return info;
}
2020-02-11 22:48:49 +03:00
inline std::ostream& operator<<(std::ostream& out, const InferenceEngine::Precision& p) {
2018-10-16 13:45:03 +03:00
return out << p.name();
}
2020-02-11 22:48:49 +03:00
inline std::ostream& operator<<(std::ostream& out, const InferenceEngine::Precision::ePrecision& p) {
2018-10-16 13:45:03 +03:00
return out << Precision(p).name();
}
2020-02-11 22:48:49 +03:00
inline std::ostream& operator<<(std::ostream& os, const std::vector<Precision>& values) {
os << "{ ";
for (size_t i = 0; i < values.size(); ++i) {
os << values[i];
if (i != (values.size() - 1ul)) {
os << ", ";
}
}
os << " }";
return os;
}
inline constexpr uint32_t getPrecisionMask(
InferenceEngine::Precision::ePrecision precision1, InferenceEngine::Precision::ePrecision precision2,
InferenceEngine::Precision::ePrecision precision3 = InferenceEngine::Precision::MIXED,
InferenceEngine::Precision::ePrecision precision4 = InferenceEngine::Precision::MIXED) {
2019-10-04 19:26:43 +03:00
return (precision1) | (precision2 << 8) | (precision3 << 16) | (precision4 << 24);
}
2018-10-16 13:45:03 +03:00
/** @endcond */
} // namespace InferenceEngine