[PT FE]: fix constant folding for dequantization (#18190)

* [PT FE]: fix constant folding for dequantization

* add test
This commit is contained in:
Ekaterina Aidova
2023-06-23 08:41:32 +04:00
committed by GitHub
parent d13adf7ae8
commit df0bd18ed2
4 changed files with 103 additions and 1 deletions

View File

@@ -54,4 +54,30 @@ class TestTupleConstruct(PytorchLayerTest):
@pytest.mark.parametrize("case", ["single", "multiple", "none", "list", "list_and_tuple"])
@pytest.mark.nightly
def test_tuple_construct(self, case, ie_device, precision, ir_version):
self._test(*self.create_model(case), ie_device, precision, ir_version)
self._test(*self.create_model(case), ie_device, precision, ir_version)
class TestTupleConstructTupleUnpack(PytorchLayerTest):
def _prepare_input(self):
return (np.random.uniform(0, 50, (1, 2, 10)).astype(np.float32),)
def create_model(self):
import torch
class prim_tuple_construct_tuple_unpack(torch.nn.Module):
def forward(self, x):
x1, x2, x3, x4, x5 = self.prepare_input(x)
return x1, x2, x3, x4, x5
def prepare_input(self, x):
return x, x + 2, None, x.reshape(-1), (x * 10).to(torch.int32)
ref_net = None
return prim_tuple_construct_tuple_unpack(), ref_net, ["prim::TupleConstruct", "prim::TupleUnpack"]
@pytest.mark.nightly
def test_tuple_construct_unpack(self, ie_device, precision, ir_version):
self._test(*self.create_model(), ie_device, precision, ir_version, freeze_model=False)