[MO TESTS] PReLU tests fix (#13960)

This commit is contained in:
Roman Lyamin 2022-11-11 23:39:17 +04:00 committed by GitHub
parent 3ec40c490b
commit 5eae673220
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -94,9 +94,9 @@ class TestPRelu(OnnxRuntimeLayerTest):
# Note: IE only support slopes of one element or of size equal to number of channels. # Note: IE only support slopes of one element or of size equal to number of channels.
test_data_shared_channels = [ test_data_shared_channels = [
dict(shape=[10, 12], slope_shape=[12]), dict(shape=[10, 12], slope_shape=[12]),
dict(shape=[8, 10, 12], slope_shape=[10]), dict(shape=[8, 10, 12], slope_shape=[10, 1]),
dict(shape=[6, 8, 10, 12], slope_shape=[8]), dict(shape=[6, 8, 10, 12], slope_shape=[8, 1, 1]),
dict(shape=[4, 6, 8, 10, 12], slope_shape=[6])] dict(shape=[4, 6, 8, 10, 12], slope_shape=[6, 1, 1, 1])]
test_data_scalar_precommit = [ test_data_scalar_precommit = [
dict(shape=[2, 4, 6, 8], slope_shape=[1]), dict(shape=[2, 4, 6, 8], slope_shape=[1]),
@ -111,19 +111,6 @@ class TestPRelu(OnnxRuntimeLayerTest):
test_data_precommit = [dict(shape=[8, 10, 12], slope_shape=[12])] test_data_precommit = [dict(shape=[8, 10, 12], slope_shape=[12])]
@pytest.mark.parametrize("params", test_data_scalar)
@pytest.mark.nightly
def test_prelu_opset6_scalar(self, params, ie_device, precision, ir_version, temp_dir, use_old_api):
self._test(*self.create_net(**params, precision=precision, opset=6, ir_version=ir_version),
ie_device, precision, ir_version, temp_dir=temp_dir, use_old_api=use_old_api)
@pytest.mark.parametrize("params", test_data_shared_channels)
@pytest.mark.nightly
def test_prelu_opset6_shared_channels(self, params, ie_device, precision, ir_version, temp_dir,
use_old_api):
self._test(*self.create_net(**params, precision=precision, opset=6, ir_version=ir_version),
ie_device, precision, ir_version, temp_dir=temp_dir, use_old_api=use_old_api)
@pytest.mark.parametrize("params", test_data_scalar) @pytest.mark.parametrize("params", test_data_scalar)
@pytest.mark.nightly @pytest.mark.nightly
def test_prelu_opset7_scalar(self, params, ie_device, precision, ir_version, temp_dir, use_old_api): def test_prelu_opset7_scalar(self, params, ie_device, precision, ir_version, temp_dir, use_old_api):