[ONNX] Fix memleak caused by shared_ptr cyclic dependency (#9236)

ONNXFrameworkNode had it own copy of shared_ptr<Graph> so in convert phase,
it can be used to produce real ngraph nodes (by graph->make_ng_nodes(..)).
But Graph also keeps ONNXFrameworkNodes in its cache and in consequence
its own shared_ptr, which is causing a dependency cycle.

This change removes shared_ptr<Graph> from ONNXFrameworkNode class
and moves it to decoded function runtime info, so Graph is in a single
place now and its lifetime ends when decoded function is destroyed.
This commit is contained in:
Mateusz Tabaka 2021-12-15 21:24:35 +01:00 committed by GitHub
parent d9ecb108f1
commit 38bbc30a29
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 89 additions and 42 deletions

View File

@ -26,6 +26,32 @@ def create_onnx_model():
return make_model(graph, producer_name="ngraph ONNX Importer") return make_model(graph, producer_name="ngraph ONNX Importer")
def create_onnx_model_with_subgraphs():
A = onnx.helper.make_tensor_value_info("A", onnx.TensorProto.FLOAT, [3])
B = onnx.helper.make_tensor_value_info("B", onnx.TensorProto.FLOAT, [3])
add_out = onnx.helper.make_tensor_value_info("add_out", onnx.TensorProto.FLOAT, [3])
sub_out = onnx.helper.make_tensor_value_info("sub_out", onnx.TensorProto.FLOAT, [3])
add = onnx.helper.make_node("Add", inputs=["A", "B"], outputs=["add_out"])
sub = onnx.helper.make_node("Sub", inputs=["A", "B"], outputs=["sub_out"])
then_body = make_graph([add], "then_body", [], [add_out])
else_body = make_graph([sub], "else_body", [], [sub_out])
if_node = onnx.helper.make_node(
"If",
inputs=["cond"],
outputs=["res"],
then_branch=then_body,
else_branch=else_body
)
cond = onnx.helper.make_tensor_value_info("cond", onnx.TensorProto.BOOL, [])
res = onnx.helper.make_tensor_value_info("res", onnx.TensorProto.FLOAT, [3])
graph = make_graph([if_node], "graph", [cond, A, B], [res])
return make_model(graph, producer_name="ngraph ONNX Importer")
def run_function(function, *inputs, expected): def run_function(function, *inputs, expected):
runtime = get_runtime() runtime = get_runtime()
computation = runtime.computation(function) computation = runtime.computation(function)
@ -37,15 +63,18 @@ def run_function(function, *inputs, expected):
fem = FrontEndManager() fem = FrontEndManager()
onnx_model_filename = "model.onnx" onnx_model_filename = "model.onnx"
onnx_model_with_subgraphs_filename = "model_subgraphs.onnx"
ONNX_FRONTEND_NAME = "onnx" ONNX_FRONTEND_NAME = "onnx"
def setup_module(): def setup_module():
onnx.save_model(create_onnx_model(), onnx_model_filename) onnx.save_model(create_onnx_model(), onnx_model_filename)
onnx.save_model(create_onnx_model_with_subgraphs(), onnx_model_with_subgraphs_filename)
def teardown_module(): def teardown_module():
os.remove(onnx_model_filename) os.remove(onnx_model_filename)
os.remove(onnx_model_with_subgraphs_filename)
def skip_if_onnx_frontend_is_disabled(): def skip_if_onnx_frontend_is_disabled():
@ -72,17 +101,29 @@ def test_convert():
run_function(function, a, b, expected=[expected]) run_function(function, a, b, expected=[expected])
def test_decode_and_convert(): @pytest.mark.parametrize("model_filename, inputs, expected", [
[onnx_model_filename,
[np.array([[1, 2], [3, 4]], dtype=np.float32),
np.array([[2, 3], [4, 5]], dtype=np.float32)],
np.array([[1.5, 5], [10.5, 18]], dtype=np.float32)],
[onnx_model_with_subgraphs_filename,
[np.array(False, dtype=bool),
np.array([1, 2, 3], dtype=np.float32),
np.array([2, 3, 5], dtype=np.float32)],
np.array([-1, -1, -2], dtype=np.float32)],
])
def test_decode_and_convert(model_filename, inputs, expected):
skip_if_onnx_frontend_is_disabled() skip_if_onnx_frontend_is_disabled()
fe = fem.load_by_framework(framework=ONNX_FRONTEND_NAME) fe = fem.load_by_framework(framework=ONNX_FRONTEND_NAME)
assert fe assert fe
model = fe.load(onnx_model_filename) model = fe.load(model_filename)
assert model assert model
decoded_function = fe.decode(model) decoded_function = fe.decode(model)
assert decoded_function assert decoded_function
for op in decoded_function.get_ordered_ops(): for op in decoded_function.get_ordered_ops():
assert op.get_type_name() in ["Parameter", "Constant", "ONNXFrameworkNode", assert op.get_type_name() in ["Parameter", "Constant", "ONNXFrameworkNode",
"ONNXSubgraphFrameworkNode", "Result"] "ONNXSubgraphFrameworkNode", "Result"]
@ -92,10 +133,7 @@ def test_decode_and_convert():
for op in decoded_function.get_ordered_ops(): for op in decoded_function.get_ordered_ops():
assert op.get_type_name() not in ["ONNXFrameworkNode", "ONNXSubgraphFrameworkNode"] assert op.get_type_name() not in ["ONNXFrameworkNode", "ONNXSubgraphFrameworkNode"]
a = np.array([[1, 2], [3, 4]], dtype=np.float32) run_function(decoded_function, *inputs, expected=[expected])
b = np.array([[2, 3], [4, 5]], dtype=np.float32)
expected = np.array([[1.5, 5], [10.5, 18]], dtype=np.float32)
run_function(decoded_function, a, b, expected=[expected])
def test_load_by_model(): def test_load_by_model():

View File

@ -380,7 +380,7 @@ NGRAPH_TEST(${BACKEND_NAME}, onnx_model_initializer_wo_input) {
test_case.run(); test_case.run();
} }
NGRAPH_TEST(onnx_${BACKEND_NAME}, onnx_expand_function) { NGRAPH_TEST(${BACKEND_NAME}, onnx_expand_function) {
const auto function = onnx_import::import_onnx_model( const auto function = onnx_import::import_onnx_model(
file_util::path_join(SERIALIZED_ZOO, "onnx/quantization/dynamicquantizelinear.onnx")); file_util::path_join(SERIALIZED_ZOO, "onnx/quantization/dynamicquantizelinear.onnx"));
@ -392,7 +392,7 @@ NGRAPH_TEST(onnx_${BACKEND_NAME}, onnx_expand_function) {
test_case.run(); test_case.run();
} }
NGRAPH_TEST(onnx_${BACKEND_NAME}, onnx_expand_function_dependency_to_created_subgraph) { NGRAPH_TEST(${BACKEND_NAME}, onnx_expand_function_dependency_to_created_subgraph) {
const auto function = onnx_import::import_onnx_model( const auto function = onnx_import::import_onnx_model(
file_util::path_join(SERIALIZED_ZOO, "onnx/transformations/greater_or_equal.onnx")); file_util::path_join(SERIALIZED_ZOO, "onnx/transformations/greater_or_equal.onnx"));
@ -403,7 +403,7 @@ NGRAPH_TEST(onnx_${BACKEND_NAME}, onnx_expand_function_dependency_to_created_sub
test_case.run(); test_case.run();
} }
NGRAPH_TEST(onnx_${BACKEND_NAME}, onnx_expand_context_dependent_function) { NGRAPH_TEST(${BACKEND_NAME}, onnx_expand_context_dependent_function) {
auto function = onnx_import::import_onnx_model( auto function = onnx_import::import_onnx_model(
file_util::path_join(SERIALIZED_ZOO, "onnx/transformations/softmax_crossentropy_consumed.onnx")); file_util::path_join(SERIALIZED_ZOO, "onnx/transformations/softmax_crossentropy_consumed.onnx"));

View File

@ -199,9 +199,10 @@ void Graph::decode_to_framework_nodes() {
if (node.has_subgraphs()) { if (node.has_subgraphs()) {
const auto& subgraphs = node.get_subgraphs(); const auto& subgraphs = node.get_subgraphs();
auto inputs = node.get_ng_inputs(); auto inputs = node.get_ng_inputs();
std::vector<std::shared_ptr<Function>> functions;
for (const auto& kv : subgraphs) { for (const auto& kv : subgraphs) {
auto& subgraph = kv.second; auto& subgraph = kv.second;
subgraph->decode(); functions.push_back(subgraph->decode());
for (const auto& input : subgraph->get_inputs_from_parent()) { for (const auto& input : subgraph->get_inputs_from_parent()) {
const auto& name = input.get_node()->get_friendly_name(); const auto& name = input.get_node()->get_friendly_name();
if (std::find_if(inputs.begin(), inputs.end(), [&name](const Output<ngraph::Node>& n) -> bool { if (std::find_if(inputs.begin(), inputs.end(), [&name](const Output<ngraph::Node>& n) -> bool {
@ -211,10 +212,9 @@ void Graph::decode_to_framework_nodes() {
} }
} }
} }
framework_node = framework_node = std::make_shared<frontend::ONNXSubgraphFrameworkNode>(node, functions, inputs);
std::make_shared<ngraph::frontend::ONNXSubgraphFrameworkNode>(shared_from_this(), node, inputs);
} else { } else {
framework_node = std::make_shared<ngraph::frontend::ONNXFrameworkNode>(shared_from_this(), node); framework_node = std::make_shared<frontend::ONNXFrameworkNode>(node);
} }
OutputVector ng_nodes{framework_node->outputs()}; OutputVector ng_nodes{framework_node->outputs()};
set_friendly_names(node, ng_nodes); set_friendly_names(node, ng_nodes);
@ -240,7 +240,10 @@ std::shared_ptr<Function> Graph::create_function() {
std::shared_ptr<Function> Graph::decode() { std::shared_ptr<Function> Graph::decode() {
decode_to_framework_nodes(); decode_to_framework_nodes();
return create_function(); auto function = create_function();
auto& rt_info = function->get_rt_info();
rt_info[ONNX_GRAPH_RT_ATTRIBUTE] = shared_from_this();
return function;
} }
bool Graph::is_ng_node_in_cache(const std::string& name) const { bool Graph::is_ng_node_in_cache(const std::string& name) const {
@ -399,7 +402,8 @@ void Subgraph::find_inputs_from_parent() {
for (const auto& out_name : node_proto.output()) { for (const auto& out_name : node_proto.output()) {
if (m_cache->contains(out_name)) { if (m_cache->contains(out_name)) {
auto node_to_replace_input = m_cache->get_node(out_name).get_node(); auto node_to_replace_input = m_cache->get_node(out_name).get_node();
if (!dynamic_cast<op::util::MultiSubGraphOp*>(node_to_replace_input)) if (!ov::is_type<op::util::MultiSubGraphOp>(node_to_replace_input) &&
!ov::is_type<frontend::ONNXSubgraphFrameworkNode>(node_to_replace_input))
continue; continue;
auto inputs = node_to_replace_input->input_values(); auto inputs = node_to_replace_input->input_values();
for (size_t i = 0; i < inputs.size(); i++) { for (size_t i = 0; i < inputs.size(); i++) {

View File

@ -121,6 +121,8 @@ inline std::ostream& operator<<(std::ostream& outs, const Graph& graph) {
return (outs << "<Graph: " << graph.get_name() << ">"); return (outs << "<Graph: " << graph.get_name() << ">");
} }
static const char* const ONNX_GRAPH_RT_ATTRIBUTE = "onnx_graph";
} // namespace onnx_import } // namespace onnx_import
} // namespace ngraph } // namespace ngraph

View File

@ -21,10 +21,14 @@ namespace frontend {
NGRAPH_RTTI_DEFINITION(ONNXFrameworkNode, "ONNXFrameworkNode", 1); NGRAPH_RTTI_DEFINITION(ONNXFrameworkNode, "ONNXFrameworkNode", 1);
std::shared_ptr<Node> ONNXFrameworkNode::clone_with_new_inputs(const OutputVector& inputs) const { std::shared_ptr<Node> ONNXFrameworkNode::clone_with_new_inputs(const OutputVector& inputs) const {
return std::make_shared<ONNXFrameworkNode>(m_graph, m_node, inputs); return std::make_shared<ONNXFrameworkNode>(m_node, inputs);
} }
NGRAPH_RTTI_DEFINITION(ONNXSubgraphFrameworkNode, "ONNXSubgraphFrameworkNode", 1); NGRAPH_RTTI_DEFINITION(ONNXSubgraphFrameworkNode, "ONNXSubgraphFrameworkNode", 1);
std::shared_ptr<Node> ONNXSubgraphFrameworkNode::clone_with_new_inputs(const OutputVector& inputs) const {
return std::make_shared<ONNXSubgraphFrameworkNode>(m_node, m_functions, inputs);
}
} // namespace frontend } // namespace frontend
} // namespace ngraph } // namespace ngraph

View File

@ -38,20 +38,16 @@ class ONNXFrameworkNode : public ov::op::util::FrameworkNode {
public: public:
NGRAPH_RTTI_DECLARATION; NGRAPH_RTTI_DECLARATION;
ONNXFrameworkNode(std::shared_ptr<onnx_import::Graph> graph, const onnx_import::Node& node) ONNXFrameworkNode(const onnx_import::Node& node)
: ov::op::util::FrameworkNode(node.get_ng_inputs(), node.get_outputs_size()), : ov::op::util::FrameworkNode(node.get_ng_inputs(), node.get_outputs_size()),
m_node(node), m_node(node) {}
m_graph(graph) {}
ONNXFrameworkNode(std::shared_ptr<onnx_import::Graph> graph, ONNXFrameworkNode(const onnx_import::Node& node, const OutputVector& inputs)
const onnx_import::Node& node,
const OutputVector& inputs)
: ov::op::util::FrameworkNode(inputs, node.get_outputs_size()), : ov::op::util::FrameworkNode(inputs, node.get_outputs_size()),
m_node(node), m_node(node) {}
m_graph(graph) {}
OutputVector get_ng_nodes() const { OutputVector get_ng_nodes(const std::shared_ptr<onnx_import::Graph>& graph) const {
OutputVector ng_nodes{m_graph->make_ng_nodes(m_node)}; OutputVector ng_nodes{graph->make_ng_nodes(m_node)};
if (ng_nodes.size() > get_output_size()) { if (ng_nodes.size() > get_output_size()) {
ng_nodes.resize(get_output_size()); ng_nodes.resize(get_output_size());
} }
@ -71,35 +67,31 @@ public:
protected: protected:
onnx_import::Node m_node; onnx_import::Node m_node;
private:
std::shared_ptr<onnx_import::Graph> m_graph;
}; };
class ONNXSubgraphFrameworkNode : public ONNXFrameworkNode { class ONNXSubgraphFrameworkNode : public ONNXFrameworkNode {
public: public:
NGRAPH_RTTI_DECLARATION; NGRAPH_RTTI_DECLARATION;
ONNXSubgraphFrameworkNode(std::shared_ptr<onnx_import::Graph> graph, ONNXSubgraphFrameworkNode(const onnx_import::Node& node,
const onnx_import::Node& node, const std::vector<std::shared_ptr<Function>>& functions,
const OutputVector& inputs) const OutputVector& inputs)
: ONNXFrameworkNode(graph, node, inputs) {} : ONNXFrameworkNode(node, inputs),
m_functions(functions) {}
void infer_inputs_from_parent() { void infer_inputs_from_parent() {
for (auto& subgraph : m_node.get_subgraphs()) for (auto& subgraph : m_node.get_subgraphs())
subgraph.second->infer_inputs_from_parent(); subgraph.second->infer_inputs_from_parent();
} }
std::vector<std::shared_ptr<Function>> get_subgraph_functions() const { const std::vector<std::shared_ptr<Function>>& get_subgraph_functions() const {
std::vector<std::shared_ptr<Function>> ret; return m_functions;
for (const auto& kv : m_node.get_subgraphs()) {
auto& subgraph = kv.second;
ret.push_back(std::make_shared<Function>(subgraph->get_ng_outputs(),
subgraph->get_ng_parameters(),
subgraph->get_name()));
}
return ret;
} }
virtual std::shared_ptr<Node> clone_with_new_inputs(const OutputVector& inputs) const override;
private:
std::vector<std::shared_ptr<Function>> m_functions;
}; };
} // namespace frontend } // namespace frontend

View File

@ -60,6 +60,12 @@ void apply_transformations(ONNX_NAMESPACE::ModelProto& model_proto, const std::s
} // namespace } // namespace
void convert_decoded_function(std::shared_ptr<Function> function) { void convert_decoded_function(std::shared_ptr<Function> function) {
auto& rt_info = function->get_rt_info();
auto it = rt_info.find(ONNX_GRAPH_RT_ATTRIBUTE);
OPENVINO_ASSERT(it != rt_info.end(),
"Could not find '" + std::string(ONNX_GRAPH_RT_ATTRIBUTE) +
"' attribute in decoded model. Model probably wasn't created by FrontEnd::decode function.");
auto onnx_graph = it->second.as<std::shared_ptr<onnx_import::Graph>>();
for (const auto& node : function->get_ordered_ops()) { for (const auto& node : function->get_ordered_ops()) {
if (auto raw_node = std::dynamic_pointer_cast<frontend::ONNXFrameworkNode>(node)) { if (auto raw_node = std::dynamic_pointer_cast<frontend::ONNXFrameworkNode>(node)) {
if (auto subgraph_node = std::dynamic_pointer_cast<frontend::ONNXSubgraphFrameworkNode>(node)) { if (auto subgraph_node = std::dynamic_pointer_cast<frontend::ONNXSubgraphFrameworkNode>(node)) {
@ -68,7 +74,7 @@ void convert_decoded_function(std::shared_ptr<Function> function) {
convert_decoded_function(function); convert_decoded_function(function);
} }
} }
auto ng_nodes = raw_node->get_ng_nodes(); auto ng_nodes = raw_node->get_ng_nodes(onnx_graph);
replace_node(raw_node, ng_nodes); replace_node(raw_node, ng_nodes);
} else { } else {
// Have to revalidate node because new intpus can affect shape/type // Have to revalidate node because new intpus can affect shape/type
@ -76,6 +82,7 @@ void convert_decoded_function(std::shared_ptr<Function> function) {
node->revalidate_and_infer_types(); node->revalidate_and_infer_types();
} }
} }
rt_info.erase(it);
detail::remove_dangling_parameters(function); detail::remove_dangling_parameters(function);
detail::remove_dangling_results(function); detail::remove_dangling_results(function);
} }