Caching support of multi-device scenario (#5040)
* Caching support of multi-device scenario - IE_CORE: introduce CacheGuard which can create locks for specific cache identified by 'hash' - Added functional tests for it Fixes of Thread Sanitizer failures: - ngraph::Serialize - m_ref[i] can create new element, casted to 'const' to avoid this - ngraph::get_opset oprations: reworked to use std::call_once instead of double bool check * Added docs for ie_cache_guard.hpp * Fix Debian 9 compilation issue * Fix build for CentOS 6 Added assert to verify that table of locked hashes is empty on destruction * Fixed review comments
This commit is contained in:
64
inference-engine/src/inference_engine/ie_cache_guard.cpp
Normal file
64
inference-engine/src/inference_engine/ie_cache_guard.cpp
Normal file
@@ -0,0 +1,64 @@
|
||||
// Copyright (C) 2018-2021 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#include "ie_cache_guard.hpp"
|
||||
#include "ie_common.h"
|
||||
|
||||
namespace InferenceEngine {
|
||||
|
||||
CacheGuardEntry::CacheGuardEntry(CacheGuard& cacheGuard, const std::string& hash,
|
||||
std::shared_ptr<std::mutex> m, std::atomic_int& refCount):
|
||||
m_cacheGuard(cacheGuard), m_hash(hash), m_mutex(m), m_refCount(refCount) {
|
||||
// Don't lock mutex right here for exception-safe considerations
|
||||
m_refCount++;
|
||||
}
|
||||
|
||||
CacheGuardEntry::~CacheGuardEntry() {
|
||||
m_refCount--;
|
||||
m_mutex->unlock();
|
||||
m_cacheGuard.checkForRemove(m_hash);
|
||||
}
|
||||
|
||||
void CacheGuardEntry::performLock() {
|
||||
m_mutex->lock();
|
||||
}
|
||||
|
||||
//////////////////////////////////////////////////////
|
||||
|
||||
CacheGuard::~CacheGuard() {
|
||||
IE_ASSERT(m_table.size() == 0);
|
||||
}
|
||||
|
||||
std::unique_ptr<CacheGuardEntry> CacheGuard::getHashLock(const std::string& hash) {
|
||||
std::unique_lock<std::mutex> lock(m_tableMutex);
|
||||
auto& data = m_table[hash];
|
||||
std::unique_ptr<CacheGuardEntry> res;
|
||||
try {
|
||||
// TODO: use std::make_unique when migrated to C++14
|
||||
res = std::unique_ptr<CacheGuardEntry>(
|
||||
new CacheGuardEntry(*this, hash, data.m_mutexPtr, data.m_itemRefCounter));
|
||||
} catch (...) {
|
||||
// In case of exception, we shall remove hash entry if it is not used
|
||||
if (data.m_itemRefCounter == 0) {
|
||||
m_table.erase(hash);
|
||||
}
|
||||
throw;
|
||||
}
|
||||
lock.unlock(); // can unlock table lock here, as refCounter is positive and nobody can remove entry
|
||||
res->performLock(); // in case of exception, 'res' will be destroyed and item will be cleaned up from table
|
||||
return res;
|
||||
}
|
||||
|
||||
void CacheGuard::checkForRemove(const std::string& hash) {
|
||||
std::lock_guard<std::mutex> lock(m_tableMutex);
|
||||
if (m_table.count(hash)) {
|
||||
auto &data = m_table[hash];
|
||||
if (data.m_itemRefCounter == 0) {
|
||||
// Nobody is using this and nobody is waiting for it - can be removed
|
||||
m_table.erase(hash);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace InferenceEngine
|
||||
122
inference-engine/src/inference_engine/ie_cache_guard.hpp
Normal file
122
inference-engine/src/inference_engine/ie_cache_guard.hpp
Normal file
@@ -0,0 +1,122 @@
|
||||
// Copyright (C) 2018-2021 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#pragma once
|
||||
|
||||
/**
|
||||
* @brief This is a header file for the Inference Engine Cache Guard class C++ API
|
||||
*
|
||||
* @file ie_cache_guard.hpp
|
||||
*/
|
||||
|
||||
#include <string>
|
||||
#include <map>
|
||||
#include <mutex>
|
||||
#include <memory>
|
||||
#include <atomic>
|
||||
#include <unordered_map>
|
||||
|
||||
namespace InferenceEngine {
|
||||
|
||||
class CacheGuard;
|
||||
/**
|
||||
* @brief This class represents RAII guard class to protect multiple threads to modify the same cached network
|
||||
* Use CacheGuard::getHashLock(hash) to acquire lock for specific cache entry identified by its 'hash'
|
||||
* On destruction, lock will be released
|
||||
* @see CacheGuard
|
||||
*/
|
||||
class CacheGuardEntry {
|
||||
public:
|
||||
/**
|
||||
* @brief Internal constructor, will be called by @CacheGuard
|
||||
*
|
||||
* @param cacheGuard Reference link to parent's Cache Guard
|
||||
* @param hash String representing hash of network
|
||||
* @param m Shared pointer to mutex for internal locking
|
||||
* @param refCount Reference counter. Will be decremented on CacheGuardEntry destruction
|
||||
*/
|
||||
CacheGuardEntry(CacheGuard& cacheGuard, const std::string& hash,
|
||||
std::shared_ptr<std::mutex> m, std::atomic_int& refCount);
|
||||
CacheGuardEntry(const CacheGuardEntry&) = delete;
|
||||
|
||||
/**
|
||||
* @brief Destructor, will perform the following cleanup
|
||||
*
|
||||
* Decrement reference counter
|
||||
* Unlock associated mutex
|
||||
* Call CacheGuard::checkForRemove to check if appropriate table hash entry is not used anymore and can be deleted
|
||||
*/
|
||||
~CacheGuardEntry();
|
||||
|
||||
/**
|
||||
* @brief Performs real lock of associated mutex
|
||||
* It is separated from construction due to exception safety considerations
|
||||
*
|
||||
* @note Will be called only by CacheGuard, it shall not be called from client's code
|
||||
*/
|
||||
void performLock();
|
||||
|
||||
private:
|
||||
CacheGuard& m_cacheGuard;
|
||||
std::string m_hash;
|
||||
std::shared_ptr<std::mutex> m_mutex;
|
||||
std::atomic_int& m_refCount;
|
||||
};
|
||||
|
||||
/**
|
||||
* @brief This class holds a table of currently locked hashes
|
||||
* Inference engine core will need to obtain a lock for a specific cache to get exclusive access to it
|
||||
* It is needed to avoid race situations when multiple threads try to to write to the same cache simultaneously
|
||||
*
|
||||
* Usage example:
|
||||
* auto hash = <calculate hash for network>;
|
||||
* {
|
||||
* auto lock = m_cacheGuard.getHashLock(hash);
|
||||
* <work with cache entry exclusively>
|
||||
* }
|
||||
*/
|
||||
class CacheGuard {
|
||||
public:
|
||||
CacheGuard() = default;
|
||||
~CacheGuard();
|
||||
|
||||
/**
|
||||
* @brief Gets a lock for a specific cache entry identified by it's hash value
|
||||
* Once returned, client has an exclusive access to cache entry for read/write/delete
|
||||
* If any other thread holds a lock to same hash - this function will not return until it is unlocked
|
||||
*
|
||||
* @param hash String representing hash of network
|
||||
*
|
||||
* @return RAII pointer to CacheGuardEntry
|
||||
*/
|
||||
std::unique_ptr<CacheGuardEntry> getHashLock(const std::string& hash);
|
||||
|
||||
/**
|
||||
* @brief Checks whether there is any clients holding the lock after CacheGuardEntry deletion
|
||||
* It will be called on destruction of CacheGuardEntry and shall not be used directly by client's code
|
||||
* If there is no more clients holding the lock, associated entry will be removed from table unlocked
|
||||
*
|
||||
* @param hash String representing hash of network
|
||||
*
|
||||
* @return RAII pointer to CacheGuardEntry
|
||||
*/
|
||||
void checkForRemove(const std::string& hash);
|
||||
|
||||
private:
|
||||
struct Item {
|
||||
std::shared_ptr<std::mutex> m_mutexPtr { std::make_shared<std::mutex>() };
|
||||
// Reference counter for item usage
|
||||
std::atomic_int m_itemRefCounter {0};
|
||||
|
||||
Item() = default;
|
||||
Item(const Item& other): m_mutexPtr(other.m_mutexPtr),
|
||||
m_itemRefCounter(other.m_itemRefCounter.load()) {}
|
||||
Item(Item&& other): m_mutexPtr(std::move(other.m_mutexPtr)),
|
||||
m_itemRefCounter(other.m_itemRefCounter.load()) {}
|
||||
};
|
||||
std::mutex m_tableMutex;
|
||||
std::unordered_map<std::string, Item> m_table;
|
||||
};
|
||||
|
||||
} // namespace InferenceEngine
|
||||
@@ -21,6 +21,7 @@
|
||||
#include "ie_plugin_cpp.hpp"
|
||||
#include "ie_plugin_config.hpp"
|
||||
#include "ie_cache_manager.hpp"
|
||||
#include "ie_cache_guard.hpp"
|
||||
#include "ie_itt.hpp"
|
||||
#include "file_utils.h"
|
||||
#include "ie_network_reader.hpp"
|
||||
@@ -197,6 +198,8 @@ class Core::Impl : public ICore {
|
||||
// Core settings (cache config, etc)
|
||||
CoreConfig coreConfig;
|
||||
|
||||
CacheGuard cacheGuard;
|
||||
|
||||
struct PluginDescriptor {
|
||||
FileUtils::FilePath libraryLocation;
|
||||
std::map<std::string, std::string> defaultConfig;
|
||||
@@ -447,17 +450,18 @@ public:
|
||||
}
|
||||
auto parsed = parseDeviceNameIntoConfig(context->getDeviceName(), config);
|
||||
auto plugin = GetCPPPluginByName(parsed._deviceName);
|
||||
bool loadedFromCache = false;
|
||||
ExecutableNetwork res;
|
||||
std::string hash;
|
||||
auto cacheManager = coreConfig.getCacheConfig()._cacheManager;
|
||||
if (cacheManager && DeviceSupportsImportExport(plugin)) {
|
||||
hash = CalculateNetworkHash(network, parsed._deviceName, plugin, parsed._config);
|
||||
auto hash = CalculateNetworkHash(network, parsed._deviceName, plugin, parsed._config);
|
||||
bool loadedFromCache = false;
|
||||
auto lock = cacheGuard.getHashLock(hash);
|
||||
res = LoadNetworkFromCache(cacheManager, hash, plugin, parsed._config, context, loadedFromCache);
|
||||
}
|
||||
|
||||
if (!loadedFromCache) {
|
||||
res = LoadNetworkImpl(network, plugin, parsed._config, context, hash);
|
||||
if (!loadedFromCache) {
|
||||
res = LoadNetworkImpl(network, plugin, parsed._config, context, hash);
|
||||
}
|
||||
} else {
|
||||
res = LoadNetworkImpl(network, plugin, parsed._config, context, {});
|
||||
}
|
||||
return res;
|
||||
}
|
||||
@@ -472,17 +476,18 @@ public:
|
||||
parsed._config.erase(CONFIG_KEY_INTERNAL(FORCE_DISABLE_CACHE));
|
||||
}
|
||||
auto plugin = GetCPPPluginByName(parsed._deviceName);
|
||||
bool loadedFromCache = false;
|
||||
ExecutableNetwork res;
|
||||
std::string hash;
|
||||
auto cacheManager = coreConfig.getCacheConfig()._cacheManager;
|
||||
if (!forceDisableCache && cacheManager && DeviceSupportsImportExport(plugin)) {
|
||||
hash = CalculateNetworkHash(network, parsed._deviceName, plugin, parsed._config);
|
||||
auto hash = CalculateNetworkHash(network, parsed._deviceName, plugin, parsed._config);
|
||||
bool loadedFromCache = false;
|
||||
auto lock = cacheGuard.getHashLock(hash);
|
||||
res = LoadNetworkFromCache(cacheManager, hash, plugin, parsed._config, nullptr, loadedFromCache);
|
||||
}
|
||||
|
||||
if (!loadedFromCache) {
|
||||
res = LoadNetworkImpl(network, plugin, parsed._config, nullptr, hash, {}, forceDisableCache);
|
||||
if (!loadedFromCache) {
|
||||
res = LoadNetworkImpl(network, plugin, parsed._config, nullptr, hash, {}, forceDisableCache);
|
||||
}
|
||||
} else {
|
||||
res = LoadNetworkImpl(network, plugin, parsed._config, nullptr, {}, {}, forceDisableCache);
|
||||
}
|
||||
return res;
|
||||
}
|
||||
@@ -493,19 +498,21 @@ public:
|
||||
OV_ITT_SCOPED_TASK(itt::domains::IE_LT, "Core::LoadNetwork::Path");
|
||||
auto parsed = parseDeviceNameIntoConfig(deviceName, config);
|
||||
auto plugin = GetCPPPluginByName(parsed._deviceName);
|
||||
bool loadedFromCache = false;
|
||||
ExecutableNetwork res;
|
||||
std::string hash;
|
||||
auto cacheManager = coreConfig.getCacheConfig()._cacheManager;
|
||||
if (cacheManager && DeviceSupportsImportExport(plugin)) {
|
||||
hash = CalculateFileHash(modelPath, parsed._deviceName, plugin, parsed._config);
|
||||
bool loadedFromCache = false;
|
||||
auto hash = CalculateFileHash(modelPath, parsed._deviceName, plugin, parsed._config);
|
||||
auto lock = cacheGuard.getHashLock(hash);
|
||||
res = LoadNetworkFromCache(cacheManager, hash, plugin, parsed._config,
|
||||
nullptr, loadedFromCache, modelPath);
|
||||
}
|
||||
|
||||
if (!loadedFromCache) {
|
||||
if (!loadedFromCache) {
|
||||
auto cnnNetwork = ReadNetwork(modelPath, std::string());
|
||||
res = LoadNetworkImpl(cnnNetwork, plugin, parsed._config, nullptr, hash, modelPath);
|
||||
}
|
||||
} else {
|
||||
auto cnnNetwork = ReadNetwork(modelPath, std::string());
|
||||
res = LoadNetworkImpl(cnnNetwork, plugin, parsed._config, nullptr, hash, modelPath);
|
||||
res = LoadNetworkImpl(cnnNetwork, plugin, parsed._config, nullptr, {}, modelPath);
|
||||
}
|
||||
return res;
|
||||
}
|
||||
|
||||
@@ -27,6 +27,7 @@
|
||||
#include "functional_test_utils/network_utils.hpp"
|
||||
|
||||
#include "unit_test_utils/mocks/mock_iexecutable_network.hpp"
|
||||
#include "unit_test_utils/mocks/mock_iinfer_request.hpp"
|
||||
|
||||
using namespace InferenceEngine;
|
||||
using namespace ::testing;
|
||||
@@ -97,6 +98,8 @@ public:
|
||||
MOCK_METHOD0(CreateInferRequest, IInferRequest::Ptr());
|
||||
MOCK_CONST_METHOD0(GetInputsInfo, ConstInputsDataMap());
|
||||
MOCK_CONST_METHOD0(GetOutputsInfo, ConstOutputsDataMap());
|
||||
MOCK_CONST_METHOD1(GetConfig, Parameter(const std::string& name));
|
||||
MOCK_CONST_METHOD1(GetMetric, Parameter(const std::string& name));
|
||||
};
|
||||
|
||||
//------------------------------------------------------
|
||||
@@ -240,6 +243,31 @@ public:
|
||||
ie.LoadNetwork(cnnNetwork, context, config);
|
||||
}
|
||||
|
||||
ExecutableNetwork createMockIExecutableNet() {
|
||||
auto mock = std::make_shared<MockIExecutableNetwork>();
|
||||
EXPECT_CALL(*mock, GetInputsInfo(_, _)).Times(AnyNumber()).WillRepeatedly(Return(OK));
|
||||
EXPECT_CALL(*mock, GetOutputsInfo(_, _)).Times(AnyNumber()).WillRepeatedly(Return(OK));
|
||||
EXPECT_CALL(*mock, GetConfig(PluginConfigParams::KEY_PERF_COUNT, _, _)).Times(AnyNumber())
|
||||
.WillRepeatedly(Invoke([&](const std::string &name, Parameter &result, ResponseDesc *resp) {
|
||||
result = PluginConfigParams::NO;
|
||||
return OK;
|
||||
}));
|
||||
EXPECT_CALL(*mock, GetMetric(METRIC_KEY(OPTIMAL_NUMBER_OF_INFER_REQUESTS), _, _)).Times(AnyNumber())
|
||||
.WillRepeatedly(Invoke([&](const std::string &name, Parameter &result, ResponseDesc *resp) {
|
||||
result = (unsigned int) 1;
|
||||
return OK;
|
||||
}));
|
||||
EXPECT_CALL(*mock, CreateInferRequest(_, _)).Times(AnyNumber())
|
||||
.WillRepeatedly(Invoke([&](IInferRequest::Ptr &req, ResponseDesc*) {
|
||||
auto ptr = std::make_shared<MockIInferRequest>();
|
||||
EXPECT_CALL(*ptr, SetCompletionCallback(_)).Times(AnyNumber()).WillRepeatedly(Return(OK));
|
||||
EXPECT_CALL(*ptr, SetUserData(_, _)).Times(AnyNumber()).WillRepeatedly(Return(OK));
|
||||
req = ptr;
|
||||
return OK;
|
||||
}));
|
||||
return ExecutableNetwork(mock);
|
||||
}
|
||||
|
||||
private:
|
||||
template <class T>
|
||||
std::function<T> make_std_function(const std::string& functionName) {
|
||||
@@ -271,18 +299,12 @@ private:
|
||||
ON_CALL(plugin, ImportNetworkImpl(_, _, _)).
|
||||
WillByDefault(Invoke([&](std::istream &istr, RemoteContext::Ptr,
|
||||
const std::map<std::string, std::string> &) {
|
||||
auto mock = std::make_shared<MockIExecutableNetwork>();
|
||||
EXPECT_CALL(*mock, GetInputsInfo(_, _)).Times(AnyNumber()).WillRepeatedly(Return(OK));
|
||||
EXPECT_CALL(*mock, GetOutputsInfo(_, _)).Times(AnyNumber()).WillRepeatedly(Return(OK));
|
||||
return ExecutableNetwork(mock);
|
||||
return createMockIExecutableNet();
|
||||
}));
|
||||
|
||||
ON_CALL(plugin, ImportNetworkImpl(_, _)).
|
||||
WillByDefault(Invoke([&](std::istream &istr, const std::map<std::string, std::string> &) {
|
||||
auto mock = std::make_shared<MockIExecutableNetwork>();
|
||||
EXPECT_CALL(*mock, GetInputsInfo(_, _)).Times(AnyNumber()).WillRepeatedly(Return(OK));
|
||||
EXPECT_CALL(*mock, GetOutputsInfo(_, _)).Times(AnyNumber()).WillRepeatedly(Return(OK));
|
||||
return ExecutableNetwork(mock);
|
||||
return createMockIExecutableNet();
|
||||
}));
|
||||
|
||||
ON_CALL(plugin, LoadExeNetworkImpl(_, _, _)).
|
||||
@@ -318,6 +340,27 @@ private:
|
||||
.WillRepeatedly(Return(ConstInputsDataMap{}));
|
||||
EXPECT_CALL(*net, GetOutputsInfo()).Times(AnyNumber())
|
||||
.WillRepeatedly(Return(ConstOutputsDataMap{}));
|
||||
EXPECT_CALL(*net, GetConfig(PluginConfigParams::KEY_PERF_COUNT)).Times(AnyNumber())
|
||||
.WillRepeatedly(Return(PluginConfigParams::NO));
|
||||
EXPECT_CALL(*net, GetMetric(METRIC_KEY(OPTIMAL_NUMBER_OF_INFER_REQUESTS))).Times(AnyNumber())
|
||||
.WillRepeatedly(Return((unsigned int) 1));
|
||||
EXPECT_CALL(*net, GetMetric(METRIC_KEY(NETWORK_NAME))).Times(AnyNumber())
|
||||
.WillRepeatedly(Return("mock_net"));
|
||||
EXPECT_CALL(*net, GetMetric(METRIC_KEY(SUPPORTED_METRICS))).Times(AnyNumber())
|
||||
.WillRepeatedly(Invoke([&](const std::string &) {
|
||||
std::vector<std::string> res;
|
||||
res.push_back(METRIC_KEY(OPTIMAL_NUMBER_OF_INFER_REQUESTS));
|
||||
res.push_back(METRIC_KEY(NETWORK_NAME));
|
||||
return res;
|
||||
}));
|
||||
EXPECT_CALL(*net, CreateInferRequest()).Times(AnyNumber())
|
||||
.WillRepeatedly(Invoke([&]() {
|
||||
std::vector<std::string> res;
|
||||
auto inferReq = std::make_shared<MockIInferRequest>();
|
||||
EXPECT_CALL(*inferReq, SetCompletionCallback(_)).Times(AnyNumber()).WillRepeatedly(Return(OK));
|
||||
EXPECT_CALL(*inferReq, SetUserData(_, _)).Times(AnyNumber()).WillRepeatedly(Return(OK));
|
||||
return inferReq;
|
||||
}));
|
||||
}
|
||||
};
|
||||
|
||||
@@ -1172,6 +1215,130 @@ TEST_P(CachingTest, LoadHetero_MultiArchs_TargetFallback_FromCore) {
|
||||
}
|
||||
}
|
||||
|
||||
// MULTI-DEVICE test
|
||||
// Test that it is safe to load multiple devices sharing same cache
|
||||
TEST_P(CachingTest, LoadMulti_race) {
|
||||
const auto TEST_DURATION_MS = 2000;
|
||||
const auto TEST_DEVICE_MAX_COUNT = 10;
|
||||
EXPECT_CALL(*mockPlugin, GetMetric(_, _)).Times(AnyNumber());
|
||||
EXPECT_CALL(*mockPlugin, QueryNetwork(_, _)).Times(AnyNumber());
|
||||
EXPECT_CALL(*mockPlugin, GetMetric(METRIC_KEY(DEVICE_ARCHITECTURE), _)).Times(AnyNumber());
|
||||
if (m_remoteContext) {
|
||||
return; // skip the remote Context test for Multi plugin
|
||||
}
|
||||
int index = 0;
|
||||
auto start = high_resolution_clock::now();
|
||||
do {
|
||||
std::string cacheDir = m_cacheDir + std::to_string(index);
|
||||
MkDirGuard guard(cacheDir);
|
||||
int devCount = 1 + index % (TEST_DEVICE_MAX_COUNT - 1); // try dynamic number of devices from 1 to max
|
||||
deviceToLoad = CommonTestUtils::DEVICE_MULTI;
|
||||
deviceToLoad += ":mock.0";
|
||||
for (int i = 1; i < devCount; i++) {
|
||||
deviceToLoad += ",mock." + std::to_string(i);
|
||||
}
|
||||
|
||||
EXPECT_CALL(*mockPlugin, LoadExeNetworkImpl(_, _, _)).Times(0);
|
||||
EXPECT_CALL(*mockPlugin, LoadExeNetworkImpl(_, _)).Times(1);
|
||||
EXPECT_CALL(*mockPlugin, ImportNetworkImpl(_, _, _)).Times(0);
|
||||
EXPECT_CALL(*mockPlugin, ImportNetworkImpl(_, _)).Times(devCount - 1);
|
||||
EXPECT_CALL(*net, ExportImpl(_)).Times(1);
|
||||
testLoad([&](Core &ie) {
|
||||
ie.SetConfig({{CONFIG_KEY(CACHE_DIR), cacheDir}});
|
||||
ASSERT_NO_THROW(m_testFunction(ie));
|
||||
});
|
||||
index++;
|
||||
} while (duration_cast<milliseconds>(high_resolution_clock::now() - start).count() < TEST_DURATION_MS);
|
||||
std::cout << "Caching LoadMulti Test completed. Tried " << index << " times" << std::endl;
|
||||
}
|
||||
|
||||
TEST_P(CachingTest, Load_threads) {
|
||||
const auto TEST_DURATION_MS = 2000;
|
||||
const auto THREADS_COUNT = 4;
|
||||
EXPECT_CALL(*mockPlugin, GetMetric(_, _)).Times(AnyNumber());
|
||||
EXPECT_CALL(*mockPlugin, QueryNetwork(_, _)).Times(AnyNumber());
|
||||
EXPECT_CALL(*mockPlugin, GetMetric(METRIC_KEY(DEVICE_ARCHITECTURE), _)).Times(AnyNumber());
|
||||
if (m_remoteContext) {
|
||||
return; // skip the remote Context test for Multi plugin
|
||||
}
|
||||
auto start = high_resolution_clock::now();
|
||||
int index = 0;
|
||||
do {
|
||||
std::string cacheDir = m_cacheDir + std::to_string(index);
|
||||
MkDirGuard guard(cacheDir);
|
||||
EXPECT_CALL(*mockPlugin, LoadExeNetworkImpl(_, _, _)).Times(0);
|
||||
EXPECT_CALL(*mockPlugin, LoadExeNetworkImpl(_, _)).Times(1);
|
||||
EXPECT_CALL(*mockPlugin, ImportNetworkImpl(_, _, _)).Times(0);
|
||||
EXPECT_CALL(*mockPlugin, ImportNetworkImpl(_, _)).Times(THREADS_COUNT - 1);
|
||||
EXPECT_CALL(*net, ExportImpl(_)).Times(1);
|
||||
testLoad([&](Core &ie) {
|
||||
ie.SetConfig({{CONFIG_KEY(CACHE_DIR), cacheDir}});
|
||||
std::vector<std::thread> threads;
|
||||
for (int i = 0; i < THREADS_COUNT; i++) {
|
||||
threads.emplace_back(([&]() { m_testFunction(ie); }));
|
||||
}
|
||||
for (int i = 0; i < THREADS_COUNT; i++) {
|
||||
threads[i].join();
|
||||
}
|
||||
});
|
||||
index++;
|
||||
} while (duration_cast<milliseconds>(high_resolution_clock::now() - start).count() < TEST_DURATION_MS);
|
||||
std::cout << "Caching Load multiple threads test completed. Tried " << index << " times" << std::endl;
|
||||
}
|
||||
|
||||
// MULTI-DEVICE test
|
||||
// Test that loading of device with one architecture doesn't block loading of device with another architecture
|
||||
TEST_P(CachingTest, LoadMulti_Archs) {
|
||||
const auto IMPORT_DELAY_LONG_MS = 3000;
|
||||
const auto TEST_DEVICE_MAX_COUNT = 30; // Shall be >= 2
|
||||
const auto IMPORT_DELAY_SHORT_MS = 100;
|
||||
const auto EXP_MAX_EXEC_TIME_MS = 5500;
|
||||
EXPECT_CALL(*mockPlugin, GetMetric(_, _)).Times(AnyNumber());
|
||||
EXPECT_CALL(*mockPlugin, QueryNetwork(_, _)).Times(AnyNumber());
|
||||
EXPECT_CALL(*mockPlugin, GetMetric(METRIC_KEY(DEVICE_ARCHITECTURE), _)).Times(AnyNumber())
|
||||
.WillRepeatedly(Invoke([&](const std::string &, const std::map<std::string, Parameter> &options) {
|
||||
auto id = options.at("DEVICE_ID").as<std::string>();
|
||||
if (std::stoi(id) < 2) {
|
||||
return "mock_first_architecture";
|
||||
} else {
|
||||
return "mock_another_architecture";
|
||||
}
|
||||
}));
|
||||
if (m_remoteContext) {
|
||||
return; // skip the remote Context test for Multi plugin
|
||||
}
|
||||
|
||||
deviceToLoad = CommonTestUtils::DEVICE_MULTI;
|
||||
deviceToLoad += ":mock.0";
|
||||
for (int i = 1; i < TEST_DEVICE_MAX_COUNT; i++) {
|
||||
deviceToLoad += ",mock." + std::to_string(i);
|
||||
}
|
||||
|
||||
auto start = high_resolution_clock::now();
|
||||
{
|
||||
EXPECT_CALL(*mockPlugin, LoadExeNetworkImpl(_, _, _)).Times(0);
|
||||
EXPECT_CALL(*mockPlugin, LoadExeNetworkImpl(_, _)).Times(2);
|
||||
|
||||
EXPECT_CALL(*mockPlugin, ImportNetworkImpl(_, _, _)).Times(0);
|
||||
EXPECT_CALL(*mockPlugin, ImportNetworkImpl(_, _)).Times(TEST_DEVICE_MAX_COUNT - 2)
|
||||
.WillRepeatedly(Invoke([&](std::istream &, const std::map<std::string, std::string> &opt) {
|
||||
auto id = opt.at("DEVICE_ID");
|
||||
if (std::stoi(id) < 2) {
|
||||
std::this_thread::sleep_for(milliseconds(IMPORT_DELAY_LONG_MS));
|
||||
} else {
|
||||
std::this_thread::sleep_for(milliseconds(IMPORT_DELAY_SHORT_MS));
|
||||
}
|
||||
return createMockIExecutableNet();
|
||||
}));
|
||||
EXPECT_CALL(*net, ExportImpl(_)).Times(2);
|
||||
testLoad([&](Core &ie) {
|
||||
ie.SetConfig({{CONFIG_KEY(CACHE_DIR), m_cacheDir}});
|
||||
ASSERT_NO_THROW(m_testFunction(ie));
|
||||
});
|
||||
}
|
||||
ASSERT_LT(duration_cast<milliseconds>(high_resolution_clock::now() - start).count(), EXP_MAX_EXEC_TIME_MS);
|
||||
}
|
||||
|
||||
INSTANTIATE_TEST_CASE_P(CachingTest, CachingTest,
|
||||
::testing::Combine(
|
||||
::testing::ValuesIn(loadVariants),
|
||||
|
||||
@@ -349,6 +349,33 @@ TEST(NetworkContext_CNNNetwork, HashWithDifferentMeanValues) {
|
||||
NetworkCompilationContext::computeHash(net3, {}));
|
||||
}
|
||||
|
||||
// Verify all internal hash calculations are thread-safe (like ngraph::function serialization)
|
||||
TEST(NetworkContext_CNNNetwork, HashOfSameMultiThreading) {
|
||||
auto net1 = createNetwork();
|
||||
auto net2 = createNetwork();
|
||||
std::atomic_bool fail{false};
|
||||
const auto TEST_DURATION_MS = 1000;
|
||||
auto start = high_resolution_clock::now();
|
||||
int t1Count = 0, t2Count = 0;
|
||||
auto threadFun = [&](int& count) {
|
||||
do {
|
||||
count++;
|
||||
auto hash1 = NetworkCompilationContext::computeHash(net1, {});
|
||||
auto hash2 = NetworkCompilationContext::computeHash(net2, {});
|
||||
if (hash1 != hash2) {
|
||||
fail = true;
|
||||
break;
|
||||
}
|
||||
} while (!fail && duration_cast<milliseconds>(high_resolution_clock::now() - start).count() < TEST_DURATION_MS);
|
||||
};
|
||||
std::thread t1(threadFun, std::ref(t1Count));
|
||||
std::thread t2(threadFun, std::ref(t2Count));
|
||||
t1.join();
|
||||
t2.join();
|
||||
std::cout << "Hash threading test finished. Total runs = " << t1Count + t2Count << std::endl;
|
||||
ASSERT_FALSE(fail);
|
||||
}
|
||||
|
||||
////////////////////////////////////////////
|
||||
|
||||
TEST(NetworkContext_ModelName, HashOfSame) {
|
||||
|
||||
@@ -34,133 +34,84 @@ ngraph::Node* ngraph::OpSet::create_insensitive(const std::string& name) const
|
||||
|
||||
const ngraph::OpSet& ngraph::get_opset1()
|
||||
{
|
||||
static std::mutex init_mutex;
|
||||
static bool opset_is_initialized = false;
|
||||
static OpSet opset;
|
||||
if (!opset_is_initialized)
|
||||
{
|
||||
std::lock_guard<std::mutex> guard(init_mutex);
|
||||
if (!opset_is_initialized)
|
||||
{
|
||||
static std::once_flag flag;
|
||||
std::call_once(flag, [&]() {
|
||||
#define NGRAPH_OP(NAME, NAMESPACE) opset.insert<NAMESPACE::NAME>();
|
||||
#include "ngraph/opsets/opset1_tbl.hpp"
|
||||
#undef NGRAPH_OP
|
||||
opset_is_initialized = true;
|
||||
}
|
||||
}
|
||||
});
|
||||
return opset;
|
||||
}
|
||||
|
||||
const ngraph::OpSet& ngraph::get_opset2()
|
||||
{
|
||||
static std::mutex init_mutex;
|
||||
static bool opset_is_initialized = false;
|
||||
static OpSet opset;
|
||||
if (!opset_is_initialized)
|
||||
{
|
||||
std::lock_guard<std::mutex> guard(init_mutex);
|
||||
if (!opset_is_initialized)
|
||||
{
|
||||
static std::once_flag flag;
|
||||
std::call_once(flag, [&]() {
|
||||
#define NGRAPH_OP(NAME, NAMESPACE) opset.insert<NAMESPACE::NAME>();
|
||||
#include "ngraph/opsets/opset2_tbl.hpp"
|
||||
#undef NGRAPH_OP
|
||||
opset_is_initialized = true;
|
||||
}
|
||||
}
|
||||
});
|
||||
return opset;
|
||||
}
|
||||
|
||||
const ngraph::OpSet& ngraph::get_opset3()
|
||||
{
|
||||
static std::mutex init_mutex;
|
||||
static bool opset_is_initialized = false;
|
||||
static OpSet opset;
|
||||
if (!opset_is_initialized)
|
||||
{
|
||||
std::lock_guard<std::mutex> guard(init_mutex);
|
||||
if (!opset_is_initialized)
|
||||
{
|
||||
static std::once_flag flag;
|
||||
std::call_once(flag, [&]() {
|
||||
#define NGRAPH_OP(NAME, NAMESPACE) opset.insert<NAMESPACE::NAME>();
|
||||
#include "ngraph/opsets/opset3_tbl.hpp"
|
||||
#undef NGRAPH_OP
|
||||
opset_is_initialized = true;
|
||||
}
|
||||
}
|
||||
});
|
||||
return opset;
|
||||
}
|
||||
|
||||
const ngraph::OpSet& ngraph::get_opset4()
|
||||
{
|
||||
static std::mutex init_mutex;
|
||||
static bool opset_is_initialized = false;
|
||||
static OpSet opset;
|
||||
if (!opset_is_initialized)
|
||||
{
|
||||
std::lock_guard<std::mutex> guard(init_mutex);
|
||||
if (!opset_is_initialized)
|
||||
{
|
||||
static std::once_flag flag;
|
||||
std::call_once(flag, [&]() {
|
||||
#define NGRAPH_OP(NAME, NAMESPACE) opset.insert<NAMESPACE::NAME>();
|
||||
#include "ngraph/opsets/opset4_tbl.hpp"
|
||||
#undef NGRAPH_OP
|
||||
opset_is_initialized = true;
|
||||
}
|
||||
}
|
||||
});
|
||||
return opset;
|
||||
}
|
||||
|
||||
const ngraph::OpSet& ngraph::get_opset5()
|
||||
{
|
||||
static std::mutex init_mutex;
|
||||
static bool opset_is_initialized = false;
|
||||
static OpSet opset;
|
||||
if (!opset_is_initialized)
|
||||
{
|
||||
std::lock_guard<std::mutex> guard(init_mutex);
|
||||
if (!opset_is_initialized)
|
||||
{
|
||||
static std::once_flag flag;
|
||||
std::call_once(flag, [&]() {
|
||||
#define NGRAPH_OP(NAME, NAMESPACE) opset.insert<NAMESPACE::NAME>();
|
||||
#include "ngraph/opsets/opset5_tbl.hpp"
|
||||
#undef NGRAPH_OP
|
||||
opset_is_initialized = true;
|
||||
}
|
||||
}
|
||||
});
|
||||
return opset;
|
||||
}
|
||||
|
||||
const ngraph::OpSet& ngraph::get_opset6()
|
||||
{
|
||||
static std::mutex init_mutex;
|
||||
static bool opset_is_initialized = false;
|
||||
static OpSet opset;
|
||||
if (!opset_is_initialized)
|
||||
{
|
||||
std::lock_guard<std::mutex> guard(init_mutex);
|
||||
if (!opset_is_initialized)
|
||||
{
|
||||
static std::once_flag flag;
|
||||
std::call_once(flag, [&]() {
|
||||
#define NGRAPH_OP(NAME, NAMESPACE) opset.insert<NAMESPACE::NAME>();
|
||||
#include "ngraph/opsets/opset6_tbl.hpp"
|
||||
#undef NGRAPH_OP
|
||||
opset_is_initialized = true;
|
||||
}
|
||||
}
|
||||
});
|
||||
return opset;
|
||||
}
|
||||
|
||||
const ngraph::OpSet& ngraph::get_opset7()
|
||||
{
|
||||
static std::mutex init_mutex;
|
||||
static bool opset_is_initialized = false;
|
||||
static OpSet opset;
|
||||
if (!opset_is_initialized)
|
||||
{
|
||||
std::lock_guard<std::mutex> guard(init_mutex);
|
||||
if (!opset_is_initialized)
|
||||
{
|
||||
static std::once_flag flag;
|
||||
std::call_once(flag, [&]() {
|
||||
#define NGRAPH_OP(NAME, NAMESPACE) opset.insert<NAMESPACE::NAME>();
|
||||
#include "ngraph/opsets/opset7_tbl.hpp"
|
||||
#undef NGRAPH_OP
|
||||
opset_is_initialized = true;
|
||||
}
|
||||
}
|
||||
});
|
||||
return opset;
|
||||
}
|
||||
|
||||
@@ -484,7 +484,7 @@ const std::vector<int64_t>& ngraph::AttributeAdapter<ngraph::PartialShape>::get(
|
||||
{
|
||||
for (size_t i = 0; i < m_ref.rank().get_length(); ++i)
|
||||
{
|
||||
auto& elt = m_ref[i];
|
||||
const auto& elt = static_cast<const PartialShape&>(m_ref)[i];
|
||||
m_buffer.push_back(elt.is_dynamic() ? -1 : elt.get_length());
|
||||
}
|
||||
}
|
||||
|
||||
@@ -6,6 +6,7 @@
|
||||
#include <sstream>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
#include <thread>
|
||||
|
||||
#include "gmock/gmock.h"
|
||||
#include "gtest/gtest.h"
|
||||
@@ -13,6 +14,7 @@
|
||||
#include "ngraph/graph_util.hpp"
|
||||
#include "ngraph/ngraph.hpp"
|
||||
#include "ngraph/variant.hpp"
|
||||
#include "ngraph/opsets/opset.hpp"
|
||||
|
||||
NGRAPH_SUPPRESS_DEPRECATED_START
|
||||
|
||||
@@ -50,6 +52,35 @@ TEST(op, provenance_tag)
|
||||
ASSERT_TRUE(tags.find(tag2) != tags.end());
|
||||
}
|
||||
|
||||
TEST(op, opset_multi_thread) {
|
||||
auto doTest = [&](std::function<const ngraph::OpSet&()> fun) {
|
||||
std::atomic<const ngraph::OpSet*> opset {nullptr};
|
||||
std::atomic_bool failed {false};
|
||||
auto threadFun = [&] () {
|
||||
const ngraph::OpSet* op = &fun();
|
||||
const ngraph::OpSet* current = opset;
|
||||
do {
|
||||
if (current != nullptr && current != op) {
|
||||
failed = true;
|
||||
break;
|
||||
}
|
||||
} while (opset.compare_exchange_strong(op, current));
|
||||
};
|
||||
std::thread t1 {threadFun};
|
||||
std::thread t2 {threadFun};
|
||||
t1.join();
|
||||
t2.join();
|
||||
ASSERT_FALSE(failed);
|
||||
};
|
||||
doTest(ngraph::get_opset1);
|
||||
doTest(ngraph::get_opset2);
|
||||
doTest(ngraph::get_opset3);
|
||||
doTest(ngraph::get_opset4);
|
||||
doTest(ngraph::get_opset5);
|
||||
doTest(ngraph::get_opset6);
|
||||
doTest(ngraph::get_opset7);
|
||||
}
|
||||
|
||||
struct Ship
|
||||
{
|
||||
std::string name;
|
||||
|
||||
Reference in New Issue
Block a user