[Hetero plugin] Model caching support (#4977)

* [Hetero plugin] Model caching support

- Enable IMPORT_EXPORT_SUPPORT metric
- Introduce internal FORCE_DISABLE_CACHE flag to avoid caching of subnetworks
- Added test for hetero with plugins which don't support caching
- Added test for hetero with plugins supported different cache architecture

* Hetero plugin - support DEVICE_ARCHITECTURE

Test setup:
mock.1 - mock.9 returns "one" for DEVICE_ARCHITECTURE
mock.10 - mock.99 returns "two"

Test:
Load "HETERO:mock.1,mock.51".
Load "HETERO:mock.2,mock.52" - cache shall be reused

* Fixed review comments

Covered use case case
ie.SetConfig({{"TARGET_FALLBACK", "CPU"}}, "HETERO");
ie.LoadNetwork(network, "HETERO");

* Fixed more comments and failed tests

Don't propagate FORCE_DISABLE_CACHE to plugins as they can throw exception
Fixed case with set TARGET_FALLBACK from core with different architectures of one plugin

* Fix unit tests

Add 'FORCE_DISABLE_CACHE' config key only for LoadExeNetwork
It is not needed to have in in QueryNetwork and other places

* Attempt to fix failed func test on Windows
This commit is contained in:
Mikhail Nosov 2021-03-27 22:39:10 +03:00 committed by GitHub
parent a0b06303cf
commit 65bad6c3a8
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 307 additions and 21 deletions

View File

@ -418,6 +418,7 @@ HeteroExecutableNetwork::HeteroExecutableNetwork(const InferenceEngine::CNNNetwo
}
for (auto&& network : networks) {
auto metaDevices = _heteroPlugin->GetDevicePlugins(network._device, _config);
metaDevices[network._device].emplace(CONFIG_KEY_INTERNAL(FORCE_DISABLE_CACHE), "");
network._network = _heteroPlugin->GetCore()->LoadNetwork(network._clonedNetwork,
network._device, metaDevices[network._device]);
}
@ -481,9 +482,9 @@ HeteroExecutableNetwork::HeteroExecutableNetwork(std::istream&
InferenceEngine::ExecutableNetwork executableNetwork;
CNNNetwork cnnnetwork;
bool loaded = false;
try {
if (ImportExportSupported(deviceName)) {
executableNetwork = _heteroPlugin->GetCore()->ImportNetwork(heteroModel, deviceName, loadConfig);
} catch (const InferenceEngine::NotImplemented& ex) {
} else {
// read XML content
std::string xmlString;
std::uint64_t dataSize = 0;
@ -609,9 +610,9 @@ void HeteroExecutableNetwork::ExportImpl(std::ostream& heteroModel) {
heteroModel << std::endl;
for (auto&& subnetwork : networks) {
try {
if (ImportExportSupported(subnetwork._device)) {
subnetwork._network.Export(heteroModel);
} catch (const InferenceEngine::NotImplemented& ex) {
} else {
auto subnet = subnetwork._clonedNetwork;
if (!subnet.getFunction()) {
IE_THROW() << "Hetero plugin supports only ngraph function representation";
@ -798,3 +799,13 @@ InferenceEngine::Parameter HeteroExecutableNetwork::GetMetric(const std::string
IE_THROW() << "Unsupported ExecutableNetwork metric: " << name;
}
}
bool HeteroExecutableNetwork::ImportExportSupported(const std::string& deviceName) const {
std::vector<std::string> supportedMetricKeys = _heteroPlugin->GetCore()->GetMetric(
deviceName, METRIC_KEY(SUPPORTED_METRICS));
auto it = std::find(supportedMetricKeys.begin(), supportedMetricKeys.end(),
METRIC_KEY(IMPORT_EXPORT_SUPPORT));
bool supported = (it != supportedMetricKeys.end()) &&
_heteroPlugin->GetCore()->GetMetric(deviceName, METRIC_KEY(IMPORT_EXPORT_SUPPORT));
return supported;
}

View File

@ -63,6 +63,7 @@ public:
private:
void InitCNNImpl(const InferenceEngine::CNNNetwork& network);
void InitNgraph(const InferenceEngine::CNNNetwork& network);
bool ImportExportSupported(const std::string& deviceName) const;
struct NetworkDesc {
std::string _device;

View File

@ -154,12 +154,14 @@ QueryNetworkResult Engine::QueryNetwork(const CNNNetwork &network, const Configs
return qr;
}
Parameter Engine::GetMetric(const std::string& name, const std::map<std::string, Parameter> & /*options*/) const {
Parameter Engine::GetMetric(const std::string& name, const std::map<std::string, Parameter>& options) const {
if (METRIC_KEY(SUPPORTED_METRICS) == name) {
IE_SET_METRIC_RETURN(SUPPORTED_METRICS, std::vector<std::string>{
METRIC_KEY(SUPPORTED_METRICS),
METRIC_KEY(FULL_DEVICE_NAME),
METRIC_KEY(SUPPORTED_CONFIG_KEYS)});
METRIC_KEY(SUPPORTED_CONFIG_KEYS),
METRIC_KEY(DEVICE_ARCHITECTURE),
METRIC_KEY(IMPORT_EXPORT_SUPPORT)});
} else if (METRIC_KEY(SUPPORTED_CONFIG_KEYS) == name) {
IE_SET_METRIC_RETURN(SUPPORTED_CONFIG_KEYS, std::vector<std::string>{
HETERO_CONFIG_KEY(DUMP_GRAPH_DOT),
@ -167,10 +169,37 @@ Parameter Engine::GetMetric(const std::string& name, const std::map<std::string,
CONFIG_KEY(EXCLUSIVE_ASYNC_REQUESTS)});
} else if (METRIC_KEY(FULL_DEVICE_NAME) == name) {
IE_SET_METRIC_RETURN(FULL_DEVICE_NAME, std::string{"HETERO"});
} else if (METRIC_KEY(IMPORT_EXPORT_SUPPORT) == name) {
IE_SET_METRIC_RETURN(IMPORT_EXPORT_SUPPORT, true);
} else if (METRIC_KEY(DEVICE_ARCHITECTURE) == name) {
auto deviceIt = options.find("TARGET_FALLBACK");
std::string targetFallback;
if (deviceIt != options.end()) {
targetFallback = deviceIt->second.as<std::string>();
} else {
targetFallback = GetConfig("TARGET_FALLBACK", {}).as<std::string>();
}
IE_SET_METRIC_RETURN(DEVICE_ARCHITECTURE, DeviceArchitecture(targetFallback));
} else {
IE_THROW() << "Unsupported Plugin metric: " << name;
}
}
std::string Engine::DeviceArchitecture(const std::string& targetFallback) const {
auto fallbackDevices = InferenceEngine::DeviceIDParser::getHeteroDevices(targetFallback);
std::string resArch;
for (const auto& device : fallbackDevices) {
InferenceEngine::DeviceIDParser parser(device);
std::vector<std::string> supportedMetricKeys = GetCore()->GetMetric(
parser.getDeviceName(), METRIC_KEY(SUPPORTED_METRICS));
auto it = std::find(supportedMetricKeys.begin(), supportedMetricKeys.end(),
METRIC_KEY(DEVICE_ARCHITECTURE));
auto arch = (it != supportedMetricKeys.end()) ?
GetCore()->GetMetric(device, METRIC_KEY(DEVICE_ARCHITECTURE)).as<std::string>() : parser.getDeviceName();
resArch += " " + arch;
}
return resArch;
}
Parameter Engine::GetConfig(const std::string& name, const std::map<std::string, Parameter> & /*options*/) const {
if (name == HETERO_CONFIG_KEY(DUMP_GRAPH_DOT)) {

View File

@ -44,5 +44,6 @@ public:
private:
Configs GetSupportedConfig(const Configs& config, const std::string & deviceName) const;
std::string DeviceArchitecture(const std::string& targetFallback) const;
};
} // namespace HeteroPlugin

View File

@ -25,6 +25,7 @@
#include "file_utils.h"
#include "ie_network_reader.hpp"
#include "xml_parse_utils.h"
#include "cpp_interfaces/interface/ie_internal_plugin_config.hpp"
using namespace InferenceEngine::PluginConfigParams;
using namespace std::placeholders;
@ -222,13 +223,14 @@ class Core::Impl : public ICore {
const std::map<std::string, std::string>& parsedConfig,
const RemoteContext::Ptr& context,
const std::string& blobID,
const std::string& modelPath = std::string()) {
const std::string& modelPath = std::string(),
bool forceDisableCache = false) {
OV_ITT_SCOPED_TASK(itt::domains::IE_LT, "Core::Impl::LoadNetworkImpl");
ExecutableNetwork execNetwork;
execNetwork = context ? plugin.LoadNetwork(network, context, parsedConfig) :
plugin.LoadNetwork(network, parsedConfig);
auto cacheManager = coreConfig.getCacheConfig()._cacheManager;
if (cacheManager && DeviceSupportsImportExport(plugin)) {
if (!forceDisableCache && cacheManager && DeviceSupportsImportExport(plugin)) {
try {
// need to export network for further import from "cache"
OV_ITT_SCOPED_TASK(itt::domains::IE_LT, "Core::LoadNetwork::Export");
@ -296,14 +298,21 @@ class Core::Impl : public ICore {
std::map<std::string, Parameter> getMetricConfig;
auto compileConfig = origConfig;
// 0. remove DEVICE_ID key
// 0. Remove TARGET_FALLBACK key, move it to getMetricConfig
auto targetFallbackIt = compileConfig.find("TARGET_FALLBACK");
if (targetFallbackIt != compileConfig.end()) {
getMetricConfig[targetFallbackIt->first] = targetFallbackIt->second;
compileConfig.erase(targetFallbackIt);
}
// 1. remove DEVICE_ID key
auto deviceIt = compileConfig.find(CONFIG_KEY(DEVICE_ID));
if (deviceIt != compileConfig.end()) {
getMetricConfig[deviceIt->first] = deviceIt->second;
compileConfig.erase(deviceIt);
}
// 1. replace it with DEVICE_ARCHITECTURE value
// 2. replace it with DEVICE_ARCHITECTURE value
std::vector<std::string> supportedMetricKeys =
plugin.GetMetric(METRIC_KEY(SUPPORTED_METRICS), getMetricConfig);
auto archIt = std::find(supportedMetricKeys.begin(), supportedMetricKeys.end(),
@ -456,19 +465,24 @@ public:
ExecutableNetwork LoadNetwork(const CNNNetwork& network, const std::string& deviceName,
const std::map<std::string, std::string>& config) override {
OV_ITT_SCOPED_TASK(itt::domains::IE_LT, "Core::LoadNetwork::CNN");
bool forceDisableCache = config.count(CONFIG_KEY_INTERNAL(FORCE_DISABLE_CACHE)) > 0;
auto parsed = parseDeviceNameIntoConfig(deviceName, config);
if (forceDisableCache) {
// remove this config key from parsed as plugins can throw unsupported exception
parsed._config.erase(CONFIG_KEY_INTERNAL(FORCE_DISABLE_CACHE));
}
auto plugin = GetCPPPluginByName(parsed._deviceName);
bool loadedFromCache = false;
ExecutableNetwork res;
std::string hash;
auto cacheManager = coreConfig.getCacheConfig()._cacheManager;
if (cacheManager && DeviceSupportsImportExport(plugin)) {
if (!forceDisableCache && cacheManager && DeviceSupportsImportExport(plugin)) {
hash = CalculateNetworkHash(network, parsed._deviceName, plugin, parsed._config);
res = LoadNetworkFromCache(cacheManager, hash, plugin, parsed._config, nullptr, loadedFromCache);
}
if (!loadedFromCache) {
res = LoadNetworkImpl(network, plugin, parsed._config, nullptr, hash);
res = LoadNetworkImpl(network, plugin, parsed._config, nullptr, hash, {}, forceDisableCache);
}
return res;
}

View File

@ -38,6 +38,13 @@ DECLARE_CONFIG_KEY(LP_TRANSFORMS_MODE);
*/
DECLARE_CONFIG_KEY(CPU_THREADS_PER_STREAM);
/**
* @brief This key should be used to force disable export while loading network even if global cache dir is defined
* Used by HETERO plugin to disable automatic caching of subnetworks (set value to YES)
* @ingroup ie_dev_api_plugin_api
*/
DECLARE_CONFIG_KEY(FORCE_DISABLE_CACHE);
} // namespace PluginConfigInternalParams
} // namespace InferenceEngine

View File

@ -95,6 +95,8 @@ public:
MockExecutableNetwork() {}
MOCK_METHOD1(ExportImpl, void(std::ostream& networkModel));
MOCK_METHOD0(CreateInferRequest, IInferRequest::Ptr());
MOCK_CONST_METHOD0(GetInputsInfo, ConstInputsDataMap());
MOCK_CONST_METHOD0(GetOutputsInfo, ConstOutputsDataMap());
};
//------------------------------------------------------
@ -267,14 +269,20 @@ private:
WillByDefault(Return("mock"));
ON_CALL(plugin, ImportNetworkImpl(_, _, _)).
WillByDefault(Invoke([&](std::istream &, RemoteContext::Ptr,
WillByDefault(Invoke([&](std::istream &istr, RemoteContext::Ptr,
const std::map<std::string, std::string> &) {
return ExecutableNetwork(std::make_shared<MockIExecutableNetwork>());
auto mock = std::make_shared<MockIExecutableNetwork>();
EXPECT_CALL(*mock, GetInputsInfo(_, _)).Times(AnyNumber()).WillRepeatedly(Return(OK));
EXPECT_CALL(*mock, GetOutputsInfo(_, _)).Times(AnyNumber()).WillRepeatedly(Return(OK));
return ExecutableNetwork(mock);
}));
ON_CALL(plugin, ImportNetworkImpl(_, _)).
WillByDefault(Invoke([&](std::istream &, const std::map<std::string, std::string> &) {
return ExecutableNetwork(std::make_shared<MockIExecutableNetwork>());
WillByDefault(Invoke([&](std::istream &istr, const std::map<std::string, std::string> &) {
auto mock = std::make_shared<MockIExecutableNetwork>();
EXPECT_CALL(*mock, GetInputsInfo(_, _)).Times(AnyNumber()).WillRepeatedly(Return(OK));
EXPECT_CALL(*mock, GetOutputsInfo(_, _)).Times(AnyNumber()).WillRepeatedly(Return(OK));
return ExecutableNetwork(mock);
}));
ON_CALL(plugin, LoadExeNetworkImpl(_, _, _)).
@ -305,6 +313,11 @@ private:
}
return res;
}));
EXPECT_CALL(*net, GetInputsInfo()).Times(AnyNumber())
.WillRepeatedly(Return(ConstInputsDataMap{}));
EXPECT_CALL(*net, GetOutputsInfo()).Times(AnyNumber())
.WillRepeatedly(Return(ConstOutputsDataMap{}));
}
};
@ -348,7 +361,10 @@ TEST_P(CachingTest, TestLoadCustomImportExport) {
int a;
s >> a;
EXPECT_EQ(customNumber, a);
return ExecutableNetwork(std::make_shared<MockIExecutableNetwork>());
auto mock = std::make_shared<MockIExecutableNetwork>();
EXPECT_CALL(*mock, GetInputsInfo(_, _)).Times(AnyNumber()).WillRepeatedly(Return(OK));
EXPECT_CALL(*mock, GetOutputsInfo(_, _)).Times(AnyNumber()).WillRepeatedly(Return(OK));
return ExecutableNetwork(mock);
}));
ON_CALL(*mockPlugin, ImportNetworkImpl(_, _)).
@ -356,7 +372,10 @@ TEST_P(CachingTest, TestLoadCustomImportExport) {
int a;
s >> a;
EXPECT_EQ(customNumber, a);
return ExecutableNetwork(std::make_shared<MockIExecutableNetwork>());
auto mock = std::make_shared<MockIExecutableNetwork>();
EXPECT_CALL(*mock, GetInputsInfo(_, _)).Times(AnyNumber()).WillRepeatedly(Return(OK));
EXPECT_CALL(*mock, GetOutputsInfo(_, _)).Times(AnyNumber()).WillRepeatedly(Return(OK));
return ExecutableNetwork(mock);
}));
ON_CALL(*net, ExportImpl(_)).WillByDefault(Invoke([&] (std::ostream& s) {
@ -916,14 +935,37 @@ TEST_P(CachingTest, TestCacheFileOldVersion) {
}
}
TEST_P(CachingTest, LoadHeteroWithCorrectConfig) {
EXPECT_CALL(*mockPlugin, GetMetric(_, _)).Times(AnyNumber());
TEST_P(CachingTest, LoadHetero_NoCacheMetric) {
EXPECT_CALL(*mockPlugin, GetMetric(METRIC_KEY(SUPPORTED_CONFIG_KEYS), _))
.Times(AnyNumber()).WillRepeatedly(Return(std::vector<std::string>{}));
EXPECT_CALL(*mockPlugin, QueryNetwork(_, _)).Times(AnyNumber());
// TODO: test also HETERO with 1 plugin but different architectures, e.g. "HETERO:mock.1,mock.51"
EXPECT_CALL(*mockPlugin, GetMetric(METRIC_KEY(SUPPORTED_METRICS), _))
.Times(AnyNumber()).WillRepeatedly(Return(std::vector<std::string>{}));
// Hetero supports Import/Export, but mock plugin does not
deviceToLoad = CommonTestUtils::DEVICE_HETERO + std::string(":mock.1,mock.2");
if (m_remoteContext) {
return; // skip the remote Context test for Hetero plugin
}
for (int i = 0; i < 2; i++) {
EXPECT_CALL(*mockPlugin, LoadExeNetworkImpl(_, _, _)).Times(0);
EXPECT_CALL(*mockPlugin, LoadExeNetworkImpl(_, _)).Times(1);
EXPECT_CALL(*mockPlugin, ImportNetworkImpl(_, _, _)).Times(0);
EXPECT_CALL(*mockPlugin, ImportNetworkImpl(_, _)).Times(0);
EXPECT_CALL(*net, ExportImpl(_)).Times(0);
testLoad([&](Core &ie) {
ie.SetConfig({{CONFIG_KEY(CACHE_DIR), m_cacheDir}});
m_testFunction(ie);
});
}
}
TEST_P(CachingTest, LoadHetero_OneDevice) {
EXPECT_CALL(*mockPlugin, QueryNetwork(_, _)).Times(AnyNumber());
EXPECT_CALL(*mockPlugin, GetMetric(_, _)).Times(AnyNumber());
deviceToLoad = CommonTestUtils::DEVICE_HETERO + std::string(":mock");
if (m_remoteContext) {
return; // skip the remote Context test for Hetero plugin
}
{
EXPECT_CALL(*mockPlugin, LoadExeNetworkImpl(_, _, _)).Times(0);
EXPECT_CALL(*mockPlugin, LoadExeNetworkImpl(_, _)).Times(1);
@ -934,6 +976,8 @@ TEST_P(CachingTest, LoadHeteroWithCorrectConfig) {
ie.SetConfig({{CONFIG_KEY(CACHE_DIR), m_cacheDir}});
m_testFunction(ie);
});
// Ensure that only 1 blob (for Hetero) is created
EXPECT_EQ(CommonTestUtils::listFilesWithExt(m_cacheDir, "blob").size(), 1);
}
{
@ -949,6 +993,185 @@ TEST_P(CachingTest, LoadHeteroWithCorrectConfig) {
}
}
TEST_P(CachingTest, LoadHetero_TargetFallbackFromCore) {
EXPECT_CALL(*mockPlugin, QueryNetwork(_, _)).Times(AnyNumber());
EXPECT_CALL(*mockPlugin, GetMetric(_, _)).Times(AnyNumber());
deviceToLoad = CommonTestUtils::DEVICE_HETERO;
if (m_remoteContext) {
return; // skip the remote Context test for Hetero plugin
}
{
EXPECT_CALL(*mockPlugin, LoadExeNetworkImpl(_, _, _)).Times(0);
EXPECT_CALL(*mockPlugin, LoadExeNetworkImpl(_, _)).Times(1);
EXPECT_CALL(*mockPlugin, ImportNetworkImpl(_, _, _)).Times(0);
EXPECT_CALL(*mockPlugin, ImportNetworkImpl(_, _)).Times(0);
EXPECT_CALL(*net, ExportImpl(_)).Times(1);
testLoad([&](Core &ie) {
ie.SetConfig({{CONFIG_KEY(CACHE_DIR), m_cacheDir}});
ie.SetConfig({{"TARGET_FALLBACK", "mock"}}, CommonTestUtils::DEVICE_HETERO);
m_testFunction(ie);
});
// Ensure that only 1 blob (for Hetero) is created
EXPECT_EQ(CommonTestUtils::listFilesWithExt(m_cacheDir, "blob").size(), 1);
}
{
EXPECT_CALL(*mockPlugin, LoadExeNetworkImpl(_, _, _)).Times(0);
EXPECT_CALL(*mockPlugin, LoadExeNetworkImpl(_, _)).Times(0);
EXPECT_CALL(*mockPlugin, ImportNetworkImpl(_, _, _)).Times(0);
EXPECT_CALL(*mockPlugin, ImportNetworkImpl(_, _)).Times(1);
EXPECT_CALL(*net, ExportImpl(_)).Times(0);
testLoad([&](Core &ie) {
ie.SetConfig({{CONFIG_KEY(CACHE_DIR), m_cacheDir}});
ie.SetConfig({{"TARGET_FALLBACK", "mock"}}, CommonTestUtils::DEVICE_HETERO);
m_testFunction(ie);
});
}
}
TEST_P(CachingTest, LoadHetero_MultiArchs) {
EXPECT_CALL(*mockPlugin, GetMetric(_, _)).Times(AnyNumber());
int customNumber = 1234;
ON_CALL(*mockPlugin, ImportNetworkImpl(_, _)).
WillByDefault(Invoke([&](std::istream &s, const std::map<std::string, std::string> &) {
int a;
s >> a;
EXPECT_EQ(customNumber, a);
auto mock = std::make_shared<MockIExecutableNetwork>();
EXPECT_CALL(*mock, GetInputsInfo(_, _)).Times(AnyNumber()).WillRepeatedly(Return(OK));
EXPECT_CALL(*mock, GetOutputsInfo(_, _)).Times(AnyNumber()).WillRepeatedly(Return(OK));
return ExecutableNetwork(mock);
}));
ON_CALL(*net, ExportImpl(_)).WillByDefault(Invoke([&] (std::ostream& s) {
s << customNumber;
}));
EXPECT_CALL(*mockPlugin, QueryNetwork(_, _)).Times(AnyNumber()).WillRepeatedly(
Invoke([&](const CNNNetwork &network, const std::map<std::string, std::string> &config) {
QueryNetworkResult res;
auto function = network.getFunction();
EXPECT_TRUE(function);
auto id = config.at("DEVICE_ID");
bool supportsRelu = std::stoi(id) < 10;
for (auto &&node : function->get_ops()) {
std::string nodeType = node->get_type_name();
if ((nodeType == "Relu" && supportsRelu) ||
(nodeType != "Relu" && !supportsRelu)) {
res.supportedLayersMap.emplace(node->get_friendly_name(), deviceName + "." + id);
}
}
return res;
}));
EXPECT_CALL(*mockPlugin, GetMetric(METRIC_KEY(DEVICE_ARCHITECTURE), _)).Times(AnyNumber())
.WillRepeatedly(Invoke([&](const std::string &, const std::map<std::string, Parameter> &options) {
auto id = options.at("DEVICE_ID").as<std::string>();
if (std::stoi(id) < 10) {
return "mock_first_architecture";
} else {
return "mock_another_architecture";
}
}));
deviceToLoad = CommonTestUtils::DEVICE_HETERO + std::string(":mock.1,mock.51");
if (m_remoteContext) {
return; // skip the remote Context test for Hetero plugin
}
{
EXPECT_CALL(*mockPlugin, LoadExeNetworkImpl(_, _, _)).Times(0);
EXPECT_CALL(*mockPlugin, LoadExeNetworkImpl(_, _)).Times(AtLeast(2)); // for .1 and for .51
EXPECT_CALL(*mockPlugin, ImportNetworkImpl(_, _, _)).Times(0);
EXPECT_CALL(*mockPlugin, ImportNetworkImpl(_, _)).Times(0);
EXPECT_CALL(*net, ExportImpl(_)).Times(AtLeast(2)); // for .1 and for .51
testLoad([&](Core &ie) {
ie.SetConfig({{CONFIG_KEY(CACHE_DIR), m_cacheDir}});
m_testFunction(ie);
});
// Ensure that only 1 blob (for Hetero) is created
EXPECT_EQ(CommonTestUtils::listFilesWithExt(m_cacheDir, "blob").size(), 1);
}
deviceToLoad = CommonTestUtils::DEVICE_HETERO + std::string(":mock.2,mock.52");
{
EXPECT_CALL(*mockPlugin, LoadExeNetworkImpl(_, _, _)).Times(0);
EXPECT_CALL(*mockPlugin, LoadExeNetworkImpl(_, _)).Times(0);
EXPECT_CALL(*mockPlugin, ImportNetworkImpl(_, _, _)).Times(0);
EXPECT_CALL(*mockPlugin, ImportNetworkImpl(_, _)).Times(AtLeast(2)); // for .2 and for .52
EXPECT_CALL(*net, ExportImpl(_)).Times(0);
testLoad([&](Core &ie) {
ie.SetConfig({{CONFIG_KEY(CACHE_DIR), m_cacheDir}});
m_testFunction(ie);
});
}
deviceToLoad = CommonTestUtils::DEVICE_HETERO + std::string(":mock.53,mock.3");
{
EXPECT_CALL(*mockPlugin, LoadExeNetworkImpl(_, _, _)).Times(0);
EXPECT_CALL(*mockPlugin, LoadExeNetworkImpl(_, _)).Times(AtLeast(1));
EXPECT_CALL(*mockPlugin, ImportNetworkImpl(_, _, _)).Times(0);
EXPECT_CALL(*mockPlugin, ImportNetworkImpl(_, _)).Times(0);
EXPECT_CALL(*net, ExportImpl(_)).Times(AtLeast(1));
testLoad([&](Core &ie) {
ie.SetConfig({{CONFIG_KEY(CACHE_DIR), m_cacheDir}});
m_testFunction(ie);
});
}
}
TEST_P(CachingTest, LoadHetero_MultiArchs_TargetFallback_FromCore) {
EXPECT_CALL(*mockPlugin, GetMetric(_, _)).Times(AnyNumber());
EXPECT_CALL(*mockPlugin, QueryNetwork(_, _)).Times(AnyNumber());
EXPECT_CALL(*mockPlugin, GetMetric(METRIC_KEY(DEVICE_ARCHITECTURE), _)).Times(AnyNumber())
.WillRepeatedly(Invoke([&](const std::string &, const std::map<std::string, Parameter> &options) {
auto id = options.at("DEVICE_ID").as<std::string>();
if (std::stoi(id) < 10) {
return "mock_first_architecture";
} else {
return "mock_another_architecture";
}
}));
deviceToLoad = CommonTestUtils::DEVICE_HETERO;
if (m_remoteContext) {
return; // skip the remote Context test for Hetero plugin
}
{
EXPECT_CALL(*mockPlugin, LoadExeNetworkImpl(_, _, _)).Times(0);
EXPECT_CALL(*mockPlugin, LoadExeNetworkImpl(_, _)).Times(1);
EXPECT_CALL(*mockPlugin, ImportNetworkImpl(_, _, _)).Times(0);
EXPECT_CALL(*mockPlugin, ImportNetworkImpl(_, _)).Times(0);
EXPECT_CALL(*net, ExportImpl(_)).Times(1);
testLoad([&](Core &ie) {
ie.SetConfig({{CONFIG_KEY(CACHE_DIR), m_cacheDir}});
ie.SetConfig({{"TARGET_FALLBACK", "mock.1"}}, CommonTestUtils::DEVICE_HETERO);
m_testFunction(ie);
});
}
{
EXPECT_CALL(*mockPlugin, LoadExeNetworkImpl(_, _, _)).Times(0);
EXPECT_CALL(*mockPlugin, LoadExeNetworkImpl(_, _)).Times(0);
EXPECT_CALL(*mockPlugin, ImportNetworkImpl(_, _, _)).Times(0);
EXPECT_CALL(*mockPlugin, ImportNetworkImpl(_, _)).Times(1);
EXPECT_CALL(*net, ExportImpl(_)).Times(0);
testLoad([&](Core &ie) {
ie.SetConfig({{"TARGET_FALLBACK", "mock.1"}}, CommonTestUtils::DEVICE_HETERO);
ie.SetConfig({{CONFIG_KEY(CACHE_DIR), m_cacheDir}});
m_testFunction(ie);
});
}
{
EXPECT_CALL(*mockPlugin, LoadExeNetworkImpl(_, _, _)).Times(0);
EXPECT_CALL(*mockPlugin, LoadExeNetworkImpl(_, _)).Times(1);
EXPECT_CALL(*mockPlugin, ImportNetworkImpl(_, _, _)).Times(0);
EXPECT_CALL(*mockPlugin, ImportNetworkImpl(_, _)).Times(0);
EXPECT_CALL(*net, ExportImpl(_)).Times(1);
testLoad([&](Core &ie) {
ie.SetConfig({{"TARGET_FALLBACK", "mock.51"}}, CommonTestUtils::DEVICE_HETERO);
ie.SetConfig({{CONFIG_KEY(CACHE_DIR), m_cacheDir}});
m_testFunction(ie);
});
}
}
INSTANTIATE_TEST_CASE_P(CachingTest, CachingTest,
::testing::Combine(
::testing::ValuesIn(loadVariants),