Multi plugin - override loading network from file (#5677)
* Multi plugin - override loading network from file When caching is enabled, MULTI plugin will check all devices - For devices with caching supported - call LoadNetwork(modelPath, ...) - For others - ReadNetwork once and then LoadNetwork(cnnNetwork) for each device Caching unit test is added for both cases Additional helper methods: - ICore::ToExecutableNetwork - converts internal ExeNetwork to ExecutableNetwork - ICore::DeviceSupportsImportExport - checks if device supports import and export functionality. Used by Hetero and Multi * Updated according to review comments * fixed sporadic failure of 'multi-device' test cases Root cause: Currently only one 'ExecutableNetwork' object is created for each LoadNetwork For Multi-testing several threads could call simultaneously setNetworkInputs/Outputs/SetPointerToPlugin It caused race condition and invalid data structures * Fix build issues after rebase * Multi: Set network inputs/outputs/pointerToPlugin for load-from-file case Overloaded function doesn't call these methods, thus multi executable network was unusable Added caching test verifying that inputs/outputs are copied now from first loaded device network
This commit is contained in:
@@ -484,7 +484,7 @@ HeteroExecutableNetwork::HeteroExecutableNetwork(std::istream&
|
||||
InferenceEngine::SoExecutableNetworkInternal executableNetwork;
|
||||
CNNNetwork cnnnetwork;
|
||||
bool loaded = false;
|
||||
if (ImportExportSupported(deviceName)) {
|
||||
if (_heteroPlugin->GetCore()->DeviceSupportsImportExport(deviceName)) {
|
||||
executableNetwork = _heteroPlugin->GetCore()->ImportNetwork(heteroModel, deviceName, loadConfig);
|
||||
} else {
|
||||
// read XML content
|
||||
@@ -612,7 +612,7 @@ void HeteroExecutableNetwork::ExportImpl(std::ostream& heteroModel) {
|
||||
heteroModel << std::endl;
|
||||
|
||||
for (auto&& subnetwork : _networks) {
|
||||
if (ImportExportSupported(subnetwork._device)) {
|
||||
if (_heteroPlugin->GetCore()->DeviceSupportsImportExport(subnetwork._device)) {
|
||||
subnetwork._network->Export(heteroModel);
|
||||
} else {
|
||||
auto subnet = subnetwork._clonedNetwork;
|
||||
@@ -801,13 +801,3 @@ InferenceEngine::Parameter HeteroExecutableNetwork::GetMetric(const std::string
|
||||
IE_THROW() << "Unsupported ExecutableNetwork metric: " << name;
|
||||
}
|
||||
}
|
||||
|
||||
bool HeteroExecutableNetwork::ImportExportSupported(const std::string& deviceName) const {
|
||||
std::vector<std::string> supportedMetricKeys = _heteroPlugin->GetCore()->GetMetric(
|
||||
deviceName, METRIC_KEY(SUPPORTED_METRICS));
|
||||
auto it = std::find(supportedMetricKeys.begin(), supportedMetricKeys.end(),
|
||||
METRIC_KEY(IMPORT_EXPORT_SUPPORT));
|
||||
bool supported = (it != supportedMetricKeys.end()) &&
|
||||
_heteroPlugin->GetCore()->GetMetric(deviceName, METRIC_KEY(IMPORT_EXPORT_SUPPORT));
|
||||
return supported;
|
||||
}
|
||||
|
||||
@@ -63,7 +63,6 @@ public:
|
||||
private:
|
||||
void InitCNNImpl(const InferenceEngine::CNNNetwork& network);
|
||||
void InitNgraph(const InferenceEngine::CNNNetwork& network);
|
||||
bool ImportExportSupported(const std::string& deviceName) const;
|
||||
|
||||
struct NetworkDesc {
|
||||
std::string _device;
|
||||
|
||||
@@ -225,6 +225,12 @@ class Core::Impl : public ICore {
|
||||
std::map<std::string, PluginDescriptor> pluginRegistry;
|
||||
mutable std::mutex pluginsMutex; // to lock parallel access to pluginRegistry and plugins
|
||||
|
||||
bool DeviceSupportsImportExport(const std::string& deviceName) const override {
|
||||
auto parsed = parseDeviceNameIntoConfig(deviceName);
|
||||
auto plugin = GetCPPPluginByName(parsed._deviceName);
|
||||
return DeviceSupportsImportExport(plugin);
|
||||
}
|
||||
|
||||
bool DeviceSupportsImportExport(const InferencePlugin& plugin) const {
|
||||
std::vector<std::string> supportedMetricKeys = plugin.GetMetric(METRIC_KEY(SUPPORTED_METRICS), {});
|
||||
auto it = std::find(supportedMetricKeys.begin(), supportedMetricKeys.end(),
|
||||
|
||||
@@ -142,13 +142,48 @@ InferenceEngine::Parameter MultiDeviceInferencePlugin::GetMetric(const std::stri
|
||||
}
|
||||
}
|
||||
|
||||
void MultiDeviceInferencePlugin::SetExeNetworkInfo(InferenceEngine::ExecutableNetworkInternal::Ptr exeNetwork,
|
||||
const InferenceEngine::ConstInputsDataMap& devInputs,
|
||||
const InferenceEngine::ConstOutputsDataMap& devOutputs) {
|
||||
// Set inputs/outputs and pointer to plugin manually here
|
||||
InputsDataMap _inputs, clonedInputs;
|
||||
OutputsDataMap _outputs, clonedOutputs;
|
||||
for (auto& it : devInputs) {
|
||||
InputInfo::CPtr devData = it.second;
|
||||
InputInfo::Ptr data = std::make_shared<InputInfo>(*devData);
|
||||
_inputs[it.first] = data;
|
||||
}
|
||||
for (auto& it : devOutputs) {
|
||||
CDataPtr devData = it.second;
|
||||
DataPtr data = std::make_shared<Data>(*devData);
|
||||
_outputs[it.first] = data;
|
||||
}
|
||||
copyInputOutputInfo(_inputs, _outputs, clonedInputs, clonedOutputs);
|
||||
exeNetwork->setNetworkInputs(clonedInputs);
|
||||
exeNetwork->setNetworkOutputs(clonedOutputs);
|
||||
exeNetwork->SetPointerToPlugin(shared_from_this());
|
||||
}
|
||||
|
||||
// Is called only when caching is enabled
|
||||
IExecutableNetworkInternal::Ptr MultiDeviceInferencePlugin::LoadNetwork(const std::string& modelPath,
|
||||
const std::map<std::string, std::string>& config) {
|
||||
CNNNetwork network;
|
||||
return LoadExeNetworkImpl(modelPath, network, config);
|
||||
}
|
||||
|
||||
ExecutableNetworkInternal::Ptr MultiDeviceInferencePlugin::LoadExeNetworkImpl(const CNNNetwork &network,
|
||||
const std::map<std::string, std::string>& config) {
|
||||
return LoadExeNetworkImpl({}, network, config);
|
||||
}
|
||||
|
||||
ExecutableNetworkInternal::Ptr MultiDeviceInferencePlugin::LoadExeNetworkImpl(const std::string& modelPath,
|
||||
CNNNetwork network,
|
||||
const std::map<std::string, std::string>& config) {
|
||||
if (GetCore() == nullptr) {
|
||||
IE_THROW() << "Please, work with MULTI device via InferencEngine::Core object";
|
||||
IE_THROW() << "Please, work with MULTI device via InferenceEngine::Core object";
|
||||
}
|
||||
|
||||
if (network.getFunction() == nullptr) {
|
||||
if (modelPath.empty() && network.getFunction() == nullptr) {
|
||||
IE_THROW() << "MULTI device supports just ngraph network representation";
|
||||
}
|
||||
|
||||
@@ -167,11 +202,22 @@ ExecutableNetworkInternal::Ptr MultiDeviceInferencePlugin::LoadExeNetworkImpl(co
|
||||
DeviceMap<SoExecutableNetworkInternal> executableNetworkPerDevice;
|
||||
std::mutex load_mutex;
|
||||
std::vector<Task> loads;
|
||||
std::once_flag readNetworkFlag;
|
||||
for (auto& p : metaDevices) {
|
||||
loads.push_back([&]() {
|
||||
const auto &deviceName = p.deviceName;
|
||||
const auto &deviceConfig = p.config;
|
||||
auto exec_net = GetCore()->LoadNetwork(network, deviceName, deviceConfig);
|
||||
SoExecutableNetworkInternal exec_net;
|
||||
if (modelPath.empty()) {
|
||||
exec_net = GetCore()->LoadNetwork(network, deviceName, deviceConfig);
|
||||
} else if (GetCore()->DeviceSupportsImportExport(deviceName)) {
|
||||
exec_net = GetCore()->LoadNetwork(modelPath, deviceName, deviceConfig);
|
||||
} else {
|
||||
std::call_once(readNetworkFlag, [&]() {
|
||||
network = GetCore()->ReadNetwork(modelPath, std::string());
|
||||
});
|
||||
exec_net = GetCore()->LoadNetwork(network, deviceName, deviceConfig);
|
||||
}
|
||||
std::unique_lock<std::mutex> lock{load_mutex};
|
||||
executableNetworkPerDevice.insert({deviceName, exec_net});
|
||||
multiNetworkConfig.insert(deviceConfig.begin(), deviceConfig.end());
|
||||
@@ -199,10 +245,16 @@ ExecutableNetworkInternal::Ptr MultiDeviceInferencePlugin::LoadExeNetworkImpl(co
|
||||
}
|
||||
// MULTI can enable the perf counters only if all devices support/enable that
|
||||
bool enablePerfCounters = num_plugins_supporting_perf_counters == executableNetworkPerDevice.size();
|
||||
return std::make_shared<MultiDeviceExecutableNetwork>(executableNetworkPerDevice,
|
||||
metaDevices,
|
||||
multiNetworkConfig,
|
||||
enablePerfCounters);
|
||||
auto impl = std::make_shared<MultiDeviceExecutableNetwork>(executableNetworkPerDevice,
|
||||
metaDevices,
|
||||
multiNetworkConfig,
|
||||
enablePerfCounters);
|
||||
if (!modelPath.empty()) {
|
||||
SetExeNetworkInfo(impl,
|
||||
executableNetworkPerDevice.begin()->second->GetInputsInfo(),
|
||||
executableNetworkPerDevice.begin()->second->GetOutputsInfo());
|
||||
}
|
||||
return impl;
|
||||
}
|
||||
|
||||
QueryNetworkResult MultiDeviceInferencePlugin::QueryNetwork(const CNNNetwork& network,
|
||||
|
||||
@@ -23,6 +23,9 @@ public:
|
||||
InferenceEngine::ExecutableNetworkInternal::Ptr LoadExeNetworkImpl(const InferenceEngine::CNNNetwork& network,
|
||||
const std::map<std::string, std::string>& config) override;
|
||||
|
||||
InferenceEngine::IExecutableNetworkInternal::Ptr LoadNetwork(const std::string& modelPath,
|
||||
const std::map<std::string, std::string>& config) override;
|
||||
|
||||
void SetConfig(const std::map<std::string, std::string>& config) override;
|
||||
InferenceEngine::Parameter GetConfig(const std::string& name, const std::map<std::string, InferenceEngine::Parameter> & options) const override;
|
||||
InferenceEngine::QueryNetworkResult QueryNetwork(const InferenceEngine::CNNNetwork& network,
|
||||
@@ -36,6 +39,15 @@ public:
|
||||
protected:
|
||||
std::map<std::string, std::string> GetSupportedConfig(const std::map<std::string, std::string>& config,
|
||||
const MultiDevicePlugin::DeviceName & deviceName) const;
|
||||
|
||||
private:
|
||||
InferenceEngine::ExecutableNetworkInternal::Ptr LoadExeNetworkImpl(const std::string& modelPath,
|
||||
InferenceEngine::CNNNetwork network,
|
||||
const std::map<std::string, std::string>& config);
|
||||
|
||||
void SetExeNetworkInfo(InferenceEngine::ExecutableNetworkInternal::Ptr exeNetwork,
|
||||
const InferenceEngine::ConstInputsDataMap& inputs,
|
||||
const InferenceEngine::ConstOutputsDataMap& outputs);
|
||||
};
|
||||
|
||||
} // namespace MultiDevicePlugin
|
||||
|
||||
@@ -72,6 +72,10 @@ public:
|
||||
return impl;
|
||||
}
|
||||
|
||||
// NOTE:
|
||||
// In case of overloading this method, make sure that executable network
|
||||
// has correctly setNetworkInputs/setNetworkOutputs/SetPointerToPlugin
|
||||
// Base implementation does this via GetCore()->LoadNetwork(cnnNet)
|
||||
IExecutableNetworkInternal::Ptr LoadNetwork(const std::string& modelPath,
|
||||
const std::map<std::string, std::string>& config) override {
|
||||
auto cnnNet = GetCore()->ReadNetwork(modelPath, std::string());
|
||||
|
||||
@@ -123,6 +123,15 @@ public:
|
||||
*/
|
||||
virtual std::vector<std::string> GetAvailableDevices() const = 0;
|
||||
|
||||
/**
|
||||
* @brief Checks whether device supports Export & Import functionality of network
|
||||
*
|
||||
* @param deviceName - A name of a device to get a metric value.
|
||||
* @return True if device has IMPORT_EXPORT_SUPPORT metric in SUPPORTED_METRICS and
|
||||
* this metric returns 'true', False otherwise.
|
||||
*/
|
||||
virtual bool DeviceSupportsImportExport(const std::string& deviceName) const = 0;
|
||||
|
||||
/**
|
||||
* @brief Default virtual destructor
|
||||
*/
|
||||
|
||||
@@ -7,6 +7,7 @@
|
||||
#include <vector>
|
||||
#include <thread>
|
||||
#include <chrono>
|
||||
#include <mutex>
|
||||
#include <functional>
|
||||
#include <gtest/gtest.h>
|
||||
#include <gmock/gmock.h>
|
||||
@@ -68,7 +69,24 @@ public:
|
||||
MOCK_QUALIFIED_METHOD0(getParams, const, ParamMap());
|
||||
};
|
||||
|
||||
class MockCachingInferencePlugin : public InferenceEngine::InferencePluginInternal {
|
||||
class MockCachingInferencePluginBase : public InferenceEngine::InferencePluginInternal {
|
||||
public:
|
||||
MockCachingInferencePluginBase() = default;
|
||||
~MockCachingInferencePluginBase() = default;
|
||||
|
||||
IExecutableNetworkInternal::Ptr LoadNetwork(const std::string& modelPath,
|
||||
const std::map<std::string, std::string>& config) override {
|
||||
// In GTEST, it is not possible to call base implementation inside of mocked class
|
||||
// Thus, we define a proxy callback and will use
|
||||
// EXPECT_CALL(OnLoadNetworkFromFile) instead of EXPECT_CALL(LoadNetwork)
|
||||
OnLoadNetworkFromFile();
|
||||
return InferenceEngine::InferencePluginInternal::LoadNetwork(modelPath, config);
|
||||
}
|
||||
|
||||
virtual void OnLoadNetworkFromFile() const {}
|
||||
};
|
||||
|
||||
class MockCachingInferencePlugin : public MockCachingInferencePluginBase {
|
||||
public:
|
||||
MockCachingInferencePlugin() = default;
|
||||
~MockCachingInferencePlugin() = default;
|
||||
@@ -79,6 +97,8 @@ public:
|
||||
MOCK_METHOD3(LoadExeNetworkImpl, ExecutableNetworkInternal::Ptr(const CNNNetwork& network, RemoteContext::Ptr context,
|
||||
const std::map<std::string, std::string>& config));
|
||||
|
||||
MOCK_CONST_METHOD0(OnLoadNetworkFromFile, void(void));
|
||||
|
||||
MOCK_METHOD2(ImportNetworkImpl, ExecutableNetworkInternal::Ptr(std::istream& networkModel,
|
||||
const std::map<std::string, std::string>& config));
|
||||
|
||||
@@ -94,6 +114,8 @@ public:
|
||||
};
|
||||
|
||||
class MockExecutableNetwork : public ExecutableNetworkInternal {
|
||||
std::mutex m_pluginMutex;
|
||||
|
||||
public:
|
||||
MockExecutableNetwork() {}
|
||||
MOCK_METHOD1(ExportImpl, void(std::ostream& networkModel));
|
||||
@@ -103,6 +125,18 @@ public:
|
||||
MOCK_CONST_METHOD1(GetConfig, Parameter(const std::string& name));
|
||||
MOCK_CONST_METHOD1(GetMetric, Parameter(const std::string& name));
|
||||
MOCK_METHOD2(CreateInferRequestImpl, IInferRequestInternal::Ptr(InputsDataMap, OutputsDataMap));
|
||||
MOCK_METHOD1(setNetworkInputs, void(const InputsDataMap networkInputs));
|
||||
MOCK_METHOD1(setNetworkOutputs, void(const OutputsDataMap networkOutputs));
|
||||
|
||||
void Export(std::ostream& networkModel) override {
|
||||
std::lock_guard<std::mutex> guard(m_pluginMutex);
|
||||
ExecutableNetworkInternal::Export(networkModel);
|
||||
}
|
||||
|
||||
void SetPointerToPlugin(IInferencePlugin::Ptr plugin) override {
|
||||
std::lock_guard<std::mutex> guard(m_pluginMutex);
|
||||
ExecutableNetworkInternal::SetPointerToPlugin(plugin);
|
||||
}
|
||||
};
|
||||
|
||||
//------------------------------------------------------
|
||||
@@ -139,7 +173,7 @@ public:
|
||||
std::unique_ptr<MkDirGuard> m_dirCreator;
|
||||
TestLoadType m_type;
|
||||
std::string m_cacheDir;
|
||||
using LoadFunction = std::function<void(Core&)>;
|
||||
using LoadFunction = std::function<ExecutableNetwork(Core&)>;
|
||||
using LoadFunctionWithCfg = std::function<void(Core&, const std::map<std::string, std::string> &)>;
|
||||
LoadFunction m_testFunction;
|
||||
LoadFunctionWithCfg m_testFunctionWithCfg;
|
||||
@@ -208,11 +242,11 @@ public:
|
||||
LoadFunction getLoadFunction(TestLoadType type) const {
|
||||
switch (type) {
|
||||
case TestLoadType::ECNN:
|
||||
return [&](Core& ie) { performReadAndLoad(ie); };
|
||||
return [&](Core& ie) { return performReadAndLoad(ie); };
|
||||
case TestLoadType::EContext:
|
||||
return [&](Core& ie) { performReadAndLoadWithContext(ie); };
|
||||
return [&](Core& ie) { return performReadAndLoadWithContext(ie); };
|
||||
case TestLoadType::EModelName:
|
||||
return [&](Core& ie) { performLoadByName(ie); };
|
||||
return [&](Core& ie) { return performLoadByName(ie); };
|
||||
}
|
||||
return nullptr;
|
||||
}
|
||||
@@ -229,22 +263,22 @@ public:
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
void performLoadByName(Core& ie, const std::map<std::string, std::string>& config = {}) const {
|
||||
ie.LoadNetwork(modelName, deviceToLoad, config);
|
||||
ExecutableNetwork performLoadByName(Core& ie, const std::map<std::string, std::string>& config = {}) const {
|
||||
return ie.LoadNetwork(modelName, deviceToLoad, config);
|
||||
}
|
||||
|
||||
void performReadAndLoad(Core& ie, const std::map<std::string, std::string>& config = {}) const {
|
||||
ExecutableNetwork performReadAndLoad(Core& ie, const std::map<std::string, std::string>& config = {}) const {
|
||||
auto cnnNetwork = ie.ReadNetwork(modelName);
|
||||
if (m_cnnCallback) m_cnnCallback(cnnNetwork);
|
||||
ie.LoadNetwork(cnnNetwork, deviceToLoad, config);
|
||||
return ie.LoadNetwork(cnnNetwork, deviceToLoad, config);
|
||||
}
|
||||
|
||||
void performReadAndLoadWithContext(Core& ie, const std::map<std::string, std::string>& config = {}) const {
|
||||
ExecutableNetwork performReadAndLoadWithContext(Core& ie, const std::map<std::string, std::string>& config = {}) const {
|
||||
auto cnnNetwork = ie.ReadNetwork(modelName);
|
||||
EXPECT_CALL(*mockPlugin, GetDefaultContext(_)).Times(AnyNumber());
|
||||
auto context = ie.GetDefaultContext(deviceToLoad);
|
||||
if (m_cnnCallback) m_cnnCallback(cnnNetwork);
|
||||
ie.LoadNetwork(cnnNetwork, context, config);
|
||||
return ie.LoadNetwork(cnnNetwork, context, config);
|
||||
}
|
||||
|
||||
std::shared_ptr<MockExecutableNetwork> createMockIExecutableNet() {
|
||||
@@ -350,6 +384,8 @@ private:
|
||||
EXPECT_CALL(*inferReq, SetCallback(_)).Times(AnyNumber());
|
||||
return inferReq;
|
||||
}));
|
||||
EXPECT_CALL(*net, setNetworkInputs(_)).Times(AnyNumber());
|
||||
EXPECT_CALL(*net, setNetworkOutputs(_)).Times(AnyNumber());
|
||||
}
|
||||
};
|
||||
|
||||
@@ -500,6 +536,7 @@ TEST_P(CachingTest, TestNoCacheSupported) {
|
||||
{
|
||||
EXPECT_CALL(*mockPlugin, LoadExeNetworkImpl(_, _, _)).Times(m_remoteContext ? 1 : 0);
|
||||
EXPECT_CALL(*mockPlugin, LoadExeNetworkImpl(_, _)).Times(!m_remoteContext ? 1 : 0);
|
||||
EXPECT_CALL(*mockPlugin, OnLoadNetworkFromFile()).Times(m_type == TestLoadType::EModelName ? 1 : 0);
|
||||
EXPECT_CALL(*mockPlugin, ImportNetworkImpl(_, _, _)).Times(0);
|
||||
EXPECT_CALL(*mockPlugin, ImportNetworkImpl(_, _)).Times(0);
|
||||
EXPECT_CALL(*net, ExportImpl(_)).Times(0);
|
||||
@@ -518,6 +555,7 @@ TEST_P(CachingTest, TestNoCacheMetricSupported) {
|
||||
{
|
||||
EXPECT_CALL(*mockPlugin, LoadExeNetworkImpl(_, _, _)).Times(m_remoteContext ? 1 : 0);
|
||||
EXPECT_CALL(*mockPlugin, LoadExeNetworkImpl(_, _)).Times(!m_remoteContext ? 1 : 0);
|
||||
EXPECT_CALL(*mockPlugin, OnLoadNetworkFromFile()).Times(m_type == TestLoadType::EModelName ? 1 : 0);
|
||||
EXPECT_CALL(*mockPlugin, ImportNetworkImpl(_, _, _)).Times(0);
|
||||
EXPECT_CALL(*mockPlugin, ImportNetworkImpl(_, _)).Times(0);
|
||||
EXPECT_CALL(*net, ExportImpl(_)).Times(0);
|
||||
@@ -1308,6 +1346,8 @@ TEST_P(CachingTest, LoadMulti_Archs) {
|
||||
{
|
||||
EXPECT_CALL(*mockPlugin, LoadExeNetworkImpl(_, _, _)).Times(0);
|
||||
EXPECT_CALL(*mockPlugin, LoadExeNetworkImpl(_, _)).Times(2);
|
||||
// Load network from file shall not be called for plugins with caching supported
|
||||
EXPECT_CALL(*mockPlugin, OnLoadNetworkFromFile()).Times(0);
|
||||
|
||||
EXPECT_CALL(*mockPlugin, ImportNetworkImpl(_, _, _)).Times(0);
|
||||
EXPECT_CALL(*mockPlugin, ImportNetworkImpl(_, _)).Times(TEST_DEVICE_MAX_COUNT - 2)
|
||||
@@ -1322,6 +1362,54 @@ TEST_P(CachingTest, LoadMulti_Archs) {
|
||||
}
|
||||
}
|
||||
|
||||
// MULTI-DEVICE test
|
||||
// Test loading of devices which don't support caching
|
||||
TEST_P(CachingTest, LoadMulti_NoCachingOnDevice) {
|
||||
const auto TEST_DEVICE_MAX_COUNT = 100; // Looks enough to catch potential race conditions
|
||||
EXPECT_CALL(*mockPlugin, GetMetric(_, _)).Times(AnyNumber());
|
||||
EXPECT_CALL(*mockPlugin, GetMetric(METRIC_KEY(IMPORT_EXPORT_SUPPORT), _))
|
||||
.Times(AnyNumber()).WillRepeatedly(Return(false));
|
||||
EXPECT_CALL(*mockPlugin, QueryNetwork(_, _)).Times(AnyNumber());
|
||||
EXPECT_CALL(*mockPlugin, GetMetric(METRIC_KEY(DEVICE_ARCHITECTURE), _)).Times(AnyNumber());
|
||||
DataPtr inData = std::make_shared<Data>("in", Precision::FP32);
|
||||
InputInfo inpInfo;
|
||||
inpInfo.setInputData(inData);
|
||||
InputInfo::CPtr cptr = std::make_shared<InputInfo>(inpInfo);
|
||||
ConstInputsDataMap inputMap {{"Input1", cptr}};
|
||||
CDataPtr dataptr = std::make_shared<Data>("out", Precision::FP32);
|
||||
ConstOutputsDataMap outputMap {{"Output1", dataptr}};
|
||||
EXPECT_CALL(*net, GetInputsInfo()).Times(AnyNumber()).WillRepeatedly(Return(inputMap));
|
||||
EXPECT_CALL(*net, GetOutputsInfo()).Times(AnyNumber()).WillRepeatedly(Return(outputMap));
|
||||
if (m_remoteContext) {
|
||||
return; // skip the remote Context test for Multi plugin
|
||||
}
|
||||
|
||||
deviceToLoad = CommonTestUtils::DEVICE_MULTI;
|
||||
deviceToLoad += ":mock.0";
|
||||
for (int i = 1; i < TEST_DEVICE_MAX_COUNT; i++) {
|
||||
deviceToLoad += ",mock." + std::to_string(i);
|
||||
}
|
||||
|
||||
{
|
||||
EXPECT_CALL(*mockPlugin, LoadExeNetworkImpl(_, _, _)).Times(0);
|
||||
EXPECT_CALL(*mockPlugin, LoadExeNetworkImpl(_, _)).Times(TEST_DEVICE_MAX_COUNT);
|
||||
// Load network from file shall not be called by Multi plugin for devices with caching supported
|
||||
EXPECT_CALL(*mockPlugin, OnLoadNetworkFromFile()).Times(0);
|
||||
|
||||
EXPECT_CALL(*mockPlugin, ImportNetworkImpl(_, _, _)).Times(0);
|
||||
EXPECT_CALL(*mockPlugin, ImportNetworkImpl(_, _)).Times(0);
|
||||
EXPECT_CALL(*net, ExportImpl(_)).Times(0);
|
||||
testLoad([&](Core &ie) {
|
||||
ie.SetConfig({{CONFIG_KEY(CACHE_DIR), m_cacheDir}});
|
||||
ExecutableNetwork exeNet;
|
||||
ASSERT_NO_THROW(exeNet = m_testFunction(ie));
|
||||
// Verify that inputs and outputs are set for Multi Executable Network
|
||||
ASSERT_EQ(exeNet.GetInputsInfo().size(), inputMap.size());
|
||||
ASSERT_EQ(exeNet.GetOutputsInfo().size(), outputMap.size());
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
INSTANTIATE_TEST_CASE_P(CachingTest, CachingTest,
|
||||
::testing::Combine(
|
||||
::testing::ValuesIn(loadVariants),
|
||||
|
||||
@@ -31,6 +31,7 @@ public:
|
||||
|
||||
MOCK_QUALIFIED_METHOD2(GetMetric, const, InferenceEngine::Parameter(const std::string&, const std::string&));
|
||||
MOCK_QUALIFIED_METHOD0(GetAvailableDevices, const, std::vector<std::string>());
|
||||
MOCK_QUALIFIED_METHOD1(DeviceSupportsImportExport, const, bool(const std::string&)); // NOLINT not a cast to bool
|
||||
|
||||
~MockICore() = default;
|
||||
};
|
||||
|
||||
@@ -49,6 +49,16 @@ MockPlugin::LoadNetwork(const CNNNetwork& network, const std::map<std::string, s
|
||||
}
|
||||
}
|
||||
|
||||
InferenceEngine::IExecutableNetworkInternal::Ptr
|
||||
MockPlugin::LoadNetwork(const std::string &modelPath,
|
||||
const std::map<std::string, std::string> &config) {
|
||||
if (_target) {
|
||||
return _target->LoadNetwork(modelPath, config);
|
||||
} else {
|
||||
return InferenceEngine::InferencePluginInternal::LoadNetwork(modelPath, config);
|
||||
}
|
||||
}
|
||||
|
||||
ExecutableNetworkInternal::Ptr
|
||||
MockPlugin::LoadExeNetworkImpl(const CNNNetwork& network,
|
||||
const std::map<std::string, std::string>& config) {
|
||||
@@ -94,6 +104,27 @@ MockPlugin::QueryNetwork(const InferenceEngine::CNNNetwork& network,
|
||||
}
|
||||
}
|
||||
|
||||
void MockPlugin::SetCore(InferenceEngine::ICore* core) noexcept {
|
||||
if (_target) {
|
||||
_target->SetCore(core);
|
||||
}
|
||||
InferenceEngine::InferencePluginInternal::SetCore(core);
|
||||
}
|
||||
|
||||
void MockPlugin::SetName(const std::string& name) noexcept {
|
||||
if (_target) {
|
||||
_target->SetName(name);
|
||||
}
|
||||
InferenceEngine::InferencePluginInternal::SetName(name);
|
||||
}
|
||||
|
||||
std::string MockPlugin::GetName() const noexcept {
|
||||
if (_target) {
|
||||
return _target->GetName();
|
||||
}
|
||||
return InferenceEngine::InferencePluginInternal::GetName();
|
||||
}
|
||||
|
||||
|
||||
InferenceEngine::IInferencePlugin *__target = nullptr;
|
||||
|
||||
|
||||
@@ -30,6 +30,10 @@ public:
|
||||
LoadExeNetworkImpl(const InferenceEngine::CNNNetwork& network,
|
||||
const std::map<std::string, std::string>& config) override;
|
||||
|
||||
InferenceEngine::IExecutableNetworkInternal::Ptr
|
||||
LoadNetwork(const std::string &modelPath,
|
||||
const std::map<std::string, std::string> &config) override;
|
||||
|
||||
std::shared_ptr<InferenceEngine::ExecutableNetworkInternal>
|
||||
ImportNetworkImpl(std::istream& networkModel,
|
||||
const std::map<std::string, std::string>& config) override;
|
||||
@@ -47,5 +51,11 @@ public:
|
||||
InferenceEngine::QueryNetworkResult QueryNetwork(const InferenceEngine::CNNNetwork& network,
|
||||
const std::map<std::string, std::string>& config) const override;
|
||||
|
||||
void SetCore(InferenceEngine::ICore* core) noexcept override;
|
||||
|
||||
void SetName(const std::string& name) noexcept override;
|
||||
|
||||
std::string GetName() const noexcept override;
|
||||
|
||||
std::map<std::string, std::string> config;
|
||||
};
|
||||
|
||||
Reference in New Issue
Block a user