Fix "Unexpected number of outputs after override_all_outputs" (#9454)

This commit is contained in:
Dawid Kożykowski 2022-01-05 12:10:16 +01:00 committed by GitHub
parent f255c195c5
commit e89db1c6de
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 173 additions and 20 deletions

View File

@ -49,6 +49,16 @@ from openvino.frontend import FrontEndManager
# | |
# out1 out2
#
#
# ------Test input model 3------
# in1 in2
# | / \
# +--------+ +------+
# | Add | | Relu |
# +--------+ +------+
# | |
# out1 out2
#
def create_test_onnx_models():
models = {}
# Input model 1
@ -91,6 +101,23 @@ def create_test_onnx_models():
models["input_model_2.onnx"] = make_model(graph, producer_name="ONNX Importer",
opset_imports=[onnx.helper.make_opsetid("", 13)])
# Input model 3
add_2 = onnx.helper.make_node("Add", inputs=["in1", "in2"], outputs=["out1"], name="onnx_add_op")
relu_2 = onnx.helper.make_node("Relu", inputs=["in2"], outputs=["out2"])
input_tensors = [
make_tensor_value_info("in1", onnx.TensorProto.FLOAT, (2, 2)),
make_tensor_value_info("in2", onnx.TensorProto.FLOAT, (2, 2)),
]
output_tensors = [
make_tensor_value_info("out1", onnx.TensorProto.FLOAT, (2, 2)),
make_tensor_value_info("out1", onnx.TensorProto.FLOAT, (2, 2)),
make_tensor_value_info("out2", onnx.TensorProto.FLOAT, (2, 2)),
]
graph = make_graph([add_2, relu_2], "test_graph_3", input_tensors, output_tensors)
models["input_model_3.onnx"] = make_model(graph, producer_name="ONNX Importer",
opset_imports=[onnx.helper.make_opsetid("", 13)])
# Expected for extract_subgraph
input_tensors = [
make_tensor_value_info("in1", onnx.TensorProto.FLOAT, (2, 2)),
@ -188,6 +215,19 @@ def create_test_onnx_models():
models["test_override_all_outputs_2.onnx"] = make_model(graph, producer_name="ONNX Importer",
opset_imports=[onnx.helper.make_opsetid("", 13)])
# Expected for test_override_all_outputs 3
input_tensors = [
make_tensor_value_info("in1", onnx.TensorProto.FLOAT, (2, 2)),
make_tensor_value_info("in2", onnx.TensorProto.FLOAT, (2, 2)),
]
output_tensors = [
make_tensor_value_info("out1", onnx.TensorProto.FLOAT, (2, 2)),
make_tensor_value_info("out1", onnx.TensorProto.FLOAT, (2, 2)),
]
graph = make_graph([add_2], "test_graph_3", input_tensors, output_tensors)
models["test_override_all_outputs_3.onnx"] = make_model(graph, producer_name="ONNX Importer",
opset_imports=[onnx.helper.make_opsetid("", 13)])
# Expected for test_override_all_inputs
input_tensors = [
make_tensor_value_info("in3", onnx.TensorProto.FLOAT, (2, 2)),
@ -594,6 +634,50 @@ def test_override_all_outputs_2():
assert res
def test_override_all_outputs_3():
skip_if_onnx_frontend_is_disabled()
fe = fem.load_by_framework(framework=ONNX_FRONTEND_NAME)
assert fe
model = fe.load("input_model_3.onnx")
assert model
place1 = model.get_place_by_tensor_name(tensor_name="out1")
place2 = model.get_place_by_tensor_name(tensor_name="out1")
model.override_all_outputs(outputs=[place1, place2])
result_func = fe.convert(model)
expected_model = fe.load("test_override_all_outputs_3.onnx")
expected_func = fe.convert(expected_model)
res = compare_functions(result_func, expected_func)
assert res
def test_override_all_outputs_invalid_place():
skip_if_onnx_frontend_is_disabled()
fe = fem.load_by_framework(framework=ONNX_FRONTEND_NAME)
assert fe
model = fe.load("input_model_3.onnx")
assert model
model2 = fe.load("input_model.onnx")
assert model2
invalid_place = model2.get_place_by_tensor_name(tensor_name="out3")
place1 = model.get_place_by_tensor_name(tensor_name="out1")
place2 = model.get_place_by_tensor_name(tensor_name="out1")
model.override_all_outputs(outputs=[place1, place2, invalid_place])
result_func = fe.convert(model)
expected_model = fe.load("test_override_all_outputs_3.onnx")
expected_func = fe.convert(expected_model)
res = compare_functions(result_func, expected_func)
assert res
def test_override_all_inputs():
skip_if_onnx_frontend_is_disabled()
fe = fem.load_by_framework(framework=ONNX_FRONTEND_NAME)
@ -618,26 +702,31 @@ def test_override_all_inputs():
assert res
def test_override_all_inputs_exceptions():
def test_override_all_inputs_invalid_place():
skip_if_onnx_frontend_is_disabled()
fe = fem.load_by_framework(framework=ONNX_FRONTEND_NAME)
assert fe
model = fe.load("input_model.onnx")
model = fe.load("input_model_3.onnx")
assert model
place1 = model.get_place_by_tensor_name(tensor_name="in1")
place2 = model.get_place_by_tensor_name(tensor_name="in2")
place3 = model.get_place_by_operation_name_and_input_port(operation_name="split1", input_port_index=0)
place4 = model.get_place_by_tensor_name(tensor_name="in3")
model2 = fe.load("input_model.onnx")
assert model2
with pytest.raises(Exception) as e:
model.override_all_inputs(inputs=[place1, place2])
assert "Unexpected number of inputs after override_all_inputs" in str(e)
out3_tensor = model2.get_place_by_tensor_name(tensor_name="out3")
invalid_place = out3_tensor.get_producing_operation().get_input_port(input_port_index=0)
with pytest.raises(Exception) as e:
model.override_all_inputs(inputs=[place3, place4])
assert "Unexpected number of inputs after override_all_inputs" in str(e)
out1_tensor = model.get_place_by_tensor_name(tensor_name="out1")
place1 = out1_tensor.get_producing_operation().get_input_port(input_port_index=0)
place2 = out1_tensor.get_producing_operation().get_input_port(input_port_index=1)
model.override_all_inputs(inputs=[place1, place2, invalid_place])
result_func = fe.convert(model)
expected_model = fe.load("input_model_3.onnx")
expected_func = fe.convert(expected_model)
res = compare_functions(result_func, expected_func)
assert res
def test_is_input_output():

View File

@ -7,6 +7,7 @@
#include <openvino/frontend/exception.hpp>
#include <openvino/util/file_util.hpp>
#include "ngraph/log.hpp"
#include "place.hpp"
using namespace ov;
@ -202,28 +203,90 @@ std::shared_ptr<Model> InputModel::convert() {
}
// Editor features
bool InputModel::is_correct_place(const ov::frontend::Place::Ptr& place) const {
if (const auto tensor = std::dynamic_pointer_cast<PlaceTensor>(place)) {
return m_editor->is_correct_tensor_name(tensor->get_names()[0]);
}
if (const auto op = std::dynamic_pointer_cast<PlaceOp>(place)) {
return m_editor->is_correct_and_unambiguous_node(op->get_editor_node());
}
if (const auto input_edge = std::dynamic_pointer_cast<PlaceInputEdge>(place)) {
if (auto tensor = std::dynamic_pointer_cast<PlaceTensor>(input_edge->get_source_tensor())) {
return m_editor->is_correct_tensor_name(tensor->get_names()[0]);
}
}
if (const auto output_edge = std::dynamic_pointer_cast<PlaceOutputEdge>(place)) {
if (auto tensor = std::dynamic_pointer_cast<PlaceTensor>(output_edge->get_target_tensor())) {
return m_editor->is_correct_tensor_name(tensor->get_names()[0]);
}
}
return false;
}
void InputModel::override_all_outputs(const std::vector<ov::frontend::Place::Ptr>& outputs) {
extract_subgraph({}, outputs);
NGRAPH_CHECK(m_editor->model_outputs().size() == outputs.size(),
"Unexpected number of outputs after override_all_outputs");
NGRAPH_CHECK(std::all_of(std::begin(outputs),
std::end(outputs),
std::vector<Place::Ptr> expected_valid_outputs;
for (const auto& output : outputs) {
bool is_correct = is_correct_place(output);
if (!is_correct)
NGRAPH_WARN << "Name " << output->get_names().at(0)
<< " of output node is not a correct node name. Ignoring this parameter.";
else
expected_valid_outputs.push_back(output);
}
extract_subgraph({}, expected_valid_outputs);
NGRAPH_CHECK(std::all_of(std::begin(expected_valid_outputs),
std::end(expected_valid_outputs),
[](const ov::frontend::Place::Ptr& place) {
return place->is_output();
}),
"Not all provided arguments of override_all_outputs are new outputs of the model");
const auto current_outputs = get_outputs();
NGRAPH_CHECK(std::all_of(std::begin(current_outputs),
std::end(current_outputs),
[&](const Place::Ptr& current_out) {
return std::find_if(std::begin(expected_valid_outputs),
std::end(expected_valid_outputs),
[&](const Place::Ptr& expected_out) {
return expected_out->is_equal(current_out);
}) != std::end(current_outputs);
}),
"Some other than expected outputs were created during override_all_outputs");
}
void InputModel::override_all_inputs(const std::vector<ov::frontend::Place::Ptr>& inputs) {
std::vector<Place::Ptr> expected_valid_inputs;
for (const auto& input : inputs) {
bool is_correct = is_correct_place(input);
if (!is_correct)
NGRAPH_WARN << "Name " << input->get_names().at(0)
<< " of input node is not a correct node. Ignoring this parameter.";
else
expected_valid_inputs.push_back(input);
}
const auto outputs_before_extraction = m_editor->model_outputs();
extract_subgraph({inputs}, {});
extract_subgraph({expected_valid_inputs}, {});
NGRAPH_CHECK(std::equal(std::begin(outputs_before_extraction),
std::end(outputs_before_extraction),
std::begin(m_editor->model_outputs())),
"All outputs should be preserved after override_all_inputs. Provided inputs does "
"not satisfy all outputs");
NGRAPH_CHECK(m_editor->model_inputs().size() == inputs.size(),
"Unexpected number of inputs after override_all_inputs");
const auto current_inputs = get_inputs();
NGRAPH_CHECK(std::all_of(std::begin(current_inputs),
std::end(current_inputs),
[&](const Place::Ptr& current_in) {
return std::find_if(std::begin(expected_valid_inputs),
std::end(expected_valid_inputs),
[&](const Place::Ptr& expected_in) {
return expected_in->is_equal(current_in);
}) != std::end(current_inputs);
}),
"Some other than expected inputs were created during override_all_inputs");
}
void InputModel::extract_subgraph(const std::vector<ov::frontend::Place::Ptr>& inputs,

View File

@ -78,6 +78,7 @@ public:
private:
std::shared_ptr<ov::onnx_editor::ONNXModelEditor> m_editor;
bool is_correct_place(const ov::frontend::Place::Ptr& place) const;
std::unordered_map<std::string, std::unordered_set<std::string>> m_additional_tensor_names;
void add_tensor_names(std::shared_ptr<Model>& model);