[IE Python Speech Sample] Add context windows feature (#7801)
* Add `context_window_left` feature * Add a check of positive context window args * Add `context_window_right` feature
This commit is contained in:
@@ -30,7 +30,9 @@ def get_scale_factor(matrix: np.ndarray) -> float:
|
||||
return target_max / max_val
|
||||
|
||||
|
||||
def infer_data(data: dict, exec_net: ExecutableNetwork, input_blobs: list, output_blobs: list) -> np.ndarray:
|
||||
def infer_data(
|
||||
data: dict, exec_net: ExecutableNetwork, input_blobs: list, output_blobs: list, cw_l: int = 0, cw_r: int = 0,
|
||||
) -> np.ndarray:
|
||||
"""Do a synchronous matrix inference"""
|
||||
matrix_shape = next(iter(data.values())).shape
|
||||
result = {}
|
||||
@@ -40,11 +42,16 @@ def infer_data(data: dict, exec_net: ExecutableNetwork, input_blobs: list, outpu
|
||||
batch_size = shape[0]
|
||||
result[blob_name] = np.ndarray((matrix_shape[0], shape[-1]))
|
||||
|
||||
slice_begin = 0
|
||||
slice_end = batch_size
|
||||
for i in range(-cw_l, matrix_shape[0] + cw_r, batch_size):
|
||||
if i < 0:
|
||||
index = 0
|
||||
elif i >= matrix_shape[0]:
|
||||
index = matrix_shape[0] - 1
|
||||
else:
|
||||
index = i
|
||||
|
||||
vectors = {blob_name: data[blob_name][index:index + batch_size] for blob_name in input_blobs}
|
||||
|
||||
while slice_begin < matrix_shape[0]:
|
||||
vectors = {blob_name: data[blob_name][slice_begin:slice_end] for blob_name in input_blobs}
|
||||
num_of_vectors = next(iter(vectors.values())).shape[0]
|
||||
|
||||
if num_of_vectors < batch_size:
|
||||
@@ -57,11 +64,11 @@ def infer_data(data: dict, exec_net: ExecutableNetwork, input_blobs: list, outpu
|
||||
|
||||
vector_results = exec_net.infer(vectors)
|
||||
|
||||
for blob_name in output_blobs:
|
||||
result[blob_name][slice_begin:slice_end] = vector_results[blob_name][:num_of_vectors]
|
||||
if i - cw_r < 0:
|
||||
continue
|
||||
|
||||
slice_begin += batch_size
|
||||
slice_end += batch_size
|
||||
for blob_name in output_blobs:
|
||||
result[blob_name][i - cw_r:i - cw_r + batch_size] = vector_results[blob_name][:num_of_vectors]
|
||||
|
||||
return result
|
||||
|
||||
@@ -161,7 +168,7 @@ def main():
|
||||
for blob_name in output_blobs:
|
||||
net.outputs[blob_name].precision = 'FP32'
|
||||
|
||||
net.batch_size = args.batch_size
|
||||
net.batch_size = args.batch_size if args.context_window_left + args.context_window_right == 0 else 1
|
||||
|
||||
# ---------------------------Step 4. Loading model to the device-------------------------------------------------------
|
||||
devices = args.device.replace('HETERO:', '').split(',')
|
||||
@@ -272,7 +279,9 @@ def main():
|
||||
for state in request.query_state():
|
||||
state.reset()
|
||||
|
||||
result = infer_data(input_data[key], exec_net, input_blobs, output_blobs)
|
||||
result = infer_data(
|
||||
input_data[key], exec_net, input_blobs, output_blobs, args.context_window_left, args.context_window_right,
|
||||
)
|
||||
|
||||
for blob_name in result.keys():
|
||||
results[blob_name][key] = result[blob_name]
|
||||
|
||||
Reference in New Issue
Block a user