[PT FE]: support aten::t and inplace tril/triu (#18040)

This commit is contained in:
Ekaterina Aidova 2023-06-14 15:08:45 +04:00 committed by GitHub
parent b69c11d8ef
commit d3461074ea
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 125 additions and 2 deletions

View File

@ -50,6 +50,16 @@ OutputVector translate_transpose(const NodeContext& context) {
return {context.mark_node(std::make_shared<v1::Transpose>(context.get_input(0), scatter))};
};
OutputVector translate_t(const NodeContext& context) {
num_inputs_check(context, 1, 1);
auto input = context.get_input(0);
if (input.get_partial_shape().rank().is_dynamic() || input.get_partial_shape().rank().get_length() < 2) {
return {input};
}
auto dims = context.mark_node(v0::Constant::create(element::i32, Shape{2}, {1, 0}));
return {context.mark_node(std::make_shared<v1::Transpose>(input, dims))};
}
} // namespace op
} // namespace pytorch
} // namespace frontend

View File

@ -130,6 +130,7 @@ OP_CONVERTER(translate_square);
OP_CONVERTER(translate_squeeze);
OP_CONVERTER(translate_sub);
OP_CONVERTER(translate_sum);
OP_CONVERTER(translate_t);
OP_CONVERTER(translate_to);
OP_CONVERTER(translate_topk);
OP_CONVERTER(translate_transpose);
@ -348,6 +349,8 @@ const std::map<std::string, CreatorFunction> get_supported_ops() {
{"aten::squeeze", op::translate_squeeze},
{"aten::sub", op::translate_sub},
{"aten::sum", op::translate_sum},
{"aten::t", op::translate_t},
{"aten::t_", op::inplace_op<op::translate_t>},
{"aten::tan", op::translate_1to1_match_1_inputs_with_fp32_type_alignment<opset10::Tan>},
{"aten::tan_", op::inplace_op<op::translate_1to1_match_1_inputs<opset10::Tan>>},
{"aten::tanh", op::translate_1to1_match_1_inputs_with_fp32_type_alignment<opset10::Tanh>},
@ -357,7 +360,9 @@ const std::map<std::string, CreatorFunction> get_supported_ops() {
{"aten::topk", op::translate_topk},
{"aten::transpose", op::translate_transpose},
{"aten::tril", op::translate_tril},
{"aten::tril_", op::inplace_op<op::translate_tril>},
{"aten::triu", op::translate_triu},
{"aten::triu_", op::inplace_op<op::translate_triu>},
{"aten::type_as",
op::translate_1to1_match_2_inputs<opset10::ConvertLike>}, // TODO: overflow semantics is different
{"aten::unflatten", op::translate_unflatten},

View File

@ -31,6 +31,65 @@ class TestTranspose(PytorchLayerTest):
@pytest.mark.parametrize("dim1", [0, 1, 2, 3, -1, -2, -3, -4])
@pytest.mark.nightly
@pytest.mark.precommit
def test_relu(self, dim0, dim1, ie_device, precision, ir_version):
def test_transpose(self, dim0, dim1, ie_device, precision, ir_version):
self._test(*self.create_model(dim0, dim1),
ie_device, precision, ir_version)
class TestTSmall(PytorchLayerTest):
def _prepare_input(self, num_dims=2, input_dtype="float32"):
import numpy as np
shape = (2, 3)
if num_dims == 0:
return (np.array(num_dims).astype(input_dtype), )
return (np.random.randn(*shape[:num_dims]).astype(input_dtype),)
def create_model(self, num_dims=2, inplace=False):
import torch
class aten_transpose(torch.nn.Module):
def __init__(self, num_dims, inplace):
super(aten_transpose, self).__init__()
if num_dims == 2:
self.forward = self.forward_2d if not inplace else self.forward_2d_inplace
elif num_dims == 1:
self.forward = self.forward_1d if not inplace else self.forward_1d_inplace
else:
if inplace:
self.forward = self.forward_inplace
def forward_2d(self, x):
x = torch.reshape(x, (2, -1))
return x.t(), x
def forward_2d_inplace(self, x):
x = torch.reshape(x, (2, -1))
return x.t_(), x
def forward_1d(self, x):
x = torch.reshape(x, (-1, ))
return x.t(), x
def forward_1d_inplace(self, x):
x = torch.reshape(x, (-1, ))
return x.t_(), x
def forward(self, x):
return x.t(), x
def forward_inplace(self, x):
return x.t_(), x
ref_net = None
return aten_transpose(num_dims, inplace), ref_net, "aten::t" if not inplace else "aten::t_"
@pytest.mark.parametrize("num_dims", [0, 1, 2])
@pytest.mark.parametrize("input_dtype", ["float32", "int32"])
@pytest.mark.parametrize("inplace", [True, False])
@pytest.mark.nightly
@pytest.mark.precommit
def test_t_small(self, num_dims, input_dtype, inplace, ie_device, precision, ir_version):
self._test(*self.create_model(num_dims, inplace),
ie_device, precision, ir_version,
kwargs_to_prepare_input={"num_dims": num_dims, "input_dtype": input_dtype})

View File

@ -42,4 +42,53 @@ class TestTriuTril(PytorchLayerTest):
@pytest.mark.nightly
@pytest.mark.precommit
def test_trilu(self, input_shape, dtype, diagonal, op, ie_device, precision, ir_version):
self._test(*self.create_model(op, diagonal), ie_device, precision, ir_version, kwargs_to_prepare_input={"shape": input_shape, "dtype": dtype})
self._test(*self.create_model(op, diagonal), ie_device, precision, ir_version,
kwargs_to_prepare_input={"shape": input_shape, "dtype": dtype})
class TestTriuTrilTensor(PytorchLayerTest):
def _prepare_input(self, shape, dtype):
import numpy as np
return (np.random.randn(*shape).astype(dtype),)
def create_model(self, op, diagonal):
import torch
class aten_trilu(torch.nn.Module):
def __init__(self, op, diagonal):
super(aten_trilu, self).__init__()
op_map = {
"tril": self.tril,
"tril_": self.tril_,
"triu": self.triu,
"triu_": self.triu_
}
self.diagonal = diagonal
self.forward = op_map[op]
def tril(self, x):
return x.tril(self.diagonal), x
def tril_(self, x):
return x.tril_(self.diagonal), x
def triu(self, x):
return x.triu(self.diagonal), x
def triu_(self, x):
return x.triu_(self.diagonal), x
ref_net = None
return aten_trilu(op, diagonal), ref_net, f"aten::{op}"
@pytest.mark.parametrize("input_shape", [(5, 5), (6, 4), (4, 6)])
@pytest.mark.parametrize("dtype", ["float32", "float64", "int32", "int64", "int8", "uint8", "bool"])
@pytest.mark.parametrize("diagonal", [0, 1, 2, -1, -2])
@pytest.mark.parametrize("op", ["triu", "tril", "triu_", "tril_"])
@pytest.mark.nightly
@pytest.mark.precommit
def test_trilu(self, input_shape, dtype, diagonal, op, ie_device, precision, ir_version):
self._test(*self.create_model(op, diagonal), ie_device, precision, ir_version,
kwargs_to_prepare_input={"shape": input_shape, "dtype": dtype})