[PT FE] Support aten::max_poolnd_with_indices (#20322)
This commit is contained in:
parent
a31ed6ad19
commit
9adfaca1a8
@ -100,8 +100,23 @@ OutputVector translate_max_poolnd(const NodeContext& context) {
|
||||
std::fill_n(pads.begin(), pads.size(), 0);
|
||||
}
|
||||
|
||||
return {
|
||||
context.mark_node(std::make_shared<v8::MaxPool>(input, strides, dilations, pads, pads, kernel, rounding_type))};
|
||||
auto res = context.mark_node(std::make_shared<v8::MaxPool>(input,
|
||||
strides,
|
||||
dilations,
|
||||
pads,
|
||||
pads,
|
||||
kernel,
|
||||
rounding_type,
|
||||
PadType::EXPLICIT,
|
||||
element::i64,
|
||||
2));
|
||||
if (context.get_output_size() == 2) {
|
||||
auto out1 = res->output(0);
|
||||
auto out2 = res->output(1);
|
||||
return {out1, out2};
|
||||
} else {
|
||||
return {res};
|
||||
}
|
||||
};
|
||||
|
||||
OutputVector translate_max_poolnd_fx(const NodeContext& context) {
|
||||
|
@ -367,8 +367,11 @@ const std::map<std::string, CreatorFunction> get_supported_ops_ts() {
|
||||
{"aten::max", op::translate_max},
|
||||
{"aten::maximum", op::translate_maximum},
|
||||
{"aten::max_pool1d", op::quantizable_op<op::translate_max_poolnd>},
|
||||
{"aten::max_pool1d_with_indices", op::quantizable_op<op::translate_max_poolnd>},
|
||||
{"aten::max_pool2d", op::quantizable_op<op::translate_max_poolnd>},
|
||||
{"aten::max_pool2d_with_indices", op::quantizable_op<op::translate_max_poolnd>},
|
||||
{"aten::max_pool3d", op::quantizable_op<op::translate_max_poolnd>},
|
||||
{"aten::max_pool3d_with_indices", op::quantizable_op<op::translate_max_poolnd>},
|
||||
{"aten::mean", op::quantizable_op<op::translate_mean>},
|
||||
{"aten::meshgrid", op::translate_meshgrid},
|
||||
{"aten::min", op::translate_min},
|
||||
|
@ -4,6 +4,7 @@
|
||||
import pytest
|
||||
|
||||
from pytorch_layer_test_class import PytorchLayerTest
|
||||
import numpy as np
|
||||
|
||||
d2_params = [{'kernel_size': [3, 3], 'stride': 1, 'padding': 0},
|
||||
{'kernel_size': [3, 3], 'stride': [1, 1], 'padding': 1},
|
||||
@ -95,13 +96,31 @@ class TestPooling(PytorchLayerTest):
|
||||
return torch.nn.functional.max_pool1d(x, self.kernel_size, self.stride, self.padding, self.dilation,
|
||||
self.ceil_mode)
|
||||
|
||||
class aten_max_pool2d_indices(aten_max_pooling_base):
|
||||
def forward(self, x):
|
||||
return torch.nn.functional.max_pool2d(x, self.kernel_size, self.stride, self.padding, self.dilation,
|
||||
self.ceil_mode, return_indices=True)
|
||||
|
||||
class aten_max_pool3d_indices(aten_max_pooling_base):
|
||||
def forward(self, x):
|
||||
return torch.nn.functional.max_pool3d(x, self.kernel_size, self.stride, self.padding, self.dilation,
|
||||
self.ceil_mode, return_indices=True)
|
||||
|
||||
class aten_max_pool1d_indices(aten_max_pooling_base):
|
||||
def forward(self, x):
|
||||
return torch.nn.functional.max_pool1d(x, self.kernel_size, self.stride, self.padding, self.dilation,
|
||||
self.ceil_mode, return_indices=True)
|
||||
|
||||
ops = {
|
||||
"max_pool1d": aten_max_pool1d,
|
||||
"max_pool2d": aten_max_pool2d,
|
||||
"max_pool3d": aten_max_pool3d,
|
||||
"avg_pool1d": aten_avg_pool1d,
|
||||
"avg_pool2d": aten_avg_pool2d,
|
||||
"avg_pool3d": aten_avg_pool3d
|
||||
"avg_pool3d": aten_avg_pool3d,
|
||||
"max_pool1d_with_indices": aten_max_pool1d_indices,
|
||||
"max_pool2d_with_indices": aten_max_pool2d_indices,
|
||||
"max_pool3d_with_indices": aten_max_pool3d_indices,
|
||||
}
|
||||
|
||||
ref_net = None
|
||||
@ -160,7 +179,7 @@ class TestPooling(PytorchLayerTest):
|
||||
@pytest.mark.parametrize("dilation", [1, 2])
|
||||
@pytest.mark.nightly
|
||||
@pytest.mark.precommit
|
||||
def test_max_pool2d(self, params, ceil_mode, dilation, ie_device, precision, ir_version):
|
||||
def test_max_pool2d(self, params, ceil_mode, dilation, ie_device, precision, ir_version):
|
||||
to_trace = False
|
||||
if params["stride"] == []:
|
||||
to_trace = True
|
||||
@ -175,3 +194,39 @@ class TestPooling(PytorchLayerTest):
|
||||
def test_max_pool3d(self, params, ceil_mode, dilation, ie_device, precision, ir_version):
|
||||
self._test(*self.create_model("max_pool3d", **params, ceil_mode=ceil_mode, dilation=dilation),
|
||||
ie_device, precision, ir_version, kwargs_to_prepare_input={'ndim': 5}, dynamic_shapes=False)
|
||||
|
||||
@pytest.mark.parametrize("params", d1_params)
|
||||
@pytest.mark.parametrize("ceil_mode", [True, False])
|
||||
@pytest.mark.parametrize("dilation", [1, 2])
|
||||
@pytest.mark.nightly
|
||||
@pytest.mark.precommit
|
||||
def test_max_pool1d_indices(self, params, ceil_mode, dilation, ie_device, precision, ir_version):
|
||||
if ceil_mode and (np.array(params["padding"]).any() != 0):
|
||||
pytest.skip("ticket 122418")
|
||||
self._test(*self.create_model("max_pool1d_with_indices", **params, ceil_mode=ceil_mode, dilation=dilation),
|
||||
ie_device, precision, ir_version, kwargs_to_prepare_input={'ndim': 3}, dynamic_shapes=False)
|
||||
|
||||
@pytest.mark.parametrize("params", d2_params + d2_params_corner_case)
|
||||
@pytest.mark.parametrize("ceil_mode", [True, False])
|
||||
@pytest.mark.parametrize("dilation", [1, 2])
|
||||
@pytest.mark.nightly
|
||||
@pytest.mark.precommit
|
||||
def test_max_pool2d_indices(self, params, ceil_mode, dilation, ie_device, precision, ir_version):
|
||||
if ceil_mode and (np.array(params["padding"]).any() != 0):
|
||||
pytest.skip("ticket 122418")
|
||||
to_trace = False
|
||||
if params["stride"] == []:
|
||||
to_trace = True
|
||||
self._test(*self.create_model("max_pool2d_with_indices", **params, ceil_mode=ceil_mode, dilation=dilation),
|
||||
ie_device, precision, ir_version, dynamic_shapes=False, trace_model=to_trace)
|
||||
|
||||
@pytest.mark.parametrize("params", d3_params)
|
||||
@pytest.mark.parametrize("ceil_mode", [True, False])
|
||||
@pytest.mark.parametrize("dilation", [1, 2])
|
||||
@pytest.mark.nightly
|
||||
@pytest.mark.precommit
|
||||
def test_max_pool3d_indices(self, params, ceil_mode, dilation, ie_device, precision, ir_version):
|
||||
if ceil_mode and (np.array(params["padding"]).any() != 0):
|
||||
pytest.skip("ticket 122418")
|
||||
self._test(*self.create_model("max_pool3d_with_indices", **params, ceil_mode=ceil_mode, dilation=dilation),
|
||||
ie_device, precision, ir_version, kwargs_to_prepare_input={'ndim': 5}, dynamic_shapes=False)
|
||||
|
Loading…
Reference in New Issue
Block a user