[ONNX] Fix memleak caused by shared_ptr cyclic dependency (#9236)

ONNXFrameworkNode had it own copy of shared_ptr<Graph> so in convert phase,
it can be used to produce real ngraph nodes (by graph->make_ng_nodes(..)).
But Graph also keeps ONNXFrameworkNodes in its cache and in consequence
its own shared_ptr, which is causing a dependency cycle.

This change removes shared_ptr<Graph> from ONNXFrameworkNode class
and moves it to decoded function runtime info, so Graph is in a single
place now and its lifetime ends when decoded function is destroyed.
This commit is contained in:
Mateusz Tabaka
2021-12-15 21:24:35 +01:00
committed by GitHub
parent d9ecb108f1
commit 38bbc30a29
7 changed files with 89 additions and 42 deletions

View File

@@ -26,6 +26,32 @@ def create_onnx_model():
return make_model(graph, producer_name="ngraph ONNX Importer")
def create_onnx_model_with_subgraphs():
A = onnx.helper.make_tensor_value_info("A", onnx.TensorProto.FLOAT, [3])
B = onnx.helper.make_tensor_value_info("B", onnx.TensorProto.FLOAT, [3])
add_out = onnx.helper.make_tensor_value_info("add_out", onnx.TensorProto.FLOAT, [3])
sub_out = onnx.helper.make_tensor_value_info("sub_out", onnx.TensorProto.FLOAT, [3])
add = onnx.helper.make_node("Add", inputs=["A", "B"], outputs=["add_out"])
sub = onnx.helper.make_node("Sub", inputs=["A", "B"], outputs=["sub_out"])
then_body = make_graph([add], "then_body", [], [add_out])
else_body = make_graph([sub], "else_body", [], [sub_out])
if_node = onnx.helper.make_node(
"If",
inputs=["cond"],
outputs=["res"],
then_branch=then_body,
else_branch=else_body
)
cond = onnx.helper.make_tensor_value_info("cond", onnx.TensorProto.BOOL, [])
res = onnx.helper.make_tensor_value_info("res", onnx.TensorProto.FLOAT, [3])
graph = make_graph([if_node], "graph", [cond, A, B], [res])
return make_model(graph, producer_name="ngraph ONNX Importer")
def run_function(function, *inputs, expected):
runtime = get_runtime()
computation = runtime.computation(function)
@@ -37,15 +63,18 @@ def run_function(function, *inputs, expected):
fem = FrontEndManager()
onnx_model_filename = "model.onnx"
onnx_model_with_subgraphs_filename = "model_subgraphs.onnx"
ONNX_FRONTEND_NAME = "onnx"
def setup_module():
onnx.save_model(create_onnx_model(), onnx_model_filename)
onnx.save_model(create_onnx_model_with_subgraphs(), onnx_model_with_subgraphs_filename)
def teardown_module():
os.remove(onnx_model_filename)
os.remove(onnx_model_with_subgraphs_filename)
def skip_if_onnx_frontend_is_disabled():
@@ -72,17 +101,29 @@ def test_convert():
run_function(function, a, b, expected=[expected])
def test_decode_and_convert():
@pytest.mark.parametrize("model_filename, inputs, expected", [
[onnx_model_filename,
[np.array([[1, 2], [3, 4]], dtype=np.float32),
np.array([[2, 3], [4, 5]], dtype=np.float32)],
np.array([[1.5, 5], [10.5, 18]], dtype=np.float32)],
[onnx_model_with_subgraphs_filename,
[np.array(False, dtype=bool),
np.array([1, 2, 3], dtype=np.float32),
np.array([2, 3, 5], dtype=np.float32)],
np.array([-1, -1, -2], dtype=np.float32)],
])
def test_decode_and_convert(model_filename, inputs, expected):
skip_if_onnx_frontend_is_disabled()
fe = fem.load_by_framework(framework=ONNX_FRONTEND_NAME)
assert fe
model = fe.load(onnx_model_filename)
model = fe.load(model_filename)
assert model
decoded_function = fe.decode(model)
assert decoded_function
for op in decoded_function.get_ordered_ops():
assert op.get_type_name() in ["Parameter", "Constant", "ONNXFrameworkNode",
"ONNXSubgraphFrameworkNode", "Result"]
@@ -92,10 +133,7 @@ def test_decode_and_convert():
for op in decoded_function.get_ordered_ops():
assert op.get_type_name() not in ["ONNXFrameworkNode", "ONNXSubgraphFrameworkNode"]
a = np.array([[1, 2], [3, 4]], dtype=np.float32)
b = np.array([[2, 3], [4, 5]], dtype=np.float32)
expected = np.array([[1.5, 5], [10.5, 18]], dtype=np.float32)
run_function(decoded_function, a, b, expected=[expected])
run_function(decoded_function, *inputs, expected=[expected])
def test_load_by_model():