[TF FE] Fix CTCLoss translator (#20775)

* Fix CTCLoss translator

Signed-off-by: Kazantsev, Roman <roman.kazantsev@intel.com>

* Expend layer tests for CTCLoss

---------

Signed-off-by: Kazantsev, Roman <roman.kazantsev@intel.com>
This commit is contained in:
Roman Kazantsev 2023-10-31 08:51:18 +04:00 committed by GitHub
parent 0076f7fc00
commit fc4fe07a0e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 13 additions and 6 deletions

View File

@ -36,7 +36,7 @@ OutputVector translate_ctc_loss_op(const NodeContext& node) {
// retrieve all attributes for CTCLoss // retrieve all attributes for CTCLoss
auto preprocess_collapse_repeated = node.get_attribute<bool>("preprocess_collapse_repeated", false); auto preprocess_collapse_repeated = node.get_attribute<bool>("preprocess_collapse_repeated", false);
auto ctc_merge_repeated = node.get_attribute<bool>("preprocess_collapse_repeated", true); auto ctc_merge_repeated = node.get_attribute<bool>("ctc_merge_repeated", true);
auto time_major = node.get_attribute<bool>("time_major", true); auto time_major = node.get_attribute<bool>("time_major", true);
if (time_major) { if (time_major) {

View File

@ -18,7 +18,7 @@ class TestCTCLoss(CommonTFLayerTest):
inputs_dict[input] = np.random.randint(0, 5, inputs_dict[input]).astype(np.float32) inputs_dict[input] = np.random.randint(0, 5, inputs_dict[input]).astype(np.float32)
return inputs_dict return inputs_dict
def create_ctcloss_placeholder_const_net(self, inputs, targets): def create_ctcloss_placeholder_const_net(self, inputs, targets, preprocess_collapse_repeated, ctc_merge_repeated):
seq_lens = np.array([inputs[2]], dtype=np.int32) seq_lens = np.array([inputs[2]], dtype=np.int32)
x = [targets] x = [targets]
@ -36,7 +36,9 @@ class TestCTCLoss(CommonTFLayerTest):
tf_inputs = tf.compat.v1.placeholder(tf.float32, inputs, "inputs") tf_inputs = tf.compat.v1.placeholder(tf.float32, inputs, "inputs")
ctc_loss = tf.raw_ops.CTCLoss(inputs=tf_inputs, labels_indices=indices, labels_values=vals, ctc_loss = tf.raw_ops.CTCLoss(inputs=tf_inputs, labels_indices=indices, labels_values=vals,
sequence_length=seq_lens) sequence_length=seq_lens,
preprocess_collapse_repeated=preprocess_collapse_repeated,
ctc_merge_repeated=ctc_merge_repeated)
# compute exponent since CTCLoss value is -ln(prob) # compute exponent since CTCLoss value is -ln(prob)
tf.math.exp(-ctc_loss[0]) tf.math.exp(-ctc_loss[0])
@ -54,11 +56,16 @@ class TestCTCLoss(CommonTFLayerTest):
] ]
@pytest.mark.parametrize("params", test_data) @pytest.mark.parametrize("params", test_data)
@pytest.mark.parametrize("preprocess_collapse_repeated", [True, False, None])
@pytest.mark.parametrize("ctc_merge_repeated", [True, False, None])
@pytest.mark.precommit_tf_fe @pytest.mark.precommit_tf_fe
@pytest.mark.nightly @pytest.mark.nightly
@pytest.mark.skipif(platform == 'darwin', reason="Ticket - 122182") @pytest.mark.skipif(platform == 'darwin', reason="Ticket - 122182")
def test_ctcloss_placeholder_const(self, params, ie_device, precision, ir_version, temp_dir, def test_ctcloss_placeholder_const(self, params, preprocess_collapse_repeated, ctc_merge_repeated,
ie_device, precision, ir_version, temp_dir,
use_new_frontend, use_old_api): use_new_frontend, use_old_api):
self._test(*self.create_ctcloss_placeholder_const_net(**params), self._test(*self.create_ctcloss_placeholder_const_net(**params,
preprocess_collapse_repeated=preprocess_collapse_repeated,
ctc_merge_repeated=ctc_merge_repeated),
ie_device, precision, ir_version, temp_dir=temp_dir, ie_device, precision, ir_version, temp_dir=temp_dir,
use_new_frontend=use_new_frontend, use_old_api=use_old_api, custom_eps=1e-2) use_new_frontend=use_new_frontend, use_old_api=use_old_api)