[Model Caching] Enabling per-device cache dir (#10774)
* Initial commit 10 more caching tests * Fix clang-format * Added brief explanations to each test * Fix review comments
This commit is contained in:
@@ -90,7 +90,7 @@ public:
|
||||
* @brief Constructor
|
||||
*
|
||||
*/
|
||||
FileStorageCacheManager(std::string&& cachePath) : m_cachePath(std::move(cachePath)) {}
|
||||
FileStorageCacheManager(std::string cachePath) : m_cachePath(std::move(cachePath)) {}
|
||||
|
||||
/**
|
||||
* @brief Destructor
|
||||
|
||||
@@ -150,27 +150,75 @@ class CoreImpl : public ie::ICore, public std::enable_shared_from_this<ie::ICore
|
||||
auto it = config.find(CONFIG_KEY(CACHE_DIR));
|
||||
if (it != config.end()) {
|
||||
std::lock_guard<std::mutex> lock(_cacheConfigMutex);
|
||||
_cacheConfig._cacheDir = it->second;
|
||||
if (!it->second.empty()) {
|
||||
FileUtils::createDirectoryRecursive(it->second);
|
||||
_cacheConfig._cacheManager = std::make_shared<ie::FileStorageCacheManager>(std::move(it->second));
|
||||
} else {
|
||||
_cacheConfig._cacheManager = nullptr;
|
||||
fillConfig(_cacheConfig, it->second);
|
||||
for (auto& deviceCfg : _cacheConfigPerDevice) {
|
||||
fillConfig(deviceCfg.second, it->second);
|
||||
}
|
||||
|
||||
config.erase(it);
|
||||
}
|
||||
}
|
||||
|
||||
// Creating thread-safe copy of config including shared_ptr to ICacheManager
|
||||
CacheConfig getCacheConfig() const {
|
||||
void setCacheForDevice(const std::string& dir, const std::string& name) {
|
||||
std::lock_guard<std::mutex> lock(_cacheConfigMutex);
|
||||
return _cacheConfig;
|
||||
fillConfig(_cacheConfigPerDevice[name], dir);
|
||||
}
|
||||
|
||||
// Creating thread-safe copy of config including shared_ptr to ICacheManager
|
||||
// Passing empty or not-existing name will return global cache config
|
||||
CacheConfig getCacheConfigForDevice(const std::string& device_name,
|
||||
bool deviceSupportsCacheDir,
|
||||
std::map<std::string, std::string>& parsedConfig) const {
|
||||
if (parsedConfig.count(CONFIG_KEY(CACHE_DIR))) {
|
||||
CoreConfig::CacheConfig tempConfig;
|
||||
CoreConfig::fillConfig(tempConfig, parsedConfig.at(CONFIG_KEY(CACHE_DIR)));
|
||||
if (!deviceSupportsCacheDir) {
|
||||
parsedConfig.erase(CONFIG_KEY(CACHE_DIR));
|
||||
}
|
||||
return tempConfig;
|
||||
} else {
|
||||
std::lock_guard<std::mutex> lock(_cacheConfigMutex);
|
||||
if (_cacheConfigPerDevice.count(device_name) > 0) {
|
||||
return _cacheConfigPerDevice.at(device_name);
|
||||
} else {
|
||||
return _cacheConfig;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
CacheConfig getCacheConfigForDevice(const std::string& device_name) const {
|
||||
std::lock_guard<std::mutex> lock(_cacheConfigMutex);
|
||||
if (_cacheConfigPerDevice.count(device_name) > 0) {
|
||||
return _cacheConfigPerDevice.at(device_name);
|
||||
} else {
|
||||
return _cacheConfig;
|
||||
}
|
||||
}
|
||||
|
||||
private:
|
||||
static void fillConfig(CacheConfig& config, const std::string& dir) {
|
||||
config._cacheDir = dir;
|
||||
if (!dir.empty()) {
|
||||
FileUtils::createDirectoryRecursive(dir);
|
||||
config._cacheManager = std::make_shared<ie::FileStorageCacheManager>(dir);
|
||||
} else {
|
||||
config._cacheManager = nullptr;
|
||||
}
|
||||
}
|
||||
|
||||
private:
|
||||
mutable std::mutex _cacheConfigMutex;
|
||||
CacheConfig _cacheConfig;
|
||||
std::map<std::string, CacheConfig> _cacheConfigPerDevice;
|
||||
};
|
||||
|
||||
struct CacheContent {
|
||||
explicit CacheContent(const std::shared_ptr<ie::ICacheManager>& cache_manager,
|
||||
const std::string model_path = {})
|
||||
: cacheManager(cache_manager),
|
||||
modelPath(model_path) {}
|
||||
std::shared_ptr<ie::ICacheManager> cacheManager;
|
||||
std::string blobId = {};
|
||||
std::string modelPath = {};
|
||||
};
|
||||
|
||||
// Core settings (cache config, etc)
|
||||
@@ -246,46 +294,42 @@ class CoreImpl : public ie::ICore, public std::enable_shared_from_this<ie::ICore
|
||||
ov::InferencePlugin& plugin,
|
||||
const std::map<std::string, std::string>& parsedConfig,
|
||||
const ie::RemoteContext::Ptr& context,
|
||||
const std::string& blobID,
|
||||
const std::string& modelPath = std::string(),
|
||||
const CacheContent& cacheContent,
|
||||
bool forceDisableCache = false) {
|
||||
OV_ITT_SCOPED_TASK(ov::itt::domains::IE, "CoreImpl::compile_model_impl");
|
||||
ov::SoPtr<ie::IExecutableNetworkInternal> execNetwork;
|
||||
execNetwork = context ? plugin.compile_model(network, context, parsedConfig)
|
||||
: plugin.compile_model(network, parsedConfig);
|
||||
auto cacheManager = coreConfig.getCacheConfig()._cacheManager;
|
||||
if (!forceDisableCache && cacheManager && DeviceSupportsImportExport(plugin)) {
|
||||
if (!forceDisableCache && cacheContent.cacheManager && DeviceSupportsImportExport(plugin)) {
|
||||
try {
|
||||
// need to export network for further import from "cache"
|
||||
OV_ITT_SCOPE(FIRST_INFERENCE, ie::itt::domains::IE_LT, "Core::LoadNetwork::Export");
|
||||
cacheManager->writeCacheEntry(blobID, [&](std::ostream& networkStream) {
|
||||
cacheContent.cacheManager->writeCacheEntry(cacheContent.blobId, [&](std::ostream& networkStream) {
|
||||
networkStream << ie::CompiledBlobHeader(
|
||||
ie::GetInferenceEngineVersion()->buildNumber,
|
||||
ie::NetworkCompilationContext::calculateFileInfo(modelPath));
|
||||
ie::NetworkCompilationContext::calculateFileInfo(cacheContent.modelPath));
|
||||
execNetwork->Export(networkStream);
|
||||
});
|
||||
} catch (...) {
|
||||
cacheManager->removeCacheEntry(blobID);
|
||||
cacheContent.cacheManager->removeCacheEntry(cacheContent.blobId);
|
||||
throw;
|
||||
}
|
||||
}
|
||||
return execNetwork;
|
||||
}
|
||||
|
||||
ov::SoPtr<ie::IExecutableNetworkInternal> LoadNetworkFromCache(
|
||||
const std::shared_ptr<ie::ICacheManager>& cacheManager,
|
||||
const std::string& blobId,
|
||||
static ov::SoPtr<ie::IExecutableNetworkInternal> LoadNetworkFromCache(
|
||||
const CacheContent& cacheContent,
|
||||
ov::InferencePlugin& plugin,
|
||||
const std::map<std::string, std::string>& config,
|
||||
const std::shared_ptr<ie::RemoteContext>& context,
|
||||
bool& networkIsImported,
|
||||
const std::string& modelPath = std::string()) {
|
||||
bool& networkIsImported) {
|
||||
ov::SoPtr<ie::IExecutableNetworkInternal> execNetwork;
|
||||
struct HeaderException {};
|
||||
|
||||
OPENVINO_ASSERT(cacheManager != nullptr);
|
||||
OPENVINO_ASSERT(cacheContent.cacheManager != nullptr);
|
||||
try {
|
||||
cacheManager->readCacheEntry(blobId, [&](std::istream& networkStream) {
|
||||
cacheContent.cacheManager->readCacheEntry(cacheContent.blobId, [&](std::istream& networkStream) {
|
||||
OV_ITT_SCOPE(FIRST_INFERENCE,
|
||||
ie::itt::domains::IE_LT,
|
||||
"Core::LoadNetworkFromCache::ReadStreamAndImport");
|
||||
@@ -296,7 +340,8 @@ class CoreImpl : public ie::ICore, public std::enable_shared_from_this<ie::ICore
|
||||
// Build number mismatch, don't use this cache
|
||||
throw ie::NetworkNotRead("Version does not match");
|
||||
}
|
||||
if (header.getFileInfo() != ie::NetworkCompilationContext::calculateFileInfo(modelPath)) {
|
||||
if (header.getFileInfo() !=
|
||||
ie::NetworkCompilationContext::calculateFileInfo(cacheContent.modelPath)) {
|
||||
// Original file is changed, don't use cache
|
||||
throw ie::NetworkNotRead("Original model file is changed");
|
||||
}
|
||||
@@ -310,10 +355,10 @@ class CoreImpl : public ie::ICore, public std::enable_shared_from_this<ie::ICore
|
||||
});
|
||||
} catch (const HeaderException&) {
|
||||
// For these exceptions just remove old cache and set that import didn't work
|
||||
cacheManager->removeCacheEntry(blobId);
|
||||
cacheContent.cacheManager->removeCacheEntry(cacheContent.blobId);
|
||||
networkIsImported = false;
|
||||
} catch (...) {
|
||||
cacheManager->removeCacheEntry(blobId);
|
||||
cacheContent.cacheManager->removeCacheEntry(cacheContent.blobId);
|
||||
networkIsImported = false;
|
||||
// TODO: temporary disabled by #54335. In future don't throw only for new 'blob_outdated' exception
|
||||
// throw;
|
||||
@@ -538,20 +583,23 @@ public:
|
||||
|
||||
auto plugin = GetCPPPluginByName(parsed._deviceName);
|
||||
ov::SoPtr<ie::IExecutableNetworkInternal> res;
|
||||
auto cacheManager = coreConfig.getCacheConfig()._cacheManager;
|
||||
auto cacheManager =
|
||||
coreConfig.getCacheConfigForDevice(parsed._deviceName, DeviceSupportsCacheDir(plugin), parsed._config)
|
||||
._cacheManager;
|
||||
auto cacheContent = CacheContent{cacheManager};
|
||||
if (cacheManager && DeviceSupportsImportExport(plugin)) {
|
||||
auto hash = CalculateNetworkHash(network, parsed._deviceName, plugin, parsed._config);
|
||||
cacheContent.blobId = CalculateNetworkHash(network, parsed._deviceName, plugin, parsed._config);
|
||||
bool loadedFromCache = false;
|
||||
auto lock = cacheGuard.getHashLock(hash);
|
||||
res = LoadNetworkFromCache(cacheManager, hash, plugin, parsed._config, context, loadedFromCache);
|
||||
auto lock = cacheGuard.getHashLock(cacheContent.blobId);
|
||||
res = LoadNetworkFromCache(cacheContent, plugin, parsed._config, context, loadedFromCache);
|
||||
if (!loadedFromCache) {
|
||||
res = compile_model_impl(network, plugin, parsed._config, context, hash);
|
||||
res = compile_model_impl(network, plugin, parsed._config, context, cacheContent);
|
||||
} else {
|
||||
// Temporary workaround until all plugins support caching of original model inputs
|
||||
InferenceEngine::SetExeNetworkInfo(res._ptr, network.getFunction(), isNewAPI());
|
||||
}
|
||||
} else {
|
||||
res = compile_model_impl(network, plugin, parsed._config, context, {});
|
||||
res = compile_model_impl(network, plugin, parsed._config, context, cacheContent);
|
||||
}
|
||||
return res;
|
||||
}
|
||||
@@ -634,20 +682,23 @@ public:
|
||||
}
|
||||
auto plugin = GetCPPPluginByName(parsed._deviceName);
|
||||
ov::SoPtr<ie::IExecutableNetworkInternal> res;
|
||||
auto cacheManager = coreConfig.getCacheConfig()._cacheManager;
|
||||
auto cacheManager =
|
||||
coreConfig.getCacheConfigForDevice(parsed._deviceName, DeviceSupportsCacheDir(plugin), parsed._config)
|
||||
._cacheManager;
|
||||
auto cacheContent = CacheContent{cacheManager};
|
||||
if (!forceDisableCache && cacheManager && DeviceSupportsImportExport(plugin)) {
|
||||
auto hash = CalculateNetworkHash(network, parsed._deviceName, plugin, parsed._config);
|
||||
cacheContent.blobId = CalculateNetworkHash(network, parsed._deviceName, plugin, parsed._config);
|
||||
bool loadedFromCache = false;
|
||||
auto lock = cacheGuard.getHashLock(hash);
|
||||
res = LoadNetworkFromCache(cacheManager, hash, plugin, parsed._config, nullptr, loadedFromCache);
|
||||
auto lock = cacheGuard.getHashLock(cacheContent.blobId);
|
||||
res = LoadNetworkFromCache(cacheContent, plugin, parsed._config, nullptr, loadedFromCache);
|
||||
if (!loadedFromCache) {
|
||||
res = compile_model_impl(network, plugin, parsed._config, nullptr, hash, {}, forceDisableCache);
|
||||
res = compile_model_impl(network, plugin, parsed._config, nullptr, cacheContent, forceDisableCache);
|
||||
} else {
|
||||
// Temporary workaround until all plugins support caching of original model inputs
|
||||
InferenceEngine::SetExeNetworkInfo(res._ptr, network.getFunction(), isNewAPI());
|
||||
}
|
||||
} else {
|
||||
res = compile_model_impl(network, plugin, parsed._config, nullptr, {}, {}, forceDisableCache);
|
||||
res = compile_model_impl(network, plugin, parsed._config, nullptr, cacheContent, forceDisableCache);
|
||||
}
|
||||
return {res._ptr, res._so};
|
||||
}
|
||||
@@ -660,18 +711,21 @@ public:
|
||||
auto parsed = parseDeviceNameIntoConfig(deviceName, config);
|
||||
auto plugin = GetCPPPluginByName(parsed._deviceName);
|
||||
ov::SoPtr<ie::IExecutableNetworkInternal> res;
|
||||
auto cacheManager = coreConfig.getCacheConfig()._cacheManager;
|
||||
auto cacheManager =
|
||||
coreConfig.getCacheConfigForDevice(parsed._deviceName, DeviceSupportsCacheDir(plugin), parsed._config)
|
||||
._cacheManager;
|
||||
auto cacheContent = CacheContent{cacheManager, modelPath};
|
||||
if (cacheManager && DeviceSupportsImportExport(plugin)) {
|
||||
bool loadedFromCache = false;
|
||||
auto hash = CalculateFileHash(modelPath, parsed._deviceName, plugin, parsed._config);
|
||||
auto lock = cacheGuard.getHashLock(hash);
|
||||
res = LoadNetworkFromCache(cacheManager, hash, plugin, parsed._config, nullptr, loadedFromCache, modelPath);
|
||||
cacheContent.blobId = CalculateFileHash(modelPath, parsed._deviceName, plugin, parsed._config);
|
||||
auto lock = cacheGuard.getHashLock(cacheContent.blobId);
|
||||
res = LoadNetworkFromCache(cacheContent, plugin, parsed._config, nullptr, loadedFromCache);
|
||||
if (!loadedFromCache) {
|
||||
auto cnnNetwork = ReadNetwork(modelPath, std::string());
|
||||
if (val) {
|
||||
val(cnnNetwork);
|
||||
}
|
||||
res = compile_model_impl(cnnNetwork, plugin, parsed._config, nullptr, hash, modelPath);
|
||||
res = compile_model_impl(cnnNetwork, plugin, parsed._config, nullptr, cacheContent);
|
||||
}
|
||||
} else if (cacheManager) {
|
||||
// TODO: 'validation' for dynamic API doesn't work for this case, as it affects a lot of plugin API
|
||||
@@ -681,7 +735,7 @@ public:
|
||||
if (val) {
|
||||
val(cnnNetwork);
|
||||
}
|
||||
res = compile_model_impl(cnnNetwork, plugin, parsed._config, nullptr, {}, modelPath);
|
||||
res = compile_model_impl(cnnNetwork, plugin, parsed._config, nullptr, cacheContent);
|
||||
}
|
||||
return {res._ptr, res._so};
|
||||
}
|
||||
@@ -921,10 +975,13 @@ public:
|
||||
// configuring
|
||||
{
|
||||
if (DeviceSupportsCacheDir(plugin)) {
|
||||
auto cacheConfig = coreConfig.getCacheConfig();
|
||||
auto cacheConfig = coreConfig.getCacheConfigForDevice(deviceName);
|
||||
if (cacheConfig._cacheManager) {
|
||||
desc.defaultConfig[CONFIG_KEY(CACHE_DIR)] = cacheConfig._cacheDir;
|
||||
}
|
||||
} else if (desc.defaultConfig.count(CONFIG_KEY(CACHE_DIR)) > 0) {
|
||||
// Remove "CACHE_DIR" from config if it is not supported by plugin
|
||||
desc.defaultConfig.erase(CONFIG_KEY(CACHE_DIR));
|
||||
}
|
||||
allowNotImplemented([&]() {
|
||||
// Add device specific value to support device_name.device_id cases
|
||||
@@ -1056,6 +1113,11 @@ public:
|
||||
|
||||
if (deviceName.empty()) {
|
||||
coreConfig.setAndUpdate(config);
|
||||
} else {
|
||||
auto cache_it = config.find(CONFIG_KEY(CACHE_DIR));
|
||||
if (cache_it != config.end()) {
|
||||
coreConfig.setCacheForDevice(cache_it->second, clearDeviceName);
|
||||
}
|
||||
}
|
||||
|
||||
auto base_desc = pluginRegistry.find(clearDeviceName);
|
||||
@@ -1085,10 +1147,13 @@ public:
|
||||
allowNotImplemented([&]() {
|
||||
auto configCopy = config;
|
||||
if (DeviceSupportsCacheDir(plugin.second)) {
|
||||
auto cacheConfig = coreConfig.getCacheConfig();
|
||||
auto cacheConfig = coreConfig.getCacheConfigForDevice(deviceName);
|
||||
if (cacheConfig._cacheManager) {
|
||||
configCopy[CONFIG_KEY(CACHE_DIR)] = cacheConfig._cacheDir;
|
||||
}
|
||||
} else if (configCopy.count(CONFIG_KEY(CACHE_DIR)) > 0) {
|
||||
// Remove "CACHE_DIR" from config if it is not supported by plugin
|
||||
configCopy.erase(CONFIG_KEY(CACHE_DIR));
|
||||
}
|
||||
// Add device specific value to support device_name.device_id cases
|
||||
std::vector<std::string> supportedConfigKeys =
|
||||
|
||||
@@ -194,6 +194,8 @@ public:
|
||||
CNNCallback m_cnnCallback = nullptr;
|
||||
std::map<std::string, InputsDataMap> m_inputs_map;
|
||||
std::map<std::string, OutputsDataMap> m_outputs_map;
|
||||
using CheckConfigCb = std::function<void(const std::map<std::string, std::string> &)>;
|
||||
CheckConfigCb m_checkConfigCb = nullptr;
|
||||
|
||||
static std::string get_mock_engine_name() {
|
||||
std::string mockEngineName("mock_engine");
|
||||
@@ -377,6 +379,10 @@ private:
|
||||
ov::device::capabilities.name(),
|
||||
ov::device::architecture.name()};
|
||||
}));
|
||||
|
||||
ON_CALL(plugin, GetMetric(METRIC_KEY(OPTIMIZATION_CAPABILITIES), _)).
|
||||
WillByDefault(Return(std::vector<std::string>()));
|
||||
|
||||
ON_CALL(plugin, GetMetric(METRIC_KEY(IMPORT_EXPORT_SUPPORT), _)).
|
||||
WillByDefault(Return(true));
|
||||
|
||||
@@ -401,7 +407,10 @@ private:
|
||||
|
||||
ON_CALL(plugin, ImportNetwork(_, _, _)).
|
||||
WillByDefault(Invoke([&](std::istream &istr, const RemoteContext::Ptr&,
|
||||
const std::map<std::string, std::string> &) {
|
||||
const std::map<std::string, std::string> &config) {
|
||||
if (m_checkConfigCb) {
|
||||
m_checkConfigCb(config);
|
||||
}
|
||||
std::string name;
|
||||
istr >> name;
|
||||
char space;
|
||||
@@ -411,7 +420,10 @@ private:
|
||||
}));
|
||||
|
||||
ON_CALL(plugin, ImportNetwork(_, _)).
|
||||
WillByDefault(Invoke([&](std::istream &istr, const std::map<std::string, std::string> &) {
|
||||
WillByDefault(Invoke([&](std::istream &istr, const std::map<std::string, std::string> &config) {
|
||||
if (m_checkConfigCb) {
|
||||
m_checkConfigCb(config);
|
||||
}
|
||||
std::string name;
|
||||
istr >> name;
|
||||
char space;
|
||||
@@ -422,7 +434,10 @@ private:
|
||||
|
||||
ON_CALL(plugin, LoadExeNetworkImpl(_, _, _)).
|
||||
WillByDefault(Invoke([&](const CNNNetwork & cnn, const RemoteContext::Ptr&,
|
||||
const std::map<std::string, std::string> &) {
|
||||
const std::map<std::string, std::string> &config) {
|
||||
if (m_checkConfigCb) {
|
||||
m_checkConfigCb(config);
|
||||
}
|
||||
std::lock_guard<std::mutex> lock(mock_creation_mutex);
|
||||
std::string name = cnn.getFunction()->get_friendly_name();
|
||||
m_inputs_map[name] = cnn.getInputsInfo();
|
||||
@@ -440,7 +455,10 @@ private:
|
||||
|
||||
ON_CALL(plugin, LoadExeNetworkImpl(_, _)).
|
||||
WillByDefault(Invoke([&](const CNNNetwork & cnn,
|
||||
const std::map<std::string, std::string> &) {
|
||||
const std::map<std::string, std::string> &config) {
|
||||
if (m_checkConfigCb) {
|
||||
m_checkConfigCb(config);
|
||||
}
|
||||
std::string name = cnn.getFunction()->get_friendly_name();
|
||||
std::lock_guard<std::mutex> lock(mock_creation_mutex);
|
||||
m_inputs_map[name] = cnn.getInputsInfo();
|
||||
@@ -517,6 +535,44 @@ TEST_P(CachingTest, TestLoad) {
|
||||
}
|
||||
}
|
||||
|
||||
/// \brief Verifies that ie.SetConfig({{"CACHE_DIR", <dir>}}, "deviceName"}}); enables caching for one device
|
||||
TEST_P(CachingTest, TestLoad_by_device_name) {
|
||||
EXPECT_CALL(*mockPlugin, GetMetric(METRIC_KEY(SUPPORTED_CONFIG_KEYS), _)).Times(AnyNumber());
|
||||
EXPECT_CALL(*mockPlugin, GetMetric(ov::supported_properties.name(), _)).Times(AnyNumber());
|
||||
EXPECT_CALL(*mockPlugin, GetMetric(METRIC_KEY(SUPPORTED_METRICS), _)).Times(AnyNumber());
|
||||
EXPECT_CALL(*mockPlugin, GetMetric(METRIC_KEY(IMPORT_EXPORT_SUPPORT), _)).Times(AnyNumber());
|
||||
EXPECT_CALL(*mockPlugin, GetMetric(METRIC_KEY(DEVICE_ARCHITECTURE), _)).Times(AnyNumber());
|
||||
{
|
||||
EXPECT_CALL(*mockPlugin, LoadExeNetworkImpl(_, _, _)).Times(m_remoteContext ? 1 : 0);
|
||||
EXPECT_CALL(*mockPlugin, LoadExeNetworkImpl(_, _)).Times(!m_remoteContext ? 1 : 0);
|
||||
EXPECT_CALL(*mockPlugin, ImportNetwork(_, _, _)).Times(0);
|
||||
EXPECT_CALL(*mockPlugin, ImportNetwork(_, _)).Times(0);
|
||||
m_post_mock_net_callbacks.emplace_back([&](MockExecutableNetwork& net) {
|
||||
EXPECT_CALL(net, Export(_)).Times(1);
|
||||
});
|
||||
testLoad([&](Core &ie) {
|
||||
ie.SetConfig({{CONFIG_KEY(CACHE_DIR), m_cacheDir}}, "mock");
|
||||
m_testFunction(ie);
|
||||
});
|
||||
EXPECT_EQ(networks.size(), 1);
|
||||
}
|
||||
|
||||
{
|
||||
EXPECT_CALL(*mockPlugin, LoadExeNetworkImpl(_, _, _)).Times(0);
|
||||
EXPECT_CALL(*mockPlugin, LoadExeNetworkImpl(_, _)).Times(0);
|
||||
EXPECT_CALL(*mockPlugin, ImportNetwork(_, _, _)).Times(m_remoteContext ? 1 : 0);
|
||||
EXPECT_CALL(*mockPlugin, ImportNetwork(_, _)).Times(!m_remoteContext ? 1 : 0);
|
||||
for (auto& net : networks) {
|
||||
EXPECT_CALL(*net, Export(_)).Times(0); // No more 'Export' for existing networks
|
||||
}
|
||||
testLoad([&](Core &ie) {
|
||||
ie.SetConfig({{CONFIG_KEY(CACHE_DIR), m_cacheDir}}, "mock");
|
||||
m_testFunction(ie);
|
||||
});
|
||||
EXPECT_EQ(networks.size(), 1);
|
||||
}
|
||||
}
|
||||
|
||||
TEST_P(CachingTest, TestLoadCustomImportExport) {
|
||||
const char customData[] = {1, 2, 3, 4, 5};
|
||||
EXPECT_CALL(*mockPlugin, GetMetric(METRIC_KEY(SUPPORTED_CONFIG_KEYS), _)).Times(AnyNumber());
|
||||
@@ -626,6 +682,45 @@ TEST_P(CachingTest, TestChangeLoadConfig) {
|
||||
}
|
||||
}
|
||||
|
||||
/// \brief Verifies that ie.LoadNetwork(cnn, "deviceName", {{"CACHE_DIR", <dir>>}}) works
|
||||
TEST_P(CachingTest, TestChangeLoadConfig_With_Cache_Dir_inline) {
|
||||
EXPECT_CALL(*mockPlugin, GetMetric(METRIC_KEY(SUPPORTED_CONFIG_KEYS), _)).Times(AnyNumber());
|
||||
EXPECT_CALL(*mockPlugin, GetMetric(ov::supported_properties.name(), _)).Times(AnyNumber());
|
||||
EXPECT_CALL(*mockPlugin, GetMetric(METRIC_KEY(SUPPORTED_METRICS), _)).Times(AnyNumber());
|
||||
EXPECT_CALL(*mockPlugin, GetMetric(METRIC_KEY(IMPORT_EXPORT_SUPPORT), _)).Times(AnyNumber());
|
||||
EXPECT_CALL(*mockPlugin, GetMetric(METRIC_KEY(DEVICE_ARCHITECTURE), _)).Times(AnyNumber());
|
||||
ON_CALL(*mockPlugin, GetMetric(METRIC_KEY(SUPPORTED_CONFIG_KEYS), _)).
|
||||
WillByDefault(Invoke([&](const std::string &, const std::map<std::string, Parameter> &) {
|
||||
return std::vector<std::string>{};
|
||||
}));
|
||||
{
|
||||
EXPECT_CALL(*mockPlugin, LoadExeNetworkImpl(_, _, _)).Times(m_remoteContext ? 1 : 0);
|
||||
EXPECT_CALL(*mockPlugin, LoadExeNetworkImpl(_, _)).Times(!m_remoteContext ? 1 : 0);
|
||||
EXPECT_CALL(*mockPlugin, ImportNetwork(_, _, _)).Times(0);
|
||||
EXPECT_CALL(*mockPlugin, ImportNetwork(_, _)).Times(0);
|
||||
m_post_mock_net_callbacks.emplace_back([&](MockExecutableNetwork& net) {
|
||||
EXPECT_CALL(net, Export(_)).Times(1);
|
||||
});
|
||||
testLoad([&](Core &ie) {
|
||||
m_testFunctionWithCfg(ie, {{CONFIG_KEY(CACHE_DIR), m_cacheDir}});
|
||||
});
|
||||
}
|
||||
m_post_mock_net_callbacks.pop_back();
|
||||
{
|
||||
EXPECT_CALL(*mockPlugin, LoadExeNetworkImpl(_, _, _)).Times(0);
|
||||
EXPECT_CALL(*mockPlugin, LoadExeNetworkImpl(_, _)).Times(0);
|
||||
EXPECT_CALL(*mockPlugin, ImportNetwork(_, _, _)).Times(m_remoteContext ? 1 : 0);
|
||||
EXPECT_CALL(*mockPlugin, ImportNetwork(_, _)).Times(!m_remoteContext ? 1 : 0);
|
||||
for (auto& net : networks) {
|
||||
EXPECT_CALL(*net, Export(_)).Times(0); // No more 'Export' for existing networks
|
||||
}
|
||||
testLoad([&](Core &ie) {
|
||||
m_testFunctionWithCfg(ie, {{CONFIG_KEY(CACHE_DIR), m_cacheDir}});
|
||||
});
|
||||
EXPECT_EQ(networks.size(), 1);
|
||||
}
|
||||
}
|
||||
|
||||
TEST_P(CachingTest, TestNoCacheEnabled) {
|
||||
EXPECT_CALL(*mockPlugin, GetMetric(METRIC_KEY(SUPPORTED_CONFIG_KEYS), _)).Times(AnyNumber());
|
||||
EXPECT_CALL(*mockPlugin, GetMetric(ov::supported_properties.name(), _)).Times(AnyNumber());
|
||||
@@ -697,6 +792,32 @@ TEST_P(CachingTest, TestNoCacheMetricSupported) {
|
||||
}
|
||||
}
|
||||
|
||||
/// \brief If device doesn't support 'cache_dir' or 'import_export' - setting cache_dir is ignored
|
||||
TEST_P(CachingTest, TestNoCacheMetricSupported_by_device_name) {
|
||||
EXPECT_CALL(*mockPlugin, GetMetric(METRIC_KEY(SUPPORTED_CONFIG_KEYS), _)).Times(AnyNumber());
|
||||
EXPECT_CALL(*mockPlugin, GetMetric(ov::supported_properties.name(), _)).Times(AnyNumber())
|
||||
.WillRepeatedly(Return(std::vector<ov::PropertyName>{}));
|
||||
EXPECT_CALL(*mockPlugin, GetMetric(METRIC_KEY(SUPPORTED_METRICS), _))
|
||||
.Times(AnyNumber()).WillRepeatedly(Return(std::vector<std::string>{}));
|
||||
EXPECT_CALL(*mockPlugin, GetMetric(METRIC_KEY(IMPORT_EXPORT_SUPPORT), _)).Times(0);
|
||||
EXPECT_CALL(*mockPlugin, GetMetric(METRIC_KEY(DEVICE_ARCHITECTURE), _)).Times(0);
|
||||
EXPECT_CALL(*mockPlugin, GetMetric(ov::device::capabilities.name(), _)).Times(0);
|
||||
{
|
||||
EXPECT_CALL(*mockPlugin, LoadExeNetworkImpl(_, _, _)).Times(m_remoteContext ? 1 : 0);
|
||||
EXPECT_CALL(*mockPlugin, LoadExeNetworkImpl(_, _)).Times(!m_remoteContext ? 1 : 0);
|
||||
EXPECT_CALL(*mockPlugin, OnLoadNetworkFromFile()).Times(m_type == TestLoadType::EModelName ? 1 : 0);
|
||||
EXPECT_CALL(*mockPlugin, ImportNetwork(_, _, _)).Times(0);
|
||||
EXPECT_CALL(*mockPlugin, ImportNetwork(_, _)).Times(0);
|
||||
m_post_mock_net_callbacks.emplace_back([&](MockExecutableNetwork& net) {
|
||||
EXPECT_CALL(net, Export(_)).Times(0);
|
||||
});
|
||||
testLoad([&](Core &ie) {
|
||||
ie.SetConfig({{CONFIG_KEY(CACHE_DIR), m_cacheDir}}, "mock");
|
||||
m_testFunction(ie);
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
TEST_P(CachingTest, TestNoCacheMetric_hasCacheDirConfig) {
|
||||
EXPECT_CALL(*mockPlugin, GetMetric(METRIC_KEY(SUPPORTED_METRICS), _))
|
||||
.Times(AnyNumber()).WillRepeatedly(
|
||||
@@ -723,6 +844,62 @@ TEST_P(CachingTest, TestNoCacheMetric_hasCacheDirConfig) {
|
||||
}
|
||||
}
|
||||
|
||||
/// \brief If device supports 'cache_dir' or 'import_export' - setting cache_dir is passed to plugin on ie.LoadNetwork
|
||||
TEST_P(CachingTest, TestNoCacheMetric_hasCacheDirConfig_inline) {
|
||||
m_checkConfigCb = [](const std::map<std::string, std::string>& config) {
|
||||
EXPECT_NE(config.count(CONFIG_KEY(CACHE_DIR)), 0);
|
||||
};
|
||||
EXPECT_CALL(*mockPlugin, GetMetric(METRIC_KEY(SUPPORTED_METRICS), _))
|
||||
.Times(AnyNumber()).WillRepeatedly(
|
||||
Return(std::vector<std::string>{METRIC_KEY(SUPPORTED_CONFIG_KEYS)}));
|
||||
EXPECT_CALL(*mockPlugin, GetMetric(METRIC_KEY(SUPPORTED_CONFIG_KEYS), _))
|
||||
.Times(AtLeast(1)).WillRepeatedly(Return(std::vector<std::string>{CONFIG_KEY(CACHE_DIR)}));
|
||||
EXPECT_CALL(*mockPlugin, GetMetric(ov::supported_properties.name(), _))
|
||||
.Times(AtLeast(1)).WillRepeatedly(Return(std::vector<ov::PropertyName>{
|
||||
ov::supported_properties.name(), ov::cache_dir.name()}));
|
||||
|
||||
{
|
||||
EXPECT_CALL(*mockPlugin, LoadExeNetworkImpl(_, _, _)).Times(m_remoteContext ? 1 : 0);
|
||||
EXPECT_CALL(*mockPlugin, LoadExeNetworkImpl(_, _)).Times(!m_remoteContext ? 1 : 0);
|
||||
EXPECT_CALL(*mockPlugin, OnLoadNetworkFromFile()).Times(m_type == TestLoadType::EModelName ? 1 : 0);
|
||||
ASSERT_NO_THROW(
|
||||
testLoad([&](Core &ie) {
|
||||
m_testFunctionWithCfg(ie, {{CONFIG_KEY(CACHE_DIR), m_cacheDir}});
|
||||
}));
|
||||
}
|
||||
}
|
||||
|
||||
/// \brief ie.SetConfig(<cachedir>, "deviceName") is propagated to plugin's SetConfig if device supports CACHE_DIR
|
||||
TEST_P(CachingTest, TestNoCacheMetric_hasCacheDirConfig_by_device_name) {
|
||||
m_checkConfigCb = [](const std::map<std::string, std::string>& config) {
|
||||
// Shall be '0' as appropriate 'cache_dir' is expected in SetConfig, not in Load/Import network
|
||||
EXPECT_EQ(config.count(CONFIG_KEY(CACHE_DIR)), 0);
|
||||
};
|
||||
EXPECT_CALL(*mockPlugin, GetMetric(METRIC_KEY(SUPPORTED_METRICS), _))
|
||||
.Times(AnyNumber()).WillRepeatedly(
|
||||
Return(std::vector<std::string>{METRIC_KEY(SUPPORTED_CONFIG_KEYS)}));
|
||||
EXPECT_CALL(*mockPlugin, GetMetric(METRIC_KEY(SUPPORTED_CONFIG_KEYS), _))
|
||||
.Times(AtLeast(1)).WillRepeatedly(Return(std::vector<std::string>{CONFIG_KEY(CACHE_DIR)}));
|
||||
EXPECT_CALL(*mockPlugin, GetMetric(ov::supported_properties.name(), _))
|
||||
.Times(AtLeast(1)).WillRepeatedly(Return(std::vector<ov::PropertyName>{
|
||||
ov::supported_properties.name(), ov::cache_dir.name()}));
|
||||
EXPECT_CALL(*mockPlugin, SetConfig(_)).Times(AtLeast(1)).WillRepeatedly(
|
||||
Invoke([](const std::map<std::string, std::string>& config) {
|
||||
ASSERT_GT(config.count(CONFIG_KEY(CACHE_DIR)), 0);
|
||||
}));
|
||||
|
||||
{
|
||||
EXPECT_CALL(*mockPlugin, LoadExeNetworkImpl(_, _, _)).Times(m_remoteContext ? 1 : 0);
|
||||
EXPECT_CALL(*mockPlugin, LoadExeNetworkImpl(_, _)).Times(!m_remoteContext ? 1 : 0);
|
||||
EXPECT_CALL(*mockPlugin, OnLoadNetworkFromFile()).Times(m_type == TestLoadType::EModelName ? 1 : 0);
|
||||
ASSERT_NO_THROW(
|
||||
testLoad([&](Core &ie) {
|
||||
ie.SetConfig({{CONFIG_KEY(CACHE_DIR), m_cacheDir}}, "mock");
|
||||
m_testFunction(ie);
|
||||
}));
|
||||
}
|
||||
}
|
||||
|
||||
TEST_P(CachingTest, TestCacheEnabled_noConfig) {
|
||||
EXPECT_CALL(*mockPlugin, GetMetric(METRIC_KEY(SUPPORTED_METRICS), _))
|
||||
.Times(AnyNumber()).WillRepeatedly(
|
||||
@@ -750,6 +927,9 @@ TEST_P(CachingTest, TestCacheEnabled_noConfig) {
|
||||
|
||||
|
||||
TEST_P(CachingTest, TestNoCacheMetric_configThrow) {
|
||||
m_checkConfigCb = [](const std::map<std::string, std::string>& config) {
|
||||
EXPECT_NE(config.count(CONFIG_KEY(CACHE_DIR)), 0);
|
||||
};
|
||||
EXPECT_CALL(*mockPlugin, GetMetric(METRIC_KEY(SUPPORTED_METRICS), _))
|
||||
.Times(AnyNumber()).WillRepeatedly(
|
||||
Return(std::vector<std::string>{METRIC_KEY(SUPPORTED_CONFIG_KEYS)}));
|
||||
@@ -833,6 +1013,198 @@ TEST_P(CachingTest, TestLoadChangeCacheDir) {
|
||||
}
|
||||
}
|
||||
|
||||
/// \brief Change CACHE_DIR during working with same 'Core' object. Verifies that new dir is used for caching
|
||||
TEST_P(CachingTest, TestLoadChangeCacheDirOneCore) {
|
||||
EXPECT_CALL(*mockPlugin, GetMetric(METRIC_KEY(SUPPORTED_CONFIG_KEYS), _)).Times(AnyNumber());
|
||||
EXPECT_CALL(*mockPlugin, GetMetric(ov::supported_properties.name(), _)).Times(AnyNumber());
|
||||
EXPECT_CALL(*mockPlugin, GetMetric(METRIC_KEY(SUPPORTED_METRICS), _)).Times(AnyNumber());
|
||||
EXPECT_CALL(*mockPlugin, GetMetric(METRIC_KEY(IMPORT_EXPORT_SUPPORT), _)).Times(AnyNumber());
|
||||
EXPECT_CALL(*mockPlugin, GetMetric(METRIC_KEY(DEVICE_ARCHITECTURE), _)).Times(AnyNumber());
|
||||
EXPECT_CALL(*mockPlugin, SetConfig(_)).Times(AnyNumber()).WillRepeatedly(
|
||||
Invoke([](const std::map<std::string, std::string>& config) {
|
||||
ASSERT_EQ(config.count(CONFIG_KEY(CACHE_DIR)), 0);
|
||||
}));
|
||||
{
|
||||
EXPECT_CALL(*mockPlugin, LoadExeNetworkImpl(_, _, _)).Times(m_remoteContext ? 2 : 0);
|
||||
EXPECT_CALL(*mockPlugin, LoadExeNetworkImpl(_, _)).Times(!m_remoteContext ? 2 : 0);
|
||||
EXPECT_CALL(*mockPlugin, ImportNetwork(_, _, _)).Times(0);
|
||||
EXPECT_CALL(*mockPlugin, ImportNetwork(_, _)).Times(0);
|
||||
testLoad([&](Core &ie) {
|
||||
m_post_mock_net_callbacks.emplace_back([&](MockExecutableNetwork& net) {
|
||||
EXPECT_CALL(net, Export(_)).Times(1);
|
||||
});
|
||||
ie.SetConfig({{CONFIG_KEY(CACHE_DIR), m_cacheDir}});
|
||||
m_testFunction(ie);
|
||||
std::string newCacheDir = m_cacheDir + "2";
|
||||
m_post_mock_net_callbacks.pop_back();
|
||||
m_post_mock_net_callbacks.emplace_back([&](MockExecutableNetwork& net) {
|
||||
EXPECT_CALL(net, Export(_)).Times(1);
|
||||
});
|
||||
MkDirGuard dir(newCacheDir);
|
||||
ie.SetConfig({{CONFIG_KEY(CACHE_DIR), newCacheDir}});
|
||||
m_testFunction(ie);
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
/// \brief Change CACHE_DIR during working with same 'Core' object
|
||||
/// Initially set for 'device', then is overwritten with global 'cache_dir' for all devices
|
||||
TEST_P(CachingTest, TestLoadChangeCacheDirOneCore_overwrite_device_dir) {
|
||||
EXPECT_CALL(*mockPlugin, GetMetric(METRIC_KEY(SUPPORTED_CONFIG_KEYS), _)).Times(AnyNumber());
|
||||
EXPECT_CALL(*mockPlugin, GetMetric(ov::supported_properties.name(), _)).Times(AnyNumber());
|
||||
EXPECT_CALL(*mockPlugin, GetMetric(METRIC_KEY(SUPPORTED_METRICS), _)).Times(AnyNumber());
|
||||
EXPECT_CALL(*mockPlugin, GetMetric(METRIC_KEY(IMPORT_EXPORT_SUPPORT), _)).Times(AnyNumber());
|
||||
EXPECT_CALL(*mockPlugin, GetMetric(METRIC_KEY(DEVICE_ARCHITECTURE), _)).Times(AnyNumber());
|
||||
EXPECT_CALL(*mockPlugin, SetConfig(_)).Times(AnyNumber()).WillRepeatedly(
|
||||
Invoke([](const std::map<std::string, std::string>& config) {
|
||||
ASSERT_EQ(config.count(CONFIG_KEY(CACHE_DIR)), 0);
|
||||
}));
|
||||
{
|
||||
EXPECT_CALL(*mockPlugin, LoadExeNetworkImpl(_, _, _)).Times(m_remoteContext ? 2 : 0);
|
||||
EXPECT_CALL(*mockPlugin, LoadExeNetworkImpl(_, _)).Times(!m_remoteContext ? 2 : 0);
|
||||
EXPECT_CALL(*mockPlugin, ImportNetwork(_, _, _)).Times(0);
|
||||
EXPECT_CALL(*mockPlugin, ImportNetwork(_, _)).Times(0);
|
||||
testLoad([&](Core &ie) {
|
||||
m_post_mock_net_callbacks.emplace_back([&](MockExecutableNetwork& net) {
|
||||
EXPECT_CALL(net, Export(_)).Times(1);
|
||||
});
|
||||
ie.SetConfig({{CONFIG_KEY(CACHE_DIR), m_cacheDir}}, "mock");
|
||||
m_testFunction(ie);
|
||||
std::string newCacheDir = m_cacheDir + "2";
|
||||
m_post_mock_net_callbacks.pop_back();
|
||||
m_post_mock_net_callbacks.emplace_back([&](MockExecutableNetwork& net) {
|
||||
EXPECT_CALL(net, Export(_)).Times(1);
|
||||
});
|
||||
MkDirGuard dir(newCacheDir);
|
||||
ie.SetConfig({{CONFIG_KEY(CACHE_DIR), newCacheDir}});
|
||||
m_testFunction(ie);
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
/// \brief Change CACHE_DIR during working with same 'Core' object for device which supports 'CACHE_DIR' config, not import_export
|
||||
/// Expectation is that SetConfig for plugin will be called 2 times - with appropriate cache_dir values
|
||||
TEST_P(CachingTest, TestLoadChangeCacheDirOneCore_SupportsCacheDir_NoImportExport) {
|
||||
m_checkConfigCb = [](const std::map<std::string, std::string>& config) {
|
||||
EXPECT_EQ(config.count(CONFIG_KEY(CACHE_DIR)), 0);
|
||||
};
|
||||
EXPECT_CALL(*mockPlugin, GetMetric(METRIC_KEY(SUPPORTED_CONFIG_KEYS), _))
|
||||
.Times(AtLeast(1)).WillRepeatedly(Return(std::vector<std::string>{CONFIG_KEY(CACHE_DIR)}));
|
||||
EXPECT_CALL(*mockPlugin, GetMetric(ov::supported_properties.name(), _))
|
||||
.Times(AnyNumber()).WillRepeatedly(Return(std::vector<ov::PropertyName>{
|
||||
ov::supported_properties.name(), ov::cache_dir.name()}));
|
||||
EXPECT_CALL(*mockPlugin, GetMetric(ov::device::capabilities.name(), _)).Times(AnyNumber());
|
||||
EXPECT_CALL(*mockPlugin, GetMetric(METRIC_KEY(SUPPORTED_METRICS), _))
|
||||
.Times(AnyNumber()).WillRepeatedly(
|
||||
Return(std::vector<std::string>{METRIC_KEY(SUPPORTED_CONFIG_KEYS)}));
|
||||
EXPECT_CALL(*mockPlugin, GetMetric(ov::device::capabilities.name(), _)).Times(AnyNumber()).
|
||||
WillRepeatedly(Return(decltype(ov::device::capabilities)::value_type{}));
|
||||
EXPECT_CALL(*mockPlugin, GetMetric(METRIC_KEY(DEVICE_ARCHITECTURE), _)).Times(AnyNumber());
|
||||
std::string set_cache_dir = {};
|
||||
EXPECT_CALL(*mockPlugin, SetConfig(_)).Times(AtLeast(2)).WillRepeatedly(
|
||||
Invoke([&](const std::map<std::string, std::string>& config) {
|
||||
ASSERT_NE(config.count(CONFIG_KEY(CACHE_DIR)), 0);
|
||||
set_cache_dir = config.at(CONFIG_KEY(CACHE_DIR));
|
||||
}));
|
||||
{
|
||||
EXPECT_CALL(*mockPlugin, LoadExeNetworkImpl(_, _, _)).Times(m_remoteContext ? 2 : 0);
|
||||
EXPECT_CALL(*mockPlugin, LoadExeNetworkImpl(_, _)).Times(!m_remoteContext ? 2 : 0);
|
||||
EXPECT_CALL(*mockPlugin, OnLoadNetworkFromFile()).Times(m_type == TestLoadType::EModelName ? 2 : 0);
|
||||
EXPECT_CALL(*mockPlugin, ImportNetwork(_, _, _)).Times(0);
|
||||
EXPECT_CALL(*mockPlugin, ImportNetwork(_, _)).Times(0);
|
||||
m_post_mock_net_callbacks.emplace_back([&](MockExecutableNetwork& net) {
|
||||
EXPECT_CALL(net, Export(_)).Times(0);
|
||||
});
|
||||
testLoad([&](Core &ie) {
|
||||
ie.SetConfig({{CONFIG_KEY(CACHE_DIR), m_cacheDir}});
|
||||
m_testFunction(ie);
|
||||
EXPECT_EQ(set_cache_dir, m_cacheDir);
|
||||
|
||||
std::string new_cache_dir = m_cacheDir + "2";
|
||||
MkDirGuard dir(new_cache_dir);
|
||||
ie.SetConfig({{CONFIG_KEY(CACHE_DIR), new_cache_dir}});
|
||||
m_testFunction(ie);
|
||||
EXPECT_EQ(set_cache_dir, new_cache_dir);
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
/// \brief Change CACHE_DIR per device during working with same 'Core' object - expected that new cache dir is used
|
||||
TEST_P(CachingTest, TestLoadChangeCacheDirOneCore_by_device_name) {
|
||||
EXPECT_CALL(*mockPlugin, GetMetric(METRIC_KEY(SUPPORTED_CONFIG_KEYS), _)).Times(AnyNumber());
|
||||
EXPECT_CALL(*mockPlugin, GetMetric(ov::supported_properties.name(), _)).Times(AnyNumber());
|
||||
EXPECT_CALL(*mockPlugin, GetMetric(METRIC_KEY(SUPPORTED_METRICS), _)).Times(AnyNumber());
|
||||
EXPECT_CALL(*mockPlugin, GetMetric(METRIC_KEY(IMPORT_EXPORT_SUPPORT), _)).Times(AnyNumber());
|
||||
EXPECT_CALL(*mockPlugin, GetMetric(METRIC_KEY(DEVICE_ARCHITECTURE), _)).Times(AnyNumber());
|
||||
EXPECT_CALL(*mockPlugin, SetConfig(_)).Times(AnyNumber()).WillRepeatedly(
|
||||
Invoke([](const std::map<std::string, std::string>& config) {
|
||||
ASSERT_EQ(config.count(CONFIG_KEY(CACHE_DIR)), 0);
|
||||
}));
|
||||
{
|
||||
EXPECT_CALL(*mockPlugin, LoadExeNetworkImpl(_, _, _)).Times(m_remoteContext ? 2 : 0);
|
||||
EXPECT_CALL(*mockPlugin, LoadExeNetworkImpl(_, _)).Times(!m_remoteContext ? 2 : 0);
|
||||
EXPECT_CALL(*mockPlugin, ImportNetwork(_, _, _)).Times(0);
|
||||
EXPECT_CALL(*mockPlugin, ImportNetwork(_, _)).Times(0);
|
||||
testLoad([&](Core &ie) {
|
||||
m_post_mock_net_callbacks.emplace_back([&](MockExecutableNetwork& net) {
|
||||
EXPECT_CALL(net, Export(_)).Times(1);
|
||||
});
|
||||
ie.SetConfig({{CONFIG_KEY(CACHE_DIR), m_cacheDir}}, "mock");
|
||||
m_testFunction(ie);
|
||||
m_post_mock_net_callbacks.pop_back();
|
||||
m_post_mock_net_callbacks.emplace_back([&](MockExecutableNetwork& net) {
|
||||
EXPECT_CALL(net, Export(_)).Times(1);
|
||||
});
|
||||
std::string newCacheDir = m_cacheDir + "2";
|
||||
MkDirGuard dir(newCacheDir);
|
||||
ie.SetConfig({{CONFIG_KEY(CACHE_DIR), newCacheDir}}, "mock");
|
||||
m_testFunction(ie);
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
/// \brief Change CACHE_DIR per device during working with same 'Core' object - device supports CACHE_DIR
|
||||
/// Verifies that no 'export' is called and cache_dir is propagated to set_config
|
||||
TEST_P(CachingTest, TestLoadChangeCacheDirOneCore_by_device_name_supports_cache_dir) {
|
||||
EXPECT_CALL(*mockPlugin, GetMetric(METRIC_KEY(SUPPORTED_CONFIG_KEYS), _))
|
||||
.Times(AtLeast(1)).WillRepeatedly(Return(std::vector<std::string>{CONFIG_KEY(CACHE_DIR)}));
|
||||
EXPECT_CALL(*mockPlugin, GetMetric(ov::supported_properties.name(), _)).Times(AnyNumber())
|
||||
.WillRepeatedly(Return(std::vector<ov::PropertyName>{
|
||||
ov::cache_dir.name()}));
|
||||
EXPECT_CALL(*mockPlugin, GetMetric(METRIC_KEY(SUPPORTED_METRICS), _)).Times(AnyNumber());
|
||||
EXPECT_CALL(*mockPlugin, GetMetric(METRIC_KEY(IMPORT_EXPORT_SUPPORT), _)).Times(AnyNumber())
|
||||
.WillRepeatedly(Return(false));
|
||||
EXPECT_CALL(*mockPlugin, GetMetric(ov::device::capabilities.name(), _)).Times(AnyNumber()).
|
||||
WillRepeatedly(Return(decltype(ov::device::capabilities)::value_type{}));
|
||||
EXPECT_CALL(*mockPlugin, GetMetric(METRIC_KEY(DEVICE_ARCHITECTURE), _)).Times(AnyNumber());
|
||||
EXPECT_CALL(*mockPlugin, SetConfig(_)).Times(AtLeast(2)).WillRepeatedly(
|
||||
Invoke([](const std::map<std::string, std::string>& config) {
|
||||
ASSERT_GT(config.count(CONFIG_KEY(CACHE_DIR)), 0);
|
||||
}));
|
||||
{
|
||||
EXPECT_CALL(*mockPlugin, LoadExeNetworkImpl(_, _, _)).Times(m_remoteContext ? 2 : 0);
|
||||
EXPECT_CALL(*mockPlugin, LoadExeNetworkImpl(_, _)).Times(!m_remoteContext ? 2 : 0);
|
||||
EXPECT_CALL(*mockPlugin, OnLoadNetworkFromFile()).Times(m_type == TestLoadType::EModelName ? 2 : 0);
|
||||
EXPECT_CALL(*mockPlugin, ImportNetwork(_, _, _)).Times(0);
|
||||
EXPECT_CALL(*mockPlugin, ImportNetwork(_, _)).Times(0);
|
||||
testLoad([&](Core &ie) {
|
||||
m_post_mock_net_callbacks.emplace_back([&](MockExecutableNetwork& net) {
|
||||
EXPECT_CALL(net, Export(_)).Times(0);
|
||||
});
|
||||
ie.SetConfig({{CONFIG_KEY(CACHE_DIR), m_cacheDir}}, "mock");
|
||||
m_testFunction(ie);
|
||||
m_post_mock_net_callbacks.pop_back();
|
||||
m_post_mock_net_callbacks.emplace_back([&](MockExecutableNetwork& net) {
|
||||
EXPECT_CALL(net, Export(_)).Times(0);
|
||||
});
|
||||
std::string newCacheDir = m_cacheDir + "2";
|
||||
MkDirGuard dir(newCacheDir);
|
||||
ie.SetConfig({{CONFIG_KEY(CACHE_DIR), newCacheDir}}, "mock");
|
||||
m_testFunction(ie);
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
TEST_P(CachingTest, TestClearCacheDir) {
|
||||
EXPECT_CALL(*mockPlugin, GetMetric(METRIC_KEY(SUPPORTED_CONFIG_KEYS), _)).Times(AnyNumber());
|
||||
EXPECT_CALL(*mockPlugin, GetMetric(ov::supported_properties.name(), _)).Times(AnyNumber());
|
||||
|
||||
Reference in New Issue
Block a user