[PT FE] Add torchvision::deform_conv2d translation (#16450)
* Initial commit * Initial commit * Cleanup * Improve tests * Make NodeContext const
This commit is contained in:
parent
bb9de29062
commit
7d16ee1835
83
src/frontends/pytorch/src/op/deform_conv.cpp
Normal file
83
src/frontends/pytorch/src/op/deform_conv.cpp
Normal file
@ -0,0 +1,83 @@
|
||||
// Copyright (C) 2018-2023 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#include "openvino/frontend/pytorch/node_context.hpp"
|
||||
#include "openvino/op/add.hpp"
|
||||
#include "openvino/op/deformable_convolution.hpp"
|
||||
#include "pt_framework_node.hpp"
|
||||
#include "utils.hpp"
|
||||
|
||||
namespace ov {
|
||||
namespace frontend {
|
||||
namespace pytorch {
|
||||
namespace op {
|
||||
|
||||
using namespace ov::op;
|
||||
|
||||
OutputVector translate_deform_conv(const NodeContext& context) {
|
||||
// torchvision::deform_conv2d(Tensor input, Tensor weight, Tensor offset,
|
||||
// Tensor mask, Tensor bias, int64_t stride_h, int64_t stride_w,
|
||||
// int64_t pad_h, int64_t pad_w, int64_t dilation_h, int64_t dilation_w,
|
||||
// int64_t n_weight_grps, int64_t n_offset_grps, bool use_mask) -> Tensor
|
||||
num_inputs_check(context, 14, 14);
|
||||
auto pt_input = context.get_input(0);
|
||||
auto pt_weight = context.get_input(1);
|
||||
auto pt_offset = context.get_input(2);
|
||||
auto pt_mask = context.get_input(3);
|
||||
|
||||
int32_t pt_stride_h = context.const_input<int32_t>(5);
|
||||
int32_t pt_stride_w = context.const_input<int32_t>(6);
|
||||
auto strides = Strides({(size_t)pt_stride_h, (size_t)pt_stride_w});
|
||||
|
||||
int32_t pt_pad_h = context.const_input<int32_t>(7);
|
||||
int32_t pt_pad_w = context.const_input<int32_t>(8);
|
||||
auto pads = CoordinateDiff({pt_pad_h, pt_pad_w});
|
||||
|
||||
int32_t pt_dilation_h = context.const_input<int32_t>(9);
|
||||
int32_t pt_dilation_w = context.const_input<int32_t>(10);
|
||||
auto dilations = Strides({(size_t)pt_dilation_h, (size_t)pt_dilation_w});
|
||||
|
||||
int32_t pt_n_weight_grps = context.const_input<int32_t>(11);
|
||||
int32_t pt_n_offset_grps = context.const_input<int32_t>(12);
|
||||
bool pt_use_mask = context.const_input<bool>(13);
|
||||
|
||||
std::shared_ptr<ov::Node> deformable_convolution;
|
||||
if (!pt_use_mask) {
|
||||
deformable_convolution = context.mark_node(std::make_shared<v8::DeformableConvolution>(pt_input,
|
||||
pt_offset,
|
||||
pt_weight,
|
||||
strides,
|
||||
pads,
|
||||
pads,
|
||||
dilations,
|
||||
PadType::EXPLICIT,
|
||||
pt_n_weight_grps,
|
||||
pt_n_offset_grps,
|
||||
true));
|
||||
} else {
|
||||
deformable_convolution = context.mark_node(std::make_shared<v8::DeformableConvolution>(pt_input,
|
||||
pt_offset,
|
||||
pt_weight,
|
||||
pt_mask,
|
||||
strides,
|
||||
pads,
|
||||
pads,
|
||||
dilations,
|
||||
PadType::EXPLICIT,
|
||||
pt_n_weight_grps,
|
||||
pt_n_offset_grps,
|
||||
true));
|
||||
}
|
||||
|
||||
if (!context.input_is_none(4)) {
|
||||
auto bias = context.get_input(4);
|
||||
bias = reshape_channelwise(context, bias, deformable_convolution);
|
||||
deformable_convolution = context.mark_node(std::make_shared<v1::Add>(deformable_convolution, bias));
|
||||
}
|
||||
return {context.mark_output(deformable_convolution)};
|
||||
}
|
||||
} // namespace op
|
||||
} // namespace pytorch
|
||||
} // namespace frontend
|
||||
} // namespace ov
|
@ -34,6 +34,7 @@ OP_CONVERTER(translate_convnd);
|
||||
OP_CONVERTER(translate_convolution);
|
||||
OP_CONVERTER(translate_convolution_mode);
|
||||
OP_CONVERTER(translate_cumsum);
|
||||
OP_CONVERTER(translate_deform_conv);
|
||||
OP_CONVERTER(translate_dim);
|
||||
OP_CONVERTER(translate_div);
|
||||
OP_CONVERTER(translate_elu);
|
||||
@ -348,6 +349,7 @@ const std::map<std::string, CreatorFunction> get_supported_ops() {
|
||||
{"prim::requires_grad", op::return_false_scalar},
|
||||
{"prim::PythonOp", op::translate_pythonop},
|
||||
{"prim::type", op::skip_node}, // Used with prim::device, pass PtFrameworkNode.
|
||||
{"torchvision::deform_conv2d", op::translate_deform_conv},
|
||||
{"torchvision::nms", op::translate_nms},
|
||||
{"torchvision::roi_align", op::translate_roi_align},
|
||||
};
|
||||
|
189
tests/layer_tests/pytorch_tests/test_deformable_convolution.py
Normal file
189
tests/layer_tests/pytorch_tests/test_deformable_convolution.py
Normal file
@ -0,0 +1,189 @@
|
||||
# Copyright (C) 2018-2023 Intel Corporation
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
from copy import deepcopy
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
import torch
|
||||
from pytorch_layer_test_class import PytorchLayerTest
|
||||
from torchvision.ops import deform_conv2d
|
||||
|
||||
|
||||
def xfail_106712(test_param):
|
||||
return pytest.param(
|
||||
test_param,
|
||||
marks=pytest.mark.xfail(
|
||||
reason="Depending on number of groups and number of output channels, deformable convolution may return incorrect reasults. Ticket 106712"
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
params = [
|
||||
{
|
||||
"weights_shape": [64, 64, 3, 3],
|
||||
"offset_shape": [1, 18, 64, 64],
|
||||
"stride": (1, 1),
|
||||
"padding": (1, 1),
|
||||
"dilation": (1, 1),
|
||||
},
|
||||
{
|
||||
"weights_shape": [64, 64, 3, 3],
|
||||
"offset_shape": [1, 18, 62, 64],
|
||||
"stride": (1, 1),
|
||||
"padding": (1, 1),
|
||||
"dilation": (2, 1),
|
||||
},
|
||||
{
|
||||
"weights_shape": [64, 64, 3, 3],
|
||||
"offset_shape": [1, 18, 66, 64],
|
||||
"stride": (1, 1),
|
||||
"padding": (2, 1),
|
||||
"dilation": (1, 1),
|
||||
},
|
||||
{
|
||||
"weights_shape": [64, 64, 3, 3],
|
||||
"offset_shape": [1, 18, 32, 64],
|
||||
"stride": (2, 1),
|
||||
"padding": (1, 1),
|
||||
"dilation": (1, 1),
|
||||
},
|
||||
{
|
||||
"weights_shape": [64, 64, 3, 3],
|
||||
"offset_shape": [1, 18, 62, 62],
|
||||
"stride": (1, 1),
|
||||
"padding": (0, 0),
|
||||
"dilation": (1, 1),
|
||||
},
|
||||
{
|
||||
"weights_shape": [64, 64, 3, 3],
|
||||
"offset_shape": [1, 18, 66, 66],
|
||||
"stride": (1, 1),
|
||||
"padding": (2, 2),
|
||||
"dilation": (1, 1),
|
||||
},
|
||||
xfail_106712(
|
||||
{
|
||||
"weights_shape": [64, 16, 3, 3],
|
||||
"offset_shape": [1, 18, 64, 64],
|
||||
"stride": (1, 1),
|
||||
"padding": (1, 1),
|
||||
"dilation": (1, 1),
|
||||
}
|
||||
),
|
||||
{
|
||||
"weights_shape": [60, 16, 3, 3],
|
||||
"offset_shape": [1, 18, 64, 64],
|
||||
"stride": (1, 1),
|
||||
"padding": (1, 1),
|
||||
"dilation": (1, 1),
|
||||
},
|
||||
{
|
||||
"weights_shape": [64, 1, 3, 3],
|
||||
"offset_shape": [1, 18, 64, 64],
|
||||
"stride": (1, 1),
|
||||
"padding": (1, 1),
|
||||
"dilation": (1, 1),
|
||||
},
|
||||
{
|
||||
"weights_shape": [64, 64, 3, 3],
|
||||
"offset_shape": [1, 36, 64, 64],
|
||||
"stride": (1, 1),
|
||||
"padding": (1, 1),
|
||||
"dilation": (1, 1),
|
||||
},
|
||||
xfail_106712(
|
||||
{
|
||||
"weights_shape": [64, 32, 3, 3],
|
||||
"offset_shape": [1, 36, 68, 68],
|
||||
"stride": (1, 1),
|
||||
"padding": (3, 3),
|
||||
"dilation": (1, 1),
|
||||
}
|
||||
),
|
||||
{
|
||||
"weights_shape": [62, 32, 3, 3],
|
||||
"offset_shape": [1, 36, 68, 68],
|
||||
"stride": (1, 1),
|
||||
"padding": (3, 3),
|
||||
"dilation": (1, 1),
|
||||
},
|
||||
{
|
||||
"weights_shape": [2, 32, 3, 3],
|
||||
"offset_shape": [1, 36, 68, 68],
|
||||
"stride": (1, 1),
|
||||
"padding": (3, 3),
|
||||
"dilation": (1, 1),
|
||||
},
|
||||
{
|
||||
"weights_shape": [1, 64, 3, 3],
|
||||
"offset_shape": [1, 18, 68, 68],
|
||||
"stride": (1, 1),
|
||||
"padding": (3, 3),
|
||||
"dilation": (1, 1),
|
||||
},
|
||||
]
|
||||
|
||||
|
||||
class TestDeformableConvolution(PytorchLayerTest):
|
||||
def _prepare_input(self):
|
||||
return (np.random.rand(1, 64, 64, 64).astype(np.float32),)
|
||||
|
||||
def create_model(
|
||||
self,
|
||||
offset_shape,
|
||||
weights_shape,
|
||||
stride,
|
||||
padding,
|
||||
dilation,
|
||||
bias,
|
||||
mask,
|
||||
mask_shape=None,
|
||||
bias_shape=None,
|
||||
):
|
||||
class aten_deform_convolution(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super(aten_deform_convolution, self).__init__()
|
||||
self.weight = torch.rand(weights_shape)
|
||||
self.offset = torch.rand(offset_shape)
|
||||
if mask_shape is None:
|
||||
self.mask_shape = deepcopy(offset_shape)
|
||||
self.mask_shape[1] = self.mask_shape[1] // 2
|
||||
else:
|
||||
self.mask_shape = mask_shape
|
||||
if mask:
|
||||
self.mask = torch.rand(self.mask_shape)
|
||||
else:
|
||||
self.mask = None
|
||||
self.stride = stride
|
||||
self.padding = padding
|
||||
self.dilation = dilation
|
||||
self.bias_shape = bias_shape
|
||||
if self.bias_shape is None:
|
||||
self.bias_shape = weights_shape[0]
|
||||
self.bias = torch.rand(self.bias_shape) if bias else None
|
||||
|
||||
def forward(self, x):
|
||||
return deform_conv2d(
|
||||
x,
|
||||
self.offset,
|
||||
self.weight,
|
||||
bias=self.bias,
|
||||
mask=self.mask,
|
||||
stride=self.stride,
|
||||
dilation=self.dilation,
|
||||
padding=self.padding,
|
||||
)
|
||||
|
||||
ref_net = None
|
||||
return aten_deform_convolution(), ref_net, "torchvision::deform_conv2d"
|
||||
|
||||
@pytest.mark.parametrize("params", params)
|
||||
@pytest.mark.parametrize("bias", [True, False])
|
||||
@pytest.mark.parametrize("mask", [True, False])
|
||||
@pytest.mark.nightly
|
||||
@pytest.mark.precommit
|
||||
def test_deformable_convolution2d(self, params, bias, mask, ie_device, precision, ir_version):
|
||||
self._test(
|
||||
*self.create_model(**params, bias=bias, mask=mask), ie_device, precision, ir_version, trace_model=True
|
||||
)
|
Loading…
Reference in New Issue
Block a user