Avoid redundant clone and reshape (#1376)
* Avoid redundant clone and reshape * Removed some constructors * Fixed output precision
This commit is contained in:
parent
2b1fc60435
commit
6c3b7ee8ca
@ -51,9 +51,11 @@ public:
|
|||||||
|
|
||||||
/**
|
/**
|
||||||
* @brief A constructor from ngraph::Function object
|
* @brief A constructor from ngraph::Function object
|
||||||
|
* This constructor wraps existing ngraph::Function
|
||||||
|
* If you want to avoid modification of original Function, please create a copy
|
||||||
* @param network Pointer to the ngraph::Function object
|
* @param network Pointer to the ngraph::Function object
|
||||||
*/
|
*/
|
||||||
explicit CNNNetwork(const std::shared_ptr<const ngraph::Function>& network);
|
explicit CNNNetwork(const std::shared_ptr<ngraph::Function>& network);
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* @brief A destructor
|
* @brief A destructor
|
||||||
|
@ -143,7 +143,7 @@ void dumpGraph(InferenceEngine::ICNNNetwork &network,
|
|||||||
|
|
||||||
|
|
||||||
void dumpGraph(InferenceEngine::ICNNNetwork& network,
|
void dumpGraph(InferenceEngine::ICNNNetwork& network,
|
||||||
const std::vector<std::shared_ptr<const ngraph::Function>>& subFunctions,
|
const std::vector<std::shared_ptr<ngraph::Function>>& subFunctions,
|
||||||
std::ostream& stream) {
|
std::ostream& stream) {
|
||||||
static const std::array<const char *, 9> colors{{"#FFC405",
|
static const std::array<const char *, 9> colors{{"#FFC405",
|
||||||
"#20F608",
|
"#20F608",
|
||||||
@ -665,13 +665,13 @@ void HeteroExecutableNetwork::InitNgraph(const InferenceEngine::ICNNNetwork& net
|
|||||||
InputsDataMap externalInputsData;
|
InputsDataMap externalInputsData;
|
||||||
network.getInputsInfo(externalInputsData);
|
network.getInputsInfo(externalInputsData);
|
||||||
networks.resize(orderedSubgraphs.size());
|
networks.resize(orderedSubgraphs.size());
|
||||||
std::vector<std::shared_ptr<const ngraph::Function>> subFunctions(orderedSubgraphs.size());
|
std::vector<std::shared_ptr<ngraph::Function>> subFunctions(orderedSubgraphs.size());
|
||||||
std::vector<bool> isInputSubnetwork(orderedSubgraphs.size());
|
std::vector<bool> isInputSubnetwork(orderedSubgraphs.size());
|
||||||
int id = 0;
|
int id = 0;
|
||||||
for (auto&& subgraph : orderedSubgraphs) {
|
for (auto&& subgraph : orderedSubgraphs) {
|
||||||
networks[id]._device = subgraph._affinity;
|
networks[id]._device = subgraph._affinity;
|
||||||
subFunctions[id] =
|
subFunctions[id] =
|
||||||
std::make_shared<const ngraph::Function>(subgraph._results, subgraph._parameters,
|
std::make_shared<ngraph::Function>(subgraph._results, subgraph._parameters,
|
||||||
_name + '_' + std::to_string(id));
|
_name + '_' + std::to_string(id));
|
||||||
networks[id]._clonedNetwork = CNNNetwork{subFunctions[id]};
|
networks[id]._clonedNetwork = CNNNetwork{subFunctions[id]};
|
||||||
// update of pre-processing info
|
// update of pre-processing info
|
||||||
|
@ -71,14 +71,13 @@ static std::shared_ptr<ngraph::Function> copyFunction(const std::shared_ptr<cons
|
|||||||
return specialized_function;
|
return specialized_function;
|
||||||
}
|
}
|
||||||
|
|
||||||
// WA: for cnnNetwork ngraph constructor
|
CNNNetwork::CNNNetwork(const std::shared_ptr<ngraph::Function>& graph) {
|
||||||
CNNNetwork::CNNNetwork(const std::shared_ptr<const ngraph::Function>& graph) {
|
|
||||||
if (graph == nullptr) {
|
if (graph == nullptr) {
|
||||||
THROW_IE_EXCEPTION << "CNNNetwork was not initialized: 'graph' object is empty";
|
THROW_IE_EXCEPTION << "CNNNetwork was not initialized: 'graph' object is empty";
|
||||||
}
|
}
|
||||||
|
|
||||||
// Copy nGraph function
|
// Create CNNNetworkNGraphImpl
|
||||||
network = std::make_shared<CNNNetworkNGraphImpl>(copyFunction(graph, false, {}));
|
network = std::make_shared<CNNNetworkNGraphImpl>(graph);
|
||||||
actual = network.get();
|
actual = network.get();
|
||||||
if (actual == nullptr) {
|
if (actual == nullptr) {
|
||||||
THROW_IE_EXCEPTION << "CNNNetwork was not initialized.";
|
THROW_IE_EXCEPTION << "CNNNetwork was not initialized.";
|
||||||
@ -146,6 +145,36 @@ CNNNetworkNGraphImpl::CNNNetworkNGraphImpl(const std::shared_ptr<Function>& nGra
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
CNNNetworkNGraphImpl::CNNNetworkNGraphImpl(const ICNNNetwork& network) {
|
||||||
|
if (network.getFunction() == nullptr) {
|
||||||
|
THROW_IE_EXCEPTION << "Cannot create CNNNetwork with nGraph from legacy network format!";
|
||||||
|
}
|
||||||
|
|
||||||
|
_ngraph_function = copyFunction(network.getFunction(), false, {});
|
||||||
|
InputsDataMap inputs;
|
||||||
|
OutputsDataMap outputs;
|
||||||
|
network.getInputsInfo(inputs);
|
||||||
|
network.getOutputsInfo(outputs);
|
||||||
|
|
||||||
|
for (const auto& outputInfo : outputs) {
|
||||||
|
const auto& name = outputInfo.second->getName();
|
||||||
|
DataPtr output = std::make_shared<Data>(name, outputInfo.second->getTensorDesc());
|
||||||
|
_outputData[name] = output;
|
||||||
|
_data[name] = output;
|
||||||
|
}
|
||||||
|
for (const auto& inputInfo : inputs) {
|
||||||
|
InputInfo::Ptr info = std::make_shared<InputInfo>();
|
||||||
|
const auto& name = inputInfo.second->getInputData()->getName();
|
||||||
|
DataPtr input = std::make_shared<Data>(name, inputInfo.second->getInputData()->getTensorDesc());
|
||||||
|
_data[name] = input;
|
||||||
|
info->setInputData(input);
|
||||||
|
info->getPreProcess() = inputInfo.second->getPreProcess();
|
||||||
|
info->setPrecision(inputInfo.second->getPrecision());
|
||||||
|
info->setLayout(inputInfo.second->getLayout());
|
||||||
|
_inputData[name] = info;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
void CNNNetworkNGraphImpl::setInputInfo(InputInfo::Ptr data) {
|
void CNNNetworkNGraphImpl::setInputInfo(InputInfo::Ptr data) {
|
||||||
if (cnnNetwork) cnnNetwork->setInputInfo(data);
|
if (cnnNetwork) cnnNetwork->setInputInfo(data);
|
||||||
_inputData[data->name()] = data;
|
_inputData[data->name()] = data;
|
||||||
|
@ -43,6 +43,7 @@ namespace details {
|
|||||||
class INFERENCE_ENGINE_API_CLASS(CNNNetworkNGraphImpl): public ICNNNetwork {
|
class INFERENCE_ENGINE_API_CLASS(CNNNetworkNGraphImpl): public ICNNNetwork {
|
||||||
public:
|
public:
|
||||||
CNNNetworkNGraphImpl(const std::shared_ptr<::ngraph::Function>& nGraph);
|
CNNNetworkNGraphImpl(const std::shared_ptr<::ngraph::Function>& nGraph);
|
||||||
|
CNNNetworkNGraphImpl(const ICNNNetwork& nGraph);
|
||||||
~CNNNetworkNGraphImpl() override = default;
|
~CNNNetworkNGraphImpl() override = default;
|
||||||
|
|
||||||
void getOutputsInfo(std::map<std::string, DataPtr>& out) const noexcept override;
|
void getOutputsInfo(std::map<std::string, DataPtr>& out) const noexcept override;
|
||||||
|
@ -24,6 +24,7 @@
|
|||||||
#include "graph_tools.hpp"
|
#include "graph_tools.hpp"
|
||||||
#include "net_pass.h"
|
#include "net_pass.h"
|
||||||
#include "precision_utils.h"
|
#include "precision_utils.h"
|
||||||
|
#include "cnn_network_ngraph_impl.hpp"
|
||||||
|
|
||||||
using std::string;
|
using std::string;
|
||||||
|
|
||||||
@ -148,30 +149,8 @@ CNNLayerPtr clonelayer(const CNNLayer& source) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
std::shared_ptr<ICNNNetwork> cloneNetwork(const ICNNNetwork& network) {
|
std::shared_ptr<ICNNNetwork> cloneNetwork(const ICNNNetwork& network) {
|
||||||
if (auto func = network.getFunction()) {
|
if (network.getFunction()) {
|
||||||
CNNNetwork net(func);
|
return std::make_shared<details::CNNNetworkNGraphImpl>(network);
|
||||||
|
|
||||||
InputsDataMap originInputs;
|
|
||||||
OutputsDataMap originOutputs;
|
|
||||||
network.getInputsInfo(originInputs);
|
|
||||||
network.getOutputsInfo(originOutputs);
|
|
||||||
InputsDataMap clonedInputs = net.getInputsInfo();
|
|
||||||
OutputsDataMap clonedOutputs = net.getOutputsInfo();
|
|
||||||
|
|
||||||
for (const auto& outputInfo : originOutputs) {
|
|
||||||
if (clonedOutputs.find(outputInfo.first) == clonedOutputs.end())
|
|
||||||
THROW_IE_EXCEPTION << "Cannot clone network! Cloned network doesn't contain all outputs";
|
|
||||||
clonedOutputs[outputInfo.first]->setPrecision(outputInfo.second->getPrecision());
|
|
||||||
clonedOutputs[outputInfo.first]->setLayout(outputInfo.second->getLayout());
|
|
||||||
}
|
|
||||||
for (const auto& inputInfo : originInputs) {
|
|
||||||
if (clonedInputs.find(inputInfo.first) == clonedInputs.end())
|
|
||||||
THROW_IE_EXCEPTION << "Cannot clone network! Cloned network doesn't contain all inputs";
|
|
||||||
clonedInputs[inputInfo.first]->setPrecision(inputInfo.second->getPrecision());
|
|
||||||
clonedInputs[inputInfo.first]->setLayout(inputInfo.second->getLayout());
|
|
||||||
clonedInputs[inputInfo.first]->getPreProcess() = inputInfo.second->getPreProcess();
|
|
||||||
}
|
|
||||||
return net;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return cloneNet(network);
|
return cloneNet(network);
|
||||||
|
@ -15,7 +15,7 @@ TEST_F(CNNNetworkTests, throwsOnInitWithNull) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(CNNNetworkTests, throwsOnInitWithNullNgraph) {
|
TEST_F(CNNNetworkTests, throwsOnInitWithNullNgraph) {
|
||||||
std::shared_ptr<const ngraph::Function> nlptr = nullptr;
|
std::shared_ptr<ngraph::Function> nlptr = nullptr;
|
||||||
ASSERT_THROW(CNNNetwork network(nlptr), InferenceEngine::details::InferenceEngineException);
|
ASSERT_THROW(CNNNetwork network(nlptr), InferenceEngine::details::InferenceEngineException);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -21,6 +21,7 @@
|
|||||||
#include <ngraph/op/relu.hpp>
|
#include <ngraph/op/relu.hpp>
|
||||||
#include <ngraph/op/result.hpp>
|
#include <ngraph/op/result.hpp>
|
||||||
#include <ngraph/opsets/opset.hpp>
|
#include <ngraph/opsets/opset.hpp>
|
||||||
|
#include <ngraph/graph_util.hpp>
|
||||||
|
|
||||||
#include <ie_util_internal.hpp>
|
#include <ie_util_internal.hpp>
|
||||||
#include <ie_core.hpp>
|
#include <ie_core.hpp>
|
||||||
@ -121,6 +122,39 @@ TEST_F(NGraphReshapeTests, ReshapeSpatialReLU) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(NGraphReshapeTests, CNNReshapeSpatialReLU) {
|
TEST_F(NGraphReshapeTests, CNNReshapeSpatialReLU) {
|
||||||
|
std::shared_ptr<const ngraph::Function> ngraph;
|
||||||
|
{
|
||||||
|
ngraph::PartialShape shape({1, 3, 22, 22});
|
||||||
|
ngraph::element::Type type(ngraph::element::Type_t::f32);
|
||||||
|
auto param = std::make_shared<ngraph::op::Parameter>(type, shape);
|
||||||
|
param->set_friendly_name("data");
|
||||||
|
auto relu = std::make_shared<ngraph::op::Relu>(param);
|
||||||
|
auto result = std::make_shared<ngraph::op::Result>(relu);
|
||||||
|
|
||||||
|
ngraph::ParameterVector params = {param};
|
||||||
|
ngraph::ResultVector results = {result};
|
||||||
|
|
||||||
|
ngraph = std::make_shared<const ngraph::Function>(results, params);
|
||||||
|
}
|
||||||
|
|
||||||
|
ASSERT_EQ(ngraph->get_parameters()[0]->get_shape(), ngraph::Shape({1, 3, 22, 22}));
|
||||||
|
ASSERT_EQ(ngraph->get_results()[0]->get_shape(), ngraph::Shape({1, 3, 22, 22}));
|
||||||
|
|
||||||
|
CNNNetwork cnnNetwork(ngraph::clone_function(*ngraph));
|
||||||
|
std::map<std::string, std::vector<size_t>> shapes;
|
||||||
|
shapes["data"] = {1, 3, 25, 25};
|
||||||
|
|
||||||
|
ASSERT_NO_THROW(cnnNetwork.reshape(shapes));
|
||||||
|
|
||||||
|
auto changedFunction = cnnNetwork.getFunction();
|
||||||
|
ASSERT_NE(nullptr, changedFunction);
|
||||||
|
ASSERT_EQ(changedFunction->get_parameters()[0]->get_shape(), ngraph::Shape({1, 3, 25, 25}));
|
||||||
|
ASSERT_EQ(changedFunction->get_results()[0]->get_shape(), ngraph::Shape({1, 3, 25, 25}));
|
||||||
|
ASSERT_EQ(ngraph->get_parameters()[0]->get_shape(), ngraph::Shape({1, 3, 22, 22}));
|
||||||
|
ASSERT_EQ(ngraph->get_results()[0]->get_shape(), ngraph::Shape({1, 3, 22, 22}));
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(NGraphReshapeTests, CNNReshapeSpatialReLUWithoutCloneFunction) {
|
||||||
std::shared_ptr<ngraph::Function> ngraph;
|
std::shared_ptr<ngraph::Function> ngraph;
|
||||||
{
|
{
|
||||||
ngraph::PartialShape shape({1, 3, 22, 22});
|
ngraph::PartialShape shape({1, 3, 22, 22});
|
||||||
@ -149,8 +183,8 @@ TEST_F(NGraphReshapeTests, CNNReshapeSpatialReLU) {
|
|||||||
ASSERT_NE(nullptr, changedFunction);
|
ASSERT_NE(nullptr, changedFunction);
|
||||||
ASSERT_EQ(changedFunction->get_parameters()[0]->get_shape(), ngraph::Shape({1, 3, 25, 25}));
|
ASSERT_EQ(changedFunction->get_parameters()[0]->get_shape(), ngraph::Shape({1, 3, 25, 25}));
|
||||||
ASSERT_EQ(changedFunction->get_results()[0]->get_shape(), ngraph::Shape({1, 3, 25, 25}));
|
ASSERT_EQ(changedFunction->get_results()[0]->get_shape(), ngraph::Shape({1, 3, 25, 25}));
|
||||||
ASSERT_EQ(ngraph->get_parameters()[0]->get_shape(), ngraph::Shape({1, 3, 22, 22}));
|
ASSERT_EQ(ngraph->get_parameters()[0]->get_shape(), ngraph::Shape({1, 3, 25, 25}));
|
||||||
ASSERT_EQ(ngraph->get_results()[0]->get_shape(), ngraph::Shape({1, 3, 22, 22}));
|
ASSERT_EQ(ngraph->get_results()[0]->get_shape(), ngraph::Shape({1, 3, 25, 25}));
|
||||||
}
|
}
|
||||||
|
|
||||||
class CustomTestOp: public ngraph::op::Op {
|
class CustomTestOp: public ngraph::op::Op {
|
||||||
|
Loading…
Reference in New Issue
Block a user