[PT FE] Quantized Add, Add_ReLU, Mul, Hardswish (#18510)

* Support GetAttr with packed params

* Apply suggestions from code review

* [PT FE] Add quantized types as normal types to decoder

* [PT FE] Add decoder dequantize, add dtypes to quantize

* [PT FE] Add dequantize example

* [PT FE] Implement replacer for quantized nodes

* [PT FE] Register replacer for quantize/dequantize

* [PT FE] Remove unwanted junk from previous version

* [PT FE] Fix building mistakes for frontend

* [PT FE] Clang fix

* [PT FE] Ease of use upgrade to quantize funcs

* [PT FE] Clang format

* [PT FE] Introduce new version of quantize/dequantize

* [PT FE] Remove unwanted files from new version

* [PT FE] Fix style

* [PT FE] Add QuantizedPtNode replacer, fix accuracy error

* [PT FE] Quantized add

* [PT FE] Add improved version of quantize/dequantize with shared_ptrs

* [PT FE] Quantized Add Relu

* [PT FE] Fix utils shared ptr reference error

* [PT FE] Quantized Hardswish & tests

* [PT FE] Add quantized_add_relu test

* [PT FE] Quantize now takes correct input for operations

* [PT FE] Add dtype to tests

* [PT FE] Increase matrix size in the tests to cover more edge cases

* [PT FE] Upgrade quantize method

* [PT FE] Add BFS for dequantize, add quantize_per_channel

* [PT FE] Add missing replacer to frontend, improve tests

* [PT FE] Rename replacer -> remover, remove unwanted header files

* [PT FE] Change function declarations to return ov::Output instead of shared ptr

* [PT FE] Add missing context mark node

* [PT FE] Remove unknown modifications to ie_c_api

* [PT FE] Remove fp16 support, turn off int32 tests

* [PT FE] Clang format

* [PT FE] Fix quantize_per_tensor

* [PT FE] Minor fixes from review

* [PT FE] Remove dequantize, remove helpers, replacer now removes nodes instead

* [PT FE] Rename Replacer to Remover for dequantize nodes

* [PT FE] Clang format

* [PT FE] Move comments to header files, minor import fixes

* [PT FE] Enable add, add_relu, mul, hardswish with newest quantize ops

* [PT FE] Fix clang format

* [PT FE] Fix dtype issue

* [PT FE] Reenable hardswish tests

* [PT FE] Fix building error for quantized ops

* [PT FE] Tests now use rand instead of randn

* [PT FE] Remove int32 support from tests

* [PT FE] Fix quantize_per_channel tests

* Apply suggestions from code review

Removing sporadic tests from precommit

* Apply suggestions from code review

* [PT FE] Mark qint32 with skip, remove precommit flag due to sporadic errors

* [PT FE] Introduce quantized test for PytorchTestClass

* [PT FE] Add missing tensor size in the errors count computation

* Update pytorch_layer_test_class.py

* [PT FE] Apply suggestions from review - remove contexts, log100 for errors, add precommit flags

* Apply suggestions from code review

* [PT FE] Fix quantize per channel test after merge

* Update tests/layer_tests/pytorch_tests/pytorch_layer_test_class.py

* Update pytorch_layer_test_class.py

* Update test_quantize.py

* Apply suggestions from code review

* [PT FE] Fix quantized mul by changing default quantization dtype

* Update pytorch_layer_test_class.py

* Update tests/layer_tests/pytorch_tests/pytorch_layer_test_class.py

* Update src/frontends/pytorch/src/utils_quantize.cpp

* Update src/frontends/pytorch/src/utils_quantize.cpp

* Update tests/layer_tests/pytorch_tests/test_quantized_mul.py

---------

Co-authored-by: Maxim Vafin <maxim.vafin@intel.com>
This commit is contained in:
Piotr Krzemiński 2023-07-21 12:27:35 +02:00 committed by GitHub
parent a6e25d69ac
commit e076ed4726
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
10 changed files with 318 additions and 14 deletions

View File

@ -0,0 +1,46 @@
// Copyright (C) 2018-2023 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#include "openvino/frontend/pytorch/node_context.hpp"
#include "openvino/op/add.hpp"
#include "openvino/op/constant.hpp"
#include "openvino/op/relu.hpp"
#include "utils_quantize.hpp"
namespace ov {
namespace frontend {
namespace pytorch {
namespace op {
using namespace ov::op;
OutputVector translate_quantized_add(const NodeContext& context) {
num_inputs_check(context, 4, 4);
const auto x = context.get_input(0);
const auto y = context.get_input(1);
const auto scale = context.get_input(2);
const auto zero_point = context.get_input(3);
const auto quantized_add = context.mark_node(std::make_shared<v1::Add>(x, y));
return {quantize(context, quantized_add, scale, zero_point, x)};
}
OutputVector translate_quantized_add_relu(const NodeContext& context) {
num_inputs_check(context, 4, 4);
const auto x = context.get_input(0);
const auto y = context.get_input(1);
const auto scale = context.get_input(2);
const auto zero_point = context.get_input(3);
const auto quantized_add = context.mark_node(std::make_shared<v1::Add>(x, y));
const auto quantized_add_relu = context.mark_node(std::make_shared<v0::Relu>(quantized_add));
return {quantize(context, quantized_add_relu, scale, zero_point, x)};
}
} // namespace op
} // namespace pytorch
} // namespace frontend
} // namespace ov

View File

@ -0,0 +1,31 @@
// Copyright (C) 2018-2023 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#include "openvino/frontend/pytorch/node_context.hpp"
#include "openvino/op/constant.hpp"
#include "openvino/op/hswish.hpp"
#include "utils_quantize.hpp"
namespace ov {
namespace frontend {
namespace pytorch {
namespace op {
using namespace ov::op;
OutputVector translate_quantized_hardswish(const NodeContext& context) {
num_inputs_check(context, 3, 3);
const auto x = context.get_input(0);
const auto scale = context.get_input(1);
const auto zero_point = context.get_input(2);
const auto quantized_hardswish = context.mark_node(std::make_shared<v4::HSwish>(x));
return {quantize(context, quantized_hardswish, scale, zero_point, x)};
}
} // namespace op
} // namespace pytorch
} // namespace frontend
} // namespace ov

View File

@ -0,0 +1,32 @@
// Copyright (C) 2018-2023 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#include "openvino/frontend/pytorch/node_context.hpp"
#include "openvino/op/constant.hpp"
#include "openvino/op/multiply.hpp"
#include "utils_quantize.hpp"
namespace ov {
namespace frontend {
namespace pytorch {
namespace op {
using namespace ov::op;
OutputVector translate_quantized_mul(const NodeContext& context) {
num_inputs_check(context, 4, 4);
const auto x = context.get_input(0);
const auto y = context.get_input(1);
const auto scale = context.get_input(2);
const auto zero_point = context.get_input(3);
const auto quantized_mul = context.mark_node(std::make_shared<v1::Multiply>(x, y));
return {quantize(context, quantized_mul, scale, zero_point, x)};
}
} // namespace op
} // namespace pytorch
} // namespace frontend
} // namespace ov

View File

@ -108,6 +108,10 @@ OP_CONVERTER(translate_pow);
OP_CONVERTER(translate_pythonop);
OP_CONVERTER(translate_quantize_per_channel);
OP_CONVERTER(translate_quantize_per_tensor);
OP_CONVERTER(translate_quantized_add);
OP_CONVERTER(translate_quantized_add_relu);
OP_CONVERTER(translate_quantized_hardswish);
OP_CONVERTER(translate_quantized_mul);
OP_CONVERTER(translate_range_length);
OP_CONVERTER(translate_rand);
OP_CONVERTER(translate_randn);
@ -429,8 +433,12 @@ const std::map<std::string, CreatorFunction> get_supported_ops() {
{"prim::requires_grad", op::return_false_scalar},
{"prim::PythonOp", op::translate_pythonop},
{"prim::type", op::skip_node}, // Used with prim::device, pass PtFrameworkNode.
{"quantized::add", op::translate_quantized_add},
{"quantized::add_relu", op::translate_quantized_add_relu},
{"quantized::conv2d", op::translate_quantized_convnd},
{"quantized::conv2d_relu", op::translate_quantized_convnd_relu},
{"quantized::hardswish", op::translate_quantized_hardswish},
{"quantized::mul", op::translate_quantized_mul},
{"quantized::linear", op::translate_quantized_linear},
{"torchvision::deform_conv2d", op::translate_deform_conv},
{"torchvision::nms", op::translate_nms},

View File

@ -125,21 +125,34 @@ class PytorchLayerTest:
# Compare Ie results with Framework results
fw_eps = custom_eps if precision == 'FP32' else 5e-2
is_ok = True
quantized_ops = False
if 'quantized_ops' in kwargs and kwargs['quantized_ops'] is not None:
quantized_ops = kwargs['quantized_ops']
if quantized_ops:
assert 'quant_size' in kwargs, "quant size must be specified for quantized_ops flag"
quant_size = kwargs['quant_size']
for i in range(len(infer_res)):
cur_fw_res = flatten_fw_res[i].to(memory_format=torch.contiguous_format).numpy(
) if isinstance(flatten_fw_res[i], torch.Tensor) else flatten_fw_res[i]
if np.array(cur_fw_res).size == 0:
continue
cur_ov_res = infer_res[compiled.output(i)]
print(f"fw_re: {cur_fw_res};\n ov_res: {cur_ov_res}")
if not np.allclose(cur_ov_res, cur_fw_res,
print(f"fw_res: {cur_fw_res};\n ov_res: {cur_ov_res}")
n_is_not_close = np.array(cur_fw_res).size - np.isclose(cur_ov_res, cur_fw_res,
atol=fw_eps,
rtol=fw_eps, equal_nan=True):
rtol=fw_eps, equal_nan=True).sum()
max_diff = np.array(abs(np.array(cur_ov_res, dtype=np.float32) - np.array(cur_fw_res, dtype=np.float32))).max()
if not quantized_ops and n_is_not_close > 0:
is_ok = False
print("Max diff is {}".format(
np.array(
abs(cur_ov_res - cur_fw_res)).max()))
print("Max diff is {}".format(max_diff))
elif quantized_ops and (n_is_not_close > int(np.log10(cur_fw_res.size)) or max_diff > np.array(quant_size + fw_eps).max()):
is_ok = False
print("Errors outside threshold range: {} with max diff {}, expected at most {} with max diff {}".format(
n_is_not_close, max_diff, int(np.log10(cur_fw_res.size)), quant_size + fw_eps))
else:
print("Accuracy validation successful!\n")
print("absolute eps: {}, relative eps: {}".format(fw_eps, fw_eps))
print("absolute eps: {}, relative eps: {}".format(
fw_eps, fw_eps))
assert is_ok, "Accuracy validation failed"
# Each model should specify inputs

View File

@ -48,11 +48,11 @@ class TestQuantizePerTensorDequantize(PytorchLayerTest):
reason="Not supported with FakeQuantize."))
])
@pytest.mark.nightly
# @pytest.mark.precommit - sporadic issue
@pytest.mark.precommit
def test_quantize_per_tensor_dequantize(self, scale, zero_point, dtype, ie_device, precision, ir_version):
if dtype == torch.quint8: zero_point = abs(zero_point)
self._test(aten_quantize_per_tensor_aten_dequantize(scale, zero_point, dtype), None, ["aten::quantize_per_tensor", "aten::dequantize"],
ie_device, precision, ir_version, )
ie_device, precision, ir_version, quantized_ops=True, quant_size=scale)
class TestQuantizePerChannelDequantize(PytorchLayerTest):
def _prepare_input(self):
@ -87,9 +87,9 @@ class TestQuantizePerChannelDequantize(PytorchLayerTest):
reason="Not supported with FakeQuantize."))
])
@pytest.mark.nightly
# @pytest.mark.precommit - sporadic issue
def test_quantize_per_channel_dequantize(self, scale, zero_point, axis, dtype, ie_device, precision, ir_version):
@pytest.mark.precommit
def test_quantize_per_channel_dequantize(self, scale, zero_point, dtype, axis, ie_device, precision, ir_version):
np.random.shuffle(scale), np.random.shuffle(zero_point)
if dtype == torch.quint8: zero_point = abs(zero_point)
self._test(aten_quantize_per_channel_aten_dequantize(scale, zero_point, axis, dtype), None, ["aten::quantize_per_channel", "aten::dequantize"],
ie_device, precision, ir_version, )
self._test(aten_quantize_per_channel_aten_dequantize(scale, zero_point, dtype, axis), None, ["aten::quantize_per_channel", "aten::dequantize"],
ie_device, precision, ir_version, quantized_ops=True, quant_size=scale)

View File

@ -0,0 +1,44 @@
# Copyright (C) 2018-2023 Intel Corporation
# SPDX-License-Identifier: Apache-2.0
import numpy as np
import pytest
import torch
from pytorch_layer_test_class import PytorchLayerTest
class quantized_add(torch.nn.Module):
def __init__(self, scale, zero_point, dtype) -> None:
torch.nn.Module.__init__(self)
self.scale = scale
self.zero_point = zero_point
self.dtype = dtype
def forward(self, input_tensor1, input_tensor2):
quantized_tensor1 = torch.quantize_per_tensor(input_tensor1, self.scale, self.zero_point, self.dtype)
quantized_tensor2 = torch.quantize_per_tensor(input_tensor2, self.scale, self.zero_point, self.dtype)
quantized_add = torch.ops.quantized.add(quantized_tensor1, quantized_tensor2, self.scale, self.zero_point)
dequantized_tensor = torch.dequantize(quantized_add)
return dequantized_tensor
class TestQuantizedAdd(PytorchLayerTest):
def _prepare_input(self):
return (np.array(5.00 * np.random.rand(100, 100) + 5.00, dtype=np.float32),
np.array(5.00 * np.random.rand(100, 100) + 5.00, dtype=np.float32))
@pytest.mark.parametrize("scale", [
1.0, 0.21, 0.62
])
@pytest.mark.parametrize("zero_point", [
0, 4, -7
])
@pytest.mark.parametrize("dtype", [
torch.quint8,
torch.qint8
])
@pytest.mark.nightly
@pytest.mark.precommit
def test_quantized_add(self, scale, zero_point, dtype, ie_device, precision, ir_version):
if dtype == torch.quint8: zero_point = abs(zero_point)
self._test(quantized_add(scale, zero_point, dtype), None, ["quantized::add"],
ie_device, precision, ir_version, quantized_ops=True, quant_size=scale)

View File

@ -0,0 +1,44 @@
# Copyright (C) 2018-2023 Intel Corporation
# SPDX-License-Identifier: Apache-2.0
import numpy as np
import pytest
import torch
from pytorch_layer_test_class import PytorchLayerTest
class quantized_add_relu(torch.nn.Module):
def __init__(self, scale, zero_point, dtype) -> None:
torch.nn.Module.__init__(self)
self.scale = scale
self.zero_point = zero_point
self.dtype = dtype
def forward(self, input_tensor1, input_tensor2):
quantized_tensor1 = torch.quantize_per_tensor(input_tensor1, self.scale, self.zero_point, self.dtype)
quantized_tensor2 = torch.quantize_per_tensor(input_tensor2, self.scale, self.zero_point, self.dtype)
quantized_add_relu = torch.ops.quantized.add_relu(quantized_tensor1, quantized_tensor2, self.scale, self.zero_point)
dequantized_tensor = torch.dequantize(quantized_add_relu)
return dequantized_tensor
class TestQuantizedAddReLU(PytorchLayerTest):
def _prepare_input(self):
return (np.array(5.00 * np.random.rand(100, 100) + 5.00, dtype=np.float32),
np.array(5.00 * np.random.rand(100, 100) + 5.00, dtype=np.float32))
@pytest.mark.parametrize("scale", [
1.0, 0.21, 0.62
])
@pytest.mark.parametrize("zero_point", [
0, 4, -7
])
@pytest.mark.parametrize("dtype", [
torch.quint8,
torch.qint8
])
@pytest.mark.nightly
@pytest.mark.precommit
def test_quantized_add_relu(self, scale, zero_point, dtype, ie_device, precision, ir_version):
if dtype == torch.quint8: zero_point = abs(zero_point)
self._test(quantized_add_relu(scale, zero_point, dtype), None, ["quantized::add_relu"],
ie_device, precision, ir_version, quantized_ops=True, quant_size=scale)

View File

@ -0,0 +1,42 @@
# Copyright (C) 2018-2023 Intel Corporation
# SPDX-License-Identifier: Apache-2.0
import numpy as np
import pytest
import torch
from pytorch_layer_test_class import PytorchLayerTest
class quantized_hardswish(torch.nn.Module):
def __init__(self, scale, zero_point, dtype) -> None:
torch.nn.Module.__init__(self)
self.scale = scale
self.zero_point = zero_point
self.dtype = dtype
def forward(self, input_tensor1):
quantized_tensor1 = torch.quantize_per_tensor(input_tensor1, self.scale, self.zero_point, self.dtype)
quantized_hardswish = torch.ops.quantized.hardswish(quantized_tensor1, self.scale, self.zero_point)
dequantized_tensor = torch.dequantize(quantized_hardswish)
return dequantized_tensor
class TestQuantizedHardswish(PytorchLayerTest):
def _prepare_input(self):
return (np.array(5.00 * np.random.rand(100, 100) + 5.00, dtype=np.float32),)
@pytest.mark.parametrize("scale", [
1.0, 0.21, 0.62,
])
@pytest.mark.parametrize("zero_point", [
0, 4, -7
])
@pytest.mark.parametrize("dtype", [
torch.quint8,
torch.qint8
])
@pytest.mark.nightly
@pytest.mark.precommit
def test_quantized_hardswish(self, scale, zero_point, dtype, ie_device, precision, ir_version):
if dtype == torch.quint8: zero_point = abs(zero_point)
self._test(quantized_hardswish(scale, zero_point, dtype), None, ["quantized::hardswish"],
ie_device, precision, ir_version, quantized_ops=True, quant_size=scale)

View File

@ -0,0 +1,44 @@
# Copyright (C) 2018-2023 Intel Corporation
# SPDX-License-Identifier: Apache-2.0
import numpy as np
import pytest
import torch
from pytorch_layer_test_class import PytorchLayerTest
class quantized_mul(torch.nn.Module):
def __init__(self, scale, zero_point, dtype) -> None:
torch.nn.Module.__init__(self)
self.scale = scale
self.zero_point = zero_point
self.dtype = dtype
def forward(self, input_tensor1, input_tensor2):
quantized_tensor1 = torch.quantize_per_tensor(input_tensor1, self.scale, self.zero_point, self.dtype)
quantized_tensor2 = torch.quantize_per_tensor(input_tensor2, self.scale, self.zero_point, self.dtype)
quantized_mul = torch.ops.quantized.mul(quantized_tensor1, quantized_tensor2, self.scale, self.zero_point)
dequantized_tensor = torch.dequantize(quantized_mul)
return dequantized_tensor
class TestQuantizedMul(PytorchLayerTest):
def _prepare_input(self):
return (np.array(5.00 * np.random.rand(100, 100) + 5.00, dtype=np.float32),
np.array(5.00 * np.random.rand(100, 100) + 5.00, dtype=np.float32))
@pytest.mark.parametrize("scale", [
1.0, 0.21, 0.62
])
@pytest.mark.parametrize("zero_point", [
0, 4, -7
])
@pytest.mark.parametrize("dtype", [
torch.quint8,
torch.qint8
])
@pytest.mark.nightly
# @pytest.mark.precommit - accuracy problem
def test_quantized_mul(self, scale, zero_point, dtype, ie_device, precision, ir_version):
if dtype == torch.quint8: zero_point = abs(zero_point)
self._test(quantized_mul(scale, zero_point, dtype), None, ["quantized::mul"],
ie_device, precision, ir_version, quantized_ops=True, quant_size=scale)