[GNA] Fixed incorrect output layout for convolution (#2426)

* [GNA] add permute for NCHW Convolution H=1

* [GNA] fix permute after convolution case

* [GNA] fix Reshape1DOps transformation case

* [GNA] fix rm_cnn4a tests

* [GNA] fix wsj model tests

* quick fix

* [GNA] fix sw_exact mode

* [GNA] Add rotateOutput for convolution op instead of permute

* [GNA] fix CI

* [GNA] apply changes from review
This commit is contained in:
Anna Alberska 2020-11-23 10:23:57 +01:00 committed by GitHub
parent 55dd8b0a2d
commit f86065ce7a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
11 changed files with 435 additions and 29 deletions

View File

@ -29,8 +29,11 @@ public:
num_left_context(0), num_left_context(0),
num_right_context(0), num_right_context(0),
do_rotate_input(false), do_rotate_input(false),
do_rotate_output(false),
num_rotate_rows(0), num_rotate_rows(0),
num_rotate_columns(0), num_rotate_columns(0),
num_rotate_output_rows(0),
num_rotate_output_columns(0),
softmax_type(kSoftmaxNone), softmax_type(kSoftmaxNone),
ptr_sumgroup_sizes(NULL), ptr_sumgroup_sizes(NULL),
num_sumgroup_sizes(0), num_sumgroup_sizes(0),
@ -309,8 +312,11 @@ public:
uint32_t num_right_context; uint32_t num_right_context;
uint32_t new_num_conv_columns = 0; uint32_t new_num_conv_columns = 0;
bool do_rotate_input; bool do_rotate_input;
bool do_rotate_output;
uint32_t num_rotate_rows = 0; uint32_t num_rotate_rows = 0;
uint32_t num_rotate_columns = 0; uint32_t num_rotate_columns = 0;
uint32_t num_rotate_output_rows = 0;
uint32_t num_rotate_output_columns = 0;
DnnSoftmaxType softmax_type; DnnSoftmaxType softmax_type;
uint32_t *ptr_sumgroup_sizes; uint32_t *ptr_sumgroup_sizes;
uint32_t num_sumgroup_sizes; uint32_t num_sumgroup_sizes;

View File

@ -240,9 +240,6 @@ void GNAGraphCompiler::ConvolutionPrimitive(InferenceEngine::CNNLayerPtr layer)
inputs->getLayout() != Layout::NC) { inputs->getLayout() != Layout::NC) {
THROW_GNA_LAYER_EXCEPTION(layer) << "with layout " << inputs->getLayout() << " isn't currently supported on GNA"; THROW_GNA_LAYER_EXCEPTION(layer) << "with layout " << inputs->getLayout() << " isn't currently supported on GNA";
} }
if (inputs->getLayout() != outputs->getLayout()) {
THROW_GNA_LAYER_EXCEPTION(layer) << "I/O layout mismatch: " << inputs->getLayout() << " vs " << outputs->getLayout();
}
auto in_order = getFromIRDimsOrderNCHW(inputs->getLayout()); auto in_order = getFromIRDimsOrderNCHW(inputs->getLayout());
auto in_batch = FROM_IR_DIM(inputs, in_order[0]); auto in_batch = FROM_IR_DIM(inputs, in_order[0]);
@ -432,6 +429,29 @@ void GNAGraphCompiler::ConvolutionPrimitive(InferenceEngine::CNNLayerPtr layer)
connectOutput(layer, ptr_outputs, num_data_bytes_out); connectOutput(layer, ptr_outputs, num_data_bytes_out);
// When there's a NCHW convolution as a last layer, the output needs to be transposed back to NCHW
// TODO: Jira: 43659 - the issue also appears when after conv there's an eltwise or activation
// For last layer or when next ones are only non functional, the data can be reordered when exporting scores
// For other cases inserting permute is required if data are reordered
auto isNonFunctional = [](CNNLayerPtr l) {
return LayerInfo(l).isNonFunctional();
};
if (getInputTo(layer->outData.front()).empty() || !CNNNetHasNextLayerSkipCertain(layer, 0, 0, isNonFunctional)) {
// if height dim and width dim both equal 1, the permute is not needed to return correct results
// if height dim doesn't equal 1, the case requires additional permute
auto inputDimsCheck = (in_channels != 1 ||
(in_height == 1 && in_width == 1) ||
in_height != 1);
//if kernel is pow of 2 and heigher than 8, then the issue doesn't appear
auto kernelCheck = convolution._kernel_x > 15 && !(convolution._kernel_x & (convolution._kernel_x - 1));
if (!inputDimsCheck && !kernelCheck) {
dnn->do_rotate_output = true;
dnn->num_rotate_output_rows = out_width;
dnn->num_rotate_output_columns = out_channels;
}
}
std::vector<uint8_t> transposedWeights; std::vector<uint8_t> transposedWeights;
for (uint32_t k = 0; k < convolution._out_depth; k++) { for (uint32_t k = 0; k < convolution._out_depth; k++) {
uint8_t * ptr_filt_current uint8_t * ptr_filt_current
@ -1544,7 +1564,7 @@ void GNAGraphCompiler::PWLPrimitive(InferenceEngine::CNNLayerPtr layer) {
if (dnn->new_num_conv_columns) { if (dnn->new_num_conv_columns) {
num_rows = dnn->new_num_conv_columns; num_rows = dnn->new_num_conv_columns;
if (inputs->getDims().size() == 4) num_rows /= FROM_IR_DIM(inputs, 3); if (inputs->getDims().size() == 4) num_rows /= num_columns;
dnn->new_num_conv_columns = 0; dnn->new_num_conv_columns = 0;
} }

View File

@ -107,14 +107,17 @@ GNAPluginNS::HeaderLatest::ModelHeader GNAModelSerial::ReadHeader(std::istream &
switch (header.version.minor) { switch (header.version.minor) {
case 1: case 1:
readBits(tempHeader2dot1, is); readBits(tempHeader2dot1, is);
header = Header2dot3::ModelHeader(tempHeader2dot1); header = HeaderLatest::ModelHeader(tempHeader2dot1);
break; break;
case 2: case 2:
case 3: case 3:
readBits(header, is); readNBytes(&header, sizeof(Header2dot3::ModelHeader), is);
break;
case 4:
readNBytes(&header, sizeof(Header2dot4::ModelHeader), is);
break; break;
default: default:
THROW_GNA_EXCEPTION << "Imported file unsupported. minor version should be equal to 1 or 2 and is: " << header.version.minor; THROW_GNA_EXCEPTION << "Imported file unsupported. minor version should have values in range 1 to 4 and is: " << header.version.minor;
} }
break; break;
default: default:
@ -331,7 +334,9 @@ void GNAModelSerial::Export(void * basePointer, size_t gnaGraphSize, std::ostrea
header.nRotateRows = nRotateRows; header.nRotateRows = nRotateRows;
header.nRotateColumns = nRotateColumns; header.nRotateColumns = nRotateColumns;
header.doRotateInput = doRotateInput; header.doRotateInput = doRotateInput;
header.nRotateOutputRows = nRotateOutputRows;
header.nRotateOutputColumns = nRotateOutputColumns;
header.doRotateOutput = doRotateOutput;
writeBits(header, os); writeBits(header, os);

View File

@ -37,6 +37,9 @@ private:
uint32_t nRotateRows = 0; uint32_t nRotateRows = 0;
uint32_t nRotateColumns = 0; uint32_t nRotateColumns = 0;
bool doRotateInput = false; bool doRotateInput = false;
uint32_t nRotateOutputRows = 0;
uint32_t nRotateOutputColumns = 0;
bool doRotateOutput = false;
MemoryType states, *pstates = nullptr; MemoryType states, *pstates = nullptr;
GNAPluginNS::HeaderLatest::ModelHeader modelHeader; GNAPluginNS::HeaderLatest::ModelHeader modelHeader;
@ -109,6 +112,13 @@ private:
return *this; return *this;
} }
GNAModelSerial& SetOutputRotation(uint32_t nRotateOutputRows, uint32_t nRotateOutputColumns, bool do_rotate_outputs) {
this->nRotateOutputColumns = nRotateOutputColumns;
this->nRotateOutputRows = nRotateOutputRows;
this->doRotateOutput = do_rotate_outputs;
return *this;
}
/** /**
* mark certain part of gna_blob as state (in future naming is possible) * mark certain part of gna_blob as state (in future naming is possible)
* @param descriptor_ptr * @param descriptor_ptr

View File

@ -879,6 +879,10 @@ void GNAPlugin::LoadNetwork(ICNNNetwork & _network) {
num_rotate_rows = dnn->num_rotate_rows; num_rotate_rows = dnn->num_rotate_rows;
num_rotate_columns = dnn->num_rotate_columns; num_rotate_columns = dnn->num_rotate_columns;
do_rotate_output = dnn->do_rotate_output;
num_rotate_output_rows = dnn->num_rotate_output_rows;
num_rotate_output_columns = dnn->num_rotate_output_columns;
DumpXNNToFile(); DumpXNNToFile();
#ifdef PLOT #ifdef PLOT
@ -1166,29 +1170,35 @@ GnaWaitStatus GNAPlugin::WaitFor(uint32_t request_idx, int64_t millisTimeout) {
auto & outputDesc = outputsDesc[output_idx]; auto & outputDesc = outputsDesc[output_idx];
if (outputBlob->getTensorDesc().getLayout() == Layout::NC || outputBlob->getTensorDesc().getLayout() == Layout::CN if (outputBlob->getTensorDesc().getLayout() == Layout::NC || outputBlob->getTensorDesc().getLayout() == Layout::CN
|| outputBlob->getTensorDesc().getLayout() == Layout::NCHW || outputBlob->getTensorDesc().getLayout() == Layout::CHW) { || outputBlob->getTensorDesc().getLayout() == Layout::NCHW || outputBlob->getTensorDesc().getLayout() == Layout::CHW) {
// TODO: rotate can be incorporated with exporting - used only in unit tests so far auto dims = outputBlob->getTensorDesc().getDims();
// TODO: restore:
// if (orientation_out != kDnnInterleavedOrientation) {
// if (inputs.size() != 1) {
// THROW_GNA_EXCEPTION << "Invalid number of inputs for for deinterleave " << inputs.size()
// << ", only 1 supported";
// }
// auto dims = inputs.begin()->second->dims();
// RotateFeatures(reinterpret_cast<uint8_t*>(ptr_outputs_global),
// gnadevice ? 2 : 4,
// dims[dims.size() - 1],
// dims[0], // num_feature_vectors looks batch should be there
// dims[0],
// dims[dims.size() - 1]);
// }
auto is2D = outputBlob->getTensorDesc().getLayout() == Layout::NC || outputBlob->getTensorDesc().getLayout() == Layout::CN; auto is2D = outputBlob->getTensorDesc().getLayout() == Layout::NC || outputBlob->getTensorDesc().getLayout() == Layout::CN;
auto is3D = outputBlob->getTensorDesc().getLayout() == Layout::CHW; auto is3D = outputBlob->getTensorDesc().getLayout() == Layout::CHW;
auto& exportOutputDims = outputBlob->getTensorDesc().getDims(); auto& exportOutputDims = outputBlob->getTensorDesc().getDims();
auto batchSize = is3D ? 1 : exportOutputDims[0]; auto batchSize = is3D ? 1 : exportOutputDims[0];
auto elementsPerBatch = is2D ? exportOutputDims[exportOutputDims.size() - 1] auto elementsPerBatch = is2D ? exportOutputDims[exportOutputDims.size() - 1]
: exportOutputDims[exportOutputDims.size() - 1] : exportOutputDims[exportOutputDims.size() - 1]
* exportOutputDims[exportOutputDims.size() - 2] * exportOutputDims[exportOutputDims.size() - 2]
* exportOutputDims[exportOutputDims.size() - 3]; * exportOutputDims[exportOutputDims.size() - 3];
if (do_rotate_output) {
if (batchSize * elementsPerBatch != num_rotate_output_columns * num_rotate_output_rows) {
THROW_GNA_EXCEPTION << "Rotate output dimensions (" << num_rotate_output_rows << "," << num_rotate_output_columns
<< ") do not match output buffer length of " << batchSize * elementsPerBatch;
}
uint32_t element_size = outputDesc.num_bytes_per_element;
std::vector<uint8_t> temp(num_rotate_output_columns * num_rotate_output_rows * element_size);
for (uint32_t k = 0; k < num_rotate_output_columns; ++k) {
uint8_t* ptr_in = reinterpret_cast<uint8_t*>(outputDesc.ptrs[request_idx]) + k * element_size;
for (uint32_t i = 0; i < num_rotate_output_rows; ++i) {
ie_memcpy(&temp.front() + (k *num_rotate_output_rows + i) * element_size,
element_size,
ptr_in + (i * num_rotate_output_columns) * element_size,
element_size);
}
}
ie_memcpy(outputDesc.ptrs[request_idx], num_rotate_output_columns * num_rotate_output_rows * element_size,
&temp.front(), num_rotate_output_columns * num_rotate_output_rows * element_size);
}
ExportScores(outputBlob->buffer(), ExportScores(outputBlob->buffer(),
outputDesc.ptrs[request_idx], outputDesc.ptrs[request_idx],
@ -1366,6 +1376,10 @@ InferenceEngine::ExecutableNetwork GNAPlugin::ImportNetwork(std::istream& networ
num_rotate_rows = header.nRotateRows; num_rotate_rows = header.nRotateRows;
num_rotate_columns = header.nRotateColumns; num_rotate_columns = header.nRotateColumns;
do_rotate_output = header.doRotateOutput;
num_rotate_output_rows = header.nRotateOutputRows;
num_rotate_output_columns = header.nRotateOutputColumns;
for (auto && memory : mt) { for (auto && memory : mt) {
GNAMemoryLayer memoryLayer(nullptr, nullptr, gnaFlags->sw_fp32 ? 4 : 2); GNAMemoryLayer memoryLayer(nullptr, nullptr, gnaFlags->sw_fp32 ? 4 : 2);
memoryLayer.gna_ptr = memory.first; memoryLayer.gna_ptr = memory.first;
@ -1416,7 +1430,8 @@ void GNAPlugin::Export(const std::string &fileName) {
outputsDesc, outputsDesc,
inputsDataMap, inputsDataMap,
outputsDataMap) outputsDataMap)
.SetInputRotation(dnn->num_rotate_rows, dnn->num_rotate_columns, dnn->do_rotate_input); .SetInputRotation(dnn->num_rotate_rows, dnn->num_rotate_columns, dnn->do_rotate_input)
.SetOutputRotation(dnn->num_rotate_output_rows, dnn->num_rotate_output_columns, dnn->do_rotate_output);
for (auto && memoryConnection : graphCompiler.memory_connection) { for (auto && memoryConnection : graphCompiler.memory_connection) {
serial.AddState(memoryConnection.second.gna_ptr, memoryConnection.second.reserved_size); serial.AddState(memoryConnection.second.gna_ptr, memoryConnection.second.reserved_size);

View File

@ -58,6 +58,9 @@ class GNAPlugin : public InferenceEngine::IInferencePlugin {
bool do_rotate_input = false; bool do_rotate_input = false;
uint32_t num_rotate_rows = 0; uint32_t num_rotate_rows = 0;
uint32_t num_rotate_columns = 0; uint32_t num_rotate_columns = 0;
bool do_rotate_output = false;
uint32_t num_rotate_output_rows = 0;
uint32_t num_rotate_output_columns = 0;
uint32_t *ptr_active_indices = nullptr; uint32_t *ptr_active_indices = nullptr;
uint32_t num_active_indices = 0; uint32_t num_active_indices = 0;
uint32_t num_group_in = 0; uint32_t num_group_in = 0;

View File

@ -0,0 +1,128 @@
// Copyright (C) 2018-2020 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#pragma once
#include <cstdint>
#include "backend/dnn_types.h"
#include "serial/headers/2dot3/gna_model_header.hpp"
#pragma pack(push, 1)
namespace GNAPluginNS {
namespace Header2dot4 {
/**
* @brief Header version 2.4
*/
struct ModelHeader {
/**
*@brief MagicNumber GNAM in ascii table, equals to hex 0x474e414d
*/
char gnam[4] = {};
/**
* @brief if header size is not equal to sizeof ModelHeader - some reserved data append in the end of header
* usually it is an indicator of working with version of model different that is current export function produce
*/
uint32_t headerSize = 0u;
struct Version {
/**
* @details Version of format Major unsigned int, ex: 0x0001
* every change in the header or in the layers definition should be reflected in version change
* for backward compatibility new parsers can read old versions of model with certain restrictions
*/
uint16_t major = 2u;
/**
* @details Version of Format Minor unsigned int, corresponding to build revision for example
* changes in minor version are not affected layout of model
*/
uint32_t minor = 4u;
} version;
/**
* @brief Memory required to be allocated using GNAAlloc()
*/
uint64_t gnaMemSize = 0ull;
/**
* @brief Number of GNA Layers
*/
uint64_t layersCount = 0ull;
/**
* @brief Grouping level
*/
uint32_t nGroup = 0u;
/**
* Convolution related setting - they are affecting input transformation
*/
uint32_t nRotateRows = 0u;
uint32_t nRotateColumns = 0u;
bool doRotateInput = false;
uint32_t nInputs = 0u;
uint32_t nOutputs = 0u;
/**
* Convolution related setting - they are affecting output transformation
*/
uint32_t nRotateOutputRows = 0u;
uint32_t nRotateOutputColumns = 0u;
bool doRotateOutput = false;
/**
* Reserved Data might be here
*/
ModelHeader() = default;
ModelHeader(GNAPluginNS::Header2dot1::ModelHeader const &old) {
gnaMemSize = old.gnaMemSize;
layersCount = old.layersCount;
nGroup = old.nGroup;
nRotateRows = old.nRotateRows;
nRotateColumns = old.nRotateColumns;
nInputs = old.nInputs;
nOutputs = old.nOutputs;
}
};
#pragma pack(pop)
/*
* In runtime endpoint mostly same as in serial version, except of descriptor field
*/
struct RuntimeEndPoint {
/**
* if scale factor is different then pased into infer , network might need to be requantized
*/
float scaleFactor = 0;
/**
* Pointer descriptor
*/
void* descriptor_ptr = nullptr;
/**
* Endpoint resolution in bytes.
*/
uint32_t element_size = 0;
/**
* Number of elements
*/
uint32_t elements_count = 0;
/**
* Offset in bytes of pointer descriptor
*/
uint64_t descriptor_offset = 0ull;
intel_dnn_orientation_t orientation = kDnnUnknownOrientation;
RuntimeEndPoint() = default;
RuntimeEndPoint(double scaleFactor,
void* descriptor_ptr,
uint32_t element_size,
uint32_t elements_count,
intel_dnn_orientation_t orientation) : scaleFactor(scaleFactor),
descriptor_ptr(descriptor_ptr),
element_size(element_size),
elements_count(elements_count),
orientation(orientation) { }
};
} // namespace Header2dot4
} // namespace GNAPluginNS

View File

@ -4,11 +4,11 @@
#pragma once #pragma once
#include "serial/headers/2dot3/gna_model_header.hpp" #include "serial/headers/2dot4/gna_model_header.hpp"
namespace GNAPluginNS { namespace GNAPluginNS {
namespace HeaderLatest { namespace HeaderLatest {
using ModelHeader = GNAPluginNS::Header2dot3::ModelHeader; using ModelHeader = GNAPluginNS::Header2dot4::ModelHeader;
using RuntimeEndPoint = GNAPluginNS::Header2dot3::RuntimeEndPoint; using RuntimeEndPoint = GNAPluginNS::Header2dot4::RuntimeEndPoint;
} }
} }

View File

@ -0,0 +1,58 @@
// Copyright (C) 2020 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#include <vector>
#include "common_test_utils/test_constants.hpp"
#include "subgraph_tests/input_conv.hpp"
using namespace LayerTestsDefinitions;
namespace {
const std::vector<InferenceEngine::Precision> netPrecisions = {
InferenceEngine::Precision::FP32,
InferenceEngine::Precision::FP16
};
const std::vector<std::map<std::string, std::string>> configs = {
{
{"GNA_DEVICE_MODE", "GNA_SW_FP32"}
},
{
{"GNA_DEVICE_MODE", "GNA_SW_EXACT"},
{"GNA_SCALE_FACTOR_0", "163.835"}
}
};
std::vector<convParams> params = {
std::make_tuple(
std::vector<size_t>{1, 1, 1, 16}, //InputShape
std::vector<size_t>{1, 8}, //KernelShape
1), //Stride
std::make_tuple(std::vector<size_t>{1, 1, 1, 16}, std::vector<size_t>{1, 9}, 1),
std::make_tuple(std::vector<size_t>{1, 1, 1, 168}, std::vector<size_t>{1, 9}, 1),
std::make_tuple(std::vector<size_t>{1, 1, 1, 168}, std::vector<size_t>{1, 8}, 1),
std::make_tuple(std::vector<size_t>{1, 1, 1, 640}, std::vector<size_t>{1, 512}, 128)
};
std::vector<size_t> outputChannels = {
4,
8
};
std::vector<bool> addReshape = {
true,
false
};
INSTANTIATE_TEST_CASE_P(smoke_InputConv, InputConvTest,
::testing::Combine(
::testing::ValuesIn(netPrecisions),
::testing::Values(CommonTestUtils::DEVICE_GNA),
::testing::ValuesIn(configs),
::testing::ValuesIn(params),
::testing::ValuesIn(outputChannels),
::testing::ValuesIn(addReshape)),
InputConvTest::getTestCaseName);
} // namespace

View File

@ -0,0 +1,43 @@
// Copyright (C) 2020 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#pragma once
#include <memory>
#include <string>
#include <tuple>
#include <vector>
#include "functional_test_utils/layer_test_utils.hpp"
#include "ngraph_functions/builders.hpp"
#include "ngraph_functions/utils/ngraph_helpers.hpp"
typedef std::tuple<
std::vector<size_t>, // Input Shapes
std::vector<size_t>, // Kernel Shape
size_t // Stride
> convParams;
typedef std::tuple<
InferenceEngine::Precision, // Network Precision
std::string, // Target Device
std::map<std::string, std::string>, // Configuration
convParams, // Convolution Params
size_t, // Output Channels
bool // If Add Reshape at the end of the model to reshape to 2D
> inputConvParams;
namespace LayerTestsDefinitions {
class InputConvTest : public testing::WithParamInterface<inputConvParams>,
public LayerTestsUtils::LayerTestsCommon {
public:
static std::string getTestCaseName(testing::TestParamInfo<inputConvParams> obj);
InferenceEngine::Blob::Ptr GenerateInput(const InferenceEngine::InputInfo& info) const override;
protected:
void SetUp() override;
};
} // namespace LayerTestsDefinitions

View File

@ -0,0 +1,118 @@
// Copyright (C) 2020 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#include <ie_core.hpp>
#include <memory>
#include <string>
#include <tuple>
#include <vector>
#include <ie_plugin_config.hpp>
#include "common_test_utils/common_utils.hpp"
#include "functional_test_utils/blob_utils.hpp"
#include "functional_test_utils/layer_test_utils.hpp"
#include "functional_test_utils/plugin_cache.hpp"
#include "ngraph_functions/pass/convert_prc.hpp"
#include "subgraph_tests/input_conv.hpp"
#include "ngraph_functions/builders.hpp"
namespace LayerTestsDefinitions {
std::string InputConvTest::getTestCaseName(testing::TestParamInfo<inputConvParams> obj) {
InferenceEngine::Precision netPrecision;
std::string targetDevice;
std::map<std::string, std::string> configuration;
size_t outputChannels;
convParams convolutionParams;
std::vector<size_t> inputShape;
std::vector<size_t> kernelShape;
size_t stride;
bool addReshape;
std::tie(netPrecision, targetDevice, configuration, convolutionParams, outputChannels, addReshape) = obj.param;
std::tie(inputShape, kernelShape, stride) = convolutionParams;
std::ostringstream result;
result << "IS=" << CommonTestUtils::vec2str(inputShape) << "_";
result << "KS=" << CommonTestUtils::vec2str(kernelShape) << "_";
result << "S=" << stride << "_";
result << "OC=" << outputChannels << "_";
result << "addReshape=" << addReshape << "_";
result << "netPRC=" << netPrecision.name() << "_";
result << "targetDevice=" << targetDevice;
for (auto const& configItem : configuration) {
result << "_configItem=" << configItem.first << "_" << configItem.second;
}
return result.str();
}
InferenceEngine::Blob::Ptr InputConvTest::GenerateInput(const InferenceEngine::InputInfo& info) const {
InferenceEngine::Blob::Ptr blob = make_blob_with_precision(info.getTensorDesc());
blob->allocate();
auto precision = info.getPrecision();
auto* rawBlobDataPtr = blob->buffer().as<float*>();
for (size_t i = 0; i < blob->size(); i++) {
float value = i % 16;
if (typeid(precision) == typeid(typename InferenceEngine::PrecisionTrait<InferenceEngine::Precision::FP16>::value_type)) {
rawBlobDataPtr[i] = ngraph::float16(value).to_bits();
} else {
rawBlobDataPtr[i] = value;
}
}
return blob;
}
void InputConvTest::SetUp() {
auto generateWeights = [](std::size_t out_channels, std::size_t kernel_size) {
std::vector<float> res;
for (int i = 0; i < out_channels; ++i) {
for (int j = 0; j < kernel_size; ++j) {
j == 0 ? res.emplace_back(1.0f) : res.emplace_back(0.0f);
}
}
return res;
};
InferenceEngine::Precision netPrecision;
std::map<std::string, std::string> tempConfig;
convParams convolutionParams;
size_t outputChannels;
bool addReshape;
std::tie(netPrecision, targetDevice, tempConfig, convolutionParams, outputChannels, addReshape) = this->GetParam();
configuration.insert(tempConfig.begin(), tempConfig.end());
std::vector<size_t> inputShape;
std::vector<size_t> kernelShape;
size_t stride;
std::tie(inputShape, kernelShape, stride) = convolutionParams;
auto ngPrc = FuncTestUtils::PrecisionUtils::convertIE2nGraphPrc(netPrecision);
auto params = ngraph::builder::makeParams(ngPrc, { inputShape });
auto conv0 = ngraph::builder::makeConvolution(params[0], ngPrc, { kernelShape[0], kernelShape[1] }, { stride, stride }, { 0, 0 },
{ 0, 0 }, { 1, 1 }, ngraph::op::PadType::VALID, outputChannels, true,
generateWeights(outputChannels, kernelShape[1]));
if (addReshape) {
size_t numOutputWidth = (((inputShape[1] * inputShape[2] * inputShape[3] - kernelShape[1] * kernelShape[0]) / (inputShape[1] * stride)) + 1);
std::vector<size_t> outFormShapes0 = { 1, outputChannels * numOutputWidth };
auto pattern0 = std::make_shared<ngraph::opset1::Constant>(ngraph::element::Type_t::i64, ngraph::Shape{ 2 }, outFormShapes0);
auto reshape0 = std::make_shared<ngraph::opset1::Reshape>(conv0, pattern0, false);
ngraph::ResultVector results{ std::make_shared<ngraph::op::Result>(reshape0) };
function = std::make_shared<ngraph::Function>(results, params, "InputConvTest");
} else {
ngraph::ResultVector results{ std::make_shared<ngraph::op::Result>(conv0) };
function = std::make_shared<ngraph::Function>(results, params, "InputConvTest");
}
}
TEST_P(InputConvTest, CompareWithRefImpl) {
Run();
};
} // namespace LayerTestsDefinitions