[CPU] CTCLoss performance improvement.
This commit is contained in:
parent
8715b60d88
commit
ff7fc01c76
@ -60,6 +60,8 @@ public:
|
|||||||
StatusCode execute(std::vector<Blob::Ptr>& inputs,
|
StatusCode execute(std::vector<Blob::Ptr>& inputs,
|
||||||
std::vector<Blob::Ptr>& outputs,
|
std::vector<Blob::Ptr>& outputs,
|
||||||
ResponseDesc *resp) noexcept override {
|
ResponseDesc *resp) noexcept override {
|
||||||
|
StatusCode returnCode = OK;
|
||||||
|
|
||||||
const float* logits = inputs[0]->cbuffer().as<const float*>() +
|
const float* logits = inputs[0]->cbuffer().as<const float*>() +
|
||||||
inputs[0]->getTensorDesc().getBlockingDesc().getOffsetPadding();
|
inputs[0]->getTensorDesc().getBlockingDesc().getOffsetPadding();
|
||||||
const int* logitsLength = inputs[1]->cbuffer().as<const int*>() +
|
const int* logitsLength = inputs[1]->cbuffer().as<const int*>() +
|
||||||
@ -72,37 +74,47 @@ public:
|
|||||||
outputs[0]->getTensorDesc().getBlockingDesc().getOffsetPadding();
|
outputs[0]->getTensorDesc().getBlockingDesc().getOffsetPadding();
|
||||||
|
|
||||||
const auto& logitsShape = inputs[0]->getTensorDesc().getDims();
|
const auto& logitsShape = inputs[0]->getTensorDesc().getDims();
|
||||||
const auto batchNum = logitsShape[0];
|
const size_t batchNum = logitsShape[0];
|
||||||
const auto maxTime = logitsShape[1];
|
const size_t maxTime = logitsShape[1];
|
||||||
const auto classesNum = logitsShape[2];
|
const size_t classesNum = logitsShape[2];
|
||||||
|
|
||||||
int blankIndex = classesNum - 1;
|
int blankIndex = classesNum - 1;
|
||||||
if (inputs.size() > 4) {
|
if (inputs.size() > 4) {
|
||||||
blankIndex = inputs[4]->cbuffer().as<const int*>()[0];
|
blankIndex = inputs[4]->cbuffer().as<const int*>()[0];
|
||||||
}
|
}
|
||||||
|
|
||||||
std::vector<int> targetD(maxTime);
|
std::vector<int> decodedTargetLenB(batchNum, 0);
|
||||||
|
std::vector<std::vector<int>> targetDB(batchNum);
|
||||||
|
std::vector<std::vector<std::vector<float>>> logProbabilitiesB(batchNum);
|
||||||
|
size_t workAmount2 = 0lu;
|
||||||
|
std::vector<std::string> errorMsgB(parallel_get_max_threads());
|
||||||
|
|
||||||
const size_t TC = maxTime * classesNum;
|
auto threadBody_1 = [&](const int ithr, const int nthr) {
|
||||||
|
size_t start(0lu), end(0lu);
|
||||||
|
splitter(batchNum, nthr, ithr, start, end);
|
||||||
|
if (start >= end)
|
||||||
|
return;
|
||||||
|
|
||||||
for (size_t b = 0; b < batchNum; b++) {
|
for (size_t b = start; b < end; b++) {
|
||||||
const int actualLogitLen = logitsLength[b];
|
if (logitsLength[b] < 0 || labelsLength[b] < 0 || logitsLength[b] > maxTime || labelsLength[b] > logitsLength[b]) {
|
||||||
const int actualTargetLen = labelsLength[b];
|
errorMsgB[ithr] = _logPrefix + ". Logit length cannot be greater than max sequence length. "
|
||||||
if (actualLogitLen < 0 || actualTargetLen < 0 || actualLogitLen > maxTime || actualTargetLen > maxTime
|
+ "Label length cannot be greater than a logit length"
|
||||||
|| actualTargetLen > actualLogitLen) {
|
|
||||||
std::string errorMsg = _logPrefix + ". Logit or label length cannot be greater than max sequence length. "
|
|
||||||
+ "Also a label length cannot be greater than a logit length"
|
|
||||||
+ " and both cannot be negative.\nMaxSeqLen: "
|
+ " and both cannot be negative.\nMaxSeqLen: "
|
||||||
+ std::to_string(maxTime) + "; Logit len: " + std::to_string(actualLogitLen)
|
+ std::to_string(maxTime) + "; Logit len: " + std::to_string(logitsLength[b])
|
||||||
+ "; Label len: " + std::to_string(actualTargetLen);
|
+ "; Label len: " + std::to_string(labelsLength[b]);
|
||||||
errorMsg.copy(resp->msg, sizeof(resp->msg) - 1);
|
returnCode = GENERAL_ERROR;
|
||||||
return GENERAL_ERROR;
|
return;
|
||||||
}
|
}
|
||||||
|
const size_t actualLogitLen = logitsLength[b];
|
||||||
const int* target = &labels[b * maxTime];
|
const size_t actualTargetLen = labelsLength[b];
|
||||||
// Decoding target: merge repeated characters if preprocess_collapse_repeated == True,
|
|
||||||
// find unique elemnts if unique == True
|
|
||||||
size_t decodedTargetLen = 0lu;
|
size_t decodedTargetLen = 0lu;
|
||||||
|
|
||||||
|
// Decoding target: merge repeated characters if preprocess_collapse_repeated == True,
|
||||||
|
// find unique elemnts if unique == True.
|
||||||
|
// Inserts blanks before each index and a blank at the end.
|
||||||
|
const int* target = &labels[b * maxTime];
|
||||||
|
targetDB[b].resize(actualTargetLen * 2 + 1);
|
||||||
|
auto& targetD = targetDB[b];
|
||||||
if (_unique) {
|
if (_unique) {
|
||||||
std::unordered_set<int> uniqVals;
|
std::unordered_set<int> uniqVals;
|
||||||
for (size_t t = 0lu; t < actualTargetLen; t++) {
|
for (size_t t = 0lu; t < actualTargetLen; t++) {
|
||||||
@ -110,219 +122,162 @@ public:
|
|||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
uniqVals.insert(target[t]);
|
uniqVals.insert(target[t]);
|
||||||
|
targetD[decodedTargetLen++] = blankIndex;
|
||||||
targetD[decodedTargetLen++] = target[t];
|
targetD[decodedTargetLen++] = target[t];
|
||||||
}
|
}
|
||||||
|
targetD[decodedTargetLen++] = blankIndex;
|
||||||
} else if (_preprocessCollapseRepeated) {
|
} else if (_preprocessCollapseRepeated) {
|
||||||
int prevValue = target[0];
|
auto prevValue = target[0];
|
||||||
|
targetD[decodedTargetLen++] = blankIndex;
|
||||||
targetD[decodedTargetLen++] = target[0];
|
targetD[decodedTargetLen++] = target[0];
|
||||||
for (size_t t = 1lu; t < actualTargetLen; t++) {
|
for (size_t t = 1lu; t < actualTargetLen; t++) {
|
||||||
if (target[t] == prevValue) {
|
if (target[t] == prevValue) {
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
targetD[decodedTargetLen++] = target[t];
|
targetD[decodedTargetLen++] = blankIndex;
|
||||||
prevValue = target[t];
|
targetD[decodedTargetLen++] = prevValue = target[t];
|
||||||
}
|
}
|
||||||
|
targetD[decodedTargetLen++] = blankIndex;
|
||||||
} else {
|
} else {
|
||||||
std::copy(target, target + actualTargetLen, targetD.data());
|
for (size_t t = 0lu; t < actualTargetLen; t++) {
|
||||||
decodedTargetLen = actualTargetLen;
|
targetD[decodedTargetLen++] = blankIndex;
|
||||||
|
targetD[decodedTargetLen++] = target[t];
|
||||||
|
}
|
||||||
|
targetD[decodedTargetLen++] = blankIndex;
|
||||||
|
}
|
||||||
|
decodedTargetLenB[b] = decodedTargetLen;
|
||||||
|
|
||||||
|
auto& logProbabilities = logProbabilitiesB[b];
|
||||||
|
logProbabilities.resize(actualLogitLen);
|
||||||
|
for (size_t ll = 0; ll < actualLogitLen; ll++) {
|
||||||
|
logProbabilities[ll].resize(decodedTargetLen);
|
||||||
|
}
|
||||||
|
workAmount2 += actualLogitLen;
|
||||||
|
} // for batch
|
||||||
|
}; // threadBody_1
|
||||||
|
|
||||||
|
parallel_nt(0, threadBody_1);
|
||||||
|
if (returnCode != OK) {
|
||||||
|
std::string resErr("");
|
||||||
|
for (auto& err : errorMsgB) {
|
||||||
|
if (!err.empty())
|
||||||
|
resErr += err + "\n";
|
||||||
|
resErr.copy(resp->msg, sizeof(resp->msg) - 1);
|
||||||
|
}
|
||||||
|
return returnCode;
|
||||||
}
|
}
|
||||||
|
|
||||||
const size_t BTC = b * TC;
|
const size_t TC = maxTime * classesNum;
|
||||||
|
|
||||||
std::vector<std::unordered_map<size_t, float>> logProbabilities(actualLogitLen);
|
auto threadBody_2 = [&](const int ithr, const int nthr) {
|
||||||
float logProb = 0.f, kExp = 0.f;
|
size_t start(0lu), end(0lu);
|
||||||
for (size_t t = 0; t < actualLogitLen; t++) {
|
size_t sB(0lu), sT(0lu);
|
||||||
kExp = 0.f;
|
splitter(workAmount2, nthr, ithr, start, end);
|
||||||
const size_t btcT = BTC + classesNum * t;
|
if (start >= end)
|
||||||
for (size_t c = 0; c < classesNum; c++) {
|
return;
|
||||||
kExp += std::exp(logits[btcT + c]);
|
int64_t cw = 0, st = start;
|
||||||
|
for (; sB < batchNum; sB++) {
|
||||||
|
cw += logitsLength[sB];
|
||||||
|
if (cw >= st) {
|
||||||
|
sT = logitsLength[sB] + st - cw;
|
||||||
|
break;
|
||||||
}
|
}
|
||||||
for (size_t s = 0; s < decodedTargetLen; s++) {
|
|
||||||
logProb = logits[btcT + targetD[s]] - std::log(kExp);
|
|
||||||
logProbabilities[t].insert({targetD[s], logProb});
|
|
||||||
}
|
}
|
||||||
logProb = logits[btcT + blankIndex] - std::log(kExp);
|
size_t workCounter = start;
|
||||||
logProbabilities[t].insert({blankIndex, logProb});
|
|
||||||
|
for (size_t b = sB; b < batchNum; b++) {
|
||||||
|
const size_t actualLogitLen = logitsLength[b];
|
||||||
|
const size_t decodedTargetLen = decodedTargetLenB[b];
|
||||||
|
auto& logProbabilities = logProbabilitiesB[b];
|
||||||
|
auto& targetD = targetDB[b];
|
||||||
|
|
||||||
|
double expSum = 0.0;
|
||||||
|
size_t btcT = b * TC + sT * classesNum;
|
||||||
|
// logProbabilities = logSoftmax = logits[b][t][c] - ln(sum_c(exp(logits[b][t])))
|
||||||
|
for (size_t t = sT; t < actualLogitLen; t++) {
|
||||||
|
expSum = 0.0;
|
||||||
|
for (size_t c = 0lu; c < classesNum; c++) {
|
||||||
|
expSum += std::exp(logits[btcT + c]);
|
||||||
}
|
}
|
||||||
|
for (size_t s = 0lu; s < decodedTargetLen; s++) {
|
||||||
|
logProbabilities[t][s] = logits[btcT + targetD[s]] - std::log(expSum);
|
||||||
|
}
|
||||||
|
btcT += classesNum;
|
||||||
|
if (++workCounter >= end) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
sT = 0lu;
|
||||||
|
} // for batch
|
||||||
|
}; // threadBody_2
|
||||||
|
|
||||||
|
parallel_nt(0, threadBody_2);
|
||||||
|
|
||||||
const auto float_inf = std::numeric_limits<float>::infinity();
|
const auto float_inf = std::numeric_limits<float>::infinity();
|
||||||
size_t work_amount = actualLogitLen - decodedTargetLen + 1lu;
|
|
||||||
std::vector<float> sumPerThread(parallel_get_max_threads(), -float_inf);
|
|
||||||
|
|
||||||
// Looking for aligned paths
|
auto sumLogs = [&float_inf](float log1, float log2) {
|
||||||
auto thread_body = [&](const int ithr, const int nthr) {
|
if (log1 == -float_inf) {
|
||||||
size_t start0(0lu), end0(0lu);
|
return log2;
|
||||||
splitter(work_amount, nthr, ithr, start0, end0);
|
} else if (log2 == -float_inf) {
|
||||||
if (start0 >= end0)
|
return log1;
|
||||||
return;
|
} else {
|
||||||
if (ithr >= sumPerThread.size())
|
if (log1 > log2)
|
||||||
sumPerThread.push_back(-float_inf);
|
return log1 + std::log1pf(std::exp(log2 - log1));
|
||||||
|
|
||||||
std::function<void(size_t, size_t, size_t, float)> findPaths =
|
|
||||||
[&](size_t targetIdx, size_t start, size_t end, float prevLogProb) {
|
|
||||||
if (end > actualLogitLen) {
|
|
||||||
if (sumPerThread[ithr] == -float_inf) {
|
|
||||||
sumPerThread[ithr] = prevLogProb;
|
|
||||||
} else if (prevLogProb != -float_inf) {
|
|
||||||
if (sumPerThread[ithr] > prevLogProb)
|
|
||||||
sumPerThread[ithr] = sumPerThread[ithr] + std::log1pf(std::exp(prevLogProb - sumPerThread[ithr]));
|
|
||||||
else
|
else
|
||||||
sumPerThread[ithr] = prevLogProb + std::log1pf(std::exp(sumPerThread[ithr] - prevLogProb));
|
return log2 + std::log1pf(std::exp(log1 - log2));
|
||||||
}
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
auto threadBody_3 = [&](const int ithr, const int nthr) {
|
||||||
|
size_t start(0lu), end(0lu);
|
||||||
|
splitter(batchNum, nthr, ithr, start, end);
|
||||||
|
if (start >= end)
|
||||||
return;
|
return;
|
||||||
|
|
||||||
|
// As per Connectionist Temporal Classification - Labeling Unsegmented Sequence Data with Recurrent Neural Networks:
|
||||||
|
// Graves et al., 2016, paragraph 4.1 (10)
|
||||||
|
for (size_t b = start; b < end; b++) {
|
||||||
|
auto& targetD = targetDB[b];
|
||||||
|
auto& logProbabilities = logProbabilitiesB[b];
|
||||||
|
const int actualLogitLen = logitsLength[b];
|
||||||
|
const int decodedTargetLen = decodedTargetLenB[b];
|
||||||
|
std::vector<std::vector<float>> logBwd(decodedTargetLen, std::vector<float>(actualLogitLen, -float_inf));
|
||||||
|
for (int s = decodedTargetLen - 2; s < decodedTargetLen; s++)
|
||||||
|
logBwd[s][actualLogitLen - 1] = 0.f;
|
||||||
|
|
||||||
|
for (int t = actualLogitLen - 2; t >= 0; t--) {
|
||||||
|
const int t_1 = t + 1;
|
||||||
|
for (int s = std::max(0, decodedTargetLen - (2 * (actualLogitLen - t)));
|
||||||
|
s < std::min(decodedTargetLen, 2 * (t_1)); s++) {
|
||||||
|
if (_ctcMergeRepeated || targetD[s] == blankIndex) {
|
||||||
|
logBwd[s][t] = sumLogs(logBwd[s][t],
|
||||||
|
logBwd[s][t_1] + logProbabilities[t_1][s]);
|
||||||
}
|
}
|
||||||
|
|
||||||
size_t nextIdx = targetIdx + 1;
|
if (s + 1 < decodedTargetLen) {
|
||||||
int64_t st64 = start;
|
logBwd[s][t] = sumLogs(logBwd[s][t],
|
||||||
float newLogProb = prevLogProb;
|
logBwd[s + 1][t_1] + logProbabilities[t_1][s + 1]);
|
||||||
if (!_ctcMergeRepeated) {
|
|
||||||
for (size_t pos = start; pos < end; pos++) {
|
|
||||||
newLogProb = prevLogProb;
|
|
||||||
for (size_t bl = start; bl < pos; bl++) {
|
|
||||||
auto lnProbIt = logProbabilities[bl].find(blankIndex);
|
|
||||||
if (lnProbIt != logProbabilities[bl].end())
|
|
||||||
newLogProb += lnProbIt->second;
|
|
||||||
}
|
}
|
||||||
auto lnProbIt = logProbabilities[pos].find(targetD[targetIdx]);
|
|
||||||
if (lnProbIt != logProbabilities[pos].end())
|
|
||||||
newLogProb += lnProbIt->second;
|
|
||||||
if (end == actualLogitLen) {
|
|
||||||
for (int64_t ble = pos + 1; ble < actualLogitLen; ble++) {
|
|
||||||
auto lnProbIt = logProbabilities[ble].find(blankIndex);
|
|
||||||
if (lnProbIt != logProbabilities[ble].end())
|
|
||||||
newLogProb += lnProbIt->second;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
findPaths(nextIdx, pos + 1, end + 1, newLogProb);
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
for (size_t pos = start; pos < end; pos++) {
|
|
||||||
newLogProb = prevLogProb;
|
|
||||||
size_t next_start = pos + 1;
|
|
||||||
for (size_t bl = start; bl < pos; bl++) {
|
|
||||||
auto lnProbIt = logProbabilities[bl].find(blankIndex);
|
|
||||||
if (lnProbIt != logProbabilities[bl].end())
|
|
||||||
newLogProb += lnProbIt->second;
|
|
||||||
}
|
|
||||||
if (end == actualLogitLen) {
|
|
||||||
for (int64_t ble = pos + 1; ble < actualLogitLen; ble++) {
|
|
||||||
auto lnProbIt = logProbabilities[ble].find(blankIndex);
|
|
||||||
if (lnProbIt != logProbabilities[ble].end())
|
|
||||||
newLogProb += lnProbIt->second;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if (targetIdx < decodedTargetLen - 1
|
|
||||||
&& targetD[targetIdx] == targetD[targetIdx + 1]) {
|
|
||||||
auto lnProbIt = logProbabilities[next_start++].find(blankIndex);
|
|
||||||
if (lnProbIt != logProbabilities[next_start].end())
|
|
||||||
newLogProb += lnProbIt->second;
|
|
||||||
}
|
|
||||||
for (int64_t bl = pos; bl >= st64; bl--) {
|
|
||||||
newLogProb += logProbabilities[bl].find(targetD[targetIdx])->second;
|
|
||||||
findPaths(nextIdx, next_start, end + 1, newLogProb);
|
|
||||||
if (bl > 0) {
|
|
||||||
auto lnProbIt = logProbabilities[bl - 1].find(blankIndex);
|
|
||||||
if (lnProbIt != logProbabilities[bl - 1].end())
|
|
||||||
newLogProb -= lnProbIt->second;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}; // findPaths
|
|
||||||
|
|
||||||
// First tartget symbol
|
if (s + 2 < decodedTargetLen) {
|
||||||
int64_t st64 = start0;
|
if (targetD[s] != blankIndex && (!_ctcMergeRepeated || (targetD[s] != targetD[s + 2]))) {
|
||||||
float newLogProb = 0.f;
|
logBwd[s][t] = sumLogs(logBwd[s][t],
|
||||||
if (!_ctcMergeRepeated) {
|
logBwd[s + 2][t_1] + logProbabilities[t_1][s + 2]);
|
||||||
for (size_t pos = start0; pos < end0; pos++) {
|
|
||||||
newLogProb = 0.f;
|
|
||||||
for (size_t bl = 0; bl < pos; bl++) {
|
|
||||||
auto lnProbIt = logProbabilities[bl].find(blankIndex);
|
|
||||||
if (lnProbIt != logProbabilities[bl].end())
|
|
||||||
newLogProb += lnProbIt->second;
|
|
||||||
}
|
|
||||||
auto lnProbIt = logProbabilities[pos].find(targetD[0]);
|
|
||||||
if (lnProbIt != logProbabilities[pos].end())
|
|
||||||
newLogProb += lnProbIt->second;
|
|
||||||
if (work_amount == actualLogitLen) {
|
|
||||||
for (int64_t ble = pos + 1; ble < actualLogitLen; ble++) {
|
|
||||||
auto lnProbIt = logProbabilities[ble].find(blankIndex);
|
|
||||||
if (lnProbIt != logProbabilities[ble].end())
|
|
||||||
newLogProb += lnProbIt->second;
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if (decodedTargetLen > 1) {
|
|
||||||
findPaths(1, pos + 1, work_amount + 1, newLogProb);
|
|
||||||
} else {
|
|
||||||
if (sumPerThread[ithr] == -float_inf)
|
|
||||||
sumPerThread[ithr] = newLogProb;
|
|
||||||
else if (newLogProb != -float_inf)
|
|
||||||
sumPerThread[ithr] = sumPerThread[ithr] + std::log1pf(std::exp(newLogProb - sumPerThread[ithr]));
|
|
||||||
}
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
for (size_t pos = start0; pos < end0; pos++) {
|
|
||||||
newLogProb = 0.f;
|
|
||||||
size_t next_start = pos + 1;
|
|
||||||
for (size_t bl = 0; bl < pos; bl++) {
|
|
||||||
auto lnProbIt = logProbabilities[bl].find(blankIndex);
|
|
||||||
if (lnProbIt != logProbabilities[bl].end())
|
|
||||||
newLogProb += lnProbIt->second;
|
|
||||||
}
|
|
||||||
if (work_amount == actualLogitLen) {
|
|
||||||
for (int64_t ble = pos + 1; ble < actualLogitLen; ble++) {
|
|
||||||
auto lnProbIt = logProbabilities[ble].find(blankIndex);
|
|
||||||
if (lnProbIt != logProbabilities[ble].end())
|
|
||||||
newLogProb += lnProbIt->second;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if (decodedTargetLen > 1
|
|
||||||
&& targetD[0] == targetD[1]) {
|
|
||||||
auto lnProbIt = logProbabilities[next_start++].find(blankIndex);
|
|
||||||
if (lnProbIt != logProbabilities[next_start].end())
|
|
||||||
newLogProb += lnProbIt->second;
|
|
||||||
}
|
|
||||||
for (int64_t bl = pos; bl >= 0; bl--) {
|
|
||||||
auto lnProbIt = logProbabilities[bl].find(targetD[0]);
|
|
||||||
if (lnProbIt != logProbabilities[bl].end())
|
|
||||||
newLogProb += lnProbIt->second;
|
|
||||||
if (decodedTargetLen > 1) {
|
|
||||||
findPaths(1, next_start, work_amount + 1, newLogProb);
|
|
||||||
} else {
|
|
||||||
if (sumPerThread[ithr] == -float_inf)
|
|
||||||
sumPerThread[ithr] = newLogProb;
|
|
||||||
else if (newLogProb != -float_inf)
|
|
||||||
sumPerThread[ithr] = sumPerThread[ithr] + std::log1pf(std::exp(newLogProb - sumPerThread[ithr]));
|
|
||||||
}
|
|
||||||
if (bl > 0) {
|
|
||||||
auto lnProbIt = logProbabilities[bl - 1].find(blankIndex);
|
|
||||||
if (lnProbIt != logProbabilities[bl - 1].end())
|
|
||||||
newLogProb -= lnProbIt->second;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}; // thread_body
|
|
||||||
|
|
||||||
parallel_nt(0, thread_body);
|
|
||||||
|
|
||||||
float res = -float_inf;
|
|
||||||
|
|
||||||
for (auto sum : sumPerThread) {
|
|
||||||
if (res == -float_inf) {
|
|
||||||
res = sum;
|
|
||||||
} else if (sum != -float_inf) {
|
|
||||||
if (res > sum)
|
|
||||||
res = res + std::log1pf(std::exp(sum - res));
|
|
||||||
else
|
|
||||||
res = sum + std::log1pf(std::exp(res - sum));
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
dstData[b] = -res;
|
logBwd[0][0] += logProbabilities[0][0];
|
||||||
} // for (size_t b = 0; b < batchNum; b++)
|
logBwd[1][0] += logProbabilities[0][(decodedTargetLen > 1) ? 1 : 0];
|
||||||
|
|
||||||
return OK;
|
dstData[b] = -sumLogs(logBwd[0][0], logBwd[1][0]);
|
||||||
|
} // for batch
|
||||||
|
}; // threadBody_3
|
||||||
|
|
||||||
|
parallel_nt(0, threadBody_3);
|
||||||
|
|
||||||
|
return returnCode;
|
||||||
} // execute
|
} // execute
|
||||||
|
|
||||||
protected:
|
protected:
|
||||||
@ -334,8 +289,6 @@ protected:
|
|||||||
};
|
};
|
||||||
|
|
||||||
REG_FACTORY_FOR(CTCLossImpl, CTCLoss);
|
REG_FACTORY_FOR(CTCLossImpl, CTCLoss);
|
||||||
|
|
||||||
} // namespace Cpu
|
} // namespace Cpu
|
||||||
} // namespace Extensions
|
} // namespace Extensions
|
||||||
} // namespace InferenceEngine
|
} // namespace InferenceEngine
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user