use input parameter for building example_inputs (#17207)
* use input parameter for building example_inputs * Update tools/mo/openvino/tools/mo/moc_frontend/pytorch_frontend_utils.py
This commit is contained in:
@@ -577,6 +577,13 @@ def create_pytorch_nn_module_shapes_list_static(tmp_dir):
|
||||
return pt_model, ref_model, {'input_shape': [[1, 3, 20, 20], [1, 3, 20, 20]], "input": [np.float32, np.float32]}
|
||||
|
||||
|
||||
def create_pytorch_nn_module_shapes_list_static_via_input(tmp_dir):
|
||||
pt_model = make_pt_model_two_inputs()
|
||||
ref_model = make_ref_pt_model_two_inputs([1, 3, 20, 20])
|
||||
|
||||
return pt_model, ref_model, {"input": [([1, 3, 20, 20], np.float32), ([1, 3, 20, 20], np.float32)]}
|
||||
|
||||
|
||||
def create_pytorch_nn_module_shapes_list_dynamic(tmp_dir):
|
||||
pt_model = make_pt_model_two_inputs()
|
||||
inp_shapes = [[Dimension(-1), 3, 20, Dimension(20, -1)],
|
||||
@@ -595,6 +602,24 @@ def create_pytorch_nn_module_shapes_list_dynamic(tmp_dir):
|
||||
return pt_model, ref_model, {'input_shape': inp_shapes, "input": [np.float32, np.float32]}
|
||||
|
||||
|
||||
def create_pytorch_nn_module_shapes_list_dynamic_via_input(tmp_dir):
|
||||
pt_model = make_pt_model_two_inputs()
|
||||
inp_shapes = [[Dimension(-1), 3, 20, Dimension(20, -1)],
|
||||
[-1, 3, 20, Dimension(-1, 20)]]
|
||||
|
||||
param1 = ov.opset8.parameter(PartialShape(
|
||||
inp_shapes[0]), name="x", dtype=np.float32)
|
||||
param2 = ov.opset8.parameter(PartialShape(
|
||||
inp_shapes[1]), name="y", dtype=np.float32)
|
||||
add = ov.opset8.add(param1, param2)
|
||||
relu = ov.opset8.relu(add)
|
||||
sigm = ov.opset8.sigmoid(relu)
|
||||
|
||||
parameter_list = [param1, param2]
|
||||
ref_model = Model([sigm], parameter_list, "test")
|
||||
return pt_model, ref_model, {"input": [(inp_shapes[0], np.float32), (inp_shapes[1], np.float32)]}
|
||||
|
||||
|
||||
def create_pytorch_nn_module_shapes_list_dynamic_single_input(tmp_dir):
|
||||
pt_model = make_pt_model_one_input()
|
||||
inp_shapes = [[Dimension(-1), 3, 20, Dimension(20, -1)]]
|
||||
@@ -602,6 +627,13 @@ def create_pytorch_nn_module_shapes_list_dynamic_single_input(tmp_dir):
|
||||
return pt_model, ref_model, {'input_shape': inp_shapes, "input": np.float32}
|
||||
|
||||
|
||||
def create_pytorch_nn_module_shapes_list_dynamic_single_input_via_input(tmp_dir):
|
||||
pt_model = make_pt_model_one_input()
|
||||
inp_shapes = [Dimension(-1), 3, 20, Dimension(20, -1)]
|
||||
ref_model = make_ref_pt_model_one_input(inp_shapes)
|
||||
return pt_model, ref_model, {"input": InputCutInfo(shape=inp_shapes, type=np.float32)}
|
||||
|
||||
|
||||
def create_pytorch_nn_module_shapes_list_static_single_input(tmp_dir):
|
||||
pt_model = make_pt_model_one_input()
|
||||
inp_shapes = [[1, 3, 20, 20]]
|
||||
@@ -609,6 +641,13 @@ def create_pytorch_nn_module_shapes_list_static_single_input(tmp_dir):
|
||||
return pt_model, ref_model, {'input_shape': inp_shapes, "input": np.float32}
|
||||
|
||||
|
||||
def create_pytorch_nn_module_shapes_list_static_single_input_via_input(tmp_dir):
|
||||
pt_model = make_pt_model_one_input()
|
||||
inp_shapes = [1, 3, 20, 20]
|
||||
ref_model = make_ref_pt_model_one_input(inp_shapes)
|
||||
return pt_model, ref_model, {"input": (inp_shapes, np.float32)}
|
||||
|
||||
|
||||
def create_pytorch_nn_module_convert_pytorch_frontend1(tmp_dir):
|
||||
pt_model = make_pt_model_one_input()
|
||||
shape = [-1, -1, -1, -1]
|
||||
@@ -750,9 +789,13 @@ class TestMoConvertPyTorch(CommonMOConvertTest):
|
||||
create_pytorch_nn_module_scale_list_default_no_compression,
|
||||
create_pytorch_nn_module_scale_list_compression_enabled,
|
||||
create_pytorch_nn_module_shapes_list_static,
|
||||
create_pytorch_nn_module_shapes_list_static_via_input,
|
||||
create_pytorch_nn_module_shapes_list_dynamic,
|
||||
create_pytorch_nn_module_shapes_list_dynamic_via_input,
|
||||
create_pytorch_nn_module_shapes_list_dynamic_single_input,
|
||||
create_pytorch_nn_module_shapes_list_static_single_input,
|
||||
create_pytorch_nn_module_shapes_list_dynamic_single_input_via_input,
|
||||
create_pytorch_nn_module_shapes_list_static_single_input_via_input,
|
||||
create_pytorch_nn_module_convert_pytorch_frontend1,
|
||||
create_pytorch_nn_module_convert_pytorch_frontend2,
|
||||
create_pytorch_nn_module_convert_pytorch_frontend3,
|
||||
|
||||
Reference in New Issue
Block a user