Add LoadNetwork(modelPath) to plugin interface (#5606)

This commit is contained in:
Mikhail Nosov 2021-05-12 21:43:35 +03:00 committed by GitHub
parent f2f44ce160
commit 4d7eeede35
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
11 changed files with 53 additions and 10 deletions

View File

@ -110,6 +110,8 @@ class GNAPlugin : public InferenceEngine::IInferencePlugin {
InferenceEngine::IExecutableNetworkInternal::Ptr LoadNetwork(const InferenceEngine::CNNNetwork &network,
const std::map<std::string, std::string> &config_map,
InferenceEngine::RemoteContext::Ptr context) override { THROW_GNA_EXCEPTION << "Not implemented"; }
InferenceEngine::ExecutableNetwork LoadNetwork(const std::string &modelPath,
const std::map<std::string, std::string> &config_map) override { THROW_GNA_EXCEPTION << "Not implemented"; }
bool Infer(const InferenceEngine::Blob &input, InferenceEngine::Blob &result);
void SetCore(InferenceEngine::ICore*) noexcept override {}
InferenceEngine::ICore* GetCore() const noexcept override {return nullptr;}

View File

@ -493,9 +493,8 @@ public:
return res;
}
// TODO: In future this method can be added to ICore interface
ExecutableNetwork LoadNetwork(const std::string& modelPath, const std::string& deviceName,
const std::map<std::string, std::string>& config) {
const std::map<std::string, std::string>& config) override {
OV_ITT_SCOPE(FIRST_INFERENCE, itt::domains::IE_LT, "Core::LoadNetwork::Path");
auto parsed = parseDeviceNameIntoConfig(deviceName, config);
auto plugin = GetCPPPluginByName(parsed._deviceName);
@ -511,6 +510,8 @@ public:
auto cnnNetwork = ReadNetwork(modelPath, std::string());
res = LoadNetworkImpl(cnnNetwork, plugin, parsed._config, nullptr, hash, modelPath);
}
} else if (cacheManager) {
res = plugin.LoadNetwork(modelPath, parsed._config);
} else {
auto cnnNetwork = ReadNetwork(modelPath, std::string());
res = LoadNetworkImpl(cnnNetwork, plugin, parsed._config, nullptr, {}, modelPath);

View File

@ -88,6 +88,10 @@ public:
PLUGIN_CALL_STATEMENT(return ExecutableNetwork(actual->LoadNetwork(network, config, context), actual));
}
ExecutableNetwork LoadNetwork(const std::string& modelPath, const std::map<std::string, std::string>& config) {
PLUGIN_CALL_STATEMENT(return actual->LoadNetwork(modelPath, config));
}
QueryNetworkResult QueryNetwork(const CNNNetwork& network,
const std::map<std::string, std::string>& config) const {
QueryNetworkResult res;

View File

@ -72,6 +72,12 @@ public:
return impl;
}
ExecutableNetwork LoadNetwork(const std::string& modelPath,
const std::map<std::string, std::string>& config) override {
auto cnnNet = GetCore()->ReadNetwork(modelPath, std::string());
return GetCore()->LoadNetwork(cnnNet, GetName(), config);
}
IExecutableNetworkInternal::Ptr ImportNetwork(const std::string& modelFileName,
const std::map<std::string, std::string>& config) override {
(void)modelFileName;

View File

@ -166,6 +166,16 @@ public:
virtual std::shared_ptr<IExecutableNetworkInternal> LoadNetwork(const CNNNetwork& network,
const std::map<std::string, std::string>& config,
RemoteContext::Ptr context) = 0;
/**
* @brief Creates an executable network from model file path
* @param modelPath A path to model
* @param config A string-string map of config parameters relevant only for this load operation
* @return Created Executable Network object
*/
virtual ExecutableNetwork LoadNetwork(const std::string& modelPath,
const std::map<std::string, std::string>& config) = 0;
/**
* @brief Registers extension within plugin
* @param extension - pointer to already loaded extension

View File

@ -66,6 +66,21 @@ public:
virtual ExecutableNetwork LoadNetwork(const CNNNetwork& network, const std::string& deviceName,
const std::map<std::string, std::string>& config = {}) = 0;
/**
* @brief Creates an executable network from a model file.
*
* Users can create as many networks as they need and use
* them simultaneously (up to the limitation of the hardware resources)
*
* @param modelPath Path to model
* @param deviceName Name of device to load network to
* @param config Optional map of pairs: (config parameter name, config parameter value) relevant only for this load
* operation
* @return An executable network reference
*/
virtual ExecutableNetwork LoadNetwork(const std::string& modelPath, const std::string& deviceName,
const std::map<std::string, std::string>& config) = 0;
/**
* @brief Creates an executable network from a previously exported network
* @param networkModel network model stream

View File

@ -18,6 +18,8 @@ public:
const InferenceEngine::CNNNetwork&, const std::string&, const std::map<std::string, std::string>&));
MOCK_METHOD3(LoadNetwork, InferenceEngine::ExecutableNetwork(
const InferenceEngine::CNNNetwork&, const InferenceEngine::RemoteContext::Ptr &, const std::map<std::string, std::string>&));
MOCK_METHOD3(LoadNetwork, InferenceEngine::ExecutableNetwork(
const std::string &, const std::string &, const std::map<std::string, std::string>&));
MOCK_METHOD3(ImportNetwork, InferenceEngine::ExecutableNetwork(
std::istream&, const std::string&, const std::map<std::string, std::string>&));

View File

@ -15,6 +15,8 @@ public:
MOCK_METHOD1(AddExtension, void(InferenceEngine::IExtensionPtr));
MOCK_METHOD2(LoadNetwork, std::shared_ptr<InferenceEngine::IExecutableNetworkInternal>(
const InferenceEngine::CNNNetwork&, const std::map<std::string, std::string>&));
MOCK_METHOD2(LoadNetwork, InferenceEngine::ExecutableNetwork(
const std::string&, const std::map<std::string, std::string>&));
MOCK_METHOD2(ImportNetwork, std::shared_ptr<InferenceEngine::IExecutableNetworkInternal>(
const std::string&, const std::map<std::string, std::string>&));
MOCK_METHOD1(SetConfig, void(const std::map<std::string, std::string> &));

View File

@ -199,9 +199,9 @@ protected:
mockIExeNet = std::make_shared<MockIExecutableNetworkInternal>();
ON_CALL(*mockIExeNet, CreateInferRequest()).WillByDefault(Return(mock_request));
std::unique_ptr<MockIInferencePlugin> mockIPluginPtr{new MockIInferencePlugin};
ON_CALL(*mockIPluginPtr, LoadNetwork(_, _)).WillByDefault(Return(mockIExeNet));
ON_CALL(*mockIPluginPtr, LoadNetwork(MatcherCast<const CNNNetwork&>(_), _)).WillByDefault(Return(mockIExeNet));
plugin = InferenceEngine::InferencePlugin{InferenceEngine::details::SOPointer<MockIInferencePlugin>{mockIPluginPtr.release()}};
exeNetwork = plugin.LoadNetwork({}, {});
exeNetwork = plugin.LoadNetwork(CNNNetwork{}, {});
request = exeNetwork.CreateInferRequest();
_incorrectName = "incorrect_name";
_inputName = MockNotEmptyICNNNetwork::INPUT_BLOB_NAME;
@ -223,9 +223,9 @@ protected:
auto mockIExeNet = std::make_shared<MockIExecutableNetworkInternal>();
ON_CALL(*mockIExeNet, CreateInferRequest()).WillByDefault(Return(mockInferRequestInternal));
std::unique_ptr<MockIInferencePlugin> mockIPluginPtr{new MockIInferencePlugin};
ON_CALL(*mockIPluginPtr, LoadNetwork(_, _)).WillByDefault(Return(mockIExeNet));
ON_CALL(*mockIPluginPtr, LoadNetwork(MatcherCast<const CNNNetwork&>(_), _)).WillByDefault(Return(mockIExeNet));
auto plugin = InferenceEngine::InferencePlugin{InferenceEngine::details::SOPointer<MockIInferencePlugin>{mockIPluginPtr.release()}};
auto exeNetwork = plugin.LoadNetwork({}, {});
auto exeNetwork = plugin.LoadNetwork(CNNNetwork{}, {});
return exeNetwork.CreateInferRequest();
}

View File

@ -36,9 +36,9 @@ class VariableStateTests : public ::testing::Test {
mockVariableStateInternal = make_shared<MockIVariableStateInternal>();
ON_CALL(*mockExeNetworkInternal, CreateInferRequest()).WillByDefault(Return(mockInferRequestInternal));
std::unique_ptr<MockIInferencePlugin> mockIPluginPtr{new MockIInferencePlugin};
ON_CALL(*mockIPluginPtr, LoadNetwork(_, _)).WillByDefault(Return(mockExeNetworkInternal));
ON_CALL(*mockIPluginPtr, LoadNetwork(MatcherCast<const CNNNetwork&>(_), _)).WillByDefault(Return(mockExeNetworkInternal));
plugin = InferenceEngine::InferencePlugin{InferenceEngine::details::SOPointer<MockIInferencePlugin>{mockIPluginPtr.release()}};
net = plugin.LoadNetwork({}, {});
net = plugin.LoadNetwork(CNNNetwork{}, {});
req = net.CreateInferRequest();
}
};

View File

@ -20,6 +20,7 @@
#include "unit_test_utils/mocks/cpp_interfaces/interface/mock_iinference_plugin.hpp"
using testing::_;
using testing::MatcherCast;
using testing::Throw;
using testing::Ref;
using testing::Return;
@ -52,9 +53,9 @@ protected:
virtual void SetUp() {
mockIExeNet = std::make_shared<MockIExecutableNetworkInternal>();
std::unique_ptr<MockIInferencePlugin> mockIPluginPtr{new MockIInferencePlugin};
ON_CALL(*mockIPluginPtr, LoadNetwork(_, _)).WillByDefault(Return(mockIExeNet));
ON_CALL(*mockIPluginPtr, LoadNetwork(MatcherCast<const CNNNetwork&>(_), _)).WillByDefault(Return(mockIExeNet));
plugin = InferenceEngine::InferencePlugin{InferenceEngine::details::SOPointer<MockIInferencePlugin>{mockIPluginPtr.release()}};
exeNetwork = plugin.LoadNetwork({}, {});
exeNetwork = plugin.LoadNetwork(CNNNetwork{}, {});
}
};