Enabled Unit tests and remove IReaderPtr (#653)
* Enabled Unit tests and remove IReaderPtr * Fixed unicode tests for Windows * Fixed typo
This commit is contained in:
parent
5f6999ed7e
commit
e51e1682ca
@ -5,33 +5,28 @@
|
||||
#include "ie_core.hpp"
|
||||
|
||||
#include <unordered_set>
|
||||
#include <fstream>
|
||||
#include <functional>
|
||||
#include <limits>
|
||||
#include <map>
|
||||
#include <memory>
|
||||
#include <sstream>
|
||||
#include <streambuf>
|
||||
#include <string>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
#include <istream>
|
||||
#include <mutex>
|
||||
|
||||
#include "ie_blob_stream.hpp"
|
||||
#include <ie_reader_ptr.hpp>
|
||||
#include <ngraph/opsets/opset.hpp>
|
||||
#include "cpp/ie_cnn_net_reader.h"
|
||||
#include "cpp/ie_plugin_cpp.hpp"
|
||||
#include "cpp_interfaces/base/ie_plugin_base.hpp"
|
||||
#include "details/ie_exception_conversion.hpp"
|
||||
#include "details/ie_so_pointer.hpp"
|
||||
#include "file_utils.h"
|
||||
#include "ie_icore.hpp"
|
||||
#include "ie_plugin.hpp"
|
||||
#include "ie_plugin_config.hpp"
|
||||
#include "ie_profiling.hpp"
|
||||
#include "ie_util_internal.hpp"
|
||||
#include "ie_network_reader.hpp"
|
||||
#include "multi-device/multi_device_config.hpp"
|
||||
#include "xml_parse_utils.h"
|
||||
|
||||
@ -133,79 +128,6 @@ Parameter copyParameterValue(const Parameter & value) {
|
||||
|
||||
} // namespace
|
||||
|
||||
class Reader: public IReader {
|
||||
private:
|
||||
InferenceEngine::IReaderPtr ptr;
|
||||
std::once_flag readFlag;
|
||||
std::string name;
|
||||
std::string location;
|
||||
|
||||
InferenceEngine::IReaderPtr getReaderPtr() {
|
||||
std::call_once(readFlag, [&] () {
|
||||
FileUtils::FilePath libraryName = FileUtils::toFilePath(location);
|
||||
FileUtils::FilePath readersLibraryPath = FileUtils::makeSharedLibraryName(getInferenceEngineLibraryPath(), libraryName);
|
||||
|
||||
if (!FileUtils::fileExist(readersLibraryPath)) {
|
||||
THROW_IE_EXCEPTION << "Please, make sure that Inference Engine ONNX reader library "
|
||||
<< FileUtils::fromFilePath(::FileUtils::makeSharedLibraryName({}, libraryName)) << " is in "
|
||||
<< getIELibraryPath();
|
||||
}
|
||||
ptr = IReaderPtr(readersLibraryPath);
|
||||
});
|
||||
|
||||
return ptr;
|
||||
}
|
||||
|
||||
InferenceEngine::IReaderPtr getReaderPtr() const {
|
||||
return const_cast<Reader*>(this)->getReaderPtr();
|
||||
}
|
||||
|
||||
void Release() noexcept override {
|
||||
delete this;
|
||||
}
|
||||
|
||||
public:
|
||||
using Ptr = std::shared_ptr<Reader>;
|
||||
Reader(const std::string& name, const std::string location): name(name), location(location) {}
|
||||
bool supportModel(std::istream& model) const override {
|
||||
auto reader = getReaderPtr();
|
||||
return reader->supportModel(model);
|
||||
}
|
||||
CNNNetwork read(std::istream& model, const std::vector<IExtensionPtr>& exts) const override {
|
||||
auto reader = getReaderPtr();
|
||||
return reader->read(model, exts);
|
||||
}
|
||||
CNNNetwork read(std::istream& model, std::istream& weights, const std::vector<IExtensionPtr>& exts) const override {
|
||||
auto reader = getReaderPtr();
|
||||
return reader->read(model, weights, exts);
|
||||
}
|
||||
std::vector<std::string> getDataFileExtensions() const override {
|
||||
auto reader = getReaderPtr();
|
||||
return reader->getDataFileExtensions();
|
||||
}
|
||||
std::string getName() const {
|
||||
return name;
|
||||
}
|
||||
};
|
||||
|
||||
namespace {
|
||||
|
||||
// Extension to plugins creator
|
||||
std::multimap<std::string, Reader::Ptr> readers;
|
||||
|
||||
void registerReaders() {
|
||||
static std::mutex readerMutex;
|
||||
std::lock_guard<std::mutex> lock(readerMutex);
|
||||
// TODO: Read readers info from XML
|
||||
auto onnxReader = std::make_shared<Reader>("ONNX", std::string("inference_engine_onnx_reader") + std::string(IE_BUILD_POSTFIX));
|
||||
readers.emplace("onnx", onnxReader);
|
||||
readers.emplace("prototxt", onnxReader);
|
||||
auto irReader = std::make_shared<Reader>("IR", std::string("inference_engine_ir_reader") + std::string(IE_BUILD_POSTFIX));
|
||||
readers.emplace("xml", irReader);
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
CNNNetReaderPtr CreateCNNNetReaderPtr() noexcept {
|
||||
auto loader = createCnnReaderLoader();
|
||||
return CNNNetReaderPtr(loader);
|
||||
@ -374,57 +296,12 @@ public:
|
||||
|
||||
CNNNetwork ReadNetwork(const std::string& modelPath, const std::string& binPath) const override {
|
||||
IE_PROFILING_AUTO_SCOPE(Core::ReadNetwork)
|
||||
|
||||
std::ifstream modelStream(modelPath, std::ios::binary);
|
||||
if (!modelStream.is_open())
|
||||
THROW_IE_EXCEPTION << "Model file " << modelPath << " cannot be opened!";
|
||||
|
||||
auto fileExt = modelPath.substr(modelPath.find_last_of(".") + 1);
|
||||
for (auto it = readers.lower_bound(fileExt); it != readers.upper_bound(fileExt); it++) {
|
||||
auto reader = it->second;
|
||||
if (reader->supportModel(modelStream)) {
|
||||
// Find weights
|
||||
std::string bPath = binPath;
|
||||
if (bPath.empty()) {
|
||||
auto pathWoExt = modelPath;
|
||||
auto pos = modelPath.rfind('.');
|
||||
if (pos != std::string::npos) pathWoExt = modelPath.substr(0, pos);
|
||||
for (const auto& ext : reader->getDataFileExtensions()) {
|
||||
bPath = pathWoExt + "." + ext;
|
||||
if (!FileUtils::fileExist(bPath)) {
|
||||
bPath.clear();
|
||||
} else {
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
if (!bPath.empty()) {
|
||||
std::ifstream binStream;
|
||||
binStream.open(bPath, std::ios::binary);
|
||||
if (!binStream.is_open())
|
||||
THROW_IE_EXCEPTION << "Weights file " << bPath << " cannot be opened!";
|
||||
return reader->read(modelStream, binStream, extensions);
|
||||
}
|
||||
return reader->read(modelStream, extensions);
|
||||
}
|
||||
}
|
||||
THROW_IE_EXCEPTION << "Unknown model format! Cannot read the model: " << modelPath;
|
||||
return details::ReadNetwork(modelPath, binPath, extensions);
|
||||
}
|
||||
|
||||
CNNNetwork ReadNetwork(const std::string& model, const Blob::CPtr& weights) const override {
|
||||
IE_PROFILING_AUTO_SCOPE(Core::ReadNetwork)
|
||||
std::istringstream modelStream(model);
|
||||
details::BlobStream binStream(weights);
|
||||
|
||||
for (auto it = readers.begin(); it != readers.end(); it++) {
|
||||
auto reader = it->second;
|
||||
if (reader->supportModel(modelStream)) {
|
||||
if (weights)
|
||||
return reader->read(modelStream, binStream, extensions);
|
||||
return reader->read(modelStream, extensions);
|
||||
}
|
||||
}
|
||||
THROW_IE_EXCEPTION << "Unknown model format! Cannot read the model from string!";
|
||||
return details::ReadNetwork(model, weights, extensions);
|
||||
}
|
||||
|
||||
ExecutableNetwork LoadNetwork(const CNNNetwork& network, const std::string& deviceName,
|
||||
@ -704,7 +581,6 @@ Core::Impl::Impl() {
|
||||
opsetNames.insert("opset1");
|
||||
opsetNames.insert("opset2");
|
||||
opsetNames.insert("opset3");
|
||||
registerReaders();
|
||||
}
|
||||
|
||||
Core::Impl::~Impl() {}
|
||||
|
193
inference-engine/src/inference_engine/ie_network_reader.cpp
Normal file
193
inference-engine/src/inference_engine/ie_network_reader.cpp
Normal file
@ -0,0 +1,193 @@
|
||||
// Copyright (C) 2020 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#include "ie_network_reader.hpp"
|
||||
|
||||
#include <details/ie_so_pointer.hpp>
|
||||
#include <file_utils.h>
|
||||
#include <ie_blob_stream.hpp>
|
||||
#include <ie_profiling.hpp>
|
||||
#include <ie_reader.hpp>
|
||||
|
||||
#include <fstream>
|
||||
#include <istream>
|
||||
#include <mutex>
|
||||
#include <map>
|
||||
|
||||
namespace InferenceEngine {
|
||||
|
||||
namespace details {
|
||||
|
||||
/**
|
||||
* @brief This class defines the name of the fabric for creating an IReader object in DLL
|
||||
*/
|
||||
template <>
|
||||
class SOCreatorTrait<IReader> {
|
||||
public:
|
||||
/**
|
||||
* @brief A name of the fabric for creating IReader object in DLL
|
||||
*/
|
||||
static constexpr auto name = "CreateReader";
|
||||
};
|
||||
|
||||
} // namespace details
|
||||
|
||||
/**
|
||||
* @brief This class is a wrapper for reader interfaces
|
||||
*/
|
||||
class Reader: public IReader {
|
||||
private:
|
||||
InferenceEngine::details::SOPointer<IReader> ptr;
|
||||
std::once_flag readFlag;
|
||||
std::string name;
|
||||
std::string location;
|
||||
|
||||
InferenceEngine::details::SOPointer<IReader> getReaderPtr() {
|
||||
std::call_once(readFlag, [&] () {
|
||||
FileUtils::FilePath libraryName = FileUtils::toFilePath(location);
|
||||
FileUtils::FilePath readersLibraryPath = FileUtils::makeSharedLibraryName(getInferenceEngineLibraryPath(), libraryName);
|
||||
|
||||
if (!FileUtils::fileExist(readersLibraryPath)) {
|
||||
THROW_IE_EXCEPTION << "Please, make sure that Inference Engine ONNX reader library "
|
||||
<< FileUtils::fromFilePath(::FileUtils::makeSharedLibraryName({}, libraryName)) << " is in "
|
||||
<< getIELibraryPath();
|
||||
}
|
||||
ptr = InferenceEngine::details::SOPointer<IReader>(readersLibraryPath);
|
||||
});
|
||||
|
||||
return ptr;
|
||||
}
|
||||
|
||||
InferenceEngine::details::SOPointer<IReader> getReaderPtr() const {
|
||||
return const_cast<Reader*>(this)->getReaderPtr();
|
||||
}
|
||||
|
||||
void Release() noexcept override {
|
||||
delete this;
|
||||
}
|
||||
|
||||
public:
|
||||
using Ptr = std::shared_ptr<Reader>;
|
||||
Reader(const std::string& name, const std::string location): name(name), location(location) {}
|
||||
bool supportModel(std::istream& model) const override {
|
||||
auto reader = getReaderPtr();
|
||||
return reader->supportModel(model);
|
||||
}
|
||||
CNNNetwork read(std::istream& model, const std::vector<IExtensionPtr>& exts) const override {
|
||||
auto reader = getReaderPtr();
|
||||
return reader->read(model, exts);
|
||||
}
|
||||
CNNNetwork read(std::istream& model, std::istream& weights, const std::vector<IExtensionPtr>& exts) const override {
|
||||
auto reader = getReaderPtr();
|
||||
return reader->read(model, weights, exts);
|
||||
}
|
||||
std::vector<std::string> getDataFileExtensions() const override {
|
||||
auto reader = getReaderPtr();
|
||||
return reader->getDataFileExtensions();
|
||||
}
|
||||
std::string getName() const {
|
||||
return name;
|
||||
}
|
||||
};
|
||||
|
||||
namespace {
|
||||
|
||||
// Extension to plugins creator
|
||||
std::multimap<std::string, Reader::Ptr> readers;
|
||||
|
||||
void registerReaders() {
|
||||
IE_PROFILING_AUTO_SCOPE(details::registerReaders)
|
||||
static bool initialized = false;
|
||||
static std::mutex readerMutex;
|
||||
std::lock_guard<std::mutex> lock(readerMutex);
|
||||
if (initialized) return;
|
||||
// TODO: Read readers info from XML
|
||||
auto onnxReader = std::make_shared<Reader>("ONNX", std::string("inference_engine_onnx_reader") + std::string(IE_BUILD_POSTFIX));
|
||||
readers.emplace("onnx", onnxReader);
|
||||
readers.emplace("prototxt", onnxReader);
|
||||
auto irReader = std::make_shared<Reader>("IR", std::string("inference_engine_ir_reader") + std::string(IE_BUILD_POSTFIX));
|
||||
readers.emplace("xml", irReader);
|
||||
initialized = true;
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
CNNNetwork details::ReadNetwork(const std::string& modelPath, const std::string& binPath, const std::vector<IExtensionPtr>& exts) {
|
||||
IE_PROFILING_AUTO_SCOPE(details::ReadNetwork)
|
||||
// Register readers if it is needed
|
||||
registerReaders();
|
||||
|
||||
// Fix unicode name
|
||||
#if defined(ENABLE_UNICODE_PATH_SUPPORT) && defined(_WIN32)
|
||||
std::wstring model_path = InferenceEngine::details::multiByteCharToWString(modelPath.c_str());
|
||||
#else
|
||||
std::string model_path = modelPath;
|
||||
#endif
|
||||
// Try to open model file
|
||||
std::ifstream modelStream(model_path, std::ios::binary);
|
||||
if (!modelStream.is_open())
|
||||
THROW_IE_EXCEPTION << "Model file " << modelPath << " cannot be opened!";
|
||||
|
||||
// Find reader for model extension
|
||||
auto fileExt = modelPath.substr(modelPath.find_last_of(".") + 1);
|
||||
for (auto it = readers.lower_bound(fileExt); it != readers.upper_bound(fileExt); it++) {
|
||||
auto reader = it->second;
|
||||
// Check that reader supports the model
|
||||
if (reader->supportModel(modelStream)) {
|
||||
// Find weights
|
||||
std::string bPath = binPath;
|
||||
if (bPath.empty()) {
|
||||
auto pathWoExt = modelPath;
|
||||
auto pos = modelPath.rfind('.');
|
||||
if (pos != std::string::npos) pathWoExt = modelPath.substr(0, pos);
|
||||
for (const auto& ext : reader->getDataFileExtensions()) {
|
||||
bPath = pathWoExt + "." + ext;
|
||||
if (!FileUtils::fileExist(bPath)) {
|
||||
bPath.clear();
|
||||
} else {
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
if (!bPath.empty()) {
|
||||
// Open weights file
|
||||
#if defined(ENABLE_UNICODE_PATH_SUPPORT) && defined(_WIN32)
|
||||
std::wstring weights_path = InferenceEngine::details::multiByteCharToWString(bPath.c_str());
|
||||
#else
|
||||
std::string weights_path = bPath;
|
||||
#endif
|
||||
std::ifstream binStream;
|
||||
binStream.open(weights_path, std::ios::binary);
|
||||
if (!binStream.is_open())
|
||||
THROW_IE_EXCEPTION << "Weights file " << bPath << " cannot be opened!";
|
||||
|
||||
// read model with weights
|
||||
return reader->read(modelStream, binStream, exts);
|
||||
}
|
||||
// read model without weights
|
||||
return reader->read(modelStream, exts);
|
||||
}
|
||||
}
|
||||
THROW_IE_EXCEPTION << "Unknown model format! Cannot read the model: " << modelPath;
|
||||
}
|
||||
|
||||
CNNNetwork details::ReadNetwork(const std::string& model, const Blob::CPtr& weights, const std::vector<IExtensionPtr>& exts) {
|
||||
IE_PROFILING_AUTO_SCOPE(details::ReadNetwork)
|
||||
// Register readers if it is needed
|
||||
registerReaders();
|
||||
std::istringstream modelStream(model);
|
||||
details::BlobStream binStream(weights);
|
||||
|
||||
for (auto it = readers.begin(); it != readers.end(); it++) {
|
||||
auto reader = it->second;
|
||||
if (reader->supportModel(modelStream)) {
|
||||
if (weights)
|
||||
return reader->read(modelStream, binStream, exts);
|
||||
return reader->read(modelStream, exts);
|
||||
}
|
||||
}
|
||||
THROW_IE_EXCEPTION << "Unknown model format! Cannot read the model from string!";
|
||||
}
|
||||
|
||||
} // namespace InferenceEngine
|
33
inference-engine/src/inference_engine/ie_network_reader.hpp
Normal file
33
inference-engine/src/inference_engine/ie_network_reader.hpp
Normal file
@ -0,0 +1,33 @@
|
||||
// Copyright (C) 2020 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <cpp/ie_cnn_network.h>
|
||||
#include <ie_blob.h>
|
||||
#include <string>
|
||||
|
||||
namespace InferenceEngine {
|
||||
namespace details {
|
||||
|
||||
/**
|
||||
* @brief Reads IR xml and bin files
|
||||
* @param modelPath path to IR file
|
||||
* @param binPath path to bin file, if path is empty, will try to read bin file with the same name as xml and
|
||||
* if bin file with the same name was not found, will load IR without weights.
|
||||
* @param exts vector with extensions
|
||||
* @return CNNNetwork
|
||||
*/
|
||||
CNNNetwork ReadNetwork(const std::string& modelPath, const std::string& binPath, const std::vector<IExtensionPtr>& exts);
|
||||
/**
|
||||
* @brief Reads IR xml and bin (with the same name) files
|
||||
* @param model string with IR
|
||||
* @param weights shared pointer to constant blob with weights
|
||||
* @param exts vector with extensions
|
||||
* @return CNNNetwork
|
||||
*/
|
||||
CNNNetwork ReadNetwork(const std::string& model, const Blob::CPtr& weights, const std::vector<IExtensionPtr>& exts);
|
||||
|
||||
} // namespace details
|
||||
} // namespace InferenceEngine
|
@ -1,36 +0,0 @@
|
||||
// Copyright (C) 2020 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <string>
|
||||
|
||||
#include <details/ie_so_pointer.hpp>
|
||||
#include "ie_reader.hpp"
|
||||
|
||||
namespace InferenceEngine {
|
||||
namespace details {
|
||||
|
||||
/**
|
||||
* @brief This class defines the name of the fabric for creating an IReader object in DLL
|
||||
*/
|
||||
template <>
|
||||
class SOCreatorTrait<IReader> {
|
||||
public:
|
||||
/**
|
||||
* @brief A name of the fabric for creating IReader object in DLL
|
||||
*/
|
||||
static constexpr auto name = "CreateReader";
|
||||
};
|
||||
|
||||
} // namespace details
|
||||
|
||||
/**
|
||||
* @brief A C++ helper to work with objects created by the plugin.
|
||||
*
|
||||
* Implements different interfaces.
|
||||
*/
|
||||
using IReaderPtr = InferenceEngine::details::SOPointer<IReader>;
|
||||
|
||||
} // namespace InferenceEngine
|
@ -107,7 +107,7 @@ TEST_P(NetReaderTest, ReadNetworkTwiceSeparately) {
|
||||
|
||||
#ifdef ENABLE_UNICODE_PATH_SUPPORT
|
||||
|
||||
TEST_P(NetReaderTest, DISABLED_ReadCorrectModelWithWeightsUnicodePath) {
|
||||
TEST_P(NetReaderTest, ReadCorrectModelWithWeightsUnicodePath) {
|
||||
GTEST_COUT << "params.modelPath: '" << _modelPath << "'" << std::endl;
|
||||
GTEST_COUT << "params.weightsPath: '" << _weightsPath << "'" << std::endl;
|
||||
GTEST_COUT << "params.netPrc: '" << _netPrc.name() << "'" << std::endl;
|
||||
|
Loading…
Reference in New Issue
Block a user