[PYTHON API] Add bin, bool, bf16 precisions (#4560)

This commit is contained in:
Anastasia Kuporosova 2021-03-02 21:32:17 +03:00 committed by GitHub
parent b04a11697e
commit 6b4abc49a4
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 39 additions and 7 deletions

View File

@ -18,7 +18,8 @@ from .cimport ie_api_impl_defs as C
import numpy as np
from enum import Enum
supported_precisions = ["FP32", "FP64", "FP16", "I64", "U64", "I32", "U32", "I16", "I8", "U16", "U8", "BOOL"]
supported_precisions = ["FP32", "FP64", "FP16", "I64", "U64", "I32", "U32",
"I16", "I8", "U16", "U8", "BOOL", "BIN", "BF16"]
known_plugins = ['CPU', 'GPU', 'FPGA', 'MYRIAD', 'HETERO', 'HDDL', 'MULTI']
@ -28,14 +29,18 @@ layout_int_to_str_map = {0: "ANY", 1: "NCHW", 2: "NHWC", 3: "NCDHW", 4: "NDHWC",
format_map = {
'FP32' : np.float32,
'FP64' : np.float64,
'I32' : np.int32,
'FP16' : np.float16,
'I64' : np.int64,
'U64' : np.uint64,
'I32' : np.int32,
'U32' : np.uint32,
'I16' : np.int16,
'U16' : np.uint16,
'I8' : np.int8,
'U8' : np.uint8,
'I64' : np.int64,
'BOOL' : np.uint8
'BOOL' : np.uint8,
'BIN' : np.int8,
'BF16' : np.float16,
}
layout_str_to_enum = {'ANY': C.Layout.ANY,

View File

@ -23,7 +23,7 @@ from libcpp.pair cimport pair
from libcpp.map cimport map
from libcpp.memory cimport unique_ptr
from libc.stdlib cimport malloc, free
from libc.stdint cimport int64_t, uint8_t, int8_t, int32_t, uint16_t, int16_t
from libc.stdint cimport int64_t, uint8_t, int8_t, int32_t, uint16_t, int16_t, uint32_t, uint64_t
from libc.stddef cimport size_t
from libc.string cimport memcpy
@ -134,6 +134,8 @@ cdef class Blob:
cdef int8_t[::1] I8_array_memview
cdef int32_t[::1] I32_array_memview
cdef int64_t[::1] I64_array_memview
cdef uint32_t[::1] U32_array_memview
cdef uint64_t[::1] U64_array_memview
cdef int16_t[:] x_as_uint
cdef int16_t[:] y_as_uint
@ -153,7 +155,7 @@ cdef class Blob:
self._ptr = C.make_shared_blob[float](c_tensor_desc)
elif precision == "FP64":
self._ptr = C.make_shared_blob[double](c_tensor_desc)
elif precision == "FP16" or precision == "I16":
elif precision == "FP16" or precision == "I16" or precision == "BF16":
self._ptr = C.make_shared_blob[int16_t](c_tensor_desc)
elif precision == "Q78" or precision == "U16":
self._ptr = C.make_shared_blob[uint16_t](c_tensor_desc)
@ -163,8 +165,12 @@ cdef class Blob:
self._ptr = C.make_shared_blob[int8_t](c_tensor_desc)
elif precision == "I32":
self._ptr = C.make_shared_blob[int32_t](c_tensor_desc)
elif precision == "U32":
self._ptr = C.make_shared_blob[uint32_t](c_tensor_desc)
elif precision == "I64":
self._ptr = C.make_shared_blob[int64_t](c_tensor_desc)
elif precision == "U64":
self._ptr = C.make_shared_blob[uint64_t](c_tensor_desc)
else:
raise AttributeError(f"Unsupported precision {precision} for blob")
deref(self._ptr).allocate()
@ -187,7 +193,7 @@ cdef class Blob:
elif precision == "FP64":
fp64_array_memview = self._array_data
self._ptr = C.make_shared_blob[double](c_tensor_desc, &fp64_array_memview[0], fp64_array_memview.shape[0])
elif precision == "FP16":
elif precision == "FP16" or precision == "BF16":
I16_array_memview = self._array_data.view(dtype=np.int16)
self._ptr = C.make_shared_blob[int16_t](c_tensor_desc, &I16_array_memview[0], I16_array_memview.shape[0])
elif precision == "I16":
@ -205,9 +211,15 @@ cdef class Blob:
elif precision == "I32":
I32_array_memview = self._array_data
self._ptr = C.make_shared_blob[int32_t](c_tensor_desc, &I32_array_memview[0], I32_array_memview.shape[0])
elif precision == "U32":
U32_array_memview = self._array_data
self._ptr = C.make_shared_blob[uint32_t](c_tensor_desc, &U32_array_memview[0], U32_array_memview.shape[0])
elif precision == "I64":
I64_array_memview = self._array_data
self._ptr = C.make_shared_blob[int64_t](c_tensor_desc, &I64_array_memview[0], I64_array_memview.shape[0])
elif precision == "U64":
U64_array_memview = self._array_data
self._ptr = C.make_shared_blob[uint64_t](c_tensor_desc, &U64_array_memview[0], U64_array_memview.shape[0])
else:
raise AttributeError(f"Unsupported precision {precision} for blob")
@ -1520,6 +1532,9 @@ cdef class BlobBuffer:
'U32': 'I', # unsigned int
'I64': 'q', # signed long int
'U64': 'Q', # unsigned long int
'BOOL': 'B', # unsigned char
'BF16': 'h', # signed short
'BIN': 'b', # signed char
}
if name not in precision_to_format:
raise ValueError(f"Unknown Blob precision: {name}")

View File

@ -96,6 +96,18 @@ def test_write_to_buffer_int64():
write_to_buffer("I64", np.int64)
def test_write_to_buffer_bool():
write_to_buffer("BOOL", np.uint8)
def test_write_to_buffer_bin():
write_to_buffer("BIN", np.int8)
def test_write_to_buffer_bf16():
write_to_buffer("BF16", np.float16)
def test_write_numpy_scalar_int64():
tensor_desc = TensorDesc("I64", [], "SCALAR")
scalar = np.array(0, dtype=np.int64)