Imironov/ref ngraph ctc gready decoder (#3867)

* Add ctc gready decoder sec len op to ngraph

* Remove some comments

* Add second constructor

* Fix code style

* Fix code style

* Add unit tests

* Add tests to cmake

* Fix according to review

* Fix code style

* fix

* Change input layoyt

* Fix code style

* Add unit tests

* Add 3 input tensor check

* Update shell impl

* Fix code style

* Fix code style

* Add doxy gen

* Fix code style

* Update doxigen

* Update constructor description

* Fix code style

* Refactoring code

* fix code style

* Optimize op constructor

* Add macros. Optimize code for validate_and_infer_types

* Refactoring code

* Fix code style

* Fix code style

* Fix check blanck_index shape

* Fix code style

* Add ref impl

* Fix unit test for dynemic case

* Fix code style

* Fix copyryting

* reverse changes

* Update copyrite

* Add ref implemintation

* rollback

* Fix code style

* Fix code style

* Fix

* Add unit tests

* Refactoring ref impl

* Refactoring code style

* Fix code style

* Fix code style

* fix unit tests

* Refactoring code

* Refactoring code

* Fix code style

* Refactoring unit tests

* Fix style

* Fix style
This commit is contained in:
iliya mironov 2021-01-25 08:19:03 +03:00 committed by GitHub
parent 94b2cc1dad
commit c083e7fb63
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 347 additions and 0 deletions

View File

@ -0,0 +1,69 @@
//*****************************************************************************
// Copyright 2017-2021 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#pragma once
#include <algorithm>
#include <limits>
#include <vector>
#include "ngraph/coordinate_transform.hpp"
namespace ngraph
{
namespace runtime
{
namespace reference
{
template <typename TF, typename TI, typename TCI, typename TSL>
void ctc_greedy_decoder_seq_len(const TF* data,
const TI* sequence_length,
const TI* blank_index,
TCI* out1,
TSL* out2,
const Shape& data_shape,
const Shape& out_shape,
const bool ctc_merge_repeated)
{
const auto batch_size = data_shape[0];
const auto seq_len_max = data_shape[1];
const auto class_count = data_shape[2];
std::fill_n(out1, shape_size(out_shape), -1);
for (std::size_t batch_ind = 0; batch_ind < batch_size; ++batch_ind)
{
TI previous_class_index = static_cast<TI>(-1);
auto out_index = batch_ind * seq_len_max;
auto seq_len = static_cast<std::size_t>(sequence_length[batch_ind]);
for (std::size_t seq_ind = 0; seq_ind < seq_len; seq_ind++)
{
auto data_index =
batch_ind * seq_len_max * class_count + seq_ind * class_count;
auto class_index = data + data_index;
auto class_max_element =
std::max_element(class_index, class_index + class_count);
const auto max_class_ind = std::distance(class_index, class_max_element);
if (max_class_ind < blank_index[0] &&
!(ctc_merge_repeated && previous_class_index == max_class_ind))
{
out1[out_index++] = max_class_ind;
}
previous_class_index = max_class_ind;
}
out2[batch_ind] = seq_len;
}
}
} // namespace reference
} // namespace runtime
} // namespace ngraph

View File

@ -273,6 +273,7 @@ set(MULTI_TEST_SRC
backend/cos.in.cpp
backend/cosh.in.cpp
backend/ctc_greedy_decoder.in.cpp
backend/ctc_greedy_decoder_seq_len.in.cpp
backend/cum_sum.in.cpp
backend/detection_output.in.cpp
backend/divide.in.cpp

View File

@ -0,0 +1,184 @@
//*****************************************************************************
// Copyright 2017-2021 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#include <algorithm>
#include <cinttypes>
#include <cmath>
#include <cstdlib>
#include <random>
#include <string>
// clang-format off
#ifdef ${BACKEND_NAME}_FLOAT_TOLERANCE_BITS
#define DEFAULT_FLOAT_TOLERANCE_BITS ${BACKEND_NAME}_FLOAT_TOLERANCE_BITS
#endif
#ifdef ${BACKEND_NAME}_DOUBLE_TOLERANCE_BITS
#define DEFAULT_DOUBLE_TOLERANCE_BITS ${BACKEND_NAME}_DOUBLE_TOLERANCE_BITS
#endif
// clang-format on
#include "gtest/gtest.h"
#include "ngraph/ngraph.hpp"
#include "util/engine/test_engines.hpp"
#include "util/test_case.hpp"
#include "util/test_control.hpp"
NGRAPH_SUPPRESS_DEPRECATED_START
using namespace std;
using namespace ngraph;
static string s_manifest = "${MANIFEST}";
using TestEngine = test::ENGINE_CLASS_NAME(${BACKEND_NAME});
NGRAPH_TEST(${BACKEND_NAME}, evaluate_ctc_greedy_decoder_seq_len)
{
const int N = 1;
const int T = 3;
const int C = 3;
const auto data_shape = Shape{N, T, C};
const auto seq_len_shape = Shape{N};
auto data = make_shared<op::Parameter>(element::f32, data_shape);
auto seq_len = make_shared<op::Parameter>(element::i32, seq_len_shape);
auto blanck_index = op::Constant::create<int32_t>(element::i32, Shape{}, {2});
auto decoder = make_shared<op::v6::CTCGreedyDecoderSeqLen>(data, seq_len, blanck_index, false);
auto function = make_shared<Function>(decoder, ParameterVector{data, seq_len});
auto test_case = test::TestCase<TestEngine>(function);
test_case.add_input<float>({0.1f, 0.2f, 0.f, 0.4f, 0.3f, 0.f, 0.5f, 0.6f, 0.f});
test_case.add_input<int32_t>({2});
test_case.add_expected_output(Shape{N, T}, vector<int32_t>{1, 0, -1});
test_case.add_expected_output(Shape{N}, vector<int32_t>{2});
test_case.run();
}
NGRAPH_TEST(${BACKEND_NAME}, evaluate_ctc_greedy_decoder_seq_len_merge)
{
const int N = 1;
const int T = 3;
const int C = 3;
const auto data_shape = Shape{N, T, C};
const auto seq_len_shape = Shape{N};
auto data = make_shared<op::Parameter>(element::f32, data_shape);
auto seq_len = make_shared<op::Parameter>(element::i32, seq_len_shape);
auto blanck_index = op::Constant::create<int32_t>(element::i32, Shape{}, {2});
auto decoder = make_shared<op::v6::CTCGreedyDecoderSeqLen>(data, seq_len, blanck_index, true);
auto function = make_shared<Function>(decoder, ParameterVector{data, seq_len});
auto test_case = test::TestCase<TestEngine>(function);
test_case.add_input<float>({0.1f, 0.2f, 0.f, 0.4f, 0.3f, 0.f, 0.5f, 0.6f, 0.f});
test_case.add_input<int32_t>({2});
test_case.add_expected_output(Shape{N, T}, vector<int32_t>{1, 0, -1});
test_case.add_expected_output(Shape{N}, vector<int32_t>{2});
test_case.run();
}
NGRAPH_TEST(${BACKEND_NAME}, evaluate_ctc_greedy_decoder_seq_len_f16)
{
const int N = 1;
const int T = 3;
const int C = 3;
const auto data_shape = Shape{N, T, C};
const auto seq_len_shape = Shape{N};
auto data = make_shared<op::Parameter>(element::f16, data_shape);
auto seq_len = make_shared<op::Parameter>(element::i32, seq_len_shape);
auto blanck_index = op::Constant::create<int32_t>(element::i32, Shape{}, {2});
auto decoder = make_shared<op::v6::CTCGreedyDecoderSeqLen>(data, seq_len, blanck_index, true);
auto function = make_shared<Function>(decoder, ParameterVector{data, seq_len});
auto test_case = test::TestCase<TestEngine>(function);
test_case.add_input<float16>({0.1f, 0.2f, 0.f, 0.4f, 0.3f, 0.f, 0.5f, 0.6f, 0.f});
test_case.add_input<int32_t>({2});
test_case.add_expected_output(Shape{N, T}, vector<int32_t>{1, 0, -1});
test_case.add_expected_output(Shape{N}, vector<int32_t>{2});
test_case.run();
}
NGRAPH_TEST(${BACKEND_NAME}, evaluate_ctc_greedy_decoder_seq_len_multiple_batches)
{
const int N = 2;
const int T = 3;
const int C = 3;
const auto data_shape = Shape{N, T, C};
const auto seq_len_shape = Shape{N};
auto data = make_shared<op::Parameter>(element::f32, data_shape);
auto seq_len = make_shared<op::Parameter>(element::i32, seq_len_shape);
auto blanck_index = op::Constant::create<int32_t>(element::i32, Shape{}, {2});
auto decoder = make_shared<op::v6::CTCGreedyDecoderSeqLen>(data, seq_len, blanck_index, false);
auto function = make_shared<Function>(decoder, ParameterVector{data, seq_len});
auto test_case = test::TestCase<TestEngine>(function);
test_case.add_input<float>({0.1f,
0.2f,
0.f,
0.15f,
0.25f,
0.f,
0.4f,
0.3f,
0.f,
0.45f,
0.35f,
0.f,
0.5f,
0.6f,
0.f,
0.55f,
0.65f,
0.f});
test_case.add_input<int32_t>({1, 1});
test_case.add_expected_output(Shape{N, T}, vector<int32_t>{1, -1, -1, 0, -1, -1});
test_case.add_expected_output(Shape{N}, vector<int32_t>{1, 1});
test_case.run();
}
NGRAPH_TEST(${BACKEND_NAME}, evaluate_ctc_greedy_decoder_seq_len_multiple_batches2)
{
const int N = 3;
const int T = 3;
const int C = 3;
const auto data_shape = Shape{N, T, C};
const auto seq_len_shape = Shape{N};
auto data = make_shared<op::Parameter>(element::f32, data_shape);
auto seq_len = make_shared<op::Parameter>(element::i32, seq_len_shape);
auto blanck_index = op::Constant::create<int32_t>(element::i32, Shape{}, {2});
auto decoder = make_shared<op::v6::CTCGreedyDecoderSeqLen>(data, seq_len, blanck_index, false);
auto function = make_shared<Function>(decoder, ParameterVector{data, seq_len});
auto test_case = test::TestCase<TestEngine>(function);
test_case.add_input<float>({0.1f, 0.2f, 0.f, 0.15f, 0.25f, 0.f, 0.4f, 0.3f, 0.f,
0.45f, 0.35f, 0.f, 0.5f, 0.6f, 0.f, 0.55f, 0.65f, 0.f,
0.1f, 0.2f, 0.f, 0.15f, 0.25f, 0.f, 0.4f, 0.3f, 0.f});
test_case.add_input<int32_t>({2, 3, 1});
test_case.add_expected_output(Shape{N, T}, vector<int32_t>{1, 1, -1, 0, 1, 1, 1, -1, -1});
test_case.add_expected_output(Shape{N}, vector<int32_t>{2, 3, 1});
test_case.run();
}

View File

@ -1595,6 +1595,17 @@ IE_GPU.onnx_model_gather_elements_int32_axis_0
IE_GPU.onnx_model_gather_elements_int8_axis_1
IE_GPU.onnx_model_gather_elements_float_3D_axis_2
IE_CPU.evaluate_ctc_greedy_decoder_seq_len
IE_GPU.evaluate_ctc_greedy_decoder_seq_len
IE_CPU.evaluate_ctc_greedy_decoder_seq_len_f16
IE_GPU.evaluate_ctc_greedy_decoder_seq_len_f16
IE_CPU.evaluate_ctc_greedy_decoder_seq_len_merge
IE_GPU.evaluate_ctc_greedy_decoder_seq_len_merge
IE_CPU.evaluate_ctc_greedy_decoder_seq_len_multiple_batches
IE_GPU.evaluate_ctc_greedy_decoder_seq_len_multiple_batches
IE_CPU.evaluate_ctc_greedy_decoder_seq_len_multiple_batches2
IE_GPU.evaluate_ctc_greedy_decoder_seq_len_multiple_batches2
# incorrect result for Minimum if u16 type is used
minimum_u16
IE_CPU/ElemTypesTests/1.onnx_test_add_abc_set_precission

View File

@ -27,6 +27,7 @@
#include <ngraph/runtime/reference/convert.hpp>
#include <ngraph/runtime/reference/convolution.hpp>
#include <ngraph/runtime/reference/ctc_greedy_decoder.hpp>
#include <ngraph/runtime/reference/ctc_greedy_decoder_seq_len.hpp>
#include <ngraph/runtime/reference/ctc_loss.hpp>
#include <ngraph/runtime/reference/cum_sum.hpp>
#include <ngraph/runtime/reference/detection_output.hpp>
@ -1732,6 +1733,86 @@ namespace
return true;
}
namespace ctc_greedy_decoder_v6
{
template <element::Type_t T1, element::Type_t T2, element::Type_t TOUT>
inline void evaluate(const shared_ptr<op::v6::CTCGreedyDecoderSeqLen>& op,
const HostTensorVector& outputs,
const HostTensorVector& inputs)
{
using TF = typename element_type_traits<T1>::value_type;
using TI = typename element_type_traits<T2>::value_type;
using TIND1 = typename element_type_traits<TOUT>::value_type;
if (op->get_sequence_length_type() == element::i32)
{
runtime::reference::ctc_greedy_decoder_seq_len<TF>(
inputs[0]->get_data_ptr<const TF>(),
inputs[1]->get_data_ptr<const TI>(),
inputs[2]->get_data_ptr<const TI>(),
outputs[0]->get_data_ptr<TIND1>(),
outputs[1]->get_data_ptr<int32_t>(),
inputs[0]->get_shape(),
outputs[0]->get_shape(),
op->get_merge_repeated());
}
else if (op->get_sequence_length_type() == element::i64)
{
runtime::reference::ctc_greedy_decoder_seq_len<TF>(
inputs[0]->get_data_ptr<const TF>(),
inputs[1]->get_data_ptr<const TI>(),
inputs[2]->get_data_ptr<const TI>(),
outputs[0]->get_data_ptr<TIND1>(),
outputs[1]->get_data_ptr<int64_t>(),
inputs[0]->get_shape(),
outputs[0]->get_shape(),
op->get_merge_repeated());
}
}
}
template <element::Type_t ET>
bool evaluate(const shared_ptr<op::v6::CTCGreedyDecoderSeqLen>& op,
const HostTensorVector& outputs,
const HostTensorVector& inputs)
{
const auto& dataType = inputs[0]->get_element_type();
const auto& seqLenType = inputs[1]->get_element_type();
if (dataType == element::Type_t::f16 && seqLenType == element::Type_t::i32)
{
ctc_greedy_decoder_v6::evaluate<element::Type_t::f16, element::Type_t::i32, ET>(
op, outputs, inputs);
}
else if (dataType == element::Type_t::f32 && seqLenType == element::Type_t::i32)
{
ctc_greedy_decoder_v6::evaluate<element::Type_t::f32, element::Type_t::i32, ET>(
op, outputs, inputs);
}
else if (dataType == element::Type_t::f64 && seqLenType == element::Type_t::i32)
{
ctc_greedy_decoder_v6::evaluate<element::Type_t::f64, element::Type_t::i32, ET>(
op, outputs, inputs);
}
else if (dataType == element::Type_t::f16 && seqLenType == element::Type_t::i64)
{
ctc_greedy_decoder_v6::evaluate<element::Type_t::f16, element::Type_t::i64, ET>(
op, outputs, inputs);
}
else if (dataType == element::Type_t::f32 && seqLenType == element::Type_t::i64)
{
ctc_greedy_decoder_v6::evaluate<element::Type_t::f32, element::Type_t::i64, ET>(
op, outputs, inputs);
}
else if (dataType == element::Type_t::f64 && seqLenType == element::Type_t::i64)
{
ctc_greedy_decoder_v6::evaluate<element::Type_t::f64, element::Type_t::i64, ET>(
op, outputs, inputs);
}
else
{
return false;
}
return true;
}
template <element::Type_t ET>
bool evaluate(const shared_ptr<op::v0::SquaredDifference>& op,
const HostTensorVector& outputs,

View File

@ -89,5 +89,6 @@ NGRAPH_OP(NonMaxSuppression, op::v5)
NGRAPH_OP(RNNSequence, op::v5)
NGRAPH_OP(Round, op::v5)
NGRAPH_OP(CTCGreedyDecoderSeqLen, op::v6)
NGRAPH_OP(GatherElements, op::v6)
NGRAPH_OP(MVN, ngraph::op::v6)