ConvertTensorIteratorToLSTMSequence fix (#9541)

This commit is contained in:
Vladislav Golubev 2022-01-11 17:23:32 +03:00 committed by GitHub
parent c6079ccc11
commit a49f1b3bc6
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 30 additions and 11 deletions

View File

@ -78,7 +78,7 @@ bool convertTensorIteratorToSequence(
ordered_out_descs[0] = output_desc;
} else if (res->input_value(0) == found_cell->output(0)) {
ordered_out_descs[1] = output_desc;
} else if (found_cell->get_output_size() == 3 && res->input_value(0) == found_cell->output(1)) {
} else if (found_cell->get_output_size() == 2 && res->input_value(0) == found_cell->output(1)) {
ordered_out_descs[2] = output_desc;
} else {
return false;

View File

@ -43,24 +43,33 @@ TEST(TransformationTests, ConvertTensorIteratorToLSTMSequence) {
auto B = ngraph::opset5::Constant::create(ngraph::element::f32, ngraph::Shape{512}, b_val);
auto lstm_cell = std::make_shared<opset5::LSTMCell>(squeeze, Yi, Zi, W, R, B, 128);
auto res_1 = std::make_shared<opset5::Result>(lstm_cell);
auto lstm_res_1 = std::make_shared<opset5::Result>(lstm_cell->output(0));
auto lstm_res_2 = std::make_shared<opset5::Result>(lstm_cell->output(1));
auto reshape_pattern_2 = ngraph::opset5::Constant::create(ngraph::element::i64, ngraph::Shape{3}, {1, 1, 128});
auto unsqueeze = std::make_shared<opset5::Reshape>(lstm_cell, reshape_pattern_2, false);
auto res_2 = std::make_shared<opset5::Result>(unsqueeze);
auto body = std::make_shared<Function>(OutputVector{res_1, res_2}, ParameterVector{Xi, Yi, Zi});
auto unsqueeze = std::make_shared<opset5::Reshape>(lstm_cell->output(0), reshape_pattern_2, false);
auto lstm_res1_unsqueeze = std::make_shared<opset5::Result>(unsqueeze);
auto body = std::make_shared<Function>(OutputVector{lstm_res_1, lstm_res1_unsqueeze, lstm_res_2}, ParameterVector{Xi, Yi, Zi});
auto tensor_iterator = std::make_shared<opset5::TensorIterator>();
tensor_iterator->set_body(body);
tensor_iterator->set_invariant_input(Zi, Z);
tensor_iterator->set_sliced_input(Xi, X, 0, 1, 1, -1, 1);
tensor_iterator->set_merged_input(Yi, Y, res_1);
tensor_iterator->set_merged_input(Yi, Y, lstm_res_1);
auto out0 = tensor_iterator->get_iter_value(res_1, -1);
auto out1 = tensor_iterator->get_concatenated_slices(res_2, 0, 1, 1, -1, 1);
auto out0 = tensor_iterator->get_concatenated_slices(lstm_res1_unsqueeze, 0, 1, 1, -1, 1);
auto out1 = tensor_iterator->get_iter_value(lstm_res_1, -1);
auto out2 = tensor_iterator->get_iter_value(lstm_res_2, -1);
auto res_ti_0 = std::make_shared<opset5::Result>(tensor_iterator->output(0));
auto res_ti_1 = std::make_shared<opset5::Result>(tensor_iterator->output(1));
f = std::make_shared<ngraph::Function>(ngraph::NodeVector{res_ti_1}, ngraph::ParameterVector{X, Y, Z});
auto res_ti_2 = std::make_shared<opset5::Result>(tensor_iterator->output(2));
res_ti_0->set_friendly_name("Result1");
res_ti_1->set_friendly_name("Result2");
res_ti_2->set_friendly_name("Result3");
f = std::make_shared<ngraph::Function>(ngraph::NodeVector{res_ti_0, res_ti_1, res_ti_2},
ngraph::ParameterVector{X, Y, Z});
ngraph::pass::Manager m;
m.register_pass<ngraph::pass::InitNodeInfo>();
@ -89,8 +98,18 @@ TEST(TransformationTests, ConvertTensorIteratorToLSTMSequence) {
auto lstm_seq = std::make_shared<opset5::LSTMSequence>(X, in_1, in_2, seq_lengths, W, R, B, 128, op::RecurrentSequenceDirection::FORWARD);
auto axis_out = ngraph::opset5::Constant::create(ngraph::element::i64, ngraph::Shape{1}, {1});
auto out_0 = std::make_shared<ngraph::opset5::Squeeze>(lstm_seq->output(0), axis_out);
auto res_ti_1 = std::make_shared<opset5::Result>(out_0);
f_ref = std::make_shared<ngraph::Function>(ngraph::NodeVector{res_ti_1}, ngraph::ParameterVector{X, Y, Z});
auto out_1 = std::make_shared<ngraph::opset5::Squeeze>(lstm_seq->output(1), axis_out);
auto out_2 = std::make_shared<ngraph::opset5::Squeeze>(lstm_seq->output(2), axis_out);
auto res_ti_0 = std::make_shared<opset5::Result>(out_0);
auto res_ti_1 = std::make_shared<opset5::Result>(out_1);
auto res_ti_2 = std::make_shared<opset5::Result>(out_2);
res_ti_0->set_friendly_name("Result1");
res_ti_1->set_friendly_name("Result2");
res_ti_2->set_friendly_name("Result3");
f_ref = std::make_shared<ngraph::Function>(ngraph::NodeVector{res_ti_0, res_ti_1, res_ti_2},
ngraph::ParameterVector{X, Y, Z});
}
auto res = compare_functions(f, f_ref);