Added more comments into speech sample (#5049)

This commit is contained in:
Anton Romanov 2021-04-08 14:00:22 +03:00 committed by GitHub
parent 37893caa36
commit d30740af66
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -263,11 +263,6 @@ float StdDevError(score_error_t error) {
- (error.sumError / error.numScores) * (error.sumError / error.numScores))); - (error.sumError / error.numScores) * (error.sumError / error.numScores)));
} }
float StdDevRelError(score_error_t error) {
return (sqrt(error.sumSquaredRelError / error.numScores
- (error.sumRelError / error.numScores) * (error.sumRelError / error.numScores)));
}
#if !defined(__arm__) && !defined(_M_ARM) && !defined(__aarch64__) && !defined(_M_ARM64) #if !defined(__arm__) && !defined(_M_ARM) && !defined(__aarch64__) && !defined(_M_ARM64)
#ifdef _WIN32 #ifdef _WIN32
#include <intrin.h> #include <intrin.h>
@ -579,23 +574,24 @@ int main(int argc, char *argv[]) {
// --------------------------- 1. Load inference engine ------------------------------------- // --------------------------- 1. Load inference engine -------------------------------------
slog::info << "Loading Inference Engine" << slog::endl; slog::info << "Loading Inference Engine" << slog::endl;
Core ie; Core ie;
CNNNetwork network;
ExecutableNetwork executableNet;
/** Printing device version **/ /** Printing device version **/
slog::info << "Device info: " << slog::endl; slog::info << "Device info: " << slog::endl;
std::cout << ie.GetVersions(deviceStr) << std::endl; std::cout << ie.GetVersions(deviceStr) << std::endl;
// ----------------------------------------------------------------------------------------------------- // -----------------------------------------------------------------------------------------------------
// 2. Read a model in OpenVINO Intermediate Representation (.xml and .bin files) or ONNX (.onnx file) format // --------------------------- 2. Read a model in OpenVINO Intermediate Representation (.xml and .bin files) or ONNX (.onnx file) format
slog::info << "Loading network files" << slog::endl; slog::info << "Loading network files" << slog::endl;
CNNNetwork network;
if (!FLAGS_m.empty()) { if (!FLAGS_m.empty()) {
/** Read network model **/ /** Read network model **/
network = ie.ReadNetwork(FLAGS_m); network = ie.ReadNetwork(FLAGS_m);
CheckNumberOfInputs(network.getInputsInfo().size(), numInputArkFiles); CheckNumberOfInputs(network.getInputsInfo().size(), numInputArkFiles);
// ------------------------------------------------------------------------------------------------- // -------------------------------------------------------------------------------------------------
// --------------------------- 3. Set batch size --------------------------------------------------- // --------------------------- Set batch size ---------------------------------------------------
/** Set batch size. Unlike in imaging, batching in time (rather than space) is done for speech recognition. **/ /** Set batch size. Unlike in imaging, batching in time (rather than space) is done for speech recognition. **/
network.setBatchSize(batchSize); network.setBatchSize(batchSize);
slog::info << "Batch size is " << std::to_string(network.getBatchSize()) slog::info << "Batch size is " << std::to_string(network.getBatchSize())
@ -604,7 +600,7 @@ int main(int argc, char *argv[]) {
// ----------------------------------------------------------------------------------------------------- // -----------------------------------------------------------------------------------------------------
// --------------------------- 4. Set parameters and scale factors ------------------------------------- // --------------------------- Set parameters and scale factors -------------------------------------
/** Setting parameter for per layer metrics **/ /** Setting parameter for per layer metrics **/
std::map<std::string, std::string> gnaPluginConfig; std::map<std::string, std::string> gnaPluginConfig;
std::map<std::string, std::string> genericPluginConfig; std::map<std::string, std::string> genericPluginConfig;
@ -678,7 +674,7 @@ int main(int argc, char *argv[]) {
gnaPluginConfig[GNA_CONFIG_KEY(PWL_MAX_ERROR_PERCENT)] = std::to_string(FLAGS_pwl_me); gnaPluginConfig[GNA_CONFIG_KEY(PWL_MAX_ERROR_PERCENT)] = std::to_string(FLAGS_pwl_me);
// ----------------------------------------------------------------------------------------------------- // -----------------------------------------------------------------------------------------------------
// --------------------------- 5. Write model to file -------------------------------------------------- // --------------------------- Write model to file --------------------------------------------------
// Embedded GNA model dumping (for Intel(R) Speech Enabling Developer Kit) // Embedded GNA model dumping (for Intel(R) Speech Enabling Developer Kit)
if (!FLAGS_we.empty()) { if (!FLAGS_we.empty()) {
gnaPluginConfig[GNAConfigParams::KEY_GNA_FIRMWARE_MODEL_IMAGE] = FLAGS_we; gnaPluginConfig[GNAConfigParams::KEY_GNA_FIRMWARE_MODEL_IMAGE] = FLAGS_we;
@ -686,14 +682,13 @@ int main(int argc, char *argv[]) {
} }
// ----------------------------------------------------------------------------------------------------- // -----------------------------------------------------------------------------------------------------
// --------------------------- 6. Loading model to the device ------------------------------------------ // --------------------------- 3. Loading model to the device ------------------------------------------
if (useGna) { if (useGna) {
genericPluginConfig.insert(std::begin(gnaPluginConfig), std::end(gnaPluginConfig)); genericPluginConfig.insert(std::begin(gnaPluginConfig), std::end(gnaPluginConfig));
} }
auto t0 = Time::now(); auto t0 = Time::now();
std::vector<std::string> outputs; std::vector<std::string> outputs;
ExecutableNetwork executableNet;
if (!FLAGS_oname.empty()) { if (!FLAGS_oname.empty()) {
std::vector<std::string> output_names = ParseBlobName(FLAGS_oname); std::vector<std::string> output_names = ParseBlobName(FLAGS_oname);
@ -726,7 +721,7 @@ int main(int argc, char *argv[]) {
ms loadTime = std::chrono::duration_cast<ms>(Time::now() - t0); ms loadTime = std::chrono::duration_cast<ms>(Time::now() - t0);
slog::info << "Model loading time " << loadTime.count() << " ms" << slog::endl; slog::info << "Model loading time " << loadTime.count() << " ms" << slog::endl;
// --------------------------- 7. Exporting gna model using InferenceEngine AOT API--------------------- // --------------------------- Exporting gna model using InferenceEngine AOT API---------------------
if (!FLAGS_wg.empty()) { if (!FLAGS_wg.empty()) {
slog::info << "Writing GNA Model to file " << FLAGS_wg << slog::endl; slog::info << "Writing GNA Model to file " << FLAGS_wg << slog::endl;
t0 = Time::now(); t0 = Time::now();
@ -744,13 +739,17 @@ int main(int argc, char *argv[]) {
return 0; return 0;
} }
// --------------------------- 4. Create infer request --------------------------------------------------
std::vector<InferRequestStruct> inferRequests((FLAGS_cw_r > 0 || FLAGS_cw_l > 0) ? 1 : FLAGS_nthreads); std::vector<InferRequestStruct> inferRequests((FLAGS_cw_r > 0 || FLAGS_cw_l > 0) ? 1 : FLAGS_nthreads);
for (auto& inferRequest : inferRequests) { for (auto& inferRequest : inferRequests) {
inferRequest = {executableNet.CreateInferRequest(), -1, batchSize}; inferRequest = {executableNet.CreateInferRequest(), -1, batchSize};
} }
// ----------------------------------------------------------------------------------------------------- // ---------------------------------------------------------------------------------------------------------
// --------------------------- 8. Prepare input blobs -------------------------------------------------- // --------------------------- 5. Configure input & output --------------------------------------------------
//--- Prepare input blobs ----------------------------------------------
/** Taking information about all topology inputs **/ /** Taking information about all topology inputs **/
ConstInputsDataMap cInputInfo = executableNet.GetInputsInfo(); ConstInputsDataMap cInputInfo = executableNet.GetInputsInfo();
CheckNumberOfInputs(cInputInfo.size(), numInputArkFiles); CheckNumberOfInputs(cInputInfo.size(), numInputArkFiles);
@ -788,9 +787,9 @@ int main(int argc, char *argv[]) {
item.second->setPrecision(inputPrecision); item.second->setPrecision(inputPrecision);
} }
// ----------------------------------------------------------------------------------------------------- // ---------------------------------------------------------------------
// --------------------------- 9. Prepare output blobs ------------------------------------------------- //--- Prepare output blobs ---------------------------------------------
ConstOutputsDataMap cOutputInfo(executableNet.GetOutputsInfo()); ConstOutputsDataMap cOutputInfo(executableNet.GetOutputsInfo());
OutputsDataMap outputInfo; OutputsDataMap outputInfo;
if (!FLAGS_m.empty()) { if (!FLAGS_m.empty()) {
@ -821,9 +820,10 @@ int main(int argc, char *argv[]) {
Precision outputPrecision = Precision::FP32; // specify Precision::I32 to retrieve quantized outputs Precision outputPrecision = Precision::FP32; // specify Precision::I32 to retrieve quantized outputs
outData->setPrecision(outputPrecision); outData->setPrecision(outputPrecision);
} }
// ---------------------------------------------------------------------
// ----------------------------------------------------------------------------------------------------- // -----------------------------------------------------------------------------------------------------
// --------------------------- 10. Do inference -------------------------------------------------------- // --------------------------- 6. Do inference --------------------------------------------------------
std::vector<std::string> output_name_files; std::vector<std::string> output_name_files;
std::vector<std::string> reference_name_files; std::vector<std::string> reference_name_files;
size_t count_file = 1; size_t count_file = 1;
@ -854,6 +854,7 @@ int main(int argc, char *argv[]) {
state.Reset(); state.Reset();
} }
/** Work with each utterance **/
for (uint32_t utteranceIndex = 0; utteranceIndex < numUtterances; ++utteranceIndex) { for (uint32_t utteranceIndex = 0; utteranceIndex < numUtterances; ++utteranceIndex) {
std::map<std::string, InferenceEngine::InferenceEngineProfileInfo> utterancePerfMap; std::map<std::string, InferenceEngine::InferenceEngineProfileInfo> utterancePerfMap;
std::string uttName; std::string uttName;
@ -867,6 +868,7 @@ int main(int argc, char *argv[]) {
slog::info << "Number scores per frame : " << numScoresPerFrame << slog::endl; slog::info << "Number scores per frame : " << numScoresPerFrame << slog::endl;
/** Get information from ark file for current utterance **/
numFrameElementsInput.resize(numInputArkFiles); numFrameElementsInput.resize(numInputArkFiles);
for (size_t i = 0; i < inputArkFiles.size(); i++) { for (size_t i = 0; i < inputArkFiles.size(); i++) {
std::vector<uint8_t> ptrUtterance; std::vector<uint8_t> ptrUtterance;
@ -905,6 +907,7 @@ int main(int argc, char *argv[]) {
ptrScores.resize(numFrames * numScoresPerFrame * sizeof(float)); ptrScores.resize(numFrames * numScoresPerFrame * sizeof(float));
if (!FLAGS_r.empty()) { if (!FLAGS_r.empty()) {
/** Read ark file with reference scores **/
std::string refUtteranceName; std::string refUtteranceName;
GetKaldiArkInfo(reference_name_files[next_output].c_str(), utteranceIndex, &n, &numBytesReferenceScoreThisUtterance); GetKaldiArkInfo(reference_name_files[next_output].c_str(), utteranceIndex, &n, &numBytesReferenceScoreThisUtterance);
ptrReferenceScores.resize(numBytesReferenceScoreThisUtterance); ptrReferenceScores.resize(numBytesReferenceScoreThisUtterance);
@ -950,6 +953,7 @@ int main(int argc, char *argv[]) {
} }
bool inferRequestFetched = false; bool inferRequestFetched = false;
/** Start inference loop **/
for (auto &inferRequest : inferRequests) { for (auto &inferRequest : inferRequests) {
if (frameIndex == numFrames) { if (frameIndex == numFrames) {
numFramesThisBatch = 1; numFramesThisBatch = 1;
@ -969,6 +973,7 @@ int main(int argc, char *argv[]) {
ConstOutputsDataMap newOutputInfo; ConstOutputsDataMap newOutputInfo;
if (inferRequest.frameIndex >= 0) { if (inferRequest.frameIndex >= 0) {
if (!FLAGS_o.empty()) { if (!FLAGS_o.empty()) {
/* Prepare output data for save to file in future */
outputFrame = outputFrame =
&ptrScores.front() + &ptrScores.front() +
numScoresPerFrame * sizeof(float) * (inferRequest.frameIndex); numScoresPerFrame * sizeof(float) * (inferRequest.frameIndex);
@ -993,6 +998,7 @@ int main(int argc, char *argv[]) {
byteSize); byteSize);
} }
if (!FLAGS_r.empty()) { if (!FLAGS_r.empty()) {
/** Compare output data with reference scores **/
if (!outputs.empty()) { if (!outputs.empty()) {
newOutputInfo[outputs[next_output]] = cOutputInfo[outputs[next_output]]; newOutputInfo[outputs[next_output]] = cOutputInfo[outputs[next_output]];
} else { } else {
@ -1029,6 +1035,7 @@ int main(int argc, char *argv[]) {
continue; continue;
} }
/** Prepare input blobs**/
ptrInputBlobs.clear(); ptrInputBlobs.clear();
if (FLAGS_iname.empty()) { if (FLAGS_iname.empty()) {
for (auto &input : cInputInfo) { for (auto &input : cInputInfo) {
@ -1063,6 +1070,7 @@ int main(int argc, char *argv[]) {
} }
int index = static_cast<int>(frameIndex) - (FLAGS_cw_l + FLAGS_cw_r); int index = static_cast<int>(frameIndex) - (FLAGS_cw_l + FLAGS_cw_r);
/** Start inference **/
inferRequest.inferRequest.StartAsync(); inferRequest.inferRequest.StartAsync();
inferRequest.frameIndex = index < 0 ? -2 : index; inferRequest.frameIndex = index < 0 ? -2 : index;
inferRequest.numFramesThisBatch = numFramesThisBatch; inferRequest.numFramesThisBatch = numFramesThisBatch;
@ -1086,6 +1094,7 @@ int main(int argc, char *argv[]) {
} }
inferRequestFetched |= true; inferRequestFetched |= true;
} }
/** Inference was finished for current frame **/
if (!inferRequestFetched) { if (!inferRequestFetched) {
std::this_thread::sleep_for(std::chrono::milliseconds(1)); std::this_thread::sleep_for(std::chrono::milliseconds(1));
continue; continue;
@ -1103,6 +1112,7 @@ int main(int argc, char *argv[]) {
} }
if (!FLAGS_o.empty()) { if (!FLAGS_o.empty()) {
/* Save output data to file */
bool shouldAppend = (utteranceIndex == 0) ? false : true; bool shouldAppend = (utteranceIndex == 0) ? false : true;
SaveKaldiArkArray(output_name_files[next_output].c_str(), shouldAppend, uttName, &ptrScores.front(), SaveKaldiArkArray(output_name_files[next_output].c_str(), shouldAppend, uttName, &ptrScores.front(),
numFramesArkFile, numScoresPerFrame); numFramesArkFile, numScoresPerFrame);