[GNA] fix TensorIterator unrolling with one iteration (#5217)
This commit is contained in:
parent
7ac7215924
commit
f7cf92e52a
@ -615,8 +615,11 @@ bool unrollTI(CNNLayerPtr cur, CNNNetwork& net) {
|
|||||||
auto out_data = ti->outData[rule.from];
|
auto out_data = ti->outData[rule.from];
|
||||||
|
|
||||||
if (num == 1) {
|
if (num == 1) {
|
||||||
getInputTo(body_list[0].outputs[rule.to]) = getInputTo(out_data);
|
auto to_data = body_list[0].outputs[rule.to];
|
||||||
getInputTo(body_list[0].outputs[rule.to]).begin()->second->insData[0] = body_list[0].outputs[rule.to];
|
auto parent = getCreatorLayer(to_data).lock();
|
||||||
|
std::replace(parent->outData.begin(), parent->outData.end(), to_data, out_data);
|
||||||
|
getCreatorLayer(out_data) = parent;
|
||||||
|
CombineData(out_data, to_data);
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -1,4 +1,4 @@
|
|||||||
// Copyright (C) 2018-2021 Intel Corporation
|
// Copyright (C) 2021 Intel Corporation
|
||||||
// SPDX-License-Identifier: Apache-2.0
|
// SPDX-License-Identifier: Apache-2.0
|
||||||
//
|
//
|
||||||
|
|
||||||
|
@ -0,0 +1,37 @@
|
|||||||
|
// Copyright (C) 2021 Intel Corporation
|
||||||
|
// SPDX-License-Identifier: Apache-2.0
|
||||||
|
//
|
||||||
|
|
||||||
|
#include <vector>
|
||||||
|
#include <ngraph/op/util/attr_types.hpp>
|
||||||
|
#include "single_layer_tests/tensor_iterator.hpp"
|
||||||
|
#include "common_test_utils/test_constants.hpp"
|
||||||
|
|
||||||
|
using namespace LayerTestsDefinitions;
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
const std::vector<bool> should_decompose = {false};
|
||||||
|
const std::vector<size_t> seqLengths = {1};
|
||||||
|
const std::vector<size_t> batches = {1};
|
||||||
|
const std::vector<size_t> hiddenSizes = {128, 200, 300};
|
||||||
|
const std::vector<size_t> seqAxes = {0, 1};
|
||||||
|
const std::vector<float> clip = {0.f};
|
||||||
|
const std::vector<InferenceEngine::Precision> netPrecisions = {
|
||||||
|
InferenceEngine::Precision::FP32,
|
||||||
|
InferenceEngine::Precision::FP16
|
||||||
|
};
|
||||||
|
|
||||||
|
INSTANTIATE_TEST_CASE_P(smoke_TensorIterator, TensorIteratorTest,
|
||||||
|
::testing::Combine(
|
||||||
|
::testing::ValuesIn(should_decompose),
|
||||||
|
::testing::ValuesIn(seqLengths),
|
||||||
|
::testing::ValuesIn(batches),
|
||||||
|
::testing::ValuesIn(hiddenSizes),
|
||||||
|
::testing::ValuesIn(seqAxes),
|
||||||
|
::testing::ValuesIn(clip),
|
||||||
|
::testing::Values(ngraph::helpers::TensorIteratorBody::LSTM),
|
||||||
|
::testing::Values(ngraph::op::RecurrentSequenceDirection::FORWARD),
|
||||||
|
::testing::ValuesIn(netPrecisions),
|
||||||
|
::testing::Values(CommonTestUtils::DEVICE_GNA)),
|
||||||
|
TensorIteratorTest::getTestCaseName);
|
||||||
|
} // namespace
|
Loading…
Reference in New Issue
Block a user