Fixed names for GraphDef. (#21799)

This commit is contained in:
Anastasiia Pnevskaia 2023-12-21 14:04:05 +01:00 committed by GitHub
parent 80618b0498
commit 3ab5ee861d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 24 additions and 1 deletions

View File

@ -363,7 +363,7 @@ def extract_model_graph(argv):
if isinstance(model, tf.compat.v1.GraphDef):
graph = tf.Graph()
with graph.as_default():
tf.graph_util.import_graph_def(model)
tf.graph_util.import_graph_def(model, name='')
argv["input_model"] = graph
return True
if isinstance(model, tf.compat.v1.Session):

View File

@ -938,6 +938,29 @@ class TestMoConvertTF(CommonMOConvertTest):
assert CommonLayerTest().compare_ie_results_with_framework(ov_infer, {"Identity:0": fw_infer}, eps)
assert CommonLayerTest().compare_ie_results_with_framework(ov_infer, {"Identity:0": [-1.8, -4.4]}, eps)
class TFGraphDefNames(unittest.TestCase):
def test_graph_def_names(self):
from openvino.tools.ovc import convert_model
tf_model, model_ref, _ = create_tf_graph_def(None)
ov_model = convert_model(tf_model)
input_list = []
with tf.Graph().as_default() as graph:
tf.import_graph_def(tf_model, name='')
for op in graph.get_operations():
if op.type == "Placeholder":
input_list.append(op.name)
for input in input_list:
found = False
for ov_input in ov_model.inputs:
if input in ov_input.get_names():
found = True
assert found, "Could not found input {} in resulting model.".format(input)
class TFConvertTest(unittest.TestCase):
@pytest.mark.nightly
@pytest.mark.precommit