[PT FE]: fix unflatten for list construct sizes (#18039)
This commit is contained in:
parent
d66e322529
commit
838d792d96
@ -25,6 +25,9 @@ OutputVector translate_unflatten(const NodeContext& context) {
|
||||
auto input = context.get_input(0);
|
||||
auto dim = context.get_input(1);
|
||||
auto sizes = context.get_input(2);
|
||||
if (context.get_input_type(2).is<type::List>()) {
|
||||
sizes = concat_list_construct(sizes);
|
||||
}
|
||||
auto input_shape = context.mark_node(std::make_shared<v3::ShapeOf>(input, element::i32));
|
||||
auto zero_1d = context.mark_node(v0::Constant::create(element::i32, Shape{1}, {0}));
|
||||
auto one_1d = context.mark_node(v0::Constant::create(element::i32, Shape{1}, {1}));
|
||||
|
@ -32,4 +32,38 @@ class TestUnflatten(PytorchLayerTest):
|
||||
@pytest.mark.nightly
|
||||
@pytest.mark.precommit
|
||||
def test_unflatten(self, dim, shape, dtype, ie_device, precision, ir_version):
|
||||
self._test(*self.create_model(dim, shape), ie_device, precision, ir_version, kwargs_to_prepare_input={"dtype": dtype})
|
||||
self._test(*self.create_model(dim, shape), ie_device, precision, ir_version, kwargs_to_prepare_input={"dtype": dtype})
|
||||
|
||||
|
||||
class TestUnflattenListSizes(PytorchLayerTest):
|
||||
def _prepare_input(self, dtype):
|
||||
return (np.random.uniform(0, 50, (6, 2, 4)).astype(dtype),)
|
||||
|
||||
def create_model(self, dim):
|
||||
import torch
|
||||
|
||||
class aten_unflatten(torch.nn.Module):
|
||||
def __init__(self, dim):
|
||||
super(aten_unflatten, self).__init__()
|
||||
self.dim = dim
|
||||
|
||||
def forward(self, x):
|
||||
dim1, dim2, dim3 = x.shape
|
||||
if self.dim == 0:
|
||||
sizes = [dim1, -1]
|
||||
elif self.dim == 1:
|
||||
sizes = [dim2 // 2, -1]
|
||||
else:
|
||||
sizes = [2, dim3 // 2, -1]
|
||||
return x.unflatten(self.dim, sizes)
|
||||
|
||||
ref_net = None
|
||||
|
||||
return aten_unflatten(dim), ref_net, "aten::unflatten"
|
||||
|
||||
@pytest.mark.parametrize("dim", [0, 1, 2])
|
||||
@pytest.mark.parametrize("dtype", ["float32", "int32"])
|
||||
@pytest.mark.nightly
|
||||
@pytest.mark.precommit
|
||||
def test_unflatten(self, dim, dtype, ie_device, precision, ir_version):
|
||||
self._test(*self.create_model(dim), ie_device, precision, ir_version, kwargs_to_prepare_input={"dtype": dtype})
|
Loading…
Reference in New Issue
Block a user