Split IR readers (#1167)
* Split IR readers * Fixed tests * CMAKE: Removed add_clang_format_target usage from readers
This commit is contained in:
parent
0e904405f7
commit
ef6280ab99
@ -88,6 +88,9 @@ function(ie_add_plugin)
|
||||
if(TARGET inference_engine_ir_reader)
|
||||
add_dependencies(${IE_PLUGIN_NAME} inference_engine_ir_reader)
|
||||
endif()
|
||||
if(TARGET inference_engine_ir_reader_v7)
|
||||
add_dependencies(${IE_PLUGIN_NAME} inference_engine_ir_reader_v7)
|
||||
endif()
|
||||
if(TARGET inference_engine_onnx_reader)
|
||||
add_dependencies(${IE_PLUGIN_NAME} inference_engine_onnx_reader)
|
||||
endif()
|
||||
|
@ -34,26 +34,6 @@ namespace InferenceEngine {
|
||||
|
||||
namespace {
|
||||
|
||||
std::once_flag flag;
|
||||
InferenceEngine::details::SharedObjectLoader::Ptr cnnReaderLoader;
|
||||
|
||||
InferenceEngine::details::SharedObjectLoader::Ptr createCnnReaderLoader() {
|
||||
std::call_once(flag, [&] () {
|
||||
FileUtils::FilePath libraryName = FileUtils::toFilePath(std::string("inference_engine_ir_reader") + std::string(IE_BUILD_POSTFIX));
|
||||
FileUtils::FilePath irReadersLibraryPath = FileUtils::makeSharedLibraryName(getInferenceEngineLibraryPath(), libraryName);
|
||||
|
||||
if (!FileUtils::fileExist(irReadersLibraryPath)) {
|
||||
THROW_IE_EXCEPTION << "Please, make sure that Inference Engine IR readers library "
|
||||
<< FileUtils::fromFilePath(::FileUtils::makeSharedLibraryName({}, libraryName)) << " is in "
|
||||
<< getIELibraryPath();
|
||||
}
|
||||
cnnReaderLoader = std::shared_ptr<InferenceEngine::details::SharedObjectLoader>(
|
||||
new InferenceEngine::details::SharedObjectLoader(irReadersLibraryPath.c_str()));
|
||||
});
|
||||
|
||||
return cnnReaderLoader;
|
||||
}
|
||||
|
||||
IInferencePluginAPI* getInferencePluginAPIInterface(IInferencePlugin* iplugin) {
|
||||
return dynamic_cast<IInferencePluginAPI*>(iplugin);
|
||||
}
|
||||
|
@ -37,7 +37,6 @@ public:
|
||||
* @brief This class is a wrapper for reader interfaces
|
||||
*/
|
||||
class Reader: public IReader {
|
||||
private:
|
||||
InferenceEngine::details::SOPointer<IReader> ptr;
|
||||
std::once_flag readFlag;
|
||||
std::string name;
|
||||
@ -120,10 +119,16 @@ void registerReaders() {
|
||||
readers.emplace("prototxt", onnxReader);
|
||||
}
|
||||
|
||||
// try to load IR reader if library exists
|
||||
auto irReader = create_if_exists("IR", std::string("inference_engine_ir_reader") + std::string(IE_BUILD_POSTFIX));
|
||||
if (irReader)
|
||||
readers.emplace("xml", irReader);
|
||||
// try to load IR reader v10 if library exists
|
||||
auto irReaderv10 = create_if_exists("IRv10", std::string("inference_engine_ir_reader") + std::string(IE_BUILD_POSTFIX));
|
||||
if (irReaderv10)
|
||||
readers.emplace("xml", irReaderv10);
|
||||
|
||||
// try to load IR reader v7 if library exists
|
||||
auto irReaderv7 = create_if_exists("IRv7", std::string("inference_engine_ir_reader_v7") + std::string(IE_BUILD_POSTFIX));
|
||||
if (irReaderv7)
|
||||
readers.emplace("xml", irReaderv7);
|
||||
|
||||
initialized = true;
|
||||
}
|
||||
|
||||
|
@ -3,7 +3,7 @@
|
||||
#
|
||||
set(TARGET_NAME inference_engine_reader_api)
|
||||
|
||||
# Reader API library
|
||||
# Reader API interface library
|
||||
add_library(${TARGET_NAME} INTERFACE)
|
||||
|
||||
target_include_directories(${TARGET_NAME} INTERFACE
|
||||
@ -16,6 +16,8 @@ file(GLOB_RECURSE reader_api_src "${CMAKE_CURRENT_SOURCE_DIR}/reader_api/*.hpp"
|
||||
add_cpplint_target(${TARGET_NAME}_cpplint FOR_SOURCES ${reader_api_src})
|
||||
|
||||
add_subdirectory(ir_reader)
|
||||
add_subdirectory(ir_reader_v7)
|
||||
|
||||
if(NGRAPH_ONNX_IMPORT_ENABLE)
|
||||
add_subdirectory(onnx_reader)
|
||||
endif()
|
||||
|
@ -24,7 +24,8 @@ source_group("include" FILES ${PUBLIC_HEADERS})
|
||||
add_library(${TARGET_NAME} SHARED ${LIBRARY_SRC} ${PUBLIC_HEADERS})
|
||||
|
||||
target_compile_definitions(${TARGET_NAME} PRIVATE IMPLEMENT_INFERENCE_ENGINE_API
|
||||
IMPLEMENT_INFERENCE_ENGINE_PLUGIN)
|
||||
IMPLEMENT_INFERENCE_ENGINE_PLUGIN
|
||||
IR_READER_V10)
|
||||
|
||||
target_include_directories(${TARGET_NAME} PUBLIC ${PUBLIC_HEADERS_DIR})
|
||||
target_include_directories(${TARGET_NAME} PRIVATE "${IE_MAIN_SOURCE_DIR}/src/inference_engine")
|
||||
@ -35,11 +36,6 @@ target_link_libraries(${TARGET_NAME} PRIVATE pugixml)
|
||||
# code style
|
||||
|
||||
add_cpplint_target(${TARGET_NAME}_cpplint FOR_TARGETS ${TARGET_NAME})
|
||||
add_clang_format_target(${TARGET_NAME}_clang_format FOR_TARGETS ${TARGET_NAME})
|
||||
|
||||
# developer package
|
||||
|
||||
ie_developer_export_targets(${TARGET_NAME})
|
||||
|
||||
# install
|
||||
|
||||
|
@ -27,25 +27,17 @@
|
||||
|
||||
#include <cpp/ie_cnn_network.h>
|
||||
#include "ie_blob_stream.hpp"
|
||||
#include "cnn_network_impl.hpp"
|
||||
#include "details/caseless.hpp"
|
||||
#include "details/ie_cnn_network_tools.h"
|
||||
#include "ie_format_parser.h"
|
||||
#include "ie_ngraph_utils.hpp"
|
||||
#include "generic_ie.hpp"
|
||||
#include "precision_utils.h"
|
||||
#include "blob_factory.hpp"
|
||||
#include "ie_cnn_net_reader_impl.h"
|
||||
|
||||
using namespace InferenceEngine;
|
||||
using namespace XMLParseUtils;
|
||||
|
||||
IRParser::IRParser(size_t version): IRParser(version, {}) {}
|
||||
IRParser::IRParser(size_t version, const std::vector<InferenceEngine::IExtensionPtr>& exts) {
|
||||
if (version < 10) {
|
||||
parser = std::make_shared<CNNParser>();
|
||||
return;
|
||||
}
|
||||
switch (version) {
|
||||
case 10:
|
||||
parser = std::make_shared<V10Parser>(exts);
|
||||
@ -72,45 +64,6 @@ public:
|
||||
originBlob(weights) { }
|
||||
};
|
||||
|
||||
std::shared_ptr<ICNNNetwork> CNNParser::parse(const pugi::xml_node& root, std::istream& binStream) {
|
||||
auto getBlobStream = [](std::istream& binStream) {
|
||||
details::BlobStream* blobStream = dynamic_cast<details::BlobStream*>(&binStream);
|
||||
if (blobStream == nullptr) {
|
||||
details::BlobStream helper({});
|
||||
std::string typeStream = typeid(binStream).name();
|
||||
std::string typeBlobStream = typeid(helper).name();
|
||||
if (typeStream == typeBlobStream)
|
||||
blobStream = static_cast<details::BlobStream*>(&binStream);
|
||||
}
|
||||
return blobStream;
|
||||
};
|
||||
details::CNNNetReaderImpl reader(std::make_shared<details::V2FormatParserCreator>());
|
||||
ResponseDesc resp;
|
||||
StatusCode ret = reader.ReadNetwork(root, &resp);
|
||||
if (ret != OK)
|
||||
THROW_IE_EXCEPTION << resp.msg;
|
||||
TBlob<uint8_t>::Ptr weightsPtr;
|
||||
|
||||
// Try to get BlobStream to work with original blob
|
||||
details::BlobStream* blobStream = getBlobStream(binStream);
|
||||
if (blobStream != nullptr) {
|
||||
weightsPtr = std::make_shared<WeightsHolderBlob>(blobStream->getBlob());
|
||||
} else {
|
||||
// Allocate a blob for weights
|
||||
binStream.seekg(0, std::ios::end);
|
||||
size_t length = binStream.tellg();
|
||||
weightsPtr = std::make_shared<TBlob<uint8_t>>(TensorDesc(Precision::U8, {length}, Layout::C));
|
||||
weightsPtr->allocate();
|
||||
char* data = weightsPtr->buffer().as<char*>();
|
||||
binStream.seekg(0, std::ios::beg);
|
||||
binStream.read(data, length);
|
||||
}
|
||||
ret = reader.SetWeights(weightsPtr, &resp);
|
||||
if (ret != OK)
|
||||
THROW_IE_EXCEPTION << resp.msg;
|
||||
return reader.getNetwork();
|
||||
}
|
||||
|
||||
V10Parser::V10Parser(const std::vector<IExtensionPtr>& exts) {
|
||||
// Load default opsets
|
||||
opsets["opset1"] = ngraph::get_opset1();
|
||||
|
@ -4,7 +4,9 @@
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <ngraph/opsets/opset.hpp>
|
||||
#include <ie_blob.h>
|
||||
#include <ie_icnn_network.hpp>
|
||||
#include <ie_iextension.h>
|
||||
#include <xml_parse_utils.h>
|
||||
|
||||
@ -18,7 +20,6 @@
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "cnn_network_impl.hpp"
|
||||
#include "ie_ngraph_utils.hpp"
|
||||
|
||||
namespace InferenceEngine {
|
||||
|
@ -27,17 +27,38 @@ bool IRReader::supportModel(std::istream& model) const {
|
||||
const int header_size = 128;
|
||||
std::string header(header_size, ' ');
|
||||
model.read(&header[0], header_size);
|
||||
|
||||
// find '<net ' substring in the .xml file
|
||||
return (header.find("<net ") != std::string::npos) || (header.find("<Net ") != std::string::npos);
|
||||
bool supports = (header.find("<net ") != std::string::npos) ||
|
||||
(header.find("<Net ") != std::string::npos);
|
||||
|
||||
if (supports) {
|
||||
pugi::xml_document xmlDoc;
|
||||
model.seekg(0, model.beg);
|
||||
pugi::xml_parse_result res = xmlDoc.load(model);
|
||||
if (res.status != pugi::status_ok) {
|
||||
supports = false;
|
||||
} else {
|
||||
pugi::xml_node root = xmlDoc.document_element();
|
||||
auto version = GetIRVersion(root);
|
||||
#ifdef IR_READER_V10
|
||||
supports = version == 10;
|
||||
#else
|
||||
supports = version < 10;
|
||||
#endif
|
||||
}
|
||||
}
|
||||
|
||||
model.seekg(0, model.beg);
|
||||
return supports;
|
||||
}
|
||||
|
||||
CNNNetwork IRReader::read(std::istream& model, const std::vector<IExtensionPtr>& exts) const {
|
||||
std::istringstream emptyStream;
|
||||
return read(model, emptyStream, exts);
|
||||
}
|
||||
|
||||
CNNNetwork IRReader::read(std::istream& model, std::istream& weights, const std::vector<IExtensionPtr>& exts) const {
|
||||
model.seekg(0, model.beg);
|
||||
weights.seekg(0, weights.beg);
|
||||
pugi::xml_document xmlDoc;
|
||||
pugi::xml_parse_result res = xmlDoc.load(model);
|
||||
if (res.status != pugi::status_ok) {
|
||||
|
@ -32,8 +32,6 @@ namespace InferenceEngine {
|
||||
*/
|
||||
class IRReader: public IReader {
|
||||
public:
|
||||
IRReader() = default;
|
||||
|
||||
void Release() noexcept override {
|
||||
delete this;
|
||||
}
|
||||
|
39
inference-engine/src/readers/ir_reader_v7/CMakeLists.txt
Normal file
39
inference-engine/src/readers/ir_reader_v7/CMakeLists.txt
Normal file
@ -0,0 +1,39 @@
|
||||
# Copyright (C) 2018-2019 Intel Corporation
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
#
|
||||
|
||||
set(TARGET_NAME "inference_engine_ir_reader_v7")
|
||||
|
||||
if(ENABLE_LTO)
|
||||
ie_enable_lto()
|
||||
endif()
|
||||
|
||||
set(PUBLIC_HEADERS_DIR "${CMAKE_CURRENT_SOURCE_DIR}/")
|
||||
|
||||
file(GLOB_RECURSE LIBRARY_SRC ${CMAKE_CURRENT_SOURCE_DIR}/*.cpp)
|
||||
list(APPEND LIBRARY_SRC ${IE_MAIN_SOURCE_DIR}/src/readers/ir_reader/ie_ir_reader.cpp)
|
||||
file(GLOB_RECURSE PUBLIC_HEADERS ${PUBLIC_HEADERS_DIR}/*.h ${PUBLIC_HEADERS_DIR}/*.hpp)
|
||||
|
||||
# Create named folders for the sources within the .vcproj
|
||||
# Empty name lists them directly under the .vcproj
|
||||
|
||||
source_group("src" FILES ${LIBRARY_SRC})
|
||||
source_group("include" FILES ${PUBLIC_HEADERS})
|
||||
|
||||
# Create shared library
|
||||
|
||||
add_library(${TARGET_NAME} SHARED ${LIBRARY_SRC} ${PUBLIC_HEADERS})
|
||||
|
||||
target_compile_definitions(${TARGET_NAME} PRIVATE IMPLEMENT_INFERENCE_ENGINE_API
|
||||
IMPLEMENT_INFERENCE_ENGINE_PLUGIN)
|
||||
|
||||
target_include_directories(${TARGET_NAME} PUBLIC ${PUBLIC_HEADERS_DIR})
|
||||
target_include_directories(${TARGET_NAME} PRIVATE "${IE_MAIN_SOURCE_DIR}/src/inference_engine"
|
||||
"${IE_MAIN_SOURCE_DIR}/src/readers/ir_reader")
|
||||
|
||||
target_link_libraries(${TARGET_NAME} PUBLIC inference_engine_reader_api inference_engine_plugin_api ${NGRAPH_LIBRARIES} inference_engine)
|
||||
target_link_libraries(${TARGET_NAME} PRIVATE pugixml)
|
||||
|
||||
# code style
|
||||
|
||||
add_cpplint_target(${TARGET_NAME}_cpplint FOR_TARGETS ${TARGET_NAME})
|
@ -17,7 +17,6 @@
|
||||
#include "cnn_network_ngraph_impl.hpp"
|
||||
#include "details/os/os_filesystem.hpp"
|
||||
#include "ie_format_parser.h"
|
||||
#include "ie_ir_parser.hpp"
|
||||
#include "ie_profiling.hpp"
|
||||
#include "parsers.h"
|
||||
#include "xml_parse_utils.h"
|
||||
@ -29,20 +28,13 @@ using namespace InferenceEngine::details;
|
||||
CNNNetReaderImpl::CNNNetReaderImpl(const FormatParserCreator::Ptr& _creator)
|
||||
: parseSuccess(false), _version(0), parserCreator(_creator) {}
|
||||
|
||||
CNNNetReaderImpl::~CNNNetReaderImpl() { }
|
||||
|
||||
StatusCode CNNNetReaderImpl::SetWeights(const TBlob<uint8_t>::Ptr& weights, ResponseDesc* desc) noexcept {
|
||||
if (!_parser && _version < 10) {
|
||||
return DescriptionBuffer(desc) << "network must be read first";
|
||||
}
|
||||
|
||||
try {
|
||||
if (_version == 10) {
|
||||
// It's time to perform actual reading of V10 network and instantiate CNNNetworkNGraphImpl
|
||||
IRParser parser(_version, extensions);
|
||||
pugi::xml_node root = xmlDoc->document_element();
|
||||
details::BlobStream blobStream(weights);
|
||||
network = parser.parse(root, blobStream);
|
||||
} else if (weights) {
|
||||
if (_version < 10) {
|
||||
_parser->SetWeights(weights);
|
||||
}
|
||||
} catch (const InferenceEngineException& iee) {
|
||||
@ -54,7 +46,7 @@ StatusCode CNNNetReaderImpl::SetWeights(const TBlob<uint8_t>::Ptr& weights, Resp
|
||||
return OK;
|
||||
}
|
||||
|
||||
size_t CNNNetReaderImpl::GetFileVersion(pugi::xml_node& root) {
|
||||
static size_t GetFileVersion(pugi::xml_node& root) {
|
||||
return XMLParseUtils::GetUIntAttr(root, "version", 0);
|
||||
}
|
||||
|
||||
@ -126,12 +118,8 @@ StatusCode CNNNetReaderImpl::ReadNetwork(const pugi::xml_node& const_root, Respo
|
||||
pugi::xml_node root = *const_cast<pugi::xml_node*>(&const_root);
|
||||
_version = GetFileVersion(root);
|
||||
if (_version < 2) THROW_IE_EXCEPTION << "deprecated IR version: " << _version;
|
||||
if (_version == 10) {
|
||||
// Activate an alternative code path for V10 that should be read into ngraph::Function
|
||||
// We cannot proceed with reading right now, because there is not binary file loaded.
|
||||
// So we are postponing real read until weights are specified.
|
||||
parseSuccess = true;
|
||||
} else if (_version < 10) {
|
||||
|
||||
if (_version < 10) {
|
||||
_parser = parserCreator->create(_version);
|
||||
InferenceEngine::details::CNNNetworkImplPtr local_network = _parser->Parse(root);
|
||||
name = local_network->getName();
|
||||
@ -185,10 +173,6 @@ StatusCode CNNNetReaderImpl::ReadNetwork() {
|
||||
return OK;
|
||||
}
|
||||
|
||||
void CNNNetReaderImpl::addExtensions(const std::vector<InferenceEngine::IExtensionPtr>& ext) {
|
||||
extensions = ext;
|
||||
}
|
||||
|
||||
std::shared_ptr<IFormatParser> V2FormatParserCreator::create(size_t version) {
|
||||
return std::make_shared<FormatParser>(version);
|
||||
}
|
@ -58,14 +58,6 @@ public:
|
||||
return network;
|
||||
}
|
||||
|
||||
bool isParseSuccess(ResponseDesc* resp) noexcept {
|
||||
return parseSuccess;
|
||||
}
|
||||
|
||||
StatusCode getDescription(ResponseDesc* desc) noexcept {
|
||||
return DescriptionBuffer(OK, desc) << description;
|
||||
}
|
||||
|
||||
StatusCode getName(char* name, size_t len, ResponseDesc* resp) noexcept {
|
||||
if (len > 0) {
|
||||
size_t length = std::min(this->name.size(), len - 1); // cut the name if buffer is too small
|
||||
@ -75,21 +67,8 @@ public:
|
||||
return OK;
|
||||
}
|
||||
|
||||
int getVersion(ResponseDesc* resp) noexcept {
|
||||
return _version;
|
||||
}
|
||||
|
||||
void Release() noexcept {
|
||||
delete this;
|
||||
}
|
||||
|
||||
void addExtensions(const std::vector<InferenceEngine::IExtensionPtr>& ext);
|
||||
|
||||
~CNNNetReaderImpl();
|
||||
|
||||
private:
|
||||
std::shared_ptr<InferenceEngine::details::IFormatParser> _parser;
|
||||
size_t GetFileVersion(pugi::xml_node& root);
|
||||
StatusCode ReadNetwork();
|
||||
|
||||
std::string description;
|
||||
@ -101,7 +80,6 @@ private:
|
||||
|
||||
// Stashed xmlDoc that is needed to delayed loading of V10 IR version
|
||||
std::shared_ptr<pugi::xml_document> xmlDoc;
|
||||
std::vector<InferenceEngine::IExtensionPtr> extensions;
|
||||
};
|
||||
|
||||
} // namespace details
|
@ -496,6 +496,9 @@ Blob::Ptr FormatParser::GetBlobFromSegment(const TBlob<uint8_t>::Ptr& weights, c
|
||||
}
|
||||
|
||||
void FormatParser::SetWeights(const TBlob<uint8_t>::Ptr& weights) {
|
||||
if (weights == nullptr)
|
||||
return;
|
||||
|
||||
for (auto& kvp : _network->allLayers()) {
|
||||
auto fit = layersParseInfo.find(kvp.second->name);
|
||||
// todo: may check that earlier - while parsing...
|
76
inference-engine/src/readers/ir_reader_v7/ie_ir_parser.cpp
Normal file
76
inference-engine/src/readers/ir_reader_v7/ie_ir_parser.cpp
Normal file
@ -0,0 +1,76 @@
|
||||
// Copyright (C) 2018-2020 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#include "ie_reader.hpp"
|
||||
#include "ie_ir_parser.hpp"
|
||||
#include "ie_blob_stream.hpp"
|
||||
#include "ie_cnn_net_reader_impl.h"
|
||||
|
||||
using namespace InferenceEngine;
|
||||
|
||||
IRParser::IRParser(size_t version): IRParser(version, {}) {}
|
||||
IRParser::IRParser(size_t version, const std::vector<InferenceEngine::IExtensionPtr>& exts) {
|
||||
if (version < 10) {
|
||||
parser = std::make_shared<CNNParser>();
|
||||
return;
|
||||
} else {
|
||||
THROW_IE_EXCEPTION << "Unsupported IR version: " << version;
|
||||
}
|
||||
}
|
||||
|
||||
std::shared_ptr<ICNNNetwork> IRParser::parse(const pugi::xml_node& root, std::istream& binStream) {
|
||||
return parser->parse(root, binStream);
|
||||
}
|
||||
|
||||
/**
|
||||
* Hold original blob in order to avoid situations when original blob is allocated on stack
|
||||
*/
|
||||
class WeightsHolderBlob : public TBlob<uint8_t> {
|
||||
Blob::CPtr originBlob;
|
||||
|
||||
public:
|
||||
explicit WeightsHolderBlob(const Blob::CPtr& weights) :
|
||||
TBlob<uint8_t>(weights->getTensorDesc(),
|
||||
weights->cbuffer().as<uint8_t*>()),
|
||||
originBlob(weights) { }
|
||||
};
|
||||
|
||||
std::shared_ptr<ICNNNetwork> CNNParser::parse(const pugi::xml_node& root, std::istream& binStream) {
|
||||
auto getBlobStream = [](std::istream& binStream) {
|
||||
details::BlobStream* blobStream = dynamic_cast<details::BlobStream*>(&binStream);
|
||||
if (blobStream == nullptr) {
|
||||
details::BlobStream helper({});
|
||||
std::string typeStream = typeid(binStream).name();
|
||||
std::string typeBlobStream = typeid(helper).name();
|
||||
if (typeStream == typeBlobStream)
|
||||
blobStream = static_cast<details::BlobStream*>(&binStream);
|
||||
}
|
||||
return blobStream;
|
||||
};
|
||||
details::CNNNetReaderImpl reader(std::make_shared<details::V2FormatParserCreator>());
|
||||
ResponseDesc resp;
|
||||
StatusCode ret = reader.ReadNetwork(root, &resp);
|
||||
if (ret != OK)
|
||||
THROW_IE_EXCEPTION << resp.msg;
|
||||
TBlob<uint8_t>::Ptr weightsPtr;
|
||||
|
||||
// Try to get BlobStream to work with original blob
|
||||
details::BlobStream* blobStream = getBlobStream(binStream);
|
||||
if (blobStream != nullptr) {
|
||||
weightsPtr = std::make_shared<WeightsHolderBlob>(blobStream->getBlob());
|
||||
} else {
|
||||
// Allocate a blob for weights
|
||||
binStream.seekg(0, std::ios::end);
|
||||
size_t length = binStream.tellg();
|
||||
weightsPtr = std::make_shared<TBlob<uint8_t>>(TensorDesc(Precision::U8, {length}, Layout::C));
|
||||
weightsPtr->allocate();
|
||||
char* data = weightsPtr->buffer().as<char*>();
|
||||
binStream.seekg(0, std::ios::beg);
|
||||
binStream.read(data, length);
|
||||
}
|
||||
ret = reader.SetWeights(weightsPtr, &resp);
|
||||
if (ret != OK)
|
||||
THROW_IE_EXCEPTION << resp.msg;
|
||||
return reader.getNetwork();
|
||||
}
|
@ -13,6 +13,7 @@ bool ONNXReader::supportModel(std::istream& model) const {
|
||||
const int header_size = 128;
|
||||
std::string header(header_size, ' ');
|
||||
model.read(&header[0], header_size);
|
||||
model.seekg(0, model.beg);
|
||||
// find 'onnx' substring in the .onnx files
|
||||
// find 'ir_version' and 'graph' for prototxt
|
||||
// return (header.find("onnx") != std::string::npos) || (header.find("pytorch") != std::string::npos) ||
|
||||
@ -21,7 +22,6 @@ bool ONNXReader::supportModel(std::istream& model) const {
|
||||
}
|
||||
|
||||
CNNNetwork ONNXReader::read(std::istream& model, const std::vector<IExtensionPtr>& exts) const {
|
||||
model.seekg(0, model.beg);
|
||||
return CNNNetwork(ngraph::onnx_import::import_onnx_model(model));
|
||||
}
|
||||
|
||||
|
@ -10,7 +10,6 @@ list(APPEND EXPORT_DEPENDENCIES
|
||||
commonTestUtils_s
|
||||
inference_engine_s
|
||||
inference_engine_lp_transformations
|
||||
inference_engine_ir_reader
|
||||
gmock)
|
||||
|
||||
addIeTarget(
|
||||
|
@ -6,7 +6,6 @@
|
||||
#include "unit_test_utils/mocks/mock_icnn_network.hpp"
|
||||
#include "unit_test_utils/mocks/mock_ie_imemory_state.hpp"
|
||||
#include "unit_test_utils/mocks/mock_iexecutable_network.hpp"
|
||||
#include "unit_test_utils/mocks/mock_iformat_parser.hpp"
|
||||
#include "unit_test_utils/mocks/mock_iinfer_request.hpp"
|
||||
#include "unit_test_utils/mocks/mock_iinference_plugin.hpp"
|
||||
#include "unit_test_utils/mocks/mock_not_empty_icnn_network.hpp"
|
||||
|
@ -1,24 +0,0 @@
|
||||
// Copyright (C) 2018-2020 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
/**
|
||||
* \brief mock file for header file for IFormatParser
|
||||
* \file mock_iformat_parser.hpp
|
||||
*/
|
||||
#pragma once
|
||||
|
||||
#include <gmock/gmock.h>
|
||||
|
||||
#include "ie_icnn_network.hpp"
|
||||
#include <ie_cnn_net_reader_impl.h>
|
||||
#include <parsers.h>
|
||||
#include "pugixml.hpp"
|
||||
|
||||
struct MockIFormatParser : public InferenceEngine::details::IFormatParser {
|
||||
public:
|
||||
MOCK_METHOD1(Parse, InferenceEngine::details::CNNNetworkImplPtr(pugi::xml_node &));
|
||||
|
||||
MOCK_METHOD1(SetWeights, void(const InferenceEngine::TBlob<uint8_t>::Ptr &));
|
||||
};
|
||||
|
@ -9,6 +9,7 @@ addIeTargetTest(
|
||||
ROOT ${CMAKE_CURRENT_SOURCE_DIR}
|
||||
LINK_LIBRARIES
|
||||
unitTestUtils
|
||||
inference_engine_ir_reader_v7
|
||||
ADD_CPPLINT
|
||||
DEPENDENCIES
|
||||
mock_engine
|
||||
|
@ -4,7 +4,7 @@
|
||||
|
||||
#include "classification_matcher.hpp"
|
||||
#include <gtest/gtest.h>
|
||||
#include <xml_helper.hpp>
|
||||
#include <fstream>
|
||||
#include "details/ie_cnn_network_iterator.hpp"
|
||||
|
||||
using namespace Regression ;
|
||||
|
@ -2,7 +2,6 @@
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#include <xml_helper.hpp>
|
||||
#include "object_detection_matcher.hpp"
|
||||
#include "details/ie_cnn_network_iterator.hpp"
|
||||
|
||||
|
@ -6,7 +6,6 @@
|
||||
#include <tests_common.hpp>
|
||||
#include <tests_common_func.hpp>
|
||||
#include <memory>
|
||||
#include "xml_helper.hpp"
|
||||
#include <ie_core.hpp>
|
||||
|
||||
#define XBYAK_NO_OP_NAMES
|
||||
|
@ -7,7 +7,6 @@
|
||||
#include <gtest/gtest.h>
|
||||
|
||||
#include <tests_common.hpp>
|
||||
#include <ie_format_parser.h>
|
||||
#include <ie_layers_internal.hpp>
|
||||
#include <details/ie_cnn_network_iterator.hpp>
|
||||
#include <functional_test_utils/plugin_cache.hpp>
|
||||
|
@ -5,7 +5,6 @@
|
||||
|
||||
#include <memory>
|
||||
#include <unordered_set>
|
||||
#include <xml_helper.hpp>
|
||||
|
||||
#include <gtest/gtest.h>
|
||||
#include "cpp_interfaces/interface/ie_internal_plugin_config.hpp"
|
||||
|
@ -6,6 +6,7 @@
|
||||
#include <ngraph_functions/subgraph_builders.hpp>
|
||||
#include "myriad_layers_tests.hpp"
|
||||
#include "vpu_tests_config.hpp"
|
||||
#include <fstream>
|
||||
|
||||
using namespace InferenceEngine;
|
||||
using namespace ::testing;
|
||||
|
@ -8,7 +8,6 @@
|
||||
#include <chrono>
|
||||
#include <iostream>
|
||||
|
||||
#include "ie_ir_reader.hpp"
|
||||
#include "functional_test_utils/plugin_cache.hpp"
|
||||
|
||||
using namespace InferenceEngine;
|
||||
|
@ -3,7 +3,7 @@
|
||||
//
|
||||
|
||||
#include "vpu_layer_tests_utils.hpp"
|
||||
|
||||
#include <fstream>
|
||||
#include "common_test_utils/common_layers_params.hpp"
|
||||
|
||||
using namespace InferenceEngine;
|
||||
|
@ -15,7 +15,6 @@
|
||||
#include <common/include/vpu/utils/error.hpp>
|
||||
|
||||
#include "blob_factory.hpp"
|
||||
#include "ie_ir_reader.hpp"
|
||||
#include "debug.h"
|
||||
#include "vpu_tests_config.hpp"
|
||||
|
||||
|
@ -21,7 +21,6 @@ function(add_helpers target_name)
|
||||
|
||||
target_include_directories(${target_name} PUBLIC "${CMAKE_CURRENT_SOURCE_DIR}"
|
||||
"${IE_MAIN_SOURCE_DIR}/src/inference_engine"
|
||||
$<TARGET_PROPERTY:inference_engine_ir_reader,INTERFACE_INCLUDE_DIRECTORIES>
|
||||
$<TARGET_PROPERTY:inference_engine_lp_transformations,INTERFACE_INCLUDE_DIRECTORIES>
|
||||
$<TARGET_PROPERTY:pugixml,INTERFACE_INCLUDE_DIRECTORIES>
|
||||
"${IE_MAIN_SOURCE_DIR}/src/vpu/"
|
||||
|
@ -6,10 +6,10 @@
|
||||
|
||||
#include <ie_blob.h>
|
||||
#include <ie_core.hpp>
|
||||
#include <cnn_network_impl.hpp>
|
||||
#include <ie_layers_property.hpp>
|
||||
#include <precision_utils.h>
|
||||
#include <common_test_utils/xml_net_builder/xml_net_builder.hpp>
|
||||
#include <xml_helper.hpp>
|
||||
#include <tests_common.hpp>
|
||||
|
||||
#if defined(__GNUC__) && (__GNUC__ <= 4) && (__GNUC_MINOR__ < 9) && !defined(__clang__)
|
||||
|
@ -143,7 +143,7 @@ target_link_libraries(${TARGET_NAME} PRIVATE
|
||||
# dynamic libraries
|
||||
inference_engine_lp_transformations
|
||||
inference_engine_transformations
|
||||
inference_engine_ir_reader
|
||||
inference_engine_ir_reader_v7
|
||||
${CMAKE_DL_LIBS})
|
||||
|
||||
if(TARGET libGNAStubs)
|
||||
|
@ -10,7 +10,6 @@
|
||||
#include <thread>
|
||||
|
||||
#include "unit_test_utils/mocks/mock_icnn_network.hpp"
|
||||
#include "unit_test_utils/mocks/mock_iformat_parser.hpp"
|
||||
#include "common_test_utils/common_utils.hpp"
|
||||
|
||||
using namespace testing;
|
||||
@ -24,184 +23,6 @@ public:
|
||||
ResponseDesc resp;
|
||||
};
|
||||
|
||||
struct MockFormatParserCreator : public FormatParserCreator {
|
||||
MockFormatParserCreator() {
|
||||
_parser = make_shared<MockIFormatParser>();
|
||||
}
|
||||
|
||||
std::shared_ptr<IFormatParser> create(size_t version) override {
|
||||
return _parser;
|
||||
}
|
||||
|
||||
MockIFormatParser* getParser() {
|
||||
return _parser.get();
|
||||
}
|
||||
|
||||
private:
|
||||
std::shared_ptr<MockIFormatParser> _parser;
|
||||
};
|
||||
|
||||
TEST_F(CNNNetReaderImplTest, validateIsCalled) {
|
||||
std::string model =
|
||||
"<net name=\"PVANET\" version=\"2\" batch=\"1\">"
|
||||
" <layers>"
|
||||
" <layer name=\"data\" type=\"Input\" precision=\"FP32\" id=\"0\">"
|
||||
" <output>"
|
||||
" <port id=\"0\">"
|
||||
" <dim>1</dim>"
|
||||
" <dim>3</dim>"
|
||||
" <dim>544</dim>"
|
||||
" <dim>992</dim>"
|
||||
" </port>"
|
||||
" </output>"
|
||||
" </layer>"
|
||||
" <layer name=\"conv1_1_conv\" type=\"Convolution\" precision=\"FP32\" id=\"2\">"
|
||||
" <convolution_data stride-x=\"2\" stride-y=\"2\" pad-x=\"3\" pad-y=\"3\" kernel-x=\"7\" kernel-y=\"7\" output=\"16\" group=\"1\"/>"
|
||||
" <input>"
|
||||
" <port id=\"2\">"
|
||||
" <dim>1</dim>"
|
||||
" <dim>3</dim>"
|
||||
" <dim>544</dim>"
|
||||
" <dim>992</dim>"
|
||||
" </port>"
|
||||
" </input>"
|
||||
" <output>"
|
||||
" <port id=\"3\">"
|
||||
" <dim>1</dim>"
|
||||
" <dim>16</dim>"
|
||||
" <dim>272</dim>"
|
||||
" <dim>496</dim>"
|
||||
" </port>"
|
||||
" </output>"
|
||||
" <weights offset=\"0\" size=\"9408\"/>"
|
||||
" <biases offset=\"9408\" size=\"64\"/>"
|
||||
" </layer>"
|
||||
" <layer name=\"conv1_1_neg\" type=\"Power\" precision=\"FP32\" id=\"3\">"
|
||||
" <power_data power=\"1\" scale=\"-1\" shift=\"0\"/>"
|
||||
" <input>"
|
||||
" <port id=\"4\">"
|
||||
" <dim>1</dim>"
|
||||
" <dim>16</dim>"
|
||||
" <dim>272</dim>"
|
||||
" <dim>496</dim>"
|
||||
" </port>"
|
||||
" </input>"
|
||||
" <output>"
|
||||
" <port id=\"5\">"
|
||||
" <dim>1</dim>"
|
||||
" <dim>16</dim>"
|
||||
" <dim>272</dim>"
|
||||
" <dim>496</dim>"
|
||||
" </port>"
|
||||
" </output>"
|
||||
" </layer>"
|
||||
" <layer name=\"conv1_1_concat\" type=\"Concat\" precision=\"FP32\" id=\"4\">"
|
||||
" <concat_data axis=\"1\"/>"
|
||||
" <input>"
|
||||
" <port id=\"6\">"
|
||||
" <dim>1</dim>"
|
||||
" <dim>16</dim>"
|
||||
" <dim>272</dim>"
|
||||
" <dim>496</dim>"
|
||||
" </port>"
|
||||
" <port id=\"7\">"
|
||||
" <dim>1</dim>"
|
||||
" <dim>16</dim>"
|
||||
" <dim>272</dim>"
|
||||
" <dim>496</dim>"
|
||||
" </port>"
|
||||
" </input>"
|
||||
" <output>"
|
||||
" <port id=\"8\">"
|
||||
" <dim>1</dim>"
|
||||
" <dim>32</dim>"
|
||||
" <dim>272</dim>"
|
||||
" <dim>496</dim>"
|
||||
" </port>"
|
||||
" </output>"
|
||||
" </layer>"
|
||||
" <layer name=\"conv1_1_scale\" type=\"ScaleShift\" precision=\"FP32\" id=\"5\">"
|
||||
" <input>"
|
||||
" <port id=\"9\">"
|
||||
" <dim>1</dim>"
|
||||
" <dim>32</dim>"
|
||||
" <dim>272</dim>"
|
||||
" <dim>496</dim>"
|
||||
" </port>"
|
||||
" </input>"
|
||||
" <output>"
|
||||
" <port id=\"10\">"
|
||||
" <dim>1</dim>"
|
||||
" <dim>32</dim>"
|
||||
" <dim>272</dim>"
|
||||
" <dim>496</dim>"
|
||||
" </port>"
|
||||
" </output>"
|
||||
" <weights offset=\"9472\" size=\"128\"/>"
|
||||
" <biases offset=\"9600\" size=\"128\"/>"
|
||||
" </layer>"
|
||||
" <layer name=\"conv1_1_relu\" type=\"ReLU\" precision=\"FP32\" id=\"6\">"
|
||||
" <data negative_slope=\"0\" engine=\"caffe.ReLUParameter.DEFAULT\"/>"
|
||||
" <input>"
|
||||
" <port id=\"11\">"
|
||||
" <dim>1</dim>"
|
||||
" <dim>32</dim>"
|
||||
" <dim>272</dim>"
|
||||
" <dim>496</dim>"
|
||||
" </port>"
|
||||
" </input>"
|
||||
" <output>"
|
||||
" <port id=\"12\">"
|
||||
" <dim>1</dim>"
|
||||
" <dim>32</dim>"
|
||||
" <dim>272</dim>"
|
||||
" <dim>496</dim>"
|
||||
" </port>"
|
||||
" </output>"
|
||||
" </layer>"
|
||||
" <layer name=\"pool1\" type=\"Pooling\" precision=\"FP32\" id=\"7\">"
|
||||
" <pooling_data kernel-x=\"3\" kernel-y=\"3\" pad-x=\"0\" pad-y=\"0\" stride-x=\"2\" stride-y=\"2\" rounding-type=\"ceil\" pool-method=\"max\"/>"
|
||||
" <input>"
|
||||
" <port id=\"13\">"
|
||||
" <dim>1</dim>"
|
||||
" <dim>32</dim>"
|
||||
" <dim>272</dim>"
|
||||
" <dim>496</dim>"
|
||||
" </port>"
|
||||
" </input>"
|
||||
" <output>"
|
||||
" <port id=\"14\">"
|
||||
" <dim>1</dim>"
|
||||
" <dim>32</dim>"
|
||||
" <dim>136</dim>"
|
||||
" <dim>248</dim>"
|
||||
" </port>"
|
||||
" </output>"
|
||||
" </layer>"
|
||||
" </layers>"
|
||||
" <edges>"
|
||||
" <edge from-layer=\"0\" from-port=\"0\" to-layer=\"2\" to-port=\"2\"/>"
|
||||
" <edge from-layer=\"2\" from-port=\"3\" to-layer=\"3\" to-port=\"4\"/>"
|
||||
" <edge from-layer=\"2\" from-port=\"3\" to-layer=\"4\" to-port=\"6\"/>"
|
||||
" <edge from-layer=\"3\" from-port=\"5\" to-layer=\"4\" to-port=\"7\"/>"
|
||||
" <edge from-layer=\"4\" from-port=\"8\" to-layer=\"5\" to-port=\"9\"/>"
|
||||
" <edge from-layer=\"5\" from-port=\"10\" to-layer=\"6\" to-port=\"11\"/>"
|
||||
" <edge from-layer=\"6\" from-port=\"12\" to-layer=\"7\" to-port=\"13\"/>"
|
||||
" </edges>"
|
||||
"</net>";
|
||||
auto parserCreator = make_shared<MockFormatParserCreator>();
|
||||
CNNNetReaderImpl reader(parserCreator);
|
||||
auto network = make_shared<MockCNNNetworkImpl>();
|
||||
auto name = std::string{"AlexNet"};
|
||||
|
||||
EXPECT_CALL(*parserCreator->getParser(), Parse(_)).Times(1).WillOnce(Return(network));
|
||||
EXPECT_CALL(*network.get(), validate(_)).Times(1);
|
||||
EXPECT_CALL(*network.get(), getName()).Times(1).WillOnce(ReturnRef(name));
|
||||
|
||||
ASSERT_NO_THROW(sts = reader.ReadNetwork(model.data(), model.length(), &resp));
|
||||
ASSERT_EQ(OK, sts);
|
||||
}
|
||||
|
||||
TEST_F(CNNNetReaderImplTest, cycleIsDetectedInReader) {
|
||||
std::string model =
|
||||
"<net batch=\"1\" name=\"model\" version=\"2\">"
|
||||
|
Loading…
Reference in New Issue
Block a user