[CPU] CTCLoss operation implementation. (#1482)
This commit is contained in:
parent
54a24b0e40
commit
43ec4a5695
@ -51,6 +51,7 @@ set(LAYERS
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/nodes/broadcast.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/nodes/convert.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/nodes/ctc_greedy.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/nodes/ctc_loss.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/nodes/depth_to_space.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/nodes/detectionoutput.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/nodes/detectionoutput_onnx.cpp
|
||||
|
341
inference-engine/src/mkldnn_plugin/nodes/ctc_loss.cpp
Normal file
341
inference-engine/src/mkldnn_plugin/nodes/ctc_loss.cpp
Normal file
@ -0,0 +1,341 @@
|
||||
// Copyright (C) 2020 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#include "base.hpp"
|
||||
#include "ie_parallel.hpp"
|
||||
|
||||
#include <cmath>
|
||||
|
||||
|
||||
namespace InferenceEngine {
|
||||
namespace Extensions {
|
||||
namespace Cpu {
|
||||
|
||||
class CTCLossImpl : public ExtLayerBase {
|
||||
public:
|
||||
explicit CTCLossImpl(const CNNLayer* layer) {
|
||||
_logPrefix = std::string("CTCLoss layer with name '") + layer->name + "'";
|
||||
|
||||
if (layer->insData.size() != 4 && layer->insData.size() != 5)
|
||||
THROW_IE_EXCEPTION << _logPrefix << " has invalid inputs number.";
|
||||
|
||||
_ctcMergeRepeated = layer->GetParamAsBool("ctc_merge_repeated", true);
|
||||
_preprocessCollapseRepeated = layer->GetParamAsBool("preprocess_collapse_repeated", false);
|
||||
_unique = layer->GetParamAsBool("unique", false);
|
||||
|
||||
auto logitsData = layer->insData[0].lock();
|
||||
if (logitsData == nullptr)
|
||||
THROW_IE_EXCEPTION << _logPrefix << " has nullable logits data";
|
||||
auto logitsPrecision = logitsData->getTensorDesc().getPrecision();
|
||||
if (logitsPrecision == Precision::BF16)
|
||||
logitsPrecision = Precision::FP32;
|
||||
|
||||
LayerConfig config;
|
||||
config.inConfs.resize(layer->insData.size());
|
||||
config.inConfs[0].desc = TensorDesc(logitsPrecision,
|
||||
logitsData->getTensorDesc().getDims(),
|
||||
TensorDesc::getLayoutByDims(logitsData->getTensorDesc().getDims()));
|
||||
auto intPrecision = Precision::I32;
|
||||
for (int i = 1; i < layer->insData.size(); i++) {
|
||||
auto data = layer->insData[i].lock();
|
||||
if (data == nullptr)
|
||||
THROW_IE_EXCEPTION << _logPrefix << " has nullable input data at " << i;
|
||||
config.inConfs[i].desc = TensorDesc(intPrecision,
|
||||
data->getTensorDesc().getDims(),
|
||||
TensorDesc::getLayoutByDims(data->getTensorDesc().getDims()));
|
||||
}
|
||||
|
||||
DataConfig outConfig;
|
||||
auto& outDims = layer->outData[0]->getTensorDesc().getDims();
|
||||
outConfig.desc = TensorDesc(logitsPrecision,
|
||||
outDims,
|
||||
TensorDesc::getLayoutByDims(outDims));
|
||||
config.outConfs.push_back(outConfig);
|
||||
config.dynBatchSupport = false;
|
||||
|
||||
confs.push_back(config);
|
||||
}
|
||||
|
||||
StatusCode execute(std::vector<Blob::Ptr>& inputs,
|
||||
std::vector<Blob::Ptr>& outputs,
|
||||
ResponseDesc *resp) noexcept override {
|
||||
const float* logits = inputs[0]->cbuffer().as<const float*>() +
|
||||
inputs[0]->getTensorDesc().getBlockingDesc().getOffsetPadding();
|
||||
const int* logitsLength = inputs[1]->cbuffer().as<const int*>() +
|
||||
inputs[1]->getTensorDesc().getBlockingDesc().getOffsetPadding();
|
||||
const int* labels = inputs[2]->cbuffer().as<const int*>() +
|
||||
inputs[2]->getTensorDesc().getBlockingDesc().getOffsetPadding();
|
||||
const int* labelsLength = inputs[3]->cbuffer().as<const int*>() +
|
||||
inputs[3]->getTensorDesc().getBlockingDesc().getOffsetPadding();
|
||||
float* dstData = outputs[0]->buffer().as<float*>() +
|
||||
outputs[0]->getTensorDesc().getBlockingDesc().getOffsetPadding();
|
||||
|
||||
const auto& logitsShape = inputs[0]->getTensorDesc().getDims();
|
||||
const auto batchNum = logitsShape[0];
|
||||
const auto maxTime = logitsShape[1];
|
||||
const auto classesNum = logitsShape[2];
|
||||
|
||||
int blankIndex = classesNum - 1;
|
||||
if (inputs.size() > 4) {
|
||||
blankIndex = inputs[4]->cbuffer().as<const int*>()[0];
|
||||
}
|
||||
|
||||
std::vector<int> targetD(maxTime);
|
||||
|
||||
const size_t TC = maxTime * classesNum;
|
||||
|
||||
for (size_t b = 0; b < batchNum; b++) {
|
||||
const int actualLogitLen = logitsLength[b];
|
||||
const int actualTargetLen = labelsLength[b];
|
||||
if (actualLogitLen < 0 || actualTargetLen < 0 || actualLogitLen > maxTime || actualTargetLen > maxTime
|
||||
|| actualTargetLen > actualLogitLen) {
|
||||
std::string errorMsg = _logPrefix + ". Logit or label length cannot be greater than max sequence length. "
|
||||
+ "Also a label length cannot be greater than a logit length"
|
||||
+ " and both cannot be negative.\nMaxSeqLen: "
|
||||
+ std::to_string(maxTime) + "; Logit len: " + std::to_string(actualLogitLen)
|
||||
+ "; Label len: " + std::to_string(actualTargetLen);
|
||||
errorMsg.copy(resp->msg, sizeof(resp->msg) - 1);
|
||||
return GENERAL_ERROR;
|
||||
}
|
||||
|
||||
const int* target = &labels[b * maxTime];
|
||||
// Decoding target: merge repeated characters if preprocess_collapse_repeated == True,
|
||||
// find unique elemnts if unique == True
|
||||
size_t decodedTargetLen = 0lu;
|
||||
if (_unique) {
|
||||
std::unordered_set<int> uniqVals;
|
||||
for (size_t t = 0lu; t < actualTargetLen; t++) {
|
||||
if (uniqVals.find(target[t]) != uniqVals.end()) {
|
||||
continue;
|
||||
}
|
||||
uniqVals.insert(target[t]);
|
||||
targetD[decodedTargetLen++] = target[t];
|
||||
}
|
||||
} else if (_preprocessCollapseRepeated) {
|
||||
int prevValue = target[0];
|
||||
targetD[decodedTargetLen++] = target[0];
|
||||
for (size_t t = 1lu; t < actualTargetLen; t++) {
|
||||
if (target[t] == prevValue) {
|
||||
continue;
|
||||
}
|
||||
targetD[decodedTargetLen++] = target[t];
|
||||
prevValue = target[t];
|
||||
}
|
||||
} else {
|
||||
std::copy(target, target + actualTargetLen, targetD.data());
|
||||
decodedTargetLen = actualTargetLen;
|
||||
}
|
||||
|
||||
const size_t BTC = b * TC;
|
||||
|
||||
std::vector<std::unordered_map<size_t, float>> logProbabilities(actualLogitLen);
|
||||
float logProb = 0.f, kExp = 0.f;
|
||||
for (size_t t = 0; t < actualLogitLen; t++) {
|
||||
kExp = 0.f;
|
||||
const size_t btcT = BTC + classesNum * t;
|
||||
for (size_t c = 0; c < classesNum; c++) {
|
||||
kExp += std::exp(logits[btcT + c]);
|
||||
}
|
||||
for (size_t s = 0; s < decodedTargetLen; s++) {
|
||||
logProb = logits[btcT + targetD[s]] - std::log(kExp);
|
||||
logProbabilities[t].insert({targetD[s], logProb});
|
||||
}
|
||||
logProb = logits[btcT + blankIndex] - std::log(kExp);
|
||||
logProbabilities[t].insert({blankIndex, logProb});
|
||||
}
|
||||
|
||||
const auto float_inf = std::numeric_limits<float>::infinity();
|
||||
size_t work_amount = actualLogitLen - decodedTargetLen + 1lu;
|
||||
std::vector<float> sumPerThread(parallel_get_max_threads(), -float_inf);
|
||||
|
||||
// Looking for aligned paths
|
||||
auto thread_body = [&](const int ithr, const int nthr) {
|
||||
size_t start0(0lu), end0(0lu);
|
||||
splitter(work_amount, nthr, ithr, start0, end0);
|
||||
if (start0 >= end0)
|
||||
return;
|
||||
if (ithr >= sumPerThread.size())
|
||||
sumPerThread.push_back(-float_inf);
|
||||
|
||||
std::function<void(size_t, size_t, size_t, float)> findPaths =
|
||||
[&](size_t targetIdx, size_t start, size_t end, float prevLogProb) {
|
||||
if (end > actualLogitLen) {
|
||||
if (sumPerThread[ithr] == -float_inf) {
|
||||
sumPerThread[ithr] = prevLogProb;
|
||||
} else if (prevLogProb != -float_inf) {
|
||||
if (sumPerThread[ithr] > prevLogProb)
|
||||
sumPerThread[ithr] = sumPerThread[ithr] + std::log1pf(std::exp(prevLogProb - sumPerThread[ithr]));
|
||||
else
|
||||
sumPerThread[ithr] = prevLogProb + std::log1pf(std::exp(sumPerThread[ithr] - prevLogProb));
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
size_t nextIdx = targetIdx + 1;
|
||||
int64_t st64 = start;
|
||||
float newLogProb = prevLogProb;
|
||||
if (!_ctcMergeRepeated) {
|
||||
for (size_t pos = start; pos < end; pos++) {
|
||||
newLogProb = prevLogProb;
|
||||
for (size_t bl = start; bl < pos; bl++) {
|
||||
auto lnProbIt = logProbabilities[bl].find(blankIndex);
|
||||
if (lnProbIt != logProbabilities[bl].end())
|
||||
newLogProb += lnProbIt->second;
|
||||
}
|
||||
auto lnProbIt = logProbabilities[pos].find(targetD[targetIdx]);
|
||||
if (lnProbIt != logProbabilities[pos].end())
|
||||
newLogProb += lnProbIt->second;
|
||||
if (end == actualLogitLen) {
|
||||
for (int64_t ble = pos + 1; ble < actualLogitLen; ble++) {
|
||||
auto lnProbIt = logProbabilities[ble].find(blankIndex);
|
||||
if (lnProbIt != logProbabilities[ble].end())
|
||||
newLogProb += lnProbIt->second;
|
||||
}
|
||||
}
|
||||
findPaths(nextIdx, pos + 1, end + 1, newLogProb);
|
||||
}
|
||||
} else {
|
||||
for (size_t pos = start; pos < end; pos++) {
|
||||
newLogProb = prevLogProb;
|
||||
size_t next_start = pos + 1;
|
||||
for (size_t bl = start; bl < pos; bl++) {
|
||||
auto lnProbIt = logProbabilities[bl].find(blankIndex);
|
||||
if (lnProbIt != logProbabilities[bl].end())
|
||||
newLogProb += lnProbIt->second;
|
||||
}
|
||||
if (end == actualLogitLen) {
|
||||
for (int64_t ble = pos + 1; ble < actualLogitLen; ble++) {
|
||||
auto lnProbIt = logProbabilities[ble].find(blankIndex);
|
||||
if (lnProbIt != logProbabilities[ble].end())
|
||||
newLogProb += lnProbIt->second;
|
||||
}
|
||||
}
|
||||
if (targetIdx < decodedTargetLen - 1
|
||||
&& targetD[targetIdx] == targetD[targetIdx + 1]) {
|
||||
auto lnProbIt = logProbabilities[next_start++].find(blankIndex);
|
||||
if (lnProbIt != logProbabilities[next_start].end())
|
||||
newLogProb += lnProbIt->second;
|
||||
}
|
||||
for (int64_t bl = pos; bl >= st64; bl--) {
|
||||
newLogProb += logProbabilities[bl].find(targetD[targetIdx])->second;
|
||||
findPaths(nextIdx, next_start, end + 1, newLogProb);
|
||||
if (bl > 0) {
|
||||
auto lnProbIt = logProbabilities[bl - 1].find(blankIndex);
|
||||
if (lnProbIt != logProbabilities[bl - 1].end())
|
||||
newLogProb -= lnProbIt->second;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}; // findPaths
|
||||
|
||||
// First tartget symbol
|
||||
int64_t st64 = start0;
|
||||
float newLogProb = 0.f;
|
||||
if (!_ctcMergeRepeated) {
|
||||
for (size_t pos = start0; pos < end0; pos++) {
|
||||
newLogProb = 0.f;
|
||||
for (size_t bl = 0; bl < pos; bl++) {
|
||||
auto lnProbIt = logProbabilities[bl].find(blankIndex);
|
||||
if (lnProbIt != logProbabilities[bl].end())
|
||||
newLogProb += lnProbIt->second;
|
||||
}
|
||||
auto lnProbIt = logProbabilities[pos].find(targetD[0]);
|
||||
if (lnProbIt != logProbabilities[pos].end())
|
||||
newLogProb += lnProbIt->second;
|
||||
if (work_amount == actualLogitLen) {
|
||||
for (int64_t ble = pos + 1; ble < actualLogitLen; ble++) {
|
||||
auto lnProbIt = logProbabilities[ble].find(blankIndex);
|
||||
if (lnProbIt != logProbabilities[ble].end())
|
||||
newLogProb += lnProbIt->second;
|
||||
}
|
||||
}
|
||||
if (decodedTargetLen > 1) {
|
||||
findPaths(1, pos + 1, work_amount + 1, newLogProb);
|
||||
} else {
|
||||
if (sumPerThread[ithr] == -float_inf)
|
||||
sumPerThread[ithr] = newLogProb;
|
||||
else if (newLogProb != -float_inf)
|
||||
sumPerThread[ithr] = sumPerThread[ithr] + std::log1pf(std::exp(newLogProb - sumPerThread[ithr]));
|
||||
}
|
||||
}
|
||||
} else {
|
||||
for (size_t pos = start0; pos < end0; pos++) {
|
||||
newLogProb = 0.f;
|
||||
size_t next_start = pos + 1;
|
||||
for (size_t bl = 0; bl < pos; bl++) {
|
||||
auto lnProbIt = logProbabilities[bl].find(blankIndex);
|
||||
if (lnProbIt != logProbabilities[bl].end())
|
||||
newLogProb += lnProbIt->second;
|
||||
}
|
||||
if (work_amount == actualLogitLen) {
|
||||
for (int64_t ble = pos + 1; ble < actualLogitLen; ble++) {
|
||||
auto lnProbIt = logProbabilities[ble].find(blankIndex);
|
||||
if (lnProbIt != logProbabilities[ble].end())
|
||||
newLogProb += lnProbIt->second;
|
||||
}
|
||||
}
|
||||
if (decodedTargetLen > 1
|
||||
&& targetD[0] == targetD[1]) {
|
||||
auto lnProbIt = logProbabilities[next_start++].find(blankIndex);
|
||||
if (lnProbIt != logProbabilities[next_start].end())
|
||||
newLogProb += lnProbIt->second;
|
||||
}
|
||||
for (int64_t bl = pos; bl >= 0; bl--) {
|
||||
auto lnProbIt = logProbabilities[bl].find(targetD[0]);
|
||||
if (lnProbIt != logProbabilities[bl].end())
|
||||
newLogProb += lnProbIt->second;
|
||||
if (decodedTargetLen > 1) {
|
||||
findPaths(1, next_start, work_amount + 1, newLogProb);
|
||||
} else {
|
||||
if (sumPerThread[ithr] == -float_inf)
|
||||
sumPerThread[ithr] = newLogProb;
|
||||
else if (newLogProb != -float_inf)
|
||||
sumPerThread[ithr] = sumPerThread[ithr] + std::log1pf(std::exp(newLogProb - sumPerThread[ithr]));
|
||||
}
|
||||
if (bl > 0) {
|
||||
auto lnProbIt = logProbabilities[bl - 1].find(blankIndex);
|
||||
if (lnProbIt != logProbabilities[bl - 1].end())
|
||||
newLogProb -= lnProbIt->second;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}; // thread_body
|
||||
|
||||
parallel_nt(0, thread_body);
|
||||
|
||||
float res = -float_inf;
|
||||
|
||||
for (auto sum : sumPerThread) {
|
||||
if (res == -float_inf) {
|
||||
res = sum;
|
||||
} else if (sum != -float_inf) {
|
||||
if (res > sum)
|
||||
res = res + std::log1pf(std::exp(sum - res));
|
||||
else
|
||||
res = sum + std::log1pf(std::exp(res - sum));
|
||||
}
|
||||
}
|
||||
|
||||
dstData[b] = -res;
|
||||
} // for (size_t b = 0; b < batchNum; b++)
|
||||
|
||||
return OK;
|
||||
} // execute
|
||||
|
||||
protected:
|
||||
bool _ctcMergeRepeated;
|
||||
bool _preprocessCollapseRepeated;
|
||||
bool _unique;
|
||||
|
||||
std::string _logPrefix;
|
||||
};
|
||||
|
||||
REG_FACTORY_FOR(CTCLossImpl, CTCLoss);
|
||||
|
||||
} // namespace Cpu
|
||||
} // namespace Extensions
|
||||
} // namespace InferenceEngine
|
||||
|
@ -10,6 +10,7 @@
|
||||
MKLDNN_EXTENSION_NODE(EmbeddingBagOffsetsSumImpl, EmbeddingBagOffsetsSum);
|
||||
MKLDNN_EXTENSION_NODE(EmbeddingBagPackedSumImpl, EmbeddingBagPackedSum);
|
||||
MKLDNN_EXTENSION_NODE(EmbeddingSegmentsSumImpl, EmbeddingSegmentsSum);
|
||||
MKLDNN_EXTENSION_NODE(CTCLossImpl, CTCLoss);
|
||||
MKLDNN_EXTENSION_NODE(PriorBoxImpl, PriorBox);
|
||||
MKLDNN_EXTENSION_NODE(MathImpl, Abs);
|
||||
MKLDNN_EXTENSION_NODE(MathImpl, Acos);
|
||||
|
@ -0,0 +1,66 @@
|
||||
// Copyright (C) 2020 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#include <vector>
|
||||
|
||||
#include "single_layer_tests/ctc_loss.hpp"
|
||||
|
||||
using namespace LayerTestsDefinitions;
|
||||
|
||||
namespace {
|
||||
|
||||
const std::vector<InferenceEngine::Precision> fPrecisions = {
|
||||
InferenceEngine::Precision::FP32,
|
||||
InferenceEngine::Precision::FP16
|
||||
};
|
||||
const std::vector<InferenceEngine::Precision> iPrecisions = {
|
||||
InferenceEngine::Precision::I32,
|
||||
InferenceEngine::Precision::I64
|
||||
};
|
||||
|
||||
const std::vector<bool> preprocessCollapseRepeated = {true, false};
|
||||
const std::vector<bool> ctcMergeRepeated = {true, false};
|
||||
const std::vector<bool> unique = {true, false};
|
||||
|
||||
const auto ctcLossArgsSubset1 = ::testing::Combine(
|
||||
::testing::Values(std::vector<size_t>({2, 3, 3})), // logits shape
|
||||
::testing::ValuesIn(std::vector<std::vector<int>>({{2, 3}, {3, 3}})), // logits length
|
||||
::testing::ValuesIn(std::vector<std::vector<std::vector<int>>>(
|
||||
{{{0, 1, 0}, {1, 0, 1}}, {{0, 1, 2}, {1, 1, 1}}})), // labels
|
||||
::testing::ValuesIn(std::vector<std::vector<int>>({{2, 2}, {2, 1}})), // labels length
|
||||
::testing::Values(2), // blank index
|
||||
::testing::ValuesIn(preprocessCollapseRepeated),
|
||||
::testing::ValuesIn(ctcMergeRepeated),
|
||||
::testing::ValuesIn(unique)
|
||||
);
|
||||
|
||||
INSTANTIATE_TEST_CASE_P(Set1, CTCLossLayerTest,
|
||||
::testing::Combine(
|
||||
ctcLossArgsSubset1,
|
||||
::testing::ValuesIn(fPrecisions),
|
||||
::testing::ValuesIn(iPrecisions),
|
||||
::testing::Values(CommonTestUtils::DEVICE_CPU)),
|
||||
CTCLossLayerTest::getTestCaseName);
|
||||
|
||||
const auto ctcLossArgsSubset2 = ::testing::Combine(
|
||||
::testing::Values(std::vector<size_t>({3, 6, 8})), // logits shape
|
||||
::testing::ValuesIn(std::vector<std::vector<int>>({{6, 5, 6}, {5, 5, 5}})), // logits length
|
||||
::testing::ValuesIn(std::vector<std::vector<std::vector<int>>>(
|
||||
{{{4, 1, 2, 3, 4, 5}, {5, 4, 3, 0, 1, 0}, {2, 1, 3, 1, 3, 0}},
|
||||
{{2, 1, 5, 3, 2, 6}, {3, 3, 3, 3, 3, 3}, {6, 5, 6, 5, 6, 5}}})), // labels
|
||||
::testing::ValuesIn(std::vector<std::vector<int>>({{4, 3, 5}, {3, 3, 5}})), // labels length
|
||||
::testing::ValuesIn(std::vector<int>({0, 7})), // blank index
|
||||
::testing::ValuesIn(preprocessCollapseRepeated),
|
||||
::testing::ValuesIn(ctcMergeRepeated),
|
||||
::testing::ValuesIn(unique)
|
||||
);
|
||||
|
||||
INSTANTIATE_TEST_CASE_P(Set2, CTCLossLayerTest,
|
||||
::testing::Combine(
|
||||
ctcLossArgsSubset2,
|
||||
::testing::ValuesIn(fPrecisions),
|
||||
::testing::ValuesIn(iPrecisions),
|
||||
::testing::Values(CommonTestUtils::DEVICE_CPU)),
|
||||
CTCLossLayerTest::getTestCaseName);
|
||||
} // namespace
|
@ -63,7 +63,7 @@ class Eltwise4dBroadcast : public testing::WithParamInterface<eltwiseParams>,
|
||||
auto pattern1 = std::make_shared<ngraph::opset1::Constant>(ngraph::element::Type_t::i64, ngraph::Shape{ 4 }, outFormShapes1);
|
||||
auto reshape1 = std::make_shared<ngraph::opset1::Reshape>(params[0], pattern1, false);
|
||||
|
||||
auto constant1 = ngraph::builder::makeConstant(ngPrc, { 1, 1, 1, 12 }, {}, true);
|
||||
auto constant1 = ngraph::builder::makeConstant<float>(ngPrc, { 1, 1, 1, 12 }, {}, true);
|
||||
auto eltwise = ngraph::builder::makeEltwise(reshape1, constant1, eltwiseType);
|
||||
|
||||
std::vector<size_t> outFormShapes2 = { 1, 72 };
|
||||
|
@ -48,7 +48,7 @@ protected:
|
||||
auto ngPrc = FuncTestUtils::PrecisionUtils::convertIE2nGraphPrc(netPrecision);
|
||||
|
||||
auto params = ngraph::builder::makeParams(ngPrc, { {1, 67000} });
|
||||
auto const_mult2 = ngraph::builder::makeConstant(ngPrc, {1, 67000}, {-1.0f});
|
||||
auto const_mult2 = ngraph::builder::makeConstant<float>(ngPrc, {1, 67000}, {-1.0f});
|
||||
|
||||
auto sum = ngraph::builder::makeEltwise(params[0], const_mult2, ngraph::helpers::EltwiseTypes::MULTIPLY);
|
||||
function = std::make_shared<ngraph::Function>(sum, params, "RemovePermutationPass");
|
||||
|
@ -0,0 +1,43 @@
|
||||
// Copyright (C) 2020 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <string>
|
||||
#include <tuple>
|
||||
#include <vector>
|
||||
|
||||
#include "functional_test_utils/layer_test_utils.hpp"
|
||||
|
||||
|
||||
typedef std::tuple<
|
||||
std::vector<size_t>, // Logits shapes
|
||||
std::vector<int>, // logits lenght
|
||||
std::vector<std::vector<int>>, // labels
|
||||
std::vector<int>, // labels length
|
||||
int, // blank index
|
||||
bool, // preprocessCollapseRepeated
|
||||
bool, // ctcMergeRepeated
|
||||
bool // Unique
|
||||
> CTCLossParamsSubset;
|
||||
|
||||
typedef std::tuple<
|
||||
CTCLossParamsSubset,
|
||||
InferenceEngine::Precision, // Float point precision
|
||||
InferenceEngine::Precision, // Integer precision
|
||||
LayerTestsUtils::TargetDevice // Device name
|
||||
> CTCLossParams;
|
||||
|
||||
namespace LayerTestsDefinitions {
|
||||
|
||||
class CTCLossLayerTest : public testing::WithParamInterface<CTCLossParams>,
|
||||
public LayerTestsUtils::LayerTestsCommon {
|
||||
public:
|
||||
static std::string getTestCaseName(const testing::TestParamInfo<CTCLossParams> &obj);
|
||||
|
||||
protected:
|
||||
void SetUp() override;
|
||||
};
|
||||
|
||||
} // namespace LayerTestsDefinitions
|
@ -0,0 +1,71 @@
|
||||
// Copyright (C) 2020 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <tuple>
|
||||
#include <vector>
|
||||
|
||||
#include "ngraph_functions/builders.hpp"
|
||||
#include "single_layer_tests/ctc_loss.hpp"
|
||||
|
||||
namespace LayerTestsDefinitions {
|
||||
|
||||
std::string CTCLossLayerTest::getTestCaseName(const testing::TestParamInfo<CTCLossParams>& obj) {
|
||||
InferenceEngine::SizeVector logitsShapes;
|
||||
InferenceEngine::Precision fpPrecision, intPrecision;
|
||||
bool preprocessCollapseRepeated, ctcMergeRepeated, unique;
|
||||
std::vector<int> logitsLength, labelsLength;
|
||||
std::vector<std::vector<int>> labels;
|
||||
int blankIndex;
|
||||
std::string targetDevice;
|
||||
CTCLossParamsSubset ctcLossArgsSubset;
|
||||
std::tie(ctcLossArgsSubset, fpPrecision, intPrecision, targetDevice) = obj.param;
|
||||
std::tie(logitsShapes, logitsLength, labels, labelsLength, blankIndex, preprocessCollapseRepeated,
|
||||
ctcMergeRepeated, unique) = ctcLossArgsSubset;
|
||||
|
||||
std::ostringstream result;
|
||||
result << "IS=" << CommonTestUtils::vec2str(logitsShapes) << "_";
|
||||
result << "LL=" << CommonTestUtils::vec2str(logitsLength) << "_";
|
||||
result << "A=" << CommonTestUtils::vec2str(labels) << "_";
|
||||
result << "AL=" << CommonTestUtils::vec2str(labelsLength) << "_";
|
||||
result << "BI=" << blankIndex << "_";
|
||||
result << "PCR=" << preprocessCollapseRepeated << "_";
|
||||
result << "CMR=" << ctcMergeRepeated << "_";
|
||||
result << "U=" << unique << "_";
|
||||
result << "PF=" << fpPrecision.name() << "_";
|
||||
result << "PI=" << intPrecision.name() << "_";
|
||||
result << "targetDevice=" << targetDevice;
|
||||
return result.str();
|
||||
}
|
||||
|
||||
void CTCLossLayerTest::SetUp() {
|
||||
std::vector<size_t> logitsShapes;
|
||||
InferenceEngine::Precision fpPrecision, intPrecision;
|
||||
bool preprocessCollapseRepeated, ctcMergeRepeated, unique;
|
||||
std::vector<int> logitsLength, labelsLength;
|
||||
std::vector<std::vector<int>> labels;
|
||||
int blankIndex;
|
||||
CTCLossParamsSubset ctcLossArgsSubset;
|
||||
std::tie(ctcLossArgsSubset, fpPrecision, intPrecision, targetDevice) = this->GetParam();
|
||||
std::tie(logitsShapes, logitsLength, labels, labelsLength, blankIndex, preprocessCollapseRepeated,
|
||||
ctcMergeRepeated, unique) = ctcLossArgsSubset;
|
||||
|
||||
auto ngFpPrc = FuncTestUtils::PrecisionUtils::convertIE2nGraphPrc(fpPrecision);
|
||||
auto ngIntPrc = FuncTestUtils::PrecisionUtils::convertIE2nGraphPrc(intPrecision);
|
||||
|
||||
auto params = ngraph::builder::makeParams(ngFpPrc, {logitsShapes});
|
||||
auto paramOuts = ngraph::helpers::convert2OutputVector(
|
||||
ngraph::helpers::castOps2Nodes<ngraph::op::Parameter>(params));
|
||||
auto conv = std::dynamic_pointer_cast<ngraph::opset4::CTCLoss>(
|
||||
ngraph::builder::makeCTCLoss(paramOuts[0], logitsLength, labels, labelsLength, blankIndex,
|
||||
ngFpPrc, ngIntPrc, preprocessCollapseRepeated, ctcMergeRepeated, unique));
|
||||
ngraph::ResultVector results{std::make_shared<ngraph::opset1::Result>(conv)};
|
||||
function = std::make_shared<ngraph::Function>(results, params, "CTCLoss");
|
||||
}
|
||||
|
||||
TEST_P(CTCLossLayerTest, CompareWithRefs) {
|
||||
Run();
|
||||
}
|
||||
} // namespace LayerTestsDefinitions
|
@ -58,16 +58,16 @@ void Basic_LSTM_S::SetUp() {
|
||||
auto reshape1 = std::make_shared<ngraph::opset1::Reshape>(params[0], pattern1, false);
|
||||
|
||||
auto reshape1_shape = reshape1->output(0).get_shape();
|
||||
auto H_init = ngraph::builder::makeConstant(ngPrc, { batch_size, hidden_size }, {}, true);
|
||||
auto C_init = ngraph::builder::makeConstant(ngPrc, { batch_size, hidden_size }, {}, true);
|
||||
auto H_init = ngraph::builder::makeConstant<float>(ngPrc, { batch_size, hidden_size }, {}, true);
|
||||
auto C_init = ngraph::builder::makeConstant<float>(ngPrc, { batch_size, hidden_size }, {}, true);
|
||||
|
||||
auto H_t = std::make_shared<ngraph::opset1::Parameter>(ngPrc, ngraph::Shape{ batch_size, hidden_size });
|
||||
auto C_t = std::make_shared<ngraph::opset1::Parameter>(ngPrc, ngraph::Shape{ batch_size, hidden_size });
|
||||
|
||||
//Body
|
||||
auto X = std::make_shared<ngraph::opset1::Parameter>(ngPrc, ngraph::Shape{ batch_size, 1, reshape1_shape[2] });
|
||||
auto weightsNode = ngraph::builder::makeConstant(ngPrc, { 4 * hidden_size, reshape1_shape[2] }, {}, true);
|
||||
auto reccurrenceWeightsNode = ngraph::builder::makeConstant(ngPrc, { 4 * hidden_size, hidden_size }, {}, true);
|
||||
auto weightsNode = ngraph::builder::makeConstant<float>(ngPrc, { 4 * hidden_size, reshape1_shape[2] }, {}, true);
|
||||
auto reccurrenceWeightsNode = ngraph::builder::makeConstant<float>(ngPrc, { 4 * hidden_size, hidden_size }, {}, true);
|
||||
|
||||
//lstm [1, 10], [1, 118], [1, 118] -> [1, 118], [1, 118]
|
||||
outFormShapes1 = { batch_size, reshape1_shape[2] };
|
||||
@ -138,11 +138,11 @@ std::shared_ptr<ngraph::Function> Basic_LSTM_S::CreateGraphWithUnrolledTI() {
|
||||
ngraph::Output<ngraph::Node> H[iterations + 1];
|
||||
ngraph::Output<ngraph::Node> C[iterations + 1];
|
||||
std::shared_ptr<ngraph::opset1::LSTMCell> lstm[iterations];
|
||||
H[0] = ngraph::builder::makeConstant(ngPrc, { batch_size, hidden_size }, {}, true);
|
||||
C[0] = ngraph::builder::makeConstant(ngPrc, { batch_size, hidden_size }, {}, true);
|
||||
H[0] = ngraph::builder::makeConstant<float>(ngPrc, { batch_size, hidden_size }, {}, true);
|
||||
C[0] = ngraph::builder::makeConstant<float>(ngPrc, { batch_size, hidden_size }, {}, true);
|
||||
auto reshape1_shape = reshape1->output(0).get_shape();
|
||||
auto weightsNode = ngraph::builder::makeConstant(ngPrc, { 4 * hidden_size, reshape1_shape[2] }, {}, true);
|
||||
auto reccurrenceWeightsNode = ngraph::builder::makeConstant(ngPrc, { 4 * hidden_size, hidden_size }, {}, true);
|
||||
auto weightsNode = ngraph::builder::makeConstant<float>(ngPrc, { 4 * hidden_size, reshape1_shape[2] }, {}, true);
|
||||
auto reccurrenceWeightsNode = ngraph::builder::makeConstant<float>(ngPrc, { 4 * hidden_size, hidden_size }, {}, true);
|
||||
|
||||
outFormShapes1 = { batch_size, reshape1_shape[2] };
|
||||
auto constantX = std::make_shared<ngraph::opset1::Constant>(ngraph::element::i64, ngraph::Shape{ 2 }, outFormShapes1);
|
||||
|
@ -51,7 +51,7 @@ void ConcatQuantization::SetUp() {
|
||||
std::vector<size_t> outFormShapes2 = { 1, 160 };
|
||||
auto pattern2 = std::make_shared<ngraph::opset1::Constant>(ngraph::element::Type_t::i64, ngraph::Shape{ 2 }, outFormShapes2);
|
||||
auto reshape2 = std::make_shared<ngraph::opset1::Reshape>(tanh, pattern2, false);
|
||||
auto scale = ngraph::builder::makeConstant(ngPrc, outFormShapes2, {}, true);
|
||||
auto scale = ngraph::builder::makeConstant<float>(ngPrc, outFormShapes2, {}, true);
|
||||
//For ngraph::op::ScaleShift: Cannot cast ngraph node ScaleShift to CNNLayer!
|
||||
auto scale_shift = std::make_shared<ngraph::opset1::Multiply>(reshape2, scale);
|
||||
|
||||
|
@ -43,8 +43,8 @@ void ConvMultiply::SetUp() {
|
||||
|
||||
ngraph::Shape strides(spatial_dims, 1);
|
||||
std::vector<ptrdiff_t> pad_begin(spatial_dims, 0), pad_end(spatial_dims, 0);
|
||||
auto weights = ngraph::builder::makeConstant(precision, weights_shape, {}, true);
|
||||
auto mul_const = ngraph::builder::makeConstant(precision, const_shape, {}, true);
|
||||
auto weights = ngraph::builder::makeConstant<float>(precision, weights_shape, {}, true);
|
||||
auto mul_const = ngraph::builder::makeConstant<float>(precision, const_shape, {}, true);
|
||||
std::shared_ptr<ngraph::Node> conv;
|
||||
if (conv_type == ngraph::opset4::Convolution::type_info) {
|
||||
conv = std::make_shared<ngraph::opset4::Convolution>(param, weights, strides, pad_begin, pad_end, strides);
|
||||
|
@ -22,8 +22,48 @@ ngraph::ParameterVector makeParams(const element::Type &type, const std::vector<
|
||||
ngraph::ParameterVector
|
||||
makeParams(const element::Type &type, const std::vector<std::pair<std::string, std::vector<size_t>>> &inputs);
|
||||
|
||||
std::shared_ptr<ngraph::Node> makeConstant(const element::Type &type, const std::vector<size_t> &shape,
|
||||
const std::vector<float> &data, bool random = false);
|
||||
template<typename T>
|
||||
std::shared_ptr<Node> makeConstant(const element::Type &type, const std::vector<size_t> &shape,
|
||||
const std::vector<T> &data, bool random = false) {
|
||||
std::shared_ptr<ngraph::Node> weightsNode;
|
||||
|
||||
#define makeNode(TYPE) \
|
||||
case TYPE: \
|
||||
weightsNode = std::make_shared<ngraph::opset1::Constant>( \
|
||||
type, shape, \
|
||||
random ? NGraphFunctions::Utils::generateVector<TYPE>(ngraph::shape_size(shape)) : \
|
||||
NGraphFunctions::Utils::castVector<T, ngraph::helpers::nGraphTypesTrait<TYPE>::value_type >(data)); \
|
||||
break;
|
||||
switch (type) {
|
||||
case ngraph::element::Type_t::bf16:
|
||||
weightsNode = std::make_shared<ngraph::opset1::Constant>(
|
||||
type, shape,
|
||||
random ? NGraphFunctions::Utils::generateBF16Vector(ngraph::shape_size(shape)) :
|
||||
NGraphFunctions::Utils::castVector<T, ngraph::bfloat16>(data));
|
||||
break;
|
||||
case ngraph::element::Type_t::f16:
|
||||
weightsNode = std::make_shared<ngraph::opset1::Constant>(
|
||||
type, shape,
|
||||
random ? NGraphFunctions::Utils::generateF16Vector(ngraph::shape_size(shape)) :
|
||||
NGraphFunctions::Utils::castVector<T, ngraph::float16>(data));
|
||||
break;
|
||||
makeNode(ngraph::element::Type_t::f32);
|
||||
makeNode(ngraph::element::Type_t::f64);
|
||||
makeNode(ngraph::element::Type_t::i8);
|
||||
makeNode(ngraph::element::Type_t::i16);
|
||||
makeNode(ngraph::element::Type_t::i32);
|
||||
makeNode(ngraph::element::Type_t::i64);
|
||||
makeNode(ngraph::element::Type_t::u8);
|
||||
makeNode(ngraph::element::Type_t::u16);
|
||||
makeNode(ngraph::element::Type_t::u32);
|
||||
makeNode(ngraph::element::Type_t::u64);
|
||||
makeNode(ngraph::element::Type_t::boolean);
|
||||
#undef makeNode
|
||||
default:
|
||||
throw std::runtime_error("Unhandled precision");
|
||||
}
|
||||
return weightsNode;
|
||||
}
|
||||
|
||||
std::shared_ptr<ngraph::Node> makeInputLayer(const element::Type& type, ngraph::helpers::InputLayerType inputType,
|
||||
const std::vector<size_t>& shape);
|
||||
@ -90,6 +130,18 @@ std::shared_ptr<ngraph::Node> makeConvolutionBackpropData(const ngraph::Output<N
|
||||
bool addBiases = false,
|
||||
const std::vector<float> &biasesWeights = {});
|
||||
|
||||
std::shared_ptr<ngraph::Node> makeCTCLoss(
|
||||
const ngraph::Output<Node>& logitsNode,
|
||||
std::vector<int>& logitsLength,
|
||||
std::vector<std::vector<int>>& labels,
|
||||
std::vector<int>& labelsLength,
|
||||
int blankIndex,
|
||||
const element::Type& fType,
|
||||
const element::Type& iType,
|
||||
const bool preprocessCollapseRepeated,
|
||||
const bool ctcMergeRepeated,
|
||||
const bool unique);
|
||||
|
||||
std::shared_ptr<ngraph::Node> makeGroupConvolutionBackpropData(const ngraph::Output<Node> &in,
|
||||
const element::Type &type,
|
||||
const std::vector<size_t> &filterSize,
|
||||
|
@ -63,11 +63,12 @@ std::vector<ngraph::bfloat16> inline generateBF16Vector(size_t vec_len) {
|
||||
template<typename fromType, typename toType>
|
||||
std::vector<toType> castVector(const std::vector<fromType> &vec) {
|
||||
std::vector<toType> resVec;
|
||||
for (auto el : vec) {
|
||||
resVec.reserve(vec.size());
|
||||
for (auto& el : vec) {
|
||||
resVec.push_back(static_cast<toType>(el));
|
||||
}
|
||||
return resVec;
|
||||
}
|
||||
|
||||
} // namespace Utils
|
||||
} // namespace NGraphFunctions
|
||||
} // namespace NGraphFunctions
|
||||
|
@ -1,57 +0,0 @@
|
||||
// Copyright (C) 2019 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#include <vector>
|
||||
#include <memory>
|
||||
|
||||
#include "ngraph_functions/builders.hpp"
|
||||
#include "ngraph_functions/utils/ngraph_helpers.hpp"
|
||||
|
||||
|
||||
namespace ngraph {
|
||||
namespace builder {
|
||||
|
||||
std::shared_ptr<Node> makeConstant(const element::Type &type, const std::vector<size_t> &shape,
|
||||
const std::vector<float> &data, bool random) {
|
||||
std::shared_ptr<ngraph::Node> weightsNode;
|
||||
|
||||
#define makeNode(TYPE) \
|
||||
case TYPE: \
|
||||
weightsNode = std::make_shared<ngraph::opset1::Constant>( \
|
||||
type, shape, \
|
||||
random ? NGraphFunctions::Utils::generateVector<TYPE>(ngraph::shape_size(shape)) : \
|
||||
NGraphFunctions::Utils::castVector<float, ngraph::helpers::nGraphTypesTrait<TYPE>::value_type >(data)); \
|
||||
break;
|
||||
switch (type) {
|
||||
case ngraph::element::Type_t::bf16:
|
||||
weightsNode = std::make_shared<ngraph::opset1::Constant>(
|
||||
type, shape,
|
||||
random ? NGraphFunctions::Utils::generateBF16Vector(ngraph::shape_size(shape)) :
|
||||
NGraphFunctions::Utils::castVector<float, ngraph::bfloat16>(data));
|
||||
break;
|
||||
case ngraph::element::Type_t::f16:
|
||||
weightsNode = std::make_shared<ngraph::opset1::Constant>(
|
||||
type, shape,
|
||||
random ? NGraphFunctions::Utils::generateF16Vector(ngraph::shape_size(shape)) :
|
||||
NGraphFunctions::Utils::castVector<float, ngraph::float16>(data));
|
||||
break;
|
||||
makeNode(ngraph::element::Type_t::f32);
|
||||
makeNode(ngraph::element::Type_t::f64);
|
||||
makeNode(ngraph::element::Type_t::i8);
|
||||
makeNode(ngraph::element::Type_t::i16);
|
||||
makeNode(ngraph::element::Type_t::i32);
|
||||
makeNode(ngraph::element::Type_t::i64);
|
||||
makeNode(ngraph::element::Type_t::u8);
|
||||
makeNode(ngraph::element::Type_t::u16);
|
||||
makeNode(ngraph::element::Type_t::u32);
|
||||
makeNode(ngraph::element::Type_t::u64);
|
||||
makeNode(ngraph::element::Type_t::boolean);
|
||||
#undef makeNode
|
||||
default:
|
||||
throw std::runtime_error("Unhandled precision");
|
||||
}
|
||||
return weightsNode;
|
||||
}
|
||||
} // namespace builder
|
||||
} // namespace ngraph
|
44
inference-engine/tests/ngraph_functions/src/ctc_loss.cpp
Normal file
44
inference-engine/tests/ngraph_functions/src/ctc_loss.cpp
Normal file
@ -0,0 +1,44 @@
|
||||
// Copyright (C) 2020 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#include <vector>
|
||||
#include <memory>
|
||||
|
||||
#include "ngraph_functions/builders.hpp"
|
||||
|
||||
namespace ngraph {
|
||||
namespace builder {
|
||||
|
||||
std::shared_ptr<Node> makeCTCLoss(
|
||||
const ngraph::Output<Node>& logitsNode,
|
||||
std::vector<int>& logitsLength,
|
||||
std::vector<std::vector<int>>& labels,
|
||||
std::vector<int>& labelsLength,
|
||||
int blankIndex,
|
||||
const element::Type& fType,
|
||||
const element::Type& iType,
|
||||
const bool preprocessCollapseRepeated,
|
||||
const bool ctcMergeRepeated,
|
||||
const bool unique) {
|
||||
auto logitsShape = logitsNode.get_shape();
|
||||
size_t N = logitsShape[0];
|
||||
size_t T = logitsShape[1];
|
||||
|
||||
std::vector<int> labelsOneD(N * T);
|
||||
for (int i = 0; i < labels.size(); i++)
|
||||
std::copy(labels[i].begin(), labels[i].end(), labelsOneD.data() + i * T);
|
||||
|
||||
auto logitsLengthNode = makeConstant(iType, {N}, logitsLength);
|
||||
auto labelsNode = makeConstant(iType, {N, T}, labelsOneD);
|
||||
auto labelsLengthNode = makeConstant(iType, {N}, labelsLength);
|
||||
auto blankIndexNode = makeConstant<int>(iType, {}, {blankIndex});
|
||||
|
||||
auto ctcLossNode = std::make_shared<opset4::CTCLoss>(logitsNode, logitsLengthNode, labelsNode,
|
||||
labelsLengthNode, blankIndexNode, preprocessCollapseRepeated, ctcMergeRepeated, unique);
|
||||
|
||||
return ctcLossNode;
|
||||
}
|
||||
|
||||
} // namespace builder
|
||||
} // namespace ngraph
|
@ -29,7 +29,7 @@ std::shared_ptr<Node> makeEmbeddingBagOffsetsSum(
|
||||
std::vector<size_t> d_shape = {};
|
||||
auto defIdxNode = std::make_shared<ngraph::opset1::Constant>(indicesType, d_shape, default_index);
|
||||
if (with_weights) {
|
||||
auto weightsNode = makeConstant(dataType, {indices.size()}, {}, true);
|
||||
auto weightsNode = makeConstant<float>(dataType, {indices.size()}, {}, true);
|
||||
|
||||
embBag = std::make_shared<opset3::EmbeddingBagOffsetsSum>(
|
||||
embTableNode, indicesNode, offsetsNode, defIdxNode, weightsNode);
|
||||
|
@ -25,7 +25,7 @@ std::shared_ptr<Node> makeEmbeddingBagPackedSum(
|
||||
|
||||
std::shared_ptr<Node> embBag;
|
||||
if (with_weights) {
|
||||
auto weightsNode = makeConstant(dataType, i_shape, {}, true);
|
||||
auto weightsNode = makeConstant<float>(dataType, i_shape, {}, true);
|
||||
|
||||
embBag = std::make_shared<opset3::EmbeddingBagPackedSum>(
|
||||
embTableNode, indicesNode, weightsNode);
|
||||
|
@ -31,7 +31,7 @@ std::shared_ptr<Node> makeEmbeddingSegmentsSum(
|
||||
if (with_default_index) {
|
||||
auto defIdxNode = std::make_shared<ngraph::opset1::Constant>(indicesType, shape_0, default_index);
|
||||
if (with_weights) {
|
||||
auto weightsNode = makeConstant(dataType, {indices.size()}, {}, true);
|
||||
auto weightsNode = makeConstant<float>(dataType, {indices.size()}, {}, true);
|
||||
|
||||
embBag = std::make_shared<opset3::EmbeddingSegmentsSum>(
|
||||
embTableNode, indicesNode, segmentIdNode, segmentNumNode, defIdxNode, weightsNode);
|
||||
|
@ -43,11 +43,11 @@ namespace ngraph
|
||||
{
|
||||
U actualLogitLen = logitsLength[b];
|
||||
U actualTargetLen = labelsLength[b];
|
||||
if (actualLogitLen >= maxTime || actualTargetLen >= maxTime ||
|
||||
if (actualLogitLen > maxTime || actualTargetLen > maxTime ||
|
||||
actualTargetLen > actualLogitLen)
|
||||
{
|
||||
throw ngraph_error(
|
||||
std::string("Logit or label length cannot be more than max sequence"
|
||||
std::string("Logit or label length cannot greater than max sequence"
|
||||
"length. Also a label length cannot be greater than a"
|
||||
"logit length.\nMaxSeqLen: ") +
|
||||
std::to_string(maxTime) + "; Logit len: " +
|
||||
@ -95,82 +95,110 @@ namespace ngraph
|
||||
|
||||
const size_t BTC = b * TC;
|
||||
|
||||
std::vector<T> kExp(actualLogitLen, 0);
|
||||
std::vector<std::unordered_map<size_t, T>> logProbabilities(actualLogitLen);
|
||||
T logProb = 0.f, kExp = 0.f;
|
||||
for (size_t t = 0; t < actualLogitLen; t++)
|
||||
{
|
||||
size_t btcT = BTC + classesNum * t;
|
||||
kExp = 0.f;
|
||||
const size_t btcT = BTC + classesNum * t;
|
||||
for (size_t c = 0; c < classesNum; c++)
|
||||
{
|
||||
kExp[t] += std::exp(logits[btcT + c]);
|
||||
kExp += std::exp(logits[btcT + c]);
|
||||
}
|
||||
for (size_t s = 0; s < decodedTargetLen; s++)
|
||||
{
|
||||
logProb = logits[btcT + targetD[s]] - std::log(kExp);
|
||||
logProbabilities[t].insert({targetD[s], logProb});
|
||||
}
|
||||
logProb = logits[btcT + blankIndex] - std::log(kExp);
|
||||
logProbabilities[t].insert({blankIndex, logProb});
|
||||
}
|
||||
|
||||
T res = -std::numeric_limits<T>::infinity();
|
||||
const auto type_inf = std::numeric_limits<T>::infinity();
|
||||
T res = -type_inf;
|
||||
|
||||
// Looking for aligned paths
|
||||
std::function<void(size_t targetIdx, size_t start, size_t end)> findPaths = [&](
|
||||
size_t targetIdx, size_t start, size_t end) {
|
||||
std::function<void(size_t, size_t, size_t, T)> findPaths = [&](
|
||||
size_t targetIdx, size_t start, size_t end, T prevLogProb) {
|
||||
if (end > actualLogitLen)
|
||||
{
|
||||
T prod = 0;
|
||||
for (size_t t = 0; t < actualLogitLen; t++)
|
||||
if (res == -type_inf)
|
||||
{
|
||||
prod += std::log(std::exp(logits[BTC + classesNum * t + pathS[t]]) /
|
||||
kExp[t]);
|
||||
res = prevLogProb;
|
||||
}
|
||||
else if (prevLogProb != -type_inf)
|
||||
{
|
||||
if (res > prevLogProb)
|
||||
res = res + std::log1pf(std::exp(prevLogProb - res));
|
||||
else
|
||||
res = prevLogProb + std::log1pf(std::exp(res - prevLogProb));
|
||||
}
|
||||
if (res == -std::numeric_limits<T>::infinity())
|
||||
res = prod;
|
||||
else if (prod != -std::numeric_limits<T>::infinity())
|
||||
res = res + std::log1pf(std::exp(prod - res));
|
||||
|
||||
return;
|
||||
}
|
||||
|
||||
size_t nextIdx = targetIdx + 1;
|
||||
int64_t st64 = start;
|
||||
T newLogProb = prevLogProb;
|
||||
if (!ctcMergeRepeated)
|
||||
{
|
||||
for (size_t pos = start; pos < end; pos++)
|
||||
{
|
||||
newLogProb = prevLogProb;
|
||||
for (size_t bl = start; bl < pos; bl++)
|
||||
{
|
||||
pathS[bl] = blankIndex;
|
||||
newLogProb += logProbabilities[bl].find(blankIndex)->second;
|
||||
}
|
||||
pathS[pos] = targetD[targetIdx];
|
||||
findPaths(nextIdx, pos + 1, end + 1);
|
||||
newLogProb +=
|
||||
logProbabilities[pos].find(targetD[targetIdx])->second;
|
||||
if (end == actualLogitLen)
|
||||
{
|
||||
for (int64_t ble = pos + 1; ble < actualLogitLen; ble++)
|
||||
{
|
||||
newLogProb +=
|
||||
logProbabilities[ble].find(blankIndex)->second;
|
||||
}
|
||||
}
|
||||
findPaths(nextIdx, pos + 1, end + 1, newLogProb);
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
for (size_t pos = start; pos < end; pos++)
|
||||
{
|
||||
newLogProb = prevLogProb;
|
||||
size_t next_start = pos + 1;
|
||||
for (size_t bl = start; bl < pos; bl++)
|
||||
{
|
||||
pathS[bl] = blankIndex;
|
||||
newLogProb += logProbabilities[bl].find(blankIndex)->second;
|
||||
}
|
||||
if (end == actualLogitLen)
|
||||
{
|
||||
for (int64_t ble = pos + 1; ble < actualLogitLen; ble++)
|
||||
{
|
||||
newLogProb +=
|
||||
logProbabilities[ble].find(blankIndex)->second;
|
||||
}
|
||||
}
|
||||
if (targetIdx < decodedTargetLen - 1 &&
|
||||
targetD[targetIdx] == targetD[targetIdx + 1])
|
||||
{
|
||||
newLogProb +=
|
||||
logProbabilities[next_start++].find(blankIndex)->second;
|
||||
}
|
||||
for (int64_t bl = pos; bl >= st64; bl--)
|
||||
{
|
||||
pathS[bl] = targetD[targetIdx];
|
||||
if (end == actualLogitLen)
|
||||
{
|
||||
for (int64_t ble = pos + 1; ble < actualLogitLen; ble++)
|
||||
{
|
||||
pathS[ble] = blankIndex;
|
||||
}
|
||||
}
|
||||
size_t next_start = pos + 1;
|
||||
if (targetIdx < decodedTargetLen - 1 &&
|
||||
targetD[targetIdx] == targetD[targetIdx + 1])
|
||||
{
|
||||
pathS[next_start++] = blankIndex;
|
||||
}
|
||||
findPaths(nextIdx, next_start, end + 1);
|
||||
newLogProb +=
|
||||
logProbabilities[bl].find(targetD[targetIdx])->second;
|
||||
findPaths(nextIdx, next_start, end + 1, newLogProb);
|
||||
if (bl > 0)
|
||||
newLogProb -=
|
||||
logProbabilities[bl - 1].find(blankIndex)->second;
|
||||
}
|
||||
}
|
||||
}
|
||||
}; // findPaths
|
||||
|
||||
findPaths(0lu, 0lu, actualLogitLen - decodedTargetLen + 1lu);
|
||||
findPaths(0lu, 0lu, actualLogitLen - decodedTargetLen + 1lu, 0.f);
|
||||
|
||||
output[b] = -res;
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user