update pytorch layer tests for torch 2.1 compatibility (#20264)
* update pytorch layer tests for torch 2.1 compatibility
This commit is contained in:
@@ -23,4 +23,4 @@ pytest-html==3.2.0
|
||||
pytest-timeout==2.1.0
|
||||
jax<=0.4.14
|
||||
jaxlib<=0.4.14
|
||||
torch<2.1.0,>=1.13
|
||||
torch>=1.13
|
||||
|
||||
@@ -2,6 +2,8 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from packaging.version import parse as parse_version
|
||||
import pytest
|
||||
|
||||
from pytorch_layer_test_class import PytorchLayerTest
|
||||
@@ -48,7 +50,7 @@ class TestMaskedFill(PytorchLayerTest):
|
||||
@pytest.mark.parametrize(
|
||||
"mask_fill", ['zeros', 'ones', 'random'])
|
||||
@pytest.mark.parametrize("input_dtype", [np.float32, np.float64, int, np.int32])
|
||||
@pytest.mark.parametrize("mask_dtype", [np.uint8, np.int32, bool]) # np.float32 incorrectly casted to bool
|
||||
@pytest.mark.parametrize("mask_dtype", [bool]) # np.float32 incorrectly casted to bool
|
||||
@pytest.mark.parametrize("inplace", [True, False])
|
||||
@pytest.mark.nightly
|
||||
@pytest.mark.precommit
|
||||
@@ -56,3 +58,17 @@ class TestMaskedFill(PytorchLayerTest):
|
||||
self._test(*self.create_model(value, inplace),
|
||||
ie_device, precision, ir_version,
|
||||
kwargs_to_prepare_input={'mask_fill': mask_fill, 'mask_dtype': mask_dtype, "input_dtype": input_dtype})
|
||||
|
||||
@pytest.mark.skipif(parse_version(torch.__version__) >= parse_version("2.1.0"), reason="pytorch 2.1 and above does not support nonboolean mask")
|
||||
@pytest.mark.parametrize("value", [0.0, 1.0, -1.0, 2])
|
||||
@pytest.mark.parametrize(
|
||||
"mask_fill", ['zeros', 'ones', 'random'])
|
||||
@pytest.mark.parametrize("input_dtype", [np.float32, np.float64, int, np.int32])
|
||||
@pytest.mark.parametrize("mask_dtype", [np.uint8, np.int32]) # np.float32 incorrectly casted to bool
|
||||
@pytest.mark.parametrize("inplace", [True, False])
|
||||
@pytest.mark.nightly
|
||||
@pytest.mark.precommit
|
||||
def test_masked_fill_non_bool_mask(self, value, mask_fill, mask_dtype, input_dtype, inplace, ie_device, precision, ir_version):
|
||||
self._test(*self.create_model(value, inplace),
|
||||
ie_device, precision, ir_version,
|
||||
kwargs_to_prepare_input={'mask_fill': mask_fill, 'mask_dtype': mask_dtype, "input_dtype": input_dtype})
|
||||
|
||||
@@ -2,6 +2,8 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
from packaging.version import parse as parse_version
|
||||
|
||||
from pytorch_layer_test_class import PytorchLayerTest
|
||||
|
||||
@@ -45,7 +47,7 @@ class TestMaskedScatter(PytorchLayerTest):
|
||||
@pytest.mark.precommit
|
||||
@pytest.mark.parametrize("shape", [[2, 5], [10, 10], [2, 3, 4], [10, 5, 10, 3], [2, 6, 4, 1]])
|
||||
@pytest.mark.parametrize("input_dtype", ["float32", "int32", "float", "int", "uint8"])
|
||||
@pytest.mark.parametrize("mask_dtype", ["bool", "uint8"])
|
||||
@pytest.mark.parametrize("mask_dtype", ["bool"])
|
||||
@pytest.mark.parametrize("out", [True, False])
|
||||
def test_masked_scatter(self, shape, input_dtype, mask_dtype, out, ie_device, precision, ir_version):
|
||||
self._test(*self.create_model(out), ie_device, precision, ir_version,
|
||||
@@ -55,7 +57,28 @@ class TestMaskedScatter(PytorchLayerTest):
|
||||
@pytest.mark.precommit
|
||||
@pytest.mark.parametrize("shape", [[2, 5], [10, 10], [2, 3, 4], [10, 5, 10, 3], [2, 6, 4, 1]])
|
||||
@pytest.mark.parametrize("input_dtype", ["float32", "int32", "float", "int", "uint8"])
|
||||
@pytest.mark.parametrize("mask_dtype", ["bool", "uint8"])
|
||||
@pytest.mark.parametrize("mask_dtype", ["bool"])
|
||||
def test_masked_scatter_inplace(self, shape, input_dtype, mask_dtype, ie_device, precision, ir_version):
|
||||
self._test(*self.create_model(inplace=True), ie_device, precision, ir_version,
|
||||
kwargs_to_prepare_input={"shape": shape, "x_dtype": input_dtype, "mask_dtype": mask_dtype})
|
||||
|
||||
@pytest.mark.skipif(parse_version(torch.__version__) >= parse_version("2.1.0"), reason="pytorch 2.1 and above does not support nonboolean mask")
|
||||
@pytest.mark.nightly
|
||||
@pytest.mark.precommit
|
||||
@pytest.mark.parametrize("shape", [[2, 5], [10, 10], [2, 3, 4], [10, 5, 10, 3], [2, 6, 4, 1]])
|
||||
@pytest.mark.parametrize("input_dtype", ["float32", "int32", "float", "int", "uint8"])
|
||||
@pytest.mark.parametrize("mask_dtype", ["uint8"])
|
||||
@pytest.mark.parametrize("out", [True, False])
|
||||
def test_masked_scatter_u8(self, shape, input_dtype, mask_dtype, out, ie_device, precision, ir_version):
|
||||
self._test(*self.create_model(out), ie_device, precision, ir_version,
|
||||
kwargs_to_prepare_input={"shape": shape, "x_dtype": input_dtype, "mask_dtype": mask_dtype, "out": out})
|
||||
|
||||
@pytest.mark.skipif(parse_version(torch.__version__) >= parse_version("2.1.0"), reason="pytorch 2.1 and above does not support nonboolean mask")
|
||||
@pytest.mark.nightly
|
||||
@pytest.mark.precommit
|
||||
@pytest.mark.parametrize("shape", [[2, 5], [10, 10], [2, 3, 4], [10, 5, 10, 3], [2, 6, 4, 1]])
|
||||
@pytest.mark.parametrize("input_dtype", ["float32", "int32", "float", "int", "uint8"])
|
||||
@pytest.mark.parametrize("mask_dtype", ["uint8"])
|
||||
def test_masked_scatter_inplace_u8(self, shape, input_dtype, mask_dtype, ie_device, precision, ir_version):
|
||||
self._test(*self.create_model(inplace=True), ie_device, precision, ir_version,
|
||||
kwargs_to_prepare_input={"shape": shape, "x_dtype": input_dtype, "mask_dtype": mask_dtype})
|
||||
@@ -2,7 +2,7 @@
|
||||
numpy
|
||||
pytest
|
||||
pytest-html
|
||||
torch
|
||||
torch<2.1
|
||||
torchvision
|
||||
av
|
||||
transformers
|
||||
|
||||
Reference in New Issue
Block a user