[PT FE] Fix aten::repeat regression (#19991)
* Revert "[PT FE] Simplify repeat operation (#19926)"
This reverts commit f926e0e392
.
* Fix aten::repeats regression
* Simplify
* Update src/frontends/pytorch/src/op_table.cpp
* Add impacted model
This commit is contained in:
parent
b1bf16c7cf
commit
058b45e608
@ -393,7 +393,7 @@ const std::map<std::string, CreatorFunction> get_supported_ops_ts() {
|
|||||||
{"aten::relu_", op::inplace_op<op::translate_1to1_match_1_inputs<opset10::Relu>>},
|
{"aten::relu_", op::inplace_op<op::translate_1to1_match_1_inputs<opset10::Relu>>},
|
||||||
{"aten::relu6", op::translate_relu6},
|
{"aten::relu6", op::translate_relu6},
|
||||||
{"aten::remainder", op::translate_remainder},
|
{"aten::remainder", op::translate_remainder},
|
||||||
{"aten::repeat", op::inplace_op<op::translate_1to1_match_2_inputs<opset10::Tile>>},
|
{"aten::repeat", op::translate_1to1_match_2_inputs<opset10::Tile>},
|
||||||
{"aten::repeat_interleave", op::translate_repeat_interleave},
|
{"aten::repeat_interleave", op::translate_repeat_interleave},
|
||||||
{"aten::reshape", op::translate_reshape},
|
{"aten::reshape", op::translate_reshape},
|
||||||
{"aten::reshape_as", op::translate_reshape_as},
|
{"aten::reshape_as", op::translate_reshape_as},
|
||||||
|
@ -32,6 +32,7 @@ class TestRepeat(PytorchLayerTest):
|
|||||||
def test_repeat(self, repeats, ie_device, precision, ir_version):
|
def test_repeat(self, repeats, ie_device, precision, ir_version):
|
||||||
self._test(*self.create_model(repeats), ie_device, precision, ir_version)
|
self._test(*self.create_model(repeats), ie_device, precision, ir_version)
|
||||||
|
|
||||||
|
|
||||||
class TestRepeatList(PytorchLayerTest):
|
class TestRepeatList(PytorchLayerTest):
|
||||||
def _prepare_input(self, repeats_shape):
|
def _prepare_input(self, repeats_shape):
|
||||||
import numpy as np
|
import numpy as np
|
||||||
@ -54,4 +55,26 @@ class TestRepeatList(PytorchLayerTest):
|
|||||||
@pytest.mark.nightly
|
@pytest.mark.nightly
|
||||||
@pytest.mark.precommit
|
@pytest.mark.precommit
|
||||||
def test_repeat(self, repeats, ie_device, precision, ir_version):
|
def test_repeat(self, repeats, ie_device, precision, ir_version):
|
||||||
self._test(*self.create_model(), ie_device, precision, ir_version, kwargs_to_prepare_input={"repeats_shape": repeats})
|
self._test(*self.create_model(), ie_device, precision, ir_version,
|
||||||
|
kwargs_to_prepare_input={"repeats_shape": repeats})
|
||||||
|
|
||||||
|
|
||||||
|
class TestRepeatFromFlanT5(PytorchLayerTest):
|
||||||
|
def _prepare_input(self):
|
||||||
|
import numpy as np
|
||||||
|
return (np.random.randn(1, 15).astype(np.float32),)
|
||||||
|
|
||||||
|
def create_model(self):
|
||||||
|
import torch
|
||||||
|
from transformers.modeling_utils import ModuleUtilsMixin
|
||||||
|
|
||||||
|
class aten_repeat(torch.nn.Module):
|
||||||
|
def forward(self, x):
|
||||||
|
return ModuleUtilsMixin.create_extended_attention_mask_for_decoder(x.size(), x)
|
||||||
|
|
||||||
|
return aten_repeat(), None, "aten::repeat"
|
||||||
|
|
||||||
|
@pytest.mark.nightly
|
||||||
|
@pytest.mark.precommit
|
||||||
|
def test_repeat_t5(self, ie_device, precision, ir_version):
|
||||||
|
self._test(*self.create_model(), ie_device, precision, ir_version, trace_model=True)
|
||||||
|
@ -4,6 +4,7 @@ numpy
|
|||||||
requests
|
requests
|
||||||
torch
|
torch
|
||||||
torchvision
|
torchvision
|
||||||
|
transformers
|
||||||
pytest
|
pytest
|
||||||
tensorflow-addons; python_version <= '3.10'
|
tensorflow-addons; python_version <= '3.10'
|
||||||
jax; sys_platform == "linux"
|
jax; sys_platform == "linux"
|
||||||
|
@ -274,6 +274,7 @@ class TestTransformersModel(TestConvertModel):
|
|||||||
|
|
||||||
@pytest.mark.parametrize("name,type", [("bert-base-uncased", "bert"),
|
@pytest.mark.parametrize("name,type", [("bert-base-uncased", "bert"),
|
||||||
("facebook/bart-large-mnli", "bart"),
|
("facebook/bart-large-mnli", "bart"),
|
||||||
|
("google/flan-t5-base","t5"),
|
||||||
("gpt2", "gpt2"),
|
("gpt2", "gpt2"),
|
||||||
("openai/clip-vit-large-patch14", "clip")])
|
("openai/clip-vit-large-patch14", "clip")])
|
||||||
@pytest.mark.precommit
|
@pytest.mark.precommit
|
||||||
|
Loading…
Reference in New Issue
Block a user