TensorIterator to RNN/GRU/LSTM Sequence transformation (#2146)
* ti to sequences transformations * fix sequences to sequences ie conversion * resolve review marks * resolve review remarks, fix ti to sequences transformations to support batch > 1 if slice axis == 0 * temporary enable ngraph ti transformations for cpu plugin * fix includes * Revert "fix includes" This reverts commit6cf15b97be
. * Revert "temporary enable ngraph ti transformations for cpu plugin" This reverts commitfd528d7216
. * delete todo comments
This commit is contained in:
parent
ac2370b420
commit
cd722d72df
@ -0,0 +1,58 @@
|
||||
// Copyright (C) 2020 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <vector>
|
||||
#include <memory>
|
||||
|
||||
#include <vector>
|
||||
#include <memory>
|
||||
|
||||
#include <transformations_visibility.hpp>
|
||||
|
||||
#include <ngraph/pass/graph_rewrite.hpp>
|
||||
|
||||
namespace ngraph {
|
||||
namespace pass {
|
||||
|
||||
class TRANSFORMATIONS_API ConvertTensorIteratorToLSTMSequence;
|
||||
class TRANSFORMATIONS_API ConvertTensorIteratorToRNNSequence;
|
||||
class TRANSFORMATIONS_API ConvertTensorIteratorToGRUSequence;
|
||||
|
||||
} // namespace pass
|
||||
} // namespace ngraph
|
||||
|
||||
/**
|
||||
* @ingroup ie_transformation_common_api
|
||||
* @brief Finds all TensorIterator layers, detects the pattern Squeeze->LSTMCell->Unsqueeze in the TensorIterator body,
|
||||
* converts this pattern to LSTMSequence layer and replaces them TensorIterator.
|
||||
*/
|
||||
|
||||
class ngraph::pass::ConvertTensorIteratorToLSTMSequence: public ngraph::pass::MatcherPass {
|
||||
public:
|
||||
ConvertTensorIteratorToLSTMSequence();
|
||||
};
|
||||
|
||||
/**
|
||||
* @ingroup ie_transformation_common_api
|
||||
* @brief Finds all TensorIterator layers, detects the pattern Squeeze->RNNCell->Unsqueeze in the TensorIterator body,
|
||||
* converts this pattern to RNNSequence layer and replaces them TensorIterator.
|
||||
*/
|
||||
|
||||
class ngraph::pass::ConvertTensorIteratorToRNNSequence: public ngraph::pass::MatcherPass {
|
||||
public:
|
||||
ConvertTensorIteratorToRNNSequence();
|
||||
};
|
||||
|
||||
/**
|
||||
* @ingroup ie_transformation_common_api
|
||||
* @brief Finds all TensorIterator layers, detects the pattern Squeeze->GRUCell->Unsqueeze in the TensorIterator body,
|
||||
* converts this pattern to GRUSequence layer and replaces them TensorIterator.
|
||||
*/
|
||||
|
||||
class ngraph::pass::ConvertTensorIteratorToGRUSequence: public ngraph::pass::MatcherPass {
|
||||
public:
|
||||
ConvertTensorIteratorToGRUSequence();
|
||||
};
|
@ -23,31 +23,26 @@ ngraph::pass::ConvertLSTMSequenceMatcher::ConvertLSTMSequenceMatcher() {
|
||||
return false;
|
||||
}
|
||||
|
||||
const auto& W = std::dynamic_pointer_cast<ngraph::opset4::Constant>(
|
||||
lstm_sequence->input_value(4).get_node_shared_ptr());
|
||||
if (!W) {
|
||||
return false;
|
||||
}
|
||||
const auto& W = lstm_sequence->input_value(4);
|
||||
const auto& R = lstm_sequence->input_value(5);
|
||||
|
||||
const auto& R = std::dynamic_pointer_cast<ngraph::opset4::Constant>(
|
||||
lstm_sequence->input_value(5).get_node_shared_ptr());
|
||||
if (!R) {
|
||||
// Bidirectional cases are not supported
|
||||
if (lstm_sequence->get_direction() == ngraph::op::RecurrentSequenceDirection::BIDIRECTIONAL)
|
||||
return false;
|
||||
}
|
||||
|
||||
// for forward/reverse cases we can squeeze num_direction dimension
|
||||
auto axis_1 = ngraph::opset4::Constant::create(ngraph::element::i64, ngraph::Shape{1}, {1});
|
||||
auto in_1 = std::make_shared<ngraph::opset4::Squeeze>(lstm_sequence->input(1).get_source_output(), axis_1);
|
||||
auto in_2 = std::make_shared<ngraph::opset4::Squeeze>(lstm_sequence->input(2).get_source_output(), axis_1);
|
||||
auto concat = std::make_shared<ngraph::opset4::Concat>(ngraph::NodeVector({W, R}), 2);
|
||||
auto in_1 = std::make_shared<ngraph::opset4::Squeeze>(lstm_sequence->input_value(1), axis_1);
|
||||
auto in_2 = std::make_shared<ngraph::opset4::Squeeze>(lstm_sequence->input_value(2), axis_1);
|
||||
auto concat = std::make_shared<ngraph::opset4::Concat>(ngraph::OutputVector{W, R}, 2);
|
||||
auto axis_2 = ngraph::opset4::Constant::create(ngraph::element::i64, ngraph::Shape{1}, {0});
|
||||
auto in_3 = std::make_shared<ngraph::opset4::Squeeze>(concat->output(0), axis_2);
|
||||
auto in_4 = std::make_shared<ngraph::opset4::Squeeze>(lstm_sequence->input(6).get_source_output(), axis_2);
|
||||
auto in_4 = std::make_shared<ngraph::opset4::Squeeze>(lstm_sequence->input_value(6), axis_2);
|
||||
auto lstm_sequence_ie = std::make_shared<ngraph::op::LSTMSequenceIE>(
|
||||
lstm_sequence->input(0).get_source_output(), // X
|
||||
in_1, // initial_hidden_state
|
||||
in_2, // initial_cell_state
|
||||
lstm_sequence->input(3).get_source_output(),
|
||||
lstm_sequence->input_value(3),
|
||||
in_3, // WR
|
||||
in_4, // B
|
||||
lstm_sequence->get_hidden_size(),
|
||||
@ -84,34 +79,25 @@ ngraph::pass::ConvertGRUSequenceMatcher::ConvertGRUSequenceMatcher() {
|
||||
return false;
|
||||
}
|
||||
|
||||
auto W = std::dynamic_pointer_cast<ngraph::opset4::Constant>(
|
||||
gru_sequence->input_value(3).get_node_shared_ptr());
|
||||
if (!W) {
|
||||
return false;
|
||||
}
|
||||
auto W = gru_sequence->input_value(3);
|
||||
auto R = gru_sequence->input_value(4);
|
||||
|
||||
auto R = std::dynamic_pointer_cast<ngraph::opset4::Constant>(
|
||||
gru_sequence->input_value(4).get_node_shared_ptr());
|
||||
if (!R) {
|
||||
return false;
|
||||
}
|
||||
|
||||
// todo: add exception?
|
||||
// Bidirectional cases are not supported
|
||||
if (gru_sequence->get_direction() == ngraph::op::RecurrentSequenceDirection::BIDIRECTIONAL)
|
||||
return false;
|
||||
|
||||
// for forward/reverse cases we can squeeze num_direction dimension
|
||||
auto axis_1 = ngraph::opset4::Constant::create(ngraph::element::i64, ngraph::Shape{1}, {1});
|
||||
auto in_1 = std::make_shared<ngraph::opset4::Squeeze>(gru_sequence->input(1).get_source_output(), axis_1);
|
||||
auto concat = std::make_shared<ngraph::opset4::Concat>(ngraph::NodeVector({W, R}), 2);
|
||||
auto in_1 = std::make_shared<ngraph::opset4::Squeeze>(gru_sequence->input_value(1), axis_1);
|
||||
auto concat = std::make_shared<ngraph::opset4::Concat>(ngraph::OutputVector{W, R}, 2);
|
||||
auto axis_2 = ngraph::opset4::Constant::create(ngraph::element::i64, ngraph::Shape{1}, {0});
|
||||
auto in_3 = std::make_shared<ngraph::opset4::Squeeze>(concat->output(0), axis_2);
|
||||
auto in_4 = std::make_shared<ngraph::opset4::Squeeze>(gru_sequence->input(5).get_source_output(), axis_2);
|
||||
auto in_4 = std::make_shared<ngraph::opset4::Squeeze>(gru_sequence->input_value(5), axis_2);
|
||||
|
||||
auto gru_sequence_ie = std::make_shared<ngraph::op::GRUSequenceIE>(
|
||||
gru_sequence->input(0).get_source_output(), // X
|
||||
gru_sequence->input_value(0), // X
|
||||
in_1, // initial_hidden_state
|
||||
gru_sequence->input(2).get_source_output(),
|
||||
gru_sequence->input_value(2),
|
||||
in_3, // WR
|
||||
in_4, // B
|
||||
gru_sequence->get_hidden_size(),
|
||||
@ -146,27 +132,22 @@ ngraph::pass::ConvertRNNSequenceMatcher::ConvertRNNSequenceMatcher() {
|
||||
return false;
|
||||
}
|
||||
|
||||
auto W = std::dynamic_pointer_cast<ngraph::opset4::Constant>(
|
||||
rnn_sequence->input_value(3).get_node_shared_ptr());
|
||||
if (!W) {
|
||||
// Bidirectional cases are not supported
|
||||
if (rnn_sequence->get_direction() == ngraph::op::RecurrentSequenceDirection::BIDIRECTIONAL)
|
||||
return false;
|
||||
}
|
||||
|
||||
auto R = std::dynamic_pointer_cast<ngraph::opset4::Constant>(
|
||||
rnn_sequence->input_value(4).get_node_shared_ptr());
|
||||
if (!R) {
|
||||
return false;
|
||||
}
|
||||
auto W = rnn_sequence->input_value(3);
|
||||
auto R = rnn_sequence->input_value(4);
|
||||
|
||||
// for forward/reverse cases we can squeeze num_direction dimension
|
||||
auto axis_1 = ngraph::opset4::Constant::create(ngraph::element::i64, ngraph::Shape{1}, {1});
|
||||
auto in_1 = std::make_shared<ngraph::opset4::Squeeze>(rnn_sequence->input(1).get_source_output(), axis_1);
|
||||
auto concat = std::make_shared<ngraph::opset4::Concat>(ngraph::NodeVector({W, R}), 2);
|
||||
auto in_1 = std::make_shared<ngraph::opset4::Squeeze>(rnn_sequence->input_value(1), axis_1);
|
||||
auto concat = std::make_shared<ngraph::opset4::Concat>(ngraph::OutputVector{W, R}, 2);
|
||||
auto axis_2 = ngraph::opset4::Constant::create(ngraph::element::i64, ngraph::Shape{1}, {0});
|
||||
auto in_3 = std::make_shared<ngraph::opset4::Squeeze>(concat->output(0), axis_2);
|
||||
auto in_4 = std::make_shared<ngraph::opset4::Squeeze>(rnn_sequence->input(5).get_source_output(), axis_2);
|
||||
auto in_4 = std::make_shared<ngraph::opset4::Squeeze>(rnn_sequence->input_value(5), axis_2);
|
||||
auto rnn_sequence_ie = std::make_shared<ngraph::op::RNNSequenceIE>(
|
||||
rnn_sequence->input(0).get_source_output(), // X
|
||||
rnn_sequence->input_value(0), // X
|
||||
in_1, // initial_hidden_state
|
||||
rnn_sequence->input_value(2),
|
||||
in_3, // WR
|
||||
|
@ -0,0 +1,478 @@
|
||||
// Copyright (C) 2020 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#include "transformations/tensor_iterator_transformations/convert_ti_to_sequences.h"
|
||||
#include "transformations/utils/utils.hpp"
|
||||
|
||||
#include <memory>
|
||||
#include <vector>
|
||||
|
||||
#include <ngraph/node.hpp>
|
||||
#include <ngraph/pass/manager.hpp>
|
||||
#include <ngraph/opsets/opset4.hpp>
|
||||
#include <ngraph/rt_info.hpp>
|
||||
#include <ngraph/graph_util.hpp>
|
||||
#include <ngraph/specialize_function.hpp>
|
||||
#include <ngraph/pattern/op/wrap_type.hpp>
|
||||
|
||||
ngraph::pass::ConvertTensorIteratorToLSTMSequence::ConvertTensorIteratorToLSTMSequence() {
|
||||
auto tensor_iterator = std::make_shared<ngraph::pattern::op::Label>(ngraph::element::f32,
|
||||
ngraph::Shape{}, ngraph::pattern::has_class<ngraph::opset4::TensorIterator>());
|
||||
ngraph::matcher_pass_callback callback = [this](pattern::Matcher &m) {
|
||||
auto ti = std::dynamic_pointer_cast<ngraph::opset4::TensorIterator>(m.get_match_root());
|
||||
if (!ti || !m_transformation_callback(ti))
|
||||
return false;
|
||||
|
||||
// create pattern
|
||||
auto data = std::make_shared<ngraph::opset4::Parameter>(ngraph::element::f32, ngraph::Shape{1, 1, 1});
|
||||
auto axis_squeeze = std::make_shared<ngraph::opset4::Constant>(ngraph::element::i64, ngraph::Shape{1}, 1);
|
||||
|
||||
auto input_data = std::make_shared<ngraph::opset4::Squeeze>(data, axis_squeeze);
|
||||
auto input_H_state = std::make_shared<ngraph::opset4::Parameter>(ngraph::element::f32, ngraph::Shape{1, 1});
|
||||
auto input_C_state = std::make_shared<ngraph::opset4::Parameter>(ngraph::element::f32, ngraph::Shape{1, 1});
|
||||
auto input_W = std::make_shared<ngraph::opset4::Constant>(ngraph::element::f32, ngraph::Shape{4, 1});
|
||||
auto input_R = std::make_shared<ngraph::opset4::Constant>(ngraph::element::f32, ngraph::Shape{4, 1});
|
||||
auto input_B = std::make_shared<ngraph::opset4::Constant>(ngraph::element::f32, ngraph::Shape{4});
|
||||
|
||||
auto cell = std::make_shared<ngraph::opset4::LSTMCell>(input_data, input_H_state, input_C_state,
|
||||
input_W, input_R, input_B, 1);
|
||||
|
||||
auto axis_unsqueeze = std::make_shared<ngraph::opset4::Constant>(ngraph::element::i64, ngraph::Shape{1}, 1);
|
||||
auto unsqueeze = std::make_shared<ngraph::opset4::Unsqueeze>(cell, axis_unsqueeze);
|
||||
ngraph::pattern::Matcher matcher(unsqueeze);
|
||||
|
||||
bool match = false;
|
||||
auto func = ti->get_body();
|
||||
for (const auto& res : func->get_results()) {
|
||||
match = matcher.match((res->get_input_source_output(0)));
|
||||
if (match)
|
||||
break;
|
||||
}
|
||||
|
||||
// All nodes are in the TI body should be matched in pattern
|
||||
if (!match || (matcher.get_matched_nodes().size() + func->get_results().size()) != func->get_ops().size())
|
||||
return false;
|
||||
|
||||
auto pattern_map = matcher.get_pattern_map();
|
||||
|
||||
auto params = func->get_parameters();
|
||||
std::vector<std::shared_ptr<ngraph::opset4::TensorIterator::InputDescription>> ordered_in_descs(3);
|
||||
int64_t stride = 0, slice_axis = 0;
|
||||
size_t batch_size = 0;
|
||||
for (const auto& input_desc : ti->get_input_descriptions()) {
|
||||
auto param = params[input_desc->m_body_parameter_index];
|
||||
if (param == pattern_map[data]) {
|
||||
// to get batch size value
|
||||
if (param->get_partial_shape().is_dynamic()) {
|
||||
return false;
|
||||
}
|
||||
auto slice_input
|
||||
= std::dynamic_pointer_cast<ngraph::opset4::TensorIterator::SliceInputDescription>(input_desc);
|
||||
if (!slice_input)
|
||||
return false;
|
||||
|
||||
stride = slice_input->m_stride;
|
||||
slice_axis = slice_input->m_axis;
|
||||
|
||||
if (!(slice_axis == 0 || slice_axis == 1)) {
|
||||
return false;
|
||||
}
|
||||
batch_size = param->get_shape()[slice_axis == 0 ? 1 : 0];
|
||||
ordered_in_descs[0] = input_desc;
|
||||
} else if (param == pattern_map[input_H_state]) {
|
||||
ordered_in_descs[1] = input_desc;
|
||||
} else if (param == pattern_map[input_C_state]) {
|
||||
ordered_in_descs[2] = input_desc;
|
||||
} else {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
auto results = func->get_results();
|
||||
std::vector<std::shared_ptr<ngraph::opset4::TensorIterator::OutputDescription>> ordered_out_descs(3);
|
||||
for (const auto& output_desc : ti->get_output_descriptions()) {
|
||||
std::shared_ptr<opset4::Result> res = results[output_desc->m_body_value_index];
|
||||
if (res->get_input_source_output(0) == pattern_map[unsqueeze]) {
|
||||
auto concat_output
|
||||
= std::dynamic_pointer_cast<ngraph::opset4::TensorIterator::ConcatOutputDescription>(output_desc);
|
||||
if (!concat_output)
|
||||
return false;
|
||||
|
||||
stride = concat_output->m_stride;
|
||||
ordered_out_descs[0] = output_desc;
|
||||
} else if (res->get_input_source_output(0) == pattern_map[cell]->output(0)) {
|
||||
ordered_out_descs[1] = output_desc;
|
||||
} else if (res->get_input_source_output(0) == pattern_map[cell]->output(1)) {
|
||||
ordered_out_descs[2] = output_desc;
|
||||
} else {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
auto seq_lengths = ngraph::opset4::Constant::create(element::i32, Shape{batch_size}, {ti->get_num_iterations()});
|
||||
const auto& lstm_cell = std::dynamic_pointer_cast<ngraph::opset4::LSTMCell>(pattern_map[cell]);
|
||||
auto in_0 = ti->input_values()[ordered_in_descs[0]->m_input_index];
|
||||
if (slice_axis == 0) {
|
||||
auto order = ngraph::opset4::Constant::create(ngraph::element::i64, ngraph::Shape{3}, {1, 0, 2});
|
||||
in_0 = std::make_shared<ngraph::opset4::Transpose>(ti->input_values()[ordered_in_descs[0]->m_input_index], order);
|
||||
}
|
||||
auto axis_1 = ngraph::opset4::Constant::create(ngraph::element::i64, ngraph::Shape{1}, {1});
|
||||
auto in_1 = std::make_shared<ngraph::opset4::Unsqueeze>(ti->input_values()[ordered_in_descs[1]->m_input_index], axis_1);
|
||||
auto in_2 = std::make_shared<ngraph::opset4::Unsqueeze>(ti->input_values()[ordered_in_descs[2]->m_input_index], axis_1);
|
||||
|
||||
auto axis_2 = ngraph::opset4::Constant::create(ngraph::element::i64, ngraph::Shape{1}, {0});
|
||||
auto in_4 = std::make_shared<ngraph::opset4::Unsqueeze>(pattern_map[input_W]->output(0).get_node_shared_ptr(), axis_2);
|
||||
auto in_5 = std::make_shared<ngraph::opset4::Unsqueeze>(pattern_map[input_R]->output(0).get_node_shared_ptr(), axis_2);
|
||||
auto in_6 = std::make_shared<ngraph::opset4::Unsqueeze>(pattern_map[input_B]->output(0).get_node_shared_ptr(), axis_2);
|
||||
auto sequence = std::make_shared<op::v5::LSTMSequence>(
|
||||
in_0,
|
||||
in_1,
|
||||
in_2,
|
||||
seq_lengths,
|
||||
in_4,
|
||||
in_5,
|
||||
in_6,
|
||||
lstm_cell->get_hidden_size(),
|
||||
stride > 0 ? ngraph::op::RecurrentSequenceDirection::FORWARD: ngraph::op::RecurrentSequenceDirection::REVERSE,
|
||||
lstm_cell->get_activations_alpha(),
|
||||
lstm_cell->get_activations_beta(),
|
||||
lstm_cell->get_activations(),
|
||||
lstm_cell->get_clip());
|
||||
|
||||
auto axis_out = ngraph::opset4::Constant::create(ngraph::element::i64, ngraph::Shape{1}, {1});
|
||||
auto out_0 = std::make_shared<ngraph::opset4::Squeeze>(sequence->output(0), axis_out);
|
||||
auto out_1 = std::make_shared<ngraph::opset4::Squeeze>(sequence->output(1), axis_out);
|
||||
auto out_2 = std::make_shared<ngraph::opset4::Squeeze>(sequence->output(2), axis_out);
|
||||
|
||||
std::shared_ptr<Node> out = out_0;
|
||||
if (slice_axis == 0) {
|
||||
auto order = ngraph::opset4::Constant::create(ngraph::element::i64, ngraph::Shape{3}, {1, 0, 2});
|
||||
out = std::make_shared<ngraph::opset4::Transpose>(out_0, order);
|
||||
}
|
||||
|
||||
ngraph::NodeVector outputs = {out, out_1, out_2};
|
||||
for (size_t i = 0; i < ordered_out_descs.size(); ++i) {
|
||||
if (ordered_out_descs[i]) {
|
||||
for (const auto &input : ti->output(ordered_out_descs[i]->m_output_index).get_target_inputs()) {
|
||||
input.replace_source_output(outputs[i]->output(0));
|
||||
}
|
||||
outputs[i]->get_output_tensor(0).set_name(op::util::create_ie_output_name(ti->output(ordered_out_descs[i]->m_output_index)));
|
||||
}
|
||||
}
|
||||
|
||||
ngraph::NodeVector new_nodes = {in_1, in_2, in_4, in_5, in_6, sequence, out_0, out_1, out_2};
|
||||
if (slice_axis == 0) {
|
||||
new_nodes.push_back(out);
|
||||
new_nodes.push_back(in_0.get_node_shared_ptr());
|
||||
}
|
||||
copy_runtime_info(ti, new_nodes);
|
||||
return true;
|
||||
};
|
||||
|
||||
auto m = std::make_shared<ngraph::pattern::Matcher>(tensor_iterator, "ConvertTensorIteratorToLSTMSequence");
|
||||
register_matcher(m, callback);
|
||||
}
|
||||
|
||||
ngraph::pass::ConvertTensorIteratorToRNNSequence::ConvertTensorIteratorToRNNSequence() {
|
||||
auto tensor_iterator = std::make_shared<ngraph::pattern::op::Label>(ngraph::element::f32,
|
||||
ngraph::Shape{}, ngraph::pattern::has_class<ngraph::opset4::TensorIterator>());
|
||||
ngraph::matcher_pass_callback callback = [this](pattern::Matcher &m) {
|
||||
auto ti = std::dynamic_pointer_cast<ngraph::opset4::TensorIterator>(m.get_match_root());
|
||||
if (!ti || !m_transformation_callback(ti))
|
||||
return false;
|
||||
|
||||
// create pattern
|
||||
auto data = std::make_shared<ngraph::opset4::Parameter>(ngraph::element::f32, ngraph::Shape{1, 1, 1});
|
||||
auto axis_squeeze = std::make_shared<ngraph::opset4::Constant>(ngraph::element::i64, ngraph::Shape{1}, 0);
|
||||
auto input_data = std::make_shared<ngraph::opset4::Squeeze>(data, axis_squeeze);
|
||||
|
||||
auto input_H_state = std::make_shared<ngraph::opset4::Parameter>(ngraph::element::f32, ngraph::Shape{1, 1});
|
||||
auto input_W = std::make_shared<ngraph::opset4::Constant>(ngraph::element::f32, ngraph::Shape{1, 1});
|
||||
auto input_R = std::make_shared<ngraph::opset4::Constant>(ngraph::element::f32, ngraph::Shape{1, 1});
|
||||
auto input_B = std::make_shared<ngraph::opset4::Constant>(ngraph::element::f32, ngraph::Shape{1});
|
||||
|
||||
auto cell = std::make_shared<ngraph::opset4::RNNCell>(input_data, input_H_state, input_W, input_R, input_B, 1);
|
||||
|
||||
auto axis_unsqueeze = std::make_shared<ngraph::opset4::Constant>(ngraph::element::i64, ngraph::Shape{1}, 0);
|
||||
auto unsqueeze = std::make_shared<ngraph::opset4::Unsqueeze>(cell, axis_unsqueeze);
|
||||
ngraph::pattern::Matcher matcher(unsqueeze);
|
||||
|
||||
bool match = false;
|
||||
auto func = ti->get_body();
|
||||
for (const auto& res : func->get_results()) {
|
||||
match = matcher.match((res->get_input_source_output(0)));
|
||||
if (match)
|
||||
break;
|
||||
}
|
||||
|
||||
// All nodes are in the TI body should be matched in pattern
|
||||
if (!match || (matcher.get_matched_nodes().size() + func->get_results().size()) != func->get_ops().size())
|
||||
return false;
|
||||
|
||||
auto pattern_map = matcher.get_pattern_map();
|
||||
|
||||
auto params = func->get_parameters();
|
||||
std::vector<std::shared_ptr<ngraph::opset4::TensorIterator::InputDescription>> ordered_in_descs(3);
|
||||
int64_t stride = 0, slice_axis = 0;
|
||||
size_t batch_size = 0;
|
||||
for (const auto& input_desc : ti->get_input_descriptions()) {
|
||||
auto param = params[input_desc->m_body_parameter_index];
|
||||
if (param == pattern_map[data]) {
|
||||
// to get batch size value
|
||||
if (param->get_partial_shape().is_dynamic()) {
|
||||
return false;
|
||||
}
|
||||
auto slice_input
|
||||
= std::dynamic_pointer_cast<ngraph::opset4::TensorIterator::SliceInputDescription>(input_desc);
|
||||
if (!slice_input)
|
||||
return false;
|
||||
|
||||
stride = slice_input->m_stride;
|
||||
slice_axis = slice_input->m_axis;
|
||||
if (!(slice_axis == 0 || slice_axis == 1)) {
|
||||
return false;
|
||||
}
|
||||
batch_size = param->get_shape()[slice_axis == 0 ? 1 : 0];
|
||||
ordered_in_descs[0] = input_desc;
|
||||
} else if (param == pattern_map[input_H_state]) {
|
||||
ordered_in_descs[1] = input_desc;
|
||||
} else {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
auto seq_lengths = ngraph::opset4::Constant::create(element::i32, Shape{batch_size}, {ti->get_num_iterations()});
|
||||
|
||||
auto results = func->get_results();
|
||||
std::vector<std::shared_ptr<ngraph::opset4::TensorIterator::OutputDescription>> ordered_out_descs(2);
|
||||
for (const auto& output_desc : ti->get_output_descriptions()) {
|
||||
std::shared_ptr<opset4::Result> res = results[output_desc->m_body_value_index];
|
||||
if (res->get_input_source_output(0) == pattern_map[unsqueeze]) {
|
||||
auto concat_output
|
||||
= std::dynamic_pointer_cast<ngraph::opset4::TensorIterator::ConcatOutputDescription>(output_desc);
|
||||
if (!concat_output)
|
||||
return false;
|
||||
|
||||
stride = concat_output->m_stride;
|
||||
ordered_out_descs[0] = output_desc;
|
||||
} else if (res->get_input_source_output(0) == pattern_map[cell]->output(0)) {
|
||||
ordered_out_descs[1] = output_desc;
|
||||
} else {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
const auto& rnn_cell = std::dynamic_pointer_cast<ngraph::opset4::RNNCell>(pattern_map[cell]);
|
||||
|
||||
auto in_0 = ti->input_values()[ordered_in_descs[0]->m_input_index];
|
||||
if (slice_axis == 0) {
|
||||
auto order = ngraph::opset4::Constant::create(ngraph::element::i64, ngraph::Shape{3}, {1, 0, 2});
|
||||
in_0 = std::make_shared<ngraph::opset4::Transpose>(ti->input_values()[ordered_in_descs[0]->m_input_index], order);
|
||||
}
|
||||
|
||||
auto axis_1 = ngraph::opset4::Constant::create(ngraph::element::i64, ngraph::Shape{1}, {1});
|
||||
auto in_1 = std::make_shared<ngraph::opset4::Unsqueeze>(ti->input_values()[ordered_in_descs[1]->m_input_index], axis_1);
|
||||
|
||||
auto axis_2 = ngraph::opset4::Constant::create(ngraph::element::i64, ngraph::Shape{1}, {0});
|
||||
auto in_3 = std::make_shared<ngraph::opset4::Unsqueeze>(pattern_map[input_W]->output(0).get_node_shared_ptr(), axis_2);
|
||||
auto in_4 = std::make_shared<ngraph::opset4::Unsqueeze>(pattern_map[input_R]->output(0).get_node_shared_ptr(), axis_2);
|
||||
auto in_5 = std::make_shared<ngraph::opset4::Unsqueeze>(pattern_map[input_B]->output(0).get_node_shared_ptr(), axis_2);
|
||||
auto sequence = std::make_shared<op::v5::RNNSequence>(
|
||||
in_0,
|
||||
in_1,
|
||||
seq_lengths,
|
||||
in_3,
|
||||
in_4,
|
||||
in_5,
|
||||
rnn_cell->get_hidden_size(),
|
||||
stride > 0 ? ngraph::op::RecurrentSequenceDirection::FORWARD: ngraph::op::RecurrentSequenceDirection::REVERSE,
|
||||
rnn_cell->get_activations(),
|
||||
rnn_cell->get_activations_alpha(),
|
||||
rnn_cell->get_activations_beta(),
|
||||
rnn_cell->get_clip());
|
||||
|
||||
auto axis_out = ngraph::opset4::Constant::create(ngraph::element::i64, ngraph::Shape{1}, {1});
|
||||
auto out_0 = std::make_shared<ngraph::opset4::Squeeze>(sequence->output(0), axis_out);
|
||||
auto out_1 = std::make_shared<ngraph::opset4::Squeeze>(sequence->output(1), axis_out);
|
||||
|
||||
std::shared_ptr<Node> out = out_0;
|
||||
if (slice_axis == 0) {
|
||||
auto order = ngraph::opset4::Constant::create(ngraph::element::i64, ngraph::Shape{3}, {1, 0, 2});
|
||||
out = std::make_shared<ngraph::opset4::Transpose>(out_0, order);
|
||||
}
|
||||
|
||||
ngraph::NodeVector outputs = {out, out_1};
|
||||
for (size_t i = 0; i < ordered_out_descs.size(); ++i) {
|
||||
if (ordered_out_descs[i]) {
|
||||
for (const auto &input : ti->output(ordered_out_descs[i]->m_output_index).get_target_inputs()) {
|
||||
input.replace_source_output(outputs[i]->output(0));
|
||||
}
|
||||
outputs[i]->get_output_tensor(0).set_name(op::util::create_ie_output_name(ti->output(ordered_out_descs[i]->m_output_index)));
|
||||
}
|
||||
}
|
||||
|
||||
ngraph::OutputVector new_nodes = {in_1, in_3, in_4, in_5, sequence, out_0, out_1};
|
||||
if (slice_axis == 0) {
|
||||
new_nodes.push_back(out);
|
||||
new_nodes.push_back(in_0);
|
||||
}
|
||||
copy_runtime_info(ti, as_node_vector(new_nodes));
|
||||
return true;
|
||||
};
|
||||
|
||||
auto m = std::make_shared<ngraph::pattern::Matcher>(tensor_iterator, "ConvertTensorIteratorToRNNSequence");
|
||||
register_matcher(m, callback);
|
||||
}
|
||||
|
||||
ngraph::pass::ConvertTensorIteratorToGRUSequence::ConvertTensorIteratorToGRUSequence() {
|
||||
auto tensor_iterator = std::make_shared<ngraph::pattern::op::Label>(ngraph::element::f32,
|
||||
ngraph::Shape{}, ngraph::pattern::has_class<ngraph::opset4::TensorIterator>());
|
||||
ngraph::matcher_pass_callback callback = [this](pattern::Matcher &m) {
|
||||
auto ti = std::dynamic_pointer_cast<ngraph::opset4::TensorIterator>(m.get_match_root());
|
||||
if (!ti || !m_transformation_callback(ti))
|
||||
return false;
|
||||
|
||||
// create pattern
|
||||
auto data = std::make_shared<ngraph::opset4::Parameter>(ngraph::element::f32, ngraph::Shape{1, 1, 1});
|
||||
auto axis_squeeze = std::make_shared<ngraph::opset4::Constant>(ngraph::element::i64, ngraph::Shape{1}, 0);
|
||||
auto input_data = std::make_shared<ngraph::opset4::Squeeze>(data, axis_squeeze);
|
||||
|
||||
auto input_H_state = std::make_shared<ngraph::opset4::Parameter>(ngraph::element::f32, ngraph::Shape{1, 1});
|
||||
auto input_W = std::make_shared<ngraph::opset4::Constant>(ngraph::element::f32, ngraph::Shape{3, 1});
|
||||
auto input_R = std::make_shared<ngraph::opset4::Constant>(ngraph::element::f32, ngraph::Shape{3, 1});
|
||||
auto input_B = std::make_shared<ngraph::opset4::Constant>(ngraph::element::f32, ngraph::Shape{3});
|
||||
|
||||
auto cell = std::make_shared<ngraph::opset4::GRUCell>(input_data, input_H_state, input_W, input_R, input_B, 1);
|
||||
|
||||
auto axis_unsqueeze = std::make_shared<ngraph::opset4::Constant>(ngraph::element::i64, ngraph::Shape{1}, 0);
|
||||
auto unsqueeze = std::make_shared<ngraph::opset4::Unsqueeze>(cell, axis_unsqueeze);
|
||||
ngraph::pattern::Matcher matcher(unsqueeze);
|
||||
|
||||
bool match = false;
|
||||
auto func = ti->get_body();
|
||||
for (const auto& res : func->get_results()) {
|
||||
match = matcher.match((res->get_input_source_output(0)));
|
||||
if (match)
|
||||
break;
|
||||
}
|
||||
|
||||
// All nodes are in the TI body should be matched in pattern
|
||||
if (!match || (matcher.get_matched_nodes().size() + func->get_results().size()) != func->get_ops().size())
|
||||
return false;
|
||||
|
||||
auto pattern_map = matcher.get_pattern_map();
|
||||
|
||||
auto params = func->get_parameters();
|
||||
std::vector<std::shared_ptr<ngraph::opset4::TensorIterator::InputDescription>> ordered_in_descs(3);
|
||||
int64_t stride = 0, slice_axis = 0;
|
||||
size_t batch_size = 0;
|
||||
for (const auto& input_desc : ti->get_input_descriptions()) {
|
||||
auto param = params[input_desc->m_body_parameter_index];
|
||||
if (param == pattern_map[data]) {
|
||||
// to get batch size value
|
||||
if (param->get_partial_shape().is_dynamic()) {
|
||||
return false;
|
||||
}
|
||||
auto slice_input
|
||||
= std::dynamic_pointer_cast<ngraph::opset4::TensorIterator::SliceInputDescription>(input_desc);
|
||||
if (!slice_input)
|
||||
return false;
|
||||
|
||||
stride = slice_input->m_stride;
|
||||
slice_axis = slice_input->m_axis;
|
||||
if (!(slice_axis == 0 || slice_axis == 1)) {
|
||||
return false;
|
||||
}
|
||||
batch_size = param->get_shape()[slice_axis == 0 ? 1 : 0];
|
||||
ordered_in_descs[0] = input_desc;
|
||||
} else if (param == pattern_map[input_H_state]) {
|
||||
ordered_in_descs[1] = input_desc;
|
||||
} else {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
auto seq_lengths = ngraph::opset4::Constant::create(element::i32, Shape{batch_size}, {ti->get_num_iterations()});
|
||||
|
||||
auto results = func->get_results();
|
||||
std::vector<std::shared_ptr<ngraph::opset4::TensorIterator::OutputDescription>> ordered_out_descs(2);
|
||||
for (const auto& output_desc : ti->get_output_descriptions()) {
|
||||
std::shared_ptr<opset4::Result> res = results[output_desc->m_body_value_index];
|
||||
if (res->get_input_source_output(0) == pattern_map[unsqueeze]) {
|
||||
auto concat_output
|
||||
= std::dynamic_pointer_cast<ngraph::opset4::TensorIterator::ConcatOutputDescription>(output_desc);
|
||||
if (!concat_output)
|
||||
return false;
|
||||
|
||||
stride = concat_output->m_stride;
|
||||
ordered_out_descs[0] = output_desc;
|
||||
} else if (res->get_input_source_output(0) == pattern_map[cell]->output(0)) {
|
||||
ordered_out_descs[1] = output_desc;
|
||||
} else {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
const auto& rnn_cell = std::dynamic_pointer_cast<ngraph::opset4::GRUCell>(pattern_map[cell]);
|
||||
|
||||
auto in_0 = ti->input_values()[ordered_in_descs[0]->m_input_index];
|
||||
if (slice_axis == 0) {
|
||||
auto order = ngraph::opset4::Constant::create(ngraph::element::i64, ngraph::Shape{3}, {1, 0, 2});
|
||||
in_0 = std::make_shared<ngraph::opset4::Transpose>(ti->input_values()[ordered_in_descs[0]->m_input_index], order);
|
||||
}
|
||||
|
||||
auto axis_1 = ngraph::opset4::Constant::create(ngraph::element::i64, ngraph::Shape{1}, {1});
|
||||
auto in_1 = std::make_shared<ngraph::opset4::Unsqueeze>(ti->input_values()[ordered_in_descs[1]->m_input_index], axis_1);
|
||||
|
||||
auto axis_2 = ngraph::opset4::Constant::create(ngraph::element::i64, ngraph::Shape{1}, {0});
|
||||
auto in_3 = std::make_shared<ngraph::opset4::Unsqueeze>(pattern_map[input_W]->output(0).get_node_shared_ptr(), axis_2);
|
||||
auto in_4 = std::make_shared<ngraph::opset4::Unsqueeze>(pattern_map[input_R]->output(0).get_node_shared_ptr(), axis_2);
|
||||
auto in_5 = std::make_shared<ngraph::opset4::Unsqueeze>(pattern_map[input_B]->output(0).get_node_shared_ptr(), axis_2);
|
||||
auto sequence = std::make_shared<op::v5::GRUSequence>(
|
||||
in_0,
|
||||
in_1,
|
||||
seq_lengths,
|
||||
in_3,
|
||||
in_4,
|
||||
in_5,
|
||||
rnn_cell->get_hidden_size(),
|
||||
stride > 0 ? ngraph::op::RecurrentSequenceDirection::FORWARD: ngraph::op::RecurrentSequenceDirection::REVERSE,
|
||||
rnn_cell->get_activations(),
|
||||
rnn_cell->get_activations_alpha(),
|
||||
rnn_cell->get_activations_beta(),
|
||||
rnn_cell->get_clip(),
|
||||
rnn_cell->get_linear_before_reset());
|
||||
|
||||
auto axis_out = ngraph::opset4::Constant::create(ngraph::element::i64, ngraph::Shape{1}, {1});
|
||||
auto out_0 = std::make_shared<ngraph::opset4::Squeeze>(sequence->output(0), axis_out);
|
||||
auto out_1 = std::make_shared<ngraph::opset4::Squeeze>(sequence->output(1), axis_out);
|
||||
|
||||
std::shared_ptr<Node> out = out_0;
|
||||
if (slice_axis == 0) {
|
||||
auto order = ngraph::opset4::Constant::create(ngraph::element::i64, ngraph::Shape{3}, {1, 0, 2});
|
||||
out = std::make_shared<ngraph::opset4::Transpose>(out_0, order);
|
||||
}
|
||||
|
||||
ngraph::NodeVector outputs = {out, out_1};
|
||||
for (size_t i = 0; i < ordered_out_descs.size(); ++i) {
|
||||
if (ordered_out_descs[i]) {
|
||||
for (const auto &input : ti->output(ordered_out_descs[i]->m_output_index).get_target_inputs()) {
|
||||
input.replace_source_output(outputs[i]->output(0));
|
||||
}
|
||||
outputs[i]->get_output_tensor(0).set_name(op::util::create_ie_output_name(ti->output(ordered_out_descs[i]->m_output_index)));
|
||||
}
|
||||
}
|
||||
|
||||
ngraph::OutputVector new_nodes = {in_1, in_3, in_4, in_5, sequence, out_0, out_1};
|
||||
if (slice_axis == 0) {
|
||||
new_nodes.push_back(out);
|
||||
new_nodes.push_back(in_0);
|
||||
}
|
||||
copy_runtime_info(ti, as_node_vector(new_nodes));
|
||||
return true;
|
||||
};
|
||||
|
||||
auto m = std::make_shared<ngraph::pattern::Matcher>(tensor_iterator, "ConvertTensorIteratorToGRUSequence");
|
||||
register_matcher(m, callback);
|
||||
}
|
@ -0,0 +1,278 @@
|
||||
// Copyright (C) 2020 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#include <gtest/gtest.h>
|
||||
|
||||
#include "common_test_utils/test_common.hpp"
|
||||
#include <string>
|
||||
#include <memory>
|
||||
#include <queue>
|
||||
|
||||
#include <ngraph/pass/manager.hpp>
|
||||
#include <ngraph/function.hpp>
|
||||
#include <ngraph/opsets/opset4.hpp>
|
||||
#include <ngraph_ops/fully_connected.hpp>
|
||||
#include <transformations/tensor_iterator_transformations/convert_ti_to_sequences.h>
|
||||
#include <transformations/utils/utils.hpp>
|
||||
#include <transformations/init_node_info.hpp>
|
||||
|
||||
#include "common_test_utils/ngraph_test_utils.hpp"
|
||||
|
||||
using namespace testing;
|
||||
using namespace ngraph;
|
||||
|
||||
TEST(TransformationTests, ConvertTensorIteratorToLSTMSequence) {
|
||||
std::shared_ptr<ngraph::Function> f(nullptr), f_ref(nullptr);
|
||||
{
|
||||
auto X = std::make_shared<opset4::Parameter>(element::f32, Shape{1, 2, 16});
|
||||
auto Y = std::make_shared<opset4::Parameter>(element::f32, Shape{1, 128});
|
||||
auto Z = std::make_shared<opset4::Parameter>(element::f32, Shape{1, 128});
|
||||
|
||||
auto Xi = std::make_shared<opset4::Parameter>(element::f32, Shape{1, 1, 16});
|
||||
auto Yi = std::make_shared<opset4::Parameter>(element::f32, Shape{1, 128});
|
||||
auto Zi = std::make_shared<opset4::Parameter>(element::f32, Shape{1, 128});
|
||||
|
||||
// Body
|
||||
auto axis = ngraph::opset4::Constant::create(ngraph::element::i64, ngraph::Shape{1}, {1});
|
||||
auto squeeze = std::make_shared<opset4::Squeeze>(Xi, axis);
|
||||
|
||||
auto w_val = std::vector<float>(512 * 16, 0);
|
||||
auto r_val = std::vector<float>(512 * 128, 0);
|
||||
auto b_val = std::vector<float>(512, 0);
|
||||
auto W = ngraph::opset4::Constant::create(ngraph::element::f32, ngraph::Shape{512, 16}, w_val);
|
||||
auto R = ngraph::opset4::Constant::create(ngraph::element::f32, ngraph::Shape{512, 128}, r_val);
|
||||
auto B = ngraph::opset4::Constant::create(ngraph::element::f32, ngraph::Shape{512}, b_val);
|
||||
|
||||
auto lstm_cell = std::make_shared<opset4::LSTMCell>(squeeze, Yi, Zi, W, R, B, 128);
|
||||
auto res_1 = std::make_shared<opset4::Result>(lstm_cell);
|
||||
auto axis_unsqueeze = ngraph::opset4::Constant::create(ngraph::element::i64, ngraph::Shape{1}, {1});
|
||||
auto unsqueeze = std::make_shared<opset4::Unsqueeze>(lstm_cell, axis_unsqueeze);
|
||||
auto res_2 = std::make_shared<opset4::Result>(unsqueeze);
|
||||
auto body = std::make_shared<Function>(OutputVector{res_1, res_2},
|
||||
ParameterVector{Xi, Yi, Zi});
|
||||
|
||||
auto tensor_iterator = std::make_shared<opset4::TensorIterator>();
|
||||
tensor_iterator->set_body(body);
|
||||
|
||||
tensor_iterator->set_invariant_input(Zi, Z);
|
||||
tensor_iterator->set_sliced_input(Xi, X, 0, 1, 1, -1, 1);
|
||||
tensor_iterator->set_merged_input(Yi, Y, res_1);
|
||||
|
||||
auto out0 = tensor_iterator->get_iter_value(res_1, -1);
|
||||
auto out1 = tensor_iterator->get_concatenated_slices(res_2, 0, 1, 1, -1, 1);
|
||||
|
||||
auto res_ti_1 = std::make_shared<opset4::Result>(tensor_iterator->output(1));
|
||||
//auto res_ti_2 = std::make_shared<opset4::Result>(tensor_iterator->output(0));
|
||||
f = std::make_shared<ngraph::Function>(ngraph::NodeVector{res_ti_1},
|
||||
ngraph::ParameterVector{X, Y, Z});
|
||||
|
||||
ngraph::pass::Manager m;
|
||||
m.register_pass<ngraph::pass::InitNodeInfo>();
|
||||
m.register_pass<ngraph::pass::ConvertTensorIteratorToLSTMSequence>();
|
||||
m.set_callback([](const std::shared_ptr<const Node>&) -> bool { return true; });
|
||||
m.run_passes(f);
|
||||
ASSERT_NO_THROW(check_rt_info(f));
|
||||
}
|
||||
|
||||
{
|
||||
auto X = std::make_shared<opset4::Parameter>(element::f32, Shape{1, 2, 16});
|
||||
auto Y = std::make_shared<opset4::Parameter>(element::f32, Shape{1, 128});
|
||||
auto Z = std::make_shared<opset4::Parameter>(element::f32, Shape{1, 128});
|
||||
|
||||
auto w_val = std::vector<float>(512 * 16, 0);
|
||||
auto r_val = std::vector<float>(512 * 128, 0);
|
||||
auto b_val = std::vector<float>(512, 0);
|
||||
auto W = ngraph::opset4::Constant::create(ngraph::element::f32, ngraph::Shape{512, 16}, w_val);
|
||||
auto R = ngraph::opset4::Constant::create(ngraph::element::f32, ngraph::Shape{512, 128}, r_val);
|
||||
auto B = ngraph::opset4::Constant::create(ngraph::element::f32, ngraph::Shape{512}, b_val);
|
||||
|
||||
auto axis_1 = ngraph::opset4::Constant::create(ngraph::element::i64, ngraph::Shape{1}, {1});
|
||||
auto in_1 = std::make_shared<ngraph::opset4::Unsqueeze>(Y, axis_1);
|
||||
auto in_2 = std::make_shared<ngraph::opset4::Unsqueeze>(Z, axis_1);
|
||||
|
||||
auto axis_2 = ngraph::opset4::Constant::create(ngraph::element::i64, ngraph::Shape{1}, {0});
|
||||
auto in_4 = std::make_shared<ngraph::opset4::Unsqueeze>(W, axis_2);
|
||||
auto in_5 = std::make_shared<ngraph::opset4::Unsqueeze>(R, axis_2);
|
||||
auto in_6 = std::make_shared<ngraph::opset4::Unsqueeze>(B, axis_2);
|
||||
|
||||
auto seq_lengths = ngraph::opset4::Constant::create(element::i32, Shape{1}, {2});
|
||||
auto lstm_seq = std::make_shared<op::v5::LSTMSequence>(X, in_1, in_2, seq_lengths, in_4, in_5, in_6,
|
||||
128, ngraph::op::RecurrentSequenceDirection::FORWARD);
|
||||
auto axis_out = ngraph::opset4::Constant::create(ngraph::element::i64, ngraph::Shape{1}, {1});
|
||||
auto out_0 = std::make_shared<ngraph::opset4::Squeeze>(lstm_seq->output(0), axis_out);
|
||||
auto out_1 = std::make_shared<ngraph::opset4::Squeeze>(lstm_seq->output(1), axis_out);
|
||||
auto out_2 = std::make_shared<ngraph::opset4::Squeeze>(lstm_seq->output(1), axis_out);
|
||||
auto res_ti_1 = std::make_shared<opset4::Result>(out_0);
|
||||
f_ref = std::make_shared<ngraph::Function>(ngraph::NodeVector{res_ti_1}, ngraph::ParameterVector{X, Y, Z});
|
||||
}
|
||||
|
||||
auto res = compare_functions(f, f_ref);
|
||||
ASSERT_TRUE(res.first) << res.second;
|
||||
}
|
||||
|
||||
TEST(TransformationTests, ConvertTensorIteratorToRNNSequence) {
|
||||
std::shared_ptr<ngraph::Function> f(nullptr), f_ref(nullptr);
|
||||
{
|
||||
auto X = std::make_shared<opset4::Parameter>(element::f32, Shape{1, 2, 16});
|
||||
auto Y = std::make_shared<opset4::Parameter>(element::f32, Shape{1, 128});
|
||||
|
||||
auto Xi = std::make_shared<opset4::Parameter>(element::f32, Shape{1, 1, 16});
|
||||
auto Yi = std::make_shared<opset4::Parameter>(element::f32, Shape{1, 128});
|
||||
|
||||
// Body
|
||||
auto axis = ngraph::opset4::Constant::create(ngraph::element::i64, ngraph::Shape{1}, {1});
|
||||
auto squeeze = std::make_shared<opset4::Squeeze>(Xi, axis);
|
||||
|
||||
auto w_val = std::vector<float>(128 * 16, 0);
|
||||
auto r_val = std::vector<float>(128 * 128, 0);
|
||||
auto b_val = std::vector<float>(128, 0);
|
||||
auto W = ngraph::opset4::Constant::create(ngraph::element::f32, ngraph::Shape{128, 16}, w_val);
|
||||
auto R = ngraph::opset4::Constant::create(ngraph::element::f32, ngraph::Shape{128, 128}, r_val);
|
||||
auto B = ngraph::opset4::Constant::create(ngraph::element::f32, ngraph::Shape{128}, b_val);
|
||||
|
||||
auto rnn_cell = std::make_shared<opset4::RNNCell>(squeeze, Yi, W, R, B, 128);
|
||||
auto res_1 = std::make_shared<opset4::Result>(rnn_cell);
|
||||
auto axis_unsqueeze = ngraph::opset4::Constant::create(ngraph::element::i64, ngraph::Shape{1}, {1});
|
||||
auto unsqueeze = std::make_shared<opset4::Unsqueeze>(rnn_cell, axis_unsqueeze);
|
||||
auto res_2 = std::make_shared<opset4::Result>(unsqueeze);
|
||||
auto body = std::make_shared<Function>(OutputVector{res_1, res_2},
|
||||
ParameterVector{Xi, Yi});
|
||||
|
||||
auto tensor_iterator = std::make_shared<opset4::TensorIterator>();
|
||||
tensor_iterator->set_body(body);
|
||||
|
||||
tensor_iterator->set_sliced_input(Xi, X, 0, 1, 1, -1, 1);
|
||||
tensor_iterator->set_merged_input(Yi, Y, res_1);
|
||||
|
||||
auto out0 = tensor_iterator->get_iter_value(res_1, -1);
|
||||
auto out1 = tensor_iterator->get_concatenated_slices(res_2, 0, 1, 1, -1, 1);
|
||||
|
||||
auto res_ti_1 = std::make_shared<opset4::Result>(tensor_iterator->output(1));
|
||||
//auto res_ti_2 = std::make_shared<opset4::Result>(tensor_iterator->output(0));
|
||||
f = std::make_shared<ngraph::Function>(ngraph::NodeVector{res_ti_1},
|
||||
ngraph::ParameterVector{X, Y});
|
||||
|
||||
ngraph::pass::Manager m;
|
||||
m.register_pass<ngraph::pass::InitNodeInfo>();
|
||||
m.register_pass<ngraph::pass::ConvertTensorIteratorToRNNSequence>();
|
||||
m.set_callback([](const std::shared_ptr<const Node>&) -> bool { return true; });
|
||||
m.run_passes(f);
|
||||
ASSERT_NO_THROW(check_rt_info(f));
|
||||
}
|
||||
|
||||
{
|
||||
auto X = std::make_shared<opset4::Parameter>(element::f32, Shape{1, 2, 16});
|
||||
auto Y = std::make_shared<opset4::Parameter>(element::f32, Shape{1, 128});
|
||||
|
||||
auto w_val = std::vector<float>(128 * 16, 0);
|
||||
auto r_val = std::vector<float>(128 * 128, 0);
|
||||
auto b_val = std::vector<float>(128, 0);
|
||||
auto W = ngraph::opset4::Constant::create(ngraph::element::f32, ngraph::Shape{128, 16}, w_val);
|
||||
auto R = ngraph::opset4::Constant::create(ngraph::element::f32, ngraph::Shape{128, 128}, r_val);
|
||||
auto B = ngraph::opset4::Constant::create(ngraph::element::f32, ngraph::Shape{128}, b_val);
|
||||
|
||||
auto axis_1 = ngraph::opset4::Constant::create(ngraph::element::i64, ngraph::Shape{1}, {1});
|
||||
auto in_1 = std::make_shared<ngraph::opset4::Unsqueeze>(Y, axis_1);
|
||||
|
||||
auto axis_2 = ngraph::opset4::Constant::create(ngraph::element::i64, ngraph::Shape{1}, {0});
|
||||
auto in_3 = std::make_shared<ngraph::opset4::Unsqueeze>(W, axis_2);
|
||||
auto in_4 = std::make_shared<ngraph::opset4::Unsqueeze>(R, axis_2);
|
||||
auto in_5 = std::make_shared<ngraph::opset4::Unsqueeze>(B, axis_2);
|
||||
|
||||
auto seq_lengths = ngraph::opset4::Constant::create(element::i32, Shape{1}, {2});
|
||||
auto rnn_sequence = std::make_shared<op::v5::RNNSequence>(X, in_1, seq_lengths, in_3, in_4, in_5,
|
||||
128, ngraph::op::RecurrentSequenceDirection::FORWARD);
|
||||
auto axis_out = ngraph::opset4::Constant::create(ngraph::element::i64, ngraph::Shape{1}, {1});
|
||||
auto out_0 = std::make_shared<ngraph::opset4::Squeeze>(rnn_sequence->output(0), axis_out);
|
||||
auto out_1 = std::make_shared<ngraph::opset4::Squeeze>(rnn_sequence->output(1), axis_out);
|
||||
auto res_ti_1 = std::make_shared<opset4::Result>(out_0);
|
||||
f_ref = std::make_shared<ngraph::Function>(ngraph::NodeVector{res_ti_1}, ngraph::ParameterVector{X, Y});
|
||||
}
|
||||
|
||||
auto res = compare_functions(f, f_ref);
|
||||
ASSERT_TRUE(res.first) << res.second;
|
||||
}
|
||||
|
||||
TEST(TransformationTests, ConvertTensorIteratorToGRUSequence) {
|
||||
std::shared_ptr<ngraph::Function> f(nullptr), f_ref(nullptr);
|
||||
{
|
||||
auto X = std::make_shared<opset4::Parameter>(element::f32, Shape{1, 2, 16});
|
||||
auto Y = std::make_shared<opset4::Parameter>(element::f32, Shape{1, 128});
|
||||
|
||||
auto Xi = std::make_shared<opset4::Parameter>(element::f32, Shape{1, 1, 16});
|
||||
auto Yi = std::make_shared<opset4::Parameter>(element::f32, Shape{1, 128});
|
||||
|
||||
// Body
|
||||
auto axis = ngraph::opset4::Constant::create(ngraph::element::i64, ngraph::Shape{1}, {1});
|
||||
auto squeeze = std::make_shared<opset4::Squeeze>(Xi, axis);
|
||||
|
||||
auto w_val = std::vector<float>(384 * 16, 0);
|
||||
auto r_val = std::vector<float>(384 * 128, 0);
|
||||
auto b_val = std::vector<float>(384, 0);
|
||||
auto W = ngraph::opset4::Constant::create(ngraph::element::f32, ngraph::Shape{384, 16}, w_val);
|
||||
auto R = ngraph::opset4::Constant::create(ngraph::element::f32, ngraph::Shape{384, 128}, r_val);
|
||||
auto B = ngraph::opset4::Constant::create(ngraph::element::f32, ngraph::Shape{384}, b_val);
|
||||
|
||||
auto gru_cell = std::make_shared<opset4::GRUCell>(squeeze, Yi, W, R, B, 128);
|
||||
auto res_1 = std::make_shared<opset4::Result>(gru_cell);
|
||||
auto axis_unsqueeze = ngraph::opset4::Constant::create(ngraph::element::i64, ngraph::Shape{1}, {1});
|
||||
auto unsqueeze = std::make_shared<opset4::Unsqueeze>(gru_cell, axis_unsqueeze);
|
||||
auto res_2 = std::make_shared<opset4::Result>(unsqueeze);
|
||||
auto body = std::make_shared<Function>(OutputVector{res_1, res_2},
|
||||
ParameterVector{Xi, Yi});
|
||||
|
||||
auto tensor_iterator = std::make_shared<opset4::TensorIterator>();
|
||||
tensor_iterator->set_body(body);
|
||||
|
||||
tensor_iterator->set_sliced_input(Xi, X, 0, 1, 1, -1, 1);
|
||||
tensor_iterator->set_merged_input(Yi, Y, res_1);
|
||||
|
||||
auto out0 = tensor_iterator->get_iter_value(res_1, -1);
|
||||
auto out1 = tensor_iterator->get_concatenated_slices(res_2, 0, 1, 1, -1, 1);
|
||||
|
||||
auto res_ti_1 = std::make_shared<opset4::Result>(tensor_iterator->output(1));
|
||||
//auto res_tRNNCelli_2 = std::make_shared<opset4::Result>(tensor_iterator->output(0));
|
||||
f = std::make_shared<ngraph::Function>(ngraph::NodeVector{res_ti_1},
|
||||
ngraph::ParameterVector{X, Y});
|
||||
|
||||
ngraph::pass::Manager m;
|
||||
m.register_pass<ngraph::pass::InitNodeInfo>();
|
||||
m.register_pass<ngraph::pass::ConvertTensorIteratorToGRUSequence>();
|
||||
m.set_callback([](const std::shared_ptr<const Node>&) -> bool { return true; });
|
||||
m.run_passes(f);
|
||||
ASSERT_NO_THROW(check_rt_info(f));
|
||||
}
|
||||
|
||||
{
|
||||
auto X = std::make_shared<opset4::Parameter>(element::f32, Shape{1, 2, 16});
|
||||
auto Y = std::make_shared<opset4::Parameter>(element::f32, Shape{1, 128});
|
||||
|
||||
auto w_val = std::vector<float>(384 * 16, 0);
|
||||
auto r_val = std::vector<float>(384 * 128, 0);
|
||||
auto b_val = std::vector<float>(384, 0);
|
||||
auto W = ngraph::opset4::Constant::create(ngraph::element::f32, ngraph::Shape{384, 16}, w_val);
|
||||
auto R = ngraph::opset4::Constant::create(ngraph::element::f32, ngraph::Shape{384, 128}, r_val);
|
||||
auto B = ngraph::opset4::Constant::create(ngraph::element::f32, ngraph::Shape{384}, b_val);
|
||||
|
||||
auto axis_1 = ngraph::opset4::Constant::create(ngraph::element::i64, ngraph::Shape{1}, {1});
|
||||
auto in_1 = std::make_shared<ngraph::opset4::Unsqueeze>(Y, axis_1);
|
||||
|
||||
auto axis_2 = ngraph::opset4::Constant::create(ngraph::element::i64, ngraph::Shape{1}, {0});
|
||||
auto in_3 = std::make_shared<ngraph::opset4::Unsqueeze>(W, axis_2);
|
||||
auto in_4 = std::make_shared<ngraph::opset4::Unsqueeze>(R, axis_2);
|
||||
auto in_5 = std::make_shared<ngraph::opset4::Unsqueeze>(B, axis_2);
|
||||
|
||||
auto seq_lengths = ngraph::opset4::Constant::create(element::i32, Shape{1}, {2});
|
||||
auto gru_sequence = std::make_shared<op::v5::GRUSequence>(X, in_1, seq_lengths, in_3, in_4, in_5,
|
||||
128, ngraph::op::RecurrentSequenceDirection::FORWARD);
|
||||
auto axis_out = ngraph::opset4::Constant::create(ngraph::element::i64, ngraph::Shape{1}, {1});
|
||||
auto out_0 = std::make_shared<ngraph::opset4::Squeeze>(gru_sequence->output(0), axis_out);
|
||||
auto out_1 = std::make_shared<ngraph::opset4::Squeeze>(gru_sequence->output(1), axis_out);
|
||||
auto res_ti_1 = std::make_shared<opset4::Result>(out_0);
|
||||
f_ref = std::make_shared<ngraph::Function>(ngraph::NodeVector{res_ti_1}, ngraph::ParameterVector{X, Y});
|
||||
}
|
||||
|
||||
auto res = compare_functions(f, f_ref);
|
||||
ASSERT_TRUE(res.first) << res.second;
|
||||
}
|
Loading…
Reference in New Issue
Block a user