[IE PYTHON] dynamic shape api for python (#7282)

* Added nGraph as a public dependency

* Fixed Windows warning

* Fixed CMake

* Fixed constant op

* Fixed typo

* Added reshape to PartialShape to CNNNetwork

* Added SetShape to InferRequest

* Enable support of DynamicShape in IE Data

* Add support of dynamic shapes to template plugin

* Fixed build

* Enable support dynamic rank

* Fixed test for dynamic rank

* Fixed some tests

* Fix preprocess tests

* Fixed SetBlob

* Fixed code style

* Add more tests

* Fixed accuracy tests

* Fixed documentation

* Added tests for custom operation

* Added new tests

* WIP: move setShape from infer request to Blob

* Returned isApplicable check back

* Removed obsolete tests for InferRequest::SetShape and add new test for Blob::setShape (a trivial one)

* Fixed artifacts

* Break code style

* Revert "Break code style"

This reverts commit 71ee638cd0.

* Added -j8 for fix_all

* Applied code style fixes

* Adde doxygen items

* Fixed style

* Applied codestyle patch

* Reverted unrelevant commit with template extension

* Fixed cmake file for shared func tests (pick from master)

* Revert all changes in template_extension

* Deleted some old stuff that commented and won't be used

* Fixed wrong exception throwing

* Code style fix

* Fixed preprocessing part

* Fixed incorrect blob reshape in GetBlob

* Deleted incorrect assert in GAPI that prevents passing some tests in Debug

* Fixed issues identified during review

* Removed SetShape, replace getLayoutByDims by getLayoutByRank and removed problematic modification from IE preprocessing

* Fixed comments

* Removed obsolete setShape

* [VPU] Fixed allocating dynamic blobs in myriad_infer_request

* Fixed comments

* Fixed CNNNgraphImpl and comments

* add partial reshape for IENetwork

* Add getPartialShape

* Add setShape for Blob

* Add tests

* Add partial_shape property for CDataPtr

* Add partial_shape property for data

* Fix code style

* Fix code style

* Fix test

* Fix code style

* Fix code style

* Fix code style

* Add tests

* Add new lines

* Mark tests

* Fix tests

* call set_shape implicit  only for dynamic inputs

* skip tests on ARM

* remove ngraph dependence from ie_api.pyx

* expand only shape, not array in expand_dims_to_corresponding_layout

* Mark ngraph dependent tests

* remove skip inside test

* code refactoring

* add new line

* Add docstring

* Fix code style

Co-authored-by: Ilya Churaev <ilya.churaev@intel.com>
Co-authored-by: Lyalin, Sergey <sergey.lyalin@intel.com>
Co-authored-by: Polina <polina.brzezinskaya@intel.com>
This commit is contained in:
Alexey Lebedev 2021-09-08 19:41:20 +03:00 committed by GitHub
parent 42b93bed42
commit f89b3d770b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
15 changed files with 476 additions and 24 deletions

View File

@ -39,7 +39,7 @@ cdef class InferRequest:
cpdef get_perf_counts(self) cpdef get_perf_counts(self)
cdef void user_callback(self, int status) with gil cdef void user_callback(self, int status) with gil
cdef public: cdef public:
_inputs_list, _outputs_list, _py_callback, _py_data, _py_callback_used, _py_callback_called, _user_blobs _inputs_list, _outputs_list, _py_callback, _py_data, _py_callback_used, _py_callback_called, _user_blobs, _inputs_is_dynamic
cdef class IENetwork: cdef class IENetwork:
cdef C.IENetwork impl cdef C.IENetwork impl

View File

@ -29,7 +29,6 @@ from .constants import WaitMode, StatusCode, MeanVariant, layout_str_to_enum, fo
import numpy as np import numpy as np
warnings.filterwarnings(action="module", category=DeprecationWarning) warnings.filterwarnings(action="module", category=DeprecationWarning)
cdef extern from "<utility>" namespace "std" nogil: cdef extern from "<utility>" namespace "std" nogil:
@ -53,6 +52,11 @@ cdef c_map_to_dict(map[string, string] c_map):
return py_dict return py_dict
cdef expand_dims_to_corresponding_layout(shape, layout):
single_axes = [1] * (len(layout) - len(shape))
return single_axes + list(shape)
def get_version(): def get_version():
return C.get_version().decode() return C.get_version().decode()
@ -271,6 +275,10 @@ cdef class Blob:
tensor_desc = TensorDesc(precision, dims, layout_int_to_str_map[layout]) tensor_desc = TensorDesc(precision, dims, layout_int_to_str_map[layout])
return tensor_desc return tensor_desc
def set_shape(self, new_shape):
self._initial_shape = new_shape
deref(self._ptr).setShape(new_shape)
## This class represents an Inference Engine entity and allows you to manipulate with plugins using unified interfaces. ## This class represents an Inference Engine entity and allows you to manipulate with plugins using unified interfaces.
cdef class IECore: cdef class IECore:
## Class constructor ## Class constructor
@ -815,6 +823,14 @@ cdef class DataPtr:
def initialized(self): def initialized(self):
return deref(self._ptr).isInitialized() return deref(self._ptr).isInitialized()
@property
def is_dynamic(self):
return deref(self._ptr).isDynamic()
## get capsule with ngraph::PartialShape
def _get_partial_shape_capsule(self):
return C.getPartialShape_capsule(self._ptr)
## This class is the layer constant data representation. Provides same interface as DataPtr object except properties setters ## This class is the layer constant data representation. Provides same interface as DataPtr object except properties setters
cdef class CDataPtr: cdef class CDataPtr:
@ -843,6 +859,14 @@ cdef class CDataPtr:
def initialized(self): def initialized(self):
return deref(self._ptr).isInitialized() return deref(self._ptr).isInitialized()
@property
def is_dynamic(self):
return deref(self._ptr).isDynamic()
## get capsule with ngraph::PartialShape
def _get_partial_shape_capsule(self):
return C.getPartialShape_capsule(self._ptr)
## This class represents a network instance loaded to plugin and ready for inference. ## This class represents a network instance loaded to plugin and ready for inference.
cdef class ExecutableNetwork: cdef class ExecutableNetwork:
@ -912,6 +936,8 @@ cdef class ExecutableNetwork:
infer_request.impl = &(deref(self.impl).infer_requests[i]) infer_request.impl = &(deref(self.impl).infer_requests[i])
infer_request._inputs_list = list(self.input_info.keys()) infer_request._inputs_list = list(self.input_info.keys())
infer_request._outputs_list = list(self.outputs.keys()) infer_request._outputs_list = list(self.outputs.keys())
for input_name in infer_request._inputs_list:
infer_request._inputs_is_dynamic[input_name] = self.input_info[input_name].input_data.is_dynamic
self._infer_requests.append(infer_request) self._infer_requests.append(infer_request)
if len(self._infer_requests) != c_infer_requests_size: if len(self._infer_requests) != c_infer_requests_size:
@ -1048,6 +1074,7 @@ cdef class InferRequest:
self._py_callback_used = False self._py_callback_used = False
self._py_callback_called = threading.Event() self._py_callback_called = threading.Event()
self._py_data = None self._py_data = None
self._inputs_is_dynamic = {}
cdef void user_callback(self, int status) with gil: cdef void user_callback(self, int status) with gil:
if self._py_callback: if self._py_callback:
@ -1308,6 +1335,9 @@ cdef class InferRequest:
def _fill_inputs(self, inputs): def _fill_inputs(self, inputs):
for k, v in inputs.items(): for k, v in inputs.items():
assert k in self._inputs_list, f"No input with name {k} found in network" assert k in self._inputs_list, f"No input with name {k} found in network"
if self._inputs_is_dynamic[k]:
shape = expand_dims_to_corresponding_layout(v.shape, self.input_blobs[k].tensor_desc.layout)
self.input_blobs[k].set_shape(shape)
if self.input_blobs[k].tensor_desc.precision == "FP16": if self.input_blobs[k].tensor_desc.precision == "FP16":
self.input_blobs[k].buffer[:] = v.view(dtype=np.int16) self.input_blobs[k].buffer[:] = v.view(dtype=np.int16)
else: else:
@ -1452,15 +1482,25 @@ cdef class IENetwork:
# net.reshape({input_layer: (n, c, h*2, w*2)}) # net.reshape({input_layer: (n, c, h*2, w*2)})
# ``` # ```
def reshape(self, input_shapes: dict): def reshape(self, input_shapes: dict):
cdef map[string, vector[size_t]] c_input_shapes cdef map[string, vector[vector[int64_t]]] c_input_shapes
cdef vector[size_t] c_shape cdef vector[vector[int64_t]] c_shape
cdef vector[int64_t] dim
net_inputs = self.input_info net_inputs = self.input_info
for input, shape in input_shapes.items(): for input, shape in input_shapes.items():
c_shape = [] c_shape = []
if input not in net_inputs: if input not in net_inputs:
raise AttributeError(f"Specified '{input}' layer not in network inputs '{net_inputs}'! ") raise AttributeError(f"Specified '{input}' layer not in network inputs '{net_inputs}'! ")
for v in shape: for v in shape:
c_shape.push_back(v) if isinstance(v, list) or isinstance(v, tuple):
if len(v) < 1 or len(v) > 2:
raise ValueError(f"Incorrect PartialShape dimension definition '{v}' "
f"in shape '{shape}', expected one or two values for a dimension! ")
for d in v:
dim.push_back(d)
else:
dim.push_back(v)
c_shape.push_back(dim)
dim.clear()
c_input_shapes[input.encode()] = c_shape c_input_shapes[input.encode()] = c_shape
self.impl.reshape(c_input_shapes) self.impl.reshape(c_input_shapes)

View File

@ -4,6 +4,8 @@
#include "ie_api_impl.hpp" #include "ie_api_impl.hpp"
#include <ngraph/partial_shape.hpp>
#include "ie_iinfer_request.hpp" #include "ie_iinfer_request.hpp"
#include "ie_plugin_config.hpp" #include "ie_plugin_config.hpp"
@ -206,6 +208,24 @@ InferenceEnginePython::IENetwork InferenceEnginePython::read_network(std::string
return InferenceEnginePython::IENetwork(std::make_shared<InferenceEngine::CNNNetwork>(net)); return InferenceEnginePython::IENetwork(std::make_shared<InferenceEngine::CNNNetwork>(net));
} }
PyObject* InferenceEnginePython::getPartialShape_capsule(InferenceEngine::CDataPtr data) {
const char* py_capsule_name = "ngraph_partial_shape";
auto ngraph_pShape_ptr = std::make_shared<ngraph::PartialShape>(data->getPartialShape());
auto* sp_copy = new std::shared_ptr<const ngraph::PartialShape>(ngraph_pShape_ptr);
auto sp_deleter = [](PyObject* capsule) {
auto* capsule_ptr = PyCapsule_GetPointer(capsule, "ngraph_partial_shape");
auto* function_sp = static_cast<std::shared_ptr<ngraph::PartialShape>*>(capsule_ptr);
if (function_sp) {
delete function_sp;
}
};
if (ngraph_pShape_ptr) {
return PyCapsule_New(sp_copy, py_capsule_name, sp_deleter);
} else {
return nullptr;
}
}
InferenceEnginePython::IENetwork::IENetwork(const std::shared_ptr<InferenceEngine::CNNNetwork>& cnn_network) InferenceEnginePython::IENetwork::IENetwork(const std::shared_ptr<InferenceEngine::CNNNetwork>& cnn_network)
: actual(cnn_network) { : actual(cnn_network) {
if (actual == nullptr) if (actual == nullptr)
@ -289,8 +309,21 @@ size_t InferenceEnginePython::IENetwork::getBatch() {
return actual->getBatchSize(); return actual->getBatchSize();
} }
void InferenceEnginePython::IENetwork::reshape(const std::map<std::string, std::vector<size_t>>& input_shapes) { void InferenceEnginePython::IENetwork::reshape(
actual->reshape(input_shapes); const std::map<std::string, std::vector<std::vector<int64_t>>>& input_shapes) {
std::map<std::string, ngraph::PartialShape> inputShapes;
for (auto const& input : input_shapes) {
using ngraph::Dimension;
std::vector<Dimension> dims;
for (auto const& d : input.second) {
if (d.size() == 1)
dims.push_back(Dimension(d[0]));
else if (d.size() == 2)
dims.push_back(Dimension(d[0], d[1]));
}
inputShapes[input.first] = ngraph::PartialShape(dims);
}
actual->reshape(inputShapes);
} }
InferenceEnginePython::IEExecNetwork::IEExecNetwork(const std::string& name, size_t num_requests) InferenceEnginePython::IEExecNetwork::IEExecNetwork(const std::string& name, size_t num_requests)

View File

@ -62,7 +62,7 @@ struct IENetwork {
const std::map<std::string, InferenceEngine::DataPtr> getOutputs(); const std::map<std::string, InferenceEngine::DataPtr> getOutputs();
void reshape(const std::map<std::string, std::vector<size_t>>& input_shapes); void reshape(const std::map<std::string, std::vector<std::vector<int64_t>>>& input_shapes);
void serialize(const std::string& path_to_xml, const std::string& path_to_bin); void serialize(const std::string& path_to_xml, const std::string& path_to_bin);
@ -203,4 +203,6 @@ std::string get_version();
InferenceEnginePython::IENetwork read_network(std::string path_to_xml, std::string path_to_bin); InferenceEnginePython::IENetwork read_network(std::string path_to_xml, std::string path_to_bin);
PyObject* getPartialShape_capsule(InferenceEngine::CDataPtr data);
}; // namespace InferenceEnginePython }; // namespace InferenceEnginePython

View File

@ -23,6 +23,7 @@ cdef extern from "<inference_engine.hpp>" namespace "InferenceEngine":
const CTensorDesc& getTensorDesc() except + const CTensorDesc& getTensorDesc() except +
size_t element_size() except + size_t element_size() except +
void allocate() void allocate()
void setShape(const SizeVector& dims) except +
cdef TBlob[Type].Ptr make_shared_blob[Type](const CTensorDesc& tensorDesc) cdef TBlob[Type].Ptr make_shared_blob[Type](const CTensorDesc& tensorDesc)
@ -47,6 +48,7 @@ cdef extern from "<inference_engine.hpp>" namespace "InferenceEngine":
const Layout getLayout() except + const Layout getLayout() except +
void setLayout(Layout layout) except + void setLayout(Layout layout) except +
const bool isInitialized() except + const bool isInitialized() except +
bool isDynamic() except +
ctypedef shared_ptr[Data] DataPtr ctypedef shared_ptr[Data] DataPtr
ctypedef weak_ptr[Data] DataWeakPtr ctypedef weak_ptr[Data] DataWeakPtr
@ -178,7 +180,7 @@ cdef extern from "ie_api_impl.hpp" namespace "InferenceEnginePython":
size_t getBatch() except + size_t getBatch() except +
void setLayerParams(map[string, map[string, string]] params_map) except + void setLayerParams(map[string, map[string, string]] params_map) except +
void serialize(const string& path_to_xml, const string& path_to_bin) except + void serialize(const string& path_to_xml, const string& path_to_bin) except +
void reshape(map[string, vector[size_t]] input_shapes) except + void reshape(map[string, vector[vector[int64_t]]] input_shapes) except +
object getFunction() except + object getFunction() except +
void convertToOldRepresentation() except + void convertToOldRepresentation() except +
string getOVNameForTensor(const string &) except + string getOVNameForTensor(const string &) except +
@ -226,3 +228,5 @@ cdef extern from "ie_api_impl.hpp" namespace "InferenceEnginePython":
cdef string get_version() cdef string get_version()
cdef IENetwork read_network(string path_to_xml, string path_to_bin) cdef IENetwork read_network(string path_to_xml, string path_to_bin)
cdef object getPartialShape_capsule(DataPtr)

View File

@ -3,6 +3,7 @@
import os import os
import pytest import pytest
import numpy as np
def model_path(is_myriad=False): def model_path(is_myriad=False):
@ -41,7 +42,19 @@ def device():
def pytest_configure(config): def pytest_configure(config):
# register an additional marker for ngraph dependent tests # register an additional markers
config.addinivalue_line( config.addinivalue_line(
"markers", "ngraph_dependent_test" "markers", "ngraph_dependent_test"
) )
config.addinivalue_line(
"markers", "template_plugin"
)
def create_ngraph_function(inputShape):
import ngraph as ng
inputShape = ng.impl.PartialShape(inputShape)
param = ng.parameter(inputShape, dtype=np.float32, name="data")
result = ng.relu(param, name='out')
function = ng.Function(result, [param], "TestFunction")
return function

View File

@ -121,3 +121,34 @@ def test_buffer_values_after_add_outputs(device):
result = exec_net.infer(feed_dict) result = exec_net.infer(feed_dict)
assert np.all(abs(result[output_layer])<30) assert np.all(abs(result[output_layer])<30)
assert result[output_layer].dtype == np.float16 assert result[output_layer].dtype == np.float16
def test_set_shape():
tensor_desc = TensorDesc("FP32", [1, 3, 127, 127], "NHWC")
blob = Blob(tensor_desc)
blob.set_shape([1, 4, 128, 128])
assert blob.tensor_desc.dims == [1, 4, 128, 128]
assert blob.buffer.shape == (1, 4, 128, 128)
array = np.ones([1, 3, 127, 127], dtype=np.float32)
blob = Blob(tensor_desc, array)
blob.set_shape([1, 4, 128, 128])
assert blob.tensor_desc.dims == [1, 4, 128, 128]
assert blob.buffer.shape == (1, 4, 128, 128)
@pytest.mark.ngraph_dependent_test
@pytest.mark.template_plugin
def test_blob_set_shape_after_async_infer():
from conftest import create_ngraph_function
import ngraph as ng
function = create_ngraph_function([ng.Dimension(0,5), ng.Dimension(4), ng.Dimension(20), ng.Dimension(20)])
net = ng.function_to_cnn(function)
ie_core = IECore()
ie_core.register_plugin("templatePlugin", "TEMPLATE")
exec_net = ie_core.load_network(net, "TEMPLATE")
request = exec_net.requests[0]
request.async_infer({"data": np.ones([4, 4, 20, 20])})
with pytest.raises(RuntimeError) as e:
request.input_blobs['data'].set_shape([3, 4, 20, 20])
assert "REQUEST_BUSY" in str(e.value)

View File

@ -56,3 +56,21 @@ def test_initialized(device):
net = ie.read_network(model=test_net_xml, weights=test_net_bin) net = ie.read_network(model=test_net_xml, weights=test_net_bin)
exec_net = ie.load_network(net, device, num_requests=5) exec_net = ie.load_network(net, device, num_requests=5)
assert exec_net.outputs['fc_out'].initialized, "Incorrect value for initialized property for layer 'fc_out" assert exec_net.outputs['fc_out'].initialized, "Incorrect value for initialized property for layer 'fc_out"
@pytest.mark.ngraph_dependent_test
@pytest.mark.template_plugin
def test_is_dynamic():
from conftest import create_ngraph_function
import ngraph as ng
function = create_ngraph_function([-1, 3, 20, 20])
net = ng.function_to_cnn(function)
ie = IECore()
ie.register_plugin("templatePlugin", "TEMPLATE")
exec_net = ie.load_network(net, "TEMPLATE")
assert exec_net.outputs["out"].is_dynamic
p_shape = ng.partial_shape_from_data(exec_net.outputs["out"])
assert isinstance(p_shape, ng.impl.PartialShape)
with pytest.raises(RuntimeError) as e:
exec_net.outputs["out"].shape
assert "Cannot return dims for Data with dynamic shapes!" in str(e.value)

View File

@ -43,3 +43,27 @@ def test_layout():
def test_initialized(): def test_initialized():
assert layer_out_data().initialized, "Incorrect value for initialized property for layer 'fc_out'" assert layer_out_data().initialized, "Incorrect value for initialized property for layer 'fc_out'"
@pytest.mark.ngraph_dependent_test
@pytest.mark.template_plugin
def test_is_dynamic():
from conftest import create_ngraph_function
import ngraph as ng
function = create_ngraph_function([-1, 3, 20, 20])
net = ng.function_to_cnn(function)
assert net.input_info["data"].input_data.is_dynamic
assert net.outputs["out"].is_dynamic
p_shape = ng.partial_shape_from_data(net.input_info["data"].input_data)
assert isinstance(p_shape, ng.impl.PartialShape)
p_shape = ng.partial_shape_from_data(net.outputs["out"])
assert isinstance(p_shape, ng.impl.PartialShape)
with pytest.raises(RuntimeError) as e:
net.input_info["data"].input_data.shape
assert "Cannot return dims for Data with dynamic shapes!" in str(e.value)
ie = IECore()
ie.register_plugin("templatePlugin", "TEMPLATE")
exec_net = ie.load_network(net, "TEMPLATE")
assert exec_net.input_info["data"].input_data.is_dynamic
p_shape = ng.partial_shape_from_data(exec_net.input_info["data"].input_data)
assert isinstance(p_shape, ng.impl.PartialShape)

View File

@ -156,6 +156,43 @@ def test_reshape():
ie = IECore() ie = IECore()
net = ie.read_network(model=test_net_xml, weights=test_net_bin) net = ie.read_network(model=test_net_xml, weights=test_net_bin)
net.reshape({"data": (2, 3, 32, 32)}) net.reshape({"data": (2, 3, 32, 32)})
assert net.input_info["data"].input_data.shape == [2, 3, 32, 32]
@pytest.mark.ngraph_dependent_test
@pytest.mark.parametrize("shape, p_shape", [
([1, 3, 22, 22], [1, 3, -1, 25]),
([1, 3, 22, 22], [-1, -1, -1, -1]),
([1, 3, -1, 25], [1, 3, 22, -1])
])
def test_reshape_with_partial_shape(device, shape, p_shape):
from conftest import create_ngraph_function
import ngraph as ng
function = create_ngraph_function(shape)
net = ng.function_to_cnn(function)
net.reshape({"data": p_shape})
changedFunction = ng.function_from_cnn(net)
p_shape = ng.impl.PartialShape(p_shape)
assert changedFunction.get_parameters()[0].get_partial_shape().is_dynamic
assert changedFunction.get_results()[0].get_output_partial_shape(0).is_dynamic
assert function.get_parameters()[0].get_partial_shape().is_dynamic
assert function.get_results()[0].get_output_partial_shape(0).is_dynamic
assert changedFunction.get_parameters()[0].get_partial_shape() == p_shape
assert changedFunction.get_results()[0].get_output_partial_shape(0) == p_shape
assert function.get_parameters()[0].get_partial_shape() == p_shape
assert function.get_results()[0].get_output_partial_shape(0) == p_shape
@pytest.mark.ngraph_dependent_test
def test_incorrect_reshape(device):
from conftest import create_ngraph_function
import ngraph as ng
function = create_ngraph_function([1, 3, 22, 22])
net = ng.function_to_cnn(function)
with pytest.raises(ValueError) as e:
net.reshape({"data": [(2, 4, 6), 3, 22, 22]})
assert "Incorrect PartialShape dimension definition '(2, 4, 6)' " \
"in shape '[(2, 4, 6), 3, 22, 22]', expected one or two values for a dimension! " in str(e.value)
def test_net_from_buffer_valid(): def test_net_from_buffer_valid():
@ -245,3 +282,18 @@ def test_tensor_names():
assert net.get_ov_name_for_tensor("relu_t") == "activation" assert net.get_ov_name_for_tensor("relu_t") == "activation"
assert net.get_ov_name_for_tensor("identity_t") == "activation" assert net.get_ov_name_for_tensor("identity_t") == "activation"
assert net.get_ov_name_for_tensor("input") == "in1" assert net.get_ov_name_for_tensor("input") == "in1"
@pytest.mark.ngraph_dependent_test
@pytest.mark.template_plugin
def test_create_two_exec_net():
from conftest import create_ngraph_function
import ngraph as ng
function = create_ngraph_function([ng.Dimension(0,5), ng.Dimension(4), ng.Dimension(20), ng.Dimension(20)])
net = ng.function_to_cnn(function)
ie_core = IECore()
ie_core.register_plugin("templatePlugin", "TEMPLATE")
exec_net1 = ie_core.load_network(net, "TEMPLATE", num_requests=2)
assert ng.function_from_cnn(net) != None
exec_net2 = ie_core.load_network(net, "TEMPLATE", num_requests=2)
assert ng.function_from_cnn(net) != None

View File

@ -17,9 +17,8 @@ path_to_img = image_path()
def create_function_with_memory(input_shape, data_type): def create_function_with_memory(input_shape, data_type):
import ngraph as ng
from ngraph.impl import Function, Type from ngraph.impl import Function, Type
import ngraph as ng
input_data = ng.parameter(input_shape, name="input_data", dtype=data_type) input_data = ng.parameter(input_shape, name="input_data", dtype=data_type)
rv = ng.read_value(input_data, "var_id_667") rv = ng.read_value(input_data, "var_id_667")
add = ng.add(rv, input_data, name="MemoryAdd") add = ng.add(rv, input_data, name="MemoryAdd")
@ -563,4 +562,221 @@ def test_query_state_write_buffer(device, input_shape, data_type, mode):
expected_res = np.full(input_shape, i, dtype=format_map[data_type]) expected_res = np.full(input_shape, i, dtype=format_map[data_type])
assert np.allclose(res['MemoryAdd'], expected_res, atol=1e-6), \ assert np.allclose(res['MemoryAdd'], expected_res, atol=1e-6), \
"Expected values: {} \n Actual values: {} \n".format(expected_res, res) "Expected values: {} \n Actual values: {} \n".format(expected_res, res)
@pytest.mark.ngraph_dependent_test
@pytest.mark.template_plugin
@pytest.mark.parametrize("shape, p_shape, ref_shape", [
([1, 4, 20, 20], [-1, 4, 20, 20], [5, 4, 20, 20]),
([1, 4, 20, 20], [(0,5), 4, 20, 20], [3, 4, 20, 20]),
([1, 4, 20, 20], [(3,5), 3, 20, 20], [2, 4, 20, 20]),
([1, 4, 20, 20], [(3,5), 3, 20, 20], [6, 4, 20, 20]),
])
def test_infer_dynamic_network_with_set_shape(shape, p_shape, ref_shape):
from conftest import create_ngraph_function
import ngraph as ng
function = create_ngraph_function(shape)
net = ng.function_to_cnn(function)
net.reshape({"data": p_shape})
ie_core = ie.IECore()
ie_core.register_plugin("templatePlugin", "TEMPLATE")
exec_net = ie_core.load_network(net, "TEMPLATE")
exec_net.requests[0].input_blobs["data"].set_shape(ref_shape)
assert exec_net.requests[0].input_blobs["data"].tensor_desc.dims == ref_shape
exec_net.infer({"data": np.ones(ref_shape)})
request = exec_net.requests[0]
request.async_infer({"data": np.ones(ref_shape)})
status = request.wait(ie.WaitMode.RESULT_READY)
assert status == ie.StatusCode.OK
assert request.output_blobs['out'].tensor_desc.dims == ref_shape
@pytest.mark.ngraph_dependent_test
@pytest.mark.template_plugin
@pytest.mark.parametrize("shape, p_shape, ref_shape", [
([1, 4, 20, 20], [-1, 4, 20, 20], [5, 4, 20, 20]),
([1, 4, 20, 20], [(0,5), 4, 20, 20], [3, 4, 20, 20]),
([1, 4, 20, 20], [(3,5), 3, 20, 20], [2, 4, 20, 20]),
([1, 4, 20, 20], [(3,5), 3, 20, 20], [6, 4, 20, 20]),
])
def test_infer_dynamic_network_without_set_shape(shape, p_shape, ref_shape):
from conftest import create_ngraph_function
import ngraph as ng
function = create_ngraph_function(shape)
net = ng.function_to_cnn(function)
net.reshape({"data": p_shape})
ie_core = ie.IECore()
ie_core.register_plugin("templatePlugin", "TEMPLATE")
exec_net = ie_core.load_network(net, "TEMPLATE")
exec_net.infer({"data": np.ones(ref_shape)})
assert exec_net.requests[0].input_blobs["data"].tensor_desc.dims == ref_shape
request = exec_net.requests[0]
request.async_infer({"data": np.ones(ref_shape)})
status = request.wait(ie.WaitMode.RESULT_READY)
assert status == ie.StatusCode.OK
assert request.output_blobs['out'].tensor_desc.dims == ref_shape
@pytest.mark.ngraph_dependent_test
@pytest.mark.template_plugin
@pytest.mark.parametrize("shape, p_shape, ref_shape", [
([1, 4, 20, 20], [-1, 4, 20, 20], [5, 4, 20, 20]),
([1, 4, 20, 20], [(0,5), 4, 20, 20], [3, 4, 20, 20]),
([1, 4, 20, 20], [(3,5), 3, 20, 20], [2, 4, 20, 20]),
([1, 4, 20, 20], [(3,5), 3, 20, 20], [6, 4, 20, 20]),
])
def test_infer_dynamic_network_with_set_blob(shape, p_shape, ref_shape):
from conftest import create_ngraph_function
import ngraph as ng
function = create_ngraph_function(shape)
net = ng.function_to_cnn(function)
net.reshape({"data": p_shape})
ie_core = ie.IECore()
ie_core.register_plugin("templatePlugin", "TEMPLATE")
exec_net = ie_core.load_network(net, "TEMPLATE")
tensor_desc = exec_net.requests[0].input_blobs["data"].tensor_desc
tensor_desc.dims = ref_shape
blob = ie.Blob(tensor_desc)
exec_net.requests[0].set_blob("data", blob)
assert exec_net.requests[0].input_blobs["data"].tensor_desc.dims == ref_shape
request = exec_net.requests[0]
request.infer({"data": np.ones(ref_shape)})
request.async_infer({"data": np.ones(ref_shape)})
status = request.wait(ie.WaitMode.RESULT_READY)
assert status == ie.StatusCode.OK
assert request.output_blobs["out"].tensor_desc.dims == ref_shape
@pytest.mark.ngraph_dependent_test
@pytest.mark.template_plugin
def test_infer_dynamic_network_twice():
from conftest import create_ngraph_function
import ngraph as ng
shape, p_shape = [1, 4, 20, 20], [(0,5), 4, 20, 20]
ref_shape1, ref_shape2 = [2, 4, 20, 20], [3, 4, 20, 20]
function = create_ngraph_function(shape)
net = ng.function_to_cnn(function)
net.reshape({"data": p_shape})
ie_core = ie.IECore()
ie_core.register_plugin("templatePlugin", "TEMPLATE")
exec_net = ie_core.load_network(net, "TEMPLATE")
request = exec_net.requests[0]
request.infer({"data": np.ones(ref_shape1)})
assert exec_net.requests[0].input_blobs["data"].tensor_desc.dims == ref_shape1
assert request.output_blobs['out'].tensor_desc.dims == ref_shape1
request.infer({"data": np.ones(ref_shape2)})
assert exec_net.requests[0].input_blobs["data"].tensor_desc.dims == ref_shape2
assert request.output_blobs['out'].tensor_desc.dims == ref_shape2
@pytest.mark.ngraph_dependent_test
@pytest.mark.template_plugin
def test_infer_dynamic_network_with_set_blob_twice():
from conftest import create_ngraph_function
import ngraph as ng
shape, p_shape = [1, 4, 20, 20], [(0,5), 4, 20, 20]
ref_shape1, ref_shape2 = [2, 4, 20, 20], [3, 4, 20, 20]
function = create_ngraph_function(shape)
net = ng.function_to_cnn(function)
net.reshape({"data": p_shape})
ie_core = ie.IECore()
ie_core.register_plugin("templatePlugin", "TEMPLATE")
exec_net = ie_core.load_network(net, "TEMPLATE")
request = exec_net.requests[0]
td = request.input_blobs['data'].tensor_desc
td.dims = ref_shape1
blob = ie.Blob(td)
request.set_blob("data", blob)
request.infer({"data": np.ones(ref_shape1)})
assert exec_net.requests[0].input_blobs["data"].tensor_desc.dims == ref_shape1
assert request.output_blobs['out'].tensor_desc.dims == ref_shape1
td = request.input_blobs['data'].tensor_desc
td.dims = ref_shape2
blob = ie.Blob(td)
request.set_blob("data", blob)
request.infer({"data": np.ones(ref_shape2)})
assert exec_net.requests[0].input_blobs["data"].tensor_desc.dims == ref_shape2
assert request.output_blobs['out'].tensor_desc.dims == ref_shape2
@pytest.mark.ngraph_dependent_test
@pytest.mark.template_plugin
@pytest.mark.parametrize("shapes", [
([3, 4, 20, 20], [3, 4, 20, 20], [3, 4, 20, 20]),
([3, 4, 20, 20], [3, 6, 20, 20], [3, 8, 20, 20]),
])
def test_async_infer_dynamic_network_3_requests(shapes):
from conftest import create_ngraph_function
import ngraph as ng
function = create_ngraph_function([3, 4, 20, 20])
net = ng.function_to_cnn(function)
net.reshape({"data": [3, (2, 10), 20, 20]})
ie_core = ie.IECore()
ie_core.register_plugin("templatePlugin", "TEMPLATE")
exec_net = ie_core.load_network(net, "TEMPLATE", num_requests=3)
for i,request in enumerate(exec_net.requests):
request.async_infer({"data": np.ones(shapes[i])})
for i,request in enumerate(exec_net.requests):
status = request.wait(ie.WaitMode.RESULT_READY)
assert status == ie.StatusCode.OK
assert request.output_blobs['out'].tensor_desc.dims == shapes[i]
@pytest.mark.ngraph_dependent_test
@pytest.mark.template_plugin
def test_set_blob_with_incorrect_name():
from conftest import create_ngraph_function
import ngraph as ng
function = create_ngraph_function([4, 4, 20, 20])
net = ng.function_to_cnn(function)
ie_core = ie.IECore()
ie_core.register_plugin("templatePlugin", "TEMPLATE")
exec_net = ie_core.load_network(net, "TEMPLATE")
tensor_desc = exec_net.requests[0].input_blobs["data"].tensor_desc
tensor_desc.dims = [4, 4, 20, 20]
blob = ie.Blob(tensor_desc)
with pytest.raises(RuntimeError) as e:
exec_net.requests[0].set_blob("incorrect_name", blob)
assert f"Failed to find input or output with name: 'incorrect_name'" in str(e.value)
@pytest.mark.ngraph_dependent_test
@pytest.mark.template_plugin
def test_set_blob_with_incorrect_size():
from conftest import create_ngraph_function
import ngraph as ng
function = create_ngraph_function([4, 4, 20, 20])
net = ng.function_to_cnn(function)
ie_core = ie.IECore()
ie_core.register_plugin("templatePlugin", "TEMPLATE")
exec_net = ie_core.load_network(net, "TEMPLATE")
tensor_desc = exec_net.requests[0].input_blobs["data"].tensor_desc
tensor_desc.dims = [tensor_desc.dims[0]*2, 4, 20, 20]
blob = ie.Blob(tensor_desc)
with pytest.raises(RuntimeError) as e:
exec_net.requests[0].set_blob("data", blob)
assert f"Input blob size is not equal network input size" in str(e.value)
with pytest.raises(RuntimeError) as e:
exec_net.requests[0].set_blob("out", blob)
assert f"Output blob size is not equal network output size" in str(e.value)
@pytest.mark.ngraph_dependent_test
@pytest.mark.template_plugin
def test_set_blob_after_async_infer():
from conftest import create_ngraph_function
import ngraph as ng
function = create_ngraph_function([ng.Dimension(0,5), ng.Dimension(4), ng.Dimension(20), ng.Dimension(20)])
net = ng.function_to_cnn(function)
ie_core = ie.IECore()
ie_core.register_plugin("templatePlugin", "TEMPLATE")
exec_net = ie_core.load_network(net, "TEMPLATE")
request = exec_net.requests[0]
tensor_desc = request.input_blobs['data'].tensor_desc
tensor_desc.dims = [2, 4, 20, 20]
blob = ie.Blob(tensor_desc)
request.async_infer({"data": np.ones([4, 4, 20, 20])})
with pytest.raises(RuntimeError) as e:
request.set_blob("data", blob)
assert "REQUEST_BUSY" in str(e.value)

View File

@ -6,17 +6,14 @@ import ngraph as ng
from ngraph.impl.op import Parameter from ngraph.impl.op import Parameter
from ngraph.impl import Function, Shape, Type from ngraph.impl import Function, Shape, Type
from conftest import model_path from conftest import model_path, create_ngraph_function
test_net_xml, test_net_bin = model_path() test_net_xml, test_net_bin = model_path()
def test_create_IENetwork_from_nGraph(): def test_create_IENetwork_from_nGraph():
element_type = Type.f32 func = create_ngraph_function([1, 3, 22, 22])
param = Parameter(element_type, Shape([1, 3, 22, 22]))
relu = ng.relu(param)
func = Function([relu], [param], 'test')
caps = Function.to_capsule(func) caps = Function.to_capsule(func)
cnnNetwork = IENetwork(caps) cnnNetwork = IENetwork(caps)
assert cnnNetwork != None assert cnnNetwork != None
@ -26,10 +23,7 @@ def test_create_IENetwork_from_nGraph():
def test_get_IENetwork_from_nGraph(): def test_get_IENetwork_from_nGraph():
element_type = Type.f32 func = create_ngraph_function([1, 3, 22, 22])
param = Parameter(element_type, Shape([1, 3, 22, 22]))
relu = ng.relu(param)
func = Function([relu], [param], 'test')
caps = Function.to_capsule(func) caps = Function.to_capsule(func)
cnnNetwork = IENetwork(caps) cnnNetwork = IENetwork(caps)
assert cnnNetwork != None assert cnnNetwork != None

View File

@ -27,6 +27,7 @@ from ngraph.frontend import OpValidationFailure
from ngraph.frontend import Place from ngraph.frontend import Place
from ngraph.helpers import function_from_cnn from ngraph.helpers import function_from_cnn
from ngraph.helpers import function_to_cnn from ngraph.helpers import function_to_cnn
from ngraph.helpers import partial_shape_from_data
from ngraph.opset8 import absolute from ngraph.opset8 import absolute
from ngraph.opset8 import absolute as abs from ngraph.opset8 import absolute as abs
from ngraph.opset8 import acos from ngraph.opset8 import acos

View File

@ -3,8 +3,10 @@
"""nGraph helper functions.""" """nGraph helper functions."""
from ngraph.impl import Function from typing import Union
from openvino.inference_engine import IENetwork
from ngraph.impl import Function, PartialShape
from openvino.inference_engine import IENetwork, DataPtr, CDataPtr
def function_from_cnn(cnn_network: IENetwork) -> Function: def function_from_cnn(cnn_network: IENetwork) -> Function:
@ -18,3 +20,9 @@ def function_to_cnn(ng_function: Function) -> Function:
"""Get Inference Engine CNN network from nGraph function.""" """Get Inference Engine CNN network from nGraph function."""
capsule = Function.to_capsule(ng_function) capsule = Function.to_capsule(ng_function)
return IENetwork(capsule) return IENetwork(capsule)
def partial_shape_from_data(data: Union[DataPtr, CDataPtr]) -> PartialShape:
"""Get nGraph PartialShape from Inference Engine Data."""
capsule = data._get_partial_shape_capsule()
return PartialShape.from_capsule(capsule)

View File

@ -17,6 +17,8 @@
namespace py = pybind11; namespace py = pybind11;
static const char* CAPSULE_NAME = "ngraph_partial_shape";
void regclass_pyngraph_PartialShape(py::module m) { void regclass_pyngraph_PartialShape(py::module m) {
py::class_<ngraph::PartialShape, std::shared_ptr<ngraph::PartialShape>> shape(m, "PartialShape"); py::class_<ngraph::PartialShape, std::shared_ptr<ngraph::PartialShape>> shape(m, "PartialShape");
shape.doc() = "ngraph.impl.PartialShape wraps ngraph::PartialShape"; shape.doc() = "ngraph.impl.PartialShape wraps ngraph::PartialShape";
@ -199,4 +201,18 @@ void regclass_pyngraph_PartialShape(py::module m) {
shape.def("__repr__", [](const ngraph::PartialShape& self) -> std::string { shape.def("__repr__", [](const ngraph::PartialShape& self) -> std::string {
return "<PartialShape: " + py::cast(self).attr("__str__")().cast<std::string>() + ">"; return "<PartialShape: " + py::cast(self).attr("__str__")().cast<std::string>() + ">";
}); });
shape.def_static("from_capsule", [](py::object* capsule) {
// get the underlying PyObject* which is a PyCapsule pointer
auto* pybind_capsule_ptr = capsule->ptr();
// extract the pointer stored in the PyCapsule under the name CAPSULE_NAME
auto* capsule_ptr = PyCapsule_GetPointer(pybind_capsule_ptr, CAPSULE_NAME);
auto* ngraph_pShape = static_cast<std::shared_ptr<ngraph::PartialShape>*>(capsule_ptr);
if (ngraph_pShape && *ngraph_pShape) {
return *ngraph_pShape;
} else {
throw std::runtime_error("The provided capsule does not contain an ngraph::PartialShape");
}
});
} }