LSTM/GRU/RNN Sequences : support for seq_lengths input (#2788)

* sequences to ti transformations, support for seq_lengths input, update reference implemetations, add new tests

* fix python api, update sequences to ti transformation

* fix sequences to ti transformation

* Update sequences to TI transformation: fix reverse sequence support

* update single layer tests, fix TI reference impl, fix Sequences to TI transformations

* ngraph code style

* fix build

* fix ngraph python api

* resolver review comments, refactoring

* Resolve review remarks

* delete xfail
This commit is contained in:
Ivan Tikhonov 2020-11-17 07:04:20 +03:00 committed by GitHub
parent 89f06586cf
commit b45e1a25a5
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
34 changed files with 2593 additions and 342 deletions

View File

@ -31,6 +31,7 @@
#include <transformations/common_optimizations/common_optimizations.hpp>
#include <transformations/opset_conversions/convert_opset2_to_opset1.hpp>
#include <transformations/opset_conversions/convert_opset3_to_opset2.hpp>
#include <transformations/op_conversions/convert_sequences_to_tensor_iterator.hpp>
#include <transformations/init_node_info.hpp>
#include <transformations/convert_precision.hpp>
#include <transformations/rt_info/fused_names_attribute.hpp>
@ -149,6 +150,9 @@ InferenceEngine::ICNNNetwork::Ptr clDNNEngine::CloneAndTransformNetwork(const In
// WA: ConvertPriorBox must be executed before the 1st ConstantFolding pass
manager.register_pass<ngraph::pass::ConvertPriorBox>();
manager.register_pass<ngraph::pass::CommonOptimizations>();
manager.register_pass<ngraph::pass::ConvertRNNSequenceToTensorIterator>();
manager.register_pass<ngraph::pass::ConvertGRUSequenceToTensorIterator>();
manager.register_pass<ngraph::pass::ConvertLSTMSequenceToTensorIterator>();
manager.register_pass<ngraph::pass::ConvertOpSet3ToOpSet2>();
manager.register_pass<ngraph::pass::ConvertOpSet2ToOpSet1>();

View File

@ -32,7 +32,6 @@
#include <transformations/common_optimizations/common_optimizations.hpp>
#include <transformations/common_optimizations/depth_to_space_fusion.hpp>
#include <transformations/control_flow/unroll_tensor_iterator.hpp>
#include <transformations/op_conversions/convert_depth_to_space.hpp>
#include <transformations/op_conversions/convert_space_to_depth.hpp>
#include <transformations/op_conversions/convert_gelu.hpp>
@ -44,6 +43,8 @@
#include <transformations/op_conversions/softplus_decomposition.hpp>
#include <transformations/op_conversions/convert_space_to_batch.hpp>
#include <transformations/op_conversions/convert_batch_to_space.hpp>
#include <transformations/op_conversions/convert_sequences_to_tensor_iterator.hpp>
#include <transformations/control_flow/unroll_tensor_iterator.hpp>
#include <transformations/op_conversions/convert_mod.hpp>
#include <transformations/op_conversions/log_softmax_decomposition.hpp>
#include <transformations/convert_precision.hpp>
@ -96,6 +97,9 @@ static void Transformation(ICNNNetwork::Ptr& clonedNetwork, const Config& conf)
// WA: ConvertPriorBox must be executed before the 1st ConstantFolding pass
manager.register_pass<ngraph::pass::ConvertPriorBox>();
manager.register_pass<ngraph::pass::CommonOptimizations>();
manager.register_pass<ngraph::pass::ConvertRNNSequenceToTensorIterator>();
manager.register_pass<ngraph::pass::ConvertGRUSequenceToTensorIterator>();
manager.register_pass<ngraph::pass::ConvertLSTMSequenceToTensorIterator>();
manager.register_pass<ngraph::pass::ConvertOpSet3ToOpSet2>();
manager.register_pass<ngraph::pass::ConvertOpSet2ToOpSet1>();

View File

@ -0,0 +1,58 @@
// Copyright (C) 2020 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#pragma once
#include <vector>
#include <memory>
#include <transformations_visibility.hpp>
#include <ngraph/pass/graph_rewrite.hpp>
namespace ngraph {
namespace pass {
class TRANSFORMATIONS_API ConvertRNNSequenceToTensorIterator;
class TRANSFORMATIONS_API ConvertGRUSequenceToTensorIterator;
class TRANSFORMATIONS_API ConvertLSTMSequenceToTensorIterator;
} // namespace pass
} // namespace ngraph
/**
* @ingroup ie_transformation_common_api
* @brief ConvertRNNSequenceToTensorIterator transformation converts RNNSequence layer to TensorIterator
* *
*/
class ngraph::pass::ConvertRNNSequenceToTensorIterator: public ngraph::pass::MatcherPass {
public:
NGRAPH_RTTI_DECLARATION;
ConvertRNNSequenceToTensorIterator();
};
/**
* @ingroup ie_transformation_common_api
* @brief ConvertGRUSequenceToTensorIterator transformation converts GRUSequence layer to TensorIterator
* *
*/
class ngraph::pass::ConvertGRUSequenceToTensorIterator: public ngraph::pass::MatcherPass {
public:
NGRAPH_RTTI_DECLARATION;
ConvertGRUSequenceToTensorIterator();
};
/**
* @ingroup ie_transformation_common_api
* @brief ConvertLSTMSequenceToTensorIterator transformation converts LSTMSequence layer to TensorIterator
* *
*/
class ngraph::pass::ConvertLSTMSequenceToTensorIterator: public ngraph::pass::MatcherPass {
public:
NGRAPH_RTTI_DECLARATION;
ConvertLSTMSequenceToTensorIterator();
};

View File

@ -73,11 +73,11 @@ bool ngraph::pass::CommonOptimizations::run_on_function(std::shared_ptr<ngraph::
manager.register_pass<ngraph::pass::HSigmoidFusion>();
manager.register_pass<ngraph::pass::ConvertPadToGroupConvolution, false>();
manager.register_pass<ngraph::pass::NormalizeL2Fusion>();
manager.register_pass<ngraph::pass::BidirectionalLSTMSequenceDecomposition>();
manager.register_pass<ngraph::pass::BidirectionalRNNSequenceDecomposition>();
manager.register_pass<ngraph::pass::BidirectionalGRUSequenceDecomposition>();
auto decomp = manager.register_pass<ngraph::pass::GraphRewrite>();
decomp->add_matcher<ngraph::pass::BidirectionalLSTMSequenceDecomposition>();
decomp->add_matcher<ngraph::pass::BidirectionalRNNSequenceDecomposition>();
decomp->add_matcher<ngraph::pass::BidirectionalGRUSequenceDecomposition>();
decomp->add_matcher<ngraph::pass::ReduceL1Decomposition>();
decomp->add_matcher<ngraph::pass::ReduceL2Decomposition>();
decomp->add_matcher<ngraph::pass::HSwishDecomposition>();

View File

@ -0,0 +1,586 @@
// Copyright (C) 2020 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#include "ngraph/builder/autobroadcast.hpp"
#include "transformations/op_conversions/convert_sequences_to_tensor_iterator.hpp"
#include <memory>
#include <transformations/utils/utils.hpp>
#include <ngraph/opsets/opset5.hpp>
#include <ngraph/rt_info.hpp>
#include <ngraph/pattern/op/wrap_type.hpp>
#include <ngraph/op/util/activation_functions.hpp>
NGRAPH_RTTI_DEFINITION(ngraph::pass::ConvertRNNSequenceToTensorIterator, "ConvertRNNSequenceToTensorIterator", 0);
NGRAPH_RTTI_DEFINITION(ngraph::pass::ConvertGRUSequenceToTensorIterator, "ConvertGRUSequenceToTensorIterator", 0);
NGRAPH_RTTI_DEFINITION(ngraph::pass::ConvertLSTMSequenceToTensorIterator, "ConvertLSTMSequenceToTensorIterator", 0);
namespace {
ngraph::Output<ngraph::Node> get_current_iter(ngraph::ParameterVector &body_params,
ngraph::ResultVector &body_results,
const ngraph::Output<ngraph::Node> &seq_lengths) {
auto curr_iter_body_param = std::make_shared<ngraph::opset5::Parameter>(seq_lengths.get_element_type(),
ngraph::Shape{1});
// increment current iteration
auto one = std::make_shared<ngraph::opset5::Constant>(seq_lengths.get_element_type(), ngraph::Shape{1},
std::vector<int64_t>{1});
auto add = std::make_shared<ngraph::opset5::Add>(curr_iter_body_param, one);
auto curr_iter_result = std::make_shared<ngraph::opset5::Result>(add);
body_params.push_back(curr_iter_body_param);
body_results.push_back(curr_iter_result);
return curr_iter_body_param;
}
ngraph::Output<ngraph::Node> get_masked_value(const std::shared_ptr<ngraph::opset5::TensorIterator> &ti,
ngraph::ParameterVector &body_params,
ngraph::ResultVector &body_results,
const ngraph::Output<ngraph::Node> &current_iter,
const ngraph::Output<ngraph::Node> &data,
const ngraph::Output<ngraph::Node> &seq_lengths) {
const auto &data_type = data.get_element_type();
const auto &data_shape = data.get_shape();
const auto &data_shape_size = ngraph::shape_size(data_shape);
// body parameters
auto aggregated_Y_h_body_param = std::make_shared<ngraph::opset5::Parameter>(data_type, data_shape);
body_params.push_back(aggregated_Y_h_body_param);
// Create mask node deciding whether or not to mask batch data.
ngraph::Output<ngraph::Node> batch_seq_length = ngraph::builder::opset1::legacy_broadcast_for_binary_operation(
data, seq_lengths, 0);
auto mask_value = std::make_shared<ngraph::opset5::Constant>(data_type, data_shape, std::vector<float>(data_shape_size, 0.f));
auto mask_condition = std::make_shared<ngraph::opset5::Greater>(current_iter, batch_seq_length);
auto mask_Y_h = std::make_shared<ngraph::opset5::Equal>(current_iter, batch_seq_length);
// Select values depending on mask.
// Select(<condition>, <true_value>, <false_value>)
auto select_aggregated_H = std::make_shared<ngraph::opset5::Select>(mask_Y_h, data, aggregated_Y_h_body_param);
auto aggregated_result = std::make_shared<ngraph::opset5::Result>(select_aggregated_H);
body_results.push_back(aggregated_result);
return std::make_shared<ngraph::opset5::Select>(mask_condition, mask_value, data);
}
ngraph::NodeVector squeeze_nodes(const ngraph::OutputVector &nodes_to_squeeze, const ngraph::OutputVector &axes) {
ngraph::NodeVector squeezed_nodes(nodes_to_squeeze.size());
for (size_t i = 0; i < nodes_to_squeeze.size(); ++i) {
squeezed_nodes[i] = std::make_shared<ngraph::opset5::Squeeze>(nodes_to_squeeze[i], axes[i]);
}
return squeezed_nodes;
}
bool should_enable_mask(const ngraph::Output<ngraph::Node> &seq_lengths, size_t max_seq_len) {
// disable the mask if all values of seq_lengths input are equal to max_seq_len (X_shape[1])
if (const auto &seq_len_const = std::dynamic_pointer_cast<ngraph::opset5::Constant>(
seq_lengths.get_node_shared_ptr())) {
const auto &seq_len_values = seq_len_const->cast_vector<int64_t>();
return std::any_of(seq_len_values.begin(), seq_len_values.end(), [max_seq_len](const int64_t val) {
return val != max_seq_len;
});
}
return true;
}
} // namespace
ngraph::pass::ConvertRNNSequenceToTensorIterator::ConvertRNNSequenceToTensorIterator() {
// X, H, seq_lengths - static, W,R,B - any
auto rnn_seq = ngraph::pattern::wrap_type<opset5::RNNSequence>({pattern::any_input(pattern::has_static_shape()),
pattern::any_input(pattern::has_static_shape()),
pattern::any_input(pattern::has_static_shape()),
pattern::any_input(),
pattern::any_input(),
pattern::any_input()});
ngraph::matcher_pass_callback callback = [](ngraph::pattern::Matcher &m) {
auto sequence = std::dynamic_pointer_cast<ngraph::opset5::RNNSequence>(m.get_match_root());
// Bidirectional Sequence op should be decomposed to Reverse + Forward
// (e.g. apply BidirectionalRNNSequenceDecomposition transformation before this one)
if (!sequence || sequence->get_direction() == ngraph::op::RecurrentSequenceDirection::BIDIRECTIONAL) {
return false;
}
NodeVector new_nodes;
const auto &X = sequence->input_value(0); // split
const auto &H_t = sequence->input_value(1); // merged (init value + back edge)
const auto &seq_lengths = sequence->input_value(2); // invariant
const auto &W = sequence->input_value(3); // const in the body
const auto &R = sequence->input_value(4); // const in the body
const auto &B = sequence->input_value(5); // const in the body
bool is_reverse = sequence->get_direction() == ngraph::op::RecurrentSequenceDirection::REVERSE;
auto tensor_iterator = std::make_shared<opset5::TensorIterator>();
auto max_seq_len = X.get_shape().at(1);
bool enable_mask = should_enable_mask(seq_lengths, max_seq_len);
std::shared_ptr<Node> reverse_seq_before;
if (is_reverse && enable_mask) {
reverse_seq_before = std::make_shared<opset5::ReverseSequence>(X, seq_lengths, 0, 1);
}
// TensorIterator Body: begin
Shape X_param_shape = X.get_shape();
X_param_shape.at(1) = 1; // split by seq_lengths dimension
auto X_body_param = std::make_shared<opset5::Parameter>(X.get_element_type(), X_param_shape);
auto H_body_param = std::make_shared<opset5::Parameter>(H_t.get_element_type(),
H_t.get_shape());
auto seq_body_param = std::make_shared<opset5::Parameter>(seq_lengths.get_element_type(),
seq_lengths.get_partial_shape());
auto axis_0 = std::make_shared<opset5::Constant>(element::i64, Shape{1}, std::vector<int64_t>{0});
auto axis_1 = std::make_shared<opset5::Constant>(element::i64, Shape{1}, std::vector<int64_t>{1});
const auto& ins = squeeze_nodes({X_body_param, H_body_param, W, R, B}, {axis_1, axis_1, axis_0, axis_0, axis_0});
auto cell = std::make_shared<opset5::RNNCell>(ins[0],
ins[1],
ins[2],
ins[3],
ins[4],
sequence->get_hidden_size(),
sequence->get_activations(),
sequence->get_activations_alpha(),
sequence->get_activations_beta(),
sequence->get_clip());
ParameterVector body_params;
ResultVector body_results;
auto unsqueeze_dum_dir = std::make_shared<opset5::Unsqueeze>(cell->output(0), axis_1);
Output<Node> h_node_to_result = unsqueeze_dum_dir;
if (enable_mask) {
auto current_iter = get_current_iter(body_params, body_results, seq_body_param);
h_node_to_result = get_masked_value(tensor_iterator, body_params, body_results, current_iter,
unsqueeze_dum_dir, seq_body_param);
}
auto H_res = std::make_shared<opset5::Result>(h_node_to_result);
auto unsqueeze_seq_len = std::make_shared<opset5::Unsqueeze>(h_node_to_result, axis_1);
auto concat_res = std::make_shared<opset5::Result>(unsqueeze_seq_len);
body_params.push_back(X_body_param);
body_params.push_back(H_body_param);
body_params.push_back(seq_body_param);
body_results.push_back(concat_res);
body_results.push_back(H_res);
auto body = std::make_shared<ngraph::Function>(body_results, body_params);
tensor_iterator->set_function(body);
// TensorIterator Body: end
if (is_reverse) {
if (!enable_mask) {
// Reversed order, stride -1
tensor_iterator->set_sliced_input(X_body_param, X, -1, -1, 1, 0, 1);
tensor_iterator->get_concatenated_slices(concat_res, -1, -1, 1, 0, 2);
} else {
// use ReverseSequence as initializer
tensor_iterator->set_sliced_input(X_body_param, reverse_seq_before, 0, 1, 1, -1, 1);
tensor_iterator->get_concatenated_slices(concat_res, 0, 1, 1, -1, 2);
}
} else {
// forward order
tensor_iterator->set_sliced_input(X_body_param, X, 0, 1, 1, -1, 1);
tensor_iterator->get_concatenated_slices(concat_res, 0, 1, 1, -1, 2);
}
tensor_iterator->set_merged_input(H_body_param, H_t, H_res);
tensor_iterator->set_invariant_input(seq_body_param, seq_lengths);
Output<Node> H_out = H_res;
if (enable_mask) {
// create initial values for body_parameters in outer graph
// aggregated Y_h - concatenation of the last non-zero values for each batch
auto aggregated_Y_h = std::make_shared<ngraph::opset5::Constant>(H_body_param->get_element_type(),
H_body_param->get_shape(),
std::vector<float>(shape_size(H_body_param->get_shape()),
0.f));
auto init_val_curr_iter = std::make_shared<ngraph::opset5::Constant>(seq_lengths.get_element_type(),
ngraph::Shape{1},
std::vector<int64_t>{1});
ngraph::copy_runtime_info(sequence, {aggregated_Y_h, init_val_curr_iter});
// set initial value and back edge for current iteration
tensor_iterator->set_merged_input(body_params.at(0), init_val_curr_iter, body_results.at(0));
// set initial value and back edge for aggregated H
tensor_iterator->set_merged_input(body_params.at(1), aggregated_Y_h, body_results.at(1));
H_out = tensor_iterator->get_function()->get_results()[1];
}
tensor_iterator->get_iter_value(H_out);
tensor_iterator->set_friendly_name(sequence->get_friendly_name());
if (enable_mask && is_reverse) {
auto reverse_seq_after = std::make_shared<opset5::ReverseSequence>(tensor_iterator->output(0), seq_lengths, 0, 2);
// Resolve a collision of names data nodes in CNN Network in Reverse case with mask.
/*
* Before transformation (no collisions)
* RNN/LSTM/GRU Sequence [rnn_name] -- (data_node: rnn_name.0) - > Result1
* -- (data_node: rnn_name.1) - > Result2
*
*
* After transformation (without identity, there are collisions):
* We need to set rnn_name.0 to RevSequence to store result name.
* TI [rnn_name] -- (DATA_NODE: rnn_name.0) --> RevSequence [rnn_name.0] -- (DATA_NODE: rnn_name.0) -> Result1
* -- (data_node: rnn_name.1) --> Result2
*
*
* After transformation (with identity, no collisions):
* TI has other_name, but it doesn't affect result names due TI is not connected to Results directly.
* TI [other_name] -- (data_node: other_name.0) --> RevSequence [rnn_name.0] -- (data_node: rnn_name.0) -> Result1
* -- (data_node: other_name.1) --> Identity(rnn_name.1) -- (data_node: rnn_name.1) -> Result2
*/
auto identity_1 = std::make_shared<opset5::Unsqueeze>(tensor_iterator->output(1), axis_1);
auto identity_2 = std::make_shared<opset5::Squeeze>(identity_1, axis_1);
ngraph::copy_runtime_info(sequence, {reverse_seq_after, tensor_iterator, identity_1, identity_2, reverse_seq_before});
ngraph::replace_node(sequence, {reverse_seq_after, identity_2});
tensor_iterator->set_friendly_name(sequence->get_friendly_name() + "/tensor_iterator");
reverse_seq_after->set_friendly_name(sequence->get_friendly_name() + ".0");
identity_2->set_friendly_name(sequence->get_friendly_name() + ".1");
} else {
ngraph::copy_runtime_info(sequence, tensor_iterator);
ngraph::replace_node(sequence, tensor_iterator);
}
return true;
};
auto m = std::make_shared<ngraph::pattern::Matcher>(rnn_seq, "ConvertRNNSequenceToTensorIterator");
register_matcher(m, callback);
}
ngraph::pass::ConvertGRUSequenceToTensorIterator::ConvertGRUSequenceToTensorIterator() {
// X, H, seq_lengths - static, W,R,B - any
auto rnn_seq = ngraph::pattern::wrap_type<opset5::GRUSequence>({pattern::any_input(pattern::has_static_shape()),
pattern::any_input(pattern::has_static_shape()),
pattern::any_input(pattern::has_static_shape()),
pattern::any_input(),
pattern::any_input(),
pattern::any_input()});
ngraph::matcher_pass_callback callback = [](ngraph::pattern::Matcher &m) {
auto sequence = std::dynamic_pointer_cast<ngraph::opset5::GRUSequence>(m.get_match_root());
// Bidirectional Sequence op should be decomposed to Reverse + Forward
// (e.g. apply BidirectionalRNNSequenceDecomposition transformation before this one)
if (!sequence || sequence->get_direction() == ngraph::op::RecurrentSequenceDirection::BIDIRECTIONAL) {
return false;
}
NodeVector new_nodes;
const auto &X = sequence->input_value(0); // split
const auto &H_t = sequence->input_value(1); // merged (init value + back edge)
const auto &seq_lengths = sequence->input_value(2); // invariant
const auto &W = sequence->input_value(3); // const in the body
const auto &R = sequence->input_value(4); // const in the body
const auto &B = sequence->input_value(5); // const in the body
bool is_reverse = sequence->get_direction() == ngraph::op::RecurrentSequenceDirection::REVERSE;
auto tensor_iterator = std::make_shared<opset5::TensorIterator>();
auto max_seq_len = X.get_shape().at(1);
bool enable_mask = should_enable_mask(seq_lengths, max_seq_len);
std::shared_ptr<Node> reverse_seq_before;
if (is_reverse && enable_mask) {
reverse_seq_before = std::make_shared<opset5::ReverseSequence>(X, seq_lengths, 0, 1);
}
// TensorIterator Body: begin
Shape X_param_shape = X.get_shape();
X_param_shape.at(1) = 1; // split by seq_lengths dimension
auto X_body_param = std::make_shared<opset5::Parameter>(X.get_element_type(), X_param_shape);
auto H_body_param = std::make_shared<opset5::Parameter>(H_t.get_element_type(),
H_t.get_shape());
auto seq_body_param = std::make_shared<opset5::Parameter>(seq_lengths.get_element_type(),
seq_lengths.get_partial_shape());
auto axis_0 = std::make_shared<opset5::Constant>(element::i64, Shape{1}, std::vector<int64_t>{0});
auto axis_1 = std::make_shared<opset5::Constant>(element::i64, Shape{1}, std::vector<int64_t>{1});
const auto& ins = squeeze_nodes({X_body_param, H_body_param, W, R, B}, {axis_1, axis_1, axis_0, axis_0, axis_0});
auto cell = std::make_shared<opset5::GRUCell>(ins[0],
ins[1],
ins[2],
ins[3],
ins[4],
sequence->get_hidden_size(),
sequence->get_activations(),
sequence->get_activations_alpha(),
sequence->get_activations_beta(),
sequence->get_clip(),
sequence->get_linear_before_reset());
ParameterVector body_params;
ResultVector body_results;
auto unsqueeze_dum_dir = std::make_shared<opset5::Unsqueeze>(cell->output(0), axis_1);
Output<Node> h_node_to_result = unsqueeze_dum_dir;
if (enable_mask) {
auto current_iter = get_current_iter(body_params, body_results, seq_body_param);
h_node_to_result = get_masked_value(tensor_iterator, body_params, body_results, current_iter,
unsqueeze_dum_dir, seq_body_param);
}
auto H_res = std::make_shared<opset5::Result>(h_node_to_result);
auto unsqueeze_seq_len = std::make_shared<opset5::Unsqueeze>(h_node_to_result, axis_1);
auto concat_res = std::make_shared<opset5::Result>(unsqueeze_seq_len);
body_params.push_back(X_body_param);
body_params.push_back(H_body_param);
body_params.push_back(seq_body_param);
body_results.push_back(concat_res);
body_results.push_back(H_res);
auto body = std::make_shared<ngraph::Function>(body_results, body_params);
tensor_iterator->set_function(body);
// TensorIterator Body: end
if (is_reverse) {
if (!enable_mask) {
// Reversed order, stride -1
tensor_iterator->set_sliced_input(X_body_param, X, -1, -1, 1, 0, 1);
tensor_iterator->get_concatenated_slices(concat_res, -1, -1, 1, 0, 2);
} else {
// use ReverseSequence as initializer
tensor_iterator->set_sliced_input(X_body_param, reverse_seq_before, 0, 1, 1, -1, 1);
tensor_iterator->get_concatenated_slices(concat_res, 0, 1, 1, -1, 2);
}
} else {
// forward order
tensor_iterator->set_sliced_input(X_body_param, X, 0, 1, 1, -1, 1);
tensor_iterator->get_concatenated_slices(concat_res, 0, 1, 1, -1, 2);
}
tensor_iterator->set_merged_input(H_body_param, H_t, H_res);
tensor_iterator->set_invariant_input(seq_body_param, seq_lengths);
Output<Node> H_out = H_res;
if (enable_mask) {
// create initial values for body_parameters in outer graph
// aggregated Y_h - concatenation of the last non-zero values for each batch
auto aggregated_Y_h = std::make_shared<ngraph::opset5::Constant>(H_body_param->get_element_type(),
H_body_param->get_shape(),
std::vector<float>(shape_size(H_body_param->get_shape()),
0.f));
auto init_val_curr_iter = std::make_shared<ngraph::opset5::Constant>(seq_lengths.get_element_type(),
ngraph::Shape{1},
std::vector<int64_t>{1});
ngraph::copy_runtime_info(sequence, {aggregated_Y_h, init_val_curr_iter});
// set initial value and back edge for current iteration
tensor_iterator->set_merged_input(body_params.at(0), init_val_curr_iter, body_results.at(0));
// set initial value and back edge for aggregated H
tensor_iterator->set_merged_input(body_params.at(1), aggregated_Y_h, body_results.at(1));
H_out = tensor_iterator->get_function()->get_results()[1];
}
tensor_iterator->get_iter_value(H_out);
tensor_iterator->set_friendly_name(sequence->get_friendly_name());
if (enable_mask && is_reverse) {
auto reverse_seq_after = std::make_shared<opset5::ReverseSequence>(tensor_iterator->output(0), seq_lengths, 0, 2);
// Resolve a collision of names data nodes in CNN Network in Reverse case with mask.
/*
* Before transformation (no collisions)
* RNN/LSTM/GRU Sequence [rnn_name] -- (data_node: rnn_name.0) - > Result1
* -- (data_node: rnn_name.1) - > Result2
*
*
* After transformation (without identity, there are collisions):
* We need to set rnn_name.0 to RevSequence to store result name.
* TI [rnn_name] -- (DATA_NODE: rnn_name.0) --> RevSequence [rnn_name.0] -- (DATA_NODE: rnn_name.0) -> Result1
* -- (data_node: rnn_name.1) --> Result2
*
*
* After transformation (with identity, no collisions):
* TI has other_name, but it doesn't affect result names due TI is not connected to Results directly.
* TI [other_name] -- (data_node: other_name.0) --> RevSequence [rnn_name.0] -- (data_node: rnn_name.0) -> Result1
* -- (data_node: other_name.1) --> Identity(rnn_name.1) -- (data_node: rnn_name.1) -> Result2
*/
auto identity_1 = std::make_shared<opset5::Unsqueeze>(tensor_iterator->output(1), axis_1);
auto identity_2 = std::make_shared<opset5::Squeeze>(identity_1, axis_1);
ngraph::copy_runtime_info(sequence, {reverse_seq_after, tensor_iterator, reverse_seq_before, identity_2, identity_1});
ngraph::replace_node(sequence, {reverse_seq_after, identity_2});
tensor_iterator->set_friendly_name(sequence->get_friendly_name() + "/tensor_iterator");
reverse_seq_after->set_friendly_name(sequence->get_friendly_name() + ".0");
identity_2->set_friendly_name(sequence->get_friendly_name() + ".1");
} else {
ngraph::copy_runtime_info(sequence, tensor_iterator);
ngraph::replace_node(sequence, tensor_iterator);
}
return true;
};
auto m = std::make_shared<ngraph::pattern::Matcher>(rnn_seq, "ConvertGRUSequenceToTensorIterator");
register_matcher(m, callback);
}
ngraph::pass::ConvertLSTMSequenceToTensorIterator::ConvertLSTMSequenceToTensorIterator() {
// X, H, C, seq_lengths - static, W,R,B - any
auto rnn_seq = ngraph::pattern::wrap_type<opset5::LSTMSequence>({pattern::any_input(pattern::has_static_shape()),
pattern::any_input(pattern::has_static_shape()),
pattern::any_input(pattern::has_static_shape()),
pattern::any_input(pattern::has_static_shape()),
pattern::any_input(),
pattern::any_input(),
pattern::any_input()});
ngraph::matcher_pass_callback callback = [](ngraph::pattern::Matcher &m) {
auto sequence = std::dynamic_pointer_cast<ngraph::opset5::LSTMSequence>(m.get_match_root());
// Bidirectional Sequence op should be decomposed to Reverse + Forward
// (e.g. apply BidirectionalRNNSequenceDecomposition transformation before this one)
if (!sequence || sequence->get_direction() == ngraph::op::RecurrentSequenceDirection::BIDIRECTIONAL) {
return false;
}
NodeVector new_nodes;
const auto &X = sequence->input_value(0); // split
const auto &H_t = sequence->input_value(1); // merged (init value + back edge)
const auto &C_t = sequence->input_value(2); // merged (init value + back edge)
const auto &seq_lengths = sequence->input_value(3); // invariant
const auto &W = sequence->input_value(4); // const in the body
const auto &R = sequence->input_value(5); // const in the body
const auto &B = sequence->input_value(6); // const in the body
bool is_reverse = sequence->get_direction() == ngraph::op::RecurrentSequenceDirection::REVERSE;
auto tensor_iterator = std::make_shared<opset5::TensorIterator>();
auto max_seq_len = X.get_shape().at(1);
bool enable_mask = should_enable_mask(seq_lengths, max_seq_len);
std::shared_ptr<Node> reverse_seq_before;
if (is_reverse && enable_mask) {
reverse_seq_before = std::make_shared<opset5::ReverseSequence>(X, seq_lengths, 0, 1);
}
// TensorIterator Body: begin
Shape X_param_shape = X.get_shape();
X_param_shape.at(1) = 1; // split by seq_lengths dimension
auto X_body_param = std::make_shared<opset5::Parameter>(X.get_element_type(), X_param_shape);
auto H_body_param = std::make_shared<opset5::Parameter>(H_t.get_element_type(),
H_t.get_shape());
auto C_body_param = std::make_shared<opset5::Parameter>(C_t.get_element_type(),
C_t.get_partial_shape());
auto seq_body_param = std::make_shared<opset5::Parameter>(seq_lengths.get_element_type(),
seq_lengths.get_partial_shape());
auto axis_0 = std::make_shared<opset5::Constant>(element::i64, Shape{1}, std::vector<int64_t>{0});
auto axis_1 = std::make_shared<opset5::Constant>(element::i64, Shape{1}, std::vector<int64_t>{1});
const auto& ins = squeeze_nodes({X_body_param, H_body_param, C_body_param, W, R, B},
{axis_1, axis_1, axis_1, axis_0, axis_0, axis_0});
auto cell = std::make_shared<opset5::LSTMCell>(ins[0],
ins[1],
ins[2],
ins[3],
ins[4],
ins[5],
sequence->get_hidden_size(),
sequence->get_activations(),
sequence->get_activations_alpha(),
sequence->get_activations_beta(),
sequence->get_clip());
ParameterVector body_params;
ResultVector body_results;
auto unsqueeze_dum_dir_h = std::make_shared<opset5::Unsqueeze>(cell->output(0), axis_1);
auto unsqueeze_dum_dir_c = std::make_shared<opset5::Unsqueeze>(cell->output(1), axis_1);
Output<Node> h_node_to_result = unsqueeze_dum_dir_h;
Output<Node> c_node_to_result = unsqueeze_dum_dir_c;
if (enable_mask) {
auto current_iter = get_current_iter(body_params, body_results, seq_body_param);
h_node_to_result = get_masked_value(tensor_iterator, body_params, body_results, current_iter,
unsqueeze_dum_dir_h, seq_body_param);
c_node_to_result = get_masked_value(tensor_iterator, body_params, body_results, current_iter,
unsqueeze_dum_dir_c, seq_body_param);
}
auto H_res = std::make_shared<opset5::Result>(h_node_to_result);
auto C_res = std::make_shared<opset5::Result>(c_node_to_result);
auto unsqueeze_seq_len = std::make_shared<opset5::Unsqueeze>(h_node_to_result, axis_1);
auto concat_res = std::make_shared<opset5::Result>(unsqueeze_seq_len);
body_params.push_back(X_body_param);
body_params.push_back(H_body_param);
body_params.push_back(C_body_param);
body_params.push_back(seq_body_param);
body_results.push_back(concat_res);
body_results.push_back(H_res);
body_results.push_back(C_res);
auto body = std::make_shared<ngraph::Function>(body_results, body_params);
tensor_iterator->set_function(body);
// TensorIterator Body: end
if (is_reverse) {
if (!enable_mask) {
// Reversed order, stride -1
tensor_iterator->set_sliced_input(X_body_param, X, -1, -1, 1, 0, 1);
tensor_iterator->get_concatenated_slices(concat_res, -1, -1, 1, 0, 2);
} else {
// use ReverseSequence as initializer
tensor_iterator->set_sliced_input(X_body_param, reverse_seq_before, 0, 1, 1, -1, 1);
tensor_iterator->get_concatenated_slices(concat_res, 0, 1, 1, -1, 2);
}
} else {
// forward order
tensor_iterator->set_sliced_input(X_body_param, X, 0, 1, 1, -1, 1);
tensor_iterator->get_concatenated_slices(concat_res, 0, 1, 1, -1, 2);
}
tensor_iterator->set_merged_input(H_body_param, H_t, H_res);
tensor_iterator->set_merged_input(C_body_param, C_t, C_res);
tensor_iterator->set_invariant_input(seq_body_param, seq_lengths);
Output<Node> H_out = H_res;
Output<Node> C_out = C_res;
if (enable_mask) {
// create initial values for body_parameters in outer graph
// aggregated Y_h - concatenation of the last non-zero values for each batch
auto aggregated_Y_h = std::make_shared<ngraph::opset5::Constant>(H_body_param->get_element_type(),
H_body_param->get_shape(),
std::vector<float>(shape_size(H_body_param->get_shape()),
0.f));
auto aggregated_Y_c = std::make_shared<ngraph::opset5::Constant>(C_body_param->get_element_type(),
C_body_param->get_shape(),
std::vector<float>(shape_size(C_body_param->get_shape()),
0.f));
auto init_val_curr_iter = std::make_shared<ngraph::opset5::Constant>(seq_lengths.get_element_type(),
ngraph::Shape{1},
std::vector<int64_t>{1});
ngraph::copy_runtime_info(sequence, {aggregated_Y_h, init_val_curr_iter, aggregated_Y_c});
// set initial value and back edge for current iteration
tensor_iterator->set_merged_input(body_params.at(0), init_val_curr_iter, body_results.at(0));
// set initial value and back edge for aggregated H
tensor_iterator->set_merged_input(body_params.at(1), aggregated_Y_h, body_results.at(1));
// set initial value and back edge for aggregated H
tensor_iterator->set_merged_input(body_params.at(2), aggregated_Y_c, body_results.at(2));
H_out = tensor_iterator->get_function()->get_results()[1];
C_out = tensor_iterator->get_function()->get_results()[2];
}
tensor_iterator->get_iter_value(H_out);
tensor_iterator->get_iter_value(C_out);
tensor_iterator->set_friendly_name(sequence->get_friendly_name());
if (enable_mask && is_reverse) {
auto reverse_seq_after = std::make_shared<opset5::ReverseSequence>(tensor_iterator->output(0), seq_lengths, 0, 2);
// Resolve a collision of names data nodes in CNN Network in Reverse case with mask.
/*
* Before transformation (no collisions)
* RNN/LSTM/GRU Sequence [rnn_name] -- (data_node: rnn_name.0) - > Result1
* -- (data_node: rnn_name.1) - > Result2
*
*
* After transformation (without identity, there are collisions):
* We need to set rnn_name.0 to RevSequence to store result name.
* TI [rnn_name] -- (DATA_NODE: rnn_name.0) --> RevSequence [rnn_name.0] -- (DATA_NODE: rnn_name.0) -> Result1
* -- (data_node: rnn_name.1) --> Result2
*
*
* After transformation (with identity, no collisions):
* TI has other_name, but it doesn't affect result names due TI is not connected to Results directly.
* TI [other_name] -- (data_node: other_name.0) --> RevSequence [rnn_name.0] -- (data_node: rnn_name.0) -> Result1
* -- (data_node: other_name.1) --> Identity(rnn_name.1) -- (data_node: rnn_name.1) -> Result2
*/
auto identity_1_h = std::make_shared<opset5::Unsqueeze>(tensor_iterator->output(1), axis_1);
auto identity_2_h = std::make_shared<opset5::Squeeze>(identity_1_h, axis_1);
auto identity_1_c = std::make_shared<opset5::Unsqueeze>(tensor_iterator->output(2), axis_1);
auto identity_2_c = std::make_shared<opset5::Squeeze>(identity_1_c, axis_1);
ngraph::copy_runtime_info(sequence, {reverse_seq_after, tensor_iterator, reverse_seq_before, identity_2_c, identity_1_c,
identity_1_h, identity_2_h});
ngraph::replace_node(sequence, {reverse_seq_after, identity_2_h, identity_2_c});
tensor_iterator->set_friendly_name(sequence->get_friendly_name() + "/tensor_iterator");
reverse_seq_after->set_friendly_name(sequence->get_friendly_name() + ".0");
identity_2_h->set_friendly_name(sequence->get_friendly_name() + ".1");
identity_2_c->set_friendly_name(sequence->get_friendly_name() + ".2");
} else {
ngraph::copy_runtime_info(sequence, tensor_iterator);
ngraph::replace_node(sequence, tensor_iterator);
}
return true;
};
auto m = std::make_shared<ngraph::pattern::Matcher>(rnn_seq, "ConvertLSTMSequenceToTensorIterator");
register_matcher(m, callback);
}

View File

@ -0,0 +1,470 @@
// Copyright (C) 2020 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#include <string>
#include "ngraph_reader_tests.hpp"
#include "common_test_utils/data_utils.hpp"
TEST_F(NGraphReaderTests, LSTMSeqNetwork) {
std::string model = R"V0G0N(
<net name="LSTMSeqNetwork" version="10">
<layers>
<layer id="0" name="0" type="Parameter" version="opset1">
<data element_type="f32" shape="10,3,512"/>
<output>
<port id="0" precision="FP32">
<dim>10</dim>
<dim>3</dim>
<dim>512</dim>
</port>
</output>
</layer>
<layer id="1" name="1" type="Parameter" version="opset1">
<data element_type="f32" shape="10,1,256"/>
<output>
<port id="0" precision="FP32">
<dim>10</dim>
<dim>1</dim>
<dim>256</dim>
</port>
</output>
</layer>
<layer id="2" name="2" type="Parameter" version="opset1">
<data element_type="f32" shape="10,1,256"/>
<output>
<port id="0" precision="FP32">
<dim>10</dim>
<dim>1</dim>
<dim>256</dim>
</port>
</output>
</layer>
<layer id="3" name="3" type="Parameter" version="opset1">
<data element_type="f32" shape="10"/>
<output>
<port id="0" precision="FP32">
<dim>10</dim>
</port>
</output>
</layer>
<layer id="4" name="4" type="Parameter" version="opset1">
<data element_type="f32" shape="1,1024,512"/>
<output>
<port id="0" precision="FP32">
<dim>1</dim>
<dim>1024</dim>
<dim>512</dim>
</port>
</output>
</layer>
<layer id="5" name="5" type="Parameter" version="opset1">
<data element_type="f32" shape="1,1024,256"/>
<output>
<port id="0" precision="FP32">
<dim>1</dim>
<dim>1024</dim>
<dim>256</dim>
</port>
</output>
</layer>
<layer id="6" name="6" type="Parameter" version="opset1">
<data element_type="f32" shape="1,1024"/>
<output>
<port id="0" precision="FP32">
<dim>1</dim>
<dim>1024</dim>
</port>
</output>
</layer>
<layer id="7" name="layer/LSTMSequence" type="LSTMSequence" version="opset5">
<data hidden_size="256"/>
<input>
<port id="0">
<dim>10</dim>
<dim>3</dim>
<dim>512</dim>
</port>
<port id="1">
<dim>10</dim>
<dim>1</dim>
<dim>256</dim>
</port>
<port id="2">
<dim>10</dim>
<dim>1</dim>
<dim>256</dim>
</port>
<port id="3">
<dim>10</dim>
</port>
<port id="4">
<dim>1</dim>
<dim>1024</dim>
<dim>512</dim>
</port>
<port id="5">
<dim>1</dim>
<dim>1024</dim>
<dim>256</dim>
</port>
<port id="6">
<dim>1</dim>
<dim>1024</dim>
</port>
</input>
<output>
<port id="7" precision="FP32">
<dim>10</dim>
<dim>1</dim>
<dim>3</dim>
<dim>256</dim>
</port>
<port id="8" precision="FP32">
<dim>10</dim>
<dim>1</dim>
<dim>256</dim>
</port>
<port id="9" precision="FP32">
<dim>10</dim>
<dim>1</dim>
<dim>256</dim>
</port>
</output>
</layer>
<layer id="8" name="8" type="Result" version="opset1">
<input>
<port id="0" precision="FP32">
<dim>10</dim>
<dim>1</dim>
<dim>3</dim>
<dim>256</dim>
</port>
</input>
</layer>
<layer id="9" name="9" type="Result" version="opset1">
<input>
<port id="0" precision="FP32">
<dim>10</dim>
<dim>1</dim>
<dim>256</dim>
</port>
</input>
</layer>
<layer id="10" name="10" type="Result" version="opset1">
<input>
<port id="0" precision="FP32">
<dim>10</dim>
<dim>1</dim>
<dim>256</dim>
</port>
</input>
</layer>
</layers>
<edges>
<edge from-layer="0" from-port="0" to-layer="7" to-port="0"/>
<edge from-layer="1" from-port="0" to-layer="7" to-port="1"/>
<edge from-layer="2" from-port="0" to-layer="7" to-port="2"/>
<edge from-layer="3" from-port="0" to-layer="7" to-port="3"/>
<edge from-layer="4" from-port="0" to-layer="7" to-port="4"/>
<edge from-layer="5" from-port="0" to-layer="7" to-port="5"/>
<edge from-layer="6" from-port="0" to-layer="7" to-port="6"/>
<edge from-layer="7" from-port="7" to-layer="8" to-port="0"/>
<edge from-layer="7" from-port="8" to-layer="9" to-port="0"/>
<edge from-layer="7" from-port="9" to-layer="10" to-port="0"/>
</edges>
</net>
)V0G0N";
Blob::CPtr blob;
Core reader;
reader.ReadNetwork(model, blob);
}
TEST_F(NGraphReaderTests, GRUSeqNetwork) {
std::string model = R"V0G0N(
<net name="GRUSeqNetwork" version="10">
<layers>
<layer id="0" name="0" type="Parameter" version="opset1">
<data element_type="f32" shape="10,3,512"/>
<output>
<port id="0" precision="FP32">
<dim>10</dim>
<dim>3</dim>
<dim>512</dim>
</port>
</output>
</layer>
<layer id="1" name="1" type="Parameter" version="opset1">
<data element_type="f32" shape="10,1,256"/>
<output>
<port id="0" precision="FP32">
<dim>10</dim>
<dim>1</dim>
<dim>256</dim>
</port>
</output>
</layer>
<layer id="3" name="3" type="Parameter" version="opset1">
<data element_type="f32" shape="10"/>
<output>
<port id="0" precision="FP32">
<dim>10</dim>
</port>
</output>
</layer>
<layer id="4" name="4" type="Parameter" version="opset1">
<data element_type="f32" shape="1,768,512"/>
<output>
<port id="0" precision="FP32">
<dim>1</dim>
<dim>768</dim>
<dim>512</dim>
</port>
</output>
</layer>
<layer id="5" name="5" type="Parameter" version="opset1">
<data element_type="f32" shape="1,768,256"/>
<output>
<port id="0" precision="FP32">
<dim>1</dim>
<dim>768</dim>
<dim>256</dim>
</port>
</output>
</layer>
<layer id="6" name="6" type="Parameter" version="opset1">
<data element_type="f32" shape="1,768"/>
<output>
<port id="0" precision="FP32">
<dim>1</dim>
<dim>768</dim>
</port>
</output>
</layer>
<layer id="7" name="layer/LSTMSequence" type="GRUSequence" version="opset5">
<data hidden_size="256"/>
<input>
<port id="0">
<dim>10</dim>
<dim>3</dim>
<dim>512</dim>
</port>
<port id="1">
<dim>10</dim>
<dim>1</dim>
<dim>256</dim>
</port>
<port id="3">
<dim>10</dim>
</port>
<port id="4">
<dim>1</dim>
<dim>768</dim>
<dim>512</dim>
</port>
<port id="5">
<dim>1</dim>
<dim>768</dim>
<dim>256</dim>
</port>
<port id="6">
<dim>1</dim>
<dim>768</dim>
</port>
</input>
<output>
<port id="7" precision="FP32">
<dim>10</dim>
<dim>1</dim>
<dim>3</dim>
<dim>256</dim>
</port>
<port id="8" precision="FP32">
<dim>10</dim>
<dim>1</dim>
<dim>256</dim>
</port>
</output>
</layer>
<layer id="8" name="8" type="Result" version="opset1">
<input>
<port id="0" precision="FP32">
<dim>10</dim>
<dim>1</dim>
<dim>3</dim>
<dim>256</dim>
</port>
</input>
</layer>
<layer id="9" name="9" type="Result" version="opset1">
<input>
<port id="0" precision="FP32">
<dim>10</dim>
<dim>1</dim>
<dim>256</dim>
</port>
</input>
</layer>
</layers>
<edges>
<edge from-layer="0" from-port="0" to-layer="7" to-port="0"/>
<edge from-layer="1" from-port="0" to-layer="7" to-port="1"/>
<edge from-layer="3" from-port="0" to-layer="7" to-port="3"/>
<edge from-layer="4" from-port="0" to-layer="7" to-port="4"/>
<edge from-layer="5" from-port="0" to-layer="7" to-port="5"/>
<edge from-layer="6" from-port="0" to-layer="7" to-port="6"/>
<edge from-layer="7" from-port="7" to-layer="8" to-port="0"/>
<edge from-layer="7" from-port="8" to-layer="9" to-port="0"/>
</edges>
</net>
)V0G0N";
Blob::CPtr blob;
Core reader;
reader.ReadNetwork(model, blob);
}
TEST_F(NGraphReaderTests, RNNSeqNetwork) {
std::string model = R"V0G0N(
<net name="RNNSeqNetwork" version="10">
<layers>
<layer id="0" name="0" type="Parameter" version="opset1">
<data element_type="f32" shape="10,3,512"/>
<output>
<port id="0" precision="FP32">
<dim>10</dim>
<dim>3</dim>
<dim>512</dim>
</port>
</output>
</layer>
<layer id="1" name="1" type="Parameter" version="opset1">
<data element_type="f32" shape="10,1,256"/>
<output>
<port id="0" precision="FP32">
<dim>10</dim>
<dim>1</dim>
<dim>256</dim>
</port>
</output>
</layer>
<layer id="3" name="3" type="Parameter" version="opset1">
<data element_type="f32" shape="10"/>
<output>
<port id="0" precision="FP32">
<dim>10</dim>
</port>
</output>
</layer>
<layer id="4" name="4" type="Parameter" version="opset1">
<data element_type="f32" shape="1,256,512"/>
<output>
<port id="0" precision="FP32">
<dim>1</dim>
<dim>256</dim>
<dim>512</dim>
</port>
</output>
</layer>
<layer id="5" name="5" type="Parameter" version="opset1">
<data element_type="f32" shape="1,256,256"/>
<output>
<port id="0" precision="FP32">
<dim>1</dim>
<dim>256</dim>
<dim>256</dim>
</port>
</output>
</layer>
<layer id="6" name="6" type="Parameter" version="opset1">
<data element_type="f32" shape="1,256"/>
<output>
<port id="0" precision="FP32">
<dim>1</dim>
<dim>256</dim>
</port>
</output>
</layer>
<layer id="7" name="layer/LSTMSequence" type="RNNSequence" version="opset5">
<data hidden_size="256"/>
<input>
<port id="0">
<dim>10</dim>
<dim>3</dim>
<dim>512</dim>
</port>
<port id="1">
<dim>10</dim>
<dim>1</dim>
<dim>256</dim>
</port>
<port id="3">
<dim>10</dim>
</port>
<port id="4">
<dim>1</dim>
<dim>256</dim>
<dim>512</dim>
</port>
<port id="5">
<dim>1</dim>
<dim>256</dim>
<dim>256</dim>
</port>
<port id="6">
<dim>1</dim>
<dim>256</dim>
</port>
</input>
<output>
<port id="7" precision="FP32">
<dim>10</dim>
<dim>1</dim>
<dim>3</dim>
<dim>256</dim>
</port>
<port id="8" precision="FP32">
<dim>10</dim>
<dim>1</dim>
<dim>256</dim>
</port>
</output>
</layer>
<layer id="8" name="8" type="Result" version="opset1">
<input>
<port id="0" precision="FP32">
<dim>10</dim>
<dim>1</dim>
<dim>3</dim>
<dim>256</dim>
</port>
</input>
</layer>
<layer id="9" name="9" type="Result" version="opset1">
<input>
<port id="0" precision="FP32">
<dim>10</dim>
<dim>1</dim>
<dim>256</dim>
</port>
</input>
</layer>
</layers>
<edges>
<edge from-layer="0" from-port="0" to-layer="7" to-port="0"/>
<edge from-layer="1" from-port="0" to-layer="7" to-port="1"/>
<edge from-layer="3" from-port="0" to-layer="7" to-port="3"/>
<edge from-layer="4" from-port="0" to-layer="7" to-port="4"/>
<edge from-layer="5" from-port="0" to-layer="7" to-port="5"/>
<edge from-layer="6" from-port="0" to-layer="7" to-port="6"/>
<edge from-layer="7" from-port="7" to-layer="8" to-port="0"/>
<edge from-layer="7" from-port="8" to-layer="9" to-port="0"/>
</edges>
</net>
)V0G0N";
Blob::CPtr blob;
Core reader;
reader.ReadNetwork(model, blob);
}

View File

@ -10,12 +10,16 @@
using namespace LayerTestsDefinitions;
namespace {
std::vector<ngraph::helpers::SequenceTestsMode> mode{ngraph::helpers::SequenceTestsMode::CONVERT_TO_TI_MAX_SEQ_LEN_CONST,
ngraph::helpers::SequenceTestsMode::CONVERT_TO_TI_RAND_SEQ_LEN_CONST,
ngraph::helpers::SequenceTestsMode::CONVERT_TO_TI_RAND_SEQ_LEN_PARAM,
ngraph::helpers::SequenceTestsMode::PURE_SEQ};
// output values increase rapidly without clip, so use only seq_lenghts = 2
std::vector<size_t> seq_lengths_zero_clip{2};
std::vector<size_t> seq_lengths_clip_non_zero{20};
std::vector<size_t> batch{1, 10};
std::vector<size_t> batch{10};
std::vector<size_t> hidden_size{1, 10};
std::vector<size_t> input_size{10};
// std::vector<size_t> input_size{10};
std::vector<std::vector<std::string>> activations = {{"relu", "tanh"}, {"tanh", "sigmoid"}, {"sigmoid", "tanh"},
{"tanh", "relu"}};
std::vector<bool> linear_before_reset = {true, false};
@ -30,10 +34,11 @@ namespace {
INSTANTIATE_TEST_CASE_P(smoke_GRUSequenceCommonZeroClip, GRUSequenceTest,
::testing::Combine(
::testing::ValuesIn(mode),
::testing::ValuesIn(seq_lengths_zero_clip),
::testing::ValuesIn(batch),
::testing::ValuesIn(hidden_size),
::testing::ValuesIn(input_size),
// ::testing::ValuesIn(input_size), // hardcoded to 10 due to Combine supports up to 10 args
::testing::ValuesIn(activations),
::testing::ValuesIn(clip),
::testing::ValuesIn(linear_before_reset),
@ -44,10 +49,11 @@ namespace {
INSTANTIATE_TEST_CASE_P(smoke_GRUSequenceCommonClip, GRUSequenceTest,
::testing::Combine(
::testing::ValuesIn(mode),
::testing::ValuesIn(seq_lengths_clip_non_zero),
::testing::ValuesIn(batch),
::testing::ValuesIn(hidden_size),
::testing::ValuesIn(input_size),
// ::testing::ValuesIn(input_size), // hardcoded to 10 due to Combine supports up to 10 args
::testing::ValuesIn(activations),
::testing::ValuesIn(clip_non_zeros),
::testing::ValuesIn(linear_before_reset),

View File

@ -10,10 +10,14 @@
using namespace LayerTestsDefinitions;
namespace {
std::vector<ngraph::helpers::SequenceTestsMode> mode{ngraph::helpers::SequenceTestsMode::CONVERT_TO_TI_MAX_SEQ_LEN_CONST,
ngraph::helpers::SequenceTestsMode::CONVERT_TO_TI_RAND_SEQ_LEN_CONST,
ngraph::helpers::SequenceTestsMode::CONVERT_TO_TI_RAND_SEQ_LEN_PARAM,
ngraph::helpers::SequenceTestsMode::PURE_SEQ};
// output values increase rapidly without clip, so use only seq_lenghts = 2
std::vector<size_t> seq_lengths_zero_clip{2};
std::vector<size_t> seq_lengths_clip_non_zero{20};
std::vector<size_t> batch{1, 10};
std::vector<size_t> batch{10};
std::vector<size_t> hidden_size{1, 10};
std::vector<size_t> input_size{10};
std::vector<std::vector<std::string>> activations = {{"relu", "sigmoid", "tanh"}, {"sigmoid", "tanh", "tanh"},
@ -30,6 +34,7 @@ namespace {
INSTANTIATE_TEST_CASE_P(smoke_LSTMSequenceCommonZeroClip, LSTMSequenceTest,
::testing::Combine(
::testing::ValuesIn(mode),
::testing::ValuesIn(seq_lengths_zero_clip),
::testing::ValuesIn(batch),
::testing::ValuesIn(hidden_size),
@ -43,6 +48,7 @@ namespace {
INSTANTIATE_TEST_CASE_P(smoke_LSTMSequenceCommonClip, LSTMSequenceTest,
::testing::Combine(
::testing::ValuesIn(mode),
::testing::ValuesIn(seq_lengths_clip_non_zero),
::testing::ValuesIn(batch),
::testing::ValuesIn(hidden_size),

View File

@ -10,6 +10,10 @@
using namespace LayerTestsDefinitions;
namespace {
std::vector<ngraph::helpers::SequenceTestsMode> mode{ngraph::helpers::SequenceTestsMode::CONVERT_TO_TI_MAX_SEQ_LEN_CONST,
ngraph::helpers::SequenceTestsMode::CONVERT_TO_TI_RAND_SEQ_LEN_CONST,
ngraph::helpers::SequenceTestsMode::CONVERT_TO_TI_RAND_SEQ_LEN_PARAM,
ngraph::helpers::SequenceTestsMode::PURE_SEQ};
// output values increase rapidly without clip, so use only seq_lenghts = 2
std::vector<size_t> seq_lengths_zero_clip{2};
std::vector<size_t> seq_lengths_clip_non_zero{20};
@ -21,13 +25,13 @@ namespace {
std::vector<float> clip_non_zeros{0.7f};
std::vector<ngraph::op::RecurrentSequenceDirection> direction = {ngraph::op::RecurrentSequenceDirection::FORWARD,
ngraph::op::RecurrentSequenceDirection::REVERSE,
ngraph::op::RecurrentSequenceDirection::BIDIRECTIONAL
ngraph::op::RecurrentSequenceDirection::BIDIRECTIONAL,
};
std::vector<InferenceEngine::Precision> netPrecisions = {InferenceEngine::Precision::FP32,
InferenceEngine::Precision::FP16};
std::vector<InferenceEngine::Precision> netPrecisions = {InferenceEngine::Precision::FP32};
INSTANTIATE_TEST_CASE_P(smoke_RNNSequenceCommonZeroClip, RNNSequenceTest,
::testing::Combine(
::testing::ValuesIn(mode),
::testing::ValuesIn(seq_lengths_zero_clip),
::testing::ValuesIn(batch),
::testing::ValuesIn(hidden_size),
@ -41,6 +45,7 @@ namespace {
INSTANTIATE_TEST_CASE_P(smoke_RNNSequenceCommonClip, RNNSequenceTest,
::testing::Combine(
::testing::ValuesIn(mode),
::testing::ValuesIn(seq_lengths_clip_non_zero),
::testing::ValuesIn(batch),
::testing::ValuesIn(hidden_size),

View File

@ -16,11 +16,12 @@
namespace LayerTestsDefinitions {
using GRUSequenceParams = typename std::tuple<
// bool, // using decompose to sub-ops transformation
ngraph::helpers::SequenceTestsMode, // pure Sequence or TensorIterator
size_t, // seq_lengths
size_t, // batch
size_t, // hidden size
size_t, // input size
// todo: fix. input size hardcoded to 10 due to limitation (10 args) of gtests Combine() func.
//size_t, // input size
std::vector<std::string>, // activations
float, // clip
bool, // linear_before_reset
@ -35,6 +36,11 @@ public:
protected:
void SetUp() override;
void Infer() override;
private:
ngraph::helpers::SequenceTestsMode m_mode;
int64_t m_max_seq_len = 0;
};
} // namespace LayerTestsDefinitions

View File

@ -16,7 +16,7 @@
namespace LayerTestsDefinitions {
using LSTMSequenceParams = typename std::tuple<
// bool, // using decompose to sub-ops transformation
ngraph::helpers::SequenceTestsMode, // pure Sequence or TensorIterator
size_t, // seq_lengths
size_t, // batch
size_t, // hidden size
@ -31,9 +31,13 @@ class LSTMSequenceTest : public testing::WithParamInterface<LSTMSequenceParams>,
virtual public LayerTestsUtils::LayerTestsCommon {
public:
static std::string getTestCaseName(const testing::TestParamInfo<LSTMSequenceParams> &obj);
protected:
void Infer() override;
void SetUp() override;
private:
ngraph::helpers::SequenceTestsMode m_mode;
int64_t m_max_seq_len = 0;
};
} // namespace LayerTestsDefinitions

View File

@ -16,7 +16,7 @@
namespace LayerTestsDefinitions {
using RNNSequenceParams = typename std::tuple<
// bool, // using decompose to sub-ops transformation
ngraph::helpers::SequenceTestsMode, // pure Sequence or TensorIterator
size_t, // seq_lengths
size_t, // batch
size_t, // hidden size
@ -34,6 +34,11 @@ public:
protected:
void SetUp() override;
void Infer() override;
private:
ngraph::helpers::SequenceTestsMode m_mode;
int64_t m_max_seq_len = 0;
};
} // namespace LayerTestsDefinitions

View File

@ -18,15 +18,16 @@
#include "single_layer_tests/gru_sequence.hpp"
#include <transformations/op_conversions/bidirectional_sequences_decomposition.hpp>
#include <transformations/op_conversions/convert_sequences_to_tensor_iterator.hpp>
namespace LayerTestsDefinitions {
std::string GRUSequenceTest::getTestCaseName(const testing::TestParamInfo<GRUSequenceParams> &obj) {
//bool should_decompose;
ngraph::helpers::SequenceTestsMode mode;
size_t seq_lenghts;
size_t batch;
size_t hidden_size;
size_t input_size;
size_t input_size = 10;
std::vector<std::string> activations;
std::vector<float> activations_alpha;
std::vector<float> activations_beta;
@ -35,13 +36,14 @@ namespace LayerTestsDefinitions {
ngraph::op::RecurrentSequenceDirection direction;
InferenceEngine::Precision netPrecision;
std::string targetDevice;
std::tie(seq_lenghts, batch, hidden_size, input_size, activations, clip, linear_before_reset, direction, netPrecision,
std::tie(mode, seq_lenghts, batch, hidden_size, activations, clip, linear_before_reset, direction, netPrecision,
targetDevice) = obj.param;
std::vector<std::vector<size_t>> inputShapes = {
{{batch, input_size}, {batch, hidden_size}, {batch, hidden_size}, {3 * hidden_size, input_size},
{3 * hidden_size, hidden_size}, {(linear_before_reset ? 4 : 3) * hidden_size}},
};
std::ostringstream result;
result << "mode=" << mode << "_";
result << "seq_lenghts=" << seq_lenghts << "_";
result << "batch=" << batch << "_";
result << "hidden_size=" << hidden_size << "_";
@ -57,10 +59,9 @@ namespace LayerTestsDefinitions {
void GRUSequenceTest::SetUp() {
size_t seq_lenghts;
// bool should_decompose;
size_t batch;
size_t hidden_size;
size_t input_size;
size_t input_size = 10;
std::vector<std::string> activations;
std::vector<float> activations_alpha;
std::vector<float> activations_beta;
@ -68,7 +69,7 @@ namespace LayerTestsDefinitions {
bool linear_before_reset;
ngraph::op::RecurrentSequenceDirection direction;
InferenceEngine::Precision netPrecision;
std::tie(seq_lenghts, batch, hidden_size, input_size, activations, clip, linear_before_reset, direction, netPrecision,
std::tie(m_mode, seq_lenghts, batch, hidden_size, activations, clip, linear_before_reset, direction, netPrecision,
targetDevice) = this->GetParam();
size_t num_directions = direction == ngraph::op::RecurrentSequenceDirection::BIDIRECTIONAL ? 2 : 1;
std::vector<std::vector<size_t>> inputShapes = {
@ -76,16 +77,57 @@ namespace LayerTestsDefinitions {
{num_directions, 3 * hidden_size, input_size}, {num_directions, 3 * hidden_size, hidden_size},
{num_directions, (linear_before_reset ? 4 : 3) * hidden_size}},
};
m_max_seq_len = seq_lenghts;
auto ngPrc = FuncTestUtils::PrecisionUtils::convertIE2nGraphPrc(netPrecision);
auto params = ngraph::builder::makeParams(ngPrc, {inputShapes[0], inputShapes[1]});
if (m_mode == ngraph::helpers::SequenceTestsMode::CONVERT_TO_TI_MAX_SEQ_LEN_PARAM ||
m_mode == ngraph::helpers::SequenceTestsMode::CONVERT_TO_TI_RAND_SEQ_LEN_PARAM) {
auto seq_lengths = ngraph::builder::makeParams(ngraph::element::i64, {inputShapes[2]}).at(0);
seq_lengths->set_friendly_name("seq_lengths");
params.push_back(seq_lengths);
}
std::vector<ngraph::Shape> WRB = {inputShapes[3], inputShapes[4], inputShapes[5], inputShapes[2]};
auto gru_sequence = ngraph::builder::makeGRU(ngraph::helpers::convert2OutputVector(ngraph::helpers::castOps2Nodes(params)),
WRB, hidden_size, activations, {}, {}, clip, linear_before_reset, true, direction);
WRB, hidden_size, activations, {}, {}, clip, linear_before_reset, true, direction,
m_mode);
ngraph::ResultVector results{std::make_shared<ngraph::opset1::Result>(gru_sequence->output(0)),
std::make_shared<ngraph::opset1::Result>(gru_sequence->output(1))};
function = std::make_shared<ngraph::Function>(results, params, "gru_sequence");
if (m_mode != ngraph::helpers::SequenceTestsMode::PURE_SEQ) {
ngraph::pass::Manager manager;
if (direction == ngraph::op::RecurrentSequenceDirection::BIDIRECTIONAL)
manager.register_pass<ngraph::pass::BidirectionalGRUSequenceDecomposition>();
manager.register_pass<ngraph::pass::ConvertGRUSequenceToTensorIterator>();
manager.run_passes(function);
bool ti_found = ngraph::helpers::is_tensor_iterator_exist(function);
EXPECT_EQ(ti_found, true);
} else {
bool ti_found = ngraph::helpers::is_tensor_iterator_exist(function);
EXPECT_EQ(ti_found, false);
}
}
void GRUSequenceTest::Infer() {
inferRequest = executableNetwork.CreateInferRequest();
inputs.clear();
for (const auto &input : executableNetwork.GetInputsInfo()) {
const auto &info = input.second;
auto blob = GenerateInput(*info);
if (input.first == "seq_lengths") {
blob = FuncTestUtils::createAndFillBlob(info->getTensorDesc(), m_max_seq_len, 0);
}
inferRequest.SetBlob(info->name(), blob);
inputs.push_back(blob);
}
if (configuration.count(InferenceEngine::PluginConfigParams::KEY_DYN_BATCH_ENABLED) &&
configuration.count(InferenceEngine::PluginConfigParams::YES)) {
auto batchSize = executableNetwork.GetInputsInfo().begin()->second->getTensorDesc().getDims()[0] / 2;
inferRequest.SetBatch(batchSize);
}
inferRequest.Infer();
}
TEST_P(GRUSequenceTest, CompareWithRefs) {
Run();

View File

@ -18,11 +18,13 @@
#include "single_layer_tests/lstm_sequence.hpp"
#include <transformations/op_conversions/bidirectional_sequences_decomposition.hpp>
#include <transformations/op_conversions/convert_sequences_to_tensor_iterator.hpp>
#include <ngraph/pass/visualize_tree.hpp>
namespace LayerTestsDefinitions {
std::string LSTMSequenceTest::getTestCaseName(const testing::TestParamInfo<LSTMSequenceParams> &obj) {
//bool should_decompose;
ngraph::helpers::SequenceTestsMode mode;
size_t seq_lenghts;
size_t batch;
size_t hidden_size;
@ -34,13 +36,14 @@ namespace LayerTestsDefinitions {
ngraph::op::RecurrentSequenceDirection direction;
InferenceEngine::Precision netPrecision;
std::string targetDevice;
std::tie(seq_lenghts, batch, hidden_size, input_size, activations, clip, direction, netPrecision,
std::tie(mode, seq_lenghts, batch, hidden_size, input_size, activations, clip, direction, netPrecision,
targetDevice) = obj.param;
std::vector<std::vector<size_t>> inputShapes = {
{{batch, input_size}, {batch, hidden_size}, {batch, hidden_size}, {4 * hidden_size, input_size},
{4 * hidden_size, hidden_size}, {4 * hidden_size}},
};
std::ostringstream result;
result << "mode=" << mode << "_";
result << "seq_lenghts=" << seq_lenghts << "_";
result << "batch=" << batch << "_";
result << "hidden_size=" << hidden_size << "_";
@ -56,7 +59,7 @@ namespace LayerTestsDefinitions {
void LSTMSequenceTest::SetUp() {
size_t seq_lenghts;
// bool should_decompose;
size_t batch;
size_t hidden_size;
size_t input_size;
@ -66,24 +69,65 @@ namespace LayerTestsDefinitions {
float clip;
ngraph::op::RecurrentSequenceDirection direction;
InferenceEngine::Precision netPrecision;
std::tie(seq_lenghts, batch, hidden_size, input_size, activations, clip, direction, netPrecision,
std::tie(m_mode, seq_lenghts, batch, hidden_size, input_size, activations, clip, direction, netPrecision,
targetDevice) = this->GetParam();
size_t num_directions = direction == ngraph::op::RecurrentSequenceDirection::BIDIRECTIONAL ? 2 : 1;
m_max_seq_len = seq_lenghts;
std::vector<std::vector<size_t>> inputShapes = {
{{batch, seq_lenghts, input_size}, {batch, num_directions, hidden_size}, {batch, num_directions, hidden_size},
{batch}, {num_directions, 4 * hidden_size, input_size}, {num_directions, 4 * hidden_size, hidden_size}, {num_directions, 4 * hidden_size}},
};
auto ngPrc = FuncTestUtils::PrecisionUtils::convertIE2nGraphPrc(netPrecision);
auto params = ngraph::builder::makeParams(ngPrc, {inputShapes[0], inputShapes[1], inputShapes[2]});
if (m_mode == ngraph::helpers::SequenceTestsMode::CONVERT_TO_TI_MAX_SEQ_LEN_PARAM ||
m_mode == ngraph::helpers::SequenceTestsMode::CONVERT_TO_TI_RAND_SEQ_LEN_PARAM) {
auto seq_lengths = ngraph::builder::makeParams(ngraph::element::i64, {inputShapes[3]}).at(0);
seq_lengths->set_friendly_name("seq_lengths");
params.push_back(seq_lengths);
}
std::vector<ngraph::Shape> WRB = {inputShapes[4], inputShapes[5], inputShapes[6], inputShapes[3]};
auto lstm_sequence = ngraph::builder::makeLSTM(ngraph::helpers::convert2OutputVector(ngraph::helpers::castOps2Nodes(params)),
WRB, hidden_size, activations, {}, {}, clip, true, direction);
WRB, hidden_size, activations, {}, {}, clip, true, direction,
m_mode);
ngraph::ResultVector results{std::make_shared<ngraph::opset1::Result>(lstm_sequence->output(0)),
std::make_shared<ngraph::opset1::Result>(lstm_sequence->output(1)),
std::make_shared<ngraph::opset1::Result>(lstm_sequence->output(2))};
function = std::make_shared<ngraph::Function>(results, params, "lstm_sequence");
if (m_mode != ngraph::helpers::SequenceTestsMode::PURE_SEQ) {
ngraph::pass::Manager manager;
if (direction == ngraph::op::RecurrentSequenceDirection::BIDIRECTIONAL)
manager.register_pass<ngraph::pass::BidirectionalLSTMSequenceDecomposition>();
manager.register_pass<ngraph::pass::ConvertLSTMSequenceToTensorIterator>();
manager.run_passes(function);
bool ti_found = ngraph::helpers::is_tensor_iterator_exist(function);
EXPECT_EQ(ti_found, true);
} else {
bool ti_found = ngraph::helpers::is_tensor_iterator_exist(function);
EXPECT_EQ(ti_found, false);
}
}
void LSTMSequenceTest::Infer() {
inferRequest = executableNetwork.CreateInferRequest();
inputs.clear();
for (const auto &input : executableNetwork.GetInputsInfo()) {
const auto &info = input.second;
auto blob = GenerateInput(*info);
if (input.first == "seq_lengths") {
blob = FuncTestUtils::createAndFillBlob(info->getTensorDesc(), m_max_seq_len, 0);
}
inferRequest.SetBlob(info->name(), blob);
inputs.push_back(blob);
}
if (configuration.count(InferenceEngine::PluginConfigParams::KEY_DYN_BATCH_ENABLED) &&
configuration.count(InferenceEngine::PluginConfigParams::YES)) {
auto batchSize = executableNetwork.GetInputsInfo().begin()->second->getTensorDesc().getDims()[0] / 2;
inferRequest.SetBatch(batchSize);
}
inferRequest.Infer();
}
TEST_P(LSTMSequenceTest, CompareWithRefs) {
Run();

View File

@ -18,11 +18,12 @@
#include "single_layer_tests/rnn_sequence.hpp"
#include <transformations/op_conversions/bidirectional_sequences_decomposition.hpp>
#include <transformations/op_conversions/convert_sequences_to_tensor_iterator.hpp>
namespace LayerTestsDefinitions {
std::string RNNSequenceTest::getTestCaseName(const testing::TestParamInfo<RNNSequenceParams> &obj) {
//bool should_decompose;
ngraph::helpers::SequenceTestsMode mode;
size_t seq_lenghts;
size_t batch;
size_t hidden_size;
@ -34,13 +35,14 @@ namespace LayerTestsDefinitions {
ngraph::op::RecurrentSequenceDirection direction;
InferenceEngine::Precision netPrecision;
std::string targetDevice;
std::tie(seq_lenghts, batch, hidden_size, input_size, activations, clip, direction, netPrecision,
std::tie(mode, seq_lenghts, batch, hidden_size, input_size, activations, clip, direction, netPrecision,
targetDevice) = obj.param;
std::vector<std::vector<size_t>> inputShapes = {
{{batch, input_size}, {batch, hidden_size}, {batch, hidden_size}, {hidden_size, input_size},
{hidden_size, hidden_size}, {hidden_size}},
};
std::ostringstream result;
result << "mode=" << mode << "_";
result << "seq_lenghts=" << seq_lenghts << "_";
result << "batch=" << batch << "_";
result << "hidden_size=" << hidden_size << "_";
@ -56,7 +58,6 @@ namespace LayerTestsDefinitions {
void RNNSequenceTest::SetUp() {
size_t seq_lenghts;
// bool should_decompose;
size_t batch;
size_t hidden_size;
size_t input_size;
@ -66,7 +67,7 @@ namespace LayerTestsDefinitions {
float clip;
ngraph::op::RecurrentSequenceDirection direction;
InferenceEngine::Precision netPrecision;
std::tie(seq_lenghts, batch, hidden_size, input_size, activations, clip, direction, netPrecision,
std::tie(m_mode, seq_lenghts, batch, hidden_size, input_size, activations, clip, direction, netPrecision,
targetDevice) = this->GetParam();
size_t num_directions = direction == ngraph::op::RecurrentSequenceDirection::BIDIRECTIONAL ? 2 : 1;
std::vector<std::vector<size_t>> inputShapes = {
@ -74,16 +75,57 @@ namespace LayerTestsDefinitions {
{num_directions, hidden_size, input_size}, {num_directions, hidden_size, hidden_size},
{num_directions, hidden_size}},
};
m_max_seq_len = seq_lenghts;
auto ngPrc = FuncTestUtils::PrecisionUtils::convertIE2nGraphPrc(netPrecision);
auto params = ngraph::builder::makeParams(ngPrc, {inputShapes[0], inputShapes[1]});
if (m_mode == ngraph::helpers::SequenceTestsMode::CONVERT_TO_TI_MAX_SEQ_LEN_PARAM ||
m_mode == ngraph::helpers::SequenceTestsMode::CONVERT_TO_TI_RAND_SEQ_LEN_PARAM) {
auto seq_lengths = ngraph::builder::makeParams(ngraph::element::i64, {inputShapes[2]}).at(0);
seq_lengths->set_friendly_name("seq_lengths");
params.push_back(seq_lengths);
}
std::vector<ngraph::Shape> WRB = {inputShapes[3], inputShapes[4], inputShapes[5], inputShapes[2]};
auto rnn_sequence = ngraph::builder::makeRNN(ngraph::helpers::convert2OutputVector(ngraph::helpers::castOps2Nodes(params)),
WRB, hidden_size, activations, {}, {}, clip, true, direction);
WRB, hidden_size, activations, {}, {}, clip, true, direction,
m_mode);
ngraph::ResultVector results{std::make_shared<ngraph::opset1::Result>(rnn_sequence->output(0)),
std::make_shared<ngraph::opset1::Result>(rnn_sequence->output(1))};
function = std::make_shared<ngraph::Function>(results, params, "rnn_sequence");
if (m_mode != ngraph::helpers::SequenceTestsMode::PURE_SEQ) {
ngraph::pass::Manager manager;
if (direction == ngraph::op::RecurrentSequenceDirection::BIDIRECTIONAL)
manager.register_pass<ngraph::pass::BidirectionalRNNSequenceDecomposition>();
manager.register_pass<ngraph::pass::ConvertRNNSequenceToTensorIterator>();
manager.run_passes(function);
bool ti_found = ngraph::helpers::is_tensor_iterator_exist(function);
EXPECT_EQ(ti_found, true);
} else {
bool ti_found = ngraph::helpers::is_tensor_iterator_exist(function);
EXPECT_EQ(ti_found, false);
}
}
void RNNSequenceTest::Infer() {
inferRequest = executableNetwork.CreateInferRequest();
inputs.clear();
for (const auto &input : executableNetwork.GetInputsInfo()) {
const auto &info = input.second;
auto blob = GenerateInput(*info);
if (input.first == "seq_lengths") {
blob = FuncTestUtils::createAndFillBlob(info->getTensorDesc(), m_max_seq_len, 0);
}
inferRequest.SetBlob(info->name(), blob);
inputs.push_back(blob);
}
if (configuration.count(InferenceEngine::PluginConfigParams::KEY_DYN_BATCH_ENABLED) &&
configuration.count(InferenceEngine::PluginConfigParams::YES)) {
auto batchSize = executableNetwork.GetInputsInfo().begin()->second->getTensorDesc().getDims()[0] / 2;
inferRequest.SetBatch(batchSize);
}
inferRequest.Infer();
}
TEST_P(RNNSequenceTest, CompareWithRefs) {
Run();

View File

@ -124,7 +124,7 @@ namespace LayerTestsDefinitions {
NGRAPH_CHECK(false, "Bidirectional case is not supported.");
}
tensor_iterator->set_invariant_input(body_params[1], outer_params[1]);
tensor_iterator->set_merged_input(body_params[1], outer_params[1], results[1]);
tensor_iterator->set_invariant_input(body_params[2], outer_params[2]);
tensor_iterator->get_iter_value(results[1]);
tensor_iterator->get_iter_value(results[2]);
@ -166,7 +166,7 @@ namespace LayerTestsDefinitions {
NGRAPH_CHECK(false, "Bidirectional case is not supported.");
}
tensor_iterator->set_invariant_input(body_params[1], outer_params[1]);
tensor_iterator->set_merged_input(body_params[1], outer_params[1], results[0]);
tensor_iterator->get_iter_value(results[0]);
// 3. Outer function
@ -205,7 +205,7 @@ namespace LayerTestsDefinitions {
NGRAPH_CHECK(false, "Bidirectional case is not supported.");
}
tensor_iterator->set_invariant_input(body_params[1], outer_params[1]);
tensor_iterator->set_merged_input(body_params[1], outer_params[1], results[0]);
tensor_iterator->get_iter_value(results[0]);
// 3. Outer function

View File

@ -407,7 +407,8 @@ std::shared_ptr<ngraph::Node> makeLSTM(const OutputVector& in,
const std::vector<float>& activations_beta = {},
float clip = 0.f,
bool make_sequence = false,
ngraph::op::RecurrentSequenceDirection direction = ngraph::op::RecurrentSequenceDirection::FORWARD);
ngraph::op::RecurrentSequenceDirection direction = ngraph::op::RecurrentSequenceDirection::FORWARD,
ngraph::helpers::SequenceTestsMode mode = ngraph::helpers::SequenceTestsMode::PURE_SEQ);
std::shared_ptr<ngraph::Node> makeGRU(const OutputVector& in,
const std::vector<ngraph::Shape>& constants,
@ -419,7 +420,8 @@ std::shared_ptr<ngraph::Node> makeGRU(const OutputVector& in,
float clip = 0.f,
bool linear_before_reset = false,
bool make_sequence = false,
ngraph::op::RecurrentSequenceDirection direction = ngraph::op::RecurrentSequenceDirection::FORWARD);
ngraph::op::RecurrentSequenceDirection direction = ngraph::op::RecurrentSequenceDirection::FORWARD,
ngraph::helpers::SequenceTestsMode mode = ngraph::helpers::SequenceTestsMode::PURE_SEQ);
std::shared_ptr<ngraph::Node> makeRNN(const OutputVector& in,
const std::vector<ngraph::Shape>& constants,
@ -429,7 +431,8 @@ std::shared_ptr<ngraph::Node> makeRNN(const OutputVector& in,
const std::vector<float>& activations_beta = {},
float clip = 0.f,
bool make_sequence = false,
ngraph::op::RecurrentSequenceDirection direction = ngraph::op::RecurrentSequenceDirection::FORWARD);
ngraph::op::RecurrentSequenceDirection direction = ngraph::op::RecurrentSequenceDirection::FORWARD,
ngraph::helpers::SequenceTestsMode mode = ngraph::helpers::SequenceTestsMode::PURE_SEQ);
std::shared_ptr<ngraph::Node> makeGatherND(
const ngraph::Output<Node>& dataNode,

View File

@ -189,9 +189,19 @@ enum class TensorIteratorBody {
// CNN todo: implement
};
enum class SequenceTestsMode {
PURE_SEQ,
CONVERT_TO_TI_MAX_SEQ_LEN_CONST,
CONVERT_TO_TI_MAX_SEQ_LEN_PARAM,
CONVERT_TO_TI_RAND_SEQ_LEN_CONST,
CONVERT_TO_TI_RAND_SEQ_LEN_PARAM,
};
std::ostream &operator<<(std::ostream &os, const ReductionType &m);
std::ostream &operator<<(std::ostream &os, const PadMode &m);
bool is_tensor_iterator_exist(const std::shared_ptr<ngraph::Function> & func);
inline std::string quantizationGranularityToString(const QuantizationGranularity &granularity) {
static std::map<QuantizationGranularity, std::string> names = {
{Pertensor, "Pertensor"},
@ -267,5 +277,7 @@ std::ostream& operator<<(std::ostream & os, ngraph::op::v4::Interpolate::ShapeCa
std::ostream& operator<<(std::ostream & os, TensorIteratorBody type);
std::ostream& operator<<(std::ostream & os, SequenceTestsMode type);
} // namespace helpers
} // namespace ngraph

View File

@ -19,7 +19,8 @@ std::shared_ptr<ngraph::Node> makeGRU(const OutputVector& in,
float clip,
bool linear_before_reset,
bool make_sequence,
ngraph::op::RecurrentSequenceDirection direction) {
ngraph::op::RecurrentSequenceDirection direction,
ngraph::helpers::SequenceTestsMode mode) {
std::vector<float> empty;
auto W = ngraph::builder::makeConstant(in[0].get_element_type(), constants[0], empty, true);
auto R = ngraph::builder::makeConstant(in[0].get_element_type(), constants[1], empty, true);
@ -29,9 +30,32 @@ std::shared_ptr<ngraph::Node> makeGRU(const OutputVector& in,
activations_alpha, activations_beta, clip,
linear_before_reset);
} else {
std::vector<float> lenghts(in[0].get_shape()[0], in[0].get_shape()[1]);
auto seq_lenghts = ngraph::builder::makeConstant(in[0].get_element_type(), constants[3], lenghts, false);
return std::make_shared<ngraph::opset5::GRUSequence>(in[0], in[1], seq_lenghts, W, R, B, hidden_size, direction,
std::shared_ptr<Node> seq_lengths;
switch (mode) {
case ngraph::helpers::SequenceTestsMode::PURE_SEQ:
case ngraph::helpers::SequenceTestsMode::CONVERT_TO_TI_MAX_SEQ_LEN_CONST: {
std::vector<float> lengths(in[0].get_shape()[0], in[0].get_shape()[1]);
seq_lengths = ngraph::builder::makeConstant(element::i64, constants[3], lengths, false);
break;
}
case ngraph::helpers::SequenceTestsMode::CONVERT_TO_TI_RAND_SEQ_LEN_CONST: {
for (size_t i = 0; i <= in[0].get_shape().at(0); ++i) {
std::vector<float> lengths;
seq_lengths = ngraph::builder::makeConstant(element::i64, constants[3], lengths, true,
in[0].get_shape()[1], 0);
}
break;
}
case ngraph::helpers::SequenceTestsMode::CONVERT_TO_TI_RAND_SEQ_LEN_PARAM:
case ngraph::helpers::SequenceTestsMode::CONVERT_TO_TI_MAX_SEQ_LEN_PARAM: {
// Seq_lengths should be as a Parameter node for these two modes
seq_lengths = in.at(2).get_node_shared_ptr();
break;
}
default:
throw std::runtime_error("Incorrect mode for creation of Sequence operation");
}
return std::make_shared<ngraph::opset5::GRUSequence>(in[0], in[1], seq_lengths, W, R, B, hidden_size, direction,
activations, activations_alpha, activations_beta, clip, linear_before_reset);
}
}

View File

@ -18,7 +18,8 @@ std::shared_ptr<ngraph::Node> makeLSTM(const std::vector<ngraph::Output<Node>>&
const std::vector<float>& activations_beta,
float clip,
bool make_sequence,
ngraph::op::RecurrentSequenceDirection direction) {
ngraph::op::RecurrentSequenceDirection direction,
ngraph::helpers::SequenceTestsMode mode) {
std::vector<float> empty;
auto W = ngraph::builder::makeConstant(in[0].get_element_type(), constants[0], empty, true);
auto R = ngraph::builder::makeConstant(in[0].get_element_type(), constants[1], empty, true);
@ -27,9 +28,32 @@ std::shared_ptr<ngraph::Node> makeLSTM(const std::vector<ngraph::Output<Node>>&
return std::make_shared<ngraph::opset4::LSTMCell>(in[0], in[1], in[2], W, R, B, hidden_size, activations,
activations_alpha, activations_beta, clip);
} else {
std::vector<float> lenghts(in[0].get_shape()[0], in[0].get_shape()[1]);
auto seq_lenghts = ngraph::builder::makeConstant(in[0].get_element_type(), constants[3], lenghts, false);
return std::make_shared<ngraph::opset5::LSTMSequence>(in[0], in[1], in[2], seq_lenghts, W, R, B, hidden_size, direction,
std::shared_ptr<Node> seq_lengths;
switch (mode) {
case ngraph::helpers::SequenceTestsMode::PURE_SEQ:
case ngraph::helpers::SequenceTestsMode::CONVERT_TO_TI_MAX_SEQ_LEN_CONST: {
std::vector<float> lengths(in[0].get_shape()[0], in[0].get_shape()[1]);
seq_lengths = ngraph::builder::makeConstant(element::i64, constants[3], lengths, false);
break;
}
case ngraph::helpers::SequenceTestsMode::CONVERT_TO_TI_RAND_SEQ_LEN_CONST: {
for (size_t i = 0; i <= in[0].get_shape().at(0); ++i) {
std::vector<float> lengths;
seq_lengths = ngraph::builder::makeConstant(element::i64, constants[3], lengths, true,
in[0].get_shape()[1], 0);
}
break;
}
case ngraph::helpers::SequenceTestsMode::CONVERT_TO_TI_RAND_SEQ_LEN_PARAM:
case ngraph::helpers::SequenceTestsMode::CONVERT_TO_TI_MAX_SEQ_LEN_PARAM: {
// Seq_lengths should be as a Parameter node for these two modes
seq_lengths = in.at(3).get_node_shared_ptr();
break;
}
default:
throw std::runtime_error("Incorrect mode for creation of Sequence operation");
}
return std::make_shared<ngraph::opset5::LSTMSequence>(in[0], in[1], in[2], seq_lengths, W, R, B, hidden_size, direction,
activations_alpha, activations_beta, activations, clip);
}
}

View File

@ -18,7 +18,8 @@ std::shared_ptr<ngraph::Node> makeRNN(const OutputVector& in,
const std::vector<float>& activations_beta,
float clip,
bool make_sequence,
ngraph::op::RecurrentSequenceDirection direction) {
ngraph::op::RecurrentSequenceDirection direction,
ngraph::helpers::SequenceTestsMode mode) {
std::vector<float> empty;
auto W = ngraph::builder::makeConstant(in[0].get_element_type(), constants[0], empty, true);
auto R = ngraph::builder::makeConstant(in[0].get_element_type(), constants[1], empty, true);
@ -27,9 +28,32 @@ std::shared_ptr<ngraph::Node> makeRNN(const OutputVector& in,
return std::make_shared<ngraph::opset4::RNNCell>(in[0], in[1], W, R, B, hidden_size, activations,
activations_alpha, activations_beta, clip);
} else {
std::vector<float> lenghts(in[0].get_shape()[0], in[0].get_shape()[1]);
auto seq_lenghts = ngraph::builder::makeConstant(in[0].get_element_type(), constants[3], lenghts, false);
return std::make_shared<ngraph::opset5::RNNSequence>(in[0], in[1], seq_lenghts, W, R, B, hidden_size, direction,
std::shared_ptr<Node> seq_lengths;
switch (mode) {
case ngraph::helpers::SequenceTestsMode::PURE_SEQ:
case ngraph::helpers::SequenceTestsMode::CONVERT_TO_TI_MAX_SEQ_LEN_CONST: {
std::vector<float> lengths(in[0].get_shape()[0], in[0].get_shape()[1]);
seq_lengths = ngraph::builder::makeConstant(element::i64, constants[3], lengths, false);
break;
}
case ngraph::helpers::SequenceTestsMode::CONVERT_TO_TI_RAND_SEQ_LEN_CONST: {
for (size_t i = 0; i <= in[0].get_shape().at(0); ++i) {
std::vector<float> lengths;
seq_lengths = ngraph::builder::makeConstant(element::i64, constants[3], lengths, true,
in[0].get_shape()[1], 0);
}
break;
}
case ngraph::helpers::SequenceTestsMode::CONVERT_TO_TI_RAND_SEQ_LEN_PARAM:
case ngraph::helpers::SequenceTestsMode::CONVERT_TO_TI_MAX_SEQ_LEN_PARAM: {
// Seq_lengths should be as a Parameter node for these two modes
seq_lengths = in.at(2).get_node_shared_ptr();
break;
}
default:
throw std::runtime_error("Incorrect mode for creation of Sequence operation");
}
return std::make_shared<ngraph::opset5::RNNSequence>(in[0], in[1], seq_lengths, W, R, B, hidden_size, direction,
activations, activations_alpha, activations_beta, clip);
}
}

View File

@ -9,6 +9,7 @@
#include <ngraph/op/util/op_types.hpp>
#include <ngraph/opsets/opset1.hpp>
#include <ngraph/opsets/opset3.hpp>
#include <ngraph/opsets/opset5.hpp>
#include <ngraph/pass/constant_folding.hpp>
#include <ngraph/specialize_function.hpp>
@ -255,6 +256,17 @@ std::vector<std::uint8_t> convertPrecision(std::vector<std::uint8_t> &buffer, co
return convertedData;
}
bool is_tensor_iterator_exist(const std::shared_ptr<ngraph::Function> & func) {
const auto& ops = func->get_ops();
for (const auto& node : ops) {
const auto& ti = std::dynamic_pointer_cast<ngraph::opset5::TensorIterator>(node);
if (ti) {
return true;
}
}
return false;
}
std::vector<std::uint8_t> convertOutputPrecision(std::vector<std::uint8_t> &output, const element::Type_t &fromPrecision, const element::Type_t &toPrecision,
const size_t elementsCount) {
switch (fromPrecision) {
@ -748,5 +760,28 @@ std::ostream& operator<<(std::ostream & os, TensorIteratorBody type) {
}
return os;
}
std::ostream& operator<<(std::ostream & os, SequenceTestsMode type) {
switch (type) {
case SequenceTestsMode::PURE_SEQ:
os << "PURE_SEQ";
break;
case SequenceTestsMode::CONVERT_TO_TI_RAND_SEQ_LEN_PARAM:
os << "CONVERT_TO_TI_RAND_SEQ_LEN_PARAM";
break;
case SequenceTestsMode::CONVERT_TO_TI_RAND_SEQ_LEN_CONST:
os << "CONVERT_TO_TI_RAND_SEQ_LEN_CONST";
break;
case SequenceTestsMode::CONVERT_TO_TI_MAX_SEQ_LEN_PARAM:
os << "CONVERT_TO_TI_MAX_SEQ_LEN_PARAM";
break;
case SequenceTestsMode::CONVERT_TO_TI_MAX_SEQ_LEN_CONST:
os << "CONVERT_TO_TI_MAX_SEQ_LEN_CONST";
break;
default:
throw std::runtime_error("NOT_SUPPORTED_OP_TYPE");
}
return os;
}
} // namespace helpers
} // namespace ngraph

View File

@ -45,7 +45,7 @@ namespace ngraph
bool linear_before_reset = false; // GRU
};
template <typename T>
template <typename T, typename U>
void cell_pass(CellType type,
const std::vector<const char*>& inputs,
const std::vector<Shape>& shapes,
@ -70,17 +70,49 @@ namespace ngraph
// split X
size_t num_splits = shapes[0].at(1);
std::vector<std::vector<char>> in_seqs(
num_splits, std::vector<char>(x_shape_size / num_splits * sizeof(T)));
size_t part_size = x_shape_size / num_splits * sizeof(T);
std::vector<char> in_seqs(x_shape_size * sizeof(T));
std::vector<char*> pointers(num_splits);
for (size_t i = 0; i < num_splits; ++i)
pointers[is_reverse ? num_splits - i - 1 : i] = in_seqs[i].data();
reference::split(inputs[0], shapes[0], sizeof(T), 1, num_splits, pointers.data());
Shape part_shape{shapes[0][0], 1, shapes[2][2]};
// in case of seq_lengths input was provided and filled with values !=
// max_seq_lengths
// we have to fill some of the outputs with zeros (apply mask)
size_t batch = shapes[0][0];
size_t hidden_size = shapes[2][2];
int64_t max_seq_lengths = num_splits;
const auto* seq_len_values = reinterpret_cast<const U*>(inputs[1]);
bool enable_mask = false;
for (size_t i = 0; i < batch; ++i)
{
enable_mask |= (max_seq_lengths != seq_len_values[i]);
}
std::vector<char> temp_buffer(x_shape_size * sizeof(T));
if (is_reverse)
{
reference::reverse_sequence<T, U>(reinterpret_cast<const T*>(inputs[0]),
reinterpret_cast<T*>(temp_buffer.data()),
shapes[0],
0,
1,
seq_len_values);
}
else
{
memcpy(temp_buffer.data(), inputs[0], x_shape_size * sizeof(T));
}
for (size_t i = 0; i < num_splits; ++i)
pointers[i] = in_seqs.data() + i * part_size;
reference::split(
temp_buffer.data(), shapes[0], sizeof(T), 1, num_splits, pointers.data());
Shape part_shape{batch, 1, hidden_size};
size_t part_shape_size = ngraph::shape_size(part_shape);
std::vector<std::vector<char>> h_list(
num_splits, std::vector<char>(ngraph::shape_size(part_shape) * sizeof(T)));
num_splits, std::vector<char>(part_shape_size * sizeof(T), 0));
std::vector<std::vector<char>> c_list(
num_splits, std::vector<char>(part_shape_size * sizeof(T), 0));
// use outputs as a buffer for temporarily values
char* H_i = outputs[1];
@ -98,7 +130,7 @@ namespace ngraph
if (type == CellType::LSTM)
{
runtime::reference::lstm_cell<T>(
reinterpret_cast<const T*>(in_seqs[time_step].data()),
reinterpret_cast<const T*>(in_seqs.data() + time_step * part_size),
squeeze_axis(shapes[0], 1),
reinterpret_cast<const T*>(H_i),
squeeze_axis(shapes[2], 1),
@ -120,7 +152,7 @@ namespace ngraph
else if (type == CellType::RNN)
{
runtime::reference::rnn_cell<T>(
reinterpret_cast<const T*>(in_seqs[time_step].data()),
reinterpret_cast<const T*>(in_seqs.data() + time_step * part_size),
squeeze_axis(shapes[0], 1),
reinterpret_cast<const T*>(H_i),
squeeze_axis(shapes[2], 1),
@ -137,7 +169,7 @@ namespace ngraph
else if (type == CellType::GRU)
{
runtime::reference::gru_cell<T>(
reinterpret_cast<const T*>(in_seqs[time_step].data()),
reinterpret_cast<const T*>(in_seqs.data() + time_step * part_size),
squeeze_axis(shapes[0], 1),
reinterpret_cast<const T*>(H_i),
squeeze_axis(shapes[2], 1),
@ -153,23 +185,87 @@ namespace ngraph
args.clip,
args.linear_before_reset);
}
std::memcpy(h_list[time_step].data(), outputs[1], part_shape_size * sizeof(T));
if (enable_mask)
{
size_t part_size_single_batch = part_shape_size / batch * sizeof(T);
for (int i = 0; i < batch; ++i)
{
if ((time_step + 1) > seq_len_values[i])
{
continue;
}
std::memcpy(h_list[time_step].data() + i * part_size_single_batch,
outputs[1] + i * part_size_single_batch,
part_size_single_batch);
if (type == CellType::LSTM)
{
std::memcpy(c_list[time_step].data() + i * part_size_single_batch,
outputs[2] + i * part_size_single_batch,
part_size_single_batch);
}
}
if ((num_splits - time_step) > 1)
{
std::memcpy(
outputs[1], h_list[time_step].data(), part_shape_size * sizeof(T));
if (type == CellType::LSTM)
{
std::memcpy(outputs[2],
c_list[time_step].data(),
part_shape_size * sizeof(T));
}
}
else
{
for (int i = 0; i < batch; ++i)
{
std::memcpy(outputs[1] + i * part_size_single_batch,
h_list[seq_len_values[i] - 1].data() +
i * part_size_single_batch,
part_size_single_batch);
if (type == CellType::LSTM)
{
std::memcpy(outputs[2] + i * part_size_single_batch,
c_list[seq_len_values[i] - 1].data() +
i * part_size_single_batch,
part_size_single_batch);
}
}
}
}
else
{
std::memcpy(
h_list[time_step].data(), outputs[1], part_shape_size * sizeof(T));
}
}
// The tensor that concats all the intermediate output values of the hidden.
// It has shape [batch_size, seq_length, hidden_size]
std::vector<Shape> in_shapes(num_splits, part_shape);
std::vector<const char*> to_concat_pointers(num_splits);
Shape out_shape{batch, num_splits, hidden_size};
for (size_t i = 0; i < num_splits; ++i)
to_concat_pointers[is_reverse ? num_splits - i - 1 : i] = h_list[i].data();
runtime::reference::concat(to_concat_pointers,
outputs[0],
in_shapes,
{shapes[0][0], shapes[0][1], shapes[2][2]},
1,
sizeof(T));
to_concat_pointers[i] = h_list[i].data();
runtime::reference::concat(
to_concat_pointers, outputs[0], in_shapes, out_shape, 1, sizeof(T));
if (is_reverse) // enable_mask
{
temp_buffer.resize(shape_size(out_shape) * sizeof(T));
reference::reverse_sequence<T, U>(reinterpret_cast<const T*>(outputs[0]),
reinterpret_cast<T*>(temp_buffer.data()),
out_shape,
0,
1,
seq_len_values);
std::memcpy(outputs[0], temp_buffer.data(), shape_size(out_shape) * sizeof(T));
}
}
template <typename T>
template <typename T, typename U>
void lstm_sequence(const char* X,
const Shape& X_shape,
const char* H,
@ -206,12 +302,12 @@ namespace ngraph
std::vector<char*> outputs = {Y, Ho, Co};
std::vector<Shape> shapes = {
X_shape, seq_lengths_shape, H_shape, C_shape, W_shape, R_shape, B_shape};
cell_pass<T>(CellType::LSTM,
inputs,
shapes,
outputs,
args,
direction == op::RecurrentSequenceDirection::REVERSE);
cell_pass<T, U>(CellType::LSTM,
inputs,
shapes,
outputs,
args,
direction == op::RecurrentSequenceDirection::REVERSE);
}
else if (direction == op::RecurrentSequenceDirection::BIDIRECTIONAL)
{
@ -261,7 +357,7 @@ namespace ngraph
shapes[i][0] = 1;
}
// forward pass
cell_pass<T>(
cell_pass<T, U>(
CellType::LSTM,
{X,
seq_lengths,
@ -275,7 +371,7 @@ namespace ngraph
args,
false);
// reverse pass
cell_pass<T>(
cell_pass<T, U>(
CellType::LSTM,
{X,
seq_lengths,
@ -318,7 +414,7 @@ namespace ngraph
}
}
template <typename T>
template <typename T, typename U>
void gru_sequence(const char* X,
const Shape& X_shape,
const char* H,
@ -352,12 +448,12 @@ namespace ngraph
std::vector<char*> outputs = {Y, Ho};
std::vector<Shape> shapes = {
X_shape, seq_lengths_shape, H_shape, W_shape, R_shape, B_shape};
cell_pass<T>(CellType::GRU,
inputs,
shapes,
outputs,
args,
direction == op::RecurrentSequenceDirection::REVERSE);
cell_pass<T, U>(CellType::GRU,
inputs,
shapes,
outputs,
args,
direction == op::RecurrentSequenceDirection::REVERSE);
}
else if (direction == op::RecurrentSequenceDirection::BIDIRECTIONAL)
{
@ -400,29 +496,29 @@ namespace ngraph
shapes[i][0] = 1;
}
// forward pass
cell_pass<T>(CellType::GRU,
{X,
seq_lengths,
h_pointers[0],
w_pointers[0],
r_pointers[0],
b_pointers[0]},
shapes,
{forward_res_y.data(), forward_res_h.data()},
args,
false);
cell_pass<T, U>(CellType::GRU,
{X,
seq_lengths,
h_pointers[0],
w_pointers[0],
r_pointers[0],
b_pointers[0]},
shapes,
{forward_res_y.data(), forward_res_h.data()},
args,
false);
// reverse pass
cell_pass<T>(CellType::GRU,
{X,
seq_lengths,
h_pointers[1],
w_pointers[1],
r_pointers[1],
b_pointers[1]},
shapes,
{reverse_res_y.data(), reverse_res_h.data()},
args,
true);
cell_pass<T, U>(CellType::GRU,
{X,
seq_lengths,
h_pointers[1],
w_pointers[1],
r_pointers[1],
b_pointers[1]},
shapes,
{reverse_res_y.data(), reverse_res_h.data()},
args,
true);
// Stack together respective outputs from both forward and reverse passes.
std::vector<Shape> in_shapes_y = {{H_shape[0], 1, X_shape[1], H_shape[2]},
@ -447,7 +543,7 @@ namespace ngraph
}
}
template <typename T>
template <typename T, typename U>
void rnn_sequence(const char* X,
const Shape& X_shape,
const char* H,
@ -477,12 +573,12 @@ namespace ngraph
std::vector<char*> outputs = {Y, Ho};
std::vector<Shape> shapes = {
X_shape, seq_lengths_shape, H_shape, W_shape, R_shape, B_shape};
cell_pass<T>(CellType::RNN,
inputs,
shapes,
outputs,
args,
direction == op::RecurrentSequenceDirection::REVERSE);
cell_pass<T, U>(CellType::RNN,
inputs,
shapes,
outputs,
args,
direction == op::RecurrentSequenceDirection::REVERSE);
}
else if (direction == op::RecurrentSequenceDirection::BIDIRECTIONAL)
{
@ -523,29 +619,29 @@ namespace ngraph
shapes[i][0] = 1;
}
// forward pass
cell_pass<T>(CellType::RNN,
{X,
seq_lengths,
h_pointers[0],
w_pointers[0],
r_pointers[0],
b_pointers[0]},
shapes,
{forward_res_y.data(), forward_res_h.data()},
args,
false);
cell_pass<T, U>(CellType::RNN,
{X,
seq_lengths,
h_pointers[0],
w_pointers[0],
r_pointers[0],
b_pointers[0]},
shapes,
{forward_res_y.data(), forward_res_h.data()},
args,
false);
// reverse pass
cell_pass<T>(CellType::RNN,
{X,
seq_lengths,
h_pointers[1],
w_pointers[1],
r_pointers[1],
b_pointers[1]},
shapes,
{reverse_res_y.data(), reverse_res_h.data()},
args,
true);
cell_pass<T, U>(CellType::RNN,
{X,
seq_lengths,
h_pointers[1],
w_pointers[1],
r_pointers[1],
b_pointers[1]},
shapes,
{reverse_res_y.data(), reverse_res_h.data()},
args,
true);
// Stack together respective outputs from both forward and reverse passes.
std::vector<Shape> in_shapes_y = {{H_shape[0], 1, X_shape[1], H_shape[2]},

View File

@ -91,7 +91,9 @@ namespace ngraph
std::vector<char*> pointers_to_data(num_iterations);
for (size_t j = 0; j < pointers_to_data.size(); ++j)
{
pointers_to_data[j] =
pointers_to_data[slice_desc->m_stride > 0
? j
: (pointers_to_data.size() - j - 1)] =
sliced_values[slice_in_idx][j]->get_data_ptr<char>();
}
reference::split(args[slice_desc->m_input_index]->get_data_ptr<char>(),
@ -118,6 +120,7 @@ namespace ngraph
}
// Evaluate body
body_outputs.clear();
if (!evaluate)
{
reference::function(func, inputs_to_body, body_outputs);
@ -164,9 +167,13 @@ namespace ngraph
out[concat_desc->m_output_index]->set_shape(shape);
std::vector<const char*> pointers_on_values;
pointers_on_values.reserve(values_to_concat[i].size());
for (const auto& vec : values_to_concat[i])
for (size_t j = 0; j < values_to_concat[i].size(); ++j)
{
pointers_on_values.push_back(vec->get_data_ptr<char>());
pointers_on_values.push_back(
values_to_concat[i][concat_desc->m_stride > 0
? j
: (values_to_concat[i].size() - j - 1)]
->get_data_ptr<char>());
}
reference::concat(pointers_on_values,
out[concat_desc->m_output_index]->get_data_ptr<char>(),

View File

@ -1500,17 +1500,17 @@ def lstm_sequence(
Shape: [batch_size]. Integer type.
@param W: Tensor with weights for matrix multiplication operation with input portion of data.
Shape: [num_directions, 4*hidden_size, input_size].
@param R: The tensor with weights for matrix multiplication operation with hidden state.
Shape: [num_directions, 4*hidden_size, input_size].
@param B: The tensor with biases.
:param R: The tensor with weights for matrix multiplication operation with hidden state.
Shape: [num_directions, 4*hidden_size, hidden_size].
@param hidden_size: Specifies hidden state size.
@param direction: Specifies if the RNN is forward, reverse, or bidirectional.
@param activations: The list of three activation functions for gates.
@param activations_alpha: The list of alpha parameters for activation functions.
@param activations_beta: The list of beta parameters for activation functions.
@param clip: Specifies bound values [-C, C] for tensor clipping performed before activations.
@param name: An optional name of the output node.
:param B: The tensor with biases.
Shape: [num_directions, 4*hidden_size].
:param hidden_size: Specifies hidden state size.
:param direction: Specifies if the RNN is forward, reverse, or bidirectional.
:param activations: The list of three activation functions for gates.
:param activations_alpha: The list of alpha parameters for activation functions.
:param activations_beta: The list of beta parameters for activation functions.
:param clip: Specifies bound values [-C, C] for tensor clipping performed before activations.
:param name: An optional name of the output node.
@return The new node represents LSTMSequence. Node outputs count: 3.
"""

View File

@ -312,7 +312,7 @@ def gru_cell(
@return The new node performing a GRUCell operation on tensor from input node.
"""
if activations is None:
activations = ["relu", "sigmoid", "tanh"]
activations = ["sigmoid", "tanh"]
if activations_alpha is None:
activations_alpha = []
if activations_beta is None:
@ -432,7 +432,7 @@ def rnn_cell(
@param W: The weight tensor with shape: [hidden_size, input_size].
@param R: The recurrence weight tensor with shape: [hidden_size,
hidden_size].
@param B: The bias tensor for input gate with shape: [2*hidden_size].
@param B: The sum of biases (weight and recurrence) with shape: [hidden_size].
@param hidden_size: The number of hidden units for recurrent cell.
Specifies hidden state size.
@param activations: The vector of activation functions used inside recurrent cell.
@ -446,7 +446,7 @@ def rnn_cell(
@return The new node performing a RNNCell operation on tensor from input node.
"""
if activations is None:
activations = ["sigmoid", "tanh"]
activations = ["tanh"]
if activations_alpha is None:
activations_alpha = []
if activations_beta is None:
@ -557,7 +557,7 @@ def shape_of(data: NodeInput, output_type: str = "i64", name: Optional[str] = No
"""! Return a node which produces a tensor containing the shape of its input data.
@param data: The tensor containing the input data.
:para output_type: Output element type.
@param output_type: Output element type.
@return ShapeOf node
"""
return _get_node_factory_opset3().create(

View File

@ -70,15 +70,15 @@ def batch_norm_inference(
) -> Node:
"""Perform layer normalizes a input tensor by mean and variance with appling scale and offset.
:param data: The input tensor with data for normalization.
:param gamma: The scalar scaling for normalized value.
:param beta: The bias added to the scaled normalized value.
:param mean: The value for mean normalization.
:param variance: The value for variance normalization.
:param epsilon: The number to be added to the variance to avoid division
@param data: The input tensor with data for normalization.
@param gamma: The scalar scaling for normalized value.
@param beta: The bias added to the scaled normalized value.
@param mean: The value for mean normalization.
@param variance: The value for variance normalization.
@param epsilon: The number to be added to the variance to avoid division
by zero when normalizing a value.
:param name: The optional name of the output node.
:return: The new node which performs BatchNormInference.
@param name: The optional name of the output node.
@return: The new node which performs BatchNormInference.
"""
inputs = as_nodes(data, gamma, beta, mean, variance)
return _get_node_factory_opset5().create("BatchNormInference", inputs, {"epsilon": epsilon})
@ -93,10 +93,10 @@ def gather_nd(
) -> Node:
"""Return a node which performs GatherND.
:param data: N-D tensor with data for gathering
:param indices: K-D tensor of tuples with indices by which data is gathered
:param batch_dims: Scalar value of batch dimensions
:return: The new node which performs GatherND
@param data: N-D tensor with data for gathering
@param indices: K-D tensor of tuples with indices by which data is gathered
@param batch_dims: Scalar value of batch dimensions
@return: The new node which performs GatherND
"""
inputs = as_nodes(data, indices)
@ -111,9 +111,9 @@ def gather_nd(
def log_softmax(data: NodeInput, axis: int, name: Optional[str] = None) -> Node:
"""Apply LogSoftmax operation on each element of input tensor.
:param data: The tensor providing input data.
:param axis: An axis along which LogSoftmax should be calculated
:return: The new node with LogSoftmax operation applied on each element.
@param data: The tensor providing input data.
@param axis: An axis along which LogSoftmax should be calculated
@return: The new node with LogSoftmax operation applied on each element.
"""
return _get_node_factory_opset5().create("LogSoftmax", [as_node(data)], {"axis": axis})
@ -133,18 +133,18 @@ def non_max_suppression(
) -> Node:
"""Return a node which performs NonMaxSuppression.
:param boxes: Tensor with box coordinates.
:param scores: Tensor with box scores.
:param max_output_boxes_per_class: Tensor Specifying maximum number of boxes
@param boxes: Tensor with box coordinates.
@param scores: Tensor with box scores.
@param max_output_boxes_per_class: Tensor Specifying maximum number of boxes
to be selected per class.
:param iou_threshold: Tensor specifying intersection over union threshold
:param score_threshold: Tensor specifying minimum score to consider box for the processing.
:param soft_nms_sigma: Tensor specifying the sigma parameter for Soft-NMS.
:param box_encoding: Format of boxes data encoding.
:param sort_result_descending: Flag that specifies whenever it is necessary to sort selected
@param iou_threshold: Tensor specifying intersection over union threshold
@param score_threshold: Tensor specifying minimum score to consider box for the processing.
@param soft_nms_sigma: Tensor specifying the sigma parameter for Soft-NMS.
@param box_encoding: Format of boxes data encoding.
@param sort_result_descending: Flag that specifies whenever it is necessary to sort selected
boxes across batches or not.
:param output_type: Output element type.
:return: The new node which performs NonMaxSuppression
@param output_type: Output element type.
@return: The new node which performs NonMaxSuppression
"""
if max_output_boxes_per_class is None:
max_output_boxes_per_class = make_constant_node(0, np.int64)
@ -174,63 +174,131 @@ def non_max_suppression(
def round(data: NodeInput, mode: str = "half_to_even", name: Optional[str] = None) -> Node:
"""Apply Round operation on each element of input tensor.
:param data: The tensor providing input data.
:param mode: Rule to round halfway cases. If set to 'half_to_even' then halfs round to the nearest even
@param data: The tensor providing input data.
@param mode: Rule to round halfway cases. If set to 'half_to_even' then halfs round to the nearest even
integer or rounding in such a way that the result heads away from zero if `mode` attribute is
'half_away_from_zero`.
:param name: An optional name of the output node.
:return: The new node with Round operation applied on each element.
@param name: An optional name of the output node.
@return: The new node with Round operation applied on each element.
"""
return _get_node_factory_opset5().create("Round", as_nodes(data), {"mode": mode.upper()})
@nameable_op
def lstm_sequence(
X: NodeInput,
initial_hidden_state: NodeInput,
initial_cell_state: NodeInput,
sequence_lengths: NodeInput,
W: NodeInput,
R: NodeInput,
B: NodeInput,
hidden_size: int,
direction: str,
activations: List[str] = None,
activations_alpha: List[float] = None,
activations_beta: List[float] = None,
clip: float = 0.0,
name: Optional[str] = None,
) -> Node:
"""Return a node which performs LSTMSequence operation.
@param X: The input tensor. Shape: [batch_size, seq_length, input_size].
@param initial_hidden_state: The hidden state tensor.
Shape: [batch_size, num_directions, hidden_size].
@param initial_cell_state: The cell state tensor.
Shape: [batch_size, num_directions, hidden_size].
@param sequence_lengths: Specifies real sequence lengths for each batch element.
Shape: [batch_size]. Integer type.
@param W: Tensor with weights for matrix multiplication operation with input portion of data.
Expected format: fico
Shape: [num_directions, 4*hidden_size, input_size].
@param R: The tensor with weights for matrix multiplication operation with hidden state.
Expected format: fico
Shape: [num_directions, 4*hidden_size, hidden_size].
@param B: The sum of biases (weight and recurrence). Expected format: fico
Shape: [num_directions, 4*hidden_size].
@param hidden_size: Specifies hidden state size.
@param direction: Specifies if the RNN is forward, reverse, or bidirectional.
@param activations: The list of three activation functions for gates.
@param activations_alpha: The list of alpha parameters for activation functions.
@param activations_beta: The list of beta parameters for activation functions.
@param clip: Specifies bound values [-C, C] for tensor clipping performed before activations.
@param name: An optional name of the output node.
@return: The new node represents LSTMSequence. Node outputs count: 3.
"""
if activations is None:
activations = ["sigmoid", "tanh", "tanh"]
if activations_alpha is None:
activations_alpha = []
if activations_beta is None:
activations_beta = []
node_inputs = as_nodes(X, initial_hidden_state, initial_cell_state, sequence_lengths, W, R, B)
attributes = {
"hidden_size": hidden_size,
"direction": direction.lower(),
"activations": activations,
"activations_alpha": activations_alpha,
"activations_beta": activations_beta,
"clip": clip,
}
return _get_node_factory_opset5().create("LSTMSequence", node_inputs, attributes)
def hsigmoid(data: NodeInput, name: Optional[str] = None,) -> Node:
"""Return a node which performs HSigmoid.
:param data: Tensor with input data floating point type.
:return: The new node which performs HSigmoid
@param data: Tensor with input data floating point type.
@return: The new node which performs HSigmoid
"""
return _get_node_factory_opset5().create("HSigmoid", as_nodes(data), {})
@nameable_op
def gru_sequence(
X: NodeInput,
H_t: NodeInput,
sequence_lengths: NodeInput,
W: NodeInput,
R: NodeInput,
B: NodeInput,
hidden_size: int,
direction: str,
activations: List[str] = None,
activations_alpha: List[float] = None,
activations_beta: List[float] = None,
clip: float = 0.0,
linear_before_reset: bool = False,
name: Optional[str] = None,
X: NodeInput,
initial_hidden_state: NodeInput,
sequence_lengths: NodeInput,
W: NodeInput,
R: NodeInput,
B: NodeInput,
hidden_size: int,
direction: str,
activations: List[str] = None,
activations_alpha: List[float] = None,
activations_beta: List[float] = None,
clip: float = 0.0,
linear_before_reset: bool = False,
name: Optional[str] = None,
) -> Node:
"""Return a node which performs GRUSequence.
"""Return a node which performs GRUSequence operation.
:param X: 3D tensor, input data.
:param H_t: 3D tensor, input hidden state data.
:param sequence_lengths: 1D tensor, specifies sequence lenghts
for each batch element.
:param W: 3D tensor, weights matrix.
:param R: 3D tensor, recurrence weights matrix.
:param B: 2D tensor, sum of biases.
:param hidden_size: Size of the hidden state.
:param direction: Specify if the RNN is forward, reverse, or bidirectional.
:param activations: Activation functions for gates.
:param activations_alpha: Attributes of function; applicability and meaning
of these attributes depends on choosen activation function.
:param activations_beta: Attributes of function; applicability and meaning
of these attributes depends on choosen activation function.
:param clip: Specifies bound values *[-clip, clip]* for tensor clipping.
:param linear_before_reset: During the computation of the output of
the hidden gate, apply the linear transformation.
:return: The new node which performs GRUSequence
@param X: The input tensor. Shape: [batch_size, seq_length, input_size].
@param initial_hidden_state: The hidden state tensor.
Shape: [batch_size, num_directions, hidden_size].
@param sequence_lengths: Specifies real sequence lengths for each batch element.
Shape: [batch_size]. Integer type.
@param W: Tensor with weights for matrix multiplication operation with input portion of data.
Shape: [num_directions, 3*hidden_size, input_size].
@param R: The tensor with weights for matrix multiplication operation with hidden state.
Shape: [num_directions, 3*hidden_size, hidden_size].
@param B: The sum of biases (weight and recurrence).
For linear_before_reset set True the shape is [num_directions, 4*hidden_size].
Otherwise the shape is [num_directions, 3*hidden_size].
@param hidden_size: Specifies hidden state size.
@param direction: Specifies if the RNN is forward, reverse, or bidirectional.
@param activations: The list of three activation functions for gates.
@param activations_alpha: The list of alpha parameters for activation functions.
@param activations_beta: The list of beta parameters for activation functions.
@param clip: Specifies bound values [-C, C] for tensor clipping performed before activations.
@param linear_before_reset: Flag denotes if the layer behaves according to the modification
of GRU described in the formula in the ONNX documentation.
@param name: An optional name of the output node.
@return: The new node represents GRUSequence. Node outputs count: 2.
"""
if activations is None:
activations = ["sigmoid", "tanh"]
@ -239,54 +307,58 @@ def gru_sequence(
if activations_beta is None:
activations_beta = []
inputs = as_nodes(X, H_t, sequence_lengths, W, R, B)
node_inputs = as_nodes(X, initial_hidden_state, sequence_lengths, W, R, B)
attributes = {
"hidden_size": hidden_size,
"direction": direction.lower(),
"activations": activations,
"activations_alpha": activations_alpha,
"activations_beta": activations_alpha,
"clip": clip,
"activations_beta": activations_beta,
"linear_before_reset": linear_before_reset,
"clip": clip,
}
return _get_node_factory_opset5().create("GRUSequence", inputs, attributes)
return _get_node_factory_opset5().create("GRUSequence", node_inputs, attributes)
@nameable_op
def rnn_sequence(
X: NodeInput,
H_t: NodeInput,
sequence_lengths: NodeInput,
W: NodeInput,
R: NodeInput,
B: NodeInput,
hidden_size: int,
direction: str,
activations: List[str] = None,
activations_alpha: List[float] = None,
activations_beta: List[float] = None,
clip: float = 0.0,
name: Optional[str] = None,
X: NodeInput,
initial_hidden_state: NodeInput,
sequence_lengths: NodeInput,
W: NodeInput,
R: NodeInput,
B: NodeInput,
hidden_size: int,
direction: str,
activations: List[str] = None,
activations_alpha: List[float] = None,
activations_beta: List[float] = None,
clip: float = 0.0,
name: Optional[str] = None,
) -> Node:
"""Return a node which performs RNNSequence.
"""Return a node which performs RNNSequence operation.
:param X: 3D tensor, input data.
:param H_t: 3D tensor, input hidden state data.
:param sequence_lengths: 1D tensor, specifies sequence lenghts
for each batch element.
:param W: 3D tensor, weights matrix.
:param R: 3D tensor, recurrence weights matrix.
:param B: 2D tensor, sum of biases.
:param hidden_size: Size of the hidden state.
:param direction: Specify if the RNN is forward, reverse, or bidirectional.
:param activations: Activation functions for gates.
:param activations_alpha: Attributes of function; applicability and meaning
of these attributes depends on choosen activation function.
:param activations_beta: Attributes of function; applicability and meaning
of these attributes depends on choosen activation function.
:param clip: Specifies bound values *[-clip, clip]* for tensor clipping.
:return: The new node which performs RNNSequence
@param X: The input tensor. Shape: [batch_size, seq_length, input_size].
@param initial_hidden_state: The hidden state tensor.
Shape: [batch_size, num_directions, hidden_size].
@param sequence_lengths: Specifies real sequence lengths for each batch element.
Shape: [batch_size]. Integer type.
@param W: Tensor with weights for matrix multiplication operation with input portion of data.
Shape: [num_directions, hidden_size, input_size].
@param R: The tensor with weights for matrix multiplication operation with hidden state.
Shape: [num_directions, hidden_size, hidden_size].
@param B: The sum of biases (weight and recurrence).
Shape: [num_directions, hidden_size].
@param hidden_size: Specifies hidden state size.
@param direction: Specifies if the RNN is forward, reverse, or bidirectional.
@param activations: The list of three activation functions for gates.
@param activations_alpha: The list of alpha parameters for activation functions.
@param activations_beta: The list of beta parameters for activation functions.
@param clip: Specifies bound values [-C, C] for tensor clipping performed before activations.
@param name: An optional name of the output node.
@return: The new node represents RNNSequence. Node outputs count: 2.
"""
if activations is None:
activations = ["tanh"]
@ -295,13 +367,14 @@ def rnn_sequence(
if activations_beta is None:
activations_beta = []
inputs = as_nodes(X, H_t, sequence_lengths, W, R, B)
inputs = as_nodes(X, initial_hidden_state, sequence_lengths, W, R, B)
attributes = {
"hidden_size": hidden_size,
"direction": direction.lower(),
"activations": activations,
"activations_alpha": activations_alpha,
"activations_beta": activations_alpha,
"activations_beta": activations_beta,
"clip": clip,
}
@ -316,11 +389,11 @@ def loop(
) -> Node:
"""Return a node which performs Loop.
:param trip_count: A scalar or 1D tensor with 1 element specifying
@param trip_count: A scalar or 1D tensor with 1 element specifying
maximum number of iterations.
:param execution_condition: A scalar or 1D tensor with 1 element
@param execution_condition: A scalar or 1D tensor with 1 element
specifying whether to execute the first iteration or not.
:return: The new node which performs Loop.
@return: The new node which performs Loop.
"""
inputs = as_nodes(trip_count, execution_condition)

View File

@ -235,4 +235,3 @@ xfail_issue_39663 = xfail_test(reason="RuntimeError: Unsupported primitive of ty
xfail_issue_41815 = xfail_test(reason="RuntimeError: Unsupported dynamic ops: v5::NonMaxSuppression casted "
"(yolo_evaluation_layer_1/concat_6:0_btc[0]:f32{1,2535,4},")
xfail_issue_41894 = xfail_test(reason="CPU plugin elementwise computation missmatch")
xfail_issue_42818 = xfail_test(reason="AssertionError: This model has no test data")

View File

@ -290,7 +290,7 @@ def test_lstm_cell_operator_opset1(dtype):
@pytest.mark.parametrize("dtype", [np.float32, np.float64])
def test_lstm_sequence_operator_bidirectional(dtype):
def test_lstm_sequence_operator_bidirectional_opset1(dtype):
batch_size = 1
input_size = 16
hidden_size = 128
@ -355,7 +355,7 @@ def test_lstm_sequence_operator_bidirectional(dtype):
@pytest.mark.parametrize("dtype", [np.float32, np.float64])
def test_lstm_sequence_operator_reverse(dtype):
def test_lstm_sequence_operator_reverse_opset1(dtype):
batch_size = 2
input_size = 4
hidden_size = 3
@ -421,7 +421,7 @@ def test_lstm_sequence_operator_reverse(dtype):
@pytest.mark.parametrize("dtype", [np.float32, np.float64])
def test_lstm_sequence_operator_forward(dtype):
def test_lstm_sequence_operator_forward_opset1(dtype):
batch_size = 2
input_size = 4
hidden_size = 3
@ -1089,3 +1089,582 @@ def test_extract_image_patches():
assert node.get_output_size() == 1
assert list(node.get_output_shape(0)) == [64, 27, 2, 2]
assert node.get_output_element_type(0) == Type.i32
@pytest.mark.parametrize("dtype", [np.float32, np.float64])
def test_lstm_sequence_operator_bidirectional(dtype):
batch_size = 1
input_size = 16
hidden_size = 128
num_directions = 2
seq_length = 2
X_shape = [batch_size, seq_length, input_size]
H_t_shape = [batch_size, num_directions, hidden_size]
C_t_shape = [batch_size, num_directions, hidden_size]
seq_len_shape = [batch_size]
W_shape = [num_directions, 4 * hidden_size, input_size]
R_shape = [num_directions, 4 * hidden_size, hidden_size]
B_shape = [num_directions, 4 * hidden_size]
parameter_X = ng.parameter(X_shape, name="X", dtype=dtype)
parameter_H_t = ng.parameter(H_t_shape, name="H_t", dtype=dtype)
parameter_C_t = ng.parameter(C_t_shape, name="C_t", dtype=dtype)
parameter_seq_len = ng.parameter(seq_len_shape, name="seq_len", dtype=np.int32)
parameter_W = ng.parameter(W_shape, name="W", dtype=dtype)
parameter_R = ng.parameter(R_shape, name="R", dtype=dtype)
parameter_B = ng.parameter(B_shape, name="B", dtype=dtype)
direction = "BIDIRECTIONAL"
node = ng.lstm_sequence(
parameter_X,
parameter_H_t,
parameter_C_t,
parameter_seq_len,
parameter_W,
parameter_R,
parameter_B,
hidden_size,
direction,
)
assert node.get_type_name() == "LSTMSequence"
assert node.get_output_size() == 3
activations = ["RELU", "tanh", "Sigmoid"]
activation_alpha = [1.0, 2.0, 3.0]
activation_beta = [3.0, 2.0, 1.0]
clip = 1.22
node_param = ng.lstm_sequence(
parameter_X,
parameter_H_t,
parameter_C_t,
parameter_seq_len,
parameter_W,
parameter_R,
parameter_B,
hidden_size,
direction,
activations,
activation_alpha,
activation_beta,
clip,
)
assert node_param.get_type_name() == "LSTMSequence"
assert node_param.get_output_size() == 3
@pytest.mark.parametrize("dtype", [np.float32, np.float64])
def test_lstm_sequence_operator_reverse(dtype):
batch_size = 2
input_size = 4
hidden_size = 3
num_directions = 1
seq_length = 2
X_shape = [batch_size, seq_length, input_size]
H_t_shape = [batch_size, num_directions, hidden_size]
C_t_shape = [batch_size, num_directions, hidden_size]
seq_len_shape = [batch_size]
W_shape = [num_directions, 4 * hidden_size, input_size]
R_shape = [num_directions, 4 * hidden_size, hidden_size]
B_shape = [num_directions, 4 * hidden_size]
parameter_X = ng.parameter(X_shape, name="X", dtype=dtype)
parameter_H_t = ng.parameter(H_t_shape, name="H_t", dtype=dtype)
parameter_C_t = ng.parameter(C_t_shape, name="C_t", dtype=dtype)
parameter_seq_len = ng.parameter(seq_len_shape, name="seq_len", dtype=np.int32)
parameter_W = ng.parameter(W_shape, name="W", dtype=dtype)
parameter_R = ng.parameter(R_shape, name="R", dtype=dtype)
parameter_B = ng.parameter(B_shape, name="B", dtype=dtype)
direction = "REVERSE"
node_default = ng.lstm_sequence(
parameter_X,
parameter_H_t,
parameter_C_t,
parameter_seq_len,
parameter_W,
parameter_R,
parameter_B,
hidden_size,
direction,
)
assert node_default.get_type_name() == "LSTMSequence"
assert node_default.get_output_size() == 3
activations = ["RELU", "tanh", "Sigmoid"]
activation_alpha = [1.0, 2.0, 3.0]
activation_beta = [3.0, 2.0, 1.0]
clip = 1.22
node_param = ng.lstm_sequence(
parameter_X,
parameter_H_t,
parameter_C_t,
parameter_seq_len,
parameter_W,
parameter_R,
parameter_B,
hidden_size,
direction,
activations,
activation_alpha,
activation_beta,
clip,
)
assert node_param.get_type_name() == "LSTMSequence"
assert node_param.get_output_size() == 3
@pytest.mark.parametrize("dtype", [np.float32, np.float64])
def test_lstm_sequence_operator_forward(dtype):
batch_size = 2
input_size = 4
hidden_size = 3
num_directions = 1
seq_length = 2
X_shape = [batch_size, seq_length, input_size]
H_t_shape = [batch_size, num_directions, hidden_size]
C_t_shape = [batch_size, num_directions, hidden_size]
seq_len_shape = [batch_size]
W_shape = [num_directions, 4 * hidden_size, input_size]
R_shape = [num_directions, 4 * hidden_size, hidden_size]
B_shape = [num_directions, 4 * hidden_size]
parameter_X = ng.parameter(X_shape, name="X", dtype=dtype)
parameter_H_t = ng.parameter(H_t_shape, name="H_t", dtype=dtype)
parameter_C_t = ng.parameter(C_t_shape, name="C_t", dtype=dtype)
parameter_seq_len = ng.parameter(seq_len_shape, name="seq_len", dtype=np.int32)
parameter_W = ng.parameter(W_shape, name="W", dtype=dtype)
parameter_R = ng.parameter(R_shape, name="R", dtype=dtype)
parameter_B = ng.parameter(B_shape, name="B", dtype=dtype)
direction = "forward"
node_default = ng.lstm_sequence(
parameter_X,
parameter_H_t,
parameter_C_t,
parameter_seq_len,
parameter_W,
parameter_R,
parameter_B,
hidden_size,
direction,
)
assert node_default.get_type_name() == "LSTMSequence"
assert node_default.get_output_size() == 3
activations = ["RELU", "tanh", "Sigmoid"]
activation_alpha = [2.0]
activation_beta = [1.0]
clip = 0.5
node = ng.lstm_sequence(
parameter_X,
parameter_H_t,
parameter_C_t,
parameter_seq_len,
parameter_W,
parameter_R,
parameter_B,
hidden_size,
direction,
activations,
activation_alpha,
activation_beta,
clip,
)
assert node.get_type_name() == "LSTMSequence"
assert node.get_output_size() == 3
@pytest.mark.parametrize("dtype", [np.float32, np.float64])
def test_gru_sequence_operator_bidirectional(dtype):
batch_size = 1
input_size = 16
hidden_size = 128
num_directions = 2
seq_length = 2
X_shape = [batch_size, seq_length, input_size]
H_t_shape = [batch_size, num_directions, hidden_size]
seq_len_shape = [batch_size]
W_shape = [num_directions, 3 * hidden_size, input_size]
R_shape = [num_directions, 3 * hidden_size, hidden_size]
B_shape = [num_directions, 3 * hidden_size]
parameter_X = ng.parameter(X_shape, name="X", dtype=dtype)
parameter_H_t = ng.parameter(H_t_shape, name="H_t", dtype=dtype)
parameter_seq_len = ng.parameter(seq_len_shape, name="seq_len", dtype=np.int32)
parameter_W = ng.parameter(W_shape, name="W", dtype=dtype)
parameter_R = ng.parameter(R_shape, name="R", dtype=dtype)
parameter_B = ng.parameter(B_shape, name="B", dtype=dtype)
direction = "BIDIRECTIONAL"
node = ng.gru_sequence(
parameter_X,
parameter_H_t,
parameter_seq_len,
parameter_W,
parameter_R,
parameter_B,
hidden_size,
direction,
)
assert node.get_type_name() == "GRUSequence"
assert node.get_output_size() == 2
activations = ["RELU", "tanh"]
activation_alpha = [1.0, 2.0, 3.0]
activation_beta = [3.0, 2.0, 1.0]
clip = 1.22
linear_before_reset = True
B_shape = [num_directions, 4 * hidden_size]
parameter_B = ng.parameter(B_shape, name="B", dtype=dtype)
node_param = ng.gru_sequence(
parameter_X,
parameter_H_t,
parameter_seq_len,
parameter_W,
parameter_R,
parameter_B,
hidden_size,
direction,
activations,
activation_alpha,
activation_beta,
clip,
linear_before_reset
)
assert node_param.get_type_name() == "GRUSequence"
assert node_param.get_output_size() == 2
@pytest.mark.parametrize("dtype", [np.float32, np.float64])
def test_gru_sequence_operator_reverse(dtype):
batch_size = 2
input_size = 4
hidden_size = 3
num_directions = 1
seq_length = 2
X_shape = [batch_size, seq_length, input_size]
H_t_shape = [batch_size, num_directions, hidden_size]
seq_len_shape = [batch_size]
W_shape = [num_directions, 3 * hidden_size, input_size]
R_shape = [num_directions, 3 * hidden_size, hidden_size]
B_shape = [num_directions, 3 * hidden_size]
parameter_X = ng.parameter(X_shape, name="X", dtype=dtype)
parameter_H_t = ng.parameter(H_t_shape, name="H_t", dtype=dtype)
parameter_seq_len = ng.parameter(seq_len_shape, name="seq_len", dtype=np.int32)
parameter_W = ng.parameter(W_shape, name="W", dtype=dtype)
parameter_R = ng.parameter(R_shape, name="R", dtype=dtype)
parameter_B = ng.parameter(B_shape, name="B", dtype=dtype)
direction = "REVERSE"
node_default = ng.gru_sequence(
parameter_X,
parameter_H_t,
parameter_seq_len,
parameter_W,
parameter_R,
parameter_B,
hidden_size,
direction,
)
assert node_default.get_type_name() == "GRUSequence"
assert node_default.get_output_size() == 2
activations = ["RELU", "tanh"]
activation_alpha = [1.0, 2.0, 3.0]
activation_beta = [3.0, 2.0, 1.0]
clip = 1.22
linear_before_reset = True
B_shape = [num_directions, 4 * hidden_size]
parameter_B = ng.parameter(B_shape, name="B", dtype=dtype)
node_param = ng.gru_sequence(
parameter_X,
parameter_H_t,
parameter_seq_len,
parameter_W,
parameter_R,
parameter_B,
hidden_size,
direction,
activations,
activation_alpha,
activation_beta,
clip,
linear_before_reset
)
assert node_param.get_type_name() == "GRUSequence"
assert node_param.get_output_size() == 2
@pytest.mark.parametrize("dtype", [np.float32, np.float64])
def test_gru_sequence_operator_forward(dtype):
batch_size = 2
input_size = 4
hidden_size = 3
num_directions = 1
seq_length = 2
X_shape = [batch_size, seq_length, input_size]
H_t_shape = [batch_size, num_directions, hidden_size]
seq_len_shape = [batch_size]
W_shape = [num_directions, 3 * hidden_size, input_size]
R_shape = [num_directions, 3 * hidden_size, hidden_size]
B_shape = [num_directions, 3 * hidden_size]
parameter_X = ng.parameter(X_shape, name="X", dtype=dtype)
parameter_H_t = ng.parameter(H_t_shape, name="H_t", dtype=dtype)
parameter_seq_len = ng.parameter(seq_len_shape, name="seq_len", dtype=np.int32)
parameter_W = ng.parameter(W_shape, name="W", dtype=dtype)
parameter_R = ng.parameter(R_shape, name="R", dtype=dtype)
parameter_B = ng.parameter(B_shape, name="B", dtype=dtype)
direction = "forward"
node_default = ng.gru_sequence(
parameter_X,
parameter_H_t,
parameter_seq_len,
parameter_W,
parameter_R,
parameter_B,
hidden_size,
direction,
)
assert node_default.get_type_name() == "GRUSequence"
assert node_default.get_output_size() == 2
activations = ["RELU", "tanh"]
activation_alpha = [2.0]
activation_beta = [1.0]
clip = 0.5
linear_before_reset = True
B_shape = [num_directions, 4 * hidden_size]
parameter_B = ng.parameter(B_shape, name="B", dtype=dtype)
node = ng.gru_sequence(
parameter_X,
parameter_H_t,
parameter_seq_len,
parameter_W,
parameter_R,
parameter_B,
hidden_size,
direction,
activations,
activation_alpha,
activation_beta,
clip,
linear_before_reset
)
assert node.get_type_name() == "GRUSequence"
assert node.get_output_size() == 2
@pytest.mark.parametrize("dtype", [np.float32, np.float64])
def test_rnn_sequence_operator_bidirectional(dtype):
batch_size = 1
input_size = 16
hidden_size = 128
num_directions = 2
seq_length = 2
X_shape = [batch_size, seq_length, input_size]
H_t_shape = [batch_size, num_directions, hidden_size]
seq_len_shape = [batch_size]
W_shape = [num_directions, hidden_size, input_size]
R_shape = [num_directions, hidden_size, hidden_size]
B_shape = [num_directions, hidden_size]
parameter_X = ng.parameter(X_shape, name="X", dtype=dtype)
parameter_H_t = ng.parameter(H_t_shape, name="H_t", dtype=dtype)
parameter_seq_len = ng.parameter(seq_len_shape, name="seq_len", dtype=np.int32)
parameter_W = ng.parameter(W_shape, name="W", dtype=dtype)
parameter_R = ng.parameter(R_shape, name="R", dtype=dtype)
parameter_B = ng.parameter(B_shape, name="B", dtype=dtype)
direction = "BIDIRECTIONAL"
node = ng.rnn_sequence(
parameter_X,
parameter_H_t,
parameter_seq_len,
parameter_W,
parameter_R,
parameter_B,
hidden_size,
direction,
)
assert node.get_type_name() == "RNNSequence"
assert node.get_output_size() == 2
activations = ["RELU", "tanh"]
activation_alpha = [1.0, 2.0, 3.0]
activation_beta = [3.0, 2.0, 1.0]
clip = 1.22
node_param = ng.rnn_sequence(
parameter_X,
parameter_H_t,
parameter_seq_len,
parameter_W,
parameter_R,
parameter_B,
hidden_size,
direction,
activations,
activation_alpha,
activation_beta,
clip,
)
assert node_param.get_type_name() == "RNNSequence"
assert node_param.get_output_size() == 2
@pytest.mark.parametrize("dtype", [np.float32, np.float64])
def test_rnn_sequence_operator_reverse(dtype):
batch_size = 2
input_size = 4
hidden_size = 3
num_directions = 1
seq_length = 2
X_shape = [batch_size, seq_length, input_size]
H_t_shape = [batch_size, num_directions, hidden_size]
seq_len_shape = [batch_size]
W_shape = [num_directions, hidden_size, input_size]
R_shape = [num_directions, hidden_size, hidden_size]
B_shape = [num_directions, hidden_size]
parameter_X = ng.parameter(X_shape, name="X", dtype=dtype)
parameter_H_t = ng.parameter(H_t_shape, name="H_t", dtype=dtype)
parameter_seq_len = ng.parameter(seq_len_shape, name="seq_len", dtype=np.int32)
parameter_W = ng.parameter(W_shape, name="W", dtype=dtype)
parameter_R = ng.parameter(R_shape, name="R", dtype=dtype)
parameter_B = ng.parameter(B_shape, name="B", dtype=dtype)
direction = "REVERSE"
node_default = ng.rnn_sequence(
parameter_X,
parameter_H_t,
parameter_seq_len,
parameter_W,
parameter_R,
parameter_B,
hidden_size,
direction,
)
assert node_default.get_type_name() == "RNNSequence"
assert node_default.get_output_size() == 2
activations = ["RELU", "tanh"]
activation_alpha = [1.0, 2.0, 3.0]
activation_beta = [3.0, 2.0, 1.0]
clip = 1.22
node_param = ng.rnn_sequence(
parameter_X,
parameter_H_t,
parameter_seq_len,
parameter_W,
parameter_R,
parameter_B,
hidden_size,
direction,
activations,
activation_alpha,
activation_beta,
clip,
)
assert node_param.get_type_name() == "RNNSequence"
assert node_param.get_output_size() == 2
@pytest.mark.parametrize("dtype", [np.float32, np.float64])
def test_rnn_sequence_operator_forward(dtype):
batch_size = 2
input_size = 4
hidden_size = 3
num_directions = 1
seq_length = 2
X_shape = [batch_size, seq_length, input_size]
H_t_shape = [batch_size, num_directions, hidden_size]
seq_len_shape = [batch_size]
W_shape = [num_directions, hidden_size, input_size]
R_shape = [num_directions, hidden_size, hidden_size]
B_shape = [num_directions, hidden_size]
parameter_X = ng.parameter(X_shape, name="X", dtype=dtype)
parameter_H_t = ng.parameter(H_t_shape, name="H_t", dtype=dtype)
parameter_seq_len = ng.parameter(seq_len_shape, name="seq_len", dtype=np.int32)
parameter_W = ng.parameter(W_shape, name="W", dtype=dtype)
parameter_R = ng.parameter(R_shape, name="R", dtype=dtype)
parameter_B = ng.parameter(B_shape, name="B", dtype=dtype)
direction = "forward"
node_default = ng.rnn_sequence(
parameter_X,
parameter_H_t,
parameter_seq_len,
parameter_W,
parameter_R,
parameter_B,
hidden_size,
direction,
)
assert node_default.get_type_name() == "RNNSequence"
assert node_default.get_output_size() == 2
activations = ["RELU", "tanh"]
activation_alpha = [2.0]
activation_beta = [1.0]
clip = 0.5
node = ng.rnn_sequence(
parameter_X,
parameter_H_t,
parameter_seq_len,
parameter_W,
parameter_R,
parameter_B,
hidden_size,
direction,
activations,
activation_alpha,
activation_beta,
clip,
)
assert node.get_type_name() == "RNNSequence"
assert node.get_output_size() == 2

View File

@ -38,8 +38,7 @@ from tests import (
xfail_issue_39669,
xfail_issue_38726,
xfail_issue_40686,
xfail_issue_42779,
xfail_issue_42818)
xfail_issue_42779)
MODELS_ROOT_DIR = tests.MODEL_ZOO_DIR
@ -182,7 +181,6 @@ if len(zoo_models) > 0:
(xfail_issue_41815, "test_MSFT_opset11_tinyyolov3_yolov3_tiny_cpu"),
(xfail_issue_41815, "test_MSFT_opset10_yolov3_yolov3_cpu"),
(xfail_issue_42818, "test_MSFT_opset9_LSTM_Seq_lens_unpacked_model_cpu"),
]
for test_case in import_xfail_list + execution_xfail_list:
xfail, test_name = test_case

View File

@ -1689,7 +1689,7 @@ NGRAPH_TEST_F(${BACKEND_NAME}, RNNSequenceOp, onnx_model_rnn_fwd_mixed_seq_len_c
-0.18203181f,
0.9996245f,
});
test_case.run(DEFAULT_FLOAT_TOLERANCE_BITS + 3);
test_case.run(DEFAULT_FLOAT_TOLERANCE_BITS + 4);
}
NGRAPH_TEST_F(${BACKEND_NAME}, RNNSequenceOp, onnx_model_rnn_fwd_mixed_seq_len)

View File

@ -237,12 +237,12 @@ IE_CPU.onnx_model_rnn_fwd_bias_initial_h
IE_CPU.onnx_model_rnn_bidirectional
## RNN/GRU/LSTM Sequence: Output values mismatch - seq_lengths not supported
IE_CPU.onnx_model_lstm_fwd_mixed_seq_const
IE_CPU.onnx_model_lstm_reverse_mixed_seq_const
IE_CPU.onnx_model_rnn_fwd_mixed_seq_len
IE_CPU.onnx_model_rnn_fwd_mixed_seq_len_const
IE_CPU.onnx_model_gru_fwd_mixed_seq_len
IE_CPU.onnx_model_gru_fwd_mixed_seq_len_const
IE_GPU.onnx_model_lstm_fwd_mixed_seq_const
IE_GPU.onnx_model_lstm_reverse_mixed_seq_const
IE_GPU.onnx_model_rnn_fwd_mixed_seq_len
IE_GPU.onnx_model_rnn_fwd_mixed_seq_len_const
IE_GPU.onnx_model_gru_fwd_mixed_seq_len
IE_GPU.onnx_model_gru_fwd_mixed_seq_len_const
#-------------------------------------------------------------------------------

View File

@ -770,74 +770,171 @@ protected:
case OP_TYPEID::LSTMSequence_v5:
{
auto lstm_seq = static_cast<const op::v5::LSTMSequence*>(&node);
runtime::reference::lstm_sequence<T>(args[0]->get_data_ptr<char>(),
args[0]->get_shape(),
args[1]->get_data_ptr<char>(),
args[1]->get_shape(),
args[2]->get_data_ptr<char>(),
args[2]->get_shape(),
args[3]->get_data_ptr<char>(),
args[3]->get_shape(),
args[4]->get_data_ptr<char>(),
args[4]->get_shape(),
args[5]->get_data_ptr<char>(),
args[5]->get_shape(),
args[6]->get_data_ptr<char>(),
args[6]->get_shape(),
out[0]->get_data_ptr<char>(),
out[1]->get_data_ptr<char>(),
out[2]->get_data_ptr<char>(),
lstm_seq->get_activations()[0],
lstm_seq->get_activations()[1],
lstm_seq->get_activations()[2],
lstm_seq->get_clip(),
lstm_seq->get_direction());
auto type = args[3]->get_element_type();
if (type == element::i64 || type == element::u64)
{
runtime::reference::lstm_sequence<T, int64_t>(args[0]->get_data_ptr<char>(),
args[0]->get_shape(),
args[1]->get_data_ptr<char>(),
args[1]->get_shape(),
args[2]->get_data_ptr<char>(),
args[2]->get_shape(),
args[3]->get_data_ptr<char>(),
args[3]->get_shape(),
args[4]->get_data_ptr<char>(),
args[4]->get_shape(),
args[5]->get_data_ptr<char>(),
args[5]->get_shape(),
args[6]->get_data_ptr<char>(),
args[6]->get_shape(),
out[0]->get_data_ptr<char>(),
out[1]->get_data_ptr<char>(),
out[2]->get_data_ptr<char>(),
lstm_seq->get_activations()[0],
lstm_seq->get_activations()[1],
lstm_seq->get_activations()[2],
lstm_seq->get_clip(),
lstm_seq->get_direction());
}
else if (type == element::i32 || type == element::u32)
{
runtime::reference::lstm_sequence<T, int32_t>(args[0]->get_data_ptr<char>(),
args[0]->get_shape(),
args[1]->get_data_ptr<char>(),
args[1]->get_shape(),
args[2]->get_data_ptr<char>(),
args[2]->get_shape(),
args[3]->get_data_ptr<char>(),
args[3]->get_shape(),
args[4]->get_data_ptr<char>(),
args[4]->get_shape(),
args[5]->get_data_ptr<char>(),
args[5]->get_shape(),
args[6]->get_data_ptr<char>(),
args[6]->get_shape(),
out[0]->get_data_ptr<char>(),
out[1]->get_data_ptr<char>(),
out[2]->get_data_ptr<char>(),
lstm_seq->get_activations()[0],
lstm_seq->get_activations()[1],
lstm_seq->get_activations()[2],
lstm_seq->get_clip(),
lstm_seq->get_direction());
}
else
{
std::stringstream ss;
ss << "unsupported element type " << type << " op LSTMSequence";
throw std::runtime_error(ss.str());
}
break;
}
case OP_TYPEID::GRUSequence_v5:
{
auto gru_seq = static_cast<const op::v5::GRUSequence*>(&node);
runtime::reference::gru_sequence<T>(args[0]->get_data_ptr<char>(),
args[0]->get_shape(),
args[1]->get_data_ptr<char>(),
args[1]->get_shape(),
args[2]->get_data_ptr<char>(),
args[2]->get_shape(),
args[3]->get_data_ptr<char>(),
args[3]->get_shape(),
args[4]->get_data_ptr<char>(),
args[4]->get_shape(),
args[5]->get_data_ptr<char>(),
args[5]->get_shape(),
out[0]->get_data_ptr<char>(),
out[1]->get_data_ptr<char>(),
gru_seq->get_activations()[0],
gru_seq->get_activations()[1],
gru_seq->get_clip(),
gru_seq->get_direction(),
gru_seq->get_linear_before_reset());
auto type = args[2]->get_element_type();
if (type == element::i64 || type == element::u64)
{
runtime::reference::gru_sequence<T, int64_t>(args[0]->get_data_ptr<char>(),
args[0]->get_shape(),
args[1]->get_data_ptr<char>(),
args[1]->get_shape(),
args[2]->get_data_ptr<char>(),
args[2]->get_shape(),
args[3]->get_data_ptr<char>(),
args[3]->get_shape(),
args[4]->get_data_ptr<char>(),
args[4]->get_shape(),
args[5]->get_data_ptr<char>(),
args[5]->get_shape(),
out[0]->get_data_ptr<char>(),
out[1]->get_data_ptr<char>(),
gru_seq->get_activations()[0],
gru_seq->get_activations()[1],
gru_seq->get_clip(),
gru_seq->get_direction(),
gru_seq->get_linear_before_reset());
}
else if (type == element::i32 || type == element::u32)
{
runtime::reference::gru_sequence<T, int32_t>(args[0]->get_data_ptr<char>(),
args[0]->get_shape(),
args[1]->get_data_ptr<char>(),
args[1]->get_shape(),
args[2]->get_data_ptr<char>(),
args[2]->get_shape(),
args[3]->get_data_ptr<char>(),
args[3]->get_shape(),
args[4]->get_data_ptr<char>(),
args[4]->get_shape(),
args[5]->get_data_ptr<char>(),
args[5]->get_shape(),
out[0]->get_data_ptr<char>(),
out[1]->get_data_ptr<char>(),
gru_seq->get_activations()[0],
gru_seq->get_activations()[1],
gru_seq->get_clip(),
gru_seq->get_direction(),
gru_seq->get_linear_before_reset());
}
else
{
std::stringstream ss;
ss << "unsupported element type " << type << " op GRUSequence";
throw std::runtime_error(ss.str());
}
break;
}
case OP_TYPEID::RNNSequence_v5:
{
auto rnn_seq = static_cast<const op::v5::RNNSequence*>(&node);
runtime::reference::rnn_sequence<T>(args[0]->get_data_ptr<char>(),
args[0]->get_shape(),
args[1]->get_data_ptr<char>(),
args[1]->get_shape(),
args[2]->get_data_ptr<char>(),
args[2]->get_shape(),
args[3]->get_data_ptr<char>(),
args[3]->get_shape(),
args[4]->get_data_ptr<char>(),
args[4]->get_shape(),
args[5]->get_data_ptr<char>(),
args[5]->get_shape(),
out[0]->get_data_ptr<char>(),
out[1]->get_data_ptr<char>(),
rnn_seq->get_activations()[0],
rnn_seq->get_clip(),
rnn_seq->get_direction());
auto type = args[2]->get_element_type();
if (type == element::i64 || type == element::u64)
{
runtime::reference::rnn_sequence<T, int64_t>(args[0]->get_data_ptr<char>(),
args[0]->get_shape(),
args[1]->get_data_ptr<char>(),
args[1]->get_shape(),
args[2]->get_data_ptr<char>(),
args[2]->get_shape(),
args[3]->get_data_ptr<char>(),
args[3]->get_shape(),
args[4]->get_data_ptr<char>(),
args[4]->get_shape(),
args[5]->get_data_ptr<char>(),
args[5]->get_shape(),
out[0]->get_data_ptr<char>(),
out[1]->get_data_ptr<char>(),
rnn_seq->get_activations()[0],
rnn_seq->get_clip(),
rnn_seq->get_direction());
}
else if (type == element::i32 || type == element::u32)
{
runtime::reference::rnn_sequence<T, int32_t>(args[0]->get_data_ptr<char>(),
args[0]->get_shape(),
args[1]->get_data_ptr<char>(),
args[1]->get_shape(),
args[2]->get_data_ptr<char>(),
args[2]->get_shape(),
args[3]->get_data_ptr<char>(),
args[3]->get_shape(),
args[4]->get_data_ptr<char>(),
args[4]->get_shape(),
args[5]->get_data_ptr<char>(),
args[5]->get_shape(),
out[0]->get_data_ptr<char>(),
out[1]->get_data_ptr<char>(),
rnn_seq->get_activations()[0],
rnn_seq->get_clip(),
rnn_seq->get_direction());
}
else
{
std::stringstream ss;
ss << "unsupported element type " << type << " op RNNSequence";
throw std::runtime_error(ss.str());
}
break;
}
case OP_TYPEID::Log:
@ -1171,6 +1268,15 @@ protected:
reverse->get_sequence_axis(),
args[1]->get_data_ptr<const int32_t>());
}
else if (node.get_input_element_type(1) == element::i64)
{
reference::reverse_sequence<T, int64_t>(args[0]->get_data_ptr<const T>(),
out[0]->get_data_ptr<T>(),
node.get_input_shape(0),
reverse->get_batch_axis(),
reverse->get_sequence_axis(),
args[1]->get_data_ptr<const int64_t>());
}
else
{
throw ngraph_error("only int32 indices are supported");

View File

@ -96,17 +96,6 @@ INTERPRETER.onnx_model_conv_integer_zero_point_zero
INTERPRETER.onnx_model_conv_integer_no_zero_point
INTERPRETER.onnx_model_conv_integer_pads
# GRU/RNN/LSTM Sequence: Output values mismatch - seq_lengths not supported
onnx_model_lstm_fwd_mixed_seq_const
onnx_model_lstm_reverse_mixed_seq_const
onnx_model_lstm_fwd_mixed_seq
onnx_model_lstm_mixed_seq_reverse
onnx_model_gru_fwd_mixed_seq_len
onnx_model_gru_fwd_mixed_seq_len_const
onnx_model_rnn_fwd_mixed_seq_len
onnx_model_rnn_fwd_mixed_seq_len_const
# Activation function hardsigmoid is not supported.
gru_cell_activation_function
lstm_cell_activaction_functions