diff --git a/docs/IE_DG/Samples_Overview.md b/docs/IE_DG/Samples_Overview.md index 11bfb1d51c3..db39cbfc5b4 100644 --- a/docs/IE_DG/Samples_Overview.md +++ b/docs/IE_DG/Samples_Overview.md @@ -9,7 +9,9 @@ After installation of Intel® Distribution of OpenVINO™ toolkit, С, C++ and P Inference Engine sample applications include the following: -- **[Automatic Speech Recognition C++ Sample](../../inference-engine/samples/speech_sample/README.md)** – Acoustic model inference based on Kaldi neural networks and speech feature vectors. +- **Speech Sample** - Acoustic model inference based on Kaldi neural networks and speech feature vectors. + - [Automatic Speech Recognition C++ Sample](../../inference-engine/samples/speech_sample/README.md) + - [Automatic Speech Recognition Python Sample](../../inference-engine/ie_bridges/python/sample/speech_sample/README.md) - **Benchmark Application** – Estimates deep learning inference performance on supported devices for synchronous and asynchronous modes. - [Benchmark C++ Tool](../../inference-engine/samples/benchmark_app/README.md) - [Benchmark Python Tool](../../inference-engine/tools/benchmark_tool/README.md) diff --git a/docs/doxygen/openvino_docs.xml b/docs/doxygen/openvino_docs.xml index 34539a41246..a47a08e6777 100644 --- a/docs/doxygen/openvino_docs.xml +++ b/docs/doxygen/openvino_docs.xml @@ -180,6 +180,7 @@ limitations under the License. + diff --git a/inference-engine/ie_bridges/python/sample/speech_sample/README.md b/inference-engine/ie_bridges/python/sample/speech_sample/README.md new file mode 100644 index 00000000000..6f02696e085 --- /dev/null +++ b/inference-engine/ie_bridges/python/sample/speech_sample/README.md @@ -0,0 +1,205 @@ +# Automatic Speech Recognition Python* Sample {#openvino_inference_engine_ie_bridges_python_sample_speech_sample_README} + +This sample demonstrates how to do a Synchronous Inference of acoustic model based on Kaldi\* neural networks and speech feature vectors. + +The sample works with Kaldi ARK or Numpy* uncompressed NPZ files, so it does not cover an end-to-end speech recognition scenario (speech to text), requiring additional preprocessing (feature extraction) to get a feature vector from a speech signal, as well as postprocessing (decoding) to produce text from scores. + +Automatic Speech Recognition Python sample application demonstrates how to use the following Inference Engine Python API in applications: + +| Feature | API | Description | +| :------------------ | :---------------------------------------------------------------------------------------------------- | :-------------------------------------------------------------------- | +| Import/Export Model | [IECore.import_network], [ExecutableNetwork.export] | The GNA plugin supports loading and saving of the GNA-optimized model | +| Network Operations | [IENetwork.batch_size], [CDataPtr.shape], [ExecutableNetwork.input_info], [ExecutableNetwork.outputs] | Managing of network: configure input and output blobs | +| Network Operations | [IENetwork.add_outputs] | Managing of network: Change names of output layers in the network | +| InferRequest Operations|InferRequest.query_state, VariableState.reset| Gets and resets state control interface for given executable network | + +Basic Inference Engine API is covered by [Hello Classification Python* Sample](../hello_classification/README.md). + +| Options | Values | +| :------------------------- | :---------------------------------------------------------------------------------------------------- | +| Validated Models | Acoustic model based on Kaldi* neural networks (see [Model Preparation](#model-preparation) section) | +| Model Format | Inference Engine Intermediate Representation (.xml + .bin) | +| Supported devices | See [Execution Modes](#execution-modes) section below and [List Supported Devices](../../../../../docs/IE_DG/supported_plugins/Supported_Devices.md) | +| Other language realization | [C++](../../../../samples/speech_sample) | + +## How It Works + +At startup, the sample application reads command-line parameters, loads a specified model and input data to the Inference Engine plugin, performs synchronous inference on all speech utterances stored in the input file, logging each step in a standard output stream. + +You can see the explicit description of +each sample step at [Integration Steps](../../../../../docs/IE_DG/Integrate_with_customer_application_new_API.md) section of "Integrate the Inference Engine with Your Application" guide. + +## GNA-specific details + +### Quantization + +If the GNA device is selected (for example, using the `-d` GNA flag), the GNA Inference Engine plugin quantizes the model and input feature vector sequence to integer representation before performing inference. + +The `-qb` flag provides a hint to the GNA plugin regarding the preferred target weight resolution for all layers. +For example, when `-qb 8` is specified, the plugin will use 8-bit weights wherever possible in the +network. + +> **NOTE**: +> +> - It is not always possible to use 8-bit weights due to GNA hardware limitations. For example, convolutional layers always use 16-bit weights (GNA hardware version 1 and 2). This limitation will be removed in GNA hardware version 3 and higher. +> + +### Execution Modes + +Several execution modes are supported via the `-d` flag: + +- `CPU` - All calculation are performed on CPU device using CPU Plugin. +- `GPU` - All calculation are performed on GPU device using GPU Plugin. +- `MYRIAD` - All calculation are performed on Intel® Neural Compute Stick 2 device using VPU MYRIAD Plugin. +- `GNA_AUTO` - GNA hardware is used if available and the driver is installed. Otherwise, the GNA device is emulated in fast-but-not-bit-exact mode. +- `GNA_HW` - GNA hardware is used if available and the driver is installed. Otherwise, an error will occur. +- `GNA_SW` - Deprecated. The GNA device is emulated in fast-but-not-bit-exact mode. +- `GNA_SW_FP32` - Substitutes parameters and calculations from low precision to floating point (FP32). +- `GNA_SW_EXACT` - GNA device is emulated in bit-exact mode. + +### Loading and Saving Models + +The GNA plugin supports loading and saving of the GNA-optimized model (non-IR) via the `-rg` and `-wg` flags. +Thereby, it is possible to avoid the cost of full model quantization at run time. + +In addition to performing inference directly from a GNA model file, this option makes it possible to: + +- Convert from IR format to GNA format model file (`-m`, `-wg`) + +## Running + +Run the application with the -h option to see the usage message: + +```sh +python speech_sample.py -h +``` + +Usage message: + +```sh +usage: speech_sample.py [-h] (-m MODEL | -rg IMPORT_GNA_MODEL) -i INPUT + [-o OUTPUT] [-r REFERENCE] [-d DEVICE] + [-bs BATCH_SIZE] [-qb QUANTIZATION_BITS] + [-wg EXPORT_GNA_MODEL] [-iname INPUT_LAYERS] + [-oname OUTPUT_LAYERS] + +optional arguments: + -m MODEL, --model MODEL + Path to an .xml file with a trained model (required if + -rg is missing). + -rg IMPORT_GNA_MODEL, --import_gna_model IMPORT_GNA_MODEL + Read GNA model from file using path/filename provided + (required if -m is missing). + +Options: + -h, --help Show this help message and exit. + -i INPUT, --input INPUT + Required. Path to an input file (.ark or .npz). + -o OUTPUT, --output OUTPUT + Optional. Output file name to save inference results (.ark or .npz). + -r REFERENCE, --reference REFERENCE + Optional. Read reference score file and compare + scores. + -d DEVICE, --device DEVICE + Optional. Specify a target device to infer on. CPU, + GPU, MYRIAD, GNA_AUTO, GNA_HW, GNA_SW_FP32, + GNA_SW_EXACT and HETERO with combination of GNA as the + primary device and CPU as a secondary (e.g. + HETERO:GNA,CPU) are supported. The sample will look + for a suitable plugin for device specified. Default + value is CPU. + -bs BATCH_SIZE, --batch_size BATCH_SIZE + Optional. Batch size 1-8 (default 1). + -qb QUANTIZATION_BITS, --quantization_bits QUANTIZATION_BITS + Optional. Weight bits for quantization: 8 or 16 + (default 16). + -wg EXPORT_GNA_MODEL, --export_gna_model EXPORT_GNA_MODEL + Optional. Write GNA model to file using path/filename + provided. + -iname INPUT_LAYERS, --input_layers INPUT_LAYERS + Optional. Layer names for input blobs. The names are + separated with ",". Allows to change the order of + input layers for -i flag. Example: Input1,Input2 + -oname OUTPUT_LAYERS, --output_layers OUTPUT_LAYERS + Optional. Layer names for output blobs. The names are + separated with ",". Allows to change the order of + output layers for -o flag. Example: + Output1:port,Output2:port. +``` + +## Model Preparation + +You can use the following model optimizer command to convert a Kaldi nnet1 or nnet2 neural network to Inference Engine Intermediate Representation format: + +```sh +python mo.py --framework kaldi --input_model wsj_dnn5b.nnet --counts wsj_dnn5b.counts --remove_output_softmax --output_dir +``` + +The following pre-trained models are available: + +- wsj_dnn5b_smbr +- rm_lstm4f +- rm_cnn4a_smbr + +All of them can be downloaded from [https://storage.openvinotoolkit.org/models_contrib/speech/2021.2](https://storage.openvinotoolkit.org/models_contrib/speech/2021.2). + +## Speech Inference + +You can do inference on Intel® Processors with the GNA co-processor (or emulation library): + +```sh +python speech_sample.py -d GNA_AUTO -m wsj_dnn5b.xml -i dev93_10.ark -r dev93_scores_10.ark -o result.npz +``` + +> **NOTES**: +> +> - Before running the sample with a trained model, make sure the model is converted to the Inference Engine format (\*.xml + \*.bin) using the [Model Optimizer tool](../../../../../docs/MO_DG/Deep_Learning_Model_Optimizer_DevGuide.md). +> +> - The sample supports input and output in numpy file format (.npz) + +## Sample Output + +The sample application logs each step in a standard output stream. + +```sh +[ INFO ] Creating Inference Engine +[ INFO ] Reading the network: wsj_dnn5b.xml +[ INFO ] Configuring input and output blobs +[ INFO ] Using scale factor of 2175.4322417 calculated from first utterance. +[ INFO ] Loading the model to the plugin +[ INFO ] Starting inference in synchronous mode +[ INFO ] Utterance 0 (4k0c0301) +[ INFO ] Frames in utterance: 1294 +[ INFO ] Total time in Infer (HW and SW): 5305.47ms +[ INFO ] max error: 0.7051839 +[ INFO ] avg error: 0.0448387 +[ INFO ] avg rms error: 0.0582387 +[ INFO ] stdev error: 0.0371649 +[ INFO ] +[ INFO ] Utterance 1 (4k0c0302) +[ INFO ] Frames in utterance: 1005 +[ INFO ] Total time in Infer (HW and SW): 5031.53ms +[ INFO ] max error: 0.7575974 +[ INFO ] avg error: 0.0452166 +[ INFO ] avg rms error: 0.0586013 +[ INFO ] stdev error: 0.0372769 +[ INFO ] +... +[ INFO ] Total sample time: 38033.09ms +[ INFO ] This sample is an API example, for any performance measurements please use the dedicated benchmark_app tool +``` + +## See Also + +- [Integrate the Inference Engine with Your Application](../../../../../docs/IE_DG/Integrate_with_customer_application_new_API.md) +- [Using Inference Engine Samples](../../../../../docs/IE_DG/Samples_Overview.md) +- [Model Downloader](@ref omz_tools_downloader_README) +- [Model Optimizer](../../../../../docs/MO_DG/Deep_Learning_Model_Optimizer_DevGuide.md) + +[IENetwork.batch_size]:https://docs.openvinotoolkit.org/latest/ie_python_api/classie__api_1_1IENetwork.html#a79a647cb1b49645616eaeb2ca255ef2e +[IENetwork.add_outputs]:https://docs.openvinotoolkit.org/latest/ie_python_api/classie__api_1_1IENetwork.html#ae8024b07f3301d6d5de5c0d153e2e6e6 +[CDataPtr.shape]:https://docs.openvinotoolkit.org/latest/ie_python_api/classie__api_1_1CDataPtr.html#aa6fd459edb323d1c6215dc7a970ebf7f +[ExecutableNetwork.input_info]:https://docs.openvinotoolkit.org/latest/ie_python_api/classie__api_1_1ExecutableNetwork.html#ac76a04c2918607874018d2e15a2f274f +[ExecutableNetwork.outputs]:https://docs.openvinotoolkit.org/latest/ie_python_api/classie__api_1_1ExecutableNetwork.html#a4a631776df195004b1523e6ae91a65c1 +[IECore.import_network]:https://docs.openvinotoolkit.org/latest/ie_python_api/classie__api_1_1IECore.html#afdeac5192bb1d9e64722f1071fb0a64a +[ExecutableNetwork.export]:https://docs.openvinotoolkit.org/latest/ie_python_api/classie__api_1_1ExecutableNetwork.html#afa78158252f0d8070181bafec4318413 \ No newline at end of file diff --git a/inference-engine/ie_bridges/python/sample/speech_sample/arg_parser.py b/inference-engine/ie_bridges/python/sample/speech_sample/arg_parser.py new file mode 100644 index 00000000000..cfc20dfb425 --- /dev/null +++ b/inference-engine/ie_bridges/python/sample/speech_sample/arg_parser.py @@ -0,0 +1,42 @@ +# -*- coding: utf-8 -*- +# Copyright (C) 2018-2021 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 +import argparse + + +def parse_args() -> argparse.Namespace: + """Parse and return command line arguments""" + parser = argparse.ArgumentParser(add_help=False) + args = parser.add_argument_group('Options') + model = parser.add_mutually_exclusive_group(required=True) + + args.add_argument('-h', '--help', action='help', help='Show this help message and exit.') + model.add_argument('-m', '--model', type=str, + help='Path to an .xml file with a trained model (required if -rg is missing).') + model.add_argument('-rg', '--import_gna_model', type=str, + help='Read GNA model from file using path/filename provided (required if -m is missing).') + args.add_argument('-i', '--input', required=True, type=str, help='Required. Path to an input file (.ark or .npz).') + args.add_argument('-o', '--output', type=str, + help='Optional. Output file name to save inference results (.ark or .npz).') + args.add_argument('-r', '--reference', type=str, + help='Optional. Read reference score file and compare scores.') + args.add_argument('-d', '--device', default='CPU', type=str, + help='Optional. Specify a target device to infer on. ' + 'CPU, GPU, MYRIAD, GNA_AUTO, GNA_HW, GNA_SW_FP32, GNA_SW_EXACT and HETERO with combination of GNA' + ' as the primary device and CPU as a secondary (e.g. HETERO:GNA,CPU) are supported. ' + 'The sample will look for a suitable plugin for device specified. Default value is CPU.') + args.add_argument('-bs', '--batch_size', default=1, type=int, help='Optional. Batch size 1-8 (default 1).') + args.add_argument('-qb', '--quantization_bits', default=16, type=int, + help='Optional. Weight bits for quantization: 8 or 16 (default 16).') + args.add_argument('-wg', '--export_gna_model', type=str, + help='Optional. Write GNA model to file using path/filename provided.') + args.add_argument('-we', '--export_embedded_gna_model', type=str, help=argparse.SUPPRESS) + args.add_argument('-we_gen', '--embedded_gna_configuration', default='GNA1', type=str, help=argparse.SUPPRESS) + args.add_argument('-iname', '--input_layers', type=str, + help='Optional. Layer names for input blobs. The names are separated with ",". ' + 'Allows to change the order of input layers for -i flag. Example: Input1,Input2') + args.add_argument('-oname', '--output_layers', type=str, + help='Optional. Layer names for output blobs. The names are separated with ",". ' + 'Allows to change the order of output layers for -o flag. Example: Output1:port,Output2:port.') + + return parser.parse_args() diff --git a/inference-engine/ie_bridges/python/sample/speech_sample/file_options.py b/inference-engine/ie_bridges/python/sample/speech_sample/file_options.py new file mode 100644 index 00000000000..f0d911343d9 --- /dev/null +++ b/inference-engine/ie_bridges/python/sample/speech_sample/file_options.py @@ -0,0 +1,104 @@ +# -*- coding: utf-8 -*- +# Copyright (C) 2018-2021 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 +import logging as log +import sys +from typing import IO, Any + +import numpy as np + + +def read_ark_file(file_name: str) -> dict: + """Read utterance matrices from a .ark file""" + def read_key(input_file: IO[Any]) -> str: + """Read a identifier of utterance matrix""" + key = '' + while True: + char = input_file.read(1).decode() + if char in ('', ' '): + break + else: + key += char + + return key + + def read_matrix(input_file: IO[Any]) -> np.ndarray: + """Read a utterance matrix""" + header = input_file.read(5).decode() + if 'FM' in header: + num_of_bytes = 4 + dtype = 'float32' + elif 'DM' in header: + num_of_bytes = 8 + dtype = 'float64' + else: + log.error(f'The utterance header "{header}" does not contain information about a type of elements.') + sys.exit(-7) + + _, rows, _, cols = np.frombuffer(input_file.read(10), 'int8, int32, int8, int32')[0] + buffer = input_file.read(rows * cols * num_of_bytes) + vector = np.frombuffer(buffer, dtype) + matrix = np.reshape(vector, (rows, cols)) + + return matrix + + utterances = {} + with open(file_name, 'rb') as input_file: + while True: + key = read_key(input_file) + if not key: + break + utterances[key] = read_matrix(input_file) + + return utterances + + +def write_ark_file(file_name: str, utterances: dict): + """Write utterance matrices to a .ark file""" + with open(file_name, 'wb') as output_file: + for key, matrix in sorted(utterances.items()): + # write a matrix key + output_file.write(key.encode()) + output_file.write(' '.encode()) + output_file.write('\0B'.encode()) + + # write a matrix precision + if matrix.dtype == 'float32': + output_file.write('FM '.encode()) + elif matrix.dtype == 'float64': + output_file.write('DM '.encode()) + + # write a matrix shape + output_file.write('\04'.encode()) + output_file.write(matrix.shape[0].to_bytes(4, byteorder='little', signed=False)) + output_file.write('\04'.encode()) + output_file.write(matrix.shape[1].to_bytes(4, byteorder='little', signed=False)) + + # write a matrix data + output_file.write(matrix.tobytes()) + + +def read_utterance_file(file_name: str) -> dict: + """Read utterance matrices from a file""" + file_extension = file_name.split('.')[-1] + + if file_extension == 'ark': + return read_ark_file(file_name) + elif file_extension == 'npz': + return dict(np.load(file_name)) + else: + log.error(f'The file {file_name} cannot be read. The sample supports only .ark and .npz files.') + sys.exit(-1) + + +def write_utterance_file(file_name: str, utterances: dict): + """Write utterance matrices to a file""" + file_extension = file_name.split('.')[-1] + + if file_extension == 'ark': + write_ark_file(file_name, utterances) + elif file_extension == 'npz': + np.savez(file_name, **utterances) + else: + log.error(f'The file {file_name} cannot be written. The sample supports only .ark and .npz files.') + sys.exit(-2) diff --git a/inference-engine/ie_bridges/python/sample/speech_sample/speech_sample.py b/inference-engine/ie_bridges/python/sample/speech_sample/speech_sample.py new file mode 100755 index 00000000000..858223f5e40 --- /dev/null +++ b/inference-engine/ie_bridges/python/sample/speech_sample/speech_sample.py @@ -0,0 +1,255 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +# Copyright (C) 2018-2021 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 +import logging as log +import re +import sys +from timeit import default_timer + +import numpy as np +from arg_parser import parse_args +from file_options import read_utterance_file, write_utterance_file +from openvino.inference_engine import ExecutableNetwork, IECore + + +def get_scale_factor(matrix: np.ndarray) -> float: + """Get scale factor for quantization using utterance matrix""" + # Max to find scale factor + target_max = 16384 + max_val = np.max(matrix) + if max_val == 0: + return 1.0 + else: + return target_max / max_val + + +def infer_data(data: dict, exec_net: ExecutableNetwork, input_blobs: list, output_blobs: list) -> np.ndarray: + """Do a synchronous matrix inference""" + matrix_shape = next(iter(data.values())).shape + result = {} + + for blob_name in output_blobs: + batch_size, num_of_dims = exec_net.outputs[blob_name].shape + result[blob_name] = np.ndarray((matrix_shape[0], num_of_dims)) + + slice_begin = 0 + slice_end = batch_size + + while True: + vectors = {blob_name: data[blob_name][slice_begin:slice_end] for blob_name in input_blobs} + vector_shape = next(iter(vectors.values())).shape + + if vector_shape[0] < batch_size: + temp = {blob_name: np.zeros((batch_size, vector_shape[1])) for blob_name in input_blobs} + + for blob_name in input_blobs: + temp[blob_name][:vector_shape[0]] = vectors[blob_name] + + vectors = temp + + vector_results = exec_net.infer(vectors) + + for blob_name in output_blobs: + result[blob_name][slice_begin:slice_end] = vector_results[blob_name][:vector_shape[0]] + + slice_begin += batch_size + slice_end += batch_size + + if slice_begin >= matrix_shape[0]: + return result + + +def compare_with_reference(result: np.ndarray, reference: np.ndarray): + error_matrix = np.absolute(result - reference) + + max_error = np.max(error_matrix) + sum_error = np.sum(error_matrix) + avg_error = sum_error / error_matrix.size + sum_square_error = np.sum(np.square(error_matrix)) + avg_rms_error = np.sqrt(sum_square_error / error_matrix.size) + stdev_error = np.sqrt(sum_square_error / error_matrix.size - avg_error * avg_error) + + log.info(f'max error: {max_error:.7f}') + log.info(f'avg error: {avg_error:.7f}') + log.info(f'avg rms error: {avg_rms_error:.7f}') + log.info(f'stdev error: {stdev_error:.7f}') + + +def main(): + log.basicConfig(format='[ %(levelname)s ] %(message)s', level=log.INFO, stream=sys.stdout) + args = parse_args() + +# ---------------------------Step 1. Initialize inference engine core-------------------------------------------------- + log.info('Creating Inference Engine') + ie = IECore() + +# ---------------------------Step 2. Read a model in OpenVINO Intermediate Representation--------------- + if args.model: + log.info(f'Reading the network: {args.model}') + # .xml and .bin files + net = ie.read_network(model=args.model) + +# ---------------------------Step 3. Configure input & output---------------------------------------------------------- + log.info('Configuring input and output blobs') + # Get names of input and output blobs + if args.input_layers: + input_blobs = re.split(', |,', args.input_layers) + else: + input_blobs = [next(iter(net.input_info))] + + if args.output_layers: + output_name_port = [output.split(':') for output in re.split(', |,', args.output_layers)] + try: + output_name_port = [(blob_name, int(port)) for blob_name, port in output_name_port] + except ValueError: + log.error('Output Parameter does not have a port.') + sys.exit(-4) + + net.add_outputs(output_name_port) + + output_blobs = [blob_name for blob_name, port in output_name_port] + else: + output_blobs = [list(net.outputs.keys())[-1]] + + # Set input and output precision manually + for blob_name in input_blobs: + net.input_info[blob_name].precision = 'FP32' + + for blob_name in output_blobs: + net.outputs[blob_name].precision = 'FP32' + + net.batch_size = args.batch_size + +# ---------------------------Step 4. Loading model to the device------------------------------------------------------- + devices = args.device.replace('HETERO:', '').split(',') + plugin_config = {} + + if 'GNA' in args.device: + gna_device_mode = devices[0] if '_' in devices[0] else 'GNA_AUTO' + devices[0] = 'GNA' + + plugin_config['GNA_DEVICE_MODE'] = gna_device_mode + plugin_config['GNA_PRECISION'] = f'I{args.quantization_bits}' + + # Get a GNA scale factor + if args.import_gna_model: + log.info(f'Using scale factor from the imported GNA model: {args.import_gna_model}') + else: + utterances = read_utterance_file(args.input.split(',')[0]) + key = sorted(utterances)[0] + scale_factor = get_scale_factor(utterances[key]) + log.info(f'Using scale factor of {scale_factor:.7f} calculated from first utterance.') + + plugin_config['GNA_SCALE_FACTOR'] = str(scale_factor) + + if args.export_embedded_gna_model: + plugin_config['GNA_FIRMWARE_MODEL_IMAGE'] = args.export_embedded_gna_model + plugin_config['GNA_FIRMWARE_MODEL_IMAGE_GENERATION'] = args.embedded_gna_configuration + + device_str = f'HETERO:{",".join(devices)}' if 'HETERO' in args.device else devices[0] + + log.info('Loading the model to the plugin') + if args.model: + exec_net = ie.load_network(net, device_str, plugin_config) + else: + exec_net = ie.import_network(args.import_gna_model, device_str, plugin_config) + input_blobs = [next(iter(exec_net.input_info))] + output_blobs = [list(exec_net.outputs.keys())[-1]] + + if args.input: + input_files = re.split(', |,', args.input) + + if len(input_blobs) != len(input_files): + log.error(f'Number of network inputs ({len(input_blobs)}) is not equal ' + f'to number of ark files ({len(input_files)})') + sys.exit(-3) + + if args.reference: + reference_files = re.split(', |,', args.reference) + + if len(output_blobs) != len(reference_files): + log.error('The number of reference files is not equal to the number of network outputs.') + sys.exit(-5) + + if args.output: + output_files = re.split(', |,', args.output) + + if len(output_blobs) != len(output_files): + log.error('The number of output files is not equal to the number of network outputs.') + sys.exit(-6) + + if args.export_gna_model: + log.info(f'Writing GNA Model to {args.export_gna_model}') + exec_net.export(args.export_gna_model) + return 0 + + if args.export_embedded_gna_model: + log.info(f'Exported GNA embedded model to file {args.export_embedded_gna_model}') + log.info(f'GNA embedded model export done for GNA generation {args.embedded_gna_configuration}') + return 0 + +# ---------------------------Step 5. Create infer request-------------------------------------------------------------- +# load_network() method of the IECore class with a specified number of requests (default 1) returns an ExecutableNetwork +# instance which stores infer requests. So you already created Infer requests in the previous step. + +# ---------------------------Step 6. Prepare input--------------------------------------------------------------------- + file_data = [read_utterance_file(file_name) for file_name in input_files] + input_data = { + utterance_name: { + input_blobs[i]: file_data[i][utterance_name] for i in range(len(input_blobs)) + } + for utterance_name in file_data[0].keys() + } + + if args.reference: + references = {output_blobs[i]: read_utterance_file(reference_files[i]) for i in range(len(output_blobs))} + +# ---------------------------Step 7. Do inference---------------------------------------------------------------------- + log.info('Starting inference in synchronous mode') + results = {blob_name: {} for blob_name in output_blobs} + infer_times = [] + + for key in sorted(input_data): + start_infer_time = default_timer() + + # Reset states between utterance inferences to remove a memory impact + for request in exec_net.requests: + for state in request.query_state(): + state.reset() + + result = infer_data(input_data[key], exec_net, input_blobs, output_blobs) + + for blob_name in result.keys(): + results[blob_name][key] = result[blob_name] + + infer_times.append(default_timer() - start_infer_time) + +# ---------------------------Step 8. Process output-------------------------------------------------------------------- + for blob_name in output_blobs: + for i, key in enumerate(sorted(results[blob_name])): + log.info(f'Utterance {i} ({key})') + log.info(f'Output blob name: {blob_name}') + log.info(f'Frames in utterance: {results[blob_name][key].shape[0]}') + log.info(f'Total time in Infer (HW and SW): {infer_times[i] * 1000:.2f}ms') + + if args.reference: + compare_with_reference(results[blob_name][key], references[blob_name][key]) + + log.info('') + + log.info(f'Total sample time: {sum(infer_times) * 1000:.2f}ms') + + if args.output: + for i, blob_name in enumerate(results): + write_utterance_file(output_files[i], results[blob_name]) + log.info(f'File {output_files[i]} was created!') + +# ---------------------------------------------------------------------------------------------------------------------- + log.info('This sample is an API example, ' + 'for any performance measurements please use the dedicated benchmark_app tool\n') + return 0 + + +if __name__ == '__main__': + sys.exit(main()) diff --git a/inference-engine/samples/speech_sample/README.md b/inference-engine/samples/speech_sample/README.md index 293f3240f64..4b6eabfd5d4 100644 --- a/inference-engine/samples/speech_sample/README.md +++ b/inference-engine/samples/speech_sample/README.md @@ -22,7 +22,7 @@ Basic Inference Engine API is covered by [Hello Classification C++ sample](../he | Options | Values | |:--- |:--- | Validated Models | Acoustic model based on Kaldi\* neural networks (see [Model Preparation](#model-preparation) section) -| Model Format | Inference Engine Intermediate Representation (\*.xml + \*.bin), ONNX (\*.onnx) +| Model Format | Inference Engine Intermediate Representation (\*.xml + \*.bin) | Supported devices | See [Execution Modes](#execution-modes) section below and [List Supported Devices](../../../docs/IE_DG/supported_plugins/Supported_Devices.md) | ## How It Works @@ -164,8 +164,7 @@ All of them can be downloaded from [https://storage.openvinotoolkit.org/models_c > **NOTES**: > > - Before running the sample with a trained model, make sure the model is converted to the Inference Engine format (\*.xml + \*.bin) using the [Model Optimizer tool](../../../docs/MO_DG/Deep_Learning_Model_Optimizer_DevGuide.md). -> -> - The sample accepts models in ONNX format (.onnx) that do not require preprocessing. + ## Sample Output diff --git a/inference-engine/samples/speech_sample/main.cpp b/inference-engine/samples/speech_sample/main.cpp index c7a32057254..ef1b96375d8 100644 --- a/inference-engine/samples/speech_sample/main.cpp +++ b/inference-engine/samples/speech_sample/main.cpp @@ -681,7 +681,7 @@ int main(int argc, char* argv[]) { std::cout << ie.GetVersions(deviceStr) << std::endl; // ----------------------------------------------------------------------------------------------------- - // --------------------------- Step 2. Read a model in OpenVINO Intermediate Representation (.xml and .bin files) or ONNX (.onnx file) format + // --------------------------- Step 2. Read a model in OpenVINO Intermediate Representation (.xml and .bin files) slog::info << "Loading network files:" << slog::endl << FLAGS_m << slog::endl; uint32_t batchSize = (FLAGS_cw_r > 0 || FLAGS_cw_l > 0) ? 1 : (uint32_t)FLAGS_bs;