Enabled Unit tests and remove IReaderPtr (#653)

* Enabled Unit tests and remove IReaderPtr

* Fixed unicode tests for Windows

* Fixed typo
This commit is contained in:
Ilya Churaev 2020-05-28 22:40:20 +03:00 committed by GitHub
parent 5f6999ed7e
commit e51e1682ca
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 230 additions and 164 deletions

View File

@ -5,33 +5,28 @@
#include "ie_core.hpp"
#include <unordered_set>
#include <fstream>
#include <functional>
#include <limits>
#include <map>
#include <memory>
#include <sstream>
#include <streambuf>
#include <string>
#include <utility>
#include <vector>
#include <istream>
#include <mutex>
#include "ie_blob_stream.hpp"
#include <ie_reader_ptr.hpp>
#include <ngraph/opsets/opset.hpp>
#include "cpp/ie_cnn_net_reader.h"
#include "cpp/ie_plugin_cpp.hpp"
#include "cpp_interfaces/base/ie_plugin_base.hpp"
#include "details/ie_exception_conversion.hpp"
#include "details/ie_so_pointer.hpp"
#include "file_utils.h"
#include "ie_icore.hpp"
#include "ie_plugin.hpp"
#include "ie_plugin_config.hpp"
#include "ie_profiling.hpp"
#include "ie_util_internal.hpp"
#include "ie_network_reader.hpp"
#include "multi-device/multi_device_config.hpp"
#include "xml_parse_utils.h"
@ -133,79 +128,6 @@ Parameter copyParameterValue(const Parameter & value) {
} // namespace
class Reader: public IReader {
private:
InferenceEngine::IReaderPtr ptr;
std::once_flag readFlag;
std::string name;
std::string location;
InferenceEngine::IReaderPtr getReaderPtr() {
std::call_once(readFlag, [&] () {
FileUtils::FilePath libraryName = FileUtils::toFilePath(location);
FileUtils::FilePath readersLibraryPath = FileUtils::makeSharedLibraryName(getInferenceEngineLibraryPath(), libraryName);
if (!FileUtils::fileExist(readersLibraryPath)) {
THROW_IE_EXCEPTION << "Please, make sure that Inference Engine ONNX reader library "
<< FileUtils::fromFilePath(::FileUtils::makeSharedLibraryName({}, libraryName)) << " is in "
<< getIELibraryPath();
}
ptr = IReaderPtr(readersLibraryPath);
});
return ptr;
}
InferenceEngine::IReaderPtr getReaderPtr() const {
return const_cast<Reader*>(this)->getReaderPtr();
}
void Release() noexcept override {
delete this;
}
public:
using Ptr = std::shared_ptr<Reader>;
Reader(const std::string& name, const std::string location): name(name), location(location) {}
bool supportModel(std::istream& model) const override {
auto reader = getReaderPtr();
return reader->supportModel(model);
}
CNNNetwork read(std::istream& model, const std::vector<IExtensionPtr>& exts) const override {
auto reader = getReaderPtr();
return reader->read(model, exts);
}
CNNNetwork read(std::istream& model, std::istream& weights, const std::vector<IExtensionPtr>& exts) const override {
auto reader = getReaderPtr();
return reader->read(model, weights, exts);
}
std::vector<std::string> getDataFileExtensions() const override {
auto reader = getReaderPtr();
return reader->getDataFileExtensions();
}
std::string getName() const {
return name;
}
};
namespace {
// Extension to plugins creator
std::multimap<std::string, Reader::Ptr> readers;
void registerReaders() {
static std::mutex readerMutex;
std::lock_guard<std::mutex> lock(readerMutex);
// TODO: Read readers info from XML
auto onnxReader = std::make_shared<Reader>("ONNX", std::string("inference_engine_onnx_reader") + std::string(IE_BUILD_POSTFIX));
readers.emplace("onnx", onnxReader);
readers.emplace("prototxt", onnxReader);
auto irReader = std::make_shared<Reader>("IR", std::string("inference_engine_ir_reader") + std::string(IE_BUILD_POSTFIX));
readers.emplace("xml", irReader);
}
} // namespace
CNNNetReaderPtr CreateCNNNetReaderPtr() noexcept {
auto loader = createCnnReaderLoader();
return CNNNetReaderPtr(loader);
@ -374,57 +296,12 @@ public:
CNNNetwork ReadNetwork(const std::string& modelPath, const std::string& binPath) const override {
IE_PROFILING_AUTO_SCOPE(Core::ReadNetwork)
std::ifstream modelStream(modelPath, std::ios::binary);
if (!modelStream.is_open())
THROW_IE_EXCEPTION << "Model file " << modelPath << " cannot be opened!";
auto fileExt = modelPath.substr(modelPath.find_last_of(".") + 1);
for (auto it = readers.lower_bound(fileExt); it != readers.upper_bound(fileExt); it++) {
auto reader = it->second;
if (reader->supportModel(modelStream)) {
// Find weights
std::string bPath = binPath;
if (bPath.empty()) {
auto pathWoExt = modelPath;
auto pos = modelPath.rfind('.');
if (pos != std::string::npos) pathWoExt = modelPath.substr(0, pos);
for (const auto& ext : reader->getDataFileExtensions()) {
bPath = pathWoExt + "." + ext;
if (!FileUtils::fileExist(bPath)) {
bPath.clear();
} else {
break;
}
}
}
if (!bPath.empty()) {
std::ifstream binStream;
binStream.open(bPath, std::ios::binary);
if (!binStream.is_open())
THROW_IE_EXCEPTION << "Weights file " << bPath << " cannot be opened!";
return reader->read(modelStream, binStream, extensions);
}
return reader->read(modelStream, extensions);
}
}
THROW_IE_EXCEPTION << "Unknown model format! Cannot read the model: " << modelPath;
return details::ReadNetwork(modelPath, binPath, extensions);
}
CNNNetwork ReadNetwork(const std::string& model, const Blob::CPtr& weights) const override {
IE_PROFILING_AUTO_SCOPE(Core::ReadNetwork)
std::istringstream modelStream(model);
details::BlobStream binStream(weights);
for (auto it = readers.begin(); it != readers.end(); it++) {
auto reader = it->second;
if (reader->supportModel(modelStream)) {
if (weights)
return reader->read(modelStream, binStream, extensions);
return reader->read(modelStream, extensions);
}
}
THROW_IE_EXCEPTION << "Unknown model format! Cannot read the model from string!";
return details::ReadNetwork(model, weights, extensions);
}
ExecutableNetwork LoadNetwork(const CNNNetwork& network, const std::string& deviceName,
@ -704,7 +581,6 @@ Core::Impl::Impl() {
opsetNames.insert("opset1");
opsetNames.insert("opset2");
opsetNames.insert("opset3");
registerReaders();
}
Core::Impl::~Impl() {}

View File

@ -0,0 +1,193 @@
// Copyright (C) 2020 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#include "ie_network_reader.hpp"
#include <details/ie_so_pointer.hpp>
#include <file_utils.h>
#include <ie_blob_stream.hpp>
#include <ie_profiling.hpp>
#include <ie_reader.hpp>
#include <fstream>
#include <istream>
#include <mutex>
#include <map>
namespace InferenceEngine {
namespace details {
/**
* @brief This class defines the name of the fabric for creating an IReader object in DLL
*/
template <>
class SOCreatorTrait<IReader> {
public:
/**
* @brief A name of the fabric for creating IReader object in DLL
*/
static constexpr auto name = "CreateReader";
};
} // namespace details
/**
* @brief This class is a wrapper for reader interfaces
*/
class Reader: public IReader {
private:
InferenceEngine::details::SOPointer<IReader> ptr;
std::once_flag readFlag;
std::string name;
std::string location;
InferenceEngine::details::SOPointer<IReader> getReaderPtr() {
std::call_once(readFlag, [&] () {
FileUtils::FilePath libraryName = FileUtils::toFilePath(location);
FileUtils::FilePath readersLibraryPath = FileUtils::makeSharedLibraryName(getInferenceEngineLibraryPath(), libraryName);
if (!FileUtils::fileExist(readersLibraryPath)) {
THROW_IE_EXCEPTION << "Please, make sure that Inference Engine ONNX reader library "
<< FileUtils::fromFilePath(::FileUtils::makeSharedLibraryName({}, libraryName)) << " is in "
<< getIELibraryPath();
}
ptr = InferenceEngine::details::SOPointer<IReader>(readersLibraryPath);
});
return ptr;
}
InferenceEngine::details::SOPointer<IReader> getReaderPtr() const {
return const_cast<Reader*>(this)->getReaderPtr();
}
void Release() noexcept override {
delete this;
}
public:
using Ptr = std::shared_ptr<Reader>;
Reader(const std::string& name, const std::string location): name(name), location(location) {}
bool supportModel(std::istream& model) const override {
auto reader = getReaderPtr();
return reader->supportModel(model);
}
CNNNetwork read(std::istream& model, const std::vector<IExtensionPtr>& exts) const override {
auto reader = getReaderPtr();
return reader->read(model, exts);
}
CNNNetwork read(std::istream& model, std::istream& weights, const std::vector<IExtensionPtr>& exts) const override {
auto reader = getReaderPtr();
return reader->read(model, weights, exts);
}
std::vector<std::string> getDataFileExtensions() const override {
auto reader = getReaderPtr();
return reader->getDataFileExtensions();
}
std::string getName() const {
return name;
}
};
namespace {
// Extension to plugins creator
std::multimap<std::string, Reader::Ptr> readers;
void registerReaders() {
IE_PROFILING_AUTO_SCOPE(details::registerReaders)
static bool initialized = false;
static std::mutex readerMutex;
std::lock_guard<std::mutex> lock(readerMutex);
if (initialized) return;
// TODO: Read readers info from XML
auto onnxReader = std::make_shared<Reader>("ONNX", std::string("inference_engine_onnx_reader") + std::string(IE_BUILD_POSTFIX));
readers.emplace("onnx", onnxReader);
readers.emplace("prototxt", onnxReader);
auto irReader = std::make_shared<Reader>("IR", std::string("inference_engine_ir_reader") + std::string(IE_BUILD_POSTFIX));
readers.emplace("xml", irReader);
initialized = true;
}
} // namespace
CNNNetwork details::ReadNetwork(const std::string& modelPath, const std::string& binPath, const std::vector<IExtensionPtr>& exts) {
IE_PROFILING_AUTO_SCOPE(details::ReadNetwork)
// Register readers if it is needed
registerReaders();
// Fix unicode name
#if defined(ENABLE_UNICODE_PATH_SUPPORT) && defined(_WIN32)
std::wstring model_path = InferenceEngine::details::multiByteCharToWString(modelPath.c_str());
#else
std::string model_path = modelPath;
#endif
// Try to open model file
std::ifstream modelStream(model_path, std::ios::binary);
if (!modelStream.is_open())
THROW_IE_EXCEPTION << "Model file " << modelPath << " cannot be opened!";
// Find reader for model extension
auto fileExt = modelPath.substr(modelPath.find_last_of(".") + 1);
for (auto it = readers.lower_bound(fileExt); it != readers.upper_bound(fileExt); it++) {
auto reader = it->second;
// Check that reader supports the model
if (reader->supportModel(modelStream)) {
// Find weights
std::string bPath = binPath;
if (bPath.empty()) {
auto pathWoExt = modelPath;
auto pos = modelPath.rfind('.');
if (pos != std::string::npos) pathWoExt = modelPath.substr(0, pos);
for (const auto& ext : reader->getDataFileExtensions()) {
bPath = pathWoExt + "." + ext;
if (!FileUtils::fileExist(bPath)) {
bPath.clear();
} else {
break;
}
}
}
if (!bPath.empty()) {
// Open weights file
#if defined(ENABLE_UNICODE_PATH_SUPPORT) && defined(_WIN32)
std::wstring weights_path = InferenceEngine::details::multiByteCharToWString(bPath.c_str());
#else
std::string weights_path = bPath;
#endif
std::ifstream binStream;
binStream.open(weights_path, std::ios::binary);
if (!binStream.is_open())
THROW_IE_EXCEPTION << "Weights file " << bPath << " cannot be opened!";
// read model with weights
return reader->read(modelStream, binStream, exts);
}
// read model without weights
return reader->read(modelStream, exts);
}
}
THROW_IE_EXCEPTION << "Unknown model format! Cannot read the model: " << modelPath;
}
CNNNetwork details::ReadNetwork(const std::string& model, const Blob::CPtr& weights, const std::vector<IExtensionPtr>& exts) {
IE_PROFILING_AUTO_SCOPE(details::ReadNetwork)
// Register readers if it is needed
registerReaders();
std::istringstream modelStream(model);
details::BlobStream binStream(weights);
for (auto it = readers.begin(); it != readers.end(); it++) {
auto reader = it->second;
if (reader->supportModel(modelStream)) {
if (weights)
return reader->read(modelStream, binStream, exts);
return reader->read(modelStream, exts);
}
}
THROW_IE_EXCEPTION << "Unknown model format! Cannot read the model from string!";
}
} // namespace InferenceEngine

View File

@ -0,0 +1,33 @@
// Copyright (C) 2020 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#pragma once
#include <cpp/ie_cnn_network.h>
#include <ie_blob.h>
#include <string>
namespace InferenceEngine {
namespace details {
/**
* @brief Reads IR xml and bin files
* @param modelPath path to IR file
* @param binPath path to bin file, if path is empty, will try to read bin file with the same name as xml and
* if bin file with the same name was not found, will load IR without weights.
* @param exts vector with extensions
* @return CNNNetwork
*/
CNNNetwork ReadNetwork(const std::string& modelPath, const std::string& binPath, const std::vector<IExtensionPtr>& exts);
/**
* @brief Reads IR xml and bin (with the same name) files
* @param model string with IR
* @param weights shared pointer to constant blob with weights
* @param exts vector with extensions
* @return CNNNetwork
*/
CNNNetwork ReadNetwork(const std::string& model, const Blob::CPtr& weights, const std::vector<IExtensionPtr>& exts);
} // namespace details
} // namespace InferenceEngine

View File

@ -1,36 +0,0 @@
// Copyright (C) 2020 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#pragma once
#include <string>
#include <details/ie_so_pointer.hpp>
#include "ie_reader.hpp"
namespace InferenceEngine {
namespace details {
/**
* @brief This class defines the name of the fabric for creating an IReader object in DLL
*/
template <>
class SOCreatorTrait<IReader> {
public:
/**
* @brief A name of the fabric for creating IReader object in DLL
*/
static constexpr auto name = "CreateReader";
};
} // namespace details
/**
* @brief A C++ helper to work with objects created by the plugin.
*
* Implements different interfaces.
*/
using IReaderPtr = InferenceEngine::details::SOPointer<IReader>;
} // namespace InferenceEngine

View File

@ -107,7 +107,7 @@ TEST_P(NetReaderTest, ReadNetworkTwiceSeparately) {
#ifdef ENABLE_UNICODE_PATH_SUPPORT
TEST_P(NetReaderTest, DISABLED_ReadCorrectModelWithWeightsUnicodePath) {
TEST_P(NetReaderTest, ReadCorrectModelWithWeightsUnicodePath) {
GTEST_COUT << "params.modelPath: '" << _modelPath << "'" << std::endl;
GTEST_COUT << "params.weightsPath: '" << _weightsPath << "'" << std::endl;
GTEST_COUT << "params.netPrc: '" << _netPrc.name() << "'" << std::endl;