[GNA] fix TensorIterator unrolling with one iteration (#5217)

This commit is contained in:
Elizaveta Lobanova 2021-04-14 18:55:52 +03:00 committed by GitHub
parent 7ac7215924
commit f7cf92e52a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 43 additions and 3 deletions

View File

@ -615,8 +615,11 @@ bool unrollTI(CNNLayerPtr cur, CNNNetwork& net) {
auto out_data = ti->outData[rule.from];
if (num == 1) {
getInputTo(body_list[0].outputs[rule.to]) = getInputTo(out_data);
getInputTo(body_list[0].outputs[rule.to]).begin()->second->insData[0] = body_list[0].outputs[rule.to];
auto to_data = body_list[0].outputs[rule.to];
auto parent = getCreatorLayer(to_data).lock();
std::replace(parent->outData.begin(), parent->outData.end(), to_data, out_data);
getCreatorLayer(out_data) = parent;
CombineData(out_data, to_data);
continue;
}

View File

@ -1,4 +1,4 @@
// Copyright (C) 2018-2021 Intel Corporation
// Copyright (C) 2021 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//

View File

@ -0,0 +1,37 @@
// Copyright (C) 2021 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#include <vector>
#include <ngraph/op/util/attr_types.hpp>
#include "single_layer_tests/tensor_iterator.hpp"
#include "common_test_utils/test_constants.hpp"
using namespace LayerTestsDefinitions;
namespace {
const std::vector<bool> should_decompose = {false};
const std::vector<size_t> seqLengths = {1};
const std::vector<size_t> batches = {1};
const std::vector<size_t> hiddenSizes = {128, 200, 300};
const std::vector<size_t> seqAxes = {0, 1};
const std::vector<float> clip = {0.f};
const std::vector<InferenceEngine::Precision> netPrecisions = {
InferenceEngine::Precision::FP32,
InferenceEngine::Precision::FP16
};
INSTANTIATE_TEST_CASE_P(smoke_TensorIterator, TensorIteratorTest,
::testing::Combine(
::testing::ValuesIn(should_decompose),
::testing::ValuesIn(seqLengths),
::testing::ValuesIn(batches),
::testing::ValuesIn(hiddenSizes),
::testing::ValuesIn(seqAxes),
::testing::ValuesIn(clip),
::testing::Values(ngraph::helpers::TensorIteratorBody::LSTM),
::testing::Values(ngraph::op::RecurrentSequenceDirection::FORWARD),
::testing::ValuesIn(netPrecisions),
::testing::Values(CommonTestUtils::DEVICE_GNA)),
TensorIteratorTest::getTestCaseName);
} // namespace