From fa3c7452637203ddc13a22677d104114e7278757 Mon Sep 17 00:00:00 2001 From: Roman Kazantsev Date: Wed, 5 Oct 2022 06:00:33 +0300 Subject: [PATCH] [TF FE] Support DeepSpeech model by TF FE (#13316) * [TF FE] Support DeepSpeech model by TF FE Add the final part to support BlockLSTM operation with sliced state cell from the last time step. Signed-off-by: Kazantsev, Roman * Apply code-review feedback: use get_pattern_map, no nullptr Signed-off-by: Kazantsev, Roman --- src/frontends/tensorflow/src/frontend.cpp | 2 +- .../helper_transforms/block_lstm_replacer.cpp | 118 +++++++++++++----- .../helper_transforms/block_lstm_replacer.hpp | 10 +- .../tensorflow/tests/block_lstm_replacer.cpp | 45 ++++++- 4 files changed, 131 insertions(+), 44 deletions(-) diff --git a/src/frontends/tensorflow/src/frontend.cpp b/src/frontends/tensorflow/src/frontend.cpp index 20618a8234f..b113c743020 100644 --- a/src/frontends/tensorflow/src/frontend.cpp +++ b/src/frontends/tensorflow/src/frontend.cpp @@ -430,7 +430,7 @@ void FrontEnd::normalize(const std::shared_ptr& function) const { // Runs middle transformations to convert sub-graphs with intermediate (frontend internal) operations // into sub-graphs with only OpenVINO operations manager.register_pass(); - manager.register_pass(); + manager.register_pass(); manager.register_pass(); // TODO: reimplement TransposeSinking that does not corrupt filters for Convolution diff --git a/src/frontends/tensorflow/src/helper_transforms/block_lstm_replacer.cpp b/src/frontends/tensorflow/src/helper_transforms/block_lstm_replacer.cpp index 9f88e2dfa22..0d359380064 100644 --- a/src/frontends/tensorflow/src/helper_transforms/block_lstm_replacer.cpp +++ b/src/frontends/tensorflow/src/helper_transforms/block_lstm_replacer.cpp @@ -18,20 +18,69 @@ using namespace std; using namespace ov::pass; +using namespace ov::pass::pattern; using namespace ov::opset9; using namespace ov::frontend::tensorflow; -pass::BlockLSTMToLSTMSequenceOneOutput::BlockLSTMToLSTMSequenceOneOutput() { - auto block_lstm = pattern::wrap_type(); +namespace { +std::function)> can_have_outputs(const std::vector& allowed_output_indices) { + return [=](ov::Output output) -> bool { + auto block_lstm_node = output.get_node_shared_ptr(); + auto output_size = block_lstm_node->get_output_size(); + for (size_t output_ind = 0; output_ind < output_size; ++output_ind) { + if (std::find(allowed_output_indices.begin(), allowed_output_indices.end(), output_ind) != + allowed_output_indices.end()) { + continue; + } + if (block_lstm_node->output(output_ind).get_target_inputs().size() > 0) { + return false; + } + } + return true; + }; +} +} // namespace + +pass::BlockLSTMReplacer::BlockLSTMReplacer() { + // Pattern 1: BlockLSTM with last state cell output (BlockLSTM -> Concat -> GatherND) + // used in DeepSpeech model + auto block_lstm_1 = pattern::wrap_type(can_have_outputs({1, 6})); + auto states_cell_1 = pattern::wrap_type({pattern::any_input(), block_lstm_1}); + auto pattern1 = pattern::wrap_type({states_cell_1, pattern::any_input()}); + + // Pattern 2: BlockLSTM with just one output, concatenated hidden states (BlockLSTM) + auto pattern2 = pattern::wrap_type(can_have_outputs({6})); + + auto root = std::make_shared(OutputVector{pattern1, pattern2}); matcher_pass_callback callback = [=](pattern::Matcher& m) { - NodeRegistry rg; + auto pattern_map = m.get_pattern_map(); + auto is_pattern1 = (pattern_map.find(pattern1) != std::end(pattern_map)); + auto is_pattern2 = (pattern_map.find(pattern2) != std::end(pattern_map)); - auto block_lstm_node = std::dynamic_pointer_cast(m.get_match_root()); + // find for each pattern BlockLSTM node for which we adjust inputs + // and check its attributes before the transformation + std::shared_ptr block_lstm_node; + std::shared_ptr last_state_c_node; + ov::NodeVector rt_info_from; + if (is_pattern1) { + block_lstm_node = std::dynamic_pointer_cast(pattern_map.at(block_lstm_1)); + auto concat_node = std::dynamic_pointer_cast(pattern_map.at(states_cell_1)); + if (!concat_node || concat_node->get_axis() != 0) { + // timestep is the first dimension + return false; + } + last_state_c_node = pattern_map.at(pattern1); + rt_info_from = {block_lstm_node, concat_node, last_state_c_node}; + } else if (is_pattern2) { + block_lstm_node = std::dynamic_pointer_cast(pattern_map.at(pattern2)); + rt_info_from = {block_lstm_node}; + } if (!block_lstm_node) { return false; } + NodeRegistry rg; // currently, LSTMSequence does not support peephole and cell clip if (block_lstm_node->get_use_peephole()) { return false; @@ -47,7 +96,7 @@ pass::BlockLSTMToLSTMSequenceOneOutput::BlockLSTMToLSTMSequenceOneOutput() { return false; } - auto block_lstm_node_name = block_lstm->get_friendly_name(); + auto block_lstm_node_name = block_lstm_node->get_friendly_name(); auto seq_len_max = block_lstm_node->input_value(0); auto x = block_lstm_node->input_value(1); auto cs_prev = block_lstm_node->input_value(2); @@ -84,15 +133,6 @@ pass::BlockLSTMToLSTMSequenceOneOutput::BlockLSTMToLSTMSequenceOneOutput() { auto hidden_size_const = rg.make(element::i64, Shape{1}, std::vector{hidden_size.get_length()}); - // this transformation expects only one output - concatenated hidden states - // the only output of BlockLSTM that is supported by LSTMSequence - std::vector restricted_output_indices = {0, 1, 2, 3, 4, 5}; - for (size_t output_ind : restricted_output_indices) { - if (block_lstm_node->output(output_ind).get_target_inputs().size() > 0) { - return false; - } - } - // adjust weights and bias // 1. reshape weights and bias to highlight channel dimension auto new_weight_shape = rg.make(element::i64, Shape{3}, std::vector{0, 4, -1}); @@ -152,28 +192,42 @@ pass::BlockLSTMToLSTMSequenceOneOutput::BlockLSTMToLSTMSequenceOneOutput() { hidden_size.get_length(), LSTMSequence::direction::FORWARD); - // adjust output of concatenated of hidden states from LSTMSequence to have it in a format [time_len, - // batch_size, hidden_size] - // 1. squeeze extra dimension - num_directions - auto squeeze_axis = rg.make(element::i64, Shape{1}, std::vector{1}); - auto squeeze_output_hidden_states = rg.make(lstm_sequence->output(0), squeeze_axis); - // 2. transpose the output to rotate batch and time dimensions - auto output_hidden_states_order = rg.make(element::i64, Shape{3}, std::vector{1, 0, 2}); - auto output_hidden_states = rg.make(squeeze_output_hidden_states, output_hidden_states_order); + if (block_lstm_node->output(1).get_target_inputs().size() > 0) { + // adjust output with the last state cell and connect to the main graph + // squeeze extra dimension - num_directions + auto squeeze_axis = rg.make(element::i64, Shape{1}, std::vector{1}); + auto squeeze_last_state_cell = rg.make(lstm_sequence->output(2), squeeze_axis); - // preserve names of the node and the output tensor - output_hidden_states->set_friendly_name(m.get_match_root()->get_friendly_name() + ":6"); - copy_runtime_info(block_lstm_node, rg.get()); + // preserve names of the node and the output tensor + squeeze_last_state_cell->set_friendly_name(last_state_c_node->get_friendly_name()); + + ov::replace_node(last_state_c_node, squeeze_last_state_cell); + } + + if (block_lstm_node->output(6).get_target_inputs().size() > 0) { + // adjust output of concatenated of hidden states from LSTMSequence + // to have it in a format [time_len, batch_size, hidden_size] + // 1. squeeze extra dimension - num_directions + auto squeeze_axis = rg.make(element::i64, Shape{1}, std::vector{1}); + auto squeeze_output_hidden_states = rg.make(lstm_sequence->output(0), squeeze_axis); + // 2. transpose the output to rotate batch and time dimensions + auto output_hidden_states_order = rg.make(element::i64, Shape{3}, std::vector{1, 0, 2}); + auto output_hidden_states = rg.make(squeeze_output_hidden_states, output_hidden_states_order); + + // preserve names of the node and the output tensor + output_hidden_states->set_friendly_name(block_lstm_node->get_friendly_name() + ":6"); + + // replace BlockLSTM with LSTMSequence manually instead of calling + // ov::replace_node(m.get_match_root(), lstm_sequence); + // because BlockLSTM has 7 outputs and LSTMSequence has three outputs + block_lstm_node->output(6).replace(output_hidden_states->output(0)); + } + + copy_runtime_info(rt_info_from, rg.get()); - // replace BlockLSTM with LSTMSequence manually instead of calling - // ov::replace_node(m.get_match_root(), lstm_sequence); - // because BlockLSTM has 7 outputs and LSTMSequence has three outputs - m.get_match_root()->output(6).replace(output_hidden_states->output(0)); return true; }; - auto m = - std::make_shared(block_lstm, - "ov::frontend::tensorflow::pass::BlockLSTMToLSTMSequenceOneOutput"); + auto m = std::make_shared(root, "ov::frontend::tensorflow::pass::BlockLSTMReplacer"); register_matcher(m, callback); } diff --git a/src/frontends/tensorflow/src/helper_transforms/block_lstm_replacer.hpp b/src/frontends/tensorflow/src/helper_transforms/block_lstm_replacer.hpp index 324fc78665c..1c9fd4638f1 100644 --- a/src/frontends/tensorflow/src/helper_transforms/block_lstm_replacer.hpp +++ b/src/frontends/tensorflow/src/helper_transforms/block_lstm_replacer.hpp @@ -16,12 +16,12 @@ namespace frontend { namespace tensorflow { namespace pass { -// This transformation handles BlockLSTM with just one output, concatenation of all the intermediate -// output values of the hidden. -class TENSORFLOW_API BlockLSTMToLSTMSequenceOneOutput : public ov::pass::MatcherPass { +// This transformation replaces BlockLSTM with such outputs as concatenated hidden states +// and cell state from the last time step. +class TENSORFLOW_API BlockLSTMReplacer : public ov::pass::MatcherPass { public: - OPENVINO_RTTI("ov::frontend::tensorflow::pass::BlockLSTMToLSTMSequenceOneOutput"); - BlockLSTMToLSTMSequenceOneOutput(); + OPENVINO_RTTI("ov::frontend::tensorflow::pass::BlockLSTMReplacer"); + BlockLSTMReplacer(); }; } // namespace pass diff --git a/src/frontends/tensorflow/tests/block_lstm_replacer.cpp b/src/frontends/tensorflow/tests/block_lstm_replacer.cpp index f7b09e52e30..bf8ec7a1cc9 100644 --- a/src/frontends/tensorflow/tests/block_lstm_replacer.cpp +++ b/src/frontends/tensorflow/tests/block_lstm_replacer.cpp @@ -29,7 +29,8 @@ shared_ptr gen_model(Dimension batch_size, Dimension input_size, float forget_bias, float cell_clip, - bool use_peephole) { + bool use_peephole, + bool with_two_outputs = false) { auto seq_len_max = make_shared(i64, Shape{}); auto x = make_shared(f32, PartialShape{time_len, batch_size, input_size}); auto cs_prev = make_shared(f32, PartialShape::dynamic()); @@ -43,6 +44,20 @@ shared_ptr gen_model(Dimension batch_size, auto block_lstm = make_shared< BlockLSTM>(seq_len_max, x, cs_prev, h_prev, w, wci, wcf, wco, b, forget_bias, cell_clip, use_peephole); + if (with_two_outputs) { + auto prev_cell_states = make_shared( + ov::element::f32, + ov::Shape{1, static_cast(batch_size.get_length()), static_cast(hidden_size)}, + 0); + auto concat = make_shared(OutputVector{prev_cell_states, block_lstm->output(1)}, 0); + auto indices_const = make_shared(ov::element::i32, + ov::Shape{2}, + vector{static_cast(time_len.get_length()), 0}); + auto gather_nd = make_shared(concat, indices_const); + return make_shared(OutputVector{gather_nd->output(0), block_lstm->output(6)}, + ParameterVector{seq_len_max, x, cs_prev, h_prev, w, wci, wcf, wco, b}); + } + return make_shared(OutputVector{block_lstm->output(6)}, ParameterVector{seq_len_max, x, cs_prev, h_prev, w, wci, wcf, wco, b}); } @@ -51,7 +66,8 @@ shared_ptr gen_model_ref(Dimension m_batch_size, Dimension m_time_len, int64_t m_hidden_size, Dimension m_input_size, - float forget_bias) { + float forget_bias, + bool with_two_outputs = false) { auto seq_len_max = make_shared(i64, Shape{}); auto x = make_shared(f32, PartialShape{m_time_len, m_batch_size, m_input_size}); auto cs_prev = make_shared(f32, PartialShape::dynamic()); @@ -150,24 +166,41 @@ shared_ptr gen_model_ref(Dimension m_batch_size, auto output_hidden_states_order = make_shared(element::i64, Shape{3}, std::vector{1, 0, 2}); auto output_hidden_states = make_shared(squeeze_output_hidden_states, output_hidden_states_order); + if (with_two_outputs) { + // adjust output with the last state cell and connect to the main graph + // squeeze extra dimension - num_directions + auto squeeze_axis = make_shared(element::i64, Shape{1}, std::vector{1}); + auto squeeze_last_state_cell = make_shared(lstm_sequence->output(2), squeeze_axis); + return make_shared(OutputVector{squeeze_last_state_cell->output(0), output_hidden_states->output(0)}, + ParameterVector{seq_len_max, x, cs_prev, h_prev, weights, bias}); + } + return make_shared(OutputVector{output_hidden_states->output(0)}, ParameterVector{seq_len_max, x, cs_prev, h_prev, weights, bias}); } } // namespace -TEST_F(TransformationTestsF, BlockLSTMReplacerOneOutput) { +TEST_F(TransformationTestsF, BlockLSTMReplacerWithHiddenOutput) { { function = gen_model(2, 10, 120, 20, 1.0f, -1.0f, false); - manager.register_pass(); + manager.register_pass(); } { function_ref = gen_model_ref(2, 10, 120, 20, 1.0f); } } -TEST_F(TransformationTestsF, BlockLSTMReplacerOneOutputPeepHole) { +TEST_F(TransformationTestsF, BlockLSTMReplacerWithHiddenOutputAndLastCellState) { + { + function = gen_model(2, 10, 120, 20, 1.0f, -1.0f, false, true); + manager.register_pass(); + } + { function_ref = gen_model_ref(2, 10, 120, 20, 1.0f, true); } +} + +TEST_F(TransformationTestsF, BlockLSTMReplacerWithPeepHole) { { function = gen_model(2, 10, 120, 20, 1.0f, -1.0f, true); - manager.register_pass(); + manager.register_pass(); } { // the transformation is not applied for the peep hole case