[PT FE] Fix for prim::Constant optional or containing list of tensors (#16754)
* Fix Constant list of tensor * Write TorchScript transformation * Handle Optional Tensor Constants * Improve tests * Add comments * Try fix flake
This commit is contained in:
parent
b452dab8f0
commit
8e5b0650a0
@ -109,6 +109,10 @@ class TorchScriptPythonDecoder (Decoder):
|
||||
if self._input_signature is not None and self.raw_inputs[0].debugName() == "self":
|
||||
self._input_signature.insert(0, "self")
|
||||
|
||||
if isinstance(self.graph_element, torch.Graph):
|
||||
self._transform_tensor_list_constants_to_listconstruct(self.graph_element)
|
||||
self._transform_optional_constants(self.graph_element)
|
||||
|
||||
def _get_scripted_model(self, pt_module, example_inputs=None, freeze=True):
|
||||
import torch
|
||||
import inspect
|
||||
@ -277,6 +281,7 @@ class TorchScriptPythonDecoder (Decoder):
|
||||
return decoder
|
||||
|
||||
def get_op_type(self) -> str:
|
||||
assert isinstance(self.graph_element, torch.Node), "Function can be called only when self.graph_element is of type torch.Node"
|
||||
return self.graph_element.kind()
|
||||
|
||||
def get_schema(self) -> str:
|
||||
@ -309,10 +314,11 @@ class TorchScriptPythonDecoder (Decoder):
|
||||
return []
|
||||
|
||||
def as_constant(self):
|
||||
if not isinstance(self.graph_element, torch.Node):
|
||||
return None
|
||||
if not self.get_op_type() == "prim::Constant":
|
||||
return None
|
||||
pt_value = self._raw_output(0)
|
||||
|
||||
pt_type = pt_value.type()
|
||||
if isinstance(pt_type, torch.TensorType):
|
||||
return ivalue_to_constant(pt_value.toIValue())
|
||||
@ -369,3 +375,48 @@ class TorchScriptPythonDecoder (Decoder):
|
||||
pt_value = get_value_from_getattr(in_node, self.pt_module)
|
||||
return pt_value is None
|
||||
return False
|
||||
|
||||
@staticmethod
|
||||
def _transform_tensor_list_constants_to_listconstruct(graph: torch.Graph):
|
||||
# Function replaces prim::Constant containing List of Tensors with
|
||||
# prim::ListConstruct containing prim::Constant Tensors.
|
||||
assert isinstance(graph, torch.Graph), "Function can be called only with parameters of type torch.Graph."
|
||||
for node in graph.nodes():
|
||||
if node.kind() != "prim::Constant":
|
||||
continue
|
||||
output_type = node.output().type()
|
||||
allowed_types = [
|
||||
output_type.isSubtypeOf(torch.ListType.ofTensors()),
|
||||
output_type.isSubtypeOf(torch.ListType(torch.OptionalType.ofTensor())),
|
||||
]
|
||||
if not any(allowed_types):
|
||||
continue
|
||||
const_inputs = []
|
||||
for val in node.output().toIValue():
|
||||
const_input = graph.insertConstant(val)
|
||||
const_input.node().moveBefore(node)
|
||||
const_input.node().copyMetadata(node)
|
||||
const_inputs.append(const_input)
|
||||
|
||||
replacement = graph.create("prim::ListConstruct", const_inputs)
|
||||
replacement.insertBefore(node)
|
||||
replacement.output().setType(torch.ListType.ofTensors())
|
||||
replacement.copyMetadata(node)
|
||||
node.output().replaceAllUsesWith(replacement.output())
|
||||
|
||||
@staticmethod
|
||||
def _transform_optional_constants(graph: torch.Graph):
|
||||
# Function replaces prim::Constant containing torch.OptionalType with
|
||||
# prim::Constant containing torch.NoneType or type of IValue.
|
||||
assert isinstance(graph, torch.Graph), "Function can be called only with parameters of type torch.Graph."
|
||||
for node in graph.nodes():
|
||||
if node.kind() != "prim::Constant":
|
||||
continue
|
||||
output_type = node.output().type()
|
||||
if not isinstance(output_type, torch.OptionalType):
|
||||
continue
|
||||
value = node.output().toIValue()
|
||||
const_input = graph.insertConstant(value)
|
||||
const_input.node().moveBefore(node)
|
||||
const_input.node().copyMetadata(node)
|
||||
node.output().replaceAllUsesWith(const_input)
|
||||
|
@ -513,3 +513,88 @@ def test_pytorch_decoder_can_convert_float_scalar_tensor():
|
||||
assert len(ov_const) == 1
|
||||
assert ov_const[0].get_element_type() == Type.f32
|
||||
assert ov_const[0].get_partial_shape() == PartialShape([])
|
||||
|
||||
@pytest.mark.precommit
|
||||
def test_pytorch_decoder_can_convert_tensor_list():
|
||||
from openvino.frontend.pytorch.decoder import TorchScriptPythonDecoder
|
||||
from openvino.runtime import PartialShape, Type
|
||||
from typing import List, Optional
|
||||
|
||||
class SomeTensor(torch.nn.Module):
|
||||
def forward(self):
|
||||
l = torch.jit.annotate(List[Optional[torch.Tensor]], [torch.ones((1, 3, 3), dtype=torch.float),])
|
||||
return l
|
||||
|
||||
model = get_scripted_model(SomeTensor())
|
||||
consts = list(model.graph.findAllNodes("prim::Constant"))
|
||||
assert len(consts) == 1, "Input model should contain 1 prim::Constant"
|
||||
nc_decoder = TorchScriptPythonDecoder(model)
|
||||
graph = nc_decoder.graph_element
|
||||
converted_const_nodes = list(graph.findAllNodes("prim::Constant"))
|
||||
converted_listconstruct_nodes = list(graph.findAllNodes("prim::ListConstruct"))
|
||||
# # Assert that replaced const exist and is not used
|
||||
assert len(converted_const_nodes) == 2
|
||||
assert len([node for node in converted_const_nodes if not node.hasUses()]) == 1
|
||||
# Assert that prim::ListConstruct exist and has uses
|
||||
assert len(converted_listconstruct_nodes) == 1
|
||||
assert converted_listconstruct_nodes[0].kind() == "prim::ListConstruct"
|
||||
assert converted_listconstruct_nodes[0].hasUses()
|
||||
assert len(list(converted_listconstruct_nodes[0].inputs())) == 1
|
||||
created_const = converted_listconstruct_nodes[0].input().node()
|
||||
assert created_const in converted_const_nodes
|
||||
created_const_decoder = TorchScriptPythonDecoder(model, created_const).as_constant()
|
||||
assert created_const_decoder[0].get_element_type() == Type.f32
|
||||
assert created_const_decoder[0].get_partial_shape() == PartialShape([1, 3, 3])
|
||||
|
||||
@pytest.mark.precommit
|
||||
def test_pytorch_decoder_can_convert_tensor_list_empty():
|
||||
from openvino.frontend.pytorch.decoder import TorchScriptPythonDecoder
|
||||
from typing import List, Optional
|
||||
|
||||
class SomeTensor(torch.nn.Module):
|
||||
def forward(self):
|
||||
l = torch.jit.annotate(List[Optional[torch.Tensor]], [])
|
||||
return l
|
||||
|
||||
model = get_scripted_model(SomeTensor())
|
||||
consts = list(model.graph.findAllNodes("prim::Constant"))
|
||||
assert len(consts) == 1, "Input model should contain 1 prim::Constant"
|
||||
nc_decoder = TorchScriptPythonDecoder(model)
|
||||
graph = nc_decoder.graph_element
|
||||
converted_const_nodes = list(graph.findAllNodes("prim::Constant"))
|
||||
converted_listconstruct_nodes = list(graph.findAllNodes("prim::ListConstruct"))
|
||||
# Assert that replaced const exist and is not used
|
||||
assert len(converted_const_nodes) == 1
|
||||
assert not converted_const_nodes[0].hasUses()
|
||||
# Assert that prim::ListConstruct exist, has uses and dont have inputs
|
||||
assert len(converted_listconstruct_nodes) == 1
|
||||
assert converted_listconstruct_nodes[0].kind() == "prim::ListConstruct"
|
||||
assert converted_listconstruct_nodes[0].hasUses()
|
||||
assert len(list(converted_listconstruct_nodes[0].inputs())) == 0
|
||||
|
||||
@pytest.mark.precommit
|
||||
def test_pytorch_decoder_can_convert_optional_tensor_none():
|
||||
from openvino.frontend.pytorch.decoder import TorchScriptPythonDecoder
|
||||
from typing import Optional
|
||||
class SomeTensor(torch.nn.Module):
|
||||
def forward(self):
|
||||
l = torch.jit.annotate(Optional[torch.Tensor], None)
|
||||
return l
|
||||
|
||||
model = get_scripted_model(SomeTensor())
|
||||
consts = list(model.graph.findAllNodes("prim::Constant"))
|
||||
assert len(consts) == 1, "Input model should contain 1 prim::Constant"
|
||||
nc_decoder = TorchScriptPythonDecoder(model)
|
||||
graph = nc_decoder.graph_element
|
||||
converted_const_nodes = list(graph.findAllNodes("prim::Constant"))
|
||||
removed_consts = [node for node in converted_const_nodes if not node.hasUses()]
|
||||
created_consts = [node for node in converted_const_nodes if node.hasUses()]
|
||||
assert len(removed_consts) == len(created_consts) == 1
|
||||
# Assert that unused const has torch.OptionalType dtype
|
||||
assert isinstance(removed_consts[0].output().type(), torch.OptionalType)
|
||||
# Assert that replacer const has correct dtype
|
||||
assert isinstance(created_consts[0].output().type(), torch.NoneType)
|
||||
# Assert that graph has correct output
|
||||
outputs = list(nc_decoder.graph_element.outputs())
|
||||
assert len(outputs) == 1
|
||||
assert isinstance(outputs[0].type(), torch.NoneType)
|
||||
|
Loading…
Reference in New Issue
Block a user