diff --git a/src/frontends/tensorflow/src/decoder_proto.cpp b/src/frontends/tensorflow/src/decoder_proto.cpp index 309f75ed886..f42c9bfbb4c 100644 --- a/src/frontends/tensorflow/src/decoder_proto.cpp +++ b/src/frontends/tensorflow/src/decoder_proto.cpp @@ -67,6 +67,9 @@ void extract_compressed_tensor_content(const ::tensorflow::TensorProto& tensor_p case ov::element::i64: val_i = tensor_proto.int64_val()[i]; break; + case ov::element::f16: + val_i = float16::from_bits(tensor_proto.half_val()[i]); + break; case ov::element::f32: val_i = tensor_proto.float_val()[i]; break; @@ -227,6 +230,10 @@ ov::Any DecoderProto::get_attribute(const std::string& name) const { val_size = tensor_proto.int64_val_size(); extract_compressed_tensor_content(tensor_proto, val_size, &res); break; + case ov::element::f16: + val_size = tensor_proto.half_val_size(); + extract_compressed_tensor_content(tensor_proto, val_size, &res); + break; case ov::element::f32: val_size = tensor_proto.float_val_size(); extract_compressed_tensor_content(tensor_proto, val_size, &res); diff --git a/tests/layer_tests/tensorflow_tests/test_tf_BiasAdd.py b/tests/layer_tests/tensorflow_tests/test_tf_BiasAdd.py index f851820afe2..90f74819f92 100644 --- a/tests/layer_tests/tensorflow_tests/test_tf_BiasAdd.py +++ b/tests/layer_tests/tensorflow_tests/test_tf_BiasAdd.py @@ -5,10 +5,11 @@ import pytest from common.tf_layer_test_class import CommonTFLayerTest from common.utils.tf_utils import permute_nchw_to_nhwc - +import tensorflow as tf +import numpy as np class TestBiasAdd(CommonTFLayerTest): - def create_bias_add_placeholder_const_net(self, shape, ir_version, use_new_frontend): + def create_bias_add_placeholder_const_net(self, shape, ir_version, use_new_frontend, output_type=tf.float32): """ Tensorflow net IR net @@ -18,9 +19,6 @@ class TestBiasAdd(CommonTFLayerTest): """ - import tensorflow as tf - import numpy as np - tf.compat.v1.reset_default_graph() # Create the graph and model @@ -30,8 +28,8 @@ class TestBiasAdd(CommonTFLayerTest): tf_x_shape = permute_nchw_to_nhwc(tf_x_shape, use_new_frontend) tf_y_shape = tf_x_shape[-1:] - x = tf.compat.v1.placeholder(tf.float32, tf_x_shape, 'Input') - constant_value = np.random.randint(0, 1, tf_y_shape).astype(np.float32) + x = tf.compat.v1.placeholder(output_type, tf_x_shape, 'Input') + constant_value = np.random.randint(0, 1, tf_y_shape).astype(output_type.as_numpy_dtype()) if (constant_value == 0).all(): # Avoid elimination of the layer from IR constant_value = constant_value + 1 @@ -42,17 +40,11 @@ class TestBiasAdd(CommonTFLayerTest): tf.compat.v1.global_variables_initializer() tf_net = sess.graph_def - # - # Create reference IR net - # Please, specify 'type': 'Input' for input node - # Moreover, do not forget to validate ALL layer attributes!!! - # - ref_net = None return tf_net, ref_net - def create_bias_add_2_consts_net(self, shape, ir_version, use_new_frontend): + def create_bias_add_2_consts_net(self, shape, ir_version, use_new_frontend, output_type=tf.float32): """ Tensorflow net IR net @@ -68,9 +60,6 @@ class TestBiasAdd(CommonTFLayerTest): # Create Tensorflow model # - import tensorflow as tf - import numpy as np - tf.compat.v1.reset_default_graph() tf_concat_axis = -1 @@ -82,14 +71,14 @@ class TestBiasAdd(CommonTFLayerTest): tf_x_shape = permute_nchw_to_nhwc(tf_x_shape, use_new_frontend) tf_y_shape = tf_x_shape[-1:] - constant_value_x = np.random.randint(-256, 256, tf_x_shape).astype(np.float32) + constant_value_x = np.random.randint(-256, 256, tf_x_shape).astype(output_type.as_numpy_dtype()) x = tf.constant(constant_value_x) - constant_value_y = np.random.randint(-256, 256, tf_y_shape).astype(np.float32) + constant_value_y = np.random.randint(-256, 256, tf_y_shape).astype(output_type.as_numpy_dtype()) y = tf.constant(constant_value_y) add = tf.nn.bias_add(x, y, name="Operation") - placeholder = tf.compat.v1.placeholder(tf.float32, tf_x_shape, + placeholder = tf.compat.v1.placeholder(output_type, tf_x_shape, 'Input') # Input_1 in graph_def concat = tf.concat([placeholder, add], axis=tf_concat_axis, name='Operation') @@ -97,12 +86,6 @@ class TestBiasAdd(CommonTFLayerTest): tf.compat.v1.global_variables_initializer() tf_net = sess.graph_def - # - # Create reference IR net - # Please, specify 'type': 'Input' for input node - # Moreover, do not forget to validate ALL layer attributes!!! - # - ref_net = None return tf_net, ref_net @@ -155,7 +138,8 @@ class TestBiasAdd(CommonTFLayerTest): test_data_4D = [ dict(shape=[1, 1, 100, 224]), - pytest.param(dict(shape=[1, 3, 100, 224]), marks=pytest.mark.precommit_tf_fe) + pytest.param(dict(shape=[1, 3, 100, 224]), marks=pytest.mark.precommit_tf_fe), + pytest.param(dict(shape=[1, 3, 100, 224], output_type=tf.float16), marks=pytest.mark.precommit_tf_fe) ] @pytest.mark.parametrize("params", test_data_4D)