Caching: pass global CACHE_DIR setting to plugin (#5893)
* Caching: pass global CACHE_DIR setting to plugin This can be helpful for GPU - it doesn't support Import/Export but can significantly speed up load time when CACHE_DIR is set for device only * Ignore exception in 'DeviceSupportsConfigKey' if plugin doesn't support GetMetric at all
This commit is contained in:
@@ -180,6 +180,7 @@ class Core::Impl : public ICore {
|
||||
class CoreConfig final {
|
||||
public:
|
||||
struct CacheConfig {
|
||||
std::string _cacheDir;
|
||||
std::shared_ptr<ICacheManager> _cacheManager;
|
||||
};
|
||||
|
||||
@@ -187,6 +188,7 @@ class Core::Impl : public ICore {
|
||||
auto it = config.find(CONFIG_KEY(CACHE_DIR));
|
||||
if (it != config.end()) {
|
||||
std::lock_guard<std::mutex> lock(_cacheConfigMutex);
|
||||
_cacheConfig._cacheDir = it->second;
|
||||
if (!it->second.empty()) {
|
||||
FileUtils::createDirectoryRecursive(it->second);
|
||||
_cacheConfig._cacheManager = std::make_shared<FileStorageCacheManager>(std::move(it->second));
|
||||
@@ -241,6 +243,27 @@ class Core::Impl : public ICore {
|
||||
return supported;
|
||||
}
|
||||
|
||||
bool DeviceSupportsCacheDir(const InferencePlugin& plugin) const {
|
||||
return DeviceSupportsConfigKey(plugin, CONFIG_KEY(CACHE_DIR));
|
||||
}
|
||||
|
||||
bool DeviceSupportsConfigKey(const InferencePlugin& plugin, const std::string& key) const {
|
||||
bool supported = false;
|
||||
std::vector<std::string> supportedMetricKeys;
|
||||
try {
|
||||
// If plugin doesn't support 'SUPPORTED_METRICS' - treat it as config is not supported as well
|
||||
supportedMetricKeys =
|
||||
plugin.GetMetric(METRIC_KEY(SUPPORTED_METRICS), {}).as<std::vector<std::string>>();
|
||||
} catch(...) {}
|
||||
auto it = std::find(supportedMetricKeys.begin(), supportedMetricKeys.end(),
|
||||
METRIC_KEY(SUPPORTED_CONFIG_KEYS));
|
||||
if (it != supportedMetricKeys.end()) {
|
||||
std::vector<std::string> configKeys = plugin.GetMetric(METRIC_KEY(SUPPORTED_CONFIG_KEYS), {});
|
||||
supported = std::find(configKeys.begin(), configKeys.end(), key) != configKeys.end();
|
||||
}
|
||||
return supported;
|
||||
}
|
||||
|
||||
SoExecutableNetworkInternal LoadNetworkImpl(const CNNNetwork& network,
|
||||
InferencePlugin& plugin,
|
||||
const std::map<std::string, std::string>& parsedConfig,
|
||||
@@ -700,6 +723,12 @@ public:
|
||||
|
||||
// configuring
|
||||
{
|
||||
if (DeviceSupportsCacheDir(plugin)) {
|
||||
auto cacheConfig = coreConfig.getCacheConfig();
|
||||
if (cacheConfig._cacheManager) {
|
||||
desc.defaultConfig[CONFIG_KEY(CACHE_DIR)] = cacheConfig._cacheDir;
|
||||
}
|
||||
}
|
||||
allowNotImplemented([&]() {
|
||||
plugin.SetConfig(desc.defaultConfig);
|
||||
});
|
||||
@@ -816,7 +845,14 @@ public:
|
||||
for (auto& plugin : plugins) {
|
||||
if (deviceName.empty() || deviceName == plugin.first) {
|
||||
allowNotImplemented([&]() {
|
||||
plugin.second.SetConfig(config);
|
||||
auto configCopy = config;
|
||||
if (DeviceSupportsCacheDir(plugin.second)) {
|
||||
auto cacheConfig = coreConfig.getCacheConfig();
|
||||
if (cacheConfig._cacheManager) {
|
||||
configCopy[CONFIG_KEY(CACHE_DIR)] = cacheConfig._cacheDir;
|
||||
}
|
||||
}
|
||||
plugin.second.SetConfig(configCopy);
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
@@ -111,6 +111,7 @@ public:
|
||||
const std::map<std::string, std::string>& config));
|
||||
|
||||
MOCK_CONST_METHOD2(GetMetric, Parameter(const std::string& name, const std::map<std::string, Parameter>& options));
|
||||
MOCK_METHOD1(SetConfig, void(const std::map<std::string, std::string>& options));
|
||||
MOCK_METHOD1(GetDefaultContext, RemoteContext::Ptr(const ParamMap& params));
|
||||
};
|
||||
|
||||
@@ -362,6 +363,11 @@ private:
|
||||
return res;
|
||||
}));
|
||||
|
||||
EXPECT_CALL(plugin, SetConfig(_)).Times(AnyNumber()).WillRepeatedly(
|
||||
Invoke([](const std::map<std::string, std::string>) {
|
||||
throw InferenceEngine::NotImplemented("Not implemented");
|
||||
}));
|
||||
|
||||
EXPECT_CALL(*net, GetInputsInfo()).Times(AnyNumber())
|
||||
.WillRepeatedly(Return(ConstInputsDataMap{}));
|
||||
EXPECT_CALL(*net, GetOutputsInfo()).Times(AnyNumber())
|
||||
@@ -567,6 +573,93 @@ TEST_P(CachingTest, TestNoCacheMetricSupported) {
|
||||
}
|
||||
}
|
||||
|
||||
TEST_P(CachingTest, TestNoCacheMetric_hasCacheDirConfig) {
|
||||
EXPECT_CALL(*mockPlugin, GetMetric(METRIC_KEY(SUPPORTED_METRICS), _))
|
||||
.Times(AnyNumber()).WillRepeatedly(
|
||||
Return(std::vector<std::string>{METRIC_KEY(SUPPORTED_CONFIG_KEYS)}));
|
||||
EXPECT_CALL(*mockPlugin, GetMetric(METRIC_KEY(SUPPORTED_CONFIG_KEYS), _))
|
||||
.Times(AtLeast(1)).WillRepeatedly(Return(std::vector<std::string>{CONFIG_KEY(CACHE_DIR)}));
|
||||
EXPECT_CALL(*mockPlugin, SetConfig(_)).Times(AtLeast(1)).WillRepeatedly(
|
||||
Invoke([](const std::map<std::string, std::string>& config) {
|
||||
ASSERT_GT(config.count(CONFIG_KEY(CACHE_DIR)), 0);
|
||||
}));
|
||||
|
||||
{
|
||||
EXPECT_CALL(*mockPlugin, LoadExeNetworkImpl(_, _, _)).Times(m_remoteContext ? 1 : 0);
|
||||
EXPECT_CALL(*mockPlugin, LoadExeNetworkImpl(_, _)).Times(!m_remoteContext ? 1 : 0);
|
||||
EXPECT_CALL(*mockPlugin, OnLoadNetworkFromFile()).Times(m_type == TestLoadType::EModelName ? 1 : 0);
|
||||
ASSERT_NO_THROW(
|
||||
testLoad([&](Core &ie) {
|
||||
ie.SetConfig({{CONFIG_KEY(CACHE_DIR), m_cacheDir}});
|
||||
m_testFunction(ie);
|
||||
}));
|
||||
}
|
||||
}
|
||||
|
||||
TEST_P(CachingTest, TestCacheEnabled_noConfig) {
|
||||
EXPECT_CALL(*mockPlugin, GetMetric(METRIC_KEY(SUPPORTED_METRICS), _))
|
||||
.Times(AnyNumber()).WillRepeatedly(
|
||||
Return(std::vector<std::string>{METRIC_KEY(SUPPORTED_CONFIG_KEYS)}));
|
||||
EXPECT_CALL(*mockPlugin, GetMetric(METRIC_KEY(SUPPORTED_CONFIG_KEYS), _))
|
||||
.Times(AtLeast(1)).WillRepeatedly(Return(std::vector<std::string>{}));
|
||||
EXPECT_CALL(*mockPlugin, SetConfig(_)).Times(AnyNumber()).WillRepeatedly(
|
||||
Invoke([](const std::map<std::string, std::string>& config) {
|
||||
ASSERT_EQ(config.count(CONFIG_KEY(CACHE_DIR)), 0);
|
||||
}));
|
||||
|
||||
{
|
||||
EXPECT_CALL(*mockPlugin, LoadExeNetworkImpl(_, _, _)).Times(m_remoteContext ? 1 : 0);
|
||||
EXPECT_CALL(*mockPlugin, LoadExeNetworkImpl(_, _)).Times(!m_remoteContext ? 1 : 0);
|
||||
EXPECT_CALL(*mockPlugin, OnLoadNetworkFromFile()).Times(m_type == TestLoadType::EModelName ? 1 : 0);
|
||||
testLoad([&](Core &ie) {
|
||||
ie.SetConfig({{CONFIG_KEY(CACHE_DIR), m_cacheDir}});
|
||||
m_testFunction(ie);
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
TEST_P(CachingTest, TestNoCacheMetric_configThrow) {
|
||||
EXPECT_CALL(*mockPlugin, GetMetric(METRIC_KEY(SUPPORTED_METRICS), _))
|
||||
.Times(AnyNumber()).WillRepeatedly(
|
||||
Return(std::vector<std::string>{METRIC_KEY(SUPPORTED_CONFIG_KEYS)}));
|
||||
EXPECT_CALL(*mockPlugin, GetMetric(METRIC_KEY(SUPPORTED_CONFIG_KEYS), _))
|
||||
.Times(AtLeast(1)).WillRepeatedly(Return(std::vector<std::string>{CONFIG_KEY(CACHE_DIR)}));
|
||||
EXPECT_CALL(*mockPlugin, SetConfig(_)).Times(AtLeast(1)).WillRepeatedly(
|
||||
Invoke([](const std::map<std::string, std::string>& config) {
|
||||
ASSERT_GT(config.count(CONFIG_KEY(CACHE_DIR)), 0);
|
||||
throw InferenceEngine::GeneralError("Error occurred");
|
||||
}));
|
||||
|
||||
ASSERT_ANY_THROW(
|
||||
testLoad([&](Core &ie) {
|
||||
ie.SetConfig({{CONFIG_KEY(CACHE_DIR), m_cacheDir}});
|
||||
m_testFunction(ie);
|
||||
}));
|
||||
}
|
||||
|
||||
TEST_P(CachingTest, TestNoCacheEnabled_cacheDirConfig) {
|
||||
EXPECT_CALL(*mockPlugin, GetMetric(METRIC_KEY(SUPPORTED_METRICS), _))
|
||||
.Times(AnyNumber()).WillRepeatedly(
|
||||
Return(std::vector<std::string>{METRIC_KEY(SUPPORTED_CONFIG_KEYS)}));
|
||||
EXPECT_CALL(*mockPlugin, GetMetric(METRIC_KEY(SUPPORTED_CONFIG_KEYS), _))
|
||||
.Times(AtLeast(1)).WillRepeatedly(Return(std::vector<std::string>{CONFIG_KEY(CACHE_DIR)}));
|
||||
EXPECT_CALL(*mockPlugin, SetConfig(_)).Times(AnyNumber()).WillRepeatedly(
|
||||
Invoke([](const std::map<std::string, std::string>& config) {
|
||||
ASSERT_EQ(config.count(CONFIG_KEY(CACHE_DIR)), 0);
|
||||
}));
|
||||
|
||||
{
|
||||
EXPECT_CALL(*mockPlugin, LoadExeNetworkImpl(_, _, _)).Times(m_remoteContext ? 1 : 0);
|
||||
EXPECT_CALL(*mockPlugin, LoadExeNetworkImpl(_, _)).Times(!m_remoteContext ? 1 : 0);
|
||||
EXPECT_CALL(*mockPlugin, ImportNetworkImpl(_, _, _)).Times(0);
|
||||
EXPECT_CALL(*mockPlugin, ImportNetworkImpl(_, _)).Times(0);
|
||||
testLoad([&](Core &ie) {
|
||||
m_testFunction(ie);
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
TEST_P(CachingTest, TestLoadChangeCacheDir) {
|
||||
EXPECT_CALL(*mockPlugin, GetMetric(METRIC_KEY(SUPPORTED_METRICS), _)).Times(AnyNumber());
|
||||
EXPECT_CALL(*mockPlugin, GetMetric(METRIC_KEY(IMPORT_EXPORT_SUPPORT), _)).Times(AnyNumber());
|
||||
|
||||
@@ -17,8 +17,11 @@ MockPlugin::MockPlugin(InferenceEngine::IInferencePlugin *target) {
|
||||
_target = target;
|
||||
}
|
||||
|
||||
void MockPlugin::SetConfig(const std::map<std::string, std::string>& config) {
|
||||
this->config = config;
|
||||
void MockPlugin::SetConfig(const std::map<std::string, std::string>& _config) {
|
||||
this->config = _config;
|
||||
if (_target) {
|
||||
_target->SetConfig(config);
|
||||
}
|
||||
}
|
||||
|
||||
Parameter MockPlugin::GetMetric(const std::string& name, const std::map<std::string, InferenceEngine::Parameter>& options) const {
|
||||
|
||||
Reference in New Issue
Block a user