[TF FE] Added CTC Greedy Decoder layer test with fixes (#15296)

* [TF FE] Added CTC Greedy Decoder layer test with fixes

* Update src/frontends/tensorflow_common/src/op/ctc_greedy_decoder.cpp

* Update src/frontends/tensorflow_common/src/op/ctc_greedy_decoder.cpp

* Update src/frontends/tensorflow_common/src/op/ctc_greedy_decoder.cpp

* Update src/frontends/tensorflow_common/src/op/ctc_greedy_decoder.cpp

* Update src/frontends/tensorflow_common/src/op/ctc_greedy_decoder.cpp

* Update src/frontends/tensorflow_common/src/op/ctc_greedy_decoder.cpp

* Removed ov:: and fixed codestyle

Co-authored-by: Roman Kazantsev <roman.kazantsev@intel.com>
This commit is contained in:
Georgy Krivoruchko 2023-01-25 18:08:39 +04:00 committed by GitHub
parent 623de0fdb4
commit be3ed31513
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 87 additions and 19 deletions

View File

@ -8,7 +8,7 @@
using namespace std;
using namespace ov;
using namespace opset8;
using namespace ov::frontend;
using namespace frontend;
using namespace frontend::tensorflow::detail;
namespace ov {
@ -22,57 +22,57 @@ OutputVector translate_ctc_greedy_decoder_op(const NodeContext& node) {
auto sequence_length = node.get_input(1);
// retrieve attribute for CTCGreedyDecoder
auto merge_repeated = node.get_attribute<bool>("merge_repeated", true);
auto merge_repeated = node.get_attribute<bool>("merge_repeated", false);
auto blank_index = node.get_attribute<int64_t>("blank_index", -1);
// In TensorFlow the input is going in a format [time_size, batch_size, num_classes]
// CTCGreedyDecoder expects inputs in a format [batch_size, time_size, num_classes]
ov::AxisVector inputs_order = {1, 0, 2};
inputs = ov::frontend::tensorflow::make_transpose(inputs, inputs_order);
AxisVector inputs_order = {1, 0, 2};
inputs = frontend::tensorflow::make_transpose(inputs, inputs_order);
shared_ptr<CTCGreedyDecoderSeqLen> ctc_greedy_decoder = nullptr;
if (blank_index == -1) {
// default value for blank index means it should be equal to num_classes - 1
// in this case it is not required to specify the third input for OpenVINO CTCGreedyDecoderSeqLen
ctc_greedy_decoder =
make_shared<CTCGreedyDecoderSeqLen>(inputs, sequence_length, merge_repeated, ov::element::i64);
make_shared<CTCGreedyDecoderSeqLen>(inputs, sequence_length, merge_repeated, element::i64, element::i64);
} else {
auto blank_index_const = make_shared<Constant>(sequence_length.get_element_type(), ov::Shape{}, blank_index);
auto blank_index_const = make_shared<Constant>(sequence_length.get_element_type(), Shape{}, blank_index);
ctc_greedy_decoder = make_shared<CTCGreedyDecoderSeqLen>(inputs,
sequence_length,
blank_index_const,
merge_repeated,
ov::element::i64,
ov::element::i64);
element::i64,
element::i64);
}
// CTCGreedyDecoderSeqLen returns dense tensor holding the decoded results.
// We need to transform this output into a sparse format.
auto minus_one_const = make_shared<Constant>(ctc_greedy_decoder->output(0).get_element_type(), ov::Shape{}, -1);
auto minus_one_const = make_shared<Constant>(ctc_greedy_decoder->output(0).get_element_type(), Shape{}, -1);
auto decoded_mask = make_shared<NotEqual>(ctc_greedy_decoder->output(0), minus_one_const);
auto decoded_indices = make_shared<NonZero>(decoded_mask, ov::element::i64)->output(0);
auto decoded_indices = make_shared<NonZero>(decoded_mask, element::i64)->output(0);
// Since the indices in row-major format, we need to transpose them before gathering values
auto decoded_indices_transposed = ov::frontend::tensorflow::make_transpose(decoded_indices, {1, 0});
auto decoded_indices_transposed = frontend::tensorflow::make_transpose(decoded_indices, {1, 0});
auto decoded_values = make_shared<GatherND>(ctc_greedy_decoder->output(0), decoded_indices_transposed);
// Compute the shape of the smallest dense tensor that can contain the sparse
// matrix represented by ng_indices and ng_values.
auto max_seq_len_axis = make_shared<Constant>(ov::element::i64, ov::Shape{}, 0);
auto max_seq_len_axis = make_shared<Constant>(element::i64, Shape{}, 0);
auto max_seq_len = make_shared<ReduceMax>(ctc_greedy_decoder->output(1), max_seq_len_axis, true);
// inputs shape is in the form [batch_size, time_size, num_classes]
auto inputs_shape = make_shared<ShapeOf>(inputs, ov::element::i64);
auto slice_start = make_shared<Constant>(ov::element::i64, ov::Shape{}, 0);
auto slice_end = make_shared<Constant>(ov::element::i64, ov::Shape{}, 1);
auto slice_step = make_shared<Constant>(ov::element::i64, ov::Shape{}, 1);
auto inputs_shape = make_shared<ShapeOf>(inputs, element::i64);
auto slice_start = make_shared<Constant>(element::i64, Shape{1}, 0);
auto slice_end = make_shared<Constant>(element::i64, Shape{1}, 1);
auto slice_step = make_shared<Constant>(element::i64, Shape{1}, 1);
auto batch_size = make_shared<Slice>(inputs_shape, slice_start, slice_end, slice_step);
auto dense_shape = make_shared<Concat>(OutputVector{batch_size, max_seq_len}, 0);
// Compute the negative of the sum of the greatest logit at each timeframe
// the inputs are in a form [batch_size, time_size, num_classes]
auto max_log_probs_axis = make_shared<Constant>(ov::element::i64, ov::Shape{}, 2);
auto max_log_probs_axis = make_shared<Constant>(element::i64, Shape{}, 2);
auto max_log_probs = make_shared<ReduceMax>(inputs, max_log_probs_axis, false);
auto sum_max_log_probs_axis = make_shared<Constant>(ov::element::i64, ov::Shape{}, 1);
auto sum_max_log_probs_axis = make_shared<Constant>(element::i64, Shape{}, 1);
auto sum_max_log_probs = make_shared<ReduceSum>(max_log_probs, sum_max_log_probs_axis, false);
auto neg_sum_logits = make_shared<Negative>(sum_max_log_probs);
@ -81,7 +81,7 @@ OutputVector translate_ctc_greedy_decoder_op(const NodeContext& node) {
set_node_name(node.get_name() + ":2", dense_shape);
set_node_name(node.get_name() + ":3", neg_sum_logits);
return {decoded_indices, decoded_values, dense_shape, neg_sum_logits};
return {decoded_indices_transposed, decoded_values, dense_shape, neg_sum_logits};
}
} // namespace op
} // namespace tensorflow

View File

@ -0,0 +1,68 @@
# Copyright (C) 2023 Intel Corporation
# SPDX-License-Identifier: Apache-2.0
import pytest
from common.tf_layer_test_class import CommonTFLayerTest
import numpy as np
import tensorflow as tf
# Testing operation CTCGreedyDecoder
# Documentation: https://www.tensorflow.org/api_docs/python/tf/raw_ops/CTCGreedyDecoder
class TestCTCGreedyDecoder(CommonTFLayerTest):
# input_shape - shape a tensor for a decoder
# merge_repeated - bool, enables/disable merge repeated classes in decoder
# ir_version - common parameter
# use_new_frontend - common parameter
def create_ctcgreedydecoder_placeholder_const_net(self, input_shape, merge_repeated,
ir_version, use_new_frontend):
"""
Tensorflow net IR net
Placeholder->CTCLoss => Placeholder->Transpose->CTCGreedyDecoder->NotEqual->NonZero->Transpose
"""
if use_new_frontend == False:
pytest.skip('Legacy path isn\'t supported by CTCGreedyDecoder')
seq_lens = np.array([input_shape[2]], dtype=np.int32)
tf.compat.v1.reset_default_graph()
# Create the graph and model
with tf.compat.v1.Session() as sess:
tf_inputs = tf.compat.v1.placeholder(tf.float32, input_shape, "inputs")
ctc_gd = tf.raw_ops.CTCGreedyDecoder(inputs = tf_inputs, sequence_length = seq_lens, merge_repeated=merge_repeated)
tf.identity(ctc_gd[0], name='decoded_indices')
tf.compat.v1.global_variables_initializer()
tf_net = sess.graph_def
ref_net = None
return tf_net, ref_net
test_data = [
pytest.param(
dict(
input_shape = [6, 1, 4],
),
marks=pytest.mark.precommit_tf_fe),
dict(
input_shape = [10, 1, 7],
),
]
@pytest.mark.parametrize("params", test_data)
@pytest.mark.parametrize("merge_repeated", [False, True])
@pytest.mark.nightly
def test_ctcgreedydecoder_placeholder_const(self, params, merge_repeated, ie_device, precision, ir_version, temp_dir,
use_new_frontend, use_old_api):
self._test(*self.create_ctcgreedydecoder_placeholder_const_net(**params, ir_version=ir_version,
use_new_frontend=use_new_frontend, merge_repeated=merge_repeated),
ie_device, precision, ir_version, temp_dir=temp_dir,
use_new_frontend=use_new_frontend, use_old_api=use_old_api, merge_repeated=merge_repeated)