Avoid redundant clone and reshape (#1376)

* Avoid redundant clone and reshape

* Removed some constructors

* Fixed output precision
This commit is contained in:
Ilya Churaev 2020-07-29 19:30:59 +03:00 committed by GitHub
parent 2b1fc60435
commit 6c3b7ee8ca
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 80 additions and 35 deletions

View File

@ -51,9 +51,11 @@ public:
/**
* @brief A constructor from ngraph::Function object
* This constructor wraps existing ngraph::Function
* If you want to avoid modification of original Function, please create a copy
* @param network Pointer to the ngraph::Function object
*/
explicit CNNNetwork(const std::shared_ptr<const ngraph::Function>& network);
explicit CNNNetwork(const std::shared_ptr<ngraph::Function>& network);
/**
* @brief A destructor

View File

@ -143,7 +143,7 @@ void dumpGraph(InferenceEngine::ICNNNetwork &network,
void dumpGraph(InferenceEngine::ICNNNetwork& network,
const std::vector<std::shared_ptr<const ngraph::Function>>& subFunctions,
const std::vector<std::shared_ptr<ngraph::Function>>& subFunctions,
std::ostream& stream) {
static const std::array<const char *, 9> colors{{"#FFC405",
"#20F608",
@ -665,13 +665,13 @@ void HeteroExecutableNetwork::InitNgraph(const InferenceEngine::ICNNNetwork& net
InputsDataMap externalInputsData;
network.getInputsInfo(externalInputsData);
networks.resize(orderedSubgraphs.size());
std::vector<std::shared_ptr<const ngraph::Function>> subFunctions(orderedSubgraphs.size());
std::vector<std::shared_ptr<ngraph::Function>> subFunctions(orderedSubgraphs.size());
std::vector<bool> isInputSubnetwork(orderedSubgraphs.size());
int id = 0;
for (auto&& subgraph : orderedSubgraphs) {
networks[id]._device = subgraph._affinity;
subFunctions[id] =
std::make_shared<const ngraph::Function>(subgraph._results, subgraph._parameters,
std::make_shared<ngraph::Function>(subgraph._results, subgraph._parameters,
_name + '_' + std::to_string(id));
networks[id]._clonedNetwork = CNNNetwork{subFunctions[id]};
// update of pre-processing info

View File

@ -71,14 +71,13 @@ static std::shared_ptr<ngraph::Function> copyFunction(const std::shared_ptr<cons
return specialized_function;
}
// WA: for cnnNetwork ngraph constructor
CNNNetwork::CNNNetwork(const std::shared_ptr<const ngraph::Function>& graph) {
CNNNetwork::CNNNetwork(const std::shared_ptr<ngraph::Function>& graph) {
if (graph == nullptr) {
THROW_IE_EXCEPTION << "CNNNetwork was not initialized: 'graph' object is empty";
}
// Copy nGraph function
network = std::make_shared<CNNNetworkNGraphImpl>(copyFunction(graph, false, {}));
// Create CNNNetworkNGraphImpl
network = std::make_shared<CNNNetworkNGraphImpl>(graph);
actual = network.get();
if (actual == nullptr) {
THROW_IE_EXCEPTION << "CNNNetwork was not initialized.";
@ -146,6 +145,36 @@ CNNNetworkNGraphImpl::CNNNetworkNGraphImpl(const std::shared_ptr<Function>& nGra
}
}
CNNNetworkNGraphImpl::CNNNetworkNGraphImpl(const ICNNNetwork& network) {
if (network.getFunction() == nullptr) {
THROW_IE_EXCEPTION << "Cannot create CNNNetwork with nGraph from legacy network format!";
}
_ngraph_function = copyFunction(network.getFunction(), false, {});
InputsDataMap inputs;
OutputsDataMap outputs;
network.getInputsInfo(inputs);
network.getOutputsInfo(outputs);
for (const auto& outputInfo : outputs) {
const auto& name = outputInfo.second->getName();
DataPtr output = std::make_shared<Data>(name, outputInfo.second->getTensorDesc());
_outputData[name] = output;
_data[name] = output;
}
for (const auto& inputInfo : inputs) {
InputInfo::Ptr info = std::make_shared<InputInfo>();
const auto& name = inputInfo.second->getInputData()->getName();
DataPtr input = std::make_shared<Data>(name, inputInfo.second->getInputData()->getTensorDesc());
_data[name] = input;
info->setInputData(input);
info->getPreProcess() = inputInfo.second->getPreProcess();
info->setPrecision(inputInfo.second->getPrecision());
info->setLayout(inputInfo.second->getLayout());
_inputData[name] = info;
}
}
void CNNNetworkNGraphImpl::setInputInfo(InputInfo::Ptr data) {
if (cnnNetwork) cnnNetwork->setInputInfo(data);
_inputData[data->name()] = data;

View File

@ -43,6 +43,7 @@ namespace details {
class INFERENCE_ENGINE_API_CLASS(CNNNetworkNGraphImpl): public ICNNNetwork {
public:
CNNNetworkNGraphImpl(const std::shared_ptr<::ngraph::Function>& nGraph);
CNNNetworkNGraphImpl(const ICNNNetwork& nGraph);
~CNNNetworkNGraphImpl() override = default;
void getOutputsInfo(std::map<std::string, DataPtr>& out) const noexcept override;

View File

@ -24,6 +24,7 @@
#include "graph_tools.hpp"
#include "net_pass.h"
#include "precision_utils.h"
#include "cnn_network_ngraph_impl.hpp"
using std::string;
@ -148,30 +149,8 @@ CNNLayerPtr clonelayer(const CNNLayer& source) {
}
std::shared_ptr<ICNNNetwork> cloneNetwork(const ICNNNetwork& network) {
if (auto func = network.getFunction()) {
CNNNetwork net(func);
InputsDataMap originInputs;
OutputsDataMap originOutputs;
network.getInputsInfo(originInputs);
network.getOutputsInfo(originOutputs);
InputsDataMap clonedInputs = net.getInputsInfo();
OutputsDataMap clonedOutputs = net.getOutputsInfo();
for (const auto& outputInfo : originOutputs) {
if (clonedOutputs.find(outputInfo.first) == clonedOutputs.end())
THROW_IE_EXCEPTION << "Cannot clone network! Cloned network doesn't contain all outputs";
clonedOutputs[outputInfo.first]->setPrecision(outputInfo.second->getPrecision());
clonedOutputs[outputInfo.first]->setLayout(outputInfo.second->getLayout());
}
for (const auto& inputInfo : originInputs) {
if (clonedInputs.find(inputInfo.first) == clonedInputs.end())
THROW_IE_EXCEPTION << "Cannot clone network! Cloned network doesn't contain all inputs";
clonedInputs[inputInfo.first]->setPrecision(inputInfo.second->getPrecision());
clonedInputs[inputInfo.first]->setLayout(inputInfo.second->getLayout());
clonedInputs[inputInfo.first]->getPreProcess() = inputInfo.second->getPreProcess();
}
return net;
if (network.getFunction()) {
return std::make_shared<details::CNNNetworkNGraphImpl>(network);
}
return cloneNet(network);

View File

@ -15,7 +15,7 @@ TEST_F(CNNNetworkTests, throwsOnInitWithNull) {
}
TEST_F(CNNNetworkTests, throwsOnInitWithNullNgraph) {
std::shared_ptr<const ngraph::Function> nlptr = nullptr;
std::shared_ptr<ngraph::Function> nlptr = nullptr;
ASSERT_THROW(CNNNetwork network(nlptr), InferenceEngine::details::InferenceEngineException);
}

View File

@ -21,6 +21,7 @@
#include <ngraph/op/relu.hpp>
#include <ngraph/op/result.hpp>
#include <ngraph/opsets/opset.hpp>
#include <ngraph/graph_util.hpp>
#include <ie_util_internal.hpp>
#include <ie_core.hpp>
@ -121,6 +122,39 @@ TEST_F(NGraphReshapeTests, ReshapeSpatialReLU) {
}
TEST_F(NGraphReshapeTests, CNNReshapeSpatialReLU) {
std::shared_ptr<const ngraph::Function> ngraph;
{
ngraph::PartialShape shape({1, 3, 22, 22});
ngraph::element::Type type(ngraph::element::Type_t::f32);
auto param = std::make_shared<ngraph::op::Parameter>(type, shape);
param->set_friendly_name("data");
auto relu = std::make_shared<ngraph::op::Relu>(param);
auto result = std::make_shared<ngraph::op::Result>(relu);
ngraph::ParameterVector params = {param};
ngraph::ResultVector results = {result};
ngraph = std::make_shared<const ngraph::Function>(results, params);
}
ASSERT_EQ(ngraph->get_parameters()[0]->get_shape(), ngraph::Shape({1, 3, 22, 22}));
ASSERT_EQ(ngraph->get_results()[0]->get_shape(), ngraph::Shape({1, 3, 22, 22}));
CNNNetwork cnnNetwork(ngraph::clone_function(*ngraph));
std::map<std::string, std::vector<size_t>> shapes;
shapes["data"] = {1, 3, 25, 25};
ASSERT_NO_THROW(cnnNetwork.reshape(shapes));
auto changedFunction = cnnNetwork.getFunction();
ASSERT_NE(nullptr, changedFunction);
ASSERT_EQ(changedFunction->get_parameters()[0]->get_shape(), ngraph::Shape({1, 3, 25, 25}));
ASSERT_EQ(changedFunction->get_results()[0]->get_shape(), ngraph::Shape({1, 3, 25, 25}));
ASSERT_EQ(ngraph->get_parameters()[0]->get_shape(), ngraph::Shape({1, 3, 22, 22}));
ASSERT_EQ(ngraph->get_results()[0]->get_shape(), ngraph::Shape({1, 3, 22, 22}));
}
TEST_F(NGraphReshapeTests, CNNReshapeSpatialReLUWithoutCloneFunction) {
std::shared_ptr<ngraph::Function> ngraph;
{
ngraph::PartialShape shape({1, 3, 22, 22});
@ -149,8 +183,8 @@ TEST_F(NGraphReshapeTests, CNNReshapeSpatialReLU) {
ASSERT_NE(nullptr, changedFunction);
ASSERT_EQ(changedFunction->get_parameters()[0]->get_shape(), ngraph::Shape({1, 3, 25, 25}));
ASSERT_EQ(changedFunction->get_results()[0]->get_shape(), ngraph::Shape({1, 3, 25, 25}));
ASSERT_EQ(ngraph->get_parameters()[0]->get_shape(), ngraph::Shape({1, 3, 22, 22}));
ASSERT_EQ(ngraph->get_results()[0]->get_shape(), ngraph::Shape({1, 3, 22, 22}));
ASSERT_EQ(ngraph->get_parameters()[0]->get_shape(), ngraph::Shape({1, 3, 25, 25}));
ASSERT_EQ(ngraph->get_results()[0]->get_shape(), ngraph::Shape({1, 3, 25, 25}));
}
class CustomTestOp: public ngraph::op::Op {