[TF FE] Fix CTCLoss translator (#20775)
* Fix CTCLoss translator Signed-off-by: Kazantsev, Roman <roman.kazantsev@intel.com> * Expend layer tests for CTCLoss --------- Signed-off-by: Kazantsev, Roman <roman.kazantsev@intel.com>
This commit is contained in:
parent
0076f7fc00
commit
fc4fe07a0e
@ -36,7 +36,7 @@ OutputVector translate_ctc_loss_op(const NodeContext& node) {
|
||||
|
||||
// retrieve all attributes for CTCLoss
|
||||
auto preprocess_collapse_repeated = node.get_attribute<bool>("preprocess_collapse_repeated", false);
|
||||
auto ctc_merge_repeated = node.get_attribute<bool>("preprocess_collapse_repeated", true);
|
||||
auto ctc_merge_repeated = node.get_attribute<bool>("ctc_merge_repeated", true);
|
||||
auto time_major = node.get_attribute<bool>("time_major", true);
|
||||
|
||||
if (time_major) {
|
||||
|
@ -18,7 +18,7 @@ class TestCTCLoss(CommonTFLayerTest):
|
||||
inputs_dict[input] = np.random.randint(0, 5, inputs_dict[input]).astype(np.float32)
|
||||
return inputs_dict
|
||||
|
||||
def create_ctcloss_placeholder_const_net(self, inputs, targets):
|
||||
def create_ctcloss_placeholder_const_net(self, inputs, targets, preprocess_collapse_repeated, ctc_merge_repeated):
|
||||
seq_lens = np.array([inputs[2]], dtype=np.int32)
|
||||
x = [targets]
|
||||
|
||||
@ -36,7 +36,9 @@ class TestCTCLoss(CommonTFLayerTest):
|
||||
tf_inputs = tf.compat.v1.placeholder(tf.float32, inputs, "inputs")
|
||||
|
||||
ctc_loss = tf.raw_ops.CTCLoss(inputs=tf_inputs, labels_indices=indices, labels_values=vals,
|
||||
sequence_length=seq_lens)
|
||||
sequence_length=seq_lens,
|
||||
preprocess_collapse_repeated=preprocess_collapse_repeated,
|
||||
ctc_merge_repeated=ctc_merge_repeated)
|
||||
# compute exponent since CTCLoss value is -ln(prob)
|
||||
tf.math.exp(-ctc_loss[0])
|
||||
|
||||
@ -54,11 +56,16 @@ class TestCTCLoss(CommonTFLayerTest):
|
||||
]
|
||||
|
||||
@pytest.mark.parametrize("params", test_data)
|
||||
@pytest.mark.parametrize("preprocess_collapse_repeated", [True, False, None])
|
||||
@pytest.mark.parametrize("ctc_merge_repeated", [True, False, None])
|
||||
@pytest.mark.precommit_tf_fe
|
||||
@pytest.mark.nightly
|
||||
@pytest.mark.skipif(platform == 'darwin', reason="Ticket - 122182")
|
||||
def test_ctcloss_placeholder_const(self, params, ie_device, precision, ir_version, temp_dir,
|
||||
def test_ctcloss_placeholder_const(self, params, preprocess_collapse_repeated, ctc_merge_repeated,
|
||||
ie_device, precision, ir_version, temp_dir,
|
||||
use_new_frontend, use_old_api):
|
||||
self._test(*self.create_ctcloss_placeholder_const_net(**params),
|
||||
self._test(*self.create_ctcloss_placeholder_const_net(**params,
|
||||
preprocess_collapse_repeated=preprocess_collapse_repeated,
|
||||
ctc_merge_repeated=ctc_merge_repeated),
|
||||
ie_device, precision, ir_version, temp_dir=temp_dir,
|
||||
use_new_frontend=use_new_frontend, use_old_api=use_old_api, custom_eps=1e-2)
|
||||
use_new_frontend=use_new_frontend, use_old_api=use_old_api)
|
||||
|
Loading…
Reference in New Issue
Block a user