From 951be58dc344d6a8a7de8503d1d6e7ee636ce88e Mon Sep 17 00:00:00 2001 From: Georgy Krivoruchko Date: Mon, 14 Nov 2022 15:50:43 +0400 Subject: [PATCH] Extended coverage of TensorFlow layer tests (#13547) --- .../tensorflow_tests/test_tf_ArgMinMax.py | 70 ++++++++++++++ .../tensorflow_tests/test_tf_BinaryOps.py | 23 +++-- .../tensorflow_tests/test_tf_Cast.py | 73 +++++++++++++++ .../tensorflow_tests/test_tf_Fill.py | 59 ++++++++++++ .../tensorflow_tests/test_tf_LogicalOps.py | 75 +++++++++++++++ .../tensorflow_tests/test_tf_MatMul.py | 91 +++++++++++++++++++ .../tensorflow_tests/test_tf_MinMax.py | 66 ++++++++++++++ .../tensorflow_tests/test_tf_Pooling.py | 56 ++++-------- .../tensorflow_tests/test_tf_TopK.py | 40 +------- .../tensorflow_tests/test_tf_UnaryOps.py | 9 +- 10 files changed, 474 insertions(+), 88 deletions(-) create mode 100644 tests/layer_tests/tensorflow_tests/test_tf_ArgMinMax.py create mode 100644 tests/layer_tests/tensorflow_tests/test_tf_Cast.py create mode 100644 tests/layer_tests/tensorflow_tests/test_tf_Fill.py create mode 100644 tests/layer_tests/tensorflow_tests/test_tf_LogicalOps.py create mode 100644 tests/layer_tests/tensorflow_tests/test_tf_MatMul.py create mode 100644 tests/layer_tests/tensorflow_tests/test_tf_MinMax.py diff --git a/tests/layer_tests/tensorflow_tests/test_tf_ArgMinMax.py b/tests/layer_tests/tensorflow_tests/test_tf_ArgMinMax.py new file mode 100644 index 00000000000..4a82d51b689 --- /dev/null +++ b/tests/layer_tests/tensorflow_tests/test_tf_ArgMinMax.py @@ -0,0 +1,70 @@ +# Copyright (C) 2022 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +import numpy as np +import pytest + +from common.tf_layer_test_class import CommonTFLayerTest + +# Testing operation ArgMin, ArgMax (Initial Implementation) +# Documentation: https://www.tensorflow.org/api_docs/python/tf/raw_ops/ArgMin +# https://www.tensorflow.org/api_docs/python/tf/raw_ops/ArgMax + +class TestArgMinMax(CommonTFLayerTest): + # input_shape - should be an array + # dimension - dimension to be used, for vector should be 0 + # op_type - type of testing operation + # ir_version - common parameter + # use_new_frontend - common parameter + def create_argminmax_placeholder_const_net(self, input_shape, dimension, op_type, ir_version, use_new_frontend): + """ + Tensorflow net IR net + + Placeholder->op_type => Placeholder->TopK->Squeeze + / / / + Const-------/ Const-------/-----/ + + """ + + import tensorflow as tf + + tf.compat.v1.reset_default_graph() + + # Create the graph and model + with tf.compat.v1.Session() as sess: + op_type_to_tf = { + 'ArgMax': tf.raw_ops.ArgMax, + 'ArgMin': tf.raw_ops.ArgMin, + } + tf_input_shape = input_shape.copy() + tf_input = tf.compat.v1.placeholder(tf.float32, tf_input_shape, 'Input') + tf_dimension = tf.constant(dimension) + + op_type_to_tf[op_type](input = tf_input, dimension = tf_dimension) + + tf.compat.v1.global_variables_initializer() + tf_net = sess.graph_def + + ref_net = None + + return tf_net, ref_net + + test_data = [ + dict(input_shape=[5], dimension=0), #Simple test of vector + pytest.param( + dict(input_shape=[2, 3], dimension=1), #Simple test + marks=pytest.mark.precommit_tf_fe), + dict(input_shape=[2, 3, 3, 4], dimension=2), #Simple test with possible nchw/nhcw + ] + + @pytest.mark.parametrize("params", test_data) + @pytest.mark.parametrize("op_type", ['ArgMin', 'ArgMax']) + @pytest.mark.precommit + @pytest.mark.nightly + def test_argminmax_placeholder_const(self, params, op_type, ie_device, precision, ir_version, temp_dir, + use_new_frontend, use_old_api): + self._test(*self.create_argminmax_placeholder_const_net(**params, op_type=op_type, + ir_version=ir_version, + use_new_frontend=use_new_frontend), + ie_device, precision, ir_version, temp_dir=temp_dir, + use_new_frontend=use_new_frontend, use_old_api=use_old_api) diff --git a/tests/layer_tests/tensorflow_tests/test_tf_BinaryOps.py b/tests/layer_tests/tensorflow_tests/test_tf_BinaryOps.py index e54fa371640..cace1072348 100644 --- a/tests/layer_tests/tensorflow_tests/test_tf_BinaryOps.py +++ b/tests/layer_tests/tensorflow_tests/test_tf_BinaryOps.py @@ -46,6 +46,8 @@ class TestBinaryOps(CommonTFLayerTest): Const-------/ Const-------/ """ + if not use_new_frontend and op_type == "Xdivy": + pytest.xfail(reason="95499") self.current_op_type = op_type @@ -53,6 +55,7 @@ class TestBinaryOps(CommonTFLayerTest): op_type_to_tf = { 'Add': tf.math.add, + 'AddV2': tf.raw_ops.AddV2, 'Sub': tf.math.subtract, 'Mul': tf.math.multiply, 'Div': tf.math.divide, @@ -72,8 +75,12 @@ class TestBinaryOps(CommonTFLayerTest): 'LogicalOr': tf.math.logical_or, 'LogicalXor': tf.math.logical_xor, 'FloorMod': tf.math.floormod, + 'FloorDiv': tf.math.floordiv, + 'Xdivy': tf.raw_ops.Xdivy, } + op_type_kw_args = [ 'AddV2', 'Xdivy' ] + type = np.float32 if op_type in ["LogicalAnd", "LogicalOr", "LogicalXor"]: type = np.bool @@ -93,17 +100,14 @@ class TestBinaryOps(CommonTFLayerTest): constant_value = constant_value + 1 y = tf.constant(constant_value, dtype=type) - op = op_type_to_tf[op_type](x, y, name="Operation") + if not op_type in op_type_kw_args: + op = op_type_to_tf[op_type](x, y, name="Operation") + else: + op = op_type_to_tf[op_type](x = x, y = y, name="Operation") tf.compat.v1.global_variables_initializer() tf_net = sess.graph_def - # - # Create reference IR net - # Please, specify 'type': 'Input' for input node - # Moreover, do not forget to validate ALL layer attributes!!! - # - ref_net = None return tf_net, ref_net @@ -114,11 +118,12 @@ class TestBinaryOps(CommonTFLayerTest): @pytest.mark.parametrize("params", test_data_precommits) @pytest.mark.parametrize("op_type", - ['Add', 'Sub', 'Mul', 'Div', 'RealDiv', 'SquaredDifference', 'Pow', + ['Add', 'AddV2', 'Sub', 'Mul', 'Div', 'RealDiv', 'SquaredDifference', 'Pow', 'Maximum', 'Minimum', 'Equal', 'NotEqual', 'Mod', 'Greater', 'GreaterEqual', 'Less', 'LessEqual', - 'LogicalAnd', 'LogicalOr', 'LogicalXor', 'FloorMod']) + 'LogicalAnd', 'LogicalOr', 'LogicalXor', 'FloorMod', 'FloorDiv', + 'Xdivy']) @pytest.mark.nightly @pytest.mark.precommit def test_binary_op(self, params, ie_device, precision, ir_version, temp_dir, op_type, diff --git a/tests/layer_tests/tensorflow_tests/test_tf_Cast.py b/tests/layer_tests/tensorflow_tests/test_tf_Cast.py new file mode 100644 index 00000000000..3f0fa147d18 --- /dev/null +++ b/tests/layer_tests/tensorflow_tests/test_tf_Cast.py @@ -0,0 +1,73 @@ +# Copyright (C) 2022 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +import numpy as np +import pytest + +from common.tf_layer_test_class import CommonTFLayerTest + +# Testing Cast operation (Initial Implementation) +# Documentation: https://www.tensorflow.org/api_docs/python/tf/raw_ops/Cast + +class TestCastOp(CommonTFLayerTest): + input_type = np.float32 + + # Overload inputs generation to fill dummy input + def _prepare_input(self, inputs_dict): + for input in inputs_dict.keys(): + inputs_dict[input] = np.random.randint(-256, 256, inputs_dict[input]).astype(self.input_type) + return inputs_dict + + # input_shape - should be an array + # input_type - type of input value + # output_type - type of output value + # truncate - boolean flag of truncation + # ir_version - common parameter + # use_new_frontend - common parameter + def create_cast_op_placeholder_const_net(self, input_shape, input_type, output_type, truncate, ir_version, use_new_frontend): + if(input_type == output_type): + pytest.skip("Input and output types shouldn't be equal") + + import tensorflow as tf + + tf.compat.v1.reset_default_graph() + + self.input_type = input_type + + tf_types = { + np.int32: tf.int32, np.int64: tf.int64, + np.float16: tf.float16, np.float32: tf.float32, np.float64: tf.float64 + } + + # Create the graph and model + with tf.compat.v1.Session() as sess: + tf_input = tf.compat.v1.placeholder(tf_types[input_type], input_shape, 'Input') + + tf.raw_ops.Cast(x = input_shape, DstT = output_type, Truncate = truncate) + + tf.compat.v1.global_variables_initializer() + tf_net = sess.graph_def + + ref_net = None + + return tf_net, ref_net + + test_data = [ + pytest.param( + dict(input_shape=[2, 3]), #Simple test + marks=pytest.mark.precommit_tf_fe), + dict(input_shape=[2, 3, 3, 4]), #Simple test with possible nchw/nhwc + ] + + @pytest.mark.parametrize("params", test_data) + @pytest.mark.parametrize("input_type", [ np.int32, np.int64, np.float16, np.float32, np.float64 ]) + @pytest.mark.parametrize("output_type", [ np.int32, np.int64, np.float16, np.float32, np.float64 ]) + @pytest.mark.parametrize("truncate", [ False, True ]) + @pytest.mark.nightly + def test_cast_op_placeholder_const(self, params, input_type, output_type, truncate, ie_device, precision, ir_version, temp_dir, + use_new_frontend, use_old_api): + self._test(*self.create_cast_op_placeholder_const_net(**params, ir_version=ir_version, + use_new_frontend=use_new_frontend, input_type=input_type, + output_type=output_type, truncate=truncate), + ie_device, precision, ir_version, temp_dir=temp_dir, + use_new_frontend=use_new_frontend, use_old_api=use_old_api) \ No newline at end of file diff --git a/tests/layer_tests/tensorflow_tests/test_tf_Fill.py b/tests/layer_tests/tensorflow_tests/test_tf_Fill.py new file mode 100644 index 00000000000..46a4be61d0c --- /dev/null +++ b/tests/layer_tests/tensorflow_tests/test_tf_Fill.py @@ -0,0 +1,59 @@ +# Copyright (C) 2022 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +import numpy as np +import pytest + +from common.tf_layer_test_class import CommonTFLayerTest + +# Testing Fill operation (Initial Implementation) +# Documentation: https://www.tensorflow.org/api_docs/python/tf/raw_ops/Fill + +class TestFillOps(CommonTFLayerTest): + stored_shape = [] + + # Overload inputs generation to fill dummy input + def _prepare_input(self, inputs_dict): + for input in inputs_dict.keys(): + inputs_dict[input] = np.ndarray([len(self.stored_shape)], buffer=np.array(self.stored_shape), dtype=np.int32) + # Return shape as is + return inputs_dict + + # input_shape - should be an array + # value - value which should be set to tensor + # ir_version - common parameter + # use_new_frontend - common parameter + def create_fill_ops_placeholder_const_net(self, input_shape, value, ir_version, use_new_frontend): + self.stored_shape = input_shape + + import tensorflow as tf + + tf.compat.v1.reset_default_graph() + + # Create the graph and model + with tf.compat.v1.Session() as sess: + tf_input = tf.compat.v1.placeholder(tf.int32, [len(input_shape)], 'Input') + + tf.raw_ops.Fill(dims = tf_input, value = value) + + tf.compat.v1.global_variables_initializer() + tf_net = sess.graph_def + + ref_net = None + + return tf_net, ref_net + + test_data = [ + dict(input_shape=[2, 3], value=123), + dict(input_shape=[2, 3, 3, 4], value=123), + ] + + @pytest.mark.parametrize("params", test_data) + @pytest.mark.precommit_tf_fe + @pytest.mark.nightly + def test_fill_ops_placeholder_const(self, params, ie_device, precision, ir_version, temp_dir, + use_new_frontend, use_old_api): + self._test(*self.create_fill_ops_placeholder_const_net(**params, ir_version=ir_version, + use_new_frontend=use_new_frontend), + ie_device, precision, ir_version, temp_dir=temp_dir, + use_new_frontend=use_new_frontend, use_old_api=use_old_api) \ No newline at end of file diff --git a/tests/layer_tests/tensorflow_tests/test_tf_LogicalOps.py b/tests/layer_tests/tensorflow_tests/test_tf_LogicalOps.py new file mode 100644 index 00000000000..f7adbc60c70 --- /dev/null +++ b/tests/layer_tests/tensorflow_tests/test_tf_LogicalOps.py @@ -0,0 +1,75 @@ +# Copyright (C) 2022 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +import numpy as np +import pytest + +from common.tf_layer_test_class import CommonTFLayerTest + +# Testing Logical operations (Initial Implementation) +# Documentation: https://www.tensorflow.org/api_docs/python/tf/raw_ops/All +# https://www.tensorflow.org/api_docs/python/tf/raw_ops/Any + +class TestLogicalOps(CommonTFLayerTest): + # Overload inputs generation to fill dummy input + def _prepare_input(self, inputs_dict): + for input in inputs_dict.keys(): + inputs_dict[input] = np.random.randint(0, 2, inputs_dict[input]).astype(bool) + return inputs_dict + + # input_shape - should be an array + # axis - array which points on axis for the operation + # op_type - type of tested operation + # ir_version - common parameter + # use_new_frontend - common parameter + def create_logical_ops_placeholder_const_net(self, input_shape, axis, op_type, ir_version, use_new_frontend): + """ + Tensorflow net IR net + + Placeholder->op_type => Placeholder->ReduceLogicalAnd/Or + / / + Const-------/ Const-------/ + + """ + if not use_new_frontend and op_type == "Any": + pytest.xfail(reason="95499") + + import tensorflow as tf + + tf.compat.v1.reset_default_graph() + + # Create the graph and model + with tf.compat.v1.Session() as sess: + op_type_to_tf = { + 'All': tf.raw_ops.All, + 'Any': tf.raw_ops.Any, + } + tf_input = tf.compat.v1.placeholder(tf.bool, input_shape, 'Input') + tf_axis = tf.constant(axis) + + op_type_to_tf[op_type](input = tf_input, axis = tf_axis) + + tf.compat.v1.global_variables_initializer() + tf_net = sess.graph_def + + ref_net = None + + return tf_net, ref_net + + test_data = [ + pytest.param( + dict(input_shape=[2, 3], axis=[1]), #Simple test + marks=pytest.mark.precommit_tf_fe), + dict(input_shape=[2, 3, 3, 4], axis=[2]), #Simple test with possible nchw/nhwc + ] + + @pytest.mark.parametrize("params", test_data) + @pytest.mark.parametrize("op_type", ['All', 'Any']) + @pytest.mark.precommit + @pytest.mark.nightly + def test_logical_ops_placeholder_const(self, params, op_type, ie_device, precision, ir_version, temp_dir, + use_new_frontend, use_old_api): + self._test(*self.create_logical_ops_placeholder_const_net(**params, op_type=op_type, ir_version=ir_version, + use_new_frontend=use_new_frontend), + ie_device, precision, ir_version, temp_dir=temp_dir, + use_new_frontend=use_new_frontend, use_old_api=use_old_api) diff --git a/tests/layer_tests/tensorflow_tests/test_tf_MatMul.py b/tests/layer_tests/tensorflow_tests/test_tf_MatMul.py new file mode 100644 index 00000000000..208283b35a3 --- /dev/null +++ b/tests/layer_tests/tensorflow_tests/test_tf_MatMul.py @@ -0,0 +1,91 @@ +# Copyright (C) 2022 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +import numpy as np +import pytest +from common.tf_layer_test_class import CommonTFLayerTest + + +class TestMatMul(CommonTFLayerTest): + + def create_net_with_matmul_op(self, x_shape, y_shape, x_bool, y_bool, op_type, ir_version, use_new_frontend): + import tensorflow as tf + op_type_to_tf = { + 'BatchMatMul': tf.raw_ops.BatchMatMul, + 'BatchMatMulV2': tf.raw_ops.BatchMatMulV2, + 'BatchMatMulV3': tf.raw_ops.BatchMatMulV3, + 'MatMul': tf.raw_ops.MatMul, + } + + tf.compat.v1.reset_default_graph() + + # Create the graph and model + with tf.compat.v1.Session() as sess: + tf_x = tf.compat.v1.placeholder(tf.float32, x_shape, 'InputX') + tf_y = tf.compat.v1.placeholder(tf.float32, y_shape, 'InputY') + if op_type == 'MatMul': + if len(x_shape) != 2 or len(y_shape) != 2: + pytest.skip("MatMul doesn't support rank != 2") + op_type_to_tf[op_type](a=tf_x, b=tf_y, transpose_a=x_bool, transpose_b=y_bool, name='Operation') + elif op_type == 'BatchMatMul': + if len(x_shape) != len(y_shape): + pytest.skip("BatchMatMul doesn't support broadcast") + op_type_to_tf[op_type](x=tf_x, y=tf_y, adj_x=x_bool, adj_y=y_bool, name='Operation') + elif op_type == 'BatchMatMulV2': + op_type_to_tf[op_type](x=tf_x, y=tf_y, adj_x=x_bool, adj_y=y_bool, name='Operation') + elif op_type == 'BatchMatMulV3': + op_type_to_tf[op_type](x=tf_x, y=tf_y, Tout=tf.float32, adj_x=x_bool, adj_y=y_bool, name='Operation') + else: + raise RuntimeError("Undknown operation") + + tf.compat.v1.global_variables_initializer() + tf_net = sess.graph_def + + ref_net = None + + return tf_net, ref_net + + test_data_precommit = [ + dict(x_shape=[2, 4, 4], y_shape=[2, 4, 4]), #Tests 2D shapes + dict(x_shape=[2, 3, 4, 4], y_shape=[4, 4]), #Tests broadcast + ] + + @pytest.mark.parametrize("params", test_data_precommit) + @pytest.mark.parametrize("op_type", ['BatchMatMul', + 'BatchMatMulV2', + #'BatchMatMulV3', #Isn't supported + 'MatMul', + ]) + @pytest.mark.precommit_tf_fe + def test_matmul_op_precommit(self, params, ie_device, precision, ir_version, temp_dir, op_type, + use_new_frontend, use_old_api): + self._test(*self.create_net_with_matmul_op(**params, ir_version=ir_version, op_type=op_type, + use_new_frontend=use_new_frontend, x_bool=False, y_bool=False), + ie_device, precision, ir_version, temp_dir=temp_dir, + use_new_frontend=use_new_frontend, use_old_api=use_old_api) + + test_data = test_data_precommit + [ + dict(x_shape=[2, 3, 4, 4], y_shape=[2, 3, 4, 4]), #Tests 4D shapes + ] + + @pytest.mark.parametrize("params", test_data) + @pytest.mark.parametrize("op_type", ['BatchMatMul', + 'BatchMatMulV2', + #'BatchMatMulV3', #Isn't supported + 'MatMul', + ]) + @pytest.mark.parametrize("x_bool", [ + False, + True + ]) + @pytest.mark.parametrize("y_bool", [ + False, + True + ]) + @pytest.mark.nightly + def test_matmul_op_nightly(self, params, ie_device, precision, ir_version, temp_dir, op_type, + x_bool, y_bool, use_new_frontend, use_old_api): + self._test(*self.create_net_with_matmul_op(**params, ir_version=ir_version, op_type=op_type, + use_new_frontend=use_new_frontend, x_bool=x_bool, y_bool=y_bool), + ie_device, precision, ir_version, temp_dir=temp_dir, + use_new_frontend=use_new_frontend, use_old_api=use_old_api) diff --git a/tests/layer_tests/tensorflow_tests/test_tf_MinMax.py b/tests/layer_tests/tensorflow_tests/test_tf_MinMax.py new file mode 100644 index 00000000000..06ff92cb14f --- /dev/null +++ b/tests/layer_tests/tensorflow_tests/test_tf_MinMax.py @@ -0,0 +1,66 @@ +# Copyright (C) 2022 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +import numpy as np +import pytest + +from common.tf_layer_test_class import CommonTFLayerTest + +# Testing Min, Max operations (Initial Implementation) +# Documentation: https://www.tensorflow.org/api_docs/python/tf/raw_ops/Min +# https://www.tensorflow.org/api_docs/python/tf/raw_ops/Max + +class TestMinMaxOps(CommonTFLayerTest): + # input_shape - should be an array + # axis - array which points on axis for the operation + # op_type - type of tested operation + # ir_version - common parameter + # use_new_frontend - common parameter + def create_minmax_ops_placeholder_const_net(self, input_shape, axis, op_type, keep_dims, ir_version, use_new_frontend): + """ + Tensorflow net IR net + + Placeholder->op_type => Placeholder->ReduceMin/Max + + """ + + import tensorflow as tf + + tf.compat.v1.reset_default_graph() + + # Create the graph and model + with tf.compat.v1.Session() as sess: + op_type_to_tf = { + 'Min': tf.raw_ops.Min, + 'Max': tf.raw_ops.Max, + } + tf_input = tf.compat.v1.placeholder(tf.float32, input_shape, 'Input') + tf_axis = tf.constant(axis) + + op_type_to_tf[op_type](input = tf_input, axis = tf_axis, keep_dims = keep_dims) + + tf.compat.v1.global_variables_initializer() + tf_net = sess.graph_def + + ref_net = None + + return tf_net, ref_net + + test_data = [ + pytest.param( + dict(input_shape=[2, 3], axis=[1]), #Simple test + marks=pytest.mark.precommit_tf_fe), + dict(input_shape=[2, 3, 3, 4], axis=[2]), #Simple test with possible nchw/nhwc + ] + + @pytest.mark.parametrize("params", test_data) + @pytest.mark.parametrize("op_type", ['Min', 'Max']) + @pytest.mark.parametrize("keep_dims", [False, True]) + @pytest.mark.precommit + @pytest.mark.nightly + def test_minmax_ops_placeholder_const(self, params, op_type, keep_dims, ie_device, precision, ir_version, temp_dir, + use_new_frontend, use_old_api): + self._test(*self.create_minmax_ops_placeholder_const_net(**params, op_type=op_type, ir_version=ir_version, + use_new_frontend=use_new_frontend, keep_dims=keep_dims), + ie_device, precision, ir_version, temp_dir=temp_dir, + use_new_frontend=use_new_frontend, use_old_api=use_old_api) \ No newline at end of file diff --git a/tests/layer_tests/tensorflow_tests/test_tf_Pooling.py b/tests/layer_tests/tensorflow_tests/test_tf_Pooling.py index 2ae54cd1a3d..daa8094d9bc 100644 --- a/tests/layer_tests/tensorflow_tests/test_tf_Pooling.py +++ b/tests/layer_tests/tensorflow_tests/test_tf_Pooling.py @@ -36,10 +36,10 @@ class TestPooling(CommonTFLayerTest): kernel = [1, kernel_size[0], kernel_size[1], 1] if method == 'max': - tf.nn.max_pool2d(input=input, ksize=kernel, strides=stride, padding=padding, + tf.raw_ops.MaxPool(input=input, ksize=kernel, strides=stride, padding=padding, name='Operation') elif method == 'avg': - tf.nn.avg_pool2d(input=input, ksize=kernel, strides=stride, padding=padding, + tf.raw_ops.AvgPool(value=input, ksize=kernel, strides=stride, padding=padding, name='Operation') # 5D tensors @@ -51,51 +51,17 @@ class TestPooling(CommonTFLayerTest): kernel = [1, kernel_size[0], kernel_size[1], kernel_size[2], 1] if method == 'max': - tf.nn.max_pool3d(input, kernel, stride, padding, - name='Operation') # , data_format='NCHW') + tf.raw_ops.MaxPool3D(input=input, ksize=kernel, strides=stride, padding=padding, + name='Operation') elif method == 'avg': - tf.nn.avg_pool3d(input, kernel, stride, padding, - name='Operation') # , data_format='NCHW') + tf.raw_ops.AvgPool3D(input=input, ksize=kernel, strides=stride, padding=padding, + name='Operation') tf.compat.v1.global_variables_initializer() tf_net = sess.graph_def - # - # Create reference IR net - # Please, specify 'type': 'Input' for input node - # Moreover, do not forget to validate ALL layer attributes!!! - # - ref_net = None - if check_ir_version(10, None, ir_version) and not use_new_frontend: - nodes_attributes = { - 'input': {'kind': 'op', 'type': 'Parameter'}, - 'input_data': {'shape': in_shape, 'kind': 'data'}, - 'pooling': {'kernel': kernel_size, 'pads_begin': pads_begin, 'pads_end': pads_end, - 'strides': strides, 'kind': 'op', 'type': None}, - 'pooling_data': {'shape': out_shape, 'kind': 'data'}, - 'result': {'kind': 'op', 'type': 'Result'}, - 'pooling_indicies_data': {'kind': 'data', 'shape': out_shape} - } - - if method == 'avg': - nodes_attributes['pooling']['type'] = 'AvgPool' - elif method == 'max': - nodes_attributes['pooling']['type'] = 'MaxPool' - - edges = [('input', 'input_data'), - ('input_data', 'pooling'), - ('pooling', 'pooling_data', {'out': 0}), - ('pooling_data', 'result')] - - if method == 'max': - edges.append(('pooling', 'pooling_indicies_data', {'out': 1})) - - ref_net = build_graph(nodes_attributes, - edges=edges, - nodes_with_edges_only=True) - return tf_net, ref_net test_data_4D = [] @@ -103,9 +69,11 @@ class TestPooling(CommonTFLayerTest): test_data_4D.extend([dict(kernel_size=[1, 1], strides=[1, 1], pads=[[0, 0], [0, 0], 'SAME'], in_shape=[1, 3, 224, 224], out_shape=[1, 3, 224, 224], method=method), + pytest.param( dict(kernel_size=[2, 2], strides=[2, 2], pads=[[0, 0], [0, 0], 'SAME'], in_shape=[1, 3, 224, 224], out_shape=[1, 3, 112, 112], method=method), + marks=pytest.mark.precommit_tf_fe), dict(kernel_size=[2, 4], strides=[2, 4], pads=[[0, 0], [0, 0], 'SAME'], in_shape=[1, 3, 224, 224], out_shape=[1, 3, 112, 56], method=method), @@ -127,9 +95,11 @@ class TestPooling(CommonTFLayerTest): dict(kernel_size=[2, 3], strides=[2, 3], pads=[[0, 0], [0, 1], 'SAME'], in_shape=[1, 3, 224, 224], out_shape=[1, 3, 112, 75], method=method), + pytest.param( dict(kernel_size=[111, 111], strides=[111, 111], pads=[[54, 54], [55, 55], 'SAME'], in_shape=[1, 3, 224, 224], out_shape=[1, 3, 3, 3], method=method), + marks=pytest.mark.precommit_tf_fe), dict(kernel_size=[111, 113], strides=[111, 113], pads=[[54, 1], [55, 1], 'SAME'], in_shape=[1, 3, 224, 224], out_shape=[1, 3, 3, 2], method=method), @@ -146,8 +116,10 @@ class TestPooling(CommonTFLayerTest): in_shape=[1, 3, 224, 224], out_shape=[1, 3, 224, 224], method=method), dict(kernel_size=[2, 2], strides=[2, 2], pads=[[0, 0], [0, 0], 'VALID'], in_shape=[1, 3, 224, 224], out_shape=[1, 3, 112, 112], method=method), + pytest.param( dict(kernel_size=[2, 4], strides=[2, 4], pads=[[0, 0], [0, 0], 'VALID'], in_shape=[1, 3, 224, 224], out_shape=[1, 3, 112, 56], method=method), + marks=pytest.mark.precommit_tf_fe), dict(kernel_size=[4, 2], strides=[4, 2], pads=[[0, 0], [0, 0], 'VALID'], in_shape=[1, 3, 224, 224], out_shape=[1, 3, 56, 112], method=method), dict(kernel_size=[2, 3], strides=[2, 3], pads=[[0, 0], [0, 0], 'VALID'], @@ -185,8 +157,10 @@ class TestPooling(CommonTFLayerTest): test_data_5D.extend( [dict(kernel_size=[1, 1, 1], strides=[1, 1, 1], pads=[[0, 0, 0], [0, 0, 0], 'SAME'], in_shape=[1, 3, 224, 224, 224], out_shape=[1, 3, 224, 224, 224], method=method), + pytest.param( dict(kernel_size=[2, 2, 2], strides=[2, 2, 2], pads=[[0, 0, 0], [0, 0, 0], 'SAME'], in_shape=[1, 3, 224, 224, 224], out_shape=[1, 3, 112, 112, 112], method=method), + marks=pytest.mark.precommit_tf_fe), dict(kernel_size=[2, 2, 4], strides=[2, 2, 4], pads=[[0, 0, 0], [0, 0, 0], 'SAME'], in_shape=[1, 3, 224, 224, 224], out_shape=[1, 3, 112, 112, 56], method=method), dict(kernel_size=[4, 2, 2], strides=[4, 2, 2], pads=[[0, 0, 0], [0, 0, 0], 'SAME'], @@ -217,8 +191,10 @@ class TestPooling(CommonTFLayerTest): test_data_5D.extend( [dict(kernel_size=[1, 1, 1], strides=[1, 1, 1], pads=[[0, 0, 0], [0, 0, 0], 'VALID'], in_shape=[1, 3, 224, 224, 224], out_shape=[1, 3, 224, 224, 224], method=method), + pytest.param( dict(kernel_size=[2, 2, 2], strides=[2, 2, 2], pads=[[0, 0, 0], [0, 0, 0], 'VALID'], in_shape=[1, 3, 224, 224, 224], out_shape=[1, 3, 112, 112, 112], method=method), + marks=pytest.mark.precommit_tf_fe), dict(kernel_size=[2, 2, 4], strides=[2, 2, 4], pads=[[0, 0, 0], [0, 0, 0], 'VALID'], in_shape=[1, 3, 224, 224, 224], out_shape=[1, 3, 112, 112, 56], method=method), dict(kernel_size=[4, 2, 2], strides=[4, 2, 2], pads=[[0, 0, 0], [0, 0, 0], 'VALID'], diff --git a/tests/layer_tests/tensorflow_tests/test_tf_TopK.py b/tests/layer_tests/tensorflow_tests/test_tf_TopK.py index a2c707b4592..feb19110dac 100644 --- a/tests/layer_tests/tensorflow_tests/test_tf_TopK.py +++ b/tests/layer_tests/tensorflow_tests/test_tf_TopK.py @@ -29,6 +29,8 @@ class Test_TopK(CommonTFLayerTest): """ + pytest.xfail(reason="95063") + import tensorflow as tf tf.compat.v1.reset_default_graph() @@ -43,46 +45,8 @@ class Test_TopK(CommonTFLayerTest): tf.compat.v1.global_variables_initializer() tf_net = sess.graph_def - # - # Create reference IR net - # - topk_output_shape = shape.copy() - inverse_nhwc_nchw = PermuteAttrs.get_nhwc_to_nchw_permutation(len(topk_output_shape)).inv - topk_axis = permute_axis(len(topk_output_shape) - 1, - inverse_nhwc_nchw) # we need to permute axis attribute - topk_output_shape[topk_axis] = k - ref_net = None - if check_ir_version(10, None, ir_version) and not use_new_frontend: - nodes_attributes = { - 'input': {'kind': 'op', 'type': 'Parameter'}, - 'input_data': {'shape': shape, 'kind': 'data'}, - 'Const_k_input_data': {'shape': [], 'kind': 'data'}, - 'Const_k': {'kind': 'op', 'type': 'Const'}, - 'Const_k_data': {'shape': [], 'kind': 'data'}, - 'TopK': {'kind': 'op', 'type': 'TopK', 'axis': topk_axis, 'mode': 'max', - 'sort': 'value'}, - 'TopK_data_1': {'shape': topk_output_shape, 'kind': 'data'}, - 'TopK_data_2': {'shape': topk_output_shape, 'kind': 'data'}, - 'result_1': {'kind': 'op', 'type': 'Result'}, - 'result_2': {'kind': 'op', 'type': 'Result'}, - } - - ref_net = build_graph(nodes_attributes, - [('input', 'input_data'), - ('input_data', 'TopK', {'in': 0}), - - ('Const_k_input_data', 'Const_k'), - ('Const_k', 'Const_k_data'), - ('Const_k_data', 'TopK', {'in': 1}), - - ('TopK', 'TopK_data_1', {'out': 0}), - ('TopK', 'TopK_data_2', {'out': 1}), - ('TopK_data_1', 'result_1'), - ('TopK_data_2', 'result_2'), - ]) - return tf_net, ref_net test_data_1D = [ diff --git a/tests/layer_tests/tensorflow_tests/test_tf_UnaryOps.py b/tests/layer_tests/tensorflow_tests/test_tf_UnaryOps.py index 72f2cc88ede..ff7aec52415 100644 --- a/tests/layer_tests/tensorflow_tests/test_tf_UnaryOps.py +++ b/tests/layer_tests/tensorflow_tests/test_tf_UnaryOps.py @@ -67,6 +67,7 @@ class TestUnaryOps(CommonTFLayerTest): 'Cos': tf.math.cos, 'Cosh': tf.math.cosh, 'Elu': tf.nn.elu, + 'Erf': tf.math.erf, 'Exp': tf.math.exp, 'Floor': tf.math.floor, 'Log': tf.math.log, @@ -77,6 +78,7 @@ class TestUnaryOps(CommonTFLayerTest): 'Sin': tf.math.sin, 'Sinh': tf.math.sinh, 'SoftPlus': tf.nn.softplus, + 'Square': tf.math.square, 'Tan': tf.math.tan, 'Tanh': tf.math.tanh, 'ReLU': tf.nn.relu, @@ -151,6 +153,8 @@ class TestUnaryOps(CommonTFLayerTest): 'Acosh', 'Asinh', 'LogicalNot', + 'Square', + 'Erf', ]) @pytest.mark.precommit def test_unary_op_precommit(self, params, ie_device, precision, ir_version, temp_dir, op_type, @@ -191,7 +195,10 @@ class TestUnaryOps(CommonTFLayerTest): 'SoftPlus', 'Atanh', 'Acosh', - 'Asinh']) + 'Asinh', + 'Square', + 'Erf', + ]) @pytest.mark.nightly def test_unary_op(self, params, ie_device, precision, ir_version, temp_dir, op_type, use_new_frontend, use_old_api):