Caching support of multi-device scenario (#5040)

* Caching support of multi-device scenario

- IE_CORE: introduce CacheGuard which can create locks for specific cache identified by 'hash'
- Added functional tests for it

Fixes of Thread Sanitizer failures:
- ngraph::Serialize - m_ref[i] can create new element, casted to 'const' to avoid this
- ngraph::get_opset oprations: reworked to use std::call_once instead of double bool check

* Added docs for ie_cache_guard.hpp

* Fix Debian 9 compilation issue

* Fix build for CentOS 6

Added assert to verify that table of locked hashes is empty on destruction

* Fixed review comments
This commit is contained in:
Mikhail Nosov
2021-04-01 14:42:48 +03:00
committed by GitHub
parent ce5aa7dc1b
commit 4021cb7519
8 changed files with 469 additions and 100 deletions

View File

@@ -0,0 +1,64 @@
// Copyright (C) 2018-2021 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#include "ie_cache_guard.hpp"
#include "ie_common.h"
namespace InferenceEngine {
CacheGuardEntry::CacheGuardEntry(CacheGuard& cacheGuard, const std::string& hash,
std::shared_ptr<std::mutex> m, std::atomic_int& refCount):
m_cacheGuard(cacheGuard), m_hash(hash), m_mutex(m), m_refCount(refCount) {
// Don't lock mutex right here for exception-safe considerations
m_refCount++;
}
CacheGuardEntry::~CacheGuardEntry() {
m_refCount--;
m_mutex->unlock();
m_cacheGuard.checkForRemove(m_hash);
}
void CacheGuardEntry::performLock() {
m_mutex->lock();
}
//////////////////////////////////////////////////////
CacheGuard::~CacheGuard() {
IE_ASSERT(m_table.size() == 0);
}
std::unique_ptr<CacheGuardEntry> CacheGuard::getHashLock(const std::string& hash) {
std::unique_lock<std::mutex> lock(m_tableMutex);
auto& data = m_table[hash];
std::unique_ptr<CacheGuardEntry> res;
try {
// TODO: use std::make_unique when migrated to C++14
res = std::unique_ptr<CacheGuardEntry>(
new CacheGuardEntry(*this, hash, data.m_mutexPtr, data.m_itemRefCounter));
} catch (...) {
// In case of exception, we shall remove hash entry if it is not used
if (data.m_itemRefCounter == 0) {
m_table.erase(hash);
}
throw;
}
lock.unlock(); // can unlock table lock here, as refCounter is positive and nobody can remove entry
res->performLock(); // in case of exception, 'res' will be destroyed and item will be cleaned up from table
return res;
}
void CacheGuard::checkForRemove(const std::string& hash) {
std::lock_guard<std::mutex> lock(m_tableMutex);
if (m_table.count(hash)) {
auto &data = m_table[hash];
if (data.m_itemRefCounter == 0) {
// Nobody is using this and nobody is waiting for it - can be removed
m_table.erase(hash);
}
}
}
} // namespace InferenceEngine

View File

@@ -0,0 +1,122 @@
// Copyright (C) 2018-2021 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#pragma once
/**
* @brief This is a header file for the Inference Engine Cache Guard class C++ API
*
* @file ie_cache_guard.hpp
*/
#include <string>
#include <map>
#include <mutex>
#include <memory>
#include <atomic>
#include <unordered_map>
namespace InferenceEngine {
class CacheGuard;
/**
* @brief This class represents RAII guard class to protect multiple threads to modify the same cached network
* Use CacheGuard::getHashLock(hash) to acquire lock for specific cache entry identified by its 'hash'
* On destruction, lock will be released
* @see CacheGuard
*/
class CacheGuardEntry {
public:
/**
* @brief Internal constructor, will be called by @CacheGuard
*
* @param cacheGuard Reference link to parent's Cache Guard
* @param hash String representing hash of network
* @param m Shared pointer to mutex for internal locking
* @param refCount Reference counter. Will be decremented on CacheGuardEntry destruction
*/
CacheGuardEntry(CacheGuard& cacheGuard, const std::string& hash,
std::shared_ptr<std::mutex> m, std::atomic_int& refCount);
CacheGuardEntry(const CacheGuardEntry&) = delete;
/**
* @brief Destructor, will perform the following cleanup
*
* Decrement reference counter
* Unlock associated mutex
* Call CacheGuard::checkForRemove to check if appropriate table hash entry is not used anymore and can be deleted
*/
~CacheGuardEntry();
/**
* @brief Performs real lock of associated mutex
* It is separated from construction due to exception safety considerations
*
* @note Will be called only by CacheGuard, it shall not be called from client's code
*/
void performLock();
private:
CacheGuard& m_cacheGuard;
std::string m_hash;
std::shared_ptr<std::mutex> m_mutex;
std::atomic_int& m_refCount;
};
/**
* @brief This class holds a table of currently locked hashes
* Inference engine core will need to obtain a lock for a specific cache to get exclusive access to it
* It is needed to avoid race situations when multiple threads try to to write to the same cache simultaneously
*
* Usage example:
* auto hash = <calculate hash for network>;
* {
* auto lock = m_cacheGuard.getHashLock(hash);
* <work with cache entry exclusively>
* }
*/
class CacheGuard {
public:
CacheGuard() = default;
~CacheGuard();
/**
* @brief Gets a lock for a specific cache entry identified by it's hash value
* Once returned, client has an exclusive access to cache entry for read/write/delete
* If any other thread holds a lock to same hash - this function will not return until it is unlocked
*
* @param hash String representing hash of network
*
* @return RAII pointer to CacheGuardEntry
*/
std::unique_ptr<CacheGuardEntry> getHashLock(const std::string& hash);
/**
* @brief Checks whether there is any clients holding the lock after CacheGuardEntry deletion
* It will be called on destruction of CacheGuardEntry and shall not be used directly by client's code
* If there is no more clients holding the lock, associated entry will be removed from table unlocked
*
* @param hash String representing hash of network
*
* @return RAII pointer to CacheGuardEntry
*/
void checkForRemove(const std::string& hash);
private:
struct Item {
std::shared_ptr<std::mutex> m_mutexPtr { std::make_shared<std::mutex>() };
// Reference counter for item usage
std::atomic_int m_itemRefCounter {0};
Item() = default;
Item(const Item& other): m_mutexPtr(other.m_mutexPtr),
m_itemRefCounter(other.m_itemRefCounter.load()) {}
Item(Item&& other): m_mutexPtr(std::move(other.m_mutexPtr)),
m_itemRefCounter(other.m_itemRefCounter.load()) {}
};
std::mutex m_tableMutex;
std::unordered_map<std::string, Item> m_table;
};
} // namespace InferenceEngine

View File

@@ -21,6 +21,7 @@
#include "ie_plugin_cpp.hpp"
#include "ie_plugin_config.hpp"
#include "ie_cache_manager.hpp"
#include "ie_cache_guard.hpp"
#include "ie_itt.hpp"
#include "file_utils.h"
#include "ie_network_reader.hpp"
@@ -197,6 +198,8 @@ class Core::Impl : public ICore {
// Core settings (cache config, etc)
CoreConfig coreConfig;
CacheGuard cacheGuard;
struct PluginDescriptor {
FileUtils::FilePath libraryLocation;
std::map<std::string, std::string> defaultConfig;
@@ -447,17 +450,18 @@ public:
}
auto parsed = parseDeviceNameIntoConfig(context->getDeviceName(), config);
auto plugin = GetCPPPluginByName(parsed._deviceName);
bool loadedFromCache = false;
ExecutableNetwork res;
std::string hash;
auto cacheManager = coreConfig.getCacheConfig()._cacheManager;
if (cacheManager && DeviceSupportsImportExport(plugin)) {
hash = CalculateNetworkHash(network, parsed._deviceName, plugin, parsed._config);
auto hash = CalculateNetworkHash(network, parsed._deviceName, plugin, parsed._config);
bool loadedFromCache = false;
auto lock = cacheGuard.getHashLock(hash);
res = LoadNetworkFromCache(cacheManager, hash, plugin, parsed._config, context, loadedFromCache);
}
if (!loadedFromCache) {
res = LoadNetworkImpl(network, plugin, parsed._config, context, hash);
if (!loadedFromCache) {
res = LoadNetworkImpl(network, plugin, parsed._config, context, hash);
}
} else {
res = LoadNetworkImpl(network, plugin, parsed._config, context, {});
}
return res;
}
@@ -472,17 +476,18 @@ public:
parsed._config.erase(CONFIG_KEY_INTERNAL(FORCE_DISABLE_CACHE));
}
auto plugin = GetCPPPluginByName(parsed._deviceName);
bool loadedFromCache = false;
ExecutableNetwork res;
std::string hash;
auto cacheManager = coreConfig.getCacheConfig()._cacheManager;
if (!forceDisableCache && cacheManager && DeviceSupportsImportExport(plugin)) {
hash = CalculateNetworkHash(network, parsed._deviceName, plugin, parsed._config);
auto hash = CalculateNetworkHash(network, parsed._deviceName, plugin, parsed._config);
bool loadedFromCache = false;
auto lock = cacheGuard.getHashLock(hash);
res = LoadNetworkFromCache(cacheManager, hash, plugin, parsed._config, nullptr, loadedFromCache);
}
if (!loadedFromCache) {
res = LoadNetworkImpl(network, plugin, parsed._config, nullptr, hash, {}, forceDisableCache);
if (!loadedFromCache) {
res = LoadNetworkImpl(network, plugin, parsed._config, nullptr, hash, {}, forceDisableCache);
}
} else {
res = LoadNetworkImpl(network, plugin, parsed._config, nullptr, {}, {}, forceDisableCache);
}
return res;
}
@@ -493,19 +498,21 @@ public:
OV_ITT_SCOPED_TASK(itt::domains::IE_LT, "Core::LoadNetwork::Path");
auto parsed = parseDeviceNameIntoConfig(deviceName, config);
auto plugin = GetCPPPluginByName(parsed._deviceName);
bool loadedFromCache = false;
ExecutableNetwork res;
std::string hash;
auto cacheManager = coreConfig.getCacheConfig()._cacheManager;
if (cacheManager && DeviceSupportsImportExport(plugin)) {
hash = CalculateFileHash(modelPath, parsed._deviceName, plugin, parsed._config);
bool loadedFromCache = false;
auto hash = CalculateFileHash(modelPath, parsed._deviceName, plugin, parsed._config);
auto lock = cacheGuard.getHashLock(hash);
res = LoadNetworkFromCache(cacheManager, hash, plugin, parsed._config,
nullptr, loadedFromCache, modelPath);
}
if (!loadedFromCache) {
if (!loadedFromCache) {
auto cnnNetwork = ReadNetwork(modelPath, std::string());
res = LoadNetworkImpl(cnnNetwork, plugin, parsed._config, nullptr, hash, modelPath);
}
} else {
auto cnnNetwork = ReadNetwork(modelPath, std::string());
res = LoadNetworkImpl(cnnNetwork, plugin, parsed._config, nullptr, hash, modelPath);
res = LoadNetworkImpl(cnnNetwork, plugin, parsed._config, nullptr, {}, modelPath);
}
return res;
}

View File

@@ -27,6 +27,7 @@
#include "functional_test_utils/network_utils.hpp"
#include "unit_test_utils/mocks/mock_iexecutable_network.hpp"
#include "unit_test_utils/mocks/mock_iinfer_request.hpp"
using namespace InferenceEngine;
using namespace ::testing;
@@ -97,6 +98,8 @@ public:
MOCK_METHOD0(CreateInferRequest, IInferRequest::Ptr());
MOCK_CONST_METHOD0(GetInputsInfo, ConstInputsDataMap());
MOCK_CONST_METHOD0(GetOutputsInfo, ConstOutputsDataMap());
MOCK_CONST_METHOD1(GetConfig, Parameter(const std::string& name));
MOCK_CONST_METHOD1(GetMetric, Parameter(const std::string& name));
};
//------------------------------------------------------
@@ -240,6 +243,31 @@ public:
ie.LoadNetwork(cnnNetwork, context, config);
}
ExecutableNetwork createMockIExecutableNet() {
auto mock = std::make_shared<MockIExecutableNetwork>();
EXPECT_CALL(*mock, GetInputsInfo(_, _)).Times(AnyNumber()).WillRepeatedly(Return(OK));
EXPECT_CALL(*mock, GetOutputsInfo(_, _)).Times(AnyNumber()).WillRepeatedly(Return(OK));
EXPECT_CALL(*mock, GetConfig(PluginConfigParams::KEY_PERF_COUNT, _, _)).Times(AnyNumber())
.WillRepeatedly(Invoke([&](const std::string &name, Parameter &result, ResponseDesc *resp) {
result = PluginConfigParams::NO;
return OK;
}));
EXPECT_CALL(*mock, GetMetric(METRIC_KEY(OPTIMAL_NUMBER_OF_INFER_REQUESTS), _, _)).Times(AnyNumber())
.WillRepeatedly(Invoke([&](const std::string &name, Parameter &result, ResponseDesc *resp) {
result = (unsigned int) 1;
return OK;
}));
EXPECT_CALL(*mock, CreateInferRequest(_, _)).Times(AnyNumber())
.WillRepeatedly(Invoke([&](IInferRequest::Ptr &req, ResponseDesc*) {
auto ptr = std::make_shared<MockIInferRequest>();
EXPECT_CALL(*ptr, SetCompletionCallback(_)).Times(AnyNumber()).WillRepeatedly(Return(OK));
EXPECT_CALL(*ptr, SetUserData(_, _)).Times(AnyNumber()).WillRepeatedly(Return(OK));
req = ptr;
return OK;
}));
return ExecutableNetwork(mock);
}
private:
template <class T>
std::function<T> make_std_function(const std::string& functionName) {
@@ -271,18 +299,12 @@ private:
ON_CALL(plugin, ImportNetworkImpl(_, _, _)).
WillByDefault(Invoke([&](std::istream &istr, RemoteContext::Ptr,
const std::map<std::string, std::string> &) {
auto mock = std::make_shared<MockIExecutableNetwork>();
EXPECT_CALL(*mock, GetInputsInfo(_, _)).Times(AnyNumber()).WillRepeatedly(Return(OK));
EXPECT_CALL(*mock, GetOutputsInfo(_, _)).Times(AnyNumber()).WillRepeatedly(Return(OK));
return ExecutableNetwork(mock);
return createMockIExecutableNet();
}));
ON_CALL(plugin, ImportNetworkImpl(_, _)).
WillByDefault(Invoke([&](std::istream &istr, const std::map<std::string, std::string> &) {
auto mock = std::make_shared<MockIExecutableNetwork>();
EXPECT_CALL(*mock, GetInputsInfo(_, _)).Times(AnyNumber()).WillRepeatedly(Return(OK));
EXPECT_CALL(*mock, GetOutputsInfo(_, _)).Times(AnyNumber()).WillRepeatedly(Return(OK));
return ExecutableNetwork(mock);
return createMockIExecutableNet();
}));
ON_CALL(plugin, LoadExeNetworkImpl(_, _, _)).
@@ -318,6 +340,27 @@ private:
.WillRepeatedly(Return(ConstInputsDataMap{}));
EXPECT_CALL(*net, GetOutputsInfo()).Times(AnyNumber())
.WillRepeatedly(Return(ConstOutputsDataMap{}));
EXPECT_CALL(*net, GetConfig(PluginConfigParams::KEY_PERF_COUNT)).Times(AnyNumber())
.WillRepeatedly(Return(PluginConfigParams::NO));
EXPECT_CALL(*net, GetMetric(METRIC_KEY(OPTIMAL_NUMBER_OF_INFER_REQUESTS))).Times(AnyNumber())
.WillRepeatedly(Return((unsigned int) 1));
EXPECT_CALL(*net, GetMetric(METRIC_KEY(NETWORK_NAME))).Times(AnyNumber())
.WillRepeatedly(Return("mock_net"));
EXPECT_CALL(*net, GetMetric(METRIC_KEY(SUPPORTED_METRICS))).Times(AnyNumber())
.WillRepeatedly(Invoke([&](const std::string &) {
std::vector<std::string> res;
res.push_back(METRIC_KEY(OPTIMAL_NUMBER_OF_INFER_REQUESTS));
res.push_back(METRIC_KEY(NETWORK_NAME));
return res;
}));
EXPECT_CALL(*net, CreateInferRequest()).Times(AnyNumber())
.WillRepeatedly(Invoke([&]() {
std::vector<std::string> res;
auto inferReq = std::make_shared<MockIInferRequest>();
EXPECT_CALL(*inferReq, SetCompletionCallback(_)).Times(AnyNumber()).WillRepeatedly(Return(OK));
EXPECT_CALL(*inferReq, SetUserData(_, _)).Times(AnyNumber()).WillRepeatedly(Return(OK));
return inferReq;
}));
}
};
@@ -1172,6 +1215,130 @@ TEST_P(CachingTest, LoadHetero_MultiArchs_TargetFallback_FromCore) {
}
}
// MULTI-DEVICE test
// Test that it is safe to load multiple devices sharing same cache
TEST_P(CachingTest, LoadMulti_race) {
const auto TEST_DURATION_MS = 2000;
const auto TEST_DEVICE_MAX_COUNT = 10;
EXPECT_CALL(*mockPlugin, GetMetric(_, _)).Times(AnyNumber());
EXPECT_CALL(*mockPlugin, QueryNetwork(_, _)).Times(AnyNumber());
EXPECT_CALL(*mockPlugin, GetMetric(METRIC_KEY(DEVICE_ARCHITECTURE), _)).Times(AnyNumber());
if (m_remoteContext) {
return; // skip the remote Context test for Multi plugin
}
int index = 0;
auto start = high_resolution_clock::now();
do {
std::string cacheDir = m_cacheDir + std::to_string(index);
MkDirGuard guard(cacheDir);
int devCount = 1 + index % (TEST_DEVICE_MAX_COUNT - 1); // try dynamic number of devices from 1 to max
deviceToLoad = CommonTestUtils::DEVICE_MULTI;
deviceToLoad += ":mock.0";
for (int i = 1; i < devCount; i++) {
deviceToLoad += ",mock." + std::to_string(i);
}
EXPECT_CALL(*mockPlugin, LoadExeNetworkImpl(_, _, _)).Times(0);
EXPECT_CALL(*mockPlugin, LoadExeNetworkImpl(_, _)).Times(1);
EXPECT_CALL(*mockPlugin, ImportNetworkImpl(_, _, _)).Times(0);
EXPECT_CALL(*mockPlugin, ImportNetworkImpl(_, _)).Times(devCount - 1);
EXPECT_CALL(*net, ExportImpl(_)).Times(1);
testLoad([&](Core &ie) {
ie.SetConfig({{CONFIG_KEY(CACHE_DIR), cacheDir}});
ASSERT_NO_THROW(m_testFunction(ie));
});
index++;
} while (duration_cast<milliseconds>(high_resolution_clock::now() - start).count() < TEST_DURATION_MS);
std::cout << "Caching LoadMulti Test completed. Tried " << index << " times" << std::endl;
}
TEST_P(CachingTest, Load_threads) {
const auto TEST_DURATION_MS = 2000;
const auto THREADS_COUNT = 4;
EXPECT_CALL(*mockPlugin, GetMetric(_, _)).Times(AnyNumber());
EXPECT_CALL(*mockPlugin, QueryNetwork(_, _)).Times(AnyNumber());
EXPECT_CALL(*mockPlugin, GetMetric(METRIC_KEY(DEVICE_ARCHITECTURE), _)).Times(AnyNumber());
if (m_remoteContext) {
return; // skip the remote Context test for Multi plugin
}
auto start = high_resolution_clock::now();
int index = 0;
do {
std::string cacheDir = m_cacheDir + std::to_string(index);
MkDirGuard guard(cacheDir);
EXPECT_CALL(*mockPlugin, LoadExeNetworkImpl(_, _, _)).Times(0);
EXPECT_CALL(*mockPlugin, LoadExeNetworkImpl(_, _)).Times(1);
EXPECT_CALL(*mockPlugin, ImportNetworkImpl(_, _, _)).Times(0);
EXPECT_CALL(*mockPlugin, ImportNetworkImpl(_, _)).Times(THREADS_COUNT - 1);
EXPECT_CALL(*net, ExportImpl(_)).Times(1);
testLoad([&](Core &ie) {
ie.SetConfig({{CONFIG_KEY(CACHE_DIR), cacheDir}});
std::vector<std::thread> threads;
for (int i = 0; i < THREADS_COUNT; i++) {
threads.emplace_back(([&]() { m_testFunction(ie); }));
}
for (int i = 0; i < THREADS_COUNT; i++) {
threads[i].join();
}
});
index++;
} while (duration_cast<milliseconds>(high_resolution_clock::now() - start).count() < TEST_DURATION_MS);
std::cout << "Caching Load multiple threads test completed. Tried " << index << " times" << std::endl;
}
// MULTI-DEVICE test
// Test that loading of device with one architecture doesn't block loading of device with another architecture
TEST_P(CachingTest, LoadMulti_Archs) {
const auto IMPORT_DELAY_LONG_MS = 3000;
const auto TEST_DEVICE_MAX_COUNT = 30; // Shall be >= 2
const auto IMPORT_DELAY_SHORT_MS = 100;
const auto EXP_MAX_EXEC_TIME_MS = 5500;
EXPECT_CALL(*mockPlugin, GetMetric(_, _)).Times(AnyNumber());
EXPECT_CALL(*mockPlugin, QueryNetwork(_, _)).Times(AnyNumber());
EXPECT_CALL(*mockPlugin, GetMetric(METRIC_KEY(DEVICE_ARCHITECTURE), _)).Times(AnyNumber())
.WillRepeatedly(Invoke([&](const std::string &, const std::map<std::string, Parameter> &options) {
auto id = options.at("DEVICE_ID").as<std::string>();
if (std::stoi(id) < 2) {
return "mock_first_architecture";
} else {
return "mock_another_architecture";
}
}));
if (m_remoteContext) {
return; // skip the remote Context test for Multi plugin
}
deviceToLoad = CommonTestUtils::DEVICE_MULTI;
deviceToLoad += ":mock.0";
for (int i = 1; i < TEST_DEVICE_MAX_COUNT; i++) {
deviceToLoad += ",mock." + std::to_string(i);
}
auto start = high_resolution_clock::now();
{
EXPECT_CALL(*mockPlugin, LoadExeNetworkImpl(_, _, _)).Times(0);
EXPECT_CALL(*mockPlugin, LoadExeNetworkImpl(_, _)).Times(2);
EXPECT_CALL(*mockPlugin, ImportNetworkImpl(_, _, _)).Times(0);
EXPECT_CALL(*mockPlugin, ImportNetworkImpl(_, _)).Times(TEST_DEVICE_MAX_COUNT - 2)
.WillRepeatedly(Invoke([&](std::istream &, const std::map<std::string, std::string> &opt) {
auto id = opt.at("DEVICE_ID");
if (std::stoi(id) < 2) {
std::this_thread::sleep_for(milliseconds(IMPORT_DELAY_LONG_MS));
} else {
std::this_thread::sleep_for(milliseconds(IMPORT_DELAY_SHORT_MS));
}
return createMockIExecutableNet();
}));
EXPECT_CALL(*net, ExportImpl(_)).Times(2);
testLoad([&](Core &ie) {
ie.SetConfig({{CONFIG_KEY(CACHE_DIR), m_cacheDir}});
ASSERT_NO_THROW(m_testFunction(ie));
});
}
ASSERT_LT(duration_cast<milliseconds>(high_resolution_clock::now() - start).count(), EXP_MAX_EXEC_TIME_MS);
}
INSTANTIATE_TEST_CASE_P(CachingTest, CachingTest,
::testing::Combine(
::testing::ValuesIn(loadVariants),

View File

@@ -349,6 +349,33 @@ TEST(NetworkContext_CNNNetwork, HashWithDifferentMeanValues) {
NetworkCompilationContext::computeHash(net3, {}));
}
// Verify all internal hash calculations are thread-safe (like ngraph::function serialization)
TEST(NetworkContext_CNNNetwork, HashOfSameMultiThreading) {
auto net1 = createNetwork();
auto net2 = createNetwork();
std::atomic_bool fail{false};
const auto TEST_DURATION_MS = 1000;
auto start = high_resolution_clock::now();
int t1Count = 0, t2Count = 0;
auto threadFun = [&](int& count) {
do {
count++;
auto hash1 = NetworkCompilationContext::computeHash(net1, {});
auto hash2 = NetworkCompilationContext::computeHash(net2, {});
if (hash1 != hash2) {
fail = true;
break;
}
} while (!fail && duration_cast<milliseconds>(high_resolution_clock::now() - start).count() < TEST_DURATION_MS);
};
std::thread t1(threadFun, std::ref(t1Count));
std::thread t2(threadFun, std::ref(t2Count));
t1.join();
t2.join();
std::cout << "Hash threading test finished. Total runs = " << t1Count + t2Count << std::endl;
ASSERT_FALSE(fail);
}
////////////////////////////////////////////
TEST(NetworkContext_ModelName, HashOfSame) {

View File

@@ -34,133 +34,84 @@ ngraph::Node* ngraph::OpSet::create_insensitive(const std::string& name) const
const ngraph::OpSet& ngraph::get_opset1()
{
static std::mutex init_mutex;
static bool opset_is_initialized = false;
static OpSet opset;
if (!opset_is_initialized)
{
std::lock_guard<std::mutex> guard(init_mutex);
if (!opset_is_initialized)
{
static std::once_flag flag;
std::call_once(flag, [&]() {
#define NGRAPH_OP(NAME, NAMESPACE) opset.insert<NAMESPACE::NAME>();
#include "ngraph/opsets/opset1_tbl.hpp"
#undef NGRAPH_OP
opset_is_initialized = true;
}
}
});
return opset;
}
const ngraph::OpSet& ngraph::get_opset2()
{
static std::mutex init_mutex;
static bool opset_is_initialized = false;
static OpSet opset;
if (!opset_is_initialized)
{
std::lock_guard<std::mutex> guard(init_mutex);
if (!opset_is_initialized)
{
static std::once_flag flag;
std::call_once(flag, [&]() {
#define NGRAPH_OP(NAME, NAMESPACE) opset.insert<NAMESPACE::NAME>();
#include "ngraph/opsets/opset2_tbl.hpp"
#undef NGRAPH_OP
opset_is_initialized = true;
}
}
});
return opset;
}
const ngraph::OpSet& ngraph::get_opset3()
{
static std::mutex init_mutex;
static bool opset_is_initialized = false;
static OpSet opset;
if (!opset_is_initialized)
{
std::lock_guard<std::mutex> guard(init_mutex);
if (!opset_is_initialized)
{
static std::once_flag flag;
std::call_once(flag, [&]() {
#define NGRAPH_OP(NAME, NAMESPACE) opset.insert<NAMESPACE::NAME>();
#include "ngraph/opsets/opset3_tbl.hpp"
#undef NGRAPH_OP
opset_is_initialized = true;
}
}
});
return opset;
}
const ngraph::OpSet& ngraph::get_opset4()
{
static std::mutex init_mutex;
static bool opset_is_initialized = false;
static OpSet opset;
if (!opset_is_initialized)
{
std::lock_guard<std::mutex> guard(init_mutex);
if (!opset_is_initialized)
{
static std::once_flag flag;
std::call_once(flag, [&]() {
#define NGRAPH_OP(NAME, NAMESPACE) opset.insert<NAMESPACE::NAME>();
#include "ngraph/opsets/opset4_tbl.hpp"
#undef NGRAPH_OP
opset_is_initialized = true;
}
}
});
return opset;
}
const ngraph::OpSet& ngraph::get_opset5()
{
static std::mutex init_mutex;
static bool opset_is_initialized = false;
static OpSet opset;
if (!opset_is_initialized)
{
std::lock_guard<std::mutex> guard(init_mutex);
if (!opset_is_initialized)
{
static std::once_flag flag;
std::call_once(flag, [&]() {
#define NGRAPH_OP(NAME, NAMESPACE) opset.insert<NAMESPACE::NAME>();
#include "ngraph/opsets/opset5_tbl.hpp"
#undef NGRAPH_OP
opset_is_initialized = true;
}
}
});
return opset;
}
const ngraph::OpSet& ngraph::get_opset6()
{
static std::mutex init_mutex;
static bool opset_is_initialized = false;
static OpSet opset;
if (!opset_is_initialized)
{
std::lock_guard<std::mutex> guard(init_mutex);
if (!opset_is_initialized)
{
static std::once_flag flag;
std::call_once(flag, [&]() {
#define NGRAPH_OP(NAME, NAMESPACE) opset.insert<NAMESPACE::NAME>();
#include "ngraph/opsets/opset6_tbl.hpp"
#undef NGRAPH_OP
opset_is_initialized = true;
}
}
});
return opset;
}
const ngraph::OpSet& ngraph::get_opset7()
{
static std::mutex init_mutex;
static bool opset_is_initialized = false;
static OpSet opset;
if (!opset_is_initialized)
{
std::lock_guard<std::mutex> guard(init_mutex);
if (!opset_is_initialized)
{
static std::once_flag flag;
std::call_once(flag, [&]() {
#define NGRAPH_OP(NAME, NAMESPACE) opset.insert<NAMESPACE::NAME>();
#include "ngraph/opsets/opset7_tbl.hpp"
#undef NGRAPH_OP
opset_is_initialized = true;
}
}
});
return opset;
}

View File

@@ -484,7 +484,7 @@ const std::vector<int64_t>& ngraph::AttributeAdapter<ngraph::PartialShape>::get(
{
for (size_t i = 0; i < m_ref.rank().get_length(); ++i)
{
auto& elt = m_ref[i];
const auto& elt = static_cast<const PartialShape&>(m_ref)[i];
m_buffer.push_back(elt.is_dynamic() ? -1 : elt.get_length());
}
}

View File

@@ -6,6 +6,7 @@
#include <sstream>
#include <string>
#include <vector>
#include <thread>
#include "gmock/gmock.h"
#include "gtest/gtest.h"
@@ -13,6 +14,7 @@
#include "ngraph/graph_util.hpp"
#include "ngraph/ngraph.hpp"
#include "ngraph/variant.hpp"
#include "ngraph/opsets/opset.hpp"
NGRAPH_SUPPRESS_DEPRECATED_START
@@ -50,6 +52,35 @@ TEST(op, provenance_tag)
ASSERT_TRUE(tags.find(tag2) != tags.end());
}
TEST(op, opset_multi_thread) {
auto doTest = [&](std::function<const ngraph::OpSet&()> fun) {
std::atomic<const ngraph::OpSet*> opset {nullptr};
std::atomic_bool failed {false};
auto threadFun = [&] () {
const ngraph::OpSet* op = &fun();
const ngraph::OpSet* current = opset;
do {
if (current != nullptr && current != op) {
failed = true;
break;
}
} while (opset.compare_exchange_strong(op, current));
};
std::thread t1 {threadFun};
std::thread t2 {threadFun};
t1.join();
t2.join();
ASSERT_FALSE(failed);
};
doTest(ngraph::get_opset1);
doTest(ngraph::get_opset2);
doTest(ngraph::get_opset3);
doTest(ngraph::get_opset4);
doTest(ngraph::get_opset5);
doTest(ngraph::get_opset6);
doTest(ngraph::get_opset7);
}
struct Ship
{
std::string name;