[ONNX] Fix memleak caused by shared_ptr cyclic dependency (#9236)
ONNXFrameworkNode had it own copy of shared_ptr<Graph> so in convert phase, it can be used to produce real ngraph nodes (by graph->make_ng_nodes(..)). But Graph also keeps ONNXFrameworkNodes in its cache and in consequence its own shared_ptr, which is causing a dependency cycle. This change removes shared_ptr<Graph> from ONNXFrameworkNode class and moves it to decoded function runtime info, so Graph is in a single place now and its lifetime ends when decoded function is destroyed.
This commit is contained in:
@@ -26,6 +26,32 @@ def create_onnx_model():
|
||||
return make_model(graph, producer_name="ngraph ONNX Importer")
|
||||
|
||||
|
||||
def create_onnx_model_with_subgraphs():
|
||||
A = onnx.helper.make_tensor_value_info("A", onnx.TensorProto.FLOAT, [3])
|
||||
B = onnx.helper.make_tensor_value_info("B", onnx.TensorProto.FLOAT, [3])
|
||||
add_out = onnx.helper.make_tensor_value_info("add_out", onnx.TensorProto.FLOAT, [3])
|
||||
sub_out = onnx.helper.make_tensor_value_info("sub_out", onnx.TensorProto.FLOAT, [3])
|
||||
|
||||
add = onnx.helper.make_node("Add", inputs=["A", "B"], outputs=["add_out"])
|
||||
sub = onnx.helper.make_node("Sub", inputs=["A", "B"], outputs=["sub_out"])
|
||||
|
||||
then_body = make_graph([add], "then_body", [], [add_out])
|
||||
else_body = make_graph([sub], "else_body", [], [sub_out])
|
||||
|
||||
if_node = onnx.helper.make_node(
|
||||
"If",
|
||||
inputs=["cond"],
|
||||
outputs=["res"],
|
||||
then_branch=then_body,
|
||||
else_branch=else_body
|
||||
)
|
||||
cond = onnx.helper.make_tensor_value_info("cond", onnx.TensorProto.BOOL, [])
|
||||
res = onnx.helper.make_tensor_value_info("res", onnx.TensorProto.FLOAT, [3])
|
||||
|
||||
graph = make_graph([if_node], "graph", [cond, A, B], [res])
|
||||
return make_model(graph, producer_name="ngraph ONNX Importer")
|
||||
|
||||
|
||||
def run_function(function, *inputs, expected):
|
||||
runtime = get_runtime()
|
||||
computation = runtime.computation(function)
|
||||
@@ -37,15 +63,18 @@ def run_function(function, *inputs, expected):
|
||||
|
||||
fem = FrontEndManager()
|
||||
onnx_model_filename = "model.onnx"
|
||||
onnx_model_with_subgraphs_filename = "model_subgraphs.onnx"
|
||||
ONNX_FRONTEND_NAME = "onnx"
|
||||
|
||||
|
||||
def setup_module():
|
||||
onnx.save_model(create_onnx_model(), onnx_model_filename)
|
||||
onnx.save_model(create_onnx_model_with_subgraphs(), onnx_model_with_subgraphs_filename)
|
||||
|
||||
|
||||
def teardown_module():
|
||||
os.remove(onnx_model_filename)
|
||||
os.remove(onnx_model_with_subgraphs_filename)
|
||||
|
||||
|
||||
def skip_if_onnx_frontend_is_disabled():
|
||||
@@ -72,17 +101,29 @@ def test_convert():
|
||||
run_function(function, a, b, expected=[expected])
|
||||
|
||||
|
||||
def test_decode_and_convert():
|
||||
@pytest.mark.parametrize("model_filename, inputs, expected", [
|
||||
[onnx_model_filename,
|
||||
[np.array([[1, 2], [3, 4]], dtype=np.float32),
|
||||
np.array([[2, 3], [4, 5]], dtype=np.float32)],
|
||||
np.array([[1.5, 5], [10.5, 18]], dtype=np.float32)],
|
||||
[onnx_model_with_subgraphs_filename,
|
||||
[np.array(False, dtype=bool),
|
||||
np.array([1, 2, 3], dtype=np.float32),
|
||||
np.array([2, 3, 5], dtype=np.float32)],
|
||||
np.array([-1, -1, -2], dtype=np.float32)],
|
||||
])
|
||||
def test_decode_and_convert(model_filename, inputs, expected):
|
||||
skip_if_onnx_frontend_is_disabled()
|
||||
|
||||
fe = fem.load_by_framework(framework=ONNX_FRONTEND_NAME)
|
||||
assert fe
|
||||
|
||||
model = fe.load(onnx_model_filename)
|
||||
model = fe.load(model_filename)
|
||||
assert model
|
||||
|
||||
decoded_function = fe.decode(model)
|
||||
assert decoded_function
|
||||
|
||||
for op in decoded_function.get_ordered_ops():
|
||||
assert op.get_type_name() in ["Parameter", "Constant", "ONNXFrameworkNode",
|
||||
"ONNXSubgraphFrameworkNode", "Result"]
|
||||
@@ -92,10 +133,7 @@ def test_decode_and_convert():
|
||||
for op in decoded_function.get_ordered_ops():
|
||||
assert op.get_type_name() not in ["ONNXFrameworkNode", "ONNXSubgraphFrameworkNode"]
|
||||
|
||||
a = np.array([[1, 2], [3, 4]], dtype=np.float32)
|
||||
b = np.array([[2, 3], [4, 5]], dtype=np.float32)
|
||||
expected = np.array([[1.5, 5], [10.5, 18]], dtype=np.float32)
|
||||
run_function(decoded_function, a, b, expected=[expected])
|
||||
run_function(decoded_function, *inputs, expected=[expected])
|
||||
|
||||
|
||||
def test_load_by_model():
|
||||
|
||||
Reference in New Issue
Block a user