[IE Python Speech Sample] Add context windows feature (#7801)
* Add `context_window_left` feature * Add a check of positive context window args * Add `context_window_right` feature
This commit is contained in:
@@ -25,8 +25,9 @@ def parse_args() -> argparse.Namespace:
|
||||
'CPU, GPU, MYRIAD, GNA_AUTO, GNA_HW, GNA_SW_FP32, GNA_SW_EXACT and HETERO with combination of GNA'
|
||||
' as the primary device and CPU as a secondary (e.g. HETERO:GNA,CPU) are supported. '
|
||||
'The sample will look for a suitable plugin for device specified. Default value is CPU.')
|
||||
args.add_argument('-bs', '--batch_size', default=1, type=int, help='Optional. Batch size 1-8 (default 1).')
|
||||
args.add_argument('-qb', '--quantization_bits', default=16, type=int,
|
||||
args.add_argument('-bs', '--batch_size', default=1, type=int, choices=range(1, 9), metavar='[1-8]',
|
||||
help='Optional. Batch size 1-8 (default 1).')
|
||||
args.add_argument('-qb', '--quantization_bits', default=16, type=int, choices=(8, 16), metavar='[8, 16]',
|
||||
help='Optional. Weight bits for quantization: 8 or 16 (default 16).')
|
||||
args.add_argument('-sf', '--scale_factor', type=str,
|
||||
help='Optional. The user-specified input scale factor for quantization. '
|
||||
@@ -37,7 +38,7 @@ def parse_args() -> argparse.Namespace:
|
||||
args.add_argument('-we_gen', '--embedded_gna_configuration', default='GNA1', type=str, help=argparse.SUPPRESS)
|
||||
args.add_argument('-pc', '--performance_counter', action='store_true',
|
||||
help='Optional. Enables performance report (specify -a to ensure arch accurate results).')
|
||||
args.add_argument('-a', '--arch', default='CORE', type=str.upper, choices=['CORE', 'ATOM'],
|
||||
args.add_argument('-a', '--arch', default='CORE', type=str.upper, choices=('CORE', 'ATOM'), metavar='[CORE, ATOM]',
|
||||
help='Optional. Specify architecture. CORE, ATOM with the combination of -pc.')
|
||||
args.add_argument('-iname', '--input_layers', type=str,
|
||||
help='Optional. Layer names for input blobs. The names are separated with ",". '
|
||||
@@ -45,5 +46,37 @@ def parse_args() -> argparse.Namespace:
|
||||
args.add_argument('-oname', '--output_layers', type=str,
|
||||
help='Optional. Layer names for output blobs. The names are separated with ",". '
|
||||
'Allows to change the order of output layers for -o flag. Example: Output1:port,Output2:port.')
|
||||
args.add_argument('-cw_l', '--context_window_left', type=IntRange(0), default=0,
|
||||
help='Optional. Number of frames for left context windows (default is 0). '
|
||||
'Works only with context window networks. '
|
||||
'If you use the cw_l or cw_r flag, then batch size argument is ignored.')
|
||||
args.add_argument('-cw_r', '--context_window_right', type=IntRange(0), default=0,
|
||||
help='Optional. Number of frames for right context windows (default is 0). '
|
||||
'Works only with context window networks. '
|
||||
'If you use the cw_l or cw_r flag, then batch size argument is ignored.')
|
||||
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
class IntRange:
|
||||
"""Custom argparse type representing a bounded int."""
|
||||
|
||||
def __init__(self, _min=None, _max=None):
|
||||
self._min = _min
|
||||
self._max = _max
|
||||
|
||||
def __call__(self, arg):
|
||||
try:
|
||||
value = int(arg)
|
||||
except ValueError:
|
||||
raise argparse.ArgumentTypeError('Must be an integer.')
|
||||
|
||||
if (self._min is not None and value < self._min) or (self._max is not None and value > self._max):
|
||||
if self._min is not None and self._max is not None:
|
||||
raise argparse.ArgumentTypeError(f'Must be an integer in the range [{self._min}, {self._max}].')
|
||||
elif self._min is not None:
|
||||
raise argparse.ArgumentTypeError(f'Must be an integer >= {self._min}.')
|
||||
elif self._max is not None:
|
||||
raise argparse.ArgumentTypeError(f'Must be an integer <= {self._max}.')
|
||||
|
||||
return value
|
||||
|
||||
@@ -30,7 +30,9 @@ def get_scale_factor(matrix: np.ndarray) -> float:
|
||||
return target_max / max_val
|
||||
|
||||
|
||||
def infer_data(data: dict, exec_net: ExecutableNetwork, input_blobs: list, output_blobs: list) -> np.ndarray:
|
||||
def infer_data(
|
||||
data: dict, exec_net: ExecutableNetwork, input_blobs: list, output_blobs: list, cw_l: int = 0, cw_r: int = 0,
|
||||
) -> np.ndarray:
|
||||
"""Do a synchronous matrix inference"""
|
||||
matrix_shape = next(iter(data.values())).shape
|
||||
result = {}
|
||||
@@ -40,11 +42,16 @@ def infer_data(data: dict, exec_net: ExecutableNetwork, input_blobs: list, outpu
|
||||
batch_size = shape[0]
|
||||
result[blob_name] = np.ndarray((matrix_shape[0], shape[-1]))
|
||||
|
||||
slice_begin = 0
|
||||
slice_end = batch_size
|
||||
for i in range(-cw_l, matrix_shape[0] + cw_r, batch_size):
|
||||
if i < 0:
|
||||
index = 0
|
||||
elif i >= matrix_shape[0]:
|
||||
index = matrix_shape[0] - 1
|
||||
else:
|
||||
index = i
|
||||
|
||||
vectors = {blob_name: data[blob_name][index:index + batch_size] for blob_name in input_blobs}
|
||||
|
||||
while slice_begin < matrix_shape[0]:
|
||||
vectors = {blob_name: data[blob_name][slice_begin:slice_end] for blob_name in input_blobs}
|
||||
num_of_vectors = next(iter(vectors.values())).shape[0]
|
||||
|
||||
if num_of_vectors < batch_size:
|
||||
@@ -57,11 +64,11 @@ def infer_data(data: dict, exec_net: ExecutableNetwork, input_blobs: list, outpu
|
||||
|
||||
vector_results = exec_net.infer(vectors)
|
||||
|
||||
for blob_name in output_blobs:
|
||||
result[blob_name][slice_begin:slice_end] = vector_results[blob_name][:num_of_vectors]
|
||||
if i - cw_r < 0:
|
||||
continue
|
||||
|
||||
slice_begin += batch_size
|
||||
slice_end += batch_size
|
||||
for blob_name in output_blobs:
|
||||
result[blob_name][i - cw_r:i - cw_r + batch_size] = vector_results[blob_name][:num_of_vectors]
|
||||
|
||||
return result
|
||||
|
||||
@@ -161,7 +168,7 @@ def main():
|
||||
for blob_name in output_blobs:
|
||||
net.outputs[blob_name].precision = 'FP32'
|
||||
|
||||
net.batch_size = args.batch_size
|
||||
net.batch_size = args.batch_size if args.context_window_left + args.context_window_right == 0 else 1
|
||||
|
||||
# ---------------------------Step 4. Loading model to the device-------------------------------------------------------
|
||||
devices = args.device.replace('HETERO:', '').split(',')
|
||||
@@ -272,7 +279,9 @@ def main():
|
||||
for state in request.query_state():
|
||||
state.reset()
|
||||
|
||||
result = infer_data(input_data[key], exec_net, input_blobs, output_blobs)
|
||||
result = infer_data(
|
||||
input_data[key], exec_net, input_blobs, output_blobs, args.context_window_left, args.context_window_right,
|
||||
)
|
||||
|
||||
for blob_name in result.keys():
|
||||
results[blob_name][key] = result[blob_name]
|
||||
|
||||
Reference in New Issue
Block a user