[PT FE] Improved MaxPool convert by PyTorch FE (#18965)

* Improved MaxPool convert by PyTorch FE

* Mark xfail corner cases for AvgPool

* Update src/frontends/pytorch/src/op/max_poolnd.cpp

Co-authored-by: Maxim Vafin <maxim.vafin@intel.com>

* Update src/frontends/pytorch/src/op/max_poolnd.cpp

Co-authored-by: Maxim Vafin <maxim.vafin@intel.com>

* Fix build issues

---------

Co-authored-by: Maxim Vafin <maxim.vafin@intel.com>
This commit is contained in:
Pawel Raasz 2023-08-04 17:07:33 +02:00 committed by GitHub
parent 36309938d9
commit 2385c2769e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 73 additions and 5 deletions

View File

@ -3,7 +3,18 @@
//
#include "openvino/frontend/pytorch/node_context.hpp"
#include "openvino/op/add.hpp"
#include "openvino/op/concat.hpp"
#include "openvino/op/constant.hpp"
#include "openvino/op/gather.hpp"
#include "openvino/op/greater.hpp"
#include "openvino/op/max_pool.hpp"
#include "openvino/op/multiply.hpp"
#include "openvino/op/pad.hpp"
#include "openvino/op/range.hpp"
#include "openvino/op/select.hpp"
#include "openvino/op/shape_of.hpp"
#include "openvino/op/subtract.hpp"
#include "openvino/op/util/framework_node.hpp"
#include "utils.hpp"
@ -21,7 +32,8 @@ OutputVector translate_max_poolnd(const NodeContext& context) {
if (!context.input_is_none(2)) {
strides = context.const_input<Strides>(2);
}
if (context.input_is_none(2) || strides.size() == 0) {
const bool use_kernel = context.input_is_none(2) || (strides.size() == 0);
if (use_kernel) {
// In case strides are not provided default is kernel
strides = kernel;
}
@ -42,8 +54,54 @@ OutputVector translate_max_poolnd(const NodeContext& context) {
rounding_type = context.const_input<bool>(5) ? RoundingType::CEIL : RoundingType::FLOOR;
}
return {context.mark_node(
std::make_shared<v8::MaxPool>(context.get_input(0), strides, dilations, pads, pads, kernel, rounding_type))};
auto input = context.get_input(0);
if (rounding_type == RoundingType::CEIL) {
// The corner case of Max Pooling with ceil_mode on
// PyTorch allows sliding window go off bound, which leads to this accommodation.
// More detail on https://github.com/pytorch/pytorch/issues/57178
const auto zero = context.mark_node(v0::Constant::create(element::i32, Shape{}, {0}));
const auto one = context.mark_node(v0::Constant::create(element::i32, Shape{}, {1}));
const auto two = context.mark_node(v0::Constant::create(element::i32, Shape{}, {2}));
const auto padding =
context.input_is_none(3)
? context.mark_node(std::make_shared<v0::Constant>(element::i32, Shape{pads.size()}, 0))->output(0)
: context.get_input(3);
const auto pads_len = context.mark_node(v0::Constant::create(element::i32, Shape{}, {pads.size()}));
const auto pads_remaining = context.mark_node(v0::Constant::create(element::i32, Shape{2}, {0, 0}));
// gather input spatial dims and prepare for compare as values (in_dim + pad)
const auto input_shape_rank = get_shape_rank(context, input);
const auto end = context.mark_node(v0::Constant::create(element::i32, Shape{}, {pads.size() + 2}));
const auto dim_idxs = context.mark_node(std::make_shared<v4::Range>(two, end, one, element::i32));
const auto gth_in_dims =
context.mark_node(std::make_shared<v8::Gather>(std::get<0>(input_shape_rank), dim_idxs, zero));
const auto in_left_padded = context.mark_node(std::make_shared<v1::Add>(gth_in_dims, padding));
// gather output spatial dims and prepare it for compare as values (out_dim - 1) * stride
const auto mp = context.mark_node(
std::make_shared<v8::MaxPool>(input, strides, dilations, pads, pads, kernel, rounding_type));
const auto shape_of_mp = context.mark_node(std::make_shared<v3::ShapeOf>(mp, element::i32));
const auto gth_out_dims = context.mark_node(std::make_shared<v8::Gather>(shape_of_mp, dim_idxs, zero));
const auto out_sub_one = context.mark_node(std::make_shared<v1::Subtract>(gth_out_dims, one));
const auto stride_node = use_kernel ? context.get_input(1) : context.get_input(2);
const auto out_mul_stride = context.mark_node(std::make_shared<v1::Multiply>(out_sub_one, stride_node));
// if (in_dim + pad) > ((out_dim - 1) * stride) sliding window in bound use end padding.
const auto in_gt_out = context.mark_node(std::make_shared<v1::Greater>(in_left_padded, out_mul_stride));
const auto selected_pads = context.mark_node(std::make_shared<v1::Select>(in_gt_out, padding, zero));
// apply padding on input clear pads attribute
const auto pb = context.mark_node(std::make_shared<v0::Concat>(OutputVector{pads_remaining, padding}, 0));
const auto pe = context.mark_node(std::make_shared<v0::Concat>(OutputVector{pads_remaining, selected_pads}, 0));
const auto minus_inf =
context.mark_node(v0::Constant::create(element::f32, Shape{}, {-std::numeric_limits<float>::infinity()}));
input = context.mark_node(std::make_shared<v12::Pad>(input, pb, pe, minus_inf, op::PadMode::CONSTANT));
std::fill_n(pads.begin(), pads.size(), 0);
}
return {
context.mark_node(std::make_shared<v8::MaxPool>(input, strides, dilations, pads, pads, kernel, rounding_type))};
};
OutputVector translate_max_poolnd_fx(const NodeContext& context) {

View File

@ -15,6 +15,8 @@ d2_params = [{'kernel_size': [3, 3], 'stride': 1, 'padding': 0},
{'kernel_size': [2, 1], 'stride': [], 'padding': 0},
]
d2_params_corner_case = [{'kernel_size': [8, 8], 'stride': [8,4], 'padding': 1}]
d1_params = [{'kernel_size': 3, 'stride': 1, 'padding': 0},
{'kernel_size': (4,), 'stride': 1, 'padding': 1},
{'kernel_size': 4, 'stride': (5,), 'padding': 2},
@ -117,7 +119,15 @@ class TestPooling(PytorchLayerTest):
ie_device, precision, ir_version, kwargs_to_prepare_input={'ndim': 3}, trace_model=True,
dynamic_shapes=False)
@pytest.mark.parametrize("params", d2_params)
@pytest.mark.parametrize(
"params",
d2_params
+ [
pytest.param(
{"kernel_size": [8, 8], "stride": [8, 4], "padding": 1},
marks=pytest.mark.xfail(reason="Sliding windows that would start in the right padded are ignored.")
)
])
@pytest.mark.parametrize("ceil_mode", [True, False])
@pytest.mark.parametrize("count_include_pad", [True, False])
@pytest.mark.nightly
@ -145,7 +155,7 @@ class TestPooling(PytorchLayerTest):
self._test(*self.create_model("max_pool1d", **params, ceil_mode=ceil_mode, dilation=dilation),
ie_device, precision, ir_version, kwargs_to_prepare_input={'ndim': 3}, dynamic_shapes=False)
@pytest.mark.parametrize("params", d2_params)
@pytest.mark.parametrize("params", d2_params + d2_params_corner_case)
@pytest.mark.parametrize("ceil_mode", [True, False])
@pytest.mark.parametrize("dilation", [1, 2])
@pytest.mark.nightly