Handle 3rd optional input in CTCGreedyDecoderSeqLen evaluate (#5669)
* Handle 3rd optional input in CTCGreedyDecoderSeqLen evaluate * Add test for OP without blank_index input * Fix default blank_index calculation
This commit is contained in:
parent
97f020481a
commit
66d98530d9
@ -43,8 +43,8 @@ NGRAPH_TEST(${BACKEND_NAME}, evaluate_ctc_greedy_decoder_seq_len)
|
||||
|
||||
auto data = make_shared<op::Parameter>(element::f32, data_shape);
|
||||
auto seq_len = make_shared<op::Parameter>(element::i32, seq_len_shape);
|
||||
auto blanck_index = op::Constant::create<int32_t>(element::i32, Shape{}, {2});
|
||||
auto decoder = make_shared<op::v6::CTCGreedyDecoderSeqLen>(data, seq_len, blanck_index, false);
|
||||
auto blank_index = op::Constant::create<int32_t>(element::i32, Shape{}, {2});
|
||||
auto decoder = make_shared<op::v6::CTCGreedyDecoderSeqLen>(data, seq_len, blank_index, false);
|
||||
auto function = make_shared<Function>(decoder, ParameterVector{data, seq_len});
|
||||
auto test_case = test::TestCase<TestEngine>(function);
|
||||
|
||||
@ -66,8 +66,8 @@ NGRAPH_TEST(${BACKEND_NAME}, evaluate_ctc_greedy_decoder_seq_len_merge)
|
||||
|
||||
auto data = make_shared<op::Parameter>(element::f32, data_shape);
|
||||
auto seq_len = make_shared<op::Parameter>(element::i32, seq_len_shape);
|
||||
auto blanck_index = op::Constant::create<int32_t>(element::i32, Shape{}, {2});
|
||||
auto decoder = make_shared<op::v6::CTCGreedyDecoderSeqLen>(data, seq_len, blanck_index, true);
|
||||
auto blank_index = op::Constant::create<int32_t>(element::i32, Shape{}, {2});
|
||||
auto decoder = make_shared<op::v6::CTCGreedyDecoderSeqLen>(data, seq_len, blank_index, true);
|
||||
auto function = make_shared<Function>(decoder, ParameterVector{data, seq_len});
|
||||
auto test_case = test::TestCase<TestEngine>(function);
|
||||
|
||||
@ -89,8 +89,8 @@ NGRAPH_TEST(${BACKEND_NAME}, evaluate_ctc_greedy_decoder_seq_len_f16)
|
||||
|
||||
auto data = make_shared<op::Parameter>(element::f16, data_shape);
|
||||
auto seq_len = make_shared<op::Parameter>(element::i32, seq_len_shape);
|
||||
auto blanck_index = op::Constant::create<int32_t>(element::i32, Shape{}, {2});
|
||||
auto decoder = make_shared<op::v6::CTCGreedyDecoderSeqLen>(data, seq_len, blanck_index, true);
|
||||
auto blank_index = op::Constant::create<int32_t>(element::i32, Shape{}, {2});
|
||||
auto decoder = make_shared<op::v6::CTCGreedyDecoderSeqLen>(data, seq_len, blank_index, true);
|
||||
auto function = make_shared<Function>(decoder, ParameterVector{data, seq_len});
|
||||
auto test_case = test::TestCase<TestEngine>(function);
|
||||
|
||||
@ -112,8 +112,8 @@ NGRAPH_TEST(${BACKEND_NAME}, evaluate_ctc_greedy_decoder_seq_len_multiple_batche
|
||||
|
||||
auto data = make_shared<op::Parameter>(element::f32, data_shape);
|
||||
auto seq_len = make_shared<op::Parameter>(element::i32, seq_len_shape);
|
||||
auto blanck_index = op::Constant::create<int32_t>(element::i32, Shape{}, {2});
|
||||
auto decoder = make_shared<op::v6::CTCGreedyDecoderSeqLen>(data, seq_len, blanck_index, false);
|
||||
auto blank_index = op::Constant::create<int32_t>(element::i32, Shape{}, {2});
|
||||
auto decoder = make_shared<op::v6::CTCGreedyDecoderSeqLen>(data, seq_len, blank_index, false);
|
||||
auto function = make_shared<Function>(decoder, ParameterVector{data, seq_len});
|
||||
auto test_case = test::TestCase<TestEngine>(function);
|
||||
|
||||
@ -154,8 +154,8 @@ NGRAPH_TEST(${BACKEND_NAME}, evaluate_ctc_greedy_decoder_seq_len_multiple_batche
|
||||
|
||||
auto data = make_shared<op::Parameter>(element::f32, data_shape);
|
||||
auto seq_len = make_shared<op::Parameter>(element::i32, seq_len_shape);
|
||||
auto blanck_index = op::Constant::create<int32_t>(element::i32, Shape{}, {2});
|
||||
auto decoder = make_shared<op::v6::CTCGreedyDecoderSeqLen>(data, seq_len, blanck_index, false);
|
||||
auto blank_index = op::Constant::create<int32_t>(element::i32, Shape{}, {2});
|
||||
auto decoder = make_shared<op::v6::CTCGreedyDecoderSeqLen>(data, seq_len, blank_index, false);
|
||||
auto function = make_shared<Function>(decoder, ParameterVector{data, seq_len});
|
||||
auto test_case = test::TestCase<TestEngine>(function);
|
||||
|
||||
@ -170,3 +170,26 @@ NGRAPH_TEST(${BACKEND_NAME}, evaluate_ctc_greedy_decoder_seq_len_multiple_batche
|
||||
|
||||
test_case.run();
|
||||
}
|
||||
|
||||
|
||||
NGRAPH_TEST(${BACKEND_NAME}, evaluate_ctc_greedy_decoder_seq_len_no_optional_input)
|
||||
{
|
||||
const int N = 1;
|
||||
const int T = 3;
|
||||
const int C = 3;
|
||||
const auto data_shape = Shape{N, T, C};
|
||||
const auto seq_len_shape = Shape{N};
|
||||
|
||||
auto data = make_shared<op::Parameter>(element::f32, data_shape);
|
||||
auto seq_len = make_shared<op::Parameter>(element::i32, seq_len_shape);
|
||||
auto decoder = make_shared<op::v6::CTCGreedyDecoderSeqLen>(data, seq_len, false);
|
||||
auto function = make_shared<Function>(decoder, ParameterVector{data, seq_len});
|
||||
auto test_case = test::TestCase<TestEngine>(function);
|
||||
|
||||
test_case.add_input<float>({0.1f, 0.2f, 0.f, 0.4f, 0.3f, 0.f, 0.5f, 0.6f, 0.f});
|
||||
test_case.add_input<int32_t>({2});
|
||||
test_case.add_expected_output(Shape{N, T}, vector<int32_t>{1, 0, -1});
|
||||
test_case.add_expected_output(Shape{N}, vector<int32_t>{2});
|
||||
|
||||
test_case.run();
|
||||
}
|
@ -2067,12 +2067,17 @@ namespace
|
||||
using TF = typename element_type_traits<T1>::value_type;
|
||||
using TI = typename element_type_traits<T2>::value_type;
|
||||
using TIND1 = typename element_type_traits<TOUT>::value_type;
|
||||
TI blank_index_val = inputs[0]->get_shape().back() - 1;
|
||||
const TI *blank_index = &blank_index_val;
|
||||
if (inputs.size() == 3) {
|
||||
blank_index = inputs[2]->get_data_ptr<const TI>();
|
||||
}
|
||||
if (op->get_sequence_length_type() == element::i32)
|
||||
{
|
||||
runtime::reference::ctc_greedy_decoder_seq_len<TF>(
|
||||
inputs[0]->get_data_ptr<const TF>(),
|
||||
inputs[1]->get_data_ptr<const TI>(),
|
||||
inputs[2]->get_data_ptr<const TI>(),
|
||||
blank_index,
|
||||
outputs[0]->get_data_ptr<TIND1>(),
|
||||
outputs[1]->get_data_ptr<int32_t>(),
|
||||
inputs[0]->get_shape(),
|
||||
@ -2084,7 +2089,7 @@ namespace
|
||||
runtime::reference::ctc_greedy_decoder_seq_len<TF>(
|
||||
inputs[0]->get_data_ptr<const TF>(),
|
||||
inputs[1]->get_data_ptr<const TI>(),
|
||||
inputs[2]->get_data_ptr<const TI>(),
|
||||
blank_index,
|
||||
outputs[0]->get_data_ptr<TIND1>(),
|
||||
outputs[1]->get_data_ptr<int64_t>(),
|
||||
inputs[0]->get_shape(),
|
||||
|
Loading…
Reference in New Issue
Block a user