diff --git a/ngraph/src/ngraph/frontend/onnx_import/CMakeLists.txt b/ngraph/src/ngraph/frontend/onnx_import/CMakeLists.txt index a1e5d40338c..82b6a181029 100644 --- a/ngraph/src/ngraph/frontend/onnx_import/CMakeLists.txt +++ b/ngraph/src/ngraph/frontend/onnx_import/CMakeLists.txt @@ -14,7 +14,7 @@ # limitations under the License. # ****************************************************************************** -set(ONNX_OPSET_VERSION 11 CACHE INTERNAL "Supported version of ONNX operator set") +set(ONNX_OPSET_VERSION 13 CACHE INTERNAL "Supported version of ONNX operator set") add_library(onnx_importer SHARED core/node.cpp @@ -86,8 +86,8 @@ add_library(onnx_importer SHARED op/cum_sum.hpp op/depth_to_space.cpp op/depth_to_space.hpp - # op/dequantize_linear.cpp - # op/dequantize_linear.hpp + op/dequantize_linear.cpp + op/dequantize_linear.hpp op/div.hpp op/dropout.hpp op/elu.cpp diff --git a/ngraph/src/ngraph/frontend/onnx_import/op/dequantize_linear.cpp b/ngraph/src/ngraph/frontend/onnx_import/op/dequantize_linear.cpp index 93d4851e0c8..aebea5a2957 100644 --- a/ngraph/src/ngraph/frontend/onnx_import/op/dequantize_linear.cpp +++ b/ngraph/src/ngraph/frontend/onnx_import/op/dequantize_linear.cpp @@ -23,9 +23,9 @@ #include "ngraph/builder/make_constant.hpp" #include "ngraph/op/convert.hpp" #include "ngraph/op/dequantize.hpp" -#include "ngraph/opsets/opset0.hpp" #include "ngraph/shape.hpp" #include "ngraph/validation_util.hpp" +#include "utils/common.hpp" namespace ngraph { @@ -33,54 +33,177 @@ namespace ngraph { namespace op { + namespace + { + std::shared_ptr get_zero_point(const NodeVector& inputs) + { + if (inputs.size() == 3 && !inputs[2]->is_null()) + { + auto zero_point = inputs[2]; + + if (zero_point->get_element_type() != element::f32) + { + zero_point = + std::make_shared(zero_point, element::f32); + } + + return zero_point; + } + else + { + return default_opset::Constant::create(element::f32, Shape{}, {0}); + } + } + } namespace set_1 { NodeVector dequantize_linear(const Node& node) { - NodeVector inputs{node.get_ng_inputs()}; - std::shared_ptr x = inputs.at(0); - std::shared_ptr x_scale = inputs.at(1); - std::shared_ptr zero_point; - if (inputs.size() == 3 && !inputs.at(2)->is_null()) + const NodeVector inputs{node.get_ng_inputs()}; + + NGRAPH_CHECK( + 2 <= inputs.size() && inputs.size() <= 3, + "The DequantizeLinear op expects 2 required and one optional input. Got: ", + inputs.size()); + + const auto x = inputs[0]; + const auto scale = inputs[1]; + const auto zero_point = get_zero_point(inputs); + + common::validate_scalar_input("Dequantization scale", scale, {element::f32}); + common::validate_scalar_input("Zero point", zero_point); + + const auto converted_x = + std::make_shared(x, element::f32); + + return {std::make_shared( + std::make_shared(converted_x, zero_point), scale)}; + } + } + + namespace set_13 + { + namespace + { + void validate_scale(const std::shared_ptr scale, + const std::shared_ptr x, + const int64_t axis) { - zero_point = inputs.at(2); - } - else - { - zero_point = - ngraph::builder::make_constant(x->get_element_type(), Shape{}, 0); + const auto& scale_shape = scale->get_output_partial_shape(0); + NGRAPH_CHECK(scale_shape.rank().get_length() == 0 || + scale_shape.rank().get_length() == 1, + "Dequantization scale needs to be a scalar or a vector."); + + if (scale_shape.rank().get_length() == 1) + { + const auto& scale_dim = scale_shape[0]; + const auto& x_shape = x->get_output_partial_shape(0); + const auto& x_dim_at_axis = x_shape[axis]; + + NGRAPH_CHECK(scale_dim.same_scheme(x_dim_at_axis), + "The number of dequantization scale elements '", + scale_dim, + "' must match the input shape dimension '", + x_dim_at_axis, + " pointed to by the axis attribute: ", + axis); + } } - Shape y_scale_shape = x_scale->get_shape(); - Shape y_zero_point_shape = zero_point->get_shape(); - - // get axis twice with two default values to see if it is set - int64_t axis_0{node.get_attribute_value("axis", 0)}; - int64_t axis_1{node.get_attribute_value("axis", 1)}; - - const auto data_rank = x->get_output_partial_shape(0).rank(); - AxisSet axes; - // if axis attribute is set - if (axis_0 == axis_1) + void validate_zero_point(const std::shared_ptr zero_point, + const std::shared_ptr x, + const int64_t axis) { - axes.insert( - ngraph::normalize_axis(node.get_description(), axis_0, data_rank)); + const auto& zero_point_shape = zero_point->get_output_partial_shape(0); + NGRAPH_CHECK(zero_point_shape.rank().get_length() == 0 || + zero_point_shape.rank().get_length() == 1, + "Zero point needs to be a scalar or a vector."); + + if (zero_point_shape.rank().get_length() == 1) + { + const auto& zero_point_dim = zero_point_shape[0]; + const auto& x_shape = x->get_output_partial_shape(0); + const auto& x_dim_at_axis = x_shape[axis]; + + NGRAPH_CHECK(zero_point_dim.same_scheme(x_dim_at_axis), + "The number of zero point elements '", + zero_point_dim, + "' must match the input shape dimension '", + x_dim_at_axis, + " pointed to by the axis attribute: ", + axis); + } } - if (x->get_element_type() != zero_point->get_element_type()) + std::shared_ptr + reshape_input(const std::shared_ptr input, + const int64_t axis, + const PartialShape& x_shape) { - zero_point = std::make_shared( - zero_point, x->get_element_type()); - } + std::vector target_dims; - return {std::make_shared( - x, x_scale, zero_point, x_scale->get_element_type(), axes)}; + for (size_t i = 0; i < axis; ++i) + { + target_dims.push_back(1); + } + + // copy dimension at axis from input X + if (x_shape[axis].is_static()) + { + target_dims.push_back(x_shape[axis].get_length()); + } + else + { + target_dims.push_back(0); + } + + for (size_t i = axis + 1; i < x_shape.rank().get_length(); ++i) + { + target_dims.push_back(1); + } + + const auto target_shape = default_opset::Constant::create( + element::i64, Shape{target_dims.size()}, target_dims); + + return std::make_shared(input, target_shape, true); + } } - } // namespace set_1 + NodeVector dequantize_linear(const Node& node) + { + const NodeVector inputs{node.get_ng_inputs()}; - } // namespace op + NGRAPH_CHECK(2 <= inputs.size() && inputs.size() <= 3, + "The DequantizeLinear op expects 2 required and one optional " + "input. Got: ", + inputs.size()); - } // namespace onnx_import + const auto x = inputs[0]; + auto scale = inputs[1]; + auto zero_point = get_zero_point(inputs); -} // namespace ngraph + const auto x_shape = x->get_output_partial_shape(0); + + NGRAPH_CHECK(x_shape.rank().is_static(), + "Rank of the input data tensor has to be known (static)."); + + int64_t axis{node.get_attribute_value("axis", 1)}; + axis = ngraph::normalize_axis(node.get_description(), axis, x_shape.rank()); + + validate_scale(scale, x, axis); + validate_zero_point(zero_point, x, axis); + + // these reshapes make sure that dequantization happens over the specified axis + scale = reshape_input(scale, axis, x_shape); + zero_point = reshape_input(zero_point, axis, x_shape); + + const auto converted_x = + std::make_shared(x, element::f32); + + return {std::make_shared( + std::make_shared(converted_x, zero_point), scale)}; + } + } + } + } +} diff --git a/ngraph/src/ngraph/frontend/onnx_import/op/dequantize_linear.hpp b/ngraph/src/ngraph/frontend/onnx_import/op/dequantize_linear.hpp index 41dd150a4b6..dc6709358cd 100644 --- a/ngraph/src/ngraph/frontend/onnx_import/op/dequantize_linear.hpp +++ b/ngraph/src/ngraph/frontend/onnx_import/op/dequantize_linear.hpp @@ -31,6 +31,11 @@ namespace ngraph } // namespace set_1 + namespace set_13 + { + NodeVector dequantize_linear(const Node& node); + } + } // namespace op } // namespace onnx_import diff --git a/ngraph/src/ngraph/frontend/onnx_import/ops_bridge.cpp b/ngraph/src/ngraph/frontend/onnx_import/ops_bridge.cpp index 6c7c4c004ad..1fee117f410 100644 --- a/ngraph/src/ngraph/frontend/onnx_import/ops_bridge.cpp +++ b/ngraph/src/ngraph/frontend/onnx_import/ops_bridge.cpp @@ -48,7 +48,7 @@ #include "op/cosh.hpp" #include "op/cum_sum.hpp" #include "op/depth_to_space.hpp" -// #include "op/dequantize_linear.hpp" +#include "op/dequantize_linear.hpp" #include "op/div.hpp" #include "op/dropout.hpp" #include "op/elu.hpp" @@ -278,7 +278,8 @@ namespace ngraph REGISTER_OPERATOR("Cosh", 1, cosh); REGISTER_OPERATOR("CumSum", 1, cum_sum); REGISTER_OPERATOR("DepthToSpace", 1, depth_to_space); - // REGISTER_OPERATOR("DequantizeLinear", 1, dequantize_linear); + REGISTER_OPERATOR("DequantizeLinear", 1, dequantize_linear); + REGISTER_OPERATOR("DequantizeLinear", 13, dequantize_linear); REGISTER_OPERATOR("Div", 1, div); REGISTER_OPERATOR("Div", 7, div); REGISTER_OPERATOR("Dropout", 1, dropout); diff --git a/ngraph/src/ngraph/frontend/onnx_import/utils/common.cpp b/ngraph/src/ngraph/frontend/onnx_import/utils/common.cpp index a803f65d15b..ffe432d4f9c 100644 --- a/ngraph/src/ngraph/frontend/onnx_import/utils/common.cpp +++ b/ngraph/src/ngraph/frontend/onnx_import/utils/common.cpp @@ -67,6 +67,26 @@ namespace ngraph default_opset::Constant::create(element::i64, {}, {step})); } + void validate_scalar_input(const char* input_name, + const std::shared_ptr input, + const std::set allowed_types) + { + const auto validated_input_rank = input->get_output_partial_shape(0).rank(); + + NGRAPH_CHECK( + validated_input_rank.same_scheme({0}), input_name, " needs to be a scalar."); + + if (!allowed_types.empty()) + { + const bool data_type_ok = allowed_types.count(input->get_element_type()); + NGRAPH_CHECK(data_type_ok, + "Incorrect data type of the ", + input_name, + " input: ", + input->get_element_type()); + } + } + } // namespace common } // namespace onnx_import } // namespace ngraph diff --git a/ngraph/src/ngraph/frontend/onnx_import/utils/common.hpp b/ngraph/src/ngraph/frontend/onnx_import/utils/common.hpp index 2aaa14eee90..69af5b52a06 100644 --- a/ngraph/src/ngraph/frontend/onnx_import/utils/common.hpp +++ b/ngraph/src/ngraph/frontend/onnx_import/utils/common.hpp @@ -127,6 +127,16 @@ namespace ngraph return shifted_square_identity(Shape{n, n}, type, 0); } + /// \brief Performs validation of an input that is expected to be a scalar. + /// \note This function throws an exception if any of the validation steps fails. + /// + /// \param[in] input_name A human-readable name of an input (used for logging) + /// \param[in] input An input node to be validated + /// \param[in] allowed_types An optional set of allowed element types for this input + void validate_scalar_input(const char* input_name, + const std::shared_ptr input, + const std::set allowed_types = {}); + } // namespace common } // namespace onnx_import } // namespace ngraph diff --git a/ngraph/test/models/onnx/dequantize_linear_2.prototxt b/ngraph/test/models/onnx/dequantize_linear_2.prototxt index d5ffdda754f..4629456efdd 100644 --- a/ngraph/test/models/onnx/dequantize_linear_2.prototxt +++ b/ngraph/test/models/onnx/dequantize_linear_2.prototxt @@ -75,5 +75,5 @@ graph { } } opset_import { - version: 10 + version: 13 } diff --git a/ngraph/test/models/onnx/dequantize_linear_3.prototxt b/ngraph/test/models/onnx/dequantize_linear_3.prototxt index 1c01ba3b1b6..112312fd08e 100644 --- a/ngraph/test/models/onnx/dequantize_linear_3.prototxt +++ b/ngraph/test/models/onnx/dequantize_linear_3.prototxt @@ -75,5 +75,5 @@ graph { } } opset_import { - version: 10 + version: 13 } diff --git a/ngraph/test/models/onnx/dequantize_linear_4.prototxt b/ngraph/test/models/onnx/dequantize_linear_4.prototxt index 2f081f59f5a..422046dea2c 100644 --- a/ngraph/test/models/onnx/dequantize_linear_4.prototxt +++ b/ngraph/test/models/onnx/dequantize_linear_4.prototxt @@ -87,5 +87,5 @@ graph { } } opset_import { - version: 10 + version: 13 } diff --git a/ngraph/test/models/onnx/dequantize_linear_5.prototxt b/ngraph/test/models/onnx/dequantize_linear_5.prototxt index 0b5ee82dd04..2d9466c5af0 100644 --- a/ngraph/test/models/onnx/dequantize_linear_5.prototxt +++ b/ngraph/test/models/onnx/dequantize_linear_5.prototxt @@ -75,5 +75,5 @@ graph { } } opset_import { - version: 10 + version: 13 } diff --git a/ngraph/test/onnx/onnx_import_quant.in.cpp b/ngraph/test/onnx/onnx_import_quant.in.cpp index 02f2b86a2b2..3bd50870fda 100644 --- a/ngraph/test/onnx/onnx_import_quant.in.cpp +++ b/ngraph/test/onnx/onnx_import_quant.in.cpp @@ -156,7 +156,7 @@ NGRAPH_TEST(${BACKEND_NAME}, onnx_model_dequantize_linear_1d_zero_scale_uint8) auto function = onnx_import::import_onnx_model( file_util::path_join(SERIALIZED_ZOO, "onnx/dequantize_linear_2.prototxt")); - auto test_case = test::TestCase(function); + auto test_case = ngraph::test::TestCase(function); test_case.add_input(std::vector{0, 1, 2, 3, 0, 1, 2, 3, 0, 10, 20, 30}); // x test_case.add_input(std::vector{1.0f, 2.0f, 4.0f}); // scale @@ -174,7 +174,7 @@ NGRAPH_TEST(${BACKEND_NAME}, onnx_model_dequantize_linear_1d_zero_scale_int8) auto function = onnx_import::import_onnx_model( file_util::path_join(SERIALIZED_ZOO, "onnx/dequantize_linear_3.prototxt")); - auto test_case = test::TestCase(function); + auto test_case = ngraph::test::TestCase(function); test_case.add_input(std::vector{0, 1, 2, 3, 0, 2, 4, 6, 0, 10, 20, 30}); // x test_case.add_input(std::vector{1.0f, 2.0f, 4.0f, 8.0f}); // scale @@ -192,7 +192,7 @@ NGRAPH_TEST(${BACKEND_NAME}, onnx_model_dequantize_linear_1d_zero_scale_int8_4d) auto function = onnx_import::import_onnx_model( file_util::path_join(SERIALIZED_ZOO, "onnx/dequantize_linear_4.prototxt")); - auto test_case = test::TestCase(function); + auto test_case = ngraph::test::TestCase(function); test_case.add_input(std::vector{7, 9, 10, 10, 5, 8, 9, 1, 8, 6, 7, 9, 10, 0, 7, 10, 8, 2, 6, 0, 5, 9, 8, 1, 2, 7, 5, 3, 2, 4, 1, 3, 8, 7, @@ -216,7 +216,7 @@ NGRAPH_TEST(${BACKEND_NAME}, onnx_model_dequantize_linear_1d_zero_scale_uint8_ne auto function = onnx_import::import_onnx_model( file_util::path_join(SERIALIZED_ZOO, "onnx/dequantize_linear_5.prototxt")); - auto test_case = test::TestCase(function); + auto test_case = ngraph::test::TestCase(function); test_case.add_input(std::vector{0, 1, 2, 3, 0, 1, 2, 3, 0, 10, 20, 30}); // x test_case.add_input(std::vector{1.0f, 2.0f, 4.0f}); // scale diff --git a/ngraph/test/runtime/ie/unit_test.manifest b/ngraph/test/runtime/ie/unit_test.manifest index eb905e31c72..e411c917670 100644 --- a/ngraph/test/runtime/ie/unit_test.manifest +++ b/ngraph/test/runtime/ie/unit_test.manifest @@ -17,14 +17,12 @@ onnx_model_quantize_linear_zero_point onnx_model_quantize_linear_axis_zero onnx_model_quantize_linear_axis_negative -# Not supported ONNX op: DequantizeLinear -onnx_model_dequantize_linear -onnx_model_dequantize_linear_scalar_zero_scale_uint8 -onnx_model_dequantize_linear_scalar_zero_scale_int8 -onnx_model_dequantize_linear_1d_zero_scale_uint8 -onnx_model_dequantize_linear_1d_zero_scale_int8 -onnx_model_dequantize_linear_1d_zero_scale_int8_4d -onnx_model_dequantize_linear_1d_zero_scale_uint8_negative_axis +# DequantizeLinear: +# C++ exception with description "Unsupported precisions! +IE_CPU.onnx_model_dequantize_linear_scalar_zero_scale_int8 +IE_CPU.onnx_model_dequantize_linear_1d_zero_scale_int8 +# C++ exception with description "Input data precision not supported. Expected float. +IE_CPU.onnx_model_dequantize_linear_1d_zero_scale_int8_4d # Not supported ONNX op: QLinearConv onnx_model_quant_conv_linear diff --git a/ngraph/test/runtime/interpreter/unit_test.manifest b/ngraph/test/runtime/interpreter/unit_test.manifest index 643b0aac1d5..b8365841e15 100644 --- a/ngraph/test/runtime/interpreter/unit_test.manifest +++ b/ngraph/test/runtime/interpreter/unit_test.manifest @@ -91,13 +91,6 @@ INTERPRETER.onnx_model_quantize_linear INTERPRETER.onnx_model_quantize_linear_zero_point INTERPRETER.onnx_model_quantize_linear_axis_zero INTERPRETER.onnx_model_quantize_linear_axis_negative -INTERPRETER.onnx_model_dequantize_linear -INTERPRETER.onnx_model_dequantize_linear_scalar_zero_scale_uint8 -INTERPRETER.onnx_model_dequantize_linear_scalar_zero_scale_int8 -INTERPRETER.onnx_model_dequantize_linear_1d_zero_scale_uint8 -INTERPRETER.onnx_model_dequantize_linear_1d_zero_scale_int8 -INTERPRETER.onnx_model_dequantize_linear_1d_zero_scale_int8_4d -INTERPRETER.onnx_model_dequantize_linear_1d_zero_scale_uint8_negative_axis INTERPRETER.onnx_model_quant_conv_linear_2d INTERPRETER.onnx_model_quant_conv_linear_3d INTERPRETER.onnx_model_conv_integer