[PT FE]: switch on tracing as main path if example inputs provided (#17194)
This commit is contained in:
parent
09265083ed
commit
5857c4438b
@ -141,17 +141,17 @@ class TorchScriptPythonDecoder (Decoder):
|
||||
input_signature = None
|
||||
if isinstance(pt_module, torch.nn.Module) and not isinstance(pt_module, (torch.jit._trace.TopLevelTracedModule, torch.jit._script.RecursiveScriptModule)):
|
||||
input_signature = list(inspect.signature(pt_module.forward).parameters.keys())
|
||||
try:
|
||||
if example_inputs is None:
|
||||
scripted = torch.jit.script(pt_module)
|
||||
except Exception as scripting_err:
|
||||
if example_inputs is not None:
|
||||
inputs, input_signature = prepare_example_inputs(example_inputs, input_signature)
|
||||
else:
|
||||
inputs, input_signature = prepare_example_inputs(example_inputs, input_signature)
|
||||
try:
|
||||
scripted = torch.jit.trace(pt_module, inputs)
|
||||
except Exception:
|
||||
try:
|
||||
scripted = torch.jit.trace(pt_module, inputs)
|
||||
except Exception as tracing_e:
|
||||
raise tracing_e
|
||||
else:
|
||||
raise scripting_err
|
||||
scripted = torch.jit.script(pt_module)
|
||||
except Exception:
|
||||
scripted = torch.jit.trace(pt_module, inputs, strict=False)
|
||||
else:
|
||||
scripted = pt_module
|
||||
if freeze:
|
||||
@ -253,7 +253,7 @@ class TorchScriptPythonDecoder (Decoder):
|
||||
|
||||
def get_subgraph_size(self) -> int:
|
||||
if isinstance(self.graph_element, torch.Node):
|
||||
return len(self.get_subgraphs())
|
||||
return len(self.get_subgraphs())
|
||||
else:
|
||||
return 1
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user