[PT FE] Quantized Add, Add_ReLU, Mul, Hardswish (#18510)
* Support GetAttr with packed params * Apply suggestions from code review * [PT FE] Add quantized types as normal types to decoder * [PT FE] Add decoder dequantize, add dtypes to quantize * [PT FE] Add dequantize example * [PT FE] Implement replacer for quantized nodes * [PT FE] Register replacer for quantize/dequantize * [PT FE] Remove unwanted junk from previous version * [PT FE] Fix building mistakes for frontend * [PT FE] Clang fix * [PT FE] Ease of use upgrade to quantize funcs * [PT FE] Clang format * [PT FE] Introduce new version of quantize/dequantize * [PT FE] Remove unwanted files from new version * [PT FE] Fix style * [PT FE] Add QuantizedPtNode replacer, fix accuracy error * [PT FE] Quantized add * [PT FE] Add improved version of quantize/dequantize with shared_ptrs * [PT FE] Quantized Add Relu * [PT FE] Fix utils shared ptr reference error * [PT FE] Quantized Hardswish & tests * [PT FE] Add quantized_add_relu test * [PT FE] Quantize now takes correct input for operations * [PT FE] Add dtype to tests * [PT FE] Increase matrix size in the tests to cover more edge cases * [PT FE] Upgrade quantize method * [PT FE] Add BFS for dequantize, add quantize_per_channel * [PT FE] Add missing replacer to frontend, improve tests * [PT FE] Rename replacer -> remover, remove unwanted header files * [PT FE] Change function declarations to return ov::Output instead of shared ptr * [PT FE] Add missing context mark node * [PT FE] Remove unknown modifications to ie_c_api * [PT FE] Remove fp16 support, turn off int32 tests * [PT FE] Clang format * [PT FE] Fix quantize_per_tensor * [PT FE] Minor fixes from review * [PT FE] Remove dequantize, remove helpers, replacer now removes nodes instead * [PT FE] Rename Replacer to Remover for dequantize nodes * [PT FE] Clang format * [PT FE] Move comments to header files, minor import fixes * [PT FE] Enable add, add_relu, mul, hardswish with newest quantize ops * [PT FE] Fix clang format * [PT FE] Fix dtype issue * [PT FE] Reenable hardswish tests * [PT FE] Fix building error for quantized ops * [PT FE] Tests now use rand instead of randn * [PT FE] Remove int32 support from tests * [PT FE] Fix quantize_per_channel tests * Apply suggestions from code review Removing sporadic tests from precommit * Apply suggestions from code review * [PT FE] Mark qint32 with skip, remove precommit flag due to sporadic errors * [PT FE] Introduce quantized test for PytorchTestClass * [PT FE] Add missing tensor size in the errors count computation * Update pytorch_layer_test_class.py * [PT FE] Apply suggestions from review - remove contexts, log100 for errors, add precommit flags * Apply suggestions from code review * [PT FE] Fix quantize per channel test after merge * Update tests/layer_tests/pytorch_tests/pytorch_layer_test_class.py * Update pytorch_layer_test_class.py * Update test_quantize.py * Apply suggestions from code review * [PT FE] Fix quantized mul by changing default quantization dtype * Update pytorch_layer_test_class.py * Update tests/layer_tests/pytorch_tests/pytorch_layer_test_class.py * Update src/frontends/pytorch/src/utils_quantize.cpp * Update src/frontends/pytorch/src/utils_quantize.cpp * Update tests/layer_tests/pytorch_tests/test_quantized_mul.py --------- Co-authored-by: Maxim Vafin <maxim.vafin@intel.com>
This commit is contained in:
parent
a6e25d69ac
commit
e076ed4726
46
src/frontends/pytorch/src/op/quantized_add.cpp
Normal file
46
src/frontends/pytorch/src/op/quantized_add.cpp
Normal file
@ -0,0 +1,46 @@
|
|||||||
|
// Copyright (C) 2018-2023 Intel Corporation
|
||||||
|
// SPDX-License-Identifier: Apache-2.0
|
||||||
|
//
|
||||||
|
|
||||||
|
#include "openvino/frontend/pytorch/node_context.hpp"
|
||||||
|
#include "openvino/op/add.hpp"
|
||||||
|
#include "openvino/op/constant.hpp"
|
||||||
|
#include "openvino/op/relu.hpp"
|
||||||
|
#include "utils_quantize.hpp"
|
||||||
|
|
||||||
|
namespace ov {
|
||||||
|
namespace frontend {
|
||||||
|
namespace pytorch {
|
||||||
|
namespace op {
|
||||||
|
|
||||||
|
using namespace ov::op;
|
||||||
|
|
||||||
|
OutputVector translate_quantized_add(const NodeContext& context) {
|
||||||
|
num_inputs_check(context, 4, 4);
|
||||||
|
const auto x = context.get_input(0);
|
||||||
|
const auto y = context.get_input(1);
|
||||||
|
const auto scale = context.get_input(2);
|
||||||
|
const auto zero_point = context.get_input(3);
|
||||||
|
|
||||||
|
const auto quantized_add = context.mark_node(std::make_shared<v1::Add>(x, y));
|
||||||
|
|
||||||
|
return {quantize(context, quantized_add, scale, zero_point, x)};
|
||||||
|
}
|
||||||
|
|
||||||
|
OutputVector translate_quantized_add_relu(const NodeContext& context) {
|
||||||
|
num_inputs_check(context, 4, 4);
|
||||||
|
const auto x = context.get_input(0);
|
||||||
|
const auto y = context.get_input(1);
|
||||||
|
const auto scale = context.get_input(2);
|
||||||
|
const auto zero_point = context.get_input(3);
|
||||||
|
|
||||||
|
const auto quantized_add = context.mark_node(std::make_shared<v1::Add>(x, y));
|
||||||
|
const auto quantized_add_relu = context.mark_node(std::make_shared<v0::Relu>(quantized_add));
|
||||||
|
|
||||||
|
return {quantize(context, quantized_add_relu, scale, zero_point, x)};
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace op
|
||||||
|
} // namespace pytorch
|
||||||
|
} // namespace frontend
|
||||||
|
} // namespace ov
|
31
src/frontends/pytorch/src/op/quantized_hardswish.cpp
Normal file
31
src/frontends/pytorch/src/op/quantized_hardswish.cpp
Normal file
@ -0,0 +1,31 @@
|
|||||||
|
// Copyright (C) 2018-2023 Intel Corporation
|
||||||
|
// SPDX-License-Identifier: Apache-2.0
|
||||||
|
//
|
||||||
|
|
||||||
|
#include "openvino/frontend/pytorch/node_context.hpp"
|
||||||
|
#include "openvino/op/constant.hpp"
|
||||||
|
#include "openvino/op/hswish.hpp"
|
||||||
|
#include "utils_quantize.hpp"
|
||||||
|
|
||||||
|
namespace ov {
|
||||||
|
namespace frontend {
|
||||||
|
namespace pytorch {
|
||||||
|
namespace op {
|
||||||
|
|
||||||
|
using namespace ov::op;
|
||||||
|
|
||||||
|
OutputVector translate_quantized_hardswish(const NodeContext& context) {
|
||||||
|
num_inputs_check(context, 3, 3);
|
||||||
|
const auto x = context.get_input(0);
|
||||||
|
const auto scale = context.get_input(1);
|
||||||
|
const auto zero_point = context.get_input(2);
|
||||||
|
|
||||||
|
const auto quantized_hardswish = context.mark_node(std::make_shared<v4::HSwish>(x));
|
||||||
|
|
||||||
|
return {quantize(context, quantized_hardswish, scale, zero_point, x)};
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace op
|
||||||
|
} // namespace pytorch
|
||||||
|
} // namespace frontend
|
||||||
|
} // namespace ov
|
32
src/frontends/pytorch/src/op/quantized_mul.cpp
Normal file
32
src/frontends/pytorch/src/op/quantized_mul.cpp
Normal file
@ -0,0 +1,32 @@
|
|||||||
|
// Copyright (C) 2018-2023 Intel Corporation
|
||||||
|
// SPDX-License-Identifier: Apache-2.0
|
||||||
|
//
|
||||||
|
|
||||||
|
#include "openvino/frontend/pytorch/node_context.hpp"
|
||||||
|
#include "openvino/op/constant.hpp"
|
||||||
|
#include "openvino/op/multiply.hpp"
|
||||||
|
#include "utils_quantize.hpp"
|
||||||
|
|
||||||
|
namespace ov {
|
||||||
|
namespace frontend {
|
||||||
|
namespace pytorch {
|
||||||
|
namespace op {
|
||||||
|
|
||||||
|
using namespace ov::op;
|
||||||
|
|
||||||
|
OutputVector translate_quantized_mul(const NodeContext& context) {
|
||||||
|
num_inputs_check(context, 4, 4);
|
||||||
|
const auto x = context.get_input(0);
|
||||||
|
const auto y = context.get_input(1);
|
||||||
|
const auto scale = context.get_input(2);
|
||||||
|
const auto zero_point = context.get_input(3);
|
||||||
|
|
||||||
|
const auto quantized_mul = context.mark_node(std::make_shared<v1::Multiply>(x, y));
|
||||||
|
|
||||||
|
return {quantize(context, quantized_mul, scale, zero_point, x)};
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace op
|
||||||
|
} // namespace pytorch
|
||||||
|
} // namespace frontend
|
||||||
|
} // namespace ov
|
@ -108,6 +108,10 @@ OP_CONVERTER(translate_pow);
|
|||||||
OP_CONVERTER(translate_pythonop);
|
OP_CONVERTER(translate_pythonop);
|
||||||
OP_CONVERTER(translate_quantize_per_channel);
|
OP_CONVERTER(translate_quantize_per_channel);
|
||||||
OP_CONVERTER(translate_quantize_per_tensor);
|
OP_CONVERTER(translate_quantize_per_tensor);
|
||||||
|
OP_CONVERTER(translate_quantized_add);
|
||||||
|
OP_CONVERTER(translate_quantized_add_relu);
|
||||||
|
OP_CONVERTER(translate_quantized_hardswish);
|
||||||
|
OP_CONVERTER(translate_quantized_mul);
|
||||||
OP_CONVERTER(translate_range_length);
|
OP_CONVERTER(translate_range_length);
|
||||||
OP_CONVERTER(translate_rand);
|
OP_CONVERTER(translate_rand);
|
||||||
OP_CONVERTER(translate_randn);
|
OP_CONVERTER(translate_randn);
|
||||||
@ -429,8 +433,12 @@ const std::map<std::string, CreatorFunction> get_supported_ops() {
|
|||||||
{"prim::requires_grad", op::return_false_scalar},
|
{"prim::requires_grad", op::return_false_scalar},
|
||||||
{"prim::PythonOp", op::translate_pythonop},
|
{"prim::PythonOp", op::translate_pythonop},
|
||||||
{"prim::type", op::skip_node}, // Used with prim::device, pass PtFrameworkNode.
|
{"prim::type", op::skip_node}, // Used with prim::device, pass PtFrameworkNode.
|
||||||
|
{"quantized::add", op::translate_quantized_add},
|
||||||
|
{"quantized::add_relu", op::translate_quantized_add_relu},
|
||||||
{"quantized::conv2d", op::translate_quantized_convnd},
|
{"quantized::conv2d", op::translate_quantized_convnd},
|
||||||
{"quantized::conv2d_relu", op::translate_quantized_convnd_relu},
|
{"quantized::conv2d_relu", op::translate_quantized_convnd_relu},
|
||||||
|
{"quantized::hardswish", op::translate_quantized_hardswish},
|
||||||
|
{"quantized::mul", op::translate_quantized_mul},
|
||||||
{"quantized::linear", op::translate_quantized_linear},
|
{"quantized::linear", op::translate_quantized_linear},
|
||||||
{"torchvision::deform_conv2d", op::translate_deform_conv},
|
{"torchvision::deform_conv2d", op::translate_deform_conv},
|
||||||
{"torchvision::nms", op::translate_nms},
|
{"torchvision::nms", op::translate_nms},
|
||||||
|
@ -125,21 +125,34 @@ class PytorchLayerTest:
|
|||||||
# Compare Ie results with Framework results
|
# Compare Ie results with Framework results
|
||||||
fw_eps = custom_eps if precision == 'FP32' else 5e-2
|
fw_eps = custom_eps if precision == 'FP32' else 5e-2
|
||||||
is_ok = True
|
is_ok = True
|
||||||
|
quantized_ops = False
|
||||||
|
if 'quantized_ops' in kwargs and kwargs['quantized_ops'] is not None:
|
||||||
|
quantized_ops = kwargs['quantized_ops']
|
||||||
|
if quantized_ops:
|
||||||
|
assert 'quant_size' in kwargs, "quant size must be specified for quantized_ops flag"
|
||||||
|
quant_size = kwargs['quant_size']
|
||||||
for i in range(len(infer_res)):
|
for i in range(len(infer_res)):
|
||||||
cur_fw_res = flatten_fw_res[i].to(memory_format=torch.contiguous_format).numpy(
|
cur_fw_res = flatten_fw_res[i].to(memory_format=torch.contiguous_format).numpy(
|
||||||
) if isinstance(flatten_fw_res[i], torch.Tensor) else flatten_fw_res[i]
|
) if isinstance(flatten_fw_res[i], torch.Tensor) else flatten_fw_res[i]
|
||||||
|
if np.array(cur_fw_res).size == 0:
|
||||||
|
continue
|
||||||
cur_ov_res = infer_res[compiled.output(i)]
|
cur_ov_res = infer_res[compiled.output(i)]
|
||||||
print(f"fw_re: {cur_fw_res};\n ov_res: {cur_ov_res}")
|
print(f"fw_res: {cur_fw_res};\n ov_res: {cur_ov_res}")
|
||||||
if not np.allclose(cur_ov_res, cur_fw_res,
|
n_is_not_close = np.array(cur_fw_res).size - np.isclose(cur_ov_res, cur_fw_res,
|
||||||
atol=fw_eps,
|
atol=fw_eps,
|
||||||
rtol=fw_eps, equal_nan=True):
|
rtol=fw_eps, equal_nan=True).sum()
|
||||||
|
max_diff = np.array(abs(np.array(cur_ov_res, dtype=np.float32) - np.array(cur_fw_res, dtype=np.float32))).max()
|
||||||
|
if not quantized_ops and n_is_not_close > 0:
|
||||||
is_ok = False
|
is_ok = False
|
||||||
print("Max diff is {}".format(
|
print("Max diff is {}".format(max_diff))
|
||||||
np.array(
|
elif quantized_ops and (n_is_not_close > int(np.log10(cur_fw_res.size)) or max_diff > np.array(quant_size + fw_eps).max()):
|
||||||
abs(cur_ov_res - cur_fw_res)).max()))
|
is_ok = False
|
||||||
|
print("Errors outside threshold range: {} with max diff {}, expected at most {} with max diff {}".format(
|
||||||
|
n_is_not_close, max_diff, int(np.log10(cur_fw_res.size)), quant_size + fw_eps))
|
||||||
else:
|
else:
|
||||||
print("Accuracy validation successful!\n")
|
print("Accuracy validation successful!\n")
|
||||||
print("absolute eps: {}, relative eps: {}".format(fw_eps, fw_eps))
|
print("absolute eps: {}, relative eps: {}".format(
|
||||||
|
fw_eps, fw_eps))
|
||||||
assert is_ok, "Accuracy validation failed"
|
assert is_ok, "Accuracy validation failed"
|
||||||
|
|
||||||
# Each model should specify inputs
|
# Each model should specify inputs
|
||||||
|
@ -48,11 +48,11 @@ class TestQuantizePerTensorDequantize(PytorchLayerTest):
|
|||||||
reason="Not supported with FakeQuantize."))
|
reason="Not supported with FakeQuantize."))
|
||||||
])
|
])
|
||||||
@pytest.mark.nightly
|
@pytest.mark.nightly
|
||||||
# @pytest.mark.precommit - sporadic issue
|
@pytest.mark.precommit
|
||||||
def test_quantize_per_tensor_dequantize(self, scale, zero_point, dtype, ie_device, precision, ir_version):
|
def test_quantize_per_tensor_dequantize(self, scale, zero_point, dtype, ie_device, precision, ir_version):
|
||||||
if dtype == torch.quint8: zero_point = abs(zero_point)
|
if dtype == torch.quint8: zero_point = abs(zero_point)
|
||||||
self._test(aten_quantize_per_tensor_aten_dequantize(scale, zero_point, dtype), None, ["aten::quantize_per_tensor", "aten::dequantize"],
|
self._test(aten_quantize_per_tensor_aten_dequantize(scale, zero_point, dtype), None, ["aten::quantize_per_tensor", "aten::dequantize"],
|
||||||
ie_device, precision, ir_version, )
|
ie_device, precision, ir_version, quantized_ops=True, quant_size=scale)
|
||||||
|
|
||||||
class TestQuantizePerChannelDequantize(PytorchLayerTest):
|
class TestQuantizePerChannelDequantize(PytorchLayerTest):
|
||||||
def _prepare_input(self):
|
def _prepare_input(self):
|
||||||
@ -87,9 +87,9 @@ class TestQuantizePerChannelDequantize(PytorchLayerTest):
|
|||||||
reason="Not supported with FakeQuantize."))
|
reason="Not supported with FakeQuantize."))
|
||||||
])
|
])
|
||||||
@pytest.mark.nightly
|
@pytest.mark.nightly
|
||||||
# @pytest.mark.precommit - sporadic issue
|
@pytest.mark.precommit
|
||||||
def test_quantize_per_channel_dequantize(self, scale, zero_point, axis, dtype, ie_device, precision, ir_version):
|
def test_quantize_per_channel_dequantize(self, scale, zero_point, dtype, axis, ie_device, precision, ir_version):
|
||||||
np.random.shuffle(scale), np.random.shuffle(zero_point)
|
np.random.shuffle(scale), np.random.shuffle(zero_point)
|
||||||
if dtype == torch.quint8: zero_point = abs(zero_point)
|
if dtype == torch.quint8: zero_point = abs(zero_point)
|
||||||
self._test(aten_quantize_per_channel_aten_dequantize(scale, zero_point, axis, dtype), None, ["aten::quantize_per_channel", "aten::dequantize"],
|
self._test(aten_quantize_per_channel_aten_dequantize(scale, zero_point, dtype, axis), None, ["aten::quantize_per_channel", "aten::dequantize"],
|
||||||
ie_device, precision, ir_version, )
|
ie_device, precision, ir_version, quantized_ops=True, quant_size=scale)
|
||||||
|
44
tests/layer_tests/pytorch_tests/test_quantized_add.py
Normal file
44
tests/layer_tests/pytorch_tests/test_quantized_add.py
Normal file
@ -0,0 +1,44 @@
|
|||||||
|
# Copyright (C) 2018-2023 Intel Corporation
|
||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import pytest
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from pytorch_layer_test_class import PytorchLayerTest
|
||||||
|
|
||||||
|
class quantized_add(torch.nn.Module):
|
||||||
|
def __init__(self, scale, zero_point, dtype) -> None:
|
||||||
|
torch.nn.Module.__init__(self)
|
||||||
|
self.scale = scale
|
||||||
|
self.zero_point = zero_point
|
||||||
|
self.dtype = dtype
|
||||||
|
|
||||||
|
def forward(self, input_tensor1, input_tensor2):
|
||||||
|
quantized_tensor1 = torch.quantize_per_tensor(input_tensor1, self.scale, self.zero_point, self.dtype)
|
||||||
|
quantized_tensor2 = torch.quantize_per_tensor(input_tensor2, self.scale, self.zero_point, self.dtype)
|
||||||
|
quantized_add = torch.ops.quantized.add(quantized_tensor1, quantized_tensor2, self.scale, self.zero_point)
|
||||||
|
dequantized_tensor = torch.dequantize(quantized_add)
|
||||||
|
return dequantized_tensor
|
||||||
|
|
||||||
|
class TestQuantizedAdd(PytorchLayerTest):
|
||||||
|
def _prepare_input(self):
|
||||||
|
return (np.array(5.00 * np.random.rand(100, 100) + 5.00, dtype=np.float32),
|
||||||
|
np.array(5.00 * np.random.rand(100, 100) + 5.00, dtype=np.float32))
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("scale", [
|
||||||
|
1.0, 0.21, 0.62
|
||||||
|
])
|
||||||
|
@pytest.mark.parametrize("zero_point", [
|
||||||
|
0, 4, -7
|
||||||
|
])
|
||||||
|
@pytest.mark.parametrize("dtype", [
|
||||||
|
torch.quint8,
|
||||||
|
torch.qint8
|
||||||
|
])
|
||||||
|
@pytest.mark.nightly
|
||||||
|
@pytest.mark.precommit
|
||||||
|
def test_quantized_add(self, scale, zero_point, dtype, ie_device, precision, ir_version):
|
||||||
|
if dtype == torch.quint8: zero_point = abs(zero_point)
|
||||||
|
self._test(quantized_add(scale, zero_point, dtype), None, ["quantized::add"],
|
||||||
|
ie_device, precision, ir_version, quantized_ops=True, quant_size=scale)
|
44
tests/layer_tests/pytorch_tests/test_quantized_add_relu.py
Normal file
44
tests/layer_tests/pytorch_tests/test_quantized_add_relu.py
Normal file
@ -0,0 +1,44 @@
|
|||||||
|
# Copyright (C) 2018-2023 Intel Corporation
|
||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import pytest
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from pytorch_layer_test_class import PytorchLayerTest
|
||||||
|
|
||||||
|
class quantized_add_relu(torch.nn.Module):
|
||||||
|
def __init__(self, scale, zero_point, dtype) -> None:
|
||||||
|
torch.nn.Module.__init__(self)
|
||||||
|
self.scale = scale
|
||||||
|
self.zero_point = zero_point
|
||||||
|
self.dtype = dtype
|
||||||
|
|
||||||
|
def forward(self, input_tensor1, input_tensor2):
|
||||||
|
quantized_tensor1 = torch.quantize_per_tensor(input_tensor1, self.scale, self.zero_point, self.dtype)
|
||||||
|
quantized_tensor2 = torch.quantize_per_tensor(input_tensor2, self.scale, self.zero_point, self.dtype)
|
||||||
|
quantized_add_relu = torch.ops.quantized.add_relu(quantized_tensor1, quantized_tensor2, self.scale, self.zero_point)
|
||||||
|
dequantized_tensor = torch.dequantize(quantized_add_relu)
|
||||||
|
return dequantized_tensor
|
||||||
|
|
||||||
|
class TestQuantizedAddReLU(PytorchLayerTest):
|
||||||
|
def _prepare_input(self):
|
||||||
|
return (np.array(5.00 * np.random.rand(100, 100) + 5.00, dtype=np.float32),
|
||||||
|
np.array(5.00 * np.random.rand(100, 100) + 5.00, dtype=np.float32))
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("scale", [
|
||||||
|
1.0, 0.21, 0.62
|
||||||
|
])
|
||||||
|
@pytest.mark.parametrize("zero_point", [
|
||||||
|
0, 4, -7
|
||||||
|
])
|
||||||
|
@pytest.mark.parametrize("dtype", [
|
||||||
|
torch.quint8,
|
||||||
|
torch.qint8
|
||||||
|
])
|
||||||
|
@pytest.mark.nightly
|
||||||
|
@pytest.mark.precommit
|
||||||
|
def test_quantized_add_relu(self, scale, zero_point, dtype, ie_device, precision, ir_version):
|
||||||
|
if dtype == torch.quint8: zero_point = abs(zero_point)
|
||||||
|
self._test(quantized_add_relu(scale, zero_point, dtype), None, ["quantized::add_relu"],
|
||||||
|
ie_device, precision, ir_version, quantized_ops=True, quant_size=scale)
|
42
tests/layer_tests/pytorch_tests/test_quantized_hardswish.py
Normal file
42
tests/layer_tests/pytorch_tests/test_quantized_hardswish.py
Normal file
@ -0,0 +1,42 @@
|
|||||||
|
# Copyright (C) 2018-2023 Intel Corporation
|
||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import pytest
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from pytorch_layer_test_class import PytorchLayerTest
|
||||||
|
|
||||||
|
class quantized_hardswish(torch.nn.Module):
|
||||||
|
def __init__(self, scale, zero_point, dtype) -> None:
|
||||||
|
torch.nn.Module.__init__(self)
|
||||||
|
self.scale = scale
|
||||||
|
self.zero_point = zero_point
|
||||||
|
self.dtype = dtype
|
||||||
|
|
||||||
|
def forward(self, input_tensor1):
|
||||||
|
quantized_tensor1 = torch.quantize_per_tensor(input_tensor1, self.scale, self.zero_point, self.dtype)
|
||||||
|
quantized_hardswish = torch.ops.quantized.hardswish(quantized_tensor1, self.scale, self.zero_point)
|
||||||
|
dequantized_tensor = torch.dequantize(quantized_hardswish)
|
||||||
|
return dequantized_tensor
|
||||||
|
|
||||||
|
class TestQuantizedHardswish(PytorchLayerTest):
|
||||||
|
def _prepare_input(self):
|
||||||
|
return (np.array(5.00 * np.random.rand(100, 100) + 5.00, dtype=np.float32),)
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("scale", [
|
||||||
|
1.0, 0.21, 0.62,
|
||||||
|
])
|
||||||
|
@pytest.mark.parametrize("zero_point", [
|
||||||
|
0, 4, -7
|
||||||
|
])
|
||||||
|
@pytest.mark.parametrize("dtype", [
|
||||||
|
torch.quint8,
|
||||||
|
torch.qint8
|
||||||
|
])
|
||||||
|
@pytest.mark.nightly
|
||||||
|
@pytest.mark.precommit
|
||||||
|
def test_quantized_hardswish(self, scale, zero_point, dtype, ie_device, precision, ir_version):
|
||||||
|
if dtype == torch.quint8: zero_point = abs(zero_point)
|
||||||
|
self._test(quantized_hardswish(scale, zero_point, dtype), None, ["quantized::hardswish"],
|
||||||
|
ie_device, precision, ir_version, quantized_ops=True, quant_size=scale)
|
44
tests/layer_tests/pytorch_tests/test_quantized_mul.py
Normal file
44
tests/layer_tests/pytorch_tests/test_quantized_mul.py
Normal file
@ -0,0 +1,44 @@
|
|||||||
|
# Copyright (C) 2018-2023 Intel Corporation
|
||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import pytest
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from pytorch_layer_test_class import PytorchLayerTest
|
||||||
|
|
||||||
|
class quantized_mul(torch.nn.Module):
|
||||||
|
def __init__(self, scale, zero_point, dtype) -> None:
|
||||||
|
torch.nn.Module.__init__(self)
|
||||||
|
self.scale = scale
|
||||||
|
self.zero_point = zero_point
|
||||||
|
self.dtype = dtype
|
||||||
|
|
||||||
|
def forward(self, input_tensor1, input_tensor2):
|
||||||
|
quantized_tensor1 = torch.quantize_per_tensor(input_tensor1, self.scale, self.zero_point, self.dtype)
|
||||||
|
quantized_tensor2 = torch.quantize_per_tensor(input_tensor2, self.scale, self.zero_point, self.dtype)
|
||||||
|
quantized_mul = torch.ops.quantized.mul(quantized_tensor1, quantized_tensor2, self.scale, self.zero_point)
|
||||||
|
dequantized_tensor = torch.dequantize(quantized_mul)
|
||||||
|
return dequantized_tensor
|
||||||
|
|
||||||
|
class TestQuantizedMul(PytorchLayerTest):
|
||||||
|
def _prepare_input(self):
|
||||||
|
return (np.array(5.00 * np.random.rand(100, 100) + 5.00, dtype=np.float32),
|
||||||
|
np.array(5.00 * np.random.rand(100, 100) + 5.00, dtype=np.float32))
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("scale", [
|
||||||
|
1.0, 0.21, 0.62
|
||||||
|
])
|
||||||
|
@pytest.mark.parametrize("zero_point", [
|
||||||
|
0, 4, -7
|
||||||
|
])
|
||||||
|
@pytest.mark.parametrize("dtype", [
|
||||||
|
torch.quint8,
|
||||||
|
torch.qint8
|
||||||
|
])
|
||||||
|
@pytest.mark.nightly
|
||||||
|
# @pytest.mark.precommit - accuracy problem
|
||||||
|
def test_quantized_mul(self, scale, zero_point, dtype, ie_device, precision, ir_version):
|
||||||
|
if dtype == torch.quint8: zero_point = abs(zero_point)
|
||||||
|
self._test(quantized_mul(scale, zero_point, dtype), None, ["quantized::mul"],
|
||||||
|
ie_device, precision, ir_version, quantized_ops=True, quant_size=scale)
|
Loading…
Reference in New Issue
Block a user