Add get_function to ONNX Editor API (#5045)

This commit is contained in:
Mateusz Bencer 2021-04-02 14:17:39 +02:00 committed by GitHub
parent 3a6fba913c
commit 69e71b7287
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 120 additions and 89 deletions

View File

@ -22,7 +22,7 @@ add_library(ngraph::onnx_editor ALIAS ${TARGET_NAME})
# TODO Add handling ie_faster_build
target_link_libraries(${TARGET_NAME} PRIVATE onnx_common
target_link_libraries(${TARGET_NAME} PRIVATE onnx_common onnx_importer
PUBLIC ngraph)
set(ONNX_EDITOR_INCLUDE_DIR ${CMAKE_CURRENT_SOURCE_DIR}/include)

View File

@ -8,6 +8,7 @@
#include <map>
#include <memory>
#include "ngraph/function.hpp"
#include "ngraph/op/constant.hpp"
#include "ngraph/partial_shape.hpp"
#include "ngraph/type/element_type.hpp"
@ -86,15 +87,12 @@ namespace ngraph
void set_input_values(
const std::map<std::string, std::shared_ptr<ngraph::op::Constant>>& input_values);
/// \brief Returns a non-const reference to the underlying ModelProto object, possibly
/// modified by the editor's API calls
///
/// \return A reference to ONNX ModelProto object containing the in-memory model
ONNX_NAMESPACE::ModelProto& model() const;
/// \brief Returns a serialized ONNX model, possibly modified by the editor.
std::string model_string() const;
/// \brief Converts an edited ONNX model to an nGraph Function representation.
std::shared_ptr<Function> get_function() const;
/// \brief Returns a list of all inputs of the in-memory model, including initializers.
/// The returned value might depend on the previous operations executed on an
/// instance of the model editor, in particular the subgraph extraction which

View File

@ -11,6 +11,7 @@
#include "onnx_common/parser.hpp"
#include "onnx_common/utils.hpp"
#include "onnx_editor/editor.hpp"
#include "onnx_import/utils/onnx_internal.hpp"
using namespace ngraph;
@ -217,11 +218,6 @@ onnx_editor::ONNXModelEditor::ONNXModelEditor(const std::string& model_path)
{
}
ONNX_NAMESPACE::ModelProto& onnx_editor::ONNXModelEditor::model() const
{
return m_pimpl->m_model_proto;
}
const std::string& onnx_editor::ONNXModelEditor::model_path() const
{
return m_model_path;
@ -330,6 +326,11 @@ std::string onnx_editor::ONNXModelEditor::model_string() const
return m_pimpl->m_model_proto.SerializeAsString();
}
std::shared_ptr<Function> onnx_editor::ONNXModelEditor::get_function() const
{
return onnx_import::detail::import_onnx_model(m_pimpl->m_model_proto, m_model_path);
}
void onnx_editor::ONNXModelEditor::set_input_values(
const std::map<std::string, std::shared_ptr<ngraph::op::Constant>>& input_values)
{

View File

@ -0,0 +1,44 @@
// Copyright (C) 2018-2021 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#pragma once
#include <memory>
#include <string>
#include "ngraph/function.hpp"
#include "onnx_import/utils/onnx_importer_visibility.hpp"
namespace ONNX_NAMESPACE
{
class ModelProto;
}
namespace ngraph
{
namespace onnx_import
{
namespace detail
{
/// \brief Imports and converts an serialized ONNX model from a ModelProto
/// to an nGraph Function representation.
///
/// \note The function can be used only internally by OV components!
/// Passing ModelProto between componets which use different protobuf
/// library can cause segfaults. If stream parsing fails or the ONNX model
/// contains unsupported ops, the function throws an ngraph_error exception.
///
/// \param[in] model_proto Reference to a GraphProto object.
/// \param[in] model_path The path to the imported onnx model.
/// It is required if the imported model uses data saved in
/// external files.
///
/// \return An nGraph function that represents a single output from the created
/// graph.
ONNX_IMPORTER_API
std::shared_ptr<Function> import_onnx_model(ONNX_NAMESPACE::ModelProto& model_proto,
const std::string& model_path);
} // namespace detail
} // namespace onnx_import
} // namespace ngraph

View File

@ -4,47 +4,18 @@
#include <fstream>
#include <memory>
#include <onnx/onnx_pb.h>
#include "core/graph.hpp"
#include "core/model.hpp"
#include "core/transform.hpp"
#include "ngraph/except.hpp"
#include "onnx_common/parser.hpp"
#include "onnx_import/onnx.hpp"
#include "onnx_import/utils/onnx_internal.hpp"
#include "ops_bridge.hpp"
namespace ngraph
{
namespace onnx_import
{
namespace detail
{
std::shared_ptr<Function>
convert_to_ng_function(const ONNX_NAMESPACE::ModelProto& model_proto)
{
Model model{model_proto};
Graph graph{model_proto.graph(), model};
auto function = std::make_shared<Function>(
graph.get_ng_outputs(), graph.get_ng_parameters(), graph.get_name());
for (std::size_t i{0}; i < function->get_output_size(); ++i)
{
function->get_output_op(i)->set_friendly_name(
graph.get_outputs().at(i).get_name());
}
return function;
}
std::shared_ptr<Function> import_onnx_model(ONNX_NAMESPACE::ModelProto& model_proto,
const std::string& model_path)
{
transform::expand_onnx_functions(model_proto);
transform::fixup_legacy_operators(model_proto);
transform::update_external_data_paths(model_proto, model_path);
return detail::convert_to_ng_function(model_proto);
}
} // namespace detail
std::shared_ptr<Function> import_onnx_model(std::istream& stream,
const std::string& model_path)
{

View File

@ -0,0 +1,44 @@
// Copyright (C) 2018-2021 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#include <onnx/onnx_pb.h>
#include "core/graph.hpp"
#include "core/model.hpp"
#include "core/transform.hpp"
#include "onnx_import/utils/onnx_internal.hpp"
namespace ngraph
{
namespace onnx_import
{
namespace detail
{
std::shared_ptr<Function>
convert_to_ng_function(const ONNX_NAMESPACE::ModelProto& model_proto)
{
Model model{model_proto};
Graph graph{model_proto.graph(), model};
auto function = std::make_shared<Function>(
graph.get_ng_outputs(), graph.get_ng_parameters(), graph.get_name());
for (std::size_t i{0}; i < function->get_output_size(); ++i)
{
function->get_output_op(i)->set_friendly_name(
graph.get_outputs().at(i).get_name());
}
return function;
}
std::shared_ptr<Function> import_onnx_model(ONNX_NAMESPACE::ModelProto& model_proto,
const std::string& model_path)
{
transform::expand_onnx_functions(model_proto);
transform::fixup_legacy_operators(model_proto);
transform::update_external_data_paths(model_proto, model_path);
return detail::convert_to_ng_function(model_proto);
}
} // namespace detail
} // namespace onnx_import
} // namespace ngraph

View File

@ -59,9 +59,7 @@ NGRAPH_TEST(onnx_editor, types__single_input_type_substitution)
editor.set_input_types({{"A", element::i64}});
std::istringstream model_stream(editor.model_string());
const auto function = onnx_import::import_onnx_model(model_stream);
const auto function = editor.get_function();
const auto graph_inputs = function->get_parameters();
const auto float_inputs_count = std::count_if(
@ -84,8 +82,7 @@ NGRAPH_TEST(onnx_editor, types__all_inputs_type_substitution)
editor.set_input_types({{"A", element::i8}, {"B", element::i8}, {"C", element::i8}});
std::istringstream model_stream(editor.model_string());
const auto function = onnx_import::import_onnx_model(model_stream);
const auto function = editor.get_function();
const auto graph_inputs = function->get_parameters();
@ -142,8 +139,7 @@ NGRAPH_TEST(onnx_editor, types__elem_type_missing_in_input)
// the "elem_type" is missing in the model but it should be possible to set the type anyway
EXPECT_NO_THROW(editor.set_input_types({{"A", element::i64}}));
std::istringstream model_stream(editor.model_string());
const auto function = onnx_import::import_onnx_model(model_stream);
const auto function = editor.get_function();
const auto graph_inputs = function->get_parameters();
@ -165,9 +161,7 @@ NGRAPH_TEST(onnx_editor, shapes__modify_single_input)
editor.set_input_shapes({{"B", new_shape}});
std::istringstream model_stream(editor.model_string());
const auto function = onnx_import::import_onnx_model(model_stream);
const auto function = editor.get_function();
const auto graph_inputs = function->get_parameters();
EXPECT_TRUE(find_input(graph_inputs, "B")->get_partial_shape().same_scheme(new_shape));
@ -182,9 +176,7 @@ NGRAPH_TEST(onnx_editor, shapes__modify_all_inputs)
editor.set_input_shapes({{"A", new_shape}, {"B", new_shape}});
std::istringstream model_stream(editor.model_string());
const auto function = onnx_import::import_onnx_model(model_stream);
const auto function = editor.get_function();
const auto graph_inputs = function->get_parameters();
for (const auto& input : graph_inputs)
@ -203,9 +195,7 @@ NGRAPH_TEST(onnx_editor, shapes__dynamic_rank_in_model)
const auto expected_shape_of_A = PartialShape{1, 2};
EXPECT_NO_THROW(editor.set_input_shapes({{"A", expected_shape_of_A}}));
std::istringstream model_stream(editor.model_string());
const auto function = onnx_import::import_onnx_model(model_stream);
const auto function = editor.get_function();
const auto graph_inputs = function->get_parameters();
EXPECT_TRUE(
@ -221,9 +211,7 @@ NGRAPH_TEST(onnx_editor, shapes__set_dynamic_dimension)
editor.set_input_shapes({{"A", new_shape}});
std::istringstream model_stream(editor.model_string());
const auto function = onnx_import::import_onnx_model(model_stream);
const auto function = editor.get_function();
const auto graph_inputs = function->get_parameters();
EXPECT_TRUE(find_input(graph_inputs, "A")->get_partial_shape().same_scheme(new_shape));
@ -239,9 +227,7 @@ NGRAPH_TEST(onnx_editor, shapes__set_mixed_dimensions)
editor.set_input_shapes({{"A", new_shape_A}, {"B", new_shape_B}});
std::istringstream model_stream(editor.model_string());
const auto function = onnx_import::import_onnx_model(model_stream);
const auto function = editor.get_function();
const auto graph_inputs = function->get_parameters();
const auto input_A = find_input(graph_inputs, "A");
@ -260,9 +246,7 @@ NGRAPH_TEST(onnx_editor, shapes__set_scalar_inputs)
editor.set_input_shapes({{"A", new_shape}, {"B", new_shape}});
std::istringstream model_stream(editor.model_string());
const auto function = onnx_import::import_onnx_model(model_stream);
const auto function = editor.get_function();
const auto graph_inputs = function->get_parameters();
const auto input_A = find_input(graph_inputs, "A");
@ -281,9 +265,7 @@ NGRAPH_TEST(onnx_editor, shapes__static_to_dynamic_rank_substitution)
editor.set_input_shapes({{"A", new_shape}, {"B", new_shape}});
std::istringstream model_stream(editor.model_string());
const auto function = onnx_import::import_onnx_model(model_stream);
const auto function = editor.get_function();
const auto graph_inputs = function->get_parameters();
for (const auto& input : graph_inputs)
@ -687,8 +669,7 @@ NGRAPH_TEST(onnx_editor, values__append_one_initializer)
in_vals.emplace("A", op::Constant::create(element::i64, Shape{2}, {1, 2}));
editor.set_input_values(in_vals);
std::istringstream model_stream(editor.model_string());
const auto function = onnx_import::import_onnx_model(model_stream);
const auto function = editor.get_function();
auto test_case = test::TestCase<TestEngine>(function);
test_case.add_input<int64_t>(Shape{2}, {5, 6});
test_case.add_expected_output<int64_t>(Shape{2}, {6, 8});
@ -705,8 +686,7 @@ NGRAPH_TEST(onnx_editor, values__append_two_initializers_to_invalid)
in_vals.emplace("B", op::Constant::create(element::i64, Shape{2}, {1, 3}));
editor.set_input_values(in_vals);
std::istringstream model_stream(editor.model_string());
const auto function = onnx_import::import_onnx_model(model_stream);
const auto function = editor.get_function();
auto test_case = test::TestCase<TestEngine>(function);
test_case.add_expected_output<int64_t>(Shape{2}, {5, 5});
test_case.run();
@ -721,8 +701,7 @@ NGRAPH_TEST(onnx_editor, values__modify_one_initializer)
in_vals.emplace("B", op::Constant::create(element::i64, Shape{2}, {3, 4}));
editor.set_input_values(in_vals);
std::istringstream model_stream(editor.model_string());
const auto function = onnx_import::import_onnx_model(model_stream);
const auto function = editor.get_function();
auto test_case = test::TestCase<TestEngine>(function);
test_case.add_expected_output<int64_t>(Shape{2}, {4, 6});
test_case.run();
@ -738,8 +717,7 @@ NGRAPH_TEST(onnx_editor, values__modify_two_initializers)
in_vals.emplace("B", op::Constant::create(element::i64, Shape{2}, {2, 1}));
editor.set_input_values(in_vals);
std::istringstream model_stream(editor.model_string());
const auto function = onnx_import::import_onnx_model(model_stream);
const auto function = editor.get_function();
auto test_case = test::TestCase<TestEngine>(function);
test_case.add_expected_output<int64_t>(Shape{2}, {5, 7});
test_case.run();
@ -755,8 +733,7 @@ NGRAPH_TEST(onnx_editor, values__no_inputs_modify_two_initializers)
in_vals.emplace("B", op::Constant::create(element::i64, Shape{2}, {11, 22}));
editor.set_input_values(in_vals);
std::istringstream model_stream(editor.model_string());
const auto function = onnx_import::import_onnx_model(model_stream);
const auto function = editor.get_function();
auto test_case = test::TestCase<TestEngine>(function);
test_case.add_expected_output<int64_t>(Shape{2}, {12, 24});
test_case.run();
@ -772,8 +749,7 @@ NGRAPH_TEST(onnx_editor, values__append_two_initializers_change_shape_type)
in_vals.emplace("B", op::Constant::create(element::i8, Shape{2, 1}, {-2, 2}));
editor.set_input_values(in_vals);
std::istringstream model_stream(editor.model_string());
const auto function = onnx_import::import_onnx_model(model_stream);
const auto function = editor.get_function();
auto test_case = test::TestCase<TestEngine>(function);
test_case.add_expected_output<int8_t>(Shape{2, 1}, {-3, 3});
test_case.run();
@ -790,8 +766,7 @@ NGRAPH_TEST(onnx_editor, values__append_two_initializers_mixed_types)
in_vals.emplace("indices", op::Constant::create(element::i32, Shape{2, 2, 1}, {0, 1, 0, 1}));
editor.set_input_values(in_vals);
std::istringstream model_stream(editor.model_string());
const auto function = onnx_import::import_onnx_model(model_stream);
const auto function = editor.get_function();
auto test_case = test::TestCase<TestEngine>(function);
test_case.add_expected_output<int16_t>(Shape{2, 2, 1}, {1, 4, 5, 8});
test_case.run();

View File

@ -41,8 +41,7 @@ TYPED_TEST_P(ElemTypesTests, onnx_test_add_abc_set_precission)
editor.set_input_types({{"A", ng_type}, {"B", ng_type}, {"C", ng_type}});
std::istringstream model_stream(editor.model_string());
const auto function = onnx_import::import_onnx_model(model_stream);
const auto function = editor.get_function();
auto test_case = test::TestCase<TestEngine>(function);
test_case.add_input<DataType>(std::vector<DataType>{1, 2, 3});
test_case.add_input<DataType>(std::vector<DataType>{4, 5, 6});
@ -61,8 +60,7 @@ TYPED_TEST_P(ElemTypesTests, onnx_test_split_multioutput_set_precission)
editor.set_input_types({{"input", ng_type}});
std::istringstream model_stream(editor.model_string());
const auto function = onnx_import::import_onnx_model(model_stream);
const auto function = editor.get_function();
auto test_case = test::TestCase<TestEngine>(function);
test_case.add_input<DataType>(std::vector<DataType>{1, 2, 3, 4, 5, 6});
test_case.add_expected_output<DataType>(Shape{2}, std::vector<DataType>{1, 2});