From 5857c4438b9e97a9f8a0d87e3c3ff9502af1b465 Mon Sep 17 00:00:00 2001 From: Ekaterina Aidova Date: Wed, 26 Apr 2023 12:50:43 +0400 Subject: [PATCH] [PT FE]: switch on tracing as main path if example inputs provided (#17194) --- .../src/openvino/frontend/pytorch/decoder.py | 20 +++++++++---------- 1 file changed, 10 insertions(+), 10 deletions(-) diff --git a/src/bindings/python/src/openvino/frontend/pytorch/decoder.py b/src/bindings/python/src/openvino/frontend/pytorch/decoder.py index 01f09e373c5..a35663e93b4 100644 --- a/src/bindings/python/src/openvino/frontend/pytorch/decoder.py +++ b/src/bindings/python/src/openvino/frontend/pytorch/decoder.py @@ -141,17 +141,17 @@ class TorchScriptPythonDecoder (Decoder): input_signature = None if isinstance(pt_module, torch.nn.Module) and not isinstance(pt_module, (torch.jit._trace.TopLevelTracedModule, torch.jit._script.RecursiveScriptModule)): input_signature = list(inspect.signature(pt_module.forward).parameters.keys()) - try: + if example_inputs is None: scripted = torch.jit.script(pt_module) - except Exception as scripting_err: - if example_inputs is not None: - inputs, input_signature = prepare_example_inputs(example_inputs, input_signature) + else: + inputs, input_signature = prepare_example_inputs(example_inputs, input_signature) + try: + scripted = torch.jit.trace(pt_module, inputs) + except Exception: try: - scripted = torch.jit.trace(pt_module, inputs) - except Exception as tracing_e: - raise tracing_e - else: - raise scripting_err + scripted = torch.jit.script(pt_module) + except Exception: + scripted = torch.jit.trace(pt_module, inputs, strict=False) else: scripted = pt_module if freeze: @@ -253,7 +253,7 @@ class TorchScriptPythonDecoder (Decoder): def get_subgraph_size(self) -> int: if isinstance(self.graph_element, torch.Node): - return len(self.get_subgraphs()) + return len(self.get_subgraphs()) else: return 1