[PT FE]: support grid sampler (#15243)
This commit is contained in:
parent
7e3e0ff003
commit
758a0dea56
@ -26,9 +26,11 @@ bool op::v9::GridSample::visit_attributes(AttributeVisitor& visitor) {
|
||||
|
||||
void op::v9::GridSample::validate_and_infer_types() {
|
||||
OV_OP_SCOPE(v9_GridSample_validate_and_infer_types);
|
||||
NODE_VALIDATION_CHECK(this,
|
||||
get_input_element_type(1).is_real(),
|
||||
"The element type of the grid input tensor must be a floating point type.");
|
||||
if (!get_input_element_type(1).is_dynamic()) {
|
||||
NODE_VALIDATION_CHECK(this,
|
||||
get_input_element_type(1).is_real(),
|
||||
"The element type of the grid input tensor must be a floating point type.");
|
||||
}
|
||||
|
||||
std::vector<PartialShape> out_shapes(1);
|
||||
shape_infer(this, {get_input_partial_shape(0), get_input_partial_shape(1)}, out_shapes);
|
||||
|
46
src/frontends/pytorch/src/op/grid_sampler.cpp
Normal file
46
src/frontends/pytorch/src/op/grid_sampler.cpp
Normal file
@ -0,0 +1,46 @@
|
||||
// Copyright (C) 2018-2023 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#include "openvino/frontend/pytorch/node_context.hpp"
|
||||
#include "openvino/op/grid_sample.hpp"
|
||||
|
||||
namespace ov {
|
||||
namespace frontend {
|
||||
namespace pytorch {
|
||||
namespace op {
|
||||
|
||||
OutputVector translate_grid_sampler(NodeContext& context) {
|
||||
auto x = context.get_input(0);
|
||||
auto grid = context.get_input(1);
|
||||
ov::op::v9::GridSample::Attributes attrs{};
|
||||
const std::unordered_map<int64_t, ov::op::v9::GridSample::InterpolationMode> grid_sample_mode_map{
|
||||
{0, ov::op::v9::GridSample::InterpolationMode::BILINEAR},
|
||||
{1, ov::op::v9::GridSample::InterpolationMode::NEAREST},
|
||||
{2, ov::op::v9::GridSample::InterpolationMode::BICUBIC},
|
||||
};
|
||||
const std::unordered_map<int64_t, ov::op::v9::GridSample::PaddingMode> grid_sample_padding_mode_map{
|
||||
{0, ov::op::v9::GridSample::PaddingMode::ZEROS},
|
||||
{1, ov::op::v9::GridSample::PaddingMode::BORDER},
|
||||
{2, ov::op::v9::GridSample::PaddingMode::REFLECTION}};
|
||||
auto mode = context.const_input<int64_t>(2);
|
||||
FRONT_END_OP_CONVERSION_CHECK(grid_sample_mode_map.count(mode), "Unknown interpolation mode: ", mode);
|
||||
attrs.mode = grid_sample_mode_map.at(mode);
|
||||
auto padding_mode = context.const_input<int64_t>(3);
|
||||
FRONT_END_OP_CONVERSION_CHECK(grid_sample_padding_mode_map.count(padding_mode),
|
||||
"Unknown padding mode: ",
|
||||
padding_mode);
|
||||
attrs.padding_mode = grid_sample_padding_mode_map.at(padding_mode);
|
||||
bool align_corners = false;
|
||||
if (!context.input_is_none(4)) {
|
||||
align_corners = context.const_input<int64_t>(4);
|
||||
}
|
||||
attrs.align_corners = align_corners;
|
||||
|
||||
return {context.mark_node(std::make_shared<ov::op::v9::GridSample>(x, grid, attrs))};
|
||||
};
|
||||
|
||||
} // namespace op
|
||||
} // namespace pytorch
|
||||
} // namespace frontend
|
||||
} // namespace ov
|
@ -45,6 +45,7 @@ OP_CONVERTER(translate_full_like);
|
||||
OP_CONVERTER(translate_gelu);
|
||||
OP_CONVERTER(translate_get_attr);
|
||||
OP_CONVERTER(translate_glu);
|
||||
OP_CONVERTER(translate_grid_sampler);
|
||||
OP_CONVERTER(translate_group_norm);
|
||||
OP_CONVERTER(translate_hardtanh);
|
||||
OP_CONVERTER(translate_if);
|
||||
@ -189,6 +190,7 @@ const std::map<std::string, CreatorFunction> get_supported_ops() {
|
||||
{"aten::group_norm", op::translate_group_norm},
|
||||
{"aten::ge", op::translate_1to1_match_2_inputs<opset10::GreaterEqual>},
|
||||
{"aten::gt", op::translate_1to1_match_2_inputs<opset10::Greater>},
|
||||
{"aten::grid_sampler", op::translate_grid_sampler},
|
||||
{"aten::hardsigmoid", op::translate_1to1_match_1_inputs<opset10::HSigmoid>},
|
||||
{"aten::hardswish", op::translate_1to1_match_1_inputs<opset10::HSwish>},
|
||||
{"aten::hardswish_", op::inplace_op<op::translate_1to1_match_1_inputs<opset10::HSwish>>},
|
||||
|
41
tests/layer_tests/pytorch_tests/test_grid_sampler.py
Normal file
41
tests/layer_tests/pytorch_tests/test_grid_sampler.py
Normal file
@ -0,0 +1,41 @@
|
||||
# Copyright (C) 2018-2023 Intel Corporation
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
import pytest
|
||||
|
||||
from pytorch_layer_test_class import PytorchLayerTest
|
||||
|
||||
|
||||
class TestGridSampler(PytorchLayerTest):
|
||||
def _prepare_input(self, h_in, w_in, h_out, w_out):
|
||||
import numpy as np
|
||||
return (np.random.randn(1, 3, h_in, w_in).astype(np.float32), np.random.randn(1, h_out, w_out, 2).astype(np.float32))
|
||||
|
||||
def create_model(self, mode, padding_mode, align_corners):
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
class aten_grid_sampler(torch.nn.Module):
|
||||
def __init__(self, mode, padding_mode, align_corners):
|
||||
super(aten_grid_sampler, self).__init__()
|
||||
self.mode = mode
|
||||
self.padding_mode = padding_mode
|
||||
self.align_corners = align_corners
|
||||
|
||||
def forward(self, input, grid):
|
||||
return F.grid_sample(input, grid, self.mode, self.padding_mode, self.align_corners)
|
||||
|
||||
ref_net = None
|
||||
|
||||
return aten_grid_sampler(mode, padding_mode, align_corners), ref_net, "aten::grid_sampler"
|
||||
|
||||
@pytest.mark.parametrize(["h_in", "w_in", "h_out", "w_out"], ([10, 10, 5, 5], [10, 15, 3, 5]))
|
||||
@pytest.mark.parametrize("mode", ["bilinear", "nearest", "bicubic"])
|
||||
@pytest.mark.parametrize("padding_mode", ["zeros", "border", "reflection"])
|
||||
@pytest.mark.parametrize("align_corners", [True, False, None])
|
||||
@pytest.mark.nightly
|
||||
@pytest.mark.precommit
|
||||
def test_grid_sampler(self, h_in, w_in, h_out, w_out, mode, padding_mode, align_corners, ie_device, precision, ir_version):
|
||||
self._test(*self.create_model(mode, padding_mode, align_corners), ie_device, precision, ir_version, kwargs_to_prepare_input={
|
||||
"h_in": h_in, "w_in": w_in, "h_out": h_out, "w_out": w_out
|
||||
})
|
Loading…
Reference in New Issue
Block a user