[PT FE]: support conv transpose (#15191)
* [PT FE]: support conv transpose * apply comments
This commit is contained in:
parent
ed6282935b
commit
2ec116f592
64
src/frontends/pytorch/src/op/conv_transposend.cpp
Normal file
64
src/frontends/pytorch/src/op/conv_transposend.cpp
Normal file
@ -0,0 +1,64 @@
|
||||
// Copyright (C) 2018-2023 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#include "openvino/frontend/pytorch/node_context.hpp"
|
||||
#include "openvino/op/add.hpp"
|
||||
#include "openvino/op/convolution.hpp"
|
||||
#include "openvino/op/group_conv.hpp"
|
||||
#include "utils.hpp"
|
||||
|
||||
namespace ov {
|
||||
namespace frontend {
|
||||
namespace pytorch {
|
||||
namespace op {
|
||||
|
||||
OutputVector translate_conv_transposend(NodeContext& context) {
|
||||
auto num_inputs = context.get_input_size();
|
||||
FRONT_END_OP_CONVERSION_CHECK(num_inputs == 8, "Unsupported number of inputs: ", num_inputs);
|
||||
auto strides = context.const_input<Strides>(3);
|
||||
// PyTorch support only symmetric padding, padding sizes are the same for begins and ends for each dimension
|
||||
auto pads = context.const_input<CoordinateDiff>(4);
|
||||
auto output_padding = context.const_input<CoordinateDiff>(5);
|
||||
auto pad_type = ov::op::PadType::EXPLICIT;
|
||||
auto dilations = context.const_input<Strides>(7);
|
||||
auto groups = context.const_input<int64_t>(6);
|
||||
FRONT_END_OP_CONVERSION_CHECK(groups > 0, "Number of groups for convolution_transpose should be >= 1");
|
||||
|
||||
std::shared_ptr<ov::Node> conv;
|
||||
if (groups == 1) {
|
||||
conv = std::make_shared<ov::op::v1::ConvolutionBackpropData>(context.get_input(0),
|
||||
context.get_input(1),
|
||||
strides,
|
||||
pads,
|
||||
pads,
|
||||
dilations,
|
||||
pad_type,
|
||||
output_padding);
|
||||
} else {
|
||||
conv = std::make_shared<ov::op::v1::GroupConvolutionBackpropData>(
|
||||
context.get_input(0),
|
||||
reshape_kernel_for_group(context, context.get_input(0), context.get_input(1), groups),
|
||||
strides,
|
||||
pads,
|
||||
pads,
|
||||
dilations,
|
||||
pad_type,
|
||||
output_padding);
|
||||
}
|
||||
if (!context.input_is_none(2)) {
|
||||
auto bias = context.get_input(2);
|
||||
auto bias_rank = bias.get_partial_shape().rank();
|
||||
if (bias_rank == 1) {
|
||||
bias = reshape_conv_bias(context, bias, conv);
|
||||
}
|
||||
conv = context.mark_node(std::make_shared<ov::op::v1::Add>(conv, bias));
|
||||
}
|
||||
|
||||
return {conv};
|
||||
};
|
||||
|
||||
} // namespace op
|
||||
} // namespace pytorch
|
||||
} // namespace frontend
|
||||
} // namespace ov
|
@ -26,14 +26,15 @@ OP_CONVERTER(translate_batch_norm);
|
||||
OP_CONVERTER(translate_clamp);
|
||||
OP_CONVERTER(translate_constant);
|
||||
OP_CONVERTER(translate_convnd);
|
||||
OP_CONVERTER(translate_conv_transposend);
|
||||
OP_CONVERTER(translate_convolution);
|
||||
OP_CONVERTER(translate_convolution_mode);
|
||||
OP_CONVERTER(translate_dim);
|
||||
OP_CONVERTER(translate_div);
|
||||
OP_CONVERTER(translate_elu);
|
||||
OP_CONVERTER(translate_embedding);
|
||||
OP_CONVERTER(translate_expand);
|
||||
OP_CONVERTER(translate_expand_as);
|
||||
OP_CONVERTER(translate_embedding);
|
||||
OP_CONVERTER(translate_flatten);
|
||||
OP_CONVERTER(translate_floordiv);
|
||||
OP_CONVERTER(translate_floor_divide);
|
||||
@ -144,6 +145,9 @@ const std::map<std::string, CreatorFunction> get_supported_ops() {
|
||||
{"aten::conv1d", op::translate_convnd},
|
||||
{"aten::conv2d", op::translate_convnd},
|
||||
{"aten::conv3d", op::translate_convnd},
|
||||
{"aten::conv_transpose1d", op::translate_conv_transposend},
|
||||
{"aten::conv_transpose2d", op::translate_conv_transposend},
|
||||
{"aten::conv_transpose3d", op::translate_conv_transposend},
|
||||
{"aten::convolution", op::translate_convolution},
|
||||
{"aten::cos", op::translate_1to1_match_1_inputs<opset10::Cos>},
|
||||
{"aten::cos_", op::inplace_op<op::translate_1to1_match_1_inputs<opset10::Cos>>},
|
||||
|
196
tests/layer_tests/pytorch_tests/test_conv_transposend.py
Normal file
196
tests/layer_tests/pytorch_tests/test_conv_transposend.py
Normal file
@ -0,0 +1,196 @@
|
||||
# Copyright (C) 2018-2023 Intel Corporation
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
import pytest
|
||||
|
||||
from pytorch_layer_test_class import PytorchLayerTest
|
||||
|
||||
|
||||
class TestConvTranspose2D(PytorchLayerTest):
|
||||
def _prepare_input(self):
|
||||
import numpy as np
|
||||
return (np.random.randn(1, 3, 10, 10).astype(np.float32),)
|
||||
|
||||
def create_model(self, weights_shape, strides, pads, dilations, groups, bias, output_padding):
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
class aten_conv_transpose2d(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super(aten_conv_transpose2d, self).__init__()
|
||||
self.weight = torch.randn(weights_shape)
|
||||
self.bias = None
|
||||
if bias:
|
||||
self.bias = torch.randn(groups)
|
||||
self.strides = strides
|
||||
self.pads = pads
|
||||
self.dilations = dilations
|
||||
self.groups = groups
|
||||
self.output_padding = output_padding
|
||||
|
||||
def forward(self, x):
|
||||
return F.conv_transpose2d(x, weight=self.weight, bias=self.bias, stride=self.strides, padding=self.pads, output_padding=self.output_padding, dilation=self.dilations, groups=self.groups)
|
||||
|
||||
ref_net = None
|
||||
|
||||
return aten_conv_transpose2d(), ref_net, "aten::conv_transpose2d"
|
||||
|
||||
@pytest.mark.parametrize("params",
|
||||
[{'weights_shape': [3, 1, 1, 1], 'strides': [1, 1], 'pads': [0, 0],
|
||||
'dilations': [2, 2], 'groups': 1, 'output_padding': [0, 0]},
|
||||
{'weights_shape': [3, 1, 1, 1], 'strides': [1, 1], 'pads': [
|
||||
0, 0], 'dilations': [1, 1], 'groups': 3, 'output_padding': [0, 0]},
|
||||
{'weights_shape': [3, 1, 1, 1], 'strides': [1, 1], 'pads': [
|
||||
1, 1], 'dilations': [1, 1], 'groups': 1, 'output_padding': [0, 0]},
|
||||
{'weights_shape': [3, 1, 1, 1], 'strides': [1, 1], 'pads': [
|
||||
3, 1], 'dilations': [1, 1], 'groups': 1, 'output_padding': [0, 0]},
|
||||
{'weights_shape': [3, 1, 1, 1], 'strides': [1, 1], 'pads': [
|
||||
1, 0], 'dilations': [1, 1], 'groups': 1, 'output_padding': [0, 0]},
|
||||
{'weights_shape': [3, 1, 1, 1], 'strides': [1, 1], 'pads': [
|
||||
1, 0], 'dilations': [1, 1], 'groups': 3, 'output_padding': [0, 0]},
|
||||
{'weights_shape': [3, 1, 1, 1], 'strides': [1, 1], 'pads': [
|
||||
1, 0], 'dilations': [2, 2], 'groups': 3, 'output_padding': [0, 0]},
|
||||
{'weights_shape': [3, 1, 1, 1], 'strides': [2, 1], 'pads': [
|
||||
1, 0], 'dilations': [1, 1], 'groups': 1, 'output_padding': [0, 0]},
|
||||
{'weights_shape': [3, 1, 1, 1], 'strides': [2, 2], 'pads': [
|
||||
0, 0], 'dilations': [1, 1], 'groups': 1, 'output_padding': [0, 0]},
|
||||
{'weights_shape': [3, 1, 1, 1], 'strides': [2, 2], 'pads': [
|
||||
0, 0], 'dilations': [1, 1], 'groups': 1, 'output_padding': [0, 0]},
|
||||
{'weights_shape': [3, 1, 1, 1], 'strides': [2, 2], 'pads': [
|
||||
1, 1], 'dilations': [2, 2], 'groups': 1, 'output_padding': [1, 1]},
|
||||
])
|
||||
@pytest.mark.parametrize("bias", [True, False])
|
||||
@pytest.mark.nightly
|
||||
@pytest.mark.precommit
|
||||
def test_conv_transpose2d(self, params, bias, ie_device, precision, ir_version):
|
||||
self._test(*self.create_model(**params, bias=bias),
|
||||
ie_device, precision, ir_version, dynamic_shapes=params['groups'] == 1)
|
||||
|
||||
|
||||
class TestConvTranspose1D(PytorchLayerTest):
|
||||
def _prepare_input(self):
|
||||
import numpy as np
|
||||
return (np.random.randn(1, 3, 10).astype(np.float32),)
|
||||
|
||||
def create_model(self, weights_shape, strides, pads, dilations, groups, bias, output_padding):
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
class aten_conv_transpose1d(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super(aten_conv_transpose1d, self).__init__()
|
||||
self.weight = torch.randn(weights_shape)
|
||||
self.bias = None
|
||||
if bias:
|
||||
self.bias = torch.randn(groups)
|
||||
self.strides = strides
|
||||
self.pads = pads
|
||||
self.dilations = dilations
|
||||
self.groups = groups
|
||||
self.output_padding = output_padding
|
||||
|
||||
def forward(self, x):
|
||||
return F.conv_transpose1d(
|
||||
x,
|
||||
weight=self.weight,
|
||||
bias=self.bias,
|
||||
stride=self.strides,
|
||||
padding=self.pads,
|
||||
output_padding=self.output_padding,
|
||||
dilation=self.dilations,
|
||||
groups=self.groups
|
||||
)
|
||||
|
||||
ref_net = None
|
||||
|
||||
return aten_conv_transpose1d(), ref_net, "aten::conv_transpose1d"
|
||||
|
||||
@pytest.mark.parametrize("params",
|
||||
[{'weights_shape': [3, 1, 1], 'strides': 1, 'pads': 0, 'dilations': 1, 'groups': 1, 'output_padding': 0},
|
||||
{'weights_shape': [3, 1, 1], 'strides': 1, 'pads': 0,
|
||||
'dilations': 1, 'groups': 3, 'output_padding': 0},
|
||||
{'weights_shape': [3, 1, 1], 'strides': 1, 'pads': 1,
|
||||
'dilations': 1, 'groups': 1, 'output_padding': 0},
|
||||
{'weights_shape': [3, 1, 1], 'strides': 1, 'pads': 1,
|
||||
'dilations': 1, 'groups': 3, 'output_padding': 0},
|
||||
{'weights_shape': [3, 1, 1], 'strides': 1, 'pads': 3,
|
||||
'dilations': 2, 'groups': 1, 'output_padding': 1},
|
||||
{'weights_shape': [3, 1, 1], 'strides': 1, 'pads': 3,
|
||||
'dilations': 2, 'groups': 3, 'output_padding': 1},
|
||||
])
|
||||
@pytest.mark.parametrize("bias", [True, False])
|
||||
@pytest.mark.nightly
|
||||
@pytest.mark.precommit
|
||||
def test_conv_transpose1d(self, params, bias, ie_device, precision, ir_version):
|
||||
self._test(*self.create_model(**params, bias=bias),
|
||||
ie_device, precision, ir_version, dynamic_shapes=params['groups'] == 1)
|
||||
|
||||
|
||||
class TestConvTranspose3D(PytorchLayerTest):
|
||||
def _prepare_input(self):
|
||||
import numpy as np
|
||||
return (np.random.randn(1, 3, 10, 10, 4).astype(np.float32),)
|
||||
|
||||
def create_model(self, weights_shape, strides, pads, dilations, groups, bias, output_padding):
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
class aten_conv_transpose3d(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super(aten_conv_transpose3d, self).__init__()
|
||||
self.weight = torch.randn(weights_shape)
|
||||
self.bias = None
|
||||
if bias:
|
||||
self.bias = torch.randn(groups)
|
||||
self.strides = strides
|
||||
self.pads = pads
|
||||
self.dilations = dilations
|
||||
self.groups = groups
|
||||
self.output_padding = output_padding
|
||||
|
||||
def forward(self, x):
|
||||
return F.conv_transpose3d(
|
||||
x,
|
||||
weight=self.weight,
|
||||
bias=self.bias,
|
||||
stride=self.strides,
|
||||
padding=self.pads,
|
||||
output_padding=self.output_padding,
|
||||
dilation=self.dilations,
|
||||
groups=self.groups
|
||||
)
|
||||
|
||||
ref_net = None
|
||||
|
||||
return aten_conv_transpose3d(), ref_net, "aten::conv_transpose3d"
|
||||
|
||||
@pytest.mark.parametrize("params",
|
||||
[{'weights_shape': [3, 1, 1, 1, 1], 'strides': [1, 1, 1], 'pads': [0, 0, 0],
|
||||
'dilations': [2, 2, 2], 'groups': 1, 'output_padding': [0, 0, 0]},
|
||||
{'weights_shape': [3, 1, 1, 1, 1], 'strides': [1, 1, 1], 'pads': [
|
||||
0, 0, 0], 'dilations': [1, 1, 1], 'groups': 3, 'output_padding': [0, 0, 0]},
|
||||
{'weights_shape': [3, 1, 1, 1, 1], 'strides': [1, 1, 1], 'pads': [
|
||||
1, 1, 0], 'dilations': [1, 1, 2], 'groups': 1, 'output_padding': [0, 0, 1]},
|
||||
{'weights_shape': [3, 1, 1, 1, 1], 'strides': [1, 1, 2], 'pads': [
|
||||
3, 1, 0], 'dilations': [4, 4, 4], 'groups': 1, 'output_padding': [1, 1, 1]},
|
||||
{'weights_shape': [3, 1, 1, 1, 1], 'strides': [1, 1, 1], 'pads': [
|
||||
1, 0, 1], 'dilations': [1, 2, 1], 'groups': 1, 'output_padding': [0, 1, 0]},
|
||||
{'weights_shape': [3, 1, 1, 1, 1], 'strides': [1, 1, 1], 'pads': [
|
||||
1, 0, 0], 'dilations': [1, 1, 2], 'groups': 3, 'output_padding': [0, 0, 0]},
|
||||
{'weights_shape': [3, 1, 1, 1, 1], 'strides': [1, 1, 1], 'pads': [
|
||||
1, 0, 0], 'dilations': [2, 2, 1], 'groups': 3, 'output_padding': [0, 0, 0]},
|
||||
{'weights_shape': [3, 1, 1, 1, 1], 'strides': [2, 1, 2], 'pads': [
|
||||
1, 0, 0], 'dilations': [3, 4, 2], 'groups': 1, 'output_padding': [2, 0, 0]},
|
||||
{'weights_shape': [3, 1, 1, 1, 1], 'strides': [2, 2, 2], 'pads': [
|
||||
0, 0, 0], 'dilations': [1, 1, 1], 'groups': 1, 'output_padding': [0, 0, 0]},
|
||||
{'weights_shape': [3, 1, 1, 1, 1], 'strides': [2, 2, 2], 'pads': [
|
||||
0, 0, 0], 'dilations': [1, 1, 1], 'groups': 1, 'output_padding': [0, 0, 0]},
|
||||
{'weights_shape': [3, 1, 1, 1, 1], 'strides': [2, 2, 2], 'pads': [
|
||||
1, 1, 2], 'dilations': [2, 2, 2], 'groups': 1, 'output_padding': [1, 1, 0]},
|
||||
])
|
||||
@pytest.mark.parametrize("bias", [True, False])
|
||||
@pytest.mark.nightly
|
||||
@pytest.mark.precommit
|
||||
def test_conv_transpose3d(self, params, bias, ie_device, precision, ir_version):
|
||||
self._test(*self.create_model(**params, bias=bias),
|
||||
ie_device, precision, ir_version, dynamic_shapes=params['groups'] == 1)
|
Loading…
Reference in New Issue
Block a user