Files
openvino/tests/layer_tests/pytorch_tests/test_index.py
Ekaterina Aidova d407bc1b3b [PT FE] fix invalid reshape shape after aten::index (#16821)
* [PT FE] fix invalid reshape shape after aten::index

* support aten::index_select
2023-04-11 12:41:59 +03:00

128 lines
4.4 KiB
Python

# Copyright (C) 2018-2023 Intel Corporation
# SPDX-License-Identifier: Apache-2.0
import pytest
import numpy as np
from pytorch_layer_test_class import PytorchLayerTest
class TestIndex(PytorchLayerTest):
def _prepare_input(self, input_shape, idx):
import numpy as np
return (np.random.randn(*input_shape).astype(np.float32), idx)
def create_model(self, model="list"):
import torch
class aten_index_list(torch.nn.Module):
def forward(self, x, idx):
return x[idx]
class aten_index_getitem(torch.nn.Module):
def forward(self, x, idx):
return x.__getitem__(idx)
class aten_index_list_bool(torch.nn.Module):
def forward(self, x, idx):
return x[idx.to(torch.bool)]
class aten_index_getitem_bool(torch.nn.Module):
def forward(self, x, idx):
return x.__getitem__(idx.to(torch.bool))
cases = {
"list": aten_index_list,
"getitem": aten_index_getitem,
"list_with_bool": aten_index_list_bool,
"getitem_with_bool": aten_index_getitem_bool
}
aten_index = cases[model]
ref_net = None
return aten_index(), ref_net, "aten::index"
@pytest.mark.nightly
@pytest.mark.precommit
@pytest.mark.parametrize("case", ["list", "getitem"])
@pytest.mark.parametrize(("input_shape", "idx"), [
((1,), np.array(0).astype(int)),
([2, 3], np.array(-1).astype(int)),
([4, 5, 6], np.array((1, 2)).astype(int)),
([7, 8, 9], np.array((-1, 2, -3)).astype(int)),
([2, 2, 3, 4], np.array((1,)).astype(int))])
def test_index(self, input_shape, idx, case, ie_device, precision, ir_version):
self._test(*self.create_model(case), ie_device, precision, ir_version,
kwargs_to_prepare_input={"input_shape": input_shape, "idx": idx})
@pytest.mark.nightly
@pytest.mark.precommit
@pytest.mark.parametrize("case", ["getitem_with_bool", "list_with_bool"])
@pytest.mark.parametrize(("input_shape", "idx"), [
((1, 2), np.array([[1, 0]]).astype(bool)),
((2, 2, 5), np.zeros([2, 2, 5]).astype(bool)),
((2, 2, 5), np.ones([2, 2, 5]).astype(bool)),
((2, 2, 5), np.random.rand(2, 2, 5) > 0)
])
def test_index_bool(self, input_shape, idx, case, ie_device, precision, ir_version):
self._test(*self.create_model(case), ie_device, precision, ir_version,
kwargs_to_prepare_input={"input_shape": input_shape, "idx": idx})
class TestIndexRange(PytorchLayerTest):
def _prepare_input(self, input_shape, idx):
import numpy as np
return (np.random.randn(*input_shape).astype(np.float32), np.array(idx).astype(np.int32))
def create_model(self):
import torch
class aten_index_arange(torch.nn.Module):
def forward(self, x, y):
x = x.reshape(x.shape[0], -1)
return x[torch.arange(x.shape[0]), y]
ref_net = None
return aten_index_arange(), ref_net, "aten::index"
def create_model2(self):
import torch
class aten_index_arange(torch.nn.Module):
def forward(self, x, y):
x = x.reshape(x.shape[0], x.shape[1], -1, 1)
return x[torch.arange(x.shape[0]), y]
ref_net = None
return aten_index_arange(), ref_net, "aten::index"
@pytest.mark.nightly
@pytest.mark.precommit
@pytest.mark.parametrize(("input_shape", "idx"), (
((1, 1), [0]),
([2, 3], [1, 2]),
([7, 8, 9], [1]),
([2, 2, 3, 4], [0])))
def test_index_range(self, input_shape, idx, ie_device, precision, ir_version):
self._test(*self.create_model(), ie_device, precision, ir_version, kwargs_to_prepare_input={
"input_shape": input_shape, "idx": idx}, trace_model=True, dynamic_shapes=False)
@pytest.mark.nightly
@pytest.mark.precommit
@pytest.mark.parametrize(("input_shape", "idx"), (
((1, 1), [0]),
([2, 3], [1, 2]),
([7, 8, 9], [1]),
([2, 2, 3, 4], [0])))
def test_index_range_free_dims(self, input_shape, idx, ie_device, precision, ir_version):
self._test(*self.create_model2(), ie_device, precision, ir_version, kwargs_to_prepare_input={
"input_shape": input_shape, "idx": idx}, trace_model=True, dynamic_shapes=False)