diff --git a/ngraph/frontend/onnx_common/src/utils.cpp b/ngraph/frontend/onnx_common/src/utils.cpp index 5c63e0430b1..998f5f4daa8 100644 --- a/ngraph/frontend/onnx_common/src/utils.cpp +++ b/ngraph/frontend/onnx_common/src/utils.cpp @@ -29,6 +29,7 @@ namespace ngraph case ONNX_NAMESPACE::TensorProto_DataType_UINT16: return sizeof(uint16_t); case ONNX_NAMESPACE::TensorProto_DataType_UINT32: return sizeof(uint32_t); case ONNX_NAMESPACE::TensorProto_DataType_UINT64: return sizeof(uint64_t); + case ONNX_NAMESPACE::TensorProto_DataType_BFLOAT16: return sizeof(uint16_t); } #ifdef NGRAPH_USE_PROTOBUF_LITE throw ngraph_error("unsupported element type"); diff --git a/ngraph/frontend/onnx_import/src/utils/common.cpp b/ngraph/frontend/onnx_import/src/utils/common.cpp index fd3d387000d..67431ac8c52 100644 --- a/ngraph/frontend/onnx_import/src/utils/common.cpp +++ b/ngraph/frontend/onnx_import/src/utils/common.cpp @@ -31,6 +31,7 @@ namespace ngraph case ONNX_NAMESPACE::TensorProto_DataType_UINT32: return element::u32; case ONNX_NAMESPACE::TensorProto_DataType_UINT64: return element::u64; case ONNX_NAMESPACE::TensorProto_DataType_UNDEFINED: return element::dynamic; + case ONNX_NAMESPACE::TensorProto_DataType_BFLOAT16: return element::bf16; } #ifdef NGRAPH_USE_PROTOBUF_LITE throw ngraph_error("unsupported element type"); diff --git a/ngraph/python/src/ngraph/utils/types.py b/ngraph/python/src/ngraph/utils/types.py index 4e3b2f63132..b40ec700a53 100644 --- a/ngraph/python/src/ngraph/utils/types.py +++ b/ngraph/python/src/ngraph/utils/types.py @@ -34,11 +34,11 @@ ngraph_to_numpy_types_map = [ (NgraphType.u16, np.uint16), (NgraphType.u32, np.uint32), (NgraphType.u64, np.uint64), + (NgraphType.bf16, np.uint16), ] ngraph_to_numpy_types_str_map = [ ("boolean", np.bool), - # ('bf16', ???), ("f16", np.float16), ("f32", np.float32), ("f64", np.float64), diff --git a/ngraph/python/src/pyngraph/types/element_type.cpp b/ngraph/python/src/pyngraph/types/element_type.cpp index 7ae833a4fa5..db1bac50ed0 100644 --- a/ngraph/python/src/pyngraph/types/element_type.cpp +++ b/ngraph/python/src/pyngraph/types/element_type.cpp @@ -28,6 +28,7 @@ void regclass_pyngraph_Type(py::module m) type.attr("u16") = ngraph::element::u16; type.attr("u32") = ngraph::element::u32; type.attr("u64") = ngraph::element::u64; + type.attr("bf16") = ngraph::element::bf16; type.def("__repr__", [](const ngraph::element::Type& self) { std::string bitwidth = std::to_string(self.bitwidth()); diff --git a/ngraph/python/src/pyngraph/types/element_type.hpp b/ngraph/python/src/pyngraph/types/element_type.hpp index 94a67165c20..763b08e51c9 100644 --- a/ngraph/python/src/pyngraph/types/element_type.hpp +++ b/ngraph/python/src/pyngraph/types/element_type.hpp @@ -20,3 +20,4 @@ void regclass_pyngraph_UInt8(py::module m); // void regclass_pyngraph_UInt16(py::module m); void regclass_pyngraph_UInt32(py::module m); void regclass_pyngraph_UInt64(py::module m); +void regclass_pyngraph_BFloat16(py::module m); diff --git a/ngraph/python/tests/__init__.py b/ngraph/python/tests/__init__.py index 34fc00636b4..65b7040f679 100644 --- a/ngraph/python/tests/__init__.py +++ b/ngraph/python/tests/__init__.py @@ -118,7 +118,6 @@ xfail_issue_44956 = xfail_test(reason="E Unsupported dynamic op: Loop") xfail_issue_44957 = xfail_test(reason="E Unsupported dynamic op: NonZero") xfail_issue_44958 = xfail_test(reason="E Unsupported dynamic op: Interpolate") xfail_issue_44965 = xfail_test(reason="E RuntimeError: value info has no element") -xfail_issue_44967 = xfail_test(reason="E RuntimeError: unsupported element type: BFLOAT16") xfail_issue_44968 = xfail_test(reason="E Unsupported dynamic op: Squeeze") xfail_issue_44970 = xfail_test(reason="Assertion error") xfail_issue_44976 = xfail_test(reason="E RuntimeError: Quantize layer with name:" diff --git a/ngraph/python/tests/runtime.py b/ngraph/python/tests/runtime.py index 20035fb66b7..16b79b85b5b 100644 --- a/ngraph/python/tests/runtime.py +++ b/ngraph/python/tests/runtime.py @@ -121,6 +121,17 @@ class Computation(object): out_name = self._get_ie_output_blob_name(output_blobs, ng_result) return output_blobs[out_name].buffer + def convert_buffers(self, source_buffers, target_dtypes): + converted_buffers = [] + for i in range(len(source_buffers)): + target_dtype = target_dtypes[i] + # custom conversion for bf16 + if self.results[i].get_output_element_type(0) == Type.bf16: + converted_buffers.append((source_buffers[i].view(np.uint32) >> 16).astype(np.uint16)) + else: + converted_buffers.append(source_buffers[i].astype(target_dtype)) + return converted_buffers + def __call__(self, *input_values: NumericData) -> List[NumericData]: """Run computation on input values and return result.""" # Input validation @@ -173,6 +184,5 @@ class Computation(object): # Since OV overwrite result data type we have to convert results to the original one. original_dtypes = [get_dtype(result.get_output_element_type(0)) for result in self.results] - converted_buffers = [buffer.astype(original_dtype) for buffer, original_dtype in - zip(result_buffers, original_dtypes)] + converted_buffers = self.convert_buffers(result_buffers, original_dtypes) return converted_buffers diff --git a/ngraph/python/tests/test_onnx/test_backend.py b/ngraph/python/tests/test_onnx/test_backend.py index 3b65634c9b0..aa136fd1525 100644 --- a/ngraph/python/tests/test_onnx/test_backend.py +++ b/ngraph/python/tests/test_onnx/test_backend.py @@ -48,7 +48,6 @@ from tests import (BACKEND_NAME, xfail_issue_44957, xfail_issue_44958, xfail_issue_44965, - xfail_issue_44967, xfail_issue_44968, xfail_issue_44976, xfail_issue_45180, @@ -382,9 +381,6 @@ tests_expected_to_fail = [ "OnnxBackendNodeModelTest.test_loop13_seq_cpu", "OnnxBackendNodeModelTest.test_sequence_insert_at_back_cpu", "OnnxBackendNodeModelTest.test_sequence_insert_at_front_cpu",), - (xfail_issue_44967, - "OnnxBackendNodeModelTest.test_cast_BFLOAT16_to_FLOAT_cpu", - "OnnxBackendNodeModelTest.test_cast_FLOAT_to_BFLOAT16_cpu",), (xfail_issue_44968, "OnnxBackendNodeModelTest.test_squeeze_cpu", "OnnxBackendNodeModelTest.test_squeeze_negative_axes_cpu",),