From 8e5b0650a0e71b83918be4bcf003a7e112f5e84c Mon Sep 17 00:00:00 2001 From: Mateusz Mikolajczyk Date: Mon, 24 Apr 2023 22:56:42 +0200 Subject: [PATCH] [PT FE] Fix for prim::Constant optional or containing list of tensors (#16754) * Fix Constant list of tensor * Write TorchScript transformation * Handle Optional Tensor Constants * Improve tests * Add comments * Try fix flake --- .../src/openvino/frontend/pytorch/decoder.py | 53 +++++++++++- .../py_frontend_tests/test_torch_decoder.py | 85 +++++++++++++++++++ 2 files changed, 137 insertions(+), 1 deletion(-) diff --git a/src/bindings/python/src/openvino/frontend/pytorch/decoder.py b/src/bindings/python/src/openvino/frontend/pytorch/decoder.py index d994a771de8..01f09e373c5 100644 --- a/src/bindings/python/src/openvino/frontend/pytorch/decoder.py +++ b/src/bindings/python/src/openvino/frontend/pytorch/decoder.py @@ -109,6 +109,10 @@ class TorchScriptPythonDecoder (Decoder): if self._input_signature is not None and self.raw_inputs[0].debugName() == "self": self._input_signature.insert(0, "self") + if isinstance(self.graph_element, torch.Graph): + self._transform_tensor_list_constants_to_listconstruct(self.graph_element) + self._transform_optional_constants(self.graph_element) + def _get_scripted_model(self, pt_module, example_inputs=None, freeze=True): import torch import inspect @@ -277,6 +281,7 @@ class TorchScriptPythonDecoder (Decoder): return decoder def get_op_type(self) -> str: + assert isinstance(self.graph_element, torch.Node), "Function can be called only when self.graph_element is of type torch.Node" return self.graph_element.kind() def get_schema(self) -> str: @@ -309,10 +314,11 @@ class TorchScriptPythonDecoder (Decoder): return [] def as_constant(self): + if not isinstance(self.graph_element, torch.Node): + return None if not self.get_op_type() == "prim::Constant": return None pt_value = self._raw_output(0) - pt_type = pt_value.type() if isinstance(pt_type, torch.TensorType): return ivalue_to_constant(pt_value.toIValue()) @@ -369,3 +375,48 @@ class TorchScriptPythonDecoder (Decoder): pt_value = get_value_from_getattr(in_node, self.pt_module) return pt_value is None return False + + @staticmethod + def _transform_tensor_list_constants_to_listconstruct(graph: torch.Graph): + # Function replaces prim::Constant containing List of Tensors with + # prim::ListConstruct containing prim::Constant Tensors. + assert isinstance(graph, torch.Graph), "Function can be called only with parameters of type torch.Graph." + for node in graph.nodes(): + if node.kind() != "prim::Constant": + continue + output_type = node.output().type() + allowed_types = [ + output_type.isSubtypeOf(torch.ListType.ofTensors()), + output_type.isSubtypeOf(torch.ListType(torch.OptionalType.ofTensor())), + ] + if not any(allowed_types): + continue + const_inputs = [] + for val in node.output().toIValue(): + const_input = graph.insertConstant(val) + const_input.node().moveBefore(node) + const_input.node().copyMetadata(node) + const_inputs.append(const_input) + + replacement = graph.create("prim::ListConstruct", const_inputs) + replacement.insertBefore(node) + replacement.output().setType(torch.ListType.ofTensors()) + replacement.copyMetadata(node) + node.output().replaceAllUsesWith(replacement.output()) + + @staticmethod + def _transform_optional_constants(graph: torch.Graph): + # Function replaces prim::Constant containing torch.OptionalType with + # prim::Constant containing torch.NoneType or type of IValue. + assert isinstance(graph, torch.Graph), "Function can be called only with parameters of type torch.Graph." + for node in graph.nodes(): + if node.kind() != "prim::Constant": + continue + output_type = node.output().type() + if not isinstance(output_type, torch.OptionalType): + continue + value = node.output().toIValue() + const_input = graph.insertConstant(value) + const_input.node().moveBefore(node) + const_input.node().copyMetadata(node) + node.output().replaceAllUsesWith(const_input) diff --git a/tests/layer_tests/py_frontend_tests/test_torch_decoder.py b/tests/layer_tests/py_frontend_tests/test_torch_decoder.py index 1df325338a2..e19b0e97fdd 100644 --- a/tests/layer_tests/py_frontend_tests/test_torch_decoder.py +++ b/tests/layer_tests/py_frontend_tests/test_torch_decoder.py @@ -513,3 +513,88 @@ def test_pytorch_decoder_can_convert_float_scalar_tensor(): assert len(ov_const) == 1 assert ov_const[0].get_element_type() == Type.f32 assert ov_const[0].get_partial_shape() == PartialShape([]) + +@pytest.mark.precommit +def test_pytorch_decoder_can_convert_tensor_list(): + from openvino.frontend.pytorch.decoder import TorchScriptPythonDecoder + from openvino.runtime import PartialShape, Type + from typing import List, Optional + + class SomeTensor(torch.nn.Module): + def forward(self): + l = torch.jit.annotate(List[Optional[torch.Tensor]], [torch.ones((1, 3, 3), dtype=torch.float),]) + return l + + model = get_scripted_model(SomeTensor()) + consts = list(model.graph.findAllNodes("prim::Constant")) + assert len(consts) == 1, "Input model should contain 1 prim::Constant" + nc_decoder = TorchScriptPythonDecoder(model) + graph = nc_decoder.graph_element + converted_const_nodes = list(graph.findAllNodes("prim::Constant")) + converted_listconstruct_nodes = list(graph.findAllNodes("prim::ListConstruct")) + # # Assert that replaced const exist and is not used + assert len(converted_const_nodes) == 2 + assert len([node for node in converted_const_nodes if not node.hasUses()]) == 1 + # Assert that prim::ListConstruct exist and has uses + assert len(converted_listconstruct_nodes) == 1 + assert converted_listconstruct_nodes[0].kind() == "prim::ListConstruct" + assert converted_listconstruct_nodes[0].hasUses() + assert len(list(converted_listconstruct_nodes[0].inputs())) == 1 + created_const = converted_listconstruct_nodes[0].input().node() + assert created_const in converted_const_nodes + created_const_decoder = TorchScriptPythonDecoder(model, created_const).as_constant() + assert created_const_decoder[0].get_element_type() == Type.f32 + assert created_const_decoder[0].get_partial_shape() == PartialShape([1, 3, 3]) + +@pytest.mark.precommit +def test_pytorch_decoder_can_convert_tensor_list_empty(): + from openvino.frontend.pytorch.decoder import TorchScriptPythonDecoder + from typing import List, Optional + + class SomeTensor(torch.nn.Module): + def forward(self): + l = torch.jit.annotate(List[Optional[torch.Tensor]], []) + return l + + model = get_scripted_model(SomeTensor()) + consts = list(model.graph.findAllNodes("prim::Constant")) + assert len(consts) == 1, "Input model should contain 1 prim::Constant" + nc_decoder = TorchScriptPythonDecoder(model) + graph = nc_decoder.graph_element + converted_const_nodes = list(graph.findAllNodes("prim::Constant")) + converted_listconstruct_nodes = list(graph.findAllNodes("prim::ListConstruct")) + # Assert that replaced const exist and is not used + assert len(converted_const_nodes) == 1 + assert not converted_const_nodes[0].hasUses() + # Assert that prim::ListConstruct exist, has uses and dont have inputs + assert len(converted_listconstruct_nodes) == 1 + assert converted_listconstruct_nodes[0].kind() == "prim::ListConstruct" + assert converted_listconstruct_nodes[0].hasUses() + assert len(list(converted_listconstruct_nodes[0].inputs())) == 0 + +@pytest.mark.precommit +def test_pytorch_decoder_can_convert_optional_tensor_none(): + from openvino.frontend.pytorch.decoder import TorchScriptPythonDecoder + from typing import Optional + class SomeTensor(torch.nn.Module): + def forward(self): + l = torch.jit.annotate(Optional[torch.Tensor], None) + return l + + model = get_scripted_model(SomeTensor()) + consts = list(model.graph.findAllNodes("prim::Constant")) + assert len(consts) == 1, "Input model should contain 1 prim::Constant" + nc_decoder = TorchScriptPythonDecoder(model) + graph = nc_decoder.graph_element + converted_const_nodes = list(graph.findAllNodes("prim::Constant")) + removed_consts = [node for node in converted_const_nodes if not node.hasUses()] + created_consts = [node for node in converted_const_nodes if node.hasUses()] + assert len(removed_consts) == len(created_consts) == 1 + # Assert that unused const has torch.OptionalType dtype + assert isinstance(removed_consts[0].output().type(), torch.OptionalType) + # Assert that replacer const has correct dtype + assert isinstance(created_consts[0].output().type(), torch.NoneType) + # Assert that graph has correct output + outputs = list(nc_decoder.graph_element.outputs()) + assert len(outputs) == 1 + assert isinstance(outputs[0].type(), torch.NoneType)