[nGraph] Reorder nGraph LSTMSequence inputs and outputs dimensions (#560)
* Reorder nGraph LSTMSequence input/outpt dimensions * Update nGraph pythonAPI for LSTMSequence * Reorder axes in ONNX importer LSTM * Tests update * Fix clang warning * Use opset3 namespace * Style apply * Tests update * Use opset1 namespace * Remove usage of GetOutputElement in ONNX importer LSTM * Remove opset0 header * Use Node::output()
This commit is contained in:
@@ -472,11 +472,11 @@ def lstm_sequence(
|
||||
) -> Node:
|
||||
"""Return a node which performs LSTMSequence operation.
|
||||
|
||||
:param X: The input tensor. Shape: [seq_length, batch_size, input_size].
|
||||
:param X: The input tensor. Shape: [batch_size, seq_length, input_size].
|
||||
:param initial_hidden_state: The hidden state tensor.
|
||||
Shape: [num_directions, batch_size, hidden_size].
|
||||
Shape: [batch_size, num_directions, hidden_size].
|
||||
:param initial_cell_state: The cell state tensor.
|
||||
Shape: [num_directions, batch_size, hidden_size].
|
||||
Shape: [batch_size, num_directions, hidden_size].
|
||||
:param sequence_lengths: Specifies real sequence lengths for each batch element.
|
||||
Shape: [batch_size]. Integer type.
|
||||
:param W: Tensor with weights for matrix multiplication operation with input portion of data.
|
||||
|
||||
@@ -258,9 +258,9 @@ def test_lstm_sequence_operator_bidirectional(dtype):
|
||||
num_directions = 2
|
||||
seq_length = 2
|
||||
|
||||
X_shape = [seq_length, batch_size, input_size]
|
||||
H_t_shape = [num_directions, batch_size, hidden_size]
|
||||
C_t_shape = [num_directions, batch_size, hidden_size]
|
||||
X_shape = [batch_size, seq_length, input_size]
|
||||
H_t_shape = [batch_size, num_directions, hidden_size]
|
||||
C_t_shape = [batch_size, num_directions, hidden_size]
|
||||
seq_len_shape = [batch_size]
|
||||
W_shape = [num_directions, 4 * hidden_size, input_size]
|
||||
R_shape = [num_directions, 4 * hidden_size, hidden_size]
|
||||
@@ -323,9 +323,9 @@ def test_lstm_sequence_operator_reverse(dtype):
|
||||
num_directions = 1
|
||||
seq_length = 2
|
||||
|
||||
X_shape = [seq_length, batch_size, input_size]
|
||||
H_t_shape = [num_directions, batch_size, hidden_size]
|
||||
C_t_shape = [num_directions, batch_size, hidden_size]
|
||||
X_shape = [batch_size, seq_length, input_size]
|
||||
H_t_shape = [batch_size, num_directions, hidden_size]
|
||||
C_t_shape = [batch_size, num_directions, hidden_size]
|
||||
seq_len_shape = [batch_size]
|
||||
W_shape = [num_directions, 4 * hidden_size, input_size]
|
||||
R_shape = [num_directions, 4 * hidden_size, hidden_size]
|
||||
@@ -389,9 +389,9 @@ def test_lstm_sequence_operator_forward(dtype):
|
||||
num_directions = 1
|
||||
seq_length = 2
|
||||
|
||||
X_shape = [seq_length, batch_size, input_size]
|
||||
H_t_shape = [num_directions, batch_size, hidden_size]
|
||||
C_t_shape = [num_directions, batch_size, hidden_size]
|
||||
X_shape = [batch_size, seq_length, input_size]
|
||||
H_t_shape = [batch_size, num_directions, hidden_size]
|
||||
C_t_shape = [batch_size, num_directions, hidden_size]
|
||||
seq_len_shape = [batch_size]
|
||||
W_shape = [num_directions, 4 * hidden_size, input_size]
|
||||
R_shape = [num_directions, 4 * hidden_size, hidden_size]
|
||||
|
||||
Reference in New Issue
Block a user