Files
openvino/inference-engine/include/cpp/ie_cnn_network.h
2020-06-03 06:04:50 +03:00

347 lines
10 KiB
C++

// Copyright (C) 2018-2020 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
/**
* @brief A header file that provides wrapper for ICNNNetwork object
*
* @file ie_cnn_network.h
*/
#pragma once
#include <ie_icnn_net_reader.h>
#include <details/ie_cnn_network_iterator.hpp>
#include <details/ie_exception_conversion.hpp>
#include <ie_icnn_network.hpp>
#include <map>
#include <memory>
#include <string>
#include <utility>
#include <vector>
#include "ie_blob.h"
#include "ie_common.h"
#include "ie_data.h"
namespace ngraph {
class Function;
} // namespace ngraph
namespace InferenceEngine {
/**
* @brief This class contains all the information about the Neural Network and the related binary information
*/
class INFERENCE_ENGINE_API_CLASS(CNNNetwork) {
public:
/**
* @brief A default constructor
*/
CNNNetwork() = default;
/**
* @brief Allows helper class to manage lifetime of network object
*
* @param network Pointer to the network object
*/
explicit CNNNetwork(std::shared_ptr<ICNNNetwork> network): network(network) {
actual = network.get();
if (actual == nullptr) THROW_IE_EXCEPTION << "CNNNetwork was not initialized.";
}
/**
* @brief A constructor from ngraph::Function object
* @param network Pointer to the ngraph::Function object
*/
explicit CNNNetwork(const std::shared_ptr<const ngraph::Function>& network);
/**
* @brief A constructor from ICNNNetReader object
*
* @param reader Pointer to the ICNNNetReader object
*/
IE_SUPPRESS_DEPRECATED_START
explicit CNNNetwork(CNNNetReaderPtr reader_): reader(reader_) {
if (reader == nullptr) {
THROW_IE_EXCEPTION << "ICNNNetReader was not initialized.";
}
if ((actual = reader->getNetwork(nullptr)) == nullptr) {
THROW_IE_EXCEPTION << "CNNNetwork was not initialized.";
}
}
IE_SUPPRESS_DEPRECATED_END
/**
* @brief A destructor
*/
virtual ~CNNNetwork() {}
/**
* @deprecated Network precision does not make sence, use precision on egdes. The method will be removed in 2021.1
* @copybrief ICNNNetwork::getPrecision
*
* Wraps ICNNNetwork::getPrecision
*
* @return A precision type
*/
INFERENCE_ENGINE_DEPRECATED("Network precision does not make sence, use precision on egdes. The method will be removed in 2021.1")
virtual Precision getPrecision() const;
/**
* @copybrief ICNNNetwork::getOutputsInfo
*
* Wraps ICNNNetwork::getOutputsInfo
*
* @return outputs Reference to the OutputsDataMap object
*/
virtual OutputsDataMap getOutputsInfo() const {
if (actual == nullptr) THROW_IE_EXCEPTION << "CNNNetwork was not initialized.";
OutputsDataMap outputs;
actual->getOutputsInfo(outputs);
return outputs;
}
/**
* @copybrief ICNNNetwork::getInputsInfo
*
* Wraps ICNNNetwork::getInputsInfo
*
* @return inputs Reference to InputsDataMap object
*/
virtual InputsDataMap getInputsInfo() const {
if (actual == nullptr) THROW_IE_EXCEPTION << "CNNNetwork was not initialized.";
InputsDataMap inputs;
actual->getInputsInfo(inputs);
return inputs;
}
/**
* @copybrief ICNNNetwork::layerCount
*
* Wraps ICNNNetwork::layerCount
*
* @return The number of layers as an integer value
*/
size_t layerCount() const {
if (actual == nullptr) THROW_IE_EXCEPTION << "CNNNetwork was not initialized.";
return actual->layerCount();
}
/**
* @copybrief ICNNNetwork::getName
*
* Wraps ICNNNetwork::getName
*
* @return Network name
*/
const std::string& getName() const {
if (actual == nullptr) THROW_IE_EXCEPTION << "CNNNetwork was not initialized.";
return actual->getName();
}
/**
* @copybrief ICNNNetwork::setBatchSize
*
* Wraps ICNNNetwork::setBatchSize
*
* @param size Size of batch to set
* @return Status code of the operation
*/
virtual void setBatchSize(const size_t size) {
CALL_STATUS_FNC(setBatchSize, size);
}
/**
* @copybrief ICNNNetwork::getBatchSize
*
* Wraps ICNNNetwork::getBatchSize
*
* @return The size of batch as a size_t value
*/
virtual size_t getBatchSize() const {
if (actual == nullptr) THROW_IE_EXCEPTION << "CNNNetwork was not initialized.";
return actual->getBatchSize();
}
/**
* @brief An overloaded operator cast to get pointer on current network
*
* @return A shared pointer of the current network
*/
operator ICNNNetwork::Ptr() {
return network;
}
/**
* @brief An overloaded operator & to get current network
*
* @return An instance of the current network
*/
operator ICNNNetwork&() {
if (actual == nullptr) THROW_IE_EXCEPTION << "CNNNetwork was not initialized.";
return *actual;
}
/**
* @brief An overloaded operator & to get current network
*
* @return A const reference of the current network
*/
operator const ICNNNetwork&() const {
if (actual == nullptr) THROW_IE_EXCEPTION << "CNNNetwork was not initialized.";
return *actual;
}
/**
* @brief Returns constant nGraph function
*
* @return constant nGraph function
*/
std::shared_ptr<ngraph::Function> getFunction() {
if (actual == nullptr) THROW_IE_EXCEPTION << "CNNNetwork was not initialized.";
return actual->getFunction();
}
/**
* @brief Returns constant nGraph function
*
* @return constant nGraph function
*/
std::shared_ptr<const ngraph::Function> getFunction() const {
if (actual == nullptr) THROW_IE_EXCEPTION << "CNNNetwork was not initialized.";
return actual->getFunction();
}
/**
* @copybrief ICNNNetwork::addOutput
*
* Wraps ICNNNetwork::addOutput
*
* @param layerName Name of the layer
* @param outputIndex Index of the output
*/
void addOutput(const std::string& layerName, size_t outputIndex = 0) {
CALL_STATUS_FNC(addOutput, layerName, outputIndex);
}
/**
* @deprecated Migrate to IR v10 and work with ngraph::Function directly. The method will be removed in 2021.1
* @copybrief ICNNNetwork::getLayerByName
*
* Wraps ICNNNetwork::getLayerByName
*
* @param layerName Given name of the layer
* @return Status code of the operation. InferenceEngine::OK if succeeded
*/
INFERENCE_ENGINE_DEPRECATED("Migrate to IR v10 and work with ngraph::Function directly. The method will be removed in 2021.1")
CNNLayerPtr getLayerByName(const char* layerName) const;
/**
* @deprecated Use CNNNetwork::getFunction() and work with ngraph::Function directly. The method will be removed in 2021.1
* @brief Begin layer iterator
*
* Order of layers is implementation specific,
* and can be changed in future
*
* @return Iterator pointing to a layer
*/
IE_SUPPRESS_DEPRECATED_START
INFERENCE_ENGINE_DEPRECATED("Use CNNNetwork::getFunction() and work with ngraph::Function directly. The method will be removed in 2021.1")
details::CNNNetworkIterator begin() const;
/**
* @deprecated Use CNNNetwork::getFunction() and work with ngraph::Function directly. The method will be removed in 2021.1
* @brief End layer iterator
* @return Iterator pointing to a layer
*/
INFERENCE_ENGINE_DEPRECATED("Use CNNNetwork::getFunction() and work with ngraph::Function directly. The method will be removed in 2021.1")
details::CNNNetworkIterator end() const;
IE_SUPPRESS_DEPRECATED_END
/**
* @deprecated Use CNNNetwork::layerCount() instead. The method will be removed in 2021.1
* @brief Number of layers in network object
*
* @return Number of layers.
*/
INFERENCE_ENGINE_DEPRECATED("Use CNNNetwork::layerCount() instead. The method will be removed in 2021.1")
size_t size() const;
/**
* @deprecated Use Core::AddExtension to add an extension to the library
* @brief Registers extension within the plugin
*
* @param extension Pointer to already loaded reader extension with shape propagation implementations
*/
INFERENCE_ENGINE_DEPRECATED("Use Core::AddExtension to add an extension to the library")
void AddExtension(InferenceEngine::IShapeInferExtensionPtr extension);
/**
* @brief Helper method to get collect all input shapes with names of corresponding Data objects
*
* @return Map of pairs: input name and its dimension.
*/
virtual ICNNNetwork::InputShapes getInputShapes() const {
if (actual == nullptr) THROW_IE_EXCEPTION << "CNNNetwork was not initialized.";
ICNNNetwork::InputShapes shapes;
InputsDataMap inputs;
actual->getInputsInfo(inputs);
for (const auto& pair : inputs) {
auto info = pair.second;
if (info) {
auto data = info->getInputData();
if (data) {
shapes[data->getName()] = data->getTensorDesc().getDims();
}
}
}
return shapes;
}
/**
* @brief Run shape inference with new input shapes for the network
*
* @param inputShapes - map of pairs: name of corresponding data and its dimension.
*/
virtual void reshape(const ICNNNetwork::InputShapes& inputShapes) {
CALL_STATUS_FNC(reshape, inputShapes);
}
/**
* @brief Serialize network to IR and weights files.
*
* @param xmlPath Path to output IR file.
* @param binPath Path to output weights file. The parameter is skipped in case
* of executable graph info serialization.
*/
void serialize(const std::string& xmlPath, const std::string& binPath = "") const {
CALL_STATUS_FNC(serialize, xmlPath, binPath);
}
protected:
/**
* @brief Reader extra reference, might be nullptr
*/
IE_SUPPRESS_DEPRECATED_START
CNNNetReaderPtr reader;
IE_SUPPRESS_DEPRECATED_END
/**
* @brief Network extra interface, might be nullptr
*/
std::shared_ptr<ICNNNetwork> network;
/**
* @brief A pointer to the current network
*/
ICNNNetwork* actual = nullptr;
/**
* @brief A pointer to output data
*/
DataPtr output;
};
} // namespace InferenceEngine