[PYTHON] Allow ports as inputs for infer methods (#9839)
* Use of ports for input values * Refactor tests * Fix inputs overwritting * Fixed for review comments * Remove unused branch and refactor ie_api.py
This commit is contained in:
parent
f45991bd64
commit
1e58c55678
@ -3,14 +3,14 @@
|
||||
|
||||
import numpy as np
|
||||
import copy
|
||||
from typing import Any, List, Union
|
||||
from typing import Any, List, Type, Union
|
||||
|
||||
from openvino.pyopenvino import Model
|
||||
from openvino.pyopenvino import Core as CoreBase
|
||||
from openvino.pyopenvino import CompiledModel as CompiledModelBase
|
||||
from openvino.pyopenvino import InferRequest as InferRequestBase
|
||||
from openvino.pyopenvino import AsyncInferQueue as AsyncInferQueueBase
|
||||
from openvino.pyopenvino import Output
|
||||
from openvino.pyopenvino import ConstOutput
|
||||
from openvino.pyopenvino import Tensor
|
||||
from openvino.pyopenvino import OVAny as OVAnyBase
|
||||
|
||||
@ -22,23 +22,43 @@ def tensor_from_file(path: str) -> Tensor:
|
||||
return Tensor(np.fromfile(path, dtype=np.uint8))
|
||||
|
||||
|
||||
def normalize_inputs(inputs: Union[dict, list], py_types: dict) -> dict:
|
||||
"""Normalize a dictionary of inputs to Tensors."""
|
||||
if isinstance(inputs, list):
|
||||
inputs = {index: input for index, input in enumerate(inputs)}
|
||||
def convert_dict_items(inputs: dict, py_types: dict) -> dict:
|
||||
"""Helper function converting dictionary items to Tensors."""
|
||||
# Create new temporary dictionary.
|
||||
# new_inputs will be used to transfer data to inference calls,
|
||||
# ensuring that original inputs are not overwritten with Tensors.
|
||||
new_inputs = {}
|
||||
for k, val in inputs.items():
|
||||
if not isinstance(k, (str, int)):
|
||||
raise TypeError("Incompatible key type for tensor named: {}".format(k))
|
||||
if not isinstance(k, (str, int, ConstOutput)):
|
||||
raise TypeError("Incompatible key type for tensor: {}".format(k))
|
||||
try:
|
||||
ov_type = py_types[k]
|
||||
except KeyError:
|
||||
raise KeyError("Port for tensor named {} was not found!".format(k))
|
||||
inputs[k] = (
|
||||
raise KeyError("Port for tensor {} was not found!".format(k))
|
||||
# Convert numpy arrays or copy Tensors
|
||||
new_inputs[k] = (
|
||||
val
|
||||
if isinstance(val, Tensor)
|
||||
else Tensor(np.array(val, get_dtype(ov_type)))
|
||||
)
|
||||
return inputs
|
||||
return new_inputs
|
||||
|
||||
|
||||
def normalize_inputs(inputs: Union[dict, list], py_types: dict) -> dict:
|
||||
"""Normalize a dictionary of inputs to Tensors."""
|
||||
if isinstance(inputs, dict):
|
||||
return convert_dict_items(inputs, py_types)
|
||||
elif isinstance(inputs, list):
|
||||
# Lists are required to be represented as dictionaries with int keys
|
||||
return convert_dict_items(
|
||||
{index: input for index, input in enumerate(inputs)}, py_types
|
||||
)
|
||||
else:
|
||||
raise TypeError(
|
||||
"Inputs should be either list or dict! Current type: {}".format(
|
||||
type(inputs)
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def get_input_types(obj: Union[InferRequestBase, CompiledModelBase]) -> dict:
|
||||
@ -47,13 +67,18 @@ def get_input_types(obj: Union[InferRequestBase, CompiledModelBase]) -> dict:
|
||||
def get_inputs(obj: Union[InferRequestBase, CompiledModelBase]) -> list:
|
||||
return obj.model_inputs if isinstance(obj, InferRequestBase) else obj.inputs
|
||||
|
||||
def map_tensor_names_to_types(input: Output) -> dict:
|
||||
def map_tensor_names_to_types(input: ConstOutput) -> dict:
|
||||
return {n: input.get_element_type() for n in input.get_names()}
|
||||
|
||||
input_types: dict = {}
|
||||
for idx, input in enumerate(get_inputs(obj)):
|
||||
input_types.update(map_tensor_names_to_types(input))
|
||||
# Add all possible "accessing aliases" to dictionary
|
||||
# Key as a ConstOutput port
|
||||
input_types[input] = input.get_element_type()
|
||||
# Key as an integer
|
||||
input_types[idx] = input.get_element_type()
|
||||
# Multiple possible keys as Tensor names
|
||||
input_types.update(map_tensor_names_to_types(input))
|
||||
return input_types
|
||||
|
||||
|
||||
@ -62,17 +87,18 @@ class InferRequest(InferRequestBase):
|
||||
|
||||
def infer(self, inputs: Union[dict, list] = None) -> dict:
|
||||
"""Infer wrapper for InferRequest."""
|
||||
inputs = (
|
||||
return super().infer(
|
||||
{} if inputs is None else normalize_inputs(inputs, get_input_types(self))
|
||||
)
|
||||
return super().infer(inputs)
|
||||
|
||||
def start_async(self, inputs: Union[dict, list] = None, userdata: Any = None) -> None:
|
||||
def start_async(
|
||||
self, inputs: Union[dict, list] = None, userdata: Any = None
|
||||
) -> None:
|
||||
"""Asynchronous infer wrapper for InferRequest."""
|
||||
inputs = (
|
||||
{} if inputs is None else normalize_inputs(inputs, get_input_types(self))
|
||||
super().start_async(
|
||||
{} if inputs is None else normalize_inputs(inputs, get_input_types(self)),
|
||||
userdata,
|
||||
)
|
||||
super().start_async(inputs, userdata)
|
||||
|
||||
|
||||
class CompiledModel(CompiledModelBase):
|
||||
@ -84,10 +110,9 @@ class CompiledModel(CompiledModelBase):
|
||||
|
||||
def infer_new_request(self, inputs: Union[dict, list] = None) -> dict:
|
||||
"""Infer wrapper for CompiledModel."""
|
||||
inputs = (
|
||||
return super().infer_new_request(
|
||||
{} if inputs is None else normalize_inputs(inputs, get_input_types(self))
|
||||
)
|
||||
return super().infer_new_request(inputs)
|
||||
|
||||
def __call__(self, inputs: Union[dict, list] = None) -> dict:
|
||||
"""Callable infer wrapper for CompiledModel."""
|
||||
@ -101,16 +126,18 @@ class AsyncInferQueue(AsyncInferQueueBase):
|
||||
"""Return i-th InferRequest from AsyncInferQueue."""
|
||||
return InferRequest(super().__getitem__(i))
|
||||
|
||||
def start_async(self, inputs: Union[dict, list] = None, userdata: Any = None) -> None:
|
||||
def start_async(
|
||||
self, inputs: Union[dict, list] = None, userdata: Any = None
|
||||
) -> None:
|
||||
"""Asynchronous infer wrapper for AsyncInferQueue."""
|
||||
inputs = (
|
||||
super().start_async(
|
||||
{}
|
||||
if inputs is None
|
||||
else normalize_inputs(
|
||||
inputs, get_input_types(self[self.get_idle_request_id()])
|
||||
)
|
||||
),
|
||||
userdata,
|
||||
)
|
||||
super().start_async(inputs, userdata)
|
||||
|
||||
|
||||
class Core(CoreBase):
|
||||
|
@ -162,10 +162,15 @@ const Containers::TensorIndexMap cast_to_tensor_index_map(const py::dict& inputs
|
||||
void set_request_tensors(ov::InferRequest& request, const py::dict& inputs) {
|
||||
if (!inputs.empty()) {
|
||||
for (auto&& input : inputs) {
|
||||
if (py::isinstance<py::str>(input.first)) {
|
||||
request.set_tensor(input.first.cast<std::string>(), Common::cast_to_tensor(input.second));
|
||||
// Cast second argument to tensor
|
||||
auto tensor = Common::cast_to_tensor(input.second);
|
||||
// Check if key is compatible, should be port/string/integer
|
||||
if (py::isinstance<ov::Output<const ov::Node>>(input.first)) {
|
||||
request.set_tensor(input.first.cast<ov::Output<const ov::Node>>(), tensor);
|
||||
} else if (py::isinstance<py::str>(input.first)) {
|
||||
request.set_tensor(input.first.cast<std::string>(), tensor);
|
||||
} else if (py::isinstance<py::int_>(input.first)) {
|
||||
request.set_input_tensor(input.first.cast<size_t>(), Common::cast_to_tensor(input.second));
|
||||
request.set_input_tensor(input.first.cast<size_t>(), tensor);
|
||||
} else {
|
||||
throw py::type_error("Incompatible key type for tensor named: " + input.first.cast<std::string>());
|
||||
}
|
||||
|
@ -267,7 +267,7 @@ def test_infer_new_request_wrong_port_name(device):
|
||||
exec_net = ie.compile_model(func, device)
|
||||
with pytest.raises(KeyError) as e:
|
||||
exec_net.infer_new_request({"_data_": tensor})
|
||||
assert "Port for tensor named _data_ was not found!" in str(e.value)
|
||||
assert "Port for tensor _data_ was not found!" in str(e.value)
|
||||
|
||||
|
||||
def test_infer_tensor_wrong_input_data(device):
|
||||
@ -279,7 +279,7 @@ def test_infer_tensor_wrong_input_data(device):
|
||||
exec_net = ie.compile_model(func, device)
|
||||
with pytest.raises(TypeError) as e:
|
||||
exec_net.infer_new_request({0.: tensor})
|
||||
assert "Incompatible key type for tensor named: 0." in str(e.value)
|
||||
assert "Incompatible key type for tensor: 0." in str(e.value)
|
||||
|
||||
|
||||
def test_infer_numpy_model_from_buffer(device):
|
||||
|
@ -1,6 +1,7 @@
|
||||
# Copyright (C) 2018-2022 Intel Corporation
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
from copy import deepcopy
|
||||
import numpy as np
|
||||
import os
|
||||
import pytest
|
||||
@ -17,24 +18,40 @@ is_myriad = os.environ.get("TEST_DEVICE") == "MYRIAD"
|
||||
test_net_xml, test_net_bin = model_path(is_myriad)
|
||||
|
||||
|
||||
def create_function_with_memory(input_shape, data_type):
|
||||
def create_model_with_memory(input_shape, data_type):
|
||||
input_data = ops.parameter(input_shape, name="input_data", dtype=data_type)
|
||||
rv = ops.read_value(input_data, "var_id_667")
|
||||
add = ops.add(rv, input_data, name="MemoryAdd")
|
||||
node = ops.assign(add, "var_id_667")
|
||||
res = ops.result(add, "res")
|
||||
func = Model(results=[res], sinks=[node], parameters=[input_data], name="name")
|
||||
return func
|
||||
model = Model(results=[res], sinks=[node], parameters=[input_data], name="name")
|
||||
return model
|
||||
|
||||
|
||||
def create_simple_request_and_inputs(device):
|
||||
input_shape = [2, 2]
|
||||
param_a = ops.parameter(input_shape, np.float32)
|
||||
param_b = ops.parameter(input_shape, np.float32)
|
||||
model = Model(ops.add(param_a, param_b), [param_a, param_b])
|
||||
|
||||
core = Core()
|
||||
compiled = core.compile_model(model, device)
|
||||
request = compiled.create_infer_request()
|
||||
|
||||
arr_1 = np.array([[1, 2], [3, 4]], dtype=np.float32)
|
||||
arr_2 = np.array([[3, 4], [1, 2]], dtype=np.float32)
|
||||
|
||||
return request, arr_1, arr_2
|
||||
|
||||
|
||||
def test_get_profiling_info(device):
|
||||
core = Core()
|
||||
func = core.read_model(test_net_xml, test_net_bin)
|
||||
model = core.read_model(test_net_xml, test_net_bin)
|
||||
core.set_config({"PERF_COUNT": "YES"}, device)
|
||||
exec_net = core.compile_model(func, device)
|
||||
compiled = core.compile_model(model, device)
|
||||
img = read_image()
|
||||
request = exec_net.create_infer_request()
|
||||
tensor_name = exec_net.input("data").any_name
|
||||
request = compiled.create_infer_request()
|
||||
tensor_name = compiled.input("data").any_name
|
||||
request.infer({tensor_name: img})
|
||||
assert request.latency > 0
|
||||
prof_info = request.get_profiling_info()
|
||||
@ -48,14 +65,15 @@ def test_get_profiling_info(device):
|
||||
|
||||
def test_tensor_setter(device):
|
||||
core = Core()
|
||||
func = core.read_model(test_net_xml, test_net_bin)
|
||||
exec_net_1 = core.compile_model(model=func, device_name=device)
|
||||
exec_net_2 = core.compile_model(model=func, device_name=device)
|
||||
model = core.read_model(test_net_xml, test_net_bin)
|
||||
compiled_1 = core.compile_model(model=model, device_name=device)
|
||||
compiled_2 = core.compile_model(model=model, device_name=device)
|
||||
compiled_3 = core.compile_model(model=model, device_name=device)
|
||||
|
||||
img = read_image()
|
||||
tensor = Tensor(img)
|
||||
|
||||
request1 = exec_net_1.create_infer_request()
|
||||
request1 = compiled_1.create_infer_request()
|
||||
request1.set_tensor("data", tensor)
|
||||
t1 = request1.get_tensor("data")
|
||||
|
||||
@ -67,7 +85,7 @@ def test_tensor_setter(device):
|
||||
t2 = request1.get_tensor("fc_out")
|
||||
assert np.allclose(t2.data, res[k].data, atol=1e-2, rtol=1e-2)
|
||||
|
||||
request = exec_net_2.create_infer_request()
|
||||
request = compiled_2.create_infer_request()
|
||||
res = request.infer({"data": tensor})
|
||||
res_2 = np.sort(request.get_tensor("fc_out").data)
|
||||
assert np.allclose(res_1, res_2, atol=1e-2, rtol=1e-2)
|
||||
@ -76,11 +94,23 @@ def test_tensor_setter(device):
|
||||
t3 = request.get_tensor("data")
|
||||
assert np.allclose(t3.data, t1.data, atol=1e-2, rtol=1e-2)
|
||||
|
||||
request = compiled_3.create_infer_request()
|
||||
request.set_tensor(model.inputs[0], tensor)
|
||||
t1 = request1.get_tensor(model.inputs[0])
|
||||
|
||||
assert np.allclose(tensor.data, t1.data, atol=1e-2, rtol=1e-2)
|
||||
|
||||
res = request.infer()
|
||||
k = list(res)[0]
|
||||
res_1 = np.sort(res[k])
|
||||
t2 = request1.get_tensor(model.outputs[0])
|
||||
assert np.allclose(t2.data, res[k].data, atol=1e-2, rtol=1e-2)
|
||||
|
||||
|
||||
def test_set_tensors(device):
|
||||
core = Core()
|
||||
func = core.read_model(test_net_xml, test_net_bin)
|
||||
exec_net = core.compile_model(func, device)
|
||||
model = core.read_model(test_net_xml, test_net_bin)
|
||||
compiled = core.compile_model(model, device)
|
||||
|
||||
data1 = read_image()
|
||||
tensor1 = Tensor(data1)
|
||||
@ -91,7 +121,7 @@ def test_set_tensors(device):
|
||||
data4 = np.zeros(shape=(1, 10), dtype=np.float32)
|
||||
tensor4 = Tensor(data4)
|
||||
|
||||
request = exec_net.create_infer_request()
|
||||
request = compiled.create_infer_request()
|
||||
request.set_tensors({"data": tensor1, "fc_out": tensor2})
|
||||
t1 = request.get_tensor("data")
|
||||
t2 = request.get_tensor("fc_out")
|
||||
@ -99,16 +129,16 @@ def test_set_tensors(device):
|
||||
assert np.allclose(tensor2.data, t2.data, atol=1e-2, rtol=1e-2)
|
||||
|
||||
request.set_output_tensors({0: tensor2})
|
||||
output_node = exec_net.outputs[0]
|
||||
output_node = compiled.outputs[0]
|
||||
t3 = request.get_tensor(output_node)
|
||||
assert np.allclose(tensor2.data, t3.data, atol=1e-2, rtol=1e-2)
|
||||
|
||||
request.set_input_tensors({0: tensor1})
|
||||
output_node = exec_net.inputs[0]
|
||||
output_node = compiled.inputs[0]
|
||||
t4 = request.get_tensor(output_node)
|
||||
assert np.allclose(tensor1.data, t4.data, atol=1e-2, rtol=1e-2)
|
||||
|
||||
output_node = exec_net.inputs[0]
|
||||
output_node = compiled.inputs[0]
|
||||
request.set_tensor(output_node, tensor3)
|
||||
t5 = request.get_tensor(output_node)
|
||||
assert np.allclose(tensor3.data, t5.data, atol=1e-2, rtol=1e-2)
|
||||
@ -148,10 +178,10 @@ def test_inputs_outputs_property(device):
|
||||
|
||||
def test_cancel(device):
|
||||
core = Core()
|
||||
func = core.read_model(test_net_xml, test_net_bin)
|
||||
exec_net = core.compile_model(func, device)
|
||||
model = core.read_model(test_net_xml, test_net_bin)
|
||||
compiled = core.compile_model(model, device)
|
||||
img = read_image()
|
||||
request = exec_net.create_infer_request()
|
||||
request = compiled.create_infer_request()
|
||||
|
||||
request.start_async({0: img})
|
||||
request.cancel()
|
||||
@ -168,13 +198,13 @@ def test_cancel(device):
|
||||
|
||||
def test_start_async(device):
|
||||
core = Core()
|
||||
func = core.read_model(test_net_xml, test_net_bin)
|
||||
exec_net = core.compile_model(func, device)
|
||||
model = core.read_model(test_net_xml, test_net_bin)
|
||||
compiled = core.compile_model(model, device)
|
||||
img = read_image()
|
||||
jobs = 3
|
||||
requests = []
|
||||
for _ in range(jobs):
|
||||
requests.append(exec_net.create_infer_request())
|
||||
requests.append(compiled.create_infer_request())
|
||||
|
||||
def callback(callbacks_info):
|
||||
time.sleep(0.01)
|
||||
@ -217,9 +247,9 @@ def test_infer_list_as_inputs(device):
|
||||
|
||||
def test_infer_mixed_keys(device):
|
||||
core = Core()
|
||||
func = core.read_model(test_net_xml, test_net_bin)
|
||||
model = core.read_model(test_net_xml, test_net_bin)
|
||||
core.set_config({"PERF_COUNT": "YES"}, device)
|
||||
model = core.compile_model(func, device)
|
||||
model = core.compile_model(model, device)
|
||||
|
||||
img = read_image()
|
||||
tensor = Tensor(img)
|
||||
@ -236,9 +266,9 @@ def test_infer_queue(device):
|
||||
jobs = 8
|
||||
num_request = 4
|
||||
core = Core()
|
||||
func = core.read_model(test_net_xml, test_net_bin)
|
||||
exec_net = core.compile_model(func, device)
|
||||
infer_queue = AsyncInferQueue(exec_net, num_request)
|
||||
model = core.read_model(test_net_xml, test_net_bin)
|
||||
compiled = core.compile_model(model, device)
|
||||
infer_queue = AsyncInferQueue(compiled, num_request)
|
||||
jobs_done = [{"finished": False, "latency": 0} for _ in range(jobs)]
|
||||
|
||||
def callback(request, job_id):
|
||||
@ -255,13 +285,13 @@ def test_infer_queue(device):
|
||||
assert all(job["latency"] > 0 for job in jobs_done)
|
||||
|
||||
|
||||
def test_infer_queue_fail_on_cpp_func(device):
|
||||
def test_infer_queue_fail_on_cpp_model(device):
|
||||
jobs = 6
|
||||
num_request = 4
|
||||
core = Core()
|
||||
func = core.read_model(test_net_xml, test_net_bin)
|
||||
exec_net = core.compile_model(func, device)
|
||||
infer_queue = AsyncInferQueue(exec_net, num_request)
|
||||
model = core.read_model(test_net_xml, test_net_bin)
|
||||
compiled = core.compile_model(model, device)
|
||||
infer_queue = AsyncInferQueue(compiled, num_request)
|
||||
|
||||
def callback(request, _):
|
||||
request.get_tensor("Unknown")
|
||||
@ -278,13 +308,13 @@ def test_infer_queue_fail_on_cpp_func(device):
|
||||
assert "Port for tensor name Unknown was not found" in str(e.value)
|
||||
|
||||
|
||||
def test_infer_queue_fail_on_py_func(device):
|
||||
def test_infer_queue_fail_on_py_model(device):
|
||||
jobs = 1
|
||||
num_request = 1
|
||||
core = Core()
|
||||
func = core.read_model(test_net_xml, test_net_bin)
|
||||
exec_net = core.compile_model(func, device)
|
||||
infer_queue = AsyncInferQueue(exec_net, num_request)
|
||||
model = core.read_model(test_net_xml, test_net_bin)
|
||||
compiled = core.compile_model(model, device)
|
||||
infer_queue = AsyncInferQueue(compiled, num_request)
|
||||
|
||||
def callback(request, _):
|
||||
request = request + 21
|
||||
@ -321,9 +351,9 @@ def test_query_state_write_buffer(device, input_shape, data_type, mode):
|
||||
from openvino.runtime import Tensor
|
||||
from openvino.runtime.utils.types import get_dtype
|
||||
|
||||
function = create_function_with_memory(input_shape, data_type)
|
||||
exec_net = core.compile_model(model=function, device_name=device)
|
||||
request = exec_net.create_infer_request()
|
||||
model = create_model_with_memory(input_shape, data_type)
|
||||
compiled = core.compile_model(model=model, device_name=device)
|
||||
request = compiled.create_infer_request()
|
||||
mem_states = request.query_state()
|
||||
mem_state = mem_states[0]
|
||||
|
||||
@ -372,9 +402,9 @@ def test_results_async_infer(device):
|
||||
jobs = 8
|
||||
num_request = 4
|
||||
core = Core()
|
||||
func = core.read_model(test_net_xml, test_net_bin)
|
||||
exec_net = core.compile_model(func, device)
|
||||
infer_queue = AsyncInferQueue(exec_net, num_request)
|
||||
model = core.read_model(test_net_xml, test_net_bin)
|
||||
compiled = core.compile_model(model, device)
|
||||
infer_queue = AsyncInferQueue(compiled, num_request)
|
||||
jobs_done = [{"finished": False, "latency": 0} for _ in range(jobs)]
|
||||
|
||||
def callback(request, job_id):
|
||||
@ -388,7 +418,7 @@ def test_results_async_infer(device):
|
||||
infer_queue.start_async({"data": img}, i)
|
||||
infer_queue.wait_all()
|
||||
|
||||
request = exec_net.create_infer_request()
|
||||
request = compiled.create_infer_request()
|
||||
outputs = request.infer({0: img})
|
||||
|
||||
for i in range(num_request):
|
||||
@ -458,8 +488,8 @@ def test_infer_float16(device):
|
||||
</edges>
|
||||
</net>""")
|
||||
core = Core()
|
||||
func = core.read_model(model=model)
|
||||
p = PrePostProcessor(func)
|
||||
model = core.read_model(model=model)
|
||||
p = PrePostProcessor(model)
|
||||
p.input(0).tensor().set_element_type(Type.f16)
|
||||
p.input(0).preprocess().convert_element_type(Type.f16)
|
||||
p.input(1).tensor().set_element_type(Type.f16)
|
||||
@ -467,10 +497,67 @@ def test_infer_float16(device):
|
||||
p.output(0).tensor().set_element_type(Type.f16)
|
||||
p.output(0).postprocess().convert_element_type(Type.f16)
|
||||
|
||||
func = p.build()
|
||||
exec_net = core.compile_model(func, device)
|
||||
model = p.build()
|
||||
compiled = core.compile_model(model, device)
|
||||
input_data = np.array([[[1, 2], [3, 4]], [[5, 6], [7, 8]]]).astype(np.float16)
|
||||
request = exec_net.create_infer_request()
|
||||
request = compiled.create_infer_request()
|
||||
outputs = request.infer({0: input_data, 1: input_data})
|
||||
assert np.allclose(list(outputs.values()), list(request.results.values()))
|
||||
assert np.allclose(list(outputs.values()), input_data + input_data)
|
||||
|
||||
|
||||
def test_ports_as_inputs(device):
|
||||
input_shape = [2, 2]
|
||||
param_a = ops.parameter(input_shape, np.float32)
|
||||
param_b = ops.parameter(input_shape, np.float32)
|
||||
model = Model(ops.add(param_a, param_b), [param_a, param_b])
|
||||
|
||||
core = Core()
|
||||
compiled = core.compile_model(model, device)
|
||||
request = compiled.create_infer_request()
|
||||
|
||||
arr_1 = np.array([[1, 2], [3, 4]], dtype=np.float32)
|
||||
arr_2 = np.array([[3, 4], [1, 2]], dtype=np.float32)
|
||||
|
||||
tensor1 = Tensor(arr_1)
|
||||
tensor2 = Tensor(arr_2)
|
||||
|
||||
res = request.infer({compiled.inputs[0]: tensor1, compiled.inputs[1]: tensor2})
|
||||
assert np.array_equal(res[compiled.outputs[0]], tensor1.data + tensor2.data)
|
||||
|
||||
res = request.infer({request.model_inputs[0]: tensor1, request.model_inputs[1]: tensor2})
|
||||
assert np.array_equal(res[request.model_outputs[0]], tensor1.data + tensor2.data)
|
||||
|
||||
|
||||
def test_inputs_dict_not_replaced(device):
|
||||
request, arr_1, arr_2 = create_simple_request_and_inputs(device)
|
||||
|
||||
inputs = {0: arr_1, 1: arr_2}
|
||||
inputs_copy = deepcopy(inputs)
|
||||
|
||||
res = request.infer(inputs)
|
||||
|
||||
np.testing.assert_equal(inputs, inputs_copy)
|
||||
assert np.array_equal(res[request.model_outputs[0]], arr_1 + arr_2)
|
||||
|
||||
|
||||
def test_inputs_list_not_replaced(device):
|
||||
request, arr_1, arr_2 = create_simple_request_and_inputs(device)
|
||||
|
||||
inputs = [arr_1, arr_2]
|
||||
inputs_copy = deepcopy(inputs)
|
||||
|
||||
res = request.infer(inputs)
|
||||
|
||||
assert np.array_equal(inputs, inputs_copy)
|
||||
assert np.array_equal(res[request.model_outputs[0]], arr_1 + arr_2)
|
||||
|
||||
|
||||
def test_invalid_inputs_container(device):
|
||||
request, arr_1, arr_2 = create_simple_request_and_inputs(device)
|
||||
|
||||
inputs = (arr_1, arr_2)
|
||||
|
||||
with pytest.raises(TypeError) as e:
|
||||
request.infer(inputs)
|
||||
assert "Inputs should be either list or dict! Current type:" in str(e.value)
|
||||
|
Loading…
Reference in New Issue
Block a user