Add reset_state method to Python API (#21660)
* Add reset_state method to Python API * fix missprint * codestyle * extend python tests * delete debug code * resolve review comments
This commit is contained in:
parent
0e7248430c
commit
c4a49a3987
@ -44,8 +44,7 @@ def main():
|
||||
output_tensors.append(infer_request.get_tensor(output))
|
||||
|
||||
# 7. Initialize memory state before starting
|
||||
for state in infer_request.query_state():
|
||||
state.reset()
|
||||
infer_request.reset_state()
|
||||
|
||||
#! [ov:part1]
|
||||
# input data
|
||||
@ -66,8 +65,7 @@ def main():
|
||||
log.info(state_buf[0])
|
||||
|
||||
log.info("\nReset state between utterances...\n")
|
||||
for state in infer_request.query_state():
|
||||
state.reset()
|
||||
infer_request.reset_state()
|
||||
|
||||
log.info("Infer the second utterance")
|
||||
for next_input in range(int(len(input_data)/2), len(input_data)):
|
||||
|
@ -218,8 +218,7 @@ def main():
|
||||
start_infer_time = default_timer()
|
||||
|
||||
# Reset states between utterance inferences to remove a memory impact
|
||||
for state in infer_request.query_state():
|
||||
state.reset()
|
||||
infer_request.reset_state()
|
||||
|
||||
results.append(do_inference(
|
||||
infer_data[i],
|
||||
|
@ -612,6 +612,16 @@ void regclass_InferRequest(py::module m) {
|
||||
:rtype: List[openvino.runtime.VariableState]
|
||||
)");
|
||||
|
||||
cls.def(
|
||||
"reset_state",
|
||||
[](InferRequestWrapper& self) {
|
||||
return self.m_request.reset_state();
|
||||
},
|
||||
R"(
|
||||
Resets all internal variable states for relevant infer request to
|
||||
a value specified as default for the corresponding `ReadValue` node
|
||||
)");
|
||||
|
||||
cls.def(
|
||||
"get_compiled_model",
|
||||
[](InferRequestWrapper& self) {
|
||||
|
@ -395,12 +395,12 @@ def test_get_compiled_model(device):
|
||||
[np.float32,
|
||||
np.int32,
|
||||
np.float16])
|
||||
@pytest.mark.parametrize("mode", ["set_init_memory_state", "reset_memory_state", "normal"])
|
||||
@pytest.mark.parametrize("mode", ["set_init_memory_state", "reset_memory_state", "normal", "reset_via_infer_request"])
|
||||
@pytest.mark.parametrize("input_shape", [[10], [10, 10], [10, 10, 10], [2, 10, 10, 10]])
|
||||
@pytest.mark.skipif(
|
||||
os.environ.get("TEST_DEVICE", "CPU") != "CPU",
|
||||
os.environ.get("TEST_DEVICE", "CPU") not in ["CPU", "GPU"],
|
||||
reason=f"Can't run test on device {os.environ.get('TEST_DEVICE', 'CPU')}, "
|
||||
"Memory layers fully supported only on CPU",
|
||||
"Memory layers fully supported only on CPU and GPU",
|
||||
)
|
||||
def test_query_state_write_buffer(device, input_shape, data_type, mode):
|
||||
core = Core()
|
||||
@ -434,6 +434,12 @@ def test_query_state_write_buffer(device, input_shape, data_type, mode):
|
||||
res = request.infer({0: np.full(input_shape, 1, dtype=data_type)})
|
||||
# always ones
|
||||
expected_res = np.full(input_shape, 1, dtype=data_type)
|
||||
elif mode == "reset_via_infer_request":
|
||||
# reset initial state of ReadValue to zero
|
||||
request.reset_state()
|
||||
res = request.infer({0: np.full(input_shape, 1, dtype=data_type)})
|
||||
# always ones
|
||||
expected_res = np.full(input_shape, 1, dtype=data_type)
|
||||
else:
|
||||
res = request.infer({0: np.full(input_shape, 1, dtype=data_type)})
|
||||
expected_res = np.full(input_shape, i, dtype=data_type)
|
||||
|
Loading…
Reference in New Issue
Block a user