[Model Caching] Enabling per-device cache dir (#10774)

* Initial commit

10 more caching tests

* Fix clang-format

* Added brief explanations to each test

* Fix review comments
This commit is contained in:
Mikhail Nosov
2022-03-24 11:24:47 +03:00
committed by GitHub
parent 0cc119cd86
commit 9d865a2133
3 changed files with 489 additions and 52 deletions

View File

@@ -90,7 +90,7 @@ public:
* @brief Constructor
*
*/
FileStorageCacheManager(std::string&& cachePath) : m_cachePath(std::move(cachePath)) {}
FileStorageCacheManager(std::string cachePath) : m_cachePath(std::move(cachePath)) {}
/**
* @brief Destructor

View File

@@ -150,27 +150,75 @@ class CoreImpl : public ie::ICore, public std::enable_shared_from_this<ie::ICore
auto it = config.find(CONFIG_KEY(CACHE_DIR));
if (it != config.end()) {
std::lock_guard<std::mutex> lock(_cacheConfigMutex);
_cacheConfig._cacheDir = it->second;
if (!it->second.empty()) {
FileUtils::createDirectoryRecursive(it->second);
_cacheConfig._cacheManager = std::make_shared<ie::FileStorageCacheManager>(std::move(it->second));
} else {
_cacheConfig._cacheManager = nullptr;
fillConfig(_cacheConfig, it->second);
for (auto& deviceCfg : _cacheConfigPerDevice) {
fillConfig(deviceCfg.second, it->second);
}
config.erase(it);
}
}
// Creating thread-safe copy of config including shared_ptr to ICacheManager
CacheConfig getCacheConfig() const {
void setCacheForDevice(const std::string& dir, const std::string& name) {
std::lock_guard<std::mutex> lock(_cacheConfigMutex);
return _cacheConfig;
fillConfig(_cacheConfigPerDevice[name], dir);
}
// Creating thread-safe copy of config including shared_ptr to ICacheManager
// Passing empty or not-existing name will return global cache config
CacheConfig getCacheConfigForDevice(const std::string& device_name,
bool deviceSupportsCacheDir,
std::map<std::string, std::string>& parsedConfig) const {
if (parsedConfig.count(CONFIG_KEY(CACHE_DIR))) {
CoreConfig::CacheConfig tempConfig;
CoreConfig::fillConfig(tempConfig, parsedConfig.at(CONFIG_KEY(CACHE_DIR)));
if (!deviceSupportsCacheDir) {
parsedConfig.erase(CONFIG_KEY(CACHE_DIR));
}
return tempConfig;
} else {
std::lock_guard<std::mutex> lock(_cacheConfigMutex);
if (_cacheConfigPerDevice.count(device_name) > 0) {
return _cacheConfigPerDevice.at(device_name);
} else {
return _cacheConfig;
}
}
}
CacheConfig getCacheConfigForDevice(const std::string& device_name) const {
std::lock_guard<std::mutex> lock(_cacheConfigMutex);
if (_cacheConfigPerDevice.count(device_name) > 0) {
return _cacheConfigPerDevice.at(device_name);
} else {
return _cacheConfig;
}
}
private:
static void fillConfig(CacheConfig& config, const std::string& dir) {
config._cacheDir = dir;
if (!dir.empty()) {
FileUtils::createDirectoryRecursive(dir);
config._cacheManager = std::make_shared<ie::FileStorageCacheManager>(dir);
} else {
config._cacheManager = nullptr;
}
}
private:
mutable std::mutex _cacheConfigMutex;
CacheConfig _cacheConfig;
std::map<std::string, CacheConfig> _cacheConfigPerDevice;
};
struct CacheContent {
explicit CacheContent(const std::shared_ptr<ie::ICacheManager>& cache_manager,
const std::string model_path = {})
: cacheManager(cache_manager),
modelPath(model_path) {}
std::shared_ptr<ie::ICacheManager> cacheManager;
std::string blobId = {};
std::string modelPath = {};
};
// Core settings (cache config, etc)
@@ -246,46 +294,42 @@ class CoreImpl : public ie::ICore, public std::enable_shared_from_this<ie::ICore
ov::InferencePlugin& plugin,
const std::map<std::string, std::string>& parsedConfig,
const ie::RemoteContext::Ptr& context,
const std::string& blobID,
const std::string& modelPath = std::string(),
const CacheContent& cacheContent,
bool forceDisableCache = false) {
OV_ITT_SCOPED_TASK(ov::itt::domains::IE, "CoreImpl::compile_model_impl");
ov::SoPtr<ie::IExecutableNetworkInternal> execNetwork;
execNetwork = context ? plugin.compile_model(network, context, parsedConfig)
: plugin.compile_model(network, parsedConfig);
auto cacheManager = coreConfig.getCacheConfig()._cacheManager;
if (!forceDisableCache && cacheManager && DeviceSupportsImportExport(plugin)) {
if (!forceDisableCache && cacheContent.cacheManager && DeviceSupportsImportExport(plugin)) {
try {
// need to export network for further import from "cache"
OV_ITT_SCOPE(FIRST_INFERENCE, ie::itt::domains::IE_LT, "Core::LoadNetwork::Export");
cacheManager->writeCacheEntry(blobID, [&](std::ostream& networkStream) {
cacheContent.cacheManager->writeCacheEntry(cacheContent.blobId, [&](std::ostream& networkStream) {
networkStream << ie::CompiledBlobHeader(
ie::GetInferenceEngineVersion()->buildNumber,
ie::NetworkCompilationContext::calculateFileInfo(modelPath));
ie::NetworkCompilationContext::calculateFileInfo(cacheContent.modelPath));
execNetwork->Export(networkStream);
});
} catch (...) {
cacheManager->removeCacheEntry(blobID);
cacheContent.cacheManager->removeCacheEntry(cacheContent.blobId);
throw;
}
}
return execNetwork;
}
ov::SoPtr<ie::IExecutableNetworkInternal> LoadNetworkFromCache(
const std::shared_ptr<ie::ICacheManager>& cacheManager,
const std::string& blobId,
static ov::SoPtr<ie::IExecutableNetworkInternal> LoadNetworkFromCache(
const CacheContent& cacheContent,
ov::InferencePlugin& plugin,
const std::map<std::string, std::string>& config,
const std::shared_ptr<ie::RemoteContext>& context,
bool& networkIsImported,
const std::string& modelPath = std::string()) {
bool& networkIsImported) {
ov::SoPtr<ie::IExecutableNetworkInternal> execNetwork;
struct HeaderException {};
OPENVINO_ASSERT(cacheManager != nullptr);
OPENVINO_ASSERT(cacheContent.cacheManager != nullptr);
try {
cacheManager->readCacheEntry(blobId, [&](std::istream& networkStream) {
cacheContent.cacheManager->readCacheEntry(cacheContent.blobId, [&](std::istream& networkStream) {
OV_ITT_SCOPE(FIRST_INFERENCE,
ie::itt::domains::IE_LT,
"Core::LoadNetworkFromCache::ReadStreamAndImport");
@@ -296,7 +340,8 @@ class CoreImpl : public ie::ICore, public std::enable_shared_from_this<ie::ICore
// Build number mismatch, don't use this cache
throw ie::NetworkNotRead("Version does not match");
}
if (header.getFileInfo() != ie::NetworkCompilationContext::calculateFileInfo(modelPath)) {
if (header.getFileInfo() !=
ie::NetworkCompilationContext::calculateFileInfo(cacheContent.modelPath)) {
// Original file is changed, don't use cache
throw ie::NetworkNotRead("Original model file is changed");
}
@@ -310,10 +355,10 @@ class CoreImpl : public ie::ICore, public std::enable_shared_from_this<ie::ICore
});
} catch (const HeaderException&) {
// For these exceptions just remove old cache and set that import didn't work
cacheManager->removeCacheEntry(blobId);
cacheContent.cacheManager->removeCacheEntry(cacheContent.blobId);
networkIsImported = false;
} catch (...) {
cacheManager->removeCacheEntry(blobId);
cacheContent.cacheManager->removeCacheEntry(cacheContent.blobId);
networkIsImported = false;
// TODO: temporary disabled by #54335. In future don't throw only for new 'blob_outdated' exception
// throw;
@@ -538,20 +583,23 @@ public:
auto plugin = GetCPPPluginByName(parsed._deviceName);
ov::SoPtr<ie::IExecutableNetworkInternal> res;
auto cacheManager = coreConfig.getCacheConfig()._cacheManager;
auto cacheManager =
coreConfig.getCacheConfigForDevice(parsed._deviceName, DeviceSupportsCacheDir(plugin), parsed._config)
._cacheManager;
auto cacheContent = CacheContent{cacheManager};
if (cacheManager && DeviceSupportsImportExport(plugin)) {
auto hash = CalculateNetworkHash(network, parsed._deviceName, plugin, parsed._config);
cacheContent.blobId = CalculateNetworkHash(network, parsed._deviceName, plugin, parsed._config);
bool loadedFromCache = false;
auto lock = cacheGuard.getHashLock(hash);
res = LoadNetworkFromCache(cacheManager, hash, plugin, parsed._config, context, loadedFromCache);
auto lock = cacheGuard.getHashLock(cacheContent.blobId);
res = LoadNetworkFromCache(cacheContent, plugin, parsed._config, context, loadedFromCache);
if (!loadedFromCache) {
res = compile_model_impl(network, plugin, parsed._config, context, hash);
res = compile_model_impl(network, plugin, parsed._config, context, cacheContent);
} else {
// Temporary workaround until all plugins support caching of original model inputs
InferenceEngine::SetExeNetworkInfo(res._ptr, network.getFunction(), isNewAPI());
}
} else {
res = compile_model_impl(network, plugin, parsed._config, context, {});
res = compile_model_impl(network, plugin, parsed._config, context, cacheContent);
}
return res;
}
@@ -634,20 +682,23 @@ public:
}
auto plugin = GetCPPPluginByName(parsed._deviceName);
ov::SoPtr<ie::IExecutableNetworkInternal> res;
auto cacheManager = coreConfig.getCacheConfig()._cacheManager;
auto cacheManager =
coreConfig.getCacheConfigForDevice(parsed._deviceName, DeviceSupportsCacheDir(plugin), parsed._config)
._cacheManager;
auto cacheContent = CacheContent{cacheManager};
if (!forceDisableCache && cacheManager && DeviceSupportsImportExport(plugin)) {
auto hash = CalculateNetworkHash(network, parsed._deviceName, plugin, parsed._config);
cacheContent.blobId = CalculateNetworkHash(network, parsed._deviceName, plugin, parsed._config);
bool loadedFromCache = false;
auto lock = cacheGuard.getHashLock(hash);
res = LoadNetworkFromCache(cacheManager, hash, plugin, parsed._config, nullptr, loadedFromCache);
auto lock = cacheGuard.getHashLock(cacheContent.blobId);
res = LoadNetworkFromCache(cacheContent, plugin, parsed._config, nullptr, loadedFromCache);
if (!loadedFromCache) {
res = compile_model_impl(network, plugin, parsed._config, nullptr, hash, {}, forceDisableCache);
res = compile_model_impl(network, plugin, parsed._config, nullptr, cacheContent, forceDisableCache);
} else {
// Temporary workaround until all plugins support caching of original model inputs
InferenceEngine::SetExeNetworkInfo(res._ptr, network.getFunction(), isNewAPI());
}
} else {
res = compile_model_impl(network, plugin, parsed._config, nullptr, {}, {}, forceDisableCache);
res = compile_model_impl(network, plugin, parsed._config, nullptr, cacheContent, forceDisableCache);
}
return {res._ptr, res._so};
}
@@ -660,18 +711,21 @@ public:
auto parsed = parseDeviceNameIntoConfig(deviceName, config);
auto plugin = GetCPPPluginByName(parsed._deviceName);
ov::SoPtr<ie::IExecutableNetworkInternal> res;
auto cacheManager = coreConfig.getCacheConfig()._cacheManager;
auto cacheManager =
coreConfig.getCacheConfigForDevice(parsed._deviceName, DeviceSupportsCacheDir(plugin), parsed._config)
._cacheManager;
auto cacheContent = CacheContent{cacheManager, modelPath};
if (cacheManager && DeviceSupportsImportExport(plugin)) {
bool loadedFromCache = false;
auto hash = CalculateFileHash(modelPath, parsed._deviceName, plugin, parsed._config);
auto lock = cacheGuard.getHashLock(hash);
res = LoadNetworkFromCache(cacheManager, hash, plugin, parsed._config, nullptr, loadedFromCache, modelPath);
cacheContent.blobId = CalculateFileHash(modelPath, parsed._deviceName, plugin, parsed._config);
auto lock = cacheGuard.getHashLock(cacheContent.blobId);
res = LoadNetworkFromCache(cacheContent, plugin, parsed._config, nullptr, loadedFromCache);
if (!loadedFromCache) {
auto cnnNetwork = ReadNetwork(modelPath, std::string());
if (val) {
val(cnnNetwork);
}
res = compile_model_impl(cnnNetwork, plugin, parsed._config, nullptr, hash, modelPath);
res = compile_model_impl(cnnNetwork, plugin, parsed._config, nullptr, cacheContent);
}
} else if (cacheManager) {
// TODO: 'validation' for dynamic API doesn't work for this case, as it affects a lot of plugin API
@@ -681,7 +735,7 @@ public:
if (val) {
val(cnnNetwork);
}
res = compile_model_impl(cnnNetwork, plugin, parsed._config, nullptr, {}, modelPath);
res = compile_model_impl(cnnNetwork, plugin, parsed._config, nullptr, cacheContent);
}
return {res._ptr, res._so};
}
@@ -921,10 +975,13 @@ public:
// configuring
{
if (DeviceSupportsCacheDir(plugin)) {
auto cacheConfig = coreConfig.getCacheConfig();
auto cacheConfig = coreConfig.getCacheConfigForDevice(deviceName);
if (cacheConfig._cacheManager) {
desc.defaultConfig[CONFIG_KEY(CACHE_DIR)] = cacheConfig._cacheDir;
}
} else if (desc.defaultConfig.count(CONFIG_KEY(CACHE_DIR)) > 0) {
// Remove "CACHE_DIR" from config if it is not supported by plugin
desc.defaultConfig.erase(CONFIG_KEY(CACHE_DIR));
}
allowNotImplemented([&]() {
// Add device specific value to support device_name.device_id cases
@@ -1056,6 +1113,11 @@ public:
if (deviceName.empty()) {
coreConfig.setAndUpdate(config);
} else {
auto cache_it = config.find(CONFIG_KEY(CACHE_DIR));
if (cache_it != config.end()) {
coreConfig.setCacheForDevice(cache_it->second, clearDeviceName);
}
}
auto base_desc = pluginRegistry.find(clearDeviceName);
@@ -1085,10 +1147,13 @@ public:
allowNotImplemented([&]() {
auto configCopy = config;
if (DeviceSupportsCacheDir(plugin.second)) {
auto cacheConfig = coreConfig.getCacheConfig();
auto cacheConfig = coreConfig.getCacheConfigForDevice(deviceName);
if (cacheConfig._cacheManager) {
configCopy[CONFIG_KEY(CACHE_DIR)] = cacheConfig._cacheDir;
}
} else if (configCopy.count(CONFIG_KEY(CACHE_DIR)) > 0) {
// Remove "CACHE_DIR" from config if it is not supported by plugin
configCopy.erase(CONFIG_KEY(CACHE_DIR));
}
// Add device specific value to support device_name.device_id cases
std::vector<std::string> supportedConfigKeys =

View File

@@ -194,6 +194,8 @@ public:
CNNCallback m_cnnCallback = nullptr;
std::map<std::string, InputsDataMap> m_inputs_map;
std::map<std::string, OutputsDataMap> m_outputs_map;
using CheckConfigCb = std::function<void(const std::map<std::string, std::string> &)>;
CheckConfigCb m_checkConfigCb = nullptr;
static std::string get_mock_engine_name() {
std::string mockEngineName("mock_engine");
@@ -377,6 +379,10 @@ private:
ov::device::capabilities.name(),
ov::device::architecture.name()};
}));
ON_CALL(plugin, GetMetric(METRIC_KEY(OPTIMIZATION_CAPABILITIES), _)).
WillByDefault(Return(std::vector<std::string>()));
ON_CALL(plugin, GetMetric(METRIC_KEY(IMPORT_EXPORT_SUPPORT), _)).
WillByDefault(Return(true));
@@ -401,7 +407,10 @@ private:
ON_CALL(plugin, ImportNetwork(_, _, _)).
WillByDefault(Invoke([&](std::istream &istr, const RemoteContext::Ptr&,
const std::map<std::string, std::string> &) {
const std::map<std::string, std::string> &config) {
if (m_checkConfigCb) {
m_checkConfigCb(config);
}
std::string name;
istr >> name;
char space;
@@ -411,7 +420,10 @@ private:
}));
ON_CALL(plugin, ImportNetwork(_, _)).
WillByDefault(Invoke([&](std::istream &istr, const std::map<std::string, std::string> &) {
WillByDefault(Invoke([&](std::istream &istr, const std::map<std::string, std::string> &config) {
if (m_checkConfigCb) {
m_checkConfigCb(config);
}
std::string name;
istr >> name;
char space;
@@ -422,7 +434,10 @@ private:
ON_CALL(plugin, LoadExeNetworkImpl(_, _, _)).
WillByDefault(Invoke([&](const CNNNetwork & cnn, const RemoteContext::Ptr&,
const std::map<std::string, std::string> &) {
const std::map<std::string, std::string> &config) {
if (m_checkConfigCb) {
m_checkConfigCb(config);
}
std::lock_guard<std::mutex> lock(mock_creation_mutex);
std::string name = cnn.getFunction()->get_friendly_name();
m_inputs_map[name] = cnn.getInputsInfo();
@@ -440,7 +455,10 @@ private:
ON_CALL(plugin, LoadExeNetworkImpl(_, _)).
WillByDefault(Invoke([&](const CNNNetwork & cnn,
const std::map<std::string, std::string> &) {
const std::map<std::string, std::string> &config) {
if (m_checkConfigCb) {
m_checkConfigCb(config);
}
std::string name = cnn.getFunction()->get_friendly_name();
std::lock_guard<std::mutex> lock(mock_creation_mutex);
m_inputs_map[name] = cnn.getInputsInfo();
@@ -517,6 +535,44 @@ TEST_P(CachingTest, TestLoad) {
}
}
/// \brief Verifies that ie.SetConfig({{"CACHE_DIR", <dir>}}, "deviceName"}}); enables caching for one device
TEST_P(CachingTest, TestLoad_by_device_name) {
EXPECT_CALL(*mockPlugin, GetMetric(METRIC_KEY(SUPPORTED_CONFIG_KEYS), _)).Times(AnyNumber());
EXPECT_CALL(*mockPlugin, GetMetric(ov::supported_properties.name(), _)).Times(AnyNumber());
EXPECT_CALL(*mockPlugin, GetMetric(METRIC_KEY(SUPPORTED_METRICS), _)).Times(AnyNumber());
EXPECT_CALL(*mockPlugin, GetMetric(METRIC_KEY(IMPORT_EXPORT_SUPPORT), _)).Times(AnyNumber());
EXPECT_CALL(*mockPlugin, GetMetric(METRIC_KEY(DEVICE_ARCHITECTURE), _)).Times(AnyNumber());
{
EXPECT_CALL(*mockPlugin, LoadExeNetworkImpl(_, _, _)).Times(m_remoteContext ? 1 : 0);
EXPECT_CALL(*mockPlugin, LoadExeNetworkImpl(_, _)).Times(!m_remoteContext ? 1 : 0);
EXPECT_CALL(*mockPlugin, ImportNetwork(_, _, _)).Times(0);
EXPECT_CALL(*mockPlugin, ImportNetwork(_, _)).Times(0);
m_post_mock_net_callbacks.emplace_back([&](MockExecutableNetwork& net) {
EXPECT_CALL(net, Export(_)).Times(1);
});
testLoad([&](Core &ie) {
ie.SetConfig({{CONFIG_KEY(CACHE_DIR), m_cacheDir}}, "mock");
m_testFunction(ie);
});
EXPECT_EQ(networks.size(), 1);
}
{
EXPECT_CALL(*mockPlugin, LoadExeNetworkImpl(_, _, _)).Times(0);
EXPECT_CALL(*mockPlugin, LoadExeNetworkImpl(_, _)).Times(0);
EXPECT_CALL(*mockPlugin, ImportNetwork(_, _, _)).Times(m_remoteContext ? 1 : 0);
EXPECT_CALL(*mockPlugin, ImportNetwork(_, _)).Times(!m_remoteContext ? 1 : 0);
for (auto& net : networks) {
EXPECT_CALL(*net, Export(_)).Times(0); // No more 'Export' for existing networks
}
testLoad([&](Core &ie) {
ie.SetConfig({{CONFIG_KEY(CACHE_DIR), m_cacheDir}}, "mock");
m_testFunction(ie);
});
EXPECT_EQ(networks.size(), 1);
}
}
TEST_P(CachingTest, TestLoadCustomImportExport) {
const char customData[] = {1, 2, 3, 4, 5};
EXPECT_CALL(*mockPlugin, GetMetric(METRIC_KEY(SUPPORTED_CONFIG_KEYS), _)).Times(AnyNumber());
@@ -626,6 +682,45 @@ TEST_P(CachingTest, TestChangeLoadConfig) {
}
}
/// \brief Verifies that ie.LoadNetwork(cnn, "deviceName", {{"CACHE_DIR", <dir>>}}) works
TEST_P(CachingTest, TestChangeLoadConfig_With_Cache_Dir_inline) {
EXPECT_CALL(*mockPlugin, GetMetric(METRIC_KEY(SUPPORTED_CONFIG_KEYS), _)).Times(AnyNumber());
EXPECT_CALL(*mockPlugin, GetMetric(ov::supported_properties.name(), _)).Times(AnyNumber());
EXPECT_CALL(*mockPlugin, GetMetric(METRIC_KEY(SUPPORTED_METRICS), _)).Times(AnyNumber());
EXPECT_CALL(*mockPlugin, GetMetric(METRIC_KEY(IMPORT_EXPORT_SUPPORT), _)).Times(AnyNumber());
EXPECT_CALL(*mockPlugin, GetMetric(METRIC_KEY(DEVICE_ARCHITECTURE), _)).Times(AnyNumber());
ON_CALL(*mockPlugin, GetMetric(METRIC_KEY(SUPPORTED_CONFIG_KEYS), _)).
WillByDefault(Invoke([&](const std::string &, const std::map<std::string, Parameter> &) {
return std::vector<std::string>{};
}));
{
EXPECT_CALL(*mockPlugin, LoadExeNetworkImpl(_, _, _)).Times(m_remoteContext ? 1 : 0);
EXPECT_CALL(*mockPlugin, LoadExeNetworkImpl(_, _)).Times(!m_remoteContext ? 1 : 0);
EXPECT_CALL(*mockPlugin, ImportNetwork(_, _, _)).Times(0);
EXPECT_CALL(*mockPlugin, ImportNetwork(_, _)).Times(0);
m_post_mock_net_callbacks.emplace_back([&](MockExecutableNetwork& net) {
EXPECT_CALL(net, Export(_)).Times(1);
});
testLoad([&](Core &ie) {
m_testFunctionWithCfg(ie, {{CONFIG_KEY(CACHE_DIR), m_cacheDir}});
});
}
m_post_mock_net_callbacks.pop_back();
{
EXPECT_CALL(*mockPlugin, LoadExeNetworkImpl(_, _, _)).Times(0);
EXPECT_CALL(*mockPlugin, LoadExeNetworkImpl(_, _)).Times(0);
EXPECT_CALL(*mockPlugin, ImportNetwork(_, _, _)).Times(m_remoteContext ? 1 : 0);
EXPECT_CALL(*mockPlugin, ImportNetwork(_, _)).Times(!m_remoteContext ? 1 : 0);
for (auto& net : networks) {
EXPECT_CALL(*net, Export(_)).Times(0); // No more 'Export' for existing networks
}
testLoad([&](Core &ie) {
m_testFunctionWithCfg(ie, {{CONFIG_KEY(CACHE_DIR), m_cacheDir}});
});
EXPECT_EQ(networks.size(), 1);
}
}
TEST_P(CachingTest, TestNoCacheEnabled) {
EXPECT_CALL(*mockPlugin, GetMetric(METRIC_KEY(SUPPORTED_CONFIG_KEYS), _)).Times(AnyNumber());
EXPECT_CALL(*mockPlugin, GetMetric(ov::supported_properties.name(), _)).Times(AnyNumber());
@@ -697,6 +792,32 @@ TEST_P(CachingTest, TestNoCacheMetricSupported) {
}
}
/// \brief If device doesn't support 'cache_dir' or 'import_export' - setting cache_dir is ignored
TEST_P(CachingTest, TestNoCacheMetricSupported_by_device_name) {
EXPECT_CALL(*mockPlugin, GetMetric(METRIC_KEY(SUPPORTED_CONFIG_KEYS), _)).Times(AnyNumber());
EXPECT_CALL(*mockPlugin, GetMetric(ov::supported_properties.name(), _)).Times(AnyNumber())
.WillRepeatedly(Return(std::vector<ov::PropertyName>{}));
EXPECT_CALL(*mockPlugin, GetMetric(METRIC_KEY(SUPPORTED_METRICS), _))
.Times(AnyNumber()).WillRepeatedly(Return(std::vector<std::string>{}));
EXPECT_CALL(*mockPlugin, GetMetric(METRIC_KEY(IMPORT_EXPORT_SUPPORT), _)).Times(0);
EXPECT_CALL(*mockPlugin, GetMetric(METRIC_KEY(DEVICE_ARCHITECTURE), _)).Times(0);
EXPECT_CALL(*mockPlugin, GetMetric(ov::device::capabilities.name(), _)).Times(0);
{
EXPECT_CALL(*mockPlugin, LoadExeNetworkImpl(_, _, _)).Times(m_remoteContext ? 1 : 0);
EXPECT_CALL(*mockPlugin, LoadExeNetworkImpl(_, _)).Times(!m_remoteContext ? 1 : 0);
EXPECT_CALL(*mockPlugin, OnLoadNetworkFromFile()).Times(m_type == TestLoadType::EModelName ? 1 : 0);
EXPECT_CALL(*mockPlugin, ImportNetwork(_, _, _)).Times(0);
EXPECT_CALL(*mockPlugin, ImportNetwork(_, _)).Times(0);
m_post_mock_net_callbacks.emplace_back([&](MockExecutableNetwork& net) {
EXPECT_CALL(net, Export(_)).Times(0);
});
testLoad([&](Core &ie) {
ie.SetConfig({{CONFIG_KEY(CACHE_DIR), m_cacheDir}}, "mock");
m_testFunction(ie);
});
}
}
TEST_P(CachingTest, TestNoCacheMetric_hasCacheDirConfig) {
EXPECT_CALL(*mockPlugin, GetMetric(METRIC_KEY(SUPPORTED_METRICS), _))
.Times(AnyNumber()).WillRepeatedly(
@@ -723,6 +844,62 @@ TEST_P(CachingTest, TestNoCacheMetric_hasCacheDirConfig) {
}
}
/// \brief If device supports 'cache_dir' or 'import_export' - setting cache_dir is passed to plugin on ie.LoadNetwork
TEST_P(CachingTest, TestNoCacheMetric_hasCacheDirConfig_inline) {
m_checkConfigCb = [](const std::map<std::string, std::string>& config) {
EXPECT_NE(config.count(CONFIG_KEY(CACHE_DIR)), 0);
};
EXPECT_CALL(*mockPlugin, GetMetric(METRIC_KEY(SUPPORTED_METRICS), _))
.Times(AnyNumber()).WillRepeatedly(
Return(std::vector<std::string>{METRIC_KEY(SUPPORTED_CONFIG_KEYS)}));
EXPECT_CALL(*mockPlugin, GetMetric(METRIC_KEY(SUPPORTED_CONFIG_KEYS), _))
.Times(AtLeast(1)).WillRepeatedly(Return(std::vector<std::string>{CONFIG_KEY(CACHE_DIR)}));
EXPECT_CALL(*mockPlugin, GetMetric(ov::supported_properties.name(), _))
.Times(AtLeast(1)).WillRepeatedly(Return(std::vector<ov::PropertyName>{
ov::supported_properties.name(), ov::cache_dir.name()}));
{
EXPECT_CALL(*mockPlugin, LoadExeNetworkImpl(_, _, _)).Times(m_remoteContext ? 1 : 0);
EXPECT_CALL(*mockPlugin, LoadExeNetworkImpl(_, _)).Times(!m_remoteContext ? 1 : 0);
EXPECT_CALL(*mockPlugin, OnLoadNetworkFromFile()).Times(m_type == TestLoadType::EModelName ? 1 : 0);
ASSERT_NO_THROW(
testLoad([&](Core &ie) {
m_testFunctionWithCfg(ie, {{CONFIG_KEY(CACHE_DIR), m_cacheDir}});
}));
}
}
/// \brief ie.SetConfig(<cachedir>, "deviceName") is propagated to plugin's SetConfig if device supports CACHE_DIR
TEST_P(CachingTest, TestNoCacheMetric_hasCacheDirConfig_by_device_name) {
m_checkConfigCb = [](const std::map<std::string, std::string>& config) {
// Shall be '0' as appropriate 'cache_dir' is expected in SetConfig, not in Load/Import network
EXPECT_EQ(config.count(CONFIG_KEY(CACHE_DIR)), 0);
};
EXPECT_CALL(*mockPlugin, GetMetric(METRIC_KEY(SUPPORTED_METRICS), _))
.Times(AnyNumber()).WillRepeatedly(
Return(std::vector<std::string>{METRIC_KEY(SUPPORTED_CONFIG_KEYS)}));
EXPECT_CALL(*mockPlugin, GetMetric(METRIC_KEY(SUPPORTED_CONFIG_KEYS), _))
.Times(AtLeast(1)).WillRepeatedly(Return(std::vector<std::string>{CONFIG_KEY(CACHE_DIR)}));
EXPECT_CALL(*mockPlugin, GetMetric(ov::supported_properties.name(), _))
.Times(AtLeast(1)).WillRepeatedly(Return(std::vector<ov::PropertyName>{
ov::supported_properties.name(), ov::cache_dir.name()}));
EXPECT_CALL(*mockPlugin, SetConfig(_)).Times(AtLeast(1)).WillRepeatedly(
Invoke([](const std::map<std::string, std::string>& config) {
ASSERT_GT(config.count(CONFIG_KEY(CACHE_DIR)), 0);
}));
{
EXPECT_CALL(*mockPlugin, LoadExeNetworkImpl(_, _, _)).Times(m_remoteContext ? 1 : 0);
EXPECT_CALL(*mockPlugin, LoadExeNetworkImpl(_, _)).Times(!m_remoteContext ? 1 : 0);
EXPECT_CALL(*mockPlugin, OnLoadNetworkFromFile()).Times(m_type == TestLoadType::EModelName ? 1 : 0);
ASSERT_NO_THROW(
testLoad([&](Core &ie) {
ie.SetConfig({{CONFIG_KEY(CACHE_DIR), m_cacheDir}}, "mock");
m_testFunction(ie);
}));
}
}
TEST_P(CachingTest, TestCacheEnabled_noConfig) {
EXPECT_CALL(*mockPlugin, GetMetric(METRIC_KEY(SUPPORTED_METRICS), _))
.Times(AnyNumber()).WillRepeatedly(
@@ -750,6 +927,9 @@ TEST_P(CachingTest, TestCacheEnabled_noConfig) {
TEST_P(CachingTest, TestNoCacheMetric_configThrow) {
m_checkConfigCb = [](const std::map<std::string, std::string>& config) {
EXPECT_NE(config.count(CONFIG_KEY(CACHE_DIR)), 0);
};
EXPECT_CALL(*mockPlugin, GetMetric(METRIC_KEY(SUPPORTED_METRICS), _))
.Times(AnyNumber()).WillRepeatedly(
Return(std::vector<std::string>{METRIC_KEY(SUPPORTED_CONFIG_KEYS)}));
@@ -833,6 +1013,198 @@ TEST_P(CachingTest, TestLoadChangeCacheDir) {
}
}
/// \brief Change CACHE_DIR during working with same 'Core' object. Verifies that new dir is used for caching
TEST_P(CachingTest, TestLoadChangeCacheDirOneCore) {
EXPECT_CALL(*mockPlugin, GetMetric(METRIC_KEY(SUPPORTED_CONFIG_KEYS), _)).Times(AnyNumber());
EXPECT_CALL(*mockPlugin, GetMetric(ov::supported_properties.name(), _)).Times(AnyNumber());
EXPECT_CALL(*mockPlugin, GetMetric(METRIC_KEY(SUPPORTED_METRICS), _)).Times(AnyNumber());
EXPECT_CALL(*mockPlugin, GetMetric(METRIC_KEY(IMPORT_EXPORT_SUPPORT), _)).Times(AnyNumber());
EXPECT_CALL(*mockPlugin, GetMetric(METRIC_KEY(DEVICE_ARCHITECTURE), _)).Times(AnyNumber());
EXPECT_CALL(*mockPlugin, SetConfig(_)).Times(AnyNumber()).WillRepeatedly(
Invoke([](const std::map<std::string, std::string>& config) {
ASSERT_EQ(config.count(CONFIG_KEY(CACHE_DIR)), 0);
}));
{
EXPECT_CALL(*mockPlugin, LoadExeNetworkImpl(_, _, _)).Times(m_remoteContext ? 2 : 0);
EXPECT_CALL(*mockPlugin, LoadExeNetworkImpl(_, _)).Times(!m_remoteContext ? 2 : 0);
EXPECT_CALL(*mockPlugin, ImportNetwork(_, _, _)).Times(0);
EXPECT_CALL(*mockPlugin, ImportNetwork(_, _)).Times(0);
testLoad([&](Core &ie) {
m_post_mock_net_callbacks.emplace_back([&](MockExecutableNetwork& net) {
EXPECT_CALL(net, Export(_)).Times(1);
});
ie.SetConfig({{CONFIG_KEY(CACHE_DIR), m_cacheDir}});
m_testFunction(ie);
std::string newCacheDir = m_cacheDir + "2";
m_post_mock_net_callbacks.pop_back();
m_post_mock_net_callbacks.emplace_back([&](MockExecutableNetwork& net) {
EXPECT_CALL(net, Export(_)).Times(1);
});
MkDirGuard dir(newCacheDir);
ie.SetConfig({{CONFIG_KEY(CACHE_DIR), newCacheDir}});
m_testFunction(ie);
});
}
}
/// \brief Change CACHE_DIR during working with same 'Core' object
/// Initially set for 'device', then is overwritten with global 'cache_dir' for all devices
TEST_P(CachingTest, TestLoadChangeCacheDirOneCore_overwrite_device_dir) {
EXPECT_CALL(*mockPlugin, GetMetric(METRIC_KEY(SUPPORTED_CONFIG_KEYS), _)).Times(AnyNumber());
EXPECT_CALL(*mockPlugin, GetMetric(ov::supported_properties.name(), _)).Times(AnyNumber());
EXPECT_CALL(*mockPlugin, GetMetric(METRIC_KEY(SUPPORTED_METRICS), _)).Times(AnyNumber());
EXPECT_CALL(*mockPlugin, GetMetric(METRIC_KEY(IMPORT_EXPORT_SUPPORT), _)).Times(AnyNumber());
EXPECT_CALL(*mockPlugin, GetMetric(METRIC_KEY(DEVICE_ARCHITECTURE), _)).Times(AnyNumber());
EXPECT_CALL(*mockPlugin, SetConfig(_)).Times(AnyNumber()).WillRepeatedly(
Invoke([](const std::map<std::string, std::string>& config) {
ASSERT_EQ(config.count(CONFIG_KEY(CACHE_DIR)), 0);
}));
{
EXPECT_CALL(*mockPlugin, LoadExeNetworkImpl(_, _, _)).Times(m_remoteContext ? 2 : 0);
EXPECT_CALL(*mockPlugin, LoadExeNetworkImpl(_, _)).Times(!m_remoteContext ? 2 : 0);
EXPECT_CALL(*mockPlugin, ImportNetwork(_, _, _)).Times(0);
EXPECT_CALL(*mockPlugin, ImportNetwork(_, _)).Times(0);
testLoad([&](Core &ie) {
m_post_mock_net_callbacks.emplace_back([&](MockExecutableNetwork& net) {
EXPECT_CALL(net, Export(_)).Times(1);
});
ie.SetConfig({{CONFIG_KEY(CACHE_DIR), m_cacheDir}}, "mock");
m_testFunction(ie);
std::string newCacheDir = m_cacheDir + "2";
m_post_mock_net_callbacks.pop_back();
m_post_mock_net_callbacks.emplace_back([&](MockExecutableNetwork& net) {
EXPECT_CALL(net, Export(_)).Times(1);
});
MkDirGuard dir(newCacheDir);
ie.SetConfig({{CONFIG_KEY(CACHE_DIR), newCacheDir}});
m_testFunction(ie);
});
}
}
/// \brief Change CACHE_DIR during working with same 'Core' object for device which supports 'CACHE_DIR' config, not import_export
/// Expectation is that SetConfig for plugin will be called 2 times - with appropriate cache_dir values
TEST_P(CachingTest, TestLoadChangeCacheDirOneCore_SupportsCacheDir_NoImportExport) {
m_checkConfigCb = [](const std::map<std::string, std::string>& config) {
EXPECT_EQ(config.count(CONFIG_KEY(CACHE_DIR)), 0);
};
EXPECT_CALL(*mockPlugin, GetMetric(METRIC_KEY(SUPPORTED_CONFIG_KEYS), _))
.Times(AtLeast(1)).WillRepeatedly(Return(std::vector<std::string>{CONFIG_KEY(CACHE_DIR)}));
EXPECT_CALL(*mockPlugin, GetMetric(ov::supported_properties.name(), _))
.Times(AnyNumber()).WillRepeatedly(Return(std::vector<ov::PropertyName>{
ov::supported_properties.name(), ov::cache_dir.name()}));
EXPECT_CALL(*mockPlugin, GetMetric(ov::device::capabilities.name(), _)).Times(AnyNumber());
EXPECT_CALL(*mockPlugin, GetMetric(METRIC_KEY(SUPPORTED_METRICS), _))
.Times(AnyNumber()).WillRepeatedly(
Return(std::vector<std::string>{METRIC_KEY(SUPPORTED_CONFIG_KEYS)}));
EXPECT_CALL(*mockPlugin, GetMetric(ov::device::capabilities.name(), _)).Times(AnyNumber()).
WillRepeatedly(Return(decltype(ov::device::capabilities)::value_type{}));
EXPECT_CALL(*mockPlugin, GetMetric(METRIC_KEY(DEVICE_ARCHITECTURE), _)).Times(AnyNumber());
std::string set_cache_dir = {};
EXPECT_CALL(*mockPlugin, SetConfig(_)).Times(AtLeast(2)).WillRepeatedly(
Invoke([&](const std::map<std::string, std::string>& config) {
ASSERT_NE(config.count(CONFIG_KEY(CACHE_DIR)), 0);
set_cache_dir = config.at(CONFIG_KEY(CACHE_DIR));
}));
{
EXPECT_CALL(*mockPlugin, LoadExeNetworkImpl(_, _, _)).Times(m_remoteContext ? 2 : 0);
EXPECT_CALL(*mockPlugin, LoadExeNetworkImpl(_, _)).Times(!m_remoteContext ? 2 : 0);
EXPECT_CALL(*mockPlugin, OnLoadNetworkFromFile()).Times(m_type == TestLoadType::EModelName ? 2 : 0);
EXPECT_CALL(*mockPlugin, ImportNetwork(_, _, _)).Times(0);
EXPECT_CALL(*mockPlugin, ImportNetwork(_, _)).Times(0);
m_post_mock_net_callbacks.emplace_back([&](MockExecutableNetwork& net) {
EXPECT_CALL(net, Export(_)).Times(0);
});
testLoad([&](Core &ie) {
ie.SetConfig({{CONFIG_KEY(CACHE_DIR), m_cacheDir}});
m_testFunction(ie);
EXPECT_EQ(set_cache_dir, m_cacheDir);
std::string new_cache_dir = m_cacheDir + "2";
MkDirGuard dir(new_cache_dir);
ie.SetConfig({{CONFIG_KEY(CACHE_DIR), new_cache_dir}});
m_testFunction(ie);
EXPECT_EQ(set_cache_dir, new_cache_dir);
});
}
}
/// \brief Change CACHE_DIR per device during working with same 'Core' object - expected that new cache dir is used
TEST_P(CachingTest, TestLoadChangeCacheDirOneCore_by_device_name) {
EXPECT_CALL(*mockPlugin, GetMetric(METRIC_KEY(SUPPORTED_CONFIG_KEYS), _)).Times(AnyNumber());
EXPECT_CALL(*mockPlugin, GetMetric(ov::supported_properties.name(), _)).Times(AnyNumber());
EXPECT_CALL(*mockPlugin, GetMetric(METRIC_KEY(SUPPORTED_METRICS), _)).Times(AnyNumber());
EXPECT_CALL(*mockPlugin, GetMetric(METRIC_KEY(IMPORT_EXPORT_SUPPORT), _)).Times(AnyNumber());
EXPECT_CALL(*mockPlugin, GetMetric(METRIC_KEY(DEVICE_ARCHITECTURE), _)).Times(AnyNumber());
EXPECT_CALL(*mockPlugin, SetConfig(_)).Times(AnyNumber()).WillRepeatedly(
Invoke([](const std::map<std::string, std::string>& config) {
ASSERT_EQ(config.count(CONFIG_KEY(CACHE_DIR)), 0);
}));
{
EXPECT_CALL(*mockPlugin, LoadExeNetworkImpl(_, _, _)).Times(m_remoteContext ? 2 : 0);
EXPECT_CALL(*mockPlugin, LoadExeNetworkImpl(_, _)).Times(!m_remoteContext ? 2 : 0);
EXPECT_CALL(*mockPlugin, ImportNetwork(_, _, _)).Times(0);
EXPECT_CALL(*mockPlugin, ImportNetwork(_, _)).Times(0);
testLoad([&](Core &ie) {
m_post_mock_net_callbacks.emplace_back([&](MockExecutableNetwork& net) {
EXPECT_CALL(net, Export(_)).Times(1);
});
ie.SetConfig({{CONFIG_KEY(CACHE_DIR), m_cacheDir}}, "mock");
m_testFunction(ie);
m_post_mock_net_callbacks.pop_back();
m_post_mock_net_callbacks.emplace_back([&](MockExecutableNetwork& net) {
EXPECT_CALL(net, Export(_)).Times(1);
});
std::string newCacheDir = m_cacheDir + "2";
MkDirGuard dir(newCacheDir);
ie.SetConfig({{CONFIG_KEY(CACHE_DIR), newCacheDir}}, "mock");
m_testFunction(ie);
});
}
}
/// \brief Change CACHE_DIR per device during working with same 'Core' object - device supports CACHE_DIR
/// Verifies that no 'export' is called and cache_dir is propagated to set_config
TEST_P(CachingTest, TestLoadChangeCacheDirOneCore_by_device_name_supports_cache_dir) {
EXPECT_CALL(*mockPlugin, GetMetric(METRIC_KEY(SUPPORTED_CONFIG_KEYS), _))
.Times(AtLeast(1)).WillRepeatedly(Return(std::vector<std::string>{CONFIG_KEY(CACHE_DIR)}));
EXPECT_CALL(*mockPlugin, GetMetric(ov::supported_properties.name(), _)).Times(AnyNumber())
.WillRepeatedly(Return(std::vector<ov::PropertyName>{
ov::cache_dir.name()}));
EXPECT_CALL(*mockPlugin, GetMetric(METRIC_KEY(SUPPORTED_METRICS), _)).Times(AnyNumber());
EXPECT_CALL(*mockPlugin, GetMetric(METRIC_KEY(IMPORT_EXPORT_SUPPORT), _)).Times(AnyNumber())
.WillRepeatedly(Return(false));
EXPECT_CALL(*mockPlugin, GetMetric(ov::device::capabilities.name(), _)).Times(AnyNumber()).
WillRepeatedly(Return(decltype(ov::device::capabilities)::value_type{}));
EXPECT_CALL(*mockPlugin, GetMetric(METRIC_KEY(DEVICE_ARCHITECTURE), _)).Times(AnyNumber());
EXPECT_CALL(*mockPlugin, SetConfig(_)).Times(AtLeast(2)).WillRepeatedly(
Invoke([](const std::map<std::string, std::string>& config) {
ASSERT_GT(config.count(CONFIG_KEY(CACHE_DIR)), 0);
}));
{
EXPECT_CALL(*mockPlugin, LoadExeNetworkImpl(_, _, _)).Times(m_remoteContext ? 2 : 0);
EXPECT_CALL(*mockPlugin, LoadExeNetworkImpl(_, _)).Times(!m_remoteContext ? 2 : 0);
EXPECT_CALL(*mockPlugin, OnLoadNetworkFromFile()).Times(m_type == TestLoadType::EModelName ? 2 : 0);
EXPECT_CALL(*mockPlugin, ImportNetwork(_, _, _)).Times(0);
EXPECT_CALL(*mockPlugin, ImportNetwork(_, _)).Times(0);
testLoad([&](Core &ie) {
m_post_mock_net_callbacks.emplace_back([&](MockExecutableNetwork& net) {
EXPECT_CALL(net, Export(_)).Times(0);
});
ie.SetConfig({{CONFIG_KEY(CACHE_DIR), m_cacheDir}}, "mock");
m_testFunction(ie);
m_post_mock_net_callbacks.pop_back();
m_post_mock_net_callbacks.emplace_back([&](MockExecutableNetwork& net) {
EXPECT_CALL(net, Export(_)).Times(0);
});
std::string newCacheDir = m_cacheDir + "2";
MkDirGuard dir(newCacheDir);
ie.SetConfig({{CONFIG_KEY(CACHE_DIR), newCacheDir}}, "mock");
m_testFunction(ie);
});
}
}
TEST_P(CachingTest, TestClearCacheDir) {
EXPECT_CALL(*mockPlugin, GetMetric(METRIC_KEY(SUPPORTED_CONFIG_KEYS), _)).Times(AnyNumber());
EXPECT_CALL(*mockPlugin, GetMetric(ov::supported_properties.name(), _)).Times(AnyNumber());