[ONNX] remove dangling Parameters from graph (#3774)

We have models with inputs with no initializer and not connected to any other node.
This commit is contained in:
Mateusz Tabaka 2021-01-12 14:12:57 +01:00 committed by GitHub
parent 9ab44f8d5c
commit 9996994b8d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 99 additions and 3 deletions

View File

@ -43,12 +43,17 @@ namespace ngraph
public:
/// \brief Add node to the cache or override the existing one.
///
/// \note GraphCahce takes ownership of the node.
/// \note GraphCache takes ownership of the node.
///
/// \param[in] name The name of node added to the cache.
/// \param[in] node The node added to the cache.
void emplace_node(const std::string& name, Output<ngraph::Node>&& node);
/// \brief Remove node from the cache
///
/// \param[in] name The name of node to be removed
void remove_node(const std::string& name);
/// \brief Get the node from the cache
///
/// \note If the node is not found the ngraph_error exception is thrown.

View File

@ -66,6 +66,25 @@ namespace ngraph
Graph::Graph(const ONNX_NAMESPACE::GraphProto& graph_proto, Model& model)
: Graph(graph_proto, model, std::unique_ptr<GraphCache>(new GraphCache()))
{
// Remove dangling Parameters
for (auto param_it = m_parameters.begin(); param_it != m_parameters.end();)
{
if ((*param_it)->get_output_target_inputs(0).size() == 0)
{
const auto& name = (*param_it)->get_friendly_name();
auto out_it = std::find_if(
m_outputs.begin(), m_outputs.end(), [&name](const ValueInfo& info) {
return info.get_name() == name;
});
if (out_it == m_outputs.end())
{
m_cache->remove_node(name);
param_it = m_parameters.erase(param_it);
continue;
}
}
param_it++;
}
}
Graph::Graph(const ONNX_NAMESPACE::GraphProto& graph_proto,

View File

@ -26,6 +26,15 @@ namespace ngraph
m_graph_cache_map[name] = std::move(node);
}
void GraphCache::remove_node(const std::string& name)
{
auto it = m_graph_cache_map.find(name);
if (it != m_graph_cache_map.end())
{
m_graph_cache_map.erase(it);
}
}
Output<ngraph::Node> GraphCache::get_node(const std::string& name) const
{
try

View File

@ -0,0 +1,53 @@
ir_version: 4
producer_name: "pytorch"
producer_version: "1.2"
graph {
node {
input: "X"
output: "Y"
op_type: "Abs"
}
name: "torch-jit-export"
input {
name: "X"
type {
tensor_type {
elem_type: 1
shape {
dim {
dim_value: 3
}
}
}
}
}
input {
name: "unused_bool"
type {
tensor_type {
elem_type: 9
shape {
dim {
dim_value: 1
}
}
}
}
}
output {
name: "Y"
type {
tensor_type {
elem_type: 1
shape {
dim {
dim_value: 3
}
}
}
}
}
}
opset_import {
version: 9
}

View File

@ -765,7 +765,6 @@ NGRAPH_TEST(${BACKEND_NAME}, onnx_model_softmax_0D)
file_util::path_join(SERIALIZED_ZOO, "onnx/softmax_0D.prototxt"));
auto test_case = test::TestCase<TestEngine>(function);
test_case.add_input<float>({3.141592});
test_case.add_expected_output<float>({1.0});
test_case.run();
}
@ -2597,7 +2596,6 @@ NGRAPH_TEST(${BACKEND_NAME}, onnx_model_eye_like)
file_util::path_join(SERIALIZED_ZOO, "onnx/eye_like.prototxt"));
auto test_case = test::TestCase<TestEngine>(function);
test_case.add_input<float>({0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f});
test_case.add_expected_output<float>(
Shape{3, 4}, {0.f, 0.f, 0.f, 0.f, 1.f, 0.f, 0.f, 0.f, 0.f, 1.f, 0.f, 0.f});
@ -3287,3 +3285,15 @@ NGRAPH_TEST(${BACKEND_NAME}, onnx_model_hard_sigmoid)
test_case.add_expected_output<float>(Shape{4}, {1.0f, 0.0f, 0.5f, 0.699999988079071f});
test_case.run();
}
NGRAPH_TEST(${BACKEND_NAME}, onnx_model_dangling_parameter)
{
auto function = onnx_import::import_onnx_model(
file_util::path_join(SERIALIZED_ZOO, "onnx/dangling_parameter.prototxt"));
auto test_case = test::TestCase<TestEngine>(function);
test_case.add_input<float>({-1.0f, 2.0f, -3.0f});
test_case.add_expected_output<float>(Shape{3}, {1.0f, 2.0f, 3.0f});
test_case.run();
}