Add reset_state method to Python API (#21660)
* Add reset_state method to Python API * fix missprint * codestyle * extend python tests * delete debug code * resolve review comments
This commit is contained in:
@@ -612,6 +612,16 @@ void regclass_InferRequest(py::module m) {
|
||||
:rtype: List[openvino.runtime.VariableState]
|
||||
)");
|
||||
|
||||
cls.def(
|
||||
"reset_state",
|
||||
[](InferRequestWrapper& self) {
|
||||
return self.m_request.reset_state();
|
||||
},
|
||||
R"(
|
||||
Resets all internal variable states for relevant infer request to
|
||||
a value specified as default for the corresponding `ReadValue` node
|
||||
)");
|
||||
|
||||
cls.def(
|
||||
"get_compiled_model",
|
||||
[](InferRequestWrapper& self) {
|
||||
|
||||
@@ -395,12 +395,12 @@ def test_get_compiled_model(device):
|
||||
[np.float32,
|
||||
np.int32,
|
||||
np.float16])
|
||||
@pytest.mark.parametrize("mode", ["set_init_memory_state", "reset_memory_state", "normal"])
|
||||
@pytest.mark.parametrize("mode", ["set_init_memory_state", "reset_memory_state", "normal", "reset_via_infer_request"])
|
||||
@pytest.mark.parametrize("input_shape", [[10], [10, 10], [10, 10, 10], [2, 10, 10, 10]])
|
||||
@pytest.mark.skipif(
|
||||
os.environ.get("TEST_DEVICE", "CPU") != "CPU",
|
||||
os.environ.get("TEST_DEVICE", "CPU") not in ["CPU", "GPU"],
|
||||
reason=f"Can't run test on device {os.environ.get('TEST_DEVICE', 'CPU')}, "
|
||||
"Memory layers fully supported only on CPU",
|
||||
"Memory layers fully supported only on CPU and GPU",
|
||||
)
|
||||
def test_query_state_write_buffer(device, input_shape, data_type, mode):
|
||||
core = Core()
|
||||
@@ -434,6 +434,12 @@ def test_query_state_write_buffer(device, input_shape, data_type, mode):
|
||||
res = request.infer({0: np.full(input_shape, 1, dtype=data_type)})
|
||||
# always ones
|
||||
expected_res = np.full(input_shape, 1, dtype=data_type)
|
||||
elif mode == "reset_via_infer_request":
|
||||
# reset initial state of ReadValue to zero
|
||||
request.reset_state()
|
||||
res = request.infer({0: np.full(input_shape, 1, dtype=data_type)})
|
||||
# always ones
|
||||
expected_res = np.full(input_shape, 1, dtype=data_type)
|
||||
else:
|
||||
res = request.infer({0: np.full(input_shape, 1, dtype=data_type)})
|
||||
expected_res = np.full(input_shape, i, dtype=data_type)
|
||||
|
||||
Reference in New Issue
Block a user