[IE CLDNN] Added CTCGreedyDecoderSeqLen operation (#4119)
This commit is contained in:
parent
3f5ff2cfe5
commit
d406a5af50
@ -201,5 +201,8 @@ REGISTER_FACTORY(v5, Round);
|
||||
// REGISTER_FACTORY(v5, Loop);
|
||||
// REGISTER_FACTORY(v5, RNNSequence);
|
||||
|
||||
// ------------------------------ Supported v6 ops ------------------------------ //
|
||||
REGISTER_FACTORY(v6, CTCGreedyDecoderSeqLen);
|
||||
|
||||
// --------------------------- Supported internal ops --------------------------- //
|
||||
REGISTER_FACTORY(internal, NonMaxSuppressionIEInternal);
|
||||
|
@ -1,4 +1,4 @@
|
||||
// Copyright (C) 2020 Intel Corporation
|
||||
// Copyright (C) 2020-2021 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
@ -6,27 +6,117 @@
|
||||
#include "cldnn_common_utils.h"
|
||||
|
||||
#include "ngraph/op/ctc_greedy_decoder.hpp"
|
||||
#include "ngraph/op/ctc_greedy_decoder_seq_len.hpp"
|
||||
|
||||
#include "api/ctc_greedy_decoder.hpp"
|
||||
#include "api/reorder.hpp"
|
||||
#include "api/mutable_data.hpp"
|
||||
|
||||
#include "transformations/utils/utils.hpp"
|
||||
|
||||
namespace CLDNNPlugin {
|
||||
|
||||
void CreateCTCGreedyDecoderOp(Program& p, const std::shared_ptr<ngraph::op::v0::CTCGreedyDecoder>& op) {
|
||||
p.ValidateInputs(op, {2});
|
||||
void CreateCommonCTCGreedyDecoderOp(Program& p, const std::shared_ptr<ngraph::Node>& op, bool ctc_merge_repeated) {
|
||||
p.ValidateInputs(op, {2, 3});
|
||||
auto inputPrimitives = p.GetInputPrimitiveIDs(op);
|
||||
std::string layerName = layer_type_name_ID(op);
|
||||
|
||||
auto primitive = cldnn::ctc_greedy_decoder(layerName,
|
||||
inputPrimitives[0],
|
||||
inputPrimitives[1],
|
||||
op->get_ctc_merge_repeated(),
|
||||
DataTypeFromPrecision(op->get_output_element_type(0)),
|
||||
CldnnTensorFromIEDims(op->get_output_shape(0)));
|
||||
std::vector<cldnn::primitive_id> reorderedInputs;
|
||||
reorderedInputs.resize(inputPrimitives.size());
|
||||
|
||||
for (size_t portIndex = 0; portIndex < inputPrimitives.size(); portIndex++) {
|
||||
auto inputDataType = DataTypeFromPrecision(op->get_input_element_type(portIndex));
|
||||
if (inputDataType == cldnn::data_types::i64) {
|
||||
// clDNN primitive supports only i32 data type for 'sequence_length' and 'blank_index' inputs
|
||||
// so we need additional reorder if it's provided as i64
|
||||
auto reorderPrimName = inputPrimitives[portIndex] + "_" + op->get_friendly_name() + Program::m_preProcessTag;
|
||||
auto targetFormat = DefaultFormatForDims(op->get_input_shape(portIndex).size());
|
||||
auto preprocessPrim = cldnn::reorder(reorderPrimName,
|
||||
inputPrimitives[portIndex],
|
||||
targetFormat,
|
||||
cldnn::data_types::i32);
|
||||
p.AddPrimitive(preprocessPrim);
|
||||
p.AddInnerPrimitiveToProfiler(reorderPrimName, layer_type_name_ID(op), op);
|
||||
reorderedInputs[portIndex] = (reorderPrimName);
|
||||
} else {
|
||||
reorderedInputs[portIndex] = inputPrimitives[portIndex];
|
||||
}
|
||||
}
|
||||
|
||||
uint32_t blank_index = op->get_input_shape(0).back() - 1;
|
||||
if (reorderedInputs.size() == 3) {
|
||||
auto blank_index_node = std::dynamic_pointer_cast<ngraph::op::v0::Constant>(op->get_input_node_shared_ptr(2));
|
||||
if (!blank_index_node) {
|
||||
THROW_IE_EXCEPTION << "Unsupported blank_index node type in " << op->get_friendly_name() << " (" << op->get_type_name() << ")";
|
||||
}
|
||||
float val;
|
||||
if (ngraph::shape_size(blank_index_node->get_output_shape(0)) != 1 || !ngraph::op::util::get_single_value(blank_index_node, val)) {
|
||||
THROW_IE_EXCEPTION << "Unsupported parameter size in " << op->get_friendly_name() << " (" << op->get_type_name() << ")";
|
||||
}
|
||||
blank_index = static_cast<uint32_t>(val);
|
||||
reorderedInputs.pop_back();
|
||||
}
|
||||
|
||||
std::size_t num_output = op->get_output_size();
|
||||
|
||||
std::vector<cldnn::memory> shared_memory;
|
||||
if (num_output == 2) {
|
||||
auto mutable_precision = op->get_output_element_type(1);
|
||||
if (mutable_precision == ngraph::element::i64) {
|
||||
mutable_precision = ngraph::element::i32;
|
||||
}
|
||||
|
||||
cldnn::layout mutableLayout = cldnn::layout(
|
||||
DataTypeFromPrecision(mutable_precision),
|
||||
DefaultFormatForDims(op->get_output_shape(1).size()),
|
||||
CldnnTensorFromIEDims(op->get_output_shape(1)));
|
||||
|
||||
shared_memory.emplace_back(cldnn::memory::allocate(p.GetEngine(), mutableLayout));
|
||||
|
||||
cldnn::primitive_id ctc_gd_mutable_id_w = layer_type_name_ID(op) + "_md_write";
|
||||
auto ctc_gd_mutable_prim = cldnn::mutable_data(ctc_gd_mutable_id_w, shared_memory[0]);
|
||||
p.primitivesToIRLayersMap[ctc_gd_mutable_id_w] = { op->get_friendly_name() };
|
||||
p.primitiveIDs[ctc_gd_mutable_id_w] = ctc_gd_mutable_id_w;
|
||||
p.AddPrimitive(ctc_gd_mutable_prim);
|
||||
reorderedInputs.push_back(ctc_gd_mutable_id_w);
|
||||
}
|
||||
|
||||
auto CTCGreedyDecoderLayerName = num_output == 2 ? layer_type_name_ID(op) + ".0" : layer_type_name_ID(op);
|
||||
auto primitive = cldnn::ctc_greedy_decoder(
|
||||
CTCGreedyDecoderLayerName,
|
||||
reorderedInputs,
|
||||
blank_index,
|
||||
ctc_merge_repeated,
|
||||
CldnnTensorFromIEDims(op->get_output_shape(0)));
|
||||
|
||||
// clDNN primitive supports only i32 as output data type
|
||||
primitive.output_data_type = DataTypeFromPrecision(ngraph::element::i32);
|
||||
|
||||
if (num_output == 2) {
|
||||
primitive.second_output = reorderedInputs.back();
|
||||
}
|
||||
|
||||
p.AddPrimitive(primitive);
|
||||
p.AddPrimitiveToProfiler(op);
|
||||
|
||||
if (num_output == 2) {
|
||||
cldnn::primitive_id ctc_gd_mutable_id_r = layer_type_name_ID(op) + ".1";
|
||||
auto ctc_gd_mutable_prim_r = cldnn::mutable_data(ctc_gd_mutable_id_r, { CTCGreedyDecoderLayerName }, shared_memory[0]);
|
||||
p.primitivesToIRLayersMap[ctc_gd_mutable_id_r] = { op->get_friendly_name() };
|
||||
p.primitiveIDs[ctc_gd_mutable_id_r] = ctc_gd_mutable_id_r;
|
||||
p.AddPrimitive(ctc_gd_mutable_prim_r);
|
||||
}
|
||||
|
||||
p.AddPrimitiveToProfiler(CTCGreedyDecoderLayerName, op);
|
||||
}
|
||||
|
||||
void CreateCTCGreedyDecoderOp(Program& p, const std::shared_ptr<ngraph::op::v0::CTCGreedyDecoder>& op) {
|
||||
CreateCommonCTCGreedyDecoderOp(p, op, op->get_ctc_merge_repeated());
|
||||
}
|
||||
|
||||
void CreateCTCGreedyDecoderSeqLenOp(Program& p, const std::shared_ptr<ngraph::op::v6::CTCGreedyDecoderSeqLen>& op) {
|
||||
CreateCommonCTCGreedyDecoderOp(p, op, op->get_merge_repeated());
|
||||
}
|
||||
|
||||
REGISTER_FACTORY_IMPL(v0, CTCGreedyDecoder);
|
||||
REGISTER_FACTORY_IMPL(v6, CTCGreedyDecoderSeqLen);
|
||||
|
||||
} // namespace CLDNNPlugin
|
||||
|
@ -1,4 +1,4 @@
|
||||
// Copyright (C) 2020 Intel Corporation
|
||||
// Copyright (C) 2020-2021 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
@ -10,24 +10,31 @@ using namespace LayerTestsDefinitions;
|
||||
using namespace ngraph::helpers;
|
||||
|
||||
namespace {
|
||||
// Common params
|
||||
const std::vector<InferenceEngine::Precision> netPrecisions = {
|
||||
InferenceEngine::Precision::FP32,
|
||||
InferenceEngine::Precision::FP16
|
||||
};
|
||||
// Common params
|
||||
const std::vector<InferenceEngine::Precision> netPrecisions = {
|
||||
InferenceEngine::Precision::FP32,
|
||||
InferenceEngine::Precision::FP16
|
||||
};
|
||||
std::vector<bool> mergeRepeated{true, false};
|
||||
|
||||
const auto basicCases = ::testing::Combine(
|
||||
::testing::ValuesIn(netPrecisions),
|
||||
::testing::Values(InferenceEngine::Precision::UNSPECIFIED),
|
||||
::testing::Values(InferenceEngine::Precision::UNSPECIFIED),
|
||||
::testing::Values(InferenceEngine::Layout::ANY),
|
||||
::testing::Values(InferenceEngine::Layout::ANY),
|
||||
::testing::Values(std::vector<size_t>({ 10, 1, 16 }),
|
||||
std::vector<size_t>({ 20, 2, 8 })),
|
||||
::testing::Values(true, false),
|
||||
::testing::Values(CommonTestUtils::DEVICE_GPU));
|
||||
const auto basicCases = ::testing::Combine(
|
||||
::testing::ValuesIn(netPrecisions),
|
||||
::testing::Values(InferenceEngine::Precision::UNSPECIFIED),
|
||||
::testing::Values(InferenceEngine::Precision::UNSPECIFIED),
|
||||
::testing::Values(InferenceEngine::Layout::ANY),
|
||||
::testing::Values(InferenceEngine::Layout::ANY),
|
||||
::testing::Values(std::vector<size_t>({ 50, 3, 3 }),
|
||||
std::vector<size_t>({ 50, 3, 7 }),
|
||||
std::vector<size_t>({ 50, 3, 8 }),
|
||||
std::vector<size_t>({ 50, 3, 16 }),
|
||||
std::vector<size_t>({ 50, 3, 128 }),
|
||||
std::vector<size_t>({ 50, 3, 49 }),
|
||||
std::vector<size_t>({ 50, 3, 55 }),
|
||||
std::vector<size_t>({ 1, 1, 16 })),
|
||||
::testing::ValuesIn(mergeRepeated),
|
||||
::testing::Values(CommonTestUtils::DEVICE_GPU));
|
||||
|
||||
INSTANTIATE_TEST_CASE_P(smoke_CTC_Greedy_decoder_Basic, CTCGreedyDecoderLayerTest,
|
||||
basicCases,
|
||||
CTCGreedyDecoderLayerTest::getTestCaseName);
|
||||
INSTANTIATE_TEST_CASE_P(smoke_CtcGreedyDecoderBasic, CTCGreedyDecoderLayerTest,
|
||||
basicCases,
|
||||
CTCGreedyDecoderLayerTest::getTestCaseName);
|
||||
} // namespace
|
||||
|
@ -0,0 +1,48 @@
|
||||
// Copyright (C) 2021 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#include <vector>
|
||||
#include "single_layer_tests/ctc_greedy_decoder_seq_len.hpp"
|
||||
#include "common_test_utils/test_constants.hpp"
|
||||
|
||||
using namespace LayerTestsDefinitions;
|
||||
using namespace ngraph::helpers;
|
||||
|
||||
namespace {
|
||||
|
||||
std::vector<std::vector<size_t>> inputShape{{1, 1, 1}, {1, 6, 10}, {3, 3, 16}, {5, 3, 55}};
|
||||
|
||||
const std::vector<InferenceEngine::Precision> probPrecisions = {
|
||||
InferenceEngine::Precision::FP32,
|
||||
InferenceEngine::Precision::FP16
|
||||
};
|
||||
const std::vector<InferenceEngine::Precision> idxPrecisions = {
|
||||
InferenceEngine::Precision::I32,
|
||||
InferenceEngine::Precision::I64
|
||||
};
|
||||
|
||||
std::vector<bool> mergeRepeated{true, false};
|
||||
|
||||
const auto basicCases = ::testing::Combine(
|
||||
::testing::ValuesIn(inputShape),
|
||||
::testing::ValuesIn(probPrecisions),
|
||||
::testing::ValuesIn(idxPrecisions),
|
||||
::testing::Values(0),
|
||||
::testing::ValuesIn(mergeRepeated),
|
||||
::testing::Values(CommonTestUtils::DEVICE_GPU));
|
||||
|
||||
INSTANTIATE_TEST_CASE_P(smoke_set1, CTCGreedyDecoderSeqLenLayerTest,
|
||||
basicCases,
|
||||
CTCGreedyDecoderSeqLenLayerTest::getTestCaseName);
|
||||
|
||||
INSTANTIATE_TEST_CASE_P(smoke_set2, CTCGreedyDecoderSeqLenLayerTest,
|
||||
::testing::Combine(
|
||||
::testing::Values(std::vector<size_t>{2, 8, 11}),
|
||||
::testing::ValuesIn(probPrecisions),
|
||||
::testing::ValuesIn(idxPrecisions),
|
||||
::testing::ValuesIn(std::vector<int>{0, 5, 10}),
|
||||
::testing::ValuesIn(mergeRepeated),
|
||||
::testing::Values(CommonTestUtils::DEVICE_GPU)),
|
||||
CTCGreedyDecoderSeqLenLayerTest::getTestCaseName);
|
||||
} // namespace
|
@ -1,5 +1,5 @@
|
||||
/*
|
||||
// Copyright (c) 2020 Intel Corporation
|
||||
// Copyright (c) 2020-2021 Intel Corporation
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
@ -32,24 +32,24 @@ struct ctc_greedy_decoder : public primitive_base<ctc_greedy_decoder> {
|
||||
|
||||
/// @brief Constructs ctc_greedy_decoder primitive.
|
||||
/// @param id This primitive id.
|
||||
/// @param input Input primitive id.
|
||||
/// @param input sequence_indicators primitive id.
|
||||
/// @param ctc_merge_repeated int
|
||||
/// @param input Input primitive id (input, sequence_indicators, second_output(optional)).
|
||||
/// @param blank_index Specifies the class index to use for the blank class.
|
||||
/// @param ctc_merge_repeated Flag for merging repeated labels during the CTC calculation
|
||||
ctc_greedy_decoder(const primitive_id& id,
|
||||
const primitive_id& input,
|
||||
const primitive_id& sequence_indicators,
|
||||
const bool ctc_merge_repeated,
|
||||
const data_types data_type,
|
||||
const tensor output_tensor,
|
||||
const padding& output_padding = padding())
|
||||
: primitive_base(id, { input, sequence_indicators },
|
||||
output_padding, optional_data_type{ data_type }),
|
||||
ctc_merge_repeated(ctc_merge_repeated),
|
||||
output_tensor(output_tensor)
|
||||
{}
|
||||
const std::vector<primitive_id>& input,
|
||||
const uint32_t blank_index,
|
||||
const bool ctc_merge_repeated,
|
||||
const tensor output_tensor,
|
||||
const padding& output_padding = padding())
|
||||
: primitive_base(id, input, output_padding)
|
||||
, blank_index(blank_index)
|
||||
, ctc_merge_repeated(ctc_merge_repeated)
|
||||
, output_tensor(output_tensor) {}
|
||||
|
||||
uint32_t blank_index;
|
||||
bool ctc_merge_repeated;
|
||||
tensor output_tensor;
|
||||
primitive_id second_output;
|
||||
};
|
||||
/// @}
|
||||
/// @}
|
||||
|
@ -1,4 +1,4 @@
|
||||
// Copyright (c) 2020 Intel Corporation
|
||||
// Copyright (c) 2020-2021 Intel Corporation
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
@ -24,11 +24,23 @@ JitConstants CTCGreedyDecoderKernelBase::GetJitConstants(const ctc_greedy_decode
|
||||
|
||||
jit.AddConstants({
|
||||
MakeJitConstant("ctc_merge_repeated_", params.merge_repeated),
|
||||
MakeJitConstant("T_", inp.Batch().v),
|
||||
MakeJitConstant("N_", inp.Feature().v),
|
||||
MakeJitConstant("blank_index_", params.blank_index),
|
||||
MakeJitConstant("C_", inp.Y().v)
|
||||
});
|
||||
|
||||
if (params.outputs_num == 2) {
|
||||
jit.AddConstants({
|
||||
MakeJitConstant("SECOND_OUTPUT_EXIST", 1),
|
||||
MakeJitConstant("N_", inp.Batch().v),
|
||||
MakeJitConstant("T_", inp.Feature().v)
|
||||
});
|
||||
} else {
|
||||
jit.AddConstants({
|
||||
MakeJitConstant("T_", inp.Batch().v),
|
||||
MakeJitConstant("N_", inp.Feature().v)
|
||||
});
|
||||
};
|
||||
|
||||
return jit;
|
||||
}
|
||||
|
||||
@ -71,6 +83,10 @@ KernelsData CTCGreedyDecoderKernelBase::GetCommonKernelsData(const Params& param
|
||||
2, // input and sequence indicatiors
|
||||
GetFusedPrimitiveInputsCount(params));
|
||||
|
||||
if (orgParams.outputs_num == 2) {
|
||||
kernel.arguments.push_back({ArgumentDescriptor::Types::INPUT, 2});
|
||||
}
|
||||
|
||||
return {kd};
|
||||
}
|
||||
|
||||
|
@ -1,4 +1,4 @@
|
||||
// Copyright (c) 2020 Intel Corporation
|
||||
// Copyright (c) 2020-2021 Intel Corporation
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
@ -25,6 +25,8 @@ struct ctc_greedy_decoder_params : public base_params {
|
||||
ctc_greedy_decoder_params() : base_params(KernelType::CTC_GREEDY_DECODER) {}
|
||||
|
||||
bool merge_repeated = true;
|
||||
uint32_t blank_index;
|
||||
uint32_t outputs_num = 1;
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
@ -1,4 +1,4 @@
|
||||
// Copyright (c) 2020 Intel Corporation
|
||||
// Copyright (c) 2020-2021 Intel Corporation
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
@ -21,9 +21,13 @@ ParamsKey CTCGreedyDecoderKernelRef::GetSupportedKey() const {
|
||||
ParamsKey k;
|
||||
k.EnableInputDataType(Datatype::F16);
|
||||
k.EnableInputDataType(Datatype::F32);
|
||||
k.EnableInputDataType(Datatype::INT32);
|
||||
k.EnableInputDataType(Datatype::INT64);
|
||||
|
||||
k.EnableOutputDataType(Datatype::F16);
|
||||
k.EnableOutputDataType(Datatype::F32);
|
||||
k.EnableOutputDataType(Datatype::INT32);
|
||||
k.EnableOutputDataType(Datatype::INT64);
|
||||
|
||||
k.EnableInputLayout(DataLayout::bfyx);
|
||||
k.EnableOutputLayout(DataLayout::bfyx);
|
||||
|
@ -1,4 +1,4 @@
|
||||
// Copyright (C) 2020 Intel Corporation
|
||||
// Copyright (C) 2020-2021 Intel Corporation
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
@ -13,10 +13,13 @@
|
||||
// limitations under the License.
|
||||
#include "include/include_all.cl"
|
||||
|
||||
KERNEL(ctc_greedy_decoder_ref)(
|
||||
const __global INPUT0_TYPE* probabilities,
|
||||
const __global INPUT1_TYPE* sequence_indicators,
|
||||
__global OUTPUT_TYPE* output_sequences)
|
||||
KERNEL(ctc_greedy_decoder_ref)(const __global INPUT0_TYPE* probabilities
|
||||
,const __global INPUT1_TYPE* sequence_indicators
|
||||
,__global OUTPUT_TYPE* output_sequences
|
||||
#ifdef SECOND_OUTPUT_EXIST
|
||||
,__global INPUT2_TYPE* second_output
|
||||
#endif
|
||||
)
|
||||
{
|
||||
// Fill output_sequences with -1
|
||||
for (int ii = 0; ii < T_ * N_; ii++) {
|
||||
@ -27,11 +30,19 @@ KERNEL(ctc_greedy_decoder_ref)(
|
||||
int prev_class_idx = -1;
|
||||
int output_index = n * T_;
|
||||
|
||||
for (int t = 0; /* check at end */; ++t) {
|
||||
for (int t = 0; t < T_; ++t) {
|
||||
// get maximum probability and its index
|
||||
#ifdef SECOND_OUTPUT_EXIST
|
||||
if (t >= sequence_indicators[n]) break;
|
||||
#else
|
||||
if (sequence_indicators[t * N_ + n] == 0) break;
|
||||
#endif
|
||||
int max_class_idx = 0;
|
||||
|
||||
#ifdef SECOND_OUTPUT_EXIST
|
||||
const __global INPUT0_TYPE* probs = probabilities + n * C_ * T_ + t * C_;
|
||||
#else
|
||||
const __global INPUT0_TYPE* probs = probabilities + t * C_ * N_ + n * C_;
|
||||
#endif
|
||||
INPUT0_TYPE max_prob = probs[0];
|
||||
++probs;
|
||||
|
||||
@ -42,15 +53,15 @@ KERNEL(ctc_greedy_decoder_ref)(
|
||||
}
|
||||
}
|
||||
|
||||
if (max_class_idx != C_ - 1 && !(ctc_merge_repeated_ && max_class_idx == prev_class_idx)) {
|
||||
if (max_class_idx != blank_index_ && !(ctc_merge_repeated_ && max_class_idx == prev_class_idx)) {
|
||||
output_sequences[output_index] = max_class_idx;
|
||||
output_index++;
|
||||
}
|
||||
|
||||
prev_class_idx = max_class_idx;
|
||||
if (t + 1 == T_ || sequence_indicators[(t + 1) * N_ + n] == 0) {
|
||||
break;
|
||||
}
|
||||
}
|
||||
#ifdef SECOND_OUTPUT_EXIST
|
||||
second_output[n] = output_index - n * T_;
|
||||
#endif
|
||||
}
|
||||
}
|
||||
|
@ -1,5 +1,5 @@
|
||||
/*
|
||||
// Copyright (c) 2020 Intel Corporation
|
||||
// Copyright (c) 2020-2021 Intel Corporation
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
@ -27,16 +27,17 @@ primitive_type_id ctc_greedy_decoder::type_id() {
|
||||
|
||||
layout ctc_greedy_decoder_inst::calc_output_layout(ctc_greedy_decoder_node const& node) {
|
||||
auto input_node_layout = node.input().get_non_padded_output_layout();
|
||||
auto output_type = node.get_primitive()->output_data_type ?
|
||||
*node.get_primitive()->output_data_type : input_node_layout.data_type;
|
||||
auto output_tensor = node.get_primitive()->output_tensor;
|
||||
return layout(output_type, input_node_layout.format, output_tensor);
|
||||
auto prim = node.get_primitive();
|
||||
auto output_type = prim->output_data_type ? *prim->output_data_type : input_node_layout.data_type;
|
||||
|
||||
return layout(output_type, input_node_layout.format, prim->output_tensor);
|
||||
}
|
||||
|
||||
std::string ctc_greedy_decoder_inst::to_string(ctc_greedy_decoder_node const& node) {
|
||||
auto node_info = node.desc_to_json();
|
||||
auto desc = node.get_primitive();
|
||||
auto ctc_mr = desc->ctc_merge_repeated;
|
||||
auto blank_index = desc->blank_index;
|
||||
auto& input = node.input();
|
||||
auto& seq_ind = node.seq_indicators();
|
||||
|
||||
@ -46,6 +47,7 @@ std::string ctc_greedy_decoder_inst::to_string(ctc_greedy_decoder_node const& no
|
||||
ctc_gd_info.add("input id", input.id());
|
||||
ctc_gd_info.add("seq inidicatior id", seq_ind.id());
|
||||
ctc_gd_info.add("ctc_mr", ctc_mr);
|
||||
ctc_gd_info.add("blank_index", blank_index);
|
||||
|
||||
node_info->add("ctc_greedy_decoder info", ctc_gd_info);
|
||||
node_info->dump(primitive_description);
|
||||
|
@ -1,5 +1,5 @@
|
||||
/*
|
||||
// Copyright (c) 2020 Intel Corporation
|
||||
// Copyright (c) 2020-2021 Intel Corporation
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
@ -37,10 +37,18 @@ public:
|
||||
static primitive_impl* create(const ctc_greedy_decoder_node& arg) {
|
||||
auto ctc_gd_params = get_default_params<kernel_selector::ctc_greedy_decoder_params>(arg);
|
||||
auto ctc_gd_optional_params = get_default_optional_params<kernel_selector::ctc_greedy_decoder_optional_params>(arg.get_program());
|
||||
auto prim = arg.get_primitive();
|
||||
|
||||
ctc_gd_params.inputs.push_back(
|
||||
convert_data_tensor(arg.seq_indicators().get_output_layout()));
|
||||
ctc_gd_params.merge_repeated = arg.get_primitive()->ctc_merge_repeated;
|
||||
ctc_gd_params.merge_repeated = prim->ctc_merge_repeated;
|
||||
ctc_gd_params.blank_index = prim->blank_index;
|
||||
ctc_gd_params.outputs_num = arg.has_second_output() ? 2 : 1;
|
||||
|
||||
if (ctc_gd_params.outputs_num == 2) {
|
||||
ctc_gd_params.inputs.push_back(
|
||||
convert_data_tensor(arg.second_output().get_output_layout()));
|
||||
}
|
||||
|
||||
auto& kernel_selector = kernel_selector::ctc_greedy_decoder_kernel_selector::Instance();
|
||||
auto best_kernels = kernel_selector.GetBestKernels(
|
||||
@ -62,6 +70,8 @@ namespace detail {
|
||||
attach_ctc_greedy_decoder_gpu::attach_ctc_greedy_decoder_gpu() {
|
||||
implementation_map<ctc_greedy_decoder>::add(std::make_tuple(engine_types::ocl, data_types::f32, format::bfyx), ctc_greedy_decoder_gpu::create);
|
||||
implementation_map<ctc_greedy_decoder>::add(std::make_tuple(engine_types::ocl, data_types::f16, format::bfyx), ctc_greedy_decoder_gpu::create);
|
||||
implementation_map<ctc_greedy_decoder>::add(std::make_tuple(engine_types::ocl, data_types::i32, format::bfyx), ctc_greedy_decoder_gpu::create);
|
||||
implementation_map<ctc_greedy_decoder>::add(std::make_tuple(engine_types::ocl, data_types::i64, format::bfyx), ctc_greedy_decoder_gpu::create);
|
||||
}
|
||||
|
||||
} // namespace detail
|
||||
|
@ -168,7 +168,7 @@ void remove_redundant_reorders::run(program_impl& p) {
|
||||
|
||||
bool no_output_optimization = remove_output_reorders ?
|
||||
r_node.is_output() && (r_node.get_dependency(0).is_output() || r_node.get_dependency(0).is_type<input_layout>() ||
|
||||
r_node.get_dependency(0).can_be_optimized()) : r_node.is_output();
|
||||
r_node.get_dependency(0).can_be_optimized() || r_node.get_dependency(0).get_users().size() != 1) : r_node.is_output();
|
||||
|
||||
if (r_node.has_mean() ||
|
||||
!r_node.get_primitive()->subtract_per_feature.empty() ||
|
||||
|
@ -1,5 +1,5 @@
|
||||
/*
|
||||
// Copyright (c) 2020 Intel Corporation
|
||||
// Copyright (c) 2020-2021 Intel Corporation
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
@ -31,6 +31,9 @@ public:
|
||||
|
||||
program_node& input() const { return get_dependency(0); }
|
||||
program_node& seq_indicators() const { return get_dependency(1); }
|
||||
|
||||
bool has_second_output() const { return !get_primitive()->second_output.empty(); }
|
||||
program_node& second_output() const { return get_dependency(2); }
|
||||
};
|
||||
|
||||
using ctc_greedy_decoder_node = typed_program_node<ctc_greedy_decoder>;
|
||||
|
Loading…
Reference in New Issue
Block a user