diff --git a/ngraph/frontend/onnx_import/include/onnx_import/core/graph_cache.hpp b/ngraph/frontend/onnx_import/include/onnx_import/core/graph_cache.hpp index 07ed67acb42..c34cbf46ae2 100644 --- a/ngraph/frontend/onnx_import/include/onnx_import/core/graph_cache.hpp +++ b/ngraph/frontend/onnx_import/include/onnx_import/core/graph_cache.hpp @@ -43,12 +43,17 @@ namespace ngraph public: /// \brief Add node to the cache or override the existing one. /// - /// \note GraphCahce takes ownership of the node. + /// \note GraphCache takes ownership of the node. /// /// \param[in] name The name of node added to the cache. /// \param[in] node The node added to the cache. void emplace_node(const std::string& name, Output&& node); + /// \brief Remove node from the cache + /// + /// \param[in] name The name of node to be removed + void remove_node(const std::string& name); + /// \brief Get the node from the cache /// /// \note If the node is not found the ngraph_error exception is thrown. diff --git a/ngraph/frontend/onnx_import/src/core/graph.cpp b/ngraph/frontend/onnx_import/src/core/graph.cpp index 605fdfba27c..463747ad8c2 100644 --- a/ngraph/frontend/onnx_import/src/core/graph.cpp +++ b/ngraph/frontend/onnx_import/src/core/graph.cpp @@ -66,6 +66,25 @@ namespace ngraph Graph::Graph(const ONNX_NAMESPACE::GraphProto& graph_proto, Model& model) : Graph(graph_proto, model, std::unique_ptr(new GraphCache())) { + // Remove dangling Parameters + for (auto param_it = m_parameters.begin(); param_it != m_parameters.end();) + { + if ((*param_it)->get_output_target_inputs(0).size() == 0) + { + const auto& name = (*param_it)->get_friendly_name(); + auto out_it = std::find_if( + m_outputs.begin(), m_outputs.end(), [&name](const ValueInfo& info) { + return info.get_name() == name; + }); + if (out_it == m_outputs.end()) + { + m_cache->remove_node(name); + param_it = m_parameters.erase(param_it); + continue; + } + } + param_it++; + } } Graph::Graph(const ONNX_NAMESPACE::GraphProto& graph_proto, diff --git a/ngraph/frontend/onnx_import/src/core/graph_cache.cpp b/ngraph/frontend/onnx_import/src/core/graph_cache.cpp index 2155bb0e01d..925a91302bf 100644 --- a/ngraph/frontend/onnx_import/src/core/graph_cache.cpp +++ b/ngraph/frontend/onnx_import/src/core/graph_cache.cpp @@ -26,6 +26,15 @@ namespace ngraph m_graph_cache_map[name] = std::move(node); } + void GraphCache::remove_node(const std::string& name) + { + auto it = m_graph_cache_map.find(name); + if (it != m_graph_cache_map.end()) + { + m_graph_cache_map.erase(it); + } + } + Output GraphCache::get_node(const std::string& name) const { try diff --git a/ngraph/test/models/onnx/dangling_parameter.prototxt b/ngraph/test/models/onnx/dangling_parameter.prototxt new file mode 100644 index 00000000000..1e07a93fef8 --- /dev/null +++ b/ngraph/test/models/onnx/dangling_parameter.prototxt @@ -0,0 +1,53 @@ +ir_version: 4 +producer_name: "pytorch" +producer_version: "1.2" +graph { + node { + input: "X" + output: "Y" + op_type: "Abs" + } + name: "torch-jit-export" + input { + name: "X" + type { + tensor_type { + elem_type: 1 + shape { + dim { + dim_value: 3 + } + } + } + } + } + input { + name: "unused_bool" + type { + tensor_type { + elem_type: 9 + shape { + dim { + dim_value: 1 + } + } + } + } + } + output { + name: "Y" + type { + tensor_type { + elem_type: 1 + shape { + dim { + dim_value: 3 + } + } + } + } + } +} +opset_import { + version: 9 +} diff --git a/ngraph/test/onnx/onnx_import.in.cpp b/ngraph/test/onnx/onnx_import.in.cpp index 28eab8c7ac6..267ef98233a 100644 --- a/ngraph/test/onnx/onnx_import.in.cpp +++ b/ngraph/test/onnx/onnx_import.in.cpp @@ -765,7 +765,6 @@ NGRAPH_TEST(${BACKEND_NAME}, onnx_model_softmax_0D) file_util::path_join(SERIALIZED_ZOO, "onnx/softmax_0D.prototxt")); auto test_case = test::TestCase(function); - test_case.add_input({3.141592}); test_case.add_expected_output({1.0}); test_case.run(); } @@ -2597,7 +2596,6 @@ NGRAPH_TEST(${BACKEND_NAME}, onnx_model_eye_like) file_util::path_join(SERIALIZED_ZOO, "onnx/eye_like.prototxt")); auto test_case = test::TestCase(function); - test_case.add_input({0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f}); test_case.add_expected_output( Shape{3, 4}, {0.f, 0.f, 0.f, 0.f, 1.f, 0.f, 0.f, 0.f, 0.f, 1.f, 0.f, 0.f}); @@ -3287,3 +3285,15 @@ NGRAPH_TEST(${BACKEND_NAME}, onnx_model_hard_sigmoid) test_case.add_expected_output(Shape{4}, {1.0f, 0.0f, 0.5f, 0.699999988079071f}); test_case.run(); } + +NGRAPH_TEST(${BACKEND_NAME}, onnx_model_dangling_parameter) +{ + auto function = onnx_import::import_onnx_model( + file_util::path_join(SERIALIZED_ZOO, "onnx/dangling_parameter.prototxt")); + + auto test_case = test::TestCase(function); + + test_case.add_input({-1.0f, 2.0f, -3.0f}); + test_case.add_expected_output(Shape{3}, {1.0f, 2.0f, 3.0f}); + test_case.run(); +}