Fix bidirectional mode in reference implementations of GRU/LSTM/RNN Sequences (#2264)

* fix bidirectional case in references of sequences ops, enable decomposition of bidirectional cases in CommonOptimizations

* introduce new opset5, include GRU/RNN/LSTM Sequences to opset5

* Revert "introduce new opset5, include GRU/RNN/LSTM Sequences to opset5"

This reverts commit 73c22a11db.
This commit is contained in:
Ivan Tikhonov 2020-09-18 10:14:01 +03:00 committed by GitHub
parent 93074590de
commit 1b7dfc6e4c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 119 additions and 86 deletions

View File

@ -19,6 +19,9 @@ ngraph::pass::BidirectionalLSTMSequenceDecomposition::BidirectionalLSTMSequenceD
return false;
}
if (lstm_sequence->get_direction() != ngraph::op::RecurrentSequenceDirection::BIDIRECTIONAL)
return false;
auto axis_0 = ngraph::opset4::Constant::create(element::i64, Shape{}, {0});
auto axis_1 = ngraph::opset4::Constant::create(element::i64, Shape{}, {1});
auto H = std::make_shared<opset4::Split>(lstm_sequence->input_value(1), axis_1, 2);
@ -84,6 +87,9 @@ ngraph::pass::BidirectionalGRUSequenceDecomposition::BidirectionalGRUSequenceDec
return false;
}
if (gru_sequence->get_direction() != ngraph::op::RecurrentSequenceDirection::BIDIRECTIONAL)
return false;
auto axis_0 = ngraph::opset4::Constant::create(element::i64, Shape{}, {0});
auto axis_1 = ngraph::opset4::Constant::create(element::i64, Shape{}, {1});
auto H = std::make_shared<opset4::Split>(gru_sequence->input_value(1), axis_1, 2);
@ -145,6 +151,9 @@ ngraph::pass::BidirectionalRNNSequenceDecomposition::BidirectionalRNNSequenceDec
return false;
}
if (rnn_sequence->get_direction() != ngraph::op::RecurrentSequenceDirection::BIDIRECTIONAL)
return false;
auto axis_0 = ngraph::opset4::Constant::create(element::i64, Shape{}, {0});
auto axis_1 = ngraph::opset4::Constant::create(element::i64, Shape{}, {1});
auto H = std::make_shared<opset4::Split>(rnn_sequence->input_value(1), axis_1, 2);

View File

@ -21,6 +21,7 @@
#include "transformations/hswish_fusion.hpp"
#include "transformations/normalize_l2_fusion.hpp"
#include "transformations/convert_quantize_dequantize.hpp"
#include "transformations/bidirectional_sequences_decomposition.hpp"
#include <ngraph/pass/manager.hpp>
#include <ngraph/pass/constant_folding.hpp>
@ -50,6 +51,9 @@ bool ngraph::pass::CommonOptimizations::run_on_function(std::shared_ptr<ngraph::
manager.register_pass<ngraph::pass::HSwishFusion>();
manager.register_pass<ngraph::pass::ConvertPadToGroupConvolution>();
manager.register_pass<ngraph::pass::NormalizeL2Fusion>();
manager.register_pass<ngraph::pass::BidirectionalLSTMSequenceDecomposition>();
manager.register_pass<ngraph::pass::BidirectionalRNNSequenceDecomposition>();
manager.register_pass<ngraph::pass::BidirectionalGRUSequenceDecomposition>();
manager.set_callback(m_transformation_callback);
manager.run_passes(f);

View File

@ -84,11 +84,6 @@ namespace LayerTestsDefinitions {
ngraph::ResultVector results{std::make_shared<ngraph::opset1::Result>(gru_sequence->output(0)),
std::make_shared<ngraph::opset1::Result>(gru_sequence->output(1))};
function = std::make_shared<ngraph::Function>(results, params, "gru_sequence");
if (direction == ngraph::op::RecurrentSequenceDirection::BIDIRECTIONAL) {
ngraph::pass::Manager m;
m.register_pass<ngraph::pass::BidirectionalGRUSequenceDecomposition>();
m.run_passes(function);
}
}

View File

@ -82,11 +82,6 @@ namespace LayerTestsDefinitions {
std::make_shared<ngraph::opset1::Result>(lstm_sequence->output(1)),
std::make_shared<ngraph::opset1::Result>(lstm_sequence->output(2))};
function = std::make_shared<ngraph::Function>(results, params, "lstm_sequence");
if (direction == ngraph::op::RecurrentSequenceDirection::BIDIRECTIONAL) {
ngraph::pass::Manager m;
m.register_pass<ngraph::pass::BidirectionalLSTMSequenceDecomposition>();
m.run_passes(function);
}
}

View File

@ -82,11 +82,6 @@ namespace LayerTestsDefinitions {
ngraph::ResultVector results{std::make_shared<ngraph::opset1::Result>(rnn_sequence->output(0)),
std::make_shared<ngraph::opset1::Result>(rnn_sequence->output(1))};
function = std::make_shared<ngraph::Function>(results, params, "rnn_sequence");
if (direction == ngraph::op::RecurrentSequenceDirection::BIDIRECTIONAL) {
ngraph::pass::Manager m;
m.register_pass<ngraph::pass::BidirectionalRNNSequenceDecomposition>();
m.run_passes(function);
}
}

View File

@ -218,15 +218,15 @@ namespace ngraph
// Split bidirectional case to forward + reverse passes.
// split inputs
std::vector<std::vector<char>> H_split(
2, std::vector<char>(ngraph::shape_size(H_shape) / 2));
2, std::vector<char>(sizeof(T) * ngraph::shape_size(H_shape) / 2));
std::vector<std::vector<char>> C_split(
2, std::vector<char>(ngraph::shape_size(C_shape) / 2));
2, std::vector<char>(sizeof(T) * ngraph::shape_size(C_shape) / 2));
std::vector<std::vector<char>> W_split(
2, std::vector<char>(ngraph::shape_size(W_shape) / 2));
2, std::vector<char>(sizeof(T) * ngraph::shape_size(W_shape) / 2));
std::vector<std::vector<char>> R_split(
2, std::vector<char>(ngraph::shape_size(R_shape) / 2));
2, std::vector<char>(sizeof(T) * ngraph::shape_size(R_shape) / 2));
std::vector<std::vector<char>> B_split(
2, std::vector<char>(ngraph::shape_size(B_shape) / 2));
2, std::vector<char>(sizeof(T) * ngraph::shape_size(B_shape) / 2));
char* h_pointers[2] = {H_split[0].data(), H_split[1].data()};
char* c_pointers[2] = {C_split[0].data(), C_split[1].data()};
char* w_pointers[2] = {W_split[0].data(), W_split[1].data()};
@ -234,13 +234,17 @@ namespace ngraph
char* b_pointers[2] = {B_split[0].data(), B_split[1].data()};
reference::split(H, H_shape, sizeof(T), 1, 2, h_pointers);
reference::split(C, C_shape, sizeof(T), 1, 2, c_pointers);
reference::split(W, W_shape, sizeof(T), 1, 2, w_pointers);
reference::split(R, R_shape, sizeof(T), 1, 2, r_pointers);
reference::split(B, B_shape, sizeof(T), 1, 2, b_pointers);
reference::split(W, W_shape, sizeof(T), 0, 2, w_pointers);
reference::split(R, R_shape, sizeof(T), 0, 2, r_pointers);
reference::split(B, B_shape, sizeof(T), 0, 2, b_pointers);
std::vector<char> forward_res_y(sizeof(T) * H_shape[0] * H_shape[2] *
X_shape[1]);
std::vector<char> reverse_res_y(sizeof(T) * H_shape[0] * H_shape[2] *
X_shape[1]);
std::vector<std::vector<char>> forward_res(
3, std::vector<char>(H_shape[0] * H_shape[2]));
2, std::vector<char>(sizeof(T) * H_shape[0] * H_shape[2]));
std::vector<std::vector<char>> reverse_res(
3, std::vector<char>(H_shape[0] * H_shape[2]));
2, std::vector<char>(sizeof(T) * H_shape[0] * H_shape[2]));
CellArgs args;
args.activation_f = activation_f;
@ -249,6 +253,13 @@ namespace ngraph
args.clip = clip;
std::vector<Shape> shapes = {
X_shape, seq_lengths_shape, H_shape, C_shape, W_shape, R_shape, B_shape};
// update H,C,W,R,B shapes after split
shapes[2][1] = 1;
shapes[3][1] = 1;
for (int i = 4; i < shapes.size(); ++i)
{
shapes[i][0] = 1;
}
// forward pass
cell_pass<T>(
CellType::LSTM,
@ -260,7 +271,7 @@ namespace ngraph
r_pointers[0],
b_pointers[0]},
shapes,
{forward_res[0].data(), forward_res[1].data(), forward_res[2].data()},
{forward_res_y.data(), forward_res[0].data(), forward_res[1].data()},
args,
false);
// reverse pass
@ -274,32 +285,34 @@ namespace ngraph
r_pointers[1],
b_pointers[1]},
shapes,
{reverse_res[0].data(), reverse_res[1].data(), reverse_res[2].data()},
{reverse_res_y.data(), reverse_res[0].data(), reverse_res[1].data()},
args,
true);
// Stack together respective outputs from both forward and reverse passes.
std::vector<Shape> in_shapes = {{H_shape[0], 1, H_shape[2]},
{H_shape[0], 1, H_shape[2]},
{H_shape[0], 1, H_shape[2]}};
Shape output_shape = {{H_shape[0], 2, H_shape[2]}};
std::vector<Shape> in_shapes_y = {{H_shape[0], 1, X_shape[1], H_shape[2]},
{H_shape[0], 1, X_shape[1], H_shape[2]}};
std::vector<Shape> in_shapes_h_c = {{H_shape[0], 1, H_shape[2]},
{H_shape[0], 1, H_shape[2]}};
Shape output_shape_y{H_shape[0], 2, X_shape[1], H_shape[2]};
Shape output_shape_h_c{H_shape[0], 2, H_shape[2]};
runtime::reference::concat({forward_res[0].data(), reverse_res[0].data()},
runtime::reference::concat({forward_res_y.data(), reverse_res_y.data()},
Y,
in_shapes,
output_shape,
in_shapes_y,
output_shape_y,
1,
sizeof(T));
runtime::reference::concat({forward_res[0].data(), reverse_res[0].data()},
Ho,
in_shapes_h_c,
output_shape_h_c,
1,
sizeof(T));
runtime::reference::concat({forward_res[1].data(), reverse_res[1].data()},
Ho,
in_shapes,
output_shape,
1,
sizeof(T));
runtime::reference::concat({forward_res[2].data(), reverse_res[2].data()},
Co,
in_shapes,
output_shape,
in_shapes_h_c,
output_shape_h_c,
1,
sizeof(T));
}
@ -351,25 +364,27 @@ namespace ngraph
// Split bidirectional case to forward + reverse passes.
// split inputs
std::vector<std::vector<char>> H_split(
2, std::vector<char>(ngraph::shape_size(H_shape) / 2));
2, std::vector<char>(sizeof(T) * ngraph::shape_size(H_shape) / 2));
std::vector<std::vector<char>> W_split(
2, std::vector<char>(ngraph::shape_size(W_shape) / 2));
2, std::vector<char>(sizeof(T) * ngraph::shape_size(W_shape) / 2));
std::vector<std::vector<char>> R_split(
2, std::vector<char>(ngraph::shape_size(R_shape) / 2));
2, std::vector<char>(sizeof(T) * ngraph::shape_size(R_shape) / 2));
std::vector<std::vector<char>> B_split(
2, std::vector<char>(ngraph::shape_size(B_shape) / 2));
2, std::vector<char>(sizeof(T) * ngraph::shape_size(B_shape) / 2));
char* h_pointers[2] = {H_split[0].data(), H_split[1].data()};
char* w_pointers[2] = {W_split[0].data(), W_split[1].data()};
char* r_pointers[2] = {R_split[0].data(), R_split[1].data()};
char* b_pointers[2] = {B_split[0].data(), B_split[1].data()};
reference::split(H, H_shape, sizeof(T), 1, 2, h_pointers);
reference::split(W, W_shape, sizeof(T), 1, 2, w_pointers);
reference::split(R, R_shape, sizeof(T), 1, 2, r_pointers);
reference::split(B, B_shape, sizeof(T), 1, 2, b_pointers);
std::vector<std::vector<char>> forward_res(
2, std::vector<char>(H_shape[0] * H_shape[2]));
std::vector<std::vector<char>> reverse_res(
2, std::vector<char>(H_shape[0] * H_shape[2]));
reference::split(W, W_shape, sizeof(T), 0, 2, w_pointers);
reference::split(R, R_shape, sizeof(T), 0, 2, r_pointers);
reference::split(B, B_shape, sizeof(T), 0, 2, b_pointers);
std::vector<char> forward_res_y(sizeof(T) * H_shape[0] * H_shape[2] *
X_shape[1]);
std::vector<char> forward_res_h(sizeof(T) * H_shape[0] * H_shape[2]);
std::vector<char> reverse_res_y(sizeof(T) * H_shape[0] * H_shape[2] *
X_shape[1]);
std::vector<char> reverse_res_h(sizeof(T) * H_shape[0] * H_shape[2]);
CellArgs args;
args.activation_f = activation_f;
@ -378,6 +393,12 @@ namespace ngraph
args.clip = clip;
std::vector<Shape> shapes = {
X_shape, seq_lengths_shape, H_shape, W_shape, R_shape, B_shape};
// update H,W,R,B shapes after split
shapes[2][1] = 1;
for (int i = 3; i < shapes.size(); ++i)
{
shapes[i][0] = 1;
}
// forward pass
cell_pass<T>(CellType::GRU,
{X,
@ -387,7 +408,7 @@ namespace ngraph
r_pointers[0],
b_pointers[0]},
shapes,
{forward_res[0].data(), forward_res[1].data()},
{forward_res_y.data(), forward_res_h.data()},
args,
false);
// reverse pass
@ -399,25 +420,28 @@ namespace ngraph
r_pointers[1],
b_pointers[1]},
shapes,
{reverse_res[0].data(), reverse_res[1].data()},
{reverse_res_y.data(), reverse_res_h.data()},
args,
true);
// Stack together respective outputs from both forward and reverse passes.
std::vector<Shape> in_shapes = {{H_shape[0], 1, H_shape[2]},
{H_shape[0], 1, H_shape[2]}};
Shape output_shape = {{H_shape[0], 2, H_shape[2]}};
std::vector<Shape> in_shapes_y = {{H_shape[0], 1, X_shape[1], H_shape[2]},
{H_shape[0], 1, X_shape[1], H_shape[2]}};
std::vector<Shape> in_shapes_h = {{H_shape[0], 1, H_shape[2]},
{H_shape[0], 1, H_shape[2]}};
Shape output_shape_y{H_shape[0], 2, X_shape[1], H_shape[2]};
Shape output_shape_h{H_shape[0], 2, H_shape[2]};
runtime::reference::concat({forward_res[0].data(), reverse_res[0].data()},
runtime::reference::concat({forward_res_y.data(), reverse_res_y.data()},
Y,
in_shapes,
output_shape,
in_shapes_y,
output_shape_y,
1,
sizeof(T));
runtime::reference::concat({forward_res[1].data(), reverse_res[1].data()},
runtime::reference::concat({forward_res_h.data(), reverse_res_h.data()},
Ho,
in_shapes,
output_shape,
in_shapes_h,
output_shape_h,
1,
sizeof(T));
}
@ -465,31 +489,39 @@ namespace ngraph
// Split bidirectional case to forward + reverse passes.
// split inputs
std::vector<std::vector<char>> H_split(
2, std::vector<char>(ngraph::shape_size(H_shape) / 2));
2, std::vector<char>(sizeof(T) * ngraph::shape_size(H_shape) / 2));
std::vector<std::vector<char>> W_split(
2, std::vector<char>(ngraph::shape_size(W_shape) / 2));
2, std::vector<char>(sizeof(T) * ngraph::shape_size(W_shape) / 2));
std::vector<std::vector<char>> R_split(
2, std::vector<char>(ngraph::shape_size(R_shape) / 2));
2, std::vector<char>(sizeof(T) * ngraph::shape_size(R_shape) / 2));
std::vector<std::vector<char>> B_split(
2, std::vector<char>(ngraph::shape_size(B_shape) / 2));
2, std::vector<char>(sizeof(T) * ngraph::shape_size(B_shape) / 2));
char* h_pointers[2] = {H_split[0].data(), H_split[1].data()};
char* w_pointers[2] = {W_split[0].data(), W_split[1].data()};
char* r_pointers[2] = {R_split[0].data(), R_split[1].data()};
char* b_pointers[2] = {B_split[0].data(), B_split[1].data()};
reference::split(H, H_shape, sizeof(T), 1, 2, h_pointers);
reference::split(W, W_shape, sizeof(T), 1, 2, w_pointers);
reference::split(R, R_shape, sizeof(T), 1, 2, r_pointers);
reference::split(B, B_shape, sizeof(T), 1, 2, b_pointers);
std::vector<std::vector<char>> forward_res(
2, std::vector<char>(H_shape[0] * H_shape[2]));
std::vector<std::vector<char>> reverse_res(
2, std::vector<char>(H_shape[0] * H_shape[2]));
reference::split(W, W_shape, sizeof(T), 0, 2, w_pointers);
reference::split(R, R_shape, sizeof(T), 0, 2, r_pointers);
reference::split(B, B_shape, sizeof(T), 0, 2, b_pointers);
std::vector<char> forward_res_y(sizeof(T) * H_shape[0] * H_shape[2] *
X_shape[1]);
std::vector<char> forward_res_h(sizeof(T) * H_shape[0] * H_shape[2]);
std::vector<char> reverse_res_y(sizeof(T) * H_shape[0] * H_shape[2] *
X_shape[1]);
std::vector<char> reverse_res_h(sizeof(T) * H_shape[0] * H_shape[2]);
CellArgs args;
args.activation_f = activation_f;
args.clip = clip;
std::vector<Shape> shapes = {
X_shape, seq_lengths_shape, H_shape, W_shape, R_shape, B_shape};
// update H,W,R,B shapes after split
shapes[2][1] = 1;
for (int i = 3; i < shapes.size(); ++i)
{
shapes[i][0] = 1;
}
// forward pass
cell_pass<T>(CellType::RNN,
{X,
@ -499,7 +531,7 @@ namespace ngraph
r_pointers[0],
b_pointers[0]},
shapes,
{forward_res[0].data(), forward_res[1].data()},
{forward_res_y.data(), forward_res_h.data()},
args,
false);
// reverse pass
@ -511,25 +543,28 @@ namespace ngraph
r_pointers[1],
b_pointers[1]},
shapes,
{reverse_res[0].data(), reverse_res[1].data()},
{reverse_res_y.data(), reverse_res_h.data()},
args,
true);
// Stack together respective outputs from both forward and reverse passes.
std::vector<Shape> in_shapes = {{H_shape[0], 1, H_shape[2]},
{H_shape[0], 1, H_shape[2]}};
Shape output_shape = {{H_shape[0], 2, H_shape[2]}};
std::vector<Shape> in_shapes_y = {{H_shape[0], 1, X_shape[1], H_shape[2]},
{H_shape[0], 1, X_shape[1], H_shape[2]}};
std::vector<Shape> in_shapes_h = {{H_shape[0], 1, H_shape[2]},
{H_shape[0], 1, H_shape[2]}};
Shape output_shape_y{H_shape[0], 2, X_shape[1], H_shape[2]};
Shape output_shape_h{H_shape[0], 2, H_shape[2]};
runtime::reference::concat({forward_res[0].data(), reverse_res[0].data()},
runtime::reference::concat({forward_res_y.data(), reverse_res_y.data()},
Y,
in_shapes,
output_shape,
in_shapes_y,
output_shape_y,
1,
sizeof(T));
runtime::reference::concat({forward_res[1].data(), reverse_res[1].data()},
runtime::reference::concat({forward_res_h.data(), reverse_res_h.data()},
Ho,
in_shapes,
output_shape,
in_shapes_h,
output_shape_h,
1,
sizeof(T));
}