TorchFX caching fix (#18813)

* TorchFX caching fix

* Added required newlines for code formatting

* TorchFX model caching file_name check added
This commit is contained in:
Mustafa Cavus
2023-07-28 03:25:18 -07:00
committed by GitHub
parent fd085f870f
commit c720052f40
3 changed files with 18 additions and 1 deletions

View File

@@ -127,3 +127,7 @@ def fx_openvino(subgraph, example_inputs):
except Exception as e:
log.debug(f"Failed in OpenVINO execution: {e}")
return compile_fx(subgraph, example_inputs)
def reset():
clear_caches()

View File

@@ -52,6 +52,8 @@ def openvino_compile(gm: GraphModule, *args, model_hash_str: str = None):
for idx, input_data in enumerate(args): # subgraph.example_inputs):
input_types.append(input_data.type())
input_shapes.append(input_data.size())
if file_name is not None:
file_name += "_" + str(input_data.type()) + str(input_data.size())[11:-1].replace(" ", "")
decoder = TorchFXPythonDecoder(gm, gm, input_shapes=input_shapes, input_types=input_types)

View File

@@ -143,9 +143,20 @@ def openvino_execute_partitioned(gm: GraphModule, *args, executor_parameters=Non
signature = str(id(gm))
for idx, input_data in enumerate(args):
signature = signature + "_" + str(idx) + ":" + str(input_data.type())[6:] + ":" + str(input_data.size())[11:-1].replace(" ", "")
if isinstance(input_data, torch.Tensor):
signature = signature + "_" + str(idx) + ":" + str(input_data.type())[6:] + ":" + str(input_data.size())[11:-1].replace(" ", "")
else:
signature = signature + "_" + str(idx) + ":" + type(input_data).__name__ + ":val(" + str(input_data) + ")"
if signature not in partitioned_modules:
partitioned_modules[signature] = partition_graph(gm, use_python_fusion_cache=use_python_fusion_cache,
model_hash_str=model_hash_str)
return partitioned_modules[signature](*args)
def clear_caches():
global partitioned_modules
global compiled_cache
compiled_cache.clear()
partitioned_modules.clear()