diff --git a/src/bindings/python/tests/test_frontend/test_frontend_onnx_editor.py b/src/bindings/python/tests/test_frontend/test_frontend_onnx_editor.py index fe02533d05e..2e9acaf9021 100644 --- a/src/bindings/python/tests/test_frontend/test_frontend_onnx_editor.py +++ b/src/bindings/python/tests/test_frontend/test_frontend_onnx_editor.py @@ -49,6 +49,16 @@ from openvino.frontend import FrontEndManager # | | # out1 out2 # +# +# ------Test input model 3------ +# in1 in2 +# | / \ +# +--------+ +------+ +# | Add | | Relu | +# +--------+ +------+ +# | | +# out1 out2 +# def create_test_onnx_models(): models = {} # Input model 1 @@ -91,6 +101,23 @@ def create_test_onnx_models(): models["input_model_2.onnx"] = make_model(graph, producer_name="ONNX Importer", opset_imports=[onnx.helper.make_opsetid("", 13)]) + # Input model 3 + add_2 = onnx.helper.make_node("Add", inputs=["in1", "in2"], outputs=["out1"], name="onnx_add_op") + relu_2 = onnx.helper.make_node("Relu", inputs=["in2"], outputs=["out2"]) + + input_tensors = [ + make_tensor_value_info("in1", onnx.TensorProto.FLOAT, (2, 2)), + make_tensor_value_info("in2", onnx.TensorProto.FLOAT, (2, 2)), + ] + output_tensors = [ + make_tensor_value_info("out1", onnx.TensorProto.FLOAT, (2, 2)), + make_tensor_value_info("out1", onnx.TensorProto.FLOAT, (2, 2)), + make_tensor_value_info("out2", onnx.TensorProto.FLOAT, (2, 2)), + ] + graph = make_graph([add_2, relu_2], "test_graph_3", input_tensors, output_tensors) + models["input_model_3.onnx"] = make_model(graph, producer_name="ONNX Importer", + opset_imports=[onnx.helper.make_opsetid("", 13)]) + # Expected for extract_subgraph input_tensors = [ make_tensor_value_info("in1", onnx.TensorProto.FLOAT, (2, 2)), @@ -188,6 +215,19 @@ def create_test_onnx_models(): models["test_override_all_outputs_2.onnx"] = make_model(graph, producer_name="ONNX Importer", opset_imports=[onnx.helper.make_opsetid("", 13)]) + # Expected for test_override_all_outputs 3 + input_tensors = [ + make_tensor_value_info("in1", onnx.TensorProto.FLOAT, (2, 2)), + make_tensor_value_info("in2", onnx.TensorProto.FLOAT, (2, 2)), + ] + output_tensors = [ + make_tensor_value_info("out1", onnx.TensorProto.FLOAT, (2, 2)), + make_tensor_value_info("out1", onnx.TensorProto.FLOAT, (2, 2)), + ] + graph = make_graph([add_2], "test_graph_3", input_tensors, output_tensors) + models["test_override_all_outputs_3.onnx"] = make_model(graph, producer_name="ONNX Importer", + opset_imports=[onnx.helper.make_opsetid("", 13)]) + # Expected for test_override_all_inputs input_tensors = [ make_tensor_value_info("in3", onnx.TensorProto.FLOAT, (2, 2)), @@ -594,6 +634,50 @@ def test_override_all_outputs_2(): assert res +def test_override_all_outputs_3(): + skip_if_onnx_frontend_is_disabled() + fe = fem.load_by_framework(framework=ONNX_FRONTEND_NAME) + assert fe + + model = fe.load("input_model_3.onnx") + assert model + + place1 = model.get_place_by_tensor_name(tensor_name="out1") + place2 = model.get_place_by_tensor_name(tensor_name="out1") + model.override_all_outputs(outputs=[place1, place2]) + result_func = fe.convert(model) + + expected_model = fe.load("test_override_all_outputs_3.onnx") + expected_func = fe.convert(expected_model) + + res = compare_functions(result_func, expected_func) + assert res + + +def test_override_all_outputs_invalid_place(): + skip_if_onnx_frontend_is_disabled() + fe = fem.load_by_framework(framework=ONNX_FRONTEND_NAME) + assert fe + + model = fe.load("input_model_3.onnx") + assert model + + model2 = fe.load("input_model.onnx") + assert model2 + invalid_place = model2.get_place_by_tensor_name(tensor_name="out3") + + place1 = model.get_place_by_tensor_name(tensor_name="out1") + place2 = model.get_place_by_tensor_name(tensor_name="out1") + model.override_all_outputs(outputs=[place1, place2, invalid_place]) + result_func = fe.convert(model) + + expected_model = fe.load("test_override_all_outputs_3.onnx") + expected_func = fe.convert(expected_model) + + res = compare_functions(result_func, expected_func) + assert res + + def test_override_all_inputs(): skip_if_onnx_frontend_is_disabled() fe = fem.load_by_framework(framework=ONNX_FRONTEND_NAME) @@ -618,26 +702,31 @@ def test_override_all_inputs(): assert res -def test_override_all_inputs_exceptions(): +def test_override_all_inputs_invalid_place(): skip_if_onnx_frontend_is_disabled() fe = fem.load_by_framework(framework=ONNX_FRONTEND_NAME) assert fe - model = fe.load("input_model.onnx") + model = fe.load("input_model_3.onnx") assert model - place1 = model.get_place_by_tensor_name(tensor_name="in1") - place2 = model.get_place_by_tensor_name(tensor_name="in2") - place3 = model.get_place_by_operation_name_and_input_port(operation_name="split1", input_port_index=0) - place4 = model.get_place_by_tensor_name(tensor_name="in3") + model2 = fe.load("input_model.onnx") + assert model2 - with pytest.raises(Exception) as e: - model.override_all_inputs(inputs=[place1, place2]) - assert "Unexpected number of inputs after override_all_inputs" in str(e) + out3_tensor = model2.get_place_by_tensor_name(tensor_name="out3") + invalid_place = out3_tensor.get_producing_operation().get_input_port(input_port_index=0) - with pytest.raises(Exception) as e: - model.override_all_inputs(inputs=[place3, place4]) - assert "Unexpected number of inputs after override_all_inputs" in str(e) + out1_tensor = model.get_place_by_tensor_name(tensor_name="out1") + place1 = out1_tensor.get_producing_operation().get_input_port(input_port_index=0) + place2 = out1_tensor.get_producing_operation().get_input_port(input_port_index=1) + model.override_all_inputs(inputs=[place1, place2, invalid_place]) + result_func = fe.convert(model) + + expected_model = fe.load("input_model_3.onnx") + expected_func = fe.convert(expected_model) + + res = compare_functions(result_func, expected_func) + assert res def test_is_input_output(): diff --git a/src/frontends/onnx/frontend/src/input_model.cpp b/src/frontends/onnx/frontend/src/input_model.cpp index d712b6c3f8f..81cef6f947f 100644 --- a/src/frontends/onnx/frontend/src/input_model.cpp +++ b/src/frontends/onnx/frontend/src/input_model.cpp @@ -7,6 +7,7 @@ #include #include +#include "ngraph/log.hpp" #include "place.hpp" using namespace ov; @@ -202,28 +203,90 @@ std::shared_ptr InputModel::convert() { } // Editor features +bool InputModel::is_correct_place(const ov::frontend::Place::Ptr& place) const { + if (const auto tensor = std::dynamic_pointer_cast(place)) { + return m_editor->is_correct_tensor_name(tensor->get_names()[0]); + } + if (const auto op = std::dynamic_pointer_cast(place)) { + return m_editor->is_correct_and_unambiguous_node(op->get_editor_node()); + } + if (const auto input_edge = std::dynamic_pointer_cast(place)) { + if (auto tensor = std::dynamic_pointer_cast(input_edge->get_source_tensor())) { + return m_editor->is_correct_tensor_name(tensor->get_names()[0]); + } + } + if (const auto output_edge = std::dynamic_pointer_cast(place)) { + if (auto tensor = std::dynamic_pointer_cast(output_edge->get_target_tensor())) { + return m_editor->is_correct_tensor_name(tensor->get_names()[0]); + } + } + return false; +} + void InputModel::override_all_outputs(const std::vector& outputs) { - extract_subgraph({}, outputs); - NGRAPH_CHECK(m_editor->model_outputs().size() == outputs.size(), - "Unexpected number of outputs after override_all_outputs"); - NGRAPH_CHECK(std::all_of(std::begin(outputs), - std::end(outputs), + std::vector expected_valid_outputs; + for (const auto& output : outputs) { + bool is_correct = is_correct_place(output); + if (!is_correct) + NGRAPH_WARN << "Name " << output->get_names().at(0) + << " of output node is not a correct node name. Ignoring this parameter."; + else + expected_valid_outputs.push_back(output); + } + + extract_subgraph({}, expected_valid_outputs); + + NGRAPH_CHECK(std::all_of(std::begin(expected_valid_outputs), + std::end(expected_valid_outputs), [](const ov::frontend::Place::Ptr& place) { return place->is_output(); }), "Not all provided arguments of override_all_outputs are new outputs of the model"); + + const auto current_outputs = get_outputs(); + NGRAPH_CHECK(std::all_of(std::begin(current_outputs), + std::end(current_outputs), + [&](const Place::Ptr& current_out) { + return std::find_if(std::begin(expected_valid_outputs), + std::end(expected_valid_outputs), + [&](const Place::Ptr& expected_out) { + return expected_out->is_equal(current_out); + }) != std::end(current_outputs); + }), + "Some other than expected outputs were created during override_all_outputs"); } void InputModel::override_all_inputs(const std::vector& inputs) { + std::vector expected_valid_inputs; + for (const auto& input : inputs) { + bool is_correct = is_correct_place(input); + if (!is_correct) + NGRAPH_WARN << "Name " << input->get_names().at(0) + << " of input node is not a correct node. Ignoring this parameter."; + else + expected_valid_inputs.push_back(input); + } + const auto outputs_before_extraction = m_editor->model_outputs(); - extract_subgraph({inputs}, {}); + extract_subgraph({expected_valid_inputs}, {}); + NGRAPH_CHECK(std::equal(std::begin(outputs_before_extraction), std::end(outputs_before_extraction), std::begin(m_editor->model_outputs())), "All outputs should be preserved after override_all_inputs. Provided inputs does " "not satisfy all outputs"); - NGRAPH_CHECK(m_editor->model_inputs().size() == inputs.size(), - "Unexpected number of inputs after override_all_inputs"); + + const auto current_inputs = get_inputs(); + NGRAPH_CHECK(std::all_of(std::begin(current_inputs), + std::end(current_inputs), + [&](const Place::Ptr& current_in) { + return std::find_if(std::begin(expected_valid_inputs), + std::end(expected_valid_inputs), + [&](const Place::Ptr& expected_in) { + return expected_in->is_equal(current_in); + }) != std::end(current_inputs); + }), + "Some other than expected inputs were created during override_all_inputs"); } void InputModel::extract_subgraph(const std::vector& inputs, diff --git a/src/frontends/onnx/frontend/src/input_model.hpp b/src/frontends/onnx/frontend/src/input_model.hpp index 863a253f02e..e8fe258259a 100644 --- a/src/frontends/onnx/frontend/src/input_model.hpp +++ b/src/frontends/onnx/frontend/src/input_model.hpp @@ -78,6 +78,7 @@ public: private: std::shared_ptr m_editor; + bool is_correct_place(const ov::frontend::Place::Ptr& place) const; std::unordered_map> m_additional_tensor_names; void add_tensor_names(std::shared_ptr& model);