Added support of float16 in decoder proto, test (#13891)

This commit is contained in:
Georgy Krivoruchko 2022-11-29 17:17:04 +04:00 committed by GitHub
parent 090261c852
commit c4e15e0883
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 18 additions and 27 deletions

View File

@ -67,6 +67,9 @@ void extract_compressed_tensor_content(const ::tensorflow::TensorProto& tensor_p
case ov::element::i64:
val_i = tensor_proto.int64_val()[i];
break;
case ov::element::f16:
val_i = float16::from_bits(tensor_proto.half_val()[i]);
break;
case ov::element::f32:
val_i = tensor_proto.float_val()[i];
break;
@ -227,6 +230,10 @@ ov::Any DecoderProto::get_attribute(const std::string& name) const {
val_size = tensor_proto.int64_val_size();
extract_compressed_tensor_content<int64_t>(tensor_proto, val_size, &res);
break;
case ov::element::f16:
val_size = tensor_proto.half_val_size();
extract_compressed_tensor_content<float16>(tensor_proto, val_size, &res);
break;
case ov::element::f32:
val_size = tensor_proto.float_val_size();
extract_compressed_tensor_content<float>(tensor_proto, val_size, &res);

View File

@ -5,10 +5,11 @@ import pytest
from common.tf_layer_test_class import CommonTFLayerTest
from common.utils.tf_utils import permute_nchw_to_nhwc
import tensorflow as tf
import numpy as np
class TestBiasAdd(CommonTFLayerTest):
def create_bias_add_placeholder_const_net(self, shape, ir_version, use_new_frontend):
def create_bias_add_placeholder_const_net(self, shape, ir_version, use_new_frontend, output_type=tf.float32):
"""
Tensorflow net IR net
@ -18,9 +19,6 @@ class TestBiasAdd(CommonTFLayerTest):
"""
import tensorflow as tf
import numpy as np
tf.compat.v1.reset_default_graph()
# Create the graph and model
@ -30,8 +28,8 @@ class TestBiasAdd(CommonTFLayerTest):
tf_x_shape = permute_nchw_to_nhwc(tf_x_shape, use_new_frontend)
tf_y_shape = tf_x_shape[-1:]
x = tf.compat.v1.placeholder(tf.float32, tf_x_shape, 'Input')
constant_value = np.random.randint(0, 1, tf_y_shape).astype(np.float32)
x = tf.compat.v1.placeholder(output_type, tf_x_shape, 'Input')
constant_value = np.random.randint(0, 1, tf_y_shape).astype(output_type.as_numpy_dtype())
if (constant_value == 0).all():
# Avoid elimination of the layer from IR
constant_value = constant_value + 1
@ -42,17 +40,11 @@ class TestBiasAdd(CommonTFLayerTest):
tf.compat.v1.global_variables_initializer()
tf_net = sess.graph_def
#
# Create reference IR net
# Please, specify 'type': 'Input' for input node
# Moreover, do not forget to validate ALL layer attributes!!!
#
ref_net = None
return tf_net, ref_net
def create_bias_add_2_consts_net(self, shape, ir_version, use_new_frontend):
def create_bias_add_2_consts_net(self, shape, ir_version, use_new_frontend, output_type=tf.float32):
"""
Tensorflow net IR net
@ -68,9 +60,6 @@ class TestBiasAdd(CommonTFLayerTest):
# Create Tensorflow model
#
import tensorflow as tf
import numpy as np
tf.compat.v1.reset_default_graph()
tf_concat_axis = -1
@ -82,14 +71,14 @@ class TestBiasAdd(CommonTFLayerTest):
tf_x_shape = permute_nchw_to_nhwc(tf_x_shape, use_new_frontend)
tf_y_shape = tf_x_shape[-1:]
constant_value_x = np.random.randint(-256, 256, tf_x_shape).astype(np.float32)
constant_value_x = np.random.randint(-256, 256, tf_x_shape).astype(output_type.as_numpy_dtype())
x = tf.constant(constant_value_x)
constant_value_y = np.random.randint(-256, 256, tf_y_shape).astype(np.float32)
constant_value_y = np.random.randint(-256, 256, tf_y_shape).astype(output_type.as_numpy_dtype())
y = tf.constant(constant_value_y)
add = tf.nn.bias_add(x, y, name="Operation")
placeholder = tf.compat.v1.placeholder(tf.float32, tf_x_shape,
placeholder = tf.compat.v1.placeholder(output_type, tf_x_shape,
'Input') # Input_1 in graph_def
concat = tf.concat([placeholder, add], axis=tf_concat_axis, name='Operation')
@ -97,12 +86,6 @@ class TestBiasAdd(CommonTFLayerTest):
tf.compat.v1.global_variables_initializer()
tf_net = sess.graph_def
#
# Create reference IR net
# Please, specify 'type': 'Input' for input node
# Moreover, do not forget to validate ALL layer attributes!!!
#
ref_net = None
return tf_net, ref_net
@ -155,7 +138,8 @@ class TestBiasAdd(CommonTFLayerTest):
test_data_4D = [
dict(shape=[1, 1, 100, 224]),
pytest.param(dict(shape=[1, 3, 100, 224]), marks=pytest.mark.precommit_tf_fe)
pytest.param(dict(shape=[1, 3, 100, 224]), marks=pytest.mark.precommit_tf_fe),
pytest.param(dict(shape=[1, 3, 100, 224], output_type=tf.float16), marks=pytest.mark.precommit_tf_fe)
]
@pytest.mark.parametrize("params", test_data_4D)