268 lines
7.1 KiB
C++
268 lines
7.1 KiB
C++
// Copyright (C) 2018-2019 Intel Corporation
|
|
// SPDX-License-Identifier: Apache-2.0
|
|
//
|
|
|
|
/**
|
|
* @brief A header file that provides wrapper for ICNNNetwork object
|
|
* @file ie_cnn_network.h
|
|
*/
|
|
#pragma once
|
|
|
|
#include <details/ie_exception_conversion.hpp>
|
|
#include <details/ie_cnn_network_iterator.hpp>
|
|
#include <ie_icnn_network.hpp>
|
|
#include <ie_icnn_net_reader.h>
|
|
#include "ie_common.h"
|
|
#include "ie_data.h"
|
|
#include "ie_blob.h"
|
|
#include <vector>
|
|
#include <string>
|
|
#include <map>
|
|
#include <utility>
|
|
#include <memory>
|
|
|
|
namespace InferenceEngine {
|
|
|
|
/**
|
|
* @brief This class contains all the information about the Neural Network and the related binary information
|
|
*/
|
|
class CNNNetwork {
|
|
public:
|
|
/**
|
|
* @brief A default constructor
|
|
*/
|
|
CNNNetwork() = default;
|
|
|
|
/**
|
|
* @deprecated Use CNNNetwork::CNNNetwork(std::shared_ptr<ICNNNetwork>) to construct a network
|
|
* @brief Initialises helper class from externally managed pointer
|
|
* @param actual Pointer to the network object
|
|
*/
|
|
INFERENCE_ENGINE_DEPRECATED
|
|
explicit CNNNetwork(ICNNNetwork* actual) : actual(actual) {
|
|
if (actual == nullptr) {
|
|
THROW_IE_EXCEPTION << "CNNNetwork was not initialized.";
|
|
}
|
|
}
|
|
|
|
/**
|
|
* @brief Allows helper class to manage lifetime of network object
|
|
* @param network Pointer to the network object
|
|
*/
|
|
explicit CNNNetwork(std::shared_ptr<ICNNNetwork> network)
|
|
: network(network) {
|
|
actual = network.get();
|
|
if (actual == nullptr) {
|
|
THROW_IE_EXCEPTION << "CNNNetwork was not initialized.";
|
|
}
|
|
}
|
|
|
|
/**
|
|
* @brief A constructor from ICNNNetReader object
|
|
* @param reader Pointer to the ICNNNetReader object
|
|
*/
|
|
explicit CNNNetwork(std::shared_ptr<ICNNNetReader> reader)
|
|
: reader(reader)
|
|
, actual(reader->getNetwork(nullptr)) {
|
|
if (actual == nullptr) {
|
|
THROW_IE_EXCEPTION << "CNNNetwork was not initialized.";
|
|
}
|
|
}
|
|
|
|
/**
|
|
* @brief A destructor
|
|
*/
|
|
virtual ~CNNNetwork() {}
|
|
|
|
/**
|
|
* @brief Wraps original method
|
|
* ICNNNetwork::getPrecision
|
|
*/
|
|
virtual Precision getPrecision() const {
|
|
return actual->getPrecision();
|
|
}
|
|
|
|
/**
|
|
* @brief Wraps original method
|
|
* ICNNNetwork::getOutputsInfo
|
|
*/
|
|
virtual OutputsDataMap getOutputsInfo() const {
|
|
OutputsDataMap outputs;
|
|
actual->getOutputsInfo(outputs);
|
|
return outputs;
|
|
}
|
|
|
|
/**
|
|
* @brief Wraps original method
|
|
* ICNNNetwork::getInputsInfo
|
|
*/
|
|
virtual InputsDataMap getInputsInfo() const {
|
|
InputsDataMap inputs;
|
|
actual->getInputsInfo(inputs);
|
|
return inputs;
|
|
}
|
|
|
|
/**
|
|
* @brief Wraps original method
|
|
* ICNNNetwork::layerCount
|
|
*/
|
|
size_t layerCount() const {
|
|
return actual->layerCount();
|
|
}
|
|
|
|
/**
|
|
* @brief Wraps original method
|
|
* ICNNNetwork::getName
|
|
*/
|
|
const std::string& getName() const noexcept {
|
|
return actual->getName();
|
|
}
|
|
|
|
/**
|
|
* @brief Wraps original method
|
|
* ICNNNetwork::setBatchSize
|
|
*/
|
|
virtual void setBatchSize(const size_t size) {
|
|
CALL_STATUS_FNC(setBatchSize, size);
|
|
}
|
|
|
|
/**
|
|
* @brief Wraps original method
|
|
* ICNNNetwork::getBatchSize
|
|
*/
|
|
virtual size_t getBatchSize() const {
|
|
return actual->getBatchSize();
|
|
}
|
|
|
|
/**
|
|
* @brief An overloaded operator & to get current network
|
|
* @return An instance of the current network
|
|
*/
|
|
operator ICNNNetwork &() const {
|
|
return *actual;
|
|
}
|
|
|
|
/**
|
|
* @deprecated No needs to specify target device to the network. Use InferenceEngine::Core with target device directly
|
|
* @brief Sets tha target device
|
|
* @param device Device instance to set
|
|
*/
|
|
#ifndef _WIN32
|
|
INFERENCE_ENGINE_DEPRECATED
|
|
#endif
|
|
void setTargetDevice(TargetDevice device) {
|
|
IE_SUPPRESS_DEPRECATED_START
|
|
actual->setTargetDevice(device);
|
|
IE_SUPPRESS_DEPRECATED_END
|
|
}
|
|
|
|
/**
|
|
* @brief Wraps original method
|
|
* ICNNNetwork::addOutput
|
|
*/
|
|
void addOutput(const std::string &layerName, size_t outputIndex = 0) {
|
|
CALL_STATUS_FNC(addOutput, layerName, outputIndex);
|
|
}
|
|
|
|
/**
|
|
* @brief Wraps original method
|
|
* ICNNNetwork::getLayerByName
|
|
*/
|
|
CNNLayerPtr getLayerByName(const char *layerName) const {
|
|
CNNLayerPtr layer;
|
|
CALL_STATUS_FNC(getLayerByName, layerName, layer);
|
|
return layer;
|
|
}
|
|
|
|
/**
|
|
* @brief Begin layer iterator
|
|
* Order of layers is implementation specific,
|
|
* and can be changed in future
|
|
*/
|
|
details::CNNNetworkIterator begin() const {
|
|
return details::CNNNetworkIterator(actual);
|
|
}
|
|
|
|
/**
|
|
* @brief End layer iterator
|
|
*/
|
|
details::CNNNetworkIterator end() const {
|
|
return details::CNNNetworkIterator();
|
|
}
|
|
|
|
/**
|
|
* @brief number of layers in network object
|
|
* @return
|
|
*/
|
|
size_t size() const {
|
|
return std::distance(std::begin(*this), std::end(*this));
|
|
}
|
|
|
|
/**
|
|
* @brief Registers extension within the plugin
|
|
* @param extension Pointer to already loaded reader extension with shape propagation implementations
|
|
*/
|
|
void AddExtension(InferenceEngine::IShapeInferExtensionPtr extension) {
|
|
CALL_STATUS_FNC(AddExtension, extension);
|
|
}
|
|
|
|
/**
|
|
* @brief - Helper method to get collect all input shapes with names of corresponding Data objects
|
|
* @return Map of pairs: input's name and its dimension.
|
|
*/
|
|
virtual ICNNNetwork::InputShapes getInputShapes() const {
|
|
ICNNNetwork::InputShapes shapes;
|
|
InputsDataMap inputs;
|
|
actual->getInputsInfo(inputs);
|
|
for (const auto& pair : inputs) {
|
|
auto info = pair.second;
|
|
if (info) {
|
|
auto data = info->getInputData();
|
|
if (data) {
|
|
shapes[data->getName()] = data->getTensorDesc().getDims();
|
|
}
|
|
}
|
|
}
|
|
return shapes;
|
|
}
|
|
|
|
/**
|
|
* @brief Run shape inference with new input shapes for the network
|
|
* @param inputShapes - map of pairs: name of corresponding data and its dimension.
|
|
*/
|
|
virtual void reshape(const ICNNNetwork::InputShapes &inputShapes) {
|
|
CALL_STATUS_FNC(reshape, inputShapes);
|
|
}
|
|
|
|
/**
|
|
* @brief Serialize network to IR and weights files.
|
|
* @param xmlPath Path to output IR file.
|
|
* @param binPath Path to output weights file. The parameter is skipped in case
|
|
* of executable graph info serialization.
|
|
*/
|
|
void serialize(const std::string &xmlPath, const std::string &binPath = "") const {
|
|
CALL_STATUS_FNC(serialize, xmlPath, binPath);
|
|
}
|
|
|
|
protected:
|
|
/**
|
|
* @brief reader extra reference, might be nullptr
|
|
*/
|
|
std::shared_ptr<ICNNNetReader> reader;
|
|
/**
|
|
* @brief network extra interface, might be nullptr
|
|
*/
|
|
std::shared_ptr<ICNNNetwork> network;
|
|
|
|
/**
|
|
* @brief A pointer to the current network
|
|
*/
|
|
ICNNNetwork *actual = nullptr;
|
|
/**
|
|
* @brief A pointer to output data
|
|
*/
|
|
DataPtr output;
|
|
};
|
|
|
|
} // namespace InferenceEngine
|