[PT FE]: support conv transpose (#15191)

* [PT FE]: support conv transpose

* apply comments
This commit is contained in:
Ekaterina Aidova 2023-01-22 11:03:54 +04:00 committed by GitHub
parent ed6282935b
commit 2ec116f592
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 265 additions and 1 deletions

View File

@ -0,0 +1,64 @@
// Copyright (C) 2018-2023 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#include "openvino/frontend/pytorch/node_context.hpp"
#include "openvino/op/add.hpp"
#include "openvino/op/convolution.hpp"
#include "openvino/op/group_conv.hpp"
#include "utils.hpp"
namespace ov {
namespace frontend {
namespace pytorch {
namespace op {
OutputVector translate_conv_transposend(NodeContext& context) {
auto num_inputs = context.get_input_size();
FRONT_END_OP_CONVERSION_CHECK(num_inputs == 8, "Unsupported number of inputs: ", num_inputs);
auto strides = context.const_input<Strides>(3);
// PyTorch support only symmetric padding, padding sizes are the same for begins and ends for each dimension
auto pads = context.const_input<CoordinateDiff>(4);
auto output_padding = context.const_input<CoordinateDiff>(5);
auto pad_type = ov::op::PadType::EXPLICIT;
auto dilations = context.const_input<Strides>(7);
auto groups = context.const_input<int64_t>(6);
FRONT_END_OP_CONVERSION_CHECK(groups > 0, "Number of groups for convolution_transpose should be >= 1");
std::shared_ptr<ov::Node> conv;
if (groups == 1) {
conv = std::make_shared<ov::op::v1::ConvolutionBackpropData>(context.get_input(0),
context.get_input(1),
strides,
pads,
pads,
dilations,
pad_type,
output_padding);
} else {
conv = std::make_shared<ov::op::v1::GroupConvolutionBackpropData>(
context.get_input(0),
reshape_kernel_for_group(context, context.get_input(0), context.get_input(1), groups),
strides,
pads,
pads,
dilations,
pad_type,
output_padding);
}
if (!context.input_is_none(2)) {
auto bias = context.get_input(2);
auto bias_rank = bias.get_partial_shape().rank();
if (bias_rank == 1) {
bias = reshape_conv_bias(context, bias, conv);
}
conv = context.mark_node(std::make_shared<ov::op::v1::Add>(conv, bias));
}
return {conv};
};
} // namespace op
} // namespace pytorch
} // namespace frontend
} // namespace ov

View File

@ -26,14 +26,15 @@ OP_CONVERTER(translate_batch_norm);
OP_CONVERTER(translate_clamp);
OP_CONVERTER(translate_constant);
OP_CONVERTER(translate_convnd);
OP_CONVERTER(translate_conv_transposend);
OP_CONVERTER(translate_convolution);
OP_CONVERTER(translate_convolution_mode);
OP_CONVERTER(translate_dim);
OP_CONVERTER(translate_div);
OP_CONVERTER(translate_elu);
OP_CONVERTER(translate_embedding);
OP_CONVERTER(translate_expand);
OP_CONVERTER(translate_expand_as);
OP_CONVERTER(translate_embedding);
OP_CONVERTER(translate_flatten);
OP_CONVERTER(translate_floordiv);
OP_CONVERTER(translate_floor_divide);
@ -144,6 +145,9 @@ const std::map<std::string, CreatorFunction> get_supported_ops() {
{"aten::conv1d", op::translate_convnd},
{"aten::conv2d", op::translate_convnd},
{"aten::conv3d", op::translate_convnd},
{"aten::conv_transpose1d", op::translate_conv_transposend},
{"aten::conv_transpose2d", op::translate_conv_transposend},
{"aten::conv_transpose3d", op::translate_conv_transposend},
{"aten::convolution", op::translate_convolution},
{"aten::cos", op::translate_1to1_match_1_inputs<opset10::Cos>},
{"aten::cos_", op::inplace_op<op::translate_1to1_match_1_inputs<opset10::Cos>>},

View File

@ -0,0 +1,196 @@
# Copyright (C) 2018-2023 Intel Corporation
# SPDX-License-Identifier: Apache-2.0
import pytest
from pytorch_layer_test_class import PytorchLayerTest
class TestConvTranspose2D(PytorchLayerTest):
def _prepare_input(self):
import numpy as np
return (np.random.randn(1, 3, 10, 10).astype(np.float32),)
def create_model(self, weights_shape, strides, pads, dilations, groups, bias, output_padding):
import torch
import torch.nn.functional as F
class aten_conv_transpose2d(torch.nn.Module):
def __init__(self):
super(aten_conv_transpose2d, self).__init__()
self.weight = torch.randn(weights_shape)
self.bias = None
if bias:
self.bias = torch.randn(groups)
self.strides = strides
self.pads = pads
self.dilations = dilations
self.groups = groups
self.output_padding = output_padding
def forward(self, x):
return F.conv_transpose2d(x, weight=self.weight, bias=self.bias, stride=self.strides, padding=self.pads, output_padding=self.output_padding, dilation=self.dilations, groups=self.groups)
ref_net = None
return aten_conv_transpose2d(), ref_net, "aten::conv_transpose2d"
@pytest.mark.parametrize("params",
[{'weights_shape': [3, 1, 1, 1], 'strides': [1, 1], 'pads': [0, 0],
'dilations': [2, 2], 'groups': 1, 'output_padding': [0, 0]},
{'weights_shape': [3, 1, 1, 1], 'strides': [1, 1], 'pads': [
0, 0], 'dilations': [1, 1], 'groups': 3, 'output_padding': [0, 0]},
{'weights_shape': [3, 1, 1, 1], 'strides': [1, 1], 'pads': [
1, 1], 'dilations': [1, 1], 'groups': 1, 'output_padding': [0, 0]},
{'weights_shape': [3, 1, 1, 1], 'strides': [1, 1], 'pads': [
3, 1], 'dilations': [1, 1], 'groups': 1, 'output_padding': [0, 0]},
{'weights_shape': [3, 1, 1, 1], 'strides': [1, 1], 'pads': [
1, 0], 'dilations': [1, 1], 'groups': 1, 'output_padding': [0, 0]},
{'weights_shape': [3, 1, 1, 1], 'strides': [1, 1], 'pads': [
1, 0], 'dilations': [1, 1], 'groups': 3, 'output_padding': [0, 0]},
{'weights_shape': [3, 1, 1, 1], 'strides': [1, 1], 'pads': [
1, 0], 'dilations': [2, 2], 'groups': 3, 'output_padding': [0, 0]},
{'weights_shape': [3, 1, 1, 1], 'strides': [2, 1], 'pads': [
1, 0], 'dilations': [1, 1], 'groups': 1, 'output_padding': [0, 0]},
{'weights_shape': [3, 1, 1, 1], 'strides': [2, 2], 'pads': [
0, 0], 'dilations': [1, 1], 'groups': 1, 'output_padding': [0, 0]},
{'weights_shape': [3, 1, 1, 1], 'strides': [2, 2], 'pads': [
0, 0], 'dilations': [1, 1], 'groups': 1, 'output_padding': [0, 0]},
{'weights_shape': [3, 1, 1, 1], 'strides': [2, 2], 'pads': [
1, 1], 'dilations': [2, 2], 'groups': 1, 'output_padding': [1, 1]},
])
@pytest.mark.parametrize("bias", [True, False])
@pytest.mark.nightly
@pytest.mark.precommit
def test_conv_transpose2d(self, params, bias, ie_device, precision, ir_version):
self._test(*self.create_model(**params, bias=bias),
ie_device, precision, ir_version, dynamic_shapes=params['groups'] == 1)
class TestConvTranspose1D(PytorchLayerTest):
def _prepare_input(self):
import numpy as np
return (np.random.randn(1, 3, 10).astype(np.float32),)
def create_model(self, weights_shape, strides, pads, dilations, groups, bias, output_padding):
import torch
import torch.nn.functional as F
class aten_conv_transpose1d(torch.nn.Module):
def __init__(self):
super(aten_conv_transpose1d, self).__init__()
self.weight = torch.randn(weights_shape)
self.bias = None
if bias:
self.bias = torch.randn(groups)
self.strides = strides
self.pads = pads
self.dilations = dilations
self.groups = groups
self.output_padding = output_padding
def forward(self, x):
return F.conv_transpose1d(
x,
weight=self.weight,
bias=self.bias,
stride=self.strides,
padding=self.pads,
output_padding=self.output_padding,
dilation=self.dilations,
groups=self.groups
)
ref_net = None
return aten_conv_transpose1d(), ref_net, "aten::conv_transpose1d"
@pytest.mark.parametrize("params",
[{'weights_shape': [3, 1, 1], 'strides': 1, 'pads': 0, 'dilations': 1, 'groups': 1, 'output_padding': 0},
{'weights_shape': [3, 1, 1], 'strides': 1, 'pads': 0,
'dilations': 1, 'groups': 3, 'output_padding': 0},
{'weights_shape': [3, 1, 1], 'strides': 1, 'pads': 1,
'dilations': 1, 'groups': 1, 'output_padding': 0},
{'weights_shape': [3, 1, 1], 'strides': 1, 'pads': 1,
'dilations': 1, 'groups': 3, 'output_padding': 0},
{'weights_shape': [3, 1, 1], 'strides': 1, 'pads': 3,
'dilations': 2, 'groups': 1, 'output_padding': 1},
{'weights_shape': [3, 1, 1], 'strides': 1, 'pads': 3,
'dilations': 2, 'groups': 3, 'output_padding': 1},
])
@pytest.mark.parametrize("bias", [True, False])
@pytest.mark.nightly
@pytest.mark.precommit
def test_conv_transpose1d(self, params, bias, ie_device, precision, ir_version):
self._test(*self.create_model(**params, bias=bias),
ie_device, precision, ir_version, dynamic_shapes=params['groups'] == 1)
class TestConvTranspose3D(PytorchLayerTest):
def _prepare_input(self):
import numpy as np
return (np.random.randn(1, 3, 10, 10, 4).astype(np.float32),)
def create_model(self, weights_shape, strides, pads, dilations, groups, bias, output_padding):
import torch
import torch.nn.functional as F
class aten_conv_transpose3d(torch.nn.Module):
def __init__(self):
super(aten_conv_transpose3d, self).__init__()
self.weight = torch.randn(weights_shape)
self.bias = None
if bias:
self.bias = torch.randn(groups)
self.strides = strides
self.pads = pads
self.dilations = dilations
self.groups = groups
self.output_padding = output_padding
def forward(self, x):
return F.conv_transpose3d(
x,
weight=self.weight,
bias=self.bias,
stride=self.strides,
padding=self.pads,
output_padding=self.output_padding,
dilation=self.dilations,
groups=self.groups
)
ref_net = None
return aten_conv_transpose3d(), ref_net, "aten::conv_transpose3d"
@pytest.mark.parametrize("params",
[{'weights_shape': [3, 1, 1, 1, 1], 'strides': [1, 1, 1], 'pads': [0, 0, 0],
'dilations': [2, 2, 2], 'groups': 1, 'output_padding': [0, 0, 0]},
{'weights_shape': [3, 1, 1, 1, 1], 'strides': [1, 1, 1], 'pads': [
0, 0, 0], 'dilations': [1, 1, 1], 'groups': 3, 'output_padding': [0, 0, 0]},
{'weights_shape': [3, 1, 1, 1, 1], 'strides': [1, 1, 1], 'pads': [
1, 1, 0], 'dilations': [1, 1, 2], 'groups': 1, 'output_padding': [0, 0, 1]},
{'weights_shape': [3, 1, 1, 1, 1], 'strides': [1, 1, 2], 'pads': [
3, 1, 0], 'dilations': [4, 4, 4], 'groups': 1, 'output_padding': [1, 1, 1]},
{'weights_shape': [3, 1, 1, 1, 1], 'strides': [1, 1, 1], 'pads': [
1, 0, 1], 'dilations': [1, 2, 1], 'groups': 1, 'output_padding': [0, 1, 0]},
{'weights_shape': [3, 1, 1, 1, 1], 'strides': [1, 1, 1], 'pads': [
1, 0, 0], 'dilations': [1, 1, 2], 'groups': 3, 'output_padding': [0, 0, 0]},
{'weights_shape': [3, 1, 1, 1, 1], 'strides': [1, 1, 1], 'pads': [
1, 0, 0], 'dilations': [2, 2, 1], 'groups': 3, 'output_padding': [0, 0, 0]},
{'weights_shape': [3, 1, 1, 1, 1], 'strides': [2, 1, 2], 'pads': [
1, 0, 0], 'dilations': [3, 4, 2], 'groups': 1, 'output_padding': [2, 0, 0]},
{'weights_shape': [3, 1, 1, 1, 1], 'strides': [2, 2, 2], 'pads': [
0, 0, 0], 'dilations': [1, 1, 1], 'groups': 1, 'output_padding': [0, 0, 0]},
{'weights_shape': [3, 1, 1, 1, 1], 'strides': [2, 2, 2], 'pads': [
0, 0, 0], 'dilations': [1, 1, 1], 'groups': 1, 'output_padding': [0, 0, 0]},
{'weights_shape': [3, 1, 1, 1, 1], 'strides': [2, 2, 2], 'pads': [
1, 1, 2], 'dilations': [2, 2, 2], 'groups': 1, 'output_padding': [1, 1, 0]},
])
@pytest.mark.parametrize("bias", [True, False])
@pytest.mark.nightly
@pytest.mark.precommit
def test_conv_transpose3d(self, params, bias, ie_device, precision, ir_version):
self._test(*self.create_model(**params, bias=bias),
ie_device, precision, ir_version, dynamic_shapes=params['groups'] == 1)