[TF FE] Support DeepSpeech model by TF FE (#13316)
* [TF FE] Support DeepSpeech model by TF FE Add the final part to support BlockLSTM operation with sliced state cell from the last time step. Signed-off-by: Kazantsev, Roman <roman.kazantsev@intel.com> * Apply code-review feedback: use get_pattern_map, no nullptr Signed-off-by: Kazantsev, Roman <roman.kazantsev@intel.com>
This commit is contained in:
parent
76fc9cb109
commit
fa3c745263
@ -430,7 +430,7 @@ void FrontEnd::normalize(const std::shared_ptr<ov::Model>& function) const {
|
||||
// Runs middle transformations to convert sub-graphs with intermediate (frontend internal) operations
|
||||
// into sub-graphs with only OpenVINO operations
|
||||
manager.register_pass<ov::frontend::tensorflow::pass::EmbeddingSegmentSingleFeatureFusion>();
|
||||
manager.register_pass<ov::frontend::tensorflow::pass::BlockLSTMToLSTMSequenceOneOutput>();
|
||||
manager.register_pass<ov::frontend::tensorflow::pass::BlockLSTMReplacer>();
|
||||
manager.register_pass<ov::frontend::tensorflow::pass::GRUBlockCellReplacer>();
|
||||
|
||||
// TODO: reimplement TransposeSinking that does not corrupt filters for Convolution
|
||||
|
@ -18,20 +18,69 @@
|
||||
|
||||
using namespace std;
|
||||
using namespace ov::pass;
|
||||
using namespace ov::pass::pattern;
|
||||
using namespace ov::opset9;
|
||||
using namespace ov::frontend::tensorflow;
|
||||
|
||||
pass::BlockLSTMToLSTMSequenceOneOutput::BlockLSTMToLSTMSequenceOneOutput() {
|
||||
auto block_lstm = pattern::wrap_type<BlockLSTM>();
|
||||
namespace {
|
||||
std::function<bool(ov::Output<ov::Node>)> can_have_outputs(const std::vector<size_t>& allowed_output_indices) {
|
||||
return [=](ov::Output<ov::Node> output) -> bool {
|
||||
auto block_lstm_node = output.get_node_shared_ptr();
|
||||
auto output_size = block_lstm_node->get_output_size();
|
||||
for (size_t output_ind = 0; output_ind < output_size; ++output_ind) {
|
||||
if (std::find(allowed_output_indices.begin(), allowed_output_indices.end(), output_ind) !=
|
||||
allowed_output_indices.end()) {
|
||||
continue;
|
||||
}
|
||||
if (block_lstm_node->output(output_ind).get_target_inputs().size() > 0) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
return true;
|
||||
};
|
||||
}
|
||||
} // namespace
|
||||
|
||||
pass::BlockLSTMReplacer::BlockLSTMReplacer() {
|
||||
// Pattern 1: BlockLSTM with last state cell output (BlockLSTM -> Concat -> GatherND)
|
||||
// used in DeepSpeech model
|
||||
auto block_lstm_1 = pattern::wrap_type<BlockLSTM>(can_have_outputs({1, 6}));
|
||||
auto states_cell_1 = pattern::wrap_type<Concat>({pattern::any_input(), block_lstm_1});
|
||||
auto pattern1 = pattern::wrap_type<GatherND>({states_cell_1, pattern::any_input()});
|
||||
|
||||
// Pattern 2: BlockLSTM with just one output, concatenated hidden states (BlockLSTM)
|
||||
auto pattern2 = pattern::wrap_type<BlockLSTM>(can_have_outputs({6}));
|
||||
|
||||
auto root = std::make_shared<pattern::op::Or>(OutputVector{pattern1, pattern2});
|
||||
|
||||
matcher_pass_callback callback = [=](pattern::Matcher& m) {
|
||||
NodeRegistry rg;
|
||||
auto pattern_map = m.get_pattern_map();
|
||||
auto is_pattern1 = (pattern_map.find(pattern1) != std::end(pattern_map));
|
||||
auto is_pattern2 = (pattern_map.find(pattern2) != std::end(pattern_map));
|
||||
|
||||
auto block_lstm_node = std::dynamic_pointer_cast<BlockLSTM>(m.get_match_root());
|
||||
// find for each pattern BlockLSTM node for which we adjust inputs
|
||||
// and check its attributes before the transformation
|
||||
std::shared_ptr<BlockLSTM> block_lstm_node;
|
||||
std::shared_ptr<Node> last_state_c_node;
|
||||
ov::NodeVector rt_info_from;
|
||||
if (is_pattern1) {
|
||||
block_lstm_node = std::dynamic_pointer_cast<BlockLSTM>(pattern_map.at(block_lstm_1));
|
||||
auto concat_node = std::dynamic_pointer_cast<Concat>(pattern_map.at(states_cell_1));
|
||||
if (!concat_node || concat_node->get_axis() != 0) {
|
||||
// timestep is the first dimension
|
||||
return false;
|
||||
}
|
||||
last_state_c_node = pattern_map.at(pattern1);
|
||||
rt_info_from = {block_lstm_node, concat_node, last_state_c_node};
|
||||
} else if (is_pattern2) {
|
||||
block_lstm_node = std::dynamic_pointer_cast<BlockLSTM>(pattern_map.at(pattern2));
|
||||
rt_info_from = {block_lstm_node};
|
||||
}
|
||||
if (!block_lstm_node) {
|
||||
return false;
|
||||
}
|
||||
|
||||
NodeRegistry rg;
|
||||
// currently, LSTMSequence does not support peephole and cell clip
|
||||
if (block_lstm_node->get_use_peephole()) {
|
||||
return false;
|
||||
@ -47,7 +96,7 @@ pass::BlockLSTMToLSTMSequenceOneOutput::BlockLSTMToLSTMSequenceOneOutput() {
|
||||
return false;
|
||||
}
|
||||
|
||||
auto block_lstm_node_name = block_lstm->get_friendly_name();
|
||||
auto block_lstm_node_name = block_lstm_node->get_friendly_name();
|
||||
auto seq_len_max = block_lstm_node->input_value(0);
|
||||
auto x = block_lstm_node->input_value(1);
|
||||
auto cs_prev = block_lstm_node->input_value(2);
|
||||
@ -84,15 +133,6 @@ pass::BlockLSTMToLSTMSequenceOneOutput::BlockLSTMToLSTMSequenceOneOutput() {
|
||||
auto hidden_size_const =
|
||||
rg.make<Constant>(element::i64, Shape{1}, std::vector<int64_t>{hidden_size.get_length()});
|
||||
|
||||
// this transformation expects only one output - concatenated hidden states
|
||||
// the only output of BlockLSTM that is supported by LSTMSequence
|
||||
std::vector<int> restricted_output_indices = {0, 1, 2, 3, 4, 5};
|
||||
for (size_t output_ind : restricted_output_indices) {
|
||||
if (block_lstm_node->output(output_ind).get_target_inputs().size() > 0) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
// adjust weights and bias
|
||||
// 1. reshape weights and bias to highlight channel dimension
|
||||
auto new_weight_shape = rg.make<Constant>(element::i64, Shape{3}, std::vector<int64_t>{0, 4, -1});
|
||||
@ -152,28 +192,42 @@ pass::BlockLSTMToLSTMSequenceOneOutput::BlockLSTMToLSTMSequenceOneOutput() {
|
||||
hidden_size.get_length(),
|
||||
LSTMSequence::direction::FORWARD);
|
||||
|
||||
// adjust output of concatenated of hidden states from LSTMSequence to have it in a format [time_len,
|
||||
// batch_size, hidden_size]
|
||||
// 1. squeeze extra dimension - num_directions
|
||||
auto squeeze_axis = rg.make<Constant>(element::i64, Shape{1}, std::vector<int64_t>{1});
|
||||
auto squeeze_output_hidden_states = rg.make<Squeeze>(lstm_sequence->output(0), squeeze_axis);
|
||||
// 2. transpose the output to rotate batch and time dimensions
|
||||
auto output_hidden_states_order = rg.make<Constant>(element::i64, Shape{3}, std::vector<int64_t>{1, 0, 2});
|
||||
auto output_hidden_states = rg.make<Transpose>(squeeze_output_hidden_states, output_hidden_states_order);
|
||||
if (block_lstm_node->output(1).get_target_inputs().size() > 0) {
|
||||
// adjust output with the last state cell and connect to the main graph
|
||||
// squeeze extra dimension - num_directions
|
||||
auto squeeze_axis = rg.make<Constant>(element::i64, Shape{1}, std::vector<int64_t>{1});
|
||||
auto squeeze_last_state_cell = rg.make<Squeeze>(lstm_sequence->output(2), squeeze_axis);
|
||||
|
||||
// preserve names of the node and the output tensor
|
||||
output_hidden_states->set_friendly_name(m.get_match_root()->get_friendly_name() + ":6");
|
||||
copy_runtime_info(block_lstm_node, rg.get());
|
||||
// preserve names of the node and the output tensor
|
||||
squeeze_last_state_cell->set_friendly_name(last_state_c_node->get_friendly_name());
|
||||
|
||||
ov::replace_node(last_state_c_node, squeeze_last_state_cell);
|
||||
}
|
||||
|
||||
if (block_lstm_node->output(6).get_target_inputs().size() > 0) {
|
||||
// adjust output of concatenated of hidden states from LSTMSequence
|
||||
// to have it in a format [time_len, batch_size, hidden_size]
|
||||
// 1. squeeze extra dimension - num_directions
|
||||
auto squeeze_axis = rg.make<Constant>(element::i64, Shape{1}, std::vector<int64_t>{1});
|
||||
auto squeeze_output_hidden_states = rg.make<Squeeze>(lstm_sequence->output(0), squeeze_axis);
|
||||
// 2. transpose the output to rotate batch and time dimensions
|
||||
auto output_hidden_states_order = rg.make<Constant>(element::i64, Shape{3}, std::vector<int64_t>{1, 0, 2});
|
||||
auto output_hidden_states = rg.make<Transpose>(squeeze_output_hidden_states, output_hidden_states_order);
|
||||
|
||||
// preserve names of the node and the output tensor
|
||||
output_hidden_states->set_friendly_name(block_lstm_node->get_friendly_name() + ":6");
|
||||
|
||||
// replace BlockLSTM with LSTMSequence manually instead of calling
|
||||
// ov::replace_node(m.get_match_root(), lstm_sequence);
|
||||
// because BlockLSTM has 7 outputs and LSTMSequence has three outputs
|
||||
block_lstm_node->output(6).replace(output_hidden_states->output(0));
|
||||
}
|
||||
|
||||
copy_runtime_info(rt_info_from, rg.get());
|
||||
|
||||
// replace BlockLSTM with LSTMSequence manually instead of calling
|
||||
// ov::replace_node(m.get_match_root(), lstm_sequence);
|
||||
// because BlockLSTM has 7 outputs and LSTMSequence has three outputs
|
||||
m.get_match_root()->output(6).replace(output_hidden_states->output(0));
|
||||
return true;
|
||||
};
|
||||
|
||||
auto m =
|
||||
std::make_shared<ngraph::pattern::Matcher>(block_lstm,
|
||||
"ov::frontend::tensorflow::pass::BlockLSTMToLSTMSequenceOneOutput");
|
||||
auto m = std::make_shared<ngraph::pattern::Matcher>(root, "ov::frontend::tensorflow::pass::BlockLSTMReplacer");
|
||||
register_matcher(m, callback);
|
||||
}
|
||||
|
@ -16,12 +16,12 @@ namespace frontend {
|
||||
namespace tensorflow {
|
||||
namespace pass {
|
||||
|
||||
// This transformation handles BlockLSTM with just one output, concatenation of all the intermediate
|
||||
// output values of the hidden.
|
||||
class TENSORFLOW_API BlockLSTMToLSTMSequenceOneOutput : public ov::pass::MatcherPass {
|
||||
// This transformation replaces BlockLSTM with such outputs as concatenated hidden states
|
||||
// and cell state from the last time step.
|
||||
class TENSORFLOW_API BlockLSTMReplacer : public ov::pass::MatcherPass {
|
||||
public:
|
||||
OPENVINO_RTTI("ov::frontend::tensorflow::pass::BlockLSTMToLSTMSequenceOneOutput");
|
||||
BlockLSTMToLSTMSequenceOneOutput();
|
||||
OPENVINO_RTTI("ov::frontend::tensorflow::pass::BlockLSTMReplacer");
|
||||
BlockLSTMReplacer();
|
||||
};
|
||||
|
||||
} // namespace pass
|
||||
|
@ -29,7 +29,8 @@ shared_ptr<Model> gen_model(Dimension batch_size,
|
||||
Dimension input_size,
|
||||
float forget_bias,
|
||||
float cell_clip,
|
||||
bool use_peephole) {
|
||||
bool use_peephole,
|
||||
bool with_two_outputs = false) {
|
||||
auto seq_len_max = make_shared<Parameter>(i64, Shape{});
|
||||
auto x = make_shared<Parameter>(f32, PartialShape{time_len, batch_size, input_size});
|
||||
auto cs_prev = make_shared<Parameter>(f32, PartialShape::dynamic());
|
||||
@ -43,6 +44,20 @@ shared_ptr<Model> gen_model(Dimension batch_size,
|
||||
auto block_lstm = make_shared<
|
||||
BlockLSTM>(seq_len_max, x, cs_prev, h_prev, w, wci, wcf, wco, b, forget_bias, cell_clip, use_peephole);
|
||||
|
||||
if (with_two_outputs) {
|
||||
auto prev_cell_states = make_shared<Constant>(
|
||||
ov::element::f32,
|
||||
ov::Shape{1, static_cast<uint64_t>(batch_size.get_length()), static_cast<uint64_t>(hidden_size)},
|
||||
0);
|
||||
auto concat = make_shared<Concat>(OutputVector{prev_cell_states, block_lstm->output(1)}, 0);
|
||||
auto indices_const = make_shared<Constant>(ov::element::i32,
|
||||
ov::Shape{2},
|
||||
vector<int32_t>{static_cast<int32_t>(time_len.get_length()), 0});
|
||||
auto gather_nd = make_shared<GatherND>(concat, indices_const);
|
||||
return make_shared<Model>(OutputVector{gather_nd->output(0), block_lstm->output(6)},
|
||||
ParameterVector{seq_len_max, x, cs_prev, h_prev, w, wci, wcf, wco, b});
|
||||
}
|
||||
|
||||
return make_shared<Model>(OutputVector{block_lstm->output(6)},
|
||||
ParameterVector{seq_len_max, x, cs_prev, h_prev, w, wci, wcf, wco, b});
|
||||
}
|
||||
@ -51,7 +66,8 @@ shared_ptr<Model> gen_model_ref(Dimension m_batch_size,
|
||||
Dimension m_time_len,
|
||||
int64_t m_hidden_size,
|
||||
Dimension m_input_size,
|
||||
float forget_bias) {
|
||||
float forget_bias,
|
||||
bool with_two_outputs = false) {
|
||||
auto seq_len_max = make_shared<Parameter>(i64, Shape{});
|
||||
auto x = make_shared<Parameter>(f32, PartialShape{m_time_len, m_batch_size, m_input_size});
|
||||
auto cs_prev = make_shared<Parameter>(f32, PartialShape::dynamic());
|
||||
@ -150,24 +166,41 @@ shared_ptr<Model> gen_model_ref(Dimension m_batch_size,
|
||||
auto output_hidden_states_order = make_shared<Constant>(element::i64, Shape{3}, std::vector<int64_t>{1, 0, 2});
|
||||
auto output_hidden_states = make_shared<Transpose>(squeeze_output_hidden_states, output_hidden_states_order);
|
||||
|
||||
if (with_two_outputs) {
|
||||
// adjust output with the last state cell and connect to the main graph
|
||||
// squeeze extra dimension - num_directions
|
||||
auto squeeze_axis = make_shared<Constant>(element::i64, Shape{1}, std::vector<int64_t>{1});
|
||||
auto squeeze_last_state_cell = make_shared<Squeeze>(lstm_sequence->output(2), squeeze_axis);
|
||||
return make_shared<Model>(OutputVector{squeeze_last_state_cell->output(0), output_hidden_states->output(0)},
|
||||
ParameterVector{seq_len_max, x, cs_prev, h_prev, weights, bias});
|
||||
}
|
||||
|
||||
return make_shared<Model>(OutputVector{output_hidden_states->output(0)},
|
||||
ParameterVector{seq_len_max, x, cs_prev, h_prev, weights, bias});
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
TEST_F(TransformationTestsF, BlockLSTMReplacerOneOutput) {
|
||||
TEST_F(TransformationTestsF, BlockLSTMReplacerWithHiddenOutput) {
|
||||
{
|
||||
function = gen_model(2, 10, 120, 20, 1.0f, -1.0f, false);
|
||||
manager.register_pass<BlockLSTMToLSTMSequenceOneOutput>();
|
||||
manager.register_pass<BlockLSTMReplacer>();
|
||||
}
|
||||
{ function_ref = gen_model_ref(2, 10, 120, 20, 1.0f); }
|
||||
}
|
||||
|
||||
TEST_F(TransformationTestsF, BlockLSTMReplacerOneOutputPeepHole) {
|
||||
TEST_F(TransformationTestsF, BlockLSTMReplacerWithHiddenOutputAndLastCellState) {
|
||||
{
|
||||
function = gen_model(2, 10, 120, 20, 1.0f, -1.0f, false, true);
|
||||
manager.register_pass<BlockLSTMReplacer>();
|
||||
}
|
||||
{ function_ref = gen_model_ref(2, 10, 120, 20, 1.0f, true); }
|
||||
}
|
||||
|
||||
TEST_F(TransformationTestsF, BlockLSTMReplacerWithPeepHole) {
|
||||
{
|
||||
function = gen_model(2, 10, 120, 20, 1.0f, -1.0f, true);
|
||||
manager.register_pass<BlockLSTMToLSTMSequenceOneOutput>();
|
||||
manager.register_pass<BlockLSTMReplacer>();
|
||||
}
|
||||
{
|
||||
// the transformation is not applied for the peep hole case
|
||||
|
Loading…
Reference in New Issue
Block a user