From 173ce2c907e3dc69663e79a91d1cc19abd20ca87 Mon Sep 17 00:00:00 2001 From: Adam Osewski Date: Wed, 15 Jul 2020 14:02:18 +0200 Subject: [PATCH] [ONNX] Exception handling refinements. (#1266) --- .../frontend/onnx_import/core/graph.cpp | 28 ++++- .../frontend/onnx_import/exceptions.hpp | 24 ----- .../frontend/onnx_import/op/batch_norm.cpp | 2 +- .../frontend/onnx_import/op/conv_integer.cpp | 8 +- .../frontend/onnx_import/op/eye_like.cpp | 8 +- .../frontend/onnx_import/op/leaky_relu.cpp | 4 +- .../frontend/onnx_import/op/lp_norm.cpp | 8 +- .../frontend/onnx_import/op/lp_pool.cpp | 6 +- .../ngraph/frontend/onnx_import/op/mod.cpp | 4 +- .../onnx_import/op/non_max_suppression.cpp | 6 +- .../ngraph/frontend/onnx_import/op/pad.cpp | 3 +- .../frontend/onnx_import/op/quant_conv.cpp | 30 +++--- .../ngraph/frontend/onnx_import/op/shrink.cpp | 7 +- .../frontend/onnx_import/utils/reduction.cpp | 20 ++-- ngraph/src/ngraph/op/pad.cpp | 1 + ngraph/test/CMakeLists.txt | 2 + .../add_opset6_dyn_shape.prototxt | 38 +++++++ .../instance_norm_bad_scale_type.prototxt | 92 ++++++++++++++++ ngraph/test/onnx/onnx_import_exceptions.cpp | 102 ++++++++++++++++++ 19 files changed, 326 insertions(+), 67 deletions(-) create mode 100644 ngraph/test/models/onnx/dynamic_shapes/add_opset6_dyn_shape.prototxt create mode 100644 ngraph/test/models/onnx/instance_norm_bad_scale_type.prototxt create mode 100644 ngraph/test/onnx/onnx_import_exceptions.cpp diff --git a/ngraph/src/ngraph/frontend/onnx_import/core/graph.cpp b/ngraph/src/ngraph/frontend/onnx_import/core/graph.cpp index 9ad9e8ed685..780fb60d5ad 100644 --- a/ngraph/src/ngraph/frontend/onnx_import/core/graph.cpp +++ b/ngraph/src/ngraph/frontend/onnx_import/core/graph.cpp @@ -14,11 +14,14 @@ // limitations under the License. //***************************************************************************** +#include #include #include #include +#include "exceptions.hpp" #include "graph.hpp" +#include "ngraph/log.hpp" #include "node.hpp" #include "provenance.hpp" #include "utils/common.hpp" @@ -190,8 +193,29 @@ namespace ngraph { const auto ng_node_factory = m_model->get_operator(onnx_node.op_type(), onnx_node.domain()); - - const auto ng_node_vector = ng_node_factory(onnx_node); + NodeVector ng_node_vector; + try + { + ng_node_vector = ng_node_factory(onnx_node); + } + catch (const ::ngraph::onnx_import::error::OnnxNodeValidationFailure& exc) + { + // Do nothing OnnxNodeValidationFailure exception already has ONNX node information. + throw; + } + catch (const std::exception& exc) + { + std::string msg_prefix = error::detail::get_error_msg_prefix(onnx_node); + throw ngraph_error(msg_prefix + ":\n" + std::string(exc.what())); + } + catch (...) + { + std::string msg_prefix = error::detail::get_error_msg_prefix(onnx_node); + // Since we do not know anything about current exception data type we can only + // notify user in this way. + NGRAPH_ERR << msg_prefix + "Unhandled exception type. \n"; + std::rethrow_exception(std::current_exception()); + } set_friendly_names(onnx_node, ng_node_vector); add_provenance_tags(onnx_node, ng_node_vector); diff --git a/ngraph/src/ngraph/frontend/onnx_import/exceptions.hpp b/ngraph/src/ngraph/frontend/onnx_import/exceptions.hpp index bb346561157..714da81dbb5 100644 --- a/ngraph/src/ngraph/frontend/onnx_import/exceptions.hpp +++ b/ngraph/src/ngraph/frontend/onnx_import/exceptions.hpp @@ -34,22 +34,6 @@ namespace ngraph std::string get_error_msg_prefix(const Node& node); } - struct NotSupported : AssertionFailure - { - explicit NotSupported(const std::string& what_arg) - : AssertionFailure(what_arg) - { - } - }; - - struct InvalidArgument : AssertionFailure - { - explicit InvalidArgument(const std::string& what_arg) - : AssertionFailure(what_arg) - { - } - }; - class OnnxNodeValidationFailure : public CheckFailure { public: @@ -67,14 +51,6 @@ namespace ngraph } // namespace ngraph -#define ASSERT_IS_SUPPORTED(node_, cond_) \ - NGRAPH_ASSERT_STREAM_DO_NOT_USE_IN_NEW_CODE(ngraph::onnx_import::error::NotSupported, cond_) \ - << (node_) << " " -#define ASSERT_VALID_ARGUMENT(node_, cond_) \ - NGRAPH_ASSERT_STREAM_DO_NOT_USE_IN_NEW_CODE(ngraph::onnx_import::error::InvalidArgument, \ - cond_) \ - << (node_) << " " - #define CHECK_VALID_NODE(node_, cond_, ...) \ NGRAPH_CHECK_HELPER( \ ::ngraph::onnx_import::error::OnnxNodeValidationFailure, (node_), (cond_), ##__VA_ARGS__) diff --git a/ngraph/src/ngraph/frontend/onnx_import/op/batch_norm.cpp b/ngraph/src/ngraph/frontend/onnx_import/op/batch_norm.cpp index 898cedd2b73..7e5a77f28bd 100644 --- a/ngraph/src/ngraph/frontend/onnx_import/op/batch_norm.cpp +++ b/ngraph/src/ngraph/frontend/onnx_import/op/batch_norm.cpp @@ -44,7 +44,7 @@ namespace ngraph // TODO: Implement learning mode support // float momentum{node.get_attribute_value("momentum", 0.9f)}; - ASSERT_IS_SUPPORTED(node, is_test) << "only 'is_test' mode is supported."; + CHECK_VALID_NODE(node, is_test, "only 'is_test' mode is supported."); // optional outputs auto after_bn_mean = std::make_shared(); diff --git a/ngraph/src/ngraph/frontend/onnx_import/op/conv_integer.cpp b/ngraph/src/ngraph/frontend/onnx_import/op/conv_integer.cpp index 7b15f449ef2..5d076043282 100644 --- a/ngraph/src/ngraph/frontend/onnx_import/op/conv_integer.cpp +++ b/ngraph/src/ngraph/frontend/onnx_import/op/conv_integer.cpp @@ -39,9 +39,11 @@ namespace ngraph auto filters = inputs.at(1); int64_t groups{node.get_attribute_value("group", 1)}; - ASSERT_VALID_ARGUMENT(node, (groups == 1)) - << "Only value of 1 for 'group' supported for ConvInteger. Given: " - << groups; + CHECK_VALID_NODE( + node, + groups == 1, + "Only value of 1 for 'group' supported for ConvInteger. Given: ", + groups); auto window_movement_strides = convpool::get_strides(node); auto window_dilation_strides = convpool::get_dilations(node); diff --git a/ngraph/src/ngraph/frontend/onnx_import/op/eye_like.cpp b/ngraph/src/ngraph/frontend/onnx_import/op/eye_like.cpp index 9da9a8fc104..3cbde4a5d21 100644 --- a/ngraph/src/ngraph/frontend/onnx_import/op/eye_like.cpp +++ b/ngraph/src/ngraph/frontend/onnx_import/op/eye_like.cpp @@ -47,9 +47,11 @@ namespace ngraph target_type = input->get_element_type(); } - ASSERT_VALID_ARGUMENT(node, input_shape.size() == 2) - << "The provided shape rank: " << input_shape.size() - << " is unsupported, only 2D shapes are supported"; + CHECK_VALID_NODE(node, + input_shape.size() == 2, + "The provided shape rank: ", + input_shape.size(), + " is unsupported, only 2D shapes are supported"); std::shared_ptr eye_like_matrix = common::shifted_square_identity(input_shape, target_type, shift); diff --git a/ngraph/src/ngraph/frontend/onnx_import/op/leaky_relu.cpp b/ngraph/src/ngraph/frontend/onnx_import/op/leaky_relu.cpp index 28242f0b580..30a08faecdd 100644 --- a/ngraph/src/ngraph/frontend/onnx_import/op/leaky_relu.cpp +++ b/ngraph/src/ngraph/frontend/onnx_import/op/leaky_relu.cpp @@ -33,8 +33,8 @@ namespace ngraph auto data = node.get_ng_inputs().at(0); double alpha = node.get_attribute_value("alpha", 0.01); - ASSERT_VALID_ARGUMENT(node, ((alpha >= 0) && (alpha <= 1))) - << " alpha value should be in range (0,1)"; + CHECK_VALID_NODE( + node, alpha >= 0 && alpha <= 1, " alpha value should be in range (0,1)"); std::shared_ptr alpha_node = default_opset::Constant::create(data->get_element_type(), Shape{}, {alpha}); diff --git a/ngraph/src/ngraph/frontend/onnx_import/op/lp_norm.cpp b/ngraph/src/ngraph/frontend/onnx_import/op/lp_norm.cpp index 9b20a1f5b0a..3d8eaa7dabc 100644 --- a/ngraph/src/ngraph/frontend/onnx_import/op/lp_norm.cpp +++ b/ngraph/src/ngraph/frontend/onnx_import/op/lp_norm.cpp @@ -51,9 +51,11 @@ namespace ngraph const size_t normalize_axis = ngraph::normalize_axis(node.get_description(), axis, data_rank); - ASSERT_VALID_ARGUMENT(node, p_norm == 1 || p_norm == 2) - << "Invalid `p` attribute value: " << p_norm - << "Only normalization of 1st or 2nd order is supported."; + CHECK_VALID_NODE(node, + p_norm == 1 || p_norm == 2, + "Invalid `p` attribute value: ", + p_norm, + "Only normalization of 1st or 2nd order is supported."); const auto normalize_axis_const = default_opset::Constant::create(element::i64, {}, {normalize_axis}); diff --git a/ngraph/src/ngraph/frontend/onnx_import/op/lp_pool.cpp b/ngraph/src/ngraph/frontend/onnx_import/op/lp_pool.cpp index e8d7d534a77..319491d9412 100644 --- a/ngraph/src/ngraph/frontend/onnx_import/op/lp_pool.cpp +++ b/ngraph/src/ngraph/frontend/onnx_import/op/lp_pool.cpp @@ -53,8 +53,10 @@ namespace ngraph const std::size_t channels_count = data_shape[channel_axis].get_length(); const std::int64_t p_norm{node.get_attribute_value("p", 2)}; - ASSERT_VALID_ARGUMENT(node, p_norm >= 0) - << "Only positive (including zero) values are supported for 'p' attribute."; + CHECK_VALID_NODE( + node, + p_norm >= 0, + "Only positive (including zero) values are supported for 'p' attribute."); NodeVector slices = ngraph::builder::opset1::split(data, channels_count, channel_axis); diff --git a/ngraph/src/ngraph/frontend/onnx_import/op/mod.cpp b/ngraph/src/ngraph/frontend/onnx_import/op/mod.cpp index 551f12b28ef..492af78d2c5 100644 --- a/ngraph/src/ngraph/frontend/onnx_import/op/mod.cpp +++ b/ngraph/src/ngraph/frontend/onnx_import/op/mod.cpp @@ -37,8 +37,8 @@ namespace ngraph std::shared_ptr divisor{node.get_ng_inputs().at(1)}; std::int64_t fmod = node.get_attribute_value("fmod", 0); - ASSERT_IS_SUPPORTED(node, fmod == 1) - << "Only 'fmod=1' mode is supported for mod operator."; + CHECK_VALID_NODE( + node, fmod == 1, "Only 'fmod=1' mode is supported for mod operator."); return {std::make_shared(dividend, divisor)}; } diff --git a/ngraph/src/ngraph/frontend/onnx_import/op/non_max_suppression.cpp b/ngraph/src/ngraph/frontend/onnx_import/op/non_max_suppression.cpp index 4c88dcc9eac..1563807206e 100644 --- a/ngraph/src/ngraph/frontend/onnx_import/op/non_max_suppression.cpp +++ b/ngraph/src/ngraph/frontend/onnx_import/op/non_max_suppression.cpp @@ -79,8 +79,10 @@ namespace ngraph const auto center_point_box = node.get_attribute_value("center_point_box", 0); - ASSERT_IS_SUPPORTED(node, center_point_box == 0 || center_point_box == 1) - << "Allowed values of the 'center_point_box' attribute are 0 and 1."; + CHECK_VALID_NODE( + node, + center_point_box == 0 || center_point_box == 1, + "Allowed values of the 'center_point_box' attribute are 0 and 1."); const auto box_encoding = center_point_box == 0 diff --git a/ngraph/src/ngraph/frontend/onnx_import/op/pad.cpp b/ngraph/src/ngraph/frontend/onnx_import/op/pad.cpp index aef415439a4..7c70f55b0ca 100644 --- a/ngraph/src/ngraph/frontend/onnx_import/op/pad.cpp +++ b/ngraph/src/ngraph/frontend/onnx_import/op/pad.cpp @@ -47,8 +47,7 @@ namespace } else { - throw ngraph::onnx_import::error::InvalidArgument("Unsupported padding mode: [" + mode + - "]"); + throw ngraph::ngraph_error("Unsupported padding mode: [" + mode + "]"); } return pad_mode; diff --git a/ngraph/src/ngraph/frontend/onnx_import/op/quant_conv.cpp b/ngraph/src/ngraph/frontend/onnx_import/op/quant_conv.cpp index 8e960f7531f..fe16635d3ac 100644 --- a/ngraph/src/ngraph/frontend/onnx_import/op/quant_conv.cpp +++ b/ngraph/src/ngraph/frontend/onnx_import/op/quant_conv.cpp @@ -109,7 +109,7 @@ namespace ngraph if (bias) { - throw error::NotSupported( + throw ngraph_error( "Groups != 1 not supported for Quantized Convolution with " "bias."); } @@ -198,22 +198,26 @@ namespace ngraph auto output_scale = inputs.at(6); auto output_zero_point = inputs.at(7); - ASSERT_VALID_ARGUMENT( - node, - ((groups >= 0) && - (groups <= static_cast(data->get_shape().at(1))) && - (groups <= static_cast(filters->get_shape().at(0))))) - << "incorrect value of 'group' attribute: " << groups; + CHECK_VALID_NODE(node, + ((groups >= 0) && + (groups <= static_cast(data->get_shape().at(1))) && + (groups <= static_cast(filters->get_shape().at(0)))), + "incorrect value of 'group' attribute: ", + groups); std::size_t n_data_channels{data->get_shape().at(1)}; std::size_t n_filters_channels{filters->get_shape().at(0)}; - ASSERT_VALID_ARGUMENT(node, n_data_channels % groups == 0) - << "provided group attribute value must be a multiple of data channels " - "count."; - ASSERT_VALID_ARGUMENT(node, n_filters_channels % groups == 0) - << "provided group attribute value must be a multiple of filter channels " - "count."; + CHECK_VALID_NODE( + node, + n_data_channels % groups == 0, + "provided group attribute value must be a multiple of data channels " + "count."); + CHECK_VALID_NODE( + node, + n_filters_channels % groups == 0, + "provided group attribute value must be a multiple of filter channels " + "count."); Strides strides = convpool::get_strides(node); Strides filter_dilations = convpool::get_dilations(node); diff --git a/ngraph/src/ngraph/frontend/onnx_import/op/shrink.cpp b/ngraph/src/ngraph/frontend/onnx_import/op/shrink.cpp index 3daa2fbbe2c..ba8c29f026d 100644 --- a/ngraph/src/ngraph/frontend/onnx_import/op/shrink.cpp +++ b/ngraph/src/ngraph/frontend/onnx_import/op/shrink.cpp @@ -34,8 +34,11 @@ namespace ngraph const float bias = node.get_attribute_value("bias", 0.0f); const float lambd = node.get_attribute_value("lambd", 0.5f); - ASSERT_VALID_ARGUMENT(node, !(lambd < 0.0f)) - << " The provided 'lambd' value:" << lambd << " must not be negative."; + CHECK_VALID_NODE(node, + !(lambd < 0.0f), + " The provided 'lambd' value: ", + lambd, + " must not be negative."); std::shared_ptr negative_lambd; const auto input_element_type = input->get_element_type(); diff --git a/ngraph/src/ngraph/frontend/onnx_import/utils/reduction.cpp b/ngraph/src/ngraph/frontend/onnx_import/utils/reduction.cpp index 97c2a1e0c98..9d01398df5a 100644 --- a/ngraph/src/ngraph/frontend/onnx_import/utils/reduction.cpp +++ b/ngraph/src/ngraph/frontend/onnx_import/utils/reduction.cpp @@ -64,9 +64,13 @@ namespace ngraph auto reduction_axes = detail::get_reduction_axes(node); - ASSERT_VALID_ARGUMENT(node, reduction_axes.size() <= data_shape.size()) - << "provided reduction axes count (" << reduction_axes.size() - << ") is larger than input tensor rank (" << data_shape.size() << ")"; + CHECK_VALID_NODE(node, + reduction_axes.size() <= data_shape.size(), + "provided reduction axes count (", + reduction_axes.size(), + ") is larger than input tensor rank (", + data_shape.size(), + ")"); std::shared_ptr op_node = reduction_function(ng_input, reduction_axes); @@ -99,9 +103,13 @@ namespace ngraph const auto reduction_axes = detail::get_reduction_axes(node); - ASSERT_VALID_ARGUMENT(node, reduction_axes.size() <= data_rank) - << "provided reduction axes count (" << reduction_axes.size() - << ") is larger than input tensor rank (" << data_rank << ")"; + CHECK_VALID_NODE(node, + reduction_axes.size() <= data_rank, + "provided reduction axes count (", + reduction_axes.size(), + ") is larger than input tensor rank (", + data_rank, + ")"); std::int64_t keepdims = node.get_attribute_value("keepdims", 1); diff --git a/ngraph/src/ngraph/op/pad.cpp b/ngraph/src/ngraph/op/pad.cpp index e80e30d18a8..a1afc54a9e0 100644 --- a/ngraph/src/ngraph/op/pad.cpp +++ b/ngraph/src/ngraph/op/pad.cpp @@ -16,6 +16,7 @@ #include "ngraph/op/pad.hpp" #include "ngraph/attribute_visitor.hpp" +#include "ngraph/except.hpp" #include "ngraph/op/broadcast.hpp" #include "ngraph/op/constant.hpp" diff --git a/ngraph/test/CMakeLists.txt b/ngraph/test/CMakeLists.txt index 81dc83652f0..1e708ed180d 100644 --- a/ngraph/test/CMakeLists.txt +++ b/ngraph/test/CMakeLists.txt @@ -381,6 +381,8 @@ if (NGRAPH_ONNX_IMPORT_ENABLE AND NOT NGRAPH_USE_PROTOBUF_LITE) onnx/onnx_import_reshape.in.cpp onnx/onnx_import_rnn.in.cpp onnx/onnx_import_quant.in.cpp) + list(APPEND SRC + onnx/onnx_import_exceptions.cpp) endif() foreach(BACKEND_NAME ${ACTIVE_BACKEND_LIST}) diff --git a/ngraph/test/models/onnx/dynamic_shapes/add_opset6_dyn_shape.prototxt b/ngraph/test/models/onnx/dynamic_shapes/add_opset6_dyn_shape.prototxt new file mode 100644 index 00000000000..a8ddd863e01 --- /dev/null +++ b/ngraph/test/models/onnx/dynamic_shapes/add_opset6_dyn_shape.prototxt @@ -0,0 +1,38 @@ +ir_version: 3 +producer_name: "nGraph ONNX Importer" +graph { + node { + input: "x" + input: "y" + output: "sum" + op_type: "Add" + } + name: "test_add_dyn_shapes" + input { + name: "x" + type { + tensor_type { + elem_type: 1 + } + } + } + input { + name: "y" + type { + tensor_type { + elem_type: 1 + } + } + } + output { + name: "sum" + type { + tensor_type { + elem_type: 1 + } + } + } +} +opset_import { + version: 1 +} diff --git a/ngraph/test/models/onnx/instance_norm_bad_scale_type.prototxt b/ngraph/test/models/onnx/instance_norm_bad_scale_type.prototxt new file mode 100644 index 00000000000..8a13b675ffb --- /dev/null +++ b/ngraph/test/models/onnx/instance_norm_bad_scale_type.prototxt @@ -0,0 +1,92 @@ +ir_version: 3 +producer_name: "nGraph ONNX Importer" +graph { + node { + input: "x" + input: "scale" + input: "bias" + output: "y" + op_type: "InstanceNormalization" + attribute { + name: "epsilon" + f: 0.01 + type: FLOAT + } + } + name: "instance_norm_graph" + input { + name: "x" + type { + tensor_type { + elem_type: 1 + shape { + dim { + dim_value: 1 + } + dim { + dim_value: 2 + } + dim { + dim_value: 3 + } + dim { + dim_value: 4 + } + } + } + } + } + input { + name: "scale" + type { + tensor_type { + elem_type: 4 + shape { + dim { + dim_value: 2 + } + + } + } + } + } + input { + name: "bias" + type { + tensor_type { + elem_type: 1 + shape { + dim { + dim_value: 2 + } + + } + } + } + } + output { + name: "y" + type { + tensor_type { + elem_type: 1 + shape { + dim { + dim_value: 1 + } + dim { + dim_value: 2 + } + dim { + dim_value: 3 + } + dim { + dim_value: 4 + } + } + } + } + } +} +opset_import { + version: 1 +} diff --git a/ngraph/test/onnx/onnx_import_exceptions.cpp b/ngraph/test/onnx/onnx_import_exceptions.cpp new file mode 100644 index 00000000000..e1aa60ac3ef --- /dev/null +++ b/ngraph/test/onnx/onnx_import_exceptions.cpp @@ -0,0 +1,102 @@ +//***************************************************************************** +// Copyright 2017-2020 Intel Corporation +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +//***************************************************************************** + +#include + +#include "gtest/gtest.h" +#include "ngraph/file_util.hpp" +#include "ngraph/frontend/onnx_import/exceptions.hpp" +#include "ngraph/frontend/onnx_import/onnx.hpp" +#include "ngraph/ngraph.hpp" +#include "util/type_prop.hpp" + +using namespace ngraph; + +TEST(onnx_importer, exception_throws_ngraph_error) +{ + EXPECT_THROW(onnx_import::import_onnx_model(file_util::path_join( + SERIALIZED_ZOO, "onnx/depth_to_space_bad_blocksize.prototxt")), + ngraph_error); +} + +TEST(onnx_importer, exception_msg_ngraph_error) +{ + try + { + onnx_import::import_onnx_model( + file_util::path_join(SERIALIZED_ZOO, "onnx/depth_to_space_bad_blocksize.prototxt")); + // Should have thrown, so fail if it didn't + FAIL() << "ONNX Importer did not detected incorrect model!"; + } + catch (const ngraph_error& e) + { + EXPECT_HAS_SUBSTRING(e.what(), + std::string("While validating ONNX node '