ONNX model validator enhancements (#10456)

This commit is contained in:
Tomasz Dołbniak
2022-02-17 11:01:47 +01:00
committed by GitHub
parent 61f915b4f6
commit 83a8ac800c
4 changed files with 32 additions and 10 deletions

View File

@@ -26,6 +26,7 @@ from google.protobuf import text_format
import onnx
from onnx.external_data_helper import convert_model_to_external_data
import os
import sys
ONNX_SUFFX = '.onnx'
PROTOTXT_SUFFX = '.prototxt'

View File

@@ -8,6 +8,7 @@
#include <array>
#include <exception>
#include <map>
#include <unordered_set>
#include <vector>
namespace {
@@ -128,7 +129,7 @@ ONNXField decode_next_field(std::istream& model) {
switch (decoded_key.second) {
case VARINT: {
// the decoded varint is the payload in this case but its value does not matter
// in the fast check process so you can discard it
// in the fast check process so it can be discarded
decode_varint(model);
return {onnx_field, 0};
}
@@ -198,21 +199,23 @@ namespace ngraph {
namespace onnx_common {
bool is_valid_model(std::istream& model) {
// the model usually starts with a 0x08 byte indicating the ir_version value
// so this checker expects at least 2 valid ONNX keys to be found in the validated model
const unsigned int EXPECTED_FIELDS_FOUND = 2u;
unsigned int valid_fields_found = 0u;
// so this checker expects at least 3 valid ONNX keys to be found in the validated model
const size_t EXPECTED_FIELDS_FOUND = 3u;
std::unordered_set<onnx::Field, std::hash<int>> onnx_fields_found = {};
try {
while (!model.eof() && valid_fields_found < EXPECTED_FIELDS_FOUND) {
while (!model.eof() && onnx_fields_found.size() < EXPECTED_FIELDS_FOUND) {
const auto field = ::onnx::decode_next_field(model);
++valid_fields_found;
if (field.second > 0) {
::onnx::skip_payload(model, field.second);
if (onnx_fields_found.count(field.first) > 0) {
// if the same field is found twice, this is not a valid ONNX model
return false;
} else {
onnx_fields_found.insert(field.first);
onnx::skip_payload(model, field.second);
}
}
return valid_fields_found == EXPECTED_FIELDS_FOUND;
return onnx_fields_found.size() == EXPECTED_FIELDS_FOUND;
} catch (...) {
return false;
}

View File

@@ -63,3 +63,9 @@ TEST(ONNXReader_ModelUnsupported, unknown_wire_type) {
EXPECT_THROW(InferenceEngine::Core{}.ReadNetwork(model_path("unsupported/unknown_wire_type.onnx")),
InferenceEngine::NetworkNotRead);
}
TEST(ONNXReader_ModelUnsupported, duplicate_fields) {
// the model contains the IR_VERSION field twice - this is not correct
EXPECT_THROW(InferenceEngine::Core{}.ReadNetwork(model_path("unsupported/duplicate_onnx_fields.onnx")),
std::exception);
}

View File

@@ -0,0 +1,12 @@
ONNX Reader test2Doc string for this model:D
xy"Cosh
cosh_graphZ
x


b
y


B