From 17914d516a2b58b946df75ff4c34a5eed0676b1c Mon Sep 17 00:00:00 2001 From: Jan Iwaszkiewicz Date: Wed, 20 Oct 2021 10:02:14 +0200 Subject: [PATCH] [ONNX] QLinearMatMul (#7418) * Add qlinear_matmul files * Extend matmul helpers * register op in bridge * Add test for 2d * newlines at end of files * Codestyle changes * Re-use of nodes inputs * Remove xfails * Register as set1 * reshape inputs as scalars * remove fixed xfail issue residues * update changes to fit the newest master Co-authored-by: dkozykowski --- .../frontend/onnx/frontend/src/op/matmul.hpp | 10 ++-- .../onnx/frontend/src/op/qlinear_matmul.cpp | 55 +++++++++++++++++++ .../onnx/frontend/src/op/qlinear_matmul.hpp | 24 ++++++++ .../frontend/onnx/frontend/src/ops_bridge.cpp | 2 + ngraph/test/onnx/onnx_import_quant.in.cpp | 18 ++++++ ngraph/test/runtime/ie/unit_test.manifest | 4 -- .../runtime/interpreter/unit_test.manifest | 4 -- runtime/bindings/python/tests/__init__.py | 3 - .../python/tests/test_onnx/test_backend.py | 6 -- .../test_onnx/test_backend.py | 6 -- 10 files changed, 105 insertions(+), 27 deletions(-) create mode 100644 ngraph/frontend/onnx/frontend/src/op/qlinear_matmul.cpp create mode 100644 ngraph/frontend/onnx/frontend/src/op/qlinear_matmul.hpp diff --git a/ngraph/frontend/onnx/frontend/src/op/matmul.hpp b/ngraph/frontend/onnx/frontend/src/op/matmul.hpp index 0bcbdc94c59..eb0f0fcc7d0 100644 --- a/ngraph/frontend/onnx/frontend/src/op/matmul.hpp +++ b/ngraph/frontend/onnx/frontend/src/op/matmul.hpp @@ -13,14 +13,16 @@ namespace ngraph { namespace onnx_import { namespace op { +namespace detail { +inline OutputVector matmul(const Output& a, const Output& b) { + return {std::make_shared(a, b)}; +} +} // namespace detail namespace set_1 { -OutputVector matmul(const Node& node) { +inline OutputVector matmul(const Node& node) { return {std::make_shared(node.get_ng_inputs().at(0), node.get_ng_inputs().at(1))}; } } // namespace set_1 - } // namespace op - } // namespace onnx_import - } // namespace ngraph diff --git a/ngraph/frontend/onnx/frontend/src/op/qlinear_matmul.cpp b/ngraph/frontend/onnx/frontend/src/op/qlinear_matmul.cpp new file mode 100644 index 00000000000..c0a23e7c8c8 --- /dev/null +++ b/ngraph/frontend/onnx/frontend/src/op/qlinear_matmul.cpp @@ -0,0 +1,55 @@ +// Copyright (C) 2018-2021 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#include "op/qlinear_matmul.hpp" + +#include +#include +#include + +#include "dequantize_linear.hpp" +#include "matmul.hpp" +#include "ngraph/opsets/opset6.hpp" +#include "quantize_linear.hpp" +#include "utils/reshape.hpp" + +namespace ngraph { +namespace onnx_import { +namespace op { +namespace set_1 { +OutputVector qlinear_matmul(const Node& node) { + const OutputVector& inputs = node.get_ng_inputs(); + + const auto& a = inputs.at(0); + const auto& a_scale = reshape::interpret_as_scalar(inputs.at(1)); + const auto& a_zero_point = reshape::interpret_as_scalar(inputs.at(2)); + const auto& b = inputs.at(3); + const auto& b_scale = reshape::interpret_as_scalar(inputs.at(4)); + const auto& b_zero_point = reshape::interpret_as_scalar(inputs.at(5)); + const auto& y_scale = inputs.at(6); + const auto& y_zero_point = inputs.at(7); + + const auto& dequnatize_a = + set_13::detail::dequantize_linear(a, + a_scale, + std::make_shared(a_zero_point, element::f32), + 1, + node); + const auto& dequnatize_b = + set_13::detail::dequantize_linear(b, + b_scale, + std::make_shared(b_zero_point, element::f32), + 1, + node); + + const auto& result = op::detail::matmul(dequnatize_a[0], dequnatize_b[0]); + + const auto& quantized_result = op::detail::make_fake_quantize(y_scale, y_zero_point, result[0]); + + return {quantized_result}; +} +} // namespace set_1 +} // namespace op +} // namespace onnx_import +} // namespace ngraph diff --git a/ngraph/frontend/onnx/frontend/src/op/qlinear_matmul.hpp b/ngraph/frontend/onnx/frontend/src/op/qlinear_matmul.hpp new file mode 100644 index 00000000000..9700e999746 --- /dev/null +++ b/ngraph/frontend/onnx/frontend/src/op/qlinear_matmul.hpp @@ -0,0 +1,24 @@ +// Copyright (C) 2018-2021 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#pragma once + +#include "ngraph/node.hpp" +#include "onnx_import/core/node.hpp" + +namespace ngraph { +namespace onnx_import { +namespace op { +namespace set_1 { +/// \brief Performs ONNX QLinearMatMul operation. +/// +/// \param node The ONNX node object representing this operation. +/// +/// \return The vector containing Ngraph nodes producing output of ONNX quantizied +/// matrix multiplication operation. +OutputVector qlinear_matmul(const Node& node); +} // namespace set_1 +} // namespace op +} // namespace onnx_import +} // namespace ngraph diff --git a/ngraph/frontend/onnx/frontend/src/ops_bridge.cpp b/ngraph/frontend/onnx/frontend/src/ops_bridge.cpp index 539cc7fc0d7..768db860144 100644 --- a/ngraph/frontend/onnx/frontend/src/ops_bridge.cpp +++ b/ngraph/frontend/onnx/frontend/src/ops_bridge.cpp @@ -112,6 +112,7 @@ #include "op/pow.hpp" #include "op/prelu.hpp" #include "op/qlinear_conv.hpp" +#include "op/qlinear_matmul.hpp" #include "op/quantize_linear.hpp" #include "op/random_uniform.hpp" #include "op/random_uniform_like.hpp" @@ -378,6 +379,7 @@ OperatorsBridge::OperatorsBridge() { REGISTER_OPERATOR("Pow", 1, pow); REGISTER_OPERATOR("PRelu", 1, prelu); REGISTER_OPERATOR("QLinearConv", 1, qlinear_conv); + REGISTER_OPERATOR("QLinearMatMul", 1, qlinear_matmul); REGISTER_OPERATOR("QuantizeLinear", 1, quantize_linear); REGISTER_OPERATOR("QuantizeLinear", 13, quantize_linear); REGISTER_OPERATOR("Range", 1, range); diff --git a/ngraph/test/onnx/onnx_import_quant.in.cpp b/ngraph/test/onnx/onnx_import_quant.in.cpp index 049a1712442..40b04a5f1cf 100644 --- a/ngraph/test/onnx/onnx_import_quant.in.cpp +++ b/ngraph/test/onnx/onnx_import_quant.in.cpp @@ -402,6 +402,24 @@ NGRAPH_TEST(${BACKEND_NAME}, onnx_model_quant_conv_linear_onnx_example) { test_case.run(); } +NGRAPH_TEST(${BACKEND_NAME}, onnx_model_qlinear_matmul_2d) { + auto function = onnx_import::import_onnx_model(file_util::path_join(SERIALIZED_ZOO, "onnx/qlinear_matmul.onnx")); + + auto test_case = test::TestCase(function); + + test_case.add_input(std::vector{208, 236, 0, 238, 3, 214, 255, 29}); // T1 + test_case.add_input(std::vector{0.0066f}); // a_scale + test_case.add_input(std::vector{113}); // a_zero_point + test_case.add_input(std::vector{152, 51, 244, 60, 26, 255, 0, 127, 246, 127, 254, 247}); // T2 + test_case.add_input(std::vector{0.00705f}); // b_scale + test_case.add_input(std::vector{114}); // b_zero_point + test_case.add_input(std::vector{0.0107f}); // y_scale + test_case.add_input(std::vector{118}); // y_zero_point + + test_case.add_expected_output({2, 3}, std::vector{168, 115, 255, 1, 66, 151}); // T3 + test_case.run(); +} + NGRAPH_TEST(${BACKEND_NAME}, onnx_model_matmul_integer_2d_simple_zero_point) { auto function = onnx_import::import_onnx_model(file_util::path_join(SERIALIZED_ZOO, "onnx/matmul_integer.onnx")); diff --git a/ngraph/test/runtime/ie/unit_test.manifest b/ngraph/test/runtime/ie/unit_test.manifest index 6aa3b9fb30b..c89881120c0 100644 --- a/ngraph/test/runtime/ie/unit_test.manifest +++ b/ngraph/test/runtime/ie/unit_test.manifest @@ -30,10 +30,6 @@ onnx_model_conv_integer_zero_point_zero onnx_model_conv_integer_no_zero_point onnx_model_conv_integer_pads -# Unsupported operator detected in the graph: QuantizedDot -onnx_model_qlinear_matmul -onnx_model_qlinear_matmul_3d - # No support yet for RandomUniform onnx_model_random_uniform onnx_model_random_uniform_like diff --git a/ngraph/test/runtime/interpreter/unit_test.manifest b/ngraph/test/runtime/interpreter/unit_test.manifest index 4b04419aae4..f5303fe192a 100644 --- a/ngraph/test/runtime/interpreter/unit_test.manifest +++ b/ngraph/test/runtime/interpreter/unit_test.manifest @@ -11,10 +11,6 @@ INTERPRETER.onnx_resize10_down_scales_const_nearest # Failed in MacOS: INTERPRETER.onnx_resize11_sizes_nearest_asymmetric_floor -# nGraph does not support the following ONNX operations -INTERPRETER.onnx_model_qlinear_matmul -INTERPRETER.onnx_model_qlinear_matmul_3d - # Disabled tests for disabled reference implementations INTERPRETER.onnx_dyn_shapes_expand_uint16_dyn_shape INTERPRETER.sum_2d_to_scalar_int8 diff --git a/runtime/bindings/python/tests/__init__.py b/runtime/bindings/python/tests/__init__.py index 244ce7b4016..fad3d7367ff 100644 --- a/runtime/bindings/python/tests/__init__.py +++ b/runtime/bindings/python/tests/__init__.py @@ -67,9 +67,6 @@ xfail_issue_38713 = xfail_test(reason="RuntimeError: nGraph does not support the "ai.onnx.preview.training.Momentum") xfail_issue_45457 = xfail_test(reason="RuntimeError: Unsupported dynamic ops: v5::Loop " "Not constant termination condition body output is not supported") -xfail_issue_38722 = xfail_test(reason="RuntimeError: While validating ONNX nodes MatMulInteger " - "and QLinearMatMul " - "Input0 scale and input0 zero point shape must be same and 1") xfail_issue_38724 = xfail_test(reason="RuntimeError: While validating ONNX node '': " "tf_crop_and_resize - this type of coordinate transformation mode " "is not supported. Choose one of the following modes: " diff --git a/runtime/bindings/python/tests/test_onnx/test_backend.py b/runtime/bindings/python/tests/test_onnx/test_backend.py index b92e64dd71e..49c6589c20c 100644 --- a/runtime/bindings/python/tests/test_onnx/test_backend.py +++ b/runtime/bindings/python/tests/test_onnx/test_backend.py @@ -25,7 +25,6 @@ from tests import ( xfail_issue_38708, xfail_issue_38710, xfail_issue_38713, - xfail_issue_38722, xfail_issue_38724, xfail_issue_38732, xfail_issue_38734, @@ -372,11 +371,6 @@ tests_expected_to_fail = [ "OnnxBackendNodeModelTest.test_isinf_negative_cpu", "OnnxBackendNodeModelTest.test_isinf_cpu", ), - ( - xfail_issue_38722, - "OnnxBackendNodeModelTest.test_qlinearmatmul_2D_cpu", - "OnnxBackendNodeModelTest.test_qlinearmatmul_3D_cpu", - ), (xfail_issue_38724, "OnnxBackendNodeModelTest.test_resize_tf_crop_and_resize_cpu"), ( xfail_issue_33606, diff --git a/runtime/bindings/python/tests_compatibility/test_onnx/test_backend.py b/runtime/bindings/python/tests_compatibility/test_onnx/test_backend.py index c850f9d5045..675ca9be114 100644 --- a/runtime/bindings/python/tests_compatibility/test_onnx/test_backend.py +++ b/runtime/bindings/python/tests_compatibility/test_onnx/test_backend.py @@ -24,7 +24,6 @@ from tests_compatibility import ( xfail_issue_38708, xfail_issue_38710, xfail_issue_38713, - xfail_issue_38722, xfail_issue_38724, xfail_issue_38732, xfail_issue_38734, @@ -331,11 +330,6 @@ tests_expected_to_fail = [ "OnnxBackendNodeModelTest.test_isinf_negative_cpu", "OnnxBackendNodeModelTest.test_isinf_cpu", ), - ( - xfail_issue_38722, - "OnnxBackendNodeModelTest.test_qlinearmatmul_2D_cpu", - "OnnxBackendNodeModelTest.test_qlinearmatmul_3D_cpu", - ), (xfail_issue_38724, "OnnxBackendNodeModelTest.test_resize_tf_crop_and_resize_cpu"), ( xfail_issue_33606,