[ONNX] remove dangling Parameters from graph (#3774)
We have models with inputs with no initializer and not connected to any other node.
This commit is contained in:
parent
9ab44f8d5c
commit
9996994b8d
@ -43,12 +43,17 @@ namespace ngraph
|
||||
public:
|
||||
/// \brief Add node to the cache or override the existing one.
|
||||
///
|
||||
/// \note GraphCahce takes ownership of the node.
|
||||
/// \note GraphCache takes ownership of the node.
|
||||
///
|
||||
/// \param[in] name The name of node added to the cache.
|
||||
/// \param[in] node The node added to the cache.
|
||||
void emplace_node(const std::string& name, Output<ngraph::Node>&& node);
|
||||
|
||||
/// \brief Remove node from the cache
|
||||
///
|
||||
/// \param[in] name The name of node to be removed
|
||||
void remove_node(const std::string& name);
|
||||
|
||||
/// \brief Get the node from the cache
|
||||
///
|
||||
/// \note If the node is not found the ngraph_error exception is thrown.
|
||||
|
@ -66,6 +66,25 @@ namespace ngraph
|
||||
Graph::Graph(const ONNX_NAMESPACE::GraphProto& graph_proto, Model& model)
|
||||
: Graph(graph_proto, model, std::unique_ptr<GraphCache>(new GraphCache()))
|
||||
{
|
||||
// Remove dangling Parameters
|
||||
for (auto param_it = m_parameters.begin(); param_it != m_parameters.end();)
|
||||
{
|
||||
if ((*param_it)->get_output_target_inputs(0).size() == 0)
|
||||
{
|
||||
const auto& name = (*param_it)->get_friendly_name();
|
||||
auto out_it = std::find_if(
|
||||
m_outputs.begin(), m_outputs.end(), [&name](const ValueInfo& info) {
|
||||
return info.get_name() == name;
|
||||
});
|
||||
if (out_it == m_outputs.end())
|
||||
{
|
||||
m_cache->remove_node(name);
|
||||
param_it = m_parameters.erase(param_it);
|
||||
continue;
|
||||
}
|
||||
}
|
||||
param_it++;
|
||||
}
|
||||
}
|
||||
|
||||
Graph::Graph(const ONNX_NAMESPACE::GraphProto& graph_proto,
|
||||
|
@ -26,6 +26,15 @@ namespace ngraph
|
||||
m_graph_cache_map[name] = std::move(node);
|
||||
}
|
||||
|
||||
void GraphCache::remove_node(const std::string& name)
|
||||
{
|
||||
auto it = m_graph_cache_map.find(name);
|
||||
if (it != m_graph_cache_map.end())
|
||||
{
|
||||
m_graph_cache_map.erase(it);
|
||||
}
|
||||
}
|
||||
|
||||
Output<ngraph::Node> GraphCache::get_node(const std::string& name) const
|
||||
{
|
||||
try
|
||||
|
53
ngraph/test/models/onnx/dangling_parameter.prototxt
Normal file
53
ngraph/test/models/onnx/dangling_parameter.prototxt
Normal file
@ -0,0 +1,53 @@
|
||||
ir_version: 4
|
||||
producer_name: "pytorch"
|
||||
producer_version: "1.2"
|
||||
graph {
|
||||
node {
|
||||
input: "X"
|
||||
output: "Y"
|
||||
op_type: "Abs"
|
||||
}
|
||||
name: "torch-jit-export"
|
||||
input {
|
||||
name: "X"
|
||||
type {
|
||||
tensor_type {
|
||||
elem_type: 1
|
||||
shape {
|
||||
dim {
|
||||
dim_value: 3
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
input {
|
||||
name: "unused_bool"
|
||||
type {
|
||||
tensor_type {
|
||||
elem_type: 9
|
||||
shape {
|
||||
dim {
|
||||
dim_value: 1
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
output {
|
||||
name: "Y"
|
||||
type {
|
||||
tensor_type {
|
||||
elem_type: 1
|
||||
shape {
|
||||
dim {
|
||||
dim_value: 3
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
opset_import {
|
||||
version: 9
|
||||
}
|
@ -765,7 +765,6 @@ NGRAPH_TEST(${BACKEND_NAME}, onnx_model_softmax_0D)
|
||||
file_util::path_join(SERIALIZED_ZOO, "onnx/softmax_0D.prototxt"));
|
||||
|
||||
auto test_case = test::TestCase<TestEngine>(function);
|
||||
test_case.add_input<float>({3.141592});
|
||||
test_case.add_expected_output<float>({1.0});
|
||||
test_case.run();
|
||||
}
|
||||
@ -2597,7 +2596,6 @@ NGRAPH_TEST(${BACKEND_NAME}, onnx_model_eye_like)
|
||||
file_util::path_join(SERIALIZED_ZOO, "onnx/eye_like.prototxt"));
|
||||
|
||||
auto test_case = test::TestCase<TestEngine>(function);
|
||||
test_case.add_input<float>({0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f});
|
||||
test_case.add_expected_output<float>(
|
||||
Shape{3, 4}, {0.f, 0.f, 0.f, 0.f, 1.f, 0.f, 0.f, 0.f, 0.f, 1.f, 0.f, 0.f});
|
||||
|
||||
@ -3287,3 +3285,15 @@ NGRAPH_TEST(${BACKEND_NAME}, onnx_model_hard_sigmoid)
|
||||
test_case.add_expected_output<float>(Shape{4}, {1.0f, 0.0f, 0.5f, 0.699999988079071f});
|
||||
test_case.run();
|
||||
}
|
||||
|
||||
NGRAPH_TEST(${BACKEND_NAME}, onnx_model_dangling_parameter)
|
||||
{
|
||||
auto function = onnx_import::import_onnx_model(
|
||||
file_util::path_join(SERIALIZED_ZOO, "onnx/dangling_parameter.prototxt"));
|
||||
|
||||
auto test_case = test::TestCase<TestEngine>(function);
|
||||
|
||||
test_case.add_input<float>({-1.0f, 2.0f, -3.0f});
|
||||
test_case.add_expected_output<float>(Shape{3}, {1.0f, 2.0f, 3.0f});
|
||||
test_case.run();
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user