[PT FE] Support bfloat16 constants (#18534)

* [PT FE] Support bfloat16 constants

* Update src/bindings/python/src/openvino/frontend/pytorch/decoder.py

* Add tests for tracing
This commit is contained in:
Maxim Vafin 2023-07-14 08:29:55 +02:00 committed by GitHub
parent 3f67b3948d
commit acb14d5d6b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 114 additions and 17 deletions

View File

@ -6,7 +6,7 @@
from openvino.frontend.pytorch.py_pytorch_frontend import _FrontEndPytorchDecoder as Decoder
from openvino.frontend.pytorch.py_pytorch_frontend import _Type as DecoderType
from openvino.runtime import op, PartialShape, Type as OVType, OVAny, Shape
from openvino.runtime import op, PartialShape, Type as OVType, OVAny, Shape, Tensor
import typing
from packaging.version import parse
@ -39,12 +39,17 @@ def ivalue_to_constant(ivalue):
return op.Constant(ov_type, Shape([len(ivalue)]), ivalue).outputs()
if isinstance(ivalue, torch.Tensor):
if ivalue.dim() == 0:
assert str(ivalue.dtype) in pt_to_ov_type_map, f"Type is not known {ivalue.dtype}"
ov_type = pt_to_ov_type_map[str(ivalue.dtype)]
ov_const = op.Constant(ov_type, Shape([]), [ivalue.item()])
ivalue = ivalue.to(memory_format=torch.contiguous_format)
if ivalue.dtype == torch.bfloat16:
# reinterpret bfloat16 data as float16 to allow conversion to numpy
ivalue = ivalue.view(torch.float16)
narr = ivalue.numpy(force=True)
if not narr.flags['C_CONTIGUOUS']:
narr = np.ascontiguousarray(narr)
# TODO: this tensor doesn't share memory with initial tensor
tensor = Tensor(narr, ivalue.shape, OVType.bf16)
ov_const = op.Constant(tensor, shared_memory=True)
else:
ivalue = ivalue.to(memory_format=torch.contiguous_format)
narr = ivalue.numpy(force=True)
if not narr.flags['C_CONTIGUOUS']:
narr = np.ascontiguousarray(narr)
@ -76,6 +81,7 @@ pt_to_ov_type_map = {
"float": OVType.f32,
"int": OVType.i32,
"bool": OVType.boolean,
"torch.bfloat16": OVType.bf16,
"torch.float16": OVType.f16,
"torch.float32": OVType.f32,
"torch.float64": OVType.f64,
@ -153,7 +159,7 @@ class TorchScriptPythonDecoder (Decoder):
inputs = ordered_inputs
if isinstance(inputs, torch.Tensor):
inputs = [inputs]
return {"example_inputs": inputs}, input_signature
if isinstance(pt_module, torch.nn.Module):
@ -184,7 +190,7 @@ class TorchScriptPythonDecoder (Decoder):
f_model = scripted
else:
f_model = pt_module
self._input_signature = input_signature
return f_model

View File

@ -24,6 +24,7 @@ def get_scripted_model(model):
print(model.inlined_graph) # will help debugging
return model
def get_traced_model(model, inputs=[], frozen=True):
with torch.no_grad():
model = torch.jit.trace(model, example_inputs=inputs)
@ -116,6 +117,28 @@ def test_pytorch_decoder_can_convert_fp16_tensor():
assert ov_const[0].get_partial_shape() == PartialShape([2])
@pytest.mark.precommit
def test_pytorch_decoder_can_convert_bf16_tensor():
from openvino.frontend.pytorch.decoder import TorchScriptPythonDecoder
from openvino.runtime import PartialShape, Type
class SomeTensor(torch.nn.Module):
def forward(self):
return torch.tensor([1, 2], dtype=torch.bfloat16)
model = get_scripted_model(SomeTensor())
consts = [n for n in model.inlined_graph.nodes() if n.kind() ==
"prim::Constant"]
assert len(consts) > 0
some_const = consts[0]
nc_decoder = TorchScriptPythonDecoder(model, some_const)
ov_const = nc_decoder.as_constant()
assert ov_const is not None
assert len(ov_const) == 1
assert ov_const[0].get_element_type() == Type.bf16
assert ov_const[0].get_partial_shape() == PartialShape([2])
@pytest.mark.precommit
def test_pytorch_decoder_can_convert_fp32_tensor():
from openvino.frontend.pytorch.decoder import TorchScriptPythonDecoder
@ -453,6 +476,7 @@ def test_pytorch_decoder_can_convert_empty_list():
assert ov_const[0].get_element_type() == Type.i32
assert ov_const[0].get_partial_shape() == PartialShape([0])
@pytest.mark.precommit
def test_pytorch_decoder_can_convert_int_scalar_tensor():
from openvino.frontend.pytorch.decoder import TorchScriptPythonDecoder
@ -483,6 +507,7 @@ def test_pytorch_decoder_can_convert_int_scalar_tensor():
assert ov_const[0].get_element_type() == Type.i32
assert ov_const[0].get_partial_shape() == PartialShape([])
@pytest.mark.precommit
def test_pytorch_decoder_can_convert_float_scalar_tensor():
from openvino.frontend.pytorch.decoder import TorchScriptPythonDecoder
@ -498,10 +523,9 @@ def test_pytorch_decoder_can_convert_float_scalar_tensor():
# would create nore with output being Tensor with IValue of type float.
return torch.add(torch.tensor([1.], dtype=torch.float), self.value + 1)
model = get_traced_model(SomeTensor(), frozen=False)
consts = [n for n in model.inlined_graph.nodes() if n.kind() ==
"prim::Constant"]
"prim::Constant"]
assert len(consts) > 0
some_const = consts[6]
node_output = list(some_const.outputs())[0]
@ -514,6 +538,7 @@ def test_pytorch_decoder_can_convert_float_scalar_tensor():
assert ov_const[0].get_element_type() == Type.f32
assert ov_const[0].get_partial_shape() == PartialShape([])
@pytest.mark.precommit
def test_pytorch_decoder_can_convert_tensor_list():
from openvino.frontend.pytorch.decoder import TorchScriptPythonDecoder
@ -522,7 +547,8 @@ def test_pytorch_decoder_can_convert_tensor_list():
class SomeTensor(torch.nn.Module):
def forward(self):
l = torch.jit.annotate(List[Optional[torch.Tensor]], [torch.ones((1, 3, 3), dtype=torch.float),])
l = torch.jit.annotate(List[Optional[torch.Tensor]], [
torch.ones((1, 3, 3), dtype=torch.float),])
return l
model = get_scripted_model(SomeTensor())
@ -531,10 +557,12 @@ def test_pytorch_decoder_can_convert_tensor_list():
nc_decoder = TorchScriptPythonDecoder(model)
graph = nc_decoder.graph_element
converted_const_nodes = list(graph.findAllNodes("prim::Constant"))
converted_listconstruct_nodes = list(graph.findAllNodes("prim::ListConstruct"))
converted_listconstruct_nodes = list(
graph.findAllNodes("prim::ListConstruct"))
# # Assert that replaced const exist and is not used
assert len(converted_const_nodes) == 2
assert len([node for node in converted_const_nodes if not node.hasUses()]) == 1
assert len(
[node for node in converted_const_nodes if not node.hasUses()]) == 1
# Assert that prim::ListConstruct exist and has uses
assert len(converted_listconstruct_nodes) == 1
assert converted_listconstruct_nodes[0].kind() == "prim::ListConstruct"
@ -542,9 +570,12 @@ def test_pytorch_decoder_can_convert_tensor_list():
assert len(list(converted_listconstruct_nodes[0].inputs())) == 1
created_const = converted_listconstruct_nodes[0].input().node()
assert created_const in converted_const_nodes
created_const_decoder = TorchScriptPythonDecoder(model, created_const).as_constant()
created_const_decoder = TorchScriptPythonDecoder(
model, created_const).as_constant()
assert created_const_decoder[0].get_element_type() == Type.f32
assert created_const_decoder[0].get_partial_shape() == PartialShape([1, 3, 3])
assert created_const_decoder[0].get_partial_shape() == PartialShape([
1, 3, 3])
@pytest.mark.precommit
def test_pytorch_decoder_can_convert_tensor_list_empty():
@ -562,7 +593,8 @@ def test_pytorch_decoder_can_convert_tensor_list_empty():
nc_decoder = TorchScriptPythonDecoder(model)
graph = nc_decoder.graph_element
converted_const_nodes = list(graph.findAllNodes("prim::Constant"))
converted_listconstruct_nodes = list(graph.findAllNodes("prim::ListConstruct"))
converted_listconstruct_nodes = list(
graph.findAllNodes("prim::ListConstruct"))
# Assert that replaced const exist and is not used
assert len(converted_const_nodes) == 1
assert not converted_const_nodes[0].hasUses()
@ -572,10 +604,12 @@ def test_pytorch_decoder_can_convert_tensor_list_empty():
assert converted_listconstruct_nodes[0].hasUses()
assert len(list(converted_listconstruct_nodes[0].inputs())) == 0
@pytest.mark.precommit
def test_pytorch_decoder_can_convert_optional_tensor_none():
from openvino.frontend.pytorch.decoder import TorchScriptPythonDecoder
from typing import Optional
class SomeTensor(torch.nn.Module):
def forward(self):
l = torch.jit.annotate(Optional[torch.Tensor], None)
@ -587,7 +621,8 @@ def test_pytorch_decoder_can_convert_optional_tensor_none():
nc_decoder = TorchScriptPythonDecoder(model)
graph = nc_decoder.graph_element
converted_const_nodes = list(graph.findAllNodes("prim::Constant"))
removed_consts = [node for node in converted_const_nodes if not node.hasUses()]
removed_consts = [
node for node in converted_const_nodes if not node.hasUses()]
created_consts = [node for node in converted_const_nodes if node.hasUses()]
assert len(removed_consts) == len(created_consts) == 1
# Assert that unused const has torch.OptionalType dtype

View File

@ -0,0 +1,56 @@
# Copyright (C) 2018-2023 Intel Corporation
# SPDX-License-Identifier: Apache-2.0
import numpy as np
import pytest
import torch
from pytorch_layer_test_class import PytorchLayerTest
class TestBF16(PytorchLayerTest):
def _prepare_input(self):
return (np.random.randn(10).astype(np.float32),)
def create_model(self):
class aten_add(torch.nn.Module):
def __init__(self):
super(aten_add, self).__init__()
self.y = torch.randn(10, dtype=torch.bfloat16)
def forward(self, x):
return x + self.y.to(torch.float32)
return aten_add(), None, "aten::add"
@pytest.mark.nightly
@pytest.mark.precommit
@pytest.mark.parametrize("to_trace", [True, False])
def test_bf16(self, ie_device, precision, ir_version, to_trace):
self._test(*self.create_model(), ie_device, precision,
ir_version, trace_model=to_trace, freeze_model=False)
class TestFP16(PytorchLayerTest):
def _prepare_input(self):
return (np.random.randn(10).astype(np.float32),)
def create_model(self):
class aten_add(torch.nn.Module):
def __init__(self):
super(aten_add, self).__init__()
self.y = torch.randn(10, dtype=torch.float16)
def forward(self, x):
return x + self.y.to(torch.float32)
return aten_add(), None, "aten::add"
@pytest.mark.nightly
@pytest.mark.precommit
@pytest.mark.parametrize("to_trace", [True, False])
def test_fp16(self, ie_device, precision, ir_version, to_trace):
self._test(*self.create_model(), ie_device, precision,
ir_version, trace_model=to_trace, freeze_model=False)