[PT FE]: switch on tracing as main path if example inputs provided (#17194)

This commit is contained in:
Ekaterina Aidova 2023-04-26 12:50:43 +04:00 committed by GitHub
parent 09265083ed
commit 5857c4438b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -141,17 +141,17 @@ class TorchScriptPythonDecoder (Decoder):
input_signature = None
if isinstance(pt_module, torch.nn.Module) and not isinstance(pt_module, (torch.jit._trace.TopLevelTracedModule, torch.jit._script.RecursiveScriptModule)):
input_signature = list(inspect.signature(pt_module.forward).parameters.keys())
try:
if example_inputs is None:
scripted = torch.jit.script(pt_module)
except Exception as scripting_err:
if example_inputs is not None:
inputs, input_signature = prepare_example_inputs(example_inputs, input_signature)
else:
inputs, input_signature = prepare_example_inputs(example_inputs, input_signature)
try:
scripted = torch.jit.trace(pt_module, inputs)
except Exception:
try:
scripted = torch.jit.trace(pt_module, inputs)
except Exception as tracing_e:
raise tracing_e
else:
raise scripting_err
scripted = torch.jit.script(pt_module)
except Exception:
scripted = torch.jit.trace(pt_module, inputs, strict=False)
else:
scripted = pt_module
if freeze:
@ -253,7 +253,7 @@ class TorchScriptPythonDecoder (Decoder):
def get_subgraph_size(self) -> int:
if isinstance(self.graph_element, torch.Node):
return len(self.get_subgraphs())
return len(self.get_subgraphs())
else:
return 1