ONNX model validator enhancements (#10456)
This commit is contained in:
@@ -26,6 +26,7 @@ from google.protobuf import text_format
|
||||
import onnx
|
||||
from onnx.external_data_helper import convert_model_to_external_data
|
||||
import os
|
||||
import sys
|
||||
|
||||
ONNX_SUFFX = '.onnx'
|
||||
PROTOTXT_SUFFX = '.prototxt'
|
||||
|
||||
@@ -8,6 +8,7 @@
|
||||
#include <array>
|
||||
#include <exception>
|
||||
#include <map>
|
||||
#include <unordered_set>
|
||||
#include <vector>
|
||||
|
||||
namespace {
|
||||
@@ -128,7 +129,7 @@ ONNXField decode_next_field(std::istream& model) {
|
||||
switch (decoded_key.second) {
|
||||
case VARINT: {
|
||||
// the decoded varint is the payload in this case but its value does not matter
|
||||
// in the fast check process so you can discard it
|
||||
// in the fast check process so it can be discarded
|
||||
decode_varint(model);
|
||||
return {onnx_field, 0};
|
||||
}
|
||||
@@ -198,21 +199,23 @@ namespace ngraph {
|
||||
namespace onnx_common {
|
||||
bool is_valid_model(std::istream& model) {
|
||||
// the model usually starts with a 0x08 byte indicating the ir_version value
|
||||
// so this checker expects at least 2 valid ONNX keys to be found in the validated model
|
||||
const unsigned int EXPECTED_FIELDS_FOUND = 2u;
|
||||
unsigned int valid_fields_found = 0u;
|
||||
// so this checker expects at least 3 valid ONNX keys to be found in the validated model
|
||||
const size_t EXPECTED_FIELDS_FOUND = 3u;
|
||||
std::unordered_set<onnx::Field, std::hash<int>> onnx_fields_found = {};
|
||||
try {
|
||||
while (!model.eof() && valid_fields_found < EXPECTED_FIELDS_FOUND) {
|
||||
while (!model.eof() && onnx_fields_found.size() < EXPECTED_FIELDS_FOUND) {
|
||||
const auto field = ::onnx::decode_next_field(model);
|
||||
|
||||
++valid_fields_found;
|
||||
|
||||
if (field.second > 0) {
|
||||
::onnx::skip_payload(model, field.second);
|
||||
if (onnx_fields_found.count(field.first) > 0) {
|
||||
// if the same field is found twice, this is not a valid ONNX model
|
||||
return false;
|
||||
} else {
|
||||
onnx_fields_found.insert(field.first);
|
||||
onnx::skip_payload(model, field.second);
|
||||
}
|
||||
}
|
||||
|
||||
return valid_fields_found == EXPECTED_FIELDS_FOUND;
|
||||
return onnx_fields_found.size() == EXPECTED_FIELDS_FOUND;
|
||||
} catch (...) {
|
||||
return false;
|
||||
}
|
||||
|
||||
@@ -63,3 +63,9 @@ TEST(ONNXReader_ModelUnsupported, unknown_wire_type) {
|
||||
EXPECT_THROW(InferenceEngine::Core{}.ReadNetwork(model_path("unsupported/unknown_wire_type.onnx")),
|
||||
InferenceEngine::NetworkNotRead);
|
||||
}
|
||||
|
||||
TEST(ONNXReader_ModelUnsupported, duplicate_fields) {
|
||||
// the model contains the IR_VERSION field twice - this is not correct
|
||||
EXPECT_THROW(InferenceEngine::Core{}.ReadNetwork(model_path("unsupported/duplicate_onnx_fields.onnx")),
|
||||
std::exception);
|
||||
}
|
||||
|
||||
@@ -0,0 +1,12 @@
|
||||
ONNX Reader test2Doc string for this model:D
|
||||
|
||||
xy"Cosh
|
||||
cosh_graphZ
|
||||
x
|
||||
|
||||
|
||||
b
|
||||
y
|
||||
|
||||
|
||||
B
|
||||
Reference in New Issue
Block a user