[PT FE] Fix pad value dtype for pooling (#21139)
* [PT FE] Fix pad value dtype for pooling * Apply suggestions from code review * Apply suggestions from code review
This commit is contained in:
parent
71af0eef38
commit
b2e14f9fad
@ -53,6 +53,7 @@ OutputVector translate_avg_poolnd(const NodeContext& context) {
|
||||
// More detail on https://github.com/pytorch/pytorch/issues/57178
|
||||
if (count_include_pad) {
|
||||
auto zero = context.mark_node(v0::Constant::create(element::f32, Shape{}, {0}));
|
||||
zero = context.mark_node(std::make_shared<v1::ConvertLike>(zero, input));
|
||||
auto zero_i32 = context.mark_node(v0::Constant::create(element::i32, Shape{}, {0}));
|
||||
Output<Node> rank;
|
||||
std::tie(std::ignore, rank) = get_shape_rank(context, input);
|
||||
|
@ -94,8 +94,9 @@ OutputVector translate_max_poolnd(const NodeContext& context) {
|
||||
// apply padding on input clear pads attribute
|
||||
const auto pb = context.mark_node(std::make_shared<v0::Concat>(OutputVector{pads_remaining, padding}, 0));
|
||||
const auto pe = context.mark_node(std::make_shared<v0::Concat>(OutputVector{pads_remaining, selected_pads}, 0));
|
||||
const auto minus_inf =
|
||||
auto minus_inf =
|
||||
context.mark_node(v0::Constant::create(element::f32, Shape{}, {-std::numeric_limits<float>::infinity()}));
|
||||
minus_inf = context.mark_node(std::make_shared<v1::ConvertLike>(minus_inf, input));
|
||||
input = context.mark_node(std::make_shared<v12::Pad>(input, pb, pe, minus_inf, op::PadMode::CONSTANT));
|
||||
std::fill_n(pads.begin(), pads.size(), 0);
|
||||
}
|
||||
|
@ -4,6 +4,7 @@
|
||||
import platform
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from pytorch_layer_test_class import PytorchLayerTest
|
||||
import numpy as np
|
||||
@ -18,7 +19,7 @@ d2_params = [{'kernel_size': [3, 3], 'stride': 1, 'padding': 0},
|
||||
{'kernel_size': [2, 1], 'stride': [], 'padding': 0},
|
||||
]
|
||||
|
||||
d2_params_corner_case = [{'kernel_size': [8, 8], 'stride': [8,4], 'padding': 1}]
|
||||
d2_params_corner_case = [{'kernel_size': [8, 8], 'stride': [8, 4], 'padding': 1}]
|
||||
|
||||
d1_params = [{'kernel_size': 3, 'stride': 1, 'padding': 0},
|
||||
{'kernel_size': (4,), 'stride': 1, 'padding': 1},
|
||||
@ -41,9 +42,7 @@ class TestPooling(PytorchLayerTest):
|
||||
shape = (1, 3, 15, 15, 15)
|
||||
return (np.random.randn(*shape[:ndim]).astype(np.float32),)
|
||||
|
||||
def create_model(self, op_type, kernel_size, stride, padding, dilation=1, ceil_mode=True, count_include_pad=True):
|
||||
import torch
|
||||
|
||||
def create_model(self, op_type, kernel_size, stride, padding, dilation=1, ceil_mode=True, count_include_pad=True, dtype=torch.float32):
|
||||
class aten_avg_pooling_base(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super(aten_avg_pooling_base, self).__init__()
|
||||
@ -64,6 +63,7 @@ class TestPooling(PytorchLayerTest):
|
||||
self.padding = padding
|
||||
self.dilation = dilation
|
||||
self.ceil_mode = ceil_mode
|
||||
self.dtype = dtype
|
||||
|
||||
def forward(self, x):
|
||||
pass
|
||||
@ -85,7 +85,7 @@ class TestPooling(PytorchLayerTest):
|
||||
|
||||
class aten_max_pool2d(aten_max_pooling_base):
|
||||
def forward(self, x):
|
||||
return torch.nn.functional.max_pool2d(x, self.kernel_size, self.stride, self.padding, self.dilation,
|
||||
return torch.nn.functional.max_pool2d(x.to(self.dtype), self.kernel_size, self.stride, self.padding, self.dilation,
|
||||
self.ceil_mode)
|
||||
|
||||
class aten_max_pool3d(aten_max_pooling_base):
|
||||
@ -187,15 +187,16 @@ class TestPooling(PytorchLayerTest):
|
||||
@pytest.mark.parametrize("params", d2_params + d2_params_corner_case)
|
||||
@pytest.mark.parametrize("ceil_mode", [True, False])
|
||||
@pytest.mark.parametrize("dilation", [1, 2])
|
||||
@pytest.mark.parametrize("dtype", [torch.float32, torch.int32])
|
||||
@pytest.mark.nightly
|
||||
@pytest.mark.precommit
|
||||
@pytest.mark.xfail(condition=platform.system() == 'Darwin' and platform.machine() == 'arm64',
|
||||
reason='Ticket - 122715')
|
||||
def test_max_pool2d(self, params, ceil_mode, dilation, ie_device, precision, ir_version):
|
||||
def test_max_pool2d(self, params, ceil_mode, dilation, dtype, ie_device, precision, ir_version):
|
||||
to_trace = False
|
||||
if params["stride"] == []:
|
||||
to_trace = True
|
||||
self._test(*self.create_model("max_pool2d", **params, ceil_mode=ceil_mode, dilation=dilation),
|
||||
self._test(*self.create_model("max_pool2d", **params, ceil_mode=ceil_mode, dilation=dilation, dtype=dtype),
|
||||
ie_device, precision, ir_version, dynamic_shapes=False, trace_model=to_trace)
|
||||
|
||||
@pytest.mark.parametrize("params", d3_params)
|
||||
|
Loading…
Reference in New Issue
Block a user