[TF Hub API][TF FE] Support TF Keras Model OOB without example_input (#19892)
* [TF Hub] Cover TF Hub use cases with adoption to OpenVINO This is necessarily to demonstrate support of models programmed with TF Hub API through OV notebooks. Signed-off-by: Kazantsev, Roman <roman.kazantsev@intel.com> * Preserve original keras input and output tensor names * Add tests with TF Hub API models * No KerasLayer handling * Handle specific signature --------- Signed-off-by: Kazantsev, Roman <roman.kazantsev@intel.com>
This commit is contained in:
parent
a4cbac3dee
commit
df19699e3a
@ -63,7 +63,7 @@ def get_imported_module_version(imported_module):
|
||||
for attr in version_attrs:
|
||||
installed_version = getattr(imported_module, attr, None)
|
||||
if isinstance(installed_version, str):
|
||||
return installed_version
|
||||
return installed_version
|
||||
else:
|
||||
installed_version = None
|
||||
|
||||
@ -98,7 +98,8 @@ def get_environment_setup(framework):
|
||||
|
||||
def trace_tf_model_if_needed(input_model, placeholder_shapes, placeholder_data_types, example_input):
|
||||
import tensorflow as tf
|
||||
if not isinstance(input_model, (tf.keras.layers.Layer, tf.Module, tf.keras.Model, tf.types.experimental.GenericFunction)):
|
||||
if not isinstance(input_model,
|
||||
(tf.keras.layers.Layer, tf.Module, tf.keras.Model, tf.types.experimental.GenericFunction)):
|
||||
return input_model
|
||||
return trace_tf_model(input_model, placeholder_shapes, placeholder_data_types, example_input)
|
||||
|
||||
@ -175,6 +176,65 @@ def get_concrete_func(tf_function, example_input, input_needs_packing, error_mes
|
||||
return concrete_func
|
||||
|
||||
|
||||
def create_generic_function_from_keras_model(keras_model):
|
||||
import tensorflow as tf
|
||||
assert isinstance(keras_model, tf.keras.Model), \
|
||||
"[TensorFlow Frontend] internal error: the input model must be of Keras model type"
|
||||
if not hasattr(keras_model, 'input') or getattr(keras_model, 'input') is None:
|
||||
return None
|
||||
keras_input_signature = getattr(keras_model, 'input')
|
||||
tf_input_signature = None
|
||||
wrapper_function = None
|
||||
if isinstance(keras_input_signature, dict):
|
||||
tf_input_signature = []
|
||||
for tensor_name, tensor_spec in keras_input_signature.items():
|
||||
tf_input_signature.append(tf.TensorSpec(shape=tensor_spec.shape,
|
||||
dtype=tensor_spec.dtype,
|
||||
name=tensor_name))
|
||||
elif isinstance(keras_input_signature, list):
|
||||
tf_input_signature = []
|
||||
for tensor_spec in keras_input_signature:
|
||||
tf_input_signature.append(tf.TensorSpec(shape=tensor_spec.shape,
|
||||
dtype=tensor_spec.dtype,
|
||||
name=tensor_spec.name))
|
||||
else:
|
||||
try:
|
||||
# single KerasTensor case
|
||||
tf_input_signature = []
|
||||
tf_input_signature.append(tf.TensorSpec(shape=keras_input_signature.shape,
|
||||
dtype=keras_input_signature.dtype,
|
||||
name=keras_input_signature.name))
|
||||
except:
|
||||
tf_input_signature = None
|
||||
if tf_input_signature is not None:
|
||||
@tf.function(input_signature=tf_input_signature)
|
||||
def wrapper_function_dict(*args):
|
||||
input_dict = {}
|
||||
for ind, tensor_spec in enumerate(tf_input_signature):
|
||||
input_dict[tensor_spec.name] = args[ind]
|
||||
outputs = keras_model(input_dict)
|
||||
# need to wrap the output into dictionary
|
||||
# it helps to preserve original keras tensor names
|
||||
post_outputs = {}
|
||||
if isinstance(outputs, dict):
|
||||
for output_name, output_value in outputs.items():
|
||||
post_outputs[output_name] = output_value
|
||||
else:
|
||||
try:
|
||||
if isinstance(outputs, list) and isinstance(keras_model.outputs, list) and \
|
||||
len(outputs) == len(keras_model.outputs):
|
||||
for output_value, output_tensor in zip(outputs, keras_model.outputs):
|
||||
post_outputs[output_tensor.name] = output_value
|
||||
else:
|
||||
post_outputs[keras_model.output.name] = outputs
|
||||
except:
|
||||
post_outputs = outputs
|
||||
return post_outputs
|
||||
|
||||
wrapper_function = wrapper_function_dict
|
||||
return wrapper_function
|
||||
|
||||
|
||||
def trace_tf_model(model, input_shapes, input_types, example_input):
|
||||
import tensorflow as tf
|
||||
if isinstance(model.__call__, tf.types.experimental.GenericFunction):
|
||||
@ -183,12 +243,25 @@ def trace_tf_model(model, input_shapes, input_types, example_input):
|
||||
elif isinstance(model, tf.types.experimental.GenericFunction):
|
||||
tf_function = model
|
||||
input_needs_packing = False
|
||||
elif isinstance(model, tf.keras.Model):
|
||||
tf_function = create_generic_function_from_keras_model(model)
|
||||
if tf_function is not None:
|
||||
input_needs_packing = False
|
||||
else:
|
||||
# Wrap model to tf.Function.
|
||||
# In this case we loose input/output tensor names.
|
||||
@tf.function
|
||||
def tf_function(args):
|
||||
return model(*args)
|
||||
|
||||
input_needs_packing = True
|
||||
else:
|
||||
# Wrap model to tf.Function.
|
||||
# In this case we loose input/output tensor names.
|
||||
@tf.function
|
||||
def tf_function(args):
|
||||
return model(*args)
|
||||
|
||||
input_needs_packing = True
|
||||
|
||||
if example_input is not None:
|
||||
@ -216,7 +289,8 @@ def trace_tf_model(model, input_shapes, input_types, example_input):
|
||||
def type_supported_by_tf_fe(input_model):
|
||||
import tensorflow as tf
|
||||
# Types that require tracing
|
||||
if isinstance(input_model, (tf.keras.layers.Layer, tf.Module, tf.keras.Model, tf.types.experimental.GenericFunction)):
|
||||
if isinstance(input_model,
|
||||
(tf.keras.layers.Layer, tf.Module, tf.keras.Model, tf.types.experimental.GenericFunction)):
|
||||
return True
|
||||
# Types that do not require tracing
|
||||
if isinstance(input_model, (tf.Graph, tf.types.experimental.ConcreteFunction)):
|
||||
@ -246,7 +320,15 @@ def create_tf_graph_iterator(input_model, placeholder_shapes, placeholder_data_t
|
||||
if func_input.dtype == tf.resource:
|
||||
continue
|
||||
internal_tensor_names.append(func_input.name)
|
||||
if len(input_model.structured_input_signature) > 1 and \
|
||||
if len(input_model.structured_input_signature) > 0 and \
|
||||
len(internal_tensor_names) == len(input_model.structured_input_signature[0]):
|
||||
for internal_name, tensor_spec in zip(internal_tensor_names, input_model.structured_input_signature[0]):
|
||||
input_names_map = input_names_map or {}
|
||||
if not isinstance(tensor_spec, tf.TensorSpec):
|
||||
input_names_map = None
|
||||
break
|
||||
input_names_map[internal_name] = tensor_spec.name
|
||||
elif len(input_model.structured_input_signature) > 1 and \
|
||||
len(internal_tensor_names) == len(input_model.structured_input_signature[1]):
|
||||
external_tensor_names = sorted(input_model.structured_input_signature[1].keys())
|
||||
for internal_name, external_name in zip(internal_tensor_names, external_tensor_names):
|
||||
@ -262,6 +344,19 @@ def create_tf_graph_iterator(input_model, placeholder_shapes, placeholder_data_t
|
||||
for external_name, internal_name in zip(external_names, internal_names):
|
||||
output_names_map = output_names_map or {}
|
||||
output_names_map[internal_name] = external_name
|
||||
else:
|
||||
for external_name, internal_tensor in input_model.structured_outputs.items():
|
||||
internal_tf_tensor = None
|
||||
if isinstance(internal_tensor, tf.Tensor):
|
||||
internal_tf_tensor = internal_tensor
|
||||
if isinstance(internal_tensor, list) and len(internal_tensor) > 0 and \
|
||||
isinstance(internal_tensor[0], tf.Tensor):
|
||||
internal_tf_tensor = internal_tensor[0]
|
||||
if internal_tf_tensor is None:
|
||||
output_names_map = None
|
||||
break
|
||||
output_names_map = output_names_map or {}
|
||||
output_names_map[internal_tf_tensor.name] = external_name
|
||||
return GraphIteratorTFGraph(input_model.graph, share_weights, False, input_names_map, output_names_map)
|
||||
raise Exception("Could not wrap model of type {} to GraphIteratorTFGraph.".format(type(input_model)))
|
||||
|
||||
@ -271,7 +366,7 @@ def extract_model_graph(argv):
|
||||
import tensorflow as tf
|
||||
trackable_is_imported = False
|
||||
try:
|
||||
from tensorflow.python.training.tracking.base import Trackable # pylint: disable=no-name-in-module,import-error
|
||||
from tensorflow.python.training.tracking.base import Trackable # pylint: disable=no-name-in-module,import-error
|
||||
trackable_is_imported = True
|
||||
except:
|
||||
log.warning("Could not import tensorflow.python.training.tracking.base.Trackable type.")
|
||||
|
@ -35,9 +35,15 @@ class TestConvertModel:
|
||||
assert False, "Unsupported type {}".format(input_type)
|
||||
|
||||
def prepare_inputs(self, inputs_info):
|
||||
inputs = {}
|
||||
for input_name, input_shape, input_type in inputs_info:
|
||||
inputs[input_name] = self.prepare_input(input_shape, input_type)
|
||||
if len(inputs_info) > 0 and inputs_info[0] == 'list':
|
||||
inputs = []
|
||||
inputs_info = inputs_info[1:]
|
||||
for input_name, input_shape, input_type in inputs_info:
|
||||
inputs.append(self.prepare_input(input_shape, input_type))
|
||||
else:
|
||||
inputs = {}
|
||||
for input_name, input_shape, input_type in inputs_info:
|
||||
inputs[input_name] = self.prepare_input(input_shape, input_type)
|
||||
return inputs
|
||||
|
||||
def convert_model(self, model_obj):
|
||||
|
@ -0,0 +1,63 @@
|
||||
# Copyright (C) 2018-2023 Intel Corporation
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
import pytest
|
||||
import tensorflow as tf
|
||||
import tensorflow_hub as hub
|
||||
from models_hub_common.test_convert_model import TestConvertModel
|
||||
from tf_hub_tests.utils import get_input_info
|
||||
|
||||
|
||||
class TestTFHubApiNotebooks(TestConvertModel):
|
||||
def load_model(self, model_name, model_link):
|
||||
if model_name == 'mobilenet_v2_100_224_dict':
|
||||
image = tf.keras.layers.Input(shape=(224, 224, 3), dtype=tf.float32, name="image")
|
||||
feature_vector = hub.KerasLayer("https://tfhub.dev/google/imagenet/mobilenet_v2_100_224/feature_vector/5",
|
||||
trainable=False)(image)
|
||||
softmax = tf.keras.layers.Dense(20, activation='softmax')(feature_vector)
|
||||
classification_model = tf.keras.Model(inputs={'image': image}, outputs={'softmax': softmax})
|
||||
return classification_model
|
||||
elif model_name == 'mobilenet_v2_100_224_list':
|
||||
image = tf.keras.layers.Input(shape=(224, 224, 3), dtype=tf.float32, name="image")
|
||||
feature_vector = hub.KerasLayer("https://tfhub.dev/google/imagenet/mobilenet_v2_100_224/feature_vector/5",
|
||||
trainable=False)(image)
|
||||
softmax = tf.keras.layers.Dense(20, activation='softmax')(feature_vector)
|
||||
classification_model = tf.keras.Model(inputs=[image], outputs=[softmax])
|
||||
return classification_model
|
||||
else:
|
||||
raise "Unknown input model: {}".format(model_name)
|
||||
|
||||
def get_inputs_info(self, keras_model):
|
||||
inputs_info = []
|
||||
if isinstance(keras_model.input, dict):
|
||||
for input_name, input_tensor in keras_model.input.items():
|
||||
inputs_info.append(get_input_info(input_tensor, input_name))
|
||||
elif isinstance(keras_model.input, list):
|
||||
inputs_info.append('list')
|
||||
for input_tensor in keras_model.input:
|
||||
inputs_info.append(get_input_info(input_tensor, input_tensor.name))
|
||||
else:
|
||||
inputs_info.append('list')
|
||||
input_tensor = keras_model.input
|
||||
inputs_info.append(get_input_info(input_tensor, input_tensor.name))
|
||||
return inputs_info
|
||||
|
||||
def infer_fw_model(self, model_obj, inputs):
|
||||
outputs = model_obj(inputs)
|
||||
if isinstance(outputs, dict):
|
||||
post_outputs = {}
|
||||
for out_name, out_value in outputs.items():
|
||||
post_outputs[out_name] = out_value.numpy()
|
||||
elif isinstance(outputs, list):
|
||||
post_outputs = []
|
||||
for out_value in outputs:
|
||||
post_outputs.append(out_value.numpy())
|
||||
else:
|
||||
post_outputs = [outputs.numpy()]
|
||||
|
||||
return post_outputs
|
||||
|
||||
@pytest.mark.precommit
|
||||
@pytest.mark.parametrize("model_name", ['mobilenet_v2_100_224_dict', 'mobilenet_v2_100_224_list'])
|
||||
def test_tf_hub_api_notebook1(self, model_name, ie_device):
|
||||
self.run(model_name, '', ie_device)
|
@ -5,7 +5,6 @@ import gc
|
||||
import os
|
||||
import shutil
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
import tensorflow as tf
|
||||
import tensorflow_hub as hub
|
||||
@ -14,6 +13,7 @@ import tensorflow_text # do not delete, needed for text models
|
||||
from models_hub_common.constants import tf_hub_cache_dir
|
||||
from models_hub_common.test_convert_model import TestConvertModel
|
||||
from models_hub_common.utils import get_models_list
|
||||
from tf_hub_tests.utils import type_map
|
||||
|
||||
|
||||
class TestTFHubConvertModel(TestConvertModel):
|
||||
@ -49,18 +49,6 @@ class TestTFHubConvertModel(TestConvertModel):
|
||||
except ValueError:
|
||||
# unknown rank case
|
||||
pass
|
||||
type_map = {
|
||||
tf.float64: np.float64,
|
||||
tf.float32: np.float32,
|
||||
tf.int8: np.int8,
|
||||
tf.int16: np.int16,
|
||||
tf.int32: np.int32,
|
||||
tf.int64: np.int64,
|
||||
tf.uint8: np.uint8,
|
||||
tf.uint16: np.uint16,
|
||||
tf.string: str,
|
||||
tf.bool: bool,
|
||||
}
|
||||
if input_info.dtype == tf.resource:
|
||||
# skip inputs corresponding to variables
|
||||
continue
|
||||
|
33
tests/model_hub_tests/tf_hub_tests/utils.py
Normal file
33
tests/model_hub_tests/tf_hub_tests/utils.py
Normal file
@ -0,0 +1,33 @@
|
||||
# Copyright (C) 2018-2023 Intel Corporation
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
import numpy as np
|
||||
import tensorflow as tf
|
||||
|
||||
type_map = {
|
||||
tf.float64: np.float64,
|
||||
tf.float32: np.float32,
|
||||
tf.int8: np.int8,
|
||||
tf.int16: np.int16,
|
||||
tf.int32: np.int32,
|
||||
tf.int64: np.int64,
|
||||
tf.uint8: np.uint8,
|
||||
tf.uint16: np.uint16,
|
||||
tf.string: str,
|
||||
tf.bool: bool,
|
||||
}
|
||||
|
||||
|
||||
def get_input_info(input_tensor, input_name):
|
||||
input_shape = []
|
||||
try:
|
||||
for dim in input_tensor.shape.as_list():
|
||||
if dim is None:
|
||||
input_shape.append(1)
|
||||
else:
|
||||
input_shape.append(dim)
|
||||
except ValueError:
|
||||
# unknown rank case
|
||||
pass
|
||||
assert input_tensor.dtype in type_map, "Unsupported input type: {}".format(input_tensor.dtype)
|
||||
return input_name, input_shape, type_map[input_tensor.dtype]
|
Loading…
Reference in New Issue
Block a user