Handle 3rd optional input in CTCGreedyDecoderSeqLen evaluate (#5669)

* Handle 3rd optional input in CTCGreedyDecoderSeqLen evaluate

* Add test for OP without blank_index input

* Fix default blank_index calculation
This commit is contained in:
Mikhail Treskin 2021-05-24 12:24:46 +03:00 committed by GitHub
parent 97f020481a
commit 66d98530d9
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 40 additions and 12 deletions

View File

@ -43,8 +43,8 @@ NGRAPH_TEST(${BACKEND_NAME}, evaluate_ctc_greedy_decoder_seq_len)
auto data = make_shared<op::Parameter>(element::f32, data_shape);
auto seq_len = make_shared<op::Parameter>(element::i32, seq_len_shape);
auto blanck_index = op::Constant::create<int32_t>(element::i32, Shape{}, {2});
auto decoder = make_shared<op::v6::CTCGreedyDecoderSeqLen>(data, seq_len, blanck_index, false);
auto blank_index = op::Constant::create<int32_t>(element::i32, Shape{}, {2});
auto decoder = make_shared<op::v6::CTCGreedyDecoderSeqLen>(data, seq_len, blank_index, false);
auto function = make_shared<Function>(decoder, ParameterVector{data, seq_len});
auto test_case = test::TestCase<TestEngine>(function);
@ -66,8 +66,8 @@ NGRAPH_TEST(${BACKEND_NAME}, evaluate_ctc_greedy_decoder_seq_len_merge)
auto data = make_shared<op::Parameter>(element::f32, data_shape);
auto seq_len = make_shared<op::Parameter>(element::i32, seq_len_shape);
auto blanck_index = op::Constant::create<int32_t>(element::i32, Shape{}, {2});
auto decoder = make_shared<op::v6::CTCGreedyDecoderSeqLen>(data, seq_len, blanck_index, true);
auto blank_index = op::Constant::create<int32_t>(element::i32, Shape{}, {2});
auto decoder = make_shared<op::v6::CTCGreedyDecoderSeqLen>(data, seq_len, blank_index, true);
auto function = make_shared<Function>(decoder, ParameterVector{data, seq_len});
auto test_case = test::TestCase<TestEngine>(function);
@ -89,8 +89,8 @@ NGRAPH_TEST(${BACKEND_NAME}, evaluate_ctc_greedy_decoder_seq_len_f16)
auto data = make_shared<op::Parameter>(element::f16, data_shape);
auto seq_len = make_shared<op::Parameter>(element::i32, seq_len_shape);
auto blanck_index = op::Constant::create<int32_t>(element::i32, Shape{}, {2});
auto decoder = make_shared<op::v6::CTCGreedyDecoderSeqLen>(data, seq_len, blanck_index, true);
auto blank_index = op::Constant::create<int32_t>(element::i32, Shape{}, {2});
auto decoder = make_shared<op::v6::CTCGreedyDecoderSeqLen>(data, seq_len, blank_index, true);
auto function = make_shared<Function>(decoder, ParameterVector{data, seq_len});
auto test_case = test::TestCase<TestEngine>(function);
@ -112,8 +112,8 @@ NGRAPH_TEST(${BACKEND_NAME}, evaluate_ctc_greedy_decoder_seq_len_multiple_batche
auto data = make_shared<op::Parameter>(element::f32, data_shape);
auto seq_len = make_shared<op::Parameter>(element::i32, seq_len_shape);
auto blanck_index = op::Constant::create<int32_t>(element::i32, Shape{}, {2});
auto decoder = make_shared<op::v6::CTCGreedyDecoderSeqLen>(data, seq_len, blanck_index, false);
auto blank_index = op::Constant::create<int32_t>(element::i32, Shape{}, {2});
auto decoder = make_shared<op::v6::CTCGreedyDecoderSeqLen>(data, seq_len, blank_index, false);
auto function = make_shared<Function>(decoder, ParameterVector{data, seq_len});
auto test_case = test::TestCase<TestEngine>(function);
@ -154,8 +154,8 @@ NGRAPH_TEST(${BACKEND_NAME}, evaluate_ctc_greedy_decoder_seq_len_multiple_batche
auto data = make_shared<op::Parameter>(element::f32, data_shape);
auto seq_len = make_shared<op::Parameter>(element::i32, seq_len_shape);
auto blanck_index = op::Constant::create<int32_t>(element::i32, Shape{}, {2});
auto decoder = make_shared<op::v6::CTCGreedyDecoderSeqLen>(data, seq_len, blanck_index, false);
auto blank_index = op::Constant::create<int32_t>(element::i32, Shape{}, {2});
auto decoder = make_shared<op::v6::CTCGreedyDecoderSeqLen>(data, seq_len, blank_index, false);
auto function = make_shared<Function>(decoder, ParameterVector{data, seq_len});
auto test_case = test::TestCase<TestEngine>(function);
@ -170,3 +170,26 @@ NGRAPH_TEST(${BACKEND_NAME}, evaluate_ctc_greedy_decoder_seq_len_multiple_batche
test_case.run();
}
NGRAPH_TEST(${BACKEND_NAME}, evaluate_ctc_greedy_decoder_seq_len_no_optional_input)
{
const int N = 1;
const int T = 3;
const int C = 3;
const auto data_shape = Shape{N, T, C};
const auto seq_len_shape = Shape{N};
auto data = make_shared<op::Parameter>(element::f32, data_shape);
auto seq_len = make_shared<op::Parameter>(element::i32, seq_len_shape);
auto decoder = make_shared<op::v6::CTCGreedyDecoderSeqLen>(data, seq_len, false);
auto function = make_shared<Function>(decoder, ParameterVector{data, seq_len});
auto test_case = test::TestCase<TestEngine>(function);
test_case.add_input<float>({0.1f, 0.2f, 0.f, 0.4f, 0.3f, 0.f, 0.5f, 0.6f, 0.f});
test_case.add_input<int32_t>({2});
test_case.add_expected_output(Shape{N, T}, vector<int32_t>{1, 0, -1});
test_case.add_expected_output(Shape{N}, vector<int32_t>{2});
test_case.run();
}

View File

@ -2067,12 +2067,17 @@ namespace
using TF = typename element_type_traits<T1>::value_type;
using TI = typename element_type_traits<T2>::value_type;
using TIND1 = typename element_type_traits<TOUT>::value_type;
TI blank_index_val = inputs[0]->get_shape().back() - 1;
const TI *blank_index = &blank_index_val;
if (inputs.size() == 3) {
blank_index = inputs[2]->get_data_ptr<const TI>();
}
if (op->get_sequence_length_type() == element::i32)
{
runtime::reference::ctc_greedy_decoder_seq_len<TF>(
inputs[0]->get_data_ptr<const TF>(),
inputs[1]->get_data_ptr<const TI>(),
inputs[2]->get_data_ptr<const TI>(),
blank_index,
outputs[0]->get_data_ptr<TIND1>(),
outputs[1]->get_data_ptr<int32_t>(),
inputs[0]->get_shape(),
@ -2084,7 +2089,7 @@ namespace
runtime::reference::ctc_greedy_decoder_seq_len<TF>(
inputs[0]->get_data_ptr<const TF>(),
inputs[1]->get_data_ptr<const TI>(),
inputs[2]->get_data_ptr<const TI>(),
blank_index,
outputs[0]->get_data_ptr<TIND1>(),
outputs[1]->get_data_ptr<int64_t>(),
inputs[0]->get_shape(),