Fix bidirectional mode in reference implementations of GRU/LSTM/RNN Sequences (#2264)
* fix bidirectional case in references of sequences ops, enable decomposition of bidirectional cases in CommonOptimizations
* introduce new opset5, include GRU/RNN/LSTM Sequences to opset5
* Revert "introduce new opset5, include GRU/RNN/LSTM Sequences to opset5"
This reverts commit 73c22a11db
.
This commit is contained in:
parent
93074590de
commit
1b7dfc6e4c
@ -19,6 +19,9 @@ ngraph::pass::BidirectionalLSTMSequenceDecomposition::BidirectionalLSTMSequenceD
|
||||
return false;
|
||||
}
|
||||
|
||||
if (lstm_sequence->get_direction() != ngraph::op::RecurrentSequenceDirection::BIDIRECTIONAL)
|
||||
return false;
|
||||
|
||||
auto axis_0 = ngraph::opset4::Constant::create(element::i64, Shape{}, {0});
|
||||
auto axis_1 = ngraph::opset4::Constant::create(element::i64, Shape{}, {1});
|
||||
auto H = std::make_shared<opset4::Split>(lstm_sequence->input_value(1), axis_1, 2);
|
||||
@ -84,6 +87,9 @@ ngraph::pass::BidirectionalGRUSequenceDecomposition::BidirectionalGRUSequenceDec
|
||||
return false;
|
||||
}
|
||||
|
||||
if (gru_sequence->get_direction() != ngraph::op::RecurrentSequenceDirection::BIDIRECTIONAL)
|
||||
return false;
|
||||
|
||||
auto axis_0 = ngraph::opset4::Constant::create(element::i64, Shape{}, {0});
|
||||
auto axis_1 = ngraph::opset4::Constant::create(element::i64, Shape{}, {1});
|
||||
auto H = std::make_shared<opset4::Split>(gru_sequence->input_value(1), axis_1, 2);
|
||||
@ -145,6 +151,9 @@ ngraph::pass::BidirectionalRNNSequenceDecomposition::BidirectionalRNNSequenceDec
|
||||
return false;
|
||||
}
|
||||
|
||||
if (rnn_sequence->get_direction() != ngraph::op::RecurrentSequenceDirection::BIDIRECTIONAL)
|
||||
return false;
|
||||
|
||||
auto axis_0 = ngraph::opset4::Constant::create(element::i64, Shape{}, {0});
|
||||
auto axis_1 = ngraph::opset4::Constant::create(element::i64, Shape{}, {1});
|
||||
auto H = std::make_shared<opset4::Split>(rnn_sequence->input_value(1), axis_1, 2);
|
||||
|
@ -21,6 +21,7 @@
|
||||
#include "transformations/hswish_fusion.hpp"
|
||||
#include "transformations/normalize_l2_fusion.hpp"
|
||||
#include "transformations/convert_quantize_dequantize.hpp"
|
||||
#include "transformations/bidirectional_sequences_decomposition.hpp"
|
||||
|
||||
#include <ngraph/pass/manager.hpp>
|
||||
#include <ngraph/pass/constant_folding.hpp>
|
||||
@ -50,6 +51,9 @@ bool ngraph::pass::CommonOptimizations::run_on_function(std::shared_ptr<ngraph::
|
||||
manager.register_pass<ngraph::pass::HSwishFusion>();
|
||||
manager.register_pass<ngraph::pass::ConvertPadToGroupConvolution>();
|
||||
manager.register_pass<ngraph::pass::NormalizeL2Fusion>();
|
||||
manager.register_pass<ngraph::pass::BidirectionalLSTMSequenceDecomposition>();
|
||||
manager.register_pass<ngraph::pass::BidirectionalRNNSequenceDecomposition>();
|
||||
manager.register_pass<ngraph::pass::BidirectionalGRUSequenceDecomposition>();
|
||||
|
||||
manager.set_callback(m_transformation_callback);
|
||||
manager.run_passes(f);
|
||||
|
@ -84,11 +84,6 @@ namespace LayerTestsDefinitions {
|
||||
ngraph::ResultVector results{std::make_shared<ngraph::opset1::Result>(gru_sequence->output(0)),
|
||||
std::make_shared<ngraph::opset1::Result>(gru_sequence->output(1))};
|
||||
function = std::make_shared<ngraph::Function>(results, params, "gru_sequence");
|
||||
if (direction == ngraph::op::RecurrentSequenceDirection::BIDIRECTIONAL) {
|
||||
ngraph::pass::Manager m;
|
||||
m.register_pass<ngraph::pass::BidirectionalGRUSequenceDecomposition>();
|
||||
m.run_passes(function);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
@ -82,11 +82,6 @@ namespace LayerTestsDefinitions {
|
||||
std::make_shared<ngraph::opset1::Result>(lstm_sequence->output(1)),
|
||||
std::make_shared<ngraph::opset1::Result>(lstm_sequence->output(2))};
|
||||
function = std::make_shared<ngraph::Function>(results, params, "lstm_sequence");
|
||||
if (direction == ngraph::op::RecurrentSequenceDirection::BIDIRECTIONAL) {
|
||||
ngraph::pass::Manager m;
|
||||
m.register_pass<ngraph::pass::BidirectionalLSTMSequenceDecomposition>();
|
||||
m.run_passes(function);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
@ -82,11 +82,6 @@ namespace LayerTestsDefinitions {
|
||||
ngraph::ResultVector results{std::make_shared<ngraph::opset1::Result>(rnn_sequence->output(0)),
|
||||
std::make_shared<ngraph::opset1::Result>(rnn_sequence->output(1))};
|
||||
function = std::make_shared<ngraph::Function>(results, params, "rnn_sequence");
|
||||
if (direction == ngraph::op::RecurrentSequenceDirection::BIDIRECTIONAL) {
|
||||
ngraph::pass::Manager m;
|
||||
m.register_pass<ngraph::pass::BidirectionalRNNSequenceDecomposition>();
|
||||
m.run_passes(function);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
@ -218,15 +218,15 @@ namespace ngraph
|
||||
// Split bidirectional case to forward + reverse passes.
|
||||
// split inputs
|
||||
std::vector<std::vector<char>> H_split(
|
||||
2, std::vector<char>(ngraph::shape_size(H_shape) / 2));
|
||||
2, std::vector<char>(sizeof(T) * ngraph::shape_size(H_shape) / 2));
|
||||
std::vector<std::vector<char>> C_split(
|
||||
2, std::vector<char>(ngraph::shape_size(C_shape) / 2));
|
||||
2, std::vector<char>(sizeof(T) * ngraph::shape_size(C_shape) / 2));
|
||||
std::vector<std::vector<char>> W_split(
|
||||
2, std::vector<char>(ngraph::shape_size(W_shape) / 2));
|
||||
2, std::vector<char>(sizeof(T) * ngraph::shape_size(W_shape) / 2));
|
||||
std::vector<std::vector<char>> R_split(
|
||||
2, std::vector<char>(ngraph::shape_size(R_shape) / 2));
|
||||
2, std::vector<char>(sizeof(T) * ngraph::shape_size(R_shape) / 2));
|
||||
std::vector<std::vector<char>> B_split(
|
||||
2, std::vector<char>(ngraph::shape_size(B_shape) / 2));
|
||||
2, std::vector<char>(sizeof(T) * ngraph::shape_size(B_shape) / 2));
|
||||
char* h_pointers[2] = {H_split[0].data(), H_split[1].data()};
|
||||
char* c_pointers[2] = {C_split[0].data(), C_split[1].data()};
|
||||
char* w_pointers[2] = {W_split[0].data(), W_split[1].data()};
|
||||
@ -234,13 +234,17 @@ namespace ngraph
|
||||
char* b_pointers[2] = {B_split[0].data(), B_split[1].data()};
|
||||
reference::split(H, H_shape, sizeof(T), 1, 2, h_pointers);
|
||||
reference::split(C, C_shape, sizeof(T), 1, 2, c_pointers);
|
||||
reference::split(W, W_shape, sizeof(T), 1, 2, w_pointers);
|
||||
reference::split(R, R_shape, sizeof(T), 1, 2, r_pointers);
|
||||
reference::split(B, B_shape, sizeof(T), 1, 2, b_pointers);
|
||||
reference::split(W, W_shape, sizeof(T), 0, 2, w_pointers);
|
||||
reference::split(R, R_shape, sizeof(T), 0, 2, r_pointers);
|
||||
reference::split(B, B_shape, sizeof(T), 0, 2, b_pointers);
|
||||
std::vector<char> forward_res_y(sizeof(T) * H_shape[0] * H_shape[2] *
|
||||
X_shape[1]);
|
||||
std::vector<char> reverse_res_y(sizeof(T) * H_shape[0] * H_shape[2] *
|
||||
X_shape[1]);
|
||||
std::vector<std::vector<char>> forward_res(
|
||||
3, std::vector<char>(H_shape[0] * H_shape[2]));
|
||||
2, std::vector<char>(sizeof(T) * H_shape[0] * H_shape[2]));
|
||||
std::vector<std::vector<char>> reverse_res(
|
||||
3, std::vector<char>(H_shape[0] * H_shape[2]));
|
||||
2, std::vector<char>(sizeof(T) * H_shape[0] * H_shape[2]));
|
||||
|
||||
CellArgs args;
|
||||
args.activation_f = activation_f;
|
||||
@ -249,6 +253,13 @@ namespace ngraph
|
||||
args.clip = clip;
|
||||
std::vector<Shape> shapes = {
|
||||
X_shape, seq_lengths_shape, H_shape, C_shape, W_shape, R_shape, B_shape};
|
||||
// update H,C,W,R,B shapes after split
|
||||
shapes[2][1] = 1;
|
||||
shapes[3][1] = 1;
|
||||
for (int i = 4; i < shapes.size(); ++i)
|
||||
{
|
||||
shapes[i][0] = 1;
|
||||
}
|
||||
// forward pass
|
||||
cell_pass<T>(
|
||||
CellType::LSTM,
|
||||
@ -260,7 +271,7 @@ namespace ngraph
|
||||
r_pointers[0],
|
||||
b_pointers[0]},
|
||||
shapes,
|
||||
{forward_res[0].data(), forward_res[1].data(), forward_res[2].data()},
|
||||
{forward_res_y.data(), forward_res[0].data(), forward_res[1].data()},
|
||||
args,
|
||||
false);
|
||||
// reverse pass
|
||||
@ -274,32 +285,34 @@ namespace ngraph
|
||||
r_pointers[1],
|
||||
b_pointers[1]},
|
||||
shapes,
|
||||
{reverse_res[0].data(), reverse_res[1].data(), reverse_res[2].data()},
|
||||
{reverse_res_y.data(), reverse_res[0].data(), reverse_res[1].data()},
|
||||
args,
|
||||
true);
|
||||
|
||||
// Stack together respective outputs from both forward and reverse passes.
|
||||
std::vector<Shape> in_shapes = {{H_shape[0], 1, H_shape[2]},
|
||||
{H_shape[0], 1, H_shape[2]},
|
||||
{H_shape[0], 1, H_shape[2]}};
|
||||
Shape output_shape = {{H_shape[0], 2, H_shape[2]}};
|
||||
std::vector<Shape> in_shapes_y = {{H_shape[0], 1, X_shape[1], H_shape[2]},
|
||||
{H_shape[0], 1, X_shape[1], H_shape[2]}};
|
||||
std::vector<Shape> in_shapes_h_c = {{H_shape[0], 1, H_shape[2]},
|
||||
{H_shape[0], 1, H_shape[2]}};
|
||||
Shape output_shape_y{H_shape[0], 2, X_shape[1], H_shape[2]};
|
||||
Shape output_shape_h_c{H_shape[0], 2, H_shape[2]};
|
||||
|
||||
runtime::reference::concat({forward_res[0].data(), reverse_res[0].data()},
|
||||
runtime::reference::concat({forward_res_y.data(), reverse_res_y.data()},
|
||||
Y,
|
||||
in_shapes,
|
||||
output_shape,
|
||||
in_shapes_y,
|
||||
output_shape_y,
|
||||
1,
|
||||
sizeof(T));
|
||||
runtime::reference::concat({forward_res[0].data(), reverse_res[0].data()},
|
||||
Ho,
|
||||
in_shapes_h_c,
|
||||
output_shape_h_c,
|
||||
1,
|
||||
sizeof(T));
|
||||
runtime::reference::concat({forward_res[1].data(), reverse_res[1].data()},
|
||||
Ho,
|
||||
in_shapes,
|
||||
output_shape,
|
||||
1,
|
||||
sizeof(T));
|
||||
runtime::reference::concat({forward_res[2].data(), reverse_res[2].data()},
|
||||
Co,
|
||||
in_shapes,
|
||||
output_shape,
|
||||
in_shapes_h_c,
|
||||
output_shape_h_c,
|
||||
1,
|
||||
sizeof(T));
|
||||
}
|
||||
@ -351,25 +364,27 @@ namespace ngraph
|
||||
// Split bidirectional case to forward + reverse passes.
|
||||
// split inputs
|
||||
std::vector<std::vector<char>> H_split(
|
||||
2, std::vector<char>(ngraph::shape_size(H_shape) / 2));
|
||||
2, std::vector<char>(sizeof(T) * ngraph::shape_size(H_shape) / 2));
|
||||
std::vector<std::vector<char>> W_split(
|
||||
2, std::vector<char>(ngraph::shape_size(W_shape) / 2));
|
||||
2, std::vector<char>(sizeof(T) * ngraph::shape_size(W_shape) / 2));
|
||||
std::vector<std::vector<char>> R_split(
|
||||
2, std::vector<char>(ngraph::shape_size(R_shape) / 2));
|
||||
2, std::vector<char>(sizeof(T) * ngraph::shape_size(R_shape) / 2));
|
||||
std::vector<std::vector<char>> B_split(
|
||||
2, std::vector<char>(ngraph::shape_size(B_shape) / 2));
|
||||
2, std::vector<char>(sizeof(T) * ngraph::shape_size(B_shape) / 2));
|
||||
char* h_pointers[2] = {H_split[0].data(), H_split[1].data()};
|
||||
char* w_pointers[2] = {W_split[0].data(), W_split[1].data()};
|
||||
char* r_pointers[2] = {R_split[0].data(), R_split[1].data()};
|
||||
char* b_pointers[2] = {B_split[0].data(), B_split[1].data()};
|
||||
reference::split(H, H_shape, sizeof(T), 1, 2, h_pointers);
|
||||
reference::split(W, W_shape, sizeof(T), 1, 2, w_pointers);
|
||||
reference::split(R, R_shape, sizeof(T), 1, 2, r_pointers);
|
||||
reference::split(B, B_shape, sizeof(T), 1, 2, b_pointers);
|
||||
std::vector<std::vector<char>> forward_res(
|
||||
2, std::vector<char>(H_shape[0] * H_shape[2]));
|
||||
std::vector<std::vector<char>> reverse_res(
|
||||
2, std::vector<char>(H_shape[0] * H_shape[2]));
|
||||
reference::split(W, W_shape, sizeof(T), 0, 2, w_pointers);
|
||||
reference::split(R, R_shape, sizeof(T), 0, 2, r_pointers);
|
||||
reference::split(B, B_shape, sizeof(T), 0, 2, b_pointers);
|
||||
std::vector<char> forward_res_y(sizeof(T) * H_shape[0] * H_shape[2] *
|
||||
X_shape[1]);
|
||||
std::vector<char> forward_res_h(sizeof(T) * H_shape[0] * H_shape[2]);
|
||||
std::vector<char> reverse_res_y(sizeof(T) * H_shape[0] * H_shape[2] *
|
||||
X_shape[1]);
|
||||
std::vector<char> reverse_res_h(sizeof(T) * H_shape[0] * H_shape[2]);
|
||||
|
||||
CellArgs args;
|
||||
args.activation_f = activation_f;
|
||||
@ -378,6 +393,12 @@ namespace ngraph
|
||||
args.clip = clip;
|
||||
std::vector<Shape> shapes = {
|
||||
X_shape, seq_lengths_shape, H_shape, W_shape, R_shape, B_shape};
|
||||
// update H,W,R,B shapes after split
|
||||
shapes[2][1] = 1;
|
||||
for (int i = 3; i < shapes.size(); ++i)
|
||||
{
|
||||
shapes[i][0] = 1;
|
||||
}
|
||||
// forward pass
|
||||
cell_pass<T>(CellType::GRU,
|
||||
{X,
|
||||
@ -387,7 +408,7 @@ namespace ngraph
|
||||
r_pointers[0],
|
||||
b_pointers[0]},
|
||||
shapes,
|
||||
{forward_res[0].data(), forward_res[1].data()},
|
||||
{forward_res_y.data(), forward_res_h.data()},
|
||||
args,
|
||||
false);
|
||||
// reverse pass
|
||||
@ -399,25 +420,28 @@ namespace ngraph
|
||||
r_pointers[1],
|
||||
b_pointers[1]},
|
||||
shapes,
|
||||
{reverse_res[0].data(), reverse_res[1].data()},
|
||||
{reverse_res_y.data(), reverse_res_h.data()},
|
||||
args,
|
||||
true);
|
||||
|
||||
// Stack together respective outputs from both forward and reverse passes.
|
||||
std::vector<Shape> in_shapes = {{H_shape[0], 1, H_shape[2]},
|
||||
{H_shape[0], 1, H_shape[2]}};
|
||||
Shape output_shape = {{H_shape[0], 2, H_shape[2]}};
|
||||
std::vector<Shape> in_shapes_y = {{H_shape[0], 1, X_shape[1], H_shape[2]},
|
||||
{H_shape[0], 1, X_shape[1], H_shape[2]}};
|
||||
std::vector<Shape> in_shapes_h = {{H_shape[0], 1, H_shape[2]},
|
||||
{H_shape[0], 1, H_shape[2]}};
|
||||
Shape output_shape_y{H_shape[0], 2, X_shape[1], H_shape[2]};
|
||||
Shape output_shape_h{H_shape[0], 2, H_shape[2]};
|
||||
|
||||
runtime::reference::concat({forward_res[0].data(), reverse_res[0].data()},
|
||||
runtime::reference::concat({forward_res_y.data(), reverse_res_y.data()},
|
||||
Y,
|
||||
in_shapes,
|
||||
output_shape,
|
||||
in_shapes_y,
|
||||
output_shape_y,
|
||||
1,
|
||||
sizeof(T));
|
||||
runtime::reference::concat({forward_res[1].data(), reverse_res[1].data()},
|
||||
runtime::reference::concat({forward_res_h.data(), reverse_res_h.data()},
|
||||
Ho,
|
||||
in_shapes,
|
||||
output_shape,
|
||||
in_shapes_h,
|
||||
output_shape_h,
|
||||
1,
|
||||
sizeof(T));
|
||||
}
|
||||
@ -465,31 +489,39 @@ namespace ngraph
|
||||
// Split bidirectional case to forward + reverse passes.
|
||||
// split inputs
|
||||
std::vector<std::vector<char>> H_split(
|
||||
2, std::vector<char>(ngraph::shape_size(H_shape) / 2));
|
||||
2, std::vector<char>(sizeof(T) * ngraph::shape_size(H_shape) / 2));
|
||||
std::vector<std::vector<char>> W_split(
|
||||
2, std::vector<char>(ngraph::shape_size(W_shape) / 2));
|
||||
2, std::vector<char>(sizeof(T) * ngraph::shape_size(W_shape) / 2));
|
||||
std::vector<std::vector<char>> R_split(
|
||||
2, std::vector<char>(ngraph::shape_size(R_shape) / 2));
|
||||
2, std::vector<char>(sizeof(T) * ngraph::shape_size(R_shape) / 2));
|
||||
std::vector<std::vector<char>> B_split(
|
||||
2, std::vector<char>(ngraph::shape_size(B_shape) / 2));
|
||||
2, std::vector<char>(sizeof(T) * ngraph::shape_size(B_shape) / 2));
|
||||
char* h_pointers[2] = {H_split[0].data(), H_split[1].data()};
|
||||
char* w_pointers[2] = {W_split[0].data(), W_split[1].data()};
|
||||
char* r_pointers[2] = {R_split[0].data(), R_split[1].data()};
|
||||
char* b_pointers[2] = {B_split[0].data(), B_split[1].data()};
|
||||
reference::split(H, H_shape, sizeof(T), 1, 2, h_pointers);
|
||||
reference::split(W, W_shape, sizeof(T), 1, 2, w_pointers);
|
||||
reference::split(R, R_shape, sizeof(T), 1, 2, r_pointers);
|
||||
reference::split(B, B_shape, sizeof(T), 1, 2, b_pointers);
|
||||
std::vector<std::vector<char>> forward_res(
|
||||
2, std::vector<char>(H_shape[0] * H_shape[2]));
|
||||
std::vector<std::vector<char>> reverse_res(
|
||||
2, std::vector<char>(H_shape[0] * H_shape[2]));
|
||||
reference::split(W, W_shape, sizeof(T), 0, 2, w_pointers);
|
||||
reference::split(R, R_shape, sizeof(T), 0, 2, r_pointers);
|
||||
reference::split(B, B_shape, sizeof(T), 0, 2, b_pointers);
|
||||
std::vector<char> forward_res_y(sizeof(T) * H_shape[0] * H_shape[2] *
|
||||
X_shape[1]);
|
||||
std::vector<char> forward_res_h(sizeof(T) * H_shape[0] * H_shape[2]);
|
||||
std::vector<char> reverse_res_y(sizeof(T) * H_shape[0] * H_shape[2] *
|
||||
X_shape[1]);
|
||||
std::vector<char> reverse_res_h(sizeof(T) * H_shape[0] * H_shape[2]);
|
||||
|
||||
CellArgs args;
|
||||
args.activation_f = activation_f;
|
||||
args.clip = clip;
|
||||
std::vector<Shape> shapes = {
|
||||
X_shape, seq_lengths_shape, H_shape, W_shape, R_shape, B_shape};
|
||||
// update H,W,R,B shapes after split
|
||||
shapes[2][1] = 1;
|
||||
for (int i = 3; i < shapes.size(); ++i)
|
||||
{
|
||||
shapes[i][0] = 1;
|
||||
}
|
||||
// forward pass
|
||||
cell_pass<T>(CellType::RNN,
|
||||
{X,
|
||||
@ -499,7 +531,7 @@ namespace ngraph
|
||||
r_pointers[0],
|
||||
b_pointers[0]},
|
||||
shapes,
|
||||
{forward_res[0].data(), forward_res[1].data()},
|
||||
{forward_res_y.data(), forward_res_h.data()},
|
||||
args,
|
||||
false);
|
||||
// reverse pass
|
||||
@ -511,25 +543,28 @@ namespace ngraph
|
||||
r_pointers[1],
|
||||
b_pointers[1]},
|
||||
shapes,
|
||||
{reverse_res[0].data(), reverse_res[1].data()},
|
||||
{reverse_res_y.data(), reverse_res_h.data()},
|
||||
args,
|
||||
true);
|
||||
|
||||
// Stack together respective outputs from both forward and reverse passes.
|
||||
std::vector<Shape> in_shapes = {{H_shape[0], 1, H_shape[2]},
|
||||
{H_shape[0], 1, H_shape[2]}};
|
||||
Shape output_shape = {{H_shape[0], 2, H_shape[2]}};
|
||||
std::vector<Shape> in_shapes_y = {{H_shape[0], 1, X_shape[1], H_shape[2]},
|
||||
{H_shape[0], 1, X_shape[1], H_shape[2]}};
|
||||
std::vector<Shape> in_shapes_h = {{H_shape[0], 1, H_shape[2]},
|
||||
{H_shape[0], 1, H_shape[2]}};
|
||||
Shape output_shape_y{H_shape[0], 2, X_shape[1], H_shape[2]};
|
||||
Shape output_shape_h{H_shape[0], 2, H_shape[2]};
|
||||
|
||||
runtime::reference::concat({forward_res[0].data(), reverse_res[0].data()},
|
||||
runtime::reference::concat({forward_res_y.data(), reverse_res_y.data()},
|
||||
Y,
|
||||
in_shapes,
|
||||
output_shape,
|
||||
in_shapes_y,
|
||||
output_shape_y,
|
||||
1,
|
||||
sizeof(T));
|
||||
runtime::reference::concat({forward_res[1].data(), reverse_res[1].data()},
|
||||
runtime::reference::concat({forward_res_h.data(), reverse_res_h.data()},
|
||||
Ho,
|
||||
in_shapes,
|
||||
output_shape,
|
||||
in_shapes_h,
|
||||
output_shape_h,
|
||||
1,
|
||||
sizeof(T));
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user