[GPU] Implement CTCLoss-4 (#13122)

This commit is contained in:
Mykhailo Hnap 2022-10-14 05:13:32 +03:00 committed by GitHub
parent b08fa945bc
commit 30774036ab
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
17 changed files with 852 additions and 1 deletions

View File

@ -182,9 +182,9 @@ REGISTER_FACTORY(v4, ReduceL1);
REGISTER_FACTORY(v4, ReduceL2);
REGISTER_FACTORY(v4, SoftPlus);
REGISTER_FACTORY(v4, Swish);
REGISTER_FACTORY(v4, CTCLoss);
// ----------------------------- Unsupported v4 ops ----------------------------- //
// REGISTER_FACTORY(v4, CTCLoss);
// REGISTER_FACTORY(v4, Range);
// ------------------------------ Supported v5 ops ------------------------------ //

View File

@ -0,0 +1,47 @@
// Copyright (C) 2022 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#pragma once
#include <vector>
#include "primitive.hpp"
namespace cldnn {
/// @addtogroup cpp_api C++ API
/// @{
/// @addtogroup cpp_topology Network Topology
/// @{
/// @addtogroup cpp_primitives Primitives
/// @{
/// @brief CTCLoss-4 primitive.
struct ctc_loss : primitive_base<ctc_loss> {
CLDNN_DECLARE_PRIMITIVE(ctc_loss)
/// @brief Constructs ctc_loss primitive.
/// @param id This primitive id.
/// @param inputs Input primitives ids.
/// @param preprocess_collapse_repeated Flag for preprocessing labels before loss calculation.
/// @param ctc_merge_repeated Flag for merging repeated characters in a potential alignment.
/// @param unique Flag to find unique elements in a target.
ctc_loss(const primitive_id& id,
const std::vector<primitive_id>& inputs,
bool preprocess_collapse_repeated,
bool ctc_merge_repeated,
bool unique,
const padding& output_padding = {})
: primitive_base(id, inputs, output_padding),
preprocess_collapse_repeated(preprocess_collapse_repeated),
ctc_merge_repeated(ctc_merge_repeated),
unique(unique) {}
bool preprocess_collapse_repeated;
bool ctc_merge_repeated;
bool unique;
};
/// @}
/// @}
/// @}
} // namespace cldnn

View File

@ -0,0 +1,44 @@
// Copyright (C) 2022 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#include <sstream>
#include <string>
#include "ctc_loss_inst.hpp"
#include "json_object.h"
#include "primitive_type_base.h"
#include "to_string_utils.h"
namespace cldnn {
primitive_type_id ctc_loss::type_id() {
static primitive_type_base<ctc_loss> instance;
return &instance;
}
layout ctc_loss_inst::calc_output_layout(const ctc_loss_node& node, const kernel_impl_params& impl_param) {
auto input_layout = impl_param.get_input_layout();
std::vector<tensor::value_type> out_tensor = {input_layout.get_tensor().sizes().front(), 1, 1, 1};
return {input_layout.data_type, input_layout.format, tensor(input_layout.format, out_tensor)};
}
std::string ctc_loss_inst::to_string(const ctc_loss_node& node) {
auto primitive = node.get_primitive();
json_composite ctc_loss_info;
for (size_t i = 0; i < primitive->input_size(); ++i) {
ctc_loss_info.add("input_" + std::to_string(i), node.input(i).id());
}
ctc_loss_info.add("preprocess_collapse_repeated", primitive->preprocess_collapse_repeated);
ctc_loss_info.add("ctc_merge_repeated", primitive->ctc_merge_repeated);
ctc_loss_info.add("unique", primitive->unique);
auto node_info = node.desc_to_json();
node_info->add("ctc_loss info", ctc_loss_info);
std::ostringstream primitive_description;
node_info->dump(primitive_description);
return primitive_description.str();
}
} // namespace cldnn

View File

@ -0,0 +1,64 @@
// Copyright (C) 2022 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#include "ctc_loss/ctc_loss_kernel_ref.hpp"
#include "ctc_loss/ctc_loss_kernel_selector.hpp"
#include "ctc_loss_inst.hpp"
#include "impls/implementation_map.hpp"
#include "primitive_base.hpp"
namespace cldnn {
namespace ocl {
struct ctc_loss_impl : typed_primitive_impl_ocl<ctc_loss> {
using parent = typed_primitive_impl_ocl<ctc_loss>;
using parent::parent;
std::unique_ptr<primitive_impl> clone() const override {
return make_unique<ctc_loss_impl>(*this);
}
static primitive_impl* create(const ctc_loss_node& arg, const kernel_impl_params& impl_param) {
auto params = get_default_params<kernel_selector::ctc_loss_params>(impl_param);
auto optional_params =
get_default_optional_params<kernel_selector::ctc_loss_optional_params>(arg.get_program());
const auto& primitive = impl_param.typed_desc<ctc_loss>();
params.preprocess_collapse_repeated = primitive->preprocess_collapse_repeated;
params.ctc_merge_repeated = primitive->ctc_merge_repeated;
params.unique = primitive->unique;
for (size_t i = 1; i < impl_param.input_layouts.size(); ++i) {
params.inputs.push_back(convert_data_tensor(impl_param.input_layouts[i]));
}
const auto& kernel_selector = kernel_selector::ctc_loss_kernel_selector::Instance();
const auto best_kernels = kernel_selector.GetBestKernels(params, optional_params);
CLDNN_ERROR_BOOL(arg.id(),
"Best_kernel.empty()",
best_kernels.empty(),
"Cannot find a proper kernel with this arguments");
return new ctc_loss_impl(arg, best_kernels.front());
}
};
namespace detail {
attach_ctc_loss_impl::attach_ctc_loss_impl() {
auto types = {data_types::f16, data_types::f32};
auto formats = {format::bfyx,
format::b_fs_yx_fsv16,
format::b_fs_yx_fsv32,
format::bs_fs_yx_bsv16_fsv16,
format::bs_fs_yx_bsv32_fsv32,
format::bs_fs_yx_bsv32_fsv16};
implementation_map<ctc_loss>::add(impl_types::ocl, ctc_loss_impl::create, types, formats);
}
} // namespace detail
} // namespace ocl
} // namespace cldnn

View File

@ -87,6 +87,7 @@ void register_implementations() {
REGISTER_OCL(resample);
REGISTER_OCL(grn);
REGISTER_OCL(ctc_greedy_decoder);
REGISTER_OCL(ctc_loss);
REGISTER_OCL(cum_sum);
REGISTER_OCL(embedding_bag);
REGISTER_OCL(extract_image_patches);

View File

@ -19,6 +19,7 @@
#include "intel_gpu/primitives/convolution.hpp"
#include "intel_gpu/primitives/crop.hpp"
#include "intel_gpu/primitives/ctc_greedy_decoder.hpp"
#include "intel_gpu/primitives/ctc_loss.hpp"
#include "intel_gpu/primitives/custom_gpu_primitive.hpp"
#include "intel_gpu/primitives/deconvolution.hpp"
#include "intel_gpu/primitives/depth_to_space.hpp"
@ -165,6 +166,7 @@ REGISTER_OCL(gather_tree);
REGISTER_OCL(resample);
REGISTER_OCL(grn);
REGISTER_OCL(ctc_greedy_decoder);
REGISTER_OCL(ctc_loss);
REGISTER_OCL(cum_sum);
REGISTER_OCL(embedding_bag);
REGISTER_OCL(extract_image_patches);

View File

@ -0,0 +1,35 @@
// Copyright (C) 2022 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#pragma once
#include "intel_gpu/primitives/ctc_loss.hpp"
#include "primitive_inst.h"
namespace cldnn {
template <>
struct typed_program_node<ctc_loss> : typed_program_node_base<ctc_loss> {
using parent = typed_program_node_base<ctc_loss>;
using parent::parent;
program_node& input(size_t index) const {
return get_dependency(index);
}
};
using ctc_loss_node = typed_program_node<ctc_loss>;
template <>
class typed_primitive_inst<ctc_loss> : public typed_primitive_inst_base<ctc_loss> {
public:
using parent = typed_primitive_inst_base<ctc_loss>;
using parent::parent;
static layout calc_output_layout(const ctc_loss_node& node, kernel_impl_params const& impl_param);
static std::string to_string(const ctc_loss_node& node);
};
using ctc_loss_inst = typed_primitive_inst<ctc_loss>;
} // namespace cldnn

View File

@ -1426,6 +1426,7 @@ void program::set_layout_optimizer_attributes(layout_optimizer& lo) {
prim.type() != cldnn::gather::type_id() &&
prim.type() != cldnn::scatter_nd_update::type_id() &&
prim.type() != cldnn::broadcast::type_id() &&
prim.type() != cldnn::ctc_loss::type_id() &&
prim.type() != cldnn::non_max_suppression::type_id() &&
prim.type() != cldnn::roi_align::type_id() &&
prim.type() != cldnn::adaptive_pooling::type_id() &&
@ -1463,6 +1464,7 @@ void program::set_layout_optimizer_attributes(layout_optimizer& lo) {
prim.type() != cldnn::scatter_nd_update::type_id() &&
prim.type() != cldnn::broadcast::type_id() &&
prim.type() != cldnn::quantize::type_id() &&
prim.type() != cldnn::ctc_loss::type_id() &&
prim.type() != cldnn::non_max_suppression::type_id() &&
prim.type() != cldnn::roi_align::type_id() &&
prim.type() != cldnn::adaptive_pooling::type_id() &&

View File

@ -74,6 +74,7 @@ enum class KernelType {
SPACE_TO_BATCH,
GRN,
CTC_GREEDY_DECODER,
CTC_LOSS,
CUM_SUM,
EMBEDDING_BAG,
EXTRACT_IMAGE_PATCHES,

View File

@ -0,0 +1,94 @@
// Copyright (C) 2022 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#include "ctc_loss_kernel_ref.hpp"
#include "kernel_selector_utils.h"
namespace kernel_selector {
namespace {
CommonDispatchData SetDefault(const ctc_loss_params& kernel_params) {
CommonDispatchData dispatch_data;
const auto& output = kernel_params.outputs.front();
dispatch_data.gws = {output.Batch().v, 1, 1};
dispatch_data.lws = GetOptimalLocalWorkGroupSizes(dispatch_data.gws, kernel_params.engineInfo);
return dispatch_data;
}
} // namespace
KernelsData CTCLossKernelRef::GetKernelsData(const Params& params, const optional_params& options) const {
if (!Validate(params, options)) {
return {};
}
auto kernel_data = KernelData::Default<ctc_loss_params>(params);
const auto& kernel_params = dynamic_cast<const ctc_loss_params&>(*kernel_data.params);
const auto dispatch_data = SetDefault(kernel_params);
const auto entry_point = GetEntryPoint(kernelName, kernel_params.layerID, params, options);
const auto jit_constants = GetJitConstants(kernel_params);
const auto jit = CreateJit(kernelName, jit_constants, entry_point);
auto& kernel = kernel_data.kernels.front();
FillCLKernelData(kernel,
dispatch_data,
params.engineInfo,
kernelName,
jit,
entry_point,
{},
false,
false,
kernel_params.inputs.size());
return {kernel_data};
}
ParamsKey CTCLossKernelRef::GetSupportedKey() const {
ParamsKey key;
key.EnableInputDataType(Datatype::INT32);
key.EnableInputDataType(Datatype::INT64);
key.EnableInputDataType(Datatype::F16);
key.EnableInputDataType(Datatype::F32);
key.EnableOutputDataType(Datatype::F16);
key.EnableOutputDataType(Datatype::F32);
key.EnableDifferentTypes();
key.EnableAllInputLayout();
key.EnableAllOutputLayout();
key.EnableTensorOffset();
key.EnableTensorPitches();
key.EnableBatching();
return key;
}
bool CTCLossKernelRef::Validate(const Params& params, const optional_params& options) const {
if (params.GetType() != KernelType::CTC_LOSS || options.GetType() != KernelType::CTC_LOSS) {
return false;
}
const auto& kernel_params = dynamic_cast<const ctc_loss_params&>(params);
if (kernel_params.inputs.size() != 4 && kernel_params.inputs.size() != 5) {
return false;
}
return true;
}
JitConstants CTCLossKernelRef::GetJitConstants(const ctc_loss_params& kernel_params) const {
auto jit_constants = MakeBaseParamsJitConstants(kernel_params);
jit_constants.AddConstants({
MakeJitConstant("PREPROCESS_COLLAPSE_REPEATED", kernel_params.preprocess_collapse_repeated),
MakeJitConstant("CTC_MERGE_REPEATED", kernel_params.ctc_merge_repeated),
MakeJitConstant("UNIQUE", kernel_params.unique),
});
return jit_constants;
}
} // namespace kernel_selector

View File

@ -0,0 +1,41 @@
// Copyright (C) 2022 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#pragma once
#include "kernel_base_opencl.h"
namespace kernel_selector {
/**
* CTCLoss reference kernel parameters.
*/
struct ctc_loss_params : base_params {
ctc_loss_params() : base_params(KernelType::CTC_LOSS) {}
bool preprocess_collapse_repeated = false;
bool ctc_merge_repeated = true;
bool unique = false;
};
/**
* CTCLoss reference kernel optional parameters.
*/
struct ctc_loss_optional_params : optional_params {
ctc_loss_optional_params() : optional_params(KernelType::CTC_LOSS) {}
};
/**
* Reference kernel for CTCLoss.
*/
class CTCLossKernelRef : public KernelBaseOpenCL {
public:
CTCLossKernelRef() : KernelBaseOpenCL{"ctc_loss_ref"} {}
private:
KernelsData GetKernelsData(const Params& params, const optional_params& options) const override;
ParamsKey GetSupportedKey() const override;
bool Validate(const Params& params, const optional_params& options) const override;
JitConstants GetJitConstants(const ctc_loss_params& kernel_params) const;
};
} // namespace kernel_selector

View File

@ -0,0 +1,24 @@
// Copyright (C) 2022 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#include "ctc_loss_kernel_selector.hpp"
#include "ctc_loss_kernel_ref.hpp"
namespace kernel_selector {
ctc_loss_kernel_selector::ctc_loss_kernel_selector() {
Attach<CTCLossKernelRef>();
}
KernelsData ctc_loss_kernel_selector::GetBestKernels(const Params& params, const optional_params& options) const {
return GetNaiveBestKernel(params, options, KernelType::CTC_LOSS);
}
ctc_loss_kernel_selector& ctc_loss_kernel_selector::Instance() {
static ctc_loss_kernel_selector instance;
return instance;
}
} // namespace kernel_selector

View File

@ -0,0 +1,20 @@
// Copyright (C) 2022 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#pragma once
#include "kernel_selector.h"
namespace kernel_selector {
/*
* CTCLoss kernel selector.
*/
class ctc_loss_kernel_selector : public kernel_selector_base {
public:
ctc_loss_kernel_selector();
KernelsData GetBestKernels(const Params& params, const optional_params& options) const override;
static ctc_loss_kernel_selector& Instance();
};
} // namespace kernel_selector

View File

@ -0,0 +1,136 @@
// Copyright (C) 2022 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#define BATCH_NUM INPUT0_BATCH_NUM
#define MAX_TIME INPUT0_FEATURE_NUM
#define CLASSES_NUM INPUT0_SIZE_Y
#define MAX_DECODED_LENGTH (MAX_TIME * 2 + 1)
#ifdef INPUT4_TYPE
# define BLANK_INDEX blank_index[INPUT4_GET_INDEX(0, 0, 0, 0)]
#else
# define BLANK_INDEX (CLASSES_NUM - 1)
#endif
inline OUTPUT_TYPE FUNC(sumLogs)(OUTPUT_TYPE log1, OUTPUT_TYPE log2) {
if (log1 == -INFINITY) {
return log2;
} else if (log2 == -INFINITY) {
return log1;
} else {
if (log1 > log2) {
return log1 + log1p(exp(log2 - log1));
} else {
return log2 + log1p(exp(log1 - log2));
}
}
}
// Based on the CPU plugin implementation that uses the backward dynamic programming algorithm
KERNEL(ctc_loss_ref)
(const __global INPUT0_TYPE* logits,
const __global INPUT1_TYPE* logit_length,
const __global INPUT2_TYPE* labels,
const __global INPUT3_TYPE* label_length,
#ifdef INPUT4_TYPE
const __global INPUT4_TYPE* blank_index,
#endif
__global OUTPUT_TYPE* output) {
const uint b = get_global_id(0);
const INPUT1_TYPE actual_logit_length = logit_length[INPUT1_GET_INDEX(0, b, 0, 0)];
const INPUT3_TYPE actual_target_length = label_length[INPUT3_GET_INDEX(0, b, 0, 0)];
INPUT2_TYPE decoded_target[MAX_DECODED_LENGTH] = {};
INPUT3_TYPE decoded_target_length = 0;
#if UNIQUE
bool founded_values[CLASSES_NUM] = {};
for (uint t = 0; t < actual_target_length; ++t) {
const INPUT2_TYPE value = labels[INPUT2_GET_INDEX(b, t, 0, 0)];
if (founded_values[value]) {
continue;
}
founded_values[value] = true;
decoded_target[decoded_target_length++] = BLANK_INDEX;
decoded_target[decoded_target_length++] = value;
}
decoded_target[decoded_target_length++] = BLANK_INDEX;
#elif PREPROCESS_COLLAPSE_REPEATED
INPUT2_TYPE previous_value = labels[INPUT2_GET_INDEX(b, 0, 0, 0)];
decoded_target[decoded_target_length++] = BLANK_INDEX;
decoded_target[decoded_target_length++] = previous_value;
for (uint t = 1; t < actual_target_length; ++t) {
const INPUT2_TYPE value = labels[INPUT2_GET_INDEX(b, t, 0, 0)];
if (value == previous_value) {
continue;
}
decoded_target[decoded_target_length++] = BLANK_INDEX;
decoded_target[decoded_target_length++] = value;
previous_value = value;
}
decoded_target[decoded_target_length++] = BLANK_INDEX;
#else
for (uint t = 0; t < actual_target_length; ++t) {
decoded_target[decoded_target_length++] = BLANK_INDEX;
decoded_target[decoded_target_length++] = labels[INPUT2_GET_INDEX(b, t, 0, 0)];
}
decoded_target[decoded_target_length++] = BLANK_INDEX;
#endif
OUTPUT_TYPE log_probabilities[MAX_TIME][MAX_DECODED_LENGTH] = {};
for (uint t = 0; t < actual_logit_length; ++t) {
OUTPUT_TYPE exp_sum = OUTPUT_VAL_ZERO;
for (uint c = 0; c < CLASSES_NUM; ++c) {
exp_sum += exp(logits[INPUT0_GET_INDEX(b, t, c, 0)]);
}
for (uint s = 0; s < decoded_target_length; ++s) {
log_probabilities[t][s] = logits[INPUT0_GET_INDEX(b, t, decoded_target[s], 0)] - log(exp_sum);
}
}
OUTPUT_TYPE log_backward[MAX_DECODED_LENGTH][MAX_TIME] = {};
for (uint i = 0; i < MAX_DECODED_LENGTH; ++i) {
for (uint j = 0; j < MAX_TIME; ++j) {
log_backward[i][j] = -INFINITY;
}
}
log_backward[decoded_target_length - 1][actual_logit_length - 1] = OUTPUT_VAL_ZERO;
log_backward[decoded_target_length - 2][actual_logit_length - 1] = OUTPUT_VAL_ZERO;
for (INPUT1_TYPE t = actual_logit_length - 2; t >= 0; t--) {
const INPUT1_TYPE t_1 = t + 1;
for (INPUT1_TYPE s = max(INPUT1_VAL_ZERO, decoded_target_length - (2 * (actual_logit_length - t)));
s < min(decoded_target_length, 2 * (t_1));
s++) {
if (CTC_MERGE_REPEATED || decoded_target[s] == BLANK_INDEX) {
log_backward[s][t] =
FUNC_CALL(sumLogs)(log_backward[s][t], log_backward[s][t_1] + log_probabilities[t_1][s]);
}
if (s + 1 < decoded_target_length) {
log_backward[s][t] =
FUNC_CALL(sumLogs)(log_backward[s][t], log_backward[s + 1][t_1] + log_probabilities[t_1][s + 1]);
}
if (s + 2 < decoded_target_length) {
if (decoded_target[s] != BLANK_INDEX &&
(!CTC_MERGE_REPEATED || (decoded_target[s] != decoded_target[s + 2]))) {
log_backward[s][t] = FUNC_CALL(sumLogs)(log_backward[s][t],
log_backward[s + 2][t_1] + log_probabilities[t_1][s + 2]);
}
}
}
}
log_backward[0][0] += log_probabilities[0][0];
log_backward[1][0] += log_probabilities[0][1];
output[OUTPUT_GET_INDEX(b, 0, 0, 0)] = -FUNC_CALL(sumLogs)(log_backward[0][0], log_backward[1][0]);
}
#undef BATCH_NUM
#undef MAX_TIME
#undef CLASSES_NUM
#undef MAX_DECODED_LENGTH
#undef BLANK_INDEX

View File

@ -0,0 +1,33 @@
// Copyright (C) 2022 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#include "intel_gpu/primitives/ctc_loss.hpp"
#include <ngraph/op/ctc_loss.hpp>
#include "intel_gpu/plugin/program.hpp"
namespace ov {
namespace intel_gpu {
namespace {
void CreateCTCLossOp(Program& p, const std::shared_ptr<ngraph::op::v4::CTCLoss>& op) {
validate_inputs_count(op, {4, 5});
const cldnn::ctc_loss ctc_loss_prim(layer_type_name_ID(op),
p.GetInputPrimitiveIDs(op),
op->get_preprocess_collapse_repeated(),
op->get_ctc_merge_repeated(),
op->get_unique());
p.add_primitive(*op, ctc_loss_prim);
}
} // namespace
REGISTER_FACTORY_IMPL(v4, CTCLoss);
} // namespace intel_gpu
} // namespace ov

View File

@ -0,0 +1,242 @@
// Copyright (C) 2022 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#include <test_utils/test_utils.h>
#include <intel_gpu/primitives/ctc_loss.hpp>
#include <vector>
using namespace cldnn;
using namespace tests;
namespace {
template <typename vecElementType>
std::string vec2str(const std::vector<vecElementType>& vec) {
if (!vec.empty()) {
std::ostringstream result;
result << "(";
std::copy(vec.begin(), vec.end() - 1, std::ostream_iterator<vecElementType>(result, "."));
result << vec.back() << ")";
return result.str();
}
return "()";
}
template <class TF, class TI>
struct ctc_loss_test_inputs {
bool preprocess_collapse_repeated;
bool ctc_merge_repeated;
bool unique;
std::vector<int> logits_shape;
std::vector<TF> logits;
std::vector<TI> logit_length;
std::vector<TI> labels;
std::vector<TI> label_length;
TI blank_index;
std::vector<TF> expected_values;
};
template <class TF, class TI>
using ctc_loss_test_params = std::tuple<ctc_loss_test_inputs<TF, TI>, format::type>;
template <class TF, class TI>
struct ctc_loss_gpu_test : public testing::TestWithParam<ctc_loss_test_params<TF, TI>> {
public:
void test() {
format::type fmt;
ctc_loss_test_inputs<TF, TI> p;
std::tie(p, fmt) = testing::TestWithParam<ctc_loss_test_params<TF, TI>>::GetParam();
auto& engine = get_test_engine();
const auto float_data_type = type_to_data_type<TF>::value;
const auto int_data_type = type_to_data_type<TI>::value;
const auto plane_format = format::bfyx;
std::vector<std::tuple<primitive_id, memory::ptr, data_types>> inputs;
const auto batch_num = p.logits_shape[0];
const auto max_time = p.logits_shape[1];
const auto classes_num = p.logits_shape[2];
const layout logits_layout(float_data_type,
plane_format,
tensor(plane_format, {batch_num, max_time, classes_num, 1}));
auto logits = engine.allocate_memory(logits_layout);
set_values(logits, p.logits);
inputs.emplace_back("logits", logits, float_data_type);
const layout logit_length_layout(int_data_type, plane_format, tensor(plane_format, {1, batch_num, 1, 1}));
auto logit_length = engine.allocate_memory(logit_length_layout);
set_values(logit_length, p.logit_length);
inputs.emplace_back("logit_length", logit_length, int_data_type);
const layout labels_layout(int_data_type, plane_format, tensor(plane_format, {batch_num, max_time, 1, 1}));
auto labels = engine.allocate_memory(labels_layout);
set_values(labels, p.labels);
inputs.emplace_back("labels", labels, int_data_type);
const layout label_length_layout(int_data_type, plane_format, tensor(plane_format, {1, batch_num, 1, 1}));
auto label_length = engine.allocate_memory(label_length_layout);
set_values(label_length, p.label_length);
inputs.emplace_back("label_length", label_length, int_data_type);
const layout blank_index_layout(int_data_type, plane_format, tensor(plane_format, {1, 1, 1, 1}));
auto blank_index = engine.allocate_memory(blank_index_layout);
set_values(blank_index, {p.blank_index});
inputs.emplace_back("blank_index", blank_index, int_data_type);
std::vector<primitive_id> inputs_ids;
std::transform(inputs.begin(),
inputs.end(),
std::back_inserter(inputs_ids),
[](const decltype(inputs)::value_type& input) {
return "reordered_" + std::get<0>(input);
});
topology topology;
for (const auto& input : inputs) {
topology.add(input_layout(std::get<0>(input), std::get<1>(input)->get_layout()));
topology.add(reorder("reordered_" + std::get<0>(input), std::get<0>(input), fmt, std::get<2>(input)));
}
topology.add(ctc_loss("ctc_loss", inputs_ids, p.preprocess_collapse_repeated, p.ctc_merge_repeated, p.unique));
topology.add(reorder("reordered_ctc_loss", "ctc_loss", plane_format, float_data_type));
network network(engine, topology);
for (auto& input : inputs) {
network.set_input_data(std::get<0>(input), std::get<1>(input));
}
const auto outputs = network.execute();
EXPECT_EQ(outputs.size(), size_t(1));
EXPECT_EQ(outputs.begin()->first, "reordered_ctc_loss");
auto output = outputs.at("reordered_ctc_loss").get_memory();
cldnn::mem_lock<TF> output_ptr(output, get_test_stream());
ASSERT_EQ(output_ptr.size(), p.expected_values.size());
for (size_t i = 0; i < output_ptr.size(); ++i) {
EXPECT_NEAR(p.expected_values[i], output_ptr[i], 0.1);
}
}
static std::string PrintToStringParamName(const testing::TestParamInfo<ctc_loss_test_params<TF, TI>>& info) {
format::type fmt;
ctc_loss_test_inputs<TF, TI> p;
std::tie(p, fmt) = info.param;
std::ostringstream result;
result << "PreprocessCollapseRepeated=" << p.preprocess_collapse_repeated << "_";
result << "CtcMergeRepeated=" << p.ctc_merge_repeated << "_";
result << "Unique=" << p.unique << "_";
result << "LogitsShape=" << vec2str(p.logits_shape) << "_";
result << "LogitsLength=" << vec2str(p.logit_length) << "_";
result << "Labels=" << vec2str(p.labels) << "_";
result << "LabelLength=" << vec2str(p.label_length) << "_";
result << "BlankIndex=" << p.blank_index << "_";
result << "Format=" << fmt_to_str(fmt);
return result.str();
}
};
template <class TF, class TI>
std::vector<ctc_loss_test_inputs<TF, TI>> getCTCLossParams() {
return {
{
false, // preprocess_collapse_repeated
false, // ctc_merge_repeated
false, // unique
{2, 3, 3}, // logits_shape
{0, 1, 8, 5, 5, 2, 0, 7, 7, 10, 4, 5, 9, 0, 0, 5, 7, 0}, // logits
{3, 3}, // logit_length
{0, 1, 2, 1, 1, 1}, // labels
{2, 1}, // label_length
2, // blank_index
{1.41223f, 14.1359f}, // expected_values
},
{
false, // preprocess_collapse_repeated
false, // ctc_merge_repeated
true, // unique
{2, 3, 3}, // logits_shape
{0, 1, 8, 5, 5, 2, 0, 7, 7, 10, 4, 5, 9, 0, 0, 5, 7, 0}, // logits
{3, 3}, // logit_length
{0, 1, 2, 1, 1, 1}, // labels
{2, 1}, // label_length
2, // blank_index
{1.41223f, 14.1359f}, // expected_values
},
{
false, // preprocess_collapse_repeated
true, // ctc_merge_repeated
false, // unique
{2, 3, 3}, // logits_shape
{0, 1, 8, 5, 5, 2, 0, 7, 7, 10, 4, 5, 9, 0, 0, 5, 7, 0}, // logits
{3, 3}, // logit_length
{0, 1, 2, 1, 1, 1}, // labels
{2, 1}, // label_length
2, // blank_index
{1.41156f, 13.2745f}, // expected_values
},
{
true, // preprocess_collapse_repeated
false, // ctc_merge_repeated
false, // unique
{2, 3, 3}, // logits_shape
{0, 1, 8, 5, 5, 2, 0, 7, 7, 10, 4, 5, 9, 0, 0, 5, 7, 0}, // logits
{3, 3}, // logit_length
{0, 1, 2, 1, 1, 1}, // labels
{2, 1}, // label_length
2, // blank_index
{1.41223f, 14.1359f}, // expected_values
},
{
false, // preprocess_collapse_repeated
true, // ctc_merge_repeated
true, // unique
{2, 3, 3}, // logits_shape
{0, 1, 8, 5, 5, 2, 0, 7, 7, 10, 4, 5, 9, 0, 0, 5, 7, 0}, // logits
{3, 3}, // logit_length
{0, 1, 2, 1, 1, 1}, // labels
{2, 1}, // label_length
2, // blank_index
{1.41156f, 13.2745f}, // expected_values
},
{
true, // preprocess_collapse_repeated
true, // ctc_merge_repeated
true, // unique
{2, 3, 3}, // logits_shape
{0, 1, 8, 5, 5, 2, 0, 7, 7, 10, 4, 5, 9, 0, 0, 5, 7, 0}, // logits
{3, 3}, // logit_length
{0, 1, 2, 1, 1, 1}, // labels
{2, 1}, // label_length
2, // blank_index
{1.41223f, 13.2745f}, // expected_values
},
};
}
const std::vector<format::type> layout_formats = {
format::bfyx,
format::b_fs_yx_fsv16,
format::b_fs_yx_fsv32,
format::bs_fs_yx_bsv16_fsv16,
format::bs_fs_yx_bsv32_fsv32,
format::bs_fs_yx_bsv32_fsv16,
};
#define INSTANTIATE_CTC_LOSS_TEST_SUITE(float_type, int_type) \
using ctc_loss_gpu_test_##float_type##int_type = ctc_loss_gpu_test<float_type, int_type>; \
TEST_P(ctc_loss_gpu_test_##float_type##int_type, test) { ASSERT_NO_FATAL_FAILURE(test()); } \
INSTANTIATE_TEST_SUITE_P(smoke_ctc_loss_##float_type##int_type, \
ctc_loss_gpu_test_##float_type##int_type, \
testing::Combine(testing::ValuesIn(getCTCLossParams<float_type, int_type>()), \
testing::ValuesIn(layout_formats)), \
ctc_loss_gpu_test_##float_type##int_type::PrintToStringParamName);
INSTANTIATE_CTC_LOSS_TEST_SUITE(float, int64_t);
INSTANTIATE_CTC_LOSS_TEST_SUITE(FLOAT16, int32_t);
} // namespace

View File

@ -0,0 +1,65 @@
// Copyright (C) 2022 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#include "single_layer_tests/ctc_loss.hpp"
#include <vector>
using namespace LayerTestsDefinitions;
namespace {
const std::vector<InferenceEngine::Precision> fPrecisions = {
InferenceEngine::Precision::FP32,
InferenceEngine::Precision::FP16,
};
const std::vector<InferenceEngine::Precision> iPrecisions = {
InferenceEngine::Precision::I32,
InferenceEngine::Precision::I64,
};
const std::vector<bool> preprocessCollapseRepeated = {true, false};
const std::vector<bool> ctcMergeRepeated = {true, false};
const std::vector<bool> unique = {true, false};
const auto ctcLossArgsSubset1 = testing::Combine(
testing::Values(std::vector<size_t>({2, 3, 3})), // logits shape
testing::ValuesIn(std::vector<std::vector<int>>({{2, 3}, {3, 3}})), // logits length
testing::ValuesIn(
std::vector<std::vector<std::vector<int>>>({{{0, 1, 0}, {1, 0, 1}}, {{0, 1, 2}, {1, 1, 1}}})), // labels
testing::ValuesIn(std::vector<std::vector<int>>({{2, 2}, {2, 1}})), // labels length
testing::Values(2), // blank index
testing::ValuesIn(preprocessCollapseRepeated),
testing::ValuesIn(ctcMergeRepeated),
testing::ValuesIn(unique));
INSTANTIATE_TEST_SUITE_P(smoke_CTCLoss_Set1,
CTCLossLayerTest,
testing::Combine(ctcLossArgsSubset1,
testing::ValuesIn(fPrecisions),
testing::ValuesIn(iPrecisions),
testing::Values(CommonTestUtils::DEVICE_GPU)),
CTCLossLayerTest::getTestCaseName);
const auto ctcLossArgsSubset2 =
testing::Combine(testing::Values(std::vector<size_t>({3, 6, 8})), // logits shape
testing::ValuesIn(std::vector<std::vector<int>>({{6, 5, 6}, {5, 5, 5}})), // logits length
testing::ValuesIn(std::vector<std::vector<std::vector<int>>>(
{{{4, 1, 2, 3, 4, 5}, {5, 4, 3, 0, 1, 0}, {2, 1, 3, 1, 3, 0}},
{{2, 1, 5, 3, 2, 6}, {3, 3, 3, 3, 3, 3}, {6, 5, 6, 5, 6, 5}}})), // labels
testing::ValuesIn(std::vector<std::vector<int>>({{4, 3, 5}, {3, 3, 5}})), // labels length
testing::ValuesIn(std::vector<int>({0, 7})), // blank index
testing::ValuesIn(preprocessCollapseRepeated),
testing::ValuesIn(ctcMergeRepeated),
testing::ValuesIn(unique));
INSTANTIATE_TEST_SUITE_P(smoke_CTCLoss_Set2,
CTCLossLayerTest,
testing::Combine(ctcLossArgsSubset2,
testing::ValuesIn(fPrecisions),
testing::ValuesIn(iPrecisions),
testing::Values(CommonTestUtils::DEVICE_GPU)),
CTCLossLayerTest::getTestCaseName);
} // namespace