Fixed 'output_model' logic in OVC. (#19171)
* Fixed output_model logic. * Removed not needed code. * Used os.path.basename, added comments. * Removed loop. * Test corrections.
This commit is contained in:
parent
3e23908983
commit
7a566313e5
@ -174,8 +174,8 @@ def transpose_nhwc_to_nchw(data, use_new_frontend, use_old_api):
|
||||
return data
|
||||
|
||||
|
||||
def save_to_pb(tf_model, path_to_saved_tf_model):
|
||||
tf.io.write_graph(tf_model, path_to_saved_tf_model, 'model.pb', False)
|
||||
assert os.path.isfile(os.path.join(path_to_saved_tf_model, 'model.pb')), "model.pb haven't been saved " \
|
||||
def save_to_pb(tf_model, path_to_saved_tf_model, model_name = 'model.pb'):
|
||||
tf.io.write_graph(tf_model, path_to_saved_tf_model, model_name, False)
|
||||
assert os.path.isfile(os.path.join(path_to_saved_tf_model, model_name)), "model.pb haven't been saved " \
|
||||
"here: {}".format(path_to_saved_tf_model)
|
||||
return os.path.join(path_to_saved_tf_model, 'model.pb')
|
||||
return os.path.join(path_to_saved_tf_model, model_name)
|
||||
|
@ -72,7 +72,32 @@ class TestOVCTool(CommonMOConvertTest):
|
||||
tf_net = sess.graph_def
|
||||
|
||||
# save model to .pb and return path to the model
|
||||
return save_to_pb(tf_net, tmp_dir)
|
||||
return save_to_pb(tf_net, tmp_dir, 'model2.pb')
|
||||
|
||||
def create_tf_saved_model_dir(self, temp_dir):
|
||||
import tensorflow as tf
|
||||
|
||||
input_names = ["Input1", "Input2"]
|
||||
input_shape = [1, 2, 3]
|
||||
|
||||
x1 = tf.keras.Input(shape=input_shape, name=input_names[0])
|
||||
x2 = tf.keras.Input(shape=input_shape, name=input_names[1])
|
||||
y = tf.nn.sigmoid(tf.nn.relu(x1 + x2))
|
||||
keras_net = tf.keras.Model(inputs=[x1, x2], outputs=[y])
|
||||
|
||||
tf.saved_model.save(keras_net, temp_dir + "/test_model")
|
||||
|
||||
shape = PartialShape([-1, 1, 2, 3])
|
||||
param1 = ov.opset8.parameter(shape, name="Input1:0", dtype=np.float32)
|
||||
param2 = ov.opset8.parameter(shape, name="Input2:0", dtype=np.float32)
|
||||
add = ov.opset8.add(param1, param2)
|
||||
relu = ov.opset8.relu(add)
|
||||
sigm = ov.opset8.sigmoid(relu)
|
||||
|
||||
parameter_list = [param1, param2]
|
||||
model_ref = Model([sigm], parameter_list, "test")
|
||||
|
||||
return temp_dir + "/test_model", model_ref
|
||||
|
||||
|
||||
def test_ovc_tool(self, ie_device, precision, ir_version, temp_dir, use_new_frontend, use_old_api):
|
||||
@ -83,10 +108,10 @@ class TestOVCTool(CommonMOConvertTest):
|
||||
core = Core()
|
||||
|
||||
# tests for MO cli tool
|
||||
exit_code, stderr = generate_ir_ovc(coverage=False, **{"input_model": model_path, "output_model": temp_dir + os.sep + "model"})
|
||||
exit_code, stderr = generate_ir_ovc(coverage=False, **{"input_model": model_path, "output_model": temp_dir + os.sep + "model1"})
|
||||
assert not exit_code
|
||||
|
||||
ov_model = core.read_model(os.path.join(temp_dir, "model.xml"))
|
||||
ov_model = core.read_model(os.path.join(temp_dir, "model1.xml"))
|
||||
flag, msg = compare_functions(ov_model, create_ref_graph(), False)
|
||||
assert flag, msg
|
||||
|
||||
@ -101,6 +126,32 @@ class TestOVCTool(CommonMOConvertTest):
|
||||
exit_code, stderr = generate_ir_ovc(coverage=False, **{"input_model": model_path, "output_model": temp_dir})
|
||||
assert not exit_code
|
||||
|
||||
ov_model = core.read_model(os.path.join(temp_dir, "model.xml"))
|
||||
ov_model = core.read_model(os.path.join(temp_dir, "model2.xml"))
|
||||
flag, msg = compare_functions(ov_model, create_ref_graph(), False)
|
||||
assert flag, msg
|
||||
assert flag, msg
|
||||
|
||||
def test_ovc_tool_saved_model_dir(self, ie_device, precision, ir_version, temp_dir, use_new_frontend, use_old_api):
|
||||
from openvino.runtime import Core
|
||||
core = Core()
|
||||
|
||||
model_dir, ref_model = self.create_tf_saved_model_dir(temp_dir)
|
||||
|
||||
exit_code, stderr = generate_ir_ovc(coverage=False, **{"input_model": model_dir, "output_model": temp_dir})
|
||||
assert not exit_code
|
||||
|
||||
ov_model = core.read_model(os.path.join(temp_dir, "test_model.xml"))
|
||||
flag, msg = compare_functions(ov_model, ref_model, False)
|
||||
assert flag, msg
|
||||
|
||||
def test_ovc_tool_saved_model_dir_with_sep_at_path_end(self, ie_device, precision, ir_version, temp_dir, use_new_frontend, use_old_api):
|
||||
from openvino.runtime import Core
|
||||
core = Core()
|
||||
|
||||
model_dir, ref_model = self.create_tf_saved_model_dir(temp_dir)
|
||||
|
||||
exit_code, stderr = generate_ir_ovc(coverage=False, **{"input_model": model_dir + os.sep, "output_model": temp_dir})
|
||||
assert not exit_code
|
||||
|
||||
ov_model = core.read_model(os.path.join(temp_dir, "test_model.xml"))
|
||||
flag, msg = compare_functions(ov_model, ref_model, False)
|
||||
assert flag, msg
|
||||
|
@ -693,7 +693,18 @@ def get_model_name_from_args(argv: argparse.Namespace):
|
||||
if not isinstance(input_model, (str, pathlib.Path)):
|
||||
return output_dir
|
||||
|
||||
input_model_name = os.path.splitext(os.path.split(input_model)[1])[0]
|
||||
input_model_name = os.path.basename(input_model)
|
||||
if input_model_name == '':
|
||||
input_model_name = os.path.basename(os.path.dirname(input_model))
|
||||
|
||||
# remove extension if exists
|
||||
input_model_name = os.path.splitext(input_model_name)[0]
|
||||
|
||||
# if no valid name exists in input path set name to 'model'
|
||||
if input_model_name == '' or input_model_name == '.':
|
||||
input_model_name = "model"
|
||||
|
||||
# add .xml extension
|
||||
return os.path.join(output_dir, input_model_name + ".xml")
|
||||
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user