[PT FE] Fix aten::conv2d conversion issues related to bias (#17228)

* Allow for conversion

* Check for const

* Apply requested changes

* Improve tests
This commit is contained in:
Mateusz Mikolajczyk 2023-05-22 09:59:02 +02:00 committed by GitHub
parent 4ccc6e3034
commit 9b52a77531
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 65 additions and 1 deletions

View File

@ -51,6 +51,10 @@ OutputVector translate_convnd(const NodeContext& context) {
}
if (!context.input_is_none(2)) {
auto bias = context.get_input(2);
auto bias_from_visible_context = context.get_input_from_visible_context(2);
if (std::dynamic_pointer_cast<v0::Constant>(bias_from_visible_context.get_node_shared_ptr())) {
bias = bias_from_visible_context;
}
auto bias_rank = bias.get_partial_shape().rank();
if (bias_rank == 1) {
bias = reshape_channelwise(context, bias, conv);

View File

@ -2,7 +2,8 @@
# SPDX-License-Identifier: Apache-2.0
import pytest
from openvino.frontend import FrontEndManager
from openvino.frontend.pytorch.decoder import TorchScriptPythonDecoder
from pytorch_layer_test_class import PytorchLayerTest
@ -159,3 +160,62 @@ class TestConv3D(PytorchLayerTest):
def test_conv3d(self, params, bias, ie_device, precision, ir_version):
self._test(*self.create_model(**params, bias=bias),
ie_device, precision, ir_version)
class TestConv2DInSubgraph(PytorchLayerTest):
def _prepare_input(self):
import numpy as np
return (np.random.randn(2, 3, 25, 25).astype(np.float32), np.array([1], dtype=np.int32))
def convert_directly_via_frontend(self, model, example_input, trace_model, dynamic_shapes, ov_inputs, freeze_model):
# Overload function to allow reproduction of issue caused by additional freeze.
import torch
fe_manager = FrontEndManager()
fe = fe_manager.load_by_framework('pytorch')
model.eval()
with torch.no_grad():
if trace_model:
model = torch.jit.trace(model, example_input)
else:
model = torch.jit.script(model)
model = torch.jit.freeze(model)
print(model.inlined_graph)
decoder = TorchScriptPythonDecoder(model, freeze=freeze_model)
im = fe.load(decoder)
om = fe.convert(im)
self._resolve_input_shape_dtype(om, ov_inputs, dynamic_shapes)
return model, om
def create_model(self):
import torch
from torchvision.ops import Conv2dNormActivation
class aten_conv2d(torch.nn.Module):
def __init__(self):
super().__init__()
convs = []
conv_depth=2
for _ in range(conv_depth):
convs.append(Conv2dNormActivation(3, 3, 3, norm_layer=None))
self.convs = torch.nn.Sequential(*convs)
for layer in self.modules():
if isinstance(layer, torch.nn.Conv2d):
torch.nn.init.normal_(layer.weight) # type: ignore[arg-type]
torch.nn.init.constant_(layer.bias, 0) # type: ignore[arg-type]
def forward(self, x, y):
acc = self.convs(x)
if y:
acc += self.convs(x)
return acc
ref_net = None
return aten_conv2d(), ref_net, "aten::conv2d"
@pytest.mark.nightly
@pytest.mark.precommit
def test_conv2d(self, ie_device, precision, ir_version):
self._test(*self.create_model(),
ie_device, precision, ir_version, freeze_model=True, dynamic_shapes=False)