CTCGreedyDecoder reference implementation (#2284)

This commit is contained in:
Bartosz Lesniewski 2020-10-07 13:44:56 +02:00 committed by GitHub
parent 7a389b7ef5
commit 8f95e22a5c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
9 changed files with 313 additions and 8 deletions

View File

@ -48,7 +48,7 @@ namespace ngraph
private:
bool m_ctc_merge_repeated;
};
}
} // namespace v0
using v0::CTCGreedyDecoder;
}
}
} // namespace op
} // namespace ngraph

View File

@ -0,0 +1,84 @@
//*****************************************************************************
// Copyright 2017-2020 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#pragma once
#include <algorithm>
#include <limits>
#include <vector>
#include "ngraph/coordinate_transform.hpp"
namespace ngraph
{
namespace runtime
{
namespace reference
{
template <typename T>
void ctc_greedy_decoder(const T* data,
const T* sequence_masks,
T* out,
const Shape& data_shape,
const Shape& sequence_masks_shape,
const Shape& out_shape,
const bool ctc_merge_repeated)
{
auto max_seq_len = data_shape[0];
auto batch_size = data_shape[1];
auto class_count = data_shape[2];
CoordinateTransform out_transform = CoordinateTransform(out_shape);
CoordinateTransform data_transform = CoordinateTransform(data_shape);
CoordinateTransform seq_masks_transform = CoordinateTransform(sequence_masks_shape);
// final sequences don't have to fill the whole output, elements that don't store
// information are set to -1
std::vector<T> tmp_out(shape_size(out_shape));
std::fill(tmp_out.begin(), tmp_out.end(), static_cast<T>(-1.0));
for (unsigned int batch_ind = 0; batch_ind < batch_size; batch_ind++)
{
T previous_class_index = static_cast<T>(-1);
auto out_index = out_transform.index({batch_ind, 0, 0, 0});
for (unsigned int seq_ind = 0; seq_ind < max_seq_len; seq_ind++)
{
auto data_index = data_transform.index({seq_ind, batch_ind, 0});
auto mask_index = seq_masks_transform.index({seq_ind, batch_ind});
// first 0 marks the end of a sequence
if (std::abs(static_cast<double>(sequence_masks[mask_index] -
static_cast<T>(1))) >
std::numeric_limits<double>::epsilon())
{
continue;
}
auto class_index = data + data_index;
auto class_max_element =
std::max_element(class_index, class_index + class_count);
unsigned int max_class_ind = std::distance(class_index, class_max_element);
if (!(previous_class_index == max_class_ind && ctc_merge_repeated))
{
tmp_out[out_index++] = max_class_ind;
}
previous_class_index = max_class_ind;
}
}
std::copy(tmp_out.begin(), tmp_out.end(), out);
}
} // namespace reference
} // namespace runtime
} // namespace ngraph

View File

@ -273,6 +273,7 @@ set(MULTI_TEST_SRC
backend/convolution.in.cpp
backend/cos.in.cpp
backend/cosh.in.cpp
backend/ctc_greedy_decoder.in.cpp
backend/cum_sum.in.cpp
backend/divide.in.cpp
backend/dot.in.cpp

View File

@ -0,0 +1,197 @@
//*****************************************************************************
// Copyright 2017-2020 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#include <algorithm>
#include <cinttypes>
#include <cmath>
#include <cstdlib>
#include <random>
#include <string>
// clang-format off
#ifdef ${BACKEND_NAME}_FLOAT_TOLERANCE_BITS
#define DEFAULT_FLOAT_TOLERANCE_BITS ${BACKEND_NAME}_FLOAT_TOLERANCE_BITS
#endif
#ifdef ${BACKEND_NAME}_DOUBLE_TOLERANCE_BITS
#define DEFAULT_DOUBLE_TOLERANCE_BITS ${BACKEND_NAME}_DOUBLE_TOLERANCE_BITS
#endif
// clang-format on
#include "gtest/gtest.h"
#include "ngraph/ngraph.hpp"
#include "util/engine/test_engines.hpp"
#include "util/test_case.hpp"
#include "util/test_control.hpp"
NGRAPH_SUPPRESS_DEPRECATED_START
using namespace std;
using namespace ngraph;
static string s_manifest = "${MANIFEST}";
using TestEngine = test::ENGINE_CLASS_NAME(${BACKEND_NAME});
NGRAPH_TEST(${BACKEND_NAME}, ctc_greedy_decoder)
{
const int T = 3;
const int N = 1;
const int C = 2;
const auto data_shape = Shape{T, N, C};
const auto masks_shape = Shape{T, N};
auto data = make_shared<op::Parameter>(element::f32, data_shape);
auto masks = make_shared<op::Parameter>(element::f32, masks_shape);
auto decoder = make_shared<op::CTCGreedyDecoder>(data, masks, false);
auto function = make_shared<Function>(decoder, ParameterVector{data, masks});
auto test_case = test::TestCase<TestEngine>(function);
test_case.add_input<float>({0.1f, 0.2f, 0.4f, 0.3f, 0.5f, 0.6f});
test_case.add_input<float>({1.0f, 1.0f, 1.0f});
test_case.add_expected_output(Shape{N, T, 1, 1}, vector<float>{1.0f, 0.0f, 1.0f});
test_case.run_with_tolerance_as_fp(1.0e-4f);
}
NGRAPH_TEST(${BACKEND_NAME}, ctc_greedy_decoder_f16)
{
const int T = 3;
const int N = 1;
const int C = 2;
const auto data_shape = Shape{T, N, C};
const auto masks_shape = Shape{T, N};
auto data = make_shared<op::Parameter>(element::f16, data_shape);
auto masks = make_shared<op::Parameter>(element::f16, masks_shape);
auto decoder = make_shared<op::CTCGreedyDecoder>(data, masks, false);
auto function = make_shared<Function>(decoder, ParameterVector{data, masks});
auto test_case = test::TestCase<TestEngine>(function);
test_case.add_input<float16>({0.1f, 0.2f, 0.4f, 0.3f, 0.5f, 0.6f});
test_case.add_input<float16>({1.0f, 1.0f, 1.0f});
test_case.add_expected_output(Shape{N, T, 1, 1}, vector<float16>{1.0f, 0.0f, 1.0f});
test_case.run_with_tolerance_as_fp(1.0e-4f);
}
NGRAPH_TEST(${BACKEND_NAME}, ctc_greedy_decoder_multiple_batches)
{
const int T = 3;
const int N = 2;
const int C = 2;
const auto data_shape = Shape{T, N, C};
const auto masks_shape = Shape{T, N};
auto data = make_shared<op::Parameter>(element::f32, data_shape);
auto masks = make_shared<op::Parameter>(element::f32, masks_shape);
auto decoder = make_shared<op::CTCGreedyDecoder>(data, masks, false);
auto function = make_shared<Function>(decoder, ParameterVector{data, masks});
auto test_case = test::TestCase<TestEngine>(function);
test_case.add_input<float>(
{0.1f, 0.2f, 0.15f, 0.25f, 0.4f, 0.3f, 0.45f, 0.35f, 0.5f, 0.6f, 0.55f, 0.65f});
test_case.add_input<float>({1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f});
test_case.add_expected_output(Shape{N, T, 1, 1},
vector<float>{1.0f, 0.0f, 1.0f, 1.0f, 0.0f, 1.0f});
test_case.run_with_tolerance_as_fp(1.0e-4f);
}
NGRAPH_TEST(${BACKEND_NAME}, ctc_greedy_decoder_single_batch_short_sequence)
{
const int T = 3;
const int N = 1;
const int C = 2;
const auto data_shape = Shape{T, N, C};
const auto masks_shape = Shape{T, N};
auto data = make_shared<op::Parameter>(element::f32, data_shape);
auto masks = make_shared<op::Parameter>(element::f32, masks_shape);
auto decoder = make_shared<op::CTCGreedyDecoder>(data, masks, false);
auto function = make_shared<Function>(decoder, ParameterVector{data, masks});
auto test_case = test::TestCase<TestEngine>(function);
test_case.add_input<float>({0.1f, 0.2f, 0.4f, 0.3f, 0.5f, 0.6f});
test_case.add_input<float>({1.0f, 1.0f, 0.0f});
test_case.add_expected_output(Shape{N, T, 1, 1}, vector<float>{1.0f, 0.0f, -1.0f});
test_case.run_with_tolerance_as_fp(1.0e-4f);
}
NGRAPH_TEST(${BACKEND_NAME}, ctc_greedy_decoder_merge)
{
const int T = 3;
const int N = 1;
const int C = 2;
const auto data_shape = Shape{T, N, C};
const auto masks_shape = Shape{T, N};
auto data = make_shared<op::Parameter>(element::f32, data_shape);
auto masks = make_shared<op::Parameter>(element::f32, masks_shape);
auto decoder = make_shared<op::CTCGreedyDecoder>(data, masks, true);
auto function = make_shared<Function>(decoder, ParameterVector{data, masks});
auto test_case = test::TestCase<TestEngine>(function);
test_case.add_input<float>({0.1f, 0.2f, 0.3f, 0.4f, 0.6f, 0.5f});
test_case.add_input<float>({1.0f, 1.0f, 1.0f});
test_case.add_expected_output(Shape{N, T, 1, 1}, vector<float>{1.0f, 0.0f, -1.0f});
test_case.run_with_tolerance_as_fp(1.0e-4f);
}
NGRAPH_TEST(${BACKEND_NAME}, ctc_greedy_decoder_single_no_merge)
{
const int T = 3;
const int N = 1;
const int C = 2;
const auto data_shape = Shape{T, N, C};
const auto masks_shape = Shape{T, N};
auto data = make_shared<op::Parameter>(element::f32, data_shape);
auto masks = make_shared<op::Parameter>(element::f32, masks_shape);
auto decoder = make_shared<op::CTCGreedyDecoder>(data, masks, false);
auto function = make_shared<Function>(decoder, ParameterVector{data, masks});
auto test_case = test::TestCase<TestEngine>(function);
test_case.add_input<float>({0.1f, 0.2f, 0.3f, 0.4f, 0.6f, 0.5f});
test_case.add_input<float>({1.0f, 1.0f, 1.0f});
test_case.add_expected_output(Shape{N, T, 1, 1}, vector<float>{1.0f, 1.0f, 0.0f});
test_case.run_with_tolerance_as_fp(1.0e-4f);
}
NGRAPH_TEST(${BACKEND_NAME}, ctc_greedy_decoder_multiple_sequences)
{
const int T = 2;
const int N = 2;
const int C = 2;
const auto data_shape = Shape{T, N, C};
const auto masks_shape = Shape{T, N};
auto data = make_shared<op::Parameter>(element::f32, data_shape);
auto masks = make_shared<op::Parameter>(element::f32, masks_shape);
auto decoder = make_shared<op::CTCGreedyDecoder>(data, masks, false);
auto function = make_shared<Function>(decoder, ParameterVector{data, masks});
auto test_case = test::TestCase<TestEngine>(function);
test_case.add_input<float>({0.1f, 0.2f, 0.4f, 0.3f, 0.5f, 0.6f, 0.7f, 0.8f});
test_case.add_input<float>({1.0f, 1.0f, 1.0f, 0.0f});
test_case.add_expected_output(Shape{N, T, 1, 1}, vector<float>{1.0f, 1.0f, 0.0f, -1.0f});
test_case.run_with_tolerance_as_fp(1.0e-4f);
}

View File

@ -982,7 +982,7 @@ namespace
EXPECT_FALSE(op::is_binary_elementwise_comparison(&node));
EXPECT_TRUE(op::is_binary_elementwise_logical(&node));
}
}
} // namespace
TEST(op_is, check)
{

View File

@ -1116,6 +1116,9 @@ IE_CPU.builder_opset1_collapse_dyn_shape
# Y because output with index 0 contains dynamic shapes: {?,?,?,?}
IE_CPU.onnx_resize11_scales_nearest_asymmetric_floor_dynamic_sizes
# Input data precision not supported. Expected float.
ctc_greedy_decoder_f16
#-------------------------------------------------------------------------------
#
# Inference Engine GPU plugin excludes
@ -1201,6 +1204,12 @@ IE_GPU.onnx_dyn_shapes_avg_pool_dyn_shape
IE_GPU.onnx_dyn_shapes_max_pool_dyn_shape
IE_GPU.onnx_dyn_shapes_global_avg_pool_dyn_shape
IE_GPU.onnx_dyn_shapes_global_max_pool_dyn_shape
IE_GPU.ctc_greedy_decoder
IE_GPU.ctc_greedy_decoder_multiple_batches
IE_GPU.ctc_greedy_decoder_single_batch_short_sequence
IE_GPU.ctc_greedy_decoder_merge
IE_GPU.ctc_greedy_decoder_single_no_merge
IE_GPU.ctc_greedy_decoder_multiple_sequences
IE_GPU.onnx_roi_align_f32
IE_GPU.tanh
IE_GPU.tan

View File

@ -44,6 +44,7 @@
#include "ngraph/runtime/reference/convolution.hpp"
#include "ngraph/runtime/reference/cos.hpp"
#include "ngraph/runtime/reference/cosh.hpp"
#include "ngraph/runtime/reference/ctc_greedy_decoder.hpp"
#include "ngraph/runtime/reference/ctc_loss.hpp"
#include "ngraph/runtime/reference/cum_sum.hpp"
#include "ngraph/runtime/reference/dequantize.hpp"
@ -60,7 +61,6 @@
#include "ngraph/runtime/reference/gather.hpp"
#include "ngraph/runtime/reference/gather_nd.hpp"
#include "ngraph/runtime/reference/gather_tree.hpp"
#include "ngraph/runtime/reference/gather_tree.hpp"
#include "ngraph/runtime/reference/gru_cell.hpp"
#include "ngraph/runtime/reference/log.hpp"
#include "ngraph/runtime/reference/lrn.hpp"
@ -397,6 +397,18 @@ protected:
args[0]->get_data_ptr<const T>(), out[0]->get_data_ptr<T>(), element_count);
break;
}
case OP_TYPEID::CTCGreedyDecoder_v0:
{
const auto ctc_greedy_dec = static_cast<const op::v0::CTCGreedyDecoder*>(&node);
reference::ctc_greedy_decoder<T>(args[0]->get_data_ptr<const T>(),
args[1]->get_data_ptr<const T>(),
out[0]->get_data_ptr<T>(),
args[0]->get_shape(),
args[1]->get_shape(),
out[0]->get_shape(),
ctc_greedy_dec->get_ctc_merge_repeated());
break;
}
case OP_TYPEID::CTCLoss_v4:
{
const op::v4::CTCLoss* ctc_loss = static_cast<const op::v4::CTCLoss*>(&node);
@ -809,9 +821,7 @@ protected:
gru_seq->get_activations()[1],
gru_seq->get_clip(),
gru_seq->get_direction(),
gru_seq->get_linear_before_reset()
);
gru_seq->get_linear_before_reset());
break;
}
case OP_TYPEID::RNNSequence_v5:

View File

@ -19,6 +19,7 @@
#undef ID_SUFFIX
#define ID_SUFFIX(NAME) NAME##_v0
NGRAPH_OP(CTCGreedyDecoder, ngraph::op::v0)
NGRAPH_OP(DetectionOutput, op::v0)
NGRAPH_OP(RNNCell, op::v0)
#undef ID_SUFFIX

View File

@ -138,3 +138,6 @@ onnx_model_gru_fwd_activations
# Peepholes, input_forget are not supported
lstm_cell_bias_peepholes
lstm_cell_bias_peepholes_clip_input_forget
# unsupported element type f16
INTERPRETER.ctc_greedy_decoder_f16