[PyOV] Remove deprecated py api (#21675)

* [PyOV][Draft] Remove deprecated py api

* clean up

* remove symbol

* fix failure

* fix ci

* remove shared_memory from arguments

* update tests
This commit is contained in:
Anastasia Kuporosova 2023-12-21 16:57:22 +01:00 committed by GitHub
parent 243602929f
commit d10f49441d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
14 changed files with 9 additions and 209 deletions

View File

@ -6,7 +6,7 @@
from openvino._pyopenvino.properties.hint import Priority
from openvino._pyopenvino.properties.hint import SchedulingCoreType
from openvino._pyopenvino.properties.hint import ExecutionMode
from openvino.runtime.properties.hint.overloads import PerformanceMode
from openvino._pyopenvino.properties.hint import PerformanceMode
# Properties
import openvino._pyopenvino.properties.hint as __hint

View File

@ -23,19 +23,6 @@ from openvino.runtime.utils.data_helpers import (
)
def _deprecated_memory_arg(shared_memory: bool, share_inputs: bool) -> bool:
if shared_memory is not None:
warnings.warn(
"`shared_memory` is deprecated and will be removed in 2024.0. "
"Value of `shared_memory` is going to override `share_inputs` value. "
"Please use only `share_inputs` explicitly.",
FutureWarning,
stacklevel=3,
)
return shared_memory
return share_inputs
class Model(ModelBase):
def __init__(self, *args: Any, **kwargs: Any) -> None:
if args and not kwargs:
@ -70,8 +57,6 @@ class InferRequest(_InferRequestWrapper):
inputs: Any = None,
share_inputs: bool = False,
share_outputs: bool = False,
*,
shared_memory: Any = None,
) -> OVDict:
"""Infers specified input(s) in synchronous mode.
@ -129,22 +114,14 @@ class InferRequest(_InferRequestWrapper):
Default value: False
:type share_outputs: bool, optional
:param shared_memory: Deprecated. Works like `share_inputs` mode.
If not specified, function uses `share_inputs` value.
Note: Will be removed in 2024.0 release!
Note: This is keyword-only argument.
Default value: None
:type shared_memory: bool, optional
:return: Dictionary of results from output tensors with port/int/str keys.
:rtype: OVDict
"""
return OVDict(super().infer(_data_dispatch(
self,
inputs,
is_shared=_deprecated_memory_arg(shared_memory, share_inputs),
is_shared=share_inputs,
), share_outputs=share_outputs))
def start_async(
@ -152,8 +129,6 @@ class InferRequest(_InferRequestWrapper):
inputs: Any = None,
userdata: Any = None,
share_inputs: bool = False,
*,
shared_memory: Any = None,
) -> None:
"""Starts inference of specified input(s) in asynchronous mode.
@ -202,21 +177,12 @@ class InferRequest(_InferRequestWrapper):
Default value: False
:type share_inputs: bool, optional
:param shared_memory: Deprecated. Works like `share_inputs` mode.
If not specified, function uses `share_inputs` value.
Note: Will be removed in 2024.0 release!
Note: This is keyword-only argument.
Default value: None
:type shared_memory: bool, optional
"""
super().start_async(
_data_dispatch(
self,
inputs,
is_shared=_deprecated_memory_arg(shared_memory, share_inputs),
is_shared=share_inputs,
),
userdata,
)
@ -302,8 +268,6 @@ class CompiledModel(CompiledModelBase):
inputs: Any = None,
share_inputs: bool = True,
share_outputs: bool = False,
*,
shared_memory: Any = None,
) -> OVDict:
"""Callable infer wrapper for CompiledModel.
@ -369,15 +333,7 @@ class CompiledModel(CompiledModelBase):
Default value: False
:type share_outputs: bool, optional
:param shared_memory: Deprecated. Works like `share_inputs` mode.
If not specified, function uses `share_inputs` value.
Note: Will be removed in 2024.0 release!
Note: This is keyword-only argument.
Default value: None
:type shared_memory: bool, optional
:return: Dictionary of results from output tensors with port/int/str as keys.
:rtype: OVDict
"""
@ -386,7 +342,7 @@ class CompiledModel(CompiledModelBase):
return self._infer_request.infer(
inputs,
share_inputs=_deprecated_memory_arg(shared_memory, share_inputs),
share_inputs=share_inputs,
share_outputs=share_outputs,
)
@ -430,8 +386,6 @@ class AsyncInferQueue(AsyncInferQueueBase):
inputs: Any = None,
userdata: Any = None,
share_inputs: bool = False,
*,
shared_memory: Any = None,
) -> None:
"""Run asynchronous inference using the next available InferRequest from the pool.
@ -476,21 +430,12 @@ class AsyncInferQueue(AsyncInferQueueBase):
Default value: False
:type share_inputs: bool, optional
:param shared_memory: Deprecated. Works like `share_inputs` mode.
If not specified, function uses `share_inputs` value.
Note: Will be removed in 2024.0 release!
Note: This is keyword-only argument.
Default value: None
:type shared_memory: bool, optional
"""
super().start_async(
_data_dispatch(
self[self.get_idle_request_id()],
inputs,
is_shared=_deprecated_memory_arg(shared_memory, share_inputs),
is_shared=share_inputs,
),
userdata,
)

View File

@ -6,7 +6,7 @@
from openvino._pyopenvino.properties.hint import Priority
from openvino._pyopenvino.properties.hint import SchedulingCoreType
from openvino._pyopenvino.properties.hint import ExecutionMode
from openvino.runtime.properties.hint.overloads import PerformanceMode
from openvino._pyopenvino.properties.hint import PerformanceMode
# Properties
from openvino._pyopenvino.properties.hint import inference_precision

View File

@ -1,19 +0,0 @@
# -*- coding: utf-8 -*-
# Copyright (C) 2018-2023 Intel Corporation
# SPDX-License-Identifier: Apache-2.0
from openvino.utils import deprecatedclassproperty
from openvino._pyopenvino.properties.hint import PerformanceMode as PerformanceModeBase
class PerformanceMode(PerformanceModeBase):
@deprecatedclassproperty(
name="PerformanceMode.UNDEFINED", # noqa: N802, N805
version="2024.0",
message="Please use actual value instead.",
stacklevel=2,
)
def UNDEFINED(cls) -> PerformanceModeBase: # noqa: N802, N805
return super().UNDEFINED

View File

@ -4,6 +4,4 @@
"""Generic utilities. Factor related functions out to separate files."""
from openvino._pyopenvino.util import numpy_to_c
from openvino._pyopenvino.util import get_constant_from_source, replace_node, replace_output_update_name
from openvino.runtime.utils.util import clone_model
from openvino._pyopenvino.util import numpy_to_c, replace_node, replace_output_update_name

View File

@ -1,17 +0,0 @@
# -*- coding: utf-8 -*-
# Copyright (C) 2018-2023 Intel Corporation
# SPDX-License-Identifier: Apache-2.0
from openvino._pyopenvino.util import clone_model as clone_model_base
from openvino.utils import deprecated
from typing import TYPE_CHECKING
if TYPE_CHECKING:
from openvino.runtime import Model
@deprecated(version="2024.0")
def clone_model(model: "Model") -> "Model":
from openvino.runtime import Model
return Model(clone_model_base(model))

View File

@ -667,30 +667,6 @@ void regclass_InferRequest(py::module m) {
:rtype: List[openvino.runtime.ConstOutput]
)");
cls.def_property_readonly(
"inputs",
[](InferRequestWrapper& self) {
Common::utils::deprecation_warning("inputs", "2024.0", "Please use 'input_tensors' property instead.");
return self.get_input_tensors();
},
R"(
Gets all input tensors of this InferRequest.
:rtype: List[openvino.runtime.Tensor]
)");
cls.def_property_readonly(
"outputs",
[](InferRequestWrapper& self) {
Common::utils::deprecation_warning("outputs", "2024.0", "Please use 'output_tensors' property instead.");
return self.get_output_tensors();
},
R"(
Gets all output tensors of this InferRequest.
:rtype: List[openvino.runtime.Tensor]
)");
cls.def_property_readonly("input_tensors",
&InferRequestWrapper::get_input_tensors,
R"(

View File

@ -55,12 +55,10 @@ void regmodule_properties(py::module m) {
.value("HIGH", ov::hint::Priority::HIGH)
.value("DEFAULT", ov::hint::Priority::DEFAULT);
OPENVINO_SUPPRESS_DEPRECATED_START
py::enum_<ov::hint::PerformanceMode>(m_hint, "PerformanceMode", py::arithmetic())
.value("LATENCY", ov::hint::PerformanceMode::LATENCY)
.value("THROUGHPUT", ov::hint::PerformanceMode::THROUGHPUT)
.value("CUMULATIVE_THROUGHPUT", ov::hint::PerformanceMode::CUMULATIVE_THROUGHPUT);
OPENVINO_SUPPRESS_DEPRECATED_END
py::enum_<ov::hint::SchedulingCoreType>(m_hint, "SchedulingCoreType", py::arithmetic())
.value("ANY_CORE", ov::hint::SchedulingCoreType::ANY_CORE)

View File

@ -27,35 +27,6 @@ inline void* numpy_to_c(py::array a) {
void regmodule_graph_util(py::module m) {
py::module mod = m.def_submodule("util", "openvino.runtime.utils");
mod.def("numpy_to_c", &numpy_to_c);
OPENVINO_SUPPRESS_DEPRECATED_START
mod.def("get_constant_from_source",
&ov::get_constant_from_source,
py::arg("output"),
R"(
Runs an estimation of source tensor.
:param index: Output node.
:type index: openvino.runtime.Output
:return: If it succeeded to calculate both bounds and
they are the same, returns Constant operation
from the resulting bound, otherwise Null.
:rtype: openvino.runtime.op.Constant or openvino.runtime.Node
)");
OPENVINO_SUPPRESS_DEPRECATED_END
mod.def(
"clone_model",
[](ov::Model& model) {
return model.clone();
},
py::arg("model"),
R"(
Creates a copy of a model object.
:param model: Model to copy.
:type model: openvino.runtime.Model
:return: A copy of Model.
:rtype: openvino.runtime.Model
)");
mod.def("replace_output_update_name", &ov::replace_output_update_name, py::arg("output"), py::arg("target_output"));

View File

@ -318,30 +318,22 @@ def test_clone_model():
assert isinstance(model_original, Model)
# Make copies of it
with pytest.deprecated_call():
model_copy1 = ov.utils.clone_model(model_original)
model_copy2 = model_original.clone()
model_copy3 = deepcopy(model_original)
assert isinstance(model_copy1, Model)
assert isinstance(model_copy2, Model)
assert isinstance(model_copy3, Model)
# Make changes to the copied models' inputs
model_copy1.reshape({"A": [3, 3], "B": [3, 3]})
model_copy2.reshape({"A": [3, 3], "B": [3, 3]})
model_copy3.reshape({"A": [3, 3], "B": [3, 3]})
original_model_shapes = [single_input.get_shape() for single_input in model_original.inputs]
model_copy1_shapes = [single_input.get_shape() for single_input in model_copy1.inputs]
model_copy2_shapes = [single_input.get_shape() for single_input in model_copy2.inputs]
model_copy3_shapes = [single_input.get_shape() for single_input in model_copy3.inputs]
assert original_model_shapes != model_copy1_shapes
assert original_model_shapes != model_copy2_shapes
assert original_model_shapes != model_copy3_shapes
assert model_copy1_shapes == model_copy2_shapes
assert model_copy1_shapes == model_copy3_shapes
assert model_copy2_shapes == model_copy3_shapes

View File

@ -2,30 +2,8 @@
# Copyright (C) 2018-2023 Intel Corporation
# SPDX-License-Identifier: Apache-2.0
import openvino.runtime as ov
import pytest
from openvino._pyopenvino.util import deprecation_warning
from openvino import Shape
def test_get_constant_from_source_success():
input1 = ov.opset8.parameter(Shape([5, 5]), dtype=int, name="input_1")
input2 = ov.opset8.parameter(Shape([25]), dtype=int, name="input_2")
shape_of = ov.opset8.shape_of(input2, name="shape_of")
reshape = ov.opset8.reshape(input1, shape_of, special_zero=True)
folded_const = ov.utils.get_constant_from_source(reshape.input(1).get_source_output())
assert folded_const is not None
assert folded_const.get_vector() == [25]
def test_get_constant_from_source_failed():
input1 = ov.opset8.parameter(Shape([5, 5]), dtype=int, name="input_1")
input2 = ov.opset8.parameter(Shape([1]), dtype=int, name="input_2")
reshape = ov.opset8.reshape(input1, input2, special_zero=True)
folded_const = ov.utils.get_constant_from_source(reshape.input(1).get_source_output())
assert folded_const is None
def test_deprecation_warning():

View File

@ -296,28 +296,6 @@ def test_array_like_input_async_infer_queue(device, share_inputs):
infer_queue_list[i].get_output_tensor().data, np.abs(input_data))
@pytest.mark.parametrize("shared_flag", [True, False])
def test_shared_memory_deprecation(device, shared_flag):
compiled, request, _, input_data = abs_model_with_data(
device, Type.f32, np.float32)
with pytest.warns(FutureWarning, match="`shared_memory` is deprecated and will be removed in 2024.0"):
_ = compiled(input_data, shared_memory=shared_flag)
with pytest.warns(FutureWarning, match="`shared_memory` is deprecated and will be removed in 2024.0"):
_ = request.infer(input_data, shared_memory=shared_flag)
with pytest.warns(FutureWarning, match="`shared_memory` is deprecated and will be removed in 2024.0"):
request.start_async(input_data, shared_memory=shared_flag)
request.wait()
queue = AsyncInferQueue(compiled, jobs=1)
with pytest.warns(FutureWarning, match="`shared_memory` is deprecated and will be removed in 2024.0"):
queue.start_async(input_data, shared_memory=shared_flag)
queue.wait_all()
@pytest.mark.skip(reason="Sporadically failed. Need further investigation. Ticket - 95967")
def test_cancel(device):
core = Core()

View File

@ -213,7 +213,7 @@ def test_direct_infer(device, shared_flag):
compiled_model, img = generate_model_and_image(device)
tensor = Tensor(img)
res = compiled_model({"data": tensor}, shared_memory=shared_flag)
res = compiled_model({"data": tensor}, share_inputs=shared_flag)
assert np.argmax(res[compiled_model.outputs[0]]) == 531
ref = compiled_model.infer_new_request({"data": tensor})
assert np.array_equal(ref[compiled_model.outputs[0]], res[compiled_model.outputs[0]])

View File

@ -245,7 +245,7 @@ def test_properties_ro(ov_property_ro, expected_value):
(
hints.performance_mode,
"PERFORMANCE_HINT",
((hints.PerformanceMode.THROUGHPUT, hints.PerformanceMode.THROUGHPUT),),
((hints.PerformanceMode.LATENCY, hints.PerformanceMode.LATENCY),),
),
(
hints.enable_cpu_pinning,