diff --git a/src/frontends/pytorch/src/op_table.cpp b/src/frontends/pytorch/src/op_table.cpp index f475c2cd186..9cb68a3ea5c 100644 --- a/src/frontends/pytorch/src/op_table.cpp +++ b/src/frontends/pytorch/src/op_table.cpp @@ -393,7 +393,7 @@ const std::map get_supported_ops_ts() { {"aten::relu_", op::inplace_op>}, {"aten::relu6", op::translate_relu6}, {"aten::remainder", op::translate_remainder}, - {"aten::repeat", op::inplace_op>}, + {"aten::repeat", op::translate_1to1_match_2_inputs}, {"aten::repeat_interleave", op::translate_repeat_interleave}, {"aten::reshape", op::translate_reshape}, {"aten::reshape_as", op::translate_reshape_as}, diff --git a/tests/layer_tests/pytorch_tests/test_repeat.py b/tests/layer_tests/pytorch_tests/test_repeat.py index 71c79c32d81..45263366c76 100644 --- a/tests/layer_tests/pytorch_tests/test_repeat.py +++ b/tests/layer_tests/pytorch_tests/test_repeat.py @@ -32,6 +32,7 @@ class TestRepeat(PytorchLayerTest): def test_repeat(self, repeats, ie_device, precision, ir_version): self._test(*self.create_model(repeats), ie_device, precision, ir_version) + class TestRepeatList(PytorchLayerTest): def _prepare_input(self, repeats_shape): import numpy as np @@ -54,4 +55,26 @@ class TestRepeatList(PytorchLayerTest): @pytest.mark.nightly @pytest.mark.precommit def test_repeat(self, repeats, ie_device, precision, ir_version): - self._test(*self.create_model(), ie_device, precision, ir_version, kwargs_to_prepare_input={"repeats_shape": repeats}) + self._test(*self.create_model(), ie_device, precision, ir_version, + kwargs_to_prepare_input={"repeats_shape": repeats}) + + +class TestRepeatFromFlanT5(PytorchLayerTest): + def _prepare_input(self): + import numpy as np + return (np.random.randn(1, 15).astype(np.float32),) + + def create_model(self): + import torch + from transformers.modeling_utils import ModuleUtilsMixin + + class aten_repeat(torch.nn.Module): + def forward(self, x): + return ModuleUtilsMixin.create_extended_attention_mask_for_decoder(x.size(), x) + + return aten_repeat(), None, "aten::repeat" + + @pytest.mark.nightly + @pytest.mark.precommit + def test_repeat_t5(self, ie_device, precision, ir_version): + self._test(*self.create_model(), ie_device, precision, ir_version, trace_model=True) diff --git a/tests/layer_tests/requirements.txt b/tests/layer_tests/requirements.txt index 30a92216833..b1b76c54b92 100644 --- a/tests/layer_tests/requirements.txt +++ b/tests/layer_tests/requirements.txt @@ -4,6 +4,7 @@ numpy requests torch torchvision +transformers pytest tensorflow-addons; python_version <= '3.10' jax; sys_platform == "linux" diff --git a/tests/model_hub_tests/torch_tests/test_transformers.py b/tests/model_hub_tests/torch_tests/test_transformers.py index f4cea932d4a..1df75b45502 100644 --- a/tests/model_hub_tests/torch_tests/test_transformers.py +++ b/tests/model_hub_tests/torch_tests/test_transformers.py @@ -274,6 +274,7 @@ class TestTransformersModel(TestConvertModel): @pytest.mark.parametrize("name,type", [("bert-base-uncased", "bert"), ("facebook/bart-large-mnli", "bart"), + ("google/flan-t5-base","t5"), ("gpt2", "gpt2"), ("openai/clip-vit-large-patch14", "clip")]) @pytest.mark.precommit