Fixed 'output_model' logic in OVC. (#19171)

* Fixed output_model logic.

* Removed not needed code.

* Used os.path.basename, added comments.

* Removed loop.

* Test corrections.
This commit is contained in:
Anastasiia Pnevskaia 2023-08-17 07:47:18 +02:00 committed by GitHub
parent 3e23908983
commit 7a566313e5
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 72 additions and 10 deletions

View File

@ -174,8 +174,8 @@ def transpose_nhwc_to_nchw(data, use_new_frontend, use_old_api):
return data
def save_to_pb(tf_model, path_to_saved_tf_model):
tf.io.write_graph(tf_model, path_to_saved_tf_model, 'model.pb', False)
assert os.path.isfile(os.path.join(path_to_saved_tf_model, 'model.pb')), "model.pb haven't been saved " \
def save_to_pb(tf_model, path_to_saved_tf_model, model_name = 'model.pb'):
tf.io.write_graph(tf_model, path_to_saved_tf_model, model_name, False)
assert os.path.isfile(os.path.join(path_to_saved_tf_model, model_name)), "model.pb haven't been saved " \
"here: {}".format(path_to_saved_tf_model)
return os.path.join(path_to_saved_tf_model, 'model.pb')
return os.path.join(path_to_saved_tf_model, model_name)

View File

@ -72,7 +72,32 @@ class TestOVCTool(CommonMOConvertTest):
tf_net = sess.graph_def
# save model to .pb and return path to the model
return save_to_pb(tf_net, tmp_dir)
return save_to_pb(tf_net, tmp_dir, 'model2.pb')
def create_tf_saved_model_dir(self, temp_dir):
import tensorflow as tf
input_names = ["Input1", "Input2"]
input_shape = [1, 2, 3]
x1 = tf.keras.Input(shape=input_shape, name=input_names[0])
x2 = tf.keras.Input(shape=input_shape, name=input_names[1])
y = tf.nn.sigmoid(tf.nn.relu(x1 + x2))
keras_net = tf.keras.Model(inputs=[x1, x2], outputs=[y])
tf.saved_model.save(keras_net, temp_dir + "/test_model")
shape = PartialShape([-1, 1, 2, 3])
param1 = ov.opset8.parameter(shape, name="Input1:0", dtype=np.float32)
param2 = ov.opset8.parameter(shape, name="Input2:0", dtype=np.float32)
add = ov.opset8.add(param1, param2)
relu = ov.opset8.relu(add)
sigm = ov.opset8.sigmoid(relu)
parameter_list = [param1, param2]
model_ref = Model([sigm], parameter_list, "test")
return temp_dir + "/test_model", model_ref
def test_ovc_tool(self, ie_device, precision, ir_version, temp_dir, use_new_frontend, use_old_api):
@ -83,10 +108,10 @@ class TestOVCTool(CommonMOConvertTest):
core = Core()
# tests for MO cli tool
exit_code, stderr = generate_ir_ovc(coverage=False, **{"input_model": model_path, "output_model": temp_dir + os.sep + "model"})
exit_code, stderr = generate_ir_ovc(coverage=False, **{"input_model": model_path, "output_model": temp_dir + os.sep + "model1"})
assert not exit_code
ov_model = core.read_model(os.path.join(temp_dir, "model.xml"))
ov_model = core.read_model(os.path.join(temp_dir, "model1.xml"))
flag, msg = compare_functions(ov_model, create_ref_graph(), False)
assert flag, msg
@ -101,6 +126,32 @@ class TestOVCTool(CommonMOConvertTest):
exit_code, stderr = generate_ir_ovc(coverage=False, **{"input_model": model_path, "output_model": temp_dir})
assert not exit_code
ov_model = core.read_model(os.path.join(temp_dir, "model.xml"))
ov_model = core.read_model(os.path.join(temp_dir, "model2.xml"))
flag, msg = compare_functions(ov_model, create_ref_graph(), False)
assert flag, msg
assert flag, msg
def test_ovc_tool_saved_model_dir(self, ie_device, precision, ir_version, temp_dir, use_new_frontend, use_old_api):
from openvino.runtime import Core
core = Core()
model_dir, ref_model = self.create_tf_saved_model_dir(temp_dir)
exit_code, stderr = generate_ir_ovc(coverage=False, **{"input_model": model_dir, "output_model": temp_dir})
assert not exit_code
ov_model = core.read_model(os.path.join(temp_dir, "test_model.xml"))
flag, msg = compare_functions(ov_model, ref_model, False)
assert flag, msg
def test_ovc_tool_saved_model_dir_with_sep_at_path_end(self, ie_device, precision, ir_version, temp_dir, use_new_frontend, use_old_api):
from openvino.runtime import Core
core = Core()
model_dir, ref_model = self.create_tf_saved_model_dir(temp_dir)
exit_code, stderr = generate_ir_ovc(coverage=False, **{"input_model": model_dir + os.sep, "output_model": temp_dir})
assert not exit_code
ov_model = core.read_model(os.path.join(temp_dir, "test_model.xml"))
flag, msg = compare_functions(ov_model, ref_model, False)
assert flag, msg

View File

@ -693,7 +693,18 @@ def get_model_name_from_args(argv: argparse.Namespace):
if not isinstance(input_model, (str, pathlib.Path)):
return output_dir
input_model_name = os.path.splitext(os.path.split(input_model)[1])[0]
input_model_name = os.path.basename(input_model)
if input_model_name == '':
input_model_name = os.path.basename(os.path.dirname(input_model))
# remove extension if exists
input_model_name = os.path.splitext(input_model_name)[0]
# if no valid name exists in input path set name to 'model'
if input_model_name == '' or input_model_name == '.':
input_model_name = "model"
# add .xml extension
return os.path.join(output_dir, input_model_name + ".xml")