diff --git a/ngraph/frontend/onnx_editor/CMakeLists.txt b/ngraph/frontend/onnx_editor/CMakeLists.txt index 0714b1e02be..d893f40a4fa 100644 --- a/ngraph/frontend/onnx_editor/CMakeLists.txt +++ b/ngraph/frontend/onnx_editor/CMakeLists.txt @@ -22,7 +22,7 @@ add_library(ngraph::onnx_editor ALIAS ${TARGET_NAME}) # TODO Add handling ie_faster_build -target_link_libraries(${TARGET_NAME} PRIVATE onnx_common +target_link_libraries(${TARGET_NAME} PRIVATE onnx_common onnx_importer PUBLIC ngraph) set(ONNX_EDITOR_INCLUDE_DIR ${CMAKE_CURRENT_SOURCE_DIR}/include) diff --git a/ngraph/frontend/onnx_editor/include/onnx_editor/editor.hpp b/ngraph/frontend/onnx_editor/include/onnx_editor/editor.hpp index 72a6ececff3..465890cab11 100644 --- a/ngraph/frontend/onnx_editor/include/onnx_editor/editor.hpp +++ b/ngraph/frontend/onnx_editor/include/onnx_editor/editor.hpp @@ -8,6 +8,7 @@ #include #include +#include "ngraph/function.hpp" #include "ngraph/op/constant.hpp" #include "ngraph/partial_shape.hpp" #include "ngraph/type/element_type.hpp" @@ -86,15 +87,12 @@ namespace ngraph void set_input_values( const std::map>& input_values); - /// \brief Returns a non-const reference to the underlying ModelProto object, possibly - /// modified by the editor's API calls - /// - /// \return A reference to ONNX ModelProto object containing the in-memory model - ONNX_NAMESPACE::ModelProto& model() const; - /// \brief Returns a serialized ONNX model, possibly modified by the editor. std::string model_string() const; + /// \brief Converts an edited ONNX model to an nGraph Function representation. + std::shared_ptr get_function() const; + /// \brief Returns a list of all inputs of the in-memory model, including initializers. /// The returned value might depend on the previous operations executed on an /// instance of the model editor, in particular the subgraph extraction which diff --git a/ngraph/frontend/onnx_editor/src/editor.cpp b/ngraph/frontend/onnx_editor/src/editor.cpp index ad60dd6c702..d4b24300bea 100644 --- a/ngraph/frontend/onnx_editor/src/editor.cpp +++ b/ngraph/frontend/onnx_editor/src/editor.cpp @@ -11,6 +11,7 @@ #include "onnx_common/parser.hpp" #include "onnx_common/utils.hpp" #include "onnx_editor/editor.hpp" +#include "onnx_import/utils/onnx_internal.hpp" using namespace ngraph; @@ -217,11 +218,6 @@ onnx_editor::ONNXModelEditor::ONNXModelEditor(const std::string& model_path) { } -ONNX_NAMESPACE::ModelProto& onnx_editor::ONNXModelEditor::model() const -{ - return m_pimpl->m_model_proto; -} - const std::string& onnx_editor::ONNXModelEditor::model_path() const { return m_model_path; @@ -330,6 +326,11 @@ std::string onnx_editor::ONNXModelEditor::model_string() const return m_pimpl->m_model_proto.SerializeAsString(); } +std::shared_ptr onnx_editor::ONNXModelEditor::get_function() const +{ + return onnx_import::detail::import_onnx_model(m_pimpl->m_model_proto, m_model_path); +} + void onnx_editor::ONNXModelEditor::set_input_values( const std::map>& input_values) { diff --git a/ngraph/frontend/onnx_import/include/onnx_import/utils/onnx_internal.hpp b/ngraph/frontend/onnx_import/include/onnx_import/utils/onnx_internal.hpp new file mode 100644 index 00000000000..58554bd3c99 --- /dev/null +++ b/ngraph/frontend/onnx_import/include/onnx_import/utils/onnx_internal.hpp @@ -0,0 +1,44 @@ +// Copyright (C) 2018-2021 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#pragma once + +#include +#include + +#include "ngraph/function.hpp" +#include "onnx_import/utils/onnx_importer_visibility.hpp" + +namespace ONNX_NAMESPACE +{ + class ModelProto; +} + +namespace ngraph +{ + namespace onnx_import + { + namespace detail + { + /// \brief Imports and converts an serialized ONNX model from a ModelProto + /// to an nGraph Function representation. + /// + /// \note The function can be used only internally by OV components! + /// Passing ModelProto between componets which use different protobuf + /// library can cause segfaults. If stream parsing fails or the ONNX model + /// contains unsupported ops, the function throws an ngraph_error exception. + /// + /// \param[in] model_proto Reference to a GraphProto object. + /// \param[in] model_path The path to the imported onnx model. + /// It is required if the imported model uses data saved in + /// external files. + /// + /// \return An nGraph function that represents a single output from the created + /// graph. + ONNX_IMPORTER_API + std::shared_ptr import_onnx_model(ONNX_NAMESPACE::ModelProto& model_proto, + const std::string& model_path); + } // namespace detail + } // namespace onnx_import +} // namespace ngraph diff --git a/ngraph/frontend/onnx_import/src/onnx.cpp b/ngraph/frontend/onnx_import/src/onnx.cpp index 3fbd799ab41..09f6623611d 100644 --- a/ngraph/frontend/onnx_import/src/onnx.cpp +++ b/ngraph/frontend/onnx_import/src/onnx.cpp @@ -4,47 +4,18 @@ #include #include +#include -#include "core/graph.hpp" -#include "core/model.hpp" -#include "core/transform.hpp" #include "ngraph/except.hpp" #include "onnx_common/parser.hpp" #include "onnx_import/onnx.hpp" +#include "onnx_import/utils/onnx_internal.hpp" #include "ops_bridge.hpp" namespace ngraph { namespace onnx_import { - namespace detail - { - std::shared_ptr - convert_to_ng_function(const ONNX_NAMESPACE::ModelProto& model_proto) - { - Model model{model_proto}; - Graph graph{model_proto.graph(), model}; - auto function = std::make_shared( - graph.get_ng_outputs(), graph.get_ng_parameters(), graph.get_name()); - for (std::size_t i{0}; i < function->get_output_size(); ++i) - { - function->get_output_op(i)->set_friendly_name( - graph.get_outputs().at(i).get_name()); - } - return function; - } - - std::shared_ptr import_onnx_model(ONNX_NAMESPACE::ModelProto& model_proto, - const std::string& model_path) - { - transform::expand_onnx_functions(model_proto); - transform::fixup_legacy_operators(model_proto); - transform::update_external_data_paths(model_proto, model_path); - - return detail::convert_to_ng_function(model_proto); - } - } // namespace detail - std::shared_ptr import_onnx_model(std::istream& stream, const std::string& model_path) { diff --git a/ngraph/frontend/onnx_import/src/utils/onnx_internal.cpp b/ngraph/frontend/onnx_import/src/utils/onnx_internal.cpp new file mode 100644 index 00000000000..00544c1fabf --- /dev/null +++ b/ngraph/frontend/onnx_import/src/utils/onnx_internal.cpp @@ -0,0 +1,44 @@ +// Copyright (C) 2018-2021 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#include + +#include "core/graph.hpp" +#include "core/model.hpp" +#include "core/transform.hpp" +#include "onnx_import/utils/onnx_internal.hpp" + +namespace ngraph +{ + namespace onnx_import + { + namespace detail + { + std::shared_ptr + convert_to_ng_function(const ONNX_NAMESPACE::ModelProto& model_proto) + { + Model model{model_proto}; + Graph graph{model_proto.graph(), model}; + auto function = std::make_shared( + graph.get_ng_outputs(), graph.get_ng_parameters(), graph.get_name()); + for (std::size_t i{0}; i < function->get_output_size(); ++i) + { + function->get_output_op(i)->set_friendly_name( + graph.get_outputs().at(i).get_name()); + } + return function; + } + + std::shared_ptr import_onnx_model(ONNX_NAMESPACE::ModelProto& model_proto, + const std::string& model_path) + { + transform::expand_onnx_functions(model_proto); + transform::fixup_legacy_operators(model_proto); + transform::update_external_data_paths(model_proto, model_path); + + return detail::convert_to_ng_function(model_proto); + } + } // namespace detail + } // namespace onnx_import +} // namespace ngraph diff --git a/ngraph/test/onnx/onnx_editor.cpp b/ngraph/test/onnx/onnx_editor.cpp index 898d5e449b1..defcf6745ba 100644 --- a/ngraph/test/onnx/onnx_editor.cpp +++ b/ngraph/test/onnx/onnx_editor.cpp @@ -59,9 +59,7 @@ NGRAPH_TEST(onnx_editor, types__single_input_type_substitution) editor.set_input_types({{"A", element::i64}}); - std::istringstream model_stream(editor.model_string()); - const auto function = onnx_import::import_onnx_model(model_stream); - + const auto function = editor.get_function(); const auto graph_inputs = function->get_parameters(); const auto float_inputs_count = std::count_if( @@ -84,8 +82,7 @@ NGRAPH_TEST(onnx_editor, types__all_inputs_type_substitution) editor.set_input_types({{"A", element::i8}, {"B", element::i8}, {"C", element::i8}}); - std::istringstream model_stream(editor.model_string()); - const auto function = onnx_import::import_onnx_model(model_stream); + const auto function = editor.get_function(); const auto graph_inputs = function->get_parameters(); @@ -142,8 +139,7 @@ NGRAPH_TEST(onnx_editor, types__elem_type_missing_in_input) // the "elem_type" is missing in the model but it should be possible to set the type anyway EXPECT_NO_THROW(editor.set_input_types({{"A", element::i64}})); - std::istringstream model_stream(editor.model_string()); - const auto function = onnx_import::import_onnx_model(model_stream); + const auto function = editor.get_function(); const auto graph_inputs = function->get_parameters(); @@ -165,9 +161,7 @@ NGRAPH_TEST(onnx_editor, shapes__modify_single_input) editor.set_input_shapes({{"B", new_shape}}); - std::istringstream model_stream(editor.model_string()); - const auto function = onnx_import::import_onnx_model(model_stream); - + const auto function = editor.get_function(); const auto graph_inputs = function->get_parameters(); EXPECT_TRUE(find_input(graph_inputs, "B")->get_partial_shape().same_scheme(new_shape)); @@ -182,9 +176,7 @@ NGRAPH_TEST(onnx_editor, shapes__modify_all_inputs) editor.set_input_shapes({{"A", new_shape}, {"B", new_shape}}); - std::istringstream model_stream(editor.model_string()); - const auto function = onnx_import::import_onnx_model(model_stream); - + const auto function = editor.get_function(); const auto graph_inputs = function->get_parameters(); for (const auto& input : graph_inputs) @@ -203,9 +195,7 @@ NGRAPH_TEST(onnx_editor, shapes__dynamic_rank_in_model) const auto expected_shape_of_A = PartialShape{1, 2}; EXPECT_NO_THROW(editor.set_input_shapes({{"A", expected_shape_of_A}})); - std::istringstream model_stream(editor.model_string()); - const auto function = onnx_import::import_onnx_model(model_stream); - + const auto function = editor.get_function(); const auto graph_inputs = function->get_parameters(); EXPECT_TRUE( @@ -221,9 +211,7 @@ NGRAPH_TEST(onnx_editor, shapes__set_dynamic_dimension) editor.set_input_shapes({{"A", new_shape}}); - std::istringstream model_stream(editor.model_string()); - const auto function = onnx_import::import_onnx_model(model_stream); - + const auto function = editor.get_function(); const auto graph_inputs = function->get_parameters(); EXPECT_TRUE(find_input(graph_inputs, "A")->get_partial_shape().same_scheme(new_shape)); @@ -239,9 +227,7 @@ NGRAPH_TEST(onnx_editor, shapes__set_mixed_dimensions) editor.set_input_shapes({{"A", new_shape_A}, {"B", new_shape_B}}); - std::istringstream model_stream(editor.model_string()); - const auto function = onnx_import::import_onnx_model(model_stream); - + const auto function = editor.get_function(); const auto graph_inputs = function->get_parameters(); const auto input_A = find_input(graph_inputs, "A"); @@ -260,9 +246,7 @@ NGRAPH_TEST(onnx_editor, shapes__set_scalar_inputs) editor.set_input_shapes({{"A", new_shape}, {"B", new_shape}}); - std::istringstream model_stream(editor.model_string()); - const auto function = onnx_import::import_onnx_model(model_stream); - + const auto function = editor.get_function(); const auto graph_inputs = function->get_parameters(); const auto input_A = find_input(graph_inputs, "A"); @@ -281,9 +265,7 @@ NGRAPH_TEST(onnx_editor, shapes__static_to_dynamic_rank_substitution) editor.set_input_shapes({{"A", new_shape}, {"B", new_shape}}); - std::istringstream model_stream(editor.model_string()); - const auto function = onnx_import::import_onnx_model(model_stream); - + const auto function = editor.get_function(); const auto graph_inputs = function->get_parameters(); for (const auto& input : graph_inputs) @@ -687,8 +669,7 @@ NGRAPH_TEST(onnx_editor, values__append_one_initializer) in_vals.emplace("A", op::Constant::create(element::i64, Shape{2}, {1, 2})); editor.set_input_values(in_vals); - std::istringstream model_stream(editor.model_string()); - const auto function = onnx_import::import_onnx_model(model_stream); + const auto function = editor.get_function(); auto test_case = test::TestCase(function); test_case.add_input(Shape{2}, {5, 6}); test_case.add_expected_output(Shape{2}, {6, 8}); @@ -705,8 +686,7 @@ NGRAPH_TEST(onnx_editor, values__append_two_initializers_to_invalid) in_vals.emplace("B", op::Constant::create(element::i64, Shape{2}, {1, 3})); editor.set_input_values(in_vals); - std::istringstream model_stream(editor.model_string()); - const auto function = onnx_import::import_onnx_model(model_stream); + const auto function = editor.get_function(); auto test_case = test::TestCase(function); test_case.add_expected_output(Shape{2}, {5, 5}); test_case.run(); @@ -721,8 +701,7 @@ NGRAPH_TEST(onnx_editor, values__modify_one_initializer) in_vals.emplace("B", op::Constant::create(element::i64, Shape{2}, {3, 4})); editor.set_input_values(in_vals); - std::istringstream model_stream(editor.model_string()); - const auto function = onnx_import::import_onnx_model(model_stream); + const auto function = editor.get_function(); auto test_case = test::TestCase(function); test_case.add_expected_output(Shape{2}, {4, 6}); test_case.run(); @@ -738,8 +717,7 @@ NGRAPH_TEST(onnx_editor, values__modify_two_initializers) in_vals.emplace("B", op::Constant::create(element::i64, Shape{2}, {2, 1})); editor.set_input_values(in_vals); - std::istringstream model_stream(editor.model_string()); - const auto function = onnx_import::import_onnx_model(model_stream); + const auto function = editor.get_function(); auto test_case = test::TestCase(function); test_case.add_expected_output(Shape{2}, {5, 7}); test_case.run(); @@ -755,8 +733,7 @@ NGRAPH_TEST(onnx_editor, values__no_inputs_modify_two_initializers) in_vals.emplace("B", op::Constant::create(element::i64, Shape{2}, {11, 22})); editor.set_input_values(in_vals); - std::istringstream model_stream(editor.model_string()); - const auto function = onnx_import::import_onnx_model(model_stream); + const auto function = editor.get_function(); auto test_case = test::TestCase(function); test_case.add_expected_output(Shape{2}, {12, 24}); test_case.run(); @@ -772,8 +749,7 @@ NGRAPH_TEST(onnx_editor, values__append_two_initializers_change_shape_type) in_vals.emplace("B", op::Constant::create(element::i8, Shape{2, 1}, {-2, 2})); editor.set_input_values(in_vals); - std::istringstream model_stream(editor.model_string()); - const auto function = onnx_import::import_onnx_model(model_stream); + const auto function = editor.get_function(); auto test_case = test::TestCase(function); test_case.add_expected_output(Shape{2, 1}, {-3, 3}); test_case.run(); @@ -790,8 +766,7 @@ NGRAPH_TEST(onnx_editor, values__append_two_initializers_mixed_types) in_vals.emplace("indices", op::Constant::create(element::i32, Shape{2, 2, 1}, {0, 1, 0, 1})); editor.set_input_values(in_vals); - std::istringstream model_stream(editor.model_string()); - const auto function = onnx_import::import_onnx_model(model_stream); + const auto function = editor.get_function(); auto test_case = test::TestCase(function); test_case.add_expected_output(Shape{2, 2, 1}, {1, 4, 5, 8}); test_case.run(); diff --git a/ngraph/test/onnx/onnx_test_utils.in.cpp b/ngraph/test/onnx/onnx_test_utils.in.cpp index ede1003fb11..3e14360ae73 100644 --- a/ngraph/test/onnx/onnx_test_utils.in.cpp +++ b/ngraph/test/onnx/onnx_test_utils.in.cpp @@ -41,8 +41,7 @@ TYPED_TEST_P(ElemTypesTests, onnx_test_add_abc_set_precission) editor.set_input_types({{"A", ng_type}, {"B", ng_type}, {"C", ng_type}}); - std::istringstream model_stream(editor.model_string()); - const auto function = onnx_import::import_onnx_model(model_stream); + const auto function = editor.get_function(); auto test_case = test::TestCase(function); test_case.add_input(std::vector{1, 2, 3}); test_case.add_input(std::vector{4, 5, 6}); @@ -61,8 +60,7 @@ TYPED_TEST_P(ElemTypesTests, onnx_test_split_multioutput_set_precission) editor.set_input_types({{"input", ng_type}}); - std::istringstream model_stream(editor.model_string()); - const auto function = onnx_import::import_onnx_model(model_stream); + const auto function = editor.get_function(); auto test_case = test::TestCase(function); test_case.add_input(std::vector{1, 2, 3, 4, 5, 6}); test_case.add_expected_output(Shape{2}, std::vector{1, 2});