From 8406983849d502f756264daf61701512353b25f7 Mon Sep 17 00:00:00 2001 From: Roman Kazantsev Date: Tue, 27 Sep 2022 01:30:55 +0300 Subject: [PATCH] [TF FE] Support GRUBlockCell (#13202) * [TF FE] Support GRUBlockCell Currently, we support only hidden state output from GRUBlockCell due to OpenVINO GRUCell capability Signed-off-by: Kazantsev, Roman * Fix code-style issue * Return Softsign translator * Add tests for GRUBlockCellReplacer * Fix issue for bias in GRUBlockCellReplacer Signed-off-by: Kazantsev, Roman --- src/frontends/tensorflow/src/frontend.cpp | 2 + .../src/helper_ops/gru_block_cell.hpp | 115 +++++++++++++++++ .../gru_block_cell_replacer.cpp | 108 ++++++++++++++++ .../gru_block_cell_replacer.hpp | 29 +++++ .../tensorflow/src/op/gru_block_cell.cpp | 49 +++++++ src/frontends/tensorflow/src/op_table.cpp | 2 + .../tests/gru_block_cell_replacer.cpp | 120 ++++++++++++++++++ 7 files changed, 425 insertions(+) create mode 100644 src/frontends/tensorflow/src/helper_ops/gru_block_cell.hpp create mode 100644 src/frontends/tensorflow/src/helper_transforms/gru_block_cell_replacer.cpp create mode 100644 src/frontends/tensorflow/src/helper_transforms/gru_block_cell_replacer.hpp create mode 100644 src/frontends/tensorflow/src/op/gru_block_cell.cpp create mode 100644 src/frontends/tensorflow/tests/gru_block_cell_replacer.cpp diff --git a/src/frontends/tensorflow/src/frontend.cpp b/src/frontends/tensorflow/src/frontend.cpp index 4c0bee01065..20618a8234f 100644 --- a/src/frontends/tensorflow/src/frontend.cpp +++ b/src/frontends/tensorflow/src/frontend.cpp @@ -7,6 +7,7 @@ #include "graph_iterator_proto.hpp" #include "helper_transforms/block_lstm_replacer.hpp" #include "helper_transforms/embedding_segments_feature_fusing.hpp" +#include "helper_transforms/gru_block_cell_replacer.hpp" #include "input_model.hpp" #include "op_table.hpp" #include "openvino/frontend/tensorflow/extension/conversion.hpp" @@ -430,6 +431,7 @@ void FrontEnd::normalize(const std::shared_ptr& function) const { // into sub-graphs with only OpenVINO operations manager.register_pass(); manager.register_pass(); + manager.register_pass(); // TODO: reimplement TransposeSinking that does not corrupt filters for Convolution // and preserve tensor names in case of sinking diff --git a/src/frontends/tensorflow/src/helper_ops/gru_block_cell.hpp b/src/frontends/tensorflow/src/helper_ops/gru_block_cell.hpp new file mode 100644 index 00000000000..f9125b270b1 --- /dev/null +++ b/src/frontends/tensorflow/src/helper_ops/gru_block_cell.hpp @@ -0,0 +1,115 @@ +// Copyright (C) 2018-2022 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#pragma once + +#include +#include + +#include "internal_operation.hpp" + +namespace ov { +namespace frontend { +namespace tensorflow { + +class GRUBlockCell : public InternalOperation { +public: + OPENVINO_OP("GRUBlockCell", "ov::frontend::tensorflow::util", InternalOperation); + + GRUBlockCell(const Output& x, + const Output& h_prev, + const Output& w_ru, + const Output& w_c, + const Output& b_ru, + const Output& b_c, + const std::shared_ptr& decoder = std::make_shared()) + : InternalOperation(decoder, OutputVector{x, h_prev, w_ru, w_c, b_ru, b_c}, 4), + m_hidden_size(ov::Dimension::dynamic()) { + validate_and_infer_types(); + } + + void validate_and_infer_types() override { + // GRUBlockCell computes the GRU cell forward propagation for 1 time step + // Inputs: + // 0) x: Input to the GRU cell + // 1) h_prev: State input from the previous GRU cell + // 2) w_ru: Weight matrix for the reset and update gate + // 3) w_c: Weight matrix for the cell connection gate + // 4) b_ru: Bias vector for the reset and update gate + // 5) b_c: Bias vector for the cell connection gate + // + // Outputs: + // 0) r: Output of the reset gate + // 1) u: Output of the update gate + // 2) c: Output of the cell connection gate + // 3) h: Current state of the GRU cell + + // try to deduce static hidden_size + // 1. use h_prev shape + auto h_prev_shape = get_input_partial_shape(1); + auto h_prev_rank = h_prev_shape.rank(); + if (h_prev_rank.is_static()) { + FRONT_END_OP_CONVERSION_CHECK(h_prev_rank.get_length() == 2, + "Internal error in OpenVINO TensorFlow Frontend: initial hidden state for " + "GRUBlockCell must be of rank equal to 2."); + m_hidden_size = h_prev_shape[1].is_static() ? h_prev_shape[1].get_length() : m_hidden_size; + } + // 2. use w_ru shape + auto w_ru_shape = get_input_partial_shape(2); + auto w_ru_rank = w_ru_shape.rank(); + if (w_ru_rank.is_static()) { + FRONT_END_OP_CONVERSION_CHECK( + w_ru_rank.get_length() == 2, + "Internal error in OpenVINO TensorFlow Frontend: weights for GRUBlockCell must be of rank equal to 2."); + m_hidden_size = w_ru_shape[1].is_static() ? w_ru_shape[1].get_length() / 2 : m_hidden_size; + } + // 3. use w_c shape + auto w_c_shape = get_input_partial_shape(3); + auto w_c_rank = w_c_shape.rank(); + if (w_c_rank.is_static()) { + FRONT_END_OP_CONVERSION_CHECK( + w_c_rank.get_length() == 2, + "Internal error in OpenVINO TensorFlow Frontend: weights for GRUBlockCell must be of rank equal to 2."); + m_hidden_size = w_c_shape[1].is_static() ? w_c_shape[1].get_length() : m_hidden_size; + } + // 3. use b_ru shape + auto b_ru_shape = get_input_partial_shape(4); + auto b_ru_rank = b_ru_shape.rank(); + if (b_ru_rank.is_static()) { + FRONT_END_OP_CONVERSION_CHECK( + b_ru_rank.get_length() == 1, + "Internal error in OpenVINO TensorFlow Frontend: bias for GRUBlockCell must be of rank equal to 1."); + m_hidden_size = b_ru_shape[0].is_static() ? b_ru_shape[0].get_length() / 2 : m_hidden_size; + } + // 4. use b_c shape + auto b_c_shape = get_input_partial_shape(5); + auto b_c_rank = b_c_shape.rank(); + if (b_c_rank.is_static()) { + FRONT_END_OP_CONVERSION_CHECK( + b_c_rank.get_length() == 1, + "Internal error in OpenVINO TensorFlow Frontend: bias for GRUBlockCell must be of rank equal to 1."); + m_hidden_size = b_c_shape[0].is_static() ? b_c_shape[0].get_length() : m_hidden_size; + } + + // set the defined shape only for the fourth output since + // OpenVINO GRUCell supports hidden state output + auto x_type = get_input_element_type(0); + set_output_type(0, x_type, ov::PartialShape::dynamic()); + set_output_type(1, x_type, ov::PartialShape::dynamic()); + set_output_type(2, x_type, ov::PartialShape::dynamic()); + set_output_type(3, x_type, h_prev_shape); + } + + ov::Dimension get_hidden_size() const { + // TODO: it must be deleted once hidden_size is gone from attributes + return m_hidden_size; + } + +private: + ov::Dimension m_hidden_size; +}; + +} // namespace tensorflow +} // namespace frontend +} // namespace ov diff --git a/src/frontends/tensorflow/src/helper_transforms/gru_block_cell_replacer.cpp b/src/frontends/tensorflow/src/helper_transforms/gru_block_cell_replacer.cpp new file mode 100644 index 00000000000..f34ed7b789b --- /dev/null +++ b/src/frontends/tensorflow/src/helper_transforms/gru_block_cell_replacer.cpp @@ -0,0 +1,108 @@ +// Copyright (C) 2018-2022 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#include "helper_transforms/gru_block_cell_replacer.hpp" + +#include +#include + +#include "helper_ops/gru_block_cell.hpp" +#include "ngraph/rt_info.hpp" +#include "openvino/opsets/opset9.hpp" +#include "openvino/pass/pattern/matcher.hpp" +#include "openvino/pass/pattern/op/or.hpp" +#include "openvino/pass/pattern/op/wrap_type.hpp" +#include "transformations/utils/utils.hpp" +#include "utils.hpp" + +using namespace std; +using namespace ov::pass; +using namespace ov::opset9; +using namespace ov::frontend::tensorflow; + +pass::GRUBlockCellReplacer::GRUBlockCellReplacer() { + auto gru_block_cell = pattern::wrap_type(); + + matcher_pass_callback callback = [=](pattern::Matcher& m) { + NodeRegistry rg; + + auto gru_block_cell_node = std::dynamic_pointer_cast(m.get_match_root()); + if (!gru_block_cell_node) { + return false; + } + + // this transformation expects only one output (the forth output) - hidden state + // that is only output supported by OpenVINO GRUCell + std::vector restricted_output_indices = {0, 1, 2}; + for (size_t output_ind : restricted_output_indices) { + if (gru_block_cell_node->output(output_ind).get_target_inputs().size() > 0) { + return false; + } + } + + // currently, OpenVINO support only static hidden_size + auto m_hidden_size = gru_block_cell_node->get_hidden_size(); + if (m_hidden_size.is_dynamic()) { + return false; + } + + auto x = gru_block_cell_node->input_value(0); + auto h_prev = gru_block_cell_node->input_value(1); + auto w_ru = gru_block_cell_node->input_value(2); + auto w_c = gru_block_cell_node->input_value(3); + auto b_ru = gru_block_cell_node->input_value(4); + auto b_c = gru_block_cell_node->input_value(5); + + // retrive input_size and hidden_size + auto x_shape = rg.make(x, element::i64); + auto ss_start = rg.make(element::i64, Shape{1}, 1); + auto ss_end = rg.make(element::i64, Shape{1}, 2); + auto ss_step = rg.make(element::i64, Shape{1}, 1); + auto input_size = rg.make(x_shape, ss_start, ss_end, ss_step); + auto h_prev_shape = rg.make(h_prev, element::i64); + auto hidden_size = rg.make(h_prev_shape, ss_start, ss_end, ss_step); + + // prepare weights input + // TensorFlow provides weights in a format w_ru and w_c, where + // z or u - update, r - reset, c or h - hidden (connection) + // OpenVINO GRUCell accepts weights in a format w_zrh (or w_urс) + // 1. split w_ru into w_r and w_u + auto split_w_ru = rg.make(w_ru, rg.make(element::i64, Shape{}, 1), 2); + // 2. concatenate different parts of weights into w_zrh (or w_urс) + auto w_urc = rg.make(OutputVector{split_w_ru->output(1), split_w_ru->output(0), w_c}, 1); + + // prepare bias in the same way + auto split_b_ru = rg.make(b_ru, rg.make(element::i64, Shape{}, 0), 2); + auto b_urc = rg.make(OutputVector{split_b_ru->output(1), split_b_ru->output(0), b_c}, 0); + + // transpose weights + // the current shape - [input_size + hidden_size, 3 * hidden_size] + // we need the shape [3 * hidden_size, input_size + hidden_size] + // in order to split WR into W and R + auto transpose_order = rg.make(element::i64, Shape{2}, std::vector{1, 0}); + auto w_urc_transpose = rg.make(w_urc, transpose_order); + + // split combined weights WR into W and R + auto split_axis = rg.make(element::i64, Shape{}, 1); + auto split_nums = rg.make(OutputVector{input_size, hidden_size}, 0); + auto split_WR = rg.make(w_urc_transpose, split_axis, split_nums); + + auto gru_cell = + rg.make(x, h_prev, split_WR->output(0), split_WR->output(1), b_urc, m_hidden_size.get_length()); + + // preserve names of the node and the output tensor + gru_cell->set_friendly_name(m.get_match_root()->get_friendly_name() + ":3"); + copy_runtime_info(gru_block_cell_node, rg.get()); + + // replace GRUBlockCell with GRUCell manually instead of calling + // ov::replace_node(m.get_match_root(), gru_cell); + // because GRUBlockCell has 4 outputs and GRUCell has just one + m.get_match_root()->output(3).replace(gru_cell->output(0)); + return true; + }; + + auto m = std::make_shared(gru_block_cell, + "ov::frontend::tensorflow::pass::GRUBlockCellReplacer"); + register_matcher(m, callback); +} diff --git a/src/frontends/tensorflow/src/helper_transforms/gru_block_cell_replacer.hpp b/src/frontends/tensorflow/src/helper_transforms/gru_block_cell_replacer.hpp new file mode 100644 index 00000000000..c9204620473 --- /dev/null +++ b/src/frontends/tensorflow/src/helper_transforms/gru_block_cell_replacer.hpp @@ -0,0 +1,29 @@ +// Copyright (C) 2018-2022 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#pragma once + +#include +#include + +#include "openvino/frontend/tensorflow/visibility.hpp" +#include "openvino/pass/graph_rewrite.hpp" +#include "openvino/pass/pass.hpp" + +namespace ov { +namespace frontend { +namespace tensorflow { +namespace pass { + +// This transformation handles GRUBlockCell with just one output - hidden state +class TENSORFLOW_API GRUBlockCellReplacer : public ov::pass::MatcherPass { +public: + OPENVINO_RTTI("ov::frontend::tensorflow::pass::GRUBlockCellReplacer"); + GRUBlockCellReplacer(); +}; + +} // namespace pass +} // namespace tensorflow +} // namespace frontend +} // namespace ov diff --git a/src/frontends/tensorflow/src/op/gru_block_cell.cpp b/src/frontends/tensorflow/src/op/gru_block_cell.cpp new file mode 100644 index 00000000000..08c36902bf0 --- /dev/null +++ b/src/frontends/tensorflow/src/op/gru_block_cell.cpp @@ -0,0 +1,49 @@ +// Copyright (C) 2018-2022 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#include "helper_ops/gru_block_cell.hpp" + +#include "op_table.hpp" + +using namespace std; +using namespace ov; +using namespace ov::frontend::tensorflow; + +namespace ov { +namespace frontend { +namespace tensorflow { +namespace op { + +OutputVector translate_gru_block_cell_op(const NodeContext& node) { + // GRUBlockCell computes the GRU cell forward propagation for 1 time step + // Inputs: + // 0) x: Input to the GRU cell + // 1) h_prev: State input from the previous GRU cell + // 2) w_ru: Weight matrix for the reset and update gate + // 3) w_c: Weight matrix for the cell connection gate + // 4) b_ru: Bias vector for the reset and update gate + // 5) b_c: Bias vector for the cell connection gate + // + // Outputs: + // 0) r: Output of the reset gate + // 1) u: Output of the update gate + // 2) c: Output of the cell connection gate + // 3) h: Current state of the GRU cell + default_op_checks(node, 6, {"GRUBlockCell"}); + auto x = node.get_input(0); + auto h_prev = node.get_input(1); + auto w_ru = node.get_input(2); + auto w_c = node.get_input(3); + auto b_ru = node.get_input(4); + auto b_c = node.get_input(5); + + auto gru_block_cell_node = make_shared(x, h_prev, w_ru, w_c, b_ru, b_c, node.get_decoder()); + set_node_name(node.get_name(), gru_block_cell_node); + return gru_block_cell_node->outputs(); +} + +} // namespace op +} // namespace tensorflow +} // namespace frontend +} // namespace ov diff --git a/src/frontends/tensorflow/src/op_table.cpp b/src/frontends/tensorflow/src/op_table.cpp index ae4565d112e..e6ff8ec65c3 100644 --- a/src/frontends/tensorflow/src/op_table.cpp +++ b/src/frontends/tensorflow/src/op_table.cpp @@ -58,6 +58,7 @@ OP_CONVERTER(translate_fused_batch_norm_op); OP_CONVERTER(translate_gather_op); OP_CONVERTER(translate_gather_v2_op); OP_CONVERTER(translate_gather_nd_op); +OP_CONVERTER(translate_gru_block_cell_op); OP_CONVERTER(translate_identity_op); OP_CONVERTER(translate_identity_n_op); OP_CONVERTER(translate_interpolate_op); @@ -302,6 +303,7 @@ const std::map get_supported_ops() { // Translators for internal operations {"BlockLSTM", translate_block_lstm_op}, + {"GRUBlockCell", translate_gru_block_cell_op}, {"SparseFillEmptyRows", translate_sparse_fill_empty_rows_op}, {"SparseSegmentSum", translate_sparse_segment_sum_op}, {"Unique", translate_unique_op}, diff --git a/src/frontends/tensorflow/tests/gru_block_cell_replacer.cpp b/src/frontends/tensorflow/tests/gru_block_cell_replacer.cpp new file mode 100644 index 00000000000..30c27f69898 --- /dev/null +++ b/src/frontends/tensorflow/tests/gru_block_cell_replacer.cpp @@ -0,0 +1,120 @@ +// Copyright (C) 2018-2022 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#include "helper_transforms/gru_block_cell_replacer.hpp" + +#include + +#include +#include +#include +#include + +#include "common_test_utils/ngraph_test_utils.hpp" +#include "gtest/gtest.h" +#include "helper_ops/gru_block_cell.hpp" + +using namespace std; +using namespace ov; +using namespace opset9; +using namespace element; +using namespace frontend::tensorflow; +using namespace frontend::tensorflow::pass; + +namespace { +shared_ptr gen_model(Dimension batch_size, int64_t hidden_size, Dimension input_size) { + auto x = make_shared(f32, PartialShape{batch_size, input_size}); + auto h_prev = make_shared(f32, PartialShape{batch_size, hidden_size}); + auto w_ru = make_shared(f32, PartialShape::dynamic()); + auto w_c = make_shared(f32, PartialShape::dynamic()); + auto b_ru = make_shared(f32, PartialShape::dynamic()); + auto b_c = make_shared(f32, PartialShape::dynamic()); + + auto gru_block_cell = make_shared(x, h_prev, w_ru, w_c, b_ru, b_c); + + return make_shared(OutputVector{gru_block_cell->output(3)}, + ParameterVector{x, h_prev, w_ru, w_c, b_ru, b_c}); +} + +shared_ptr gen_model_with_two_outputs(Dimension batch_size, int64_t hidden_size, Dimension input_size) { + auto x = make_shared(f32, PartialShape{batch_size, input_size}); + auto h_prev = make_shared(f32, PartialShape{batch_size, hidden_size}); + auto w_ru = make_shared(f32, PartialShape::dynamic()); + auto w_c = make_shared(f32, PartialShape::dynamic()); + auto b_ru = make_shared(f32, PartialShape::dynamic()); + auto b_c = make_shared(f32, PartialShape::dynamic()); + + auto gru_block_cell = make_shared(x, h_prev, w_ru, w_c, b_ru, b_c); + + return make_shared(OutputVector{gru_block_cell->output(0), gru_block_cell->output(3)}, + ParameterVector{x, h_prev, w_ru, w_c, b_ru, b_c}); +} + +shared_ptr gen_model_ref(Dimension m_batch_size, int64_t m_hidden_size, Dimension m_input_size) { + auto x = make_shared(f32, PartialShape{m_batch_size, m_input_size}); + auto h_prev = make_shared(f32, PartialShape{m_batch_size, m_hidden_size}); + auto w_ru = make_shared(f32, PartialShape::dynamic()); + auto w_c = make_shared(f32, PartialShape::dynamic()); + auto b_ru = make_shared(f32, PartialShape::dynamic()); + auto b_c = make_shared(f32, PartialShape::dynamic()); + + // retrive input_size and hidden_size + auto x_shape = make_shared(x, element::i64); + auto ss_start = make_shared(element::i64, Shape{1}, 1); + auto ss_end = make_shared(element::i64, Shape{1}, 2); + auto ss_step = make_shared(element::i64, Shape{1}, 1); + auto input_size = make_shared(x_shape, ss_start, ss_end, ss_step); + auto h_prev_shape = make_shared(h_prev, element::i64); + auto hidden_size = make_shared(h_prev_shape, ss_start, ss_end, ss_step); + + // prepare weights input + // TensorFlow provides weights in a format w_ru and w_c, where + // z or u - update, r - reset, c or h - hidden (connection) + // OpenVINO GRUCell accepts weights in a format w_zrh (or w_urс) + // 1. split w_ru into w_r and w_u + auto split_w_ru = make_shared(w_ru, make_shared(element::i64, Shape{}, 1), 2); + // 2. concatenate different parts of weights into w_zrh (or w_urс) + auto w_urc = make_shared(OutputVector{split_w_ru->output(1), split_w_ru->output(0), w_c}, 1); + + // prepare bias in the same way + auto split_b_ru = make_shared(b_ru, make_shared(element::i64, Shape{}, 0), 2); + auto b_urc = make_shared(OutputVector{split_b_ru->output(1), split_b_ru->output(0), b_c}, 0); + + // transpose weights + // the current shape - [input_size + hidden_size, 3 * hidden_size] + // we need the shape [3 * hidden_size, input_size + hidden_size] + // in order to split WR into W and R + auto transpose_order = make_shared(element::i64, Shape{2}, std::vector{1, 0}); + auto w_urc_transpose = make_shared(w_urc, transpose_order); + + // split combined weights WR into W and R + auto split_axis = make_shared(element::i64, Shape{}, 1); + auto split_nums = make_shared(OutputVector{input_size, hidden_size}, 0); + auto split_WR = make_shared(w_urc_transpose, split_axis, split_nums); + + auto gru_cell = make_shared(x, h_prev, split_WR->output(0), split_WR->output(1), b_urc, m_hidden_size); + + return make_shared(OutputVector{gru_cell->output(0)}, ParameterVector{x, h_prev, w_ru, w_c, b_ru, b_c}); +} + +} // namespace + +TEST_F(TransformationTestsF, GRUBlockCellReplacerOneOutput) { + { + function = gen_model(2, 10, 120); + manager.register_pass(); + } + { function_ref = gen_model_ref(2, 10, 120); } +} + +TEST_F(TransformationTestsF, GRUBlockCellReplacerTwoOutputs) { + { + function = gen_model_with_two_outputs(2, 10, 120); + manager.register_pass(); + } + { + // transformation is not applied due to presence of the first output + function_ref = gen_model_with_two_outputs(2, 10, 120); + } +}