[PT FE] Fix aten::repeat regression (#19991)

* Revert "[PT FE] Simplify repeat operation (#19926)"

This reverts commit f926e0e392.

* Fix aten::repeats regression

* Simplify

* Update src/frontends/pytorch/src/op_table.cpp

* Add impacted model
This commit is contained in:
Maxim Vafin 2023-09-21 23:58:09 +02:00 committed by GitHub
parent b1bf16c7cf
commit 058b45e608
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 27 additions and 2 deletions

View File

@ -393,7 +393,7 @@ const std::map<std::string, CreatorFunction> get_supported_ops_ts() {
{"aten::relu_", op::inplace_op<op::translate_1to1_match_1_inputs<opset10::Relu>>},
{"aten::relu6", op::translate_relu6},
{"aten::remainder", op::translate_remainder},
{"aten::repeat", op::inplace_op<op::translate_1to1_match_2_inputs<opset10::Tile>>},
{"aten::repeat", op::translate_1to1_match_2_inputs<opset10::Tile>},
{"aten::repeat_interleave", op::translate_repeat_interleave},
{"aten::reshape", op::translate_reshape},
{"aten::reshape_as", op::translate_reshape_as},

View File

@ -32,6 +32,7 @@ class TestRepeat(PytorchLayerTest):
def test_repeat(self, repeats, ie_device, precision, ir_version):
self._test(*self.create_model(repeats), ie_device, precision, ir_version)
class TestRepeatList(PytorchLayerTest):
def _prepare_input(self, repeats_shape):
import numpy as np
@ -54,4 +55,26 @@ class TestRepeatList(PytorchLayerTest):
@pytest.mark.nightly
@pytest.mark.precommit
def test_repeat(self, repeats, ie_device, precision, ir_version):
self._test(*self.create_model(), ie_device, precision, ir_version, kwargs_to_prepare_input={"repeats_shape": repeats})
self._test(*self.create_model(), ie_device, precision, ir_version,
kwargs_to_prepare_input={"repeats_shape": repeats})
class TestRepeatFromFlanT5(PytorchLayerTest):
def _prepare_input(self):
import numpy as np
return (np.random.randn(1, 15).astype(np.float32),)
def create_model(self):
import torch
from transformers.modeling_utils import ModuleUtilsMixin
class aten_repeat(torch.nn.Module):
def forward(self, x):
return ModuleUtilsMixin.create_extended_attention_mask_for_decoder(x.size(), x)
return aten_repeat(), None, "aten::repeat"
@pytest.mark.nightly
@pytest.mark.precommit
def test_repeat_t5(self, ie_device, precision, ir_version):
self._test(*self.create_model(), ie_device, precision, ir_version, trace_model=True)

View File

@ -4,6 +4,7 @@ numpy
requests
torch
torchvision
transformers
pytest
tensorflow-addons; python_version <= '3.10'
jax; sys_platform == "linux"

View File

@ -274,6 +274,7 @@ class TestTransformersModel(TestConvertModel):
@pytest.mark.parametrize("name,type", [("bert-base-uncased", "bert"),
("facebook/bart-large-mnli", "bart"),
("google/flan-t5-base","t5"),
("gpt2", "gpt2"),
("openai/clip-vit-large-patch14", "clip")])
@pytest.mark.precommit