SequenceFusion transformation (#12845)
* SequenceFusion transformation and tests * Enable SequenceFusion transformation in MOC * add missed includes * fix type * fix build, apply review comments * fix build * fix build * fix build again * use ov namespace in has_result_consumers function * fix win build * try to fix win build * investigate issue on win platform * investigate issue on win platform * investigate issue on win * issue on win platform * remove the transformation from MOC * fix LSTMCell fusion, simplify transformation implementation, fix copying tensor and friendly names * clean up * add support for LSTMCell v0, resolve review comments, enable additional tests
This commit is contained in:
parent
0c54905587
commit
4f002c46b9
@ -0,0 +1,34 @@
|
||||
// Copyright (C) 2018-2022 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <memory>
|
||||
#include <openvino/pass/graph_rewrite.hpp>
|
||||
#include <transformations_visibility.hpp>
|
||||
#include <vector>
|
||||
|
||||
namespace ov {
|
||||
namespace pass {
|
||||
|
||||
class TRANSFORMATIONS_API SequenceFusion;
|
||||
|
||||
} // namespace pass
|
||||
} // namespace ov
|
||||
|
||||
/**
|
||||
* @ingroup ie_transformation_common_api
|
||||
* @brief SequenceFusion transformation replaces a chain of Cells
|
||||
* operations with single Sequence op.
|
||||
*
|
||||
* Supported cells: GRUCell, LSTMCell, RNNCell, AUGRUCell
|
||||
* Prerequisites: the source of W,R,B inputs must be the same or
|
||||
* it can be different Constants with the same type, shape and value.
|
||||
*/
|
||||
|
||||
class ov::pass::SequenceFusion : public ov::pass::MatcherPass {
|
||||
public:
|
||||
OPENVINO_RTTI("SequenceFusion", "0");
|
||||
SequenceFusion();
|
||||
};
|
@ -0,0 +1,355 @@
|
||||
// Copyright (C) 2018-2022 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#include "transformations/common_optimizations/sequence_fusion.hpp"
|
||||
|
||||
#include <memory>
|
||||
|
||||
#include "itt.hpp"
|
||||
#include "ngraph_ops/augru_cell.hpp"
|
||||
#include "ngraph_ops/augru_sequence.hpp"
|
||||
#include "openvino/core/rt_info.hpp"
|
||||
#include "openvino/opsets/opset3.hpp"
|
||||
#include "openvino/opsets/opset9.hpp"
|
||||
#include "openvino/pass/pattern/op/wrap_type.hpp"
|
||||
#include "transformations/utils/utils.hpp"
|
||||
|
||||
using namespace std;
|
||||
using namespace ov::element;
|
||||
using namespace ov::opset9;
|
||||
using namespace ov::op::util;
|
||||
|
||||
namespace {
|
||||
bool is_equal_consts(const shared_ptr<ov::Node>& l, const shared_ptr<ov::Node>& r) {
|
||||
auto l_const = dynamic_pointer_cast<Constant>(l);
|
||||
auto r_const = dynamic_pointer_cast<Constant>(r);
|
||||
if (l_const && r_const) {
|
||||
auto l_ptr = l_const->get_data_ptr();
|
||||
auto r_ptr = r_const->get_data_ptr();
|
||||
size_t bytes = shape_size(l_const->get_shape()) * l_const->get_element_type().size();
|
||||
return l_const->get_element_type() == r_const->get_element_type() &&
|
||||
l_const->get_shape() == r_const->get_shape() && (l_ptr == r_ptr || memcmp(l_ptr, r_ptr, bytes) == 0);
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
bool check_WRB(const shared_ptr<RNNCellBase>& cell_1, const shared_ptr<RNNCellBase>& cell_2) {
|
||||
int64_t idx_W = 2, idx_R = 3, idx_B = 4;
|
||||
auto increase_indexes = [&]() {
|
||||
++idx_B;
|
||||
++idx_R;
|
||||
++idx_W;
|
||||
};
|
||||
auto lstm_cell_v4_1 = dynamic_pointer_cast<LSTMCell>(cell_1);
|
||||
auto lstm_cell_v4_2 = dynamic_pointer_cast<LSTMCell>(cell_2);
|
||||
// 2nd input is Cell State
|
||||
if (lstm_cell_v4_1 && lstm_cell_v4_2) {
|
||||
increase_indexes();
|
||||
}
|
||||
auto lstm_cell_v0_1 = dynamic_pointer_cast<ov::opset3::LSTMCell>(cell_1);
|
||||
auto lstm_cell_v0_2 = dynamic_pointer_cast<ov::opset3::LSTMCell>(cell_2);
|
||||
if (lstm_cell_v0_1 && lstm_cell_v0_2) {
|
||||
if (lstm_cell_v0_1->get_weights_format() != lstm_cell_v0_2->get_weights_format() ||
|
||||
lstm_cell_v0_1->get_input_forget() != lstm_cell_v0_2->get_input_forget()) {
|
||||
return false;
|
||||
}
|
||||
increase_indexes();
|
||||
}
|
||||
auto lW = cell_1->input_value(idx_W).get_node_shared_ptr();
|
||||
auto lR = cell_1->input_value(idx_R).get_node_shared_ptr();
|
||||
auto lB = cell_1->input_value(idx_B).get_node_shared_ptr();
|
||||
auto rW = cell_2->input_value(idx_W).get_node_shared_ptr();
|
||||
auto rR = cell_2->input_value(idx_R).get_node_shared_ptr();
|
||||
auto rB = cell_2->input_value(idx_B).get_node_shared_ptr();
|
||||
bool is_equal = (lW.get() == rW.get() || is_equal_consts(lW, rW));
|
||||
is_equal = is_equal && (lR.get() == rR.get() || is_equal_consts(lR, rR));
|
||||
is_equal = is_equal && (lB.get() == rB.get() || is_equal_consts(lB, rB));
|
||||
return is_equal;
|
||||
}
|
||||
|
||||
bool is_equal_cells(const shared_ptr<RNNCellBase>& cell_1, const shared_ptr<RNNCellBase>& cell_2) {
|
||||
bool is_equal = true;
|
||||
auto gru_cell_1 = dynamic_pointer_cast<GRUCell>(cell_1);
|
||||
auto gru_cell_2 = dynamic_pointer_cast<GRUCell>(cell_2);
|
||||
if (gru_cell_1 && gru_cell_2) {
|
||||
is_equal = gru_cell_1->get_linear_before_reset() == gru_cell_2->get_linear_before_reset();
|
||||
}
|
||||
is_equal = is_equal && cell_1->get_type_name() == cell_2->get_type_name() &&
|
||||
cell_1->get_hidden_size() == cell_2->get_hidden_size() &&
|
||||
cell_1->get_activations() == cell_2->get_activations() &&
|
||||
cell_1->get_activations_alpha() == cell_2->get_activations_alpha() &&
|
||||
cell_1->get_activations_beta() == cell_2->get_activations_beta() &&
|
||||
cell_1->get_clip() == cell_2->get_clip() && check_WRB(cell_1, cell_2);
|
||||
return is_equal;
|
||||
}
|
||||
|
||||
bool check_lstm_cell(const shared_ptr<RNNCellBase>& prev_cell, const shared_ptr<RNNCellBase>& current_cell) {
|
||||
// check intermediate C outputs in case of LSTMCell
|
||||
// LSTMCell - C -> LSTMCell
|
||||
if ((dynamic_pointer_cast<LSTMCell>(prev_cell) || dynamic_pointer_cast<ov::opset3::LSTMCell>(prev_cell))) {
|
||||
const auto& target_inputs = prev_cell->get_output_target_inputs(1);
|
||||
bool valid = target_inputs.empty() ||
|
||||
(target_inputs.size() == 1 &&
|
||||
dynamic_cast<RNNCellBase*>(target_inputs.begin()->get_node()) == current_cell.get() &&
|
||||
target_inputs.begin()->get_index() == 2);
|
||||
|
||||
// if intermediate C output is connected to other node, except LSTMCell,
|
||||
// we can't replace cells with sequence. Sequence doesn't provide access to these outputs.
|
||||
return valid;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
shared_ptr<RNNCellBase> find_cell_chain(ov::pass::NodeRegistry& cp_from,
|
||||
ov::pass::NodeRegistry& cp_to,
|
||||
const shared_ptr<RNNCellBase>& current_cell,
|
||||
ov::OutputVector& x_to_concat,
|
||||
ov::OutputVector& attention_to_concat,
|
||||
map<int, ov::Output<ov::Node>>& h_outputs_to_redirect,
|
||||
int& cells_cnt,
|
||||
const shared_ptr<ov::Node>& axis_1) {
|
||||
cells_cnt = 1;
|
||||
shared_ptr<RNNCellBase> current = current_cell;
|
||||
while (true) {
|
||||
cp_from.add(current);
|
||||
// check the source node of HiddenState input
|
||||
auto prev = current->input_value(1).get_node_shared_ptr();
|
||||
auto prev_cell = dynamic_pointer_cast<RNNCellBase>(prev);
|
||||
|
||||
auto in_X = current->input(0);
|
||||
x_to_concat.push_back(cp_to.make<Unsqueeze>(in_X.get_source_output(), axis_1));
|
||||
h_outputs_to_redirect[cells_cnt] = current->output(0);
|
||||
|
||||
if (auto augru = dynamic_pointer_cast<ov::op::internal::AUGRUCell>(current)) {
|
||||
attention_to_concat.push_back(cp_to.make<Unsqueeze>(augru->input_value(5), axis_1));
|
||||
}
|
||||
|
||||
if (prev_cell && is_equal_cells(prev_cell, current) && check_lstm_cell(prev_cell, current)) {
|
||||
current = prev_cell;
|
||||
cells_cnt++;
|
||||
} else {
|
||||
break;
|
||||
}
|
||||
}
|
||||
reverse(x_to_concat.begin(), x_to_concat.end());
|
||||
reverse(attention_to_concat.begin(), attention_to_concat.end());
|
||||
// the first cell in the chain
|
||||
return current;
|
||||
}
|
||||
|
||||
bool create_sequence(ov::pass::NodeRegistry& cp_to,
|
||||
const shared_ptr<RNNCellBase>& first_cell,
|
||||
const shared_ptr<RNNCellBase>& last_cell,
|
||||
const ov::OutputVector& x_to_concat,
|
||||
const ov::OutputVector& attention_to_concat,
|
||||
const map<int, ov::Output<ov::Node>>& h_outputs_to_redirect,
|
||||
int cells_cnt,
|
||||
const shared_ptr<ov::Node>& axis_0,
|
||||
const shared_ptr<ov::Node>& axis_1) {
|
||||
int64_t idx_W = 2, idx_R = 3, idx_B = 4;
|
||||
// 2nd input is Cell State
|
||||
bool is_lstm = false;
|
||||
if (dynamic_pointer_cast<LSTMCell>(last_cell) || dynamic_pointer_cast<ov::opset3::LSTMCell>(last_cell)) {
|
||||
is_lstm = true;
|
||||
idx_B++;
|
||||
idx_R++;
|
||||
idx_W++;
|
||||
}
|
||||
|
||||
const auto X_in = cp_to.make<Concat>(x_to_concat, 1);
|
||||
const auto Ht_in = cp_to.make<Unsqueeze>(first_cell->input_value(1), axis_1);
|
||||
const auto W_in = cp_to.make<Unsqueeze>(first_cell->input_value(idx_W), axis_0);
|
||||
const auto R_in = cp_to.make<Unsqueeze>(first_cell->input_value(idx_R), axis_0);
|
||||
const auto B_in = cp_to.make<Unsqueeze>(first_cell->input_value(idx_B), axis_0);
|
||||
|
||||
const auto& shape_node = cp_to.add(ngraph::op::util::make_try_fold<ShapeOf>(first_cell->input_value(0)));
|
||||
const auto& zero = cp_to.make<Constant>(i64, ov::Shape{1}, 0);
|
||||
const auto& batch_dimension = cp_to.add(ngraph::op::util::make_try_fold<Gather>(shape_node, zero, axis_0));
|
||||
auto seq_lengths_scalar = cp_to.make<Constant>(i64, ov::Shape{}, cells_cnt);
|
||||
auto sequence_lengths_in =
|
||||
cp_to.add(ngraph::op::util::make_try_fold<Broadcast>(seq_lengths_scalar, batch_dimension));
|
||||
shared_ptr<ov::Node> sequence;
|
||||
ov::OutputVector outputs(1);
|
||||
if (dynamic_pointer_cast<LSTMCell>(first_cell)) {
|
||||
const auto Ct_in = cp_to.make<Unsqueeze>(first_cell->input_value(2), axis_1);
|
||||
sequence = cp_to.make<LSTMSequence>(X_in,
|
||||
Ht_in,
|
||||
Ct_in,
|
||||
sequence_lengths_in,
|
||||
W_in,
|
||||
R_in,
|
||||
B_in,
|
||||
first_cell->get_hidden_size(),
|
||||
ov::op::RecurrentSequenceDirection::FORWARD,
|
||||
first_cell->get_activations_alpha(),
|
||||
first_cell->get_activations_beta(),
|
||||
first_cell->get_activations(),
|
||||
first_cell->get_clip());
|
||||
outputs.resize(2);
|
||||
outputs[1] = cp_to.make<Squeeze>(sequence->output(2), axis_1);
|
||||
} else if (auto lstm_cell_v0 = dynamic_pointer_cast<ov::opset3::LSTMCell>(first_cell)) {
|
||||
// input_forget modification is not supported
|
||||
if (lstm_cell_v0->get_input_forget()) {
|
||||
return false;
|
||||
}
|
||||
auto weights_format = lstm_cell_v0->get_weights_format();
|
||||
ov::Output<ov::Node> W = W_in, R = R_in, B = B_in;
|
||||
if (weights_format != ov::op::LSTMWeightsFormat::FICO) {
|
||||
W = ov::op::util::convert_lstm_node_format(W_in, convert_lstm_weights_enums(weights_format));
|
||||
R = ov::op::util::convert_lstm_node_format(R_in, convert_lstm_weights_enums(weights_format));
|
||||
B = ov::op::util::convert_lstm_node_format(B_in, convert_lstm_weights_enums(weights_format));
|
||||
}
|
||||
const auto Ct_in = cp_to.make<Unsqueeze>(first_cell->input_value(2), axis_1);
|
||||
sequence = cp_to.make<LSTMSequence>(X_in,
|
||||
Ht_in,
|
||||
Ct_in,
|
||||
sequence_lengths_in,
|
||||
W,
|
||||
R,
|
||||
B,
|
||||
first_cell->get_hidden_size(),
|
||||
ov::op::RecurrentSequenceDirection::FORWARD,
|
||||
first_cell->get_activations_alpha(),
|
||||
first_cell->get_activations_beta(),
|
||||
first_cell->get_activations(),
|
||||
first_cell->get_clip());
|
||||
outputs.resize(2);
|
||||
outputs[1] = cp_to.make<Squeeze>(sequence->output(2), axis_1);
|
||||
} else if (auto gru_cell = dynamic_pointer_cast<GRUCell>(first_cell)) {
|
||||
sequence = cp_to.make<GRUSequence>(X_in,
|
||||
Ht_in,
|
||||
sequence_lengths_in,
|
||||
W_in,
|
||||
R_in,
|
||||
B_in,
|
||||
first_cell->get_hidden_size(),
|
||||
ov::op::RecurrentSequenceDirection::FORWARD,
|
||||
first_cell->get_activations(),
|
||||
first_cell->get_activations_alpha(),
|
||||
first_cell->get_activations_beta(),
|
||||
first_cell->get_clip(),
|
||||
gru_cell->get_linear_before_reset());
|
||||
} else if (dynamic_pointer_cast<RNNCell>(first_cell)) {
|
||||
sequence = cp_to.make<RNNSequence>(X_in,
|
||||
Ht_in,
|
||||
sequence_lengths_in,
|
||||
W_in,
|
||||
R_in,
|
||||
B_in,
|
||||
first_cell->get_hidden_size(),
|
||||
ov::op::RecurrentSequenceDirection::FORWARD,
|
||||
first_cell->get_activations(),
|
||||
first_cell->get_activations_alpha(),
|
||||
first_cell->get_activations_beta(),
|
||||
first_cell->get_clip());
|
||||
} else if (dynamic_pointer_cast<ov::op::internal::AUGRUCell>(first_cell)) {
|
||||
const auto A_in = cp_to.make<Concat>(attention_to_concat, 1);
|
||||
sequence = cp_to.make<ov::op::internal::AUGRUSequence>(X_in,
|
||||
Ht_in,
|
||||
sequence_lengths_in,
|
||||
W_in,
|
||||
R_in,
|
||||
B_in,
|
||||
A_in,
|
||||
first_cell->get_hidden_size());
|
||||
} else {
|
||||
// cell is not supported;
|
||||
return false;
|
||||
}
|
||||
|
||||
if (!h_outputs_to_redirect.empty()) {
|
||||
auto squeeze_Y = cp_to.make<Squeeze>(sequence->output(0), axis_1);
|
||||
auto split = cp_to.make<Split>(squeeze_Y, axis_1, cells_cnt);
|
||||
|
||||
for (auto it : h_outputs_to_redirect) {
|
||||
auto Hi = split->output(cells_cnt - it.first);
|
||||
auto friendly_name = it.second.get_node_shared_ptr()->get_friendly_name();
|
||||
if (it.first == 1) {
|
||||
Hi = sequence->output(1);
|
||||
}
|
||||
auto squeeze = cp_to.make<Squeeze>(Hi, axis_1);
|
||||
it.second.replace(squeeze);
|
||||
if (is_lstm) {
|
||||
friendly_name += ":1";
|
||||
}
|
||||
squeeze->set_friendly_name(friendly_name);
|
||||
}
|
||||
}
|
||||
if (is_lstm) {
|
||||
auto squeeze = cp_to.make<Squeeze>(sequence->output(2), axis_1);
|
||||
last_cell->output(1).replace(squeeze);
|
||||
squeeze->set_friendly_name(last_cell->get_friendly_name() + ":2");
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
} // namespace
|
||||
|
||||
ov::pass::SequenceFusion::SequenceFusion() {
|
||||
MATCHER_SCOPE(SequenceFusion);
|
||||
|
||||
auto cell = pattern::wrap_type<RNNCellBase>();
|
||||
matcher_pass_callback callback = [=](pattern::Matcher& m) {
|
||||
NodeRegistry copy_from;
|
||||
NodeRegistry copy_to;
|
||||
auto cell = m.get_match_root();
|
||||
shared_ptr<RNNCellBase> current_cell = dynamic_pointer_cast<RNNCellBase>(cell);
|
||||
if (!current_cell) {
|
||||
return false;
|
||||
}
|
||||
// check that this is the last Cell in the chain, e.g.
|
||||
// GRUCell -> GRUCell (the last cell) -> OtherNode
|
||||
// GRUCell (hidden_size = 128) -> GRUCell (hs = 128, the last) -> GRUCell (hs = 64)
|
||||
for (const auto& target : cell->get_output_target_inputs(0)) {
|
||||
auto cell_1 = dynamic_pointer_cast<RNNCellBase>(target.get_node()->shared_from_this());
|
||||
if (cell_1 && is_equal_cells(cell_1, current_cell)) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
int cells_cnt;
|
||||
ov::OutputVector x_to_concat;
|
||||
ov::OutputVector attention_to_concat;
|
||||
map<int, ov::Output<ov::Node>> h_outputs_to_redirect;
|
||||
auto axis_0 = copy_to.make<Constant>(i64, Shape{}, 0);
|
||||
auto axis_1 = copy_to.make<Constant>(i64, Shape{}, 1);
|
||||
|
||||
// detect chain (Cell->Cell->Cell->..)
|
||||
auto first_cell = find_cell_chain(copy_from,
|
||||
copy_to,
|
||||
current_cell,
|
||||
x_to_concat,
|
||||
attention_to_concat,
|
||||
h_outputs_to_redirect,
|
||||
cells_cnt,
|
||||
axis_1);
|
||||
if (!first_cell) {
|
||||
return false;
|
||||
}
|
||||
|
||||
// no reasons to create sequence if the single cell detected.
|
||||
// TODO: investigate optimal cnt of cells
|
||||
constexpr int optimal_cnt_of_cells = 2;
|
||||
if (cells_cnt < optimal_cnt_of_cells) {
|
||||
return false;
|
||||
}
|
||||
|
||||
auto res = create_sequence(copy_to,
|
||||
first_cell,
|
||||
current_cell,
|
||||
x_to_concat,
|
||||
attention_to_concat,
|
||||
h_outputs_to_redirect,
|
||||
cells_cnt,
|
||||
axis_0,
|
||||
axis_1);
|
||||
if (!res) {
|
||||
return false;
|
||||
}
|
||||
copy_runtime_info(copy_from.get(), copy_to.get());
|
||||
return true;
|
||||
};
|
||||
|
||||
auto m = make_shared<pattern::Matcher>(cell, matcher_name);
|
||||
this->register_matcher(m, callback);
|
||||
}
|
@ -23,6 +23,8 @@ enum class LSTMWeightsFormat {
|
||||
IOFC, // ONNX
|
||||
};
|
||||
|
||||
ov::op::util::LSTMWeightsFormat convert_lstm_weights_enums(LSTMWeightsFormat format);
|
||||
|
||||
namespace v0 {
|
||||
///
|
||||
/// \brief Class for single lstm cell node.
|
||||
|
@ -253,6 +253,22 @@ NGRAPH_API EnumNames<ngraph::op::LSTMWeightsFormat>& EnumNames<ngraph::op::LSTMW
|
||||
|
||||
BWDCMP_RTTI_DEFINITION(AttributeAdapter<ov::op::LSTMWeightsFormat>);
|
||||
|
||||
ov::op::util::LSTMWeightsFormat op::convert_lstm_weights_enums(op::LSTMWeightsFormat format) {
|
||||
switch (format) {
|
||||
case LSTMWeightsFormat::FICO:
|
||||
return ov::op::util::LSTMWeightsFormat::FICO;
|
||||
case LSTMWeightsFormat::ICOF:
|
||||
return ov::op::util::LSTMWeightsFormat::ICOF;
|
||||
case LSTMWeightsFormat::IFCO:
|
||||
return ov::op::util::LSTMWeightsFormat::IFCO;
|
||||
case LSTMWeightsFormat::IFOC:
|
||||
return ov::op::util::LSTMWeightsFormat::IFOC;
|
||||
case LSTMWeightsFormat::IOFC:
|
||||
return ov::op::util::LSTMWeightsFormat::IOFC;
|
||||
default:
|
||||
OPENVINO_ASSERT(false, "Incorrect LSTM weights format");
|
||||
}
|
||||
}
|
||||
} // namespace ov
|
||||
|
||||
std::ostream& ov::operator<<(std::ostream& s, const op::LSTMWeightsFormat& type) {
|
||||
|
@ -0,0 +1,214 @@
|
||||
// Copyright (C) 2018-2022 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#include <gtest/gtest.h>
|
||||
#include <queue>
|
||||
|
||||
#include "common_test_utils/ngraph_test_utils.hpp"
|
||||
#include "openvino/opsets/opset3.hpp"
|
||||
#include "openvino/opsets/opset9.hpp"
|
||||
#include "transformations/common_optimizations/sequence_fusion.hpp"
|
||||
#include "ngraph_ops/augru_sequence.hpp"
|
||||
#include "ngraph_ops/augru_cell.hpp"
|
||||
|
||||
using namespace ov;
|
||||
using namespace std;
|
||||
using namespace testing;
|
||||
using namespace ov::opset9;
|
||||
using namespace ov::element;
|
||||
|
||||
namespace {
|
||||
enum class RNN_TYPE {
|
||||
LSTM_v0,
|
||||
LSTM_v4,
|
||||
GRU,
|
||||
RNN,
|
||||
AUGRU
|
||||
};
|
||||
|
||||
int get_gate_by_rnn_type(RNN_TYPE rnn_type) {
|
||||
int gate = 1;
|
||||
if (rnn_type == RNN_TYPE::LSTM_v4 || rnn_type == RNN_TYPE::LSTM_v0) {
|
||||
gate = 4;
|
||||
} else if (rnn_type == RNN_TYPE::GRU || rnn_type == RNN_TYPE::AUGRU) {
|
||||
gate = 3;
|
||||
} else if (rnn_type == RNN_TYPE::RNN) {
|
||||
gate = 1;
|
||||
}
|
||||
return gate;
|
||||
}
|
||||
|
||||
OutputVector create_cell(RNN_TYPE rnn_type,
|
||||
const shared_ptr<Node>& X,
|
||||
const shared_ptr<Node>& H,
|
||||
const shared_ptr<Node>& C,
|
||||
const shared_ptr<Node>& W,
|
||||
const shared_ptr<Node>& R,
|
||||
const shared_ptr<Node>& B,
|
||||
const shared_ptr<Node>& A,
|
||||
size_t hidden_size,
|
||||
int64_t cells_cnt) {
|
||||
shared_ptr<Node> cell;
|
||||
Output<Node> cur_H = H;
|
||||
Output<Node> cur_C = C;
|
||||
OutputVector hidden_vec;
|
||||
auto axis_1 = make_shared<Constant>(i64, Shape{}, 1);
|
||||
|
||||
for (int i = 0; i < cells_cnt; ++i) {
|
||||
if (rnn_type == RNN_TYPE::LSTM_v4) {
|
||||
cell = make_shared<LSTMCell>(X, cur_H, cur_C, W, R, B, hidden_size);
|
||||
cur_C = cell->output(1);
|
||||
} else if (rnn_type == RNN_TYPE::LSTM_v0) {
|
||||
cell = make_shared<opset3::LSTMCell>(X, cur_H, cur_C, W, R, B, hidden_size, ov::op::LSTMWeightsFormat::FICO);
|
||||
cur_C = cell->output(1);
|
||||
} else if (rnn_type == RNN_TYPE::GRU) {
|
||||
cell = make_shared<GRUCell>(X, cur_H, W, R, B, hidden_size);
|
||||
} else if (rnn_type == RNN_TYPE::RNN) {
|
||||
cell = make_shared<RNNCell>(X, cur_H, W, R, B, hidden_size);
|
||||
} else if (rnn_type == RNN_TYPE::AUGRU) {
|
||||
cell = make_shared<op::internal::AUGRUCell>(X, cur_H, W, R, B, A, hidden_size);
|
||||
}
|
||||
cur_H = cell->output(0);
|
||||
hidden_vec.push_back(make_shared<Unsqueeze>(cur_H, axis_1));
|
||||
}
|
||||
auto concat = make_shared<Concat>(hidden_vec, 1);
|
||||
OutputVector outputs = {concat->output(0)};
|
||||
auto cell_outputs = cell->outputs();
|
||||
outputs.insert(outputs.end(), cell_outputs.begin(), cell_outputs.end());
|
||||
return outputs;
|
||||
}
|
||||
|
||||
shared_ptr<Model> gen_model(RNN_TYPE rnn_type, size_t batch, size_t hidden_size, size_t input_size,
|
||||
int64_t cells_cnt) {
|
||||
int gate = get_gate_by_rnn_type(rnn_type);
|
||||
auto X = make_shared<Parameter>(f32, Shape{batch, input_size});
|
||||
auto H = make_shared<Parameter>(f32, Shape{batch, hidden_size});
|
||||
auto C = make_shared<Parameter>(f32, Shape{batch, hidden_size});
|
||||
auto W = make_shared<Parameter>(f32, Shape{gate * hidden_size, input_size});
|
||||
auto R = make_shared<Parameter>(f32, Shape{gate * hidden_size, hidden_size});
|
||||
auto B = make_shared<Parameter>(f32, Shape{gate * hidden_size});
|
||||
auto A = make_shared<Parameter>(f32, Shape{batch, 1});
|
||||
|
||||
auto outputs = create_cell(rnn_type, X, H, C, W, R, B, A, hidden_size, cells_cnt);
|
||||
ParameterVector params = {X, H, W, R, B};
|
||||
if (rnn_type == RNN_TYPE::LSTM_v4 || rnn_type == RNN_TYPE::LSTM_v0) {
|
||||
params.push_back(C);
|
||||
} else if (rnn_type == RNN_TYPE::AUGRU) {
|
||||
params.push_back(A);
|
||||
}
|
||||
return make_shared<Model>(outputs, params);
|
||||
}
|
||||
|
||||
shared_ptr<Model> gen_reference(RNN_TYPE rnn_type, size_t batch, size_t hidden_size, size_t input_size,
|
||||
int64_t cells_cnt) {
|
||||
int gate = get_gate_by_rnn_type(rnn_type);
|
||||
auto axis_0 = make_shared<Constant>(i64, Shape{}, 0);
|
||||
auto axis_1 = make_shared<Constant>(i64, Shape{}, 1);
|
||||
auto seq_len = make_shared<Constant>(i64, Shape{batch}, cells_cnt);
|
||||
|
||||
auto X = make_shared<Parameter>(f32, Shape{batch, input_size});
|
||||
auto H = make_shared<Parameter>(f32, Shape{batch, hidden_size});
|
||||
auto C = make_shared<Parameter>(f32, Shape{batch, hidden_size});
|
||||
auto W = make_shared<Parameter>(f32, Shape{gate * hidden_size, input_size});
|
||||
auto R = make_shared<Parameter>(f32, Shape{gate * hidden_size, hidden_size});
|
||||
auto B = make_shared<Parameter>(f32, Shape{gate * hidden_size});
|
||||
auto A = make_shared<Parameter>(f32, Shape{batch, 1});
|
||||
|
||||
ParameterVector params = {X, H, W, R, B};
|
||||
if (rnn_type == RNN_TYPE::LSTM_v4 || rnn_type == RNN_TYPE::LSTM_v0) {
|
||||
params.push_back(C);
|
||||
} else if (rnn_type == RNN_TYPE::AUGRU) {
|
||||
params.push_back(A);
|
||||
}
|
||||
auto unH = make_shared<Unsqueeze>(H, axis_1);
|
||||
auto unC = make_shared<Unsqueeze>(C, axis_1);
|
||||
auto unW = make_shared<Unsqueeze>(W, axis_0);
|
||||
auto unR = make_shared<Unsqueeze>(R, axis_0);
|
||||
auto unB = make_shared<Unsqueeze>(B, axis_0);
|
||||
|
||||
OutputVector in_X;
|
||||
OutputVector in_A;
|
||||
for (int i = 0; i < cells_cnt; ++i) {
|
||||
in_X.push_back(make_shared<Unsqueeze>(X, axis_1));
|
||||
in_A.push_back(make_shared<Unsqueeze>(A, axis_1));
|
||||
}
|
||||
auto concat_X = make_shared<Concat>(in_X, 1);
|
||||
auto concat_A = make_shared<Concat>(in_A, 1);
|
||||
|
||||
shared_ptr<Node> seq;
|
||||
if (rnn_type == RNN_TYPE::LSTM_v4 || rnn_type == RNN_TYPE::LSTM_v0) {
|
||||
seq = make_shared<LSTMSequence>(concat_X, unH, unC, seq_len, unW, unR, unB, hidden_size, op::RecurrentSequenceDirection::FORWARD);
|
||||
} else if (rnn_type == RNN_TYPE::GRU) {
|
||||
seq = make_shared<GRUSequence>(concat_X, unH, seq_len, unW, unR, unB, hidden_size, op::RecurrentSequenceDirection::FORWARD);
|
||||
} else if (rnn_type == RNN_TYPE::RNN) {
|
||||
seq = make_shared<RNNSequence>(concat_X, unH, seq_len, unW, unR, unB, hidden_size, op::RecurrentSequenceDirection::FORWARD);
|
||||
} else if (rnn_type == RNN_TYPE::AUGRU) {
|
||||
seq = make_shared<op::internal::AUGRUSequence>(concat_X, unH, seq_len, unW, unR, unB, concat_A, hidden_size);
|
||||
}
|
||||
|
||||
auto squeeze_H = make_shared<Squeeze>(seq->output(0), axis_1);
|
||||
auto _axis_1 = make_shared<Constant>(i64, Shape{}, 1);
|
||||
auto split = make_shared<Split>(squeeze_H, axis_1, cells_cnt);
|
||||
|
||||
OutputVector in_vec;
|
||||
for (int i = 0; i < split->outputs().size(); ++i) {
|
||||
auto squeeze = make_shared<Squeeze>(split->output(i), axis_1);
|
||||
in_vec.push_back(make_shared<Unsqueeze>(squeeze, _axis_1));
|
||||
}
|
||||
auto concat = make_shared<Concat>(in_vec, 1);
|
||||
|
||||
auto squeeze_Ht = make_shared<Squeeze>(seq->output(1), axis_1);
|
||||
OutputVector outputs = {concat->output(0), squeeze_Ht};
|
||||
if (rnn_type == RNN_TYPE::LSTM_v4 || rnn_type == RNN_TYPE::LSTM_v0) {
|
||||
auto squeeze_Ct = make_shared<Squeeze>(seq->output(2), axis_1);
|
||||
outputs.push_back(squeeze_Ct);
|
||||
}
|
||||
return make_shared<Model>(outputs, params);
|
||||
}
|
||||
} // namespace
|
||||
|
||||
struct SequenceFusionParams {
|
||||
RNN_TYPE rnn_type;
|
||||
size_t batch;
|
||||
size_t hidden_size;
|
||||
size_t input_size;
|
||||
int64_t cell_cnt;
|
||||
};
|
||||
|
||||
class SequenceFusionTest
|
||||
: public WithParamInterface<SequenceFusionParams>,
|
||||
public TransformationTestsF {
|
||||
};
|
||||
|
||||
TEST_P(SequenceFusionTest, SequencePattern) {
|
||||
const auto& p = GetParam();
|
||||
{
|
||||
function = gen_model(p.rnn_type, p.batch,
|
||||
p.hidden_size, p.input_size, p.cell_cnt);
|
||||
manager.register_pass<pass::SequenceFusion>();
|
||||
}
|
||||
|
||||
// the transformation won't be applied for single cell
|
||||
if (p.cell_cnt > 1) {
|
||||
function_ref = gen_reference(p.rnn_type, p.batch, p.hidden_size, p.input_size, p.cell_cnt);
|
||||
}
|
||||
comparator.enable(FunctionsComparator::CmpValues::ACCURACY);
|
||||
comparator.enable(FunctionsComparator::CmpValues::CONST_VALUES);
|
||||
comparator.enable(FunctionsComparator::CmpValues::ATTRIBUTES);
|
||||
}
|
||||
|
||||
static const std::vector<SequenceFusionParams> params = {
|
||||
SequenceFusionParams{RNN_TYPE::LSTM_v4, 2, 128, 32, 1},
|
||||
SequenceFusionParams{RNN_TYPE::LSTM_v4, 2, 128, 32, 10},
|
||||
SequenceFusionParams{RNN_TYPE::LSTM_v0, 2, 128, 32, 1},
|
||||
SequenceFusionParams{RNN_TYPE::LSTM_v0, 2, 128, 32, 10},
|
||||
SequenceFusionParams{RNN_TYPE::GRU, 2, 128, 32, 1},
|
||||
SequenceFusionParams{RNN_TYPE::GRU, 2, 128, 32, 10},
|
||||
SequenceFusionParams{RNN_TYPE::RNN, 2, 128, 32, 1},
|
||||
SequenceFusionParams{RNN_TYPE::RNN, 2, 128, 32, 10},
|
||||
SequenceFusionParams{RNN_TYPE::AUGRU, 2, 128, 32, 1},
|
||||
SequenceFusionParams{RNN_TYPE::AUGRU, 2, 128, 32, 10},
|
||||
};
|
||||
|
||||
INSTANTIATE_TEST_SUITE_P(SequenceFusionTest, SequenceFusionTest, ValuesIn(params));
|
Loading…
Reference in New Issue
Block a user