From 28d2e77a926e96f7091ae1592b190d40633d4415 Mon Sep 17 00:00:00 2001 From: Ruslan Nugmanov Date: Mon, 29 May 2023 13:50:34 +0200 Subject: [PATCH] TFLite layer tests second part (#17688) * tfl - removes redundant params * tfl - batch matmul * tfl - expand_dims * tfl - squeeze * tfl - hardswish * tfl - batch matmul * tfl - padv2 * tfl - fixes for dynamic shapes * tfl - where * tfl - zeros_like * tfl - zeros_like * tfl - precommit fix * tfl - shape and xfail for expand dims --- tests/layer_tests/common/layer_utils.py | 5 +- .../layer_tests/common/utils/tflite_utils.py | 8 +++ .../test_tfl_BatchMatmul.py | 65 +++++++++++++++++++ .../tensorflow_lite_tests/test_tfl_Conv2D.py | 16 ++--- .../test_tfl_DepthToSpace.py | 8 +-- .../test_tfl_DepthwiseConv2D.py | 8 +-- .../test_tfl_ExpandDims.py | 46 +++++++++++++ .../test_tfl_FullyConnected.py | 13 ++-- .../test_tfl_HardSwish.py | 40 ++++++++++++ .../tensorflow_lite_tests/test_tfl_Pad.py | 50 +++++++++----- .../tensorflow_lite_tests/test_tfl_Pool.py | 13 ++-- .../tensorflow_lite_tests/test_tfl_Shape.py | 51 +++++++++++++++ .../test_tfl_SpaceToDepth.py | 8 +-- .../tensorflow_lite_tests/test_tfl_Squeeze.py | 36 ++++++++++ .../test_tfl_TransposeConv.py | 15 ++--- .../tensorflow_lite_tests/test_tfl_Unary.py | 3 - .../tensorflow_lite_tests/test_tfl_Where.py | 39 +++++++++++ .../test_tfl_ZerosLike.py | 41 ++++++++++++ .../tensorflow_tests/test_tf_GatherNd.py | 2 +- 19 files changed, 402 insertions(+), 65 deletions(-) create mode 100644 tests/layer_tests/tensorflow_lite_tests/test_tfl_BatchMatmul.py create mode 100644 tests/layer_tests/tensorflow_lite_tests/test_tfl_ExpandDims.py create mode 100644 tests/layer_tests/tensorflow_lite_tests/test_tfl_HardSwish.py create mode 100644 tests/layer_tests/tensorflow_lite_tests/test_tfl_Shape.py create mode 100644 tests/layer_tests/tensorflow_lite_tests/test_tfl_Squeeze.py create mode 100644 tests/layer_tests/tensorflow_lite_tests/test_tfl_Where.py create mode 100644 tests/layer_tests/tensorflow_lite_tests/test_tfl_ZerosLike.py diff --git a/tests/layer_tests/common/layer_utils.py b/tests/layer_tests/common/layer_utils.py index 007839743c5..95fa18d196a 100644 --- a/tests/layer_tests/common/layer_utils.py +++ b/tests/layer_tests/common/layer_utils.py @@ -119,6 +119,9 @@ class InferAPI20(BaseInfer): net = core.read_model(self.model, self.weights) inputs_info = {} for item in net.inputs: - inputs_info[item.get_any_name()] = list(item.shape) + if item.partial_shape.is_dynamic: + inputs_info[item.get_any_name()] = item.partial_shape + else: + inputs_info[item.get_any_name()] = item.partial_shape.to_shape() return inputs_info diff --git a/tests/layer_tests/common/utils/tflite_utils.py b/tests/layer_tests/common/utils/tflite_utils.py index 6e5f98c1475..887cded30e5 100644 --- a/tests/layer_tests/common/utils/tflite_utils.py +++ b/tests/layer_tests/common/utils/tflite_utils.py @@ -1,5 +1,6 @@ import itertools import os +import warnings import numpy as np import tensorflow as tf @@ -90,6 +91,13 @@ def get_tflite_results(use_new_frontend, use_old_api, inputs_dict, model_path): for layer, data in inputs_dict.items(): tensor_index = input_name_to_id_mapping[layer] tensor_id = next(i for i, tensor in enumerate(input_details) if tensor['index'] == tensor_index) + + if list(input_details[tensor_id]['shape']) != list(data.shape): + warnings.warn(f'Model and data have different shapes:\nModel {tensor_id} ' + f'input shape{input_details[tensor_id]["shape"]}\nInput data shape: {data.shape}') + interpreter.resize_tensor_input(input_details[tensor_id]['index'], data.shape) + interpreter.allocate_tensors() + interpreter.set_tensor(input_details[tensor_id]['index'], data) interpreter.invoke() diff --git a/tests/layer_tests/tensorflow_lite_tests/test_tfl_BatchMatmul.py b/tests/layer_tests/tensorflow_lite_tests/test_tfl_BatchMatmul.py new file mode 100644 index 00000000000..46e78b3cbab --- /dev/null +++ b/tests/layer_tests/tensorflow_lite_tests/test_tfl_BatchMatmul.py @@ -0,0 +1,65 @@ +import pytest +import tensorflow as tf +import numpy as np + +from common.tflite_layer_test_class import TFLiteLayerTest + +test_params = [ + {'shapes': ((None, 4, 5), (None, 5, 6), (3, 4, 5), (3, 5, 6)), 'adjoint_a': True, 'adjoint_b': True}, + {'shapes': ((None, 1, 3, 4), (None, 4, 2), (2, 1, 3, 4), (5, 4, 2)), 'adjoint_a': False, 'adjoint_b': False}, + {'shapes': ((None, None, None, 3, 4), (None, None, None, 4, 3), + (2, 2, 2, 3, 4), (2, 2, 2, 4, 3)), 'adjoint_a': False, 'adjoint_b': False}, +] + + +class TestTFLiteBatchMatmulLayerTest(TFLiteLayerTest): + inputs = ["Input", "Input1"] + outputs = ["BatchMatmul"] + allowed_ops = ['BATCH_MATMUL'] + + def _prepare_input(self, inputs_dict, generator=None): + input0_shape = self.shapes[2] + adj_a = self.adjoint_a + adj_b = self.adjoint_b + if adj_a: + input0_shape = self._swap_last_two_dims(*input0_shape) + inputs_dict['Input'] = np.float32((1.0 - (-1.0)) * np.random.random_sample(input0_shape) + (-1.0)) + + input1_shape = self.shapes[3] if not adj_b else self._swap_last_two_dims(*self.shapes[3]) + inputs_dict['Input1'] = np.float32((1.0 - (-1.0)) * np.random.random_sample(input1_shape) + (-1.0)) + + return inputs_dict + + def _swap_last_two_dims(self, *args): + """Return a tuple with the last two dimensions swapped.""" + return args[:-2] + (args[-1],) + (args[-2],) + + def make_model(self, params): + assert len(set(params.keys()).intersection({'shapes', 'adjoint_a', 'adjoint_b'})) == 3, \ + 'Unexpected parameters for test: ' + ','.join(params.keys()) + self.shapes = params['shapes'] + self.adjoint_a = params['adjoint_a'] + self.adjoint_b = params['adjoint_b'] + tf.compat.v1.reset_default_graph() + with tf.compat.v1.Session() as sess: + placeholder0_shape = self.shapes[0] + adj_a = params["adjoint_a"] + adj_b = params["adjoint_b"] + + if adj_a: + placeholder0_shape = self._swap_last_two_dims(*placeholder0_shape) + input0_tensor = tf.compat.v1.placeholder(dtype=tf.float32, shape=placeholder0_shape, name=self.inputs[0]) + if adj_b: + placeholder1_shape = self._swap_last_two_dims(*self.shapes[1]) + else: + placeholder1_shape = self.shapes[1] + input1_tensor = tf.compat.v1.placeholder(dtype=tf.float32, shape=placeholder1_shape, name=self.inputs[1]) + + tf.matmul(input0_tensor, input1_tensor, adjoint_a=adj_a, adjoint_b=adj_b, name=self.outputs[0]) + net = sess.graph_def + return net + + @pytest.mark.parametrize("params", test_params) + @pytest.mark.nightly + def test_batch_matmul(self, params, ie_device, precision, temp_dir): + self._test(ie_device, precision, temp_dir, params) diff --git a/tests/layer_tests/tensorflow_lite_tests/test_tfl_Conv2D.py b/tests/layer_tests/tensorflow_lite_tests/test_tfl_Conv2D.py index 52a03a982f6..cbed6411f5e 100644 --- a/tests/layer_tests/tensorflow_lite_tests/test_tfl_Conv2D.py +++ b/tests/layer_tests/tensorflow_lite_tests/test_tfl_Conv2D.py @@ -7,14 +7,10 @@ from common.tflite_layer_test_class import TFLiteLayerTest np.random.seed(42) test_params = [ - {'shape': [1, 22, 22, 8], 'ksize': [32, 3, 4, 4], 'strides': 2, 'padding': 'SAME', 'data_format': 'NHWC', - 'dilations': [1, 1, 1, 1]}, - {'shape': [1, 22, 22, 9], 'ksize': [32, 3, 3, 3], 'strides': (2, 1), 'padding': 'SAME', 'data_format': 'NHWC', - 'dilations': [1, 2, 2, 1]}, - {'shape': [1, 22, 22, 8], 'ksize': [1, 3, 4, 4], 'strides': 2, 'padding': 'VALID', 'data_format': 'NHWC', - 'dilations': [1, 1, 1, 1]}, - {'shape': [1, 22, 22, 3], 'ksize': [1, 3, 3, 3], 'strides': (3, 4), 'padding': 'VALID', 'data_format': 'NHWC', - 'dilations': [1, 2, 2, 1]}, + {'shape': [1, 22, 22, 8], 'ksize': [32, 3, 4, 4], 'strides': 2, 'padding': 'SAME', 'dilations': [1, 1, 1, 1]}, + {'shape': [1, 22, 22, 9], 'ksize': [32, 3, 3, 3], 'strides': (2, 1), 'padding': 'SAME', 'dilations': [1, 2, 2, 1]}, + {'shape': [1, 22, 22, 8], 'ksize': [1, 3, 4, 4], 'strides': 2, 'padding': 'VALID', 'dilations': [1, 1, 1, 1]}, + {'shape': [1, 22, 22, 3], 'ksize': [1, 3, 3, 3], 'strides': (3, 4), 'padding': 'VALID', 'dilations': [1, 2, 2, 1]}, ] @@ -25,14 +21,14 @@ class TestTFLiteConv2DLayerTest(TFLiteLayerTest): def make_model(self, params): assert len(set(params.keys()).intersection({'shape', 'ksize', 'strides', - 'padding', 'data_format', 'dilations'})) == 6, \ + 'padding', 'dilations'})) == 5, \ 'Unexpected parameters for test: ' + ','.join(params.keys()) tf.compat.v1.reset_default_graph() with tf.compat.v1.Session() as sess: weights = tf.constant(np.random.randint(-1, 1, params['ksize']), dtype=tf.float32) place_holder = tf.compat.v1.placeholder(params.get('dtype', tf.float32), params['shape'], name=self.inputs[0]) - tf.nn.conv2d(place_holder, weights, params['strides'], params['padding'], params['data_format'], + tf.nn.conv2d(place_holder, weights, params['strides'], params['padding'], 'NHWC', params['dilations'], name=self.outputs[0]) net = sess.graph_def return net diff --git a/tests/layer_tests/tensorflow_lite_tests/test_tfl_DepthToSpace.py b/tests/layer_tests/tensorflow_lite_tests/test_tfl_DepthToSpace.py index e50b22fa132..d92ea652d01 100644 --- a/tests/layer_tests/tensorflow_lite_tests/test_tfl_DepthToSpace.py +++ b/tests/layer_tests/tensorflow_lite_tests/test_tfl_DepthToSpace.py @@ -4,8 +4,8 @@ import tensorflow as tf from common.tflite_layer_test_class import TFLiteLayerTest test_params = [ - {'shape': [8, 10, 10, 16], 'block_size': 2, 'data_format': 'NHWC'}, - {'shape': [24, 10, 10, 50], 'block_size': 5, 'data_format': 'NHWC'}, + {'shape': [8, 10, 10, 16], 'block_size': 2}, + {'shape': [24, 10, 10, 50], 'block_size': 5}, ] @@ -15,13 +15,13 @@ class TestTFLiteDepthToSpaceLayerTest(TFLiteLayerTest): allowed_ops = ['DEPTH_TO_SPACE'] def make_model(self, params): - assert len(set(params.keys()).intersection({'shape', 'block_size', 'data_format'})) == 3, \ + assert len(set(params.keys()).intersection({'shape', 'block_size'})) == 2, \ 'Unexpected parameters for test: ' + ','.join(params.keys()) tf.compat.v1.reset_default_graph() with tf.compat.v1.Session() as sess: place_holder = tf.compat.v1.placeholder(params.get('dtype', tf.float32), params['shape'], name=self.inputs[0]) - tf.nn.depth_to_space(place_holder, params['block_size'], params['data_format'], name=self.outputs[0]) + tf.nn.depth_to_space(place_holder, params['block_size'], 'NHWC', name=self.outputs[0]) net = sess.graph_def return net diff --git a/tests/layer_tests/tensorflow_lite_tests/test_tfl_DepthwiseConv2D.py b/tests/layer_tests/tensorflow_lite_tests/test_tfl_DepthwiseConv2D.py index 24f8023975c..c330553c79e 100644 --- a/tests/layer_tests/tensorflow_lite_tests/test_tfl_DepthwiseConv2D.py +++ b/tests/layer_tests/tensorflow_lite_tests/test_tfl_DepthwiseConv2D.py @@ -7,9 +7,9 @@ from common.tflite_layer_test_class import TFLiteLayerTest np.random.seed(42) test_params = [ - {'shape': [1, 22, 22, 8], 'ksize': [3, 3, 8, 2], 'strides': [1, 2, 2, 1], 'padding': 'SAME', 'data_format': 'NHWC', + {'shape': [1, 22, 22, 8], 'ksize': [3, 3, 8, 2], 'strides': [1, 2, 2, 1], 'padding': 'SAME', 'dilations': [1, 1]}, - {'shape': [1, 22, 22, 9], 'ksize': [3, 3, 9, 1], 'strides': [1, 1, 1, 1], 'padding': 'SAME', 'data_format': 'NHWC', + {'shape': [1, 22, 22, 9], 'ksize': [3, 3, 9, 1], 'strides': [1, 1, 1, 1], 'padding': 'SAME', 'dilations': [1, 1]}, ] @@ -21,14 +21,14 @@ class TestTFLiteDepthwiseConv2DLayerTest(TFLiteLayerTest): def make_model(self, params): assert len(set(params.keys()).intersection({'shape', 'ksize', 'strides', - 'padding', 'data_format', 'dilations'})) == 6, \ + 'padding', 'dilations'})) == 5, \ 'Unexpected parameters for test: ' + ','.join(params.keys()) tf.compat.v1.reset_default_graph() with tf.compat.v1.Session() as sess: weights = tf.constant(np.random.randint(-1, 1, params['ksize']), dtype=tf.float32) place_holder = tf.compat.v1.placeholder(params.get('dtype', tf.float32), params['shape'], name=self.inputs[0]) - tf.nn.depthwise_conv2d(place_holder, weights, params['strides'], params['padding'], params['data_format'], + tf.nn.depthwise_conv2d(place_holder, weights, params['strides'], params['padding'], 'NHWC', params['dilations'], name=self.outputs[0]) net = sess.graph_def return net diff --git a/tests/layer_tests/tensorflow_lite_tests/test_tfl_ExpandDims.py b/tests/layer_tests/tensorflow_lite_tests/test_tfl_ExpandDims.py new file mode 100644 index 00000000000..711862ef319 --- /dev/null +++ b/tests/layer_tests/tensorflow_lite_tests/test_tfl_ExpandDims.py @@ -0,0 +1,46 @@ +import pytest +import tensorflow as tf +import numpy as np + +from common.tflite_layer_test_class import TFLiteLayerTest + +test_params = [ + {'shape': [2], 'axis': [0]}, + {'shape': [2, 2], 'axis': [2]}, + {'shape': [2, 2], 'axis': [0, 2]}, + {'shape': [5, 2, 2, 2], 'axis': [1]}, + {'shape': [5, 2, 2, 2], 'axis': [1, 1]}, + {'shape': [5, 2, 2, 2], 'axis': [-1, -2, -3]}, +] + + +class TestTFLiteExpandDimsLayerTest(TFLiteLayerTest): + inputs = ["Input", "Input1"] + outputs = ["ExpandDims"] + allowed_ops = ['EXPAND_DIMS'] + + def _prepare_input(self, inputs_dict, generator=None): + inputs_dict['Input'] = (1.0 - (-1.0)) * np.random.random_sample(inputs_dict['Input']) + (-1.0) + inputs_dict['Input1'] = self.axis + + return inputs_dict + + def make_model(self, params): + assert len(set(params.keys()).intersection({'shape', 'axis'})) == 2, \ + 'Unexpected parameters for test: ' + ','.join(params.keys()) + tf.compat.v1.reset_default_graph() + self.axis = params['axis'] + + with tf.compat.v1.Session() as sess: + input_value1 = tf.compat.v1.placeholder(tf.float32, params['shape'], name=self.inputs[0]) + axis = tf.compat.v1.placeholder(tf.int32, [len(params['axis'])], name=self.inputs[1]) + + tf.expand_dims(input_value1, axis, name=self.outputs[0]) + net = sess.graph_def + return net + + @pytest.mark.parametrize("params", test_params) + @pytest.mark.nightly + def test_expand_dims(self, params, ie_device, precision, temp_dir): + pytest.xfail("CVS-111983") + self._test(ie_device, precision, temp_dir, params) diff --git a/tests/layer_tests/tensorflow_lite_tests/test_tfl_FullyConnected.py b/tests/layer_tests/tensorflow_lite_tests/test_tfl_FullyConnected.py index 5b5d522c5fe..a8119813dd4 100644 --- a/tests/layer_tests/tensorflow_lite_tests/test_tfl_FullyConnected.py +++ b/tests/layer_tests/tensorflow_lite_tests/test_tfl_FullyConnected.py @@ -4,9 +4,9 @@ import tensorflow as tf from common.tflite_layer_test_class import TFLiteLayerTest test_params = [ - {'shape_x': [40, 37], 'shape_y': [37, 37], 'transpose_a': False, 'transpose_b': True}, - {'shape_x': [5, 5], 'shape_y': [4, 5], 'transpose_a': False, 'transpose_b': True}, - {'shape_x': [1, 5, 5], 'shape_y': [4, 5], 'transpose_a': False, 'transpose_b': True}, + {'shape_x': [40, 37], 'shape_y': [37, 37]}, + {'shape_x': [5, 5], 'shape_y': [4, 5]}, + {'shape_x': [1, 5, 5], 'shape_y': [4, 5]}, ] @@ -16,16 +16,15 @@ class TestTFLiteFullyConnectedLayerTest(TFLiteLayerTest): allowed_ops = ['FULLY_CONNECTED'] def make_model(self, params): - assert len(set(params.keys()).intersection({'shape_x', 'shape_y', - 'transpose_a', 'transpose_b'})) == 4, \ + assert len(set(params.keys()).intersection({'shape_x', 'shape_y'})) == 2, \ 'Unexpected parameters for test: ' + ','.join(params.keys()) + tf.compat.v1.reset_default_graph() with tf.compat.v1.Session() as sess: x = tf.compat.v1.placeholder(params.get('dtype', tf.float32), params['shape_x'], name=self.inputs[0]) y = tf.compat.v1.placeholder(params.get('dtype', tf.float32), params['shape_y'], name=self.inputs[1]) - tf.matmul(x, y, transpose_a=params['transpose_a'], transpose_b=params['transpose_b'], - name=self.outputs[0]) + tf.matmul(x, y, transpose_a=False, transpose_b=True, name=self.outputs[0]) net = sess.graph_def return net diff --git a/tests/layer_tests/tensorflow_lite_tests/test_tfl_HardSwish.py b/tests/layer_tests/tensorflow_lite_tests/test_tfl_HardSwish.py new file mode 100644 index 00000000000..a2d7cc53c55 --- /dev/null +++ b/tests/layer_tests/tensorflow_lite_tests/test_tfl_HardSwish.py @@ -0,0 +1,40 @@ +import pytest +import tensorflow as tf +import numpy as np + +from common.tflite_layer_test_class import TFLiteLayerTest + +test_params = [ + {'shape': [1]}, + {'shape': [2, 3]}, + {'shape': [1, 1, 1, 1]}, + {'shape': [1, 3, 4, 3]}, + {'shape': [3, 15, 14, 3]}, +] + + +class TestTFLiteHardSwishLayerTest(TFLiteLayerTest): + inputs = ["Input"] + outputs = ["HardSwish"] + allowed_ops = ['HARD_SWISH'] + + def _prepare_input(self, inputs_dict, generator=None): + inputs_dict['Input'] = np.float32((10 - (-10)) * np.random.random_sample(inputs_dict['Input']) + (-10)) + return inputs_dict + + def make_model(self, params): + assert len(set(params.keys()).intersection({'shape'})) == 1, \ + 'Unexpected parameters for test: ' + ','.join(params.keys()) + tf.compat.v1.reset_default_graph() + with tf.compat.v1.Session() as sess: + placeholder = tf.compat.v1.placeholder(tf.float32, params['shape'], + name=self.inputs[0]) + hs = placeholder * tf.nn.relu6(placeholder + np.float32(3)) * np.float32(1. / 6.) + tf.identity(hs, name=self.outputs[0]) + net = sess.graph_def + return net + + @pytest.mark.parametrize("params", test_params) + @pytest.mark.nightly + def test_hardswish(self, params, ie_device, precision, temp_dir): + self._test(ie_device, precision, temp_dir, params) diff --git a/tests/layer_tests/tensorflow_lite_tests/test_tfl_Pad.py b/tests/layer_tests/tensorflow_lite_tests/test_tfl_Pad.py index 37a7362311b..a17e9d0e5a7 100644 --- a/tests/layer_tests/tensorflow_lite_tests/test_tfl_Pad.py +++ b/tests/layer_tests/tensorflow_lite_tests/test_tfl_Pad.py @@ -4,40 +4,58 @@ import tensorflow as tf from common.tflite_layer_test_class import TFLiteLayerTest test_params = [ - {'shape': [1, 1, 2, 1, 1], 'paddings': [[0, 0], [0, 1], [2, 3], [0, 0], [1, 0]]}, - {'shape': [2, 1, 1, 1, 1], 'paddings': [[0, 1], [0, 0], [0, 0], [2, 3], [1, 0]]}, + {'op_name': 'PAD', 'shape': [1, 1, 2, 1, 1], 'paddings': [[0, 0], [0, 1], [2, 3], [0, 0], [1, 0]]}, + {'op_name': 'PAD', 'shape': [2, 1, 1, 1, 1], 'paddings': [[0, 1], [0, 0], [0, 0], [2, 3], [1, 0]]}, - {'shape': [1, 1, 2, 1], 'paddings': [[0, 0], [0, 1], [2, 3], [0, 0]]}, - {'shape': [1, 1, 2, 1], 'paddings': [[0, 0], [0, 1], [2, 3], [0, 0]]}, + {'op_name': 'PAD', 'shape': [1, 1, 2, 1], 'paddings': [[0, 0], [0, 1], [2, 3], [0, 0]]}, + {'op_name': 'PAD', 'shape': [1, 1, 2, 1], 'paddings': [[0, 0], [0, 1], [2, 3], [0, 0]]}, - {'shape': [1, 2], 'paddings': [[0, 1], [2, 1]]}, - {'shape': [1, 2], 'paddings': [[2, 3], [0, 1]]}, + {'op_name': 'PAD', 'shape': [1, 2], 'paddings': [[0, 1], [2, 1]]}, + {'op_name': 'PAD', 'shape': [1, 2], 'paddings': [[2, 3], [0, 1]]}, - {'shape': [1], 'paddings': [[1, 2]]}, + {'op_name': 'PAD', 'shape': [1], 'paddings': [[1, 2]]}, + + {'op_name': 'PADV2', 'shape': [1, 1, 2, 1, 1], 'paddings': [[0, 0], [0, 1], [2, 3], [0, 0], [1, 0]], + 'constant_value': -1}, + {'op_name': 'PADV2', 'shape': [2, 1, 1, 1, 1], 'paddings': [[0, 1], [0, 0], [0, 0], [2, 3], [1, 0]], + 'constant_value': 1}, + + {'op_name': 'PADV2', 'shape': [1, 1, 2, 1], 'paddings': [[0, 0], [0, 1], [2, 3], [0, 0]], 'constant_value': -1}, + {'op_name': 'PADV2', 'shape': [1, 1, 2, 1], 'paddings': [[0, 0], [0, 1], [2, 3], [0, 0]], 'constant_value': 1}, + + {'op_name': 'PADV2', 'shape': [1, 2], 'paddings': [[0, 1], [2, 1]], 'constant_value': 1}, + {'op_name': 'PADV2', 'shape': [1, 2], 'paddings': [[2, 3], [0, 1]], 'constant_value': -1}, + + {'op_name': 'PADV2', 'shape': [1], 'paddings': [[1, 2]], 'constant_value': -1}, ] class TestTFLitePadLayerTest(TFLiteLayerTest): - inputs = ["Input", 'Paddings'] outputs = ["Pad"] - allowed_ops = ['PAD'] def make_model(self, params): - assert len(set(params.keys()).intersection({'shape', 'paddings'})) == 2, \ + assert len(set(params.keys()).intersection({'op_name', 'shape', 'paddings'})) == 3, \ 'Unexpected parameters for test: ' + ','.join(params.keys()) tf.compat.v1.reset_default_graph() + self.allowed_ops = [params['op_name']] + self.inputs = ["Input"] with tf.compat.v1.Session() as sess: - place_holder = tf.compat.v1.placeholder(params.get('dtype', tf.float32), params['shape'], - name=self.inputs[0]) - shape = [len(params["paddings"]), 2] - paddings = tf.compat.v1.placeholder(dtype=tf.int32, name=self.inputs[1], shape=shape) - tf.pad(tensor=place_holder, paddings=paddings, name=self.outputs[0]) + place_holder = tf.compat.v1.placeholder(tf.float32, params['shape'], name=self.inputs[0]) + if params['op_name'] == 'PADV2': + tf.pad(tensor=place_holder, paddings=params['paddings'], constant_values=params['constant_value'], + name=self.outputs[0]) + else: + self.inputs.append('Paddings') + shape = [len(params["paddings"]), 2] + paddings = tf.compat.v1.placeholder(dtype=tf.int32, name=self.inputs[1], shape=shape) + tf.pad(tensor=place_holder, paddings=paddings, name=self.outputs[0]) net = sess.graph_def return net @pytest.mark.parametrize("params", test_params) @pytest.mark.nightly def test_pad(self, params, ie_device, precision, temp_dir): - pytest.xfail("CVS-110828") + if params['op_name'] == 'PAD': + pytest.xfail("CVS-110828") self._test(ie_device, precision, temp_dir, params) diff --git a/tests/layer_tests/tensorflow_lite_tests/test_tfl_Pool.py b/tests/layer_tests/tensorflow_lite_tests/test_tfl_Pool.py index a67097628e2..fd96bcf08cd 100644 --- a/tests/layer_tests/tensorflow_lite_tests/test_tfl_Pool.py +++ b/tests/layer_tests/tensorflow_lite_tests/test_tfl_Pool.py @@ -10,11 +10,10 @@ test_ops = [ ] test_params = [ - # TFLite doesn't support avgpool with 'NCHW' format - {'shape': [1, 22, 22, 8], 'ksize': [3, 3], 'strides': 2, 'padding': 'SAME', 'data_format': 'NHWC'}, - {'shape': [1, 22, 22, 8], 'ksize': [3, 3], 'strides': (2, 1), 'padding': 'SAME', 'data_format': 'NHWC'}, - {'shape': [1, 22, 22, 8], 'ksize': [3, 3], 'strides': 2, 'padding': 'VALID', 'data_format': 'NHWC'}, - {'shape': [1, 22, 22, 8], 'ksize': [3, 3], 'strides': (3, 4), 'padding': 'VALID', 'data_format': 'NHWC'}, + {'shape': [1, 22, 22, 8], 'ksize': [3, 3], 'strides': 2, 'padding': 'SAME'}, + {'shape': [1, 22, 22, 8], 'ksize': [3, 3], 'strides': (2, 1), 'padding': 'SAME'}, + {'shape': [1, 22, 22, 8], 'ksize': [3, 3], 'strides': 2, 'padding': 'VALID'}, + {'shape': [1, 22, 22, 8], 'ksize': [3, 3], 'strides': (3, 4), 'padding': 'VALID'}, ] test_data = parametrize_tests(test_ops, test_params) @@ -26,7 +25,7 @@ class TestTFLitePoolLayerTest(TFLiteLayerTest): def make_model(self, params): assert len(set(params.keys()).intersection({'op_name', 'op_func', 'shape', 'ksize', 'strides', - 'padding', 'data_format'})) == 7, \ + 'padding'})) == 6, \ 'Unexpected parameters for test: ' + ','.join(params.keys()) self.allowed_ops = [params['op_name']] tf.compat.v1.reset_default_graph() @@ -34,7 +33,7 @@ class TestTFLitePoolLayerTest(TFLiteLayerTest): place_holder = tf.compat.v1.placeholder(params.get('dtype', tf.float32), params['shape'], name=self.inputs[0]) params['op_func'](place_holder, params['ksize'], params['strides'], - params['padding'], params['data_format'], name=self.outputs[0]) + params['padding'], 'NHWC', name=self.outputs[0]) net = sess.graph_def return net diff --git a/tests/layer_tests/tensorflow_lite_tests/test_tfl_Shape.py b/tests/layer_tests/tensorflow_lite_tests/test_tfl_Shape.py new file mode 100644 index 00000000000..8e132944786 --- /dev/null +++ b/tests/layer_tests/tensorflow_lite_tests/test_tfl_Shape.py @@ -0,0 +1,51 @@ +# Copyright (C) 2018-2023 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 +from functools import partial + +import numpy as np +import pytest +import tensorflow as tf + +from common.tflite_layer_test_class import TFLiteLayerTest +from common.utils.tflite_utils import parametrize_tests + + +test_params = [ + {'shape': [1, 4], 'new_shape': [1, 4]}, + {'shape': [1, 4], 'new_shape': [4, 1]}, + {'shape': [1, 4], 'new_shape': [2, 2]}, +] + + +class TestTFLiteShapeLayerTest(TFLiteLayerTest): + inputs = ["Input", "Input1"] + outputs = ["Shape"] + allowed_ops = ['RESHAPE', 'SHAPE'] + + def _prepare_input(self, inputs_dict, generator=None): + inputs_dict['Input'] = np.random.randint(0, 100, self.shape).astype(np.int32) + inputs_dict['Input1'] = np.array(self.new_shape).astype(np.int32) + return inputs_dict + + def make_model(self, params): + assert len(set(params.keys()).intersection({'shape', 'new_shape'})) == 2, \ + 'Unexpected parameters for test: ' + ','.join(params.keys()) + tf.compat.v1.reset_default_graph() + self.shape = params['shape'] + self.new_shape = params['new_shape'] + + with tf.compat.v1.Session() as sess: + placeholder = tf.compat.v1.placeholder(tf.int32, self.shape, name=self.inputs[0]) + shape_of_new_shape = [len(self.new_shape)] + new_shape = tf.compat.v1.placeholder(tf.int32, shape_of_new_shape, name=self.inputs[1]) + + reshaped = tf.reshape(placeholder, shape=new_shape) + tf.shape(input=reshaped) + + net = sess.graph_def + return net + + @pytest.mark.parametrize("params", test_params) + @pytest.mark.nightly + def test_shape(self, params, ie_device, precision, temp_dir): + self._test(ie_device, precision, temp_dir, params) diff --git a/tests/layer_tests/tensorflow_lite_tests/test_tfl_SpaceToDepth.py b/tests/layer_tests/tensorflow_lite_tests/test_tfl_SpaceToDepth.py index 1c9612160a8..3aa9ff9a32d 100644 --- a/tests/layer_tests/tensorflow_lite_tests/test_tfl_SpaceToDepth.py +++ b/tests/layer_tests/tensorflow_lite_tests/test_tfl_SpaceToDepth.py @@ -4,8 +4,8 @@ import tensorflow as tf from common.tflite_layer_test_class import TFLiteLayerTest test_params = [ - {'shape': [8, 10, 10, 16], 'block_size': 2, 'data_format': 'NHWC'}, - {'shape': [24, 10, 10, 50], 'block_size': 5, 'data_format': 'NHWC'}, + {'shape': [8, 10, 10, 16], 'block_size': 2}, + {'shape': [24, 10, 10, 50], 'block_size': 5}, ] @@ -15,13 +15,13 @@ class TestTFLiteSpaceToDepthLayerTest(TFLiteLayerTest): allowed_ops = ['SPACE_TO_DEPTH'] def make_model(self, params): - assert len(set(params.keys()).intersection({'shape', 'block_size', 'data_format'})) == 3, \ + assert len(set(params.keys()).intersection({'shape', 'block_size'})) == 2, \ 'Unexpected parameters for test: ' + ','.join(params.keys()) tf.compat.v1.reset_default_graph() with tf.compat.v1.Session() as sess: place_holder = tf.compat.v1.placeholder(params.get('dtype', tf.float32), params['shape'], name=self.inputs[0]) - tf.nn.space_to_depth(place_holder, params['block_size'], params['data_format'], name=self.outputs[0]) + tf.nn.space_to_depth(place_holder, params['block_size'], 'NHWC', name=self.outputs[0]) net = sess.graph_def return net diff --git a/tests/layer_tests/tensorflow_lite_tests/test_tfl_Squeeze.py b/tests/layer_tests/tensorflow_lite_tests/test_tfl_Squeeze.py new file mode 100644 index 00000000000..0fb56be3342 --- /dev/null +++ b/tests/layer_tests/tensorflow_lite_tests/test_tfl_Squeeze.py @@ -0,0 +1,36 @@ +import pytest +import tensorflow as tf + +from common.tflite_layer_test_class import TFLiteLayerTest + +test_params = [ + {'shape': [1, 3], 'axis': [0]}, + {'shape': [2, 1], 'axis': [1]}, + {'shape': [1, 1, 2], 'axis': [0, 1]}, + {'shape': [1, 1, 2, 2], 'axis': [1]}, + {'shape': [1, 1, 2, 2], 'axis': [1, 1]}, + {'shape': [5, 1, 1, 1], 'axis': [-1, -2, -3]}, +] + + +class TestTFLiteSqueezeLayerTest(TFLiteLayerTest): + inputs = ["Input"] + outputs = ["Squeeze"] + # TFLite returns SQUEEZE only when it has undetermined rank, but OV doesn't support SQUEEZE op with such rank + allowed_ops = ['RESHAPE'] + + def make_model(self, params): + assert len(set(params.keys()).intersection({'shape', 'axis'})) == 2, \ + 'Unexpected parameters for test: ' + ','.join(params.keys()) + tf.compat.v1.reset_default_graph() + + with tf.compat.v1.Session() as sess: + input_value1 = tf.compat.v1.placeholder(tf.float32, params['shape'], name=self.inputs[0]) + tf.squeeze(input_value1, params['axis'], name=self.outputs[0]) + net = sess.graph_def + return net + + @pytest.mark.parametrize("params", test_params) + @pytest.mark.nightly + def test_squeeze_dims(self, params, ie_device, precision, temp_dir): + self._test(ie_device, precision, temp_dir, params) diff --git a/tests/layer_tests/tensorflow_lite_tests/test_tfl_TransposeConv.py b/tests/layer_tests/tensorflow_lite_tests/test_tfl_TransposeConv.py index a56976cef99..2d485845334 100644 --- a/tests/layer_tests/tensorflow_lite_tests/test_tfl_TransposeConv.py +++ b/tests/layer_tests/tensorflow_lite_tests/test_tfl_TransposeConv.py @@ -8,14 +8,13 @@ np.random.seed(42) test_params = [ {'shape': [1, 3, 4, 1], 'ksize': [1, 1, 1, 1], 'output_shape': [1, 3, 4, 1], 'strides': [1, 1, 1, 1], - 'padding': 'SAME', 'data_format': 'NHWC', 'dilations': [1, 1, 1, 1]}, + 'padding': 'SAME', 'dilations': [1, 1, 1, 1]}, {'shape': [1, 4, 4, 1], 'ksize': [1, 1, 1, 1], 'output_shape': [1, 4, 4, 1], 'strides': [1, 1], 'padding': 'SAME', - 'data_format': 'NHWC', 'dilations': [1, 2, 2, 1]}, - # + 'dilations': [1, 2, 2, 1]}, {'shape': [1, 22, 22, 3], 'ksize': [1, 1, 6, 3], 'output_shape': [1, 22, 22, 6], 'strides': [1, 1], - 'padding': 'VALID', 'data_format': 'NHWC', 'dilations': [1, 1, 1, 1]}, + 'padding': 'VALID', 'dilations': [1, 1, 1, 1]}, {'shape': [1, 22, 22, 3], 'ksize': [3, 3, 3, 3], 'output_shape': [1, 24, 24, 3], 'strides': [1, 1], - 'padding': 'VALID', 'data_format': 'NHWC', 'dilations': [1, 1, 1, 1]}, + 'padding': 'VALID', 'dilations': [1, 1, 1, 1]}, ] @@ -26,7 +25,7 @@ class TestTFLiteTransposeConvLayerTest(TFLiteLayerTest): def make_model(self, params): assert len(set(params.keys()).intersection({'shape', 'ksize', 'strides', - 'padding', 'data_format', 'dilations', 'output_shape'})) == 7, \ + 'padding', 'dilations', 'output_shape'})) == 6, \ 'Unexpected parameters for test: ' + ','.join(params.keys()) tf.compat.v1.reset_default_graph() with tf.compat.v1.Session() as sess: @@ -34,8 +33,8 @@ class TestTFLiteTransposeConvLayerTest(TFLiteLayerTest): name=self.inputs[0]) filter_input = tf.constant(np.random.randint(-1, 1, size=(params['ksize'])), dtype=tf.float32) tf.nn.conv2d_transpose(placeholder, filter_input, params['output_shape'], params["strides"], - params["padding"], - params["data_format"], name=self.outputs[0]) + params["padding"], 'NHWC', name=self.outputs[0]) + net = sess.graph_def return net diff --git a/tests/layer_tests/tensorflow_lite_tests/test_tfl_Unary.py b/tests/layer_tests/tensorflow_lite_tests/test_tfl_Unary.py index ac45dc4283d..e6a4789a34c 100644 --- a/tests/layer_tests/tensorflow_lite_tests/test_tfl_Unary.py +++ b/tests/layer_tests/tensorflow_lite_tests/test_tfl_Unary.py @@ -41,9 +41,6 @@ test_ops = [ # This op could not be converted standalone -- tries to become FlexOp (offload from tfl to tf) # {'op_name': 'SIGN', 'op_func': tf.math.sign}, - - # TF has no such standalone operation - # {'op_name': 'HARD_SWISH'} ] test_params = [ diff --git a/tests/layer_tests/tensorflow_lite_tests/test_tfl_Where.py b/tests/layer_tests/tensorflow_lite_tests/test_tfl_Where.py new file mode 100644 index 00000000000..f444a45e938 --- /dev/null +++ b/tests/layer_tests/tensorflow_lite_tests/test_tfl_Where.py @@ -0,0 +1,39 @@ +import numpy as np +import pytest +import tensorflow as tf + +from common.tflite_layer_test_class import TFLiteLayerTest + +np.random.seed(42) + +test_params = [ + {'shape': [10]}, + {'shape': [1, 2, 3, 4]}, + {'shape': [8, 7, 6, 5, 4, 3, 2, 1]} +] + + +class TestTFLiteWhereLayerTest(TFLiteLayerTest): + inputs = ["Input"] + outputs = ["Where"] + allowed_ops = ['WHERE'] + + def _prepare_input(self, inputs_dict, generator=None): + inputs_dict['Input'] = np.random.randint(0, 1, inputs_dict['Input']) < 1 + return inputs_dict + + def make_model(self, params): + assert len(set(params.keys()).intersection({'shape'})) == 1, \ + 'Unexpected parameters for test: ' + ','.join(params.keys()) + tf.compat.v1.reset_default_graph() + + with tf.compat.v1.Session() as sess: + input_value1 = tf.compat.v1.placeholder(tf.bool, params['shape'], name=self.inputs[0]) + tf.where(input_value1, name=self.outputs[0]) + net = sess.graph_def + return net + + @pytest.mark.parametrize("params", test_params) + @pytest.mark.nightly + def test_where(self, params, ie_device, precision, temp_dir): + self._test(ie_device, precision, temp_dir, params) diff --git a/tests/layer_tests/tensorflow_lite_tests/test_tfl_ZerosLike.py b/tests/layer_tests/tensorflow_lite_tests/test_tfl_ZerosLike.py new file mode 100644 index 00000000000..f685009ccc4 --- /dev/null +++ b/tests/layer_tests/tensorflow_lite_tests/test_tfl_ZerosLike.py @@ -0,0 +1,41 @@ +import pytest +import tensorflow as tf + +from common.tflite_layer_test_class import TFLiteLayerTest + +test_params = [ + {'shape': ([None], [10])}, + {'shape': ([None, 10], [1, 10])}, + {'shape': ([1, 10, None], [1, 10, 10])}, + {'shape': ([None, 2, 3, 4], [5, 2, 3, 4])}, + {'shape': ([5, 2, None, 4], [5, 2, 3, 4])} +] + + +class TestTFLiteBatchMatmulLayerTest(TFLiteLayerTest): + inputs = ["Input"] + outputs = ["ZerosLike"] + allowed_ops = ['MAXIMUM', 'ZEROS_LIKE'] + + def _prepare_input(self, inputs_dict, generator=None): + import numpy as np + inputs_dict['Input'] = np.float32(np.random.randint(0, 100, self.shapes_set[1])) + return inputs_dict + + def make_model(self, params): + assert len(set(params.keys()).intersection({'shape'})) == 1, \ + 'Unexpected parameters for test: ' + ','.join(params.keys()) + tf.compat.v1.reset_default_graph() + self.shapes_set = params['shape'] + + with tf.compat.v1.Session() as sess: + placeholder = tf.compat.v1.placeholder(dtype=tf.float32, shape=self.shapes_set[0], name=self.inputs[0]) + zeros = tf.zeros_like(placeholder) + tf.maximum(zeros, placeholder, name=self.outputs[0]) + net = sess.graph_def + return net + + @pytest.mark.parametrize("params", test_params) + @pytest.mark.nightly + def test_zeros_like(self, params, ie_device, precision, temp_dir): + self._test(ie_device, precision, temp_dir, params) diff --git a/tests/layer_tests/tensorflow_tests/test_tf_GatherNd.py b/tests/layer_tests/tensorflow_tests/test_tf_GatherNd.py index 352cca1c6ff..e9d2eefab2b 100644 --- a/tests/layer_tests/tensorflow_tests/test_tf_GatherNd.py +++ b/tests/layer_tests/tensorflow_tests/test_tf_GatherNd.py @@ -12,7 +12,7 @@ class TestGatherNd(CommonTFLayerTest): assert 'params' in inputs_info assert 'indices' in inputs_info params_shape = inputs_info['params'] - indices_shape = inputs_info['indices'] + indices_shape = list(inputs_info['indices']) inputs_data = {} inputs_data['params'] = np.random.randint(-50, 50, params_shape).astype(self.params_type) # generate indices for each slice and concatenate