diff --git a/ngraph/frontend/onnx/frontend/src/input_model.cpp b/ngraph/frontend/onnx/frontend/src/input_model.cpp index 5dc2ff54c76..ac136915bcf 100644 --- a/ngraph/frontend/onnx/frontend/src/input_model.cpp +++ b/ngraph/frontend/onnx/frontend/src/input_model.cpp @@ -147,16 +147,42 @@ void InputModelONNX::extract_subgraph(const std::vector& inputs, con [](const onnx_editor::InputEdge& edge) { return edge; }); + } else if (const auto op = std::dynamic_pointer_cast(input)) { + const auto editor_node = op->get_editor_node(); + const auto op_inputs = m_editor->get_input_ports(editor_node); + int node_idx = m_editor->get_node_index(editor_node); + int port_idx = 0; + std::transform(std::begin(op_inputs), + std::end(op_inputs), + std::back_inserter(onnx_inputs), + [&node_idx, &port_idx](const std::string&) { + return onnx_editor::InputEdge{node_idx, port_idx++}; + }); } } std::vector onnx_outputs; onnx_outputs.reserve(outputs.size()); for (const auto& output : outputs) { - const auto output_port = output->get_producing_port(); - const auto onnx_output_edge = std::dynamic_pointer_cast(output_port); - NGRAPH_CHECK(onnx_output_edge, "Non-onnx output place was passed as extraction subgraph argument"); - onnx_outputs.push_back(onnx_output_edge->get_output_edge()); + if (const auto output_port = std::dynamic_pointer_cast(output)) { + onnx_outputs.push_back(output_port->get_output_edge()); + } else if (const auto tensor = std::dynamic_pointer_cast(output)) { + const auto output_port = tensor->get_producing_port(); + const auto onnx_output_edge = std::dynamic_pointer_cast(output_port); + NGRAPH_CHECK(onnx_output_edge, "Non-onnx output place was passed as extraction subgraph argument"); + onnx_outputs.push_back(onnx_output_edge->get_output_edge()); + } else if (const auto op = std::dynamic_pointer_cast(output)) { + const auto editor_node = op->get_editor_node(); + const auto op_outputs = m_editor->get_output_ports(editor_node); + int node_idx = m_editor->get_node_index(editor_node); + int port_idx = 0; + std::transform(std::begin(op_outputs), + std::end(op_outputs), + std::back_inserter(onnx_outputs), + [&node_idx, &port_idx](const std::string&) { + return onnx_editor::OutputEdge{node_idx, port_idx++}; + }); + } } m_editor->cut_graph_fragment(onnx_inputs, onnx_outputs); } diff --git a/runtime/bindings/python/tests/test_frontend/test_frontend_onnx_editor.py b/runtime/bindings/python/tests/test_frontend/test_frontend_onnx_editor.py index 0c9b7837a6e..cd8bdabd3e5 100644 --- a/runtime/bindings/python/tests/test_frontend/test_frontend_onnx_editor.py +++ b/runtime/bindings/python/tests/test_frontend/test_frontend_onnx_editor.py @@ -132,9 +132,9 @@ def create_test_onnx_models(): # Expected for extract_subgraph 4 input_tensors = [ + make_tensor_value_info("out1/placeholder_port_0", onnx.TensorProto.FLOAT, (2, 2)), make_tensor_value_info("out4/placeholder_port_0", onnx.TensorProto.FLOAT, (2, 2)), make_tensor_value_info("out4/placeholder_port_1", onnx.TensorProto.FLOAT, (2, 2)), - make_tensor_value_info("out1/placeholder_port_0", onnx.TensorProto.FLOAT, (2, 2)), ] output_tensors = [ make_tensor_value_info("out1", onnx.TensorProto.FLOAT, (1, 2)), @@ -149,6 +149,18 @@ def create_test_onnx_models(): models["extract_subgraph_4.onnx"] = make_model(graph, producer_name="ONNX Importer", opset_imports=[onnx.helper.make_opsetid("", 13)]) + # Expected for extract_subgraph 5 + input_tensors = [ + make_tensor_value_info("in1", onnx.TensorProto.FLOAT, (2, 2)), + make_tensor_value_info("in2", onnx.TensorProto.FLOAT, (2, 2)), + ] + output_tensors = [ + make_tensor_value_info("add_out", onnx.TensorProto.FLOAT, (2, 2)), + ] + graph = make_graph([add], "test_graph", input_tensors, output_tensors) + models["extract_subgraph_5.onnx"] = make_model(graph, producer_name="ONNX Importer", + opset_imports=[onnx.helper.make_opsetid("", 13)]) + # Expected for test_override_all_outputs input_tensors = [ make_tensor_value_info("in1", onnx.TensorProto.FLOAT, (2, 2)), @@ -355,9 +367,9 @@ def test_extract_subgraph_4(): assert model out4_tensor = model.get_place_by_tensor_name(tensorName="out4") - place1 = out4_tensor.get_producing_operation().get_input_port(inputPortIndex=0) - place2 = out4_tensor.get_producing_operation().get_input_port(inputPortIndex=1) - place3 = model.get_place_by_operation_name_and_input_port(operationName="split1", inputPortIndex=0) + place1 = model.get_place_by_operation_name_and_input_port(operationName="split1", inputPortIndex=0) + place2 = out4_tensor.get_producing_operation().get_input_port(inputPortIndex=0) + place3 = out4_tensor.get_producing_operation().get_input_port(inputPortIndex=1) place4 = model.get_place_by_tensor_name(tensorName="out1") place5 = model.get_place_by_tensor_name(tensorName="out2") place6 = model.get_place_by_tensor_name(tensorName="out4") @@ -371,6 +383,99 @@ def test_extract_subgraph_4(): assert res +def test_extract_subgraph_by_op_place_as_input(): + skip_if_onnx_frontend_is_disabled() + fe = fem.load_by_framework(framework=ONNX_FRONTEND_NAME) + assert fe + + model = fe.load("input_model.onnx") + assert model + + split_op = model.get_place_by_operation_name(operationName="split1") + out4 = model.get_place_by_tensor_name(tensorName="out4") + mul_op = out4.get_producing_operation() + out1 = model.get_place_by_tensor_name(tensorName="out1") + out2 = model.get_place_by_tensor_name(tensorName="out2") + + model.extract_subgraph(inputs=[split_op, mul_op], outputs=[out1, out2, out4]) + result_func = fe.convert(model) + + expected_model = fe.load("extract_subgraph_4.onnx") + expected_func = fe.convert(expected_model) + + res = compare_functions(result_func, expected_func) + assert res + + +def test_extract_subgraph_by_op_place_as_output(): + skip_if_onnx_frontend_is_disabled() + fe = fem.load_by_framework(framework=ONNX_FRONTEND_NAME) + assert fe + + model = fe.load("input_model.onnx") + assert model + + in1_tensor = model.get_place_by_tensor_name(tensorName="in1") + in2_tensor = model.get_place_by_tensor_name(tensorName="in2") + add_out_tensor = model.get_place_by_tensor_name(tensorName="add_out") + add_op = add_out_tensor.get_producing_operation() + + model.extract_subgraph(inputs=[in1_tensor, in2_tensor], outputs=[add_op]) + result_func = fe.convert(model) + + expected_model = fe.load("extract_subgraph_5.onnx") + expected_func = fe.convert(expected_model) + + res = compare_functions(result_func, expected_func) + assert res + + +def test_extract_subgraph_by_op_place_as_output_2(): + skip_if_onnx_frontend_is_disabled() + fe = fem.load_by_framework(framework=ONNX_FRONTEND_NAME) + assert fe + + model = fe.load("input_model.onnx") + assert model + + split_op = model.get_place_by_operation_name(operationName="split1") + out4 = model.get_place_by_tensor_name(tensorName="out4") + mul_op = out4.get_producing_operation() + + model.extract_subgraph(inputs=[split_op, mul_op], outputs=[]) + result_func = fe.convert(model) + + expected_model = fe.load("test_override_all_inputs.onnx") + expected_func = fe.convert(expected_model) + + res = compare_functions(result_func, expected_func) + assert res + + +def test_extract_subgraph_by_port_place_as_output(): + skip_if_onnx_frontend_is_disabled() + fe = fem.load_by_framework(framework=ONNX_FRONTEND_NAME) + assert fe + + model = fe.load("input_model.onnx") + assert model + + add_out_tensor = model.get_place_by_tensor_name(tensorName="add_out") + add_op = add_out_tensor.get_producing_operation() + add_op_out_port = add_op.get_output_port(outputPortIndex=0) + in1_tensor = model.get_place_by_tensor_name(tensorName="in1") + in2_tensor = model.get_place_by_tensor_name(tensorName="in2") + + model.extract_subgraph(inputs=[in1_tensor, in2_tensor], outputs=[add_op_out_port]) + result_func = fe.convert(model) + + expected_model = fe.load("extract_subgraph.onnx") + expected_func = fe.convert(expected_model) + + res = compare_functions(result_func, expected_func) + assert res + + def test_override_all_outputs(): skip_if_onnx_frontend_is_disabled() fe = fem.load_by_framework(framework=ONNX_FRONTEND_NAME)