Python speech sample (#5322)
* Upload python speech sample draft * Add function doc strings * Add the ability to save results to a file * Add a plugin configuration for GNA to get better accuracy * Add errors calculation (comparison with a reference) * Fix flake8 issues * Add ability to run in hetero mode * Add ability to load and save numpy format files (.npz) * Add an error for wrong file extensions & update help message * Add import and export GNA model options * Add -we option to export embedded gna model * Add readme * Add -oname command line option (Layer names for output blobs) * Add -iname command line option (Layer names for input blobs) * Add info about -iname option to README.md * doc: update readme, fix style * Add a state reset between inferences * add reset API to speech README * doc: remove extra output from README * remove onnx and TODO, format output * Add an else branch to the if statement that checks a utterance data type * Add dummy data for inference if a number of vectors < batch size * Split the sample into separte files Co-authored-by: Kate Generalova <kate.generalova@intel.com>
This commit is contained in:
parent
d40636f835
commit
b2abf25218
@ -9,7 +9,9 @@ After installation of Intel® Distribution of OpenVINO™ toolkit, С, C++ and P
|
||||
|
||||
Inference Engine sample applications include the following:
|
||||
|
||||
- **[Automatic Speech Recognition C++ Sample](../../inference-engine/samples/speech_sample/README.md)** – Acoustic model inference based on Kaldi neural networks and speech feature vectors.
|
||||
- **Speech Sample** - Acoustic model inference based on Kaldi neural networks and speech feature vectors.
|
||||
- [Automatic Speech Recognition C++ Sample](../../inference-engine/samples/speech_sample/README.md)
|
||||
- [Automatic Speech Recognition Python Sample](../../inference-engine/ie_bridges/python/sample/speech_sample/README.md)
|
||||
- **Benchmark Application** – Estimates deep learning inference performance on supported devices for synchronous and asynchronous modes.
|
||||
- [Benchmark C++ Tool](../../inference-engine/samples/benchmark_app/README.md)
|
||||
- [Benchmark Python Tool](../../inference-engine/tools/benchmark_tool/README.md)
|
||||
|
@ -180,6 +180,7 @@ limitations under the License.
|
||||
<tab type="user" title="Object Detection SSD Python* Sample" url="@ref openvino_inference_engine_ie_bridges_python_sample_object_detection_sample_ssd_README"/>
|
||||
<tab type="user" title="Object Detection SSD C Sample" url="@ref openvino_inference_engine_ie_bridges_c_samples_object_detection_sample_ssd_README"/>
|
||||
<tab type="user" title="Automatic Speech Recognition C++ Sample" url="@ref openvino_inference_engine_samples_speech_sample_README"/>
|
||||
<tab type="user" title="Automatic Speech Recognition Python Sample" url="@ref openvino_inference_engine_ie_bridges_python_sample_speech_sample_README"/>
|
||||
<tab type="user" title="Style Transfer C++ Sample" url="@ref openvino_inference_engine_samples_style_transfer_sample_README"/>
|
||||
<tab type="user" title="Style Transfer Python* Sample" url="@ref openvino_inference_engine_ie_bridges_python_sample_style_transfer_sample_README"/>
|
||||
<tab type="user" title="Benchmark C++ Tool" url="@ref openvino_inference_engine_samples_benchmark_app_README"/>
|
||||
|
@ -0,0 +1,205 @@
|
||||
# Automatic Speech Recognition Python* Sample {#openvino_inference_engine_ie_bridges_python_sample_speech_sample_README}
|
||||
|
||||
This sample demonstrates how to do a Synchronous Inference of acoustic model based on Kaldi\* neural networks and speech feature vectors.
|
||||
|
||||
The sample works with Kaldi ARK or Numpy* uncompressed NPZ files, so it does not cover an end-to-end speech recognition scenario (speech to text), requiring additional preprocessing (feature extraction) to get a feature vector from a speech signal, as well as postprocessing (decoding) to produce text from scores.
|
||||
|
||||
Automatic Speech Recognition Python sample application demonstrates how to use the following Inference Engine Python API in applications:
|
||||
|
||||
| Feature | API | Description |
|
||||
| :------------------ | :---------------------------------------------------------------------------------------------------- | :-------------------------------------------------------------------- |
|
||||
| Import/Export Model | [IECore.import_network], [ExecutableNetwork.export] | The GNA plugin supports loading and saving of the GNA-optimized model |
|
||||
| Network Operations | [IENetwork.batch_size], [CDataPtr.shape], [ExecutableNetwork.input_info], [ExecutableNetwork.outputs] | Managing of network: configure input and output blobs |
|
||||
| Network Operations | [IENetwork.add_outputs] | Managing of network: Change names of output layers in the network |
|
||||
| InferRequest Operations|InferRequest.query_state, VariableState.reset| Gets and resets state control interface for given executable network |
|
||||
|
||||
Basic Inference Engine API is covered by [Hello Classification Python* Sample](../hello_classification/README.md).
|
||||
|
||||
| Options | Values |
|
||||
| :------------------------- | :---------------------------------------------------------------------------------------------------- |
|
||||
| Validated Models | Acoustic model based on Kaldi* neural networks (see [Model Preparation](#model-preparation) section) |
|
||||
| Model Format | Inference Engine Intermediate Representation (.xml + .bin) |
|
||||
| Supported devices | See [Execution Modes](#execution-modes) section below and [List Supported Devices](../../../../../docs/IE_DG/supported_plugins/Supported_Devices.md) |
|
||||
| Other language realization | [C++](../../../../samples/speech_sample) |
|
||||
|
||||
## How It Works
|
||||
|
||||
At startup, the sample application reads command-line parameters, loads a specified model and input data to the Inference Engine plugin, performs synchronous inference on all speech utterances stored in the input file, logging each step in a standard output stream.
|
||||
|
||||
You can see the explicit description of
|
||||
each sample step at [Integration Steps](../../../../../docs/IE_DG/Integrate_with_customer_application_new_API.md) section of "Integrate the Inference Engine with Your Application" guide.
|
||||
|
||||
## GNA-specific details
|
||||
|
||||
### Quantization
|
||||
|
||||
If the GNA device is selected (for example, using the `-d` GNA flag), the GNA Inference Engine plugin quantizes the model and input feature vector sequence to integer representation before performing inference.
|
||||
|
||||
The `-qb` flag provides a hint to the GNA plugin regarding the preferred target weight resolution for all layers.
|
||||
For example, when `-qb 8` is specified, the plugin will use 8-bit weights wherever possible in the
|
||||
network.
|
||||
|
||||
> **NOTE**:
|
||||
>
|
||||
> - It is not always possible to use 8-bit weights due to GNA hardware limitations. For example, convolutional layers always use 16-bit weights (GNA hardware version 1 and 2). This limitation will be removed in GNA hardware version 3 and higher.
|
||||
>
|
||||
|
||||
### Execution Modes
|
||||
|
||||
Several execution modes are supported via the `-d` flag:
|
||||
|
||||
- `CPU` - All calculation are performed on CPU device using CPU Plugin.
|
||||
- `GPU` - All calculation are performed on GPU device using GPU Plugin.
|
||||
- `MYRIAD` - All calculation are performed on Intel® Neural Compute Stick 2 device using VPU MYRIAD Plugin.
|
||||
- `GNA_AUTO` - GNA hardware is used if available and the driver is installed. Otherwise, the GNA device is emulated in fast-but-not-bit-exact mode.
|
||||
- `GNA_HW` - GNA hardware is used if available and the driver is installed. Otherwise, an error will occur.
|
||||
- `GNA_SW` - Deprecated. The GNA device is emulated in fast-but-not-bit-exact mode.
|
||||
- `GNA_SW_FP32` - Substitutes parameters and calculations from low precision to floating point (FP32).
|
||||
- `GNA_SW_EXACT` - GNA device is emulated in bit-exact mode.
|
||||
|
||||
### Loading and Saving Models
|
||||
|
||||
The GNA plugin supports loading and saving of the GNA-optimized model (non-IR) via the `-rg` and `-wg` flags.
|
||||
Thereby, it is possible to avoid the cost of full model quantization at run time.
|
||||
|
||||
In addition to performing inference directly from a GNA model file, this option makes it possible to:
|
||||
|
||||
- Convert from IR format to GNA format model file (`-m`, `-wg`)
|
||||
|
||||
## Running
|
||||
|
||||
Run the application with the <code>-h</code> option to see the usage message:
|
||||
|
||||
```sh
|
||||
python speech_sample.py -h
|
||||
```
|
||||
|
||||
Usage message:
|
||||
|
||||
```sh
|
||||
usage: speech_sample.py [-h] (-m MODEL | -rg IMPORT_GNA_MODEL) -i INPUT
|
||||
[-o OUTPUT] [-r REFERENCE] [-d DEVICE]
|
||||
[-bs BATCH_SIZE] [-qb QUANTIZATION_BITS]
|
||||
[-wg EXPORT_GNA_MODEL] [-iname INPUT_LAYERS]
|
||||
[-oname OUTPUT_LAYERS]
|
||||
|
||||
optional arguments:
|
||||
-m MODEL, --model MODEL
|
||||
Path to an .xml file with a trained model (required if
|
||||
-rg is missing).
|
||||
-rg IMPORT_GNA_MODEL, --import_gna_model IMPORT_GNA_MODEL
|
||||
Read GNA model from file using path/filename provided
|
||||
(required if -m is missing).
|
||||
|
||||
Options:
|
||||
-h, --help Show this help message and exit.
|
||||
-i INPUT, --input INPUT
|
||||
Required. Path to an input file (.ark or .npz).
|
||||
-o OUTPUT, --output OUTPUT
|
||||
Optional. Output file name to save inference results (.ark or .npz).
|
||||
-r REFERENCE, --reference REFERENCE
|
||||
Optional. Read reference score file and compare
|
||||
scores.
|
||||
-d DEVICE, --device DEVICE
|
||||
Optional. Specify a target device to infer on. CPU,
|
||||
GPU, MYRIAD, GNA_AUTO, GNA_HW, GNA_SW_FP32,
|
||||
GNA_SW_EXACT and HETERO with combination of GNA as the
|
||||
primary device and CPU as a secondary (e.g.
|
||||
HETERO:GNA,CPU) are supported. The sample will look
|
||||
for a suitable plugin for device specified. Default
|
||||
value is CPU.
|
||||
-bs BATCH_SIZE, --batch_size BATCH_SIZE
|
||||
Optional. Batch size 1-8 (default 1).
|
||||
-qb QUANTIZATION_BITS, --quantization_bits QUANTIZATION_BITS
|
||||
Optional. Weight bits for quantization: 8 or 16
|
||||
(default 16).
|
||||
-wg EXPORT_GNA_MODEL, --export_gna_model EXPORT_GNA_MODEL
|
||||
Optional. Write GNA model to file using path/filename
|
||||
provided.
|
||||
-iname INPUT_LAYERS, --input_layers INPUT_LAYERS
|
||||
Optional. Layer names for input blobs. The names are
|
||||
separated with ",". Allows to change the order of
|
||||
input layers for -i flag. Example: Input1,Input2
|
||||
-oname OUTPUT_LAYERS, --output_layers OUTPUT_LAYERS
|
||||
Optional. Layer names for output blobs. The names are
|
||||
separated with ",". Allows to change the order of
|
||||
output layers for -o flag. Example:
|
||||
Output1:port,Output2:port.
|
||||
```
|
||||
|
||||
## Model Preparation
|
||||
|
||||
You can use the following model optimizer command to convert a Kaldi nnet1 or nnet2 neural network to Inference Engine Intermediate Representation format:
|
||||
|
||||
```sh
|
||||
python mo.py --framework kaldi --input_model wsj_dnn5b.nnet --counts wsj_dnn5b.counts --remove_output_softmax --output_dir <OUTPUT_MODEL_DIR>
|
||||
```
|
||||
|
||||
The following pre-trained models are available:
|
||||
|
||||
- wsj_dnn5b_smbr
|
||||
- rm_lstm4f
|
||||
- rm_cnn4a_smbr
|
||||
|
||||
All of them can be downloaded from [https://storage.openvinotoolkit.org/models_contrib/speech/2021.2](https://storage.openvinotoolkit.org/models_contrib/speech/2021.2).
|
||||
|
||||
## Speech Inference
|
||||
|
||||
You can do inference on Intel® Processors with the GNA co-processor (or emulation library):
|
||||
|
||||
```sh
|
||||
python speech_sample.py -d GNA_AUTO -m wsj_dnn5b.xml -i dev93_10.ark -r dev93_scores_10.ark -o result.npz
|
||||
```
|
||||
|
||||
> **NOTES**:
|
||||
>
|
||||
> - Before running the sample with a trained model, make sure the model is converted to the Inference Engine format (\*.xml + \*.bin) using the [Model Optimizer tool](../../../../../docs/MO_DG/Deep_Learning_Model_Optimizer_DevGuide.md).
|
||||
>
|
||||
> - The sample supports input and output in numpy file format (.npz)
|
||||
|
||||
## Sample Output
|
||||
|
||||
The sample application logs each step in a standard output stream.
|
||||
|
||||
```sh
|
||||
[ INFO ] Creating Inference Engine
|
||||
[ INFO ] Reading the network: wsj_dnn5b.xml
|
||||
[ INFO ] Configuring input and output blobs
|
||||
[ INFO ] Using scale factor of 2175.4322417 calculated from first utterance.
|
||||
[ INFO ] Loading the model to the plugin
|
||||
[ INFO ] Starting inference in synchronous mode
|
||||
[ INFO ] Utterance 0 (4k0c0301)
|
||||
[ INFO ] Frames in utterance: 1294
|
||||
[ INFO ] Total time in Infer (HW and SW): 5305.47ms
|
||||
[ INFO ] max error: 0.7051839
|
||||
[ INFO ] avg error: 0.0448387
|
||||
[ INFO ] avg rms error: 0.0582387
|
||||
[ INFO ] stdev error: 0.0371649
|
||||
[ INFO ]
|
||||
[ INFO ] Utterance 1 (4k0c0302)
|
||||
[ INFO ] Frames in utterance: 1005
|
||||
[ INFO ] Total time in Infer (HW and SW): 5031.53ms
|
||||
[ INFO ] max error: 0.7575974
|
||||
[ INFO ] avg error: 0.0452166
|
||||
[ INFO ] avg rms error: 0.0586013
|
||||
[ INFO ] stdev error: 0.0372769
|
||||
[ INFO ]
|
||||
...
|
||||
[ INFO ] Total sample time: 38033.09ms
|
||||
[ INFO ] This sample is an API example, for any performance measurements please use the dedicated benchmark_app tool
|
||||
```
|
||||
|
||||
## See Also
|
||||
|
||||
- [Integrate the Inference Engine with Your Application](../../../../../docs/IE_DG/Integrate_with_customer_application_new_API.md)
|
||||
- [Using Inference Engine Samples](../../../../../docs/IE_DG/Samples_Overview.md)
|
||||
- [Model Downloader](@ref omz_tools_downloader_README)
|
||||
- [Model Optimizer](../../../../../docs/MO_DG/Deep_Learning_Model_Optimizer_DevGuide.md)
|
||||
|
||||
[IENetwork.batch_size]:https://docs.openvinotoolkit.org/latest/ie_python_api/classie__api_1_1IENetwork.html#a79a647cb1b49645616eaeb2ca255ef2e
|
||||
[IENetwork.add_outputs]:https://docs.openvinotoolkit.org/latest/ie_python_api/classie__api_1_1IENetwork.html#ae8024b07f3301d6d5de5c0d153e2e6e6
|
||||
[CDataPtr.shape]:https://docs.openvinotoolkit.org/latest/ie_python_api/classie__api_1_1CDataPtr.html#aa6fd459edb323d1c6215dc7a970ebf7f
|
||||
[ExecutableNetwork.input_info]:https://docs.openvinotoolkit.org/latest/ie_python_api/classie__api_1_1ExecutableNetwork.html#ac76a04c2918607874018d2e15a2f274f
|
||||
[ExecutableNetwork.outputs]:https://docs.openvinotoolkit.org/latest/ie_python_api/classie__api_1_1ExecutableNetwork.html#a4a631776df195004b1523e6ae91a65c1
|
||||
[IECore.import_network]:https://docs.openvinotoolkit.org/latest/ie_python_api/classie__api_1_1IECore.html#afdeac5192bb1d9e64722f1071fb0a64a
|
||||
[ExecutableNetwork.export]:https://docs.openvinotoolkit.org/latest/ie_python_api/classie__api_1_1ExecutableNetwork.html#afa78158252f0d8070181bafec4318413
|
@ -0,0 +1,42 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
# Copyright (C) 2018-2021 Intel Corporation
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
import argparse
|
||||
|
||||
|
||||
def parse_args() -> argparse.Namespace:
|
||||
"""Parse and return command line arguments"""
|
||||
parser = argparse.ArgumentParser(add_help=False)
|
||||
args = parser.add_argument_group('Options')
|
||||
model = parser.add_mutually_exclusive_group(required=True)
|
||||
|
||||
args.add_argument('-h', '--help', action='help', help='Show this help message and exit.')
|
||||
model.add_argument('-m', '--model', type=str,
|
||||
help='Path to an .xml file with a trained model (required if -rg is missing).')
|
||||
model.add_argument('-rg', '--import_gna_model', type=str,
|
||||
help='Read GNA model from file using path/filename provided (required if -m is missing).')
|
||||
args.add_argument('-i', '--input', required=True, type=str, help='Required. Path to an input file (.ark or .npz).')
|
||||
args.add_argument('-o', '--output', type=str,
|
||||
help='Optional. Output file name to save inference results (.ark or .npz).')
|
||||
args.add_argument('-r', '--reference', type=str,
|
||||
help='Optional. Read reference score file and compare scores.')
|
||||
args.add_argument('-d', '--device', default='CPU', type=str,
|
||||
help='Optional. Specify a target device to infer on. '
|
||||
'CPU, GPU, MYRIAD, GNA_AUTO, GNA_HW, GNA_SW_FP32, GNA_SW_EXACT and HETERO with combination of GNA'
|
||||
' as the primary device and CPU as a secondary (e.g. HETERO:GNA,CPU) are supported. '
|
||||
'The sample will look for a suitable plugin for device specified. Default value is CPU.')
|
||||
args.add_argument('-bs', '--batch_size', default=1, type=int, help='Optional. Batch size 1-8 (default 1).')
|
||||
args.add_argument('-qb', '--quantization_bits', default=16, type=int,
|
||||
help='Optional. Weight bits for quantization: 8 or 16 (default 16).')
|
||||
args.add_argument('-wg', '--export_gna_model', type=str,
|
||||
help='Optional. Write GNA model to file using path/filename provided.')
|
||||
args.add_argument('-we', '--export_embedded_gna_model', type=str, help=argparse.SUPPRESS)
|
||||
args.add_argument('-we_gen', '--embedded_gna_configuration', default='GNA1', type=str, help=argparse.SUPPRESS)
|
||||
args.add_argument('-iname', '--input_layers', type=str,
|
||||
help='Optional. Layer names for input blobs. The names are separated with ",". '
|
||||
'Allows to change the order of input layers for -i flag. Example: Input1,Input2')
|
||||
args.add_argument('-oname', '--output_layers', type=str,
|
||||
help='Optional. Layer names for output blobs. The names are separated with ",". '
|
||||
'Allows to change the order of output layers for -o flag. Example: Output1:port,Output2:port.')
|
||||
|
||||
return parser.parse_args()
|
@ -0,0 +1,104 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
# Copyright (C) 2018-2021 Intel Corporation
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
import logging as log
|
||||
import sys
|
||||
from typing import IO, Any
|
||||
|
||||
import numpy as np
|
||||
|
||||
|
||||
def read_ark_file(file_name: str) -> dict:
|
||||
"""Read utterance matrices from a .ark file"""
|
||||
def read_key(input_file: IO[Any]) -> str:
|
||||
"""Read a identifier of utterance matrix"""
|
||||
key = ''
|
||||
while True:
|
||||
char = input_file.read(1).decode()
|
||||
if char in ('', ' '):
|
||||
break
|
||||
else:
|
||||
key += char
|
||||
|
||||
return key
|
||||
|
||||
def read_matrix(input_file: IO[Any]) -> np.ndarray:
|
||||
"""Read a utterance matrix"""
|
||||
header = input_file.read(5).decode()
|
||||
if 'FM' in header:
|
||||
num_of_bytes = 4
|
||||
dtype = 'float32'
|
||||
elif 'DM' in header:
|
||||
num_of_bytes = 8
|
||||
dtype = 'float64'
|
||||
else:
|
||||
log.error(f'The utterance header "{header}" does not contain information about a type of elements.')
|
||||
sys.exit(-7)
|
||||
|
||||
_, rows, _, cols = np.frombuffer(input_file.read(10), 'int8, int32, int8, int32')[0]
|
||||
buffer = input_file.read(rows * cols * num_of_bytes)
|
||||
vector = np.frombuffer(buffer, dtype)
|
||||
matrix = np.reshape(vector, (rows, cols))
|
||||
|
||||
return matrix
|
||||
|
||||
utterances = {}
|
||||
with open(file_name, 'rb') as input_file:
|
||||
while True:
|
||||
key = read_key(input_file)
|
||||
if not key:
|
||||
break
|
||||
utterances[key] = read_matrix(input_file)
|
||||
|
||||
return utterances
|
||||
|
||||
|
||||
def write_ark_file(file_name: str, utterances: dict):
|
||||
"""Write utterance matrices to a .ark file"""
|
||||
with open(file_name, 'wb') as output_file:
|
||||
for key, matrix in sorted(utterances.items()):
|
||||
# write a matrix key
|
||||
output_file.write(key.encode())
|
||||
output_file.write(' '.encode())
|
||||
output_file.write('\0B'.encode())
|
||||
|
||||
# write a matrix precision
|
||||
if matrix.dtype == 'float32':
|
||||
output_file.write('FM '.encode())
|
||||
elif matrix.dtype == 'float64':
|
||||
output_file.write('DM '.encode())
|
||||
|
||||
# write a matrix shape
|
||||
output_file.write('\04'.encode())
|
||||
output_file.write(matrix.shape[0].to_bytes(4, byteorder='little', signed=False))
|
||||
output_file.write('\04'.encode())
|
||||
output_file.write(matrix.shape[1].to_bytes(4, byteorder='little', signed=False))
|
||||
|
||||
# write a matrix data
|
||||
output_file.write(matrix.tobytes())
|
||||
|
||||
|
||||
def read_utterance_file(file_name: str) -> dict:
|
||||
"""Read utterance matrices from a file"""
|
||||
file_extension = file_name.split('.')[-1]
|
||||
|
||||
if file_extension == 'ark':
|
||||
return read_ark_file(file_name)
|
||||
elif file_extension == 'npz':
|
||||
return dict(np.load(file_name))
|
||||
else:
|
||||
log.error(f'The file {file_name} cannot be read. The sample supports only .ark and .npz files.')
|
||||
sys.exit(-1)
|
||||
|
||||
|
||||
def write_utterance_file(file_name: str, utterances: dict):
|
||||
"""Write utterance matrices to a file"""
|
||||
file_extension = file_name.split('.')[-1]
|
||||
|
||||
if file_extension == 'ark':
|
||||
write_ark_file(file_name, utterances)
|
||||
elif file_extension == 'npz':
|
||||
np.savez(file_name, **utterances)
|
||||
else:
|
||||
log.error(f'The file {file_name} cannot be written. The sample supports only .ark and .npz files.')
|
||||
sys.exit(-2)
|
255
inference-engine/ie_bridges/python/sample/speech_sample/speech_sample.py
Executable file
255
inference-engine/ie_bridges/python/sample/speech_sample/speech_sample.py
Executable file
@ -0,0 +1,255 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding: utf-8 -*-
|
||||
# Copyright (C) 2018-2021 Intel Corporation
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
import logging as log
|
||||
import re
|
||||
import sys
|
||||
from timeit import default_timer
|
||||
|
||||
import numpy as np
|
||||
from arg_parser import parse_args
|
||||
from file_options import read_utterance_file, write_utterance_file
|
||||
from openvino.inference_engine import ExecutableNetwork, IECore
|
||||
|
||||
|
||||
def get_scale_factor(matrix: np.ndarray) -> float:
|
||||
"""Get scale factor for quantization using utterance matrix"""
|
||||
# Max to find scale factor
|
||||
target_max = 16384
|
||||
max_val = np.max(matrix)
|
||||
if max_val == 0:
|
||||
return 1.0
|
||||
else:
|
||||
return target_max / max_val
|
||||
|
||||
|
||||
def infer_data(data: dict, exec_net: ExecutableNetwork, input_blobs: list, output_blobs: list) -> np.ndarray:
|
||||
"""Do a synchronous matrix inference"""
|
||||
matrix_shape = next(iter(data.values())).shape
|
||||
result = {}
|
||||
|
||||
for blob_name in output_blobs:
|
||||
batch_size, num_of_dims = exec_net.outputs[blob_name].shape
|
||||
result[blob_name] = np.ndarray((matrix_shape[0], num_of_dims))
|
||||
|
||||
slice_begin = 0
|
||||
slice_end = batch_size
|
||||
|
||||
while True:
|
||||
vectors = {blob_name: data[blob_name][slice_begin:slice_end] for blob_name in input_blobs}
|
||||
vector_shape = next(iter(vectors.values())).shape
|
||||
|
||||
if vector_shape[0] < batch_size:
|
||||
temp = {blob_name: np.zeros((batch_size, vector_shape[1])) for blob_name in input_blobs}
|
||||
|
||||
for blob_name in input_blobs:
|
||||
temp[blob_name][:vector_shape[0]] = vectors[blob_name]
|
||||
|
||||
vectors = temp
|
||||
|
||||
vector_results = exec_net.infer(vectors)
|
||||
|
||||
for blob_name in output_blobs:
|
||||
result[blob_name][slice_begin:slice_end] = vector_results[blob_name][:vector_shape[0]]
|
||||
|
||||
slice_begin += batch_size
|
||||
slice_end += batch_size
|
||||
|
||||
if slice_begin >= matrix_shape[0]:
|
||||
return result
|
||||
|
||||
|
||||
def compare_with_reference(result: np.ndarray, reference: np.ndarray):
|
||||
error_matrix = np.absolute(result - reference)
|
||||
|
||||
max_error = np.max(error_matrix)
|
||||
sum_error = np.sum(error_matrix)
|
||||
avg_error = sum_error / error_matrix.size
|
||||
sum_square_error = np.sum(np.square(error_matrix))
|
||||
avg_rms_error = np.sqrt(sum_square_error / error_matrix.size)
|
||||
stdev_error = np.sqrt(sum_square_error / error_matrix.size - avg_error * avg_error)
|
||||
|
||||
log.info(f'max error: {max_error:.7f}')
|
||||
log.info(f'avg error: {avg_error:.7f}')
|
||||
log.info(f'avg rms error: {avg_rms_error:.7f}')
|
||||
log.info(f'stdev error: {stdev_error:.7f}')
|
||||
|
||||
|
||||
def main():
|
||||
log.basicConfig(format='[ %(levelname)s ] %(message)s', level=log.INFO, stream=sys.stdout)
|
||||
args = parse_args()
|
||||
|
||||
# ---------------------------Step 1. Initialize inference engine core--------------------------------------------------
|
||||
log.info('Creating Inference Engine')
|
||||
ie = IECore()
|
||||
|
||||
# ---------------------------Step 2. Read a model in OpenVINO Intermediate Representation---------------
|
||||
if args.model:
|
||||
log.info(f'Reading the network: {args.model}')
|
||||
# .xml and .bin files
|
||||
net = ie.read_network(model=args.model)
|
||||
|
||||
# ---------------------------Step 3. Configure input & output----------------------------------------------------------
|
||||
log.info('Configuring input and output blobs')
|
||||
# Get names of input and output blobs
|
||||
if args.input_layers:
|
||||
input_blobs = re.split(', |,', args.input_layers)
|
||||
else:
|
||||
input_blobs = [next(iter(net.input_info))]
|
||||
|
||||
if args.output_layers:
|
||||
output_name_port = [output.split(':') for output in re.split(', |,', args.output_layers)]
|
||||
try:
|
||||
output_name_port = [(blob_name, int(port)) for blob_name, port in output_name_port]
|
||||
except ValueError:
|
||||
log.error('Output Parameter does not have a port.')
|
||||
sys.exit(-4)
|
||||
|
||||
net.add_outputs(output_name_port)
|
||||
|
||||
output_blobs = [blob_name for blob_name, port in output_name_port]
|
||||
else:
|
||||
output_blobs = [list(net.outputs.keys())[-1]]
|
||||
|
||||
# Set input and output precision manually
|
||||
for blob_name in input_blobs:
|
||||
net.input_info[blob_name].precision = 'FP32'
|
||||
|
||||
for blob_name in output_blobs:
|
||||
net.outputs[blob_name].precision = 'FP32'
|
||||
|
||||
net.batch_size = args.batch_size
|
||||
|
||||
# ---------------------------Step 4. Loading model to the device-------------------------------------------------------
|
||||
devices = args.device.replace('HETERO:', '').split(',')
|
||||
plugin_config = {}
|
||||
|
||||
if 'GNA' in args.device:
|
||||
gna_device_mode = devices[0] if '_' in devices[0] else 'GNA_AUTO'
|
||||
devices[0] = 'GNA'
|
||||
|
||||
plugin_config['GNA_DEVICE_MODE'] = gna_device_mode
|
||||
plugin_config['GNA_PRECISION'] = f'I{args.quantization_bits}'
|
||||
|
||||
# Get a GNA scale factor
|
||||
if args.import_gna_model:
|
||||
log.info(f'Using scale factor from the imported GNA model: {args.import_gna_model}')
|
||||
else:
|
||||
utterances = read_utterance_file(args.input.split(',')[0])
|
||||
key = sorted(utterances)[0]
|
||||
scale_factor = get_scale_factor(utterances[key])
|
||||
log.info(f'Using scale factor of {scale_factor:.7f} calculated from first utterance.')
|
||||
|
||||
plugin_config['GNA_SCALE_FACTOR'] = str(scale_factor)
|
||||
|
||||
if args.export_embedded_gna_model:
|
||||
plugin_config['GNA_FIRMWARE_MODEL_IMAGE'] = args.export_embedded_gna_model
|
||||
plugin_config['GNA_FIRMWARE_MODEL_IMAGE_GENERATION'] = args.embedded_gna_configuration
|
||||
|
||||
device_str = f'HETERO:{",".join(devices)}' if 'HETERO' in args.device else devices[0]
|
||||
|
||||
log.info('Loading the model to the plugin')
|
||||
if args.model:
|
||||
exec_net = ie.load_network(net, device_str, plugin_config)
|
||||
else:
|
||||
exec_net = ie.import_network(args.import_gna_model, device_str, plugin_config)
|
||||
input_blobs = [next(iter(exec_net.input_info))]
|
||||
output_blobs = [list(exec_net.outputs.keys())[-1]]
|
||||
|
||||
if args.input:
|
||||
input_files = re.split(', |,', args.input)
|
||||
|
||||
if len(input_blobs) != len(input_files):
|
||||
log.error(f'Number of network inputs ({len(input_blobs)}) is not equal '
|
||||
f'to number of ark files ({len(input_files)})')
|
||||
sys.exit(-3)
|
||||
|
||||
if args.reference:
|
||||
reference_files = re.split(', |,', args.reference)
|
||||
|
||||
if len(output_blobs) != len(reference_files):
|
||||
log.error('The number of reference files is not equal to the number of network outputs.')
|
||||
sys.exit(-5)
|
||||
|
||||
if args.output:
|
||||
output_files = re.split(', |,', args.output)
|
||||
|
||||
if len(output_blobs) != len(output_files):
|
||||
log.error('The number of output files is not equal to the number of network outputs.')
|
||||
sys.exit(-6)
|
||||
|
||||
if args.export_gna_model:
|
||||
log.info(f'Writing GNA Model to {args.export_gna_model}')
|
||||
exec_net.export(args.export_gna_model)
|
||||
return 0
|
||||
|
||||
if args.export_embedded_gna_model:
|
||||
log.info(f'Exported GNA embedded model to file {args.export_embedded_gna_model}')
|
||||
log.info(f'GNA embedded model export done for GNA generation {args.embedded_gna_configuration}')
|
||||
return 0
|
||||
|
||||
# ---------------------------Step 5. Create infer request--------------------------------------------------------------
|
||||
# load_network() method of the IECore class with a specified number of requests (default 1) returns an ExecutableNetwork
|
||||
# instance which stores infer requests. So you already created Infer requests in the previous step.
|
||||
|
||||
# ---------------------------Step 6. Prepare input---------------------------------------------------------------------
|
||||
file_data = [read_utterance_file(file_name) for file_name in input_files]
|
||||
input_data = {
|
||||
utterance_name: {
|
||||
input_blobs[i]: file_data[i][utterance_name] for i in range(len(input_blobs))
|
||||
}
|
||||
for utterance_name in file_data[0].keys()
|
||||
}
|
||||
|
||||
if args.reference:
|
||||
references = {output_blobs[i]: read_utterance_file(reference_files[i]) for i in range(len(output_blobs))}
|
||||
|
||||
# ---------------------------Step 7. Do inference----------------------------------------------------------------------
|
||||
log.info('Starting inference in synchronous mode')
|
||||
results = {blob_name: {} for blob_name in output_blobs}
|
||||
infer_times = []
|
||||
|
||||
for key in sorted(input_data):
|
||||
start_infer_time = default_timer()
|
||||
|
||||
# Reset states between utterance inferences to remove a memory impact
|
||||
for request in exec_net.requests:
|
||||
for state in request.query_state():
|
||||
state.reset()
|
||||
|
||||
result = infer_data(input_data[key], exec_net, input_blobs, output_blobs)
|
||||
|
||||
for blob_name in result.keys():
|
||||
results[blob_name][key] = result[blob_name]
|
||||
|
||||
infer_times.append(default_timer() - start_infer_time)
|
||||
|
||||
# ---------------------------Step 8. Process output--------------------------------------------------------------------
|
||||
for blob_name in output_blobs:
|
||||
for i, key in enumerate(sorted(results[blob_name])):
|
||||
log.info(f'Utterance {i} ({key})')
|
||||
log.info(f'Output blob name: {blob_name}')
|
||||
log.info(f'Frames in utterance: {results[blob_name][key].shape[0]}')
|
||||
log.info(f'Total time in Infer (HW and SW): {infer_times[i] * 1000:.2f}ms')
|
||||
|
||||
if args.reference:
|
||||
compare_with_reference(results[blob_name][key], references[blob_name][key])
|
||||
|
||||
log.info('')
|
||||
|
||||
log.info(f'Total sample time: {sum(infer_times) * 1000:.2f}ms')
|
||||
|
||||
if args.output:
|
||||
for i, blob_name in enumerate(results):
|
||||
write_utterance_file(output_files[i], results[blob_name])
|
||||
log.info(f'File {output_files[i]} was created!')
|
||||
|
||||
# ----------------------------------------------------------------------------------------------------------------------
|
||||
log.info('This sample is an API example, '
|
||||
'for any performance measurements please use the dedicated benchmark_app tool\n')
|
||||
return 0
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
sys.exit(main())
|
@ -22,7 +22,7 @@ Basic Inference Engine API is covered by [Hello Classification C++ sample](../he
|
||||
| Options | Values |
|
||||
|:--- |:---
|
||||
| Validated Models | Acoustic model based on Kaldi\* neural networks (see [Model Preparation](#model-preparation) section)
|
||||
| Model Format | Inference Engine Intermediate Representation (\*.xml + \*.bin), ONNX (\*.onnx)
|
||||
| Model Format | Inference Engine Intermediate Representation (\*.xml + \*.bin)
|
||||
| Supported devices | See [Execution Modes](#execution-modes) section below and [List Supported Devices](../../../docs/IE_DG/supported_plugins/Supported_Devices.md) |
|
||||
|
||||
## How It Works
|
||||
@ -164,8 +164,7 @@ All of them can be downloaded from [https://storage.openvinotoolkit.org/models_c
|
||||
> **NOTES**:
|
||||
>
|
||||
> - Before running the sample with a trained model, make sure the model is converted to the Inference Engine format (\*.xml + \*.bin) using the [Model Optimizer tool](../../../docs/MO_DG/Deep_Learning_Model_Optimizer_DevGuide.md).
|
||||
>
|
||||
> - The sample accepts models in ONNX format (.onnx) that do not require preprocessing.
|
||||
|
||||
|
||||
## Sample Output
|
||||
|
||||
|
@ -681,7 +681,7 @@ int main(int argc, char* argv[]) {
|
||||
std::cout << ie.GetVersions(deviceStr) << std::endl;
|
||||
// -----------------------------------------------------------------------------------------------------
|
||||
|
||||
// --------------------------- Step 2. Read a model in OpenVINO Intermediate Representation (.xml and .bin files) or ONNX (.onnx file) format
|
||||
// --------------------------- Step 2. Read a model in OpenVINO Intermediate Representation (.xml and .bin files)
|
||||
slog::info << "Loading network files:" << slog::endl << FLAGS_m << slog::endl;
|
||||
|
||||
uint32_t batchSize = (FLAGS_cw_r > 0 || FLAGS_cw_l > 0) ? 1 : (uint32_t)FLAGS_bs;
|
||||
|
Loading…
Reference in New Issue
Block a user