Files
openvino/inference-engine/include/ie_preprocess.hpp
2020-02-11 22:48:49 +03:00

226 lines
7.0 KiB
C++

// Copyright (C) 2018-2020 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
/**
* @brief This header file provides structures to store info about pre-processing of network inputs (scale, mean image,
* ...)
*
* @file ie_preprocess.hpp
*/
#pragma once
#include <memory>
#include <vector>
#include "ie_blob.h"
namespace InferenceEngine {
/**
* @brief This structure stores info about pre-processing of network inputs (scale, mean image, ...)
*/
struct PreProcessChannel {
/** @brief Scale parameter for a channel */
float stdScale = 1;
/** @brief Mean value for a channel */
float meanValue = 0;
/** @brief Mean data for a channel */
Blob::Ptr meanData;
/** @brief Smart pointer to an instance */
using Ptr = std::shared_ptr<PreProcessChannel>;
};
/**
* @brief Defines available types of mean
*/
enum MeanVariant {
MEAN_IMAGE, /**< mean value is specified for each input pixel */
MEAN_VALUE, /**< mean value is specified for each input channel */
NONE, /**< no mean value specified */
};
/**
* @enum ResizeAlgorithm
* @brief Represents the list of supported resize algorithms.
*/
enum ResizeAlgorithm { NO_RESIZE = 0, RESIZE_BILINEAR, RESIZE_AREA };
/**
* @brief This class stores pre-process information for the input
*/
class PreProcessInfo {
// Channel data
std::vector<PreProcessChannel::Ptr> _channelsInfo;
MeanVariant _variant = NONE;
// Resize Algorithm to be applied for input before inference if needed.
ResizeAlgorithm _resizeAlg = NO_RESIZE;
// Color format to be used in on-demand color conversions applied to input before inference
ColorFormat _colorFormat = ColorFormat::RAW;
public:
/**
* @brief Overloaded [] operator to safely get the channel by an index
*
* Throws an exception if channels are empty
*
* @param index Index of the channel to get
* @return The pre-process channel instance
*/
PreProcessChannel::Ptr& operator[](size_t index) {
if (_channelsInfo.empty()) {
THROW_IE_EXCEPTION << "accessing pre-process when nothing was set.";
}
if (index >= _channelsInfo.size()) {
THROW_IE_EXCEPTION << "pre process index " << index << " is out of bounds.";
}
return _channelsInfo[index];
}
/**
* @brief operator [] to safely get the channel preprocessing information by index.
*
* Throws exception if channels are empty or index is out of border
*
* @param index Index of the channel to get
* @return The const preprocess channel instance
*/
const PreProcessChannel::Ptr& operator[](size_t index) const {
if (_channelsInfo.empty()) {
THROW_IE_EXCEPTION << "accessing pre-process when nothing was set.";
}
if (index >= _channelsInfo.size()) {
THROW_IE_EXCEPTION << "pre process index " << index << " is out of bounds.";
}
return _channelsInfo[index];
}
/**
* @brief Returns a number of channels to preprocess
*
* @return The number of channels
*/
size_t getNumberOfChannels() const {
return _channelsInfo.size();
}
/**
* @brief Initializes with given number of channels
*
* @param numberOfChannels Number of channels to initialize
*/
void init(const size_t numberOfChannels) {
_channelsInfo.resize(numberOfChannels);
for (auto& channelInfo : _channelsInfo) {
channelInfo = std::make_shared<PreProcessChannel>();
}
}
/**
* @brief Sets mean image values if operation is applicable.
*
* Also sets the mean type to MEAN_IMAGE for all channels
*
* @param meanImage Blob with a mean image
*/
void setMeanImage(const Blob::Ptr& meanImage) {
if (meanImage.get() == nullptr) {
THROW_IE_EXCEPTION << "Failed to set invalid mean image: nullptr";
} else if (meanImage.get()->getTensorDesc().getLayout() != Layout::CHW) {
THROW_IE_EXCEPTION << "Mean image layout should be CHW";
} else if (meanImage.get()->getTensorDesc().getDims().size() != 3) {
THROW_IE_EXCEPTION << "Failed to set invalid mean image: number of dimensions != 3";
} else if (meanImage.get()->getTensorDesc().getDims()[0] != getNumberOfChannels()) {
THROW_IE_EXCEPTION << "Failed to set invalid mean image: number of channels != " << getNumberOfChannels();
}
_variant = MEAN_IMAGE;
}
/**
* @brief Sets mean image values if operation is applicable.
*
* Also sets the mean type to MEAN_IMAGE for a particular channel
*
* @param meanImage Blob with a mean image
* @param channel Index of a particular channel
*/
void setMeanImageForChannel(const Blob::Ptr& meanImage, const size_t channel) {
if (meanImage.get() == nullptr) {
THROW_IE_EXCEPTION << "Failed to set invalid mean image for channel: nullptr";
} else if (meanImage.get()->getTensorDesc().getDims().size() != 2) {
THROW_IE_EXCEPTION << "Failed to set invalid mean image for channel: number of dimensions != 2";
} else if (channel >= _channelsInfo.size()) {
THROW_IE_EXCEPTION << "Channel " << channel
<< " exceed number of PreProcess channels: " << _channelsInfo.size();
}
_variant = MEAN_IMAGE;
_channelsInfo[channel]->meanData = meanImage;
}
/**
* @brief Sets a type of mean operation
*
* @param variant Type of mean operation to set
*/
void setVariant(const MeanVariant& variant) {
_variant = variant;
}
/**
* @brief Gets a type of mean operation
*
* @return The type of mean operation
*/
MeanVariant getMeanVariant() const {
return _variant;
}
/**
* @brief Sets resize algorithm to be used during pre-processing
*
* @param alg Resize algorithm
*/
void setResizeAlgorithm(const ResizeAlgorithm& alg) {
_resizeAlg = alg;
}
/**
* @brief Gets preconfigured resize algorithm
*
* @return Resize algorithm
*/
ResizeAlgorithm getResizeAlgorithm() const {
return _resizeAlg;
}
/**
* @brief Changes the color format of the input data provided by the user
*
* This function should be called before loading the network to the plugin
* Setting color format different from ColorFormat::RAW enables automatic color conversion
* (as a part of built-in preprocessing routine)
*
* @param fmt A new color format associated with the input
*/
void setColorFormat(ColorFormat fmt) {
_colorFormat = fmt;
}
/**
* @brief Gets a color format associated with the input
*
* @details By default, the color format is ColorFormat::RAW meaning
* there is no particular color format assigned to the input
* @return Color format.
*/
ColorFormat getColorFormat() const {
return _colorFormat;
}
};
} // namespace InferenceEngine