diff --git a/inference-engine/src/inference_engine/ie_network_reader.cpp b/inference-engine/src/inference_engine/ie_network_reader.cpp index b406c225b77..41359306480 100644 --- a/inference-engine/src/inference_engine/ie_network_reader.cpp +++ b/inference-engine/src/inference_engine/ie_network_reader.cpp @@ -183,7 +183,6 @@ CNNNetwork details::ReadNetwork(const std::string& modelPath, const std::string& auto reader = it->second; // Check that reader supports the model if (reader->supportModel(modelStream)) { - modelStream.seekg(0, modelStream.beg); // Find weights std::string bPath = binPath; if (bPath.empty()) { @@ -236,7 +235,6 @@ CNNNetwork details::ReadNetwork(const std::string& model, const Blob::CPtr& weig for (auto it = readers.begin(); it != readers.end(); it++) { auto reader = it->second; if (reader->supportModel(modelStream)) { - modelStream.seekg(0, modelStream.beg); if (weights) return reader->read(modelStream, binStream, exts); return reader->read(modelStream, exts); diff --git a/inference-engine/src/readers/onnx_reader/ie_onnx_reader.cpp b/inference-engine/src/readers/onnx_reader/ie_onnx_reader.cpp index 0892239f9c3..866806131c2 100644 --- a/inference-engine/src/readers/onnx_reader/ie_onnx_reader.cpp +++ b/inference-engine/src/readers/onnx_reader/ie_onnx_reader.cpp @@ -10,17 +10,36 @@ using namespace InferenceEngine; namespace { - std::string readPathFromStream(std::istream& stream) { - if (stream.pword(0) == nullptr) { - return {}; - } - // read saved path from extensible array - return std::string{static_cast(stream.pword(0))}; +std::string readPathFromStream(std::istream& stream) { + if (stream.pword(0) == nullptr) { + return {}; } + // read saved path from extensible array + return std::string{static_cast(stream.pword(0))}; } +/** + * This helper struct uses RAII to rewind/reset the stream so that it points to the beginning + * of the underlying resource (string, file, ...). It works similarily to std::lock_guard + * which releases a mutex upon destruction. + * + * This makes sure that the stream is always reset (exception, successful and unsuccessful + * model validation). + */ +struct StreamRewinder { + StreamRewinder(std::istream& stream) : m_stream(stream) { + m_stream.seekg(0, m_stream.beg); + } + ~StreamRewinder() { + m_stream.seekg(0, m_stream.beg); + } +private: + std::istream& m_stream; +}; +} // namespace + bool ONNXReader::supportModel(std::istream& model) const { - model.seekg(0, model.beg); + StreamRewinder rwd{model}; const auto model_path = readPathFromStream(model); diff --git a/ngraph/test/CMakeLists.txt b/ngraph/test/CMakeLists.txt index 7abaff8cb15..d7d2a118022 100644 --- a/ngraph/test/CMakeLists.txt +++ b/ngraph/test/CMakeLists.txt @@ -365,13 +365,13 @@ if (NGRAPH_ONNX_IMPORT_ENABLE AND NOT NGRAPH_USE_PROTOBUF_LITE) onnx/onnx_import_convpool.in.cpp onnx/onnx_import_dyn_shapes.in.cpp onnx/onnx_import_external_data.in.cpp - onnx/onnx_import_library.in.cpp onnx/onnx_import_provenance.in.cpp onnx/onnx_import_reshape.in.cpp onnx/onnx_import_rnn.in.cpp onnx/onnx_import_quant.in.cpp) list(APPEND SRC - onnx/onnx_import_exceptions.cpp) + onnx/onnx_import_exceptions.cpp + onnx/onnx_import_library.cpp) endif() foreach(BACKEND_NAME ${ACTIVE_BACKEND_LIST}) diff --git a/ngraph/test/onnx/onnx_import_library.in.cpp b/ngraph/test/onnx/onnx_import_library.cpp similarity index 54% rename from ngraph/test/onnx/onnx_import_library.in.cpp rename to ngraph/test/onnx/onnx_import_library.cpp index 5d032b5da48..5553e1ca559 100644 --- a/ngraph/test/onnx/onnx_import_library.in.cpp +++ b/ngraph/test/onnx/onnx_import_library.cpp @@ -19,15 +19,13 @@ #include "onnx/defs/schema.h" #include "gtest/gtest.h" -#include "util/all_close.hpp" -#include "util/test_case.hpp" #include "util/test_control.hpp" using namespace ngraph; static std::string s_manifest = "${MANIFEST}"; -NGRAPH_TEST(onnx_${BACKEND_NAME}, get_function_op_with_version) +NGRAPH_TEST(onnx, get_function_op_with_version) { const auto* schema = ONNX_NAMESPACE::OpSchemaRegistry::Schema("MeanVarianceNormalization", 9, ""); @@ -36,3 +34,19 @@ NGRAPH_TEST(onnx_${BACKEND_NAME}, get_function_op_with_version) auto func = schema->GetFunction(); EXPECT_EQ(func->name(), "MeanVarianceNormalization"); } + +NGRAPH_TEST(onnx, check_ir_version_support) +{ + // It appears you've changed the ONNX library version used by nGraph. Please update the value + // tested below to make sure it equals the current IR_VERSION enum value defined in ONNX headers + // + // You should also check the onnx_reader/onnx_model_validator.cpp file and make sure that + // the details::onnx::is_correct_onnx_field() handles any new fields added in the new release + // of the ONNX library. Make sure to update the "Field" enum and the function mentioned above. + // + // The last step is to also update the details::onnx::contains_onnx_model_keys() function + // in the same file to make sure that prototxt format validation also covers the changes in ONNX + EXPECT_EQ(ONNX_NAMESPACE::Version::IR_VERSION, 6) + << "The IR_VERSION defined in ONNX does not match the version that OpenVINO supports. " + "Please check the source code of this test for details and explanation how to proceed."; +}