Int4 support for nGraph (#4712)

* Start int4 type

* Changed int4 type

* Fixed code style

* Added u4 type

* Added tests

* Add U4 I4 precisions to the IE

* Removed redundant tests

* Added convert precision

* Fixed tests

* Fixed C tests

* Fixed logic for signed values

* Fixed nGraph tests

* Added more tests

* Fixed code style

* Update inference-engine/src/transformations/src/transformations/convert_precision.cpp

Co-authored-by: Evgenya Stepyreva <evgenya.stepyreva@intel.com>

Co-authored-by: Evgenya Stepyreva <evgenya.stepyreva@intel.com>
This commit is contained in:
Ilya Churaev
2021-03-24 16:17:43 +03:00
committed by GitHub
parent b62f49b66a
commit f0e574903a
25 changed files with 544 additions and 44 deletions

View File

@@ -121,15 +121,19 @@ enum precision_e{
FP16 = 11, /**< 16bit floating point value */
BF16 = 12, /**< 16bit floating point value, 8 bit for exponent, 7 bit for mantisa*/
FP64 = 13, /**< 64bit floating point value */
Q78 = 20, /**< 16bit specific signed fixed point precision */
I16 = 30, /**< 16bit signed integer value */
U4 = 39, /**< 4bit unsigned integer value */
U8 = 40, /**< 8bit unsigned integer value */
I4 = 49, /**< 4bit signed integer value */
I8 = 50, /**< 8bit signed integer value */
U16 = 60, /**< 16bit unsigned integer value */

View File

@@ -177,7 +177,9 @@ typedef enum {
FP64 = 13, /**< 64bit floating point value */
Q78 = 20, /**< 16bit specific signed fixed point precision */
I16 = 30, /**< 16bit signed integer value */
U4 = 39, /**< 4bit unsigned integer value */
U8 = 40, /**< 8bit unsigned integer value */
I4 = 49, /**< 4bit signed integer value */
I8 = 50, /**< 8bit signed integer value */
U16 = 60, /**< 16bit unsigned integer value */
I32 = 70, /**< 32bit signed integer value */

View File

@@ -82,7 +82,9 @@ std::map<IE::Precision, precision_e> precision_map = {{IE::Precision::UNSPECIFIE
{IE::Precision::FP64, precision_e::FP64},
{IE::Precision::Q78, precision_e::Q78},
{IE::Precision::I16, precision_e::I16},
{IE::Precision::U4, precision_e::U4},
{IE::Precision::U8, precision_e::U8},
{IE::Precision::I4, precision_e::I4},
{IE::Precision::I8, precision_e::I8},
{IE::Precision::U16, precision_e::U16},
{IE::Precision::I32, precision_e::I32},
@@ -1369,7 +1371,7 @@ IEStatusCode ie_blob_make_memory(const tensor_desc_t *tensorDesc, ie_blob_t **bl
_blob->object = IE::make_shared_blob<uint8_t>(tensor);
} else if (prec == IE::Precision::U16) {
_blob->object = IE::make_shared_blob<uint16_t>(tensor);
} else if (prec == IE::Precision::I8 || prec == IE::Precision::BIN) {
} else if (prec == IE::Precision::I8 || prec == IE::Precision::BIN || prec == IE::Precision::I4 || prec == IE::Precision::U4) {
_blob->object = IE::make_shared_blob<int8_t>(tensor);
} else if (prec == IE::Precision::I16 || prec == IE::Precision::FP16 || prec == IE::Precision::Q78) {
_blob->object = IE::make_shared_blob<int16_t>(tensor);
@@ -1434,7 +1436,7 @@ IEStatusCode ie_blob_make_memory_from_preallocated(const tensor_desc_t *tensorDe
} else if (prec == IE::Precision::U16) {
uint16_t *p = reinterpret_cast<uint16_t *>(ptr);
_blob->object = IE::make_shared_blob(tensor, p, size);
} else if (prec == IE::Precision::I8 || prec == IE::Precision::BIN) {
} else if (prec == IE::Precision::I8 || prec == IE::Precision::BIN || prec == IE::Precision::I4 || prec == IE::Precision::U4) {
int8_t *p = reinterpret_cast<int8_t *>(ptr);
_blob->object = IE::make_shared_blob(tensor, p, size);
} else if (prec == IE::Precision::I16 || prec == IE::Precision::FP16 || prec == IE::Precision::Q78) {

View File

@@ -19,7 +19,7 @@ import numpy as np
from enum import Enum
supported_precisions = ["FP32", "FP64", "FP16", "I64", "U64", "I32", "U32",
"I16", "I8", "U16", "U8", "BOOL", "BIN", "BF16"]
"I16", "I4", "I8", "U16", "U4", "U8", "BOOL", "BIN", "BF16"]
known_plugins = ['CPU', 'GPU', 'FPGA', 'MYRIAD', 'HETERO', 'HDDL', 'MULTI']
@@ -36,7 +36,9 @@ format_map = {
'U32' : np.uint32,
'I16' : np.int16,
'U16' : np.uint16,
'I4' : np.int8,
'I8' : np.int8,
'U4' : np.int8,
'U8' : np.uint8,
'BOOL' : np.uint8,
'BIN' : np.int8,

View File

@@ -161,7 +161,7 @@ cdef class Blob:
self._ptr = C.make_shared_blob[uint16_t](c_tensor_desc)
elif precision == "U8" or precision == "BOOL":
self._ptr = C.make_shared_blob[uint8_t](c_tensor_desc)
elif precision == "I8" or precision == "BIN":
elif precision == "I8" or precision == "BIN" or precision == "I4" or precision == "U4":
self._ptr = C.make_shared_blob[int8_t](c_tensor_desc)
elif precision == "I32":
self._ptr = C.make_shared_blob[int32_t](c_tensor_desc)
@@ -205,7 +205,7 @@ cdef class Blob:
elif precision == "U8" or precision == "BOOL":
U8_array_memview = self._array_data
self._ptr = C.make_shared_blob[uint8_t](c_tensor_desc, &U8_array_memview[0], U8_array_memview.shape[0])
elif precision == "I8" or precision == "BIN":
elif precision == "I8" or precision == "BIN" or precision == "I4" or precision == "U4":
I8_array_memview = self._array_data
self._ptr = C.make_shared_blob[int8_t](c_tensor_desc, &I8_array_memview[0], I8_array_memview.shape[0])
elif precision == "I32":

View File

@@ -33,7 +33,9 @@ public:
FP64 = 13, /**< 64bit floating point value */
Q78 = 20, /**< 16bit specific signed fixed point precision */
I16 = 30, /**< 16bit signed integer value */
U4 = 39, /**< 4bit unsigned integer value */
U8 = 40, /**< 8bit unsigned integer value */
I4 = 49, /**< 4bit signed integer value */
I8 = 50, /**< 8bit signed integer value */
U16 = 60, /**< 16bit unsigned integer value */
I32 = 70, /**< 32bit signed integer value */
@@ -116,10 +118,12 @@ public:
CASE(FP64, double);
CASE2(FP16, int16_t, uint16_t);
CASE2(BF16, int16_t, uint16_t);
CASE2(I4, int8_t, uint8_t);
CASE(I8, int8_t);
CASE(I16, int16_t);
CASE(I32, int32_t);
CASE(I64, int64_t);
CASE(U4, uint8_t);
CASE(U8, uint8_t);
CASE(U16, uint16_t);
CASE(U32, uint32_t);
@@ -224,8 +228,8 @@ public:
static const std::unordered_map<std::string, ePrecision> names = {
#define PRECISION_NAME(s) {#s, s}
PRECISION_NAME(Q78), PRECISION_NAME(BOOL), PRECISION_NAME(BF16),
PRECISION_NAME(I8), PRECISION_NAME(I16), PRECISION_NAME(I32), PRECISION_NAME(I64),
PRECISION_NAME(U8), PRECISION_NAME(U16), PRECISION_NAME(U32), PRECISION_NAME(U64),
PRECISION_NAME(I4), PRECISION_NAME(I8), PRECISION_NAME(I16), PRECISION_NAME(I32), PRECISION_NAME(I64),
PRECISION_NAME(U4), PRECISION_NAME(U8), PRECISION_NAME(U16), PRECISION_NAME(U32), PRECISION_NAME(U64),
PRECISION_NAME(FP32), PRECISION_NAME(FP64), PRECISION_NAME(FP16), PRECISION_NAME(MIXED),
PRECISION_NAME(BIN),
#undef PRECISION_NAME
@@ -264,7 +268,7 @@ public:
(precisionInfo.value == Precision::I16) || (precisionInfo.value == Precision::I8) ||
(precisionInfo.value == Precision::I32) || (precisionInfo.value == Precision::I64) ||
(precisionInfo.value == Precision::BIN) || (precisionInfo.value == Precision::BF16) ||
(precisionInfo.value == Precision::CUSTOM);
(precisionInfo.value == Precision::CUSTOM) || (precisionInfo.value == Precision::I4);
}
protected:
@@ -309,10 +313,12 @@ protected:
CASE(FP64);
CASE(FP16);
CASE(BF16);
CASE(I4);
CASE(I8);
CASE(I16);
CASE(I32);
CASE(I64);
CASE(U4);
CASE(U8);
CASE(U16);
CASE(U32);
@@ -366,10 +372,18 @@ struct PrecisionTrait<Precision::U16> {
using value_type = uint16_t;
};
template <>
struct PrecisionTrait<Precision::U4> {
using value_type = uint8_t;
};
template <>
struct PrecisionTrait<Precision::U8> {
using value_type = uint8_t;
};
template <>
struct PrecisionTrait<Precision::I4> {
using value_type = int8_t;
};
template <>
struct PrecisionTrait<Precision::I8> {
using value_type = int8_t;
};

View File

@@ -110,10 +110,12 @@ InferenceEngine::Blob::Ptr make_blob_with_precision(InferenceEngine::Precision p
USE_FACTORY(FP64);
USE_FACTORY(FP16);
USE_FACTORY(Q78);
USE_FACTORY(I4);
USE_FACTORY(I8);
USE_FACTORY(I16);
USE_FACTORY(I32);
USE_FACTORY(I64);
USE_FACTORY(U4);
USE_FACTORY(U8);
USE_FACTORY(U16);
USE_FACTORY(U32);

View File

@@ -26,8 +26,12 @@ inline ::ngraph::element::Type convertPrecision(const Precision& precision) {
return ::ngraph::element::Type(::ngraph::element::Type_t::f16);
case Precision::BF16:
return ::ngraph::element::Type(::ngraph::element::Type_t::bf16);
case Precision::U4:
return ::ngraph::element::Type(::ngraph::element::Type_t::u4);
case Precision::U8:
return ::ngraph::element::Type(::ngraph::element::Type_t::u8);
case Precision::I4:
return ::ngraph::element::Type(::ngraph::element::Type_t::i4);
case Precision::I8:
return ::ngraph::element::Type(::ngraph::element::Type_t::i8);
case Precision::U16:
@@ -63,6 +67,8 @@ inline ::ngraph::element::Type convertPrecision(const std::string& precision) {
return ::ngraph::element::Type(::ngraph::element::Type_t::bf16);
} else if (precision == "f64" || precision == "FP64") {
return ::ngraph::element::Type(::ngraph::element::Type_t::f64);
} else if (precision == "i4" || precision == "I4") {
return ::ngraph::element::Type(::ngraph::element::Type_t::i4);
} else if (precision == "i8" || precision == "I8") {
return ::ngraph::element::Type(::ngraph::element::Type_t::i8);
} else if (precision == "i16" || precision == "I16") {
@@ -73,6 +79,8 @@ inline ::ngraph::element::Type convertPrecision(const std::string& precision) {
return ::ngraph::element::Type(::ngraph::element::Type_t::i64);
} else if (precision == "u1" || precision == "U1" || precision == "BIN" || precision == "bin") {
return ::ngraph::element::Type(::ngraph::element::Type_t::u1);
} else if (precision == "u4" || precision == "U4") {
return ::ngraph::element::Type(::ngraph::element::Type_t::u4);
} else if (precision == "u8" || precision == "U8") {
return ::ngraph::element::Type(::ngraph::element::Type_t::u8);
} else if (precision == "u16" || precision == "U16") {
@@ -102,6 +110,8 @@ inline Precision convertPrecision(const ::ngraph::element::Type& precision) {
return Precision(Precision::FP64);
case ::ngraph::element::Type_t::bf16:
return Precision(Precision::BF16);
case ::ngraph::element::Type_t::i4:
return Precision(Precision::I4);
case ::ngraph::element::Type_t::i8:
return Precision(Precision::I8);
case ::ngraph::element::Type_t::i16:
@@ -110,6 +120,8 @@ inline Precision convertPrecision(const ::ngraph::element::Type& precision) {
return Precision(Precision::I32);
case ::ngraph::element::Type_t::i64:
return Precision(Precision::I64);
case ::ngraph::element::Type_t::u4:
return Precision(Precision::U4);
case ::ngraph::element::Type_t::u8:
return Precision(Precision::U8);
case ::ngraph::element::Type_t::u16:

View File

@@ -355,42 +355,194 @@ inline int32_t convert_value<uint32_t, int32_t>(uint32_t val) {
}
namespace {
template <element::Type_t PREC_FROM, element::Type_t PREC_TO>
std::shared_ptr<ngraph::Node> change_constant_precision(std::shared_ptr<opset4::Constant>& constant) {
using src_type = typename element_type_traits<PREC_FROM>::value_type;
using dst_type = typename element_type_traits<PREC_TO>::value_type;
template <element::Type_t PREC_FROM, element::Type_t PREC_TO>
std::shared_ptr<ngraph::Node> change_constant_precision(std::shared_ptr<opset4::Constant>& constant) {
using src_type = typename element_type_traits<PREC_FROM>::value_type;
using dst_type = typename element_type_traits<PREC_TO>::value_type;
const auto * src_data = constant->get_data_ptr<src_type>();
const auto size = shape_size(constant->get_shape());
const auto * src_data = constant->get_data_ptr<src_type>();
const auto size = shape_size(constant->get_shape());
auto new_constant = std::make_shared<ngraph::opset4::Constant>(PREC_TO, constant->get_shape());
auto * dst_data = const_cast<dst_type *>(reinterpret_cast<const dst_type *>(new_constant->get_data_ptr()));
if (dst_data == nullptr)
throw ngraph_error("Can't get destination data pointer");
auto new_constant = std::make_shared<ngraph::opset4::Constant>(PREC_TO, constant->get_shape());
auto * dst_data = const_cast<dst_type *>(reinterpret_cast<const dst_type *>(new_constant->get_data_ptr()));
if (dst_data == nullptr)
throw ngraph_error("Can't get destination data pointer");
for (size_t i = 0; i < size; ++i) {
dst_data[i] = convert_value<src_type, dst_type>(src_data[i]);
for (size_t i = 0; i < size; ++i) {
dst_data[i] = convert_value<src_type, dst_type>(src_data[i]);
}
return new_constant;
}
template <>
std::shared_ptr<Node> change_constant_precision<element::Type_t::f16, element::Type_t::f32>(std::shared_ptr<opset4::Constant>& constant) {
using src_type = typename element_type_traits<element::Type_t::f16>::value_type;
using dst_type = typename element_type_traits<element::Type_t::f32>::value_type;
const auto * src_data = constant->get_data_ptr<src_type>();
const auto size = shape_size(constant->get_shape());
auto new_constant = std::make_shared<ngraph::opset4::Constant>(element::Type_t::f32, constant->get_shape());
auto * dst_data = const_cast<dst_type *>(reinterpret_cast<const dst_type *>(new_constant->get_data_ptr()));
if (dst_data == nullptr)
throw ngraph_error("Can't get destination data pointer");
ngraph::runtime::reference::convert<src_type, dst_type>(src_data, dst_data, size);
return new_constant;
}
struct EnumClassHash {
template<class T>
std::size_t operator()(T t) const {
return static_cast<size_t>(t);
}
};
/**
* @brief Method converts low precision integer types
*
* @param src source value !!! the type must be unsigned !!!
* @param dst destination value !!! the type must be unsigned !!!
* @param src_offset source offset (for custom data types)
* @param src_size source size (for custom data types)
* @param dst_offset destination offset
* @param dst_size destination size
* @param is_signed the type of source data
*/
template <class SRC, class DST>
void convert_lp_value(const SRC& src, DST& dst, size_t src_offset, size_t src_size, size_t dst_offset, size_t dst_size, bool is_signed) {
constexpr SRC src_max = std::numeric_limits<SRC>::max();
constexpr DST dst_max = std::numeric_limits<DST>::max();
// Make a shift for the source value
// src [11101000] offset 2, size 4
// val [00011101]
SRC val = src >> src_offset;
// dst [10001111 00000100] offset 5 size 9
// new_val [00000000 00000000]
DST new_val = 0;
// If source type is signed
if (is_signed) {
// Get the sign of value
// sign [00000000]
// invert value in order to use XOR
SRC sign = (~(val >> (src_size - 1))) & 0b1;
// Calculate diff in order to clean bits which don't exist in the source value
// diff 5
size_t diff = sizeof(SRC)*8 - src_size + 1;
// Clean unnecessary bits
// val [10100000]
val = val << diff;
// val [00000101]
val = (val >> diff);
// Negative number
if (!sign) {
// val [11110101]
val |= (src_max << (diff - 1));
// new_val [00000001 11111111]
new_val = (sign << (dst_size - 1)) ^ (dst_max >> (sizeof(DST) * 8 - dst_size));
// new_val [00000001 11110101]
new_val &= (dst_max << sizeof(SRC)*8) | val;
} else {
// new_val [00000000 00000101]
new_val = val;
}
return new_constant;
} else {
// Calculate diff in order to clean bits which don't exist in the source value
// diff 4
size_t diff = sizeof(SRC)*8 - src_size;
// Clean unnecessary bits
// val [11010000]
val = val << diff;
// val [00001101]
val = val >> diff;
// new_val [00000000 00001101]
new_val = val;
}
template <>
std::shared_ptr<Node> change_constant_precision<element::Type_t::f16, element::Type_t::f32>(std::shared_ptr<opset4::Constant>& constant) {
using src_type = typename element_type_traits<element::Type_t::f16>::value_type;
using dst_type = typename element_type_traits<element::Type_t::f32>::value_type;
// Make a mask in order to save other values if DST contains several values
// mask [11000000 00011111]
DST mask = 0;
if (dst_offset + dst_size < sizeof(DST) * 8)
mask = (dst_max << (dst_offset + dst_size));
if (dst_offset != 0)
mask |= (dst_max >> (sizeof(DST) * 8 - dst_offset));
const auto * src_data = constant->get_data_ptr<src_type>();
const auto size = shape_size(constant->get_shape());
// Add mask to our converted value
// signed: new_val [11100000 10111111]
// unsigned: new_val [11000001 10111111]
new_val = mask | (new_val << dst_offset);
auto new_constant = std::make_shared<ngraph::opset4::Constant>(element::Type_t::f32, constant->get_shape());
auto * dst_data = const_cast<dst_type *>(reinterpret_cast<const dst_type *>(new_constant->get_data_ptr()));
if (dst_data == nullptr)
throw ngraph_error("Can't get destination data pointer");
// Add our value to destination
// dst: [10111111 11100100]
dst |= ~mask;
// signed: dst [10100000 10100100]
// unsigned: dst [10000001 10100100]
dst &= new_val;
}
ngraph::runtime::reference::convert<src_type, dst_type>(src_data, dst_data, size);
std::shared_ptr<Node> convert_low_precisions_int(std::shared_ptr<opset4::Constant>& constant, element::Type to) {
// Supported integer precisions
static const std::unordered_set<element::Type_t, EnumClassHash> supported_integer_precisions = {
element::i4,
element::u4,
element::u1
};
// Get source element type and source data
auto src_type = constant->get_element_type();
const auto* src_data = reinterpret_cast<const uint8_t*>(constant->get_data_ptr());
return new_constant;
// We support conversion only if several elements can be represented in one instance of some C++ common data type without any exception,
// destination data type should be bigger than source and destination data type should be real
if (!supported_integer_precisions.count(src_type) || (src_type.size() * 8) % src_type.bitwidth() ||
(to.size() * 8) % to.bitwidth() || to.is_real() || to.bitwidth() < src_type.bitwidth())
throw ngraph_error("Convert low precision for " + constant->get_element_type().get_type_name() + " to " +
to.get_type_name() + " is not implemented!");
// Create a new constant operation and get destination data
auto new_constant = std::make_shared<ngraph::opset4::Constant>(to, constant->get_shape());
auto* dst_data = const_cast<uint8_t *>(reinterpret_cast<const uint8_t *>(new_constant->get_data_ptr()));
// Check pointers
if (src_data == nullptr || dst_data == nullptr)
throw ngraph_error("Can't get data pointer");
// Convert values
const auto size = shape_size(constant->get_shape());
for (size_t i = 0; i < size; i++) {
// Calculate indexes
size_t dst_idx = i / ((to.size() * 8) / to.bitwidth());
size_t src_idx = i / ((src_type.size() * 8) / src_type.bitwidth());
// Calculate offsets inside the indexes
size_t dst_off = (to.size() * 8 - to.bitwidth()) - to.bitwidth() * (i % ((to.size() * 8) / to.bitwidth()));
size_t src_off = (src_type.size() * 8 - src_type.bitwidth()) - src_type.bitwidth() * (i % ((src_type.size() * 8) / src_type.bitwidth()));
// Source type at the current moment always less than 1 byte
// Select the right destination type
switch (to.size()) {
case 1:
convert_lp_value<uint8_t, uint8_t>(src_data[src_idx], dst_data[dst_idx], src_off, src_type.bitwidth(),
dst_off, to.bitwidth(), src_type.is_signed());
break;
case 2:
convert_lp_value<uint8_t, uint16_t>(src_data[src_idx], reinterpret_cast<uint16_t*>(dst_data)[dst_idx], src_off, src_type.bitwidth(),
dst_off, to.bitwidth(), src_type.is_signed());
break;
case 4:
convert_lp_value<uint8_t, uint32_t>(src_data[src_idx], reinterpret_cast<uint32_t*>(dst_data)[dst_idx], src_off, src_type.bitwidth(),
dst_off, to.bitwidth(), src_type.is_signed());
break;
case 8:
convert_lp_value<uint8_t, uint64_t>(src_data[src_idx], reinterpret_cast<uint64_t*>(dst_data)[dst_idx], src_off, src_type.bitwidth(),
dst_off, to.bitwidth(), src_type.is_signed());
break;
default:
throw ngraph_error("Unsupported element size!");
}
}
return new_constant;
}
} // namespace
bool fuse_type_to_constant(const std::shared_ptr<ngraph::Node> & node, element::Type to, const std::vector<Input<Node>> & consumers) {
@@ -417,8 +569,10 @@ bool fuse_type_to_constant(const std::shared_ptr<ngraph::Node> & node, element::
new_const = change_constant_precision<element::Type_t::boolean, element::Type_t::u8>(constant);
} else if (from == element::boolean && to == element::i32) {
new_const = change_constant_precision<element::Type_t::boolean, element::Type_t::i32>(constant);
} else if (from == element::i4 || from == element::u4 || from == element::u1) {
new_const = convert_low_precisions_int(constant, to);
} else {
throw ngraph_error("not supported");
throw ngraph_error("Precision conversion from " + from.get_type_name() + " to " + to.get_type_name() + " is not supported");
}
for (auto & output : consumers) {
output.replace_source_output(new_const);

View File

@@ -62,9 +62,11 @@ InferenceEngine::Blob::Ptr createBlob(InferenceEngine::Precision precision, Size
return make_shared_blob<uint64_t>(tensorDesc);
case InferenceEngine::Precision::U16:
return make_shared_blob<uint16_t>(tensorDesc);
case InferenceEngine::Precision::I4:
case InferenceEngine::Precision::I8:
case InferenceEngine::Precision::BIN:
return make_shared_blob<int8_t>(tensorDesc);
case InferenceEngine::Precision::U4:
case InferenceEngine::Precision::U8:
return make_shared_blob<uint8_t>(tensorDesc);
default:
@@ -134,9 +136,11 @@ void FillBlob(Blob::Ptr& inputBlob) {
return FillBlobRandom<uint64_t>(inputBlob);
case InferenceEngine::Precision::U16:
return FillBlobRandom<uint16_t>(inputBlob);
case InferenceEngine::Precision::I4:
case InferenceEngine::Precision::I8:
case InferenceEngine::Precision::BIN:
return FillBlobRandom<int8_t>(inputBlob);
case InferenceEngine::Precision::U4:
case InferenceEngine::Precision::U8:
return FillBlobRandom<uint8_t>(inputBlob);
default:
@@ -224,9 +228,11 @@ bool IsCorrectBlobCopy(Blob::Ptr& srcBlob, Blob::Ptr& dstBlob) {
return IsCorrectBlobCopy_Impl<uint64_t >(srcBlob, dstBlob);
case InferenceEngine::Precision::U16:
return IsCorrectBlobCopy_Impl<uint16_t>(srcBlob, dstBlob);
case InferenceEngine::Precision::I4:
case InferenceEngine::Precision::I8:
case InferenceEngine::Precision::BIN:
return IsCorrectBlobCopy_Impl<int8_t>(srcBlob, dstBlob);
case InferenceEngine::Precision::U4:
case InferenceEngine::Precision::U8:
return IsCorrectBlobCopy_Impl<uint8_t>(srcBlob, dstBlob);
default:
@@ -351,9 +357,11 @@ bool IsEqualBlobCopy(Blob::Ptr& srcBlob, Blob::Ptr& dstBlob) {
return IsEqualBlobCopy_Impl<uint64_t>(srcBlob, dstBlob);
case InferenceEngine::Precision::I64:
return IsEqualBlobCopy_Impl<int64_t>(srcBlob, dstBlob);
case InferenceEngine::Precision::I4:
case InferenceEngine::Precision::I8:
case InferenceEngine::Precision::BIN:
return IsEqualBlobCopy_Impl<int8_t>(srcBlob, dstBlob);
case InferenceEngine::Precision::U4:
case InferenceEngine::Precision::U8:
return IsEqualBlobCopy_Impl<uint8_t>(srcBlob, dstBlob);
case InferenceEngine::Precision::U16:
@@ -404,9 +412,11 @@ void copy3DBlobsAllBytesWithReLayoutWrapper(const Blob::Ptr& srcLayoutBlob, Blob
return copy3DBlobsAllBytesWithReLayout<int64_t>(srcLayoutBlob, trgLayoutBlob);
case InferenceEngine::Precision::U16:
return copy3DBlobsAllBytesWithReLayout<uint16_t>(srcLayoutBlob, trgLayoutBlob);
case InferenceEngine::Precision::I4:
case InferenceEngine::Precision::I8:
case InferenceEngine::Precision::BIN:
return copy3DBlobsAllBytesWithReLayout<int8_t>(srcLayoutBlob, trgLayoutBlob);
case InferenceEngine::Precision::U4:
case InferenceEngine::Precision::U8:
return copy3DBlobsAllBytesWithReLayout<uint8_t>(srcLayoutBlob, trgLayoutBlob);
default:

View File

@@ -22,8 +22,10 @@ TEST_F(PrecisionTests, ShowsCorrectPrecisionNames) {
ASSERT_STREQ(Precision(Precision::I32).name(), "I32");
ASSERT_STREQ(Precision(Precision::U32).name(), "U32");
ASSERT_STREQ(Precision(Precision::U16).name(), "U16");
ASSERT_STREQ(Precision(Precision::I4).name(), "I4");
ASSERT_STREQ(Precision(Precision::I8).name(), "I8");
ASSERT_STREQ(Precision(Precision::Q78).name(), "Q78");
ASSERT_STREQ(Precision(Precision::U4).name(), "U4");
ASSERT_STREQ(Precision(Precision::U8).name(), "U8");
ASSERT_STREQ(Precision(Precision::MIXED).name(), "MIXED");
ASSERT_STREQ(Precision(Precision::UNSPECIFIED).name(), "UNSPECIFIED");
@@ -41,9 +43,11 @@ TEST_F(PrecisionTests, sizeIsCorrect) {
ASSERT_EQ(Precision(Precision::U32).size(), 4);
ASSERT_EQ(Precision(Precision::I16).size(), 2);
ASSERT_EQ(Precision(Precision::U16).size(), 2);
ASSERT_EQ(Precision(Precision::I4).size(), 1);
ASSERT_EQ(Precision(Precision::I8).size(), 1);
ASSERT_EQ(Precision(Precision::Q78).size(), 2);
ASSERT_EQ(Precision(Precision::U8).size(), 1);
ASSERT_EQ(Precision(Precision::U4).size(), 1);
ASSERT_EQ(Precision(10 * 8).size(), 10);
ASSERT_ANY_THROW(Precision(Precision::MIXED).size());
ASSERT_ANY_THROW(Precision(Precision::UNSPECIFIED).size());
@@ -60,7 +64,9 @@ TEST_F(PrecisionTests, is_float) {
ASSERT_FALSE(Precision(Precision::I16).is_float());
ASSERT_FALSE(Precision(Precision::U16).is_float());
ASSERT_FALSE(Precision(Precision::I8).is_float());
ASSERT_FALSE(Precision(Precision::I4).is_float());
ASSERT_FALSE(Precision(Precision::Q78).is_float());
ASSERT_FALSE(Precision(Precision::U4).is_float());
ASSERT_FALSE(Precision(Precision::U8).is_float());
ASSERT_FALSE(Precision(Precision::MIXED).is_float());
ASSERT_FALSE(Precision(10).is_float());
@@ -78,8 +84,10 @@ TEST_F(PrecisionTests, constructFromSTR) {
ASSERT_EQ(Precision(Precision::U32), Precision::FromStr("U32"));
ASSERT_EQ(Precision(Precision::I16), Precision::FromStr("I16"));
ASSERT_EQ(Precision(Precision::U16), Precision::FromStr("U16"));
ASSERT_EQ(Precision(Precision::I4), Precision::FromStr("I4"));
ASSERT_EQ(Precision(Precision::I8), Precision::FromStr("I8"));
ASSERT_EQ(Precision(Precision::Q78), Precision::FromStr("Q78"));
ASSERT_EQ(Precision(Precision::U4), Precision::FromStr("U4"));
ASSERT_EQ(Precision(Precision::U8), Precision::FromStr("U8"));
ASSERT_EQ(Precision(Precision::MIXED), Precision::FromStr("MIXED"));
ASSERT_EQ(Precision(static_cast<Precision::ePrecision >(-3)), Precision::FromStr("UNSPECIFIED"));

View File

@@ -567,7 +567,7 @@ void constant_convert_test(element::Type_t type_from, element::Type_t type_to, F
std::shared_ptr<ngraph::Function> f(nullptr);
std::string expected_friendly_name;
{
auto c = opset4::Constant::create(type_from, Shape{}, {value});
auto c = std::make_shared<opset4::Constant>(type_from, Shape{}, &value);
expected_friendly_name = c->get_friendly_name();
f = std::make_shared<Function>(NodeVector{c}, ParameterVector{});
@@ -608,7 +608,7 @@ TEST(TransformationTests, ConvertPrecision_ConstantConversion_U64MaxToI32) {
}
TEST(TransformationTests, ConvertPrecision_ConstantConversion_U64ToI32) {
constant_convert_test(element::Type_t::u64, element::Type_t::i32, 42, 42);
constant_convert_test<uint64_t, int32_t>(element::Type_t::u64, element::Type_t::i32, 42, 42);
}
TEST(TransformationTests, ConvertPrecision_ConstantConversion_U32MinToI32) {
@@ -630,3 +630,152 @@ TEST(TransformationTests, ConvertPrecision_ConstantConversion_BoolToU8) {
constant_convert_test(element::Type_t::boolean, element::Type_t::u8, true, 1);
constant_convert_test(element::Type_t::boolean, element::Type_t::u8, false, 0);
}
TEST(TransformationTests, ConvertPrecision_ConstantConversion_U4ToI8) {
constant_convert_test(element::Type_t::u4, element::Type_t::i8, 171, 10);
}
TEST(TransformationTests, ConvertPrecision_ConstantConversion_U4ToU8) {
constant_convert_test(element::Type_t::u4, element::Type_t::u8, 171, 10);
}
TEST(TransformationTests, ConvertPrecision_ConstantConversion_U4ToI8_2) {
constant_convert_test(element::Type_t::u4, element::Type_t::i8, 96, 6);
}
TEST(TransformationTests, ConvertPrecision_ConstantConversion_U4ToU8_96) {
constant_convert_test(element::Type_t::u4, element::Type_t::u8, 96, 6);
}
TEST(TransformationTests, ConvertPrecision_ConstantConversion_I4ToU8) {
constant_convert_test(element::Type_t::i4, element::Type_t::u8, 96, 6);
}
TEST(TransformationTests, ConvertPrecision_ConstantConversion_I4ToI8) {
constant_convert_test(element::Type_t::i4, element::Type_t::i8, 96, 6);
}
TEST(TransformationTests, ConvertPrecision_ConstantConversion_I4ToU8_neg) {
constant_convert_test(element::Type_t::i4, element::Type_t::u8, 171, 242);
}
TEST(TransformationTests, ConvertPrecision_ConstantConversion_I4ToI8_neg) {
constant_convert_test(element::Type_t::i4, element::Type_t::i8, 171, -14);
}
TEST(TransformationTests, ConvertPrecision_ConstantConversion_U4ToI32) {
constant_convert_test(element::Type_t::u4, element::Type_t::i32, 171, 10);
}
TEST(TransformationTests, ConvertPrecision_ConstantConversion_U4ToU32) {
constant_convert_test(element::Type_t::u4, element::Type_t::u32, 171, 10);
}
TEST(TransformationTests, ConvertPrecision_ConstantConversion_I4ToU32) {
constant_convert_test(element::Type_t::i4, element::Type_t::u32, 96, 6);
}
TEST(TransformationTests, ConvertPrecision_ConstantConversion_I4ToI32) {
constant_convert_test(element::Type_t::i4, element::Type_t::i32, 96, 6);
}
TEST(TransformationTests, ConvertPrecision_ConstantConversion_I4ToU32_neg) {
constant_convert_test(element::Type_t::i4, element::Type_t::u32, 171, -14);
}
TEST(TransformationTests, ConvertPrecision_ConstantConversion_I4ToI32_neg) {
constant_convert_test(element::Type_t::i4, element::Type_t::i32, 171, -14);
}
TEST(TransformationTests, ConvertPrecision_ConstantConversion_U4ToI16) {
constant_convert_test(element::Type_t::u4, element::Type_t::i16, 171, 10);
}
TEST(TransformationTests, ConvertPrecision_ConstantConversion_U4ToU16) {
constant_convert_test(element::Type_t::u4, element::Type_t::u16, 171, 10);
}
TEST(TransformationTests, ConvertPrecision_ConstantConversion_I4ToU16) {
constant_convert_test(element::Type_t::i4, element::Type_t::u16, 96, 6);
}
TEST(TransformationTests, ConvertPrecision_ConstantConversion_I4ToI16) {
constant_convert_test(element::Type_t::i4, element::Type_t::i16, 96, 6);
}
TEST(TransformationTests, ConvertPrecision_ConstantConversion_I4ToU16_neg) {
constant_convert_test(element::Type_t::i4, element::Type_t::u16, 171, 65522);
}
TEST(TransformationTests, ConvertPrecision_ConstantConversion_I4ToI16_neg) {
constant_convert_test(element::Type_t::i4, element::Type_t::i16, 171, -14);
}
TEST(TransformationTests, ConvertPrecision_ConstantConversion_U4ToI64) {
constant_convert_test(element::Type_t::u4, element::Type_t::i64, 171, 10);
}
TEST(TransformationTests, ConvertPrecision_ConstantConversion_U4ToU64) {
constant_convert_test(element::Type_t::u4, element::Type_t::u64, 171, 10);
}
TEST(TransformationTests, ConvertPrecision_ConstantConversion_I4ToU64) {
constant_convert_test(element::Type_t::i4, element::Type_t::u64, 96, 6);
}
TEST(TransformationTests, ConvertPrecision_ConstantConversion_I4ToI64) {
constant_convert_test(element::Type_t::i4, element::Type_t::i64, 96, 6);
}
TEST(TransformationTests, ConvertPrecision_ConstantConversion_I4ToU64_neg) {
constant_convert_test(element::Type_t::i4, element::Type_t::u64, 171, -14);
}
TEST(TransformationTests, ConvertPrecision_ConstantConversion_I4ToI64_neg) {
constant_convert_test(element::Type_t::i4, element::Type_t::i64, 171, -14);
}
TEST(TransformationTests, ConvertPrecision_ConstantConversion_U1ToU8) {
std::shared_ptr<ngraph::Function> f(nullptr);
uint8_t value = 171;
std::string expected_friendly_name;
{
auto c = std::make_shared<opset4::Constant>(element::Type_t::u1, Shape{2}, &value);
expected_friendly_name = c->get_friendly_name();
f = std::make_shared<Function>(NodeVector{c}, ParameterVector{});
pass::Manager manager;
manager.register_pass<ngraph::pass::ConvertPrecision>(element::Type_t::u1, element::Type_t::u8);
manager.run_passes(f);
}
auto ops = f->get_ordered_ops();
auto c = std::dynamic_pointer_cast<opset4::Constant>(ops[0]);
ASSERT_NE(c, nullptr);
ASSERT_EQ(c->get_friendly_name(), expected_friendly_name);
auto* actual = reinterpret_cast<const uint8_t*>(c->get_data_ptr());
ASSERT_EQ(1, actual[0]);
ASSERT_EQ(0, actual[1]);
}
TEST(TransformationTests, ConvertPrecision_ConstantConversion_U1ToU4) {
std::shared_ptr<ngraph::Function> f(nullptr);
uint8_t value = 171;
std::string expected_friendly_name;
{
auto c = std::make_shared<opset4::Constant>(element::Type_t::u1, Shape{2}, &value);
expected_friendly_name = c->get_friendly_name();
f = std::make_shared<Function>(NodeVector{c}, ParameterVector{});
pass::Manager manager;
manager.register_pass<ngraph::pass::ConvertPrecision>(element::Type_t::u1, element::Type_t::u4);
manager.run_passes(f);
}
auto ops = f->get_ordered_ops();
auto c = std::dynamic_pointer_cast<opset4::Constant>(ops[0]);
ASSERT_NE(c, nullptr);
ASSERT_EQ(c->get_friendly_name(), expected_friendly_name);
auto actual = reinterpret_cast<const uint8_t*>(c->get_data_ptr())[0];
ASSERT_EQ(16, actual);
}

View File

@@ -97,6 +97,10 @@ namespace ngraph
throw ngraph_error("make_constant: Unsupported element type 'boolean'");
case element::Type_t::u1:
throw ngraph_error("make_constant: Unsupported element type 'u1'");
case element::Type_t::i4:
throw ngraph_error("make_constant: Unsupported element type 'i4'");
case element::Type_t::u4:
throw ngraph_error("make_constant: Unsupported element type 'u4'");
case element::Type_t::undefined:
throw ngraph_error("make_constant: Unsupported element type 'undefined'");
}

View File

@@ -208,8 +208,10 @@ namespace ngraph
typename element_type_traits<element::Type_t::u64>::value_type>(
value));
break;
case element::Type_t::u1: throw std::runtime_error("unsupported type");
case element::Type_t::undefined: throw std::runtime_error("unsupported type");
case element::Type_t::i4:
case element::Type_t::u1:
case element::Type_t::u4:
case element::Type_t::undefined:
case element::Type_t::dynamic: throw std::runtime_error("unsupported type");
}
#if defined(__GNUC__) && !(__GNUC__ == 4 && __GNUC_MINOR__ == 8)
@@ -586,8 +588,10 @@ namespace ngraph
case element::Type_t::u64:
write_buffer<uint64_t, T>(target, source, target_element_count);
break;
case element::Type_t::u1: throw std::runtime_error("unsupported type");
case element::Type_t::undefined: throw std::runtime_error("unsupported type");
case element::Type_t::i4:
case element::Type_t::u1:
case element::Type_t::u4:
case element::Type_t::undefined:
case element::Type_t::dynamic: throw std::runtime_error("unsupported type");
}
#if defined(__GNUC__) && !(__GNUC__ == 4 && __GNUC_MINOR__ == 8)

View File

@@ -46,11 +46,13 @@ namespace ngraph
f16,
f32,
f64,
i4,
i8,
i16,
i32,
i64,
u1,
u4,
u8,
u16,
u32,
@@ -134,11 +136,13 @@ namespace ngraph
constexpr Type f16(Type_t::f16);
constexpr Type f32(Type_t::f32);
constexpr Type f64(Type_t::f64);
constexpr Type i4(Type_t::i4);
constexpr Type i8(Type_t::i8);
constexpr Type i16(Type_t::i16);
constexpr Type i32(Type_t::i32);
constexpr Type i64(Type_t::i64);
constexpr Type u1(Type_t::u1);
constexpr Type u4(Type_t::u4);
constexpr Type u8(Type_t::u8);
constexpr Type u16(Type_t::u16);
constexpr Type u32(Type_t::u32);

View File

@@ -55,6 +55,12 @@ namespace ngraph
using value_type = double;
};
template <>
struct element_type_traits<element::Type_t::i4>
{
using value_type = int8_t;
};
template <>
struct element_type_traits<element::Type_t::i8>
{
@@ -85,6 +91,12 @@ namespace ngraph
using value_type = int8_t;
};
template <>
struct element_type_traits<element::Type_t::u4>
{
using value_type = int8_t;
};
template <>
struct element_type_traits<element::Type_t::u8>
{

View File

@@ -179,6 +179,14 @@ op::Constant::Constant(const element::Type& type,
{
throw std::runtime_error("deserialize unsupported type dynamic");
}
case element::Type_t::i4:
{
throw std::runtime_error("deserialize unsupported type i4");
}
case element::Type_t::u4:
{
throw std::runtime_error("deserialize unsupported type u4");
}
case element::Type_t::u1:
{
throw std::runtime_error("deserialize unsupported type u1");
@@ -291,6 +299,8 @@ op::Constant::Constant(const element::Type& type,
throw std::runtime_error("deserialize unsupported type undefined");
case element::Type_t::dynamic:
throw std::runtime_error("deserialize unsupported type dynamic");
case element::Type_t::i4: throw std::runtime_error("deserialize unsupported type i4");
case element::Type_t::u4: throw std::runtime_error("deserialize unsupported type u4");
case element::Type_t::u1: throw std::runtime_error("deserialize unsupported type u1");
}
m_all_elements_bitwise_identical = are_all_data_elements_bitwise_identical();
@@ -351,6 +361,21 @@ string op::Constant::convert_value_to_string(size_t index) const
break;
case element::Type_t::f32: rc = to_cpp_string(get_data_ptr<float>()[index]); break;
case element::Type_t::f64: rc = to_cpp_string(get_data_ptr<double>()[index]); break;
case element::Type_t::i4:
{
uint8_t i4data = (get_data_ptr<uint8_t>()[index / 2] >> (index % 2 ? 0 : 4)) & 0x0F;
int8_t data = i4data;
if ((i4data >> 3) & 0b1)
{
// negative number
data = (i4data & 0x7) | 0xF0;
}
rc = to_string(data);
break;
}
case element::Type_t::u4:
rc = to_string((get_data_ptr<uint8_t>()[index / 2] >> (index % 2 ? 0 : 4)) & 0x0F);
break;
case element::Type_t::i8: rc = to_string(get_data_ptr<int8_t>()[index]); break;
case element::Type_t::i16: rc = to_string(get_data_ptr<int16_t>()[index]); break;
case element::Type_t::i32: rc = to_string(get_data_ptr<int32_t>()[index]); break;
@@ -460,8 +485,10 @@ vector<string> op::Constant::get_value_strings() const
rc.push_back(to_string(value));
}
break;
case element::Type_t::u1: throw runtime_error("unsupported type");
case element::Type_t::undefined: throw runtime_error("unsupported type");
case element::Type_t::u1:
case element::Type_t::i4:
case element::Type_t::u4:
case element::Type_t::undefined:
case element::Type_t::dynamic: throw runtime_error("unsupported type");
}
#if defined(__GNUC__) && !(__GNUC__ == 4 && __GNUC_MINOR__ == 8)
@@ -615,7 +642,9 @@ bool op::Constant::are_all_data_elements_bitwise_identical() const
rc = test_bitwise_identical<uint64_t>(this);
break;
}
case element::Type_t::i4:
case element::Type_t::u1:
case element::Type_t::u4:
case element::Type_t::undefined:
case element::Type_t::dynamic: break;
}

View File

@@ -476,6 +476,8 @@ void op::v0::Range::validate_and_infer_types()
case element::Type_t::u64: result_shape = infer_output_shape<uint64_t>(this, result_et); break;
case element::Type_t::dynamic: result_shape = PartialShape::dynamic(1); break;
case element::Type_t::u1:
case element::Type_t::i4:
case element::Type_t::u4:
case element::Type_t::undefined:
case element::Type_t::boolean:
NODE_VALIDATION_CHECK(

View File

@@ -389,6 +389,8 @@ std::string pass::VisualizeTree::get_constant_value(std::shared_ptr<Node> node,
case element::Type_t::undefined: ss << "[ undefined value ]"; break;
case element::Type_t::dynamic: ss << "[ dynamic value ]"; break;
case element::Type_t::u1: ss << "[ u1 value ]"; break;
case element::Type_t::u4: ss << "[ u4 value ]"; break;
case element::Type_t::i4: ss << "[ i4 value ]"; break;
case element::Type_t::bf16:
case element::Type_t::f16:
case element::Type_t::f32:

View File

@@ -72,11 +72,13 @@ static const element_types_map_t& get_type_info_map()
{element::Type_t::f16, TypeInfo(16, true, true, false, "float16", "f16")},
{element::Type_t::f32, TypeInfo(32, true, true, false, "float", "f32")},
{element::Type_t::f64, TypeInfo(64, true, true, false, "double", "f64")},
{element::Type_t::i4, TypeInfo(4, false, true, true, "int4_t", "i4")},
{element::Type_t::i8, TypeInfo(8, false, true, true, "int8_t", "i8")},
{element::Type_t::i16, TypeInfo(16, false, true, false, "int16_t", "i16")},
{element::Type_t::i32, TypeInfo(32, false, true, true, "int32_t", "i32")},
{element::Type_t::i64, TypeInfo(64, false, true, false, "int64_t", "i64")},
{element::Type_t::u1, TypeInfo(1, false, false, false, "uint1_t", "u1")},
{element::Type_t::u4, TypeInfo(4, false, false, false, "uint4_t", "u4")},
{element::Type_t::u8, TypeInfo(8, false, false, true, "uint8_t", "u8")},
{element::Type_t::u16, TypeInfo(16, false, false, false, "uint16_t", "u16")},
{element::Type_t::u32, TypeInfo(32, false, false, false, "uint32_t", "u32")},
@@ -93,11 +95,13 @@ std::vector<const element::Type*> element::Type::get_known_types()
&element::f16,
&element::f32,
&element::f64,
&element::i4,
&element::i8,
&element::i16,
&element::i32,
&element::i64,
&element::u1,
&element::u4,
&element::u8,
&element::u16,
&element::u32,
@@ -294,11 +298,13 @@ size_t ngraph::compiler_byte_size(element::Type_t et)
ET_CASE(f16);
ET_CASE(f32);
ET_CASE(f64);
ET_CASE(i4);
ET_CASE(i8);
ET_CASE(i16);
ET_CASE(i32);
ET_CASE(i64);
ET_CASE(u1);
ET_CASE(u4);
ET_CASE(u8);
ET_CASE(u16);
ET_CASE(u32);
@@ -326,11 +332,13 @@ namespace ngraph
{"f16", element::Type_t::f16},
{"f32", element::Type_t::f32},
{"f64", element::Type_t::f64},
{"i4", element::Type_t::i4},
{"i8", element::Type_t::i8},
{"i16", element::Type_t::i16},
{"i32", element::Type_t::i32},
{"i64", element::Type_t::i64},
{"u1", element::Type_t::u1},
{"u4", element::Type_t::u4},
{"u8", element::Type_t::u8},
{"u16", element::Type_t::u16},
{"u32", element::Type_t::u32},

View File

@@ -64,6 +64,7 @@ set(SRC
graph_rewrite.cpp
includes.cpp
input_output_assign.cpp
int4.cpp
intervals.cpp
main.cpp
matcher_pass.cpp
@@ -211,6 +212,7 @@ set(SRC
type_prop/unsqueeze.cpp
type_prop/variadic_split.cpp
type_prop_layers.cpp
uint4.cpp
util.cpp
)

35
ngraph/test/int4.cpp Normal file
View File

@@ -0,0 +1,35 @@
//*****************************************************************************
// Copyright 2017-2021 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#include "gtest/gtest.h"
#include "ngraph/ngraph.hpp"
#include "util/all_close_f.hpp"
#include "util/test_tools.hpp"
using namespace ngraph;
using namespace std;
TEST(int4, convert_i4_to_string)
{
vector<uint8_t> values{171, 16};
auto constant = make_shared<op::Constant>(element::i4, Shape{3}, &values[0]);
vector<string> ref{"-14", "-13", "1"};
for (size_t i = 0; i < 3; ++i)
{
ASSERT_EQ(constant->convert_value_to_string(i), ref[i]);
}
}

View File

@@ -124,6 +124,8 @@ void pass::DynElimination::construct_range()
case element::Type_t::u64:
replacement = make_range_replacement<uint64_t>(et, shape, start_arg, step_arg);
break;
case element::Type_t::i4:
case element::Type_t::u4:
case element::Type_t::u1:
case element::Type_t::undefined:
case element::Type_t::dynamic:

35
ngraph/test/uint4.cpp Normal file
View File

@@ -0,0 +1,35 @@
//*****************************************************************************
// Copyright 2017-2021 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#include "gtest/gtest.h"
#include "ngraph/ngraph.hpp"
#include "util/all_close_f.hpp"
#include "util/test_tools.hpp"
using namespace ngraph;
using namespace std;
TEST(uint4, convert_u4_to_string)
{
vector<uint8_t> values{171, 16};
auto constant = make_shared<op::Constant>(element::u4, Shape{3}, &values[0]);
vector<string> ref{"10", "11", "1"};
for (size_t i = 0; i < 3; ++i)
{
ASSERT_EQ(constant->convert_value_to_string(i), ref[i]);
}
}

View File

@@ -149,7 +149,9 @@ namespace
case element::Type_t::u32: return InferenceEngine::Precision::U32; break;
case element::Type_t::u64: return InferenceEngine::Precision::U64; break;
case element::Type_t::u1: return InferenceEngine::Precision::BIN; break;
case element::Type_t::undefined: throw std::runtime_error("unsupported type");
case element::Type_t::i4:
case element::Type_t::u4:
case element::Type_t::undefined:
case element::Type_t::dynamic: throw std::runtime_error("unsupported type");
}
#if defined(__GNUC__) && !(__GNUC__ == 4 && __GNUC_MINOR__ == 8)