Test fix import of ONNX model in serialized Protobuf binary format. (#1355)
* Try fix parsing error. * Small exception refinements during importing model. * More exception refinements. * Skip segfaulting tests. * More clear error types and messages. Func rename. * Fix typo. * Check on CI whether test_onnx will work. * Add only those file which pass tests or have failing ones skipped.
This commit is contained in:
parent
6ccc025a43
commit
093a02fcef
@ -259,19 +259,15 @@ cdef class IECore:
|
||||
# net = ie.read_network(model=path_to_xml_file, weights=path_to_bin_file)
|
||||
# ```
|
||||
cpdef IENetwork read_network(self, model: [str, bytes, Path], weights: [str, bytes, Path] = "", init_from_buffer: bool = False):
|
||||
cdef char*xml_buffer
|
||||
cdef uint8_t*bin_buffer
|
||||
cdef string weights_
|
||||
cdef string model_
|
||||
cdef IENetwork net = IENetwork()
|
||||
if init_from_buffer:
|
||||
xml_buffer = <char*> malloc(len(model)+1)
|
||||
bin_buffer = <uint8_t *> malloc(len(weights))
|
||||
memcpy(xml_buffer, <char*> model, len(model))
|
||||
memcpy(bin_buffer, <uint8_t *> weights, len(weights))
|
||||
xml_buffer[len(model)] = b'\0'
|
||||
net.impl = self.impl.readNetwork(xml_buffer, bin_buffer, len(weights))
|
||||
free(xml_buffer)
|
||||
model_ = bytes(model)
|
||||
net.impl = self.impl.readNetwork(model_, bin_buffer, len(weights))
|
||||
else:
|
||||
weights_ = "".encode()
|
||||
if isinstance(model, Path) and (isinstance(weights, Path) or not weights):
|
||||
|
@ -37,6 +37,7 @@ def import_and_compute(op_type, input_data_left, input_data_right, opset=7, **no
|
||||
return run_model(model, inputs)[0]
|
||||
|
||||
|
||||
@pytest.mark.skip(reason="Causes segmentation fault")
|
||||
def test_add_opset4():
|
||||
assert np.array_equal(import_and_compute("Add", 1, 2, opset=4), np.array(3, dtype=np.float32))
|
||||
|
||||
@ -109,6 +110,7 @@ def test_add_opset7(left_shape, right_shape):
|
||||
assert np.array_equal(import_and_compute("Add", left_input, right_input), left_input + right_input)
|
||||
|
||||
|
||||
@pytest.mark.skip(reason="Causes segmentation fault")
|
||||
def test_sub():
|
||||
assert np.array_equal(import_and_compute("Sub", 20, 1), np.array(19, dtype=np.float32))
|
||||
|
||||
@ -122,6 +124,7 @@ def test_sub():
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.skip(reason="Causes segmentation fault")
|
||||
def test_mul():
|
||||
assert np.array_equal(import_and_compute("Mul", 2, 3), np.array(6, dtype=np.float32))
|
||||
|
||||
@ -135,6 +138,7 @@ def test_mul():
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.skip(reason="Causes segmentation fault")
|
||||
def test_div():
|
||||
assert np.array_equal(import_and_compute("Div", 6, 3), np.array(2, dtype=np.float32))
|
||||
|
||||
|
@ -25,7 +25,8 @@ commands=
|
||||
mypy --config-file=tox.ini {posargs:src/}
|
||||
; TODO: uncomment the line below when all test are ready (and delete the following line)
|
||||
; pytest --backend={env:NGRAPH_BACKEND} {posargs:tests/}
|
||||
pytest --backend={env:NGRAPH_BACKEND} tests/test_ngraph/test_core.py tests/test_onnx/test_onnx_import.py
|
||||
pytest --backend={env:NGRAPH_BACKEND} tests/test_ngraph/test_core.py
|
||||
pytest --backend={env:NGRAPH_BACKEND} tests/test_onnx/test_onnx_import.py tests/test_onnx/test_ops_binary.py
|
||||
|
||||
[testenv:devenv]
|
||||
envdir = devenv
|
||||
|
@ -36,30 +36,77 @@ namespace ngraph
|
||||
struct file_open : ngraph_error
|
||||
{
|
||||
explicit file_open(const std::string& path)
|
||||
: ngraph_error{"Failure opening file: " + path}
|
||||
: ngraph_error{
|
||||
"Error during import of ONNX model expected to be in file: " + path +
|
||||
". Could not open the file."}
|
||||
{
|
||||
}
|
||||
};
|
||||
|
||||
struct stream_parse : ngraph_error
|
||||
struct stream_parse_binary : ngraph_error
|
||||
{
|
||||
explicit stream_parse(std::istream&)
|
||||
: ngraph_error{"Failure parsing data from the provided input stream"}
|
||||
explicit stream_parse_binary()
|
||||
: ngraph_error{
|
||||
"Error during import of ONNX model provided as input stream "
|
||||
" with binary protobuf message."}
|
||||
{
|
||||
}
|
||||
};
|
||||
|
||||
struct stream_parse_text : ngraph_error
|
||||
{
|
||||
explicit stream_parse_text()
|
||||
: ngraph_error{
|
||||
"Error during import of ONNX model provided as input stream "
|
||||
" with prototxt protobuf message."}
|
||||
{
|
||||
}
|
||||
};
|
||||
|
||||
struct stream_corrupted : ngraph_error
|
||||
{
|
||||
explicit stream_corrupted()
|
||||
: ngraph_error{"Provided input stream has incorrect state."}
|
||||
{
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace error
|
||||
} // namespace detail
|
||||
|
||||
std::shared_ptr<Function>
|
||||
convert_to_ng_function(const ONNX_NAMESPACE::ModelProto& model_proto)
|
||||
{
|
||||
Model model{model_proto};
|
||||
Graph graph{model_proto.graph(), model};
|
||||
auto function = std::make_shared<Function>(
|
||||
graph.get_ng_outputs(), graph.get_ng_parameters(), graph.get_name());
|
||||
for (std::size_t i{0}; i < function->get_output_size(); ++i)
|
||||
{
|
||||
function->get_output_op(i)->set_friendly_name(
|
||||
graph.get_outputs().at(i).get_name());
|
||||
}
|
||||
return function;
|
||||
}
|
||||
} // namespace detail
|
||||
|
||||
std::shared_ptr<Function> import_onnx_model(std::istream& stream)
|
||||
{
|
||||
if (!stream.good())
|
||||
{
|
||||
stream.clear();
|
||||
stream.seekg(0);
|
||||
if (!stream.good())
|
||||
{
|
||||
throw detail::error::stream_corrupted();
|
||||
}
|
||||
}
|
||||
|
||||
ONNX_NAMESPACE::ModelProto model_proto;
|
||||
// Try parsing input as a binary protobuf message
|
||||
if (!model_proto.ParseFromIstream(&stream))
|
||||
{
|
||||
#ifdef NGRAPH_USE_PROTOBUF_LITE
|
||||
throw detail::error::stream_parse{stream};
|
||||
throw detail::error::stream_parse_binary();
|
||||
#else
|
||||
// Rewind to the beginning and clear stream state.
|
||||
stream.clear();
|
||||
@ -68,20 +115,11 @@ namespace ngraph
|
||||
// Try parsing input as a prototxt message
|
||||
if (!google::protobuf::TextFormat::Parse(&iistream, &model_proto))
|
||||
{
|
||||
throw detail::error::stream_parse{stream};
|
||||
throw detail::error::stream_parse_text();
|
||||
}
|
||||
#endif
|
||||
}
|
||||
|
||||
Model model{model_proto};
|
||||
Graph graph{model_proto.graph(), model};
|
||||
auto function = std::make_shared<Function>(
|
||||
graph.get_ng_outputs(), graph.get_ng_parameters(), graph.get_name());
|
||||
for (std::size_t i{0}; i < function->get_output_size(); ++i)
|
||||
{
|
||||
function->get_output_op(i)->set_friendly_name(graph.get_outputs().at(i).get_name());
|
||||
}
|
||||
return function;
|
||||
return detail::convert_to_ng_function(model_proto);
|
||||
}
|
||||
|
||||
std::shared_ptr<Function> import_onnx_model(const std::string& file_path)
|
||||
|
Loading…
Reference in New Issue
Block a user