Do not convert Sequences to TensorIterator, when plugin supports the original primitive (#4631)

* do not convert Sequences to TensorIterator when plugin supports Sequence primitive

* fix referece implementations for Sequences, processing seq_len value == 0 case

* Adding new mode for LSTMSequence single layer tests

* update single layer tests

* fix failed unit tests, updated single layer tests for rnn/gru sequences

* fix failed unit tests

* fix single layer tests

* ignore failed single layer tests on gpu (known issue), fix review remarks
This commit is contained in:
Ivan Tikhonov 2021-03-22 06:46:46 +03:00 committed by GitHub
parent f8c037c238
commit a868d88d49
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
22 changed files with 191 additions and 72 deletions

View File

@ -18,7 +18,7 @@
#include <ngraph/opsets/opset2.hpp> #include <ngraph/opsets/opset2.hpp>
#include <ngraph/opsets/opset3.hpp> #include <ngraph/opsets/opset3.hpp>
#include <ngraph/opsets/opset4.hpp> #include <ngraph/opsets/opset4.hpp>
#include <ngraph/opsets/opset5.hpp> #include <ngraph/opsets/opset6.hpp>
#include <ngraph/pass/manager.hpp> #include <ngraph/pass/manager.hpp>
#include <ngraph/pass/constant_folding.hpp> #include <ngraph/pass/constant_folding.hpp>
#include <ie_ngraph_utils.hpp> #include <ie_ngraph_utils.hpp>
@ -221,31 +221,56 @@ InferenceEngine::CNNNetwork clDNNEngine::CloneAndTransformNetwork(const Inferenc
}); });
auto isCellPrimitiveSupported = [](const_node_ptr &node) -> bool { auto isCellPrimitiveSupported = [](const_node_ptr &node) -> bool {
if (std::dynamic_pointer_cast<const ngraph::op::v0::RNNCell>(node) || std::dynamic_pointer_cast<const ngraph::op::v5::RNNSequence>(node)) { if (std::dynamic_pointer_cast<const ngraph::opset6::RNNCell>(node)) {
return false; return false;
} else if (std::dynamic_pointer_cast<const ngraph::op::v3::GRUCell>(node) || } else if (std::dynamic_pointer_cast<const ngraph::opset6::GRUCell>(node)) {
std::dynamic_pointer_cast<const ngraph::op::v5::GRUSequence>(node)) {
return false; return false;
} else if (const auto &lstm_cell = std::dynamic_pointer_cast<const ngraph::op::v4::LSTMCell>(node)) { } else if (const auto &lstm_cell = std::dynamic_pointer_cast<const ngraph::opset6::LSTMCell>(node)) {
return lstm_cell->get_clip() == 0.0f && lstm_cell->get_activations() == std::vector<std::string>{"sigmoid", "tanh", "tanh"}; return lstm_cell->get_clip() == 0.0f && lstm_cell->get_activations() == std::vector<std::string>{"sigmoid", "tanh", "tanh"};
} else if (const auto &lstm_cell_v1 = std::dynamic_pointer_cast<const ngraph::op::v0::LSTMCell>(node)) { } else if (const auto &lstm_cell_v1 = std::dynamic_pointer_cast<const ngraph::opset1::LSTMCell>(node)) {
return lstm_cell_v1->get_clip() == 0.0f && lstm_cell_v1->get_activations() == std::vector<std::string>{"sigmoid", "tanh", "tanh"}; return lstm_cell_v1->get_clip() == 0.0f && lstm_cell_v1->get_activations() == std::vector<std::string>{"sigmoid", "tanh", "tanh"};
} else if (const auto &lstm_sequence = std::dynamic_pointer_cast<const ngraph::op::v5::LSTMSequence>(node)) {
return lstm_sequence->get_clip() == 0.0f && lstm_sequence->get_activations() == std::vector<std::string>{"sigmoid", "tanh", "tanh"};
} }
return false; return false;
}; };
pass_config->set_callback<ngraph::pass::ConvertRNNSequenceToTensorIterator, // Sequences supported by the plugin shouldn't be converted to TensorIterator.
ngraph::pass::ConvertGRUSequenceToTensorIterator, // sequence_length input is not supported in all Sequences, so if is_seq_len_provided() == true, we
ngraph::pass::ConvertLSTMSequenceToTensorIterator, // should always convert to TensorIterator.
ngraph::pass::RNNCellDecomposition, // RNN/GRU Sequences are not supported in GPU plugin
// LSTM Sequence supported with clip == 0, and activations have default values (sigmoid, tanh, tanh)
auto isSequencePrimitiveSupported = [](const_node_ptr &node) -> bool {
const auto& data = node->input(0);
const auto& data_pshape = data.get_partial_shape();
if (data_pshape.rank().is_static() && data_pshape.rank().get_length() > 1 && !data_pshape[1].is_static())
return false;
auto max_seq_len = data.get_shape().at(1);
if (std::dynamic_pointer_cast<const ngraph::opset6::RNNSequence>(node)) {
return false;
} else if (std::dynamic_pointer_cast<const ngraph::opset6::GRUSequence>(node)) {
return false;
} else if (const auto &lstm_seq = std::dynamic_pointer_cast<const ngraph::opset6::LSTMSequence>(node)) {
return lstm_seq->get_clip() == 0.0f &&
lstm_seq->get_activations() == std::vector<std::string>{"sigmoid", "tanh", "tanh"} &&
!ngraph::op::util::is_seq_len_provided(lstm_seq->get_input_node_shared_ptr(3),
max_seq_len);
}
return false;
};
pass_config->set_callback<ngraph::pass::RNNCellDecomposition,
ngraph::pass::GRUCellDecomposition, ngraph::pass::GRUCellDecomposition,
ngraph::pass::LSTMCellDecomposition>( ngraph::pass::LSTMCellDecomposition>(
[isCellPrimitiveSupported](const_node_ptr &node) -> bool { [isCellPrimitiveSupported](const_node_ptr &node) -> bool {
return isCellPrimitiveSupported(node); return isCellPrimitiveSupported(node);
}); });
pass_config->set_callback<ngraph::pass::ConvertRNNSequenceToTensorIterator,
ngraph::pass::ConvertGRUSequenceToTensorIterator,
ngraph::pass::ConvertLSTMSequenceToTensorIterator>(
[isSequencePrimitiveSupported](const_node_ptr &node) -> bool {
return isSequencePrimitiveSupported(node);
});
pass_config->set_callback<ngraph::pass::ConvertTensorIteratorToRNNSequence, pass_config->set_callback<ngraph::pass::ConvertTensorIteratorToRNNSequence,
ngraph::pass::ConvertTensorIteratorToLSTMSequence, ngraph::pass::ConvertTensorIteratorToLSTMSequence,
ngraph::pass::ConvertTensorIteratorToGRUSequence>( ngraph::pass::ConvertTensorIteratorToGRUSequence>(

View File

@ -116,6 +116,7 @@ ngraph::pass::ConvertLSTMSequenceMatcher::ConvertLSTMSequenceMatcher() {
if (lstm_target_inputs.empty()) if (lstm_target_inputs.empty())
return false; return false;
auto transpose_after = lstm_target_inputs.begin()->get_node()->shared_from_this(); auto transpose_after = lstm_target_inputs.begin()->get_node()->shared_from_this();
unsqueeze_1->set_friendly_name(transpose_after->get_friendly_name());
ngraph::replace_node(transpose_after, unsqueeze_1); ngraph::replace_node(transpose_after, unsqueeze_1);
ngraph::replace_node(lstm_sequence, {lstm_sequence_ie->output(0), unsqueeze_2->output(0), unsqueeze_3->output(0)}); ngraph::replace_node(lstm_sequence, {lstm_sequence_ie->output(0), unsqueeze_2->output(0), unsqueeze_3->output(0)});
} }
@ -186,6 +187,7 @@ ngraph::pass::ConvertGRUSequenceMatcher::ConvertGRUSequenceMatcher() {
if (gru_target_inputs.empty()) if (gru_target_inputs.empty())
return false; return false;
auto transpose_after = gru_target_inputs.begin()->get_node()->shared_from_this(); auto transpose_after = gru_target_inputs.begin()->get_node()->shared_from_this();
unsqueeze_1->set_friendly_name(transpose_after->get_friendly_name());
ngraph::replace_node(transpose_after, unsqueeze_1); ngraph::replace_node(transpose_after, unsqueeze_1);
ngraph::replace_node(gru_sequence, {gru_sequence_ie->output(0), unsqueeze_2->output(0)}); ngraph::replace_node(gru_sequence, {gru_sequence_ie->output(0), unsqueeze_2->output(0)});
} }
@ -257,6 +259,7 @@ ngraph::pass::ConvertRNNSequenceMatcher::ConvertRNNSequenceMatcher() {
if (rnn_target_inputs.empty()) if (rnn_target_inputs.empty())
return false; return false;
auto transpose_after = rnn_target_inputs.begin()->get_node()->shared_from_this(); auto transpose_after = rnn_target_inputs.begin()->get_node()->shared_from_this();
unsqueeze_1->set_friendly_name(transpose_after->get_friendly_name());
ngraph::replace_node(transpose_after, unsqueeze_1); ngraph::replace_node(transpose_after, unsqueeze_1);
ngraph::replace_node(rnn_sequence, {rnn_sequence_ie->output(0), unsqueeze_2->output(0)}); ngraph::replace_node(rnn_sequence, {rnn_sequence_ie->output(0), unsqueeze_2->output(0)});
} }

View File

@ -68,6 +68,7 @@
#include <ngraph/opsets/opset2.hpp> #include <ngraph/opsets/opset2.hpp>
#include <ngraph/opsets/opset3.hpp> #include <ngraph/opsets/opset3.hpp>
#include <ngraph/opsets/opset4.hpp> #include <ngraph/opsets/opset4.hpp>
#include <ngraph/opsets/opset6.hpp>
#include <ngraph/op/util/op_types.hpp> #include <ngraph/op/util/op_types.hpp>
#include <ngraph/pass/manager.hpp> #include <ngraph/pass/manager.hpp>
@ -196,6 +197,42 @@ static void Transformation(CNNNetwork& clonedNetwork, const Config& conf) {
return false; return false;
}; };
// Sequences supported by the plugin shouldn't be converted to TensorIterator.
// sequence_length input is not supported in all Sequences, so if is_seq_len_provided() == true, we
// should always convert to TensorIterator.
// RNN/GRU/LSTM Sequences are supported with clip == 0, and with default activations.
auto isSequencePrimitiveSupported = [](const_node_ptr &node) -> bool {
const auto& data = node->input(0);
const auto& data_pshape = data.get_partial_shape();
if (data_pshape.rank().is_static() && data_pshape.rank().get_length() > 1 && !data_pshape[1].is_static())
return false;
auto max_seq_len = data.get_shape().at(1);
if (const auto &rnn_seq = std::dynamic_pointer_cast<const ngraph::opset6::RNNSequence>(node)) {
return rnn_seq->get_clip() == 0.0f &&
!ngraph::op::util::is_seq_len_provided(rnn_seq->get_input_node_shared_ptr(2),
max_seq_len);
} else if (const auto &gru_seq = std::dynamic_pointer_cast<const ngraph::opset6::GRUSequence>(
node)) {
return gru_seq->get_clip() == 0.0f &&
gru_seq->get_activations() == std::vector<std::string>{"sigmoid", "tanh"} &&
!ngraph::op::util::is_seq_len_provided(gru_seq->get_input_node_shared_ptr(2),
max_seq_len);
} else if (const auto &lstm_seq = std::dynamic_pointer_cast<const ngraph::opset6::LSTMSequence>(
node)) {
return lstm_seq->get_clip() == 0.0f &&
lstm_seq->get_activations() == std::vector<std::string>{"sigmoid", "tanh", "tanh"} &&
!ngraph::op::util::is_seq_len_provided(lstm_seq->get_input_node_shared_ptr(3),
max_seq_len);
}
return false;
};
pass_config->set_callback<ngraph::pass::ConvertRNNSequenceToTensorIterator, ngraph::pass::ConvertGRUSequenceToTensorIterator,
ngraph::pass::ConvertLSTMSequenceToTensorIterator>(
[isSequencePrimitiveSupported](const_node_ptr &node) -> bool {
return isSequencePrimitiveSupported(node);
});
pass_config->set_callback<ngraph::pass::RNNCellDecomposition, ngraph::pass::GRUCellDecomposition, pass_config->set_callback<ngraph::pass::RNNCellDecomposition, ngraph::pass::GRUCellDecomposition,
ngraph::pass::LSTMCellDecomposition>( ngraph::pass::LSTMCellDecomposition>(
[isCellPrimitiveSupported](const_node_ptr &node) -> bool { [isCellPrimitiveSupported](const_node_ptr &node) -> bool {

View File

@ -1,4 +1,4 @@
// Copyright (C) 2018-2020 Intel Corporation // Copyright (C) 2018-2021 Intel Corporation
// SPDX-License-Identifier: Apache-2.0 // SPDX-License-Identifier: Apache-2.0
// //
@ -104,6 +104,9 @@ TRANSFORMATIONS_API bool check_for_broadcast(const ngraph::Shape &ref_shape, con
TRANSFORMATIONS_API std::shared_ptr<ngraph::Node> activation(const std::string& activation_name, TRANSFORMATIONS_API std::shared_ptr<ngraph::Node> activation(const std::string& activation_name,
const ngraph::Output<ngraph::Node>& apply_to); const ngraph::Output<ngraph::Node>& apply_to);
TRANSFORMATIONS_API bool is_seq_len_provided(const std::shared_ptr<Node> &seq_len_input, int64_t max_seq_len);
template <class T> template <class T>
Output<Node> eltwise_fold(const Output<Node> & input0, const Output<Node> & input1) { Output<Node> eltwise_fold(const Output<Node> & input0, const Output<Node> & input1) {
auto eltwise = std::make_shared<T>(input0, input1); auto eltwise = std::make_shared<T>(input0, input1);

View File

@ -70,18 +70,6 @@ namespace {
} }
return squeezed_nodes; return squeezed_nodes;
} }
bool should_enable_mask(const ngraph::Output<ngraph::Node> &seq_lengths, int64_t max_seq_len) {
// disable the mask if all values of seq_lengths input are equal to max_seq_len (X_shape[1])
if (const auto &seq_len_const = std::dynamic_pointer_cast<ngraph::opset5::Constant>(
seq_lengths.get_node_shared_ptr())) {
const auto &seq_len_values = seq_len_const->cast_vector<int64_t>();
return std::any_of(seq_len_values.begin(), seq_len_values.end(), [max_seq_len](const int64_t val) {
return val != max_seq_len;
});
}
return true;
}
} // namespace } // namespace
ngraph::pass::ConvertRNNSequenceToTensorIterator::ConvertRNNSequenceToTensorIterator() { ngraph::pass::ConvertRNNSequenceToTensorIterator::ConvertRNNSequenceToTensorIterator() {
@ -93,12 +81,12 @@ ngraph::pass::ConvertRNNSequenceToTensorIterator::ConvertRNNSequenceToTensorIter
pattern::any_input(), pattern::any_input(),
pattern::any_input(), pattern::any_input(),
pattern::any_input()}); pattern::any_input()});
ngraph::matcher_pass_callback callback = [](ngraph::pattern::Matcher &m) { ngraph::matcher_pass_callback callback = [this](ngraph::pattern::Matcher &m) {
auto sequence = std::dynamic_pointer_cast<ngraph::opset5::RNNSequence>(m.get_match_root()); auto sequence = std::dynamic_pointer_cast<ngraph::opset5::RNNSequence>(m.get_match_root());
// Bidirectional Sequence op should be decomposed to Reverse + Forward // Bidirectional Sequence op should be decomposed to Reverse + Forward
// (e.g. apply BidirectionalRNNSequenceDecomposition transformation before this one) // (e.g. apply BidirectionalRNNSequenceDecomposition transformation before this one)
if (!sequence || sequence->get_direction() == ngraph::op::RecurrentSequenceDirection::BIDIRECTIONAL) { if (!sequence || sequence->get_direction() == ngraph::op::RecurrentSequenceDirection::BIDIRECTIONAL || transformation_callback(sequence)) {
return false; return false;
} }
@ -113,7 +101,7 @@ ngraph::pass::ConvertRNNSequenceToTensorIterator::ConvertRNNSequenceToTensorIter
auto tensor_iterator = std::make_shared<opset5::TensorIterator>(); auto tensor_iterator = std::make_shared<opset5::TensorIterator>();
auto max_seq_len = X.get_shape().at(1); auto max_seq_len = X.get_shape().at(1);
bool enable_mask = should_enable_mask(seq_lengths, max_seq_len); bool enable_mask = ngraph::op::util::is_seq_len_provided(seq_lengths.get_node_shared_ptr(), max_seq_len);
std::shared_ptr<Node> reverse_seq_before; std::shared_ptr<Node> reverse_seq_before;
if (is_reverse && enable_mask) { if (is_reverse && enable_mask) {
@ -252,12 +240,12 @@ ngraph::pass::ConvertGRUSequenceToTensorIterator::ConvertGRUSequenceToTensorIter
pattern::any_input(), pattern::any_input(),
pattern::any_input(), pattern::any_input(),
pattern::any_input()}); pattern::any_input()});
ngraph::matcher_pass_callback callback = [](ngraph::pattern::Matcher &m) { ngraph::matcher_pass_callback callback = [this](ngraph::pattern::Matcher &m) {
auto sequence = std::dynamic_pointer_cast<ngraph::opset5::GRUSequence>(m.get_match_root()); auto sequence = std::dynamic_pointer_cast<ngraph::opset5::GRUSequence>(m.get_match_root());
// Bidirectional Sequence op should be decomposed to Reverse + Forward // Bidirectional Sequence op should be decomposed to Reverse + Forward
// (e.g. apply BidirectionalRNNSequenceDecomposition transformation before this one) // (e.g. apply BidirectionalRNNSequenceDecomposition transformation before this one)
if (!sequence || sequence->get_direction() == ngraph::op::RecurrentSequenceDirection::BIDIRECTIONAL) { if (!sequence || sequence->get_direction() == ngraph::op::RecurrentSequenceDirection::BIDIRECTIONAL || transformation_callback(sequence)) {
return false; return false;
} }
@ -272,7 +260,7 @@ ngraph::pass::ConvertGRUSequenceToTensorIterator::ConvertGRUSequenceToTensorIter
auto tensor_iterator = std::make_shared<opset5::TensorIterator>(); auto tensor_iterator = std::make_shared<opset5::TensorIterator>();
auto max_seq_len = X.get_shape().at(1); auto max_seq_len = X.get_shape().at(1);
bool enable_mask = should_enable_mask(seq_lengths, max_seq_len); bool enable_mask = ngraph::op::util::is_seq_len_provided(seq_lengths.get_node_shared_ptr(), max_seq_len);
std::shared_ptr<Node> reverse_seq_before; std::shared_ptr<Node> reverse_seq_before;
if (is_reverse && enable_mask) { if (is_reverse && enable_mask) {
@ -412,12 +400,12 @@ ngraph::pass::ConvertLSTMSequenceToTensorIterator::ConvertLSTMSequenceToTensorIt
pattern::any_input(), pattern::any_input(),
pattern::any_input(), pattern::any_input(),
pattern::any_input()}); pattern::any_input()});
ngraph::matcher_pass_callback callback = [](ngraph::pattern::Matcher &m) { ngraph::matcher_pass_callback callback = [this](ngraph::pattern::Matcher &m) {
auto sequence = std::dynamic_pointer_cast<ngraph::opset5::LSTMSequence>(m.get_match_root()); auto sequence = std::dynamic_pointer_cast<ngraph::opset5::LSTMSequence>(m.get_match_root());
// Bidirectional Sequence op should be decomposed to Reverse + Forward // Bidirectional Sequence op should be decomposed to Reverse + Forward
// (e.g. apply BidirectionalRNNSequenceDecomposition transformation before this one) // (e.g. apply BidirectionalRNNSequenceDecomposition transformation before this one)
if (!sequence || sequence->get_direction() == ngraph::op::RecurrentSequenceDirection::BIDIRECTIONAL) { if (!sequence || sequence->get_direction() == ngraph::op::RecurrentSequenceDirection::BIDIRECTIONAL || transformation_callback(sequence)) {
return false; return false;
} }
@ -433,7 +421,7 @@ ngraph::pass::ConvertLSTMSequenceToTensorIterator::ConvertLSTMSequenceToTensorIt
auto tensor_iterator = std::make_shared<opset5::TensorIterator>(); auto tensor_iterator = std::make_shared<opset5::TensorIterator>();
auto max_seq_len = X.get_shape().at(1); auto max_seq_len = X.get_shape().at(1);
bool enable_mask = should_enable_mask(seq_lengths, max_seq_len); bool enable_mask = ngraph::op::util::is_seq_len_provided(seq_lengths.get_node_shared_ptr(), max_seq_len);
std::shared_ptr<Node> reverse_seq_before; std::shared_ptr<Node> reverse_seq_before;
if (is_reverse && enable_mask) { if (is_reverse && enable_mask) {

View File

@ -120,6 +120,16 @@ std::shared_ptr<ngraph::Node> activation(const std::string& activation_name, con
} }
} }
bool is_seq_len_provided(const std::shared_ptr<Node> &seq_len_input, int64_t max_seq_len) {
if (const auto &seq_len_const = std::dynamic_pointer_cast<ngraph::op::Constant>(seq_len_input)) {
const auto &seq_len_values = seq_len_const->cast_vector<int64_t>();
return std::any_of(seq_len_values.begin(), seq_len_values.end(), [max_seq_len](const int64_t val) {
return val != max_seq_len;
});
}
return true;
}
} // namespace util } // namespace util
} // namespace op } // namespace op
} // namespace ngraph } // namespace ngraph

View File

@ -13,6 +13,8 @@ namespace {
std::vector<ngraph::helpers::SequenceTestsMode> mode{ngraph::helpers::SequenceTestsMode::CONVERT_TO_TI_MAX_SEQ_LEN_CONST, std::vector<ngraph::helpers::SequenceTestsMode> mode{ngraph::helpers::SequenceTestsMode::CONVERT_TO_TI_MAX_SEQ_LEN_CONST,
ngraph::helpers::SequenceTestsMode::CONVERT_TO_TI_RAND_SEQ_LEN_CONST, ngraph::helpers::SequenceTestsMode::CONVERT_TO_TI_RAND_SEQ_LEN_CONST,
ngraph::helpers::SequenceTestsMode::CONVERT_TO_TI_RAND_SEQ_LEN_PARAM, ngraph::helpers::SequenceTestsMode::CONVERT_TO_TI_RAND_SEQ_LEN_PARAM,
ngraph::helpers::SequenceTestsMode::PURE_SEQ_RAND_SEQ_LEN_PARAM,
ngraph::helpers::SequenceTestsMode::PURE_SEQ_RAND_SEQ_LEN_CONST,
ngraph::helpers::SequenceTestsMode::PURE_SEQ}; ngraph::helpers::SequenceTestsMode::PURE_SEQ};
// output values increase rapidly without clip, so use only seq_lenghts = 2 // output values increase rapidly without clip, so use only seq_lenghts = 2
std::vector<size_t> seq_lengths_zero_clip{2}; std::vector<size_t> seq_lengths_zero_clip{2};

View File

@ -13,7 +13,9 @@ namespace {
std::vector<ngraph::helpers::SequenceTestsMode> mode{ngraph::helpers::SequenceTestsMode::CONVERT_TO_TI_MAX_SEQ_LEN_CONST, std::vector<ngraph::helpers::SequenceTestsMode> mode{ngraph::helpers::SequenceTestsMode::CONVERT_TO_TI_MAX_SEQ_LEN_CONST,
ngraph::helpers::SequenceTestsMode::CONVERT_TO_TI_RAND_SEQ_LEN_CONST, ngraph::helpers::SequenceTestsMode::CONVERT_TO_TI_RAND_SEQ_LEN_CONST,
ngraph::helpers::SequenceTestsMode::CONVERT_TO_TI_RAND_SEQ_LEN_PARAM, ngraph::helpers::SequenceTestsMode::CONVERT_TO_TI_RAND_SEQ_LEN_PARAM,
ngraph::helpers::SequenceTestsMode::PURE_SEQ}; ngraph::helpers::SequenceTestsMode::PURE_SEQ,
ngraph::helpers::SequenceTestsMode::PURE_SEQ_RAND_SEQ_LEN_CONST,
ngraph::helpers::SequenceTestsMode::PURE_SEQ_RAND_SEQ_LEN_PARAM};
// output values increase rapidly without clip, so use only seq_lenghts = 2 // output values increase rapidly without clip, so use only seq_lenghts = 2
std::vector<size_t> seq_lengths_zero_clip{2}; std::vector<size_t> seq_lengths_zero_clip{2};
std::vector<size_t> seq_lengths_clip_non_zero{20}; std::vector<size_t> seq_lengths_clip_non_zero{20};

View File

@ -13,6 +13,8 @@ namespace {
std::vector<ngraph::helpers::SequenceTestsMode> mode{ngraph::helpers::SequenceTestsMode::CONVERT_TO_TI_MAX_SEQ_LEN_CONST, std::vector<ngraph::helpers::SequenceTestsMode> mode{ngraph::helpers::SequenceTestsMode::CONVERT_TO_TI_MAX_SEQ_LEN_CONST,
ngraph::helpers::SequenceTestsMode::CONVERT_TO_TI_RAND_SEQ_LEN_CONST, ngraph::helpers::SequenceTestsMode::CONVERT_TO_TI_RAND_SEQ_LEN_CONST,
ngraph::helpers::SequenceTestsMode::CONVERT_TO_TI_RAND_SEQ_LEN_PARAM, ngraph::helpers::SequenceTestsMode::CONVERT_TO_TI_RAND_SEQ_LEN_PARAM,
ngraph::helpers::SequenceTestsMode::PURE_SEQ_RAND_SEQ_LEN_PARAM,
ngraph::helpers::SequenceTestsMode::PURE_SEQ_RAND_SEQ_LEN_CONST,
ngraph::helpers::SequenceTestsMode::PURE_SEQ}; ngraph::helpers::SequenceTestsMode::PURE_SEQ};
// output values increase rapidly without clip, so use only seq_lenghts = 2 // output values increase rapidly without clip, so use only seq_lenghts = 2
std::vector<size_t> seq_lengths_zero_clip{2}; std::vector<size_t> seq_lengths_zero_clip{2};

View File

@ -13,6 +13,8 @@ namespace {
std::vector<ngraph::helpers::SequenceTestsMode> mode{ngraph::helpers::SequenceTestsMode::CONVERT_TO_TI_MAX_SEQ_LEN_CONST, std::vector<ngraph::helpers::SequenceTestsMode> mode{ngraph::helpers::SequenceTestsMode::CONVERT_TO_TI_MAX_SEQ_LEN_CONST,
ngraph::helpers::SequenceTestsMode::CONVERT_TO_TI_RAND_SEQ_LEN_CONST, ngraph::helpers::SequenceTestsMode::CONVERT_TO_TI_RAND_SEQ_LEN_CONST,
ngraph::helpers::SequenceTestsMode::CONVERT_TO_TI_RAND_SEQ_LEN_PARAM, ngraph::helpers::SequenceTestsMode::CONVERT_TO_TI_RAND_SEQ_LEN_PARAM,
ngraph::helpers::SequenceTestsMode::PURE_SEQ_RAND_SEQ_LEN_CONST,
ngraph::helpers::SequenceTestsMode::PURE_SEQ_RAND_SEQ_LEN_PARAM,
ngraph::helpers::SequenceTestsMode::PURE_SEQ}; ngraph::helpers::SequenceTestsMode::PURE_SEQ};
// output values increase rapidly without clip, so use only seq_lenghts = 2 // output values increase rapidly without clip, so use only seq_lenghts = 2
std::vector<size_t> seq_lengths_zero_clip{2}; std::vector<size_t> seq_lengths_zero_clip{2};

View File

@ -13,6 +13,8 @@ namespace {
std::vector<ngraph::helpers::SequenceTestsMode> mode{ngraph::helpers::SequenceTestsMode::CONVERT_TO_TI_MAX_SEQ_LEN_CONST, std::vector<ngraph::helpers::SequenceTestsMode> mode{ngraph::helpers::SequenceTestsMode::CONVERT_TO_TI_MAX_SEQ_LEN_CONST,
ngraph::helpers::SequenceTestsMode::CONVERT_TO_TI_RAND_SEQ_LEN_CONST, ngraph::helpers::SequenceTestsMode::CONVERT_TO_TI_RAND_SEQ_LEN_CONST,
ngraph::helpers::SequenceTestsMode::CONVERT_TO_TI_RAND_SEQ_LEN_PARAM, ngraph::helpers::SequenceTestsMode::CONVERT_TO_TI_RAND_SEQ_LEN_PARAM,
ngraph::helpers::SequenceTestsMode::PURE_SEQ_RAND_SEQ_LEN_CONST,
ngraph::helpers::SequenceTestsMode::PURE_SEQ_RAND_SEQ_LEN_PARAM,
ngraph::helpers::SequenceTestsMode::PURE_SEQ}; ngraph::helpers::SequenceTestsMode::PURE_SEQ};
// output values increase rapidly without clip, so use only seq_lenghts = 2 // output values increase rapidly without clip, so use only seq_lenghts = 2
std::vector<size_t> seq_lengths_zero_clip{2}; std::vector<size_t> seq_lengths_zero_clip{2};

View File

@ -13,6 +13,8 @@ namespace {
std::vector<ngraph::helpers::SequenceTestsMode> mode{ngraph::helpers::SequenceTestsMode::CONVERT_TO_TI_MAX_SEQ_LEN_CONST, std::vector<ngraph::helpers::SequenceTestsMode> mode{ngraph::helpers::SequenceTestsMode::CONVERT_TO_TI_MAX_SEQ_LEN_CONST,
ngraph::helpers::SequenceTestsMode::CONVERT_TO_TI_RAND_SEQ_LEN_CONST, ngraph::helpers::SequenceTestsMode::CONVERT_TO_TI_RAND_SEQ_LEN_CONST,
ngraph::helpers::SequenceTestsMode::CONVERT_TO_TI_RAND_SEQ_LEN_PARAM, ngraph::helpers::SequenceTestsMode::CONVERT_TO_TI_RAND_SEQ_LEN_PARAM,
ngraph::helpers::SequenceTestsMode::PURE_SEQ_RAND_SEQ_LEN_CONST,
ngraph::helpers::SequenceTestsMode::PURE_SEQ_RAND_SEQ_LEN_PARAM,
ngraph::helpers::SequenceTestsMode::PURE_SEQ}; ngraph::helpers::SequenceTestsMode::PURE_SEQ};
// output values increase rapidly without clip, so use only seq_lenghts = 2 // output values increase rapidly without clip, so use only seq_lenghts = 2
std::vector<size_t> seq_lengths_zero_clip{2}; std::vector<size_t> seq_lengths_zero_clip{2};

View File

@ -48,7 +48,7 @@ std::vector<std::string> disabledTestPatterns() {
// Need to update activation primitive to support any broadcastable constant to enable these cases. // Need to update activation primitive to support any broadcastable constant to enable these cases.
R"(.*ActivationParamLayerTest.*)", R"(.*ActivationParamLayerTest.*)",
// Unknown issues // Unknown issues
R"(.*(LSTMSequence).*mode=CONVERT_TO_TI_RAND_SEQ_LEN.*)", R"(.*(LSTMSequence).*mode=.*_RAND_SEQ_LEN_CONST.*)",
R"(.*(smoke_DetectionOutput3In).*)", R"(.*(smoke_DetectionOutput3In).*)",
R"(.*(smoke_DetectionOutput5In).*)", R"(.*(smoke_DetectionOutput5In).*)",
// TODO: Issue: 47773 // TODO: Issue: 47773

View File

@ -44,6 +44,7 @@ namespace LayerTestsDefinitions {
} }
void GRUSequenceTest::SetUp() { void GRUSequenceTest::SetUp() {
using namespace ngraph::helpers;
size_t seq_lenghts; size_t seq_lenghts;
size_t batch; size_t batch;
size_t hidden_size; size_t hidden_size;
@ -66,8 +67,9 @@ namespace LayerTestsDefinitions {
m_max_seq_len = seq_lenghts; m_max_seq_len = seq_lenghts;
auto ngPrc = FuncTestUtils::PrecisionUtils::convertIE2nGraphPrc(netPrecision); auto ngPrc = FuncTestUtils::PrecisionUtils::convertIE2nGraphPrc(netPrecision);
auto params = ngraph::builder::makeParams(ngPrc, {inputShapes[0], inputShapes[1]}); auto params = ngraph::builder::makeParams(ngPrc, {inputShapes[0], inputShapes[1]});
if (m_mode == ngraph::helpers::SequenceTestsMode::CONVERT_TO_TI_MAX_SEQ_LEN_PARAM || if (m_mode == SequenceTestsMode::CONVERT_TO_TI_MAX_SEQ_LEN_PARAM ||
m_mode == ngraph::helpers::SequenceTestsMode::CONVERT_TO_TI_RAND_SEQ_LEN_PARAM) { m_mode == SequenceTestsMode::CONVERT_TO_TI_RAND_SEQ_LEN_PARAM ||
m_mode == SequenceTestsMode::PURE_SEQ_RAND_SEQ_LEN_PARAM) {
auto seq_lengths = ngraph::builder::makeParams(ngraph::element::i64, {inputShapes[2]}).at(0); auto seq_lengths = ngraph::builder::makeParams(ngraph::element::i64, {inputShapes[2]}).at(0);
seq_lengths->set_friendly_name("seq_lengths"); seq_lengths->set_friendly_name("seq_lengths");
params.push_back(seq_lengths); params.push_back(seq_lengths);
@ -79,16 +81,19 @@ namespace LayerTestsDefinitions {
ngraph::ResultVector results{std::make_shared<ngraph::opset1::Result>(gru_sequence->output(0)), ngraph::ResultVector results{std::make_shared<ngraph::opset1::Result>(gru_sequence->output(0)),
std::make_shared<ngraph::opset1::Result>(gru_sequence->output(1))}; std::make_shared<ngraph::opset1::Result>(gru_sequence->output(1))};
function = std::make_shared<ngraph::Function>(results, params, "gru_sequence"); function = std::make_shared<ngraph::Function>(results, params, "gru_sequence");
if (m_mode != ngraph::helpers::SequenceTestsMode::PURE_SEQ) { bool is_pure_sequence = (m_mode == SequenceTestsMode::PURE_SEQ ||
m_mode == SequenceTestsMode::PURE_SEQ_RAND_SEQ_LEN_PARAM ||
m_mode == SequenceTestsMode::PURE_SEQ_RAND_SEQ_LEN_CONST);
if (!is_pure_sequence) {
ngraph::pass::Manager manager; ngraph::pass::Manager manager;
if (direction == ngraph::op::RecurrentSequenceDirection::BIDIRECTIONAL) if (direction == ngraph::op::RecurrentSequenceDirection::BIDIRECTIONAL)
manager.register_pass<ngraph::pass::BidirectionalGRUSequenceDecomposition>(); manager.register_pass<ngraph::pass::BidirectionalGRUSequenceDecomposition>();
manager.register_pass<ngraph::pass::ConvertGRUSequenceToTensorIterator>(); manager.register_pass<ngraph::pass::ConvertGRUSequenceToTensorIterator>();
manager.run_passes(function); manager.run_passes(function);
bool ti_found = ngraph::helpers::is_tensor_iterator_exist(function); bool ti_found = is_tensor_iterator_exist(function);
EXPECT_EQ(ti_found, true); EXPECT_EQ(ti_found, true);
} else { } else {
bool ti_found = ngraph::helpers::is_tensor_iterator_exist(function); bool ti_found = is_tensor_iterator_exist(function);
EXPECT_EQ(ti_found, false); EXPECT_EQ(ti_found, false);
} }
} }

View File

@ -44,6 +44,8 @@ namespace LayerTestsDefinitions {
} }
void LSTMSequenceTest::SetUp() { void LSTMSequenceTest::SetUp() {
using namespace ngraph::helpers;
using namespace ngraph::builder;
size_t seq_lenghts; size_t seq_lenghts;
size_t batch; size_t batch;
@ -64,33 +66,36 @@ namespace LayerTestsDefinitions {
{batch}, {num_directions, 4 * hidden_size, input_size}, {num_directions, 4 * hidden_size, hidden_size}, {num_directions, 4 * hidden_size}}, {batch}, {num_directions, 4 * hidden_size, input_size}, {num_directions, 4 * hidden_size, hidden_size}, {num_directions, 4 * hidden_size}},
}; };
auto ngPrc = FuncTestUtils::PrecisionUtils::convertIE2nGraphPrc(netPrecision); auto ngPrc = FuncTestUtils::PrecisionUtils::convertIE2nGraphPrc(netPrecision);
auto params = ngraph::builder::makeParams(ngPrc, {inputShapes[0], inputShapes[1], inputShapes[2]}); auto params = makeParams(ngPrc, {inputShapes[0], inputShapes[1], inputShapes[2]});
if (m_mode == ngraph::helpers::SequenceTestsMode::CONVERT_TO_TI_MAX_SEQ_LEN_PARAM || if (m_mode == SequenceTestsMode::CONVERT_TO_TI_MAX_SEQ_LEN_PARAM ||
m_mode == ngraph::helpers::SequenceTestsMode::CONVERT_TO_TI_RAND_SEQ_LEN_PARAM) { m_mode == SequenceTestsMode::CONVERT_TO_TI_RAND_SEQ_LEN_PARAM ||
auto seq_lengths = ngraph::builder::makeParams(ngraph::element::i64, {inputShapes[3]}).at(0); m_mode == SequenceTestsMode::PURE_SEQ_RAND_SEQ_LEN_PARAM) {
auto seq_lengths = makeParams(ngraph::element::i64, {inputShapes[3]}).at(0);
seq_lengths->set_friendly_name("seq_lengths"); seq_lengths->set_friendly_name("seq_lengths");
params.push_back(seq_lengths); params.push_back(seq_lengths);
} }
std::vector<ngraph::Shape> WRB = {inputShapes[4], inputShapes[5], inputShapes[6], inputShapes[3]}; std::vector<ngraph::Shape> WRB = {inputShapes[4], inputShapes[5], inputShapes[6], inputShapes[3]};
auto lstm_sequence = ngraph::builder::makeLSTM(ngraph::helpers::convert2OutputVector(ngraph::helpers::castOps2Nodes(params)), auto lstm_sequence = makeLSTM(convert2OutputVector(castOps2Nodes(params)), WRB, hidden_size, activations,
WRB, hidden_size, activations, {}, {}, clip, true, direction, {}, {}, clip, true, direction, m_mode);
m_mode);
ngraph::ResultVector results{std::make_shared<ngraph::opset1::Result>(lstm_sequence->output(0)), ngraph::ResultVector results{std::make_shared<ngraph::opset1::Result>(lstm_sequence->output(0)),
std::make_shared<ngraph::opset1::Result>(lstm_sequence->output(1)), std::make_shared<ngraph::opset1::Result>(lstm_sequence->output(1)),
std::make_shared<ngraph::opset1::Result>(lstm_sequence->output(2))}; std::make_shared<ngraph::opset1::Result>(lstm_sequence->output(2))};
function = std::make_shared<ngraph::Function>(results, params, "lstm_sequence"); function = std::make_shared<ngraph::Function>(results, params, "lstm_sequence");
if (m_mode != ngraph::helpers::SequenceTestsMode::PURE_SEQ) { bool is_pure_sequence = (m_mode == SequenceTestsMode::PURE_SEQ ||
m_mode == SequenceTestsMode::PURE_SEQ_RAND_SEQ_LEN_PARAM ||
m_mode == SequenceTestsMode::PURE_SEQ_RAND_SEQ_LEN_CONST);
if (!is_pure_sequence) {
ngraph::pass::Manager manager; ngraph::pass::Manager manager;
if (direction == ngraph::op::RecurrentSequenceDirection::BIDIRECTIONAL) if (direction == ngraph::op::RecurrentSequenceDirection::BIDIRECTIONAL)
manager.register_pass<ngraph::pass::BidirectionalLSTMSequenceDecomposition>(); manager.register_pass<ngraph::pass::BidirectionalLSTMSequenceDecomposition>();
manager.register_pass<ngraph::pass::ConvertLSTMSequenceToTensorIterator>(); manager.register_pass<ngraph::pass::ConvertLSTMSequenceToTensorIterator>();
manager.run_passes(function); manager.run_passes(function);
bool ti_found = ngraph::helpers::is_tensor_iterator_exist(function); bool ti_found = is_tensor_iterator_exist(function);
EXPECT_EQ(ti_found, true); EXPECT_EQ(ti_found, true);
} else { } else {
bool ti_found = ngraph::helpers::is_tensor_iterator_exist(function); bool ti_found = is_tensor_iterator_exist(function);
EXPECT_EQ(ti_found, false); EXPECT_EQ(ti_found, false);
} }
} }
void LSTMSequenceTest::GenerateInputs() { void LSTMSequenceTest::GenerateInputs() {

View File

@ -43,6 +43,7 @@ namespace LayerTestsDefinitions {
} }
void RNNSequenceTest::SetUp() { void RNNSequenceTest::SetUp() {
using namespace ngraph::helpers;
size_t seq_lenghts; size_t seq_lenghts;
size_t batch; size_t batch;
size_t hidden_size; size_t hidden_size;
@ -64,8 +65,9 @@ namespace LayerTestsDefinitions {
m_max_seq_len = seq_lenghts; m_max_seq_len = seq_lenghts;
auto ngPrc = FuncTestUtils::PrecisionUtils::convertIE2nGraphPrc(netPrecision); auto ngPrc = FuncTestUtils::PrecisionUtils::convertIE2nGraphPrc(netPrecision);
auto params = ngraph::builder::makeParams(ngPrc, {inputShapes[0], inputShapes[1]}); auto params = ngraph::builder::makeParams(ngPrc, {inputShapes[0], inputShapes[1]});
if (m_mode == ngraph::helpers::SequenceTestsMode::CONVERT_TO_TI_MAX_SEQ_LEN_PARAM || if (m_mode == SequenceTestsMode::CONVERT_TO_TI_MAX_SEQ_LEN_PARAM ||
m_mode == ngraph::helpers::SequenceTestsMode::CONVERT_TO_TI_RAND_SEQ_LEN_PARAM) { m_mode == SequenceTestsMode::CONVERT_TO_TI_RAND_SEQ_LEN_PARAM ||
m_mode == SequenceTestsMode::PURE_SEQ_RAND_SEQ_LEN_PARAM) {
auto seq_lengths = ngraph::builder::makeParams(ngraph::element::i64, {inputShapes[2]}).at(0); auto seq_lengths = ngraph::builder::makeParams(ngraph::element::i64, {inputShapes[2]}).at(0);
seq_lengths->set_friendly_name("seq_lengths"); seq_lengths->set_friendly_name("seq_lengths");
params.push_back(seq_lengths); params.push_back(seq_lengths);
@ -77,7 +79,10 @@ namespace LayerTestsDefinitions {
ngraph::ResultVector results{std::make_shared<ngraph::opset1::Result>(rnn_sequence->output(0)), ngraph::ResultVector results{std::make_shared<ngraph::opset1::Result>(rnn_sequence->output(0)),
std::make_shared<ngraph::opset1::Result>(rnn_sequence->output(1))}; std::make_shared<ngraph::opset1::Result>(rnn_sequence->output(1))};
function = std::make_shared<ngraph::Function>(results, params, "rnn_sequence"); function = std::make_shared<ngraph::Function>(results, params, "rnn_sequence");
if (m_mode != ngraph::helpers::SequenceTestsMode::PURE_SEQ) { bool is_pure_sequence = (m_mode == SequenceTestsMode::PURE_SEQ ||
m_mode == SequenceTestsMode::PURE_SEQ_RAND_SEQ_LEN_PARAM ||
m_mode == SequenceTestsMode::PURE_SEQ_RAND_SEQ_LEN_CONST);
if (!is_pure_sequence) {
ngraph::pass::Manager manager; ngraph::pass::Manager manager;
if (direction == ngraph::op::RecurrentSequenceDirection::BIDIRECTIONAL) if (direction == ngraph::op::RecurrentSequenceDirection::BIDIRECTIONAL)
manager.register_pass<ngraph::pass::BidirectionalRNNSequenceDecomposition>(); manager.register_pass<ngraph::pass::BidirectionalRNNSequenceDecomposition>();

View File

@ -198,6 +198,8 @@ enum class TensorIteratorBody {
enum class SequenceTestsMode { enum class SequenceTestsMode {
PURE_SEQ, PURE_SEQ,
PURE_SEQ_RAND_SEQ_LEN_CONST,
PURE_SEQ_RAND_SEQ_LEN_PARAM,
CONVERT_TO_TI_MAX_SEQ_LEN_CONST, CONVERT_TO_TI_MAX_SEQ_LEN_CONST,
CONVERT_TO_TI_MAX_SEQ_LEN_PARAM, CONVERT_TO_TI_MAX_SEQ_LEN_PARAM,
CONVERT_TO_TI_RAND_SEQ_LEN_CONST, CONVERT_TO_TI_RAND_SEQ_LEN_CONST,

View File

@ -38,7 +38,8 @@ std::shared_ptr<ngraph::Node> makeGRU(const OutputVector& in,
seq_lengths = ngraph::builder::makeConstant(element::i64, constants[3], lengths, false); seq_lengths = ngraph::builder::makeConstant(element::i64, constants[3], lengths, false);
break; break;
} }
case ngraph::helpers::SequenceTestsMode::CONVERT_TO_TI_RAND_SEQ_LEN_CONST: { case ngraph::helpers::SequenceTestsMode::CONVERT_TO_TI_RAND_SEQ_LEN_CONST:
case ngraph::helpers::SequenceTestsMode::PURE_SEQ_RAND_SEQ_LEN_CONST: {
for (size_t i = 0; i <= in[0].get_shape().at(0); ++i) { for (size_t i = 0; i <= in[0].get_shape().at(0); ++i) {
std::vector<float> lengths; std::vector<float> lengths;
seq_lengths = ngraph::builder::makeConstant(element::i64, constants[3], lengths, true, seq_lengths = ngraph::builder::makeConstant(element::i64, constants[3], lengths, true,
@ -47,7 +48,8 @@ std::shared_ptr<ngraph::Node> makeGRU(const OutputVector& in,
break; break;
} }
case ngraph::helpers::SequenceTestsMode::CONVERT_TO_TI_RAND_SEQ_LEN_PARAM: case ngraph::helpers::SequenceTestsMode::CONVERT_TO_TI_RAND_SEQ_LEN_PARAM:
case ngraph::helpers::SequenceTestsMode::CONVERT_TO_TI_MAX_SEQ_LEN_PARAM: { case ngraph::helpers::SequenceTestsMode::CONVERT_TO_TI_MAX_SEQ_LEN_PARAM:
case ngraph::helpers::SequenceTestsMode::PURE_SEQ_RAND_SEQ_LEN_PARAM: {
// Seq_lengths should be as a Parameter node for these two modes // Seq_lengths should be as a Parameter node for these two modes
seq_lengths = in.at(2).get_node_shared_ptr(); seq_lengths = in.at(2).get_node_shared_ptr();
break; break;

View File

@ -36,7 +36,8 @@ std::shared_ptr<ngraph::Node> makeLSTM(const std::vector<ngraph::Output<Node>>&
seq_lengths = ngraph::builder::makeConstant(element::i64, constants[3], lengths, false); seq_lengths = ngraph::builder::makeConstant(element::i64, constants[3], lengths, false);
break; break;
} }
case ngraph::helpers::SequenceTestsMode::CONVERT_TO_TI_RAND_SEQ_LEN_CONST: { case ngraph::helpers::SequenceTestsMode::CONVERT_TO_TI_RAND_SEQ_LEN_CONST:
case ngraph::helpers::SequenceTestsMode::PURE_SEQ_RAND_SEQ_LEN_CONST: {
for (size_t i = 0; i <= in[0].get_shape().at(0); ++i) { for (size_t i = 0; i <= in[0].get_shape().at(0); ++i) {
std::vector<float> lengths; std::vector<float> lengths;
seq_lengths = ngraph::builder::makeConstant(element::i64, constants[3], lengths, true, seq_lengths = ngraph::builder::makeConstant(element::i64, constants[3], lengths, true,
@ -44,6 +45,7 @@ std::shared_ptr<ngraph::Node> makeLSTM(const std::vector<ngraph::Output<Node>>&
} }
break; break;
} }
case ngraph::helpers::SequenceTestsMode::PURE_SEQ_RAND_SEQ_LEN_PARAM:
case ngraph::helpers::SequenceTestsMode::CONVERT_TO_TI_RAND_SEQ_LEN_PARAM: case ngraph::helpers::SequenceTestsMode::CONVERT_TO_TI_RAND_SEQ_LEN_PARAM:
case ngraph::helpers::SequenceTestsMode::CONVERT_TO_TI_MAX_SEQ_LEN_PARAM: { case ngraph::helpers::SequenceTestsMode::CONVERT_TO_TI_MAX_SEQ_LEN_PARAM: {
// Seq_lengths should be as a Parameter node for these two modes // Seq_lengths should be as a Parameter node for these two modes

View File

@ -36,7 +36,8 @@ std::shared_ptr<ngraph::Node> makeRNN(const OutputVector& in,
seq_lengths = ngraph::builder::makeConstant(element::i64, constants[3], lengths, false); seq_lengths = ngraph::builder::makeConstant(element::i64, constants[3], lengths, false);
break; break;
} }
case ngraph::helpers::SequenceTestsMode::CONVERT_TO_TI_RAND_SEQ_LEN_CONST: { case ngraph::helpers::SequenceTestsMode::CONVERT_TO_TI_RAND_SEQ_LEN_CONST:
case ngraph::helpers::SequenceTestsMode::PURE_SEQ_RAND_SEQ_LEN_CONST: {
for (size_t i = 0; i <= in[0].get_shape().at(0); ++i) { for (size_t i = 0; i <= in[0].get_shape().at(0); ++i) {
std::vector<float> lengths; std::vector<float> lengths;
seq_lengths = ngraph::builder::makeConstant(element::i64, constants[3], lengths, true, seq_lengths = ngraph::builder::makeConstant(element::i64, constants[3], lengths, true,
@ -45,7 +46,8 @@ std::shared_ptr<ngraph::Node> makeRNN(const OutputVector& in,
break; break;
} }
case ngraph::helpers::SequenceTestsMode::CONVERT_TO_TI_RAND_SEQ_LEN_PARAM: case ngraph::helpers::SequenceTestsMode::CONVERT_TO_TI_RAND_SEQ_LEN_PARAM:
case ngraph::helpers::SequenceTestsMode::CONVERT_TO_TI_MAX_SEQ_LEN_PARAM: { case ngraph::helpers::SequenceTestsMode::CONVERT_TO_TI_MAX_SEQ_LEN_PARAM:
case ngraph::helpers::SequenceTestsMode::PURE_SEQ_RAND_SEQ_LEN_PARAM: {
// Seq_lengths should be as a Parameter node for these two modes // Seq_lengths should be as a Parameter node for these two modes
seq_lengths = in.at(2).get_node_shared_ptr(); seq_lengths = in.at(2).get_node_shared_ptr();
break; break;

View File

@ -943,6 +943,12 @@ std::ostream& operator<<(std::ostream & os, SequenceTestsMode type) {
case SequenceTestsMode::PURE_SEQ: case SequenceTestsMode::PURE_SEQ:
os << "PURE_SEQ"; os << "PURE_SEQ";
break; break;
case SequenceTestsMode::PURE_SEQ_RAND_SEQ_LEN_CONST:
os << "PURE_SEQ_RAND_SEQ_LEN_CONST";
break;
case SequenceTestsMode::PURE_SEQ_RAND_SEQ_LEN_PARAM:
os << "PURE_SEQ_RAND_SEQ_LEN_PARAM";
break;
case SequenceTestsMode::CONVERT_TO_TI_RAND_SEQ_LEN_PARAM: case SequenceTestsMode::CONVERT_TO_TI_RAND_SEQ_LEN_PARAM:
os << "CONVERT_TO_TI_RAND_SEQ_LEN_PARAM"; os << "CONVERT_TO_TI_RAND_SEQ_LEN_PARAM";
break; break;

View File

@ -191,17 +191,18 @@ namespace ngraph
size_t part_size_single_batch = part_shape_size / batch * sizeof(T); size_t part_size_single_batch = part_shape_size / batch * sizeof(T);
for (int i = 0; i < batch; ++i) for (int i = 0; i < batch; ++i)
{ {
auto shift = i * part_size_single_batch;
if ((time_step + 1) > seq_len_values[i]) if ((time_step + 1) > seq_len_values[i])
{ {
continue; continue;
} }
std::memcpy(h_list[time_step].data() + i * part_size_single_batch, std::memcpy(h_list[time_step].data() + shift,
outputs[1] + i * part_size_single_batch, outputs[1] + shift,
part_size_single_batch); part_size_single_batch);
if (type == CellType::LSTM) if (type == CellType::LSTM)
{ {
std::memcpy(c_list[time_step].data() + i * part_size_single_batch, std::memcpy(c_list[time_step].data() + shift,
outputs[2] + i * part_size_single_batch, outputs[2] + shift,
part_size_single_batch); part_size_single_batch);
} }
} }
@ -220,16 +221,27 @@ namespace ngraph
{ {
for (int i = 0; i < batch; ++i) for (int i = 0; i < batch; ++i)
{ {
std::memcpy(outputs[1] + i * part_size_single_batch, auto idx = seq_len_values[i] - 1;
h_list[seq_len_values[i] - 1].data() + auto shift = i * part_size_single_batch;
i * part_size_single_batch, if (idx >= 0 && idx < h_list.size())
part_size_single_batch);
if (type == CellType::LSTM)
{ {
std::memcpy(outputs[2] + i * part_size_single_batch, std::memcpy(outputs[1] + shift,
c_list[seq_len_values[i] - 1].data() + h_list[idx].data() + shift,
i * part_size_single_batch,
part_size_single_batch); part_size_single_batch);
if (type == CellType::LSTM)
{
std::memcpy(outputs[2] + shift,
c_list[idx].data() + shift,
part_size_single_batch);
}
}
else
{
std::memset(outputs[1] + shift, 0, part_size_single_batch);
if (type == CellType::LSTM)
{
std::memset(outputs[2] + shift, 0, part_size_single_batch);
}
} }
} }
} }