[TF FE] Fix CTCLoss translator (#20775)
* Fix CTCLoss translator Signed-off-by: Kazantsev, Roman <roman.kazantsev@intel.com> * Expend layer tests for CTCLoss --------- Signed-off-by: Kazantsev, Roman <roman.kazantsev@intel.com>
This commit is contained in:
parent
0076f7fc00
commit
fc4fe07a0e
@ -36,7 +36,7 @@ OutputVector translate_ctc_loss_op(const NodeContext& node) {
|
|||||||
|
|
||||||
// retrieve all attributes for CTCLoss
|
// retrieve all attributes for CTCLoss
|
||||||
auto preprocess_collapse_repeated = node.get_attribute<bool>("preprocess_collapse_repeated", false);
|
auto preprocess_collapse_repeated = node.get_attribute<bool>("preprocess_collapse_repeated", false);
|
||||||
auto ctc_merge_repeated = node.get_attribute<bool>("preprocess_collapse_repeated", true);
|
auto ctc_merge_repeated = node.get_attribute<bool>("ctc_merge_repeated", true);
|
||||||
auto time_major = node.get_attribute<bool>("time_major", true);
|
auto time_major = node.get_attribute<bool>("time_major", true);
|
||||||
|
|
||||||
if (time_major) {
|
if (time_major) {
|
||||||
|
@ -18,7 +18,7 @@ class TestCTCLoss(CommonTFLayerTest):
|
|||||||
inputs_dict[input] = np.random.randint(0, 5, inputs_dict[input]).astype(np.float32)
|
inputs_dict[input] = np.random.randint(0, 5, inputs_dict[input]).astype(np.float32)
|
||||||
return inputs_dict
|
return inputs_dict
|
||||||
|
|
||||||
def create_ctcloss_placeholder_const_net(self, inputs, targets):
|
def create_ctcloss_placeholder_const_net(self, inputs, targets, preprocess_collapse_repeated, ctc_merge_repeated):
|
||||||
seq_lens = np.array([inputs[2]], dtype=np.int32)
|
seq_lens = np.array([inputs[2]], dtype=np.int32)
|
||||||
x = [targets]
|
x = [targets]
|
||||||
|
|
||||||
@ -36,7 +36,9 @@ class TestCTCLoss(CommonTFLayerTest):
|
|||||||
tf_inputs = tf.compat.v1.placeholder(tf.float32, inputs, "inputs")
|
tf_inputs = tf.compat.v1.placeholder(tf.float32, inputs, "inputs")
|
||||||
|
|
||||||
ctc_loss = tf.raw_ops.CTCLoss(inputs=tf_inputs, labels_indices=indices, labels_values=vals,
|
ctc_loss = tf.raw_ops.CTCLoss(inputs=tf_inputs, labels_indices=indices, labels_values=vals,
|
||||||
sequence_length=seq_lens)
|
sequence_length=seq_lens,
|
||||||
|
preprocess_collapse_repeated=preprocess_collapse_repeated,
|
||||||
|
ctc_merge_repeated=ctc_merge_repeated)
|
||||||
# compute exponent since CTCLoss value is -ln(prob)
|
# compute exponent since CTCLoss value is -ln(prob)
|
||||||
tf.math.exp(-ctc_loss[0])
|
tf.math.exp(-ctc_loss[0])
|
||||||
|
|
||||||
@ -54,11 +56,16 @@ class TestCTCLoss(CommonTFLayerTest):
|
|||||||
]
|
]
|
||||||
|
|
||||||
@pytest.mark.parametrize("params", test_data)
|
@pytest.mark.parametrize("params", test_data)
|
||||||
|
@pytest.mark.parametrize("preprocess_collapse_repeated", [True, False, None])
|
||||||
|
@pytest.mark.parametrize("ctc_merge_repeated", [True, False, None])
|
||||||
@pytest.mark.precommit_tf_fe
|
@pytest.mark.precommit_tf_fe
|
||||||
@pytest.mark.nightly
|
@pytest.mark.nightly
|
||||||
@pytest.mark.skipif(platform == 'darwin', reason="Ticket - 122182")
|
@pytest.mark.skipif(platform == 'darwin', reason="Ticket - 122182")
|
||||||
def test_ctcloss_placeholder_const(self, params, ie_device, precision, ir_version, temp_dir,
|
def test_ctcloss_placeholder_const(self, params, preprocess_collapse_repeated, ctc_merge_repeated,
|
||||||
|
ie_device, precision, ir_version, temp_dir,
|
||||||
use_new_frontend, use_old_api):
|
use_new_frontend, use_old_api):
|
||||||
self._test(*self.create_ctcloss_placeholder_const_net(**params),
|
self._test(*self.create_ctcloss_placeholder_const_net(**params,
|
||||||
|
preprocess_collapse_repeated=preprocess_collapse_repeated,
|
||||||
|
ctc_merge_repeated=ctc_merge_repeated),
|
||||||
ie_device, precision, ir_version, temp_dir=temp_dir,
|
ie_device, precision, ir_version, temp_dir=temp_dir,
|
||||||
use_new_frontend=use_new_frontend, use_old_api=use_old_api, custom_eps=1e-2)
|
use_new_frontend=use_new_frontend, use_old_api=use_old_api)
|
||||||
|
Loading…
Reference in New Issue
Block a user