Temporary revert "[TF FE] Support different types" (#21656)
This reverts commit f9d20d5aa0
.
This commit is contained in:
parent
825778308b
commit
bf00569ae1
@ -18,12 +18,25 @@ namespace frontend {
|
||||
namespace tensorflow {
|
||||
|
||||
bool GraphIteratorMeta::is_valid_signature(const ::tensorflow::SignatureDef& signature) const {
|
||||
const std::map<::tensorflow::DataType, ov::element::Type> types{
|
||||
{::tensorflow::DataType::DT_BOOL, ov::element::boolean},
|
||||
{::tensorflow::DataType::DT_INT16, ov::element::i16},
|
||||
{::tensorflow::DataType::DT_INT32, ov::element::i32},
|
||||
{::tensorflow::DataType::DT_INT64, ov::element::i64},
|
||||
{::tensorflow::DataType::DT_HALF, ov::element::f16},
|
||||
{::tensorflow::DataType::DT_FLOAT, ov::element::f32},
|
||||
{::tensorflow::DataType::DT_DOUBLE, ov::element::f64},
|
||||
{::tensorflow::DataType::DT_UINT8, ov::element::u8},
|
||||
{::tensorflow::DataType::DT_INT8, ov::element::i8},
|
||||
{::tensorflow::DataType::DT_BFLOAT16, ov::element::bf16},
|
||||
{::tensorflow::DataType::DT_STRING, ov::element::dynamic}};
|
||||
|
||||
for (const auto& it : signature.inputs()) {
|
||||
if (it.second.name().empty())
|
||||
if (it.second.name().empty() || types.find(it.second.dtype()) == types.end())
|
||||
return false;
|
||||
}
|
||||
for (const auto& it : signature.outputs()) {
|
||||
if (it.second.name().empty())
|
||||
if (it.second.name().empty() || types.find(it.second.dtype()) == types.end())
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
|
@ -18,12 +18,25 @@ namespace frontend {
|
||||
namespace tensorflow {
|
||||
|
||||
bool GraphIteratorSavedModel::is_valid_signature(const ::tensorflow::SignatureDef& signature) const {
|
||||
const std::map<::tensorflow::DataType, ov::element::Type> types{
|
||||
{::tensorflow::DataType::DT_BOOL, ov::element::boolean},
|
||||
{::tensorflow::DataType::DT_INT16, ov::element::i16},
|
||||
{::tensorflow::DataType::DT_INT32, ov::element::i32},
|
||||
{::tensorflow::DataType::DT_INT64, ov::element::i64},
|
||||
{::tensorflow::DataType::DT_HALF, ov::element::f16},
|
||||
{::tensorflow::DataType::DT_FLOAT, ov::element::f32},
|
||||
{::tensorflow::DataType::DT_DOUBLE, ov::element::f64},
|
||||
{::tensorflow::DataType::DT_UINT8, ov::element::u8},
|
||||
{::tensorflow::DataType::DT_INT8, ov::element::i8},
|
||||
{::tensorflow::DataType::DT_BFLOAT16, ov::element::bf16},
|
||||
{::tensorflow::DataType::DT_STRING, ov::element::dynamic}};
|
||||
|
||||
for (const auto& it : signature.inputs()) {
|
||||
if (it.second.name().empty())
|
||||
if (it.second.name().empty() || types.find(it.second.dtype()) == types.end())
|
||||
return false;
|
||||
}
|
||||
for (const auto& it : signature.outputs()) {
|
||||
if (it.second.name().empty())
|
||||
if (it.second.name().empty() || types.find(it.second.dtype()) == types.end())
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
|
@ -82,17 +82,17 @@ void extract_tensor_content(const std::string& tensor_content, Tensor* values) {
|
||||
# pragma warning(disable : 4244) // possible loss of data
|
||||
# pragma warning(disable : 4267) // possible loss of data
|
||||
#endif
|
||||
template <typename SRC_T, typename DST_T = SRC_T>
|
||||
template <typename T>
|
||||
void extract_compressed_tensor_content(const ::tensorflow::TensorProto& tensor_proto,
|
||||
int64_t val_size,
|
||||
Tensor* values) {
|
||||
auto val_lastsaved = static_cast<SRC_T>(0);
|
||||
auto values_data = values->data<DST_T>();
|
||||
auto val_lastsaved = static_cast<T>(0);
|
||||
auto values_data = values->data<T>();
|
||||
for (size_t i = 0; i < values->get_size(); i++) {
|
||||
if (val_size == 0) {
|
||||
values_data[i] = static_cast<DST_T>(0);
|
||||
values_data[i] = static_cast<T>(0);
|
||||
} else if (static_cast<int64_t>(i) < val_size) {
|
||||
auto val_i = static_cast<SRC_T>(0);
|
||||
auto val_i = static_cast<T>(0);
|
||||
switch (values->get_element_type()) {
|
||||
// TODO: there are more element types to support here
|
||||
case boolean:
|
||||
@ -113,34 +113,13 @@ void extract_compressed_tensor_content(const ::tensorflow::TensorProto& tensor_p
|
||||
case f64:
|
||||
val_i = tensor_proto.double_val()[i];
|
||||
break;
|
||||
case u8:
|
||||
val_i = tensor_proto.int_val()[i];
|
||||
break;
|
||||
case u16:
|
||||
val_i = tensor_proto.int_val()[i];
|
||||
break;
|
||||
case u64:
|
||||
val_i = tensor_proto.uint64_val()[i];
|
||||
break;
|
||||
case i8:
|
||||
val_i = tensor_proto.int_val()[i];
|
||||
break;
|
||||
case bf16:
|
||||
val_i = bfloat16::from_bits(tensor_proto.half_val()[i]);
|
||||
break;
|
||||
case u32:
|
||||
val_i = tensor_proto.uint32_val()[i];
|
||||
break;
|
||||
case i16:
|
||||
val_i = tensor_proto.int_val()[i];
|
||||
break;
|
||||
default:
|
||||
FRONT_END_THROW("Encountered unknown element type " + values->get_element_type().get_type_name());
|
||||
}
|
||||
values_data[i] = static_cast<DST_T>(val_i);
|
||||
values_data[i] = val_i;
|
||||
val_lastsaved = val_i;
|
||||
} else {
|
||||
values_data[i] = static_cast<DST_T>(val_lastsaved);
|
||||
values_data[i] = val_lastsaved;
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -171,18 +150,16 @@ bool CfMarkerType::is_copyable() const {
|
||||
}
|
||||
|
||||
Type get_ov_type(const ::tensorflow::DataType& type) {
|
||||
using ::tensorflow::DataType;
|
||||
|
||||
static map<DataType, Type> type_map{
|
||||
{DataType::DT_FLOAT, f32}, {DataType::DT_DOUBLE, f64}, {DataType::DT_INT32, i32},
|
||||
{DataType::DT_UINT8, u8}, {DataType::DT_INT16, i16}, {DataType::DT_INT8, i8},
|
||||
{DataType::DT_INT64, i64}, {DataType::DT_BOOL, boolean}, {DataType::DT_BFLOAT16, bf16},
|
||||
{DataType::DT_UINT16, u16}, {DataType::DT_HALF, f16}, {DataType::DT_UINT32, u32},
|
||||
{DataType::DT_UINT64, u64}, {DataType::DT_FLOAT_REF, f32}, {DataType::DT_DOUBLE_REF, f64},
|
||||
{DataType::DT_INT32_REF, i32}, {DataType::DT_UINT8_REF, u8}, {DataType::DT_INT16_REF, i16},
|
||||
{DataType::DT_INT8_REF, i8}, {DataType::DT_INT64_REF, i64}, {DataType::DT_BOOL_REF, boolean},
|
||||
{DataType::DT_BFLOAT16_REF, bf16}, {DataType::DT_UINT16_REF, u16}, {DataType::DT_HALF_REF, f16},
|
||||
{DataType::DT_UINT32_REF, u32}, {DataType::DT_UINT64_REF, u64}};
|
||||
static const map<::tensorflow::DataType, Type> type_map{{::tensorflow::DataType::DT_BOOL, boolean},
|
||||
{::tensorflow::DataType::DT_INT16, i16},
|
||||
{::tensorflow::DataType::DT_INT32, i32},
|
||||
{::tensorflow::DataType::DT_INT64, i64},
|
||||
{::tensorflow::DataType::DT_HALF, f16},
|
||||
{::tensorflow::DataType::DT_FLOAT, f32},
|
||||
{::tensorflow::DataType::DT_DOUBLE, f64},
|
||||
{::tensorflow::DataType::DT_UINT8, u8},
|
||||
{::tensorflow::DataType::DT_INT8, i8},
|
||||
{::tensorflow::DataType::DT_BFLOAT16, bf16}};
|
||||
|
||||
auto it = type_map.find(type);
|
||||
// for all unsupported types return dynamic type
|
||||
@ -214,49 +191,36 @@ Any unpack_tensor_proto(const ::tensorflow::TensorProto& tensor_proto,
|
||||
}
|
||||
return data;
|
||||
}
|
||||
|
||||
Tensor res(ov_type, pshape.get_shape());
|
||||
auto tensor_content = tensor_proto.tensor_content();
|
||||
if (!tensor_content.empty() && tensor_proto.has_tensor_shape()) {
|
||||
switch (ov_type) {
|
||||
case f32:
|
||||
extract_tensor_content<float>(tensor_content, &res);
|
||||
break;
|
||||
case u8:
|
||||
extract_tensor_content<uint8_t>(tensor_content, &res);
|
||||
break;
|
||||
case i64:
|
||||
extract_tensor_content<int64_t>(tensor_content, &res);
|
||||
break;
|
||||
case u16:
|
||||
extract_tensor_content<uint16_t>(tensor_content, &res);
|
||||
break;
|
||||
case u64:
|
||||
extract_tensor_content<uint64_t>(tensor_content, &res);
|
||||
break;
|
||||
case i32:
|
||||
extract_tensor_content<int32_t>(tensor_content, &res);
|
||||
break;
|
||||
case i8:
|
||||
extract_tensor_content<int8_t>(tensor_content, &res);
|
||||
break;
|
||||
case bf16:
|
||||
extract_tensor_content<bfloat16>(tensor_content, &res);
|
||||
case i16:
|
||||
extract_tensor_content<int16_t>(tensor_content, &res);
|
||||
break;
|
||||
case u32:
|
||||
extract_tensor_content<uint32_t>(tensor_content, &res);
|
||||
case i32:
|
||||
extract_tensor_content<int32_t>(tensor_content, &res);
|
||||
break;
|
||||
case i64:
|
||||
extract_tensor_content<int64_t>(tensor_content, &res);
|
||||
break;
|
||||
case f16:
|
||||
extract_tensor_content<float16>(tensor_content, &res);
|
||||
break;
|
||||
case f32:
|
||||
extract_tensor_content<float>(tensor_content, &res);
|
||||
break;
|
||||
case f64:
|
||||
extract_tensor_content<double>(tensor_content, &res);
|
||||
break;
|
||||
case i16:
|
||||
extract_tensor_content<int16_t>(tensor_content, &res);
|
||||
break;
|
||||
case boolean:
|
||||
extract_tensor_content<bool>(tensor_content, &res);
|
||||
break;
|
||||
case f16:
|
||||
extract_tensor_content<float16>(tensor_content, &res);
|
||||
case bf16:
|
||||
extract_tensor_content<bfloat16>(tensor_content, &res);
|
||||
break;
|
||||
default:
|
||||
FRONT_END_THROW("Encountered unknown element type " + ov_type.get_type_name());
|
||||
@ -264,58 +228,30 @@ Any unpack_tensor_proto(const ::tensorflow::TensorProto& tensor_proto,
|
||||
} else {
|
||||
int64_t val_size = 0;
|
||||
switch (ov_type) {
|
||||
case f32:
|
||||
val_size = tensor_proto.float_val_size();
|
||||
extract_compressed_tensor_content<float>(tensor_proto, val_size, &res);
|
||||
break;
|
||||
case u8:
|
||||
val_size = tensor_proto.int_val_size();
|
||||
extract_compressed_tensor_content<int32_t, uint8_t>(tensor_proto, val_size, &res);
|
||||
break;
|
||||
case i64:
|
||||
val_size = tensor_proto.int64_val_size();
|
||||
extract_compressed_tensor_content<int64_t>(tensor_proto, val_size, &res);
|
||||
break;
|
||||
case u16:
|
||||
val_size = tensor_proto.int_val_size();
|
||||
extract_compressed_tensor_content<uint16_t, int32_t>(tensor_proto, val_size, &res);
|
||||
break;
|
||||
case u64:
|
||||
val_size = tensor_proto.uint64_val_size();
|
||||
extract_compressed_tensor_content<uint64_t>(tensor_proto, val_size, &res);
|
||||
case boolean:
|
||||
val_size = tensor_proto.bool_val_size();
|
||||
extract_compressed_tensor_content<bool>(tensor_proto, val_size, &res);
|
||||
break;
|
||||
case i32:
|
||||
val_size = tensor_proto.int_val_size();
|
||||
extract_compressed_tensor_content<int32_t>(tensor_proto, val_size, &res);
|
||||
break;
|
||||
case i8:
|
||||
val_size = tensor_proto.int_val_size();
|
||||
extract_compressed_tensor_content<int32_t, int8_t>(tensor_proto, val_size, &res);
|
||||
case i64:
|
||||
val_size = tensor_proto.int64_val_size();
|
||||
extract_compressed_tensor_content<int64_t>(tensor_proto, val_size, &res);
|
||||
break;
|
||||
case bf16:
|
||||
case f16:
|
||||
val_size = tensor_proto.half_val_size();
|
||||
extract_compressed_tensor_content<int32_t, bfloat16>(tensor_proto, val_size, &res);
|
||||
extract_compressed_tensor_content<float16>(tensor_proto, val_size, &res);
|
||||
break;
|
||||
case u32:
|
||||
val_size = tensor_proto.uint32_val_size();
|
||||
extract_compressed_tensor_content<uint32_t>(tensor_proto, val_size, &res);
|
||||
case f32:
|
||||
val_size = tensor_proto.float_val_size();
|
||||
extract_compressed_tensor_content<float>(tensor_proto, val_size, &res);
|
||||
break;
|
||||
case f64:
|
||||
val_size = tensor_proto.double_val_size();
|
||||
extract_compressed_tensor_content<double>(tensor_proto, val_size, &res);
|
||||
break;
|
||||
case i16:
|
||||
val_size = tensor_proto.int_val_size();
|
||||
extract_compressed_tensor_content<int32_t, int16_t>(tensor_proto, val_size, &res);
|
||||
break;
|
||||
case boolean:
|
||||
val_size = tensor_proto.bool_val_size();
|
||||
extract_compressed_tensor_content<bool>(tensor_proto, val_size, &res);
|
||||
break;
|
||||
case f16:
|
||||
val_size = tensor_proto.half_val_size();
|
||||
extract_compressed_tensor_content<int32_t, float16>(tensor_proto, val_size, &res);
|
||||
break;
|
||||
default:
|
||||
FRONT_END_THROW("Encountered unknown element type " + ov_type.get_type_name());
|
||||
}
|
||||
|
@ -1,51 +0,0 @@
|
||||
# Copyright (C) 2018-2023 Intel Corporation
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
import tensorflow as tf
|
||||
from common.tf_layer_test_class import CommonTFLayerTest
|
||||
|
||||
rng = np.random.default_rng()
|
||||
|
||||
|
||||
class TestAddTypes(CommonTFLayerTest):
|
||||
def _prepare_input(self, inputs_info):
|
||||
assert 'x' in inputs_info, "Test error: inputs_info must contain `x`"
|
||||
x_shape = inputs_info['x']
|
||||
inputs_data = {}
|
||||
if np.issubdtype(self.input_type, np.signedinteger):
|
||||
inputs_data['x'] = rng.integers(-8, 8, x_shape).astype(self.input_type)
|
||||
else:
|
||||
inputs_data['x'] = rng.integers(0, 8, x_shape).astype(self.input_type)
|
||||
return inputs_data
|
||||
|
||||
def create_add_types_net(self, const_shape, input_type):
|
||||
self.input_type = input_type
|
||||
tf.compat.v1.reset_default_graph()
|
||||
# Create the graph and model
|
||||
with tf.compat.v1.Session() as sess:
|
||||
x = tf.compat.v1.placeholder(input_type, [], 'x')
|
||||
if np.issubdtype(self.input_type, np.signedinteger):
|
||||
const_value = rng.integers(-8, 8, const_shape).astype(self.input_type)
|
||||
else:
|
||||
const_value = rng.integers(0, 8, const_shape).astype(self.input_type)
|
||||
const_input = tf.constant(const_value, dtype=input_type)
|
||||
tf.raw_ops.Add(x=x, y=const_input)
|
||||
tf.compat.v1.global_variables_initializer()
|
||||
|
||||
tf_net = sess.graph_def
|
||||
|
||||
return tf_net, None
|
||||
|
||||
@pytest.mark.parametrize("const_shape", [[], [2], [3, 4], [3, 2, 1, 4]])
|
||||
@pytest.mark.parametrize("input_type", [np.int8, np.uint8, np.int16,
|
||||
np.int32, np.int64,
|
||||
np.float16, np.float32, np.float64])
|
||||
@pytest.mark.precommit_tf_fe
|
||||
@pytest.mark.nightly
|
||||
def test_add_types(self, const_shape, input_type, ie_device, precision, ir_version, temp_dir,
|
||||
use_new_frontend, use_old_api):
|
||||
self._test(*self.create_add_types_net(const_shape, input_type),
|
||||
ie_device, precision, ir_version, temp_dir=temp_dir,
|
||||
use_new_frontend=use_new_frontend, use_old_api=use_old_api)
|
Loading…
Reference in New Issue
Block a user