[PT FE] Support aten::max_poolnd_with_indices (#20322)

This commit is contained in:
Maxim Vafin 2023-10-10 08:00:02 +02:00 committed by Alexander Nesterov
parent a31ed6ad19
commit 9adfaca1a8
3 changed files with 77 additions and 4 deletions

View File

@ -100,8 +100,23 @@ OutputVector translate_max_poolnd(const NodeContext& context) {
std::fill_n(pads.begin(), pads.size(), 0);
}
return {
context.mark_node(std::make_shared<v8::MaxPool>(input, strides, dilations, pads, pads, kernel, rounding_type))};
auto res = context.mark_node(std::make_shared<v8::MaxPool>(input,
strides,
dilations,
pads,
pads,
kernel,
rounding_type,
PadType::EXPLICIT,
element::i64,
2));
if (context.get_output_size() == 2) {
auto out1 = res->output(0);
auto out2 = res->output(1);
return {out1, out2};
} else {
return {res};
}
};
OutputVector translate_max_poolnd_fx(const NodeContext& context) {

View File

@ -367,8 +367,11 @@ const std::map<std::string, CreatorFunction> get_supported_ops_ts() {
{"aten::max", op::translate_max},
{"aten::maximum", op::translate_maximum},
{"aten::max_pool1d", op::quantizable_op<op::translate_max_poolnd>},
{"aten::max_pool1d_with_indices", op::quantizable_op<op::translate_max_poolnd>},
{"aten::max_pool2d", op::quantizable_op<op::translate_max_poolnd>},
{"aten::max_pool2d_with_indices", op::quantizable_op<op::translate_max_poolnd>},
{"aten::max_pool3d", op::quantizable_op<op::translate_max_poolnd>},
{"aten::max_pool3d_with_indices", op::quantizable_op<op::translate_max_poolnd>},
{"aten::mean", op::quantizable_op<op::translate_mean>},
{"aten::meshgrid", op::translate_meshgrid},
{"aten::min", op::translate_min},

View File

@ -4,6 +4,7 @@
import pytest
from pytorch_layer_test_class import PytorchLayerTest
import numpy as np
d2_params = [{'kernel_size': [3, 3], 'stride': 1, 'padding': 0},
{'kernel_size': [3, 3], 'stride': [1, 1], 'padding': 1},
@ -95,13 +96,31 @@ class TestPooling(PytorchLayerTest):
return torch.nn.functional.max_pool1d(x, self.kernel_size, self.stride, self.padding, self.dilation,
self.ceil_mode)
class aten_max_pool2d_indices(aten_max_pooling_base):
def forward(self, x):
return torch.nn.functional.max_pool2d(x, self.kernel_size, self.stride, self.padding, self.dilation,
self.ceil_mode, return_indices=True)
class aten_max_pool3d_indices(aten_max_pooling_base):
def forward(self, x):
return torch.nn.functional.max_pool3d(x, self.kernel_size, self.stride, self.padding, self.dilation,
self.ceil_mode, return_indices=True)
class aten_max_pool1d_indices(aten_max_pooling_base):
def forward(self, x):
return torch.nn.functional.max_pool1d(x, self.kernel_size, self.stride, self.padding, self.dilation,
self.ceil_mode, return_indices=True)
ops = {
"max_pool1d": aten_max_pool1d,
"max_pool2d": aten_max_pool2d,
"max_pool3d": aten_max_pool3d,
"avg_pool1d": aten_avg_pool1d,
"avg_pool2d": aten_avg_pool2d,
"avg_pool3d": aten_avg_pool3d
"avg_pool3d": aten_avg_pool3d,
"max_pool1d_with_indices": aten_max_pool1d_indices,
"max_pool2d_with_indices": aten_max_pool2d_indices,
"max_pool3d_with_indices": aten_max_pool3d_indices,
}
ref_net = None
@ -160,7 +179,7 @@ class TestPooling(PytorchLayerTest):
@pytest.mark.parametrize("dilation", [1, 2])
@pytest.mark.nightly
@pytest.mark.precommit
def test_max_pool2d(self, params, ceil_mode, dilation, ie_device, precision, ir_version):
def test_max_pool2d(self, params, ceil_mode, dilation, ie_device, precision, ir_version):
to_trace = False
if params["stride"] == []:
to_trace = True
@ -175,3 +194,39 @@ class TestPooling(PytorchLayerTest):
def test_max_pool3d(self, params, ceil_mode, dilation, ie_device, precision, ir_version):
self._test(*self.create_model("max_pool3d", **params, ceil_mode=ceil_mode, dilation=dilation),
ie_device, precision, ir_version, kwargs_to_prepare_input={'ndim': 5}, dynamic_shapes=False)
@pytest.mark.parametrize("params", d1_params)
@pytest.mark.parametrize("ceil_mode", [True, False])
@pytest.mark.parametrize("dilation", [1, 2])
@pytest.mark.nightly
@pytest.mark.precommit
def test_max_pool1d_indices(self, params, ceil_mode, dilation, ie_device, precision, ir_version):
if ceil_mode and (np.array(params["padding"]).any() != 0):
pytest.skip("ticket 122418")
self._test(*self.create_model("max_pool1d_with_indices", **params, ceil_mode=ceil_mode, dilation=dilation),
ie_device, precision, ir_version, kwargs_to_prepare_input={'ndim': 3}, dynamic_shapes=False)
@pytest.mark.parametrize("params", d2_params + d2_params_corner_case)
@pytest.mark.parametrize("ceil_mode", [True, False])
@pytest.mark.parametrize("dilation", [1, 2])
@pytest.mark.nightly
@pytest.mark.precommit
def test_max_pool2d_indices(self, params, ceil_mode, dilation, ie_device, precision, ir_version):
if ceil_mode and (np.array(params["padding"]).any() != 0):
pytest.skip("ticket 122418")
to_trace = False
if params["stride"] == []:
to_trace = True
self._test(*self.create_model("max_pool2d_with_indices", **params, ceil_mode=ceil_mode, dilation=dilation),
ie_device, precision, ir_version, dynamic_shapes=False, trace_model=to_trace)
@pytest.mark.parametrize("params", d3_params)
@pytest.mark.parametrize("ceil_mode", [True, False])
@pytest.mark.parametrize("dilation", [1, 2])
@pytest.mark.nightly
@pytest.mark.precommit
def test_max_pool3d_indices(self, params, ceil_mode, dilation, ie_device, precision, ir_version):
if ceil_mode and (np.array(params["padding"]).any() != 0):
pytest.skip("ticket 122418")
self._test(*self.create_model("max_pool3d_with_indices", **params, ceil_mode=ceil_mode, dilation=dilation),
ie_device, precision, ir_version, kwargs_to_prepare_input={'ndim': 5}, dynamic_shapes=False)