From 1b7dfc6e4c8efbac04a9d2dc57eca90641e387f1 Mon Sep 17 00:00:00 2001 From: Ivan Tikhonov Date: Fri, 18 Sep 2020 10:14:01 +0300 Subject: [PATCH] Fix bidirectional mode in reference implementations of GRU/LSTM/RNN Sequences (#2264) * fix bidirectional case in references of sequences ops, enable decomposition of bidirectional cases in CommonOptimizations * introduce new opset5, include GRU/RNN/LSTM Sequences to opset5 * Revert "introduce new opset5, include GRU/RNN/LSTM Sequences to opset5" This reverts commit 73c22a11dbd724d2cfa9212ff211db74ef09cf2a. --- .../bidirectional_sequences_decomposition.cpp | 9 + .../common_optimizations.cpp | 4 + .../src/single_layer_tests/gru_sequence.cpp | 5 - .../src/single_layer_tests/lstm_sequence.cpp | 5 - .../src/single_layer_tests/rnn_sequence.cpp | 5 - .../ngraph/runtime/reference/sequences.hpp | 177 +++++++++++------- 6 files changed, 119 insertions(+), 86 deletions(-) diff --git a/inference-engine/src/transformations/src/transformations/bidirectional_sequences_decomposition.cpp b/inference-engine/src/transformations/src/transformations/bidirectional_sequences_decomposition.cpp index 281d470fa14..1df3c61f7c1 100644 --- a/inference-engine/src/transformations/src/transformations/bidirectional_sequences_decomposition.cpp +++ b/inference-engine/src/transformations/src/transformations/bidirectional_sequences_decomposition.cpp @@ -19,6 +19,9 @@ ngraph::pass::BidirectionalLSTMSequenceDecomposition::BidirectionalLSTMSequenceD return false; } + if (lstm_sequence->get_direction() != ngraph::op::RecurrentSequenceDirection::BIDIRECTIONAL) + return false; + auto axis_0 = ngraph::opset4::Constant::create(element::i64, Shape{}, {0}); auto axis_1 = ngraph::opset4::Constant::create(element::i64, Shape{}, {1}); auto H = std::make_shared(lstm_sequence->input_value(1), axis_1, 2); @@ -84,6 +87,9 @@ ngraph::pass::BidirectionalGRUSequenceDecomposition::BidirectionalGRUSequenceDec return false; } + if (gru_sequence->get_direction() != ngraph::op::RecurrentSequenceDirection::BIDIRECTIONAL) + return false; + auto axis_0 = ngraph::opset4::Constant::create(element::i64, Shape{}, {0}); auto axis_1 = ngraph::opset4::Constant::create(element::i64, Shape{}, {1}); auto H = std::make_shared(gru_sequence->input_value(1), axis_1, 2); @@ -145,6 +151,9 @@ ngraph::pass::BidirectionalRNNSequenceDecomposition::BidirectionalRNNSequenceDec return false; } + if (rnn_sequence->get_direction() != ngraph::op::RecurrentSequenceDirection::BIDIRECTIONAL) + return false; + auto axis_0 = ngraph::opset4::Constant::create(element::i64, Shape{}, {0}); auto axis_1 = ngraph::opset4::Constant::create(element::i64, Shape{}, {1}); auto H = std::make_shared(rnn_sequence->input_value(1), axis_1, 2); diff --git a/inference-engine/src/transformations/src/transformations/common_optimizations/common_optimizations.cpp b/inference-engine/src/transformations/src/transformations/common_optimizations/common_optimizations.cpp index 4ec3c63a867..e45776a9c78 100644 --- a/inference-engine/src/transformations/src/transformations/common_optimizations/common_optimizations.cpp +++ b/inference-engine/src/transformations/src/transformations/common_optimizations/common_optimizations.cpp @@ -21,6 +21,7 @@ #include "transformations/hswish_fusion.hpp" #include "transformations/normalize_l2_fusion.hpp" #include "transformations/convert_quantize_dequantize.hpp" +#include "transformations/bidirectional_sequences_decomposition.hpp" #include #include @@ -50,6 +51,9 @@ bool ngraph::pass::CommonOptimizations::run_on_function(std::shared_ptr(); manager.register_pass(); manager.register_pass(); + manager.register_pass(); + manager.register_pass(); + manager.register_pass(); manager.set_callback(m_transformation_callback); manager.run_passes(f); diff --git a/inference-engine/tests/functional/plugin/shared/src/single_layer_tests/gru_sequence.cpp b/inference-engine/tests/functional/plugin/shared/src/single_layer_tests/gru_sequence.cpp index 1327831a457..f1d8afee399 100644 --- a/inference-engine/tests/functional/plugin/shared/src/single_layer_tests/gru_sequence.cpp +++ b/inference-engine/tests/functional/plugin/shared/src/single_layer_tests/gru_sequence.cpp @@ -84,11 +84,6 @@ namespace LayerTestsDefinitions { ngraph::ResultVector results{std::make_shared(gru_sequence->output(0)), std::make_shared(gru_sequence->output(1))}; function = std::make_shared(results, params, "gru_sequence"); - if (direction == ngraph::op::RecurrentSequenceDirection::BIDIRECTIONAL) { - ngraph::pass::Manager m; - m.register_pass(); - m.run_passes(function); - } } diff --git a/inference-engine/tests/functional/plugin/shared/src/single_layer_tests/lstm_sequence.cpp b/inference-engine/tests/functional/plugin/shared/src/single_layer_tests/lstm_sequence.cpp index b1edaa9b0a4..d910194540b 100644 --- a/inference-engine/tests/functional/plugin/shared/src/single_layer_tests/lstm_sequence.cpp +++ b/inference-engine/tests/functional/plugin/shared/src/single_layer_tests/lstm_sequence.cpp @@ -82,11 +82,6 @@ namespace LayerTestsDefinitions { std::make_shared(lstm_sequence->output(1)), std::make_shared(lstm_sequence->output(2))}; function = std::make_shared(results, params, "lstm_sequence"); - if (direction == ngraph::op::RecurrentSequenceDirection::BIDIRECTIONAL) { - ngraph::pass::Manager m; - m.register_pass(); - m.run_passes(function); - } } diff --git a/inference-engine/tests/functional/plugin/shared/src/single_layer_tests/rnn_sequence.cpp b/inference-engine/tests/functional/plugin/shared/src/single_layer_tests/rnn_sequence.cpp index 63f9e850026..90ac191c738 100644 --- a/inference-engine/tests/functional/plugin/shared/src/single_layer_tests/rnn_sequence.cpp +++ b/inference-engine/tests/functional/plugin/shared/src/single_layer_tests/rnn_sequence.cpp @@ -82,11 +82,6 @@ namespace LayerTestsDefinitions { ngraph::ResultVector results{std::make_shared(rnn_sequence->output(0)), std::make_shared(rnn_sequence->output(1))}; function = std::make_shared(results, params, "rnn_sequence"); - if (direction == ngraph::op::RecurrentSequenceDirection::BIDIRECTIONAL) { - ngraph::pass::Manager m; - m.register_pass(); - m.run_passes(function); - } } diff --git a/ngraph/core/reference/include/ngraph/runtime/reference/sequences.hpp b/ngraph/core/reference/include/ngraph/runtime/reference/sequences.hpp index e236bbdb574..894f1c39d07 100644 --- a/ngraph/core/reference/include/ngraph/runtime/reference/sequences.hpp +++ b/ngraph/core/reference/include/ngraph/runtime/reference/sequences.hpp @@ -218,15 +218,15 @@ namespace ngraph // Split bidirectional case to forward + reverse passes. // split inputs std::vector> H_split( - 2, std::vector(ngraph::shape_size(H_shape) / 2)); + 2, std::vector(sizeof(T) * ngraph::shape_size(H_shape) / 2)); std::vector> C_split( - 2, std::vector(ngraph::shape_size(C_shape) / 2)); + 2, std::vector(sizeof(T) * ngraph::shape_size(C_shape) / 2)); std::vector> W_split( - 2, std::vector(ngraph::shape_size(W_shape) / 2)); + 2, std::vector(sizeof(T) * ngraph::shape_size(W_shape) / 2)); std::vector> R_split( - 2, std::vector(ngraph::shape_size(R_shape) / 2)); + 2, std::vector(sizeof(T) * ngraph::shape_size(R_shape) / 2)); std::vector> B_split( - 2, std::vector(ngraph::shape_size(B_shape) / 2)); + 2, std::vector(sizeof(T) * ngraph::shape_size(B_shape) / 2)); char* h_pointers[2] = {H_split[0].data(), H_split[1].data()}; char* c_pointers[2] = {C_split[0].data(), C_split[1].data()}; char* w_pointers[2] = {W_split[0].data(), W_split[1].data()}; @@ -234,13 +234,17 @@ namespace ngraph char* b_pointers[2] = {B_split[0].data(), B_split[1].data()}; reference::split(H, H_shape, sizeof(T), 1, 2, h_pointers); reference::split(C, C_shape, sizeof(T), 1, 2, c_pointers); - reference::split(W, W_shape, sizeof(T), 1, 2, w_pointers); - reference::split(R, R_shape, sizeof(T), 1, 2, r_pointers); - reference::split(B, B_shape, sizeof(T), 1, 2, b_pointers); + reference::split(W, W_shape, sizeof(T), 0, 2, w_pointers); + reference::split(R, R_shape, sizeof(T), 0, 2, r_pointers); + reference::split(B, B_shape, sizeof(T), 0, 2, b_pointers); + std::vector forward_res_y(sizeof(T) * H_shape[0] * H_shape[2] * + X_shape[1]); + std::vector reverse_res_y(sizeof(T) * H_shape[0] * H_shape[2] * + X_shape[1]); std::vector> forward_res( - 3, std::vector(H_shape[0] * H_shape[2])); + 2, std::vector(sizeof(T) * H_shape[0] * H_shape[2])); std::vector> reverse_res( - 3, std::vector(H_shape[0] * H_shape[2])); + 2, std::vector(sizeof(T) * H_shape[0] * H_shape[2])); CellArgs args; args.activation_f = activation_f; @@ -249,6 +253,13 @@ namespace ngraph args.clip = clip; std::vector shapes = { X_shape, seq_lengths_shape, H_shape, C_shape, W_shape, R_shape, B_shape}; + // update H,C,W,R,B shapes after split + shapes[2][1] = 1; + shapes[3][1] = 1; + for (int i = 4; i < shapes.size(); ++i) + { + shapes[i][0] = 1; + } // forward pass cell_pass( CellType::LSTM, @@ -260,7 +271,7 @@ namespace ngraph r_pointers[0], b_pointers[0]}, shapes, - {forward_res[0].data(), forward_res[1].data(), forward_res[2].data()}, + {forward_res_y.data(), forward_res[0].data(), forward_res[1].data()}, args, false); // reverse pass @@ -274,32 +285,34 @@ namespace ngraph r_pointers[1], b_pointers[1]}, shapes, - {reverse_res[0].data(), reverse_res[1].data(), reverse_res[2].data()}, + {reverse_res_y.data(), reverse_res[0].data(), reverse_res[1].data()}, args, true); // Stack together respective outputs from both forward and reverse passes. - std::vector in_shapes = {{H_shape[0], 1, H_shape[2]}, - {H_shape[0], 1, H_shape[2]}, - {H_shape[0], 1, H_shape[2]}}; - Shape output_shape = {{H_shape[0], 2, H_shape[2]}}; + std::vector in_shapes_y = {{H_shape[0], 1, X_shape[1], H_shape[2]}, + {H_shape[0], 1, X_shape[1], H_shape[2]}}; + std::vector in_shapes_h_c = {{H_shape[0], 1, H_shape[2]}, + {H_shape[0], 1, H_shape[2]}}; + Shape output_shape_y{H_shape[0], 2, X_shape[1], H_shape[2]}; + Shape output_shape_h_c{H_shape[0], 2, H_shape[2]}; - runtime::reference::concat({forward_res[0].data(), reverse_res[0].data()}, + runtime::reference::concat({forward_res_y.data(), reverse_res_y.data()}, Y, - in_shapes, - output_shape, + in_shapes_y, + output_shape_y, + 1, + sizeof(T)); + runtime::reference::concat({forward_res[0].data(), reverse_res[0].data()}, + Ho, + in_shapes_h_c, + output_shape_h_c, 1, sizeof(T)); runtime::reference::concat({forward_res[1].data(), reverse_res[1].data()}, - Ho, - in_shapes, - output_shape, - 1, - sizeof(T)); - runtime::reference::concat({forward_res[2].data(), reverse_res[2].data()}, Co, - in_shapes, - output_shape, + in_shapes_h_c, + output_shape_h_c, 1, sizeof(T)); } @@ -351,25 +364,27 @@ namespace ngraph // Split bidirectional case to forward + reverse passes. // split inputs std::vector> H_split( - 2, std::vector(ngraph::shape_size(H_shape) / 2)); + 2, std::vector(sizeof(T) * ngraph::shape_size(H_shape) / 2)); std::vector> W_split( - 2, std::vector(ngraph::shape_size(W_shape) / 2)); + 2, std::vector(sizeof(T) * ngraph::shape_size(W_shape) / 2)); std::vector> R_split( - 2, std::vector(ngraph::shape_size(R_shape) / 2)); + 2, std::vector(sizeof(T) * ngraph::shape_size(R_shape) / 2)); std::vector> B_split( - 2, std::vector(ngraph::shape_size(B_shape) / 2)); + 2, std::vector(sizeof(T) * ngraph::shape_size(B_shape) / 2)); char* h_pointers[2] = {H_split[0].data(), H_split[1].data()}; char* w_pointers[2] = {W_split[0].data(), W_split[1].data()}; char* r_pointers[2] = {R_split[0].data(), R_split[1].data()}; char* b_pointers[2] = {B_split[0].data(), B_split[1].data()}; reference::split(H, H_shape, sizeof(T), 1, 2, h_pointers); - reference::split(W, W_shape, sizeof(T), 1, 2, w_pointers); - reference::split(R, R_shape, sizeof(T), 1, 2, r_pointers); - reference::split(B, B_shape, sizeof(T), 1, 2, b_pointers); - std::vector> forward_res( - 2, std::vector(H_shape[0] * H_shape[2])); - std::vector> reverse_res( - 2, std::vector(H_shape[0] * H_shape[2])); + reference::split(W, W_shape, sizeof(T), 0, 2, w_pointers); + reference::split(R, R_shape, sizeof(T), 0, 2, r_pointers); + reference::split(B, B_shape, sizeof(T), 0, 2, b_pointers); + std::vector forward_res_y(sizeof(T) * H_shape[0] * H_shape[2] * + X_shape[1]); + std::vector forward_res_h(sizeof(T) * H_shape[0] * H_shape[2]); + std::vector reverse_res_y(sizeof(T) * H_shape[0] * H_shape[2] * + X_shape[1]); + std::vector reverse_res_h(sizeof(T) * H_shape[0] * H_shape[2]); CellArgs args; args.activation_f = activation_f; @@ -378,6 +393,12 @@ namespace ngraph args.clip = clip; std::vector shapes = { X_shape, seq_lengths_shape, H_shape, W_shape, R_shape, B_shape}; + // update H,W,R,B shapes after split + shapes[2][1] = 1; + for (int i = 3; i < shapes.size(); ++i) + { + shapes[i][0] = 1; + } // forward pass cell_pass(CellType::GRU, {X, @@ -387,7 +408,7 @@ namespace ngraph r_pointers[0], b_pointers[0]}, shapes, - {forward_res[0].data(), forward_res[1].data()}, + {forward_res_y.data(), forward_res_h.data()}, args, false); // reverse pass @@ -399,25 +420,28 @@ namespace ngraph r_pointers[1], b_pointers[1]}, shapes, - {reverse_res[0].data(), reverse_res[1].data()}, + {reverse_res_y.data(), reverse_res_h.data()}, args, true); // Stack together respective outputs from both forward and reverse passes. - std::vector in_shapes = {{H_shape[0], 1, H_shape[2]}, - {H_shape[0], 1, H_shape[2]}}; - Shape output_shape = {{H_shape[0], 2, H_shape[2]}}; + std::vector in_shapes_y = {{H_shape[0], 1, X_shape[1], H_shape[2]}, + {H_shape[0], 1, X_shape[1], H_shape[2]}}; + std::vector in_shapes_h = {{H_shape[0], 1, H_shape[2]}, + {H_shape[0], 1, H_shape[2]}}; + Shape output_shape_y{H_shape[0], 2, X_shape[1], H_shape[2]}; + Shape output_shape_h{H_shape[0], 2, H_shape[2]}; - runtime::reference::concat({forward_res[0].data(), reverse_res[0].data()}, + runtime::reference::concat({forward_res_y.data(), reverse_res_y.data()}, Y, - in_shapes, - output_shape, + in_shapes_y, + output_shape_y, 1, sizeof(T)); - runtime::reference::concat({forward_res[1].data(), reverse_res[1].data()}, + runtime::reference::concat({forward_res_h.data(), reverse_res_h.data()}, Ho, - in_shapes, - output_shape, + in_shapes_h, + output_shape_h, 1, sizeof(T)); } @@ -465,31 +489,39 @@ namespace ngraph // Split bidirectional case to forward + reverse passes. // split inputs std::vector> H_split( - 2, std::vector(ngraph::shape_size(H_shape) / 2)); + 2, std::vector(sizeof(T) * ngraph::shape_size(H_shape) / 2)); std::vector> W_split( - 2, std::vector(ngraph::shape_size(W_shape) / 2)); + 2, std::vector(sizeof(T) * ngraph::shape_size(W_shape) / 2)); std::vector> R_split( - 2, std::vector(ngraph::shape_size(R_shape) / 2)); + 2, std::vector(sizeof(T) * ngraph::shape_size(R_shape) / 2)); std::vector> B_split( - 2, std::vector(ngraph::shape_size(B_shape) / 2)); + 2, std::vector(sizeof(T) * ngraph::shape_size(B_shape) / 2)); char* h_pointers[2] = {H_split[0].data(), H_split[1].data()}; char* w_pointers[2] = {W_split[0].data(), W_split[1].data()}; char* r_pointers[2] = {R_split[0].data(), R_split[1].data()}; char* b_pointers[2] = {B_split[0].data(), B_split[1].data()}; reference::split(H, H_shape, sizeof(T), 1, 2, h_pointers); - reference::split(W, W_shape, sizeof(T), 1, 2, w_pointers); - reference::split(R, R_shape, sizeof(T), 1, 2, r_pointers); - reference::split(B, B_shape, sizeof(T), 1, 2, b_pointers); - std::vector> forward_res( - 2, std::vector(H_shape[0] * H_shape[2])); - std::vector> reverse_res( - 2, std::vector(H_shape[0] * H_shape[2])); + reference::split(W, W_shape, sizeof(T), 0, 2, w_pointers); + reference::split(R, R_shape, sizeof(T), 0, 2, r_pointers); + reference::split(B, B_shape, sizeof(T), 0, 2, b_pointers); + std::vector forward_res_y(sizeof(T) * H_shape[0] * H_shape[2] * + X_shape[1]); + std::vector forward_res_h(sizeof(T) * H_shape[0] * H_shape[2]); + std::vector reverse_res_y(sizeof(T) * H_shape[0] * H_shape[2] * + X_shape[1]); + std::vector reverse_res_h(sizeof(T) * H_shape[0] * H_shape[2]); CellArgs args; args.activation_f = activation_f; args.clip = clip; std::vector shapes = { X_shape, seq_lengths_shape, H_shape, W_shape, R_shape, B_shape}; + // update H,W,R,B shapes after split + shapes[2][1] = 1; + for (int i = 3; i < shapes.size(); ++i) + { + shapes[i][0] = 1; + } // forward pass cell_pass(CellType::RNN, {X, @@ -499,7 +531,7 @@ namespace ngraph r_pointers[0], b_pointers[0]}, shapes, - {forward_res[0].data(), forward_res[1].data()}, + {forward_res_y.data(), forward_res_h.data()}, args, false); // reverse pass @@ -511,25 +543,28 @@ namespace ngraph r_pointers[1], b_pointers[1]}, shapes, - {reverse_res[0].data(), reverse_res[1].data()}, + {reverse_res_y.data(), reverse_res_h.data()}, args, true); // Stack together respective outputs from both forward and reverse passes. - std::vector in_shapes = {{H_shape[0], 1, H_shape[2]}, - {H_shape[0], 1, H_shape[2]}}; - Shape output_shape = {{H_shape[0], 2, H_shape[2]}}; + std::vector in_shapes_y = {{H_shape[0], 1, X_shape[1], H_shape[2]}, + {H_shape[0], 1, X_shape[1], H_shape[2]}}; + std::vector in_shapes_h = {{H_shape[0], 1, H_shape[2]}, + {H_shape[0], 1, H_shape[2]}}; + Shape output_shape_y{H_shape[0], 2, X_shape[1], H_shape[2]}; + Shape output_shape_h{H_shape[0], 2, H_shape[2]}; - runtime::reference::concat({forward_res[0].data(), reverse_res[0].data()}, + runtime::reference::concat({forward_res_y.data(), reverse_res_y.data()}, Y, - in_shapes, - output_shape, + in_shapes_y, + output_shape_y, 1, sizeof(T)); - runtime::reference::concat({forward_res[1].data(), reverse_res[1].data()}, + runtime::reference::concat({forward_res_h.data(), reverse_res_h.data()}, Ho, - in_shapes, - output_shape, + in_shapes_h, + output_shape_h, 1, sizeof(T)); }