Fixed names for GraphDef. (#21799)
This commit is contained in:
parent
80618b0498
commit
3ab5ee861d
|
@ -363,7 +363,7 @@ def extract_model_graph(argv):
|
|||
if isinstance(model, tf.compat.v1.GraphDef):
|
||||
graph = tf.Graph()
|
||||
with graph.as_default():
|
||||
tf.graph_util.import_graph_def(model)
|
||||
tf.graph_util.import_graph_def(model, name='')
|
||||
argv["input_model"] = graph
|
||||
return True
|
||||
if isinstance(model, tf.compat.v1.Session):
|
||||
|
|
|
@ -938,6 +938,29 @@ class TestMoConvertTF(CommonMOConvertTest):
|
|||
assert CommonLayerTest().compare_ie_results_with_framework(ov_infer, {"Identity:0": fw_infer}, eps)
|
||||
assert CommonLayerTest().compare_ie_results_with_framework(ov_infer, {"Identity:0": [-1.8, -4.4]}, eps)
|
||||
|
||||
|
||||
class TFGraphDefNames(unittest.TestCase):
|
||||
def test_graph_def_names(self):
|
||||
from openvino.tools.ovc import convert_model
|
||||
|
||||
tf_model, model_ref, _ = create_tf_graph_def(None)
|
||||
ov_model = convert_model(tf_model)
|
||||
|
||||
input_list = []
|
||||
with tf.Graph().as_default() as graph:
|
||||
tf.import_graph_def(tf_model, name='')
|
||||
for op in graph.get_operations():
|
||||
if op.type == "Placeholder":
|
||||
input_list.append(op.name)
|
||||
|
||||
for input in input_list:
|
||||
found = False
|
||||
for ov_input in ov_model.inputs:
|
||||
if input in ov_input.get_names():
|
||||
found = True
|
||||
assert found, "Could not found input {} in resulting model.".format(input)
|
||||
|
||||
|
||||
class TFConvertTest(unittest.TestCase):
|
||||
@pytest.mark.nightly
|
||||
@pytest.mark.precommit
|
||||
|
|
Loading…
Reference in New Issue
Block a user