[Transformations] Convert TI to sequences: dynamic case fix (#9742)
* Convert TI to sequences: dynamic case fix * tests corrected
This commit is contained in:
committed by
GitHub
parent
cf4d4db4c7
commit
43130622d3
@@ -102,12 +102,9 @@ bool convertTensorIteratorToSequence(
|
||||
std::make_shared<ngraph::opset5::Unsqueeze>(ti_inputs[ordered_in_descs[2]->m_input_index], axis_1);
|
||||
|
||||
const size_t batch_dim = slice_axis == 0 ? 1 : 0;
|
||||
//TODO: replace to ngraph::op::util::node_to_get_shape_value_of_indices_from_shape_node(target_param, { batch_dimension });
|
||||
auto shape_node = ngraph::op::util::make_try_fold<ngraph::opset5::ShapeOf>(ti_inputs[ordered_in_descs[0]->m_input_index]);
|
||||
auto batch_dimension = ngraph::op::util::make_try_fold<ngraph::opset5::Gather>(
|
||||
shape_node,
|
||||
ngraph::opset5::Constant::create(ngraph::element::i64, { 1 }, { batch_dim }),
|
||||
ngraph::opset5::Constant::create(ngraph::element::i64, {}, { 0 }));
|
||||
auto batch_dimension = ngraph::op::util::node_to_get_shape_value_of_indices_from_shape_source(
|
||||
ti_inputs[ordered_in_descs[0]->m_input_index],
|
||||
{batch_dim});
|
||||
|
||||
auto seq_lengths_scalar = ngraph::opset5::Constant::create(ngraph::element::i32, {}, { ti->get_num_iterations() });
|
||||
auto seq_lengths = ngraph::op::util::make_try_fold<ngraph::opset5::Broadcast>(seq_lengths_scalar, batch_dimension);
|
||||
@@ -211,8 +208,8 @@ bool convertTensorIteratorToSequence(
|
||||
new_nodes.emplace_back(initial_cell_state);
|
||||
}
|
||||
if (!std::dynamic_pointer_cast<ngraph::opset5::Constant>(seq_lengths)) {
|
||||
new_nodes.emplace_back(shape_node);
|
||||
new_nodes.emplace_back(batch_dimension);
|
||||
new_nodes.emplace_back(batch_dimension->get_input_node_shared_ptr(0));
|
||||
new_nodes.emplace_back(seq_lengths_scalar);
|
||||
new_nodes.emplace_back(seq_lengths);
|
||||
}
|
||||
@@ -238,7 +235,7 @@ ngraph::pass::ConvertTensorIteratorToLSTMSequence::ConvertTensorIteratorToLSTMSe
|
||||
// create a pattern for the TensorIterator body
|
||||
auto data = ngraph::pattern::wrap_type<ngraph::opset5::Parameter>(ngraph::pattern::rank_equals(3));
|
||||
auto pattern_1 = ngraph::pattern::wrap_type<ngraph::opset5::Constant>(ngraph::pattern::rank_equals(1));
|
||||
auto squeeze = ngraph::pattern::wrap_type<ngraph::opset5::Reshape>({ data, pattern_1 });
|
||||
auto squeeze = ngraph::pattern::wrap_type<ngraph::opset5::Reshape, ngraph::opset5::Squeeze>({data, pattern_1});
|
||||
|
||||
auto input_H_state = ngraph::pattern::wrap_type<ngraph::opset5::Parameter>(ngraph::pattern::rank_equals(2));
|
||||
auto input_C_state = ngraph::pattern::wrap_type<ngraph::opset5::Parameter>(ngraph::pattern::rank_equals(2));
|
||||
@@ -246,7 +243,7 @@ ngraph::pass::ConvertTensorIteratorToLSTMSequence::ConvertTensorIteratorToLSTMSe
|
||||
auto input_R = ngraph::pattern::wrap_type<ngraph::opset5::Constant>(ngraph::pattern::rank_equals(2));
|
||||
auto input_B = ngraph::pattern::wrap_type<ngraph::opset5::Constant>(ngraph::pattern::rank_equals(1));
|
||||
|
||||
ngraph::OutputVector cell_inputs{ squeeze, input_H_state, input_C_state, input_W, input_R, input_B };
|
||||
ngraph::OutputVector cell_inputs{squeeze, input_H_state, input_C_state, input_W, input_R, input_B};
|
||||
auto cell = ngraph::pattern::wrap_type<ngraph::opset1::LSTMCell, ngraph::opset5::LSTMCell>(cell_inputs);
|
||||
|
||||
auto pattern_2 = ngraph::pattern::wrap_type<ngraph::opset5::Constant>(ngraph::pattern::rank_equals(1));
|
||||
@@ -297,14 +294,14 @@ ngraph::pass::ConvertTensorIteratorToRNNSequence::ConvertTensorIteratorToRNNSequ
|
||||
// create a pattern for the TensorIterator body
|
||||
auto data = ngraph::pattern::wrap_type<ngraph::opset5::Parameter>(ngraph::pattern::rank_equals(3));
|
||||
auto pattern_1 = ngraph::pattern::wrap_type<ngraph::opset5::Constant>(ngraph::pattern::rank_equals(1));
|
||||
auto squeeze = ngraph::pattern::wrap_type<ngraph::opset5::Reshape>({ data, pattern_1 });
|
||||
auto squeeze = ngraph::pattern::wrap_type<ngraph::opset5::Reshape, ngraph::opset5::Squeeze>({data, pattern_1});
|
||||
|
||||
auto input_H_state = ngraph::pattern::wrap_type<ngraph::opset5::Parameter>(ngraph::pattern::rank_equals(2));
|
||||
auto input_W = ngraph::pattern::wrap_type<ngraph::opset5::Constant>(ngraph::pattern::rank_equals(2));
|
||||
auto input_R = ngraph::pattern::wrap_type<ngraph::opset5::Constant>(ngraph::pattern::rank_equals(2));
|
||||
auto input_B = ngraph::pattern::wrap_type<ngraph::opset5::Constant>(ngraph::pattern::rank_equals(1));
|
||||
|
||||
ngraph::OutputVector cell_inputs{ squeeze, input_H_state, input_W, input_R, input_B };
|
||||
ngraph::OutputVector cell_inputs{squeeze, input_H_state, input_W, input_R, input_B};
|
||||
auto cell = ngraph::pattern::wrap_type<ngraph::opset5::RNNCell>(cell_inputs);
|
||||
|
||||
auto pattern_2 = ngraph::pattern::wrap_type<ngraph::opset5::Constant>(ngraph::pattern::rank_equals(1));
|
||||
@@ -354,14 +351,14 @@ ngraph::pass::ConvertTensorIteratorToGRUSequence::ConvertTensorIteratorToGRUSequ
|
||||
// create a pattern for the TensorIterator body
|
||||
auto data = ngraph::pattern::wrap_type<ngraph::opset5::Parameter>(ngraph::pattern::rank_equals(3));
|
||||
auto pattern_1 = ngraph::pattern::wrap_type<ngraph::opset5::Constant>(ngraph::pattern::rank_equals(1));
|
||||
auto squeeze = ngraph::pattern::wrap_type<ngraph::opset5::Reshape>({ data, pattern_1 });
|
||||
auto squeeze = ngraph::pattern::wrap_type<ngraph::opset5::Reshape, ngraph::opset5::Squeeze>({data, pattern_1});
|
||||
|
||||
auto input_H_state = ngraph::pattern::wrap_type<ngraph::opset5::Parameter>(ngraph::pattern::rank_equals(2));
|
||||
auto input_W = ngraph::pattern::wrap_type<ngraph::opset5::Constant>(ngraph::pattern::rank_equals(2));
|
||||
auto input_R = ngraph::pattern::wrap_type<ngraph::opset5::Constant>(ngraph::pattern::rank_equals(2));
|
||||
auto input_B = ngraph::pattern::wrap_type<ngraph::opset5::Constant>(ngraph::pattern::rank_equals(1));
|
||||
|
||||
ngraph::OutputVector cell_inputs{ squeeze, input_H_state, input_W, input_R, input_B };
|
||||
ngraph::OutputVector cell_inputs{squeeze, input_H_state, input_W, input_R, input_B};
|
||||
auto cell = ngraph::pattern::wrap_type<ngraph::opset5::GRUCell>(cell_inputs);
|
||||
|
||||
auto pattern_2 = ngraph::pattern::wrap_type<ngraph::opset5::Constant>(ngraph::pattern::rank_equals(1));
|
||||
|
||||
@@ -6,6 +6,7 @@
|
||||
|
||||
#include <ngraph/opsets/opset1.hpp>
|
||||
#include <ngraph/opsets/opset5.hpp>
|
||||
#include <ngraph/opsets/opset7.hpp>
|
||||
#include <ngraph/pattern/op/wrap_type.hpp>
|
||||
#include <transformations/utils/utils.hpp>
|
||||
#include <ngraph/variant.hpp>
|
||||
@@ -56,7 +57,7 @@ namespace {
|
||||
ngraph::Output<ngraph::Node> in_0 = sequenceOp->get_input_node_shared_ptr(0)->input_value(0);
|
||||
|
||||
auto shapeBeforeTranspose = ngraph::op::util::make_try_fold<ngraph::opset1::ShapeOf>(in_0);
|
||||
auto newInShape = ngraph::op::util::make_try_fold<ngraph::opset1::Gather>(shapeBeforeTranspose,
|
||||
auto newInShape = ngraph::op::util::make_try_fold<ngraph::opset7::Gather>(shapeBeforeTranspose,
|
||||
ngraph::opset1::Constant::create(ngraph::element::i32, { 3 }, { 1, 0, 2 }),
|
||||
ngraph::opset1::Constant::create(ngraph::element::i32, {}, { 0 }));
|
||||
auto reshape1 = std::make_shared<ngraph::opset1::Reshape>(in_0, newInShape, false);
|
||||
@@ -69,7 +70,7 @@ namespace {
|
||||
auto transposeAfter = seqTargetInputs.begin()->get_node()->shared_from_this();
|
||||
|
||||
auto lstmOutShape = ngraph::op::util::make_try_fold<ngraph::opset1::ShapeOf>(sequenceOp->output(0));
|
||||
auto newOutShape = ngraph::op::util::make_try_fold<ngraph::opset1::Gather>(lstmOutShape,
|
||||
auto newOutShape = ngraph::op::util::make_try_fold<ngraph::opset7::Gather>(lstmOutShape,
|
||||
ngraph::opset1::Constant::create(ngraph::element::i32, { 4 }, { 2, 1, 0, 3 }),
|
||||
ngraph::opset1::Constant::create(ngraph::element::i32, {}, { 0 }));
|
||||
|
||||
|
||||
@@ -11,6 +11,7 @@
|
||||
#include <ngraph/pass/manager.hpp>
|
||||
#include <ngraph/function.hpp>
|
||||
#include <ngraph/opsets/opset5.hpp>
|
||||
#include <ngraph/opsets/opset7.hpp>
|
||||
#include <transformations/op_conversions/convert_ti_to_sequences.hpp>
|
||||
#include <transformations/utils/utils.hpp>
|
||||
#include <transformations/init_node_info.hpp>
|
||||
@@ -116,7 +117,7 @@ TEST(TransformationTests, ConvertTensorIteratorToLSTMSequence) {
|
||||
ASSERT_TRUE(res.first) << res.second;
|
||||
}
|
||||
|
||||
TEST(TransformationTests, ConvertTensorIteratorToLSTMSequenceDynamic) {
|
||||
TEST(TransformationTests, ConvertTensorIteratorToLSTMSequenceDynamicReshapeCase) {
|
||||
std::shared_ptr<ngraph::Function> f(nullptr), f_ref(nullptr);
|
||||
{
|
||||
auto X = std::make_shared<opset5::Parameter>(element::f32, PartialShape{ -1, 2, -1 });
|
||||
@@ -176,7 +177,7 @@ TEST(TransformationTests, ConvertTensorIteratorToLSTMSequenceDynamic) {
|
||||
auto in_2 = std::make_shared<ngraph::opset5::Unsqueeze>(Z, axis_1);
|
||||
|
||||
auto shape_of = std::make_shared<opset5::ShapeOf>(X);
|
||||
auto batch_dimension = ngraph::op::util::make_try_fold<ngraph::opset5::Gather>(
|
||||
auto batch_dimension = ngraph::op::util::make_try_fold<ngraph::opset7::Gather>(
|
||||
shape_of,
|
||||
ngraph::opset5::Constant::create(ngraph::element::i64, { 1 }, { 0 }),
|
||||
ngraph::opset5::Constant::create(ngraph::element::i64, {}, { 0 }));
|
||||
@@ -203,6 +204,103 @@ TEST(TransformationTests, ConvertTensorIteratorToLSTMSequenceDynamic) {
|
||||
ASSERT_TRUE(res.first) << res.second;
|
||||
}
|
||||
|
||||
TEST(TransformationTests, ConvertTensorIteratorToLSTMSequenceDynamicSqueezeCase) {
|
||||
std::shared_ptr<ngraph::Function> f(nullptr), f_ref(nullptr);
|
||||
{
|
||||
auto X = std::make_shared<opset5::Parameter>(element::f32, PartialShape{-1, 2, -1});
|
||||
auto Y = std::make_shared<opset5::Parameter>(element::f32, PartialShape{-1, -1});
|
||||
auto Z = std::make_shared<opset5::Parameter>(element::f32, PartialShape{-1, -1});
|
||||
|
||||
auto Xi = std::make_shared<opset5::Parameter>(element::f32, PartialShape{-1, 1, -1});
|
||||
auto Yi = std::make_shared<opset5::Parameter>(element::f32, PartialShape{-1, -1});
|
||||
auto Zi = std::make_shared<opset5::Parameter>(element::f32, PartialShape{-1, -1});
|
||||
|
||||
// Body
|
||||
auto axis = ngraph::opset5::Constant::create(ngraph::element::i64, ngraph::Shape{1}, {1});
|
||||
auto squeeze = std::make_shared<opset5::Squeeze>(Xi, axis);
|
||||
|
||||
auto w_val = std::vector<float>(512 * 16, 0);
|
||||
auto r_val = std::vector<float>(512 * 128, 0);
|
||||
auto b_val = std::vector<float>(512, 0);
|
||||
auto W = ngraph::opset5::Constant::create(ngraph::element::f32, ngraph::Shape{512, 16}, w_val);
|
||||
auto R = ngraph::opset5::Constant::create(ngraph::element::f32, ngraph::Shape{512, 128}, r_val);
|
||||
auto B = ngraph::opset5::Constant::create(ngraph::element::f32, ngraph::Shape{512}, b_val);
|
||||
|
||||
auto lstm_cell = std::make_shared<opset5::LSTMCell>(squeeze, Yi, Zi, W, R, B, 128);
|
||||
|
||||
auto res_1 = std::make_shared<opset5::Result>(lstm_cell);
|
||||
auto reshape_pattern_2 = ngraph::opset5::Constant::create(ngraph::element::i64, ngraph::Shape{3}, {1, 1, 128});
|
||||
auto unsqueeze = std::make_shared<opset5::Reshape>(lstm_cell, reshape_pattern_2, false);
|
||||
auto res_2 = std::make_shared<opset5::Result>(unsqueeze);
|
||||
auto body = std::make_shared<Function>(OutputVector{res_1, res_2}, ParameterVector{Xi, Yi, Zi});
|
||||
|
||||
auto tensor_iterator = std::make_shared<opset5::TensorIterator>();
|
||||
tensor_iterator->set_body(body);
|
||||
|
||||
tensor_iterator->set_invariant_input(Zi, Z);
|
||||
tensor_iterator->set_sliced_input(Xi, X, 0, 1, 1, -1, 1);
|
||||
tensor_iterator->set_merged_input(Yi, Y, res_1);
|
||||
|
||||
auto out0 = tensor_iterator->get_iter_value(res_1, -1);
|
||||
auto out1 = tensor_iterator->get_concatenated_slices(res_2, 0, 1, 1, -1, 1);
|
||||
|
||||
auto res_ti_1 = std::make_shared<opset5::Result>(tensor_iterator->output(1));
|
||||
f = std::make_shared<ngraph::Function>(ngraph::NodeVector{res_ti_1}, ngraph::ParameterVector{X, Y, Z});
|
||||
|
||||
ngraph::pass::Manager m;
|
||||
m.register_pass<ngraph::pass::InitNodeInfo>();
|
||||
m.register_pass<ngraph::pass::ConvertTensorIteratorToLSTMSequence>();
|
||||
m.run_passes(f);
|
||||
ASSERT_NO_THROW(check_rt_info(f));
|
||||
}
|
||||
|
||||
{
|
||||
auto X = std::make_shared<opset5::Parameter>(element::f32, PartialShape{-1, 2, -1});
|
||||
auto Y = std::make_shared<opset5::Parameter>(element::f32, PartialShape{-1, -1});
|
||||
auto Z = std::make_shared<opset5::Parameter>(element::f32, PartialShape{-1, -1});
|
||||
|
||||
auto axis_1 = ngraph::opset5::Constant::create(ngraph::element::i64, ngraph::Shape{1}, {1});
|
||||
auto in_1 = std::make_shared<ngraph::opset5::Unsqueeze>(Y, axis_1);
|
||||
auto in_2 = std::make_shared<ngraph::opset5::Unsqueeze>(Z, axis_1);
|
||||
|
||||
auto shape_of = std::make_shared<opset5::ShapeOf>(X);
|
||||
auto batch_dimension = ngraph::op::util::make_try_fold<ngraph::opset7::Gather>(
|
||||
shape_of,
|
||||
ngraph::opset5::Constant::create(ngraph::element::i64, {1}, {0}),
|
||||
ngraph::opset5::Constant::create(ngraph::element::i64, {}, {0}));
|
||||
auto seq_lengths =
|
||||
std::make_shared<opset5::Broadcast>(ngraph::opset5::Constant::create(element::i32, Shape{}, {2}),
|
||||
batch_dimension);
|
||||
|
||||
auto w_val = std::vector<float>(512 * 16, 0);
|
||||
auto r_val = std::vector<float>(512 * 128, 0);
|
||||
auto b_val = std::vector<float>(512, 0);
|
||||
auto W = ngraph::opset5::Constant::create(ngraph::element::f32, ngraph::Shape{1, 512, 16}, w_val);
|
||||
auto R = ngraph::opset5::Constant::create(ngraph::element::f32, ngraph::Shape{1, 512, 128}, r_val);
|
||||
auto B = ngraph::opset5::Constant::create(ngraph::element::f32, ngraph::Shape{1, 512}, b_val);
|
||||
|
||||
auto lstm_seq = std::make_shared<opset5::LSTMSequence>(X,
|
||||
in_1,
|
||||
in_2,
|
||||
seq_lengths,
|
||||
W,
|
||||
R,
|
||||
B,
|
||||
128,
|
||||
op::RecurrentSequenceDirection::FORWARD);
|
||||
|
||||
auto axis_out = ngraph::opset5::Constant::create(ngraph::element::i64, ngraph::Shape{1}, {1});
|
||||
auto out_0 = std::make_shared<ngraph::opset5::Squeeze>(lstm_seq->output(0), axis_out);
|
||||
auto out_1 = std::make_shared<ngraph::opset5::Squeeze>(lstm_seq->output(1), axis_out);
|
||||
auto out_2 = std::make_shared<ngraph::opset5::Squeeze>(lstm_seq->output(2), axis_out);
|
||||
auto res_ti_1 = std::make_shared<opset5::Result>(out_0);
|
||||
f_ref = std::make_shared<ngraph::Function>(ngraph::NodeVector{res_ti_1}, ngraph::ParameterVector{X, Y, Z});
|
||||
}
|
||||
|
||||
auto res = compare_functions(f, f_ref);
|
||||
ASSERT_TRUE(res.first) << res.second;
|
||||
}
|
||||
|
||||
TEST(TransformationTests, ConvertTensorIteratorToRNNSequence) {
|
||||
std::shared_ptr<ngraph::Function> f(nullptr), f_ref(nullptr);
|
||||
{
|
||||
@@ -276,7 +374,7 @@ TEST(TransformationTests, ConvertTensorIteratorToRNNSequence) {
|
||||
ASSERT_TRUE(res.first) << res.second;
|
||||
}
|
||||
|
||||
TEST(TransformationTests, ConvertTensorIteratorToRNNSequenceDynamic) {
|
||||
TEST(TransformationTests, ConvertTensorIteratorToRNNSequenceDynamicReshapeCase) {
|
||||
std::shared_ptr<ngraph::Function> f(nullptr), f_ref(nullptr);
|
||||
{
|
||||
auto X = std::make_shared<opset5::Parameter>(element::f32, PartialShape{ -1, 2, -1 });
|
||||
@@ -337,7 +435,7 @@ TEST(TransformationTests, ConvertTensorIteratorToRNNSequenceDynamic) {
|
||||
auto in_1 = std::make_shared<ngraph::opset5::Unsqueeze>(Y, axis_1);
|
||||
|
||||
auto shape_of = std::make_shared<opset5::ShapeOf>(X);
|
||||
auto batch_dimension = ngraph::op::util::make_try_fold<ngraph::opset5::Gather>(
|
||||
auto batch_dimension = ngraph::op::util::make_try_fold<ngraph::opset7::Gather>(
|
||||
shape_of,
|
||||
ngraph::opset5::Constant::create(ngraph::element::i64, { 1 }, { 0 }),
|
||||
ngraph::opset5::Constant::create(ngraph::element::i64, {}, { 0 }));
|
||||
@@ -355,6 +453,94 @@ TEST(TransformationTests, ConvertTensorIteratorToRNNSequenceDynamic) {
|
||||
ASSERT_TRUE(res.first) << res.second;
|
||||
}
|
||||
|
||||
TEST(TransformationTests, ConvertTensorIteratorToRNNSequenceDynamicSqueezeCase) {
|
||||
std::shared_ptr<ngraph::Function> f(nullptr), f_ref(nullptr);
|
||||
{
|
||||
auto X = std::make_shared<opset5::Parameter>(element::f32, PartialShape{-1, 2, -1});
|
||||
auto Y = std::make_shared<opset5::Parameter>(element::f32, PartialShape{-1, -1});
|
||||
|
||||
auto Xi = std::make_shared<opset5::Parameter>(element::f32, PartialShape{-1, 1, -1});
|
||||
auto Yi = std::make_shared<opset5::Parameter>(element::f32, PartialShape{-1, -1});
|
||||
|
||||
// Body
|
||||
auto axis = ngraph::opset5::Constant::create(ngraph::element::i64, ngraph::Shape{1}, {1});
|
||||
auto squeeze = std::make_shared<opset5::Squeeze>(Xi, axis);
|
||||
|
||||
auto w_val = std::vector<float>(128 * 16, 0);
|
||||
auto r_val = std::vector<float>(128 * 128, 0);
|
||||
auto b_val = std::vector<float>(128, 0);
|
||||
auto W = ngraph::opset5::Constant::create(ngraph::element::f32, ngraph::Shape{128, 16}, w_val);
|
||||
auto R = ngraph::opset5::Constant::create(ngraph::element::f32, ngraph::Shape{128, 128}, r_val);
|
||||
auto B = ngraph::opset5::Constant::create(ngraph::element::f32, ngraph::Shape{128}, b_val);
|
||||
|
||||
auto rnn_cell = std::make_shared<opset5::RNNCell>(squeeze, Yi, W, R, B, 128);
|
||||
auto res_1 = std::make_shared<opset5::Result>(rnn_cell);
|
||||
auto reshape_pattern_2 = ngraph::opset5::Constant::create(ngraph::element::i64, ngraph::Shape{3}, {1, 1, -1});
|
||||
auto unsqueeze = std::make_shared<opset5::Reshape>(rnn_cell, reshape_pattern_2, false);
|
||||
auto res_2 = std::make_shared<opset5::Result>(unsqueeze);
|
||||
auto body = std::make_shared<Function>(OutputVector{res_1, res_2}, ParameterVector{Xi, Yi});
|
||||
|
||||
auto tensor_iterator = std::make_shared<opset5::TensorIterator>();
|
||||
tensor_iterator->set_body(body);
|
||||
|
||||
tensor_iterator->set_sliced_input(Xi, X, 0, 1, 1, -1, 1);
|
||||
tensor_iterator->set_merged_input(Yi, Y, res_1);
|
||||
|
||||
auto out0 = tensor_iterator->get_iter_value(res_1, -1);
|
||||
auto out1 = tensor_iterator->get_concatenated_slices(res_2, 0, 1, 1, -1, 1);
|
||||
|
||||
auto res_ti_1 = std::make_shared<opset5::Result>(tensor_iterator->output(1));
|
||||
f = std::make_shared<ngraph::Function>(ngraph::NodeVector{res_ti_1}, ngraph::ParameterVector{X, Y});
|
||||
|
||||
ngraph::pass::Manager m;
|
||||
m.register_pass<ngraph::pass::InitNodeInfo>();
|
||||
m.register_pass<ngraph::pass::ConvertTensorIteratorToRNNSequence>();
|
||||
m.run_passes(f);
|
||||
ASSERT_NO_THROW(check_rt_info(f));
|
||||
}
|
||||
|
||||
{
|
||||
auto X = std::make_shared<opset5::Parameter>(element::f32, PartialShape{-1, 2, -1});
|
||||
auto Y = std::make_shared<opset5::Parameter>(element::f32, PartialShape{-1, -1});
|
||||
|
||||
auto w_val = std::vector<float>(128 * 16, 0);
|
||||
auto r_val = std::vector<float>(128 * 128, 0);
|
||||
auto b_val = std::vector<float>(128, 0);
|
||||
auto W = ngraph::opset5::Constant::create(ngraph::element::f32, ngraph::Shape{1, 128, 16}, w_val);
|
||||
auto R = ngraph::opset5::Constant::create(ngraph::element::f32, ngraph::Shape{1, 128, 128}, r_val);
|
||||
auto B = ngraph::opset5::Constant::create(ngraph::element::f32, ngraph::Shape{1, 128}, b_val);
|
||||
|
||||
auto axis_1 = ngraph::opset5::Constant::create(ngraph::element::i64, ngraph::Shape{1}, {1});
|
||||
auto in_1 = std::make_shared<ngraph::opset5::Unsqueeze>(Y, axis_1);
|
||||
|
||||
auto shape_of = std::make_shared<opset5::ShapeOf>(X);
|
||||
auto batch_dimension = ngraph::op::util::make_try_fold<ngraph::opset7::Gather>(
|
||||
shape_of,
|
||||
ngraph::opset5::Constant::create(ngraph::element::i64, {1}, {0}),
|
||||
ngraph::opset5::Constant::create(ngraph::element::i64, {}, {0}));
|
||||
auto seq_lengths =
|
||||
std::make_shared<opset5::Broadcast>(ngraph::opset5::Constant::create(element::i32, Shape{}, {2}),
|
||||
batch_dimension);
|
||||
|
||||
auto rnn_sequence = std::make_shared<opset5::RNNSequence>(X,
|
||||
in_1,
|
||||
seq_lengths,
|
||||
W,
|
||||
R,
|
||||
B,
|
||||
128,
|
||||
op::RecurrentSequenceDirection::FORWARD);
|
||||
auto axis_out = ngraph::opset5::Constant::create(ngraph::element::i64, ngraph::Shape{1}, {1});
|
||||
auto out_0 = std::make_shared<ngraph::opset5::Squeeze>(rnn_sequence->output(0), axis_out);
|
||||
auto out_1 = std::make_shared<ngraph::opset5::Squeeze>(rnn_sequence->output(1), axis_out);
|
||||
auto res_ti_1 = std::make_shared<opset5::Result>(out_0);
|
||||
f_ref = std::make_shared<ngraph::Function>(ngraph::NodeVector{res_ti_1}, ngraph::ParameterVector{X, Y});
|
||||
}
|
||||
|
||||
auto res = compare_functions(f, f_ref);
|
||||
ASSERT_TRUE(res.first) << res.second;
|
||||
}
|
||||
|
||||
TEST(TransformationTests, ConvertTensorIteratorToGRUSequence) {
|
||||
std::shared_ptr<ngraph::Function> f(nullptr), f_ref(nullptr);
|
||||
{
|
||||
@@ -428,7 +614,7 @@ TEST(TransformationTests, ConvertTensorIteratorToGRUSequence) {
|
||||
ASSERT_TRUE(res.first) << res.second;
|
||||
}
|
||||
|
||||
TEST(TransformationTests, ConvertTensorIteratorToGRUSequenceDynamic) {
|
||||
TEST(TransformationTests, ConvertTensorIteratorToGRUSequenceDynamicReshapeCase) {
|
||||
std::shared_ptr<ngraph::Function> f(nullptr), f_ref(nullptr);
|
||||
{
|
||||
auto X = std::make_shared<opset5::Parameter>(element::f32, PartialShape{ -1, 2, -1 });
|
||||
@@ -489,7 +675,7 @@ TEST(TransformationTests, ConvertTensorIteratorToGRUSequenceDynamic) {
|
||||
auto B = ngraph::opset5::Constant::create(ngraph::element::f32, ngraph::Shape{ 1, 384 }, b_val);
|
||||
|
||||
auto shape_of = std::make_shared<opset5::ShapeOf>(X);
|
||||
auto batch_dimension = ngraph::op::util::make_try_fold<ngraph::opset5::Gather>(
|
||||
auto batch_dimension = ngraph::op::util::make_try_fold<ngraph::opset7::Gather>(
|
||||
shape_of,
|
||||
ngraph::opset5::Constant::create(ngraph::element::i64, { 1 }, { 0 }),
|
||||
ngraph::opset5::Constant::create(ngraph::element::i64, {}, { 0 }));
|
||||
@@ -506,3 +692,91 @@ TEST(TransformationTests, ConvertTensorIteratorToGRUSequenceDynamic) {
|
||||
auto res = compare_functions(f, f_ref);
|
||||
ASSERT_TRUE(res.first) << res.second;
|
||||
}
|
||||
|
||||
TEST(TransformationTests, ConvertTensorIteratorToGRUSequenceDynamicSqueezeCase) {
|
||||
std::shared_ptr<ngraph::Function> f(nullptr), f_ref(nullptr);
|
||||
{
|
||||
auto X = std::make_shared<opset5::Parameter>(element::f32, PartialShape{-1, 2, -1});
|
||||
auto Y = std::make_shared<opset5::Parameter>(element::f32, PartialShape{-1, -1});
|
||||
|
||||
auto Xi = std::make_shared<opset5::Parameter>(element::f32, PartialShape{-1, 1, -1});
|
||||
auto Yi = std::make_shared<opset5::Parameter>(element::f32, PartialShape{-1, -1});
|
||||
|
||||
// Body
|
||||
auto axis = ngraph::opset5::Constant::create(ngraph::element::i64, ngraph::Shape{1}, {1});
|
||||
auto squeeze = std::make_shared<opset5::Squeeze>(Xi, axis);
|
||||
|
||||
auto w_val = std::vector<float>(384 * 16, 0);
|
||||
auto r_val = std::vector<float>(384 * 128, 0);
|
||||
auto b_val = std::vector<float>(384, 0);
|
||||
auto W = ngraph::opset5::Constant::create(ngraph::element::f32, ngraph::Shape{384, 16}, w_val);
|
||||
auto R = ngraph::opset5::Constant::create(ngraph::element::f32, ngraph::Shape{384, 128}, r_val);
|
||||
auto B = ngraph::opset5::Constant::create(ngraph::element::f32, ngraph::Shape{384}, b_val);
|
||||
|
||||
auto gru_cell = std::make_shared<opset5::GRUCell>(squeeze, Yi, W, R, B, 128);
|
||||
auto res_1 = std::make_shared<opset5::Result>(gru_cell);
|
||||
auto reshape_pattern_2 = ngraph::opset5::Constant::create(ngraph::element::i64, ngraph::Shape{3}, {1, 1, 128});
|
||||
auto unsqueeze = std::make_shared<opset5::Reshape>(gru_cell, reshape_pattern_2, false);
|
||||
auto res_2 = std::make_shared<opset5::Result>(unsqueeze);
|
||||
auto body = std::make_shared<Function>(OutputVector{res_1, res_2}, ParameterVector{Xi, Yi});
|
||||
|
||||
auto tensor_iterator = std::make_shared<opset5::TensorIterator>();
|
||||
tensor_iterator->set_body(body);
|
||||
|
||||
tensor_iterator->set_sliced_input(Xi, X, 0, 1, 1, -1, 1);
|
||||
tensor_iterator->set_merged_input(Yi, Y, res_1);
|
||||
|
||||
auto out0 = tensor_iterator->get_iter_value(res_1, -1);
|
||||
auto out1 = tensor_iterator->get_concatenated_slices(res_2, 0, 1, 1, -1, 1);
|
||||
|
||||
auto res_ti_1 = std::make_shared<opset5::Result>(tensor_iterator->output(1));
|
||||
f = std::make_shared<ngraph::Function>(ngraph::NodeVector{res_ti_1}, ngraph::ParameterVector{X, Y});
|
||||
|
||||
ngraph::pass::Manager m;
|
||||
m.register_pass<ngraph::pass::InitNodeInfo>();
|
||||
m.register_pass<ngraph::pass::ConvertTensorIteratorToGRUSequence>();
|
||||
m.run_passes(f);
|
||||
ASSERT_NO_THROW(check_rt_info(f));
|
||||
}
|
||||
|
||||
{
|
||||
auto X = std::make_shared<opset5::Parameter>(element::f32, PartialShape{-1, 2, -1});
|
||||
auto Y = std::make_shared<opset5::Parameter>(element::f32, PartialShape{-1, -1});
|
||||
|
||||
auto axis_1 = ngraph::opset5::Constant::create(ngraph::element::i64, ngraph::Shape{1}, {1});
|
||||
auto in_1 = std::make_shared<ngraph::opset5::Unsqueeze>(Y, axis_1);
|
||||
|
||||
auto w_val = std::vector<float>(384 * 16, 0);
|
||||
auto r_val = std::vector<float>(384 * 128, 0);
|
||||
auto b_val = std::vector<float>(384, 0);
|
||||
auto W = ngraph::opset5::Constant::create(ngraph::element::f32, ngraph::Shape{1, 384, 16}, w_val);
|
||||
auto R = ngraph::opset5::Constant::create(ngraph::element::f32, ngraph::Shape{1, 384, 128}, r_val);
|
||||
auto B = ngraph::opset5::Constant::create(ngraph::element::f32, ngraph::Shape{1, 384}, b_val);
|
||||
|
||||
auto shape_of = std::make_shared<opset5::ShapeOf>(X);
|
||||
auto batch_dimension = ngraph::op::util::make_try_fold<ngraph::opset7::Gather>(
|
||||
shape_of,
|
||||
ngraph::opset5::Constant::create(ngraph::element::i64, {1}, {0}),
|
||||
ngraph::opset5::Constant::create(ngraph::element::i64, {}, {0}));
|
||||
auto seq_lengths =
|
||||
std::make_shared<opset5::Broadcast>(ngraph::opset5::Constant::create(element::i32, Shape{}, {2}),
|
||||
batch_dimension);
|
||||
|
||||
auto gru_sequence = std::make_shared<opset5::GRUSequence>(X,
|
||||
in_1,
|
||||
seq_lengths,
|
||||
W,
|
||||
R,
|
||||
B,
|
||||
128,
|
||||
op::RecurrentSequenceDirection::FORWARD);
|
||||
auto axis_out = ngraph::opset5::Constant::create(ngraph::element::i64, ngraph::Shape{1}, {1});
|
||||
auto out_0 = std::make_shared<ngraph::opset5::Squeeze>(gru_sequence->output(0), axis_out);
|
||||
auto out_1 = std::make_shared<ngraph::opset5::Squeeze>(gru_sequence->output(1), axis_out);
|
||||
auto res_ti_1 = std::make_shared<opset5::Result>(out_0);
|
||||
f_ref = std::make_shared<ngraph::Function>(ngraph::NodeVector{res_ti_1}, ngraph::ParameterVector{X, Y});
|
||||
}
|
||||
|
||||
auto res = compare_functions(f, f_ref);
|
||||
ASSERT_TRUE(res.first) << res.second;
|
||||
}
|
||||
|
||||
@@ -10,6 +10,7 @@
|
||||
#include <ngraph/function.hpp>
|
||||
#include <ngraph/opsets/opset1.hpp>
|
||||
#include <ngraph/opsets/opset5.hpp>
|
||||
#include <ngraph/opsets/opset7.hpp>
|
||||
#include <ngraph_transformations/rnn_sequences_optimization.hpp>
|
||||
#include <transformations/init_node_info.hpp>
|
||||
#include <transformations/utils/utils.hpp>
|
||||
@@ -147,7 +148,7 @@ TEST(TransformationTests, OptimizeLSTMSequenceTransposesDynamicTest) {
|
||||
auto B = ngraph::opset1::Constant::create(ngraph::element::f32, ngraph::Shape{ 1, 512 }, b_val);
|
||||
|
||||
auto data = std::make_shared<ngraph::opset1::ShapeOf>(X);
|
||||
auto reshape_before_pattern = std::make_shared<ngraph::opset1::Gather>(data,
|
||||
auto reshape_before_pattern = std::make_shared<ngraph::opset7::Gather>(data,
|
||||
ngraph::opset1::Constant::create(ngraph::element::i32, { 3 }, { 1, 0, 2 }),
|
||||
ngraph::opset1::Constant::create(ngraph::element::i32, {}, { 0 }));
|
||||
auto reshape_before = std::make_shared<ngraph::opset1::Reshape>(X, reshape_before_pattern, false);
|
||||
@@ -290,7 +291,7 @@ TEST(TransformationTests, OptimizeRNNSequenceTransposesDynamicTest) {
|
||||
auto B = ngraph::opset1::Constant::create(ngraph::element::f32, ngraph::Shape{ 1, 128 }, b_val);
|
||||
|
||||
auto data = std::make_shared<ngraph::opset1::ShapeOf>(X);
|
||||
auto reshape_before_pattern = std::make_shared<ngraph::opset1::Gather>(data,
|
||||
auto reshape_before_pattern = std::make_shared<ngraph::opset7::Gather>(data,
|
||||
ngraph::opset1::Constant::create(ngraph::element::i32, { 3 }, { 1, 0, 2 }),
|
||||
ngraph::opset1::Constant::create(ngraph::element::i32, {}, { 0 }));
|
||||
auto reshape_before = std::make_shared<ngraph::opset1::Reshape>(X, reshape_before_pattern, false);
|
||||
@@ -431,7 +432,7 @@ TEST(TransformationTests, OptimizeGRUSequenceTransposesDynamicTest) {
|
||||
auto B = ngraph::opset1::Constant::create(ngraph::element::f32, ngraph::Shape{ 1, 384 }, b_val);
|
||||
|
||||
auto data = std::make_shared<ngraph::opset1::ShapeOf>(X);
|
||||
auto reshape_before_pattern = std::make_shared<ngraph::opset1::Gather>(data,
|
||||
auto reshape_before_pattern = std::make_shared<ngraph::opset7::Gather>(data,
|
||||
ngraph::opset1::Constant::create(ngraph::element::i32, { 3 }, { 1, 0, 2 }),
|
||||
ngraph::opset1::Constant::create(ngraph::element::i32, {}, { 0 }));
|
||||
auto reshape_before = std::make_shared<ngraph::opset1::Reshape>(X, reshape_before_pattern, false);
|
||||
|
||||
Reference in New Issue
Block a user