Throw special exception if IR v7 is passed, but no IR v7 reader (#1293)
This commit is contained in:
parent
bce6ca07df
commit
71a7e913d1
@ -119,15 +119,11 @@ add_library(${TARGET_NAME}_obj OBJECT
|
||||
|
||||
target_compile_definitions(${TARGET_NAME}_obj PRIVATE IMPLEMENT_INFERENCE_ENGINE_API)
|
||||
|
||||
# TODO: Remove this definitios when readers will be loaded from xml
|
||||
if(NGRAPH_ONNX_IMPORT_ENABLE)
|
||||
target_compile_definitions(${TARGET_NAME}_obj PRIVATE ONNX_IMPORT_ENABLE)
|
||||
endif()
|
||||
|
||||
target_include_directories(${TARGET_NAME}_obj SYSTEM PRIVATE $<TARGET_PROPERTY:ngraph::ngraph,INTERFACE_INCLUDE_DIRECTORIES>
|
||||
$<TARGET_PROPERTY:pugixml,INTERFACE_INCLUDE_DIRECTORIES>)
|
||||
|
||||
target_include_directories(${TARGET_NAME}_obj PRIVATE "${CMAKE_CURRENT_SOURCE_DIR}"
|
||||
"${IE_MAIN_SOURCE_DIR}/src/readers/ir_reader" # for ie_ir_version.hpp
|
||||
$<TARGET_PROPERTY:${TARGET_NAME}_reader_api,INTERFACE_INCLUDE_DIRECTORIES>
|
||||
$<TARGET_PROPERTY:${TARGET_NAME}_plugin_api,INTERFACE_INCLUDE_DIRECTORIES>)
|
||||
|
||||
|
@ -9,6 +9,7 @@
|
||||
#include <ie_blob_stream.hpp>
|
||||
#include <ie_profiling.hpp>
|
||||
#include <ie_reader.hpp>
|
||||
#include <ie_ir_version.hpp>
|
||||
|
||||
#include <fstream>
|
||||
#include <istream>
|
||||
@ -132,6 +133,23 @@ void registerReaders() {
|
||||
initialized = true;
|
||||
}
|
||||
|
||||
void assertIfIRv7LikeModel(std::istream & modelStream) {
|
||||
auto irVersion = details::GetIRVersion(modelStream);
|
||||
bool isIRv7 = irVersion > 1 && irVersion <= 7;
|
||||
|
||||
if (!isIRv7)
|
||||
return;
|
||||
|
||||
for (auto && kvp : readers) {
|
||||
Reader::Ptr reader = kvp.second;
|
||||
if (reader->getName() == "IRv7") {
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
THROW_IE_EXCEPTION << "IR v" << irVersion << " is deprecated. Please, migrate to IR v10 version";
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
CNNNetwork details::ReadNetwork(const std::string& modelPath, const std::string& binPath, const std::vector<IExtensionPtr>& exts) {
|
||||
@ -150,6 +168,8 @@ CNNNetwork details::ReadNetwork(const std::string& modelPath, const std::string&
|
||||
if (!modelStream.is_open())
|
||||
THROW_IE_EXCEPTION << "Model file " << modelPath << " cannot be opened!";
|
||||
|
||||
assertIfIRv7LikeModel(modelStream);
|
||||
|
||||
// Find reader for model extension
|
||||
auto fileExt = modelPath.substr(modelPath.find_last_of(".") + 1);
|
||||
for (auto it = readers.lower_bound(fileExt); it != readers.upper_bound(fileExt); it++) {
|
||||
@ -203,6 +223,8 @@ CNNNetwork details::ReadNetwork(const std::string& model, const Blob::CPtr& weig
|
||||
std::istringstream modelStream(model);
|
||||
details::BlobStream binStream(weights);
|
||||
|
||||
assertIfIRv7LikeModel(modelStream);
|
||||
|
||||
for (auto it = readers.begin(); it != readers.end(); it++) {
|
||||
auto reader = it->second;
|
||||
if (reader->supportModel(modelStream)) {
|
||||
|
@ -5,6 +5,7 @@
|
||||
#include <file_utils.h>
|
||||
#include <xml_parse_utils.h>
|
||||
|
||||
#include <ie_ir_version.hpp>
|
||||
#include <ie_ir_reader.hpp>
|
||||
#include <memory>
|
||||
#include <ngraph/ngraph.hpp>
|
||||
@ -20,39 +21,14 @@
|
||||
|
||||
using namespace InferenceEngine;
|
||||
|
||||
static size_t GetIRVersion(pugi::xml_node& root) {
|
||||
return XMLParseUtils::GetUIntAttr(root, "version", 0);
|
||||
}
|
||||
|
||||
bool IRReader::supportModel(std::istream& model) const {
|
||||
model.seekg(0, model.beg);
|
||||
const int header_size = 128;
|
||||
std::string header(header_size, ' ');
|
||||
model.read(&header[0], header_size);
|
||||
model.seekg(0, model.beg);
|
||||
auto version = details::GetIRVersion(model);
|
||||
|
||||
pugi::xml_document doc;
|
||||
auto res = doc.load_string(header.c_str(), pugi::parse_default | pugi::parse_fragment);
|
||||
|
||||
bool supports = false;
|
||||
|
||||
if (res == pugi::status_ok) {
|
||||
pugi::xml_node root = doc.document_element();
|
||||
|
||||
std::string node_name = root.name();
|
||||
std::transform(node_name.begin(), node_name.end(), node_name.begin(), ::tolower);
|
||||
|
||||
if (node_name == "net") {
|
||||
size_t const version = GetIRVersion(root);
|
||||
#ifdef IR_READER_V10
|
||||
supports = version == 10;
|
||||
return version == 10;
|
||||
#else
|
||||
supports = version < 10;
|
||||
return version > 1 && version <= 7;
|
||||
#endif
|
||||
}
|
||||
}
|
||||
|
||||
return supports;
|
||||
}
|
||||
|
||||
CNNNetwork IRReader::read(std::istream& model, const std::vector<IExtensionPtr>& exts) const {
|
||||
@ -68,7 +44,7 @@ CNNNetwork IRReader::read(std::istream& model, std::istream& weights, const std:
|
||||
}
|
||||
pugi::xml_node root = xmlDoc.document_element();
|
||||
|
||||
auto version = GetIRVersion(root);
|
||||
auto version = details::GetIRVersion(root);
|
||||
IRParser parser(version, exts);
|
||||
return CNNNetwork(parser.parse(root, weights));
|
||||
}
|
||||
|
46
inference-engine/src/readers/ir_reader/ie_ir_version.hpp
Normal file
46
inference-engine/src/readers/ir_reader/ie_ir_version.hpp
Normal file
@ -0,0 +1,46 @@
|
||||
|
||||
// Copyright (C) 2018-2020 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#include <fstream>
|
||||
#include <xml_parse_utils.h>
|
||||
|
||||
namespace InferenceEngine {
|
||||
namespace details {
|
||||
|
||||
inline size_t GetIRVersion(pugi::xml_node& root) {
|
||||
return XMLParseUtils::GetUIntAttr(root, "version", 0);
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Extracts IR version from model stream
|
||||
* @param model Models stream
|
||||
* @return IR version, 0 if model does represent IR
|
||||
*/
|
||||
size_t GetIRVersion(std::istream& model) {
|
||||
model.seekg(0, model.beg);
|
||||
const int header_size = 128;
|
||||
std::string header(header_size, ' ');
|
||||
model.read(&header[0], header_size);
|
||||
model.seekg(0, model.beg);
|
||||
|
||||
pugi::xml_document doc;
|
||||
auto res = doc.load_string(header.c_str(), pugi::parse_default | pugi::parse_fragment);
|
||||
|
||||
if (res == pugi::status_ok) {
|
||||
pugi::xml_node root = doc.document_element();
|
||||
|
||||
std::string node_name = root.name();
|
||||
std::transform(node_name.begin(), node_name.end(), node_name.begin(), ::tolower);
|
||||
|
||||
if (node_name == "net") {
|
||||
return GetIRVersion(root);
|
||||
}
|
||||
}
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
} // namespace details
|
||||
} // namespace InferenceEngine
|
Loading…
Reference in New Issue
Block a user