[PT FE] Fix aten::conv2d conversion issues related to bias (#17228)
* Allow for conversion * Check for const * Apply requested changes * Improve tests
This commit is contained in:
parent
4ccc6e3034
commit
9b52a77531
@ -51,6 +51,10 @@ OutputVector translate_convnd(const NodeContext& context) {
|
||||
}
|
||||
if (!context.input_is_none(2)) {
|
||||
auto bias = context.get_input(2);
|
||||
auto bias_from_visible_context = context.get_input_from_visible_context(2);
|
||||
if (std::dynamic_pointer_cast<v0::Constant>(bias_from_visible_context.get_node_shared_ptr())) {
|
||||
bias = bias_from_visible_context;
|
||||
}
|
||||
auto bias_rank = bias.get_partial_shape().rank();
|
||||
if (bias_rank == 1) {
|
||||
bias = reshape_channelwise(context, bias, conv);
|
||||
|
@ -2,7 +2,8 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
import pytest
|
||||
|
||||
from openvino.frontend import FrontEndManager
|
||||
from openvino.frontend.pytorch.decoder import TorchScriptPythonDecoder
|
||||
from pytorch_layer_test_class import PytorchLayerTest
|
||||
|
||||
|
||||
@ -159,3 +160,62 @@ class TestConv3D(PytorchLayerTest):
|
||||
def test_conv3d(self, params, bias, ie_device, precision, ir_version):
|
||||
self._test(*self.create_model(**params, bias=bias),
|
||||
ie_device, precision, ir_version)
|
||||
|
||||
class TestConv2DInSubgraph(PytorchLayerTest):
|
||||
def _prepare_input(self):
|
||||
import numpy as np
|
||||
return (np.random.randn(2, 3, 25, 25).astype(np.float32), np.array([1], dtype=np.int32))
|
||||
|
||||
def convert_directly_via_frontend(self, model, example_input, trace_model, dynamic_shapes, ov_inputs, freeze_model):
|
||||
# Overload function to allow reproduction of issue caused by additional freeze.
|
||||
import torch
|
||||
|
||||
fe_manager = FrontEndManager()
|
||||
fe = fe_manager.load_by_framework('pytorch')
|
||||
|
||||
model.eval()
|
||||
with torch.no_grad():
|
||||
if trace_model:
|
||||
model = torch.jit.trace(model, example_input)
|
||||
else:
|
||||
model = torch.jit.script(model)
|
||||
model = torch.jit.freeze(model)
|
||||
print(model.inlined_graph)
|
||||
decoder = TorchScriptPythonDecoder(model, freeze=freeze_model)
|
||||
im = fe.load(decoder)
|
||||
om = fe.convert(im)
|
||||
self._resolve_input_shape_dtype(om, ov_inputs, dynamic_shapes)
|
||||
return model, om
|
||||
|
||||
|
||||
def create_model(self):
|
||||
import torch
|
||||
from torchvision.ops import Conv2dNormActivation
|
||||
|
||||
class aten_conv2d(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
convs = []
|
||||
conv_depth=2
|
||||
for _ in range(conv_depth):
|
||||
convs.append(Conv2dNormActivation(3, 3, 3, norm_layer=None))
|
||||
self.convs = torch.nn.Sequential(*convs)
|
||||
for layer in self.modules():
|
||||
if isinstance(layer, torch.nn.Conv2d):
|
||||
torch.nn.init.normal_(layer.weight) # type: ignore[arg-type]
|
||||
torch.nn.init.constant_(layer.bias, 0) # type: ignore[arg-type]
|
||||
|
||||
def forward(self, x, y):
|
||||
acc = self.convs(x)
|
||||
if y:
|
||||
acc += self.convs(x)
|
||||
return acc
|
||||
ref_net = None
|
||||
|
||||
return aten_conv2d(), ref_net, "aten::conv2d"
|
||||
|
||||
@pytest.mark.nightly
|
||||
@pytest.mark.precommit
|
||||
def test_conv2d(self, ie_device, precision, ir_version):
|
||||
self._test(*self.create_model(),
|
||||
ie_device, precision, ir_version, freeze_model=True, dynamic_shapes=False)
|
||||
|
Loading…
Reference in New Issue
Block a user