[ONNX] Fix memleak caused by shared_ptr cyclic dependency (#9236)
ONNXFrameworkNode had it own copy of shared_ptr<Graph> so in convert phase, it can be used to produce real ngraph nodes (by graph->make_ng_nodes(..)). But Graph also keeps ONNXFrameworkNodes in its cache and in consequence its own shared_ptr, which is causing a dependency cycle. This change removes shared_ptr<Graph> from ONNXFrameworkNode class and moves it to decoded function runtime info, so Graph is in a single place now and its lifetime ends when decoded function is destroyed.
This commit is contained in:
parent
d9ecb108f1
commit
38bbc30a29
@ -26,6 +26,32 @@ def create_onnx_model():
|
||||
return make_model(graph, producer_name="ngraph ONNX Importer")
|
||||
|
||||
|
||||
def create_onnx_model_with_subgraphs():
|
||||
A = onnx.helper.make_tensor_value_info("A", onnx.TensorProto.FLOAT, [3])
|
||||
B = onnx.helper.make_tensor_value_info("B", onnx.TensorProto.FLOAT, [3])
|
||||
add_out = onnx.helper.make_tensor_value_info("add_out", onnx.TensorProto.FLOAT, [3])
|
||||
sub_out = onnx.helper.make_tensor_value_info("sub_out", onnx.TensorProto.FLOAT, [3])
|
||||
|
||||
add = onnx.helper.make_node("Add", inputs=["A", "B"], outputs=["add_out"])
|
||||
sub = onnx.helper.make_node("Sub", inputs=["A", "B"], outputs=["sub_out"])
|
||||
|
||||
then_body = make_graph([add], "then_body", [], [add_out])
|
||||
else_body = make_graph([sub], "else_body", [], [sub_out])
|
||||
|
||||
if_node = onnx.helper.make_node(
|
||||
"If",
|
||||
inputs=["cond"],
|
||||
outputs=["res"],
|
||||
then_branch=then_body,
|
||||
else_branch=else_body
|
||||
)
|
||||
cond = onnx.helper.make_tensor_value_info("cond", onnx.TensorProto.BOOL, [])
|
||||
res = onnx.helper.make_tensor_value_info("res", onnx.TensorProto.FLOAT, [3])
|
||||
|
||||
graph = make_graph([if_node], "graph", [cond, A, B], [res])
|
||||
return make_model(graph, producer_name="ngraph ONNX Importer")
|
||||
|
||||
|
||||
def run_function(function, *inputs, expected):
|
||||
runtime = get_runtime()
|
||||
computation = runtime.computation(function)
|
||||
@ -37,15 +63,18 @@ def run_function(function, *inputs, expected):
|
||||
|
||||
fem = FrontEndManager()
|
||||
onnx_model_filename = "model.onnx"
|
||||
onnx_model_with_subgraphs_filename = "model_subgraphs.onnx"
|
||||
ONNX_FRONTEND_NAME = "onnx"
|
||||
|
||||
|
||||
def setup_module():
|
||||
onnx.save_model(create_onnx_model(), onnx_model_filename)
|
||||
onnx.save_model(create_onnx_model_with_subgraphs(), onnx_model_with_subgraphs_filename)
|
||||
|
||||
|
||||
def teardown_module():
|
||||
os.remove(onnx_model_filename)
|
||||
os.remove(onnx_model_with_subgraphs_filename)
|
||||
|
||||
|
||||
def skip_if_onnx_frontend_is_disabled():
|
||||
@ -72,17 +101,29 @@ def test_convert():
|
||||
run_function(function, a, b, expected=[expected])
|
||||
|
||||
|
||||
def test_decode_and_convert():
|
||||
@pytest.mark.parametrize("model_filename, inputs, expected", [
|
||||
[onnx_model_filename,
|
||||
[np.array([[1, 2], [3, 4]], dtype=np.float32),
|
||||
np.array([[2, 3], [4, 5]], dtype=np.float32)],
|
||||
np.array([[1.5, 5], [10.5, 18]], dtype=np.float32)],
|
||||
[onnx_model_with_subgraphs_filename,
|
||||
[np.array(False, dtype=bool),
|
||||
np.array([1, 2, 3], dtype=np.float32),
|
||||
np.array([2, 3, 5], dtype=np.float32)],
|
||||
np.array([-1, -1, -2], dtype=np.float32)],
|
||||
])
|
||||
def test_decode_and_convert(model_filename, inputs, expected):
|
||||
skip_if_onnx_frontend_is_disabled()
|
||||
|
||||
fe = fem.load_by_framework(framework=ONNX_FRONTEND_NAME)
|
||||
assert fe
|
||||
|
||||
model = fe.load(onnx_model_filename)
|
||||
model = fe.load(model_filename)
|
||||
assert model
|
||||
|
||||
decoded_function = fe.decode(model)
|
||||
assert decoded_function
|
||||
|
||||
for op in decoded_function.get_ordered_ops():
|
||||
assert op.get_type_name() in ["Parameter", "Constant", "ONNXFrameworkNode",
|
||||
"ONNXSubgraphFrameworkNode", "Result"]
|
||||
@ -92,10 +133,7 @@ def test_decode_and_convert():
|
||||
for op in decoded_function.get_ordered_ops():
|
||||
assert op.get_type_name() not in ["ONNXFrameworkNode", "ONNXSubgraphFrameworkNode"]
|
||||
|
||||
a = np.array([[1, 2], [3, 4]], dtype=np.float32)
|
||||
b = np.array([[2, 3], [4, 5]], dtype=np.float32)
|
||||
expected = np.array([[1.5, 5], [10.5, 18]], dtype=np.float32)
|
||||
run_function(decoded_function, a, b, expected=[expected])
|
||||
run_function(decoded_function, *inputs, expected=[expected])
|
||||
|
||||
|
||||
def test_load_by_model():
|
||||
|
@ -380,7 +380,7 @@ NGRAPH_TEST(${BACKEND_NAME}, onnx_model_initializer_wo_input) {
|
||||
test_case.run();
|
||||
}
|
||||
|
||||
NGRAPH_TEST(onnx_${BACKEND_NAME}, onnx_expand_function) {
|
||||
NGRAPH_TEST(${BACKEND_NAME}, onnx_expand_function) {
|
||||
const auto function = onnx_import::import_onnx_model(
|
||||
file_util::path_join(SERIALIZED_ZOO, "onnx/quantization/dynamicquantizelinear.onnx"));
|
||||
|
||||
@ -392,7 +392,7 @@ NGRAPH_TEST(onnx_${BACKEND_NAME}, onnx_expand_function) {
|
||||
test_case.run();
|
||||
}
|
||||
|
||||
NGRAPH_TEST(onnx_${BACKEND_NAME}, onnx_expand_function_dependency_to_created_subgraph) {
|
||||
NGRAPH_TEST(${BACKEND_NAME}, onnx_expand_function_dependency_to_created_subgraph) {
|
||||
const auto function = onnx_import::import_onnx_model(
|
||||
file_util::path_join(SERIALIZED_ZOO, "onnx/transformations/greater_or_equal.onnx"));
|
||||
|
||||
@ -403,7 +403,7 @@ NGRAPH_TEST(onnx_${BACKEND_NAME}, onnx_expand_function_dependency_to_created_sub
|
||||
test_case.run();
|
||||
}
|
||||
|
||||
NGRAPH_TEST(onnx_${BACKEND_NAME}, onnx_expand_context_dependent_function) {
|
||||
NGRAPH_TEST(${BACKEND_NAME}, onnx_expand_context_dependent_function) {
|
||||
auto function = onnx_import::import_onnx_model(
|
||||
file_util::path_join(SERIALIZED_ZOO, "onnx/transformations/softmax_crossentropy_consumed.onnx"));
|
||||
|
||||
|
@ -199,9 +199,10 @@ void Graph::decode_to_framework_nodes() {
|
||||
if (node.has_subgraphs()) {
|
||||
const auto& subgraphs = node.get_subgraphs();
|
||||
auto inputs = node.get_ng_inputs();
|
||||
std::vector<std::shared_ptr<Function>> functions;
|
||||
for (const auto& kv : subgraphs) {
|
||||
auto& subgraph = kv.second;
|
||||
subgraph->decode();
|
||||
functions.push_back(subgraph->decode());
|
||||
for (const auto& input : subgraph->get_inputs_from_parent()) {
|
||||
const auto& name = input.get_node()->get_friendly_name();
|
||||
if (std::find_if(inputs.begin(), inputs.end(), [&name](const Output<ngraph::Node>& n) -> bool {
|
||||
@ -211,10 +212,9 @@ void Graph::decode_to_framework_nodes() {
|
||||
}
|
||||
}
|
||||
}
|
||||
framework_node =
|
||||
std::make_shared<ngraph::frontend::ONNXSubgraphFrameworkNode>(shared_from_this(), node, inputs);
|
||||
framework_node = std::make_shared<frontend::ONNXSubgraphFrameworkNode>(node, functions, inputs);
|
||||
} else {
|
||||
framework_node = std::make_shared<ngraph::frontend::ONNXFrameworkNode>(shared_from_this(), node);
|
||||
framework_node = std::make_shared<frontend::ONNXFrameworkNode>(node);
|
||||
}
|
||||
OutputVector ng_nodes{framework_node->outputs()};
|
||||
set_friendly_names(node, ng_nodes);
|
||||
@ -240,7 +240,10 @@ std::shared_ptr<Function> Graph::create_function() {
|
||||
|
||||
std::shared_ptr<Function> Graph::decode() {
|
||||
decode_to_framework_nodes();
|
||||
return create_function();
|
||||
auto function = create_function();
|
||||
auto& rt_info = function->get_rt_info();
|
||||
rt_info[ONNX_GRAPH_RT_ATTRIBUTE] = shared_from_this();
|
||||
return function;
|
||||
}
|
||||
|
||||
bool Graph::is_ng_node_in_cache(const std::string& name) const {
|
||||
@ -399,7 +402,8 @@ void Subgraph::find_inputs_from_parent() {
|
||||
for (const auto& out_name : node_proto.output()) {
|
||||
if (m_cache->contains(out_name)) {
|
||||
auto node_to_replace_input = m_cache->get_node(out_name).get_node();
|
||||
if (!dynamic_cast<op::util::MultiSubGraphOp*>(node_to_replace_input))
|
||||
if (!ov::is_type<op::util::MultiSubGraphOp>(node_to_replace_input) &&
|
||||
!ov::is_type<frontend::ONNXSubgraphFrameworkNode>(node_to_replace_input))
|
||||
continue;
|
||||
auto inputs = node_to_replace_input->input_values();
|
||||
for (size_t i = 0; i < inputs.size(); i++) {
|
||||
|
@ -121,6 +121,8 @@ inline std::ostream& operator<<(std::ostream& outs, const Graph& graph) {
|
||||
return (outs << "<Graph: " << graph.get_name() << ">");
|
||||
}
|
||||
|
||||
static const char* const ONNX_GRAPH_RT_ATTRIBUTE = "onnx_graph";
|
||||
|
||||
} // namespace onnx_import
|
||||
|
||||
} // namespace ngraph
|
||||
|
@ -21,10 +21,14 @@ namespace frontend {
|
||||
NGRAPH_RTTI_DEFINITION(ONNXFrameworkNode, "ONNXFrameworkNode", 1);
|
||||
|
||||
std::shared_ptr<Node> ONNXFrameworkNode::clone_with_new_inputs(const OutputVector& inputs) const {
|
||||
return std::make_shared<ONNXFrameworkNode>(m_graph, m_node, inputs);
|
||||
return std::make_shared<ONNXFrameworkNode>(m_node, inputs);
|
||||
}
|
||||
|
||||
NGRAPH_RTTI_DEFINITION(ONNXSubgraphFrameworkNode, "ONNXSubgraphFrameworkNode", 1);
|
||||
|
||||
std::shared_ptr<Node> ONNXSubgraphFrameworkNode::clone_with_new_inputs(const OutputVector& inputs) const {
|
||||
return std::make_shared<ONNXSubgraphFrameworkNode>(m_node, m_functions, inputs);
|
||||
}
|
||||
|
||||
} // namespace frontend
|
||||
} // namespace ngraph
|
||||
|
@ -38,20 +38,16 @@ class ONNXFrameworkNode : public ov::op::util::FrameworkNode {
|
||||
public:
|
||||
NGRAPH_RTTI_DECLARATION;
|
||||
|
||||
ONNXFrameworkNode(std::shared_ptr<onnx_import::Graph> graph, const onnx_import::Node& node)
|
||||
ONNXFrameworkNode(const onnx_import::Node& node)
|
||||
: ov::op::util::FrameworkNode(node.get_ng_inputs(), node.get_outputs_size()),
|
||||
m_node(node),
|
||||
m_graph(graph) {}
|
||||
m_node(node) {}
|
||||
|
||||
ONNXFrameworkNode(std::shared_ptr<onnx_import::Graph> graph,
|
||||
const onnx_import::Node& node,
|
||||
const OutputVector& inputs)
|
||||
ONNXFrameworkNode(const onnx_import::Node& node, const OutputVector& inputs)
|
||||
: ov::op::util::FrameworkNode(inputs, node.get_outputs_size()),
|
||||
m_node(node),
|
||||
m_graph(graph) {}
|
||||
m_node(node) {}
|
||||
|
||||
OutputVector get_ng_nodes() const {
|
||||
OutputVector ng_nodes{m_graph->make_ng_nodes(m_node)};
|
||||
OutputVector get_ng_nodes(const std::shared_ptr<onnx_import::Graph>& graph) const {
|
||||
OutputVector ng_nodes{graph->make_ng_nodes(m_node)};
|
||||
if (ng_nodes.size() > get_output_size()) {
|
||||
ng_nodes.resize(get_output_size());
|
||||
}
|
||||
@ -71,35 +67,31 @@ public:
|
||||
|
||||
protected:
|
||||
onnx_import::Node m_node;
|
||||
|
||||
private:
|
||||
std::shared_ptr<onnx_import::Graph> m_graph;
|
||||
};
|
||||
|
||||
class ONNXSubgraphFrameworkNode : public ONNXFrameworkNode {
|
||||
public:
|
||||
NGRAPH_RTTI_DECLARATION;
|
||||
|
||||
ONNXSubgraphFrameworkNode(std::shared_ptr<onnx_import::Graph> graph,
|
||||
const onnx_import::Node& node,
|
||||
ONNXSubgraphFrameworkNode(const onnx_import::Node& node,
|
||||
const std::vector<std::shared_ptr<Function>>& functions,
|
||||
const OutputVector& inputs)
|
||||
: ONNXFrameworkNode(graph, node, inputs) {}
|
||||
: ONNXFrameworkNode(node, inputs),
|
||||
m_functions(functions) {}
|
||||
|
||||
void infer_inputs_from_parent() {
|
||||
for (auto& subgraph : m_node.get_subgraphs())
|
||||
subgraph.second->infer_inputs_from_parent();
|
||||
}
|
||||
|
||||
std::vector<std::shared_ptr<Function>> get_subgraph_functions() const {
|
||||
std::vector<std::shared_ptr<Function>> ret;
|
||||
for (const auto& kv : m_node.get_subgraphs()) {
|
||||
auto& subgraph = kv.second;
|
||||
ret.push_back(std::make_shared<Function>(subgraph->get_ng_outputs(),
|
||||
subgraph->get_ng_parameters(),
|
||||
subgraph->get_name()));
|
||||
}
|
||||
return ret;
|
||||
const std::vector<std::shared_ptr<Function>>& get_subgraph_functions() const {
|
||||
return m_functions;
|
||||
}
|
||||
|
||||
virtual std::shared_ptr<Node> clone_with_new_inputs(const OutputVector& inputs) const override;
|
||||
|
||||
private:
|
||||
std::vector<std::shared_ptr<Function>> m_functions;
|
||||
};
|
||||
|
||||
} // namespace frontend
|
||||
|
@ -60,6 +60,12 @@ void apply_transformations(ONNX_NAMESPACE::ModelProto& model_proto, const std::s
|
||||
} // namespace
|
||||
|
||||
void convert_decoded_function(std::shared_ptr<Function> function) {
|
||||
auto& rt_info = function->get_rt_info();
|
||||
auto it = rt_info.find(ONNX_GRAPH_RT_ATTRIBUTE);
|
||||
OPENVINO_ASSERT(it != rt_info.end(),
|
||||
"Could not find '" + std::string(ONNX_GRAPH_RT_ATTRIBUTE) +
|
||||
"' attribute in decoded model. Model probably wasn't created by FrontEnd::decode function.");
|
||||
auto onnx_graph = it->second.as<std::shared_ptr<onnx_import::Graph>>();
|
||||
for (const auto& node : function->get_ordered_ops()) {
|
||||
if (auto raw_node = std::dynamic_pointer_cast<frontend::ONNXFrameworkNode>(node)) {
|
||||
if (auto subgraph_node = std::dynamic_pointer_cast<frontend::ONNXSubgraphFrameworkNode>(node)) {
|
||||
@ -68,7 +74,7 @@ void convert_decoded_function(std::shared_ptr<Function> function) {
|
||||
convert_decoded_function(function);
|
||||
}
|
||||
}
|
||||
auto ng_nodes = raw_node->get_ng_nodes();
|
||||
auto ng_nodes = raw_node->get_ng_nodes(onnx_graph);
|
||||
replace_node(raw_node, ng_nodes);
|
||||
} else {
|
||||
// Have to revalidate node because new intpus can affect shape/type
|
||||
@ -76,6 +82,7 @@ void convert_decoded_function(std::shared_ptr<Function> function) {
|
||||
node->revalidate_and_infer_types();
|
||||
}
|
||||
}
|
||||
rt_info.erase(it);
|
||||
detail::remove_dangling_parameters(function);
|
||||
detail::remove_dangling_results(function);
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user