Throw special exception if IR v7 is passed, but no IR v7 reader (#1293)

This commit is contained in:
Ilya Lavrenov 2020-07-13 06:13:59 +03:00 committed by GitHub
parent bce6ca07df
commit 71a7e913d1
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 74 additions and 34 deletions

View File

@ -119,15 +119,11 @@ add_library(${TARGET_NAME}_obj OBJECT
target_compile_definitions(${TARGET_NAME}_obj PRIVATE IMPLEMENT_INFERENCE_ENGINE_API)
# TODO: Remove this definitios when readers will be loaded from xml
if(NGRAPH_ONNX_IMPORT_ENABLE)
target_compile_definitions(${TARGET_NAME}_obj PRIVATE ONNX_IMPORT_ENABLE)
endif()
target_include_directories(${TARGET_NAME}_obj SYSTEM PRIVATE $<TARGET_PROPERTY:ngraph::ngraph,INTERFACE_INCLUDE_DIRECTORIES>
$<TARGET_PROPERTY:pugixml,INTERFACE_INCLUDE_DIRECTORIES>)
target_include_directories(${TARGET_NAME}_obj PRIVATE "${CMAKE_CURRENT_SOURCE_DIR}"
"${IE_MAIN_SOURCE_DIR}/src/readers/ir_reader" # for ie_ir_version.hpp
$<TARGET_PROPERTY:${TARGET_NAME}_reader_api,INTERFACE_INCLUDE_DIRECTORIES>
$<TARGET_PROPERTY:${TARGET_NAME}_plugin_api,INTERFACE_INCLUDE_DIRECTORIES>)

View File

@ -9,6 +9,7 @@
#include <ie_blob_stream.hpp>
#include <ie_profiling.hpp>
#include <ie_reader.hpp>
#include <ie_ir_version.hpp>
#include <fstream>
#include <istream>
@ -132,6 +133,23 @@ void registerReaders() {
initialized = true;
}
void assertIfIRv7LikeModel(std::istream & modelStream) {
auto irVersion = details::GetIRVersion(modelStream);
bool isIRv7 = irVersion > 1 && irVersion <= 7;
if (!isIRv7)
return;
for (auto && kvp : readers) {
Reader::Ptr reader = kvp.second;
if (reader->getName() == "IRv7") {
return;
}
}
THROW_IE_EXCEPTION << "IR v" << irVersion << " is deprecated. Please, migrate to IR v10 version";
}
} // namespace
CNNNetwork details::ReadNetwork(const std::string& modelPath, const std::string& binPath, const std::vector<IExtensionPtr>& exts) {
@ -150,6 +168,8 @@ CNNNetwork details::ReadNetwork(const std::string& modelPath, const std::string&
if (!modelStream.is_open())
THROW_IE_EXCEPTION << "Model file " << modelPath << " cannot be opened!";
assertIfIRv7LikeModel(modelStream);
// Find reader for model extension
auto fileExt = modelPath.substr(modelPath.find_last_of(".") + 1);
for (auto it = readers.lower_bound(fileExt); it != readers.upper_bound(fileExt); it++) {
@ -203,6 +223,8 @@ CNNNetwork details::ReadNetwork(const std::string& model, const Blob::CPtr& weig
std::istringstream modelStream(model);
details::BlobStream binStream(weights);
assertIfIRv7LikeModel(modelStream);
for (auto it = readers.begin(); it != readers.end(); it++) {
auto reader = it->second;
if (reader->supportModel(modelStream)) {

View File

@ -5,6 +5,7 @@
#include <file_utils.h>
#include <xml_parse_utils.h>
#include <ie_ir_version.hpp>
#include <ie_ir_reader.hpp>
#include <memory>
#include <ngraph/ngraph.hpp>
@ -20,39 +21,14 @@
using namespace InferenceEngine;
static size_t GetIRVersion(pugi::xml_node& root) {
return XMLParseUtils::GetUIntAttr(root, "version", 0);
}
bool IRReader::supportModel(std::istream& model) const {
model.seekg(0, model.beg);
const int header_size = 128;
std::string header(header_size, ' ');
model.read(&header[0], header_size);
model.seekg(0, model.beg);
auto version = details::GetIRVersion(model);
pugi::xml_document doc;
auto res = doc.load_string(header.c_str(), pugi::parse_default | pugi::parse_fragment);
bool supports = false;
if (res == pugi::status_ok) {
pugi::xml_node root = doc.document_element();
std::string node_name = root.name();
std::transform(node_name.begin(), node_name.end(), node_name.begin(), ::tolower);
if (node_name == "net") {
size_t const version = GetIRVersion(root);
#ifdef IR_READER_V10
supports = version == 10;
return version == 10;
#else
supports = version < 10;
return version > 1 && version <= 7;
#endif
}
}
return supports;
}
CNNNetwork IRReader::read(std::istream& model, const std::vector<IExtensionPtr>& exts) const {
@ -68,7 +44,7 @@ CNNNetwork IRReader::read(std::istream& model, std::istream& weights, const std:
}
pugi::xml_node root = xmlDoc.document_element();
auto version = GetIRVersion(root);
auto version = details::GetIRVersion(root);
IRParser parser(version, exts);
return CNNNetwork(parser.parse(root, weights));
}

View File

@ -0,0 +1,46 @@
// Copyright (C) 2018-2020 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#include <fstream>
#include <xml_parse_utils.h>
namespace InferenceEngine {
namespace details {
inline size_t GetIRVersion(pugi::xml_node& root) {
return XMLParseUtils::GetUIntAttr(root, "version", 0);
}
/**
* @brief Extracts IR version from model stream
* @param model Models stream
* @return IR version, 0 if model does represent IR
*/
size_t GetIRVersion(std::istream& model) {
model.seekg(0, model.beg);
const int header_size = 128;
std::string header(header_size, ' ');
model.read(&header[0], header_size);
model.seekg(0, model.beg);
pugi::xml_document doc;
auto res = doc.load_string(header.c_str(), pugi::parse_default | pugi::parse_fragment);
if (res == pugi::status_ok) {
pugi::xml_node root = doc.document_element();
std::string node_name = root.name();
std::transform(node_name.begin(), node_name.end(), node_name.begin(), ::tolower);
if (node_name == "net") {
return GetIRVersion(root);
}
}
return 0;
}
} // namespace details
} // namespace InferenceEngine