[Hetero plugin] Model caching support (#4977)
* [Hetero plugin] Model caching support - Enable IMPORT_EXPORT_SUPPORT metric - Introduce internal FORCE_DISABLE_CACHE flag to avoid caching of subnetworks - Added test for hetero with plugins which don't support caching - Added test for hetero with plugins supported different cache architecture * Hetero plugin - support DEVICE_ARCHITECTURE Test setup: mock.1 - mock.9 returns "one" for DEVICE_ARCHITECTURE mock.10 - mock.99 returns "two" Test: Load "HETERO:mock.1,mock.51". Load "HETERO:mock.2,mock.52" - cache shall be reused * Fixed review comments Covered use case case ie.SetConfig({{"TARGET_FALLBACK", "CPU"}}, "HETERO"); ie.LoadNetwork(network, "HETERO"); * Fixed more comments and failed tests Don't propagate FORCE_DISABLE_CACHE to plugins as they can throw exception Fixed case with set TARGET_FALLBACK from core with different architectures of one plugin * Fix unit tests Add 'FORCE_DISABLE_CACHE' config key only for LoadExeNetwork It is not needed to have in in QueryNetwork and other places * Attempt to fix failed func test on Windows
This commit is contained in:
parent
a0b06303cf
commit
65bad6c3a8
@ -418,6 +418,7 @@ HeteroExecutableNetwork::HeteroExecutableNetwork(const InferenceEngine::CNNNetwo
|
||||
}
|
||||
for (auto&& network : networks) {
|
||||
auto metaDevices = _heteroPlugin->GetDevicePlugins(network._device, _config);
|
||||
metaDevices[network._device].emplace(CONFIG_KEY_INTERNAL(FORCE_DISABLE_CACHE), "");
|
||||
network._network = _heteroPlugin->GetCore()->LoadNetwork(network._clonedNetwork,
|
||||
network._device, metaDevices[network._device]);
|
||||
}
|
||||
@ -481,9 +482,9 @@ HeteroExecutableNetwork::HeteroExecutableNetwork(std::istream&
|
||||
InferenceEngine::ExecutableNetwork executableNetwork;
|
||||
CNNNetwork cnnnetwork;
|
||||
bool loaded = false;
|
||||
try {
|
||||
if (ImportExportSupported(deviceName)) {
|
||||
executableNetwork = _heteroPlugin->GetCore()->ImportNetwork(heteroModel, deviceName, loadConfig);
|
||||
} catch (const InferenceEngine::NotImplemented& ex) {
|
||||
} else {
|
||||
// read XML content
|
||||
std::string xmlString;
|
||||
std::uint64_t dataSize = 0;
|
||||
@ -609,9 +610,9 @@ void HeteroExecutableNetwork::ExportImpl(std::ostream& heteroModel) {
|
||||
heteroModel << std::endl;
|
||||
|
||||
for (auto&& subnetwork : networks) {
|
||||
try {
|
||||
if (ImportExportSupported(subnetwork._device)) {
|
||||
subnetwork._network.Export(heteroModel);
|
||||
} catch (const InferenceEngine::NotImplemented& ex) {
|
||||
} else {
|
||||
auto subnet = subnetwork._clonedNetwork;
|
||||
if (!subnet.getFunction()) {
|
||||
IE_THROW() << "Hetero plugin supports only ngraph function representation";
|
||||
@ -798,3 +799,13 @@ InferenceEngine::Parameter HeteroExecutableNetwork::GetMetric(const std::string
|
||||
IE_THROW() << "Unsupported ExecutableNetwork metric: " << name;
|
||||
}
|
||||
}
|
||||
|
||||
bool HeteroExecutableNetwork::ImportExportSupported(const std::string& deviceName) const {
|
||||
std::vector<std::string> supportedMetricKeys = _heteroPlugin->GetCore()->GetMetric(
|
||||
deviceName, METRIC_KEY(SUPPORTED_METRICS));
|
||||
auto it = std::find(supportedMetricKeys.begin(), supportedMetricKeys.end(),
|
||||
METRIC_KEY(IMPORT_EXPORT_SUPPORT));
|
||||
bool supported = (it != supportedMetricKeys.end()) &&
|
||||
_heteroPlugin->GetCore()->GetMetric(deviceName, METRIC_KEY(IMPORT_EXPORT_SUPPORT));
|
||||
return supported;
|
||||
}
|
||||
|
@ -63,6 +63,7 @@ public:
|
||||
private:
|
||||
void InitCNNImpl(const InferenceEngine::CNNNetwork& network);
|
||||
void InitNgraph(const InferenceEngine::CNNNetwork& network);
|
||||
bool ImportExportSupported(const std::string& deviceName) const;
|
||||
|
||||
struct NetworkDesc {
|
||||
std::string _device;
|
||||
|
@ -154,12 +154,14 @@ QueryNetworkResult Engine::QueryNetwork(const CNNNetwork &network, const Configs
|
||||
return qr;
|
||||
}
|
||||
|
||||
Parameter Engine::GetMetric(const std::string& name, const std::map<std::string, Parameter> & /*options*/) const {
|
||||
Parameter Engine::GetMetric(const std::string& name, const std::map<std::string, Parameter>& options) const {
|
||||
if (METRIC_KEY(SUPPORTED_METRICS) == name) {
|
||||
IE_SET_METRIC_RETURN(SUPPORTED_METRICS, std::vector<std::string>{
|
||||
METRIC_KEY(SUPPORTED_METRICS),
|
||||
METRIC_KEY(FULL_DEVICE_NAME),
|
||||
METRIC_KEY(SUPPORTED_CONFIG_KEYS)});
|
||||
METRIC_KEY(SUPPORTED_CONFIG_KEYS),
|
||||
METRIC_KEY(DEVICE_ARCHITECTURE),
|
||||
METRIC_KEY(IMPORT_EXPORT_SUPPORT)});
|
||||
} else if (METRIC_KEY(SUPPORTED_CONFIG_KEYS) == name) {
|
||||
IE_SET_METRIC_RETURN(SUPPORTED_CONFIG_KEYS, std::vector<std::string>{
|
||||
HETERO_CONFIG_KEY(DUMP_GRAPH_DOT),
|
||||
@ -167,10 +169,37 @@ Parameter Engine::GetMetric(const std::string& name, const std::map<std::string,
|
||||
CONFIG_KEY(EXCLUSIVE_ASYNC_REQUESTS)});
|
||||
} else if (METRIC_KEY(FULL_DEVICE_NAME) == name) {
|
||||
IE_SET_METRIC_RETURN(FULL_DEVICE_NAME, std::string{"HETERO"});
|
||||
} else if (METRIC_KEY(IMPORT_EXPORT_SUPPORT) == name) {
|
||||
IE_SET_METRIC_RETURN(IMPORT_EXPORT_SUPPORT, true);
|
||||
} else if (METRIC_KEY(DEVICE_ARCHITECTURE) == name) {
|
||||
auto deviceIt = options.find("TARGET_FALLBACK");
|
||||
std::string targetFallback;
|
||||
if (deviceIt != options.end()) {
|
||||
targetFallback = deviceIt->second.as<std::string>();
|
||||
} else {
|
||||
targetFallback = GetConfig("TARGET_FALLBACK", {}).as<std::string>();
|
||||
}
|
||||
IE_SET_METRIC_RETURN(DEVICE_ARCHITECTURE, DeviceArchitecture(targetFallback));
|
||||
} else {
|
||||
IE_THROW() << "Unsupported Plugin metric: " << name;
|
||||
}
|
||||
}
|
||||
std::string Engine::DeviceArchitecture(const std::string& targetFallback) const {
|
||||
auto fallbackDevices = InferenceEngine::DeviceIDParser::getHeteroDevices(targetFallback);
|
||||
std::string resArch;
|
||||
for (const auto& device : fallbackDevices) {
|
||||
InferenceEngine::DeviceIDParser parser(device);
|
||||
|
||||
std::vector<std::string> supportedMetricKeys = GetCore()->GetMetric(
|
||||
parser.getDeviceName(), METRIC_KEY(SUPPORTED_METRICS));
|
||||
auto it = std::find(supportedMetricKeys.begin(), supportedMetricKeys.end(),
|
||||
METRIC_KEY(DEVICE_ARCHITECTURE));
|
||||
auto arch = (it != supportedMetricKeys.end()) ?
|
||||
GetCore()->GetMetric(device, METRIC_KEY(DEVICE_ARCHITECTURE)).as<std::string>() : parser.getDeviceName();
|
||||
resArch += " " + arch;
|
||||
}
|
||||
return resArch;
|
||||
}
|
||||
|
||||
Parameter Engine::GetConfig(const std::string& name, const std::map<std::string, Parameter> & /*options*/) const {
|
||||
if (name == HETERO_CONFIG_KEY(DUMP_GRAPH_DOT)) {
|
||||
|
@ -44,5 +44,6 @@ public:
|
||||
|
||||
private:
|
||||
Configs GetSupportedConfig(const Configs& config, const std::string & deviceName) const;
|
||||
std::string DeviceArchitecture(const std::string& targetFallback) const;
|
||||
};
|
||||
} // namespace HeteroPlugin
|
||||
|
@ -25,6 +25,7 @@
|
||||
#include "file_utils.h"
|
||||
#include "ie_network_reader.hpp"
|
||||
#include "xml_parse_utils.h"
|
||||
#include "cpp_interfaces/interface/ie_internal_plugin_config.hpp"
|
||||
|
||||
using namespace InferenceEngine::PluginConfigParams;
|
||||
using namespace std::placeholders;
|
||||
@ -222,13 +223,14 @@ class Core::Impl : public ICore {
|
||||
const std::map<std::string, std::string>& parsedConfig,
|
||||
const RemoteContext::Ptr& context,
|
||||
const std::string& blobID,
|
||||
const std::string& modelPath = std::string()) {
|
||||
const std::string& modelPath = std::string(),
|
||||
bool forceDisableCache = false) {
|
||||
OV_ITT_SCOPED_TASK(itt::domains::IE_LT, "Core::Impl::LoadNetworkImpl");
|
||||
ExecutableNetwork execNetwork;
|
||||
execNetwork = context ? plugin.LoadNetwork(network, context, parsedConfig) :
|
||||
plugin.LoadNetwork(network, parsedConfig);
|
||||
auto cacheManager = coreConfig.getCacheConfig()._cacheManager;
|
||||
if (cacheManager && DeviceSupportsImportExport(plugin)) {
|
||||
if (!forceDisableCache && cacheManager && DeviceSupportsImportExport(plugin)) {
|
||||
try {
|
||||
// need to export network for further import from "cache"
|
||||
OV_ITT_SCOPED_TASK(itt::domains::IE_LT, "Core::LoadNetwork::Export");
|
||||
@ -296,14 +298,21 @@ class Core::Impl : public ICore {
|
||||
std::map<std::string, Parameter> getMetricConfig;
|
||||
auto compileConfig = origConfig;
|
||||
|
||||
// 0. remove DEVICE_ID key
|
||||
// 0. Remove TARGET_FALLBACK key, move it to getMetricConfig
|
||||
auto targetFallbackIt = compileConfig.find("TARGET_FALLBACK");
|
||||
if (targetFallbackIt != compileConfig.end()) {
|
||||
getMetricConfig[targetFallbackIt->first] = targetFallbackIt->second;
|
||||
compileConfig.erase(targetFallbackIt);
|
||||
}
|
||||
|
||||
// 1. remove DEVICE_ID key
|
||||
auto deviceIt = compileConfig.find(CONFIG_KEY(DEVICE_ID));
|
||||
if (deviceIt != compileConfig.end()) {
|
||||
getMetricConfig[deviceIt->first] = deviceIt->second;
|
||||
compileConfig.erase(deviceIt);
|
||||
}
|
||||
|
||||
// 1. replace it with DEVICE_ARCHITECTURE value
|
||||
// 2. replace it with DEVICE_ARCHITECTURE value
|
||||
std::vector<std::string> supportedMetricKeys =
|
||||
plugin.GetMetric(METRIC_KEY(SUPPORTED_METRICS), getMetricConfig);
|
||||
auto archIt = std::find(supportedMetricKeys.begin(), supportedMetricKeys.end(),
|
||||
@ -456,19 +465,24 @@ public:
|
||||
ExecutableNetwork LoadNetwork(const CNNNetwork& network, const std::string& deviceName,
|
||||
const std::map<std::string, std::string>& config) override {
|
||||
OV_ITT_SCOPED_TASK(itt::domains::IE_LT, "Core::LoadNetwork::CNN");
|
||||
bool forceDisableCache = config.count(CONFIG_KEY_INTERNAL(FORCE_DISABLE_CACHE)) > 0;
|
||||
auto parsed = parseDeviceNameIntoConfig(deviceName, config);
|
||||
if (forceDisableCache) {
|
||||
// remove this config key from parsed as plugins can throw unsupported exception
|
||||
parsed._config.erase(CONFIG_KEY_INTERNAL(FORCE_DISABLE_CACHE));
|
||||
}
|
||||
auto plugin = GetCPPPluginByName(parsed._deviceName);
|
||||
bool loadedFromCache = false;
|
||||
ExecutableNetwork res;
|
||||
std::string hash;
|
||||
auto cacheManager = coreConfig.getCacheConfig()._cacheManager;
|
||||
if (cacheManager && DeviceSupportsImportExport(plugin)) {
|
||||
if (!forceDisableCache && cacheManager && DeviceSupportsImportExport(plugin)) {
|
||||
hash = CalculateNetworkHash(network, parsed._deviceName, plugin, parsed._config);
|
||||
res = LoadNetworkFromCache(cacheManager, hash, plugin, parsed._config, nullptr, loadedFromCache);
|
||||
}
|
||||
|
||||
if (!loadedFromCache) {
|
||||
res = LoadNetworkImpl(network, plugin, parsed._config, nullptr, hash);
|
||||
res = LoadNetworkImpl(network, plugin, parsed._config, nullptr, hash, {}, forceDisableCache);
|
||||
}
|
||||
return res;
|
||||
}
|
||||
|
@ -38,6 +38,13 @@ DECLARE_CONFIG_KEY(LP_TRANSFORMS_MODE);
|
||||
*/
|
||||
DECLARE_CONFIG_KEY(CPU_THREADS_PER_STREAM);
|
||||
|
||||
/**
|
||||
* @brief This key should be used to force disable export while loading network even if global cache dir is defined
|
||||
* Used by HETERO plugin to disable automatic caching of subnetworks (set value to YES)
|
||||
* @ingroup ie_dev_api_plugin_api
|
||||
*/
|
||||
DECLARE_CONFIG_KEY(FORCE_DISABLE_CACHE);
|
||||
|
||||
} // namespace PluginConfigInternalParams
|
||||
|
||||
} // namespace InferenceEngine
|
||||
|
@ -95,6 +95,8 @@ public:
|
||||
MockExecutableNetwork() {}
|
||||
MOCK_METHOD1(ExportImpl, void(std::ostream& networkModel));
|
||||
MOCK_METHOD0(CreateInferRequest, IInferRequest::Ptr());
|
||||
MOCK_CONST_METHOD0(GetInputsInfo, ConstInputsDataMap());
|
||||
MOCK_CONST_METHOD0(GetOutputsInfo, ConstOutputsDataMap());
|
||||
};
|
||||
|
||||
//------------------------------------------------------
|
||||
@ -267,14 +269,20 @@ private:
|
||||
WillByDefault(Return("mock"));
|
||||
|
||||
ON_CALL(plugin, ImportNetworkImpl(_, _, _)).
|
||||
WillByDefault(Invoke([&](std::istream &, RemoteContext::Ptr,
|
||||
WillByDefault(Invoke([&](std::istream &istr, RemoteContext::Ptr,
|
||||
const std::map<std::string, std::string> &) {
|
||||
return ExecutableNetwork(std::make_shared<MockIExecutableNetwork>());
|
||||
auto mock = std::make_shared<MockIExecutableNetwork>();
|
||||
EXPECT_CALL(*mock, GetInputsInfo(_, _)).Times(AnyNumber()).WillRepeatedly(Return(OK));
|
||||
EXPECT_CALL(*mock, GetOutputsInfo(_, _)).Times(AnyNumber()).WillRepeatedly(Return(OK));
|
||||
return ExecutableNetwork(mock);
|
||||
}));
|
||||
|
||||
ON_CALL(plugin, ImportNetworkImpl(_, _)).
|
||||
WillByDefault(Invoke([&](std::istream &, const std::map<std::string, std::string> &) {
|
||||
return ExecutableNetwork(std::make_shared<MockIExecutableNetwork>());
|
||||
WillByDefault(Invoke([&](std::istream &istr, const std::map<std::string, std::string> &) {
|
||||
auto mock = std::make_shared<MockIExecutableNetwork>();
|
||||
EXPECT_CALL(*mock, GetInputsInfo(_, _)).Times(AnyNumber()).WillRepeatedly(Return(OK));
|
||||
EXPECT_CALL(*mock, GetOutputsInfo(_, _)).Times(AnyNumber()).WillRepeatedly(Return(OK));
|
||||
return ExecutableNetwork(mock);
|
||||
}));
|
||||
|
||||
ON_CALL(plugin, LoadExeNetworkImpl(_, _, _)).
|
||||
@ -305,6 +313,11 @@ private:
|
||||
}
|
||||
return res;
|
||||
}));
|
||||
|
||||
EXPECT_CALL(*net, GetInputsInfo()).Times(AnyNumber())
|
||||
.WillRepeatedly(Return(ConstInputsDataMap{}));
|
||||
EXPECT_CALL(*net, GetOutputsInfo()).Times(AnyNumber())
|
||||
.WillRepeatedly(Return(ConstOutputsDataMap{}));
|
||||
}
|
||||
};
|
||||
|
||||
@ -348,7 +361,10 @@ TEST_P(CachingTest, TestLoadCustomImportExport) {
|
||||
int a;
|
||||
s >> a;
|
||||
EXPECT_EQ(customNumber, a);
|
||||
return ExecutableNetwork(std::make_shared<MockIExecutableNetwork>());
|
||||
auto mock = std::make_shared<MockIExecutableNetwork>();
|
||||
EXPECT_CALL(*mock, GetInputsInfo(_, _)).Times(AnyNumber()).WillRepeatedly(Return(OK));
|
||||
EXPECT_CALL(*mock, GetOutputsInfo(_, _)).Times(AnyNumber()).WillRepeatedly(Return(OK));
|
||||
return ExecutableNetwork(mock);
|
||||
}));
|
||||
|
||||
ON_CALL(*mockPlugin, ImportNetworkImpl(_, _)).
|
||||
@ -356,7 +372,10 @@ TEST_P(CachingTest, TestLoadCustomImportExport) {
|
||||
int a;
|
||||
s >> a;
|
||||
EXPECT_EQ(customNumber, a);
|
||||
return ExecutableNetwork(std::make_shared<MockIExecutableNetwork>());
|
||||
auto mock = std::make_shared<MockIExecutableNetwork>();
|
||||
EXPECT_CALL(*mock, GetInputsInfo(_, _)).Times(AnyNumber()).WillRepeatedly(Return(OK));
|
||||
EXPECT_CALL(*mock, GetOutputsInfo(_, _)).Times(AnyNumber()).WillRepeatedly(Return(OK));
|
||||
return ExecutableNetwork(mock);
|
||||
}));
|
||||
|
||||
ON_CALL(*net, ExportImpl(_)).WillByDefault(Invoke([&] (std::ostream& s) {
|
||||
@ -916,14 +935,37 @@ TEST_P(CachingTest, TestCacheFileOldVersion) {
|
||||
}
|
||||
}
|
||||
|
||||
TEST_P(CachingTest, LoadHeteroWithCorrectConfig) {
|
||||
EXPECT_CALL(*mockPlugin, GetMetric(_, _)).Times(AnyNumber());
|
||||
TEST_P(CachingTest, LoadHetero_NoCacheMetric) {
|
||||
EXPECT_CALL(*mockPlugin, GetMetric(METRIC_KEY(SUPPORTED_CONFIG_KEYS), _))
|
||||
.Times(AnyNumber()).WillRepeatedly(Return(std::vector<std::string>{}));
|
||||
EXPECT_CALL(*mockPlugin, QueryNetwork(_, _)).Times(AnyNumber());
|
||||
// TODO: test also HETERO with 1 plugin but different architectures, e.g. "HETERO:mock.1,mock.51"
|
||||
EXPECT_CALL(*mockPlugin, GetMetric(METRIC_KEY(SUPPORTED_METRICS), _))
|
||||
.Times(AnyNumber()).WillRepeatedly(Return(std::vector<std::string>{}));
|
||||
// Hetero supports Import/Export, but mock plugin does not
|
||||
deviceToLoad = CommonTestUtils::DEVICE_HETERO + std::string(":mock.1,mock.2");
|
||||
if (m_remoteContext) {
|
||||
return; // skip the remote Context test for Hetero plugin
|
||||
}
|
||||
for (int i = 0; i < 2; i++) {
|
||||
EXPECT_CALL(*mockPlugin, LoadExeNetworkImpl(_, _, _)).Times(0);
|
||||
EXPECT_CALL(*mockPlugin, LoadExeNetworkImpl(_, _)).Times(1);
|
||||
EXPECT_CALL(*mockPlugin, ImportNetworkImpl(_, _, _)).Times(0);
|
||||
EXPECT_CALL(*mockPlugin, ImportNetworkImpl(_, _)).Times(0);
|
||||
EXPECT_CALL(*net, ExportImpl(_)).Times(0);
|
||||
testLoad([&](Core &ie) {
|
||||
ie.SetConfig({{CONFIG_KEY(CACHE_DIR), m_cacheDir}});
|
||||
m_testFunction(ie);
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
TEST_P(CachingTest, LoadHetero_OneDevice) {
|
||||
EXPECT_CALL(*mockPlugin, QueryNetwork(_, _)).Times(AnyNumber());
|
||||
EXPECT_CALL(*mockPlugin, GetMetric(_, _)).Times(AnyNumber());
|
||||
deviceToLoad = CommonTestUtils::DEVICE_HETERO + std::string(":mock");
|
||||
if (m_remoteContext) {
|
||||
return; // skip the remote Context test for Hetero plugin
|
||||
}
|
||||
{
|
||||
EXPECT_CALL(*mockPlugin, LoadExeNetworkImpl(_, _, _)).Times(0);
|
||||
EXPECT_CALL(*mockPlugin, LoadExeNetworkImpl(_, _)).Times(1);
|
||||
@ -934,6 +976,8 @@ TEST_P(CachingTest, LoadHeteroWithCorrectConfig) {
|
||||
ie.SetConfig({{CONFIG_KEY(CACHE_DIR), m_cacheDir}});
|
||||
m_testFunction(ie);
|
||||
});
|
||||
// Ensure that only 1 blob (for Hetero) is created
|
||||
EXPECT_EQ(CommonTestUtils::listFilesWithExt(m_cacheDir, "blob").size(), 1);
|
||||
}
|
||||
|
||||
{
|
||||
@ -949,6 +993,185 @@ TEST_P(CachingTest, LoadHeteroWithCorrectConfig) {
|
||||
}
|
||||
}
|
||||
|
||||
TEST_P(CachingTest, LoadHetero_TargetFallbackFromCore) {
|
||||
EXPECT_CALL(*mockPlugin, QueryNetwork(_, _)).Times(AnyNumber());
|
||||
EXPECT_CALL(*mockPlugin, GetMetric(_, _)).Times(AnyNumber());
|
||||
deviceToLoad = CommonTestUtils::DEVICE_HETERO;
|
||||
if (m_remoteContext) {
|
||||
return; // skip the remote Context test for Hetero plugin
|
||||
}
|
||||
{
|
||||
EXPECT_CALL(*mockPlugin, LoadExeNetworkImpl(_, _, _)).Times(0);
|
||||
EXPECT_CALL(*mockPlugin, LoadExeNetworkImpl(_, _)).Times(1);
|
||||
EXPECT_CALL(*mockPlugin, ImportNetworkImpl(_, _, _)).Times(0);
|
||||
EXPECT_CALL(*mockPlugin, ImportNetworkImpl(_, _)).Times(0);
|
||||
EXPECT_CALL(*net, ExportImpl(_)).Times(1);
|
||||
testLoad([&](Core &ie) {
|
||||
ie.SetConfig({{CONFIG_KEY(CACHE_DIR), m_cacheDir}});
|
||||
ie.SetConfig({{"TARGET_FALLBACK", "mock"}}, CommonTestUtils::DEVICE_HETERO);
|
||||
m_testFunction(ie);
|
||||
});
|
||||
// Ensure that only 1 blob (for Hetero) is created
|
||||
EXPECT_EQ(CommonTestUtils::listFilesWithExt(m_cacheDir, "blob").size(), 1);
|
||||
}
|
||||
|
||||
{
|
||||
EXPECT_CALL(*mockPlugin, LoadExeNetworkImpl(_, _, _)).Times(0);
|
||||
EXPECT_CALL(*mockPlugin, LoadExeNetworkImpl(_, _)).Times(0);
|
||||
EXPECT_CALL(*mockPlugin, ImportNetworkImpl(_, _, _)).Times(0);
|
||||
EXPECT_CALL(*mockPlugin, ImportNetworkImpl(_, _)).Times(1);
|
||||
EXPECT_CALL(*net, ExportImpl(_)).Times(0);
|
||||
testLoad([&](Core &ie) {
|
||||
ie.SetConfig({{CONFIG_KEY(CACHE_DIR), m_cacheDir}});
|
||||
ie.SetConfig({{"TARGET_FALLBACK", "mock"}}, CommonTestUtils::DEVICE_HETERO);
|
||||
m_testFunction(ie);
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
TEST_P(CachingTest, LoadHetero_MultiArchs) {
|
||||
EXPECT_CALL(*mockPlugin, GetMetric(_, _)).Times(AnyNumber());
|
||||
int customNumber = 1234;
|
||||
ON_CALL(*mockPlugin, ImportNetworkImpl(_, _)).
|
||||
WillByDefault(Invoke([&](std::istream &s, const std::map<std::string, std::string> &) {
|
||||
int a;
|
||||
s >> a;
|
||||
EXPECT_EQ(customNumber, a);
|
||||
auto mock = std::make_shared<MockIExecutableNetwork>();
|
||||
EXPECT_CALL(*mock, GetInputsInfo(_, _)).Times(AnyNumber()).WillRepeatedly(Return(OK));
|
||||
EXPECT_CALL(*mock, GetOutputsInfo(_, _)).Times(AnyNumber()).WillRepeatedly(Return(OK));
|
||||
return ExecutableNetwork(mock);
|
||||
}));
|
||||
|
||||
ON_CALL(*net, ExportImpl(_)).WillByDefault(Invoke([&] (std::ostream& s) {
|
||||
s << customNumber;
|
||||
}));
|
||||
EXPECT_CALL(*mockPlugin, QueryNetwork(_, _)).Times(AnyNumber()).WillRepeatedly(
|
||||
Invoke([&](const CNNNetwork &network, const std::map<std::string, std::string> &config) {
|
||||
QueryNetworkResult res;
|
||||
auto function = network.getFunction();
|
||||
EXPECT_TRUE(function);
|
||||
|
||||
auto id = config.at("DEVICE_ID");
|
||||
bool supportsRelu = std::stoi(id) < 10;
|
||||
|
||||
for (auto &&node : function->get_ops()) {
|
||||
std::string nodeType = node->get_type_name();
|
||||
if ((nodeType == "Relu" && supportsRelu) ||
|
||||
(nodeType != "Relu" && !supportsRelu)) {
|
||||
res.supportedLayersMap.emplace(node->get_friendly_name(), deviceName + "." + id);
|
||||
}
|
||||
}
|
||||
return res;
|
||||
}));
|
||||
EXPECT_CALL(*mockPlugin, GetMetric(METRIC_KEY(DEVICE_ARCHITECTURE), _)).Times(AnyNumber())
|
||||
.WillRepeatedly(Invoke([&](const std::string &, const std::map<std::string, Parameter> &options) {
|
||||
auto id = options.at("DEVICE_ID").as<std::string>();
|
||||
if (std::stoi(id) < 10) {
|
||||
return "mock_first_architecture";
|
||||
} else {
|
||||
return "mock_another_architecture";
|
||||
}
|
||||
}));
|
||||
deviceToLoad = CommonTestUtils::DEVICE_HETERO + std::string(":mock.1,mock.51");
|
||||
if (m_remoteContext) {
|
||||
return; // skip the remote Context test for Hetero plugin
|
||||
}
|
||||
{
|
||||
EXPECT_CALL(*mockPlugin, LoadExeNetworkImpl(_, _, _)).Times(0);
|
||||
EXPECT_CALL(*mockPlugin, LoadExeNetworkImpl(_, _)).Times(AtLeast(2)); // for .1 and for .51
|
||||
EXPECT_CALL(*mockPlugin, ImportNetworkImpl(_, _, _)).Times(0);
|
||||
EXPECT_CALL(*mockPlugin, ImportNetworkImpl(_, _)).Times(0);
|
||||
EXPECT_CALL(*net, ExportImpl(_)).Times(AtLeast(2)); // for .1 and for .51
|
||||
testLoad([&](Core &ie) {
|
||||
ie.SetConfig({{CONFIG_KEY(CACHE_DIR), m_cacheDir}});
|
||||
m_testFunction(ie);
|
||||
});
|
||||
// Ensure that only 1 blob (for Hetero) is created
|
||||
EXPECT_EQ(CommonTestUtils::listFilesWithExt(m_cacheDir, "blob").size(), 1);
|
||||
}
|
||||
|
||||
deviceToLoad = CommonTestUtils::DEVICE_HETERO + std::string(":mock.2,mock.52");
|
||||
{
|
||||
EXPECT_CALL(*mockPlugin, LoadExeNetworkImpl(_, _, _)).Times(0);
|
||||
EXPECT_CALL(*mockPlugin, LoadExeNetworkImpl(_, _)).Times(0);
|
||||
EXPECT_CALL(*mockPlugin, ImportNetworkImpl(_, _, _)).Times(0);
|
||||
EXPECT_CALL(*mockPlugin, ImportNetworkImpl(_, _)).Times(AtLeast(2)); // for .2 and for .52
|
||||
EXPECT_CALL(*net, ExportImpl(_)).Times(0);
|
||||
testLoad([&](Core &ie) {
|
||||
ie.SetConfig({{CONFIG_KEY(CACHE_DIR), m_cacheDir}});
|
||||
m_testFunction(ie);
|
||||
});
|
||||
}
|
||||
deviceToLoad = CommonTestUtils::DEVICE_HETERO + std::string(":mock.53,mock.3");
|
||||
{
|
||||
EXPECT_CALL(*mockPlugin, LoadExeNetworkImpl(_, _, _)).Times(0);
|
||||
EXPECT_CALL(*mockPlugin, LoadExeNetworkImpl(_, _)).Times(AtLeast(1));
|
||||
EXPECT_CALL(*mockPlugin, ImportNetworkImpl(_, _, _)).Times(0);
|
||||
EXPECT_CALL(*mockPlugin, ImportNetworkImpl(_, _)).Times(0);
|
||||
EXPECT_CALL(*net, ExportImpl(_)).Times(AtLeast(1));
|
||||
testLoad([&](Core &ie) {
|
||||
ie.SetConfig({{CONFIG_KEY(CACHE_DIR), m_cacheDir}});
|
||||
m_testFunction(ie);
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
TEST_P(CachingTest, LoadHetero_MultiArchs_TargetFallback_FromCore) {
|
||||
EXPECT_CALL(*mockPlugin, GetMetric(_, _)).Times(AnyNumber());
|
||||
EXPECT_CALL(*mockPlugin, QueryNetwork(_, _)).Times(AnyNumber());
|
||||
EXPECT_CALL(*mockPlugin, GetMetric(METRIC_KEY(DEVICE_ARCHITECTURE), _)).Times(AnyNumber())
|
||||
.WillRepeatedly(Invoke([&](const std::string &, const std::map<std::string, Parameter> &options) {
|
||||
auto id = options.at("DEVICE_ID").as<std::string>();
|
||||
if (std::stoi(id) < 10) {
|
||||
return "mock_first_architecture";
|
||||
} else {
|
||||
return "mock_another_architecture";
|
||||
}
|
||||
}));
|
||||
deviceToLoad = CommonTestUtils::DEVICE_HETERO;
|
||||
if (m_remoteContext) {
|
||||
return; // skip the remote Context test for Hetero plugin
|
||||
}
|
||||
{
|
||||
EXPECT_CALL(*mockPlugin, LoadExeNetworkImpl(_, _, _)).Times(0);
|
||||
EXPECT_CALL(*mockPlugin, LoadExeNetworkImpl(_, _)).Times(1);
|
||||
EXPECT_CALL(*mockPlugin, ImportNetworkImpl(_, _, _)).Times(0);
|
||||
EXPECT_CALL(*mockPlugin, ImportNetworkImpl(_, _)).Times(0);
|
||||
EXPECT_CALL(*net, ExportImpl(_)).Times(1);
|
||||
testLoad([&](Core &ie) {
|
||||
ie.SetConfig({{CONFIG_KEY(CACHE_DIR), m_cacheDir}});
|
||||
ie.SetConfig({{"TARGET_FALLBACK", "mock.1"}}, CommonTestUtils::DEVICE_HETERO);
|
||||
m_testFunction(ie);
|
||||
});
|
||||
}
|
||||
|
||||
{
|
||||
EXPECT_CALL(*mockPlugin, LoadExeNetworkImpl(_, _, _)).Times(0);
|
||||
EXPECT_CALL(*mockPlugin, LoadExeNetworkImpl(_, _)).Times(0);
|
||||
EXPECT_CALL(*mockPlugin, ImportNetworkImpl(_, _, _)).Times(0);
|
||||
EXPECT_CALL(*mockPlugin, ImportNetworkImpl(_, _)).Times(1);
|
||||
EXPECT_CALL(*net, ExportImpl(_)).Times(0);
|
||||
testLoad([&](Core &ie) {
|
||||
ie.SetConfig({{"TARGET_FALLBACK", "mock.1"}}, CommonTestUtils::DEVICE_HETERO);
|
||||
ie.SetConfig({{CONFIG_KEY(CACHE_DIR), m_cacheDir}});
|
||||
m_testFunction(ie);
|
||||
});
|
||||
}
|
||||
{
|
||||
EXPECT_CALL(*mockPlugin, LoadExeNetworkImpl(_, _, _)).Times(0);
|
||||
EXPECT_CALL(*mockPlugin, LoadExeNetworkImpl(_, _)).Times(1);
|
||||
EXPECT_CALL(*mockPlugin, ImportNetworkImpl(_, _, _)).Times(0);
|
||||
EXPECT_CALL(*mockPlugin, ImportNetworkImpl(_, _)).Times(0);
|
||||
EXPECT_CALL(*net, ExportImpl(_)).Times(1);
|
||||
testLoad([&](Core &ie) {
|
||||
ie.SetConfig({{"TARGET_FALLBACK", "mock.51"}}, CommonTestUtils::DEVICE_HETERO);
|
||||
ie.SetConfig({{CONFIG_KEY(CACHE_DIR), m_cacheDir}});
|
||||
m_testFunction(ie);
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
INSTANTIATE_TEST_CASE_P(CachingTest, CachingTest,
|
||||
::testing::Combine(
|
||||
::testing::ValuesIn(loadVariants),
|
||||
|
Loading…
Reference in New Issue
Block a user