Int4 support for nGraph (#4712)
* Start int4 type * Changed int4 type * Fixed code style * Added u4 type * Added tests * Add U4 I4 precisions to the IE * Removed redundant tests * Added convert precision * Fixed tests * Fixed C tests * Fixed logic for signed values * Fixed nGraph tests * Added more tests * Fixed code style * Update inference-engine/src/transformations/src/transformations/convert_precision.cpp Co-authored-by: Evgenya Stepyreva <evgenya.stepyreva@intel.com> Co-authored-by: Evgenya Stepyreva <evgenya.stepyreva@intel.com>
This commit is contained in:
@@ -121,15 +121,19 @@ enum precision_e{
|
||||
FP16 = 11, /**< 16bit floating point value */
|
||||
|
||||
BF16 = 12, /**< 16bit floating point value, 8 bit for exponent, 7 bit for mantisa*/
|
||||
|
||||
|
||||
FP64 = 13, /**< 64bit floating point value */
|
||||
|
||||
Q78 = 20, /**< 16bit specific signed fixed point precision */
|
||||
|
||||
I16 = 30, /**< 16bit signed integer value */
|
||||
|
||||
U4 = 39, /**< 4bit unsigned integer value */
|
||||
|
||||
U8 = 40, /**< 8bit unsigned integer value */
|
||||
|
||||
I4 = 49, /**< 4bit signed integer value */
|
||||
|
||||
I8 = 50, /**< 8bit signed integer value */
|
||||
|
||||
U16 = 60, /**< 16bit unsigned integer value */
|
||||
|
||||
@@ -177,7 +177,9 @@ typedef enum {
|
||||
FP64 = 13, /**< 64bit floating point value */
|
||||
Q78 = 20, /**< 16bit specific signed fixed point precision */
|
||||
I16 = 30, /**< 16bit signed integer value */
|
||||
U4 = 39, /**< 4bit unsigned integer value */
|
||||
U8 = 40, /**< 8bit unsigned integer value */
|
||||
I4 = 49, /**< 4bit signed integer value */
|
||||
I8 = 50, /**< 8bit signed integer value */
|
||||
U16 = 60, /**< 16bit unsigned integer value */
|
||||
I32 = 70, /**< 32bit signed integer value */
|
||||
|
||||
@@ -82,7 +82,9 @@ std::map<IE::Precision, precision_e> precision_map = {{IE::Precision::UNSPECIFIE
|
||||
{IE::Precision::FP64, precision_e::FP64},
|
||||
{IE::Precision::Q78, precision_e::Q78},
|
||||
{IE::Precision::I16, precision_e::I16},
|
||||
{IE::Precision::U4, precision_e::U4},
|
||||
{IE::Precision::U8, precision_e::U8},
|
||||
{IE::Precision::I4, precision_e::I4},
|
||||
{IE::Precision::I8, precision_e::I8},
|
||||
{IE::Precision::U16, precision_e::U16},
|
||||
{IE::Precision::I32, precision_e::I32},
|
||||
@@ -1369,7 +1371,7 @@ IEStatusCode ie_blob_make_memory(const tensor_desc_t *tensorDesc, ie_blob_t **bl
|
||||
_blob->object = IE::make_shared_blob<uint8_t>(tensor);
|
||||
} else if (prec == IE::Precision::U16) {
|
||||
_blob->object = IE::make_shared_blob<uint16_t>(tensor);
|
||||
} else if (prec == IE::Precision::I8 || prec == IE::Precision::BIN) {
|
||||
} else if (prec == IE::Precision::I8 || prec == IE::Precision::BIN || prec == IE::Precision::I4 || prec == IE::Precision::U4) {
|
||||
_blob->object = IE::make_shared_blob<int8_t>(tensor);
|
||||
} else if (prec == IE::Precision::I16 || prec == IE::Precision::FP16 || prec == IE::Precision::Q78) {
|
||||
_blob->object = IE::make_shared_blob<int16_t>(tensor);
|
||||
@@ -1434,7 +1436,7 @@ IEStatusCode ie_blob_make_memory_from_preallocated(const tensor_desc_t *tensorDe
|
||||
} else if (prec == IE::Precision::U16) {
|
||||
uint16_t *p = reinterpret_cast<uint16_t *>(ptr);
|
||||
_blob->object = IE::make_shared_blob(tensor, p, size);
|
||||
} else if (prec == IE::Precision::I8 || prec == IE::Precision::BIN) {
|
||||
} else if (prec == IE::Precision::I8 || prec == IE::Precision::BIN || prec == IE::Precision::I4 || prec == IE::Precision::U4) {
|
||||
int8_t *p = reinterpret_cast<int8_t *>(ptr);
|
||||
_blob->object = IE::make_shared_blob(tensor, p, size);
|
||||
} else if (prec == IE::Precision::I16 || prec == IE::Precision::FP16 || prec == IE::Precision::Q78) {
|
||||
|
||||
@@ -19,7 +19,7 @@ import numpy as np
|
||||
from enum import Enum
|
||||
|
||||
supported_precisions = ["FP32", "FP64", "FP16", "I64", "U64", "I32", "U32",
|
||||
"I16", "I8", "U16", "U8", "BOOL", "BIN", "BF16"]
|
||||
"I16", "I4", "I8", "U16", "U4", "U8", "BOOL", "BIN", "BF16"]
|
||||
|
||||
known_plugins = ['CPU', 'GPU', 'FPGA', 'MYRIAD', 'HETERO', 'HDDL', 'MULTI']
|
||||
|
||||
@@ -36,7 +36,9 @@ format_map = {
|
||||
'U32' : np.uint32,
|
||||
'I16' : np.int16,
|
||||
'U16' : np.uint16,
|
||||
'I4' : np.int8,
|
||||
'I8' : np.int8,
|
||||
'U4' : np.int8,
|
||||
'U8' : np.uint8,
|
||||
'BOOL' : np.uint8,
|
||||
'BIN' : np.int8,
|
||||
|
||||
@@ -161,7 +161,7 @@ cdef class Blob:
|
||||
self._ptr = C.make_shared_blob[uint16_t](c_tensor_desc)
|
||||
elif precision == "U8" or precision == "BOOL":
|
||||
self._ptr = C.make_shared_blob[uint8_t](c_tensor_desc)
|
||||
elif precision == "I8" or precision == "BIN":
|
||||
elif precision == "I8" or precision == "BIN" or precision == "I4" or precision == "U4":
|
||||
self._ptr = C.make_shared_blob[int8_t](c_tensor_desc)
|
||||
elif precision == "I32":
|
||||
self._ptr = C.make_shared_blob[int32_t](c_tensor_desc)
|
||||
@@ -205,7 +205,7 @@ cdef class Blob:
|
||||
elif precision == "U8" or precision == "BOOL":
|
||||
U8_array_memview = self._array_data
|
||||
self._ptr = C.make_shared_blob[uint8_t](c_tensor_desc, &U8_array_memview[0], U8_array_memview.shape[0])
|
||||
elif precision == "I8" or precision == "BIN":
|
||||
elif precision == "I8" or precision == "BIN" or precision == "I4" or precision == "U4":
|
||||
I8_array_memview = self._array_data
|
||||
self._ptr = C.make_shared_blob[int8_t](c_tensor_desc, &I8_array_memview[0], I8_array_memview.shape[0])
|
||||
elif precision == "I32":
|
||||
|
||||
@@ -33,7 +33,9 @@ public:
|
||||
FP64 = 13, /**< 64bit floating point value */
|
||||
Q78 = 20, /**< 16bit specific signed fixed point precision */
|
||||
I16 = 30, /**< 16bit signed integer value */
|
||||
U4 = 39, /**< 4bit unsigned integer value */
|
||||
U8 = 40, /**< 8bit unsigned integer value */
|
||||
I4 = 49, /**< 4bit signed integer value */
|
||||
I8 = 50, /**< 8bit signed integer value */
|
||||
U16 = 60, /**< 16bit unsigned integer value */
|
||||
I32 = 70, /**< 32bit signed integer value */
|
||||
@@ -116,10 +118,12 @@ public:
|
||||
CASE(FP64, double);
|
||||
CASE2(FP16, int16_t, uint16_t);
|
||||
CASE2(BF16, int16_t, uint16_t);
|
||||
CASE2(I4, int8_t, uint8_t);
|
||||
CASE(I8, int8_t);
|
||||
CASE(I16, int16_t);
|
||||
CASE(I32, int32_t);
|
||||
CASE(I64, int64_t);
|
||||
CASE(U4, uint8_t);
|
||||
CASE(U8, uint8_t);
|
||||
CASE(U16, uint16_t);
|
||||
CASE(U32, uint32_t);
|
||||
@@ -224,8 +228,8 @@ public:
|
||||
static const std::unordered_map<std::string, ePrecision> names = {
|
||||
#define PRECISION_NAME(s) {#s, s}
|
||||
PRECISION_NAME(Q78), PRECISION_NAME(BOOL), PRECISION_NAME(BF16),
|
||||
PRECISION_NAME(I8), PRECISION_NAME(I16), PRECISION_NAME(I32), PRECISION_NAME(I64),
|
||||
PRECISION_NAME(U8), PRECISION_NAME(U16), PRECISION_NAME(U32), PRECISION_NAME(U64),
|
||||
PRECISION_NAME(I4), PRECISION_NAME(I8), PRECISION_NAME(I16), PRECISION_NAME(I32), PRECISION_NAME(I64),
|
||||
PRECISION_NAME(U4), PRECISION_NAME(U8), PRECISION_NAME(U16), PRECISION_NAME(U32), PRECISION_NAME(U64),
|
||||
PRECISION_NAME(FP32), PRECISION_NAME(FP64), PRECISION_NAME(FP16), PRECISION_NAME(MIXED),
|
||||
PRECISION_NAME(BIN),
|
||||
#undef PRECISION_NAME
|
||||
@@ -264,7 +268,7 @@ public:
|
||||
(precisionInfo.value == Precision::I16) || (precisionInfo.value == Precision::I8) ||
|
||||
(precisionInfo.value == Precision::I32) || (precisionInfo.value == Precision::I64) ||
|
||||
(precisionInfo.value == Precision::BIN) || (precisionInfo.value == Precision::BF16) ||
|
||||
(precisionInfo.value == Precision::CUSTOM);
|
||||
(precisionInfo.value == Precision::CUSTOM) || (precisionInfo.value == Precision::I4);
|
||||
}
|
||||
|
||||
protected:
|
||||
@@ -309,10 +313,12 @@ protected:
|
||||
CASE(FP64);
|
||||
CASE(FP16);
|
||||
CASE(BF16);
|
||||
CASE(I4);
|
||||
CASE(I8);
|
||||
CASE(I16);
|
||||
CASE(I32);
|
||||
CASE(I64);
|
||||
CASE(U4);
|
||||
CASE(U8);
|
||||
CASE(U16);
|
||||
CASE(U32);
|
||||
@@ -366,10 +372,18 @@ struct PrecisionTrait<Precision::U16> {
|
||||
using value_type = uint16_t;
|
||||
};
|
||||
template <>
|
||||
struct PrecisionTrait<Precision::U4> {
|
||||
using value_type = uint8_t;
|
||||
};
|
||||
template <>
|
||||
struct PrecisionTrait<Precision::U8> {
|
||||
using value_type = uint8_t;
|
||||
};
|
||||
template <>
|
||||
struct PrecisionTrait<Precision::I4> {
|
||||
using value_type = int8_t;
|
||||
};
|
||||
template <>
|
||||
struct PrecisionTrait<Precision::I8> {
|
||||
using value_type = int8_t;
|
||||
};
|
||||
|
||||
@@ -110,10 +110,12 @@ InferenceEngine::Blob::Ptr make_blob_with_precision(InferenceEngine::Precision p
|
||||
USE_FACTORY(FP64);
|
||||
USE_FACTORY(FP16);
|
||||
USE_FACTORY(Q78);
|
||||
USE_FACTORY(I4);
|
||||
USE_FACTORY(I8);
|
||||
USE_FACTORY(I16);
|
||||
USE_FACTORY(I32);
|
||||
USE_FACTORY(I64);
|
||||
USE_FACTORY(U4);
|
||||
USE_FACTORY(U8);
|
||||
USE_FACTORY(U16);
|
||||
USE_FACTORY(U32);
|
||||
|
||||
@@ -26,8 +26,12 @@ inline ::ngraph::element::Type convertPrecision(const Precision& precision) {
|
||||
return ::ngraph::element::Type(::ngraph::element::Type_t::f16);
|
||||
case Precision::BF16:
|
||||
return ::ngraph::element::Type(::ngraph::element::Type_t::bf16);
|
||||
case Precision::U4:
|
||||
return ::ngraph::element::Type(::ngraph::element::Type_t::u4);
|
||||
case Precision::U8:
|
||||
return ::ngraph::element::Type(::ngraph::element::Type_t::u8);
|
||||
case Precision::I4:
|
||||
return ::ngraph::element::Type(::ngraph::element::Type_t::i4);
|
||||
case Precision::I8:
|
||||
return ::ngraph::element::Type(::ngraph::element::Type_t::i8);
|
||||
case Precision::U16:
|
||||
@@ -63,6 +67,8 @@ inline ::ngraph::element::Type convertPrecision(const std::string& precision) {
|
||||
return ::ngraph::element::Type(::ngraph::element::Type_t::bf16);
|
||||
} else if (precision == "f64" || precision == "FP64") {
|
||||
return ::ngraph::element::Type(::ngraph::element::Type_t::f64);
|
||||
} else if (precision == "i4" || precision == "I4") {
|
||||
return ::ngraph::element::Type(::ngraph::element::Type_t::i4);
|
||||
} else if (precision == "i8" || precision == "I8") {
|
||||
return ::ngraph::element::Type(::ngraph::element::Type_t::i8);
|
||||
} else if (precision == "i16" || precision == "I16") {
|
||||
@@ -73,6 +79,8 @@ inline ::ngraph::element::Type convertPrecision(const std::string& precision) {
|
||||
return ::ngraph::element::Type(::ngraph::element::Type_t::i64);
|
||||
} else if (precision == "u1" || precision == "U1" || precision == "BIN" || precision == "bin") {
|
||||
return ::ngraph::element::Type(::ngraph::element::Type_t::u1);
|
||||
} else if (precision == "u4" || precision == "U4") {
|
||||
return ::ngraph::element::Type(::ngraph::element::Type_t::u4);
|
||||
} else if (precision == "u8" || precision == "U8") {
|
||||
return ::ngraph::element::Type(::ngraph::element::Type_t::u8);
|
||||
} else if (precision == "u16" || precision == "U16") {
|
||||
@@ -102,6 +110,8 @@ inline Precision convertPrecision(const ::ngraph::element::Type& precision) {
|
||||
return Precision(Precision::FP64);
|
||||
case ::ngraph::element::Type_t::bf16:
|
||||
return Precision(Precision::BF16);
|
||||
case ::ngraph::element::Type_t::i4:
|
||||
return Precision(Precision::I4);
|
||||
case ::ngraph::element::Type_t::i8:
|
||||
return Precision(Precision::I8);
|
||||
case ::ngraph::element::Type_t::i16:
|
||||
@@ -110,6 +120,8 @@ inline Precision convertPrecision(const ::ngraph::element::Type& precision) {
|
||||
return Precision(Precision::I32);
|
||||
case ::ngraph::element::Type_t::i64:
|
||||
return Precision(Precision::I64);
|
||||
case ::ngraph::element::Type_t::u4:
|
||||
return Precision(Precision::U4);
|
||||
case ::ngraph::element::Type_t::u8:
|
||||
return Precision(Precision::U8);
|
||||
case ::ngraph::element::Type_t::u16:
|
||||
|
||||
@@ -355,42 +355,194 @@ inline int32_t convert_value<uint32_t, int32_t>(uint32_t val) {
|
||||
}
|
||||
|
||||
namespace {
|
||||
template <element::Type_t PREC_FROM, element::Type_t PREC_TO>
|
||||
std::shared_ptr<ngraph::Node> change_constant_precision(std::shared_ptr<opset4::Constant>& constant) {
|
||||
using src_type = typename element_type_traits<PREC_FROM>::value_type;
|
||||
using dst_type = typename element_type_traits<PREC_TO>::value_type;
|
||||
template <element::Type_t PREC_FROM, element::Type_t PREC_TO>
|
||||
std::shared_ptr<ngraph::Node> change_constant_precision(std::shared_ptr<opset4::Constant>& constant) {
|
||||
using src_type = typename element_type_traits<PREC_FROM>::value_type;
|
||||
using dst_type = typename element_type_traits<PREC_TO>::value_type;
|
||||
|
||||
const auto * src_data = constant->get_data_ptr<src_type>();
|
||||
const auto size = shape_size(constant->get_shape());
|
||||
const auto * src_data = constant->get_data_ptr<src_type>();
|
||||
const auto size = shape_size(constant->get_shape());
|
||||
|
||||
auto new_constant = std::make_shared<ngraph::opset4::Constant>(PREC_TO, constant->get_shape());
|
||||
auto * dst_data = const_cast<dst_type *>(reinterpret_cast<const dst_type *>(new_constant->get_data_ptr()));
|
||||
if (dst_data == nullptr)
|
||||
throw ngraph_error("Can't get destination data pointer");
|
||||
auto new_constant = std::make_shared<ngraph::opset4::Constant>(PREC_TO, constant->get_shape());
|
||||
auto * dst_data = const_cast<dst_type *>(reinterpret_cast<const dst_type *>(new_constant->get_data_ptr()));
|
||||
if (dst_data == nullptr)
|
||||
throw ngraph_error("Can't get destination data pointer");
|
||||
|
||||
for (size_t i = 0; i < size; ++i) {
|
||||
dst_data[i] = convert_value<src_type, dst_type>(src_data[i]);
|
||||
for (size_t i = 0; i < size; ++i) {
|
||||
dst_data[i] = convert_value<src_type, dst_type>(src_data[i]);
|
||||
}
|
||||
return new_constant;
|
||||
}
|
||||
|
||||
template <>
|
||||
std::shared_ptr<Node> change_constant_precision<element::Type_t::f16, element::Type_t::f32>(std::shared_ptr<opset4::Constant>& constant) {
|
||||
using src_type = typename element_type_traits<element::Type_t::f16>::value_type;
|
||||
using dst_type = typename element_type_traits<element::Type_t::f32>::value_type;
|
||||
|
||||
const auto * src_data = constant->get_data_ptr<src_type>();
|
||||
const auto size = shape_size(constant->get_shape());
|
||||
|
||||
auto new_constant = std::make_shared<ngraph::opset4::Constant>(element::Type_t::f32, constant->get_shape());
|
||||
auto * dst_data = const_cast<dst_type *>(reinterpret_cast<const dst_type *>(new_constant->get_data_ptr()));
|
||||
if (dst_data == nullptr)
|
||||
throw ngraph_error("Can't get destination data pointer");
|
||||
|
||||
ngraph::runtime::reference::convert<src_type, dst_type>(src_data, dst_data, size);
|
||||
|
||||
return new_constant;
|
||||
}
|
||||
|
||||
struct EnumClassHash {
|
||||
template<class T>
|
||||
std::size_t operator()(T t) const {
|
||||
return static_cast<size_t>(t);
|
||||
}
|
||||
};
|
||||
|
||||
/**
|
||||
* @brief Method converts low precision integer types
|
||||
*
|
||||
* @param src source value !!! the type must be unsigned !!!
|
||||
* @param dst destination value !!! the type must be unsigned !!!
|
||||
* @param src_offset source offset (for custom data types)
|
||||
* @param src_size source size (for custom data types)
|
||||
* @param dst_offset destination offset
|
||||
* @param dst_size destination size
|
||||
* @param is_signed the type of source data
|
||||
*/
|
||||
template <class SRC, class DST>
|
||||
void convert_lp_value(const SRC& src, DST& dst, size_t src_offset, size_t src_size, size_t dst_offset, size_t dst_size, bool is_signed) {
|
||||
constexpr SRC src_max = std::numeric_limits<SRC>::max();
|
||||
constexpr DST dst_max = std::numeric_limits<DST>::max();
|
||||
// Make a shift for the source value
|
||||
// src [11101000] offset 2, size 4
|
||||
// val [00011101]
|
||||
SRC val = src >> src_offset;
|
||||
// dst [10001111 00000100] offset 5 size 9
|
||||
// new_val [00000000 00000000]
|
||||
DST new_val = 0;
|
||||
// If source type is signed
|
||||
if (is_signed) {
|
||||
// Get the sign of value
|
||||
// sign [00000000]
|
||||
// invert value in order to use XOR
|
||||
SRC sign = (~(val >> (src_size - 1))) & 0b1;
|
||||
// Calculate diff in order to clean bits which don't exist in the source value
|
||||
// diff 5
|
||||
size_t diff = sizeof(SRC)*8 - src_size + 1;
|
||||
// Clean unnecessary bits
|
||||
// val [10100000]
|
||||
val = val << diff;
|
||||
// val [00000101]
|
||||
val = (val >> diff);
|
||||
|
||||
// Negative number
|
||||
if (!sign) {
|
||||
// val [11110101]
|
||||
val |= (src_max << (diff - 1));
|
||||
// new_val [00000001 11111111]
|
||||
new_val = (sign << (dst_size - 1)) ^ (dst_max >> (sizeof(DST) * 8 - dst_size));
|
||||
// new_val [00000001 11110101]
|
||||
new_val &= (dst_max << sizeof(SRC)*8) | val;
|
||||
} else {
|
||||
// new_val [00000000 00000101]
|
||||
new_val = val;
|
||||
}
|
||||
return new_constant;
|
||||
} else {
|
||||
// Calculate diff in order to clean bits which don't exist in the source value
|
||||
// diff 4
|
||||
size_t diff = sizeof(SRC)*8 - src_size;
|
||||
// Clean unnecessary bits
|
||||
// val [11010000]
|
||||
val = val << diff;
|
||||
// val [00001101]
|
||||
val = val >> diff;
|
||||
// new_val [00000000 00001101]
|
||||
new_val = val;
|
||||
}
|
||||
|
||||
template <>
|
||||
std::shared_ptr<Node> change_constant_precision<element::Type_t::f16, element::Type_t::f32>(std::shared_ptr<opset4::Constant>& constant) {
|
||||
using src_type = typename element_type_traits<element::Type_t::f16>::value_type;
|
||||
using dst_type = typename element_type_traits<element::Type_t::f32>::value_type;
|
||||
// Make a mask in order to save other values if DST contains several values
|
||||
// mask [11000000 00011111]
|
||||
DST mask = 0;
|
||||
if (dst_offset + dst_size < sizeof(DST) * 8)
|
||||
mask = (dst_max << (dst_offset + dst_size));
|
||||
if (dst_offset != 0)
|
||||
mask |= (dst_max >> (sizeof(DST) * 8 - dst_offset));
|
||||
|
||||
const auto * src_data = constant->get_data_ptr<src_type>();
|
||||
const auto size = shape_size(constant->get_shape());
|
||||
// Add mask to our converted value
|
||||
// signed: new_val [11100000 10111111]
|
||||
// unsigned: new_val [11000001 10111111]
|
||||
new_val = mask | (new_val << dst_offset);
|
||||
|
||||
auto new_constant = std::make_shared<ngraph::opset4::Constant>(element::Type_t::f32, constant->get_shape());
|
||||
auto * dst_data = const_cast<dst_type *>(reinterpret_cast<const dst_type *>(new_constant->get_data_ptr()));
|
||||
if (dst_data == nullptr)
|
||||
throw ngraph_error("Can't get destination data pointer");
|
||||
// Add our value to destination
|
||||
// dst: [10111111 11100100]
|
||||
dst |= ~mask;
|
||||
// signed: dst [10100000 10100100]
|
||||
// unsigned: dst [10000001 10100100]
|
||||
dst &= new_val;
|
||||
}
|
||||
|
||||
ngraph::runtime::reference::convert<src_type, dst_type>(src_data, dst_data, size);
|
||||
std::shared_ptr<Node> convert_low_precisions_int(std::shared_ptr<opset4::Constant>& constant, element::Type to) {
|
||||
// Supported integer precisions
|
||||
static const std::unordered_set<element::Type_t, EnumClassHash> supported_integer_precisions = {
|
||||
element::i4,
|
||||
element::u4,
|
||||
element::u1
|
||||
};
|
||||
// Get source element type and source data
|
||||
auto src_type = constant->get_element_type();
|
||||
const auto* src_data = reinterpret_cast<const uint8_t*>(constant->get_data_ptr());
|
||||
|
||||
return new_constant;
|
||||
// We support conversion only if several elements can be represented in one instance of some C++ common data type without any exception,
|
||||
// destination data type should be bigger than source and destination data type should be real
|
||||
if (!supported_integer_precisions.count(src_type) || (src_type.size() * 8) % src_type.bitwidth() ||
|
||||
(to.size() * 8) % to.bitwidth() || to.is_real() || to.bitwidth() < src_type.bitwidth())
|
||||
throw ngraph_error("Convert low precision for " + constant->get_element_type().get_type_name() + " to " +
|
||||
to.get_type_name() + " is not implemented!");
|
||||
|
||||
// Create a new constant operation and get destination data
|
||||
auto new_constant = std::make_shared<ngraph::opset4::Constant>(to, constant->get_shape());
|
||||
auto* dst_data = const_cast<uint8_t *>(reinterpret_cast<const uint8_t *>(new_constant->get_data_ptr()));
|
||||
// Check pointers
|
||||
if (src_data == nullptr || dst_data == nullptr)
|
||||
throw ngraph_error("Can't get data pointer");
|
||||
|
||||
// Convert values
|
||||
const auto size = shape_size(constant->get_shape());
|
||||
for (size_t i = 0; i < size; i++) {
|
||||
// Calculate indexes
|
||||
size_t dst_idx = i / ((to.size() * 8) / to.bitwidth());
|
||||
size_t src_idx = i / ((src_type.size() * 8) / src_type.bitwidth());
|
||||
// Calculate offsets inside the indexes
|
||||
size_t dst_off = (to.size() * 8 - to.bitwidth()) - to.bitwidth() * (i % ((to.size() * 8) / to.bitwidth()));
|
||||
size_t src_off = (src_type.size() * 8 - src_type.bitwidth()) - src_type.bitwidth() * (i % ((src_type.size() * 8) / src_type.bitwidth()));
|
||||
// Source type at the current moment always less than 1 byte
|
||||
// Select the right destination type
|
||||
switch (to.size()) {
|
||||
case 1:
|
||||
convert_lp_value<uint8_t, uint8_t>(src_data[src_idx], dst_data[dst_idx], src_off, src_type.bitwidth(),
|
||||
dst_off, to.bitwidth(), src_type.is_signed());
|
||||
break;
|
||||
case 2:
|
||||
convert_lp_value<uint8_t, uint16_t>(src_data[src_idx], reinterpret_cast<uint16_t*>(dst_data)[dst_idx], src_off, src_type.bitwidth(),
|
||||
dst_off, to.bitwidth(), src_type.is_signed());
|
||||
break;
|
||||
case 4:
|
||||
convert_lp_value<uint8_t, uint32_t>(src_data[src_idx], reinterpret_cast<uint32_t*>(dst_data)[dst_idx], src_off, src_type.bitwidth(),
|
||||
dst_off, to.bitwidth(), src_type.is_signed());
|
||||
break;
|
||||
case 8:
|
||||
convert_lp_value<uint8_t, uint64_t>(src_data[src_idx], reinterpret_cast<uint64_t*>(dst_data)[dst_idx], src_off, src_type.bitwidth(),
|
||||
dst_off, to.bitwidth(), src_type.is_signed());
|
||||
break;
|
||||
default:
|
||||
throw ngraph_error("Unsupported element size!");
|
||||
}
|
||||
}
|
||||
|
||||
return new_constant;
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
bool fuse_type_to_constant(const std::shared_ptr<ngraph::Node> & node, element::Type to, const std::vector<Input<Node>> & consumers) {
|
||||
@@ -417,8 +569,10 @@ bool fuse_type_to_constant(const std::shared_ptr<ngraph::Node> & node, element::
|
||||
new_const = change_constant_precision<element::Type_t::boolean, element::Type_t::u8>(constant);
|
||||
} else if (from == element::boolean && to == element::i32) {
|
||||
new_const = change_constant_precision<element::Type_t::boolean, element::Type_t::i32>(constant);
|
||||
} else if (from == element::i4 || from == element::u4 || from == element::u1) {
|
||||
new_const = convert_low_precisions_int(constant, to);
|
||||
} else {
|
||||
throw ngraph_error("not supported");
|
||||
throw ngraph_error("Precision conversion from " + from.get_type_name() + " to " + to.get_type_name() + " is not supported");
|
||||
}
|
||||
for (auto & output : consumers) {
|
||||
output.replace_source_output(new_const);
|
||||
|
||||
@@ -62,9 +62,11 @@ InferenceEngine::Blob::Ptr createBlob(InferenceEngine::Precision precision, Size
|
||||
return make_shared_blob<uint64_t>(tensorDesc);
|
||||
case InferenceEngine::Precision::U16:
|
||||
return make_shared_blob<uint16_t>(tensorDesc);
|
||||
case InferenceEngine::Precision::I4:
|
||||
case InferenceEngine::Precision::I8:
|
||||
case InferenceEngine::Precision::BIN:
|
||||
return make_shared_blob<int8_t>(tensorDesc);
|
||||
case InferenceEngine::Precision::U4:
|
||||
case InferenceEngine::Precision::U8:
|
||||
return make_shared_blob<uint8_t>(tensorDesc);
|
||||
default:
|
||||
@@ -134,9 +136,11 @@ void FillBlob(Blob::Ptr& inputBlob) {
|
||||
return FillBlobRandom<uint64_t>(inputBlob);
|
||||
case InferenceEngine::Precision::U16:
|
||||
return FillBlobRandom<uint16_t>(inputBlob);
|
||||
case InferenceEngine::Precision::I4:
|
||||
case InferenceEngine::Precision::I8:
|
||||
case InferenceEngine::Precision::BIN:
|
||||
return FillBlobRandom<int8_t>(inputBlob);
|
||||
case InferenceEngine::Precision::U4:
|
||||
case InferenceEngine::Precision::U8:
|
||||
return FillBlobRandom<uint8_t>(inputBlob);
|
||||
default:
|
||||
@@ -224,9 +228,11 @@ bool IsCorrectBlobCopy(Blob::Ptr& srcBlob, Blob::Ptr& dstBlob) {
|
||||
return IsCorrectBlobCopy_Impl<uint64_t >(srcBlob, dstBlob);
|
||||
case InferenceEngine::Precision::U16:
|
||||
return IsCorrectBlobCopy_Impl<uint16_t>(srcBlob, dstBlob);
|
||||
case InferenceEngine::Precision::I4:
|
||||
case InferenceEngine::Precision::I8:
|
||||
case InferenceEngine::Precision::BIN:
|
||||
return IsCorrectBlobCopy_Impl<int8_t>(srcBlob, dstBlob);
|
||||
case InferenceEngine::Precision::U4:
|
||||
case InferenceEngine::Precision::U8:
|
||||
return IsCorrectBlobCopy_Impl<uint8_t>(srcBlob, dstBlob);
|
||||
default:
|
||||
@@ -351,9 +357,11 @@ bool IsEqualBlobCopy(Blob::Ptr& srcBlob, Blob::Ptr& dstBlob) {
|
||||
return IsEqualBlobCopy_Impl<uint64_t>(srcBlob, dstBlob);
|
||||
case InferenceEngine::Precision::I64:
|
||||
return IsEqualBlobCopy_Impl<int64_t>(srcBlob, dstBlob);
|
||||
case InferenceEngine::Precision::I4:
|
||||
case InferenceEngine::Precision::I8:
|
||||
case InferenceEngine::Precision::BIN:
|
||||
return IsEqualBlobCopy_Impl<int8_t>(srcBlob, dstBlob);
|
||||
case InferenceEngine::Precision::U4:
|
||||
case InferenceEngine::Precision::U8:
|
||||
return IsEqualBlobCopy_Impl<uint8_t>(srcBlob, dstBlob);
|
||||
case InferenceEngine::Precision::U16:
|
||||
@@ -404,9 +412,11 @@ void copy3DBlobsAllBytesWithReLayoutWrapper(const Blob::Ptr& srcLayoutBlob, Blob
|
||||
return copy3DBlobsAllBytesWithReLayout<int64_t>(srcLayoutBlob, trgLayoutBlob);
|
||||
case InferenceEngine::Precision::U16:
|
||||
return copy3DBlobsAllBytesWithReLayout<uint16_t>(srcLayoutBlob, trgLayoutBlob);
|
||||
case InferenceEngine::Precision::I4:
|
||||
case InferenceEngine::Precision::I8:
|
||||
case InferenceEngine::Precision::BIN:
|
||||
return copy3DBlobsAllBytesWithReLayout<int8_t>(srcLayoutBlob, trgLayoutBlob);
|
||||
case InferenceEngine::Precision::U4:
|
||||
case InferenceEngine::Precision::U8:
|
||||
return copy3DBlobsAllBytesWithReLayout<uint8_t>(srcLayoutBlob, trgLayoutBlob);
|
||||
default:
|
||||
|
||||
@@ -22,8 +22,10 @@ TEST_F(PrecisionTests, ShowsCorrectPrecisionNames) {
|
||||
ASSERT_STREQ(Precision(Precision::I32).name(), "I32");
|
||||
ASSERT_STREQ(Precision(Precision::U32).name(), "U32");
|
||||
ASSERT_STREQ(Precision(Precision::U16).name(), "U16");
|
||||
ASSERT_STREQ(Precision(Precision::I4).name(), "I4");
|
||||
ASSERT_STREQ(Precision(Precision::I8).name(), "I8");
|
||||
ASSERT_STREQ(Precision(Precision::Q78).name(), "Q78");
|
||||
ASSERT_STREQ(Precision(Precision::U4).name(), "U4");
|
||||
ASSERT_STREQ(Precision(Precision::U8).name(), "U8");
|
||||
ASSERT_STREQ(Precision(Precision::MIXED).name(), "MIXED");
|
||||
ASSERT_STREQ(Precision(Precision::UNSPECIFIED).name(), "UNSPECIFIED");
|
||||
@@ -41,9 +43,11 @@ TEST_F(PrecisionTests, sizeIsCorrect) {
|
||||
ASSERT_EQ(Precision(Precision::U32).size(), 4);
|
||||
ASSERT_EQ(Precision(Precision::I16).size(), 2);
|
||||
ASSERT_EQ(Precision(Precision::U16).size(), 2);
|
||||
ASSERT_EQ(Precision(Precision::I4).size(), 1);
|
||||
ASSERT_EQ(Precision(Precision::I8).size(), 1);
|
||||
ASSERT_EQ(Precision(Precision::Q78).size(), 2);
|
||||
ASSERT_EQ(Precision(Precision::U8).size(), 1);
|
||||
ASSERT_EQ(Precision(Precision::U4).size(), 1);
|
||||
ASSERT_EQ(Precision(10 * 8).size(), 10);
|
||||
ASSERT_ANY_THROW(Precision(Precision::MIXED).size());
|
||||
ASSERT_ANY_THROW(Precision(Precision::UNSPECIFIED).size());
|
||||
@@ -60,7 +64,9 @@ TEST_F(PrecisionTests, is_float) {
|
||||
ASSERT_FALSE(Precision(Precision::I16).is_float());
|
||||
ASSERT_FALSE(Precision(Precision::U16).is_float());
|
||||
ASSERT_FALSE(Precision(Precision::I8).is_float());
|
||||
ASSERT_FALSE(Precision(Precision::I4).is_float());
|
||||
ASSERT_FALSE(Precision(Precision::Q78).is_float());
|
||||
ASSERT_FALSE(Precision(Precision::U4).is_float());
|
||||
ASSERT_FALSE(Precision(Precision::U8).is_float());
|
||||
ASSERT_FALSE(Precision(Precision::MIXED).is_float());
|
||||
ASSERT_FALSE(Precision(10).is_float());
|
||||
@@ -78,8 +84,10 @@ TEST_F(PrecisionTests, constructFromSTR) {
|
||||
ASSERT_EQ(Precision(Precision::U32), Precision::FromStr("U32"));
|
||||
ASSERT_EQ(Precision(Precision::I16), Precision::FromStr("I16"));
|
||||
ASSERT_EQ(Precision(Precision::U16), Precision::FromStr("U16"));
|
||||
ASSERT_EQ(Precision(Precision::I4), Precision::FromStr("I4"));
|
||||
ASSERT_EQ(Precision(Precision::I8), Precision::FromStr("I8"));
|
||||
ASSERT_EQ(Precision(Precision::Q78), Precision::FromStr("Q78"));
|
||||
ASSERT_EQ(Precision(Precision::U4), Precision::FromStr("U4"));
|
||||
ASSERT_EQ(Precision(Precision::U8), Precision::FromStr("U8"));
|
||||
ASSERT_EQ(Precision(Precision::MIXED), Precision::FromStr("MIXED"));
|
||||
ASSERT_EQ(Precision(static_cast<Precision::ePrecision >(-3)), Precision::FromStr("UNSPECIFIED"));
|
||||
|
||||
@@ -567,7 +567,7 @@ void constant_convert_test(element::Type_t type_from, element::Type_t type_to, F
|
||||
std::shared_ptr<ngraph::Function> f(nullptr);
|
||||
std::string expected_friendly_name;
|
||||
{
|
||||
auto c = opset4::Constant::create(type_from, Shape{}, {value});
|
||||
auto c = std::make_shared<opset4::Constant>(type_from, Shape{}, &value);
|
||||
expected_friendly_name = c->get_friendly_name();
|
||||
f = std::make_shared<Function>(NodeVector{c}, ParameterVector{});
|
||||
|
||||
@@ -608,7 +608,7 @@ TEST(TransformationTests, ConvertPrecision_ConstantConversion_U64MaxToI32) {
|
||||
}
|
||||
|
||||
TEST(TransformationTests, ConvertPrecision_ConstantConversion_U64ToI32) {
|
||||
constant_convert_test(element::Type_t::u64, element::Type_t::i32, 42, 42);
|
||||
constant_convert_test<uint64_t, int32_t>(element::Type_t::u64, element::Type_t::i32, 42, 42);
|
||||
}
|
||||
|
||||
TEST(TransformationTests, ConvertPrecision_ConstantConversion_U32MinToI32) {
|
||||
@@ -630,3 +630,152 @@ TEST(TransformationTests, ConvertPrecision_ConstantConversion_BoolToU8) {
|
||||
constant_convert_test(element::Type_t::boolean, element::Type_t::u8, true, 1);
|
||||
constant_convert_test(element::Type_t::boolean, element::Type_t::u8, false, 0);
|
||||
}
|
||||
|
||||
TEST(TransformationTests, ConvertPrecision_ConstantConversion_U4ToI8) {
|
||||
constant_convert_test(element::Type_t::u4, element::Type_t::i8, 171, 10);
|
||||
}
|
||||
|
||||
TEST(TransformationTests, ConvertPrecision_ConstantConversion_U4ToU8) {
|
||||
constant_convert_test(element::Type_t::u4, element::Type_t::u8, 171, 10);
|
||||
}
|
||||
|
||||
TEST(TransformationTests, ConvertPrecision_ConstantConversion_U4ToI8_2) {
|
||||
constant_convert_test(element::Type_t::u4, element::Type_t::i8, 96, 6);
|
||||
}
|
||||
|
||||
TEST(TransformationTests, ConvertPrecision_ConstantConversion_U4ToU8_96) {
|
||||
constant_convert_test(element::Type_t::u4, element::Type_t::u8, 96, 6);
|
||||
}
|
||||
|
||||
TEST(TransformationTests, ConvertPrecision_ConstantConversion_I4ToU8) {
|
||||
constant_convert_test(element::Type_t::i4, element::Type_t::u8, 96, 6);
|
||||
}
|
||||
|
||||
TEST(TransformationTests, ConvertPrecision_ConstantConversion_I4ToI8) {
|
||||
constant_convert_test(element::Type_t::i4, element::Type_t::i8, 96, 6);
|
||||
}
|
||||
|
||||
TEST(TransformationTests, ConvertPrecision_ConstantConversion_I4ToU8_neg) {
|
||||
constant_convert_test(element::Type_t::i4, element::Type_t::u8, 171, 242);
|
||||
}
|
||||
|
||||
TEST(TransformationTests, ConvertPrecision_ConstantConversion_I4ToI8_neg) {
|
||||
constant_convert_test(element::Type_t::i4, element::Type_t::i8, 171, -14);
|
||||
}
|
||||
|
||||
TEST(TransformationTests, ConvertPrecision_ConstantConversion_U4ToI32) {
|
||||
constant_convert_test(element::Type_t::u4, element::Type_t::i32, 171, 10);
|
||||
}
|
||||
|
||||
TEST(TransformationTests, ConvertPrecision_ConstantConversion_U4ToU32) {
|
||||
constant_convert_test(element::Type_t::u4, element::Type_t::u32, 171, 10);
|
||||
}
|
||||
|
||||
TEST(TransformationTests, ConvertPrecision_ConstantConversion_I4ToU32) {
|
||||
constant_convert_test(element::Type_t::i4, element::Type_t::u32, 96, 6);
|
||||
}
|
||||
|
||||
TEST(TransformationTests, ConvertPrecision_ConstantConversion_I4ToI32) {
|
||||
constant_convert_test(element::Type_t::i4, element::Type_t::i32, 96, 6);
|
||||
}
|
||||
|
||||
TEST(TransformationTests, ConvertPrecision_ConstantConversion_I4ToU32_neg) {
|
||||
constant_convert_test(element::Type_t::i4, element::Type_t::u32, 171, -14);
|
||||
}
|
||||
|
||||
TEST(TransformationTests, ConvertPrecision_ConstantConversion_I4ToI32_neg) {
|
||||
constant_convert_test(element::Type_t::i4, element::Type_t::i32, 171, -14);
|
||||
}
|
||||
|
||||
TEST(TransformationTests, ConvertPrecision_ConstantConversion_U4ToI16) {
|
||||
constant_convert_test(element::Type_t::u4, element::Type_t::i16, 171, 10);
|
||||
}
|
||||
|
||||
TEST(TransformationTests, ConvertPrecision_ConstantConversion_U4ToU16) {
|
||||
constant_convert_test(element::Type_t::u4, element::Type_t::u16, 171, 10);
|
||||
}
|
||||
|
||||
TEST(TransformationTests, ConvertPrecision_ConstantConversion_I4ToU16) {
|
||||
constant_convert_test(element::Type_t::i4, element::Type_t::u16, 96, 6);
|
||||
}
|
||||
|
||||
TEST(TransformationTests, ConvertPrecision_ConstantConversion_I4ToI16) {
|
||||
constant_convert_test(element::Type_t::i4, element::Type_t::i16, 96, 6);
|
||||
}
|
||||
|
||||
TEST(TransformationTests, ConvertPrecision_ConstantConversion_I4ToU16_neg) {
|
||||
constant_convert_test(element::Type_t::i4, element::Type_t::u16, 171, 65522);
|
||||
}
|
||||
|
||||
TEST(TransformationTests, ConvertPrecision_ConstantConversion_I4ToI16_neg) {
|
||||
constant_convert_test(element::Type_t::i4, element::Type_t::i16, 171, -14);
|
||||
}
|
||||
|
||||
TEST(TransformationTests, ConvertPrecision_ConstantConversion_U4ToI64) {
|
||||
constant_convert_test(element::Type_t::u4, element::Type_t::i64, 171, 10);
|
||||
}
|
||||
|
||||
TEST(TransformationTests, ConvertPrecision_ConstantConversion_U4ToU64) {
|
||||
constant_convert_test(element::Type_t::u4, element::Type_t::u64, 171, 10);
|
||||
}
|
||||
|
||||
TEST(TransformationTests, ConvertPrecision_ConstantConversion_I4ToU64) {
|
||||
constant_convert_test(element::Type_t::i4, element::Type_t::u64, 96, 6);
|
||||
}
|
||||
|
||||
TEST(TransformationTests, ConvertPrecision_ConstantConversion_I4ToI64) {
|
||||
constant_convert_test(element::Type_t::i4, element::Type_t::i64, 96, 6);
|
||||
}
|
||||
|
||||
TEST(TransformationTests, ConvertPrecision_ConstantConversion_I4ToU64_neg) {
|
||||
constant_convert_test(element::Type_t::i4, element::Type_t::u64, 171, -14);
|
||||
}
|
||||
|
||||
TEST(TransformationTests, ConvertPrecision_ConstantConversion_I4ToI64_neg) {
|
||||
constant_convert_test(element::Type_t::i4, element::Type_t::i64, 171, -14);
|
||||
}
|
||||
|
||||
TEST(TransformationTests, ConvertPrecision_ConstantConversion_U1ToU8) {
|
||||
std::shared_ptr<ngraph::Function> f(nullptr);
|
||||
uint8_t value = 171;
|
||||
std::string expected_friendly_name;
|
||||
{
|
||||
auto c = std::make_shared<opset4::Constant>(element::Type_t::u1, Shape{2}, &value);
|
||||
expected_friendly_name = c->get_friendly_name();
|
||||
f = std::make_shared<Function>(NodeVector{c}, ParameterVector{});
|
||||
|
||||
pass::Manager manager;
|
||||
manager.register_pass<ngraph::pass::ConvertPrecision>(element::Type_t::u1, element::Type_t::u8);
|
||||
manager.run_passes(f);
|
||||
}
|
||||
auto ops = f->get_ordered_ops();
|
||||
auto c = std::dynamic_pointer_cast<opset4::Constant>(ops[0]);
|
||||
ASSERT_NE(c, nullptr);
|
||||
ASSERT_EQ(c->get_friendly_name(), expected_friendly_name);
|
||||
|
||||
auto* actual = reinterpret_cast<const uint8_t*>(c->get_data_ptr());
|
||||
ASSERT_EQ(1, actual[0]);
|
||||
ASSERT_EQ(0, actual[1]);
|
||||
}
|
||||
|
||||
TEST(TransformationTests, ConvertPrecision_ConstantConversion_U1ToU4) {
|
||||
std::shared_ptr<ngraph::Function> f(nullptr);
|
||||
uint8_t value = 171;
|
||||
std::string expected_friendly_name;
|
||||
{
|
||||
auto c = std::make_shared<opset4::Constant>(element::Type_t::u1, Shape{2}, &value);
|
||||
expected_friendly_name = c->get_friendly_name();
|
||||
f = std::make_shared<Function>(NodeVector{c}, ParameterVector{});
|
||||
|
||||
pass::Manager manager;
|
||||
manager.register_pass<ngraph::pass::ConvertPrecision>(element::Type_t::u1, element::Type_t::u4);
|
||||
manager.run_passes(f);
|
||||
}
|
||||
auto ops = f->get_ordered_ops();
|
||||
auto c = std::dynamic_pointer_cast<opset4::Constant>(ops[0]);
|
||||
ASSERT_NE(c, nullptr);
|
||||
ASSERT_EQ(c->get_friendly_name(), expected_friendly_name);
|
||||
|
||||
auto actual = reinterpret_cast<const uint8_t*>(c->get_data_ptr())[0];
|
||||
ASSERT_EQ(16, actual);
|
||||
}
|
||||
|
||||
@@ -97,6 +97,10 @@ namespace ngraph
|
||||
throw ngraph_error("make_constant: Unsupported element type 'boolean'");
|
||||
case element::Type_t::u1:
|
||||
throw ngraph_error("make_constant: Unsupported element type 'u1'");
|
||||
case element::Type_t::i4:
|
||||
throw ngraph_error("make_constant: Unsupported element type 'i4'");
|
||||
case element::Type_t::u4:
|
||||
throw ngraph_error("make_constant: Unsupported element type 'u4'");
|
||||
case element::Type_t::undefined:
|
||||
throw ngraph_error("make_constant: Unsupported element type 'undefined'");
|
||||
}
|
||||
|
||||
@@ -208,8 +208,10 @@ namespace ngraph
|
||||
typename element_type_traits<element::Type_t::u64>::value_type>(
|
||||
value));
|
||||
break;
|
||||
case element::Type_t::u1: throw std::runtime_error("unsupported type");
|
||||
case element::Type_t::undefined: throw std::runtime_error("unsupported type");
|
||||
case element::Type_t::i4:
|
||||
case element::Type_t::u1:
|
||||
case element::Type_t::u4:
|
||||
case element::Type_t::undefined:
|
||||
case element::Type_t::dynamic: throw std::runtime_error("unsupported type");
|
||||
}
|
||||
#if defined(__GNUC__) && !(__GNUC__ == 4 && __GNUC_MINOR__ == 8)
|
||||
@@ -586,8 +588,10 @@ namespace ngraph
|
||||
case element::Type_t::u64:
|
||||
write_buffer<uint64_t, T>(target, source, target_element_count);
|
||||
break;
|
||||
case element::Type_t::u1: throw std::runtime_error("unsupported type");
|
||||
case element::Type_t::undefined: throw std::runtime_error("unsupported type");
|
||||
case element::Type_t::i4:
|
||||
case element::Type_t::u1:
|
||||
case element::Type_t::u4:
|
||||
case element::Type_t::undefined:
|
||||
case element::Type_t::dynamic: throw std::runtime_error("unsupported type");
|
||||
}
|
||||
#if defined(__GNUC__) && !(__GNUC__ == 4 && __GNUC_MINOR__ == 8)
|
||||
|
||||
@@ -46,11 +46,13 @@ namespace ngraph
|
||||
f16,
|
||||
f32,
|
||||
f64,
|
||||
i4,
|
||||
i8,
|
||||
i16,
|
||||
i32,
|
||||
i64,
|
||||
u1,
|
||||
u4,
|
||||
u8,
|
||||
u16,
|
||||
u32,
|
||||
@@ -134,11 +136,13 @@ namespace ngraph
|
||||
constexpr Type f16(Type_t::f16);
|
||||
constexpr Type f32(Type_t::f32);
|
||||
constexpr Type f64(Type_t::f64);
|
||||
constexpr Type i4(Type_t::i4);
|
||||
constexpr Type i8(Type_t::i8);
|
||||
constexpr Type i16(Type_t::i16);
|
||||
constexpr Type i32(Type_t::i32);
|
||||
constexpr Type i64(Type_t::i64);
|
||||
constexpr Type u1(Type_t::u1);
|
||||
constexpr Type u4(Type_t::u4);
|
||||
constexpr Type u8(Type_t::u8);
|
||||
constexpr Type u16(Type_t::u16);
|
||||
constexpr Type u32(Type_t::u32);
|
||||
|
||||
@@ -55,6 +55,12 @@ namespace ngraph
|
||||
using value_type = double;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct element_type_traits<element::Type_t::i4>
|
||||
{
|
||||
using value_type = int8_t;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct element_type_traits<element::Type_t::i8>
|
||||
{
|
||||
@@ -85,6 +91,12 @@ namespace ngraph
|
||||
using value_type = int8_t;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct element_type_traits<element::Type_t::u4>
|
||||
{
|
||||
using value_type = int8_t;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct element_type_traits<element::Type_t::u8>
|
||||
{
|
||||
|
||||
@@ -179,6 +179,14 @@ op::Constant::Constant(const element::Type& type,
|
||||
{
|
||||
throw std::runtime_error("deserialize unsupported type dynamic");
|
||||
}
|
||||
case element::Type_t::i4:
|
||||
{
|
||||
throw std::runtime_error("deserialize unsupported type i4");
|
||||
}
|
||||
case element::Type_t::u4:
|
||||
{
|
||||
throw std::runtime_error("deserialize unsupported type u4");
|
||||
}
|
||||
case element::Type_t::u1:
|
||||
{
|
||||
throw std::runtime_error("deserialize unsupported type u1");
|
||||
@@ -291,6 +299,8 @@ op::Constant::Constant(const element::Type& type,
|
||||
throw std::runtime_error("deserialize unsupported type undefined");
|
||||
case element::Type_t::dynamic:
|
||||
throw std::runtime_error("deserialize unsupported type dynamic");
|
||||
case element::Type_t::i4: throw std::runtime_error("deserialize unsupported type i4");
|
||||
case element::Type_t::u4: throw std::runtime_error("deserialize unsupported type u4");
|
||||
case element::Type_t::u1: throw std::runtime_error("deserialize unsupported type u1");
|
||||
}
|
||||
m_all_elements_bitwise_identical = are_all_data_elements_bitwise_identical();
|
||||
@@ -351,6 +361,21 @@ string op::Constant::convert_value_to_string(size_t index) const
|
||||
break;
|
||||
case element::Type_t::f32: rc = to_cpp_string(get_data_ptr<float>()[index]); break;
|
||||
case element::Type_t::f64: rc = to_cpp_string(get_data_ptr<double>()[index]); break;
|
||||
case element::Type_t::i4:
|
||||
{
|
||||
uint8_t i4data = (get_data_ptr<uint8_t>()[index / 2] >> (index % 2 ? 0 : 4)) & 0x0F;
|
||||
int8_t data = i4data;
|
||||
if ((i4data >> 3) & 0b1)
|
||||
{
|
||||
// negative number
|
||||
data = (i4data & 0x7) | 0xF0;
|
||||
}
|
||||
rc = to_string(data);
|
||||
break;
|
||||
}
|
||||
case element::Type_t::u4:
|
||||
rc = to_string((get_data_ptr<uint8_t>()[index / 2] >> (index % 2 ? 0 : 4)) & 0x0F);
|
||||
break;
|
||||
case element::Type_t::i8: rc = to_string(get_data_ptr<int8_t>()[index]); break;
|
||||
case element::Type_t::i16: rc = to_string(get_data_ptr<int16_t>()[index]); break;
|
||||
case element::Type_t::i32: rc = to_string(get_data_ptr<int32_t>()[index]); break;
|
||||
@@ -460,8 +485,10 @@ vector<string> op::Constant::get_value_strings() const
|
||||
rc.push_back(to_string(value));
|
||||
}
|
||||
break;
|
||||
case element::Type_t::u1: throw runtime_error("unsupported type");
|
||||
case element::Type_t::undefined: throw runtime_error("unsupported type");
|
||||
case element::Type_t::u1:
|
||||
case element::Type_t::i4:
|
||||
case element::Type_t::u4:
|
||||
case element::Type_t::undefined:
|
||||
case element::Type_t::dynamic: throw runtime_error("unsupported type");
|
||||
}
|
||||
#if defined(__GNUC__) && !(__GNUC__ == 4 && __GNUC_MINOR__ == 8)
|
||||
@@ -615,7 +642,9 @@ bool op::Constant::are_all_data_elements_bitwise_identical() const
|
||||
rc = test_bitwise_identical<uint64_t>(this);
|
||||
break;
|
||||
}
|
||||
case element::Type_t::i4:
|
||||
case element::Type_t::u1:
|
||||
case element::Type_t::u4:
|
||||
case element::Type_t::undefined:
|
||||
case element::Type_t::dynamic: break;
|
||||
}
|
||||
|
||||
@@ -476,6 +476,8 @@ void op::v0::Range::validate_and_infer_types()
|
||||
case element::Type_t::u64: result_shape = infer_output_shape<uint64_t>(this, result_et); break;
|
||||
case element::Type_t::dynamic: result_shape = PartialShape::dynamic(1); break;
|
||||
case element::Type_t::u1:
|
||||
case element::Type_t::i4:
|
||||
case element::Type_t::u4:
|
||||
case element::Type_t::undefined:
|
||||
case element::Type_t::boolean:
|
||||
NODE_VALIDATION_CHECK(
|
||||
|
||||
@@ -389,6 +389,8 @@ std::string pass::VisualizeTree::get_constant_value(std::shared_ptr<Node> node,
|
||||
case element::Type_t::undefined: ss << "[ undefined value ]"; break;
|
||||
case element::Type_t::dynamic: ss << "[ dynamic value ]"; break;
|
||||
case element::Type_t::u1: ss << "[ u1 value ]"; break;
|
||||
case element::Type_t::u4: ss << "[ u4 value ]"; break;
|
||||
case element::Type_t::i4: ss << "[ i4 value ]"; break;
|
||||
case element::Type_t::bf16:
|
||||
case element::Type_t::f16:
|
||||
case element::Type_t::f32:
|
||||
|
||||
@@ -72,11 +72,13 @@ static const element_types_map_t& get_type_info_map()
|
||||
{element::Type_t::f16, TypeInfo(16, true, true, false, "float16", "f16")},
|
||||
{element::Type_t::f32, TypeInfo(32, true, true, false, "float", "f32")},
|
||||
{element::Type_t::f64, TypeInfo(64, true, true, false, "double", "f64")},
|
||||
{element::Type_t::i4, TypeInfo(4, false, true, true, "int4_t", "i4")},
|
||||
{element::Type_t::i8, TypeInfo(8, false, true, true, "int8_t", "i8")},
|
||||
{element::Type_t::i16, TypeInfo(16, false, true, false, "int16_t", "i16")},
|
||||
{element::Type_t::i32, TypeInfo(32, false, true, true, "int32_t", "i32")},
|
||||
{element::Type_t::i64, TypeInfo(64, false, true, false, "int64_t", "i64")},
|
||||
{element::Type_t::u1, TypeInfo(1, false, false, false, "uint1_t", "u1")},
|
||||
{element::Type_t::u4, TypeInfo(4, false, false, false, "uint4_t", "u4")},
|
||||
{element::Type_t::u8, TypeInfo(8, false, false, true, "uint8_t", "u8")},
|
||||
{element::Type_t::u16, TypeInfo(16, false, false, false, "uint16_t", "u16")},
|
||||
{element::Type_t::u32, TypeInfo(32, false, false, false, "uint32_t", "u32")},
|
||||
@@ -93,11 +95,13 @@ std::vector<const element::Type*> element::Type::get_known_types()
|
||||
&element::f16,
|
||||
&element::f32,
|
||||
&element::f64,
|
||||
&element::i4,
|
||||
&element::i8,
|
||||
&element::i16,
|
||||
&element::i32,
|
||||
&element::i64,
|
||||
&element::u1,
|
||||
&element::u4,
|
||||
&element::u8,
|
||||
&element::u16,
|
||||
&element::u32,
|
||||
@@ -294,11 +298,13 @@ size_t ngraph::compiler_byte_size(element::Type_t et)
|
||||
ET_CASE(f16);
|
||||
ET_CASE(f32);
|
||||
ET_CASE(f64);
|
||||
ET_CASE(i4);
|
||||
ET_CASE(i8);
|
||||
ET_CASE(i16);
|
||||
ET_CASE(i32);
|
||||
ET_CASE(i64);
|
||||
ET_CASE(u1);
|
||||
ET_CASE(u4);
|
||||
ET_CASE(u8);
|
||||
ET_CASE(u16);
|
||||
ET_CASE(u32);
|
||||
@@ -326,11 +332,13 @@ namespace ngraph
|
||||
{"f16", element::Type_t::f16},
|
||||
{"f32", element::Type_t::f32},
|
||||
{"f64", element::Type_t::f64},
|
||||
{"i4", element::Type_t::i4},
|
||||
{"i8", element::Type_t::i8},
|
||||
{"i16", element::Type_t::i16},
|
||||
{"i32", element::Type_t::i32},
|
||||
{"i64", element::Type_t::i64},
|
||||
{"u1", element::Type_t::u1},
|
||||
{"u4", element::Type_t::u4},
|
||||
{"u8", element::Type_t::u8},
|
||||
{"u16", element::Type_t::u16},
|
||||
{"u32", element::Type_t::u32},
|
||||
|
||||
@@ -64,6 +64,7 @@ set(SRC
|
||||
graph_rewrite.cpp
|
||||
includes.cpp
|
||||
input_output_assign.cpp
|
||||
int4.cpp
|
||||
intervals.cpp
|
||||
main.cpp
|
||||
matcher_pass.cpp
|
||||
@@ -211,6 +212,7 @@ set(SRC
|
||||
type_prop/unsqueeze.cpp
|
||||
type_prop/variadic_split.cpp
|
||||
type_prop_layers.cpp
|
||||
uint4.cpp
|
||||
util.cpp
|
||||
)
|
||||
|
||||
|
||||
35
ngraph/test/int4.cpp
Normal file
35
ngraph/test/int4.cpp
Normal file
@@ -0,0 +1,35 @@
|
||||
//*****************************************************************************
|
||||
// Copyright 2017-2021 Intel Corporation
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
//*****************************************************************************
|
||||
|
||||
#include "gtest/gtest.h"
|
||||
#include "ngraph/ngraph.hpp"
|
||||
#include "util/all_close_f.hpp"
|
||||
#include "util/test_tools.hpp"
|
||||
|
||||
using namespace ngraph;
|
||||
using namespace std;
|
||||
|
||||
TEST(int4, convert_i4_to_string)
|
||||
{
|
||||
vector<uint8_t> values{171, 16};
|
||||
auto constant = make_shared<op::Constant>(element::i4, Shape{3}, &values[0]);
|
||||
|
||||
vector<string> ref{"-14", "-13", "1"};
|
||||
for (size_t i = 0; i < 3; ++i)
|
||||
{
|
||||
ASSERT_EQ(constant->convert_value_to_string(i), ref[i]);
|
||||
}
|
||||
}
|
||||
@@ -124,6 +124,8 @@ void pass::DynElimination::construct_range()
|
||||
case element::Type_t::u64:
|
||||
replacement = make_range_replacement<uint64_t>(et, shape, start_arg, step_arg);
|
||||
break;
|
||||
case element::Type_t::i4:
|
||||
case element::Type_t::u4:
|
||||
case element::Type_t::u1:
|
||||
case element::Type_t::undefined:
|
||||
case element::Type_t::dynamic:
|
||||
|
||||
35
ngraph/test/uint4.cpp
Normal file
35
ngraph/test/uint4.cpp
Normal file
@@ -0,0 +1,35 @@
|
||||
//*****************************************************************************
|
||||
// Copyright 2017-2021 Intel Corporation
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
//*****************************************************************************
|
||||
|
||||
#include "gtest/gtest.h"
|
||||
#include "ngraph/ngraph.hpp"
|
||||
#include "util/all_close_f.hpp"
|
||||
#include "util/test_tools.hpp"
|
||||
|
||||
using namespace ngraph;
|
||||
using namespace std;
|
||||
|
||||
TEST(uint4, convert_u4_to_string)
|
||||
{
|
||||
vector<uint8_t> values{171, 16};
|
||||
auto constant = make_shared<op::Constant>(element::u4, Shape{3}, &values[0]);
|
||||
|
||||
vector<string> ref{"10", "11", "1"};
|
||||
for (size_t i = 0; i < 3; ++i)
|
||||
{
|
||||
ASSERT_EQ(constant->convert_value_to_string(i), ref[i]);
|
||||
}
|
||||
}
|
||||
@@ -149,7 +149,9 @@ namespace
|
||||
case element::Type_t::u32: return InferenceEngine::Precision::U32; break;
|
||||
case element::Type_t::u64: return InferenceEngine::Precision::U64; break;
|
||||
case element::Type_t::u1: return InferenceEngine::Precision::BIN; break;
|
||||
case element::Type_t::undefined: throw std::runtime_error("unsupported type");
|
||||
case element::Type_t::i4:
|
||||
case element::Type_t::u4:
|
||||
case element::Type_t::undefined:
|
||||
case element::Type_t::dynamic: throw std::runtime_error("unsupported type");
|
||||
}
|
||||
#if defined(__GNUC__) && !(__GNUC__ == 4 && __GNUC_MINOR__ == 8)
|
||||
|
||||
Reference in New Issue
Block a user