From 6c3b7ee8cab5b6d7cd3c82f73cae32163a8c9472 Mon Sep 17 00:00:00 2001 From: Ilya Churaev Date: Wed, 29 Jul 2020 19:30:59 +0300 Subject: [PATCH] Avoid redundant clone and reshape (#1376) * Avoid redundant clone and reshape * Removed some constructors * Fixed output precision --- inference-engine/include/cpp/ie_cnn_network.h | 4 +- .../hetero_executable_network.cpp | 6 +-- .../cnn_network_ngraph_impl.cpp | 37 ++++++++++++++++-- .../cnn_network_ngraph_impl.hpp | 1 + .../src/legacy_api/src/ie_util_internal.cpp | 27 ++----------- .../inference_engine/cnn_network_test.cpp | 2 +- .../inference_engine/ngraph_reshape_tests.cpp | 38 ++++++++++++++++++- 7 files changed, 80 insertions(+), 35 deletions(-) diff --git a/inference-engine/include/cpp/ie_cnn_network.h b/inference-engine/include/cpp/ie_cnn_network.h index f71dd32b09d..e1b2c719916 100644 --- a/inference-engine/include/cpp/ie_cnn_network.h +++ b/inference-engine/include/cpp/ie_cnn_network.h @@ -51,9 +51,11 @@ public: /** * @brief A constructor from ngraph::Function object + * This constructor wraps existing ngraph::Function + * If you want to avoid modification of original Function, please create a copy * @param network Pointer to the ngraph::Function object */ - explicit CNNNetwork(const std::shared_ptr& network); + explicit CNNNetwork(const std::shared_ptr& network); /** * @brief A destructor diff --git a/inference-engine/src/hetero_plugin/hetero_executable_network.cpp b/inference-engine/src/hetero_plugin/hetero_executable_network.cpp index e9186122668..36e027bcd78 100644 --- a/inference-engine/src/hetero_plugin/hetero_executable_network.cpp +++ b/inference-engine/src/hetero_plugin/hetero_executable_network.cpp @@ -143,7 +143,7 @@ void dumpGraph(InferenceEngine::ICNNNetwork &network, void dumpGraph(InferenceEngine::ICNNNetwork& network, - const std::vector>& subFunctions, + const std::vector>& subFunctions, std::ostream& stream) { static const std::array colors{{"#FFC405", "#20F608", @@ -665,13 +665,13 @@ void HeteroExecutableNetwork::InitNgraph(const InferenceEngine::ICNNNetwork& net InputsDataMap externalInputsData; network.getInputsInfo(externalInputsData); networks.resize(orderedSubgraphs.size()); - std::vector> subFunctions(orderedSubgraphs.size()); + std::vector> subFunctions(orderedSubgraphs.size()); std::vector isInputSubnetwork(orderedSubgraphs.size()); int id = 0; for (auto&& subgraph : orderedSubgraphs) { networks[id]._device = subgraph._affinity; subFunctions[id] = - std::make_shared(subgraph._results, subgraph._parameters, + std::make_shared(subgraph._results, subgraph._parameters, _name + '_' + std::to_string(id)); networks[id]._clonedNetwork = CNNNetwork{subFunctions[id]}; // update of pre-processing info diff --git a/inference-engine/src/inference_engine/cnn_network_ngraph_impl.cpp b/inference-engine/src/inference_engine/cnn_network_ngraph_impl.cpp index 8a5edef782a..678a3039832 100644 --- a/inference-engine/src/inference_engine/cnn_network_ngraph_impl.cpp +++ b/inference-engine/src/inference_engine/cnn_network_ngraph_impl.cpp @@ -71,14 +71,13 @@ static std::shared_ptr copyFunction(const std::shared_ptr& graph) { +CNNNetwork::CNNNetwork(const std::shared_ptr& graph) { if (graph == nullptr) { THROW_IE_EXCEPTION << "CNNNetwork was not initialized: 'graph' object is empty"; } - // Copy nGraph function - network = std::make_shared(copyFunction(graph, false, {})); + // Create CNNNetworkNGraphImpl + network = std::make_shared(graph); actual = network.get(); if (actual == nullptr) { THROW_IE_EXCEPTION << "CNNNetwork was not initialized."; @@ -146,6 +145,36 @@ CNNNetworkNGraphImpl::CNNNetworkNGraphImpl(const std::shared_ptr& nGra } } +CNNNetworkNGraphImpl::CNNNetworkNGraphImpl(const ICNNNetwork& network) { + if (network.getFunction() == nullptr) { + THROW_IE_EXCEPTION << "Cannot create CNNNetwork with nGraph from legacy network format!"; + } + + _ngraph_function = copyFunction(network.getFunction(), false, {}); + InputsDataMap inputs; + OutputsDataMap outputs; + network.getInputsInfo(inputs); + network.getOutputsInfo(outputs); + + for (const auto& outputInfo : outputs) { + const auto& name = outputInfo.second->getName(); + DataPtr output = std::make_shared(name, outputInfo.second->getTensorDesc()); + _outputData[name] = output; + _data[name] = output; + } + for (const auto& inputInfo : inputs) { + InputInfo::Ptr info = std::make_shared(); + const auto& name = inputInfo.second->getInputData()->getName(); + DataPtr input = std::make_shared(name, inputInfo.second->getInputData()->getTensorDesc()); + _data[name] = input; + info->setInputData(input); + info->getPreProcess() = inputInfo.second->getPreProcess(); + info->setPrecision(inputInfo.second->getPrecision()); + info->setLayout(inputInfo.second->getLayout()); + _inputData[name] = info; + } +} + void CNNNetworkNGraphImpl::setInputInfo(InputInfo::Ptr data) { if (cnnNetwork) cnnNetwork->setInputInfo(data); _inputData[data->name()] = data; diff --git a/inference-engine/src/inference_engine/cnn_network_ngraph_impl.hpp b/inference-engine/src/inference_engine/cnn_network_ngraph_impl.hpp index 7b042c72b12..e776093dd22 100644 --- a/inference-engine/src/inference_engine/cnn_network_ngraph_impl.hpp +++ b/inference-engine/src/inference_engine/cnn_network_ngraph_impl.hpp @@ -43,6 +43,7 @@ namespace details { class INFERENCE_ENGINE_API_CLASS(CNNNetworkNGraphImpl): public ICNNNetwork { public: CNNNetworkNGraphImpl(const std::shared_ptr<::ngraph::Function>& nGraph); + CNNNetworkNGraphImpl(const ICNNNetwork& nGraph); ~CNNNetworkNGraphImpl() override = default; void getOutputsInfo(std::map& out) const noexcept override; diff --git a/inference-engine/src/legacy_api/src/ie_util_internal.cpp b/inference-engine/src/legacy_api/src/ie_util_internal.cpp index c8b4767e6ba..9fa73ea5bb7 100644 --- a/inference-engine/src/legacy_api/src/ie_util_internal.cpp +++ b/inference-engine/src/legacy_api/src/ie_util_internal.cpp @@ -24,6 +24,7 @@ #include "graph_tools.hpp" #include "net_pass.h" #include "precision_utils.h" +#include "cnn_network_ngraph_impl.hpp" using std::string; @@ -148,30 +149,8 @@ CNNLayerPtr clonelayer(const CNNLayer& source) { } std::shared_ptr cloneNetwork(const ICNNNetwork& network) { - if (auto func = network.getFunction()) { - CNNNetwork net(func); - - InputsDataMap originInputs; - OutputsDataMap originOutputs; - network.getInputsInfo(originInputs); - network.getOutputsInfo(originOutputs); - InputsDataMap clonedInputs = net.getInputsInfo(); - OutputsDataMap clonedOutputs = net.getOutputsInfo(); - - for (const auto& outputInfo : originOutputs) { - if (clonedOutputs.find(outputInfo.first) == clonedOutputs.end()) - THROW_IE_EXCEPTION << "Cannot clone network! Cloned network doesn't contain all outputs"; - clonedOutputs[outputInfo.first]->setPrecision(outputInfo.second->getPrecision()); - clonedOutputs[outputInfo.first]->setLayout(outputInfo.second->getLayout()); - } - for (const auto& inputInfo : originInputs) { - if (clonedInputs.find(inputInfo.first) == clonedInputs.end()) - THROW_IE_EXCEPTION << "Cannot clone network! Cloned network doesn't contain all inputs"; - clonedInputs[inputInfo.first]->setPrecision(inputInfo.second->getPrecision()); - clonedInputs[inputInfo.first]->setLayout(inputInfo.second->getLayout()); - clonedInputs[inputInfo.first]->getPreProcess() = inputInfo.second->getPreProcess(); - } - return net; + if (network.getFunction()) { + return std::make_shared(network); } return cloneNet(network); diff --git a/inference-engine/tests/functional/inference_engine/cnn_network_test.cpp b/inference-engine/tests/functional/inference_engine/cnn_network_test.cpp index c506949bcb0..c07a59916e9 100644 --- a/inference-engine/tests/functional/inference_engine/cnn_network_test.cpp +++ b/inference-engine/tests/functional/inference_engine/cnn_network_test.cpp @@ -15,7 +15,7 @@ TEST_F(CNNNetworkTests, throwsOnInitWithNull) { } TEST_F(CNNNetworkTests, throwsOnInitWithNullNgraph) { - std::shared_ptr nlptr = nullptr; + std::shared_ptr nlptr = nullptr; ASSERT_THROW(CNNNetwork network(nlptr), InferenceEngine::details::InferenceEngineException); } diff --git a/inference-engine/tests/functional/inference_engine/ngraph_reshape_tests.cpp b/inference-engine/tests/functional/inference_engine/ngraph_reshape_tests.cpp index cd7fb7af826..f02993e2ea7 100644 --- a/inference-engine/tests/functional/inference_engine/ngraph_reshape_tests.cpp +++ b/inference-engine/tests/functional/inference_engine/ngraph_reshape_tests.cpp @@ -21,6 +21,7 @@ #include #include #include +#include #include #include @@ -121,6 +122,39 @@ TEST_F(NGraphReshapeTests, ReshapeSpatialReLU) { } TEST_F(NGraphReshapeTests, CNNReshapeSpatialReLU) { + std::shared_ptr ngraph; + { + ngraph::PartialShape shape({1, 3, 22, 22}); + ngraph::element::Type type(ngraph::element::Type_t::f32); + auto param = std::make_shared(type, shape); + param->set_friendly_name("data"); + auto relu = std::make_shared(param); + auto result = std::make_shared(relu); + + ngraph::ParameterVector params = {param}; + ngraph::ResultVector results = {result}; + + ngraph = std::make_shared(results, params); + } + + ASSERT_EQ(ngraph->get_parameters()[0]->get_shape(), ngraph::Shape({1, 3, 22, 22})); + ASSERT_EQ(ngraph->get_results()[0]->get_shape(), ngraph::Shape({1, 3, 22, 22})); + + CNNNetwork cnnNetwork(ngraph::clone_function(*ngraph)); + std::map> shapes; + shapes["data"] = {1, 3, 25, 25}; + + ASSERT_NO_THROW(cnnNetwork.reshape(shapes)); + + auto changedFunction = cnnNetwork.getFunction(); + ASSERT_NE(nullptr, changedFunction); + ASSERT_EQ(changedFunction->get_parameters()[0]->get_shape(), ngraph::Shape({1, 3, 25, 25})); + ASSERT_EQ(changedFunction->get_results()[0]->get_shape(), ngraph::Shape({1, 3, 25, 25})); + ASSERT_EQ(ngraph->get_parameters()[0]->get_shape(), ngraph::Shape({1, 3, 22, 22})); + ASSERT_EQ(ngraph->get_results()[0]->get_shape(), ngraph::Shape({1, 3, 22, 22})); +} + +TEST_F(NGraphReshapeTests, CNNReshapeSpatialReLUWithoutCloneFunction) { std::shared_ptr ngraph; { ngraph::PartialShape shape({1, 3, 22, 22}); @@ -149,8 +183,8 @@ TEST_F(NGraphReshapeTests, CNNReshapeSpatialReLU) { ASSERT_NE(nullptr, changedFunction); ASSERT_EQ(changedFunction->get_parameters()[0]->get_shape(), ngraph::Shape({1, 3, 25, 25})); ASSERT_EQ(changedFunction->get_results()[0]->get_shape(), ngraph::Shape({1, 3, 25, 25})); - ASSERT_EQ(ngraph->get_parameters()[0]->get_shape(), ngraph::Shape({1, 3, 22, 22})); - ASSERT_EQ(ngraph->get_results()[0]->get_shape(), ngraph::Shape({1, 3, 22, 22})); + ASSERT_EQ(ngraph->get_parameters()[0]->get_shape(), ngraph::Shape({1, 3, 25, 25})); + ASSERT_EQ(ngraph->get_results()[0]->get_shape(), ngraph::Shape({1, 3, 25, 25})); } class CustomTestOp: public ngraph::op::Op {