Added cloneNetwork method into plugins api (#3450)
* Added cloneNetwork method into plugins api * Fixed cloneNetwork call in MKLDNN and GNA plugins to pick correct function * Changed return type
This commit is contained in:
parent
2d75d8aff2
commit
325a0a4f5e
@ -436,7 +436,7 @@ void GNAPlugin::UpdateInputScaleFromNetwork(InferenceEngine::ICNNNetwork & netwo
|
||||
void GNAPlugin::LoadNetwork(CNNNetwork & _network) {
|
||||
std::shared_ptr<InferenceEngine::details::CNNNetworkImpl> convertedNetwork;
|
||||
if (_network.getFunction()) {
|
||||
std::shared_ptr<ICNNNetwork> clonedNetwork = cloneNetwork(_network);
|
||||
std::shared_ptr<ICNNNetwork> clonedNetwork = InferenceEngine::cloneNetwork(_network);
|
||||
const auto& graph = clonedNetwork->getFunction();
|
||||
// Disable shape inference (WA for generic operations)
|
||||
ngraph::op::GenericIE::DisableReshape noReshape(graph);
|
||||
|
@ -36,7 +36,7 @@ public:
|
||||
updated_config.UpdateFromMap(config);
|
||||
auto plg = std::make_shared<GNAPlugin>(updated_config.key_config_map);
|
||||
plgPtr = plg;
|
||||
InferenceEngine::CNNNetwork clonedNetwork(cloneNetwork(network));
|
||||
InferenceEngine::CNNNetwork clonedNetwork(InferenceEngine::cloneNetwork(network));
|
||||
return std::make_shared<GNAExecutableNetwork>(clonedNetwork, plg);
|
||||
}
|
||||
|
||||
|
23
inference-engine/src/inference_engine/ie_ngraph_utils.cpp
Normal file
23
inference-engine/src/inference_engine/ie_ngraph_utils.cpp
Normal file
@ -0,0 +1,23 @@
|
||||
// Copyright (C) 2020 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#include <ie_ngraph_utils.hpp>
|
||||
#include "cnn_network_ngraph_impl.hpp"
|
||||
#include "ie_itt.hpp"
|
||||
|
||||
namespace InferenceEngine {
|
||||
namespace details {
|
||||
|
||||
CNNNetwork cloneNetwork(const CNNNetwork& network) {
|
||||
OV_ITT_SCOPED_TASK(itt::domains::IE, "cloneNetwork");
|
||||
|
||||
if (network.getFunction()) {
|
||||
return CNNNetwork(std::make_shared<details::CNNNetworkNGraphImpl>(network));
|
||||
}
|
||||
|
||||
THROW_IE_EXCEPTION << "InferenceEngine::details::cloneNetwork requires ngraph-based `network` object to clone";
|
||||
}
|
||||
|
||||
} // namespace details
|
||||
} // namespace InferenceEngine
|
@ -298,7 +298,7 @@ Engine::LoadExeNetworkImpl(const InferenceEngine::CNNNetwork &network, const std
|
||||
conf.batchLimit = static_cast<int>(network.getBatchSize());
|
||||
}
|
||||
|
||||
std::shared_ptr<ICNNNetwork> clonedNetwork = cloneNetwork(network);
|
||||
std::shared_ptr<ICNNNetwork> clonedNetwork = InferenceEngine::cloneNetwork(network);
|
||||
|
||||
bool is_transformed = false;
|
||||
if (clonedNetwork->getFunction()) {
|
||||
@ -437,7 +437,7 @@ QueryNetworkResult Engine::QueryNetwork(const CNNNetwork& network, const std::ma
|
||||
conf.batchLimit = static_cast<int>(network.getBatchSize());
|
||||
}
|
||||
|
||||
auto clonedNetwork = cloneNetwork(network);
|
||||
auto clonedNetwork = InferenceEngine::cloneNetwork(network);
|
||||
Transformation(clonedNetwork, conf);
|
||||
std::unordered_set<std::string> supported;
|
||||
std::unordered_set<std::string> unsupported;
|
||||
|
@ -8,6 +8,7 @@
|
||||
#include <ngraph/type/element_type.hpp>
|
||||
#include <string>
|
||||
#include <algorithm>
|
||||
#include <cpp/ie_cnn_network.h>
|
||||
|
||||
namespace InferenceEngine {
|
||||
namespace details {
|
||||
@ -126,5 +127,13 @@ inline Precision convertPrecision(const ::ngraph::element::Type& precision) {
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Clones input network including all layers and internal data objects
|
||||
* @note Blobs inside layers are reused
|
||||
* @param network A network to clone
|
||||
* @return A cloned object
|
||||
*/
|
||||
INFERENCE_ENGINE_API_CPP(CNNNetwork) cloneNetwork(const CNNNetwork& network);
|
||||
|
||||
} // namespace details
|
||||
} // namespace InferenceEngine
|
||||
|
Loading…
Reference in New Issue
Block a user