Fix bidirectional mode in reference implementations of GRU/LSTM/RNN Sequences (#2264)

* fix bidirectional case in references of sequences ops, enable decomposition of bidirectional cases in CommonOptimizations

* introduce new opset5, include GRU/RNN/LSTM Sequences to opset5

* Revert "introduce new opset5, include GRU/RNN/LSTM Sequences to opset5"

This reverts commit 73c22a11db.
This commit is contained in:
Ivan Tikhonov
2020-09-18 10:14:01 +03:00
committed by GitHub
parent 93074590de
commit 1b7dfc6e4c
6 changed files with 119 additions and 86 deletions

View File

@@ -84,11 +84,6 @@ namespace LayerTestsDefinitions {
ngraph::ResultVector results{std::make_shared<ngraph::opset1::Result>(gru_sequence->output(0)),
std::make_shared<ngraph::opset1::Result>(gru_sequence->output(1))};
function = std::make_shared<ngraph::Function>(results, params, "gru_sequence");
if (direction == ngraph::op::RecurrentSequenceDirection::BIDIRECTIONAL) {
ngraph::pass::Manager m;
m.register_pass<ngraph::pass::BidirectionalGRUSequenceDecomposition>();
m.run_passes(function);
}
}

View File

@@ -82,11 +82,6 @@ namespace LayerTestsDefinitions {
std::make_shared<ngraph::opset1::Result>(lstm_sequence->output(1)),
std::make_shared<ngraph::opset1::Result>(lstm_sequence->output(2))};
function = std::make_shared<ngraph::Function>(results, params, "lstm_sequence");
if (direction == ngraph::op::RecurrentSequenceDirection::BIDIRECTIONAL) {
ngraph::pass::Manager m;
m.register_pass<ngraph::pass::BidirectionalLSTMSequenceDecomposition>();
m.run_passes(function);
}
}

View File

@@ -82,11 +82,6 @@ namespace LayerTestsDefinitions {
ngraph::ResultVector results{std::make_shared<ngraph::opset1::Result>(rnn_sequence->output(0)),
std::make_shared<ngraph::opset1::Result>(rnn_sequence->output(1))};
function = std::make_shared<ngraph::Function>(results, params, "rnn_sequence");
if (direction == ngraph::op::RecurrentSequenceDirection::BIDIRECTIONAL) {
ngraph::pass::Manager m;
m.register_pass<ngraph::pass::BidirectionalRNNSequenceDecomposition>();
m.run_passes(function);
}
}