[PT FE]: fix unflatten for list construct sizes (#18039)
This commit is contained in:
parent
d66e322529
commit
838d792d96
@ -25,6 +25,9 @@ OutputVector translate_unflatten(const NodeContext& context) {
|
|||||||
auto input = context.get_input(0);
|
auto input = context.get_input(0);
|
||||||
auto dim = context.get_input(1);
|
auto dim = context.get_input(1);
|
||||||
auto sizes = context.get_input(2);
|
auto sizes = context.get_input(2);
|
||||||
|
if (context.get_input_type(2).is<type::List>()) {
|
||||||
|
sizes = concat_list_construct(sizes);
|
||||||
|
}
|
||||||
auto input_shape = context.mark_node(std::make_shared<v3::ShapeOf>(input, element::i32));
|
auto input_shape = context.mark_node(std::make_shared<v3::ShapeOf>(input, element::i32));
|
||||||
auto zero_1d = context.mark_node(v0::Constant::create(element::i32, Shape{1}, {0}));
|
auto zero_1d = context.mark_node(v0::Constant::create(element::i32, Shape{1}, {0}));
|
||||||
auto one_1d = context.mark_node(v0::Constant::create(element::i32, Shape{1}, {1}));
|
auto one_1d = context.mark_node(v0::Constant::create(element::i32, Shape{1}, {1}));
|
||||||
|
@ -33,3 +33,37 @@ class TestUnflatten(PytorchLayerTest):
|
|||||||
@pytest.mark.precommit
|
@pytest.mark.precommit
|
||||||
def test_unflatten(self, dim, shape, dtype, ie_device, precision, ir_version):
|
def test_unflatten(self, dim, shape, dtype, ie_device, precision, ir_version):
|
||||||
self._test(*self.create_model(dim, shape), ie_device, precision, ir_version, kwargs_to_prepare_input={"dtype": dtype})
|
self._test(*self.create_model(dim, shape), ie_device, precision, ir_version, kwargs_to_prepare_input={"dtype": dtype})
|
||||||
|
|
||||||
|
|
||||||
|
class TestUnflattenListSizes(PytorchLayerTest):
|
||||||
|
def _prepare_input(self, dtype):
|
||||||
|
return (np.random.uniform(0, 50, (6, 2, 4)).astype(dtype),)
|
||||||
|
|
||||||
|
def create_model(self, dim):
|
||||||
|
import torch
|
||||||
|
|
||||||
|
class aten_unflatten(torch.nn.Module):
|
||||||
|
def __init__(self, dim):
|
||||||
|
super(aten_unflatten, self).__init__()
|
||||||
|
self.dim = dim
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
dim1, dim2, dim3 = x.shape
|
||||||
|
if self.dim == 0:
|
||||||
|
sizes = [dim1, -1]
|
||||||
|
elif self.dim == 1:
|
||||||
|
sizes = [dim2 // 2, -1]
|
||||||
|
else:
|
||||||
|
sizes = [2, dim3 // 2, -1]
|
||||||
|
return x.unflatten(self.dim, sizes)
|
||||||
|
|
||||||
|
ref_net = None
|
||||||
|
|
||||||
|
return aten_unflatten(dim), ref_net, "aten::unflatten"
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("dim", [0, 1, 2])
|
||||||
|
@pytest.mark.parametrize("dtype", ["float32", "int32"])
|
||||||
|
@pytest.mark.nightly
|
||||||
|
@pytest.mark.precommit
|
||||||
|
def test_unflatten(self, dim, dtype, ie_device, precision, ir_version):
|
||||||
|
self._test(*self.create_model(dim), ie_device, precision, ir_version, kwargs_to_prepare_input={"dtype": dtype})
|
Loading…
Reference in New Issue
Block a user