[PT FE] Fix aten::repeat regression (#19991)
* Revert "[PT FE] Simplify repeat operation (#19926)"
This reverts commit f926e0e392
.
* Fix aten::repeats regression
* Simplify
* Update src/frontends/pytorch/src/op_table.cpp
* Add impacted model
This commit is contained in:
parent
b1bf16c7cf
commit
058b45e608
@ -393,7 +393,7 @@ const std::map<std::string, CreatorFunction> get_supported_ops_ts() {
|
||||
{"aten::relu_", op::inplace_op<op::translate_1to1_match_1_inputs<opset10::Relu>>},
|
||||
{"aten::relu6", op::translate_relu6},
|
||||
{"aten::remainder", op::translate_remainder},
|
||||
{"aten::repeat", op::inplace_op<op::translate_1to1_match_2_inputs<opset10::Tile>>},
|
||||
{"aten::repeat", op::translate_1to1_match_2_inputs<opset10::Tile>},
|
||||
{"aten::repeat_interleave", op::translate_repeat_interleave},
|
||||
{"aten::reshape", op::translate_reshape},
|
||||
{"aten::reshape_as", op::translate_reshape_as},
|
||||
|
@ -32,6 +32,7 @@ class TestRepeat(PytorchLayerTest):
|
||||
def test_repeat(self, repeats, ie_device, precision, ir_version):
|
||||
self._test(*self.create_model(repeats), ie_device, precision, ir_version)
|
||||
|
||||
|
||||
class TestRepeatList(PytorchLayerTest):
|
||||
def _prepare_input(self, repeats_shape):
|
||||
import numpy as np
|
||||
@ -54,4 +55,26 @@ class TestRepeatList(PytorchLayerTest):
|
||||
@pytest.mark.nightly
|
||||
@pytest.mark.precommit
|
||||
def test_repeat(self, repeats, ie_device, precision, ir_version):
|
||||
self._test(*self.create_model(), ie_device, precision, ir_version, kwargs_to_prepare_input={"repeats_shape": repeats})
|
||||
self._test(*self.create_model(), ie_device, precision, ir_version,
|
||||
kwargs_to_prepare_input={"repeats_shape": repeats})
|
||||
|
||||
|
||||
class TestRepeatFromFlanT5(PytorchLayerTest):
|
||||
def _prepare_input(self):
|
||||
import numpy as np
|
||||
return (np.random.randn(1, 15).astype(np.float32),)
|
||||
|
||||
def create_model(self):
|
||||
import torch
|
||||
from transformers.modeling_utils import ModuleUtilsMixin
|
||||
|
||||
class aten_repeat(torch.nn.Module):
|
||||
def forward(self, x):
|
||||
return ModuleUtilsMixin.create_extended_attention_mask_for_decoder(x.size(), x)
|
||||
|
||||
return aten_repeat(), None, "aten::repeat"
|
||||
|
||||
@pytest.mark.nightly
|
||||
@pytest.mark.precommit
|
||||
def test_repeat_t5(self, ie_device, precision, ir_version):
|
||||
self._test(*self.create_model(), ie_device, precision, ir_version, trace_model=True)
|
||||
|
@ -4,6 +4,7 @@ numpy
|
||||
requests
|
||||
torch
|
||||
torchvision
|
||||
transformers
|
||||
pytest
|
||||
tensorflow-addons; python_version <= '3.10'
|
||||
jax; sys_platform == "linux"
|
||||
|
@ -274,6 +274,7 @@ class TestTransformersModel(TestConvertModel):
|
||||
|
||||
@pytest.mark.parametrize("name,type", [("bert-base-uncased", "bert"),
|
||||
("facebook/bart-large-mnli", "bart"),
|
||||
("google/flan-t5-base","t5"),
|
||||
("gpt2", "gpt2"),
|
||||
("openai/clip-vit-large-patch14", "clip")])
|
||||
@pytest.mark.precommit
|
||||
|
Loading…
Reference in New Issue
Block a user