diff --git a/docs/IE_DG/Samples_Overview.md b/docs/IE_DG/Samples_Overview.md
index 11bfb1d51c3..db39cbfc5b4 100644
--- a/docs/IE_DG/Samples_Overview.md
+++ b/docs/IE_DG/Samples_Overview.md
@@ -9,7 +9,9 @@ After installation of Intel® Distribution of OpenVINO™ toolkit, С, C++ and P
Inference Engine sample applications include the following:
-- **[Automatic Speech Recognition C++ Sample](../../inference-engine/samples/speech_sample/README.md)** – Acoustic model inference based on Kaldi neural networks and speech feature vectors.
+- **Speech Sample** - Acoustic model inference based on Kaldi neural networks and speech feature vectors.
+ - [Automatic Speech Recognition C++ Sample](../../inference-engine/samples/speech_sample/README.md)
+ - [Automatic Speech Recognition Python Sample](../../inference-engine/ie_bridges/python/sample/speech_sample/README.md)
- **Benchmark Application** – Estimates deep learning inference performance on supported devices for synchronous and asynchronous modes.
- [Benchmark C++ Tool](../../inference-engine/samples/benchmark_app/README.md)
- [Benchmark Python Tool](../../inference-engine/tools/benchmark_tool/README.md)
diff --git a/docs/doxygen/openvino_docs.xml b/docs/doxygen/openvino_docs.xml
index 34539a41246..a47a08e6777 100644
--- a/docs/doxygen/openvino_docs.xml
+++ b/docs/doxygen/openvino_docs.xml
@@ -180,6 +180,7 @@ limitations under the License.
+
diff --git a/inference-engine/ie_bridges/python/sample/speech_sample/README.md b/inference-engine/ie_bridges/python/sample/speech_sample/README.md
new file mode 100644
index 00000000000..6f02696e085
--- /dev/null
+++ b/inference-engine/ie_bridges/python/sample/speech_sample/README.md
@@ -0,0 +1,205 @@
+# Automatic Speech Recognition Python* Sample {#openvino_inference_engine_ie_bridges_python_sample_speech_sample_README}
+
+This sample demonstrates how to do a Synchronous Inference of acoustic model based on Kaldi\* neural networks and speech feature vectors.
+
+The sample works with Kaldi ARK or Numpy* uncompressed NPZ files, so it does not cover an end-to-end speech recognition scenario (speech to text), requiring additional preprocessing (feature extraction) to get a feature vector from a speech signal, as well as postprocessing (decoding) to produce text from scores.
+
+Automatic Speech Recognition Python sample application demonstrates how to use the following Inference Engine Python API in applications:
+
+| Feature | API | Description |
+| :------------------ | :---------------------------------------------------------------------------------------------------- | :-------------------------------------------------------------------- |
+| Import/Export Model | [IECore.import_network], [ExecutableNetwork.export] | The GNA plugin supports loading and saving of the GNA-optimized model |
+| Network Operations | [IENetwork.batch_size], [CDataPtr.shape], [ExecutableNetwork.input_info], [ExecutableNetwork.outputs] | Managing of network: configure input and output blobs |
+| Network Operations | [IENetwork.add_outputs] | Managing of network: Change names of output layers in the network |
+| InferRequest Operations|InferRequest.query_state, VariableState.reset| Gets and resets state control interface for given executable network |
+
+Basic Inference Engine API is covered by [Hello Classification Python* Sample](../hello_classification/README.md).
+
+| Options | Values |
+| :------------------------- | :---------------------------------------------------------------------------------------------------- |
+| Validated Models | Acoustic model based on Kaldi* neural networks (see [Model Preparation](#model-preparation) section) |
+| Model Format | Inference Engine Intermediate Representation (.xml + .bin) |
+| Supported devices | See [Execution Modes](#execution-modes) section below and [List Supported Devices](../../../../../docs/IE_DG/supported_plugins/Supported_Devices.md) |
+| Other language realization | [C++](../../../../samples/speech_sample) |
+
+## How It Works
+
+At startup, the sample application reads command-line parameters, loads a specified model and input data to the Inference Engine plugin, performs synchronous inference on all speech utterances stored in the input file, logging each step in a standard output stream.
+
+You can see the explicit description of
+each sample step at [Integration Steps](../../../../../docs/IE_DG/Integrate_with_customer_application_new_API.md) section of "Integrate the Inference Engine with Your Application" guide.
+
+## GNA-specific details
+
+### Quantization
+
+If the GNA device is selected (for example, using the `-d` GNA flag), the GNA Inference Engine plugin quantizes the model and input feature vector sequence to integer representation before performing inference.
+
+The `-qb` flag provides a hint to the GNA plugin regarding the preferred target weight resolution for all layers.
+For example, when `-qb 8` is specified, the plugin will use 8-bit weights wherever possible in the
+network.
+
+> **NOTE**:
+>
+> - It is not always possible to use 8-bit weights due to GNA hardware limitations. For example, convolutional layers always use 16-bit weights (GNA hardware version 1 and 2). This limitation will be removed in GNA hardware version 3 and higher.
+>
+
+### Execution Modes
+
+Several execution modes are supported via the `-d` flag:
+
+- `CPU` - All calculation are performed on CPU device using CPU Plugin.
+- `GPU` - All calculation are performed on GPU device using GPU Plugin.
+- `MYRIAD` - All calculation are performed on Intel® Neural Compute Stick 2 device using VPU MYRIAD Plugin.
+- `GNA_AUTO` - GNA hardware is used if available and the driver is installed. Otherwise, the GNA device is emulated in fast-but-not-bit-exact mode.
+- `GNA_HW` - GNA hardware is used if available and the driver is installed. Otherwise, an error will occur.
+- `GNA_SW` - Deprecated. The GNA device is emulated in fast-but-not-bit-exact mode.
+- `GNA_SW_FP32` - Substitutes parameters and calculations from low precision to floating point (FP32).
+- `GNA_SW_EXACT` - GNA device is emulated in bit-exact mode.
+
+### Loading and Saving Models
+
+The GNA plugin supports loading and saving of the GNA-optimized model (non-IR) via the `-rg` and `-wg` flags.
+Thereby, it is possible to avoid the cost of full model quantization at run time.
+
+In addition to performing inference directly from a GNA model file, this option makes it possible to:
+
+- Convert from IR format to GNA format model file (`-m`, `-wg`)
+
+## Running
+
+Run the application with the -h
option to see the usage message:
+
+```sh
+python speech_sample.py -h
+```
+
+Usage message:
+
+```sh
+usage: speech_sample.py [-h] (-m MODEL | -rg IMPORT_GNA_MODEL) -i INPUT
+ [-o OUTPUT] [-r REFERENCE] [-d DEVICE]
+ [-bs BATCH_SIZE] [-qb QUANTIZATION_BITS]
+ [-wg EXPORT_GNA_MODEL] [-iname INPUT_LAYERS]
+ [-oname OUTPUT_LAYERS]
+
+optional arguments:
+ -m MODEL, --model MODEL
+ Path to an .xml file with a trained model (required if
+ -rg is missing).
+ -rg IMPORT_GNA_MODEL, --import_gna_model IMPORT_GNA_MODEL
+ Read GNA model from file using path/filename provided
+ (required if -m is missing).
+
+Options:
+ -h, --help Show this help message and exit.
+ -i INPUT, --input INPUT
+ Required. Path to an input file (.ark or .npz).
+ -o OUTPUT, --output OUTPUT
+ Optional. Output file name to save inference results (.ark or .npz).
+ -r REFERENCE, --reference REFERENCE
+ Optional. Read reference score file and compare
+ scores.
+ -d DEVICE, --device DEVICE
+ Optional. Specify a target device to infer on. CPU,
+ GPU, MYRIAD, GNA_AUTO, GNA_HW, GNA_SW_FP32,
+ GNA_SW_EXACT and HETERO with combination of GNA as the
+ primary device and CPU as a secondary (e.g.
+ HETERO:GNA,CPU) are supported. The sample will look
+ for a suitable plugin for device specified. Default
+ value is CPU.
+ -bs BATCH_SIZE, --batch_size BATCH_SIZE
+ Optional. Batch size 1-8 (default 1).
+ -qb QUANTIZATION_BITS, --quantization_bits QUANTIZATION_BITS
+ Optional. Weight bits for quantization: 8 or 16
+ (default 16).
+ -wg EXPORT_GNA_MODEL, --export_gna_model EXPORT_GNA_MODEL
+ Optional. Write GNA model to file using path/filename
+ provided.
+ -iname INPUT_LAYERS, --input_layers INPUT_LAYERS
+ Optional. Layer names for input blobs. The names are
+ separated with ",". Allows to change the order of
+ input layers for -i flag. Example: Input1,Input2
+ -oname OUTPUT_LAYERS, --output_layers OUTPUT_LAYERS
+ Optional. Layer names for output blobs. The names are
+ separated with ",". Allows to change the order of
+ output layers for -o flag. Example:
+ Output1:port,Output2:port.
+```
+
+## Model Preparation
+
+You can use the following model optimizer command to convert a Kaldi nnet1 or nnet2 neural network to Inference Engine Intermediate Representation format:
+
+```sh
+python mo.py --framework kaldi --input_model wsj_dnn5b.nnet --counts wsj_dnn5b.counts --remove_output_softmax --output_dir
+```
+
+The following pre-trained models are available:
+
+- wsj_dnn5b_smbr
+- rm_lstm4f
+- rm_cnn4a_smbr
+
+All of them can be downloaded from [https://storage.openvinotoolkit.org/models_contrib/speech/2021.2](https://storage.openvinotoolkit.org/models_contrib/speech/2021.2).
+
+## Speech Inference
+
+You can do inference on Intel® Processors with the GNA co-processor (or emulation library):
+
+```sh
+python speech_sample.py -d GNA_AUTO -m wsj_dnn5b.xml -i dev93_10.ark -r dev93_scores_10.ark -o result.npz
+```
+
+> **NOTES**:
+>
+> - Before running the sample with a trained model, make sure the model is converted to the Inference Engine format (\*.xml + \*.bin) using the [Model Optimizer tool](../../../../../docs/MO_DG/Deep_Learning_Model_Optimizer_DevGuide.md).
+>
+> - The sample supports input and output in numpy file format (.npz)
+
+## Sample Output
+
+The sample application logs each step in a standard output stream.
+
+```sh
+[ INFO ] Creating Inference Engine
+[ INFO ] Reading the network: wsj_dnn5b.xml
+[ INFO ] Configuring input and output blobs
+[ INFO ] Using scale factor of 2175.4322417 calculated from first utterance.
+[ INFO ] Loading the model to the plugin
+[ INFO ] Starting inference in synchronous mode
+[ INFO ] Utterance 0 (4k0c0301)
+[ INFO ] Frames in utterance: 1294
+[ INFO ] Total time in Infer (HW and SW): 5305.47ms
+[ INFO ] max error: 0.7051839
+[ INFO ] avg error: 0.0448387
+[ INFO ] avg rms error: 0.0582387
+[ INFO ] stdev error: 0.0371649
+[ INFO ]
+[ INFO ] Utterance 1 (4k0c0302)
+[ INFO ] Frames in utterance: 1005
+[ INFO ] Total time in Infer (HW and SW): 5031.53ms
+[ INFO ] max error: 0.7575974
+[ INFO ] avg error: 0.0452166
+[ INFO ] avg rms error: 0.0586013
+[ INFO ] stdev error: 0.0372769
+[ INFO ]
+...
+[ INFO ] Total sample time: 38033.09ms
+[ INFO ] This sample is an API example, for any performance measurements please use the dedicated benchmark_app tool
+```
+
+## See Also
+
+- [Integrate the Inference Engine with Your Application](../../../../../docs/IE_DG/Integrate_with_customer_application_new_API.md)
+- [Using Inference Engine Samples](../../../../../docs/IE_DG/Samples_Overview.md)
+- [Model Downloader](@ref omz_tools_downloader_README)
+- [Model Optimizer](../../../../../docs/MO_DG/Deep_Learning_Model_Optimizer_DevGuide.md)
+
+[IENetwork.batch_size]:https://docs.openvinotoolkit.org/latest/ie_python_api/classie__api_1_1IENetwork.html#a79a647cb1b49645616eaeb2ca255ef2e
+[IENetwork.add_outputs]:https://docs.openvinotoolkit.org/latest/ie_python_api/classie__api_1_1IENetwork.html#ae8024b07f3301d6d5de5c0d153e2e6e6
+[CDataPtr.shape]:https://docs.openvinotoolkit.org/latest/ie_python_api/classie__api_1_1CDataPtr.html#aa6fd459edb323d1c6215dc7a970ebf7f
+[ExecutableNetwork.input_info]:https://docs.openvinotoolkit.org/latest/ie_python_api/classie__api_1_1ExecutableNetwork.html#ac76a04c2918607874018d2e15a2f274f
+[ExecutableNetwork.outputs]:https://docs.openvinotoolkit.org/latest/ie_python_api/classie__api_1_1ExecutableNetwork.html#a4a631776df195004b1523e6ae91a65c1
+[IECore.import_network]:https://docs.openvinotoolkit.org/latest/ie_python_api/classie__api_1_1IECore.html#afdeac5192bb1d9e64722f1071fb0a64a
+[ExecutableNetwork.export]:https://docs.openvinotoolkit.org/latest/ie_python_api/classie__api_1_1ExecutableNetwork.html#afa78158252f0d8070181bafec4318413
\ No newline at end of file
diff --git a/inference-engine/ie_bridges/python/sample/speech_sample/arg_parser.py b/inference-engine/ie_bridges/python/sample/speech_sample/arg_parser.py
new file mode 100644
index 00000000000..cfc20dfb425
--- /dev/null
+++ b/inference-engine/ie_bridges/python/sample/speech_sample/arg_parser.py
@@ -0,0 +1,42 @@
+# -*- coding: utf-8 -*-
+# Copyright (C) 2018-2021 Intel Corporation
+# SPDX-License-Identifier: Apache-2.0
+import argparse
+
+
+def parse_args() -> argparse.Namespace:
+ """Parse and return command line arguments"""
+ parser = argparse.ArgumentParser(add_help=False)
+ args = parser.add_argument_group('Options')
+ model = parser.add_mutually_exclusive_group(required=True)
+
+ args.add_argument('-h', '--help', action='help', help='Show this help message and exit.')
+ model.add_argument('-m', '--model', type=str,
+ help='Path to an .xml file with a trained model (required if -rg is missing).')
+ model.add_argument('-rg', '--import_gna_model', type=str,
+ help='Read GNA model from file using path/filename provided (required if -m is missing).')
+ args.add_argument('-i', '--input', required=True, type=str, help='Required. Path to an input file (.ark or .npz).')
+ args.add_argument('-o', '--output', type=str,
+ help='Optional. Output file name to save inference results (.ark or .npz).')
+ args.add_argument('-r', '--reference', type=str,
+ help='Optional. Read reference score file and compare scores.')
+ args.add_argument('-d', '--device', default='CPU', type=str,
+ help='Optional. Specify a target device to infer on. '
+ 'CPU, GPU, MYRIAD, GNA_AUTO, GNA_HW, GNA_SW_FP32, GNA_SW_EXACT and HETERO with combination of GNA'
+ ' as the primary device and CPU as a secondary (e.g. HETERO:GNA,CPU) are supported. '
+ 'The sample will look for a suitable plugin for device specified. Default value is CPU.')
+ args.add_argument('-bs', '--batch_size', default=1, type=int, help='Optional. Batch size 1-8 (default 1).')
+ args.add_argument('-qb', '--quantization_bits', default=16, type=int,
+ help='Optional. Weight bits for quantization: 8 or 16 (default 16).')
+ args.add_argument('-wg', '--export_gna_model', type=str,
+ help='Optional. Write GNA model to file using path/filename provided.')
+ args.add_argument('-we', '--export_embedded_gna_model', type=str, help=argparse.SUPPRESS)
+ args.add_argument('-we_gen', '--embedded_gna_configuration', default='GNA1', type=str, help=argparse.SUPPRESS)
+ args.add_argument('-iname', '--input_layers', type=str,
+ help='Optional. Layer names for input blobs. The names are separated with ",". '
+ 'Allows to change the order of input layers for -i flag. Example: Input1,Input2')
+ args.add_argument('-oname', '--output_layers', type=str,
+ help='Optional. Layer names for output blobs. The names are separated with ",". '
+ 'Allows to change the order of output layers for -o flag. Example: Output1:port,Output2:port.')
+
+ return parser.parse_args()
diff --git a/inference-engine/ie_bridges/python/sample/speech_sample/file_options.py b/inference-engine/ie_bridges/python/sample/speech_sample/file_options.py
new file mode 100644
index 00000000000..f0d911343d9
--- /dev/null
+++ b/inference-engine/ie_bridges/python/sample/speech_sample/file_options.py
@@ -0,0 +1,104 @@
+# -*- coding: utf-8 -*-
+# Copyright (C) 2018-2021 Intel Corporation
+# SPDX-License-Identifier: Apache-2.0
+import logging as log
+import sys
+from typing import IO, Any
+
+import numpy as np
+
+
+def read_ark_file(file_name: str) -> dict:
+ """Read utterance matrices from a .ark file"""
+ def read_key(input_file: IO[Any]) -> str:
+ """Read a identifier of utterance matrix"""
+ key = ''
+ while True:
+ char = input_file.read(1).decode()
+ if char in ('', ' '):
+ break
+ else:
+ key += char
+
+ return key
+
+ def read_matrix(input_file: IO[Any]) -> np.ndarray:
+ """Read a utterance matrix"""
+ header = input_file.read(5).decode()
+ if 'FM' in header:
+ num_of_bytes = 4
+ dtype = 'float32'
+ elif 'DM' in header:
+ num_of_bytes = 8
+ dtype = 'float64'
+ else:
+ log.error(f'The utterance header "{header}" does not contain information about a type of elements.')
+ sys.exit(-7)
+
+ _, rows, _, cols = np.frombuffer(input_file.read(10), 'int8, int32, int8, int32')[0]
+ buffer = input_file.read(rows * cols * num_of_bytes)
+ vector = np.frombuffer(buffer, dtype)
+ matrix = np.reshape(vector, (rows, cols))
+
+ return matrix
+
+ utterances = {}
+ with open(file_name, 'rb') as input_file:
+ while True:
+ key = read_key(input_file)
+ if not key:
+ break
+ utterances[key] = read_matrix(input_file)
+
+ return utterances
+
+
+def write_ark_file(file_name: str, utterances: dict):
+ """Write utterance matrices to a .ark file"""
+ with open(file_name, 'wb') as output_file:
+ for key, matrix in sorted(utterances.items()):
+ # write a matrix key
+ output_file.write(key.encode())
+ output_file.write(' '.encode())
+ output_file.write('\0B'.encode())
+
+ # write a matrix precision
+ if matrix.dtype == 'float32':
+ output_file.write('FM '.encode())
+ elif matrix.dtype == 'float64':
+ output_file.write('DM '.encode())
+
+ # write a matrix shape
+ output_file.write('\04'.encode())
+ output_file.write(matrix.shape[0].to_bytes(4, byteorder='little', signed=False))
+ output_file.write('\04'.encode())
+ output_file.write(matrix.shape[1].to_bytes(4, byteorder='little', signed=False))
+
+ # write a matrix data
+ output_file.write(matrix.tobytes())
+
+
+def read_utterance_file(file_name: str) -> dict:
+ """Read utterance matrices from a file"""
+ file_extension = file_name.split('.')[-1]
+
+ if file_extension == 'ark':
+ return read_ark_file(file_name)
+ elif file_extension == 'npz':
+ return dict(np.load(file_name))
+ else:
+ log.error(f'The file {file_name} cannot be read. The sample supports only .ark and .npz files.')
+ sys.exit(-1)
+
+
+def write_utterance_file(file_name: str, utterances: dict):
+ """Write utterance matrices to a file"""
+ file_extension = file_name.split('.')[-1]
+
+ if file_extension == 'ark':
+ write_ark_file(file_name, utterances)
+ elif file_extension == 'npz':
+ np.savez(file_name, **utterances)
+ else:
+ log.error(f'The file {file_name} cannot be written. The sample supports only .ark and .npz files.')
+ sys.exit(-2)
diff --git a/inference-engine/ie_bridges/python/sample/speech_sample/speech_sample.py b/inference-engine/ie_bridges/python/sample/speech_sample/speech_sample.py
new file mode 100755
index 00000000000..858223f5e40
--- /dev/null
+++ b/inference-engine/ie_bridges/python/sample/speech_sample/speech_sample.py
@@ -0,0 +1,255 @@
+#!/usr/bin/env python3
+# -*- coding: utf-8 -*-
+# Copyright (C) 2018-2021 Intel Corporation
+# SPDX-License-Identifier: Apache-2.0
+import logging as log
+import re
+import sys
+from timeit import default_timer
+
+import numpy as np
+from arg_parser import parse_args
+from file_options import read_utterance_file, write_utterance_file
+from openvino.inference_engine import ExecutableNetwork, IECore
+
+
+def get_scale_factor(matrix: np.ndarray) -> float:
+ """Get scale factor for quantization using utterance matrix"""
+ # Max to find scale factor
+ target_max = 16384
+ max_val = np.max(matrix)
+ if max_val == 0:
+ return 1.0
+ else:
+ return target_max / max_val
+
+
+def infer_data(data: dict, exec_net: ExecutableNetwork, input_blobs: list, output_blobs: list) -> np.ndarray:
+ """Do a synchronous matrix inference"""
+ matrix_shape = next(iter(data.values())).shape
+ result = {}
+
+ for blob_name in output_blobs:
+ batch_size, num_of_dims = exec_net.outputs[blob_name].shape
+ result[blob_name] = np.ndarray((matrix_shape[0], num_of_dims))
+
+ slice_begin = 0
+ slice_end = batch_size
+
+ while True:
+ vectors = {blob_name: data[blob_name][slice_begin:slice_end] for blob_name in input_blobs}
+ vector_shape = next(iter(vectors.values())).shape
+
+ if vector_shape[0] < batch_size:
+ temp = {blob_name: np.zeros((batch_size, vector_shape[1])) for blob_name in input_blobs}
+
+ for blob_name in input_blobs:
+ temp[blob_name][:vector_shape[0]] = vectors[blob_name]
+
+ vectors = temp
+
+ vector_results = exec_net.infer(vectors)
+
+ for blob_name in output_blobs:
+ result[blob_name][slice_begin:slice_end] = vector_results[blob_name][:vector_shape[0]]
+
+ slice_begin += batch_size
+ slice_end += batch_size
+
+ if slice_begin >= matrix_shape[0]:
+ return result
+
+
+def compare_with_reference(result: np.ndarray, reference: np.ndarray):
+ error_matrix = np.absolute(result - reference)
+
+ max_error = np.max(error_matrix)
+ sum_error = np.sum(error_matrix)
+ avg_error = sum_error / error_matrix.size
+ sum_square_error = np.sum(np.square(error_matrix))
+ avg_rms_error = np.sqrt(sum_square_error / error_matrix.size)
+ stdev_error = np.sqrt(sum_square_error / error_matrix.size - avg_error * avg_error)
+
+ log.info(f'max error: {max_error:.7f}')
+ log.info(f'avg error: {avg_error:.7f}')
+ log.info(f'avg rms error: {avg_rms_error:.7f}')
+ log.info(f'stdev error: {stdev_error:.7f}')
+
+
+def main():
+ log.basicConfig(format='[ %(levelname)s ] %(message)s', level=log.INFO, stream=sys.stdout)
+ args = parse_args()
+
+# ---------------------------Step 1. Initialize inference engine core--------------------------------------------------
+ log.info('Creating Inference Engine')
+ ie = IECore()
+
+# ---------------------------Step 2. Read a model in OpenVINO Intermediate Representation---------------
+ if args.model:
+ log.info(f'Reading the network: {args.model}')
+ # .xml and .bin files
+ net = ie.read_network(model=args.model)
+
+# ---------------------------Step 3. Configure input & output----------------------------------------------------------
+ log.info('Configuring input and output blobs')
+ # Get names of input and output blobs
+ if args.input_layers:
+ input_blobs = re.split(', |,', args.input_layers)
+ else:
+ input_blobs = [next(iter(net.input_info))]
+
+ if args.output_layers:
+ output_name_port = [output.split(':') for output in re.split(', |,', args.output_layers)]
+ try:
+ output_name_port = [(blob_name, int(port)) for blob_name, port in output_name_port]
+ except ValueError:
+ log.error('Output Parameter does not have a port.')
+ sys.exit(-4)
+
+ net.add_outputs(output_name_port)
+
+ output_blobs = [blob_name for blob_name, port in output_name_port]
+ else:
+ output_blobs = [list(net.outputs.keys())[-1]]
+
+ # Set input and output precision manually
+ for blob_name in input_blobs:
+ net.input_info[blob_name].precision = 'FP32'
+
+ for blob_name in output_blobs:
+ net.outputs[blob_name].precision = 'FP32'
+
+ net.batch_size = args.batch_size
+
+# ---------------------------Step 4. Loading model to the device-------------------------------------------------------
+ devices = args.device.replace('HETERO:', '').split(',')
+ plugin_config = {}
+
+ if 'GNA' in args.device:
+ gna_device_mode = devices[0] if '_' in devices[0] else 'GNA_AUTO'
+ devices[0] = 'GNA'
+
+ plugin_config['GNA_DEVICE_MODE'] = gna_device_mode
+ plugin_config['GNA_PRECISION'] = f'I{args.quantization_bits}'
+
+ # Get a GNA scale factor
+ if args.import_gna_model:
+ log.info(f'Using scale factor from the imported GNA model: {args.import_gna_model}')
+ else:
+ utterances = read_utterance_file(args.input.split(',')[0])
+ key = sorted(utterances)[0]
+ scale_factor = get_scale_factor(utterances[key])
+ log.info(f'Using scale factor of {scale_factor:.7f} calculated from first utterance.')
+
+ plugin_config['GNA_SCALE_FACTOR'] = str(scale_factor)
+
+ if args.export_embedded_gna_model:
+ plugin_config['GNA_FIRMWARE_MODEL_IMAGE'] = args.export_embedded_gna_model
+ plugin_config['GNA_FIRMWARE_MODEL_IMAGE_GENERATION'] = args.embedded_gna_configuration
+
+ device_str = f'HETERO:{",".join(devices)}' if 'HETERO' in args.device else devices[0]
+
+ log.info('Loading the model to the plugin')
+ if args.model:
+ exec_net = ie.load_network(net, device_str, plugin_config)
+ else:
+ exec_net = ie.import_network(args.import_gna_model, device_str, plugin_config)
+ input_blobs = [next(iter(exec_net.input_info))]
+ output_blobs = [list(exec_net.outputs.keys())[-1]]
+
+ if args.input:
+ input_files = re.split(', |,', args.input)
+
+ if len(input_blobs) != len(input_files):
+ log.error(f'Number of network inputs ({len(input_blobs)}) is not equal '
+ f'to number of ark files ({len(input_files)})')
+ sys.exit(-3)
+
+ if args.reference:
+ reference_files = re.split(', |,', args.reference)
+
+ if len(output_blobs) != len(reference_files):
+ log.error('The number of reference files is not equal to the number of network outputs.')
+ sys.exit(-5)
+
+ if args.output:
+ output_files = re.split(', |,', args.output)
+
+ if len(output_blobs) != len(output_files):
+ log.error('The number of output files is not equal to the number of network outputs.')
+ sys.exit(-6)
+
+ if args.export_gna_model:
+ log.info(f'Writing GNA Model to {args.export_gna_model}')
+ exec_net.export(args.export_gna_model)
+ return 0
+
+ if args.export_embedded_gna_model:
+ log.info(f'Exported GNA embedded model to file {args.export_embedded_gna_model}')
+ log.info(f'GNA embedded model export done for GNA generation {args.embedded_gna_configuration}')
+ return 0
+
+# ---------------------------Step 5. Create infer request--------------------------------------------------------------
+# load_network() method of the IECore class with a specified number of requests (default 1) returns an ExecutableNetwork
+# instance which stores infer requests. So you already created Infer requests in the previous step.
+
+# ---------------------------Step 6. Prepare input---------------------------------------------------------------------
+ file_data = [read_utterance_file(file_name) for file_name in input_files]
+ input_data = {
+ utterance_name: {
+ input_blobs[i]: file_data[i][utterance_name] for i in range(len(input_blobs))
+ }
+ for utterance_name in file_data[0].keys()
+ }
+
+ if args.reference:
+ references = {output_blobs[i]: read_utterance_file(reference_files[i]) for i in range(len(output_blobs))}
+
+# ---------------------------Step 7. Do inference----------------------------------------------------------------------
+ log.info('Starting inference in synchronous mode')
+ results = {blob_name: {} for blob_name in output_blobs}
+ infer_times = []
+
+ for key in sorted(input_data):
+ start_infer_time = default_timer()
+
+ # Reset states between utterance inferences to remove a memory impact
+ for request in exec_net.requests:
+ for state in request.query_state():
+ state.reset()
+
+ result = infer_data(input_data[key], exec_net, input_blobs, output_blobs)
+
+ for blob_name in result.keys():
+ results[blob_name][key] = result[blob_name]
+
+ infer_times.append(default_timer() - start_infer_time)
+
+# ---------------------------Step 8. Process output--------------------------------------------------------------------
+ for blob_name in output_blobs:
+ for i, key in enumerate(sorted(results[blob_name])):
+ log.info(f'Utterance {i} ({key})')
+ log.info(f'Output blob name: {blob_name}')
+ log.info(f'Frames in utterance: {results[blob_name][key].shape[0]}')
+ log.info(f'Total time in Infer (HW and SW): {infer_times[i] * 1000:.2f}ms')
+
+ if args.reference:
+ compare_with_reference(results[blob_name][key], references[blob_name][key])
+
+ log.info('')
+
+ log.info(f'Total sample time: {sum(infer_times) * 1000:.2f}ms')
+
+ if args.output:
+ for i, blob_name in enumerate(results):
+ write_utterance_file(output_files[i], results[blob_name])
+ log.info(f'File {output_files[i]} was created!')
+
+# ----------------------------------------------------------------------------------------------------------------------
+ log.info('This sample is an API example, '
+ 'for any performance measurements please use the dedicated benchmark_app tool\n')
+ return 0
+
+
+if __name__ == '__main__':
+ sys.exit(main())
diff --git a/inference-engine/samples/speech_sample/README.md b/inference-engine/samples/speech_sample/README.md
index 293f3240f64..4b6eabfd5d4 100644
--- a/inference-engine/samples/speech_sample/README.md
+++ b/inference-engine/samples/speech_sample/README.md
@@ -22,7 +22,7 @@ Basic Inference Engine API is covered by [Hello Classification C++ sample](../he
| Options | Values |
|:--- |:---
| Validated Models | Acoustic model based on Kaldi\* neural networks (see [Model Preparation](#model-preparation) section)
-| Model Format | Inference Engine Intermediate Representation (\*.xml + \*.bin), ONNX (\*.onnx)
+| Model Format | Inference Engine Intermediate Representation (\*.xml + \*.bin)
| Supported devices | See [Execution Modes](#execution-modes) section below and [List Supported Devices](../../../docs/IE_DG/supported_plugins/Supported_Devices.md) |
## How It Works
@@ -164,8 +164,7 @@ All of them can be downloaded from [https://storage.openvinotoolkit.org/models_c
> **NOTES**:
>
> - Before running the sample with a trained model, make sure the model is converted to the Inference Engine format (\*.xml + \*.bin) using the [Model Optimizer tool](../../../docs/MO_DG/Deep_Learning_Model_Optimizer_DevGuide.md).
->
-> - The sample accepts models in ONNX format (.onnx) that do not require preprocessing.
+
## Sample Output
diff --git a/inference-engine/samples/speech_sample/main.cpp b/inference-engine/samples/speech_sample/main.cpp
index c7a32057254..ef1b96375d8 100644
--- a/inference-engine/samples/speech_sample/main.cpp
+++ b/inference-engine/samples/speech_sample/main.cpp
@@ -681,7 +681,7 @@ int main(int argc, char* argv[]) {
std::cout << ie.GetVersions(deviceStr) << std::endl;
// -----------------------------------------------------------------------------------------------------
- // --------------------------- Step 2. Read a model in OpenVINO Intermediate Representation (.xml and .bin files) or ONNX (.onnx file) format
+ // --------------------------- Step 2. Read a model in OpenVINO Intermediate Representation (.xml and .bin files)
slog::info << "Loading network files:" << slog::endl << FLAGS_m << slog::endl;
uint32_t batchSize = (FLAGS_cw_r > 0 || FLAGS_cw_l > 0) ? 1 : (uint32_t)FLAGS_bs;