SequenceFusion transformation (#12845)

* SequenceFusion transformation and tests

* Enable SequenceFusion transformation in MOC

* add missed includes

* fix type

* fix build, apply review comments

* fix build

* fix build

* fix build again

* use ov namespace in has_result_consumers function

* fix win build

* try to fix win build

* investigate issue on win platform

* investigate issue on win platform

* investigate issue on win

* issue on win platform

* remove the transformation from MOC

* fix LSTMCell fusion, simplify transformation implementation, fix copying tensor and friendly names

* clean up

* add support for LSTMCell v0, resolve review comments, enable additional tests
This commit is contained in:
Ivan Tikhonov 2022-10-03 09:36:54 +03:00 committed by GitHub
parent 0c54905587
commit 4f002c46b9
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 621 additions and 0 deletions

View File

@ -0,0 +1,34 @@
// Copyright (C) 2018-2022 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#pragma once
#include <memory>
#include <openvino/pass/graph_rewrite.hpp>
#include <transformations_visibility.hpp>
#include <vector>
namespace ov {
namespace pass {
class TRANSFORMATIONS_API SequenceFusion;
} // namespace pass
} // namespace ov
/**
* @ingroup ie_transformation_common_api
* @brief SequenceFusion transformation replaces a chain of Cells
* operations with single Sequence op.
*
* Supported cells: GRUCell, LSTMCell, RNNCell, AUGRUCell
* Prerequisites: the source of W,R,B inputs must be the same or
* it can be different Constants with the same type, shape and value.
*/
class ov::pass::SequenceFusion : public ov::pass::MatcherPass {
public:
OPENVINO_RTTI("SequenceFusion", "0");
SequenceFusion();
};

View File

@ -0,0 +1,355 @@
// Copyright (C) 2018-2022 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#include "transformations/common_optimizations/sequence_fusion.hpp"
#include <memory>
#include "itt.hpp"
#include "ngraph_ops/augru_cell.hpp"
#include "ngraph_ops/augru_sequence.hpp"
#include "openvino/core/rt_info.hpp"
#include "openvino/opsets/opset3.hpp"
#include "openvino/opsets/opset9.hpp"
#include "openvino/pass/pattern/op/wrap_type.hpp"
#include "transformations/utils/utils.hpp"
using namespace std;
using namespace ov::element;
using namespace ov::opset9;
using namespace ov::op::util;
namespace {
bool is_equal_consts(const shared_ptr<ov::Node>& l, const shared_ptr<ov::Node>& r) {
auto l_const = dynamic_pointer_cast<Constant>(l);
auto r_const = dynamic_pointer_cast<Constant>(r);
if (l_const && r_const) {
auto l_ptr = l_const->get_data_ptr();
auto r_ptr = r_const->get_data_ptr();
size_t bytes = shape_size(l_const->get_shape()) * l_const->get_element_type().size();
return l_const->get_element_type() == r_const->get_element_type() &&
l_const->get_shape() == r_const->get_shape() && (l_ptr == r_ptr || memcmp(l_ptr, r_ptr, bytes) == 0);
}
return false;
}
bool check_WRB(const shared_ptr<RNNCellBase>& cell_1, const shared_ptr<RNNCellBase>& cell_2) {
int64_t idx_W = 2, idx_R = 3, idx_B = 4;
auto increase_indexes = [&]() {
++idx_B;
++idx_R;
++idx_W;
};
auto lstm_cell_v4_1 = dynamic_pointer_cast<LSTMCell>(cell_1);
auto lstm_cell_v4_2 = dynamic_pointer_cast<LSTMCell>(cell_2);
// 2nd input is Cell State
if (lstm_cell_v4_1 && lstm_cell_v4_2) {
increase_indexes();
}
auto lstm_cell_v0_1 = dynamic_pointer_cast<ov::opset3::LSTMCell>(cell_1);
auto lstm_cell_v0_2 = dynamic_pointer_cast<ov::opset3::LSTMCell>(cell_2);
if (lstm_cell_v0_1 && lstm_cell_v0_2) {
if (lstm_cell_v0_1->get_weights_format() != lstm_cell_v0_2->get_weights_format() ||
lstm_cell_v0_1->get_input_forget() != lstm_cell_v0_2->get_input_forget()) {
return false;
}
increase_indexes();
}
auto lW = cell_1->input_value(idx_W).get_node_shared_ptr();
auto lR = cell_1->input_value(idx_R).get_node_shared_ptr();
auto lB = cell_1->input_value(idx_B).get_node_shared_ptr();
auto rW = cell_2->input_value(idx_W).get_node_shared_ptr();
auto rR = cell_2->input_value(idx_R).get_node_shared_ptr();
auto rB = cell_2->input_value(idx_B).get_node_shared_ptr();
bool is_equal = (lW.get() == rW.get() || is_equal_consts(lW, rW));
is_equal = is_equal && (lR.get() == rR.get() || is_equal_consts(lR, rR));
is_equal = is_equal && (lB.get() == rB.get() || is_equal_consts(lB, rB));
return is_equal;
}
bool is_equal_cells(const shared_ptr<RNNCellBase>& cell_1, const shared_ptr<RNNCellBase>& cell_2) {
bool is_equal = true;
auto gru_cell_1 = dynamic_pointer_cast<GRUCell>(cell_1);
auto gru_cell_2 = dynamic_pointer_cast<GRUCell>(cell_2);
if (gru_cell_1 && gru_cell_2) {
is_equal = gru_cell_1->get_linear_before_reset() == gru_cell_2->get_linear_before_reset();
}
is_equal = is_equal && cell_1->get_type_name() == cell_2->get_type_name() &&
cell_1->get_hidden_size() == cell_2->get_hidden_size() &&
cell_1->get_activations() == cell_2->get_activations() &&
cell_1->get_activations_alpha() == cell_2->get_activations_alpha() &&
cell_1->get_activations_beta() == cell_2->get_activations_beta() &&
cell_1->get_clip() == cell_2->get_clip() && check_WRB(cell_1, cell_2);
return is_equal;
}
bool check_lstm_cell(const shared_ptr<RNNCellBase>& prev_cell, const shared_ptr<RNNCellBase>& current_cell) {
// check intermediate C outputs in case of LSTMCell
// LSTMCell - C -> LSTMCell
if ((dynamic_pointer_cast<LSTMCell>(prev_cell) || dynamic_pointer_cast<ov::opset3::LSTMCell>(prev_cell))) {
const auto& target_inputs = prev_cell->get_output_target_inputs(1);
bool valid = target_inputs.empty() ||
(target_inputs.size() == 1 &&
dynamic_cast<RNNCellBase*>(target_inputs.begin()->get_node()) == current_cell.get() &&
target_inputs.begin()->get_index() == 2);
// if intermediate C output is connected to other node, except LSTMCell,
// we can't replace cells with sequence. Sequence doesn't provide access to these outputs.
return valid;
}
return true;
}
shared_ptr<RNNCellBase> find_cell_chain(ov::pass::NodeRegistry& cp_from,
ov::pass::NodeRegistry& cp_to,
const shared_ptr<RNNCellBase>& current_cell,
ov::OutputVector& x_to_concat,
ov::OutputVector& attention_to_concat,
map<int, ov::Output<ov::Node>>& h_outputs_to_redirect,
int& cells_cnt,
const shared_ptr<ov::Node>& axis_1) {
cells_cnt = 1;
shared_ptr<RNNCellBase> current = current_cell;
while (true) {
cp_from.add(current);
// check the source node of HiddenState input
auto prev = current->input_value(1).get_node_shared_ptr();
auto prev_cell = dynamic_pointer_cast<RNNCellBase>(prev);
auto in_X = current->input(0);
x_to_concat.push_back(cp_to.make<Unsqueeze>(in_X.get_source_output(), axis_1));
h_outputs_to_redirect[cells_cnt] = current->output(0);
if (auto augru = dynamic_pointer_cast<ov::op::internal::AUGRUCell>(current)) {
attention_to_concat.push_back(cp_to.make<Unsqueeze>(augru->input_value(5), axis_1));
}
if (prev_cell && is_equal_cells(prev_cell, current) && check_lstm_cell(prev_cell, current)) {
current = prev_cell;
cells_cnt++;
} else {
break;
}
}
reverse(x_to_concat.begin(), x_to_concat.end());
reverse(attention_to_concat.begin(), attention_to_concat.end());
// the first cell in the chain
return current;
}
bool create_sequence(ov::pass::NodeRegistry& cp_to,
const shared_ptr<RNNCellBase>& first_cell,
const shared_ptr<RNNCellBase>& last_cell,
const ov::OutputVector& x_to_concat,
const ov::OutputVector& attention_to_concat,
const map<int, ov::Output<ov::Node>>& h_outputs_to_redirect,
int cells_cnt,
const shared_ptr<ov::Node>& axis_0,
const shared_ptr<ov::Node>& axis_1) {
int64_t idx_W = 2, idx_R = 3, idx_B = 4;
// 2nd input is Cell State
bool is_lstm = false;
if (dynamic_pointer_cast<LSTMCell>(last_cell) || dynamic_pointer_cast<ov::opset3::LSTMCell>(last_cell)) {
is_lstm = true;
idx_B++;
idx_R++;
idx_W++;
}
const auto X_in = cp_to.make<Concat>(x_to_concat, 1);
const auto Ht_in = cp_to.make<Unsqueeze>(first_cell->input_value(1), axis_1);
const auto W_in = cp_to.make<Unsqueeze>(first_cell->input_value(idx_W), axis_0);
const auto R_in = cp_to.make<Unsqueeze>(first_cell->input_value(idx_R), axis_0);
const auto B_in = cp_to.make<Unsqueeze>(first_cell->input_value(idx_B), axis_0);
const auto& shape_node = cp_to.add(ngraph::op::util::make_try_fold<ShapeOf>(first_cell->input_value(0)));
const auto& zero = cp_to.make<Constant>(i64, ov::Shape{1}, 0);
const auto& batch_dimension = cp_to.add(ngraph::op::util::make_try_fold<Gather>(shape_node, zero, axis_0));
auto seq_lengths_scalar = cp_to.make<Constant>(i64, ov::Shape{}, cells_cnt);
auto sequence_lengths_in =
cp_to.add(ngraph::op::util::make_try_fold<Broadcast>(seq_lengths_scalar, batch_dimension));
shared_ptr<ov::Node> sequence;
ov::OutputVector outputs(1);
if (dynamic_pointer_cast<LSTMCell>(first_cell)) {
const auto Ct_in = cp_to.make<Unsqueeze>(first_cell->input_value(2), axis_1);
sequence = cp_to.make<LSTMSequence>(X_in,
Ht_in,
Ct_in,
sequence_lengths_in,
W_in,
R_in,
B_in,
first_cell->get_hidden_size(),
ov::op::RecurrentSequenceDirection::FORWARD,
first_cell->get_activations_alpha(),
first_cell->get_activations_beta(),
first_cell->get_activations(),
first_cell->get_clip());
outputs.resize(2);
outputs[1] = cp_to.make<Squeeze>(sequence->output(2), axis_1);
} else if (auto lstm_cell_v0 = dynamic_pointer_cast<ov::opset3::LSTMCell>(first_cell)) {
// input_forget modification is not supported
if (lstm_cell_v0->get_input_forget()) {
return false;
}
auto weights_format = lstm_cell_v0->get_weights_format();
ov::Output<ov::Node> W = W_in, R = R_in, B = B_in;
if (weights_format != ov::op::LSTMWeightsFormat::FICO) {
W = ov::op::util::convert_lstm_node_format(W_in, convert_lstm_weights_enums(weights_format));
R = ov::op::util::convert_lstm_node_format(R_in, convert_lstm_weights_enums(weights_format));
B = ov::op::util::convert_lstm_node_format(B_in, convert_lstm_weights_enums(weights_format));
}
const auto Ct_in = cp_to.make<Unsqueeze>(first_cell->input_value(2), axis_1);
sequence = cp_to.make<LSTMSequence>(X_in,
Ht_in,
Ct_in,
sequence_lengths_in,
W,
R,
B,
first_cell->get_hidden_size(),
ov::op::RecurrentSequenceDirection::FORWARD,
first_cell->get_activations_alpha(),
first_cell->get_activations_beta(),
first_cell->get_activations(),
first_cell->get_clip());
outputs.resize(2);
outputs[1] = cp_to.make<Squeeze>(sequence->output(2), axis_1);
} else if (auto gru_cell = dynamic_pointer_cast<GRUCell>(first_cell)) {
sequence = cp_to.make<GRUSequence>(X_in,
Ht_in,
sequence_lengths_in,
W_in,
R_in,
B_in,
first_cell->get_hidden_size(),
ov::op::RecurrentSequenceDirection::FORWARD,
first_cell->get_activations(),
first_cell->get_activations_alpha(),
first_cell->get_activations_beta(),
first_cell->get_clip(),
gru_cell->get_linear_before_reset());
} else if (dynamic_pointer_cast<RNNCell>(first_cell)) {
sequence = cp_to.make<RNNSequence>(X_in,
Ht_in,
sequence_lengths_in,
W_in,
R_in,
B_in,
first_cell->get_hidden_size(),
ov::op::RecurrentSequenceDirection::FORWARD,
first_cell->get_activations(),
first_cell->get_activations_alpha(),
first_cell->get_activations_beta(),
first_cell->get_clip());
} else if (dynamic_pointer_cast<ov::op::internal::AUGRUCell>(first_cell)) {
const auto A_in = cp_to.make<Concat>(attention_to_concat, 1);
sequence = cp_to.make<ov::op::internal::AUGRUSequence>(X_in,
Ht_in,
sequence_lengths_in,
W_in,
R_in,
B_in,
A_in,
first_cell->get_hidden_size());
} else {
// cell is not supported;
return false;
}
if (!h_outputs_to_redirect.empty()) {
auto squeeze_Y = cp_to.make<Squeeze>(sequence->output(0), axis_1);
auto split = cp_to.make<Split>(squeeze_Y, axis_1, cells_cnt);
for (auto it : h_outputs_to_redirect) {
auto Hi = split->output(cells_cnt - it.first);
auto friendly_name = it.second.get_node_shared_ptr()->get_friendly_name();
if (it.first == 1) {
Hi = sequence->output(1);
}
auto squeeze = cp_to.make<Squeeze>(Hi, axis_1);
it.second.replace(squeeze);
if (is_lstm) {
friendly_name += ":1";
}
squeeze->set_friendly_name(friendly_name);
}
}
if (is_lstm) {
auto squeeze = cp_to.make<Squeeze>(sequence->output(2), axis_1);
last_cell->output(1).replace(squeeze);
squeeze->set_friendly_name(last_cell->get_friendly_name() + ":2");
}
return true;
}
} // namespace
ov::pass::SequenceFusion::SequenceFusion() {
MATCHER_SCOPE(SequenceFusion);
auto cell = pattern::wrap_type<RNNCellBase>();
matcher_pass_callback callback = [=](pattern::Matcher& m) {
NodeRegistry copy_from;
NodeRegistry copy_to;
auto cell = m.get_match_root();
shared_ptr<RNNCellBase> current_cell = dynamic_pointer_cast<RNNCellBase>(cell);
if (!current_cell) {
return false;
}
// check that this is the last Cell in the chain, e.g.
// GRUCell -> GRUCell (the last cell) -> OtherNode
// GRUCell (hidden_size = 128) -> GRUCell (hs = 128, the last) -> GRUCell (hs = 64)
for (const auto& target : cell->get_output_target_inputs(0)) {
auto cell_1 = dynamic_pointer_cast<RNNCellBase>(target.get_node()->shared_from_this());
if (cell_1 && is_equal_cells(cell_1, current_cell)) {
return false;
}
}
int cells_cnt;
ov::OutputVector x_to_concat;
ov::OutputVector attention_to_concat;
map<int, ov::Output<ov::Node>> h_outputs_to_redirect;
auto axis_0 = copy_to.make<Constant>(i64, Shape{}, 0);
auto axis_1 = copy_to.make<Constant>(i64, Shape{}, 1);
// detect chain (Cell->Cell->Cell->..)
auto first_cell = find_cell_chain(copy_from,
copy_to,
current_cell,
x_to_concat,
attention_to_concat,
h_outputs_to_redirect,
cells_cnt,
axis_1);
if (!first_cell) {
return false;
}
// no reasons to create sequence if the single cell detected.
// TODO: investigate optimal cnt of cells
constexpr int optimal_cnt_of_cells = 2;
if (cells_cnt < optimal_cnt_of_cells) {
return false;
}
auto res = create_sequence(copy_to,
first_cell,
current_cell,
x_to_concat,
attention_to_concat,
h_outputs_to_redirect,
cells_cnt,
axis_0,
axis_1);
if (!res) {
return false;
}
copy_runtime_info(copy_from.get(), copy_to.get());
return true;
};
auto m = make_shared<pattern::Matcher>(cell, matcher_name);
this->register_matcher(m, callback);
}

View File

@ -23,6 +23,8 @@ enum class LSTMWeightsFormat {
IOFC, // ONNX
};
ov::op::util::LSTMWeightsFormat convert_lstm_weights_enums(LSTMWeightsFormat format);
namespace v0 {
///
/// \brief Class for single lstm cell node.

View File

@ -253,6 +253,22 @@ NGRAPH_API EnumNames<ngraph::op::LSTMWeightsFormat>& EnumNames<ngraph::op::LSTMW
BWDCMP_RTTI_DEFINITION(AttributeAdapter<ov::op::LSTMWeightsFormat>);
ov::op::util::LSTMWeightsFormat op::convert_lstm_weights_enums(op::LSTMWeightsFormat format) {
switch (format) {
case LSTMWeightsFormat::FICO:
return ov::op::util::LSTMWeightsFormat::FICO;
case LSTMWeightsFormat::ICOF:
return ov::op::util::LSTMWeightsFormat::ICOF;
case LSTMWeightsFormat::IFCO:
return ov::op::util::LSTMWeightsFormat::IFCO;
case LSTMWeightsFormat::IFOC:
return ov::op::util::LSTMWeightsFormat::IFOC;
case LSTMWeightsFormat::IOFC:
return ov::op::util::LSTMWeightsFormat::IOFC;
default:
OPENVINO_ASSERT(false, "Incorrect LSTM weights format");
}
}
} // namespace ov
std::ostream& ov::operator<<(std::ostream& s, const op::LSTMWeightsFormat& type) {

View File

@ -0,0 +1,214 @@
// Copyright (C) 2018-2022 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#include <gtest/gtest.h>
#include <queue>
#include "common_test_utils/ngraph_test_utils.hpp"
#include "openvino/opsets/opset3.hpp"
#include "openvino/opsets/opset9.hpp"
#include "transformations/common_optimizations/sequence_fusion.hpp"
#include "ngraph_ops/augru_sequence.hpp"
#include "ngraph_ops/augru_cell.hpp"
using namespace ov;
using namespace std;
using namespace testing;
using namespace ov::opset9;
using namespace ov::element;
namespace {
enum class RNN_TYPE {
LSTM_v0,
LSTM_v4,
GRU,
RNN,
AUGRU
};
int get_gate_by_rnn_type(RNN_TYPE rnn_type) {
int gate = 1;
if (rnn_type == RNN_TYPE::LSTM_v4 || rnn_type == RNN_TYPE::LSTM_v0) {
gate = 4;
} else if (rnn_type == RNN_TYPE::GRU || rnn_type == RNN_TYPE::AUGRU) {
gate = 3;
} else if (rnn_type == RNN_TYPE::RNN) {
gate = 1;
}
return gate;
}
OutputVector create_cell(RNN_TYPE rnn_type,
const shared_ptr<Node>& X,
const shared_ptr<Node>& H,
const shared_ptr<Node>& C,
const shared_ptr<Node>& W,
const shared_ptr<Node>& R,
const shared_ptr<Node>& B,
const shared_ptr<Node>& A,
size_t hidden_size,
int64_t cells_cnt) {
shared_ptr<Node> cell;
Output<Node> cur_H = H;
Output<Node> cur_C = C;
OutputVector hidden_vec;
auto axis_1 = make_shared<Constant>(i64, Shape{}, 1);
for (int i = 0; i < cells_cnt; ++i) {
if (rnn_type == RNN_TYPE::LSTM_v4) {
cell = make_shared<LSTMCell>(X, cur_H, cur_C, W, R, B, hidden_size);
cur_C = cell->output(1);
} else if (rnn_type == RNN_TYPE::LSTM_v0) {
cell = make_shared<opset3::LSTMCell>(X, cur_H, cur_C, W, R, B, hidden_size, ov::op::LSTMWeightsFormat::FICO);
cur_C = cell->output(1);
} else if (rnn_type == RNN_TYPE::GRU) {
cell = make_shared<GRUCell>(X, cur_H, W, R, B, hidden_size);
} else if (rnn_type == RNN_TYPE::RNN) {
cell = make_shared<RNNCell>(X, cur_H, W, R, B, hidden_size);
} else if (rnn_type == RNN_TYPE::AUGRU) {
cell = make_shared<op::internal::AUGRUCell>(X, cur_H, W, R, B, A, hidden_size);
}
cur_H = cell->output(0);
hidden_vec.push_back(make_shared<Unsqueeze>(cur_H, axis_1));
}
auto concat = make_shared<Concat>(hidden_vec, 1);
OutputVector outputs = {concat->output(0)};
auto cell_outputs = cell->outputs();
outputs.insert(outputs.end(), cell_outputs.begin(), cell_outputs.end());
return outputs;
}
shared_ptr<Model> gen_model(RNN_TYPE rnn_type, size_t batch, size_t hidden_size, size_t input_size,
int64_t cells_cnt) {
int gate = get_gate_by_rnn_type(rnn_type);
auto X = make_shared<Parameter>(f32, Shape{batch, input_size});
auto H = make_shared<Parameter>(f32, Shape{batch, hidden_size});
auto C = make_shared<Parameter>(f32, Shape{batch, hidden_size});
auto W = make_shared<Parameter>(f32, Shape{gate * hidden_size, input_size});
auto R = make_shared<Parameter>(f32, Shape{gate * hidden_size, hidden_size});
auto B = make_shared<Parameter>(f32, Shape{gate * hidden_size});
auto A = make_shared<Parameter>(f32, Shape{batch, 1});
auto outputs = create_cell(rnn_type, X, H, C, W, R, B, A, hidden_size, cells_cnt);
ParameterVector params = {X, H, W, R, B};
if (rnn_type == RNN_TYPE::LSTM_v4 || rnn_type == RNN_TYPE::LSTM_v0) {
params.push_back(C);
} else if (rnn_type == RNN_TYPE::AUGRU) {
params.push_back(A);
}
return make_shared<Model>(outputs, params);
}
shared_ptr<Model> gen_reference(RNN_TYPE rnn_type, size_t batch, size_t hidden_size, size_t input_size,
int64_t cells_cnt) {
int gate = get_gate_by_rnn_type(rnn_type);
auto axis_0 = make_shared<Constant>(i64, Shape{}, 0);
auto axis_1 = make_shared<Constant>(i64, Shape{}, 1);
auto seq_len = make_shared<Constant>(i64, Shape{batch}, cells_cnt);
auto X = make_shared<Parameter>(f32, Shape{batch, input_size});
auto H = make_shared<Parameter>(f32, Shape{batch, hidden_size});
auto C = make_shared<Parameter>(f32, Shape{batch, hidden_size});
auto W = make_shared<Parameter>(f32, Shape{gate * hidden_size, input_size});
auto R = make_shared<Parameter>(f32, Shape{gate * hidden_size, hidden_size});
auto B = make_shared<Parameter>(f32, Shape{gate * hidden_size});
auto A = make_shared<Parameter>(f32, Shape{batch, 1});
ParameterVector params = {X, H, W, R, B};
if (rnn_type == RNN_TYPE::LSTM_v4 || rnn_type == RNN_TYPE::LSTM_v0) {
params.push_back(C);
} else if (rnn_type == RNN_TYPE::AUGRU) {
params.push_back(A);
}
auto unH = make_shared<Unsqueeze>(H, axis_1);
auto unC = make_shared<Unsqueeze>(C, axis_1);
auto unW = make_shared<Unsqueeze>(W, axis_0);
auto unR = make_shared<Unsqueeze>(R, axis_0);
auto unB = make_shared<Unsqueeze>(B, axis_0);
OutputVector in_X;
OutputVector in_A;
for (int i = 0; i < cells_cnt; ++i) {
in_X.push_back(make_shared<Unsqueeze>(X, axis_1));
in_A.push_back(make_shared<Unsqueeze>(A, axis_1));
}
auto concat_X = make_shared<Concat>(in_X, 1);
auto concat_A = make_shared<Concat>(in_A, 1);
shared_ptr<Node> seq;
if (rnn_type == RNN_TYPE::LSTM_v4 || rnn_type == RNN_TYPE::LSTM_v0) {
seq = make_shared<LSTMSequence>(concat_X, unH, unC, seq_len, unW, unR, unB, hidden_size, op::RecurrentSequenceDirection::FORWARD);
} else if (rnn_type == RNN_TYPE::GRU) {
seq = make_shared<GRUSequence>(concat_X, unH, seq_len, unW, unR, unB, hidden_size, op::RecurrentSequenceDirection::FORWARD);
} else if (rnn_type == RNN_TYPE::RNN) {
seq = make_shared<RNNSequence>(concat_X, unH, seq_len, unW, unR, unB, hidden_size, op::RecurrentSequenceDirection::FORWARD);
} else if (rnn_type == RNN_TYPE::AUGRU) {
seq = make_shared<op::internal::AUGRUSequence>(concat_X, unH, seq_len, unW, unR, unB, concat_A, hidden_size);
}
auto squeeze_H = make_shared<Squeeze>(seq->output(0), axis_1);
auto _axis_1 = make_shared<Constant>(i64, Shape{}, 1);
auto split = make_shared<Split>(squeeze_H, axis_1, cells_cnt);
OutputVector in_vec;
for (int i = 0; i < split->outputs().size(); ++i) {
auto squeeze = make_shared<Squeeze>(split->output(i), axis_1);
in_vec.push_back(make_shared<Unsqueeze>(squeeze, _axis_1));
}
auto concat = make_shared<Concat>(in_vec, 1);
auto squeeze_Ht = make_shared<Squeeze>(seq->output(1), axis_1);
OutputVector outputs = {concat->output(0), squeeze_Ht};
if (rnn_type == RNN_TYPE::LSTM_v4 || rnn_type == RNN_TYPE::LSTM_v0) {
auto squeeze_Ct = make_shared<Squeeze>(seq->output(2), axis_1);
outputs.push_back(squeeze_Ct);
}
return make_shared<Model>(outputs, params);
}
} // namespace
struct SequenceFusionParams {
RNN_TYPE rnn_type;
size_t batch;
size_t hidden_size;
size_t input_size;
int64_t cell_cnt;
};
class SequenceFusionTest
: public WithParamInterface<SequenceFusionParams>,
public TransformationTestsF {
};
TEST_P(SequenceFusionTest, SequencePattern) {
const auto& p = GetParam();
{
function = gen_model(p.rnn_type, p.batch,
p.hidden_size, p.input_size, p.cell_cnt);
manager.register_pass<pass::SequenceFusion>();
}
// the transformation won't be applied for single cell
if (p.cell_cnt > 1) {
function_ref = gen_reference(p.rnn_type, p.batch, p.hidden_size, p.input_size, p.cell_cnt);
}
comparator.enable(FunctionsComparator::CmpValues::ACCURACY);
comparator.enable(FunctionsComparator::CmpValues::CONST_VALUES);
comparator.enable(FunctionsComparator::CmpValues::ATTRIBUTES);
}
static const std::vector<SequenceFusionParams> params = {
SequenceFusionParams{RNN_TYPE::LSTM_v4, 2, 128, 32, 1},
SequenceFusionParams{RNN_TYPE::LSTM_v4, 2, 128, 32, 10},
SequenceFusionParams{RNN_TYPE::LSTM_v0, 2, 128, 32, 1},
SequenceFusionParams{RNN_TYPE::LSTM_v0, 2, 128, 32, 10},
SequenceFusionParams{RNN_TYPE::GRU, 2, 128, 32, 1},
SequenceFusionParams{RNN_TYPE::GRU, 2, 128, 32, 10},
SequenceFusionParams{RNN_TYPE::RNN, 2, 128, 32, 1},
SequenceFusionParams{RNN_TYPE::RNN, 2, 128, 32, 10},
SequenceFusionParams{RNN_TYPE::AUGRU, 2, 128, 32, 1},
SequenceFusionParams{RNN_TYPE::AUGRU, 2, 128, 32, 10},
};
INSTANTIATE_TEST_SUITE_P(SequenceFusionTest, SequenceFusionTest, ValuesIn(params));