Handle PlaceOpONNX and PlaceOutputEdgeONNX by extract_subgraph (#7908)
This commit is contained in:
parent
a10f40d6d4
commit
29eaa0af60
@ -147,16 +147,42 @@ void InputModelONNX::extract_subgraph(const std::vector<Place::Ptr>& inputs, con
|
||||
[](const onnx_editor::InputEdge& edge) {
|
||||
return edge;
|
||||
});
|
||||
} else if (const auto op = std::dynamic_pointer_cast<PlaceOpONNX>(input)) {
|
||||
const auto editor_node = op->get_editor_node();
|
||||
const auto op_inputs = m_editor->get_input_ports(editor_node);
|
||||
int node_idx = m_editor->get_node_index(editor_node);
|
||||
int port_idx = 0;
|
||||
std::transform(std::begin(op_inputs),
|
||||
std::end(op_inputs),
|
||||
std::back_inserter(onnx_inputs),
|
||||
[&node_idx, &port_idx](const std::string&) {
|
||||
return onnx_editor::InputEdge{node_idx, port_idx++};
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
std::vector<onnx_editor::OutputEdge> onnx_outputs;
|
||||
onnx_outputs.reserve(outputs.size());
|
||||
for (const auto& output : outputs) {
|
||||
const auto output_port = output->get_producing_port();
|
||||
if (const auto output_port = std::dynamic_pointer_cast<PlaceOutputEdgeONNX>(output)) {
|
||||
onnx_outputs.push_back(output_port->get_output_edge());
|
||||
} else if (const auto tensor = std::dynamic_pointer_cast<PlaceTensorONNX>(output)) {
|
||||
const auto output_port = tensor->get_producing_port();
|
||||
const auto onnx_output_edge = std::dynamic_pointer_cast<PlaceOutputEdgeONNX>(output_port);
|
||||
NGRAPH_CHECK(onnx_output_edge, "Non-onnx output place was passed as extraction subgraph argument");
|
||||
onnx_outputs.push_back(onnx_output_edge->get_output_edge());
|
||||
} else if (const auto op = std::dynamic_pointer_cast<PlaceOpONNX>(output)) {
|
||||
const auto editor_node = op->get_editor_node();
|
||||
const auto op_outputs = m_editor->get_output_ports(editor_node);
|
||||
int node_idx = m_editor->get_node_index(editor_node);
|
||||
int port_idx = 0;
|
||||
std::transform(std::begin(op_outputs),
|
||||
std::end(op_outputs),
|
||||
std::back_inserter(onnx_outputs),
|
||||
[&node_idx, &port_idx](const std::string&) {
|
||||
return onnx_editor::OutputEdge{node_idx, port_idx++};
|
||||
});
|
||||
}
|
||||
}
|
||||
m_editor->cut_graph_fragment(onnx_inputs, onnx_outputs);
|
||||
}
|
||||
|
@ -132,9 +132,9 @@ def create_test_onnx_models():
|
||||
|
||||
# Expected for extract_subgraph 4
|
||||
input_tensors = [
|
||||
make_tensor_value_info("out1/placeholder_port_0", onnx.TensorProto.FLOAT, (2, 2)),
|
||||
make_tensor_value_info("out4/placeholder_port_0", onnx.TensorProto.FLOAT, (2, 2)),
|
||||
make_tensor_value_info("out4/placeholder_port_1", onnx.TensorProto.FLOAT, (2, 2)),
|
||||
make_tensor_value_info("out1/placeholder_port_0", onnx.TensorProto.FLOAT, (2, 2)),
|
||||
]
|
||||
output_tensors = [
|
||||
make_tensor_value_info("out1", onnx.TensorProto.FLOAT, (1, 2)),
|
||||
@ -149,6 +149,18 @@ def create_test_onnx_models():
|
||||
models["extract_subgraph_4.onnx"] = make_model(graph, producer_name="ONNX Importer",
|
||||
opset_imports=[onnx.helper.make_opsetid("", 13)])
|
||||
|
||||
# Expected for extract_subgraph 5
|
||||
input_tensors = [
|
||||
make_tensor_value_info("in1", onnx.TensorProto.FLOAT, (2, 2)),
|
||||
make_tensor_value_info("in2", onnx.TensorProto.FLOAT, (2, 2)),
|
||||
]
|
||||
output_tensors = [
|
||||
make_tensor_value_info("add_out", onnx.TensorProto.FLOAT, (2, 2)),
|
||||
]
|
||||
graph = make_graph([add], "test_graph", input_tensors, output_tensors)
|
||||
models["extract_subgraph_5.onnx"] = make_model(graph, producer_name="ONNX Importer",
|
||||
opset_imports=[onnx.helper.make_opsetid("", 13)])
|
||||
|
||||
# Expected for test_override_all_outputs
|
||||
input_tensors = [
|
||||
make_tensor_value_info("in1", onnx.TensorProto.FLOAT, (2, 2)),
|
||||
@ -355,9 +367,9 @@ def test_extract_subgraph_4():
|
||||
assert model
|
||||
|
||||
out4_tensor = model.get_place_by_tensor_name(tensorName="out4")
|
||||
place1 = out4_tensor.get_producing_operation().get_input_port(inputPortIndex=0)
|
||||
place2 = out4_tensor.get_producing_operation().get_input_port(inputPortIndex=1)
|
||||
place3 = model.get_place_by_operation_name_and_input_port(operationName="split1", inputPortIndex=0)
|
||||
place1 = model.get_place_by_operation_name_and_input_port(operationName="split1", inputPortIndex=0)
|
||||
place2 = out4_tensor.get_producing_operation().get_input_port(inputPortIndex=0)
|
||||
place3 = out4_tensor.get_producing_operation().get_input_port(inputPortIndex=1)
|
||||
place4 = model.get_place_by_tensor_name(tensorName="out1")
|
||||
place5 = model.get_place_by_tensor_name(tensorName="out2")
|
||||
place6 = model.get_place_by_tensor_name(tensorName="out4")
|
||||
@ -371,6 +383,99 @@ def test_extract_subgraph_4():
|
||||
assert res
|
||||
|
||||
|
||||
def test_extract_subgraph_by_op_place_as_input():
|
||||
skip_if_onnx_frontend_is_disabled()
|
||||
fe = fem.load_by_framework(framework=ONNX_FRONTEND_NAME)
|
||||
assert fe
|
||||
|
||||
model = fe.load("input_model.onnx")
|
||||
assert model
|
||||
|
||||
split_op = model.get_place_by_operation_name(operationName="split1")
|
||||
out4 = model.get_place_by_tensor_name(tensorName="out4")
|
||||
mul_op = out4.get_producing_operation()
|
||||
out1 = model.get_place_by_tensor_name(tensorName="out1")
|
||||
out2 = model.get_place_by_tensor_name(tensorName="out2")
|
||||
|
||||
model.extract_subgraph(inputs=[split_op, mul_op], outputs=[out1, out2, out4])
|
||||
result_func = fe.convert(model)
|
||||
|
||||
expected_model = fe.load("extract_subgraph_4.onnx")
|
||||
expected_func = fe.convert(expected_model)
|
||||
|
||||
res = compare_functions(result_func, expected_func)
|
||||
assert res
|
||||
|
||||
|
||||
def test_extract_subgraph_by_op_place_as_output():
|
||||
skip_if_onnx_frontend_is_disabled()
|
||||
fe = fem.load_by_framework(framework=ONNX_FRONTEND_NAME)
|
||||
assert fe
|
||||
|
||||
model = fe.load("input_model.onnx")
|
||||
assert model
|
||||
|
||||
in1_tensor = model.get_place_by_tensor_name(tensorName="in1")
|
||||
in2_tensor = model.get_place_by_tensor_name(tensorName="in2")
|
||||
add_out_tensor = model.get_place_by_tensor_name(tensorName="add_out")
|
||||
add_op = add_out_tensor.get_producing_operation()
|
||||
|
||||
model.extract_subgraph(inputs=[in1_tensor, in2_tensor], outputs=[add_op])
|
||||
result_func = fe.convert(model)
|
||||
|
||||
expected_model = fe.load("extract_subgraph_5.onnx")
|
||||
expected_func = fe.convert(expected_model)
|
||||
|
||||
res = compare_functions(result_func, expected_func)
|
||||
assert res
|
||||
|
||||
|
||||
def test_extract_subgraph_by_op_place_as_output_2():
|
||||
skip_if_onnx_frontend_is_disabled()
|
||||
fe = fem.load_by_framework(framework=ONNX_FRONTEND_NAME)
|
||||
assert fe
|
||||
|
||||
model = fe.load("input_model.onnx")
|
||||
assert model
|
||||
|
||||
split_op = model.get_place_by_operation_name(operationName="split1")
|
||||
out4 = model.get_place_by_tensor_name(tensorName="out4")
|
||||
mul_op = out4.get_producing_operation()
|
||||
|
||||
model.extract_subgraph(inputs=[split_op, mul_op], outputs=[])
|
||||
result_func = fe.convert(model)
|
||||
|
||||
expected_model = fe.load("test_override_all_inputs.onnx")
|
||||
expected_func = fe.convert(expected_model)
|
||||
|
||||
res = compare_functions(result_func, expected_func)
|
||||
assert res
|
||||
|
||||
|
||||
def test_extract_subgraph_by_port_place_as_output():
|
||||
skip_if_onnx_frontend_is_disabled()
|
||||
fe = fem.load_by_framework(framework=ONNX_FRONTEND_NAME)
|
||||
assert fe
|
||||
|
||||
model = fe.load("input_model.onnx")
|
||||
assert model
|
||||
|
||||
add_out_tensor = model.get_place_by_tensor_name(tensorName="add_out")
|
||||
add_op = add_out_tensor.get_producing_operation()
|
||||
add_op_out_port = add_op.get_output_port(outputPortIndex=0)
|
||||
in1_tensor = model.get_place_by_tensor_name(tensorName="in1")
|
||||
in2_tensor = model.get_place_by_tensor_name(tensorName="in2")
|
||||
|
||||
model.extract_subgraph(inputs=[in1_tensor, in2_tensor], outputs=[add_op_out_port])
|
||||
result_func = fe.convert(model)
|
||||
|
||||
expected_model = fe.load("extract_subgraph.onnx")
|
||||
expected_func = fe.convert(expected_model)
|
||||
|
||||
res = compare_functions(result_func, expected_func)
|
||||
assert res
|
||||
|
||||
|
||||
def test_override_all_outputs():
|
||||
skip_if_onnx_frontend_is_disabled()
|
||||
fe = fem.load_by_framework(framework=ONNX_FRONTEND_NAME)
|
||||
|
Loading…
Reference in New Issue
Block a user