CTCGreedyDecoder reference implementation (#2284)
This commit is contained in:
parent
7a389b7ef5
commit
8f95e22a5c
@ -48,7 +48,7 @@ namespace ngraph
|
||||
private:
|
||||
bool m_ctc_merge_repeated;
|
||||
};
|
||||
}
|
||||
} // namespace v0
|
||||
using v0::CTCGreedyDecoder;
|
||||
}
|
||||
}
|
||||
} // namespace op
|
||||
} // namespace ngraph
|
||||
|
@ -0,0 +1,84 @@
|
||||
//*****************************************************************************
|
||||
// Copyright 2017-2020 Intel Corporation
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
//*****************************************************************************
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <algorithm>
|
||||
#include <limits>
|
||||
#include <vector>
|
||||
#include "ngraph/coordinate_transform.hpp"
|
||||
namespace ngraph
|
||||
{
|
||||
namespace runtime
|
||||
{
|
||||
namespace reference
|
||||
{
|
||||
template <typename T>
|
||||
void ctc_greedy_decoder(const T* data,
|
||||
const T* sequence_masks,
|
||||
T* out,
|
||||
const Shape& data_shape,
|
||||
const Shape& sequence_masks_shape,
|
||||
const Shape& out_shape,
|
||||
const bool ctc_merge_repeated)
|
||||
{
|
||||
auto max_seq_len = data_shape[0];
|
||||
auto batch_size = data_shape[1];
|
||||
auto class_count = data_shape[2];
|
||||
|
||||
CoordinateTransform out_transform = CoordinateTransform(out_shape);
|
||||
CoordinateTransform data_transform = CoordinateTransform(data_shape);
|
||||
CoordinateTransform seq_masks_transform = CoordinateTransform(sequence_masks_shape);
|
||||
|
||||
// final sequences don't have to fill the whole output, elements that don't store
|
||||
// information are set to -1
|
||||
|
||||
std::vector<T> tmp_out(shape_size(out_shape));
|
||||
std::fill(tmp_out.begin(), tmp_out.end(), static_cast<T>(-1.0));
|
||||
|
||||
for (unsigned int batch_ind = 0; batch_ind < batch_size; batch_ind++)
|
||||
{
|
||||
T previous_class_index = static_cast<T>(-1);
|
||||
auto out_index = out_transform.index({batch_ind, 0, 0, 0});
|
||||
for (unsigned int seq_ind = 0; seq_ind < max_seq_len; seq_ind++)
|
||||
{
|
||||
auto data_index = data_transform.index({seq_ind, batch_ind, 0});
|
||||
auto mask_index = seq_masks_transform.index({seq_ind, batch_ind});
|
||||
|
||||
// first 0 marks the end of a sequence
|
||||
if (std::abs(static_cast<double>(sequence_masks[mask_index] -
|
||||
static_cast<T>(1))) >
|
||||
std::numeric_limits<double>::epsilon())
|
||||
{
|
||||
continue;
|
||||
}
|
||||
|
||||
auto class_index = data + data_index;
|
||||
auto class_max_element =
|
||||
std::max_element(class_index, class_index + class_count);
|
||||
unsigned int max_class_ind = std::distance(class_index, class_max_element);
|
||||
if (!(previous_class_index == max_class_ind && ctc_merge_repeated))
|
||||
{
|
||||
tmp_out[out_index++] = max_class_ind;
|
||||
}
|
||||
previous_class_index = max_class_ind;
|
||||
}
|
||||
}
|
||||
std::copy(tmp_out.begin(), tmp_out.end(), out);
|
||||
}
|
||||
} // namespace reference
|
||||
} // namespace runtime
|
||||
} // namespace ngraph
|
@ -273,6 +273,7 @@ set(MULTI_TEST_SRC
|
||||
backend/convolution.in.cpp
|
||||
backend/cos.in.cpp
|
||||
backend/cosh.in.cpp
|
||||
backend/ctc_greedy_decoder.in.cpp
|
||||
backend/cum_sum.in.cpp
|
||||
backend/divide.in.cpp
|
||||
backend/dot.in.cpp
|
||||
|
197
ngraph/test/backend/ctc_greedy_decoder.in.cpp
Normal file
197
ngraph/test/backend/ctc_greedy_decoder.in.cpp
Normal file
@ -0,0 +1,197 @@
|
||||
//*****************************************************************************
|
||||
// Copyright 2017-2020 Intel Corporation
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
//*****************************************************************************
|
||||
|
||||
#include <algorithm>
|
||||
#include <cinttypes>
|
||||
#include <cmath>
|
||||
#include <cstdlib>
|
||||
#include <random>
|
||||
#include <string>
|
||||
|
||||
// clang-format off
|
||||
#ifdef ${BACKEND_NAME}_FLOAT_TOLERANCE_BITS
|
||||
#define DEFAULT_FLOAT_TOLERANCE_BITS ${BACKEND_NAME}_FLOAT_TOLERANCE_BITS
|
||||
#endif
|
||||
|
||||
#ifdef ${BACKEND_NAME}_DOUBLE_TOLERANCE_BITS
|
||||
#define DEFAULT_DOUBLE_TOLERANCE_BITS ${BACKEND_NAME}_DOUBLE_TOLERANCE_BITS
|
||||
#endif
|
||||
// clang-format on
|
||||
|
||||
#include "gtest/gtest.h"
|
||||
#include "ngraph/ngraph.hpp"
|
||||
#include "util/engine/test_engines.hpp"
|
||||
#include "util/test_case.hpp"
|
||||
#include "util/test_control.hpp"
|
||||
|
||||
NGRAPH_SUPPRESS_DEPRECATED_START
|
||||
|
||||
using namespace std;
|
||||
using namespace ngraph;
|
||||
|
||||
static string s_manifest = "${MANIFEST}";
|
||||
using TestEngine = test::ENGINE_CLASS_NAME(${BACKEND_NAME});
|
||||
|
||||
NGRAPH_TEST(${BACKEND_NAME}, ctc_greedy_decoder)
|
||||
{
|
||||
const int T = 3;
|
||||
const int N = 1;
|
||||
const int C = 2;
|
||||
const auto data_shape = Shape{T, N, C};
|
||||
const auto masks_shape = Shape{T, N};
|
||||
|
||||
auto data = make_shared<op::Parameter>(element::f32, data_shape);
|
||||
auto masks = make_shared<op::Parameter>(element::f32, masks_shape);
|
||||
auto decoder = make_shared<op::CTCGreedyDecoder>(data, masks, false);
|
||||
auto function = make_shared<Function>(decoder, ParameterVector{data, masks});
|
||||
auto test_case = test::TestCase<TestEngine>(function);
|
||||
|
||||
test_case.add_input<float>({0.1f, 0.2f, 0.4f, 0.3f, 0.5f, 0.6f});
|
||||
test_case.add_input<float>({1.0f, 1.0f, 1.0f});
|
||||
test_case.add_expected_output(Shape{N, T, 1, 1}, vector<float>{1.0f, 0.0f, 1.0f});
|
||||
|
||||
test_case.run_with_tolerance_as_fp(1.0e-4f);
|
||||
}
|
||||
|
||||
NGRAPH_TEST(${BACKEND_NAME}, ctc_greedy_decoder_f16)
|
||||
{
|
||||
const int T = 3;
|
||||
const int N = 1;
|
||||
const int C = 2;
|
||||
const auto data_shape = Shape{T, N, C};
|
||||
const auto masks_shape = Shape{T, N};
|
||||
|
||||
auto data = make_shared<op::Parameter>(element::f16, data_shape);
|
||||
auto masks = make_shared<op::Parameter>(element::f16, masks_shape);
|
||||
auto decoder = make_shared<op::CTCGreedyDecoder>(data, masks, false);
|
||||
auto function = make_shared<Function>(decoder, ParameterVector{data, masks});
|
||||
auto test_case = test::TestCase<TestEngine>(function);
|
||||
|
||||
test_case.add_input<float16>({0.1f, 0.2f, 0.4f, 0.3f, 0.5f, 0.6f});
|
||||
test_case.add_input<float16>({1.0f, 1.0f, 1.0f});
|
||||
test_case.add_expected_output(Shape{N, T, 1, 1}, vector<float16>{1.0f, 0.0f, 1.0f});
|
||||
|
||||
test_case.run_with_tolerance_as_fp(1.0e-4f);
|
||||
}
|
||||
|
||||
NGRAPH_TEST(${BACKEND_NAME}, ctc_greedy_decoder_multiple_batches)
|
||||
{
|
||||
const int T = 3;
|
||||
const int N = 2;
|
||||
const int C = 2;
|
||||
const auto data_shape = Shape{T, N, C};
|
||||
const auto masks_shape = Shape{T, N};
|
||||
|
||||
auto data = make_shared<op::Parameter>(element::f32, data_shape);
|
||||
auto masks = make_shared<op::Parameter>(element::f32, masks_shape);
|
||||
auto decoder = make_shared<op::CTCGreedyDecoder>(data, masks, false);
|
||||
auto function = make_shared<Function>(decoder, ParameterVector{data, masks});
|
||||
auto test_case = test::TestCase<TestEngine>(function);
|
||||
|
||||
test_case.add_input<float>(
|
||||
{0.1f, 0.2f, 0.15f, 0.25f, 0.4f, 0.3f, 0.45f, 0.35f, 0.5f, 0.6f, 0.55f, 0.65f});
|
||||
|
||||
test_case.add_input<float>({1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f});
|
||||
|
||||
test_case.add_expected_output(Shape{N, T, 1, 1},
|
||||
vector<float>{1.0f, 0.0f, 1.0f, 1.0f, 0.0f, 1.0f});
|
||||
|
||||
test_case.run_with_tolerance_as_fp(1.0e-4f);
|
||||
}
|
||||
|
||||
NGRAPH_TEST(${BACKEND_NAME}, ctc_greedy_decoder_single_batch_short_sequence)
|
||||
{
|
||||
const int T = 3;
|
||||
const int N = 1;
|
||||
const int C = 2;
|
||||
const auto data_shape = Shape{T, N, C};
|
||||
const auto masks_shape = Shape{T, N};
|
||||
|
||||
auto data = make_shared<op::Parameter>(element::f32, data_shape);
|
||||
auto masks = make_shared<op::Parameter>(element::f32, masks_shape);
|
||||
auto decoder = make_shared<op::CTCGreedyDecoder>(data, masks, false);
|
||||
auto function = make_shared<Function>(decoder, ParameterVector{data, masks});
|
||||
auto test_case = test::TestCase<TestEngine>(function);
|
||||
|
||||
test_case.add_input<float>({0.1f, 0.2f, 0.4f, 0.3f, 0.5f, 0.6f});
|
||||
test_case.add_input<float>({1.0f, 1.0f, 0.0f});
|
||||
test_case.add_expected_output(Shape{N, T, 1, 1}, vector<float>{1.0f, 0.0f, -1.0f});
|
||||
|
||||
test_case.run_with_tolerance_as_fp(1.0e-4f);
|
||||
}
|
||||
|
||||
NGRAPH_TEST(${BACKEND_NAME}, ctc_greedy_decoder_merge)
|
||||
{
|
||||
const int T = 3;
|
||||
const int N = 1;
|
||||
const int C = 2;
|
||||
const auto data_shape = Shape{T, N, C};
|
||||
const auto masks_shape = Shape{T, N};
|
||||
|
||||
auto data = make_shared<op::Parameter>(element::f32, data_shape);
|
||||
auto masks = make_shared<op::Parameter>(element::f32, masks_shape);
|
||||
auto decoder = make_shared<op::CTCGreedyDecoder>(data, masks, true);
|
||||
auto function = make_shared<Function>(decoder, ParameterVector{data, masks});
|
||||
auto test_case = test::TestCase<TestEngine>(function);
|
||||
|
||||
test_case.add_input<float>({0.1f, 0.2f, 0.3f, 0.4f, 0.6f, 0.5f});
|
||||
test_case.add_input<float>({1.0f, 1.0f, 1.0f});
|
||||
test_case.add_expected_output(Shape{N, T, 1, 1}, vector<float>{1.0f, 0.0f, -1.0f});
|
||||
|
||||
test_case.run_with_tolerance_as_fp(1.0e-4f);
|
||||
}
|
||||
|
||||
NGRAPH_TEST(${BACKEND_NAME}, ctc_greedy_decoder_single_no_merge)
|
||||
{
|
||||
const int T = 3;
|
||||
const int N = 1;
|
||||
const int C = 2;
|
||||
const auto data_shape = Shape{T, N, C};
|
||||
const auto masks_shape = Shape{T, N};
|
||||
|
||||
auto data = make_shared<op::Parameter>(element::f32, data_shape);
|
||||
auto masks = make_shared<op::Parameter>(element::f32, masks_shape);
|
||||
auto decoder = make_shared<op::CTCGreedyDecoder>(data, masks, false);
|
||||
auto function = make_shared<Function>(decoder, ParameterVector{data, masks});
|
||||
auto test_case = test::TestCase<TestEngine>(function);
|
||||
|
||||
test_case.add_input<float>({0.1f, 0.2f, 0.3f, 0.4f, 0.6f, 0.5f});
|
||||
test_case.add_input<float>({1.0f, 1.0f, 1.0f});
|
||||
test_case.add_expected_output(Shape{N, T, 1, 1}, vector<float>{1.0f, 1.0f, 0.0f});
|
||||
|
||||
test_case.run_with_tolerance_as_fp(1.0e-4f);
|
||||
}
|
||||
|
||||
NGRAPH_TEST(${BACKEND_NAME}, ctc_greedy_decoder_multiple_sequences)
|
||||
{
|
||||
const int T = 2;
|
||||
const int N = 2;
|
||||
const int C = 2;
|
||||
const auto data_shape = Shape{T, N, C};
|
||||
const auto masks_shape = Shape{T, N};
|
||||
|
||||
auto data = make_shared<op::Parameter>(element::f32, data_shape);
|
||||
auto masks = make_shared<op::Parameter>(element::f32, masks_shape);
|
||||
auto decoder = make_shared<op::CTCGreedyDecoder>(data, masks, false);
|
||||
auto function = make_shared<Function>(decoder, ParameterVector{data, masks});
|
||||
auto test_case = test::TestCase<TestEngine>(function);
|
||||
|
||||
test_case.add_input<float>({0.1f, 0.2f, 0.4f, 0.3f, 0.5f, 0.6f, 0.7f, 0.8f});
|
||||
test_case.add_input<float>({1.0f, 1.0f, 1.0f, 0.0f});
|
||||
test_case.add_expected_output(Shape{N, T, 1, 1}, vector<float>{1.0f, 1.0f, 0.0f, -1.0f});
|
||||
|
||||
test_case.run_with_tolerance_as_fp(1.0e-4f);
|
||||
}
|
@ -982,7 +982,7 @@ namespace
|
||||
EXPECT_FALSE(op::is_binary_elementwise_comparison(&node));
|
||||
EXPECT_TRUE(op::is_binary_elementwise_logical(&node));
|
||||
}
|
||||
}
|
||||
} // namespace
|
||||
|
||||
TEST(op_is, check)
|
||||
{
|
||||
|
@ -1116,6 +1116,9 @@ IE_CPU.builder_opset1_collapse_dyn_shape
|
||||
# Y because output with index 0 contains dynamic shapes: {?,?,?,?}
|
||||
IE_CPU.onnx_resize11_scales_nearest_asymmetric_floor_dynamic_sizes
|
||||
|
||||
# Input data precision not supported. Expected float.
|
||||
ctc_greedy_decoder_f16
|
||||
|
||||
#-------------------------------------------------------------------------------
|
||||
#
|
||||
# Inference Engine GPU plugin excludes
|
||||
@ -1201,6 +1204,12 @@ IE_GPU.onnx_dyn_shapes_avg_pool_dyn_shape
|
||||
IE_GPU.onnx_dyn_shapes_max_pool_dyn_shape
|
||||
IE_GPU.onnx_dyn_shapes_global_avg_pool_dyn_shape
|
||||
IE_GPU.onnx_dyn_shapes_global_max_pool_dyn_shape
|
||||
IE_GPU.ctc_greedy_decoder
|
||||
IE_GPU.ctc_greedy_decoder_multiple_batches
|
||||
IE_GPU.ctc_greedy_decoder_single_batch_short_sequence
|
||||
IE_GPU.ctc_greedy_decoder_merge
|
||||
IE_GPU.ctc_greedy_decoder_single_no_merge
|
||||
IE_GPU.ctc_greedy_decoder_multiple_sequences
|
||||
IE_GPU.onnx_roi_align_f32
|
||||
IE_GPU.tanh
|
||||
IE_GPU.tan
|
||||
|
@ -44,6 +44,7 @@
|
||||
#include "ngraph/runtime/reference/convolution.hpp"
|
||||
#include "ngraph/runtime/reference/cos.hpp"
|
||||
#include "ngraph/runtime/reference/cosh.hpp"
|
||||
#include "ngraph/runtime/reference/ctc_greedy_decoder.hpp"
|
||||
#include "ngraph/runtime/reference/ctc_loss.hpp"
|
||||
#include "ngraph/runtime/reference/cum_sum.hpp"
|
||||
#include "ngraph/runtime/reference/dequantize.hpp"
|
||||
@ -60,7 +61,6 @@
|
||||
#include "ngraph/runtime/reference/gather.hpp"
|
||||
#include "ngraph/runtime/reference/gather_nd.hpp"
|
||||
#include "ngraph/runtime/reference/gather_tree.hpp"
|
||||
#include "ngraph/runtime/reference/gather_tree.hpp"
|
||||
#include "ngraph/runtime/reference/gru_cell.hpp"
|
||||
#include "ngraph/runtime/reference/log.hpp"
|
||||
#include "ngraph/runtime/reference/lrn.hpp"
|
||||
@ -397,6 +397,18 @@ protected:
|
||||
args[0]->get_data_ptr<const T>(), out[0]->get_data_ptr<T>(), element_count);
|
||||
break;
|
||||
}
|
||||
case OP_TYPEID::CTCGreedyDecoder_v0:
|
||||
{
|
||||
const auto ctc_greedy_dec = static_cast<const op::v0::CTCGreedyDecoder*>(&node);
|
||||
reference::ctc_greedy_decoder<T>(args[0]->get_data_ptr<const T>(),
|
||||
args[1]->get_data_ptr<const T>(),
|
||||
out[0]->get_data_ptr<T>(),
|
||||
args[0]->get_shape(),
|
||||
args[1]->get_shape(),
|
||||
out[0]->get_shape(),
|
||||
ctc_greedy_dec->get_ctc_merge_repeated());
|
||||
break;
|
||||
}
|
||||
case OP_TYPEID::CTCLoss_v4:
|
||||
{
|
||||
const op::v4::CTCLoss* ctc_loss = static_cast<const op::v4::CTCLoss*>(&node);
|
||||
@ -809,9 +821,7 @@ protected:
|
||||
gru_seq->get_activations()[1],
|
||||
gru_seq->get_clip(),
|
||||
gru_seq->get_direction(),
|
||||
gru_seq->get_linear_before_reset()
|
||||
|
||||
);
|
||||
gru_seq->get_linear_before_reset());
|
||||
break;
|
||||
}
|
||||
case OP_TYPEID::RNNSequence_v5:
|
||||
|
@ -19,6 +19,7 @@
|
||||
#undef ID_SUFFIX
|
||||
|
||||
#define ID_SUFFIX(NAME) NAME##_v0
|
||||
NGRAPH_OP(CTCGreedyDecoder, ngraph::op::v0)
|
||||
NGRAPH_OP(DetectionOutput, op::v0)
|
||||
NGRAPH_OP(RNNCell, op::v0)
|
||||
#undef ID_SUFFIX
|
||||
|
@ -138,3 +138,6 @@ onnx_model_gru_fwd_activations
|
||||
# Peepholes, input_forget are not supported
|
||||
lstm_cell_bias_peepholes
|
||||
lstm_cell_bias_peepholes_clip_input_forget
|
||||
|
||||
# unsupported element type f16
|
||||
INTERPRETER.ctc_greedy_decoder_f16
|
||||
|
Loading…
Reference in New Issue
Block a user