diff --git a/src/frontends/pytorch/src/op_table.cpp b/src/frontends/pytorch/src/op_table.cpp index 18063f87c0a..1ac094d991a 100644 --- a/src/frontends/pytorch/src/op_table.cpp +++ b/src/frontends/pytorch/src/op_table.cpp @@ -468,6 +468,11 @@ const std::map get_supported_ops_ts() { {"aten::repeat_interleave", op::translate_repeat_interleave}, {"aten::reshape", op::translate_reshape}, {"aten::reshape_as", op::translate_reshape_as}, + // TO DO: enable behaviour for resolve_conj and resolve_neg complex tensors, + // when complex dtype will be supported + // for real dtypes, these operations return input tensor without changes and can be skipped + {"aten::resolve_conj", op::skip_node}, + {"aten::resolve_neg", op::skip_node}, {"aten::roll", op::translate_roll}, {"aten::round", op::translate_round}, {"aten::rsqrt", op::translate_rsqrt}, diff --git a/tests/layer_tests/pytorch_tests/test_resolve_conj_neg.py b/tests/layer_tests/pytorch_tests/test_resolve_conj_neg.py new file mode 100644 index 00000000000..87097b88af3 --- /dev/null +++ b/tests/layer_tests/pytorch_tests/test_resolve_conj_neg.py @@ -0,0 +1,54 @@ +# Copyright (C) 2018-2023 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +import pytest + +from pytorch_layer_test_class import PytorchLayerTest + + +class TestResolveConjNeg(PytorchLayerTest): + def _prepare_input(self, dtype="float32"): + import numpy as np + return (np.random.randn(2, 4).astype(dtype),) + + def _prepare_input_complex(self): + import numpy as np + return (np.array([[2+3j, 3-2j, 4-9j,10+1j], [1-3j, 3+2j, 4+9j,10-5j]]), ) + + + def create_model(self, op_type): + import torch + + ops = { + "resolve_conj": torch.resolve_conj, + "resolve_neg": torch.resolve_neg + } + + op = ops[op_type] + + class aten_resolve(torch.nn.Module): + def __init__(self, op): + super(aten_resolve, self).__init__() + self.op = op + + def forward(self, x): + return self.op(x) + + ref_net = None + + return aten_resolve(op), ref_net, f"aten::{op_type}" + + @pytest.mark.nightly + @pytest.mark.precommit + @pytest.mark.parametrize("op_type", ["resolve_neg", "resolve_conj"]) + @pytest.mark.parametrize("dtype", ["float32", "int32"]) + def test_reslove(self, op_type, dtype, ie_device, precision, ir_version): + self._test(*self.create_model(op_type), ie_device, precision, ir_version, kwargs_to_prepare_input={"dtype": dtype}) + + @pytest.mark.nightly + @pytest.mark.precommit + @pytest.mark.parametrize("op_type", ["resolve_neg", "resolve_conj"]) + @pytest.mark.xfail(reason="complex dtype is not supported yet") + def test_resolve_complex(self, op_type, ie_device, precision, ir_version): + self._prepare_input = self._prepare_input_complex + self._test(*self.create_model(op_type), ie_device, precision, ir_version)