[nGraph] Reorder nGraph LSTMSequence inputs and outputs dimensions (#560)
* Reorder nGraph LSTMSequence input/outpt dimensions * Update nGraph pythonAPI for LSTMSequence * Reorder axes in ONNX importer LSTM * Tests update * Fix clang warning * Use opset3 namespace * Style apply * Tests update * Use opset1 namespace * Remove usage of GetOutputElement in ONNX importer LSTM * Remove opset0 header * Use Node::output()
This commit is contained in:
parent
a4f13ae9fe
commit
5f8f9ec108
@ -472,11 +472,11 @@ def lstm_sequence(
|
||||
) -> Node:
|
||||
"""Return a node which performs LSTMSequence operation.
|
||||
|
||||
:param X: The input tensor. Shape: [seq_length, batch_size, input_size].
|
||||
:param X: The input tensor. Shape: [batch_size, seq_length, input_size].
|
||||
:param initial_hidden_state: The hidden state tensor.
|
||||
Shape: [num_directions, batch_size, hidden_size].
|
||||
Shape: [batch_size, num_directions, hidden_size].
|
||||
:param initial_cell_state: The cell state tensor.
|
||||
Shape: [num_directions, batch_size, hidden_size].
|
||||
Shape: [batch_size, num_directions, hidden_size].
|
||||
:param sequence_lengths: Specifies real sequence lengths for each batch element.
|
||||
Shape: [batch_size]. Integer type.
|
||||
:param W: Tensor with weights for matrix multiplication operation with input portion of data.
|
||||
|
@ -258,9 +258,9 @@ def test_lstm_sequence_operator_bidirectional(dtype):
|
||||
num_directions = 2
|
||||
seq_length = 2
|
||||
|
||||
X_shape = [seq_length, batch_size, input_size]
|
||||
H_t_shape = [num_directions, batch_size, hidden_size]
|
||||
C_t_shape = [num_directions, batch_size, hidden_size]
|
||||
X_shape = [batch_size, seq_length, input_size]
|
||||
H_t_shape = [batch_size, num_directions, hidden_size]
|
||||
C_t_shape = [batch_size, num_directions, hidden_size]
|
||||
seq_len_shape = [batch_size]
|
||||
W_shape = [num_directions, 4 * hidden_size, input_size]
|
||||
R_shape = [num_directions, 4 * hidden_size, hidden_size]
|
||||
@ -323,9 +323,9 @@ def test_lstm_sequence_operator_reverse(dtype):
|
||||
num_directions = 1
|
||||
seq_length = 2
|
||||
|
||||
X_shape = [seq_length, batch_size, input_size]
|
||||
H_t_shape = [num_directions, batch_size, hidden_size]
|
||||
C_t_shape = [num_directions, batch_size, hidden_size]
|
||||
X_shape = [batch_size, seq_length, input_size]
|
||||
H_t_shape = [batch_size, num_directions, hidden_size]
|
||||
C_t_shape = [batch_size, num_directions, hidden_size]
|
||||
seq_len_shape = [batch_size]
|
||||
W_shape = [num_directions, 4 * hidden_size, input_size]
|
||||
R_shape = [num_directions, 4 * hidden_size, hidden_size]
|
||||
@ -389,9 +389,9 @@ def test_lstm_sequence_operator_forward(dtype):
|
||||
num_directions = 1
|
||||
seq_length = 2
|
||||
|
||||
X_shape = [seq_length, batch_size, input_size]
|
||||
H_t_shape = [num_directions, batch_size, hidden_size]
|
||||
C_t_shape = [num_directions, batch_size, hidden_size]
|
||||
X_shape = [batch_size, seq_length, input_size]
|
||||
H_t_shape = [batch_size, num_directions, hidden_size]
|
||||
C_t_shape = [batch_size, num_directions, hidden_size]
|
||||
seq_len_shape = [batch_size]
|
||||
W_shape = [num_directions, 4 * hidden_size, input_size]
|
||||
R_shape = [num_directions, 4 * hidden_size, hidden_size]
|
||||
|
@ -24,13 +24,13 @@
|
||||
#include "default_opset.hpp"
|
||||
#include "exceptions.hpp"
|
||||
#include "lstm.hpp"
|
||||
#include "ngraph/builder/reshape.hpp"
|
||||
#include "ngraph/builder/split.hpp"
|
||||
#include "ngraph/frontend/onnx_import/op/lstm.hpp"
|
||||
#include "ngraph/op/add.hpp"
|
||||
#include "ngraph/op/constant.hpp"
|
||||
#include "ngraph/op/fused/lstm_sequence.hpp"
|
||||
#include "ngraph/op/get_output_element.hpp"
|
||||
#include "ngraph/opsets/opset0.hpp"
|
||||
#include "ngraph/shape.hpp"
|
||||
#include "ngraph/type/element_type.hpp"
|
||||
|
||||
@ -71,7 +71,8 @@ namespace ngraph
|
||||
|
||||
// ----- Mandatory inputs ------
|
||||
// Packed input sequences. Shape: [seq_length, batch_size, input_size]
|
||||
m_map[LSTMInput::LSTM_INPUT_X] = ng_inputs.at(0);
|
||||
m_map[LSTMInput::LSTM_INPUT_X] =
|
||||
builder::opset1::reorder_axes(ng_inputs.at(0), {1, 0, 2});
|
||||
// Weight tensor for the gates.
|
||||
// Shape: [num_directions, 4*hidden_size, input_size]
|
||||
m_map[LSTMInput::LSTM_INPUT_W] = ng_inputs.at(1);
|
||||
@ -82,7 +83,7 @@ namespace ngraph
|
||||
const std::size_t hidden_size =
|
||||
m_map[LSTMInput::LSTM_INPUT_R]->get_shape().back();
|
||||
const std::size_t batch_size =
|
||||
m_map[LSTMInput::LSTM_INPUT_X]->get_shape().at(1);
|
||||
m_map[LSTMInput::LSTM_INPUT_X]->get_shape().at(0);
|
||||
const std::size_t num_directions =
|
||||
m_map[LSTMInput::LSTM_INPUT_W]->get_shape().front();
|
||||
|
||||
@ -115,33 +116,35 @@ namespace ngraph
|
||||
Shape{batch_size},
|
||||
std::vector<std::int32_t>(
|
||||
batch_size,
|
||||
m_map[LSTMInput::LSTM_INPUT_X]->get_shape().at(0)));
|
||||
m_map[LSTMInput::LSTM_INPUT_X]->get_shape().at(1)));
|
||||
}
|
||||
// The initial value of the hidden.
|
||||
// Shape [num_directions, batch_size, hidden_size]
|
||||
if (ng_inputs.size() > 5 && !ng_inputs.at(5)->is_null())
|
||||
{
|
||||
m_map[LSTMInput::LSTM_INPUT_INIT_H] = ng_inputs.at(5);
|
||||
m_map[LSTMInput::LSTM_INPUT_INIT_H] =
|
||||
builder::opset1::reorder_axes(ng_inputs.at(5), {1, 0, 2});
|
||||
}
|
||||
else
|
||||
{
|
||||
m_map[LSTMInput::LSTM_INPUT_INIT_H] = default_opset::Constant::create(
|
||||
element::f32,
|
||||
Shape{num_directions, batch_size, hidden_size},
|
||||
std::vector<float>(num_directions * batch_size * hidden_size, 0.f));
|
||||
Shape{batch_size, num_directions, hidden_size},
|
||||
std::vector<float>(batch_size * num_directions * hidden_size, 0.f));
|
||||
}
|
||||
// The initial value of the cell.
|
||||
// Shape [num_directions, batch_size, hidden_size]
|
||||
if (ng_inputs.size() > 6 && !ng_inputs.at(6)->is_null())
|
||||
{
|
||||
m_map[LSTMInput::LSTM_INPUT_INIT_C] = ng_inputs.at(6);
|
||||
m_map[LSTMInput::LSTM_INPUT_INIT_C] =
|
||||
builder::opset1::reorder_axes(ng_inputs.at(6), {1, 0, 2});
|
||||
}
|
||||
else
|
||||
{
|
||||
m_map[LSTMInput::LSTM_INPUT_INIT_C] = default_opset::Constant::create(
|
||||
element::f32,
|
||||
Shape{num_directions, batch_size, hidden_size},
|
||||
std::vector<float>(num_directions * batch_size * hidden_size, 0.f));
|
||||
Shape{batch_size, num_directions, hidden_size},
|
||||
std::vector<float>(batch_size * num_directions * hidden_size, 0.f));
|
||||
}
|
||||
// The weight tensor for peepholes. Shape [num_directions, 3*hidde_size]
|
||||
if (ng_inputs.size() > 7 && !ng_inputs.at(7)->is_null())
|
||||
@ -239,9 +242,14 @@ namespace ngraph
|
||||
attributes.m_activations,
|
||||
attributes.m_clip_threshold,
|
||||
attributes.m_input_forget);
|
||||
return {std::make_shared<ngraph::opset0::GetOutputElement>(lstmSequence, 0),
|
||||
std::make_shared<ngraph::opset0::GetOutputElement>(lstmSequence, 1),
|
||||
std::make_shared<ngraph::opset0::GetOutputElement>(lstmSequence, 2)};
|
||||
|
||||
const auto Y = lstmSequence->output(0);
|
||||
const auto Y_h = lstmSequence->output(1);
|
||||
const auto Y_c = lstmSequence->output(2);
|
||||
|
||||
return {builder::opset1::reorder_axes(Y, {2, 1, 0, 3}),
|
||||
builder::opset1::reorder_axes(Y_h, {1, 0, 2}),
|
||||
builder::opset1::reorder_axes(Y_c, {1, 0, 2})};
|
||||
}
|
||||
} // namespace set_1
|
||||
|
||||
|
@ -20,19 +20,13 @@
|
||||
#include "ngraph/builder/autobroadcast.hpp"
|
||||
#include "ngraph/builder/reshape.hpp"
|
||||
#include "ngraph/builder/split.hpp"
|
||||
#include "ngraph/frontend/onnx_import/utils/reshape.hpp"
|
||||
#include "ngraph/op/concat.hpp"
|
||||
#include "ngraph/op/constant.hpp"
|
||||
#include "ngraph/op/fused/lstm_cell.hpp"
|
||||
#include "ngraph/op/get_output_element.hpp"
|
||||
#include "ngraph/op/greater.hpp"
|
||||
#include "ngraph/op/reverse_sequence.hpp"
|
||||
#include "ngraph/op/select.hpp"
|
||||
|
||||
#include "ngraph/opsets/opset1.hpp"
|
||||
|
||||
using namespace ngraph;
|
||||
using namespace std;
|
||||
|
||||
constexpr NodeTypeInfo op::LSTMSequence::type_info;
|
||||
constexpr NodeTypeInfo op::v0::LSTMSequence::type_info;
|
||||
bool ngraph::op::v0::LSTMSequence::visit_attributes(AttributeVisitor& visitor)
|
||||
{
|
||||
visitor.on_attribute("hidden_size", m_hidden_size);
|
||||
@ -46,7 +40,7 @@ bool ngraph::op::v0::LSTMSequence::visit_attributes(AttributeVisitor& visitor)
|
||||
visitor.on_attribute("weights_format", m_weights_format);
|
||||
return true;
|
||||
}
|
||||
NodeVector op::LSTMSequence::decompose_op() const
|
||||
NodeVector op::v0::LSTMSequence::decompose_op() const
|
||||
{
|
||||
NodeVector results;
|
||||
if (m_direction == direction::FORWARD || m_direction == direction::REVERSE)
|
||||
@ -60,55 +54,55 @@ NodeVector op::LSTMSequence::decompose_op() const
|
||||
|
||||
// Stack together respective outputs from both forward and reverse passess.
|
||||
shared_ptr<Node> Y{
|
||||
make_shared<op::Concat>(NodeVector{fwd_results.at(0), rev_results.at(0)}, 1)};
|
||||
make_shared<opset1::Concat>(NodeVector{fwd_results.at(0), rev_results.at(0)}, 1)};
|
||||
shared_ptr<Node> Y_h{
|
||||
make_shared<op::Concat>(NodeVector{fwd_results.at(1), rev_results.at(1)}, 0)};
|
||||
make_shared<opset1::Concat>(NodeVector{fwd_results.at(1), rev_results.at(1)}, 1)};
|
||||
shared_ptr<Node> Y_c{
|
||||
make_shared<op::Concat>(NodeVector{fwd_results.at(2), rev_results.at(2)}, 0)};
|
||||
make_shared<opset1::Concat>(NodeVector{fwd_results.at(2), rev_results.at(2)}, 1)};
|
||||
results = NodeVector{Y, Y_h, Y_c};
|
||||
}
|
||||
return results;
|
||||
}
|
||||
|
||||
shared_ptr<Node> op::LSTMSequence::clone_with_new_inputs(const OutputVector& new_args) const
|
||||
shared_ptr<Node> op::v0::LSTMSequence::clone_with_new_inputs(const OutputVector& new_args) const
|
||||
{
|
||||
check_new_args_count(this, new_args);
|
||||
if (new_args.size() == 8)
|
||||
{
|
||||
return make_shared<LSTMSequence>(new_args.at(0), // X
|
||||
new_args.at(1), // initial_hidden_state
|
||||
new_args.at(2), // initial_cell_state
|
||||
new_args.at(3), // sequence_lengths
|
||||
new_args.at(4), // W
|
||||
new_args.at(5), // R
|
||||
new_args.at(6), // B
|
||||
new_args.at(7), // P
|
||||
m_hidden_size,
|
||||
m_direction,
|
||||
m_weights_format,
|
||||
m_activations_alpha,
|
||||
m_activations_beta,
|
||||
m_activations,
|
||||
m_clip_threshold,
|
||||
m_input_forget);
|
||||
return make_shared<op::v0::LSTMSequence>(new_args.at(0), // X
|
||||
new_args.at(1), // initial_hidden_state
|
||||
new_args.at(2), // initial_cell_state
|
||||
new_args.at(3), // sequence_lengths
|
||||
new_args.at(4), // W
|
||||
new_args.at(5), // R
|
||||
new_args.at(6), // B
|
||||
new_args.at(7), // P
|
||||
m_hidden_size,
|
||||
m_direction,
|
||||
m_weights_format,
|
||||
m_activations_alpha,
|
||||
m_activations_beta,
|
||||
m_activations,
|
||||
m_clip_threshold,
|
||||
m_input_forget);
|
||||
}
|
||||
else if (new_args.size() == 7)
|
||||
{
|
||||
return make_shared<LSTMSequence>(new_args.at(0), // X
|
||||
new_args.at(1), // initial_hidden_state
|
||||
new_args.at(2), // initial_cell_state
|
||||
new_args.at(3), // sequence_lengths
|
||||
new_args.at(4), // W
|
||||
new_args.at(5), // R
|
||||
new_args.at(6), // B
|
||||
m_hidden_size,
|
||||
m_direction,
|
||||
m_weights_format,
|
||||
m_activations_alpha,
|
||||
m_activations_beta,
|
||||
m_activations,
|
||||
m_clip_threshold,
|
||||
m_input_forget);
|
||||
return make_shared<op::v0::LSTMSequence>(new_args.at(0), // X
|
||||
new_args.at(1), // initial_hidden_state
|
||||
new_args.at(2), // initial_cell_state
|
||||
new_args.at(3), // sequence_lengths
|
||||
new_args.at(4), // W
|
||||
new_args.at(5), // R
|
||||
new_args.at(6), // B
|
||||
m_hidden_size,
|
||||
m_direction,
|
||||
m_weights_format,
|
||||
m_activations_alpha,
|
||||
m_activations_beta,
|
||||
m_activations,
|
||||
m_clip_threshold,
|
||||
m_input_forget);
|
||||
}
|
||||
else
|
||||
{
|
||||
@ -116,46 +110,44 @@ shared_ptr<Node> op::LSTMSequence::clone_with_new_inputs(const OutputVector& new
|
||||
}
|
||||
}
|
||||
|
||||
shared_ptr<Node> op::LSTMSequence::get_masked_node(const Output<Node>& data,
|
||||
int32_t time_step,
|
||||
size_t batch_axis,
|
||||
const Output<Node>& default_value) const
|
||||
shared_ptr<Node> op::v0::LSTMSequence::get_masked_node(const Output<Node>& data,
|
||||
int32_t time_step,
|
||||
size_t batch_axis,
|
||||
const Output<Node>& default_value) const
|
||||
{
|
||||
Output<Node> mask_value = default_value;
|
||||
// Create zero mask value node.
|
||||
if (!mask_value.get_node_shared_ptr())
|
||||
{
|
||||
mask_value = op::Constant::create(data.get_element_type(),
|
||||
data.get_shape(),
|
||||
vector<float>(shape_size(data.get_shape()), 0.f));
|
||||
mask_value = opset1::Constant::create(data.get_element_type(),
|
||||
data.get_shape(),
|
||||
vector<float>(shape_size(data.get_shape()), 0.f));
|
||||
}
|
||||
|
||||
// Create predicate nodes. The condition is whether current time step value
|
||||
// is greater than sequence length for respective batch inputs.
|
||||
shared_ptr<Node> curr_time_step_node = op::Constant::create(
|
||||
shared_ptr<Node> curr_time_step_node = opset1::Constant::create(
|
||||
element::i32, data.get_shape(), vector<int32_t>(shape_size(data.get_shape()), time_step));
|
||||
|
||||
Output<Node> batch_seq_length =
|
||||
builder::legacy_broadcast_for_binary_operation(
|
||||
curr_time_step_node, input_value(3).get_node_shared_ptr(), batch_axis)
|
||||
.at(1);
|
||||
Output<Node> batch_seq_length = builder::opset1::legacy_broadcast_for_binary_operation(
|
||||
curr_time_step_node, input_value(3).get_node_shared_ptr(), batch_axis);
|
||||
|
||||
// Create mask node deciding whether or not to mask batch data.
|
||||
shared_ptr<Node> mask_condition =
|
||||
make_shared<op::Greater>(curr_time_step_node, batch_seq_length);
|
||||
make_shared<opset1::Greater>(curr_time_step_node, batch_seq_length);
|
||||
|
||||
// Select values depnding on mask_condition.
|
||||
// Select(<condition>, <true_value>, <false_value>)
|
||||
return make_shared<op::Select>(mask_condition, mask_value, data);
|
||||
return make_shared<opset1::Select>(mask_condition, mask_value, data);
|
||||
}
|
||||
|
||||
NodeVector op::LSTMSequence::lstm_pass(bool is_reverse) const
|
||||
NodeVector op::v0::LSTMSequence::lstm_pass(bool is_reverse) const
|
||||
{
|
||||
// ------ VARIABLE'S NAMES AND ACRONYM DEFINITIONS ------
|
||||
// The names used below are analogous to the one used in ONNX documentation.
|
||||
//
|
||||
// ------ INPUTS ------
|
||||
// X - The input tensor. [seq_length, batch_size, input_size]
|
||||
// X - The input tensor. [batch_size, seq_length, input_size]
|
||||
// W - The weight tensor. [num_directions, 4*hidden_size, input_size]
|
||||
// R - The recurrence weight tensor. [num_directions, 4*hidden_size, hidden_size]
|
||||
// B - The bias tensor for input gate. [num_directions, 8*hidden_size]
|
||||
@ -167,14 +159,14 @@ NodeVector op::LSTMSequence::lstm_pass(bool is_reverse) const
|
||||
// c - cell gate
|
||||
// t - time step (t-1 means previous time step)
|
||||
// ------ VARIABLE NAMES ------
|
||||
// H_t - Hidden state vector at current time step.
|
||||
// C_t - Cell state vector at current time step.
|
||||
// H_t - Hidden state vector at current time step. [batch_size, num_directions, hidden_size]
|
||||
// C_t - Cell state vector at current time step. [batch_size, num_directions, hidden_size]
|
||||
// h_list - The list of hidden states at all processed time steps.
|
||||
|
||||
NodeVector h_list;
|
||||
shared_ptr<Node> X = input_value(0).get_node_shared_ptr();
|
||||
shared_ptr<Node> H_t = prepare_input(input_value(1), is_reverse);
|
||||
shared_ptr<Node> C_t = prepare_input(input_value(2), is_reverse);
|
||||
shared_ptr<Node> H_t = prepare_input(input_value(1), is_reverse, 1);
|
||||
shared_ptr<Node> C_t = prepare_input(input_value(2), is_reverse, 1);
|
||||
shared_ptr<Node> seq_lengths = input_value(3).get_node_shared_ptr();
|
||||
shared_ptr<Node> W = prepare_input(input_value(4), is_reverse);
|
||||
shared_ptr<Node> R = prepare_input(input_value(5), is_reverse);
|
||||
@ -183,34 +175,34 @@ NodeVector op::LSTMSequence::lstm_pass(bool is_reverse) const
|
||||
|
||||
if (is_reverse)
|
||||
{
|
||||
X = make_shared<op::ReverseSequence>(X, seq_lengths, 1 /*batch_axis*/, 0 /*seq_axis*/);
|
||||
X = make_shared<opset1::ReverseSequence>(X, seq_lengths, 0 /*batch_axis*/, 1 /*seq_axis*/);
|
||||
}
|
||||
|
||||
NodeVector in_seqs = builder::split(X, X->get_shape().at(0));
|
||||
NodeVector in_seqs = builder::opset1::split(X, X->get_shape().at(1), 1);
|
||||
|
||||
for (auto& in_x : in_seqs)
|
||||
{
|
||||
// remove first empty dim, after above split.
|
||||
in_x = builder::squeeze(in_x);
|
||||
// Remove empty dim, after above split.
|
||||
in_x = builder::opset1::squeeze(in_x, {1});
|
||||
}
|
||||
|
||||
int32_t time_step{1};
|
||||
for (const auto& in_x : in_seqs)
|
||||
{
|
||||
shared_ptr<Node> lstm_cell = make_shared<op::LSTMCell>(in_x,
|
||||
H_t,
|
||||
C_t,
|
||||
W,
|
||||
R,
|
||||
B,
|
||||
P,
|
||||
m_hidden_size,
|
||||
m_weights_format,
|
||||
m_activations,
|
||||
m_activations_alpha,
|
||||
m_activations_beta,
|
||||
m_clip_threshold,
|
||||
m_input_forget);
|
||||
shared_ptr<Node> lstm_cell = make_shared<opset1::LSTMCell>(in_x,
|
||||
H_t,
|
||||
C_t,
|
||||
W,
|
||||
R,
|
||||
B,
|
||||
P,
|
||||
m_hidden_size,
|
||||
m_weights_format,
|
||||
m_activations,
|
||||
m_activations_alpha,
|
||||
m_activations_beta,
|
||||
m_clip_threshold,
|
||||
m_input_forget);
|
||||
|
||||
Output<Node> H = lstm_cell->output(0);
|
||||
Output<Node> C = lstm_cell->output(1);
|
||||
@ -220,7 +212,7 @@ NodeVector op::LSTMSequence::lstm_pass(bool is_reverse) const
|
||||
// Mask hidden state tensor in order to handle mixed sequence lengths.
|
||||
// This results in zeroing out values in batches with sequence shorter
|
||||
// than current time_step.
|
||||
h_list.push_back(get_masked_node(builder::expand_dims(H), time_step, 1));
|
||||
h_list.push_back(get_masked_node(builder::opset1::expand_dims(H, 1), time_step, 0));
|
||||
// Reference implementation in ONNX Runtime doesn't mask values of Y_h
|
||||
// and Y_c outputs, thus here we make sure that only appropriate batches
|
||||
// (in respect to its sequence length) are updated. Those batches which
|
||||
@ -230,36 +222,38 @@ NodeVector op::LSTMSequence::lstm_pass(bool is_reverse) const
|
||||
time_step++;
|
||||
}
|
||||
// The tensor that concats all the intermediate output values of the hidden.
|
||||
// It has shape [seq_length, batch_size, hidden_size]
|
||||
shared_ptr<Node> Y{make_shared<op::Concat>(h_list, 0)};
|
||||
// It has shape [batch_size, seq_length, hidden_size]
|
||||
shared_ptr<Node> Y{make_shared<opset1::Concat>(h_list, 1)};
|
||||
|
||||
// Get back the original order of the output data.
|
||||
if (is_reverse)
|
||||
{
|
||||
Y = make_shared<op::ReverseSequence>(Y, seq_lengths, 1 /*batch_axis*/, 0 /*seq_axis*/);
|
||||
Y = make_shared<opset1::ReverseSequence>(Y, seq_lengths, 0 /*batch_axis*/, 1 /*seq_axis*/);
|
||||
}
|
||||
|
||||
// Expand Y so that it has expected shape:
|
||||
// [seq_length, num_directions, batch_size, hidden_size]
|
||||
Y = builder::expand_dims(Y, 1);
|
||||
// [batch_size, num_directions, seq_length, hidden_size]
|
||||
Y = builder::opset1::expand_dims(Y, 1);
|
||||
|
||||
// expand H_t and C_t so that it has expected shape:
|
||||
// [num_directions, batch_size, hidden_size]
|
||||
auto Y_h = builder::expand_dims(H_t);
|
||||
auto Y_c = builder::expand_dims(C_t);
|
||||
// [ batch_size, num_directions, hidden_size]
|
||||
auto Y_h = builder::opset1::expand_dims(H_t, 1);
|
||||
auto Y_c = builder::opset1::expand_dims(C_t, 1);
|
||||
return {Y, Y_h, Y_c};
|
||||
}
|
||||
|
||||
shared_ptr<Node> op::LSTMSequence::prepare_input(Output<Node> node, bool is_reverse) const
|
||||
shared_ptr<Node> op::v0::LSTMSequence::prepare_input(Output<Node> node,
|
||||
bool is_reverse,
|
||||
size_t num_direction_axis) const
|
||||
{
|
||||
// In bidirectional mode inputs are stacked together, so we must split them.
|
||||
shared_ptr<Node> tmp = node.get_node_shared_ptr();
|
||||
if (m_direction == direction::BIDIRECTIONAL)
|
||||
{
|
||||
tmp = builder::split(node, 2).at(is_reverse ? 1 : 0);
|
||||
tmp = builder::opset1::split(node, 2, num_direction_axis).at(is_reverse ? 1 : 0);
|
||||
}
|
||||
// Since we have forward LSTM we can squeeze `num_directions` axis from inputs.
|
||||
return builder::squeeze(tmp);
|
||||
return builder::opset1::squeeze(tmp, {num_direction_axis});
|
||||
}
|
||||
|
||||
namespace ngraph
|
||||
|
@ -173,7 +173,9 @@ namespace ngraph
|
||||
NodeVector lstm_pass(bool is_reverse = false) const;
|
||||
|
||||
// Split(bi-directional) and squeeze input data to remove 'num_direction' dimension.
|
||||
std::shared_ptr<Node> prepare_input(Output<Node> node, bool is_reverse) const;
|
||||
std::shared_ptr<Node> prepare_input(Output<Node> node,
|
||||
bool is_reverse,
|
||||
size_t num_direction_axis = 0) const;
|
||||
|
||||
std::vector<float> m_activations_alpha;
|
||||
std::vector<float> m_activations_beta;
|
||||
|
@ -1104,16 +1104,27 @@ TEST(attributes, lstm_cell_op)
|
||||
TEST(attributes, lstm_sequence_op)
|
||||
{
|
||||
FactoryRegistry<Node>::get().register_factory<opset1::LSTMSequence>();
|
||||
const auto X = make_shared<op::Parameter>(element::f32, Shape{1, 2, 4});
|
||||
const auto initial_hidden_state = make_shared<op::Parameter>(element::f32, Shape{1, 2, 3});
|
||||
const auto initial_cell_state = make_shared<op::Parameter>(element::f32, Shape{1, 2, 3});
|
||||
const auto sequence_lengths = make_shared<op::Parameter>(element::i32, Shape{2});
|
||||
const auto W = make_shared<op::Parameter>(element::f32, Shape{1, 12, 4});
|
||||
const auto R = make_shared<op::Parameter>(element::f32, Shape{1, 12, 3});
|
||||
const auto B = make_shared<op::Parameter>(element::f32, Shape{1, 12});
|
||||
|
||||
const auto hidden_size = 3;
|
||||
const auto lstm_direction = op::LSTMSequence::direction::FORWARD;
|
||||
const auto batch_size = 4;
|
||||
const auto num_directions = 2;
|
||||
const auto seq_length = 8;
|
||||
const auto input_size = 16;
|
||||
const auto hidden_size = 64;
|
||||
|
||||
const auto X =
|
||||
make_shared<op::Parameter>(element::f32, Shape{batch_size, seq_length, input_size});
|
||||
const auto initial_hidden_state =
|
||||
make_shared<op::Parameter>(element::f32, Shape{batch_size, num_directions, hidden_size});
|
||||
const auto initial_cell_state =
|
||||
make_shared<op::Parameter>(element::f32, Shape{batch_size, num_directions, hidden_size});
|
||||
const auto sequence_lengths = make_shared<op::Parameter>(element::i32, Shape{batch_size});
|
||||
const auto W = make_shared<op::Parameter>(element::f32,
|
||||
Shape{num_directions, 4 * hidden_size, input_size});
|
||||
const auto R = make_shared<op::Parameter>(element::f32,
|
||||
Shape{num_directions, 4 * hidden_size, hidden_size});
|
||||
const auto B = make_shared<op::Parameter>(element::f32, Shape{num_directions, 4 * hidden_size});
|
||||
|
||||
const auto lstm_direction = op::LSTMSequence::direction::BIDIRECTIONAL;
|
||||
const auto weights_format = op::LSTMWeightsFormat::ICOF;
|
||||
const std::vector<float> activations_alpha = {1, 2, 3};
|
||||
const std::vector<float> activations_beta = {4, 5, 6};
|
||||
|
@ -39,6 +39,7 @@ using namespace ngraph;
|
||||
|
||||
static std::string s_manifest = "${MANIFEST}";
|
||||
|
||||
// ONNX LSTM tests (implemented by nGraph LSTMCell and LSTMSequence)
|
||||
NGRAPH_TEST(${BACKEND_NAME}, onnx_model_lstm_fwd_with_clip)
|
||||
{
|
||||
auto function = onnx_import::import_onnx_model(
|
||||
|
@ -21,16 +21,28 @@
|
||||
using namespace std;
|
||||
using namespace ngraph;
|
||||
|
||||
TEST(type_prop, lstm_sequence)
|
||||
TEST(type_prop, lstm_sequence_forward)
|
||||
{
|
||||
const auto X = make_shared<op::Parameter>(element::f32, Shape{1, 2, 4});
|
||||
const auto W = make_shared<op::Parameter>(element::f32, Shape{1, 12, 4});
|
||||
const auto R = make_shared<op::Parameter>(element::f32, Shape{1, 12, 3});
|
||||
const auto initial_hidden_state = make_shared<op::Parameter>(element::f32, Shape{1, 2, 3});
|
||||
const auto initial_cell_state = make_shared<op::Parameter>(element::f32, Shape{1, 2, 3});
|
||||
const auto B = make_shared<op::Parameter>(element::f32, Shape{1, 12});
|
||||
const auto sequence_lengths = make_shared<op::Parameter>(element::i32, Shape{2});
|
||||
const auto hidden_size = 3;
|
||||
const auto batch_size = 8;
|
||||
const auto num_directions = 1;
|
||||
const auto seq_length = 6;
|
||||
const auto input_size = 4;
|
||||
const auto hidden_size = 128;
|
||||
|
||||
const auto X =
|
||||
make_shared<op::Parameter>(element::f32, Shape{batch_size, seq_length, input_size});
|
||||
const auto initial_hidden_state =
|
||||
make_shared<op::Parameter>(element::f32, Shape{batch_size, num_directions, hidden_size});
|
||||
const auto initial_cell_state =
|
||||
make_shared<op::Parameter>(element::f32, Shape{batch_size, num_directions, hidden_size});
|
||||
const auto sequence_lengths = make_shared<op::Parameter>(element::i32, Shape{batch_size});
|
||||
const auto W = make_shared<op::Parameter>(element::f32,
|
||||
Shape{num_directions, 4 * hidden_size, input_size});
|
||||
const auto R = make_shared<op::Parameter>(element::f32,
|
||||
Shape{num_directions, 4 * hidden_size, hidden_size});
|
||||
const auto B = make_shared<op::Parameter>(element::f32, Shape{num_directions, 4 * hidden_size});
|
||||
|
||||
const auto lstm_direction = op::LSTMSequence::direction::FORWARD;
|
||||
|
||||
const auto lstm_sequence = make_shared<op::LSTMSequence>(X,
|
||||
initial_hidden_state,
|
||||
@ -40,7 +52,7 @@ TEST(type_prop, lstm_sequence)
|
||||
R,
|
||||
B,
|
||||
hidden_size,
|
||||
op::LSTMSequence::direction::FORWARD);
|
||||
lstm_direction);
|
||||
EXPECT_EQ(lstm_sequence->get_hidden_size(), hidden_size);
|
||||
EXPECT_EQ(lstm_sequence->get_direction(), op::LSTMSequence::direction::FORWARD);
|
||||
EXPECT_EQ(lstm_sequence->get_weights_format(), op::LSTMWeightsFormat::IFCO);
|
||||
@ -52,9 +64,69 @@ TEST(type_prop, lstm_sequence)
|
||||
EXPECT_EQ(lstm_sequence->get_clip_threshold(), 0.f);
|
||||
EXPECT_FALSE(lstm_sequence->get_input_forget());
|
||||
EXPECT_EQ(lstm_sequence->get_output_element_type(0), element::f32);
|
||||
EXPECT_EQ(lstm_sequence->get_output_shape(0), (Shape{1, 1, 2, 3}));
|
||||
EXPECT_EQ(lstm_sequence->get_output_shape(0),
|
||||
(Shape{batch_size, num_directions, seq_length, hidden_size}));
|
||||
EXPECT_EQ(lstm_sequence->get_output_element_type(1), element::f32);
|
||||
EXPECT_EQ(lstm_sequence->get_output_shape(1), (Shape{1, 2, 3}));
|
||||
EXPECT_EQ(lstm_sequence->get_output_shape(1), (Shape{batch_size, num_directions, hidden_size}));
|
||||
EXPECT_EQ(lstm_sequence->get_output_element_type(2), element::f32);
|
||||
EXPECT_EQ(lstm_sequence->get_output_shape(2), (Shape{1, 2, 3}));
|
||||
EXPECT_EQ(lstm_sequence->get_output_shape(2), (Shape{batch_size, num_directions, hidden_size}));
|
||||
}
|
||||
|
||||
TEST(type_prop, lstm_sequence_bidirectional)
|
||||
{
|
||||
const auto batch_size = 24;
|
||||
const auto num_directions = 2;
|
||||
const auto seq_length = 12;
|
||||
const auto input_size = 8;
|
||||
const auto hidden_size = 256;
|
||||
|
||||
const auto X =
|
||||
make_shared<op::Parameter>(element::f32, Shape{batch_size, seq_length, input_size});
|
||||
const auto initial_hidden_state =
|
||||
make_shared<op::Parameter>(element::f32, Shape{batch_size, num_directions, hidden_size});
|
||||
const auto initial_cell_state =
|
||||
make_shared<op::Parameter>(element::f32, Shape{batch_size, num_directions, hidden_size});
|
||||
const auto sequence_lengths = make_shared<op::Parameter>(element::i32, Shape{batch_size});
|
||||
const auto W = make_shared<op::Parameter>(element::f32,
|
||||
Shape{num_directions, 4 * hidden_size, input_size});
|
||||
const auto R = make_shared<op::Parameter>(element::f32,
|
||||
Shape{num_directions, 4 * hidden_size, hidden_size});
|
||||
const auto B = make_shared<op::Parameter>(element::f32, Shape{num_directions, 4 * hidden_size});
|
||||
|
||||
const auto weights_format = op::LSTMWeightsFormat::FICO;
|
||||
const auto lstm_direction = op::LSTMSequence::direction::BIDIRECTIONAL;
|
||||
const std::vector<float> activations_alpha = {2.7, 7.0, 32.367};
|
||||
const std::vector<float> activations_beta = {0.0, 5.49, 6.0};
|
||||
const std::vector<std::string> activations = {"tanh", "sigmoid", "sigmoid"};
|
||||
|
||||
const auto lstm_sequence = make_shared<op::LSTMSequence>(X,
|
||||
initial_hidden_state,
|
||||
initial_cell_state,
|
||||
sequence_lengths,
|
||||
W,
|
||||
R,
|
||||
B,
|
||||
hidden_size,
|
||||
lstm_direction,
|
||||
weights_format,
|
||||
activations_alpha,
|
||||
activations_beta,
|
||||
activations);
|
||||
EXPECT_EQ(lstm_sequence->get_hidden_size(), hidden_size);
|
||||
EXPECT_EQ(lstm_sequence->get_direction(), op::LSTMSequence::direction::BIDIRECTIONAL);
|
||||
EXPECT_EQ(lstm_sequence->get_weights_format(), op::LSTMWeightsFormat::FICO);
|
||||
EXPECT_EQ(lstm_sequence->get_activations_alpha(), activations_alpha);
|
||||
EXPECT_EQ(lstm_sequence->get_activations_beta(), activations_beta);
|
||||
EXPECT_EQ(lstm_sequence->get_activations()[0], "tanh");
|
||||
EXPECT_EQ(lstm_sequence->get_activations()[1], "sigmoid");
|
||||
EXPECT_EQ(lstm_sequence->get_activations()[2], "sigmoid");
|
||||
EXPECT_EQ(lstm_sequence->get_clip_threshold(), 0.f);
|
||||
EXPECT_FALSE(lstm_sequence->get_input_forget());
|
||||
EXPECT_EQ(lstm_sequence->get_output_element_type(0), element::f32);
|
||||
EXPECT_EQ(lstm_sequence->get_output_shape(0),
|
||||
(Shape{batch_size, num_directions, seq_length, hidden_size}));
|
||||
EXPECT_EQ(lstm_sequence->get_output_element_type(1), element::f32);
|
||||
EXPECT_EQ(lstm_sequence->get_output_shape(1), (Shape{batch_size, num_directions, hidden_size}));
|
||||
EXPECT_EQ(lstm_sequence->get_output_element_type(2), element::f32);
|
||||
EXPECT_EQ(lstm_sequence->get_output_shape(2), (Shape{batch_size, num_directions, hidden_size}));
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user