Caching: pass global CACHE_DIR setting to plugin (#5893)

* Caching: pass global CACHE_DIR setting to plugin

This can be helpful for GPU - it doesn't support Import/Export but can
significantly speed up load time when CACHE_DIR is set for device only

* Ignore exception in 'DeviceSupportsConfigKey' if plugin doesn't support GetMetric at all
This commit is contained in:
Mikhail Nosov
2021-05-31 19:09:07 +03:00
committed by GitHub
parent 7fb9bac24a
commit e29169db47
3 changed files with 135 additions and 3 deletions

View File

@@ -180,6 +180,7 @@ class Core::Impl : public ICore {
class CoreConfig final {
public:
struct CacheConfig {
std::string _cacheDir;
std::shared_ptr<ICacheManager> _cacheManager;
};
@@ -187,6 +188,7 @@ class Core::Impl : public ICore {
auto it = config.find(CONFIG_KEY(CACHE_DIR));
if (it != config.end()) {
std::lock_guard<std::mutex> lock(_cacheConfigMutex);
_cacheConfig._cacheDir = it->second;
if (!it->second.empty()) {
FileUtils::createDirectoryRecursive(it->second);
_cacheConfig._cacheManager = std::make_shared<FileStorageCacheManager>(std::move(it->second));
@@ -241,6 +243,27 @@ class Core::Impl : public ICore {
return supported;
}
bool DeviceSupportsCacheDir(const InferencePlugin& plugin) const {
return DeviceSupportsConfigKey(plugin, CONFIG_KEY(CACHE_DIR));
}
bool DeviceSupportsConfigKey(const InferencePlugin& plugin, const std::string& key) const {
bool supported = false;
std::vector<std::string> supportedMetricKeys;
try {
// If plugin doesn't support 'SUPPORTED_METRICS' - treat it as config is not supported as well
supportedMetricKeys =
plugin.GetMetric(METRIC_KEY(SUPPORTED_METRICS), {}).as<std::vector<std::string>>();
} catch(...) {}
auto it = std::find(supportedMetricKeys.begin(), supportedMetricKeys.end(),
METRIC_KEY(SUPPORTED_CONFIG_KEYS));
if (it != supportedMetricKeys.end()) {
std::vector<std::string> configKeys = plugin.GetMetric(METRIC_KEY(SUPPORTED_CONFIG_KEYS), {});
supported = std::find(configKeys.begin(), configKeys.end(), key) != configKeys.end();
}
return supported;
}
SoExecutableNetworkInternal LoadNetworkImpl(const CNNNetwork& network,
InferencePlugin& plugin,
const std::map<std::string, std::string>& parsedConfig,
@@ -700,6 +723,12 @@ public:
// configuring
{
if (DeviceSupportsCacheDir(plugin)) {
auto cacheConfig = coreConfig.getCacheConfig();
if (cacheConfig._cacheManager) {
desc.defaultConfig[CONFIG_KEY(CACHE_DIR)] = cacheConfig._cacheDir;
}
}
allowNotImplemented([&]() {
plugin.SetConfig(desc.defaultConfig);
});
@@ -816,7 +845,14 @@ public:
for (auto& plugin : plugins) {
if (deviceName.empty() || deviceName == plugin.first) {
allowNotImplemented([&]() {
plugin.second.SetConfig(config);
auto configCopy = config;
if (DeviceSupportsCacheDir(plugin.second)) {
auto cacheConfig = coreConfig.getCacheConfig();
if (cacheConfig._cacheManager) {
configCopy[CONFIG_KEY(CACHE_DIR)] = cacheConfig._cacheDir;
}
}
plugin.second.SetConfig(configCopy);
});
}
}

View File

@@ -111,6 +111,7 @@ public:
const std::map<std::string, std::string>& config));
MOCK_CONST_METHOD2(GetMetric, Parameter(const std::string& name, const std::map<std::string, Parameter>& options));
MOCK_METHOD1(SetConfig, void(const std::map<std::string, std::string>& options));
MOCK_METHOD1(GetDefaultContext, RemoteContext::Ptr(const ParamMap& params));
};
@@ -362,6 +363,11 @@ private:
return res;
}));
EXPECT_CALL(plugin, SetConfig(_)).Times(AnyNumber()).WillRepeatedly(
Invoke([](const std::map<std::string, std::string>) {
throw InferenceEngine::NotImplemented("Not implemented");
}));
EXPECT_CALL(*net, GetInputsInfo()).Times(AnyNumber())
.WillRepeatedly(Return(ConstInputsDataMap{}));
EXPECT_CALL(*net, GetOutputsInfo()).Times(AnyNumber())
@@ -567,6 +573,93 @@ TEST_P(CachingTest, TestNoCacheMetricSupported) {
}
}
TEST_P(CachingTest, TestNoCacheMetric_hasCacheDirConfig) {
EXPECT_CALL(*mockPlugin, GetMetric(METRIC_KEY(SUPPORTED_METRICS), _))
.Times(AnyNumber()).WillRepeatedly(
Return(std::vector<std::string>{METRIC_KEY(SUPPORTED_CONFIG_KEYS)}));
EXPECT_CALL(*mockPlugin, GetMetric(METRIC_KEY(SUPPORTED_CONFIG_KEYS), _))
.Times(AtLeast(1)).WillRepeatedly(Return(std::vector<std::string>{CONFIG_KEY(CACHE_DIR)}));
EXPECT_CALL(*mockPlugin, SetConfig(_)).Times(AtLeast(1)).WillRepeatedly(
Invoke([](const std::map<std::string, std::string>& config) {
ASSERT_GT(config.count(CONFIG_KEY(CACHE_DIR)), 0);
}));
{
EXPECT_CALL(*mockPlugin, LoadExeNetworkImpl(_, _, _)).Times(m_remoteContext ? 1 : 0);
EXPECT_CALL(*mockPlugin, LoadExeNetworkImpl(_, _)).Times(!m_remoteContext ? 1 : 0);
EXPECT_CALL(*mockPlugin, OnLoadNetworkFromFile()).Times(m_type == TestLoadType::EModelName ? 1 : 0);
ASSERT_NO_THROW(
testLoad([&](Core &ie) {
ie.SetConfig({{CONFIG_KEY(CACHE_DIR), m_cacheDir}});
m_testFunction(ie);
}));
}
}
TEST_P(CachingTest, TestCacheEnabled_noConfig) {
EXPECT_CALL(*mockPlugin, GetMetric(METRIC_KEY(SUPPORTED_METRICS), _))
.Times(AnyNumber()).WillRepeatedly(
Return(std::vector<std::string>{METRIC_KEY(SUPPORTED_CONFIG_KEYS)}));
EXPECT_CALL(*mockPlugin, GetMetric(METRIC_KEY(SUPPORTED_CONFIG_KEYS), _))
.Times(AtLeast(1)).WillRepeatedly(Return(std::vector<std::string>{}));
EXPECT_CALL(*mockPlugin, SetConfig(_)).Times(AnyNumber()).WillRepeatedly(
Invoke([](const std::map<std::string, std::string>& config) {
ASSERT_EQ(config.count(CONFIG_KEY(CACHE_DIR)), 0);
}));
{
EXPECT_CALL(*mockPlugin, LoadExeNetworkImpl(_, _, _)).Times(m_remoteContext ? 1 : 0);
EXPECT_CALL(*mockPlugin, LoadExeNetworkImpl(_, _)).Times(!m_remoteContext ? 1 : 0);
EXPECT_CALL(*mockPlugin, OnLoadNetworkFromFile()).Times(m_type == TestLoadType::EModelName ? 1 : 0);
testLoad([&](Core &ie) {
ie.SetConfig({{CONFIG_KEY(CACHE_DIR), m_cacheDir}});
m_testFunction(ie);
});
}
}
TEST_P(CachingTest, TestNoCacheMetric_configThrow) {
EXPECT_CALL(*mockPlugin, GetMetric(METRIC_KEY(SUPPORTED_METRICS), _))
.Times(AnyNumber()).WillRepeatedly(
Return(std::vector<std::string>{METRIC_KEY(SUPPORTED_CONFIG_KEYS)}));
EXPECT_CALL(*mockPlugin, GetMetric(METRIC_KEY(SUPPORTED_CONFIG_KEYS), _))
.Times(AtLeast(1)).WillRepeatedly(Return(std::vector<std::string>{CONFIG_KEY(CACHE_DIR)}));
EXPECT_CALL(*mockPlugin, SetConfig(_)).Times(AtLeast(1)).WillRepeatedly(
Invoke([](const std::map<std::string, std::string>& config) {
ASSERT_GT(config.count(CONFIG_KEY(CACHE_DIR)), 0);
throw InferenceEngine::GeneralError("Error occurred");
}));
ASSERT_ANY_THROW(
testLoad([&](Core &ie) {
ie.SetConfig({{CONFIG_KEY(CACHE_DIR), m_cacheDir}});
m_testFunction(ie);
}));
}
TEST_P(CachingTest, TestNoCacheEnabled_cacheDirConfig) {
EXPECT_CALL(*mockPlugin, GetMetric(METRIC_KEY(SUPPORTED_METRICS), _))
.Times(AnyNumber()).WillRepeatedly(
Return(std::vector<std::string>{METRIC_KEY(SUPPORTED_CONFIG_KEYS)}));
EXPECT_CALL(*mockPlugin, GetMetric(METRIC_KEY(SUPPORTED_CONFIG_KEYS), _))
.Times(AtLeast(1)).WillRepeatedly(Return(std::vector<std::string>{CONFIG_KEY(CACHE_DIR)}));
EXPECT_CALL(*mockPlugin, SetConfig(_)).Times(AnyNumber()).WillRepeatedly(
Invoke([](const std::map<std::string, std::string>& config) {
ASSERT_EQ(config.count(CONFIG_KEY(CACHE_DIR)), 0);
}));
{
EXPECT_CALL(*mockPlugin, LoadExeNetworkImpl(_, _, _)).Times(m_remoteContext ? 1 : 0);
EXPECT_CALL(*mockPlugin, LoadExeNetworkImpl(_, _)).Times(!m_remoteContext ? 1 : 0);
EXPECT_CALL(*mockPlugin, ImportNetworkImpl(_, _, _)).Times(0);
EXPECT_CALL(*mockPlugin, ImportNetworkImpl(_, _)).Times(0);
testLoad([&](Core &ie) {
m_testFunction(ie);
});
}
}
TEST_P(CachingTest, TestLoadChangeCacheDir) {
EXPECT_CALL(*mockPlugin, GetMetric(METRIC_KEY(SUPPORTED_METRICS), _)).Times(AnyNumber());
EXPECT_CALL(*mockPlugin, GetMetric(METRIC_KEY(IMPORT_EXPORT_SUPPORT), _)).Times(AnyNumber());

View File

@@ -17,8 +17,11 @@ MockPlugin::MockPlugin(InferenceEngine::IInferencePlugin *target) {
_target = target;
}
void MockPlugin::SetConfig(const std::map<std::string, std::string>& config) {
this->config = config;
void MockPlugin::SetConfig(const std::map<std::string, std::string>& _config) {
this->config = _config;
if (_target) {
_target->SetConfig(config);
}
}
Parameter MockPlugin::GetMetric(const std::string& name, const std::map<std::string, InferenceEngine::Parameter>& options) const {