Added cloneNetwork method into plugins api (#3450)

* Added cloneNetwork method into plugins api

* Fixed cloneNetwork call in MKLDNN and GNA plugins to pick correct function

* Changed return type
This commit is contained in:
Vladimir Paramuzov 2020-12-03 17:53:15 +03:00 committed by GitHub
parent 2d75d8aff2
commit 325a0a4f5e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 36 additions and 4 deletions

View File

@ -436,7 +436,7 @@ void GNAPlugin::UpdateInputScaleFromNetwork(InferenceEngine::ICNNNetwork & netwo
void GNAPlugin::LoadNetwork(CNNNetwork & _network) {
std::shared_ptr<InferenceEngine::details::CNNNetworkImpl> convertedNetwork;
if (_network.getFunction()) {
std::shared_ptr<ICNNNetwork> clonedNetwork = cloneNetwork(_network);
std::shared_ptr<ICNNNetwork> clonedNetwork = InferenceEngine::cloneNetwork(_network);
const auto& graph = clonedNetwork->getFunction();
// Disable shape inference (WA for generic operations)
ngraph::op::GenericIE::DisableReshape noReshape(graph);

View File

@ -36,7 +36,7 @@ public:
updated_config.UpdateFromMap(config);
auto plg = std::make_shared<GNAPlugin>(updated_config.key_config_map);
plgPtr = plg;
InferenceEngine::CNNNetwork clonedNetwork(cloneNetwork(network));
InferenceEngine::CNNNetwork clonedNetwork(InferenceEngine::cloneNetwork(network));
return std::make_shared<GNAExecutableNetwork>(clonedNetwork, plg);
}

View File

@ -0,0 +1,23 @@
// Copyright (C) 2020 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#include <ie_ngraph_utils.hpp>
#include "cnn_network_ngraph_impl.hpp"
#include "ie_itt.hpp"
namespace InferenceEngine {
namespace details {
CNNNetwork cloneNetwork(const CNNNetwork& network) {
OV_ITT_SCOPED_TASK(itt::domains::IE, "cloneNetwork");
if (network.getFunction()) {
return CNNNetwork(std::make_shared<details::CNNNetworkNGraphImpl>(network));
}
THROW_IE_EXCEPTION << "InferenceEngine::details::cloneNetwork requires ngraph-based `network` object to clone";
}
} // namespace details
} // namespace InferenceEngine

View File

@ -298,7 +298,7 @@ Engine::LoadExeNetworkImpl(const InferenceEngine::CNNNetwork &network, const std
conf.batchLimit = static_cast<int>(network.getBatchSize());
}
std::shared_ptr<ICNNNetwork> clonedNetwork = cloneNetwork(network);
std::shared_ptr<ICNNNetwork> clonedNetwork = InferenceEngine::cloneNetwork(network);
bool is_transformed = false;
if (clonedNetwork->getFunction()) {
@ -437,7 +437,7 @@ QueryNetworkResult Engine::QueryNetwork(const CNNNetwork& network, const std::ma
conf.batchLimit = static_cast<int>(network.getBatchSize());
}
auto clonedNetwork = cloneNetwork(network);
auto clonedNetwork = InferenceEngine::cloneNetwork(network);
Transformation(clonedNetwork, conf);
std::unordered_set<std::string> supported;
std::unordered_set<std::string> unsupported;

View File

@ -8,6 +8,7 @@
#include <ngraph/type/element_type.hpp>
#include <string>
#include <algorithm>
#include <cpp/ie_cnn_network.h>
namespace InferenceEngine {
namespace details {
@ -126,5 +127,13 @@ inline Precision convertPrecision(const ::ngraph::element::Type& precision) {
}
}
/**
* @brief Clones input network including all layers and internal data objects
* @note Blobs inside layers are reused
* @param network A network to clone
* @return A cloned object
*/
INFERENCE_ENGINE_API_CPP(CNNNetwork) cloneNetwork(const CNNNetwork& network);
} // namespace details
} // namespace InferenceEngine