77 lines
3.2 KiB
Python
77 lines
3.2 KiB
Python
# Copyright (C) 2018-2023 Intel Corporation
|
|
# SPDX-License-Identifier: Apache-2.0
|
|
|
|
import pytest
|
|
from pytorch_layer_test_class import PytorchLayerTest
|
|
import numpy as np
|
|
import random
|
|
import torch
|
|
|
|
|
|
@pytest.mark.parametrize('input_data', ({'repeats': 1, 'dim': 0},
|
|
{'repeats': 2, 'dim': 2},
|
|
{'repeats': [2, 3], 'dim': 1},
|
|
{'repeats': [3, 2, 1], 'dim': 3},
|
|
{'repeats': [3, 2, 1], 'dim': 3},
|
|
{'repeats': 2, 'dim': None},
|
|
{'repeats': [random.randint(1, 5) for _ in range(36)], 'dim': None}))
|
|
class TestRepeatInterleaveConstRepeats(PytorchLayerTest):
|
|
|
|
def _prepare_input(self):
|
|
return (np.random.randn(2, 2, 3, 3),)
|
|
|
|
def create_model_const_repeat(self, repeats, dim):
|
|
class aten_repeat_interleave_const_repeat(torch.nn.Module):
|
|
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
self.repeats = torch.tensor(repeats, dtype=torch.int)
|
|
self.dim = dim
|
|
|
|
def forward(self, input_tensor):
|
|
return input_tensor.repeat_interleave(self.repeats, self.dim)
|
|
|
|
ref_net = None
|
|
|
|
return aten_repeat_interleave_const_repeat(), ref_net, "aten::repeat_interleave"
|
|
|
|
@pytest.mark.nightly
|
|
@pytest.mark.precommit
|
|
def test_repeat_interleave_const_repeats(self, ie_device, precision, ir_version, input_data):
|
|
repeats = input_data['repeats']
|
|
dim = input_data['dim']
|
|
self._test(*self.create_model_const_repeat(repeats, dim),
|
|
ie_device, precision, ir_version)
|
|
|
|
@pytest.mark.parametrize('input_data', ({'repeats': np.array([1]).astype(np.int32), 'dim': 0},
|
|
{'repeats': np.array(1).astype(np.int32), 'dim': 1},
|
|
{'repeats': np.array([2]).astype(np.int32), 'dim': 2},
|
|
{'repeats': np.array(2).astype(np.int32), 'dim': 1},
|
|
{'repeats': np.array([3]).astype(np.int32), 'dim': None}))
|
|
class TestRepeatInterleaveNonConstRepeats(PytorchLayerTest):
|
|
|
|
def _prepare_input(self):
|
|
return (np.random.randn(2, 2, 3, 3), self.repeats)
|
|
|
|
def create_model_non_const_repeat(self, dim):
|
|
class aten_repeat_interleave_non_const_repeat(torch.nn.Module):
|
|
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
self.dim = dim
|
|
|
|
def forward(self, input_tensor, repeats):
|
|
return input_tensor.repeat_interleave(repeats, self.dim)
|
|
|
|
ref_net = None
|
|
|
|
return aten_repeat_interleave_non_const_repeat(), ref_net, "aten::repeat_interleave"
|
|
|
|
@pytest.mark.nightly
|
|
@pytest.mark.precommit
|
|
def test_repeat_interleave_non_const_repeats(self, ie_device, precision, ir_version, input_data):
|
|
self.repeats = input_data['repeats']
|
|
dim = input_data['dim']
|
|
self._test(*self.create_model_non_const_repeat(dim),
|
|
ie_device, precision, ir_version, dynamic_shapes=False, use_mo_convert=False)
|