[PT FE] Fix Sporadic Quantized Ops Errors (#18962)

* [PT FE] Fix sporadics with round & 0 zero_pt, 1.0 scale

* [PT FE] Change scale and round input in quantized cat tests

* [PT FE] Add rounding to conv & linear tests

* Update test_quantized_cat.py

* Update test_quantized_cat.py

* [PT FE] Replace randn with rand for consistency in convnd
This commit is contained in:
Piotr Krzemiński 2023-08-14 15:41:36 +02:00 committed by GitHub
parent 2830b20d73
commit e77070890a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 27 additions and 27 deletions

View File

@ -155,8 +155,8 @@ class PytorchLayerTest:
n_is_not_close, max_diff, int(np.log10(cur_fw_res.size)), quant_size + fw_eps))
else:
print("Accuracy validation successful!\n")
print("absolute eps: {}, relative eps: {}".format(
fw_eps, fw_eps))
print("absolute eps: {}, relative eps: {}, errors: {}".format(
fw_eps, fw_eps, n_is_not_close))
assert is_ok, "Accuracy validation failed"
# Each model should specify inputs

View File

@ -15,19 +15,19 @@ class quantized_add(torch.nn.Module):
self.dtype = dtype
def forward(self, input_tensor1, input_tensor2):
quantized_tensor1 = torch.quantize_per_tensor(input_tensor1, self.scale, self.zero_point, self.dtype)
quantized_tensor2 = torch.quantize_per_tensor(input_tensor2, self.scale, self.zero_point, self.dtype)
quantized_tensor1 = torch.quantize_per_tensor(input_tensor1, 1.0, 0, self.dtype)
quantized_tensor2 = torch.quantize_per_tensor(input_tensor2, 1.0, 0, self.dtype)
quantized_add = torch.ops.quantized.add(quantized_tensor1, quantized_tensor2, self.scale, self.zero_point)
dequantized_tensor = torch.dequantize(quantized_add)
return dequantized_tensor
class TestQuantizedAdd(PytorchLayerTest):
def _prepare_input(self):
return (np.array(5.00 * np.random.rand(100, 100) + 5.00, dtype=np.float32),
np.array(5.00 * np.random.rand(100, 100) + 5.00, dtype=np.float32))
return (np.round(np.array(5.00 * np.random.rand(10, 10) - 2.50, dtype=np.float32), 4),
np.round(np.array(5.00 * np.random.rand(10, 10) - 2.50, dtype=np.float32), 4))
@pytest.mark.parametrize("scale", [
1.0, 0.21, 0.62
1.0, 0.21, 0.62, 0.9999
])
@pytest.mark.parametrize("zero_point", [
0, 4, -7

View File

@ -15,19 +15,19 @@ class quantized_add_relu(torch.nn.Module):
self.dtype = dtype
def forward(self, input_tensor1, input_tensor2):
quantized_tensor1 = torch.quantize_per_tensor(input_tensor1, self.scale, self.zero_point, self.dtype)
quantized_tensor2 = torch.quantize_per_tensor(input_tensor2, self.scale, self.zero_point, self.dtype)
quantized_tensor1 = torch.quantize_per_tensor(input_tensor1, 1.0, 0, self.dtype)
quantized_tensor2 = torch.quantize_per_tensor(input_tensor2, 1.0, 0, self.dtype)
quantized_add_relu = torch.ops.quantized.add_relu(quantized_tensor1, quantized_tensor2, self.scale, self.zero_point)
dequantized_tensor = torch.dequantize(quantized_add_relu)
return dequantized_tensor
class TestQuantizedAddReLU(PytorchLayerTest):
def _prepare_input(self):
return (np.array(5.00 * np.random.rand(100, 100) + 5.00, dtype=np.float32),
np.array(5.00 * np.random.rand(100, 100) + 5.00, dtype=np.float32))
return (np.round(np.array(5.00 * np.random.rand(10, 10) - 2.50, dtype=np.float32), 4),
np.round(np.array(5.00 * np.random.rand(10, 10) - 2.50, dtype=np.float32), 4))
@pytest.mark.parametrize("scale", [
1.0, 0.21, 0.62
1.0, 0.21, 0.62, 0.9999
])
@pytest.mark.parametrize("zero_point", [
0, 4, -7

View File

@ -15,8 +15,8 @@ class aten_quantized_cat(torch.nn.Module):
self.dtype = dtype
def forward(self, inp):
x = torch.quantize_per_tensor(inp, 1.3, 0, self.dtype)
y = torch.quantize_per_tensor(inp, 1.0, 1, self.dtype)
x = torch.quantize_per_tensor(inp, 1.0, 0, self.dtype)
y = torch.quantize_per_tensor(inp, 1.0, 0, self.dtype)
return torch.dequantize(torch.ops.quantized.cat([x, y], 1, self.scale, self.zero_point))
@ -66,7 +66,7 @@ class aten_add_quantized_cat(torch.nn.Module):
class TestQuantizedCat(PytorchLayerTest):
def _prepare_input(self):
return (np.random.rand(2, 1, 3).astype(np.float32),)
return (np.round(np.random.rand(2, 1, 3).astype(np.float32), 4),)
@pytest.mark.parametrize("scale", [1.0, 0.3, 1.3])
@pytest.mark.parametrize("zero_point", [0, 1])

View File

@ -12,7 +12,7 @@ from pytorch_layer_test_class import PytorchLayerTest
class TestQuantizedConv2D(PytorchLayerTest):
def _prepare_input(self):
return (np.random.randn(2, 3, 25, 25).astype(np.float32),)
return (np.round(np.random.rand(2, 3, 25, 25).astype(np.float32), 4),)
def create_model(self, weights_shape, strides, pads, dilations, groups, bias, relu, scale, zero_point):
class quantized_conv2d(torch.nn.Module):

View File

@ -15,17 +15,17 @@ class quantized_hardswish(torch.nn.Module):
self.dtype = dtype
def forward(self, input_tensor1):
quantized_tensor1 = torch.quantize_per_tensor(input_tensor1, self.scale, self.zero_point, self.dtype)
quantized_tensor1 = torch.quantize_per_tensor(input_tensor1, 1.0, 0, self.dtype)
quantized_hardswish = torch.ops.quantized.hardswish(quantized_tensor1, self.scale, self.zero_point)
dequantized_tensor = torch.dequantize(quantized_hardswish)
return dequantized_tensor
class TestQuantizedHardswish(PytorchLayerTest):
def _prepare_input(self):
return (np.array(5.00 * np.random.rand(100, 100) + 5.00, dtype=np.float32),)
return (np.round(np.array(5.00 * np.random.rand(10, 10) - 2.50, dtype=np.float32), 4),)
@pytest.mark.parametrize("scale", [
1.0, 0.21, 0.62,
1.0, 0.21, 0.62, 0.9999
])
@pytest.mark.parametrize("zero_point", [
0, 4, -7

View File

@ -9,7 +9,7 @@ from pytorch_layer_test_class import PytorchLayerTest
class TestQuantizedLinear(PytorchLayerTest):
def _prepare_input(self, input_shape=(2, 2)):
return (np.random.randn(*input_shape).astype(np.float32),)
return (np.round(np.random.rand(*input_shape).astype(np.float32), 4),)
def create_model(self, weight_shape, is_bias, scale, zero_point):
@ -27,7 +27,7 @@ class TestQuantizedLinear(PytorchLayerTest):
self.linear.zero_point = int(zero_point)
def forward(self, inp):
inp_q = torch.quantize_per_tensor(inp, 1., 0, torch.quint8)
inp_q = torch.quantize_per_tensor(inp, 1.0, 0, torch.quint8)
return torch.dequantize(self.linear(inp_q))
ref_net = None

View File

@ -15,19 +15,19 @@ class quantized_mul(torch.nn.Module):
self.dtype = dtype
def forward(self, input_tensor1, input_tensor2):
quantized_tensor1 = torch.quantize_per_tensor(input_tensor1, self.scale, self.zero_point, self.dtype)
quantized_tensor2 = torch.quantize_per_tensor(input_tensor2, self.scale, self.zero_point, self.dtype)
quantized_tensor1 = torch.quantize_per_tensor(input_tensor1, 1.0, 0, self.dtype)
quantized_tensor2 = torch.quantize_per_tensor(input_tensor2, 1.0, 0, self.dtype)
quantized_mul = torch.ops.quantized.mul(quantized_tensor1, quantized_tensor2, self.scale, self.zero_point)
dequantized_tensor = torch.dequantize(quantized_mul)
return dequantized_tensor
class TestQuantizedMul(PytorchLayerTest):
def _prepare_input(self):
return (np.array(5.00 * np.random.rand(100, 100) + 5.00, dtype=np.float32),
np.array(5.00 * np.random.rand(100, 100) + 5.00, dtype=np.float32))
return (np.round(np.array(5.00 * np.random.rand(10, 10) - 2.50, dtype=np.float32), 4),
np.round(np.array(5.00 * np.random.rand(10, 10) - 2.50, dtype=np.float32), 4))
@pytest.mark.parametrize("scale", [
1.0, 0.21, 0.62
1.0, 0.21, 0.62, 0.9999
])
@pytest.mark.parametrize("zero_point", [
0, 4, -7
@ -37,7 +37,7 @@ class TestQuantizedMul(PytorchLayerTest):
torch.qint8
])
@pytest.mark.nightly
# @pytest.mark.precommit - accuracy problem
@pytest.mark.precommit
def test_quantized_mul(self, scale, zero_point, dtype, ie_device, precision, ir_version):
if dtype == torch.quint8: zero_point = abs(zero_point)
self._test(quantized_mul(scale, zero_point, dtype), None, ["quantized::mul"],