[PT FE] Fix Sporadic Quantized Ops Errors (#18962)
* [PT FE] Fix sporadics with round & 0 zero_pt, 1.0 scale * [PT FE] Change scale and round input in quantized cat tests * [PT FE] Add rounding to conv & linear tests * Update test_quantized_cat.py * Update test_quantized_cat.py * [PT FE] Replace randn with rand for consistency in convnd
This commit is contained in:
parent
2830b20d73
commit
e77070890a
@ -155,8 +155,8 @@ class PytorchLayerTest:
|
||||
n_is_not_close, max_diff, int(np.log10(cur_fw_res.size)), quant_size + fw_eps))
|
||||
else:
|
||||
print("Accuracy validation successful!\n")
|
||||
print("absolute eps: {}, relative eps: {}".format(
|
||||
fw_eps, fw_eps))
|
||||
print("absolute eps: {}, relative eps: {}, errors: {}".format(
|
||||
fw_eps, fw_eps, n_is_not_close))
|
||||
assert is_ok, "Accuracy validation failed"
|
||||
|
||||
# Each model should specify inputs
|
||||
|
@ -15,19 +15,19 @@ class quantized_add(torch.nn.Module):
|
||||
self.dtype = dtype
|
||||
|
||||
def forward(self, input_tensor1, input_tensor2):
|
||||
quantized_tensor1 = torch.quantize_per_tensor(input_tensor1, self.scale, self.zero_point, self.dtype)
|
||||
quantized_tensor2 = torch.quantize_per_tensor(input_tensor2, self.scale, self.zero_point, self.dtype)
|
||||
quantized_tensor1 = torch.quantize_per_tensor(input_tensor1, 1.0, 0, self.dtype)
|
||||
quantized_tensor2 = torch.quantize_per_tensor(input_tensor2, 1.0, 0, self.dtype)
|
||||
quantized_add = torch.ops.quantized.add(quantized_tensor1, quantized_tensor2, self.scale, self.zero_point)
|
||||
dequantized_tensor = torch.dequantize(quantized_add)
|
||||
return dequantized_tensor
|
||||
|
||||
class TestQuantizedAdd(PytorchLayerTest):
|
||||
def _prepare_input(self):
|
||||
return (np.array(5.00 * np.random.rand(100, 100) + 5.00, dtype=np.float32),
|
||||
np.array(5.00 * np.random.rand(100, 100) + 5.00, dtype=np.float32))
|
||||
return (np.round(np.array(5.00 * np.random.rand(10, 10) - 2.50, dtype=np.float32), 4),
|
||||
np.round(np.array(5.00 * np.random.rand(10, 10) - 2.50, dtype=np.float32), 4))
|
||||
|
||||
@pytest.mark.parametrize("scale", [
|
||||
1.0, 0.21, 0.62
|
||||
1.0, 0.21, 0.62, 0.9999
|
||||
])
|
||||
@pytest.mark.parametrize("zero_point", [
|
||||
0, 4, -7
|
||||
|
@ -15,19 +15,19 @@ class quantized_add_relu(torch.nn.Module):
|
||||
self.dtype = dtype
|
||||
|
||||
def forward(self, input_tensor1, input_tensor2):
|
||||
quantized_tensor1 = torch.quantize_per_tensor(input_tensor1, self.scale, self.zero_point, self.dtype)
|
||||
quantized_tensor2 = torch.quantize_per_tensor(input_tensor2, self.scale, self.zero_point, self.dtype)
|
||||
quantized_tensor1 = torch.quantize_per_tensor(input_tensor1, 1.0, 0, self.dtype)
|
||||
quantized_tensor2 = torch.quantize_per_tensor(input_tensor2, 1.0, 0, self.dtype)
|
||||
quantized_add_relu = torch.ops.quantized.add_relu(quantized_tensor1, quantized_tensor2, self.scale, self.zero_point)
|
||||
dequantized_tensor = torch.dequantize(quantized_add_relu)
|
||||
return dequantized_tensor
|
||||
|
||||
class TestQuantizedAddReLU(PytorchLayerTest):
|
||||
def _prepare_input(self):
|
||||
return (np.array(5.00 * np.random.rand(100, 100) + 5.00, dtype=np.float32),
|
||||
np.array(5.00 * np.random.rand(100, 100) + 5.00, dtype=np.float32))
|
||||
return (np.round(np.array(5.00 * np.random.rand(10, 10) - 2.50, dtype=np.float32), 4),
|
||||
np.round(np.array(5.00 * np.random.rand(10, 10) - 2.50, dtype=np.float32), 4))
|
||||
|
||||
@pytest.mark.parametrize("scale", [
|
||||
1.0, 0.21, 0.62
|
||||
1.0, 0.21, 0.62, 0.9999
|
||||
])
|
||||
@pytest.mark.parametrize("zero_point", [
|
||||
0, 4, -7
|
||||
|
@ -15,8 +15,8 @@ class aten_quantized_cat(torch.nn.Module):
|
||||
self.dtype = dtype
|
||||
|
||||
def forward(self, inp):
|
||||
x = torch.quantize_per_tensor(inp, 1.3, 0, self.dtype)
|
||||
y = torch.quantize_per_tensor(inp, 1.0, 1, self.dtype)
|
||||
x = torch.quantize_per_tensor(inp, 1.0, 0, self.dtype)
|
||||
y = torch.quantize_per_tensor(inp, 1.0, 0, self.dtype)
|
||||
return torch.dequantize(torch.ops.quantized.cat([x, y], 1, self.scale, self.zero_point))
|
||||
|
||||
|
||||
@ -66,7 +66,7 @@ class aten_add_quantized_cat(torch.nn.Module):
|
||||
|
||||
class TestQuantizedCat(PytorchLayerTest):
|
||||
def _prepare_input(self):
|
||||
return (np.random.rand(2, 1, 3).astype(np.float32),)
|
||||
return (np.round(np.random.rand(2, 1, 3).astype(np.float32), 4),)
|
||||
|
||||
@pytest.mark.parametrize("scale", [1.0, 0.3, 1.3])
|
||||
@pytest.mark.parametrize("zero_point", [0, 1])
|
||||
|
@ -12,7 +12,7 @@ from pytorch_layer_test_class import PytorchLayerTest
|
||||
|
||||
class TestQuantizedConv2D(PytorchLayerTest):
|
||||
def _prepare_input(self):
|
||||
return (np.random.randn(2, 3, 25, 25).astype(np.float32),)
|
||||
return (np.round(np.random.rand(2, 3, 25, 25).astype(np.float32), 4),)
|
||||
|
||||
def create_model(self, weights_shape, strides, pads, dilations, groups, bias, relu, scale, zero_point):
|
||||
class quantized_conv2d(torch.nn.Module):
|
||||
|
@ -15,17 +15,17 @@ class quantized_hardswish(torch.nn.Module):
|
||||
self.dtype = dtype
|
||||
|
||||
def forward(self, input_tensor1):
|
||||
quantized_tensor1 = torch.quantize_per_tensor(input_tensor1, self.scale, self.zero_point, self.dtype)
|
||||
quantized_tensor1 = torch.quantize_per_tensor(input_tensor1, 1.0, 0, self.dtype)
|
||||
quantized_hardswish = torch.ops.quantized.hardswish(quantized_tensor1, self.scale, self.zero_point)
|
||||
dequantized_tensor = torch.dequantize(quantized_hardswish)
|
||||
return dequantized_tensor
|
||||
|
||||
class TestQuantizedHardswish(PytorchLayerTest):
|
||||
def _prepare_input(self):
|
||||
return (np.array(5.00 * np.random.rand(100, 100) + 5.00, dtype=np.float32),)
|
||||
return (np.round(np.array(5.00 * np.random.rand(10, 10) - 2.50, dtype=np.float32), 4),)
|
||||
|
||||
@pytest.mark.parametrize("scale", [
|
||||
1.0, 0.21, 0.62,
|
||||
1.0, 0.21, 0.62, 0.9999
|
||||
])
|
||||
@pytest.mark.parametrize("zero_point", [
|
||||
0, 4, -7
|
||||
|
@ -9,7 +9,7 @@ from pytorch_layer_test_class import PytorchLayerTest
|
||||
|
||||
class TestQuantizedLinear(PytorchLayerTest):
|
||||
def _prepare_input(self, input_shape=(2, 2)):
|
||||
return (np.random.randn(*input_shape).astype(np.float32),)
|
||||
return (np.round(np.random.rand(*input_shape).astype(np.float32), 4),)
|
||||
|
||||
def create_model(self, weight_shape, is_bias, scale, zero_point):
|
||||
|
||||
@ -27,7 +27,7 @@ class TestQuantizedLinear(PytorchLayerTest):
|
||||
self.linear.zero_point = int(zero_point)
|
||||
|
||||
def forward(self, inp):
|
||||
inp_q = torch.quantize_per_tensor(inp, 1., 0, torch.quint8)
|
||||
inp_q = torch.quantize_per_tensor(inp, 1.0, 0, torch.quint8)
|
||||
return torch.dequantize(self.linear(inp_q))
|
||||
|
||||
ref_net = None
|
||||
|
@ -15,19 +15,19 @@ class quantized_mul(torch.nn.Module):
|
||||
self.dtype = dtype
|
||||
|
||||
def forward(self, input_tensor1, input_tensor2):
|
||||
quantized_tensor1 = torch.quantize_per_tensor(input_tensor1, self.scale, self.zero_point, self.dtype)
|
||||
quantized_tensor2 = torch.quantize_per_tensor(input_tensor2, self.scale, self.zero_point, self.dtype)
|
||||
quantized_tensor1 = torch.quantize_per_tensor(input_tensor1, 1.0, 0, self.dtype)
|
||||
quantized_tensor2 = torch.quantize_per_tensor(input_tensor2, 1.0, 0, self.dtype)
|
||||
quantized_mul = torch.ops.quantized.mul(quantized_tensor1, quantized_tensor2, self.scale, self.zero_point)
|
||||
dequantized_tensor = torch.dequantize(quantized_mul)
|
||||
return dequantized_tensor
|
||||
|
||||
class TestQuantizedMul(PytorchLayerTest):
|
||||
def _prepare_input(self):
|
||||
return (np.array(5.00 * np.random.rand(100, 100) + 5.00, dtype=np.float32),
|
||||
np.array(5.00 * np.random.rand(100, 100) + 5.00, dtype=np.float32))
|
||||
return (np.round(np.array(5.00 * np.random.rand(10, 10) - 2.50, dtype=np.float32), 4),
|
||||
np.round(np.array(5.00 * np.random.rand(10, 10) - 2.50, dtype=np.float32), 4))
|
||||
|
||||
@pytest.mark.parametrize("scale", [
|
||||
1.0, 0.21, 0.62
|
||||
1.0, 0.21, 0.62, 0.9999
|
||||
])
|
||||
@pytest.mark.parametrize("zero_point", [
|
||||
0, 4, -7
|
||||
@ -37,7 +37,7 @@ class TestQuantizedMul(PytorchLayerTest):
|
||||
torch.qint8
|
||||
])
|
||||
@pytest.mark.nightly
|
||||
# @pytest.mark.precommit - accuracy problem
|
||||
@pytest.mark.precommit
|
||||
def test_quantized_mul(self, scale, zero_point, dtype, ie_device, precision, ir_version):
|
||||
if dtype == torch.quint8: zero_point = abs(zero_point)
|
||||
self._test(quantized_mul(scale, zero_point, dtype), None, ["quantized::mul"],
|
||||
|
Loading…
Reference in New Issue
Block a user