[IE Python Speech Sample] Add context windows feature (#7801)

* Add `context_window_left` feature

* Add a check of positive context window args

* Add `context_window_right` feature
This commit is contained in:
Dmitry Pigasin
2021-10-28 00:17:34 +03:00
committed by GitHub
parent fe457aa59c
commit 054a2f8d9c
2 changed files with 56 additions and 14 deletions

View File

@@ -30,7 +30,9 @@ def get_scale_factor(matrix: np.ndarray) -> float:
return target_max / max_val
def infer_data(data: dict, exec_net: ExecutableNetwork, input_blobs: list, output_blobs: list) -> np.ndarray:
def infer_data(
data: dict, exec_net: ExecutableNetwork, input_blobs: list, output_blobs: list, cw_l: int = 0, cw_r: int = 0,
) -> np.ndarray:
"""Do a synchronous matrix inference"""
matrix_shape = next(iter(data.values())).shape
result = {}
@@ -40,11 +42,16 @@ def infer_data(data: dict, exec_net: ExecutableNetwork, input_blobs: list, outpu
batch_size = shape[0]
result[blob_name] = np.ndarray((matrix_shape[0], shape[-1]))
slice_begin = 0
slice_end = batch_size
for i in range(-cw_l, matrix_shape[0] + cw_r, batch_size):
if i < 0:
index = 0
elif i >= matrix_shape[0]:
index = matrix_shape[0] - 1
else:
index = i
vectors = {blob_name: data[blob_name][index:index + batch_size] for blob_name in input_blobs}
while slice_begin < matrix_shape[0]:
vectors = {blob_name: data[blob_name][slice_begin:slice_end] for blob_name in input_blobs}
num_of_vectors = next(iter(vectors.values())).shape[0]
if num_of_vectors < batch_size:
@@ -57,11 +64,11 @@ def infer_data(data: dict, exec_net: ExecutableNetwork, input_blobs: list, outpu
vector_results = exec_net.infer(vectors)
for blob_name in output_blobs:
result[blob_name][slice_begin:slice_end] = vector_results[blob_name][:num_of_vectors]
if i - cw_r < 0:
continue
slice_begin += batch_size
slice_end += batch_size
for blob_name in output_blobs:
result[blob_name][i - cw_r:i - cw_r + batch_size] = vector_results[blob_name][:num_of_vectors]
return result
@@ -161,7 +168,7 @@ def main():
for blob_name in output_blobs:
net.outputs[blob_name].precision = 'FP32'
net.batch_size = args.batch_size
net.batch_size = args.batch_size if args.context_window_left + args.context_window_right == 0 else 1
# ---------------------------Step 4. Loading model to the device-------------------------------------------------------
devices = args.device.replace('HETERO:', '').split(',')
@@ -272,7 +279,9 @@ def main():
for state in request.query_state():
state.reset()
result = infer_data(input_data[key], exec_net, input_blobs, output_blobs)
result = infer_data(
input_data[key], exec_net, input_blobs, output_blobs, args.context_window_left, args.context_window_right,
)
for blob_name in result.keys():
results[blob_name][key] = result[blob_name]