TensorIterator to RNN/GRU/LSTM Sequence transformation (#2146)

* ti to sequences transformations

* fix sequences to sequences ie conversion

* resolve review marks

* resolve review remarks, fix ti to sequences transformations to support batch > 1 if slice axis == 0

* temporary enable ngraph ti transformations for cpu plugin

* fix includes

* Revert "fix includes"

This reverts commit 6cf15b97be.

* Revert "temporary enable ngraph ti transformations for cpu plugin"

This reverts commit fd528d7216.

* delete todo comments
This commit is contained in:
Ivan Tikhonov 2020-09-15 10:11:51 +03:00 committed by GitHub
parent ac2370b420
commit cd722d72df
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 839 additions and 44 deletions

View File

@ -0,0 +1,58 @@
// Copyright (C) 2020 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#pragma once
#include <vector>
#include <memory>
#include <vector>
#include <memory>
#include <transformations_visibility.hpp>
#include <ngraph/pass/graph_rewrite.hpp>
namespace ngraph {
namespace pass {
class TRANSFORMATIONS_API ConvertTensorIteratorToLSTMSequence;
class TRANSFORMATIONS_API ConvertTensorIteratorToRNNSequence;
class TRANSFORMATIONS_API ConvertTensorIteratorToGRUSequence;
} // namespace pass
} // namespace ngraph
/**
* @ingroup ie_transformation_common_api
* @brief Finds all TensorIterator layers, detects the pattern Squeeze->LSTMCell->Unsqueeze in the TensorIterator body,
* converts this pattern to LSTMSequence layer and replaces them TensorIterator.
*/
class ngraph::pass::ConvertTensorIteratorToLSTMSequence: public ngraph::pass::MatcherPass {
public:
ConvertTensorIteratorToLSTMSequence();
};
/**
* @ingroup ie_transformation_common_api
* @brief Finds all TensorIterator layers, detects the pattern Squeeze->RNNCell->Unsqueeze in the TensorIterator body,
* converts this pattern to RNNSequence layer and replaces them TensorIterator.
*/
class ngraph::pass::ConvertTensorIteratorToRNNSequence: public ngraph::pass::MatcherPass {
public:
ConvertTensorIteratorToRNNSequence();
};
/**
* @ingroup ie_transformation_common_api
* @brief Finds all TensorIterator layers, detects the pattern Squeeze->GRUCell->Unsqueeze in the TensorIterator body,
* converts this pattern to GRUSequence layer and replaces them TensorIterator.
*/
class ngraph::pass::ConvertTensorIteratorToGRUSequence: public ngraph::pass::MatcherPass {
public:
ConvertTensorIteratorToGRUSequence();
};

View File

@ -23,31 +23,26 @@ ngraph::pass::ConvertLSTMSequenceMatcher::ConvertLSTMSequenceMatcher() {
return false;
}
const auto& W = std::dynamic_pointer_cast<ngraph::opset4::Constant>(
lstm_sequence->input_value(4).get_node_shared_ptr());
if (!W) {
return false;
}
const auto& W = lstm_sequence->input_value(4);
const auto& R = lstm_sequence->input_value(5);
const auto& R = std::dynamic_pointer_cast<ngraph::opset4::Constant>(
lstm_sequence->input_value(5).get_node_shared_ptr());
if (!R) {
// Bidirectional cases are not supported
if (lstm_sequence->get_direction() == ngraph::op::RecurrentSequenceDirection::BIDIRECTIONAL)
return false;
}
// for forward/reverse cases we can squeeze num_direction dimension
auto axis_1 = ngraph::opset4::Constant::create(ngraph::element::i64, ngraph::Shape{1}, {1});
auto in_1 = std::make_shared<ngraph::opset4::Squeeze>(lstm_sequence->input(1).get_source_output(), axis_1);
auto in_2 = std::make_shared<ngraph::opset4::Squeeze>(lstm_sequence->input(2).get_source_output(), axis_1);
auto concat = std::make_shared<ngraph::opset4::Concat>(ngraph::NodeVector({W, R}), 2);
auto in_1 = std::make_shared<ngraph::opset4::Squeeze>(lstm_sequence->input_value(1), axis_1);
auto in_2 = std::make_shared<ngraph::opset4::Squeeze>(lstm_sequence->input_value(2), axis_1);
auto concat = std::make_shared<ngraph::opset4::Concat>(ngraph::OutputVector{W, R}, 2);
auto axis_2 = ngraph::opset4::Constant::create(ngraph::element::i64, ngraph::Shape{1}, {0});
auto in_3 = std::make_shared<ngraph::opset4::Squeeze>(concat->output(0), axis_2);
auto in_4 = std::make_shared<ngraph::opset4::Squeeze>(lstm_sequence->input(6).get_source_output(), axis_2);
auto in_4 = std::make_shared<ngraph::opset4::Squeeze>(lstm_sequence->input_value(6), axis_2);
auto lstm_sequence_ie = std::make_shared<ngraph::op::LSTMSequenceIE>(
lstm_sequence->input(0).get_source_output(), // X
in_1, // initial_hidden_state
in_2, // initial_cell_state
lstm_sequence->input(3).get_source_output(),
lstm_sequence->input_value(3),
in_3, // WR
in_4, // B
lstm_sequence->get_hidden_size(),
@ -84,34 +79,25 @@ ngraph::pass::ConvertGRUSequenceMatcher::ConvertGRUSequenceMatcher() {
return false;
}
auto W = std::dynamic_pointer_cast<ngraph::opset4::Constant>(
gru_sequence->input_value(3).get_node_shared_ptr());
if (!W) {
return false;
}
auto W = gru_sequence->input_value(3);
auto R = gru_sequence->input_value(4);
auto R = std::dynamic_pointer_cast<ngraph::opset4::Constant>(
gru_sequence->input_value(4).get_node_shared_ptr());
if (!R) {
return false;
}
// todo: add exception?
// Bidirectional cases are not supported
if (gru_sequence->get_direction() == ngraph::op::RecurrentSequenceDirection::BIDIRECTIONAL)
return false;
// for forward/reverse cases we can squeeze num_direction dimension
auto axis_1 = ngraph::opset4::Constant::create(ngraph::element::i64, ngraph::Shape{1}, {1});
auto in_1 = std::make_shared<ngraph::opset4::Squeeze>(gru_sequence->input(1).get_source_output(), axis_1);
auto concat = std::make_shared<ngraph::opset4::Concat>(ngraph::NodeVector({W, R}), 2);
auto in_1 = std::make_shared<ngraph::opset4::Squeeze>(gru_sequence->input_value(1), axis_1);
auto concat = std::make_shared<ngraph::opset4::Concat>(ngraph::OutputVector{W, R}, 2);
auto axis_2 = ngraph::opset4::Constant::create(ngraph::element::i64, ngraph::Shape{1}, {0});
auto in_3 = std::make_shared<ngraph::opset4::Squeeze>(concat->output(0), axis_2);
auto in_4 = std::make_shared<ngraph::opset4::Squeeze>(gru_sequence->input(5).get_source_output(), axis_2);
auto in_4 = std::make_shared<ngraph::opset4::Squeeze>(gru_sequence->input_value(5), axis_2);
auto gru_sequence_ie = std::make_shared<ngraph::op::GRUSequenceIE>(
gru_sequence->input(0).get_source_output(), // X
gru_sequence->input_value(0), // X
in_1, // initial_hidden_state
gru_sequence->input(2).get_source_output(),
gru_sequence->input_value(2),
in_3, // WR
in_4, // B
gru_sequence->get_hidden_size(),
@ -146,27 +132,22 @@ ngraph::pass::ConvertRNNSequenceMatcher::ConvertRNNSequenceMatcher() {
return false;
}
auto W = std::dynamic_pointer_cast<ngraph::opset4::Constant>(
rnn_sequence->input_value(3).get_node_shared_ptr());
if (!W) {
// Bidirectional cases are not supported
if (rnn_sequence->get_direction() == ngraph::op::RecurrentSequenceDirection::BIDIRECTIONAL)
return false;
}
auto R = std::dynamic_pointer_cast<ngraph::opset4::Constant>(
rnn_sequence->input_value(4).get_node_shared_ptr());
if (!R) {
return false;
}
auto W = rnn_sequence->input_value(3);
auto R = rnn_sequence->input_value(4);
// for forward/reverse cases we can squeeze num_direction dimension
auto axis_1 = ngraph::opset4::Constant::create(ngraph::element::i64, ngraph::Shape{1}, {1});
auto in_1 = std::make_shared<ngraph::opset4::Squeeze>(rnn_sequence->input(1).get_source_output(), axis_1);
auto concat = std::make_shared<ngraph::opset4::Concat>(ngraph::NodeVector({W, R}), 2);
auto in_1 = std::make_shared<ngraph::opset4::Squeeze>(rnn_sequence->input_value(1), axis_1);
auto concat = std::make_shared<ngraph::opset4::Concat>(ngraph::OutputVector{W, R}, 2);
auto axis_2 = ngraph::opset4::Constant::create(ngraph::element::i64, ngraph::Shape{1}, {0});
auto in_3 = std::make_shared<ngraph::opset4::Squeeze>(concat->output(0), axis_2);
auto in_4 = std::make_shared<ngraph::opset4::Squeeze>(rnn_sequence->input(5).get_source_output(), axis_2);
auto in_4 = std::make_shared<ngraph::opset4::Squeeze>(rnn_sequence->input_value(5), axis_2);
auto rnn_sequence_ie = std::make_shared<ngraph::op::RNNSequenceIE>(
rnn_sequence->input(0).get_source_output(), // X
rnn_sequence->input_value(0), // X
in_1, // initial_hidden_state
rnn_sequence->input_value(2),
in_3, // WR

View File

@ -0,0 +1,478 @@
// Copyright (C) 2020 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#include "transformations/tensor_iterator_transformations/convert_ti_to_sequences.h"
#include "transformations/utils/utils.hpp"
#include <memory>
#include <vector>
#include <ngraph/node.hpp>
#include <ngraph/pass/manager.hpp>
#include <ngraph/opsets/opset4.hpp>
#include <ngraph/rt_info.hpp>
#include <ngraph/graph_util.hpp>
#include <ngraph/specialize_function.hpp>
#include <ngraph/pattern/op/wrap_type.hpp>
ngraph::pass::ConvertTensorIteratorToLSTMSequence::ConvertTensorIteratorToLSTMSequence() {
auto tensor_iterator = std::make_shared<ngraph::pattern::op::Label>(ngraph::element::f32,
ngraph::Shape{}, ngraph::pattern::has_class<ngraph::opset4::TensorIterator>());
ngraph::matcher_pass_callback callback = [this](pattern::Matcher &m) {
auto ti = std::dynamic_pointer_cast<ngraph::opset4::TensorIterator>(m.get_match_root());
if (!ti || !m_transformation_callback(ti))
return false;
// create pattern
auto data = std::make_shared<ngraph::opset4::Parameter>(ngraph::element::f32, ngraph::Shape{1, 1, 1});
auto axis_squeeze = std::make_shared<ngraph::opset4::Constant>(ngraph::element::i64, ngraph::Shape{1}, 1);
auto input_data = std::make_shared<ngraph::opset4::Squeeze>(data, axis_squeeze);
auto input_H_state = std::make_shared<ngraph::opset4::Parameter>(ngraph::element::f32, ngraph::Shape{1, 1});
auto input_C_state = std::make_shared<ngraph::opset4::Parameter>(ngraph::element::f32, ngraph::Shape{1, 1});
auto input_W = std::make_shared<ngraph::opset4::Constant>(ngraph::element::f32, ngraph::Shape{4, 1});
auto input_R = std::make_shared<ngraph::opset4::Constant>(ngraph::element::f32, ngraph::Shape{4, 1});
auto input_B = std::make_shared<ngraph::opset4::Constant>(ngraph::element::f32, ngraph::Shape{4});
auto cell = std::make_shared<ngraph::opset4::LSTMCell>(input_data, input_H_state, input_C_state,
input_W, input_R, input_B, 1);
auto axis_unsqueeze = std::make_shared<ngraph::opset4::Constant>(ngraph::element::i64, ngraph::Shape{1}, 1);
auto unsqueeze = std::make_shared<ngraph::opset4::Unsqueeze>(cell, axis_unsqueeze);
ngraph::pattern::Matcher matcher(unsqueeze);
bool match = false;
auto func = ti->get_body();
for (const auto& res : func->get_results()) {
match = matcher.match((res->get_input_source_output(0)));
if (match)
break;
}
// All nodes are in the TI body should be matched in pattern
if (!match || (matcher.get_matched_nodes().size() + func->get_results().size()) != func->get_ops().size())
return false;
auto pattern_map = matcher.get_pattern_map();
auto params = func->get_parameters();
std::vector<std::shared_ptr<ngraph::opset4::TensorIterator::InputDescription>> ordered_in_descs(3);
int64_t stride = 0, slice_axis = 0;
size_t batch_size = 0;
for (const auto& input_desc : ti->get_input_descriptions()) {
auto param = params[input_desc->m_body_parameter_index];
if (param == pattern_map[data]) {
// to get batch size value
if (param->get_partial_shape().is_dynamic()) {
return false;
}
auto slice_input
= std::dynamic_pointer_cast<ngraph::opset4::TensorIterator::SliceInputDescription>(input_desc);
if (!slice_input)
return false;
stride = slice_input->m_stride;
slice_axis = slice_input->m_axis;
if (!(slice_axis == 0 || slice_axis == 1)) {
return false;
}
batch_size = param->get_shape()[slice_axis == 0 ? 1 : 0];
ordered_in_descs[0] = input_desc;
} else if (param == pattern_map[input_H_state]) {
ordered_in_descs[1] = input_desc;
} else if (param == pattern_map[input_C_state]) {
ordered_in_descs[2] = input_desc;
} else {
return false;
}
}
auto results = func->get_results();
std::vector<std::shared_ptr<ngraph::opset4::TensorIterator::OutputDescription>> ordered_out_descs(3);
for (const auto& output_desc : ti->get_output_descriptions()) {
std::shared_ptr<opset4::Result> res = results[output_desc->m_body_value_index];
if (res->get_input_source_output(0) == pattern_map[unsqueeze]) {
auto concat_output
= std::dynamic_pointer_cast<ngraph::opset4::TensorIterator::ConcatOutputDescription>(output_desc);
if (!concat_output)
return false;
stride = concat_output->m_stride;
ordered_out_descs[0] = output_desc;
} else if (res->get_input_source_output(0) == pattern_map[cell]->output(0)) {
ordered_out_descs[1] = output_desc;
} else if (res->get_input_source_output(0) == pattern_map[cell]->output(1)) {
ordered_out_descs[2] = output_desc;
} else {
return false;
}
}
auto seq_lengths = ngraph::opset4::Constant::create(element::i32, Shape{batch_size}, {ti->get_num_iterations()});
const auto& lstm_cell = std::dynamic_pointer_cast<ngraph::opset4::LSTMCell>(pattern_map[cell]);
auto in_0 = ti->input_values()[ordered_in_descs[0]->m_input_index];
if (slice_axis == 0) {
auto order = ngraph::opset4::Constant::create(ngraph::element::i64, ngraph::Shape{3}, {1, 0, 2});
in_0 = std::make_shared<ngraph::opset4::Transpose>(ti->input_values()[ordered_in_descs[0]->m_input_index], order);
}
auto axis_1 = ngraph::opset4::Constant::create(ngraph::element::i64, ngraph::Shape{1}, {1});
auto in_1 = std::make_shared<ngraph::opset4::Unsqueeze>(ti->input_values()[ordered_in_descs[1]->m_input_index], axis_1);
auto in_2 = std::make_shared<ngraph::opset4::Unsqueeze>(ti->input_values()[ordered_in_descs[2]->m_input_index], axis_1);
auto axis_2 = ngraph::opset4::Constant::create(ngraph::element::i64, ngraph::Shape{1}, {0});
auto in_4 = std::make_shared<ngraph::opset4::Unsqueeze>(pattern_map[input_W]->output(0).get_node_shared_ptr(), axis_2);
auto in_5 = std::make_shared<ngraph::opset4::Unsqueeze>(pattern_map[input_R]->output(0).get_node_shared_ptr(), axis_2);
auto in_6 = std::make_shared<ngraph::opset4::Unsqueeze>(pattern_map[input_B]->output(0).get_node_shared_ptr(), axis_2);
auto sequence = std::make_shared<op::v5::LSTMSequence>(
in_0,
in_1,
in_2,
seq_lengths,
in_4,
in_5,
in_6,
lstm_cell->get_hidden_size(),
stride > 0 ? ngraph::op::RecurrentSequenceDirection::FORWARD: ngraph::op::RecurrentSequenceDirection::REVERSE,
lstm_cell->get_activations_alpha(),
lstm_cell->get_activations_beta(),
lstm_cell->get_activations(),
lstm_cell->get_clip());
auto axis_out = ngraph::opset4::Constant::create(ngraph::element::i64, ngraph::Shape{1}, {1});
auto out_0 = std::make_shared<ngraph::opset4::Squeeze>(sequence->output(0), axis_out);
auto out_1 = std::make_shared<ngraph::opset4::Squeeze>(sequence->output(1), axis_out);
auto out_2 = std::make_shared<ngraph::opset4::Squeeze>(sequence->output(2), axis_out);
std::shared_ptr<Node> out = out_0;
if (slice_axis == 0) {
auto order = ngraph::opset4::Constant::create(ngraph::element::i64, ngraph::Shape{3}, {1, 0, 2});
out = std::make_shared<ngraph::opset4::Transpose>(out_0, order);
}
ngraph::NodeVector outputs = {out, out_1, out_2};
for (size_t i = 0; i < ordered_out_descs.size(); ++i) {
if (ordered_out_descs[i]) {
for (const auto &input : ti->output(ordered_out_descs[i]->m_output_index).get_target_inputs()) {
input.replace_source_output(outputs[i]->output(0));
}
outputs[i]->get_output_tensor(0).set_name(op::util::create_ie_output_name(ti->output(ordered_out_descs[i]->m_output_index)));
}
}
ngraph::NodeVector new_nodes = {in_1, in_2, in_4, in_5, in_6, sequence, out_0, out_1, out_2};
if (slice_axis == 0) {
new_nodes.push_back(out);
new_nodes.push_back(in_0.get_node_shared_ptr());
}
copy_runtime_info(ti, new_nodes);
return true;
};
auto m = std::make_shared<ngraph::pattern::Matcher>(tensor_iterator, "ConvertTensorIteratorToLSTMSequence");
register_matcher(m, callback);
}
ngraph::pass::ConvertTensorIteratorToRNNSequence::ConvertTensorIteratorToRNNSequence() {
auto tensor_iterator = std::make_shared<ngraph::pattern::op::Label>(ngraph::element::f32,
ngraph::Shape{}, ngraph::pattern::has_class<ngraph::opset4::TensorIterator>());
ngraph::matcher_pass_callback callback = [this](pattern::Matcher &m) {
auto ti = std::dynamic_pointer_cast<ngraph::opset4::TensorIterator>(m.get_match_root());
if (!ti || !m_transformation_callback(ti))
return false;
// create pattern
auto data = std::make_shared<ngraph::opset4::Parameter>(ngraph::element::f32, ngraph::Shape{1, 1, 1});
auto axis_squeeze = std::make_shared<ngraph::opset4::Constant>(ngraph::element::i64, ngraph::Shape{1}, 0);
auto input_data = std::make_shared<ngraph::opset4::Squeeze>(data, axis_squeeze);
auto input_H_state = std::make_shared<ngraph::opset4::Parameter>(ngraph::element::f32, ngraph::Shape{1, 1});
auto input_W = std::make_shared<ngraph::opset4::Constant>(ngraph::element::f32, ngraph::Shape{1, 1});
auto input_R = std::make_shared<ngraph::opset4::Constant>(ngraph::element::f32, ngraph::Shape{1, 1});
auto input_B = std::make_shared<ngraph::opset4::Constant>(ngraph::element::f32, ngraph::Shape{1});
auto cell = std::make_shared<ngraph::opset4::RNNCell>(input_data, input_H_state, input_W, input_R, input_B, 1);
auto axis_unsqueeze = std::make_shared<ngraph::opset4::Constant>(ngraph::element::i64, ngraph::Shape{1}, 0);
auto unsqueeze = std::make_shared<ngraph::opset4::Unsqueeze>(cell, axis_unsqueeze);
ngraph::pattern::Matcher matcher(unsqueeze);
bool match = false;
auto func = ti->get_body();
for (const auto& res : func->get_results()) {
match = matcher.match((res->get_input_source_output(0)));
if (match)
break;
}
// All nodes are in the TI body should be matched in pattern
if (!match || (matcher.get_matched_nodes().size() + func->get_results().size()) != func->get_ops().size())
return false;
auto pattern_map = matcher.get_pattern_map();
auto params = func->get_parameters();
std::vector<std::shared_ptr<ngraph::opset4::TensorIterator::InputDescription>> ordered_in_descs(3);
int64_t stride = 0, slice_axis = 0;
size_t batch_size = 0;
for (const auto& input_desc : ti->get_input_descriptions()) {
auto param = params[input_desc->m_body_parameter_index];
if (param == pattern_map[data]) {
// to get batch size value
if (param->get_partial_shape().is_dynamic()) {
return false;
}
auto slice_input
= std::dynamic_pointer_cast<ngraph::opset4::TensorIterator::SliceInputDescription>(input_desc);
if (!slice_input)
return false;
stride = slice_input->m_stride;
slice_axis = slice_input->m_axis;
if (!(slice_axis == 0 || slice_axis == 1)) {
return false;
}
batch_size = param->get_shape()[slice_axis == 0 ? 1 : 0];
ordered_in_descs[0] = input_desc;
} else if (param == pattern_map[input_H_state]) {
ordered_in_descs[1] = input_desc;
} else {
return false;
}
}
auto seq_lengths = ngraph::opset4::Constant::create(element::i32, Shape{batch_size}, {ti->get_num_iterations()});
auto results = func->get_results();
std::vector<std::shared_ptr<ngraph::opset4::TensorIterator::OutputDescription>> ordered_out_descs(2);
for (const auto& output_desc : ti->get_output_descriptions()) {
std::shared_ptr<opset4::Result> res = results[output_desc->m_body_value_index];
if (res->get_input_source_output(0) == pattern_map[unsqueeze]) {
auto concat_output
= std::dynamic_pointer_cast<ngraph::opset4::TensorIterator::ConcatOutputDescription>(output_desc);
if (!concat_output)
return false;
stride = concat_output->m_stride;
ordered_out_descs[0] = output_desc;
} else if (res->get_input_source_output(0) == pattern_map[cell]->output(0)) {
ordered_out_descs[1] = output_desc;
} else {
return false;
}
}
const auto& rnn_cell = std::dynamic_pointer_cast<ngraph::opset4::RNNCell>(pattern_map[cell]);
auto in_0 = ti->input_values()[ordered_in_descs[0]->m_input_index];
if (slice_axis == 0) {
auto order = ngraph::opset4::Constant::create(ngraph::element::i64, ngraph::Shape{3}, {1, 0, 2});
in_0 = std::make_shared<ngraph::opset4::Transpose>(ti->input_values()[ordered_in_descs[0]->m_input_index], order);
}
auto axis_1 = ngraph::opset4::Constant::create(ngraph::element::i64, ngraph::Shape{1}, {1});
auto in_1 = std::make_shared<ngraph::opset4::Unsqueeze>(ti->input_values()[ordered_in_descs[1]->m_input_index], axis_1);
auto axis_2 = ngraph::opset4::Constant::create(ngraph::element::i64, ngraph::Shape{1}, {0});
auto in_3 = std::make_shared<ngraph::opset4::Unsqueeze>(pattern_map[input_W]->output(0).get_node_shared_ptr(), axis_2);
auto in_4 = std::make_shared<ngraph::opset4::Unsqueeze>(pattern_map[input_R]->output(0).get_node_shared_ptr(), axis_2);
auto in_5 = std::make_shared<ngraph::opset4::Unsqueeze>(pattern_map[input_B]->output(0).get_node_shared_ptr(), axis_2);
auto sequence = std::make_shared<op::v5::RNNSequence>(
in_0,
in_1,
seq_lengths,
in_3,
in_4,
in_5,
rnn_cell->get_hidden_size(),
stride > 0 ? ngraph::op::RecurrentSequenceDirection::FORWARD: ngraph::op::RecurrentSequenceDirection::REVERSE,
rnn_cell->get_activations(),
rnn_cell->get_activations_alpha(),
rnn_cell->get_activations_beta(),
rnn_cell->get_clip());
auto axis_out = ngraph::opset4::Constant::create(ngraph::element::i64, ngraph::Shape{1}, {1});
auto out_0 = std::make_shared<ngraph::opset4::Squeeze>(sequence->output(0), axis_out);
auto out_1 = std::make_shared<ngraph::opset4::Squeeze>(sequence->output(1), axis_out);
std::shared_ptr<Node> out = out_0;
if (slice_axis == 0) {
auto order = ngraph::opset4::Constant::create(ngraph::element::i64, ngraph::Shape{3}, {1, 0, 2});
out = std::make_shared<ngraph::opset4::Transpose>(out_0, order);
}
ngraph::NodeVector outputs = {out, out_1};
for (size_t i = 0; i < ordered_out_descs.size(); ++i) {
if (ordered_out_descs[i]) {
for (const auto &input : ti->output(ordered_out_descs[i]->m_output_index).get_target_inputs()) {
input.replace_source_output(outputs[i]->output(0));
}
outputs[i]->get_output_tensor(0).set_name(op::util::create_ie_output_name(ti->output(ordered_out_descs[i]->m_output_index)));
}
}
ngraph::OutputVector new_nodes = {in_1, in_3, in_4, in_5, sequence, out_0, out_1};
if (slice_axis == 0) {
new_nodes.push_back(out);
new_nodes.push_back(in_0);
}
copy_runtime_info(ti, as_node_vector(new_nodes));
return true;
};
auto m = std::make_shared<ngraph::pattern::Matcher>(tensor_iterator, "ConvertTensorIteratorToRNNSequence");
register_matcher(m, callback);
}
ngraph::pass::ConvertTensorIteratorToGRUSequence::ConvertTensorIteratorToGRUSequence() {
auto tensor_iterator = std::make_shared<ngraph::pattern::op::Label>(ngraph::element::f32,
ngraph::Shape{}, ngraph::pattern::has_class<ngraph::opset4::TensorIterator>());
ngraph::matcher_pass_callback callback = [this](pattern::Matcher &m) {
auto ti = std::dynamic_pointer_cast<ngraph::opset4::TensorIterator>(m.get_match_root());
if (!ti || !m_transformation_callback(ti))
return false;
// create pattern
auto data = std::make_shared<ngraph::opset4::Parameter>(ngraph::element::f32, ngraph::Shape{1, 1, 1});
auto axis_squeeze = std::make_shared<ngraph::opset4::Constant>(ngraph::element::i64, ngraph::Shape{1}, 0);
auto input_data = std::make_shared<ngraph::opset4::Squeeze>(data, axis_squeeze);
auto input_H_state = std::make_shared<ngraph::opset4::Parameter>(ngraph::element::f32, ngraph::Shape{1, 1});
auto input_W = std::make_shared<ngraph::opset4::Constant>(ngraph::element::f32, ngraph::Shape{3, 1});
auto input_R = std::make_shared<ngraph::opset4::Constant>(ngraph::element::f32, ngraph::Shape{3, 1});
auto input_B = std::make_shared<ngraph::opset4::Constant>(ngraph::element::f32, ngraph::Shape{3});
auto cell = std::make_shared<ngraph::opset4::GRUCell>(input_data, input_H_state, input_W, input_R, input_B, 1);
auto axis_unsqueeze = std::make_shared<ngraph::opset4::Constant>(ngraph::element::i64, ngraph::Shape{1}, 0);
auto unsqueeze = std::make_shared<ngraph::opset4::Unsqueeze>(cell, axis_unsqueeze);
ngraph::pattern::Matcher matcher(unsqueeze);
bool match = false;
auto func = ti->get_body();
for (const auto& res : func->get_results()) {
match = matcher.match((res->get_input_source_output(0)));
if (match)
break;
}
// All nodes are in the TI body should be matched in pattern
if (!match || (matcher.get_matched_nodes().size() + func->get_results().size()) != func->get_ops().size())
return false;
auto pattern_map = matcher.get_pattern_map();
auto params = func->get_parameters();
std::vector<std::shared_ptr<ngraph::opset4::TensorIterator::InputDescription>> ordered_in_descs(3);
int64_t stride = 0, slice_axis = 0;
size_t batch_size = 0;
for (const auto& input_desc : ti->get_input_descriptions()) {
auto param = params[input_desc->m_body_parameter_index];
if (param == pattern_map[data]) {
// to get batch size value
if (param->get_partial_shape().is_dynamic()) {
return false;
}
auto slice_input
= std::dynamic_pointer_cast<ngraph::opset4::TensorIterator::SliceInputDescription>(input_desc);
if (!slice_input)
return false;
stride = slice_input->m_stride;
slice_axis = slice_input->m_axis;
if (!(slice_axis == 0 || slice_axis == 1)) {
return false;
}
batch_size = param->get_shape()[slice_axis == 0 ? 1 : 0];
ordered_in_descs[0] = input_desc;
} else if (param == pattern_map[input_H_state]) {
ordered_in_descs[1] = input_desc;
} else {
return false;
}
}
auto seq_lengths = ngraph::opset4::Constant::create(element::i32, Shape{batch_size}, {ti->get_num_iterations()});
auto results = func->get_results();
std::vector<std::shared_ptr<ngraph::opset4::TensorIterator::OutputDescription>> ordered_out_descs(2);
for (const auto& output_desc : ti->get_output_descriptions()) {
std::shared_ptr<opset4::Result> res = results[output_desc->m_body_value_index];
if (res->get_input_source_output(0) == pattern_map[unsqueeze]) {
auto concat_output
= std::dynamic_pointer_cast<ngraph::opset4::TensorIterator::ConcatOutputDescription>(output_desc);
if (!concat_output)
return false;
stride = concat_output->m_stride;
ordered_out_descs[0] = output_desc;
} else if (res->get_input_source_output(0) == pattern_map[cell]->output(0)) {
ordered_out_descs[1] = output_desc;
} else {
return false;
}
}
const auto& rnn_cell = std::dynamic_pointer_cast<ngraph::opset4::GRUCell>(pattern_map[cell]);
auto in_0 = ti->input_values()[ordered_in_descs[0]->m_input_index];
if (slice_axis == 0) {
auto order = ngraph::opset4::Constant::create(ngraph::element::i64, ngraph::Shape{3}, {1, 0, 2});
in_0 = std::make_shared<ngraph::opset4::Transpose>(ti->input_values()[ordered_in_descs[0]->m_input_index], order);
}
auto axis_1 = ngraph::opset4::Constant::create(ngraph::element::i64, ngraph::Shape{1}, {1});
auto in_1 = std::make_shared<ngraph::opset4::Unsqueeze>(ti->input_values()[ordered_in_descs[1]->m_input_index], axis_1);
auto axis_2 = ngraph::opset4::Constant::create(ngraph::element::i64, ngraph::Shape{1}, {0});
auto in_3 = std::make_shared<ngraph::opset4::Unsqueeze>(pattern_map[input_W]->output(0).get_node_shared_ptr(), axis_2);
auto in_4 = std::make_shared<ngraph::opset4::Unsqueeze>(pattern_map[input_R]->output(0).get_node_shared_ptr(), axis_2);
auto in_5 = std::make_shared<ngraph::opset4::Unsqueeze>(pattern_map[input_B]->output(0).get_node_shared_ptr(), axis_2);
auto sequence = std::make_shared<op::v5::GRUSequence>(
in_0,
in_1,
seq_lengths,
in_3,
in_4,
in_5,
rnn_cell->get_hidden_size(),
stride > 0 ? ngraph::op::RecurrentSequenceDirection::FORWARD: ngraph::op::RecurrentSequenceDirection::REVERSE,
rnn_cell->get_activations(),
rnn_cell->get_activations_alpha(),
rnn_cell->get_activations_beta(),
rnn_cell->get_clip(),
rnn_cell->get_linear_before_reset());
auto axis_out = ngraph::opset4::Constant::create(ngraph::element::i64, ngraph::Shape{1}, {1});
auto out_0 = std::make_shared<ngraph::opset4::Squeeze>(sequence->output(0), axis_out);
auto out_1 = std::make_shared<ngraph::opset4::Squeeze>(sequence->output(1), axis_out);
std::shared_ptr<Node> out = out_0;
if (slice_axis == 0) {
auto order = ngraph::opset4::Constant::create(ngraph::element::i64, ngraph::Shape{3}, {1, 0, 2});
out = std::make_shared<ngraph::opset4::Transpose>(out_0, order);
}
ngraph::NodeVector outputs = {out, out_1};
for (size_t i = 0; i < ordered_out_descs.size(); ++i) {
if (ordered_out_descs[i]) {
for (const auto &input : ti->output(ordered_out_descs[i]->m_output_index).get_target_inputs()) {
input.replace_source_output(outputs[i]->output(0));
}
outputs[i]->get_output_tensor(0).set_name(op::util::create_ie_output_name(ti->output(ordered_out_descs[i]->m_output_index)));
}
}
ngraph::OutputVector new_nodes = {in_1, in_3, in_4, in_5, sequence, out_0, out_1};
if (slice_axis == 0) {
new_nodes.push_back(out);
new_nodes.push_back(in_0);
}
copy_runtime_info(ti, as_node_vector(new_nodes));
return true;
};
auto m = std::make_shared<ngraph::pattern::Matcher>(tensor_iterator, "ConvertTensorIteratorToGRUSequence");
register_matcher(m, callback);
}

View File

@ -0,0 +1,278 @@
// Copyright (C) 2020 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#include <gtest/gtest.h>
#include "common_test_utils/test_common.hpp"
#include <string>
#include <memory>
#include <queue>
#include <ngraph/pass/manager.hpp>
#include <ngraph/function.hpp>
#include <ngraph/opsets/opset4.hpp>
#include <ngraph_ops/fully_connected.hpp>
#include <transformations/tensor_iterator_transformations/convert_ti_to_sequences.h>
#include <transformations/utils/utils.hpp>
#include <transformations/init_node_info.hpp>
#include "common_test_utils/ngraph_test_utils.hpp"
using namespace testing;
using namespace ngraph;
TEST(TransformationTests, ConvertTensorIteratorToLSTMSequence) {
std::shared_ptr<ngraph::Function> f(nullptr), f_ref(nullptr);
{
auto X = std::make_shared<opset4::Parameter>(element::f32, Shape{1, 2, 16});
auto Y = std::make_shared<opset4::Parameter>(element::f32, Shape{1, 128});
auto Z = std::make_shared<opset4::Parameter>(element::f32, Shape{1, 128});
auto Xi = std::make_shared<opset4::Parameter>(element::f32, Shape{1, 1, 16});
auto Yi = std::make_shared<opset4::Parameter>(element::f32, Shape{1, 128});
auto Zi = std::make_shared<opset4::Parameter>(element::f32, Shape{1, 128});
// Body
auto axis = ngraph::opset4::Constant::create(ngraph::element::i64, ngraph::Shape{1}, {1});
auto squeeze = std::make_shared<opset4::Squeeze>(Xi, axis);
auto w_val = std::vector<float>(512 * 16, 0);
auto r_val = std::vector<float>(512 * 128, 0);
auto b_val = std::vector<float>(512, 0);
auto W = ngraph::opset4::Constant::create(ngraph::element::f32, ngraph::Shape{512, 16}, w_val);
auto R = ngraph::opset4::Constant::create(ngraph::element::f32, ngraph::Shape{512, 128}, r_val);
auto B = ngraph::opset4::Constant::create(ngraph::element::f32, ngraph::Shape{512}, b_val);
auto lstm_cell = std::make_shared<opset4::LSTMCell>(squeeze, Yi, Zi, W, R, B, 128);
auto res_1 = std::make_shared<opset4::Result>(lstm_cell);
auto axis_unsqueeze = ngraph::opset4::Constant::create(ngraph::element::i64, ngraph::Shape{1}, {1});
auto unsqueeze = std::make_shared<opset4::Unsqueeze>(lstm_cell, axis_unsqueeze);
auto res_2 = std::make_shared<opset4::Result>(unsqueeze);
auto body = std::make_shared<Function>(OutputVector{res_1, res_2},
ParameterVector{Xi, Yi, Zi});
auto tensor_iterator = std::make_shared<opset4::TensorIterator>();
tensor_iterator->set_body(body);
tensor_iterator->set_invariant_input(Zi, Z);
tensor_iterator->set_sliced_input(Xi, X, 0, 1, 1, -1, 1);
tensor_iterator->set_merged_input(Yi, Y, res_1);
auto out0 = tensor_iterator->get_iter_value(res_1, -1);
auto out1 = tensor_iterator->get_concatenated_slices(res_2, 0, 1, 1, -1, 1);
auto res_ti_1 = std::make_shared<opset4::Result>(tensor_iterator->output(1));
//auto res_ti_2 = std::make_shared<opset4::Result>(tensor_iterator->output(0));
f = std::make_shared<ngraph::Function>(ngraph::NodeVector{res_ti_1},
ngraph::ParameterVector{X, Y, Z});
ngraph::pass::Manager m;
m.register_pass<ngraph::pass::InitNodeInfo>();
m.register_pass<ngraph::pass::ConvertTensorIteratorToLSTMSequence>();
m.set_callback([](const std::shared_ptr<const Node>&) -> bool { return true; });
m.run_passes(f);
ASSERT_NO_THROW(check_rt_info(f));
}
{
auto X = std::make_shared<opset4::Parameter>(element::f32, Shape{1, 2, 16});
auto Y = std::make_shared<opset4::Parameter>(element::f32, Shape{1, 128});
auto Z = std::make_shared<opset4::Parameter>(element::f32, Shape{1, 128});
auto w_val = std::vector<float>(512 * 16, 0);
auto r_val = std::vector<float>(512 * 128, 0);
auto b_val = std::vector<float>(512, 0);
auto W = ngraph::opset4::Constant::create(ngraph::element::f32, ngraph::Shape{512, 16}, w_val);
auto R = ngraph::opset4::Constant::create(ngraph::element::f32, ngraph::Shape{512, 128}, r_val);
auto B = ngraph::opset4::Constant::create(ngraph::element::f32, ngraph::Shape{512}, b_val);
auto axis_1 = ngraph::opset4::Constant::create(ngraph::element::i64, ngraph::Shape{1}, {1});
auto in_1 = std::make_shared<ngraph::opset4::Unsqueeze>(Y, axis_1);
auto in_2 = std::make_shared<ngraph::opset4::Unsqueeze>(Z, axis_1);
auto axis_2 = ngraph::opset4::Constant::create(ngraph::element::i64, ngraph::Shape{1}, {0});
auto in_4 = std::make_shared<ngraph::opset4::Unsqueeze>(W, axis_2);
auto in_5 = std::make_shared<ngraph::opset4::Unsqueeze>(R, axis_2);
auto in_6 = std::make_shared<ngraph::opset4::Unsqueeze>(B, axis_2);
auto seq_lengths = ngraph::opset4::Constant::create(element::i32, Shape{1}, {2});
auto lstm_seq = std::make_shared<op::v5::LSTMSequence>(X, in_1, in_2, seq_lengths, in_4, in_5, in_6,
128, ngraph::op::RecurrentSequenceDirection::FORWARD);
auto axis_out = ngraph::opset4::Constant::create(ngraph::element::i64, ngraph::Shape{1}, {1});
auto out_0 = std::make_shared<ngraph::opset4::Squeeze>(lstm_seq->output(0), axis_out);
auto out_1 = std::make_shared<ngraph::opset4::Squeeze>(lstm_seq->output(1), axis_out);
auto out_2 = std::make_shared<ngraph::opset4::Squeeze>(lstm_seq->output(1), axis_out);
auto res_ti_1 = std::make_shared<opset4::Result>(out_0);
f_ref = std::make_shared<ngraph::Function>(ngraph::NodeVector{res_ti_1}, ngraph::ParameterVector{X, Y, Z});
}
auto res = compare_functions(f, f_ref);
ASSERT_TRUE(res.first) << res.second;
}
TEST(TransformationTests, ConvertTensorIteratorToRNNSequence) {
std::shared_ptr<ngraph::Function> f(nullptr), f_ref(nullptr);
{
auto X = std::make_shared<opset4::Parameter>(element::f32, Shape{1, 2, 16});
auto Y = std::make_shared<opset4::Parameter>(element::f32, Shape{1, 128});
auto Xi = std::make_shared<opset4::Parameter>(element::f32, Shape{1, 1, 16});
auto Yi = std::make_shared<opset4::Parameter>(element::f32, Shape{1, 128});
// Body
auto axis = ngraph::opset4::Constant::create(ngraph::element::i64, ngraph::Shape{1}, {1});
auto squeeze = std::make_shared<opset4::Squeeze>(Xi, axis);
auto w_val = std::vector<float>(128 * 16, 0);
auto r_val = std::vector<float>(128 * 128, 0);
auto b_val = std::vector<float>(128, 0);
auto W = ngraph::opset4::Constant::create(ngraph::element::f32, ngraph::Shape{128, 16}, w_val);
auto R = ngraph::opset4::Constant::create(ngraph::element::f32, ngraph::Shape{128, 128}, r_val);
auto B = ngraph::opset4::Constant::create(ngraph::element::f32, ngraph::Shape{128}, b_val);
auto rnn_cell = std::make_shared<opset4::RNNCell>(squeeze, Yi, W, R, B, 128);
auto res_1 = std::make_shared<opset4::Result>(rnn_cell);
auto axis_unsqueeze = ngraph::opset4::Constant::create(ngraph::element::i64, ngraph::Shape{1}, {1});
auto unsqueeze = std::make_shared<opset4::Unsqueeze>(rnn_cell, axis_unsqueeze);
auto res_2 = std::make_shared<opset4::Result>(unsqueeze);
auto body = std::make_shared<Function>(OutputVector{res_1, res_2},
ParameterVector{Xi, Yi});
auto tensor_iterator = std::make_shared<opset4::TensorIterator>();
tensor_iterator->set_body(body);
tensor_iterator->set_sliced_input(Xi, X, 0, 1, 1, -1, 1);
tensor_iterator->set_merged_input(Yi, Y, res_1);
auto out0 = tensor_iterator->get_iter_value(res_1, -1);
auto out1 = tensor_iterator->get_concatenated_slices(res_2, 0, 1, 1, -1, 1);
auto res_ti_1 = std::make_shared<opset4::Result>(tensor_iterator->output(1));
//auto res_ti_2 = std::make_shared<opset4::Result>(tensor_iterator->output(0));
f = std::make_shared<ngraph::Function>(ngraph::NodeVector{res_ti_1},
ngraph::ParameterVector{X, Y});
ngraph::pass::Manager m;
m.register_pass<ngraph::pass::InitNodeInfo>();
m.register_pass<ngraph::pass::ConvertTensorIteratorToRNNSequence>();
m.set_callback([](const std::shared_ptr<const Node>&) -> bool { return true; });
m.run_passes(f);
ASSERT_NO_THROW(check_rt_info(f));
}
{
auto X = std::make_shared<opset4::Parameter>(element::f32, Shape{1, 2, 16});
auto Y = std::make_shared<opset4::Parameter>(element::f32, Shape{1, 128});
auto w_val = std::vector<float>(128 * 16, 0);
auto r_val = std::vector<float>(128 * 128, 0);
auto b_val = std::vector<float>(128, 0);
auto W = ngraph::opset4::Constant::create(ngraph::element::f32, ngraph::Shape{128, 16}, w_val);
auto R = ngraph::opset4::Constant::create(ngraph::element::f32, ngraph::Shape{128, 128}, r_val);
auto B = ngraph::opset4::Constant::create(ngraph::element::f32, ngraph::Shape{128}, b_val);
auto axis_1 = ngraph::opset4::Constant::create(ngraph::element::i64, ngraph::Shape{1}, {1});
auto in_1 = std::make_shared<ngraph::opset4::Unsqueeze>(Y, axis_1);
auto axis_2 = ngraph::opset4::Constant::create(ngraph::element::i64, ngraph::Shape{1}, {0});
auto in_3 = std::make_shared<ngraph::opset4::Unsqueeze>(W, axis_2);
auto in_4 = std::make_shared<ngraph::opset4::Unsqueeze>(R, axis_2);
auto in_5 = std::make_shared<ngraph::opset4::Unsqueeze>(B, axis_2);
auto seq_lengths = ngraph::opset4::Constant::create(element::i32, Shape{1}, {2});
auto rnn_sequence = std::make_shared<op::v5::RNNSequence>(X, in_1, seq_lengths, in_3, in_4, in_5,
128, ngraph::op::RecurrentSequenceDirection::FORWARD);
auto axis_out = ngraph::opset4::Constant::create(ngraph::element::i64, ngraph::Shape{1}, {1});
auto out_0 = std::make_shared<ngraph::opset4::Squeeze>(rnn_sequence->output(0), axis_out);
auto out_1 = std::make_shared<ngraph::opset4::Squeeze>(rnn_sequence->output(1), axis_out);
auto res_ti_1 = std::make_shared<opset4::Result>(out_0);
f_ref = std::make_shared<ngraph::Function>(ngraph::NodeVector{res_ti_1}, ngraph::ParameterVector{X, Y});
}
auto res = compare_functions(f, f_ref);
ASSERT_TRUE(res.first) << res.second;
}
TEST(TransformationTests, ConvertTensorIteratorToGRUSequence) {
std::shared_ptr<ngraph::Function> f(nullptr), f_ref(nullptr);
{
auto X = std::make_shared<opset4::Parameter>(element::f32, Shape{1, 2, 16});
auto Y = std::make_shared<opset4::Parameter>(element::f32, Shape{1, 128});
auto Xi = std::make_shared<opset4::Parameter>(element::f32, Shape{1, 1, 16});
auto Yi = std::make_shared<opset4::Parameter>(element::f32, Shape{1, 128});
// Body
auto axis = ngraph::opset4::Constant::create(ngraph::element::i64, ngraph::Shape{1}, {1});
auto squeeze = std::make_shared<opset4::Squeeze>(Xi, axis);
auto w_val = std::vector<float>(384 * 16, 0);
auto r_val = std::vector<float>(384 * 128, 0);
auto b_val = std::vector<float>(384, 0);
auto W = ngraph::opset4::Constant::create(ngraph::element::f32, ngraph::Shape{384, 16}, w_val);
auto R = ngraph::opset4::Constant::create(ngraph::element::f32, ngraph::Shape{384, 128}, r_val);
auto B = ngraph::opset4::Constant::create(ngraph::element::f32, ngraph::Shape{384}, b_val);
auto gru_cell = std::make_shared<opset4::GRUCell>(squeeze, Yi, W, R, B, 128);
auto res_1 = std::make_shared<opset4::Result>(gru_cell);
auto axis_unsqueeze = ngraph::opset4::Constant::create(ngraph::element::i64, ngraph::Shape{1}, {1});
auto unsqueeze = std::make_shared<opset4::Unsqueeze>(gru_cell, axis_unsqueeze);
auto res_2 = std::make_shared<opset4::Result>(unsqueeze);
auto body = std::make_shared<Function>(OutputVector{res_1, res_2},
ParameterVector{Xi, Yi});
auto tensor_iterator = std::make_shared<opset4::TensorIterator>();
tensor_iterator->set_body(body);
tensor_iterator->set_sliced_input(Xi, X, 0, 1, 1, -1, 1);
tensor_iterator->set_merged_input(Yi, Y, res_1);
auto out0 = tensor_iterator->get_iter_value(res_1, -1);
auto out1 = tensor_iterator->get_concatenated_slices(res_2, 0, 1, 1, -1, 1);
auto res_ti_1 = std::make_shared<opset4::Result>(tensor_iterator->output(1));
//auto res_tRNNCelli_2 = std::make_shared<opset4::Result>(tensor_iterator->output(0));
f = std::make_shared<ngraph::Function>(ngraph::NodeVector{res_ti_1},
ngraph::ParameterVector{X, Y});
ngraph::pass::Manager m;
m.register_pass<ngraph::pass::InitNodeInfo>();
m.register_pass<ngraph::pass::ConvertTensorIteratorToGRUSequence>();
m.set_callback([](const std::shared_ptr<const Node>&) -> bool { return true; });
m.run_passes(f);
ASSERT_NO_THROW(check_rt_info(f));
}
{
auto X = std::make_shared<opset4::Parameter>(element::f32, Shape{1, 2, 16});
auto Y = std::make_shared<opset4::Parameter>(element::f32, Shape{1, 128});
auto w_val = std::vector<float>(384 * 16, 0);
auto r_val = std::vector<float>(384 * 128, 0);
auto b_val = std::vector<float>(384, 0);
auto W = ngraph::opset4::Constant::create(ngraph::element::f32, ngraph::Shape{384, 16}, w_val);
auto R = ngraph::opset4::Constant::create(ngraph::element::f32, ngraph::Shape{384, 128}, r_val);
auto B = ngraph::opset4::Constant::create(ngraph::element::f32, ngraph::Shape{384}, b_val);
auto axis_1 = ngraph::opset4::Constant::create(ngraph::element::i64, ngraph::Shape{1}, {1});
auto in_1 = std::make_shared<ngraph::opset4::Unsqueeze>(Y, axis_1);
auto axis_2 = ngraph::opset4::Constant::create(ngraph::element::i64, ngraph::Shape{1}, {0});
auto in_3 = std::make_shared<ngraph::opset4::Unsqueeze>(W, axis_2);
auto in_4 = std::make_shared<ngraph::opset4::Unsqueeze>(R, axis_2);
auto in_5 = std::make_shared<ngraph::opset4::Unsqueeze>(B, axis_2);
auto seq_lengths = ngraph::opset4::Constant::create(element::i32, Shape{1}, {2});
auto gru_sequence = std::make_shared<op::v5::GRUSequence>(X, in_1, seq_lengths, in_3, in_4, in_5,
128, ngraph::op::RecurrentSequenceDirection::FORWARD);
auto axis_out = ngraph::opset4::Constant::create(ngraph::element::i64, ngraph::Shape{1}, {1});
auto out_0 = std::make_shared<ngraph::opset4::Squeeze>(gru_sequence->output(0), axis_out);
auto out_1 = std::make_shared<ngraph::opset4::Squeeze>(gru_sequence->output(1), axis_out);
auto res_ti_1 = std::make_shared<opset4::Result>(out_0);
f_ref = std::make_shared<ngraph::Function>(ngraph::NodeVector{res_ti_1}, ngraph::ParameterVector{X, Y});
}
auto res = compare_functions(f, f_ref);
ASSERT_TRUE(res.first) << res.second;
}