ConvertTensorIteratorToLSTMSequence fix (#9541)
This commit is contained in:
parent
c6079ccc11
commit
a49f1b3bc6
@ -78,7 +78,7 @@ bool convertTensorIteratorToSequence(
|
||||
ordered_out_descs[0] = output_desc;
|
||||
} else if (res->input_value(0) == found_cell->output(0)) {
|
||||
ordered_out_descs[1] = output_desc;
|
||||
} else if (found_cell->get_output_size() == 3 && res->input_value(0) == found_cell->output(1)) {
|
||||
} else if (found_cell->get_output_size() == 2 && res->input_value(0) == found_cell->output(1)) {
|
||||
ordered_out_descs[2] = output_desc;
|
||||
} else {
|
||||
return false;
|
||||
|
@ -43,24 +43,33 @@ TEST(TransformationTests, ConvertTensorIteratorToLSTMSequence) {
|
||||
auto B = ngraph::opset5::Constant::create(ngraph::element::f32, ngraph::Shape{512}, b_val);
|
||||
|
||||
auto lstm_cell = std::make_shared<opset5::LSTMCell>(squeeze, Yi, Zi, W, R, B, 128);
|
||||
auto res_1 = std::make_shared<opset5::Result>(lstm_cell);
|
||||
auto lstm_res_1 = std::make_shared<opset5::Result>(lstm_cell->output(0));
|
||||
auto lstm_res_2 = std::make_shared<opset5::Result>(lstm_cell->output(1));
|
||||
auto reshape_pattern_2 = ngraph::opset5::Constant::create(ngraph::element::i64, ngraph::Shape{3}, {1, 1, 128});
|
||||
auto unsqueeze = std::make_shared<opset5::Reshape>(lstm_cell, reshape_pattern_2, false);
|
||||
auto res_2 = std::make_shared<opset5::Result>(unsqueeze);
|
||||
auto body = std::make_shared<Function>(OutputVector{res_1, res_2}, ParameterVector{Xi, Yi, Zi});
|
||||
auto unsqueeze = std::make_shared<opset5::Reshape>(lstm_cell->output(0), reshape_pattern_2, false);
|
||||
auto lstm_res1_unsqueeze = std::make_shared<opset5::Result>(unsqueeze);
|
||||
auto body = std::make_shared<Function>(OutputVector{lstm_res_1, lstm_res1_unsqueeze, lstm_res_2}, ParameterVector{Xi, Yi, Zi});
|
||||
|
||||
auto tensor_iterator = std::make_shared<opset5::TensorIterator>();
|
||||
tensor_iterator->set_body(body);
|
||||
|
||||
tensor_iterator->set_invariant_input(Zi, Z);
|
||||
tensor_iterator->set_sliced_input(Xi, X, 0, 1, 1, -1, 1);
|
||||
tensor_iterator->set_merged_input(Yi, Y, res_1);
|
||||
tensor_iterator->set_merged_input(Yi, Y, lstm_res_1);
|
||||
|
||||
auto out0 = tensor_iterator->get_iter_value(res_1, -1);
|
||||
auto out1 = tensor_iterator->get_concatenated_slices(res_2, 0, 1, 1, -1, 1);
|
||||
auto out0 = tensor_iterator->get_concatenated_slices(lstm_res1_unsqueeze, 0, 1, 1, -1, 1);
|
||||
auto out1 = tensor_iterator->get_iter_value(lstm_res_1, -1);
|
||||
auto out2 = tensor_iterator->get_iter_value(lstm_res_2, -1);
|
||||
|
||||
auto res_ti_0 = std::make_shared<opset5::Result>(tensor_iterator->output(0));
|
||||
auto res_ti_1 = std::make_shared<opset5::Result>(tensor_iterator->output(1));
|
||||
f = std::make_shared<ngraph::Function>(ngraph::NodeVector{res_ti_1}, ngraph::ParameterVector{X, Y, Z});
|
||||
auto res_ti_2 = std::make_shared<opset5::Result>(tensor_iterator->output(2));
|
||||
|
||||
res_ti_0->set_friendly_name("Result1");
|
||||
res_ti_1->set_friendly_name("Result2");
|
||||
res_ti_2->set_friendly_name("Result3");
|
||||
f = std::make_shared<ngraph::Function>(ngraph::NodeVector{res_ti_0, res_ti_1, res_ti_2},
|
||||
ngraph::ParameterVector{X, Y, Z});
|
||||
|
||||
ngraph::pass::Manager m;
|
||||
m.register_pass<ngraph::pass::InitNodeInfo>();
|
||||
@ -89,8 +98,18 @@ TEST(TransformationTests, ConvertTensorIteratorToLSTMSequence) {
|
||||
auto lstm_seq = std::make_shared<opset5::LSTMSequence>(X, in_1, in_2, seq_lengths, W, R, B, 128, op::RecurrentSequenceDirection::FORWARD);
|
||||
auto axis_out = ngraph::opset5::Constant::create(ngraph::element::i64, ngraph::Shape{1}, {1});
|
||||
auto out_0 = std::make_shared<ngraph::opset5::Squeeze>(lstm_seq->output(0), axis_out);
|
||||
auto res_ti_1 = std::make_shared<opset5::Result>(out_0);
|
||||
f_ref = std::make_shared<ngraph::Function>(ngraph::NodeVector{res_ti_1}, ngraph::ParameterVector{X, Y, Z});
|
||||
auto out_1 = std::make_shared<ngraph::opset5::Squeeze>(lstm_seq->output(1), axis_out);
|
||||
auto out_2 = std::make_shared<ngraph::opset5::Squeeze>(lstm_seq->output(2), axis_out);
|
||||
|
||||
auto res_ti_0 = std::make_shared<opset5::Result>(out_0);
|
||||
auto res_ti_1 = std::make_shared<opset5::Result>(out_1);
|
||||
auto res_ti_2 = std::make_shared<opset5::Result>(out_2);
|
||||
res_ti_0->set_friendly_name("Result1");
|
||||
res_ti_1->set_friendly_name("Result2");
|
||||
res_ti_2->set_friendly_name("Result3");
|
||||
|
||||
f_ref = std::make_shared<ngraph::Function>(ngraph::NodeVector{res_ti_0, res_ti_1, res_ti_2},
|
||||
ngraph::ParameterVector{X, Y, Z});
|
||||
}
|
||||
|
||||
auto res = compare_functions(f, f_ref);
|
||||
|
Loading…
Reference in New Issue
Block a user