[IE PYTHON] dynamic shape api for python (#7282)
* Added nGraph as a public dependency
* Fixed Windows warning
* Fixed CMake
* Fixed constant op
* Fixed typo
* Added reshape to PartialShape to CNNNetwork
* Added SetShape to InferRequest
* Enable support of DynamicShape in IE Data
* Add support of dynamic shapes to template plugin
* Fixed build
* Enable support dynamic rank
* Fixed test for dynamic rank
* Fixed some tests
* Fix preprocess tests
* Fixed SetBlob
* Fixed code style
* Add more tests
* Fixed accuracy tests
* Fixed documentation
* Added tests for custom operation
* Added new tests
* WIP: move setShape from infer request to Blob
* Returned isApplicable check back
* Removed obsolete tests for InferRequest::SetShape and add new test for Blob::setShape (a trivial one)
* Fixed artifacts
* Break code style
* Revert "Break code style"
This reverts commit 71ee638cd0
.
* Added -j8 for fix_all
* Applied code style fixes
* Adde doxygen items
* Fixed style
* Applied codestyle patch
* Reverted unrelevant commit with template extension
* Fixed cmake file for shared func tests (pick from master)
* Revert all changes in template_extension
* Deleted some old stuff that commented and won't be used
* Fixed wrong exception throwing
* Code style fix
* Fixed preprocessing part
* Fixed incorrect blob reshape in GetBlob
* Deleted incorrect assert in GAPI that prevents passing some tests in Debug
* Fixed issues identified during review
* Removed SetShape, replace getLayoutByDims by getLayoutByRank and removed problematic modification from IE preprocessing
* Fixed comments
* Removed obsolete setShape
* [VPU] Fixed allocating dynamic blobs in myriad_infer_request
* Fixed comments
* Fixed CNNNgraphImpl and comments
* add partial reshape for IENetwork
* Add getPartialShape
* Add setShape for Blob
* Add tests
* Add partial_shape property for CDataPtr
* Add partial_shape property for data
* Fix code style
* Fix code style
* Fix test
* Fix code style
* Fix code style
* Fix code style
* Add tests
* Add new lines
* Mark tests
* Fix tests
* call set_shape implicit only for dynamic inputs
* skip tests on ARM
* remove ngraph dependence from ie_api.pyx
* expand only shape, not array in expand_dims_to_corresponding_layout
* Mark ngraph dependent tests
* remove skip inside test
* code refactoring
* add new line
* Add docstring
* Fix code style
Co-authored-by: Ilya Churaev <ilya.churaev@intel.com>
Co-authored-by: Lyalin, Sergey <sergey.lyalin@intel.com>
Co-authored-by: Polina <polina.brzezinskaya@intel.com>
This commit is contained in:
parent
42b93bed42
commit
f89b3d770b
@ -39,7 +39,7 @@ cdef class InferRequest:
|
||||
cpdef get_perf_counts(self)
|
||||
cdef void user_callback(self, int status) with gil
|
||||
cdef public:
|
||||
_inputs_list, _outputs_list, _py_callback, _py_data, _py_callback_used, _py_callback_called, _user_blobs
|
||||
_inputs_list, _outputs_list, _py_callback, _py_data, _py_callback_used, _py_callback_called, _user_blobs, _inputs_is_dynamic
|
||||
|
||||
cdef class IENetwork:
|
||||
cdef C.IENetwork impl
|
||||
|
@ -29,7 +29,6 @@ from .constants import WaitMode, StatusCode, MeanVariant, layout_str_to_enum, fo
|
||||
|
||||
import numpy as np
|
||||
|
||||
|
||||
warnings.filterwarnings(action="module", category=DeprecationWarning)
|
||||
|
||||
cdef extern from "<utility>" namespace "std" nogil:
|
||||
@ -53,6 +52,11 @@ cdef c_map_to_dict(map[string, string] c_map):
|
||||
return py_dict
|
||||
|
||||
|
||||
cdef expand_dims_to_corresponding_layout(shape, layout):
|
||||
single_axes = [1] * (len(layout) - len(shape))
|
||||
return single_axes + list(shape)
|
||||
|
||||
|
||||
def get_version():
|
||||
return C.get_version().decode()
|
||||
|
||||
@ -271,6 +275,10 @@ cdef class Blob:
|
||||
tensor_desc = TensorDesc(precision, dims, layout_int_to_str_map[layout])
|
||||
return tensor_desc
|
||||
|
||||
def set_shape(self, new_shape):
|
||||
self._initial_shape = new_shape
|
||||
deref(self._ptr).setShape(new_shape)
|
||||
|
||||
## This class represents an Inference Engine entity and allows you to manipulate with plugins using unified interfaces.
|
||||
cdef class IECore:
|
||||
## Class constructor
|
||||
@ -815,6 +823,14 @@ cdef class DataPtr:
|
||||
def initialized(self):
|
||||
return deref(self._ptr).isInitialized()
|
||||
|
||||
@property
|
||||
def is_dynamic(self):
|
||||
return deref(self._ptr).isDynamic()
|
||||
|
||||
## get capsule with ngraph::PartialShape
|
||||
def _get_partial_shape_capsule(self):
|
||||
return C.getPartialShape_capsule(self._ptr)
|
||||
|
||||
|
||||
## This class is the layer constant data representation. Provides same interface as DataPtr object except properties setters
|
||||
cdef class CDataPtr:
|
||||
@ -843,6 +859,14 @@ cdef class CDataPtr:
|
||||
def initialized(self):
|
||||
return deref(self._ptr).isInitialized()
|
||||
|
||||
@property
|
||||
def is_dynamic(self):
|
||||
return deref(self._ptr).isDynamic()
|
||||
|
||||
## get capsule with ngraph::PartialShape
|
||||
def _get_partial_shape_capsule(self):
|
||||
return C.getPartialShape_capsule(self._ptr)
|
||||
|
||||
|
||||
## This class represents a network instance loaded to plugin and ready for inference.
|
||||
cdef class ExecutableNetwork:
|
||||
@ -912,6 +936,8 @@ cdef class ExecutableNetwork:
|
||||
infer_request.impl = &(deref(self.impl).infer_requests[i])
|
||||
infer_request._inputs_list = list(self.input_info.keys())
|
||||
infer_request._outputs_list = list(self.outputs.keys())
|
||||
for input_name in infer_request._inputs_list:
|
||||
infer_request._inputs_is_dynamic[input_name] = self.input_info[input_name].input_data.is_dynamic
|
||||
self._infer_requests.append(infer_request)
|
||||
|
||||
if len(self._infer_requests) != c_infer_requests_size:
|
||||
@ -1048,6 +1074,7 @@ cdef class InferRequest:
|
||||
self._py_callback_used = False
|
||||
self._py_callback_called = threading.Event()
|
||||
self._py_data = None
|
||||
self._inputs_is_dynamic = {}
|
||||
|
||||
cdef void user_callback(self, int status) with gil:
|
||||
if self._py_callback:
|
||||
@ -1308,6 +1335,9 @@ cdef class InferRequest:
|
||||
def _fill_inputs(self, inputs):
|
||||
for k, v in inputs.items():
|
||||
assert k in self._inputs_list, f"No input with name {k} found in network"
|
||||
if self._inputs_is_dynamic[k]:
|
||||
shape = expand_dims_to_corresponding_layout(v.shape, self.input_blobs[k].tensor_desc.layout)
|
||||
self.input_blobs[k].set_shape(shape)
|
||||
if self.input_blobs[k].tensor_desc.precision == "FP16":
|
||||
self.input_blobs[k].buffer[:] = v.view(dtype=np.int16)
|
||||
else:
|
||||
@ -1452,15 +1482,25 @@ cdef class IENetwork:
|
||||
# net.reshape({input_layer: (n, c, h*2, w*2)})
|
||||
# ```
|
||||
def reshape(self, input_shapes: dict):
|
||||
cdef map[string, vector[size_t]] c_input_shapes
|
||||
cdef vector[size_t] c_shape
|
||||
cdef map[string, vector[vector[int64_t]]] c_input_shapes
|
||||
cdef vector[vector[int64_t]] c_shape
|
||||
cdef vector[int64_t] dim
|
||||
net_inputs = self.input_info
|
||||
for input, shape in input_shapes.items():
|
||||
c_shape = []
|
||||
if input not in net_inputs:
|
||||
raise AttributeError(f"Specified '{input}' layer not in network inputs '{net_inputs}'! ")
|
||||
for v in shape:
|
||||
c_shape.push_back(v)
|
||||
if isinstance(v, list) or isinstance(v, tuple):
|
||||
if len(v) < 1 or len(v) > 2:
|
||||
raise ValueError(f"Incorrect PartialShape dimension definition '{v}' "
|
||||
f"in shape '{shape}', expected one or two values for a dimension! ")
|
||||
for d in v:
|
||||
dim.push_back(d)
|
||||
else:
|
||||
dim.push_back(v)
|
||||
c_shape.push_back(dim)
|
||||
dim.clear()
|
||||
c_input_shapes[input.encode()] = c_shape
|
||||
self.impl.reshape(c_input_shapes)
|
||||
|
||||
|
@ -4,6 +4,8 @@
|
||||
|
||||
#include "ie_api_impl.hpp"
|
||||
|
||||
#include <ngraph/partial_shape.hpp>
|
||||
|
||||
#include "ie_iinfer_request.hpp"
|
||||
#include "ie_plugin_config.hpp"
|
||||
|
||||
@ -206,6 +208,24 @@ InferenceEnginePython::IENetwork InferenceEnginePython::read_network(std::string
|
||||
return InferenceEnginePython::IENetwork(std::make_shared<InferenceEngine::CNNNetwork>(net));
|
||||
}
|
||||
|
||||
PyObject* InferenceEnginePython::getPartialShape_capsule(InferenceEngine::CDataPtr data) {
|
||||
const char* py_capsule_name = "ngraph_partial_shape";
|
||||
auto ngraph_pShape_ptr = std::make_shared<ngraph::PartialShape>(data->getPartialShape());
|
||||
auto* sp_copy = new std::shared_ptr<const ngraph::PartialShape>(ngraph_pShape_ptr);
|
||||
auto sp_deleter = [](PyObject* capsule) {
|
||||
auto* capsule_ptr = PyCapsule_GetPointer(capsule, "ngraph_partial_shape");
|
||||
auto* function_sp = static_cast<std::shared_ptr<ngraph::PartialShape>*>(capsule_ptr);
|
||||
if (function_sp) {
|
||||
delete function_sp;
|
||||
}
|
||||
};
|
||||
if (ngraph_pShape_ptr) {
|
||||
return PyCapsule_New(sp_copy, py_capsule_name, sp_deleter);
|
||||
} else {
|
||||
return nullptr;
|
||||
}
|
||||
}
|
||||
|
||||
InferenceEnginePython::IENetwork::IENetwork(const std::shared_ptr<InferenceEngine::CNNNetwork>& cnn_network)
|
||||
: actual(cnn_network) {
|
||||
if (actual == nullptr)
|
||||
@ -289,8 +309,21 @@ size_t InferenceEnginePython::IENetwork::getBatch() {
|
||||
return actual->getBatchSize();
|
||||
}
|
||||
|
||||
void InferenceEnginePython::IENetwork::reshape(const std::map<std::string, std::vector<size_t>>& input_shapes) {
|
||||
actual->reshape(input_shapes);
|
||||
void InferenceEnginePython::IENetwork::reshape(
|
||||
const std::map<std::string, std::vector<std::vector<int64_t>>>& input_shapes) {
|
||||
std::map<std::string, ngraph::PartialShape> inputShapes;
|
||||
for (auto const& input : input_shapes) {
|
||||
using ngraph::Dimension;
|
||||
std::vector<Dimension> dims;
|
||||
for (auto const& d : input.second) {
|
||||
if (d.size() == 1)
|
||||
dims.push_back(Dimension(d[0]));
|
||||
else if (d.size() == 2)
|
||||
dims.push_back(Dimension(d[0], d[1]));
|
||||
}
|
||||
inputShapes[input.first] = ngraph::PartialShape(dims);
|
||||
}
|
||||
actual->reshape(inputShapes);
|
||||
}
|
||||
|
||||
InferenceEnginePython::IEExecNetwork::IEExecNetwork(const std::string& name, size_t num_requests)
|
||||
|
@ -62,7 +62,7 @@ struct IENetwork {
|
||||
|
||||
const std::map<std::string, InferenceEngine::DataPtr> getOutputs();
|
||||
|
||||
void reshape(const std::map<std::string, std::vector<size_t>>& input_shapes);
|
||||
void reshape(const std::map<std::string, std::vector<std::vector<int64_t>>>& input_shapes);
|
||||
|
||||
void serialize(const std::string& path_to_xml, const std::string& path_to_bin);
|
||||
|
||||
@ -203,4 +203,6 @@ std::string get_version();
|
||||
|
||||
InferenceEnginePython::IENetwork read_network(std::string path_to_xml, std::string path_to_bin);
|
||||
|
||||
PyObject* getPartialShape_capsule(InferenceEngine::CDataPtr data);
|
||||
|
||||
}; // namespace InferenceEnginePython
|
||||
|
@ -23,6 +23,7 @@ cdef extern from "<inference_engine.hpp>" namespace "InferenceEngine":
|
||||
const CTensorDesc& getTensorDesc() except +
|
||||
size_t element_size() except +
|
||||
void allocate()
|
||||
void setShape(const SizeVector& dims) except +
|
||||
|
||||
cdef TBlob[Type].Ptr make_shared_blob[Type](const CTensorDesc& tensorDesc)
|
||||
|
||||
@ -47,6 +48,7 @@ cdef extern from "<inference_engine.hpp>" namespace "InferenceEngine":
|
||||
const Layout getLayout() except +
|
||||
void setLayout(Layout layout) except +
|
||||
const bool isInitialized() except +
|
||||
bool isDynamic() except +
|
||||
|
||||
ctypedef shared_ptr[Data] DataPtr
|
||||
ctypedef weak_ptr[Data] DataWeakPtr
|
||||
@ -178,7 +180,7 @@ cdef extern from "ie_api_impl.hpp" namespace "InferenceEnginePython":
|
||||
size_t getBatch() except +
|
||||
void setLayerParams(map[string, map[string, string]] params_map) except +
|
||||
void serialize(const string& path_to_xml, const string& path_to_bin) except +
|
||||
void reshape(map[string, vector[size_t]] input_shapes) except +
|
||||
void reshape(map[string, vector[vector[int64_t]]] input_shapes) except +
|
||||
object getFunction() except +
|
||||
void convertToOldRepresentation() except +
|
||||
string getOVNameForTensor(const string &) except +
|
||||
@ -226,3 +228,5 @@ cdef extern from "ie_api_impl.hpp" namespace "InferenceEnginePython":
|
||||
cdef string get_version()
|
||||
|
||||
cdef IENetwork read_network(string path_to_xml, string path_to_bin)
|
||||
|
||||
cdef object getPartialShape_capsule(DataPtr)
|
||||
|
@ -3,6 +3,7 @@
|
||||
|
||||
import os
|
||||
import pytest
|
||||
import numpy as np
|
||||
|
||||
|
||||
def model_path(is_myriad=False):
|
||||
@ -41,7 +42,19 @@ def device():
|
||||
|
||||
|
||||
def pytest_configure(config):
|
||||
# register an additional marker for ngraph dependent tests
|
||||
# register an additional markers
|
||||
config.addinivalue_line(
|
||||
"markers", "ngraph_dependent_test"
|
||||
)
|
||||
config.addinivalue_line(
|
||||
"markers", "template_plugin"
|
||||
)
|
||||
|
||||
|
||||
def create_ngraph_function(inputShape):
|
||||
import ngraph as ng
|
||||
inputShape = ng.impl.PartialShape(inputShape)
|
||||
param = ng.parameter(inputShape, dtype=np.float32, name="data")
|
||||
result = ng.relu(param, name='out')
|
||||
function = ng.Function(result, [param], "TestFunction")
|
||||
return function
|
||||
|
@ -121,3 +121,34 @@ def test_buffer_values_after_add_outputs(device):
|
||||
result = exec_net.infer(feed_dict)
|
||||
assert np.all(abs(result[output_layer])<30)
|
||||
assert result[output_layer].dtype == np.float16
|
||||
|
||||
|
||||
def test_set_shape():
|
||||
tensor_desc = TensorDesc("FP32", [1, 3, 127, 127], "NHWC")
|
||||
blob = Blob(tensor_desc)
|
||||
blob.set_shape([1, 4, 128, 128])
|
||||
assert blob.tensor_desc.dims == [1, 4, 128, 128]
|
||||
assert blob.buffer.shape == (1, 4, 128, 128)
|
||||
|
||||
array = np.ones([1, 3, 127, 127], dtype=np.float32)
|
||||
blob = Blob(tensor_desc, array)
|
||||
blob.set_shape([1, 4, 128, 128])
|
||||
assert blob.tensor_desc.dims == [1, 4, 128, 128]
|
||||
assert blob.buffer.shape == (1, 4, 128, 128)
|
||||
|
||||
|
||||
@pytest.mark.ngraph_dependent_test
|
||||
@pytest.mark.template_plugin
|
||||
def test_blob_set_shape_after_async_infer():
|
||||
from conftest import create_ngraph_function
|
||||
import ngraph as ng
|
||||
function = create_ngraph_function([ng.Dimension(0,5), ng.Dimension(4), ng.Dimension(20), ng.Dimension(20)])
|
||||
net = ng.function_to_cnn(function)
|
||||
ie_core = IECore()
|
||||
ie_core.register_plugin("templatePlugin", "TEMPLATE")
|
||||
exec_net = ie_core.load_network(net, "TEMPLATE")
|
||||
request = exec_net.requests[0]
|
||||
request.async_infer({"data": np.ones([4, 4, 20, 20])})
|
||||
with pytest.raises(RuntimeError) as e:
|
||||
request.input_blobs['data'].set_shape([3, 4, 20, 20])
|
||||
assert "REQUEST_BUSY" in str(e.value)
|
||||
|
@ -56,3 +56,21 @@ def test_initialized(device):
|
||||
net = ie.read_network(model=test_net_xml, weights=test_net_bin)
|
||||
exec_net = ie.load_network(net, device, num_requests=5)
|
||||
assert exec_net.outputs['fc_out'].initialized, "Incorrect value for initialized property for layer 'fc_out"
|
||||
|
||||
|
||||
@pytest.mark.ngraph_dependent_test
|
||||
@pytest.mark.template_plugin
|
||||
def test_is_dynamic():
|
||||
from conftest import create_ngraph_function
|
||||
import ngraph as ng
|
||||
function = create_ngraph_function([-1, 3, 20, 20])
|
||||
net = ng.function_to_cnn(function)
|
||||
ie = IECore()
|
||||
ie.register_plugin("templatePlugin", "TEMPLATE")
|
||||
exec_net = ie.load_network(net, "TEMPLATE")
|
||||
assert exec_net.outputs["out"].is_dynamic
|
||||
p_shape = ng.partial_shape_from_data(exec_net.outputs["out"])
|
||||
assert isinstance(p_shape, ng.impl.PartialShape)
|
||||
with pytest.raises(RuntimeError) as e:
|
||||
exec_net.outputs["out"].shape
|
||||
assert "Cannot return dims for Data with dynamic shapes!" in str(e.value)
|
||||
|
@ -43,3 +43,27 @@ def test_layout():
|
||||
|
||||
def test_initialized():
|
||||
assert layer_out_data().initialized, "Incorrect value for initialized property for layer 'fc_out'"
|
||||
|
||||
|
||||
@pytest.mark.ngraph_dependent_test
|
||||
@pytest.mark.template_plugin
|
||||
def test_is_dynamic():
|
||||
from conftest import create_ngraph_function
|
||||
import ngraph as ng
|
||||
function = create_ngraph_function([-1, 3, 20, 20])
|
||||
net = ng.function_to_cnn(function)
|
||||
assert net.input_info["data"].input_data.is_dynamic
|
||||
assert net.outputs["out"].is_dynamic
|
||||
p_shape = ng.partial_shape_from_data(net.input_info["data"].input_data)
|
||||
assert isinstance(p_shape, ng.impl.PartialShape)
|
||||
p_shape = ng.partial_shape_from_data(net.outputs["out"])
|
||||
assert isinstance(p_shape, ng.impl.PartialShape)
|
||||
with pytest.raises(RuntimeError) as e:
|
||||
net.input_info["data"].input_data.shape
|
||||
assert "Cannot return dims for Data with dynamic shapes!" in str(e.value)
|
||||
ie = IECore()
|
||||
ie.register_plugin("templatePlugin", "TEMPLATE")
|
||||
exec_net = ie.load_network(net, "TEMPLATE")
|
||||
assert exec_net.input_info["data"].input_data.is_dynamic
|
||||
p_shape = ng.partial_shape_from_data(exec_net.input_info["data"].input_data)
|
||||
assert isinstance(p_shape, ng.impl.PartialShape)
|
||||
|
@ -156,6 +156,43 @@ def test_reshape():
|
||||
ie = IECore()
|
||||
net = ie.read_network(model=test_net_xml, weights=test_net_bin)
|
||||
net.reshape({"data": (2, 3, 32, 32)})
|
||||
assert net.input_info["data"].input_data.shape == [2, 3, 32, 32]
|
||||
|
||||
|
||||
@pytest.mark.ngraph_dependent_test
|
||||
@pytest.mark.parametrize("shape, p_shape", [
|
||||
([1, 3, 22, 22], [1, 3, -1, 25]),
|
||||
([1, 3, 22, 22], [-1, -1, -1, -1]),
|
||||
([1, 3, -1, 25], [1, 3, 22, -1])
|
||||
])
|
||||
def test_reshape_with_partial_shape(device, shape, p_shape):
|
||||
from conftest import create_ngraph_function
|
||||
import ngraph as ng
|
||||
function = create_ngraph_function(shape)
|
||||
net = ng.function_to_cnn(function)
|
||||
net.reshape({"data": p_shape})
|
||||
changedFunction = ng.function_from_cnn(net)
|
||||
p_shape = ng.impl.PartialShape(p_shape)
|
||||
assert changedFunction.get_parameters()[0].get_partial_shape().is_dynamic
|
||||
assert changedFunction.get_results()[0].get_output_partial_shape(0).is_dynamic
|
||||
assert function.get_parameters()[0].get_partial_shape().is_dynamic
|
||||
assert function.get_results()[0].get_output_partial_shape(0).is_dynamic
|
||||
assert changedFunction.get_parameters()[0].get_partial_shape() == p_shape
|
||||
assert changedFunction.get_results()[0].get_output_partial_shape(0) == p_shape
|
||||
assert function.get_parameters()[0].get_partial_shape() == p_shape
|
||||
assert function.get_results()[0].get_output_partial_shape(0) == p_shape
|
||||
|
||||
|
||||
@pytest.mark.ngraph_dependent_test
|
||||
def test_incorrect_reshape(device):
|
||||
from conftest import create_ngraph_function
|
||||
import ngraph as ng
|
||||
function = create_ngraph_function([1, 3, 22, 22])
|
||||
net = ng.function_to_cnn(function)
|
||||
with pytest.raises(ValueError) as e:
|
||||
net.reshape({"data": [(2, 4, 6), 3, 22, 22]})
|
||||
assert "Incorrect PartialShape dimension definition '(2, 4, 6)' " \
|
||||
"in shape '[(2, 4, 6), 3, 22, 22]', expected one or two values for a dimension! " in str(e.value)
|
||||
|
||||
|
||||
def test_net_from_buffer_valid():
|
||||
@ -245,3 +282,18 @@ def test_tensor_names():
|
||||
assert net.get_ov_name_for_tensor("relu_t") == "activation"
|
||||
assert net.get_ov_name_for_tensor("identity_t") == "activation"
|
||||
assert net.get_ov_name_for_tensor("input") == "in1"
|
||||
|
||||
|
||||
@pytest.mark.ngraph_dependent_test
|
||||
@pytest.mark.template_plugin
|
||||
def test_create_two_exec_net():
|
||||
from conftest import create_ngraph_function
|
||||
import ngraph as ng
|
||||
function = create_ngraph_function([ng.Dimension(0,5), ng.Dimension(4), ng.Dimension(20), ng.Dimension(20)])
|
||||
net = ng.function_to_cnn(function)
|
||||
ie_core = IECore()
|
||||
ie_core.register_plugin("templatePlugin", "TEMPLATE")
|
||||
exec_net1 = ie_core.load_network(net, "TEMPLATE", num_requests=2)
|
||||
assert ng.function_from_cnn(net) != None
|
||||
exec_net2 = ie_core.load_network(net, "TEMPLATE", num_requests=2)
|
||||
assert ng.function_from_cnn(net) != None
|
||||
|
@ -17,9 +17,8 @@ path_to_img = image_path()
|
||||
|
||||
|
||||
def create_function_with_memory(input_shape, data_type):
|
||||
import ngraph as ng
|
||||
from ngraph.impl import Function, Type
|
||||
|
||||
import ngraph as ng
|
||||
input_data = ng.parameter(input_shape, name="input_data", dtype=data_type)
|
||||
rv = ng.read_value(input_data, "var_id_667")
|
||||
add = ng.add(rv, input_data, name="MemoryAdd")
|
||||
@ -564,3 +563,220 @@ def test_query_state_write_buffer(device, input_shape, data_type, mode):
|
||||
|
||||
assert np.allclose(res['MemoryAdd'], expected_res, atol=1e-6), \
|
||||
"Expected values: {} \n Actual values: {} \n".format(expected_res, res)
|
||||
|
||||
|
||||
@pytest.mark.ngraph_dependent_test
|
||||
@pytest.mark.template_plugin
|
||||
@pytest.mark.parametrize("shape, p_shape, ref_shape", [
|
||||
([1, 4, 20, 20], [-1, 4, 20, 20], [5, 4, 20, 20]),
|
||||
([1, 4, 20, 20], [(0,5), 4, 20, 20], [3, 4, 20, 20]),
|
||||
([1, 4, 20, 20], [(3,5), 3, 20, 20], [2, 4, 20, 20]),
|
||||
([1, 4, 20, 20], [(3,5), 3, 20, 20], [6, 4, 20, 20]),
|
||||
])
|
||||
def test_infer_dynamic_network_with_set_shape(shape, p_shape, ref_shape):
|
||||
from conftest import create_ngraph_function
|
||||
import ngraph as ng
|
||||
function = create_ngraph_function(shape)
|
||||
net = ng.function_to_cnn(function)
|
||||
net.reshape({"data": p_shape})
|
||||
ie_core = ie.IECore()
|
||||
ie_core.register_plugin("templatePlugin", "TEMPLATE")
|
||||
exec_net = ie_core.load_network(net, "TEMPLATE")
|
||||
exec_net.requests[0].input_blobs["data"].set_shape(ref_shape)
|
||||
assert exec_net.requests[0].input_blobs["data"].tensor_desc.dims == ref_shape
|
||||
exec_net.infer({"data": np.ones(ref_shape)})
|
||||
request = exec_net.requests[0]
|
||||
request.async_infer({"data": np.ones(ref_shape)})
|
||||
status = request.wait(ie.WaitMode.RESULT_READY)
|
||||
assert status == ie.StatusCode.OK
|
||||
assert request.output_blobs['out'].tensor_desc.dims == ref_shape
|
||||
|
||||
|
||||
@pytest.mark.ngraph_dependent_test
|
||||
@pytest.mark.template_plugin
|
||||
@pytest.mark.parametrize("shape, p_shape, ref_shape", [
|
||||
([1, 4, 20, 20], [-1, 4, 20, 20], [5, 4, 20, 20]),
|
||||
([1, 4, 20, 20], [(0,5), 4, 20, 20], [3, 4, 20, 20]),
|
||||
([1, 4, 20, 20], [(3,5), 3, 20, 20], [2, 4, 20, 20]),
|
||||
([1, 4, 20, 20], [(3,5), 3, 20, 20], [6, 4, 20, 20]),
|
||||
])
|
||||
def test_infer_dynamic_network_without_set_shape(shape, p_shape, ref_shape):
|
||||
from conftest import create_ngraph_function
|
||||
import ngraph as ng
|
||||
function = create_ngraph_function(shape)
|
||||
net = ng.function_to_cnn(function)
|
||||
net.reshape({"data": p_shape})
|
||||
ie_core = ie.IECore()
|
||||
ie_core.register_plugin("templatePlugin", "TEMPLATE")
|
||||
exec_net = ie_core.load_network(net, "TEMPLATE")
|
||||
exec_net.infer({"data": np.ones(ref_shape)})
|
||||
assert exec_net.requests[0].input_blobs["data"].tensor_desc.dims == ref_shape
|
||||
request = exec_net.requests[0]
|
||||
request.async_infer({"data": np.ones(ref_shape)})
|
||||
status = request.wait(ie.WaitMode.RESULT_READY)
|
||||
assert status == ie.StatusCode.OK
|
||||
assert request.output_blobs['out'].tensor_desc.dims == ref_shape
|
||||
|
||||
|
||||
@pytest.mark.ngraph_dependent_test
|
||||
@pytest.mark.template_plugin
|
||||
@pytest.mark.parametrize("shape, p_shape, ref_shape", [
|
||||
([1, 4, 20, 20], [-1, 4, 20, 20], [5, 4, 20, 20]),
|
||||
([1, 4, 20, 20], [(0,5), 4, 20, 20], [3, 4, 20, 20]),
|
||||
([1, 4, 20, 20], [(3,5), 3, 20, 20], [2, 4, 20, 20]),
|
||||
([1, 4, 20, 20], [(3,5), 3, 20, 20], [6, 4, 20, 20]),
|
||||
])
|
||||
def test_infer_dynamic_network_with_set_blob(shape, p_shape, ref_shape):
|
||||
from conftest import create_ngraph_function
|
||||
import ngraph as ng
|
||||
function = create_ngraph_function(shape)
|
||||
net = ng.function_to_cnn(function)
|
||||
net.reshape({"data": p_shape})
|
||||
ie_core = ie.IECore()
|
||||
ie_core.register_plugin("templatePlugin", "TEMPLATE")
|
||||
exec_net = ie_core.load_network(net, "TEMPLATE")
|
||||
tensor_desc = exec_net.requests[0].input_blobs["data"].tensor_desc
|
||||
tensor_desc.dims = ref_shape
|
||||
blob = ie.Blob(tensor_desc)
|
||||
exec_net.requests[0].set_blob("data", blob)
|
||||
assert exec_net.requests[0].input_blobs["data"].tensor_desc.dims == ref_shape
|
||||
request = exec_net.requests[0]
|
||||
request.infer({"data": np.ones(ref_shape)})
|
||||
request.async_infer({"data": np.ones(ref_shape)})
|
||||
status = request.wait(ie.WaitMode.RESULT_READY)
|
||||
assert status == ie.StatusCode.OK
|
||||
assert request.output_blobs["out"].tensor_desc.dims == ref_shape
|
||||
|
||||
|
||||
@pytest.mark.ngraph_dependent_test
|
||||
@pytest.mark.template_plugin
|
||||
def test_infer_dynamic_network_twice():
|
||||
from conftest import create_ngraph_function
|
||||
import ngraph as ng
|
||||
shape, p_shape = [1, 4, 20, 20], [(0,5), 4, 20, 20]
|
||||
ref_shape1, ref_shape2 = [2, 4, 20, 20], [3, 4, 20, 20]
|
||||
function = create_ngraph_function(shape)
|
||||
net = ng.function_to_cnn(function)
|
||||
net.reshape({"data": p_shape})
|
||||
ie_core = ie.IECore()
|
||||
ie_core.register_plugin("templatePlugin", "TEMPLATE")
|
||||
exec_net = ie_core.load_network(net, "TEMPLATE")
|
||||
request = exec_net.requests[0]
|
||||
request.infer({"data": np.ones(ref_shape1)})
|
||||
assert exec_net.requests[0].input_blobs["data"].tensor_desc.dims == ref_shape1
|
||||
assert request.output_blobs['out'].tensor_desc.dims == ref_shape1
|
||||
request.infer({"data": np.ones(ref_shape2)})
|
||||
assert exec_net.requests[0].input_blobs["data"].tensor_desc.dims == ref_shape2
|
||||
assert request.output_blobs['out'].tensor_desc.dims == ref_shape2
|
||||
|
||||
|
||||
@pytest.mark.ngraph_dependent_test
|
||||
@pytest.mark.template_plugin
|
||||
def test_infer_dynamic_network_with_set_blob_twice():
|
||||
from conftest import create_ngraph_function
|
||||
import ngraph as ng
|
||||
shape, p_shape = [1, 4, 20, 20], [(0,5), 4, 20, 20]
|
||||
ref_shape1, ref_shape2 = [2, 4, 20, 20], [3, 4, 20, 20]
|
||||
function = create_ngraph_function(shape)
|
||||
net = ng.function_to_cnn(function)
|
||||
net.reshape({"data": p_shape})
|
||||
ie_core = ie.IECore()
|
||||
ie_core.register_plugin("templatePlugin", "TEMPLATE")
|
||||
exec_net = ie_core.load_network(net, "TEMPLATE")
|
||||
request = exec_net.requests[0]
|
||||
td = request.input_blobs['data'].tensor_desc
|
||||
td.dims = ref_shape1
|
||||
blob = ie.Blob(td)
|
||||
request.set_blob("data", blob)
|
||||
request.infer({"data": np.ones(ref_shape1)})
|
||||
assert exec_net.requests[0].input_blobs["data"].tensor_desc.dims == ref_shape1
|
||||
assert request.output_blobs['out'].tensor_desc.dims == ref_shape1
|
||||
td = request.input_blobs['data'].tensor_desc
|
||||
td.dims = ref_shape2
|
||||
blob = ie.Blob(td)
|
||||
request.set_blob("data", blob)
|
||||
request.infer({"data": np.ones(ref_shape2)})
|
||||
assert exec_net.requests[0].input_blobs["data"].tensor_desc.dims == ref_shape2
|
||||
assert request.output_blobs['out'].tensor_desc.dims == ref_shape2
|
||||
|
||||
|
||||
@pytest.mark.ngraph_dependent_test
|
||||
@pytest.mark.template_plugin
|
||||
@pytest.mark.parametrize("shapes", [
|
||||
([3, 4, 20, 20], [3, 4, 20, 20], [3, 4, 20, 20]),
|
||||
([3, 4, 20, 20], [3, 6, 20, 20], [3, 8, 20, 20]),
|
||||
])
|
||||
def test_async_infer_dynamic_network_3_requests(shapes):
|
||||
from conftest import create_ngraph_function
|
||||
import ngraph as ng
|
||||
function = create_ngraph_function([3, 4, 20, 20])
|
||||
net = ng.function_to_cnn(function)
|
||||
net.reshape({"data": [3, (2, 10), 20, 20]})
|
||||
ie_core = ie.IECore()
|
||||
ie_core.register_plugin("templatePlugin", "TEMPLATE")
|
||||
exec_net = ie_core.load_network(net, "TEMPLATE", num_requests=3)
|
||||
for i,request in enumerate(exec_net.requests):
|
||||
request.async_infer({"data": np.ones(shapes[i])})
|
||||
for i,request in enumerate(exec_net.requests):
|
||||
status = request.wait(ie.WaitMode.RESULT_READY)
|
||||
assert status == ie.StatusCode.OK
|
||||
assert request.output_blobs['out'].tensor_desc.dims == shapes[i]
|
||||
|
||||
|
||||
@pytest.mark.ngraph_dependent_test
|
||||
@pytest.mark.template_plugin
|
||||
def test_set_blob_with_incorrect_name():
|
||||
from conftest import create_ngraph_function
|
||||
import ngraph as ng
|
||||
function = create_ngraph_function([4, 4, 20, 20])
|
||||
net = ng.function_to_cnn(function)
|
||||
ie_core = ie.IECore()
|
||||
ie_core.register_plugin("templatePlugin", "TEMPLATE")
|
||||
exec_net = ie_core.load_network(net, "TEMPLATE")
|
||||
tensor_desc = exec_net.requests[0].input_blobs["data"].tensor_desc
|
||||
tensor_desc.dims = [4, 4, 20, 20]
|
||||
blob = ie.Blob(tensor_desc)
|
||||
with pytest.raises(RuntimeError) as e:
|
||||
exec_net.requests[0].set_blob("incorrect_name", blob)
|
||||
assert f"Failed to find input or output with name: 'incorrect_name'" in str(e.value)
|
||||
|
||||
|
||||
@pytest.mark.ngraph_dependent_test
|
||||
@pytest.mark.template_plugin
|
||||
def test_set_blob_with_incorrect_size():
|
||||
from conftest import create_ngraph_function
|
||||
import ngraph as ng
|
||||
function = create_ngraph_function([4, 4, 20, 20])
|
||||
net = ng.function_to_cnn(function)
|
||||
ie_core = ie.IECore()
|
||||
ie_core.register_plugin("templatePlugin", "TEMPLATE")
|
||||
exec_net = ie_core.load_network(net, "TEMPLATE")
|
||||
tensor_desc = exec_net.requests[0].input_blobs["data"].tensor_desc
|
||||
tensor_desc.dims = [tensor_desc.dims[0]*2, 4, 20, 20]
|
||||
blob = ie.Blob(tensor_desc)
|
||||
with pytest.raises(RuntimeError) as e:
|
||||
exec_net.requests[0].set_blob("data", blob)
|
||||
assert f"Input blob size is not equal network input size" in str(e.value)
|
||||
with pytest.raises(RuntimeError) as e:
|
||||
exec_net.requests[0].set_blob("out", blob)
|
||||
assert f"Output blob size is not equal network output size" in str(e.value)
|
||||
|
||||
|
||||
@pytest.mark.ngraph_dependent_test
|
||||
@pytest.mark.template_plugin
|
||||
def test_set_blob_after_async_infer():
|
||||
from conftest import create_ngraph_function
|
||||
import ngraph as ng
|
||||
function = create_ngraph_function([ng.Dimension(0,5), ng.Dimension(4), ng.Dimension(20), ng.Dimension(20)])
|
||||
net = ng.function_to_cnn(function)
|
||||
ie_core = ie.IECore()
|
||||
ie_core.register_plugin("templatePlugin", "TEMPLATE")
|
||||
exec_net = ie_core.load_network(net, "TEMPLATE")
|
||||
request = exec_net.requests[0]
|
||||
tensor_desc = request.input_blobs['data'].tensor_desc
|
||||
tensor_desc.dims = [2, 4, 20, 20]
|
||||
blob = ie.Blob(tensor_desc)
|
||||
request.async_infer({"data": np.ones([4, 4, 20, 20])})
|
||||
with pytest.raises(RuntimeError) as e:
|
||||
request.set_blob("data", blob)
|
||||
assert "REQUEST_BUSY" in str(e.value)
|
||||
|
@ -6,17 +6,14 @@ import ngraph as ng
|
||||
from ngraph.impl.op import Parameter
|
||||
from ngraph.impl import Function, Shape, Type
|
||||
|
||||
from conftest import model_path
|
||||
from conftest import model_path, create_ngraph_function
|
||||
|
||||
|
||||
test_net_xml, test_net_bin = model_path()
|
||||
|
||||
|
||||
def test_create_IENetwork_from_nGraph():
|
||||
element_type = Type.f32
|
||||
param = Parameter(element_type, Shape([1, 3, 22, 22]))
|
||||
relu = ng.relu(param)
|
||||
func = Function([relu], [param], 'test')
|
||||
func = create_ngraph_function([1, 3, 22, 22])
|
||||
caps = Function.to_capsule(func)
|
||||
cnnNetwork = IENetwork(caps)
|
||||
assert cnnNetwork != None
|
||||
@ -26,10 +23,7 @@ def test_create_IENetwork_from_nGraph():
|
||||
|
||||
|
||||
def test_get_IENetwork_from_nGraph():
|
||||
element_type = Type.f32
|
||||
param = Parameter(element_type, Shape([1, 3, 22, 22]))
|
||||
relu = ng.relu(param)
|
||||
func = Function([relu], [param], 'test')
|
||||
func = create_ngraph_function([1, 3, 22, 22])
|
||||
caps = Function.to_capsule(func)
|
||||
cnnNetwork = IENetwork(caps)
|
||||
assert cnnNetwork != None
|
||||
|
@ -27,6 +27,7 @@ from ngraph.frontend import OpValidationFailure
|
||||
from ngraph.frontend import Place
|
||||
from ngraph.helpers import function_from_cnn
|
||||
from ngraph.helpers import function_to_cnn
|
||||
from ngraph.helpers import partial_shape_from_data
|
||||
from ngraph.opset8 import absolute
|
||||
from ngraph.opset8 import absolute as abs
|
||||
from ngraph.opset8 import acos
|
||||
|
@ -3,8 +3,10 @@
|
||||
|
||||
"""nGraph helper functions."""
|
||||
|
||||
from ngraph.impl import Function
|
||||
from openvino.inference_engine import IENetwork
|
||||
from typing import Union
|
||||
|
||||
from ngraph.impl import Function, PartialShape
|
||||
from openvino.inference_engine import IENetwork, DataPtr, CDataPtr
|
||||
|
||||
|
||||
def function_from_cnn(cnn_network: IENetwork) -> Function:
|
||||
@ -18,3 +20,9 @@ def function_to_cnn(ng_function: Function) -> Function:
|
||||
"""Get Inference Engine CNN network from nGraph function."""
|
||||
capsule = Function.to_capsule(ng_function)
|
||||
return IENetwork(capsule)
|
||||
|
||||
|
||||
def partial_shape_from_data(data: Union[DataPtr, CDataPtr]) -> PartialShape:
|
||||
"""Get nGraph PartialShape from Inference Engine Data."""
|
||||
capsule = data._get_partial_shape_capsule()
|
||||
return PartialShape.from_capsule(capsule)
|
||||
|
@ -17,6 +17,8 @@
|
||||
|
||||
namespace py = pybind11;
|
||||
|
||||
static const char* CAPSULE_NAME = "ngraph_partial_shape";
|
||||
|
||||
void regclass_pyngraph_PartialShape(py::module m) {
|
||||
py::class_<ngraph::PartialShape, std::shared_ptr<ngraph::PartialShape>> shape(m, "PartialShape");
|
||||
shape.doc() = "ngraph.impl.PartialShape wraps ngraph::PartialShape";
|
||||
@ -199,4 +201,18 @@ void regclass_pyngraph_PartialShape(py::module m) {
|
||||
shape.def("__repr__", [](const ngraph::PartialShape& self) -> std::string {
|
||||
return "<PartialShape: " + py::cast(self).attr("__str__")().cast<std::string>() + ">";
|
||||
});
|
||||
|
||||
shape.def_static("from_capsule", [](py::object* capsule) {
|
||||
// get the underlying PyObject* which is a PyCapsule pointer
|
||||
auto* pybind_capsule_ptr = capsule->ptr();
|
||||
// extract the pointer stored in the PyCapsule under the name CAPSULE_NAME
|
||||
auto* capsule_ptr = PyCapsule_GetPointer(pybind_capsule_ptr, CAPSULE_NAME);
|
||||
|
||||
auto* ngraph_pShape = static_cast<std::shared_ptr<ngraph::PartialShape>*>(capsule_ptr);
|
||||
if (ngraph_pShape && *ngraph_pShape) {
|
||||
return *ngraph_pShape;
|
||||
} else {
|
||||
throw std::runtime_error("The provided capsule does not contain an ngraph::PartialShape");
|
||||
}
|
||||
});
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user