[GNA] Pre/Post processing via evaluate (#15691)
This commit is contained in:
parent
6fc0b6479e
commit
55e9cae54f
@ -37,6 +37,9 @@ struct GnaDesc {
|
||||
uint32_t allocated_size = 0;
|
||||
std::vector<void*> ptrs = {}; // ptr per each infer request
|
||||
|
||||
// pre/post processing model
|
||||
std::shared_ptr<ov::Model> pre_post_process_model = nullptr;
|
||||
|
||||
// help methods
|
||||
uint32_t get_required_size() const {
|
||||
return num_elements * tensor_precision.size();
|
||||
|
@ -16,22 +16,6 @@
|
||||
#include "layers/gna_split_layer.hpp"
|
||||
#include "memory/gna_memory.hpp"
|
||||
|
||||
struct TranspositionInfo {
|
||||
bool transpose;
|
||||
size_t num_transpose_rows;
|
||||
size_t num_transpose_columns;
|
||||
};
|
||||
|
||||
using TranspositionInfoMap = std::map<std::string, std::vector<TranspositionInfo>>;
|
||||
|
||||
static inline bool FoundPartToTranspose(const std::vector<TranspositionInfo>& transpositionInfo) {
|
||||
auto partToTranspose =
|
||||
std::find_if(std::begin(transpositionInfo), std::end(transpositionInfo), [](const TranspositionInfo& infoPart) {
|
||||
return infoPart.transpose;
|
||||
});
|
||||
return partToTranspose != std::end(transpositionInfo);
|
||||
}
|
||||
|
||||
namespace ov {
|
||||
namespace intel_gna {
|
||||
|
||||
@ -45,6 +29,7 @@ using ConcatConnection = std::unordered_map<std::string, GNAConcatLayer>;
|
||||
using SplitConnection = std::unordered_map<std::string, GNASplitLayer>;
|
||||
using CropConnection = std::unordered_map<std::string, GNACropLayer>;
|
||||
using ConstConnections = std::unordered_map<std::string, void*>;
|
||||
using PrePostProcessModels = std::unordered_map<std::string, std::shared_ptr<ov::Model>>;
|
||||
|
||||
} // namespace intel_gna
|
||||
} // namespace ov
|
||||
|
@ -12,10 +12,13 @@
|
||||
#include "layers/gna_layer_info.hpp"
|
||||
#include "log/debug.hpp"
|
||||
#include "ops/util/util.hpp"
|
||||
#include "pre_post_process/transposition_info.hpp"
|
||||
|
||||
namespace ov {
|
||||
namespace intel_gna {
|
||||
|
||||
using TranspositionInfo = pre_post_processing::TranspositionInfo;
|
||||
|
||||
/**
|
||||
* @brief checks if it's a reshape from 4d to 3d tensor
|
||||
* @param layer Non-functional layer
|
||||
|
@ -39,7 +39,6 @@
|
||||
#include "gna_fused_iterator.hpp"
|
||||
#include "gna_graph_patterns.hpp"
|
||||
#include "gna_itt.hpp"
|
||||
#include "gna_model_serial.hpp"
|
||||
#include "gna_plugin_config.hpp"
|
||||
#include "gna_tensor_tools.hpp"
|
||||
#include "gna_transformations_pipeline.hpp"
|
||||
@ -47,12 +46,14 @@
|
||||
#include "log/log.hpp"
|
||||
#include "memory/gna_memory_state.hpp"
|
||||
#include "orientation_helper.hpp"
|
||||
#include "preprocessing.hpp"
|
||||
#include "pre_post_process/preprocessing.hpp"
|
||||
#include "pre_post_process/transposition_info.hpp"
|
||||
#include "request/model_wrapper_factory.hpp"
|
||||
#include "request/worker_factory.hpp"
|
||||
#include "request/worker_pool_impl.hpp"
|
||||
#include "runtime/gna_float_runtime.hpp"
|
||||
#include "scale_factor_helper.hpp"
|
||||
#include "serial/gna_model_serial.hpp"
|
||||
|
||||
using namespace ov::intel_gna::ngraph_util;
|
||||
|
||||
@ -81,6 +82,7 @@ using namespace InferenceEngine::details;
|
||||
|
||||
using namespace ov::intel_gna::memory;
|
||||
using namespace ov::intel_gna::frontend;
|
||||
using namespace ov::intel_gna::pre_post_processing;
|
||||
|
||||
namespace InferenceEngine {
|
||||
template <>
|
||||
@ -341,6 +343,27 @@ void GNAPlugin::ImportFrames(void* ptr_dst,
|
||||
}
|
||||
}
|
||||
|
||||
void GNAPlugin::PrePostProcess(InferenceEngine::Blob::Ptr input_blob,
|
||||
InferenceEngine::Blob::Ptr output_blob,
|
||||
std::shared_ptr<ov::Model> model) {
|
||||
const ov::element::Type input_type = details::convertPrecision(input_blob->getTensorDesc().getPrecision());
|
||||
const ov::element::Type output_type = details::convertPrecision(output_blob->getTensorDesc().getPrecision());
|
||||
const ov::Shape& input_shape = input_blob->getTensorDesc().getDims();
|
||||
const ov::Shape& output_shape = output_blob->getTensorDesc().getDims();
|
||||
|
||||
for (auto param : model->get_parameters()) {
|
||||
param->set_element_type(input_type);
|
||||
}
|
||||
model->validate_nodes_and_infer_types();
|
||||
|
||||
ov::TensorVector inputs = {ov::Tensor(input_type, input_shape, input_blob->cbuffer().as<void*>())};
|
||||
ov::TensorVector results = {ov::Tensor(output_type, output_shape, output_blob->buffer().as<void*>())};
|
||||
|
||||
if (!model->evaluate(results, inputs)) {
|
||||
THROW_GNA_EXCEPTION << "Failed to evaluate model " << model->get_friendly_name() << std::endl;
|
||||
}
|
||||
}
|
||||
|
||||
GNAPlugin::GNAPlugin() : graphCompiler(config) {
|
||||
Init();
|
||||
UpdateFieldsFromConfig();
|
||||
@ -463,6 +486,12 @@ void GNAPlugin::UpdateInputs(const std::vector<std::shared_ptr<const ov::Node>>&
|
||||
const std::string ie_name = param->get_friendly_name();
|
||||
(*inputs_ptr_)[ie_name].name = param->get_friendly_name();
|
||||
(*inputs_ptr_)[ie_name].tensor_names = param->get_output_tensor(0).get_names();
|
||||
|
||||
// find pre-processing model
|
||||
auto subgraph_it = m_input_output_subgraphs.find(ie_name);
|
||||
if (subgraph_it != m_input_output_subgraphs.end()) {
|
||||
(*inputs_ptr_)[ie_name].pre_post_process_model = subgraph_it->second;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@ -472,6 +501,12 @@ void GNAPlugin::UpdateOutputs(const std::vector<std::shared_ptr<const ov::Node>>
|
||||
const std::string ie_name = ov::op::util::create_ie_output_name(result->input_value(0));
|
||||
outputs_[ie_name].name = ie_name;
|
||||
outputs_[ie_name].tensor_names = result->get_output_tensor(0).get_names();
|
||||
|
||||
// find postprocessing model
|
||||
auto subgraph_it = m_input_output_subgraphs.find(ie_name);
|
||||
if (subgraph_it != m_input_output_subgraphs.end()) {
|
||||
outputs_[ie_name].pre_post_process_model = subgraph_it->second;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@ -668,7 +703,7 @@ void GNAPlugin::LoadNetwork(const CNNNetwork& _network) {
|
||||
if (_network.getFunction()) {
|
||||
CNNNetwork clonedNetwork = InferenceEngine::cloneNetwork(_network);
|
||||
auto model = clonedNetwork.getFunction();
|
||||
transformer.apply(model);
|
||||
transformer.apply(model, &m_input_output_subgraphs);
|
||||
limitations::check_all_ops_supported(model, effectiveCompileTarget, config.gnaPrecision);
|
||||
convertedNetwork = InferenceEngine::details::convertFunctionToICNNNetwork(model, clonedNetwork);
|
||||
}
|
||||
@ -679,6 +714,7 @@ void GNAPlugin::LoadNetwork(const CNNNetwork& _network) {
|
||||
transformer.convert_precision_legacy(network);
|
||||
|
||||
// Check the network
|
||||
|
||||
std::string error;
|
||||
if (!limitations::AreLayersSupported(network, error)) {
|
||||
THROW_GNA_EXCEPTION << error.c_str();
|
||||
@ -951,6 +987,15 @@ void GNAPlugin::LoadNetwork(const CNNNetwork& _network) {
|
||||
{TranspositionInfo{dnn->do_rotate_input, dnn->num_rotate_rows, dnn->num_rotate_columns}}});
|
||||
}
|
||||
}
|
||||
|
||||
// TODO: Need to remove this conversion when ngraph NCHW->NHWC transformation is enabled
|
||||
if (!transpose_inputs_info.empty()) {
|
||||
ConvertTransposeMapToModel(transpose_inputs_info, inputs_ptr_->Get());
|
||||
}
|
||||
if (!transpose_outputs_info.empty()) {
|
||||
ConvertTransposeMapToModel(transpose_outputs_info, outputs_.Get());
|
||||
}
|
||||
|
||||
DumpXNNToFile();
|
||||
|
||||
#ifdef PLOT
|
||||
@ -1072,40 +1117,42 @@ uint32_t GNAPlugin::QueueInference(const InferenceEngine::BlobMap& inputs, Infer
|
||||
|
||||
int inputNum = 0;
|
||||
for (auto& input : inputs) {
|
||||
auto inputLayout = input.second->getTensorDesc().getLayout();
|
||||
if (inputLayout != InferenceEngine::Layout::C && inputLayout != InferenceEngine::Layout::NC &&
|
||||
inputLayout != InferenceEngine::Layout::CN && inputLayout != InferenceEngine::Layout::CHW &&
|
||||
inputLayout != InferenceEngine::Layout::NCHW) {
|
||||
std::string input_name = input.first;
|
||||
InferenceEngine::Layout input_layout = input.second->getTensorDesc().getLayout();
|
||||
|
||||
if (input_layout != InferenceEngine::Layout::C && input_layout != InferenceEngine::Layout::NC &&
|
||||
input_layout != InferenceEngine::Layout::CN && input_layout != InferenceEngine::Layout::CHW &&
|
||||
input_layout != InferenceEngine::Layout::NCHW) {
|
||||
THROW_GNA_EXCEPTION << "Expected input blob to have Layout::C, Layout::NC, Layout::CN, Layout::NCHW or "
|
||||
"Layout::CHW. But was: "
|
||||
<< input.second->getTensorDesc().getLayout();
|
||||
<< input_layout;
|
||||
}
|
||||
|
||||
if (inputLayout == InferenceEngine::Layout::NCHW || inputLayout == InferenceEngine::Layout::CHW) {
|
||||
if (input_layout == InferenceEngine::Layout::NCHW || input_layout == InferenceEngine::Layout::CHW) {
|
||||
// specific case that can be squeezed to 2d
|
||||
inputLayout = InferenceEngine::Layout::NC;
|
||||
input_layout = InferenceEngine::Layout::NC;
|
||||
}
|
||||
|
||||
auto is1D = input.second->getTensorDesc().getLayout() == InferenceEngine::Layout::C;
|
||||
auto is3D = input.second->getTensorDesc().getLayout() == InferenceEngine::Layout::CHW;
|
||||
auto is1D = input_layout == InferenceEngine::Layout::C;
|
||||
auto is3D = input_layout == InferenceEngine::Layout::CHW;
|
||||
|
||||
if (inputs_ptr_->at(input.first).ptrs.empty()) {
|
||||
if (inputs_ptr_->at(input_name).ptrs.empty()) {
|
||||
// should not happen in user code however might happen if there any non executable network based integration
|
||||
// of GNAPlugin instance
|
||||
THROW_GNA_EXCEPTION << "network not loaded : input pointer for " << input.first << " not set";
|
||||
THROW_GNA_EXCEPTION << "network not loaded : input pointer for " << input_name << " not set";
|
||||
}
|
||||
|
||||
if (inputs_ptr_->at(input.first).ptrs[index] == nullptr) {
|
||||
if (inputs_ptr_->at(input_name).ptrs[index] == nullptr) {
|
||||
// should not happen in user code however might happen if there any non executable network based integration
|
||||
// of GNAPlugin instance
|
||||
THROW_GNA_EXCEPTION << "network not loaded : input pointer for (" << input.first << " at inferRequest #"
|
||||
THROW_GNA_EXCEPTION << "network not loaded : input pointer for (" << input_name << " at inferRequest #"
|
||||
<< index << " not set";
|
||||
}
|
||||
const auto inputOrientation = inputs_ptr_->at(input.first).orientation;
|
||||
const auto inputOrientation = inputs_ptr_->at(input_name).orientation;
|
||||
if (inputOrientation == kDnnUnknownOrientation) {
|
||||
// should not happen in user code however might happen if there any non executable network based integration
|
||||
// of GNAPlugin instance
|
||||
THROW_GNA_EXCEPTION << "network not loaded : input orientation for " << input.first << " not set";
|
||||
THROW_GNA_EXCEPTION << "network not loaded : input orientation for " << input_name << " not set";
|
||||
}
|
||||
|
||||
for (auto& output : outputs_.Get()) {
|
||||
@ -1125,43 +1172,49 @@ uint32_t GNAPlugin::QueueInference(const InferenceEngine::BlobMap& inputs, Infer
|
||||
auto importedElementSizeBytes = gnaFlags->sw_fp32 ? 4 : (gnaFlags->input_low_precision ? 1 : 2);
|
||||
auto importedBytes = importedElements * importedFrames * importedElementSizeBytes;
|
||||
|
||||
if (inputs_ptr_->at(input.first).get_required_size() < importedBytes) {
|
||||
THROW_GNA_EXCEPTION << "Cannot import input frames for :" << input.first
|
||||
<< ", allocated size: " << inputs_ptr_->at(input.first).get_required_size()
|
||||
if (inputs_ptr_->at(input_name).get_required_size() < importedBytes) {
|
||||
THROW_GNA_EXCEPTION << "Cannot import input frames for :" << input_name
|
||||
<< ", allocated size: " << inputs_ptr_->at(input_name).get_required_size()
|
||||
<< ", but input blob size: " << importedBytes;
|
||||
}
|
||||
|
||||
ImportFrames(inputs_ptr_->at(input.first).ptrs[index],
|
||||
// Perform pre-processing on CPU.
|
||||
// When we need to perform pre-processing on CPU using ngraph model we copy user input to the buffer,
|
||||
// then set preprocessing output blob as gna input blob.
|
||||
std::shared_ptr<ov::Model> model = inputs_ptr_->at(input_name).pre_post_process_model;
|
||||
Blob::Ptr buff_blob = nullptr;
|
||||
TensorDesc buff_tensor_desc(input.second->getTensorDesc());
|
||||
buff_tensor_desc.setPrecision(inputs_ptr_->at(input_name).tensor_precision);
|
||||
|
||||
if (model) {
|
||||
// WA: evaluate gather with int16 precision as fp16
|
||||
if (buff_tensor_desc.getPrecision() == Precision::I16) {
|
||||
buff_tensor_desc.setPrecision(Precision::FP16);
|
||||
}
|
||||
buff_blob = make_blob_with_precision(buff_tensor_desc);
|
||||
buff_blob->allocate();
|
||||
} else {
|
||||
buff_blob = make_blob_with_precision(buff_tensor_desc, inputs_ptr_->at(input_name).ptrs[index]);
|
||||
}
|
||||
|
||||
ImportFrames(buff_blob->buffer(),
|
||||
input.second->cbuffer().as<float*>(),
|
||||
input.second->getTensorDesc().getPrecision(),
|
||||
gnaFlags->sw_fp32 ? kScaleFactorDefault : inputs_ptr_->at(input.first).scale_factor,
|
||||
gnaFlags->sw_fp32 ? kScaleFactorDefault : inputs_ptr_->at(input_name).scale_factor,
|
||||
inputOrientation,
|
||||
importedFrames,
|
||||
targetGroups,
|
||||
importedElements,
|
||||
importedElements);
|
||||
|
||||
auto transpose_info = transpose_inputs_info.find(input.first);
|
||||
if (transpose_info != std::end(transpose_inputs_info)) {
|
||||
size_t batchSize = (dims.size() > 1) ? dims[0] : 1;
|
||||
size_t elementsPerBatch = (dims.size() > 1) ? InferenceEngine::details::product(dims) / dims[0] : dims[0];
|
||||
size_t transposed_data_size = 0;
|
||||
for (const auto& part_transposition_info : transpose_info->second) {
|
||||
transposed_data_size +=
|
||||
part_transposition_info.num_transpose_rows * part_transposition_info.num_transpose_columns;
|
||||
}
|
||||
if (elementsPerBatch != transposed_data_size) {
|
||||
THROW_GNA_EXCEPTION << "Transposed data size (" << transposed_data_size
|
||||
<< ") do not match input buffer length of " << elementsPerBatch;
|
||||
}
|
||||
auto input_ptr = reinterpret_cast<uint8_t*>(inputs_ptr_->at(input.first).ptrs[index]);
|
||||
ConvertTensorFromNCHWToNHWC(gnadevice ? 2 : 4,
|
||||
batchSize,
|
||||
elementsPerBatch,
|
||||
input_ptr,
|
||||
true,
|
||||
transpose_info->second);
|
||||
if (model) {
|
||||
Precision output_prc = buff_blob->getTensorDesc().getPrecision();
|
||||
SizeVector output_dims = model->get_result()->get_shape();
|
||||
TensorDesc output_desc(output_prc, output_dims, InferenceEngine::Layout::ANY);
|
||||
Blob::Ptr output_blob = make_blob_with_precision(output_desc, inputs_ptr_->at(input_name).ptrs[index]);
|
||||
PrePostProcess(buff_blob, output_blob, model);
|
||||
}
|
||||
|
||||
++inputNum;
|
||||
}
|
||||
|
||||
@ -1229,58 +1282,64 @@ RequestStatus GNAPlugin::WaitFor(uint32_t request_idx, int64_t millisTimeout) {
|
||||
dnn->WriteInputAndOutputTextGNA(*worker.model());
|
||||
#endif
|
||||
for (auto&& outputBlobIt : requestResult) {
|
||||
auto& outputBlob = outputBlobIt.second;
|
||||
auto& outputDesc = outputs_.at(outputBlobIt.first);
|
||||
if (outputBlob->getTensorDesc().getLayout() != InferenceEngine::Layout::C &&
|
||||
outputBlob->getTensorDesc().getLayout() != InferenceEngine::Layout::NC &&
|
||||
outputBlob->getTensorDesc().getLayout() != InferenceEngine::Layout::CN &&
|
||||
outputBlob->getTensorDesc().getLayout() != InferenceEngine::Layout::NCHW &&
|
||||
outputBlob->getTensorDesc().getLayout() != InferenceEngine::Layout::CHW &&
|
||||
outputBlob->getTensorDesc().getLayout() != InferenceEngine::Layout::SCALAR) {
|
||||
const std::string& output_name = outputBlobIt.first;
|
||||
Blob::Ptr output_blob = outputBlobIt.second;
|
||||
const InferenceEngine::Layout output_layout = output_blob->getTensorDesc().getLayout();
|
||||
|
||||
if (output_layout != InferenceEngine::Layout::C && output_layout != InferenceEngine::Layout::NC &&
|
||||
output_layout != InferenceEngine::Layout::CN && output_layout != InferenceEngine::Layout::NCHW &&
|
||||
output_layout != InferenceEngine::Layout::CHW && output_layout != InferenceEngine::Layout::SCALAR) {
|
||||
THROW_GNA_EXCEPTION << "Expected output blob to have Layout::C, Layout::NC, Layout::CN, Layout::NCHW or "
|
||||
"Layout::CHW. But was "
|
||||
<< outputBlob->getTensorDesc().getLayout();
|
||||
<< output_layout;
|
||||
}
|
||||
|
||||
auto dims = outputBlob->getTensorDesc().getDims();
|
||||
auto is1D = outputBlob->getTensorDesc().getLayout() == InferenceEngine::Layout::C;
|
||||
auto isScalar = outputBlob->getTensorDesc().getLayout() == InferenceEngine::Layout::SCALAR;
|
||||
auto is3D = outputBlob->getTensorDesc().getLayout() == InferenceEngine::Layout::CHW;
|
||||
auto dims = output_blob->getTensorDesc().getDims();
|
||||
auto is1D = output_layout == InferenceEngine::Layout::C;
|
||||
auto isScalar = output_layout == InferenceEngine::Layout::SCALAR;
|
||||
auto is3D = output_layout == InferenceEngine::Layout::CHW;
|
||||
auto batchSize = (is1D || isScalar || is3D) ? 1 : dims[0];
|
||||
auto elementsPerBatch =
|
||||
isScalar ? 1
|
||||
: (is1D ? dims.front() : InferenceEngine::details::product(++std::begin(dims), std::end(dims)));
|
||||
isScalar ? 1 : (is1D ? dims.front() : details::product(++std::begin(dims), std::end(dims)));
|
||||
|
||||
auto transpose_output_info = transpose_outputs_info.find(outputBlobIt.first);
|
||||
if (transpose_output_info != std::end(transpose_outputs_info) &&
|
||||
FoundPartToTranspose(transpose_output_info->second)) {
|
||||
size_t transposed_data_size = 0;
|
||||
for (const auto& part_transposition_info : transpose_output_info->second) {
|
||||
transposed_data_size +=
|
||||
part_transposition_info.num_transpose_rows * part_transposition_info.num_transpose_columns;
|
||||
}
|
||||
if (elementsPerBatch != transposed_data_size) {
|
||||
THROW_GNA_EXCEPTION << "Transposed data size (" << transposed_data_size
|
||||
<< ") do not match output buffer length of " << elementsPerBatch;
|
||||
}
|
||||
ConvertTensorFromNCHWToNHWC(outputDesc.tensor_precision.size(),
|
||||
batchSize,
|
||||
elementsPerBatch,
|
||||
reinterpret_cast<uint8_t*>(outputDesc.ptrs[request_idx]),
|
||||
true,
|
||||
transpose_output_info->second);
|
||||
OutputDesc& gna_output_desc = outputs_.at(output_name);
|
||||
Blob::Ptr gna_output_blob = nullptr;
|
||||
|
||||
// Perform postprocessing on CPU
|
||||
std::shared_ptr<ov::Model> model = gna_output_desc.pre_post_process_model;
|
||||
if (model) {
|
||||
// WA: evaluate gather with int16 precision as fp16
|
||||
Precision preproc_prc = (gna_output_desc.tensor_precision == Precision::I16)
|
||||
? Precision(Precision::FP16)
|
||||
: gna_output_desc.tensor_precision;
|
||||
const SizeVector& input_dims = model->get_parameters().front()->get_shape();
|
||||
TensorDesc input_desc(preproc_prc, input_dims, InferenceEngine::Layout::ANY);
|
||||
Blob::Ptr input_blob = make_blob_with_precision(input_desc, gna_output_desc.ptrs[request_idx]);
|
||||
|
||||
const SizeVector& output_dims = model->get_result()->get_shape();
|
||||
TensorDesc output_desc(preproc_prc, output_dims, InferenceEngine::Layout::ANY);
|
||||
gna_output_blob = make_blob_with_precision(output_desc);
|
||||
gna_output_blob->allocate();
|
||||
|
||||
PrePostProcess(input_blob, gna_output_blob, model);
|
||||
} else {
|
||||
log::debug() << "Postprocessing for output " << output_name << " is not required" << std::endl;
|
||||
TensorDesc output_desc(gna_output_desc.tensor_precision,
|
||||
gna_output_desc.dims,
|
||||
gna_output_desc.model_layout);
|
||||
gna_output_blob = make_blob_with_precision(output_desc, gna_output_desc.ptrs[request_idx]);
|
||||
}
|
||||
|
||||
ExportScores(outputBlob->buffer(),
|
||||
outputDesc.ptrs[request_idx],
|
||||
outputDesc.orientation,
|
||||
ExportScores(output_blob->buffer(),
|
||||
gna_output_blob->cbuffer(),
|
||||
gna_output_desc.orientation,
|
||||
batchSize,
|
||||
batchSize,
|
||||
elementsPerBatch,
|
||||
elementsPerBatch,
|
||||
elementsPerBatch,
|
||||
outputDesc.tensor_precision,
|
||||
outputDesc.model_precision);
|
||||
gna_output_desc.tensor_precision,
|
||||
gna_output_desc.model_precision);
|
||||
|
||||
if (gnadevice) {
|
||||
#ifdef PLOT
|
||||
@ -1295,11 +1354,11 @@ RequestStatus GNAPlugin::WaitFor(uint32_t request_idx, int64_t millisTimeout) {
|
||||
num_infers++;
|
||||
if (f) {
|
||||
if (isScalar) {
|
||||
fprintf(f, "%d ", outputBlob->cbuffer().as<int32_t*>()[0]);
|
||||
fprintf(f, "%d ", output_blob->cbuffer().as<int32_t*>()[0]);
|
||||
} else {
|
||||
for (int i = 0; i < batchSize; i++) {
|
||||
for (int j = 0; j < dims[dims.size() - 1]; j++) {
|
||||
fprintf(f, "%d ", outputBlob->cbuffer().as<int32_t*>()[dims[dims.size() - 1] * i + j]);
|
||||
fprintf(f, "%d ", output_blob->cbuffer().as<int32_t*>()[dims[dims.size() - 1] * i + j]);
|
||||
}
|
||||
fprintf(f, "\n");
|
||||
}
|
||||
@ -1307,25 +1366,25 @@ RequestStatus GNAPlugin::WaitFor(uint32_t request_idx, int64_t millisTimeout) {
|
||||
fprintf(f, "\n\n");
|
||||
}
|
||||
#endif
|
||||
switch (outputBlob->getTensorDesc().getPrecision()) {
|
||||
switch (output_blob->getTensorDesc().getPrecision()) {
|
||||
case InferenceEngine::Precision::FP32:
|
||||
UnscaleAndCast(outputBlob->buffer().as<float*>(),
|
||||
outputBlob->buffer().as<int32_t*>(),
|
||||
UnscaleAndCast(output_blob->buffer().as<float*>(),
|
||||
output_blob->buffer().as<int32_t*>(),
|
||||
elementsPerBatch,
|
||||
batchSize,
|
||||
outputDesc.scale_factor);
|
||||
gna_output_desc.scale_factor);
|
||||
break;
|
||||
|
||||
case InferenceEngine::Precision::I32:
|
||||
UnscaleAndCast(outputBlob->buffer().as<int32_t*>(),
|
||||
outputBlob->buffer().as<int32_t*>(),
|
||||
UnscaleAndCast(output_blob->buffer().as<int32_t*>(),
|
||||
output_blob->buffer().as<int32_t*>(),
|
||||
elementsPerBatch,
|
||||
batchSize,
|
||||
outputDesc.scale_factor);
|
||||
gna_output_desc.scale_factor);
|
||||
break;
|
||||
|
||||
default:
|
||||
THROW_GNA_EXCEPTION << "Unsupported target precision: " << outputBlob->getTensorDesc().getPrecision()
|
||||
THROW_GNA_EXCEPTION << "Unsupported target precision: " << output_blob->getTensorDesc().getPrecision()
|
||||
<< std::endl;
|
||||
break;
|
||||
}
|
||||
@ -1333,12 +1392,12 @@ RequestStatus GNAPlugin::WaitFor(uint32_t request_idx, int64_t millisTimeout) {
|
||||
#ifdef PLOT
|
||||
if (f) {
|
||||
if (isScalar) {
|
||||
fprintf(f, "%.7f ", outputBlob->cbuffer().as<float*>()[0]);
|
||||
fprintf(f, "%.7f ", output_blob->cbuffer().as<float*>()[0]);
|
||||
} else {
|
||||
auto dims = outputBlob->getTensorDesc().getDims();
|
||||
auto dims = output_blob->getTensorDesc().getDims();
|
||||
for (int i = 0; i < batchSize; i++) {
|
||||
for (int j = 0; j < dims[dims.size() - 1]; j++) {
|
||||
fprintf(f, "%.7f ", outputBlob->cbuffer().as<float*>()[dims[dims.size() - 1] * i + j]);
|
||||
fprintf(f, "%.7f ", output_blob->cbuffer().as<float*>()[dims[dims.size() - 1] * i + j]);
|
||||
}
|
||||
fprintf(f, "\n");
|
||||
}
|
||||
@ -1503,6 +1562,14 @@ InferenceEngine::IExecutableNetworkInternal::Ptr GNAPlugin::ImportNetwork(std::i
|
||||
}
|
||||
}
|
||||
|
||||
// Support model versions <= 2.8
|
||||
if (!transpose_inputs_info.empty()) {
|
||||
ConvertTransposeMapToModel(transpose_inputs_info, inputs_ptr_->Get());
|
||||
}
|
||||
if (!transpose_outputs_info.empty()) {
|
||||
ConvertTransposeMapToModel(transpose_outputs_info, outputs_.Get());
|
||||
}
|
||||
|
||||
for (auto&& memory : mt) {
|
||||
GNAMemoryLayer memoryLayer(nullptr, nullptr, gnaFlags->sw_fp32 ? 4 : 2);
|
||||
std::string name;
|
||||
|
@ -27,11 +27,11 @@
|
||||
#include "gna_plugin_config.hpp"
|
||||
#include "log/debug.hpp"
|
||||
#include "log/log.hpp"
|
||||
#include "pre_post_process/transposition_info.hpp"
|
||||
|
||||
namespace ov {
|
||||
namespace intel_gna {
|
||||
namespace request {
|
||||
|
||||
class ModelWrapper;
|
||||
class WorkerPool;
|
||||
class Worker;
|
||||
@ -51,8 +51,11 @@ protected:
|
||||
GNAGraphCompiler graphCompiler;
|
||||
|
||||
uint32_t activeLayerIndex = 0xffffffff;
|
||||
TranspositionInfoMap transpose_inputs_info;
|
||||
TranspositionInfoMap transpose_outputs_info;
|
||||
// TODO: transpose_inputs_info and transpose_outputs_info should be moved to GNAModelSerial class when ngraph
|
||||
// migration is finished. Those structures are needed to support the exported models <= 2.8.
|
||||
pre_post_processing::TranspositionInfoMap transpose_inputs_info;
|
||||
pre_post_processing::TranspositionInfoMap transpose_outputs_info;
|
||||
PrePostProcessModels m_input_output_subgraphs;
|
||||
|
||||
uint32_t dnn_dump_write_index = 0;
|
||||
intel_dnn_number_type_t output_type = kDnnInt;
|
||||
@ -188,6 +191,17 @@ protected:
|
||||
void InitGNADevice();
|
||||
|
||||
void DumpXNNToFile() const;
|
||||
/**
|
||||
* @brief Run ngraph model on CPU to modify input or output (transposing, gathering)
|
||||
* Method supports only models with 1 input and 1 output.
|
||||
* @param input_blob input blob memory
|
||||
* @param output_blob output blob memory
|
||||
* @param model ngraph function needs to be executed to modify input blob and put result to the output blob
|
||||
* @return void
|
||||
*/
|
||||
void PrePostProcess(InferenceEngine::Blob::Ptr input_blob,
|
||||
InferenceEngine::Blob::Ptr output_blob,
|
||||
std::shared_ptr<ov::Model> model);
|
||||
|
||||
void ImportFrames(void* ptr_dst,
|
||||
const void* ptr_src,
|
||||
|
@ -57,7 +57,8 @@
|
||||
namespace ov {
|
||||
namespace intel_gna {
|
||||
|
||||
void TransformationsPipeline::apply(const std::shared_ptr<ov::Model>& model) {
|
||||
void TransformationsPipeline::apply(const std::shared_ptr<ov::Model>& model,
|
||||
ov::intel_gna::PrePostProcessModels* subgraph_cpu_map) {
|
||||
OV_ITT_SCOPED_TASK(itt::domains::GNAPlugin, "TransformationsPipeline::apply");
|
||||
|
||||
fake_quantized = ov::op::util::has_op_with_type<ngraph::op::FakeQuantize>(model);
|
||||
|
@ -7,6 +7,7 @@
|
||||
#include <memory>
|
||||
|
||||
#include "cpp/ie_cnn_network.h"
|
||||
#include "gna_data_types.hpp"
|
||||
#include "gna_plugin_config.hpp"
|
||||
#include "openvino/core/model.hpp"
|
||||
|
||||
@ -18,7 +19,8 @@ public:
|
||||
explicit TransformationsPipeline(const Config& config) : config(config) {
|
||||
effective_compile_target = config.target->get_effective_compile_target();
|
||||
}
|
||||
void apply(const std::shared_ptr<ov::Model>& model);
|
||||
void apply(const std::shared_ptr<ov::Model>& model,
|
||||
ov::intel_gna::PrePostProcessModels* subgraph_cpu_map = nullptr);
|
||||
IE_SUPPRESS_DEPRECATED_START
|
||||
void apply_legacy(const InferenceEngine::CNNNetwork& network, bool runBeforeCopy);
|
||||
void convert_precision_legacy(InferenceEngine::CNNNetwork& network);
|
||||
|
@ -7,7 +7,7 @@
|
||||
#include "frontend/quantized_layer_params.hpp"
|
||||
#include "gna_graph_tools.hpp"
|
||||
#include "ie_layouts.h"
|
||||
#include "preprocessing.hpp"
|
||||
#include "pre_post_process/preprocessing.hpp"
|
||||
|
||||
namespace ov {
|
||||
namespace intel_gna {
|
||||
@ -70,11 +70,11 @@ void GNAVariableState::SetState(const InferenceEngine::Blob::Ptr& newState) {
|
||||
auto quantized =
|
||||
InferenceEngine::getInjectedData<ov::intel_gna::frontend::QuantizedLayerParams>(state->getInput());
|
||||
auto scale_factor = quantized != nullptr ? quantized->_dst_quant.GetScale() : state->scale_factor;
|
||||
ConvertToInt16(static_cast<int16_t*>(state->gna_ptr),
|
||||
newState->buffer().as<float*>(),
|
||||
1,
|
||||
data_elements,
|
||||
scale_factor);
|
||||
pre_post_processing::ConvertToInt16(static_cast<int16_t*>(state->gna_ptr),
|
||||
newState->buffer().as<float*>(),
|
||||
1,
|
||||
data_elements,
|
||||
scale_factor);
|
||||
} else {
|
||||
THROW_GNA_EXCEPTION
|
||||
<< "Failed to SetState for VariableState " << name
|
||||
|
@ -43,11 +43,13 @@
|
||||
#include "layers/gna_layer_info.hpp"
|
||||
#include "log/debug.hpp"
|
||||
#include "log/log.hpp"
|
||||
#include "pre_post_process/transposition_info.hpp"
|
||||
|
||||
using namespace InferenceEngine;
|
||||
using namespace InferenceEngine::details;
|
||||
using namespace ov::intel_gna::frontend;
|
||||
using namespace ov::intel_gna::common;
|
||||
using namespace ov::intel_gna::pre_post_processing;
|
||||
|
||||
namespace ov {
|
||||
namespace intel_gna {
|
||||
|
@ -6,6 +6,7 @@
|
||||
|
||||
namespace ov {
|
||||
namespace intel_gna {
|
||||
namespace pre_post_processing {
|
||||
|
||||
int16_t ConvertFloatToInt16(float src) {
|
||||
float rounding_value = (src > 0) ? 0.5f : -0.5f;
|
||||
@ -42,5 +43,6 @@ void ConvertToInt16(int16_t* ptr_dst,
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace pre_post_processing
|
||||
} // namespace intel_gna
|
||||
} // namespace ov
|
@ -8,6 +8,7 @@
|
||||
|
||||
namespace ov {
|
||||
namespace intel_gna {
|
||||
namespace pre_post_processing {
|
||||
|
||||
void ConvertToInt16(int16_t* ptr_dst,
|
||||
const float* ptr_src,
|
||||
@ -36,5 +37,6 @@ inline void UnscaleAndCast(T2* ptr_dst,
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace pre_post_processing
|
||||
} // namespace intel_gna
|
||||
} // namespace ov
|
||||
} // namespace ov
|
@ -0,0 +1,101 @@
|
||||
// Copyright (C) 2023 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
#include "transposition_info.hpp"
|
||||
|
||||
#include <vector>
|
||||
|
||||
#include "log/debug.hpp"
|
||||
#include "openvino/core/model.hpp"
|
||||
#include "openvino/core/shape.hpp"
|
||||
#include "openvino/opsets/opset10.hpp"
|
||||
|
||||
namespace ov {
|
||||
namespace intel_gna {
|
||||
namespace pre_post_processing {
|
||||
|
||||
using namespace ov::opset10;
|
||||
|
||||
std::shared_ptr<ov::Model> ToProcessModel(const TranspositionInfo& t_info) {
|
||||
int32_t c_size = t_info.num_transpose_rows;
|
||||
int32_t hw_size = t_info.num_transpose_columns;
|
||||
|
||||
if (!t_info.transpose) {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
ov::PartialShape input_shape{1, c_size, hw_size};
|
||||
auto param = std::make_shared<Parameter>(ov::element::f32, input_shape);
|
||||
|
||||
// legacy way was to swap C and HW dimensions in the reshaped tensor
|
||||
std::vector<int32_t> reshape_pattern{-1, c_size, hw_size};
|
||||
auto reshape_const =
|
||||
std::make_shared<Constant>(ov::element::i32, ov::Shape{reshape_pattern.size()}, reshape_pattern);
|
||||
auto reshape = std::make_shared<Reshape>(param, reshape_const, false);
|
||||
|
||||
// CHW -> HWC or HWC -> CHW
|
||||
std::vector<int8_t> transpose_order{0, 2, 1};
|
||||
auto transpose_const =
|
||||
std::make_shared<Constant>(ov::element::i8, ov::Shape{transpose_order.size()}, transpose_order);
|
||||
auto transpose = std::make_shared<Transpose>(reshape, transpose_const);
|
||||
|
||||
auto result = std::make_shared<Result>(transpose);
|
||||
|
||||
std::shared_ptr<ov::Model> model =
|
||||
std::make_shared<ov::Model>(ov::ResultVector{result}, ov::ParameterVector{param});
|
||||
|
||||
return model;
|
||||
}
|
||||
|
||||
std::shared_ptr<ov::Model> ToProcessModel(const std::vector<TranspositionInfo>& transposes) {
|
||||
// count transposition parts need to be transposed
|
||||
int count_transposes = std::count_if(transposes.begin(), transposes.end(), [](TranspositionInfo t_info) {
|
||||
return t_info.transpose || t_info.num_transpose_rows != 1 || t_info.num_transpose_columns != 1;
|
||||
});
|
||||
if (count_transposes == 0) {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
// case when the input should be transposed entirely
|
||||
if (transposes.size() == 1) {
|
||||
return ToProcessModel(transposes.front());
|
||||
}
|
||||
|
||||
std::vector<int32_t> indexes = {};
|
||||
for (auto& transpose : transposes) {
|
||||
size_t c_size = transpose.num_transpose_rows;
|
||||
size_t hw_size = transpose.num_transpose_columns;
|
||||
if (c_size == 0 || hw_size == 0) {
|
||||
THROW_GNA_EXCEPTION << "Incorrect transposition dimentions";
|
||||
}
|
||||
size_t chw_size = c_size * hw_size;
|
||||
size_t id = indexes.size();
|
||||
for (size_t i{0}; i < chw_size; ++i) {
|
||||
size_t idx = (transpose.transpose) ? hw_size * (i % c_size) + i / c_size : i;
|
||||
indexes.push_back(id + idx);
|
||||
}
|
||||
}
|
||||
|
||||
auto param = std::make_shared<Parameter>(ov::element::f32, ov::Shape{1, indexes.size()});
|
||||
// legacy way was to swap C and HW dimensions in the reshaped tensor
|
||||
std::vector<int32_t> reshape_pattern{-1, static_cast<int32_t>(indexes.size())};
|
||||
auto reshape_const =
|
||||
std::make_shared<Constant>(ov::element::i32, ov::Shape{reshape_pattern.size()}, reshape_pattern);
|
||||
auto reshape = std::make_shared<Reshape>(param, reshape_const, false);
|
||||
|
||||
// CHW -> HWC or HWC -> CHW
|
||||
auto gather_indexes = std::make_shared<Constant>(ov::element::i32, ov::Shape{indexes.size()}, indexes);
|
||||
auto gather_axis = std::make_shared<Constant>(ov::element::i8, ov::Shape{1}, std::vector<int8_t>{1});
|
||||
auto gather = std::make_shared<Gather>(reshape, gather_indexes, gather_axis);
|
||||
|
||||
auto result = std::make_shared<Result>(gather);
|
||||
|
||||
std::shared_ptr<ov::Model> model =
|
||||
std::make_shared<ov::Model>(ov::ResultVector{result}, ov::ParameterVector{param});
|
||||
|
||||
return model;
|
||||
}
|
||||
|
||||
} // namespace pre_post_processing
|
||||
} // namespace intel_gna
|
||||
} // namespace ov
|
@ -0,0 +1,60 @@
|
||||
// Copyright (C) 2023 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <cstdint>
|
||||
#include <vector>
|
||||
|
||||
#include "openvino/core/model.hpp"
|
||||
|
||||
namespace ov {
|
||||
namespace intel_gna {
|
||||
namespace pre_post_processing {
|
||||
|
||||
struct TranspositionInfo {
|
||||
bool transpose;
|
||||
size_t num_transpose_rows;
|
||||
size_t num_transpose_columns;
|
||||
};
|
||||
|
||||
using TranspositionInfoMap = std::map<std::string, std::vector<TranspositionInfo>>;
|
||||
|
||||
/*
|
||||
* Converts TranspositionInfo struct to ngraph function.
|
||||
* This method creates ngraph function with Transpose layer.
|
||||
*/
|
||||
std::shared_ptr<ov::Model> ToProcessModel(const TranspositionInfo& t_info);
|
||||
/*
|
||||
* Converts several TranspositionInfo structures to ngraph function.
|
||||
* This method creates ngraph function with Gather layer.
|
||||
*/
|
||||
std::shared_ptr<ov::Model> ToProcessModel(const std::vector<TranspositionInfo>& transposes);
|
||||
|
||||
/*
|
||||
* Converts transposition maps to ngraph model, which will be ran on CPU as pre/post-processing step.
|
||||
* This conversion is needed to support the exported models version <= 2.8 (OV < 2023.0)
|
||||
* @return
|
||||
*/
|
||||
template <class T1, class T2>
|
||||
void ConvertTransposeMapToModel(T1& transposes, T2& nodes) {
|
||||
for (auto&& node : nodes) {
|
||||
auto t_it = transposes.find(node.name);
|
||||
if (t_it != transposes.end() && !t_it->second.empty()) {
|
||||
node.pre_post_process_model = ToProcessModel(t_it->second);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
static inline bool FoundPartToTranspose(const std::vector<TranspositionInfo>& transposes) {
|
||||
auto part_to_transpose =
|
||||
std::find_if(std::begin(transposes), std::end(transposes), [](const TranspositionInfo& t_info) {
|
||||
return t_info.transpose;
|
||||
});
|
||||
return part_to_transpose != std::end(transposes);
|
||||
}
|
||||
|
||||
} // namespace pre_post_processing
|
||||
} // namespace intel_gna
|
||||
} // namespace ov
|
@ -16,18 +16,19 @@
|
||||
# include <malloc.h>
|
||||
#else
|
||||
# include <mm_malloc.h>
|
||||
|
||||
# include <serial/headers/2dot2/gna_model_header.hpp>
|
||||
# include <serial/headers/2dot5/gna_model_header.hpp>
|
||||
# include <serial/headers/2dot7/gna_model_header.hpp>
|
||||
# include <serial/headers/2dot8/gna_model_header.hpp>
|
||||
|
||||
#endif
|
||||
|
||||
#include "common/versioning.hpp"
|
||||
#include "gna2_model_helper.hpp"
|
||||
#include "gna_model_serial.hpp"
|
||||
#include "gna_plugin.hpp"
|
||||
#include "openvino/pass/serialize.hpp"
|
||||
#include "openvino/runtime/core.hpp"
|
||||
#include "serial/headers/2dot2/gna_model_header.hpp"
|
||||
#include "serial/headers/2dot5/gna_model_header.hpp"
|
||||
#include "serial/headers/2dot7/gna_model_header.hpp"
|
||||
#include "serial/headers/2dot8/gna_model_header.hpp"
|
||||
#include "serial/headers/2dot9/gna_model_header.hpp"
|
||||
#include "serial/headers/latest/gna_model_header.hpp"
|
||||
|
||||
using namespace ov::intel_gna;
|
||||
@ -48,6 +49,25 @@ inline void writeString(const std::string& str, std::ostream& os) {
|
||||
writeNBytes(c_str, str_len, os);
|
||||
}
|
||||
|
||||
inline void write_pre_processing_model(const std::shared_ptr<ov::Model>& model, std::ostream& os) {
|
||||
// allocate buffer for ir.xml
|
||||
std::ostringstream xml_buf;
|
||||
// allocate buffer for ir.bin
|
||||
std::ostringstream bin_buf;
|
||||
|
||||
// serialize IR to stream buffer (.xml + .bin)
|
||||
ov::pass::Serialize serializer(xml_buf, bin_buf);
|
||||
serializer.run_on_model(model);
|
||||
|
||||
// write IR
|
||||
writeString(xml_buf.str(), os);
|
||||
|
||||
// write BIN
|
||||
size_t ir_bin_size = bin_buf.str().size();
|
||||
writeBits(ir_bin_size, os);
|
||||
writeNBytes(bin_buf.str().c_str(), ir_bin_size, os);
|
||||
}
|
||||
|
||||
template <class T>
|
||||
inline void readBits(T& obj, std::istream& is) {
|
||||
is.read(reinterpret_cast<char*>(&obj), sizeof(T));
|
||||
@ -168,11 +188,12 @@ header_latest::ModelHeader GNAModelSerial::ReadHeader(std::istream& is) {
|
||||
case 6:
|
||||
case 7:
|
||||
case 8:
|
||||
case 9:
|
||||
readNBytes(&header, sizeof(header_latest::ModelHeader), is);
|
||||
break;
|
||||
default:
|
||||
THROW_GNA_EXCEPTION
|
||||
<< "Imported file unsupported. minor version should have values in range 1 to 8 and is: "
|
||||
<< "Imported file unsupported. minor version should have values in range 1 to 9 and is: "
|
||||
<< header.version.minor;
|
||||
}
|
||||
break;
|
||||
@ -217,11 +238,12 @@ header_latest::RuntimeEndPoint GNAModelSerial::ReadEndPoint(std::istream& is) {
|
||||
break;
|
||||
}
|
||||
case 8:
|
||||
case 9:
|
||||
readNBytes(&endPoint, sizeof(header_latest::RuntimeEndPoint), is);
|
||||
break;
|
||||
default:
|
||||
THROW_GNA_EXCEPTION
|
||||
<< "Imported file unsupported. minor version should have values in range 1 to 8 and is: "
|
||||
<< "Imported file unsupported. minor version should have values in range 1 to 9 and is: "
|
||||
<< model_header_.version.minor;
|
||||
}
|
||||
break;
|
||||
@ -269,7 +291,8 @@ void GNAModelSerial::Import(void* basePointer,
|
||||
(model_header_.version.minor >= 3) ? readString(is) : std::string("input" + std::to_string(inputIndex));
|
||||
inputs[name] = InputDesc(name);
|
||||
}
|
||||
if (model_header_.version.minor >= 5) {
|
||||
// Plugin uses ngraph pre/post-processing function to transpose inputs/outputs starting from version 2.9
|
||||
if (model_header_.version.minor >= 5 && model_header_.version.minor <= 8) {
|
||||
// 3. Read transposition input info
|
||||
for (int inputIx = 0; inputIx < model_header_.nTransposeInputs; ++inputIx) {
|
||||
std::string inputName;
|
||||
@ -287,7 +310,7 @@ void GNAModelSerial::Import(void* basePointer,
|
||||
}
|
||||
}
|
||||
// 5. Read Inputs endpoints
|
||||
ImportInputs(is, basePointer, inputs);
|
||||
ImportNodes(is, basePointer, inputs);
|
||||
// 6. Read output names
|
||||
if (model_header_.version.major == 2) {
|
||||
for (auto outputIndex = 0; outputIndex < model_header_.nOutputs; outputIndex++) {
|
||||
@ -297,7 +320,7 @@ void GNAModelSerial::Import(void* basePointer,
|
||||
}
|
||||
}
|
||||
// 7. Read outputs
|
||||
ImportOutputs(is, basePointer, outputs);
|
||||
ImportNodes(is, basePointer, outputs);
|
||||
|
||||
for (auto operation = gna2model_->Operations; operation != gna2model_->Operations + gna2model_->NumberOfOperations;
|
||||
++operation) {
|
||||
@ -463,10 +486,8 @@ void GNAModelSerial::Export(const GnaAllocations& allocations, std::ostream& os)
|
||||
// Write the input name
|
||||
writeString(input.name, os);
|
||||
}
|
||||
// 3. Write transposition input info
|
||||
ExportTranspositionInfo(os, inputs_transpose_info_);
|
||||
// 4. Write transposition output info
|
||||
ExportTranspositionInfo(os, outputs_transpose_info_);
|
||||
// 3. Write transposition input info - removed in v.2.9
|
||||
// 4. Write transposition output info - removed in v.2.9
|
||||
// 5. Write input endpoints and tensor names
|
||||
for (const auto& input : inputs_.Get()) {
|
||||
// write RuntimeEndPoint
|
||||
@ -475,6 +496,13 @@ void GNAModelSerial::Export(const GnaAllocations& allocations, std::ostream& os)
|
||||
for (const auto& tname : input.tensor_names) {
|
||||
writeString(tname, os);
|
||||
}
|
||||
// write pre-processing model
|
||||
if (input.pre_post_process_model) {
|
||||
write_pre_processing_model(input.pre_post_process_model, os);
|
||||
} else {
|
||||
// write empty string to detect that model is absent during the import
|
||||
writeString("", os);
|
||||
}
|
||||
}
|
||||
// 6. Write outputs names
|
||||
for (auto& output : outputs_.Get()) {
|
||||
@ -489,6 +517,14 @@ void GNAModelSerial::Export(const GnaAllocations& allocations, std::ostream& os)
|
||||
for (auto& tname : output.tensor_names) {
|
||||
writeString(tname, os);
|
||||
}
|
||||
|
||||
// write post-processing model
|
||||
if (output.pre_post_process_model) {
|
||||
write_pre_processing_model(output.pre_post_process_model, os);
|
||||
} else {
|
||||
// write empty string to detect that model is absent during the import
|
||||
writeString("", os);
|
||||
}
|
||||
}
|
||||
// 8. Write layers
|
||||
for (const auto& layer : layers) {
|
||||
@ -563,61 +599,49 @@ void GNAModelSerial::Export(const GnaAllocations& allocations, std::ostream& os)
|
||||
version_.Export(os);
|
||||
}
|
||||
|
||||
void GNAModelSerial::ImportInputs(std::istream& is, void* basePtr, GnaInputs& inputs) {
|
||||
for (auto& input : inputs.Get()) {
|
||||
template <class T>
|
||||
void GNAModelSerial::ImportNodes(std::istream& is, void* base_ptr, T& nodes) {
|
||||
for (auto& node : nodes.Get()) {
|
||||
header_latest::RuntimeEndPoint ep = ReadEndPoint(is);
|
||||
|
||||
input.ptrs.push_back(reinterpret_cast<float*>(reinterpret_cast<uint8_t*>(basePtr) + ep.descriptor_offset));
|
||||
input.orientation = ep.orientation;
|
||||
input.num_elements = ep.elements_count;
|
||||
input.scale_factor = ep.scaleFactor;
|
||||
input.model_precision =
|
||||
node.ptrs.push_back(reinterpret_cast<float*>(reinterpret_cast<uint8_t*>(base_ptr) + ep.descriptor_offset));
|
||||
node.orientation = ep.orientation;
|
||||
node.num_elements = ep.elements_count;
|
||||
node.scale_factor = ep.scaleFactor;
|
||||
node.model_precision =
|
||||
InferenceEngine::Precision(static_cast<InferenceEngine::Precision::ePrecision>(ep.precision));
|
||||
input.set_precision(ep.element_size);
|
||||
input.model_layout = static_cast<InferenceEngine::Layout>(ep.layout);
|
||||
input.allocated_size = input.get_required_size();
|
||||
node.set_precision(ep.element_size);
|
||||
node.model_layout = static_cast<InferenceEngine::Layout>(ep.layout);
|
||||
node.allocated_size = node.get_required_size();
|
||||
|
||||
auto inputDims = InferenceEngine::SizeVector();
|
||||
for (auto i = 0; i < ep.shape.NumberOfDimensions; ++i) {
|
||||
inputDims.push_back(ep.shape.Dimensions[i]);
|
||||
}
|
||||
input.dims = inputDims;
|
||||
node.dims = inputDims;
|
||||
|
||||
// read tensor names
|
||||
for (uint8_t tId = 0; tId < ep.tensor_names_count; ++tId) {
|
||||
input.tensor_names.insert(readString(is));
|
||||
node.tensor_names.insert(readString(is));
|
||||
}
|
||||
AppendTensorNameIfNeeded(node);
|
||||
|
||||
AppendTensorNameIfNeeded(input);
|
||||
}
|
||||
}
|
||||
// read pre-sprocessing model
|
||||
if (model_header_.version.major == 2 && model_header_.version.minor >= 9) {
|
||||
std::string ir_xml_str = readString(is);
|
||||
if (!ir_xml_str.empty()) {
|
||||
// read IR bin
|
||||
size_t ir_bin_size = 0;
|
||||
readBits(ir_bin_size, is);
|
||||
|
||||
void GNAModelSerial::ImportOutputs(std::istream& is, void* basePtr, GnaOutputs& outputs) {
|
||||
for (auto& output : outputs.Get()) {
|
||||
header_latest::RuntimeEndPoint ep = ReadEndPoint(is);
|
||||
ov::Tensor ir_bin_tensor(ov::element::u8, ov::Shape({ir_bin_size}));
|
||||
readNBytes(ir_bin_tensor.data(), ir_bin_size, is);
|
||||
|
||||
output.ptrs.push_back(reinterpret_cast<float*>(reinterpret_cast<uint8_t*>(basePtr) + ep.descriptor_offset));
|
||||
output.orientation = ep.orientation;
|
||||
output.num_elements = ep.elements_count;
|
||||
output.scale_factor = ep.scaleFactor;
|
||||
output.set_precision(ep.element_size);
|
||||
output.model_precision =
|
||||
InferenceEngine::Precision(static_cast<InferenceEngine::Precision::ePrecision>(ep.precision));
|
||||
output.model_layout = static_cast<InferenceEngine::Layout>(ep.layout);
|
||||
output.allocated_size = output.get_required_size();
|
||||
|
||||
auto outputDims = InferenceEngine::SizeVector();
|
||||
for (auto i = 0; i < ep.shape.NumberOfDimensions; ++i) {
|
||||
outputDims.push_back(ep.shape.Dimensions[i]);
|
||||
// restore model
|
||||
ov::Core core;
|
||||
node.pre_post_process_model = core.read_model(ir_xml_str, ir_bin_tensor);
|
||||
}
|
||||
}
|
||||
output.dims = outputDims;
|
||||
|
||||
// read tensor names
|
||||
for (uint8_t tId = 0; tId < ep.tensor_names_count; ++tId) {
|
||||
output.tensor_names.insert(readString(is));
|
||||
}
|
||||
|
||||
AppendTensorNameIfNeeded(output);
|
||||
}
|
||||
}
|
||||
|
||||
@ -637,19 +661,6 @@ void GNAModelSerial::ImportTranspositionInfo(std::istream& is,
|
||||
}
|
||||
}
|
||||
|
||||
void GNAModelSerial::ExportTranspositionInfo(std::ostream& os, const TranspositionInfoMap& transpositionInfoMap) const {
|
||||
for (const auto& transpositionInfo : transpositionInfoMap) {
|
||||
auto nameSize = strlen(transpositionInfo.first.c_str());
|
||||
writeBits(static_cast<uint32_t>(nameSize), os);
|
||||
writeNBytes(transpositionInfo.first.c_str(), nameSize, os);
|
||||
auto fragmentsNum = transpositionInfo.second.size();
|
||||
writeBits(static_cast<uint32_t>(fragmentsNum), os);
|
||||
for (const auto& transposeFragmentInfo : transpositionInfo.second) {
|
||||
writeNBytes(&transposeFragmentInfo, sizeof(TranspositionInfo), os);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void GNAModelSerial::AppendTensorNameIfNeeded(GnaDesc& nodeDesc) const {
|
||||
static constexpr header_2_dot_8::ModelHeader::Version kHasTensorNamesVersion;
|
||||
|
@ -12,8 +12,15 @@
|
||||
#include "descriptions/gna_desc.hpp"
|
||||
#include "gna2-model-api.h"
|
||||
#include "gna_device_allocation.hpp"
|
||||
#include "pre_post_process/transposition_info.hpp"
|
||||
#include "serial/headers/latest/gna_model_header.hpp"
|
||||
|
||||
namespace ov {
|
||||
namespace intel_gna {
|
||||
|
||||
using TranspositionInfo = pre_post_processing::TranspositionInfo;
|
||||
using TranspositionInfoMap = pre_post_processing::TranspositionInfoMap;
|
||||
|
||||
/**
|
||||
* @brief helper class for GNAGraph serialization tasks
|
||||
*/
|
||||
@ -40,16 +47,13 @@ private:
|
||||
ov::intel_gna::header_latest::ModelHeader model_header_;
|
||||
GNAVersionSerializer version_;
|
||||
|
||||
void ImportInputs(std::istream& is, void* basePtr, ov::intel_gna::GnaInputs& inputs);
|
||||
|
||||
void ImportOutputs(std::istream& is, void* basePtr, ov::intel_gna::GnaOutputs& outputs);
|
||||
template <class T>
|
||||
void ImportNodes(std::istream& is, void* basePtr, T& inputs); // inputs or outputs
|
||||
|
||||
void ImportTranspositionInfo(std::istream& is,
|
||||
std::string& name,
|
||||
std::vector<TranspositionInfo>& transpositionInfo);
|
||||
|
||||
void ExportTranspositionInfo(std::ostream& os, const TranspositionInfoMap& transpositionInfoMap) const;
|
||||
|
||||
/**
|
||||
* @brief Update input or output description to support importing of < 2.8 format where tensor_names were not
|
||||
* present
|
||||
@ -126,3 +130,6 @@ public:
|
||||
*/
|
||||
void Export(const GnaAllocations& allocations, std::ostream& os) const;
|
||||
};
|
||||
|
||||
} // namespace intel_gna
|
||||
} // namespace ov
|
@ -176,7 +176,7 @@ struct RuntimeEndPoint {
|
||||
*this = header_2_dot_8::RuntimeEndPoint(ep_v7);
|
||||
}
|
||||
|
||||
RuntimeEndPoint(header_2_dot_7::RuntimeEndPoint& old) {
|
||||
RuntimeEndPoint(const header_2_dot_7::RuntimeEndPoint& old) {
|
||||
scaleFactor = old.scaleFactor;
|
||||
descriptor_ptr = old.descriptor_ptr;
|
||||
element_size = old.element_size;
|
||||
|
@ -0,0 +1,223 @@
|
||||
// Copyright (C) 2023 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <cstdint>
|
||||
#include <map>
|
||||
|
||||
#include "backend/dnn_types.hpp"
|
||||
#include "gna_data_types.hpp"
|
||||
#include "serial/headers/2dot8/gna_model_header.hpp"
|
||||
|
||||
#pragma pack(push, 1)
|
||||
|
||||
namespace ov {
|
||||
namespace intel_gna {
|
||||
namespace header_2_dot_9 {
|
||||
|
||||
/**
|
||||
Maximal number of supported shape dimensions.
|
||||
*/
|
||||
#define GNA_SHAPE_MAXIMUM_NUMBER_OF_DIMENSIONS 8
|
||||
|
||||
/**
|
||||
* @brief Header version 2.9
|
||||
*/
|
||||
struct ModelHeader {
|
||||
/**
|
||||
*@brief MagicNumber – GNAM in ascii table, equals to hex 0x474e414d
|
||||
*/
|
||||
char gnam[4] = {};
|
||||
/**
|
||||
* @brief if header size is not equal to sizeof ModelHeader - some reserved data append in the end of header
|
||||
* usually it is an indicator of working with version of model different that is current export function produce
|
||||
*/
|
||||
uint32_t headerSize = 0u;
|
||||
struct Version {
|
||||
/**
|
||||
* @details Version of format Major – unsigned int, ex: 0x0001
|
||||
* every change in the header or in the layers definition should be reflected in version change
|
||||
* for backward compatibility new parsers can read old versions of model with certain restrictions
|
||||
*/
|
||||
uint16_t major = 2u;
|
||||
/**
|
||||
* @details Version of Format Minor – unsigned int, corresponding to build revision for example
|
||||
* changes in minor version are not affected layout of model
|
||||
*/
|
||||
uint32_t minor = 9u;
|
||||
} version;
|
||||
/**
|
||||
* @brief Memory required to be allocated using GNAAlloc()
|
||||
*/
|
||||
uint64_t gnaMemSize = 0ull;
|
||||
/**
|
||||
* @brief Number of GNA Layers
|
||||
*/
|
||||
uint64_t layersCount = 0ull;
|
||||
/**
|
||||
* @brief Grouping level
|
||||
* This is depricted field and used for old models only (<=2.6)
|
||||
*/
|
||||
uint32_t nGroup = 0u;
|
||||
|
||||
/**
|
||||
* Convolution related setting - they are affecting input transformation
|
||||
*/
|
||||
uint32_t nRotateRows = 0u;
|
||||
uint32_t nRotateColumns = 0u;
|
||||
bool doRotateInput = false;
|
||||
|
||||
uint32_t nInputs = 0u;
|
||||
uint32_t nOutputs = 0u;
|
||||
|
||||
/**
|
||||
* Convolution related setting - they are affecting output transformation
|
||||
*/
|
||||
uint32_t nRotateOutputRows = 0u;
|
||||
uint32_t nRotateOutputColumns = 0u;
|
||||
bool doRotateOutput = false;
|
||||
|
||||
uint32_t nTransposeInputs = 0u;
|
||||
uint32_t nTransposeOutputs = 0u;
|
||||
|
||||
/**
|
||||
* Reserved Data might be here
|
||||
*/
|
||||
|
||||
ModelHeader() = default;
|
||||
|
||||
ModelHeader(header_2_dot_1::ModelHeader const& old) {
|
||||
gnaMemSize = old.gnaMemSize;
|
||||
layersCount = old.layersCount;
|
||||
nGroup = old.nGroup;
|
||||
nRotateRows = old.nRotateRows;
|
||||
nRotateColumns = old.nRotateColumns;
|
||||
nInputs = old.nInputs;
|
||||
nOutputs = old.nOutputs;
|
||||
version.minor = old.version.minor;
|
||||
}
|
||||
|
||||
ModelHeader(header_2_dot_4::ModelHeader const& old) {
|
||||
gnaMemSize = old.gnaMemSize;
|
||||
layersCount = old.layersCount;
|
||||
nGroup = old.nGroup;
|
||||
nRotateRows = old.nRotateRows;
|
||||
nRotateColumns = old.nRotateColumns;
|
||||
nInputs = old.nInputs;
|
||||
nOutputs = old.nOutputs;
|
||||
nRotateOutputRows = old.nRotateOutputRows;
|
||||
nRotateOutputColumns = old.nRotateOutputColumns;
|
||||
doRotateOutput = old.doRotateOutput;
|
||||
version.minor = old.version.minor;
|
||||
}
|
||||
};
|
||||
#pragma pack(pop)
|
||||
|
||||
/*
|
||||
* In runtime endpoint mostly same as in serial version, except of descriptor field
|
||||
*/
|
||||
struct RuntimeEndPoint {
|
||||
/**
|
||||
* if scale factor is different then pased into infer , network might need to be requantized
|
||||
*/
|
||||
float scaleFactor = 0;
|
||||
/**
|
||||
* Pointer descriptor
|
||||
*/
|
||||
void* descriptor_ptr = nullptr;
|
||||
/**
|
||||
* Endpoint resolution in bytes.
|
||||
*/
|
||||
uint32_t element_size = 0;
|
||||
/**
|
||||
* Number of elements
|
||||
*/
|
||||
uint32_t elements_count = 0;
|
||||
/**
|
||||
* Offset in bytes of pointer descriptor
|
||||
*/
|
||||
uint64_t descriptor_offset = 0ull;
|
||||
/**
|
||||
Shape specifying dimension values.
|
||||
*/
|
||||
struct Shape {
|
||||
/**
|
||||
Number of dimensions or rank or order.
|
||||
*/
|
||||
uint32_t NumberOfDimensions = 0;
|
||||
/**
|
||||
array specifying value of each dimension.
|
||||
Set all zeros for scalars.
|
||||
*/
|
||||
uint32_t Dimensions[GNA_SHAPE_MAXIMUM_NUMBER_OF_DIMENSIONS] = {0};
|
||||
} shape;
|
||||
/**
|
||||
* Blob layout
|
||||
*/
|
||||
uint8_t layout = InferenceEngine::Layout::NC;
|
||||
/**
|
||||
* Blob precision
|
||||
*/
|
||||
uint8_t precision = InferenceEngine::Precision::FP32;
|
||||
/**
|
||||
* Number of tensor names
|
||||
*/
|
||||
uint8_t tensor_names_count = 0;
|
||||
|
||||
intel_dnn_orientation_t orientation = kDnnUnknownOrientation;
|
||||
|
||||
RuntimeEndPoint() = default;
|
||||
|
||||
// support of previous versions
|
||||
RuntimeEndPoint(const header_2_dot_6::RuntimeEndPoint& old, uint32_t ngroup) {
|
||||
header_2_dot_7::RuntimeEndPoint ep_v7 = header_2_dot_7::RuntimeEndPoint(old, ngroup);
|
||||
*this = header_2_dot_9::RuntimeEndPoint(ep_v7);
|
||||
}
|
||||
|
||||
// support of previous versions
|
||||
RuntimeEndPoint(const header_2_dot_7::RuntimeEndPoint& old) {
|
||||
header_2_dot_8::RuntimeEndPoint ep_v8 = header_2_dot_8::RuntimeEndPoint(old);
|
||||
*this = header_2_dot_9::RuntimeEndPoint(ep_v8);
|
||||
}
|
||||
|
||||
RuntimeEndPoint(header_2_dot_8::RuntimeEndPoint& old) {
|
||||
scaleFactor = old.scaleFactor;
|
||||
descriptor_ptr = old.descriptor_ptr;
|
||||
element_size = old.element_size;
|
||||
elements_count = old.elements_count;
|
||||
orientation = old.orientation;
|
||||
layout = old.layout;
|
||||
precision = old.precision;
|
||||
descriptor_offset = old.descriptor_offset;
|
||||
shape.NumberOfDimensions = old.shape.NumberOfDimensions;
|
||||
for (uint32_t i = 0; i < shape.NumberOfDimensions; i++) {
|
||||
shape.Dimensions[i] = old.shape.Dimensions[i];
|
||||
}
|
||||
tensor_names_count = 0;
|
||||
}
|
||||
|
||||
RuntimeEndPoint(double scaleFactor,
|
||||
void* descriptor_ptr,
|
||||
uint32_t element_size,
|
||||
uint32_t elements_count,
|
||||
Shape shape,
|
||||
uint8_t layout,
|
||||
uint8_t precision,
|
||||
uint8_t tensor_names_count,
|
||||
intel_dnn_orientation_t orientation)
|
||||
: scaleFactor(static_cast<float>(scaleFactor)),
|
||||
descriptor_ptr(descriptor_ptr),
|
||||
element_size(element_size),
|
||||
elements_count(elements_count),
|
||||
shape(shape),
|
||||
layout(layout),
|
||||
precision(precision),
|
||||
tensor_names_count(tensor_names_count),
|
||||
orientation(orientation) {}
|
||||
};
|
||||
|
||||
} // namespace header_2_dot_9
|
||||
} // namespace intel_gna
|
||||
} // namespace ov
|
@ -4,14 +4,14 @@
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "serial/headers/2dot8/gna_model_header.hpp"
|
||||
#include "serial/headers/2dot9/gna_model_header.hpp"
|
||||
|
||||
namespace ov {
|
||||
namespace intel_gna {
|
||||
namespace header_latest {
|
||||
|
||||
using ModelHeader = header_2_dot_8::ModelHeader;
|
||||
using RuntimeEndPoint = header_2_dot_8::RuntimeEndPoint;
|
||||
using ModelHeader = header_2_dot_9::ModelHeader;
|
||||
using RuntimeEndPoint = header_2_dot_9::RuntimeEndPoint;
|
||||
|
||||
template <typename A, typename B>
|
||||
bool IsFirstVersionLower(const A& first, const B& second) {
|
||||
|
@ -8,7 +8,7 @@
|
||||
// to suppress deprecated definition errors
|
||||
#define IMPLEMENT_INFERENCE_ENGINE_PLUGIN
|
||||
#include "common/versioning.hpp"
|
||||
#include "gna_model_serial.hpp"
|
||||
#include "serial/gna_model_serial.hpp"
|
||||
|
||||
using namespace testing;
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user