From 325a0a4f5e9f0024b321dc3eb4806d8436763c29 Mon Sep 17 00:00:00 2001 From: Vladimir Paramuzov Date: Thu, 3 Dec 2020 17:53:15 +0300 Subject: [PATCH] Added cloneNetwork method into plugins api (#3450) * Added cloneNetwork method into plugins api * Fixed cloneNetwork call in MKLDNN and GNA plugins to pick correct function * Changed return type --- .../src/gna_plugin/gna_plugin.cpp | 2 +- .../src/gna_plugin/gna_plugin_internal.hpp | 2 +- .../src/inference_engine/ie_ngraph_utils.cpp | 23 +++++++++++++++++++ .../src/mkldnn_plugin/mkldnn_plugin.cpp | 4 ++-- .../src/plugin_api/ie_ngraph_utils.hpp | 9 ++++++++ 5 files changed, 36 insertions(+), 4 deletions(-) create mode 100644 inference-engine/src/inference_engine/ie_ngraph_utils.cpp diff --git a/inference-engine/src/gna_plugin/gna_plugin.cpp b/inference-engine/src/gna_plugin/gna_plugin.cpp index b8f8be8d793..a724cefc376 100644 --- a/inference-engine/src/gna_plugin/gna_plugin.cpp +++ b/inference-engine/src/gna_plugin/gna_plugin.cpp @@ -436,7 +436,7 @@ void GNAPlugin::UpdateInputScaleFromNetwork(InferenceEngine::ICNNNetwork & netwo void GNAPlugin::LoadNetwork(CNNNetwork & _network) { std::shared_ptr convertedNetwork; if (_network.getFunction()) { - std::shared_ptr clonedNetwork = cloneNetwork(_network); + std::shared_ptr clonedNetwork = InferenceEngine::cloneNetwork(_network); const auto& graph = clonedNetwork->getFunction(); // Disable shape inference (WA for generic operations) ngraph::op::GenericIE::DisableReshape noReshape(graph); diff --git a/inference-engine/src/gna_plugin/gna_plugin_internal.hpp b/inference-engine/src/gna_plugin/gna_plugin_internal.hpp index 0b3e80e921e..08f1efb8bae 100644 --- a/inference-engine/src/gna_plugin/gna_plugin_internal.hpp +++ b/inference-engine/src/gna_plugin/gna_plugin_internal.hpp @@ -36,7 +36,7 @@ public: updated_config.UpdateFromMap(config); auto plg = std::make_shared(updated_config.key_config_map); plgPtr = plg; - InferenceEngine::CNNNetwork clonedNetwork(cloneNetwork(network)); + InferenceEngine::CNNNetwork clonedNetwork(InferenceEngine::cloneNetwork(network)); return std::make_shared(clonedNetwork, plg); } diff --git a/inference-engine/src/inference_engine/ie_ngraph_utils.cpp b/inference-engine/src/inference_engine/ie_ngraph_utils.cpp new file mode 100644 index 00000000000..1663b35c012 --- /dev/null +++ b/inference-engine/src/inference_engine/ie_ngraph_utils.cpp @@ -0,0 +1,23 @@ +// Copyright (C) 2020 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#include +#include "cnn_network_ngraph_impl.hpp" +#include "ie_itt.hpp" + +namespace InferenceEngine { +namespace details { + +CNNNetwork cloneNetwork(const CNNNetwork& network) { + OV_ITT_SCOPED_TASK(itt::domains::IE, "cloneNetwork"); + + if (network.getFunction()) { + return CNNNetwork(std::make_shared(network)); + } + + THROW_IE_EXCEPTION << "InferenceEngine::details::cloneNetwork requires ngraph-based `network` object to clone"; +} + +} // namespace details +} // namespace InferenceEngine diff --git a/inference-engine/src/mkldnn_plugin/mkldnn_plugin.cpp b/inference-engine/src/mkldnn_plugin/mkldnn_plugin.cpp index d5d8d8316fd..1f4553a3ca3 100644 --- a/inference-engine/src/mkldnn_plugin/mkldnn_plugin.cpp +++ b/inference-engine/src/mkldnn_plugin/mkldnn_plugin.cpp @@ -298,7 +298,7 @@ Engine::LoadExeNetworkImpl(const InferenceEngine::CNNNetwork &network, const std conf.batchLimit = static_cast(network.getBatchSize()); } - std::shared_ptr clonedNetwork = cloneNetwork(network); + std::shared_ptr clonedNetwork = InferenceEngine::cloneNetwork(network); bool is_transformed = false; if (clonedNetwork->getFunction()) { @@ -437,7 +437,7 @@ QueryNetworkResult Engine::QueryNetwork(const CNNNetwork& network, const std::ma conf.batchLimit = static_cast(network.getBatchSize()); } - auto clonedNetwork = cloneNetwork(network); + auto clonedNetwork = InferenceEngine::cloneNetwork(network); Transformation(clonedNetwork, conf); std::unordered_set supported; std::unordered_set unsupported; diff --git a/inference-engine/src/plugin_api/ie_ngraph_utils.hpp b/inference-engine/src/plugin_api/ie_ngraph_utils.hpp index 3a05fcbcfe5..22cd621cd47 100644 --- a/inference-engine/src/plugin_api/ie_ngraph_utils.hpp +++ b/inference-engine/src/plugin_api/ie_ngraph_utils.hpp @@ -8,6 +8,7 @@ #include #include #include +#include namespace InferenceEngine { namespace details { @@ -126,5 +127,13 @@ inline Precision convertPrecision(const ::ngraph::element::Type& precision) { } } +/** + * @brief Clones input network including all layers and internal data objects + * @note Blobs inside layers are reused + * @param network A network to clone + * @return A cloned object + */ +INFERENCE_ENGINE_API_CPP(CNNNetwork) cloneNetwork(const CNNNetwork& network); + } // namespace details } // namespace InferenceEngine