[GNA] fix TensorIterator unrolling with one iteration (#5217)
This commit is contained in:
parent
7ac7215924
commit
f7cf92e52a
@ -615,8 +615,11 @@ bool unrollTI(CNNLayerPtr cur, CNNNetwork& net) {
|
||||
auto out_data = ti->outData[rule.from];
|
||||
|
||||
if (num == 1) {
|
||||
getInputTo(body_list[0].outputs[rule.to]) = getInputTo(out_data);
|
||||
getInputTo(body_list[0].outputs[rule.to]).begin()->second->insData[0] = body_list[0].outputs[rule.to];
|
||||
auto to_data = body_list[0].outputs[rule.to];
|
||||
auto parent = getCreatorLayer(to_data).lock();
|
||||
std::replace(parent->outData.begin(), parent->outData.end(), to_data, out_data);
|
||||
getCreatorLayer(out_data) = parent;
|
||||
CombineData(out_data, to_data);
|
||||
continue;
|
||||
}
|
||||
|
||||
|
@ -1,4 +1,4 @@
|
||||
// Copyright (C) 2018-2021 Intel Corporation
|
||||
// Copyright (C) 2021 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
|
@ -0,0 +1,37 @@
|
||||
// Copyright (C) 2021 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#include <vector>
|
||||
#include <ngraph/op/util/attr_types.hpp>
|
||||
#include "single_layer_tests/tensor_iterator.hpp"
|
||||
#include "common_test_utils/test_constants.hpp"
|
||||
|
||||
using namespace LayerTestsDefinitions;
|
||||
|
||||
namespace {
|
||||
const std::vector<bool> should_decompose = {false};
|
||||
const std::vector<size_t> seqLengths = {1};
|
||||
const std::vector<size_t> batches = {1};
|
||||
const std::vector<size_t> hiddenSizes = {128, 200, 300};
|
||||
const std::vector<size_t> seqAxes = {0, 1};
|
||||
const std::vector<float> clip = {0.f};
|
||||
const std::vector<InferenceEngine::Precision> netPrecisions = {
|
||||
InferenceEngine::Precision::FP32,
|
||||
InferenceEngine::Precision::FP16
|
||||
};
|
||||
|
||||
INSTANTIATE_TEST_CASE_P(smoke_TensorIterator, TensorIteratorTest,
|
||||
::testing::Combine(
|
||||
::testing::ValuesIn(should_decompose),
|
||||
::testing::ValuesIn(seqLengths),
|
||||
::testing::ValuesIn(batches),
|
||||
::testing::ValuesIn(hiddenSizes),
|
||||
::testing::ValuesIn(seqAxes),
|
||||
::testing::ValuesIn(clip),
|
||||
::testing::Values(ngraph::helpers::TensorIteratorBody::LSTM),
|
||||
::testing::Values(ngraph::op::RecurrentSequenceDirection::FORWARD),
|
||||
::testing::ValuesIn(netPrecisions),
|
||||
::testing::Values(CommonTestUtils::DEVICE_GNA)),
|
||||
TensorIteratorTest::getTestCaseName);
|
||||
} // namespace
|
Loading…
Reference in New Issue
Block a user