[PT FE] Improved MaxPool convert by PyTorch FE (#18965)

* Improved MaxPool convert by PyTorch FE

* Mark xfail corner cases for AvgPool

* Update src/frontends/pytorch/src/op/max_poolnd.cpp

Co-authored-by: Maxim Vafin <maxim.vafin@intel.com>

* Update src/frontends/pytorch/src/op/max_poolnd.cpp

Co-authored-by: Maxim Vafin <maxim.vafin@intel.com>

* Fix build issues

---------

Co-authored-by: Maxim Vafin <maxim.vafin@intel.com>
This commit is contained in:
Pawel Raasz
2023-08-04 17:07:33 +02:00
committed by GitHub
parent 36309938d9
commit 2385c2769e
2 changed files with 73 additions and 5 deletions

View File

@@ -15,6 +15,8 @@ d2_params = [{'kernel_size': [3, 3], 'stride': 1, 'padding': 0},
{'kernel_size': [2, 1], 'stride': [], 'padding': 0},
]
d2_params_corner_case = [{'kernel_size': [8, 8], 'stride': [8,4], 'padding': 1}]
d1_params = [{'kernel_size': 3, 'stride': 1, 'padding': 0},
{'kernel_size': (4,), 'stride': 1, 'padding': 1},
{'kernel_size': 4, 'stride': (5,), 'padding': 2},
@@ -117,7 +119,15 @@ class TestPooling(PytorchLayerTest):
ie_device, precision, ir_version, kwargs_to_prepare_input={'ndim': 3}, trace_model=True,
dynamic_shapes=False)
@pytest.mark.parametrize("params", d2_params)
@pytest.mark.parametrize(
"params",
d2_params
+ [
pytest.param(
{"kernel_size": [8, 8], "stride": [8, 4], "padding": 1},
marks=pytest.mark.xfail(reason="Sliding windows that would start in the right padded are ignored.")
)
])
@pytest.mark.parametrize("ceil_mode", [True, False])
@pytest.mark.parametrize("count_include_pad", [True, False])
@pytest.mark.nightly
@@ -145,7 +155,7 @@ class TestPooling(PytorchLayerTest):
self._test(*self.create_model("max_pool1d", **params, ceil_mode=ceil_mode, dilation=dilation),
ie_device, precision, ir_version, kwargs_to_prepare_input={'ndim': 3}, dynamic_shapes=False)
@pytest.mark.parametrize("params", d2_params)
@pytest.mark.parametrize("params", d2_params + d2_params_corner_case)
@pytest.mark.parametrize("ceil_mode", [True, False])
@pytest.mark.parametrize("dilation", [1, 2])
@pytest.mark.nightly