From 9f5498982461b7616b8c885b71d3ceeca3c60d71 Mon Sep 17 00:00:00 2001 From: Kamil Magierski Date: Fri, 13 Nov 2020 16:12:45 +0100 Subject: [PATCH] [GNA] 4D concat align pass (#2970) * [GNA] Fix RemovePermutationsNHWCToNCHWPass in cases that permute input has many outData * style * [GNA] linux test fail fix --- .../src/gna_plugin/gna_graph_tools.hpp | 41 +++++++- .../src/gna_plugin/gna_plugin.cpp | 1 + .../src/gna_plugin/gna_plugin_policy.hpp | 5 + .../gna_plugin/optimizer/gna_pass_manager.cpp | 96 ++++++++++++++++++- .../gna_plugin/optimizer/gna_pass_manager.hpp | 5 + .../single_layer_tests/concat_4D.cpp | 34 +++++++ .../include/single_layer_tests/concat_4D.hpp | 32 +++++++ .../src/single_layer_tests/concat_4D.cpp | 70 ++++++++++++++ .../src/subgraph_tests/matmul_squeeze_add.cpp | 19 +--- .../src/subgraph_tests/memory_LSTMCell.cpp | 27 ++---- .../src/subgraph_tests/multiple_LSTMCell.cpp | 29 ++---- .../src/subgraph_tests/multiple_concat.cpp | 17 +--- .../subgraph_tests/perm_conv_perm_concat.cpp | 18 +--- .../common_test_utils/data_utils.hpp | 15 +++ 14 files changed, 320 insertions(+), 89 deletions(-) create mode 100644 inference-engine/tests/functional/plugin/gna/shared_tests_instances/single_layer_tests/concat_4D.cpp create mode 100644 inference-engine/tests/functional/plugin/shared/include/single_layer_tests/concat_4D.hpp create mode 100644 inference-engine/tests/functional/plugin/shared/src/single_layer_tests/concat_4D.cpp diff --git a/inference-engine/src/gna_plugin/gna_graph_tools.hpp b/inference-engine/src/gna_plugin/gna_graph_tools.hpp index 137543bc347..e358e97606c 100644 --- a/inference-engine/src/gna_plugin/gna_graph_tools.hpp +++ b/inference-engine/src/gna_plugin/gna_graph_tools.hpp @@ -6,7 +6,7 @@ #include #include "gna_plugin_log.hpp" - +#include "frontend/quantized_layer_params.hpp" #include #include #include @@ -441,7 +441,45 @@ inline void CNNNetSwapLayers(InferenceEngine::CNNLayerPtr lhs, lhs->outData.front()->setDims(rhs->outData.front()->getDims()); } +/** +* @brief changes the Tensor Desctiption if data by created a new one with correct description and replacing original one +*/ +inline DataPtr CNNReplaceDataWithChangedTensorDescription(DataPtr old_data, TensorDesc& new_td) { + auto new_dataPtr = std::make_shared(old_data->getName() + "_reshaped", new_td); + getInputTo(new_dataPtr) = getInputTo(old_data); + auto creatorLayer = getCreatorLayer(old_data).lock(); + getCreatorLayer(new_dataPtr) = creatorLayer; + size_t idx = -1; + for (size_t i=0; i < creatorLayer->outData.size(); i++) { + if (areEqualDatas(old_data, creatorLayer->outData[i])) { + idx = i; + break; + } + } + if (idx == -1) THROW_GNA_EXCEPTION << "No idx for data was found"; + creatorLayer->outData[idx] = new_dataPtr; + auto input_to = getInputTo(new_dataPtr); + for (auto& input : input_to) { + for (auto& input_idx : CNNLayerFindInsDataIdxes(old_data, input.second)) { + input.second->insData[input_idx] = new_dataPtr; + } + } + return new_dataPtr; +} + +/** +* @brief Creates a Reshape with given name and tensor description +*/ +inline CNNLayerPtr CNNNetworkCreateReshape(TensorDesc td, std::string name, bool quantized) { + auto reshape = std::make_shared(LayerParams({name, "reshape", Precision::FP32})); + auto reshapeLayerWithQuant = quantized ? InferenceEngine::injectData(reshape) : reshape; + auto dataPtr = std::make_shared(name + "_data", td); + getCreatorLayer(dataPtr) = reshapeLayerWithQuant; + reshapeLayerWithQuant->outData.push_back(dataPtr); + + return reshapeLayerWithQuant; +} /** * @@brief insertLayer between given layers @@ -594,6 +632,7 @@ std::vector > CNNNetGetPrevLayersSkip(CNNLayerPtr or * @brief remove given layer from topology, currently only layers with one input data and one output data supported */ inline void CNNNetworkRemoveLayer(CNNLayerPtr layer, bool checkDims = true) { + gnalog() << "Removing " << layer->name << "layer"; if (!layer) { THROW_IE_EXCEPTION << "Cannot remove layer pointed to NULL"; } diff --git a/inference-engine/src/gna_plugin/gna_plugin.cpp b/inference-engine/src/gna_plugin/gna_plugin.cpp index 7d6e6768ce9..52c55d74f15 100644 --- a/inference-engine/src/gna_plugin/gna_plugin.cpp +++ b/inference-engine/src/gna_plugin/gna_plugin.cpp @@ -408,6 +408,7 @@ void GNAPlugin::LoadNetwork(ICNNNetwork & _network) { passes->registerPass(); passes->registerPass(); + passes->registerPass(); passes->registerPass(); passes->registerPass(); if (policy.PermutePolicy != Policy::Permute::DISABLED) { diff --git a/inference-engine/src/gna_plugin/gna_plugin_policy.hpp b/inference-engine/src/gna_plugin/gna_plugin_policy.hpp index 6fee8751391..6880b9ec57d 100644 --- a/inference-engine/src/gna_plugin/gna_plugin_policy.hpp +++ b/inference-engine/src/gna_plugin/gna_plugin_policy.hpp @@ -34,6 +34,11 @@ class Policy { AUTO_PERMUTE } PermutePolicy = Permute::DISABLED; + enum class Concat4Dto2DConversion { + DISABLED, + ENABLED + } ConcatConversionPolicy = Concat4Dto2DConversion::ENABLED; + enum class ConcatAlignment { DISABLED, DISABLED_FOR_FP32, diff --git a/inference-engine/src/gna_plugin/optimizer/gna_pass_manager.cpp b/inference-engine/src/gna_plugin/optimizer/gna_pass_manager.cpp index 3bae2546c3b..0cbc01b685c 100644 --- a/inference-engine/src/gna_plugin/optimizer/gna_pass_manager.cpp +++ b/inference-engine/src/gna_plugin/optimizer/gna_pass_manager.cpp @@ -634,6 +634,10 @@ void RemovePermutationsNHWCToNCHWPass::run() { continue; } + if (l->outData.size() != 1) { + continue; + } + if (getInputTo(l->outData.front()).empty()) { continue; } @@ -661,7 +665,18 @@ void RemovePermutationsNHWCToNCHWPass::run() { next->input()->setDims(toRemove->input()->getDims()); next->input()->setLayout(Layout::NHWC); auto layerBeforePermute = CNNNetPrevLayer(toRemove); - layerBeforePermute->outData[0]->setLayout(Layout::NHWC); + + DataPtr output = nullptr; + for (auto before_output : layerBeforePermute->outData) { + if (areEqualDatas(toRemove->input(), before_output)) { + output = before_output; + output->setLayout(Layout::NHWC); + break; + } + } + if (output == nullptr) { + THROW_GNA_EXCEPTION << "Could not find correct data link between " << toRemove->name << " and " << layerBeforePermute->name; + } auto* convolution = dynamic_cast(next.get()); if (!convolution) { @@ -808,6 +823,85 @@ void InsertCopyLayerPass::run() { } } +void Concat4Dto2DPass::run() { + // Find 4D concat layers that will have to use ConcatAlignFilters and can be substituted by 2D concat + // for example if 4D concat have unaligned inputs then ConcatAlignFilters need to be used if sizes before + // axis are all ones then concat can be changed to 2D for example, lets say all unputs have same shape equal to: + // 1, 1, 5, 3 then for axis 0, 1, 2 the change will be made and inputs will be reshaped to 1, 15, + // but for shape 2, 1, 5, 3 only axis 0 is valid and inputs will reshape to 1, 30 + auto quantized = InferenceEngine::getInjectedData(pLayers->front()); + + if (getPassManager()->getPolicy().ConcatConversionPolicy == Policy::Concat4Dto2DConversion::DISABLED) return; + if (getPassManager()->getPolicy().ConcatAlignmentPolicy == Policy::ConcatAlignment::DISABLED) return; + if (getPassManager()->getPolicy().ConcatAlignmentPolicy == Policy::ConcatAlignment::DISABLED_FOR_FP32 && !quantized) return; + + for (auto & l : *pLayers) { + LayerInfo info(l); + auto concatLayer = info.as(); + if (!concatLayer) continue; + if (concatLayer->insData.size() < 1) continue; + + auto dims_size = concatLayer->insData[0].lock()->getDims().size(); + if (dims_size > 2) { + auto axis = concatLayer->_axis; + bool skip_layer = false; + for (int i = 0; i < axis; i++) { + if (concatLayer->insData[0].lock()->getDims()[i] != 1) skip_layer = true; + } + if (skip_layer) continue; + skip_layer = true; + std::vector total_sizes; + for (auto& input : concatLayer->insData) { + auto input_dims = input.lock()->getDims(); + total_sizes.push_back(std::accumulate(input_dims.begin(), input_dims.end(), size_t(1), std::multiplies())); + if (total_sizes.back() % 64 != 0) skip_layer = false; + } + if (skip_layer) continue; + + for (size_t input_idx = 0; input_idx != concatLayer->insData.size(); input_idx++) { + auto getLayerByIndex = [&concatLayer](int idx) { + auto input = concatLayer->insData[idx]; + auto lockedInput = input.lock(); + if (!lockedInput) { + THROW_GNA_EXCEPTION << "cannot get insdata : "<< idx << " for layer: " << concatLayer->name; + } + return lockedInput; + }; + + auto concatInput = getLayerByIndex(input_idx); + + auto tensor = InferenceEngine::TensorDesc(concatInput->getTensorDesc()); + tensor.reshape(SizeVector({1, total_sizes[input_idx]}), Layout::NC); + auto reshapeName = l->name + "_input_"+ std::to_string(input_idx) +"_reshape"; + auto reshape = CNNNetworkCreateReshape(tensor, reshapeName, quantized); + + CNNNetworkInsertLayer(getCreatorLayer(concatInput).lock(), l, reshape); + gnalog() << "\tInserted " << reshapeName << " between " << getCreatorLayer(concatInput).lock()->name << " and " << l->name << std::endl; + } + + for (auto output_idx = 0; output_idx != concatLayer->outData.size(); output_idx++) { + auto output = concatLayer->outData[output_idx]; + auto output_tensor_copy = TensorDesc(output->getTensorDesc()); + + auto dims = output_tensor_copy.getDims(); + auto total_size = std::accumulate(dims.begin(), dims.end(), size_t(1), std::multiplies()); + + auto new_tensor = output->getTensorDesc(); + new_tensor.reshape(SizeVector({1, total_size}), Layout::NC); + + auto new_output = CNNReplaceDataWithChangedTensorDescription(output, new_tensor); + gnalog() << "\tChanged " << output->getName() << " dims to 2D" << std::endl; + + auto reshapeName = l->name + "_output_"+ std::to_string(output_idx) +"_reshape"; + + auto reshape = CNNNetworkCreateReshape(output_tensor_copy, reshapeName, quantized); + CNNNetworkInsertLayer(l, nullptr, reshape, output_idx); + gnalog() << "\tInserted " << reshapeName << " after " << l->name << std::endl; + } + } + } +} + void InsertConcatAligningFilterPass::run() { auto quantized = InferenceEngine::getInjectedData(pLayers->front()); diff --git a/inference-engine/src/gna_plugin/optimizer/gna_pass_manager.hpp b/inference-engine/src/gna_plugin/optimizer/gna_pass_manager.hpp index 6ee8b5ce8d7..033c99e6b6f 100644 --- a/inference-engine/src/gna_plugin/optimizer/gna_pass_manager.hpp +++ b/inference-engine/src/gna_plugin/optimizer/gna_pass_manager.hpp @@ -141,6 +141,11 @@ DECL_PASS(InsertCopyLayer); */ DECL_PASS(InsertSplitAligningFilter); +/** +* @brief Pass that changes 4D concat to 2D concat in cases that would have to use ConcatAlignFilter +*/ +DECL_PASS(Concat4Dto2D); + /** * @brief concat-aligning filter layer insertion required in cases when concat inputs size are not 64-aligned */ diff --git a/inference-engine/tests/functional/plugin/gna/shared_tests_instances/single_layer_tests/concat_4D.cpp b/inference-engine/tests/functional/plugin/gna/shared_tests_instances/single_layer_tests/concat_4D.cpp new file mode 100644 index 00000000000..bdabe6c7669 --- /dev/null +++ b/inference-engine/tests/functional/plugin/gna/shared_tests_instances/single_layer_tests/concat_4D.cpp @@ -0,0 +1,34 @@ +// Copyright (C) 2019 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#include + +#include "single_layer_tests/concat_4D.hpp" +#include "common_test_utils/test_constants.hpp" + +using namespace LayerTestsDefinitions; + +namespace { +std::vector> inShapes = { + {1, 1, 33, 16}, + {1, 1, 65, 16}, +}; + +std::vector netPrecisions = {InferenceEngine::Precision::FP32, + InferenceEngine::Precision::FP16}; + +std::map additional_config = { + {"GNA_COMPACT_MODE", "NO"}, + {"GNA_DEVICE_MODE", "GNA_SW_EXACT"}, + {"GNA_SCALE_FACTOR_0", "2000.0"}, +}; + +INSTANTIATE_TEST_CASE_P(smoke_Concat4D_Basic, Concat4DLayerTest, + ::testing::Combine( + ::testing::ValuesIn(inShapes), + ::testing::ValuesIn(netPrecisions), + ::testing::Values(CommonTestUtils::DEVICE_GNA), + ::testing::Values(additional_config)), + Concat4DLayerTest::getTestCaseName); +} // namespace diff --git a/inference-engine/tests/functional/plugin/shared/include/single_layer_tests/concat_4D.hpp b/inference-engine/tests/functional/plugin/shared/include/single_layer_tests/concat_4D.hpp new file mode 100644 index 00000000000..ca6adcd29d4 --- /dev/null +++ b/inference-engine/tests/functional/plugin/shared/include/single_layer_tests/concat_4D.hpp @@ -0,0 +1,32 @@ +// Copyright (C) 2019 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#pragma once + +#include +#include +#include +#include + +#include "functional_test_utils/layer_test_utils.hpp" +#include "ngraph_functions/builders.hpp" +#include "ngraph_functions/utils/ngraph_helpers.hpp" + +namespace LayerTestsDefinitions { +using concat4DParamsTuple = typename std::tuple< + std::vector, // Inputs shape + InferenceEngine::Precision, // Network precision + std::string, // Device name + std::map // Configuration +>; + +class Concat4DLayerTest : public testing::WithParamInterface, + virtual public LayerTestsUtils::LayerTestsCommon { +public: + static std::string getTestCaseName(const testing::TestParamInfo &obj); +protected: + void SetUp() override; +}; + +} // namespace LayerTestsDefinitions diff --git a/inference-engine/tests/functional/plugin/shared/src/single_layer_tests/concat_4D.cpp b/inference-engine/tests/functional/plugin/shared/src/single_layer_tests/concat_4D.cpp new file mode 100644 index 00000000000..b4e5e509b72 --- /dev/null +++ b/inference-engine/tests/functional/plugin/shared/src/single_layer_tests/concat_4D.cpp @@ -0,0 +1,70 @@ +// Copyright (C) 2019 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#include +#include +#include +#include +#include + +#include "ie_core.hpp" + +#include "common_test_utils/common_utils.hpp" +#include "functional_test_utils/blob_utils.hpp" +#include "common_test_utils/data_utils.hpp" +#include "functional_test_utils/precision_utils.hpp" +#include "functional_test_utils/plugin_cache.hpp" +#include "functional_test_utils/skip_tests_config.hpp" + +#include "single_layer_tests/concat_4D.hpp" + +namespace LayerTestsDefinitions { + + std::string Concat4DLayerTest::getTestCaseName(const testing::TestParamInfo &obj) { + int axis; + std::vector inputShapes; + InferenceEngine::Precision netPrecision; + InferenceEngine::Precision inPrc, outPrc; + InferenceEngine::Layout inLayout, outLayout; + std::string targetName; + std::map config; + std::tie(inputShapes, netPrecision, targetName, config) = obj.param; + std::ostringstream result; + result << "IS=" << CommonTestUtils::vec2str(inputShapes) << "_"; + result << "netPRC=" << netPrecision.name() << "_"; + result << "trgDev=" << targetName << "_"; + return result.str(); + } + + void Concat4DLayerTest::SetUp() { + int axis = 1; + InferenceEngine::SizeVector inputShape; + InferenceEngine::Precision netPrecision; + std::map additional_config; + std::tie(inputShape, netPrecision, targetDevice, additional_config) = this->GetParam(); + configuration.insert(additional_config.begin(), additional_config.end()); + + auto total_size = std::accumulate(inputShape.begin(), inputShape.end(), static_cast(1), std::multiplies()); + auto ngPrc = FuncTestUtils::PrecisionUtils::convertIE2nGraphPrc(netPrecision); + auto params = ngraph::builder::makeParams(ngPrc, {inputShape}); + auto input = params[0]; + + auto constant_values = CommonTestUtils::generate_float_numbers(total_size, 11.0f, 12.0f); + auto constant = ngraph::builder::makeConstant(ngPrc, std::vector({1, total_size}), constant_values); + auto first_reshape_pattern = std::make_shared(ngraph::element::i64, + ngraph::Shape{4}, std::vector(inputShape)); + auto first_reshape = std::make_shared(constant, first_reshape_pattern, false); + auto constant_2 = ngraph::builder::makeConstant(ngPrc, inputShape, constant_values); + + auto concat = std::make_shared(ngraph::OutputVector({first_reshape, input, constant_2}), axis); + auto act = ngraph::builder::makeActivation(concat, ngPrc, ngraph::helpers::ActivationTypes::Relu); + ngraph::ResultVector results{std::make_shared(act)}; + function = std::make_shared(results, params, "concat"); + } + + + TEST_P(Concat4DLayerTest, CompareWithRefs) { + Run(); + }; +} // namespace LayerTestsDefinitions diff --git a/inference-engine/tests/functional/plugin/shared/src/subgraph_tests/matmul_squeeze_add.cpp b/inference-engine/tests/functional/plugin/shared/src/subgraph_tests/matmul_squeeze_add.cpp index b071fdbe1c7..afeb81b2ef4 100644 --- a/inference-engine/tests/functional/plugin/shared/src/subgraph_tests/matmul_squeeze_add.cpp +++ b/inference-engine/tests/functional/plugin/shared/src/subgraph_tests/matmul_squeeze_add.cpp @@ -41,20 +41,7 @@ std::string MatmulSqueezeAddTest::getTestCaseName(testing::TestParamInfo res; - - std::mt19937 gen( - static_cast(std::chrono::high_resolution_clock::now().time_since_epoch().count())); - - std::uniform_real_distribution dist(startFrom, upTo); - - for (int i = 0; i < vec_len; i++) - res.emplace_back(static_cast(dist(gen))); - - return res; - }; - + auto seed = std::chrono::high_resolution_clock::now().time_since_epoch().count(); InferenceEngine::Precision netPrecision; std::map tempConfig; std::vector inputShape; @@ -67,14 +54,14 @@ void MatmulSqueezeAddTest::SetUp() { auto params = ngraph::builder::makeParams(ngPrc, { inputShape }); auto constant_0 = ngraph::builder::makeConstant(ngPrc, { outputSize, inputShape[1] }, - generateFloatNumbers(0, 1, outputSize * inputShape[1]), false); + CommonTestUtils::generate_float_numbers(outputSize * inputShape[1], 0, 1, seed), false); auto matmul_0 = std::make_shared(params[0], constant_0, false, true); auto constant_1 = std::make_shared(ngraph::element::Type_t::i64, ngraph::Shape{ 1 }, std::vector{0}); auto unsqueeze_0 = std::make_shared(matmul_0, constant_1); auto constant_2 = ngraph::builder::makeConstant(ngPrc, { 1, inputShape[0], outputSize }, - generateFloatNumbers(0, 1, inputShape[0] * outputSize), false); + CommonTestUtils::generate_float_numbers(inputShape[0] * outputSize, 0, 1, seed), false); auto add_0 = std::make_shared(unsqueeze_0, constant_2); auto constant_3 = std::make_shared(ngraph::element::Type_t::i64, ngraph::Shape{ 1 }, std::vector{0}); diff --git a/inference-engine/tests/functional/plugin/shared/src/subgraph_tests/memory_LSTMCell.cpp b/inference-engine/tests/functional/plugin/shared/src/subgraph_tests/memory_LSTMCell.cpp index dcbeb7c68d3..93a883741a0 100644 --- a/inference-engine/tests/functional/plugin/shared/src/subgraph_tests/memory_LSTMCell.cpp +++ b/inference-engine/tests/functional/plugin/shared/src/subgraph_tests/memory_LSTMCell.cpp @@ -58,26 +58,13 @@ namespace SubgraphTestsDefinitions { std::vector hidden_memory_dims {1, hiddenSize}; std::vector cell_memory_dims {1, hiddenSize}; - const int seed = 0; - std::mt19937 gen(static_cast(seed)); - - auto generateFloatNumbers = [gen](std::size_t vec_len, float min, float max) mutable { - std::vector res; - - std::uniform_real_distribution dist(min, max); - for (int i = 0; i < vec_len; i++) - res.emplace_back(static_cast(dist(gen))); - - return res; - }; - - input_bias = generateFloatNumbers(inputSize, -0.25f, 0.0f); - input_weights = generateFloatNumbers(inputSize, 0.0f, 0.15f); - hidden_memory_init = generateFloatNumbers(hiddenSize, -0.2f, 0.2f); - cell_memory_init = generateFloatNumbers(hiddenSize, -0.2f, 0.2f); - weights_vals = generateFloatNumbers(4 * hiddenSize * inputSize, -0.1f, 0.1f); - reccurrenceWeights_vals = generateFloatNumbers(4 * hiddenSize * hiddenSize, -0.1f, 0.1f); - bias_vals = generateFloatNumbers(4 * hiddenSize, -0.25f, 0.15f); + input_bias = CommonTestUtils::generate_float_numbers(inputSize, -0.2f, 0.0f); + input_weights = CommonTestUtils::generate_float_numbers(inputSize, 0.0f, 0.1f); + hidden_memory_init = CommonTestUtils::generate_float_numbers(hiddenSize, -0.2f, 0.2f); + cell_memory_init = CommonTestUtils::generate_float_numbers(hiddenSize, -0.2f, 0.2f); + weights_vals = CommonTestUtils::generate_float_numbers(4 * hiddenSize * inputSize, -0.1f, 0.1f); + reccurrenceWeights_vals = CommonTestUtils::generate_float_numbers(4 * hiddenSize * hiddenSize, -0.1f, 0.1f); + bias_vals = CommonTestUtils::generate_float_numbers(4 * hiddenSize, -0.2f, 0.1f); auto input_parameter = ngraph::builder::makeParams(ngPrc, {input_dims}); diff --git a/inference-engine/tests/functional/plugin/shared/src/subgraph_tests/multiple_LSTMCell.cpp b/inference-engine/tests/functional/plugin/shared/src/subgraph_tests/multiple_LSTMCell.cpp index 1df0e7baf26..9463031be87 100644 --- a/inference-engine/tests/functional/plugin/shared/src/subgraph_tests/multiple_LSTMCell.cpp +++ b/inference-engine/tests/functional/plugin/shared/src/subgraph_tests/multiple_LSTMCell.cpp @@ -55,27 +55,14 @@ void MultipleLSTMCellTest::SetUp() { std::vector hidden_memory_dims {1, hiddenSize}; std::vector cell_memory_dims {1, hiddenSize}; - const int seed = 0; - std::mt19937 gen(static_cast(seed)); - - auto generateFloatNumbers = [gen](std::size_t vec_len, float min, float max) mutable { - std::vector res; - - std::uniform_real_distribution dist(min, max); - for (int i = 0; i < vec_len; i++) - res.emplace_back(static_cast(dist(gen))); - - return res; - }; - - input_bias = generateFloatNumbers(inputSize, -0.25f, 0.0f); - input_weights = generateFloatNumbers(inputSize, 0.0f, 0.15f); - hidden_memory_init = generateFloatNumbers(hiddenSize, -0.2f, 0.2f); - cell_memory_init = generateFloatNumbers(hiddenSize, -0.2f, 0.2f); - weights_vals = generateFloatNumbers(4 * hiddenSize * inputSize, -0.1f, 0.1f); - weights_2_vals = generateFloatNumbers(4 * hiddenSize * hiddenSize, -0.1f, 0.1f); - reccurrenceWeights_vals = generateFloatNumbers(4 * hiddenSize * hiddenSize, -0.1f, 0.1f); - bias_vals = generateFloatNumbers(4 * hiddenSize, -0.25f, 0.15f); + input_bias = CommonTestUtils::generate_float_numbers(inputSize, -0.25f, 0.0f); + input_weights = CommonTestUtils::generate_float_numbers(inputSize, 0.0f, 0.15f); + hidden_memory_init = CommonTestUtils::generate_float_numbers(hiddenSize, -0.2f, 0.2f); + cell_memory_init = CommonTestUtils::generate_float_numbers(hiddenSize, -0.2f, 0.2f); + weights_vals = CommonTestUtils::generate_float_numbers(4 * hiddenSize * inputSize, -0.1f, 0.1f); + weights_2_vals = CommonTestUtils::generate_float_numbers(4 * hiddenSize * hiddenSize, -0.1f, 0.1f); + reccurrenceWeights_vals = CommonTestUtils::generate_float_numbers(4 * hiddenSize * hiddenSize, -0.1f, 0.1f); + bias_vals = CommonTestUtils::generate_float_numbers(4 * hiddenSize, -0.25f, 0.15f); auto input_parameter = ngraph::builder::makeParams(ngPrc, {input_dims}); diff --git a/inference-engine/tests/functional/plugin/shared/src/subgraph_tests/multiple_concat.cpp b/inference-engine/tests/functional/plugin/shared/src/subgraph_tests/multiple_concat.cpp index 4fbd710d15f..01291111b5f 100644 --- a/inference-engine/tests/functional/plugin/shared/src/subgraph_tests/multiple_concat.cpp +++ b/inference-engine/tests/functional/plugin/shared/src/subgraph_tests/multiple_concat.cpp @@ -49,21 +49,8 @@ void MultipleConcatTest::SetUp() { std::vector input_dims { 1, inputSize }; std::vector constant_dims {1, constantSize}; - const int seed = 0; - std::mt19937 gen(static_cast(seed)); - - auto generateFloatNumbers = [gen](std::size_t vec_len, float min, float max) mutable { - std::vector res; - - std::uniform_real_distribution dist(min, max); - for (int i = 0; i < vec_len; i++) - res.emplace_back(static_cast(dist(gen))); - - return res; - }; - - auto concat_1_vals = generateFloatNumbers(constantSize, -2.0f, 2.0f); - auto concat_2_vals = generateFloatNumbers(constantSize, -5.0f, 5.0f); + auto concat_1_vals = CommonTestUtils::generate_float_numbers(constantSize, -2.0f, 2.0f); + auto concat_2_vals = CommonTestUtils::generate_float_numbers(constantSize, -5.0f, 5.0f); auto input_parameter = ngraph::builder::makeParams(ngPrc, {input_dims}); diff --git a/inference-engine/tests/functional/plugin/shared/src/subgraph_tests/perm_conv_perm_concat.cpp b/inference-engine/tests/functional/plugin/shared/src/subgraph_tests/perm_conv_perm_concat.cpp index 62ab624794a..b816e3aeace 100644 --- a/inference-engine/tests/functional/plugin/shared/src/subgraph_tests/perm_conv_perm_concat.cpp +++ b/inference-engine/tests/functional/plugin/shared/src/subgraph_tests/perm_conv_perm_concat.cpp @@ -52,19 +52,6 @@ void PermConvPermConcat::SetUp() { std::vector permute_in_order = { 0, 3, 1, 2 }; std::vector permute_out_order = { 0, 2, 3, 1 }; - const int seed = 0; - std::mt19937 gen(static_cast(seed)); - - auto generateFloatNumbers = [gen](std::size_t vec_len, float min, float max) mutable { - std::vector res; - - std::uniform_real_distribution dist(min, max); - for (int i = 0; i < vec_len; i++) - res.emplace_back(static_cast(dist(gen))); - - return res; - }; - auto input_parameter = ngraph::builder::makeParams(ngPrc, {input_dims}); auto reshape_in_pattern = std::make_shared(ngraph::element::i64, @@ -79,7 +66,7 @@ void PermConvPermConcat::SetUp() { auto conv_in_shape = permute_in->get_output_shape(0); auto conv_weights_size = output_channels * (conv_in_shape[1]) * kernel_shape[0] * kernel_shape[1]; auto conv = ngraph::builder::makeConvolution(permute_in, ngPrc, {kernel_shape[0], kernel_shape[1]}, {1, 1}, {0, 0}, {0, 0}, {1, 1}, - ngraph::op::PadType::VALID, output_channels, false, generateFloatNumbers(conv_weights_size, -0.5f, 0.5f)); + ngraph::op::PadType::VALID, output_channels, false, CommonTestUtils::generate_float_numbers(conv_weights_size, -0.5f, 0.5f)); auto permute_out_params = std::make_shared(ngraph::element::i64, ngraph::Shape{4}, @@ -88,7 +75,8 @@ void PermConvPermConcat::SetUp() { auto permute_out_shape = permute_out->get_output_shape(0); - auto concat_const = ngraph::builder::makeConstant(ngPrc, {1, 1, 1, permute_out_shape[3]}, generateFloatNumbers(permute_out_shape[3], -10, 10)); + auto concat_const = ngraph::builder::makeConstant(ngPrc, {1, 1, 1, permute_out_shape[3]}, + CommonTestUtils::generate_float_numbers(permute_out_shape[3], -10, 10)); auto concat = ngraph::builder::makeConcat({permute_out, concat_const}, 2); diff --git a/inference-engine/tests/ie_test_utils/common_test_utils/data_utils.hpp b/inference-engine/tests/ie_test_utils/common_test_utils/data_utils.hpp index 8d46d855d3b..fcbb64cf041 100644 --- a/inference-engine/tests/ie_test_utils/common_test_utils/data_utils.hpp +++ b/inference-engine/tests/ie_test_utils/common_test_utils/data_utils.hpp @@ -31,6 +31,21 @@ static void fill_data_sine(float *data, size_t size, float center, float ampl, f } } +/** + * @brief Create vector of floats with length of vec_len, with values ranging from min to max, + * with initial seed equal to variable seed with default of 0 + */ +static inline std::vector generate_float_numbers(std::size_t vec_len, float min, float max, int seed = 0) { + std::vector res; + std::mt19937 gen(static_cast(seed)); + + std::uniform_real_distribution dist(min, max); + for (int i = 0; i < vec_len; i++) + res.emplace_back(static_cast(dist(gen))); + + return res; +} + /** * Fill blob with value data blob. Broadcast semantic is included. * Broadcasting with alignment through last dimension.