[ONNX] Add support for BF16 datatype (#5194)
This commit is contained in:
parent
032ed451fd
commit
684dcf0d92
@ -29,6 +29,7 @@ namespace ngraph
|
||||
case ONNX_NAMESPACE::TensorProto_DataType_UINT16: return sizeof(uint16_t);
|
||||
case ONNX_NAMESPACE::TensorProto_DataType_UINT32: return sizeof(uint32_t);
|
||||
case ONNX_NAMESPACE::TensorProto_DataType_UINT64: return sizeof(uint64_t);
|
||||
case ONNX_NAMESPACE::TensorProto_DataType_BFLOAT16: return sizeof(uint16_t);
|
||||
}
|
||||
#ifdef NGRAPH_USE_PROTOBUF_LITE
|
||||
throw ngraph_error("unsupported element type");
|
||||
|
@ -31,6 +31,7 @@ namespace ngraph
|
||||
case ONNX_NAMESPACE::TensorProto_DataType_UINT32: return element::u32;
|
||||
case ONNX_NAMESPACE::TensorProto_DataType_UINT64: return element::u64;
|
||||
case ONNX_NAMESPACE::TensorProto_DataType_UNDEFINED: return element::dynamic;
|
||||
case ONNX_NAMESPACE::TensorProto_DataType_BFLOAT16: return element::bf16;
|
||||
}
|
||||
#ifdef NGRAPH_USE_PROTOBUF_LITE
|
||||
throw ngraph_error("unsupported element type");
|
||||
|
@ -34,11 +34,11 @@ ngraph_to_numpy_types_map = [
|
||||
(NgraphType.u16, np.uint16),
|
||||
(NgraphType.u32, np.uint32),
|
||||
(NgraphType.u64, np.uint64),
|
||||
(NgraphType.bf16, np.uint16),
|
||||
]
|
||||
|
||||
ngraph_to_numpy_types_str_map = [
|
||||
("boolean", np.bool),
|
||||
# ('bf16', ???),
|
||||
("f16", np.float16),
|
||||
("f32", np.float32),
|
||||
("f64", np.float64),
|
||||
|
@ -28,6 +28,7 @@ void regclass_pyngraph_Type(py::module m)
|
||||
type.attr("u16") = ngraph::element::u16;
|
||||
type.attr("u32") = ngraph::element::u32;
|
||||
type.attr("u64") = ngraph::element::u64;
|
||||
type.attr("bf16") = ngraph::element::bf16;
|
||||
|
||||
type.def("__repr__", [](const ngraph::element::Type& self) {
|
||||
std::string bitwidth = std::to_string(self.bitwidth());
|
||||
|
@ -20,3 +20,4 @@ void regclass_pyngraph_UInt8(py::module m);
|
||||
// void regclass_pyngraph_UInt16(py::module m);
|
||||
void regclass_pyngraph_UInt32(py::module m);
|
||||
void regclass_pyngraph_UInt64(py::module m);
|
||||
void regclass_pyngraph_BFloat16(py::module m);
|
||||
|
@ -118,7 +118,6 @@ xfail_issue_44956 = xfail_test(reason="E Unsupported dynamic op: Loop")
|
||||
xfail_issue_44957 = xfail_test(reason="E Unsupported dynamic op: NonZero")
|
||||
xfail_issue_44958 = xfail_test(reason="E Unsupported dynamic op: Interpolate")
|
||||
xfail_issue_44965 = xfail_test(reason="E RuntimeError: value info has no element")
|
||||
xfail_issue_44967 = xfail_test(reason="E RuntimeError: unsupported element type: BFLOAT16")
|
||||
xfail_issue_44968 = xfail_test(reason="E Unsupported dynamic op: Squeeze")
|
||||
xfail_issue_44970 = xfail_test(reason="Assertion error")
|
||||
xfail_issue_44976 = xfail_test(reason="E RuntimeError: Quantize layer with name:"
|
||||
|
@ -121,6 +121,17 @@ class Computation(object):
|
||||
out_name = self._get_ie_output_blob_name(output_blobs, ng_result)
|
||||
return output_blobs[out_name].buffer
|
||||
|
||||
def convert_buffers(self, source_buffers, target_dtypes):
|
||||
converted_buffers = []
|
||||
for i in range(len(source_buffers)):
|
||||
target_dtype = target_dtypes[i]
|
||||
# custom conversion for bf16
|
||||
if self.results[i].get_output_element_type(0) == Type.bf16:
|
||||
converted_buffers.append((source_buffers[i].view(np.uint32) >> 16).astype(np.uint16))
|
||||
else:
|
||||
converted_buffers.append(source_buffers[i].astype(target_dtype))
|
||||
return converted_buffers
|
||||
|
||||
def __call__(self, *input_values: NumericData) -> List[NumericData]:
|
||||
"""Run computation on input values and return result."""
|
||||
# Input validation
|
||||
@ -173,6 +184,5 @@ class Computation(object):
|
||||
|
||||
# Since OV overwrite result data type we have to convert results to the original one.
|
||||
original_dtypes = [get_dtype(result.get_output_element_type(0)) for result in self.results]
|
||||
converted_buffers = [buffer.astype(original_dtype) for buffer, original_dtype in
|
||||
zip(result_buffers, original_dtypes)]
|
||||
converted_buffers = self.convert_buffers(result_buffers, original_dtypes)
|
||||
return converted_buffers
|
||||
|
@ -48,7 +48,6 @@ from tests import (BACKEND_NAME,
|
||||
xfail_issue_44957,
|
||||
xfail_issue_44958,
|
||||
xfail_issue_44965,
|
||||
xfail_issue_44967,
|
||||
xfail_issue_44968,
|
||||
xfail_issue_44976,
|
||||
xfail_issue_45180,
|
||||
@ -382,9 +381,6 @@ tests_expected_to_fail = [
|
||||
"OnnxBackendNodeModelTest.test_loop13_seq_cpu",
|
||||
"OnnxBackendNodeModelTest.test_sequence_insert_at_back_cpu",
|
||||
"OnnxBackendNodeModelTest.test_sequence_insert_at_front_cpu",),
|
||||
(xfail_issue_44967,
|
||||
"OnnxBackendNodeModelTest.test_cast_BFLOAT16_to_FLOAT_cpu",
|
||||
"OnnxBackendNodeModelTest.test_cast_FLOAT_to_BFLOAT16_cpu",),
|
||||
(xfail_issue_44968,
|
||||
"OnnxBackendNodeModelTest.test_squeeze_cpu",
|
||||
"OnnxBackendNodeModelTest.test_squeeze_negative_axes_cpu",),
|
||||
|
Loading…
Reference in New Issue
Block a user