diff --git a/src/frontends/pytorch/src/op/quantized_add.cpp b/src/frontends/pytorch/src/op/quantized_add.cpp new file mode 100644 index 00000000000..2c6541e5962 --- /dev/null +++ b/src/frontends/pytorch/src/op/quantized_add.cpp @@ -0,0 +1,46 @@ +// Copyright (C) 2018-2023 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#include "openvino/frontend/pytorch/node_context.hpp" +#include "openvino/op/add.hpp" +#include "openvino/op/constant.hpp" +#include "openvino/op/relu.hpp" +#include "utils_quantize.hpp" + +namespace ov { +namespace frontend { +namespace pytorch { +namespace op { + +using namespace ov::op; + +OutputVector translate_quantized_add(const NodeContext& context) { + num_inputs_check(context, 4, 4); + const auto x = context.get_input(0); + const auto y = context.get_input(1); + const auto scale = context.get_input(2); + const auto zero_point = context.get_input(3); + + const auto quantized_add = context.mark_node(std::make_shared(x, y)); + + return {quantize(context, quantized_add, scale, zero_point, x)}; +} + +OutputVector translate_quantized_add_relu(const NodeContext& context) { + num_inputs_check(context, 4, 4); + const auto x = context.get_input(0); + const auto y = context.get_input(1); + const auto scale = context.get_input(2); + const auto zero_point = context.get_input(3); + + const auto quantized_add = context.mark_node(std::make_shared(x, y)); + const auto quantized_add_relu = context.mark_node(std::make_shared(quantized_add)); + + return {quantize(context, quantized_add_relu, scale, zero_point, x)}; +} + +} // namespace op +} // namespace pytorch +} // namespace frontend +} // namespace ov diff --git a/src/frontends/pytorch/src/op/quantized_hardswish.cpp b/src/frontends/pytorch/src/op/quantized_hardswish.cpp new file mode 100644 index 00000000000..3494ebe77a1 --- /dev/null +++ b/src/frontends/pytorch/src/op/quantized_hardswish.cpp @@ -0,0 +1,31 @@ +// Copyright (C) 2018-2023 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#include "openvino/frontend/pytorch/node_context.hpp" +#include "openvino/op/constant.hpp" +#include "openvino/op/hswish.hpp" +#include "utils_quantize.hpp" + +namespace ov { +namespace frontend { +namespace pytorch { +namespace op { + +using namespace ov::op; + +OutputVector translate_quantized_hardswish(const NodeContext& context) { + num_inputs_check(context, 3, 3); + const auto x = context.get_input(0); + const auto scale = context.get_input(1); + const auto zero_point = context.get_input(2); + + const auto quantized_hardswish = context.mark_node(std::make_shared(x)); + + return {quantize(context, quantized_hardswish, scale, zero_point, x)}; +} + +} // namespace op +} // namespace pytorch +} // namespace frontend +} // namespace ov diff --git a/src/frontends/pytorch/src/op/quantized_mul.cpp b/src/frontends/pytorch/src/op/quantized_mul.cpp new file mode 100644 index 00000000000..81575a67c11 --- /dev/null +++ b/src/frontends/pytorch/src/op/quantized_mul.cpp @@ -0,0 +1,32 @@ +// Copyright (C) 2018-2023 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#include "openvino/frontend/pytorch/node_context.hpp" +#include "openvino/op/constant.hpp" +#include "openvino/op/multiply.hpp" +#include "utils_quantize.hpp" + +namespace ov { +namespace frontend { +namespace pytorch { +namespace op { + +using namespace ov::op; + +OutputVector translate_quantized_mul(const NodeContext& context) { + num_inputs_check(context, 4, 4); + const auto x = context.get_input(0); + const auto y = context.get_input(1); + const auto scale = context.get_input(2); + const auto zero_point = context.get_input(3); + + const auto quantized_mul = context.mark_node(std::make_shared(x, y)); + + return {quantize(context, quantized_mul, scale, zero_point, x)}; +} + +} // namespace op +} // namespace pytorch +} // namespace frontend +} // namespace ov diff --git a/src/frontends/pytorch/src/op_table.cpp b/src/frontends/pytorch/src/op_table.cpp index 0548a222548..ff914b33b5d 100644 --- a/src/frontends/pytorch/src/op_table.cpp +++ b/src/frontends/pytorch/src/op_table.cpp @@ -108,6 +108,10 @@ OP_CONVERTER(translate_pow); OP_CONVERTER(translate_pythonop); OP_CONVERTER(translate_quantize_per_channel); OP_CONVERTER(translate_quantize_per_tensor); +OP_CONVERTER(translate_quantized_add); +OP_CONVERTER(translate_quantized_add_relu); +OP_CONVERTER(translate_quantized_hardswish); +OP_CONVERTER(translate_quantized_mul); OP_CONVERTER(translate_range_length); OP_CONVERTER(translate_rand); OP_CONVERTER(translate_randn); @@ -429,8 +433,12 @@ const std::map get_supported_ops() { {"prim::requires_grad", op::return_false_scalar}, {"prim::PythonOp", op::translate_pythonop}, {"prim::type", op::skip_node}, // Used with prim::device, pass PtFrameworkNode. + {"quantized::add", op::translate_quantized_add}, + {"quantized::add_relu", op::translate_quantized_add_relu}, {"quantized::conv2d", op::translate_quantized_convnd}, {"quantized::conv2d_relu", op::translate_quantized_convnd_relu}, + {"quantized::hardswish", op::translate_quantized_hardswish}, + {"quantized::mul", op::translate_quantized_mul}, {"quantized::linear", op::translate_quantized_linear}, {"torchvision::deform_conv2d", op::translate_deform_conv}, {"torchvision::nms", op::translate_nms}, diff --git a/tests/layer_tests/pytorch_tests/pytorch_layer_test_class.py b/tests/layer_tests/pytorch_tests/pytorch_layer_test_class.py index 81707836482..838cce35460 100644 --- a/tests/layer_tests/pytorch_tests/pytorch_layer_test_class.py +++ b/tests/layer_tests/pytorch_tests/pytorch_layer_test_class.py @@ -125,21 +125,34 @@ class PytorchLayerTest: # Compare Ie results with Framework results fw_eps = custom_eps if precision == 'FP32' else 5e-2 is_ok = True + quantized_ops = False + if 'quantized_ops' in kwargs and kwargs['quantized_ops'] is not None: + quantized_ops = kwargs['quantized_ops'] + if quantized_ops: + assert 'quant_size' in kwargs, "quant size must be specified for quantized_ops flag" + quant_size = kwargs['quant_size'] for i in range(len(infer_res)): cur_fw_res = flatten_fw_res[i].to(memory_format=torch.contiguous_format).numpy( ) if isinstance(flatten_fw_res[i], torch.Tensor) else flatten_fw_res[i] + if np.array(cur_fw_res).size == 0: + continue cur_ov_res = infer_res[compiled.output(i)] - print(f"fw_re: {cur_fw_res};\n ov_res: {cur_ov_res}") - if not np.allclose(cur_ov_res, cur_fw_res, - atol=fw_eps, - rtol=fw_eps, equal_nan=True): + print(f"fw_res: {cur_fw_res};\n ov_res: {cur_ov_res}") + n_is_not_close = np.array(cur_fw_res).size - np.isclose(cur_ov_res, cur_fw_res, + atol=fw_eps, + rtol=fw_eps, equal_nan=True).sum() + max_diff = np.array(abs(np.array(cur_ov_res, dtype=np.float32) - np.array(cur_fw_res, dtype=np.float32))).max() + if not quantized_ops and n_is_not_close > 0: is_ok = False - print("Max diff is {}".format( - np.array( - abs(cur_ov_res - cur_fw_res)).max())) + print("Max diff is {}".format(max_diff)) + elif quantized_ops and (n_is_not_close > int(np.log10(cur_fw_res.size)) or max_diff > np.array(quant_size + fw_eps).max()): + is_ok = False + print("Errors outside threshold range: {} with max diff {}, expected at most {} with max diff {}".format( + n_is_not_close, max_diff, int(np.log10(cur_fw_res.size)), quant_size + fw_eps)) else: print("Accuracy validation successful!\n") - print("absolute eps: {}, relative eps: {}".format(fw_eps, fw_eps)) + print("absolute eps: {}, relative eps: {}".format( + fw_eps, fw_eps)) assert is_ok, "Accuracy validation failed" # Each model should specify inputs diff --git a/tests/layer_tests/pytorch_tests/test_quantize.py b/tests/layer_tests/pytorch_tests/test_quantize.py index ecd06792925..f1a75221590 100644 --- a/tests/layer_tests/pytorch_tests/test_quantize.py +++ b/tests/layer_tests/pytorch_tests/test_quantize.py @@ -48,11 +48,11 @@ class TestQuantizePerTensorDequantize(PytorchLayerTest): reason="Not supported with FakeQuantize.")) ]) @pytest.mark.nightly - # @pytest.mark.precommit - sporadic issue + @pytest.mark.precommit def test_quantize_per_tensor_dequantize(self, scale, zero_point, dtype, ie_device, precision, ir_version): if dtype == torch.quint8: zero_point = abs(zero_point) self._test(aten_quantize_per_tensor_aten_dequantize(scale, zero_point, dtype), None, ["aten::quantize_per_tensor", "aten::dequantize"], - ie_device, precision, ir_version, ) + ie_device, precision, ir_version, quantized_ops=True, quant_size=scale) class TestQuantizePerChannelDequantize(PytorchLayerTest): def _prepare_input(self): @@ -87,9 +87,9 @@ class TestQuantizePerChannelDequantize(PytorchLayerTest): reason="Not supported with FakeQuantize.")) ]) @pytest.mark.nightly - # @pytest.mark.precommit - sporadic issue - def test_quantize_per_channel_dequantize(self, scale, zero_point, axis, dtype, ie_device, precision, ir_version): + @pytest.mark.precommit + def test_quantize_per_channel_dequantize(self, scale, zero_point, dtype, axis, ie_device, precision, ir_version): np.random.shuffle(scale), np.random.shuffle(zero_point) if dtype == torch.quint8: zero_point = abs(zero_point) - self._test(aten_quantize_per_channel_aten_dequantize(scale, zero_point, axis, dtype), None, ["aten::quantize_per_channel", "aten::dequantize"], - ie_device, precision, ir_version, ) + self._test(aten_quantize_per_channel_aten_dequantize(scale, zero_point, dtype, axis), None, ["aten::quantize_per_channel", "aten::dequantize"], + ie_device, precision, ir_version, quantized_ops=True, quant_size=scale) diff --git a/tests/layer_tests/pytorch_tests/test_quantized_add.py b/tests/layer_tests/pytorch_tests/test_quantized_add.py new file mode 100644 index 00000000000..ba0776fa19a --- /dev/null +++ b/tests/layer_tests/pytorch_tests/test_quantized_add.py @@ -0,0 +1,44 @@ +# Copyright (C) 2018-2023 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +import numpy as np +import pytest +import torch + +from pytorch_layer_test_class import PytorchLayerTest + +class quantized_add(torch.nn.Module): + def __init__(self, scale, zero_point, dtype) -> None: + torch.nn.Module.__init__(self) + self.scale = scale + self.zero_point = zero_point + self.dtype = dtype + + def forward(self, input_tensor1, input_tensor2): + quantized_tensor1 = torch.quantize_per_tensor(input_tensor1, self.scale, self.zero_point, self.dtype) + quantized_tensor2 = torch.quantize_per_tensor(input_tensor2, self.scale, self.zero_point, self.dtype) + quantized_add = torch.ops.quantized.add(quantized_tensor1, quantized_tensor2, self.scale, self.zero_point) + dequantized_tensor = torch.dequantize(quantized_add) + return dequantized_tensor + +class TestQuantizedAdd(PytorchLayerTest): + def _prepare_input(self): + return (np.array(5.00 * np.random.rand(100, 100) + 5.00, dtype=np.float32), + np.array(5.00 * np.random.rand(100, 100) + 5.00, dtype=np.float32)) + + @pytest.mark.parametrize("scale", [ + 1.0, 0.21, 0.62 + ]) + @pytest.mark.parametrize("zero_point", [ + 0, 4, -7 + ]) + @pytest.mark.parametrize("dtype", [ + torch.quint8, + torch.qint8 + ]) + @pytest.mark.nightly + @pytest.mark.precommit + def test_quantized_add(self, scale, zero_point, dtype, ie_device, precision, ir_version): + if dtype == torch.quint8: zero_point = abs(zero_point) + self._test(quantized_add(scale, zero_point, dtype), None, ["quantized::add"], + ie_device, precision, ir_version, quantized_ops=True, quant_size=scale) diff --git a/tests/layer_tests/pytorch_tests/test_quantized_add_relu.py b/tests/layer_tests/pytorch_tests/test_quantized_add_relu.py new file mode 100644 index 00000000000..4502c0c1973 --- /dev/null +++ b/tests/layer_tests/pytorch_tests/test_quantized_add_relu.py @@ -0,0 +1,44 @@ +# Copyright (C) 2018-2023 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +import numpy as np +import pytest +import torch + +from pytorch_layer_test_class import PytorchLayerTest + +class quantized_add_relu(torch.nn.Module): + def __init__(self, scale, zero_point, dtype) -> None: + torch.nn.Module.__init__(self) + self.scale = scale + self.zero_point = zero_point + self.dtype = dtype + + def forward(self, input_tensor1, input_tensor2): + quantized_tensor1 = torch.quantize_per_tensor(input_tensor1, self.scale, self.zero_point, self.dtype) + quantized_tensor2 = torch.quantize_per_tensor(input_tensor2, self.scale, self.zero_point, self.dtype) + quantized_add_relu = torch.ops.quantized.add_relu(quantized_tensor1, quantized_tensor2, self.scale, self.zero_point) + dequantized_tensor = torch.dequantize(quantized_add_relu) + return dequantized_tensor + +class TestQuantizedAddReLU(PytorchLayerTest): + def _prepare_input(self): + return (np.array(5.00 * np.random.rand(100, 100) + 5.00, dtype=np.float32), + np.array(5.00 * np.random.rand(100, 100) + 5.00, dtype=np.float32)) + + @pytest.mark.parametrize("scale", [ + 1.0, 0.21, 0.62 + ]) + @pytest.mark.parametrize("zero_point", [ + 0, 4, -7 + ]) + @pytest.mark.parametrize("dtype", [ + torch.quint8, + torch.qint8 + ]) + @pytest.mark.nightly + @pytest.mark.precommit + def test_quantized_add_relu(self, scale, zero_point, dtype, ie_device, precision, ir_version): + if dtype == torch.quint8: zero_point = abs(zero_point) + self._test(quantized_add_relu(scale, zero_point, dtype), None, ["quantized::add_relu"], + ie_device, precision, ir_version, quantized_ops=True, quant_size=scale) diff --git a/tests/layer_tests/pytorch_tests/test_quantized_hardswish.py b/tests/layer_tests/pytorch_tests/test_quantized_hardswish.py new file mode 100644 index 00000000000..254b43818e4 --- /dev/null +++ b/tests/layer_tests/pytorch_tests/test_quantized_hardswish.py @@ -0,0 +1,42 @@ +# Copyright (C) 2018-2023 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +import numpy as np +import pytest +import torch + +from pytorch_layer_test_class import PytorchLayerTest + +class quantized_hardswish(torch.nn.Module): + def __init__(self, scale, zero_point, dtype) -> None: + torch.nn.Module.__init__(self) + self.scale = scale + self.zero_point = zero_point + self.dtype = dtype + + def forward(self, input_tensor1): + quantized_tensor1 = torch.quantize_per_tensor(input_tensor1, self.scale, self.zero_point, self.dtype) + quantized_hardswish = torch.ops.quantized.hardswish(quantized_tensor1, self.scale, self.zero_point) + dequantized_tensor = torch.dequantize(quantized_hardswish) + return dequantized_tensor + +class TestQuantizedHardswish(PytorchLayerTest): + def _prepare_input(self): + return (np.array(5.00 * np.random.rand(100, 100) + 5.00, dtype=np.float32),) + + @pytest.mark.parametrize("scale", [ + 1.0, 0.21, 0.62, + ]) + @pytest.mark.parametrize("zero_point", [ + 0, 4, -7 + ]) + @pytest.mark.parametrize("dtype", [ + torch.quint8, + torch.qint8 + ]) + @pytest.mark.nightly + @pytest.mark.precommit + def test_quantized_hardswish(self, scale, zero_point, dtype, ie_device, precision, ir_version): + if dtype == torch.quint8: zero_point = abs(zero_point) + self._test(quantized_hardswish(scale, zero_point, dtype), None, ["quantized::hardswish"], + ie_device, precision, ir_version, quantized_ops=True, quant_size=scale) diff --git a/tests/layer_tests/pytorch_tests/test_quantized_mul.py b/tests/layer_tests/pytorch_tests/test_quantized_mul.py new file mode 100644 index 00000000000..ab6418ed449 --- /dev/null +++ b/tests/layer_tests/pytorch_tests/test_quantized_mul.py @@ -0,0 +1,44 @@ +# Copyright (C) 2018-2023 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +import numpy as np +import pytest +import torch + +from pytorch_layer_test_class import PytorchLayerTest + +class quantized_mul(torch.nn.Module): + def __init__(self, scale, zero_point, dtype) -> None: + torch.nn.Module.__init__(self) + self.scale = scale + self.zero_point = zero_point + self.dtype = dtype + + def forward(self, input_tensor1, input_tensor2): + quantized_tensor1 = torch.quantize_per_tensor(input_tensor1, self.scale, self.zero_point, self.dtype) + quantized_tensor2 = torch.quantize_per_tensor(input_tensor2, self.scale, self.zero_point, self.dtype) + quantized_mul = torch.ops.quantized.mul(quantized_tensor1, quantized_tensor2, self.scale, self.zero_point) + dequantized_tensor = torch.dequantize(quantized_mul) + return dequantized_tensor + +class TestQuantizedMul(PytorchLayerTest): + def _prepare_input(self): + return (np.array(5.00 * np.random.rand(100, 100) + 5.00, dtype=np.float32), + np.array(5.00 * np.random.rand(100, 100) + 5.00, dtype=np.float32)) + + @pytest.mark.parametrize("scale", [ + 1.0, 0.21, 0.62 + ]) + @pytest.mark.parametrize("zero_point", [ + 0, 4, -7 + ]) + @pytest.mark.parametrize("dtype", [ + torch.quint8, + torch.qint8 + ]) + @pytest.mark.nightly + # @pytest.mark.precommit - accuracy problem + def test_quantized_mul(self, scale, zero_point, dtype, ie_device, precision, ir_version): + if dtype == torch.quint8: zero_point = abs(zero_point) + self._test(quantized_mul(scale, zero_point, dtype), None, ["quantized::mul"], + ie_device, precision, ir_version, quantized_ops=True, quant_size=scale)