[nGraph] Reorder nGraph LSTMSequence inputs and outputs dimensions (#560)

* Reorder nGraph LSTMSequence input/outpt dimensions

* Update nGraph pythonAPI for LSTMSequence

* Reorder axes in ONNX importer LSTM

* Tests update

* Fix clang warning

* Use opset3 namespace

* Style apply

* Tests update

* Use opset1  namespace

* Remove usage of  GetOutputElement in ONNX importer LSTM

* Remove opset0 header

* Use Node::output()
This commit is contained in:
Katarzyna Mitrus 2020-05-29 13:29:18 +02:00 committed by GitHub
parent a4f13ae9fe
commit 5f8f9ec108
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 231 additions and 143 deletions

View File

@ -472,11 +472,11 @@ def lstm_sequence(
) -> Node:
"""Return a node which performs LSTMSequence operation.
:param X: The input tensor. Shape: [seq_length, batch_size, input_size].
:param X: The input tensor. Shape: [batch_size, seq_length, input_size].
:param initial_hidden_state: The hidden state tensor.
Shape: [num_directions, batch_size, hidden_size].
Shape: [batch_size, num_directions, hidden_size].
:param initial_cell_state: The cell state tensor.
Shape: [num_directions, batch_size, hidden_size].
Shape: [batch_size, num_directions, hidden_size].
:param sequence_lengths: Specifies real sequence lengths for each batch element.
Shape: [batch_size]. Integer type.
:param W: Tensor with weights for matrix multiplication operation with input portion of data.

View File

@ -258,9 +258,9 @@ def test_lstm_sequence_operator_bidirectional(dtype):
num_directions = 2
seq_length = 2
X_shape = [seq_length, batch_size, input_size]
H_t_shape = [num_directions, batch_size, hidden_size]
C_t_shape = [num_directions, batch_size, hidden_size]
X_shape = [batch_size, seq_length, input_size]
H_t_shape = [batch_size, num_directions, hidden_size]
C_t_shape = [batch_size, num_directions, hidden_size]
seq_len_shape = [batch_size]
W_shape = [num_directions, 4 * hidden_size, input_size]
R_shape = [num_directions, 4 * hidden_size, hidden_size]
@ -323,9 +323,9 @@ def test_lstm_sequence_operator_reverse(dtype):
num_directions = 1
seq_length = 2
X_shape = [seq_length, batch_size, input_size]
H_t_shape = [num_directions, batch_size, hidden_size]
C_t_shape = [num_directions, batch_size, hidden_size]
X_shape = [batch_size, seq_length, input_size]
H_t_shape = [batch_size, num_directions, hidden_size]
C_t_shape = [batch_size, num_directions, hidden_size]
seq_len_shape = [batch_size]
W_shape = [num_directions, 4 * hidden_size, input_size]
R_shape = [num_directions, 4 * hidden_size, hidden_size]
@ -389,9 +389,9 @@ def test_lstm_sequence_operator_forward(dtype):
num_directions = 1
seq_length = 2
X_shape = [seq_length, batch_size, input_size]
H_t_shape = [num_directions, batch_size, hidden_size]
C_t_shape = [num_directions, batch_size, hidden_size]
X_shape = [batch_size, seq_length, input_size]
H_t_shape = [batch_size, num_directions, hidden_size]
C_t_shape = [batch_size, num_directions, hidden_size]
seq_len_shape = [batch_size]
W_shape = [num_directions, 4 * hidden_size, input_size]
R_shape = [num_directions, 4 * hidden_size, hidden_size]

View File

@ -24,13 +24,13 @@
#include "default_opset.hpp"
#include "exceptions.hpp"
#include "lstm.hpp"
#include "ngraph/builder/reshape.hpp"
#include "ngraph/builder/split.hpp"
#include "ngraph/frontend/onnx_import/op/lstm.hpp"
#include "ngraph/op/add.hpp"
#include "ngraph/op/constant.hpp"
#include "ngraph/op/fused/lstm_sequence.hpp"
#include "ngraph/op/get_output_element.hpp"
#include "ngraph/opsets/opset0.hpp"
#include "ngraph/shape.hpp"
#include "ngraph/type/element_type.hpp"
@ -71,7 +71,8 @@ namespace ngraph
// ----- Mandatory inputs ------
// Packed input sequences. Shape: [seq_length, batch_size, input_size]
m_map[LSTMInput::LSTM_INPUT_X] = ng_inputs.at(0);
m_map[LSTMInput::LSTM_INPUT_X] =
builder::opset1::reorder_axes(ng_inputs.at(0), {1, 0, 2});
// Weight tensor for the gates.
// Shape: [num_directions, 4*hidden_size, input_size]
m_map[LSTMInput::LSTM_INPUT_W] = ng_inputs.at(1);
@ -82,7 +83,7 @@ namespace ngraph
const std::size_t hidden_size =
m_map[LSTMInput::LSTM_INPUT_R]->get_shape().back();
const std::size_t batch_size =
m_map[LSTMInput::LSTM_INPUT_X]->get_shape().at(1);
m_map[LSTMInput::LSTM_INPUT_X]->get_shape().at(0);
const std::size_t num_directions =
m_map[LSTMInput::LSTM_INPUT_W]->get_shape().front();
@ -115,33 +116,35 @@ namespace ngraph
Shape{batch_size},
std::vector<std::int32_t>(
batch_size,
m_map[LSTMInput::LSTM_INPUT_X]->get_shape().at(0)));
m_map[LSTMInput::LSTM_INPUT_X]->get_shape().at(1)));
}
// The initial value of the hidden.
// Shape [num_directions, batch_size, hidden_size]
if (ng_inputs.size() > 5 && !ng_inputs.at(5)->is_null())
{
m_map[LSTMInput::LSTM_INPUT_INIT_H] = ng_inputs.at(5);
m_map[LSTMInput::LSTM_INPUT_INIT_H] =
builder::opset1::reorder_axes(ng_inputs.at(5), {1, 0, 2});
}
else
{
m_map[LSTMInput::LSTM_INPUT_INIT_H] = default_opset::Constant::create(
element::f32,
Shape{num_directions, batch_size, hidden_size},
std::vector<float>(num_directions * batch_size * hidden_size, 0.f));
Shape{batch_size, num_directions, hidden_size},
std::vector<float>(batch_size * num_directions * hidden_size, 0.f));
}
// The initial value of the cell.
// Shape [num_directions, batch_size, hidden_size]
if (ng_inputs.size() > 6 && !ng_inputs.at(6)->is_null())
{
m_map[LSTMInput::LSTM_INPUT_INIT_C] = ng_inputs.at(6);
m_map[LSTMInput::LSTM_INPUT_INIT_C] =
builder::opset1::reorder_axes(ng_inputs.at(6), {1, 0, 2});
}
else
{
m_map[LSTMInput::LSTM_INPUT_INIT_C] = default_opset::Constant::create(
element::f32,
Shape{num_directions, batch_size, hidden_size},
std::vector<float>(num_directions * batch_size * hidden_size, 0.f));
Shape{batch_size, num_directions, hidden_size},
std::vector<float>(batch_size * num_directions * hidden_size, 0.f));
}
// The weight tensor for peepholes. Shape [num_directions, 3*hidde_size]
if (ng_inputs.size() > 7 && !ng_inputs.at(7)->is_null())
@ -239,9 +242,14 @@ namespace ngraph
attributes.m_activations,
attributes.m_clip_threshold,
attributes.m_input_forget);
return {std::make_shared<ngraph::opset0::GetOutputElement>(lstmSequence, 0),
std::make_shared<ngraph::opset0::GetOutputElement>(lstmSequence, 1),
std::make_shared<ngraph::opset0::GetOutputElement>(lstmSequence, 2)};
const auto Y = lstmSequence->output(0);
const auto Y_h = lstmSequence->output(1);
const auto Y_c = lstmSequence->output(2);
return {builder::opset1::reorder_axes(Y, {2, 1, 0, 3}),
builder::opset1::reorder_axes(Y_h, {1, 0, 2}),
builder::opset1::reorder_axes(Y_c, {1, 0, 2})};
}
} // namespace set_1

View File

@ -20,19 +20,13 @@
#include "ngraph/builder/autobroadcast.hpp"
#include "ngraph/builder/reshape.hpp"
#include "ngraph/builder/split.hpp"
#include "ngraph/frontend/onnx_import/utils/reshape.hpp"
#include "ngraph/op/concat.hpp"
#include "ngraph/op/constant.hpp"
#include "ngraph/op/fused/lstm_cell.hpp"
#include "ngraph/op/get_output_element.hpp"
#include "ngraph/op/greater.hpp"
#include "ngraph/op/reverse_sequence.hpp"
#include "ngraph/op/select.hpp"
#include "ngraph/opsets/opset1.hpp"
using namespace ngraph;
using namespace std;
constexpr NodeTypeInfo op::LSTMSequence::type_info;
constexpr NodeTypeInfo op::v0::LSTMSequence::type_info;
bool ngraph::op::v0::LSTMSequence::visit_attributes(AttributeVisitor& visitor)
{
visitor.on_attribute("hidden_size", m_hidden_size);
@ -46,7 +40,7 @@ bool ngraph::op::v0::LSTMSequence::visit_attributes(AttributeVisitor& visitor)
visitor.on_attribute("weights_format", m_weights_format);
return true;
}
NodeVector op::LSTMSequence::decompose_op() const
NodeVector op::v0::LSTMSequence::decompose_op() const
{
NodeVector results;
if (m_direction == direction::FORWARD || m_direction == direction::REVERSE)
@ -60,55 +54,55 @@ NodeVector op::LSTMSequence::decompose_op() const
// Stack together respective outputs from both forward and reverse passess.
shared_ptr<Node> Y{
make_shared<op::Concat>(NodeVector{fwd_results.at(0), rev_results.at(0)}, 1)};
make_shared<opset1::Concat>(NodeVector{fwd_results.at(0), rev_results.at(0)}, 1)};
shared_ptr<Node> Y_h{
make_shared<op::Concat>(NodeVector{fwd_results.at(1), rev_results.at(1)}, 0)};
make_shared<opset1::Concat>(NodeVector{fwd_results.at(1), rev_results.at(1)}, 1)};
shared_ptr<Node> Y_c{
make_shared<op::Concat>(NodeVector{fwd_results.at(2), rev_results.at(2)}, 0)};
make_shared<opset1::Concat>(NodeVector{fwd_results.at(2), rev_results.at(2)}, 1)};
results = NodeVector{Y, Y_h, Y_c};
}
return results;
}
shared_ptr<Node> op::LSTMSequence::clone_with_new_inputs(const OutputVector& new_args) const
shared_ptr<Node> op::v0::LSTMSequence::clone_with_new_inputs(const OutputVector& new_args) const
{
check_new_args_count(this, new_args);
if (new_args.size() == 8)
{
return make_shared<LSTMSequence>(new_args.at(0), // X
new_args.at(1), // initial_hidden_state
new_args.at(2), // initial_cell_state
new_args.at(3), // sequence_lengths
new_args.at(4), // W
new_args.at(5), // R
new_args.at(6), // B
new_args.at(7), // P
m_hidden_size,
m_direction,
m_weights_format,
m_activations_alpha,
m_activations_beta,
m_activations,
m_clip_threshold,
m_input_forget);
return make_shared<op::v0::LSTMSequence>(new_args.at(0), // X
new_args.at(1), // initial_hidden_state
new_args.at(2), // initial_cell_state
new_args.at(3), // sequence_lengths
new_args.at(4), // W
new_args.at(5), // R
new_args.at(6), // B
new_args.at(7), // P
m_hidden_size,
m_direction,
m_weights_format,
m_activations_alpha,
m_activations_beta,
m_activations,
m_clip_threshold,
m_input_forget);
}
else if (new_args.size() == 7)
{
return make_shared<LSTMSequence>(new_args.at(0), // X
new_args.at(1), // initial_hidden_state
new_args.at(2), // initial_cell_state
new_args.at(3), // sequence_lengths
new_args.at(4), // W
new_args.at(5), // R
new_args.at(6), // B
m_hidden_size,
m_direction,
m_weights_format,
m_activations_alpha,
m_activations_beta,
m_activations,
m_clip_threshold,
m_input_forget);
return make_shared<op::v0::LSTMSequence>(new_args.at(0), // X
new_args.at(1), // initial_hidden_state
new_args.at(2), // initial_cell_state
new_args.at(3), // sequence_lengths
new_args.at(4), // W
new_args.at(5), // R
new_args.at(6), // B
m_hidden_size,
m_direction,
m_weights_format,
m_activations_alpha,
m_activations_beta,
m_activations,
m_clip_threshold,
m_input_forget);
}
else
{
@ -116,46 +110,44 @@ shared_ptr<Node> op::LSTMSequence::clone_with_new_inputs(const OutputVector& new
}
}
shared_ptr<Node> op::LSTMSequence::get_masked_node(const Output<Node>& data,
int32_t time_step,
size_t batch_axis,
const Output<Node>& default_value) const
shared_ptr<Node> op::v0::LSTMSequence::get_masked_node(const Output<Node>& data,
int32_t time_step,
size_t batch_axis,
const Output<Node>& default_value) const
{
Output<Node> mask_value = default_value;
// Create zero mask value node.
if (!mask_value.get_node_shared_ptr())
{
mask_value = op::Constant::create(data.get_element_type(),
data.get_shape(),
vector<float>(shape_size(data.get_shape()), 0.f));
mask_value = opset1::Constant::create(data.get_element_type(),
data.get_shape(),
vector<float>(shape_size(data.get_shape()), 0.f));
}
// Create predicate nodes. The condition is whether current time step value
// is greater than sequence length for respective batch inputs.
shared_ptr<Node> curr_time_step_node = op::Constant::create(
shared_ptr<Node> curr_time_step_node = opset1::Constant::create(
element::i32, data.get_shape(), vector<int32_t>(shape_size(data.get_shape()), time_step));
Output<Node> batch_seq_length =
builder::legacy_broadcast_for_binary_operation(
curr_time_step_node, input_value(3).get_node_shared_ptr(), batch_axis)
.at(1);
Output<Node> batch_seq_length = builder::opset1::legacy_broadcast_for_binary_operation(
curr_time_step_node, input_value(3).get_node_shared_ptr(), batch_axis);
// Create mask node deciding whether or not to mask batch data.
shared_ptr<Node> mask_condition =
make_shared<op::Greater>(curr_time_step_node, batch_seq_length);
make_shared<opset1::Greater>(curr_time_step_node, batch_seq_length);
// Select values depnding on mask_condition.
// Select(<condition>, <true_value>, <false_value>)
return make_shared<op::Select>(mask_condition, mask_value, data);
return make_shared<opset1::Select>(mask_condition, mask_value, data);
}
NodeVector op::LSTMSequence::lstm_pass(bool is_reverse) const
NodeVector op::v0::LSTMSequence::lstm_pass(bool is_reverse) const
{
// ------ VARIABLE'S NAMES AND ACRONYM DEFINITIONS ------
// The names used below are analogous to the one used in ONNX documentation.
//
// ------ INPUTS ------
// X - The input tensor. [seq_length, batch_size, input_size]
// X - The input tensor. [batch_size, seq_length, input_size]
// W - The weight tensor. [num_directions, 4*hidden_size, input_size]
// R - The recurrence weight tensor. [num_directions, 4*hidden_size, hidden_size]
// B - The bias tensor for input gate. [num_directions, 8*hidden_size]
@ -167,14 +159,14 @@ NodeVector op::LSTMSequence::lstm_pass(bool is_reverse) const
// c - cell gate
// t - time step (t-1 means previous time step)
// ------ VARIABLE NAMES ------
// H_t - Hidden state vector at current time step.
// C_t - Cell state vector at current time step.
// H_t - Hidden state vector at current time step. [batch_size, num_directions, hidden_size]
// C_t - Cell state vector at current time step. [batch_size, num_directions, hidden_size]
// h_list - The list of hidden states at all processed time steps.
NodeVector h_list;
shared_ptr<Node> X = input_value(0).get_node_shared_ptr();
shared_ptr<Node> H_t = prepare_input(input_value(1), is_reverse);
shared_ptr<Node> C_t = prepare_input(input_value(2), is_reverse);
shared_ptr<Node> H_t = prepare_input(input_value(1), is_reverse, 1);
shared_ptr<Node> C_t = prepare_input(input_value(2), is_reverse, 1);
shared_ptr<Node> seq_lengths = input_value(3).get_node_shared_ptr();
shared_ptr<Node> W = prepare_input(input_value(4), is_reverse);
shared_ptr<Node> R = prepare_input(input_value(5), is_reverse);
@ -183,34 +175,34 @@ NodeVector op::LSTMSequence::lstm_pass(bool is_reverse) const
if (is_reverse)
{
X = make_shared<op::ReverseSequence>(X, seq_lengths, 1 /*batch_axis*/, 0 /*seq_axis*/);
X = make_shared<opset1::ReverseSequence>(X, seq_lengths, 0 /*batch_axis*/, 1 /*seq_axis*/);
}
NodeVector in_seqs = builder::split(X, X->get_shape().at(0));
NodeVector in_seqs = builder::opset1::split(X, X->get_shape().at(1), 1);
for (auto& in_x : in_seqs)
{
// remove first empty dim, after above split.
in_x = builder::squeeze(in_x);
// Remove empty dim, after above split.
in_x = builder::opset1::squeeze(in_x, {1});
}
int32_t time_step{1};
for (const auto& in_x : in_seqs)
{
shared_ptr<Node> lstm_cell = make_shared<op::LSTMCell>(in_x,
H_t,
C_t,
W,
R,
B,
P,
m_hidden_size,
m_weights_format,
m_activations,
m_activations_alpha,
m_activations_beta,
m_clip_threshold,
m_input_forget);
shared_ptr<Node> lstm_cell = make_shared<opset1::LSTMCell>(in_x,
H_t,
C_t,
W,
R,
B,
P,
m_hidden_size,
m_weights_format,
m_activations,
m_activations_alpha,
m_activations_beta,
m_clip_threshold,
m_input_forget);
Output<Node> H = lstm_cell->output(0);
Output<Node> C = lstm_cell->output(1);
@ -220,7 +212,7 @@ NodeVector op::LSTMSequence::lstm_pass(bool is_reverse) const
// Mask hidden state tensor in order to handle mixed sequence lengths.
// This results in zeroing out values in batches with sequence shorter
// than current time_step.
h_list.push_back(get_masked_node(builder::expand_dims(H), time_step, 1));
h_list.push_back(get_masked_node(builder::opset1::expand_dims(H, 1), time_step, 0));
// Reference implementation in ONNX Runtime doesn't mask values of Y_h
// and Y_c outputs, thus here we make sure that only appropriate batches
// (in respect to its sequence length) are updated. Those batches which
@ -230,36 +222,38 @@ NodeVector op::LSTMSequence::lstm_pass(bool is_reverse) const
time_step++;
}
// The tensor that concats all the intermediate output values of the hidden.
// It has shape [seq_length, batch_size, hidden_size]
shared_ptr<Node> Y{make_shared<op::Concat>(h_list, 0)};
// It has shape [batch_size, seq_length, hidden_size]
shared_ptr<Node> Y{make_shared<opset1::Concat>(h_list, 1)};
// Get back the original order of the output data.
if (is_reverse)
{
Y = make_shared<op::ReverseSequence>(Y, seq_lengths, 1 /*batch_axis*/, 0 /*seq_axis*/);
Y = make_shared<opset1::ReverseSequence>(Y, seq_lengths, 0 /*batch_axis*/, 1 /*seq_axis*/);
}
// Expand Y so that it has expected shape:
// [seq_length, num_directions, batch_size, hidden_size]
Y = builder::expand_dims(Y, 1);
// [batch_size, num_directions, seq_length, hidden_size]
Y = builder::opset1::expand_dims(Y, 1);
// expand H_t and C_t so that it has expected shape:
// [num_directions, batch_size, hidden_size]
auto Y_h = builder::expand_dims(H_t);
auto Y_c = builder::expand_dims(C_t);
// [ batch_size, num_directions, hidden_size]
auto Y_h = builder::opset1::expand_dims(H_t, 1);
auto Y_c = builder::opset1::expand_dims(C_t, 1);
return {Y, Y_h, Y_c};
}
shared_ptr<Node> op::LSTMSequence::prepare_input(Output<Node> node, bool is_reverse) const
shared_ptr<Node> op::v0::LSTMSequence::prepare_input(Output<Node> node,
bool is_reverse,
size_t num_direction_axis) const
{
// In bidirectional mode inputs are stacked together, so we must split them.
shared_ptr<Node> tmp = node.get_node_shared_ptr();
if (m_direction == direction::BIDIRECTIONAL)
{
tmp = builder::split(node, 2).at(is_reverse ? 1 : 0);
tmp = builder::opset1::split(node, 2, num_direction_axis).at(is_reverse ? 1 : 0);
}
// Since we have forward LSTM we can squeeze `num_directions` axis from inputs.
return builder::squeeze(tmp);
return builder::opset1::squeeze(tmp, {num_direction_axis});
}
namespace ngraph

View File

@ -173,7 +173,9 @@ namespace ngraph
NodeVector lstm_pass(bool is_reverse = false) const;
// Split(bi-directional) and squeeze input data to remove 'num_direction' dimension.
std::shared_ptr<Node> prepare_input(Output<Node> node, bool is_reverse) const;
std::shared_ptr<Node> prepare_input(Output<Node> node,
bool is_reverse,
size_t num_direction_axis = 0) const;
std::vector<float> m_activations_alpha;
std::vector<float> m_activations_beta;

View File

@ -1104,16 +1104,27 @@ TEST(attributes, lstm_cell_op)
TEST(attributes, lstm_sequence_op)
{
FactoryRegistry<Node>::get().register_factory<opset1::LSTMSequence>();
const auto X = make_shared<op::Parameter>(element::f32, Shape{1, 2, 4});
const auto initial_hidden_state = make_shared<op::Parameter>(element::f32, Shape{1, 2, 3});
const auto initial_cell_state = make_shared<op::Parameter>(element::f32, Shape{1, 2, 3});
const auto sequence_lengths = make_shared<op::Parameter>(element::i32, Shape{2});
const auto W = make_shared<op::Parameter>(element::f32, Shape{1, 12, 4});
const auto R = make_shared<op::Parameter>(element::f32, Shape{1, 12, 3});
const auto B = make_shared<op::Parameter>(element::f32, Shape{1, 12});
const auto hidden_size = 3;
const auto lstm_direction = op::LSTMSequence::direction::FORWARD;
const auto batch_size = 4;
const auto num_directions = 2;
const auto seq_length = 8;
const auto input_size = 16;
const auto hidden_size = 64;
const auto X =
make_shared<op::Parameter>(element::f32, Shape{batch_size, seq_length, input_size});
const auto initial_hidden_state =
make_shared<op::Parameter>(element::f32, Shape{batch_size, num_directions, hidden_size});
const auto initial_cell_state =
make_shared<op::Parameter>(element::f32, Shape{batch_size, num_directions, hidden_size});
const auto sequence_lengths = make_shared<op::Parameter>(element::i32, Shape{batch_size});
const auto W = make_shared<op::Parameter>(element::f32,
Shape{num_directions, 4 * hidden_size, input_size});
const auto R = make_shared<op::Parameter>(element::f32,
Shape{num_directions, 4 * hidden_size, hidden_size});
const auto B = make_shared<op::Parameter>(element::f32, Shape{num_directions, 4 * hidden_size});
const auto lstm_direction = op::LSTMSequence::direction::BIDIRECTIONAL;
const auto weights_format = op::LSTMWeightsFormat::ICOF;
const std::vector<float> activations_alpha = {1, 2, 3};
const std::vector<float> activations_beta = {4, 5, 6};

View File

@ -39,6 +39,7 @@ using namespace ngraph;
static std::string s_manifest = "${MANIFEST}";
// ONNX LSTM tests (implemented by nGraph LSTMCell and LSTMSequence)
NGRAPH_TEST(${BACKEND_NAME}, onnx_model_lstm_fwd_with_clip)
{
auto function = onnx_import::import_onnx_model(

View File

@ -21,16 +21,28 @@
using namespace std;
using namespace ngraph;
TEST(type_prop, lstm_sequence)
TEST(type_prop, lstm_sequence_forward)
{
const auto X = make_shared<op::Parameter>(element::f32, Shape{1, 2, 4});
const auto W = make_shared<op::Parameter>(element::f32, Shape{1, 12, 4});
const auto R = make_shared<op::Parameter>(element::f32, Shape{1, 12, 3});
const auto initial_hidden_state = make_shared<op::Parameter>(element::f32, Shape{1, 2, 3});
const auto initial_cell_state = make_shared<op::Parameter>(element::f32, Shape{1, 2, 3});
const auto B = make_shared<op::Parameter>(element::f32, Shape{1, 12});
const auto sequence_lengths = make_shared<op::Parameter>(element::i32, Shape{2});
const auto hidden_size = 3;
const auto batch_size = 8;
const auto num_directions = 1;
const auto seq_length = 6;
const auto input_size = 4;
const auto hidden_size = 128;
const auto X =
make_shared<op::Parameter>(element::f32, Shape{batch_size, seq_length, input_size});
const auto initial_hidden_state =
make_shared<op::Parameter>(element::f32, Shape{batch_size, num_directions, hidden_size});
const auto initial_cell_state =
make_shared<op::Parameter>(element::f32, Shape{batch_size, num_directions, hidden_size});
const auto sequence_lengths = make_shared<op::Parameter>(element::i32, Shape{batch_size});
const auto W = make_shared<op::Parameter>(element::f32,
Shape{num_directions, 4 * hidden_size, input_size});
const auto R = make_shared<op::Parameter>(element::f32,
Shape{num_directions, 4 * hidden_size, hidden_size});
const auto B = make_shared<op::Parameter>(element::f32, Shape{num_directions, 4 * hidden_size});
const auto lstm_direction = op::LSTMSequence::direction::FORWARD;
const auto lstm_sequence = make_shared<op::LSTMSequence>(X,
initial_hidden_state,
@ -40,7 +52,7 @@ TEST(type_prop, lstm_sequence)
R,
B,
hidden_size,
op::LSTMSequence::direction::FORWARD);
lstm_direction);
EXPECT_EQ(lstm_sequence->get_hidden_size(), hidden_size);
EXPECT_EQ(lstm_sequence->get_direction(), op::LSTMSequence::direction::FORWARD);
EXPECT_EQ(lstm_sequence->get_weights_format(), op::LSTMWeightsFormat::IFCO);
@ -52,9 +64,69 @@ TEST(type_prop, lstm_sequence)
EXPECT_EQ(lstm_sequence->get_clip_threshold(), 0.f);
EXPECT_FALSE(lstm_sequence->get_input_forget());
EXPECT_EQ(lstm_sequence->get_output_element_type(0), element::f32);
EXPECT_EQ(lstm_sequence->get_output_shape(0), (Shape{1, 1, 2, 3}));
EXPECT_EQ(lstm_sequence->get_output_shape(0),
(Shape{batch_size, num_directions, seq_length, hidden_size}));
EXPECT_EQ(lstm_sequence->get_output_element_type(1), element::f32);
EXPECT_EQ(lstm_sequence->get_output_shape(1), (Shape{1, 2, 3}));
EXPECT_EQ(lstm_sequence->get_output_shape(1), (Shape{batch_size, num_directions, hidden_size}));
EXPECT_EQ(lstm_sequence->get_output_element_type(2), element::f32);
EXPECT_EQ(lstm_sequence->get_output_shape(2), (Shape{1, 2, 3}));
EXPECT_EQ(lstm_sequence->get_output_shape(2), (Shape{batch_size, num_directions, hidden_size}));
}
TEST(type_prop, lstm_sequence_bidirectional)
{
const auto batch_size = 24;
const auto num_directions = 2;
const auto seq_length = 12;
const auto input_size = 8;
const auto hidden_size = 256;
const auto X =
make_shared<op::Parameter>(element::f32, Shape{batch_size, seq_length, input_size});
const auto initial_hidden_state =
make_shared<op::Parameter>(element::f32, Shape{batch_size, num_directions, hidden_size});
const auto initial_cell_state =
make_shared<op::Parameter>(element::f32, Shape{batch_size, num_directions, hidden_size});
const auto sequence_lengths = make_shared<op::Parameter>(element::i32, Shape{batch_size});
const auto W = make_shared<op::Parameter>(element::f32,
Shape{num_directions, 4 * hidden_size, input_size});
const auto R = make_shared<op::Parameter>(element::f32,
Shape{num_directions, 4 * hidden_size, hidden_size});
const auto B = make_shared<op::Parameter>(element::f32, Shape{num_directions, 4 * hidden_size});
const auto weights_format = op::LSTMWeightsFormat::FICO;
const auto lstm_direction = op::LSTMSequence::direction::BIDIRECTIONAL;
const std::vector<float> activations_alpha = {2.7, 7.0, 32.367};
const std::vector<float> activations_beta = {0.0, 5.49, 6.0};
const std::vector<std::string> activations = {"tanh", "sigmoid", "sigmoid"};
const auto lstm_sequence = make_shared<op::LSTMSequence>(X,
initial_hidden_state,
initial_cell_state,
sequence_lengths,
W,
R,
B,
hidden_size,
lstm_direction,
weights_format,
activations_alpha,
activations_beta,
activations);
EXPECT_EQ(lstm_sequence->get_hidden_size(), hidden_size);
EXPECT_EQ(lstm_sequence->get_direction(), op::LSTMSequence::direction::BIDIRECTIONAL);
EXPECT_EQ(lstm_sequence->get_weights_format(), op::LSTMWeightsFormat::FICO);
EXPECT_EQ(lstm_sequence->get_activations_alpha(), activations_alpha);
EXPECT_EQ(lstm_sequence->get_activations_beta(), activations_beta);
EXPECT_EQ(lstm_sequence->get_activations()[0], "tanh");
EXPECT_EQ(lstm_sequence->get_activations()[1], "sigmoid");
EXPECT_EQ(lstm_sequence->get_activations()[2], "sigmoid");
EXPECT_EQ(lstm_sequence->get_clip_threshold(), 0.f);
EXPECT_FALSE(lstm_sequence->get_input_forget());
EXPECT_EQ(lstm_sequence->get_output_element_type(0), element::f32);
EXPECT_EQ(lstm_sequence->get_output_shape(0),
(Shape{batch_size, num_directions, seq_length, hidden_size}));
EXPECT_EQ(lstm_sequence->get_output_element_type(1), element::f32);
EXPECT_EQ(lstm_sequence->get_output_shape(1), (Shape{batch_size, num_directions, hidden_size}));
EXPECT_EQ(lstm_sequence->get_output_element_type(2), element::f32);
EXPECT_EQ(lstm_sequence->get_output_shape(2), (Shape{batch_size, num_directions, hidden_size}));
}