Test fix import of ONNX model in serialized Protobuf binary format. (#1355)

* Try fix parsing error.

* Small exception refinements during importing model.

* More exception refinements.

* Skip segfaulting tests.

* More clear error types and messages. Func rename.

* Fix typo.

* Check on CI whether test_onnx will work.

* Add only those file which pass tests or have failing ones skipped.
This commit is contained in:
Adam Osewski 2020-07-22 12:52:53 +02:00 committed by GitHub
parent 6ccc025a43
commit 093a02fcef
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 63 additions and 24 deletions

View File

@ -259,19 +259,15 @@ cdef class IECore:
# net = ie.read_network(model=path_to_xml_file, weights=path_to_bin_file)
# ```
cpdef IENetwork read_network(self, model: [str, bytes, Path], weights: [str, bytes, Path] = "", init_from_buffer: bool = False):
cdef char*xml_buffer
cdef uint8_t*bin_buffer
cdef string weights_
cdef string model_
cdef IENetwork net = IENetwork()
if init_from_buffer:
xml_buffer = <char*> malloc(len(model)+1)
bin_buffer = <uint8_t *> malloc(len(weights))
memcpy(xml_buffer, <char*> model, len(model))
memcpy(bin_buffer, <uint8_t *> weights, len(weights))
xml_buffer[len(model)] = b'\0'
net.impl = self.impl.readNetwork(xml_buffer, bin_buffer, len(weights))
free(xml_buffer)
model_ = bytes(model)
net.impl = self.impl.readNetwork(model_, bin_buffer, len(weights))
else:
weights_ = "".encode()
if isinstance(model, Path) and (isinstance(weights, Path) or not weights):

View File

@ -37,6 +37,7 @@ def import_and_compute(op_type, input_data_left, input_data_right, opset=7, **no
return run_model(model, inputs)[0]
@pytest.mark.skip(reason="Causes segmentation fault")
def test_add_opset4():
assert np.array_equal(import_and_compute("Add", 1, 2, opset=4), np.array(3, dtype=np.float32))
@ -109,6 +110,7 @@ def test_add_opset7(left_shape, right_shape):
assert np.array_equal(import_and_compute("Add", left_input, right_input), left_input + right_input)
@pytest.mark.skip(reason="Causes segmentation fault")
def test_sub():
assert np.array_equal(import_and_compute("Sub", 20, 1), np.array(19, dtype=np.float32))
@ -122,6 +124,7 @@ def test_sub():
)
@pytest.mark.skip(reason="Causes segmentation fault")
def test_mul():
assert np.array_equal(import_and_compute("Mul", 2, 3), np.array(6, dtype=np.float32))
@ -135,6 +138,7 @@ def test_mul():
)
@pytest.mark.skip(reason="Causes segmentation fault")
def test_div():
assert np.array_equal(import_and_compute("Div", 6, 3), np.array(2, dtype=np.float32))

View File

@ -25,7 +25,8 @@ commands=
mypy --config-file=tox.ini {posargs:src/}
; TODO: uncomment the line below when all test are ready (and delete the following line)
; pytest --backend={env:NGRAPH_BACKEND} {posargs:tests/}
pytest --backend={env:NGRAPH_BACKEND} tests/test_ngraph/test_core.py tests/test_onnx/test_onnx_import.py
pytest --backend={env:NGRAPH_BACKEND} tests/test_ngraph/test_core.py
pytest --backend={env:NGRAPH_BACKEND} tests/test_onnx/test_onnx_import.py tests/test_onnx/test_ops_binary.py
[testenv:devenv]
envdir = devenv

View File

@ -36,30 +36,77 @@ namespace ngraph
struct file_open : ngraph_error
{
explicit file_open(const std::string& path)
: ngraph_error{"Failure opening file: " + path}
: ngraph_error{
"Error during import of ONNX model expected to be in file: " + path +
". Could not open the file."}
{
}
};
struct stream_parse : ngraph_error
struct stream_parse_binary : ngraph_error
{
explicit stream_parse(std::istream&)
: ngraph_error{"Failure parsing data from the provided input stream"}
explicit stream_parse_binary()
: ngraph_error{
"Error during import of ONNX model provided as input stream "
" with binary protobuf message."}
{
}
};
struct stream_parse_text : ngraph_error
{
explicit stream_parse_text()
: ngraph_error{
"Error during import of ONNX model provided as input stream "
" with prototxt protobuf message."}
{
}
};
struct stream_corrupted : ngraph_error
{
explicit stream_corrupted()
: ngraph_error{"Provided input stream has incorrect state."}
{
}
};
} // namespace error
} // namespace detail
std::shared_ptr<Function>
convert_to_ng_function(const ONNX_NAMESPACE::ModelProto& model_proto)
{
Model model{model_proto};
Graph graph{model_proto.graph(), model};
auto function = std::make_shared<Function>(
graph.get_ng_outputs(), graph.get_ng_parameters(), graph.get_name());
for (std::size_t i{0}; i < function->get_output_size(); ++i)
{
function->get_output_op(i)->set_friendly_name(
graph.get_outputs().at(i).get_name());
}
return function;
}
} // namespace detail
std::shared_ptr<Function> import_onnx_model(std::istream& stream)
{
if (!stream.good())
{
stream.clear();
stream.seekg(0);
if (!stream.good())
{
throw detail::error::stream_corrupted();
}
}
ONNX_NAMESPACE::ModelProto model_proto;
// Try parsing input as a binary protobuf message
if (!model_proto.ParseFromIstream(&stream))
{
#ifdef NGRAPH_USE_PROTOBUF_LITE
throw detail::error::stream_parse{stream};
throw detail::error::stream_parse_binary();
#else
// Rewind to the beginning and clear stream state.
stream.clear();
@ -68,20 +115,11 @@ namespace ngraph
// Try parsing input as a prototxt message
if (!google::protobuf::TextFormat::Parse(&iistream, &model_proto))
{
throw detail::error::stream_parse{stream};
throw detail::error::stream_parse_text();
}
#endif
}
Model model{model_proto};
Graph graph{model_proto.graph(), model};
auto function = std::make_shared<Function>(
graph.get_ng_outputs(), graph.get_ng_parameters(), graph.get_name());
for (std::size_t i{0}; i < function->get_output_size(); ++i)
{
function->get_output_op(i)->set_friendly_name(graph.get_outputs().at(i).get_name());
}
return function;
return detail::convert_to_ng_function(model_proto);
}
std::shared_ptr<Function> import_onnx_model(const std::string& file_path)