[PT FE]: support aten::t and inplace tril/triu (#18040)
This commit is contained in:
parent
b69c11d8ef
commit
d3461074ea
@ -50,6 +50,16 @@ OutputVector translate_transpose(const NodeContext& context) {
|
||||
return {context.mark_node(std::make_shared<v1::Transpose>(context.get_input(0), scatter))};
|
||||
};
|
||||
|
||||
OutputVector translate_t(const NodeContext& context) {
|
||||
num_inputs_check(context, 1, 1);
|
||||
auto input = context.get_input(0);
|
||||
if (input.get_partial_shape().rank().is_dynamic() || input.get_partial_shape().rank().get_length() < 2) {
|
||||
return {input};
|
||||
}
|
||||
auto dims = context.mark_node(v0::Constant::create(element::i32, Shape{2}, {1, 0}));
|
||||
return {context.mark_node(std::make_shared<v1::Transpose>(input, dims))};
|
||||
}
|
||||
|
||||
} // namespace op
|
||||
} // namespace pytorch
|
||||
} // namespace frontend
|
||||
|
@ -130,6 +130,7 @@ OP_CONVERTER(translate_square);
|
||||
OP_CONVERTER(translate_squeeze);
|
||||
OP_CONVERTER(translate_sub);
|
||||
OP_CONVERTER(translate_sum);
|
||||
OP_CONVERTER(translate_t);
|
||||
OP_CONVERTER(translate_to);
|
||||
OP_CONVERTER(translate_topk);
|
||||
OP_CONVERTER(translate_transpose);
|
||||
@ -348,6 +349,8 @@ const std::map<std::string, CreatorFunction> get_supported_ops() {
|
||||
{"aten::squeeze", op::translate_squeeze},
|
||||
{"aten::sub", op::translate_sub},
|
||||
{"aten::sum", op::translate_sum},
|
||||
{"aten::t", op::translate_t},
|
||||
{"aten::t_", op::inplace_op<op::translate_t>},
|
||||
{"aten::tan", op::translate_1to1_match_1_inputs_with_fp32_type_alignment<opset10::Tan>},
|
||||
{"aten::tan_", op::inplace_op<op::translate_1to1_match_1_inputs<opset10::Tan>>},
|
||||
{"aten::tanh", op::translate_1to1_match_1_inputs_with_fp32_type_alignment<opset10::Tanh>},
|
||||
@ -357,7 +360,9 @@ const std::map<std::string, CreatorFunction> get_supported_ops() {
|
||||
{"aten::topk", op::translate_topk},
|
||||
{"aten::transpose", op::translate_transpose},
|
||||
{"aten::tril", op::translate_tril},
|
||||
{"aten::tril_", op::inplace_op<op::translate_tril>},
|
||||
{"aten::triu", op::translate_triu},
|
||||
{"aten::triu_", op::inplace_op<op::translate_triu>},
|
||||
{"aten::type_as",
|
||||
op::translate_1to1_match_2_inputs<opset10::ConvertLike>}, // TODO: overflow semantics is different
|
||||
{"aten::unflatten", op::translate_unflatten},
|
||||
|
@ -31,6 +31,65 @@ class TestTranspose(PytorchLayerTest):
|
||||
@pytest.mark.parametrize("dim1", [0, 1, 2, 3, -1, -2, -3, -4])
|
||||
@pytest.mark.nightly
|
||||
@pytest.mark.precommit
|
||||
def test_relu(self, dim0, dim1, ie_device, precision, ir_version):
|
||||
def test_transpose(self, dim0, dim1, ie_device, precision, ir_version):
|
||||
self._test(*self.create_model(dim0, dim1),
|
||||
ie_device, precision, ir_version)
|
||||
|
||||
|
||||
class TestTSmall(PytorchLayerTest):
|
||||
def _prepare_input(self, num_dims=2, input_dtype="float32"):
|
||||
import numpy as np
|
||||
shape = (2, 3)
|
||||
if num_dims == 0:
|
||||
return (np.array(num_dims).astype(input_dtype), )
|
||||
return (np.random.randn(*shape[:num_dims]).astype(input_dtype),)
|
||||
|
||||
def create_model(self, num_dims=2, inplace=False):
|
||||
import torch
|
||||
|
||||
class aten_transpose(torch.nn.Module):
|
||||
def __init__(self, num_dims, inplace):
|
||||
super(aten_transpose, self).__init__()
|
||||
if num_dims == 2:
|
||||
self.forward = self.forward_2d if not inplace else self.forward_2d_inplace
|
||||
elif num_dims == 1:
|
||||
self.forward = self.forward_1d if not inplace else self.forward_1d_inplace
|
||||
else:
|
||||
if inplace:
|
||||
self.forward = self.forward_inplace
|
||||
|
||||
def forward_2d(self, x):
|
||||
x = torch.reshape(x, (2, -1))
|
||||
return x.t(), x
|
||||
|
||||
def forward_2d_inplace(self, x):
|
||||
x = torch.reshape(x, (2, -1))
|
||||
return x.t_(), x
|
||||
|
||||
def forward_1d(self, x):
|
||||
x = torch.reshape(x, (-1, ))
|
||||
return x.t(), x
|
||||
|
||||
def forward_1d_inplace(self, x):
|
||||
x = torch.reshape(x, (-1, ))
|
||||
return x.t_(), x
|
||||
|
||||
def forward(self, x):
|
||||
return x.t(), x
|
||||
|
||||
def forward_inplace(self, x):
|
||||
return x.t_(), x
|
||||
|
||||
ref_net = None
|
||||
|
||||
return aten_transpose(num_dims, inplace), ref_net, "aten::t" if not inplace else "aten::t_"
|
||||
|
||||
@pytest.mark.parametrize("num_dims", [0, 1, 2])
|
||||
@pytest.mark.parametrize("input_dtype", ["float32", "int32"])
|
||||
@pytest.mark.parametrize("inplace", [True, False])
|
||||
@pytest.mark.nightly
|
||||
@pytest.mark.precommit
|
||||
def test_t_small(self, num_dims, input_dtype, inplace, ie_device, precision, ir_version):
|
||||
self._test(*self.create_model(num_dims, inplace),
|
||||
ie_device, precision, ir_version,
|
||||
kwargs_to_prepare_input={"num_dims": num_dims, "input_dtype": input_dtype})
|
||||
|
@ -42,4 +42,53 @@ class TestTriuTril(PytorchLayerTest):
|
||||
@pytest.mark.nightly
|
||||
@pytest.mark.precommit
|
||||
def test_trilu(self, input_shape, dtype, diagonal, op, ie_device, precision, ir_version):
|
||||
self._test(*self.create_model(op, diagonal), ie_device, precision, ir_version, kwargs_to_prepare_input={"shape": input_shape, "dtype": dtype})
|
||||
self._test(*self.create_model(op, diagonal), ie_device, precision, ir_version,
|
||||
kwargs_to_prepare_input={"shape": input_shape, "dtype": dtype})
|
||||
|
||||
|
||||
class TestTriuTrilTensor(PytorchLayerTest):
|
||||
def _prepare_input(self, shape, dtype):
|
||||
import numpy as np
|
||||
return (np.random.randn(*shape).astype(dtype),)
|
||||
|
||||
def create_model(self, op, diagonal):
|
||||
|
||||
import torch
|
||||
|
||||
class aten_trilu(torch.nn.Module):
|
||||
def __init__(self, op, diagonal):
|
||||
super(aten_trilu, self).__init__()
|
||||
op_map = {
|
||||
"tril": self.tril,
|
||||
"tril_": self.tril_,
|
||||
"triu": self.triu,
|
||||
"triu_": self.triu_
|
||||
}
|
||||
self.diagonal = diagonal
|
||||
self.forward = op_map[op]
|
||||
|
||||
def tril(self, x):
|
||||
return x.tril(self.diagonal), x
|
||||
|
||||
def tril_(self, x):
|
||||
return x.tril_(self.diagonal), x
|
||||
|
||||
def triu(self, x):
|
||||
return x.triu(self.diagonal), x
|
||||
|
||||
def triu_(self, x):
|
||||
return x.triu_(self.diagonal), x
|
||||
|
||||
ref_net = None
|
||||
|
||||
return aten_trilu(op, diagonal), ref_net, f"aten::{op}"
|
||||
|
||||
@pytest.mark.parametrize("input_shape", [(5, 5), (6, 4), (4, 6)])
|
||||
@pytest.mark.parametrize("dtype", ["float32", "float64", "int32", "int64", "int8", "uint8", "bool"])
|
||||
@pytest.mark.parametrize("diagonal", [0, 1, 2, -1, -2])
|
||||
@pytest.mark.parametrize("op", ["triu", "tril", "triu_", "tril_"])
|
||||
@pytest.mark.nightly
|
||||
@pytest.mark.precommit
|
||||
def test_trilu(self, input_shape, dtype, diagonal, op, ie_device, precision, ir_version):
|
||||
self._test(*self.create_model(op, diagonal), ie_device, precision, ir_version,
|
||||
kwargs_to_prepare_input={"shape": input_shape, "dtype": dtype})
|
Loading…
Reference in New Issue
Block a user