[PT FE] Fix for prim::Constant optional or containing list of tensors (#16754)

* Fix Constant list of tensor

* Write TorchScript transformation

* Handle Optional Tensor Constants

* Improve tests

* Add comments

* Try fix flake
This commit is contained in:
Mateusz Mikolajczyk 2023-04-24 22:56:42 +02:00 committed by GitHub
parent b452dab8f0
commit 8e5b0650a0
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 137 additions and 1 deletions

View File

@ -109,6 +109,10 @@ class TorchScriptPythonDecoder (Decoder):
if self._input_signature is not None and self.raw_inputs[0].debugName() == "self":
self._input_signature.insert(0, "self")
if isinstance(self.graph_element, torch.Graph):
self._transform_tensor_list_constants_to_listconstruct(self.graph_element)
self._transform_optional_constants(self.graph_element)
def _get_scripted_model(self, pt_module, example_inputs=None, freeze=True):
import torch
import inspect
@ -277,6 +281,7 @@ class TorchScriptPythonDecoder (Decoder):
return decoder
def get_op_type(self) -> str:
assert isinstance(self.graph_element, torch.Node), "Function can be called only when self.graph_element is of type torch.Node"
return self.graph_element.kind()
def get_schema(self) -> str:
@ -309,10 +314,11 @@ class TorchScriptPythonDecoder (Decoder):
return []
def as_constant(self):
if not isinstance(self.graph_element, torch.Node):
return None
if not self.get_op_type() == "prim::Constant":
return None
pt_value = self._raw_output(0)
pt_type = pt_value.type()
if isinstance(pt_type, torch.TensorType):
return ivalue_to_constant(pt_value.toIValue())
@ -369,3 +375,48 @@ class TorchScriptPythonDecoder (Decoder):
pt_value = get_value_from_getattr(in_node, self.pt_module)
return pt_value is None
return False
@staticmethod
def _transform_tensor_list_constants_to_listconstruct(graph: torch.Graph):
# Function replaces prim::Constant containing List of Tensors with
# prim::ListConstruct containing prim::Constant Tensors.
assert isinstance(graph, torch.Graph), "Function can be called only with parameters of type torch.Graph."
for node in graph.nodes():
if node.kind() != "prim::Constant":
continue
output_type = node.output().type()
allowed_types = [
output_type.isSubtypeOf(torch.ListType.ofTensors()),
output_type.isSubtypeOf(torch.ListType(torch.OptionalType.ofTensor())),
]
if not any(allowed_types):
continue
const_inputs = []
for val in node.output().toIValue():
const_input = graph.insertConstant(val)
const_input.node().moveBefore(node)
const_input.node().copyMetadata(node)
const_inputs.append(const_input)
replacement = graph.create("prim::ListConstruct", const_inputs)
replacement.insertBefore(node)
replacement.output().setType(torch.ListType.ofTensors())
replacement.copyMetadata(node)
node.output().replaceAllUsesWith(replacement.output())
@staticmethod
def _transform_optional_constants(graph: torch.Graph):
# Function replaces prim::Constant containing torch.OptionalType with
# prim::Constant containing torch.NoneType or type of IValue.
assert isinstance(graph, torch.Graph), "Function can be called only with parameters of type torch.Graph."
for node in graph.nodes():
if node.kind() != "prim::Constant":
continue
output_type = node.output().type()
if not isinstance(output_type, torch.OptionalType):
continue
value = node.output().toIValue()
const_input = graph.insertConstant(value)
const_input.node().moveBefore(node)
const_input.node().copyMetadata(node)
node.output().replaceAllUsesWith(const_input)

View File

@ -513,3 +513,88 @@ def test_pytorch_decoder_can_convert_float_scalar_tensor():
assert len(ov_const) == 1
assert ov_const[0].get_element_type() == Type.f32
assert ov_const[0].get_partial_shape() == PartialShape([])
@pytest.mark.precommit
def test_pytorch_decoder_can_convert_tensor_list():
from openvino.frontend.pytorch.decoder import TorchScriptPythonDecoder
from openvino.runtime import PartialShape, Type
from typing import List, Optional
class SomeTensor(torch.nn.Module):
def forward(self):
l = torch.jit.annotate(List[Optional[torch.Tensor]], [torch.ones((1, 3, 3), dtype=torch.float),])
return l
model = get_scripted_model(SomeTensor())
consts = list(model.graph.findAllNodes("prim::Constant"))
assert len(consts) == 1, "Input model should contain 1 prim::Constant"
nc_decoder = TorchScriptPythonDecoder(model)
graph = nc_decoder.graph_element
converted_const_nodes = list(graph.findAllNodes("prim::Constant"))
converted_listconstruct_nodes = list(graph.findAllNodes("prim::ListConstruct"))
# # Assert that replaced const exist and is not used
assert len(converted_const_nodes) == 2
assert len([node for node in converted_const_nodes if not node.hasUses()]) == 1
# Assert that prim::ListConstruct exist and has uses
assert len(converted_listconstruct_nodes) == 1
assert converted_listconstruct_nodes[0].kind() == "prim::ListConstruct"
assert converted_listconstruct_nodes[0].hasUses()
assert len(list(converted_listconstruct_nodes[0].inputs())) == 1
created_const = converted_listconstruct_nodes[0].input().node()
assert created_const in converted_const_nodes
created_const_decoder = TorchScriptPythonDecoder(model, created_const).as_constant()
assert created_const_decoder[0].get_element_type() == Type.f32
assert created_const_decoder[0].get_partial_shape() == PartialShape([1, 3, 3])
@pytest.mark.precommit
def test_pytorch_decoder_can_convert_tensor_list_empty():
from openvino.frontend.pytorch.decoder import TorchScriptPythonDecoder
from typing import List, Optional
class SomeTensor(torch.nn.Module):
def forward(self):
l = torch.jit.annotate(List[Optional[torch.Tensor]], [])
return l
model = get_scripted_model(SomeTensor())
consts = list(model.graph.findAllNodes("prim::Constant"))
assert len(consts) == 1, "Input model should contain 1 prim::Constant"
nc_decoder = TorchScriptPythonDecoder(model)
graph = nc_decoder.graph_element
converted_const_nodes = list(graph.findAllNodes("prim::Constant"))
converted_listconstruct_nodes = list(graph.findAllNodes("prim::ListConstruct"))
# Assert that replaced const exist and is not used
assert len(converted_const_nodes) == 1
assert not converted_const_nodes[0].hasUses()
# Assert that prim::ListConstruct exist, has uses and dont have inputs
assert len(converted_listconstruct_nodes) == 1
assert converted_listconstruct_nodes[0].kind() == "prim::ListConstruct"
assert converted_listconstruct_nodes[0].hasUses()
assert len(list(converted_listconstruct_nodes[0].inputs())) == 0
@pytest.mark.precommit
def test_pytorch_decoder_can_convert_optional_tensor_none():
from openvino.frontend.pytorch.decoder import TorchScriptPythonDecoder
from typing import Optional
class SomeTensor(torch.nn.Module):
def forward(self):
l = torch.jit.annotate(Optional[torch.Tensor], None)
return l
model = get_scripted_model(SomeTensor())
consts = list(model.graph.findAllNodes("prim::Constant"))
assert len(consts) == 1, "Input model should contain 1 prim::Constant"
nc_decoder = TorchScriptPythonDecoder(model)
graph = nc_decoder.graph_element
converted_const_nodes = list(graph.findAllNodes("prim::Constant"))
removed_consts = [node for node in converted_const_nodes if not node.hasUses()]
created_consts = [node for node in converted_const_nodes if node.hasUses()]
assert len(removed_consts) == len(created_consts) == 1
# Assert that unused const has torch.OptionalType dtype
assert isinstance(removed_consts[0].output().type(), torch.OptionalType)
# Assert that replacer const has correct dtype
assert isinstance(created_consts[0].output().type(), torch.NoneType)
# Assert that graph has correct output
outputs = list(nc_decoder.graph_element.outputs())
assert len(outputs) == 1
assert isinstance(outputs[0].type(), torch.NoneType)