[PT FE] Fix pad value dtype for pooling (#21139)

* [PT FE] Fix pad value dtype for pooling

* Apply suggestions from code review

* Apply suggestions from code review
This commit is contained in:
Maxim Vafin 2023-11-17 16:09:56 +01:00 committed by GitHub
parent 71af0eef38
commit b2e14f9fad
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 11 additions and 8 deletions

View File

@ -53,6 +53,7 @@ OutputVector translate_avg_poolnd(const NodeContext& context) {
// More detail on https://github.com/pytorch/pytorch/issues/57178
if (count_include_pad) {
auto zero = context.mark_node(v0::Constant::create(element::f32, Shape{}, {0}));
zero = context.mark_node(std::make_shared<v1::ConvertLike>(zero, input));
auto zero_i32 = context.mark_node(v0::Constant::create(element::i32, Shape{}, {0}));
Output<Node> rank;
std::tie(std::ignore, rank) = get_shape_rank(context, input);

View File

@ -94,8 +94,9 @@ OutputVector translate_max_poolnd(const NodeContext& context) {
// apply padding on input clear pads attribute
const auto pb = context.mark_node(std::make_shared<v0::Concat>(OutputVector{pads_remaining, padding}, 0));
const auto pe = context.mark_node(std::make_shared<v0::Concat>(OutputVector{pads_remaining, selected_pads}, 0));
const auto minus_inf =
auto minus_inf =
context.mark_node(v0::Constant::create(element::f32, Shape{}, {-std::numeric_limits<float>::infinity()}));
minus_inf = context.mark_node(std::make_shared<v1::ConvertLike>(minus_inf, input));
input = context.mark_node(std::make_shared<v12::Pad>(input, pb, pe, minus_inf, op::PadMode::CONSTANT));
std::fill_n(pads.begin(), pads.size(), 0);
}

View File

@ -4,6 +4,7 @@
import platform
import pytest
import torch
from pytorch_layer_test_class import PytorchLayerTest
import numpy as np
@ -18,7 +19,7 @@ d2_params = [{'kernel_size': [3, 3], 'stride': 1, 'padding': 0},
{'kernel_size': [2, 1], 'stride': [], 'padding': 0},
]
d2_params_corner_case = [{'kernel_size': [8, 8], 'stride': [8,4], 'padding': 1}]
d2_params_corner_case = [{'kernel_size': [8, 8], 'stride': [8, 4], 'padding': 1}]
d1_params = [{'kernel_size': 3, 'stride': 1, 'padding': 0},
{'kernel_size': (4,), 'stride': 1, 'padding': 1},
@ -41,9 +42,7 @@ class TestPooling(PytorchLayerTest):
shape = (1, 3, 15, 15, 15)
return (np.random.randn(*shape[:ndim]).astype(np.float32),)
def create_model(self, op_type, kernel_size, stride, padding, dilation=1, ceil_mode=True, count_include_pad=True):
import torch
def create_model(self, op_type, kernel_size, stride, padding, dilation=1, ceil_mode=True, count_include_pad=True, dtype=torch.float32):
class aten_avg_pooling_base(torch.nn.Module):
def __init__(self):
super(aten_avg_pooling_base, self).__init__()
@ -64,6 +63,7 @@ class TestPooling(PytorchLayerTest):
self.padding = padding
self.dilation = dilation
self.ceil_mode = ceil_mode
self.dtype = dtype
def forward(self, x):
pass
@ -85,7 +85,7 @@ class TestPooling(PytorchLayerTest):
class aten_max_pool2d(aten_max_pooling_base):
def forward(self, x):
return torch.nn.functional.max_pool2d(x, self.kernel_size, self.stride, self.padding, self.dilation,
return torch.nn.functional.max_pool2d(x.to(self.dtype), self.kernel_size, self.stride, self.padding, self.dilation,
self.ceil_mode)
class aten_max_pool3d(aten_max_pooling_base):
@ -187,15 +187,16 @@ class TestPooling(PytorchLayerTest):
@pytest.mark.parametrize("params", d2_params + d2_params_corner_case)
@pytest.mark.parametrize("ceil_mode", [True, False])
@pytest.mark.parametrize("dilation", [1, 2])
@pytest.mark.parametrize("dtype", [torch.float32, torch.int32])
@pytest.mark.nightly
@pytest.mark.precommit
@pytest.mark.xfail(condition=platform.system() == 'Darwin' and platform.machine() == 'arm64',
reason='Ticket - 122715')
def test_max_pool2d(self, params, ceil_mode, dilation, ie_device, precision, ir_version):
def test_max_pool2d(self, params, ceil_mode, dilation, dtype, ie_device, precision, ir_version):
to_trace = False
if params["stride"] == []:
to_trace = True
self._test(*self.create_model("max_pool2d", **params, ceil_mode=ceil_mode, dilation=dilation),
self._test(*self.create_model("max_pool2d", **params, ceil_mode=ceil_mode, dilation=dilation, dtype=dtype),
ie_device, precision, ir_version, dynamic_shapes=False, trace_model=to_trace)
@pytest.mark.parametrize("params", d3_params)