[PT FE]: support grid sampler (#15243)

This commit is contained in:
Ekaterina Aidova 2023-01-31 14:04:37 +04:00 committed by GitHub
parent 7e3e0ff003
commit 758a0dea56
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 94 additions and 3 deletions

View File

@ -26,9 +26,11 @@ bool op::v9::GridSample::visit_attributes(AttributeVisitor& visitor) {
void op::v9::GridSample::validate_and_infer_types() {
OV_OP_SCOPE(v9_GridSample_validate_and_infer_types);
NODE_VALIDATION_CHECK(this,
get_input_element_type(1).is_real(),
"The element type of the grid input tensor must be a floating point type.");
if (!get_input_element_type(1).is_dynamic()) {
NODE_VALIDATION_CHECK(this,
get_input_element_type(1).is_real(),
"The element type of the grid input tensor must be a floating point type.");
}
std::vector<PartialShape> out_shapes(1);
shape_infer(this, {get_input_partial_shape(0), get_input_partial_shape(1)}, out_shapes);

View File

@ -0,0 +1,46 @@
// Copyright (C) 2018-2023 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#include "openvino/frontend/pytorch/node_context.hpp"
#include "openvino/op/grid_sample.hpp"
namespace ov {
namespace frontend {
namespace pytorch {
namespace op {
OutputVector translate_grid_sampler(NodeContext& context) {
auto x = context.get_input(0);
auto grid = context.get_input(1);
ov::op::v9::GridSample::Attributes attrs{};
const std::unordered_map<int64_t, ov::op::v9::GridSample::InterpolationMode> grid_sample_mode_map{
{0, ov::op::v9::GridSample::InterpolationMode::BILINEAR},
{1, ov::op::v9::GridSample::InterpolationMode::NEAREST},
{2, ov::op::v9::GridSample::InterpolationMode::BICUBIC},
};
const std::unordered_map<int64_t, ov::op::v9::GridSample::PaddingMode> grid_sample_padding_mode_map{
{0, ov::op::v9::GridSample::PaddingMode::ZEROS},
{1, ov::op::v9::GridSample::PaddingMode::BORDER},
{2, ov::op::v9::GridSample::PaddingMode::REFLECTION}};
auto mode = context.const_input<int64_t>(2);
FRONT_END_OP_CONVERSION_CHECK(grid_sample_mode_map.count(mode), "Unknown interpolation mode: ", mode);
attrs.mode = grid_sample_mode_map.at(mode);
auto padding_mode = context.const_input<int64_t>(3);
FRONT_END_OP_CONVERSION_CHECK(grid_sample_padding_mode_map.count(padding_mode),
"Unknown padding mode: ",
padding_mode);
attrs.padding_mode = grid_sample_padding_mode_map.at(padding_mode);
bool align_corners = false;
if (!context.input_is_none(4)) {
align_corners = context.const_input<int64_t>(4);
}
attrs.align_corners = align_corners;
return {context.mark_node(std::make_shared<ov::op::v9::GridSample>(x, grid, attrs))};
};
} // namespace op
} // namespace pytorch
} // namespace frontend
} // namespace ov

View File

@ -45,6 +45,7 @@ OP_CONVERTER(translate_full_like);
OP_CONVERTER(translate_gelu);
OP_CONVERTER(translate_get_attr);
OP_CONVERTER(translate_glu);
OP_CONVERTER(translate_grid_sampler);
OP_CONVERTER(translate_group_norm);
OP_CONVERTER(translate_hardtanh);
OP_CONVERTER(translate_if);
@ -189,6 +190,7 @@ const std::map<std::string, CreatorFunction> get_supported_ops() {
{"aten::group_norm", op::translate_group_norm},
{"aten::ge", op::translate_1to1_match_2_inputs<opset10::GreaterEqual>},
{"aten::gt", op::translate_1to1_match_2_inputs<opset10::Greater>},
{"aten::grid_sampler", op::translate_grid_sampler},
{"aten::hardsigmoid", op::translate_1to1_match_1_inputs<opset10::HSigmoid>},
{"aten::hardswish", op::translate_1to1_match_1_inputs<opset10::HSwish>},
{"aten::hardswish_", op::inplace_op<op::translate_1to1_match_1_inputs<opset10::HSwish>>},

View File

@ -0,0 +1,41 @@
# Copyright (C) 2018-2023 Intel Corporation
# SPDX-License-Identifier: Apache-2.0
import pytest
from pytorch_layer_test_class import PytorchLayerTest
class TestGridSampler(PytorchLayerTest):
def _prepare_input(self, h_in, w_in, h_out, w_out):
import numpy as np
return (np.random.randn(1, 3, h_in, w_in).astype(np.float32), np.random.randn(1, h_out, w_out, 2).astype(np.float32))
def create_model(self, mode, padding_mode, align_corners):
import torch
import torch.nn.functional as F
class aten_grid_sampler(torch.nn.Module):
def __init__(self, mode, padding_mode, align_corners):
super(aten_grid_sampler, self).__init__()
self.mode = mode
self.padding_mode = padding_mode
self.align_corners = align_corners
def forward(self, input, grid):
return F.grid_sample(input, grid, self.mode, self.padding_mode, self.align_corners)
ref_net = None
return aten_grid_sampler(mode, padding_mode, align_corners), ref_net, "aten::grid_sampler"
@pytest.mark.parametrize(["h_in", "w_in", "h_out", "w_out"], ([10, 10, 5, 5], [10, 15, 3, 5]))
@pytest.mark.parametrize("mode", ["bilinear", "nearest", "bicubic"])
@pytest.mark.parametrize("padding_mode", ["zeros", "border", "reflection"])
@pytest.mark.parametrize("align_corners", [True, False, None])
@pytest.mark.nightly
@pytest.mark.precommit
def test_grid_sampler(self, h_in, w_in, h_out, w_out, mode, padding_mode, align_corners, ie_device, precision, ir_version):
self._test(*self.create_model(mode, padding_mode, align_corners), ie_device, precision, ir_version, kwargs_to_prepare_input={
"h_in": h_in, "w_in": w_in, "h_out": h_out, "w_out": w_out
})