[ONNX] Exception handling refinements. (#1266)

This commit is contained in:
Adam Osewski 2020-07-15 14:02:18 +02:00 committed by GitHub
parent 382b442ab3
commit 173ce2c907
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
19 changed files with 326 additions and 67 deletions

View File

@ -14,11 +14,14 @@
// limitations under the License.
//*****************************************************************************
#include <exception>
#include <functional>
#include <numeric>
#include <sstream>
#include "exceptions.hpp"
#include "graph.hpp"
#include "ngraph/log.hpp"
#include "node.hpp"
#include "provenance.hpp"
#include "utils/common.hpp"
@ -190,8 +193,29 @@ namespace ngraph
{
const auto ng_node_factory =
m_model->get_operator(onnx_node.op_type(), onnx_node.domain());
const auto ng_node_vector = ng_node_factory(onnx_node);
NodeVector ng_node_vector;
try
{
ng_node_vector = ng_node_factory(onnx_node);
}
catch (const ::ngraph::onnx_import::error::OnnxNodeValidationFailure& exc)
{
// Do nothing OnnxNodeValidationFailure exception already has ONNX node information.
throw;
}
catch (const std::exception& exc)
{
std::string msg_prefix = error::detail::get_error_msg_prefix(onnx_node);
throw ngraph_error(msg_prefix + ":\n" + std::string(exc.what()));
}
catch (...)
{
std::string msg_prefix = error::detail::get_error_msg_prefix(onnx_node);
// Since we do not know anything about current exception data type we can only
// notify user in this way.
NGRAPH_ERR << msg_prefix + "Unhandled exception type. \n";
std::rethrow_exception(std::current_exception());
}
set_friendly_names(onnx_node, ng_node_vector);
add_provenance_tags(onnx_node, ng_node_vector);

View File

@ -34,22 +34,6 @@ namespace ngraph
std::string get_error_msg_prefix(const Node& node);
}
struct NotSupported : AssertionFailure
{
explicit NotSupported(const std::string& what_arg)
: AssertionFailure(what_arg)
{
}
};
struct InvalidArgument : AssertionFailure
{
explicit InvalidArgument(const std::string& what_arg)
: AssertionFailure(what_arg)
{
}
};
class OnnxNodeValidationFailure : public CheckFailure
{
public:
@ -67,14 +51,6 @@ namespace ngraph
} // namespace ngraph
#define ASSERT_IS_SUPPORTED(node_, cond_) \
NGRAPH_ASSERT_STREAM_DO_NOT_USE_IN_NEW_CODE(ngraph::onnx_import::error::NotSupported, cond_) \
<< (node_) << " "
#define ASSERT_VALID_ARGUMENT(node_, cond_) \
NGRAPH_ASSERT_STREAM_DO_NOT_USE_IN_NEW_CODE(ngraph::onnx_import::error::InvalidArgument, \
cond_) \
<< (node_) << " "
#define CHECK_VALID_NODE(node_, cond_, ...) \
NGRAPH_CHECK_HELPER( \
::ngraph::onnx_import::error::OnnxNodeValidationFailure, (node_), (cond_), ##__VA_ARGS__)

View File

@ -44,7 +44,7 @@ namespace ngraph
// TODO: Implement learning mode support
// float momentum{node.get_attribute_value<float>("momentum", 0.9f)};
ASSERT_IS_SUPPORTED(node, is_test) << "only 'is_test' mode is supported.";
CHECK_VALID_NODE(node, is_test, "only 'is_test' mode is supported.");
// optional outputs
auto after_bn_mean = std::make_shared<NullNode>();

View File

@ -39,9 +39,11 @@ namespace ngraph
auto filters = inputs.at(1);
int64_t groups{node.get_attribute_value<int64_t>("group", 1)};
ASSERT_VALID_ARGUMENT(node, (groups == 1))
<< "Only value of 1 for 'group' supported for ConvInteger. Given: "
<< groups;
CHECK_VALID_NODE(
node,
groups == 1,
"Only value of 1 for 'group' supported for ConvInteger. Given: ",
groups);
auto window_movement_strides = convpool::get_strides(node);
auto window_dilation_strides = convpool::get_dilations(node);

View File

@ -47,9 +47,11 @@ namespace ngraph
target_type = input->get_element_type();
}
ASSERT_VALID_ARGUMENT(node, input_shape.size() == 2)
<< "The provided shape rank: " << input_shape.size()
<< " is unsupported, only 2D shapes are supported";
CHECK_VALID_NODE(node,
input_shape.size() == 2,
"The provided shape rank: ",
input_shape.size(),
" is unsupported, only 2D shapes are supported");
std::shared_ptr<ngraph::Node> eye_like_matrix =
common::shifted_square_identity(input_shape, target_type, shift);

View File

@ -33,8 +33,8 @@ namespace ngraph
auto data = node.get_ng_inputs().at(0);
double alpha = node.get_attribute_value<double>("alpha", 0.01);
ASSERT_VALID_ARGUMENT(node, ((alpha >= 0) && (alpha <= 1)))
<< " alpha value should be in range (0,1)";
CHECK_VALID_NODE(
node, alpha >= 0 && alpha <= 1, " alpha value should be in range (0,1)");
std::shared_ptr<ngraph::Node> alpha_node =
default_opset::Constant::create(data->get_element_type(), Shape{}, {alpha});

View File

@ -51,9 +51,11 @@ namespace ngraph
const size_t normalize_axis =
ngraph::normalize_axis(node.get_description(), axis, data_rank);
ASSERT_VALID_ARGUMENT(node, p_norm == 1 || p_norm == 2)
<< "Invalid `p` attribute value: " << p_norm
<< "Only normalization of 1st or 2nd order is supported.";
CHECK_VALID_NODE(node,
p_norm == 1 || p_norm == 2,
"Invalid `p` attribute value: ",
p_norm,
"Only normalization of 1st or 2nd order is supported.");
const auto normalize_axis_const =
default_opset::Constant::create(element::i64, {}, {normalize_axis});

View File

@ -53,8 +53,10 @@ namespace ngraph
const std::size_t channels_count = data_shape[channel_axis].get_length();
const std::int64_t p_norm{node.get_attribute_value<std::int64_t>("p", 2)};
ASSERT_VALID_ARGUMENT(node, p_norm >= 0)
<< "Only positive (including zero) values are supported for 'p' attribute.";
CHECK_VALID_NODE(
node,
p_norm >= 0,
"Only positive (including zero) values are supported for 'p' attribute.");
NodeVector slices =
ngraph::builder::opset1::split(data, channels_count, channel_axis);

View File

@ -37,8 +37,8 @@ namespace ngraph
std::shared_ptr<ngraph::Node> divisor{node.get_ng_inputs().at(1)};
std::int64_t fmod = node.get_attribute_value<std::int64_t>("fmod", 0);
ASSERT_IS_SUPPORTED(node, fmod == 1)
<< "Only 'fmod=1' mode is supported for mod operator.";
CHECK_VALID_NODE(
node, fmod == 1, "Only 'fmod=1' mode is supported for mod operator.");
return {std::make_shared<default_opset::Mod>(dividend, divisor)};
}

View File

@ -79,8 +79,10 @@ namespace ngraph
const auto center_point_box =
node.get_attribute_value<std::int64_t>("center_point_box", 0);
ASSERT_IS_SUPPORTED(node, center_point_box == 0 || center_point_box == 1)
<< "Allowed values of the 'center_point_box' attribute are 0 and 1.";
CHECK_VALID_NODE(
node,
center_point_box == 0 || center_point_box == 1,
"Allowed values of the 'center_point_box' attribute are 0 and 1.");
const auto box_encoding =
center_point_box == 0

View File

@ -47,8 +47,7 @@ namespace
}
else
{
throw ngraph::onnx_import::error::InvalidArgument("Unsupported padding mode: [" + mode +
"]");
throw ngraph::ngraph_error("Unsupported padding mode: [" + mode + "]");
}
return pad_mode;

View File

@ -109,7 +109,7 @@ namespace ngraph
if (bias)
{
throw error::NotSupported(
throw ngraph_error(
"Groups != 1 not supported for Quantized Convolution with "
"bias.");
}
@ -198,22 +198,26 @@ namespace ngraph
auto output_scale = inputs.at(6);
auto output_zero_point = inputs.at(7);
ASSERT_VALID_ARGUMENT(
node,
((groups >= 0) &&
(groups <= static_cast<int64_t>(data->get_shape().at(1))) &&
(groups <= static_cast<int64_t>(filters->get_shape().at(0)))))
<< "incorrect value of 'group' attribute: " << groups;
CHECK_VALID_NODE(node,
((groups >= 0) &&
(groups <= static_cast<int64_t>(data->get_shape().at(1))) &&
(groups <= static_cast<int64_t>(filters->get_shape().at(0)))),
"incorrect value of 'group' attribute: ",
groups);
std::size_t n_data_channels{data->get_shape().at(1)};
std::size_t n_filters_channels{filters->get_shape().at(0)};
ASSERT_VALID_ARGUMENT(node, n_data_channels % groups == 0)
<< "provided group attribute value must be a multiple of data channels "
"count.";
ASSERT_VALID_ARGUMENT(node, n_filters_channels % groups == 0)
<< "provided group attribute value must be a multiple of filter channels "
"count.";
CHECK_VALID_NODE(
node,
n_data_channels % groups == 0,
"provided group attribute value must be a multiple of data channels "
"count.");
CHECK_VALID_NODE(
node,
n_filters_channels % groups == 0,
"provided group attribute value must be a multiple of filter channels "
"count.");
Strides strides = convpool::get_strides(node);
Strides filter_dilations = convpool::get_dilations(node);

View File

@ -34,8 +34,11 @@ namespace ngraph
const float bias = node.get_attribute_value<float>("bias", 0.0f);
const float lambd = node.get_attribute_value<float>("lambd", 0.5f);
ASSERT_VALID_ARGUMENT(node, !(lambd < 0.0f))
<< " The provided 'lambd' value:" << lambd << " must not be negative.";
CHECK_VALID_NODE(node,
!(lambd < 0.0f),
" The provided 'lambd' value: ",
lambd,
" must not be negative.");
std::shared_ptr<default_opset::Constant> negative_lambd;
const auto input_element_type = input->get_element_type();

View File

@ -64,9 +64,13 @@ namespace ngraph
auto reduction_axes = detail::get_reduction_axes(node);
ASSERT_VALID_ARGUMENT(node, reduction_axes.size() <= data_shape.size())
<< "provided reduction axes count (" << reduction_axes.size()
<< ") is larger than input tensor rank (" << data_shape.size() << ")";
CHECK_VALID_NODE(node,
reduction_axes.size() <= data_shape.size(),
"provided reduction axes count (",
reduction_axes.size(),
") is larger than input tensor rank (",
data_shape.size(),
")");
std::shared_ptr<ngraph::Node> op_node =
reduction_function(ng_input, reduction_axes);
@ -99,9 +103,13 @@ namespace ngraph
const auto reduction_axes = detail::get_reduction_axes(node);
ASSERT_VALID_ARGUMENT(node, reduction_axes.size() <= data_rank)
<< "provided reduction axes count (" << reduction_axes.size()
<< ") is larger than input tensor rank (" << data_rank << ")";
CHECK_VALID_NODE(node,
reduction_axes.size() <= data_rank,
"provided reduction axes count (",
reduction_axes.size(),
") is larger than input tensor rank (",
data_rank,
")");
std::int64_t keepdims = node.get_attribute_value<std::int64_t>("keepdims", 1);

View File

@ -16,6 +16,7 @@
#include "ngraph/op/pad.hpp"
#include "ngraph/attribute_visitor.hpp"
#include "ngraph/except.hpp"
#include "ngraph/op/broadcast.hpp"
#include "ngraph/op/constant.hpp"

View File

@ -381,6 +381,8 @@ if (NGRAPH_ONNX_IMPORT_ENABLE AND NOT NGRAPH_USE_PROTOBUF_LITE)
onnx/onnx_import_reshape.in.cpp
onnx/onnx_import_rnn.in.cpp
onnx/onnx_import_quant.in.cpp)
list(APPEND SRC
onnx/onnx_import_exceptions.cpp)
endif()
foreach(BACKEND_NAME ${ACTIVE_BACKEND_LIST})

View File

@ -0,0 +1,38 @@
ir_version: 3
producer_name: "nGraph ONNX Importer"
graph {
node {
input: "x"
input: "y"
output: "sum"
op_type: "Add"
}
name: "test_add_dyn_shapes"
input {
name: "x"
type {
tensor_type {
elem_type: 1
}
}
}
input {
name: "y"
type {
tensor_type {
elem_type: 1
}
}
}
output {
name: "sum"
type {
tensor_type {
elem_type: 1
}
}
}
}
opset_import {
version: 1
}

View File

@ -0,0 +1,92 @@
ir_version: 3
producer_name: "nGraph ONNX Importer"
graph {
node {
input: "x"
input: "scale"
input: "bias"
output: "y"
op_type: "InstanceNormalization"
attribute {
name: "epsilon"
f: 0.01
type: FLOAT
}
}
name: "instance_norm_graph"
input {
name: "x"
type {
tensor_type {
elem_type: 1
shape {
dim {
dim_value: 1
}
dim {
dim_value: 2
}
dim {
dim_value: 3
}
dim {
dim_value: 4
}
}
}
}
}
input {
name: "scale"
type {
tensor_type {
elem_type: 4
shape {
dim {
dim_value: 2
}
}
}
}
}
input {
name: "bias"
type {
tensor_type {
elem_type: 1
shape {
dim {
dim_value: 2
}
}
}
}
}
output {
name: "y"
type {
tensor_type {
elem_type: 1
shape {
dim {
dim_value: 1
}
dim {
dim_value: 2
}
dim {
dim_value: 3
}
dim {
dim_value: 4
}
}
}
}
}
}
opset_import {
version: 1
}

View File

@ -0,0 +1,102 @@
//*****************************************************************************
// Copyright 2017-2020 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#include <exception>
#include "gtest/gtest.h"
#include "ngraph/file_util.hpp"
#include "ngraph/frontend/onnx_import/exceptions.hpp"
#include "ngraph/frontend/onnx_import/onnx.hpp"
#include "ngraph/ngraph.hpp"
#include "util/type_prop.hpp"
using namespace ngraph;
TEST(onnx_importer, exception_throws_ngraph_error)
{
EXPECT_THROW(onnx_import::import_onnx_model(file_util::path_join(
SERIALIZED_ZOO, "onnx/depth_to_space_bad_blocksize.prototxt")),
ngraph_error);
}
TEST(onnx_importer, exception_msg_ngraph_error)
{
try
{
onnx_import::import_onnx_model(
file_util::path_join(SERIALIZED_ZOO, "onnx/depth_to_space_bad_blocksize.prototxt"));
// Should have thrown, so fail if it didn't
FAIL() << "ONNX Importer did not detected incorrect model!";
}
catch (const ngraph_error& e)
{
EXPECT_HAS_SUBSTRING(e.what(),
std::string("While validating ONNX node '<Node(DepthToSpace)"));
EXPECT_HAS_SUBSTRING(e.what(), std::string("While validating node 'v0::DepthToSpace"));
}
catch (...)
{
FAIL() << "The ONNX model importer failed for unexpected reason";
}
}
TEST(onnx_importer, exception_msg_onnx_node_validation_failure)
{
try
{
onnx_import::import_onnx_model(
file_util::path_join(SERIALIZED_ZOO, "onnx/instance_norm_bad_scale_type.prototxt"));
// Should have thrown, so fail if it didn't
FAIL() << "ONNX Importer did not detected incorrect model!";
}
catch (const ::ngraph::onnx_import::error::OnnxNodeValidationFailure& e)
{
EXPECT_HAS_SUBSTRING(
e.what(), std::string("While validating ONNX node '<Node(InstanceNormalization)"));
}
// On MacOS after we re-throw OnnxNodeValidationFailure exception, we couldn't catch it as is,
// thus below workaround.
catch (const std::exception& e)
{
EXPECT_HAS_SUBSTRING(
e.what(), std::string("While validating ONNX node '<Node(InstanceNormalization)"));
}
catch (...)
{
FAIL() << "The ONNX model importer failed for unexpected reason";
}
}
// This test aims to check for wrapping all std::exception not deriving from ngraph_error.
// This test should throw a std error because of attempt to access shape from dynamic tensor.
TEST(onnx_importer, exception_msg_std_err_wrapped)
{
try
{
onnx_import::import_onnx_model(file_util::path_join(
SERIALIZED_ZOO, "onnx/dynamic_shapes/add_opset6_dyn_shape.prototxt"));
// Should have thrown, so fail if it didn't
FAIL() << "ONNX Importer did not detected incorrect model!";
}
catch (const std::exception& e)
{
EXPECT_HAS_SUBSTRING(e.what(), std::string("While validating ONNX node '<Node(Add)"));
}
catch (...)
{
FAIL() << "The ONNX model importer failed for unexpected reason";
}
}