Add reset_state method to Python API (#21660)

* Add reset_state method to Python API

* fix missprint

* codestyle

* extend python tests

* delete debug code

* resolve review comments
This commit is contained in:
Ivan Tikhonov 2023-12-18 10:56:53 +03:30 committed by GitHub
parent 0e7248430c
commit c4a49a3987
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 22 additions and 9 deletions

View File

@ -44,8 +44,7 @@ def main():
output_tensors.append(infer_request.get_tensor(output))
# 7. Initialize memory state before starting
for state in infer_request.query_state():
state.reset()
infer_request.reset_state()
#! [ov:part1]
# input data
@ -66,8 +65,7 @@ def main():
log.info(state_buf[0])
log.info("\nReset state between utterances...\n")
for state in infer_request.query_state():
state.reset()
infer_request.reset_state()
log.info("Infer the second utterance")
for next_input in range(int(len(input_data)/2), len(input_data)):

View File

@ -218,8 +218,7 @@ def main():
start_infer_time = default_timer()
# Reset states between utterance inferences to remove a memory impact
for state in infer_request.query_state():
state.reset()
infer_request.reset_state()
results.append(do_inference(
infer_data[i],

View File

@ -612,6 +612,16 @@ void regclass_InferRequest(py::module m) {
:rtype: List[openvino.runtime.VariableState]
)");
cls.def(
"reset_state",
[](InferRequestWrapper& self) {
return self.m_request.reset_state();
},
R"(
Resets all internal variable states for relevant infer request to
a value specified as default for the corresponding `ReadValue` node
)");
cls.def(
"get_compiled_model",
[](InferRequestWrapper& self) {

View File

@ -395,12 +395,12 @@ def test_get_compiled_model(device):
[np.float32,
np.int32,
np.float16])
@pytest.mark.parametrize("mode", ["set_init_memory_state", "reset_memory_state", "normal"])
@pytest.mark.parametrize("mode", ["set_init_memory_state", "reset_memory_state", "normal", "reset_via_infer_request"])
@pytest.mark.parametrize("input_shape", [[10], [10, 10], [10, 10, 10], [2, 10, 10, 10]])
@pytest.mark.skipif(
os.environ.get("TEST_DEVICE", "CPU") != "CPU",
os.environ.get("TEST_DEVICE", "CPU") not in ["CPU", "GPU"],
reason=f"Can't run test on device {os.environ.get('TEST_DEVICE', 'CPU')}, "
"Memory layers fully supported only on CPU",
"Memory layers fully supported only on CPU and GPU",
)
def test_query_state_write_buffer(device, input_shape, data_type, mode):
core = Core()
@ -434,6 +434,12 @@ def test_query_state_write_buffer(device, input_shape, data_type, mode):
res = request.infer({0: np.full(input_shape, 1, dtype=data_type)})
# always ones
expected_res = np.full(input_shape, 1, dtype=data_type)
elif mode == "reset_via_infer_request":
# reset initial state of ReadValue to zero
request.reset_state()
res = request.infer({0: np.full(input_shape, 1, dtype=data_type)})
# always ones
expected_res = np.full(input_shape, 1, dtype=data_type)
else:
res = request.infer({0: np.full(input_shape, 1, dtype=data_type)})
expected_res = np.full(input_shape, i, dtype=data_type)