[TF FE] Added CTC Greedy Decoder layer test with fixes (#15296)
* [TF FE] Added CTC Greedy Decoder layer test with fixes * Update src/frontends/tensorflow_common/src/op/ctc_greedy_decoder.cpp * Update src/frontends/tensorflow_common/src/op/ctc_greedy_decoder.cpp * Update src/frontends/tensorflow_common/src/op/ctc_greedy_decoder.cpp * Update src/frontends/tensorflow_common/src/op/ctc_greedy_decoder.cpp * Update src/frontends/tensorflow_common/src/op/ctc_greedy_decoder.cpp * Update src/frontends/tensorflow_common/src/op/ctc_greedy_decoder.cpp * Removed ov:: and fixed codestyle Co-authored-by: Roman Kazantsev <roman.kazantsev@intel.com>
This commit is contained in:
parent
623de0fdb4
commit
be3ed31513
@ -8,7 +8,7 @@
|
||||
using namespace std;
|
||||
using namespace ov;
|
||||
using namespace opset8;
|
||||
using namespace ov::frontend;
|
||||
using namespace frontend;
|
||||
using namespace frontend::tensorflow::detail;
|
||||
|
||||
namespace ov {
|
||||
@ -22,57 +22,57 @@ OutputVector translate_ctc_greedy_decoder_op(const NodeContext& node) {
|
||||
auto sequence_length = node.get_input(1);
|
||||
|
||||
// retrieve attribute for CTCGreedyDecoder
|
||||
auto merge_repeated = node.get_attribute<bool>("merge_repeated", true);
|
||||
auto merge_repeated = node.get_attribute<bool>("merge_repeated", false);
|
||||
auto blank_index = node.get_attribute<int64_t>("blank_index", -1);
|
||||
|
||||
// In TensorFlow the input is going in a format [time_size, batch_size, num_classes]
|
||||
// CTCGreedyDecoder expects inputs in a format [batch_size, time_size, num_classes]
|
||||
ov::AxisVector inputs_order = {1, 0, 2};
|
||||
inputs = ov::frontend::tensorflow::make_transpose(inputs, inputs_order);
|
||||
AxisVector inputs_order = {1, 0, 2};
|
||||
inputs = frontend::tensorflow::make_transpose(inputs, inputs_order);
|
||||
|
||||
shared_ptr<CTCGreedyDecoderSeqLen> ctc_greedy_decoder = nullptr;
|
||||
if (blank_index == -1) {
|
||||
// default value for blank index means it should be equal to num_classes - 1
|
||||
// in this case it is not required to specify the third input for OpenVINO CTCGreedyDecoderSeqLen
|
||||
ctc_greedy_decoder =
|
||||
make_shared<CTCGreedyDecoderSeqLen>(inputs, sequence_length, merge_repeated, ov::element::i64);
|
||||
make_shared<CTCGreedyDecoderSeqLen>(inputs, sequence_length, merge_repeated, element::i64, element::i64);
|
||||
} else {
|
||||
auto blank_index_const = make_shared<Constant>(sequence_length.get_element_type(), ov::Shape{}, blank_index);
|
||||
auto blank_index_const = make_shared<Constant>(sequence_length.get_element_type(), Shape{}, blank_index);
|
||||
ctc_greedy_decoder = make_shared<CTCGreedyDecoderSeqLen>(inputs,
|
||||
sequence_length,
|
||||
blank_index_const,
|
||||
merge_repeated,
|
||||
ov::element::i64,
|
||||
ov::element::i64);
|
||||
element::i64,
|
||||
element::i64);
|
||||
}
|
||||
|
||||
// CTCGreedyDecoderSeqLen returns dense tensor holding the decoded results.
|
||||
// We need to transform this output into a sparse format.
|
||||
auto minus_one_const = make_shared<Constant>(ctc_greedy_decoder->output(0).get_element_type(), ov::Shape{}, -1);
|
||||
auto minus_one_const = make_shared<Constant>(ctc_greedy_decoder->output(0).get_element_type(), Shape{}, -1);
|
||||
auto decoded_mask = make_shared<NotEqual>(ctc_greedy_decoder->output(0), minus_one_const);
|
||||
auto decoded_indices = make_shared<NonZero>(decoded_mask, ov::element::i64)->output(0);
|
||||
auto decoded_indices = make_shared<NonZero>(decoded_mask, element::i64)->output(0);
|
||||
|
||||
// Since the indices in row-major format, we need to transpose them before gathering values
|
||||
auto decoded_indices_transposed = ov::frontend::tensorflow::make_transpose(decoded_indices, {1, 0});
|
||||
auto decoded_indices_transposed = frontend::tensorflow::make_transpose(decoded_indices, {1, 0});
|
||||
auto decoded_values = make_shared<GatherND>(ctc_greedy_decoder->output(0), decoded_indices_transposed);
|
||||
|
||||
// Compute the shape of the smallest dense tensor that can contain the sparse
|
||||
// matrix represented by ng_indices and ng_values.
|
||||
auto max_seq_len_axis = make_shared<Constant>(ov::element::i64, ov::Shape{}, 0);
|
||||
auto max_seq_len_axis = make_shared<Constant>(element::i64, Shape{}, 0);
|
||||
auto max_seq_len = make_shared<ReduceMax>(ctc_greedy_decoder->output(1), max_seq_len_axis, true);
|
||||
// inputs shape is in the form [batch_size, time_size, num_classes]
|
||||
auto inputs_shape = make_shared<ShapeOf>(inputs, ov::element::i64);
|
||||
auto slice_start = make_shared<Constant>(ov::element::i64, ov::Shape{}, 0);
|
||||
auto slice_end = make_shared<Constant>(ov::element::i64, ov::Shape{}, 1);
|
||||
auto slice_step = make_shared<Constant>(ov::element::i64, ov::Shape{}, 1);
|
||||
auto inputs_shape = make_shared<ShapeOf>(inputs, element::i64);
|
||||
auto slice_start = make_shared<Constant>(element::i64, Shape{1}, 0);
|
||||
auto slice_end = make_shared<Constant>(element::i64, Shape{1}, 1);
|
||||
auto slice_step = make_shared<Constant>(element::i64, Shape{1}, 1);
|
||||
auto batch_size = make_shared<Slice>(inputs_shape, slice_start, slice_end, slice_step);
|
||||
auto dense_shape = make_shared<Concat>(OutputVector{batch_size, max_seq_len}, 0);
|
||||
|
||||
// Compute the negative of the sum of the greatest logit at each timeframe
|
||||
// the inputs are in a form [batch_size, time_size, num_classes]
|
||||
auto max_log_probs_axis = make_shared<Constant>(ov::element::i64, ov::Shape{}, 2);
|
||||
auto max_log_probs_axis = make_shared<Constant>(element::i64, Shape{}, 2);
|
||||
auto max_log_probs = make_shared<ReduceMax>(inputs, max_log_probs_axis, false);
|
||||
auto sum_max_log_probs_axis = make_shared<Constant>(ov::element::i64, ov::Shape{}, 1);
|
||||
auto sum_max_log_probs_axis = make_shared<Constant>(element::i64, Shape{}, 1);
|
||||
auto sum_max_log_probs = make_shared<ReduceSum>(max_log_probs, sum_max_log_probs_axis, false);
|
||||
auto neg_sum_logits = make_shared<Negative>(sum_max_log_probs);
|
||||
|
||||
@ -81,7 +81,7 @@ OutputVector translate_ctc_greedy_decoder_op(const NodeContext& node) {
|
||||
set_node_name(node.get_name() + ":2", dense_shape);
|
||||
set_node_name(node.get_name() + ":3", neg_sum_logits);
|
||||
|
||||
return {decoded_indices, decoded_values, dense_shape, neg_sum_logits};
|
||||
return {decoded_indices_transposed, decoded_values, dense_shape, neg_sum_logits};
|
||||
}
|
||||
} // namespace op
|
||||
} // namespace tensorflow
|
||||
|
@ -0,0 +1,68 @@
|
||||
# Copyright (C) 2023 Intel Corporation
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
import pytest
|
||||
|
||||
from common.tf_layer_test_class import CommonTFLayerTest
|
||||
|
||||
import numpy as np
|
||||
import tensorflow as tf
|
||||
|
||||
# Testing operation CTCGreedyDecoder
|
||||
# Documentation: https://www.tensorflow.org/api_docs/python/tf/raw_ops/CTCGreedyDecoder
|
||||
|
||||
class TestCTCGreedyDecoder(CommonTFLayerTest):
|
||||
# input_shape - shape a tensor for a decoder
|
||||
# merge_repeated - bool, enables/disable merge repeated classes in decoder
|
||||
# ir_version - common parameter
|
||||
# use_new_frontend - common parameter
|
||||
def create_ctcgreedydecoder_placeholder_const_net(self, input_shape, merge_repeated,
|
||||
ir_version, use_new_frontend):
|
||||
"""
|
||||
Tensorflow net IR net
|
||||
|
||||
Placeholder->CTCLoss => Placeholder->Transpose->CTCGreedyDecoder->NotEqual->NonZero->Transpose
|
||||
|
||||
"""
|
||||
|
||||
if use_new_frontend == False:
|
||||
pytest.skip('Legacy path isn\'t supported by CTCGreedyDecoder')
|
||||
|
||||
seq_lens = np.array([input_shape[2]], dtype=np.int32)
|
||||
|
||||
tf.compat.v1.reset_default_graph()
|
||||
|
||||
# Create the graph and model
|
||||
with tf.compat.v1.Session() as sess:
|
||||
tf_inputs = tf.compat.v1.placeholder(tf.float32, input_shape, "inputs")
|
||||
|
||||
ctc_gd = tf.raw_ops.CTCGreedyDecoder(inputs = tf_inputs, sequence_length = seq_lens, merge_repeated=merge_repeated)
|
||||
tf.identity(ctc_gd[0], name='decoded_indices')
|
||||
|
||||
tf.compat.v1.global_variables_initializer()
|
||||
tf_net = sess.graph_def
|
||||
|
||||
ref_net = None
|
||||
|
||||
return tf_net, ref_net
|
||||
|
||||
test_data = [
|
||||
pytest.param(
|
||||
dict(
|
||||
input_shape = [6, 1, 4],
|
||||
),
|
||||
marks=pytest.mark.precommit_tf_fe),
|
||||
dict(
|
||||
input_shape = [10, 1, 7],
|
||||
),
|
||||
]
|
||||
|
||||
@pytest.mark.parametrize("params", test_data)
|
||||
@pytest.mark.parametrize("merge_repeated", [False, True])
|
||||
@pytest.mark.nightly
|
||||
def test_ctcgreedydecoder_placeholder_const(self, params, merge_repeated, ie_device, precision, ir_version, temp_dir,
|
||||
use_new_frontend, use_old_api):
|
||||
self._test(*self.create_ctcgreedydecoder_placeholder_const_net(**params, ir_version=ir_version,
|
||||
use_new_frontend=use_new_frontend, merge_repeated=merge_repeated),
|
||||
ie_device, precision, ir_version, temp_dir=temp_dir,
|
||||
use_new_frontend=use_new_frontend, use_old_api=use_old_api, merge_repeated=merge_repeated)
|
Loading…
Reference in New Issue
Block a user