From 3ce0f2573a4e17abcd914cb0887fc49f1086173e Mon Sep 17 00:00:00 2001 From: Mang Guo Date: Wed, 27 Oct 2021 14:00:13 +0800 Subject: [PATCH] [pdpd] Specify SEQ_LEN for each batch (#8057) * Specify SEQ_LEN for each batch * Generate different seq_len for each batch * Remove print --- ngraph/frontend/paddlepaddle/src/op/lstm.cpp | 16 +++-- .../test/frontend/paddlepaddle/op_fuzzy.cpp | 2 + .../gen_scripts/generate_rnn_lstm.py | 61 +++++++++++++++---- 3 files changed, 61 insertions(+), 18 deletions(-) diff --git a/ngraph/frontend/paddlepaddle/src/op/lstm.cpp b/ngraph/frontend/paddlepaddle/src/op/lstm.cpp index 87c7faa1752..039f5c061bb 100644 --- a/ngraph/frontend/paddlepaddle/src/op/lstm.cpp +++ b/ngraph/frontend/paddlepaddle/src/op/lstm.cpp @@ -94,12 +94,16 @@ struct LSTMNgInputMap { auto batch_size_node = std::make_shared(shape_of_x, opset6::Constant::create(element::i64, Shape{1}, {0}), axes); - auto seq_length_node = - std::make_shared(shape_of_x, opset6::Constant::create(element::i64, Shape{1}, {1}), axes); - - // TODO Specify SEQ_LEN for each batch #55404 - m_input_map[LSTMInput::LSTM_INPUT_SEQ_LENGTHS] = - std::make_shared(seq_length_node, batch_size_node); + if (node.has_ng_input("SequenceLength")) { + m_input_map[LSTMInput::LSTM_INPUT_SEQ_LENGTHS] = node.get_ng_input("SequenceLength"); + } else { + auto seq_length_node = + std::make_shared(shape_of_x, + opset6::Constant::create(element::i64, Shape{1}, {1}), + axes); + m_input_map[LSTMInput::LSTM_INPUT_SEQ_LENGTHS] = + std::make_shared(seq_length_node, batch_size_node); + } auto init_states = node.get_ng_inputs("PreState"); // 0 for init_h, 1 for init_cell, update bidirect_len for init states diff --git a/ngraph/test/frontend/paddlepaddle/op_fuzzy.cpp b/ngraph/test/frontend/paddlepaddle/op_fuzzy.cpp index 1eb14827ff4..d895829b272 100644 --- a/ngraph/test/frontend/paddlepaddle/op_fuzzy.cpp +++ b/ngraph/test/frontend/paddlepaddle/op_fuzzy.cpp @@ -189,6 +189,8 @@ static const std::vector models{std::string("argmax"), std::string("rnn_lstm_layer_1_forward"), std::string("rnn_lstm_layer_2_bidirectional"), std::string("rnn_lstm_layer_2_forward"), + std::string("rnn_lstm_layer_1_forward_seq_len_4"), + std::string("rnn_lstm_layer_2_bidirectional_seq_len_4"), std::string("scale_bias_after_float32"), std::string("scale_bias_after_int32"), std::string("scale_bias_after_int64"), diff --git a/ngraph/test/frontend/paddlepaddle/test_models/gen_scripts/generate_rnn_lstm.py b/ngraph/test/frontend/paddlepaddle/test_models/gen_scripts/generate_rnn_lstm.py index b6b4f5b265c..c5e62b75bc2 100644 --- a/ngraph/test/frontend/paddlepaddle/test_models/gen_scripts/generate_rnn_lstm.py +++ b/ngraph/test/frontend/paddlepaddle/test_models/gen_scripts/generate_rnn_lstm.py @@ -3,7 +3,7 @@ from save_model import saveModel import sys -def pdpd_rnn_lstm(input_size, hidden_size, layers, direction): +def pdpd_rnn_lstm(input_size, hidden_size, layers, direction, seq_len): import paddle as pdpd pdpd.enable_static() main_program = pdpd.static.Program() @@ -14,22 +14,40 @@ def pdpd_rnn_lstm(input_size, hidden_size, layers, direction): rnn = pdpd.nn.LSTM(input_size, hidden_size, layers, direction) - data = pdpd.static.data(name='x', shape=[4, 3, input_size], dtype='float32') - prev_h = pdpd.ones(shape=[layers * num_of_directions, 4, hidden_size], dtype=np.float32) - prev_c = pdpd.ones(shape=[layers * num_of_directions, 4, hidden_size], dtype=np.float32) + data = pdpd.static.data( + name='x', shape=[4, 3, input_size], dtype='float32') + prev_h = pdpd.ones( + shape=[layers * num_of_directions, 4, hidden_size], dtype=np.float32) + prev_c = pdpd.ones( + shape=[layers * num_of_directions, 4, hidden_size], dtype=np.float32) - y, (h, c) = rnn(data, (prev_h, prev_c)) + if seq_len: + seq_lengths = pdpd.static.data(name='sl', shape=[4], dtype='int32') + y, (h, c) = rnn(data, (prev_h, prev_c), seq_lengths) + else: + y, (h, c) = rnn(data, (prev_h, prev_c)) cpu = pdpd.static.cpu_places(1) exe = pdpd.static.Executor(cpu[0]) exe.run(startup_program) - outs = exe.run( - feed={'x': np.ones([4, 3, input_size]).astype(np.float32)}, - fetch_list=[y, h, c], - program=main_program) - saveModel("rnn_lstm_layer_" + str(layers) + '_' + str(direction), exe, feedkeys=['x'], - fetchlist=[y, h, c], inputs=[np.ones([4, 3, input_size]).astype(np.float32)], outputs=[outs[0], outs[1], outs[2]], target_dir=sys.argv[1]) + if seq_len: + outs = exe.run( + feed={'x': np.ones([4, 3, input_size]).astype( + np.float32), 'sl': np.array(seq_len).astype(np.int32)}, + fetch_list=[y, h, c], + program=main_program) + saveModel("rnn_lstm_layer_" + str(layers) + '_' + str(direction) + '_seq_len_' + str(len(seq_len)), exe, feedkeys=['x', 'sl'], + fetchlist=[y, h, c], inputs=[np.ones([4, 3, input_size]).astype(np.float32), np.array(seq_len).astype(np.int32)], outputs=[outs[0], outs[1], outs[2]], target_dir=sys.argv[1]) + else: + outs = exe.run( + feed={'x': np.ones([4, 3, input_size]).astype( + np.float32)}, + fetch_list=[y, h, c], + program=main_program) + saveModel("rnn_lstm_layer_" + str(layers) + '_' + str(direction), exe, feedkeys=['x'], + fetchlist=[y, h, c], inputs=[np.ones([4, 3, input_size]).astype(np.float32)], outputs=[outs[0], outs[1], outs[2]], target_dir=sys.argv[1]) + return outs[0] @@ -41,26 +59,45 @@ if __name__ == "__main__": 'hidden_size': 2, 'layers': 1, 'direction': 'forward', + 'seq_len': [], }, { 'input_size': 2, 'hidden_size': 2, 'layers': 1, 'direction': 'bidirectional', + 'seq_len': [], }, { 'input_size': 2, 'hidden_size': 2, 'layers': 2, 'direction': 'forward', + 'seq_len': [], }, { 'input_size': 2, 'hidden_size': 2, 'layers': 2, 'direction': 'bidirectional', + 'seq_len': [], + }, + { + 'input_size': 2, + 'hidden_size': 2, + 'layers': 1, + 'direction': 'forward', + 'seq_len': [1, 2, 3, 3], + }, + { + 'input_size': 2, + 'hidden_size': 2, + 'layers': 2, + 'direction': 'bidirectional', + 'seq_len': [2, 2, 3, 3], } ] for test in testCases: - pdpd_rnn_lstm(test['input_size'], test['hidden_size'], test['layers'], test['direction']) \ No newline at end of file + pdpd_rnn_lstm(test['input_size'], test['hidden_size'], + test['layers'], test['direction'], test['seq_len'])