[PT FE] Add torchvision::deform_conv2d translation (#16450)

* Initial commit

* Initial commit

* Cleanup

* Improve tests

* Make NodeContext const
This commit is contained in:
Mateusz Mikolajczyk 2023-03-27 11:13:32 +02:00 committed by GitHub
parent bb9de29062
commit 7d16ee1835
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 274 additions and 0 deletions

View File

@ -0,0 +1,83 @@
// Copyright (C) 2018-2023 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#include "openvino/frontend/pytorch/node_context.hpp"
#include "openvino/op/add.hpp"
#include "openvino/op/deformable_convolution.hpp"
#include "pt_framework_node.hpp"
#include "utils.hpp"
namespace ov {
namespace frontend {
namespace pytorch {
namespace op {
using namespace ov::op;
OutputVector translate_deform_conv(const NodeContext& context) {
// torchvision::deform_conv2d(Tensor input, Tensor weight, Tensor offset,
// Tensor mask, Tensor bias, int64_t stride_h, int64_t stride_w,
// int64_t pad_h, int64_t pad_w, int64_t dilation_h, int64_t dilation_w,
// int64_t n_weight_grps, int64_t n_offset_grps, bool use_mask) -> Tensor
num_inputs_check(context, 14, 14);
auto pt_input = context.get_input(0);
auto pt_weight = context.get_input(1);
auto pt_offset = context.get_input(2);
auto pt_mask = context.get_input(3);
int32_t pt_stride_h = context.const_input<int32_t>(5);
int32_t pt_stride_w = context.const_input<int32_t>(6);
auto strides = Strides({(size_t)pt_stride_h, (size_t)pt_stride_w});
int32_t pt_pad_h = context.const_input<int32_t>(7);
int32_t pt_pad_w = context.const_input<int32_t>(8);
auto pads = CoordinateDiff({pt_pad_h, pt_pad_w});
int32_t pt_dilation_h = context.const_input<int32_t>(9);
int32_t pt_dilation_w = context.const_input<int32_t>(10);
auto dilations = Strides({(size_t)pt_dilation_h, (size_t)pt_dilation_w});
int32_t pt_n_weight_grps = context.const_input<int32_t>(11);
int32_t pt_n_offset_grps = context.const_input<int32_t>(12);
bool pt_use_mask = context.const_input<bool>(13);
std::shared_ptr<ov::Node> deformable_convolution;
if (!pt_use_mask) {
deformable_convolution = context.mark_node(std::make_shared<v8::DeformableConvolution>(pt_input,
pt_offset,
pt_weight,
strides,
pads,
pads,
dilations,
PadType::EXPLICIT,
pt_n_weight_grps,
pt_n_offset_grps,
true));
} else {
deformable_convolution = context.mark_node(std::make_shared<v8::DeformableConvolution>(pt_input,
pt_offset,
pt_weight,
pt_mask,
strides,
pads,
pads,
dilations,
PadType::EXPLICIT,
pt_n_weight_grps,
pt_n_offset_grps,
true));
}
if (!context.input_is_none(4)) {
auto bias = context.get_input(4);
bias = reshape_channelwise(context, bias, deformable_convolution);
deformable_convolution = context.mark_node(std::make_shared<v1::Add>(deformable_convolution, bias));
}
return {context.mark_output(deformable_convolution)};
}
} // namespace op
} // namespace pytorch
} // namespace frontend
} // namespace ov

View File

@ -34,6 +34,7 @@ OP_CONVERTER(translate_convnd);
OP_CONVERTER(translate_convolution);
OP_CONVERTER(translate_convolution_mode);
OP_CONVERTER(translate_cumsum);
OP_CONVERTER(translate_deform_conv);
OP_CONVERTER(translate_dim);
OP_CONVERTER(translate_div);
OP_CONVERTER(translate_elu);
@ -348,6 +349,7 @@ const std::map<std::string, CreatorFunction> get_supported_ops() {
{"prim::requires_grad", op::return_false_scalar},
{"prim::PythonOp", op::translate_pythonop},
{"prim::type", op::skip_node}, // Used with prim::device, pass PtFrameworkNode.
{"torchvision::deform_conv2d", op::translate_deform_conv},
{"torchvision::nms", op::translate_nms},
{"torchvision::roi_align", op::translate_roi_align},
};

View File

@ -0,0 +1,189 @@
# Copyright (C) 2018-2023 Intel Corporation
# SPDX-License-Identifier: Apache-2.0
from copy import deepcopy
import numpy as np
import pytest
import torch
from pytorch_layer_test_class import PytorchLayerTest
from torchvision.ops import deform_conv2d
def xfail_106712(test_param):
return pytest.param(
test_param,
marks=pytest.mark.xfail(
reason="Depending on number of groups and number of output channels, deformable convolution may return incorrect reasults. Ticket 106712"
),
)
params = [
{
"weights_shape": [64, 64, 3, 3],
"offset_shape": [1, 18, 64, 64],
"stride": (1, 1),
"padding": (1, 1),
"dilation": (1, 1),
},
{
"weights_shape": [64, 64, 3, 3],
"offset_shape": [1, 18, 62, 64],
"stride": (1, 1),
"padding": (1, 1),
"dilation": (2, 1),
},
{
"weights_shape": [64, 64, 3, 3],
"offset_shape": [1, 18, 66, 64],
"stride": (1, 1),
"padding": (2, 1),
"dilation": (1, 1),
},
{
"weights_shape": [64, 64, 3, 3],
"offset_shape": [1, 18, 32, 64],
"stride": (2, 1),
"padding": (1, 1),
"dilation": (1, 1),
},
{
"weights_shape": [64, 64, 3, 3],
"offset_shape": [1, 18, 62, 62],
"stride": (1, 1),
"padding": (0, 0),
"dilation": (1, 1),
},
{
"weights_shape": [64, 64, 3, 3],
"offset_shape": [1, 18, 66, 66],
"stride": (1, 1),
"padding": (2, 2),
"dilation": (1, 1),
},
xfail_106712(
{
"weights_shape": [64, 16, 3, 3],
"offset_shape": [1, 18, 64, 64],
"stride": (1, 1),
"padding": (1, 1),
"dilation": (1, 1),
}
),
{
"weights_shape": [60, 16, 3, 3],
"offset_shape": [1, 18, 64, 64],
"stride": (1, 1),
"padding": (1, 1),
"dilation": (1, 1),
},
{
"weights_shape": [64, 1, 3, 3],
"offset_shape": [1, 18, 64, 64],
"stride": (1, 1),
"padding": (1, 1),
"dilation": (1, 1),
},
{
"weights_shape": [64, 64, 3, 3],
"offset_shape": [1, 36, 64, 64],
"stride": (1, 1),
"padding": (1, 1),
"dilation": (1, 1),
},
xfail_106712(
{
"weights_shape": [64, 32, 3, 3],
"offset_shape": [1, 36, 68, 68],
"stride": (1, 1),
"padding": (3, 3),
"dilation": (1, 1),
}
),
{
"weights_shape": [62, 32, 3, 3],
"offset_shape": [1, 36, 68, 68],
"stride": (1, 1),
"padding": (3, 3),
"dilation": (1, 1),
},
{
"weights_shape": [2, 32, 3, 3],
"offset_shape": [1, 36, 68, 68],
"stride": (1, 1),
"padding": (3, 3),
"dilation": (1, 1),
},
{
"weights_shape": [1, 64, 3, 3],
"offset_shape": [1, 18, 68, 68],
"stride": (1, 1),
"padding": (3, 3),
"dilation": (1, 1),
},
]
class TestDeformableConvolution(PytorchLayerTest):
def _prepare_input(self):
return (np.random.rand(1, 64, 64, 64).astype(np.float32),)
def create_model(
self,
offset_shape,
weights_shape,
stride,
padding,
dilation,
bias,
mask,
mask_shape=None,
bias_shape=None,
):
class aten_deform_convolution(torch.nn.Module):
def __init__(self):
super(aten_deform_convolution, self).__init__()
self.weight = torch.rand(weights_shape)
self.offset = torch.rand(offset_shape)
if mask_shape is None:
self.mask_shape = deepcopy(offset_shape)
self.mask_shape[1] = self.mask_shape[1] // 2
else:
self.mask_shape = mask_shape
if mask:
self.mask = torch.rand(self.mask_shape)
else:
self.mask = None
self.stride = stride
self.padding = padding
self.dilation = dilation
self.bias_shape = bias_shape
if self.bias_shape is None:
self.bias_shape = weights_shape[0]
self.bias = torch.rand(self.bias_shape) if bias else None
def forward(self, x):
return deform_conv2d(
x,
self.offset,
self.weight,
bias=self.bias,
mask=self.mask,
stride=self.stride,
dilation=self.dilation,
padding=self.padding,
)
ref_net = None
return aten_deform_convolution(), ref_net, "torchvision::deform_conv2d"
@pytest.mark.parametrize("params", params)
@pytest.mark.parametrize("bias", [True, False])
@pytest.mark.parametrize("mask", [True, False])
@pytest.mark.nightly
@pytest.mark.precommit
def test_deformable_convolution2d(self, params, bias, mask, ie_device, precision, ir_version):
self._test(
*self.create_model(**params, bias=bias, mask=mask), ie_device, precision, ir_version, trace_model=True
)