From acb14d5d6b10ec78e1c4d13fd1c6c40f5ae7fd18 Mon Sep 17 00:00:00 2001 From: Maxim Vafin Date: Fri, 14 Jul 2023 08:29:55 +0200 Subject: [PATCH] [PT FE] Support bfloat16 constants (#18534) * [PT FE] Support bfloat16 constants * Update src/bindings/python/src/openvino/frontend/pytorch/decoder.py * Add tests for tracing --- .../src/openvino/frontend/pytorch/decoder.py | 22 +++++--- .../py_frontend_tests/test_torch_decoder.py | 53 +++++++++++++++--- tests/layer_tests/pytorch_tests/test_fp16.py | 56 +++++++++++++++++++ 3 files changed, 114 insertions(+), 17 deletions(-) create mode 100644 tests/layer_tests/pytorch_tests/test_fp16.py diff --git a/src/bindings/python/src/openvino/frontend/pytorch/decoder.py b/src/bindings/python/src/openvino/frontend/pytorch/decoder.py index 0e3cc5f511e..04d95ed55c7 100644 --- a/src/bindings/python/src/openvino/frontend/pytorch/decoder.py +++ b/src/bindings/python/src/openvino/frontend/pytorch/decoder.py @@ -6,7 +6,7 @@ from openvino.frontend.pytorch.py_pytorch_frontend import _FrontEndPytorchDecoder as Decoder from openvino.frontend.pytorch.py_pytorch_frontend import _Type as DecoderType -from openvino.runtime import op, PartialShape, Type as OVType, OVAny, Shape +from openvino.runtime import op, PartialShape, Type as OVType, OVAny, Shape, Tensor import typing from packaging.version import parse @@ -39,12 +39,17 @@ def ivalue_to_constant(ivalue): return op.Constant(ov_type, Shape([len(ivalue)]), ivalue).outputs() if isinstance(ivalue, torch.Tensor): - if ivalue.dim() == 0: - assert str(ivalue.dtype) in pt_to_ov_type_map, f"Type is not known {ivalue.dtype}" - ov_type = pt_to_ov_type_map[str(ivalue.dtype)] - ov_const = op.Constant(ov_type, Shape([]), [ivalue.item()]) + ivalue = ivalue.to(memory_format=torch.contiguous_format) + if ivalue.dtype == torch.bfloat16: + # reinterpret bfloat16 data as float16 to allow conversion to numpy + ivalue = ivalue.view(torch.float16) + narr = ivalue.numpy(force=True) + if not narr.flags['C_CONTIGUOUS']: + narr = np.ascontiguousarray(narr) + # TODO: this tensor doesn't share memory with initial tensor + tensor = Tensor(narr, ivalue.shape, OVType.bf16) + ov_const = op.Constant(tensor, shared_memory=True) else: - ivalue = ivalue.to(memory_format=torch.contiguous_format) narr = ivalue.numpy(force=True) if not narr.flags['C_CONTIGUOUS']: narr = np.ascontiguousarray(narr) @@ -76,6 +81,7 @@ pt_to_ov_type_map = { "float": OVType.f32, "int": OVType.i32, "bool": OVType.boolean, + "torch.bfloat16": OVType.bf16, "torch.float16": OVType.f16, "torch.float32": OVType.f32, "torch.float64": OVType.f64, @@ -153,7 +159,7 @@ class TorchScriptPythonDecoder (Decoder): inputs = ordered_inputs if isinstance(inputs, torch.Tensor): inputs = [inputs] - + return {"example_inputs": inputs}, input_signature if isinstance(pt_module, torch.nn.Module): @@ -184,7 +190,7 @@ class TorchScriptPythonDecoder (Decoder): f_model = scripted else: f_model = pt_module - + self._input_signature = input_signature return f_model diff --git a/tests/layer_tests/py_frontend_tests/test_torch_decoder.py b/tests/layer_tests/py_frontend_tests/test_torch_decoder.py index 31bb0108453..d248bb55e86 100644 --- a/tests/layer_tests/py_frontend_tests/test_torch_decoder.py +++ b/tests/layer_tests/py_frontend_tests/test_torch_decoder.py @@ -24,6 +24,7 @@ def get_scripted_model(model): print(model.inlined_graph) # will help debugging return model + def get_traced_model(model, inputs=[], frozen=True): with torch.no_grad(): model = torch.jit.trace(model, example_inputs=inputs) @@ -116,6 +117,28 @@ def test_pytorch_decoder_can_convert_fp16_tensor(): assert ov_const[0].get_partial_shape() == PartialShape([2]) +@pytest.mark.precommit +def test_pytorch_decoder_can_convert_bf16_tensor(): + from openvino.frontend.pytorch.decoder import TorchScriptPythonDecoder + from openvino.runtime import PartialShape, Type + + class SomeTensor(torch.nn.Module): + def forward(self): + return torch.tensor([1, 2], dtype=torch.bfloat16) + + model = get_scripted_model(SomeTensor()) + consts = [n for n in model.inlined_graph.nodes() if n.kind() == + "prim::Constant"] + assert len(consts) > 0 + some_const = consts[0] + nc_decoder = TorchScriptPythonDecoder(model, some_const) + ov_const = nc_decoder.as_constant() + assert ov_const is not None + assert len(ov_const) == 1 + assert ov_const[0].get_element_type() == Type.bf16 + assert ov_const[0].get_partial_shape() == PartialShape([2]) + + @pytest.mark.precommit def test_pytorch_decoder_can_convert_fp32_tensor(): from openvino.frontend.pytorch.decoder import TorchScriptPythonDecoder @@ -453,6 +476,7 @@ def test_pytorch_decoder_can_convert_empty_list(): assert ov_const[0].get_element_type() == Type.i32 assert ov_const[0].get_partial_shape() == PartialShape([0]) + @pytest.mark.precommit def test_pytorch_decoder_can_convert_int_scalar_tensor(): from openvino.frontend.pytorch.decoder import TorchScriptPythonDecoder @@ -483,6 +507,7 @@ def test_pytorch_decoder_can_convert_int_scalar_tensor(): assert ov_const[0].get_element_type() == Type.i32 assert ov_const[0].get_partial_shape() == PartialShape([]) + @pytest.mark.precommit def test_pytorch_decoder_can_convert_float_scalar_tensor(): from openvino.frontend.pytorch.decoder import TorchScriptPythonDecoder @@ -498,10 +523,9 @@ def test_pytorch_decoder_can_convert_float_scalar_tensor(): # would create nore with output being Tensor with IValue of type float. return torch.add(torch.tensor([1.], dtype=torch.float), self.value + 1) - model = get_traced_model(SomeTensor(), frozen=False) consts = [n for n in model.inlined_graph.nodes() if n.kind() == - "prim::Constant"] + "prim::Constant"] assert len(consts) > 0 some_const = consts[6] node_output = list(some_const.outputs())[0] @@ -514,6 +538,7 @@ def test_pytorch_decoder_can_convert_float_scalar_tensor(): assert ov_const[0].get_element_type() == Type.f32 assert ov_const[0].get_partial_shape() == PartialShape([]) + @pytest.mark.precommit def test_pytorch_decoder_can_convert_tensor_list(): from openvino.frontend.pytorch.decoder import TorchScriptPythonDecoder @@ -522,7 +547,8 @@ def test_pytorch_decoder_can_convert_tensor_list(): class SomeTensor(torch.nn.Module): def forward(self): - l = torch.jit.annotate(List[Optional[torch.Tensor]], [torch.ones((1, 3, 3), dtype=torch.float),]) + l = torch.jit.annotate(List[Optional[torch.Tensor]], [ + torch.ones((1, 3, 3), dtype=torch.float),]) return l model = get_scripted_model(SomeTensor()) @@ -531,10 +557,12 @@ def test_pytorch_decoder_can_convert_tensor_list(): nc_decoder = TorchScriptPythonDecoder(model) graph = nc_decoder.graph_element converted_const_nodes = list(graph.findAllNodes("prim::Constant")) - converted_listconstruct_nodes = list(graph.findAllNodes("prim::ListConstruct")) + converted_listconstruct_nodes = list( + graph.findAllNodes("prim::ListConstruct")) # # Assert that replaced const exist and is not used assert len(converted_const_nodes) == 2 - assert len([node for node in converted_const_nodes if not node.hasUses()]) == 1 + assert len( + [node for node in converted_const_nodes if not node.hasUses()]) == 1 # Assert that prim::ListConstruct exist and has uses assert len(converted_listconstruct_nodes) == 1 assert converted_listconstruct_nodes[0].kind() == "prim::ListConstruct" @@ -542,9 +570,12 @@ def test_pytorch_decoder_can_convert_tensor_list(): assert len(list(converted_listconstruct_nodes[0].inputs())) == 1 created_const = converted_listconstruct_nodes[0].input().node() assert created_const in converted_const_nodes - created_const_decoder = TorchScriptPythonDecoder(model, created_const).as_constant() + created_const_decoder = TorchScriptPythonDecoder( + model, created_const).as_constant() assert created_const_decoder[0].get_element_type() == Type.f32 - assert created_const_decoder[0].get_partial_shape() == PartialShape([1, 3, 3]) + assert created_const_decoder[0].get_partial_shape() == PartialShape([ + 1, 3, 3]) + @pytest.mark.precommit def test_pytorch_decoder_can_convert_tensor_list_empty(): @@ -562,7 +593,8 @@ def test_pytorch_decoder_can_convert_tensor_list_empty(): nc_decoder = TorchScriptPythonDecoder(model) graph = nc_decoder.graph_element converted_const_nodes = list(graph.findAllNodes("prim::Constant")) - converted_listconstruct_nodes = list(graph.findAllNodes("prim::ListConstruct")) + converted_listconstruct_nodes = list( + graph.findAllNodes("prim::ListConstruct")) # Assert that replaced const exist and is not used assert len(converted_const_nodes) == 1 assert not converted_const_nodes[0].hasUses() @@ -572,10 +604,12 @@ def test_pytorch_decoder_can_convert_tensor_list_empty(): assert converted_listconstruct_nodes[0].hasUses() assert len(list(converted_listconstruct_nodes[0].inputs())) == 0 + @pytest.mark.precommit def test_pytorch_decoder_can_convert_optional_tensor_none(): from openvino.frontend.pytorch.decoder import TorchScriptPythonDecoder from typing import Optional + class SomeTensor(torch.nn.Module): def forward(self): l = torch.jit.annotate(Optional[torch.Tensor], None) @@ -587,7 +621,8 @@ def test_pytorch_decoder_can_convert_optional_tensor_none(): nc_decoder = TorchScriptPythonDecoder(model) graph = nc_decoder.graph_element converted_const_nodes = list(graph.findAllNodes("prim::Constant")) - removed_consts = [node for node in converted_const_nodes if not node.hasUses()] + removed_consts = [ + node for node in converted_const_nodes if not node.hasUses()] created_consts = [node for node in converted_const_nodes if node.hasUses()] assert len(removed_consts) == len(created_consts) == 1 # Assert that unused const has torch.OptionalType dtype diff --git a/tests/layer_tests/pytorch_tests/test_fp16.py b/tests/layer_tests/pytorch_tests/test_fp16.py new file mode 100644 index 00000000000..b7543067279 --- /dev/null +++ b/tests/layer_tests/pytorch_tests/test_fp16.py @@ -0,0 +1,56 @@ +# Copyright (C) 2018-2023 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +import numpy as np +import pytest +import torch + +from pytorch_layer_test_class import PytorchLayerTest + + +class TestBF16(PytorchLayerTest): + + def _prepare_input(self): + return (np.random.randn(10).astype(np.float32),) + + def create_model(self): + class aten_add(torch.nn.Module): + def __init__(self): + super(aten_add, self).__init__() + self.y = torch.randn(10, dtype=torch.bfloat16) + + def forward(self, x): + return x + self.y.to(torch.float32) + + return aten_add(), None, "aten::add" + + @pytest.mark.nightly + @pytest.mark.precommit + @pytest.mark.parametrize("to_trace", [True, False]) + def test_bf16(self, ie_device, precision, ir_version, to_trace): + self._test(*self.create_model(), ie_device, precision, + ir_version, trace_model=to_trace, freeze_model=False) + + +class TestFP16(PytorchLayerTest): + + def _prepare_input(self): + return (np.random.randn(10).astype(np.float32),) + + def create_model(self): + class aten_add(torch.nn.Module): + def __init__(self): + super(aten_add, self).__init__() + self.y = torch.randn(10, dtype=torch.float16) + + def forward(self, x): + return x + self.y.to(torch.float32) + + return aten_add(), None, "aten::add" + + @pytest.mark.nightly + @pytest.mark.precommit + @pytest.mark.parametrize("to_trace", [True, False]) + def test_fp16(self, ie_device, precision, ir_version, to_trace): + self._test(*self.create_model(), ie_device, precision, + ir_version, trace_model=to_trace, freeze_model=False)