[nGraph] Reorder nGraph LSTMSequence inputs and outputs dimensions (#560)

* Reorder nGraph LSTMSequence input/outpt dimensions

* Update nGraph pythonAPI for LSTMSequence

* Reorder axes in ONNX importer LSTM

* Tests update

* Fix clang warning

* Use opset3 namespace

* Style apply

* Tests update

* Use opset1  namespace

* Remove usage of  GetOutputElement in ONNX importer LSTM

* Remove opset0 header

* Use Node::output()
This commit is contained in:
Katarzyna Mitrus
2020-05-29 13:29:18 +02:00
committed by GitHub
parent a4f13ae9fe
commit 5f8f9ec108
8 changed files with 231 additions and 143 deletions

View File

@@ -472,11 +472,11 @@ def lstm_sequence(
) -> Node:
"""Return a node which performs LSTMSequence operation.
:param X: The input tensor. Shape: [seq_length, batch_size, input_size].
:param X: The input tensor. Shape: [batch_size, seq_length, input_size].
:param initial_hidden_state: The hidden state tensor.
Shape: [num_directions, batch_size, hidden_size].
Shape: [batch_size, num_directions, hidden_size].
:param initial_cell_state: The cell state tensor.
Shape: [num_directions, batch_size, hidden_size].
Shape: [batch_size, num_directions, hidden_size].
:param sequence_lengths: Specifies real sequence lengths for each batch element.
Shape: [batch_size]. Integer type.
:param W: Tensor with weights for matrix multiplication operation with input portion of data.

View File

@@ -258,9 +258,9 @@ def test_lstm_sequence_operator_bidirectional(dtype):
num_directions = 2
seq_length = 2
X_shape = [seq_length, batch_size, input_size]
H_t_shape = [num_directions, batch_size, hidden_size]
C_t_shape = [num_directions, batch_size, hidden_size]
X_shape = [batch_size, seq_length, input_size]
H_t_shape = [batch_size, num_directions, hidden_size]
C_t_shape = [batch_size, num_directions, hidden_size]
seq_len_shape = [batch_size]
W_shape = [num_directions, 4 * hidden_size, input_size]
R_shape = [num_directions, 4 * hidden_size, hidden_size]
@@ -323,9 +323,9 @@ def test_lstm_sequence_operator_reverse(dtype):
num_directions = 1
seq_length = 2
X_shape = [seq_length, batch_size, input_size]
H_t_shape = [num_directions, batch_size, hidden_size]
C_t_shape = [num_directions, batch_size, hidden_size]
X_shape = [batch_size, seq_length, input_size]
H_t_shape = [batch_size, num_directions, hidden_size]
C_t_shape = [batch_size, num_directions, hidden_size]
seq_len_shape = [batch_size]
W_shape = [num_directions, 4 * hidden_size, input_size]
R_shape = [num_directions, 4 * hidden_size, hidden_size]
@@ -389,9 +389,9 @@ def test_lstm_sequence_operator_forward(dtype):
num_directions = 1
seq_length = 2
X_shape = [seq_length, batch_size, input_size]
H_t_shape = [num_directions, batch_size, hidden_size]
C_t_shape = [num_directions, batch_size, hidden_size]
X_shape = [batch_size, seq_length, input_size]
H_t_shape = [batch_size, num_directions, hidden_size]
C_t_shape = [batch_size, num_directions, hidden_size]
seq_len_shape = [batch_size]
W_shape = [num_directions, 4 * hidden_size, input_size]
R_shape = [num_directions, 4 * hidden_size, hidden_size]