Avoid redundant clone and reshape (#1376)
* Avoid redundant clone and reshape * Removed some constructors * Fixed output precision
This commit is contained in:
parent
2b1fc60435
commit
6c3b7ee8ca
@ -51,9 +51,11 @@ public:
|
||||
|
||||
/**
|
||||
* @brief A constructor from ngraph::Function object
|
||||
* This constructor wraps existing ngraph::Function
|
||||
* If you want to avoid modification of original Function, please create a copy
|
||||
* @param network Pointer to the ngraph::Function object
|
||||
*/
|
||||
explicit CNNNetwork(const std::shared_ptr<const ngraph::Function>& network);
|
||||
explicit CNNNetwork(const std::shared_ptr<ngraph::Function>& network);
|
||||
|
||||
/**
|
||||
* @brief A destructor
|
||||
|
@ -143,7 +143,7 @@ void dumpGraph(InferenceEngine::ICNNNetwork &network,
|
||||
|
||||
|
||||
void dumpGraph(InferenceEngine::ICNNNetwork& network,
|
||||
const std::vector<std::shared_ptr<const ngraph::Function>>& subFunctions,
|
||||
const std::vector<std::shared_ptr<ngraph::Function>>& subFunctions,
|
||||
std::ostream& stream) {
|
||||
static const std::array<const char *, 9> colors{{"#FFC405",
|
||||
"#20F608",
|
||||
@ -665,13 +665,13 @@ void HeteroExecutableNetwork::InitNgraph(const InferenceEngine::ICNNNetwork& net
|
||||
InputsDataMap externalInputsData;
|
||||
network.getInputsInfo(externalInputsData);
|
||||
networks.resize(orderedSubgraphs.size());
|
||||
std::vector<std::shared_ptr<const ngraph::Function>> subFunctions(orderedSubgraphs.size());
|
||||
std::vector<std::shared_ptr<ngraph::Function>> subFunctions(orderedSubgraphs.size());
|
||||
std::vector<bool> isInputSubnetwork(orderedSubgraphs.size());
|
||||
int id = 0;
|
||||
for (auto&& subgraph : orderedSubgraphs) {
|
||||
networks[id]._device = subgraph._affinity;
|
||||
subFunctions[id] =
|
||||
std::make_shared<const ngraph::Function>(subgraph._results, subgraph._parameters,
|
||||
std::make_shared<ngraph::Function>(subgraph._results, subgraph._parameters,
|
||||
_name + '_' + std::to_string(id));
|
||||
networks[id]._clonedNetwork = CNNNetwork{subFunctions[id]};
|
||||
// update of pre-processing info
|
||||
|
@ -71,14 +71,13 @@ static std::shared_ptr<ngraph::Function> copyFunction(const std::shared_ptr<cons
|
||||
return specialized_function;
|
||||
}
|
||||
|
||||
// WA: for cnnNetwork ngraph constructor
|
||||
CNNNetwork::CNNNetwork(const std::shared_ptr<const ngraph::Function>& graph) {
|
||||
CNNNetwork::CNNNetwork(const std::shared_ptr<ngraph::Function>& graph) {
|
||||
if (graph == nullptr) {
|
||||
THROW_IE_EXCEPTION << "CNNNetwork was not initialized: 'graph' object is empty";
|
||||
}
|
||||
|
||||
// Copy nGraph function
|
||||
network = std::make_shared<CNNNetworkNGraphImpl>(copyFunction(graph, false, {}));
|
||||
// Create CNNNetworkNGraphImpl
|
||||
network = std::make_shared<CNNNetworkNGraphImpl>(graph);
|
||||
actual = network.get();
|
||||
if (actual == nullptr) {
|
||||
THROW_IE_EXCEPTION << "CNNNetwork was not initialized.";
|
||||
@ -146,6 +145,36 @@ CNNNetworkNGraphImpl::CNNNetworkNGraphImpl(const std::shared_ptr<Function>& nGra
|
||||
}
|
||||
}
|
||||
|
||||
CNNNetworkNGraphImpl::CNNNetworkNGraphImpl(const ICNNNetwork& network) {
|
||||
if (network.getFunction() == nullptr) {
|
||||
THROW_IE_EXCEPTION << "Cannot create CNNNetwork with nGraph from legacy network format!";
|
||||
}
|
||||
|
||||
_ngraph_function = copyFunction(network.getFunction(), false, {});
|
||||
InputsDataMap inputs;
|
||||
OutputsDataMap outputs;
|
||||
network.getInputsInfo(inputs);
|
||||
network.getOutputsInfo(outputs);
|
||||
|
||||
for (const auto& outputInfo : outputs) {
|
||||
const auto& name = outputInfo.second->getName();
|
||||
DataPtr output = std::make_shared<Data>(name, outputInfo.second->getTensorDesc());
|
||||
_outputData[name] = output;
|
||||
_data[name] = output;
|
||||
}
|
||||
for (const auto& inputInfo : inputs) {
|
||||
InputInfo::Ptr info = std::make_shared<InputInfo>();
|
||||
const auto& name = inputInfo.second->getInputData()->getName();
|
||||
DataPtr input = std::make_shared<Data>(name, inputInfo.second->getInputData()->getTensorDesc());
|
||||
_data[name] = input;
|
||||
info->setInputData(input);
|
||||
info->getPreProcess() = inputInfo.second->getPreProcess();
|
||||
info->setPrecision(inputInfo.second->getPrecision());
|
||||
info->setLayout(inputInfo.second->getLayout());
|
||||
_inputData[name] = info;
|
||||
}
|
||||
}
|
||||
|
||||
void CNNNetworkNGraphImpl::setInputInfo(InputInfo::Ptr data) {
|
||||
if (cnnNetwork) cnnNetwork->setInputInfo(data);
|
||||
_inputData[data->name()] = data;
|
||||
|
@ -43,6 +43,7 @@ namespace details {
|
||||
class INFERENCE_ENGINE_API_CLASS(CNNNetworkNGraphImpl): public ICNNNetwork {
|
||||
public:
|
||||
CNNNetworkNGraphImpl(const std::shared_ptr<::ngraph::Function>& nGraph);
|
||||
CNNNetworkNGraphImpl(const ICNNNetwork& nGraph);
|
||||
~CNNNetworkNGraphImpl() override = default;
|
||||
|
||||
void getOutputsInfo(std::map<std::string, DataPtr>& out) const noexcept override;
|
||||
|
@ -24,6 +24,7 @@
|
||||
#include "graph_tools.hpp"
|
||||
#include "net_pass.h"
|
||||
#include "precision_utils.h"
|
||||
#include "cnn_network_ngraph_impl.hpp"
|
||||
|
||||
using std::string;
|
||||
|
||||
@ -148,30 +149,8 @@ CNNLayerPtr clonelayer(const CNNLayer& source) {
|
||||
}
|
||||
|
||||
std::shared_ptr<ICNNNetwork> cloneNetwork(const ICNNNetwork& network) {
|
||||
if (auto func = network.getFunction()) {
|
||||
CNNNetwork net(func);
|
||||
|
||||
InputsDataMap originInputs;
|
||||
OutputsDataMap originOutputs;
|
||||
network.getInputsInfo(originInputs);
|
||||
network.getOutputsInfo(originOutputs);
|
||||
InputsDataMap clonedInputs = net.getInputsInfo();
|
||||
OutputsDataMap clonedOutputs = net.getOutputsInfo();
|
||||
|
||||
for (const auto& outputInfo : originOutputs) {
|
||||
if (clonedOutputs.find(outputInfo.first) == clonedOutputs.end())
|
||||
THROW_IE_EXCEPTION << "Cannot clone network! Cloned network doesn't contain all outputs";
|
||||
clonedOutputs[outputInfo.first]->setPrecision(outputInfo.second->getPrecision());
|
||||
clonedOutputs[outputInfo.first]->setLayout(outputInfo.second->getLayout());
|
||||
}
|
||||
for (const auto& inputInfo : originInputs) {
|
||||
if (clonedInputs.find(inputInfo.first) == clonedInputs.end())
|
||||
THROW_IE_EXCEPTION << "Cannot clone network! Cloned network doesn't contain all inputs";
|
||||
clonedInputs[inputInfo.first]->setPrecision(inputInfo.second->getPrecision());
|
||||
clonedInputs[inputInfo.first]->setLayout(inputInfo.second->getLayout());
|
||||
clonedInputs[inputInfo.first]->getPreProcess() = inputInfo.second->getPreProcess();
|
||||
}
|
||||
return net;
|
||||
if (network.getFunction()) {
|
||||
return std::make_shared<details::CNNNetworkNGraphImpl>(network);
|
||||
}
|
||||
|
||||
return cloneNet(network);
|
||||
|
@ -15,7 +15,7 @@ TEST_F(CNNNetworkTests, throwsOnInitWithNull) {
|
||||
}
|
||||
|
||||
TEST_F(CNNNetworkTests, throwsOnInitWithNullNgraph) {
|
||||
std::shared_ptr<const ngraph::Function> nlptr = nullptr;
|
||||
std::shared_ptr<ngraph::Function> nlptr = nullptr;
|
||||
ASSERT_THROW(CNNNetwork network(nlptr), InferenceEngine::details::InferenceEngineException);
|
||||
}
|
||||
|
||||
|
@ -21,6 +21,7 @@
|
||||
#include <ngraph/op/relu.hpp>
|
||||
#include <ngraph/op/result.hpp>
|
||||
#include <ngraph/opsets/opset.hpp>
|
||||
#include <ngraph/graph_util.hpp>
|
||||
|
||||
#include <ie_util_internal.hpp>
|
||||
#include <ie_core.hpp>
|
||||
@ -121,6 +122,39 @@ TEST_F(NGraphReshapeTests, ReshapeSpatialReLU) {
|
||||
}
|
||||
|
||||
TEST_F(NGraphReshapeTests, CNNReshapeSpatialReLU) {
|
||||
std::shared_ptr<const ngraph::Function> ngraph;
|
||||
{
|
||||
ngraph::PartialShape shape({1, 3, 22, 22});
|
||||
ngraph::element::Type type(ngraph::element::Type_t::f32);
|
||||
auto param = std::make_shared<ngraph::op::Parameter>(type, shape);
|
||||
param->set_friendly_name("data");
|
||||
auto relu = std::make_shared<ngraph::op::Relu>(param);
|
||||
auto result = std::make_shared<ngraph::op::Result>(relu);
|
||||
|
||||
ngraph::ParameterVector params = {param};
|
||||
ngraph::ResultVector results = {result};
|
||||
|
||||
ngraph = std::make_shared<const ngraph::Function>(results, params);
|
||||
}
|
||||
|
||||
ASSERT_EQ(ngraph->get_parameters()[0]->get_shape(), ngraph::Shape({1, 3, 22, 22}));
|
||||
ASSERT_EQ(ngraph->get_results()[0]->get_shape(), ngraph::Shape({1, 3, 22, 22}));
|
||||
|
||||
CNNNetwork cnnNetwork(ngraph::clone_function(*ngraph));
|
||||
std::map<std::string, std::vector<size_t>> shapes;
|
||||
shapes["data"] = {1, 3, 25, 25};
|
||||
|
||||
ASSERT_NO_THROW(cnnNetwork.reshape(shapes));
|
||||
|
||||
auto changedFunction = cnnNetwork.getFunction();
|
||||
ASSERT_NE(nullptr, changedFunction);
|
||||
ASSERT_EQ(changedFunction->get_parameters()[0]->get_shape(), ngraph::Shape({1, 3, 25, 25}));
|
||||
ASSERT_EQ(changedFunction->get_results()[0]->get_shape(), ngraph::Shape({1, 3, 25, 25}));
|
||||
ASSERT_EQ(ngraph->get_parameters()[0]->get_shape(), ngraph::Shape({1, 3, 22, 22}));
|
||||
ASSERT_EQ(ngraph->get_results()[0]->get_shape(), ngraph::Shape({1, 3, 22, 22}));
|
||||
}
|
||||
|
||||
TEST_F(NGraphReshapeTests, CNNReshapeSpatialReLUWithoutCloneFunction) {
|
||||
std::shared_ptr<ngraph::Function> ngraph;
|
||||
{
|
||||
ngraph::PartialShape shape({1, 3, 22, 22});
|
||||
@ -149,8 +183,8 @@ TEST_F(NGraphReshapeTests, CNNReshapeSpatialReLU) {
|
||||
ASSERT_NE(nullptr, changedFunction);
|
||||
ASSERT_EQ(changedFunction->get_parameters()[0]->get_shape(), ngraph::Shape({1, 3, 25, 25}));
|
||||
ASSERT_EQ(changedFunction->get_results()[0]->get_shape(), ngraph::Shape({1, 3, 25, 25}));
|
||||
ASSERT_EQ(ngraph->get_parameters()[0]->get_shape(), ngraph::Shape({1, 3, 22, 22}));
|
||||
ASSERT_EQ(ngraph->get_results()[0]->get_shape(), ngraph::Shape({1, 3, 22, 22}));
|
||||
ASSERT_EQ(ngraph->get_parameters()[0]->get_shape(), ngraph::Shape({1, 3, 25, 25}));
|
||||
ASSERT_EQ(ngraph->get_results()[0]->get_shape(), ngraph::Shape({1, 3, 25, 25}));
|
||||
}
|
||||
|
||||
class CustomTestOp: public ngraph::op::Op {
|
||||
|
Loading…
Reference in New Issue
Block a user