[ONNX] Add support for BF16 datatype (#5194)

This commit is contained in:
Tomasz Socha 2021-04-13 11:21:21 +02:00 committed by GitHub
parent 032ed451fd
commit 684dcf0d92
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 17 additions and 8 deletions

View File

@ -29,6 +29,7 @@ namespace ngraph
case ONNX_NAMESPACE::TensorProto_DataType_UINT16: return sizeof(uint16_t);
case ONNX_NAMESPACE::TensorProto_DataType_UINT32: return sizeof(uint32_t);
case ONNX_NAMESPACE::TensorProto_DataType_UINT64: return sizeof(uint64_t);
case ONNX_NAMESPACE::TensorProto_DataType_BFLOAT16: return sizeof(uint16_t);
}
#ifdef NGRAPH_USE_PROTOBUF_LITE
throw ngraph_error("unsupported element type");

View File

@ -31,6 +31,7 @@ namespace ngraph
case ONNX_NAMESPACE::TensorProto_DataType_UINT32: return element::u32;
case ONNX_NAMESPACE::TensorProto_DataType_UINT64: return element::u64;
case ONNX_NAMESPACE::TensorProto_DataType_UNDEFINED: return element::dynamic;
case ONNX_NAMESPACE::TensorProto_DataType_BFLOAT16: return element::bf16;
}
#ifdef NGRAPH_USE_PROTOBUF_LITE
throw ngraph_error("unsupported element type");

View File

@ -34,11 +34,11 @@ ngraph_to_numpy_types_map = [
(NgraphType.u16, np.uint16),
(NgraphType.u32, np.uint32),
(NgraphType.u64, np.uint64),
(NgraphType.bf16, np.uint16),
]
ngraph_to_numpy_types_str_map = [
("boolean", np.bool),
# ('bf16', ???),
("f16", np.float16),
("f32", np.float32),
("f64", np.float64),

View File

@ -28,6 +28,7 @@ void regclass_pyngraph_Type(py::module m)
type.attr("u16") = ngraph::element::u16;
type.attr("u32") = ngraph::element::u32;
type.attr("u64") = ngraph::element::u64;
type.attr("bf16") = ngraph::element::bf16;
type.def("__repr__", [](const ngraph::element::Type& self) {
std::string bitwidth = std::to_string(self.bitwidth());

View File

@ -20,3 +20,4 @@ void regclass_pyngraph_UInt8(py::module m);
// void regclass_pyngraph_UInt16(py::module m);
void regclass_pyngraph_UInt32(py::module m);
void regclass_pyngraph_UInt64(py::module m);
void regclass_pyngraph_BFloat16(py::module m);

View File

@ -118,7 +118,6 @@ xfail_issue_44956 = xfail_test(reason="E Unsupported dynamic op: Loop")
xfail_issue_44957 = xfail_test(reason="E Unsupported dynamic op: NonZero")
xfail_issue_44958 = xfail_test(reason="E Unsupported dynamic op: Interpolate")
xfail_issue_44965 = xfail_test(reason="E RuntimeError: value info has no element")
xfail_issue_44967 = xfail_test(reason="E RuntimeError: unsupported element type: BFLOAT16")
xfail_issue_44968 = xfail_test(reason="E Unsupported dynamic op: Squeeze")
xfail_issue_44970 = xfail_test(reason="Assertion error")
xfail_issue_44976 = xfail_test(reason="E RuntimeError: Quantize layer with name:"

View File

@ -121,6 +121,17 @@ class Computation(object):
out_name = self._get_ie_output_blob_name(output_blobs, ng_result)
return output_blobs[out_name].buffer
def convert_buffers(self, source_buffers, target_dtypes):
converted_buffers = []
for i in range(len(source_buffers)):
target_dtype = target_dtypes[i]
# custom conversion for bf16
if self.results[i].get_output_element_type(0) == Type.bf16:
converted_buffers.append((source_buffers[i].view(np.uint32) >> 16).astype(np.uint16))
else:
converted_buffers.append(source_buffers[i].astype(target_dtype))
return converted_buffers
def __call__(self, *input_values: NumericData) -> List[NumericData]:
"""Run computation on input values and return result."""
# Input validation
@ -173,6 +184,5 @@ class Computation(object):
# Since OV overwrite result data type we have to convert results to the original one.
original_dtypes = [get_dtype(result.get_output_element_type(0)) for result in self.results]
converted_buffers = [buffer.astype(original_dtype) for buffer, original_dtype in
zip(result_buffers, original_dtypes)]
converted_buffers = self.convert_buffers(result_buffers, original_dtypes)
return converted_buffers

View File

@ -48,7 +48,6 @@ from tests import (BACKEND_NAME,
xfail_issue_44957,
xfail_issue_44958,
xfail_issue_44965,
xfail_issue_44967,
xfail_issue_44968,
xfail_issue_44976,
xfail_issue_45180,
@ -382,9 +381,6 @@ tests_expected_to_fail = [
"OnnxBackendNodeModelTest.test_loop13_seq_cpu",
"OnnxBackendNodeModelTest.test_sequence_insert_at_back_cpu",
"OnnxBackendNodeModelTest.test_sequence_insert_at_front_cpu",),
(xfail_issue_44967,
"OnnxBackendNodeModelTest.test_cast_BFLOAT16_to_FLOAT_cpu",
"OnnxBackendNodeModelTest.test_cast_FLOAT_to_BFLOAT16_cpu",),
(xfail_issue_44968,
"OnnxBackendNodeModelTest.test_squeeze_cpu",
"OnnxBackendNodeModelTest.test_squeeze_negative_axes_cpu",),