diff --git a/.github/workflows/linux.yml b/.github/workflows/linux.yml index 0c9baad6733..5a0e28a8a31 100644 --- a/.github/workflows/linux.yml +++ b/.github/workflows/linux.yml @@ -617,6 +617,7 @@ jobs: python3 -m pytest ${{ env.LAYER_TESTS_INSTALL_DIR }}/pytorch_tests -m precommit --junitxml=${{ env.INSTALL_TEST_DIR }}/TEST-pytorch.xml env: TEST_DEVICE: CPU + TEST_PRECISION: FP16 - name: TensorFlow 1 Layer Tests - TF FE run: | @@ -629,6 +630,7 @@ jobs: python3 -m pytest ${{ env.LAYER_TESTS_INSTALL_DIR }}/tensorflow_tests/ --use_new_frontend -m precommit_tf_fe --junitxml=${{ env.INSTALL_TEST_DIR }}/TEST-tf_fe.xml env: TEST_DEVICE: CPU + TEST_PRECISION: FP16 - name: TensorFlow 2 Layer Tests - TF FE run: | @@ -640,6 +642,7 @@ jobs: python3 -m pytest ${{ env.LAYER_TESTS_INSTALL_DIR }}/tensorflow2_keras_tests/ --use_new_frontend -m precommit_tf_fe --junitxml=${{ env.INSTALL_TEST_DIR }}/TEST-tf2_fe.xml env: TEST_DEVICE: CPU + TEST_PRECISION: FP16 - name: JAX Layer Tests - TF FE run: | @@ -670,6 +673,7 @@ jobs: --ir_version=11 --junitxml=${{ env.INSTALL_TEST_DIR }}/TEST-tf2_Activation.xml -k "sigmoid" env: TEST_DEVICE: CPU + TEST_PRECISION: FP16 - name: TensorFlow Lite Layer Tests - TFL FE run: | @@ -680,6 +684,7 @@ jobs: python3 -m pytest ${{ env.LAYER_TESTS_INSTALL_DIR }}/tensorflow_lite_tests/ --junitxml=${{ env.INSTALL_TEST_DIR }}/TEST-tfl_fe.xml env: TEST_DEVICE: CPU + TEST_PRECISION: FP16 - name: MO Python API Tests run: | @@ -690,6 +695,7 @@ jobs: python3 -m pytest ${{ env.LAYER_TESTS_INSTALL_DIR }}/mo_python_api_tests --junitxml=${{ env.INSTALL_TEST_DIR }}/TEST-test_mo_convert.xml env: TEST_DEVICE: CPU + TEST_PRECISION: FP16 - name: Python Frontend tests run: | diff --git a/tests/layer_tests/common/layer_test_class.py b/tests/layer_tests/common/layer_test_class.py index c49a23627bf..6faa5d6db6a 100644 --- a/tests/layer_tests/common/layer_test_class.py +++ b/tests/layer_tests/common/layer_test_class.py @@ -181,8 +181,6 @@ def get_params(ie_device=None, precision=None): test_args = [] for element in itertools.product(ie_device_params, precision_params): - if element[0] == 'CPU' and element[1] == 'FP16': - continue test_args.append(element) return test_args diff --git a/tests/layer_tests/mo_python_api_tests/test_mo_convert_pytorch.py b/tests/layer_tests/mo_python_api_tests/test_mo_convert_pytorch.py index ea29e9c94cc..9a863a12d70 100644 --- a/tests/layer_tests/mo_python_api_tests/test_mo_convert_pytorch.py +++ b/tests/layer_tests/mo_python_api_tests/test_mo_convert_pytorch.py @@ -1,18 +1,19 @@ # Copyright (C) 2018-2023 Intel Corporation # SPDX-License-Identifier: Apache-2.0 -import os - +import unittest from typing import Tuple -import numpy + import numpy as np import openvino.runtime as ov import pytest import torch -import unittest -from openvino.runtime import PartialShape, Dimension, Model, Type -from openvino.tools.mo import InputCutInfo +from openvino.runtime import PartialShape, Dimension, Model, Type, Core, save_model +from openvino.test_utils import compare_functions + from common.mo_convert_test_class import CommonMOConvertTest +from openvino.tools.mo import InputCutInfo +from openvino.tools.ovc import convert_model class MyTorchOp(torch.autograd.Function): @@ -1022,9 +1023,6 @@ class TestMoConvertPyTorch(CommonMOConvertTest): @ pytest.mark.precommit def test_sharing_memory_switched_off(self, ie_device, precision, ir_version, temp_dir): - from openvino.tools.ovc import convert_model - from openvino.runtime import Core - class DataModel(torch.nn.Module): def __init__(self): super(DataModel, self).__init__() @@ -1127,3 +1125,152 @@ class ConvertRaises(unittest.TestCase): with self.assertRaisesRegex(Exception, ".*PyTorch Frontend doesn't support provided model type.*"): with tempfile.NamedTemporaryFile() as tmpfile: convert_model(tmpfile.name, framework="pytorch") + + +def create_pytorch_layer_norm(tmp_dir): + class aten_layer_norm(torch.nn.Module): + def forward(self, x): + return torch.nn.functional.layer_norm(x, normalized_shape=[3]) + + shape = PartialShape(PartialShape([-1, -1])) + param1 = ov.opset8.parameter(shape, name="input_0", dtype=np.float32) + const1 = ov.opset8.constant([-1], dtype=np.int32) + mvn1 = ov.opset8.mvn(param1, const1, True, 1e-5, "inside_sqrt") + ref_model = Model([mvn1], [param1], "test") + + test_params = {'example_input': 300 + np.random.randn(2, 3).astype(np.float32)} + return aten_layer_norm(), ref_model, test_params + + +def create_pytorch_normalize(tmp_dir): + class aten_normalize(torch.nn.Module): + def forward(self, x): + return torch.nn.functional.normalize(x) + + test_params = {'example_input': 300 + np.random.randn(2, 3).astype(np.float32)} + return aten_normalize(), None, test_params + + +def create_pytorch_precision_sensitive_with_div(tmp_dir): + class precision_sensitive_with_div(torch.nn.Module): + def forward(self, x): + eps = 1.0e-8 + return 2.0 / (torch.sqrt(torch.sum(torch.pow(x + 2, 2.0), 1)) + eps) + test_params = {'example_input': 300 + np.random.randn(2, 3).astype(np.float32)} + return precision_sensitive_with_div(), None, test_params + + +def create_pytorch_precision_sensitive_for_exp_reduce(tmp_dir): + class precision_sensitive_for_exp_reduce(torch.nn.Module): + def forward(self, x): + return torch.sum(torch.exp(x + 10), 1) + + test_params = {'example_input': 300 + np.random.randn(2, 3).astype(np.float32)} + return precision_sensitive_for_exp_reduce(), None, test_params + + +def create_pytorch_precision_sensitive_div_as_pow(tmp_dir): + class precision_sensitive_div_as_pow(torch.nn.Module): + def forward(self, x): + eps = 1.0e-8 + return 2.0 * (torch.sqrt(torch.sum(torch.pow(x + 2, 2.0), 1)) + eps)**(-1) + + test_params = {'example_input': 300 + np.random.randn(2, 3).astype(np.float32)} + return precision_sensitive_div_as_pow(), None, test_params + + +def create_pytorch_precision_sensitive_two_inp_1(tmp_dir): + class precision_sensitive_two_inp_1(torch.nn.Module): + def forward(self, x, y): + eps = 1.0e-8 + return x / (torch.sqrt(torch.sum(torch.pow(y + 2, 2.0), 2)) + eps) + test_params = {'example_input': (10000 + np.ones((2, 10), dtype=np.float32), + 300 + np.ones((2, 10, 3), dtype=np.float32))} + return precision_sensitive_two_inp_1(), None, test_params + + +def create_pytorch_precision_sensitive_two_inp_2(tmp_dir): + class precision_sensitive_two_inp_2(torch.nn.Module): + def forward(self, x, y): + eps = 1.0e-8 + return x * (torch.sqrt(torch.sum(torch.pow(y + 2, 2.0), 2)) + eps)**(-1) + test_params = {'example_input': (10000 + np.ones((2, 10), dtype=np.float32), + 300 + np.ones((2, 10, 3), dtype=np.float32))} + return precision_sensitive_two_inp_2(), None, test_params + + +def create_pytorch_precision_sensitive_with_matmul(tmp_dir): + class precision_sensitive_with_matmul(torch.nn.Module): + def forward(self, x, y): + eps = 1.0e-8 + interm_res = x / (torch.sqrt(torch.sum(torch.pow(y + 2, 2.0), 2)) + eps) + print(f"interm_res shpae: {interm_res.shape}") + print(interm_res) + weights = 1024.0 + torch.zeros(10, 2) + return torch.mm(interm_res, weights) + test_params = {'example_input': (10000 + np.ones((2, 10), dtype=np.float32), + 300 + np.ones((2, 10, 3), dtype=np.float32))} + return precision_sensitive_with_matmul(), None, test_params + + +def create_pytorch_not_precision_sensitive(tmp_dir): + class not_precision_sensitive(torch.nn.Module): + def forward(self, x): + return torch.sum(x, 1) + + test_params = 10000.0 + np.zeros((2, 20), dtype=np.float32), # 10 000 * 20 = 200 000 > 65504 (fp16_max) + return not_precision_sensitive(), None, test_params + + +class TestPrecisionSensitive(): + test_data = [ + create_pytorch_layer_norm, + create_pytorch_normalize, + create_pytorch_precision_sensitive_with_div, + create_pytorch_precision_sensitive_div_as_pow, + create_pytorch_precision_sensitive_for_exp_reduce, + create_pytorch_precision_sensitive_two_inp_1, + create_pytorch_precision_sensitive_two_inp_2, + ] + + @pytest.mark.parametrize("create_model", test_data) + @pytest.mark.nightly + @pytest.mark.precommit + def test_precision_sensitive(self, create_model, ie_device, precision, ir_version, temp_dir, use_new_frontend, use_old_api): + import numpy.testing as npt + from pathlib import Path + + fw_model, ref_model, mo_params = create_model(temp_dir) + + test_params = {'input_model': fw_model} + if mo_params is not None: + test_params.update(mo_params) + + model = convert_model(**test_params) + model_name = 'model_test.xml' + + save_model(model, str(Path(temp_dir, model_name)), True) + + core = Core() + ir_test = core.read_model(Path(temp_dir, model_name)) + if ref_model is not None: + flag, msg = compare_functions(ir_test, ref_model, compare_tensor_names=False) + assert flag, msg + + example_inputs = test_params['example_input'] + torch_inp_tensors = [] + if isinstance(example_inputs, tuple): + for input_arr in example_inputs: + torch_inp_tensors.append(torch.tensor(input_arr)) + else: + torch_inp_tensors.append(torch.tensor(example_inputs)) + + fw_res = fw_model(*torch_inp_tensors) + ov_res = core.compile_model(ir_test)(example_inputs) + + if precision == 'FP32': + custom_eps = 1e-4 + else: + custom_eps = 1e-3 + + npt.assert_allclose(ov_res[0], fw_res.numpy(), atol=custom_eps) diff --git a/tests/layer_tests/pytorch_tests/pytorch_layer_test_class.py b/tests/layer_tests/pytorch_tests/pytorch_layer_test_class.py index 7d366d1b19f..283a90942b4 100644 --- a/tests/layer_tests/pytorch_tests/pytorch_layer_test_class.py +++ b/tests/layer_tests/pytorch_tests/pytorch_layer_test_class.py @@ -44,7 +44,6 @@ class PytorchLayerTest: :param enabled_transforms/disabled_transforms: string with idxs of transforms that should be enabled/disabled. Example: "transform_1,transform_2" """ - import torch if 'kwargs_to_prepare_input' in kwargs and kwargs['kwargs_to_prepare_input']: inputs = self._prepare_input(**kwargs['kwargs_to_prepare_input']) else: @@ -286,8 +285,6 @@ def get_params(ie_device=None, precision=None): test_args = [] for element in itertools.product(ie_device_params, precision_params): - if element[0] == 'CPU' and element[1] == 'FP16': - continue test_args.append(element) return test_args diff --git a/tests/layer_tests/tensorflow2_keras_tests/test_tf2_keras_conv_1d.py b/tests/layer_tests/tensorflow2_keras_tests/test_tf2_keras_conv_1d.py index 3710fa51e67..2364f429349 100644 --- a/tests/layer_tests/tensorflow2_keras_tests/test_tf2_keras_conv_1d.py +++ b/tests/layer_tests/tensorflow2_keras_tests/test_tf2_keras_conv_1d.py @@ -16,6 +16,7 @@ class TestKerasConv1D(CommonTF2LayerTest): "softmax": tf.nn.softmax, "swish": tf.nn.swish } + conv_params = conv_params.copy() if "activation" in conv_params: conv_params["activation"] = activation_func_structure[conv_params["activation"]] diff --git a/tests/layer_tests/tensorflow2_keras_tests/test_tf2_keras_conv_2d_transpose.py b/tests/layer_tests/tensorflow2_keras_tests/test_tf2_keras_conv_2d_transpose.py index 1bb2c13455a..8d785c0d7b5 100644 --- a/tests/layer_tests/tensorflow2_keras_tests/test_tf2_keras_conv_2d_transpose.py +++ b/tests/layer_tests/tensorflow2_keras_tests/test_tf2_keras_conv_2d_transpose.py @@ -16,6 +16,7 @@ class TestKerasConv2DTranspose(CommonTF2LayerTest): "relu": tf.nn.relu, "sigmoid": tf.nn.sigmoid } + conv_params = conv_params.copy() if "activation" in conv_params: conv_params["activation"] = activation_func_structure[conv_params["activation"]] diff --git a/tests/layer_tests/tensorflow_tests/test_tf_BinaryOps.py b/tests/layer_tests/tensorflow_tests/test_tf_BinaryOps.py index 41a43b61f79..62689f5609c 100644 --- a/tests/layer_tests/tensorflow_tests/test_tf_BinaryOps.py +++ b/tests/layer_tests/tensorflow_tests/test_tf_BinaryOps.py @@ -128,9 +128,9 @@ class TestBinaryOps(CommonTFLayerTest): @pytest.mark.precommit def test_binary_op(self, params, ie_device, precision, ir_version, temp_dir, op_type, use_new_frontend, use_old_api): - if ie_device == 'GPU' and precision == "FP16": - pytest.skip("BinaryOps tests temporary skipped on GPU with FP16 precision." - "Several tests don't pass accuracy checks.") + if precision == "FP16": + pytest.skip("BinaryOps tests are skipped with FP16 precision." + "They don't pass accuracy checks because chaotic output.") self._test( *self.create_add_placeholder_const_net(**params, ir_version=ir_version, op_type=op_type, use_new_frontend=use_new_frontend), ie_device,