[PT FE]: fix constant folding for dequantization (#18190)
* [PT FE]: fix constant folding for dequantization * add test
This commit is contained in:
@@ -54,4 +54,30 @@ class TestTupleConstruct(PytorchLayerTest):
|
||||
@pytest.mark.parametrize("case", ["single", "multiple", "none", "list", "list_and_tuple"])
|
||||
@pytest.mark.nightly
|
||||
def test_tuple_construct(self, case, ie_device, precision, ir_version):
|
||||
self._test(*self.create_model(case), ie_device, precision, ir_version)
|
||||
self._test(*self.create_model(case), ie_device, precision, ir_version)
|
||||
|
||||
|
||||
class TestTupleConstructTupleUnpack(PytorchLayerTest):
|
||||
def _prepare_input(self):
|
||||
return (np.random.uniform(0, 50, (1, 2, 10)).astype(np.float32),)
|
||||
|
||||
def create_model(self):
|
||||
import torch
|
||||
|
||||
class prim_tuple_construct_tuple_unpack(torch.nn.Module):
|
||||
|
||||
def forward(self, x):
|
||||
x1, x2, x3, x4, x5 = self.prepare_input(x)
|
||||
return x1, x2, x3, x4, x5
|
||||
|
||||
def prepare_input(self, x):
|
||||
return x, x + 2, None, x.reshape(-1), (x * 10).to(torch.int32)
|
||||
|
||||
|
||||
ref_net = None
|
||||
|
||||
return prim_tuple_construct_tuple_unpack(), ref_net, ["prim::TupleConstruct", "prim::TupleUnpack"]
|
||||
|
||||
@pytest.mark.nightly
|
||||
def test_tuple_construct_unpack(self, ie_device, precision, ir_version):
|
||||
self._test(*self.create_model(), ie_device, precision, ir_version, freeze_model=False)
|
||||
Reference in New Issue
Block a user