From 4021cb75197ae411f8099ba97ff7aa30f5e3144a Mon Sep 17 00:00:00 2001 From: Mikhail Nosov Date: Thu, 1 Apr 2021 14:42:48 +0300 Subject: [PATCH] Caching support of multi-device scenario (#5040) * Caching support of multi-device scenario - IE_CORE: introduce CacheGuard which can create locks for specific cache identified by 'hash' - Added functional tests for it Fixes of Thread Sanitizer failures: - ngraph::Serialize - m_ref[i] can create new element, casted to 'const' to avoid this - ngraph::get_opset oprations: reworked to use std::call_once instead of double bool check * Added docs for ie_cache_guard.hpp * Fix Debian 9 compilation issue * Fix build for CentOS 6 Added assert to verify that table of locked hashes is empty on destruction * Fixed review comments --- .../src/inference_engine/ie_cache_guard.cpp | 64 ++++++ .../src/inference_engine/ie_cache_guard.hpp | 122 ++++++++++++ .../src/inference_engine/ie_core.cpp | 49 +++-- .../inference_engine/caching_test.cpp | 183 +++++++++++++++++- .../ie_compilation_context_test.cpp | 27 +++ ngraph/core/src/opsets/opset.cpp | 91 ++------- ngraph/core/src/partial_shape.cpp | 2 +- ngraph/test/op.cpp | 31 +++ 8 files changed, 469 insertions(+), 100 deletions(-) create mode 100644 inference-engine/src/inference_engine/ie_cache_guard.cpp create mode 100644 inference-engine/src/inference_engine/ie_cache_guard.hpp diff --git a/inference-engine/src/inference_engine/ie_cache_guard.cpp b/inference-engine/src/inference_engine/ie_cache_guard.cpp new file mode 100644 index 00000000000..fa776d13038 --- /dev/null +++ b/inference-engine/src/inference_engine/ie_cache_guard.cpp @@ -0,0 +1,64 @@ +// Copyright (C) 2018-2021 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#include "ie_cache_guard.hpp" +#include "ie_common.h" + +namespace InferenceEngine { + +CacheGuardEntry::CacheGuardEntry(CacheGuard& cacheGuard, const std::string& hash, + std::shared_ptr m, std::atomic_int& refCount): + m_cacheGuard(cacheGuard), m_hash(hash), m_mutex(m), m_refCount(refCount) { + // Don't lock mutex right here for exception-safe considerations + m_refCount++; +} + +CacheGuardEntry::~CacheGuardEntry() { + m_refCount--; + m_mutex->unlock(); + m_cacheGuard.checkForRemove(m_hash); +} + +void CacheGuardEntry::performLock() { + m_mutex->lock(); +} + +////////////////////////////////////////////////////// + +CacheGuard::~CacheGuard() { + IE_ASSERT(m_table.size() == 0); +} + +std::unique_ptr CacheGuard::getHashLock(const std::string& hash) { + std::unique_lock lock(m_tableMutex); + auto& data = m_table[hash]; + std::unique_ptr res; + try { + // TODO: use std::make_unique when migrated to C++14 + res = std::unique_ptr( + new CacheGuardEntry(*this, hash, data.m_mutexPtr, data.m_itemRefCounter)); + } catch (...) { + // In case of exception, we shall remove hash entry if it is not used + if (data.m_itemRefCounter == 0) { + m_table.erase(hash); + } + throw; + } + lock.unlock(); // can unlock table lock here, as refCounter is positive and nobody can remove entry + res->performLock(); // in case of exception, 'res' will be destroyed and item will be cleaned up from table + return res; +} + +void CacheGuard::checkForRemove(const std::string& hash) { + std::lock_guard lock(m_tableMutex); + if (m_table.count(hash)) { + auto &data = m_table[hash]; + if (data.m_itemRefCounter == 0) { + // Nobody is using this and nobody is waiting for it - can be removed + m_table.erase(hash); + } + } +} + +} // namespace InferenceEngine diff --git a/inference-engine/src/inference_engine/ie_cache_guard.hpp b/inference-engine/src/inference_engine/ie_cache_guard.hpp new file mode 100644 index 00000000000..1fe1954d479 --- /dev/null +++ b/inference-engine/src/inference_engine/ie_cache_guard.hpp @@ -0,0 +1,122 @@ +// Copyright (C) 2018-2021 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#pragma once + +/** + * @brief This is a header file for the Inference Engine Cache Guard class C++ API + * + * @file ie_cache_guard.hpp + */ + +#include +#include +#include +#include +#include +#include + +namespace InferenceEngine { + +class CacheGuard; +/** + * @brief This class represents RAII guard class to protect multiple threads to modify the same cached network + * Use CacheGuard::getHashLock(hash) to acquire lock for specific cache entry identified by its 'hash' + * On destruction, lock will be released + * @see CacheGuard + */ +class CacheGuardEntry { +public: + /** + * @brief Internal constructor, will be called by @CacheGuard + * + * @param cacheGuard Reference link to parent's Cache Guard + * @param hash String representing hash of network + * @param m Shared pointer to mutex for internal locking + * @param refCount Reference counter. Will be decremented on CacheGuardEntry destruction + */ + CacheGuardEntry(CacheGuard& cacheGuard, const std::string& hash, + std::shared_ptr m, std::atomic_int& refCount); + CacheGuardEntry(const CacheGuardEntry&) = delete; + + /** + * @brief Destructor, will perform the following cleanup + * + * Decrement reference counter + * Unlock associated mutex + * Call CacheGuard::checkForRemove to check if appropriate table hash entry is not used anymore and can be deleted + */ + ~CacheGuardEntry(); + + /** + * @brief Performs real lock of associated mutex + * It is separated from construction due to exception safety considerations + * + * @note Will be called only by CacheGuard, it shall not be called from client's code + */ + void performLock(); + +private: + CacheGuard& m_cacheGuard; + std::string m_hash; + std::shared_ptr m_mutex; + std::atomic_int& m_refCount; +}; + +/** + * @brief This class holds a table of currently locked hashes + * Inference engine core will need to obtain a lock for a specific cache to get exclusive access to it + * It is needed to avoid race situations when multiple threads try to to write to the same cache simultaneously + * + * Usage example: + * auto hash = ; + * { + * auto lock = m_cacheGuard.getHashLock(hash); + * + * } + */ +class CacheGuard { +public: + CacheGuard() = default; + ~CacheGuard(); + + /** + * @brief Gets a lock for a specific cache entry identified by it's hash value + * Once returned, client has an exclusive access to cache entry for read/write/delete + * If any other thread holds a lock to same hash - this function will not return until it is unlocked + * + * @param hash String representing hash of network + * + * @return RAII pointer to CacheGuardEntry + */ + std::unique_ptr getHashLock(const std::string& hash); + + /** + * @brief Checks whether there is any clients holding the lock after CacheGuardEntry deletion + * It will be called on destruction of CacheGuardEntry and shall not be used directly by client's code + * If there is no more clients holding the lock, associated entry will be removed from table unlocked + * + * @param hash String representing hash of network + * + * @return RAII pointer to CacheGuardEntry + */ + void checkForRemove(const std::string& hash); + +private: + struct Item { + std::shared_ptr m_mutexPtr { std::make_shared() }; + // Reference counter for item usage + std::atomic_int m_itemRefCounter {0}; + + Item() = default; + Item(const Item& other): m_mutexPtr(other.m_mutexPtr), + m_itemRefCounter(other.m_itemRefCounter.load()) {} + Item(Item&& other): m_mutexPtr(std::move(other.m_mutexPtr)), + m_itemRefCounter(other.m_itemRefCounter.load()) {} + }; + std::mutex m_tableMutex; + std::unordered_map m_table; +}; + +} // namespace InferenceEngine diff --git a/inference-engine/src/inference_engine/ie_core.cpp b/inference-engine/src/inference_engine/ie_core.cpp index 35cb82a3ddb..2122fda276e 100644 --- a/inference-engine/src/inference_engine/ie_core.cpp +++ b/inference-engine/src/inference_engine/ie_core.cpp @@ -21,6 +21,7 @@ #include "ie_plugin_cpp.hpp" #include "ie_plugin_config.hpp" #include "ie_cache_manager.hpp" +#include "ie_cache_guard.hpp" #include "ie_itt.hpp" #include "file_utils.h" #include "ie_network_reader.hpp" @@ -197,6 +198,8 @@ class Core::Impl : public ICore { // Core settings (cache config, etc) CoreConfig coreConfig; + CacheGuard cacheGuard; + struct PluginDescriptor { FileUtils::FilePath libraryLocation; std::map defaultConfig; @@ -447,17 +450,18 @@ public: } auto parsed = parseDeviceNameIntoConfig(context->getDeviceName(), config); auto plugin = GetCPPPluginByName(parsed._deviceName); - bool loadedFromCache = false; ExecutableNetwork res; - std::string hash; auto cacheManager = coreConfig.getCacheConfig()._cacheManager; if (cacheManager && DeviceSupportsImportExport(plugin)) { - hash = CalculateNetworkHash(network, parsed._deviceName, plugin, parsed._config); + auto hash = CalculateNetworkHash(network, parsed._deviceName, plugin, parsed._config); + bool loadedFromCache = false; + auto lock = cacheGuard.getHashLock(hash); res = LoadNetworkFromCache(cacheManager, hash, plugin, parsed._config, context, loadedFromCache); - } - - if (!loadedFromCache) { - res = LoadNetworkImpl(network, plugin, parsed._config, context, hash); + if (!loadedFromCache) { + res = LoadNetworkImpl(network, plugin, parsed._config, context, hash); + } + } else { + res = LoadNetworkImpl(network, plugin, parsed._config, context, {}); } return res; } @@ -472,17 +476,18 @@ public: parsed._config.erase(CONFIG_KEY_INTERNAL(FORCE_DISABLE_CACHE)); } auto plugin = GetCPPPluginByName(parsed._deviceName); - bool loadedFromCache = false; ExecutableNetwork res; - std::string hash; auto cacheManager = coreConfig.getCacheConfig()._cacheManager; if (!forceDisableCache && cacheManager && DeviceSupportsImportExport(plugin)) { - hash = CalculateNetworkHash(network, parsed._deviceName, plugin, parsed._config); + auto hash = CalculateNetworkHash(network, parsed._deviceName, plugin, parsed._config); + bool loadedFromCache = false; + auto lock = cacheGuard.getHashLock(hash); res = LoadNetworkFromCache(cacheManager, hash, plugin, parsed._config, nullptr, loadedFromCache); - } - - if (!loadedFromCache) { - res = LoadNetworkImpl(network, plugin, parsed._config, nullptr, hash, {}, forceDisableCache); + if (!loadedFromCache) { + res = LoadNetworkImpl(network, plugin, parsed._config, nullptr, hash, {}, forceDisableCache); + } + } else { + res = LoadNetworkImpl(network, plugin, parsed._config, nullptr, {}, {}, forceDisableCache); } return res; } @@ -493,19 +498,21 @@ public: OV_ITT_SCOPED_TASK(itt::domains::IE_LT, "Core::LoadNetwork::Path"); auto parsed = parseDeviceNameIntoConfig(deviceName, config); auto plugin = GetCPPPluginByName(parsed._deviceName); - bool loadedFromCache = false; ExecutableNetwork res; - std::string hash; auto cacheManager = coreConfig.getCacheConfig()._cacheManager; if (cacheManager && DeviceSupportsImportExport(plugin)) { - hash = CalculateFileHash(modelPath, parsed._deviceName, plugin, parsed._config); + bool loadedFromCache = false; + auto hash = CalculateFileHash(modelPath, parsed._deviceName, plugin, parsed._config); + auto lock = cacheGuard.getHashLock(hash); res = LoadNetworkFromCache(cacheManager, hash, plugin, parsed._config, nullptr, loadedFromCache, modelPath); - } - - if (!loadedFromCache) { + if (!loadedFromCache) { + auto cnnNetwork = ReadNetwork(modelPath, std::string()); + res = LoadNetworkImpl(cnnNetwork, plugin, parsed._config, nullptr, hash, modelPath); + } + } else { auto cnnNetwork = ReadNetwork(modelPath, std::string()); - res = LoadNetworkImpl(cnnNetwork, plugin, parsed._config, nullptr, hash, modelPath); + res = LoadNetworkImpl(cnnNetwork, plugin, parsed._config, nullptr, {}, modelPath); } return res; } diff --git a/inference-engine/tests/functional/inference_engine/caching_test.cpp b/inference-engine/tests/functional/inference_engine/caching_test.cpp index 20b3ac5bc7c..203b5d91d7e 100644 --- a/inference-engine/tests/functional/inference_engine/caching_test.cpp +++ b/inference-engine/tests/functional/inference_engine/caching_test.cpp @@ -27,6 +27,7 @@ #include "functional_test_utils/network_utils.hpp" #include "unit_test_utils/mocks/mock_iexecutable_network.hpp" +#include "unit_test_utils/mocks/mock_iinfer_request.hpp" using namespace InferenceEngine; using namespace ::testing; @@ -97,6 +98,8 @@ public: MOCK_METHOD0(CreateInferRequest, IInferRequest::Ptr()); MOCK_CONST_METHOD0(GetInputsInfo, ConstInputsDataMap()); MOCK_CONST_METHOD0(GetOutputsInfo, ConstOutputsDataMap()); + MOCK_CONST_METHOD1(GetConfig, Parameter(const std::string& name)); + MOCK_CONST_METHOD1(GetMetric, Parameter(const std::string& name)); }; //------------------------------------------------------ @@ -240,6 +243,31 @@ public: ie.LoadNetwork(cnnNetwork, context, config); } + ExecutableNetwork createMockIExecutableNet() { + auto mock = std::make_shared(); + EXPECT_CALL(*mock, GetInputsInfo(_, _)).Times(AnyNumber()).WillRepeatedly(Return(OK)); + EXPECT_CALL(*mock, GetOutputsInfo(_, _)).Times(AnyNumber()).WillRepeatedly(Return(OK)); + EXPECT_CALL(*mock, GetConfig(PluginConfigParams::KEY_PERF_COUNT, _, _)).Times(AnyNumber()) + .WillRepeatedly(Invoke([&](const std::string &name, Parameter &result, ResponseDesc *resp) { + result = PluginConfigParams::NO; + return OK; + })); + EXPECT_CALL(*mock, GetMetric(METRIC_KEY(OPTIMAL_NUMBER_OF_INFER_REQUESTS), _, _)).Times(AnyNumber()) + .WillRepeatedly(Invoke([&](const std::string &name, Parameter &result, ResponseDesc *resp) { + result = (unsigned int) 1; + return OK; + })); + EXPECT_CALL(*mock, CreateInferRequest(_, _)).Times(AnyNumber()) + .WillRepeatedly(Invoke([&](IInferRequest::Ptr &req, ResponseDesc*) { + auto ptr = std::make_shared(); + EXPECT_CALL(*ptr, SetCompletionCallback(_)).Times(AnyNumber()).WillRepeatedly(Return(OK)); + EXPECT_CALL(*ptr, SetUserData(_, _)).Times(AnyNumber()).WillRepeatedly(Return(OK)); + req = ptr; + return OK; + })); + return ExecutableNetwork(mock); + } + private: template std::function make_std_function(const std::string& functionName) { @@ -271,18 +299,12 @@ private: ON_CALL(plugin, ImportNetworkImpl(_, _, _)). WillByDefault(Invoke([&](std::istream &istr, RemoteContext::Ptr, const std::map &) { - auto mock = std::make_shared(); - EXPECT_CALL(*mock, GetInputsInfo(_, _)).Times(AnyNumber()).WillRepeatedly(Return(OK)); - EXPECT_CALL(*mock, GetOutputsInfo(_, _)).Times(AnyNumber()).WillRepeatedly(Return(OK)); - return ExecutableNetwork(mock); + return createMockIExecutableNet(); })); ON_CALL(plugin, ImportNetworkImpl(_, _)). WillByDefault(Invoke([&](std::istream &istr, const std::map &) { - auto mock = std::make_shared(); - EXPECT_CALL(*mock, GetInputsInfo(_, _)).Times(AnyNumber()).WillRepeatedly(Return(OK)); - EXPECT_CALL(*mock, GetOutputsInfo(_, _)).Times(AnyNumber()).WillRepeatedly(Return(OK)); - return ExecutableNetwork(mock); + return createMockIExecutableNet(); })); ON_CALL(plugin, LoadExeNetworkImpl(_, _, _)). @@ -318,6 +340,27 @@ private: .WillRepeatedly(Return(ConstInputsDataMap{})); EXPECT_CALL(*net, GetOutputsInfo()).Times(AnyNumber()) .WillRepeatedly(Return(ConstOutputsDataMap{})); + EXPECT_CALL(*net, GetConfig(PluginConfigParams::KEY_PERF_COUNT)).Times(AnyNumber()) + .WillRepeatedly(Return(PluginConfigParams::NO)); + EXPECT_CALL(*net, GetMetric(METRIC_KEY(OPTIMAL_NUMBER_OF_INFER_REQUESTS))).Times(AnyNumber()) + .WillRepeatedly(Return((unsigned int) 1)); + EXPECT_CALL(*net, GetMetric(METRIC_KEY(NETWORK_NAME))).Times(AnyNumber()) + .WillRepeatedly(Return("mock_net")); + EXPECT_CALL(*net, GetMetric(METRIC_KEY(SUPPORTED_METRICS))).Times(AnyNumber()) + .WillRepeatedly(Invoke([&](const std::string &) { + std::vector res; + res.push_back(METRIC_KEY(OPTIMAL_NUMBER_OF_INFER_REQUESTS)); + res.push_back(METRIC_KEY(NETWORK_NAME)); + return res; + })); + EXPECT_CALL(*net, CreateInferRequest()).Times(AnyNumber()) + .WillRepeatedly(Invoke([&]() { + std::vector res; + auto inferReq = std::make_shared(); + EXPECT_CALL(*inferReq, SetCompletionCallback(_)).Times(AnyNumber()).WillRepeatedly(Return(OK)); + EXPECT_CALL(*inferReq, SetUserData(_, _)).Times(AnyNumber()).WillRepeatedly(Return(OK)); + return inferReq; + })); } }; @@ -1172,6 +1215,130 @@ TEST_P(CachingTest, LoadHetero_MultiArchs_TargetFallback_FromCore) { } } +// MULTI-DEVICE test +// Test that it is safe to load multiple devices sharing same cache +TEST_P(CachingTest, LoadMulti_race) { + const auto TEST_DURATION_MS = 2000; + const auto TEST_DEVICE_MAX_COUNT = 10; + EXPECT_CALL(*mockPlugin, GetMetric(_, _)).Times(AnyNumber()); + EXPECT_CALL(*mockPlugin, QueryNetwork(_, _)).Times(AnyNumber()); + EXPECT_CALL(*mockPlugin, GetMetric(METRIC_KEY(DEVICE_ARCHITECTURE), _)).Times(AnyNumber()); + if (m_remoteContext) { + return; // skip the remote Context test for Multi plugin + } + int index = 0; + auto start = high_resolution_clock::now(); + do { + std::string cacheDir = m_cacheDir + std::to_string(index); + MkDirGuard guard(cacheDir); + int devCount = 1 + index % (TEST_DEVICE_MAX_COUNT - 1); // try dynamic number of devices from 1 to max + deviceToLoad = CommonTestUtils::DEVICE_MULTI; + deviceToLoad += ":mock.0"; + for (int i = 1; i < devCount; i++) { + deviceToLoad += ",mock." + std::to_string(i); + } + + EXPECT_CALL(*mockPlugin, LoadExeNetworkImpl(_, _, _)).Times(0); + EXPECT_CALL(*mockPlugin, LoadExeNetworkImpl(_, _)).Times(1); + EXPECT_CALL(*mockPlugin, ImportNetworkImpl(_, _, _)).Times(0); + EXPECT_CALL(*mockPlugin, ImportNetworkImpl(_, _)).Times(devCount - 1); + EXPECT_CALL(*net, ExportImpl(_)).Times(1); + testLoad([&](Core &ie) { + ie.SetConfig({{CONFIG_KEY(CACHE_DIR), cacheDir}}); + ASSERT_NO_THROW(m_testFunction(ie)); + }); + index++; + } while (duration_cast(high_resolution_clock::now() - start).count() < TEST_DURATION_MS); + std::cout << "Caching LoadMulti Test completed. Tried " << index << " times" << std::endl; +} + +TEST_P(CachingTest, Load_threads) { + const auto TEST_DURATION_MS = 2000; + const auto THREADS_COUNT = 4; + EXPECT_CALL(*mockPlugin, GetMetric(_, _)).Times(AnyNumber()); + EXPECT_CALL(*mockPlugin, QueryNetwork(_, _)).Times(AnyNumber()); + EXPECT_CALL(*mockPlugin, GetMetric(METRIC_KEY(DEVICE_ARCHITECTURE), _)).Times(AnyNumber()); + if (m_remoteContext) { + return; // skip the remote Context test for Multi plugin + } + auto start = high_resolution_clock::now(); + int index = 0; + do { + std::string cacheDir = m_cacheDir + std::to_string(index); + MkDirGuard guard(cacheDir); + EXPECT_CALL(*mockPlugin, LoadExeNetworkImpl(_, _, _)).Times(0); + EXPECT_CALL(*mockPlugin, LoadExeNetworkImpl(_, _)).Times(1); + EXPECT_CALL(*mockPlugin, ImportNetworkImpl(_, _, _)).Times(0); + EXPECT_CALL(*mockPlugin, ImportNetworkImpl(_, _)).Times(THREADS_COUNT - 1); + EXPECT_CALL(*net, ExportImpl(_)).Times(1); + testLoad([&](Core &ie) { + ie.SetConfig({{CONFIG_KEY(CACHE_DIR), cacheDir}}); + std::vector threads; + for (int i = 0; i < THREADS_COUNT; i++) { + threads.emplace_back(([&]() { m_testFunction(ie); })); + } + for (int i = 0; i < THREADS_COUNT; i++) { + threads[i].join(); + } + }); + index++; + } while (duration_cast(high_resolution_clock::now() - start).count() < TEST_DURATION_MS); + std::cout << "Caching Load multiple threads test completed. Tried " << index << " times" << std::endl; +} + +// MULTI-DEVICE test +// Test that loading of device with one architecture doesn't block loading of device with another architecture +TEST_P(CachingTest, LoadMulti_Archs) { + const auto IMPORT_DELAY_LONG_MS = 3000; + const auto TEST_DEVICE_MAX_COUNT = 30; // Shall be >= 2 + const auto IMPORT_DELAY_SHORT_MS = 100; + const auto EXP_MAX_EXEC_TIME_MS = 5500; + EXPECT_CALL(*mockPlugin, GetMetric(_, _)).Times(AnyNumber()); + EXPECT_CALL(*mockPlugin, QueryNetwork(_, _)).Times(AnyNumber()); + EXPECT_CALL(*mockPlugin, GetMetric(METRIC_KEY(DEVICE_ARCHITECTURE), _)).Times(AnyNumber()) + .WillRepeatedly(Invoke([&](const std::string &, const std::map &options) { + auto id = options.at("DEVICE_ID").as(); + if (std::stoi(id) < 2) { + return "mock_first_architecture"; + } else { + return "mock_another_architecture"; + } + })); + if (m_remoteContext) { + return; // skip the remote Context test for Multi plugin + } + + deviceToLoad = CommonTestUtils::DEVICE_MULTI; + deviceToLoad += ":mock.0"; + for (int i = 1; i < TEST_DEVICE_MAX_COUNT; i++) { + deviceToLoad += ",mock." + std::to_string(i); + } + + auto start = high_resolution_clock::now(); + { + EXPECT_CALL(*mockPlugin, LoadExeNetworkImpl(_, _, _)).Times(0); + EXPECT_CALL(*mockPlugin, LoadExeNetworkImpl(_, _)).Times(2); + + EXPECT_CALL(*mockPlugin, ImportNetworkImpl(_, _, _)).Times(0); + EXPECT_CALL(*mockPlugin, ImportNetworkImpl(_, _)).Times(TEST_DEVICE_MAX_COUNT - 2) + .WillRepeatedly(Invoke([&](std::istream &, const std::map &opt) { + auto id = opt.at("DEVICE_ID"); + if (std::stoi(id) < 2) { + std::this_thread::sleep_for(milliseconds(IMPORT_DELAY_LONG_MS)); + } else { + std::this_thread::sleep_for(milliseconds(IMPORT_DELAY_SHORT_MS)); + } + return createMockIExecutableNet(); + })); + EXPECT_CALL(*net, ExportImpl(_)).Times(2); + testLoad([&](Core &ie) { + ie.SetConfig({{CONFIG_KEY(CACHE_DIR), m_cacheDir}}); + ASSERT_NO_THROW(m_testFunction(ie)); + }); + } + ASSERT_LT(duration_cast(high_resolution_clock::now() - start).count(), EXP_MAX_EXEC_TIME_MS); +} + INSTANTIATE_TEST_CASE_P(CachingTest, CachingTest, ::testing::Combine( ::testing::ValuesIn(loadVariants), diff --git a/inference-engine/tests/unit/inference_engine/ie_compilation_context_test.cpp b/inference-engine/tests/unit/inference_engine/ie_compilation_context_test.cpp index c3d428dbf8b..a52ce386a7a 100644 --- a/inference-engine/tests/unit/inference_engine/ie_compilation_context_test.cpp +++ b/inference-engine/tests/unit/inference_engine/ie_compilation_context_test.cpp @@ -349,6 +349,33 @@ TEST(NetworkContext_CNNNetwork, HashWithDifferentMeanValues) { NetworkCompilationContext::computeHash(net3, {})); } +// Verify all internal hash calculations are thread-safe (like ngraph::function serialization) +TEST(NetworkContext_CNNNetwork, HashOfSameMultiThreading) { + auto net1 = createNetwork(); + auto net2 = createNetwork(); + std::atomic_bool fail{false}; + const auto TEST_DURATION_MS = 1000; + auto start = high_resolution_clock::now(); + int t1Count = 0, t2Count = 0; + auto threadFun = [&](int& count) { + do { + count++; + auto hash1 = NetworkCompilationContext::computeHash(net1, {}); + auto hash2 = NetworkCompilationContext::computeHash(net2, {}); + if (hash1 != hash2) { + fail = true; + break; + } + } while (!fail && duration_cast(high_resolution_clock::now() - start).count() < TEST_DURATION_MS); + }; + std::thread t1(threadFun, std::ref(t1Count)); + std::thread t2(threadFun, std::ref(t2Count)); + t1.join(); + t2.join(); + std::cout << "Hash threading test finished. Total runs = " << t1Count + t2Count << std::endl; + ASSERT_FALSE(fail); +} + //////////////////////////////////////////// TEST(NetworkContext_ModelName, HashOfSame) { diff --git a/ngraph/core/src/opsets/opset.cpp b/ngraph/core/src/opsets/opset.cpp index a59ca4c3726..ea09eec98c1 100644 --- a/ngraph/core/src/opsets/opset.cpp +++ b/ngraph/core/src/opsets/opset.cpp @@ -34,133 +34,84 @@ ngraph::Node* ngraph::OpSet::create_insensitive(const std::string& name) const const ngraph::OpSet& ngraph::get_opset1() { - static std::mutex init_mutex; - static bool opset_is_initialized = false; static OpSet opset; - if (!opset_is_initialized) - { - std::lock_guard guard(init_mutex); - if (!opset_is_initialized) - { + static std::once_flag flag; + std::call_once(flag, [&]() { #define NGRAPH_OP(NAME, NAMESPACE) opset.insert(); #include "ngraph/opsets/opset1_tbl.hpp" #undef NGRAPH_OP - opset_is_initialized = true; - } - } + }); return opset; } const ngraph::OpSet& ngraph::get_opset2() { - static std::mutex init_mutex; - static bool opset_is_initialized = false; static OpSet opset; - if (!opset_is_initialized) - { - std::lock_guard guard(init_mutex); - if (!opset_is_initialized) - { + static std::once_flag flag; + std::call_once(flag, [&]() { #define NGRAPH_OP(NAME, NAMESPACE) opset.insert(); #include "ngraph/opsets/opset2_tbl.hpp" #undef NGRAPH_OP - opset_is_initialized = true; - } - } + }); return opset; } const ngraph::OpSet& ngraph::get_opset3() { - static std::mutex init_mutex; - static bool opset_is_initialized = false; static OpSet opset; - if (!opset_is_initialized) - { - std::lock_guard guard(init_mutex); - if (!opset_is_initialized) - { + static std::once_flag flag; + std::call_once(flag, [&]() { #define NGRAPH_OP(NAME, NAMESPACE) opset.insert(); #include "ngraph/opsets/opset3_tbl.hpp" #undef NGRAPH_OP - opset_is_initialized = true; - } - } + }); return opset; } const ngraph::OpSet& ngraph::get_opset4() { - static std::mutex init_mutex; - static bool opset_is_initialized = false; static OpSet opset; - if (!opset_is_initialized) - { - std::lock_guard guard(init_mutex); - if (!opset_is_initialized) - { + static std::once_flag flag; + std::call_once(flag, [&]() { #define NGRAPH_OP(NAME, NAMESPACE) opset.insert(); #include "ngraph/opsets/opset4_tbl.hpp" #undef NGRAPH_OP - opset_is_initialized = true; - } - } + }); return opset; } const ngraph::OpSet& ngraph::get_opset5() { - static std::mutex init_mutex; - static bool opset_is_initialized = false; static OpSet opset; - if (!opset_is_initialized) - { - std::lock_guard guard(init_mutex); - if (!opset_is_initialized) - { + static std::once_flag flag; + std::call_once(flag, [&]() { #define NGRAPH_OP(NAME, NAMESPACE) opset.insert(); #include "ngraph/opsets/opset5_tbl.hpp" #undef NGRAPH_OP - opset_is_initialized = true; - } - } + }); return opset; } const ngraph::OpSet& ngraph::get_opset6() { - static std::mutex init_mutex; - static bool opset_is_initialized = false; static OpSet opset; - if (!opset_is_initialized) - { - std::lock_guard guard(init_mutex); - if (!opset_is_initialized) - { + static std::once_flag flag; + std::call_once(flag, [&]() { #define NGRAPH_OP(NAME, NAMESPACE) opset.insert(); #include "ngraph/opsets/opset6_tbl.hpp" #undef NGRAPH_OP - opset_is_initialized = true; - } - } + }); return opset; } const ngraph::OpSet& ngraph::get_opset7() { - static std::mutex init_mutex; - static bool opset_is_initialized = false; static OpSet opset; - if (!opset_is_initialized) - { - std::lock_guard guard(init_mutex); - if (!opset_is_initialized) - { + static std::once_flag flag; + std::call_once(flag, [&]() { #define NGRAPH_OP(NAME, NAMESPACE) opset.insert(); #include "ngraph/opsets/opset7_tbl.hpp" #undef NGRAPH_OP - opset_is_initialized = true; - } - } + }); return opset; } diff --git a/ngraph/core/src/partial_shape.cpp b/ngraph/core/src/partial_shape.cpp index cb1fea2de3e..35929dc9e8d 100644 --- a/ngraph/core/src/partial_shape.cpp +++ b/ngraph/core/src/partial_shape.cpp @@ -484,7 +484,7 @@ const std::vector& ngraph::AttributeAdapter::get( { for (size_t i = 0; i < m_ref.rank().get_length(); ++i) { - auto& elt = m_ref[i]; + const auto& elt = static_cast(m_ref)[i]; m_buffer.push_back(elt.is_dynamic() ? -1 : elt.get_length()); } } diff --git a/ngraph/test/op.cpp b/ngraph/test/op.cpp index d8cfceb0167..32cfcf09821 100644 --- a/ngraph/test/op.cpp +++ b/ngraph/test/op.cpp @@ -6,6 +6,7 @@ #include #include #include +#include #include "gmock/gmock.h" #include "gtest/gtest.h" @@ -13,6 +14,7 @@ #include "ngraph/graph_util.hpp" #include "ngraph/ngraph.hpp" #include "ngraph/variant.hpp" +#include "ngraph/opsets/opset.hpp" NGRAPH_SUPPRESS_DEPRECATED_START @@ -50,6 +52,35 @@ TEST(op, provenance_tag) ASSERT_TRUE(tags.find(tag2) != tags.end()); } +TEST(op, opset_multi_thread) { + auto doTest = [&](std::function fun) { + std::atomic opset {nullptr}; + std::atomic_bool failed {false}; + auto threadFun = [&] () { + const ngraph::OpSet* op = &fun(); + const ngraph::OpSet* current = opset; + do { + if (current != nullptr && current != op) { + failed = true; + break; + } + } while (opset.compare_exchange_strong(op, current)); + }; + std::thread t1 {threadFun}; + std::thread t2 {threadFun}; + t1.join(); + t2.join(); + ASSERT_FALSE(failed); + }; + doTest(ngraph::get_opset1); + doTest(ngraph::get_opset2); + doTest(ngraph::get_opset3); + doTest(ngraph::get_opset4); + doTest(ngraph::get_opset5); + doTest(ngraph::get_opset6); + doTest(ngraph::get_opset7); +} + struct Ship { std::string name;