[PT FE] Support bfloat16 constants (#18534)
* [PT FE] Support bfloat16 constants * Update src/bindings/python/src/openvino/frontend/pytorch/decoder.py * Add tests for tracing
This commit is contained in:
parent
3f67b3948d
commit
acb14d5d6b
@ -6,7 +6,7 @@
|
||||
|
||||
from openvino.frontend.pytorch.py_pytorch_frontend import _FrontEndPytorchDecoder as Decoder
|
||||
from openvino.frontend.pytorch.py_pytorch_frontend import _Type as DecoderType
|
||||
from openvino.runtime import op, PartialShape, Type as OVType, OVAny, Shape
|
||||
from openvino.runtime import op, PartialShape, Type as OVType, OVAny, Shape, Tensor
|
||||
|
||||
import typing
|
||||
from packaging.version import parse
|
||||
@ -39,12 +39,17 @@ def ivalue_to_constant(ivalue):
|
||||
return op.Constant(ov_type, Shape([len(ivalue)]), ivalue).outputs()
|
||||
|
||||
if isinstance(ivalue, torch.Tensor):
|
||||
if ivalue.dim() == 0:
|
||||
assert str(ivalue.dtype) in pt_to_ov_type_map, f"Type is not known {ivalue.dtype}"
|
||||
ov_type = pt_to_ov_type_map[str(ivalue.dtype)]
|
||||
ov_const = op.Constant(ov_type, Shape([]), [ivalue.item()])
|
||||
ivalue = ivalue.to(memory_format=torch.contiguous_format)
|
||||
if ivalue.dtype == torch.bfloat16:
|
||||
# reinterpret bfloat16 data as float16 to allow conversion to numpy
|
||||
ivalue = ivalue.view(torch.float16)
|
||||
narr = ivalue.numpy(force=True)
|
||||
if not narr.flags['C_CONTIGUOUS']:
|
||||
narr = np.ascontiguousarray(narr)
|
||||
# TODO: this tensor doesn't share memory with initial tensor
|
||||
tensor = Tensor(narr, ivalue.shape, OVType.bf16)
|
||||
ov_const = op.Constant(tensor, shared_memory=True)
|
||||
else:
|
||||
ivalue = ivalue.to(memory_format=torch.contiguous_format)
|
||||
narr = ivalue.numpy(force=True)
|
||||
if not narr.flags['C_CONTIGUOUS']:
|
||||
narr = np.ascontiguousarray(narr)
|
||||
@ -76,6 +81,7 @@ pt_to_ov_type_map = {
|
||||
"float": OVType.f32,
|
||||
"int": OVType.i32,
|
||||
"bool": OVType.boolean,
|
||||
"torch.bfloat16": OVType.bf16,
|
||||
"torch.float16": OVType.f16,
|
||||
"torch.float32": OVType.f32,
|
||||
"torch.float64": OVType.f64,
|
||||
@ -153,7 +159,7 @@ class TorchScriptPythonDecoder (Decoder):
|
||||
inputs = ordered_inputs
|
||||
if isinstance(inputs, torch.Tensor):
|
||||
inputs = [inputs]
|
||||
|
||||
|
||||
return {"example_inputs": inputs}, input_signature
|
||||
|
||||
if isinstance(pt_module, torch.nn.Module):
|
||||
@ -184,7 +190,7 @@ class TorchScriptPythonDecoder (Decoder):
|
||||
f_model = scripted
|
||||
else:
|
||||
f_model = pt_module
|
||||
|
||||
|
||||
self._input_signature = input_signature
|
||||
return f_model
|
||||
|
||||
|
@ -24,6 +24,7 @@ def get_scripted_model(model):
|
||||
print(model.inlined_graph) # will help debugging
|
||||
return model
|
||||
|
||||
|
||||
def get_traced_model(model, inputs=[], frozen=True):
|
||||
with torch.no_grad():
|
||||
model = torch.jit.trace(model, example_inputs=inputs)
|
||||
@ -116,6 +117,28 @@ def test_pytorch_decoder_can_convert_fp16_tensor():
|
||||
assert ov_const[0].get_partial_shape() == PartialShape([2])
|
||||
|
||||
|
||||
@pytest.mark.precommit
|
||||
def test_pytorch_decoder_can_convert_bf16_tensor():
|
||||
from openvino.frontend.pytorch.decoder import TorchScriptPythonDecoder
|
||||
from openvino.runtime import PartialShape, Type
|
||||
|
||||
class SomeTensor(torch.nn.Module):
|
||||
def forward(self):
|
||||
return torch.tensor([1, 2], dtype=torch.bfloat16)
|
||||
|
||||
model = get_scripted_model(SomeTensor())
|
||||
consts = [n for n in model.inlined_graph.nodes() if n.kind() ==
|
||||
"prim::Constant"]
|
||||
assert len(consts) > 0
|
||||
some_const = consts[0]
|
||||
nc_decoder = TorchScriptPythonDecoder(model, some_const)
|
||||
ov_const = nc_decoder.as_constant()
|
||||
assert ov_const is not None
|
||||
assert len(ov_const) == 1
|
||||
assert ov_const[0].get_element_type() == Type.bf16
|
||||
assert ov_const[0].get_partial_shape() == PartialShape([2])
|
||||
|
||||
|
||||
@pytest.mark.precommit
|
||||
def test_pytorch_decoder_can_convert_fp32_tensor():
|
||||
from openvino.frontend.pytorch.decoder import TorchScriptPythonDecoder
|
||||
@ -453,6 +476,7 @@ def test_pytorch_decoder_can_convert_empty_list():
|
||||
assert ov_const[0].get_element_type() == Type.i32
|
||||
assert ov_const[0].get_partial_shape() == PartialShape([0])
|
||||
|
||||
|
||||
@pytest.mark.precommit
|
||||
def test_pytorch_decoder_can_convert_int_scalar_tensor():
|
||||
from openvino.frontend.pytorch.decoder import TorchScriptPythonDecoder
|
||||
@ -483,6 +507,7 @@ def test_pytorch_decoder_can_convert_int_scalar_tensor():
|
||||
assert ov_const[0].get_element_type() == Type.i32
|
||||
assert ov_const[0].get_partial_shape() == PartialShape([])
|
||||
|
||||
|
||||
@pytest.mark.precommit
|
||||
def test_pytorch_decoder_can_convert_float_scalar_tensor():
|
||||
from openvino.frontend.pytorch.decoder import TorchScriptPythonDecoder
|
||||
@ -498,10 +523,9 @@ def test_pytorch_decoder_can_convert_float_scalar_tensor():
|
||||
# would create nore with output being Tensor with IValue of type float.
|
||||
return torch.add(torch.tensor([1.], dtype=torch.float), self.value + 1)
|
||||
|
||||
|
||||
model = get_traced_model(SomeTensor(), frozen=False)
|
||||
consts = [n for n in model.inlined_graph.nodes() if n.kind() ==
|
||||
"prim::Constant"]
|
||||
"prim::Constant"]
|
||||
assert len(consts) > 0
|
||||
some_const = consts[6]
|
||||
node_output = list(some_const.outputs())[0]
|
||||
@ -514,6 +538,7 @@ def test_pytorch_decoder_can_convert_float_scalar_tensor():
|
||||
assert ov_const[0].get_element_type() == Type.f32
|
||||
assert ov_const[0].get_partial_shape() == PartialShape([])
|
||||
|
||||
|
||||
@pytest.mark.precommit
|
||||
def test_pytorch_decoder_can_convert_tensor_list():
|
||||
from openvino.frontend.pytorch.decoder import TorchScriptPythonDecoder
|
||||
@ -522,7 +547,8 @@ def test_pytorch_decoder_can_convert_tensor_list():
|
||||
|
||||
class SomeTensor(torch.nn.Module):
|
||||
def forward(self):
|
||||
l = torch.jit.annotate(List[Optional[torch.Tensor]], [torch.ones((1, 3, 3), dtype=torch.float),])
|
||||
l = torch.jit.annotate(List[Optional[torch.Tensor]], [
|
||||
torch.ones((1, 3, 3), dtype=torch.float),])
|
||||
return l
|
||||
|
||||
model = get_scripted_model(SomeTensor())
|
||||
@ -531,10 +557,12 @@ def test_pytorch_decoder_can_convert_tensor_list():
|
||||
nc_decoder = TorchScriptPythonDecoder(model)
|
||||
graph = nc_decoder.graph_element
|
||||
converted_const_nodes = list(graph.findAllNodes("prim::Constant"))
|
||||
converted_listconstruct_nodes = list(graph.findAllNodes("prim::ListConstruct"))
|
||||
converted_listconstruct_nodes = list(
|
||||
graph.findAllNodes("prim::ListConstruct"))
|
||||
# # Assert that replaced const exist and is not used
|
||||
assert len(converted_const_nodes) == 2
|
||||
assert len([node for node in converted_const_nodes if not node.hasUses()]) == 1
|
||||
assert len(
|
||||
[node for node in converted_const_nodes if not node.hasUses()]) == 1
|
||||
# Assert that prim::ListConstruct exist and has uses
|
||||
assert len(converted_listconstruct_nodes) == 1
|
||||
assert converted_listconstruct_nodes[0].kind() == "prim::ListConstruct"
|
||||
@ -542,9 +570,12 @@ def test_pytorch_decoder_can_convert_tensor_list():
|
||||
assert len(list(converted_listconstruct_nodes[0].inputs())) == 1
|
||||
created_const = converted_listconstruct_nodes[0].input().node()
|
||||
assert created_const in converted_const_nodes
|
||||
created_const_decoder = TorchScriptPythonDecoder(model, created_const).as_constant()
|
||||
created_const_decoder = TorchScriptPythonDecoder(
|
||||
model, created_const).as_constant()
|
||||
assert created_const_decoder[0].get_element_type() == Type.f32
|
||||
assert created_const_decoder[0].get_partial_shape() == PartialShape([1, 3, 3])
|
||||
assert created_const_decoder[0].get_partial_shape() == PartialShape([
|
||||
1, 3, 3])
|
||||
|
||||
|
||||
@pytest.mark.precommit
|
||||
def test_pytorch_decoder_can_convert_tensor_list_empty():
|
||||
@ -562,7 +593,8 @@ def test_pytorch_decoder_can_convert_tensor_list_empty():
|
||||
nc_decoder = TorchScriptPythonDecoder(model)
|
||||
graph = nc_decoder.graph_element
|
||||
converted_const_nodes = list(graph.findAllNodes("prim::Constant"))
|
||||
converted_listconstruct_nodes = list(graph.findAllNodes("prim::ListConstruct"))
|
||||
converted_listconstruct_nodes = list(
|
||||
graph.findAllNodes("prim::ListConstruct"))
|
||||
# Assert that replaced const exist and is not used
|
||||
assert len(converted_const_nodes) == 1
|
||||
assert not converted_const_nodes[0].hasUses()
|
||||
@ -572,10 +604,12 @@ def test_pytorch_decoder_can_convert_tensor_list_empty():
|
||||
assert converted_listconstruct_nodes[0].hasUses()
|
||||
assert len(list(converted_listconstruct_nodes[0].inputs())) == 0
|
||||
|
||||
|
||||
@pytest.mark.precommit
|
||||
def test_pytorch_decoder_can_convert_optional_tensor_none():
|
||||
from openvino.frontend.pytorch.decoder import TorchScriptPythonDecoder
|
||||
from typing import Optional
|
||||
|
||||
class SomeTensor(torch.nn.Module):
|
||||
def forward(self):
|
||||
l = torch.jit.annotate(Optional[torch.Tensor], None)
|
||||
@ -587,7 +621,8 @@ def test_pytorch_decoder_can_convert_optional_tensor_none():
|
||||
nc_decoder = TorchScriptPythonDecoder(model)
|
||||
graph = nc_decoder.graph_element
|
||||
converted_const_nodes = list(graph.findAllNodes("prim::Constant"))
|
||||
removed_consts = [node for node in converted_const_nodes if not node.hasUses()]
|
||||
removed_consts = [
|
||||
node for node in converted_const_nodes if not node.hasUses()]
|
||||
created_consts = [node for node in converted_const_nodes if node.hasUses()]
|
||||
assert len(removed_consts) == len(created_consts) == 1
|
||||
# Assert that unused const has torch.OptionalType dtype
|
||||
|
56
tests/layer_tests/pytorch_tests/test_fp16.py
Normal file
56
tests/layer_tests/pytorch_tests/test_fp16.py
Normal file
@ -0,0 +1,56 @@
|
||||
# Copyright (C) 2018-2023 Intel Corporation
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from pytorch_layer_test_class import PytorchLayerTest
|
||||
|
||||
|
||||
class TestBF16(PytorchLayerTest):
|
||||
|
||||
def _prepare_input(self):
|
||||
return (np.random.randn(10).astype(np.float32),)
|
||||
|
||||
def create_model(self):
|
||||
class aten_add(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super(aten_add, self).__init__()
|
||||
self.y = torch.randn(10, dtype=torch.bfloat16)
|
||||
|
||||
def forward(self, x):
|
||||
return x + self.y.to(torch.float32)
|
||||
|
||||
return aten_add(), None, "aten::add"
|
||||
|
||||
@pytest.mark.nightly
|
||||
@pytest.mark.precommit
|
||||
@pytest.mark.parametrize("to_trace", [True, False])
|
||||
def test_bf16(self, ie_device, precision, ir_version, to_trace):
|
||||
self._test(*self.create_model(), ie_device, precision,
|
||||
ir_version, trace_model=to_trace, freeze_model=False)
|
||||
|
||||
|
||||
class TestFP16(PytorchLayerTest):
|
||||
|
||||
def _prepare_input(self):
|
||||
return (np.random.randn(10).astype(np.float32),)
|
||||
|
||||
def create_model(self):
|
||||
class aten_add(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super(aten_add, self).__init__()
|
||||
self.y = torch.randn(10, dtype=torch.float16)
|
||||
|
||||
def forward(self, x):
|
||||
return x + self.y.to(torch.float32)
|
||||
|
||||
return aten_add(), None, "aten::add"
|
||||
|
||||
@pytest.mark.nightly
|
||||
@pytest.mark.precommit
|
||||
@pytest.mark.parametrize("to_trace", [True, False])
|
||||
def test_fp16(self, ie_device, precision, ir_version, to_trace):
|
||||
self._test(*self.create_model(), ie_device, precision,
|
||||
ir_version, trace_model=to_trace, freeze_model=False)
|
Loading…
Reference in New Issue
Block a user