Moved Cache manager to new API (#15872)

* Moved Cache manager to new API

* Moved cache guard to ov namespace

* Added new files
This commit is contained in:
Ilya Churaev 2023-02-22 14:51:33 +04:00 committed by GitHub
parent 548f972e19
commit 877018bab6
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 39 additions and 39 deletions

View File

@ -2,11 +2,11 @@
// SPDX-License-Identifier: Apache-2.0
//
#include "ie_cache_guard.hpp"
#include "cache_guard.hpp"
#include "ie_common.h"
namespace InferenceEngine {
namespace ov {
CacheGuardEntry::CacheGuardEntry(CacheGuard& cacheGuard,
const std::string& hash,
@ -23,16 +23,16 @@ CacheGuardEntry::CacheGuardEntry(CacheGuard& cacheGuard,
CacheGuardEntry::~CacheGuardEntry() {
m_refCount--;
m_mutex->unlock();
m_cacheGuard.checkForRemove(m_hash);
m_cacheGuard.check_for_remove(m_hash);
}
void CacheGuardEntry::performLock() {
void CacheGuardEntry::perform_lock() {
m_mutex->lock();
}
//////////////////////////////////////////////////////
std::unique_ptr<CacheGuardEntry> CacheGuard::getHashLock(const std::string& hash) {
std::unique_ptr<CacheGuardEntry> CacheGuard::get_hash_lock(const std::string& hash) {
std::unique_lock<std::mutex> lock(m_tableMutex);
auto& data = m_table[hash];
std::unique_ptr<CacheGuardEntry> res;
@ -47,12 +47,12 @@ std::unique_ptr<CacheGuardEntry> CacheGuard::getHashLock(const std::string& hash
}
throw;
}
lock.unlock(); // can unlock table lock here, as refCounter is positive and nobody can remove entry
res->performLock(); // in case of exception, 'res' will be destroyed and item will be cleaned up from table
lock.unlock(); // can unlock table lock here, as refCounter is positive and nobody can remove entry
res->perform_lock(); // in case of exception, 'res' will be destroyed and item will be cleaned up from table
return res;
}
void CacheGuard::checkForRemove(const std::string& hash) {
void CacheGuard::check_for_remove(const std::string& hash) {
std::lock_guard<std::mutex> lock(m_tableMutex);
if (m_table.count(hash)) {
auto& data = m_table[hash];
@ -63,4 +63,4 @@ void CacheGuard::checkForRemove(const std::string& hash) {
}
}
} // namespace InferenceEngine
} // namespace ov

View File

@ -7,7 +7,7 @@
/**
* @brief This is a header file for the Inference Engine Cache Guard class C++ API
*
* @file ie_cache_guard.hpp
* @file cache_guard.hpp
*/
#include <atomic>
@ -17,7 +17,7 @@
#include <string>
#include <unordered_map>
namespace InferenceEngine {
namespace ov {
class CacheGuard;
/**
@ -58,7 +58,7 @@ public:
*
* @note Will be called only by CacheGuard, it shall not be called from client's code
*/
void performLock();
void perform_lock();
private:
CacheGuard& m_cacheGuard;
@ -92,7 +92,7 @@ public:
*
* @return RAII pointer to CacheGuardEntry
*/
std::unique_ptr<CacheGuardEntry> getHashLock(const std::string& hash);
std::unique_ptr<CacheGuardEntry> get_hash_lock(const std::string& hash);
/**
* @brief Checks whether there is any clients holding the lock after CacheGuardEntry deletion
@ -103,7 +103,7 @@ public:
*
* @return RAII pointer to CacheGuardEntry
*/
void checkForRemove(const std::string& hash);
void check_for_remove(const std::string& hash);
private:
struct Item {
@ -121,4 +121,4 @@ private:
std::unordered_map<std::string, Item> m_table;
};
} // namespace InferenceEngine
} // namespace ov

View File

@ -279,7 +279,7 @@ ov::SoPtr<ov::ICompiledModel> ov::CoreImpl::compile_model(const std::shared_ptr<
model,
create_compile_config(plugin, parsed._deviceName, parsed._config));
bool loadedFromCache = false;
auto lock = cacheGuard.getHashLock(cacheContent.blobId);
auto lock = cacheGuard.get_hash_lock(cacheContent.blobId);
res = load_model_from_cache(cacheContent, plugin, parsed._config, {}, loadedFromCache);
if (!loadedFromCache) {
res = compile_model_impl(model, plugin, parsed._config, {}, cacheContent, forceDisableCache);
@ -317,7 +317,7 @@ ov::SoPtr<ov::ICompiledModel> ov::CoreImpl::compile_model(const std::shared_ptr<
model,
create_compile_config(plugin, parsed._deviceName, parsed._config));
bool loadedFromCache = false;
auto lock = cacheGuard.getHashLock(cacheContent.blobId);
auto lock = cacheGuard.get_hash_lock(cacheContent.blobId);
res = load_model_from_cache(cacheContent, plugin, parsed._config, context, loadedFromCache);
if (!loadedFromCache) {
res = compile_model_impl(model, plugin, parsed._config, context, cacheContent);
@ -367,7 +367,7 @@ ov::SoPtr<ov::ICompiledModel> ov::CoreImpl::compile_model(const std::string& mod
cacheContent.blobId = ov::NetworkCompilationContext::compute_hash(
model_path,
create_compile_config(plugin, parsed._deviceName, parsed._config));
auto lock = cacheGuard.getHashLock(cacheContent.blobId);
auto lock = cacheGuard.get_hash_lock(cacheContent.blobId);
res = load_model_from_cache(cacheContent, plugin, parsed._config, {}, loadedFromCache);
if (!loadedFromCache) {
auto cnnNetwork = ReadNetwork(model_path, std::string());
@ -400,7 +400,7 @@ ov::SoPtr<ov::ICompiledModel> ov::CoreImpl::compile_model(const std::string& mod
model_str,
weights,
create_compile_config(plugin, parsed._deviceName, parsed._config));
auto lock = cacheGuard.getHashLock(cacheContent.blobId);
auto lock = cacheGuard.get_hash_lock(cacheContent.blobId);
res = load_model_from_cache(cacheContent, plugin, parsed._config, {}, loadedFromCache);
if (!loadedFromCache) {
auto cnnNetwork = read_model(model_str, weights);
@ -858,14 +858,14 @@ ov::SoPtr<ov::ICompiledModel> ov::CoreImpl::compile_model_impl(const std::shared
try {
// need to export network for further import from "cache"
OV_ITT_SCOPE(FIRST_INFERENCE, InferenceEngine::itt::domains::IE_LT, "Core::compile_model::Export");
cacheContent.cacheManager->writeCacheEntry(cacheContent.blobId, [&](std::ostream& networkStream) {
cacheContent.cacheManager->write_cache_entry(cacheContent.blobId, [&](std::ostream& networkStream) {
networkStream << ov::CompiledBlobHeader(
InferenceEngine::GetInferenceEngineVersion()->buildNumber,
ov::NetworkCompilationContext::calculate_file_info(cacheContent.modelPath));
execNetwork->export_model(networkStream);
});
} catch (...) {
cacheContent.cacheManager->removeCacheEntry(cacheContent.blobId);
cacheContent.cacheManager->remove_cache_entry(cacheContent.blobId);
throw;
}
}
@ -882,7 +882,7 @@ ov::SoPtr<ov::ICompiledModel> ov::CoreImpl::load_model_from_cache(const CacheCon
OPENVINO_ASSERT(cacheContent.cacheManager != nullptr);
try {
cacheContent.cacheManager->readCacheEntry(cacheContent.blobId, [&](std::istream& networkStream) {
cacheContent.cacheManager->read_cache_entry(cacheContent.blobId, [&](std::istream& networkStream) {
OV_ITT_SCOPE(FIRST_INFERENCE,
InferenceEngine::itt::domains::IE_LT,
"Core::LoadNetworkFromCache::ReadStreamAndImport");
@ -909,10 +909,10 @@ ov::SoPtr<ov::ICompiledModel> ov::CoreImpl::load_model_from_cache(const CacheCon
});
} catch (const HeaderException&) {
// For these exceptions just remove old cache and set that import didn't work
cacheContent.cacheManager->removeCacheEntry(cacheContent.blobId);
cacheContent.cacheManager->remove_cache_entry(cacheContent.blobId);
networkIsImported = false;
} catch (...) {
cacheContent.cacheManager->removeCacheEntry(cacheContent.blobId);
cacheContent.cacheManager->remove_cache_entry(cacheContent.blobId);
networkIsImported = false;
// TODO: temporary disabled by #54335. In future don't throw only for new 'blob_outdated' exception
// throw;
@ -1052,7 +1052,7 @@ void ov::CoreImpl::CoreConfig::fill_config(CacheConfig& config, const std::strin
config._cacheDir = dir;
if (!dir.empty()) {
FileUtils::createDirectoryRecursive(dir);
config._cacheManager = std::make_shared<InferenceEngine::FileStorageCacheManager>(dir);
config._cacheManager = std::make_shared<ov::FileStorageCacheManager>(dir);
} else {
config._cacheManager = nullptr;
}

View File

@ -9,9 +9,9 @@
#include <ie_remote_context.hpp>
#include "any_copy.hpp"
#include "cache_guard.hpp"
#include "cpp_interfaces/interface/ie_iplugin_internal.hpp"
#include "dev/plugin.hpp"
#include "ie_cache_guard.hpp"
#include "ie_cache_manager.hpp"
#include "ie_extension.h"
#include "ie_icore.hpp"
@ -92,7 +92,7 @@ private:
public:
struct CacheConfig {
std::string _cacheDir;
std::shared_ptr<InferenceEngine::ICacheManager> _cacheManager;
std::shared_ptr<ov::ICacheManager> _cacheManager;
};
bool flag_allow_auto_batching = true;
@ -120,11 +120,11 @@ private:
};
struct CacheContent {
explicit CacheContent(const std::shared_ptr<InferenceEngine::ICacheManager>& cache_manager,
explicit CacheContent(const std::shared_ptr<ov::ICacheManager>& cache_manager,
const std::string model_path = {})
: cacheManager(cache_manager),
modelPath(model_path) {}
std::shared_ptr<InferenceEngine::ICacheManager> cacheManager;
std::shared_ptr<ov::ICacheManager> cacheManager;
std::string blobId = {};
std::string modelPath = {};
};
@ -134,7 +134,7 @@ private:
Any get_property_for_core(const std::string& name) const;
mutable InferenceEngine::CacheGuard cacheGuard;
mutable ov::CacheGuard cacheGuard;
struct PluginDescriptor {
ov::util::FilePath libraryLocation;

View File

@ -17,7 +17,7 @@
#include "file_utils.h"
#include "ie_api.h"
namespace InferenceEngine {
namespace ov {
/**
* @brief This class represents private interface for Cache Manager
@ -44,7 +44,7 @@ public:
* @param id Id of cache (hash of the network)
* @param writer Lambda function to be called when stream is created
*/
virtual void writeCacheEntry(const std::string& id, StreamWriter writer) = 0;
virtual void write_cache_entry(const std::string& id, StreamWriter writer) = 0;
/**
* @brief Function passing created input stream
@ -60,7 +60,7 @@ public:
* @param id Id of cache (hash of the network)
* @param reader Lambda function to be called when input stream is created
*/
virtual void readCacheEntry(const std::string& id, StreamReader reader) = 0;
virtual void read_cache_entry(const std::string& id, StreamReader reader) = 0;
/**
* @brief Callback when Inference Engine intends to remove cache entry
@ -69,7 +69,7 @@ public:
*
* @param id Id of cache (hash of the network)
*/
virtual void removeCacheEntry(const std::string& id) = 0;
virtual void remove_cache_entry(const std::string& id) = 0;
};
/**
@ -99,12 +99,12 @@ public:
~FileStorageCacheManager() override = default;
private:
void writeCacheEntry(const std::string& id, StreamWriter writer) override {
void write_cache_entry(const std::string& id, StreamWriter writer) override {
std::ofstream stream(getBlobFile(id), std::ios_base::binary | std::ofstream::out);
writer(stream);
}
void readCacheEntry(const std::string& id, StreamReader reader) override {
void read_cache_entry(const std::string& id, StreamReader reader) override {
auto blobFileName = getBlobFile(id);
if (FileUtils::fileExist(blobFileName)) {
std::ifstream stream(blobFileName, std::ios_base::binary);
@ -112,11 +112,11 @@ private:
}
}
void removeCacheEntry(const std::string& id) override {
void remove_cache_entry(const std::string& id) override {
auto blobFileName = getBlobFile(id);
if (FileUtils::fileExist(blobFileName))
std::remove(blobFileName.c_str());
}
};
} // namespace InferenceEngine
} // namespace ov

View File

@ -14,6 +14,7 @@
#include <vector>
#include "any_copy.hpp"
#include "cache_guard.hpp"
#include "check_network_batchable.hpp"
#include "cnn_network_ngraph_impl.hpp"
#include "compilation_context.hpp"
@ -23,7 +24,6 @@
#include "dev/converter_utils.hpp"
#include "dev/core_impl.hpp"
#include "file_utils.h"
#include "ie_cache_guard.hpp"
#include "ie_cache_manager.hpp"
#include "ie_icore.hpp"
#include "ie_itt.hpp"