2020-02-11 22:48:49 +03:00
|
|
|
// Copyright (C) 2018-2020 Intel Corporation
|
2019-01-21 21:31:31 +03:00
|
|
|
// SPDX-License-Identifier: Apache-2.0
|
|
|
|
|
//
|
|
|
|
|
|
|
|
|
|
#include "speech_sample.hpp"
|
|
|
|
|
|
|
|
|
|
#include <gflags/gflags.h>
|
|
|
|
|
#include <functional>
|
|
|
|
|
#include <iostream>
|
|
|
|
|
#include <memory>
|
|
|
|
|
#include <map>
|
|
|
|
|
#include <fstream>
|
|
|
|
|
#include <random>
|
|
|
|
|
#include <string>
|
|
|
|
|
#include <vector>
|
|
|
|
|
#include <utility>
|
|
|
|
|
#include <time.h>
|
|
|
|
|
#include <thread>
|
|
|
|
|
#include <chrono>
|
|
|
|
|
#include <limits>
|
|
|
|
|
#include <iomanip>
|
|
|
|
|
#include <inference_engine.hpp>
|
|
|
|
|
#include <gna/gna_config.hpp>
|
|
|
|
|
|
|
|
|
|
#include <samples/common.hpp>
|
|
|
|
|
#include <samples/slog.hpp>
|
|
|
|
|
#include <samples/args_helper.hpp>
|
|
|
|
|
|
|
|
|
|
#define MAX_SCORE_DIFFERENCE 0.0001f
|
|
|
|
|
#define MAX_VAL_2B_FEAT 16384
|
|
|
|
|
|
|
|
|
|
using namespace InferenceEngine;
|
|
|
|
|
|
|
|
|
|
typedef std::chrono::high_resolution_clock Time;
|
|
|
|
|
typedef std::chrono::duration<double, std::ratio<1, 1000>> ms;
|
|
|
|
|
typedef std::chrono::duration<float> fsec;
|
|
|
|
|
typedef struct {
|
|
|
|
|
uint32_t numScores;
|
|
|
|
|
uint32_t numErrors;
|
|
|
|
|
float threshold;
|
|
|
|
|
float maxError;
|
|
|
|
|
float rmsError;
|
|
|
|
|
float sumError;
|
|
|
|
|
float sumRmsError;
|
|
|
|
|
float sumSquaredError;
|
|
|
|
|
float maxRelError;
|
|
|
|
|
float sumRelError;
|
|
|
|
|
float sumSquaredRelError;
|
|
|
|
|
} score_error_t;
|
|
|
|
|
|
2019-04-12 18:25:53 +03:00
|
|
|
struct InferRequestStruct {
|
|
|
|
|
InferRequest inferRequest;
|
|
|
|
|
int frameIndex;
|
|
|
|
|
uint32_t numFramesThisBatch;
|
|
|
|
|
};
|
|
|
|
|
|
2020-04-13 21:17:23 +03:00
|
|
|
void CheckNumberOfInputs(size_t numInputs, size_t numInputArkFiles) {
|
|
|
|
|
if (numInputs != numInputArkFiles) {
|
|
|
|
|
throw std::logic_error("Number of network inputs (" + std::to_string(numInputs) + ")"
|
|
|
|
|
" is not equal to number of ark files (" + std::to_string(numInputArkFiles) + ")");
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
2019-01-21 21:31:31 +03:00
|
|
|
void GetKaldiArkInfo(const char *fileName,
|
|
|
|
|
uint32_t numArrayToFindSize,
|
|
|
|
|
uint32_t *ptrNumArrays,
|
|
|
|
|
uint32_t *ptrNumMemoryBytes) {
|
|
|
|
|
uint32_t numArrays = 0;
|
|
|
|
|
uint32_t numMemoryBytes = 0;
|
|
|
|
|
|
|
|
|
|
std::ifstream in_file(fileName, std::ios::binary);
|
|
|
|
|
if (in_file.good()) {
|
|
|
|
|
while (!in_file.eof()) {
|
|
|
|
|
std::string line;
|
|
|
|
|
uint32_t numRows = 0u, numCols = 0u, num_bytes = 0u;
|
|
|
|
|
std::getline(in_file, line, '\0'); // read variable length name followed by space and NUL
|
|
|
|
|
std::getline(in_file, line, '\4'); // read "BFM" followed by space and control-D
|
|
|
|
|
if (line.compare("BFM ") != 0) {
|
|
|
|
|
break;
|
|
|
|
|
}
|
|
|
|
|
in_file.read(reinterpret_cast<char *>(&numRows), sizeof(uint32_t)); // read number of rows
|
|
|
|
|
std::getline(in_file, line, '\4'); // read control-D
|
|
|
|
|
in_file.read(reinterpret_cast<char *>(&numCols), sizeof(uint32_t)); // read number of columns
|
|
|
|
|
num_bytes = numRows * numCols * sizeof(float);
|
|
|
|
|
in_file.seekg(num_bytes, in_file.cur); // read data
|
|
|
|
|
|
|
|
|
|
if (numArrays == numArrayToFindSize) {
|
|
|
|
|
numMemoryBytes += num_bytes;
|
|
|
|
|
}
|
|
|
|
|
numArrays++;
|
|
|
|
|
}
|
|
|
|
|
in_file.close();
|
|
|
|
|
} else {
|
|
|
|
|
fprintf(stderr, "Failed to open %s for reading in GetKaldiArkInfo()!\n", fileName);
|
|
|
|
|
exit(-1);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
if (ptrNumArrays != NULL) *ptrNumArrays = numArrays;
|
|
|
|
|
if (ptrNumMemoryBytes != NULL) *ptrNumMemoryBytes = numMemoryBytes;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void LoadKaldiArkArray(const char *fileName, uint32_t arrayIndex, std::string &ptrName, std::vector<uint8_t> &memory,
|
|
|
|
|
uint32_t *ptrNumRows, uint32_t *ptrNumColumns, uint32_t *ptrNumBytesPerElement) {
|
|
|
|
|
std::ifstream in_file(fileName, std::ios::binary);
|
|
|
|
|
if (in_file.good()) {
|
|
|
|
|
uint32_t i = 0;
|
|
|
|
|
while (i < arrayIndex) {
|
|
|
|
|
std::string line;
|
|
|
|
|
uint32_t numRows = 0u, numCols = 0u;
|
|
|
|
|
std::getline(in_file, line, '\0'); // read variable length name followed by space and NUL
|
|
|
|
|
std::getline(in_file, line, '\4'); // read "BFM" followed by space and control-D
|
|
|
|
|
if (line.compare("BFM ") != 0) {
|
|
|
|
|
break;
|
|
|
|
|
}
|
|
|
|
|
in_file.read(reinterpret_cast<char *>(&numRows), sizeof(uint32_t)); // read number of rows
|
|
|
|
|
std::getline(in_file, line, '\4'); // read control-D
|
|
|
|
|
in_file.read(reinterpret_cast<char *>(&numCols), sizeof(uint32_t)); // read number of columns
|
|
|
|
|
in_file.seekg(numRows * numCols * sizeof(float), in_file.cur); // read data
|
|
|
|
|
i++;
|
|
|
|
|
}
|
|
|
|
|
if (!in_file.eof()) {
|
|
|
|
|
std::string line;
|
|
|
|
|
std::getline(in_file, ptrName, '\0'); // read variable length name followed by space and NUL
|
|
|
|
|
std::getline(in_file, line, '\4'); // read "BFM" followed by space and control-D
|
|
|
|
|
if (line.compare("BFM ") != 0) {
|
|
|
|
|
fprintf(stderr, "Cannot find array specifier in file %s in LoadKaldiArkArray()!\n", fileName);
|
|
|
|
|
exit(-1);
|
|
|
|
|
}
|
|
|
|
|
in_file.read(reinterpret_cast<char *>(ptrNumRows), sizeof(uint32_t)); // read number of rows
|
|
|
|
|
std::getline(in_file, line, '\4'); // read control-D
|
|
|
|
|
in_file.read(reinterpret_cast<char *>(ptrNumColumns), sizeof(uint32_t)); // read number of columns
|
|
|
|
|
in_file.read(reinterpret_cast<char *>(&memory.front()),
|
|
|
|
|
*ptrNumRows * *ptrNumColumns * sizeof(float)); // read array data
|
|
|
|
|
}
|
|
|
|
|
in_file.close();
|
|
|
|
|
} else {
|
|
|
|
|
fprintf(stderr, "Failed to open %s for reading in GetKaldiArkInfo()!\n", fileName);
|
|
|
|
|
exit(-1);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
*ptrNumBytesPerElement = sizeof(float);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void SaveKaldiArkArray(const char *fileName,
|
|
|
|
|
bool shouldAppend,
|
|
|
|
|
std::string name,
|
|
|
|
|
void *ptrMemory,
|
|
|
|
|
uint32_t numRows,
|
|
|
|
|
uint32_t numColumns) {
|
|
|
|
|
std::ios_base::openmode mode = std::ios::binary;
|
|
|
|
|
if (shouldAppend) {
|
|
|
|
|
mode |= std::ios::app;
|
|
|
|
|
}
|
|
|
|
|
std::ofstream out_file(fileName, mode);
|
|
|
|
|
if (out_file.good()) {
|
|
|
|
|
out_file.write(name.c_str(), name.length()); // write name
|
|
|
|
|
out_file.write("\0", 1);
|
|
|
|
|
out_file.write("BFM ", 4);
|
|
|
|
|
out_file.write("\4", 1);
|
|
|
|
|
out_file.write(reinterpret_cast<char *>(&numRows), sizeof(uint32_t));
|
|
|
|
|
out_file.write("\4", 1);
|
|
|
|
|
out_file.write(reinterpret_cast<char *>(&numColumns), sizeof(uint32_t));
|
|
|
|
|
out_file.write(reinterpret_cast<char *>(ptrMemory), numRows * numColumns * sizeof(float));
|
|
|
|
|
out_file.close();
|
|
|
|
|
} else {
|
|
|
|
|
throw std::runtime_error(std::string("Failed to open %s for writing in SaveKaldiArkArray()!\n") + fileName);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
float ScaleFactorForQuantization(void *ptrFloatMemory, float targetMax, uint32_t numElements) {
|
|
|
|
|
float *ptrFloatFeat = reinterpret_cast<float *>(ptrFloatMemory);
|
|
|
|
|
float max = 0.0;
|
|
|
|
|
float scaleFactor;
|
|
|
|
|
|
|
|
|
|
for (uint32_t i = 0; i < numElements; i++) {
|
|
|
|
|
if (fabs(ptrFloatFeat[i]) > max) {
|
|
|
|
|
max = fabs(ptrFloatFeat[i]);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
if (max == 0) {
|
|
|
|
|
scaleFactor = 1.0;
|
|
|
|
|
} else {
|
|
|
|
|
scaleFactor = targetMax / max;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
return (scaleFactor);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void ClearScoreError(score_error_t *error) {
|
|
|
|
|
error->numScores = 0;
|
|
|
|
|
error->numErrors = 0;
|
|
|
|
|
error->maxError = 0.0;
|
|
|
|
|
error->rmsError = 0.0;
|
|
|
|
|
error->sumError = 0.0;
|
|
|
|
|
error->sumRmsError = 0.0;
|
|
|
|
|
error->sumSquaredError = 0.0;
|
|
|
|
|
error->maxRelError = 0.0;
|
|
|
|
|
error->sumRelError = 0.0;
|
|
|
|
|
error->sumSquaredRelError = 0.0;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void UpdateScoreError(score_error_t *error, score_error_t *totalError) {
|
|
|
|
|
totalError->numErrors += error->numErrors;
|
|
|
|
|
totalError->numScores += error->numScores;
|
|
|
|
|
totalError->sumRmsError += error->rmsError;
|
|
|
|
|
totalError->sumError += error->sumError;
|
|
|
|
|
totalError->sumSquaredError += error->sumSquaredError;
|
|
|
|
|
if (error->maxError > totalError->maxError) {
|
|
|
|
|
totalError->maxError = error->maxError;
|
|
|
|
|
}
|
|
|
|
|
totalError->sumRelError += error->sumRelError;
|
|
|
|
|
totalError->sumSquaredRelError += error->sumSquaredRelError;
|
|
|
|
|
if (error->maxRelError > totalError->maxRelError) {
|
|
|
|
|
totalError->maxRelError = error->maxRelError;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
uint32_t CompareScores(float *ptrScoreArray,
|
|
|
|
|
void *ptrRefScoreArray,
|
|
|
|
|
score_error_t *scoreError,
|
|
|
|
|
uint32_t numRows,
|
|
|
|
|
uint32_t numColumns) {
|
|
|
|
|
uint32_t numErrors = 0;
|
|
|
|
|
|
|
|
|
|
ClearScoreError(scoreError);
|
|
|
|
|
|
|
|
|
|
float *A = ptrScoreArray;
|
|
|
|
|
float *B = reinterpret_cast<float *>(ptrRefScoreArray);
|
|
|
|
|
for (uint32_t i = 0; i < numRows; i++) {
|
|
|
|
|
for (uint32_t j = 0; j < numColumns; j++) {
|
|
|
|
|
float score = A[i * numColumns + j];
|
|
|
|
|
float refscore = B[i * numColumns + j];
|
|
|
|
|
float error = fabs(refscore - score);
|
|
|
|
|
float rel_error = error / (static_cast<float>(fabs(refscore)) + 1e-20f);
|
|
|
|
|
float squared_error = error * error;
|
|
|
|
|
float squared_rel_error = rel_error * rel_error;
|
|
|
|
|
scoreError->numScores++;
|
|
|
|
|
scoreError->sumError += error;
|
|
|
|
|
scoreError->sumSquaredError += squared_error;
|
|
|
|
|
if (error > scoreError->maxError) {
|
|
|
|
|
scoreError->maxError = error;
|
|
|
|
|
}
|
|
|
|
|
scoreError->sumRelError += rel_error;
|
|
|
|
|
scoreError->sumSquaredRelError += squared_rel_error;
|
|
|
|
|
if (rel_error > scoreError->maxRelError) {
|
|
|
|
|
scoreError->maxRelError = rel_error;
|
|
|
|
|
}
|
|
|
|
|
if (error > scoreError->threshold) {
|
|
|
|
|
numErrors++;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
scoreError->rmsError = sqrt(scoreError->sumSquaredError / (numRows * numColumns));
|
|
|
|
|
scoreError->sumRmsError += scoreError->rmsError;
|
|
|
|
|
scoreError->numErrors = numErrors;
|
|
|
|
|
|
|
|
|
|
return (numErrors);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
float StdDevError(score_error_t error) {
|
|
|
|
|
return (sqrt(error.sumSquaredError / error.numScores
|
|
|
|
|
- (error.sumError / error.numScores) * (error.sumError / error.numScores)));
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
float StdDevRelError(score_error_t error) {
|
|
|
|
|
return (sqrt(error.sumSquaredRelError / error.numScores
|
|
|
|
|
- (error.sumRelError / error.numScores) * (error.sumRelError / error.numScores)));
|
|
|
|
|
}
|
|
|
|
|
|
2019-08-09 19:02:42 +03:00
|
|
|
#if !defined(__arm__) && !defined(_M_ARM) && !defined(__aarch64__) && !defined(_M_ARM64)
|
2021-02-02 20:11:40 +03:00
|
|
|
#ifdef _WIN32
|
2019-01-21 21:31:31 +03:00
|
|
|
#include <intrin.h>
|
|
|
|
|
#include <windows.h>
|
|
|
|
|
#else
|
|
|
|
|
|
|
|
|
|
#include <cpuid.h>
|
|
|
|
|
|
|
|
|
|
#endif
|
|
|
|
|
|
|
|
|
|
inline void native_cpuid(unsigned int *eax, unsigned int *ebx,
|
|
|
|
|
unsigned int *ecx, unsigned int *edx) {
|
|
|
|
|
size_t level = *eax;
|
2021-02-02 20:11:40 +03:00
|
|
|
#ifdef _WIN32
|
2019-01-21 21:31:31 +03:00
|
|
|
int regs[4] = {static_cast<int>(*eax), static_cast<int>(*ebx), static_cast<int>(*ecx), static_cast<int>(*edx)};
|
|
|
|
|
__cpuid(regs, level);
|
|
|
|
|
*eax = static_cast<uint32_t>(regs[0]);
|
|
|
|
|
*ebx = static_cast<uint32_t>(regs[1]);
|
|
|
|
|
*ecx = static_cast<uint32_t>(regs[2]);
|
|
|
|
|
*edx = static_cast<uint32_t>(regs[3]);
|
|
|
|
|
#else
|
|
|
|
|
__get_cpuid(level, eax, ebx, ecx, edx);
|
|
|
|
|
#endif
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// return GNA module frequency in MHz
|
|
|
|
|
float getGnaFrequencyMHz() {
|
|
|
|
|
uint32_t eax = 1;
|
|
|
|
|
uint32_t ebx = 0;
|
|
|
|
|
uint32_t ecx = 0;
|
|
|
|
|
uint32_t edx = 0;
|
|
|
|
|
uint32_t family = 0;
|
|
|
|
|
uint32_t model = 0;
|
|
|
|
|
const uint8_t sixth_family = 6;
|
|
|
|
|
const uint8_t cannon_lake_model = 102;
|
|
|
|
|
const uint8_t gemini_lake_model = 122;
|
2020-02-11 22:48:49 +03:00
|
|
|
const uint8_t ice_lake_model = 126;
|
2020-04-13 21:17:23 +03:00
|
|
|
const uint8_t next_model = 140;
|
2019-01-21 21:31:31 +03:00
|
|
|
|
|
|
|
|
native_cpuid(&eax, &ebx, &ecx, &edx);
|
|
|
|
|
family = (eax >> 8) & 0xF;
|
|
|
|
|
|
|
|
|
|
// model is the concatenation of two fields
|
|
|
|
|
// | extended model | model |
|
|
|
|
|
// copy extended model data
|
|
|
|
|
model = (eax >> 16) & 0xF;
|
|
|
|
|
// shift
|
|
|
|
|
model <<= 4;
|
|
|
|
|
// copy model data
|
|
|
|
|
model += (eax >> 4) & 0xF;
|
|
|
|
|
|
2020-02-11 22:48:49 +03:00
|
|
|
if (family == sixth_family) {
|
|
|
|
|
switch (model) {
|
|
|
|
|
case cannon_lake_model:
|
|
|
|
|
case ice_lake_model:
|
2020-04-13 21:17:23 +03:00
|
|
|
case next_model:
|
2020-02-11 22:48:49 +03:00
|
|
|
return 400;
|
|
|
|
|
case gemini_lake_model:
|
|
|
|
|
return 200;
|
|
|
|
|
default:
|
|
|
|
|
return 1;
|
|
|
|
|
}
|
2019-01-21 21:31:31 +03:00
|
|
|
} else {
|
2020-02-11 22:48:49 +03:00
|
|
|
// counters not supported and we returns just default value
|
2019-01-21 21:31:31 +03:00
|
|
|
return 1;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
2019-08-09 19:02:42 +03:00
|
|
|
#endif // if not ARM
|
2019-01-21 21:31:31 +03:00
|
|
|
|
|
|
|
|
void printReferenceCompareResults(score_error_t const &totalError,
|
|
|
|
|
size_t framesNum,
|
|
|
|
|
std::ostream &stream) {
|
|
|
|
|
stream << " max error: " <<
|
|
|
|
|
totalError.maxError << std::endl;
|
|
|
|
|
stream << " avg error: " <<
|
|
|
|
|
totalError.sumError / totalError.numScores << std::endl;
|
|
|
|
|
stream << " avg rms error: " <<
|
|
|
|
|
totalError.sumRmsError / framesNum << std::endl;
|
|
|
|
|
stream << " stdev error: " <<
|
|
|
|
|
StdDevError(totalError) << std::endl << std::endl;
|
|
|
|
|
stream << std::endl;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void printPerformanceCounters(std::map<std::string,
|
|
|
|
|
InferenceEngine::InferenceEngineProfileInfo> const &utterancePerfMap,
|
|
|
|
|
size_t callsNum,
|
2019-08-09 19:02:42 +03:00
|
|
|
std::ostream &stream, std::string fullDeviceName) {
|
|
|
|
|
#if !defined(__arm__) && !defined(_M_ARM) && !defined(__aarch64__) && !defined(_M_ARM64)
|
2019-01-21 21:31:31 +03:00
|
|
|
stream << std::endl << "Performance counts:" << std::endl;
|
|
|
|
|
stream << std::setw(10) << std::right << "" << "Counter descriptions";
|
|
|
|
|
stream << std::setw(22) << "Utt scoring time";
|
|
|
|
|
stream << std::setw(18) << "Avg infer time";
|
|
|
|
|
stream << std::endl;
|
|
|
|
|
|
|
|
|
|
stream << std::setw(46) << "(ms)";
|
|
|
|
|
stream << std::setw(24) << "(us per call)";
|
|
|
|
|
stream << std::endl;
|
|
|
|
|
|
|
|
|
|
for (const auto &it : utterancePerfMap) {
|
|
|
|
|
std::string const &counter_name = it.first;
|
2019-04-12 18:25:53 +03:00
|
|
|
float current_units = static_cast<float>(it.second.realTime_uSec);
|
2019-01-21 21:31:31 +03:00
|
|
|
float call_units = current_units / callsNum;
|
|
|
|
|
// if GNA HW counters
|
|
|
|
|
// get frequency of GNA module
|
2019-04-12 18:25:53 +03:00
|
|
|
float freq = getGnaFrequencyMHz();
|
2019-01-21 21:31:31 +03:00
|
|
|
current_units /= freq * 1000;
|
|
|
|
|
call_units /= freq;
|
|
|
|
|
stream << std::setw(30) << std::left << counter_name.substr(4, counter_name.size() - 1);
|
|
|
|
|
stream << std::setw(16) << std::right << current_units;
|
|
|
|
|
stream << std::setw(21) << std::right << call_units;
|
|
|
|
|
stream << std::endl;
|
|
|
|
|
}
|
|
|
|
|
stream << std::endl;
|
2019-08-09 19:02:42 +03:00
|
|
|
std::cout << std::endl;
|
|
|
|
|
std::cout << "Full device name: " << fullDeviceName << std::endl;
|
|
|
|
|
std::cout << std::endl;
|
2019-01-21 21:31:31 +03:00
|
|
|
#endif
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void getPerformanceCounters(InferenceEngine::InferRequest &request,
|
|
|
|
|
std::map<std::string, InferenceEngine::InferenceEngineProfileInfo> &perfCounters) {
|
|
|
|
|
auto retPerfCounters = request.GetPerformanceCounts();
|
|
|
|
|
|
|
|
|
|
for (const auto &pair : retPerfCounters) {
|
|
|
|
|
perfCounters[pair.first] = pair.second;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void sumPerformanceCounters(std::map<std::string, InferenceEngine::InferenceEngineProfileInfo> const &perfCounters,
|
|
|
|
|
std::map<std::string, InferenceEngine::InferenceEngineProfileInfo> &totalPerfCounters) {
|
|
|
|
|
for (const auto &pair : perfCounters) {
|
|
|
|
|
totalPerfCounters[pair.first].realTime_uSec += pair.second.realTime_uSec;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
2020-06-09 14:36:28 +03:00
|
|
|
std::vector<std::string> ParseScaleFactors(const std::string& str) {
|
|
|
|
|
std::vector<std::string> scaleFactorInput;
|
|
|
|
|
|
|
|
|
|
if (!str.empty()) {
|
|
|
|
|
std::string outStr;
|
|
|
|
|
std::istringstream stream(str);
|
|
|
|
|
int i = 0;
|
|
|
|
|
while (getline(stream, outStr, ',')) {
|
|
|
|
|
auto floatScaleFactor = std::stof(outStr);
|
|
|
|
|
if (floatScaleFactor <= 0.0f) {
|
|
|
|
|
throw std::logic_error("Scale factor for input #" + std::to_string(i)
|
|
|
|
|
+ " (counting from zero) is out of range (must be positive).");
|
|
|
|
|
}
|
|
|
|
|
scaleFactorInput.push_back(outStr);
|
|
|
|
|
i++;
|
|
|
|
|
}
|
|
|
|
|
} else {
|
|
|
|
|
throw std::logic_error("Scale factor need to be specified via -sf option if you are using -q user");
|
|
|
|
|
}
|
|
|
|
|
return scaleFactorInput;
|
|
|
|
|
}
|
|
|
|
|
|
2020-10-16 15:34:22 +03:00
|
|
|
std::vector<std::string> ParseBlobName(std::string str) {
|
|
|
|
|
std::vector<std::string> blobName;
|
|
|
|
|
if (!str.empty()) {
|
|
|
|
|
size_t pos_last = 0;
|
|
|
|
|
size_t pos_next = 0;
|
|
|
|
|
while ((pos_next = str.find(",", pos_last)) != std::string::npos) {
|
2021-01-26 14:23:36 +03:00
|
|
|
blobName.push_back(str.substr(pos_last, pos_next - pos_last));
|
2020-10-16 15:34:22 +03:00
|
|
|
pos_last = pos_next + 1;
|
|
|
|
|
}
|
|
|
|
|
blobName.push_back(str.substr(pos_last));
|
|
|
|
|
}
|
|
|
|
|
return blobName;
|
|
|
|
|
}
|
|
|
|
|
|
2019-01-21 21:31:31 +03:00
|
|
|
bool ParseAndCheckCommandLine(int argc, char *argv[]) {
|
|
|
|
|
// ---------------------------Parsing and validation of input args--------------------------------------
|
|
|
|
|
slog::info << "Parsing input parameters" << slog::endl;
|
|
|
|
|
|
|
|
|
|
gflags::ParseCommandLineNonHelpFlags(&argc, &argv, true);
|
|
|
|
|
if (FLAGS_h) {
|
|
|
|
|
showUsage();
|
2019-08-09 19:02:42 +03:00
|
|
|
showAvailableDevices();
|
2019-01-21 21:31:31 +03:00
|
|
|
return false;
|
|
|
|
|
}
|
|
|
|
|
bool isDumpMode = !FLAGS_wg.empty() || !FLAGS_we.empty();
|
|
|
|
|
|
|
|
|
|
// input not required only in dump mode and if external scale factor provided
|
|
|
|
|
if (FLAGS_i.empty() && (!isDumpMode || FLAGS_q.compare("user") != 0)) {
|
|
|
|
|
if (isDumpMode) {
|
|
|
|
|
throw std::logic_error("In model dump mode either static quantization is used (-i) or user scale"
|
|
|
|
|
" factor need to be provided. See -q user option");
|
|
|
|
|
}
|
|
|
|
|
throw std::logic_error("Input file not set. Please use -i.");
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
if (FLAGS_m.empty() && FLAGS_rg.empty()) {
|
|
|
|
|
throw std::logic_error("Either IR file (-m) or GNAModel file (-rg) need to be set.");
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
if ((!FLAGS_m.empty() && !FLAGS_rg.empty())) {
|
|
|
|
|
throw std::logic_error("Only one of -m and -rg is allowed.");
|
|
|
|
|
}
|
|
|
|
|
|
2019-08-09 19:02:42 +03:00
|
|
|
std::vector<std::string> supportedDevices = {
|
2019-04-12 18:25:53 +03:00
|
|
|
"CPU",
|
|
|
|
|
"GPU",
|
|
|
|
|
"GNA_AUTO",
|
|
|
|
|
"GNA_HW",
|
|
|
|
|
"GNA_SW_EXACT",
|
|
|
|
|
"GNA_SW",
|
2019-08-09 19:02:42 +03:00
|
|
|
"GNA_SW_FP32",
|
2019-04-12 18:25:53 +03:00
|
|
|
"HETERO:GNA,CPU",
|
|
|
|
|
"HETERO:GNA_HW,CPU",
|
|
|
|
|
"HETERO:GNA_SW_EXACT,CPU",
|
|
|
|
|
"HETERO:GNA_SW,CPU",
|
2019-08-09 19:02:42 +03:00
|
|
|
"HETERO:GNA_SW_FP32,CPU",
|
|
|
|
|
"MYRIAD"
|
2019-04-12 18:25:53 +03:00
|
|
|
};
|
|
|
|
|
|
2019-08-09 19:02:42 +03:00
|
|
|
if (std::find(supportedDevices.begin(), supportedDevices.end(), FLAGS_d) == supportedDevices.end()) {
|
2019-01-21 21:31:31 +03:00
|
|
|
throw std::logic_error("Specified device is not supported.");
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
uint32_t batchSize = (uint32_t) FLAGS_bs;
|
|
|
|
|
if ((batchSize < 1) || (batchSize > 8)) {
|
|
|
|
|
throw std::logic_error("Batch size out of range (1..8).");
|
|
|
|
|
}
|
|
|
|
|
|
2020-02-11 22:48:49 +03:00
|
|
|
/** default is a static quantization **/
|
2019-01-21 21:31:31 +03:00
|
|
|
if ((FLAGS_q.compare("static") != 0) && (FLAGS_q.compare("dynamic") != 0) && (FLAGS_q.compare("user") != 0)) {
|
|
|
|
|
throw std::logic_error("Quantization mode not supported (static, dynamic, user).");
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
if (FLAGS_q.compare("dynamic") == 0) {
|
|
|
|
|
throw std::logic_error("Dynamic quantization not yet supported.");
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
if (FLAGS_qb != 16 && FLAGS_qb != 8) {
|
|
|
|
|
throw std::logic_error("Only 8 or 16 bits supported.");
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
if (FLAGS_nthreads <= 0) {
|
2019-08-09 19:02:42 +03:00
|
|
|
throw std::logic_error("Invalid value for 'nthreads' argument. It must be greater that or equal to 0");
|
2019-01-21 21:31:31 +03:00
|
|
|
}
|
|
|
|
|
|
2019-08-09 19:02:42 +03:00
|
|
|
if (FLAGS_cw_r < 0) {
|
|
|
|
|
throw std::logic_error("Invalid value for 'cw_r' argument. It must be greater than or equal to 0");
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
if (FLAGS_cw_l < 0) {
|
|
|
|
|
throw std::logic_error("Invalid value for 'cw_l' argument. It must be greater than or equal to 0");
|
2019-04-12 18:25:53 +03:00
|
|
|
}
|
|
|
|
|
|
2021-03-04 14:10:01 +01:00
|
|
|
if (FLAGS_pwl_me < 0.0 || FLAGS_pwl_me > 100.0) {
|
|
|
|
|
throw std::logic_error("Invalid value for 'pwl_me' argument. It must be greater than 0.0 and less than 100.0");
|
|
|
|
|
}
|
|
|
|
|
|
2019-01-21 21:31:31 +03:00
|
|
|
return true;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
/**
|
|
|
|
|
* @brief The entry point for inference engine automatic speech recognition sample
|
|
|
|
|
* @file speech_sample/main.cpp
|
|
|
|
|
* @example speech_sample/main.cpp
|
|
|
|
|
*/
|
|
|
|
|
int main(int argc, char *argv[]) {
|
|
|
|
|
try {
|
|
|
|
|
slog::info << "InferenceEngine: " << GetInferenceEngineVersion() << slog::endl;
|
|
|
|
|
|
|
|
|
|
// ------------------------------ Parsing and validation of input args ---------------------------------
|
|
|
|
|
if (!ParseAndCheckCommandLine(argc, argv)) {
|
|
|
|
|
return 0;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
if (FLAGS_l.empty()) {
|
|
|
|
|
slog::info << "No extensions provided" << slog::endl;
|
|
|
|
|
}
|
|
|
|
|
|
2019-04-12 18:25:53 +03:00
|
|
|
auto isFeature = [&](const std::string xFeature) { return FLAGS_d.find(xFeature) != std::string::npos; };
|
|
|
|
|
|
|
|
|
|
bool useGna = isFeature("GNA");
|
|
|
|
|
bool useHetero = isFeature("HETERO");
|
|
|
|
|
std::string deviceStr =
|
|
|
|
|
useHetero && useGna ? "HETERO:GNA,CPU" : FLAGS_d.substr(0, (FLAGS_d.find("_")));
|
2019-08-09 19:02:42 +03:00
|
|
|
uint32_t batchSize = (FLAGS_cw_r > 0 || FLAGS_cw_l > 0) ? 1 : (uint32_t) FLAGS_bs;
|
2019-01-21 21:31:31 +03:00
|
|
|
|
2019-08-09 19:02:42 +03:00
|
|
|
std::vector<std::string> inputArkFiles;
|
|
|
|
|
std::vector<uint32_t> numBytesThisUtterance;
|
|
|
|
|
uint32_t numUtterances(0);
|
2019-01-21 21:31:31 +03:00
|
|
|
if (!FLAGS_i.empty()) {
|
2019-08-09 19:02:42 +03:00
|
|
|
std::string outStr;
|
|
|
|
|
std::istringstream stream(FLAGS_i);
|
|
|
|
|
|
|
|
|
|
uint32_t currentNumUtterances(0), currentNumBytesThisUtterance(0);
|
|
|
|
|
while (getline(stream, outStr, ',')) {
|
|
|
|
|
std::string filename(fileNameNoExt(outStr) + ".ark");
|
|
|
|
|
inputArkFiles.push_back(filename);
|
|
|
|
|
|
|
|
|
|
GetKaldiArkInfo(filename.c_str(), 0, ¤tNumUtterances, ¤tNumBytesThisUtterance);
|
|
|
|
|
if (numUtterances == 0) {
|
|
|
|
|
numUtterances = currentNumUtterances;
|
|
|
|
|
} else if (currentNumUtterances != numUtterances) {
|
|
|
|
|
throw std::logic_error("Incorrect input files. Number of utterance must be the same for all ark files");
|
|
|
|
|
}
|
|
|
|
|
numBytesThisUtterance.push_back(currentNumBytesThisUtterance);
|
|
|
|
|
}
|
2019-01-21 21:31:31 +03:00
|
|
|
}
|
2019-08-09 19:02:42 +03:00
|
|
|
size_t numInputArkFiles(inputArkFiles.size());
|
2019-01-21 21:31:31 +03:00
|
|
|
// -----------------------------------------------------------------------------------------------------
|
|
|
|
|
|
2019-08-09 19:02:42 +03:00
|
|
|
// --------------------------- 1. Load inference engine -------------------------------------
|
|
|
|
|
slog::info << "Loading Inference Engine" << slog::endl;
|
|
|
|
|
Core ie;
|
2019-01-21 21:31:31 +03:00
|
|
|
|
2019-08-09 19:02:42 +03:00
|
|
|
/** Printing device version **/
|
|
|
|
|
slog::info << "Device info: " << slog::endl;
|
|
|
|
|
std::cout << ie.GetVersions(deviceStr) << std::endl;
|
2019-01-21 21:31:31 +03:00
|
|
|
// -----------------------------------------------------------------------------------------------------
|
|
|
|
|
|
2020-08-26 18:53:24 +03:00
|
|
|
// 2. Read a model in OpenVINO Intermediate Representation (.xml and .bin files) or ONNX (.onnx file) format
|
2019-01-21 21:31:31 +03:00
|
|
|
slog::info << "Loading network files" << slog::endl;
|
|
|
|
|
|
2020-02-11 22:48:49 +03:00
|
|
|
CNNNetwork network;
|
2019-01-21 21:31:31 +03:00
|
|
|
if (!FLAGS_m.empty()) {
|
|
|
|
|
/** Read network model **/
|
2020-02-11 22:48:49 +03:00
|
|
|
network = ie.ReadNetwork(FLAGS_m);
|
2020-04-13 21:17:23 +03:00
|
|
|
CheckNumberOfInputs(network.getInputsInfo().size(), numInputArkFiles);
|
2019-01-21 21:31:31 +03:00
|
|
|
// -------------------------------------------------------------------------------------------------
|
|
|
|
|
|
|
|
|
|
// --------------------------- 3. Set batch size ---------------------------------------------------
|
|
|
|
|
/** Set batch size. Unlike in imaging, batching in time (rather than space) is done for speech recognition. **/
|
2020-02-11 22:48:49 +03:00
|
|
|
network.setBatchSize(batchSize);
|
|
|
|
|
slog::info << "Batch size is " << std::to_string(network.getBatchSize())
|
2019-01-21 21:31:31 +03:00
|
|
|
<< slog::endl;
|
|
|
|
|
}
|
|
|
|
|
|
2020-04-13 21:17:23 +03:00
|
|
|
// -----------------------------------------------------------------------------------------------------
|
|
|
|
|
|
|
|
|
|
// --------------------------- 4. Set parameters and scale factors -------------------------------------
|
2019-08-09 19:02:42 +03:00
|
|
|
/** Setting parameter for per layer metrics **/
|
2019-01-21 21:31:31 +03:00
|
|
|
std::map<std::string, std::string> gnaPluginConfig;
|
|
|
|
|
std::map<std::string, std::string> genericPluginConfig;
|
2019-04-12 18:25:53 +03:00
|
|
|
if (useGna) {
|
|
|
|
|
std::string gnaDevice =
|
|
|
|
|
useHetero ? FLAGS_d.substr(FLAGS_d.find("GNA"), FLAGS_d.find(",") - FLAGS_d.find("GNA")) : FLAGS_d;
|
|
|
|
|
gnaPluginConfig[GNAConfigParams::KEY_GNA_DEVICE_MODE] =
|
|
|
|
|
gnaDevice.find("_") == std::string::npos ? "GNA_AUTO" : gnaDevice;
|
2019-01-21 21:31:31 +03:00
|
|
|
}
|
2019-04-12 18:25:53 +03:00
|
|
|
|
2019-01-21 21:31:31 +03:00
|
|
|
if (FLAGS_pc) {
|
|
|
|
|
genericPluginConfig[PluginConfigParams::KEY_PERF_COUNT] = PluginConfigParams::YES;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
if (FLAGS_q.compare("user") == 0) {
|
2020-09-29 17:32:09 +02:00
|
|
|
if (!FLAGS_rg.empty()) {
|
|
|
|
|
slog::warn << "Custom scale factor will be ignored - using scale factor from provided imported gna model: "
|
|
|
|
|
<< FLAGS_rg << slog::endl;
|
|
|
|
|
} else {
|
|
|
|
|
auto scaleFactorInput = ParseScaleFactors(FLAGS_sf);
|
|
|
|
|
if (numInputArkFiles != scaleFactorInput.size()) {
|
|
|
|
|
std::string errMessage("Incorrect command line for multiple inputs: "
|
|
|
|
|
+ std::to_string(scaleFactorInput.size()) + " scale factors provided for "
|
|
|
|
|
+ std::to_string(numInputArkFiles) + " input files.");
|
|
|
|
|
throw std::logic_error(errMessage);
|
|
|
|
|
}
|
2020-06-25 11:43:47 +02:00
|
|
|
|
2020-09-29 17:32:09 +02:00
|
|
|
for (size_t i = 0; i < scaleFactorInput.size(); ++i) {
|
|
|
|
|
slog::info << "For input " << i << " using scale factor of " << scaleFactorInput[i] << slog::endl;
|
|
|
|
|
std::string scaleFactorConfigKey = GNA_CONFIG_KEY(SCALE_FACTOR) + std::string("_") + std::to_string(i);
|
|
|
|
|
gnaPluginConfig[scaleFactorConfigKey] = scaleFactorInput[i];
|
|
|
|
|
}
|
2020-06-09 14:36:28 +03:00
|
|
|
}
|
2019-08-09 19:02:42 +03:00
|
|
|
} else {
|
|
|
|
|
// "static" quantization with calculated scale factor
|
2020-09-29 17:32:09 +02:00
|
|
|
if (!FLAGS_rg.empty()) {
|
|
|
|
|
slog::info << "Using scale factor from provided imported gna model: " << FLAGS_rg << slog::endl;
|
|
|
|
|
} else {
|
|
|
|
|
for (size_t i = 0; i < numInputArkFiles; i++) {
|
|
|
|
|
auto inputArkName = inputArkFiles[i].c_str();
|
|
|
|
|
std::string name;
|
|
|
|
|
std::vector<uint8_t> ptrFeatures;
|
|
|
|
|
uint32_t numArrays(0), numBytes(0), numFrames(0), numFrameElements(0), numBytesPerElement(0);
|
|
|
|
|
GetKaldiArkInfo(inputArkName, 0, &numArrays, &numBytes);
|
|
|
|
|
ptrFeatures.resize(numBytes);
|
|
|
|
|
LoadKaldiArkArray(inputArkName,
|
|
|
|
|
0,
|
|
|
|
|
name,
|
|
|
|
|
ptrFeatures,
|
|
|
|
|
&numFrames,
|
|
|
|
|
&numFrameElements,
|
|
|
|
|
&numBytesPerElement);
|
|
|
|
|
auto floatScaleFactor =
|
2019-08-09 19:02:42 +03:00
|
|
|
ScaleFactorForQuantization(ptrFeatures.data(), MAX_VAL_2B_FEAT, numFrames * numFrameElements);
|
2020-09-29 17:32:09 +02:00
|
|
|
slog::info << "Using scale factor of " << floatScaleFactor << " calculated from first utterance."
|
|
|
|
|
<< slog::endl;
|
|
|
|
|
std::string scaleFactorConfigKey = GNA_CONFIG_KEY(SCALE_FACTOR) + std::string("_") + std::to_string(i);
|
|
|
|
|
gnaPluginConfig[scaleFactorConfigKey] = std::to_string(floatScaleFactor);
|
|
|
|
|
}
|
2019-08-09 19:02:42 +03:00
|
|
|
}
|
2019-01-21 21:31:31 +03:00
|
|
|
}
|
|
|
|
|
|
|
|
|
|
if (FLAGS_qb == 8) {
|
|
|
|
|
gnaPluginConfig[GNAConfigParams::KEY_GNA_PRECISION] = "I8";
|
|
|
|
|
} else {
|
|
|
|
|
gnaPluginConfig[GNAConfigParams::KEY_GNA_PRECISION] = "I16";
|
|
|
|
|
}
|
|
|
|
|
|
2019-08-09 19:02:42 +03:00
|
|
|
gnaPluginConfig[GNAConfigParams::KEY_GNA_LIB_N_THREADS] = std::to_string((FLAGS_cw_r > 0 || FLAGS_cw_l > 0) ? 1 : FLAGS_nthreads);
|
2019-01-21 21:31:31 +03:00
|
|
|
gnaPluginConfig[GNA_CONFIG_KEY(COMPACT_MODE)] = CONFIG_VALUE(NO);
|
2021-03-04 14:10:01 +01:00
|
|
|
gnaPluginConfig[GNA_CONFIG_KEY(PWL_MAX_ERROR_PERCENT)] = std::to_string(FLAGS_pwl_me);
|
2019-01-21 21:31:31 +03:00
|
|
|
// -----------------------------------------------------------------------------------------------------
|
|
|
|
|
|
2020-04-13 21:17:23 +03:00
|
|
|
// --------------------------- 5. Write model to file --------------------------------------------------
|
2019-01-21 21:31:31 +03:00
|
|
|
// Embedded GNA model dumping (for Intel(R) Speech Enabling Developer Kit)
|
|
|
|
|
if (!FLAGS_we.empty()) {
|
|
|
|
|
gnaPluginConfig[GNAConfigParams::KEY_GNA_FIRMWARE_MODEL_IMAGE] = FLAGS_we;
|
2020-02-11 22:48:49 +03:00
|
|
|
gnaPluginConfig[GNAConfigParams::KEY_GNA_FIRMWARE_MODEL_IMAGE_GENERATION] = FLAGS_we_gen;
|
2019-01-21 21:31:31 +03:00
|
|
|
}
|
|
|
|
|
// -----------------------------------------------------------------------------------------------------
|
|
|
|
|
|
2020-04-13 21:17:23 +03:00
|
|
|
// --------------------------- 6. Loading model to the device ------------------------------------------
|
2019-01-21 21:31:31 +03:00
|
|
|
|
|
|
|
|
if (useGna) {
|
|
|
|
|
genericPluginConfig.insert(std::begin(gnaPluginConfig), std::end(gnaPluginConfig));
|
|
|
|
|
}
|
|
|
|
|
auto t0 = Time::now();
|
2020-10-16 15:34:22 +03:00
|
|
|
std::vector<std::string> outputs;
|
2019-01-21 21:31:31 +03:00
|
|
|
ExecutableNetwork executableNet;
|
2019-04-12 18:25:53 +03:00
|
|
|
|
2020-10-16 15:34:22 +03:00
|
|
|
if (!FLAGS_oname.empty()) {
|
|
|
|
|
std::vector<std::string> output_names = ParseBlobName(FLAGS_oname);
|
|
|
|
|
std::vector<size_t> ports;
|
|
|
|
|
for (const auto& outBlobName : output_names) {
|
|
|
|
|
int pos_layer = outBlobName.rfind(":");
|
|
|
|
|
if (pos_layer == -1) {
|
|
|
|
|
throw std::logic_error(std::string("Output ") + std::string(outBlobName)
|
|
|
|
|
+ std::string(" doesn't have a port"));
|
|
|
|
|
}
|
|
|
|
|
outputs.push_back(outBlobName.substr(0, pos_layer));
|
|
|
|
|
try {
|
|
|
|
|
ports.push_back(std::stoi(outBlobName.substr(pos_layer + 1)));
|
2020-10-23 08:54:48 +03:00
|
|
|
} catch (const std::exception &) {
|
2020-10-16 15:34:22 +03:00
|
|
|
throw std::logic_error("Ports should have integer type");
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
for (size_t i = 0; i < outputs.size(); i++) {
|
|
|
|
|
network.addOutput(outputs[i], ports[i]);
|
|
|
|
|
}
|
|
|
|
|
}
|
2019-01-21 21:31:31 +03:00
|
|
|
if (!FLAGS_m.empty()) {
|
2019-08-09 19:02:42 +03:00
|
|
|
slog::info << "Loading model to the device" << slog::endl;
|
2020-06-05 11:54:03 +03:00
|
|
|
executableNet = ie.LoadNetwork(network, deviceStr, genericPluginConfig);
|
2019-01-21 21:31:31 +03:00
|
|
|
} else {
|
2019-08-09 19:02:42 +03:00
|
|
|
slog::info << "Importing model to the device" << slog::endl;
|
2020-06-05 11:54:03 +03:00
|
|
|
executableNet = ie.ImportNetwork(FLAGS_rg.c_str(), deviceStr, genericPluginConfig);
|
2019-01-21 21:31:31 +03:00
|
|
|
}
|
|
|
|
|
ms loadTime = std::chrono::duration_cast<ms>(Time::now() - t0);
|
|
|
|
|
slog::info << "Model loading time " << loadTime.count() << " ms" << slog::endl;
|
|
|
|
|
|
2020-04-13 21:17:23 +03:00
|
|
|
// --------------------------- 7. Exporting gna model using InferenceEngine AOT API---------------------
|
2019-01-21 21:31:31 +03:00
|
|
|
if (!FLAGS_wg.empty()) {
|
|
|
|
|
slog::info << "Writing GNA Model to file " << FLAGS_wg << slog::endl;
|
|
|
|
|
t0 = Time::now();
|
|
|
|
|
executableNet.Export(FLAGS_wg);
|
|
|
|
|
ms exportTime = std::chrono::duration_cast<ms>(Time::now() - t0);
|
|
|
|
|
slog::info << "Exporting time " << exportTime.count() << " ms" << slog::endl;
|
|
|
|
|
return 0;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
if (!FLAGS_we.empty()) {
|
|
|
|
|
slog::info << "Exported GNA embedded model to file " << FLAGS_we << slog::endl;
|
2020-02-11 22:48:49 +03:00
|
|
|
if (!FLAGS_we_gen.empty()) {
|
|
|
|
|
slog::info << "GNA embedded model export done for GNA generation: " << FLAGS_we_gen << slog::endl;
|
|
|
|
|
}
|
2019-01-21 21:31:31 +03:00
|
|
|
return 0;
|
|
|
|
|
}
|
|
|
|
|
|
2019-08-09 19:02:42 +03:00
|
|
|
std::vector<InferRequestStruct> inferRequests((FLAGS_cw_r > 0 || FLAGS_cw_l > 0) ? 1 : FLAGS_nthreads);
|
2019-01-21 21:31:31 +03:00
|
|
|
for (auto& inferRequest : inferRequests) {
|
2019-04-12 18:25:53 +03:00
|
|
|
inferRequest = {executableNet.CreateInferRequest(), -1, batchSize};
|
2019-01-21 21:31:31 +03:00
|
|
|
}
|
|
|
|
|
// -----------------------------------------------------------------------------------------------------
|
|
|
|
|
|
2020-04-13 21:17:23 +03:00
|
|
|
// --------------------------- 8. Prepare input blobs --------------------------------------------------
|
2019-01-21 21:31:31 +03:00
|
|
|
/** Taking information about all topology inputs **/
|
|
|
|
|
ConstInputsDataMap cInputInfo = executableNet.GetInputsInfo();
|
2020-04-13 21:17:23 +03:00
|
|
|
CheckNumberOfInputs(cInputInfo.size(), numInputArkFiles);
|
2019-01-21 21:31:31 +03:00
|
|
|
|
2020-04-13 21:17:23 +03:00
|
|
|
/** Stores all input blobs data **/
|
2019-08-09 19:02:42 +03:00
|
|
|
std::vector<Blob::Ptr> ptrInputBlobs;
|
2020-10-16 15:34:22 +03:00
|
|
|
if (!FLAGS_iname.empty()) {
|
|
|
|
|
std::vector<std::string> inputNameBlobs = ParseBlobName(FLAGS_iname);
|
|
|
|
|
if (inputNameBlobs.size() != cInputInfo.size()) {
|
|
|
|
|
std::string errMessage(std::string("Number of network inputs ( ") + std::to_string(cInputInfo.size()) +
|
|
|
|
|
" ) is not equal to the number of inputs entered in the -iname argument ( " +
|
|
|
|
|
std::to_string(inputNameBlobs.size()) + " ).");
|
|
|
|
|
throw std::logic_error(errMessage);
|
|
|
|
|
}
|
|
|
|
|
for (const auto& input : inputNameBlobs) {
|
|
|
|
|
Blob::Ptr blob = inferRequests.begin()->inferRequest.GetBlob(input);
|
|
|
|
|
if (!blob) {
|
|
|
|
|
std::string errMessage("No blob with name : " + input);
|
|
|
|
|
throw std::logic_error(errMessage);
|
|
|
|
|
}
|
|
|
|
|
ptrInputBlobs.push_back(blob);
|
|
|
|
|
}
|
|
|
|
|
} else {
|
|
|
|
|
for (const auto& input : cInputInfo) {
|
|
|
|
|
ptrInputBlobs.push_back(inferRequests.begin()->inferRequest.GetBlob(input.first));
|
|
|
|
|
}
|
2019-08-09 19:02:42 +03:00
|
|
|
}
|
|
|
|
|
InputsDataMap inputInfo;
|
|
|
|
|
if (!FLAGS_m.empty()) {
|
2020-02-11 22:48:49 +03:00
|
|
|
inputInfo = network.getInputsInfo();
|
2019-08-09 19:02:42 +03:00
|
|
|
}
|
2020-04-13 21:17:23 +03:00
|
|
|
/** Configure input precision if model is loaded from IR **/
|
2019-01-21 21:31:31 +03:00
|
|
|
for (auto &item : inputInfo) {
|
|
|
|
|
Precision inputPrecision = Precision::FP32; // specify Precision::I16 to provide quantized inputs
|
|
|
|
|
item.second->setPrecision(inputPrecision);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// -----------------------------------------------------------------------------------------------------
|
|
|
|
|
|
2020-04-13 21:17:23 +03:00
|
|
|
// --------------------------- 9. Prepare output blobs -------------------------------------------------
|
2019-01-21 21:31:31 +03:00
|
|
|
ConstOutputsDataMap cOutputInfo(executableNet.GetOutputsInfo());
|
|
|
|
|
OutputsDataMap outputInfo;
|
|
|
|
|
if (!FLAGS_m.empty()) {
|
2020-02-11 22:48:49 +03:00
|
|
|
outputInfo = network.getOutputsInfo();
|
2019-01-21 21:31:31 +03:00
|
|
|
}
|
2020-10-16 15:34:22 +03:00
|
|
|
std::vector<Blob::Ptr> ptrOutputBlob;
|
|
|
|
|
if (!outputs.empty()) {
|
|
|
|
|
for (const auto& output : outputs) {
|
|
|
|
|
Blob::Ptr blob = inferRequests.begin()->inferRequest.GetBlob(output);
|
|
|
|
|
if (!blob) {
|
|
|
|
|
std::string errMessage("No blob with name : " + output);
|
|
|
|
|
throw std::logic_error(errMessage);
|
|
|
|
|
}
|
|
|
|
|
ptrOutputBlob.push_back(blob);
|
|
|
|
|
}
|
|
|
|
|
} else {
|
|
|
|
|
for (auto& output : cOutputInfo) {
|
|
|
|
|
ptrOutputBlob.push_back(inferRequests.begin()->inferRequest.GetBlob(output.first));
|
|
|
|
|
}
|
|
|
|
|
}
|
2019-01-21 21:31:31 +03:00
|
|
|
|
|
|
|
|
for (auto &item : outputInfo) {
|
|
|
|
|
DataPtr outData = item.second;
|
|
|
|
|
if (!outData) {
|
|
|
|
|
throw std::logic_error("output data pointer is not valid");
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
Precision outputPrecision = Precision::FP32; // specify Precision::I32 to retrieve quantized outputs
|
|
|
|
|
outData->setPrecision(outputPrecision);
|
|
|
|
|
}
|
|
|
|
|
// -----------------------------------------------------------------------------------------------------
|
|
|
|
|
|
2020-04-13 21:17:23 +03:00
|
|
|
// --------------------------- 10. Do inference --------------------------------------------------------
|
2020-10-16 15:34:22 +03:00
|
|
|
std::vector<std::string> output_name_files;
|
|
|
|
|
std::vector<std::string> reference_name_files;
|
|
|
|
|
size_t count_file = 1;
|
|
|
|
|
if (!FLAGS_o.empty()) {
|
|
|
|
|
output_name_files = ParseBlobName(FLAGS_o);
|
|
|
|
|
if (output_name_files.size() != outputs.size() && !outputs.empty()) {
|
|
|
|
|
throw std::logic_error("The number of output files is not equal to the number of network outputs.");
|
2019-08-09 19:02:42 +03:00
|
|
|
}
|
2020-10-16 15:34:22 +03:00
|
|
|
count_file = output_name_files.empty() ? 1 : output_name_files.size();
|
|
|
|
|
}
|
|
|
|
|
if (!FLAGS_r.empty()) {
|
|
|
|
|
reference_name_files = ParseBlobName(FLAGS_r);
|
|
|
|
|
if (reference_name_files.size() != outputs.size() && !outputs.empty()) {
|
|
|
|
|
throw std::logic_error("The number of reference files is not equal to the number of network outputs.");
|
2019-01-21 21:31:31 +03:00
|
|
|
}
|
2020-10-16 15:34:22 +03:00
|
|
|
count_file = reference_name_files.empty() ? 1 : reference_name_files.size();
|
|
|
|
|
}
|
|
|
|
|
for (size_t next_output = 0; next_output < count_file; next_output++) {
|
|
|
|
|
std::vector<std::vector<uint8_t>> ptrUtterances;
|
|
|
|
|
std::vector<uint8_t> ptrScores;
|
|
|
|
|
std::vector<uint8_t> ptrReferenceScores;
|
|
|
|
|
score_error_t frameError, totalError;
|
|
|
|
|
|
|
|
|
|
ptrUtterances.resize(inputArkFiles.size());
|
2019-08-09 19:02:42 +03:00
|
|
|
|
2020-10-16 15:34:22 +03:00
|
|
|
// initialize memory state before starting
|
2020-11-12 12:40:43 +03:00
|
|
|
for (auto &&state : inferRequests.begin()->inferRequest.QueryState()) {
|
2020-10-16 15:34:22 +03:00
|
|
|
state.Reset();
|
2019-01-21 21:31:31 +03:00
|
|
|
}
|
|
|
|
|
|
2020-10-16 15:34:22 +03:00
|
|
|
for (uint32_t utteranceIndex = 0; utteranceIndex < numUtterances; ++utteranceIndex) {
|
|
|
|
|
std::map<std::string, InferenceEngine::InferenceEngineProfileInfo> utterancePerfMap;
|
|
|
|
|
std::string uttName;
|
|
|
|
|
uint32_t numFrames(0), n(0);
|
|
|
|
|
std::vector<uint32_t> numFrameElementsInput;
|
|
|
|
|
|
|
|
|
|
uint32_t numFramesReference(0), numFrameElementsReference(0), numBytesPerElementReference(0),
|
|
|
|
|
numBytesReferenceScoreThisUtterance(0);
|
2020-11-10 16:33:41 +03:00
|
|
|
auto dims = outputs.empty() ? cOutputInfo.rbegin()->second->getDims() : cOutputInfo[outputs[next_output]]->getDims();
|
|
|
|
|
const auto numScoresPerFrame = std::accumulate(std::begin(dims), std::end(dims), size_t{1}, std::multiplies<size_t>());
|
|
|
|
|
|
|
|
|
|
slog::info << "Number scores per frame : " << numScoresPerFrame << slog::endl;
|
2020-10-16 15:34:22 +03:00
|
|
|
|
|
|
|
|
numFrameElementsInput.resize(numInputArkFiles);
|
|
|
|
|
for (size_t i = 0; i < inputArkFiles.size(); i++) {
|
|
|
|
|
std::vector<uint8_t> ptrUtterance;
|
|
|
|
|
auto inputArkFilename = inputArkFiles[i].c_str();
|
|
|
|
|
uint32_t currentNumFrames(0), currentNumFrameElementsInput(0), currentNumBytesPerElementInput(0);
|
|
|
|
|
GetKaldiArkInfo(inputArkFilename, utteranceIndex, &n, &numBytesThisUtterance[i]);
|
|
|
|
|
ptrUtterance.resize(numBytesThisUtterance[i]);
|
|
|
|
|
LoadKaldiArkArray(inputArkFilename,
|
|
|
|
|
utteranceIndex,
|
|
|
|
|
uttName,
|
|
|
|
|
ptrUtterance,
|
|
|
|
|
¤tNumFrames,
|
|
|
|
|
¤tNumFrameElementsInput,
|
|
|
|
|
¤tNumBytesPerElementInput);
|
|
|
|
|
if (numFrames == 0) {
|
|
|
|
|
numFrames = currentNumFrames;
|
|
|
|
|
} else if (numFrames != currentNumFrames) {
|
|
|
|
|
std::string errMessage(
|
|
|
|
|
"Number of frames in ark files is different: " + std::to_string(numFrames) +
|
|
|
|
|
" and " + std::to_string(currentNumFrames));
|
|
|
|
|
throw std::logic_error(errMessage);
|
|
|
|
|
}
|
2019-01-21 21:31:31 +03:00
|
|
|
|
2020-10-16 15:34:22 +03:00
|
|
|
ptrUtterances[i] = ptrUtterance;
|
|
|
|
|
numFrameElementsInput[i] = currentNumFrameElementsInput;
|
|
|
|
|
}
|
2019-01-21 21:31:31 +03:00
|
|
|
|
2020-10-16 15:34:22 +03:00
|
|
|
int i = 0;
|
|
|
|
|
for (auto &ptrInputBlob : ptrInputBlobs) {
|
|
|
|
|
if (ptrInputBlob->size() != numFrameElementsInput[i++] * batchSize) {
|
|
|
|
|
throw std::logic_error("network input size(" + std::to_string(ptrInputBlob->size()) +
|
|
|
|
|
") mismatch to ark file size (" +
|
|
|
|
|
std::to_string(numFrameElementsInput[i - 1] * batchSize) + ")");
|
|
|
|
|
}
|
|
|
|
|
}
|
2019-01-21 21:31:31 +03:00
|
|
|
|
2020-10-16 15:34:22 +03:00
|
|
|
ptrScores.resize(numFrames * numScoresPerFrame * sizeof(float));
|
|
|
|
|
if (!FLAGS_r.empty()) {
|
|
|
|
|
std::string refUtteranceName;
|
|
|
|
|
GetKaldiArkInfo(reference_name_files[next_output].c_str(), utteranceIndex, &n, &numBytesReferenceScoreThisUtterance);
|
|
|
|
|
ptrReferenceScores.resize(numBytesReferenceScoreThisUtterance);
|
|
|
|
|
LoadKaldiArkArray(reference_name_files[next_output].c_str(),
|
|
|
|
|
utteranceIndex,
|
|
|
|
|
refUtteranceName,
|
|
|
|
|
ptrReferenceScores,
|
|
|
|
|
&numFramesReference,
|
|
|
|
|
&numFrameElementsReference,
|
|
|
|
|
&numBytesPerElementReference);
|
|
|
|
|
}
|
2019-04-12 18:25:53 +03:00
|
|
|
|
2020-10-16 15:34:22 +03:00
|
|
|
double totalTime = 0.0;
|
2019-01-21 21:31:31 +03:00
|
|
|
|
2020-10-16 15:34:22 +03:00
|
|
|
std::cout << "Utterance " << utteranceIndex << ": " << std::endl;
|
2019-01-21 21:31:31 +03:00
|
|
|
|
2020-10-16 15:34:22 +03:00
|
|
|
ClearScoreError(&totalError);
|
|
|
|
|
totalError.threshold = frameError.threshold = MAX_SCORE_DIFFERENCE;
|
|
|
|
|
auto outputFrame = &ptrScores.front();
|
|
|
|
|
std::vector<uint8_t *> inputFrame;
|
|
|
|
|
for (auto &ut : ptrUtterances) {
|
|
|
|
|
inputFrame.push_back(&ut.front());
|
2019-01-21 21:31:31 +03:00
|
|
|
}
|
|
|
|
|
|
2020-10-16 15:34:22 +03:00
|
|
|
std::map<std::string, InferenceEngine::InferenceEngineProfileInfo> callPerfMap;
|
|
|
|
|
|
|
|
|
|
size_t frameIndex = 0;
|
|
|
|
|
uint32_t numFramesArkFile = numFrames;
|
|
|
|
|
numFrames += FLAGS_cw_l + FLAGS_cw_r;
|
|
|
|
|
uint32_t numFramesThisBatch{batchSize};
|
|
|
|
|
|
|
|
|
|
auto t0 = Time::now();
|
|
|
|
|
auto t1 = t0;
|
|
|
|
|
|
|
|
|
|
while (frameIndex <= numFrames) {
|
2019-01-21 21:31:31 +03:00
|
|
|
if (frameIndex == numFrames) {
|
2020-10-16 15:34:22 +03:00
|
|
|
if (std::find_if(inferRequests.begin(),
|
|
|
|
|
inferRequests.end(),
|
|
|
|
|
[&](InferRequestStruct x) { return (x.frameIndex != -1); }) ==
|
|
|
|
|
inferRequests.end()) {
|
|
|
|
|
break;
|
|
|
|
|
}
|
2019-01-21 21:31:31 +03:00
|
|
|
}
|
|
|
|
|
|
2020-10-16 15:34:22 +03:00
|
|
|
bool inferRequestFetched = false;
|
|
|
|
|
for (auto &inferRequest : inferRequests) {
|
|
|
|
|
if (frameIndex == numFrames) {
|
|
|
|
|
numFramesThisBatch = 1;
|
|
|
|
|
} else {
|
|
|
|
|
numFramesThisBatch = (numFrames - frameIndex < batchSize) ? (numFrames - frameIndex)
|
|
|
|
|
: batchSize;
|
2019-01-21 21:31:31 +03:00
|
|
|
}
|
|
|
|
|
|
2020-10-16 15:34:22 +03:00
|
|
|
if (inferRequest.frameIndex != -1) {
|
|
|
|
|
StatusCode code = inferRequest.inferRequest.Wait(
|
|
|
|
|
InferenceEngine::IInferRequest::WaitMode::RESULT_READY);
|
2019-04-12 18:25:53 +03:00
|
|
|
|
2020-10-16 15:34:22 +03:00
|
|
|
if (code != StatusCode::OK) {
|
|
|
|
|
if (!useHetero) continue;
|
|
|
|
|
if (code != StatusCode::INFER_NOT_STARTED) continue;
|
2019-04-12 18:25:53 +03:00
|
|
|
}
|
2020-10-16 15:34:22 +03:00
|
|
|
ConstOutputsDataMap newOutputInfo;
|
|
|
|
|
if (inferRequest.frameIndex >= 0) {
|
|
|
|
|
if (!FLAGS_o.empty()) {
|
|
|
|
|
outputFrame =
|
|
|
|
|
&ptrScores.front() +
|
|
|
|
|
numScoresPerFrame * sizeof(float) * (inferRequest.frameIndex);
|
|
|
|
|
if (!outputs.empty()) {
|
|
|
|
|
newOutputInfo[outputs[next_output]] = cOutputInfo[outputs[next_output]];
|
|
|
|
|
} else {
|
|
|
|
|
newOutputInfo = cOutputInfo;
|
|
|
|
|
}
|
|
|
|
|
Blob::Ptr outputBlob = inferRequest.inferRequest.GetBlob(newOutputInfo.rbegin()->first);
|
|
|
|
|
MemoryBlob::CPtr moutput = as<MemoryBlob>(outputBlob);
|
|
|
|
|
|
|
|
|
|
if (!moutput) {
|
|
|
|
|
throw std::logic_error("We expect output to be inherited from MemoryBlob, "
|
|
|
|
|
"but in fact we were not able to cast output to MemoryBlob");
|
|
|
|
|
}
|
|
|
|
|
// locked memory holder should be alive all time while access to its buffer happens
|
|
|
|
|
auto moutputHolder = moutput->rmap();
|
|
|
|
|
auto byteSize =
|
2021-02-18 18:31:40 +03:00
|
|
|
numScoresPerFrame * sizeof(float);
|
2020-10-16 15:34:22 +03:00
|
|
|
std::memcpy(outputFrame,
|
|
|
|
|
moutputHolder.as<const void *>(),
|
|
|
|
|
byteSize);
|
|
|
|
|
}
|
|
|
|
|
if (!FLAGS_r.empty()) {
|
|
|
|
|
if (!outputs.empty()) {
|
|
|
|
|
newOutputInfo[outputs[next_output]] = cOutputInfo[outputs[next_output]];
|
|
|
|
|
} else {
|
|
|
|
|
newOutputInfo = cOutputInfo;
|
|
|
|
|
}
|
|
|
|
|
Blob::Ptr outputBlob = inferRequest.inferRequest.GetBlob(newOutputInfo.rbegin()->first);
|
|
|
|
|
MemoryBlob::CPtr moutput = as<MemoryBlob>(outputBlob);
|
|
|
|
|
if (!moutput) {
|
|
|
|
|
throw std::logic_error("We expect output to be inherited from MemoryBlob, "
|
|
|
|
|
"but in fact we were not able to cast output to MemoryBlob");
|
|
|
|
|
}
|
|
|
|
|
// locked memory holder should be alive all time while access to its buffer happens
|
|
|
|
|
auto moutputHolder = moutput->rmap();
|
|
|
|
|
CompareScores(moutputHolder.as<float *>(),
|
|
|
|
|
&ptrReferenceScores[inferRequest.frameIndex *
|
|
|
|
|
numFrameElementsReference *
|
|
|
|
|
numBytesPerElementReference],
|
|
|
|
|
&frameError,
|
|
|
|
|
inferRequest.numFramesThisBatch,
|
|
|
|
|
numFrameElementsReference);
|
|
|
|
|
UpdateScoreError(&frameError, &totalError);
|
|
|
|
|
}
|
|
|
|
|
if (FLAGS_pc) {
|
|
|
|
|
// retrieve new counters
|
|
|
|
|
getPerformanceCounters(inferRequest.inferRequest, callPerfMap);
|
|
|
|
|
// summarize retrieved counters with all previous
|
|
|
|
|
sumPerformanceCounters(callPerfMap, utterancePerfMap);
|
|
|
|
|
}
|
2019-04-12 18:25:53 +03:00
|
|
|
}
|
2019-01-21 21:31:31 +03:00
|
|
|
}
|
2019-04-12 18:25:53 +03:00
|
|
|
|
2020-10-16 15:34:22 +03:00
|
|
|
if (frameIndex == numFrames) {
|
|
|
|
|
inferRequest.frameIndex = -1;
|
|
|
|
|
continue;
|
2020-02-11 22:48:49 +03:00
|
|
|
}
|
|
|
|
|
|
2021-03-22 17:14:15 +03:00
|
|
|
ptrInputBlobs.clear();
|
2020-10-16 15:34:22 +03:00
|
|
|
if (FLAGS_iname.empty()) {
|
2021-03-22 17:14:15 +03:00
|
|
|
for (auto &input : cInputInfo) {
|
|
|
|
|
ptrInputBlobs.push_back(inferRequest.inferRequest.GetBlob(input.first));
|
|
|
|
|
}
|
|
|
|
|
} else {
|
|
|
|
|
std::vector<std::string> inputNameBlobs = ParseBlobName(FLAGS_iname);
|
|
|
|
|
for (const auto& input : inputNameBlobs) {
|
|
|
|
|
Blob::Ptr blob = inferRequests.begin()->inferRequest.GetBlob(input);
|
|
|
|
|
if (!blob) {
|
|
|
|
|
std::string errMessage("No blob with name : " + input);
|
|
|
|
|
throw std::logic_error(errMessage);
|
2020-10-16 15:34:22 +03:00
|
|
|
}
|
2021-03-22 17:14:15 +03:00
|
|
|
ptrInputBlobs.push_back(blob);
|
|
|
|
|
}
|
|
|
|
|
}
|
2019-01-21 21:31:31 +03:00
|
|
|
|
2021-03-22 17:14:15 +03:00
|
|
|
for (size_t i = 0; i < numInputArkFiles; ++i) {
|
|
|
|
|
MemoryBlob::Ptr minput = as<MemoryBlob>(ptrInputBlobs[i]);
|
|
|
|
|
if (!minput) {
|
|
|
|
|
std::string errMessage("We expect ptrInputBlobs[" + std::to_string(i) +
|
|
|
|
|
"] to be inherited from MemoryBlob, " +
|
|
|
|
|
"but in fact we were not able to cast input blob to MemoryBlob");
|
|
|
|
|
throw std::logic_error(errMessage);
|
2020-10-16 15:34:22 +03:00
|
|
|
}
|
2021-03-22 17:14:15 +03:00
|
|
|
// locked memory holder should be alive all time while access to its buffer happens
|
|
|
|
|
auto minputHolder = minput->wmap();
|
|
|
|
|
|
|
|
|
|
std::memcpy(minputHolder.as<void *>(),
|
|
|
|
|
inputFrame[i],
|
|
|
|
|
minput->byteSize());
|
2020-10-16 15:34:22 +03:00
|
|
|
}
|
2019-01-21 21:31:31 +03:00
|
|
|
|
2020-10-16 15:34:22 +03:00
|
|
|
int index = static_cast<int>(frameIndex) - (FLAGS_cw_l + FLAGS_cw_r);
|
|
|
|
|
inferRequest.inferRequest.StartAsync();
|
|
|
|
|
inferRequest.frameIndex = index < 0 ? -2 : index;
|
|
|
|
|
inferRequest.numFramesThisBatch = numFramesThisBatch;
|
|
|
|
|
|
|
|
|
|
frameIndex += numFramesThisBatch;
|
|
|
|
|
for (size_t j = 0; j < inputArkFiles.size(); j++) {
|
|
|
|
|
if (FLAGS_cw_l > 0 || FLAGS_cw_r > 0) {
|
|
|
|
|
int idx = frameIndex - FLAGS_cw_l;
|
|
|
|
|
if (idx > 0 && idx < static_cast<int>(numFramesArkFile)) {
|
|
|
|
|
inputFrame[j] += sizeof(float) * numFrameElementsInput[j] * numFramesThisBatch;
|
|
|
|
|
} else if (idx >= static_cast<int>(numFramesArkFile)) {
|
|
|
|
|
inputFrame[j] = &ptrUtterances[j].front() +
|
|
|
|
|
(numFramesArkFile - 1) * sizeof(float) * numFrameElementsInput[j] *
|
|
|
|
|
numFramesThisBatch;
|
|
|
|
|
} else if (idx <= 0) {
|
|
|
|
|
inputFrame[j] = &ptrUtterances[j].front();
|
|
|
|
|
}
|
|
|
|
|
} else {
|
2019-08-09 19:02:42 +03:00
|
|
|
inputFrame[j] += sizeof(float) * numFrameElementsInput[j] * numFramesThisBatch;
|
|
|
|
|
}
|
2019-04-12 18:25:53 +03:00
|
|
|
}
|
2020-10-16 15:34:22 +03:00
|
|
|
inferRequestFetched |= true;
|
|
|
|
|
}
|
|
|
|
|
if (!inferRequestFetched) {
|
|
|
|
|
std::this_thread::sleep_for(std::chrono::milliseconds(1));
|
|
|
|
|
continue;
|
2019-04-12 18:25:53 +03:00
|
|
|
}
|
2019-01-21 21:31:31 +03:00
|
|
|
}
|
2020-10-16 15:34:22 +03:00
|
|
|
t1 = Time::now();
|
2019-01-21 21:31:31 +03:00
|
|
|
|
2020-10-16 15:34:22 +03:00
|
|
|
fsec fs = t1 - t0;
|
|
|
|
|
ms d = std::chrono::duration_cast<ms>(fs);
|
|
|
|
|
totalTime += d.count();
|
2019-01-21 21:31:31 +03:00
|
|
|
|
2020-10-16 15:34:22 +03:00
|
|
|
// resetting state between utterances
|
2020-11-12 12:40:43 +03:00
|
|
|
for (auto &&state : inferRequests.begin()->inferRequest.QueryState()) {
|
2020-10-16 15:34:22 +03:00
|
|
|
state.Reset();
|
|
|
|
|
}
|
2019-01-21 21:31:31 +03:00
|
|
|
|
2020-10-16 15:34:22 +03:00
|
|
|
if (!FLAGS_o.empty()) {
|
|
|
|
|
bool shouldAppend = (utteranceIndex == 0) ? false : true;
|
|
|
|
|
SaveKaldiArkArray(output_name_files[next_output].c_str(), shouldAppend, uttName, &ptrScores.front(),
|
|
|
|
|
numFramesArkFile, numScoresPerFrame);
|
|
|
|
|
}
|
2019-01-21 21:31:31 +03:00
|
|
|
|
2020-10-16 15:34:22 +03:00
|
|
|
/** Show performance results **/
|
|
|
|
|
std::cout << "Total time in Infer (HW and SW):\t" << totalTime << " ms"
|
|
|
|
|
<< std::endl;
|
|
|
|
|
std::cout << "Frames in utterance:\t\t\t" << numFrames << " frames"
|
|
|
|
|
<< std::endl;
|
|
|
|
|
std::cout << "Average Infer time per frame:\t\t" << totalTime / static_cast<double>(numFrames) << " ms"
|
|
|
|
|
<< std::endl;
|
|
|
|
|
if (FLAGS_pc) {
|
|
|
|
|
// print
|
|
|
|
|
printPerformanceCounters(utterancePerfMap, frameIndex, std::cout, getFullDeviceName(ie, FLAGS_d));
|
|
|
|
|
}
|
|
|
|
|
if (!FLAGS_r.empty()) {
|
|
|
|
|
printReferenceCompareResults(totalError, numFrames, std::cout);
|
|
|
|
|
}
|
|
|
|
|
std::cout << "End of Utterance " << utteranceIndex << std::endl << std::endl;
|
2019-01-21 21:31:31 +03:00
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
// -----------------------------------------------------------------------------------------------------
|
|
|
|
|
}
|
|
|
|
|
catch (const std::exception &error) {
|
|
|
|
|
slog::err << error.what() << slog::endl;
|
|
|
|
|
return 1;
|
|
|
|
|
}
|
|
|
|
|
catch (...) {
|
|
|
|
|
slog::err << "Unknown/internal exception happened" << slog::endl;
|
|
|
|
|
return 1;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
slog::info << "Execution successful" << slog::endl;
|
|
|
|
|
return 0;
|
|
|
|
|
}
|