Handle PlaceOpONNX and PlaceOutputEdgeONNX by extract_subgraph (#7908)

This commit is contained in:
Mateusz Bencer 2021-10-12 11:54:41 +02:00 committed by GitHub
parent a10f40d6d4
commit 29eaa0af60
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 139 additions and 8 deletions

View File

@ -147,16 +147,42 @@ void InputModelONNX::extract_subgraph(const std::vector<Place::Ptr>& inputs, con
[](const onnx_editor::InputEdge& edge) {
return edge;
});
} else if (const auto op = std::dynamic_pointer_cast<PlaceOpONNX>(input)) {
const auto editor_node = op->get_editor_node();
const auto op_inputs = m_editor->get_input_ports(editor_node);
int node_idx = m_editor->get_node_index(editor_node);
int port_idx = 0;
std::transform(std::begin(op_inputs),
std::end(op_inputs),
std::back_inserter(onnx_inputs),
[&node_idx, &port_idx](const std::string&) {
return onnx_editor::InputEdge{node_idx, port_idx++};
});
}
}
std::vector<onnx_editor::OutputEdge> onnx_outputs;
onnx_outputs.reserve(outputs.size());
for (const auto& output : outputs) {
const auto output_port = output->get_producing_port();
const auto onnx_output_edge = std::dynamic_pointer_cast<PlaceOutputEdgeONNX>(output_port);
NGRAPH_CHECK(onnx_output_edge, "Non-onnx output place was passed as extraction subgraph argument");
onnx_outputs.push_back(onnx_output_edge->get_output_edge());
if (const auto output_port = std::dynamic_pointer_cast<PlaceOutputEdgeONNX>(output)) {
onnx_outputs.push_back(output_port->get_output_edge());
} else if (const auto tensor = std::dynamic_pointer_cast<PlaceTensorONNX>(output)) {
const auto output_port = tensor->get_producing_port();
const auto onnx_output_edge = std::dynamic_pointer_cast<PlaceOutputEdgeONNX>(output_port);
NGRAPH_CHECK(onnx_output_edge, "Non-onnx output place was passed as extraction subgraph argument");
onnx_outputs.push_back(onnx_output_edge->get_output_edge());
} else if (const auto op = std::dynamic_pointer_cast<PlaceOpONNX>(output)) {
const auto editor_node = op->get_editor_node();
const auto op_outputs = m_editor->get_output_ports(editor_node);
int node_idx = m_editor->get_node_index(editor_node);
int port_idx = 0;
std::transform(std::begin(op_outputs),
std::end(op_outputs),
std::back_inserter(onnx_outputs),
[&node_idx, &port_idx](const std::string&) {
return onnx_editor::OutputEdge{node_idx, port_idx++};
});
}
}
m_editor->cut_graph_fragment(onnx_inputs, onnx_outputs);
}

View File

@ -132,9 +132,9 @@ def create_test_onnx_models():
# Expected for extract_subgraph 4
input_tensors = [
make_tensor_value_info("out1/placeholder_port_0", onnx.TensorProto.FLOAT, (2, 2)),
make_tensor_value_info("out4/placeholder_port_0", onnx.TensorProto.FLOAT, (2, 2)),
make_tensor_value_info("out4/placeholder_port_1", onnx.TensorProto.FLOAT, (2, 2)),
make_tensor_value_info("out1/placeholder_port_0", onnx.TensorProto.FLOAT, (2, 2)),
]
output_tensors = [
make_tensor_value_info("out1", onnx.TensorProto.FLOAT, (1, 2)),
@ -149,6 +149,18 @@ def create_test_onnx_models():
models["extract_subgraph_4.onnx"] = make_model(graph, producer_name="ONNX Importer",
opset_imports=[onnx.helper.make_opsetid("", 13)])
# Expected for extract_subgraph 5
input_tensors = [
make_tensor_value_info("in1", onnx.TensorProto.FLOAT, (2, 2)),
make_tensor_value_info("in2", onnx.TensorProto.FLOAT, (2, 2)),
]
output_tensors = [
make_tensor_value_info("add_out", onnx.TensorProto.FLOAT, (2, 2)),
]
graph = make_graph([add], "test_graph", input_tensors, output_tensors)
models["extract_subgraph_5.onnx"] = make_model(graph, producer_name="ONNX Importer",
opset_imports=[onnx.helper.make_opsetid("", 13)])
# Expected for test_override_all_outputs
input_tensors = [
make_tensor_value_info("in1", onnx.TensorProto.FLOAT, (2, 2)),
@ -355,9 +367,9 @@ def test_extract_subgraph_4():
assert model
out4_tensor = model.get_place_by_tensor_name(tensorName="out4")
place1 = out4_tensor.get_producing_operation().get_input_port(inputPortIndex=0)
place2 = out4_tensor.get_producing_operation().get_input_port(inputPortIndex=1)
place3 = model.get_place_by_operation_name_and_input_port(operationName="split1", inputPortIndex=0)
place1 = model.get_place_by_operation_name_and_input_port(operationName="split1", inputPortIndex=0)
place2 = out4_tensor.get_producing_operation().get_input_port(inputPortIndex=0)
place3 = out4_tensor.get_producing_operation().get_input_port(inputPortIndex=1)
place4 = model.get_place_by_tensor_name(tensorName="out1")
place5 = model.get_place_by_tensor_name(tensorName="out2")
place6 = model.get_place_by_tensor_name(tensorName="out4")
@ -371,6 +383,99 @@ def test_extract_subgraph_4():
assert res
def test_extract_subgraph_by_op_place_as_input():
skip_if_onnx_frontend_is_disabled()
fe = fem.load_by_framework(framework=ONNX_FRONTEND_NAME)
assert fe
model = fe.load("input_model.onnx")
assert model
split_op = model.get_place_by_operation_name(operationName="split1")
out4 = model.get_place_by_tensor_name(tensorName="out4")
mul_op = out4.get_producing_operation()
out1 = model.get_place_by_tensor_name(tensorName="out1")
out2 = model.get_place_by_tensor_name(tensorName="out2")
model.extract_subgraph(inputs=[split_op, mul_op], outputs=[out1, out2, out4])
result_func = fe.convert(model)
expected_model = fe.load("extract_subgraph_4.onnx")
expected_func = fe.convert(expected_model)
res = compare_functions(result_func, expected_func)
assert res
def test_extract_subgraph_by_op_place_as_output():
skip_if_onnx_frontend_is_disabled()
fe = fem.load_by_framework(framework=ONNX_FRONTEND_NAME)
assert fe
model = fe.load("input_model.onnx")
assert model
in1_tensor = model.get_place_by_tensor_name(tensorName="in1")
in2_tensor = model.get_place_by_tensor_name(tensorName="in2")
add_out_tensor = model.get_place_by_tensor_name(tensorName="add_out")
add_op = add_out_tensor.get_producing_operation()
model.extract_subgraph(inputs=[in1_tensor, in2_tensor], outputs=[add_op])
result_func = fe.convert(model)
expected_model = fe.load("extract_subgraph_5.onnx")
expected_func = fe.convert(expected_model)
res = compare_functions(result_func, expected_func)
assert res
def test_extract_subgraph_by_op_place_as_output_2():
skip_if_onnx_frontend_is_disabled()
fe = fem.load_by_framework(framework=ONNX_FRONTEND_NAME)
assert fe
model = fe.load("input_model.onnx")
assert model
split_op = model.get_place_by_operation_name(operationName="split1")
out4 = model.get_place_by_tensor_name(tensorName="out4")
mul_op = out4.get_producing_operation()
model.extract_subgraph(inputs=[split_op, mul_op], outputs=[])
result_func = fe.convert(model)
expected_model = fe.load("test_override_all_inputs.onnx")
expected_func = fe.convert(expected_model)
res = compare_functions(result_func, expected_func)
assert res
def test_extract_subgraph_by_port_place_as_output():
skip_if_onnx_frontend_is_disabled()
fe = fem.load_by_framework(framework=ONNX_FRONTEND_NAME)
assert fe
model = fe.load("input_model.onnx")
assert model
add_out_tensor = model.get_place_by_tensor_name(tensorName="add_out")
add_op = add_out_tensor.get_producing_operation()
add_op_out_port = add_op.get_output_port(outputPortIndex=0)
in1_tensor = model.get_place_by_tensor_name(tensorName="in1")
in2_tensor = model.get_place_by_tensor_name(tensorName="in2")
model.extract_subgraph(inputs=[in1_tensor, in2_tensor], outputs=[add_op_out_port])
result_func = fe.convert(model)
expected_model = fe.load("extract_subgraph.onnx")
expected_func = fe.convert(expected_model)
res = compare_functions(result_func, expected_func)
assert res
def test_override_all_outputs():
skip_if_onnx_frontend_is_disabled()
fe = fem.load_by_framework(framework=ONNX_FRONTEND_NAME)