[PT FE] Add aten::swapaxes (#19483)
* Add aten::swapaxes * Add comment * Improve swapaxes tests
This commit is contained in:
committed by
GitHub
parent
511f06f9ba
commit
c46f6bf115
@@ -427,6 +427,7 @@ const std::map<std::string, CreatorFunction> get_supported_ops_ts() {
|
||||
{"aten::sub", op::translate_sub},
|
||||
{"aten::sub_", op::inplace_op<op::translate_sub>},
|
||||
{"aten::sum", op::translate_sum},
|
||||
{"aten::swapaxes", op::quantizable_op<op::translate_transpose>},
|
||||
{"aten::t", op::translate_t},
|
||||
{"aten::t_", op::inplace_op<op::translate_t>},
|
||||
{"aten::tan", op::translate_1to1_match_1_inputs_with_fp32_type_alignment<opset10::Tan>},
|
||||
|
||||
@@ -1,52 +1,63 @@
|
||||
# Copyright (C) 2018-2023 Intel Corporation
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
import torch
|
||||
from pytorch_layer_test_class import PytorchLayerTest
|
||||
|
||||
|
||||
class TestTranspose(PytorchLayerTest):
|
||||
def _prepare_input(self):
|
||||
import numpy as np
|
||||
return (np.random.randn(2, 3, 4, 5).astype(np.float32),)
|
||||
|
||||
def create_model(self, dim0, dim1):
|
||||
import torch
|
||||
|
||||
class aten_transpose(torch.nn.Module):
|
||||
def create_model(self, dim0, dim1, op_type):
|
||||
class swapaxes(torch.nn.Module):
|
||||
def __init__(self, dim0, dim1):
|
||||
super(aten_transpose, self).__init__()
|
||||
super().__init__()
|
||||
self.dim0 = dim0
|
||||
self.dim1 = dim1
|
||||
|
||||
def forward(self, x):
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
return torch.swapaxes(x, self.dim0, self.dim1)
|
||||
|
||||
class aten_transpose(torch.nn.Module):
|
||||
def __init__(self, dim0, dim1, op_type):
|
||||
super(aten_transpose, self).__init__()
|
||||
self.dim0 = dim0
|
||||
self.dim1 = dim1
|
||||
op_types = {"transpose": self.forward_transpose, "swapaxes": self.forward_swapaxes}
|
||||
self.swapaxes = swapaxes(dim0, dim1)
|
||||
self.forward = op_types.get(op_type)
|
||||
|
||||
def forward_transpose(self, x):
|
||||
return torch.transpose(x, self.dim0, self.dim1)
|
||||
|
||||
def forward_swapaxes(self, x: torch.Tensor) -> torch.Tensor:
|
||||
# To reproduce aten::swapaxes in graph, swapaxes need to be in separate graph and tracing need to be used.
|
||||
return self.swapaxes(x)
|
||||
|
||||
ref_net = None
|
||||
|
||||
return aten_transpose(dim0, dim1), ref_net, "aten::transpose"
|
||||
return aten_transpose(dim0, dim1, op_type), ref_net, f"aten::{op_type}"
|
||||
|
||||
@pytest.mark.parametrize("dim0", [0, 1, 2, 3, -1, -2, -3, -4])
|
||||
@pytest.mark.parametrize("dim1", [0, 1, 2, 3, -1, -2, -3, -4])
|
||||
@pytest.mark.parametrize("op_type", ["transpose", "swapaxes"])
|
||||
@pytest.mark.nightly
|
||||
@pytest.mark.precommit
|
||||
def test_transpose(self, dim0, dim1, ie_device, precision, ir_version):
|
||||
self._test(*self.create_model(dim0, dim1),
|
||||
ie_device, precision, ir_version)
|
||||
def test_transpose(self, dim0, dim1, op_type, ie_device, precision, ir_version):
|
||||
self._test(*self.create_model(dim0, dim1, op_type), ie_device, precision, ir_version, trace_model=True)
|
||||
|
||||
|
||||
class TestTSmall(PytorchLayerTest):
|
||||
def _prepare_input(self, num_dims=2, input_dtype="float32"):
|
||||
import numpy as np
|
||||
shape = (2, 3)
|
||||
if num_dims == 0:
|
||||
return (np.array(num_dims).astype(input_dtype), )
|
||||
return (np.array(num_dims).astype(input_dtype),)
|
||||
return (np.random.randn(*shape[:num_dims]).astype(input_dtype),)
|
||||
|
||||
def create_model(self, num_dims=2, inplace=False):
|
||||
import torch
|
||||
|
||||
class aten_transpose(torch.nn.Module):
|
||||
def __init__(self, inplace):
|
||||
super(aten_transpose, self).__init__()
|
||||
@@ -61,7 +72,7 @@ class TestTSmall(PytorchLayerTest):
|
||||
|
||||
ref_net = None
|
||||
|
||||
return aten_transpose(inplace), ref_net, "aten::t" if not inplace else "aten::t_"
|
||||
return aten_transpose(inplace), ref_net, "aten::t" if not inplace else "aten::t_"
|
||||
|
||||
@pytest.mark.parametrize("num_dims", [0, 1, 2])
|
||||
@pytest.mark.parametrize("input_dtype", ["float32", "int32"])
|
||||
@@ -69,6 +80,10 @@ class TestTSmall(PytorchLayerTest):
|
||||
@pytest.mark.nightly
|
||||
@pytest.mark.precommit
|
||||
def test_t_small(self, num_dims, input_dtype, inplace, ie_device, precision, ir_version):
|
||||
self._test(*self.create_model(num_dims, inplace),
|
||||
ie_device, precision, ir_version,
|
||||
kwargs_to_prepare_input={"num_dims": num_dims, "input_dtype": input_dtype})
|
||||
self._test(
|
||||
*self.create_model(num_dims, inplace),
|
||||
ie_device,
|
||||
precision,
|
||||
ir_version,
|
||||
kwargs_to_prepare_input={"num_dims": num_dims, "input_dtype": input_dtype},
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user