Fix of tf.GenericFunction conversion in convert_model() (#17125)

* Added GenericFunction support, fixed tf.Function test.

* Added test, added TF version checks.

* Small correction

* Removed Trackable type support.

* Small correction.
This commit is contained in:
Anastasiia Pnevskaia 2023-04-25 00:57:56 +02:00 committed by GitHub
parent ce23ce00f1
commit 00847cba7d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 76 additions and 37 deletions

View File

@ -1,13 +1,15 @@
# Copyright (C) 2018-2023 Intel Corporation
# SPDX-License-Identifier: Apache-2.0
import numpy as np
import pytest
from common.mo_convert_test_class import CommonMOConvertTest
import unittest
import numpy as np
import openvino.runtime as ov
import pytest
from openvino.runtime import PartialShape, Model, Dimension
from common.mo_convert_test_class import CommonMOConvertTest
def create_tf_graph_def(tmp_dir):
import tensorflow as tf
@ -269,21 +271,14 @@ def create_tf_checkpoint(tmp_dir):
def create_tf_function(temp_dir):
import tensorflow as tf
input_names = ["Input1", "Input2"]
input_shape = [1, 2, 3]
x1 = tf.keras.Input(shape=input_shape, name=input_names[0])
x2 = tf.keras.Input(shape=input_shape, name=input_names[1])
y = tf.nn.sigmoid(tf.nn.relu(x1 + x2))
keras_net = tf.keras.Model(inputs=[x1, x2], outputs=[y])
@tf.function(
input_signature=[tf.TensorSpec(shape=[1, 2, 3], dtype=tf.float32),
tf.TensorSpec(shape=[1, 2, 3], dtype=tf.float32)])
def f(x):
return keras_net(x)
def f(x1, x2):
y = tf.nn.sigmoid(tf.nn.relu(x1 + x2))
return y
shape = PartialShape([-1, 1, 2, 3])
shape = PartialShape([1, 2, 3])
param1 = ov.opset8.parameter(shape, dtype=np.float32)
param2 = ov.opset8.parameter(shape, dtype=np.float32)
add = ov.opset8.add(param1, param2)
@ -293,7 +288,35 @@ def create_tf_function(temp_dir):
parameter_list = [param1, param2]
model_ref = Model([sigm], parameter_list, "test")
return keras_net, model_ref, None
return f, model_ref, None
def create_tf_graph(temp_dir):
import tensorflow as tf
tf.compat.v1.reset_default_graph()
with tf.compat.v1.Session() as sess:
inp1 = tf.compat.v1.placeholder(tf.float32, [1, 2, 3], 'Input')
inp2 = tf.compat.v1.placeholder(tf.float32, [1, 2, 3], 'Input')
relu = tf.nn.relu(inp1 + inp2, name='Relu')
output = tf.nn.sigmoid(relu, name='Sigmoid')
tf.compat.v1.global_variables_initializer()
tf_net = sess.graph
shape = PartialShape([1, 2, 3])
param1 = ov.opset8.parameter(shape, dtype=np.float32)
param2 = ov.opset8.parameter(shape, dtype=np.float32)
add = ov.opset8.add(param1, param2)
relu = ov.opset8.relu(add)
sigm = ov.opset8.sigmoid(relu)
parameter_list = [param1, param2]
model_ref = Model([sigm], parameter_list, "test")
return tf_net, model_ref, None
def create_tf_saved_model_dir(temp_dir):
@ -322,29 +345,20 @@ def create_tf_saved_model_dir(temp_dir):
return temp_dir + "/model", model_ref
def create_tf_saved_model(temp_dir):
import tensorflow as tf
saved_model_dir, model_ref = create_tf_saved_model_dir(temp_dir)
saved_model = tf.saved_model.load(saved_model_dir)
return saved_model, model_ref, None
class TestMoConvertTF(CommonMOConvertTest):
test_data = [
# TF2
create_keras_model,
create_keras_layer,
#create_tf_function, # skip, ticket 106247
create_tf_function,
create_tf_module,
create_tf_checkpoint,
create_tf_saved_model,
create_keras_layer_dynamic,
create_tf_module_dynamic,
create_tf_module_layout_list,
# TF1
create_tf_graph,
create_tf_graph_def,
create_tf1_wrap_function,
create_tf_session,
@ -386,3 +400,20 @@ class TestMoConvertTF(CommonMOConvertTest):
test_params = {'input_model': saved_model_dir, 'use_new_frontend': False}
self._test_by_ref_graph(temp_dir, test_params, graph_ref, compare_tensor_names=False)
class TFConvertTest(unittest.TestCase):
@pytest.mark.nightly
@pytest.mark.precommit
def test_tf_function_no_signature(self):
import tensorflow as tf
from openvino.tools.mo import convert_model
@tf.function()
def function(x1, x2):
y = tf.nn.sigmoid(tf.nn.relu(x1 + x2))
return y
with self.assertRaisesRegex(AssertionError,
".*'input_signature' needs to be set for model conversion.*"):
convert_model(function)

View File

@ -103,6 +103,7 @@ def convert_model(
torch.jit.ScriptFunction
TF
tf.compat.v1.Graph
tf.compat.v1.GraphDef
tf.compat.v1.wrap_function
tf.compat.v1.session
@ -113,7 +114,6 @@ def convert_model(
tf.function
tf.Module
tf.train.checkpoint
tf.python.training.tracking.base.Trackable for case when it is output from tf.saved_model.load()
:param input:
Input can be set by passing a list of InputCutInfo objects or by a list

View File

@ -9,6 +9,7 @@ import platform
import sys
from collections import OrderedDict
from copy import deepcopy
from distutils.version import LooseVersion
from pathlib import Path
try:
@ -41,7 +42,7 @@ from openvino.tools.mo.utils.guess_framework import deduce_legacy_frontend_by_na
from openvino.tools.mo.utils.logger import init_logger, progress_printer
from openvino.tools.mo.utils.utils import refer_to_faq_msg
from openvino.tools.mo.utils.telemetry_utils import send_params_info, send_framework_info
from openvino.tools.mo.utils.versions_checker import check_requirements # pylint: disable=no-name-in-module
from openvino.tools.mo.utils.versions_checker import check_requirements, get_environment_setup # pylint: disable=no-name-in-module
from openvino.tools.mo.utils.telemetry_utils import get_tid
from openvino.tools.mo.moc_frontend.check_config import legacy_extensions_used
from openvino.tools.mo.moc_frontend.pytorch_frontend_utils import get_pytorch_decoder, convert_pytorch_via_onnx
@ -523,16 +524,22 @@ def check_model_object(argv):
model = argv['input_model']
if 'tensorflow' in sys.modules:
import tensorflow as tf
from tensorflow.python.training.tracking.base import Trackable
env_setup = get_environment_setup("tf")
if isinstance(model, tf.compat.v1.GraphDef):
return "tf"
if isinstance(model, tf.compat.v1.Graph):
argv['input_model'] = model.as_graph_def()
return "tf"
if isinstance(model, tf.compat.v1.Session):
argv['input_model'] = model.graph_def
return "tf"
if isinstance(model, tf.types.experimental.ConcreteFunction):
if env_setup["tensorflow"] >= LooseVersion("2.6.0") and isinstance(model, tf.types.experimental.ConcreteFunction):
argv['input_model'] = model.graph.as_graph_def()
return "tf"
if env_setup["tensorflow"] >= LooseVersion("2.6.0") and isinstance(model, tf.types.experimental.GenericFunction):
argv['input_model'] = model
return "tf"
if isinstance(model, tf.keras.Model):
return "tf"
if isinstance(model, tf.train.Checkpoint):
@ -558,8 +565,6 @@ def check_model_object(argv):
argv['input_model'] = tf.keras.Model(inputs, outputs)
argv['input_shape'] = None
return "tf"
if isinstance(model, Trackable):
return "tf"
if 'torch' in sys.modules:
import torch
if isinstance(model, torch.nn.Module) or isinstance(model, torch.jit.ScriptFunction):

View File

@ -200,14 +200,13 @@ def freeze_tf2_concrete_function(model, concrete_func, env_setup):
def prepare_graph_def(model):
from tensorflow.python.training.tracking.base import Trackable # pylint: disable=no-name-in-module,import-error
env_setup = get_environment_setup("tf")
if isinstance(model, tf_v1.GraphDef):
nodes_to_clear_device = model.node
for node in nodes_to_clear_device:
node.device = ""
return model, {}, "tf", None
if isinstance(model, tf.keras.Model):
env_setup = get_environment_setup("tf")
assert hasattr(model, "inputs") and model.inputs is not None, "Model inputs specification is required."
@ -226,9 +225,13 @@ def prepare_graph_def(model):
conc_func = tf_function.get_concrete_function(model_inputs)
return freeze_tf2_concrete_function(model, conc_func, env_setup)
if isinstance(model, Trackable):
env_setup = get_environment_setup("tf")
return saved_model_load(model, env_setup)
if env_setup["tensorflow"] >= LooseVersion("2.6.0") and isinstance(model, tf.types.experimental.GenericFunction):
assert hasattr(model, "input_signature") and model.input_signature is not None, \
"'input_signature' needs to be set for model conversion."
conc_func = model.get_concrete_function(*tuple(model.input_signature))
return freeze_tf2_concrete_function(model, conc_func, env_setup)
raise Exception("Unknown model type {}.".format(type(model)))