[ONNX] Exception handling refinements. (#1266)
This commit is contained in:
parent
382b442ab3
commit
173ce2c907
@ -14,11 +14,14 @@
|
||||
// limitations under the License.
|
||||
//*****************************************************************************
|
||||
|
||||
#include <exception>
|
||||
#include <functional>
|
||||
#include <numeric>
|
||||
#include <sstream>
|
||||
|
||||
#include "exceptions.hpp"
|
||||
#include "graph.hpp"
|
||||
#include "ngraph/log.hpp"
|
||||
#include "node.hpp"
|
||||
#include "provenance.hpp"
|
||||
#include "utils/common.hpp"
|
||||
@ -190,8 +193,29 @@ namespace ngraph
|
||||
{
|
||||
const auto ng_node_factory =
|
||||
m_model->get_operator(onnx_node.op_type(), onnx_node.domain());
|
||||
|
||||
const auto ng_node_vector = ng_node_factory(onnx_node);
|
||||
NodeVector ng_node_vector;
|
||||
try
|
||||
{
|
||||
ng_node_vector = ng_node_factory(onnx_node);
|
||||
}
|
||||
catch (const ::ngraph::onnx_import::error::OnnxNodeValidationFailure& exc)
|
||||
{
|
||||
// Do nothing OnnxNodeValidationFailure exception already has ONNX node information.
|
||||
throw;
|
||||
}
|
||||
catch (const std::exception& exc)
|
||||
{
|
||||
std::string msg_prefix = error::detail::get_error_msg_prefix(onnx_node);
|
||||
throw ngraph_error(msg_prefix + ":\n" + std::string(exc.what()));
|
||||
}
|
||||
catch (...)
|
||||
{
|
||||
std::string msg_prefix = error::detail::get_error_msg_prefix(onnx_node);
|
||||
// Since we do not know anything about current exception data type we can only
|
||||
// notify user in this way.
|
||||
NGRAPH_ERR << msg_prefix + "Unhandled exception type. \n";
|
||||
std::rethrow_exception(std::current_exception());
|
||||
}
|
||||
set_friendly_names(onnx_node, ng_node_vector);
|
||||
add_provenance_tags(onnx_node, ng_node_vector);
|
||||
|
||||
|
@ -34,22 +34,6 @@ namespace ngraph
|
||||
std::string get_error_msg_prefix(const Node& node);
|
||||
}
|
||||
|
||||
struct NotSupported : AssertionFailure
|
||||
{
|
||||
explicit NotSupported(const std::string& what_arg)
|
||||
: AssertionFailure(what_arg)
|
||||
{
|
||||
}
|
||||
};
|
||||
|
||||
struct InvalidArgument : AssertionFailure
|
||||
{
|
||||
explicit InvalidArgument(const std::string& what_arg)
|
||||
: AssertionFailure(what_arg)
|
||||
{
|
||||
}
|
||||
};
|
||||
|
||||
class OnnxNodeValidationFailure : public CheckFailure
|
||||
{
|
||||
public:
|
||||
@ -67,14 +51,6 @@ namespace ngraph
|
||||
|
||||
} // namespace ngraph
|
||||
|
||||
#define ASSERT_IS_SUPPORTED(node_, cond_) \
|
||||
NGRAPH_ASSERT_STREAM_DO_NOT_USE_IN_NEW_CODE(ngraph::onnx_import::error::NotSupported, cond_) \
|
||||
<< (node_) << " "
|
||||
#define ASSERT_VALID_ARGUMENT(node_, cond_) \
|
||||
NGRAPH_ASSERT_STREAM_DO_NOT_USE_IN_NEW_CODE(ngraph::onnx_import::error::InvalidArgument, \
|
||||
cond_) \
|
||||
<< (node_) << " "
|
||||
|
||||
#define CHECK_VALID_NODE(node_, cond_, ...) \
|
||||
NGRAPH_CHECK_HELPER( \
|
||||
::ngraph::onnx_import::error::OnnxNodeValidationFailure, (node_), (cond_), ##__VA_ARGS__)
|
||||
|
@ -44,7 +44,7 @@ namespace ngraph
|
||||
|
||||
// TODO: Implement learning mode support
|
||||
// float momentum{node.get_attribute_value<float>("momentum", 0.9f)};
|
||||
ASSERT_IS_SUPPORTED(node, is_test) << "only 'is_test' mode is supported.";
|
||||
CHECK_VALID_NODE(node, is_test, "only 'is_test' mode is supported.");
|
||||
|
||||
// optional outputs
|
||||
auto after_bn_mean = std::make_shared<NullNode>();
|
||||
|
@ -39,9 +39,11 @@ namespace ngraph
|
||||
auto filters = inputs.at(1);
|
||||
|
||||
int64_t groups{node.get_attribute_value<int64_t>("group", 1)};
|
||||
ASSERT_VALID_ARGUMENT(node, (groups == 1))
|
||||
<< "Only value of 1 for 'group' supported for ConvInteger. Given: "
|
||||
<< groups;
|
||||
CHECK_VALID_NODE(
|
||||
node,
|
||||
groups == 1,
|
||||
"Only value of 1 for 'group' supported for ConvInteger. Given: ",
|
||||
groups);
|
||||
|
||||
auto window_movement_strides = convpool::get_strides(node);
|
||||
auto window_dilation_strides = convpool::get_dilations(node);
|
||||
|
@ -47,9 +47,11 @@ namespace ngraph
|
||||
target_type = input->get_element_type();
|
||||
}
|
||||
|
||||
ASSERT_VALID_ARGUMENT(node, input_shape.size() == 2)
|
||||
<< "The provided shape rank: " << input_shape.size()
|
||||
<< " is unsupported, only 2D shapes are supported";
|
||||
CHECK_VALID_NODE(node,
|
||||
input_shape.size() == 2,
|
||||
"The provided shape rank: ",
|
||||
input_shape.size(),
|
||||
" is unsupported, only 2D shapes are supported");
|
||||
|
||||
std::shared_ptr<ngraph::Node> eye_like_matrix =
|
||||
common::shifted_square_identity(input_shape, target_type, shift);
|
||||
|
@ -33,8 +33,8 @@ namespace ngraph
|
||||
auto data = node.get_ng_inputs().at(0);
|
||||
double alpha = node.get_attribute_value<double>("alpha", 0.01);
|
||||
|
||||
ASSERT_VALID_ARGUMENT(node, ((alpha >= 0) && (alpha <= 1)))
|
||||
<< " alpha value should be in range (0,1)";
|
||||
CHECK_VALID_NODE(
|
||||
node, alpha >= 0 && alpha <= 1, " alpha value should be in range (0,1)");
|
||||
|
||||
std::shared_ptr<ngraph::Node> alpha_node =
|
||||
default_opset::Constant::create(data->get_element_type(), Shape{}, {alpha});
|
||||
|
@ -51,9 +51,11 @@ namespace ngraph
|
||||
const size_t normalize_axis =
|
||||
ngraph::normalize_axis(node.get_description(), axis, data_rank);
|
||||
|
||||
ASSERT_VALID_ARGUMENT(node, p_norm == 1 || p_norm == 2)
|
||||
<< "Invalid `p` attribute value: " << p_norm
|
||||
<< "Only normalization of 1st or 2nd order is supported.";
|
||||
CHECK_VALID_NODE(node,
|
||||
p_norm == 1 || p_norm == 2,
|
||||
"Invalid `p` attribute value: ",
|
||||
p_norm,
|
||||
"Only normalization of 1st or 2nd order is supported.");
|
||||
|
||||
const auto normalize_axis_const =
|
||||
default_opset::Constant::create(element::i64, {}, {normalize_axis});
|
||||
|
@ -53,8 +53,10 @@ namespace ngraph
|
||||
const std::size_t channels_count = data_shape[channel_axis].get_length();
|
||||
const std::int64_t p_norm{node.get_attribute_value<std::int64_t>("p", 2)};
|
||||
|
||||
ASSERT_VALID_ARGUMENT(node, p_norm >= 0)
|
||||
<< "Only positive (including zero) values are supported for 'p' attribute.";
|
||||
CHECK_VALID_NODE(
|
||||
node,
|
||||
p_norm >= 0,
|
||||
"Only positive (including zero) values are supported for 'p' attribute.");
|
||||
|
||||
NodeVector slices =
|
||||
ngraph::builder::opset1::split(data, channels_count, channel_axis);
|
||||
|
@ -37,8 +37,8 @@ namespace ngraph
|
||||
std::shared_ptr<ngraph::Node> divisor{node.get_ng_inputs().at(1)};
|
||||
|
||||
std::int64_t fmod = node.get_attribute_value<std::int64_t>("fmod", 0);
|
||||
ASSERT_IS_SUPPORTED(node, fmod == 1)
|
||||
<< "Only 'fmod=1' mode is supported for mod operator.";
|
||||
CHECK_VALID_NODE(
|
||||
node, fmod == 1, "Only 'fmod=1' mode is supported for mod operator.");
|
||||
|
||||
return {std::make_shared<default_opset::Mod>(dividend, divisor)};
|
||||
}
|
||||
|
@ -79,8 +79,10 @@ namespace ngraph
|
||||
const auto center_point_box =
|
||||
node.get_attribute_value<std::int64_t>("center_point_box", 0);
|
||||
|
||||
ASSERT_IS_SUPPORTED(node, center_point_box == 0 || center_point_box == 1)
|
||||
<< "Allowed values of the 'center_point_box' attribute are 0 and 1.";
|
||||
CHECK_VALID_NODE(
|
||||
node,
|
||||
center_point_box == 0 || center_point_box == 1,
|
||||
"Allowed values of the 'center_point_box' attribute are 0 and 1.");
|
||||
|
||||
const auto box_encoding =
|
||||
center_point_box == 0
|
||||
|
@ -47,8 +47,7 @@ namespace
|
||||
}
|
||||
else
|
||||
{
|
||||
throw ngraph::onnx_import::error::InvalidArgument("Unsupported padding mode: [" + mode +
|
||||
"]");
|
||||
throw ngraph::ngraph_error("Unsupported padding mode: [" + mode + "]");
|
||||
}
|
||||
|
||||
return pad_mode;
|
||||
|
@ -109,7 +109,7 @@ namespace ngraph
|
||||
|
||||
if (bias)
|
||||
{
|
||||
throw error::NotSupported(
|
||||
throw ngraph_error(
|
||||
"Groups != 1 not supported for Quantized Convolution with "
|
||||
"bias.");
|
||||
}
|
||||
@ -198,22 +198,26 @@ namespace ngraph
|
||||
auto output_scale = inputs.at(6);
|
||||
auto output_zero_point = inputs.at(7);
|
||||
|
||||
ASSERT_VALID_ARGUMENT(
|
||||
node,
|
||||
((groups >= 0) &&
|
||||
(groups <= static_cast<int64_t>(data->get_shape().at(1))) &&
|
||||
(groups <= static_cast<int64_t>(filters->get_shape().at(0)))))
|
||||
<< "incorrect value of 'group' attribute: " << groups;
|
||||
CHECK_VALID_NODE(node,
|
||||
((groups >= 0) &&
|
||||
(groups <= static_cast<int64_t>(data->get_shape().at(1))) &&
|
||||
(groups <= static_cast<int64_t>(filters->get_shape().at(0)))),
|
||||
"incorrect value of 'group' attribute: ",
|
||||
groups);
|
||||
|
||||
std::size_t n_data_channels{data->get_shape().at(1)};
|
||||
std::size_t n_filters_channels{filters->get_shape().at(0)};
|
||||
|
||||
ASSERT_VALID_ARGUMENT(node, n_data_channels % groups == 0)
|
||||
<< "provided group attribute value must be a multiple of data channels "
|
||||
"count.";
|
||||
ASSERT_VALID_ARGUMENT(node, n_filters_channels % groups == 0)
|
||||
<< "provided group attribute value must be a multiple of filter channels "
|
||||
"count.";
|
||||
CHECK_VALID_NODE(
|
||||
node,
|
||||
n_data_channels % groups == 0,
|
||||
"provided group attribute value must be a multiple of data channels "
|
||||
"count.");
|
||||
CHECK_VALID_NODE(
|
||||
node,
|
||||
n_filters_channels % groups == 0,
|
||||
"provided group attribute value must be a multiple of filter channels "
|
||||
"count.");
|
||||
|
||||
Strides strides = convpool::get_strides(node);
|
||||
Strides filter_dilations = convpool::get_dilations(node);
|
||||
|
@ -34,8 +34,11 @@ namespace ngraph
|
||||
const float bias = node.get_attribute_value<float>("bias", 0.0f);
|
||||
const float lambd = node.get_attribute_value<float>("lambd", 0.5f);
|
||||
|
||||
ASSERT_VALID_ARGUMENT(node, !(lambd < 0.0f))
|
||||
<< " The provided 'lambd' value:" << lambd << " must not be negative.";
|
||||
CHECK_VALID_NODE(node,
|
||||
!(lambd < 0.0f),
|
||||
" The provided 'lambd' value: ",
|
||||
lambd,
|
||||
" must not be negative.");
|
||||
|
||||
std::shared_ptr<default_opset::Constant> negative_lambd;
|
||||
const auto input_element_type = input->get_element_type();
|
||||
|
@ -64,9 +64,13 @@ namespace ngraph
|
||||
|
||||
auto reduction_axes = detail::get_reduction_axes(node);
|
||||
|
||||
ASSERT_VALID_ARGUMENT(node, reduction_axes.size() <= data_shape.size())
|
||||
<< "provided reduction axes count (" << reduction_axes.size()
|
||||
<< ") is larger than input tensor rank (" << data_shape.size() << ")";
|
||||
CHECK_VALID_NODE(node,
|
||||
reduction_axes.size() <= data_shape.size(),
|
||||
"provided reduction axes count (",
|
||||
reduction_axes.size(),
|
||||
") is larger than input tensor rank (",
|
||||
data_shape.size(),
|
||||
")");
|
||||
|
||||
std::shared_ptr<ngraph::Node> op_node =
|
||||
reduction_function(ng_input, reduction_axes);
|
||||
@ -99,9 +103,13 @@ namespace ngraph
|
||||
|
||||
const auto reduction_axes = detail::get_reduction_axes(node);
|
||||
|
||||
ASSERT_VALID_ARGUMENT(node, reduction_axes.size() <= data_rank)
|
||||
<< "provided reduction axes count (" << reduction_axes.size()
|
||||
<< ") is larger than input tensor rank (" << data_rank << ")";
|
||||
CHECK_VALID_NODE(node,
|
||||
reduction_axes.size() <= data_rank,
|
||||
"provided reduction axes count (",
|
||||
reduction_axes.size(),
|
||||
") is larger than input tensor rank (",
|
||||
data_rank,
|
||||
")");
|
||||
|
||||
std::int64_t keepdims = node.get_attribute_value<std::int64_t>("keepdims", 1);
|
||||
|
||||
|
@ -16,6 +16,7 @@
|
||||
|
||||
#include "ngraph/op/pad.hpp"
|
||||
#include "ngraph/attribute_visitor.hpp"
|
||||
#include "ngraph/except.hpp"
|
||||
#include "ngraph/op/broadcast.hpp"
|
||||
#include "ngraph/op/constant.hpp"
|
||||
|
||||
|
@ -381,6 +381,8 @@ if (NGRAPH_ONNX_IMPORT_ENABLE AND NOT NGRAPH_USE_PROTOBUF_LITE)
|
||||
onnx/onnx_import_reshape.in.cpp
|
||||
onnx/onnx_import_rnn.in.cpp
|
||||
onnx/onnx_import_quant.in.cpp)
|
||||
list(APPEND SRC
|
||||
onnx/onnx_import_exceptions.cpp)
|
||||
endif()
|
||||
|
||||
foreach(BACKEND_NAME ${ACTIVE_BACKEND_LIST})
|
||||
|
@ -0,0 +1,38 @@
|
||||
ir_version: 3
|
||||
producer_name: "nGraph ONNX Importer"
|
||||
graph {
|
||||
node {
|
||||
input: "x"
|
||||
input: "y"
|
||||
output: "sum"
|
||||
op_type: "Add"
|
||||
}
|
||||
name: "test_add_dyn_shapes"
|
||||
input {
|
||||
name: "x"
|
||||
type {
|
||||
tensor_type {
|
||||
elem_type: 1
|
||||
}
|
||||
}
|
||||
}
|
||||
input {
|
||||
name: "y"
|
||||
type {
|
||||
tensor_type {
|
||||
elem_type: 1
|
||||
}
|
||||
}
|
||||
}
|
||||
output {
|
||||
name: "sum"
|
||||
type {
|
||||
tensor_type {
|
||||
elem_type: 1
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
opset_import {
|
||||
version: 1
|
||||
}
|
@ -0,0 +1,92 @@
|
||||
ir_version: 3
|
||||
producer_name: "nGraph ONNX Importer"
|
||||
graph {
|
||||
node {
|
||||
input: "x"
|
||||
input: "scale"
|
||||
input: "bias"
|
||||
output: "y"
|
||||
op_type: "InstanceNormalization"
|
||||
attribute {
|
||||
name: "epsilon"
|
||||
f: 0.01
|
||||
type: FLOAT
|
||||
}
|
||||
}
|
||||
name: "instance_norm_graph"
|
||||
input {
|
||||
name: "x"
|
||||
type {
|
||||
tensor_type {
|
||||
elem_type: 1
|
||||
shape {
|
||||
dim {
|
||||
dim_value: 1
|
||||
}
|
||||
dim {
|
||||
dim_value: 2
|
||||
}
|
||||
dim {
|
||||
dim_value: 3
|
||||
}
|
||||
dim {
|
||||
dim_value: 4
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
input {
|
||||
name: "scale"
|
||||
type {
|
||||
tensor_type {
|
||||
elem_type: 4
|
||||
shape {
|
||||
dim {
|
||||
dim_value: 2
|
||||
}
|
||||
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
input {
|
||||
name: "bias"
|
||||
type {
|
||||
tensor_type {
|
||||
elem_type: 1
|
||||
shape {
|
||||
dim {
|
||||
dim_value: 2
|
||||
}
|
||||
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
output {
|
||||
name: "y"
|
||||
type {
|
||||
tensor_type {
|
||||
elem_type: 1
|
||||
shape {
|
||||
dim {
|
||||
dim_value: 1
|
||||
}
|
||||
dim {
|
||||
dim_value: 2
|
||||
}
|
||||
dim {
|
||||
dim_value: 3
|
||||
}
|
||||
dim {
|
||||
dim_value: 4
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
opset_import {
|
||||
version: 1
|
||||
}
|
102
ngraph/test/onnx/onnx_import_exceptions.cpp
Normal file
102
ngraph/test/onnx/onnx_import_exceptions.cpp
Normal file
@ -0,0 +1,102 @@
|
||||
//*****************************************************************************
|
||||
// Copyright 2017-2020 Intel Corporation
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
//*****************************************************************************
|
||||
|
||||
#include <exception>
|
||||
|
||||
#include "gtest/gtest.h"
|
||||
#include "ngraph/file_util.hpp"
|
||||
#include "ngraph/frontend/onnx_import/exceptions.hpp"
|
||||
#include "ngraph/frontend/onnx_import/onnx.hpp"
|
||||
#include "ngraph/ngraph.hpp"
|
||||
#include "util/type_prop.hpp"
|
||||
|
||||
using namespace ngraph;
|
||||
|
||||
TEST(onnx_importer, exception_throws_ngraph_error)
|
||||
{
|
||||
EXPECT_THROW(onnx_import::import_onnx_model(file_util::path_join(
|
||||
SERIALIZED_ZOO, "onnx/depth_to_space_bad_blocksize.prototxt")),
|
||||
ngraph_error);
|
||||
}
|
||||
|
||||
TEST(onnx_importer, exception_msg_ngraph_error)
|
||||
{
|
||||
try
|
||||
{
|
||||
onnx_import::import_onnx_model(
|
||||
file_util::path_join(SERIALIZED_ZOO, "onnx/depth_to_space_bad_blocksize.prototxt"));
|
||||
// Should have thrown, so fail if it didn't
|
||||
FAIL() << "ONNX Importer did not detected incorrect model!";
|
||||
}
|
||||
catch (const ngraph_error& e)
|
||||
{
|
||||
EXPECT_HAS_SUBSTRING(e.what(),
|
||||
std::string("While validating ONNX node '<Node(DepthToSpace)"));
|
||||
EXPECT_HAS_SUBSTRING(e.what(), std::string("While validating node 'v0::DepthToSpace"));
|
||||
}
|
||||
catch (...)
|
||||
{
|
||||
FAIL() << "The ONNX model importer failed for unexpected reason";
|
||||
}
|
||||
}
|
||||
|
||||
TEST(onnx_importer, exception_msg_onnx_node_validation_failure)
|
||||
{
|
||||
try
|
||||
{
|
||||
onnx_import::import_onnx_model(
|
||||
file_util::path_join(SERIALIZED_ZOO, "onnx/instance_norm_bad_scale_type.prototxt"));
|
||||
// Should have thrown, so fail if it didn't
|
||||
FAIL() << "ONNX Importer did not detected incorrect model!";
|
||||
}
|
||||
catch (const ::ngraph::onnx_import::error::OnnxNodeValidationFailure& e)
|
||||
{
|
||||
EXPECT_HAS_SUBSTRING(
|
||||
e.what(), std::string("While validating ONNX node '<Node(InstanceNormalization)"));
|
||||
}
|
||||
// On MacOS after we re-throw OnnxNodeValidationFailure exception, we couldn't catch it as is,
|
||||
// thus below workaround.
|
||||
catch (const std::exception& e)
|
||||
{
|
||||
EXPECT_HAS_SUBSTRING(
|
||||
e.what(), std::string("While validating ONNX node '<Node(InstanceNormalization)"));
|
||||
}
|
||||
catch (...)
|
||||
{
|
||||
FAIL() << "The ONNX model importer failed for unexpected reason";
|
||||
}
|
||||
}
|
||||
|
||||
// This test aims to check for wrapping all std::exception not deriving from ngraph_error.
|
||||
// This test should throw a std error because of attempt to access shape from dynamic tensor.
|
||||
TEST(onnx_importer, exception_msg_std_err_wrapped)
|
||||
{
|
||||
try
|
||||
{
|
||||
onnx_import::import_onnx_model(file_util::path_join(
|
||||
SERIALIZED_ZOO, "onnx/dynamic_shapes/add_opset6_dyn_shape.prototxt"));
|
||||
// Should have thrown, so fail if it didn't
|
||||
FAIL() << "ONNX Importer did not detected incorrect model!";
|
||||
}
|
||||
catch (const std::exception& e)
|
||||
{
|
||||
EXPECT_HAS_SUBSTRING(e.what(), std::string("While validating ONNX node '<Node(Add)"));
|
||||
}
|
||||
catch (...)
|
||||
{
|
||||
FAIL() << "The ONNX model importer failed for unexpected reason";
|
||||
}
|
||||
}
|
Loading…
Reference in New Issue
Block a user