[TF FE] Support GRUBlockCell (#13202)

* [TF FE] Support GRUBlockCell

Currently, we support only hidden state output from GRUBlockCell due to
OpenVINO GRUCell capability

Signed-off-by: Kazantsev, Roman <roman.kazantsev@intel.com>

* Fix code-style issue

* Return Softsign translator

* Add tests for GRUBlockCellReplacer

* Fix issue for bias in GRUBlockCellReplacer

Signed-off-by: Kazantsev, Roman <roman.kazantsev@intel.com>
This commit is contained in:
Roman Kazantsev 2022-09-27 01:30:55 +03:00 committed by GitHub
parent e47b8858aa
commit 8406983849
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 425 additions and 0 deletions

View File

@ -7,6 +7,7 @@
#include "graph_iterator_proto.hpp"
#include "helper_transforms/block_lstm_replacer.hpp"
#include "helper_transforms/embedding_segments_feature_fusing.hpp"
#include "helper_transforms/gru_block_cell_replacer.hpp"
#include "input_model.hpp"
#include "op_table.hpp"
#include "openvino/frontend/tensorflow/extension/conversion.hpp"
@ -430,6 +431,7 @@ void FrontEnd::normalize(const std::shared_ptr<ov::Model>& function) const {
// into sub-graphs with only OpenVINO operations
manager.register_pass<ov::frontend::tensorflow::pass::EmbeddingSegmentSingleFeatureFusion>();
manager.register_pass<ov::frontend::tensorflow::pass::BlockLSTMToLSTMSequenceOneOutput>();
manager.register_pass<ov::frontend::tensorflow::pass::GRUBlockCellReplacer>();
// TODO: reimplement TransposeSinking that does not corrupt filters for Convolution
// and preserve tensor names in case of sinking

View File

@ -0,0 +1,115 @@
// Copyright (C) 2018-2022 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#pragma once
#include <string>
#include <vector>
#include "internal_operation.hpp"
namespace ov {
namespace frontend {
namespace tensorflow {
class GRUBlockCell : public InternalOperation {
public:
OPENVINO_OP("GRUBlockCell", "ov::frontend::tensorflow::util", InternalOperation);
GRUBlockCell(const Output<Node>& x,
const Output<Node>& h_prev,
const Output<Node>& w_ru,
const Output<Node>& w_c,
const Output<Node>& b_ru,
const Output<Node>& b_c,
const std::shared_ptr<DecoderBase>& decoder = std::make_shared<DecoderFake>())
: InternalOperation(decoder, OutputVector{x, h_prev, w_ru, w_c, b_ru, b_c}, 4),
m_hidden_size(ov::Dimension::dynamic()) {
validate_and_infer_types();
}
void validate_and_infer_types() override {
// GRUBlockCell computes the GRU cell forward propagation for 1 time step
// Inputs:
// 0) x: Input to the GRU cell
// 1) h_prev: State input from the previous GRU cell
// 2) w_ru: Weight matrix for the reset and update gate
// 3) w_c: Weight matrix for the cell connection gate
// 4) b_ru: Bias vector for the reset and update gate
// 5) b_c: Bias vector for the cell connection gate
//
// Outputs:
// 0) r: Output of the reset gate
// 1) u: Output of the update gate
// 2) c: Output of the cell connection gate
// 3) h: Current state of the GRU cell
// try to deduce static hidden_size
// 1. use h_prev shape
auto h_prev_shape = get_input_partial_shape(1);
auto h_prev_rank = h_prev_shape.rank();
if (h_prev_rank.is_static()) {
FRONT_END_OP_CONVERSION_CHECK(h_prev_rank.get_length() == 2,
"Internal error in OpenVINO TensorFlow Frontend: initial hidden state for "
"GRUBlockCell must be of rank equal to 2.");
m_hidden_size = h_prev_shape[1].is_static() ? h_prev_shape[1].get_length() : m_hidden_size;
}
// 2. use w_ru shape
auto w_ru_shape = get_input_partial_shape(2);
auto w_ru_rank = w_ru_shape.rank();
if (w_ru_rank.is_static()) {
FRONT_END_OP_CONVERSION_CHECK(
w_ru_rank.get_length() == 2,
"Internal error in OpenVINO TensorFlow Frontend: weights for GRUBlockCell must be of rank equal to 2.");
m_hidden_size = w_ru_shape[1].is_static() ? w_ru_shape[1].get_length() / 2 : m_hidden_size;
}
// 3. use w_c shape
auto w_c_shape = get_input_partial_shape(3);
auto w_c_rank = w_c_shape.rank();
if (w_c_rank.is_static()) {
FRONT_END_OP_CONVERSION_CHECK(
w_c_rank.get_length() == 2,
"Internal error in OpenVINO TensorFlow Frontend: weights for GRUBlockCell must be of rank equal to 2.");
m_hidden_size = w_c_shape[1].is_static() ? w_c_shape[1].get_length() : m_hidden_size;
}
// 3. use b_ru shape
auto b_ru_shape = get_input_partial_shape(4);
auto b_ru_rank = b_ru_shape.rank();
if (b_ru_rank.is_static()) {
FRONT_END_OP_CONVERSION_CHECK(
b_ru_rank.get_length() == 1,
"Internal error in OpenVINO TensorFlow Frontend: bias for GRUBlockCell must be of rank equal to 1.");
m_hidden_size = b_ru_shape[0].is_static() ? b_ru_shape[0].get_length() / 2 : m_hidden_size;
}
// 4. use b_c shape
auto b_c_shape = get_input_partial_shape(5);
auto b_c_rank = b_c_shape.rank();
if (b_c_rank.is_static()) {
FRONT_END_OP_CONVERSION_CHECK(
b_c_rank.get_length() == 1,
"Internal error in OpenVINO TensorFlow Frontend: bias for GRUBlockCell must be of rank equal to 1.");
m_hidden_size = b_c_shape[0].is_static() ? b_c_shape[0].get_length() : m_hidden_size;
}
// set the defined shape only for the fourth output since
// OpenVINO GRUCell supports hidden state output
auto x_type = get_input_element_type(0);
set_output_type(0, x_type, ov::PartialShape::dynamic());
set_output_type(1, x_type, ov::PartialShape::dynamic());
set_output_type(2, x_type, ov::PartialShape::dynamic());
set_output_type(3, x_type, h_prev_shape);
}
ov::Dimension get_hidden_size() const {
// TODO: it must be deleted once hidden_size is gone from attributes
return m_hidden_size;
}
private:
ov::Dimension m_hidden_size;
};
} // namespace tensorflow
} // namespace frontend
} // namespace ov

View File

@ -0,0 +1,108 @@
// Copyright (C) 2018-2022 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#include "helper_transforms/gru_block_cell_replacer.hpp"
#include <memory>
#include <vector>
#include "helper_ops/gru_block_cell.hpp"
#include "ngraph/rt_info.hpp"
#include "openvino/opsets/opset9.hpp"
#include "openvino/pass/pattern/matcher.hpp"
#include "openvino/pass/pattern/op/or.hpp"
#include "openvino/pass/pattern/op/wrap_type.hpp"
#include "transformations/utils/utils.hpp"
#include "utils.hpp"
using namespace std;
using namespace ov::pass;
using namespace ov::opset9;
using namespace ov::frontend::tensorflow;
pass::GRUBlockCellReplacer::GRUBlockCellReplacer() {
auto gru_block_cell = pattern::wrap_type<GRUBlockCell>();
matcher_pass_callback callback = [=](pattern::Matcher& m) {
NodeRegistry rg;
auto gru_block_cell_node = std::dynamic_pointer_cast<GRUBlockCell>(m.get_match_root());
if (!gru_block_cell_node) {
return false;
}
// this transformation expects only one output (the forth output) - hidden state
// that is only output supported by OpenVINO GRUCell
std::vector<int> restricted_output_indices = {0, 1, 2};
for (size_t output_ind : restricted_output_indices) {
if (gru_block_cell_node->output(output_ind).get_target_inputs().size() > 0) {
return false;
}
}
// currently, OpenVINO support only static hidden_size
auto m_hidden_size = gru_block_cell_node->get_hidden_size();
if (m_hidden_size.is_dynamic()) {
return false;
}
auto x = gru_block_cell_node->input_value(0);
auto h_prev = gru_block_cell_node->input_value(1);
auto w_ru = gru_block_cell_node->input_value(2);
auto w_c = gru_block_cell_node->input_value(3);
auto b_ru = gru_block_cell_node->input_value(4);
auto b_c = gru_block_cell_node->input_value(5);
// retrive input_size and hidden_size
auto x_shape = rg.make<ShapeOf>(x, element::i64);
auto ss_start = rg.make<Constant>(element::i64, Shape{1}, 1);
auto ss_end = rg.make<Constant>(element::i64, Shape{1}, 2);
auto ss_step = rg.make<Constant>(element::i64, Shape{1}, 1);
auto input_size = rg.make<Slice>(x_shape, ss_start, ss_end, ss_step);
auto h_prev_shape = rg.make<ShapeOf>(h_prev, element::i64);
auto hidden_size = rg.make<Slice>(h_prev_shape, ss_start, ss_end, ss_step);
// prepare weights input
// TensorFlow provides weights in a format w_ru and w_c, where
// z or u - update, r - reset, c or h - hidden (connection)
// OpenVINO GRUCell accepts weights in a format w_zrh (or w_urс)
// 1. split w_ru into w_r and w_u
auto split_w_ru = rg.make<Split>(w_ru, rg.make<Constant>(element::i64, Shape{}, 1), 2);
// 2. concatenate different parts of weights into w_zrh (or w_urс)
auto w_urc = rg.make<Concat>(OutputVector{split_w_ru->output(1), split_w_ru->output(0), w_c}, 1);
// prepare bias in the same way
auto split_b_ru = rg.make<Split>(b_ru, rg.make<Constant>(element::i64, Shape{}, 0), 2);
auto b_urc = rg.make<Concat>(OutputVector{split_b_ru->output(1), split_b_ru->output(0), b_c}, 0);
// transpose weights
// the current shape - [input_size + hidden_size, 3 * hidden_size]
// we need the shape [3 * hidden_size, input_size + hidden_size]
// in order to split WR into W and R
auto transpose_order = rg.make<Constant>(element::i64, Shape{2}, std::vector<int64_t>{1, 0});
auto w_urc_transpose = rg.make<Transpose>(w_urc, transpose_order);
// split combined weights WR into W and R
auto split_axis = rg.make<Constant>(element::i64, Shape{}, 1);
auto split_nums = rg.make<Concat>(OutputVector{input_size, hidden_size}, 0);
auto split_WR = rg.make<VariadicSplit>(w_urc_transpose, split_axis, split_nums);
auto gru_cell =
rg.make<GRUCell>(x, h_prev, split_WR->output(0), split_WR->output(1), b_urc, m_hidden_size.get_length());
// preserve names of the node and the output tensor
gru_cell->set_friendly_name(m.get_match_root()->get_friendly_name() + ":3");
copy_runtime_info(gru_block_cell_node, rg.get());
// replace GRUBlockCell with GRUCell manually instead of calling
// ov::replace_node(m.get_match_root(), gru_cell);
// because GRUBlockCell has 4 outputs and GRUCell has just one
m.get_match_root()->output(3).replace(gru_cell->output(0));
return true;
};
auto m = std::make_shared<ngraph::pattern::Matcher>(gru_block_cell,
"ov::frontend::tensorflow::pass::GRUBlockCellReplacer");
register_matcher(m, callback);
}

View File

@ -0,0 +1,29 @@
// Copyright (C) 2018-2022 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#pragma once
#include <memory>
#include <utility>
#include "openvino/frontend/tensorflow/visibility.hpp"
#include "openvino/pass/graph_rewrite.hpp"
#include "openvino/pass/pass.hpp"
namespace ov {
namespace frontend {
namespace tensorflow {
namespace pass {
// This transformation handles GRUBlockCell with just one output - hidden state
class TENSORFLOW_API GRUBlockCellReplacer : public ov::pass::MatcherPass {
public:
OPENVINO_RTTI("ov::frontend::tensorflow::pass::GRUBlockCellReplacer");
GRUBlockCellReplacer();
};
} // namespace pass
} // namespace tensorflow
} // namespace frontend
} // namespace ov

View File

@ -0,0 +1,49 @@
// Copyright (C) 2018-2022 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#include "helper_ops/gru_block_cell.hpp"
#include "op_table.hpp"
using namespace std;
using namespace ov;
using namespace ov::frontend::tensorflow;
namespace ov {
namespace frontend {
namespace tensorflow {
namespace op {
OutputVector translate_gru_block_cell_op(const NodeContext& node) {
// GRUBlockCell computes the GRU cell forward propagation for 1 time step
// Inputs:
// 0) x: Input to the GRU cell
// 1) h_prev: State input from the previous GRU cell
// 2) w_ru: Weight matrix for the reset and update gate
// 3) w_c: Weight matrix for the cell connection gate
// 4) b_ru: Bias vector for the reset and update gate
// 5) b_c: Bias vector for the cell connection gate
//
// Outputs:
// 0) r: Output of the reset gate
// 1) u: Output of the update gate
// 2) c: Output of the cell connection gate
// 3) h: Current state of the GRU cell
default_op_checks(node, 6, {"GRUBlockCell"});
auto x = node.get_input(0);
auto h_prev = node.get_input(1);
auto w_ru = node.get_input(2);
auto w_c = node.get_input(3);
auto b_ru = node.get_input(4);
auto b_c = node.get_input(5);
auto gru_block_cell_node = make_shared<GRUBlockCell>(x, h_prev, w_ru, w_c, b_ru, b_c, node.get_decoder());
set_node_name(node.get_name(), gru_block_cell_node);
return gru_block_cell_node->outputs();
}
} // namespace op
} // namespace tensorflow
} // namespace frontend
} // namespace ov

View File

@ -58,6 +58,7 @@ OP_CONVERTER(translate_fused_batch_norm_op);
OP_CONVERTER(translate_gather_op);
OP_CONVERTER(translate_gather_v2_op);
OP_CONVERTER(translate_gather_nd_op);
OP_CONVERTER(translate_gru_block_cell_op);
OP_CONVERTER(translate_identity_op);
OP_CONVERTER(translate_identity_n_op);
OP_CONVERTER(translate_interpolate_op);
@ -302,6 +303,7 @@ const std::map<std::string, CreatorFunction> get_supported_ops() {
// Translators for internal operations
{"BlockLSTM", translate_block_lstm_op},
{"GRUBlockCell", translate_gru_block_cell_op},
{"SparseFillEmptyRows", translate_sparse_fill_empty_rows_op},
{"SparseSegmentSum", translate_sparse_segment_sum_op},
{"Unique", translate_unique_op},

View File

@ -0,0 +1,120 @@
// Copyright (C) 2018-2022 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#include "helper_transforms/gru_block_cell_replacer.hpp"
#include <gtest/gtest.h>
#include <frontend/shared/include/utils.hpp>
#include <openvino/frontend/manager.hpp>
#include <openvino/opsets/opset9.hpp>
#include <openvino/pass/manager.hpp>
#include "common_test_utils/ngraph_test_utils.hpp"
#include "gtest/gtest.h"
#include "helper_ops/gru_block_cell.hpp"
using namespace std;
using namespace ov;
using namespace opset9;
using namespace element;
using namespace frontend::tensorflow;
using namespace frontend::tensorflow::pass;
namespace {
shared_ptr<Model> gen_model(Dimension batch_size, int64_t hidden_size, Dimension input_size) {
auto x = make_shared<Parameter>(f32, PartialShape{batch_size, input_size});
auto h_prev = make_shared<Parameter>(f32, PartialShape{batch_size, hidden_size});
auto w_ru = make_shared<Parameter>(f32, PartialShape::dynamic());
auto w_c = make_shared<Parameter>(f32, PartialShape::dynamic());
auto b_ru = make_shared<Parameter>(f32, PartialShape::dynamic());
auto b_c = make_shared<Parameter>(f32, PartialShape::dynamic());
auto gru_block_cell = make_shared<GRUBlockCell>(x, h_prev, w_ru, w_c, b_ru, b_c);
return make_shared<Model>(OutputVector{gru_block_cell->output(3)},
ParameterVector{x, h_prev, w_ru, w_c, b_ru, b_c});
}
shared_ptr<Model> gen_model_with_two_outputs(Dimension batch_size, int64_t hidden_size, Dimension input_size) {
auto x = make_shared<Parameter>(f32, PartialShape{batch_size, input_size});
auto h_prev = make_shared<Parameter>(f32, PartialShape{batch_size, hidden_size});
auto w_ru = make_shared<Parameter>(f32, PartialShape::dynamic());
auto w_c = make_shared<Parameter>(f32, PartialShape::dynamic());
auto b_ru = make_shared<Parameter>(f32, PartialShape::dynamic());
auto b_c = make_shared<Parameter>(f32, PartialShape::dynamic());
auto gru_block_cell = make_shared<GRUBlockCell>(x, h_prev, w_ru, w_c, b_ru, b_c);
return make_shared<Model>(OutputVector{gru_block_cell->output(0), gru_block_cell->output(3)},
ParameterVector{x, h_prev, w_ru, w_c, b_ru, b_c});
}
shared_ptr<Model> gen_model_ref(Dimension m_batch_size, int64_t m_hidden_size, Dimension m_input_size) {
auto x = make_shared<Parameter>(f32, PartialShape{m_batch_size, m_input_size});
auto h_prev = make_shared<Parameter>(f32, PartialShape{m_batch_size, m_hidden_size});
auto w_ru = make_shared<Parameter>(f32, PartialShape::dynamic());
auto w_c = make_shared<Parameter>(f32, PartialShape::dynamic());
auto b_ru = make_shared<Parameter>(f32, PartialShape::dynamic());
auto b_c = make_shared<Parameter>(f32, PartialShape::dynamic());
// retrive input_size and hidden_size
auto x_shape = make_shared<ShapeOf>(x, element::i64);
auto ss_start = make_shared<Constant>(element::i64, Shape{1}, 1);
auto ss_end = make_shared<Constant>(element::i64, Shape{1}, 2);
auto ss_step = make_shared<Constant>(element::i64, Shape{1}, 1);
auto input_size = make_shared<Slice>(x_shape, ss_start, ss_end, ss_step);
auto h_prev_shape = make_shared<ShapeOf>(h_prev, element::i64);
auto hidden_size = make_shared<Slice>(h_prev_shape, ss_start, ss_end, ss_step);
// prepare weights input
// TensorFlow provides weights in a format w_ru and w_c, where
// z or u - update, r - reset, c or h - hidden (connection)
// OpenVINO GRUCell accepts weights in a format w_zrh (or w_urс)
// 1. split w_ru into w_r and w_u
auto split_w_ru = make_shared<Split>(w_ru, make_shared<Constant>(element::i64, Shape{}, 1), 2);
// 2. concatenate different parts of weights into w_zrh (or w_urс)
auto w_urc = make_shared<Concat>(OutputVector{split_w_ru->output(1), split_w_ru->output(0), w_c}, 1);
// prepare bias in the same way
auto split_b_ru = make_shared<Split>(b_ru, make_shared<Constant>(element::i64, Shape{}, 0), 2);
auto b_urc = make_shared<Concat>(OutputVector{split_b_ru->output(1), split_b_ru->output(0), b_c}, 0);
// transpose weights
// the current shape - [input_size + hidden_size, 3 * hidden_size]
// we need the shape [3 * hidden_size, input_size + hidden_size]
// in order to split WR into W and R
auto transpose_order = make_shared<Constant>(element::i64, Shape{2}, std::vector<int64_t>{1, 0});
auto w_urc_transpose = make_shared<Transpose>(w_urc, transpose_order);
// split combined weights WR into W and R
auto split_axis = make_shared<Constant>(element::i64, Shape{}, 1);
auto split_nums = make_shared<Concat>(OutputVector{input_size, hidden_size}, 0);
auto split_WR = make_shared<VariadicSplit>(w_urc_transpose, split_axis, split_nums);
auto gru_cell = make_shared<GRUCell>(x, h_prev, split_WR->output(0), split_WR->output(1), b_urc, m_hidden_size);
return make_shared<Model>(OutputVector{gru_cell->output(0)}, ParameterVector{x, h_prev, w_ru, w_c, b_ru, b_c});
}
} // namespace
TEST_F(TransformationTestsF, GRUBlockCellReplacerOneOutput) {
{
function = gen_model(2, 10, 120);
manager.register_pass<GRUBlockCellReplacer>();
}
{ function_ref = gen_model_ref(2, 10, 120); }
}
TEST_F(TransformationTestsF, GRUBlockCellReplacerTwoOutputs) {
{
function = gen_model_with_two_outputs(2, 10, 120);
manager.register_pass<GRUBlockCellReplacer>();
}
{
// transformation is not applied due to presence of the first output
function_ref = gen_model_with_two_outputs(2, 10, 120);
}
}