[GNA] Fix transformation from NCHW to NHWC for 4d concat (#4852)
Add additional checks for transformation from NCHW to NHWC
This commit is contained in:
parent
5272bd4ba9
commit
27268f008e
@ -133,16 +133,20 @@ inline bool MustBeConvertedFromNCHWToNHWC(const std::vector<InferenceEngine::CNN
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief returns rotation information for a layer based on the previous convolution or pooling dimensions order
|
||||
* @param layer layer from which rotation info search must be started
|
||||
* @return bool value which identifies if rotation info is found and rotation information
|
||||
* @brief returns transposition information for a layer based on the previous convolution or pooling dimensions order
|
||||
* @param layer layer from which transposition info search must be started
|
||||
* @return bool value which identifies if transposition info is found and transposition information
|
||||
*/
|
||||
inline std::vector<TranspositionInfo> FindTranspositionInfoFromPrevLayers(InferenceEngine::CNNLayerPtr layer) {
|
||||
std::function<std::vector<TranspositionInfo>(InferenceEngine::CNNLayerPtr)> findTranspositionInfoRecursive =
|
||||
[&findTranspositionInfoRecursive](InferenceEngine::CNNLayerPtr layer) -> std::vector<TranspositionInfo> {
|
||||
auto getTransposeInfoFromData = [](InferenceEngine::DataPtr data, bool transpose = true) {
|
||||
auto rows = FROM_IR_DIM(data, 3);
|
||||
auto columns = FROM_IR_DIM(data, 1) * FROM_IR_DIM(data, 2);
|
||||
return std::vector<TranspositionInfo>{{transpose, rows, columns}};
|
||||
};
|
||||
if (LayerInfo(layer).isConvolution() || LayerInfo(layer).isPooling()) {
|
||||
auto out_dims = layer->outData[0]->getDims();
|
||||
return {{true, out_dims[1], out_dims[2] * out_dims[3]}};
|
||||
return getTransposeInfoFromData(layer->outData[0]);
|
||||
}
|
||||
|
||||
/* If a fullyconnected or input layers are reached, it means that transposition isn't needed, but we should keep
|
||||
@ -160,6 +164,46 @@ inline std::vector<TranspositionInfo> FindTranspositionInfoFromPrevLayers(Infere
|
||||
return findTranspositionInfoRecursive(input1);
|
||||
}
|
||||
|
||||
/* If it's a concat along not channel axis and its inputs are transposed the whole concat output must be transposed,
|
||||
* otherwise every part corresponding to some input must be transposed separately */
|
||||
if (LayerInfo(layer).isConcat() && !layer->insData.empty()) {
|
||||
auto concatLayer = LayerInfo(layer).as<InferenceEngine::ConcatLayer*>();
|
||||
IE_ASSERT(concatLayer != nullptr);
|
||||
if (concatLayer->_axis > 1) {
|
||||
for (const auto& input : layer->insData) {
|
||||
auto in_dims = input.lock()->getDims();
|
||||
if (in_dims.size() <= 2) {
|
||||
THROW_GNA_EXCEPTION << layer->name << " Invalid number of input dimensions " << in_dims.size()
|
||||
<< " for a concat with axis=" << concatLayer->_axis;
|
||||
}
|
||||
if (concatLayer->_axis == in_dims.size() - 1 && in_dims[in_dims.size() - 2] > 1) {
|
||||
std::ostringstream in_dims_oss;
|
||||
std::copy(in_dims.begin(), in_dims.end(), std::ostream_iterator<size_t>(in_dims_oss, ","));
|
||||
THROW_GNA_EXCEPTION << layer->name << " Unsupported concatenation axis=" << concatLayer->_axis
|
||||
<< " for input dimensions: " << in_dims_oss.str();
|
||||
}
|
||||
}
|
||||
// Check if non-const inputs are transposed
|
||||
bool transpose = false;
|
||||
int nonConstInputIx = 0;
|
||||
for (int i = 0; InferenceEngine::CNNNetHasPrevLayer(layer.get(), i); ++i) {
|
||||
auto input = InferenceEngine::CNNNetPrevLayer(layer, i);
|
||||
if (LayerInfo(input).isConst()) continue;
|
||||
auto transpositionInfo = FindTranspositionInfoFromPrevLayers(input);
|
||||
auto partToTranspose = std::find_if(std::begin(transpositionInfo), std::end(transpositionInfo),
|
||||
[](const TranspositionInfo &infoPart) { return infoPart.transpose; });
|
||||
bool inputTranspose = (partToTranspose != std::end(transpositionInfo));
|
||||
if (nonConstInputIx == 0) {
|
||||
transpose = inputTranspose;
|
||||
} else if (inputTranspose != transpose) {
|
||||
THROW_GNA_EXCEPTION << layer->name << " concat has inputs with different layouts";
|
||||
}
|
||||
++nonConstInputIx;
|
||||
}
|
||||
return getTransposeInfoFromData(layer->outData[0], transpose);
|
||||
}
|
||||
}
|
||||
|
||||
std::vector<TranspositionInfo> transpositionInfo;
|
||||
for (int idx = 0; idx < layer->insData.size(); ++idx) {
|
||||
if (!InferenceEngine::CNNNetHasPrevLayer(layer.get(), idx)) continue;
|
||||
@ -169,8 +213,8 @@ inline std::vector<TranspositionInfo> FindTranspositionInfoFromPrevLayers(Infere
|
||||
auto in_dims = layer->insData[idx].lock()->getDims();
|
||||
transpositionInfo.push_back({false, 1, InferenceEngine::details::product(std::begin(in_dims), std::end(in_dims))});
|
||||
} else if (LayerInfo(layer).isConcat() && LayerInfo(inputLayer).isConst()) {
|
||||
// If a concat input is a const we should keep its size to skip this part during transposition
|
||||
auto in_dims = layer->insData[idx].lock()->getDims();
|
||||
// We should keep its size to skip this part during transposition
|
||||
auto data_size = InferenceEngine::details::product(std::begin(in_dims), std::end(in_dims));
|
||||
transpositionInfo.push_back({false, 1, data_size});
|
||||
} else {
|
||||
@ -184,16 +228,17 @@ inline std::vector<TranspositionInfo> FindTranspositionInfoFromPrevLayers(Infere
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief returns rotation information for a layer based on the next convolution layer dimensions order
|
||||
* @param layer layer from which rotation info search must be started
|
||||
* @return bool value which identifies if rotation info is found and rotation information
|
||||
* @brief returns transposition information for a layer based on the next convolution layer dimensions order
|
||||
* @param layer layer from which transposition info search must be started
|
||||
* @return bool value which identifies if transposition info is found and transposition information
|
||||
*/
|
||||
inline std::vector<TranspositionInfo> FindTranspositionInfoFromNextLayers(InferenceEngine::CNNLayerPtr layer) {
|
||||
std::function<std::vector<TranspositionInfo>(InferenceEngine::CNNLayerPtr)> findTranspositionInfoRecursive =
|
||||
[&findTranspositionInfoRecursive](InferenceEngine::CNNLayerPtr layer) -> std::vector<TranspositionInfo> {
|
||||
if (LayerInfo(layer).isConvolution()) {
|
||||
auto in_dims = layer->input()->getDims();
|
||||
return {{true, in_dims[1], in_dims[2] * in_dims[3]}};
|
||||
auto rows = FROM_IR_DIM(layer->input(), 3);
|
||||
auto columns = FROM_IR_DIM(layer->input(), 1) * FROM_IR_DIM(layer->input(), 2);
|
||||
return {{true, rows, columns}};
|
||||
}
|
||||
|
||||
/* If a fullyconnected or output layers are reached, it means that transposition isn't needed, but we should keep
|
||||
|
@ -1087,7 +1087,7 @@ void InsertConcatAligningFilterPass::run() {
|
||||
std::make_shared<WeightableLayer>(LayerParams({filterName, "ConcatAlignFilter", Precision::FP32}));
|
||||
|
||||
if (dims.size() != 2) {
|
||||
THROW_GNA_EXCEPTION << "unsupported concat input a of dims.size()=" << dims.size() << ", layer=" << prevLayer->name;
|
||||
THROW_GNA_EXCEPTION << "unsupported concat input of dims.size()=" << dims.size() << ", layer=" << prevLayer->name;
|
||||
}
|
||||
|
||||
auto num_rows_in = dims[1];
|
||||
@ -2150,7 +2150,10 @@ void TransposeWeightsFromNCHWToNHWCPass::run() {
|
||||
transpositionInfo = FindTranspositionInfoFromNextLayers(getInputTo(l->outData[0]).begin()->second);
|
||||
}
|
||||
}
|
||||
if (!transpositionInfo.empty()) {
|
||||
if (foundPartToTranspose(transpositionInfo)) {
|
||||
if (l->input()->getDims().front() > 1) {
|
||||
THROW_GNA_EXCEPTION << l->name << " Weights transposition is not supported for a layer with batch size > 1";
|
||||
}
|
||||
auto weightable = dynamic_cast<WeightableLayer*>(l.get());
|
||||
IE_ASSERT(weightable != nullptr);
|
||||
ConvertTensorFromNCHWToNHWC(weightable->precision.size(), 1, weightable->_weights->size(),
|
||||
@ -2175,8 +2178,17 @@ void TransposeWeightsFromNCHWToNHWCPass::run() {
|
||||
auto weightsColumns = InferenceEngine::details::product(std::begin(in_dims) + 1, std::end(in_dims));
|
||||
// Find a convolution in previous layers to rotate weights rows
|
||||
if (InferenceEngine::CNNNetHasPrevLayer(l.get())) {
|
||||
auto transpositionInfo = FindTranspositionInfoFromPrevLayers(InferenceEngine::CNNNetPrevLayer(l));
|
||||
if (!transpositionInfo.empty()) {
|
||||
std::vector<TranspositionInfo> transpositionInfo;
|
||||
auto prevLayer = InferenceEngine::CNNNetPrevLayer(l);
|
||||
transpositionInfo = FindTranspositionInfoFromPrevLayers(prevLayer);
|
||||
if (foundPartToTranspose(transpositionInfo)) {
|
||||
if (l->input()->getDims().front() > 1) {
|
||||
THROW_GNA_EXCEPTION << l->name << " Weights transposition is not supported for a layer with batch size > 1";
|
||||
}
|
||||
if (LayerInfo(prevLayer).isSplit()) {
|
||||
// If we found a split it's not possible to rotate data
|
||||
THROW_GNA_EXCEPTION << l->name << " won't be transposed due to a split before it";
|
||||
}
|
||||
size_t totalColumns = 0;
|
||||
for (auto && transpositionInfoPart : transpositionInfo) {
|
||||
totalColumns += transpositionInfoPart.num_transpose_rows * transpositionInfoPart.num_transpose_columns;
|
||||
@ -2193,14 +2205,23 @@ void TransposeWeightsFromNCHWToNHWCPass::run() {
|
||||
}
|
||||
// Find a convolution in next layers to rotate weights columns
|
||||
if (!l->outData.empty() && !getInputTo(l->outData[0]).empty() && !l->outData.empty() && !getInputTo(l->outData[0]).empty()) {
|
||||
auto transpositionInfo = FindTranspositionInfoFromNextLayers(getInputTo(l->outData[0]).begin()->second);
|
||||
if (!transpositionInfo.empty()) {
|
||||
std::vector<TranspositionInfo> transpositionInfo;
|
||||
auto nextLayer = getInputTo(l->outData[0]).begin()->second;
|
||||
transpositionInfo = FindTranspositionInfoFromNextLayers(nextLayer);
|
||||
if (foundPartToTranspose(transpositionInfo)) {
|
||||
if (l->outData[0]->getDims().front() > 1) {
|
||||
THROW_GNA_EXCEPTION << l->name << " Weights transposition is not supported for a layer with batch size > 1";
|
||||
}
|
||||
if (LayerInfo(nextLayer).isConcat()) {
|
||||
// If we found a concat it's not possible to rotate data
|
||||
THROW_GNA_EXCEPTION << l->name << " won't be transposed due to a concat after it";
|
||||
}
|
||||
size_t totalRows = 0;
|
||||
for (const auto& transpositionInfoPart : transpositionInfo) {
|
||||
totalRows += transpositionInfoPart.num_transpose_rows * transpositionInfoPart.num_transpose_columns;
|
||||
}
|
||||
if (weightsRows != totalRows) {
|
||||
THROW_GNA_EXCEPTION << l->name << "weights rows from transposition info (" << totalRows
|
||||
THROW_GNA_EXCEPTION << l->name << " weights rows from transposition info (" << totalRows
|
||||
<< ") don't match output dimensions (" << weightsRows << ")";
|
||||
}
|
||||
ConvertTensorFromNCHWToNHWC(precision, weightsRows, weightsColumns, weightable->_weights->cbuffer().as<uint8_t*>(),
|
||||
@ -2227,14 +2248,55 @@ void TransposeWeightsFromNCHWToNHWCPass::run() {
|
||||
if (!foundPartToTranspose(transpositionInfo)) {
|
||||
transpositionInfo = FindTranspositionInfoFromNextLayers(getInputTo(l->outData[0]).begin()->second);
|
||||
}
|
||||
if (!transpositionInfo.empty()) {
|
||||
if (foundPartToTranspose(transpositionInfo)) {
|
||||
auto blob = secondInput->blobs["custom"];
|
||||
ConvertTensorFromNCHWToNHWC(blob->getTensorDesc().getPrecision().size(), 1, blob->size(),
|
||||
blob->buffer().as<uint8_t*>(), true, transpositionInfo);
|
||||
gnalog() << l->name << " data transposition info:\n";
|
||||
gnalog() << secondInput->name << " data transposition info:\n";
|
||||
printTranspositionInfo(transpositionInfo);
|
||||
}
|
||||
}
|
||||
|
||||
if (LayerInfo(l).isConcat()) {
|
||||
auto concatLayer = LayerInfo(l).as<InferenceEngine::ConcatLayer*>();
|
||||
IE_ASSERT(concatLayer != nullptr);
|
||||
// If concatenation is along channel axis constant input transposition isn't required
|
||||
if (concatLayer->_axis <= 1) continue;
|
||||
|
||||
std::vector<InferenceEngine::CNNLayerPtr> constInputs;
|
||||
bool transpose = false;
|
||||
int nonConstInputIx = 0;
|
||||
// Check if non-const inputs are transposed
|
||||
for (int i = 0; InferenceEngine::CNNNetHasPrevLayer(l.get(), i); ++i) {
|
||||
auto input = InferenceEngine::CNNNetPrevLayer(l, i);
|
||||
if (LayerInfo(input).isConst()) {
|
||||
constInputs.push_back(input);
|
||||
continue;
|
||||
}
|
||||
auto transpositionInfo = FindTranspositionInfoFromPrevLayers(input);
|
||||
bool transposeInput = foundPartToTranspose(transpositionInfo);
|
||||
if (nonConstInputIx == 0) {
|
||||
transpose = transposeInput;
|
||||
} else if (transposeInput != transpose) {
|
||||
THROW_GNA_EXCEPTION << "Concat layer " << l->name << " inputs have different layouts";
|
||||
}
|
||||
++nonConstInputIx;
|
||||
}
|
||||
if (!transpose) continue;
|
||||
|
||||
// Transpose all constant inputs
|
||||
for (auto && input : constInputs) {
|
||||
auto rows = FROM_IR_DIM(input->outData[0], 3);
|
||||
auto columns = FROM_IR_DIM(input->outData[0], 1) * FROM_IR_DIM(input->outData[0], 2);
|
||||
auto blob = input->blobs["custom"];
|
||||
// A constant should have the same number of channels since concatenation will be in height/weight dimension
|
||||
TranspositionInfo concatTranspositionInfo{true, rows, columns};
|
||||
ConvertTensorFromNCHWToNHWC(blob->getTensorDesc().getPrecision().size(), 1, blob->size(),
|
||||
blob->buffer().as<uint8_t*>(), true, {concatTranspositionInfo});
|
||||
gnalog() << input->name << " data transposition info:\n";
|
||||
printTranspositionInfo({concatTranspositionInfo});
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -0,0 +1,56 @@
|
||||
// Copyright (C) 2021 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#include <vector>
|
||||
|
||||
#include "common_test_utils/test_constants.hpp"
|
||||
#include "subgraph_tests/const_conv_concat.hpp"
|
||||
|
||||
using namespace SubgraphTestsDefinitions;
|
||||
|
||||
namespace {
|
||||
const std::vector<InferenceEngine::Precision> netPrecisions = {
|
||||
InferenceEngine::Precision::FP32,
|
||||
InferenceEngine::Precision::FP16
|
||||
};
|
||||
|
||||
const std::vector<std::map<std::string, std::string>> configs = {
|
||||
{
|
||||
{"GNA_DEVICE_MODE", "GNA_SW_FP32"}
|
||||
},
|
||||
{
|
||||
{"GNA_DEVICE_MODE", "GNA_SW_EXACT"}
|
||||
}
|
||||
};
|
||||
|
||||
std::vector<convParams> params = {
|
||||
std::make_tuple(
|
||||
std::vector<size_t>{1, 64}, //InputShape
|
||||
std::vector<size_t>{1, 3}, //KernelShape
|
||||
1), //Stride
|
||||
std::make_tuple(std::vector<size_t>{1, 128}, std::vector<size_t>{1, 5}, 1),
|
||||
std::make_tuple(std::vector<size_t>{1, 168}, std::vector<size_t>{1, 3}, 2),
|
||||
std::make_tuple(std::vector<size_t>{1, 320}, std::vector<size_t>{1, 8}, 4)
|
||||
};
|
||||
|
||||
std::vector<size_t> inputChannels = {
|
||||
1,
|
||||
4,
|
||||
8
|
||||
};
|
||||
|
||||
std::vector<size_t> outputChannels = {
|
||||
64
|
||||
};
|
||||
|
||||
INSTANTIATE_TEST_CASE_P(smoke_ConstConvConcatTest, ConstConvConcatTest,
|
||||
::testing::Combine(
|
||||
::testing::ValuesIn(netPrecisions),
|
||||
::testing::Values(CommonTestUtils::DEVICE_GNA),
|
||||
::testing::ValuesIn(configs),
|
||||
::testing::ValuesIn(params),
|
||||
::testing::ValuesIn(inputChannels),
|
||||
::testing::ValuesIn(outputChannels)),
|
||||
ConstConvConcatTest::getTestCaseName);
|
||||
} // namespace
|
@ -0,0 +1,19 @@
|
||||
// Copyright (C) 2021 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "shared_test_classes/subgraph/const_conv_concat.hpp"
|
||||
|
||||
namespace SubgraphTestsDefinitions {
|
||||
|
||||
TEST_P(ConstConvConcatTest, CompareWithRefImpl) {
|
||||
LoadNetwork();
|
||||
GenerateInputs();
|
||||
Infer();
|
||||
// Create another copy of function for validation since some data will be changed by GNA plugin
|
||||
SetUp();
|
||||
Validate();
|
||||
};
|
||||
} // namespace SubgraphTestsDefinitions
|
@ -0,0 +1,43 @@
|
||||
// Copyright (C) 2021 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <tuple>
|
||||
#include <vector>
|
||||
|
||||
#include "shared_test_classes/base/layer_test_utils.hpp"
|
||||
#include "ngraph_functions/builders.hpp"
|
||||
#include "ngraph_functions/utils/ngraph_helpers.hpp"
|
||||
|
||||
namespace SubgraphTestsDefinitions {
|
||||
|
||||
typedef std::tuple<
|
||||
std::vector<size_t>, // Input Shapes
|
||||
std::vector<size_t>, // Kernel Shape
|
||||
size_t // Stride
|
||||
> convParams;
|
||||
|
||||
typedef std::tuple<
|
||||
InferenceEngine::Precision, // Network Precision
|
||||
std::string, // Target Device
|
||||
std::map<std::string, std::string>, // Configuration
|
||||
convParams, // Convolution Params
|
||||
size_t, // Input Channels
|
||||
size_t // Output Channels
|
||||
> ConstConvConcatParams;
|
||||
|
||||
class ConstConvConcatTest : public testing::WithParamInterface<ConstConvConcatParams>,
|
||||
public LayerTestsUtils::LayerTestsCommon {
|
||||
public:
|
||||
static std::string getTestCaseName(testing::TestParamInfo<ConstConvConcatParams> obj);
|
||||
InferenceEngine::Blob::Ptr GenerateInput(const InferenceEngine::InputInfo& info) const override;
|
||||
|
||||
protected:
|
||||
void SetUp() override;
|
||||
};
|
||||
|
||||
} // namespace SubgraphTestsDefinitions
|
@ -0,0 +1,88 @@
|
||||
// Copyright (C) 2021 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#include "shared_test_classes/subgraph/const_conv_concat.hpp"
|
||||
#include "ngraph_functions/builders.hpp"
|
||||
|
||||
namespace SubgraphTestsDefinitions {
|
||||
|
||||
std::string ConstConvConcatTest::getTestCaseName(testing::TestParamInfo<ConstConvConcatParams> obj) {
|
||||
InferenceEngine::Precision netPrecision;
|
||||
std::string targetDevice;
|
||||
std::map<std::string, std::string> configuration;
|
||||
size_t inputChannels;
|
||||
size_t outputChannels;
|
||||
convParams convolutionParams;
|
||||
std::vector<size_t> inputShape;
|
||||
std::vector<size_t> kernelShape;
|
||||
size_t stride;
|
||||
std::tie(netPrecision, targetDevice, configuration, convolutionParams, inputChannels, outputChannels) = obj.param;
|
||||
std::tie(inputShape, kernelShape, stride) = convolutionParams;
|
||||
|
||||
std::ostringstream result;
|
||||
result << "IS=" << CommonTestUtils::vec2str(inputShape) << "_";
|
||||
result << "KS=" << CommonTestUtils::vec2str(kernelShape) << "_";
|
||||
result << "S=" << stride << "_";
|
||||
result << "IC=" << inputChannels << "_";
|
||||
result << "OC=" << outputChannels << "_";
|
||||
result << "netPRC=" << netPrecision.name() << "_";
|
||||
result << "targetDevice=" << targetDevice;
|
||||
for (auto const& configItem : configuration) {
|
||||
result << "_configItem=" << configItem.first << "_" << configItem.second;
|
||||
}
|
||||
return result.str();
|
||||
}
|
||||
|
||||
InferenceEngine::Blob::Ptr ConstConvConcatTest::GenerateInput(const InferenceEngine::InputInfo& info) const {
|
||||
InferenceEngine::Blob::Ptr blob = make_blob_with_precision(info.getTensorDesc());
|
||||
blob->allocate();
|
||||
|
||||
auto* rawBlobDataPtr = blob->buffer().as<float*>();
|
||||
std::vector<float> values = CommonTestUtils::generate_float_numbers(blob->size(), -0.2f, 0.2f);
|
||||
for (size_t i = 0; i < blob->size(); i++) {
|
||||
rawBlobDataPtr[i] = values[i];
|
||||
}
|
||||
return blob;
|
||||
}
|
||||
|
||||
void ConstConvConcatTest::SetUp() {
|
||||
InferenceEngine::Precision netPrecision;
|
||||
std::map<std::string, std::string> tempConfig;
|
||||
convParams convolutionParams;
|
||||
size_t inputChannels;
|
||||
size_t outputChannels;
|
||||
std::tie(netPrecision, targetDevice, tempConfig, convolutionParams, inputChannels, outputChannels) = this->GetParam();
|
||||
configuration.insert(tempConfig.begin(), tempConfig.end());
|
||||
|
||||
std::vector<size_t> inputShape;
|
||||
std::vector<size_t> kernelShape;
|
||||
size_t stride;
|
||||
std::tie(inputShape, kernelShape, stride) = convolutionParams;
|
||||
|
||||
auto ngPrc = FuncTestUtils::PrecisionUtils::convertIE2nGraphPrc(netPrecision);
|
||||
auto params = ngraph::builder::makeParams(ngPrc, { inputShape });
|
||||
|
||||
std::vector<size_t> convInputShape = {inputShape[0], inputChannels, 1, inputShape[1] / inputChannels};
|
||||
auto reshapePattern1 = std::make_shared<ngraph::opset1::Constant>(ngraph::element::Type_t::i64, ngraph::Shape{ 4 }, convInputShape);
|
||||
auto reshape1 = std::make_shared<ngraph::opset1::Reshape>(params[0], reshapePattern1, false);
|
||||
|
||||
auto filterWeights = CommonTestUtils::generate_float_numbers(outputChannels * convInputShape[1] * kernelShape[0] * kernelShape[1],
|
||||
0.0f, 0.1f);
|
||||
auto conv = ngraph::builder::makeConvolution(reshape1, ngPrc, { kernelShape[0], kernelShape[1] }, { stride, stride }, { 0, 0 },
|
||||
{ 0, 0 }, { 1, 1 }, ngraph::op::PadType::VALID, outputChannels, false, filterWeights);
|
||||
|
||||
auto widthAfterConv = (convInputShape[3] - kernelShape[1]) / stride + 1;
|
||||
std::vector<size_t> outFormShapes = {1, outputChannels * widthAfterConv };
|
||||
|
||||
auto const_values = CommonTestUtils::generate_float_numbers(outputChannels * widthAfterConv, -0.2f, 0.2f);
|
||||
auto constant = ngraph::builder::makeConstant(ngPrc, {1, outputChannels, 1, widthAfterConv}, const_values);
|
||||
auto concat = ngraph::builder::makeConcat({constant, conv}, 3);
|
||||
|
||||
auto reshapePattern2 = std::make_shared<ngraph::opset1::Constant>(ngraph::element::Type_t::i64, ngraph::Shape{ 2 },
|
||||
std::vector<size_t>{1, 2 * outputChannels * widthAfterConv });
|
||||
auto reshape2 = std::make_shared<ngraph::opset1::Reshape>(concat, reshapePattern2, false);
|
||||
|
||||
function = std::make_shared<ngraph::Function>(reshape2, params, "ConstConvConcatTest");
|
||||
}
|
||||
} // namespace SubgraphTestsDefinitions
|
Loading…
Reference in New Issue
Block a user