[ONNX] Add support for BF16 datatype (#5194)
This commit is contained in:
parent
032ed451fd
commit
684dcf0d92
@ -29,6 +29,7 @@ namespace ngraph
|
|||||||
case ONNX_NAMESPACE::TensorProto_DataType_UINT16: return sizeof(uint16_t);
|
case ONNX_NAMESPACE::TensorProto_DataType_UINT16: return sizeof(uint16_t);
|
||||||
case ONNX_NAMESPACE::TensorProto_DataType_UINT32: return sizeof(uint32_t);
|
case ONNX_NAMESPACE::TensorProto_DataType_UINT32: return sizeof(uint32_t);
|
||||||
case ONNX_NAMESPACE::TensorProto_DataType_UINT64: return sizeof(uint64_t);
|
case ONNX_NAMESPACE::TensorProto_DataType_UINT64: return sizeof(uint64_t);
|
||||||
|
case ONNX_NAMESPACE::TensorProto_DataType_BFLOAT16: return sizeof(uint16_t);
|
||||||
}
|
}
|
||||||
#ifdef NGRAPH_USE_PROTOBUF_LITE
|
#ifdef NGRAPH_USE_PROTOBUF_LITE
|
||||||
throw ngraph_error("unsupported element type");
|
throw ngraph_error("unsupported element type");
|
||||||
|
@ -31,6 +31,7 @@ namespace ngraph
|
|||||||
case ONNX_NAMESPACE::TensorProto_DataType_UINT32: return element::u32;
|
case ONNX_NAMESPACE::TensorProto_DataType_UINT32: return element::u32;
|
||||||
case ONNX_NAMESPACE::TensorProto_DataType_UINT64: return element::u64;
|
case ONNX_NAMESPACE::TensorProto_DataType_UINT64: return element::u64;
|
||||||
case ONNX_NAMESPACE::TensorProto_DataType_UNDEFINED: return element::dynamic;
|
case ONNX_NAMESPACE::TensorProto_DataType_UNDEFINED: return element::dynamic;
|
||||||
|
case ONNX_NAMESPACE::TensorProto_DataType_BFLOAT16: return element::bf16;
|
||||||
}
|
}
|
||||||
#ifdef NGRAPH_USE_PROTOBUF_LITE
|
#ifdef NGRAPH_USE_PROTOBUF_LITE
|
||||||
throw ngraph_error("unsupported element type");
|
throw ngraph_error("unsupported element type");
|
||||||
|
@ -34,11 +34,11 @@ ngraph_to_numpy_types_map = [
|
|||||||
(NgraphType.u16, np.uint16),
|
(NgraphType.u16, np.uint16),
|
||||||
(NgraphType.u32, np.uint32),
|
(NgraphType.u32, np.uint32),
|
||||||
(NgraphType.u64, np.uint64),
|
(NgraphType.u64, np.uint64),
|
||||||
|
(NgraphType.bf16, np.uint16),
|
||||||
]
|
]
|
||||||
|
|
||||||
ngraph_to_numpy_types_str_map = [
|
ngraph_to_numpy_types_str_map = [
|
||||||
("boolean", np.bool),
|
("boolean", np.bool),
|
||||||
# ('bf16', ???),
|
|
||||||
("f16", np.float16),
|
("f16", np.float16),
|
||||||
("f32", np.float32),
|
("f32", np.float32),
|
||||||
("f64", np.float64),
|
("f64", np.float64),
|
||||||
|
@ -28,6 +28,7 @@ void regclass_pyngraph_Type(py::module m)
|
|||||||
type.attr("u16") = ngraph::element::u16;
|
type.attr("u16") = ngraph::element::u16;
|
||||||
type.attr("u32") = ngraph::element::u32;
|
type.attr("u32") = ngraph::element::u32;
|
||||||
type.attr("u64") = ngraph::element::u64;
|
type.attr("u64") = ngraph::element::u64;
|
||||||
|
type.attr("bf16") = ngraph::element::bf16;
|
||||||
|
|
||||||
type.def("__repr__", [](const ngraph::element::Type& self) {
|
type.def("__repr__", [](const ngraph::element::Type& self) {
|
||||||
std::string bitwidth = std::to_string(self.bitwidth());
|
std::string bitwidth = std::to_string(self.bitwidth());
|
||||||
|
@ -20,3 +20,4 @@ void regclass_pyngraph_UInt8(py::module m);
|
|||||||
// void regclass_pyngraph_UInt16(py::module m);
|
// void regclass_pyngraph_UInt16(py::module m);
|
||||||
void regclass_pyngraph_UInt32(py::module m);
|
void regclass_pyngraph_UInt32(py::module m);
|
||||||
void regclass_pyngraph_UInt64(py::module m);
|
void regclass_pyngraph_UInt64(py::module m);
|
||||||
|
void regclass_pyngraph_BFloat16(py::module m);
|
||||||
|
@ -118,7 +118,6 @@ xfail_issue_44956 = xfail_test(reason="E Unsupported dynamic op: Loop")
|
|||||||
xfail_issue_44957 = xfail_test(reason="E Unsupported dynamic op: NonZero")
|
xfail_issue_44957 = xfail_test(reason="E Unsupported dynamic op: NonZero")
|
||||||
xfail_issue_44958 = xfail_test(reason="E Unsupported dynamic op: Interpolate")
|
xfail_issue_44958 = xfail_test(reason="E Unsupported dynamic op: Interpolate")
|
||||||
xfail_issue_44965 = xfail_test(reason="E RuntimeError: value info has no element")
|
xfail_issue_44965 = xfail_test(reason="E RuntimeError: value info has no element")
|
||||||
xfail_issue_44967 = xfail_test(reason="E RuntimeError: unsupported element type: BFLOAT16")
|
|
||||||
xfail_issue_44968 = xfail_test(reason="E Unsupported dynamic op: Squeeze")
|
xfail_issue_44968 = xfail_test(reason="E Unsupported dynamic op: Squeeze")
|
||||||
xfail_issue_44970 = xfail_test(reason="Assertion error")
|
xfail_issue_44970 = xfail_test(reason="Assertion error")
|
||||||
xfail_issue_44976 = xfail_test(reason="E RuntimeError: Quantize layer with name:"
|
xfail_issue_44976 = xfail_test(reason="E RuntimeError: Quantize layer with name:"
|
||||||
|
@ -121,6 +121,17 @@ class Computation(object):
|
|||||||
out_name = self._get_ie_output_blob_name(output_blobs, ng_result)
|
out_name = self._get_ie_output_blob_name(output_blobs, ng_result)
|
||||||
return output_blobs[out_name].buffer
|
return output_blobs[out_name].buffer
|
||||||
|
|
||||||
|
def convert_buffers(self, source_buffers, target_dtypes):
|
||||||
|
converted_buffers = []
|
||||||
|
for i in range(len(source_buffers)):
|
||||||
|
target_dtype = target_dtypes[i]
|
||||||
|
# custom conversion for bf16
|
||||||
|
if self.results[i].get_output_element_type(0) == Type.bf16:
|
||||||
|
converted_buffers.append((source_buffers[i].view(np.uint32) >> 16).astype(np.uint16))
|
||||||
|
else:
|
||||||
|
converted_buffers.append(source_buffers[i].astype(target_dtype))
|
||||||
|
return converted_buffers
|
||||||
|
|
||||||
def __call__(self, *input_values: NumericData) -> List[NumericData]:
|
def __call__(self, *input_values: NumericData) -> List[NumericData]:
|
||||||
"""Run computation on input values and return result."""
|
"""Run computation on input values and return result."""
|
||||||
# Input validation
|
# Input validation
|
||||||
@ -173,6 +184,5 @@ class Computation(object):
|
|||||||
|
|
||||||
# Since OV overwrite result data type we have to convert results to the original one.
|
# Since OV overwrite result data type we have to convert results to the original one.
|
||||||
original_dtypes = [get_dtype(result.get_output_element_type(0)) for result in self.results]
|
original_dtypes = [get_dtype(result.get_output_element_type(0)) for result in self.results]
|
||||||
converted_buffers = [buffer.astype(original_dtype) for buffer, original_dtype in
|
converted_buffers = self.convert_buffers(result_buffers, original_dtypes)
|
||||||
zip(result_buffers, original_dtypes)]
|
|
||||||
return converted_buffers
|
return converted_buffers
|
||||||
|
@ -48,7 +48,6 @@ from tests import (BACKEND_NAME,
|
|||||||
xfail_issue_44957,
|
xfail_issue_44957,
|
||||||
xfail_issue_44958,
|
xfail_issue_44958,
|
||||||
xfail_issue_44965,
|
xfail_issue_44965,
|
||||||
xfail_issue_44967,
|
|
||||||
xfail_issue_44968,
|
xfail_issue_44968,
|
||||||
xfail_issue_44976,
|
xfail_issue_44976,
|
||||||
xfail_issue_45180,
|
xfail_issue_45180,
|
||||||
@ -382,9 +381,6 @@ tests_expected_to_fail = [
|
|||||||
"OnnxBackendNodeModelTest.test_loop13_seq_cpu",
|
"OnnxBackendNodeModelTest.test_loop13_seq_cpu",
|
||||||
"OnnxBackendNodeModelTest.test_sequence_insert_at_back_cpu",
|
"OnnxBackendNodeModelTest.test_sequence_insert_at_back_cpu",
|
||||||
"OnnxBackendNodeModelTest.test_sequence_insert_at_front_cpu",),
|
"OnnxBackendNodeModelTest.test_sequence_insert_at_front_cpu",),
|
||||||
(xfail_issue_44967,
|
|
||||||
"OnnxBackendNodeModelTest.test_cast_BFLOAT16_to_FLOAT_cpu",
|
|
||||||
"OnnxBackendNodeModelTest.test_cast_FLOAT_to_BFLOAT16_cpu",),
|
|
||||||
(xfail_issue_44968,
|
(xfail_issue_44968,
|
||||||
"OnnxBackendNodeModelTest.test_squeeze_cpu",
|
"OnnxBackendNodeModelTest.test_squeeze_cpu",
|
||||||
"OnnxBackendNodeModelTest.test_squeeze_negative_axes_cpu",),
|
"OnnxBackendNodeModelTest.test_squeeze_negative_axes_cpu",),
|
||||||
|
Loading…
Reference in New Issue
Block a user